| | """ |
| | 本地进度管理适配器 |
| | |
| | 基于内存队列实现的进度管理适配器,适用于本地单实例场景。 |
| | """ |
| |
|
| | import asyncio |
| | from collections import defaultdict |
| | from datetime import datetime |
| | from typing import Any, AsyncGenerator, Dict, List, Optional |
| |
|
| | from ..base import ProgressAdapter |
| |
|
| |
|
| | class LocalProgressAdapter(ProgressAdapter): |
| | """ |
| | 本地内存进度管理适配器 |
| | |
| | 特点: |
| | 1. 使用内存字典存储最新进度 |
| | 2. 使用 asyncio.Queue 实现订阅者模式 |
| | 3. 支持多订阅者同时订阅同一任务 |
| | 4. 与 AsyncTrainingManager 的进度推送机制兼容 |
| | |
| | 注意: |
| | - 进程重启后进度数据会丢失 |
| | - 仅适用于单实例部署 |
| | - 服务器模式应使用 RedisProgressAdapter |
| | |
| | Example: |
| | >>> adapter = LocalProgressAdapter() |
| | >>> await adapter.update_progress("task-123", { |
| | ... "stage": "sovits_train", |
| | ... "progress": 0.5, |
| | ... "message": "Epoch 8/16" |
| | ... }) |
| | >>> |
| | >>> # 订阅进度 |
| | >>> async for progress in adapter.subscribe("task-123"): |
| | ... print(f"{progress['stage']}: {progress['progress']*100:.1f}%") |
| | """ |
| | |
| | def __init__(self): |
| | """初始化本地进度适配器""" |
| | |
| | self.progress_store: Dict[str, Dict[str, Any]] = {} |
| | |
| | |
| | self.subscribers: Dict[str, List[asyncio.Queue]] = defaultdict(list) |
| | |
| | |
| | self._lock = asyncio.Lock() |
| | |
| | async def update_progress(self, task_id: str, progress: Dict[str, Any]) -> None: |
| | """ |
| | 更新进度 |
| | |
| | Args: |
| | task_id: 任务ID |
| | progress: 进度信息字典,可包含: |
| | - type: 消息类型 ("progress", "log", "error", "heartbeat") |
| | - stage: 当前阶段 |
| | - progress: 阶段进度 (0.0-1.0) |
| | - overall_progress: 总体进度 (0.0-1.0) |
| | - message: 进度消息 |
| | - status: 状态 ("running", "completed", "failed", "cancelled") |
| | """ |
| | |
| | if "timestamp" not in progress: |
| | progress["timestamp"] = datetime.utcnow().isoformat() |
| | |
| | |
| | self.progress_store[task_id] = progress |
| | |
| | |
| | async with self._lock: |
| | if task_id in self.subscribers: |
| | for queue in self.subscribers[task_id]: |
| | try: |
| | await queue.put(progress) |
| | except asyncio.QueueFull: |
| | |
| | pass |
| | |
| | async def get_progress(self, task_id: str) -> Optional[Dict[str, Any]]: |
| | """ |
| | 获取当前进度 |
| | |
| | Args: |
| | task_id: 任务ID |
| | |
| | Returns: |
| | 最新进度信息,不存在则返回 None |
| | """ |
| | return self.progress_store.get(task_id) |
| | |
| | async def subscribe(self, task_id: str) -> AsyncGenerator[Dict[str, Any], None]: |
| | """ |
| | 订阅进度更新 |
| | |
| | 创建一个异步生成器,持续接收指定任务的进度更新。 |
| | 当任务进入终态(completed, failed, cancelled)时自动结束。 |
| | |
| | Args: |
| | task_id: 任务ID |
| | |
| | Yields: |
| | 进度信息字典 |
| | |
| | Example: |
| | >>> async for progress in adapter.subscribe("task-123"): |
| | ... print(progress) |
| | ... if progress.get("status") == "completed": |
| | ... break |
| | """ |
| | |
| | queue: asyncio.Queue = asyncio.Queue(maxsize=100) |
| | |
| | async with self._lock: |
| | self.subscribers[task_id].append(queue) |
| | |
| | try: |
| | |
| | current = self.progress_store.get(task_id) |
| | if current: |
| | yield current |
| | |
| | if current.get("status") in ("completed", "failed", "cancelled"): |
| | return |
| | |
| | |
| | while True: |
| | try: |
| | |
| | progress = await asyncio.wait_for(queue.get(), timeout=30.0) |
| | yield progress |
| | |
| | |
| | if progress.get("status") in ("completed", "failed", "cancelled"): |
| | break |
| | |
| | except asyncio.TimeoutError: |
| | |
| | yield { |
| | "type": "heartbeat", |
| | "timestamp": datetime.utcnow().isoformat(), |
| | } |
| | |
| | finally: |
| | |
| | async with self._lock: |
| | if task_id in self.subscribers: |
| | try: |
| | self.subscribers[task_id].remove(queue) |
| | except ValueError: |
| | pass |
| | |
| | |
| | if not self.subscribers[task_id]: |
| | del self.subscribers[task_id] |
| | |
| | async def clear_progress(self, task_id: str) -> None: |
| | """ |
| | 清除任务进度数据 |
| | |
| | Args: |
| | task_id: 任务ID |
| | """ |
| | self.progress_store.pop(task_id, None) |
| | |
| | async with self._lock: |
| | self.subscribers.pop(task_id, None) |
| | |
| | async def get_subscriber_count(self, task_id: str) -> int: |
| | """ |
| | 获取任务的订阅者数量 |
| | |
| | Args: |
| | task_id: 任务ID |
| | |
| | Returns: |
| | 订阅者数量 |
| | """ |
| | async with self._lock: |
| | return len(self.subscribers.get(task_id, [])) |
| | |
| | async def broadcast_to_all(self, message: Dict[str, Any]) -> int: |
| | """ |
| | 向所有任务的订阅者广播消息 |
| | |
| | 用于系统级通知,如服务器关闭警告等。 |
| | |
| | Args: |
| | message: 消息内容 |
| | |
| | Returns: |
| | 发送成功的订阅者数量 |
| | """ |
| | if "timestamp" not in message: |
| | message["timestamp"] = datetime.utcnow().isoformat() |
| | |
| | count = 0 |
| | async with self._lock: |
| | for task_id, queues in self.subscribers.items(): |
| | for queue in queues: |
| | try: |
| | await queue.put(message) |
| | count += 1 |
| | except asyncio.QueueFull: |
| | pass |
| | |
| | return count |
| | |
| | def get_active_tasks(self) -> List[str]: |
| | """ |
| | 获取有活跃订阅者的任务列表 |
| | |
| | Returns: |
| | 任务ID列表 |
| | """ |
| | return list(self.subscribers.keys()) |
| | |
| | def get_stats(self) -> Dict[str, Any]: |
| | """ |
| | 获取适配器统计信息 |
| | |
| | Returns: |
| | 统计信息字典 |
| | """ |
| | total_subscribers = sum( |
| | len(queues) for queues in self.subscribers.values() |
| | ) |
| | |
| | return { |
| | "stored_progress_count": len(self.progress_store), |
| | "active_tasks": len(self.subscribers), |
| | "total_subscribers": total_subscribers, |
| | } |
| |
|