""" 本地进度管理适配器 基于内存队列实现的进度管理适配器,适用于本地单实例场景。 """ import asyncio from collections import defaultdict from datetime import datetime from typing import Any, AsyncGenerator, Dict, List, Optional from ..base import ProgressAdapter class LocalProgressAdapter(ProgressAdapter): """ 本地内存进度管理适配器 特点: 1. 使用内存字典存储最新进度 2. 使用 asyncio.Queue 实现订阅者模式 3. 支持多订阅者同时订阅同一任务 4. 与 AsyncTrainingManager 的进度推送机制兼容 注意: - 进程重启后进度数据会丢失 - 仅适用于单实例部署 - 服务器模式应使用 RedisProgressAdapter Example: >>> adapter = LocalProgressAdapter() >>> await adapter.update_progress("task-123", { ... "stage": "sovits_train", ... "progress": 0.5, ... "message": "Epoch 8/16" ... }) >>> >>> # 订阅进度 >>> async for progress in adapter.subscribe("task-123"): ... print(f"{progress['stage']}: {progress['progress']*100:.1f}%") """ def __init__(self): """初始化本地进度适配器""" # 存储每个任务的最新进度 self.progress_store: Dict[str, Dict[str, Any]] = {} # 存储每个任务的订阅者队列列表 self.subscribers: Dict[str, List[asyncio.Queue]] = defaultdict(list) # 锁,用于保护订阅者列表的并发访问 self._lock = asyncio.Lock() async def update_progress(self, task_id: str, progress: Dict[str, Any]) -> None: """ 更新进度 Args: task_id: 任务ID progress: 进度信息字典,可包含: - type: 消息类型 ("progress", "log", "error", "heartbeat") - stage: 当前阶段 - progress: 阶段进度 (0.0-1.0) - overall_progress: 总体进度 (0.0-1.0) - message: 进度消息 - status: 状态 ("running", "completed", "failed", "cancelled") """ # 添加时间戳 if "timestamp" not in progress: progress["timestamp"] = datetime.utcnow().isoformat() # 存储最新进度 self.progress_store[task_id] = progress # 通知所有订阅者 async with self._lock: if task_id in self.subscribers: for queue in self.subscribers[task_id]: try: await queue.put(progress) except asyncio.QueueFull: # 队列满了,跳过(避免阻塞) pass async def get_progress(self, task_id: str) -> Optional[Dict[str, Any]]: """ 获取当前进度 Args: task_id: 任务ID Returns: 最新进度信息,不存在则返回 None """ return self.progress_store.get(task_id) async def subscribe(self, task_id: str) -> AsyncGenerator[Dict[str, Any], None]: """ 订阅进度更新 创建一个异步生成器,持续接收指定任务的进度更新。 当任务进入终态(completed, failed, cancelled)时自动结束。 Args: task_id: 任务ID Yields: 进度信息字典 Example: >>> async for progress in adapter.subscribe("task-123"): ... print(progress) ... if progress.get("status") == "completed": ... break """ # 创建订阅者队列 queue: asyncio.Queue = asyncio.Queue(maxsize=100) async with self._lock: self.subscribers[task_id].append(queue) try: # 首先发送当前进度(如果有) current = self.progress_store.get(task_id) if current: yield current # 如果已经是终态,直接返回 if current.get("status") in ("completed", "failed", "cancelled"): return # 持续接收更新 while True: try: # 30秒超时,发送心跳 progress = await asyncio.wait_for(queue.get(), timeout=30.0) yield progress # 检查是否为终态 if progress.get("status") in ("completed", "failed", "cancelled"): break except asyncio.TimeoutError: # 发送心跳保持连接 yield { "type": "heartbeat", "timestamp": datetime.utcnow().isoformat(), } finally: # 清理订阅者 async with self._lock: if task_id in self.subscribers: try: self.subscribers[task_id].remove(queue) except ValueError: pass # 如果没有订阅者了,清理列表 if not self.subscribers[task_id]: del self.subscribers[task_id] async def clear_progress(self, task_id: str) -> None: """ 清除任务进度数据 Args: task_id: 任务ID """ self.progress_store.pop(task_id, None) async with self._lock: self.subscribers.pop(task_id, None) async def get_subscriber_count(self, task_id: str) -> int: """ 获取任务的订阅者数量 Args: task_id: 任务ID Returns: 订阅者数量 """ async with self._lock: return len(self.subscribers.get(task_id, [])) async def broadcast_to_all(self, message: Dict[str, Any]) -> int: """ 向所有任务的订阅者广播消息 用于系统级通知,如服务器关闭警告等。 Args: message: 消息内容 Returns: 发送成功的订阅者数量 """ if "timestamp" not in message: message["timestamp"] = datetime.utcnow().isoformat() count = 0 async with self._lock: for task_id, queues in self.subscribers.items(): for queue in queues: try: await queue.put(message) count += 1 except asyncio.QueueFull: pass return count def get_active_tasks(self) -> List[str]: """ 获取有活跃订阅者的任务列表 Returns: 任务ID列表 """ return list(self.subscribers.keys()) def get_stats(self) -> Dict[str, Any]: """ 获取适配器统计信息 Returns: 统计信息字典 """ total_subscribers = sum( len(queues) for queues in self.subscribers.values() ) return { "stored_progress_count": len(self.progress_store), "active_tasks": len(self.subscribers), "total_subscribers": total_subscribers, }