File size: 9,983 Bytes
75e2133 297c5d3 75e2133 297c5d3 75e2133 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 | """
Training Queue Service - Manages asynchronous training jobs
"""
import asyncio
import uuid
from datetime import datetime
from typing import Dict, Optional, Callable, Any
from enum import Enum
import logging
from dataclasses import dataclass, field
import json
logger = logging.getLogger(__name__)
class JobPriority(int, Enum):
"""Job priority levels."""
LOW = 1
NORMAL = 5
HIGH = 10
URGENT = 20
@dataclass(order=True)
class QueuedJob:
"""Represents a job in the queue."""
priority: int
job_id: str = field(compare=False)
created_at: datetime = field(compare=False, default_factory=datetime.utcnow)
config: Dict = field(compare=False, default_factory=dict)
callback: Optional[Callable] = field(compare=False, default=None)
class TrainingQueue:
"""Async training queue with priority support."""
def __init__(self, max_concurrent: int = 3):
self.max_concurrent = max_concurrent
self._queue: asyncio.PriorityQueue = asyncio.PriorityQueue()
self._active_jobs: Dict[str, asyncio.Task] = {}
self._job_status: Dict[str, Dict] = {}
self._workers: list = []
self._running = False
self._lock = asyncio.Lock()
self._job_history: list = []
self._callbacks: Dict[str, Callable] = {}
async def start(self):
"""Start the queue workers."""
if self._running:
return
self._running = True
logger.info(f"Starting training queue with {self.max_concurrent} workers")
# Start worker tasks
for i in range(self.max_concurrent):
worker = asyncio.create_task(self._worker(i))
self._workers.append(worker)
# Start cleanup task
asyncio.create_task(self._cleanup_loop())
async def stop(self):
"""Stop all workers gracefully."""
logger.info("Stopping training queue...")
self._running = False
# Cancel all active jobs
for job_id, task in self._active_jobs.items():
task.cancel()
logger.info(f"Cancelled job {job_id}")
# Wait for workers to finish
await asyncio.gather(*self._workers, return_exceptions=True)
self._workers.clear()
logger.info("Training queue stopped")
async def submit(
self,
config: Dict,
priority: int = JobPriority.NORMAL,
callback: Optional[Callable] = None
) -> str:
"""Submit a new training job to the queue.
Uses job_id from config if provided (for DB consistency),
otherwise generates a new one.
"""
# CRITICAL FIX: Use job_id from config if provided
job_id = config.get("job_id") or str(uuid.uuid4())
job = QueuedJob(
priority=priority,
job_id=job_id,
config=config,
callback=callback
)
async with self._lock:
self._job_status[job_id] = {
"status": "queued",
"position": self._queue.qsize(),
"created_at": datetime.utcnow().isoformat(),
"config": config
}
self._callbacks[job_id] = callback
await self._queue.put(job)
logger.info(f"Job {job_id} submitted with priority {priority}")
return job_id
async def get_status(self, job_id: str) -> Optional[Dict]:
"""Get the status of a job."""
async with self._lock:
return self._job_status.get(job_id)
async def cancel_job(self, job_id: str) -> bool:
"""Cancel a queued or running job."""
async with self._lock:
if job_id in self._active_jobs:
self._active_jobs[job_id].cancel()
self._job_status[job_id]["status"] = "cancelled"
return True
# Mark as cancelled if in queue
if job_id in self._job_status:
self._job_status[job_id]["status"] = "cancelled"
return True
return False
async def get_queue_size(self) -> int:
"""Get current queue size."""
return self._queue.qsize()
async def get_active_jobs(self) -> list:
"""Get list of active job IDs."""
async with self._lock:
return list(self._active_jobs.keys())
async def _worker(self, worker_id: int):
"""Worker coroutine for processing jobs."""
logger.info(f"Worker {worker_id} started")
while self._running:
try:
# Get next job with timeout
try:
job = await asyncio.wait_for(self._queue.get(), timeout=1.0)
except asyncio.TimeoutError:
continue
# Check if job was cancelled while queued
status = await self.get_status(job.job_id)
if status and status.get("status") == "cancelled":
logger.info(f"Worker {worker_id}: Job {job.job_id} was cancelled, skipping")
continue
# Update status
async with self._lock:
self._job_status[job.job_id] = {
"status": "running",
"started_at": datetime.utcnow().isoformat(),
"config": job.config
}
logger.info(f"Worker {worker_id}: Starting job {job.job_id}")
# Run the job
task = asyncio.create_task(self._run_job(job))
async with self._lock:
self._active_jobs[job.job_id] = task
try:
result = await task
async with self._lock:
self._job_status[job.job_id] = {
"status": "completed",
"completed_at": datetime.utcnow().isoformat(),
"result": result
}
logger.info(f"Worker {worker_id}: Job {job.job_id} completed")
# Execute callback
if job.callback:
try:
await job.callback(job.job_id, result)
except Exception as cb_err:
logger.error(f"Callback error: {cb_err}")
except asyncio.CancelledError:
async with self._lock:
self._job_status[job.job_id] = {
"status": "cancelled",
"cancelled_at": datetime.utcnow().isoformat()
}
logger.info(f"Worker {worker_id}: Job {job.job_id} cancelled")
except Exception as e:
async with self._lock:
self._job_status[job.job_id] = {
"status": "failed",
"failed_at": datetime.utcnow().isoformat(),
"error": str(e)
}
logger.error(f"Worker {worker_id}: Job {job.job_id} failed: {e}")
finally:
async with self._lock:
self._active_jobs.pop(job.job_id, None)
# Store in history
if job.job_id in self._job_status:
self._job_history.append(self._job_status[job.job_id])
# Keep only last 100 jobs in history
if len(self._job_history) > 100:
self._job_history.pop(0)
except Exception as e:
logger.error(f"Worker {worker_id} error: {e}", exc_info=True)
await asyncio.sleep(1)
logger.info(f"Worker {worker_id} stopped")
async def _run_job(self, job: QueuedJob) -> Dict:
"""Execute the actual training job."""
from app.services.training_service import TrainingService
training_service = TrainingService()
# Pass the job_id from the queued job (which matches DB record)
result = await training_service.train(job.job_id, job.config)
return result
async def _cleanup_loop(self):
"""Periodic cleanup of old job statuses."""
while self._running:
await asyncio.sleep(300) # Every 5 minutes
async with self._lock:
# Remove completed jobs older than 1 hour
cutoff = datetime.utcnow().timestamp() - 3600
to_remove = []
for job_id, status in self._job_status.items():
if status.get("status") in ["completed", "failed", "cancelled"]:
completed_str = status.get("completed_at") or status.get("failed_at") or status.get("cancelled_at")
if completed_str:
try:
completed_dt = datetime.fromisoformat(completed_str)
if completed_dt.timestamp() < cutoff:
to_remove.append(job_id)
except:
pass
for job_id in to_remove:
self._job_status.pop(job_id, None)
self._callbacks.pop(job_id, None)
if to_remove:
logger.info(f"Cleaned up {len(to_remove)} old job statuses") |