| """ |
| Quick Mode 任务 API |
| |
| 小白用户一键训练 API 端点 |
| |
| API 列表: |
| - POST /tasks 创建一键训练任务 |
| - GET /tasks 获取任务列表 |
| - GET /tasks/{task_id} 获取任务详情 |
| - DELETE /tasks/{task_id} 取消任务 |
| - GET /tasks/{task_id}/progress SSE 进度订阅 |
| - GET /tasks/{task_id}/outputs 获取推理输出列表 |
| - GET /tasks/{task_id}/outputs/{file_type}/{filename} 下载任务相关文件 |
| """ |
|
|
| import json |
| from typing import Literal, Optional |
|
|
| from fastapi import APIRouter, Depends, HTTPException, Path, Query |
| from fastapi.responses import StreamingResponse, Response |
|
|
| from ....models.schemas.task import ( |
| QuickModeRequest, |
| TaskResponse, |
| TaskListResponse, |
| InferenceOutputsResponse, |
| ) |
| from ....models.schemas.common import SuccessResponse, ErrorResponse |
| from ....services.task_service import TaskService |
| from ...deps import get_task_service |
|
|
| router = APIRouter() |
|
|
|
|
| @router.post( |
| "", |
| response_model=TaskResponse, |
| summary="创建一键训练任务", |
| description=""" |
| 创建一键训练任务(小白用户)。 |
| |
| 上传音频文件后,系统自动配置所有参数并执行完整训练流程: |
| `audio_slice -> asr -> text_feature -> hubert_feature -> semantic_token -> sovits_train -> gpt_train -> inference` |
| |
| 训练完成后会自动进行推理测试,生成测试音频文件。可通过 `options.inference` 配置推理参数或禁用推理阶段。 |
| |
| **质量预设**: |
| - `fast`: SoVITS 4 epochs, GPT 8 epochs, 约10分钟 |
| - `standard`: SoVITS 8 epochs, GPT 15 epochs, 约20分钟 |
| - `high`: SoVITS 16 epochs, GPT 30 epochs, 约40分钟 |
| """, |
| responses={ |
| 200: {"model": TaskResponse, "description": "任务创建成功"}, |
| 400: {"model": ErrorResponse, "description": "请求参数错误"}, |
| 404: {"model": ErrorResponse, "description": "音频文件不存在"}, |
| 409: {"model": ErrorResponse, "description": "实验名称已存在"}, |
| }, |
| ) |
| async def create_task( |
| request: QuickModeRequest, |
| service: TaskService = Depends(get_task_service), |
| ) -> TaskResponse: |
| """ |
| 创建一键训练任务 |
| """ |
| |
| if await service.check_exp_name_exists(request.exp_name): |
| raise HTTPException( |
| status_code=409, |
| detail=f"实验名称 '{request.exp_name}' 已存在,请使用不同的名称" |
| ) |
| |
| |
| file_exists, audio_path = await service.validate_audio_file(request.audio_file_id) |
| if not file_exists: |
| raise HTTPException( |
| status_code=404, |
| detail=f"音频文件不存在: {request.audio_file_id}" |
| ) |
| |
| return await service.create_quick_task(request) |
|
|
|
|
| @router.get( |
| "", |
| response_model=TaskListResponse, |
| summary="获取任务列表", |
| description="获取所有训练任务列表,支持按状态筛选和分页。", |
| ) |
| async def list_tasks( |
| status: Optional[str] = Query( |
| None, |
| description="按状态筛选: queued, running, completed, failed, cancelled, interrupted" |
| ), |
| limit: int = Query(50, ge=1, le=100, description="每页数量"), |
| offset: int = Query(0, ge=0, description="偏移量"), |
| service: TaskService = Depends(get_task_service), |
| ) -> TaskListResponse: |
| """ |
| 获取任务列表 |
| """ |
| return await service.list_tasks(status=status, limit=limit, offset=offset) |
|
|
|
|
| @router.get( |
| "/{task_id}", |
| response_model=TaskResponse, |
| summary="获取任务详情", |
| description="获取指定任务的详细状态信息。", |
| responses={ |
| 200: {"model": TaskResponse, "description": "任务详情"}, |
| 404: {"model": ErrorResponse, "description": "任务不存在"}, |
| }, |
| ) |
| async def get_task( |
| task_id: str, |
| service: TaskService = Depends(get_task_service), |
| ) -> TaskResponse: |
| """ |
| 获取任务详情 |
| """ |
| task = await service.get_task(task_id) |
| if not task: |
| raise HTTPException(status_code=404, detail="任务不存在") |
| return task |
|
|
|
|
| @router.delete( |
| "/{task_id}", |
| response_model=SuccessResponse, |
| summary="取消任务", |
| description="取消排队中或运行中的任务。已完成、失败或已取消的任务无法取消。", |
| responses={ |
| 200: {"model": SuccessResponse, "description": "任务取消成功"}, |
| 400: {"model": ErrorResponse, "description": "任务无法取消"}, |
| 404: {"model": ErrorResponse, "description": "任务不存在"}, |
| }, |
| ) |
| async def cancel_task( |
| task_id: str, |
| service: TaskService = Depends(get_task_service), |
| ) -> SuccessResponse: |
| """ |
| 取消任务 |
| """ |
| |
| task = await service.get_task(task_id) |
| if not task: |
| raise HTTPException(status_code=404, detail="任务不存在") |
| |
| success = await service.cancel_task(task_id) |
| if not success: |
| raise HTTPException(status_code=400, detail="任务无法取消(可能已完成或已取消)") |
| |
| return SuccessResponse(message="任务已取消") |
|
|
|
|
| @router.get( |
| "/{task_id}/progress", |
| summary="SSE 进度订阅", |
| description=""" |
| 订阅任务进度更新(Server-Sent Events)。 |
| |
| 返回的事件流格式: |
| ``` |
| event: progress |
| data: {"stage": "sovits_train", "progress": 0.45, "message": "Epoch 8/16"} |
| |
| event: progress |
| data: {"stage": "sovits_train", "progress": 0.50, "message": "Epoch 9/16"} |
| |
| event: completed |
| data: {"status": "completed", "message": "训练完成"} |
| ``` |
| |
| 可能的事件类型: |
| - `progress`: 进度更新 |
| - `log`: 日志消息 |
| - `heartbeat`: 心跳(保持连接) |
| - `completed`: 任务完成 |
| - `failed`: 任务失败 |
| - `cancelled`: 任务取消 |
| """, |
| responses={ |
| 200: {"description": "SSE 事件流"}, |
| 404: {"model": ErrorResponse, "description": "任务不存在"}, |
| }, |
| ) |
| async def subscribe_progress( |
| task_id: str, |
| service: TaskService = Depends(get_task_service), |
| ) -> StreamingResponse: |
| """ |
| SSE 进度订阅 |
| """ |
| |
| task = await service.get_task(task_id) |
| if not task: |
| raise HTTPException(status_code=404, detail="任务不存在") |
| |
| async def event_generator(): |
| """生成 SSE 事件流""" |
| async for progress in service.subscribe_progress(task_id): |
| |
| event_type = progress.get("type", "progress") |
| status = progress.get("status") |
| |
| if status == "completed": |
| event_type = "completed" |
| elif status == "failed": |
| event_type = "failed" |
| elif status == "cancelled": |
| event_type = "cancelled" |
| elif event_type == "heartbeat": |
| event_type = "heartbeat" |
| |
| |
| data = json.dumps(progress, ensure_ascii=False) |
| yield f"event: {event_type}\ndata: {data}\n\n" |
| |
| |
| if status in ("completed", "failed", "cancelled"): |
| break |
| |
| return StreamingResponse( |
| event_generator(), |
| media_type="text/event-stream", |
| headers={ |
| "Cache-Control": "no-cache", |
| "Connection": "keep-alive", |
| "X-Accel-Buffering": "no", |
| }, |
| ) |
|
|
|
|
| @router.get( |
| "/{task_id}/outputs", |
| response_model=InferenceOutputsResponse, |
| summary="获取推理输出列表", |
| description=""" |
| 获取任务的推理输出文件列表及推理配置信息。 |
| |
| 训练任务完成后,推理阶段会生成测试音频文件。此端点返回所有生成的音频文件元信息, |
| 包括文件名、使用的模型路径、文件大小等,以及推理使用的参考音频和文本信息。 |
| |
| **推理配置**: |
| - `ref_text`: 参考音频的文本内容 |
| - `ref_audio_path`: 参考音频文件路径 |
| - `target_text`: 合成的目标文本 |
| |
| **输出文件信息**: |
| - `filename`: 文件名 |
| - `gpt_model`: 使用的 GPT 模型名称 |
| - `sovits_model`: 使用的 SoVITS 模型名称 |
| - `gpt_path`: GPT 模型完整路径 |
| - `sovits_path`: SoVITS 模型完整路径 |
| - `file_path`: 输出文件相对路径 |
| - `size_bytes`: 文件大小(字节) |
| - `created_at`: 创建时间 |
| |
| **下载文件**: |
| 使用 `/tasks/{task_id}/outputs/{file_type}/{filename}` 端点下载相关文件。 |
| """, |
| responses={ |
| 200: {"model": InferenceOutputsResponse, "description": "推理输出列表"}, |
| 404: {"model": ErrorResponse, "description": "任务不存在"}, |
| }, |
| ) |
| async def get_task_outputs( |
| task_id: str, |
| service: TaskService = Depends(get_task_service), |
| ) -> InferenceOutputsResponse: |
| """ |
| 获取任务的推理输出列表 |
| """ |
| result = await service.get_inference_outputs(task_id) |
| if result is None: |
| raise HTTPException(status_code=404, detail="任务不存在") |
| return result |
|
|
|
|
| |
| FileType = Literal["output", "ref_audio", "gpt_model", "sovits_model"] |
|
|
|
|
| @router.get( |
| "/{task_id}/outputs/{file_type}/{filename:path}", |
| summary="下载任务相关文件", |
| description=""" |
| 下载任务相关的各类文件。 |
| |
| **文件类型 (file_type)**: |
| - `output` - 推理输出音频文件 (.wav) |
| - `ref_audio` - 参考音频文件 (.wav) |
| - `gpt_model` - GPT 模型文件 (.ckpt) |
| - `sovits_model` - SoVITS 模型文件 (.pth) |
| |
| **文件名来源**: |
| - `output`: 从 `/tasks/{task_id}/outputs` 端点的 `outputs[].filename` 获取 |
| - `ref_audio`: 从 `/tasks/{task_id}/outputs` 端点的 `ref_audio_path` 获取 |
| - `gpt_model`: 从 `/tasks/{task_id}/outputs` 端点的 `outputs[].gpt_path` 获取文件名部分 |
| - `sovits_model`: 从 `/tasks/{task_id}/outputs` 端点的 `outputs[].sovits_path` 获取文件名部分 |
| |
| **返回**: |
| - 音频文件: Content-Type: audio/wav |
| - 模型文件: Content-Type: application/octet-stream |
| """, |
| responses={ |
| 200: {"description": "文件内容"}, |
| 404: {"model": ErrorResponse, "description": "任务或文件不存在"}, |
| }, |
| ) |
| async def download_task_file( |
| task_id: str, |
| file_type: FileType = Path(..., description="文件类型: output/ref_audio/gpt_model/sovits_model"), |
| filename: str = Path(..., description="文件名或路径"), |
| service: TaskService = Depends(get_task_service), |
| ) -> Response: |
| """ |
| 下载任务相关文件(推理输出、参考音频、模型文件) |
| """ |
| result = await service.download_file(task_id, file_type, filename) |
| if result is None: |
| raise HTTPException(status_code=404, detail="任务或文件不存在") |
| |
| file_data, file_name, content_type = result |
| |
| return Response( |
| content=file_data, |
| media_type=content_type, |
| headers={ |
| "Content-Disposition": f'attachment; filename="{file_name}"', |
| "Content-Length": str(len(file_data)), |
| }, |
| ) |
|
|