| | """ |
| | Advanced Mode 实验 API |
| | |
| | 专家用户分阶段训练 API 端点 |
| | |
| | API 列表: |
| | - POST /experiments 创建实验 |
| | - GET /experiments 获取实验列表 |
| | - GET /experiments/{exp_id} 获取实验详情 |
| | - PATCH /experiments/{exp_id} 更新实验配置 |
| | - DELETE /experiments/{exp_id} 删除实验 |
| | - POST /experiments/{exp_id}/stages/{stage_type} 执行阶段 |
| | - GET /experiments/{exp_id}/stages 获取所有阶段状态 |
| | - GET /experiments/{exp_id}/stages/{stage_type} 获取阶段详情 |
| | - DELETE /experiments/{exp_id}/stages/{stage_type} 取消阶段 |
| | - GET /experiments/{exp_id}/stages/{stage_type}/progress SSE 阶段进度 |
| | """ |
| |
|
| | import json |
| | from typing import Any, Dict, Optional |
| |
|
| | from fastapi import APIRouter, Body, Depends, HTTPException, Query |
| | from fastapi.responses import StreamingResponse |
| |
|
| | from ....models.schemas.experiment import ( |
| | ExperimentCreate, |
| | ExperimentUpdate, |
| | ExperimentResponse, |
| | ExperimentListResponse, |
| | StageStatus, |
| | StageExecuteResponse, |
| | StagesListResponse, |
| | STAGE_DEPENDENCIES, |
| | get_stage_params_class, |
| | ) |
| | from ....models.schemas.common import SuccessResponse, ErrorResponse |
| | from ....services.experiment_service import ExperimentService |
| | from ...deps import get_experiment_service |
| |
|
| | router = APIRouter() |
| |
|
| | |
| | VALID_STAGE_TYPES = list(STAGE_DEPENDENCIES.keys()) |
| |
|
| |
|
| | @router.post( |
| | "", |
| | response_model=ExperimentResponse, |
| | summary="创建实验", |
| | description=""" |
| | 创建实验(专家用户)。 |
| | |
| | 创建实验但不立即执行,用户可以逐阶段控制训练流程。 |
| | 实验创建后,所有阶段状态为 `pending`,需要手动触发执行。 |
| | |
| | **训练阶段**: |
| | - `audio_slice`: 音频切片 |
| | - `asr`: 语音识别 |
| | - `text_feature`: 文本特征提取 |
| | - `hubert_feature`: HuBERT 特征提取 |
| | - `semantic_token`: 语义 Token 提取 |
| | - `sovits_train`: SoVITS 训练 |
| | - `gpt_train`: GPT 训练 |
| | """, |
| | ) |
| | async def create_experiment( |
| | request: ExperimentCreate, |
| | service: ExperimentService = Depends(get_experiment_service), |
| | ) -> ExperimentResponse: |
| | """ |
| | 创建实验 |
| | """ |
| | return await service.create_experiment(request) |
| |
|
| |
|
| | @router.get( |
| | "", |
| | response_model=ExperimentListResponse, |
| | summary="获取实验列表", |
| | description="获取所有实验列表,支持按状态筛选和分页。", |
| | ) |
| | async def list_experiments( |
| | status: Optional[str] = Query(None, description="按状态筛选"), |
| | limit: int = Query(50, ge=1, le=100, description="每页数量"), |
| | offset: int = Query(0, ge=0, description="偏移量"), |
| | service: ExperimentService = Depends(get_experiment_service), |
| | ) -> ExperimentListResponse: |
| | """ |
| | 获取实验列表 |
| | """ |
| | return await service.list_experiments(status=status, limit=limit, offset=offset) |
| |
|
| |
|
| | @router.get( |
| | "/{exp_id}", |
| | response_model=ExperimentResponse, |
| | summary="获取实验详情", |
| | description="获取指定实验的详细信息,包括所有阶段状态。", |
| | responses={ |
| | 404: {"model": ErrorResponse, "description": "实验不存在"}, |
| | }, |
| | ) |
| | async def get_experiment( |
| | exp_id: str, |
| | service: ExperimentService = Depends(get_experiment_service), |
| | ) -> ExperimentResponse: |
| | """ |
| | 获取实验详情 |
| | """ |
| | experiment = await service.get_experiment(exp_id) |
| | if not experiment: |
| | raise HTTPException(status_code=404, detail="实验不存在") |
| | return experiment |
| |
|
| |
|
| | @router.patch( |
| | "/{exp_id}", |
| | response_model=ExperimentResponse, |
| | summary="更新实验配置", |
| | description="更新实验的基础配置(非阶段参数)。", |
| | responses={ |
| | 404: {"model": ErrorResponse, "description": "实验不存在"}, |
| | }, |
| | ) |
| | async def update_experiment( |
| | exp_id: str, |
| | request: ExperimentUpdate, |
| | service: ExperimentService = Depends(get_experiment_service), |
| | ) -> ExperimentResponse: |
| | """ |
| | 更新实验配置 |
| | """ |
| | experiment = await service.update_experiment(exp_id, request) |
| | if not experiment: |
| | raise HTTPException(status_code=404, detail="实验不存在") |
| | return experiment |
| |
|
| |
|
| | @router.delete( |
| | "/{exp_id}", |
| | response_model=SuccessResponse, |
| | summary="删除实验", |
| | description="删除实验及其所有阶段数据。如果有正在运行的阶段,会先取消执行。", |
| | responses={ |
| | 404: {"model": ErrorResponse, "description": "实验不存在"}, |
| | }, |
| | ) |
| | async def delete_experiment( |
| | exp_id: str, |
| | service: ExperimentService = Depends(get_experiment_service), |
| | ) -> SuccessResponse: |
| | """ |
| | 删除实验 |
| | """ |
| | success = await service.delete_experiment(exp_id) |
| | if not success: |
| | raise HTTPException(status_code=404, detail="实验不存在") |
| | return SuccessResponse(message="实验已删除") |
| |
|
| |
|
| | @router.post( |
| | "/{exp_id}/stages/{stage_type}", |
| | response_model=StageExecuteResponse, |
| | summary="执行阶段", |
| | description=""" |
| | 执行指定阶段。 |
| | |
| | **阶段依赖关系**: |
| | - `audio_slice`: 无依赖 |
| | - `asr`: 依赖 audio_slice |
| | - `text_feature`: 依赖 asr |
| | - `hubert_feature`: 依赖 audio_slice |
| | - `semantic_token`: 依赖 hubert_feature |
| | - `sovits_train`: 依赖 text_feature, semantic_token |
| | - `gpt_train`: 依赖 text_feature, semantic_token |
| | |
| | 如果依赖阶段未完成,会返回 400 错误。 |
| | 如果阶段已完成,会重新执行(返回 `rerun: true`)。 |
| | """, |
| | responses={ |
| | 400: {"model": ErrorResponse, "description": "阶段类型无效或依赖未满足"}, |
| | 404: {"model": ErrorResponse, "description": "实验不存在"}, |
| | }, |
| | ) |
| | async def execute_stage( |
| | exp_id: str, |
| | stage_type: str, |
| | params: Dict[str, Any] = Body(default={}), |
| | service: ExperimentService = Depends(get_experiment_service), |
| | ) -> StageExecuteResponse: |
| | """ |
| | 执行阶段 |
| | """ |
| | |
| | if stage_type not in VALID_STAGE_TYPES: |
| | raise HTTPException( |
| | status_code=400, |
| | detail=f"无效的阶段类型: {stage_type}。有效类型: {', '.join(VALID_STAGE_TYPES)}" |
| | ) |
| | |
| | |
| | experiment = await service.get_experiment(exp_id) |
| | if not experiment: |
| | raise HTTPException(status_code=404, detail="实验不存在") |
| | |
| | |
| | deps = await service.check_stage_dependencies(exp_id, stage_type) |
| | if not deps["satisfied"]: |
| | raise HTTPException( |
| | status_code=400, |
| | detail=f"依赖阶段未完成: {', '.join(deps['missing'])}" |
| | ) |
| | |
| | |
| | try: |
| | params_class = get_stage_params_class(stage_type) |
| | validated_params = params_class(**params) |
| | params = validated_params.model_dump(exclude_unset=True) |
| | except ValueError as e: |
| | raise HTTPException(status_code=400, detail=str(e)) |
| | |
| | |
| | result = await service.execute_stage(exp_id, stage_type, params) |
| | if not result: |
| | raise HTTPException(status_code=404, detail="实验不存在") |
| | |
| | return result |
| |
|
| |
|
| | @router.get( |
| | "/{exp_id}/stages", |
| | response_model=StagesListResponse, |
| | summary="获取所有阶段状态", |
| | description="获取实验的所有阶段状态列表。", |
| | responses={ |
| | 404: {"model": ErrorResponse, "description": "实验不存在"}, |
| | }, |
| | ) |
| | async def get_all_stages( |
| | exp_id: str, |
| | service: ExperimentService = Depends(get_experiment_service), |
| | ) -> StagesListResponse: |
| | """ |
| | 获取所有阶段状态 |
| | """ |
| | result = await service.get_all_stages(exp_id) |
| | if not result: |
| | raise HTTPException(status_code=404, detail="实验不存在") |
| | return result |
| |
|
| |
|
| | @router.get( |
| | "/{exp_id}/stages/{stage_type}", |
| | response_model=StageStatus, |
| | summary="获取阶段详情", |
| | description="获取指定阶段的详细状态和结果。", |
| | responses={ |
| | 400: {"model": ErrorResponse, "description": "阶段类型无效"}, |
| | 404: {"model": ErrorResponse, "description": "实验或阶段不存在"}, |
| | }, |
| | ) |
| | async def get_stage( |
| | exp_id: str, |
| | stage_type: str, |
| | service: ExperimentService = Depends(get_experiment_service), |
| | ) -> StageStatus: |
| | """ |
| | 获取阶段详情 |
| | """ |
| | |
| | if stage_type not in VALID_STAGE_TYPES: |
| | raise HTTPException( |
| | status_code=400, |
| | detail=f"无效的阶段类型: {stage_type}" |
| | ) |
| | |
| | stage = await service.get_stage(exp_id, stage_type) |
| | if not stage: |
| | raise HTTPException(status_code=404, detail="实验或阶段不存在") |
| | return stage |
| |
|
| |
|
| | @router.delete( |
| | "/{exp_id}/stages/{stage_type}", |
| | response_model=SuccessResponse, |
| | summary="取消阶段", |
| | description="取消正在执行的阶段。只有运行中的阶段可以取消。", |
| | responses={ |
| | 400: {"model": ErrorResponse, "description": "阶段未运行或无法取消"}, |
| | 404: {"model": ErrorResponse, "description": "实验或阶段不存在"}, |
| | }, |
| | ) |
| | async def cancel_stage( |
| | exp_id: str, |
| | stage_type: str, |
| | service: ExperimentService = Depends(get_experiment_service), |
| | ) -> SuccessResponse: |
| | """ |
| | 取消阶段 |
| | """ |
| | |
| | if stage_type not in VALID_STAGE_TYPES: |
| | raise HTTPException( |
| | status_code=400, |
| | detail=f"无效的阶段类型: {stage_type}" |
| | ) |
| | |
| | success = await service.cancel_stage(exp_id, stage_type) |
| | if not success: |
| | raise HTTPException( |
| | status_code=400, |
| | detail="阶段未运行或无法取消" |
| | ) |
| | |
| | return SuccessResponse(message=f"阶段 {stage_type} 已取消") |
| |
|
| |
|
| | @router.get( |
| | "/{exp_id}/stages/{stage_type}/progress", |
| | summary="SSE 阶段进度订阅", |
| | description=""" |
| | 订阅阶段进度更新(Server-Sent Events)。 |
| | |
| | 返回的事件流格式: |
| | ``` |
| | event: progress |
| | data: {"epoch": 8, "total_epochs": 16, "progress": 0.50, "loss": 0.034} |
| | |
| | event: checkpoint |
| | data: {"epoch": 8, "model_path": "logs/my_voice/sovits_e8.pth"} |
| | |
| | event: completed |
| | data: {"status": "completed", "final_loss": 0.023} |
| | ``` |
| | """, |
| | responses={ |
| | 400: {"model": ErrorResponse, "description": "阶段类型无效"}, |
| | 404: {"model": ErrorResponse, "description": "实验或阶段不存在"}, |
| | }, |
| | ) |
| | async def subscribe_stage_progress( |
| | exp_id: str, |
| | stage_type: str, |
| | service: ExperimentService = Depends(get_experiment_service), |
| | ) -> StreamingResponse: |
| | """ |
| | SSE 阶段进度订阅 |
| | """ |
| | |
| | if stage_type not in VALID_STAGE_TYPES: |
| | raise HTTPException( |
| | status_code=400, |
| | detail=f"无效的阶段类型: {stage_type}" |
| | ) |
| | |
| | |
| | experiment = await service.get_experiment(exp_id) |
| | if not experiment: |
| | raise HTTPException(status_code=404, detail="实验不存在") |
| | |
| | async def event_generator(): |
| | """生成 SSE 事件流""" |
| | async for progress in service.subscribe_stage_progress(exp_id, stage_type): |
| | |
| | event_type = progress.get("type", "progress") |
| | status = progress.get("status") |
| | |
| | if status == "completed": |
| | event_type = "completed" |
| | elif status == "failed": |
| | event_type = "failed" |
| | elif status == "cancelled": |
| | event_type = "cancelled" |
| | elif progress.get("model_path"): |
| | event_type = "checkpoint" |
| | |
| | |
| | data = json.dumps(progress, ensure_ascii=False) |
| | yield f"event: {event_type}\ndata: {data}\n\n" |
| | |
| | |
| | if status in ("completed", "failed", "cancelled"): |
| | break |
| | |
| | return StreamingResponse( |
| | event_generator(), |
| | media_type="text/event-stream", |
| | headers={ |
| | "Cache-Control": "no-cache", |
| | "Connection": "keep-alive", |
| | "X-Accel-Buffering": "no", |
| | }, |
| | ) |
| |
|