generaqtts / server.py
userhugginggit's picture
Update server.py
a74d59a verified
#!/usr/bin/env python3
"""
Qwen3-TTS Demo Server (Oficial)
CPU Optimized & RAM Safe
"""
import argparse
import asyncio
import base64
from collections import OrderedDict
import hashlib
import io
import json
import os
import sys
import tempfile
import threading
import time
import gc
from pathlib import Path
import numpy as np
import soundfile as sf
import torch
import torchaudio
import uvicorn
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
# OPTIMIZACIÓN CPU
torch.set_num_threads(4)
# Allow running from any directory
sys.path.insert(0, str(Path(__file__).parent.parent))
try:
from qwen_tts import Qwen3TTSModel
except ImportError:
print("Error: qwen-tts no está instalado.")
sys.exit(1)
from nano_parakeet import from_pretrained as _parakeet_from_pretrained
_ALL_MODELS =[
"Qwen/Qwen3-TTS-12Hz-0.6B-Base",
"Qwen/Qwen3-TTS-12Hz-1.7B-Base",
"Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice",
"Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice",
"Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign",
]
_active_models_env = os.environ.get("ACTIVE_MODELS", "")
if _active_models_env:
_allowed = {m.strip() for m in _active_models_env.split(",") if m.strip()}
AVAILABLE_MODELS =[m for m in _ALL_MODELS if m in _allowed]
else:
AVAILABLE_MODELS = list(_ALL_MODELS)
BASE_DIR = Path(__file__).resolve().parent
_ASSET_DIR = Path(os.environ.get("ASSET_DIR", "/tmp/qwen3-tts-assets"))
PRESET_TRANSCRIPTS = _ASSET_DIR / "samples" / "parity" / "icl_transcripts.txt"
# Restauradas exactamente las voces originales de clone
PRESET_REFS =[
("ref_audio_3", _ASSET_DIR / "ref_audio_3.wav", "Clone 1"),
("ref_audio_2", _ASSET_DIR / "ref_audio_2.wav", "Clone 2"),
("ref_audio", _ASSET_DIR / "ref_audio.wav", "Clone 3"),
]
_GITHUB_RAW = "https://raw.githubusercontent.com/andimarafioti/faster-qwen3-tts/main"
_PRESET_REMOTE = {
"ref_audio": f"{_GITHUB_RAW}/ref_audio.wav",
"ref_audio_2": f"{_GITHUB_RAW}/ref_audio_2.wav",
"ref_audio_3": f"{_GITHUB_RAW}/ref_audio_3.wav",
}
_TRANSCRIPT_REMOTE = f"{_GITHUB_RAW}/samples/parity/icl_transcripts.txt"
def _fetch_preset_assets() -> None:
"""Download preset wav files and transcripts from GitHub if not present locally."""
import urllib.request
_ASSET_DIR.mkdir(parents=True, exist_ok=True)
PRESET_TRANSCRIPTS.parent.mkdir(parents=True, exist_ok=True)
if not PRESET_TRANSCRIPTS.exists():
try:
urllib.request.urlretrieve(_TRANSCRIPT_REMOTE, PRESET_TRANSCRIPTS)
except Exception as e:
print(f"Warning: could not fetch transcripts: {e}")
for key, path, _ in PRESET_REFS:
if not path.exists() and key in _PRESET_REMOTE:
try:
urllib.request.urlretrieve(_PRESET_REMOTE[key], path)
print(f"Downloaded {path.name}")
except Exception as e:
print(f"Warning: could not fetch {key}: {e}")
_preset_refs: dict[str, dict] = {}
def _load_preset_transcripts() -> dict[str, str]:
if not PRESET_TRANSCRIPTS.exists():
return {}
transcripts = {}
for line in PRESET_TRANSCRIPTS.read_text(encoding="utf-8").splitlines():
if ":" not in line:
continue
key_part, text = line.split(":", 1)
key = key_part.split("(")[0].strip()
transcripts[key] = text.strip()
return transcripts
def _load_preset_refs() -> None:
transcripts = _load_preset_transcripts()
for key, path, label in PRESET_REFS:
if not path.exists():
continue
content = path.read_bytes()
cached_path = _get_cached_ref_path(content)
_preset_refs[key] = {
"id": key,
"label": label,
"filename": path.name,
"path": cached_path,
"ref_text": transcripts.get(key, ""),
"audio_b64": base64.b64encode(content).decode(),
}
app = FastAPI(title="Qwen3-TTS Demo")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
_model_cache: OrderedDict[str, Qwen3TTSModel] = OrderedDict()
_active_model_name: str | None = None
_loading = False
_ref_cache: dict[str, str] = {}
_ref_cache_lock = threading.Lock()
_parakeet = None
_generation_lock = asyncio.Lock()
_generation_waiters: int = 0
MAX_TEXT_CHARS = 1000
MAX_AUDIO_BYTES = 10 * 1024 * 1024
_AUDIO_TOO_LARGE_MSG = (
"Audio file too large ({size_mb:.1f} MB). "
"Voice cloning works best with short clips under 1 minute — please upload a shorter recording."
)
# ─── Helpers ──────────────────────────────────────────────────────────────────
def _to_wav_b64(audio: np.ndarray, sr: int) -> str:
if audio.dtype != np.float32:
audio = audio.astype(np.float32)
if audio.ndim > 1:
audio = audio.squeeze()
buf = io.BytesIO()
sf.write(buf, audio, sr, format="WAV", subtype="PCM_16")
b64 = base64.b64encode(buf.getvalue()).decode()
return b64
def _concat_audio(audio_list) -> np.ndarray:
if isinstance(audio_list, np.ndarray):
return audio_list.astype(np.float32).squeeze()
parts =[np.array(a, dtype=np.float32).squeeze() for a in audio_list if len(a) > 0]
return np.concatenate(parts) if parts else np.zeros(0, dtype=np.float32)
def _get_cached_ref_path(content: bytes) -> str:
digest = hashlib.sha1(content).hexdigest()
with _ref_cache_lock:
cached = _ref_cache.get(digest)
if cached and os.path.exists(cached):
return cached
tmp_dir = Path(tempfile.gettempdir())
path = tmp_dir / f"qwen3_tts_ref_{digest}.wav"
if not path.exists():
path.write_bytes(content)
_ref_cache[digest] = str(path)
return str(path)
# ─── Routes ───────────────────────────────────────────────────────────────────
_fetch_preset_assets()
_load_preset_refs()
@app.get("/")
async def root():
return FileResponse(Path(__file__).parent / "index.html")
@app.post("/transcribe")
async def transcribe_audio(audio: UploadFile = File(...)):
if _parakeet is None:
raise HTTPException(status_code=503, detail="Transcription model not loaded")
content = await audio.read()
if len(content) > MAX_AUDIO_BYTES:
raise HTTPException(
status_code=400,
detail=_AUDIO_TOO_LARGE_MSG.format(size_mb=len(content) / 1024 / 1024),
)
def run():
wav, sr = sf.read(io.BytesIO(content), dtype="float32", always_2d=False)
if wav.ndim > 1:
wav = wav.mean(axis=1)
wav_t = torch.from_numpy(wav)
if sr != 16000:
wav_t = torchaudio.functional.resample(wav_t.unsqueeze(0), sr, 16000).squeeze(0)
return _parakeet.transcribe(wav_t)
text = await asyncio.to_thread(run)
return {"text": text}
@app.get("/status")
async def get_status():
speakers =[]
model_type = None
active = _model_cache.get(_active_model_name) if _active_model_name else None
if active is not None:
try:
model_type = "official"
speakers = active.get_supported_speakers() or []
except Exception:
speakers =[]
return {
"loaded": active is not None,
"model": _active_model_name,
"loading": _loading,
"available_models": AVAILABLE_MODELS,
"model_type": model_type,
"speakers": speakers,
"transcription_available": _parakeet is not None,
"preset_refs": [
{"id": p["id"], "label": p["label"], "ref_text": p["ref_text"]}
for p in _preset_refs.values()
],
"queue_depth": _generation_waiters,
"cached_models": list(_model_cache.keys()),
}
@app.get("/preset_ref/{preset_id}")
async def get_preset_ref(preset_id: str):
preset = _preset_refs.get(preset_id)
if not preset:
raise HTTPException(status_code=404, detail="Preset not found")
return {
"id": preset["id"],
"label": preset["label"],
"filename": preset["filename"],
"ref_text": preset["ref_text"],
"audio_b64": preset["audio_b64"],
}
@app.post("/load")
async def load_model(model_id: str = Form(...)):
global _active_model_name, _loading
if model_id in _model_cache:
_active_model_name = model_id
return {"status": "already_loaded", "model": model_id}
_loading = True
def _do_load():
global _active_model_name, _loading
try:
# 🛡️ PROTECCIÓN DE RAM CRÍTICA:
# Si hay algún modelo anterior, lo destruimos y forzamos vaciado de RAM
if len(_model_cache) > 0:
_model_cache.clear()
gc.collect()
new_model = Qwen3TTSModel.from_pretrained(
model_id,
device_map="cpu",
dtype=torch.float32,
)
_model_cache[model_id] = new_model
_active_model_name = model_id
print(f"Modelo {model_id} cargado exitosamente en CPU.")
finally:
_loading = False
async with _generation_lock:
await asyncio.to_thread(_do_load)
return {"status": "loaded", "model": model_id}
@app.post("/generate/stream")
async def generate_stream(
text: str = Form(...),
language: str = Form("Spanish"),
mode: str = Form("voice_clone"),
ref_text: str = Form(""),
speaker: str = Form(""),
instruct: str = Form(""),
xvec_only: bool = Form(True),
chunk_size: int = Form(8),
temperature: float = Form(0.9),
top_k: int = Form(50),
repetition_penalty: float = Form(1.05),
ref_preset: str = Form(""),
ref_audio: UploadFile = File(None),
):
if not _active_model_name or _active_model_name not in _model_cache:
raise HTTPException(status_code=400, detail="Model not loaded. Click 'Load' first.")
if len(text) > MAX_TEXT_CHARS:
raise HTTPException(status_code=400, detail="Text too long.")
tmp_path = None
tmp_is_cached = False
if ref_preset and ref_preset in _preset_refs:
preset = _preset_refs[ref_preset]
tmp_path = preset["path"]
tmp_is_cached = True
if not ref_text:
ref_text = preset["ref_text"]
elif ref_audio and ref_audio.filename:
content = await ref_audio.read()
if len(content) > MAX_AUDIO_BYTES:
raise HTTPException(status_code=400, detail=_AUDIO_TOO_LARGE_MSG)
tmp_path = _get_cached_ref_path(content)
tmp_is_cached = True
loop = asyncio.get_event_loop()
queue: asyncio.Queue[str | None] = asyncio.Queue()
def run_generation():
try:
model = _model_cache.get(_active_model_name)
t0 = time.perf_counter()
# En CPU con la librería oficial procesamos todo y enviamos en un solo bloque
# para mantener la estabilidad del frontend y evitar "NoneType".
if mode == "voice_clone":
audio_list, sr = model.generate_voice_clone(
text=text, language=language, ref_audio=tmp_path, ref_text=ref_text,
x_vector_only_mode=xvec_only, temperature=temperature, top_k=top_k,
repetition_penalty=repetition_penalty, max_new_tokens=360, device="cpu"
)
elif mode == "custom":
if not speaker: raise ValueError("Speaker ID is required")
audio_list, sr = model.generate_custom_voice(
text=text, speaker=speaker, language=language, instruct=instruct,
temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty,
max_new_tokens=360, device="cpu"
)
else:
audio_list, sr = model.generate_voice_design(
text=text, instruct=instruct, language=language,
temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty,
max_new_tokens=360, device="cpu"
)
elapsed = time.perf_counter() - t0
chunk_audio = _concat_audio(audio_list)
dur = len(chunk_audio) / sr
rtf = dur / elapsed if elapsed > 0 else 0.0
ttfa_ms = round(elapsed * 1000)
audio_b64 = _to_wav_b64(chunk_audio, sr)
payload = {
"type": "chunk", "audio_b64": audio_b64, "sample_rate": sr,
"ttfa_ms": ttfa_ms, "voice_clone_ms": 0, "rtf": round(rtf, 3),
"total_audio_s": round(dur, 3), "elapsed_ms": ttfa_ms
}
loop.call_soon_threadsafe(queue.put_nowait, json.dumps(payload))
done_payload = {
"type": "done", "ttfa_ms": ttfa_ms, "voice_clone_ms": 0,
"rtf": round(rtf, 3), "total_audio_s": round(dur, 3), "total_ms": ttfa_ms
}
loop.call_soon_threadsafe(queue.put_nowait, json.dumps(done_payload))
except Exception as e:
import traceback
err = {"type": "error", "message": str(e), "detail": traceback.format_exc()}
loop.call_soon_threadsafe(queue.put_nowait, json.dumps(err))
finally:
loop.call_soon_threadsafe(queue.put_nowait, None)
if tmp_path and os.path.exists(tmp_path) and not tmp_is_cached:
os.unlink(tmp_path)
async def sse():
global _generation_waiters
lock_acquired = False
_generation_waiters += 1
people_ahead = _generation_waiters - 1 + (1 if _generation_lock.locked() else 0)
try:
if people_ahead > 0:
yield f"data: {json.dumps({'type': 'queued', 'position': people_ahead})}\n\n"
await _generation_lock.acquire()
lock_acquired = True
_generation_waiters -= 1
thread = threading.Thread(target=run_generation, daemon=True)
thread.start()
while True:
msg = await queue.get()
if msg is None: break
yield f"data: {msg}\n\n"
except asyncio.CancelledError:
pass
finally:
if lock_acquired: _generation_lock.release()
else: _generation_waiters -= 1
return StreamingResponse(sse(), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
@app.post("/generate")
async def generate_non_streaming(
text: str = Form(...), language: str = Form("Spanish"), mode: str = Form("voice_clone"),
ref_text: str = Form(""), speaker: str = Form(""), instruct: str = Form(""),
xvec_only: bool = Form(True), temperature: float = Form(0.9), top_k: int = Form(50),
repetition_penalty: float = Form(1.05), ref_preset: str = Form(""), ref_audio: UploadFile = File(None),
):
if not _active_model_name or _active_model_name not in _model_cache:
raise HTTPException(status_code=400, detail="Model not loaded. Click 'Load' first.")
tmp_path = None
tmp_is_cached = False
if ref_preset and ref_preset in _preset_refs:
preset = _preset_refs[ref_preset]
tmp_path = preset["path"]
tmp_is_cached = True
elif ref_audio and ref_audio.filename:
content = await ref_audio.read()
tmp_path = _get_cached_ref_path(content)
tmp_is_cached = True
def run():
model = _model_cache.get(_active_model_name)
t0 = time.perf_counter()
if mode == "voice_clone":
audio_list, sr = model.generate_voice_clone(
text=text, language=language, ref_audio=tmp_path, ref_text=ref_text,
x_vector_only_mode=xvec_only, temperature=temperature, top_k=top_k,
repetition_penalty=repetition_penalty, max_new_tokens=360, device="cpu"
)
elif mode == "custom":
audio_list, sr = model.generate_custom_voice(
text=text, speaker=speaker, language=language, instruct=instruct,
temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty,
max_new_tokens=360, device="cpu"
)
else:
audio_list, sr = model.generate_voice_design(
text=text, instruct=instruct, language=language,
temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty,
max_new_tokens=360, device="cpu"
)
elapsed = time.perf_counter() - t0
audio = _concat_audio(audio_list)
dur = len(audio) / sr
return audio, sr, elapsed, dur
global _generation_waiters
_generation_waiters += 1
lock_acquired = False
try:
await _generation_lock.acquire()
lock_acquired = True
_generation_waiters -= 1
audio, sr, elapsed, dur = await asyncio.to_thread(run)
rtf = dur / elapsed if elapsed > 0 else 0.0
return JSONResponse({
"audio_b64": _to_wav_b64(audio, sr),
"sample_rate": sr,
"metrics": {"total_ms": round(elapsed * 1000), "audio_duration_s": round(dur, 3), "rtf": round(rtf, 3)},
})
finally:
if lock_acquired: _generation_lock.release()
else: _generation_waiters -= 1
if tmp_path and os.path.exists(tmp_path) and not tmp_is_cached:
os.unlink(tmp_path)
def main():
parser = argparse.ArgumentParser(description="Qwen3-TTS Demo Server")
parser.add_argument("--model", default="Qwen/Qwen3-TTS-12Hz-0.6B-Base", help="Model to preload at startup")
parser.add_argument("--port", type=int, default=int(os.environ.get("PORT", 7860)))
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--no-preload", action="store_true", help="Skip model loading at startup")
args = parser.parse_args()
if not args.no_preload:
global _active_model_name, _parakeet
print(f"Loading model: {args.model}")
_startup_model = Qwen3TTSModel.from_pretrained(args.model, device_map="cpu", dtype=torch.float32)
_model_cache[args.model] = _startup_model
_active_model_name = args.model
print("Loading transcription model (nano-parakeet)…")
_parakeet = _parakeet_from_pretrained(device="cpu")
print("Transcription model ready on CPU.")
print(f"Ready. Open http://localhost:{args.port}")
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
if __name__ == "__main__":
main()