| |
| |
| |
| |
| |
|
|
| |
|
|
| import inspect |
| import logging |
| import os |
| from collections import defaultdict |
| from dataclasses import field |
| from typing import Any, Dict, List, Optional, Tuple |
|
|
| import torch.optim |
|
|
| from accelerate import Accelerator |
|
|
| from pytorch3d.implicitron.models.base_model import ImplicitronModelBase |
| from pytorch3d.implicitron.tools import model_io |
| from pytorch3d.implicitron.tools.config import ( |
| registry, |
| ReplaceableBase, |
| run_auto_creation, |
| ) |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class OptimizerFactoryBase(ReplaceableBase): |
| def __call__( |
| self, model: ImplicitronModelBase, **kwargs |
| ) -> Tuple[torch.optim.Optimizer, Any]: |
| """ |
| Initialize the optimizer and lr scheduler. |
| |
| Args: |
| model: The model with optionally loaded weights. |
| |
| Returns: |
| An optimizer module (optionally loaded from a checkpoint) and |
| a learning rate scheduler module (should be a subclass of torch.optim's |
| lr_scheduler._LRScheduler). |
| """ |
| raise NotImplementedError() |
|
|
|
|
| @registry.register |
| class ImplicitronOptimizerFactory(OptimizerFactoryBase): |
| """ |
| A factory that initializes the optimizer and lr scheduler. |
| |
| Members: |
| betas: Beta parameters for the Adam optimizer. |
| breed: The type of optimizer to use. We currently support SGD, Adagrad |
| and Adam. |
| exponential_lr_step_size: With Exponential policy only, |
| lr = lr * gamma ** (epoch/step_size) |
| gamma: Multiplicative factor of learning rate decay. |
| lr: The value for the initial learning rate. |
| lr_policy: The policy to use for learning rate. We currently support |
| MultiStepLR and Exponential policies. |
| momentum: A momentum value (for SGD only). |
| multistep_lr_milestones: With MultiStepLR policy only: list of |
| increasing epoch indices at which the learning rate is modified. |
| momentum: Momentum factor for SGD optimizer. |
| weight_decay: The optimizer weight_decay (L2 penalty on model weights). |
| foreach: Whether to use new "foreach" implementation of optimizer where |
| available (e.g. requires PyTorch 1.12.0 for Adam) |
| group_learning_rates: Parameters or modules can be assigned to parameter |
| groups. This dictionary has names of those parameter groups as keys |
| and learning rates as values. All parameter group names have to be |
| defined in this dictionary. Parameters which do not have predefined |
| parameter group are put into "default" parameter group which has |
| `lr` as its learning rate. |
| """ |
|
|
| betas: Tuple[float, ...] = (0.9, 0.999) |
| breed: str = "Adam" |
| exponential_lr_step_size: int = 250 |
| gamma: float = 0.1 |
| lr: float = 0.0005 |
| lr_policy: str = "MultiStepLR" |
| momentum: float = 0.9 |
| multistep_lr_milestones: tuple = () |
| weight_decay: float = 0.0 |
| linear_exponential_lr_milestone: int = 200 |
| linear_exponential_start_gamma: float = 0.1 |
| foreach: Optional[bool] = True |
| group_learning_rates: Dict[str, float] = field(default_factory=lambda: {}) |
|
|
| def __post_init__(self): |
| run_auto_creation(self) |
|
|
| def __call__( |
| self, |
| last_epoch: int, |
| model: ImplicitronModelBase, |
| accelerator: Optional[Accelerator] = None, |
| exp_dir: Optional[str] = None, |
| resume: bool = True, |
| resume_epoch: int = -1, |
| **kwargs, |
| ) -> Tuple[torch.optim.Optimizer, Any]: |
| """ |
| Initialize the optimizer (optionally from a checkpoint) and the lr scheduluer. |
| |
| Args: |
| last_epoch: If the model was loaded from checkpoint this will be the |
| number of the last epoch that was saved. |
| model: The model with optionally loaded weights. |
| accelerator: An optional Accelerator instance. |
| exp_dir: Root experiment directory. |
| resume: If True, attempt to load optimizer checkpoint from exp_dir. |
| Failure to do so will return a newly initialized optimizer. |
| resume_epoch: If `resume` is True: Resume optimizer at this epoch. If |
| `resume_epoch` <= 0, then resume from the latest checkpoint. |
| Returns: |
| An optimizer module (optionally loaded from a checkpoint) and |
| a learning rate scheduler module (should be a subclass of torch.optim's |
| lr_scheduler._LRScheduler). |
| """ |
| |
| if hasattr(model, "_get_param_groups"): |
| p_groups = model._get_param_groups(self.lr, wd=self.weight_decay) |
| else: |
| p_groups = [ |
| {"params": params, "lr": self._get_group_learning_rate(group)} |
| for group, params in self._get_param_groups(model).items() |
| ] |
|
|
| |
| optimizer_kwargs: Dict[str, Any] = { |
| "lr": self.lr, |
| "weight_decay": self.weight_decay, |
| } |
| if self.breed == "SGD": |
| optimizer_class = torch.optim.SGD |
| optimizer_kwargs["momentum"] = self.momentum |
| elif self.breed == "Adagrad": |
| optimizer_class = torch.optim.Adagrad |
| elif self.breed == "Adam": |
| optimizer_class = torch.optim.Adam |
| optimizer_kwargs["betas"] = self.betas |
| else: |
| raise ValueError(f"No such solver type {self.breed}") |
|
|
| if "foreach" in inspect.signature(optimizer_class.__init__).parameters: |
| optimizer_kwargs["foreach"] = self.foreach |
| optimizer = optimizer_class(p_groups, **optimizer_kwargs) |
| logger.info(f"Solver type = {self.breed}") |
|
|
| |
| optimizer_state = self._get_optimizer_state( |
| exp_dir, |
| accelerator, |
| resume_epoch=resume_epoch, |
| resume=resume, |
| ) |
| if optimizer_state is not None: |
| logger.info("Setting loaded optimizer state.") |
| optimizer.load_state_dict(optimizer_state) |
|
|
| |
| if self.lr_policy.casefold() == "MultiStepLR".casefold(): |
| scheduler = torch.optim.lr_scheduler.MultiStepLR( |
| optimizer, |
| milestones=self.multistep_lr_milestones, |
| gamma=self.gamma, |
| ) |
| elif self.lr_policy.casefold() == "Exponential".casefold(): |
| scheduler = torch.optim.lr_scheduler.LambdaLR( |
| optimizer, |
| lambda epoch: self.gamma ** (epoch / self.exponential_lr_step_size), |
| verbose=False, |
| ) |
| elif self.lr_policy.casefold() == "LinearExponential".casefold(): |
| |
| |
| |
| def _get_lr(epoch: int): |
| m = self.linear_exponential_lr_milestone |
| if epoch < m: |
| w = (m - epoch) / m |
| gamma = w * self.linear_exponential_start_gamma + (1 - w) |
| else: |
| epoch_rest = epoch - m |
| gamma = self.gamma ** (epoch_rest / self.exponential_lr_step_size) |
| return gamma |
|
|
| scheduler = torch.optim.lr_scheduler.LambdaLR( |
| optimizer, _get_lr, verbose=False |
| ) |
| else: |
| raise ValueError("no such lr policy %s" % self.lr_policy) |
|
|
| |
| |
| for _ in range(last_epoch): |
| scheduler.step() |
|
|
| optimizer.zero_grad() |
|
|
| return optimizer, scheduler |
|
|
| def _get_optimizer_state( |
| self, |
| exp_dir: Optional[str], |
| accelerator: Optional[Accelerator] = None, |
| resume: bool = True, |
| resume_epoch: int = -1, |
| ) -> Optional[Dict[str, Any]]: |
| """ |
| Load an optimizer state from a checkpoint. |
| |
| resume: If True, attempt to load the last checkpoint from `exp_dir` |
| passed to __call__. Failure to do so will return a newly initialized |
| optimizer. |
| resume_epoch: If `resume` is True: Resume optimizer at this epoch. If |
| `resume_epoch` <= 0, then resume from the latest checkpoint. |
| """ |
| if exp_dir is None or not resume: |
| return None |
| if resume_epoch > 0: |
| save_path = model_io.get_checkpoint(exp_dir, resume_epoch) |
| if not os.path.isfile(save_path): |
| raise FileNotFoundError( |
| f"Cannot find optimizer from epoch {resume_epoch}." |
| ) |
| else: |
| save_path = model_io.find_last_checkpoint(exp_dir) |
| optimizer_state = None |
| if save_path is not None: |
| logger.info(f"Found previous optimizer state {save_path} -> resuming.") |
| opt_path = model_io.get_optimizer_path(save_path) |
|
|
| if os.path.isfile(opt_path): |
| map_location = None |
| if accelerator is not None and not accelerator.is_local_main_process: |
| map_location = { |
| "cuda:%d" % 0: "cuda:%d" % accelerator.local_process_index |
| } |
| optimizer_state = torch.load(opt_path, map_location) |
| else: |
| raise FileNotFoundError(f"Optimizer state {opt_path} does not exist.") |
| return optimizer_state |
|
|
| def _get_param_groups( |
| self, module: torch.nn.Module |
| ) -> Dict[str, List[torch.nn.Parameter]]: |
| """ |
| Recursively visits all the modules inside the `module` and sorts all the |
| parameters in parameter groups. |
| |
| Uses `param_groups` dictionary member, where keys are names of individual |
| parameters or module members and values are the names of the parameter groups |
| for those parameters or members. "self" key is used to denote the parameter groups |
| at the module level. Possible keys, including the "self" key do not have to |
| be defined. By default all parameters have the learning rate defined in the |
| optimizer. This can be overridden by setting the parameter group in `param_groups` |
| member of a specific module. Values are a parameter group name. The keys |
| specify what parameters will be affected as follows: |
| - “self”: All the parameters of the module and its child modules |
| - name of a parameter: A parameter with that name. |
| - name of a module member: All the parameters of the module and its |
| child modules. |
| This is useful if members do not have `param_groups`, for |
| example torch.nn.Linear. |
| - <name of module member>.<something>: recursive. Same as if <something> |
| was used in param_groups of that submodule/member. |
| |
| Args: |
| module: module from which to extract the parameters and their parameter |
| groups |
| Returns: |
| dictionary with parameter groups as keys and lists of parameters as values |
| """ |
|
|
| param_groups = defaultdict(list) |
|
|
| def traverse(module, default_group: str, mapping: Dict[str, str]) -> None: |
| """ |
| Visitor for module to assign its parameters to the relevant member of |
| param_groups. |
| |
| Args: |
| module: the module being visited in a depth-first search |
| default_group: the param group to assign parameters to unless |
| otherwise overriden. |
| mapping: known mappings of parameters to groups for this module, |
| destructively modified by this function. |
| """ |
| |
| |
| if hasattr(module, "param_groups") and "self" in module.param_groups: |
| default_group = module.param_groups["self"] |
|
|
| |
| |
| |
| if hasattr(module, "param_groups"): |
| mapping.update(module.param_groups) |
|
|
| for name, param in module.named_parameters(recurse=False): |
| if param.requires_grad: |
| group_name = mapping.get(name, default_group) |
| logger.debug(f"Assigning {name} to param_group {group_name}") |
| param_groups[group_name].append(param) |
|
|
| |
| |
| for child_name, child in module.named_children(): |
| mapping_to_add = { |
| name[len(child_name) + 1 :]: group |
| for name, group in mapping.items() |
| if name.startswith(child_name + ".") |
| } |
| traverse(child, mapping.get(child_name, default_group), mapping_to_add) |
|
|
| traverse(module, "default", {}) |
| return param_groups |
|
|
| def _get_group_learning_rate(self, group_name: str) -> float: |
| """ |
| Wraps the `group_learning_rates` dictionary providing errors and returns |
| `self.lr` for "default" group_name. |
| |
| Args: |
| group_name: a string representing the name of the group |
| Returns: |
| learning rate for a specific group |
| """ |
| if group_name == "default": |
| return self.lr |
| lr = self.group_learning_rates.get(group_name, None) |
| if lr is None: |
| raise ValueError(f"no learning rate given for group {group_name}") |
| return lr |
|
|