| """ |
| 适配器抽象基类模块 |
| |
| 定义任务队列、存储、数据库等适配器的抽象接口 |
| """ |
|
|
| from abc import ABC, abstractmethod |
| from typing import TYPE_CHECKING, Dict, List, Optional, AsyncGenerator, Any |
|
|
| if TYPE_CHECKING: |
| from ..models.domain import Task |
|
|
|
|
| class TaskQueueAdapter(ABC): |
| """ |
| 任务队列适配器抽象基类 |
| |
| 定义任务队列的通用接口,支持本地(asyncio.subprocess)和 |
| 服务器(Celery)两种实现方式。 |
| |
| Example: |
| >>> adapter = AsyncTrainingManager(db_path="./data/tasks.db") |
| >>> job_id = await adapter.enqueue("task-123", {"exp_name": "test"}) |
| >>> status = await adapter.get_status(job_id) |
| >>> async for progress in adapter.subscribe_progress("task-123"): |
| ... print(progress) |
| """ |
| |
| @abstractmethod |
| async def enqueue(self, task_id: str, config: Dict, priority: str = "normal") -> str: |
| """ |
| 将任务加入队列 |
| |
| Args: |
| task_id: 任务唯一标识 |
| config: 任务配置字典,包含训练所需的所有参数 |
| priority: 任务优先级 ("low", "normal", "high") |
| |
| Returns: |
| job_id: 队列中的作业ID |
| |
| Raises: |
| ValueError: 配置无效时抛出 |
| """ |
| pass |
| |
| @abstractmethod |
| async def get_status(self, job_id: str) -> Dict: |
| """ |
| 获取任务状态 |
| |
| Args: |
| job_id: 作业ID |
| |
| Returns: |
| 状态字典,包含: |
| - status: 任务状态 (queued, running, completed, failed, cancelled) |
| - progress: 进度 (0.0-1.0) |
| - current_stage: 当前阶段名称 |
| - message: 状态消息 |
| - error_message: 错误信息(如果失败) |
| """ |
| pass |
| |
| @abstractmethod |
| async def cancel(self, job_id: str) -> bool: |
| """ |
| 取消任务 |
| |
| Args: |
| job_id: 作业ID |
| |
| Returns: |
| 是否成功取消 |
| """ |
| pass |
| |
| @abstractmethod |
| async def subscribe_progress(self, task_id: str) -> AsyncGenerator[Dict, None]: |
| """ |
| 订阅任务进度(用于 SSE 流) |
| |
| Args: |
| task_id: 任务ID |
| |
| Yields: |
| 进度信息字典,包含: |
| - type: 消息类型 ("progress", "log", "heartbeat") |
| - stage: 当前阶段 |
| - progress: 进度值 |
| - message: 进度消息 |
| - status: 状态 (running, completed, failed, cancelled) |
| |
| Note: |
| 当 status 为终态时,生成器会自动结束 |
| """ |
| pass |
|
|
|
|
| class ProgressAdapter(ABC): |
| """ |
| 进度管理适配器抽象基类 |
| |
| 用于更新和订阅任务进度,支持本地(内存队列)和 |
| 服务器(Redis Pub/Sub)两种实现。 |
| """ |
| |
| @abstractmethod |
| async def update_progress(self, task_id: str, progress: Dict) -> None: |
| """ |
| 更新进度 |
| |
| Args: |
| task_id: 任务ID |
| progress: 进度信息字典 |
| """ |
| pass |
| |
| @abstractmethod |
| async def get_progress(self, task_id: str) -> Optional[Dict]: |
| """ |
| 获取当前进度 |
| |
| Args: |
| task_id: 任务ID |
| |
| Returns: |
| 最新进度信息,不存在则返回 None |
| """ |
| pass |
| |
| @abstractmethod |
| async def subscribe(self, task_id: str) -> AsyncGenerator[Dict, None]: |
| """ |
| 订阅进度更新 |
| |
| Args: |
| task_id: 任务ID |
| |
| Yields: |
| 进度信息字典 |
| """ |
| pass |
|
|
|
|
| class StorageAdapter(ABC): |
| """ |
| 存储适配器抽象基类 |
| |
| 定义文件存储的通用接口,支持本地文件系统和 |
| 对象存储(S3/MinIO)两种实现方式。 |
| |
| Example: |
| >>> adapter = LocalStorageAdapter(base_path="./data/files") |
| >>> file_id = await adapter.upload_file(data, "audio.wav", {"purpose": "training"}) |
| >>> content = await adapter.download_file(file_id) |
| >>> await adapter.delete_file(file_id) |
| """ |
| |
| @abstractmethod |
| async def upload_file( |
| self, |
| file_data: bytes, |
| filename: str, |
| metadata: Dict[str, Any] |
| ) -> str: |
| """ |
| 上传文件 |
| |
| Args: |
| file_data: 文件二进制数据 |
| filename: 原始文件名 |
| metadata: 文件元数据,可包含: |
| - content_type: MIME类型 |
| - purpose: 文件用途 (training, reference, output) |
| - 其他自定义字段 |
| |
| Returns: |
| file_id: 文件唯一标识 |
| |
| Raises: |
| IOError: 存储失败时抛出 |
| """ |
| pass |
| |
| @abstractmethod |
| async def download_file(self, file_id: str) -> bytes: |
| """ |
| 下载文件 |
| |
| Args: |
| file_id: 文件唯一标识 |
| |
| Returns: |
| 文件二进制数据 |
| |
| Raises: |
| FileNotFoundError: 文件不存在时抛出 |
| """ |
| pass |
| |
| @abstractmethod |
| async def delete_file(self, file_id: str) -> bool: |
| """ |
| 删除文件 |
| |
| Args: |
| file_id: 文件唯一标识 |
| |
| Returns: |
| 是否成功删除 |
| """ |
| pass |
| |
| @abstractmethod |
| async def get_file_metadata(self, file_id: str) -> Optional[Dict[str, Any]]: |
| """ |
| 获取文件元数据 |
| |
| Args: |
| file_id: 文件唯一标识 |
| |
| Returns: |
| 文件元数据字典,包含: |
| - id: 文件ID |
| - filename: 原始文件名 |
| - content_type: MIME类型 |
| - size_bytes: 文件大小 |
| - purpose: 文件用途 |
| - uploaded_at: 上传时间 |
| - 音频文件额外包含: duration_seconds, sample_rate |
| |
| 文件不存在时返回 None |
| """ |
| pass |
| |
| @abstractmethod |
| async def list_files( |
| self, |
| purpose: Optional[str] = None, |
| limit: int = 50, |
| offset: int = 0 |
| ) -> List[Dict[str, Any]]: |
| """ |
| 列出文件 |
| |
| Args: |
| purpose: 按用途筛选 (training, reference, output) |
| limit: 返回数量限制 |
| offset: 偏移量 |
| |
| Returns: |
| 文件元数据列表 |
| """ |
| pass |
| |
| @abstractmethod |
| async def file_exists(self, file_id: str) -> bool: |
| """ |
| 检查文件是否存在 |
| |
| Args: |
| file_id: 文件唯一标识 |
| |
| Returns: |
| 文件是否存在 |
| """ |
| pass |
|
|
|
|
| class DatabaseAdapter(ABC): |
| """ |
| 数据库适配器抽象基类 |
| |
| 定义数据持久化的通用接口,支持 SQLite 和 |
| PostgreSQL 两种实现方式。 |
| |
| 管理以下实体: |
| - Task: Quick Mode 一键训练任务 |
| - Experiment: Advanced Mode 实验 |
| - Stage: 实验中的各个阶段 |
| - File: 上传的文件记录(可选,与StorageAdapter配合) |
| |
| Example: |
| >>> adapter = SQLiteAdapter(db_path="./data/app.db") |
| >>> task = await adapter.create_task(task_data) |
| >>> task = await adapter.get_task(task_id) |
| >>> await adapter.update_task(task_id, {"status": "completed"}) |
| """ |
| |
| |
| |
| |
| |
| @abstractmethod |
| async def create_task(self, task: "Task") -> "Task": |
| """ |
| 创建任务 |
| |
| Args: |
| task: Task 领域模型实例 |
| |
| Returns: |
| 创建后的 Task 实例(包含生成的字段如 created_at) |
| """ |
| pass |
| |
| @abstractmethod |
| async def get_task(self, task_id: str) -> Optional["Task"]: |
| """ |
| 获取任务 |
| |
| Args: |
| task_id: 任务唯一标识 |
| |
| Returns: |
| Task 实例,不存在则返回 None |
| """ |
| pass |
| |
| @abstractmethod |
| async def update_task(self, task_id: str, updates: Dict[str, Any]) -> Optional["Task"]: |
| """ |
| 更新任务 |
| |
| Args: |
| task_id: 任务唯一标识 |
| updates: 要更新的字段字典 |
| |
| Returns: |
| 更新后的 Task 实例,不存在则返回 None |
| """ |
| pass |
| |
| @abstractmethod |
| async def list_tasks( |
| self, |
| status: Optional[str] = None, |
| limit: int = 50, |
| offset: int = 0 |
| ) -> List["Task"]: |
| """ |
| 查询任务列表 |
| |
| Args: |
| status: 按状态筛选 |
| limit: 返回数量限制 |
| offset: 偏移量 |
| |
| Returns: |
| Task 实例列表 |
| """ |
| pass |
| |
| @abstractmethod |
| async def delete_task(self, task_id: str) -> bool: |
| """ |
| 删除任务 |
| |
| Args: |
| task_id: 任务唯一标识 |
| |
| Returns: |
| 是否成功删除 |
| """ |
| pass |
| |
| @abstractmethod |
| async def count_tasks(self, status: Optional[str] = None) -> int: |
| """ |
| 统计任务数量 |
| |
| Args: |
| status: 按状态筛选 |
| |
| Returns: |
| 任务数量 |
| """ |
| pass |
| |
| @abstractmethod |
| async def get_task_by_exp_name(self, exp_name: str) -> Optional["Task"]: |
| """ |
| 根据实验名称获取任务 |
| |
| 用于检查 exp_name 是否已存在。 |
| |
| Args: |
| exp_name: 实验名称 |
| |
| Returns: |
| Task 实例,不存在则返回 None |
| """ |
| pass |
| |
| |
| |
| |
| |
| @abstractmethod |
| async def create_experiment(self, experiment: Dict[str, Any]) -> Dict[str, Any]: |
| """ |
| 创建实验 |
| |
| Args: |
| experiment: 实验数据字典 |
| |
| Returns: |
| 创建后的实验数据 |
| """ |
| pass |
| |
| @abstractmethod |
| async def get_experiment(self, exp_id: str) -> Optional[Dict[str, Any]]: |
| """ |
| 获取实验 |
| |
| Args: |
| exp_id: 实验唯一标识 |
| |
| Returns: |
| 实验数据字典,不存在则返回 None |
| """ |
| pass |
| |
| @abstractmethod |
| async def update_experiment( |
| self, |
| exp_id: str, |
| updates: Dict[str, Any] |
| ) -> Optional[Dict[str, Any]]: |
| """ |
| 更新实验 |
| |
| Args: |
| exp_id: 实验唯一标识 |
| updates: 要更新的字段字典 |
| |
| Returns: |
| 更新后的实验数据,不存在则返回 None |
| """ |
| pass |
| |
| @abstractmethod |
| async def list_experiments( |
| self, |
| status: Optional[str] = None, |
| limit: int = 50, |
| offset: int = 0 |
| ) -> List[Dict[str, Any]]: |
| """ |
| 查询实验列表 |
| |
| Args: |
| status: 按状态筛选 |
| limit: 返回数量限制 |
| offset: 偏移量 |
| |
| Returns: |
| 实验数据列表 |
| """ |
| pass |
| |
| @abstractmethod |
| async def delete_experiment(self, exp_id: str) -> bool: |
| """ |
| 删除实验 |
| |
| Args: |
| exp_id: 实验唯一标识 |
| |
| Returns: |
| 是否成功删除 |
| """ |
| pass |
| |
| |
| |
| |
| |
| @abstractmethod |
| async def update_stage( |
| self, |
| exp_id: str, |
| stage_type: str, |
| updates: Dict[str, Any] |
| ) -> Optional[Dict[str, Any]]: |
| """ |
| 更新阶段状态 |
| |
| Args: |
| exp_id: 实验唯一标识 |
| stage_type: 阶段类型 |
| updates: 要更新的字段字典 |
| |
| Returns: |
| 更新后的阶段数据,不存在则返回 None |
| """ |
| pass |
| |
| @abstractmethod |
| async def get_stage( |
| self, |
| exp_id: str, |
| stage_type: str |
| ) -> Optional[Dict[str, Any]]: |
| """ |
| 获取阶段状态 |
| |
| Args: |
| exp_id: 实验唯一标识 |
| stage_type: 阶段类型 |
| |
| Returns: |
| 阶段数据字典,不存在则返回 None |
| """ |
| pass |
| |
| @abstractmethod |
| async def get_all_stages(self, exp_id: str) -> List[Dict[str, Any]]: |
| """ |
| 获取实验的所有阶段状态 |
| |
| Args: |
| exp_id: 实验唯一标识 |
| |
| Returns: |
| 阶段数据列表 |
| """ |
| pass |
| |
| |
| |
| |
| |
| @abstractmethod |
| async def create_file_record(self, file_data: Dict[str, Any]) -> Dict[str, Any]: |
| """ |
| 创建文件记录 |
| |
| Args: |
| file_data: 文件元数据 |
| |
| Returns: |
| 创建后的文件记录 |
| """ |
| pass |
| |
| @abstractmethod |
| async def get_file_record(self, file_id: str) -> Optional[Dict[str, Any]]: |
| """ |
| 获取文件记录 |
| |
| Args: |
| file_id: 文件唯一标识 |
| |
| Returns: |
| 文件记录,不存在则返回 None |
| """ |
| pass |
| |
| @abstractmethod |
| async def delete_file_record(self, file_id: str) -> bool: |
| """ |
| 删除文件记录 |
| |
| Args: |
| file_id: 文件唯一标识 |
| |
| Returns: |
| 是否成功删除 |
| """ |
| pass |
| |
| @abstractmethod |
| async def list_file_records( |
| self, |
| purpose: Optional[str] = None, |
| limit: int = 50, |
| offset: int = 0 |
| ) -> List[Dict[str, Any]]: |
| """ |
| 查询文件记录列表 |
| |
| Args: |
| purpose: 按用途筛选 |
| limit: 返回数量限制 |
| offset: 偏移量 |
| |
| Returns: |
| 文件记录列表 |
| """ |
| pass |
|
|