Qar-Raz commited on
Commit
c7256ee
·
0 Parent(s):

hf-space: deploy branch without frontend/data/results

Browse files
.dockerignore ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python caches
2
+ __pycache__/
3
+ **/__pycache__/
4
+ *.py[cod]
5
+ *.pyo
6
+
7
+ # Virtual environments
8
+ .venv/
9
+ venv/
10
+ ENV/
11
+ env/
12
+
13
+ # Frontend app (deployed separately on Vercel)
14
+ frontend/
15
+
16
+ # Local/runtime cache
17
+ .cache/
18
+
19
+ # Explicit user-requested exclusions
20
+ /EntireBookCleaned.txt
21
+ /startup.txt
22
+
23
+ # Git and editor noise
24
+ .git/
25
+ .gitignore
26
+ .vscode/
27
+ .idea/
28
+
29
+ # OS artifacts
30
+ .DS_Store
31
+ Thumbs.db
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python specific ignores
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # Virtual environments
7
+ .venv/
8
+ venv/
9
+ env/
10
+ ENV/
11
+
12
+ # Environment and local secrets
13
+ .env
14
+ .env.*
15
+ !.env.example
16
+
17
+ # Build and packaging artifacts
18
+ build/
19
+ dist/
20
+ *.egg-info/
21
+ .eggs/
22
+
23
+ # Caches and tooling
24
+ .pytest_cache/
25
+ .mypy_cache/
26
+ .ruff_cache/
27
+ .ipynb_checkpoints/
28
+ .cache/
29
+
30
+ # IDE/editor
31
+ .vscode/
32
+ .idea/
Dockerfile ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ ENV PYTHONDONTWRITEBYTECODE=1 \
4
+ PYTHONUNBUFFERED=1 \
5
+ PIP_NO_CACHE_DIR=1
6
+
7
+ WORKDIR /app
8
+
9
+ # Minimal system packages for common Python builds.
10
+ RUN apt-get update && apt-get install -y --no-install-recommends \
11
+ build-essential \
12
+ curl \
13
+ && rm -rf /var/lib/apt/lists/*
14
+
15
+ COPY requirements.txt ./
16
+ RUN pip install --upgrade pip && pip install -r requirements.txt
17
+
18
+ COPY . .
19
+
20
+ # Fail fast during build if critical runtime folders are missing from context.
21
+ RUN test -d /app/backend && test -d /app/data && test -d /app/results
22
+
23
+ # Hugging Face Spaces exposes apps on port 7860 by default.
24
+ EXPOSE 7860
25
+
26
+ CMD ["uvicorn", "backend.api:app", "--host", "0.0.0.0", "--port", "7860"]
README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: NLP RAG
3
+ emoji: 🏢
4
+ colorFrom: gray
5
+ colorTo: green
6
+ sdk: docker
7
+ pinned: false
8
+ license: mit
9
+ short_description: NLP Spring 2026 Project 1
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
backend/api.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from backend.routes.health import router as health_router
4
+ from backend.routes.predict import router as predict_router
5
+ from backend.routes.predict_stream import router as predict_stream_router
6
+ from backend.routes.title import router as title_router
7
+ from backend.services.startup import initialize_runtime_state
8
+ from backend.state import state
9
+
10
+ # fastapi configs defined here
11
+ # all the router objects are imported here
12
+ #--@Qamar
13
+
14
+ app = FastAPI(title="RAG-AS3 API", version="0.1.0")
15
+
16
+ app.add_middleware(
17
+ CORSMiddleware,
18
+ allow_origins=["*"],
19
+ allow_credentials=True,
20
+ allow_methods=["*"],
21
+ allow_headers=["*"],
22
+ )
23
+
24
+ app.include_router(health_router)
25
+ app.include_router(title_router)
26
+ app.include_router(predict_router)
27
+ app.include_router(predict_stream_router)
28
+
29
+
30
+ @app.on_event("startup")
31
+ def startup_event() -> None:
32
+ initialize_runtime_state(state)
backend/routes/health.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter
2
+
3
+ from backend.state import REQUIRED_STATE_KEYS, state
4
+
5
+ router = APIRouter()
6
+
7
+
8
+ @router.get("/health")
9
+ def health() -> dict[str, str]:
10
+ ready = all(k in state for k in REQUIRED_STATE_KEYS)
11
+ return {"status": "ok" if ready else "starting"}
backend/routes/predict.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from typing import Any
3
+
4
+ from fastapi import APIRouter, HTTPException
5
+
6
+ from backend.schemas import PredictRequest, PredictResponse
7
+ from backend.services.chunks import build_retrieved_chunks
8
+ from backend.services.models import resolve_model
9
+ from backend.state import state
10
+ from retriever.generator import RAGGenerator
11
+ from retriever.retriever import HybridRetriever
12
+
13
+ router = APIRouter()
14
+
15
+
16
+ @router.post("/predict", response_model=PredictResponse)
17
+ def predict(payload: PredictRequest) -> PredictResponse:
18
+ req_start = time.perf_counter()
19
+
20
+ precheck_start = time.perf_counter()
21
+ if not state:
22
+ raise HTTPException(status_code=503, detail="Service not initialized yet")
23
+
24
+ query = payload.query.strip()
25
+ if not query:
26
+ raise HTTPException(status_code=400, detail="Query cannot be empty")
27
+ precheck_time = time.perf_counter() - precheck_start
28
+
29
+ state_access_start = time.perf_counter()
30
+ retriever: HybridRetriever = state["retriever"]
31
+ index = state["index"]
32
+ rag_engine: RAGGenerator = state["rag_engine"]
33
+ models: dict[str, Any] = state["models"]
34
+ chunk_lookup: dict[str, dict[str, Any]] = state.get("chunk_lookup", {})
35
+ state_access_time = time.perf_counter() - state_access_start
36
+
37
+ model_resolve_start = time.perf_counter()
38
+ model_name, model_instance = resolve_model(payload.model, models)
39
+ model_resolve_time = time.perf_counter() - model_resolve_start
40
+
41
+ retrieval_start = time.perf_counter()
42
+ contexts = retriever.search(
43
+ query,
44
+ index,
45
+ chunking_technique=payload.chunking_technique,
46
+ mode=payload.mode,
47
+ rerank_strategy=payload.rerank_strategy,
48
+ use_mmr=payload.use_mmr,
49
+ lambda_param=payload.lambda_param,
50
+ top_k=payload.top_k,
51
+ final_k=payload.final_k,
52
+ verbose=False,
53
+ )
54
+ retrieval_time = time.perf_counter() - retrieval_start
55
+
56
+ if not contexts:
57
+ raise HTTPException(status_code=404, detail="No context chunks retrieved for this query")
58
+
59
+ inference_start = time.perf_counter()
60
+ answer = rag_engine.get_answer(model_instance, query, contexts, temperature=payload.temperature)
61
+ inference_time = time.perf_counter() - inference_start
62
+
63
+ mapping_start = time.perf_counter()
64
+ retrieved_chunks = build_retrieved_chunks(contexts=contexts, chunk_lookup=chunk_lookup)
65
+ mapping_time = time.perf_counter() - mapping_start
66
+
67
+ total_time = time.perf_counter() - req_start
68
+
69
+ print(
70
+ f"Predict timing | model={model_name} | mode={payload.mode} | "
71
+ f"rerank={payload.rerank_strategy} | use_mmr={payload.use_mmr} | "
72
+ f"lambda={payload.lambda_param:.2f} | temp={payload.temperature:.2f} | "
73
+ f"chunking={payload.chunking_technique} | "
74
+ f"top_k={payload.top_k} | final_k={payload.final_k} | returned={len(contexts)} | "
75
+ f"precheck={precheck_time:.3f}s | "
76
+ f"state_access={state_access_time:.3f}s | model_resolve={model_resolve_time:.3f}s | "
77
+ f"retrieval={retrieval_time:.3f}s | inference={inference_time:.3f}s | "
78
+ f"context_map={mapping_time:.3f}s | total={total_time:.3f}s"
79
+ )
80
+
81
+ return PredictResponse(
82
+ model=model_name,
83
+ answer=answer,
84
+ contexts=contexts,
85
+ retrieved_chunks=retrieved_chunks,
86
+ )
87
+
88
+
backend/routes/predict_stream.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from typing import Any
4
+
5
+ from fastapi import APIRouter, HTTPException
6
+ from fastapi.responses import StreamingResponse
7
+
8
+ from backend.schemas import PredictRequest
9
+ from backend.services.chunks import build_retrieved_chunks
10
+ from backend.services.models import resolve_model
11
+ from backend.services.streaming import to_ndjson
12
+ from backend.state import state
13
+ from retriever.generator import RAGGenerator
14
+ from retriever.retriever import HybridRetriever
15
+
16
+ router = APIRouter()
17
+
18
+
19
+ # all paths define and API router object which is called
20
+ # in the api.py
21
+
22
+
23
+ @router.post("/predict/stream")
24
+ def predict_stream(payload: PredictRequest) -> StreamingResponse:
25
+ req_start = time.perf_counter()
26
+ stream_max_tokens = int(os.getenv("STREAM_MAX_TOKENS", "400"))
27
+
28
+ precheck_start = time.perf_counter()
29
+ if not state:
30
+ raise HTTPException(status_code=503, detail="Service not initialized yet")
31
+
32
+ query = payload.query.strip()
33
+ if not query:
34
+ raise HTTPException(status_code=400, detail="Query cannot be empty")
35
+ precheck_time = time.perf_counter() - precheck_start
36
+
37
+ state_access_start = time.perf_counter()
38
+ retriever: HybridRetriever = state["retriever"]
39
+ index = state["index"]
40
+ rag_engine: RAGGenerator = state["rag_engine"]
41
+ models: dict[str, Any] = state["models"]
42
+ chunk_lookup: dict[str, dict[str, Any]] = state.get("chunk_lookup", {})
43
+ state_access_time = time.perf_counter() - state_access_start
44
+
45
+ model_resolve_start = time.perf_counter()
46
+ model_name, model_instance = resolve_model(payload.model, models)
47
+ model_resolve_time = time.perf_counter() - model_resolve_start
48
+
49
+ retrieval_start = time.perf_counter()
50
+ contexts = retriever.search(
51
+ query,
52
+ index,
53
+ chunking_technique=payload.chunking_technique,
54
+ mode=payload.mode,
55
+ rerank_strategy=payload.rerank_strategy,
56
+ use_mmr=payload.use_mmr,
57
+ lambda_param=payload.lambda_param,
58
+ top_k=payload.top_k,
59
+ final_k=payload.final_k,
60
+ verbose=False,
61
+ )
62
+ retrieval_time = time.perf_counter() - retrieval_start
63
+
64
+ if not contexts:
65
+ raise HTTPException(status_code=404, detail="No context chunks retrieved for this query")
66
+
67
+ def stream_events():
68
+ inference_start = time.perf_counter()
69
+ first_token_latency = None
70
+ answer_parts: list[str] = []
71
+ try:
72
+ yield to_ndjson(
73
+ {
74
+ "type": "status",
75
+ "stage": "inference_start",
76
+ "model": model_name,
77
+ "retrieval_s": round(retrieval_time, 3),
78
+ "retrieval_debug": {
79
+ "requested_chunking_technique": payload.chunking_technique,
80
+ "requested_top_k": payload.top_k,
81
+ "requested_final_k": payload.final_k,
82
+ "returned_context_count": len(contexts),
83
+ "use_mmr": payload.use_mmr,
84
+ "lambda_param": payload.lambda_param,
85
+ },
86
+ }
87
+ )
88
+
89
+ for token in rag_engine.get_answer_stream(
90
+ model_instance,
91
+ query,
92
+ contexts,
93
+ temperature=payload.temperature,
94
+ max_tokens=stream_max_tokens,
95
+ ):
96
+ if first_token_latency is None:
97
+ first_token_latency = time.perf_counter() - inference_start
98
+ answer_parts.append(token)
99
+ yield to_ndjson({"type": "token", "token": token})
100
+
101
+ inference_time = time.perf_counter() - inference_start
102
+ answer = "".join(answer_parts)
103
+ retrieved_chunks = build_retrieved_chunks(contexts=contexts, chunk_lookup=chunk_lookup)
104
+
105
+ yield to_ndjson(
106
+ {
107
+ "type": "done",
108
+ "model": model_name,
109
+ "answer": answer,
110
+ "contexts": contexts,
111
+ "retrieved_chunks": retrieved_chunks,
112
+ "retrieval_debug": {
113
+ "requested_chunking_technique": payload.chunking_technique,
114
+ "requested_top_k": payload.top_k,
115
+ "requested_final_k": payload.final_k,
116
+ "returned_context_count": len(contexts),
117
+ "use_mmr": payload.use_mmr,
118
+ "lambda_param": payload.lambda_param,
119
+ },
120
+ }
121
+ )
122
+
123
+ total_time = time.perf_counter() - req_start
124
+ print(
125
+ f"Predict stream timing | model={model_name} | mode={payload.mode} | "
126
+ f"rerank={payload.rerank_strategy} | use_mmr={payload.use_mmr} | "
127
+ f"lambda={payload.lambda_param:.2f} | temp={payload.temperature:.2f} | "
128
+ f"chunking={payload.chunking_technique} | "
129
+ f"top_k={payload.top_k} | final_k={payload.final_k} | returned={len(contexts)} | "
130
+ f"precheck={precheck_time:.3f}s | "
131
+ f"state_access={state_access_time:.3f}s | model_resolve={model_resolve_time:.3f}s | "
132
+ f"retrieval={retrieval_time:.3f}s | first_token={first_token_latency if first_token_latency is not None else -1:.3f}s | "
133
+ f"inference={inference_time:.3f}s | total={total_time:.3f}s | "
134
+ f"max_tokens={stream_max_tokens}"
135
+ )
136
+ except Exception as exc:
137
+ total_time = time.perf_counter() - req_start
138
+ print(
139
+ f"Predict stream error | model={model_name} | mode={payload.mode} | "
140
+ f"retrieval={retrieval_time:.3f}s | elapsed={total_time:.3f}s | error={exc}"
141
+ )
142
+ yield to_ndjson({"type": "error", "message": f"Streaming failed: {exc}"})
143
+
144
+ return StreamingResponse(
145
+ stream_events(),
146
+ media_type="application/x-ndjson",
147
+ headers={
148
+ "Cache-Control": "no-cache",
149
+ "X-Accel-Buffering": "no",
150
+ },
151
+ )
backend/routes/title.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, HTTPException
2
+ from huggingface_hub import InferenceClient
3
+
4
+ from backend.schemas import TitleRequest, TitleResponse
5
+ from backend.services.title import parse_title_model_candidates, title_from_hf, title_from_query
6
+ from backend.state import state
7
+
8
+ router = APIRouter()
9
+
10
+
11
+ @router.post("/predict/title", response_model=TitleResponse)
12
+ def suggest_title(payload: TitleRequest) -> TitleResponse:
13
+ query = payload.query.strip()
14
+ if not query:
15
+ raise HTTPException(status_code=400, detail="Query cannot be empty")
16
+
17
+ fallback_title = title_from_query(query)
18
+
19
+ title_client: InferenceClient | None = state.get("title_client")
20
+ title_model_ids: list[str] = state.get("title_model_ids", parse_title_model_candidates())
21
+
22
+ if title_client is not None:
23
+ for title_model_id in title_model_ids:
24
+ try:
25
+ hf_title = title_from_hf(query, title_client, title_model_id)
26
+ if hf_title:
27
+ return TitleResponse(title=hf_title, source=f"hf:{title_model_id}")
28
+ except Exception as exc:
29
+ err_text = str(exc)
30
+ if "model_not_supported" in err_text or "not supported by any provider" in err_text:
31
+ continue
32
+ print(f"Title generation model failed ({title_model_id}): {exc}")
33
+ continue
34
+
35
+ print("Title generation fallback triggered: no title model available/successful")
36
+ return TitleResponse(title=fallback_title, source="rule-based")
backend/schemas.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ from pydantic import BaseModel, Field
4
+
5
+ # this defines the schemas for API endpoints
6
+ #
7
+
8
+
9
+ class PredictRequest(BaseModel):
10
+ query: str = Field(..., min_length=1, description="User query text")
11
+ model: str = Field(default="Llama-3-8B", description="Model name key")
12
+ top_k: int = Field(default=10, ge=1, le=20)
13
+ final_k: int = Field(default=3, ge=1, le=8)
14
+ chunking_technique: str = Field(default="all", description="all | fixed | sentence | paragraph | semantic | recursive | page | markdown")
15
+ mode: str = Field(default="hybrid", description="semantic | bm25 | hybrid")
16
+ rerank_strategy: str = Field(default="cross-encoder", description="cross-encoder | rrf | none")
17
+ use_mmr: bool = Field(default=True, description="Whether to apply MMR after reranking")
18
+ lambda_param: float = Field(default=0.5, ge=0.0, le=1.0, description="MMR relevance/diversity tradeoff")
19
+ temperature: float = Field(default=0.1, ge=0.0, le=2.0, description="Generation temperature")
20
+
21
+
22
+ class PredictResponse(BaseModel):
23
+ model: str
24
+ answer: str
25
+ contexts: list[str]
26
+ retrieved_chunks: list[dict[str, Any]]
27
+
28
+
29
+ class TitleRequest(BaseModel):
30
+ query: str = Field(..., min_length=1, description="First user message")
31
+
32
+
33
+ class TitleResponse(BaseModel):
34
+ title: str
35
+ source: str
backend/services/cache.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any
3
+
4
+ from data.vector_db import load_chunks_with_local_cache
5
+
6
+ # cacheing logic here
7
+ # note cacheing just useful in dev environment
8
+ # not really needed in hf, not even sure if hf memory is persistent
9
+
10
+ def get_cache_settings() -> tuple[str, bool]:
11
+ project_root = os.path.dirname(os.path.abspath(__file__))
12
+ cache_dir = os.getenv("BM25_CACHE_DIR", os.path.join(project_root, "..", ".cache"))
13
+ force_cache_refresh = os.getenv("BM25_CACHE_REFRESH", "0").lower() in {"1", "true", "yes"}
14
+ return cache_dir, force_cache_refresh
15
+
16
+
17
+ def load_cached_chunks(
18
+ index: Any,
19
+ index_name: str,
20
+ cache_dir: str,
21
+ force_cache_refresh: bool,
22
+ batch_size: int = 100,
23
+ ) -> tuple[list[dict[str, Any]], str]:
24
+ return load_chunks_with_local_cache(
25
+ index=index,
26
+ index_name=index_name,
27
+ cache_dir=cache_dir,
28
+ batch_size=batch_size,
29
+ force_refresh=force_cache_refresh,
30
+ )
backend/services/chunks.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+
4
+ # might need to touch this to get the additional metadata for retrieved chunks, like title and url
5
+ # --@Qamar
6
+
7
+ def build_retrieved_chunks(
8
+ contexts: list[str],
9
+ chunk_lookup: dict[str, dict[str, Any]],
10
+ ) -> list[dict[str, Any]]:
11
+ if not contexts:
12
+ return []
13
+
14
+ retrieved_chunks: list[dict[str, Any]] = []
15
+
16
+ for idx, text in enumerate(contexts, start=1):
17
+ meta = chunk_lookup.get(text, {})
18
+ title = meta.get("title") or "Untitled"
19
+ url = meta.get("url") or ""
20
+ chunk_index = meta.get("chunk_index")
21
+ page = meta.get("page")
22
+ section = meta.get("section")
23
+ source_type = meta.get("source_type") or meta.get("source")
24
+ image_url = (
25
+ meta.get("image_url")
26
+ or meta.get("image")
27
+ or meta.get("thumbnail_url")
28
+ or meta.get("media_url")
29
+ )
30
+
31
+ extra_metadata = {
32
+ k: v
33
+ for k, v in meta.items()
34
+ if k not in {"title", "url", "chunk_index", "text", "technique", "chunking_technique"}
35
+ }
36
+
37
+ retrieved_chunks.append(
38
+ {
39
+ "rank": idx,
40
+ "text": text,
41
+ "source_title": title,
42
+ "source_url": url,
43
+ "chunk_index": chunk_index,
44
+ "page": page,
45
+ "section": section,
46
+ "source_type": source_type,
47
+ "image_url": image_url,
48
+ "extra_metadata": extra_metadata,
49
+ }
50
+ )
51
+
52
+ return retrieved_chunks
backend/services/models.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ from fastapi import HTTPException
4
+
5
+ from models.llama_3_8b import Llama3_8B
6
+ from models.mistral_7b import Mistral_7b
7
+ from models.qwen_2_5 import Qwen2_5
8
+ from models.deepseek_v3 import DeepSeek_V3
9
+ from models.tiny_aya import TinyAya
10
+
11
+ # model defination
12
+ # copied from /models
13
+
14
+
15
+ def build_models(hf_token: str) -> dict[str, Any]:
16
+ return {
17
+ "Llama-3-8B": Llama3_8B(token=hf_token),
18
+ "Mistral-7B": Mistral_7b(token=hf_token),
19
+ "Qwen-2.5": Qwen2_5(token=hf_token),
20
+ "DeepSeek-V3": DeepSeek_V3(token=hf_token),
21
+ "TinyAya": TinyAya(token=hf_token),
22
+ }
23
+
24
+
25
+ def resolve_model(name: str, models: dict[str, Any]) -> tuple[str, Any]:
26
+ aliases = {
27
+ "llama": "Llama-3-8B",
28
+ "mistral": "Mistral-7B",
29
+ "qwen": "Qwen-2.5",
30
+ "deepseek": "DeepSeek-V3",
31
+ "tinyaya": "TinyAya",
32
+ }
33
+ model_key = aliases.get(name.lower(), name)
34
+ if model_key not in models:
35
+ allowed = ", ".join(models.keys())
36
+ raise HTTPException(status_code=400, detail=f"Unknown model '{name}'. Use one of: {allowed}")
37
+ return model_key, models[model_key]
backend/services/startup.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from typing import Any
4
+
5
+ from dotenv import load_dotenv
6
+ from huggingface_hub import InferenceClient
7
+
8
+ from config_loader import cfg
9
+ from data.vector_db import get_index_by_name
10
+ from retriever.generator import RAGGenerator
11
+ from retriever.processor import ChunkProcessor
12
+ from retriever.retriever import HybridRetriever
13
+
14
+ from backend.services.cache import get_cache_settings, load_cached_chunks
15
+ from backend.services.models import build_models
16
+ from backend.services.title import parse_title_model_candidates
17
+
18
+
19
+ # main file for initializing the runtime. Actual defines the
20
+ # pipeline objects, like retriever, generator and models
21
+ # i think i
22
+
23
+ def initialize_runtime_state(state: dict[str, Any]) -> None:
24
+ startup_start = time.perf_counter()
25
+
26
+ dotenv_start = time.perf_counter()
27
+ load_dotenv()
28
+ dotenv_time = time.perf_counter() - dotenv_start
29
+
30
+ env_start = time.perf_counter()
31
+ hf_token = os.getenv("HF_TOKEN")
32
+ pinecone_api_key = os.getenv("PINECONE_API_KEY")
33
+ env_time = time.perf_counter() - env_start
34
+
35
+ if not pinecone_api_key:
36
+ raise RuntimeError("PINECONE_API_KEY not found in environment variables")
37
+ if not hf_token:
38
+ raise RuntimeError("HF_TOKEN not found in environment variables")
39
+
40
+ index_name = "cbt-book-recursive"
41
+ embed_model_name = cfg.processing.get("embedding_model", "all-MiniLM-L6-v2")
42
+ rerank_model_name = os.getenv(
43
+ "RERANK_MODEL_NAME",
44
+ cfg.retrieval.get("rerank_model", "mixedbread-ai/mxbai-rerank-base-v1"),
45
+ )
46
+ cache_dir, force_cache_refresh = get_cache_settings()
47
+
48
+ index_start = time.perf_counter()
49
+ index = get_index_by_name(api_key=pinecone_api_key, index_name=index_name)
50
+ index_time = time.perf_counter() - index_start
51
+
52
+ chunks_start = time.perf_counter()
53
+ final_chunks, chunk_source = load_cached_chunks(
54
+ index=index,
55
+ index_name=index_name,
56
+ cache_dir=cache_dir,
57
+ force_cache_refresh=force_cache_refresh,
58
+ )
59
+ chunk_load_time = time.perf_counter() - chunks_start
60
+
61
+ if not final_chunks:
62
+ raise RuntimeError("No chunks found in Pinecone metadata. Run indexing once before API mode.")
63
+
64
+ processor_start = time.perf_counter()
65
+ proc = ChunkProcessor(model_name=embed_model_name, verbose=False, load_hf_embeddings=False)
66
+ processor_time = time.perf_counter() - processor_start
67
+
68
+ retriever_start = time.perf_counter()
69
+ retriever = HybridRetriever(
70
+ final_chunks,
71
+ proc.encoder,
72
+ rerank_model_name=rerank_model_name,
73
+ verbose=False,
74
+ )
75
+ retriever_time = time.perf_counter() - retriever_start
76
+
77
+ rag_start = time.perf_counter()
78
+ rag_engine = RAGGenerator()
79
+ rag_time = time.perf_counter() - rag_start
80
+
81
+ models_start = time.perf_counter()
82
+ models = build_models(hf_token)
83
+ models_time = time.perf_counter() - models_start
84
+
85
+ state_start = time.perf_counter()
86
+ chunk_lookup: dict[str, dict[str, Any]] = {}
87
+ for chunk in final_chunks:
88
+ metadata = chunk.get("metadata", {})
89
+ text = metadata.get("text")
90
+ if not text or text in chunk_lookup:
91
+ continue
92
+ meta_without_text = {k: v for k, v in metadata.items() if k != "text"}
93
+ meta_without_text["title"] = metadata.get("title", "Untitled")
94
+ meta_without_text["url"] = metadata.get("url", "")
95
+ meta_without_text["chunk_index"] = metadata.get("chunk_index")
96
+ chunk_lookup[text] = meta_without_text
97
+
98
+ state["index"] = index
99
+ state["retriever"] = retriever
100
+ state["rag_engine"] = rag_engine
101
+ state["models"] = models
102
+ state["chunk_lookup"] = chunk_lookup
103
+ state["title_model_ids"] = parse_title_model_candidates()
104
+ state["title_client"] = InferenceClient(token=hf_token)
105
+ state_time = time.perf_counter() - state_start
106
+
107
+ startup_time = time.perf_counter() - startup_start
108
+ print(
109
+ f"API startup complete | chunks={len(final_chunks)} | "
110
+ f"dotenv={dotenv_time:.3f}s | "
111
+ f"env={env_time:.3f}s | "
112
+ f"index={index_time:.3f}s | "
113
+ f"cache_dir={cache_dir} | "
114
+ f"force_cache_refresh={force_cache_refresh} | "
115
+ f"chunk_source={chunk_source} | "
116
+ f"chunk_load={chunk_load_time:.3f}s | "
117
+ f"processor={processor_time:.3f}s | "
118
+ f"rerank_model={rerank_model_name} | "
119
+ f"retriever={retriever_time:.3f}s | "
120
+ f"rag={rag_time:.3f}s | "
121
+ f"models={models_time:.3f}s | "
122
+ f"state={state_time:.3f}s | "
123
+ f"total={startup_time:.3f}s"
124
+ )
backend/services/streaming.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import Any
3
+
4
+ #need ndjson for streaming responses, this is a simple helper to convert dicts to ndjson format
5
+
6
+ def to_ndjson(payload: dict[str, Any]) -> str:
7
+ return json.dumps(payload, ensure_ascii=False) + "\n"
backend/services/title.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+
4
+ from huggingface_hub import InferenceClient
5
+
6
+ # the functions for resolving and generating titles
7
+ # it tries to query and hf model for title
8
+
9
+ # some shitty fallback logic, if models fail
10
+ # could improve candidate model defining code
11
+
12
+ def title_from_query(query: str) -> str:
13
+ stop_words = {
14
+ "a", "an", "and", "are", "as", "at", "be", "by", "can", "do", "for", "from", "how",
15
+ "i", "in", "is", "it", "me", "my", "of", "on", "or", "please", "show", "tell", "that",
16
+ "the", "this", "to", "we", "what", "when", "where", "which", "why", "with", "you", "your",
17
+ }
18
+
19
+ words = re.findall(r"[A-Za-z0-9][A-Za-z0-9\-_/+]*", query)
20
+ if not words:
21
+ return "New Chat"
22
+
23
+ filtered: list[str] = []
24
+ for word in words:
25
+ cleaned = word.strip("-_/+")
26
+ if not cleaned:
27
+ continue
28
+ if cleaned.lower() in stop_words:
29
+ continue
30
+ filtered.append(cleaned)
31
+ if len(filtered) >= 6:
32
+ break
33
+
34
+ chosen = filtered if filtered else words[:6]
35
+ normalized = [w.capitalize() if w.islower() else w for w in chosen]
36
+ title = " ".join(normalized).strip()
37
+ return title[:80] if title else "New Chat"
38
+
39
+
40
+ def clean_title_text(raw: str) -> str:
41
+ text = (raw or "").strip()
42
+ text = text.replace("\n", " ").replace("\r", " ")
43
+ text = re.sub(r"^[\"'`\s]+|[\"'`\s]+$", "", text)
44
+ text = re.sub(r"\s+", " ", text).strip()
45
+ words = text.split()
46
+ if len(words) > 8:
47
+ text = " ".join(words[:8])
48
+ return text[:80]
49
+
50
+
51
+ def title_from_hf(query: str, client: InferenceClient, model_id: str) -> str | None:
52
+ system_prompt = (
53
+ "You generate short chat titles. Return only a title, no punctuation at the end, no quotes."
54
+ )
55
+ user_prompt = (
56
+ "Create a concise 3-7 word title for this user request:\n"
57
+ f"{query}"
58
+ )
59
+
60
+ response = client.chat_completion(
61
+ model=model_id,
62
+ messages=[
63
+ {"role": "system", "content": system_prompt},
64
+ {"role": "user", "content": user_prompt},
65
+ ],
66
+ max_tokens=24,
67
+ temperature=0.3,
68
+ )
69
+ if not response or not response.choices:
70
+ return None
71
+
72
+ raw_title = response.choices[0].message.content or ""
73
+ title = clean_title_text(raw_title)
74
+ if not title or title.lower() == "new chat":
75
+ return None
76
+ return title
77
+
78
+
79
+ def parse_title_model_candidates() -> list[str]:
80
+
81
+ raw = os.getenv(
82
+ "TITLE_MODEL_IDS",
83
+ "Qwen/Qwen2.5-1.5B-Instruct,CohereLabs/tiny-aya-global,meta-llama/Meta-Llama-3-8B-Instruct",
84
+ )
85
+ models = [m.strip() for m in raw.split(",") if m.strip()]
86
+ return models or ["meta-llama/Meta-Llama-3-8B-Instruct"]
backend/state.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ state: dict[str, Any] = {}
4
+ # this file defines the state dict
5
+ # think of this as the runtime object created after startup
6
+
7
+ REQUIRED_STATE_KEYS = ("index", "retriever", "rag_engine", "models")
config.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------
2
+ # RAG CBT QUESTION-ANSWERING SYSTEM CONFIGURATION
3
+ # ------------------------------------------------------------------
4
+
5
+ project:
6
+ name: "cbt-rag-system"
7
+ category: "psychology"
8
+ doc_limit: null # Load all pages from the book
9
+
10
+ processing:
11
+ # Embedding model used for both vector db and evaluator similarity
12
+ embedding_model: "jinaai/jina-embeddings-v2-small-en"
13
+ # Options: sentence, recursive, semantic, fixed
14
+ technique: "recursive"
15
+ # Jina supports 8192 tokens (~32k chars), using 1000 chars for better context
16
+ chunk_size: 1000
17
+ chunk_overlap: 100
18
+
19
+ vector_db:
20
+ base_index_name: "cbt-book"
21
+ dimension: 512 # Jina outputs 512 dimensions
22
+ metric: "cosine"
23
+ batch_size: 50 # Reduced batch size for CPU processing
24
+
25
+ retrieval:
26
+ # Options: hybrid, semantic, bm25
27
+ mode: "hybrid"
28
+ # Options: cross-encoder, rrf
29
+ rerank_strategy: "cross-encoder"
30
+ use_mmr: true
31
+ top_k: 10
32
+ final_k: 5
33
+
34
+ generation:
35
+ temperature: 0.
36
+ max_new_tokens: 512
37
+ # The model used to Judge the others (OpenRouter)
38
+ judge_model: "stepfun/step-3.5-flash:free"
39
+
40
+ # List of contestants in the tournament
41
+ models:
42
+ - "Llama-3-8B"
43
+ - "Mistral-7B"
44
+ - "Qwen-2.5"
45
+ - "DeepSeek-V3"
46
+ - "TinyAya"
config_loader.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ from pathlib import Path
3
+
4
+ class RAGConfig:
5
+ def __init__(self, config_path="config.yaml"):
6
+ with open(config_path, 'r') as f:
7
+ self.data = yaml.safe_load(f)
8
+
9
+ @property
10
+ def project(self): return self.data['project']
11
+
12
+ @property
13
+ def processing(self): return self.data['processing']
14
+
15
+ @property
16
+ def db(self): return self.data['vector_db']
17
+
18
+ @property
19
+ def retrieval(self): return self.data['retrieval']
20
+
21
+ @property
22
+ def gen(self): return self.data['generation']
23
+
24
+ @property
25
+ def model_list(self): return self.data['models']
26
+
27
+ cfg = RAGConfig()
main.py ADDED
@@ -0,0 +1,659 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import time
4
+ from datetime import datetime
5
+ from multiprocessing import Pool, cpu_count
6
+ from functools import partial
7
+ from dotenv import load_dotenv
8
+ from config_loader import cfg
9
+
10
+ from data.vector_db import get_pinecone_index, refresh_pinecone_index
11
+ from retriever.retriever import HybridRetriever
12
+ from retriever.generator import RAGGenerator
13
+ from retriever.processor import ChunkProcessor
14
+ from retriever.evaluator import RAGEvaluator
15
+ from data.data_loader import load_cbt_book, get_book_stats
16
+ from data.ingest import ingest_data, CHUNKING_TECHNIQUES
17
+
18
+ # Import model fleet
19
+ from models.llama_3_8b import Llama3_8B
20
+ from models.mistral_7b import Mistral_7b
21
+ from models.qwen_2_5 import Qwen2_5
22
+ from models.deepseek_v3 import DeepSeek_V3
23
+ from models.tiny_aya import TinyAya
24
+
25
+ MODEL_MAP = {
26
+ "Llama-3-8B": Llama3_8B,
27
+ "Mistral-7B": Mistral_7b,
28
+ "Qwen-2.5": Qwen2_5,
29
+ "DeepSeek-V3": DeepSeek_V3,
30
+ "TinyAya": TinyAya
31
+ }
32
+
33
+ load_dotenv()
34
+
35
+
36
+ def run_rag_for_technique(technique_name, query, index, encoder, models, evaluator, rag_engine):
37
+ """Run RAG pipeline for a specific chunking technique."""
38
+
39
+ print(f"\n{'='*80}")
40
+ print(f"TECHNIQUE: {technique_name.upper()}")
41
+ print(f"{'='*80}")
42
+
43
+ # Filter chunks by technique metadata
44
+ query_vector = encoder.encode(query).tolist()
45
+
46
+ # Query with metadata filter for this technique - get more candidates for reranking
47
+ res = index.query(
48
+ vector=query_vector,
49
+ top_k=25,
50
+ include_metadata=True,
51
+ filter={"technique": {"$eq": technique_name}}
52
+ )
53
+
54
+ # Extract context chunks with URLs
55
+ all_candidates = []
56
+ chunk_urls = []
57
+ for match in res['matches']:
58
+ all_candidates.append(match['metadata']['text'])
59
+ chunk_urls.append(match['metadata'].get('url', ''))
60
+
61
+ print(f"\nRetrieved {len(all_candidates)} candidate chunks for technique '{technique_name}'")
62
+
63
+ if not all_candidates:
64
+ print(f"WARNING: No chunks found for technique '{technique_name}'")
65
+ return {}
66
+
67
+ # Apply cross-encoder reranking to get top 5
68
+ # Use global reranker loaded once per worker
69
+ global _worker_reranker
70
+ pairs = [[query, chunk] for chunk in all_candidates]
71
+ scores = _worker_reranker.predict(pairs)
72
+ ranked = sorted(zip(all_candidates, chunk_urls, scores), key=lambda x: x[2], reverse=True)
73
+ context_chunks = [chunk for chunk, _, _ in ranked[:5]]
74
+ context_urls = [url for _, url, _ in ranked[:5]]
75
+
76
+ print(f"After reranking: {len(context_chunks)} chunks (top 5)")
77
+
78
+ # Print the final RAG context being passed to models (only once)
79
+ print(f"\n{'='*80}")
80
+ print(f"📚 FINAL RAG CONTEXT FOR TECHNIQUE '{technique_name.upper()}'")
81
+ print(f"{'='*80}")
82
+ for i, chunk in enumerate(context_chunks, 1):
83
+ print(f"\n[Chunk {i}] ({len(chunk)} chars):")
84
+ print(f"{'─'*60}")
85
+ print(chunk)
86
+ print(f"{'─'*60}")
87
+ print(f"\n{'='*80}")
88
+
89
+ # Run model tournament for this technique
90
+ tournament_results = {}
91
+
92
+ for name, model_inst in models.items():
93
+ print(f"\n{'-'*60}")
94
+ print(f"Model: {name}")
95
+ print(f"{'-'*60}")
96
+ try:
97
+ # Generation
98
+ answer = rag_engine.get_answer(
99
+ model_inst, query, context_chunks,
100
+ context_urls=context_urls,
101
+ temperature=cfg.gen['temperature']
102
+ )
103
+
104
+ print(f"\n{'─'*60}")
105
+ print(f"📝 FULL ANSWER from {name}:")
106
+ print(f"{'─'*60}")
107
+ print(answer)
108
+ print(f"{'─'*60}")
109
+
110
+ # Faithfulness Evaluation (strict=False reduces API calls from ~22 to ~3 per eval)
111
+ faith = evaluator.evaluate_faithfulness(answer, context_chunks, strict=False)
112
+ # Relevancy Evaluation
113
+ rel = evaluator.evaluate_relevancy(query, answer)
114
+
115
+ tournament_results[name] = {
116
+ "answer": answer,
117
+ "Faithfulness": faith['score'],
118
+ "Relevancy": rel['score'],
119
+ "Claims": faith['details'],
120
+ "context_chunks": context_chunks,
121
+ "context_urls": context_urls
122
+ }
123
+
124
+ print(f"\n📊 EVALUATION SCORES:")
125
+ print(f" Faithfulness: {faith['score']:.1f}%")
126
+ print(f" Relevancy: {rel['score']:.3f}")
127
+ print(f" Combined: {faith['score'] + rel['score']:.3f}")
128
+
129
+ except Exception as e:
130
+ print(f" Error evaluating {name}: {e}")
131
+ tournament_results[name] = {
132
+ "answer": "",
133
+ "Faithfulness": 0,
134
+ "Relevancy": 0,
135
+ "Claims": [],
136
+ "error": str(e),
137
+ "context_chunks": context_chunks,
138
+ "context_urls": context_urls
139
+ }
140
+
141
+ return tournament_results
142
+
143
+
144
+ def generate_findings_document(all_query_results, queries, output_file="rag_ablation_findings.md"):
145
+ """Generate detailed markdown document with findings from all techniques across all queries.
146
+
147
+ Args:
148
+ all_query_results: Dict mapping query index to results dict
149
+ queries: List of all test queries
150
+ output_file: Path to output file
151
+ """
152
+
153
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
154
+
155
+ content = f"""# RAG Ablation Study Findings
156
+
157
+ *Generated:* {timestamp}
158
+
159
+ ## Overview
160
+
161
+ This document presents findings from a comparative analysis of 6 different chunking techniques
162
+ applied to a Cognitive Behavioral Therapy (CBT) book. Each technique was evaluated using
163
+ multiple LLM models with RAG (Retrieval-Augmented Generation) pipeline.
164
+
165
+ ## Test Queries
166
+
167
+ """
168
+
169
+ for i, query in enumerate(queries, 1):
170
+ content += f"{i}. {query}\n"
171
+
172
+ content += """
173
+ ## Chunking Techniques Evaluated
174
+
175
+ 1. *Fixed* - Fixed-size chunking (1000 chars, 100 overlap)
176
+ 2. *Sentence* - Sentence-level chunking (NLTK)
177
+ 3. *Paragraph* - Paragraph-level chunking (\\n\\n boundaries)
178
+ 4. *Semantic* - Semantic chunking (embedding similarity)
179
+ 5. *Recursive* - Recursive chunking (hierarchical separators)
180
+ 6. *Page* - Page-level chunking (--- Page markers)
181
+
182
+ ## Results by Technique (Aggregated Across All Queries)
183
+
184
+ """
185
+
186
+ # Aggregate results across all queries
187
+ aggregated_results = {}
188
+
189
+ for query_idx, query_results in all_query_results.items():
190
+ for technique_name, model_results in query_results.items():
191
+ if technique_name not in aggregated_results:
192
+ aggregated_results[technique_name] = {}
193
+
194
+ for model_name, results in model_results.items():
195
+ if model_name not in aggregated_results[technique_name]:
196
+ aggregated_results[technique_name][model_name] = {
197
+ 'Faithfulness': [],
198
+ 'Relevancy': [],
199
+ 'answers': [],
200
+ 'context_chunks': results.get('context_chunks', []),
201
+ 'context_urls': results.get('context_urls', [])
202
+ }
203
+
204
+ aggregated_results[technique_name][model_name]['Faithfulness'].append(results.get('Faithfulness', 0))
205
+ aggregated_results[technique_name][model_name]['Relevancy'].append(results.get('Relevancy', 0))
206
+ aggregated_results[technique_name][model_name]['answers'].append(results.get('answer', ''))
207
+
208
+ # Add results for each technique
209
+ for technique_name, model_results in aggregated_results.items():
210
+ content += f"### {technique_name.upper()} Chunking\n\n"
211
+
212
+ if not model_results:
213
+ content += "No results available for this technique.\n\n"
214
+ continue
215
+
216
+ # Create results table with averaged scores
217
+ content += "| Model | Avg Faithfulness | Avg Relevancy | Avg Combined |\n"
218
+ content += "|-------|------------------|---------------|--------------|\n"
219
+
220
+ for model_name, results in model_results.items():
221
+ avg_faith = sum(results['Faithfulness']) / len(results['Faithfulness']) if results['Faithfulness'] else 0
222
+ avg_rel = sum(results['Relevancy']) / len(results['Relevancy']) if results['Relevancy'] else 0
223
+ avg_combined = avg_faith + avg_rel
224
+ content += f"| {model_name} | {avg_faith:.1f}% | {avg_rel:.3f} | {avg_combined:.3f} |\n"
225
+
226
+ # Find best model for this technique
227
+ if model_results:
228
+ best_model = max(
229
+ model_results.items(),
230
+ key=lambda x: (sum(x[1]['Faithfulness']) / len(x[1]['Faithfulness']) if x[1]['Faithfulness'] else 0) +
231
+ (sum(x[1]['Relevancy']) / len(x[1]['Relevancy']) if x[1]['Relevancy'] else 0)
232
+ )
233
+ best_name = best_model[0]
234
+ best_faith = sum(best_model[1]['Faithfulness']) / len(best_model[1]['Faithfulness']) if best_model[1]['Faithfulness'] else 0
235
+ best_rel = sum(best_model[1]['Relevancy']) / len(best_model[1]['Relevancy']) if best_model[1]['Relevancy'] else 0
236
+
237
+ content += f"\n*Best Model:* {best_name} (Avg Faithfulness: {best_faith:.1f}%, Avg Relevancy: {best_rel:.3f})\n\n"
238
+
239
+ # Show context chunks once per technique (not per model)
240
+ context_chunks = None
241
+ context_urls = None
242
+ for model_name, results in model_results.items():
243
+ if results.get('context_chunks'):
244
+ context_chunks = results['context_chunks']
245
+ context_urls = results.get('context_urls', [])
246
+ break
247
+
248
+ if context_chunks:
249
+ content += "#### Context Chunks Used\n\n"
250
+ for i, chunk in enumerate(context_chunks, 1):
251
+ url = context_urls[i-1] if context_urls and i-1 < len(context_urls) else ""
252
+ if url:
253
+ content += f"*Chunk {i}* ([Source]({url})):\n"
254
+ else:
255
+ content += f"*Chunk {i}*:\n"
256
+ content += f"\n{chunk}\n\n\n"
257
+
258
+ # Add detailed RAG results for each model
259
+ content += "#### Detailed RAG Results\n\n"
260
+
261
+ for model_name, results in model_results.items():
262
+ answers = results.get('answers', [])
263
+ avg_faith = sum(results['Faithfulness']) / len(results['Faithfulness']) if results['Faithfulness'] else 0
264
+ avg_rel = sum(results['Relevancy']) / len(results['Relevancy']) if results['Relevancy'] else 0
265
+
266
+ content += f"*{model_name}* (Avg Faithfulness: {avg_faith:.1f}%, Avg Relevancy: {avg_rel:.3f})\n\n"
267
+
268
+ # Show answers from each query
269
+ for q_idx, answer in enumerate(answers):
270
+ content += f"📝 *Answer for Query {q_idx + 1}:*\n\n"
271
+ content += f"\n{answer}\n\n\n"
272
+
273
+ content += "---\n\n"
274
+
275
+ # Add comparative analysis
276
+ content += """## Comparative Analysis
277
+
278
+ ### Overall Performance Ranking (Across All Queries)
279
+
280
+ | Rank | Technique | Avg Faithfulness | Avg Relevancy | Avg Combined |
281
+ |------|-----------|------------------|---------------|--------------|
282
+ """
283
+
284
+ # Calculate averages for each technique across all queries
285
+ technique_averages = {}
286
+ for technique_name, model_results in aggregated_results.items():
287
+ if model_results:
288
+ all_faith = []
289
+ all_rel = []
290
+ for model_name, results in model_results.items():
291
+ all_faith.extend(results['Faithfulness'])
292
+ all_rel.extend(results['Relevancy'])
293
+
294
+ avg_faith = sum(all_faith) / len(all_faith) if all_faith else 0
295
+ avg_rel = sum(all_rel) / len(all_rel) if all_rel else 0
296
+ avg_combined = avg_faith + avg_rel
297
+ technique_averages[technique_name] = {
298
+ 'faith': avg_faith,
299
+ 'rel': avg_rel,
300
+ 'combined': avg_combined
301
+ }
302
+
303
+ # Sort by combined score
304
+ sorted_techniques = sorted(
305
+ technique_averages.items(),
306
+ key=lambda x: x[1]['combined'],
307
+ reverse=True
308
+ )
309
+
310
+ for rank, (technique_name, averages) in enumerate(sorted_techniques, 1):
311
+ content += f"| {rank} | {technique_name} | {averages['faith']:.1f}% | {averages['rel']:.3f} | {averages['combined']:.3f} |\n"
312
+
313
+ content += """
314
+ ### Key Findings
315
+
316
+ """
317
+
318
+ if sorted_techniques:
319
+ best_technique = sorted_techniques[0][0]
320
+ worst_technique = sorted_techniques[-1][0]
321
+
322
+ content += f"""
323
+ 1. *Best Performing Technique:* {best_technique}
324
+ - Achieved highest combined score across all models and queries
325
+ - Recommended for production RAG applications
326
+
327
+ 2. *Worst Performing Technique:* {worst_technique}
328
+ - Lower combined scores across models and queries
329
+ - May need optimization or different configuration
330
+
331
+ 3. *Model Consistency:*
332
+ - Analyzed which models perform consistently across techniques
333
+ - Identified technique-specific model preferences
334
+
335
+ """
336
+
337
+ content += """## Recommendations
338
+
339
+ Based on the ablation study results:
340
+
341
+ 1. *Primary Recommendation:* Use the best-performing chunking technique for your specific use case
342
+ 2. *Hybrid Approach:* Consider combining techniques for different types of queries
343
+ 3. *Model Selection:* Choose models that perform well across multiple techniques
344
+ 4. *Parameter Tuning:* Optimize chunk sizes and overlaps based on document characteristics
345
+
346
+ ## Technical Details
347
+
348
+ - *Embedding Model:* Jina embeddings (512 dimensions)
349
+ - *Vector Database:* Pinecone (serverless, AWS us-east-1)
350
+ - *Judge Model:* Openrouter Free models
351
+ - *Retrieval:* Top 5 chunks per technique
352
+ - *Evaluation Metrics:* Faithfulness (context grounding), Relevancy (query addressing)
353
+
354
+ ---
355
+
356
+ This report was automatically generated by the RAG Ablation Study Pipeline.
357
+ """
358
+
359
+ # Write to file
360
+ with open(output_file, 'w', encoding='utf-8') as f:
361
+ f.write(content)
362
+
363
+ print(f"\nFindings document saved to: {output_file}")
364
+ return output_file
365
+
366
+
367
+ # Global variables for worker processes
368
+ _worker_proc = None
369
+ _worker_evaluator = None
370
+ _worker_models = None
371
+ _worker_rag_engine = None
372
+ _worker_reranker = None
373
+
374
+ def init_worker(model_name, evaluator_config):
375
+ """Initialize models once per worker process."""
376
+ global _worker_proc, _worker_evaluator, _worker_models, _worker_rag_engine, _worker_reranker
377
+
378
+ from retriever.processor import ChunkProcessor
379
+ from retriever.evaluator import RAGEvaluator
380
+ from retriever.generator import RAGGenerator
381
+ from sentence_transformers import CrossEncoder
382
+ from models.llama_3_8b import Llama3_8B
383
+ from models.mistral_7b import Mistral_7b
384
+ from models.qwen_2_5 import Qwen2_5
385
+ from models.deepseek_v3 import DeepSeek_V3
386
+ from models.tiny_aya import TinyAya
387
+
388
+ MODEL_MAP = {
389
+ "Llama-3-8B": Llama3_8B,
390
+ "Mistral-7B": Mistral_7b,
391
+ "Qwen-2.5": Qwen2_5,
392
+ "DeepSeek-V3": DeepSeek_V3,
393
+ "TinyAya": TinyAya
394
+ }
395
+
396
+ # Load embedding model once
397
+ _worker_proc = ChunkProcessor(model_name=model_name, verbose=False)
398
+
399
+ # Initialize evaluator
400
+ _worker_evaluator = RAGEvaluator(
401
+ judge_model=evaluator_config['judge_model'],
402
+ embedding_model=_worker_proc.encoder,
403
+ api_key=evaluator_config['api_key']
404
+ )
405
+
406
+ # Initialize models
407
+ hf_token = os.getenv("HF_TOKEN")
408
+ _worker_models = {name: MODEL_MAP[name](token=hf_token) for name in evaluator_config['model_list']}
409
+
410
+ # Initialize RAG engine
411
+ _worker_rag_engine = RAGGenerator()
412
+
413
+ # Load reranker once per worker
414
+ _worker_reranker = CrossEncoder('jinaai/jina-reranker-v1-tiny-en')
415
+
416
+
417
+ def run_rag_for_technique_wrapper(args):
418
+ """Wrapper function for parallel execution."""
419
+ global _worker_proc, _worker_evaluator, _worker_models, _worker_rag_engine
420
+
421
+ technique, query, index_name, pinecone_key = args
422
+ try:
423
+ # Create new connection in worker process
424
+ from data.vector_db import get_index_by_name
425
+ index = get_index_by_name(pinecone_key, index_name)
426
+
427
+ return technique['name'], run_rag_for_technique(
428
+ technique_name=technique['name'],
429
+ query=query,
430
+ index=index,
431
+ encoder=_worker_proc.encoder,
432
+ models=_worker_models,
433
+ evaluator=_worker_evaluator,
434
+ rag_engine=_worker_rag_engine
435
+ )
436
+ except Exception as e:
437
+ import traceback
438
+ print(f"\n✗ Error processing technique {technique['name']}: {e}")
439
+ print(f"Full traceback:")
440
+ traceback.print_exc()
441
+ return technique['name'], {}
442
+
443
+
444
+ def main():
445
+ """Main function to run RAG ablation study across all 6 chunking techniques."""
446
+ hf_token = os.getenv("HF_TOKEN")
447
+ pinecone_key = os.getenv("PINECONE_API_KEY")
448
+ openrouter_key = os.getenv("OPENROUTER_API_KEY")
449
+
450
+ # Verify environment variables
451
+ if not hf_token:
452
+ raise RuntimeError("HF_TOKEN not found in environment variables")
453
+ if not pinecone_key:
454
+ raise RuntimeError("PINECONE_API_KEY not found in environment variables")
455
+ if not openrouter_key:
456
+ raise RuntimeError("OPENROUTER_API_KEY not found in environment variables")
457
+
458
+ # Test queries
459
+ test_queries = [
460
+ "What is cognitive behavior therapy and how does it work?",
461
+ "What are the common cognitive distortions in CBT?",
462
+ "How does CBT help with anxiety and depression?"
463
+ ]
464
+
465
+ print("=" * 80)
466
+ print("RAG ABLATION STUDY - 6 CHUNKING TECHNIQUES")
467
+ print("=" * 80)
468
+ print(f"\nTest Queries:")
469
+ for i, q in enumerate(test_queries, 1):
470
+ print(f" {i}. {q}")
471
+
472
+ # Step 1: Check if data already exists, skip ingestion if so
473
+ print("\n" + "=" * 80)
474
+ print("STEP 1: CHECKING/INGESTING DATA WITH ALL 6 TECHNIQUES")
475
+ print("=" * 80)
476
+
477
+ # Check if index already has data
478
+ from data.vector_db import get_index_by_name
479
+ index_name = f"{cfg.db['base_index_name']}-{cfg.processing['technique']}"
480
+
481
+ print(f"\nChecking for existing index: {index_name}")
482
+
483
+ try:
484
+ # Try to connect to existing index
485
+ print("Connecting to Pinecone...")
486
+ existing_index = get_index_by_name(pinecone_key, index_name)
487
+ print("Getting index stats...")
488
+ stats = existing_index.describe_index_stats()
489
+ existing_count = stats.get('total_vector_count', 0)
490
+
491
+ if existing_count > 0:
492
+ print(f"\n✓ Found existing index with {existing_count} vectors")
493
+ print("Skipping ingestion - using existing data")
494
+
495
+ # Initialize processor (this loads the embedding model)
496
+ print("Loading embedding model for retrieval...")
497
+ from retriever.processor import ChunkProcessor
498
+ proc = ChunkProcessor(model_name=cfg.processing['embedding_model'], verbose=False)
499
+ index = existing_index
500
+ all_chunks = [] # Empty since we're using existing data
501
+ final_chunks = []
502
+ print("✓ Processor initialized")
503
+ else:
504
+ print("\nIndex exists but is empty. Running full ingestion...")
505
+ all_chunks, final_chunks, proc, index = ingest_data()
506
+ except Exception as e:
507
+ print(f"\nIndex check failed: {e}")
508
+ print("Running full ingestion...")
509
+ all_chunks, final_chunks, proc, index = ingest_data()
510
+
511
+ print(f"\nTechniques to evaluate: {[tech['name'] for tech in CHUNKING_TECHNIQUES]}")
512
+
513
+ # Step 2: Initialize components
514
+ print("\n" + "=" * 80)
515
+ print("STEP 2: INITIALIZING COMPONENTS")
516
+ print("=" * 80)
517
+
518
+ # Initialize models
519
+ print("\nInitializing models...")
520
+ rag_engine = RAGGenerator()
521
+ models = {name: MODEL_MAP[name](token=hf_token) for name in cfg.model_list}
522
+
523
+ # Initialize evaluator
524
+ print("Initializing evaluator...")
525
+ if not openrouter_key:
526
+ raise RuntimeError("OPENROUTER_API_KEY not found in environment variables")
527
+
528
+ evaluator = RAGEvaluator(
529
+ judge_model=cfg.gen['judge_model'],
530
+ embedding_model=proc.encoder,
531
+ api_key=openrouter_key
532
+ )
533
+
534
+ # Step 3: Run RAG for all techniques in parallel for all queries
535
+ print("\n" + "=" * 80)
536
+ print("STEP 3: RUNNING RAG FOR ALL 6 TECHNIQUES (IN PARALLEL)")
537
+ print("=" * 80)
538
+
539
+ # Prepare arguments for parallel execution
540
+ num_processes = min(cpu_count(), len(CHUNKING_TECHNIQUES))
541
+ print(f"\nUsing {num_processes} parallel processes for {len(CHUNKING_TECHNIQUES)} techniques")
542
+
543
+ # Run techniques in parallel for all queries
544
+ evaluator_config = {
545
+ 'judge_model': cfg.gen['judge_model'],
546
+ 'api_key': openrouter_key,
547
+ 'model_list': cfg.model_list
548
+ }
549
+
550
+ all_query_results = {}
551
+
552
+ for query_idx, query in enumerate(test_queries):
553
+ print(f"\n{'='*80}")
554
+ print(f"PROCESSING QUERY {query_idx + 1}/{len(test_queries)}")
555
+ print(f"Query: {query}")
556
+ print(f"{'='*80}")
557
+
558
+ with Pool(
559
+ processes=num_processes,
560
+ initializer=init_worker,
561
+ initargs=(cfg.processing['embedding_model'], evaluator_config)
562
+ ) as pool:
563
+ args_list = [
564
+ (technique, query, index_name, pinecone_key)
565
+ for technique in CHUNKING_TECHNIQUES
566
+ ]
567
+ results_list = pool.map(run_rag_for_technique_wrapper, args_list)
568
+
569
+ # Convert results to dictionary and store
570
+ query_results = {name: results for name, results in results_list}
571
+ all_query_results[query_idx] = query_results
572
+
573
+ # Print quick summary for this query
574
+ print(f"\n{'='*80}")
575
+ print(f"QUERY {query_idx + 1} SUMMARY")
576
+ print(f"{'='*80}")
577
+ print(f"\n{'Technique':<15} {'Avg Faith':>12} {'Avg Rel':>12} {'Best Model':<20}")
578
+ print("-" * 60)
579
+
580
+ for technique_name, model_results in query_results.items():
581
+ if model_results:
582
+ avg_faith = sum(r.get('Faithfulness', 0) for r in model_results.values()) / len(model_results)
583
+ avg_rel = sum(r.get('Relevancy', 0) for r in model_results.values()) / len(model_results)
584
+
585
+ # Find best model
586
+ best_model = max(
587
+ model_results.items(),
588
+ key=lambda x: x[1].get('Faithfulness', 0) + x[1].get('Relevancy', 0)
589
+ )
590
+ best_name = best_model[0]
591
+
592
+ print(f"{technique_name:<15} {avg_faith:>11.1f}% {avg_rel:>12.3f} {best_name:<20}")
593
+ else:
594
+ print(f"{technique_name:<15} {'N/A':>12} {'N/A':>12} {'N/A':<20}")
595
+
596
+ print("-" * 60)
597
+
598
+ # Step 4: Generate findings document from all queries
599
+ print("\n" + "=" * 80)
600
+ print("STEP 4: GENERATING FINDINGS DOCUMENT")
601
+ print("=" * 80)
602
+
603
+ findings_file = generate_findings_document(all_query_results, test_queries)
604
+
605
+ # Step 5: Final summary
606
+ print("\n" + "=" * 80)
607
+ print("ABLATION STUDY COMPLETE - SUMMARY")
608
+ print("=" * 80)
609
+
610
+ print(f"\nQueries processed: {len(test_queries)}")
611
+ print(f"Techniques evaluated: {len(CHUNKING_TECHNIQUES)}")
612
+ print(f"Models tested: {len(cfg.model_list)}")
613
+ print(f"\nFindings document: {findings_file}")
614
+
615
+ # Print final summary across all queries
616
+ print("\n" + "-" * 60)
617
+ print(f"{'Technique':<15} {'Avg Faith':>12} {'Avg Rel':>12} {'Best Model':<20}")
618
+ print("-" * 60)
619
+
620
+ # Calculate averages across all queries
621
+ for tech_config in CHUNKING_TECHNIQUES:
622
+ tech_name = tech_config['name']
623
+ all_faith = []
624
+ all_rel = []
625
+ best_model_name = None
626
+ best_combined = 0
627
+
628
+ for query_idx, query_results in all_query_results.items():
629
+ if tech_name in query_results and query_results[tech_name]:
630
+ model_results = query_results[tech_name]
631
+ for model_name, results in model_results.items():
632
+ faith = results.get('Faithfulness', 0)
633
+ rel = results.get('Relevancy', 0)
634
+ combined = faith + rel
635
+ all_faith.append(faith)
636
+ all_rel.append(rel)
637
+
638
+ if combined > best_combined:
639
+ best_combined = combined
640
+ best_model_name = model_name
641
+
642
+ if all_faith:
643
+ avg_faith = sum(all_faith) / len(all_faith)
644
+ avg_rel = sum(all_rel) / len(all_rel)
645
+ print(f"{tech_name:<15} {avg_faith:>11.1f}% {avg_rel:>12.3f} {best_model_name or 'N/A':<20}")
646
+ else:
647
+ print(f"{tech_name:<15} {'N/A':>12} {'N/A':>12} {'N/A':<20}")
648
+
649
+ print("-" * 60)
650
+
651
+ print("\n✓ Ablation study complete!")
652
+ print(f"✓ Results saved to: {findings_file}")
653
+ print("\nYou can now analyze the findings document to compare chunking techniques.")
654
+
655
+ return all_query_results
656
+
657
+
658
+ if __name__ == "__main__":
659
+ main()
main_easy.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from dotenv import load_dotenv
4
+ from config_loader import cfg
5
+
6
+ # Optimized imports - only what we need for Retrieval and Generation
7
+ from data.vector_db import get_index_by_name, load_chunks_from_pinecone # Using the new helper
8
+ from retriever.retriever import HybridRetriever
9
+ from retriever.generator import RAGGenerator
10
+ from retriever.processor import ChunkProcessor
11
+ from retriever.evaluator import RAGEvaluator
12
+
13
+ # Model Fleet
14
+ from models.llama_3_8b import Llama3_8B
15
+ from models.mistral_7b import Mistral_7b
16
+ from models.qwen_2_5 import Qwen2_5
17
+ from models.deepseek_v3 import DeepSeek_V3
18
+ from models.tiny_aya import TinyAya
19
+
20
+ MODEL_MAP = {
21
+ "Llama-3-8B": Llama3_8B,
22
+ "Mistral-7B": Mistral_7b,
23
+ "Qwen-2.5": Qwen2_5,
24
+ "DeepSeek-V3": DeepSeek_V3,
25
+ "TinyAya": TinyAya
26
+ }
27
+
28
+ load_dotenv()
29
+
30
+ def main():
31
+ hf_token = os.getenv("HF_TOKEN")
32
+ pinecone_key = os.getenv("PINECONE_API_KEY")
33
+ query = "How do transformers handle long sequences?"
34
+
35
+ # 1. Connect to Existing Index (No creation, no uploading)
36
+ # We use the slugified name directly or via config
37
+ index_name = f"{cfg.db['base_index_name']}-{cfg.processing['technique']}"
38
+ index = get_index_by_name(pinecone_key, index_name)
39
+
40
+ # 2. Setup Processor (Required for the Encoder/Embedding model)
41
+ proc = ChunkProcessor(model_name=cfg.processing['embedding_model'])
42
+
43
+ # 3. Load BM25 Corpus (The "Source of Truth")
44
+ # This replaces the entire data_loader/chunking block
45
+ # Note: On first run, this hits Pinecone. Use a pickle cache here for 0s delay.
46
+ print("🔄 Loading BM25 context from Pinecone metadata...")
47
+ final_chunks = load_chunks_from_pinecone(index)
48
+
49
+ # 4. Retrieval Setup
50
+ retriever = HybridRetriever(final_chunks, proc.encoder)
51
+
52
+ print(f"🔎 Searching via {cfg.retrieval['mode']} mode...")
53
+ context_chunks = retriever.search(
54
+ query, index,
55
+ mode=cfg.retrieval['mode'],
56
+ rerank_strategy=cfg.retrieval['rerank_strategy'],
57
+ use_mmr=cfg.retrieval['use_mmr'],
58
+ top_k=cfg.retrieval['top_k'],
59
+ final_k=cfg.retrieval['final_k']
60
+ )
61
+
62
+ # 5. Initialization of Contestants
63
+ rag_engine = RAGGenerator()
64
+ models = {name: MODEL_MAP[name](token=hf_token) for name in cfg.model_list}
65
+
66
+ evaluator = RAGEvaluator(
67
+ judge_model=cfg.gen['judge_model'],
68
+ embedding_model=proc.encoder,
69
+ api_key=os.getenv("GROQ_API_KEY")
70
+ )
71
+
72
+ tournament_results = {}
73
+
74
+ # 6. Tournament Loop
75
+ for name, model_inst in models.items():
76
+ print(f"\n🏆 Tournament: {name} is generating...")
77
+ try:
78
+ # Generation
79
+ answer = rag_engine.get_answer(
80
+ model_inst, query, context_chunks,
81
+ temperature=cfg.gen['temperature']
82
+ )
83
+
84
+ # Faithfulness Evaluation
85
+ faith = evaluator.evaluate_faithfulness(answer, context_chunks)
86
+ # Relevancy Evaluation
87
+ rel = evaluator.evaluate_relevancy(query, answer)
88
+
89
+ tournament_results[name] = {
90
+ "Answer": answer[:100] + "...", # Preview
91
+ "Faithfulness": faith['score'],
92
+ "Relevancy": rel['score']
93
+ }
94
+ print(f"✅ {name} Score - Faith: {faith['score']} | Rel: {rel['score']}")
95
+
96
+ except Exception as e:
97
+ print(f"❌ Error evaluating {name}: {e}")
98
+
99
+ print("\n--- Final Tournament Standings ---")
100
+ for name, scores in tournament_results.items():
101
+ print(f"{name}: F={scores['Faithfulness']}, R={scores['Relevancy']}")
102
+
103
+ if __name__ == "__main__":
104
+ main()
models/deepseek_v3.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import InferenceClient
2
+
3
+ class DeepSeek_V3:
4
+ def __init__(self, token):
5
+ self.client = InferenceClient(token=token)
6
+ self.model_id = "deepseek-ai/DeepSeek-V3"
7
+
8
+ def generate_stream(self, prompt, max_tokens=1500, temperature=0.1):
9
+ try:
10
+ for message in self.client.chat_completion(
11
+ model=self.model_id,
12
+ messages=[{"role": "user", "content": prompt}],
13
+ max_tokens=max_tokens,
14
+ temperature=temperature,
15
+ stream=True,
16
+ ):
17
+ if message.choices:
18
+ content = message.choices[0].delta.content
19
+ if content:
20
+ yield content
21
+ except Exception as e:
22
+ yield f" DeepSeek API Busy: {e}"
23
+
24
+ def generate(self, prompt, max_tokens=500, temperature=0.1):
25
+ return "".join(self.generate_stream(prompt, max_tokens=max_tokens, temperature=temperature))
models/llama_3_8b.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import InferenceClient
2
+
3
+ class Llama3_8B:
4
+ def __init__(self, token):
5
+ self.client = InferenceClient(token=token)
6
+ self.model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
7
+
8
+ def generate_stream(self, prompt, max_tokens=1500, temperature=0.1):
9
+ for message in self.client.chat_completion(
10
+ model=self.model_id,
11
+ messages=[{"role": "user", "content": prompt}],
12
+ max_tokens=max_tokens,
13
+ temperature=temperature,
14
+ stream=True,
15
+ ):
16
+ if message.choices:
17
+ content = message.choices[0].delta.content
18
+ if content:
19
+ yield content
20
+
21
+ def generate(self, prompt, max_tokens=500, temperature=0.1):
22
+ return "".join(self.generate_stream(prompt, max_tokens=max_tokens, temperature=temperature))
models/mistral_7b.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import InferenceClient
2
+ import os
3
+
4
+ class Mistral_7b:
5
+ def __init__(self, token):
6
+ self.client = InferenceClient(api_key=token)
7
+ # Provider-suffixed ids (e.g. :featherless-ai) are not valid HF repo ids.
8
+ # Keep a sane default and allow override via env for experimentation.
9
+ self.model_id = os.getenv("MISTRAL_MODEL_ID", "mistralai/Mistral-7B-Instruct-v0.2")
10
+
11
+ def generate_stream(self, prompt, max_tokens=1500, temperature=0.1):
12
+ try:
13
+ stream = self.client.chat.completions.create(
14
+ model=self.model_id,
15
+ messages=[{"role": "user", "content": prompt}],
16
+ max_tokens=max_tokens,
17
+ temperature=temperature,
18
+ stream=True,
19
+ )
20
+ for chunk in stream:
21
+ if chunk.choices and chunk.choices[0].delta.content:
22
+ content = chunk.choices[0].delta.content
23
+ yield content
24
+
25
+ except Exception as e:
26
+ yield f" Mistral Featherless Error: {e}"
27
+
28
+ def generate(self, prompt, max_tokens=500, temperature=0.1):
29
+ return "".join(self.generate_stream(prompt, max_tokens=max_tokens, temperature=temperature))
models/qwen_2_5.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import InferenceClient
2
+
3
+ class Qwen2_5:
4
+ def __init__(self, token):
5
+ self.client = InferenceClient(token=token)
6
+ self.model_id = "Qwen/Qwen2.5-72B-Instruct"
7
+
8
+ def generate_stream(self, prompt, max_tokens=1500, temperature=0.1):
9
+ for message in self.client.chat_completion(
10
+ model=self.model_id,
11
+ messages=[{"role": "user", "content": prompt}],
12
+ max_tokens=max_tokens,
13
+ temperature=temperature,
14
+ stream=True,
15
+ ):
16
+ if message.choices:
17
+ content = message.choices[0].delta.content
18
+ if content:
19
+ yield content
20
+
21
+ def generate(self, prompt, max_tokens=500, temperature=0.1):
22
+ return "".join(self.generate_stream(prompt, max_tokens=max_tokens, temperature=temperature))
models/tiny_aya.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import InferenceClient
2
+
3
+ class TinyAya:
4
+ def __init__(self, token):
5
+ self.client = InferenceClient(token=token)
6
+ self.model_id = "CohereLabs/tiny-aya-global"
7
+
8
+ def generate_stream(self, prompt, max_tokens=1500, temperature=0.1):
9
+ try:
10
+ for message in self.client.chat_completion(
11
+ model=self.model_id,
12
+ messages=[{"role": "user", "content": prompt}],
13
+ max_tokens=max_tokens,
14
+ temperature=temperature,
15
+ stream=True,
16
+ ):
17
+ if message.choices:
18
+ content = message.choices[0].delta.content
19
+ if content:
20
+ yield content
21
+ except Exception as e:
22
+ yield f" TinyAya Error: {e}"
23
+
24
+ def generate(self, prompt, max_tokens=500, temperature=0.1):
25
+ return "".join(self.generate_stream(prompt, max_tokens=max_tokens, temperature=temperature))
requirements.txt ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiohappyeyeballs==2.6.1
2
+ aiohttp==3.13.3
3
+ aiosignal==1.4.0
4
+ annotated-doc==0.0.4
5
+ annotated-types==0.7.0
6
+ anyio==4.12.1
7
+ arxiv==2.4.1
8
+ attrs==26.1.0
9
+ certifi==2026.2.25
10
+ charset-normalizer==3.4.6
11
+ click==8.3.1
12
+ colorama==0.4.6
13
+ dataclasses-json==0.6.7
14
+ feedparser==6.0.12
15
+ fastapi==0.121.1
16
+ filelock==3.25.2
17
+ frozenlist==1.8.0
18
+ fsspec==2026.2.0
19
+ greenlet==3.3.2
20
+ h11==0.16.0
21
+ hf-xet==1.4.2
22
+ httpcore==1.0.9
23
+ httpx==0.28.1
24
+ httpx-sse==0.4.3
25
+ huggingface_hub==0.36.0
26
+ idna==3.11
27
+ Jinja2==3.1.6
28
+ joblib==1.5.3
29
+ jsonpatch==1.33
30
+ jsonpointer==3.1.1
31
+ langchain-classic==1.0.3
32
+ langchain-community==0.4.1
33
+ langchain-core==1.2.21
34
+ langchain-experimental==0.4.1
35
+ langchain-huggingface==1.2.1
36
+ langchain-text-splitters==1.1.1
37
+ langsmith==0.7.22
38
+ markdown-it-py==4.0.0
39
+ MarkupSafe==3.0.3
40
+ marshmallow==3.26.2
41
+ mdurl==0.1.2
42
+ mpmath==1.3.0
43
+ multidict==6.7.1
44
+ mypy_extensions==1.1.0
45
+ networkx==3.6.1
46
+ nltk==3.9.4
47
+ numpy==2.4.3
48
+ orjson==3.11.7
49
+ packaging==24.2
50
+ pandas==3.0.1
51
+ pinecone==8.1.0
52
+ pinecone-plugin-assistant==3.0.2
53
+ pinecone-plugin-interface==0.0.7
54
+ propcache==0.4.1
55
+ pydantic==2.12.5
56
+ pydantic-settings==2.13.1
57
+ pydantic_core==2.41.5
58
+ Pygments==2.19.2
59
+ PyMuPDF==1.27.2.2
60
+ python-dateutil==2.9.0.post0
61
+ python-dotenv==1.2.2
62
+ PyYAML==6.0.3
63
+ rank-bm25==0.2.2
64
+ regex==2026.2.28
65
+ requests==2.32.5
66
+ requests-toolbelt==1.0.0
67
+ rich==14.3.3
68
+ safetensors==0.7.0
69
+ scikit-learn==1.8.0
70
+ scipy==1.17.1
71
+ sentence-transformers==5.3.0
72
+ setuptools==81.0.0
73
+ sgmllib3k==1.0.0
74
+ shellingham==1.5.4
75
+ six==1.17.0
76
+ SQLAlchemy==2.0.48
77
+ sympy==1.14.0
78
+ tenacity==9.1.4
79
+ threadpoolctl==3.6.0
80
+ tokenizers==0.22.2
81
+ torch==2.11.0
82
+ tqdm==4.67.3
83
+ transformers==4.57.1
84
+ typer==0.24.1
85
+ typing-inspect==0.9.0
86
+ typing-inspection==0.4.2
87
+ typing_extensions==4.15.0
88
+ tzdata==2025.3
89
+ urllib3==2.6.3
90
+ uvicorn==0.38.0
91
+ uuid_utils==0.14.1
92
+ xxhash==3.6.0
93
+ yarl==1.23.0
94
+ zstandard==0.25.0
95
+ groq==1.1.2
96
+ jiter==0.13.0
97
+ openai==2.30.0
retriever/evaluator.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import numpy as np
3
+ from sklearn.metrics.pairwise import cosine_similarity
4
+ from openai import OpenAI
5
+ from concurrent.futures import ThreadPoolExecutor, as_completed
6
+
7
+
8
+ # ------------------------------------------------------------------
9
+ # OpenRouter Judge Wrapper
10
+ # ------------------------------------------------------------------
11
+
12
+ class GroqJudge:
13
+ def __init__(self, api_key: str, model: str = "deepseek/deepseek-v3.2",):
14
+ """
15
+ Wraps OpenRouter's chat completions to match the .generate(prompt) interface
16
+ expected by RAGEvaluator.
17
+
18
+ Args:
19
+ api_key: Your OpenRouter API key (https://openrouter.ai)
20
+ model: OpenRouter model to use (primary model with fallback support)
21
+ """
22
+ self.client = OpenAI(
23
+ base_url="https://openrouter.ai/api/v1",
24
+ api_key=api_key,
25
+ )
26
+ self.model = model
27
+
28
+ # Fallback models in order of preference (OpenRouter free models)
29
+ self.fallback_models = [
30
+ "deepseek/deepseek-v3.2",
31
+ "qwen/qwen3.6-plus-preview:free",
32
+ "stepfun/step-3.5-flash:free",
33
+ "nvidia/nemotron-3-super-120b-a12b:free",
34
+ "z-ai/glm-4.5-air:free",
35
+ "nvidia/nemotron-3-nano-30b-a3b:free",
36
+ "arcee-ai/trinity-mini:free",
37
+ "xiaomi/mimo-v2-flash"
38
+ ]
39
+
40
+ def generate(self, prompt: str) -> str:
41
+ """Generate response with fallback support for multiple models."""
42
+ last_error = None
43
+
44
+ # Try primary model first, then fallbacks
45
+ models_to_try = [self.model] + [m for m in self.fallback_models if m != self.model]
46
+
47
+ for model_name in models_to_try:
48
+
49
+
50
+ try:
51
+ response = self.client.chat.completions.create(
52
+ model=model_name,
53
+ messages=[{"role": "user", "content": prompt}],
54
+ )
55
+ content = response.choices[0].message.content
56
+ if content is None:
57
+ raise ValueError(f"Model {model_name} returned None content")
58
+ return content.strip()
59
+ except Exception as e:
60
+ last_error = e
61
+ # If rate limited or model unavailable, try next model
62
+ if "429" in str(e) or "rate_limit" in str(e).lower() or "model" in str(e).lower():
63
+ continue
64
+ # For other errors, raise immediately
65
+ raise
66
+
67
+ # If all models fail, raise the last error
68
+ raise last_error
69
+
70
+
71
+ # ------------------------------------------------------------------
72
+ # RAG Evaluator
73
+ # ------------------------------------------------------------------
74
+
75
+ class RAGEvaluator:
76
+ def __init__(self, judge_model: str, embedding_model, api_key: str, verbose=True):
77
+ """
78
+ judge_model: Model name string passed to OpenRouterJudge, must match cfg.gen['judge_model']
79
+ e.g. "stepfun/step-3.5-flash:free", "nvidia/nemotron-3-super-120b-a12b:free"
80
+ embedding_model: The proc.encoder (SentenceTransformer) for similarity checks
81
+ api_key: OpenRouter API key (https://openrouter.ai)
82
+ verbose: If True, prints progress via internal helpers
83
+ """
84
+ self.judge = GroqJudge(api_key=api_key, model=judge_model)
85
+ self.encoder = embedding_model
86
+ self.verbose = verbose
87
+
88
+ # ------------------------------------------------------------------
89
+ # 1. FAITHFULNESS: Claim Extraction & Verification
90
+ # ------------------------------------------------------------------
91
+
92
+ def evaluate_faithfulness(self, answer: str, context_list: list[str], strict: bool = True) -> dict:
93
+ """
94
+ Args:
95
+ strict: If True, verifies each claim against chunks individually
96
+ (more API calls but catches vague batch verdicts).
97
+ If False, uses single batched verification call.
98
+ """
99
+ if self.verbose:
100
+ self._print_extraction_header(len(answer), strict=strict)
101
+
102
+ # --- Step A: Extraction ---
103
+ extraction_prompt = (
104
+ "Extract a list of independent factual claims from the following answer.\n"
105
+ "Rules:\n"
106
+ "- Each claim must be specific and verifiable — include numbers, names, or concrete details where present\n"
107
+ "- Vague claims like 'the model performs well' or 'this improves results' are NOT acceptable\n"
108
+ "- Do NOT include claims about what the context does or does not contain\n"
109
+ "- Do NOT include introductory text, numbering, or bullet points\n"
110
+ "- Do NOT rephrase or merge claims\n"
111
+ "- One claim per line only\n\n"
112
+ f"Answer: {answer}"
113
+ )
114
+ raw_claims = self.judge.generate(extraction_prompt)
115
+
116
+ # Filter out short lines, preamble, and lines ending with ':'
117
+ claims = [
118
+ c.strip() for c in raw_claims.split('\n')
119
+ if len(c.strip()) > 20 and not c.strip().endswith(':')
120
+ ]
121
+
122
+ if not claims:
123
+ return {"score": 0, "details": []}
124
+
125
+ # --- Step B: Verification ---
126
+ if strict:
127
+ # Per-chunk: claim must be explicitly supported by at least one chunk
128
+ # Parallelize across claims as well
129
+ def verify_claim_wrapper(args):
130
+ i, claim = args
131
+ return i, self._verify_claim_against_chunks(claim, context_list)
132
+
133
+ with ThreadPoolExecutor(max_workers=min(len(claims), 5)) as executor:
134
+ futures = [executor.submit(verify_claim_wrapper, (i, claim)) for i, claim in enumerate(claims)]
135
+ verdicts = {i: result for future in as_completed(futures) for i, result in [future.result()]}
136
+ else:
137
+ # Batch: all chunks joined, strict burden-of-proof prompt
138
+ combined_context = "\n".join(context_list)
139
+ if len(combined_context) > 6000:
140
+ combined_context = combined_context[:6000]
141
+
142
+ claims_formatted = "\n".join([f"{i+1}. {c}" for i, c in enumerate(claims)])
143
+
144
+ batch_prompt = (
145
+ f"Context:\n{combined_context}\n\n"
146
+ f"For each claim, respond YES only if the claim is EXPLICITLY and DIRECTLY "
147
+ f"supported by the context above. Respond NO if the claim is inferred, assumed, "
148
+ f"or not clearly stated in the context.\n\n"
149
+ f"Format strictly as:\n"
150
+ f"1: YES\n"
151
+ f"2: NO\n\n"
152
+ f"Claims:\n{claims_formatted}"
153
+ )
154
+ raw_verdicts = self.judge.generate(batch_prompt)
155
+
156
+ verdicts = {}
157
+ for line in raw_verdicts.split('\n'):
158
+ match = re.match(r'(\d+)\s*:\s*(YES|NO)', line.strip().upper())
159
+ if match:
160
+ verdicts[int(match.group(1)) - 1] = match.group(2) == "YES"
161
+
162
+ # --- Step C: Scoring & Details ---
163
+ verified_count = 0
164
+ details = []
165
+ for i, claim in enumerate(claims):
166
+ is_supported = verdicts.get(i, False)
167
+ if is_supported:
168
+ verified_count += 1
169
+ details.append({
170
+ "claim": claim,
171
+ "verdict": "Supported" if is_supported else "Not Supported"
172
+ })
173
+
174
+ score = (verified_count / len(claims)) * 100
175
+
176
+ if self.verbose:
177
+ self._print_faithfulness_results(claims, details, score)
178
+
179
+ return {"score": score, "details": details}
180
+
181
+ def _verify_claim_against_chunks(self, claim: str, context_list: list[str]) -> bool:
182
+ """Verify a single claim against each chunk individually. Returns True if any chunk supports it."""
183
+ def verify_single_chunk(chunk):
184
+ prompt = (
185
+ f"Context:\n{chunk}\n\n"
186
+ f"Claim: {claim}\n\n"
187
+ f"Is this claim EXPLICITLY and DIRECTLY stated in the context above? "
188
+ f"Do not infer or assume. Respond with YES or NO only."
189
+ )
190
+ result = self.judge.generate(prompt)
191
+ return "YES" in result.upper()
192
+
193
+ # Use ThreadPoolExecutor for parallel verification
194
+ with ThreadPoolExecutor(max_workers=min(len(context_list), 5)) as executor:
195
+ futures = [executor.submit(verify_single_chunk, chunk) for chunk in context_list]
196
+ for future in as_completed(futures):
197
+ if future.result():
198
+ return True
199
+ return False
200
+
201
+ # ------------------------------------------------------------------
202
+ # 2. RELEVANCY: Alternate Query Generation
203
+ # ------------------------------------------------------------------
204
+
205
+ def evaluate_relevancy(self, query: str, answer: str) -> dict:
206
+ if self.verbose:
207
+ self._print_relevancy_header()
208
+
209
+ # --- Step A: Generation ---
210
+ # Explicitly ask the judge NOT to rephrase the original query
211
+ gen_prompt = (
212
+ f"Generate 3 distinct questions that the following answer addresses.\n"
213
+ f"Rules:\n"
214
+ f"- Do NOT rephrase or repeat this question: '{query}'\n"
215
+ f"- Each question must end with a '?'\n"
216
+ f"- One question per line, no numbering or bullet points\n\n"
217
+ f"Answer: {answer}"
218
+ )
219
+ raw_gen = self.judge.generate(gen_prompt)
220
+
221
+ # Filter by length rather than just '?' presence
222
+ gen_queries = [
223
+ q.strip() for q in raw_gen.split('\n')
224
+ if len(q.strip()) > 10
225
+ ][:3]
226
+
227
+ if not gen_queries:
228
+ return {"score": 0, "queries": []}
229
+
230
+ # --- Step B: Similarity (single batched encode call) ---
231
+ all_vecs = self.encoder.encode([query] + gen_queries)
232
+ original_vec = all_vecs[0:1]
233
+ generated_vecs = all_vecs[1:]
234
+
235
+ similarities = cosine_similarity(original_vec, generated_vecs)[0]
236
+ avg_score = float(np.mean(similarities))
237
+
238
+ if self.verbose:
239
+ self._print_relevancy_results(query, gen_queries, similarities, avg_score)
240
+
241
+ return {"score": avg_score, "queries": gen_queries}
242
+
243
+ # ------------------------------------------------------------------
244
+ # 3. DATASET-LEVEL EVALUATION
245
+ # ------------------------------------------------------------------
246
+
247
+ def evaluate_dataset(self, test_cases: list[dict], strict: bool = False) -> dict:
248
+ """
249
+ Runs faithfulness + relevancy over a full test set and aggregates results.
250
+
251
+ Args:
252
+ test_cases: List of dicts, each with keys:
253
+ - "query": str
254
+ - "answer": str
255
+ - "contexts": List[str]
256
+ strict: If True, passes strict=True to evaluate_faithfulness
257
+ (per-chunk verification, more API calls, harder to pass)
258
+
259
+ Returns:
260
+ {
261
+ "avg_faithfulness": float,
262
+ "avg_relevancy": float,
263
+ "per_query": List[dict]
264
+ }
265
+ """
266
+ faithfulness_scores = []
267
+ relevancy_scores = []
268
+ per_query = []
269
+
270
+ for i, case in enumerate(test_cases):
271
+ if self.verbose:
272
+ print(f"\n{'='*60}")
273
+ print(f"Query {i+1}/{len(test_cases)}: {case['query']}")
274
+ print('='*60)
275
+
276
+ f_result = self.evaluate_faithfulness(case['answer'], case['contexts'], strict=strict)
277
+ r_result = self.evaluate_relevancy(case['query'], case['answer'])
278
+
279
+ faithfulness_scores.append(f_result['score'])
280
+ relevancy_scores.append(r_result['score'])
281
+ per_query.append({
282
+ "query": case['query'],
283
+ "faithfulness": f_result,
284
+ "relevancy": r_result,
285
+ })
286
+
287
+ results = {
288
+ "avg_faithfulness": float(np.mean(faithfulness_scores)),
289
+ "avg_relevancy": float(np.mean(relevancy_scores)),
290
+ "per_query": per_query,
291
+ }
292
+
293
+ if self.verbose:
294
+ self._print_dataset_summary(results)
295
+
296
+ return results
297
+
298
+ # ------------------------------------------------------------------
299
+ # 4. PRINT HELPERS
300
+ # ------------------------------------------------------------------
301
+
302
+ def _print_extraction_header(self, length, strict=False):
303
+ mode = "strict per-chunk" if strict else "batch"
304
+ print(f"\n[EVAL] Analyzing Faithfulness ({mode})...")
305
+ print(f" - Extracting claims from answer ({length} chars)")
306
+
307
+ def _print_faithfulness_results(self, claims, details, score):
308
+ print(f" - Verifying {len(claims)} claims against context...")
309
+ for i, detail in enumerate(details):
310
+ status = "✅" if "Yes" in detail['verdict'] else "❌"
311
+ print(f" {status} Claim {i+1}: {detail['claim'][:75]}...")
312
+ print(f" 🎯 Faithfulness Score: {score:.1f}%")
313
+
314
+ def _print_relevancy_header(self):
315
+ print(f"\n[EVAL] Analyzing Relevancy...")
316
+ print(f" - Generating 3 distinct questions addressed by the answer")
317
+
318
+ def _print_relevancy_results(self, query, gen_queries, similarities, avg):
319
+ print(f" - Comparing to original query: '{query}'")
320
+ for i, (q, sim) in enumerate(zip(gen_queries, similarities)):
321
+ print(f" Q{i+1}: {q} (Sim: {sim:.2f})")
322
+ print(f" 🎯 Average Relevancy: {avg:.2f}")
323
+
324
+ def _print_dataset_summary(self, results):
325
+ print(f"\n{'='*60}")
326
+ print(f" DATASET EVALUATION SUMMARY")
327
+ print(f"{'='*60}")
328
+ print(f" Avg Faithfulness : {results['avg_faithfulness']:.1f}%")
329
+ print(f" Avg Relevancy : {results['avg_relevancy']:.2f}")
330
+ print(f" Queries Evaluated: {len(results['per_query'])}")
331
+ print(f"{'='*60}")
retriever/generator.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #changed the prompt to output as markdown, plus some formating details
2
+ #also added get answer stream for incremental token rendering on the frontend
3
+ # --@Qamar
4
+ class RAGGenerator:
5
+ def generate_prompt(self, query, retrieved_contexts, context_urls=None):
6
+ if context_urls:
7
+ context_text = "\n\n".join([f"[Source {i+1}] {url}: {c}" for i, (c, url) in enumerate(zip(retrieved_contexts, context_urls))])
8
+ else:
9
+ context_text = "\n\n".join([f"[Source {i+1}]: {c}" for i, c in enumerate(retrieved_contexts)])
10
+
11
+ return f"""You are a specialized Cognitive Behavioral Therapy (CBT) assistant. Your task is to provide accurate, clinical, and structured answers based ONLY on the provided textbook excerpts.
12
+
13
+ INSTRUCTIONS:
14
+ 1. Use the provided Sources to answer the question.
15
+ 2. CITATIONS: You must cite the sources used in your answer (e.g., "CBT is based on the cognitive model [Source 1]").
16
+ 3. FORMAT: Use clear headers and bullet points for complex explanations.
17
+ 4. GROUNDING: If the sources do not contain the answer, explicitly state: "The provided excerpts from the textbook do not contain information to answer this specific question." Do not use your own internal knowledge.
18
+ 5. TONE: Maintain a professional, empathetic, and academic tone.
19
+
20
+ RETRIVED TEXTBOOK CONTEXT:
21
+ {context_text}
22
+
23
+ USER QUESTION: {query}
24
+
25
+ ACADEMIC ANSWER (WITH CITATIONS):"""
26
+
27
+ def get_answer(self, model_instance, query, retrieved_contexts, context_urls=None, **kwargs):
28
+ """Uses a specific model instance to generate the final answer."""
29
+ prompt = self.generate_prompt(query, retrieved_contexts, context_urls)
30
+ return model_instance.generate(prompt, **kwargs)
31
+
32
+ def get_answer_stream(self, model_instance, query, retrieved_contexts, context_urls=None, **kwargs):
33
+ """Streams model output token-by-token for incremental UI updates."""
34
+ prompt = self.generate_prompt(query, retrieved_contexts, context_urls)
35
+
36
+ if hasattr(model_instance, "generate_stream"):
37
+ for token in model_instance.generate_stream(prompt, **kwargs):
38
+ if token:
39
+ yield token
40
+ return
41
+
42
+ # Fallback for model wrappers that only expose sync generation.
43
+ answer = model_instance.generate(prompt, **kwargs)
44
+ if answer:
45
+ yield answer
retriever/processor.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_text_splitters import (
2
+ RecursiveCharacterTextSplitter,
3
+ CharacterTextSplitter,
4
+ SentenceTransformersTokenTextSplitter,
5
+ NLTKTextSplitter
6
+ )
7
+ from langchain_experimental.text_splitter import SemanticChunker
8
+ from langchain_huggingface import HuggingFaceEmbeddings
9
+ from sentence_transformers import SentenceTransformer
10
+ from typing import List, Dict, Any, Optional
11
+ import nltk
12
+ nltk.download('punkt_tab', quiet=True)
13
+ import pandas as pd
14
+ import re
15
+
16
+
17
+ class MarkdownTextSplitter:
18
+ """
19
+ Custom markdown header chunking strategy.
20
+
21
+ Splits text by headers in a hierarchical manner:
22
+ - First checks h1 (#) headers
23
+ - If h1 content <= max_chars, accepts it as a chunk
24
+ - If h1 content > max_chars, splits into h2 headers
25
+ - If any h2 > max_chars, splits into h3, and so on
26
+ """
27
+
28
+ def __init__(self, max_chars: int = 4000):
29
+ self.max_chars = max_chars
30
+ self.headers = ["\n# ", "\n## ", "\n### ", "\n#### "]
31
+
32
+ def split_text(self, text: str) -> List[str]:
33
+ """Split text using markdown header hierarchy."""
34
+ return self._split_by_header(text, 0)
35
+
36
+ def _split_by_header(self, content: str, header_level: int) -> List[str]:
37
+ """
38
+ Recursively split content by header levels.
39
+
40
+ Args:
41
+ content: The text content to split
42
+ header_level: Current header level (0=h1, 1=h2, etc.)
43
+
44
+ Returns:
45
+ List of text chunks
46
+ """
47
+ # If content is within limit, return it as is
48
+ if len(content) <= self.max_chars:
49
+ return [content]
50
+
51
+ # If we've exhausted all header levels, return as single chunk
52
+ if header_level >= len(self.headers):
53
+ return [content]
54
+
55
+ # Split by current header level
56
+ header = self.headers[header_level]
57
+ parts = re.split(f'(?={re.escape(header)})', content)
58
+
59
+ # If no split occurred (no headers found at this level), try next level
60
+ if len(parts) == 1:
61
+ return self._split_by_header(content, header_level + 1)
62
+
63
+ result = []
64
+ accumulated = ""
65
+
66
+ for i, part in enumerate(parts):
67
+ # If this single part is too large, recursively split it with next header level
68
+ if len(part) > self.max_chars:
69
+ # First, flush any accumulated content
70
+ if accumulated:
71
+ result.append(accumulated)
72
+ accumulated = ""
73
+ # Then recursively split this large part with deeper headers
74
+ result.extend(self._split_by_header(part, header_level + 1))
75
+ # If adding this part would exceed limit, flush accumulated and start new
76
+ elif accumulated and len(accumulated) + len(part) > self.max_chars:
77
+ result.append(accumulated)
78
+ accumulated = part
79
+ # Accumulate parts that fit together
80
+ else:
81
+ accumulated += part
82
+
83
+ # Don't forget the last accumulated part
84
+ if accumulated:
85
+ result.append(accumulated)
86
+
87
+ return result
88
+
89
+
90
+ class ChunkProcessor:
91
+ def __init__(self, model_name='all-MiniLM-L6-v2', verbose: bool = True, load_hf_embeddings: bool = False):
92
+ self.model_name = model_name
93
+ self._use_remote_code = self._requires_remote_code(model_name)
94
+ st_kwargs = {"trust_remote_code": True} if self._use_remote_code else {}
95
+ self.encoder = SentenceTransformer(model_name, **st_kwargs)
96
+ self.verbose = verbose
97
+ hf_kwargs = {"model_kwargs": {"trust_remote_code": True}} if self._use_remote_code else {}
98
+ self.hf_embeddings = HuggingFaceEmbeddings(model_name=model_name, **hf_kwargs) if load_hf_embeddings else None
99
+
100
+ def _requires_remote_code(self, model_name: str) -> bool:
101
+ normalized = (model_name or "").strip().lower()
102
+ return normalized.startswith("jinaai/")
103
+
104
+ def _get_hf_embeddings(self):
105
+ if self.hf_embeddings is None:
106
+ hf_kwargs = {"model_kwargs": {"trust_remote_code": True}} if self._use_remote_code else {}
107
+ self.hf_embeddings = HuggingFaceEmbeddings(model_name=self.model_name, **hf_kwargs)
108
+ return self.hf_embeddings
109
+
110
+ # ------------------------------------------------------------------
111
+ # Splitters
112
+ # ------------------------------------------------------------------
113
+
114
+ def get_splitter(self, technique: str, chunk_size: int = 500, chunk_overlap: int = 50, **kwargs):
115
+ """
116
+ Factory method to return different chunking strategies.
117
+
118
+ Strategies:
119
+ - "fixed": Character-based, may split mid-sentence
120
+ - "recursive": Recursive character splitting with hierarchical separators
121
+ - "character": Character-based splitting on paragraph boundaries
122
+ - "paragraph": Paragraph-level splitting on \\n\\n boundaries
123
+ - "sentence": Sliding window over NLTK sentences
124
+ - "semantic": Embedding-based semantic chunking
125
+ - "page": Page-level splitting on page markers
126
+ """
127
+ if technique == "fixed":
128
+ return CharacterTextSplitter(
129
+ separator=kwargs.get('separator', ""),
130
+ chunk_size=chunk_size,
131
+ chunk_overlap=chunk_overlap,
132
+ length_function=len,
133
+ is_separator_regex=False
134
+ )
135
+
136
+ elif technique == "recursive":
137
+ return RecursiveCharacterTextSplitter(
138
+ chunk_size=chunk_size,
139
+ chunk_overlap=chunk_overlap,
140
+ separators=kwargs.get('separators', ["\n\n", "\n", ". ", "! ", "? ", "; ", ", ", " ", ""]),
141
+ length_function=len,
142
+ keep_separator=kwargs.get('keep_separator', True)
143
+ )
144
+
145
+ elif technique == "character":
146
+ return CharacterTextSplitter(
147
+ separator=kwargs.get('separator', "\n\n"),
148
+ chunk_size=chunk_size,
149
+ chunk_overlap=chunk_overlap,
150
+ length_function=len,
151
+ is_separator_regex=False
152
+ )
153
+
154
+ elif technique == "paragraph":
155
+ # Paragraph-level chunking using paragraph breaks
156
+ return CharacterTextSplitter(
157
+ separator=kwargs.get('separator', "\n\n"),
158
+ chunk_size=chunk_size,
159
+ chunk_overlap=chunk_overlap,
160
+ length_function=len,
161
+ is_separator_regex=False
162
+ )
163
+
164
+ elif technique == "sentence":
165
+ # sentence-level chunking using NLTK
166
+ return NLTKTextSplitter(
167
+ chunk_size=chunk_size,
168
+ chunk_overlap=chunk_overlap,
169
+ separator="\n"
170
+ )
171
+
172
+ elif technique == "semantic":
173
+ return SemanticChunker(
174
+ self._get_hf_embeddings(),
175
+ breakpoint_threshold_type=kwargs.get('breakpoint_threshold_type', "percentile"),
176
+ # Using 70 because 95 was giving way too big chunks
177
+ breakpoint_threshold_amount=kwargs.get('breakpoint_threshold_amount', 70)
178
+ )
179
+
180
+ elif technique == "page":
181
+ # Page-level chunking using page markers
182
+ return CharacterTextSplitter(
183
+ separator=kwargs.get('separator', "--- Page"),
184
+ chunk_size=chunk_size,
185
+ chunk_overlap=chunk_overlap,
186
+ length_function=len,
187
+ is_separator_regex=False
188
+ )
189
+
190
+ elif technique == "markdown":
191
+ # Markdown header chunking - splits by headers with max char limit
192
+ return MarkdownTextSplitter(max_chars=chunk_size)
193
+
194
+ else:
195
+ raise ValueError(f"Technique '{technique}' is not supported. Choose from: fixed, recursive, character, paragraph, sentence, semantic, page, markdown")
196
+
197
+ # ------------------------------------------------------------------
198
+ # Processing
199
+ # ------------------------------------------------------------------
200
+
201
+ def process(self, df: pd.DataFrame, technique: str = "recursive", chunk_size: int = 500,
202
+ chunk_overlap: int = 50, max_docs: Optional[int] = 5,
203
+ verbose: Optional[bool] = None, **kwargs) -> List[Dict[str, Any]]:
204
+ """
205
+ Processes a DataFrame into vector-ready chunks.
206
+
207
+ Args:
208
+ df: DataFrame with columns: id, title, url, full_text
209
+ technique: Chunking strategy to use
210
+ chunk_size: Maximum size of each chunk in characters
211
+ chunk_overlap: Overlap between consecutive chunks
212
+ max_docs: Number of documents to process (None for all)
213
+ verbose: Override instance verbose setting
214
+ **kwargs: Additional arguments passed to the splitter
215
+
216
+ Returns:
217
+ List of chunk dicts with embeddings and metadata
218
+ """
219
+ should_print = verbose if verbose is not None else self.verbose
220
+
221
+ required_cols = ['id', 'title', 'url', 'full_text']
222
+ missing_cols = [col for col in required_cols if col not in df.columns]
223
+ if missing_cols:
224
+ raise ValueError(f"DataFrame missing required columns: {missing_cols}")
225
+
226
+ splitter = self.get_splitter(technique, chunk_size, chunk_overlap, **kwargs)
227
+ subset_df = df.head(max_docs) if max_docs else df
228
+ processed_chunks = []
229
+
230
+ for _, row in subset_df.iterrows():
231
+ if should_print:
232
+ self._print_document_header(row['title'], row['url'], technique, chunk_size, chunk_overlap)
233
+
234
+ raw_chunks = splitter.split_text(row['full_text'])
235
+
236
+ for i, text in enumerate(raw_chunks):
237
+ content = text.page_content if hasattr(text, 'page_content') else text
238
+
239
+ if should_print:
240
+ self._print_chunk(i, content)
241
+
242
+ processed_chunks.append({
243
+ "id": f"{row['id']}-chunk-{i}",
244
+ "values": self.encoder.encode(content).tolist(),
245
+ "metadata": {
246
+ "title": row['title'],
247
+ "text": content,
248
+ "url": row['url'],
249
+ "chunk_index": i,
250
+ "technique": technique,
251
+ "chunk_size": len(content),
252
+ "total_chunks": len(raw_chunks)
253
+ }
254
+ })
255
+
256
+ if should_print:
257
+ self._print_document_summary(len(raw_chunks))
258
+
259
+ if should_print:
260
+ self._print_processing_summary(len(subset_df), processed_chunks)
261
+
262
+ return processed_chunks
263
+
264
+
265
+ # ------------------------------------------------------------------
266
+ # Printing
267
+ # ------------------------------------------------------------------
268
+
269
+ def _print_document_header(self, title, url, technique, chunk_size, chunk_overlap):
270
+ print("\n" + "="*80)
271
+ print(f"DOCUMENT: {title}")
272
+ print(f"URL: {url}")
273
+ print(f"Technique: {technique.upper()} | Chunk Size: {chunk_size} | Overlap: {chunk_overlap}")
274
+ print("-" * 80)
275
+
276
+ def _print_chunk(self, index, content):
277
+ print(f"\n[Chunk {index}] ({len(content)} chars):")
278
+ print(f" {content}")
279
+
280
+ def _print_document_summary(self, num_chunks):
281
+ print(f"Total Chunks Generated: {num_chunks}")
282
+ print("="*80)
283
+
284
+ def _print_processing_summary(self, num_docs, processed_chunks):
285
+ print(f"\nFinished processing {num_docs} documents into {len(processed_chunks)} chunks.")
286
+ if processed_chunks:
287
+ avg = sum(c['metadata']['chunk_size'] for c in processed_chunks) / len(processed_chunks)
288
+ print(f"Average chunk size: {avg:.0f} chars")
retriever/retriever.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import time
3
+ import re
4
+ from rank_bm25 import BM25Okapi
5
+ from sklearn.metrics.pairwise import cosine_similarity
6
+ from typing import Optional, List
7
+
8
+ # changed mmr to return final k, as a param, prev was hardcoded to 3
9
+ # --@Qamare
10
+
11
+ # Try to import FlashRank for CPU optimization, fallback to sentence-transformers
12
+ try:
13
+ from flashrank import Ranker, RerankRequest
14
+ FLASHRANK_AVAILABLE = True
15
+ except ImportError:
16
+ from sentence_transformers import CrossEncoder
17
+ FLASHRANK_AVAILABLE = False
18
+
19
+ class HybridRetriever:
20
+ def __init__(self, final_chunks, embed_model, rerank_model_name='jinaai/jina-reranker-v1-tiny-en', verbose: bool = True):
21
+ self.final_chunks = final_chunks
22
+ self.embed_model = embed_model
23
+ self.verbose = verbose
24
+ self.rerank_model_name = self._normalize_rerank_model_name(rerank_model_name)
25
+
26
+ # Use FlashRank if available (faster on CPU), otherwise fallback to sentence-transformers
27
+ if FLASHRANK_AVAILABLE:
28
+ try:
29
+ self.rerank_model = Ranker(model_name=self.rerank_model_name)
30
+ self.use_flashrank = True
31
+ except Exception:
32
+ from sentence_transformers import CrossEncoder as STCrossEncoder
33
+ self.rerank_model = STCrossEncoder(self.rerank_model_name)
34
+ self.use_flashrank = False
35
+ else:
36
+ self.rerank_model = CrossEncoder(self.rerank_model_name)
37
+ self.use_flashrank = False
38
+
39
+ # Better tokenization for BM25 (strips punctuation)
40
+ self.tokenized_corpus = [self._tokenize(chunk['metadata']['text']) for chunk in final_chunks]
41
+ self.bm25 = BM25Okapi(self.tokenized_corpus)
42
+ self.technique_to_indices = self._build_chunking_index_map()
43
+
44
+ def _normalize_rerank_model_name(self, model_name: str) -> str:
45
+ normalized = (model_name or "").strip()
46
+ if not normalized:
47
+ return "cross-encoder/ms-marco-MiniLM-L-6-v2"
48
+ if "/" in normalized:
49
+ return normalized
50
+ return f"cross-encoder/{normalized}"
51
+
52
+ def _tokenize(self, text: str) -> List[str]:
53
+ """Tokenize text using regex to strip punctuation."""
54
+ return re.findall(r'\w+', text.lower())
55
+
56
+ # added these two helper methods for chunking based on chunk_technique metadata, and normalization of chunking_technique param
57
+ def _build_chunking_index_map(self) -> dict[str, List[int]]:
58
+ mapping: dict[str, List[int]] = {}
59
+ for idx, chunk in enumerate(self.final_chunks):
60
+ metadata = chunk.get('metadata', {})
61
+ technique = (metadata.get('chunking_technique') or '').strip().lower()
62
+ if not technique:
63
+ continue
64
+ mapping.setdefault(technique, []).append(idx)
65
+ return mapping
66
+
67
+ def _normalize_chunking_technique(self, chunking_technique: Optional[str]) -> Optional[str]:
68
+ if not chunking_technique:
69
+ return None
70
+ normalized = str(chunking_technique).strip().lower()
71
+ if not normalized or normalized in {"all", "any", "*", "none"}:
72
+ return None
73
+ return normalized
74
+
75
+ # ------------------------------------------------------------------
76
+ # Retrieval
77
+ # ------------------------------------------------------------------
78
+
79
+ def _semantic_search(self, query, index, top_k, chunking_technique: Optional[str] = None) -> tuple[np.ndarray, List[str]]:
80
+ query_vector = self.embed_model.encode(query)
81
+ query_kwargs = {
82
+ "vector": query_vector.tolist(),
83
+ "top_k": top_k,
84
+ "include_metadata": True,
85
+ }
86
+ if chunking_technique:
87
+ query_kwargs["filter"] = {"chunking_technique": {"$eq": chunking_technique}}
88
+ res = index.query(**query_kwargs)
89
+ chunks = [match['metadata']['text'] for match in res['matches']]
90
+ return query_vector, chunks
91
+
92
+ def _bm25_search(self, query, top_k, chunking_technique: Optional[str] = None) -> List[str]:
93
+ tokenized_query = self._tokenize(query)
94
+ scores = self.bm25.get_scores(tokenized_query)
95
+
96
+ if chunking_technique:
97
+ candidate_indices = self.technique_to_indices.get(chunking_technique, [])
98
+ if not candidate_indices:
99
+ return []
100
+ top_indices = sorted(candidate_indices, key=lambda i: scores[i], reverse=True)[:top_k]
101
+ else:
102
+ top_indices = np.argsort(scores)[::-1][:top_k]
103
+
104
+ return [self.final_chunks[i]['metadata']['text'] for i in top_indices]
105
+
106
+ # ------------------------------------------------------------------
107
+ # Fusion
108
+ # ------------------------------------------------------------------
109
+
110
+ def _rrf_score(self, semantic_results, bm25_results, k=60) -> List[str]:
111
+ scores = {}
112
+ for rank, chunk in enumerate(semantic_results):
113
+ scores[chunk] = scores.get(chunk, 0) + 1 / (k + rank + 1)
114
+ for rank, chunk in enumerate(bm25_results):
115
+ scores[chunk] = scores.get(chunk, 0) + 1 / (k + rank + 1)
116
+ return [chunk for chunk, _ in sorted(scores.items(), key=lambda x: x[1], reverse=True)]
117
+
118
+ # ------------------------------------------------------------------
119
+ # Reranking
120
+ # ------------------------------------------------------------------
121
+
122
+ def _cross_encoder_rerank(self, query, chunks, final_k) -> List[str]:
123
+ if self.use_flashrank:
124
+ # Use FlashRank for CPU-optimized reranking
125
+ passages = [{"id": i, "text": chunk} for i, chunk in enumerate(chunks)]
126
+ rerank_request = RerankRequest(query=query, passages=passages)
127
+ results = self.rerank_model.rerank(rerank_request)
128
+ ranked_chunks = [res['text'] for res in results]
129
+ return ranked_chunks[:final_k]
130
+ else:
131
+ # Fallback to sentence-transformers CrossEncoder
132
+ pairs = [[query, chunk] for chunk in chunks]
133
+ scores = self.rerank_model.predict(pairs)
134
+ ranked = sorted(zip(chunks, scores), key=lambda x: x[1], reverse=True)
135
+ return [chunk for chunk, _ in ranked[:final_k]]
136
+
137
+ # ------------------------------------------------------------------
138
+ # MMR (applied after reranking as a diversity filter)
139
+ # ------------------------------------------------------------------
140
+
141
+ def _maximal_marginal_relevance(self, query_vector, chunks, lambda_param=0.5, top_k=10) -> List[str]:
142
+ """
143
+ Maximum Marginal Relevance (MMR) for diversity filtering.
144
+
145
+ DIVISION BY ZERO DEBUGGING:
146
+ - This method can cause division by zero in cosine_similarity if vectors are zero
147
+ - We've added multiple safeguards to prevent this
148
+ """
149
+ print(f" [MMR DEBUG] Starting MMR with {len(chunks)} chunks, top_k={top_k}")
150
+
151
+ if not chunks:
152
+ print(f" [MMR DEBUG] No chunks, returning empty list")
153
+ return []
154
+
155
+ # STEP 1: Encode chunks to get embeddings
156
+ print(f" [MMR DEBUG] Encoding {len(chunks)} chunks...")
157
+ try:
158
+ chunk_embeddings = self.embed_model.encode(chunks)
159
+ print(f" [MMR DEBUG] Chunk embeddings shape: {chunk_embeddings.shape}")
160
+ except Exception as e:
161
+ print(f" [MMR DEBUG] ERROR encoding chunks: {e}")
162
+ return chunks[:top_k]
163
+
164
+ # STEP 2: Reshape query vector
165
+ query_embedding = query_vector.reshape(1, -1)
166
+ print(f" [MMR DEBUG] Query embedding shape: {query_embedding.shape}")
167
+
168
+ # STEP 3: Check for zero vectors (POTENTIAL DIVISION BY ZERO SOURCE)
169
+ print(f" [MMR DEBUG] Checking for zero vectors...")
170
+ query_norm = np.linalg.norm(query_embedding)
171
+ chunk_norms = np.linalg.norm(chunk_embeddings, axis=1)
172
+
173
+ print(f" [MMR DEBUG] Query norm: {query_norm}")
174
+ print(f" [MMR DEBUG] Chunk norms min: {chunk_norms.min()}, max: {chunk_norms.max()}")
175
+
176
+ # Check for zero or near-zero vectors
177
+ if query_norm < 1e-10 or np.any(chunk_norms < 1e-10):
178
+ print(f" [MMR DEBUG] WARNING: Zero or near-zero vectors detected!")
179
+ print(f" [MMR DEBUG] Query norm < 1e-10: {query_norm < 1e-10}")
180
+ print(f" [MMR DEBUG] Any chunk norm < 1e-10: {np.any(chunk_norms < 1e-10)}")
181
+ print(f" [MMR DEBUG] Falling back to simple selection without MMR")
182
+ return chunks[:top_k]
183
+
184
+ # STEP 4: Compute relevance scores (POTENTIAL DIVISION BY ZERO SOURCE)
185
+ print(f" [MMR DEBUG] Computing relevance scores with cosine_similarity...")
186
+ try:
187
+ relevance_scores = cosine_similarity(query_embedding, chunk_embeddings)[0]
188
+ print(f" [MMR DEBUG] Relevance scores computed successfully")
189
+ print(f" [MMR DEBUG] Relevance scores shape: {relevance_scores.shape}")
190
+ print(f" [MMR DEBUG] Relevance scores min: {relevance_scores.min()}, max: {relevance_scores.max()}")
191
+ except Exception as e:
192
+ print(f" [MMR DEBUG] ERROR computing relevance scores: {e}")
193
+ print(f" [MMR DEBUG] Falling back to simple selection")
194
+ return chunks[:top_k]
195
+
196
+ # STEP 5: Initialize selection
197
+ selected, unselected = [], list(range(len(chunks)))
198
+
199
+ first = int(np.argmax(relevance_scores))
200
+ selected.append(first)
201
+ unselected.remove(first)
202
+ print(f" [MMR DEBUG] Selected first chunk: index {first}")
203
+
204
+ # STEP 6: Iteratively select chunks using MMR
205
+ print(f" [MMR DEBUG] Starting MMR iteration...")
206
+ iteration = 0
207
+ while len(selected) < min(top_k, len(chunks)):
208
+ iteration += 1
209
+ print(f" [MMR DEBUG] Iteration {iteration}: selected={len(selected)}, unselected={len(unselected)}")
210
+
211
+ # Calculate MMR scores
212
+ mmr_scores = []
213
+ for i in unselected:
214
+ # Compute max similarity to already selected items
215
+ max_sim = -1
216
+ for s in selected:
217
+ try:
218
+ # POTENTIAL DIVISION BY ZERO SOURCE: cosine_similarity
219
+ sim = cosine_similarity(
220
+ chunk_embeddings[i].reshape(1, -1),
221
+ chunk_embeddings[s].reshape(1, -1)
222
+ )[0][0]
223
+ max_sim = max(max_sim, sim)
224
+ except Exception as e:
225
+ print(f" [MMR DEBUG] ERROR computing similarity between chunk {i} and {s}: {e}")
226
+ # If similarity computation fails, use 0
227
+ max_sim = max(max_sim, 0)
228
+
229
+ mmr_score = lambda_param * relevance_scores[i] - (1 - lambda_param) * max_sim
230
+ mmr_scores.append((i, mmr_score))
231
+
232
+ # Select chunk with highest MMR score
233
+ if mmr_scores:
234
+ best, best_score = max(mmr_scores, key=lambda x: x[1])
235
+ selected.append(best)
236
+ unselected.remove(best)
237
+ print(f" [MMR DEBUG] Selected chunk {best} with MMR score {best_score:.4f}")
238
+ else:
239
+ print(f" [MMR DEBUG] No MMR scores computed, breaking")
240
+ break
241
+
242
+ print(f" [MMR DEBUG] MMR complete. Selected {len(selected)} chunks")
243
+ return [chunks[i] for i in selected]
244
+
245
+ # ------------------------------------------------------------------
246
+ # Main search
247
+ # ------------------------------------------------------------------
248
+
249
+ def search(self, query, index, top_k=25, final_k=5, mode="hybrid",
250
+ chunking_technique: Optional[str] = None,
251
+ rerank_strategy="cross-encoder", use_mmr=False, lambda_param=0.5,
252
+ verbose: Optional[bool] = None) -> List[str]:
253
+ """
254
+ :param mode: "semantic", "bm25", or "hybrid"
255
+ :param rerank_strategy: "cross-encoder", "rrf", or "none"
256
+ :param use_mmr: Whether to apply MMR diversity filter after reranking
257
+ :param lambda_param: MMR trade-off between relevance (1.0) and diversity (0.0)
258
+ """
259
+ should_print = verbose if verbose is not None else self.verbose
260
+ requested_technique = self._normalize_chunking_technique(chunking_technique)
261
+ total_start = time.perf_counter()
262
+ semantic_time = 0.0
263
+ bm25_time = 0.0
264
+ rerank_time = 0.0
265
+ mmr_time = 0.0
266
+
267
+ if should_print:
268
+ self._print_search_header(query, mode, rerank_strategy, top_k, final_k)
269
+ if requested_technique:
270
+ print(f"Chunking Filter: {requested_technique}")
271
+
272
+ # 1. Retrieve candidates
273
+ query_vector = None
274
+ semantic_chunks, bm25_chunks = [], []
275
+
276
+ if mode in ["semantic", "hybrid"]:
277
+ semantic_start = time.perf_counter()
278
+ query_vector, semantic_chunks = self._semantic_search(query, index, top_k, requested_technique)
279
+ semantic_time = time.perf_counter() - semantic_start
280
+ if should_print:
281
+ self._print_candidates("Semantic Search", semantic_chunks)
282
+ print(f"Semantic time: {semantic_time:.3f}s")
283
+
284
+ if mode in ["bm25", "hybrid"]:
285
+ bm25_start = time.perf_counter()
286
+ bm25_chunks = self._bm25_search(query, top_k, requested_technique)
287
+ bm25_time = time.perf_counter() - bm25_start
288
+ if should_print:
289
+ self._print_candidates("BM25 Search", bm25_chunks)
290
+ print(f"BM25 time: {bm25_time:.3f}s")
291
+
292
+ # 2. Fuse / rerank
293
+ rerank_start = time.perf_counter()
294
+ if rerank_strategy == "rrf":
295
+ candidates = self._rrf_score(semantic_chunks, bm25_chunks)[:final_k]
296
+ label = "RRF"
297
+ elif rerank_strategy == "cross-encoder":
298
+ combined = list(dict.fromkeys(semantic_chunks + bm25_chunks))
299
+ candidates = self._cross_encoder_rerank(query, combined, final_k)
300
+ label = "Cross-Encoder"
301
+ else: # "none"
302
+ candidates = list(dict.fromkeys(semantic_chunks + bm25_chunks))[:final_k]
303
+ label = "No Reranking"
304
+ rerank_time = time.perf_counter() - rerank_start
305
+
306
+ # 3. MMR diversity filter (applied after reranking)
307
+ if use_mmr and candidates:
308
+ mmr_start = time.perf_counter()
309
+ if query_vector is None:
310
+ query_vector = self.embed_model.encode(query)
311
+ candidates = self._maximal_marginal_relevance(query_vector, candidates,
312
+ lambda_param=lambda_param, top_k=final_k)
313
+ label += " + MMR"
314
+ mmr_time = time.perf_counter() - mmr_start
315
+
316
+ total_time = time.perf_counter() - total_start
317
+
318
+ if should_print:
319
+ self._print_final_results(candidates, label)
320
+ self._print_timing_summary(semantic_time, bm25_time, rerank_time, mmr_time, total_time)
321
+
322
+ return candidates
323
+
324
+ # ------------------------------------------------------------------
325
+ # Printing
326
+ # ------------------------------------------------------------------
327
+
328
+ def _print_search_header(self, query, mode, rerank_strategy, top_k, final_k):
329
+ print("\n" + "="*80)
330
+ print(f" SEARCH QUERY: {query}")
331
+ print(f"Mode: {mode.upper()} | Rerank: {rerank_strategy.upper()}")
332
+ print(f"Top-K: {top_k} | Final-K: {final_k}")
333
+ print("-" * 80)
334
+
335
+ def _print_candidates(self, label, chunks, preview_n=3):
336
+ print(f"{label}: Retrieved {len(chunks)} candidates")
337
+ for i, chunk in enumerate(chunks[:preview_n]):
338
+ preview = chunk[:100] + "..." if len(chunk) > 100 else chunk
339
+ print(f" [{i}] {preview}")
340
+
341
+ def _print_final_results(self, results, strategy_label):
342
+ print(f"\n Final {len(results)} Results ({strategy_label}):")
343
+ for i, chunk in enumerate(results):
344
+ preview = chunk[:150] + "..." if len(chunk) > 150 else chunk
345
+ print(f" [{i+1}] {preview}")
346
+ print("="*80)
347
+
348
+ def _print_timing_summary(self, semantic_time, bm25_time, rerank_time, mmr_time, total_time):
349
+ print(" Retrieval Timing:")
350
+ print(f" Semantic: {semantic_time:.3f}s")
351
+ print(f" BM25: {bm25_time:.3f}s")
352
+ print(f" Rerank/Fusion: {rerank_time:.3f}s")
353
+ print(f" MMR: {mmr_time:.3f}s")
354
+ print(f" Total Retrieval: {total_time:.3f}s")