| | import torch |
| | from torch import nn |
| | from transformers import HubertConfig, HubertModel |
| | import logging |
| |
|
| | |
| | logging.getLogger("fairseq").setLevel(logging.WARNING) |
| | logging.getLogger("torch.distributed.nn.jit.instantiator").setLevel(logging.WARNING) |
| |
|
| | from fairseq import checkpoint_utils |
| |
|
| | models, _, _ = checkpoint_utils.load_model_ensemble_and_task( |
| | ["checkpoint_best_legacy_500.pt"], suffix="" |
| | ) |
| | model = models[0] |
| | model.eval() |
| | model.eval() |
| |
|
| |
|
| | class HubertModelWithFinalProj(HubertModel): |
| | def __init__(self, config): |
| | super().__init__(config) |
| |
|
| | self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size) |
| |
|
| |
|
| | |
| | hubert = HubertModelWithFinalProj(HubertConfig()) |
| |
|
| | |
| | mapping = { |
| | "masked_spec_embed": "mask_emb", |
| | "encoder.layer_norm.bias": "encoder.layer_norm.bias", |
| | "encoder.layer_norm.weight": "encoder.layer_norm.weight", |
| | "encoder.pos_conv_embed.conv.bias": "encoder.pos_conv.0.bias", |
| | "encoder.pos_conv_embed.conv.weight_g": "encoder.pos_conv.0.weight_g", |
| | "encoder.pos_conv_embed.conv.weight_v": "encoder.pos_conv.0.weight_v", |
| | "feature_projection.layer_norm.bias": "layer_norm.bias", |
| | "feature_projection.layer_norm.weight": "layer_norm.weight", |
| | "feature_projection.projection.bias": "post_extract_proj.bias", |
| | "feature_projection.projection.weight": "post_extract_proj.weight", |
| | "final_proj.bias": "final_proj.bias", |
| | "final_proj.weight": "final_proj.weight", |
| | } |
| |
|
| | |
| | for layer in range(12): |
| | for j in ["q", "k", "v"]: |
| | mapping[ |
| | f"encoder.layers.{layer}.attention.{j}_proj.weight" |
| | ] = f"encoder.layers.{layer}.self_attn.{j}_proj.weight" |
| | mapping[ |
| | f"encoder.layers.{layer}.attention.{j}_proj.bias" |
| | ] = f"encoder.layers.{layer}.self_attn.{j}_proj.bias" |
| |
|
| | mapping[ |
| | f"encoder.layers.{layer}.final_layer_norm.bias" |
| | ] = f"encoder.layers.{layer}.final_layer_norm.bias" |
| | mapping[ |
| | f"encoder.layers.{layer}.final_layer_norm.weight" |
| | ] = f"encoder.layers.{layer}.final_layer_norm.weight" |
| |
|
| | mapping[ |
| | f"encoder.layers.{layer}.layer_norm.bias" |
| | ] = f"encoder.layers.{layer}.self_attn_layer_norm.bias" |
| | mapping[ |
| | f"encoder.layers.{layer}.layer_norm.weight" |
| | ] = f"encoder.layers.{layer}.self_attn_layer_norm.weight" |
| |
|
| | mapping[ |
| | f"encoder.layers.{layer}.attention.out_proj.bias" |
| | ] = f"encoder.layers.{layer}.self_attn.out_proj.bias" |
| | mapping[ |
| | f"encoder.layers.{layer}.attention.out_proj.weight" |
| | ] = f"encoder.layers.{layer}.self_attn.out_proj.weight" |
| |
|
| | mapping[ |
| | f"encoder.layers.{layer}.feed_forward.intermediate_dense.bias" |
| | ] = f"encoder.layers.{layer}.fc1.bias" |
| | mapping[ |
| | f"encoder.layers.{layer}.feed_forward.intermediate_dense.weight" |
| | ] = f"encoder.layers.{layer}.fc1.weight" |
| |
|
| | mapping[ |
| | f"encoder.layers.{layer}.feed_forward.output_dense.bias" |
| | ] = f"encoder.layers.{layer}.fc2.bias" |
| | mapping[ |
| | f"encoder.layers.{layer}.feed_forward.output_dense.weight" |
| | ] = f"encoder.layers.{layer}.fc2.weight" |
| |
|
| | |
| | for layer in range(7): |
| | mapping[ |
| | f"feature_extractor.conv_layers.{layer}.conv.weight" |
| | ] = f"feature_extractor.conv_layers.{layer}.0.weight" |
| |
|
| | if layer != 0: |
| | continue |
| |
|
| | mapping[ |
| | f"feature_extractor.conv_layers.{layer}.layer_norm.weight" |
| | ] = f"feature_extractor.conv_layers.{layer}.2.weight" |
| | mapping[ |
| | f"feature_extractor.conv_layers.{layer}.layer_norm.bias" |
| | ] = f"feature_extractor.conv_layers.{layer}.2.bias" |
| |
|
| | hf_keys = set(hubert.state_dict().keys()) |
| | fair_keys = set(model.state_dict().keys()) |
| |
|
| | hf_keys -= set(mapping.keys()) |
| | fair_keys -= set(mapping.values()) |
| |
|
| | for i, j in zip(sorted(hf_keys), sorted(fair_keys)): |
| | print(i, j) |
| |
|
| | print(hf_keys, fair_keys) |
| | print(len(hf_keys), len(fair_keys)) |
| |
|
| | |
| | new_state_dict = {} |
| | for k, v in mapping.items(): |
| | new_state_dict[k] = model.state_dict()[v] |
| |
|
| | x = hubert.load_state_dict(new_state_dict, strict=False) |
| | print(x) |
| | hubert.eval() |
| |
|
| | with torch.no_grad(): |
| | new_input = torch.randn(1, 16384) |
| |
|
| | result1 = hubert(new_input, output_hidden_states=True)["hidden_states"][9] |
| | result1 = hubert.final_proj(result1) |
| |
|
| | result2 = model.extract_features( |
| | **{ |
| | "source": new_input, |
| | "padding_mask": torch.zeros(1, 16384, dtype=torch.bool), |
| | |
| | "output_layer": 9, |
| | } |
| | )[0] |
| | result2 = model.final_proj(result2) |
| |
|
| | assert torch.allclose(result1, result2, atol=1e-3) |
| |
|
| | print("Sanity check passed") |
| |
|
| | |
| | hubert.save_pretrained(".") |
| | print("Saved model") |
| |
|