| from datetime import datetime, timedelta, timezone |
| from pathlib import Path |
| from typing import Dict, List, Optional |
| from uuid import uuid4 |
| import asyncio |
| import base64 |
| import binascii |
| import contextlib |
| import json |
| import logging |
| import mimetypes |
| import os |
|
|
| import httpx |
| import uvicorn |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect |
| from fastapi.responses import HTMLResponse |
| from fastapi.staticfiles import StaticFiles |
|
|
|
|
| def load_env_file(path: str = ".env") -> None: |
| env_path = Path(path) |
| if not env_path.exists(): |
| return |
|
|
| for raw_line in env_path.read_text(encoding="utf-8").splitlines(): |
| line = raw_line.strip() |
| if not line or line.startswith("#") or "=" not in line: |
| continue |
|
|
| key, value = line.split("=", 1) |
| os.environ.setdefault(key.strip(), value.strip().strip('"').strip("'")) |
|
|
|
|
| load_env_file() |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger("trichat") |
|
|
| MAX_HISTORY_MESSAGES = 100 |
| MAX_FILE_SIZE = 5 * 1024 * 1024 |
| ROOM_HISTORY_HOURS = int(os.getenv("ROOM_HISTORY_HOURS", "5")) |
| HISTORY_CACHE_SECONDS = int(os.getenv("HISTORY_CACHE_SECONDS", "30")) |
| CLEANUP_INTERVAL_SECONDS = int(os.getenv("CLEANUP_INTERVAL_SECONDS", "600")) |
| SUPABASE_URL = os.getenv("SUPABASE_URL", "").rstrip("/") |
| SUPABASE_KEY = os.getenv("SUPABASE_SERVICE_ROLE_KEY", "") |
| SUPABASE_BUCKET = os.getenv("SUPABASE_BUCKET", "chat-files") |
|
|
| app = FastAPI(title="Tri-Chat API", description="Temporary anonymous chat rooms") |
| app.mount("/static", StaticFiles(directory="static"), name="static") |
|
|
|
|
| def now_utc() -> datetime: |
| return datetime.now(timezone.utc) |
|
|
|
|
| def utc_now() -> str: |
| return now_utc().isoformat() |
|
|
|
|
| def expiry_time() -> str: |
| return (now_utc() + timedelta(hours=ROOM_HISTORY_HOURS)).isoformat() |
|
|
|
|
| def clean_text(value: object, limit: int, default: str = "") -> str: |
| if not isinstance(value, str): |
| return default |
| return value.strip()[:limit] |
|
|
|
|
| def safe_file_name(file_name: str) -> str: |
| cleaned = Path(file_name).name.strip()[:100] |
| return cleaned or "upload.bin" |
|
|
|
|
| def storage_path_for(room: str, file_name: str) -> str: |
| room_prefix = "".join(char if char.isalnum() or char in ("-", "_") else "_" for char in room) |
| return f"{room_prefix}/{uuid4().hex}-{safe_file_name(file_name)}" |
|
|
|
|
| def parse_iso_datetime(value: str) -> Optional[datetime]: |
| if not value: |
| return None |
|
|
| try: |
| return datetime.fromisoformat(value.replace("Z", "+00:00")) |
| except ValueError: |
| return None |
|
|
|
|
| def is_expired(message: dict) -> bool: |
| expires_at = parse_iso_datetime(message.get("expiresAt", "")) |
| return bool(expires_at and expires_at <= now_utc()) |
|
|
|
|
| class SupabaseStore: |
| def __init__(self, url: str, key: str, bucket: str): |
| self.url = url |
| self.key = key |
| self.bucket = bucket |
|
|
| @property |
| def enabled(self) -> bool: |
| return bool(self.url and self.key) |
|
|
| @property |
| def headers(self) -> Dict[str, str]: |
| return { |
| "apikey": self.key, |
| "Authorization": f"Bearer {self.key}", |
| } |
|
|
| async def fetch_history(self, room: str) -> List[dict]: |
| if not self.enabled: |
| return [] |
|
|
| params = { |
| "select": "*", |
| "room": f"eq.{room}", |
| "expires_at": f"gt.{utc_now()}", |
| "order": "created_at.desc", |
| "limit": str(MAX_HISTORY_MESSAGES), |
| } |
|
|
| async with httpx.AsyncClient(timeout=10) as client: |
| response = await client.get( |
| f"{self.url}/rest/v1/messages", |
| headers=self.headers, |
| params=params, |
| ) |
| response.raise_for_status() |
|
|
| rows = list(reversed(response.json())) |
| return [self.row_to_message(row) for row in rows] |
|
|
| async def save_message(self, message: dict) -> None: |
| if not self.enabled or message.get("type") == "system": |
| return |
|
|
| payload = { |
| "room": message["room"], |
| "username": message["username"], |
| "message_type": message["type"], |
| "text": message.get("text"), |
| "file_url": message.get("fileUrl"), |
| "file_path": message.get("filePath"), |
| "file_name": message.get("fileName"), |
| "file_type": message.get("fileType"), |
| "file_size": message.get("fileSize"), |
| "created_at": message["timestamp"], |
| "expires_at": message["expiresAt"], |
| } |
|
|
| async with httpx.AsyncClient(timeout=10) as client: |
| response = await client.post( |
| f"{self.url}/rest/v1/messages", |
| headers={**self.headers, "Content-Type": "application/json"}, |
| json=payload, |
| ) |
| response.raise_for_status() |
|
|
| async def upload_file(self, room: str, file_name: str, file_type: str, file_bytes: bytes) -> tuple[str, str]: |
| if not self.enabled: |
| raise RuntimeError("Supabase is not configured") |
|
|
| path = storage_path_for(room, file_name) |
| content_type = file_type or mimetypes.guess_type(file_name)[0] or "application/octet-stream" |
|
|
| async with httpx.AsyncClient(timeout=30) as client: |
| response = await client.post( |
| f"{self.url}/storage/v1/object/{self.bucket}/{path}", |
| headers={ |
| **self.headers, |
| "Content-Type": content_type, |
| "x-upsert": "false", |
| }, |
| content=file_bytes, |
| ) |
| response.raise_for_status() |
|
|
| public_url = f"{self.url}/storage/v1/object/public/{self.bucket}/{path}" |
| return public_url, path |
|
|
| async def delete_storage_objects(self, paths: List[str]) -> None: |
| if not self.enabled or not paths: |
| return |
|
|
| async with httpx.AsyncClient(timeout=30) as client: |
| response = await client.request( |
| "DELETE", |
| f"{self.url}/storage/v1/object/{self.bucket}", |
| headers={**self.headers, "Content-Type": "application/json"}, |
| json={"prefixes": paths}, |
| ) |
| response.raise_for_status() |
|
|
| async def cleanup_expired(self) -> int: |
| if not self.enabled: |
| return 0 |
|
|
| current_time = utc_now() |
| async with httpx.AsyncClient(timeout=30) as client: |
| select_response = await client.get( |
| f"{self.url}/rest/v1/messages", |
| headers=self.headers, |
| params={ |
| "select": "id,file_path", |
| "expires_at": f"lte.{current_time}", |
| "limit": "500", |
| }, |
| ) |
| select_response.raise_for_status() |
| expired_rows = select_response.json() |
|
|
| if not expired_rows: |
| return 0 |
|
|
| file_paths = [row["file_path"] for row in expired_rows if row.get("file_path")] |
| if file_paths: |
| await self.delete_storage_objects(file_paths) |
|
|
| delete_response = await client.delete( |
| f"{self.url}/rest/v1/messages", |
| headers=self.headers, |
| params={"expires_at": f"lte.{current_time}"}, |
| ) |
| delete_response.raise_for_status() |
|
|
| return len(expired_rows) |
|
|
| def row_to_message(self, row: dict) -> dict: |
| message = { |
| "type": row["message_type"], |
| "username": row["username"], |
| "timestamp": row["created_at"], |
| "expiresAt": row["expires_at"], |
| "room": row["room"], |
| } |
|
|
| if row["message_type"] == "text": |
| message["text"] = row.get("text") or "" |
| elif row["message_type"] == "file": |
| message.update( |
| { |
| "fileUrl": row.get("file_url") or "", |
| "fileName": row.get("file_name") or "download", |
| "fileType": row.get("file_type") or "application/octet-stream", |
| "fileSize": row.get("file_size") or 0, |
| } |
| ) |
|
|
| return message |
|
|
|
|
| store = SupabaseStore(SUPABASE_URL, SUPABASE_KEY, SUPABASE_BUCKET) |
|
|
|
|
| class ConnectionManager: |
| def __init__(self): |
| self.active_connections: Dict[str, List[Dict]] = {} |
| self.fallback_history: Dict[str, List[Dict]] = {} |
| self.history_cache: Dict[str, Dict] = {} |
|
|
| async def connect(self, websocket: WebSocket, room: str, username: str): |
| await websocket.accept() |
|
|
| if room not in self.active_connections: |
| self.active_connections[room] = [] |
| self.fallback_history[room] = [] |
|
|
| self.active_connections[room].append( |
| { |
| "websocket": websocket, |
| "username": username, |
| "joined_at": utc_now(), |
| } |
| ) |
|
|
| for message in await self.get_history(room): |
| await websocket.send_text(json.dumps(message)) |
|
|
| await self.broadcast_to_room( |
| room, |
| { |
| "type": "system", |
| "message": f"{username} joined the room", |
| "timestamp": utc_now(), |
| "room": room, |
| }, |
| persist=False, |
| ) |
| await self.broadcast_user_list(room) |
|
|
| async def get_history(self, room: str) -> List[dict]: |
| self.prune_fallback_history(room) |
| cached = self.history_cache.get(room) |
| if cached and cached["expires_at"] > now_utc(): |
| return cached["messages"] |
|
|
| if store.enabled: |
| try: |
| messages = await store.fetch_history(room) |
| self.cache_history(room, messages) |
| return messages |
| except httpx.HTTPError as exc: |
| logger.warning("Could not fetch Supabase history: %s", exc) |
|
|
| messages = self.fallback_history.get(room, []) |
| self.cache_history(room, messages) |
| return messages |
|
|
| def cache_history(self, room: str, messages: List[dict]) -> None: |
| self.history_cache[room] = { |
| "messages": [message for message in messages if not is_expired(message)], |
| "expires_at": now_utc() + timedelta(seconds=HISTORY_CACHE_SECONDS), |
| } |
|
|
| def append_to_cache(self, room: str, message: dict) -> None: |
| cached = self.history_cache.get(room) |
| if not cached: |
| return |
|
|
| cached["messages"].append(message) |
| cached["messages"] = cached["messages"][-MAX_HISTORY_MESSAGES:] |
|
|
| def prune_fallback_history(self, room: str) -> None: |
| if room in self.fallback_history: |
| self.fallback_history[room] = [ |
| message for message in self.fallback_history[room] if not is_expired(message) |
| ][-MAX_HISTORY_MESSAGES:] |
|
|
| def disconnect(self, websocket: WebSocket, room: str) -> Optional[str]: |
| if room not in self.active_connections: |
| return None |
|
|
| for conn in list(self.active_connections[room]): |
| if conn["websocket"] == websocket: |
| self.active_connections[room].remove(conn) |
| username = conn["username"] |
| if not self.active_connections[room]: |
| self.active_connections.pop(room, None) |
| return username |
|
|
| return None |
|
|
| async def broadcast_to_room(self, room: str, message: dict, persist: bool = True): |
| if persist: |
| if store.enabled: |
| try: |
| await store.save_message(message) |
| except httpx.HTTPError as exc: |
| logger.warning("Could not save message to Supabase: %s", exc) |
| self.add_fallback_message(room, message) |
| else: |
| self.add_fallback_message(room, message) |
|
|
| self.append_to_cache(room, message) |
|
|
| if room not in self.active_connections: |
| return |
|
|
| disconnected = [] |
| for connection_info in list(self.active_connections[room]): |
| try: |
| await connection_info["websocket"].send_text(json.dumps(message)) |
| except RuntimeError: |
| disconnected.append(connection_info) |
|
|
| for conn in disconnected: |
| if conn in self.active_connections.get(room, []): |
| self.active_connections[room].remove(conn) |
|
|
| def add_fallback_message(self, room: str, message: dict) -> None: |
| self.fallback_history.setdefault(room, []).append(message) |
| self.prune_fallback_history(room) |
|
|
| async def broadcast_user_list(self, room: str): |
| if room not in self.active_connections: |
| return |
|
|
| users_message = { |
| "type": "user_list", |
| "users": self.get_room_users(room), |
| "room": room, |
| } |
|
|
| disconnected = [] |
| for connection_info in list(self.active_connections[room]): |
| try: |
| await connection_info["websocket"].send_text(json.dumps(users_message)) |
| except RuntimeError: |
| disconnected.append(connection_info) |
|
|
| for conn in disconnected: |
| if conn in self.active_connections.get(room, []): |
| self.active_connections[room].remove(conn) |
|
|
| def get_room_users(self, room: str) -> List[str]: |
| if room not in self.active_connections: |
| return [] |
| return [conn["username"] for conn in self.active_connections[room]] |
|
|
| def clear_cache(self) -> None: |
| self.history_cache.clear() |
|
|
|
|
| manager = ConnectionManager() |
| cleanup_task: Optional[asyncio.Task] = None |
|
|
|
|
| async def cleanup_loop() -> None: |
| while True: |
| try: |
| deleted_count = await store.cleanup_expired() |
| if deleted_count: |
| manager.clear_cache() |
| logger.info("Deleted %s expired messages/files", deleted_count) |
| except httpx.HTTPError as exc: |
| logger.warning("Expired cleanup failed: %s", exc) |
|
|
| await asyncio.sleep(CLEANUP_INTERVAL_SECONDS) |
|
|
|
|
| @app.on_event("startup") |
| async def startup_event(): |
| global cleanup_task |
| if store.enabled: |
| cleanup_task = asyncio.create_task(cleanup_loop()) |
| logger.info("Temporary cleanup is running every %s seconds", CLEANUP_INTERVAL_SECONDS) |
| else: |
| logger.warning("Supabase is not configured. Using in-memory fallback only.") |
|
|
|
|
| @app.on_event("shutdown") |
| async def shutdown_event(): |
| if cleanup_task: |
| cleanup_task.cancel() |
| with contextlib.suppress(asyncio.CancelledError): |
| await cleanup_task |
|
|
|
|
| @app.get("/", response_class=HTMLResponse) |
| async def get_chat_page(): |
| try: |
| with open("templates/index.html", "r", encoding="utf-8") as f: |
| return HTMLResponse(content=f.read()) |
| except FileNotFoundError: |
| return HTMLResponse( |
| content="<h1>Error: templates/index.html not found</h1>", |
| status_code=404, |
| ) |
|
|
|
|
| @app.websocket("/ws/{room}") |
| async def websocket_endpoint(websocket: WebSocket, room: str, username: str): |
| username = clean_text(username, 20) |
| room = clean_text(room, 30, "global") or "global" |
|
|
| if not username: |
| await websocket.close(code=1008, reason="Username is required") |
| return |
|
|
| await manager.connect(websocket, room, username) |
|
|
| try: |
| while True: |
| data = await websocket.receive_text() |
|
|
| try: |
| message_data = json.loads(data) |
| except json.JSONDecodeError: |
| await websocket.send_text(json.dumps({"type": "error", "message": "Invalid message"})) |
| continue |
|
|
| message_type = message_data.get("type") |
| if message_type == "text": |
| await handle_text_message(room, username, message_data) |
| elif message_type == "file": |
| await handle_file_message(websocket, room, username, message_data) |
| except WebSocketDisconnect: |
| disconnected_username = manager.disconnect(websocket, room) |
| if disconnected_username: |
| await manager.broadcast_to_room( |
| room, |
| { |
| "type": "system", |
| "message": f"{disconnected_username} left the room", |
| "timestamp": utc_now(), |
| "room": room, |
| }, |
| persist=False, |
| ) |
| await manager.broadcast_user_list(room) |
|
|
|
|
| async def handle_text_message(room: str, username: str, message_data: dict): |
| text_content = clean_text(message_data.get("text"), 500) |
| if not text_content: |
| return |
|
|
| await manager.broadcast_to_room( |
| room, |
| { |
| "type": "text", |
| "username": username, |
| "text": text_content, |
| "timestamp": utc_now(), |
| "expiresAt": expiry_time(), |
| "room": room, |
| }, |
| ) |
|
|
|
|
| async def handle_file_message(websocket: WebSocket, room: str, username: str, message_data: dict): |
| file_name = safe_file_name(clean_text(message_data.get("fileName"), 100, "upload.bin")) |
| file_type = clean_text(message_data.get("fileType"), 120, "application/octet-stream") |
| file_data = message_data.get("fileData", "") |
|
|
| if not store.enabled: |
| await websocket.send_text(json.dumps({"type": "error", "message": "File uploads need Supabase configured"})) |
| return |
|
|
| try: |
| file_bytes = base64.b64decode(file_data, validate=True) |
| except (binascii.Error, TypeError): |
| await websocket.send_text(json.dumps({"type": "error", "message": "Invalid file data"})) |
| return |
|
|
| if len(file_bytes) > MAX_FILE_SIZE: |
| await websocket.send_text(json.dumps({"type": "error", "message": "File size exceeds 5MB limit"})) |
| return |
|
|
| try: |
| file_url, file_path = await store.upload_file(room, file_name, file_type, file_bytes) |
| except httpx.HTTPError as exc: |
| logger.warning("File upload failed: %s", exc) |
| await websocket.send_text(json.dumps({"type": "error", "message": "File upload failed"})) |
| return |
|
|
| await manager.broadcast_to_room( |
| room, |
| { |
| "type": "file", |
| "username": username, |
| "fileName": file_name, |
| "fileType": file_type, |
| "fileSize": len(file_bytes), |
| "fileUrl": file_url, |
| "filePath": file_path, |
| "timestamp": utc_now(), |
| "expiresAt": expiry_time(), |
| "room": room, |
| }, |
| ) |
|
|
|
|
| @app.get("/api/rooms") |
| async def get_active_rooms(): |
| rooms = [] |
| for room_name, connections in manager.active_connections.items(): |
| if connections: |
| rooms.append( |
| { |
| "name": room_name, |
| "user_count": len(connections), |
| "users": [conn["username"] for conn in connections], |
| } |
| ) |
| return {"rooms": rooms} |
|
|
|
|
| @app.get("/api/rooms/{room}/users") |
| async def get_room_users(room: str): |
| users = manager.get_room_users(room) |
| return { |
| "room": room, |
| "users": users, |
| "user_count": len(users), |
| } |
|
|
|
|
| if __name__ == "__main__": |
| port = int(os.getenv("PORT", 7860)) |
| uvicorn.run(app, host="0.0.0.0", port=port) |
|
|