ishaq101's picture
[KM-582][DED][AI] Fix Retrieval in Agentic Service
61c746f
"""Chat endpoint with streaming support."""
import uuid
import json
from typing import List, Dict, Any, Optional
from fastapi import APIRouter, Depends, HTTPException
from langchain_core.messages import HumanMessage, AIMessage
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sse_starlette.sse import EventSourceResponse
from src.agents.chat_handler import ChatHandler
from src.config.settings import settings
from src.db.postgres.connection import get_db
from src.db.postgres.models import ChatMessage, MessageSource
from src.db.redis.connection import get_redis
from src.middlewares.logging import get_logger, log_execution
logger = get_logger("chat_api")
router = APIRouter(prefix="/api/v1", tags=["Chat"])
_GREETINGS = frozenset(["hi", "hello", "hey", "halo", "hai", "hei"])
_GOODBYES = frozenset(["bye", "goodbye", "thanks", "thank you", "terima kasih", "sampai jumpa"])
def _fast_intent(message: str) -> Optional[str]:
"""Return a direct response for obvious greetings/farewells, else None."""
lower = message.lower().strip().rstrip("!.,?")
if lower in _GREETINGS:
return "Hello! How can I assist you today?"
if lower in _GOODBYES:
return "Goodbye! Have a great day!"
return None
class ChatRequest(BaseModel):
user_id: str
room_id: str
message: str
async def get_cached_response(redis, cache_key: str) -> Optional[dict]:
cached = await redis.get(cache_key)
if cached:
data = json.loads(cached)
if isinstance(data, dict) and "response" in data:
return data
# legacy: plain string cached before this change
return {"response": data, "sources": []}
return None
async def cache_response(redis, cache_key: str, response: str, sources: list):
await redis.setex(cache_key, 86400, json.dumps({"response": response, "sources": sources}))
async def load_history(db: AsyncSession, room_id: str, limit: int = 10) -> list:
"""Load recent chat messages for a room as LangChain message objects (oldest-first)."""
result = await db.execute(
select(ChatMessage)
.where(ChatMessage.room_id == room_id)
.order_by(ChatMessage.created_at.asc())
.limit(limit)
)
rows = result.scalars().all()
return [
HumanMessage(content=row.content) if row.role == "user" else AIMessage(content=row.content)
for row in rows
]
async def save_messages(
db: AsyncSession,
room_id: str,
user_content: str,
assistant_content: str,
sources: Optional[List[Dict[str, Any]]] = None,
):
"""Persist user and assistant messages, and attach sources to the assistant message."""
db.add(ChatMessage(id=str(uuid.uuid4()), room_id=room_id, role="user", content=user_content))
assistant_id = str(uuid.uuid4())
db.add(ChatMessage(id=assistant_id, room_id=room_id, role="assistant", content=assistant_content))
for src in (sources or []):
page = src.get("page_label")
db.add(MessageSource(
id=str(uuid.uuid4()),
message_id=assistant_id,
document_id=src.get("document_id"),
filename=src.get("filename"),
page_label=str(page) if page is not None else None,
))
await db.commit()
@router.delete("/chat/cache")
async def clear_chat_cache(room_id: str, message: str):
"""Delete the Redis cache entry for a specific room + message pair."""
redis = await get_redis()
cache_key = f"{settings.redis_prefix}chat:{room_id}:{message}"
deleted = await redis.delete(cache_key)
return {"deleted": deleted > 0, "cache_key": cache_key}
@router.delete("/chat/cache/room/{room_id}")
async def clear_room_cache(room_id: str):
"""Delete all Redis cache entries for a room."""
redis = await get_redis()
pattern = f"{settings.redis_prefix}chat:{room_id}:*"
keys = await redis.keys(pattern)
if keys:
await redis.delete(*keys)
return {"deleted_count": len(keys), "room_id": room_id}
@router.delete("/retrieval/cache/{user_id}")
async def clear_retrieval_cache(user_id: str):
"""Delete all cached retrieval results for a user. Call this after uploading/processing new documents."""
from src.retrieval.router import retrieval_router
deleted = await retrieval_router.invalidate_cache(user_id)
return {"deleted_count": deleted, "user_id": user_id}
@router.post("/chat/stream")
@log_execution(logger)
async def chat_stream(request: ChatRequest, db: AsyncSession = Depends(get_db)):
"""Chat endpoint with streaming response.
SSE event sequence:
1. sources — JSON array of source refs from ChatHandler (table for
structured; deduped document_id/page_label for unstructured)
2. chunk — text fragments of the answer
3. done — signals end of stream
"""
redis = await get_redis()
cache_key = f"{settings.redis_prefix}chat:{request.room_id}:{request.message}"
# Redis cache hit
cached = await get_cached_response(redis, cache_key)
logger.info("cache check", cache_key=cache_key, cache_hit=cached is not None)
if cached:
logger.info("Returning cached response")
cached_text = cached["response"]
cached_sources = cached["sources"]
await save_messages(db, request.room_id, request.message, cached_text, sources=cached_sources)
async def stream_cached():
yield {"event": "sources", "data": json.dumps(cached_sources)}
for i in range(0, len(cached_text), 50):
yield {"event": "chunk", "data": cached_text[i:i + 50]}
yield {"event": "done", "data": ""}
return EventSourceResponse(stream_cached())
try:
# Fast intent: greetings/farewells bypass LLM entirely
direct = _fast_intent(request.message)
if direct:
await cache_response(redis, cache_key, direct, sources=[])
await save_messages(db, request.room_id, request.message, direct, sources=[])
async def stream_direct():
yield {"event": "sources", "data": json.dumps([])}
yield {"event": "chunk", "data": direct}
yield {"event": "done", "data": ""}
return EventSourceResponse(stream_direct())
history = await load_history(db, request.room_id, limit=10)
handler = ChatHandler()
async def stream_response():
logger.info("stream_response started", room_id=request.room_id, user_id=request.user_id)
full_response = ""
sources: List[Dict[str, Any]] = []
async for event in handler.handle(request.message, request.user_id, history):
if event["event"] == "sources":
try:
sources = json.loads(event["data"]) or []
except (TypeError, ValueError):
sources = []
yield event
elif event["event"] == "chunk":
full_response += event["data"]
yield event
elif event["event"] == "done":
await cache_response(redis, cache_key, full_response, sources=sources)
logger.info("saving messages", sources_count=len(sources), sources=sources)
try:
await save_messages(db, request.room_id, request.message, full_response, sources=sources)
except Exception as e:
logger.error("save_messages failed", room_id=request.room_id, error=str(e))
yield event
elif event["event"] == "error":
yield event
return
# "intent" event: consumed internally, not forwarded to frontend
return EventSourceResponse(stream_response())
except Exception as e:
logger.error("Chat failed", error=str(e))
raise HTTPException(status_code=500, detail=f"Chat failed: {str(e)}")