"""Chat endpoint with streaming support.""" import asyncio import uuid from fastapi import APIRouter, Depends, HTTPException from sqlalchemy.ext.asyncio import AsyncSession from src.db.postgres.connection import get_db from src.db.postgres.models import ChatMessage, MessageSource, Room from src.agents.orchestration import orchestrator from src.agents.chatbot import chatbot from src.rag.retriever import retriever from src.db.redis.connection import get_redis from src.config.settings import settings from src.middlewares.logging import get_logger, log_execution from sse_starlette.sse import EventSourceResponse from langchain_core.messages import HumanMessage, AIMessage from sqlalchemy import select from pydantic import BaseModel from typing import List, Dict, Any, Optional import json _GREETINGS = frozenset(["hi", "hello", "hey", "halo", "hai", "hei"]) _THANKS = frozenset(["thanks", "thank you", "terima kasih", "makasih", "thx"]) _GOODBYES = frozenset(["bye", "goodbye", "sampai jumpa", "dadah", "see you"]) def _fast_intent(message: str) -> Optional[dict]: """Bypass LLM orchestrator for obvious greetings, thanks, and farewells.""" lower = message.lower().strip().rstrip("!.,?") if lower in _GREETINGS: return {"intent": "greeting", "needs_search": False, "direct_response": "Halo! Ada yang bisa saya bantu?", "search_query": ""} if lower in _THANKS: return {"intent": "thanks", "needs_search": False, "direct_response": "Sama-sama! Ada yang bisa saya bantu lagi?", "search_query": ""} if lower in _GOODBYES: return {"intent": "goodbye", "needs_search": False, "direct_response": "Sampai jumpa! Semoga harimu menyenangkan.", "search_query": ""} return None logger = get_logger("chat_api") router = APIRouter(prefix="/api/v1", tags=["Chat"]) class ChatRequest(BaseModel): user_id: str room_id: str message: str class ClearCacheRequest(BaseModel): room_id: Optional[str] = None user_id: Optional[str] = None _INJECTION_PHRASES = [ "ignore previous instructions", "ignore all prior", "disregard the above", "disregard previous", "you are now", "your new instructions are", "new system prompt", "override your instructions", ] def _sanitize_content(text: str) -> str: """Escape XML metacharacters and neutralize prompt injection phrases. Pure string ops.""" text = text.replace("&", "&").replace("<", "<").replace(">", ">") lower = text.lower() for phrase in _INJECTION_PHRASES: idx = lower.find(phrase) while idx != -1: text = text[:idx] + "[content removed]" + text[idx + len(phrase):] lower = text.lower() idx = lower.find(phrase, idx + len("[content removed]")) return text.strip() def _format_context(relevant_docs: List[Dict[str, Any]], fallback_docs: List[Dict[str, Any]]) -> str: """Format retrieval results as XML-delimited context for the LLM. Injects so the system prompt can enforce the correct behavior: - relevant: docs passed the similarity threshold → answer from them - not_relevant: no docs passed threshold but fallback docs exist → suggest questions - no_documents: nothing retrieved at all → ask user to upload docs """ def _render_docs(docs: List[Dict[str, Any]]) -> str: parts = [] for i, result in enumerate(docs, start=1): data = result["metadata"].get("data", result["metadata"]) filename = data.get("filename", "Unknown") page = data.get("page_label") source_label = f"{filename}, p.{page}" if page else filename sanitized = _sanitize_content(result["content"]) parts.append( f' \n' f' {sanitized}\n' f' ' ) return "\n" + "\n".join(parts) + "\n" if relevant_docs: return "relevant\n" + _render_docs(relevant_docs) elif fallback_docs: return "not_relevant\n" + _render_docs(fallback_docs) else: return "no_documents" def _extract_sources(results: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """Extract deduplicated source references from retrieval results.""" seen = set() sources = [] for result in results: data = result["metadata"].get("data", result["metadata"]) key = (data.get("document_id"), data.get("page_label")) if key not in seen: seen.add(key) sources.append({ "document_id": data.get("document_id"), "filename": data.get("filename", "Unknown"), "page_label": data.get("page_label"), }) return sources async def get_cached_response(redis, cache_key: str) -> Optional[str]: cached = await redis.get(cache_key) if cached: return json.loads(cached) return None async def cache_response(redis, cache_key: str, response: str): await redis.setex(cache_key, 86400, json.dumps(response)) async def load_history(db: AsyncSession, room_id: str, limit: int = 10) -> list: """Load recent chat messages for a room as LangChain message objects (oldest-first).""" result = await db.execute( select(ChatMessage) .where(ChatMessage.room_id == room_id) .order_by(ChatMessage.created_at.asc()) .limit(limit) ) rows = result.scalars().all() return [ HumanMessage(content=row.content) if row.role == "user" else AIMessage(content=row.content) for row in rows ] async def _ensure_room(db: AsyncSession, room_id: str, user_id: str) -> None: """Create the room if it doesn't already exist.""" result = await db.execute(select(Room).where(Room.id == room_id)) if result.scalar_one_or_none() is None: db.add(Room(id=room_id, user_id=user_id, title="New Chat")) async def save_messages( db: AsyncSession, room_id: str, user_id: str, user_content: str, assistant_content: str, audio_text: str = "", sources: Optional[List[Dict[str, Any]]] = None, ): """Persist user and assistant messages, and attach sources to the assistant message.""" await _ensure_room(db, room_id, user_id) db.add(ChatMessage(id=str(uuid.uuid4()), room_id=room_id, role="user", content=user_content)) assistant_id = str(uuid.uuid4()) db.add(ChatMessage(id=assistant_id, room_id=room_id, role="assistant", content=assistant_content, audio_text=audio_text)) for src in (sources or []): page = src.get("page_label") db.add(MessageSource( id=str(uuid.uuid4()), message_id=assistant_id, document_id=src.get("document_id"), filename=src.get("filename"), page_label=str(page) if page is not None else None, )) await db.commit() @router.post("/chat/stream") @log_execution(logger) async def chat_stream(request: ChatRequest, db: AsyncSession = Depends(get_db)): """Chat endpoint with streaming response. SSE event sequence: 1. sources — JSON array of {document_id, filename, page_label} 2. chunk — text fragments of the answer 3. done — signals end of stream """ redis = await get_redis() cache_key = f"{settings.redis_prefix}chat:{request.room_id}:{request.message}" cached = await get_cached_response(redis, cache_key) if cached: logger.info("Returning cached response") async def stream_cached(): yield {"event": "sources", "data": json.dumps([])} for i in range(0, len(cached), 50): yield {"event": "chunk", "data": cached[i:i + 50]} yield {"event": "done", "data": ""} return EventSourceResponse(stream_cached()) try: # Step 1: Fast local intent check (skips LLM for greetings/farewells) intent_result = _fast_intent(request.message) context = "" sources: List[Dict[str, Any]] = [] if intent_result is None: # Step 2: Launch retrieval and history loading in parallel, then run orchestrator retrieval_task = asyncio.create_task( retriever.retrieve(request.message, request.user_id, db) ) history_task = asyncio.create_task( load_history(db, request.room_id, limit=6) # 6 msgs (3 pairs) for orchestrator ) history = await history_task # fast DB query (<100ms), done before orchestrator finishes intent_result = await orchestrator.analyze_message(request.message, history) if not intent_result.get("needs_search"): retrieval_task.cancel() relevant_docs, fallback_docs = [], [] else: search_query = intent_result.get("search_query", request.message) logger.info(f"Searching for: {search_query}") if search_query != request.message: retrieval_task.cancel() relevant_docs, fallback_docs = await retriever.retrieve( query=search_query, user_id=request.user_id, db=db, ) else: relevant_docs, fallback_docs = await retrieval_task context = _format_context(relevant_docs, fallback_docs) logger.info(f"assembled context ({context})") sources = _extract_sources(relevant_docs) # Step 3: Direct response for greetings / non-document intents if intent_result.get("direct_response"): response = intent_result["direct_response"] await cache_response(redis, cache_key, response) async def stream_direct(): audio_text = await chatbot.generate_audio_text(response) yield {"event": "sources", "data": json.dumps([])} yield {"event": "message", "data": response} yield {"event": "audio_text", "data": audio_text} yield {"event": "done", "data": ""} await save_messages(db, request.room_id, request.user_id, request.message, response, audio_text=audio_text, sources=[]) return EventSourceResponse(stream_direct()) # Step 4: Stream answer token-by-token as LLM generates it # Load full history (10 msgs) for chatbot — richer context than the 6 used by orchestrator full_history = await load_history(db, request.room_id, limit=10) messages = full_history + [HumanMessage(content=request.message)] async def stream_response(): full_response = "" yield {"event": "sources", "data": json.dumps(sources)} async for token in chatbot.astream_response(messages, context): full_response += token yield {"event": "chunk", "data": token} # Fire audio_text generation and cache write concurrently once streaming completes audio_text_task = asyncio.create_task(chatbot.generate_audio_text(full_response)) cache_task = asyncio.create_task(cache_response(redis, cache_key, full_response)) audio_text = await audio_text_task yield {"event": "audio_text", "data": audio_text} yield {"event": "done", "data": ""} await cache_task await save_messages(db, request.room_id, request.user_id, request.message, full_response, audio_text=audio_text, sources=sources) return EventSourceResponse(stream_response()) except Exception as e: logger.error("Chat failed", error=str(e)) raise HTTPException(status_code=500, detail=f"Chat failed: {str(e)}") @router.delete("/cache") @log_execution(logger) async def clear_cache(request: ClearCacheRequest): """Clear Redis cache. - room_id only: hapus cache chat untuk room tertentu - user_id only: hapus cache retrieval untuk user tertentu - keduanya: hapus cache chat room + retrieval user - kosong: hapus semua cache (prefix maintiva-agent-service_) """ if not request.room_id and not request.user_id: raise HTTPException( status_code=400, detail="Sediakan minimal salah satu: room_id atau user_id. Untuk clear semua cache gunakan endpoint DELETE /cache/all." ) redis = await get_redis() deleted = 0 if request.room_id: pattern = f"{settings.redis_prefix}chat:{request.room_id}:*" keys = await redis.keys(pattern) if keys: deleted += await redis.delete(*keys) if request.user_id: pattern = f"{settings.redis_prefix}retrieval:{request.user_id}:*" keys = await redis.keys(pattern) if keys: deleted += await redis.delete(*keys) return {"deleted_keys": deleted, "room_id": request.room_id, "user_id": request.user_id} @router.delete("/cache/all") @log_execution(logger) async def clear_all_cache(): """Hapus semua cache Redis: app cache (maintiva-agent-service_*) + LangChain LLM cache (langchain:*).""" redis = await get_redis() # Clear app-level cache (chat responses + retrieval results) app_keys = await redis.keys(f"{settings.redis_prefix}*") deleted = 0 if app_keys: deleted += await redis.delete(*app_keys) # Clear LangChain LLM response cache lc_keys = await redis.keys("langchain:*") if lc_keys: deleted += await redis.delete(*lc_keys) return {"deleted_keys": deleted}