| import argparse |
| import collections |
| import warnings |
|
|
| import numpy as np |
| import torch |
|
|
| import hw_asr.loss as module_loss |
| import hw_asr.metric as module_metric |
| import hw_asr.model as module_arch |
| from hw_asr.trainer import Trainer |
| from hw_asr.utils import prepare_device |
| from hw_asr.utils.object_loading import get_dataloaders |
| from hw_asr.utils.parse_config import ConfigParser |
|
|
| warnings.filterwarnings("ignore", category=UserWarning) |
|
|
| |
| SEED = 123 |
| torch.manual_seed(SEED) |
| torch.backends.cudnn.deterministic = True |
| torch.backends.cudnn.benchmark = False |
| np.random.seed(SEED) |
|
|
|
|
| def main(config): |
| logger = config.get_logger("train") |
|
|
| |
| text_encoder = config.get_text_encoder() |
|
|
| |
| dataloaders = get_dataloaders(config, text_encoder) |
|
|
| |
| model = config.init_obj(config["arch"], module_arch, n_class=len(text_encoder)) |
| logger.info(model) |
|
|
| |
| device, device_ids = prepare_device(config["n_gpu"]) |
| model = model.to(device) |
| if len(device_ids) > 1: |
| model = torch.nn.DataParallel(model, device_ids=device_ids) |
|
|
| |
| loss_module = config.init_obj(config["loss"], module_loss).to(device) |
| metrics = [ |
| config.init_obj(metric_dict, module_metric, text_encoder=text_encoder) |
| for metric_dict in config["metrics"] |
| ] |
|
|
| |
| |
| trainable_params = filter(lambda p: p.requires_grad, model.parameters()) |
| optimizer = config.init_obj(config["optimizer"], torch.optim, trainable_params) |
| lr_scheduler = config.init_obj(config["lr_scheduler"], torch.optim.lr_scheduler, optimizer) |
|
|
| trainer = Trainer( |
| model, |
| loss_module, |
| metrics, |
| optimizer, |
| text_encoder=text_encoder, |
| config=config, |
| device=device, |
| dataloaders=dataloaders, |
| lr_scheduler=lr_scheduler, |
| len_epoch=config["trainer"].get("len_epoch", None) |
| ) |
|
|
| trainer.train() |
|
|
|
|
| if __name__ == "__main__": |
| args = argparse.ArgumentParser(description="PyTorch Template") |
| args.add_argument( |
| "-c", |
| "--config", |
| default=None, |
| type=str, |
| help="config file path (default: None)", |
| ) |
| args.add_argument( |
| "-r", |
| "--resume", |
| default=None, |
| type=str, |
| help="path to latest checkpoint (default: None)", |
| ) |
| args.add_argument( |
| "-d", |
| "--device", |
| default=None, |
| type=str, |
| help="indices of GPUs to enable (default: all)", |
| ) |
|
|
| |
| CustomArgs = collections.namedtuple("CustomArgs", "flags type target") |
| options = [ |
| CustomArgs(["--lr", "--learning_rate"], type=float, target="optimizer;args;lr"), |
| CustomArgs( |
| ["--bs", "--batch_size"], type=int, target="data_loader;args;batch_size" |
| ), |
| ] |
| config = ConfigParser.from_args(args, options) |
| main(config) |
|
|