| | import argparse |
| | from contextlib import nullcontext |
| |
|
| | |
| | import collections |
| | import json |
| | import os |
| | import re |
| |
|
| | import torch |
| | from time import time |
| | from src.text_utils.logging import get_logger |
| | from contextlib import contextmanager |
| | from timeit import default_timer |
| |
|
| | logger = get_logger(__name__) |
| | |
| | |
| |
|
| |
|
| | @contextmanager |
| | def elapsed_timer(): |
| | start = default_timer() |
| | elapser = lambda: default_timer() - start |
| | yield lambda: elapser() |
| | end = default_timer() |
| | elapser = lambda: end-start |
| |
|
| | class AverageMeter(object): |
| | """Computes and stores the average and current value""" |
| | def __init__(self, name, fmt=':f'): |
| | self.name = name |
| | self.fmt = fmt |
| | self.reset() |
| |
|
| | def reset(self): |
| | self.val = 0 |
| | self.avg = 0 |
| | self.sum = 0 |
| | self.count = 0 |
| |
|
| | def update(self, val, n=1): |
| | self.val = val |
| | self.sum += val * n |
| | self.count += n |
| | self.avg = self.sum / self.count |
| |
|
| | def __str__(self): |
| | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' |
| | return fmtstr.format(**self.__dict__) |
| |
|
| |
|
| | def save_args_to_json(args, output_json_path): |
| | serializable_args = {} |
| | for k, v in vars(args).items(): |
| | try: |
| | v = json.dumps(v) |
| | serializable_args[k] = v |
| | except Exception as e: |
| | continue |
| | with open(output_json_path, 'w') as arg_json: |
| | json.dump(serializable_args, arg_json) |
| |
|
| |
|
| | def load_args_from_json(output_json_path): |
| | if os.path.isdir(output_json_path): |
| | output_json_path += 'train_args.json' |
| | with open(output_json_path, 'r') as arg_json: |
| | kwargs = json.load(arg_json) |
| | _kwargs = {} |
| | for k, v in kwargs.items(): |
| | if v == 'null': |
| | v = None |
| | elif v == 'true' or v == 'false': |
| | v = True if v == 'true' else False |
| | else: |
| | try: |
| | v = eval(v) |
| | except ValueError: |
| | pass |
| | _kwargs[k] = v |
| | args = argparse.Namespace(**_kwargs) |
| | return args |
| |
|
| | def tensor_norm(input, input_mask=None): |
| | if input_mask is not None: |
| | _norm = torch.linalg.norm((input * input_mask.unsqueeze(-1)), dim=1) |
| | _norm = torch.masked_select(_norm, input_mask.bool().reshape(-1)) |
| | else: |
| | _norm = torch.linalg.norm(input, dim=1, ord=2) |
| | return _norm.mean() |
| |
|
| |
|
| | class print_time(): |
| | def __init__(self, task): |
| | self.task = task |
| |
|
| | def __enter__(self): |
| | print_master(self.task) |
| | self.t = time() |
| |
|
| | def __exit__(self, type, value, traceback): |
| | print_master(f'{self.task} took {time()-self.t:.02f}s') |
| |
|
| |
|
| | def print_rank(message): |
| | """If distributed is initialized, print the rank.""" |
| | if torch.distributed.is_initialized(): |
| | logger.info(f'rank{torch.distributed.get_rank()}: ' + message) |
| | else: |
| | logger.info(message) |
| |
|
| |
|
| | def print_master(message): |
| | """If distributed is initialized print only on rank 0.""" |
| | if torch.distributed.is_initialized(): |
| | if torch.distributed.get_rank() == 0: |
| | logger.info(message) |
| | else: |
| | logger.info(message) |
| |
|
| |
|
| | def str2bool(v): |
| | if isinstance(v, bool): |
| | return v |
| | if v.lower() in ('yes', 'true', 't', 'y', '1'): |
| | return True |
| | elif v.lower() in ('no', 'false', 'f', 'n', '0'): |
| | return False |
| | else: |
| | raise argparse.ArgumentTypeError('Boolean value expected.') |
| |
|
| |
|
| | def calc_gradient_norm(model, return_param_norm=False, return_details=True, is_deepspeed=False): |
| | ''' |
| | return_param_norm: if True it returns the norm of parameters, otherwise grad |
| | No effect for DeepSpeed as it handles parameters differently |
| | ''' |
| | total_norm = 0.0 |
| | n_parameter = 0 |
| | group_norm = collections.defaultdict(float) |
| | group_norm['total'] = 0.0 |
| | for n, p in model.named_parameters(): |
| | |
| | with nullcontext(): |
| | if p.requires_grad and p.grad is not None: |
| | if return_param_norm: |
| | param_norm = p.detach().data.norm(p=2).item() |
| | else: |
| | param_norm = p.grad.detach().data.norm(p=2).item() |
| | |
| | total_norm += param_norm ** 2 |
| | n_parameter += torch.numel(p.grad) |
| | module_name = 'q_encoder' |
| | |
| | if return_details: |
| | if 'embed' in n: |
| | part_name = 'embeddings' |
| | group_norm[f'{module_name}-{part_name}'] += param_norm |
| | elif 'addon_layer' in n: |
| | part_name = 'addon_layer' |
| | group_norm[f'{module_name}-{part_name}'] += param_norm |
| | elif 'layer' in n: |
| | part_name = re.search('layers.\d+|layer.\d+', n) |
| | if part_name: |
| | part_name = part_name.group(0) |
| | else: |
| | part_name = 'unknown_group' |
| | |
| | group_norm[f'{module_name}-{part_name}'] += param_norm |
| | if "model" in n: |
| | part_name = n[n.rfind("model")+6:] |
| | part_name = part_name.replace('module.', '').replace('.dense', '').replace('.weight', '').replace('.bias', '').replace('.pytorch', '').replace('.default', '') |
| | group_norm[f'{part_name}'] += param_norm |
| |
|
| | group_norm['total'] = total_norm ** 0.5 |
| | return group_norm |
| |
|
| |
|
| | def get_gradient_norm(model): |
| | total_norm = 0.0 |
| | for p in model.parameters(): |
| | param_norm = p.grad.data.norm(2).item() if p.grad is not None else 0.0 |
| | total_norm += param_norm ** 2 |
| | total_norm = total_norm ** (1. / 2) |
| | return total_norm |
| |
|
| |
|
| | def count_parameters(model): |
| | total_num = sum(p.numel() for p in model.parameters()) |
| | grad_num = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| | print(f'#Total parameters: {total_num}') |
| | print(f'#Parameters require gradient: {grad_num}') |
| |
|