Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Qwen3-TTS Demo Server (Oficial) | |
| CPU Optimized & RAM Safe | |
| """ | |
| import argparse | |
| import asyncio | |
| import base64 | |
| from collections import OrderedDict | |
| import hashlib | |
| import io | |
| import json | |
| import os | |
| import sys | |
| import tempfile | |
| import threading | |
| import time | |
| import gc | |
| from pathlib import Path | |
| import numpy as np | |
| import soundfile as sf | |
| import torch | |
| import torchaudio | |
| import uvicorn | |
| from fastapi import FastAPI, File, Form, HTTPException, UploadFile | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import FileResponse, JSONResponse, StreamingResponse | |
| # OPTIMIZACIÓN CPU | |
| torch.set_num_threads(4) | |
| # Allow running from any directory | |
| sys.path.insert(0, str(Path(__file__).parent.parent)) | |
| try: | |
| from qwen_tts import Qwen3TTSModel | |
| except ImportError: | |
| print("Error: qwen-tts no está instalado.") | |
| sys.exit(1) | |
| from nano_parakeet import from_pretrained as _parakeet_from_pretrained | |
| _ALL_MODELS =[ | |
| "Qwen/Qwen3-TTS-12Hz-0.6B-Base", | |
| "Qwen/Qwen3-TTS-12Hz-1.7B-Base", | |
| "Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice", | |
| "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice", | |
| "Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign", | |
| ] | |
| _active_models_env = os.environ.get("ACTIVE_MODELS", "") | |
| if _active_models_env: | |
| _allowed = {m.strip() for m in _active_models_env.split(",") if m.strip()} | |
| AVAILABLE_MODELS =[m for m in _ALL_MODELS if m in _allowed] | |
| else: | |
| AVAILABLE_MODELS = list(_ALL_MODELS) | |
| BASE_DIR = Path(__file__).resolve().parent | |
| _ASSET_DIR = Path(os.environ.get("ASSET_DIR", "/tmp/qwen3-tts-assets")) | |
| PRESET_TRANSCRIPTS = _ASSET_DIR / "samples" / "parity" / "icl_transcripts.txt" | |
| # Restauradas exactamente las voces originales de clone | |
| PRESET_REFS =[ | |
| ("ref_audio_3", _ASSET_DIR / "ref_audio_3.wav", "Clone 1"), | |
| ("ref_audio_2", _ASSET_DIR / "ref_audio_2.wav", "Clone 2"), | |
| ("ref_audio", _ASSET_DIR / "ref_audio.wav", "Clone 3"), | |
| ] | |
| _GITHUB_RAW = "https://raw.githubusercontent.com/andimarafioti/faster-qwen3-tts/main" | |
| _PRESET_REMOTE = { | |
| "ref_audio": f"{_GITHUB_RAW}/ref_audio.wav", | |
| "ref_audio_2": f"{_GITHUB_RAW}/ref_audio_2.wav", | |
| "ref_audio_3": f"{_GITHUB_RAW}/ref_audio_3.wav", | |
| } | |
| _TRANSCRIPT_REMOTE = f"{_GITHUB_RAW}/samples/parity/icl_transcripts.txt" | |
| def _fetch_preset_assets() -> None: | |
| """Download preset wav files and transcripts from GitHub if not present locally.""" | |
| import urllib.request | |
| _ASSET_DIR.mkdir(parents=True, exist_ok=True) | |
| PRESET_TRANSCRIPTS.parent.mkdir(parents=True, exist_ok=True) | |
| if not PRESET_TRANSCRIPTS.exists(): | |
| try: | |
| urllib.request.urlretrieve(_TRANSCRIPT_REMOTE, PRESET_TRANSCRIPTS) | |
| except Exception as e: | |
| print(f"Warning: could not fetch transcripts: {e}") | |
| for key, path, _ in PRESET_REFS: | |
| if not path.exists() and key in _PRESET_REMOTE: | |
| try: | |
| urllib.request.urlretrieve(_PRESET_REMOTE[key], path) | |
| print(f"Downloaded {path.name}") | |
| except Exception as e: | |
| print(f"Warning: could not fetch {key}: {e}") | |
| _preset_refs: dict[str, dict] = {} | |
| def _load_preset_transcripts() -> dict[str, str]: | |
| if not PRESET_TRANSCRIPTS.exists(): | |
| return {} | |
| transcripts = {} | |
| for line in PRESET_TRANSCRIPTS.read_text(encoding="utf-8").splitlines(): | |
| if ":" not in line: | |
| continue | |
| key_part, text = line.split(":", 1) | |
| key = key_part.split("(")[0].strip() | |
| transcripts[key] = text.strip() | |
| return transcripts | |
| def _load_preset_refs() -> None: | |
| transcripts = _load_preset_transcripts() | |
| for key, path, label in PRESET_REFS: | |
| if not path.exists(): | |
| continue | |
| content = path.read_bytes() | |
| cached_path = _get_cached_ref_path(content) | |
| _preset_refs[key] = { | |
| "id": key, | |
| "label": label, | |
| "filename": path.name, | |
| "path": cached_path, | |
| "ref_text": transcripts.get(key, ""), | |
| "audio_b64": base64.b64encode(content).decode(), | |
| } | |
| app = FastAPI(title="Qwen3-TTS Demo") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| _model_cache: OrderedDict[str, Qwen3TTSModel] = OrderedDict() | |
| _active_model_name: str | None = None | |
| _loading = False | |
| _ref_cache: dict[str, str] = {} | |
| _ref_cache_lock = threading.Lock() | |
| _parakeet = None | |
| _generation_lock = asyncio.Lock() | |
| _generation_waiters: int = 0 | |
| MAX_TEXT_CHARS = 1000 | |
| MAX_AUDIO_BYTES = 10 * 1024 * 1024 | |
| _AUDIO_TOO_LARGE_MSG = ( | |
| "Audio file too large ({size_mb:.1f} MB). " | |
| "Voice cloning works best with short clips under 1 minute — please upload a shorter recording." | |
| ) | |
| # ─── Helpers ────────────────────────────────────────────────────────────────── | |
| def _to_wav_b64(audio: np.ndarray, sr: int) -> str: | |
| if audio.dtype != np.float32: | |
| audio = audio.astype(np.float32) | |
| if audio.ndim > 1: | |
| audio = audio.squeeze() | |
| buf = io.BytesIO() | |
| sf.write(buf, audio, sr, format="WAV", subtype="PCM_16") | |
| b64 = base64.b64encode(buf.getvalue()).decode() | |
| return b64 | |
| def _concat_audio(audio_list) -> np.ndarray: | |
| if isinstance(audio_list, np.ndarray): | |
| return audio_list.astype(np.float32).squeeze() | |
| parts =[np.array(a, dtype=np.float32).squeeze() for a in audio_list if len(a) > 0] | |
| return np.concatenate(parts) if parts else np.zeros(0, dtype=np.float32) | |
| def _get_cached_ref_path(content: bytes) -> str: | |
| digest = hashlib.sha1(content).hexdigest() | |
| with _ref_cache_lock: | |
| cached = _ref_cache.get(digest) | |
| if cached and os.path.exists(cached): | |
| return cached | |
| tmp_dir = Path(tempfile.gettempdir()) | |
| path = tmp_dir / f"qwen3_tts_ref_{digest}.wav" | |
| if not path.exists(): | |
| path.write_bytes(content) | |
| _ref_cache[digest] = str(path) | |
| return str(path) | |
| # ─── Routes ─────────────────────────────────────────────────────────────────── | |
| _fetch_preset_assets() | |
| _load_preset_refs() | |
| async def root(): | |
| return FileResponse(Path(__file__).parent / "index.html") | |
| async def transcribe_audio(audio: UploadFile = File(...)): | |
| if _parakeet is None: | |
| raise HTTPException(status_code=503, detail="Transcription model not loaded") | |
| content = await audio.read() | |
| if len(content) > MAX_AUDIO_BYTES: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=_AUDIO_TOO_LARGE_MSG.format(size_mb=len(content) / 1024 / 1024), | |
| ) | |
| def run(): | |
| wav, sr = sf.read(io.BytesIO(content), dtype="float32", always_2d=False) | |
| if wav.ndim > 1: | |
| wav = wav.mean(axis=1) | |
| wav_t = torch.from_numpy(wav) | |
| if sr != 16000: | |
| wav_t = torchaudio.functional.resample(wav_t.unsqueeze(0), sr, 16000).squeeze(0) | |
| return _parakeet.transcribe(wav_t) | |
| text = await asyncio.to_thread(run) | |
| return {"text": text} | |
| async def get_status(): | |
| speakers =[] | |
| model_type = None | |
| active = _model_cache.get(_active_model_name) if _active_model_name else None | |
| if active is not None: | |
| try: | |
| model_type = "official" | |
| speakers = active.get_supported_speakers() or [] | |
| except Exception: | |
| speakers =[] | |
| return { | |
| "loaded": active is not None, | |
| "model": _active_model_name, | |
| "loading": _loading, | |
| "available_models": AVAILABLE_MODELS, | |
| "model_type": model_type, | |
| "speakers": speakers, | |
| "transcription_available": _parakeet is not None, | |
| "preset_refs": [ | |
| {"id": p["id"], "label": p["label"], "ref_text": p["ref_text"]} | |
| for p in _preset_refs.values() | |
| ], | |
| "queue_depth": _generation_waiters, | |
| "cached_models": list(_model_cache.keys()), | |
| } | |
| async def get_preset_ref(preset_id: str): | |
| preset = _preset_refs.get(preset_id) | |
| if not preset: | |
| raise HTTPException(status_code=404, detail="Preset not found") | |
| return { | |
| "id": preset["id"], | |
| "label": preset["label"], | |
| "filename": preset["filename"], | |
| "ref_text": preset["ref_text"], | |
| "audio_b64": preset["audio_b64"], | |
| } | |
| async def load_model(model_id: str = Form(...)): | |
| global _active_model_name, _loading | |
| if model_id in _model_cache: | |
| _active_model_name = model_id | |
| return {"status": "already_loaded", "model": model_id} | |
| _loading = True | |
| def _do_load(): | |
| global _active_model_name, _loading | |
| try: | |
| # 🛡️ PROTECCIÓN DE RAM CRÍTICA: | |
| # Si hay algún modelo anterior, lo destruimos y forzamos vaciado de RAM | |
| if len(_model_cache) > 0: | |
| _model_cache.clear() | |
| gc.collect() | |
| new_model = Qwen3TTSModel.from_pretrained( | |
| model_id, | |
| device_map="cpu", | |
| dtype=torch.float32, | |
| ) | |
| _model_cache[model_id] = new_model | |
| _active_model_name = model_id | |
| print(f"Modelo {model_id} cargado exitosamente en CPU.") | |
| finally: | |
| _loading = False | |
| async with _generation_lock: | |
| await asyncio.to_thread(_do_load) | |
| return {"status": "loaded", "model": model_id} | |
| async def generate_stream( | |
| text: str = Form(...), | |
| language: str = Form("Spanish"), | |
| mode: str = Form("voice_clone"), | |
| ref_text: str = Form(""), | |
| speaker: str = Form(""), | |
| instruct: str = Form(""), | |
| xvec_only: bool = Form(True), | |
| chunk_size: int = Form(8), | |
| temperature: float = Form(0.9), | |
| top_k: int = Form(50), | |
| repetition_penalty: float = Form(1.05), | |
| ref_preset: str = Form(""), | |
| ref_audio: UploadFile = File(None), | |
| ): | |
| if not _active_model_name or _active_model_name not in _model_cache: | |
| raise HTTPException(status_code=400, detail="Model not loaded. Click 'Load' first.") | |
| if len(text) > MAX_TEXT_CHARS: | |
| raise HTTPException(status_code=400, detail="Text too long.") | |
| tmp_path = None | |
| tmp_is_cached = False | |
| if ref_preset and ref_preset in _preset_refs: | |
| preset = _preset_refs[ref_preset] | |
| tmp_path = preset["path"] | |
| tmp_is_cached = True | |
| if not ref_text: | |
| ref_text = preset["ref_text"] | |
| elif ref_audio and ref_audio.filename: | |
| content = await ref_audio.read() | |
| if len(content) > MAX_AUDIO_BYTES: | |
| raise HTTPException(status_code=400, detail=_AUDIO_TOO_LARGE_MSG) | |
| tmp_path = _get_cached_ref_path(content) | |
| tmp_is_cached = True | |
| loop = asyncio.get_event_loop() | |
| queue: asyncio.Queue[str | None] = asyncio.Queue() | |
| def run_generation(): | |
| try: | |
| model = _model_cache.get(_active_model_name) | |
| t0 = time.perf_counter() | |
| # En CPU con la librería oficial procesamos todo y enviamos en un solo bloque | |
| # para mantener la estabilidad del frontend y evitar "NoneType". | |
| if mode == "voice_clone": | |
| audio_list, sr = model.generate_voice_clone( | |
| text=text, language=language, ref_audio=tmp_path, ref_text=ref_text, | |
| x_vector_only_mode=xvec_only, temperature=temperature, top_k=top_k, | |
| repetition_penalty=repetition_penalty, max_new_tokens=360, device="cpu" | |
| ) | |
| elif mode == "custom": | |
| if not speaker: raise ValueError("Speaker ID is required") | |
| audio_list, sr = model.generate_custom_voice( | |
| text=text, speaker=speaker, language=language, instruct=instruct, | |
| temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty, | |
| max_new_tokens=360, device="cpu" | |
| ) | |
| else: | |
| audio_list, sr = model.generate_voice_design( | |
| text=text, instruct=instruct, language=language, | |
| temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty, | |
| max_new_tokens=360, device="cpu" | |
| ) | |
| elapsed = time.perf_counter() - t0 | |
| chunk_audio = _concat_audio(audio_list) | |
| dur = len(chunk_audio) / sr | |
| rtf = dur / elapsed if elapsed > 0 else 0.0 | |
| ttfa_ms = round(elapsed * 1000) | |
| audio_b64 = _to_wav_b64(chunk_audio, sr) | |
| payload = { | |
| "type": "chunk", "audio_b64": audio_b64, "sample_rate": sr, | |
| "ttfa_ms": ttfa_ms, "voice_clone_ms": 0, "rtf": round(rtf, 3), | |
| "total_audio_s": round(dur, 3), "elapsed_ms": ttfa_ms | |
| } | |
| loop.call_soon_threadsafe(queue.put_nowait, json.dumps(payload)) | |
| done_payload = { | |
| "type": "done", "ttfa_ms": ttfa_ms, "voice_clone_ms": 0, | |
| "rtf": round(rtf, 3), "total_audio_s": round(dur, 3), "total_ms": ttfa_ms | |
| } | |
| loop.call_soon_threadsafe(queue.put_nowait, json.dumps(done_payload)) | |
| except Exception as e: | |
| import traceback | |
| err = {"type": "error", "message": str(e), "detail": traceback.format_exc()} | |
| loop.call_soon_threadsafe(queue.put_nowait, json.dumps(err)) | |
| finally: | |
| loop.call_soon_threadsafe(queue.put_nowait, None) | |
| if tmp_path and os.path.exists(tmp_path) and not tmp_is_cached: | |
| os.unlink(tmp_path) | |
| async def sse(): | |
| global _generation_waiters | |
| lock_acquired = False | |
| _generation_waiters += 1 | |
| people_ahead = _generation_waiters - 1 + (1 if _generation_lock.locked() else 0) | |
| try: | |
| if people_ahead > 0: | |
| yield f"data: {json.dumps({'type': 'queued', 'position': people_ahead})}\n\n" | |
| await _generation_lock.acquire() | |
| lock_acquired = True | |
| _generation_waiters -= 1 | |
| thread = threading.Thread(target=run_generation, daemon=True) | |
| thread.start() | |
| while True: | |
| msg = await queue.get() | |
| if msg is None: break | |
| yield f"data: {msg}\n\n" | |
| except asyncio.CancelledError: | |
| pass | |
| finally: | |
| if lock_acquired: _generation_lock.release() | |
| else: _generation_waiters -= 1 | |
| return StreamingResponse(sse(), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}) | |
| async def generate_non_streaming( | |
| text: str = Form(...), language: str = Form("Spanish"), mode: str = Form("voice_clone"), | |
| ref_text: str = Form(""), speaker: str = Form(""), instruct: str = Form(""), | |
| xvec_only: bool = Form(True), temperature: float = Form(0.9), top_k: int = Form(50), | |
| repetition_penalty: float = Form(1.05), ref_preset: str = Form(""), ref_audio: UploadFile = File(None), | |
| ): | |
| if not _active_model_name or _active_model_name not in _model_cache: | |
| raise HTTPException(status_code=400, detail="Model not loaded. Click 'Load' first.") | |
| tmp_path = None | |
| tmp_is_cached = False | |
| if ref_preset and ref_preset in _preset_refs: | |
| preset = _preset_refs[ref_preset] | |
| tmp_path = preset["path"] | |
| tmp_is_cached = True | |
| elif ref_audio and ref_audio.filename: | |
| content = await ref_audio.read() | |
| tmp_path = _get_cached_ref_path(content) | |
| tmp_is_cached = True | |
| def run(): | |
| model = _model_cache.get(_active_model_name) | |
| t0 = time.perf_counter() | |
| if mode == "voice_clone": | |
| audio_list, sr = model.generate_voice_clone( | |
| text=text, language=language, ref_audio=tmp_path, ref_text=ref_text, | |
| x_vector_only_mode=xvec_only, temperature=temperature, top_k=top_k, | |
| repetition_penalty=repetition_penalty, max_new_tokens=360, device="cpu" | |
| ) | |
| elif mode == "custom": | |
| audio_list, sr = model.generate_custom_voice( | |
| text=text, speaker=speaker, language=language, instruct=instruct, | |
| temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty, | |
| max_new_tokens=360, device="cpu" | |
| ) | |
| else: | |
| audio_list, sr = model.generate_voice_design( | |
| text=text, instruct=instruct, language=language, | |
| temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty, | |
| max_new_tokens=360, device="cpu" | |
| ) | |
| elapsed = time.perf_counter() - t0 | |
| audio = _concat_audio(audio_list) | |
| dur = len(audio) / sr | |
| return audio, sr, elapsed, dur | |
| global _generation_waiters | |
| _generation_waiters += 1 | |
| lock_acquired = False | |
| try: | |
| await _generation_lock.acquire() | |
| lock_acquired = True | |
| _generation_waiters -= 1 | |
| audio, sr, elapsed, dur = await asyncio.to_thread(run) | |
| rtf = dur / elapsed if elapsed > 0 else 0.0 | |
| return JSONResponse({ | |
| "audio_b64": _to_wav_b64(audio, sr), | |
| "sample_rate": sr, | |
| "metrics": {"total_ms": round(elapsed * 1000), "audio_duration_s": round(dur, 3), "rtf": round(rtf, 3)}, | |
| }) | |
| finally: | |
| if lock_acquired: _generation_lock.release() | |
| else: _generation_waiters -= 1 | |
| if tmp_path and os.path.exists(tmp_path) and not tmp_is_cached: | |
| os.unlink(tmp_path) | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Qwen3-TTS Demo Server") | |
| parser.add_argument("--model", default="Qwen/Qwen3-TTS-12Hz-0.6B-Base", help="Model to preload at startup") | |
| parser.add_argument("--port", type=int, default=int(os.environ.get("PORT", 7860))) | |
| parser.add_argument("--host", default="0.0.0.0") | |
| parser.add_argument("--no-preload", action="store_true", help="Skip model loading at startup") | |
| args = parser.parse_args() | |
| if not args.no_preload: | |
| global _active_model_name, _parakeet | |
| print(f"Loading model: {args.model}") | |
| _startup_model = Qwen3TTSModel.from_pretrained(args.model, device_map="cpu", dtype=torch.float32) | |
| _model_cache[args.model] = _startup_model | |
| _active_model_name = args.model | |
| print("Loading transcription model (nano-parakeet)…") | |
| _parakeet = _parakeet_from_pretrained(device="cpu") | |
| print("Transcription model ready on CPU.") | |
| print(f"Ready. Open http://localhost:{args.port}") | |
| uvicorn.run(app, host=args.host, port=args.port, log_level="info") | |
| if __name__ == "__main__": | |
| main() |