| """ |
| Advanced Mode 实验/阶段 Schema |
| |
| 专家用户分阶段训练模式的请求/响应模型 |
| |
| 参考文档: development.md 4.6.2 |
| """ |
|
|
| from datetime import datetime |
| from enum import Enum |
| from typing import Any, Dict, List, Literal, Optional |
| from pydantic import BaseModel, Field |
|
|
|
|
| |
| |
| |
|
|
| class StageType(str, Enum): |
| """ |
| 训练阶段类型枚举 |
| |
| 定义了完整训练流程中的所有阶段 |
| """ |
| AUDIO_SLICE = "audio_slice" |
| ASR = "asr" |
| TEXT_FEATURE = "text_feature" |
| HUBERT_FEATURE = "hubert_feature" |
| SEMANTIC_TOKEN = "semantic_token" |
| SOVITS_TRAIN = "sovits_train" |
| GPT_TRAIN = "gpt_train" |
|
|
|
|
| |
| STAGE_DEPENDENCIES: Dict[str, List[str]] = { |
| "audio_slice": [], |
| "asr": ["audio_slice"], |
| "text_feature": ["asr"], |
| "hubert_feature": ["audio_slice"], |
| "semantic_token": ["hubert_feature"], |
| "sovits_train": ["text_feature", "semantic_token"], |
| "gpt_train": ["text_feature", "semantic_token"], |
| } |
|
|
|
|
| |
| |
| |
|
|
| class ExperimentCreate(BaseModel): |
| """ |
| 创建实验请求 |
| |
| 创建实验但不立即执行,用户可以逐阶段控制训练流程 |
| |
| Attributes: |
| exp_name: 实验名称 |
| version: 模型版本 |
| gpu_numbers: GPU 编号 |
| is_half: 是否使用半精度 |
| audio_file_id: 音频文件 ID |
| """ |
| exp_name: str = Field( |
| ..., |
| min_length=1, |
| max_length=100, |
| description="实验名称" |
| ) |
| version: Literal["v1", "v2", "v2Pro", "v3", "v4"] = Field( |
| default="v2", |
| description="模型版本" |
| ) |
| gpu_numbers: str = Field( |
| default="0", |
| description="GPU 编号,多个 GPU 用逗号分隔,如 '0,1'" |
| ) |
| is_half: bool = Field( |
| default=True, |
| description="是否使用半精度(FP16),可节省显存" |
| ) |
| audio_file_id: str = Field( |
| ..., |
| description="已上传音频文件的 ID" |
| ) |
| |
| model_config = { |
| "json_schema_extra": { |
| "examples": [ |
| { |
| "exp_name": "my_voice_custom", |
| "version": "v2", |
| "gpu_numbers": "0", |
| "is_half": True, |
| "audio_file_id": "550e8400-e29b-41d4-a716-446655440000" |
| } |
| ] |
| } |
| } |
|
|
|
|
| class ExperimentUpdate(BaseModel): |
| """ |
| 更新实验请求 |
| |
| 用于更新实验的基础配置(非阶段参数) |
| """ |
| exp_name: Optional[str] = Field( |
| default=None, |
| min_length=1, |
| max_length=100, |
| description="实验名称" |
| ) |
| gpu_numbers: Optional[str] = Field( |
| default=None, |
| description="GPU 编号" |
| ) |
| is_half: Optional[bool] = Field( |
| default=None, |
| description="是否使用半精度" |
| ) |
|
|
|
|
| class StageStatus(BaseModel): |
| """ |
| 阶段状态 |
| |
| 描述单个阶段的执行状态和结果 |
| """ |
| stage_type: str = Field(..., description="阶段类型") |
| status: Literal["pending", "running", "completed", "failed", "cancelled"] = Field( |
| default="pending", |
| description="阶段状态" |
| ) |
| progress: Optional[float] = Field( |
| default=None, |
| ge=0.0, |
| le=1.0, |
| description="阶段进度 (0.0-1.0)" |
| ) |
| message: Optional[str] = Field( |
| default=None, |
| description="状态消息" |
| ) |
| started_at: Optional[datetime] = Field( |
| default=None, |
| description="开始时间" |
| ) |
| completed_at: Optional[datetime] = Field( |
| default=None, |
| description="完成时间" |
| ) |
| config: Optional[Dict[str, Any]] = Field( |
| default=None, |
| description="阶段配置参数" |
| ) |
| outputs: Optional[Dict[str, Any]] = Field( |
| default=None, |
| description="阶段输出结果" |
| ) |
| error_message: Optional[str] = Field( |
| default=None, |
| description="错误消息(失败时)" |
| ) |
| |
| model_config = { |
| "json_schema_extra": { |
| "examples": [ |
| { |
| "stage_type": "sovits_train", |
| "status": "completed", |
| "progress": 1.0, |
| "message": "训练完成", |
| "started_at": "2024-01-01T10:30:00Z", |
| "completed_at": "2024-01-01T11:00:00Z", |
| "config": {"batch_size": 8, "total_epoch": 16}, |
| "outputs": { |
| "model_path": "logs/my_voice/sovits_e16.pth", |
| "metrics": {"final_loss": 0.023} |
| } |
| } |
| ] |
| } |
| } |
|
|
|
|
| class ExperimentResponse(BaseModel): |
| """ |
| 实验响应 |
| |
| 包含实验的完整信息和所有阶段状态 |
| """ |
| id: str = Field(..., description="实验唯一标识") |
| exp_name: str = Field(..., description="实验名称") |
| version: str = Field(..., description="模型版本") |
| status: str = Field(..., description="实验状态") |
| gpu_numbers: str = Field(default="0", description="GPU 编号") |
| is_half: bool = Field(default=True, description="是否使用半精度") |
| audio_file_id: str = Field(..., description="音频文件 ID") |
| stages: Dict[str, StageStatus] = Field( |
| default_factory=dict, |
| description="各阶段状态" |
| ) |
| created_at: datetime = Field(..., description="创建时间") |
| updated_at: Optional[datetime] = Field(default=None, description="更新时间") |
| |
| model_config = { |
| "json_schema_extra": { |
| "examples": [ |
| { |
| "id": "exp-abc123", |
| "exp_name": "my_voice_custom", |
| "version": "v2", |
| "status": "created", |
| "gpu_numbers": "0", |
| "is_half": True, |
| "audio_file_id": "550e8400-e29b-41d4-a716-446655440000", |
| "stages": { |
| "audio_slice": {"stage_type": "audio_slice", "status": "pending"}, |
| "asr": {"stage_type": "asr", "status": "pending"}, |
| "sovits_train": {"stage_type": "sovits_train", "status": "pending"} |
| }, |
| "created_at": "2024-01-01T10:00:00Z" |
| } |
| ] |
| } |
| } |
|
|
|
|
| class ExperimentListResponse(BaseModel): |
| """ |
| 实验列表响应 |
| """ |
| items: List[ExperimentResponse] = Field( |
| default_factory=list, |
| description="实验列表" |
| ) |
| total: int = Field(default=0, ge=0, description="总数量") |
| limit: int = Field(default=50, ge=1, le=100, description="每页数量") |
| offset: int = Field(default=0, ge=0, description="偏移量") |
|
|
|
|
| |
| |
| |
|
|
| class StageExecuteRequest(BaseModel): |
| """ |
| 阶段执行请求基类 |
| |
| 允许传入任意额外参数 |
| """ |
| model_config = { |
| "extra": "allow" |
| } |
|
|
|
|
| class AudioSliceParams(StageExecuteRequest): |
| """ |
| 音频切片参数 |
| |
| 将长音频切分为短片段 |
| |
| 参考文档: development.md 4.5.2 |
| """ |
| threshold: int = Field( |
| default=-34, |
| ge=-60, |
| le=0, |
| description="静音检测阈值 (dB)" |
| ) |
| min_length: int = Field( |
| default=4000, |
| ge=1000, |
| le=10000, |
| description="最小切片长度 (ms)" |
| ) |
| min_interval: int = Field( |
| default=300, |
| ge=100, |
| le=1000, |
| description="最小静音间隔 (ms)" |
| ) |
| hop_size: int = Field( |
| default=10, |
| ge=5, |
| le=50, |
| description="检测步长 (ms)" |
| ) |
| max_sil_kept: int = Field( |
| default=500, |
| ge=100, |
| le=2000, |
| description="切片保留的最大静音长度 (ms)" |
| ) |
| |
| model_config = { |
| "json_schema_extra": { |
| "examples": [ |
| { |
| "threshold": -34, |
| "min_length": 4000, |
| "min_interval": 300, |
| "hop_size": 10, |
| "max_sil_kept": 500 |
| } |
| ] |
| } |
| } |
|
|
|
|
| class ASRParams(StageExecuteRequest): |
| """ |
| ASR 语音识别参数 |
| """ |
| model: str = Field( |
| default="达摩 ASR (中文)", |
| description="ASR 模型名称" |
| ) |
| language: str = Field( |
| default="zh", |
| description="识别语言" |
| ) |
| |
| model_config = { |
| "json_schema_extra": { |
| "examples": [ |
| {"model": "达摩 ASR (中文)", "language": "zh"} |
| ] |
| } |
| } |
|
|
|
|
| class TextFeatureParams(StageExecuteRequest): |
| """ |
| 文本特征提取参数 |
| """ |
| bert_pretrained_dir: Optional[str] = Field( |
| default=None, |
| description="BERT 预训练模型目录,为空使用默认" |
| ) |
| |
| model_config = { |
| "json_schema_extra": { |
| "examples": [ |
| {"bert_pretrained_dir": None} |
| ] |
| } |
| } |
|
|
|
|
| class HubertFeatureParams(StageExecuteRequest): |
| """ |
| HuBERT 特征提取参数 |
| """ |
| ssl_pretrained_dir: Optional[str] = Field( |
| default=None, |
| description="SSL 预训练模型目录,为空使用默认" |
| ) |
| |
| model_config = { |
| "json_schema_extra": { |
| "examples": [ |
| {"ssl_pretrained_dir": None} |
| ] |
| } |
| } |
|
|
|
|
| class SemanticTokenParams(StageExecuteRequest): |
| """ |
| 语义 Token 提取参数 |
| """ |
| |
| pass |
|
|
|
|
| class SoVITSTrainParams(StageExecuteRequest): |
| """ |
| SoVITS 训练参数 |
| |
| 参考文档: development.md 4.5.2 |
| """ |
| batch_size: int = Field( |
| default=4, |
| ge=1, |
| le=32, |
| description="批次大小,显存不足时减小" |
| ) |
| total_epoch: int = Field( |
| default=8, |
| ge=1, |
| le=100, |
| description="训练总轮数" |
| ) |
| save_every_epoch: int = Field( |
| default=4, |
| ge=1, |
| description="每 N 轮保存一次模型" |
| ) |
| pretrained_s2G: Optional[str] = Field( |
| default=None, |
| description="预训练生成器模型路径,为空使用默认" |
| ) |
| pretrained_s2D: Optional[str] = Field( |
| default=None, |
| description="预训练判别器模型路径,为空使用默认" |
| ) |
| |
| model_config = { |
| "json_schema_extra": { |
| "examples": [ |
| { |
| "batch_size": 8, |
| "total_epoch": 16, |
| "save_every_epoch": 4, |
| "pretrained_s2G": None, |
| "pretrained_s2D": None |
| } |
| ] |
| } |
| } |
|
|
|
|
| class GPTTrainParams(StageExecuteRequest): |
| """ |
| GPT 训练参数 |
| """ |
| batch_size: int = Field( |
| default=4, |
| ge=1, |
| le=32, |
| description="批次大小" |
| ) |
| total_epoch: int = Field( |
| default=15, |
| ge=1, |
| le=100, |
| description="训练总轮数" |
| ) |
| save_every_epoch: int = Field( |
| default=5, |
| ge=1, |
| description="每 N 轮保存一次模型" |
| ) |
| pretrained_s1: Optional[str] = Field( |
| default=None, |
| description="预训练模型路径,为空使用默认" |
| ) |
| |
| model_config = { |
| "json_schema_extra": { |
| "examples": [ |
| { |
| "batch_size": 4, |
| "total_epoch": 15, |
| "save_every_epoch": 5, |
| "pretrained_s1": None |
| } |
| ] |
| } |
| } |
|
|
|
|
| class StageExecuteResponse(BaseModel): |
| """ |
| 阶段执行响应 |
| """ |
| exp_id: str = Field(..., description="实验 ID") |
| stage_type: str = Field(..., description="阶段类型") |
| status: Literal["running", "queued"] = Field(..., description="执行状态") |
| job_id: str = Field(..., description="作业 ID") |
| config: Dict[str, Any] = Field( |
| default_factory=dict, |
| description="阶段配置" |
| ) |
| rerun: bool = Field( |
| default=False, |
| description="是否为重新执行" |
| ) |
| previous_run: Optional[Dict[str, Any]] = Field( |
| default=None, |
| description="上次执行的信息(重新执行时)" |
| ) |
| started_at: datetime = Field(..., description="开始时间") |
| |
| model_config = { |
| "json_schema_extra": { |
| "examples": [ |
| { |
| "exp_id": "exp-abc123", |
| "stage_type": "sovits_train", |
| "status": "running", |
| "job_id": "job-xyz789", |
| "config": {"batch_size": 8, "total_epoch": 16}, |
| "rerun": False, |
| "started_at": "2024-01-01T10:30:00Z" |
| } |
| ] |
| } |
| } |
|
|
|
|
| class StagesListResponse(BaseModel): |
| """ |
| 所有阶段状态响应 |
| """ |
| exp_id: str = Field(..., description="实验 ID") |
| stages: List[StageStatus] = Field( |
| default_factory=list, |
| description="阶段状态列表" |
| ) |
| |
| model_config = { |
| "json_schema_extra": { |
| "examples": [ |
| { |
| "exp_id": "exp-abc123", |
| "stages": [ |
| {"stage_type": "audio_slice", "status": "completed"}, |
| {"stage_type": "asr", "status": "completed"}, |
| {"stage_type": "sovits_train", "status": "running", "progress": 0.45} |
| ] |
| } |
| ] |
| } |
| } |
|
|
|
|
| |
| STAGE_PARAMS_MAP: Dict[str, type] = { |
| "audio_slice": AudioSliceParams, |
| "asr": ASRParams, |
| "text_feature": TextFeatureParams, |
| "hubert_feature": HubertFeatureParams, |
| "semantic_token": SemanticTokenParams, |
| "sovits_train": SoVITSTrainParams, |
| "gpt_train": GPTTrainParams, |
| } |
|
|
|
|
| def get_stage_params_class(stage_type: str) -> type: |
| """ |
| 获取阶段对应的参数类 |
| |
| Args: |
| stage_type: 阶段类型 |
| |
| Returns: |
| 对应的参数 Pydantic 类 |
| |
| Raises: |
| ValueError: 无效的阶段类型 |
| """ |
| if stage_type not in STAGE_PARAMS_MAP: |
| raise ValueError(f"Invalid stage type: {stage_type}") |
| return STAGE_PARAMS_MAP[stage_type] |
|
|