universal-model-trainer / app /services /queue_service.py
vectorplasticity's picture
CRITICAL FIX: Use job_id from config to match DB record - fixes job tracking issue
297c5d3 verified
"""
Training Queue Service - Manages asynchronous training jobs
"""
import asyncio
import uuid
from datetime import datetime
from typing import Dict, Optional, Callable, Any
from enum import Enum
import logging
from dataclasses import dataclass, field
import json
logger = logging.getLogger(__name__)
class JobPriority(int, Enum):
"""Job priority levels."""
LOW = 1
NORMAL = 5
HIGH = 10
URGENT = 20
@dataclass(order=True)
class QueuedJob:
"""Represents a job in the queue."""
priority: int
job_id: str = field(compare=False)
created_at: datetime = field(compare=False, default_factory=datetime.utcnow)
config: Dict = field(compare=False, default_factory=dict)
callback: Optional[Callable] = field(compare=False, default=None)
class TrainingQueue:
"""Async training queue with priority support."""
def __init__(self, max_concurrent: int = 3):
self.max_concurrent = max_concurrent
self._queue: asyncio.PriorityQueue = asyncio.PriorityQueue()
self._active_jobs: Dict[str, asyncio.Task] = {}
self._job_status: Dict[str, Dict] = {}
self._workers: list = []
self._running = False
self._lock = asyncio.Lock()
self._job_history: list = []
self._callbacks: Dict[str, Callable] = {}
async def start(self):
"""Start the queue workers."""
if self._running:
return
self._running = True
logger.info(f"Starting training queue with {self.max_concurrent} workers")
# Start worker tasks
for i in range(self.max_concurrent):
worker = asyncio.create_task(self._worker(i))
self._workers.append(worker)
# Start cleanup task
asyncio.create_task(self._cleanup_loop())
async def stop(self):
"""Stop all workers gracefully."""
logger.info("Stopping training queue...")
self._running = False
# Cancel all active jobs
for job_id, task in self._active_jobs.items():
task.cancel()
logger.info(f"Cancelled job {job_id}")
# Wait for workers to finish
await asyncio.gather(*self._workers, return_exceptions=True)
self._workers.clear()
logger.info("Training queue stopped")
async def submit(
self,
config: Dict,
priority: int = JobPriority.NORMAL,
callback: Optional[Callable] = None
) -> str:
"""Submit a new training job to the queue.
Uses job_id from config if provided (for DB consistency),
otherwise generates a new one.
"""
# CRITICAL FIX: Use job_id from config if provided
job_id = config.get("job_id") or str(uuid.uuid4())
job = QueuedJob(
priority=priority,
job_id=job_id,
config=config,
callback=callback
)
async with self._lock:
self._job_status[job_id] = {
"status": "queued",
"position": self._queue.qsize(),
"created_at": datetime.utcnow().isoformat(),
"config": config
}
self._callbacks[job_id] = callback
await self._queue.put(job)
logger.info(f"Job {job_id} submitted with priority {priority}")
return job_id
async def get_status(self, job_id: str) -> Optional[Dict]:
"""Get the status of a job."""
async with self._lock:
return self._job_status.get(job_id)
async def cancel_job(self, job_id: str) -> bool:
"""Cancel a queued or running job."""
async with self._lock:
if job_id in self._active_jobs:
self._active_jobs[job_id].cancel()
self._job_status[job_id]["status"] = "cancelled"
return True
# Mark as cancelled if in queue
if job_id in self._job_status:
self._job_status[job_id]["status"] = "cancelled"
return True
return False
async def get_queue_size(self) -> int:
"""Get current queue size."""
return self._queue.qsize()
async def get_active_jobs(self) -> list:
"""Get list of active job IDs."""
async with self._lock:
return list(self._active_jobs.keys())
async def _worker(self, worker_id: int):
"""Worker coroutine for processing jobs."""
logger.info(f"Worker {worker_id} started")
while self._running:
try:
# Get next job with timeout
try:
job = await asyncio.wait_for(self._queue.get(), timeout=1.0)
except asyncio.TimeoutError:
continue
# Check if job was cancelled while queued
status = await self.get_status(job.job_id)
if status and status.get("status") == "cancelled":
logger.info(f"Worker {worker_id}: Job {job.job_id} was cancelled, skipping")
continue
# Update status
async with self._lock:
self._job_status[job.job_id] = {
"status": "running",
"started_at": datetime.utcnow().isoformat(),
"config": job.config
}
logger.info(f"Worker {worker_id}: Starting job {job.job_id}")
# Run the job
task = asyncio.create_task(self._run_job(job))
async with self._lock:
self._active_jobs[job.job_id] = task
try:
result = await task
async with self._lock:
self._job_status[job.job_id] = {
"status": "completed",
"completed_at": datetime.utcnow().isoformat(),
"result": result
}
logger.info(f"Worker {worker_id}: Job {job.job_id} completed")
# Execute callback
if job.callback:
try:
await job.callback(job.job_id, result)
except Exception as cb_err:
logger.error(f"Callback error: {cb_err}")
except asyncio.CancelledError:
async with self._lock:
self._job_status[job.job_id] = {
"status": "cancelled",
"cancelled_at": datetime.utcnow().isoformat()
}
logger.info(f"Worker {worker_id}: Job {job.job_id} cancelled")
except Exception as e:
async with self._lock:
self._job_status[job.job_id] = {
"status": "failed",
"failed_at": datetime.utcnow().isoformat(),
"error": str(e)
}
logger.error(f"Worker {worker_id}: Job {job.job_id} failed: {e}")
finally:
async with self._lock:
self._active_jobs.pop(job.job_id, None)
# Store in history
if job.job_id in self._job_status:
self._job_history.append(self._job_status[job.job_id])
# Keep only last 100 jobs in history
if len(self._job_history) > 100:
self._job_history.pop(0)
except Exception as e:
logger.error(f"Worker {worker_id} error: {e}", exc_info=True)
await asyncio.sleep(1)
logger.info(f"Worker {worker_id} stopped")
async def _run_job(self, job: QueuedJob) -> Dict:
"""Execute the actual training job."""
from app.services.training_service import TrainingService
training_service = TrainingService()
# Pass the job_id from the queued job (which matches DB record)
result = await training_service.train(job.job_id, job.config)
return result
async def _cleanup_loop(self):
"""Periodic cleanup of old job statuses."""
while self._running:
await asyncio.sleep(300) # Every 5 minutes
async with self._lock:
# Remove completed jobs older than 1 hour
cutoff = datetime.utcnow().timestamp() - 3600
to_remove = []
for job_id, status in self._job_status.items():
if status.get("status") in ["completed", "failed", "cancelled"]:
completed_str = status.get("completed_at") or status.get("failed_at") or status.get("cancelled_at")
if completed_str:
try:
completed_dt = datetime.fromisoformat(completed_str)
if completed_dt.timestamp() < cutoff:
to_remove.append(job_id)
except:
pass
for job_id in to_remove:
self._job_status.pop(job_id, None)
self._callbacks.pop(job_id, None)
if to_remove:
logger.info(f"Cleaned up {len(to_remove)} old job statuses")