Spaces:
Build error
Build error
| from torch.optim import SGD, lr_scheduler | |
| from configs.qsenn_training_params import QSENNScheduler | |
| from configs.sldd_training_params import OptimizationScheduler | |
| from training.img_net import get_default_img_schedule, get_default_img_optimizer | |
| def get_optimizer(model, schedulingClass): | |
| lr,weight_decay, step_lr, step_lr_gamma, n_epochs, finetune = schedulingClass.get_params() | |
| print("Optimizer LR set to ", lr) | |
| if lr is None: # Dense Training on ImageNet | |
| print("Learning rate is None, using Default Recipe for Resnet50") | |
| default_img_optimizer = get_default_img_optimizer(model) | |
| default_img_schedule = get_default_img_schedule(default_img_optimizer) | |
| return default_img_optimizer, default_img_schedule, 600 | |
| if finetune: | |
| param_list = [x for x in model.parameters() if x.requires_grad] | |
| else: | |
| param_list = model.parameters() | |
| if finetune: | |
| optimizer = SGD(param_list,lr, momentum=0.95, | |
| weight_decay=weight_decay) | |
| else: | |
| classifier_params_name = ["linear.bias","linear.weight"] | |
| classifier_params = [x[1] for x in | |
| list(filter(lambda kv: kv[0] in classifier_params_name, model.named_parameters()))] | |
| base_params = [x[1] for x in list( | |
| filter(lambda kv: kv[0] not in classifier_params_name, model.named_parameters()))] | |
| optimizer = SGD([ | |
| {'params': base_params}, | |
| {"params": classifier_params, 'lr': 0.01} | |
| ], momentum=0.9, lr=lr, weight_decay=weight_decay) | |
| # Make schedule | |
| schedule = lr_scheduler.StepLR(optimizer, step_size=step_lr, gamma=step_lr_gamma) | |
| return optimizer, schedule, n_epochs | |
| def get_scheduler_for_model(model, dataset): | |
| if model == "qsenn": | |
| return QSENNScheduler(dataset) | |
| elif model == "sldd": | |
| return OptimizationScheduler(dataset) |