|
|
""" |
|
|
Grounded Answer Generator |
|
|
|
|
|
Generates answers from retrieved context with citations. |
|
|
Uses local LLMs (Ollama) or cloud APIs. |
|
|
""" |
|
|
|
|
|
from typing import List, Optional, Dict, Any, Tuple |
|
|
from pydantic import BaseModel, Field |
|
|
from loguru import logger |
|
|
import json |
|
|
import re |
|
|
|
|
|
from .retriever import RetrievedChunk, DocumentRetriever, get_document_retriever |
|
|
|
|
|
try: |
|
|
import httpx |
|
|
HTTPX_AVAILABLE = True |
|
|
except ImportError: |
|
|
HTTPX_AVAILABLE = False |
|
|
|
|
|
|
|
|
class GeneratorConfig(BaseModel): |
|
|
"""Configuration for grounded generator.""" |
|
|
|
|
|
llm_provider: str = Field( |
|
|
default="ollama", |
|
|
description="LLM provider: ollama, openai" |
|
|
) |
|
|
ollama_base_url: str = Field( |
|
|
default="http://localhost:11434", |
|
|
description="Ollama API base URL" |
|
|
) |
|
|
ollama_model: str = Field( |
|
|
default="llama3.2:3b", |
|
|
description="Ollama model for generation" |
|
|
) |
|
|
|
|
|
|
|
|
openai_model: str = Field( |
|
|
default="gpt-4o-mini", |
|
|
description="OpenAI model for generation" |
|
|
) |
|
|
openai_api_key: Optional[str] = Field( |
|
|
default=None, |
|
|
description="OpenAI API key" |
|
|
) |
|
|
|
|
|
|
|
|
temperature: float = Field(default=0.1, ge=0.0, le=2.0) |
|
|
max_tokens: int = Field(default=1024, ge=1) |
|
|
timeout: float = Field(default=120.0, ge=1.0) |
|
|
|
|
|
|
|
|
require_citations: bool = Field( |
|
|
default=True, |
|
|
description="Require citations in answers" |
|
|
) |
|
|
citation_format: str = Field( |
|
|
default="[{index}]", |
|
|
description="Citation format template" |
|
|
) |
|
|
abstain_on_low_confidence: bool = Field( |
|
|
default=True, |
|
|
description="Abstain when confidence is low" |
|
|
) |
|
|
confidence_threshold: float = Field( |
|
|
default=0.6, |
|
|
ge=0.0, |
|
|
le=1.0, |
|
|
description="Minimum confidence threshold" |
|
|
) |
|
|
|
|
|
|
|
|
class Citation(BaseModel): |
|
|
"""A citation reference.""" |
|
|
index: int |
|
|
chunk_id: str |
|
|
page: Optional[int] = None |
|
|
text_snippet: str |
|
|
confidence: float |
|
|
|
|
|
|
|
|
class GeneratedAnswer(BaseModel): |
|
|
"""Generated answer with citations.""" |
|
|
answer: str |
|
|
citations: List[Citation] |
|
|
confidence: float |
|
|
abstained: bool = False |
|
|
abstain_reason: Optional[str] = None |
|
|
|
|
|
|
|
|
num_chunks_used: int |
|
|
query: str |
|
|
|
|
|
|
|
|
class GroundedGenerator: |
|
|
""" |
|
|
Generates grounded answers with citations. |
|
|
|
|
|
Features: |
|
|
- Uses retrieved chunks as context |
|
|
- Generates answers with inline citations |
|
|
- Confidence-based abstention |
|
|
- Support for Ollama and OpenAI |
|
|
""" |
|
|
|
|
|
SYSTEM_PROMPT = """You are a precise document question-answering assistant. |
|
|
Your task is to answer questions based ONLY on the provided context from documents. |
|
|
|
|
|
Rules: |
|
|
1. Only use information from the provided context |
|
|
2. Cite your sources using [N] notation where N is the chunk number |
|
|
3. If the context doesn't contain enough information, say "I cannot answer this based on the available context" |
|
|
4. Be precise and concise |
|
|
5. If information is uncertain or partial, indicate this clearly |
|
|
|
|
|
Context format: Each chunk is numbered [1], [2], etc. with page numbers and content. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
config: Optional[GeneratorConfig] = None, |
|
|
retriever: Optional[DocumentRetriever] = None, |
|
|
): |
|
|
""" |
|
|
Initialize generator. |
|
|
|
|
|
Args: |
|
|
config: Generator configuration |
|
|
retriever: Document retriever instance |
|
|
""" |
|
|
self.config = config or GeneratorConfig() |
|
|
self._retriever = retriever |
|
|
|
|
|
@property |
|
|
def retriever(self) -> DocumentRetriever: |
|
|
"""Get retriever (lazy initialization).""" |
|
|
if self._retriever is None: |
|
|
self._retriever = get_document_retriever() |
|
|
return self._retriever |
|
|
|
|
|
def generate( |
|
|
self, |
|
|
query: str, |
|
|
chunks: List[RetrievedChunk], |
|
|
additional_context: Optional[str] = None, |
|
|
) -> GeneratedAnswer: |
|
|
""" |
|
|
Generate an answer from retrieved chunks. |
|
|
|
|
|
Args: |
|
|
query: User question |
|
|
chunks: Retrieved context chunks |
|
|
additional_context: Optional additional context |
|
|
|
|
|
Returns: |
|
|
GeneratedAnswer with citations |
|
|
""" |
|
|
|
|
|
if self.config.abstain_on_low_confidence and chunks: |
|
|
avg_confidence = sum(c.similarity for c in chunks) / len(chunks) |
|
|
if avg_confidence < self.config.confidence_threshold: |
|
|
return GeneratedAnswer( |
|
|
answer="I cannot provide a confident answer based on the available context.", |
|
|
citations=[], |
|
|
confidence=avg_confidence, |
|
|
abstained=True, |
|
|
abstain_reason=f"Average confidence ({avg_confidence:.2f}) below threshold ({self.config.confidence_threshold})", |
|
|
num_chunks_used=len(chunks), |
|
|
query=query, |
|
|
) |
|
|
|
|
|
|
|
|
context = self._build_context(chunks, additional_context) |
|
|
|
|
|
|
|
|
prompt = self._build_prompt(query, context) |
|
|
|
|
|
|
|
|
if self.config.llm_provider == "ollama": |
|
|
raw_answer = self._generate_ollama(prompt) |
|
|
elif self.config.llm_provider == "openai": |
|
|
raw_answer = self._generate_openai(prompt) |
|
|
else: |
|
|
raise ValueError(f"Unknown LLM provider: {self.config.llm_provider}") |
|
|
|
|
|
|
|
|
citations = self._extract_citations(raw_answer, chunks) |
|
|
|
|
|
|
|
|
if citations: |
|
|
confidence = sum(c.confidence for c in citations) / len(citations) |
|
|
elif chunks: |
|
|
confidence = sum(c.similarity for c in chunks) / len(chunks) |
|
|
else: |
|
|
confidence = 0.0 |
|
|
|
|
|
return GeneratedAnswer( |
|
|
answer=raw_answer, |
|
|
citations=citations, |
|
|
confidence=confidence, |
|
|
abstained=False, |
|
|
num_chunks_used=len(chunks), |
|
|
query=query, |
|
|
) |
|
|
|
|
|
def answer_question( |
|
|
self, |
|
|
query: str, |
|
|
top_k: int = 5, |
|
|
filters: Optional[Dict[str, Any]] = None, |
|
|
) -> GeneratedAnswer: |
|
|
""" |
|
|
Retrieve context and generate answer. |
|
|
|
|
|
Args: |
|
|
query: User question |
|
|
top_k: Number of chunks to retrieve |
|
|
filters: Optional retrieval filters |
|
|
|
|
|
Returns: |
|
|
GeneratedAnswer with citations |
|
|
""" |
|
|
|
|
|
chunks = self.retriever.retrieve(query, top_k=top_k, filters=filters) |
|
|
|
|
|
if not chunks: |
|
|
return GeneratedAnswer( |
|
|
answer="I could not find any relevant information in the documents to answer this question.", |
|
|
citations=[], |
|
|
confidence=0.0, |
|
|
abstained=True, |
|
|
abstain_reason="No relevant chunks found", |
|
|
num_chunks_used=0, |
|
|
query=query, |
|
|
) |
|
|
|
|
|
return self.generate(query, chunks) |
|
|
|
|
|
def _build_context( |
|
|
self, |
|
|
chunks: List[RetrievedChunk], |
|
|
additional_context: Optional[str] = None, |
|
|
) -> str: |
|
|
"""Build context string from chunks.""" |
|
|
parts = [] |
|
|
|
|
|
if additional_context: |
|
|
parts.append(f"Additional context:\n{additional_context}\n") |
|
|
|
|
|
parts.append("Document excerpts:") |
|
|
|
|
|
for i, chunk in enumerate(chunks, 1): |
|
|
header = f"\n[{i}]" |
|
|
if chunk.page is not None: |
|
|
header += f" (Page {chunk.page + 1}" |
|
|
if chunk.chunk_type: |
|
|
header += f", {chunk.chunk_type}" |
|
|
header += ")" |
|
|
|
|
|
parts.append(f"{header}:") |
|
|
parts.append(chunk.text) |
|
|
|
|
|
return "\n".join(parts) |
|
|
|
|
|
def _build_prompt(self, query: str, context: str) -> str: |
|
|
"""Build the full prompt.""" |
|
|
return f"""Based on the following context, answer the question. |
|
|
|
|
|
{context} |
|
|
|
|
|
Question: {query} |
|
|
|
|
|
Answer (cite sources using [N] notation):""" |
|
|
|
|
|
def _generate_ollama(self, prompt: str) -> str: |
|
|
"""Generate using Ollama.""" |
|
|
if not HTTPX_AVAILABLE: |
|
|
raise ImportError("httpx required for Ollama") |
|
|
|
|
|
with httpx.Client(timeout=self.config.timeout) as client: |
|
|
response = client.post( |
|
|
f"{self.config.ollama_base_url}/api/generate", |
|
|
json={ |
|
|
"model": self.config.ollama_model, |
|
|
"prompt": prompt, |
|
|
"system": self.SYSTEM_PROMPT, |
|
|
"stream": False, |
|
|
"options": { |
|
|
"temperature": self.config.temperature, |
|
|
"num_predict": self.config.max_tokens, |
|
|
}, |
|
|
}, |
|
|
) |
|
|
response.raise_for_status() |
|
|
result = response.json() |
|
|
|
|
|
return result.get("response", "").strip() |
|
|
|
|
|
def _generate_openai(self, prompt: str) -> str: |
|
|
"""Generate using OpenAI.""" |
|
|
try: |
|
|
import openai |
|
|
except ImportError: |
|
|
raise ImportError("openai package required") |
|
|
|
|
|
client = openai.OpenAI(api_key=self.config.openai_api_key) |
|
|
|
|
|
response = client.chat.completions.create( |
|
|
model=self.config.openai_model, |
|
|
messages=[ |
|
|
{"role": "system", "content": self.SYSTEM_PROMPT}, |
|
|
{"role": "user", "content": prompt}, |
|
|
], |
|
|
temperature=self.config.temperature, |
|
|
max_tokens=self.config.max_tokens, |
|
|
) |
|
|
|
|
|
return response.choices[0].message.content.strip() |
|
|
|
|
|
def _extract_citations( |
|
|
self, |
|
|
answer: str, |
|
|
chunks: List[RetrievedChunk], |
|
|
) -> List[Citation]: |
|
|
"""Extract citations from answer text.""" |
|
|
citations = [] |
|
|
seen_indices = set() |
|
|
|
|
|
|
|
|
pattern = r'\[(\d+)\]' |
|
|
matches = re.findall(pattern, answer) |
|
|
|
|
|
for match in matches: |
|
|
index = int(match) |
|
|
if index in seen_indices: |
|
|
continue |
|
|
if index < 1 or index > len(chunks): |
|
|
continue |
|
|
|
|
|
seen_indices.add(index) |
|
|
chunk = chunks[index - 1] |
|
|
|
|
|
citation = Citation( |
|
|
index=index, |
|
|
chunk_id=chunk.chunk_id, |
|
|
page=chunk.page, |
|
|
text_snippet=chunk.text[:150] + ("..." if len(chunk.text) > 150 else ""), |
|
|
confidence=chunk.similarity, |
|
|
) |
|
|
citations.append(citation) |
|
|
|
|
|
return sorted(citations, key=lambda c: c.index) |
|
|
|
|
|
|
|
|
|
|
|
_grounded_generator: Optional[GroundedGenerator] = None |
|
|
|
|
|
|
|
|
def get_grounded_generator( |
|
|
config: Optional[GeneratorConfig] = None, |
|
|
retriever: Optional[DocumentRetriever] = None, |
|
|
) -> GroundedGenerator: |
|
|
""" |
|
|
Get or create singleton grounded generator. |
|
|
|
|
|
Args: |
|
|
config: Generator configuration |
|
|
retriever: Optional retriever instance |
|
|
|
|
|
Returns: |
|
|
GroundedGenerator instance |
|
|
""" |
|
|
global _grounded_generator |
|
|
|
|
|
if _grounded_generator is None: |
|
|
_grounded_generator = GroundedGenerator( |
|
|
config=config, |
|
|
retriever=retriever, |
|
|
) |
|
|
|
|
|
return _grounded_generator |
|
|
|
|
|
|
|
|
def reset_grounded_generator(): |
|
|
"""Reset the global generator instance.""" |
|
|
global _grounded_generator |
|
|
_grounded_generator = None |
|
|
|