ishaq101's picture
update fast intent
74d7562
"""Chat endpoint with streaming support."""
import asyncio
import uuid
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from src.db.postgres.connection import get_db
from src.db.postgres.models import ChatMessage, MessageSource, Room
from src.agents.orchestration import orchestrator
from src.agents.chatbot import chatbot
from src.rag.retriever import retriever
from src.db.redis.connection import get_redis
from src.config.settings import settings
from src.middlewares.logging import get_logger, log_execution
from sse_starlette.sse import EventSourceResponse
from langchain_core.messages import HumanMessage, AIMessage
from sqlalchemy import select
from pydantic import BaseModel
from typing import List, Dict, Any, Optional
import json
_GREETINGS = frozenset(["hi", "hello", "hey", "halo", "hai", "hei"])
_THANKS = frozenset(["thanks", "thank you", "terima kasih", "makasih", "thx"])
_GOODBYES = frozenset(["bye", "goodbye", "sampai jumpa", "dadah", "see you"])
def _fast_intent(message: str) -> Optional[dict]:
"""Bypass LLM orchestrator for obvious greetings, thanks, and farewells."""
lower = message.lower().strip().rstrip("!.,?")
if lower in _GREETINGS:
return {"intent": "greeting", "needs_search": False,
"direct_response": "Halo! Ada yang bisa saya bantu?", "search_query": ""}
if lower in _THANKS:
return {"intent": "thanks", "needs_search": False,
"direct_response": "Sama-sama! Ada yang bisa saya bantu lagi?", "search_query": ""}
if lower in _GOODBYES:
return {"intent": "goodbye", "needs_search": False,
"direct_response": "Sampai jumpa! Semoga harimu menyenangkan.", "search_query": ""}
return None
logger = get_logger("chat_api")
router = APIRouter(prefix="/api/v1", tags=["Chat"])
class ChatRequest(BaseModel):
user_id: str
room_id: str
message: str
class ClearCacheRequest(BaseModel):
room_id: Optional[str] = None
user_id: Optional[str] = None
_INJECTION_PHRASES = [
"ignore previous instructions",
"ignore all prior",
"disregard the above",
"disregard previous",
"you are now",
"your new instructions are",
"new system prompt",
"override your instructions",
]
def _sanitize_content(text: str) -> str:
"""Escape XML metacharacters and neutralize prompt injection phrases. Pure string ops."""
text = text.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")
lower = text.lower()
for phrase in _INJECTION_PHRASES:
idx = lower.find(phrase)
while idx != -1:
text = text[:idx] + "[content removed]" + text[idx + len(phrase):]
lower = text.lower()
idx = lower.find(phrase, idx + len("[content removed]"))
return text.strip()
def _format_context(relevant_docs: List[Dict[str, Any]], fallback_docs: List[Dict[str, Any]]) -> str:
"""Format retrieval results as XML-delimited context for the LLM.
Injects <context_status> so the system prompt can enforce the correct behavior:
- relevant: docs passed the similarity threshold → answer from them
- not_relevant: no docs passed threshold but fallback docs exist → suggest questions
- no_documents: nothing retrieved at all → ask user to upload docs
"""
def _render_docs(docs: List[Dict[str, Any]]) -> str:
parts = []
for i, result in enumerate(docs, start=1):
data = result["metadata"].get("data", result["metadata"])
filename = data.get("filename", "Unknown")
page = data.get("page_label")
source_label = f"{filename}, p.{page}" if page else filename
sanitized = _sanitize_content(result["content"])
parts.append(
f' <document index="{i}" source="{source_label}">\n'
f' {sanitized}\n'
f' </document>'
)
return "<documents>\n" + "\n".join(parts) + "\n</documents>"
if relevant_docs:
return "<context_status>relevant</context_status>\n" + _render_docs(relevant_docs)
elif fallback_docs:
return "<context_status>not_relevant</context_status>\n" + _render_docs(fallback_docs)
else:
return "<context_status>no_documents</context_status>"
def _extract_sources(results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Extract deduplicated source references from retrieval results."""
seen = set()
sources = []
for result in results:
data = result["metadata"].get("data", result["metadata"])
key = (data.get("document_id"), data.get("page_label"))
if key not in seen:
seen.add(key)
sources.append({
"document_id": data.get("document_id"),
"filename": data.get("filename", "Unknown"),
"page_label": data.get("page_label"),
})
return sources
async def get_cached_response(redis, cache_key: str) -> Optional[str]:
cached = await redis.get(cache_key)
if cached:
return json.loads(cached)
return None
async def cache_response(redis, cache_key: str, response: str):
await redis.setex(cache_key, 86400, json.dumps(response))
async def load_history(db: AsyncSession, room_id: str, limit: int = 10) -> list:
"""Load recent chat messages for a room as LangChain message objects (oldest-first)."""
result = await db.execute(
select(ChatMessage)
.where(ChatMessage.room_id == room_id)
.order_by(ChatMessage.created_at.asc())
.limit(limit)
)
rows = result.scalars().all()
return [
HumanMessage(content=row.content) if row.role == "user" else AIMessage(content=row.content)
for row in rows
]
async def _ensure_room(db: AsyncSession, room_id: str, user_id: str) -> None:
"""Create the room if it doesn't already exist."""
result = await db.execute(select(Room).where(Room.id == room_id))
if result.scalar_one_or_none() is None:
db.add(Room(id=room_id, user_id=user_id, title="New Chat"))
async def save_messages(
db: AsyncSession,
room_id: str,
user_id: str,
user_content: str,
assistant_content: str,
audio_text: str = "",
sources: Optional[List[Dict[str, Any]]] = None,
):
"""Persist user and assistant messages, and attach sources to the assistant message."""
await _ensure_room(db, room_id, user_id)
db.add(ChatMessage(id=str(uuid.uuid4()), room_id=room_id, role="user", content=user_content))
assistant_id = str(uuid.uuid4())
db.add(ChatMessage(id=assistant_id, room_id=room_id, role="assistant", content=assistant_content, audio_text=audio_text))
for src in (sources or []):
page = src.get("page_label")
db.add(MessageSource(
id=str(uuid.uuid4()),
message_id=assistant_id,
document_id=src.get("document_id"),
filename=src.get("filename"),
page_label=str(page) if page is not None else None,
))
await db.commit()
@router.post("/chat/stream")
@log_execution(logger)
async def chat_stream(request: ChatRequest, db: AsyncSession = Depends(get_db)):
"""Chat endpoint with streaming response.
SSE event sequence:
1. sources — JSON array of {document_id, filename, page_label}
2. chunk — text fragments of the answer
3. done — signals end of stream
"""
redis = await get_redis()
cache_key = f"{settings.redis_prefix}chat:{request.room_id}:{request.message}"
cached = await get_cached_response(redis, cache_key)
if cached:
logger.info("Returning cached response")
async def stream_cached():
yield {"event": "sources", "data": json.dumps([])}
for i in range(0, len(cached), 50):
yield {"event": "chunk", "data": cached[i:i + 50]}
yield {"event": "done", "data": ""}
return EventSourceResponse(stream_cached())
try:
# Step 1: Fast local intent check (skips LLM for greetings/farewells)
intent_result = _fast_intent(request.message)
context = ""
sources: List[Dict[str, Any]] = []
if intent_result is None:
# Step 2: Launch retrieval and history loading in parallel, then run orchestrator
retrieval_task = asyncio.create_task(
retriever.retrieve(request.message, request.user_id, db)
)
history_task = asyncio.create_task(
load_history(db, request.room_id, limit=6) # 6 msgs (3 pairs) for orchestrator
)
history = await history_task # fast DB query (<100ms), done before orchestrator finishes
intent_result = await orchestrator.analyze_message(request.message, history)
if not intent_result.get("needs_search"):
retrieval_task.cancel()
relevant_docs, fallback_docs = [], []
else:
search_query = intent_result.get("search_query", request.message)
logger.info(f"Searching for: {search_query}")
if search_query != request.message:
retrieval_task.cancel()
relevant_docs, fallback_docs = await retriever.retrieve(
query=search_query,
user_id=request.user_id,
db=db,
)
else:
relevant_docs, fallback_docs = await retrieval_task
context = _format_context(relevant_docs, fallback_docs)
logger.info(f"assembled context ({context})")
sources = _extract_sources(relevant_docs)
# Step 3: Direct response for greetings / non-document intents
if intent_result.get("direct_response"):
response = intent_result["direct_response"]
await cache_response(redis, cache_key, response)
async def stream_direct():
audio_text = await chatbot.generate_audio_text(response)
yield {"event": "sources", "data": json.dumps([])}
yield {"event": "message", "data": response}
yield {"event": "audio_text", "data": audio_text}
yield {"event": "done", "data": ""}
await save_messages(db, request.room_id, request.user_id, request.message, response, audio_text=audio_text, sources=[])
return EventSourceResponse(stream_direct())
# Step 4: Stream answer token-by-token as LLM generates it
# Load full history (10 msgs) for chatbot — richer context than the 6 used by orchestrator
full_history = await load_history(db, request.room_id, limit=10)
messages = full_history + [HumanMessage(content=request.message)]
async def stream_response():
full_response = ""
yield {"event": "sources", "data": json.dumps(sources)}
async for token in chatbot.astream_response(messages, context):
full_response += token
yield {"event": "chunk", "data": token}
# Fire audio_text generation and cache write concurrently once streaming completes
audio_text_task = asyncio.create_task(chatbot.generate_audio_text(full_response))
cache_task = asyncio.create_task(cache_response(redis, cache_key, full_response))
audio_text = await audio_text_task
yield {"event": "audio_text", "data": audio_text}
yield {"event": "done", "data": ""}
await cache_task
await save_messages(db, request.room_id, request.user_id, request.message, full_response, audio_text=audio_text, sources=sources)
return EventSourceResponse(stream_response())
except Exception as e:
logger.error("Chat failed", error=str(e))
raise HTTPException(status_code=500, detail=f"Chat failed: {str(e)}")
@router.delete("/cache")
@log_execution(logger)
async def clear_cache(request: ClearCacheRequest):
"""Clear Redis cache.
- room_id only: hapus cache chat untuk room tertentu
- user_id only: hapus cache retrieval untuk user tertentu
- keduanya: hapus cache chat room + retrieval user
- kosong: hapus semua cache (prefix maintiva-agent-service_)
"""
if not request.room_id and not request.user_id:
raise HTTPException(
status_code=400,
detail="Sediakan minimal salah satu: room_id atau user_id. Untuk clear semua cache gunakan endpoint DELETE /cache/all."
)
redis = await get_redis()
deleted = 0
if request.room_id:
pattern = f"{settings.redis_prefix}chat:{request.room_id}:*"
keys = await redis.keys(pattern)
if keys:
deleted += await redis.delete(*keys)
if request.user_id:
pattern = f"{settings.redis_prefix}retrieval:{request.user_id}:*"
keys = await redis.keys(pattern)
if keys:
deleted += await redis.delete(*keys)
return {"deleted_keys": deleted, "room_id": request.room_id, "user_id": request.user_id}
@router.delete("/cache/all")
@log_execution(logger)
async def clear_all_cache():
"""Hapus semua cache Redis: app cache (maintiva-agent-service_*) + LangChain LLM cache (langchain:*)."""
redis = await get_redis()
# Clear app-level cache (chat responses + retrieval results)
app_keys = await redis.keys(f"{settings.redis_prefix}*")
deleted = 0
if app_keys:
deleted += await redis.delete(*app_keys)
# Clear LangChain LLM response cache
lc_keys = await redis.keys("langchain:*")
if lc_keys:
deleted += await redis.delete(*lc_keys)
return {"deleted_keys": deleted}