| """ |
| Training Queue Service - Manages asynchronous training jobs |
| """ |
|
|
| import asyncio |
| import uuid |
| from datetime import datetime |
| from typing import Dict, Optional, Callable, Any |
| from enum import Enum |
| import logging |
| from dataclasses import dataclass, field |
| import json |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class JobPriority(int, Enum): |
| """Job priority levels.""" |
| LOW = 1 |
| NORMAL = 5 |
| HIGH = 10 |
| URGENT = 20 |
|
|
|
|
| @dataclass(order=True) |
| class QueuedJob: |
| """Represents a job in the queue.""" |
| priority: int |
| job_id: str = field(compare=False) |
| created_at: datetime = field(compare=False, default_factory=datetime.utcnow) |
| config: Dict = field(compare=False, default_factory=dict) |
| callback: Optional[Callable] = field(compare=False, default=None) |
|
|
|
|
| class TrainingQueue: |
| """Async training queue with priority support.""" |
| |
| def __init__(self, max_concurrent: int = 3): |
| self.max_concurrent = max_concurrent |
| self._queue: asyncio.PriorityQueue = asyncio.PriorityQueue() |
| self._active_jobs: Dict[str, asyncio.Task] = {} |
| self._job_status: Dict[str, Dict] = {} |
| self._workers: list = [] |
| self._running = False |
| self._lock = asyncio.Lock() |
| self._job_history: list = [] |
| self._callbacks: Dict[str, Callable] = {} |
| |
| async def start(self): |
| """Start the queue workers.""" |
| if self._running: |
| return |
| |
| self._running = True |
| logger.info(f"Starting training queue with {self.max_concurrent} workers") |
| |
| |
| for i in range(self.max_concurrent): |
| worker = asyncio.create_task(self._worker(i)) |
| self._workers.append(worker) |
| |
| asyncio.create_task(self._cleanup_loop()) |
| |
| async def stop(self): |
| """Stop all workers gracefully.""" |
| logger.info("Stopping training queue...") |
| self._running = False |
| |
| |
| for job_id, task in self._active_jobs.items(): |
| task.cancel() |
| logger.info(f"Cancelled job {job_id}") |
| |
| |
| await asyncio.gather(*self._workers, return_exceptions=True) |
| self._workers.clear() |
| |
| logger.info("Training queue stopped") |
| |
| async def submit( |
| self, |
| config: Dict, |
| priority: int = JobPriority.NORMAL, |
| callback: Optional[Callable] = None |
| ) -> str: |
| """Submit a new training job to the queue. |
| |
| Uses job_id from config if provided (for DB consistency), |
| otherwise generates a new one. |
| """ |
| |
| job_id = config.get("job_id") or str(uuid.uuid4()) |
| |
| job = QueuedJob( |
| priority=priority, |
| job_id=job_id, |
| config=config, |
| callback=callback |
| ) |
| |
| async with self._lock: |
| self._job_status[job_id] = { |
| "status": "queued", |
| "position": self._queue.qsize(), |
| "created_at": datetime.utcnow().isoformat(), |
| "config": config |
| } |
| self._callbacks[job_id] = callback |
| |
| await self._queue.put(job) |
| logger.info(f"Job {job_id} submitted with priority {priority}") |
| |
| return job_id |
| |
| async def get_status(self, job_id: str) -> Optional[Dict]: |
| """Get the status of a job.""" |
| async with self._lock: |
| return self._job_status.get(job_id) |
| |
| async def cancel_job(self, job_id: str) -> bool: |
| """Cancel a queued or running job.""" |
| async with self._lock: |
| if job_id in self._active_jobs: |
| self._active_jobs[job_id].cancel() |
| self._job_status[job_id]["status"] = "cancelled" |
| return True |
| |
| |
| if job_id in self._job_status: |
| self._job_status[job_id]["status"] = "cancelled" |
| return True |
| |
| return False |
| |
| async def get_queue_size(self) -> int: |
| """Get current queue size.""" |
| return self._queue.qsize() |
| |
| async def get_active_jobs(self) -> list: |
| """Get list of active job IDs.""" |
| async with self._lock: |
| return list(self._active_jobs.keys()) |
| |
| async def _worker(self, worker_id: int): |
| """Worker coroutine for processing jobs.""" |
| logger.info(f"Worker {worker_id} started") |
| |
| while self._running: |
| try: |
| |
| try: |
| job = await asyncio.wait_for(self._queue.get(), timeout=1.0) |
| except asyncio.TimeoutError: |
| continue |
| |
| |
| status = await self.get_status(job.job_id) |
| if status and status.get("status") == "cancelled": |
| logger.info(f"Worker {worker_id}: Job {job.job_id} was cancelled, skipping") |
| continue |
| |
| |
| async with self._lock: |
| self._job_status[job.job_id] = { |
| "status": "running", |
| "started_at": datetime.utcnow().isoformat(), |
| "config": job.config |
| } |
| |
| logger.info(f"Worker {worker_id}: Starting job {job.job_id}") |
| |
| |
| task = asyncio.create_task(self._run_job(job)) |
| async with self._lock: |
| self._active_jobs[job.job_id] = task |
| |
| try: |
| result = await task |
| async with self._lock: |
| self._job_status[job.job_id] = { |
| "status": "completed", |
| "completed_at": datetime.utcnow().isoformat(), |
| "result": result |
| } |
| logger.info(f"Worker {worker_id}: Job {job.job_id} completed") |
| |
| |
| if job.callback: |
| try: |
| await job.callback(job.job_id, result) |
| except Exception as cb_err: |
| logger.error(f"Callback error: {cb_err}") |
| |
| except asyncio.CancelledError: |
| async with self._lock: |
| self._job_status[job.job_id] = { |
| "status": "cancelled", |
| "cancelled_at": datetime.utcnow().isoformat() |
| } |
| logger.info(f"Worker {worker_id}: Job {job.job_id} cancelled") |
| |
| except Exception as e: |
| async with self._lock: |
| self._job_status[job.job_id] = { |
| "status": "failed", |
| "failed_at": datetime.utcnow().isoformat(), |
| "error": str(e) |
| } |
| logger.error(f"Worker {worker_id}: Job {job.job_id} failed: {e}") |
| |
| finally: |
| async with self._lock: |
| self._active_jobs.pop(job.job_id, None) |
| |
| if job.job_id in self._job_status: |
| self._job_history.append(self._job_status[job.job_id]) |
| |
| if len(self._job_history) > 100: |
| self._job_history.pop(0) |
| |
| except Exception as e: |
| logger.error(f"Worker {worker_id} error: {e}", exc_info=True) |
| await asyncio.sleep(1) |
| |
| logger.info(f"Worker {worker_id} stopped") |
| |
| async def _run_job(self, job: QueuedJob) -> Dict: |
| """Execute the actual training job.""" |
| from app.services.training_service import TrainingService |
| |
| training_service = TrainingService() |
| |
| result = await training_service.train(job.job_id, job.config) |
| |
| return result |
| |
| async def _cleanup_loop(self): |
| """Periodic cleanup of old job statuses.""" |
| while self._running: |
| await asyncio.sleep(300) |
| |
| async with self._lock: |
| |
| cutoff = datetime.utcnow().timestamp() - 3600 |
| to_remove = [] |
| |
| for job_id, status in self._job_status.items(): |
| if status.get("status") in ["completed", "failed", "cancelled"]: |
| completed_str = status.get("completed_at") or status.get("failed_at") or status.get("cancelled_at") |
| if completed_str: |
| try: |
| completed_dt = datetime.fromisoformat(completed_str) |
| if completed_dt.timestamp() < cutoff: |
| to_remove.append(job_id) |
| except: |
| pass |
| |
| for job_id in to_remove: |
| self._job_status.pop(job_id, None) |
| self._callbacks.pop(job_id, None) |
| |
| if to_remove: |
| logger.info(f"Cleaned up {len(to_remove)} old job statuses") |