|
|
""" |
|
|
本地异步任务管理器 |
|
|
|
|
|
基于 asyncio.subprocess + SQLite 的本地任务队列实现。 |
|
|
适用于 macOS 本地训练和 Electron 集成场景。 |
|
|
""" |
|
|
|
|
|
import asyncio |
|
|
import json |
|
|
import os |
|
|
import sqlite3 |
|
|
import sys |
|
|
import uuid |
|
|
from datetime import datetime |
|
|
from pathlib import Path |
|
|
from typing import TYPE_CHECKING, Dict, Optional, AsyncGenerator, List |
|
|
|
|
|
import aiosqlite |
|
|
|
|
|
from project_config import settings, PROJECT_ROOT, get_pythonpath |
|
|
from ..base import TaskQueueAdapter |
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from ..base import DatabaseAdapter |
|
|
|
|
|
|
|
|
PROGRESS_PREFIX = "##PROGRESS##" |
|
|
PROGRESS_SUFFIX = "##" |
|
|
|
|
|
|
|
|
class AsyncTrainingManager(TaskQueueAdapter): |
|
|
""" |
|
|
基于 asyncio.subprocess 的异步任务管理器 |
|
|
|
|
|
特点: |
|
|
1. 使用 asyncio.create_subprocess_exec() 异步启动训练子进程 |
|
|
2. 完全非阻塞,与 FastAPI 异步模型完美契合 |
|
|
3. SQLite 持久化任务状态,支持应用重启后恢复 |
|
|
4. 实时解析子进程输出获取进度 |
|
|
|
|
|
Example: |
|
|
>>> manager = AsyncTrainingManager(db_path="./data/tasks.db") |
|
|
>>> job_id = await manager.enqueue("task-123", {"exp_name": "test", ...}) |
|
|
>>> |
|
|
>>> # 订阅进度 |
|
|
>>> async for progress in manager.subscribe_progress("task-123"): |
|
|
... print(f"{progress['stage']}: {progress['progress']*100:.1f}%") |
|
|
>>> |
|
|
>>> # 取消任务 |
|
|
>>> await manager.cancel(job_id) |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
db_path: str = None, |
|
|
max_concurrent: int = 1, |
|
|
database_adapter: "DatabaseAdapter" = None |
|
|
): |
|
|
""" |
|
|
初始化任务管理器 |
|
|
|
|
|
Args: |
|
|
db_path: SQLite 数据库路径,默认使用 settings.SQLITE_PATH |
|
|
max_concurrent: 最大并发任务数(本地通常为1) |
|
|
database_adapter: 数据库适配器,用于同步更新 tasks 表 |
|
|
""" |
|
|
self.db_path = db_path or str(settings.SQLITE_PATH) |
|
|
self.max_concurrent = max_concurrent |
|
|
self._database_adapter = database_adapter |
|
|
|
|
|
|
|
|
self.running_processes: Dict[str, asyncio.subprocess.Process] = {} |
|
|
self.progress_channels: Dict[str, asyncio.Queue] = {} |
|
|
self._running_count = 0 |
|
|
self._queue_lock = asyncio.Lock() |
|
|
|
|
|
|
|
|
self._task_job_mapping: Dict[str, str] = {} |
|
|
|
|
|
|
|
|
self._init_db_sync() |
|
|
|
|
|
def _init_db_sync(self) -> None: |
|
|
"""同步初始化数据库(启动时调用)""" |
|
|
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
with sqlite3.connect(self.db_path) as conn: |
|
|
conn.execute(''' |
|
|
CREATE TABLE IF NOT EXISTS task_queue ( |
|
|
job_id TEXT PRIMARY KEY, |
|
|
task_id TEXT NOT NULL UNIQUE, |
|
|
exp_name TEXT NOT NULL, |
|
|
config TEXT NOT NULL, |
|
|
status TEXT DEFAULT 'queued', |
|
|
current_stage TEXT, |
|
|
progress REAL DEFAULT 0, |
|
|
overall_progress REAL DEFAULT 0, |
|
|
message TEXT, |
|
|
error_message TEXT, |
|
|
created_at TEXT NOT NULL, |
|
|
started_at TEXT, |
|
|
completed_at TEXT |
|
|
) |
|
|
''') |
|
|
conn.execute('CREATE INDEX IF NOT EXISTS idx_task_queue_status ON task_queue(status)') |
|
|
conn.execute('CREATE INDEX IF NOT EXISTS idx_task_queue_task_id ON task_queue(task_id)') |
|
|
conn.execute('CREATE INDEX IF NOT EXISTS idx_task_queue_created ON task_queue(created_at)') |
|
|
conn.commit() |
|
|
|
|
|
async def enqueue(self, task_id: str, config: Dict, priority: str = "normal") -> str: |
|
|
""" |
|
|
将任务加入队列并异步启动 |
|
|
|
|
|
Args: |
|
|
task_id: 任务唯一标识 |
|
|
config: 任务配置,需包含: |
|
|
- exp_name: 实验名称 |
|
|
- version: 模型版本 |
|
|
- stages: 阶段配置列表 |
|
|
priority: 优先级(当前实现忽略此参数) |
|
|
|
|
|
Returns: |
|
|
job_id: 作业ID |
|
|
""" |
|
|
job_id = str(uuid.uuid4()) |
|
|
exp_name = config.get("exp_name", "unknown") |
|
|
|
|
|
|
|
|
async with aiosqlite.connect(self.db_path) as db: |
|
|
await db.execute( |
|
|
'''INSERT INTO task_queue |
|
|
(job_id, task_id, exp_name, config, status, created_at) |
|
|
VALUES (?, ?, ?, ?, 'queued', ?)''', |
|
|
(job_id, task_id, exp_name, json.dumps(config, ensure_ascii=False), |
|
|
datetime.utcnow().isoformat()) |
|
|
) |
|
|
await db.commit() |
|
|
|
|
|
|
|
|
self._task_job_mapping[task_id] = job_id |
|
|
|
|
|
|
|
|
self.progress_channels[task_id] = asyncio.Queue() |
|
|
|
|
|
|
|
|
asyncio.create_task(self._run_training_async(job_id, task_id, config)) |
|
|
|
|
|
return job_id |
|
|
|
|
|
async def _run_training_async(self, job_id: str, task_id: str, config: Dict) -> None: |
|
|
""" |
|
|
异步执行训练 Pipeline |
|
|
|
|
|
Args: |
|
|
job_id: 作业ID |
|
|
task_id: 任务ID |
|
|
config: 任务配置 |
|
|
""" |
|
|
config_path = None |
|
|
|
|
|
try: |
|
|
|
|
|
await self._update_status( |
|
|
job_id, |
|
|
status='running', |
|
|
started_at=datetime.utcnow().isoformat() |
|
|
) |
|
|
await self._send_progress(task_id, { |
|
|
"type": "progress", |
|
|
"status": "running", |
|
|
"message": "训练任务启动中...", |
|
|
"progress": 0.0, |
|
|
"overall_progress": 0.0, |
|
|
}) |
|
|
|
|
|
|
|
|
config_path = await self._write_config_file(task_id, config) |
|
|
|
|
|
|
|
|
script_path = self._get_pipeline_script_path() |
|
|
|
|
|
|
|
|
env = os.environ.copy() |
|
|
env['PYTHONPATH'] = get_pythonpath() |
|
|
|
|
|
|
|
|
process = await asyncio.create_subprocess_exec( |
|
|
sys.executable, script_path, |
|
|
'--config', config_path, |
|
|
'--task-id', task_id, |
|
|
stdout=asyncio.subprocess.PIPE, |
|
|
stderr=asyncio.subprocess.PIPE, |
|
|
env=env, |
|
|
cwd=str(PROJECT_ROOT), |
|
|
) |
|
|
|
|
|
self.running_processes[task_id] = process |
|
|
self._running_count += 1 |
|
|
|
|
|
|
|
|
await self._monitor_process_output(task_id, job_id, process) |
|
|
|
|
|
|
|
|
returncode = await process.wait() |
|
|
|
|
|
if returncode == 0: |
|
|
await self._update_status( |
|
|
job_id, |
|
|
status='completed', |
|
|
progress=1.0, |
|
|
overall_progress=1.0, |
|
|
message='训练完成', |
|
|
completed_at=datetime.utcnow().isoformat() |
|
|
) |
|
|
await self._send_progress(task_id, { |
|
|
"type": "progress", |
|
|
"status": "completed", |
|
|
"message": "训练完成", |
|
|
"progress": 1.0, |
|
|
"overall_progress": 1.0, |
|
|
}) |
|
|
else: |
|
|
|
|
|
stderr_data = await process.stderr.read() |
|
|
error_msg = stderr_data.decode() if stderr_data else f"进程退出码: {returncode}" |
|
|
|
|
|
await self._update_status( |
|
|
job_id, |
|
|
status='failed', |
|
|
error_message=error_msg, |
|
|
completed_at=datetime.utcnow().isoformat() |
|
|
) |
|
|
await self._send_progress(task_id, { |
|
|
"type": "progress", |
|
|
"status": "failed", |
|
|
"message": f"训练失败: {error_msg[:200]}", |
|
|
"error": error_msg, |
|
|
}) |
|
|
|
|
|
except asyncio.CancelledError: |
|
|
await self._update_status( |
|
|
job_id, |
|
|
status='cancelled', |
|
|
message='任务已取消', |
|
|
completed_at=datetime.utcnow().isoformat() |
|
|
) |
|
|
await self._send_progress(task_id, { |
|
|
"type": "progress", |
|
|
"status": "cancelled", |
|
|
"message": "任务已取消", |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = str(e) |
|
|
await self._update_status( |
|
|
job_id, |
|
|
status='failed', |
|
|
error_message=error_msg, |
|
|
completed_at=datetime.utcnow().isoformat() |
|
|
) |
|
|
await self._send_progress(task_id, { |
|
|
"type": "progress", |
|
|
"status": "failed", |
|
|
"message": f"任务执行出错: {error_msg}", |
|
|
"error": error_msg, |
|
|
}) |
|
|
|
|
|
finally: |
|
|
|
|
|
self.running_processes.pop(task_id, None) |
|
|
self._running_count = max(0, self._running_count - 1) |
|
|
|
|
|
|
|
|
if config_path: |
|
|
await self._cleanup_config_file(config_path) |
|
|
|
|
|
async def _monitor_process_output( |
|
|
self, |
|
|
task_id: str, |
|
|
job_id: str, |
|
|
process: asyncio.subprocess.Process |
|
|
) -> None: |
|
|
""" |
|
|
异步监控子进程输出并解析进度 |
|
|
|
|
|
Args: |
|
|
task_id: 任务ID |
|
|
job_id: 作业ID |
|
|
process: 子进程对象 |
|
|
""" |
|
|
async def read_stdout(): |
|
|
"""读取 stdout 并解析进度""" |
|
|
while True: |
|
|
line = await process.stdout.readline() |
|
|
if not line: |
|
|
break |
|
|
|
|
|
text = line.decode('utf-8', errors='replace').strip() |
|
|
if not text: |
|
|
continue |
|
|
|
|
|
|
|
|
if text.startswith(PROGRESS_PREFIX) and text.endswith(PROGRESS_SUFFIX): |
|
|
json_str = text[len(PROGRESS_PREFIX):-len(PROGRESS_SUFFIX)] |
|
|
try: |
|
|
progress_info = json.loads(json_str) |
|
|
await self._handle_progress(task_id, job_id, progress_info) |
|
|
except json.JSONDecodeError as e: |
|
|
|
|
|
await self._send_progress(task_id, { |
|
|
"type": "log", |
|
|
"level": "warning", |
|
|
"message": f"进度解析失败: {e}", |
|
|
}) |
|
|
else: |
|
|
|
|
|
await self._send_progress(task_id, { |
|
|
"type": "log", |
|
|
"level": "info", |
|
|
"message": text, |
|
|
}) |
|
|
|
|
|
async def read_stderr(): |
|
|
"""读取 stderr 作为错误日志""" |
|
|
while True: |
|
|
line = await process.stderr.readline() |
|
|
if not line: |
|
|
break |
|
|
|
|
|
text = line.decode('utf-8', errors='replace').strip() |
|
|
if text: |
|
|
await self._send_progress(task_id, { |
|
|
"type": "log", |
|
|
"level": "error", |
|
|
"message": text, |
|
|
}) |
|
|
|
|
|
|
|
|
await asyncio.gather( |
|
|
read_stdout(), |
|
|
read_stderr(), |
|
|
return_exceptions=True |
|
|
) |
|
|
|
|
|
async def _handle_progress( |
|
|
self, |
|
|
task_id: str, |
|
|
job_id: str, |
|
|
progress_info: Dict |
|
|
) -> None: |
|
|
""" |
|
|
处理进度信息 |
|
|
|
|
|
Args: |
|
|
task_id: 任务ID |
|
|
job_id: 作业ID |
|
|
progress_info: 进度信息字典 |
|
|
""" |
|
|
|
|
|
await self._send_progress(task_id, progress_info) |
|
|
|
|
|
|
|
|
updates = {} |
|
|
|
|
|
if 'stage' in progress_info: |
|
|
updates['current_stage'] = progress_info['stage'] |
|
|
if 'progress' in progress_info: |
|
|
updates['progress'] = progress_info['progress'] |
|
|
if 'overall_progress' in progress_info: |
|
|
updates['overall_progress'] = progress_info['overall_progress'] |
|
|
if 'message' in progress_info: |
|
|
updates['message'] = progress_info['message'] |
|
|
if 'status' in progress_info: |
|
|
updates['status'] = progress_info['status'] |
|
|
if 'error' in progress_info: |
|
|
updates['error_message'] = progress_info['error'] |
|
|
|
|
|
if updates: |
|
|
await self._update_status(job_id, **updates) |
|
|
|
|
|
async def _send_progress(self, task_id: str, progress_info: Dict) -> None: |
|
|
""" |
|
|
发送进度到订阅队列 |
|
|
|
|
|
Args: |
|
|
task_id: 任务ID |
|
|
progress_info: 进度信息 |
|
|
""" |
|
|
if task_id in self.progress_channels: |
|
|
|
|
|
if 'timestamp' not in progress_info: |
|
|
progress_info['timestamp'] = datetime.utcnow().isoformat() |
|
|
|
|
|
await self.progress_channels[task_id].put(progress_info) |
|
|
|
|
|
async def _update_status(self, job_id: str, **kwargs) -> None: |
|
|
""" |
|
|
更新任务状态 |
|
|
|
|
|
同时更新 task_queue 表和 tasks 表(通过 DatabaseAdapter)。 |
|
|
|
|
|
Args: |
|
|
job_id: 作业ID |
|
|
**kwargs: 要更新的字段 |
|
|
""" |
|
|
if not kwargs: |
|
|
return |
|
|
|
|
|
|
|
|
task_id = None |
|
|
async with aiosqlite.connect(self.db_path) as db: |
|
|
updates = [] |
|
|
values = [] |
|
|
|
|
|
for key, value in kwargs.items(): |
|
|
updates.append(f"{key} = ?") |
|
|
values.append(value) |
|
|
|
|
|
values.append(job_id) |
|
|
|
|
|
await db.execute( |
|
|
f"UPDATE task_queue SET {', '.join(updates)} WHERE job_id = ?", |
|
|
values |
|
|
) |
|
|
await db.commit() |
|
|
|
|
|
|
|
|
async with db.execute( |
|
|
"SELECT task_id FROM task_queue WHERE job_id = ?", (job_id,) |
|
|
) as cursor: |
|
|
row = await cursor.fetchone() |
|
|
if row: |
|
|
task_id = row[0] |
|
|
|
|
|
|
|
|
if self._database_adapter and task_id: |
|
|
await self._sync_to_tasks_table(task_id, kwargs) |
|
|
|
|
|
async def _sync_to_tasks_table(self, task_id: str, updates: Dict) -> None: |
|
|
""" |
|
|
同步状态更新到 tasks 表 |
|
|
|
|
|
字段映射: |
|
|
- task_queue.progress -> tasks.stage_progress |
|
|
- task_queue.overall_progress -> tasks.progress |
|
|
- 其他字段直接映射 |
|
|
|
|
|
Args: |
|
|
task_id: 任务ID |
|
|
updates: 要更新的字段字典 |
|
|
""" |
|
|
if not self._database_adapter: |
|
|
return |
|
|
|
|
|
|
|
|
tasks_updates = {} |
|
|
|
|
|
for key, value in updates.items(): |
|
|
if key == 'progress': |
|
|
|
|
|
tasks_updates['stage_progress'] = value |
|
|
elif key == 'overall_progress': |
|
|
|
|
|
tasks_updates['progress'] = value |
|
|
elif key in ('status', 'current_stage', 'message', 'error_message', |
|
|
'started_at', 'completed_at'): |
|
|
|
|
|
tasks_updates[key] = value |
|
|
|
|
|
if tasks_updates: |
|
|
try: |
|
|
await self._database_adapter.update_task(task_id, tasks_updates) |
|
|
except Exception as e: |
|
|
|
|
|
import logging |
|
|
logging.warning(f"Failed to sync task status to tasks table: {e}") |
|
|
|
|
|
async def get_status(self, job_id: str) -> Dict: |
|
|
""" |
|
|
获取任务状态 |
|
|
|
|
|
Args: |
|
|
job_id: 作业ID |
|
|
|
|
|
Returns: |
|
|
状态字典 |
|
|
""" |
|
|
async with aiosqlite.connect(self.db_path) as db: |
|
|
db.row_factory = aiosqlite.Row |
|
|
async with db.execute( |
|
|
"SELECT * FROM task_queue WHERE job_id = ?", (job_id,) |
|
|
) as cursor: |
|
|
row = await cursor.fetchone() |
|
|
if row: |
|
|
return dict(row) |
|
|
|
|
|
return {"status": "not_found", "message": "任务不存在"} |
|
|
|
|
|
async def get_status_by_task_id(self, task_id: str) -> Dict: |
|
|
""" |
|
|
通过 task_id 获取任务状态 |
|
|
|
|
|
Args: |
|
|
task_id: 任务ID |
|
|
|
|
|
Returns: |
|
|
状态字典 |
|
|
""" |
|
|
async with aiosqlite.connect(self.db_path) as db: |
|
|
db.row_factory = aiosqlite.Row |
|
|
async with db.execute( |
|
|
"SELECT * FROM task_queue WHERE task_id = ?", (task_id,) |
|
|
) as cursor: |
|
|
row = await cursor.fetchone() |
|
|
if row: |
|
|
return dict(row) |
|
|
|
|
|
return {"status": "not_found", "message": "任务不存在"} |
|
|
|
|
|
async def cancel(self, job_id: str) -> bool: |
|
|
""" |
|
|
取消任务 |
|
|
|
|
|
Args: |
|
|
job_id: 作业ID |
|
|
|
|
|
Returns: |
|
|
是否成功取消 |
|
|
""" |
|
|
|
|
|
async with aiosqlite.connect(self.db_path) as db: |
|
|
async with db.execute( |
|
|
"SELECT task_id, status FROM task_queue WHERE job_id = ?", (job_id,) |
|
|
) as cursor: |
|
|
row = await cursor.fetchone() |
|
|
if not row: |
|
|
return False |
|
|
task_id, status = row |
|
|
|
|
|
|
|
|
if status in ('completed', 'failed', 'cancelled'): |
|
|
return False |
|
|
|
|
|
|
|
|
if task_id in self.running_processes: |
|
|
process = self.running_processes[task_id] |
|
|
|
|
|
|
|
|
process.terminate() |
|
|
|
|
|
try: |
|
|
|
|
|
await asyncio.wait_for(process.wait(), timeout=5.0) |
|
|
except asyncio.TimeoutError: |
|
|
|
|
|
process.kill() |
|
|
await process.wait() |
|
|
|
|
|
return True |
|
|
|
|
|
|
|
|
await self._update_status( |
|
|
job_id, |
|
|
status='cancelled', |
|
|
message='任务已取消', |
|
|
completed_at=datetime.utcnow().isoformat() |
|
|
) |
|
|
|
|
|
|
|
|
if task_id in self.progress_channels: |
|
|
await self._send_progress(task_id, { |
|
|
"type": "progress", |
|
|
"status": "cancelled", |
|
|
"message": "任务已取消", |
|
|
}) |
|
|
|
|
|
return True |
|
|
|
|
|
async def subscribe_progress(self, task_id: str) -> AsyncGenerator[Dict, None]: |
|
|
""" |
|
|
订阅任务进度(用于 SSE 流) |
|
|
|
|
|
Args: |
|
|
task_id: 任务ID |
|
|
|
|
|
Yields: |
|
|
进度信息字典 |
|
|
""" |
|
|
|
|
|
if task_id not in self.progress_channels: |
|
|
self.progress_channels[task_id] = asyncio.Queue() |
|
|
|
|
|
queue = self.progress_channels[task_id] |
|
|
|
|
|
|
|
|
status = await self.get_status_by_task_id(task_id) |
|
|
if status.get("status") != "not_found": |
|
|
yield { |
|
|
"type": "progress", |
|
|
"status": status.get("status"), |
|
|
"stage": status.get("current_stage"), |
|
|
"progress": status.get("progress", 0.0), |
|
|
"overall_progress": status.get("overall_progress", 0.0), |
|
|
"message": status.get("message"), |
|
|
"timestamp": datetime.utcnow().isoformat(), |
|
|
} |
|
|
|
|
|
|
|
|
while True: |
|
|
try: |
|
|
|
|
|
progress = await asyncio.wait_for(queue.get(), timeout=30.0) |
|
|
yield progress |
|
|
|
|
|
|
|
|
if progress.get('status') in ('completed', 'failed', 'cancelled'): |
|
|
break |
|
|
|
|
|
except asyncio.TimeoutError: |
|
|
|
|
|
yield { |
|
|
"type": "heartbeat", |
|
|
"timestamp": datetime.utcnow().isoformat(), |
|
|
} |
|
|
|
|
|
async def list_tasks( |
|
|
self, |
|
|
status: Optional[str] = None, |
|
|
limit: int = 50, |
|
|
offset: int = 0 |
|
|
) -> List[Dict]: |
|
|
""" |
|
|
列出任务 |
|
|
|
|
|
Args: |
|
|
status: 按状态筛选 |
|
|
limit: 返回数量限制 |
|
|
offset: 偏移量 |
|
|
|
|
|
Returns: |
|
|
任务列表 |
|
|
""" |
|
|
async with aiosqlite.connect(self.db_path) as db: |
|
|
db.row_factory = aiosqlite.Row |
|
|
|
|
|
if status: |
|
|
query = """ |
|
|
SELECT * FROM task_queue |
|
|
WHERE status = ? |
|
|
ORDER BY created_at DESC |
|
|
LIMIT ? OFFSET ? |
|
|
""" |
|
|
params = (status, limit, offset) |
|
|
else: |
|
|
query = """ |
|
|
SELECT * FROM task_queue |
|
|
ORDER BY created_at DESC |
|
|
LIMIT ? OFFSET ? |
|
|
""" |
|
|
params = (limit, offset) |
|
|
|
|
|
async with db.execute(query, params) as cursor: |
|
|
rows = await cursor.fetchall() |
|
|
return [dict(row) for row in rows] |
|
|
|
|
|
async def recover_pending_tasks(self) -> int: |
|
|
""" |
|
|
应用重启后恢复未完成的任务 |
|
|
|
|
|
将 running 状态的任务标记为 interrupted, |
|
|
可选择重新启动 queued 状态的任务。 |
|
|
|
|
|
Returns: |
|
|
恢复的任务数量 |
|
|
""" |
|
|
async with aiosqlite.connect(self.db_path) as db: |
|
|
|
|
|
await db.execute( |
|
|
"""UPDATE task_queue |
|
|
SET status = 'interrupted', |
|
|
message = '应用重启导致任务中断' |
|
|
WHERE status = 'running'""" |
|
|
) |
|
|
await db.commit() |
|
|
|
|
|
|
|
|
db.row_factory = aiosqlite.Row |
|
|
async with db.execute( |
|
|
"SELECT * FROM task_queue WHERE status = 'queued' ORDER BY created_at" |
|
|
) as cursor: |
|
|
queued_tasks = await cursor.fetchall() |
|
|
|
|
|
|
|
|
recovered = 0 |
|
|
for task in queued_tasks: |
|
|
task_id = task['task_id'] |
|
|
job_id = task['job_id'] |
|
|
config = json.loads(task['config']) |
|
|
|
|
|
self.progress_channels[task_id] = asyncio.Queue() |
|
|
asyncio.create_task(self._run_training_async(job_id, task_id, config)) |
|
|
recovered += 1 |
|
|
|
|
|
return recovered |
|
|
|
|
|
async def cleanup_old_tasks(self, days: int = 7) -> int: |
|
|
""" |
|
|
清理旧任务记录 |
|
|
|
|
|
Args: |
|
|
days: 保留天数 |
|
|
|
|
|
Returns: |
|
|
删除的任务数量 |
|
|
""" |
|
|
from datetime import timedelta |
|
|
|
|
|
cutoff = (datetime.utcnow() - timedelta(days=days)).isoformat() |
|
|
|
|
|
async with aiosqlite.connect(self.db_path) as db: |
|
|
cursor = await db.execute( |
|
|
"""DELETE FROM task_queue |
|
|
WHERE status IN ('completed', 'failed', 'cancelled') |
|
|
AND completed_at < ?""", |
|
|
(cutoff,) |
|
|
) |
|
|
deleted = cursor.rowcount |
|
|
await db.commit() |
|
|
|
|
|
return deleted |
|
|
|
|
|
def _get_pipeline_script_path(self) -> str: |
|
|
"""获取 run_pipeline.py 脚本路径""" |
|
|
return str(settings.PIPELINE_SCRIPT_PATH) |
|
|
|
|
|
async def _write_config_file(self, task_id: str, config: Dict) -> str: |
|
|
""" |
|
|
写入临时配置文件 |
|
|
|
|
|
Args: |
|
|
task_id: 任务ID |
|
|
config: 配置字典 |
|
|
|
|
|
Returns: |
|
|
配置文件路径 |
|
|
""" |
|
|
config_path = settings.CONFIGS_DIR / f"{task_id}.json" |
|
|
|
|
|
with open(config_path, 'w', encoding='utf-8') as f: |
|
|
json.dump(config, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
return str(config_path) |
|
|
|
|
|
async def _cleanup_config_file(self, config_path: str) -> None: |
|
|
""" |
|
|
清理临时配置文件 |
|
|
|
|
|
Args: |
|
|
config_path: 配置文件路径 |
|
|
""" |
|
|
try: |
|
|
path = Path(config_path) |
|
|
if path.exists(): |
|
|
path.unlink() |
|
|
except Exception: |
|
|
pass |
|
|
|