Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """Benchmark ResearchMind RAG retrieval and optional full chat latency.""" | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import statistics | |
| import sys | |
| import time | |
| from pathlib import Path | |
| from researchmind.config import get_config | |
| from researchmind.embeddings import embed_texts | |
| from researchmind.ingest import IngestPipeline | |
| from researchmind.retrieve import retrieve | |
| def _load_sessions() -> list[tuple[str, str]]: | |
| store = IngestPipeline().store | |
| return [(s.id, s.topic or "Untitled") for s in store.list_sessions()] | |
| def benchmark_retrieve( | |
| question: str, | |
| *, | |
| session_id: str, | |
| runs: int, | |
| ) -> dict[str, object]: | |
| cfg = get_config() | |
| store = IngestPipeline().store | |
| chunks_in_scope = store.get_chunks_with_embeddings(session_id=session_id or None) | |
| timings: list[float] = [] | |
| retrieved = 0 | |
| for _ in range(runs): | |
| started = time.perf_counter() | |
| chunks = retrieve(question, store, config=cfg, session_id=session_id or None) | |
| timings.append((time.perf_counter() - started) * 1000) | |
| retrieved = len(chunks) | |
| warm = timings[1:] if len(timings) > 1 else timings | |
| embed_started = time.perf_counter() | |
| embed_texts(["warmup query"], model_name=cfg.embed_model) | |
| embed_warm_ms = (time.perf_counter() - embed_started) * 1000 | |
| return { | |
| "question": question, | |
| "session_id": session_id, | |
| "chunks_in_scope": len(chunks_in_scope), | |
| "retrieved_chunks": retrieved, | |
| "top_k": cfg.top_k, | |
| "max_context_chunks": cfg.max_context_chunks, | |
| "embed_model": cfg.embed_model, | |
| "embedder_warm_ms": round(embed_warm_ms, 1), | |
| "retrieve_ms_cold": round(timings[0], 1) if timings else 0.0, | |
| "retrieve_ms_mean": round(statistics.mean(warm), 1), | |
| "retrieve_ms_stdev": round(statistics.stdev(warm), 1) if len(warm) > 1 else 0.0, | |
| "retrieve_ms_min": round(min(warm), 1), | |
| "retrieve_ms_max": round(max(warm), 1), | |
| } | |
| def benchmark_chat( | |
| question: str, | |
| *, | |
| session_id: str, | |
| model_key: str | None, | |
| ) -> dict[str, object]: | |
| from agent.runner import AgentRunner | |
| from gradio_space.model_loading import ensure_model_loaded, get_active_model_key | |
| from inference.factory import get_backend | |
| key = model_key or get_active_model_key() | |
| load_err = ensure_model_loaded(key) | |
| if load_err: | |
| return {"error": load_err, "model": key} | |
| backend = get_backend(key) | |
| runner = AgentRunner() | |
| started = time.perf_counter() | |
| result = runner.run_researchmind_chat( | |
| question=question, | |
| session_id=session_id, | |
| model_key=key, | |
| backend=backend, | |
| doc_ids=None, | |
| ) | |
| total_ms = (time.perf_counter() - started) * 1000 | |
| trace = json.loads(Path(result.trace_path).read_text(encoding="utf-8")) | |
| steps = [ | |
| { | |
| "name": step.get("name"), | |
| "label": step.get("label"), | |
| "duration_ms": step.get("duration_ms"), | |
| } | |
| for step in trace.get("steps", []) | |
| if step.get("type") == "step" | |
| ] | |
| return { | |
| "model": key, | |
| "question": question, | |
| "session_id": session_id, | |
| "total_ms": round(total_ms, 1), | |
| "citations": len(result.citations), | |
| "answer_preview": result.answer[:240], | |
| "steps": steps, | |
| "trace_path": result.trace_path, | |
| } | |
| def main() -> int: | |
| parser = argparse.ArgumentParser(description="Benchmark ResearchMind RAG chat") | |
| parser.add_argument( | |
| "--question", | |
| default="how we can finetune model", | |
| help="Question to benchmark", | |
| ) | |
| parser.add_argument("--session-id", default="", help="Research session id") | |
| parser.add_argument("--runs", type=int, default=5, help="Retrieve benchmark repetitions") | |
| parser.add_argument( | |
| "--full-chat", | |
| action="store_true", | |
| help="Also run one full RAG chat (loads local LLM)", | |
| ) | |
| parser.add_argument("--model-key", default="", help="Override ACTIVE_MODEL preset") | |
| args = parser.parse_args() | |
| sessions = _load_sessions() | |
| session_id = args.session_id.strip() | |
| if not session_id: | |
| session_id = sessions[0][0] if sessions else "" | |
| if not session_id: | |
| print("No indexed session found. Ingest sources first.") | |
| return 1 | |
| retrieve_report = benchmark_retrieve( | |
| args.question, | |
| session_id=session_id, | |
| runs=max(1, args.runs), | |
| ) | |
| print(json.dumps({"retrieve": retrieve_report}, indent=2)) | |
| if args.full_chat: | |
| chat_report = benchmark_chat( | |
| args.question, | |
| session_id=session_id, | |
| model_key=args.model_key or None, | |
| ) | |
| print(json.dumps({"chat": chat_report}, indent=2)) | |
| return 0 | |
| if __name__ == "__main__": | |
| sys.exit(main()) | |