| import sys |
|
|
| import wandb |
| from time import sleep |
| import os |
|
|
| def init_wandb(project_name, model_name, config, **wandb_kwargs): |
| os.environ['WANDB__SERVICE_WAIT'] = '300' |
| while True: |
| try: |
| wandb_run = wandb.init( |
| project=project_name, name=model_name, save_code=True, |
| config=config, **wandb_kwargs, |
| ) |
| break |
| except Exception as e: |
| print('wandb connection error', file=sys.stderr) |
| print(f'error: {e}', file=sys.stderr) |
| sleep(1) |
| print('retrying..', file=sys.stderr) |
| return wandb_run |
|
|
| def str2bool(v): |
| if isinstance(v, bool): |
| return v |
| if v.lower() in ('yes', 'true', 't', 'y', '1'): |
| return True |
| elif v.lower() in ('no', 'false', 'f', 'n', '0'): |
| return False |
| else: |
| raise ValueError |
|
|
| class AverageMeter(object): |
| """Computes and stores the average and current value""" |
| def __init__(self, name, fmt=':f'): |
| self.name = name |
| self.fmt = fmt |
| self.reset() |
|
|
| def reset(self): |
| self.val = 0 |
| self.avg = 0 |
| self.sum = 0 |
| self.count = 0 |
|
|
| def update(self, val, n=1): |
| self.val = val |
| self.sum += val * n |
| self.count += n |
| self.avg = self.sum / self.count |
|
|
| def __str__(self): |
| fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' |
| return fmtstr.format(**self.__dict__) |