Spaces:
Sleeping
Sleeping
| # evaluation.py - Evaluation System (WITH SAFETY CAPS) | |
| from typing import List, Dict, Tuple | |
| import time | |
| import numpy as np | |
| from dataclasses import dataclass | |
| import json | |
| class Question: | |
| """Represents a single evaluation question""" | |
| query: str | |
| query_type: str # content_retrieval, version_inquiry, change_retrieval | |
| expected_answer: str | |
| expected_version: str | |
| domain: str | |
| topic: str | |
| expected_keywords: List[str] = None | |
| class VersionQADataset: | |
| """Dataset for evaluating version-aware QA""" | |
| def __init__(self, questions: List[Question]): | |
| self.questions = questions | |
| def create_mini_versionqa(cls) -> 'VersionQADataset': | |
| """Create the Mini-VersionQA dataset as specified""" | |
| questions = [ | |
| # Software - Node.js Assert | |
| Question( | |
| query="What is the assert module in Node.js v20.0?", | |
| query_type="content_retrieval", | |
| expected_answer="assert module provides testing functions", | |
| expected_version="v20.0", | |
| domain="Software", | |
| topic="Node.js Assert", | |
| expected_keywords=["assert", "testing", "module"] | |
| ), | |
| Question( | |
| query="List all versions of the assert module", | |
| query_type="version_inquiry", | |
| expected_answer="v20.0, v21.0, v23.0", | |
| expected_version="all", | |
| domain="Software", | |
| topic="Node.js Assert", | |
| expected_keywords=["v20.0", "v21.0", "v23.0"] | |
| ), | |
| Question( | |
| query="When was the strict mode added to assert?", | |
| query_type="change_retrieval", | |
| expected_answer="v21.0", | |
| expected_version="v21.0", | |
| domain="Software", | |
| topic="Node.js Assert", | |
| expected_keywords=["strict", "mode", "v21.0"] | |
| ), | |
| # Software - Bootstrap | |
| Question( | |
| query="What are the grid classes in Bootstrap v5.2?", | |
| query_type="content_retrieval", | |
| expected_answer="col-*, row classes for responsive grid", | |
| expected_version="v5.2", | |
| domain="Software", | |
| topic="Bootstrap", | |
| expected_keywords=["grid", "col", "row"] | |
| ), | |
| Question( | |
| query="What changed in Bootstrap from v5.2 to v5.3?", | |
| query_type="change_retrieval", | |
| expected_answer="new utility classes and improvements", | |
| expected_version="v5.3", | |
| domain="Software", | |
| topic="Bootstrap", | |
| expected_keywords=["utility", "classes", "v5.3"] | |
| ), | |
| # Software - Spark | |
| Question( | |
| query="How does DataFrame work in Spark v3.0?", | |
| query_type="content_retrieval", | |
| expected_answer="distributed collection of data organized into named columns", | |
| expected_version="v3.0", | |
| domain="Software", | |
| topic="Spark", | |
| expected_keywords=["dataframe", "distributed", "columns"] | |
| ), | |
| Question( | |
| query="What was removed in Spark v3.5?", | |
| query_type="change_retrieval", | |
| expected_answer="deprecated APIs and legacy features", | |
| expected_version="v3.5", | |
| domain="Software", | |
| topic="Spark", | |
| expected_keywords=["removed", "deprecated", "v3.5"] | |
| ), | |
| # Healthcare | |
| Question( | |
| query="What are the treatment guidelines in v1.0?", | |
| query_type="content_retrieval", | |
| expected_answer="standard treatment protocols for patient care", | |
| expected_version="v1.0", | |
| domain="Healthcare", | |
| topic="Clinical Guidelines", | |
| expected_keywords=["treatment", "protocols", "guidelines"] | |
| ), | |
| Question( | |
| query="What changed in clinical guidelines from v1.0 to v2.0?", | |
| query_type="change_retrieval", | |
| expected_answer="updated treatment protocols and new recommendations", | |
| expected_version="v2.0", | |
| domain="Healthcare", | |
| topic="Clinical Guidelines", | |
| expected_keywords=["updated", "protocols", "v2.0"] | |
| ), | |
| # Finance | |
| Question( | |
| query="What are the compliance requirements in FY2023?", | |
| query_type="content_retrieval", | |
| expected_answer="regulatory compliance requirements for financial reporting", | |
| expected_version="FY2023", | |
| domain="Finance", | |
| topic="Compliance Reports", | |
| expected_keywords=["compliance", "requirements", "regulatory"] | |
| ), | |
| Question( | |
| query="What regulations changed from FY2023 to FY2024?", | |
| query_type="change_retrieval", | |
| expected_answer="new regulatory requirements and updated compliance standards", | |
| expected_version="FY2024", | |
| domain="Finance", | |
| topic="Compliance Reports", | |
| expected_keywords=["regulations", "changed", "FY2024"] | |
| ), | |
| # Industrial | |
| Question( | |
| query="What is the startup procedure in Rev. 1.0?", | |
| query_type="content_retrieval", | |
| expected_answer="machine startup steps and initialization procedures", | |
| expected_version="Rev. 1.0", | |
| domain="Industrial", | |
| topic="Machine Operation", | |
| expected_keywords=["startup", "procedure", "machine"] | |
| ), | |
| Question( | |
| query="What safety features were added in Rev. 2.0?", | |
| query_type="change_retrieval", | |
| expected_answer="enhanced safety features and emergency protocols", | |
| expected_version="Rev. 2.0", | |
| domain="Industrial", | |
| topic="Machine Operation", | |
| expected_keywords=["safety", "features", "Rev. 2.0"] | |
| ), | |
| ] | |
| return cls(questions) | |
| def from_dict(cls, data: List[Dict]) -> 'VersionQADataset': | |
| """Load dataset from dictionary""" | |
| questions = [] | |
| for q in data: | |
| questions.append(Question( | |
| query=q['query'], | |
| query_type=q['query_type'], | |
| expected_answer=q['expected_answer'], | |
| expected_version=q['expected_version'], | |
| domain=q['domain'], | |
| topic=q['topic'], | |
| expected_keywords=q.get('expected_keywords', []) | |
| )) | |
| return cls(questions) | |
| def to_dict(self) -> List[Dict]: | |
| """Convert dataset to dictionary""" | |
| return [ | |
| { | |
| 'query': q.query, | |
| 'query_type': q.query_type, | |
| 'expected_answer': q.expected_answer, | |
| 'expected_version': q.expected_version, | |
| 'domain': q.domain, | |
| 'topic': q.topic, | |
| 'expected_keywords': q.expected_keywords | |
| } | |
| for q in self.questions | |
| ] | |
| class Evaluator: | |
| """Evaluates VersionRAG and Baseline systems""" | |
| def __init__(self, version_rag, baseline_rag): | |
| self.version_rag = version_rag | |
| self.baseline_rag = baseline_rag | |
| def evaluate(self, dataset: VersionQADataset) -> Dict: | |
| """Run full evaluation on dataset""" | |
| versionrag_results = [] | |
| baseline_results = [] | |
| for question in dataset.questions: | |
| # Evaluate VersionRAG | |
| start_time = time.time() | |
| try: | |
| if question.query_type == "content_retrieval": | |
| vrag_answer = self.version_rag.query( | |
| query=question.query, | |
| version_filter=question.expected_version if question.expected_version != "all" else None | |
| ) | |
| elif question.query_type == "version_inquiry": | |
| vrag_answer = self.version_rag.version_inquiry(question.query) | |
| else: # change_retrieval | |
| vrag_answer = self.version_rag.change_retrieval(question.query) | |
| vrag_latency = time.time() - start_time | |
| except Exception as e: | |
| print(f"VersionRAG error on '{question.query}': {e}") | |
| vrag_answer = {'answer': '', 'sources': []} | |
| vrag_latency = 0 | |
| # Evaluate Baseline | |
| start_time = time.time() | |
| try: | |
| baseline_answer = self.baseline_rag.query(question.query) | |
| baseline_latency = time.time() - start_time | |
| except Exception as e: | |
| print(f"Baseline error on '{question.query}': {e}") | |
| baseline_answer = {'answer': '', 'sources': []} | |
| baseline_latency = 0 | |
| # Score answers | |
| vrag_score = self._score_answer( | |
| vrag_answer.get('answer', ''), | |
| question.expected_answer, | |
| vrag_answer.get('sources', []), | |
| question.expected_version, | |
| question.expected_keywords | |
| ) | |
| baseline_score = self._score_answer( | |
| baseline_answer.get('answer', ''), | |
| question.expected_answer, | |
| baseline_answer.get('sources', []), | |
| question.expected_version, | |
| question.expected_keywords | |
| ) | |
| versionrag_results.append({ | |
| 'question': question, | |
| 'score': vrag_score, | |
| 'latency': vrag_latency, | |
| 'answer': vrag_answer.get('answer', '') | |
| }) | |
| baseline_results.append({ | |
| 'question': question, | |
| 'score': baseline_score, | |
| 'latency': baseline_latency, | |
| 'answer': baseline_answer.get('answer', '') | |
| }) | |
| # Compute metrics | |
| versionrag_metrics = self._compute_metrics(versionrag_results) | |
| baseline_metrics = self._compute_metrics(baseline_results) | |
| return { | |
| 'versionrag': versionrag_metrics, | |
| 'baseline': baseline_metrics, | |
| 'questions': len(dataset.questions), | |
| 'improvement': { | |
| 'accuracy': versionrag_metrics['accuracy'] - baseline_metrics['accuracy'], | |
| 'vsa': versionrag_metrics['vsa'] - baseline_metrics['vsa'], | |
| 'hit_at_5': versionrag_metrics['hit_at_5'] - baseline_metrics['hit_at_5'] | |
| } | |
| } | |
| def _score_answer(self, answer: str, expected: str, sources: List[Dict], | |
| expected_version: str, expected_keywords: List[str] = None) -> Dict: | |
| """Score an answer based on correctness and version awareness""" | |
| if not answer: | |
| return { | |
| 'content_score': 0.0, | |
| 'version_score': 0.0, | |
| 'keyword_score': 0.0, | |
| 'total_score': 0.0 | |
| } | |
| # Keyword-based content scoring | |
| expected_keywords_set = set(expected.lower().split()) | |
| if expected_keywords: | |
| expected_keywords_set.update([k.lower() for k in expected_keywords]) | |
| answer_keywords = set(answer.lower().split()) | |
| # Compute overlap | |
| overlap = len(expected_keywords_set & answer_keywords) | |
| keyword_score = min(overlap / max(len(expected_keywords_set), 1), 1.0) | |
| # Semantic similarity (simple word overlap as proxy) | |
| answer_words = answer.lower().split() | |
| expected_words = expected.lower().split() | |
| common_words = set(answer_words) & set(expected_words) | |
| if len(expected_words) > 0: | |
| content_score = len(common_words) / len(expected_words) | |
| else: | |
| content_score = 0.0 | |
| # Boost score if answer is longer and contains key terms | |
| if len(answer) > 20 and keyword_score > 0.3: | |
| content_score = min(content_score * 1.2, 1.0) | |
| # Check version awareness | |
| version_score = self._compute_version_score(sources, expected_version) | |
| # Combined score with SAFETY CAP ✅ | |
| total_score = min((content_score * 0.4 + version_score * 0.4 + keyword_score * 0.2), 1.0) | |
| return { | |
| 'content_score': min(content_score, 1.0), | |
| 'version_score': min(version_score, 1.0), | |
| 'keyword_score': min(keyword_score, 1.0), | |
| 'total_score': total_score | |
| } | |
| def _compute_version_score(self, sources: List[Dict], expected_version: str) -> float: | |
| """Compute version-awareness score""" | |
| if expected_version == "all": | |
| # For version inquiry, check if multiple versions are present | |
| versions_in_sources = set() | |
| for source in sources: | |
| if isinstance(source, dict): | |
| version = source.get('version', 'N/A') | |
| if version != 'N/A': | |
| versions_in_sources.add(version) | |
| # Score based on number of versions found (more is better) | |
| return min(len(versions_in_sources) / 3.0, 1.0) | |
| else: | |
| # For specific version, check if expected version is in sources | |
| for source in sources: | |
| if isinstance(source, dict): | |
| version = source.get('version', '') | |
| if expected_version in str(version): | |
| return 1.0 | |
| return 0.0 | |
| def _compute_metrics(self, results: List[Dict]) -> Dict: | |
| """Compute evaluation metrics with SAFETY CAPS ✅""" | |
| if not results: | |
| return { | |
| 'accuracy': 0.0, | |
| 'hit_at_5': 0.0, | |
| 'mrr': 0.0, | |
| 'vsa': 0.0, | |
| 'avg_latency': 0.0, | |
| 'by_type': { | |
| 'content_retrieval': 0.0, | |
| 'version_inquiry': 0.0, | |
| 'change_retrieval': 0.0 | |
| } | |
| } | |
| # Overall metrics | |
| total_scores = [r['score']['total_score'] for r in results] | |
| content_scores = [r['score']['content_score'] for r in results] | |
| version_scores = [r['score']['version_score'] for r in results] | |
| latencies = [r['latency'] for r in results] | |
| # Hit@k (consider hit if score > 0.5) | |
| hits = [1 if score > 0.5 else 0 for score in total_scores] | |
| # MRR (Mean Reciprocal Rank) | |
| # Assume rank 1 if score > 0.7, rank 2 if > 0.5, rank 3 if > 0.3, else rank 5 | |
| reciprocal_ranks = [] | |
| for score in total_scores: | |
| if score > 0.7: | |
| reciprocal_ranks.append(1.0) | |
| elif score > 0.5: | |
| reciprocal_ranks.append(1/2) | |
| elif score > 0.3: | |
| reciprocal_ranks.append(1/3) | |
| else: | |
| reciprocal_ranks.append(1/5) | |
| # By query type | |
| by_type = { | |
| 'content_retrieval': [], | |
| 'version_inquiry': [], | |
| 'change_retrieval': [] | |
| } | |
| for result in results: | |
| qtype = result['question'].query_type | |
| by_type[qtype].append(result['score']['total_score']) | |
| # Return metrics with SAFETY CAPS ✅ | |
| return { | |
| 'accuracy': min(np.mean(total_scores) * 100, 100.0), | |
| 'hit_at_5': min(np.mean(hits) * 100, 100.0), | |
| 'mrr': min(np.mean(reciprocal_ranks), 1.0), | |
| 'vsa': min(np.mean(version_scores) * 100, 100.0), # Version-Sensitive Accuracy | |
| 'avg_latency': np.mean(latencies) if latencies else 0, | |
| 'by_type': { | |
| 'content_retrieval': min(np.mean(by_type['content_retrieval']) * 100, 100.0) if by_type['content_retrieval'] else 0, | |
| 'version_inquiry': min(np.mean(by_type['version_inquiry']) * 100, 100.0) if by_type['version_inquiry'] else 0, | |
| 'change_retrieval': min(np.mean(by_type['change_retrieval']) * 100, 100.0) if by_type['change_retrieval'] else 0 | |
| } | |
| } |