|
|
""" |
|
|
SPARKNET Cache Manager |
|
|
Redis-based caching for RAG queries and embeddings. |
|
|
""" |
|
|
|
|
|
from typing import Optional, Any, List, Dict |
|
|
from datetime import timedelta |
|
|
import hashlib |
|
|
import json |
|
|
import os |
|
|
from loguru import logger |
|
|
|
|
|
|
|
|
_redis_client = None |
|
|
|
|
|
|
|
|
def get_redis_client(): |
|
|
"""Get or create Redis client.""" |
|
|
global _redis_client |
|
|
if _redis_client is None: |
|
|
try: |
|
|
import redis |
|
|
redis_url = os.getenv("REDIS_URL", "redis://localhost:6379") |
|
|
_redis_client = redis.from_url(redis_url, decode_responses=True) |
|
|
|
|
|
_redis_client.ping() |
|
|
logger.info(f"Redis connected: {redis_url}") |
|
|
except Exception as e: |
|
|
logger.warning(f"Redis not available: {e}. Using in-memory cache.") |
|
|
_redis_client = None |
|
|
return _redis_client |
|
|
|
|
|
|
|
|
class CacheManager: |
|
|
""" |
|
|
Unified cache manager supporting Redis and in-memory fallback. |
|
|
""" |
|
|
|
|
|
def __init__(self, prefix: str = "sparknet", default_ttl: int = 3600): |
|
|
""" |
|
|
Initialize cache manager. |
|
|
|
|
|
Args: |
|
|
prefix: Key prefix for namespacing |
|
|
default_ttl: Default TTL in seconds (1 hour) |
|
|
""" |
|
|
self.prefix = prefix |
|
|
self.default_ttl = default_ttl |
|
|
self._memory_cache: Dict[str, Dict[str, Any]] = {} |
|
|
self._redis = get_redis_client() |
|
|
|
|
|
def _make_key(self, key: str) -> str: |
|
|
"""Create namespaced cache key.""" |
|
|
return f"{self.prefix}:{key}" |
|
|
|
|
|
def _hash_key(self, *args, **kwargs) -> str: |
|
|
"""Create hash key from arguments.""" |
|
|
content = json.dumps({"args": args, "kwargs": kwargs}, sort_keys=True) |
|
|
return hashlib.md5(content.encode()).hexdigest() |
|
|
|
|
|
def get(self, key: str) -> Optional[Any]: |
|
|
""" |
|
|
Get value from cache. |
|
|
|
|
|
Args: |
|
|
key: Cache key |
|
|
|
|
|
Returns: |
|
|
Cached value or None |
|
|
""" |
|
|
full_key = self._make_key(key) |
|
|
|
|
|
|
|
|
if self._redis: |
|
|
try: |
|
|
value = self._redis.get(full_key) |
|
|
if value: |
|
|
return json.loads(value) |
|
|
except Exception as e: |
|
|
logger.warning(f"Redis get failed: {e}") |
|
|
|
|
|
|
|
|
if full_key in self._memory_cache: |
|
|
entry = self._memory_cache[full_key] |
|
|
import time |
|
|
if entry.get("expires_at", 0) > time.time(): |
|
|
return entry.get("value") |
|
|
else: |
|
|
del self._memory_cache[full_key] |
|
|
|
|
|
return None |
|
|
|
|
|
def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool: |
|
|
""" |
|
|
Set value in cache. |
|
|
|
|
|
Args: |
|
|
key: Cache key |
|
|
value: Value to cache |
|
|
ttl: Time-to-live in seconds (default: self.default_ttl) |
|
|
|
|
|
Returns: |
|
|
True if successful |
|
|
""" |
|
|
full_key = self._make_key(key) |
|
|
ttl = ttl or self.default_ttl |
|
|
|
|
|
|
|
|
if self._redis: |
|
|
try: |
|
|
self._redis.setex(full_key, ttl, json.dumps(value)) |
|
|
return True |
|
|
except Exception as e: |
|
|
logger.warning(f"Redis set failed: {e}") |
|
|
|
|
|
|
|
|
import time |
|
|
self._memory_cache[full_key] = { |
|
|
"value": value, |
|
|
"expires_at": time.time() + ttl |
|
|
} |
|
|
|
|
|
|
|
|
if len(self._memory_cache) > 10000: |
|
|
self._cleanup_memory_cache() |
|
|
|
|
|
return True |
|
|
|
|
|
def delete(self, key: str) -> bool: |
|
|
"""Delete a cache entry.""" |
|
|
full_key = self._make_key(key) |
|
|
|
|
|
if self._redis: |
|
|
try: |
|
|
self._redis.delete(full_key) |
|
|
except Exception as e: |
|
|
logger.warning(f"Redis delete failed: {e}") |
|
|
|
|
|
if full_key in self._memory_cache: |
|
|
del self._memory_cache[full_key] |
|
|
|
|
|
return True |
|
|
|
|
|
def clear_prefix(self, prefix: str) -> int: |
|
|
"""Clear all keys matching a prefix.""" |
|
|
pattern = self._make_key(f"{prefix}:*") |
|
|
count = 0 |
|
|
|
|
|
if self._redis: |
|
|
try: |
|
|
keys = self._redis.keys(pattern) |
|
|
if keys: |
|
|
count = self._redis.delete(*keys) |
|
|
except Exception as e: |
|
|
logger.warning(f"Redis clear failed: {e}") |
|
|
|
|
|
|
|
|
to_delete = [k for k in self._memory_cache if k.startswith(self._make_key(prefix))] |
|
|
for k in to_delete: |
|
|
del self._memory_cache[k] |
|
|
count += 1 |
|
|
|
|
|
return count |
|
|
|
|
|
def _cleanup_memory_cache(self): |
|
|
"""Remove expired entries from memory cache.""" |
|
|
import time |
|
|
now = time.time() |
|
|
expired = [ |
|
|
k for k, v in self._memory_cache.items() |
|
|
if v.get("expires_at", 0) < now |
|
|
] |
|
|
for k in expired: |
|
|
del self._memory_cache[k] |
|
|
|
|
|
|
|
|
if len(self._memory_cache) > 10000: |
|
|
sorted_keys = sorted( |
|
|
self._memory_cache.keys(), |
|
|
key=lambda k: self._memory_cache[k].get("expires_at", 0) |
|
|
) |
|
|
for k in sorted_keys[:len(sorted_keys) // 2]: |
|
|
del self._memory_cache[k] |
|
|
|
|
|
|
|
|
class QueryCache(CacheManager): |
|
|
""" |
|
|
Specialized cache for RAG queries. |
|
|
""" |
|
|
|
|
|
def __init__(self, ttl: int = 3600): |
|
|
super().__init__(prefix="sparknet:query", default_ttl=ttl) |
|
|
|
|
|
def get_query_key(self, query: str, doc_ids: Optional[List[str]] = None) -> str: |
|
|
"""Generate cache key for a query.""" |
|
|
doc_str = ",".join(sorted(doc_ids)) if doc_ids else "all" |
|
|
content = f"{query.lower().strip()}:{doc_str}" |
|
|
return hashlib.md5(content.encode()).hexdigest() |
|
|
|
|
|
def get_query_response(self, query: str, doc_ids: Optional[List[str]] = None) -> Optional[Dict]: |
|
|
"""Get cached query response.""" |
|
|
key = self.get_query_key(query, doc_ids) |
|
|
return self.get(key) |
|
|
|
|
|
def cache_query_response( |
|
|
self, |
|
|
query: str, |
|
|
response: Dict, |
|
|
doc_ids: Optional[List[str]] = None, |
|
|
ttl: Optional[int] = None |
|
|
) -> bool: |
|
|
"""Cache a query response.""" |
|
|
key = self.get_query_key(query, doc_ids) |
|
|
return self.set(key, response, ttl) |
|
|
|
|
|
|
|
|
class EmbeddingCache(CacheManager): |
|
|
""" |
|
|
Specialized cache for embeddings. |
|
|
""" |
|
|
|
|
|
def __init__(self, ttl: int = 86400): |
|
|
super().__init__(prefix="sparknet:embed", default_ttl=ttl) |
|
|
|
|
|
def get_embedding_key(self, text: str, model: str = "default") -> str: |
|
|
"""Generate cache key for embedding.""" |
|
|
content = f"{model}:{text}" |
|
|
return hashlib.md5(content.encode()).hexdigest() |
|
|
|
|
|
def get_embedding(self, text: str, model: str = "default") -> Optional[List[float]]: |
|
|
"""Get cached embedding.""" |
|
|
key = self.get_embedding_key(text, model) |
|
|
return self.get(key) |
|
|
|
|
|
def cache_embedding( |
|
|
self, |
|
|
text: str, |
|
|
embedding: List[float], |
|
|
model: str = "default" |
|
|
) -> bool: |
|
|
"""Cache an embedding.""" |
|
|
key = self.get_embedding_key(text, model) |
|
|
return self.set(key, embedding) |
|
|
|
|
|
|
|
|
|
|
|
_query_cache: Optional[QueryCache] = None |
|
|
_embedding_cache: Optional[EmbeddingCache] = None |
|
|
|
|
|
|
|
|
def get_query_cache() -> QueryCache: |
|
|
"""Get or create query cache instance.""" |
|
|
global _query_cache |
|
|
if _query_cache is None: |
|
|
_query_cache = QueryCache() |
|
|
return _query_cache |
|
|
|
|
|
|
|
|
def get_embedding_cache() -> EmbeddingCache: |
|
|
"""Get or create embedding cache instance.""" |
|
|
global _embedding_cache |
|
|
if _embedding_cache is None: |
|
|
_embedding_cache = EmbeddingCache() |
|
|
return _embedding_cache |
|
|
|
|
|
|
|
|
|
|
|
def cached(prefix: str = "func", ttl: int = 3600): |
|
|
""" |
|
|
Decorator to cache function results. |
|
|
|
|
|
Usage: |
|
|
@cached(prefix="my_func", ttl=600) |
|
|
def expensive_function(arg1, arg2): |
|
|
... |
|
|
""" |
|
|
def decorator(func): |
|
|
cache = CacheManager(prefix=f"sparknet:{prefix}", default_ttl=ttl) |
|
|
|
|
|
def wrapper(*args, **kwargs): |
|
|
|
|
|
key = f"{func.__name__}:{cache._hash_key(*args, **kwargs)}" |
|
|
|
|
|
|
|
|
result = cache.get(key) |
|
|
if result is not None: |
|
|
return result |
|
|
|
|
|
|
|
|
result = func(*args, **kwargs) |
|
|
cache.set(key, result) |
|
|
return result |
|
|
|
|
|
return wrapper |
|
|
return decorator |
|
|
|