# Copyright (c) OpenMMLab. All rights reserved. import copy import inspect from typing import List, Union import torch import torch.nn as nn from mmengine.config import Config, ConfigDict from mmengine.device import is_npu_available, is_npu_support_full_precision from mmengine.registry import OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS from .optimizer_wrapper import OptimWrapper def register_torch_optimizers() -> List[str]: """Register optimizers in ``torch.optim`` to the ``OPTIMIZERS`` registry. Returns: List[str]: A list of registered optimizers' name. """ torch_optimizers = [] for module_name in dir(torch.optim): if module_name.startswith('__'): continue _optim = getattr(torch.optim, module_name) if inspect.isclass(_optim) and issubclass(_optim, torch.optim.Optimizer): #OPTIMIZERS.register_module(module=_optim) if module_name == 'Adafactor': OPTIMIZERS.register_module( name='torch_Adafactor', module=_optim) else: OPTIMIZERS.register_module(module=_optim) torch_optimizers.append(module_name) return torch_optimizers TORCH_OPTIMIZERS = register_torch_optimizers() def register_torch_npu_optimizers() -> List[str]: """Register optimizers in ``torch npu`` to the ``OPTIMIZERS`` registry. Returns: List[str]: A list of registered optimizers' name. """ if not is_npu_available(): return [] import torch_npu if not hasattr(torch_npu, 'optim'): return [] torch_npu_optimizers = [] for module_name in dir(torch_npu.optim): if module_name.startswith('__') or module_name in OPTIMIZERS: continue _optim = getattr(torch_npu.optim, module_name) if inspect.isclass(_optim) and issubclass(_optim, torch.optim.Optimizer): OPTIMIZERS.register_module(module=_optim) torch_npu_optimizers.append(module_name) return torch_npu_optimizers NPU_OPTIMIZERS = register_torch_npu_optimizers() def register_dadaptation_optimizers() -> List[str]: """Register optimizers in ``dadaptation`` to the ``OPTIMIZERS`` registry. Returns: List[str]: A list of registered optimizers' name. """ dadaptation_optimizers = [] try: import dadaptation except ImportError: pass else: for module_name in ['DAdaptAdaGrad', 'DAdaptAdam', 'DAdaptSGD']: _optim = getattr(dadaptation, module_name) if inspect.isclass(_optim) and issubclass(_optim, torch.optim.Optimizer): OPTIMIZERS.register_module(module=_optim) dadaptation_optimizers.append(module_name) return dadaptation_optimizers DADAPTATION_OPTIMIZERS = register_dadaptation_optimizers() def register_lion_optimizers() -> List[str]: """Register Lion optimizer to the ``OPTIMIZERS`` registry. Returns: List[str]: A list of registered optimizers' name. """ optimizers = [] try: from lion_pytorch import Lion except ImportError: pass else: OPTIMIZERS.register_module(module=Lion) optimizers.append('Lion') return optimizers LION_OPTIMIZERS = register_lion_optimizers() def register_sophia_optimizers() -> List[str]: """Register Sophia optimizer to the ``OPTIMIZERS`` registry. Returns: List[str]: A list of registered optimizers' name. """ optimizers = [] try: import Sophia except ImportError: pass else: for module_name in dir(Sophia): _optim = getattr(Sophia, module_name) if inspect.isclass(_optim) and issubclass(_optim, torch.optim.Optimizer): OPTIMIZERS.register_module(module=_optim) optimizers.append(module_name) return optimizers SOPHIA_OPTIMIZERS = register_sophia_optimizers() def register_bitsandbytes_optimizers() -> List[str]: """Register optimizers in ``bitsandbytes`` to the ``OPTIMIZERS`` registry. In the `bitsandbytes` library, optimizers that have the same name as the default optimizers in PyTorch are prefixed with ``bnb_``. For example, ``bnb_Adagrad``. Returns: List[str]: A list of registered optimizers' name. """ dadaptation_optimizers = [] try: import bitsandbytes as bnb except ImportError: pass else: optim_classes = inspect.getmembers( bnb.optim, lambda _optim: (inspect.isclass(_optim) and issubclass( _optim, torch.optim.Optimizer))) for name, optim_cls in optim_classes: if name in OPTIMIZERS: name = f'bnb_{name}' OPTIMIZERS.register_module(module=optim_cls, name=name) dadaptation_optimizers.append(name) return dadaptation_optimizers BITSANDBYTES_OPTIMIZERS = register_bitsandbytes_optimizers() def register_transformers_optimizers(): transformer_optimizers = [] try: from transformers import Adafactor except ImportError: pass else: OPTIMIZERS.register_module(name='Adafactor', module=Adafactor) transformer_optimizers.append('Adafactor') return transformer_optimizers TRANSFORMERS_OPTIMIZERS = register_transformers_optimizers() def build_optim_wrapper(model: nn.Module, cfg: Union[dict, Config, ConfigDict]) -> OptimWrapper: """Build function of OptimWrapper. If ``constructor`` is set in the ``cfg``, this method will build an optimizer wrapper constructor, and use optimizer wrapper constructor to build the optimizer wrapper. If ``constructor`` is not set, the ``DefaultOptimWrapperConstructor`` will be used by default. Args: model (nn.Module): Model to be optimized. cfg (dict): Config of optimizer wrapper, optimizer constructor and optimizer. Returns: OptimWrapper: The built optimizer wrapper. """ optim_wrapper_cfg = copy.deepcopy(cfg) constructor_type = optim_wrapper_cfg.pop('constructor', 'DefaultOptimWrapperConstructor') paramwise_cfg = optim_wrapper_cfg.pop('paramwise_cfg', None) # Since the current generation of NPU(Ascend 910) only supports # mixed precision training, here we turn on mixed precision # to make the training normal if is_npu_available() and not is_npu_support_full_precision(): optim_wrapper_cfg['type'] = 'AmpOptimWrapper' optim_wrapper_constructor = OPTIM_WRAPPER_CONSTRUCTORS.build( dict( type=constructor_type, optim_wrapper_cfg=optim_wrapper_cfg, paramwise_cfg=paramwise_cfg)) optim_wrapper = optim_wrapper_constructor(model) return optim_wrapper