| | import logging |
| |
|
| | from fastapi import HTTPException, Query, Request |
| | from fastapi.responses import StreamingResponse |
| | from pydantic import BaseModel |
| |
|
| | from modules.api.Api import APIManager |
| | from modules.api.impl.handler.TTSHandler import TTSHandler |
| | from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat |
| | from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig |
| | from modules.api.impl.model.enhancer_model import EnhancerConfig |
| | from modules.speaker import speaker_mgr |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class XTTS_V2_Settings: |
| | def __init__(self): |
| | self.stream_chunk_size = 100 |
| | self.temperature = 0.3 |
| | self.speed = 1 |
| |
|
| | |
| | self.length_penalty = 0.5 |
| | self.repetition_penalty = 1.0 |
| |
|
| | self.top_p = 0.7 |
| | self.top_k = 20 |
| | self.enable_text_splitting = True |
| |
|
| | |
| | self.batch_size = 4 |
| | self.eos = "[uv_break]" |
| | self.infer_seed = 42 |
| | self.use_decoder = True |
| | self.prompt1 = "" |
| | self.prompt2 = "" |
| | self.prefix = "" |
| | self.spliter_threshold = 100 |
| | self.style = "" |
| |
|
| |
|
| | class TTSSettingsRequest(BaseModel): |
| | |
| | stream_chunk_size: int |
| | temperature: float |
| | speed: float |
| | length_penalty: float |
| | repetition_penalty: float |
| | top_p: float |
| | top_k: int |
| | enable_text_splitting: bool |
| |
|
| | batch_size: int = None |
| | eos: str = None |
| | infer_seed: int = None |
| | use_decoder: bool = None |
| | prompt1: str = None |
| | prompt2: str = None |
| | prefix: str = None |
| | spliter_threshold: int = None |
| | style: str = None |
| |
|
| |
|
| | class SynthesisRequest(BaseModel): |
| | text: str |
| | speaker_wav: str |
| | language: str |
| |
|
| |
|
| | def setup(app: APIManager): |
| | XTTSV2 = XTTS_V2_Settings() |
| |
|
| | @app.get("/v1/xtts_v2/speakers") |
| | async def speakers(): |
| | spks = speaker_mgr.list_speakers() |
| | return [ |
| | { |
| | "name": spk.name, |
| | "voice_id": spk.id, |
| | |
| | "preview_url": "", |
| | } |
| | for spk in spks |
| | ] |
| |
|
| | @app.post("/v1/xtts_v2/tts_to_audio", response_class=StreamingResponse) |
| | async def tts_to_audio(request: SynthesisRequest): |
| | text = request.text |
| | |
| | voice_id = request.speaker_wav |
| | language = request.language |
| |
|
| | spk = speaker_mgr.get_speaker_by_id(voice_id) or speaker_mgr.get_speaker( |
| | voice_id |
| | ) |
| | if spk is None: |
| | raise HTTPException(status_code=400, detail="Invalid speaker id") |
| |
|
| | tts_config = ChatTTSConfig( |
| | style=XTTSV2.style, |
| | temperature=XTTSV2.temperature, |
| | top_k=XTTSV2.top_k, |
| | top_p=XTTSV2.top_p, |
| | prefix=XTTSV2.prefix, |
| | prompt1=XTTSV2.prompt1, |
| | prompt2=XTTSV2.prompt2, |
| | ) |
| | infer_config = InferConfig( |
| | batch_size=XTTSV2.batch_size, |
| | spliter_threshold=XTTSV2.spliter_threshold, |
| | eos=XTTSV2.eos, |
| | seed=XTTSV2.infer_seed, |
| | ) |
| | adjust_config = AdjustConfig( |
| | speed_rate=XTTSV2.speed, |
| | ) |
| | |
| | enhancer_config = EnhancerConfig( |
| | |
| | |
| | ) |
| |
|
| | handler = TTSHandler( |
| | text_content=text, |
| | spk=spk, |
| | tts_config=tts_config, |
| | infer_config=infer_config, |
| | adjust_config=adjust_config, |
| | enhancer_config=enhancer_config, |
| | ) |
| |
|
| | buffer = handler.enqueue_to_buffer(AudioFormat.mp3) |
| |
|
| | return StreamingResponse(buffer, media_type="audio/mpeg") |
| |
|
| | @app.get("/v1/xtts_v2/tts_stream") |
| | async def tts_stream( |
| | request: Request, |
| | text: str = Query(), |
| | speaker_wav: str = Query(), |
| | language: str = Query(), |
| | ): |
| | |
| | voice_id = speaker_wav |
| |
|
| | spk = speaker_mgr.get_speaker_by_id(voice_id) or speaker_mgr.get_speaker( |
| | voice_id |
| | ) |
| | if spk is None: |
| | raise HTTPException(status_code=400, detail="Invalid speaker id") |
| |
|
| | tts_config = ChatTTSConfig( |
| | style=XTTSV2.style, |
| | temperature=XTTSV2.temperature, |
| | top_k=XTTSV2.top_k, |
| | top_p=XTTSV2.top_p, |
| | prefix=XTTSV2.prefix, |
| | prompt1=XTTSV2.prompt1, |
| | prompt2=XTTSV2.prompt2, |
| | ) |
| | infer_config = InferConfig( |
| | batch_size=XTTSV2.batch_size, |
| | spliter_threshold=XTTSV2.spliter_threshold, |
| | eos=XTTSV2.eos, |
| | seed=XTTSV2.infer_seed, |
| | ) |
| | adjust_config = AdjustConfig( |
| | speed_rate=XTTSV2.speed, |
| | ) |
| | |
| | enhancer_config = EnhancerConfig( |
| | |
| | |
| | ) |
| |
|
| | handler = TTSHandler( |
| | text_content=text, |
| | spk=spk, |
| | tts_config=tts_config, |
| | infer_config=infer_config, |
| | adjust_config=adjust_config, |
| | enhancer_config=enhancer_config, |
| | ) |
| |
|
| | async def generator(): |
| | for chunk in handler.enqueue_to_stream(AudioFormat.mp3): |
| | disconnected = await request.is_disconnected() |
| | if disconnected: |
| | break |
| |
|
| | yield chunk |
| |
|
| | return StreamingResponse(generator(), media_type="audio/mpeg") |
| |
|
| | @app.post("/v1/xtts_v2/set_tts_settings") |
| | async def set_tts_settings(request: TTSSettingsRequest): |
| | try: |
| | if request.stream_chunk_size < 50: |
| | raise HTTPException( |
| | status_code=400, detail="stream_chunk_size must be greater than 0" |
| | ) |
| | if request.temperature < 0: |
| | raise HTTPException( |
| | status_code=400, detail="temperature must be greater than 0" |
| | ) |
| | if request.speed < 0: |
| | raise HTTPException( |
| | status_code=400, detail="speed must be greater than 0" |
| | ) |
| | if request.length_penalty < 0: |
| | raise HTTPException( |
| | status_code=400, detail="length_penalty must be greater than 0" |
| | ) |
| | if request.repetition_penalty < 0: |
| | raise HTTPException( |
| | status_code=400, detail="repetition_penalty must be greater than 0" |
| | ) |
| | if request.top_p < 0: |
| | raise HTTPException( |
| | status_code=400, detail="top_p must be greater than 0" |
| | ) |
| | if request.top_k < 0: |
| | raise HTTPException( |
| | status_code=400, detail="top_k must be greater than 0" |
| | ) |
| |
|
| | XTTSV2.stream_chunk_size = request.stream_chunk_size |
| | XTTSV2.spliter_threshold = request.stream_chunk_size |
| |
|
| | XTTSV2.temperature = request.temperature |
| | XTTSV2.speed = request.speed |
| | XTTSV2.length_penalty = request.length_penalty |
| | XTTSV2.repetition_penalty = request.repetition_penalty |
| | XTTSV2.top_p = request.top_p |
| | XTTSV2.top_k = request.top_k |
| | XTTSV2.enable_text_splitting = request.enable_text_splitting |
| |
|
| | |
| | if request.batch_size: |
| | XTTSV2.batch_size = request.batch_size |
| | if request.eos: |
| | XTTSV2.eos = request.eos |
| | if request.infer_seed: |
| | XTTSV2.infer_seed = request.infer_seed |
| | if request.use_decoder: |
| | XTTSV2.use_decoder = request.use_decoder |
| | if request.prompt1: |
| | XTTSV2.prompt1 = request.prompt1 |
| | if request.prompt2: |
| | XTTSV2.prompt2 = request.prompt2 |
| | if request.prefix: |
| | XTTSV2.prefix = request.prefix |
| | if request.spliter_threshold: |
| | XTTSV2.spliter_threshold = request.spliter_threshold |
| | if request.style: |
| | XTTSV2.style = request.style |
| |
|
| | return {"message": "Settings successfully applied"} |
| | except Exception as e: |
| | if isinstance(e, HTTPException): |
| | raise e |
| | logger.error(e) |
| | raise HTTPException(status_code=500, detail=str(e)) |
| |
|