File size: 5,725 Bytes
d02bacd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
"""Zero-dependency TF-IDF retriever over the company-memory corpus.

Design choices
--------------
* Pure Python (re + math) so it runs inside OpenEnv/HF Space containers
  without pulling sentence-transformers / faiss / langchain.
* Each markdown file is split into paragraph-level chunks, indexed with
  token-frequency + inverse-document-frequency, and queried with cosine
  over sparse TF-IDF vectors.
* The retriever is the *single source of grounding truth* consumed by
  both the specialists (at rollout time) and the grader (at scoring
  time), so the reward becomes verifiable instead of fuzzy.
"""
from __future__ import annotations

import math
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Iterable, List, Sequence, Tuple

_WORD_RE = re.compile(r"[a-zA-Z][a-zA-Z0-9_-]{1,}")


def _tokenize(text: str) -> List[str]:
    return [tok.lower() for tok in _WORD_RE.findall(text or "")]


@dataclass
class MemoryHit:
    source: str
    snippet: str
    score: float

    def as_citation(self) -> str:
        """Stable string form used in ExpertReport.citations and briefs."""
        return f"memory:{self.source}"


class Retriever:
    """Lightweight TF-IDF retriever over a directory of markdown files."""

    def __init__(self, corpus_dir: Path) -> None:
        self.corpus_dir = Path(corpus_dir)
        self._docs: List[Tuple[str, str]] = []
        self._vocab: Dict[str, int] = {}
        self._tf: List[Dict[int, float]] = []
        self._df: Dict[int, int] = {}
        self._norms: List[float] = []
        self._load()
        self._build_index()

    # -- indexing ----------------------------------------------------------

    def _load(self) -> None:
        for path in sorted(self.corpus_dir.rglob("*.md")):
            rel = str(path.relative_to(self.corpus_dir)).replace("\\", "/")
            text = path.read_text(encoding="utf-8")
            chunks = [chunk.strip() for chunk in re.split(r"\n\s*\n", text) if chunk.strip()]
            for i, chunk in enumerate(chunks):
                self._docs.append((f"{rel}#chunk{i}", chunk))

    def _build_index(self) -> None:
        for _, text in self._docs:
            tokens = _tokenize(text)
            tf_doc: Dict[int, float] = {}
            for tok in tokens:
                if tok not in self._vocab:
                    self._vocab[tok] = len(self._vocab)
                idx = self._vocab[tok]
                tf_doc[idx] = tf_doc.get(idx, 0.0) + 1.0
            self._tf.append(tf_doc)
            for idx in tf_doc:
                self._df[idx] = self._df.get(idx, 0) + 1
        self._num_docs = len(self._docs)
        # cache L2 norms of tf-idf vectors for cosine similarity.
        self._norms = []
        for tf_doc in self._tf:
            sq = 0.0
            for idx, tf in tf_doc.items():
                idf = math.log((self._num_docs + 1) / (self._df.get(idx, 1) + 1)) + 1.0
                sq += (tf * idf) ** 2
            self._norms.append(math.sqrt(sq) or 1.0)

    # -- public API --------------------------------------------------------

    def query(self, text: str, k: int = 3) -> List[MemoryHit]:
        if not text or not self._docs:
            return []
        tokens = _tokenize(text)
        if not tokens:
            return []
        q_tf: Dict[int, float] = {}
        for tok in tokens:
            if tok in self._vocab:
                idx = self._vocab[tok]
                q_tf[idx] = q_tf.get(idx, 0.0) + 1.0
        if not q_tf:
            return []
        # query norm
        q_sq = 0.0
        for idx, tf in q_tf.items():
            idf = math.log((self._num_docs + 1) / (self._df.get(idx, 1) + 1)) + 1.0
            q_sq += (tf * idf) ** 2
        q_norm = math.sqrt(q_sq) or 1.0

        scores: List[Tuple[float, int]] = []
        for doc_idx, tf_doc in enumerate(self._tf):
            dot = 0.0
            for idx, qt in q_tf.items():
                tfd = tf_doc.get(idx)
                if tfd is None:
                    continue
                idf = math.log((self._num_docs + 1) / (self._df.get(idx, 1) + 1)) + 1.0
                dot += (qt * idf) * (tfd * idf)
            if dot <= 0.0:
                continue
            cosine = dot / (q_norm * self._norms[doc_idx])
            scores.append((cosine, doc_idx))
        scores.sort(reverse=True)
        hits: List[MemoryHit] = []
        for score, doc_idx in scores[: max(k, 0)]:
            source, chunk = self._docs[doc_idx]
            snippet = chunk.replace("\n", " ").strip()
            if len(snippet) > 280:
                snippet = snippet[:277] + "..."
            hits.append(MemoryHit(source=source, snippet=snippet, score=round(float(score), 4)))
        return hits

    def sources(self) -> List[str]:
        return [src for src, _ in self._docs]

    def has_source(self, source: str) -> bool:
        return any(src == source for src, _ in self._docs)

    def contains_any(self, text: str, sources: Sequence[str]) -> bool:
        """Return True if any of ``sources`` appears verbatim in ``text``.

        Used by the grader to verify citations are grounded in the corpus
        rather than hallucinated strings.
        """
        if not text:
            return False
        return any(src and src in text for src in sources)

    def count_grounded_citations(self, citations: Iterable[str]) -> int:
        known = set(self.sources())
        n = 0
        for citation in citations or []:
            if not isinstance(citation, str) or not citation.startswith("memory:"):
                continue
            if citation[len("memory:"):] in known:
                n += 1
        return n


__all__ = ["MemoryHit", "Retriever"]