| | |
| | import copy |
| | import logging |
| | from abc import ABCMeta |
| | from collections import defaultdict |
| | from logging import FileHandler |
| | from typing import Iterable, List, Optional, Union |
| |
|
| | import torch.nn as nn |
| |
|
| | from mmengine.dist import master_only |
| | from mmengine.logging import MMLogger, print_log |
| | from .weight_init import PretrainedInit, initialize, update_init_info |
| | from .wrappers.utils import is_model_wrapper |
| |
|
| |
|
| | class BaseModule(nn.Module, metaclass=ABCMeta): |
| | """Base module for all modules in openmmlab. ``BaseModule`` is a wrapper of |
| | ``torch.nn.Module`` with additional functionality of parameter |
| | initialization. Compared with ``torch.nn.Module``, ``BaseModule`` mainly |
| | adds three attributes. |
| | |
| | - ``init_cfg``: the config to control the initialization. |
| | - ``init_weights``: The function of parameter initialization and recording |
| | initialization information. |
| | - ``_params_init_info``: Used to track the parameter initialization |
| | information. This attribute only exists during executing the |
| | ``init_weights``. |
| | |
| | Note: |
| | :obj:`PretrainedInit` has a higher priority than any other |
| | initializer. The loaded pretrained weights will overwrite |
| | the previous initialized weights. |
| | |
| | Args: |
| | init_cfg (dict or List[dict], optional): Initialization config dict. |
| | """ |
| |
|
| | def __init__(self, init_cfg: Union[dict, List[dict], None] = None): |
| | """Initialize BaseModule, inherited from `torch.nn.Module`""" |
| |
|
| | |
| | |
| |
|
| | super().__init__() |
| | |
| | |
| | self._is_init = False |
| |
|
| | self.init_cfg = copy.deepcopy(init_cfg) |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | @property |
| | def is_init(self): |
| | return self._is_init |
| |
|
| | @is_init.setter |
| | def is_init(self, value): |
| | self._is_init = value |
| |
|
| | def init_weights(self): |
| | """Initialize the weights.""" |
| |
|
| | is_top_level_module = False |
| | |
| | if not hasattr(self, '_params_init_info'): |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | self._params_init_info = defaultdict(dict) |
| | is_top_level_module = True |
| |
|
| | |
| | |
| | |
| | |
| | for name, param in self.named_parameters(): |
| | self._params_init_info[param][ |
| | 'init_info'] = f'The value is the same before and ' \ |
| | f'after calling `init_weights` ' \ |
| | f'of {self.__class__.__name__} ' |
| | self._params_init_info[param][ |
| | 'tmp_mean_value'] = param.data.mean().cpu() |
| |
|
| | |
| | |
| | |
| | |
| | for sub_module in self.modules(): |
| | sub_module._params_init_info = self._params_init_info |
| |
|
| | module_name = self.__class__.__name__ |
| | if not self._is_init: |
| | if self.init_cfg: |
| | print_log( |
| | f'initialize {module_name} with init_cfg {self.init_cfg}', |
| | logger='current', |
| | level=logging.DEBUG) |
| |
|
| | init_cfgs = self.init_cfg |
| | if isinstance(self.init_cfg, dict): |
| | init_cfgs = [self.init_cfg] |
| |
|
| | |
| | |
| | |
| | |
| | other_cfgs = [] |
| | pretrained_cfg = [] |
| | for init_cfg in init_cfgs: |
| | assert isinstance(init_cfg, dict) |
| | if (init_cfg['type'] == 'Pretrained' |
| | or init_cfg['type'] is PretrainedInit): |
| | pretrained_cfg.append(init_cfg) |
| | else: |
| | other_cfgs.append(init_cfg) |
| |
|
| | initialize(self, other_cfgs) |
| |
|
| | for m in self.children(): |
| | if is_model_wrapper(m) and not hasattr(m, 'init_weights'): |
| | m = m.module |
| | if hasattr(m, 'init_weights') and not getattr( |
| | m, 'is_init', False): |
| | m.init_weights() |
| | |
| | update_init_info( |
| | m, |
| | init_info=f'Initialized by ' |
| | f'user-defined `init_weights`' |
| | f' in {m.__class__.__name__} ') |
| | if self.init_cfg and pretrained_cfg: |
| | initialize(self, pretrained_cfg) |
| | self._is_init = True |
| | else: |
| | print_log( |
| | f'init_weights of {self.__class__.__name__} has ' |
| | f'been called more than once.', |
| | logger='current', |
| | level=logging.WARNING) |
| |
|
| | if is_top_level_module: |
| | self._dump_init_info() |
| |
|
| | for sub_module in self.modules(): |
| | del sub_module._params_init_info |
| |
|
| | @master_only |
| | def _dump_init_info(self): |
| | """Dump the initialization information to a file named |
| | `initialization.log.json` in workdir.""" |
| |
|
| | logger = MMLogger.get_current_instance() |
| | with_file_handler = False |
| | |
| | for handler in logger.handlers: |
| | if isinstance(handler, FileHandler): |
| | handler.stream.write( |
| | 'Name of parameter - Initialization information\n') |
| | for name, param in self.named_parameters(): |
| | handler.stream.write( |
| | f'\n{name} - {param.shape}: ' |
| | f"\n{self._params_init_info[param]['init_info']} \n") |
| | handler.stream.flush() |
| | with_file_handler = True |
| | if not with_file_handler: |
| | for name, param in self.named_parameters(): |
| | logger.info( |
| | f'\n{name} - {param.shape}: ' |
| | f"\n{self._params_init_info[param]['init_info']} \n ") |
| |
|
| | def __repr__(self): |
| | s = super().__repr__() |
| | if self.init_cfg: |
| | s += f'\ninit_cfg={self.init_cfg}' |
| | return s |
| |
|
| |
|
| | class Sequential(BaseModule, nn.Sequential): |
| | """Sequential module in openmmlab. |
| | |
| | Ensures that all modules in ``Sequential`` have a different initialization |
| | strategy than the outer model |
| | |
| | Args: |
| | init_cfg (dict, optional): Initialization config dict. |
| | """ |
| |
|
| | def __init__(self, *args, init_cfg: Optional[dict] = None): |
| | BaseModule.__init__(self, init_cfg) |
| | nn.Sequential.__init__(self, *args) |
| |
|
| |
|
| | class ModuleList(BaseModule, nn.ModuleList): |
| | """ModuleList in openmmlab. |
| | |
| | Ensures that all modules in ``ModuleList`` have a different initialization |
| | strategy than the outer model |
| | |
| | Args: |
| | modules (iterable, optional): An iterable of modules to add. |
| | init_cfg (dict, optional): Initialization config dict. |
| | """ |
| |
|
| | def __init__(self, |
| | modules: Optional[Iterable] = None, |
| | init_cfg: Optional[dict] = None): |
| | BaseModule.__init__(self, init_cfg) |
| | nn.ModuleList.__init__(self, modules) |
| |
|
| |
|
| | class ModuleDict(BaseModule, nn.ModuleDict): |
| | """ModuleDict in openmmlab. |
| | |
| | Ensures that all modules in ``ModuleDict`` have a different initialization |
| | strategy than the outer model |
| | |
| | Args: |
| | modules (dict, optional): A mapping (dictionary) of (string: module) |
| | or an iterable of key-value pairs of type (string, module). |
| | init_cfg (dict, optional): Initialization config dict. |
| | """ |
| |
|
| | def __init__(self, |
| | modules: Optional[dict] = None, |
| | init_cfg: Optional[dict] = None): |
| | BaseModule.__init__(self, init_cfg) |
| | nn.ModuleDict.__init__(self, modules) |
| |
|