Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import argparse | |
| import csv | |
| import json | |
| import math | |
| import re | |
| import shutil | |
| import zipfile | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Any, Iterable | |
| import chromadb | |
| import requests | |
| from llama_index.core import StorageContext, VectorStoreIndex | |
| from llama_index.core.node_parser import SentenceSplitter | |
| from llama_index.core.schema import Document | |
| from llama_index.core.schema import NodeWithScore, TextNode | |
| from llama_index.vector_stores.chroma import ChromaVectorStore | |
| from tools.query_knowledge import ( | |
| BM25Retriever, | |
| EMBED_MODEL_NAME, | |
| RERANKER_MODEL_NAME, | |
| CrossEncoderReranker, | |
| configure_model_cache, | |
| resolve_embed_model_name, | |
| ) | |
| PROJECT_ROOT = Path(__file__).resolve().parents[1] | |
| EVAL_DIR = PROJECT_ROOT / "eval" | |
| DATA_DIR = EVAL_DIR / "data" | |
| INDEX_DIR = EVAL_DIR / "indexes" | |
| REPORT_DIR = EVAL_DIR / "reports" | |
| BEIR_URLS = { | |
| "scifact": "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/scifact.zip", | |
| "fiqa": "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/fiqa.zip", | |
| } | |
| DATASET_ALIASES = { | |
| "beir/scifact": "scifact", | |
| "beir/fiqa": "fiqa", | |
| "open-ragbench": "open_ragbench", | |
| "open_ragbench": "open_ragbench", | |
| "t2-ragbench": "t2_ragbench", | |
| "t2_ragbench": "t2_ragbench", | |
| "local-options": "local_options", | |
| "local_options": "local_options", | |
| } | |
| class EvalCorpus: | |
| name: str | |
| documents: list[dict[str, Any]] | |
| queries: list[dict[str, Any]] | |
| qrels: dict[str, set[str]] | |
| def ensure_dirs() -> None: | |
| DATA_DIR.mkdir(parents=True, exist_ok=True) | |
| INDEX_DIR.mkdir(parents=True, exist_ok=True) | |
| REPORT_DIR.mkdir(parents=True, exist_ok=True) | |
| def download_file(url: str, destination: Path) -> None: | |
| destination.parent.mkdir(parents=True, exist_ok=True) | |
| with requests.get(url, stream=True, timeout=60) as response: | |
| response.raise_for_status() | |
| with destination.open("wb") as file: | |
| for chunk in response.iter_content(chunk_size=1024 * 1024): | |
| if chunk: | |
| file.write(chunk) | |
| def read_jsonl(path: Path) -> Iterable[dict[str, Any]]: | |
| with path.open("r", encoding="utf-8") as file: | |
| for line in file: | |
| line = line.strip() | |
| if line: | |
| yield json.loads(line) | |
| def prepare_beir_dataset(dataset_name: str) -> Path: | |
| ensure_dirs() | |
| if dataset_name not in BEIR_URLS: | |
| raise ValueError(f"Unsupported BEIR dataset: {dataset_name}") | |
| target_dir = DATA_DIR / "beir" / dataset_name | |
| corpus_path = target_dir / "corpus.jsonl" | |
| if corpus_path.exists(): | |
| return target_dir | |
| zip_path = DATA_DIR / "downloads" / f"{dataset_name}.zip" | |
| if not zip_path.exists(): | |
| download_file(BEIR_URLS[dataset_name], zip_path) | |
| extract_root = DATA_DIR / "beir" | |
| extract_root.mkdir(parents=True, exist_ok=True) | |
| with zipfile.ZipFile(zip_path) as archive: | |
| archive.extractall(extract_root) | |
| if not corpus_path.exists(): | |
| raise FileNotFoundError(f"BEIR extraction did not create {corpus_path}") | |
| return target_dir | |
| def load_beir_dataset( | |
| dataset_name: str, | |
| split: str, | |
| max_corpus_docs: int | None, | |
| max_queries: int | None, | |
| ) -> EvalCorpus: | |
| dataset_dir = prepare_beir_dataset(dataset_name) | |
| all_queries = { | |
| str(row["_id"]): row.get("text", "") | |
| for row in read_jsonl(dataset_dir / "queries.jsonl") | |
| } | |
| qrels_path = dataset_dir / "qrels" / f"{split}.tsv" | |
| if not qrels_path.exists(): | |
| candidates = sorted((dataset_dir / "qrels").glob("*.tsv")) | |
| if not candidates: | |
| raise FileNotFoundError(f"No qrels found under {dataset_dir / 'qrels'}") | |
| qrels_path = candidates[0] | |
| all_qrels: dict[str, set[str]] = {} | |
| with qrels_path.open("r", encoding="utf-8") as file: | |
| reader = csv.DictReader(file, delimiter="\t") | |
| for row in reader: | |
| query_id = str(row.get("query-id") or row.get("query_id")) | |
| corpus_id = str(row.get("corpus-id") or row.get("corpus_id")) | |
| score = int(row.get("score", 1)) | |
| if score <= 0: | |
| continue | |
| all_qrels.setdefault(query_id, set()).add(corpus_id) | |
| queries = [] | |
| required_doc_ids = set() | |
| for query_id, relevant_docs in all_qrels.items(): | |
| if query_id not in all_queries: | |
| continue | |
| if max_corpus_docs and len(required_doc_ids | relevant_docs) > max_corpus_docs: | |
| continue | |
| required_doc_ids.update(relevant_docs) | |
| queries.append( | |
| { | |
| "query_id": query_id, | |
| "question": all_queries[query_id], | |
| "relevant_doc_ids": sorted(relevant_docs), | |
| } | |
| ) | |
| if max_queries and len(queries) >= max_queries: | |
| break | |
| documents = [] | |
| seen_doc_ids = set() | |
| for row in read_jsonl(dataset_dir / "corpus.jsonl"): | |
| doc_id = str(row["_id"]) | |
| if required_doc_ids and doc_id not in required_doc_ids: | |
| if max_corpus_docs and len(documents) >= max_corpus_docs: | |
| continue | |
| if max_corpus_docs and len(documents) + len(required_doc_ids - seen_doc_ids) >= max_corpus_docs: | |
| continue | |
| title = row.get("title") or "" | |
| text = row.get("text") or "" | |
| documents.append( | |
| { | |
| "doc_id": doc_id, | |
| "title": title, | |
| "text": f"{title}\n{text}".strip(), | |
| "metadata": {"source_dataset": f"beir/{dataset_name}"}, | |
| } | |
| ) | |
| seen_doc_ids.add(doc_id) | |
| if max_corpus_docs and len(documents) >= max_corpus_docs and required_doc_ids.issubset(seen_doc_ids): | |
| break | |
| if not documents or not queries: | |
| raise ValueError( | |
| f"Dataset beir/{dataset_name} has no evaluable documents/queries. " | |
| "Increase --max-corpus-docs or use a larger sample." | |
| ) | |
| return EvalCorpus( | |
| name=f"beir_{dataset_name}", | |
| documents=documents, | |
| queries=queries, | |
| qrels={query["query_id"]: set(query["relevant_doc_ids"]) for query in queries}, | |
| ) | |
| def snapshot_hf_dataset(repo_id: str, local_name: str) -> Path: | |
| from huggingface_hub import snapshot_download | |
| ensure_dirs() | |
| target_dir = DATA_DIR / "hf" / local_name | |
| if target_dir.exists(): | |
| return target_dir | |
| snapshot_download( | |
| repo_id=repo_id, | |
| repo_type="dataset", | |
| local_dir=str(target_dir), | |
| local_dir_use_symlinks=False, | |
| ) | |
| return target_dir | |
| def flatten_open_ragbench_section(section: dict[str, Any]) -> str: | |
| parts = [section.get("text") or ""] | |
| tables = section.get("tables") or {} | |
| if isinstance(tables, dict): | |
| parts.extend(str(value) for value in tables.values()) | |
| return "\n".join(part for part in parts if part) | |
| def load_open_ragbench( | |
| max_corpus_docs: int | None, | |
| max_queries: int | None, | |
| ) -> EvalCorpus: | |
| dataset_dir = snapshot_hf_dataset("vectara/open_ragbench", "open_ragbench") | |
| root = dataset_dir / "pdf" / "arxiv" | |
| if not root.exists(): | |
| root = dataset_dir / "official" / "pdf" / "arxiv" | |
| if not root.exists(): | |
| raise FileNotFoundError(f"Open RAGBench root not found: {root}") | |
| queries_data = json.loads((root / "queries.json").read_text(encoding="utf-8")) | |
| qrels_data = json.loads((root / "qrels.json").read_text(encoding="utf-8")) | |
| documents = [] | |
| qrels: dict[str, set[str]] = {} | |
| required_doc_ids = set() | |
| selected_query_ids = [] | |
| for query_id, qrel in qrels_data.items(): | |
| doc_id = str(qrel.get("doc_id")) | |
| if not doc_id or doc_id == "None": | |
| continue | |
| selected_query_ids.append(str(query_id)) | |
| required_doc_ids.add(doc_id) | |
| if max_queries and len(selected_query_ids) >= max_queries: | |
| break | |
| allowed_doc_ids = set() | |
| corpus_files = sorted((root / "corpus").glob("*.json")) | |
| for corpus_file in corpus_files: | |
| paper = json.loads(corpus_file.read_text(encoding="utf-8")) | |
| paper_id = str(paper.get("id") or corpus_file.stem) | |
| is_required = paper_id in required_doc_ids | |
| if max_corpus_docs and not is_required: | |
| missing_required_count = len(required_doc_ids - allowed_doc_ids) | |
| if len(documents) + missing_required_count >= max_corpus_docs: | |
| continue | |
| allowed_doc_ids.add(paper_id) | |
| section_texts = [] | |
| for section_index, section in enumerate(paper.get("sections") or []): | |
| section_text = flatten_open_ragbench_section(section) | |
| if section_text: | |
| section_texts.append(f"[section {section_index}]\n{section_text}") | |
| text = "\n\n".join( | |
| part | |
| for part in [paper.get("title") or "", paper.get("abstract") or "", *section_texts] | |
| if part | |
| ) | |
| documents.append( | |
| { | |
| "doc_id": paper_id, | |
| "title": paper.get("title") or paper_id, | |
| "text": text, | |
| "metadata": { | |
| "source_dataset": "open_ragbench", | |
| "categories": ",".join(paper.get("categories") or []), | |
| }, | |
| } | |
| ) | |
| if max_corpus_docs and len(documents) >= max_corpus_docs: | |
| break | |
| queries = [] | |
| for query_id in selected_query_ids: | |
| qrel = qrels_data[query_id] | |
| doc_id = str(qrel.get("doc_id")) | |
| if doc_id not in allowed_doc_ids: | |
| continue | |
| query_payload = queries_data.get(query_id) or {} | |
| question = query_payload.get("query") if isinstance(query_payload, dict) else str(query_payload) | |
| qrels[str(query_id)] = {doc_id} | |
| queries.append( | |
| { | |
| "query_id": str(query_id), | |
| "question": question, | |
| "relevant_doc_ids": [doc_id], | |
| } | |
| ) | |
| if max_queries and len(queries) >= max_queries: | |
| break | |
| if not documents or not queries: | |
| raise ValueError("Open RAGBench produced no evaluable sample.") | |
| return EvalCorpus("open_ragbench", documents, queries, qrels) | |
| def load_t2_ragbench( | |
| max_corpus_docs: int | None, | |
| max_queries: int | None, | |
| ) -> EvalCorpus: | |
| dataset_dir = snapshot_hf_dataset("G4KMU/t2-ragbench", "t2_ragbench") | |
| parquet_files = sorted(dataset_dir.rglob("*.parquet")) | |
| jsonl_files = sorted(dataset_dir.rglob("*.jsonl")) | |
| if not parquet_files and not jsonl_files: | |
| raise FileNotFoundError(f"No parquet/jsonl files found in {dataset_dir}") | |
| rows: list[dict[str, Any]] = [] | |
| if parquet_files: | |
| import pandas as pd | |
| for parquet_file in parquet_files: | |
| frame = pd.read_parquet(parquet_file) | |
| rows.extend(frame.to_dict(orient="records")) | |
| if max_queries and len(rows) >= max_queries * 5: | |
| break | |
| else: | |
| for jsonl_file in jsonl_files: | |
| rows.extend(read_jsonl(jsonl_file)) | |
| if max_queries and len(rows) >= max_queries * 5: | |
| break | |
| documents_by_id: dict[str, dict[str, Any]] = {} | |
| queries = [] | |
| qrels: dict[str, set[str]] = {} | |
| for index, row in enumerate(rows): | |
| question = first_present(row, ["question", "query", "Question"]) | |
| answer = first_present(row, ["answer", "Answer", "response"]) | |
| context = first_present(row, ["context", "evidence", "gold_context", "text", "document"]) | |
| table = first_present(row, ["table", "Table", "markdown_table"]) | |
| doc_id = str(first_present(row, ["doc_id", "document_id", "filename", "pdf_path", "source"]) or f"row-{index}") | |
| if not question or not context: | |
| continue | |
| text = "\n".join(part for part in [str(context), str(table or "")] if part) | |
| if doc_id not in documents_by_id: | |
| documents_by_id[doc_id] = { | |
| "doc_id": doc_id, | |
| "title": str(first_present(row, ["company", "ticker", "title", "Title"]) or doc_id), | |
| "text": text, | |
| "metadata": {"source_dataset": "t2_ragbench", "answer": str(answer or "")}, | |
| } | |
| queries.append( | |
| { | |
| "query_id": str(first_present(row, ["qid", "query_id", "id"]) or f"q-{index}"), | |
| "question": str(question), | |
| "relevant_doc_ids": [doc_id], | |
| } | |
| ) | |
| qrels[queries[-1]["query_id"]] = {doc_id} | |
| if max_queries and len(queries) >= max_queries: | |
| break | |
| documents = list(documents_by_id.values()) | |
| if max_corpus_docs: | |
| documents = documents[:max_corpus_docs] | |
| allowed = {document["doc_id"] for document in documents} | |
| queries = [query for query in queries if query["relevant_doc_ids"][0] in allowed] | |
| qrels = {query["query_id"]: set(query["relevant_doc_ids"]) for query in queries} | |
| if not documents or not queries: | |
| raise ValueError("T2-RAGBench produced no evaluable sample.") | |
| return EvalCorpus("t2_ragbench", documents, queries, qrels) | |
| def first_present(row: dict[str, Any], keys: list[str]) -> Any: | |
| for key in keys: | |
| value = row.get(key) | |
| if value is not None and value != "": | |
| return value | |
| return None | |
| def load_local_options_eval(max_queries: int | None) -> EvalCorpus: | |
| cases_path = EVAL_DIR / "local_options_eval.jsonl" | |
| if not cases_path.exists(): | |
| raise FileNotFoundError( | |
| f"Local options eval set not found: {cases_path}. " | |
| "Create JSONL cases with question, expected_pages, expected_keywords." | |
| ) | |
| from tools.query_knowledge import load_pdf_file | |
| pdf_files = sorted((PROJECT_ROOT / "knowledge_base" / "raw").rglob("*.pdf")) | |
| if not pdf_files: | |
| pdf_files = sorted((PROJECT_ROOT / "tools" / "knowledge_base" / "raw").rglob("*.pdf")) | |
| documents = [] | |
| for pdf_file in pdf_files: | |
| for doc_index, document in enumerate(load_pdf_file(pdf_file)): | |
| documents.append( | |
| { | |
| "doc_id": f"{pdf_file.name}:{document.metadata.get('page_number')}:{doc_index}", | |
| "title": document.metadata.get("section_path") or pdf_file.name, | |
| "text": document.text, | |
| "metadata": document.metadata, | |
| } | |
| ) | |
| queries = [] | |
| qrels: dict[str, set[str]] = {} | |
| for case_index, case in enumerate(read_jsonl(cases_path)): | |
| query_id = str(case.get("id") or f"local-{case_index}") | |
| relevant_ids = [] | |
| expected_pages = set(case.get("expected_pages") or []) | |
| expected_keywords = case.get("expected_keywords") or [] | |
| for document in documents: | |
| metadata = document.get("metadata") or {} | |
| page_hit = metadata.get("page_number") in expected_pages | |
| keyword_hit = any(keyword in document["text"] for keyword in expected_keywords) | |
| if page_hit or keyword_hit: | |
| relevant_ids.append(document["doc_id"]) | |
| queries.append( | |
| { | |
| "query_id": query_id, | |
| "question": case["question"], | |
| "relevant_doc_ids": relevant_ids, | |
| } | |
| ) | |
| qrels[query_id] = set(relevant_ids) | |
| if max_queries and len(queries) >= max_queries: | |
| break | |
| if not documents or not queries: | |
| raise ValueError("Local options eval set produced no evaluable sample.") | |
| return EvalCorpus("local_options", documents, queries, qrels) | |
| def load_eval_corpus(args: argparse.Namespace) -> EvalCorpus: | |
| dataset = DATASET_ALIASES.get(args.dataset, args.dataset) | |
| if dataset in {"scifact", "fiqa"}: | |
| return load_beir_dataset(dataset, args.split, args.max_corpus_docs, args.max_queries) | |
| if dataset == "open_ragbench": | |
| return load_open_ragbench(args.max_corpus_docs, args.max_queries) | |
| if dataset == "t2_ragbench": | |
| return load_t2_ragbench(args.max_corpus_docs, args.max_queries) | |
| if dataset == "local_options": | |
| return load_local_options_eval(args.max_queries) | |
| raise ValueError(f"Unknown dataset: {args.dataset}") | |
| def collection_safe_name(value: str) -> str: | |
| safe = re.sub(r"[^A-Za-z0-9_-]+", "_", value) | |
| return safe.strip("_") or "default" | |
| def build_index(corpus: EvalCorpus, chunk_size: int, chunk_overlap: int, rebuild: bool) -> VectorStoreIndex: | |
| configure_model_cache() | |
| from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
| index_path = INDEX_DIR / corpus.name | |
| if rebuild and index_path.exists(): | |
| shutil.rmtree(index_path) | |
| index_path.mkdir(parents=True, exist_ok=True) | |
| db = chromadb.PersistentClient(path=str(index_path)) | |
| embed_slug = collection_safe_name(EMBED_MODEL_NAME) | |
| collection_name = f"{corpus.name}_{embed_slug}_eval" | |
| if rebuild: | |
| try: | |
| db.delete_collection(collection_name) | |
| except Exception: | |
| pass | |
| collection = db.get_or_create_collection(collection_name) | |
| vector_store = ChromaVectorStore(chroma_collection=collection) | |
| storage_context = StorageContext.from_defaults(vector_store=vector_store) | |
| embed_model = HuggingFaceEmbedding( | |
| model_name=resolve_embed_model_name(), | |
| cache_folder=str(PROJECT_ROOT / "hf_cache" / "sentence_transformers"), | |
| ) | |
| if collection.count() == 0: | |
| documents = [ | |
| Document( | |
| text=document["text"], | |
| metadata={ | |
| "doc_id": document["doc_id"], | |
| "title": document.get("title", ""), | |
| **(document.get("metadata") or {}), | |
| }, | |
| ) | |
| for document in corpus.documents | |
| ] | |
| splitter = SentenceSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) | |
| nodes = splitter.get_nodes_from_documents(documents) | |
| VectorStoreIndex( | |
| nodes, | |
| storage_context=storage_context, | |
| embed_model=embed_model, | |
| show_progress=True, | |
| ) | |
| return VectorStoreIndex.from_vector_store(vector_store, embed_model=embed_model) | |
| def build_bm25_retriever(corpus: EvalCorpus, chunk_size: int, chunk_overlap: int) -> BM25Retriever: | |
| documents = [ | |
| Document( | |
| text=document["text"], | |
| metadata={ | |
| "doc_id": document["doc_id"], | |
| "title": document.get("title", ""), | |
| **(document.get("metadata") or {}), | |
| }, | |
| ) | |
| for document in corpus.documents | |
| ] | |
| splitter = SentenceSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) | |
| nodes = splitter.get_nodes_from_documents(documents) | |
| text_nodes = [ | |
| TextNode(id_=node.node_id, text=node.get_content(), metadata=node.metadata) | |
| for node in nodes | |
| ] | |
| return BM25Retriever(text_nodes) | |
| def merge_eval_results( | |
| vector_results: list[NodeWithScore], | |
| bm25_results: list[NodeWithScore], | |
| top_k: int, | |
| ) -> list[NodeWithScore]: | |
| merged: dict[str, NodeWithScore] = {} | |
| for rank, result in enumerate(vector_results): | |
| node_id = result.node.node_id | |
| merged[node_id] = NodeWithScore(node=result.node, score=1.0 / (rank + 1)) | |
| for rank, result in enumerate(bm25_results): | |
| node_id = result.node.node_id | |
| reciprocal_rank_score = 1.0 / (rank + 1) | |
| if node_id in merged: | |
| merged[node_id].score = (merged[node_id].score or 0.0) + reciprocal_rank_score | |
| else: | |
| merged[node_id] = NodeWithScore(node=result.node, score=reciprocal_rank_score) | |
| results = list(merged.values()) | |
| results.sort(key=lambda item: item.score or float("-inf"), reverse=True) | |
| return results[:top_k] | |
| def evaluate_retrieval( | |
| corpus: EvalCorpus, | |
| index: VectorStoreIndex, | |
| top_k: int, | |
| use_reranker: bool = False, | |
| use_hybrid: bool = False, | |
| chunk_size: int = 512, | |
| chunk_overlap: int = 64, | |
| reranker_model_name: str = RERANKER_MODEL_NAME, | |
| reranker_candidates: int = 25, | |
| ) -> dict[str, Any]: | |
| retrieve_top_k = max(reranker_candidates, top_k) if use_reranker else max(top_k * 5, top_k) | |
| retriever = index.as_retriever(similarity_top_k=retrieve_top_k) | |
| bm25_retriever = ( | |
| build_bm25_retriever(corpus, chunk_size, chunk_overlap) | |
| if use_hybrid | |
| else None | |
| ) | |
| reranker = CrossEncoderReranker(reranker_model_name) if use_reranker else None | |
| cases = [] | |
| hit_counts = {1: 0, 3: 0, 5: 0, top_k: 0} | |
| reciprocal_ranks = [] | |
| ndcg_scores = [] | |
| for query in corpus.queries: | |
| relevant_doc_ids = corpus.qrels.get(query["query_id"], set()) | |
| vector_results = retriever.retrieve(query["question"]) | |
| results = vector_results | |
| if bm25_retriever: | |
| bm25_results = bm25_retriever.retrieve(query["question"], retrieve_top_k) | |
| results = merge_eval_results(vector_results, bm25_results, retrieve_top_k) | |
| if reranker: | |
| results = reranker.rerank( | |
| query["question"], | |
| results, | |
| top_n=max(top_k * 5, top_k), | |
| ) | |
| retrieved = [] | |
| seen_doc_ids = set() | |
| first_hit_rank = None | |
| dcg = 0.0 | |
| for result in results: | |
| metadata = result.node.metadata | |
| doc_id = str(metadata.get("doc_id", "")) | |
| if doc_id in seen_doc_ids: | |
| continue | |
| seen_doc_ids.add(doc_id) | |
| rank = len(retrieved) + 1 | |
| hit = doc_id in relevant_doc_ids | |
| if hit and first_hit_rank is None: | |
| first_hit_rank = rank | |
| if hit: | |
| dcg += 1 / math.log2(rank + 1) | |
| retrieved.append( | |
| { | |
| "rank": rank, | |
| "doc_id": doc_id, | |
| "score": result.score, | |
| "hit": hit, | |
| "title": metadata.get("title", ""), | |
| } | |
| ) | |
| if len(retrieved) >= top_k: | |
| break | |
| ideal_hits = min(len(relevant_doc_ids), top_k) | |
| idcg = sum(1 / math.log2(rank + 1) for rank in range(1, ideal_hits + 1)) | |
| ndcg = dcg / idcg if idcg else 0.0 | |
| ndcg_scores.append(ndcg) | |
| reciprocal_ranks.append(1 / first_hit_rank if first_hit_rank else 0.0) | |
| for k in hit_counts: | |
| if any(item["hit"] for item in retrieved[:k]): | |
| hit_counts[k] += 1 | |
| cases.append( | |
| { | |
| "query_id": query["query_id"], | |
| "question": query["question"], | |
| "relevant_doc_ids": sorted(relevant_doc_ids), | |
| "first_hit_rank": first_hit_rank, | |
| "retrieved": retrieved, | |
| } | |
| ) | |
| total = len(corpus.queries) | |
| metrics = { | |
| "queries": total, | |
| "documents": len(corpus.documents), | |
| "top_k": top_k, | |
| "mrr": sum(reciprocal_ranks) / total if total else 0.0, | |
| "ndcg_at_k": sum(ndcg_scores) / total if total else 0.0, | |
| "reranker_enabled": use_reranker, | |
| "hybrid_enabled": use_hybrid, | |
| } | |
| for k, count in sorted(hit_counts.items()): | |
| metrics[f"hit_at_{k}"] = count / total if total else 0.0 | |
| return {"dataset": corpus.name, "metrics": metrics, "cases": cases} | |
| def write_reports(report: dict[str, Any]) -> tuple[Path, Path]: | |
| ensure_dirs() | |
| dataset_name = report["dataset"] | |
| json_path = REPORT_DIR / f"{dataset_name}_retrieval_eval.json" | |
| md_path = REPORT_DIR / f"{dataset_name}_retrieval_eval.md" | |
| json_path.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8") | |
| metrics = report["metrics"] | |
| lines = [ | |
| f"# Retrieval Eval: {dataset_name}", | |
| "", | |
| "## Metrics", | |
| "", | |
| ] | |
| for key, value in metrics.items(): | |
| lines.append(f"- `{key}`: {value:.4f}" if isinstance(value, float) else f"- `{key}`: {value}") | |
| lines.extend(["", "## Sample Cases", ""]) | |
| for case in report["cases"][:10]: | |
| lines.append(f"### {case['query_id']}") | |
| lines.append("") | |
| lines.append(case["question"]) | |
| lines.append("") | |
| lines.append(f"- first_hit_rank: `{case['first_hit_rank']}`") | |
| for item in case["retrieved"][:5]: | |
| lines.append( | |
| f"- rank {item['rank']}: hit={item['hit']} doc_id=`{item['doc_id']}` score={item['score']}" | |
| ) | |
| lines.append("") | |
| md_path.write_text("\n".join(lines), encoding="utf-8") | |
| return json_path, md_path | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser(description="Run retrieval eval for RAG datasets.") | |
| parser.add_argument( | |
| "--dataset", | |
| required=True, | |
| help="beir/scifact, beir/fiqa, open-ragbench, t2-ragbench, or local-options", | |
| ) | |
| parser.add_argument("--split", default="test") | |
| parser.add_argument("--top-k", type=int, default=5) | |
| parser.add_argument("--chunk-size", type=int, default=512) | |
| parser.add_argument("--chunk-overlap", type=int, default=64) | |
| parser.add_argument("--max-corpus-docs", type=int, default=None) | |
| parser.add_argument("--max-queries", type=int, default=None) | |
| parser.add_argument("--rebuild", action="store_true") | |
| parser.add_argument("--use-hybrid", action="store_true") | |
| parser.add_argument("--use-reranker", action="store_true") | |
| parser.add_argument("--reranker-model", default=RERANKER_MODEL_NAME) | |
| parser.add_argument("--reranker-candidates", type=int, default=25) | |
| return parser.parse_args() | |
| def main() -> None: | |
| args = parse_args() | |
| corpus = load_eval_corpus(args) | |
| index = build_index(corpus, args.chunk_size, args.chunk_overlap, args.rebuild) | |
| report = evaluate_retrieval( | |
| corpus, | |
| index, | |
| args.top_k, | |
| use_reranker=args.use_reranker, | |
| use_hybrid=args.use_hybrid, | |
| chunk_size=args.chunk_size, | |
| chunk_overlap=args.chunk_overlap, | |
| reranker_model_name=args.reranker_model, | |
| reranker_candidates=args.reranker_candidates, | |
| ) | |
| json_path, md_path = write_reports(report) | |
| print(json.dumps(report["metrics"], ensure_ascii=False, indent=2)) | |
| print(f"JSON report: {json_path}") | |
| print(f"Markdown report: {md_path}") | |
| if __name__ == "__main__": | |
| main() | |