SPARKNET / src /utils /cache_manager.py
MHamdan's picture
Initial commit: SPARKNET framework
d520909
"""
SPARKNET Cache Manager
Redis-based caching for RAG queries and embeddings.
"""
from typing import Optional, Any, List, Dict
from datetime import timedelta
import hashlib
import json
import os
from loguru import logger
# Redis client (lazy loaded)
_redis_client = None
def get_redis_client():
"""Get or create Redis client."""
global _redis_client
if _redis_client is None:
try:
import redis
redis_url = os.getenv("REDIS_URL", "redis://localhost:6379")
_redis_client = redis.from_url(redis_url, decode_responses=True)
# Test connection
_redis_client.ping()
logger.info(f"Redis connected: {redis_url}")
except Exception as e:
logger.warning(f"Redis not available: {e}. Using in-memory cache.")
_redis_client = None
return _redis_client
class CacheManager:
"""
Unified cache manager supporting Redis and in-memory fallback.
"""
def __init__(self, prefix: str = "sparknet", default_ttl: int = 3600):
"""
Initialize cache manager.
Args:
prefix: Key prefix for namespacing
default_ttl: Default TTL in seconds (1 hour)
"""
self.prefix = prefix
self.default_ttl = default_ttl
self._memory_cache: Dict[str, Dict[str, Any]] = {}
self._redis = get_redis_client()
def _make_key(self, key: str) -> str:
"""Create namespaced cache key."""
return f"{self.prefix}:{key}"
def _hash_key(self, *args, **kwargs) -> str:
"""Create hash key from arguments."""
content = json.dumps({"args": args, "kwargs": kwargs}, sort_keys=True)
return hashlib.md5(content.encode()).hexdigest()
def get(self, key: str) -> Optional[Any]:
"""
Get value from cache.
Args:
key: Cache key
Returns:
Cached value or None
"""
full_key = self._make_key(key)
# Try Redis first
if self._redis:
try:
value = self._redis.get(full_key)
if value:
return json.loads(value)
except Exception as e:
logger.warning(f"Redis get failed: {e}")
# Fallback to memory cache
if full_key in self._memory_cache:
entry = self._memory_cache[full_key]
import time
if entry.get("expires_at", 0) > time.time():
return entry.get("value")
else:
del self._memory_cache[full_key]
return None
def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool:
"""
Set value in cache.
Args:
key: Cache key
value: Value to cache
ttl: Time-to-live in seconds (default: self.default_ttl)
Returns:
True if successful
"""
full_key = self._make_key(key)
ttl = ttl or self.default_ttl
# Try Redis first
if self._redis:
try:
self._redis.setex(full_key, ttl, json.dumps(value))
return True
except Exception as e:
logger.warning(f"Redis set failed: {e}")
# Fallback to memory cache
import time
self._memory_cache[full_key] = {
"value": value,
"expires_at": time.time() + ttl
}
# Limit memory cache size
if len(self._memory_cache) > 10000:
self._cleanup_memory_cache()
return True
def delete(self, key: str) -> bool:
"""Delete a cache entry."""
full_key = self._make_key(key)
if self._redis:
try:
self._redis.delete(full_key)
except Exception as e:
logger.warning(f"Redis delete failed: {e}")
if full_key in self._memory_cache:
del self._memory_cache[full_key]
return True
def clear_prefix(self, prefix: str) -> int:
"""Clear all keys matching a prefix."""
pattern = self._make_key(f"{prefix}:*")
count = 0
if self._redis:
try:
keys = self._redis.keys(pattern)
if keys:
count = self._redis.delete(*keys)
except Exception as e:
logger.warning(f"Redis clear failed: {e}")
# Clear from memory cache
to_delete = [k for k in self._memory_cache if k.startswith(self._make_key(prefix))]
for k in to_delete:
del self._memory_cache[k]
count += 1
return count
def _cleanup_memory_cache(self):
"""Remove expired entries from memory cache."""
import time
now = time.time()
expired = [
k for k, v in self._memory_cache.items()
if v.get("expires_at", 0) < now
]
for k in expired:
del self._memory_cache[k]
# If still too large, remove oldest entries
if len(self._memory_cache) > 10000:
sorted_keys = sorted(
self._memory_cache.keys(),
key=lambda k: self._memory_cache[k].get("expires_at", 0)
)
for k in sorted_keys[:len(sorted_keys) // 2]:
del self._memory_cache[k]
class QueryCache(CacheManager):
"""
Specialized cache for RAG queries.
"""
def __init__(self, ttl: int = 3600):
super().__init__(prefix="sparknet:query", default_ttl=ttl)
def get_query_key(self, query: str, doc_ids: Optional[List[str]] = None) -> str:
"""Generate cache key for a query."""
doc_str = ",".join(sorted(doc_ids)) if doc_ids else "all"
content = f"{query.lower().strip()}:{doc_str}"
return hashlib.md5(content.encode()).hexdigest()
def get_query_response(self, query: str, doc_ids: Optional[List[str]] = None) -> Optional[Dict]:
"""Get cached query response."""
key = self.get_query_key(query, doc_ids)
return self.get(key)
def cache_query_response(
self,
query: str,
response: Dict,
doc_ids: Optional[List[str]] = None,
ttl: Optional[int] = None
) -> bool:
"""Cache a query response."""
key = self.get_query_key(query, doc_ids)
return self.set(key, response, ttl)
class EmbeddingCache(CacheManager):
"""
Specialized cache for embeddings.
"""
def __init__(self, ttl: int = 86400): # 24 hours
super().__init__(prefix="sparknet:embed", default_ttl=ttl)
def get_embedding_key(self, text: str, model: str = "default") -> str:
"""Generate cache key for embedding."""
content = f"{model}:{text}"
return hashlib.md5(content.encode()).hexdigest()
def get_embedding(self, text: str, model: str = "default") -> Optional[List[float]]:
"""Get cached embedding."""
key = self.get_embedding_key(text, model)
return self.get(key)
def cache_embedding(
self,
text: str,
embedding: List[float],
model: str = "default"
) -> bool:
"""Cache an embedding."""
key = self.get_embedding_key(text, model)
return self.set(key, embedding)
# Global cache instances
_query_cache: Optional[QueryCache] = None
_embedding_cache: Optional[EmbeddingCache] = None
def get_query_cache() -> QueryCache:
"""Get or create query cache instance."""
global _query_cache
if _query_cache is None:
_query_cache = QueryCache()
return _query_cache
def get_embedding_cache() -> EmbeddingCache:
"""Get or create embedding cache instance."""
global _embedding_cache
if _embedding_cache is None:
_embedding_cache = EmbeddingCache()
return _embedding_cache
# Decorator for caching function results
def cached(prefix: str = "func", ttl: int = 3600):
"""
Decorator to cache function results.
Usage:
@cached(prefix="my_func", ttl=600)
def expensive_function(arg1, arg2):
...
"""
def decorator(func):
cache = CacheManager(prefix=f"sparknet:{prefix}", default_ttl=ttl)
def wrapper(*args, **kwargs):
# Create cache key from function name and arguments
key = f"{func.__name__}:{cache._hash_key(*args, **kwargs)}"
# Try to get from cache
result = cache.get(key)
if result is not None:
return result
# Execute function and cache result
result = func(*args, **kwargs)
cache.set(key, result)
return result
return wrapper
return decorator