Spaces:
Running
Running
| import asyncio | |
| import json | |
| import re | |
| import time | |
| from fastapi import APIRouter, Request, Depends | |
| from fastapi.responses import StreamingResponse | |
| from app.models.chat import ChatRequest | |
| from app.models.pipeline import PipelineState | |
| from app.security.rate_limiter import chat_rate_limit | |
| from app.security.jwt_auth import verify_jwt | |
| router = APIRouter() | |
| # Phrases a visitor uses when telling the bot it gave a wrong answer. | |
| # Matched on the lowercased raw message before any LLM call β O(1), zero cost. | |
| _CRITICISM_SIGNALS: frozenset[str] = frozenset({ | |
| "that's wrong", "thats wrong", "you're wrong", "youre wrong", | |
| "not right", "wrong answer", "you got it wrong", "that is wrong", | |
| "that's incorrect", "you're incorrect", "thats incorrect", "youre incorrect", | |
| "fix that", "fix your answer", "actually no", "no that's", "no thats", | |
| "that was wrong", "your answer was wrong", "wrong information", | |
| "incorrect information", "that's not right", "thats not right", | |
| }) | |
| def _is_criticism(message: str) -> bool: | |
| lowered = message.lower() | |
| return any(sig in lowered for sig in _CRITICISM_SIGNALS) | |
| async def _generate_follow_ups( | |
| query: str, | |
| answer: str, | |
| sources: list, | |
| llm_client, | |
| ) -> list[str]: | |
| """ | |
| Generates 3 specific follow-up questions after the main answer is complete. | |
| Runs after the answer stream finishes β zero added latency before first token. | |
| Questions MUST: | |
| - Be grounded in the source documents that were actually retrieved (not hypothetical). | |
| - Lead the visitor deeper into content the knowledge base ALREADY contains. | |
| - Never venture into topics not covered by the retrieved sources (no hallucinated follow-ups). | |
| - Be specific (< 12 words, no generic "tell me more" style). | |
| """ | |
| # Collect source titles AND types so the LLM knows what was actually retrieved. | |
| source_info = [] | |
| for s in sources[:4]: | |
| title = s.title if hasattr(s, "title") else s.get("title", "") | |
| src_type = s.source_type if hasattr(s, "source_type") else s.get("source_type", "") | |
| url = s.url if hasattr(s, "url") else s.get("url", "") | |
| if title: | |
| source_info.append(f"{title} ({src_type})" if src_type else title) | |
| sources_str = "\n".join(f"- {si}" for si in source_info) if source_info else "- (no specific sources)" | |
| prompt = ( | |
| f"Visitor's question: {query}\n\n" | |
| f"Answer given (excerpt): {answer[:500]}\n\n" | |
| f"Sources that were retrieved and cited in the answer:\n{sources_str}\n\n" | |
| "Write exactly 3 follow-up questions the visitor would logically ask NEXT, " | |
| "based ONLY on what was found in the sources above. " | |
| "Each question must be clearly answerable from the retrieved sources β " | |
| "do NOT invent topics that are not present in the sources listed. " | |
| "Each question must be under 12 words. " | |
| "Output ONLY the 3 questions, one per line, no numbering or bullet points." | |
| ) | |
| system = ( | |
| "You write concise follow-up questions for a portfolio chatbot. " | |
| "CRITICAL RULE: every question you write must be answerable from the source documents listed. " | |
| "Never invent follow-ups about topics, projects, or facts not mentioned in the retrieved sources. " | |
| "Never write generic questions like 'tell me more' or 'what else can you tell me'. " | |
| "Each question must be under 12 words and reference specifics from the answer and sources." | |
| ) | |
| try: | |
| stream = llm_client.complete_with_complexity( | |
| prompt=prompt, system=system, stream=True, complexity="simple" | |
| ) | |
| raw = "" | |
| async for token in stream: | |
| raw += token | |
| questions = [q.strip() for q in raw.strip().splitlines() if q.strip()][:3] | |
| return questions | |
| except Exception: | |
| return [] | |
| async def _update_summary_async( | |
| conv_store, | |
| gemini_client, | |
| session_id: str, | |
| previous_summary: str | None, | |
| query: str, | |
| answer: str, | |
| processing_api_key: str | None, | |
| ) -> None: | |
| """ | |
| Triggered post-response to update the rolling conversation summary. | |
| Failures are silently swallowed β summary is best-effort context, not critical. | |
| """ | |
| try: | |
| new_summary = await gemini_client.update_conversation_summary( | |
| previous_summary=previous_summary or "", | |
| new_turn_q=query, | |
| new_turn_a=answer[:600], # cap answer chars sent to Gemini | |
| processing_api_key=processing_api_key, | |
| ) | |
| if new_summary: | |
| conv_store.set_summary(session_id, new_summary) | |
| except Exception: | |
| pass | |
| async def chat_endpoint( | |
| request: Request, | |
| request_data: ChatRequest, | |
| token_payload: dict = Depends(verify_jwt), | |
| ) -> StreamingResponse: | |
| """Stream RAG answer as typed SSE events. | |
| Event sequence for a full RAG request: | |
| event: status β guard label, cache miss, gemini routing, retrieve labels | |
| event: reading β one per unique source found in Qdrant (before rerank) | |
| event: sources β final selected sources array (after rerank) | |
| event: thinking β CoT scratchpad tokens (70B only) | |
| event: token β answer tokens | |
| event: follow_ups β three suggested follow-up questions | |
| For cache hits: status β status β token | |
| For Gemini fast-path: status β status β token | |
| """ | |
| start_time = time.monotonic() | |
| pipeline = request.app.state.pipeline | |
| conv_store = request.app.state.conversation_store | |
| llm_client = request.app.state.llm_client | |
| session_id = request_data.session_id | |
| conversation_history = conv_store.get_recent(session_id) | |
| conversation_summary = conv_store.get_summary(session_id) | |
| criticism = _is_criticism(request_data.message) | |
| if criticism and conversation_history: | |
| conv_store.mark_last_negative(session_id) | |
| # Stage 2: decontextualize the query concurrently with Guard when we have a | |
| # rolling summary. Reference-heavy queries like "tell me more about that project" | |
| # embed poorly; a self-contained rewrite fixes retrieval without added latency | |
| # because Gemini Flash runs while Guard is classifying the query. | |
| gemini_client = getattr(request.app.state, "gemini_client", None) | |
| decontextualized_query: str | None = None | |
| decontext_task: asyncio.Task | None = None | |
| if conversation_summary and gemini_client and gemini_client.is_configured: | |
| decontext_task = asyncio.create_task( | |
| gemini_client.decontextualize_query(request_data.message, conversation_summary) | |
| ) | |
| # Bug 4: concurrent query expansion β starts at request entry so it runs | |
| # while Guard, Cache, and Gemini-fast-path execute. Result is ready before | |
| # the Retrieve node needs it (800 ms budget). Gemini uses the TOON context | |
| # to generate canonical name forms (for BM25) and semantic expansions (for | |
| # dense multi-search). Falls back to empty if Gemini unavailable or slow. | |
| expansion_task: asyncio.Task | None = None | |
| if gemini_client and gemini_client.is_configured: | |
| expansion_task = asyncio.create_task( | |
| gemini_client.expand_query(request_data.message) | |
| ) | |
| # Await decontextualization result before the pipeline begins (retrieve node | |
| # will use it if present; Guard runs first so the latency is masked). | |
| if decontext_task is not None: | |
| try: | |
| result = await asyncio.wait_for(decontext_task, timeout=3.0) | |
| if result and result.strip().lower() != request_data.message.strip().lower(): | |
| decontextualized_query = result.strip() | |
| except Exception: | |
| pass # Decontextualization is best-effort; fall back to raw query. | |
| # Await expansion result β 800 ms budget so Guard+Cache latency is fully masked. | |
| expansion_result: dict | None = None | |
| if expansion_task is not None: | |
| try: | |
| expansion_result = await asyncio.wait_for(expansion_task, timeout=0.8) | |
| except Exception: | |
| pass # Expansion is best-effort; retriever falls back to raw query. | |
| initial_state: PipelineState = { # type: ignore[assignment] | |
| "query": request_data.message, | |
| "session_id": request_data.session_id, | |
| "query_complexity": "simple", | |
| # Bug 4: seed expanded_queries with Gemini semantic expansions so the | |
| # retrieve node issues one dense search per expansion (up to 3 extras). | |
| # operator.add in PipelineState merges these with any queries added later | |
| # (e.g. the rag_query from gemini_fast routing to RAG). | |
| "expanded_queries": (expansion_result or {}).get("semantic_expansions", []), | |
| "retrieved_chunks": [], | |
| "reranked_chunks": [], | |
| "answer": "", | |
| "sources": [], | |
| "cached": False, | |
| "cache_key": None, | |
| "guard_passed": False, | |
| "thinking": False, | |
| "conversation_history": conversation_history, | |
| "is_criticism": criticism, | |
| "latency_ms": 0, | |
| "error": None, | |
| "interaction_id": None, | |
| "retrieval_attempts": 0, | |
| "rewritten_query": None, | |
| "follow_ups": [], | |
| "path": None, | |
| "query_topic": None, | |
| # Stage 1: follow-up bypass for Gemini fast-path | |
| "is_followup": request_data.is_followup, | |
| # Stage 2: progressive history summarisation | |
| "conversation_summary": conversation_summary or None, | |
| "decontextualized_query": decontextualized_query, | |
| # Stage 3: SELF-RAG critic scores (populated by generate node) | |
| "critic_groundedness": None, | |
| "critic_completeness": None, | |
| "critic_specificity": None, | |
| "critic_quality": None, | |
| # Fix 1: enumeration classifier β populated by enumerate_query node | |
| "is_enumeration_query": False, | |
| # Bug 4: query expansion β canonical name forms for BM25 union search. | |
| "query_canonical_forms": (expansion_result or {}).get("canonical_forms", []), | |
| } | |
| async def sse_generator(): | |
| final_sources = [] | |
| is_cached = False | |
| final_answer = "" | |
| interaction_id = None | |
| try: | |
| # stream_mode=["custom", "updates"] yields (mode, data) tuples: | |
| # mode="custom" β data is whatever writer(payload) was called with | |
| # mode="updates" β data is {node_name: state_updates_dict} | |
| async for mode, data in pipeline.astream( | |
| initial_state, | |
| stream_mode=["custom", "updates"], | |
| ): | |
| if await request.is_disconnected(): | |
| break | |
| if mode == "custom": | |
| # Forward writer events as named SSE events. | |
| # Each node emits {"type": "<event_name>", ...payload}. | |
| event_type = data.get("type", "status") | |
| # Strip the "type" key so the client receives a clean payload. | |
| payload = {k: v for k, v in data.items() if k != "type"} | |
| yield f"event: {event_type}\ndata: {json.dumps(payload)}\n\n" | |
| elif mode == "updates": | |
| # Capture terminal state for the done event; do not re-emit tokens. | |
| for _node_name, updates in data.items(): | |
| if "sources" in updates and updates["sources"]: | |
| final_sources = updates["sources"] | |
| if "cached" in updates: | |
| is_cached = updates["cached"] | |
| if "interaction_id" in updates and updates["interaction_id"] is not None: | |
| interaction_id = updates["interaction_id"] | |
| if "answer" in updates and updates["answer"]: | |
| final_answer = updates["answer"] | |
| elapsed_ms = int((time.monotonic() - start_time) * 1000) | |
| # Citation-index filtering β single serialisation-time safety net. | |
| # Applies to all paths (RAG, Gemini fast-path, enumeration). | |
| # If the answer cites only [3][5], only sources 3 and 5 are sent; | |
| # all other chunks retrieved but not cited are discarded here. | |
| if final_answer and final_sources: | |
| cited_nums = {int(m) for m in re.findall(r"\[(\d+)\]", final_answer)} | |
| if cited_nums: | |
| final_sources = [ | |
| s for i, s in enumerate(final_sources, start=1) | |
| if i in cited_nums | |
| ] | |
| sources_list = [ | |
| s.model_dump() if hasattr(s, "model_dump") | |
| else s.dict() if hasattr(s, "dict") | |
| else s | |
| for s in final_sources | |
| ] | |
| # The done event uses plain data: (no event: type) for backward | |
| # compatibility with widgets that listen on the raw data channel. | |
| yield ( | |
| f"data: {json.dumps({'done': True, 'sources': sources_list, 'cached': is_cached, 'latency_ms': elapsed_ms, 'interaction_id': interaction_id})}\n\n" | |
| ) | |
| # ββ Follow-up questions ββββββββββββββββββββββββββββββββββββββββββββ | |
| # Generated after the done event so it never delays answer delivery. | |
| if final_answer and not await request.is_disconnected(): | |
| follow_ups = await _generate_follow_ups( | |
| request_data.message, final_answer, final_sources, llm_client | |
| ) | |
| if follow_ups: | |
| yield f"event: follow_ups\ndata: {json.dumps({'questions': follow_ups})}\n\n" | |
| # Stage 2: update rolling summary asynchronously β fired after the | |
| # response is fully delivered so it adds zero latency to the turn. | |
| if final_answer and gemini_client and gemini_client.is_configured: | |
| processing_key = getattr( | |
| request.app.state, "gemini_processing_api_key", None | |
| ) | |
| asyncio.create_task( | |
| _update_summary_async( | |
| conv_store=conv_store, | |
| gemini_client=gemini_client, | |
| session_id=session_id, | |
| previous_summary=conversation_summary, | |
| query=request_data.message, | |
| answer=final_answer, | |
| processing_api_key=processing_key, | |
| ) | |
| ) | |
| except Exception as exc: | |
| yield f"data: {json.dumps({'error': str(exc) or 'Generation failed'})}\n\n" | |
| return StreamingResponse( | |
| sse_generator(), | |
| media_type="text/event-stream", | |
| headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, | |
| ) | |