Fabrice-TIERCELIN commited on
Commit
954e141
·
verified ·
1 Parent(s): 8b495b2

Upload qwen3_tts_tokenizer.py

Browse files
qwen_tts/inference/qwen3_tts_tokenizer.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 The Alibaba Qwen team.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import base64
17
+ import io
18
+ import urllib.request
19
+ from typing import List, Optional, Tuple, Union
20
+ from urllib.parse import urlparse
21
+
22
+ import librosa
23
+ import numpy as np
24
+ import soundfile as sf
25
+ import torch
26
+ from torch.nn.utils.rnn import pad_sequence
27
+ from transformers import AutoConfig, AutoFeatureExtractor, AutoModel
28
+
29
+ from ..core import (
30
+ Qwen3TTSTokenizerV1Config,
31
+ Qwen3TTSTokenizerV1Model,
32
+ Qwen3TTSTokenizerV2Config,
33
+ Qwen3TTSTokenizerV2Model,
34
+ )
35
+
36
+ AudioInput = Union[
37
+ str, # wav path, or base64 string
38
+ np.ndarray, # 1-D float array
39
+ List[str],
40
+ List[np.ndarray],
41
+ ]
42
+
43
+
44
+ class Qwen3TTSTokenizer:
45
+ """
46
+ A wrapper for Qwen3 TTS Tokenizer 25Hz/12Hz with HuggingFace-style loading.
47
+
48
+ - from_pretrained(): loads speech tokenizer model via AutoModel and feature_extractor via AutoFeatureExtractor.
49
+ - encode(): supports wav path(s), base64 audio string(s), numpy array(s).
50
+ - decode(): accepts either the raw model encode output, or a minimal dict/list-of-dicts.
51
+
52
+ Notes:
53
+ - For numpy array input, you must pass `sr` so the audio can be resampled to model sample rate.
54
+ - Returned audio is float32 numpy arrays and the output sample rate.
55
+ """
56
+
57
+ def __init__(self):
58
+ self.model = None
59
+ self.feature_extractor = None
60
+ self.config = None
61
+ self.device = None
62
+
63
+ @classmethod
64
+ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> "Qwen3TTSTokenizer":
65
+ """
66
+ Initialize tokenizer with HuggingFace `from_pretrained` style.
67
+
68
+ Args:
69
+ pretrained_model_name_or_path (str):
70
+ HuggingFace repo id or local directory.
71
+ **kwargs (Any):
72
+ Forwarded to `AutoModel.from_pretrained(...)` directly.
73
+ Typical examples: device_map="cuda:0", dtype=torch.bfloat16, attn_implementation="eager".
74
+
75
+ Returns:
76
+ Qwen3TTSTokenizer:
77
+ Initialized instance with `model`, `feature_extractor`, `config`.
78
+ """
79
+ inst = cls()
80
+
81
+ AutoConfig.register("qwen3_tts_tokenizer_25hz", Qwen3TTSTokenizerV1Config)
82
+ AutoModel.register(Qwen3TTSTokenizerV1Config, Qwen3TTSTokenizerV1Model)
83
+
84
+ AutoConfig.register("qwen3_tts_tokenizer_12hz", Qwen3TTSTokenizerV2Config)
85
+ AutoModel.register(Qwen3TTSTokenizerV2Config, Qwen3TTSTokenizerV2Model)
86
+
87
+ inst.feature_extractor = AutoFeatureExtractor.from_pretrained(pretrained_model_name_or_path)
88
+ inst.model = AutoModel.from_pretrained(pretrained_model_name_or_path, **kwargs)
89
+ inst.config = inst.model.config
90
+
91
+ inst.device = getattr(inst.model, "device", None)
92
+ if inst.device is None:
93
+ # fallback: infer from first parameter device
94
+ try:
95
+ inst.device = next(inst.model.parameters()).device
96
+ except StopIteration:
97
+ inst.device = torch.device("cpu")
98
+
99
+ return inst
100
+
101
+ def _is_probably_base64(self, s: str) -> bool:
102
+ if s.startswith("data:audio"):
103
+ return True
104
+ # Heuristic: no filesystem path separators and long enough.
105
+ if ("/" not in s and "\\" not in s) and len(s) > 256:
106
+ return True
107
+ return False
108
+
109
+ def _is_url(self, s: str) -> bool:
110
+ try:
111
+ u = urlparse(s)
112
+ return u.scheme in ("http", "https") and bool(u.netloc)
113
+ except Exception:
114
+ return False
115
+
116
+ def _decode_base64_to_wav_bytes(self, b64: str) -> bytes:
117
+ # Accept both "data:audio/wav;base64,...." and raw base64
118
+ if "," in b64 and b64.strip().startswith("data:"):
119
+ b64 = b64.split(",", 1)[1]
120
+ return base64.b64decode(b64)
121
+
122
+ def load_audio(
123
+ self,
124
+ x: str,
125
+ target_sr: int,
126
+ ) -> np.ndarray:
127
+ """
128
+ Load audio from wav path or base64 string, then resample to target_sr.
129
+
130
+ Args:
131
+ x (str):
132
+ A wav file path, or a base64 audio string (raw or data URL).
133
+ target_sr (int):
134
+ Target sampling rate.
135
+
136
+ Returns:
137
+ np.ndarray:
138
+ 1-D float32 waveform at target_sr.
139
+ """
140
+ if self._is_url(x):
141
+ with urllib.request.urlopen(x) as resp:
142
+ audio_bytes = resp.read()
143
+ with io.BytesIO(audio_bytes) as f:
144
+ audio, sr = sf.read(f, dtype="float32", always_2d=False)
145
+ elif self._is_probably_base64(x):
146
+ wav_bytes = self._decode_base64_to_wav_bytes(x)
147
+ with io.BytesIO(wav_bytes) as f:
148
+ audio, sr = sf.read(f, dtype="float32", always_2d=False)
149
+ else:
150
+ audio, sr = librosa.load(x, sr=None, mono=True)
151
+
152
+ if audio.ndim > 1:
153
+ audio = np.mean(audio, axis=-1)
154
+
155
+ if sr != target_sr:
156
+ audio = librosa.resample(y=audio, orig_sr=sr, target_sr=target_sr)
157
+
158
+ return audio.astype(np.float32)
159
+
160
+ def _normalize_audio_inputs(
161
+ self,
162
+ audios: AudioInput,
163
+ sr: Optional[int],
164
+ ) -> List[np.ndarray]:
165
+ """
166
+ Normalize all supported input types into a list of 1-D numpy float32 waveforms
167
+ at `self.feature_extractor.sampling_rate`.
168
+
169
+ Args:
170
+ audios (AudioInput):
171
+ - str: wav path OR base64 audio string
172
+ - np.ndarray: raw waveform (sr must be provided)
173
+ - list[str] / list[np.ndarray]
174
+ sr (Optional[int]):
175
+ Sampling rate for raw numpy input. Required if input is np.ndarray or list[np.ndarray].
176
+
177
+ Returns:
178
+ List[np.ndarray]:
179
+ List of float32 waveforms resampled to model input SR.
180
+ """
181
+ target_sr = int(self.feature_extractor.sampling_rate)
182
+
183
+ if isinstance(audios, (str, np.ndarray)):
184
+ audios = [audios]
185
+
186
+ if len(audios) == 0:
187
+ return []
188
+
189
+ if isinstance(audios[0], str):
190
+ # wav path list or base64 list
191
+ return [self.load_audio(x, target_sr=target_sr) for x in audios] # type: ignore[arg-type]
192
+
193
+ # numpy list
194
+ if sr is None:
195
+ raise ValueError("For numpy waveform input, you must provide `sr` (original sampling rate).")
196
+
197
+ out: List[np.ndarray] = []
198
+ for a in audios: # type: ignore[assignment]
199
+ if not isinstance(a, np.ndarray):
200
+ raise TypeError("Mixed input types are not supported. Use all paths/base64 or all numpy arrays.")
201
+ if a.ndim > 1:
202
+ a = np.mean(a, axis=-1)
203
+ if int(sr) != target_sr:
204
+ a = librosa.resample(y=a.astype(np.float32), orig_sr=int(sr), target_sr=target_sr)
205
+ out.append(a.astype(np.float32))
206
+ return out
207
+
208
+ def encode(
209
+ self,
210
+ audios: AudioInput,
211
+ sr: Optional[int] = None,
212
+ return_dict: bool = True,
213
+ ):
214
+ """
215
+ Batch-encode audio into discrete codes (and optional conditioning, depending on 25Hz/12Hz).
216
+
217
+ Args:
218
+ audios (AudioInput):
219
+ Supported forms:
220
+ - np.ndarray: waveform (requires sr)
221
+ - list[np.ndarray]: waveforms (requires sr)
222
+ - str: wav path OR base64 audio string
223
+ - list[str]: wav paths and/or base64 strings
224
+ sr (Optional[int], default=None):
225
+ Original sampling rate for numpy waveform input.
226
+ return_dict (bool, default=True):
227
+ Forwarded to model.encode(...). If True, returns ModelOutput.
228
+
229
+ Returns:
230
+ 25Hz:
231
+ Qwen3TTSTokenizerV1EncoderOutput (if return_dict=True) with fields:
232
+ - audio_codes: List[torch.LongTensor] each (codes_len,)
233
+ - xvectors: List[torch.FloatTensor] each (xvector_dim,)
234
+ - ref_mels: List[torch.FloatTensor] each (mel_len, mel_dim)
235
+ 12Hz:
236
+ Qwen3TTSTokenizerV2EncoderOutput (if return_dict=True) with fields:
237
+ - audio_codes: List[torch.LongTensor] each (codes_len, num_quantizers)
238
+
239
+ If return_dict=False, returns the raw tuple from model.encode.
240
+ """
241
+ wavs = self._normalize_audio_inputs(audios, sr=sr)
242
+
243
+ inputs = self.feature_extractor(
244
+ raw_audio=wavs,
245
+ sampling_rate=int(self.feature_extractor.sampling_rate),
246
+ return_tensors="pt",
247
+ )
248
+ inputs = inputs.to(self.device).to(self.model.dtype)
249
+
250
+ with torch.inference_mode():
251
+ # model.encode expects (B, T) and (B, T)
252
+ enc = self.model.encode(
253
+ inputs["input_values"].squeeze(1),
254
+ inputs["padding_mask"].squeeze(1),
255
+ return_dict=return_dict,
256
+ )
257
+ return enc
258
+
259
+ def decode(
260
+ self,
261
+ encoded,
262
+ ) -> Tuple[List[np.ndarray], int]:
263
+ """
264
+ Decode back to waveform.
265
+
266
+ Usage:
267
+ 1) Pass the raw output of `encode(...)` directly (recommended).
268
+ - 25Hz: expects fields audio_codes, xvectors, ref_mels
269
+ - 12Hz: expects field audio_codes
270
+ 2) Pass a dict or list[dict] (minimal form) for custom pipelines:
271
+ - 25Hz dict keys: {"audio_codes", "xvectors", "ref_mels"}
272
+ - 12Hz dict keys: {"audio_codes"}
273
+ Values can be torch tensors or numpy arrays.
274
+
275
+ Args:
276
+ encoded (Any):
277
+ - ModelOutput returned by `encode()`, OR
278
+ - dict, OR
279
+ - list[dict]
280
+
281
+ Returns:
282
+ Tuple[List[np.ndarray], int]:
283
+ - wavs: list of 1-D float32 numpy arrays
284
+ - sample_rate: int, model output sampling rate
285
+ """
286
+ model_type = self.model.get_model_type()
287
+
288
+ def _to_tensor(x, dtype=None):
289
+ if isinstance(x, torch.Tensor):
290
+ return x
291
+ x = np.asarray(x)
292
+ t = torch.from_numpy(x)
293
+ if dtype is not None:
294
+ t = t.to(dtype)
295
+ return t
296
+
297
+ # Normalize `encoded` into the same shapes as the official demo uses.
298
+ if hasattr(encoded, "audio_codes"):
299
+ # ModelOutput from encode()
300
+ audio_codes_list = encoded.audio_codes
301
+ xvectors_list = getattr(encoded, "xvectors", None)
302
+ ref_mels_list = getattr(encoded, "ref_mels", None)
303
+ elif isinstance(encoded, dict):
304
+ audio_codes_list = encoded["audio_codes"]
305
+ xvectors_list = encoded.get("xvectors", None)
306
+ ref_mels_list = encoded.get("ref_mels", None)
307
+ elif isinstance(encoded, list):
308
+ # list of dicts
309
+ audio_codes_list = [e["audio_codes"] for e in encoded]
310
+ xvectors_list = [e["xvectors"] for e in encoded] if ("xvectors" in encoded[0]) else None
311
+ ref_mels_list = [e["ref_mels"] for e in encoded] if ("ref_mels" in encoded[0]) else None
312
+ else:
313
+ raise TypeError("`encoded` must be an encode output, a dict, or a list of dicts.")
314
+
315
+ # Ensure list form for per-sample tensors
316
+ if isinstance(audio_codes_list, torch.Tensor):
317
+ # Could be a single sample tensor or an already padded batch tensor.
318
+ t = audio_codes_list
319
+ if t.dim() == 1:
320
+ # 25Hz single sample: (C,) -> (1, C)
321
+ t = t.unsqueeze(0)
322
+ elif t.dim() == 2:
323
+ # 12Hz single sample: (C, Q) -> (1, C, Q)
324
+ t = t.unsqueeze(0)
325
+ audio_codes_padded = t.to(self.device)
326
+ else:
327
+ # List[Tensor/np]
328
+ audio_codes_list = [_to_tensor(c, dtype=torch.long) for c in audio_codes_list]
329
+ audio_codes_padded = pad_sequence(audio_codes_list, batch_first=True, padding_value=0).to(self.device)
330
+
331
+ with torch.inference_mode():
332
+ if model_type == "qwen3_tts_tokenizer_25hz":
333
+ if xvectors_list is None or ref_mels_list is None:
334
+ raise ValueError("25Hz decode requires `xvectors` and `ref_mels`.")
335
+
336
+ if isinstance(xvectors_list, torch.Tensor):
337
+ xvectors_batch = xvectors_list
338
+ if xvectors_batch.dim() == 1: # (D,) -> (1, D)
339
+ xvectors_batch = xvectors_batch.unsqueeze(0)
340
+ xvectors_batch = xvectors_batch.to(self.device).to(self.model.dtype)
341
+ else:
342
+ xvectors_list = [_to_tensor(x, dtype=torch.float32) for x in xvectors_list]
343
+ xvectors_batch = torch.stack(xvectors_list, dim=0).to(self.device).to(self.model.dtype)
344
+
345
+ if isinstance(ref_mels_list, torch.Tensor):
346
+ ref_mels_padded = ref_mels_list
347
+ if ref_mels_padded.dim() == 2: # (T, M) -> (1, T, M)
348
+ ref_mels_padded = ref_mels_padded.unsqueeze(0)
349
+ ref_mels_padded = ref_mels_padded.to(self.device).to(self.model.dtype)
350
+ else:
351
+ ref_mels_list = [_to_tensor(m, dtype=torch.float32) for m in ref_mels_list]
352
+ ref_mels_padded = pad_sequence(ref_mels_list, batch_first=True, padding_value=0).to(self.device).to(self.model.dtype)
353
+
354
+ dec = self.model.decode(audio_codes_padded, xvectors_batch, ref_mels_padded, return_dict=True)
355
+ wav_tensors = dec.audio_values
356
+
357
+ elif model_type == "qwen3_tts_tokenizer_12hz":
358
+ dec = self.model.decode(audio_codes_padded, return_dict=True)
359
+ wav_tensors = dec.audio_values
360
+
361
+ else:
362
+ raise ValueError(f"Unknown model type: {model_type}")
363
+
364
+ wavs = [w.to(torch.float32).detach().cpu().numpy() for w in wav_tensors]
365
+ return wavs, int(self.model.get_output_sample_rate())
366
+
367
+ def get_model_type(self) -> str:
368
+ """
369
+ Get the underlying tokenizer model type.
370
+
371
+ Returns:
372
+ str: Model type string from `self.model.config.model_type`
373
+ (e.g. "qwen3_tts_tokenizer_25hz" / "qwen3_tts_tokenizer_12hz").
374
+ """
375
+ return self.model.get_model_type()
376
+
377
+ def get_input_sample_rate(self) -> int:
378
+ """
379
+ Get the expected input sample rate for encoding.
380
+
381
+ Returns:
382
+ int: Input sample rate (Hz).
383
+ """
384
+ return int(self.model.get_input_sample_rate())
385
+
386
+ def get_output_sample_rate(self) -> int:
387
+ """
388
+ Get the output sample rate for decoded waveforms.
389
+
390
+ Returns:
391
+ int: Output sample rate (Hz).
392
+ """
393
+ return int(self.model.get_output_sample_rate())
394
+
395
+ def get_encode_downsample_rate(self) -> int:
396
+ """
397
+ Get the encoder downsample rate (waveform samples per code step).
398
+
399
+ Returns:
400
+ int: Encode downsample rate.
401
+ """
402
+ return int(self.model.get_encode_downsample_rate())
403
+
404
+ def get_decode_upsample_rate(self) -> int:
405
+ """
406
+ Get the decoder upsample rate (waveform samples per code step).
407
+
408
+ Returns:
409
+ int: Decode upsample rate.
410
+ """
411
+ return int(self.model.get_decode_upsample_rate())