"""Chatbot agent with RAG capabilities.""" import re from langchain_openai import AzureChatOpenAI from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.output_parsers import StrOutputParser from langchain_core.callbacks import BaseCallbackHandler from src.config.settings import settings from src.middlewares.logging import get_logger logger = get_logger("chatbot") class _CacheHitLogger(BaseCallbackHandler): """Logs Azure prompt cache hits from response usage metadata.""" def on_llm_end(self, response, **_): try: for gen_list in response.generations: for gen in gen_list: msg = getattr(gen, "message", None) if msg is None: continue usage = getattr(msg, "usage_metadata", None) if usage: cached = usage.get("input_token_details", {}).get("cache_read", 0) if cached > 0: logger.info(f"Azure prompt cache hit: {cached} cached tokens") except Exception: pass class ChatbotAgent: """Chatbot agent with RAG capabilities.""" def __init__(self): self.llm = AzureChatOpenAI( azure_deployment=settings.azureai_deployment_name_54mini, openai_api_version=settings.azureai_api_version_54mini, azure_endpoint=settings.azureai_endpoint_url_54mini, api_key=settings.azureai_api_key_54mini, temperature=0.6, callbacks=[_CacheHitLogger()], ) # Read system prompt try: with open("src/config/agents/system_prompt.md", "r") as f: system_prompt = f.read() except FileNotFoundError: system_prompt = "You are a helpful AI assistant with access to user's uploaded documents." try: with open("src/config/agents/guardrails_prompt.md", "r") as f: guardrails_prompt = f.read() except FileNotFoundError: guardrails_prompt = "" if guardrails_prompt: combined_prompt = ( system_prompt.rstrip() + "\n\n---\n\n## Safety and Behavioral Guidelines\n\n" + guardrails_prompt ) else: combined_prompt = system_prompt # Create prompt template self.prompt = ChatPromptTemplate.from_messages([ ("system", combined_prompt), MessagesPlaceholder(variable_name="messages"), ("system", "Relevant documents:\n{context}") ]) # Create chain self.chain = self.prompt | self.llm | StrOutputParser() async def generate_response( self, messages: list, context: str = "" ) -> str: """Generate response with optional RAG context.""" try: logger.info("Generating chatbot response") # Generate response response = await self.chain.ainvoke({ "messages": messages, "context": context }) logger.info(f"Generated response: {response[:100]}...") return response except Exception as e: logger.error("Response generation failed", error=str(e)) raise def language_hint(self, full_response: str) -> str: text = full_response.lower() words = set(re.findall(r"\b[\w']+\b", text, flags=re.UNICODE)) indo_markers = { "yang", "dan", "untuk", "tidak", "akan", "saya", "kamu", "kita", "mereka", "adalah", "ini", "itu", "dengan", "karena", "sebagai", "oleh", "pada", "dari", "ke", "di" } eng_markers = { "the", "and", "for", "you", "your", "i", "is", "are", "will", "not", "this", "that", "with", "because", "as", "from", "to", "in", "of" } indo_count = sum(1 for w in words if w in indo_markers) eng_count = sum(1 for w in words if w in eng_markers) if indo_count > eng_count and indo_count >= 2: return "Indonesian" if eng_count > indo_count and eng_count >= 2: return "English" if indo_count > 0 and eng_count == 0: return "Indonesian" if eng_count > 0 and indo_count == 0: return "English" return "the same language as the response" async def generate_audio_text(self, full_response: str) -> str: """Generate a 2-3 sentence TTS-friendly summary of the assistant response.""" try: lang = self.language_hint(full_response) prompt = ( "You are a text to speech assistant. Given the following AI response, " "write a plain language summary in exactly 2 or 3 sentences. " "Output only the summary text. Allowed characters are letters numbers spaces, periods only, commas or dots for decimal number formatting" "Do not output any other characters. Do not name symbols. " f"The response language is {lang}. Write the summary in {lang} only. Do not translate.\n\n" f"Response:\n{full_response}\n\n" "Summary:" ) result = await self.llm.ainvoke(prompt) logger.info(f"Generated audio text: {str(result)[:250]}...") text = result.content if hasattr(result, "content") else str(result) text = text.replace("!", ".").replace("?", ".").replace(";", ".").replace("*", "") # Split on sentence boundaries (period + space/newline), not decimal dots like "184.900" sentences = [s.strip() for s in re.split(r'\.\s+', text) if s.strip()][:3] def sanitize(sentence: str) -> str: sentence = re.sub(r"[^A-Za-z0-9 .,]", " ", sentence) sentence = re.sub(r"\s+", " ", sentence).strip() return sentence sanitized = [sanitize(s) for s in sentences if s] if not sanitized: return "" output = ". ".join(sanitized).strip() if output and not output.endswith("."): output += "." return output except Exception as e: logger.error("Audio text generation failed", error=str(e)) return "" async def astream_response(self, messages: list, context: str = ""): """Stream response tokens as they are generated.""" try: logger.info("Streaming chatbot response") async for token in self.chain.astream({"messages": messages, "context": context}): yield token except Exception as e: logger.error("Response streaming failed", error=str(e)) raise chatbot = ChatbotAgent()