| | import io |
| | import os |
| | import pickle |
| | from typing import Optional, Dict, Callable |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | import torchaudio |
| |
|
| | from transformers import PreTrainedModel |
| |
|
| | from .configuration_speech_encoder import SpeechEncoderConfig |
| |
|
| |
|
| | def wrap_bos_eos( |
| | units: torch.Tensor, |
| | durations: torch.Tensor, |
| | f0: torch.Tensor | None, |
| | dense_features: torch.Tensor, |
| | bos: torch.Tensor, |
| | eos: torch.Tensor, |
| | ): |
| | |
| | one = durations.new_ones(1) |
| | units = torch.cat([bos.to(units.device), units, eos.to(units.device)], dim=0) |
| | durations = torch.cat([one, durations, one], dim=0) |
| | if f0 is not None: |
| | |
| | f0 = torch.cat([f0[:1], f0, f0[-1:]], dim=0) |
| | return units, durations, f0, dense_features |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class _FairseqHubertDense(nn.Module): |
| | """ |
| | Loads a fairseq HuBERT checkpoint (.pt) and exposes extract_features() at a |
| | given transformer layer. |
| | """ |
| | def __init__(self, ckpt_path: str, layer: int, expected_sr: int = 16000, hop: int = 320): |
| | super().__init__() |
| | try: |
| | from fairseq import checkpoint_utils |
| | except Exception as e: |
| | raise ImportError( |
| | "fairseq is required to load a .pt HuBERT checkpoint. " |
| | "Please `pip install fairseq` in your runtime." |
| | ) from e |
| |
|
| | models, _, _ = checkpoint_utils.load_model_ensemble_and_task([ckpt_path]) |
| | self.model = models[0] |
| | self.model.eval() |
| | for p in self.model.parameters(): |
| | p.requires_grad_(False) |
| |
|
| | self.output_layer = int(layer) |
| | self.expected_sample_rate = int(expected_sr) |
| | self.code_hop_size = int(hop) |
| |
|
| | @torch.no_grad() |
| | def forward(self, waveform: torch.Tensor) -> torch.Tensor: |
| | if waveform.ndim > 1: |
| | waveform = waveform.mean(0) |
| | wav = waveform.unsqueeze(0) |
| | |
| | feats, _ = self.model.extract_features(wav, output_layer=self.output_layer) |
| | |
| | return feats[0] |
| |
|
| |
|
| | class _TransformersHubertDense(nn.Module): |
| | """ |
| | Uses transformers' facebook/hubert-* checkpoints. |
| | """ |
| | def __init__(self, hf_name: str, layer: int, expected_sr: int = 16000, hop: int = 320): |
| | super().__init__() |
| | from transformers import AutoModel |
| | self.backbone = AutoModel.from_pretrained(hf_name) |
| | self.backbone.eval() |
| | for p in self.backbone.parameters(): |
| | p.requires_grad_(False) |
| | self.layer = int(layer) |
| | self.expected_sample_rate = int(expected_sr) |
| | self.code_hop_size = int(hop) |
| |
|
| | @torch.no_grad() |
| | def forward(self, waveform: torch.Tensor) -> torch.Tensor: |
| | if waveform.ndim > 1: |
| | waveform = waveform.mean(0) |
| | |
| | |
| | out = self.backbone( |
| | inputs_embeds=None, |
| | input_values=waveform.unsqueeze(0), |
| | output_hidden_states=True, |
| | ) |
| | |
| | hidden = out.hidden_states[self.layer] |
| | return hidden[0] |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class KMeansQuantizer(nn.Module): |
| | """ |
| | Simple KMeans quantizer: nearest center assignment per frame. |
| | Loads centers from: |
| | * .pt (Tensor or dict with keys: cluster_centers, cluster_centers_, centroids, centers) |
| | * .npy |
| | * pickle/joblib of a scikit KMeans with .cluster_centers_ |
| | """ |
| | def __init__(self, centers: torch.Tensor): |
| | super().__init__() |
| | assert centers.ndim == 2, "centers must be (K, D)" |
| | self.register_buffer("centers", centers.float()) |
| |
|
| | @property |
| | def vocab_size(self) -> int: |
| | return int(self.centers.size(0)) |
| |
|
| | @staticmethod |
| | def _to_tensor(x): |
| | if torch.is_tensor(x): |
| | return x |
| | return torch.from_numpy(np.asarray(x)) |
| |
|
| | @classmethod |
| | def from_file(cls, path: str, key: str = "") -> "KMeansQuantizer": |
| | path = os.fspath(path) |
| | if not os.path.exists(path): |
| | raise FileNotFoundError(f"KMeans file not found: {path}") |
| |
|
| | centers = None |
| |
|
| | if path.endswith(".pt") or path.endswith(".pth"): |
| | obj = torch.load(path, map_location="cpu") |
| | if torch.is_tensor(obj): |
| | centers = obj |
| | elif isinstance(obj, dict): |
| | for k in [key, "cluster_centers", "cluster_centers_", "centroids", "centers"]: |
| | if k and k in obj: |
| | centers = cls._to_tensor(obj[k]); break |
| | if centers is None: |
| | |
| | for v in obj.values(): |
| | if isinstance(v, dict): |
| | for k in ["cluster_centers", "cluster_centers_", "centroids", "centers"]: |
| | if k in v: |
| | centers = cls._to_tensor(v[k]); break |
| | if centers is not None: |
| | break |
| |
|
| | if centers is None and path.endswith(".npy"): |
| | centers = torch.from_numpy(np.load(path)) |
| |
|
| | if centers is None: |
| | |
| | try: |
| | import joblib |
| | obj = joblib.load(path) |
| | except Exception: |
| | with open(path, "rb") as f: |
| | obj = pickle.load(f) |
| | if hasattr(obj, "cluster_centers_"): |
| | centers = torch.from_numpy(np.asarray(obj.cluster_centers_)) |
| |
|
| | if centers is None: |
| | raise ValueError( |
| | f"Could not load KMeans centers from {path}. " |
| | "Supported: .pt (tensor/dict), .npy, pickled sklearn KMeans." |
| | ) |
| |
|
| | return cls(centers.float()) |
| |
|
| | @torch.no_grad() |
| | def forward(self, dense_features: torch.Tensor) -> torch.Tensor: |
| | """ |
| | dense_features: (T, D) or (B, T, D) -> returns (T,) or (B,T,) int64 |
| | """ |
| | x = dense_features |
| | if x.ndim == 2: |
| | dist = torch.cdist(x.to(self.centers.dtype), self.centers) |
| | return torch.argmin(dist, dim=-1).to(torch.long) |
| | elif x.ndim == 3: |
| | B, T, D = x.shape |
| | x2 = x.reshape(B * T, D) |
| | dist = torch.cdist(x2.to(self.centers.dtype), self.centers) |
| | ids = torch.argmin(dist, dim=-1).to(torch.long).view(B, T) |
| | return ids |
| | else: |
| | raise ValueError("dense_features must be (T,D) or (B,T,D)") |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | F0_FRAME_SPACE = 0.01 |
| |
|
| |
|
| | class SpeechEncoder(PreTrainedModel): |
| | """ |
| | Hugging Face–ready port of the Textless 'SpeechEncoder'. |
| | |
| | * Has the same public methods as the original (by_name, maybe_resample, forward, properties). |
| | * Loads your uploaded HuBERT checkpoint and KMeans centers from the repo. |
| | * `need_f0` is supported as a flag, but F0 extraction is not implemented in this minimal port. |
| | """ |
| | config_class = SpeechEncoderConfig |
| |
|
| | def __init__( |
| | self, |
| | dense_model: nn.Module, |
| | quantizer_model: nn.Module, |
| | deduplicate: bool, |
| | add_bos_eos: bool = False, |
| | need_f0: bool = False, |
| | f0_normalizer: Optional[Callable] = None, |
| | f0_quantizer: Optional[Callable] = None, |
| | config: Optional[SpeechEncoderConfig] = None, |
| | ): |
| | super().__init__(config if config is not None else SpeechEncoderConfig()) |
| | self.dense_model = dense_model |
| | self.quantizer_model = quantizer_model |
| |
|
| | self.deduplicate = bool(deduplicate) |
| | self.add_bos_eos = bool(add_bos_eos) |
| | self.need_f0 = bool(need_f0) |
| | self.f0_normalizer = f0_normalizer |
| | self.f0_quantizer = f0_quantizer |
| |
|
| | self.unit_vocab_size = int(self.quantizer_model.vocab_size) |
| |
|
| | bos_id = self.config.bos_id if self.config and self.config.bos_id is not None else self.unit_vocab_size |
| | eos_id = self.config.eos_id if self.config and self.config.eos_id is not None else self.unit_vocab_size + 1 |
| | self.register_buffer("bos", torch.tensor([bos_id], dtype=torch.long)) |
| | self.register_buffer("eos", torch.tensor([eos_id], dtype=torch.long)) |
| |
|
| | |
| | self.register_buffer("_float_tensor", torch.tensor([0.0], dtype=torch.float)) |
| |
|
| | |
| | self._feature_norm = getattr(self.config, "feature_norm", None) |
| |
|
| | |
| | @classmethod |
| | def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): |
| | """ |
| | Loads config, constructs dense+quantizer from files inside the repo, |
| | and returns a ready-to-use SpeechEncoder (no weights to load into state_dict). |
| | """ |
| | config = SpeechEncoderConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) |
| |
|
| | |
| | repo_root = os.fspath(pretrained_model_name_or_path) |
| | hubert_path = os.path.join(repo_root, config.hubert_ckpt) |
| | quant_path = os.path.join(repo_root, config.quantizer_file) |
| |
|
| | |
| | if config.hubert_backend == "fairseq": |
| | dense = _FairseqHubertDense( |
| | ckpt_path=hubert_path, |
| | layer=config.hubert_layer, |
| | expected_sr=config.expected_sample_rate, |
| | hop=config.code_hop_size, |
| | ) |
| | elif config.hubert_backend == "transformers": |
| | dense = _TransformersHubertDense( |
| | hf_name=config.hubert_hf_name, |
| | layer=config.hubert_layer, |
| | expected_sr=config.expected_sample_rate, |
| | hop=config.code_hop_size, |
| | ) |
| | else: |
| | raise ValueError("hubert_backend must be 'fairseq' or 'transformers'") |
| |
|
| | |
| | quant = KMeansQuantizer.from_file(quant_path, key=config.quantizer_key) |
| |
|
| | |
| | model = cls( |
| | dense_model=dense, |
| | quantizer_model=quant, |
| | deduplicate=config.deduplicate, |
| | add_bos_eos=config.add_bos_eos, |
| | need_f0=config.need_f0, |
| | f0_normalizer=None, |
| | f0_quantizer=None, |
| | config=config, |
| | ) |
| | return model |
| |
|
| | |
| | @classmethod |
| | def by_name( |
| | cls, |
| | dense_model_name: str, |
| | quantizer_model_name: str, |
| | vocab_size: int, |
| | deduplicate: bool, |
| | add_bos_eos: bool = False, |
| | need_f0: bool = False, |
| | f0_normalizer: Optional[Callable] = None, |
| | f0_quantizer: Optional[Callable] = None, |
| | |
| | hubert_backend: str = "fairseq", |
| | hubert_ckpt: Optional[str] = None, |
| | hubert_hf_name: str = "facebook/hubert-base-ls960", |
| | hubert_layer: int = 9, |
| | quantizer_file: Optional[str] = None, |
| | quantizer_key: str = "", |
| | expected_sample_rate: int = 16000, |
| | code_hop_size: int = 320, |
| | ) -> "SpeechEncoder": |
| | """ |
| | Mirrors textlesslib's SpeechEncoder.by_name. For HF usage prefer: |
| | AutoModel.from_pretrained(repo, trust_remote_code=True) |
| | """ |
| | |
| | if hubert_backend == "fairseq": |
| | if not hubert_ckpt: |
| | raise ValueError("Provide hubert_ckpt (path to .pt) when hubert_backend='fairseq'.") |
| | dense = _FairseqHubertDense(hubert_ckpt, layer=hubert_layer, |
| | expected_sr=expected_sample_rate, hop=code_hop_size) |
| | elif hubert_backend == "transformers": |
| | dense = _TransformersHubertDense(hubert_hf_name, layer=hubert_layer, |
| | expected_sr=expected_sample_rate, hop=code_hop_size) |
| | else: |
| | raise ValueError("hubert_backend must be 'fairseq' or 'transformers'") |
| |
|
| | if quantizer_model_name.lower() != "kmeans": |
| | raise ValueError("Only 'kmeans' quantizer is supported in this port.") |
| | if not quantizer_file: |
| | raise ValueError("Provide quantizer_file (path to centers).") |
| | quant = KMeansQuantizer.from_file(quantizer_file, key=quantizer_key) |
| |
|
| | |
| | if vocab_size is not None and int(vocab_size) != quant.vocab_size: |
| | raise ValueError(f"vocab_size={vocab_size} does not match centers K={quant.vocab_size}") |
| |
|
| | cfg = SpeechEncoderConfig( |
| | hubert_backend=hubert_backend, |
| | hubert_ckpt=hubert_ckpt or "", |
| | hubert_hf_name=hubert_hf_name, |
| | hubert_layer=hubert_layer, |
| | expected_sample_rate=expected_sample_rate, |
| | code_hop_size=code_hop_size, |
| | quantizer_file=os.path.basename(quantizer_file), |
| | deduplicate=deduplicate, |
| | add_bos_eos=add_bos_eos, |
| | need_f0=need_f0, |
| | ) |
| | return cls(dense, quant, deduplicate, add_bos_eos, need_f0, f0_normalizer, f0_quantizer, config=cfg) |
| |
|
| | |
| | @property |
| | def device(self) -> torch.device: |
| | return self._float_tensor.device |
| |
|
| | @property |
| | def vocab_size(self) -> int: |
| | return self.quantizer_model.vocab_size |
| |
|
| | @property |
| | def code_hop_size(self) -> int: |
| | return getattr(self.dense_model, "code_hop_size", 320) |
| |
|
| | @property |
| | def expected_sample_rate(self) -> int: |
| | return getattr(self.dense_model, "expected_sample_rate", 16000) |
| |
|
| | @property |
| | def f0_code_ratio(self) -> float: |
| | |
| | return self.code_hop_size / self.expected_sample_rate / F0_FRAME_SPACE |
| |
|
| | |
| | def maybe_resample(self, waveform: torch.Tensor, input_sample_rate: int) -> torch.Tensor: |
| | if int(input_sample_rate) == int(self.expected_sample_rate): |
| | return waveform |
| | return torchaudio.functional.resample( |
| | waveform, int(input_sample_rate), int(self.expected_sample_rate) |
| | ) |
| |
|
| | |
| | @torch.no_grad() |
| | def forward(self, waveform: torch.Tensor, speaker: Optional[str] = None) -> Dict[str, torch.Tensor]: |
| | """ |
| | Returns: |
| | { |
| | "units": LongTensor [U], |
| | "durations": LongTensor [U], (frame counts) |
| | "dense": FloatTensor [T, D], |
| | (optional) "f0": FloatTensor [U] or [T_f0] if implemented |
| | } |
| | """ |
| | |
| | dense_features = self.dense_model(waveform) |
| |
|
| | |
| | if self._feature_norm == "unit": |
| | eps = 1e-6 |
| | dense_features = dense_features / (dense_features.norm(dim=-1, keepdim=True) + eps) |
| | elif self._feature_norm == "layernorm": |
| | mean = dense_features.mean(dim=-1, keepdim=True) |
| | std = dense_features.std(dim=-1, keepdim=True).clamp_min(1e-5) |
| | dense_features = (dense_features - mean) / std |
| |
|
| | |
| | ids_per_frame = self.quantizer_model(dense_features) |
| |
|
| | |
| | if self.deduplicate: |
| | units, durations = torch.unique_consecutive(ids_per_frame, return_counts=True) |
| | else: |
| | units = ids_per_frame |
| | durations = torch.ones_like(units, dtype=torch.long) |
| |
|
| | |
| | f0 = None |
| | if self.need_f0: |
| | raise NotImplementedError( |
| | "F0 extraction is not included in this minimal HF port. " |
| | "Set need_f0=False (as in the reference pipeline)." |
| | ) |
| |
|
| | |
| | if self.add_bos_eos: |
| | units, durations, f0, dense_features = wrap_bos_eos( |
| | units, durations, f0, dense_features, self.bos, self.eos |
| | ) |
| |
|
| | item = { |
| | "units": units.to(self.device), |
| | "durations": durations.to(self.device), |
| | "dense": dense_features.to(self.device), |
| | } |
| | if f0 is not None: |
| | item["f0"] = f0.to(self.device) |
| | return item |
| |
|