AdarshDRC's picture
fix: Resolving backend
29bfc1f
"""
src/services/cache.py — Phase 3: Upstash Redis wrapper
Provides a thin async layer over Upstash Redis (REST API, so no socket
connection required — works fine on HF free tier which blocks raw TCP to
external hosts).
Falls back gracefully to a local in-memory dict if UPSTASH_REDIS_URL is not
set, so the rest of the codebase can import and call CacheService without
any conditional guards.
Usage:
from src.services.cache import cache
await cache.set("key", "value", ttl=3600)
val = await cache.get("key") # returns str | None
await cache.delete("key")
await cache.lpush("list_key", "item")
items = await cache.lrange("list_key", 0, -1)
"""
import json
import os
import time
from typing import Any, Optional
import aiohttp
UPSTASH_REDIS_URL = os.getenv("UPSTASH_REDIS_URL", "")
UPSTASH_REDIS_TOKEN = os.getenv("UPSTASH_REDIS_TOKEN", "")
# Fallback in-memory store used when Upstash is not configured.
_mem_store: dict[str, tuple[Any, float]] = {} # key → (value, expires_at or 0)
class CacheService:
"""
Async Redis cache backed by Upstash REST API.
Falls back to an in-memory dict when Upstash is not configured.
"""
def __init__(self):
self._enabled = bool(UPSTASH_REDIS_URL and UPSTASH_REDIS_TOKEN)
self._base_url = UPSTASH_REDIS_URL.rstrip("/") if self._enabled else ""
self._headers = (
{"Authorization": f"Bearer {UPSTASH_REDIS_TOKEN}"}
if self._enabled
else {}
)
if not self._enabled:
print("[Cache] Upstash not configured — using in-memory fallback")
# ── Internal REST call ────────────────────────────────────────────
async def _cmd(self, *args) -> Any:
"""Execute a Redis command via the Upstash REST API."""
url = f"{self._base_url}/{'/'.join(str(a) for a in args)}"
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=self._headers) as resp:
data = await resp.json()
if "error" in data:
raise RuntimeError(f"Upstash error: {data['error']}")
return data.get("result")
# ── Public API ────────────────────────────────────────────────────
async def get(self, key: str) -> Optional[str]:
if not self._enabled:
entry = _mem_store.get(key)
if entry is None:
return None
val, exp = entry
if exp and time.time() > exp:
_mem_store.pop(key, None)
return None
return val
result = await self._cmd("GET", key)
return result # str or None
async def set(self, key: str, value: Any, ttl: int = 0) -> bool:
"""
Store value under key. If ttl > 0, key expires after that many seconds.
Value is JSON-serialised if not already a str.
"""
if not isinstance(value, str):
value = json.dumps(value)
if not self._enabled:
exp = time.time() + ttl if ttl else 0
_mem_store[key] = (value, exp)
return True
if ttl:
await self._cmd("SET", key, value, "EX", ttl)
else:
await self._cmd("SET", key, value)
return True
async def get_json(self, key: str) -> Optional[Any]:
raw = await self.get(key)
if raw is None:
return None
try:
return json.loads(raw)
except (json.JSONDecodeError, TypeError):
return raw
async def set_json(self, key: str, value: Any, ttl: int = 0) -> bool:
return await self.set(key, json.dumps(value), ttl=ttl)
async def delete(self, key: str) -> bool:
if not self._enabled:
_mem_store.pop(key, None)
return True
await self._cmd("DEL", key)
return True
async def exists(self, key: str) -> bool:
if not self._enabled:
return await self.get(key) is not None
result = await self._cmd("EXISTS", key)
return bool(result)
async def incr(self, key: str) -> int:
if not self._enabled:
entry = _mem_store.get(key, ("0", 0))
new_val = int(entry[0]) + 1
_mem_store[key] = (str(new_val), entry[1])
return new_val
result = await self._cmd("INCR", key)
return int(result)
async def expire(self, key: str, ttl: int) -> bool:
if not self._enabled:
if key in _mem_store:
val, _ = _mem_store[key]
_mem_store[key] = (val, time.time() + ttl)
return True
await self._cmd("EXPIRE", key, ttl)
return True
# ── List ops (used for job queue) ─────────────────────────────────
async def lpush(self, key: str, *values: str) -> int:
"""Push values to the LEFT of a list (queue head)."""
if not self._enabled:
lst = json.loads(_mem_store.get(key, ("[]", 0))[0])
for v in values:
lst.insert(0, v)
_mem_store[key] = (json.dumps(lst), 0)
return len(lst)
for v in values:
await self._cmd("LPUSH", key, v)
return 0 # Upstash REST returns the new length; we don't need it here
async def rpop(self, key: str) -> Optional[str]:
"""Pop one value from the RIGHT of a list (queue tail = oldest item)."""
if not self._enabled:
lst = json.loads(_mem_store.get(key, ("[]", 0))[0])
if not lst:
return None
val = lst.pop()
_mem_store[key] = (json.dumps(lst), 0)
return val
return await self._cmd("RPOP", key)
async def llen(self, key: str) -> int:
if not self._enabled:
lst = json.loads(_mem_store.get(key, ("[]", 0))[0])
return len(lst)
result = await self._cmd("LLEN", key)
return int(result or 0)
async def lrange(self, key: str, start: int, stop: int) -> list[str]:
if not self._enabled:
lst = json.loads(_mem_store.get(key, ("[]", 0))[0])
end = None if stop == -1 else stop + 1
return lst[start:end]
result = await self._cmd("LRANGE", key, start, stop)
return result or []
# ── Rate limiting helper ──────────────────────────────────────────
async def rate_limit_check(self, key: str, max_calls: int, window_secs: int) -> bool:
"""
Returns True if the caller is within the rate limit, False if exceeded.
Uses a simple counter with TTL.
"""
count = await self.incr(key)
if count == 1:
await self.expire(key, window_secs)
return count <= max_calls
# Module-level singleton — import this everywhere
cache = CacheService()