Spaces:
Sleeping
Sleeping
File size: 5,725 Bytes
d02bacd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 | """Zero-dependency TF-IDF retriever over the company-memory corpus.
Design choices
--------------
* Pure Python (re + math) so it runs inside OpenEnv/HF Space containers
without pulling sentence-transformers / faiss / langchain.
* Each markdown file is split into paragraph-level chunks, indexed with
token-frequency + inverse-document-frequency, and queried with cosine
over sparse TF-IDF vectors.
* The retriever is the *single source of grounding truth* consumed by
both the specialists (at rollout time) and the grader (at scoring
time), so the reward becomes verifiable instead of fuzzy.
"""
from __future__ import annotations
import math
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Iterable, List, Sequence, Tuple
_WORD_RE = re.compile(r"[a-zA-Z][a-zA-Z0-9_-]{1,}")
def _tokenize(text: str) -> List[str]:
return [tok.lower() for tok in _WORD_RE.findall(text or "")]
@dataclass
class MemoryHit:
source: str
snippet: str
score: float
def as_citation(self) -> str:
"""Stable string form used in ExpertReport.citations and briefs."""
return f"memory:{self.source}"
class Retriever:
"""Lightweight TF-IDF retriever over a directory of markdown files."""
def __init__(self, corpus_dir: Path) -> None:
self.corpus_dir = Path(corpus_dir)
self._docs: List[Tuple[str, str]] = []
self._vocab: Dict[str, int] = {}
self._tf: List[Dict[int, float]] = []
self._df: Dict[int, int] = {}
self._norms: List[float] = []
self._load()
self._build_index()
# -- indexing ----------------------------------------------------------
def _load(self) -> None:
for path in sorted(self.corpus_dir.rglob("*.md")):
rel = str(path.relative_to(self.corpus_dir)).replace("\\", "/")
text = path.read_text(encoding="utf-8")
chunks = [chunk.strip() for chunk in re.split(r"\n\s*\n", text) if chunk.strip()]
for i, chunk in enumerate(chunks):
self._docs.append((f"{rel}#chunk{i}", chunk))
def _build_index(self) -> None:
for _, text in self._docs:
tokens = _tokenize(text)
tf_doc: Dict[int, float] = {}
for tok in tokens:
if tok not in self._vocab:
self._vocab[tok] = len(self._vocab)
idx = self._vocab[tok]
tf_doc[idx] = tf_doc.get(idx, 0.0) + 1.0
self._tf.append(tf_doc)
for idx in tf_doc:
self._df[idx] = self._df.get(idx, 0) + 1
self._num_docs = len(self._docs)
# cache L2 norms of tf-idf vectors for cosine similarity.
self._norms = []
for tf_doc in self._tf:
sq = 0.0
for idx, tf in tf_doc.items():
idf = math.log((self._num_docs + 1) / (self._df.get(idx, 1) + 1)) + 1.0
sq += (tf * idf) ** 2
self._norms.append(math.sqrt(sq) or 1.0)
# -- public API --------------------------------------------------------
def query(self, text: str, k: int = 3) -> List[MemoryHit]:
if not text or not self._docs:
return []
tokens = _tokenize(text)
if not tokens:
return []
q_tf: Dict[int, float] = {}
for tok in tokens:
if tok in self._vocab:
idx = self._vocab[tok]
q_tf[idx] = q_tf.get(idx, 0.0) + 1.0
if not q_tf:
return []
# query norm
q_sq = 0.0
for idx, tf in q_tf.items():
idf = math.log((self._num_docs + 1) / (self._df.get(idx, 1) + 1)) + 1.0
q_sq += (tf * idf) ** 2
q_norm = math.sqrt(q_sq) or 1.0
scores: List[Tuple[float, int]] = []
for doc_idx, tf_doc in enumerate(self._tf):
dot = 0.0
for idx, qt in q_tf.items():
tfd = tf_doc.get(idx)
if tfd is None:
continue
idf = math.log((self._num_docs + 1) / (self._df.get(idx, 1) + 1)) + 1.0
dot += (qt * idf) * (tfd * idf)
if dot <= 0.0:
continue
cosine = dot / (q_norm * self._norms[doc_idx])
scores.append((cosine, doc_idx))
scores.sort(reverse=True)
hits: List[MemoryHit] = []
for score, doc_idx in scores[: max(k, 0)]:
source, chunk = self._docs[doc_idx]
snippet = chunk.replace("\n", " ").strip()
if len(snippet) > 280:
snippet = snippet[:277] + "..."
hits.append(MemoryHit(source=source, snippet=snippet, score=round(float(score), 4)))
return hits
def sources(self) -> List[str]:
return [src for src, _ in self._docs]
def has_source(self, source: str) -> bool:
return any(src == source for src, _ in self._docs)
def contains_any(self, text: str, sources: Sequence[str]) -> bool:
"""Return True if any of ``sources`` appears verbatim in ``text``.
Used by the grader to verify citations are grounded in the corpus
rather than hallucinated strings.
"""
if not text:
return False
return any(src and src in text for src in sources)
def count_grounded_citations(self, citations: Iterable[str]) -> int:
known = set(self.sources())
n = 0
for citation in citations or []:
if not isinstance(citation, str) or not citation.startswith("memory:"):
continue
if citation[len("memory:"):] in known:
n += 1
return n
__all__ = ["MemoryHit", "Retriever"]
|