Zhyw0 commited on
Commit
3699820
·
1 Parent(s): ca9f763

Add application file

Browse files
Files changed (2) hide show
  1. .gitignore +2 -0
  2. app.py +1532 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ asset/*.mp3
2
+ asset/*.wav
app.py ADDED
@@ -0,0 +1,1532 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import base64
3
+ import functools
4
+ import json
5
+ import sys
6
+ import threading
7
+ import time
8
+ from collections import OrderedDict
9
+ from dataclasses import dataclass
10
+ from pathlib import Path
11
+ from typing import Iterator, Sequence
12
+
13
+ import gradio as gr
14
+ import numpy as np
15
+
16
+ import torch
17
+ import torchaudio
18
+ import torch._dynamo
19
+ from transformers import AutoModel, AutoTokenizer
20
+ from mossttsrealtime import MossTTSRealtime, MossTTSRealtimeProcessor
21
+ from mossttsrealtime.streaming_mossttsrealtime import (
22
+ AudioStreamDecoder,
23
+ MossTTSRealtimeInference,
24
+ MossTTSRealtimeStreamingSession,
25
+ )
26
+
27
+ torch._dynamo.config.cache_size_limit = 64
28
+
29
+ APP_DIR = Path(__file__).resolve().parent
30
+ AUDIO_DIR = APP_DIR / "asset"
31
+ LOG_DIR = APP_DIR / "logs"
32
+ SAMPLE_RATE = 24000
33
+
34
+ CODEC_MODEL_PATH = "OpenMOSS-Team/MOSS-Audio-Tokenizer"
35
+ MODEL_PATH = "OpenMOSS-Team/MOSS-TTS-Realtime"
36
+ TOKENIZER_PATH = "OpenMOSS-Team/MOSS-TTS-Realtime"
37
+
38
+ PROMPT_WAV = "asset/prompt_audio.mp3"
39
+ USER_WAV = "asset/user1.wav"
40
+
41
+ WARMUP_POLL_INTERVAL_SECONDS = 0.5
42
+ DEFAULT_REPETITION_WINDOW = 50
43
+ WARMUP_STEP_TOKENS = DEFAULT_REPETITION_WINDOW + 1
44
+ WARMUP_USER_TEXT = "Hello!"
45
+ WARMUP_BASE_ASSISTANT_TEXT = (
46
+ "This startup warmup request primes the streaming text to speech path "
47
+ "so the first real user request avoids the cold compile stall."
48
+ )
49
+
50
+
51
+ def _apply_seed(seed: int | None) -> None:
52
+ if seed is None:
53
+ return
54
+ torch.manual_seed(seed)
55
+ torch.cuda.manual_seed_all(seed)
56
+
57
+
58
+ def _load_audio(path: Path, target_sample_rate: int = SAMPLE_RATE) -> torch.Tensor:
59
+ wav, sr = torchaudio.load(path)
60
+ if sr != target_sample_rate:
61
+ wav = torchaudio.functional.resample(wav, sr, target_sample_rate)
62
+ if wav.shape[0] > 1:
63
+ wav = wav.mean(dim=0, keepdim=True)
64
+ return wav
65
+
66
+
67
+ def _load_codec(device: torch.device, codec_model_path: str):
68
+ codec = AutoModel.from_pretrained(codec_model_path, trust_remote_code=True).eval()
69
+ return codec.to(device)
70
+
71
+
72
+ def _extract_codes(encode_result):
73
+ if isinstance(encode_result, dict):
74
+ codes = encode_result["audio_codes"]
75
+
76
+ elif isinstance(encode_result, (list, tuple)) and encode_result:
77
+ codes = encode_result[0]
78
+ else:
79
+ codes = encode_result
80
+
81
+ if isinstance(codes, np.ndarray):
82
+ codes = torch.from_numpy(codes)
83
+
84
+ if isinstance(codes, torch.Tensor) and codes.dim() == 3:
85
+ if codes.shape[1] == 1:
86
+ codes = codes[:, 0, :]
87
+ elif codes.shape[0] == 1:
88
+ codes = codes[0]
89
+ else:
90
+ raise ValueError(f"Unsupported 3D audio code shape: {tuple(codes.shape)}")
91
+
92
+ return codes
93
+
94
+
95
+ @dataclass(frozen=True)
96
+ class BackendPaths:
97
+ model_path: str
98
+ tokenizer_path: str
99
+ codec_model_path: str
100
+ device_str: str
101
+ attn_impl: str
102
+
103
+
104
+ @dataclass(frozen=True)
105
+ class GenerationConfig:
106
+ temperature: float
107
+ top_p: float
108
+ top_k: int
109
+ repetition_penalty: float
110
+ repetition_window: int
111
+ do_sample: bool
112
+ max_length: int
113
+ seed: int | None
114
+
115
+
116
+ @dataclass(frozen=True)
117
+ class StreamingConfig:
118
+ text_chunk_tokens: int
119
+ input_delay: float
120
+ decode_chunk_frames: int
121
+ decode_overlap_frames: int
122
+ chunk_duration: float
123
+ prebuffer_seconds: float
124
+ buffer_threshold_seconds: float = 0.0
125
+
126
+
127
+ @dataclass(frozen=True)
128
+ class StreamingRequest:
129
+ user_text: str
130
+ assistant_text: str
131
+ prompt_audio: str | None
132
+ user_audio: str | None
133
+ use_default_prompt: bool
134
+ use_default_user: bool
135
+ generation: GenerationConfig
136
+ streaming: StreamingConfig
137
+ backend: BackendPaths
138
+
139
+
140
+ @dataclass(frozen=True)
141
+ class StreamEvent:
142
+ message: str
143
+ audio: tuple[int, np.ndarray] | None = None
144
+
145
+
146
+ @dataclass(frozen=True)
147
+ class WarmupSnapshot:
148
+ state: str
149
+ progress: float
150
+ message: str
151
+ detail: str | None = None
152
+ error: str | None = None
153
+
154
+ @property
155
+ def ready(self) -> bool:
156
+ return self.state == "ready"
157
+
158
+ @property
159
+ def failed(self) -> bool:
160
+ return self.state == "failed"
161
+
162
+
163
+ def _make_log_path(prefix: str) -> Path:
164
+ LOG_DIR.mkdir(parents=True, exist_ok=True)
165
+ stamp = time.strftime("%Y%m%d_%H%M%S", time.localtime())
166
+ return LOG_DIR / f"{prefix}_{stamp}_{time.time_ns() % 1_000_000_000:09d}.jsonl"
167
+
168
+
169
+ def _compute_rtf_metrics(sample_count: int, sample_rate: int, started_at: float) -> dict[str, float | None]:
170
+ elapsed_s = max(0.0, time.monotonic() - started_at)
171
+ audio_s = float(sample_count) / float(sample_rate) if sample_count > 0 and sample_rate > 0 else 0.0
172
+ rtf = (elapsed_s / audio_s) if audio_s > 0 else None
173
+ return {
174
+ "elapsed_s": elapsed_s,
175
+ "audio_s": audio_s,
176
+ "rtf": rtf,
177
+ }
178
+
179
+
180
+ class StreamRTFLogger:
181
+ def __init__(self, path: Path, started_at: float):
182
+ self.path = path
183
+ self.started_at = started_at
184
+ self.chunk_count = 0
185
+ self.sample_rate = SAMPLE_RATE
186
+ self.samples_emitted = 0
187
+
188
+ @classmethod
189
+ def create(cls, request: "StreamingRequest", started_at: float) -> "StreamRTFLogger":
190
+ logger = cls(_make_log_path("rtf"), started_at)
191
+ logger.log_request_started(request)
192
+ print(f"[MossTTSRealtime][rtf-log] {logger.path}", flush=True)
193
+ return logger
194
+
195
+ def log_request_started(self, request: "StreamingRequest") -> None:
196
+ self._append(
197
+ {
198
+ "event": "request_started",
199
+ "user_text_chars": len(request.user_text),
200
+ "assistant_text_chars": len(request.assistant_text),
201
+ "text_chunk_tokens": request.streaming.text_chunk_tokens,
202
+ "decode_chunk_frames": request.streaming.decode_chunk_frames,
203
+ "decode_overlap_frames": request.streaming.decode_overlap_frames,
204
+ "chunk_duration_s": request.streaming.chunk_duration,
205
+ "prebuffer_seconds": request.streaming.prebuffer_seconds,
206
+ "temperature": request.generation.temperature,
207
+ "top_p": request.generation.top_p,
208
+ "top_k": request.generation.top_k,
209
+ "repetition_penalty": request.generation.repetition_penalty,
210
+ "repetition_window": request.generation.repetition_window,
211
+ "do_sample": request.generation.do_sample,
212
+ "max_length": request.generation.max_length,
213
+ "seed": request.generation.seed,
214
+ "device": request.backend.device_str,
215
+ "attn_implementation": request.backend.attn_impl,
216
+ }
217
+ )
218
+
219
+ def log_chunk(
220
+ self,
221
+ *,
222
+ event_message: str,
223
+ sample_rate: int,
224
+ chunk: np.ndarray,
225
+ first_audio_time: float | None,
226
+ ) -> None:
227
+ chunk = np.asarray(chunk).reshape(-1)
228
+ if chunk.size == 0:
229
+ return
230
+ self.chunk_count += 1
231
+ self.sample_rate = int(sample_rate)
232
+ self.samples_emitted += int(chunk.size)
233
+ metrics = _compute_rtf_metrics(self.samples_emitted, self.sample_rate, self.started_at)
234
+ record = {
235
+ "event": "stream_chunk",
236
+ "message": event_message,
237
+ "chunk_idx": self.chunk_count,
238
+ "chunk_audio_s": float(chunk.size) / float(self.sample_rate),
239
+ "audio_s_emitted": metrics["audio_s"],
240
+ "elapsed_s": metrics["elapsed_s"],
241
+ "rtf": metrics["rtf"],
242
+ }
243
+ if first_audio_time is not None:
244
+ record["time_to_first_audio_ms"] = max(0.0, (first_audio_time - self.started_at) * 1000.0)
245
+ self._append(record)
246
+
247
+ def log_completion(self, *, first_audio_time: float | None) -> None:
248
+ metrics = _compute_rtf_metrics(self.samples_emitted, self.sample_rate, self.started_at)
249
+ record = {
250
+ "event": "stream_complete",
251
+ "chunk_count": self.chunk_count,
252
+ "audio_s_total": metrics["audio_s"],
253
+ "elapsed_s": metrics["elapsed_s"],
254
+ "rtf": metrics["rtf"],
255
+ }
256
+ if first_audio_time is not None:
257
+ record["time_to_first_audio_ms"] = max(0.0, (first_audio_time - self.started_at) * 1000.0)
258
+ self._append(record)
259
+
260
+ def log_no_audio(self) -> None:
261
+ metrics = _compute_rtf_metrics(0, self.sample_rate, self.started_at)
262
+ self._append(
263
+ {
264
+ "event": "stream_complete",
265
+ "chunk_count": 0,
266
+ "audio_s_total": 0.0,
267
+ "elapsed_s": metrics["elapsed_s"],
268
+ "rtf": None,
269
+ "warning": "No audio chunks emitted.",
270
+ }
271
+ )
272
+
273
+ def log_error(self, exc: Exception, *, first_audio_time: float | None) -> None:
274
+ metrics = _compute_rtf_metrics(self.samples_emitted, self.sample_rate, self.started_at)
275
+ record = {
276
+ "event": "stream_error",
277
+ "error_type": type(exc).__name__,
278
+ "error": str(exc),
279
+ "chunk_count": self.chunk_count,
280
+ "audio_s_emitted": metrics["audio_s"],
281
+ "elapsed_s": metrics["elapsed_s"],
282
+ "rtf": metrics["rtf"],
283
+ }
284
+ if first_audio_time is not None:
285
+ record["time_to_first_audio_ms"] = max(0.0, (first_audio_time - self.started_at) * 1000.0)
286
+ self._append(record)
287
+
288
+ def _append(self, payload: dict[str, object]) -> None:
289
+ record = {
290
+ "ts": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
291
+ **payload,
292
+ }
293
+ with self.path.open("a", encoding="utf-8") as handle:
294
+ handle.write(json.dumps(record, ensure_ascii=False) + "\n")
295
+
296
+
297
+ class TokenChunkStream:
298
+ def __init__(
299
+ self,
300
+ tokens: Sequence[int],
301
+ chunk_size: int,
302
+ ):
303
+ self._tokens = list(tokens)
304
+ self._chunk_size = int(chunk_size)
305
+
306
+ def __iter__(self) -> Iterator[list[int]]:
307
+ if not self._tokens:
308
+ return
309
+ step = len(self._tokens) if self._chunk_size <= 0 else self._chunk_size
310
+ for idx in range(0, len(self._tokens), step):
311
+ yield self._tokens[idx : idx + step]
312
+
313
+
314
+ class BufferedAudioTracker:
315
+ def __init__(self, sample_rate: int):
316
+ self.sample_rate = sample_rate
317
+ self.start_time: float | None = None
318
+ self.samples_emitted = 0
319
+
320
+ def add_chunk(self, chunk: np.ndarray) -> None:
321
+ if chunk.size == 0:
322
+ return
323
+ if self.start_time is None:
324
+ self.start_time = time.monotonic()
325
+ self.samples_emitted += int(chunk.size)
326
+
327
+ def buffered_seconds(self) -> float:
328
+ if self.start_time is None:
329
+ return 0.0
330
+ elapsed = time.monotonic() - self.start_time
331
+ buffered = self.samples_emitted / self.sample_rate - elapsed
332
+ return max(0.0, buffered)
333
+
334
+
335
+ class AudioFrameDecoder:
336
+ def __init__(
337
+ self,
338
+ decoder: AudioStreamDecoder,
339
+ codebook_size: int,
340
+ audio_eos_token: int,
341
+ ):
342
+ self.decoder = decoder
343
+ self.codebook_size = codebook_size
344
+ self.audio_eos_token = audio_eos_token
345
+
346
+ def decode_frames(self, audio_frames: list[torch.Tensor]) -> Iterator[np.ndarray]:
347
+ for frame in audio_frames:
348
+ tokens = frame
349
+ if tokens.dim() == 3:
350
+ tokens = tokens[0]
351
+ if tokens.dim() != 2:
352
+ raise ValueError(f"Expected [T, C] audio tokens, got {tuple(tokens.shape)}")
353
+ tokens, stop = _sanitize_tokens(tokens, self.codebook_size, self.audio_eos_token)
354
+ if tokens.numel() == 0:
355
+ if stop:
356
+ break
357
+ continue
358
+ self.decoder.push_tokens(tokens.detach())
359
+ for wav in self.decoder.audio_chunks():
360
+ if wav.numel() == 0:
361
+ continue
362
+ yield wav.detach().cpu().numpy().reshape(-1)
363
+ if stop:
364
+ break
365
+
366
+ def flush(self) -> Iterator[np.ndarray]:
367
+ final_chunk = self.decoder.flush()
368
+ if final_chunk is not None and final_chunk.numel() > 0:
369
+ yield final_chunk.detach().cpu().numpy().reshape(-1)
370
+
371
+
372
+ class StreamAudioEmitter:
373
+ def __init__(self, sample_rate: int, prebuffer_seconds: float):
374
+ self.sample_rate = sample_rate
375
+ self._buffer_tracker = BufferedAudioTracker(sample_rate)
376
+ self._prebuffer_target = max(0.0, float(prebuffer_seconds))
377
+ self._prebuffering = self._prebuffer_target > 0.0
378
+ self._pending_chunks: list[np.ndarray] = []
379
+ self._pending_samples = 0
380
+ self.chunk_count = 0
381
+ self.has_audio = False
382
+
383
+ def wait_for_capacity(self, threshold_seconds: float) -> None:
384
+ _maybe_wait_for_buffer(self._buffer_tracker, threshold_seconds)
385
+
386
+ def emit_many(self, chunks: Iterator[np.ndarray], message_prefix: str) -> Iterator[StreamEvent]:
387
+ for chunk in chunks:
388
+ yield from self.emit(chunk, message_prefix)
389
+
390
+ def emit(self, chunk: np.ndarray, message_prefix: str) -> Iterator[StreamEvent]:
391
+ chunk = np.asarray(chunk).reshape(-1)
392
+ if chunk.size == 0:
393
+ return
394
+ if self._prebuffering:
395
+ self._pending_chunks.append(chunk)
396
+ self._pending_samples += int(chunk.size)
397
+ if (self._pending_samples / self.sample_rate) < self._prebuffer_target:
398
+ return
399
+ self._prebuffering = False
400
+ pending_chunks = self._pending_chunks
401
+ self._pending_chunks = []
402
+ self._pending_samples = 0
403
+ for pending in pending_chunks:
404
+ yield self._make_event(pending, message_prefix)
405
+ return
406
+ yield self._make_event(chunk, message_prefix)
407
+
408
+ def flush(self, message_prefix: str) -> Iterator[StreamEvent]:
409
+ if not self._prebuffering or not self._pending_chunks:
410
+ self._prebuffering = False
411
+ return
412
+ self._prebuffering = False
413
+ pending_chunks = self._pending_chunks
414
+ self._pending_chunks = []
415
+ self._pending_samples = 0
416
+ for chunk in pending_chunks:
417
+ yield self._make_event(chunk, message_prefix)
418
+
419
+ def _make_event(self, chunk: np.ndarray, message_prefix: str) -> StreamEvent:
420
+ self.chunk_count += 1
421
+ self.has_audio = True
422
+ self._buffer_tracker.add_chunk(chunk)
423
+ return StreamEvent(
424
+ message=f"{message_prefix} chunk {self.chunk_count}",
425
+ audio=(self.sample_rate, chunk),
426
+ )
427
+
428
+
429
+ def _maybe_wait_for_buffer(buffer_tracker: BufferedAudioTracker, threshold_seconds: float) -> None:
430
+ if threshold_seconds <= 0:
431
+ return
432
+ while buffer_tracker.buffered_seconds() > threshold_seconds:
433
+ time.sleep(0.01)
434
+
435
+
436
+ def _sanitize_tokens(
437
+ tokens: torch.Tensor,
438
+ codebook_size: int,
439
+ audio_eos_token: int,
440
+ ) -> tuple[torch.Tensor, bool]:
441
+ if tokens.dim() == 1:
442
+ tokens = tokens.unsqueeze(0)
443
+ if tokens.numel() == 0:
444
+ return tokens, False
445
+ eos_rows = (tokens[:, 0] == audio_eos_token).nonzero(as_tuple=False)
446
+ invalid_rows = ((tokens < 0) | (tokens >= codebook_size)).any(dim=1)
447
+ stop_idx = None
448
+ if eos_rows.numel() > 0:
449
+ stop_idx = int(eos_rows[0].item())
450
+ if invalid_rows.any():
451
+ invalid_idx = int(invalid_rows.nonzero(as_tuple=False)[0].item())
452
+ stop_idx = invalid_idx if stop_idx is None else min(stop_idx, invalid_idx)
453
+ if stop_idx is not None:
454
+ tokens = tokens[:stop_idx]
455
+ return tokens, True
456
+ return tokens, False
457
+
458
+
459
+ def _build_streaming_session(
460
+ model: MossTTSRealtime,
461
+ tokenizer,
462
+ processor: MossTTSRealtimeProcessor,
463
+ codec,
464
+ *,
465
+ max_length: int,
466
+ chunk_duration: float,
467
+ temperature: float,
468
+ top_p: float,
469
+ top_k: int,
470
+ do_sample: bool,
471
+ repetition_penalty: float,
472
+ repetition_window: int,
473
+ ) -> tuple[MossTTSRealtimeStreamingSession, MossTTSRealtimeInference]:
474
+ inferencer = MossTTSRealtimeInference(model, tokenizer, max_length=max_length)
475
+ inferencer.reset_generation_state(keep_cache=False)
476
+ session = MossTTSRealtimeStreamingSession(
477
+ inferencer,
478
+ processor,
479
+ codec=codec,
480
+ codec_sample_rate=SAMPLE_RATE,
481
+ codec_encode_kwargs={"chunk_duration": chunk_duration},
482
+ prefill_text_len=processor.delay_tokens_len,
483
+ temperature=temperature,
484
+ top_p=top_p,
485
+ top_k=top_k,
486
+ do_sample=do_sample,
487
+ repetition_penalty=repetition_penalty,
488
+ repetition_window=repetition_window,
489
+ )
490
+ return session, inferencer
491
+
492
+
493
+ def _build_frame_decoder(
494
+ codec,
495
+ inferencer: MossTTSRealtimeInference,
496
+ device: torch.device,
497
+ *,
498
+ chunk_frames: int,
499
+ overlap_frames: int,
500
+ ) -> AudioFrameDecoder:
501
+ decoder = AudioStreamDecoder(
502
+ codec,
503
+ chunk_frames=chunk_frames,
504
+ overlap_frames=overlap_frames,
505
+ decode_kwargs={"chunk_duration": -1},
506
+ device=device,
507
+ )
508
+ return AudioFrameDecoder(
509
+ decoder,
510
+ int(getattr(codec, "codebook_size", 1024)),
511
+ int(getattr(inferencer, "audio_eos_token", 1026)),
512
+ )
513
+
514
+
515
+ def _normalize_seed(value: float | int | None) -> int | None:
516
+ if value is None:
517
+ return None
518
+ seed = int(value)
519
+ return None if seed == 0 else seed
520
+
521
+
522
+ def _format_completion_status(
523
+ chunk_count: int,
524
+ sample_rate: int,
525
+ full_audio: np.ndarray,
526
+ started_at: float,
527
+ first_audio_time: float | None,
528
+ ) -> str:
529
+ elapsed = time.monotonic() - started_at
530
+ audio_seconds = float(full_audio.size) / float(sample_rate) if full_audio.size > 0 else 0.0
531
+ rtf = (elapsed / audio_seconds) if audio_seconds > 0 else float("inf")
532
+ parts = [
533
+ "Done",
534
+ f"chunks={chunk_count}",
535
+ f"audio={audio_seconds:.2f}s",
536
+ f"elapsed={elapsed:.2f}s",
537
+ f"RTF={rtf:.3f}" if np.isfinite(rtf) else "RTF=inf",
538
+ ]
539
+ if first_audio_time is not None:
540
+ parts.append(f"TTFA={(first_audio_time - started_at) * 1000.0:.0f}ms")
541
+ return " | ".join(parts)
542
+
543
+
544
+ @functools.lru_cache(maxsize=1)
545
+ def _load_backend(
546
+ model_path: str,
547
+ tokenizer_path: str,
548
+ codec_model_path: str,
549
+ device_str: str,
550
+ attn_impl: str,
551
+ ):
552
+ if not torch.cuda.is_available():
553
+ raise RuntimeError("CUDA is required for the MossTTSRealtime streaming demo.")
554
+
555
+ device = torch.device(device_str)
556
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
557
+ processor = MossTTSRealtimeProcessor(tokenizer)
558
+
559
+ dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
560
+ if attn_impl and attn_impl.lower() not in {"none", ""}:
561
+ model = MossTTSRealtime.from_pretrained(model_path, attn_implementation=attn_impl, torch_dtype=dtype).to(device)
562
+ if (
563
+ attn_impl.lower() == "flash_attention_2"
564
+ and hasattr(model, "language_model")
565
+ and hasattr(model.language_model, "config")
566
+ ):
567
+ model.language_model.config.attn_implementation = "flash_attention_2"
568
+ else:
569
+ model = MossTTSRealtime.from_pretrained(model_path, torch_dtype=dtype).to(device)
570
+ model.eval()
571
+
572
+ codec = _load_codec(device, codec_model_path)
573
+ return model, tokenizer, processor, codec, device
574
+
575
+
576
+ def _resolve_audio_path(audio_path: str | None, use_default: bool, default_path: str | Path) -> Path | None:
577
+ if audio_path:
578
+ return Path(audio_path).expanduser()
579
+ if use_default:
580
+ return Path(default_path).expanduser()
581
+ return None
582
+
583
+
584
+ class StreamingTTSDemo:
585
+ def __init__(self, audio_token_cache_size: int = 8):
586
+ self._audio_token_cache_size = max(1, int(audio_token_cache_size))
587
+ self._audio_token_cache: OrderedDict[tuple[str, int, float], np.ndarray] = OrderedDict()
588
+
589
+ def get_or_load_backend(self, backend: BackendPaths):
590
+ return _load_backend(
591
+ backend.model_path,
592
+ backend.tokenizer_path,
593
+ backend.codec_model_path,
594
+ backend.device_str,
595
+ backend.attn_impl,
596
+ )
597
+
598
+ def _validate_request(self, request: StreamingRequest) -> tuple[Path | None, Path | None]:
599
+ if not request.user_text.strip():
600
+ raise ValueError("user_text is required.")
601
+ if not request.assistant_text.strip():
602
+ raise ValueError("assistant_text is required.")
603
+ if request.streaming.text_chunk_tokens <= 0:
604
+ raise ValueError("text_chunk_tokens must be greater than 0.")
605
+ if request.streaming.decode_chunk_frames <= 0:
606
+ raise ValueError("decode_chunk_frames must be greater than 0.")
607
+ if request.streaming.chunk_duration <= 0:
608
+ raise ValueError("chunk_duration must be greater than 0.")
609
+
610
+ prompt_path = _resolve_audio_path(request.prompt_audio, request.use_default_prompt, PROMPT_WAV)
611
+ user_path = _resolve_audio_path(request.user_audio, request.use_default_user, USER_WAV)
612
+
613
+ if prompt_path is not None and not prompt_path.exists():
614
+ raise FileNotFoundError(f"Prompt wav not found: {prompt_path}")
615
+ if user_path is not None and not user_path.exists():
616
+ raise FileNotFoundError(f"User wav not found: {user_path}")
617
+
618
+ return prompt_path, user_path
619
+
620
+ def _encode_audio_tokens(
621
+ self,
622
+ path: Path,
623
+ codec,
624
+ device: torch.device,
625
+ chunk_duration: float,
626
+ ) -> np.ndarray:
627
+ resolved_path = path.expanduser().resolve()
628
+ cache_key = (str(resolved_path), int(resolved_path.stat().st_mtime_ns), float(chunk_duration))
629
+ cached_tokens = self._audio_token_cache.get(cache_key)
630
+ if cached_tokens is not None:
631
+ self._audio_token_cache.move_to_end(cache_key)
632
+ return cached_tokens
633
+
634
+ with torch.inference_mode():
635
+ audio_tensor = _load_audio(resolved_path)
636
+ waveform = audio_tensor.to(device)
637
+ if waveform.dim() == 2:
638
+ waveform = waveform.unsqueeze(0)
639
+ encode_result = codec.encode(waveform, chunk_duration=chunk_duration)
640
+
641
+ tokens = _extract_codes(encode_result)
642
+ if isinstance(tokens, torch.Tensor):
643
+ tokens = tokens.detach().cpu().numpy()
644
+ else:
645
+ tokens = np.asarray(tokens)
646
+
647
+ self._audio_token_cache[cache_key] = tokens
648
+ self._audio_token_cache.move_to_end(cache_key)
649
+ while len(self._audio_token_cache) > self._audio_token_cache_size:
650
+ self._audio_token_cache.popitem(last=False)
651
+
652
+ return tokens
653
+
654
+ @staticmethod
655
+ def _build_text_only_turn_input(
656
+ processor: MossTTSRealtimeProcessor,
657
+ user_text: str,
658
+ prompt_tokens: np.ndarray | None,
659
+ ) -> np.ndarray:
660
+ system_prompt = processor.make_ensemble(prompt_tokens)
661
+ user_prompt_text = "<|im_end|>\n<|im_start|>user\n" + user_text + "<|im_end|>\n<|im_start|>assistant\n"
662
+ user_prompt_tokens = processor.tokenizer(user_prompt_text)["input_ids"]
663
+ user_prompt = np.full(
664
+ shape=(len(user_prompt_tokens), processor.channels + 1),
665
+ fill_value=processor.audio_channel_pad,
666
+ dtype=np.int64,
667
+ )
668
+ user_prompt[:, 0] = np.asarray(user_prompt_tokens, dtype=np.int64)
669
+ return np.concatenate([system_prompt, user_prompt], axis=0)
670
+
671
+ def _prepare_session_turn(
672
+ self,
673
+ session: MossTTSRealtimeStreamingSession,
674
+ processor: MossTTSRealtimeProcessor,
675
+ user_text: str,
676
+ prompt_tokens: np.ndarray | None,
677
+ user_tokens: np.ndarray | None,
678
+ ) -> str | None:
679
+ if user_tokens is None:
680
+ turn_input_ids = self._build_text_only_turn_input(processor, user_text, prompt_tokens)
681
+ session.reset_turn(input_ids=turn_input_ids, include_system_prompt=True, reset_cache=True)
682
+ return "No user audio provided, running text-only turn."
683
+
684
+ session.reset_turn(
685
+ user_text=user_text,
686
+ user_audio_tokens=user_tokens,
687
+ include_system_prompt=True,
688
+ reset_cache=True,
689
+ )
690
+ return None
691
+
692
+ def run_stream(self, request: StreamingRequest) -> Iterator[StreamEvent]:
693
+ prompt_path, user_path = self._validate_request(request)
694
+ model, tokenizer, processor, codec, device = self.get_or_load_backend(request.backend)
695
+ _apply_seed(request.generation.seed)
696
+
697
+ prompt_tokens = (
698
+ self._encode_audio_tokens(
699
+ prompt_path,
700
+ codec,
701
+ device,
702
+ chunk_duration=request.streaming.chunk_duration,
703
+ )
704
+ if prompt_path is not None
705
+ else None
706
+ )
707
+ user_tokens = (
708
+ self._encode_audio_tokens(
709
+ user_path,
710
+ codec,
711
+ device,
712
+ chunk_duration=request.streaming.chunk_duration,
713
+ )
714
+ if user_path is not None
715
+ else None
716
+ )
717
+
718
+ session, inferencer = _build_streaming_session(
719
+ model,
720
+ tokenizer,
721
+ processor,
722
+ codec,
723
+ max_length=request.generation.max_length,
724
+ chunk_duration=request.streaming.chunk_duration,
725
+ temperature=request.generation.temperature,
726
+ top_p=request.generation.top_p,
727
+ top_k=request.generation.top_k,
728
+ do_sample=request.generation.do_sample,
729
+ repetition_penalty=request.generation.repetition_penalty,
730
+ repetition_window=request.generation.repetition_window,
731
+ )
732
+ if prompt_tokens is not None:
733
+ session.set_voice_prompt_tokens(prompt_tokens)
734
+ else:
735
+ session.clear_voice_prompt()
736
+
737
+ turn_message = self._prepare_session_turn(
738
+ session,
739
+ processor,
740
+ request.user_text,
741
+ prompt_tokens,
742
+ user_tokens,
743
+ )
744
+ if turn_message:
745
+ yield StreamEvent(message=turn_message)
746
+
747
+ frame_decoder = _build_frame_decoder(
748
+ codec,
749
+ inferencer,
750
+ device,
751
+ chunk_frames=request.streaming.decode_chunk_frames,
752
+ overlap_frames=request.streaming.decode_overlap_frames,
753
+ )
754
+
755
+ text_tokens = tokenizer.encode(request.assistant_text, add_special_tokens=False)
756
+ if not text_tokens:
757
+ raise RuntimeError("Assistant text tokenization returned no tokens.")
758
+
759
+ token_stream = TokenChunkStream(text_tokens, request.streaming.text_chunk_tokens)
760
+ audio_emitter = StreamAudioEmitter(SAMPLE_RATE, request.streaming.prebuffer_seconds)
761
+
762
+ with codec.streaming(batch_size=1):
763
+ for token_chunk in token_stream:
764
+ audio_emitter.wait_for_capacity(request.streaming.buffer_threshold_seconds)
765
+ audio_frames = session.push_text_tokens(token_chunk)
766
+ yield from audio_emitter.emit_many(frame_decoder.decode_frames(audio_frames), "Streaming")
767
+ if request.streaming.input_delay > 0:
768
+ time.sleep(request.streaming.input_delay)
769
+
770
+ final_frames = session.end_text()
771
+ yield from audio_emitter.emit_many(frame_decoder.decode_frames(final_frames), "Finalizing")
772
+
773
+ while True:
774
+ drain_frames = session.drain(max_steps=1)
775
+ if not drain_frames:
776
+ break
777
+ yield from audio_emitter.emit_many(frame_decoder.decode_frames(drain_frames), "Finalizing")
778
+ if session.inferencer.is_finished:
779
+ break
780
+
781
+ yield from audio_emitter.emit_many(frame_decoder.flush(), "Final")
782
+ yield from audio_emitter.flush("Final")
783
+
784
+ if not audio_emitter.has_audio:
785
+ raise RuntimeError("No audio waveform chunks decoded from streaming inference.")
786
+
787
+ yield StreamEvent(message="Streaming complete.")
788
+
789
+
790
+ class WarmupManager:
791
+ def __init__(self, tts_demo: "StreamingTTSDemo", backend: BackendPaths):
792
+ self.tts_demo = tts_demo
793
+ self.backend = backend
794
+ self._lock = threading.Lock()
795
+ self._thread: threading.Thread | None = None
796
+ self._started = False
797
+ self._state = "pending"
798
+ self._progress = 0.0
799
+ self._message = "Waiting for startup warmup."
800
+ self._detail = "The app warms the streaming path before the first real request."
801
+ self._error: str | None = None
802
+
803
+ def start(self) -> None:
804
+ with self._lock:
805
+ if self._started:
806
+ return
807
+ self._started = True
808
+ self._thread = threading.Thread(target=self._run, name="tts-startup-warmup", daemon=True)
809
+ self._thread.start()
810
+
811
+ def snapshot(self) -> WarmupSnapshot:
812
+ with self._lock:
813
+ return WarmupSnapshot(
814
+ state=self._state,
815
+ progress=self._progress,
816
+ message=self._message,
817
+ detail=self._detail,
818
+ error=self._error,
819
+ )
820
+
821
+ def _set_state(
822
+ self,
823
+ *,
824
+ state: str | None = None,
825
+ progress: float | None = None,
826
+ message: str | None = None,
827
+ detail: str | None = None,
828
+ error: str | None = None,
829
+ ) -> None:
830
+ with self._lock:
831
+ if state is not None:
832
+ self._state = state
833
+ if progress is not None:
834
+ self._progress = max(0.0, min(1.0, float(progress)))
835
+ if message is not None:
836
+ self._message = message
837
+ if detail is not None:
838
+ self._detail = detail
839
+ self._error = error
840
+
841
+ @staticmethod
842
+ def _consume_audio(chunks: Iterator[np.ndarray]) -> None:
843
+ for _chunk in chunks:
844
+ pass
845
+
846
+ @staticmethod
847
+ def _ensure_warmup_text(tokenizer, minimum_tokens: int) -> tuple[str, list[int]]:
848
+ text = WARMUP_BASE_ASSISTANT_TEXT
849
+ tokens = tokenizer.encode(text, add_special_tokens=False)
850
+ while len(tokens) < minimum_tokens:
851
+ text = f"{text} {WARMUP_BASE_ASSISTANT_TEXT}"
852
+ tokens = tokenizer.encode(text, add_special_tokens=False)
853
+ return text, tokens
854
+
855
+ @staticmethod
856
+ def _warmup_step_detail(step_idx: int, total_steps: int) -> str:
857
+ if step_idx == 1:
858
+ return "First incremental step is compiling the cold streaming path."
859
+ if step_idx == 2:
860
+ return "Second incremental step is warming the next steady-state path."
861
+ if step_idx == DEFAULT_REPETITION_WINDOW:
862
+ return "Warming the first full repetition-window step."
863
+ if step_idx == WARMUP_STEP_TOKENS:
864
+ return "Confirming the post-window steady-state step."
865
+ return f"Warming token step {step_idx}/{total_steps}."
866
+
867
+ def _run(self) -> None:
868
+ try:
869
+ self._set_state(
870
+ state="running",
871
+ progress=0.02,
872
+ message="Starting startup warmup.",
873
+ detail="Preparing backend state for the first real request.",
874
+ error=None,
875
+ )
876
+
877
+ self._set_state(
878
+ progress=0.08,
879
+ message="Loading backend.",
880
+ detail="Model, tokenizer, codec, and CUDA runtime are warming up.",
881
+ error=None,
882
+ )
883
+ model, tokenizer, processor, codec, device = self.tts_demo.get_or_load_backend(self.backend)
884
+
885
+ self._set_state(
886
+ progress=0.32,
887
+ message="Preparing streaming session.",
888
+ detail="Building a text-only warmup turn and its decoder.",
889
+ error=None,
890
+ )
891
+ session, inferencer = _build_streaming_session(
892
+ model,
893
+ tokenizer,
894
+ processor,
895
+ codec,
896
+ max_length=256,
897
+ chunk_duration=0.24,
898
+ temperature=0.8,
899
+ top_p=0.6,
900
+ top_k=30,
901
+ do_sample=True,
902
+ repetition_penalty=1.1,
903
+ repetition_window=DEFAULT_REPETITION_WINDOW,
904
+ )
905
+ session.clear_voice_prompt()
906
+ session.reset_turn(
907
+ input_ids=self.tts_demo._build_text_only_turn_input(processor, WARMUP_USER_TEXT, None),
908
+ include_system_prompt=True,
909
+ reset_cache=True,
910
+ )
911
+
912
+ frame_decoder = _build_frame_decoder(
913
+ codec,
914
+ inferencer,
915
+ device,
916
+ chunk_frames=WARMUP_STEP_TOKENS,
917
+ overlap_frames=0,
918
+ )
919
+
920
+ _, warmup_tokens = self._ensure_warmup_text(
921
+ tokenizer,
922
+ processor.delay_tokens_len + WARMUP_STEP_TOKENS,
923
+ )
924
+
925
+ with codec.streaming(batch_size=1):
926
+ self._set_state(
927
+ progress=0.45,
928
+ message="Running prefill.",
929
+ detail="Building the first KV cache and warming the backbone path.",
930
+ error=None,
931
+ )
932
+ prefill_frames = session.push_text_tokens(warmup_tokens[: processor.delay_tokens_len])
933
+ self._consume_audio(frame_decoder.decode_frames(prefill_frames))
934
+
935
+ step_tokens = warmup_tokens[
936
+ processor.delay_tokens_len : processor.delay_tokens_len + WARMUP_STEP_TOKENS
937
+ ]
938
+ total_steps = max(1, len(step_tokens))
939
+ for idx, token in enumerate(step_tokens, start=1):
940
+ self._set_state(
941
+ progress=0.55 + 0.25 * (idx - 1) / total_steps,
942
+ message="Compiling first streaming steps.",
943
+ detail=self._warmup_step_detail(idx, total_steps),
944
+ error=None,
945
+ )
946
+ step_frames = session.push_text_tokens([token])
947
+ self._consume_audio(frame_decoder.decode_frames(step_frames))
948
+
949
+ self._set_state(
950
+ progress=0.86,
951
+ message="Warming finalization path.",
952
+ detail="Priming end-text, drain, and decoder flush before user traffic.",
953
+ error=None,
954
+ )
955
+ final_frames = session.end_text()
956
+ self._consume_audio(frame_decoder.decode_frames(final_frames))
957
+ drain_frames = session.drain(max_steps=1)
958
+ self._consume_audio(frame_decoder.decode_frames(drain_frames))
959
+ self._consume_audio(frame_decoder.flush())
960
+
961
+ self._set_state(
962
+ state="ready",
963
+ progress=1.0,
964
+ message="Warmup complete.",
965
+ detail="The first real request should avoid the cold-start stall.",
966
+ error=None,
967
+ )
968
+ except Exception as exc:
969
+ self._set_state(
970
+ state="failed",
971
+ progress=1.0,
972
+ message="Warmup failed.",
973
+ detail="The app did not finish startup warmup.",
974
+ error=str(exc),
975
+ )
976
+ print(f"[MossTTSRealtime][warmup-error] {exc}", file=sys.stderr, flush=True)
977
+
978
+
979
+ def _warmup_button_update(snapshot: WarmupSnapshot):
980
+ if snapshot.ready:
981
+ return gr.update(value="Generate", interactive=True)
982
+ if snapshot.failed:
983
+ return gr.update(value="Warmup Failed", interactive=False)
984
+ return gr.update(value="Warming Up...", interactive=False)
985
+
986
+
987
+ def _warmup_gate_message(snapshot: WarmupSnapshot) -> str:
988
+ progress_pct = int(round(max(0.0, min(1.0, snapshot.progress)) * 100.0))
989
+ if snapshot.failed:
990
+ return f"Warmup failed: {snapshot.error or snapshot.message}"
991
+ return f"Warmup in progress ({progress_pct}%): {snapshot.message}"
992
+
993
+
994
+ def _status_from_snapshot(snapshot: WarmupSnapshot) -> str:
995
+ return "Ready." if snapshot.ready else _warmup_gate_message(snapshot)
996
+
997
+
998
+ def _warmup_status_update(snapshot: WarmupSnapshot):
999
+ return gr.update(value=_status_from_snapshot(snapshot))
1000
+
1001
+
1002
+ def _warmup_timer_update(snapshot: WarmupSnapshot):
1003
+ return gr.update(active=not (snapshot.ready or snapshot.failed))
1004
+
1005
+
1006
+ def _encode_chunk(sr: int, chunk: np.ndarray, idx: int) -> str:
1007
+ if chunk.dtype != np.float32:
1008
+ chunk = chunk.astype(np.float32)
1009
+ if chunk.ndim != 1:
1010
+ chunk = chunk.reshape(-1)
1011
+ payload = {
1012
+ "sr": int(sr),
1013
+ "idx": int(idx),
1014
+ "data": base64.b64encode(chunk.tobytes()).decode("ascii"),
1015
+ }
1016
+ return json.dumps(payload)
1017
+
1018
+
1019
+ def _build_request(
1020
+ args: argparse.Namespace,
1021
+ *,
1022
+ user_text: str | None,
1023
+ assistant_text: str | None,
1024
+ prompt_audio: str | None,
1025
+ user_audio: str | None,
1026
+ use_default_prompt: bool,
1027
+ use_default_user: bool,
1028
+ temperature: float,
1029
+ top_p: float,
1030
+ top_k: int,
1031
+ repetition_penalty: float,
1032
+ repetition_window: int,
1033
+ do_sample: bool,
1034
+ max_length: int,
1035
+ seed: float | int | None,
1036
+ text_chunk_tokens: int,
1037
+ input_delay: float,
1038
+ decode_chunk_frames: int,
1039
+ decode_overlap_frames: int,
1040
+ chunk_duration: float,
1041
+ prebuffer_seconds: float,
1042
+ ) -> StreamingRequest:
1043
+ return StreamingRequest(
1044
+ user_text=str(user_text or "Hello!"),
1045
+ assistant_text=str(assistant_text or ""),
1046
+ prompt_audio=prompt_audio,
1047
+ user_audio=user_audio,
1048
+ use_default_prompt=use_default_prompt,
1049
+ use_default_user=use_default_user,
1050
+ generation=GenerationConfig(
1051
+ temperature=float(temperature),
1052
+ top_p=float(top_p),
1053
+ top_k=int(top_k),
1054
+ repetition_penalty=float(repetition_penalty),
1055
+ repetition_window=int(repetition_window),
1056
+ do_sample=bool(do_sample),
1057
+ max_length=int(max_length),
1058
+ seed=_normalize_seed(seed),
1059
+ ),
1060
+ streaming=StreamingConfig(
1061
+ text_chunk_tokens=int(text_chunk_tokens),
1062
+ input_delay=float(input_delay),
1063
+ decode_chunk_frames=int(decode_chunk_frames),
1064
+ decode_overlap_frames=int(decode_overlap_frames),
1065
+ chunk_duration=float(chunk_duration),
1066
+ prebuffer_seconds=float(prebuffer_seconds),
1067
+ ),
1068
+ backend=BackendPaths(
1069
+ model_path=args.model_path,
1070
+ tokenizer_path=args.tokenizer_path,
1071
+ codec_model_path=args.codec_model_path,
1072
+ device_str=args.device,
1073
+ attn_impl=args.attn_implementation,
1074
+ ),
1075
+ )
1076
+
1077
+
1078
+ STREAM_PLAYER_HTML = """
1079
+ <style>
1080
+ #pcm_stream {
1081
+ position: absolute !important;
1082
+ left: -9999px !important;
1083
+ width: 1px !important;
1084
+ height: 1px !important;
1085
+ opacity: 0 !important;
1086
+ pointer-events: none !important;
1087
+ }
1088
+ #pcm_stream textarea, #pcm_stream input {
1089
+ width: 1px !important;
1090
+ height: 1px !important;
1091
+ opacity: 0 !important;
1092
+ }
1093
+ </style>
1094
+ """
1095
+
1096
+ STREAM_PLAYER_JS = r"""
1097
+ const elemId = "pcm_stream";
1098
+ if (window.__pcm_streaming_inited__) {
1099
+ return;
1100
+ }
1101
+ window.__pcm_streaming_inited__ = true;
1102
+
1103
+ let audioCtx = null;
1104
+ let nextTime = 0;
1105
+ let lastIdx = -1;
1106
+ let lastValue = "";
1107
+ let boundField = null;
1108
+ let usingSetterHook = false;
1109
+ const FADE_MS = 6;
1110
+ const MIN_BUFFER_SEC = 0.25;
1111
+
1112
+ function initAudio(sr) {
1113
+ if (audioCtx && audioCtx.sampleRate !== sr) {
1114
+ audioCtx.close();
1115
+ audioCtx = null;
1116
+ }
1117
+ if (!audioCtx) {
1118
+ audioCtx = new (window.AudioContext || window.webkitAudioContext)({ sampleRate: sr });
1119
+ nextTime = audioCtx.currentTime;
1120
+ }
1121
+ if (audioCtx.state === "suspended") {
1122
+ audioCtx.resume();
1123
+ }
1124
+ }
1125
+
1126
+ function decodeBase64ToFloat32(base64) {
1127
+ const binary = atob(base64);
1128
+ const len = binary.length;
1129
+ const bytes = new Uint8Array(len);
1130
+ for (let i = 0; i < len; i++) {
1131
+ bytes[i] = binary.charCodeAt(i);
1132
+ }
1133
+ return new Float32Array(bytes.buffer);
1134
+ }
1135
+
1136
+ function playChunk(samples, sr, idx) {
1137
+ initAudio(sr);
1138
+ const buffer = audioCtx.createBuffer(1, samples.length, sr);
1139
+ buffer.copyToChannel(samples, 0);
1140
+ const source = audioCtx.createBufferSource();
1141
+ source.buffer = buffer;
1142
+ const gain = audioCtx.createGain();
1143
+ source.connect(gain);
1144
+ gain.connect(audioCtx.destination);
1145
+ const now = audioCtx.currentTime;
1146
+ if (nextTime < now + MIN_BUFFER_SEC) {
1147
+ nextTime = now + MIN_BUFFER_SEC;
1148
+ }
1149
+ const startTime = Math.max(now, nextTime);
1150
+ const endTime = startTime + buffer.duration;
1151
+ const fade = Math.min(FADE_MS / 1000.0, buffer.duration / 4);
1152
+ gain.gain.setValueAtTime(0.0, startTime);
1153
+ gain.gain.linearRampToValueAtTime(1.0, startTime + fade);
1154
+ gain.gain.setValueAtTime(1.0, Math.max(startTime + fade, endTime - fade));
1155
+ gain.gain.linearRampToValueAtTime(0.0, endTime);
1156
+ source.start(startTime);
1157
+ nextTime = endTime;
1158
+ }
1159
+
1160
+ function handlePayload(text) {
1161
+ if (!text) return;
1162
+ let payload;
1163
+ try {
1164
+ payload = JSON.parse(text);
1165
+ } catch (e) {
1166
+ return;
1167
+ }
1168
+ if (Array.isArray(payload)) {
1169
+ for (const item of payload) {
1170
+ handlePayloadObject(item);
1171
+ }
1172
+ return;
1173
+ }
1174
+ handlePayloadObject(payload);
1175
+ }
1176
+
1177
+ function handlePayloadObject(payload) {
1178
+ if (!payload) return;
1179
+ if (payload.reset) {
1180
+ lastIdx = -1;
1181
+ lastValue = "";
1182
+ if (audioCtx) {
1183
+ audioCtx.close();
1184
+ audioCtx = null;
1185
+ }
1186
+ return;
1187
+ }
1188
+ const idx = payload.idx ?? 0;
1189
+ if (idx <= lastIdx) return;
1190
+ lastIdx = idx;
1191
+ const sr = payload.sr || 24000;
1192
+ const samples = decodeBase64ToFloat32(payload.data);
1193
+ playChunk(samples, sr, idx);
1194
+ }
1195
+
1196
+ function hookField(field) {
1197
+ if (!field || field === boundField) return;
1198
+ boundField = field;
1199
+ const proto = field.tagName === "TEXTAREA" ? HTMLTextAreaElement.prototype : HTMLInputElement.prototype;
1200
+ const desc = Object.getOwnPropertyDescriptor(proto, "value");
1201
+ if (!desc || !desc.get || !desc.set) {
1202
+ usingSetterHook = false;
1203
+ return;
1204
+ }
1205
+ usingSetterHook = true;
1206
+ const nativeGet = desc.get;
1207
+ const nativeSet = desc.set;
1208
+ Object.defineProperty(field, "value", {
1209
+ configurable: true,
1210
+ get() {
1211
+ return nativeGet.call(field);
1212
+ },
1213
+ set(v) {
1214
+ nativeSet.call(field, v);
1215
+ if (v && v !== lastValue) {
1216
+ lastValue = v;
1217
+ handlePayload(v);
1218
+ }
1219
+ },
1220
+ });
1221
+
1222
+ const initial = field.value;
1223
+ if (initial && initial !== lastValue) {
1224
+ lastValue = initial;
1225
+ handlePayload(initial);
1226
+ }
1227
+ }
1228
+
1229
+ function pollField() {
1230
+ const field = document.querySelector(`#${elemId} textarea, #${elemId} input`);
1231
+ if (!field) {
1232
+ boundField = null;
1233
+ usingSetterHook = false;
1234
+ setTimeout(pollField, 300);
1235
+ return;
1236
+ }
1237
+ if (field !== boundField) {
1238
+ hookField(field);
1239
+ }
1240
+ setTimeout(pollField, 300);
1241
+ }
1242
+
1243
+ function pollValue() {
1244
+ if (usingSetterHook) {
1245
+ setTimeout(pollValue, 500);
1246
+ return;
1247
+ }
1248
+ const field = document.querySelector(`#${elemId} textarea, #${elemId} input`);
1249
+ if (!field) {
1250
+ setTimeout(pollValue, 300);
1251
+ return;
1252
+ }
1253
+ const value = field.value;
1254
+ if (value && value !== lastValue) {
1255
+ lastValue = value;
1256
+ handlePayload(value);
1257
+ }
1258
+ setTimeout(pollValue, 40);
1259
+ }
1260
+
1261
+ function tryUnlockAudio() {
1262
+ if (!audioCtx) {
1263
+ audioCtx = new (window.AudioContext || window.webkitAudioContext)();
1264
+ }
1265
+ if (audioCtx.state === "suspended") {
1266
+ audioCtx.resume();
1267
+ }
1268
+ }
1269
+
1270
+ document.addEventListener("click", (event) => {
1271
+ const btn = event.target.closest("#tts_generate");
1272
+ if (btn) {
1273
+ tryUnlockAudio();
1274
+ }
1275
+ });
1276
+
1277
+ pollField();
1278
+ pollValue();
1279
+ """
1280
+
1281
+
1282
+ def _build_demo(
1283
+ args: argparse.Namespace,
1284
+ tts_demo: StreamingTTSDemo,
1285
+ warmup_manager: WarmupManager,
1286
+ ):
1287
+ initial_warmup_snapshot = warmup_manager.snapshot()
1288
+ with gr.Blocks(title="MossTTSRealtime") as demo:
1289
+ gr.Markdown("MossTTSRealtime demo")
1290
+ gr.Markdown("Note: The first run may take a while to load the model.")
1291
+ gr.HTML(STREAM_PLAYER_HTML, js_on_load=STREAM_PLAYER_JS)
1292
+
1293
+ with gr.Row():
1294
+ with gr.Column():
1295
+ user_text = gr.Textbox(label="User Text(optional)", lines=2)
1296
+ assistant_text = gr.Textbox(label="Assistant Text", lines=6)
1297
+ prompt_audio = gr.Audio(label="Prompt WAV (optional)", type="filepath")
1298
+ user_audio = gr.Audio(label="User WAV (optional)", type="filepath")
1299
+ use_default_prompt = gr.Checkbox(label="Use Default Prompt WAV (fallback)", value=False)
1300
+ use_default_user = gr.Checkbox(label="Use Default User WAV (fallback)", value=False)
1301
+
1302
+ with gr.Accordion("Generation Options", open=False):
1303
+ temperature = gr.Slider(0.1, 1.5, value=0.8, step=0.05, label="Temperature")
1304
+ top_p = gr.Slider(0.1, 1.0, value=0.6, step=0.05, label="Top P")
1305
+ top_k = gr.Slider(1, 100, value=30, step=1, label="Top K")
1306
+ repetition_penalty = gr.Slider(1.0, 2.0, value=1.1, step=0.05, label="Repetition Penalty")
1307
+ repetition_window = gr.Slider(
1308
+ 1, 200, value=DEFAULT_REPETITION_WINDOW, step=1, label="Repetition Window"
1309
+ )
1310
+ do_sample = gr.Checkbox(label="Do Sample", value=True)
1311
+ max_length = gr.Slider(100, 10000, value=2000, step=10, label="Max Length")
1312
+ seed = gr.Number(value=0, precision=0, label="Seed (0 for random)")
1313
+
1314
+ with gr.Accordion("Streaming Options", open=False):
1315
+ stream_text_chunk_tokens = gr.Slider(1, 64, value=12, step=1, label="Text Chunk Tokens")
1316
+ stream_input_delay = gr.Slider(0.0, 0.5, value=0.0, step=0.05, label="Input Delay (s)")
1317
+ stream_decode_chunk_frames = gr.Slider(1, 20, value=12, step=1, label="Decode Chunk Frames")
1318
+ stream_decode_overlap_frames = gr.Slider(0, 10, value=0, step=1, label="Decode Overlap Frames")
1319
+ chunk_duration = gr.Slider(0.01, 1.0, value=0.24, step=0.01, label="Codec Chunk Duration (s)")
1320
+ stream_prebuffer_seconds = gr.Slider(0.0, 20.0, value=0.0, step=0.05, label="Initial Buffer (s)")
1321
+
1322
+ run_btn = gr.Button(
1323
+ "Generate" if initial_warmup_snapshot.ready else "Warming Up...",
1324
+ elem_id="tts_generate",
1325
+ interactive=initial_warmup_snapshot.ready,
1326
+ )
1327
+
1328
+ with gr.Column():
1329
+ stream_data = gr.Textbox(label="PCM Stream (JSON)", elem_id="pcm_stream", interactive=False, lines=6)
1330
+ output_audio = gr.Audio(label="Final Audio", type="numpy")
1331
+ initial_status = _status_from_snapshot(initial_warmup_snapshot)
1332
+ status = gr.Textbox(label="Status", lines=3, value=initial_status)
1333
+
1334
+ warmup_timer = gr.Timer(value=WARMUP_POLL_INTERVAL_SECONDS, active=True)
1335
+
1336
+ def _poll_warmup_state():
1337
+ snapshot = warmup_manager.snapshot()
1338
+ return (
1339
+ _warmup_button_update(snapshot),
1340
+ _warmup_status_update(snapshot),
1341
+ _warmup_timer_update(snapshot),
1342
+ )
1343
+
1344
+ def _on_generate(
1345
+ user_text_value,
1346
+ assistant_text_value,
1347
+ prompt_audio_value,
1348
+ user_audio_value,
1349
+ use_default_prompt_value,
1350
+ use_default_user_value,
1351
+ temperature_value,
1352
+ top_p_value,
1353
+ top_k_value,
1354
+ repetition_penalty_value,
1355
+ repetition_window_value,
1356
+ do_sample_value,
1357
+ max_length_value,
1358
+ seed_value,
1359
+ stream_text_chunk_tokens_value,
1360
+ stream_input_delay_value,
1361
+ stream_decode_chunk_frames_value,
1362
+ stream_decode_overlap_frames_value,
1363
+ chunk_duration_value,
1364
+ stream_prebuffer_seconds_value,
1365
+ ):
1366
+ warmup_snapshot = warmup_manager.snapshot()
1367
+ if not warmup_snapshot.ready:
1368
+ yield json.dumps({"reset": True}), gr.update(value=None), _warmup_gate_message(warmup_snapshot)
1369
+ return
1370
+ try:
1371
+ started_at = time.monotonic()
1372
+ full_chunks: list[np.ndarray] = []
1373
+ first_audio_time: float | None = None
1374
+ sample_rate = SAMPLE_RATE
1375
+ rtf_logger: StreamRTFLogger | None = None
1376
+ yield json.dumps({"reset": True}), gr.update(value=None), "Started"
1377
+
1378
+ request = _build_request(
1379
+ args,
1380
+ user_text=user_text_value,
1381
+ assistant_text=assistant_text_value,
1382
+ prompt_audio=prompt_audio_value,
1383
+ user_audio=user_audio_value,
1384
+ use_default_prompt=bool(use_default_prompt_value),
1385
+ use_default_user=bool(use_default_user_value),
1386
+ temperature=float(temperature_value),
1387
+ top_p=float(top_p_value),
1388
+ top_k=int(top_k_value),
1389
+ repetition_penalty=float(repetition_penalty_value),
1390
+ repetition_window=int(repetition_window_value),
1391
+ do_sample=bool(do_sample_value),
1392
+ max_length=int(max_length_value),
1393
+ seed=seed_value,
1394
+ text_chunk_tokens=int(stream_text_chunk_tokens_value),
1395
+ input_delay=float(stream_input_delay_value),
1396
+ decode_chunk_frames=int(stream_decode_chunk_frames_value),
1397
+ decode_overlap_frames=int(stream_decode_overlap_frames_value),
1398
+ chunk_duration=float(chunk_duration_value),
1399
+ prebuffer_seconds=float(stream_prebuffer_seconds_value),
1400
+ )
1401
+ rtf_logger = StreamRTFLogger.create(request, started_at)
1402
+
1403
+ for event in tts_demo.run_stream(request):
1404
+ if event.audio is None:
1405
+ yield gr.update(), gr.update(), event.message
1406
+ continue
1407
+
1408
+ sr, chunk = event.audio
1409
+ chunk = np.asarray(chunk).reshape(-1)
1410
+ if chunk.size == 0:
1411
+ continue
1412
+ full_chunks.append(chunk)
1413
+ sample_rate = sr
1414
+ idx = len(full_chunks)
1415
+ if first_audio_time is None:
1416
+ first_audio_time = time.monotonic()
1417
+ if rtf_logger is not None:
1418
+ rtf_logger.log_chunk(
1419
+ event_message=event.message,
1420
+ sample_rate=sr,
1421
+ chunk=chunk,
1422
+ first_audio_time=first_audio_time,
1423
+ )
1424
+ payload = _encode_chunk(sr, chunk, idx)
1425
+ ttfa_ms = (first_audio_time - started_at) * 1000.0 if first_audio_time is not None else float("nan")
1426
+ status_msg = f"{event.message} | chunks={idx} | ttfa={ttfa_ms:.0f}ms"
1427
+ yield payload, gr.update(), status_msg
1428
+
1429
+ if full_chunks:
1430
+ full_audio = np.concatenate(full_chunks)
1431
+ if rtf_logger is not None:
1432
+ rtf_logger.log_completion(first_audio_time=first_audio_time)
1433
+ done_msg = _format_completion_status(
1434
+ len(full_chunks),
1435
+ sample_rate,
1436
+ full_audio,
1437
+ started_at,
1438
+ first_audio_time,
1439
+ )
1440
+ yield gr.update(), (sample_rate, full_audio), done_msg
1441
+ else:
1442
+ if rtf_logger is not None:
1443
+ rtf_logger.log_no_audio()
1444
+ yield gr.update(), gr.update(), "Done | no audio chunks emitted"
1445
+ except Exception as exc:
1446
+ import traceback
1447
+ traceback.print_exc()
1448
+ if rtf_logger is not None:
1449
+ rtf_logger.log_error(exc, first_audio_time=first_audio_time)
1450
+ yield gr.update(), gr.update(), f"Error: {exc}"
1451
+
1452
+ run_btn.click(
1453
+ _on_generate,
1454
+ inputs=[
1455
+ user_text,
1456
+ assistant_text,
1457
+ prompt_audio,
1458
+ user_audio,
1459
+ use_default_prompt,
1460
+ use_default_user,
1461
+ temperature,
1462
+ top_p,
1463
+ top_k,
1464
+ repetition_penalty,
1465
+ repetition_window,
1466
+ do_sample,
1467
+ max_length,
1468
+ seed,
1469
+ stream_text_chunk_tokens,
1470
+ stream_input_delay,
1471
+ stream_decode_chunk_frames,
1472
+ stream_decode_overlap_frames,
1473
+ chunk_duration,
1474
+ stream_prebuffer_seconds,
1475
+ ],
1476
+ outputs=[stream_data, output_audio, status],
1477
+ )
1478
+ demo.load(
1479
+ _poll_warmup_state,
1480
+ outputs=[run_btn, status, warmup_timer],
1481
+ queue=False,
1482
+ show_progress="hidden",
1483
+ )
1484
+ warmup_timer.tick(
1485
+ _poll_warmup_state,
1486
+ outputs=[run_btn, status, warmup_timer],
1487
+ queue=False,
1488
+ show_progress="hidden",
1489
+ )
1490
+
1491
+ return demo
1492
+
1493
+
1494
+ def main():
1495
+ parser = argparse.ArgumentParser(description="MossTTSRealtime streaming TTS Gradio demo")
1496
+ parser.add_argument("--model_path", type=str, default=MODEL_PATH)
1497
+ parser.add_argument("--tokenizer_path", type=str, default=TOKENIZER_PATH)
1498
+ parser.add_argument("--codec_model_path", type=str, default=CODEC_MODEL_PATH)
1499
+ parser.add_argument("--device", type=str, default="cuda:0")
1500
+ parser.add_argument(
1501
+ "--attn_implementation",
1502
+ type=str,
1503
+ default="sdpa",
1504
+ choices=["sdpa", "flash_attention_2", "eager", "none"],
1505
+ )
1506
+ parser.add_argument("--host", type=str, default="0.0.0.0")
1507
+ parser.add_argument("--port", type=int, default=8082)
1508
+ parser.add_argument("--share", action="store_true")
1509
+ args = parser.parse_args()
1510
+
1511
+ tts_demo = StreamingTTSDemo()
1512
+ warmup_manager = WarmupManager(
1513
+ tts_demo,
1514
+ BackendPaths(
1515
+ model_path=args.model_path,
1516
+ tokenizer_path=args.tokenizer_path,
1517
+ codec_model_path=args.codec_model_path,
1518
+ device_str=args.device,
1519
+ attn_impl=args.attn_implementation,
1520
+ ),
1521
+ )
1522
+ warmup_manager.start()
1523
+ demo = _build_demo(args, tts_demo, warmup_manager)
1524
+ demo.queue(max_size=10, default_concurrency_limit=1).launch(
1525
+ server_name=args.host,
1526
+ server_port=args.port,
1527
+ share=args.share,
1528
+ )
1529
+
1530
+
1531
+ if __name__ == "__main__":
1532
+ main()