Spaces:
Sleeping
Sleeping
| """任务管理模块 | |
| 提供任务状态跟踪、进度管理和任务队列功能。 | |
| """ | |
| import asyncio | |
| import time | |
| import uuid | |
| from dataclasses import dataclass, field | |
| from datetime import datetime, timedelta | |
| from enum import Enum | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Callable, Any | |
| from concurrent.futures import ThreadPoolExecutor | |
| from ..core.config import get_config | |
| from ..utils.logger import get_task_logger | |
| from ..services.file_validator import get_file_validator | |
| from ..services.oss_service import get_oss_service | |
| from ..services.paraformer_service import get_paraformer_service | |
| class TaskStatus(Enum): | |
| """任务状态""" | |
| PENDING = "pending" | |
| VALIDATING = "validating" | |
| UPLOADING = "uploading" | |
| TRANSCRIBING = "transcribing" | |
| COMPLETED = "completed" | |
| FAILED = "failed" | |
| CANCELLED = "cancelled" | |
| class TaskPriority(Enum): | |
| """任务优先级""" | |
| LOW = 1 | |
| NORMAL = 2 | |
| HIGH = 3 | |
| URGENT = 4 | |
| class TaskProgress: | |
| """任务进度信息""" | |
| stage: str = "" | |
| current: int = 0 | |
| total: int = 100 | |
| message: str = "" | |
| percentage: float = 0.0 | |
| def update(self, current: int = None, total: int = None, message: str = None): | |
| """更新进度信息""" | |
| if current is not None: | |
| self.current = current | |
| if total is not None: | |
| self.total = total | |
| if message is not None: | |
| self.message = message | |
| if self.total > 0: | |
| self.percentage = min(100.0, (self.current / self.total) * 100) | |
| class TaskResult: | |
| """任务结果""" | |
| success: bool = False | |
| data: Optional[Dict] = None | |
| error_message: Optional[str] = None | |
| processed_files: List[str] = field(default_factory=list) | |
| failed_files: List[str] = field(default_factory=list) | |
| transcription_results: Optional[Dict] = None | |
| duration: float = 0.0 | |
| def to_dict(self) -> Dict: | |
| """转换为字典格式""" | |
| return { | |
| 'success': self.success, | |
| 'data': self.data, | |
| 'error_message': self.error_message, | |
| 'processed_files': self.processed_files, | |
| 'failed_files': self.failed_files, | |
| 'transcription_results': self.transcription_results, | |
| 'duration': self.duration | |
| } | |
| class Task: | |
| """任务信息""" | |
| id: str = field(default_factory=lambda: str(uuid.uuid4())[:8]) | |
| status: TaskStatus = TaskStatus.PENDING | |
| priority: TaskPriority = TaskPriority.NORMAL | |
| file_paths: List[Path] = field(default_factory=list) | |
| progress: TaskProgress = field(default_factory=TaskProgress) | |
| result: TaskResult = field(default_factory=TaskResult) | |
| created_at: datetime = field(default_factory=datetime.now) | |
| started_at: Optional[datetime] = None | |
| completed_at: Optional[datetime] = None | |
| callback: Optional[Callable] = None | |
| metadata: Dict[str, Any] = field(default_factory=dict) | |
| def to_dict(self) -> Dict: | |
| """转换为字典格式""" | |
| return { | |
| 'id': self.id, | |
| 'status': self.status.value, | |
| 'priority': self.priority.value, | |
| 'file_count': len(self.file_paths), | |
| 'file_names': [fp.name for fp in self.file_paths], | |
| 'progress': { | |
| 'stage': self.progress.stage, | |
| 'current': self.progress.current, | |
| 'total': self.progress.total, | |
| 'percentage': self.progress.percentage, | |
| 'message': self.progress.message | |
| }, | |
| 'result': self.result.to_dict(), | |
| 'created_at': self.created_at.isoformat() if self.created_at else None, | |
| 'started_at': self.started_at.isoformat() if self.started_at else None, | |
| 'completed_at': self.completed_at.isoformat() if self.completed_at else None, | |
| 'metadata': self.metadata | |
| } | |
| class TaskManager: | |
| """任务管理器""" | |
| def __init__(self): | |
| """初始化任务管理器""" | |
| self.config = get_config() | |
| self.logger = get_task_logger(logger_name="transcript_service.task") | |
| # 任务存储 | |
| self.tasks: Dict[str, Task] = {} | |
| self.task_queue: asyncio.Queue = asyncio.Queue(maxsize=self.config.task.queue_size) | |
| # 服务实例 | |
| self.file_validator = get_file_validator() | |
| self.oss_service = get_oss_service() | |
| self.paraformer_service = get_paraformer_service() | |
| # 工作线程池 | |
| self.executor = ThreadPoolExecutor(max_workers=self.config.app.concurrent_tasks) | |
| # 状态回调 | |
| self.status_callbacks: List[Callable] = [] | |
| # 任务处理器状态 | |
| self._processor_started = False | |
| # 启动任务处理器 | |
| self._start_task_processor() | |
| def add_status_callback(self, callback: Callable): | |
| """添加状态变化回调函数 | |
| Args: | |
| callback: 回调函数 | |
| """ | |
| self.status_callbacks.append(callback) | |
| def _notify_status_change(self, task: Task): | |
| """通知状态变化""" | |
| for callback in self.status_callbacks: | |
| try: | |
| callback(task) | |
| except Exception as e: | |
| self.logger.error(f"回调函数执行失败: {str(e)}") | |
| async def create_task(self, file_paths: List[Path], priority: TaskPriority = TaskPriority.NORMAL, metadata = None) -> str: | |
| """创建新任务 | |
| Args: | |
| file_paths: 文件路径列表 | |
| priority: 任务优先级 | |
| metadata: 任务元数据 | |
| Returns: | |
| 任务ID | |
| """ | |
| # 确保任务处理器已启动 | |
| if not self._processor_started: | |
| self._ensure_processor_started() | |
| task = Task( | |
| file_paths=file_paths, | |
| priority=priority, | |
| metadata=metadata or {} | |
| ) | |
| self.tasks[task.id] = task | |
| # 添加到队列 | |
| await self.task_queue.put(task.id) | |
| self.logger.info(f"创建任务: {task.id}, 文件数量: {len(file_paths)}") | |
| return task.id | |
| def get_task(self, task_id: str) -> Optional[Task]: | |
| """获取任务信息 | |
| Args: | |
| task_id: 任务ID | |
| Returns: | |
| 任务对象 | |
| """ | |
| return self.tasks.get(task_id) | |
| def get_all_tasks(self) -> List[Task]: | |
| """获取所有任务""" | |
| return list(self.tasks.values()) | |
| def get_tasks_by_status(self, status: TaskStatus) -> List[Task]: | |
| """根据状态获取任务""" | |
| return [task for task in self.tasks.values() if task.status == status] | |
| async def cancel_task(self, task_id: str) -> bool: | |
| """取消任务 | |
| Args: | |
| task_id: 任务ID | |
| Returns: | |
| 是否成功取消 | |
| """ | |
| task = self.get_task(task_id) | |
| if not task: | |
| return False | |
| if task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED]: | |
| return False | |
| task.status = TaskStatus.CANCELLED | |
| task.completed_at = datetime.now() | |
| task.progress.message = "任务已取消" | |
| self._notify_status_change(task) | |
| self.logger.info(f"任务已取消: {task_id}") | |
| return True | |
| def _start_task_processor(self): | |
| """启动任务处理器""" | |
| try: | |
| # 只有在有运行的事件循环时才启动任务处理器 | |
| loop = asyncio.get_running_loop() | |
| asyncio.create_task(self._process_tasks()) | |
| except RuntimeError: | |
| # 没有运行的事件循环,延迟启动 | |
| self.logger.debug("没有运行的事件循环,任务处理器将在需要时启动") | |
| self._processor_started = False | |
| else: | |
| self._processor_started = True | |
| def _ensure_processor_started(self): | |
| """确保任务处理器已启动""" | |
| if not self._processor_started: | |
| try: | |
| loop = asyncio.get_running_loop() | |
| asyncio.create_task(self._process_tasks()) | |
| self._processor_started = True | |
| except RuntimeError: | |
| self.logger.warning("无法启动任务处理器:没有运行的事件循环") | |
| async def _process_tasks(self): | |
| """处理任务队列""" | |
| while True: | |
| try: | |
| # 从队列获取任务 | |
| task_id = await self.task_queue.get() | |
| task = self.get_task(task_id) | |
| if not task or task.status == TaskStatus.CANCELLED: | |
| self.task_queue.task_done() | |
| continue | |
| # 处理任务 | |
| await self._execute_task(task) | |
| self.task_queue.task_done() | |
| except Exception as e: | |
| self.logger.exception(f"处理任务队列时发生错误: {str(e)}") | |
| await asyncio.sleep(1) | |
| async def _execute_task(self, task: Task): | |
| """执行任务 | |
| Args: | |
| task: 任务对象 | |
| """ | |
| try: | |
| # 设置任务日志上下文 | |
| self.logger.set_task_id(task.id) | |
| task.status = TaskStatus.VALIDATING | |
| task.started_at = datetime.now() | |
| task.progress.stage = "文件验证" | |
| task.progress.update(0, 100, "开始验证文件") | |
| self._notify_status_change(task) | |
| # 1. 文件验证 | |
| valid_files, invalid_files = await self._validate_files(task) | |
| if not valid_files: | |
| task.status = TaskStatus.FAILED | |
| task.result.error_message = "没有有效的文件" | |
| task.result.failed_files = [str(f[0]) for f in invalid_files] | |
| task.completed_at = datetime.now() | |
| self._notify_status_change(task) | |
| return | |
| # 2. 文件上传 | |
| task.status = TaskStatus.UPLOADING | |
| task.progress.stage = "文件上传" | |
| task.progress.update(0, len(valid_files), "开始上传文件到OSS") | |
| self._notify_status_change(task) | |
| upload_results = await self._upload_files(task, valid_files) | |
| successful_uploads = [r for r in upload_results if r[1]] | |
| if not successful_uploads: | |
| task.status = TaskStatus.FAILED | |
| task.result.error_message = "文件上传失败" | |
| task.completed_at = datetime.now() | |
| self._notify_status_change(task) | |
| return | |
| # 3. 转录处理 | |
| task.status = TaskStatus.TRANSCRIBING | |
| task.progress.stage = "语音转录" | |
| task.progress.update(0, 100, "开始语音转录") | |
| self._notify_status_change(task) | |
| file_urls = [r[2] for r in successful_uploads] | |
| success, transcription_result, error = await self._transcribe_audio(task, file_urls) | |
| # 4. 完成任务 | |
| task.completed_at = datetime.now() | |
| task.result.duration = (task.completed_at - task.started_at).total_seconds() | |
| if success: | |
| task.status = TaskStatus.COMPLETED | |
| task.result.success = True | |
| task.result.transcription_results = transcription_result | |
| task.result.processed_files = [r[0] for r in successful_uploads] | |
| task.progress.update(100, 100, "转录完成") | |
| else: | |
| task.status = TaskStatus.FAILED | |
| task.result.error_message = error | |
| self._notify_status_change(task) | |
| except Exception as e: | |
| task.status = TaskStatus.FAILED | |
| task.result.error_message = f"任务执行失败: {str(e)}" | |
| task.completed_at = datetime.now() | |
| self.logger.exception(f"执行任务时发生错误: {task.id}") | |
| self._notify_status_change(task) | |
| finally: | |
| self.logger.clear_task_id() | |
| async def _validate_files(self, task: Task) -> tuple: | |
| """验证文件""" | |
| self.logger.info(f"开始验证 {len(task.file_paths)} 个文件") | |
| valid_files, invalid_files = self.file_validator.validate_multiple_files(task.file_paths) | |
| task.progress.update(100, 100, f"验证完成: {len(valid_files)} 个有效文件") | |
| self.logger.info(f"文件验证完成: {len(valid_files)} 个有效文件, {len(invalid_files)} 个无效文件") | |
| return valid_files, invalid_files | |
| async def _upload_files(self, task: Task, file_paths: List[Path]) -> List[tuple]: | |
| """上传文件""" | |
| self.logger.info(f"开始上传 {len(file_paths)} 个文件") | |
| results = [] | |
| for i, file_path in enumerate(file_paths): | |
| if task.status == TaskStatus.CANCELLED: | |
| break | |
| success, url_or_error, object_key = await self.oss_service.upload_file(file_path, task.id) | |
| results.append((file_path.name, success, url_or_error, object_key)) | |
| # 更新进度 | |
| task.progress.update(i + 1, len(file_paths), f"已上传 {i + 1}/{len(file_paths)} 个文件") | |
| self._notify_status_change(task) | |
| self.logger.info(f"文件上传完成: {len([r for r in results if r[1]])} 个成功") | |
| return results | |
| async def _transcribe_audio(self, task: Task, file_urls: List[str]) -> tuple: | |
| """转录音频""" | |
| self.logger.info(f"开始转录 {len(file_urls)} 个音频文件") | |
| # 提取Paraformer参数 | |
| paraformer_params = None | |
| if 'paraformer_params' in task.metadata: | |
| paraformer_params = task.metadata['paraformer_params'] | |
| self.logger.info(f"使用自定义Paraformer参数: {paraformer_params}") | |
| success, results, error = await self.paraformer_service.batch_process_with_retry( | |
| file_urls, task.id, paraformer_params | |
| ) | |
| if success: | |
| task.progress.update(100, 100, "转录完成") | |
| self.logger.info(f"转录完成: {len(file_urls)} 个文件") | |
| else: | |
| self.logger.error(f"转录失败: {error}") | |
| return success, results, error | |
| def cleanup_completed_tasks(self, hours: int = 24): | |
| """清理已完成的任务 | |
| Args: | |
| hours: 保留时间(小时) | |
| """ | |
| cutoff_time = datetime.now() - timedelta(hours=hours) | |
| to_remove = [] | |
| for task_id, task in self.tasks.items(): | |
| if (task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED] and | |
| task.completed_at and task.completed_at < cutoff_time): | |
| to_remove.append(task_id) | |
| for task_id in to_remove: | |
| del self.tasks[task_id] | |
| self.logger.info(f"清理了 {len(to_remove)} 个过期任务") | |
| def get_statistics(self) -> Dict: | |
| """获取任务统计信息""" | |
| stats = { | |
| 'total_tasks': len(self.tasks), | |
| 'pending': len(self.get_tasks_by_status(TaskStatus.PENDING)), | |
| 'validating': len(self.get_tasks_by_status(TaskStatus.VALIDATING)), | |
| 'uploading': len(self.get_tasks_by_status(TaskStatus.UPLOADING)), | |
| 'transcribing': len(self.get_tasks_by_status(TaskStatus.TRANSCRIBING)), | |
| 'completed': len(self.get_tasks_by_status(TaskStatus.COMPLETED)), | |
| 'failed': len(self.get_tasks_by_status(TaskStatus.FAILED)), | |
| 'cancelled': len(self.get_tasks_by_status(TaskStatus.CANCELLED)), | |
| 'queue_size': self.task_queue.qsize() | |
| } | |
| return stats | |
| # 全局任务管理器实例 | |
| task_manager = None | |
| def get_task_manager() -> TaskManager: | |
| """获取任务管理器实例 | |
| Returns: | |
| 任务管理器实例 | |
| """ | |
| global task_manager | |
| if task_manager is None: | |
| task_manager = TaskManager() | |
| return task_manager |