| | """ |
| | Quick Mode 任务服务 |
| | |
| | 处理一键训练任务的业务逻辑 |
| | """ |
| |
|
| | import re |
| | import uuid |
| | from datetime import datetime |
| | from pathlib import Path |
| | from typing import AsyncGenerator, Dict, Optional, Any, Tuple |
| |
|
| | from project_config import settings |
| | from ..core.adapters import get_database_adapter, get_task_queue_adapter, get_storage_adapter |
| | from ..models.domain import Task, TaskStatus |
| | from ..models.schemas.task import ( |
| | QuickModeRequest, |
| | TaskResponse, |
| | TaskListResponse, |
| | InferenceOutputItem, |
| | InferenceOutputsResponse, |
| | ) |
| |
|
| | |
| | QUALITY_PRESETS = { |
| | "fast": { |
| | "sovits_epochs": 4, |
| | "gpt_epochs": 8, |
| | "description": "快速训练,约10分钟", |
| | }, |
| | "standard": { |
| | "sovits_epochs": 8, |
| | "gpt_epochs": 15, |
| | "description": "标准训练,约20分钟", |
| | }, |
| | "high": { |
| | "sovits_epochs": 16, |
| | "gpt_epochs": 30, |
| | "description": "高质量训练,约40分钟", |
| | }, |
| | } |
| |
|
| | |
| | DEFAULT_TARGET_TEXTS = { |
| | "zh": "这是一段测试语音合成的文本。请你用自然、清晰、不过度夸张的语气朗读,并在逗号和句号处做适当停顿:先慢一点,再稍微快一点,最后恢复正常语速。", |
| | "en": "This is a test text for speech synthesis. Please read it naturally and clearly, without exaggeration, pausing appropriately at commas and periods: start slowly, then speed up a bit, and finally return to normal pace.", |
| | "ja": "これは音声合成のテストテキストです。自然で明瞭に、大げさにならないように朗読してください。読点と句点で適切に間を置いて:最初はゆっくり、少し速く、最後は普通の速さに戻してください。", |
| | "ko": "이것은 음성 합성을 위한 테스트 텍스트입니다. 자연스럽고 명확하게, 과장하지 않고 읽어주세요. 쉼표와 마침표에서 적절히 멈추며: 먼저 천천히, 그 다음 조금 빠르게, 마지막으로 보통 속도로 돌아오세요.", |
| | "yue": "呢段係測試語音合成嘅文字。請你用自然、清楚、唔好太誇張嘅語氣讀出嚟,喺逗號同句號嗰度要適當噉停一停:開頭慢啲,跟住快少少,最後返返正常語速。", |
| | } |
| |
|
| |
|
| | class TaskService: |
| | """ |
| | Quick Mode 任务服务 |
| | |
| | 提供一键训练任务的完整生命周期管理: |
| | - 创建任务 |
| | - 查询任务状态 |
| | - 取消任务 |
| | - 订阅进度更新 |
| | |
| | Example: |
| | >>> service = TaskService() |
| | >>> task = await service.create_quick_task(request) |
| | >>> status = await service.get_task(task.id) |
| | >>> await service.cancel_task(task.id) |
| | """ |
| | |
| | def __init__(self): |
| | """初始化服务""" |
| | self._db = None |
| | self._queue = None |
| | self._storage = None |
| | |
| | @property |
| | def db(self): |
| | """延迟获取数据库适配器""" |
| | if self._db is None: |
| | self._db = get_database_adapter() |
| | return self._db |
| | |
| | @property |
| | def queue(self): |
| | """延迟获取任务队列适配器""" |
| | if self._queue is None: |
| | self._queue = get_task_queue_adapter() |
| | return self._queue |
| | |
| | @property |
| | def storage(self): |
| | """延迟获取存储适配器""" |
| | if self._storage is None: |
| | self._storage = get_storage_adapter() |
| | return self._storage |
| | |
| | async def check_exp_name_exists(self, exp_name: str) -> bool: |
| | """ |
| | 检查实验名称是否已存在 |
| | |
| | Args: |
| | exp_name: 实验名称 |
| | |
| | Returns: |
| | 如果存在返回 True,否则返回 False |
| | """ |
| | existing_task = await self.db.get_task_by_exp_name(exp_name) |
| | return existing_task is not None |
| | |
| | async def validate_audio_file(self, audio_file_id: str) -> tuple[bool, str]: |
| | """ |
| | 验证音频文件是否存在 |
| | |
| | Args: |
| | audio_file_id: 音频文件 ID 或路径 |
| | |
| | Returns: |
| | (是否存在, 实际文件路径) |
| | """ |
| | import os |
| | |
| | |
| | file_metadata = await self.storage.get_file_metadata(audio_file_id) |
| | |
| | if file_metadata: |
| | |
| | audio_file_path = str(self.storage.base_path / audio_file_id) |
| | exists = os.path.exists(audio_file_path) |
| | return exists, audio_file_path |
| | else: |
| | |
| | exists = os.path.exists(audio_file_id) |
| | return exists, audio_file_id |
| | |
| | async def create_quick_task(self, request: QuickModeRequest) -> TaskResponse: |
| | """ |
| | 创建一键训练任务 |
| | |
| | 根据请求参数和质量预设,自动配置训练参数并创建任务。 |
| | |
| | Args: |
| | request: 快速模式请求 |
| | |
| | Returns: |
| | TaskResponse: 任务响应 |
| | """ |
| | |
| | task_id = f"task-{uuid.uuid4().hex[:12]}" |
| | |
| | |
| | quality = request.options.quality |
| | preset = QUALITY_PRESETS.get(quality, QUALITY_PRESETS["standard"]) |
| | |
| | |
| | audio_file_id = request.audio_file_id |
| | _, audio_file_path = await self.validate_audio_file(audio_file_id) |
| | |
| | |
| | stages = [ |
| | "audio_slice", |
| | "asr", |
| | "text_feature", |
| | "hubert_feature", |
| | "semantic_token", |
| | "sovits_train", |
| | "gpt_train", |
| | ] |
| | |
| | |
| | inference_opts = request.options.inference |
| | inference_enabled = inference_opts is None or inference_opts.enabled |
| | |
| | |
| | if inference_enabled: |
| | stages.append("inference") |
| | |
| | |
| | config = { |
| | "exp_name": request.exp_name, |
| | "audio_file_id": audio_file_id, |
| | "input_path": audio_file_path, |
| | "version": request.options.version, |
| | "language": request.options.language, |
| | "quality": quality, |
| | |
| | "total_epoch": preset["sovits_epochs"], |
| | "sovits_epochs": preset["sovits_epochs"], |
| | "gpt_epochs": preset["gpt_epochs"], |
| | |
| | "bert_pretrained_dir": str(settings.BERT_PRETRAINED_DIR), |
| | "ssl_pretrained_dir": str(settings.SSL_PRETRAINED_DIR), |
| | "pretrained_s2G": str(settings.PRETRAINED_S2G), |
| | "pretrained_s2D": str(settings.PRETRAINED_S2D), |
| | "pretrained_s1": str(settings.PRETRAINED_S1), |
| | |
| | "stages": stages, |
| | } |
| | |
| | |
| | if inference_enabled: |
| | if inference_opts: |
| | config["ref_text"] = inference_opts.ref_text or "" |
| | config["ref_audio_path"] = inference_opts.ref_audio_path or "" |
| | config["target_text"] = inference_opts.target_text |
| | else: |
| | |
| | config["ref_text"] = "" |
| | config["ref_audio_path"] = "" |
| | |
| | language = request.options.language |
| | config["target_text"] = DEFAULT_TARGET_TEXTS.get(language, DEFAULT_TARGET_TEXTS["zh"]) |
| | |
| | |
| | task = Task( |
| | id=task_id, |
| | exp_name=request.exp_name, |
| | config=config, |
| | status=TaskStatus.QUEUED, |
| | created_at=datetime.utcnow(), |
| | ) |
| | |
| | |
| | await self.db.create_task(task) |
| | |
| | |
| | job_id = await self.queue.enqueue(task_id, config) |
| | |
| | |
| | await self.db.update_task(task_id, {"job_id": job_id}) |
| | task.job_id = job_id |
| | |
| | return self._task_to_response(task) |
| | |
| | async def get_task(self, task_id: str) -> Optional[TaskResponse]: |
| | """ |
| | 获取任务详情 |
| | |
| | Args: |
| | task_id: 任务ID |
| | |
| | Returns: |
| | TaskResponse 或 None(不存在时) |
| | """ |
| | task = await self.db.get_task(task_id) |
| | if not task: |
| | return None |
| | return self._task_to_response(task) |
| | |
| | async def list_tasks( |
| | self, |
| | status: Optional[str] = None, |
| | limit: int = 50, |
| | offset: int = 0 |
| | ) -> TaskListResponse: |
| | """ |
| | 获取任务列表 |
| | |
| | Args: |
| | status: 按状态筛选 |
| | limit: 每页数量 |
| | offset: 偏移量 |
| | |
| | Returns: |
| | TaskListResponse |
| | """ |
| | tasks = await self.db.list_tasks(status=status, limit=limit, offset=offset) |
| | total = await self.db.count_tasks(status=status) |
| | |
| | return TaskListResponse( |
| | items=[self._task_to_response(t) for t in tasks], |
| | total=total, |
| | limit=limit, |
| | offset=offset, |
| | ) |
| | |
| | async def cancel_task(self, task_id: str) -> bool: |
| | """ |
| | 取消任务 |
| | |
| | Args: |
| | task_id: 任务ID |
| | |
| | Returns: |
| | 是否成功取消 |
| | """ |
| | |
| | task = await self.db.get_task(task_id) |
| | if not task: |
| | return False |
| | |
| | |
| | if task.status not in (TaskStatus.QUEUED, TaskStatus.RUNNING): |
| | return False |
| | |
| | |
| | if task.job_id: |
| | await self.queue.cancel(task.job_id) |
| | |
| | |
| | await self.db.update_task(task_id, { |
| | "status": TaskStatus.CANCELLED, |
| | "completed_at": datetime.utcnow(), |
| | "message": "任务已取消", |
| | }) |
| | |
| | return True |
| | |
| | async def subscribe_progress( |
| | self, |
| | task_id: str |
| | ) -> AsyncGenerator[Dict[str, Any], None]: |
| | """ |
| | 订阅任务进度(SSE 流) |
| | |
| | Args: |
| | task_id: 任务ID |
| | |
| | Yields: |
| | 进度信息字典 |
| | """ |
| | |
| | task = await self.db.get_task(task_id) |
| | if not task: |
| | yield {"type": "error", "message": "任务不存在"} |
| | return |
| | |
| | |
| | if task.status in (TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED): |
| | yield { |
| | "type": "final", |
| | "status": task.status.value, |
| | "message": task.message or task.error_message, |
| | "progress": task.progress, |
| | } |
| | return |
| | |
| | |
| | async for progress in self.queue.subscribe_progress(task_id): |
| | yield progress |
| | |
| | |
| | if progress.get("status") in ("completed", "failed", "cancelled"): |
| | break |
| | |
| | async def get_inference_outputs(self, task_id: str) -> Optional[InferenceOutputsResponse]: |
| | """ |
| | 获取任务的推理输出列表 |
| | |
| | 扫描 logs/{exp_name}/inference/ 目录,返回所有推理生成的音频文件元信息。 |
| | |
| | Args: |
| | task_id: 任务 ID |
| | |
| | Returns: |
| | InferenceOutputsResponse 或 None(任务不存在时) |
| | """ |
| | |
| | task = await self.db.get_task(task_id) |
| | if not task: |
| | return None |
| | |
| | exp_name = task.exp_name |
| | inference_dir = Path(settings.EXP_ROOT) / exp_name / "inference" |
| | |
| | |
| | ref_text = task.config.get("ref_text", "") |
| | ref_audio_path = task.config.get("ref_audio_path", "") |
| | target_text = task.config.get("target_text", "") |
| | |
| | |
| | if not ref_text or not ref_audio_path: |
| | parsed_audio, parsed_text = self._parse_list_file(exp_name) |
| | ref_audio_path = parsed_audio |
| | ref_text = parsed_text |
| | |
| | |
| | version = task.config.get("version", "v2") |
| | gpt_weight_dir = self._get_gpt_weight_dir(version) |
| | sovits_weight_dir = self._get_sovits_weight_dir(version) |
| | |
| | outputs = [] |
| | |
| | if inference_dir.exists() and inference_dir.is_dir(): |
| | |
| | for file_path in inference_dir.glob("*.wav"): |
| | filename = file_path.name |
| | |
| | |
| | |
| | gpt_model, sovits_model = self._parse_inference_filename(filename, exp_name) |
| | |
| | |
| | gpt_path = str(Path(settings.EXP_ROOT) / exp_name / gpt_weight_dir / f"{gpt_model}.ckpt") |
| | sovits_path = str(Path(settings.EXP_ROOT) / exp_name / sovits_weight_dir / f"{sovits_model}.pth") |
| | |
| | |
| | stat = file_path.stat() |
| | |
| | outputs.append(InferenceOutputItem( |
| | filename=filename, |
| | gpt_model=gpt_model, |
| | sovits_model=sovits_model, |
| | gpt_path=gpt_path, |
| | sovits_path=sovits_path, |
| | file_path=str(file_path.relative_to(settings.PROJECT_ROOT)), |
| | size_bytes=stat.st_size, |
| | created_at=datetime.fromtimestamp(stat.st_ctime), |
| | )) |
| | |
| | return InferenceOutputsResponse( |
| | task_id=task_id, |
| | exp_name=exp_name, |
| | ref_text=ref_text, |
| | ref_audio_path=ref_audio_path, |
| | target_text=target_text, |
| | outputs=outputs, |
| | total=len(outputs), |
| | ) |
| | |
| | async def download_inference_output( |
| | self, |
| | task_id: str, |
| | filename: str |
| | ) -> Optional[Tuple[bytes, str, str]]: |
| | """ |
| | 下载指定的推理输出文件 |
| | |
| | Args: |
| | task_id: 任务 ID |
| | filename: 文件名 |
| | |
| | Returns: |
| | (文件内容, 文件名, content_type) 或 None(不存在时) |
| | """ |
| | |
| | task = await self.db.get_task(task_id) |
| | if not task: |
| | return None |
| | |
| | exp_name = task.exp_name |
| | file_path = Path(settings.EXP_ROOT) / exp_name / "inference" / filename |
| | |
| | |
| | try: |
| | file_path = file_path.resolve() |
| | expected_parent = (Path(settings.EXP_ROOT) / exp_name / "inference").resolve() |
| | if not str(file_path).startswith(str(expected_parent)): |
| | return None |
| | except (ValueError, OSError): |
| | return None |
| | |
| | if not file_path.exists() or not file_path.is_file(): |
| | return None |
| | |
| | |
| | with open(file_path, "rb") as f: |
| | file_data = f.read() |
| | |
| | return file_data, filename, "audio/wav" |
| | |
| | async def download_file( |
| | self, |
| | task_id: str, |
| | file_type: str, |
| | filename: str |
| | ) -> Optional[Tuple[bytes, str, str]]: |
| | """ |
| | 下载指定类型的文件 |
| | |
| | Args: |
| | task_id: 任务 ID |
| | file_type: 文件类型 (output/ref_audio/gpt_model/sovits_model) |
| | filename: 文件名 |
| | |
| | Returns: |
| | (文件内容, 文件名, content_type) 或 None(不存在时) |
| | """ |
| | |
| | task = await self.db.get_task(task_id) |
| | if not task: |
| | return None |
| | |
| | exp_name = task.exp_name |
| | version = task.config.get("version", "v2") |
| | |
| | |
| | if file_type == "output": |
| | file_path = Path(settings.EXP_ROOT) / exp_name / "inference" / filename |
| | content_type = "audio/wav" |
| | elif file_type == "ref_audio": |
| | |
| | file_path = Path(filename) |
| | content_type = "audio/wav" |
| | elif file_type == "gpt_model": |
| | gpt_dir = self._get_gpt_weight_dir(version) |
| | file_path = Path(settings.EXP_ROOT) / exp_name / gpt_dir / filename |
| | content_type = "application/octet-stream" |
| | elif file_type == "sovits_model": |
| | sovits_dir = self._get_sovits_weight_dir(version) |
| | file_path = Path(settings.EXP_ROOT) / exp_name / sovits_dir / filename |
| | content_type = "application/octet-stream" |
| | else: |
| | return None |
| | |
| | |
| | try: |
| | file_path = file_path.resolve() |
| | except (ValueError, OSError): |
| | return None |
| | |
| | if not file_path.exists() or not file_path.is_file(): |
| | return None |
| | |
| | |
| | with open(file_path, "rb") as f: |
| | file_data = f.read() |
| | |
| | |
| | download_filename = file_path.name |
| | |
| | return file_data, download_filename, content_type |
| | |
| | def _parse_inference_filename(self, filename: str, exp_name: str) -> Tuple[str, str]: |
| | """ |
| | 解析推理输出文件名,提取 GPT 和 SoVITS 模型名称 |
| | |
| | 文件名格式: {exp_name}_gpt-{gpt_name}_sovits-{sovits_name}.wav |
| | |
| | Args: |
| | filename: 文件名 |
| | exp_name: 实验名称 |
| | |
| | Returns: |
| | (gpt_model, sovits_model) |
| | """ |
| | |
| | name = filename.rsplit(".", 1)[0] if "." in filename else filename |
| | |
| | |
| | |
| | pattern = rf"^{re.escape(exp_name)}_gpt-(.+)_sovits-(.+)$" |
| | match = re.match(pattern, name) |
| | |
| | if match: |
| | return match.group(1), match.group(2) |
| | |
| | |
| | gpt_match = re.search(r"gpt-([^_]+)", name) |
| | sovits_match = re.search(r"sovits-([^_]+)", name) |
| | |
| | gpt_model = gpt_match.group(1) if gpt_match else "unknown" |
| | sovits_model = sovits_match.group(1) if sovits_match else "unknown" |
| | |
| | return gpt_model, sovits_model |
| | |
| | def _parse_list_file(self, exp_name: str) -> Tuple[str, str]: |
| | """ |
| | 从 asr_opt/slicer_opt.list 解析第一行获取 ref_audio_path 和 ref_text |
| | |
| | Args: |
| | exp_name: 实验名称 |
| | |
| | Returns: |
| | (ref_audio_path, ref_text) 元组,解析失败返回空字符串 |
| | """ |
| | list_path = Path(settings.EXP_ROOT) / exp_name / 'asr_opt' / 'slicer_opt.list' |
| | if not list_path.exists(): |
| | return "", "" |
| | |
| | with open(list_path, 'r', encoding='utf-8') as f: |
| | first_line = f.readline().strip() |
| | |
| | if not first_line: |
| | return "", "" |
| | |
| | |
| | parts = first_line.split('|') |
| | if len(parts) >= 4: |
| | return parts[0], parts[3] |
| | return "", "" |
| | |
| | def _get_gpt_weight_dir(self, version: str) -> str: |
| | """根据模型版本获取 GPT 权重目录名""" |
| | version_to_dir = { |
| | "v1": "GPT_weights", |
| | "v2": "GPT_weights_v2", |
| | "v3": "GPT_weights_v3", |
| | "v4": "GPT_weights_v4", |
| | "v2Pro": "GPT_weights_v2Pro", |
| | "v2ProPlus": "GPT_weights_v2ProPlus", |
| | } |
| | return version_to_dir.get(version, "GPT_weights_v2") |
| | |
| | def _get_sovits_weight_dir(self, version: str) -> str: |
| | """根据模型版本获取 SoVITS 权重目录名""" |
| | version_to_dir = { |
| | "v1": "SoVITS_weights", |
| | "v2": "SoVITS_weights_v2", |
| | "v3": "SoVITS_weights_v3", |
| | "v4": "SoVITS_weights_v4", |
| | "v2Pro": "SoVITS_weights_v2Pro", |
| | "v2ProPlus": "SoVITS_weights_v2ProPlus", |
| | } |
| | return version_to_dir.get(version, "SoVITS_weights_v2") |
| | |
| | def _task_to_response(self, task: Task) -> TaskResponse: |
| | """将 Task 领域模型转换为 TaskResponse""" |
| | return TaskResponse( |
| | id=task.id, |
| | exp_name=task.exp_name, |
| | status=task.status.value if isinstance(task.status, TaskStatus) else task.status, |
| | current_stage=task.current_stage, |
| | progress=task.stage_progress, |
| | overall_progress=task.progress, |
| | message=task.message, |
| | error_message=task.error_message, |
| | created_at=task.created_at, |
| | started_at=task.started_at, |
| | completed_at=task.completed_at, |
| | ) |
| |
|