| """Chat endpoint with streaming support.""" |
|
|
| import asyncio |
| import uuid |
| from fastapi import APIRouter, Depends, HTTPException |
| from sqlalchemy.ext.asyncio import AsyncSession |
| from src.db.postgres.connection import get_db |
| from src.db.postgres.models import ChatMessage, MessageSource, Room |
| from src.agents.orchestration import orchestrator |
| from src.agents.chatbot import chatbot |
| from src.rag.retriever import retriever |
| from src.db.redis.connection import get_redis |
| from src.config.settings import settings |
| from src.middlewares.logging import get_logger, log_execution |
| from sse_starlette.sse import EventSourceResponse |
| from langchain_core.messages import HumanMessage, AIMessage |
| from sqlalchemy import select |
| from pydantic import BaseModel |
| from typing import List, Dict, Any, Optional |
| import json |
|
|
| _GREETINGS = frozenset(["hi", "hello", "hey", "halo", "hai", "hei"]) |
| _THANKS = frozenset(["thanks", "thank you", "terima kasih", "makasih", "thx"]) |
| _GOODBYES = frozenset(["bye", "goodbye", "sampai jumpa", "dadah", "see you"]) |
|
|
|
|
| def _fast_intent(message: str) -> Optional[dict]: |
| """Bypass LLM orchestrator for obvious greetings, thanks, and farewells.""" |
| lower = message.lower().strip().rstrip("!.,?") |
| if lower in _GREETINGS: |
| return {"intent": "greeting", "needs_search": False, |
| "direct_response": "Halo! Ada yang bisa saya bantu?", "search_query": ""} |
| if lower in _THANKS: |
| return {"intent": "thanks", "needs_search": False, |
| "direct_response": "Sama-sama! Ada yang bisa saya bantu lagi?", "search_query": ""} |
| if lower in _GOODBYES: |
| return {"intent": "goodbye", "needs_search": False, |
| "direct_response": "Sampai jumpa! Semoga harimu menyenangkan.", "search_query": ""} |
| return None |
|
|
| logger = get_logger("chat_api") |
|
|
| router = APIRouter(prefix="/api/v1", tags=["Chat"]) |
|
|
|
|
| class ChatRequest(BaseModel): |
| user_id: str |
| room_id: str |
| message: str |
|
|
|
|
| class ClearCacheRequest(BaseModel): |
| room_id: Optional[str] = None |
| user_id: Optional[str] = None |
|
|
|
|
| _INJECTION_PHRASES = [ |
| "ignore previous instructions", |
| "ignore all prior", |
| "disregard the above", |
| "disregard previous", |
| "you are now", |
| "your new instructions are", |
| "new system prompt", |
| "override your instructions", |
| ] |
|
|
|
|
| def _sanitize_content(text: str) -> str: |
| """Escape XML metacharacters and neutralize prompt injection phrases. Pure string ops.""" |
| text = text.replace("&", "&").replace("<", "<").replace(">", ">") |
| lower = text.lower() |
| for phrase in _INJECTION_PHRASES: |
| idx = lower.find(phrase) |
| while idx != -1: |
| text = text[:idx] + "[content removed]" + text[idx + len(phrase):] |
| lower = text.lower() |
| idx = lower.find(phrase, idx + len("[content removed]")) |
| return text.strip() |
|
|
|
|
|
|
| def _format_context(relevant_docs: List[Dict[str, Any]], fallback_docs: List[Dict[str, Any]]) -> str: |
| """Format retrieval results as XML-delimited context for the LLM. |
| |
| Injects <context_status> so the system prompt can enforce the correct behavior: |
| - relevant: docs passed the similarity threshold → answer from them |
| - not_relevant: no docs passed threshold but fallback docs exist → suggest questions |
| - no_documents: nothing retrieved at all → ask user to upload docs |
| """ |
| def _render_docs(docs: List[Dict[str, Any]]) -> str: |
| parts = [] |
| for i, result in enumerate(docs, start=1): |
| data = result["metadata"].get("data", result["metadata"]) |
| filename = data.get("filename", "Unknown") |
| page = data.get("page_label") |
| source_label = f"{filename}, p.{page}" if page else filename |
| sanitized = _sanitize_content(result["content"]) |
| parts.append( |
| f' <document index="{i}" source="{source_label}">\n' |
| f' {sanitized}\n' |
| f' </document>' |
| ) |
| return "<documents>\n" + "\n".join(parts) + "\n</documents>" |
|
|
| if relevant_docs: |
| return "<context_status>relevant</context_status>\n" + _render_docs(relevant_docs) |
| elif fallback_docs: |
| return "<context_status>not_relevant</context_status>\n" + _render_docs(fallback_docs) |
| else: |
| return "<context_status>no_documents</context_status>" |
|
|
|
|
| def _extract_sources(results: List[Dict[str, Any]]) -> List[Dict[str, Any]]: |
| """Extract deduplicated source references from retrieval results.""" |
| seen = set() |
| sources = [] |
| for result in results: |
| data = result["metadata"].get("data", result["metadata"]) |
| key = (data.get("document_id"), data.get("page_label")) |
| if key not in seen: |
| seen.add(key) |
| sources.append({ |
| "document_id": data.get("document_id"), |
| "filename": data.get("filename", "Unknown"), |
| "page_label": data.get("page_label"), |
| }) |
| return sources |
|
|
|
|
| async def get_cached_response(redis, cache_key: str) -> Optional[str]: |
| cached = await redis.get(cache_key) |
| if cached: |
| return json.loads(cached) |
| return None |
|
|
|
|
| async def cache_response(redis, cache_key: str, response: str): |
| await redis.setex(cache_key, 86400, json.dumps(response)) |
|
|
|
|
| async def load_history(db: AsyncSession, room_id: str, limit: int = 10) -> list: |
| """Load recent chat messages for a room as LangChain message objects (oldest-first).""" |
| result = await db.execute( |
| select(ChatMessage) |
| .where(ChatMessage.room_id == room_id) |
| .order_by(ChatMessage.created_at.asc()) |
| .limit(limit) |
| ) |
| rows = result.scalars().all() |
| return [ |
| HumanMessage(content=row.content) if row.role == "user" else AIMessage(content=row.content) |
| for row in rows |
| ] |
|
|
|
|
| async def _ensure_room(db: AsyncSession, room_id: str, user_id: str) -> None: |
| """Create the room if it doesn't already exist.""" |
| result = await db.execute(select(Room).where(Room.id == room_id)) |
| if result.scalar_one_or_none() is None: |
| db.add(Room(id=room_id, user_id=user_id, title="New Chat")) |
|
|
|
|
| async def save_messages( |
| db: AsyncSession, |
| room_id: str, |
| user_id: str, |
| user_content: str, |
| assistant_content: str, |
| audio_text: str = "", |
| sources: Optional[List[Dict[str, Any]]] = None, |
| ): |
| """Persist user and assistant messages, and attach sources to the assistant message.""" |
| await _ensure_room(db, room_id, user_id) |
| db.add(ChatMessage(id=str(uuid.uuid4()), room_id=room_id, role="user", content=user_content)) |
| assistant_id = str(uuid.uuid4()) |
| db.add(ChatMessage(id=assistant_id, room_id=room_id, role="assistant", content=assistant_content, audio_text=audio_text)) |
| for src in (sources or []): |
| page = src.get("page_label") |
| db.add(MessageSource( |
| id=str(uuid.uuid4()), |
| message_id=assistant_id, |
| document_id=src.get("document_id"), |
| filename=src.get("filename"), |
| page_label=str(page) if page is not None else None, |
| )) |
| await db.commit() |
|
|
|
|
| @router.post("/chat/stream") |
| @log_execution(logger) |
| async def chat_stream(request: ChatRequest, db: AsyncSession = Depends(get_db)): |
| """Chat endpoint with streaming response. |
| |
| SSE event sequence: |
| 1. sources — JSON array of {document_id, filename, page_label} |
| 2. chunk — text fragments of the answer |
| 3. done — signals end of stream |
| """ |
| redis = await get_redis() |
|
|
| cache_key = f"{settings.redis_prefix}chat:{request.room_id}:{request.message}" |
| cached = await get_cached_response(redis, cache_key) |
| if cached: |
| logger.info("Returning cached response") |
|
|
| async def stream_cached(): |
| yield {"event": "sources", "data": json.dumps([])} |
| for i in range(0, len(cached), 50): |
| yield {"event": "chunk", "data": cached[i:i + 50]} |
| yield {"event": "done", "data": ""} |
|
|
| return EventSourceResponse(stream_cached()) |
|
|
| try: |
| |
| intent_result = _fast_intent(request.message) |
|
|
| context = "" |
| sources: List[Dict[str, Any]] = [] |
|
|
| if intent_result is None: |
| |
| retrieval_task = asyncio.create_task( |
| retriever.retrieve(request.message, request.user_id, db) |
| ) |
| history_task = asyncio.create_task( |
| load_history(db, request.room_id, limit=6) |
| ) |
| history = await history_task |
| intent_result = await orchestrator.analyze_message(request.message, history) |
|
|
| if not intent_result.get("needs_search"): |
| retrieval_task.cancel() |
| relevant_docs, fallback_docs = [], [] |
| else: |
| search_query = intent_result.get("search_query", request.message) |
| logger.info(f"Searching for: {search_query}") |
| if search_query != request.message: |
| retrieval_task.cancel() |
| relevant_docs, fallback_docs = await retriever.retrieve( |
| query=search_query, |
| user_id=request.user_id, |
| db=db, |
| ) |
| else: |
| relevant_docs, fallback_docs = await retrieval_task |
|
|
| context = _format_context(relevant_docs, fallback_docs) |
| logger.info(f"assembled context ({context})") |
| sources = _extract_sources(relevant_docs) |
|
|
| |
| if intent_result.get("direct_response"): |
| response = intent_result["direct_response"] |
| await cache_response(redis, cache_key, response) |
|
|
| async def stream_direct(): |
| audio_text = await chatbot.generate_audio_text(response) |
| yield {"event": "sources", "data": json.dumps([])} |
| yield {"event": "message", "data": response} |
| yield {"event": "audio_text", "data": audio_text} |
| yield {"event": "done", "data": ""} |
| await save_messages(db, request.room_id, request.user_id, request.message, response, audio_text=audio_text, sources=[]) |
|
|
| return EventSourceResponse(stream_direct()) |
|
|
| |
| |
| full_history = await load_history(db, request.room_id, limit=10) |
| messages = full_history + [HumanMessage(content=request.message)] |
|
|
| async def stream_response(): |
| full_response = "" |
| yield {"event": "sources", "data": json.dumps(sources)} |
| async for token in chatbot.astream_response(messages, context): |
| full_response += token |
| yield {"event": "chunk", "data": token} |
| |
| audio_text_task = asyncio.create_task(chatbot.generate_audio_text(full_response)) |
| cache_task = asyncio.create_task(cache_response(redis, cache_key, full_response)) |
| audio_text = await audio_text_task |
| yield {"event": "audio_text", "data": audio_text} |
| yield {"event": "done", "data": ""} |
| await cache_task |
| await save_messages(db, request.room_id, request.user_id, request.message, full_response, audio_text=audio_text, sources=sources) |
|
|
| return EventSourceResponse(stream_response()) |
|
|
| except Exception as e: |
| logger.error("Chat failed", error=str(e)) |
| raise HTTPException(status_code=500, detail=f"Chat failed: {str(e)}") |
|
|
|
|
| @router.delete("/cache") |
| @log_execution(logger) |
| async def clear_cache(request: ClearCacheRequest): |
| """Clear Redis cache. |
| |
| - room_id only: hapus cache chat untuk room tertentu |
| - user_id only: hapus cache retrieval untuk user tertentu |
| - keduanya: hapus cache chat room + retrieval user |
| - kosong: hapus semua cache (prefix maintiva-agent-service_) |
| """ |
| if not request.room_id and not request.user_id: |
| raise HTTPException( |
| status_code=400, |
| detail="Sediakan minimal salah satu: room_id atau user_id. Untuk clear semua cache gunakan endpoint DELETE /cache/all." |
| ) |
|
|
| redis = await get_redis() |
| deleted = 0 |
|
|
| if request.room_id: |
| pattern = f"{settings.redis_prefix}chat:{request.room_id}:*" |
| keys = await redis.keys(pattern) |
| if keys: |
| deleted += await redis.delete(*keys) |
|
|
| if request.user_id: |
| pattern = f"{settings.redis_prefix}retrieval:{request.user_id}:*" |
| keys = await redis.keys(pattern) |
| if keys: |
| deleted += await redis.delete(*keys) |
|
|
| return {"deleted_keys": deleted, "room_id": request.room_id, "user_id": request.user_id} |
|
|
|
|
| @router.delete("/cache/all") |
| @log_execution(logger) |
| async def clear_all_cache(): |
| """Hapus semua cache Redis: app cache (maintiva-agent-service_*) + LangChain LLM cache (langchain:*).""" |
| redis = await get_redis() |
|
|
| |
| app_keys = await redis.keys(f"{settings.redis_prefix}*") |
| deleted = 0 |
| if app_keys: |
| deleted += await redis.delete(*app_keys) |
|
|
| |
| lc_keys = await redis.keys("langchain:*") |
| if lc_keys: |
| deleted += await redis.delete(*lc_keys) |
|
|
| return {"deleted_keys": deleted} |
|
|