File size: 8,710 Bytes
d520909
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
"""
SPARKNET Cache Manager
Redis-based caching for RAG queries and embeddings.
"""

from typing import Optional, Any, List, Dict
from datetime import timedelta
import hashlib
import json
import os
from loguru import logger

# Redis client (lazy loaded)
_redis_client = None


def get_redis_client():
    """Get or create Redis client."""
    global _redis_client
    if _redis_client is None:
        try:
            import redis
            redis_url = os.getenv("REDIS_URL", "redis://localhost:6379")
            _redis_client = redis.from_url(redis_url, decode_responses=True)
            # Test connection
            _redis_client.ping()
            logger.info(f"Redis connected: {redis_url}")
        except Exception as e:
            logger.warning(f"Redis not available: {e}. Using in-memory cache.")
            _redis_client = None
    return _redis_client


class CacheManager:
    """
    Unified cache manager supporting Redis and in-memory fallback.
    """

    def __init__(self, prefix: str = "sparknet", default_ttl: int = 3600):
        """
        Initialize cache manager.

        Args:
            prefix: Key prefix for namespacing
            default_ttl: Default TTL in seconds (1 hour)
        """
        self.prefix = prefix
        self.default_ttl = default_ttl
        self._memory_cache: Dict[str, Dict[str, Any]] = {}
        self._redis = get_redis_client()

    def _make_key(self, key: str) -> str:
        """Create namespaced cache key."""
        return f"{self.prefix}:{key}"

    def _hash_key(self, *args, **kwargs) -> str:
        """Create hash key from arguments."""
        content = json.dumps({"args": args, "kwargs": kwargs}, sort_keys=True)
        return hashlib.md5(content.encode()).hexdigest()

    def get(self, key: str) -> Optional[Any]:
        """
        Get value from cache.

        Args:
            key: Cache key

        Returns:
            Cached value or None
        """
        full_key = self._make_key(key)

        # Try Redis first
        if self._redis:
            try:
                value = self._redis.get(full_key)
                if value:
                    return json.loads(value)
            except Exception as e:
                logger.warning(f"Redis get failed: {e}")

        # Fallback to memory cache
        if full_key in self._memory_cache:
            entry = self._memory_cache[full_key]
            import time
            if entry.get("expires_at", 0) > time.time():
                return entry.get("value")
            else:
                del self._memory_cache[full_key]

        return None

    def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool:
        """
        Set value in cache.

        Args:
            key: Cache key
            value: Value to cache
            ttl: Time-to-live in seconds (default: self.default_ttl)

        Returns:
            True if successful
        """
        full_key = self._make_key(key)
        ttl = ttl or self.default_ttl

        # Try Redis first
        if self._redis:
            try:
                self._redis.setex(full_key, ttl, json.dumps(value))
                return True
            except Exception as e:
                logger.warning(f"Redis set failed: {e}")

        # Fallback to memory cache
        import time
        self._memory_cache[full_key] = {
            "value": value,
            "expires_at": time.time() + ttl
        }

        # Limit memory cache size
        if len(self._memory_cache) > 10000:
            self._cleanup_memory_cache()

        return True

    def delete(self, key: str) -> bool:
        """Delete a cache entry."""
        full_key = self._make_key(key)

        if self._redis:
            try:
                self._redis.delete(full_key)
            except Exception as e:
                logger.warning(f"Redis delete failed: {e}")

        if full_key in self._memory_cache:
            del self._memory_cache[full_key]

        return True

    def clear_prefix(self, prefix: str) -> int:
        """Clear all keys matching a prefix."""
        pattern = self._make_key(f"{prefix}:*")
        count = 0

        if self._redis:
            try:
                keys = self._redis.keys(pattern)
                if keys:
                    count = self._redis.delete(*keys)
            except Exception as e:
                logger.warning(f"Redis clear failed: {e}")

        # Clear from memory cache
        to_delete = [k for k in self._memory_cache if k.startswith(self._make_key(prefix))]
        for k in to_delete:
            del self._memory_cache[k]
            count += 1

        return count

    def _cleanup_memory_cache(self):
        """Remove expired entries from memory cache."""
        import time
        now = time.time()
        expired = [
            k for k, v in self._memory_cache.items()
            if v.get("expires_at", 0) < now
        ]
        for k in expired:
            del self._memory_cache[k]

        # If still too large, remove oldest entries
        if len(self._memory_cache) > 10000:
            sorted_keys = sorted(
                self._memory_cache.keys(),
                key=lambda k: self._memory_cache[k].get("expires_at", 0)
            )
            for k in sorted_keys[:len(sorted_keys) // 2]:
                del self._memory_cache[k]


class QueryCache(CacheManager):
    """
    Specialized cache for RAG queries.
    """

    def __init__(self, ttl: int = 3600):
        super().__init__(prefix="sparknet:query", default_ttl=ttl)

    def get_query_key(self, query: str, doc_ids: Optional[List[str]] = None) -> str:
        """Generate cache key for a query."""
        doc_str = ",".join(sorted(doc_ids)) if doc_ids else "all"
        content = f"{query.lower().strip()}:{doc_str}"
        return hashlib.md5(content.encode()).hexdigest()

    def get_query_response(self, query: str, doc_ids: Optional[List[str]] = None) -> Optional[Dict]:
        """Get cached query response."""
        key = self.get_query_key(query, doc_ids)
        return self.get(key)

    def cache_query_response(
        self,
        query: str,
        response: Dict,
        doc_ids: Optional[List[str]] = None,
        ttl: Optional[int] = None
    ) -> bool:
        """Cache a query response."""
        key = self.get_query_key(query, doc_ids)
        return self.set(key, response, ttl)


class EmbeddingCache(CacheManager):
    """
    Specialized cache for embeddings.
    """

    def __init__(self, ttl: int = 86400):  # 24 hours
        super().__init__(prefix="sparknet:embed", default_ttl=ttl)

    def get_embedding_key(self, text: str, model: str = "default") -> str:
        """Generate cache key for embedding."""
        content = f"{model}:{text}"
        return hashlib.md5(content.encode()).hexdigest()

    def get_embedding(self, text: str, model: str = "default") -> Optional[List[float]]:
        """Get cached embedding."""
        key = self.get_embedding_key(text, model)
        return self.get(key)

    def cache_embedding(
        self,
        text: str,
        embedding: List[float],
        model: str = "default"
    ) -> bool:
        """Cache an embedding."""
        key = self.get_embedding_key(text, model)
        return self.set(key, embedding)


# Global cache instances
_query_cache: Optional[QueryCache] = None
_embedding_cache: Optional[EmbeddingCache] = None


def get_query_cache() -> QueryCache:
    """Get or create query cache instance."""
    global _query_cache
    if _query_cache is None:
        _query_cache = QueryCache()
    return _query_cache


def get_embedding_cache() -> EmbeddingCache:
    """Get or create embedding cache instance."""
    global _embedding_cache
    if _embedding_cache is None:
        _embedding_cache = EmbeddingCache()
    return _embedding_cache


# Decorator for caching function results
def cached(prefix: str = "func", ttl: int = 3600):
    """
    Decorator to cache function results.

    Usage:
        @cached(prefix="my_func", ttl=600)
        def expensive_function(arg1, arg2):
            ...
    """
    def decorator(func):
        cache = CacheManager(prefix=f"sparknet:{prefix}", default_ttl=ttl)

        def wrapper(*args, **kwargs):
            # Create cache key from function name and arguments
            key = f"{func.__name__}:{cache._hash_key(*args, **kwargs)}"

            # Try to get from cache
            result = cache.get(key)
            if result is not None:
                return result

            # Execute function and cache result
            result = func(*args, **kwargs)
            cache.set(key, result)
            return result

        return wrapper
    return decorator