| | import torch |
| | import os |
| |
|
| | def get_rank(): |
| | """Get rank of current process.""" |
| | |
| | print(os.environ.keys()) |
| |
|
| | if "SLURM_PROCID" in os.environ: |
| | return int(os.environ["SLURM_PROCID"]) |
| |
|
| | if not torch.distributed.is_available() or not torch.distributed.is_initialized(): |
| | return 0 |
| | |
| | return torch.distributed.get_rank() |
| |
|
| | class InverseLR(torch.optim.lr_scheduler._LRScheduler): |
| | """Implements an inverse decay learning rate schedule with an optional exponential |
| | warmup. When last_epoch=-1, sets initial lr as lr. |
| | inv_gamma is the number of steps/epochs required for the learning rate to decay to |
| | (1 / 2)**power of its original value. |
| | Args: |
| | optimizer (Optimizer): Wrapped optimizer. |
| | inv_gamma (float): Inverse multiplicative factor of learning rate decay. Default: 1. |
| | power (float): Exponential factor of learning rate decay. Default: 1. |
| | warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable) |
| | Default: 0. |
| | final_lr (float): The final learning rate. Default: 0. |
| | last_epoch (int): The index of last epoch. Default: -1. |
| | verbose (bool): If ``True``, prints a message to stdout for |
| | each update. Default: ``False``. |
| | """ |
| |
|
| | def __init__(self, optimizer, inv_gamma=1., power=1., warmup=0., final_lr=0., |
| | last_epoch=-1, verbose=False): |
| | self.inv_gamma = inv_gamma |
| | self.power = power |
| | if not 0. <= warmup < 1: |
| | raise ValueError('Invalid value for warmup') |
| | self.warmup = warmup |
| | self.final_lr = final_lr |
| | super().__init__(optimizer, last_epoch, verbose) |
| |
|
| | def get_lr(self): |
| | if not self._get_lr_called_within_step: |
| | import warnings |
| | warnings.warn("To get the last learning rate computed by the scheduler, " |
| | "please use `get_last_lr()`.") |
| |
|
| | return self._get_closed_form_lr() |
| |
|
| | def _get_closed_form_lr(self): |
| | warmup = 1 - self.warmup ** (self.last_epoch + 1) |
| | lr_mult = (1 + self.last_epoch / self.inv_gamma) ** -self.power |
| | return [warmup * max(self.final_lr, base_lr * lr_mult) |
| | for base_lr in self.base_lrs] |
| |
|
| | def copy_state_dict(model, state_dict): |
| | """Load state_dict to model, but only for keys that match exactly. |
| | |
| | Args: |
| | model (nn.Module): model to load state_dict. |
| | state_dict (OrderedDict): state_dict to load. |
| | """ |
| | model_state_dict = model.state_dict() |
| | for key in state_dict: |
| | if key in model_state_dict and state_dict[key].shape == model_state_dict[key].shape: |
| | if isinstance(state_dict[key], torch.nn.Parameter): |
| | |
| | state_dict[key] = state_dict[key].data |
| | model_state_dict[key] = state_dict[key] |
| | |
| | model.load_state_dict(model_state_dict, strict=False) |
| |
|
| | def create_optimizer_from_config(optimizer_config, parameters): |
| | """Create optimizer from config. |
| | |
| | Args: |
| | parameters (iterable): parameters to optimize. |
| | optimizer_config (dict): optimizer config. |
| | |
| | Returns: |
| | torch.optim.Optimizer: optimizer. |
| | """ |
| |
|
| | optimizer_type = optimizer_config["type"] |
| |
|
| | if optimizer_type == "FusedAdam": |
| | from deepspeed.ops.adam import FusedAdam |
| | optimizer = FusedAdam(parameters, **optimizer_config["config"]) |
| | else: |
| | optimizer_fn = getattr(torch.optim, optimizer_type) |
| | optimizer = optimizer_fn(parameters, **optimizer_config["config"]) |
| | return optimizer |
| |
|
| | def create_scheduler_from_config(scheduler_config, optimizer): |
| | """Create scheduler from config. |
| | |
| | Args: |
| | scheduler_config (dict): scheduler config. |
| | optimizer (torch.optim.Optimizer): optimizer. |
| | |
| | Returns: |
| | torch.optim.lr_scheduler._LRScheduler: scheduler. |
| | """ |
| | if scheduler_config["type"] == "InverseLR": |
| | scheduler_fn = InverseLR |
| | else: |
| | scheduler_fn = getattr(torch.optim.lr_scheduler, scheduler_config["type"]) |
| | scheduler = scheduler_fn(optimizer, **scheduler_config["config"]) |
| | return scheduler |