File size: 3,950 Bytes
dc71cad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
"""
api/websocket_manager.py
──────────────────────────
WebSocket connection manager for streaming execution logs.

Each task_id has a list of connected WebSocket clients.
When the Celery worker emits an event, it's broadcast to all
connected clients watching that task.

Pattern: pub/sub via Redis β€” worker publishes to Redis channel,
FastAPI subscribes and forwards to WebSocket clients.
Fallback: in-memory queue (single-process mode for development).
"""
from __future__ import annotations

import asyncio
import json
import logging
from collections import defaultdict
from typing import TYPE_CHECKING

from fastapi import WebSocket

if TYPE_CHECKING:
    pass

logger = logging.getLogger(__name__)


class WebSocketManager:
    """
    Manages active WebSocket connections per task_id.

    Usage:
        manager = WebSocketManager()
        
        # In WebSocket endpoint:
        await manager.connect(task_id, websocket)
        
        # In Celery task (via Redis pub/sub):
        await manager.broadcast(task_id, event_dict)
    """

    def __init__(self):
        # task_id β†’ list of active WebSocket connections
        self._connections: dict[str, list[WebSocket]] = defaultdict(list)
        # task_id β†’ event queue (for in-memory fallback)
        self._queues: dict[str, asyncio.Queue] = defaultdict(asyncio.Queue)

    async def connect(self, task_id: str, websocket: WebSocket) -> None:
        await websocket.accept()
        self._connections[task_id].append(websocket)
        logger.info("WS connected: task=%s | total=%d",
                    task_id, len(self._connections[task_id]))

    def disconnect(self, task_id: str, websocket: WebSocket) -> None:
        conns = self._connections.get(task_id, [])
        if websocket in conns:
            conns.remove(websocket)
        logger.info("WS disconnected: task=%s | remaining=%d", task_id, len(conns))

    async def broadcast(self, task_id: str, event: dict) -> None:
        """Send an event to all WebSocket clients watching task_id."""
        message = json.dumps(event)
        dead = []
        for ws in self._connections.get(task_id, []):
            try:
                await ws.send_text(message)
            except Exception as e:
                logger.debug("WS send failed: %s", e)
                dead.append(ws)
        for ws in dead:
            self.disconnect(task_id, ws)

    async def emit(self, task_id: str, event_type: str, data: dict) -> None:
        """Convenience: wrap data in event envelope and broadcast."""
        from datetime import datetime, timezone
        event = {
            "event": event_type,
            "data": data,
            "timestamp": datetime.now(timezone.utc).isoformat(),
        }
        await self.broadcast(task_id, event)

    def enqueue(self, task_id: str, event: dict) -> None:
        """
        Non-async version for Celery workers.
        Events are stored in an asyncio.Queue and drained by the WS listener.
        """
        try:
            self._queues[task_id].put_nowait(event)
        except asyncio.QueueFull:
            logger.warning("Event queue full for task %s β€” dropping event", task_id)

    async def drain_queue(self, task_id: str, websocket: WebSocket) -> None:
        """
        Drain events from the in-memory queue and forward to WebSocket.
        Called by the WebSocket endpoint's receive loop.
        """
        queue = self._queues[task_id]
        while True:
            try:
                event = queue.get_nowait()
                await websocket.send_text(json.dumps(event))
            except asyncio.QueueEmpty:
                await asyncio.sleep(0.05)
            except Exception:
                break

    def active_tasks(self) -> list[str]:
        return [tid for tid, conns in self._connections.items() if conns]


# Singleton used across the app
ws_manager = WebSocketManager()