"""Chat endpoint with streaming support.""" import uuid import json from typing import List, Dict, Any, Optional from fastapi import APIRouter, Depends, HTTPException from langchain_core.messages import HumanMessage, AIMessage from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sse_starlette.sse import EventSourceResponse from src.agents.chat_handler import ChatHandler from src.config.settings import settings from src.db.postgres.connection import get_db from src.db.postgres.models import ChatMessage, MessageSource from src.db.redis.connection import get_redis from src.middlewares.logging import get_logger, log_execution logger = get_logger("chat_api") router = APIRouter(prefix="/api/v1", tags=["Chat"]) _GREETINGS = frozenset(["hi", "hello", "hey", "halo", "hai", "hei"]) _GOODBYES = frozenset(["bye", "goodbye", "thanks", "thank you", "terima kasih", "sampai jumpa"]) def _fast_intent(message: str) -> Optional[str]: """Return a direct response for obvious greetings/farewells, else None.""" lower = message.lower().strip().rstrip("!.,?") if lower in _GREETINGS: return "Hello! How can I assist you today?" if lower in _GOODBYES: return "Goodbye! Have a great day!" return None class ChatRequest(BaseModel): user_id: str room_id: str message: str async def get_cached_response(redis, cache_key: str) -> Optional[dict]: cached = await redis.get(cache_key) if cached: data = json.loads(cached) if isinstance(data, dict) and "response" in data: return data # legacy: plain string cached before this change return {"response": data, "sources": []} return None async def cache_response(redis, cache_key: str, response: str, sources: list): await redis.setex(cache_key, 86400, json.dumps({"response": response, "sources": sources})) async def load_history(db: AsyncSession, room_id: str, limit: int = 10) -> list: """Load recent chat messages for a room as LangChain message objects (oldest-first).""" result = await db.execute( select(ChatMessage) .where(ChatMessage.room_id == room_id) .order_by(ChatMessage.created_at.asc()) .limit(limit) ) rows = result.scalars().all() return [ HumanMessage(content=row.content) if row.role == "user" else AIMessage(content=row.content) for row in rows ] async def save_messages( db: AsyncSession, room_id: str, user_content: str, assistant_content: str, sources: Optional[List[Dict[str, Any]]] = None, ): """Persist user and assistant messages, and attach sources to the assistant message.""" db.add(ChatMessage(id=str(uuid.uuid4()), room_id=room_id, role="user", content=user_content)) assistant_id = str(uuid.uuid4()) db.add(ChatMessage(id=assistant_id, room_id=room_id, role="assistant", content=assistant_content)) for src in (sources or []): page = src.get("page_label") db.add(MessageSource( id=str(uuid.uuid4()), message_id=assistant_id, document_id=src.get("document_id"), filename=src.get("filename"), page_label=str(page) if page is not None else None, )) await db.commit() @router.delete("/chat/cache") async def clear_chat_cache(room_id: str, message: str): """Delete the Redis cache entry for a specific room + message pair.""" redis = await get_redis() cache_key = f"{settings.redis_prefix}chat:{room_id}:{message}" deleted = await redis.delete(cache_key) return {"deleted": deleted > 0, "cache_key": cache_key} @router.delete("/chat/cache/room/{room_id}") async def clear_room_cache(room_id: str): """Delete all Redis cache entries for a room.""" redis = await get_redis() pattern = f"{settings.redis_prefix}chat:{room_id}:*" keys = await redis.keys(pattern) if keys: await redis.delete(*keys) return {"deleted_count": len(keys), "room_id": room_id} @router.delete("/retrieval/cache/{user_id}") async def clear_retrieval_cache(user_id: str): """Delete all cached retrieval results for a user. Call this after uploading/processing new documents.""" from src.retrieval.router import retrieval_router deleted = await retrieval_router.invalidate_cache(user_id) return {"deleted_count": deleted, "user_id": user_id} @router.post("/chat/stream") @log_execution(logger) async def chat_stream(request: ChatRequest, db: AsyncSession = Depends(get_db)): """Chat endpoint with streaming response. SSE event sequence: 1. sources — JSON array of source refs from ChatHandler (table for structured; deduped document_id/page_label for unstructured) 2. chunk — text fragments of the answer 3. done — signals end of stream """ redis = await get_redis() cache_key = f"{settings.redis_prefix}chat:{request.room_id}:{request.message}" # Redis cache hit cached = await get_cached_response(redis, cache_key) logger.info("cache check", cache_key=cache_key, cache_hit=cached is not None) if cached: logger.info("Returning cached response") cached_text = cached["response"] cached_sources = cached["sources"] await save_messages(db, request.room_id, request.message, cached_text, sources=cached_sources) async def stream_cached(): yield {"event": "sources", "data": json.dumps(cached_sources)} for i in range(0, len(cached_text), 50): yield {"event": "chunk", "data": cached_text[i:i + 50]} yield {"event": "done", "data": ""} return EventSourceResponse(stream_cached()) try: # Fast intent: greetings/farewells bypass LLM entirely direct = _fast_intent(request.message) if direct: await cache_response(redis, cache_key, direct, sources=[]) await save_messages(db, request.room_id, request.message, direct, sources=[]) async def stream_direct(): yield {"event": "sources", "data": json.dumps([])} yield {"event": "chunk", "data": direct} yield {"event": "done", "data": ""} return EventSourceResponse(stream_direct()) history = await load_history(db, request.room_id, limit=10) handler = ChatHandler() async def stream_response(): logger.info("stream_response started", room_id=request.room_id, user_id=request.user_id) full_response = "" sources: List[Dict[str, Any]] = [] async for event in handler.handle(request.message, request.user_id, history): if event["event"] == "sources": try: sources = json.loads(event["data"]) or [] except (TypeError, ValueError): sources = [] yield event elif event["event"] == "chunk": full_response += event["data"] yield event elif event["event"] == "done": await cache_response(redis, cache_key, full_response, sources=sources) logger.info("saving messages", sources_count=len(sources), sources=sources) try: await save_messages(db, request.room_id, request.message, full_response, sources=sources) except Exception as e: logger.error("save_messages failed", room_id=request.room_id, error=str(e)) yield event elif event["event"] == "error": yield event return # "intent" event: consumed internally, not forwarded to frontend return EventSourceResponse(stream_response()) except Exception as e: logger.error("Chat failed", error=str(e)) raise HTTPException(status_code=500, detail=f"Chat failed: {str(e)}")