| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import logging |
| import os |
| import random |
| import uuid as uuid_module |
| from collections import OrderedDict, defaultdict |
| from pathlib import Path |
| from typing import List, Optional, Sequence, Tuple, Union |
|
|
| import numpy as np |
| import onnxruntime |
| from hyperpyyaml import load_hyperpyyaml |
|
|
| import torch |
| import torchaudio |
| import torchaudio.compliance.kaldi as kaldi |
| from safetensors.torch import load_file |
| from torch import nn |
| from transformers import PreTrainedModel, WhisperFeatureExtractor |
|
|
| from .configuration_moss_speech_codec import MossSpeechCodecConfig |
| from .modeling_whisper import WhisperVQEncoder, WhisperVQConfig |
| from .utils import extract_speech_token |
|
|
| logger = logging.getLogger(__name__) |
|
|
| def set_seed(seed: int) -> None: |
| if not isinstance(seed, int): |
| raise TypeError("Seed must be an integer.") |
|
|
| logger.info("Setting random seed to %s", seed) |
| random.seed(seed) |
| np.random.seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(seed) |
| torch.backends.cudnn.deterministic = True |
| torch.backends.cudnn.benchmark = False |
| else: |
| torch.manual_seed(seed) |
| os.environ["PYTHONHASHSEED"] = str(seed) |
| os.environ["TF_CUDNN_DETERMINISTIC"] = "1" |
|
|
|
|
| def fade_in_out(fade_in_mel, fade_out_mel, window): |
| device = fade_in_mel.device |
| fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu() |
| mel_overlap_len = int(window.shape[0] / 2) |
| fade_in_mel[..., :mel_overlap_len] = fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \ |
| fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:] |
| return fade_in_mel.to(device) |
|
|
|
|
| tts_speech_prev = None |
| tts_mel_prev = None |
|
|
|
|
| class AudioDecoder(nn.Module): |
| def __init__( |
| self, |
| config_path: Union[str, os.PathLike], |
| flow_ckpt_path: Union[str, os.PathLike], |
| hift_ckpt_path: Union[str, os.PathLike], |
| campplus_model: Union[str, os.PathLike], |
| device: Union[str, torch.device] = "cuda", |
| ) -> None: |
| super().__init__() |
| self.device = torch.device(device) if isinstance(device, str) else device |
|
|
| with open(config_path, "r", encoding="utf-8") as config_file: |
| logger.info("Loading decoder configurations from %s", config_path) |
| self.scratch_configs = load_hyperpyyaml(config_file) |
|
|
| |
| self.flow = self.scratch_configs["flow"] |
| self.flow.load_state_dict(torch.load(flow_ckpt_path, map_location=self.device), strict=False) |
| self.hift = self.scratch_configs["hift"] |
| self.hift.load_state_dict(torch.load(hift_ckpt_path, map_location=self.device)) |
| self.hift = self.hift.eval() |
| self.sample_rate = self.scratch_configs["sample_rate"] |
| self.feat_extractor = self.scratch_configs["feat_extractor"] |
|
|
| |
| self.flow.to(self.device) |
| self.hift.to(self.device) |
| self.mel_overlap_dict = defaultdict(lambda: None) |
| self.hift_cache_dict = defaultdict(lambda: None) |
| self.token_min_hop_len = 2 * self.flow.input_frame_rate |
| self.token_max_hop_len = 4 * self.flow.input_frame_rate |
| self.token_overlap_len = 3.5 |
| self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 24000 / (480 * 2)) |
| self.mel_window = np.hamming(2 * self.mel_overlap_len) |
| |
| self.mel_cache_len = 1 |
| self.source_cache_len = int(self.mel_cache_len * 480) |
| |
| session_options = onnxruntime.SessionOptions() |
| session_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL |
| session_options.intra_op_num_threads = 1 |
| self.campplus_session = onnxruntime.InferenceSession( |
| str(campplus_model), |
| sess_options=session_options, |
| providers=["CPUExecutionProvider"], |
| ) |
| self.speech_window = np.hamming(2 * self.source_cache_len) |
|
|
| def token2wav( |
| self, |
| token: torch.Tensor, |
| uuid: str, |
| prompt_token: Optional[torch.Tensor] = None, |
| prompt_feat: Optional[torch.Tensor] = None, |
| embedding: Optional[torch.Tensor] = None, |
| finalize: bool = False, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| prompt_token = prompt_token if prompt_token is not None else torch.zeros(1, 0, dtype=torch.int32) |
| prompt_feat = prompt_feat if prompt_feat is not None else torch.zeros(1, 0, 80) |
| embedding = embedding if embedding is not None else torch.zeros(1, 192) |
|
|
| tts_mel = self.flow.inference( |
| token=token.to(self.device), |
| token_len=torch.tensor([token.shape[1]], dtype=torch.int32, device=self.device), |
| prompt_token=prompt_token.to(self.device), |
| prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32, device=self.device), |
| prompt_feat=prompt_feat.to(self.device), |
| prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32, device=self.device), |
| embedding=embedding.to(self.device), |
| streaming=False, |
| finalize=finalize, |
| ) |
|
|
| tts_mel = tts_mel[0] |
| if self.mel_overlap_dict[uuid] is not None: |
| tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window) |
| |
| if self.hift_cache_dict[uuid] is not None: |
| hift_cache_mel, hift_cache_source = ( |
| self.hift_cache_dict[uuid]["mel"], |
| self.hift_cache_dict[uuid]["source"], |
| ) |
| tts_mel = torch.cat([hift_cache_mel, tts_mel], dim=2) |
|
|
| else: |
| hift_cache_source = torch.zeros(1, 1, 0) |
|
|
| |
| if not finalize: |
| self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:] |
| tts_mel = tts_mel[:, :, :-self.mel_overlap_len] |
| tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source) |
|
|
| self.hift_cache_dict[uuid] = { |
| "mel": tts_mel[:, :, -self.mel_cache_len:], |
| "source": tts_source[:, :, -self.source_cache_len:], |
| "speech": tts_speech[:, -self.source_cache_len:], |
| } |
| tts_speech = tts_speech[:, :-self.source_cache_len] |
|
|
| else: |
| tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source) |
| del self.hift_cache_dict[uuid] |
| del self.mel_overlap_dict[uuid] |
| return tts_speech, tts_mel |
|
|
|
|
| def offline_inference(self, token: torch.Tensor) -> torch.Tensor: |
| this_uuid = str(uuid_module.uuid1()) |
| tts_speech, tts_mel = self.token2wav(token, uuid=this_uuid, finalize=True) |
| return tts_speech.cpu() |
|
|
| def stream_inference( |
| self, |
| token: torch.Tensor, |
| prompt_token: Optional[torch.Tensor] = None, |
| prompt_feat: Optional[torch.Tensor] = None, |
| embedding: Optional[torch.Tensor] = None, |
| block_size: int = 8, |
| ) -> torch.Tensor: |
| token = token.to(self.device) |
| this_uuid = str(uuid_module.uuid1()) |
|
|
| prompt_tensor = ( |
| prompt_token.to(self.device) |
| if prompt_token is not None |
| else torch.zeros(1, 0, dtype=torch.int32, device=self.device) |
| ) |
| prompt_speech_feat = ( |
| prompt_feat.to(self.device) |
| if prompt_feat is not None |
| else torch.zeros(1, 0, 80, device=self.device) |
| ) |
| embedding = embedding.to(self.device) if embedding is not None else torch.zeros(1, 192, device=self.device) |
|
|
| base_prompt_tensor = prompt_tensor |
| base_prompt_feat = prompt_speech_feat |
|
|
| tts_speechs: List[torch.Tensor] = [] |
| tts_mels: List[torch.Tensor] = [] |
| prev_mel: Optional[torch.Tensor] = None |
|
|
| for idx in range(0, token.size(1), block_size): |
| tts_token = token[:, idx : idx + block_size] |
|
|
| prompt_tensor_current = base_prompt_tensor |
| prompt_feat_current = base_prompt_feat |
| if prev_mel is not None: |
| prompt_feat_current = torch.cat( |
| [base_prompt_feat.transpose(1, 2)] + tts_mels, |
| dim=-1, |
| ).transpose(1, 2) |
| prompt_tensor_current = torch.cat([base_prompt_tensor, token[:, :idx]], dim=-1) |
|
|
| is_finalize = idx + block_size >= token.size(-1) |
|
|
| tts_speech, tts_mel = self.token2wav( |
| tts_token, |
| uuid=this_uuid, |
| prompt_token=prompt_tensor_current, |
| prompt_feat=prompt_feat_current, |
| embedding=embedding, |
| finalize=is_finalize, |
| ) |
|
|
| prev_mel = tts_mel |
| tts_speechs.append(tts_speech) |
| tts_mels.append(tts_mel) |
|
|
| tts_speech = torch.cat(tts_speechs, dim=-1).cpu() |
|
|
| return tts_speech |
|
|
| def streaming_inference( |
| self, |
| token: torch.Tensor, |
| prompt_token: Optional[torch.Tensor] = None, |
| prompt_feat: Optional[torch.Tensor] = None, |
| embedding: Optional[torch.Tensor] = None, |
| uuid: Optional[str] = None, |
| prev_mel: Optional[torch.Tensor] = None, |
| prev_token: Optional[torch.Tensor] = None, |
| is_finalize: bool = True, |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: |
| token = token.to(self.device) |
| this_uuid = uuid or str(uuid_module.uuid1()) |
|
|
| prompt_speech_feat = ( |
| prompt_feat.to(self.device) |
| if prompt_feat is not None |
| else torch.zeros(1, 0, 80, device=self.device) |
| ) |
| flow_prompt_speech_token = ( |
| prompt_token.to(self.device) |
| if prompt_token is not None |
| else torch.zeros(1, 0, dtype=torch.int32, device=self.device) |
| ) |
| embedding_tensor = ( |
| embedding.to(self.device) |
| if embedding is not None |
| else torch.zeros(1, 192, device=self.device) |
| ) |
|
|
| if prev_mel is not None: |
| prompt_speech_feat = prev_mel |
| if prev_token is not None: |
| flow_prompt_speech_token = prev_token |
|
|
| tts_speech, tts_mel = self.token2wav( |
| token, |
| uuid=this_uuid, |
| prompt_token=flow_prompt_speech_token, |
| prompt_feat=prompt_speech_feat, |
| embedding=embedding_tensor, |
| finalize=is_finalize, |
| ) |
|
|
| if prev_mel is not None: |
| prev_mel = torch.cat([prev_mel, tts_mel], dim=1) |
| else: |
| prev_mel = tts_mel |
| if prev_token is not None: |
| prev_token = torch.cat([prev_token, token], dim=-1) |
| else: |
| prev_token = token |
|
|
| return tts_speech.cpu(), prev_mel, prev_token |
|
|
|
|
| class MossSpeechCodec(PreTrainedModel): |
| """MossSpeech codec model (Whisper-VQ encoder + Flow/HiFT decoder). |
| |
| Notes |
| - API is designed to be compatible with the existing |
| `MossSpeechProcessor` usages, while adopting a Transformers-style layout |
| similar to HF codec models (`xcodec`, `encodec`). |
| - `encode` accepts raw audio tensors or file paths. It returns a Python |
| list of codec token ids per input sample for backward-compatibility. |
| - `decode` accepts either a 3D LongTensor `(B, 1, T)` or a nested list of |
| token ids, and returns a dict with a list of waveforms under |
| `"syn_wav_list"` (matching current processor expectations). |
| """ |
|
|
| config_class = MossSpeechCodecConfig |
|
|
| def __init__( |
| self, |
| encoder_weight_path: Union[str, os.PathLike], |
| encoder_config_path: Union[str, os.PathLike], |
| encoder_feature_extractor_path: Union[str, os.PathLike], |
| flow_path: Union[str, os.PathLike], |
| ) -> None: |
| super().__init__(config=MossSpeechCodecConfig()) |
|
|
| |
| self.sample_rate = 16000 |
| config = WhisperVQConfig.from_pretrained(str(encoder_config_path)) |
| self.whisper_vqmodel = WhisperVQEncoder(config) |
|
|
| state_dict = load_file(str(encoder_weight_path)) |
| new_state_dict: OrderedDict[str, torch.Tensor] = OrderedDict() |
| for k, v in state_dict.items(): |
| if k.startswith("encoder."): |
| new_state_dict[k[len("encoder."):]] = v |
| self.whisper_vqmodel.load_state_dict(new_state_dict, strict=False) |
|
|
| self.feature_extractor = WhisperFeatureExtractor.from_pretrained( |
| str(encoder_feature_extractor_path) |
| ) |
|
|
| |
| self.flow_path = str(flow_path) |
| self.audio_decoder = AudioDecoder( |
| config_path=os.path.join(self.flow_path, "config.yaml"), |
| flow_ckpt_path=os.path.join(self.flow_path, "flow.pt"), |
| hift_ckpt_path=os.path.join(self.flow_path, "hift.pt"), |
| campplus_model=os.path.join(self.flow_path, "campplus.onnx"), |
| ).eval() |
|
|
| @torch.no_grad() |
| def encode( |
| self, |
| inputs: Union[ |
| Sequence[Union[str, os.PathLike, Tuple[torch.Tensor, int], torch.Tensor]], |
| torch.Tensor, |
| ], |
| *, |
| sampling_rate: Optional[int] = None, |
| batch_size: int = 128, |
| ) -> List[List[int]]: |
| """Encode audio into codec token ids. |
| |
| Accepts one of: |
| - a list of file paths |
| - a list of `(waveform, sr)` tuples |
| - a list of 1D/2D waveforms (sr assumed 16k) |
| - a batched tensor with shape `(B, C, T)` or `(B, T)` |
| """ |
| |
| if isinstance(inputs, torch.Tensor): |
| if inputs.dim() == 2: |
| inputs = inputs.unsqueeze(1) |
| if inputs.dim() != 3: |
| raise ValueError("`inputs` must be (B, C, T) when passing a tensor.") |
| sr = sampling_rate or self.sample_rate |
| items: List[Tuple[torch.Tensor, int]] = [ |
| (inputs[i].squeeze(0).cpu(), sr) for i in range(inputs.size(0)) |
| ] |
| else: |
| items = list(inputs) |
|
|
| |
| audio_tokens: List[List[int]] = extract_speech_token( |
| self.whisper_vqmodel, self.feature_extractor, items, batch_size=batch_size |
| ) |
| return audio_tokens |
|
|
| def _extract_speech_feat(self, speech: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| speech_feat = self.audio_decoder.feat_extractor(speech).squeeze(dim=0).transpose(0, 1) |
| speech_feat = speech_feat.unsqueeze(dim=0) |
| speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32) |
| return speech_feat, speech_feat_len |
|
|
| def _extract_spk_embedding(self, speech_16k: torch.Tensor) -> torch.Tensor: |
| feat = kaldi.fbank(speech_16k, num_mel_bins=80, dither=0, sample_frequency=16000) |
| feat = feat - feat.mean(dim=0, keepdim=True) |
| embedding = self.audio_decoder.campplus_session.run( |
| None, |
| {self.audio_decoder.campplus_session.get_inputs()[0].name: feat.unsqueeze(0).cpu().numpy()}, |
| )[0].flatten().tolist() |
| return torch.tensor([embedding]) |
|
|
| @torch.no_grad() |
| def decode( |
| self, |
| audio_codes: Union[Sequence[Sequence[int]], torch.LongTensor], |
| *, |
| prompt_speech: Optional[Union[str, os.PathLike]] = None, |
| prompt_speech_sample_rate: Optional[int] = None, |
| use_spk_embedding: bool = True, |
| use_prompt_speech: bool = True, |
| finalize: bool = True, |
| device: torch.device = torch.device("cuda"), |
| ) -> dict: |
| """Decode codec token ids back to waveform(s). |
| |
| Args |
| - audio_codes: `(B, 1, T)` or Python nested lists per sample. |
| - prompt_speech: path to the enrollment audio used for conditioning. |
| Returns |
| - {"syn_wav_list": List[Tensor(T)]} |
| """ |
| if isinstance(audio_codes, torch.Tensor): |
| if audio_codes.dim() == 3 and audio_codes.size(1) == 1: |
| codes_list: List[List[int]] = [ |
| audio_codes[i, 0].detach().cpu().tolist() for i in range(audio_codes.size(0)) |
| ] |
| elif audio_codes.dim() == 2: |
| codes_list = [row.detach().cpu().tolist() for row in audio_codes] |
| else: |
| raise ValueError("`audio_codes` must be (B, 1, T) or (B, T) when passing a tensor.") |
| else: |
| codes_list = [list(c) for c in audio_codes] |
|
|
| if prompt_speech is None or not os.path.exists(str(prompt_speech)): |
| raise ValueError("`prompt_speech` path is required for decoding and must exist.") |
|
|
| prompt_wav, orig_sr = torchaudio.load(str(prompt_speech)) |
| target_sr = self.audio_decoder.sample_rate |
| if orig_sr != target_sr: |
| prompt_wav = torchaudio.transforms.Resample(orig_freq=orig_sr, new_freq=target_sr)(prompt_wav) |
|
|
| device = device if torch.cuda.is_available() or device.type == "cpu" else torch.device("cpu") |
| speech_token = torch.tensor(self.encode([str(prompt_speech)])[0], device=device).unsqueeze(0) |
| speech_feat, speech_feat_len = self._extract_speech_feat(prompt_wav) |
|
|
| if target_sr == 24000: |
| token_len = min(int(speech_feat.shape[1] / 4), speech_token.shape[1]) |
| speech_feat, speech_feat_len[:] = speech_feat[:, : 4 * token_len], 4 * token_len |
| speech_token, _ = speech_token[:, :token_len], token_len |
|
|
| prompt_16k = torchaudio.transforms.Resample(orig_freq=target_sr, new_freq=16000)(prompt_wav) |
| embedding = self._extract_spk_embedding(prompt_16k).to(device) |
|
|
| speech_feat = speech_feat.to(device) |
| speech_feat_len = speech_feat_len.to(device) |
|
|
| syn_wav_list: List[torch.Tensor] = [] |
| for codes in codes_list: |
| codes_t = torch.tensor(codes, device=device).unsqueeze(0) |
| uuid = os.urandom(16).hex() |
|
|
| kwargs = {"uuid": uuid, "finalize": finalize} |
| if use_prompt_speech: |
| kwargs.update({"prompt_token": speech_token, "prompt_feat": speech_feat}) |
| if use_spk_embedding: |
| kwargs.update({"embedding": embedding}) |
|
|
| tts_speech, _ = self.audio_decoder.token2wav(codes_t, **kwargs) |
| syn_wav_list.append(tts_speech.squeeze()) |
|
|
| return {"syn_wav_list": syn_wav_list} |
|
|
| @classmethod |
| def from_pretrained( |
| cls, |
| pretrained_model_name_or_path: Union[str, os.PathLike], |
| *, |
| revision: Optional[str] = None, |
| cache_dir: Optional[Union[str, os.PathLike]] = None, |
| force_download: bool = False, |
| local_files_only: bool = False, |
| token: Optional[Union[str, bool]] = None, |
| use_auth_token: Optional[Union[str, bool]] = None, |
| subfolder: Optional[str] = None, |
| **kwargs, |
| ): |
| """Instantiate codec from a local directory or a Hugging Face Hub repo. |
| This mirrors the typical Hugging Face ``from_pretrained`` behavior: |
| - If ``pretrained_model_name_or_path`` is a local folder, files are loaded from it. |
| - Otherwise, it is treated as a Hub repo ID and downloaded with ``snapshot_download``. |
| Expected layout inside the resolved base folder: |
| - ``model.safetensors`` (Whisper VQ encoder weights) |
| - ``config.json`` (Whisper VQ config) |
| - ``preprocessor_config.json`` (WhisperFeatureExtractor params) |
| - ``flow/{config.yaml, flow.pt, hift.pt, campplus.onnx}`` |
| """ |
| |
| base: Path |
| path_str = str(pretrained_model_name_or_path) |
| if os.path.isdir(path_str): |
| base = Path(path_str) |
| else: |
| try: |
| from huggingface_hub import snapshot_download |
| except Exception as exc: |
| raise RuntimeError( |
| "huggingface_hub is required to load from a repo id; please `pip install huggingface_hub`." |
| ) from exc |
| |
| if token is None and use_auth_token is not None: |
| token = use_auth_token |
| snapshot_path = snapshot_download( |
| repo_id=path_str, |
| revision=revision, |
| cache_dir=str(cache_dir) if cache_dir is not None else None, |
| force_download=force_download, |
| local_files_only=local_files_only, |
| token=token, |
| ) |
| base = Path(snapshot_path) |
| if subfolder: |
| base = base / subfolder |
| tokenizer_dir = base |
| flow_dir = base / "flow" |
| |
| missing: List[str] = [] |
| if not (tokenizer_dir / "model.safetensors").exists(): |
| missing.append(str(tokenizer_dir / "model.safetensors")) |
| if not (tokenizer_dir / "config.json").exists(): |
| missing.append(str(tokenizer_dir / "config.json")) |
| if not (tokenizer_dir / "preprocessor_config.json").exists(): |
| missing.append(str(tokenizer_dir / "preprocessor_config.json")) |
| for fname in ("config.yaml", "flow.pt", "hift.pt"): |
| if not (flow_dir / fname).exists(): |
| missing.append(str(flow_dir / fname)) |
| |
| has_campplus = (flow_dir / "campplus.onnx").exists() |
| if missing: |
| raise FileNotFoundError( |
| "Missing required codec assets under resolved path. The following files were not found: " |
| + ", ".join(missing) |
| ) |
| if not has_campplus: |
| logger.warning("campplus.onnx not found under %s; decoding speaker embedding may fail.", flow_dir) |
| encoder_weight_path = str(tokenizer_dir / "model.safetensors") |
| encoder_config_path = str(tokenizer_dir / "config.json") |
| encoder_feature_extractor_path = str(tokenizer_dir) |
| flow_path = str(flow_dir) |
| return cls( |
| encoder_weight_path=encoder_weight_path, |
| encoder_config_path=encoder_config_path, |
| encoder_feature_extractor_path=encoder_feature_extractor_path, |
| flow_path=flow_path, |
| ) |
|
|