File size: 4,902 Bytes
b5cb5bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import json
import time
from collections import OrderedDict
from dataclasses import dataclass
from hashlib import sha256
from threading import Lock
from typing import Any, Dict, Optional

try:
    import redis.asyncio as redis_async  # type: ignore[import-not-found]
except Exception:  # pragma: no cover - optional dependency
    redis_async = None  # type: ignore[assignment]


@dataclass
class _CacheRecord:
    value: Any
    expires_at: float


class DeterministicResponseCache:
    """TTL + LRU response cache with optional Redis backing.

    - Local cache is always used for fast lookups.
    - Redis is optional and fail-open.
    - Values are normalized through JSON roundtrip to keep payloads serializable.
    """

    def __init__(
        self,
        *,
        enabled: bool,
        max_entries: int,
        redis_url: Optional[str] = None,
        redis_prefix: str = "mathpulse:det-cache:",
        logger: Any = None,
    ) -> None:
        self.enabled = bool(enabled)
        self.max_entries = max(1, int(max_entries))
        self.redis_prefix = redis_prefix
        self.logger = logger

        self._lock = Lock()
        self._local: OrderedDict[str, _CacheRecord] = OrderedDict()

        self._redis = None
        if self.enabled and redis_url and redis_async is not None:
            try:
                self._redis = redis_async.from_url(redis_url, encoding="utf-8", decode_responses=True)
            except Exception as err:
                self._warn(f"Redis cache disabled: failed to initialize client: {err}")
                self._redis = None

    def build_cache_key(self, namespace: str, payload: Dict[str, Any]) -> str:
        canonical = json.dumps(payload, sort_keys=True, separators=(",", ":"), default=str, ensure_ascii=True)
        digest = sha256(canonical.encode("utf-8")).hexdigest()
        return f"{namespace}:{digest}"

    async def get(self, key: str) -> Optional[Any]:
        if not self.enabled:
            return None

        local_hit = self._get_local(key)
        if local_hit is not None:
            return local_hit

        if self._redis is None:
            return None

        redis_key = self._redis_key(key)
        try:
            raw = await self._redis.get(redis_key)
            if raw is None:
                return None
            decoded = json.loads(raw)

            ttl_seconds = await self._redis.ttl(redis_key)
            if isinstance(ttl_seconds, int) and ttl_seconds > 0:
                self._set_local(key, decoded, ttl_seconds)
            return decoded
        except Exception as err:
            self._warn(f"Redis cache get failed for {key}: {err}")
            return None

    async def set(self, key: str, value: Any, ttl_seconds: int) -> None:
        if not self.enabled:
            return

        ttl = int(ttl_seconds)
        if ttl <= 0:
            return

        normalized_value = self._normalize(value)
        self._set_local(key, normalized_value, ttl)

        if self._redis is None:
            return

        redis_key = self._redis_key(key)
        try:
            await self._redis.set(redis_key, json.dumps(normalized_value, separators=(",", ":"), default=str), ex=ttl)
        except Exception as err:
            self._warn(f"Redis cache set failed for {key}: {err}")

    async def clear(self) -> None:
        with self._lock:
            self._local.clear()

    def _normalize(self, value: Any) -> Any:
        # Keep payloads immutable enough for cache semantics and JSON-safe for Redis.
        return json.loads(json.dumps(value, default=str))

    def _redis_key(self, key: str) -> str:
        return f"{self.redis_prefix}{key}"

    def _get_local(self, key: str) -> Optional[Any]:
        now = time.time()
        with self._lock:
            self._prune_locked(now)
            record = self._local.get(key)
            if record is None:
                return None
            if record.expires_at <= now:
                self._local.pop(key, None)
                return None
            self._local.move_to_end(key, last=True)
            return record.value

    def _set_local(self, key: str, value: Any, ttl_seconds: int) -> None:
        expires_at = time.time() + ttl_seconds
        with self._lock:
            self._prune_locked(time.time())
            self._local[key] = _CacheRecord(value=value, expires_at=expires_at)
            self._local.move_to_end(key, last=True)
            while len(self._local) > self.max_entries:
                self._local.popitem(last=False)

    def _prune_locked(self, now: float) -> None:
        expired_keys = [cache_key for cache_key, record in self._local.items() if record.expires_at <= now]
        for cache_key in expired_keys:
            self._local.pop(cache_key, None)

    def _warn(self, message: str) -> None:
        if self.logger is not None:
            self.logger.warning(message)