| """Chat endpoint with streaming support.""" |
|
|
| import uuid |
| import json |
| from typing import List, Dict, Any, Optional |
|
|
| from fastapi import APIRouter, Depends, HTTPException |
| from langchain_core.messages import HumanMessage, AIMessage |
| from pydantic import BaseModel |
| from sqlalchemy import select |
| from sqlalchemy.ext.asyncio import AsyncSession |
| from sse_starlette.sse import EventSourceResponse |
|
|
| from src.agents.chat_handler import ChatHandler |
| from src.config.settings import settings |
| from src.db.postgres.connection import get_db |
| from src.db.postgres.models import ChatMessage, MessageSource |
| from src.db.redis.connection import get_redis |
| from src.middlewares.logging import get_logger, log_execution |
|
|
| logger = get_logger("chat_api") |
|
|
| router = APIRouter(prefix="/api/v1", tags=["Chat"]) |
|
|
| _GREETINGS = frozenset(["hi", "hello", "hey", "halo", "hai", "hei"]) |
| _GOODBYES = frozenset(["bye", "goodbye", "thanks", "thank you", "terima kasih", "sampai jumpa"]) |
|
|
|
|
| def _fast_intent(message: str) -> Optional[str]: |
| """Return a direct response for obvious greetings/farewells, else None.""" |
| lower = message.lower().strip().rstrip("!.,?") |
| if lower in _GREETINGS: |
| return "Hello! How can I assist you today?" |
| if lower in _GOODBYES: |
| return "Goodbye! Have a great day!" |
| return None |
|
|
|
|
| class ChatRequest(BaseModel): |
| user_id: str |
| room_id: str |
| message: str |
|
|
|
|
| async def get_cached_response(redis, cache_key: str) -> Optional[dict]: |
| cached = await redis.get(cache_key) |
| if cached: |
| data = json.loads(cached) |
| if isinstance(data, dict) and "response" in data: |
| return data |
| |
| return {"response": data, "sources": []} |
| return None |
|
|
|
|
| async def cache_response(redis, cache_key: str, response: str, sources: list): |
| await redis.setex(cache_key, 86400, json.dumps({"response": response, "sources": sources})) |
|
|
|
|
| async def load_history(db: AsyncSession, room_id: str, limit: int = 10) -> list: |
| """Load recent chat messages for a room as LangChain message objects (oldest-first).""" |
| result = await db.execute( |
| select(ChatMessage) |
| .where(ChatMessage.room_id == room_id) |
| .order_by(ChatMessage.created_at.asc()) |
| .limit(limit) |
| ) |
| rows = result.scalars().all() |
| return [ |
| HumanMessage(content=row.content) if row.role == "user" else AIMessage(content=row.content) |
| for row in rows |
| ] |
|
|
|
|
| async def save_messages( |
| db: AsyncSession, |
| room_id: str, |
| user_content: str, |
| assistant_content: str, |
| sources: Optional[List[Dict[str, Any]]] = None, |
| ): |
| """Persist user and assistant messages, and attach sources to the assistant message.""" |
| db.add(ChatMessage(id=str(uuid.uuid4()), room_id=room_id, role="user", content=user_content)) |
| assistant_id = str(uuid.uuid4()) |
| db.add(ChatMessage(id=assistant_id, room_id=room_id, role="assistant", content=assistant_content)) |
| for src in (sources or []): |
| page = src.get("page_label") |
| db.add(MessageSource( |
| id=str(uuid.uuid4()), |
| message_id=assistant_id, |
| document_id=src.get("document_id"), |
| filename=src.get("filename"), |
| page_label=str(page) if page is not None else None, |
| )) |
| await db.commit() |
|
|
|
|
| @router.delete("/chat/cache") |
| async def clear_chat_cache(room_id: str, message: str): |
| """Delete the Redis cache entry for a specific room + message pair.""" |
| redis = await get_redis() |
| cache_key = f"{settings.redis_prefix}chat:{room_id}:{message}" |
| deleted = await redis.delete(cache_key) |
| return {"deleted": deleted > 0, "cache_key": cache_key} |
|
|
|
|
| @router.delete("/chat/cache/room/{room_id}") |
| async def clear_room_cache(room_id: str): |
| """Delete all Redis cache entries for a room.""" |
| redis = await get_redis() |
| pattern = f"{settings.redis_prefix}chat:{room_id}:*" |
| keys = await redis.keys(pattern) |
| if keys: |
| await redis.delete(*keys) |
| return {"deleted_count": len(keys), "room_id": room_id} |
|
|
|
|
| @router.delete("/retrieval/cache/{user_id}") |
| async def clear_retrieval_cache(user_id: str): |
| """Delete all cached retrieval results for a user. Call this after uploading/processing new documents.""" |
| from src.retrieval.router import retrieval_router |
| deleted = await retrieval_router.invalidate_cache(user_id) |
| return {"deleted_count": deleted, "user_id": user_id} |
|
|
|
|
| @router.post("/chat/stream") |
| @log_execution(logger) |
| async def chat_stream(request: ChatRequest, db: AsyncSession = Depends(get_db)): |
| """Chat endpoint with streaming response. |
| |
| SSE event sequence: |
| 1. sources — JSON array of source refs from ChatHandler (table for |
| structured; deduped document_id/page_label for unstructured) |
| 2. chunk — text fragments of the answer |
| 3. done — signals end of stream |
| """ |
| redis = await get_redis() |
| cache_key = f"{settings.redis_prefix}chat:{request.room_id}:{request.message}" |
|
|
| |
| cached = await get_cached_response(redis, cache_key) |
| logger.info("cache check", cache_key=cache_key, cache_hit=cached is not None) |
| if cached: |
| logger.info("Returning cached response") |
| cached_text = cached["response"] |
| cached_sources = cached["sources"] |
| await save_messages(db, request.room_id, request.message, cached_text, sources=cached_sources) |
|
|
| async def stream_cached(): |
| yield {"event": "sources", "data": json.dumps(cached_sources)} |
| for i in range(0, len(cached_text), 50): |
| yield {"event": "chunk", "data": cached_text[i:i + 50]} |
| yield {"event": "done", "data": ""} |
|
|
| return EventSourceResponse(stream_cached()) |
|
|
| try: |
| |
| direct = _fast_intent(request.message) |
| if direct: |
| await cache_response(redis, cache_key, direct, sources=[]) |
| await save_messages(db, request.room_id, request.message, direct, sources=[]) |
|
|
| async def stream_direct(): |
| yield {"event": "sources", "data": json.dumps([])} |
| yield {"event": "chunk", "data": direct} |
| yield {"event": "done", "data": ""} |
|
|
| return EventSourceResponse(stream_direct()) |
|
|
| history = await load_history(db, request.room_id, limit=10) |
| handler = ChatHandler() |
|
|
| async def stream_response(): |
| logger.info("stream_response started", room_id=request.room_id, user_id=request.user_id) |
| full_response = "" |
| sources: List[Dict[str, Any]] = [] |
| async for event in handler.handle(request.message, request.user_id, history): |
| if event["event"] == "sources": |
| try: |
| sources = json.loads(event["data"]) or [] |
| except (TypeError, ValueError): |
| sources = [] |
| yield event |
| elif event["event"] == "chunk": |
| full_response += event["data"] |
| yield event |
| elif event["event"] == "done": |
| await cache_response(redis, cache_key, full_response, sources=sources) |
| logger.info("saving messages", sources_count=len(sources), sources=sources) |
| try: |
| await save_messages(db, request.room_id, request.message, full_response, sources=sources) |
| except Exception as e: |
| logger.error("save_messages failed", room_id=request.room_id, error=str(e)) |
| yield event |
| elif event["event"] == "error": |
| yield event |
| return |
| |
|
|
| return EventSourceResponse(stream_response()) |
|
|
| except Exception as e: |
| logger.error("Chat failed", error=str(e)) |
| raise HTTPException(status_code=500, detail=f"Chat failed: {str(e)}") |
|
|