Spaces:
Running
Running
| """ | |
| api/websocket_manager.py | |
| ββββββββββββββββββββββββββ | |
| WebSocket connection manager for streaming execution logs. | |
| Each task_id has a list of connected WebSocket clients. | |
| When the Celery worker emits an event, it's broadcast to all | |
| connected clients watching that task. | |
| Pattern: pub/sub via Redis β worker publishes to Redis channel, | |
| FastAPI subscribes and forwards to WebSocket clients. | |
| Fallback: in-memory queue (single-process mode for development). | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import json | |
| import logging | |
| from collections import defaultdict | |
| from typing import TYPE_CHECKING | |
| from fastapi import WebSocket | |
| if TYPE_CHECKING: | |
| pass | |
| logger = logging.getLogger(__name__) | |
| class WebSocketManager: | |
| """ | |
| Manages active WebSocket connections per task_id. | |
| Usage: | |
| manager = WebSocketManager() | |
| # In WebSocket endpoint: | |
| await manager.connect(task_id, websocket) | |
| # In Celery task (via Redis pub/sub): | |
| await manager.broadcast(task_id, event_dict) | |
| """ | |
| def __init__(self): | |
| # task_id β list of active WebSocket connections | |
| self._connections: dict[str, list[WebSocket]] = defaultdict(list) | |
| # task_id β event queue (for in-memory fallback) | |
| self._queues: dict[str, asyncio.Queue] = defaultdict(asyncio.Queue) | |
| async def connect(self, task_id: str, websocket: WebSocket) -> None: | |
| await websocket.accept() | |
| self._connections[task_id].append(websocket) | |
| logger.info("WS connected: task=%s | total=%d", | |
| task_id, len(self._connections[task_id])) | |
| def disconnect(self, task_id: str, websocket: WebSocket) -> None: | |
| conns = self._connections.get(task_id, []) | |
| if websocket in conns: | |
| conns.remove(websocket) | |
| logger.info("WS disconnected: task=%s | remaining=%d", task_id, len(conns)) | |
| async def broadcast(self, task_id: str, event: dict) -> None: | |
| """Send an event to all WebSocket clients watching task_id.""" | |
| message = json.dumps(event) | |
| dead = [] | |
| for ws in self._connections.get(task_id, []): | |
| try: | |
| await ws.send_text(message) | |
| except Exception as e: | |
| logger.debug("WS send failed: %s", e) | |
| dead.append(ws) | |
| for ws in dead: | |
| self.disconnect(task_id, ws) | |
| async def emit(self, task_id: str, event_type: str, data: dict) -> None: | |
| """Convenience: wrap data in event envelope and broadcast.""" | |
| from datetime import datetime, timezone | |
| event = { | |
| "event": event_type, | |
| "data": data, | |
| "timestamp": datetime.now(timezone.utc).isoformat(), | |
| } | |
| await self.broadcast(task_id, event) | |
| def enqueue(self, task_id: str, event: dict) -> None: | |
| """ | |
| Non-async version for Celery workers. | |
| Events are stored in an asyncio.Queue and drained by the WS listener. | |
| """ | |
| try: | |
| self._queues[task_id].put_nowait(event) | |
| except asyncio.QueueFull: | |
| logger.warning("Event queue full for task %s β dropping event", task_id) | |
| async def drain_queue(self, task_id: str, websocket: WebSocket) -> None: | |
| """ | |
| Drain events from the in-memory queue and forward to WebSocket. | |
| Called by the WebSocket endpoint's receive loop. | |
| """ | |
| queue = self._queues[task_id] | |
| while True: | |
| try: | |
| event = queue.get_nowait() | |
| await websocket.send_text(json.dumps(event)) | |
| except asyncio.QueueEmpty: | |
| await asyncio.sleep(0.05) | |
| except Exception: | |
| break | |
| def active_tasks(self) -> list[str]: | |
| return [tid for tid, conns in self._connections.items() if conns] | |
| # Singleton used across the app | |
| ws_manager = WebSocketManager() | |