File size: 3,175 Bytes
bbe01fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# backend/app/services/semantic_cache.py
# In-memory semantic cache. Replaces Redis-backed CacheService entirely.
# No external service required — works in any environment including HF Spaces.
#
# Design choices:
# - numpy dot product on L2-normalised vectors = cosine similarity (same as cos_sim)
#   without the overhead of importing sentence_transformers.util in the hot path.
# - asyncio.Lock guards all writes. Reads outside the lock are safe because Python's
#   GIL prevents partial dict reads, and we only mutate inside the lock.
# - Oldest-first eviction (by insertion order via list) instead of LRU to keep
#   O(1) insertion and avoid per-access bookkeeping in the hot path.

import asyncio
import time
from typing import Optional

import numpy as np

from app.core.logging import get_logger

logger = get_logger(__name__)


class SemanticCache:
    def __init__(
        self,
        max_size: int = 512,
        ttl_seconds: int = 3600,
        similarity_threshold: float = 0.92,
    ) -> None:
        self._max_size = max_size
        self._ttl = ttl_seconds
        self._threshold = similarity_threshold
        self._lock = asyncio.Lock()
        # Each entry: {"embedding": np.ndarray (384,), "response": str, "inserted_at": float}
        # Ordered by insertion time for oldest-first eviction.
        self._entries: list[dict] = []
        self._hits: int = 0

    async def get(self, query_embedding: np.ndarray) -> Optional[str]:
        """
        Cosine similarity lookup. Returns cached response if best score >= threshold.
        query_embedding must already be L2-normalised (bge-small normalises by default).
        """
        if not self._entries:
            return None

        now = time.monotonic()
        # Build matrix of all stored embeddings for batch dot product (one numpy op).
        valid = [e for e in self._entries if now - e["inserted_at"] < self._ttl]
        if not valid:
            return None

        matrix = np.stack([e["embedding"] for e in valid])  # (N, 384)
        scores: np.ndarray = matrix @ query_embedding  # cosine sim, shape (N,)

        best_idx = int(np.argmax(scores))
        best_score = float(scores[best_idx])

        if best_score >= self._threshold:
            self._hits += 1
            logger.debug("Semantic cache hit | score=%.4f", best_score)
            return valid[best_idx]["response"]

        return None

    async def set(self, query_embedding: np.ndarray, response: str) -> None:
        """Store a new entry. Evicts oldest if at capacity."""
        async with self._lock:
            if len(self._entries) >= self._max_size:
                # Evict oldest (index 0 is the oldest insertion).
                self._entries.pop(0)
            self._entries.append({
                "embedding": query_embedding,
                "response": response,
                "inserted_at": time.monotonic(),
            })

    async def stats(self) -> dict:
        return {
            "entries": len(self._entries),
            "hits": self._hits,
            "max_size": self._max_size,
            "ttl_seconds": self._ttl,
            "threshold": self._threshold,
        }