File size: 5,750 Bytes
4b4ff9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
"""Disk-backed cache wrappers for LLMProvider / EmbeddingProvider.



Per `docs/02_architecture_v2.md §6.5`: each unique (provider, model,

prompt) goes to the upstream API exactly once. Repeat calls hit a local

`diskcache.Cache` and return in microseconds with zero quota burn.



This buys two things that matter for portfolio-grade ablations:



1.  **Determinism.** Mistral codestral at temperature=0 is *near*

    deterministic but not exactly so — config E showed +4pp over C at

    n=50 with literally identical execution paths and repair fired

    0/50. With cache, the second run reads the same response bytes.

2.  **Free re-runs.** Bumping `schema_top_k` or `fk_hops` and rerunning

    config C only pays the API for the prompts that actually changed.



Cache key for generate:

    sha256(provider.name | provider.model | system | prompt | temperature | max_tokens)



Cache key for embed (per text, not per batch — so reordering inputs hits

the same entries):

    sha256(provider.name | provider.embed_model | text)



Cached values are pydantic-serialised dicts; `latency_ms` on a hit is

reset to 0.0 so eval reports don't accidentally average cache hits with

live API latency.

"""

from __future__ import annotations

import hashlib
import json
from pathlib import Path
from typing import Any

import diskcache

from nl_sql.llm.providers.base import (
    EmbeddingProvider,
    EmbedRequest,
    EmbedResponse,
    GenerateRequest,
    GenerateResponse,
    LLMProvider,
)


def _hash_key(parts: list[Any]) -> str:
    raw = json.dumps(parts, sort_keys=True, ensure_ascii=False, default=str)
    return hashlib.sha256(raw.encode("utf-8")).hexdigest()


def _open_cache(root: Path | str, *, size_limit_gb: int) -> diskcache.Cache:
    Path(root).mkdir(parents=True, exist_ok=True)
    return diskcache.Cache(directory=str(root), size_limit=size_limit_gb * 1024**3)


class CachingLLMProvider:
    """Wrap an `LLMProvider` so repeat `generate()` calls are served from disk.



    The wrapper preserves `name` and `model` so downstream code that reads

    `getattr(provider, "model", "?")` (e.g. eval reports) keeps working.

    """

    def __init__(

        self,

        inner: LLMProvider,

        *,

        cache_dir: Path | str,

        size_limit_gb: int = 4,

    ) -> None:
        self._inner = inner
        self.name = inner.name
        self.model = inner.model
        self._cache = _open_cache(Path(cache_dir) / "gen", size_limit_gb=size_limit_gb)

    def generate(self, req: GenerateRequest) -> GenerateResponse:
        key = _hash_key(
            [
                "gen.v1",
                self._inner.name,
                self._inner.model,
                req.system or "",
                req.prompt,
                req.temperature,
                req.max_tokens,
            ]
        )
        hit = self._cache.get(key)
        if hit is not None:
            data = dict(hit)
            data["latency_ms"] = 0.0  # honest signal: this didn't hit the wire
            return GenerateResponse(**data)

        resp = self._inner.generate(req)
        self._cache.set(key, resp.model_dump())
        return resp

    def close(self) -> None:
        self._cache.close()


class CachingEmbeddingProvider:
    """Wrap an `EmbeddingProvider` so per-text embeddings are cached.



    Batched calls are split into per-text cache lookups; only the missing

    texts are forwarded to the upstream provider in a single batch. This

    means re-indexing the same schema chunks is free, and partial overlaps

    (e.g. one new column added) only pay for the delta.

    """

    def __init__(

        self,

        inner: EmbeddingProvider,

        *,

        cache_dir: Path | str,

        size_limit_gb: int = 4,

    ) -> None:
        self._inner = inner
        self.name = inner.name
        self.embed_model = inner.embed_model
        self._cache = _open_cache(Path(cache_dir) / "embed", size_limit_gb=size_limit_gb)

    def embed(self, req: EmbedRequest) -> EmbedResponse:
        keys = [self._key_for(text) for text in req.texts]
        cached: list[list[float] | None] = [self._cache.get(k) for k in keys]
        missing_idx = [i for i, v in enumerate(cached) if v is None]

        if not missing_idx:
            vectors = [v for v in cached if v is not None]
            return EmbedResponse(vectors=vectors, model=self._inner.embed_model)

        missing_texts = [req.texts[i] for i in missing_idx]
        fresh = self._inner.embed(EmbedRequest(texts=missing_texts))
        if len(fresh.vectors) != len(missing_idx):
            raise RuntimeError(
                "embed batch length mismatch: "
                f"requested {len(missing_idx)}, got {len(fresh.vectors)}"
            )
        for j, vec in zip(missing_idx, fresh.vectors, strict=True):
            self._cache.set(keys[j], list(vec))
            cached[j] = list(vec)

        vectors = [v for v in cached if v is not None]
        return EmbedResponse(vectors=vectors, model=fresh.model)

    def _key_for(self, text: str) -> str:
        return _hash_key(
            [
                "embed.v1",
                self._inner.name,
                self._inner.embed_model,
                text,
            ]
        )

    def close(self) -> None:
        self._cache.close()


def wrap_with_cache(

    provider: LLMProvider,

    *,

    cache_dir: Path | str,

    size_limit_gb: int = 4,

) -> CachingLLMProvider:
    """Convenience wrapper for the common case (LLMProvider only)."""
    return CachingLLMProvider(provider, cache_dir=cache_dir, size_limit_gb=size_limit_gb)