MoYoYo.tts / api_server /app /models /schemas /experiment.py
liumaolin
feat(api): implement local training MVP with adapter pattern
e054d0c
"""
Advanced Mode 实验/阶段 Schema
专家用户分阶段训练模式的请求/响应模型
参考文档: development.md 4.6.2
"""
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Literal, Optional
from pydantic import BaseModel, Field
# ============================================================
# 枚举类型
# ============================================================
class StageType(str, Enum):
"""
训练阶段类型枚举
定义了完整训练流程中的所有阶段
"""
AUDIO_SLICE = "audio_slice" # 音频切片
ASR = "asr" # 语音识别
TEXT_FEATURE = "text_feature" # 文本特征提取
HUBERT_FEATURE = "hubert_feature" # HuBERT 特征提取
SEMANTIC_TOKEN = "semantic_token" # 语义 Token 提取
SOVITS_TRAIN = "sovits_train" # SoVITS 训练
GPT_TRAIN = "gpt_train" # GPT 训练
# 阶段依赖关系
STAGE_DEPENDENCIES: Dict[str, List[str]] = {
"audio_slice": [],
"asr": ["audio_slice"],
"text_feature": ["asr"],
"hubert_feature": ["audio_slice"],
"semantic_token": ["hubert_feature"],
"sovits_train": ["text_feature", "semantic_token"],
"gpt_train": ["text_feature", "semantic_token"],
}
# ============================================================
# 实验管理
# ============================================================
class ExperimentCreate(BaseModel):
"""
创建实验请求
创建实验但不立即执行,用户可以逐阶段控制训练流程
Attributes:
exp_name: 实验名称
version: 模型版本
gpu_numbers: GPU 编号
is_half: 是否使用半精度
audio_file_id: 音频文件 ID
"""
exp_name: str = Field(
...,
min_length=1,
max_length=100,
description="实验名称"
)
version: Literal["v1", "v2", "v2Pro", "v3", "v4"] = Field(
default="v2",
description="模型版本"
)
gpu_numbers: str = Field(
default="0",
description="GPU 编号,多个 GPU 用逗号分隔,如 '0,1'"
)
is_half: bool = Field(
default=True,
description="是否使用半精度(FP16),可节省显存"
)
audio_file_id: str = Field(
...,
description="已上传音频文件的 ID"
)
model_config = {
"json_schema_extra": {
"examples": [
{
"exp_name": "my_voice_custom",
"version": "v2",
"gpu_numbers": "0",
"is_half": True,
"audio_file_id": "550e8400-e29b-41d4-a716-446655440000"
}
]
}
}
class ExperimentUpdate(BaseModel):
"""
更新实验请求
用于更新实验的基础配置(非阶段参数)
"""
exp_name: Optional[str] = Field(
default=None,
min_length=1,
max_length=100,
description="实验名称"
)
gpu_numbers: Optional[str] = Field(
default=None,
description="GPU 编号"
)
is_half: Optional[bool] = Field(
default=None,
description="是否使用半精度"
)
class StageStatus(BaseModel):
"""
阶段状态
描述单个阶段的执行状态和结果
"""
stage_type: str = Field(..., description="阶段类型")
status: Literal["pending", "running", "completed", "failed", "cancelled"] = Field(
default="pending",
description="阶段状态"
)
progress: Optional[float] = Field(
default=None,
ge=0.0,
le=1.0,
description="阶段进度 (0.0-1.0)"
)
message: Optional[str] = Field(
default=None,
description="状态消息"
)
started_at: Optional[datetime] = Field(
default=None,
description="开始时间"
)
completed_at: Optional[datetime] = Field(
default=None,
description="完成时间"
)
config: Optional[Dict[str, Any]] = Field(
default=None,
description="阶段配置参数"
)
outputs: Optional[Dict[str, Any]] = Field(
default=None,
description="阶段输出结果"
)
error_message: Optional[str] = Field(
default=None,
description="错误消息(失败时)"
)
model_config = {
"json_schema_extra": {
"examples": [
{
"stage_type": "sovits_train",
"status": "completed",
"progress": 1.0,
"message": "训练完成",
"started_at": "2024-01-01T10:30:00Z",
"completed_at": "2024-01-01T11:00:00Z",
"config": {"batch_size": 8, "total_epoch": 16},
"outputs": {
"model_path": "logs/my_voice/sovits_e16.pth",
"metrics": {"final_loss": 0.023}
}
}
]
}
}
class ExperimentResponse(BaseModel):
"""
实验响应
包含实验的完整信息和所有阶段状态
"""
id: str = Field(..., description="实验唯一标识")
exp_name: str = Field(..., description="实验名称")
version: str = Field(..., description="模型版本")
status: str = Field(..., description="实验状态")
gpu_numbers: str = Field(default="0", description="GPU 编号")
is_half: bool = Field(default=True, description="是否使用半精度")
audio_file_id: str = Field(..., description="音频文件 ID")
stages: Dict[str, StageStatus] = Field(
default_factory=dict,
description="各阶段状态"
)
created_at: datetime = Field(..., description="创建时间")
updated_at: Optional[datetime] = Field(default=None, description="更新时间")
model_config = {
"json_schema_extra": {
"examples": [
{
"id": "exp-abc123",
"exp_name": "my_voice_custom",
"version": "v2",
"status": "created",
"gpu_numbers": "0",
"is_half": True,
"audio_file_id": "550e8400-e29b-41d4-a716-446655440000",
"stages": {
"audio_slice": {"stage_type": "audio_slice", "status": "pending"},
"asr": {"stage_type": "asr", "status": "pending"},
"sovits_train": {"stage_type": "sovits_train", "status": "pending"}
},
"created_at": "2024-01-01T10:00:00Z"
}
]
}
}
class ExperimentListResponse(BaseModel):
"""
实验列表响应
"""
items: List[ExperimentResponse] = Field(
default_factory=list,
description="实验列表"
)
total: int = Field(default=0, ge=0, description="总数量")
limit: int = Field(default=50, ge=1, le=100, description="每页数量")
offset: int = Field(default=0, ge=0, description="偏移量")
# ============================================================
# 阶段执行参数
# ============================================================
class StageExecuteRequest(BaseModel):
"""
阶段执行请求基类
允许传入任意额外参数
"""
model_config = {
"extra": "allow" # 允许额外字段(阶段特定参数)
}
class AudioSliceParams(StageExecuteRequest):
"""
音频切片参数
将长音频切分为短片段
参考文档: development.md 4.5.2
"""
threshold: int = Field(
default=-34,
ge=-60,
le=0,
description="静音检测阈值 (dB)"
)
min_length: int = Field(
default=4000,
ge=1000,
le=10000,
description="最小切片长度 (ms)"
)
min_interval: int = Field(
default=300,
ge=100,
le=1000,
description="最小静音间隔 (ms)"
)
hop_size: int = Field(
default=10,
ge=5,
le=50,
description="检测步长 (ms)"
)
max_sil_kept: int = Field(
default=500,
ge=100,
le=2000,
description="切片保留的最大静音长度 (ms)"
)
model_config = {
"json_schema_extra": {
"examples": [
{
"threshold": -34,
"min_length": 4000,
"min_interval": 300,
"hop_size": 10,
"max_sil_kept": 500
}
]
}
}
class ASRParams(StageExecuteRequest):
"""
ASR 语音识别参数
"""
model: str = Field(
default="达摩 ASR (中文)",
description="ASR 模型名称"
)
language: str = Field(
default="zh",
description="识别语言"
)
model_config = {
"json_schema_extra": {
"examples": [
{"model": "达摩 ASR (中文)", "language": "zh"}
]
}
}
class TextFeatureParams(StageExecuteRequest):
"""
文本特征提取参数
"""
bert_pretrained_dir: Optional[str] = Field(
default=None,
description="BERT 预训练模型目录,为空使用默认"
)
model_config = {
"json_schema_extra": {
"examples": [
{"bert_pretrained_dir": None}
]
}
}
class HubertFeatureParams(StageExecuteRequest):
"""
HuBERT 特征提取参数
"""
ssl_pretrained_dir: Optional[str] = Field(
default=None,
description="SSL 预训练模型目录,为空使用默认"
)
model_config = {
"json_schema_extra": {
"examples": [
{"ssl_pretrained_dir": None}
]
}
}
class SemanticTokenParams(StageExecuteRequest):
"""
语义 Token 提取参数
"""
# 当前阶段无特殊参数,保留扩展性
pass
class SoVITSTrainParams(StageExecuteRequest):
"""
SoVITS 训练参数
参考文档: development.md 4.5.2
"""
batch_size: int = Field(
default=4,
ge=1,
le=32,
description="批次大小,显存不足时减小"
)
total_epoch: int = Field(
default=8,
ge=1,
le=100,
description="训练总轮数"
)
save_every_epoch: int = Field(
default=4,
ge=1,
description="每 N 轮保存一次模型"
)
pretrained_s2G: Optional[str] = Field(
default=None,
description="预训练生成器模型路径,为空使用默认"
)
pretrained_s2D: Optional[str] = Field(
default=None,
description="预训练判别器模型路径,为空使用默认"
)
model_config = {
"json_schema_extra": {
"examples": [
{
"batch_size": 8,
"total_epoch": 16,
"save_every_epoch": 4,
"pretrained_s2G": None,
"pretrained_s2D": None
}
]
}
}
class GPTTrainParams(StageExecuteRequest):
"""
GPT 训练参数
"""
batch_size: int = Field(
default=4,
ge=1,
le=32,
description="批次大小"
)
total_epoch: int = Field(
default=15,
ge=1,
le=100,
description="训练总轮数"
)
save_every_epoch: int = Field(
default=5,
ge=1,
description="每 N 轮保存一次模型"
)
pretrained_s1: Optional[str] = Field(
default=None,
description="预训练模型路径,为空使用默认"
)
model_config = {
"json_schema_extra": {
"examples": [
{
"batch_size": 4,
"total_epoch": 15,
"save_every_epoch": 5,
"pretrained_s1": None
}
]
}
}
class StageExecuteResponse(BaseModel):
"""
阶段执行响应
"""
exp_id: str = Field(..., description="实验 ID")
stage_type: str = Field(..., description="阶段类型")
status: Literal["running", "queued"] = Field(..., description="执行状态")
job_id: str = Field(..., description="作业 ID")
config: Dict[str, Any] = Field(
default_factory=dict,
description="阶段配置"
)
rerun: bool = Field(
default=False,
description="是否为重新执行"
)
previous_run: Optional[Dict[str, Any]] = Field(
default=None,
description="上次执行的信息(重新执行时)"
)
started_at: datetime = Field(..., description="开始时间")
model_config = {
"json_schema_extra": {
"examples": [
{
"exp_id": "exp-abc123",
"stage_type": "sovits_train",
"status": "running",
"job_id": "job-xyz789",
"config": {"batch_size": 8, "total_epoch": 16},
"rerun": False,
"started_at": "2024-01-01T10:30:00Z"
}
]
}
}
class StagesListResponse(BaseModel):
"""
所有阶段状态响应
"""
exp_id: str = Field(..., description="实验 ID")
stages: List[StageStatus] = Field(
default_factory=list,
description="阶段状态列表"
)
model_config = {
"json_schema_extra": {
"examples": [
{
"exp_id": "exp-abc123",
"stages": [
{"stage_type": "audio_slice", "status": "completed"},
{"stage_type": "asr", "status": "completed"},
{"stage_type": "sovits_train", "status": "running", "progress": 0.45}
]
}
]
}
}
# 阶段参数类型映射
STAGE_PARAMS_MAP: Dict[str, type] = {
"audio_slice": AudioSliceParams,
"asr": ASRParams,
"text_feature": TextFeatureParams,
"hubert_feature": HubertFeatureParams,
"semantic_token": SemanticTokenParams,
"sovits_train": SoVITSTrainParams,
"gpt_train": GPTTrainParams,
}
def get_stage_params_class(stage_type: str) -> type:
"""
获取阶段对应的参数类
Args:
stage_type: 阶段类型
Returns:
对应的参数 Pydantic 类
Raises:
ValueError: 无效的阶段类型
"""
if stage_type not in STAGE_PARAMS_MAP:
raise ValueError(f"Invalid stage type: {stage_type}")
return STAGE_PARAMS_MAP[stage_type]