| import torch |
| from transformers import HubertModel, HubertConfig |
| import safetensors.torch as st |
| from safetensors import safe_open |
| import json |
|
|
| class HubertModelWithFinalProj(HubertModel): |
| def __init__(self, config): |
| super().__init__(config) |
|
|
| |
| |
| |
| self.final_proj = torch.nn.Linear(config.hidden_size, config.classifier_proj_size) |
|
|
| @staticmethod |
| def load_safetensors(path: str, device="cpu"): |
| assert path.endswith(".safetensors"), f"{path} must end with '.safetensors'" |
| |
| with safe_open(path, framework="pt", device="cpu") as f: |
| metadata = f.metadata() |
| state_dict = {} |
| for key in f.keys(): |
| state_dict[key] = f.get_tensor(key) |
| model = HubertModelWithFinalProj(HubertConfig.from_dict(json.loads(metadata["config"]))) |
| model.load_state_dict(state_dict=state_dict) |
| return model.to(device) |
| |
| def save_safetensors(self, path: str): |
| assert path.endswith(".safetensors"), f"{path} must end with '.safetensors'" |
| |
| with open(path,"wb") as f: |
| state_dict = self.state_dict() |
| f.write(st.save(state_dict,dict(config=json.dumps(self.config.to_dict())))) |
|
|
| def extract_features(self, source: torch.Tensor, version="v2", **kwargs): |
| with torch.no_grad(): |
| output_layer = 9 if version == "v1" else 12 |
| output = self(source.to(self.config.torch_dtype), output_hidden_states=True)["hidden_states"][output_layer] |
| features = self.final_proj(output) if version == "v1" else output |
| return features |