File size: 5,750 Bytes
4b4ff9e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 | """Disk-backed cache wrappers for LLMProvider / EmbeddingProvider.
Per `docs/02_architecture_v2.md §6.5`: each unique (provider, model,
prompt) goes to the upstream API exactly once. Repeat calls hit a local
`diskcache.Cache` and return in microseconds with zero quota burn.
This buys two things that matter for portfolio-grade ablations:
1. **Determinism.** Mistral codestral at temperature=0 is *near*
deterministic but not exactly so — config E showed +4pp over C at
n=50 with literally identical execution paths and repair fired
0/50. With cache, the second run reads the same response bytes.
2. **Free re-runs.** Bumping `schema_top_k` or `fk_hops` and rerunning
config C only pays the API for the prompts that actually changed.
Cache key for generate:
sha256(provider.name | provider.model | system | prompt | temperature | max_tokens)
Cache key for embed (per text, not per batch — so reordering inputs hits
the same entries):
sha256(provider.name | provider.embed_model | text)
Cached values are pydantic-serialised dicts; `latency_ms` on a hit is
reset to 0.0 so eval reports don't accidentally average cache hits with
live API latency.
"""
from __future__ import annotations
import hashlib
import json
from pathlib import Path
from typing import Any
import diskcache
from nl_sql.llm.providers.base import (
EmbeddingProvider,
EmbedRequest,
EmbedResponse,
GenerateRequest,
GenerateResponse,
LLMProvider,
)
def _hash_key(parts: list[Any]) -> str:
raw = json.dumps(parts, sort_keys=True, ensure_ascii=False, default=str)
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
def _open_cache(root: Path | str, *, size_limit_gb: int) -> diskcache.Cache:
Path(root).mkdir(parents=True, exist_ok=True)
return diskcache.Cache(directory=str(root), size_limit=size_limit_gb * 1024**3)
class CachingLLMProvider:
"""Wrap an `LLMProvider` so repeat `generate()` calls are served from disk.
The wrapper preserves `name` and `model` so downstream code that reads
`getattr(provider, "model", "?")` (e.g. eval reports) keeps working.
"""
def __init__(
self,
inner: LLMProvider,
*,
cache_dir: Path | str,
size_limit_gb: int = 4,
) -> None:
self._inner = inner
self.name = inner.name
self.model = inner.model
self._cache = _open_cache(Path(cache_dir) / "gen", size_limit_gb=size_limit_gb)
def generate(self, req: GenerateRequest) -> GenerateResponse:
key = _hash_key(
[
"gen.v1",
self._inner.name,
self._inner.model,
req.system or "",
req.prompt,
req.temperature,
req.max_tokens,
]
)
hit = self._cache.get(key)
if hit is not None:
data = dict(hit)
data["latency_ms"] = 0.0 # honest signal: this didn't hit the wire
return GenerateResponse(**data)
resp = self._inner.generate(req)
self._cache.set(key, resp.model_dump())
return resp
def close(self) -> None:
self._cache.close()
class CachingEmbeddingProvider:
"""Wrap an `EmbeddingProvider` so per-text embeddings are cached.
Batched calls are split into per-text cache lookups; only the missing
texts are forwarded to the upstream provider in a single batch. This
means re-indexing the same schema chunks is free, and partial overlaps
(e.g. one new column added) only pay for the delta.
"""
def __init__(
self,
inner: EmbeddingProvider,
*,
cache_dir: Path | str,
size_limit_gb: int = 4,
) -> None:
self._inner = inner
self.name = inner.name
self.embed_model = inner.embed_model
self._cache = _open_cache(Path(cache_dir) / "embed", size_limit_gb=size_limit_gb)
def embed(self, req: EmbedRequest) -> EmbedResponse:
keys = [self._key_for(text) for text in req.texts]
cached: list[list[float] | None] = [self._cache.get(k) for k in keys]
missing_idx = [i for i, v in enumerate(cached) if v is None]
if not missing_idx:
vectors = [v for v in cached if v is not None]
return EmbedResponse(vectors=vectors, model=self._inner.embed_model)
missing_texts = [req.texts[i] for i in missing_idx]
fresh = self._inner.embed(EmbedRequest(texts=missing_texts))
if len(fresh.vectors) != len(missing_idx):
raise RuntimeError(
"embed batch length mismatch: "
f"requested {len(missing_idx)}, got {len(fresh.vectors)}"
)
for j, vec in zip(missing_idx, fresh.vectors, strict=True):
self._cache.set(keys[j], list(vec))
cached[j] = list(vec)
vectors = [v for v in cached if v is not None]
return EmbedResponse(vectors=vectors, model=fresh.model)
def _key_for(self, text: str) -> str:
return _hash_key(
[
"embed.v1",
self._inner.name,
self._inner.embed_model,
text,
]
)
def close(self) -> None:
self._cache.close()
def wrap_with_cache(
provider: LLMProvider,
*,
cache_dir: Path | str,
size_limit_gb: int = 4,
) -> CachingLLMProvider:
"""Convenience wrapper for the common case (LLMProvider only)."""
return CachingLLMProvider(provider, cache_dir=cache_dir, size_limit_gb=size_limit_gb)
|