""" Training Queue Service - Manages asynchronous training jobs """ import asyncio import uuid from datetime import datetime from typing import Dict, Optional, Callable, Any from enum import Enum import logging from dataclasses import dataclass, field import json logger = logging.getLogger(__name__) class JobPriority(int, Enum): """Job priority levels.""" LOW = 1 NORMAL = 5 HIGH = 10 URGENT = 20 @dataclass(order=True) class QueuedJob: """Represents a job in the queue.""" priority: int job_id: str = field(compare=False) created_at: datetime = field(compare=False, default_factory=datetime.utcnow) config: Dict = field(compare=False, default_factory=dict) callback: Optional[Callable] = field(compare=False, default=None) class TrainingQueue: """Async training queue with priority support.""" def __init__(self, max_concurrent: int = 3): self.max_concurrent = max_concurrent self._queue: asyncio.PriorityQueue = asyncio.PriorityQueue() self._active_jobs: Dict[str, asyncio.Task] = {} self._job_status: Dict[str, Dict] = {} self._workers: list = [] self._running = False self._lock = asyncio.Lock() self._job_history: list = [] self._callbacks: Dict[str, Callable] = {} async def start(self): """Start the queue workers.""" if self._running: return self._running = True logger.info(f"Starting training queue with {self.max_concurrent} workers") # Start worker tasks for i in range(self.max_concurrent): worker = asyncio.create_task(self._worker(i)) self._workers.append(worker) # Start cleanup task asyncio.create_task(self._cleanup_loop()) async def stop(self): """Stop all workers gracefully.""" logger.info("Stopping training queue...") self._running = False # Cancel all active jobs for job_id, task in self._active_jobs.items(): task.cancel() logger.info(f"Cancelled job {job_id}") # Wait for workers to finish await asyncio.gather(*self._workers, return_exceptions=True) self._workers.clear() logger.info("Training queue stopped") async def submit( self, config: Dict, priority: int = JobPriority.NORMAL, callback: Optional[Callable] = None ) -> str: """Submit a new training job to the queue. Uses job_id from config if provided (for DB consistency), otherwise generates a new one. """ # CRITICAL FIX: Use job_id from config if provided job_id = config.get("job_id") or str(uuid.uuid4()) job = QueuedJob( priority=priority, job_id=job_id, config=config, callback=callback ) async with self._lock: self._job_status[job_id] = { "status": "queued", "position": self._queue.qsize(), "created_at": datetime.utcnow().isoformat(), "config": config } self._callbacks[job_id] = callback await self._queue.put(job) logger.info(f"Job {job_id} submitted with priority {priority}") return job_id async def get_status(self, job_id: str) -> Optional[Dict]: """Get the status of a job.""" async with self._lock: return self._job_status.get(job_id) async def cancel_job(self, job_id: str) -> bool: """Cancel a queued or running job.""" async with self._lock: if job_id in self._active_jobs: self._active_jobs[job_id].cancel() self._job_status[job_id]["status"] = "cancelled" return True # Mark as cancelled if in queue if job_id in self._job_status: self._job_status[job_id]["status"] = "cancelled" return True return False async def get_queue_size(self) -> int: """Get current queue size.""" return self._queue.qsize() async def get_active_jobs(self) -> list: """Get list of active job IDs.""" async with self._lock: return list(self._active_jobs.keys()) async def _worker(self, worker_id: int): """Worker coroutine for processing jobs.""" logger.info(f"Worker {worker_id} started") while self._running: try: # Get next job with timeout try: job = await asyncio.wait_for(self._queue.get(), timeout=1.0) except asyncio.TimeoutError: continue # Check if job was cancelled while queued status = await self.get_status(job.job_id) if status and status.get("status") == "cancelled": logger.info(f"Worker {worker_id}: Job {job.job_id} was cancelled, skipping") continue # Update status async with self._lock: self._job_status[job.job_id] = { "status": "running", "started_at": datetime.utcnow().isoformat(), "config": job.config } logger.info(f"Worker {worker_id}: Starting job {job.job_id}") # Run the job task = asyncio.create_task(self._run_job(job)) async with self._lock: self._active_jobs[job.job_id] = task try: result = await task async with self._lock: self._job_status[job.job_id] = { "status": "completed", "completed_at": datetime.utcnow().isoformat(), "result": result } logger.info(f"Worker {worker_id}: Job {job.job_id} completed") # Execute callback if job.callback: try: await job.callback(job.job_id, result) except Exception as cb_err: logger.error(f"Callback error: {cb_err}") except asyncio.CancelledError: async with self._lock: self._job_status[job.job_id] = { "status": "cancelled", "cancelled_at": datetime.utcnow().isoformat() } logger.info(f"Worker {worker_id}: Job {job.job_id} cancelled") except Exception as e: async with self._lock: self._job_status[job.job_id] = { "status": "failed", "failed_at": datetime.utcnow().isoformat(), "error": str(e) } logger.error(f"Worker {worker_id}: Job {job.job_id} failed: {e}") finally: async with self._lock: self._active_jobs.pop(job.job_id, None) # Store in history if job.job_id in self._job_status: self._job_history.append(self._job_status[job.job_id]) # Keep only last 100 jobs in history if len(self._job_history) > 100: self._job_history.pop(0) except Exception as e: logger.error(f"Worker {worker_id} error: {e}", exc_info=True) await asyncio.sleep(1) logger.info(f"Worker {worker_id} stopped") async def _run_job(self, job: QueuedJob) -> Dict: """Execute the actual training job.""" from app.services.training_service import TrainingService training_service = TrainingService() # Pass the job_id from the queued job (which matches DB record) result = await training_service.train(job.job_id, job.config) return result async def _cleanup_loop(self): """Periodic cleanup of old job statuses.""" while self._running: await asyncio.sleep(300) # Every 5 minutes async with self._lock: # Remove completed jobs older than 1 hour cutoff = datetime.utcnow().timestamp() - 3600 to_remove = [] for job_id, status in self._job_status.items(): if status.get("status") in ["completed", "failed", "cancelled"]: completed_str = status.get("completed_at") or status.get("failed_at") or status.get("cancelled_at") if completed_str: try: completed_dt = datetime.fromisoformat(completed_str) if completed_dt.timestamp() < cutoff: to_remove.append(job_id) except: pass for job_id in to_remove: self._job_status.pop(job_id, None) self._callbacks.pop(job_id, None) if to_remove: logger.info(f"Cleaned up {len(to_remove)} old job statuses")