| | """ |
| | 本地异步任务管理器 |
| | |
| | 基于 asyncio.subprocess + SQLite 的本地任务队列实现。 |
| | 适用于 macOS 本地训练和 Electron 集成场景。 |
| | """ |
| |
|
| | import asyncio |
| | import json |
| | import os |
| | import sqlite3 |
| | import sys |
| | import uuid |
| | from datetime import datetime |
| | from pathlib import Path |
| | from typing import TYPE_CHECKING, Dict, Optional, AsyncGenerator, List |
| |
|
| | import aiosqlite |
| |
|
| | from project_config import settings, PROJECT_ROOT, get_pythonpath |
| | from ..base import TaskQueueAdapter |
| |
|
| | if TYPE_CHECKING: |
| | from ..base import DatabaseAdapter |
| |
|
| | |
| | PROGRESS_PREFIX = "##PROGRESS##" |
| | PROGRESS_SUFFIX = "##" |
| |
|
| |
|
| | class AsyncTrainingManager(TaskQueueAdapter): |
| | """ |
| | 基于 asyncio.subprocess 的异步任务管理器 |
| | |
| | 特点: |
| | 1. 使用 asyncio.create_subprocess_exec() 异步启动训练子进程 |
| | 2. 完全非阻塞,与 FastAPI 异步模型完美契合 |
| | 3. SQLite 持久化任务状态,支持应用重启后恢复 |
| | 4. 实时解析子进程输出获取进度 |
| | |
| | Example: |
| | >>> manager = AsyncTrainingManager(db_path="./data/tasks.db") |
| | >>> job_id = await manager.enqueue("task-123", {"exp_name": "test", ...}) |
| | >>> |
| | >>> # 订阅进度 |
| | >>> async for progress in manager.subscribe_progress("task-123"): |
| | ... print(f"{progress['stage']}: {progress['progress']*100:.1f}%") |
| | >>> |
| | >>> # 取消任务 |
| | >>> await manager.cancel(job_id) |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | db_path: str = None, |
| | max_concurrent: int = 1, |
| | database_adapter: "DatabaseAdapter" = None |
| | ): |
| | """ |
| | 初始化任务管理器 |
| | |
| | Args: |
| | db_path: SQLite 数据库路径,默认使用 settings.SQLITE_PATH |
| | max_concurrent: 最大并发任务数(本地通常为1) |
| | database_adapter: 数据库适配器,用于同步更新 tasks 表 |
| | """ |
| | self.db_path = db_path or str(settings.SQLITE_PATH) |
| | self.max_concurrent = max_concurrent |
| | self._database_adapter = database_adapter |
| |
|
| | |
| | self.running_processes: Dict[str, asyncio.subprocess.Process] = {} |
| | self.progress_channels: Dict[str, asyncio.Queue] = {} |
| | self._running_count = 0 |
| | self._queue_lock = asyncio.Lock() |
| | |
| | |
| | self._task_job_mapping: Dict[str, str] = {} |
| |
|
| | |
| | self._init_db_sync() |
| |
|
| | def _init_db_sync(self) -> None: |
| | """同步初始化数据库(启动时调用)""" |
| | Path(self.db_path).parent.mkdir(parents=True, exist_ok=True) |
| |
|
| | with sqlite3.connect(self.db_path) as conn: |
| | conn.execute(''' |
| | CREATE TABLE IF NOT EXISTS task_queue ( |
| | job_id TEXT PRIMARY KEY, |
| | task_id TEXT NOT NULL UNIQUE, |
| | exp_name TEXT NOT NULL, |
| | config TEXT NOT NULL, |
| | status TEXT DEFAULT 'queued', |
| | current_stage TEXT, |
| | progress REAL DEFAULT 0, |
| | overall_progress REAL DEFAULT 0, |
| | message TEXT, |
| | error_message TEXT, |
| | created_at TEXT NOT NULL, |
| | started_at TEXT, |
| | completed_at TEXT |
| | ) |
| | ''') |
| | conn.execute('CREATE INDEX IF NOT EXISTS idx_task_queue_status ON task_queue(status)') |
| | conn.execute('CREATE INDEX IF NOT EXISTS idx_task_queue_task_id ON task_queue(task_id)') |
| | conn.execute('CREATE INDEX IF NOT EXISTS idx_task_queue_created ON task_queue(created_at)') |
| | conn.commit() |
| |
|
| | async def enqueue(self, task_id: str, config: Dict, priority: str = "normal") -> str: |
| | """ |
| | 将任务加入队列并异步启动 |
| | |
| | Args: |
| | task_id: 任务唯一标识 |
| | config: 任务配置,需包含: |
| | - exp_name: 实验名称 |
| | - version: 模型版本 |
| | - stages: 阶段配置列表 |
| | priority: 优先级(当前实现忽略此参数) |
| | |
| | Returns: |
| | job_id: 作业ID |
| | """ |
| | job_id = str(uuid.uuid4()) |
| | exp_name = config.get("exp_name", "unknown") |
| |
|
| | |
| | async with aiosqlite.connect(self.db_path) as db: |
| | await db.execute( |
| | '''INSERT INTO task_queue |
| | (job_id, task_id, exp_name, config, status, created_at) |
| | VALUES (?, ?, ?, ?, 'queued', ?)''', |
| | (job_id, task_id, exp_name, json.dumps(config, ensure_ascii=False), |
| | datetime.utcnow().isoformat()) |
| | ) |
| | await db.commit() |
| |
|
| | |
| | self._task_job_mapping[task_id] = job_id |
| |
|
| | |
| | self.progress_channels[task_id] = asyncio.Queue() |
| |
|
| | |
| | asyncio.create_task(self._run_training_async(job_id, task_id, config)) |
| |
|
| | return job_id |
| |
|
| | async def _run_training_async(self, job_id: str, task_id: str, config: Dict) -> None: |
| | """ |
| | 异步执行训练 Pipeline |
| | |
| | Args: |
| | job_id: 作业ID |
| | task_id: 任务ID |
| | config: 任务配置 |
| | """ |
| | config_path = None |
| |
|
| | try: |
| | |
| | await self._update_status( |
| | job_id, |
| | status='running', |
| | started_at=datetime.utcnow().isoformat() |
| | ) |
| | await self._send_progress(task_id, { |
| | "type": "progress", |
| | "status": "running", |
| | "message": "训练任务启动中...", |
| | "progress": 0.0, |
| | "overall_progress": 0.0, |
| | }) |
| |
|
| | |
| | config_path = await self._write_config_file(task_id, config) |
| |
|
| | |
| | script_path = self._get_pipeline_script_path() |
| |
|
| | |
| | env = os.environ.copy() |
| | env['PYTHONPATH'] = get_pythonpath() |
| |
|
| | |
| | process = await asyncio.create_subprocess_exec( |
| | sys.executable, script_path, |
| | '--config', config_path, |
| | '--task-id', task_id, |
| | stdout=asyncio.subprocess.PIPE, |
| | stderr=asyncio.subprocess.PIPE, |
| | env=env, |
| | cwd=str(PROJECT_ROOT), |
| | ) |
| |
|
| | self.running_processes[task_id] = process |
| | self._running_count += 1 |
| |
|
| | |
| | await self._monitor_process_output(task_id, job_id, process) |
| |
|
| | |
| | returncode = await process.wait() |
| |
|
| | if returncode == 0: |
| | await self._update_status( |
| | job_id, |
| | status='completed', |
| | progress=1.0, |
| | overall_progress=1.0, |
| | message='训练完成', |
| | completed_at=datetime.utcnow().isoformat() |
| | ) |
| | await self._send_progress(task_id, { |
| | "type": "progress", |
| | "status": "completed", |
| | "message": "训练完成", |
| | "progress": 1.0, |
| | "overall_progress": 1.0, |
| | }) |
| | else: |
| | |
| | stderr_data = await process.stderr.read() |
| | error_msg = stderr_data.decode() if stderr_data else f"进程退出码: {returncode}" |
| |
|
| | await self._update_status( |
| | job_id, |
| | status='failed', |
| | error_message=error_msg, |
| | completed_at=datetime.utcnow().isoformat() |
| | ) |
| | await self._send_progress(task_id, { |
| | "type": "progress", |
| | "status": "failed", |
| | "message": f"训练失败: {error_msg[:200]}", |
| | "error": error_msg, |
| | }) |
| |
|
| | except asyncio.CancelledError: |
| | await self._update_status( |
| | job_id, |
| | status='cancelled', |
| | message='任务已取消', |
| | completed_at=datetime.utcnow().isoformat() |
| | ) |
| | await self._send_progress(task_id, { |
| | "type": "progress", |
| | "status": "cancelled", |
| | "message": "任务已取消", |
| | }) |
| |
|
| | except Exception as e: |
| | error_msg = str(e) |
| | await self._update_status( |
| | job_id, |
| | status='failed', |
| | error_message=error_msg, |
| | completed_at=datetime.utcnow().isoformat() |
| | ) |
| | await self._send_progress(task_id, { |
| | "type": "progress", |
| | "status": "failed", |
| | "message": f"任务执行出错: {error_msg}", |
| | "error": error_msg, |
| | }) |
| |
|
| | finally: |
| | |
| | self.running_processes.pop(task_id, None) |
| | self._running_count = max(0, self._running_count - 1) |
| |
|
| | |
| | if config_path: |
| | await self._cleanup_config_file(config_path) |
| |
|
| | async def _monitor_process_output( |
| | self, |
| | task_id: str, |
| | job_id: str, |
| | process: asyncio.subprocess.Process |
| | ) -> None: |
| | """ |
| | 异步监控子进程输出并解析进度 |
| | |
| | Args: |
| | task_id: 任务ID |
| | job_id: 作业ID |
| | process: 子进程对象 |
| | """ |
| | async def read_stdout(): |
| | """读取 stdout 并解析进度""" |
| | while True: |
| | line = await process.stdout.readline() |
| | if not line: |
| | break |
| |
|
| | text = line.decode('utf-8', errors='replace').strip() |
| | if not text: |
| | continue |
| |
|
| | |
| | if text.startswith(PROGRESS_PREFIX) and text.endswith(PROGRESS_SUFFIX): |
| | json_str = text[len(PROGRESS_PREFIX):-len(PROGRESS_SUFFIX)] |
| | try: |
| | progress_info = json.loads(json_str) |
| | await self._handle_progress(task_id, job_id, progress_info) |
| | except json.JSONDecodeError as e: |
| | |
| | await self._send_progress(task_id, { |
| | "type": "log", |
| | "level": "warning", |
| | "message": f"进度解析失败: {e}", |
| | }) |
| | else: |
| | |
| | await self._send_progress(task_id, { |
| | "type": "log", |
| | "level": "info", |
| | "message": text, |
| | }) |
| |
|
| | async def read_stderr(): |
| | """读取 stderr 作为错误日志""" |
| | while True: |
| | line = await process.stderr.readline() |
| | if not line: |
| | break |
| |
|
| | text = line.decode('utf-8', errors='replace').strip() |
| | if text: |
| | await self._send_progress(task_id, { |
| | "type": "log", |
| | "level": "error", |
| | "message": text, |
| | }) |
| |
|
| | |
| | await asyncio.gather( |
| | read_stdout(), |
| | read_stderr(), |
| | return_exceptions=True |
| | ) |
| |
|
| | async def _handle_progress( |
| | self, |
| | task_id: str, |
| | job_id: str, |
| | progress_info: Dict |
| | ) -> None: |
| | """ |
| | 处理进度信息 |
| | |
| | Args: |
| | task_id: 任务ID |
| | job_id: 作业ID |
| | progress_info: 进度信息字典 |
| | """ |
| | |
| | await self._send_progress(task_id, progress_info) |
| |
|
| | |
| | updates = {} |
| |
|
| | if 'stage' in progress_info: |
| | updates['current_stage'] = progress_info['stage'] |
| | if 'progress' in progress_info: |
| | updates['progress'] = progress_info['progress'] |
| | if 'overall_progress' in progress_info: |
| | updates['overall_progress'] = progress_info['overall_progress'] |
| | if 'message' in progress_info: |
| | updates['message'] = progress_info['message'] |
| | if 'status' in progress_info: |
| | updates['status'] = progress_info['status'] |
| | if 'error' in progress_info: |
| | updates['error_message'] = progress_info['error'] |
| |
|
| | if updates: |
| | await self._update_status(job_id, **updates) |
| |
|
| | async def _send_progress(self, task_id: str, progress_info: Dict) -> None: |
| | """ |
| | 发送进度到订阅队列 |
| | |
| | Args: |
| | task_id: 任务ID |
| | progress_info: 进度信息 |
| | """ |
| | if task_id in self.progress_channels: |
| | |
| | if 'timestamp' not in progress_info: |
| | progress_info['timestamp'] = datetime.utcnow().isoformat() |
| |
|
| | await self.progress_channels[task_id].put(progress_info) |
| |
|
| | async def _update_status(self, job_id: str, **kwargs) -> None: |
| | """ |
| | 更新任务状态 |
| | |
| | 同时更新 task_queue 表和 tasks 表(通过 DatabaseAdapter)。 |
| | |
| | Args: |
| | job_id: 作业ID |
| | **kwargs: 要更新的字段 |
| | """ |
| | if not kwargs: |
| | return |
| |
|
| | |
| | task_id = None |
| | async with aiosqlite.connect(self.db_path) as db: |
| | updates = [] |
| | values = [] |
| |
|
| | for key, value in kwargs.items(): |
| | updates.append(f"{key} = ?") |
| | values.append(value) |
| |
|
| | values.append(job_id) |
| |
|
| | await db.execute( |
| | f"UPDATE task_queue SET {', '.join(updates)} WHERE job_id = ?", |
| | values |
| | ) |
| | await db.commit() |
| | |
| | |
| | async with db.execute( |
| | "SELECT task_id FROM task_queue WHERE job_id = ?", (job_id,) |
| | ) as cursor: |
| | row = await cursor.fetchone() |
| | if row: |
| | task_id = row[0] |
| |
|
| | |
| | if self._database_adapter and task_id: |
| | await self._sync_to_tasks_table(task_id, kwargs) |
| | |
| | async def _sync_to_tasks_table(self, task_id: str, updates: Dict) -> None: |
| | """ |
| | 同步状态更新到 tasks 表 |
| | |
| | 字段映射: |
| | - task_queue.progress -> tasks.stage_progress |
| | - task_queue.overall_progress -> tasks.progress |
| | - 其他字段直接映射 |
| | |
| | Args: |
| | task_id: 任务ID |
| | updates: 要更新的字段字典 |
| | """ |
| | if not self._database_adapter: |
| | return |
| | |
| | |
| | tasks_updates = {} |
| | |
| | for key, value in updates.items(): |
| | if key == 'progress': |
| | |
| | tasks_updates['stage_progress'] = value |
| | elif key == 'overall_progress': |
| | |
| | tasks_updates['progress'] = value |
| | elif key in ('status', 'current_stage', 'message', 'error_message', |
| | 'started_at', 'completed_at'): |
| | |
| | tasks_updates[key] = value |
| | |
| | if tasks_updates: |
| | try: |
| | await self._database_adapter.update_task(task_id, tasks_updates) |
| | except Exception as e: |
| | |
| | import logging |
| | logging.warning(f"Failed to sync task status to tasks table: {e}") |
| |
|
| | async def get_status(self, job_id: str) -> Dict: |
| | """ |
| | 获取任务状态 |
| | |
| | Args: |
| | job_id: 作业ID |
| | |
| | Returns: |
| | 状态字典 |
| | """ |
| | async with aiosqlite.connect(self.db_path) as db: |
| | db.row_factory = aiosqlite.Row |
| | async with db.execute( |
| | "SELECT * FROM task_queue WHERE job_id = ?", (job_id,) |
| | ) as cursor: |
| | row = await cursor.fetchone() |
| | if row: |
| | return dict(row) |
| |
|
| | return {"status": "not_found", "message": "任务不存在"} |
| |
|
| | async def get_status_by_task_id(self, task_id: str) -> Dict: |
| | """ |
| | 通过 task_id 获取任务状态 |
| | |
| | Args: |
| | task_id: 任务ID |
| | |
| | Returns: |
| | 状态字典 |
| | """ |
| | async with aiosqlite.connect(self.db_path) as db: |
| | db.row_factory = aiosqlite.Row |
| | async with db.execute( |
| | "SELECT * FROM task_queue WHERE task_id = ?", (task_id,) |
| | ) as cursor: |
| | row = await cursor.fetchone() |
| | if row: |
| | return dict(row) |
| |
|
| | return {"status": "not_found", "message": "任务不存在"} |
| |
|
| | async def cancel(self, job_id: str) -> bool: |
| | """ |
| | 取消任务 |
| | |
| | Args: |
| | job_id: 作业ID |
| | |
| | Returns: |
| | 是否成功取消 |
| | """ |
| | |
| | async with aiosqlite.connect(self.db_path) as db: |
| | async with db.execute( |
| | "SELECT task_id, status FROM task_queue WHERE job_id = ?", (job_id,) |
| | ) as cursor: |
| | row = await cursor.fetchone() |
| | if not row: |
| | return False |
| | task_id, status = row |
| |
|
| | |
| | if status in ('completed', 'failed', 'cancelled'): |
| | return False |
| |
|
| | |
| | if task_id in self.running_processes: |
| | process = self.running_processes[task_id] |
| |
|
| | |
| | process.terminate() |
| |
|
| | try: |
| | |
| | await asyncio.wait_for(process.wait(), timeout=5.0) |
| | except asyncio.TimeoutError: |
| | |
| | process.kill() |
| | await process.wait() |
| |
|
| | return True |
| |
|
| | |
| | await self._update_status( |
| | job_id, |
| | status='cancelled', |
| | message='任务已取消', |
| | completed_at=datetime.utcnow().isoformat() |
| | ) |
| |
|
| | |
| | if task_id in self.progress_channels: |
| | await self._send_progress(task_id, { |
| | "type": "progress", |
| | "status": "cancelled", |
| | "message": "任务已取消", |
| | }) |
| |
|
| | return True |
| |
|
| | async def subscribe_progress(self, task_id: str) -> AsyncGenerator[Dict, None]: |
| | """ |
| | 订阅任务进度(用于 SSE 流) |
| | |
| | Args: |
| | task_id: 任务ID |
| | |
| | Yields: |
| | 进度信息字典 |
| | """ |
| | |
| | if task_id not in self.progress_channels: |
| | self.progress_channels[task_id] = asyncio.Queue() |
| |
|
| | queue = self.progress_channels[task_id] |
| |
|
| | |
| | status = await self.get_status_by_task_id(task_id) |
| | if status.get("status") != "not_found": |
| | yield { |
| | "type": "progress", |
| | "status": status.get("status"), |
| | "stage": status.get("current_stage"), |
| | "progress": status.get("progress", 0.0), |
| | "overall_progress": status.get("overall_progress", 0.0), |
| | "message": status.get("message"), |
| | "timestamp": datetime.utcnow().isoformat(), |
| | } |
| |
|
| | |
| | while True: |
| | try: |
| | |
| | progress = await asyncio.wait_for(queue.get(), timeout=30.0) |
| | yield progress |
| |
|
| | |
| | if progress.get('status') in ('completed', 'failed', 'cancelled'): |
| | break |
| |
|
| | except asyncio.TimeoutError: |
| | |
| | yield { |
| | "type": "heartbeat", |
| | "timestamp": datetime.utcnow().isoformat(), |
| | } |
| |
|
| | async def list_tasks( |
| | self, |
| | status: Optional[str] = None, |
| | limit: int = 50, |
| | offset: int = 0 |
| | ) -> List[Dict]: |
| | """ |
| | 列出任务 |
| | |
| | Args: |
| | status: 按状态筛选 |
| | limit: 返回数量限制 |
| | offset: 偏移量 |
| | |
| | Returns: |
| | 任务列表 |
| | """ |
| | async with aiosqlite.connect(self.db_path) as db: |
| | db.row_factory = aiosqlite.Row |
| |
|
| | if status: |
| | query = """ |
| | SELECT * FROM task_queue |
| | WHERE status = ? |
| | ORDER BY created_at DESC |
| | LIMIT ? OFFSET ? |
| | """ |
| | params = (status, limit, offset) |
| | else: |
| | query = """ |
| | SELECT * FROM task_queue |
| | ORDER BY created_at DESC |
| | LIMIT ? OFFSET ? |
| | """ |
| | params = (limit, offset) |
| |
|
| | async with db.execute(query, params) as cursor: |
| | rows = await cursor.fetchall() |
| | return [dict(row) for row in rows] |
| |
|
| | async def recover_pending_tasks(self) -> int: |
| | """ |
| | 应用重启后恢复未完成的任务 |
| | |
| | 将 running 状态的任务标记为 interrupted, |
| | 可选择重新启动 queued 状态的任务。 |
| | |
| | Returns: |
| | 恢复的任务数量 |
| | """ |
| | async with aiosqlite.connect(self.db_path) as db: |
| | |
| | await db.execute( |
| | """UPDATE task_queue |
| | SET status = 'interrupted', |
| | message = '应用重启导致任务中断' |
| | WHERE status = 'running'""" |
| | ) |
| | await db.commit() |
| |
|
| | |
| | db.row_factory = aiosqlite.Row |
| | async with db.execute( |
| | "SELECT * FROM task_queue WHERE status = 'queued' ORDER BY created_at" |
| | ) as cursor: |
| | queued_tasks = await cursor.fetchall() |
| |
|
| | |
| | recovered = 0 |
| | for task in queued_tasks: |
| | task_id = task['task_id'] |
| | job_id = task['job_id'] |
| | config = json.loads(task['config']) |
| |
|
| | self.progress_channels[task_id] = asyncio.Queue() |
| | asyncio.create_task(self._run_training_async(job_id, task_id, config)) |
| | recovered += 1 |
| |
|
| | return recovered |
| |
|
| | async def cleanup_old_tasks(self, days: int = 7) -> int: |
| | """ |
| | 清理旧任务记录 |
| | |
| | Args: |
| | days: 保留天数 |
| | |
| | Returns: |
| | 删除的任务数量 |
| | """ |
| | from datetime import timedelta |
| |
|
| | cutoff = (datetime.utcnow() - timedelta(days=days)).isoformat() |
| |
|
| | async with aiosqlite.connect(self.db_path) as db: |
| | cursor = await db.execute( |
| | """DELETE FROM task_queue |
| | WHERE status IN ('completed', 'failed', 'cancelled') |
| | AND completed_at < ?""", |
| | (cutoff,) |
| | ) |
| | deleted = cursor.rowcount |
| | await db.commit() |
| |
|
| | return deleted |
| |
|
| | def _get_pipeline_script_path(self) -> str: |
| | """获取 run_pipeline.py 脚本路径""" |
| | return str(settings.PIPELINE_SCRIPT_PATH) |
| |
|
| | async def _write_config_file(self, task_id: str, config: Dict) -> str: |
| | """ |
| | 写入临时配置文件 |
| | |
| | Args: |
| | task_id: 任务ID |
| | config: 配置字典 |
| | |
| | Returns: |
| | 配置文件路径 |
| | """ |
| | config_path = settings.CONFIGS_DIR / f"{task_id}.json" |
| |
|
| | with open(config_path, 'w', encoding='utf-8') as f: |
| | json.dump(config, f, ensure_ascii=False, indent=2) |
| |
|
| | return str(config_path) |
| |
|
| | async def _cleanup_config_file(self, config_path: str) -> None: |
| | """ |
| | 清理临时配置文件 |
| | |
| | Args: |
| | config_path: 配置文件路径 |
| | """ |
| | try: |
| | path = Path(config_path) |
| | if path.exists(): |
| | path.unlink() |
| | except Exception: |
| | pass |
| |
|