Spaces:
Running
Running
| """ | |
| src/services/cache.py — Phase 3: Upstash Redis wrapper | |
| Provides a thin async layer over Upstash Redis (REST API, so no socket | |
| connection required — works fine on HF free tier which blocks raw TCP to | |
| external hosts). | |
| Falls back gracefully to a local in-memory dict if UPSTASH_REDIS_URL is not | |
| set, so the rest of the codebase can import and call CacheService without | |
| any conditional guards. | |
| Usage: | |
| from src.services.cache import cache | |
| await cache.set("key", "value", ttl=3600) | |
| val = await cache.get("key") # returns str | None | |
| await cache.delete("key") | |
| await cache.lpush("list_key", "item") | |
| items = await cache.lrange("list_key", 0, -1) | |
| """ | |
| import json | |
| import os | |
| import time | |
| from typing import Any, Optional | |
| import aiohttp | |
| UPSTASH_REDIS_URL = os.getenv("UPSTASH_REDIS_URL", "") | |
| UPSTASH_REDIS_TOKEN = os.getenv("UPSTASH_REDIS_TOKEN", "") | |
| # Fallback in-memory store used when Upstash is not configured. | |
| _mem_store: dict[str, tuple[Any, float]] = {} # key → (value, expires_at or 0) | |
| class CacheService: | |
| """ | |
| Async Redis cache backed by Upstash REST API. | |
| Falls back to an in-memory dict when Upstash is not configured. | |
| """ | |
| def __init__(self): | |
| self._enabled = bool(UPSTASH_REDIS_URL and UPSTASH_REDIS_TOKEN) | |
| self._base_url = UPSTASH_REDIS_URL.rstrip("/") if self._enabled else "" | |
| self._headers = ( | |
| {"Authorization": f"Bearer {UPSTASH_REDIS_TOKEN}"} | |
| if self._enabled | |
| else {} | |
| ) | |
| if not self._enabled: | |
| print("[Cache] Upstash not configured — using in-memory fallback") | |
| # ── Internal REST call ──────────────────────────────────────────── | |
| async def _cmd(self, *args) -> Any: | |
| """Execute a Redis command via the Upstash REST API.""" | |
| url = f"{self._base_url}/{'/'.join(str(a) for a in args)}" | |
| async with aiohttp.ClientSession() as session: | |
| async with session.get(url, headers=self._headers) as resp: | |
| data = await resp.json() | |
| if "error" in data: | |
| raise RuntimeError(f"Upstash error: {data['error']}") | |
| return data.get("result") | |
| # ── Public API ──────────────────────────────────────────────────── | |
| async def get(self, key: str) -> Optional[str]: | |
| if not self._enabled: | |
| entry = _mem_store.get(key) | |
| if entry is None: | |
| return None | |
| val, exp = entry | |
| if exp and time.time() > exp: | |
| _mem_store.pop(key, None) | |
| return None | |
| return val | |
| result = await self._cmd("GET", key) | |
| return result # str or None | |
| async def set(self, key: str, value: Any, ttl: int = 0) -> bool: | |
| """ | |
| Store value under key. If ttl > 0, key expires after that many seconds. | |
| Value is JSON-serialised if not already a str. | |
| """ | |
| if not isinstance(value, str): | |
| value = json.dumps(value) | |
| if not self._enabled: | |
| exp = time.time() + ttl if ttl else 0 | |
| _mem_store[key] = (value, exp) | |
| return True | |
| if ttl: | |
| await self._cmd("SET", key, value, "EX", ttl) | |
| else: | |
| await self._cmd("SET", key, value) | |
| return True | |
| async def get_json(self, key: str) -> Optional[Any]: | |
| raw = await self.get(key) | |
| if raw is None: | |
| return None | |
| try: | |
| return json.loads(raw) | |
| except (json.JSONDecodeError, TypeError): | |
| return raw | |
| async def set_json(self, key: str, value: Any, ttl: int = 0) -> bool: | |
| return await self.set(key, json.dumps(value), ttl=ttl) | |
| async def delete(self, key: str) -> bool: | |
| if not self._enabled: | |
| _mem_store.pop(key, None) | |
| return True | |
| await self._cmd("DEL", key) | |
| return True | |
| async def exists(self, key: str) -> bool: | |
| if not self._enabled: | |
| return await self.get(key) is not None | |
| result = await self._cmd("EXISTS", key) | |
| return bool(result) | |
| async def incr(self, key: str) -> int: | |
| if not self._enabled: | |
| entry = _mem_store.get(key, ("0", 0)) | |
| new_val = int(entry[0]) + 1 | |
| _mem_store[key] = (str(new_val), entry[1]) | |
| return new_val | |
| result = await self._cmd("INCR", key) | |
| return int(result) | |
| async def expire(self, key: str, ttl: int) -> bool: | |
| if not self._enabled: | |
| if key in _mem_store: | |
| val, _ = _mem_store[key] | |
| _mem_store[key] = (val, time.time() + ttl) | |
| return True | |
| await self._cmd("EXPIRE", key, ttl) | |
| return True | |
| # ── List ops (used for job queue) ───────────────────────────────── | |
| async def lpush(self, key: str, *values: str) -> int: | |
| """Push values to the LEFT of a list (queue head).""" | |
| if not self._enabled: | |
| lst = json.loads(_mem_store.get(key, ("[]", 0))[0]) | |
| for v in values: | |
| lst.insert(0, v) | |
| _mem_store[key] = (json.dumps(lst), 0) | |
| return len(lst) | |
| for v in values: | |
| await self._cmd("LPUSH", key, v) | |
| return 0 # Upstash REST returns the new length; we don't need it here | |
| async def rpop(self, key: str) -> Optional[str]: | |
| """Pop one value from the RIGHT of a list (queue tail = oldest item).""" | |
| if not self._enabled: | |
| lst = json.loads(_mem_store.get(key, ("[]", 0))[0]) | |
| if not lst: | |
| return None | |
| val = lst.pop() | |
| _mem_store[key] = (json.dumps(lst), 0) | |
| return val | |
| return await self._cmd("RPOP", key) | |
| async def llen(self, key: str) -> int: | |
| if not self._enabled: | |
| lst = json.loads(_mem_store.get(key, ("[]", 0))[0]) | |
| return len(lst) | |
| result = await self._cmd("LLEN", key) | |
| return int(result or 0) | |
| async def lrange(self, key: str, start: int, stop: int) -> list[str]: | |
| if not self._enabled: | |
| lst = json.loads(_mem_store.get(key, ("[]", 0))[0]) | |
| end = None if stop == -1 else stop + 1 | |
| return lst[start:end] | |
| result = await self._cmd("LRANGE", key, start, stop) | |
| return result or [] | |
| # ── Rate limiting helper ────────────────────────────────────────── | |
| async def rate_limit_check(self, key: str, max_calls: int, window_secs: int) -> bool: | |
| """ | |
| Returns True if the caller is within the rate limit, False if exceeded. | |
| Uses a simple counter with TTL. | |
| """ | |
| count = await self.incr(key) | |
| if count == 1: | |
| await self.expire(key, window_secs) | |
| return count <= max_calls | |
| # Module-level singleton — import this everywhere | |
| cache = CacheService() |