| | import torch |
| | import torch.nn as nn |
| | import whisper |
| | from whisper.model import AudioEncoder, ModelDimensions |
| | from typing import Dict, Optional |
| | from whisperspeech.vq_stoks import RQBottleneckTransformer, Tunables |
| | from huggingface_hub import hf_hub_download |
| | import torch.nn.functional as F |
| | import os |
| | from typing import List, Optional, Union |
| | import io |
| | import urllib |
| | from tqdm import tqdm |
| | import torchaudio |
| |
|
| | _HF_MODELS = { |
| | "medium": "https://huggingface.co/jan-hq/WhisperVQ/resolve/main/medium_encoder_only.pt", |
| | } |
| |
|
| |
|
| | def available_models() -> List[str]: |
| | """Returns the names of available models""" |
| | return list(_HF_MODELS.keys()) |
| |
|
| |
|
| | def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]: |
| | os.makedirs(root, exist_ok=True) |
| |
|
| | expected_sha256 = url.split("/")[-2] |
| | download_target = os.path.join(root, os.path.basename(url)) |
| |
|
| | if os.path.exists(download_target) and not os.path.isfile(download_target): |
| | raise RuntimeError( |
| | f"{download_target} exists and is not a regular file") |
| |
|
| | if os.path.isfile(download_target): |
| | with open(download_target, "rb") as f: |
| | model_bytes = f.read() |
| | return model_bytes if in_memory else download_target |
| | import ssl |
| | ssl._create_default_https_context = ssl._create_unverified_context |
| | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: |
| | with tqdm( |
| | total=int(source.info().get("Content-Length")), |
| | ncols=80, |
| | unit="iB", |
| | unit_scale=True, |
| | unit_divisor=1024, |
| | ) as loop: |
| | while True: |
| | buffer = source.read(8192) |
| | if not buffer: |
| | break |
| |
|
| | output.write(buffer) |
| | loop.update(len(buffer)) |
| |
|
| | model_bytes = open(download_target, "rb").read() |
| | return model_bytes if in_memory else download_target |
| |
|
| |
|
| | class CustomWhisperEncoder(nn.Module): |
| | """ |
| | Lightweight wrapper that only loads the AudioEncoder part of Whisper |
| | """ |
| |
|
| | def __init__(self, name: str, device: str = None, download_root: str = None, in_memory: bool = False,): |
| | super().__init__() |
| | if device is None: |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | if download_root is None: |
| | default = os.path.join(os.path.expanduser("~"), ".cache") |
| | |
| | download_root = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) |
| |
|
| | if name in _HF_MODELS: |
| | checkpoint_file = _download( |
| | _HF_MODELS[name], download_root, in_memory) |
| | elif os.path.isfile(name): |
| | checkpoint_file = open(name, "rb").read() if in_memory else name |
| | else: |
| | raise RuntimeError( |
| | f"Model {name} not found; available models = {available_models()}" |
| | ) |
| |
|
| | |
| | with ( |
| | io.BytesIO(checkpoint_file) if in_memory else open( |
| | checkpoint_file, "rb") |
| | ) as fp: |
| | checkpoint = torch.load(fp, map_location=device) |
| | del checkpoint_file |
| | dims = ModelDimensions(**checkpoint["dims"]) |
| | self.encoder = AudioEncoder( |
| | dims.n_mels, |
| | dims.n_audio_ctx, |
| | dims.n_audio_state, |
| | dims.n_audio_head, |
| | dims.n_audio_layer, |
| | ) |
| |
|
| | self.encoder.load_state_dict(checkpoint["model_state_dict"]) |
| |
|
| | if device: |
| | self.to(device) |
| |
|
| | self.eval() |
| |
|
| | def forward(self, mel: torch.Tensor): |
| | return self.encoder(mel) |
| |
|
| |
|
| | class CustomRQBottleneckTransformer(RQBottleneckTransformer): |
| | def __init__(self, *args, **kwargs): |
| | super().__init__(*args, **kwargs) |
| |
|
| | @classmethod |
| | def load_vq_only(cls, ref="collabora/spear-tts-pytorch:whisper-vq-stoks-medium-en+pl.model", |
| | repo_id=None, filename=None, local_filename=None): |
| | if repo_id is None and filename is None and local_filename is None: |
| | if ":" in ref: |
| | repo_id, filename = ref.split(":", 1) |
| | else: |
| | local_filename = ref |
| | if not local_filename: |
| | local_filename = hf_hub_download( |
| | repo_id=repo_id, filename=filename) |
| |
|
| | |
| | spec = torch.load(local_filename) |
| |
|
| | |
| | instance = cls(**spec['config'], tunables=Tunables(** |
| | Tunables.upgrade(spec.get('tunables', {})))) |
| |
|
| | |
| | required_components = { |
| | 'rq', 'mlp', 'mlp_ln' |
| | } |
| | filtered_state_dict = { |
| | k: v for k, v in spec['state_dict'].items() |
| | if any(k.startswith(comp) for comp in required_components) |
| | } |
| |
|
| | instance.load_state_dict(filtered_state_dict, strict=False) |
| | instance.eval() |
| | return instance |
| |
|
| | def load_encoder(self, device=None): |
| | if self.whmodel is not None: |
| | return |
| | device = device or self.device |
| | |
| | if self.whmodel is None: |
| | encoder = CustomWhisperEncoder( |
| | self.whisper_model_name, device=device) |
| | self.whmodel = encoder |
| | multilingual = not self.whisper_model_name.endswith('.en') |
| | self.tokenizer = whisper.tokenizer.get_tokenizer(multilingual) |
| |
|
| | def optimzed_encode_mel(self, mel): |
| | assert len( |
| | mel.shape) == 3, "invalid mel spectrogram shape, expect (batch,chn,time)" |
| | self.load_encoder() |
| | n = mel.shape[-1] |
| | if n > whisper.audio.N_FRAMES: |
| | padding = 0 |
| | padded = mel[:, :, :whisper.audio.N_FRAMES] |
| | else: |
| | padding = -n % whisper.audio.N_FRAMES |
| | padded = F.pad(mel, (0, padding), value=-1.5) |
| | |
| | embs = self.whmodel.encoder(padded) |
| | stoks = self.quantize(embs) |
| | if self.tunables.mask_embs: |
| | return stoks[:, :n//2//self.downsample] |
| | else: |
| | return stoks |
| | |
| |
|
| | def encode_audio(self, audio): |
| | if isinstance(audio, str): |
| | x, sr = torchaudio.load(audio) |
| | x = torchaudio.transforms.Resample(sr, 16000)(x)[0] |
| | audio = x.unsqueeze(0) |
| | return self.optimzed_encode_mel(self.log_mel_spectrogram(audio).to(self.device)) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| | vqmodel = CustomRQBottleneckTransformer.load_vq_only( |
| | "whisper-vq-stoks-v3-7lang-fixed.model" |
| | ).to("cuda") |
| | vqmodel.load_encoder('cuda') |
| | vqmodel.eval() |
| |
|