""" Advanced Mode 实验/阶段 Schema 专家用户分阶段训练模式的请求/响应模型 参考文档: development.md 4.6.2 """ from datetime import datetime from enum import Enum from typing import Any, Dict, List, Literal, Optional from pydantic import BaseModel, Field # ============================================================ # 枚举类型 # ============================================================ class StageType(str, Enum): """ 训练阶段类型枚举 定义了完整训练流程中的所有阶段 """ AUDIO_SLICE = "audio_slice" # 音频切片 ASR = "asr" # 语音识别 TEXT_FEATURE = "text_feature" # 文本特征提取 HUBERT_FEATURE = "hubert_feature" # HuBERT 特征提取 SEMANTIC_TOKEN = "semantic_token" # 语义 Token 提取 SOVITS_TRAIN = "sovits_train" # SoVITS 训练 GPT_TRAIN = "gpt_train" # GPT 训练 # 阶段依赖关系 STAGE_DEPENDENCIES: Dict[str, List[str]] = { "audio_slice": [], "asr": ["audio_slice"], "text_feature": ["asr"], "hubert_feature": ["audio_slice"], "semantic_token": ["hubert_feature"], "sovits_train": ["text_feature", "semantic_token"], "gpt_train": ["text_feature", "semantic_token"], } # ============================================================ # 实验管理 # ============================================================ class ExperimentCreate(BaseModel): """ 创建实验请求 创建实验但不立即执行,用户可以逐阶段控制训练流程 Attributes: exp_name: 实验名称 version: 模型版本 gpu_numbers: GPU 编号 is_half: 是否使用半精度 audio_file_id: 音频文件 ID """ exp_name: str = Field( ..., min_length=1, max_length=100, description="实验名称" ) version: Literal["v1", "v2", "v2Pro", "v3", "v4"] = Field( default="v2", description="模型版本" ) gpu_numbers: str = Field( default="0", description="GPU 编号,多个 GPU 用逗号分隔,如 '0,1'" ) is_half: bool = Field( default=True, description="是否使用半精度(FP16),可节省显存" ) audio_file_id: str = Field( ..., description="已上传音频文件的 ID" ) model_config = { "json_schema_extra": { "examples": [ { "exp_name": "my_voice_custom", "version": "v2", "gpu_numbers": "0", "is_half": True, "audio_file_id": "550e8400-e29b-41d4-a716-446655440000" } ] } } class ExperimentUpdate(BaseModel): """ 更新实验请求 用于更新实验的基础配置(非阶段参数) """ exp_name: Optional[str] = Field( default=None, min_length=1, max_length=100, description="实验名称" ) gpu_numbers: Optional[str] = Field( default=None, description="GPU 编号" ) is_half: Optional[bool] = Field( default=None, description="是否使用半精度" ) class StageStatus(BaseModel): """ 阶段状态 描述单个阶段的执行状态和结果 """ stage_type: str = Field(..., description="阶段类型") status: Literal["pending", "running", "completed", "failed", "cancelled"] = Field( default="pending", description="阶段状态" ) progress: Optional[float] = Field( default=None, ge=0.0, le=1.0, description="阶段进度 (0.0-1.0)" ) message: Optional[str] = Field( default=None, description="状态消息" ) started_at: Optional[datetime] = Field( default=None, description="开始时间" ) completed_at: Optional[datetime] = Field( default=None, description="完成时间" ) config: Optional[Dict[str, Any]] = Field( default=None, description="阶段配置参数" ) outputs: Optional[Dict[str, Any]] = Field( default=None, description="阶段输出结果" ) error_message: Optional[str] = Field( default=None, description="错误消息(失败时)" ) model_config = { "json_schema_extra": { "examples": [ { "stage_type": "sovits_train", "status": "completed", "progress": 1.0, "message": "训练完成", "started_at": "2024-01-01T10:30:00Z", "completed_at": "2024-01-01T11:00:00Z", "config": {"batch_size": 8, "total_epoch": 16}, "outputs": { "model_path": "logs/my_voice/sovits_e16.pth", "metrics": {"final_loss": 0.023} } } ] } } class ExperimentResponse(BaseModel): """ 实验响应 包含实验的完整信息和所有阶段状态 """ id: str = Field(..., description="实验唯一标识") exp_name: str = Field(..., description="实验名称") version: str = Field(..., description="模型版本") status: str = Field(..., description="实验状态") gpu_numbers: str = Field(default="0", description="GPU 编号") is_half: bool = Field(default=True, description="是否使用半精度") audio_file_id: str = Field(..., description="音频文件 ID") stages: Dict[str, StageStatus] = Field( default_factory=dict, description="各阶段状态" ) created_at: datetime = Field(..., description="创建时间") updated_at: Optional[datetime] = Field(default=None, description="更新时间") model_config = { "json_schema_extra": { "examples": [ { "id": "exp-abc123", "exp_name": "my_voice_custom", "version": "v2", "status": "created", "gpu_numbers": "0", "is_half": True, "audio_file_id": "550e8400-e29b-41d4-a716-446655440000", "stages": { "audio_slice": {"stage_type": "audio_slice", "status": "pending"}, "asr": {"stage_type": "asr", "status": "pending"}, "sovits_train": {"stage_type": "sovits_train", "status": "pending"} }, "created_at": "2024-01-01T10:00:00Z" } ] } } class ExperimentListResponse(BaseModel): """ 实验列表响应 """ items: List[ExperimentResponse] = Field( default_factory=list, description="实验列表" ) total: int = Field(default=0, ge=0, description="总数量") limit: int = Field(default=50, ge=1, le=100, description="每页数量") offset: int = Field(default=0, ge=0, description="偏移量") # ============================================================ # 阶段执行参数 # ============================================================ class StageExecuteRequest(BaseModel): """ 阶段执行请求基类 允许传入任意额外参数 """ model_config = { "extra": "allow" # 允许额外字段(阶段特定参数) } class AudioSliceParams(StageExecuteRequest): """ 音频切片参数 将长音频切分为短片段 参考文档: development.md 4.5.2 """ threshold: int = Field( default=-34, ge=-60, le=0, description="静音检测阈值 (dB)" ) min_length: int = Field( default=4000, ge=1000, le=10000, description="最小切片长度 (ms)" ) min_interval: int = Field( default=300, ge=100, le=1000, description="最小静音间隔 (ms)" ) hop_size: int = Field( default=10, ge=5, le=50, description="检测步长 (ms)" ) max_sil_kept: int = Field( default=500, ge=100, le=2000, description="切片保留的最大静音长度 (ms)" ) model_config = { "json_schema_extra": { "examples": [ { "threshold": -34, "min_length": 4000, "min_interval": 300, "hop_size": 10, "max_sil_kept": 500 } ] } } class ASRParams(StageExecuteRequest): """ ASR 语音识别参数 """ model: str = Field( default="达摩 ASR (中文)", description="ASR 模型名称" ) language: str = Field( default="zh", description="识别语言" ) model_config = { "json_schema_extra": { "examples": [ {"model": "达摩 ASR (中文)", "language": "zh"} ] } } class TextFeatureParams(StageExecuteRequest): """ 文本特征提取参数 """ bert_pretrained_dir: Optional[str] = Field( default=None, description="BERT 预训练模型目录,为空使用默认" ) model_config = { "json_schema_extra": { "examples": [ {"bert_pretrained_dir": None} ] } } class HubertFeatureParams(StageExecuteRequest): """ HuBERT 特征提取参数 """ ssl_pretrained_dir: Optional[str] = Field( default=None, description="SSL 预训练模型目录,为空使用默认" ) model_config = { "json_schema_extra": { "examples": [ {"ssl_pretrained_dir": None} ] } } class SemanticTokenParams(StageExecuteRequest): """ 语义 Token 提取参数 """ # 当前阶段无特殊参数,保留扩展性 pass class SoVITSTrainParams(StageExecuteRequest): """ SoVITS 训练参数 参考文档: development.md 4.5.2 """ batch_size: int = Field( default=4, ge=1, le=32, description="批次大小,显存不足时减小" ) total_epoch: int = Field( default=8, ge=1, le=100, description="训练总轮数" ) save_every_epoch: int = Field( default=4, ge=1, description="每 N 轮保存一次模型" ) pretrained_s2G: Optional[str] = Field( default=None, description="预训练生成器模型路径,为空使用默认" ) pretrained_s2D: Optional[str] = Field( default=None, description="预训练判别器模型路径,为空使用默认" ) model_config = { "json_schema_extra": { "examples": [ { "batch_size": 8, "total_epoch": 16, "save_every_epoch": 4, "pretrained_s2G": None, "pretrained_s2D": None } ] } } class GPTTrainParams(StageExecuteRequest): """ GPT 训练参数 """ batch_size: int = Field( default=4, ge=1, le=32, description="批次大小" ) total_epoch: int = Field( default=15, ge=1, le=100, description="训练总轮数" ) save_every_epoch: int = Field( default=5, ge=1, description="每 N 轮保存一次模型" ) pretrained_s1: Optional[str] = Field( default=None, description="预训练模型路径,为空使用默认" ) model_config = { "json_schema_extra": { "examples": [ { "batch_size": 4, "total_epoch": 15, "save_every_epoch": 5, "pretrained_s1": None } ] } } class StageExecuteResponse(BaseModel): """ 阶段执行响应 """ exp_id: str = Field(..., description="实验 ID") stage_type: str = Field(..., description="阶段类型") status: Literal["running", "queued"] = Field(..., description="执行状态") job_id: str = Field(..., description="作业 ID") config: Dict[str, Any] = Field( default_factory=dict, description="阶段配置" ) rerun: bool = Field( default=False, description="是否为重新执行" ) previous_run: Optional[Dict[str, Any]] = Field( default=None, description="上次执行的信息(重新执行时)" ) started_at: datetime = Field(..., description="开始时间") model_config = { "json_schema_extra": { "examples": [ { "exp_id": "exp-abc123", "stage_type": "sovits_train", "status": "running", "job_id": "job-xyz789", "config": {"batch_size": 8, "total_epoch": 16}, "rerun": False, "started_at": "2024-01-01T10:30:00Z" } ] } } class StagesListResponse(BaseModel): """ 所有阶段状态响应 """ exp_id: str = Field(..., description="实验 ID") stages: List[StageStatus] = Field( default_factory=list, description="阶段状态列表" ) model_config = { "json_schema_extra": { "examples": [ { "exp_id": "exp-abc123", "stages": [ {"stage_type": "audio_slice", "status": "completed"}, {"stage_type": "asr", "status": "completed"}, {"stage_type": "sovits_train", "status": "running", "progress": 0.45} ] } ] } } # 阶段参数类型映射 STAGE_PARAMS_MAP: Dict[str, type] = { "audio_slice": AudioSliceParams, "asr": ASRParams, "text_feature": TextFeatureParams, "hubert_feature": HubertFeatureParams, "semantic_token": SemanticTokenParams, "sovits_train": SoVITSTrainParams, "gpt_train": GPTTrainParams, } def get_stage_params_class(stage_type: str) -> type: """ 获取阶段对应的参数类 Args: stage_type: 阶段类型 Returns: 对应的参数 Pydantic 类 Raises: ValueError: 无效的阶段类型 """ if stage_type not in STAGE_PARAMS_MAP: raise ValueError(f"Invalid stage type: {stage_type}") return STAGE_PARAMS_MAP[stage_type]