| |
| |
| |
| |
| |
| |
|
|
| import numpy as np |
| from torch import optim as optim |
|
|
| def build_optimizer(config, model): |
| """ |
| Build optimizer, set weight decay of normalization to 0 by default. |
| """ |
| skip = {} |
| skip_keywords = {} |
| if hasattr(model, 'no_weight_decay'): |
| skip = model.no_weight_decay() |
| if hasattr(model, 'no_weight_decay_keywords'): |
| skip_keywords = model.no_weight_decay_keywords() |
|
|
| elif config.MODEL.TYPE.startswith("revcol"): |
| parameters = param_groups_lrd(model, weight_decay=config.TRAIN.WEIGHT_DECAY, no_weight_decay_list=[], layer_decay=config.TRAIN.OPTIMIZER.LAYER_DECAY) |
| else: |
| parameters = set_weight_decay(model, skip, skip_keywords) |
|
|
|
|
| opt_lower = config.TRAIN.OPTIMIZER.NAME.lower() |
| optimizer = None |
| if opt_lower == 'sgd': |
| optimizer = optim.SGD(parameters, momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True, |
| lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) |
| elif opt_lower == 'adamw': |
| optimizer = optim.AdamW(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS, |
| lr=config.TRAIN.BASE_LR) |
| |
| return optimizer |
|
|
|
|
| def set_weight_decay(model, skip_list=(), skip_keywords=()): |
| has_decay = [] |
| no_decay = [] |
|
|
| for name, param in model.named_parameters(): |
| if not param.requires_grad or name in ["linear_eval.weight", "linear_eval.bias"]: |
| continue |
| if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \ |
| check_keywords_in_name(name, skip_keywords): |
| no_decay.append(param) |
| |
| else: |
| has_decay.append(param) |
| return [{'params': has_decay}, |
| {'params': no_decay, 'weight_decay': 0.}] |
|
|
|
|
| def check_keywords_in_name(name, keywords=()): |
| isin = False |
| for keyword in keywords: |
| if keyword in name: |
| isin = True |
| return isin |
|
|
| def cal_model_depth(columns, layers): |
| depth = sum(layers) |
| dp = np.zeros((depth, columns)) |
| dp[:,0]=np.linspace(0, depth-1, depth) |
| dp[0,:]=np.linspace(0, columns-1, columns) |
| for i in range(1, depth): |
| for j in range(1, columns): |
| dp[i][j] = min(dp[i][j-1], dp[i-1][j])+1 |
| dp = dp.astype(int) |
| return dp |
|
|
|
|
| def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75): |
| """ |
| Parameter groups for layer-wise lr decay |
| Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 |
| """ |
| param_group_names = {} |
| param_groups = {} |
| dp = cal_model_depth(model.num_subnet, model.layers)+1 |
| num_layers = dp[-1][-1] + 1 |
|
|
| layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) |
|
|
| for n, p in model.named_parameters(): |
| if not p.requires_grad: |
| continue |
|
|
| |
| if p.ndim == 1 or n in no_weight_decay_list: |
| g_decay = "no_decay" |
| this_decay = 0. |
| else: |
| g_decay = "decay" |
| this_decay = weight_decay |
| |
| layer_id = get_layer_id(n, dp, model.layers) |
| group_name = "layer_%d_%s" % (layer_id, g_decay) |
|
|
| if group_name not in param_group_names: |
| this_scale = layer_scales[layer_id] |
|
|
| param_group_names[group_name] = { |
| "lr_scale": this_scale, |
| "weight_decay": this_decay, |
| "params": [], |
| } |
| param_groups[group_name] = { |
| "lr_scale": this_scale, |
| "weight_decay": this_decay, |
| "params": [], |
| } |
|
|
| param_group_names[group_name]["params"].append(n) |
| param_groups[group_name]["params"].append(p) |
|
|
| |
| |
| return list(param_groups.values()) |
|
|
| def get_layer_id(n, dp, layers): |
| if n.startswith("subnet"): |
| name_part = n.split('.') |
| subnet = int(name_part[0][6:]) |
| if name_part[1].startswith("alpha"): |
| id = dp[0][subnet] |
| else: |
| level = int(name_part[1][-1]) |
| if name_part[2].startswith("blocks"): |
| sub = int(name_part[3]) |
| if sub>layers[level]-1: |
| sub = layers[level]-1 |
| block = sum(layers[:level])+sub |
|
|
| if name_part[2].startswith("fusion"): |
| block = sum(layers[:level]) |
| id = dp[block][subnet] |
| elif n.startswith("stem"): |
| id = 0 |
| else: |
| id = dp[-1][-1]+1 |
| return id |