| """Script to download the pre-trained tensorflow weights and convert them to pytorch weights.""" |
| import os |
| import argparse |
| import torch |
| import numpy as np |
| from tensorflow.python.training import py_checkpoint_reader |
|
|
| from repnet import utils |
| from repnet.model import RepNet |
|
|
|
|
| |
| PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__)) |
| TF_CHECKPOINT_BASE_URL = 'https://storage.googleapis.com/repnet_ckpt' |
| TF_CHECKPOINT_FILES = ['checkpoint', 'ckpt-88.data-00000-of-00002', 'ckpt-88.data-00001-of-00002', 'ckpt-88.index'] |
| OUT_CHECKPOINTS_DIR = os.path.join(PROJECT_ROOT, 'checkpoints') |
|
|
| |
| WEIGHTS_PERMUTATION = { |
| 2: (1, 0), |
| 4: (3, 2, 0, 1), |
| 5: (4, 3, 0, 1, 2) |
| } |
|
|
| |
| ATTR_MAPPING = { |
| 'kernel':'weight', |
| 'bias': 'bias', |
| 'beta': 'bias', |
| 'gamma': 'weight', |
| 'moving_mean': 'running_mean', |
| 'moving_variance': 'running_var' |
| } |
|
|
| |
| WEIGHTS_MAPPING = [ |
| |
| ('base_model.layer-2', 'conv1_conv', 'encoder.stem.conv'), |
| ('base_model.layer-5', 'conv2_block1_preact_bn', 'encoder.stages.0.blocks.0.norm1'), |
| ('base_model.layer-7', 'conv2_block1_1_conv', 'encoder.stages.0.blocks.0.conv1'), |
| ('base_model.layer-8', 'conv2_block1_1_bn', 'encoder.stages.0.blocks.0.norm2'), |
| ('base_model.layer_with_weights-4', 'conv2_block1_2_conv', 'encoder.stages.0.blocks.0.conv2'), |
| ('base_model.layer_with_weights-5', 'conv2_block1_2_bn', 'encoder.stages.0.blocks.0.norm3'), |
| ('base_model.layer_with_weights-6', 'conv2_block1_0_conv', 'encoder.stages.0.blocks.0.downsample.conv'), |
| ('base_model.layer_with_weights-7', 'conv2_block1_3_conv', 'encoder.stages.0.blocks.0.conv3'), |
| ('base_model.layer_with_weights-8', 'conv2_block2_preact_bn', 'encoder.stages.0.blocks.1.norm1'), |
| ('base_model.layer_with_weights-9', 'conv2_block2_1_conv', 'encoder.stages.0.blocks.1.conv1'), |
| ('base_model.layer_with_weights-10', 'conv2_block2_1_bn', 'encoder.stages.0.blocks.1.norm2'), |
| ('base_model.layer_with_weights-11', 'conv2_block2_2_conv', 'encoder.stages.0.blocks.1.conv2'), |
| ('base_model.layer_with_weights-12', 'conv2_block2_2_bn', 'encoder.stages.0.blocks.1.norm3'), |
| ('base_model.layer_with_weights-13', 'conv2_block2_3_conv', 'encoder.stages.0.blocks.1.conv3'), |
| ('base_model.layer_with_weights-14', 'conv2_block3_preact_bn', 'encoder.stages.0.blocks.2.norm1'), |
| ('base_model.layer_with_weights-15', 'conv2_block3_1_conv', 'encoder.stages.0.blocks.2.conv1'), |
| ('base_model.layer_with_weights-16', 'conv2_block3_1_bn', 'encoder.stages.0.blocks.2.norm2'), |
| ('base_model.layer_with_weights-17', 'conv2_block3_2_conv', 'encoder.stages.0.blocks.2.conv2'), |
| ('base_model.layer_with_weights-18', 'conv2_block3_2_bn', 'encoder.stages.0.blocks.2.norm3'), |
| ('base_model.layer_with_weights-19', 'conv2_block3_3_conv', 'encoder.stages.0.blocks.2.conv3'), |
| ('base_model.layer_with_weights-20', 'conv3_block1_preact_bn', 'encoder.stages.1.blocks.0.norm1'), |
| ('base_model.layer_with_weights-21', 'conv3_block1_1_conv', 'encoder.stages.1.blocks.0.conv1'), |
| ('base_model.layer_with_weights-22', 'conv3_block1_1_bn', 'encoder.stages.1.blocks.0.norm2'), |
| ('base_model.layer_with_weights-23', 'conv3_block1_2_conv', 'encoder.stages.1.blocks.0.conv2'), |
| ('base_model.layer-47', 'conv3_block1_2_bn', 'encoder.stages.1.blocks.0.norm3'), |
| ('base_model.layer_with_weights-25', 'conv3_block1_0_conv', 'encoder.stages.1.blocks.0.downsample.conv'), |
| ('base_model.layer_with_weights-26', 'conv3_block1_3_conv', 'encoder.stages.1.blocks.0.conv3'), |
| ('base_model.layer_with_weights-27', 'conv3_block2_preact_bn', 'encoder.stages.1.blocks.1.norm1'), |
| ('base_model.layer_with_weights-28', 'conv3_block2_1_conv', 'encoder.stages.1.blocks.1.conv1'), |
| ('base_model.layer_with_weights-29', 'conv3_block2_1_bn', 'encoder.stages.1.blocks.1.norm2'), |
| ('base_model.layer_with_weights-30', 'conv3_block2_2_conv', 'encoder.stages.1.blocks.1.conv2'), |
| ('base_model.layer_with_weights-31', 'conv3_block2_2_bn', 'encoder.stages.1.blocks.1.norm3'), |
| ('base_model.layer-61', 'conv3_block2_3_conv', 'encoder.stages.1.blocks.1.conv3'), |
| ('base_model.layer-63', 'conv3_block3_preact_bn', 'encoder.stages.1.blocks.2.norm1'), |
| ('base_model.layer-65', 'conv3_block3_1_conv', 'encoder.stages.1.blocks.2.conv1'), |
| ('base_model.layer-66', 'conv3_block3_1_bn', 'encoder.stages.1.blocks.2.norm2'), |
| ('base_model.layer-69', 'conv3_block3_2_conv', 'encoder.stages.1.blocks.2.conv2'), |
| ('base_model.layer-70', 'conv3_block3_2_bn', 'encoder.stages.1.blocks.2.norm3'), |
| ('base_model.layer_with_weights-38', 'conv3_block3_3_conv', 'encoder.stages.1.blocks.2.conv3'), |
| ('base_model.layer-74', 'conv3_block4_preact_bn', 'encoder.stages.1.blocks.3.norm1'), |
| ('base_model.layer_with_weights-40', 'conv3_block4_1_conv', 'encoder.stages.1.blocks.3.conv1'), |
| ('base_model.layer_with_weights-41', 'conv3_block4_1_bn', 'encoder.stages.1.blocks.3.norm2'), |
| ('base_model.layer_with_weights-42', 'conv3_block4_2_conv', 'encoder.stages.1.blocks.3.conv2'), |
| ('base_model.layer_with_weights-43', 'conv3_block4_2_bn', 'encoder.stages.1.blocks.3.norm3'), |
| ('base_model.layer_with_weights-44', 'conv3_block4_3_conv', 'encoder.stages.1.blocks.3.conv3'), |
| ('base_model.layer_with_weights-45', 'conv4_block1_preact_bn', 'encoder.stages.2.blocks.0.norm1'), |
| ('base_model.layer_with_weights-46', 'conv4_block1_1_conv', 'encoder.stages.2.blocks.0.conv1'), |
| ('base_model.layer_with_weights-47', 'conv4_block1_1_bn', 'encoder.stages.2.blocks.0.norm2'), |
| ('base_model.layer-92', 'conv4_block1_2_conv', 'encoder.stages.2.blocks.0.conv2'), |
| ('base_model.layer-93', 'conv4_block1_2_bn', 'encoder.stages.2.blocks.0.norm3'), |
| ('base_model.layer-95', 'conv4_block1_0_conv', 'encoder.stages.2.blocks.0.downsample.conv'), |
| ('base_model.layer-96', 'conv4_block1_3_conv', 'encoder.stages.2.blocks.0.conv3'), |
| ('base_model.layer-98', 'conv4_block2_preact_bn', 'encoder.stages.2.blocks.1.norm1'), |
| ('base_model.layer-100', 'conv4_block2_1_conv', 'encoder.stages.2.blocks.1.conv1'), |
| ('base_model.layer-101', 'conv4_block2_1_bn', 'encoder.stages.2.blocks.1.norm2'), |
| ('base_model.layer-104', 'conv4_block2_2_conv', 'encoder.stages.2.blocks.1.conv2'), |
| ('base_model.layer-105', 'conv4_block2_2_bn', 'encoder.stages.2.blocks.1.norm3'), |
| ('base_model.layer-107', 'conv4_block2_3_conv', 'encoder.stages.2.blocks.1.conv3'), |
| ('base_model.layer-109', 'conv4_block3_preact_bn', 'encoder.stages.2.blocks.2.norm1'), |
| ('base_model.layer-111', 'conv4_block3_1_conv', 'encoder.stages.2.blocks.2.conv1'), |
| ('base_model.layer-112', 'conv4_block3_1_bn', 'encoder.stages.2.blocks.2.norm2'), |
| ('base_model.layer-115', 'conv4_block3_2_conv', 'encoder.stages.2.blocks.2.conv2'), |
| ('base_model.layer-116', 'conv4_block3_2_bn', 'encoder.stages.2.blocks.2.norm3'), |
| ('base_model.layer-118', 'conv4_block3_3_conv', 'encoder.stages.2.blocks.2.conv3'), |
| |
| ('temporal_conv_layers.0', 'conv3d', 'temporal_conv.0'), |
| ('temporal_bn_layers.0', 'batch_normalization', 'temporal_conv.1'), |
| ('conv_3x3_layer', 'conv2d', 'tsm_conv.0'), |
| |
| ('input_projection', 'dense', 'period_length_head.0.input_projection'), |
| ('pos_encoding', None, 'period_length_head.0.pos_encoding'), |
| ('transformer_layers.0.ffn.layer-0', None, 'period_length_head.0.transformer_layer.linear1'), |
| ('transformer_layers.0.ffn.layer-1', None, 'period_length_head.0.transformer_layer.linear2'), |
| ('transformer_layers.0.layernorm1', None, 'period_length_head.0.transformer_layer.norm1'), |
| ('transformer_layers.0.layernorm2', None, 'period_length_head.0.transformer_layer.norm2'), |
| ('transformer_layers.0.mha.w_weight', None, 'period_length_head.0.transformer_layer.self_attn.in_proj_weight'), |
| ('transformer_layers.0.mha.w_bias', None, 'period_length_head.0.transformer_layer.self_attn.in_proj_bias'), |
| ('transformer_layers.0.mha.dense', None, 'period_length_head.0.transformer_layer.self_attn.out_proj'), |
| ('fc_layers.0', 'dense_14', 'period_length_head.1'), |
| ('fc_layers.1', 'dense_15', 'period_length_head.3'), |
| ('fc_layers.2', 'dense_16', 'period_length_head.5'), |
| |
| ('input_projection2', 'dense_1', 'periodicity_head.0.input_projection'), |
| ('pos_encoding2', None, 'periodicity_head.0.pos_encoding'), |
| ('transformer_layers2.0.ffn.layer-0', None, 'periodicity_head.0.transformer_layer.linear1'), |
| ('transformer_layers2.0.ffn.layer-1', None, 'periodicity_head.0.transformer_layer.linear2'), |
| ('transformer_layers2.0.layernorm1', None, 'periodicity_head.0.transformer_layer.norm1'), |
| ('transformer_layers2.0.layernorm2', None, 'periodicity_head.0.transformer_layer.norm2'), |
| ('transformer_layers2.0.mha.w_weight',None, 'periodicity_head.0.transformer_layer.self_attn.in_proj_weight'), |
| ('transformer_layers2.0.mha.w_bias', None, 'periodicity_head.0.transformer_layer.self_attn.in_proj_bias'), |
| ('transformer_layers2.0.mha.dense', None, 'periodicity_head.0.transformer_layer.self_attn.out_proj'), |
| ('within_period_fc_layers.0', 'dense_17', 'periodicity_head.1'), |
| ('within_period_fc_layers.1', 'dense_18', 'periodicity_head.3'), |
| ('within_period_fc_layers.2', 'dense_19', 'periodicity_head.5'), |
| ] |
|
|
| |
| parser = argparse.ArgumentParser(description='Download and convert the pre-trained weights from tensorflow to pytorch.') |
|
|
|
|
| if __name__ == '__main__': |
| args = parser.parse_args() |
|
|
| |
| print('Downloading checkpoints...') |
| tf_checkpoint_dir = os.path.join(OUT_CHECKPOINTS_DIR, 'tf_checkpoint') |
| os.makedirs(tf_checkpoint_dir, exist_ok=True) |
| for file in TF_CHECKPOINT_FILES: |
| dst = os.path.join(tf_checkpoint_dir, file) |
| if not os.path.exists(dst): |
| utils.download_file(f'{TF_CHECKPOINT_BASE_URL}/{file}', dst) |
|
|
| |
| print('Loading tensorflow checkpoint...') |
| checkpoint_path = os.path.join(tf_checkpoint_dir, 'ckpt-88') |
| checkpoint_reader = py_checkpoint_reader.NewCheckpointReader(checkpoint_path) |
| shape_map = checkpoint_reader.get_variable_to_shape_map() |
| tf_state_dict = {} |
| for var_name in sorted(shape_map.keys()): |
| var_tensor = checkpoint_reader.get_tensor(var_name) |
| if not var_name.startswith('model') or '.OPTIMIZER_SLOT' in var_name: |
| continue |
| |
| var_path = var_name.split('/')[1:] |
| var_path = [p for p in var_path if p not in ['.ATTRIBUTES', 'VARIABLE_VALUE']] |
| |
| current_dict = tf_state_dict |
| for path in var_path[:-1]: |
| current_dict = current_dict.setdefault(path, {}) |
| current_dict[var_path[-1]] = var_tensor |
|
|
| |
| for k in ['transformer_layers', 'transformer_layers2']: |
| v = tf_state_dict[k]['0']['mha'] |
| v['w_weight'] = np.concatenate([v['wq']['kernel'].T, v['wk']['kernel'].T, v['wv']['kernel'].T], axis=0) |
| v['w_bias'] = np.concatenate([v['wq']['bias'].T, v['wk']['bias'].T, v['wv']['bias'].T], axis=0) |
| del v['wk'], v['wq'], v['wv'] |
| tf_state_dict = utils.flatten_dict(tf_state_dict, keep_last=True) |
| |
| for k, v in tf_state_dict.items(): |
| if not isinstance(v, dict): |
| tf_state_dict[k] = {None: v} |
|
|
| |
| print(f'Converting to PyTorch format...') |
| pt_checkpoint_path = os.path.join(OUT_CHECKPOINTS_DIR, 'pytorch_weights.pth') |
| pt_state_dict = {} |
| for k_tf, _, k_pt in WEIGHTS_MAPPING: |
| assert k_pt not in pt_state_dict |
| pt_state_dict[k_pt] = {} |
| for attr in tf_state_dict[k_tf]: |
| new_attr = ATTR_MAPPING.get(attr, attr) |
| pt_state_dict[k_pt][new_attr] = torch.from_numpy(tf_state_dict[k_tf][attr]) |
| if attr == 'kernel': |
| weights_permutation = WEIGHTS_PERMUTATION[pt_state_dict[k_pt][new_attr].ndim] |
| pt_state_dict[k_pt][new_attr] = pt_state_dict[k_pt][new_attr].permute(weights_permutation) |
| pt_state_dict = utils.flatten_dict(pt_state_dict, skip_none=True) |
| torch.save(pt_state_dict, pt_checkpoint_path) |
|
|
| |
| print('Check that the weights can be loaded into the model...') |
| model = RepNet() |
| pt_state_dict = torch.load(pt_checkpoint_path) |
| model.load_state_dict(pt_state_dict) |
|
|
| print(f'Done. PyTorch weights saved to {pt_checkpoint_path}.') |
|
|