|
|
""" |
|
|
本地进度管理适配器 |
|
|
|
|
|
基于内存队列实现的进度管理适配器,适用于本地单实例场景。 |
|
|
""" |
|
|
|
|
|
import asyncio |
|
|
from collections import defaultdict |
|
|
from datetime import datetime |
|
|
from typing import Any, AsyncGenerator, Dict, List, Optional |
|
|
|
|
|
from ..base import ProgressAdapter |
|
|
|
|
|
|
|
|
class LocalProgressAdapter(ProgressAdapter): |
|
|
""" |
|
|
本地内存进度管理适配器 |
|
|
|
|
|
特点: |
|
|
1. 使用内存字典存储最新进度 |
|
|
2. 使用 asyncio.Queue 实现订阅者模式 |
|
|
3. 支持多订阅者同时订阅同一任务 |
|
|
4. 与 AsyncTrainingManager 的进度推送机制兼容 |
|
|
|
|
|
注意: |
|
|
- 进程重启后进度数据会丢失 |
|
|
- 仅适用于单实例部署 |
|
|
- 服务器模式应使用 RedisProgressAdapter |
|
|
|
|
|
Example: |
|
|
>>> adapter = LocalProgressAdapter() |
|
|
>>> await adapter.update_progress("task-123", { |
|
|
... "stage": "sovits_train", |
|
|
... "progress": 0.5, |
|
|
... "message": "Epoch 8/16" |
|
|
... }) |
|
|
>>> |
|
|
>>> # 订阅进度 |
|
|
>>> async for progress in adapter.subscribe("task-123"): |
|
|
... print(f"{progress['stage']}: {progress['progress']*100:.1f}%") |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
"""初始化本地进度适配器""" |
|
|
|
|
|
self.progress_store: Dict[str, Dict[str, Any]] = {} |
|
|
|
|
|
|
|
|
self.subscribers: Dict[str, List[asyncio.Queue]] = defaultdict(list) |
|
|
|
|
|
|
|
|
self._lock = asyncio.Lock() |
|
|
|
|
|
async def update_progress(self, task_id: str, progress: Dict[str, Any]) -> None: |
|
|
""" |
|
|
更新进度 |
|
|
|
|
|
Args: |
|
|
task_id: 任务ID |
|
|
progress: 进度信息字典,可包含: |
|
|
- type: 消息类型 ("progress", "log", "error", "heartbeat") |
|
|
- stage: 当前阶段 |
|
|
- progress: 阶段进度 (0.0-1.0) |
|
|
- overall_progress: 总体进度 (0.0-1.0) |
|
|
- message: 进度消息 |
|
|
- status: 状态 ("running", "completed", "failed", "cancelled") |
|
|
""" |
|
|
|
|
|
if "timestamp" not in progress: |
|
|
progress["timestamp"] = datetime.utcnow().isoformat() |
|
|
|
|
|
|
|
|
self.progress_store[task_id] = progress |
|
|
|
|
|
|
|
|
async with self._lock: |
|
|
if task_id in self.subscribers: |
|
|
for queue in self.subscribers[task_id]: |
|
|
try: |
|
|
await queue.put(progress) |
|
|
except asyncio.QueueFull: |
|
|
|
|
|
pass |
|
|
|
|
|
async def get_progress(self, task_id: str) -> Optional[Dict[str, Any]]: |
|
|
""" |
|
|
获取当前进度 |
|
|
|
|
|
Args: |
|
|
task_id: 任务ID |
|
|
|
|
|
Returns: |
|
|
最新进度信息,不存在则返回 None |
|
|
""" |
|
|
return self.progress_store.get(task_id) |
|
|
|
|
|
async def subscribe(self, task_id: str) -> AsyncGenerator[Dict[str, Any], None]: |
|
|
""" |
|
|
订阅进度更新 |
|
|
|
|
|
创建一个异步生成器,持续接收指定任务的进度更新。 |
|
|
当任务进入终态(completed, failed, cancelled)时自动结束。 |
|
|
|
|
|
Args: |
|
|
task_id: 任务ID |
|
|
|
|
|
Yields: |
|
|
进度信息字典 |
|
|
|
|
|
Example: |
|
|
>>> async for progress in adapter.subscribe("task-123"): |
|
|
... print(progress) |
|
|
... if progress.get("status") == "completed": |
|
|
... break |
|
|
""" |
|
|
|
|
|
queue: asyncio.Queue = asyncio.Queue(maxsize=100) |
|
|
|
|
|
async with self._lock: |
|
|
self.subscribers[task_id].append(queue) |
|
|
|
|
|
try: |
|
|
|
|
|
current = self.progress_store.get(task_id) |
|
|
if current: |
|
|
yield current |
|
|
|
|
|
if current.get("status") in ("completed", "failed", "cancelled"): |
|
|
return |
|
|
|
|
|
|
|
|
while True: |
|
|
try: |
|
|
|
|
|
progress = await asyncio.wait_for(queue.get(), timeout=30.0) |
|
|
yield progress |
|
|
|
|
|
|
|
|
if progress.get("status") in ("completed", "failed", "cancelled"): |
|
|
break |
|
|
|
|
|
except asyncio.TimeoutError: |
|
|
|
|
|
yield { |
|
|
"type": "heartbeat", |
|
|
"timestamp": datetime.utcnow().isoformat(), |
|
|
} |
|
|
|
|
|
finally: |
|
|
|
|
|
async with self._lock: |
|
|
if task_id in self.subscribers: |
|
|
try: |
|
|
self.subscribers[task_id].remove(queue) |
|
|
except ValueError: |
|
|
pass |
|
|
|
|
|
|
|
|
if not self.subscribers[task_id]: |
|
|
del self.subscribers[task_id] |
|
|
|
|
|
async def clear_progress(self, task_id: str) -> None: |
|
|
""" |
|
|
清除任务进度数据 |
|
|
|
|
|
Args: |
|
|
task_id: 任务ID |
|
|
""" |
|
|
self.progress_store.pop(task_id, None) |
|
|
|
|
|
async with self._lock: |
|
|
self.subscribers.pop(task_id, None) |
|
|
|
|
|
async def get_subscriber_count(self, task_id: str) -> int: |
|
|
""" |
|
|
获取任务的订阅者数量 |
|
|
|
|
|
Args: |
|
|
task_id: 任务ID |
|
|
|
|
|
Returns: |
|
|
订阅者数量 |
|
|
""" |
|
|
async with self._lock: |
|
|
return len(self.subscribers.get(task_id, [])) |
|
|
|
|
|
async def broadcast_to_all(self, message: Dict[str, Any]) -> int: |
|
|
""" |
|
|
向所有任务的订阅者广播消息 |
|
|
|
|
|
用于系统级通知,如服务器关闭警告等。 |
|
|
|
|
|
Args: |
|
|
message: 消息内容 |
|
|
|
|
|
Returns: |
|
|
发送成功的订阅者数量 |
|
|
""" |
|
|
if "timestamp" not in message: |
|
|
message["timestamp"] = datetime.utcnow().isoformat() |
|
|
|
|
|
count = 0 |
|
|
async with self._lock: |
|
|
for task_id, queues in self.subscribers.items(): |
|
|
for queue in queues: |
|
|
try: |
|
|
await queue.put(message) |
|
|
count += 1 |
|
|
except asyncio.QueueFull: |
|
|
pass |
|
|
|
|
|
return count |
|
|
|
|
|
def get_active_tasks(self) -> List[str]: |
|
|
""" |
|
|
获取有活跃订阅者的任务列表 |
|
|
|
|
|
Returns: |
|
|
任务ID列表 |
|
|
""" |
|
|
return list(self.subscribers.keys()) |
|
|
|
|
|
def get_stats(self) -> Dict[str, Any]: |
|
|
""" |
|
|
获取适配器统计信息 |
|
|
|
|
|
Returns: |
|
|
统计信息字典 |
|
|
""" |
|
|
total_subscribers = sum( |
|
|
len(queues) for queues in self.subscribers.values() |
|
|
) |
|
|
|
|
|
return { |
|
|
"stored_progress_count": len(self.progress_store), |
|
|
"active_tasks": len(self.subscribers), |
|
|
"total_subscribers": total_subscribers, |
|
|
} |
|
|
|