File size: 7,200 Bytes
29bfc1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
"""
src/services/cache.py — Phase 3: Upstash Redis wrapper

Provides a thin async layer over Upstash Redis (REST API, so no socket
connection required — works fine on HF free tier which blocks raw TCP to
external hosts).

Falls back gracefully to a local in-memory dict if UPSTASH_REDIS_URL is not
set, so the rest of the codebase can import and call CacheService without
any conditional guards.

Usage:
    from src.services.cache import cache
    await cache.set("key", "value", ttl=3600)
    val = await cache.get("key")          # returns str | None
    await cache.delete("key")
    await cache.lpush("list_key", "item")
    items = await cache.lrange("list_key", 0, -1)
"""

import json
import os
import time
from typing import Any, Optional

import aiohttp

UPSTASH_REDIS_URL = os.getenv("UPSTASH_REDIS_URL", "")
UPSTASH_REDIS_TOKEN = os.getenv("UPSTASH_REDIS_TOKEN", "")

# Fallback in-memory store used when Upstash is not configured.
_mem_store: dict[str, tuple[Any, float]] = {}  # key → (value, expires_at or 0)


class CacheService:
    """
    Async Redis cache backed by Upstash REST API.
    Falls back to an in-memory dict when Upstash is not configured.
    """

    def __init__(self):
        self._enabled = bool(UPSTASH_REDIS_URL and UPSTASH_REDIS_TOKEN)
        self._base_url = UPSTASH_REDIS_URL.rstrip("/") if self._enabled else ""
        self._headers = (
            {"Authorization": f"Bearer {UPSTASH_REDIS_TOKEN}"}
            if self._enabled
            else {}
        )
        if not self._enabled:
            print("[Cache] Upstash not configured — using in-memory fallback")

    # ── Internal REST call ────────────────────────────────────────────
    async def _cmd(self, *args) -> Any:
        """Execute a Redis command via the Upstash REST API."""
        url = f"{self._base_url}/{'/'.join(str(a) for a in args)}"
        async with aiohttp.ClientSession() as session:
            async with session.get(url, headers=self._headers) as resp:
                data = await resp.json()
                if "error" in data:
                    raise RuntimeError(f"Upstash error: {data['error']}")
                return data.get("result")

    # ── Public API ────────────────────────────────────────────────────
    async def get(self, key: str) -> Optional[str]:
        if not self._enabled:
            entry = _mem_store.get(key)
            if entry is None:
                return None
            val, exp = entry
            if exp and time.time() > exp:
                _mem_store.pop(key, None)
                return None
            return val

        result = await self._cmd("GET", key)
        return result  # str or None

    async def set(self, key: str, value: Any, ttl: int = 0) -> bool:
        """
        Store value under key. If ttl > 0, key expires after that many seconds.
        Value is JSON-serialised if not already a str.
        """
        if not isinstance(value, str):
            value = json.dumps(value)

        if not self._enabled:
            exp = time.time() + ttl if ttl else 0
            _mem_store[key] = (value, exp)
            return True

        if ttl:
            await self._cmd("SET", key, value, "EX", ttl)
        else:
            await self._cmd("SET", key, value)
        return True

    async def get_json(self, key: str) -> Optional[Any]:
        raw = await self.get(key)
        if raw is None:
            return None
        try:
            return json.loads(raw)
        except (json.JSONDecodeError, TypeError):
            return raw

    async def set_json(self, key: str, value: Any, ttl: int = 0) -> bool:
        return await self.set(key, json.dumps(value), ttl=ttl)

    async def delete(self, key: str) -> bool:
        if not self._enabled:
            _mem_store.pop(key, None)
            return True
        await self._cmd("DEL", key)
        return True

    async def exists(self, key: str) -> bool:
        if not self._enabled:
            return await self.get(key) is not None
        result = await self._cmd("EXISTS", key)
        return bool(result)

    async def incr(self, key: str) -> int:
        if not self._enabled:
            entry = _mem_store.get(key, ("0", 0))
            new_val = int(entry[0]) + 1
            _mem_store[key] = (str(new_val), entry[1])
            return new_val
        result = await self._cmd("INCR", key)
        return int(result)

    async def expire(self, key: str, ttl: int) -> bool:
        if not self._enabled:
            if key in _mem_store:
                val, _ = _mem_store[key]
                _mem_store[key] = (val, time.time() + ttl)
            return True
        await self._cmd("EXPIRE", key, ttl)
        return True

    # ── List ops (used for job queue) ─────────────────────────────────
    async def lpush(self, key: str, *values: str) -> int:
        """Push values to the LEFT of a list (queue head)."""
        if not self._enabled:
            lst = json.loads(_mem_store.get(key, ("[]", 0))[0])
            for v in values:
                lst.insert(0, v)
            _mem_store[key] = (json.dumps(lst), 0)
            return len(lst)
        for v in values:
            await self._cmd("LPUSH", key, v)
        return 0  # Upstash REST returns the new length; we don't need it here

    async def rpop(self, key: str) -> Optional[str]:
        """Pop one value from the RIGHT of a list (queue tail = oldest item)."""
        if not self._enabled:
            lst = json.loads(_mem_store.get(key, ("[]", 0))[0])
            if not lst:
                return None
            val = lst.pop()
            _mem_store[key] = (json.dumps(lst), 0)
            return val
        return await self._cmd("RPOP", key)

    async def llen(self, key: str) -> int:
        if not self._enabled:
            lst = json.loads(_mem_store.get(key, ("[]", 0))[0])
            return len(lst)
        result = await self._cmd("LLEN", key)
        return int(result or 0)

    async def lrange(self, key: str, start: int, stop: int) -> list[str]:
        if not self._enabled:
            lst = json.loads(_mem_store.get(key, ("[]", 0))[0])
            end = None if stop == -1 else stop + 1
            return lst[start:end]
        result = await self._cmd("LRANGE", key, start, stop)
        return result or []

    # ── Rate limiting helper ──────────────────────────────────────────
    async def rate_limit_check(self, key: str, max_calls: int, window_secs: int) -> bool:
        """
        Returns True if the caller is within the rate limit, False if exceeded.
        Uses a simple counter with TTL.
        """
        count = await self.incr(key)
        if count == 1:
            await self.expire(key, window_secs)
        return count <= max_calls


# Module-level singleton — import this everywhere
cache = CacheService()