| | """A collection of useful helper functions""" |
| |
|
| | import os |
| | import logging |
| | import json |
| |
|
| | import torch |
| | from torch.profiler import profile, record_function, ProfilerActivity |
| | import pandas as pd |
| | from torchmetrics.functional import( |
| | scale_invariant_signal_noise_ratio as si_snr, |
| | signal_noise_ratio as snr, |
| | signal_distortion_ratio as sdr, |
| | scale_invariant_signal_distortion_ratio as si_sdr) |
| | import matplotlib.pyplot as plt |
| |
|
| | class Params(): |
| | """Class that loads hyperparameters from a json file. |
| | Example: |
| | ``` |
| | params = Params(json_path) |
| | print(params.learning_rate) |
| | params.learning_rate = 0.5 # change the value of learning_rate in params |
| | ``` |
| | """ |
| |
|
| | def __init__(self, json_path): |
| | with open(json_path) as f: |
| | params = json.load(f) |
| | self.__dict__.update(params) |
| |
|
| | def save(self, json_path): |
| | with open(json_path, 'w') as f: |
| | json.dump(self.__dict__, f, indent=4) |
| |
|
| | def update(self, json_path): |
| | """Loads parameters from json file""" |
| | with open(json_path) as f: |
| | params = json.load(f) |
| | self.__dict__.update(params) |
| |
|
| | @property |
| | def dict(self): |
| | """Gives dict-like access to Params instance by `params.dict['learning_rate']""" |
| | return self.__dict__ |
| |
|
| | def save_graph(train_metrics, test_metrics, save_dir): |
| | metrics = [snr, si_snr] |
| | results = {'train_loss': train_metrics['loss'], |
| | 'test_loss' : test_metrics['loss']} |
| |
|
| | for m_fn in metrics: |
| | results["train_"+m_fn.__name__] = train_metrics[m_fn.__name__] |
| | results["test_"+m_fn.__name__] = test_metrics[m_fn.__name__] |
| |
|
| | results_pd = pd.DataFrame(results) |
| |
|
| | results_pd.to_csv(os.path.join(save_dir, 'results.csv')) |
| |
|
| | fig, temp_ax = plt.subplots(2, 3, figsize=(15,10)) |
| | axs=[] |
| | for i in temp_ax: |
| | for j in i: |
| | axs.append(j) |
| |
|
| | x = range(len(train_metrics['loss'])) |
| | axs[0].plot(x, train_metrics['loss'], label='train') |
| | axs[0].plot(x, test_metrics['loss'], label='test') |
| | axs[0].set(ylabel='Loss') |
| | axs[0].set(xlabel='Epoch') |
| | axs[0].set_title('loss',fontweight='bold') |
| | axs[0].legend() |
| |
|
| | for i in range(len(metrics)): |
| | axs[i+1].plot(x, train_metrics[metrics[i].__name__], label='train') |
| | axs[i+1].plot(x, test_metrics[metrics[i].__name__], label='test') |
| | axs[i+1].set(xlabel='Epoch') |
| | axs[i+1].set_title(metrics[i].__name__,fontweight='bold') |
| | axs[i+1].legend() |
| |
|
| | plt.tight_layout() |
| | plt.savefig(os.path.join(save_dir, 'results.png')) |
| | plt.close(fig) |
| |
|
| | def set_logger(log_path): |
| | """Set the logger to log info in terminal and file `log_path`. |
| | In general, it is useful to have a logger so that every output to the terminal is saved |
| | in a permanent file. Here we save it to `model_dir/train.log`. |
| | Example: |
| | ``` |
| | logging.info("Starting training...") |
| | ``` |
| | Args: |
| | log_path: (string) where to log |
| | """ |
| | logger = logging.getLogger() |
| | logger.setLevel(logging.INFO) |
| | logger.handlers.clear() |
| |
|
| | |
| | file_handler = logging.FileHandler(log_path) |
| | file_handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s')) |
| | logger.addHandler(file_handler) |
| |
|
| | |
| | stream_handler = logging.StreamHandler() |
| | stream_handler.setFormatter(logging.Formatter('%(message)s')) |
| | logger.addHandler(stream_handler) |
| |
|
| | def load_checkpoint(checkpoint, model, optim=None, lr_sched=None, data_parallel=False): |
| | """Loads model parameters (state_dict) from file_path. |
| | |
| | Args: |
| | checkpoint: (string) filename which needs to be loaded |
| | model: (torch.nn.Module) model for which the parameters are loaded |
| | data_parallel: (bool) if the model is a data parallel model |
| | """ |
| | if not os.path.exists(checkpoint): |
| | raise("File doesn't exist {}".format(checkpoint)) |
| |
|
| | state_dict = torch.load(checkpoint) |
| |
|
| | if data_parallel: |
| | state_dict['model_state_dict'] = { |
| | 'module.' + k: state_dict['model_state_dict'][k] |
| | for k in state_dict['model_state_dict'].keys()} |
| | model.load_state_dict(state_dict['model_state_dict']) |
| |
|
| | if optim is not None: |
| | optim.load_state_dict(state_dict['optim_state_dict']) |
| |
|
| | if lr_sched is not None: |
| | lr_sched.load_state_dict(state_dict['lr_sched_state_dict']) |
| |
|
| | return state_dict['epoch'], state_dict['train_metrics'], \ |
| | state_dict['val_metrics'] |
| |
|
| | def save_checkpoint(checkpoint, epoch, model, optim=None, lr_sched=None, |
| | train_metrics=None, val_metrics=None, data_parallel=False): |
| | """Saves model parameters (state_dict) to file_path. |
| | |
| | Args: |
| | checkpoint: (string) filename which needs to be loaded |
| | model: (torch.nn.Module) model for which the parameters are loaded |
| | data_parallel: (bool) if the model is a data parallel model |
| | """ |
| | if os.path.exists(checkpoint): |
| | raise("File already exists {}".format(checkpoint)) |
| |
|
| | model_state_dict = model.state_dict() |
| | if data_parallel: |
| | model_state_dict = { |
| | k.partition('module.')[2]: |
| | model_state_dict[k] for k in model_state_dict.keys()} |
| |
|
| | optim_state_dict = None if not optim else optim.state_dict() |
| | lr_sched_state_dict = None if not lr_sched else lr_sched.state_dict() |
| |
|
| | state_dict = { |
| | 'epoch': epoch, |
| | 'model_state_dict': model_state_dict, |
| | 'optim_state_dict': optim_state_dict, |
| | 'lr_sched_state_dict': lr_sched_state_dict, |
| | 'train_metrics': train_metrics, |
| | 'val_metrics': val_metrics |
| | } |
| |
|
| | torch.save(state_dict, checkpoint) |
| |
|
| | def model_size(model): |
| | """ |
| | Returns size of the `model` in millions of parameters. |
| | """ |
| | num_train_params = sum( |
| | p.numel() for p in model.parameters() if p.requires_grad) |
| | return num_train_params / 1e6 |
| |
|
| | def run_time(model, inputs, profiling=False): |
| | """ |
| | Returns runtime of a model in ms. |
| | """ |
| | |
| | for _ in range(100): |
| | output = model(*inputs) |
| |
|
| | with profile(activities=[ProfilerActivity.CPU], |
| | record_shapes=True) as prof: |
| | with record_function("model_inference"): |
| | output = model(*inputs) |
| |
|
| | |
| | if profiling: |
| | print(prof.key_averages().table(sort_by="self_cpu_time_total", |
| | row_limit=20)) |
| |
|
| | |
| | return prof.profiler.self_cpu_time_total / 1000 |
| |
|
| | def format_lr_info(optimizer): |
| | lr_info = "" |
| | for i, pg in enumerate(optimizer.param_groups): |
| | lr_info += " {group %d: params=%.5fM lr=%.1E}" % ( |
| | i, sum([p.numel() for p in pg['params']]) / (1024 ** 2), pg['lr']) |
| | return lr_info |
| |
|
| |
|