""" SPARKNET Cache Manager Redis-based caching for RAG queries and embeddings. """ from typing import Optional, Any, List, Dict from datetime import timedelta import hashlib import json import os from loguru import logger # Redis client (lazy loaded) _redis_client = None def get_redis_client(): """Get or create Redis client.""" global _redis_client if _redis_client is None: try: import redis redis_url = os.getenv("REDIS_URL", "redis://localhost:6379") _redis_client = redis.from_url(redis_url, decode_responses=True) # Test connection _redis_client.ping() logger.info(f"Redis connected: {redis_url}") except Exception as e: logger.warning(f"Redis not available: {e}. Using in-memory cache.") _redis_client = None return _redis_client class CacheManager: """ Unified cache manager supporting Redis and in-memory fallback. """ def __init__(self, prefix: str = "sparknet", default_ttl: int = 3600): """ Initialize cache manager. Args: prefix: Key prefix for namespacing default_ttl: Default TTL in seconds (1 hour) """ self.prefix = prefix self.default_ttl = default_ttl self._memory_cache: Dict[str, Dict[str, Any]] = {} self._redis = get_redis_client() def _make_key(self, key: str) -> str: """Create namespaced cache key.""" return f"{self.prefix}:{key}" def _hash_key(self, *args, **kwargs) -> str: """Create hash key from arguments.""" content = json.dumps({"args": args, "kwargs": kwargs}, sort_keys=True) return hashlib.md5(content.encode()).hexdigest() def get(self, key: str) -> Optional[Any]: """ Get value from cache. Args: key: Cache key Returns: Cached value or None """ full_key = self._make_key(key) # Try Redis first if self._redis: try: value = self._redis.get(full_key) if value: return json.loads(value) except Exception as e: logger.warning(f"Redis get failed: {e}") # Fallback to memory cache if full_key in self._memory_cache: entry = self._memory_cache[full_key] import time if entry.get("expires_at", 0) > time.time(): return entry.get("value") else: del self._memory_cache[full_key] return None def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool: """ Set value in cache. Args: key: Cache key value: Value to cache ttl: Time-to-live in seconds (default: self.default_ttl) Returns: True if successful """ full_key = self._make_key(key) ttl = ttl or self.default_ttl # Try Redis first if self._redis: try: self._redis.setex(full_key, ttl, json.dumps(value)) return True except Exception as e: logger.warning(f"Redis set failed: {e}") # Fallback to memory cache import time self._memory_cache[full_key] = { "value": value, "expires_at": time.time() + ttl } # Limit memory cache size if len(self._memory_cache) > 10000: self._cleanup_memory_cache() return True def delete(self, key: str) -> bool: """Delete a cache entry.""" full_key = self._make_key(key) if self._redis: try: self._redis.delete(full_key) except Exception as e: logger.warning(f"Redis delete failed: {e}") if full_key in self._memory_cache: del self._memory_cache[full_key] return True def clear_prefix(self, prefix: str) -> int: """Clear all keys matching a prefix.""" pattern = self._make_key(f"{prefix}:*") count = 0 if self._redis: try: keys = self._redis.keys(pattern) if keys: count = self._redis.delete(*keys) except Exception as e: logger.warning(f"Redis clear failed: {e}") # Clear from memory cache to_delete = [k for k in self._memory_cache if k.startswith(self._make_key(prefix))] for k in to_delete: del self._memory_cache[k] count += 1 return count def _cleanup_memory_cache(self): """Remove expired entries from memory cache.""" import time now = time.time() expired = [ k for k, v in self._memory_cache.items() if v.get("expires_at", 0) < now ] for k in expired: del self._memory_cache[k] # If still too large, remove oldest entries if len(self._memory_cache) > 10000: sorted_keys = sorted( self._memory_cache.keys(), key=lambda k: self._memory_cache[k].get("expires_at", 0) ) for k in sorted_keys[:len(sorted_keys) // 2]: del self._memory_cache[k] class QueryCache(CacheManager): """ Specialized cache for RAG queries. """ def __init__(self, ttl: int = 3600): super().__init__(prefix="sparknet:query", default_ttl=ttl) def get_query_key(self, query: str, doc_ids: Optional[List[str]] = None) -> str: """Generate cache key for a query.""" doc_str = ",".join(sorted(doc_ids)) if doc_ids else "all" content = f"{query.lower().strip()}:{doc_str}" return hashlib.md5(content.encode()).hexdigest() def get_query_response(self, query: str, doc_ids: Optional[List[str]] = None) -> Optional[Dict]: """Get cached query response.""" key = self.get_query_key(query, doc_ids) return self.get(key) def cache_query_response( self, query: str, response: Dict, doc_ids: Optional[List[str]] = None, ttl: Optional[int] = None ) -> bool: """Cache a query response.""" key = self.get_query_key(query, doc_ids) return self.set(key, response, ttl) class EmbeddingCache(CacheManager): """ Specialized cache for embeddings. """ def __init__(self, ttl: int = 86400): # 24 hours super().__init__(prefix="sparknet:embed", default_ttl=ttl) def get_embedding_key(self, text: str, model: str = "default") -> str: """Generate cache key for embedding.""" content = f"{model}:{text}" return hashlib.md5(content.encode()).hexdigest() def get_embedding(self, text: str, model: str = "default") -> Optional[List[float]]: """Get cached embedding.""" key = self.get_embedding_key(text, model) return self.get(key) def cache_embedding( self, text: str, embedding: List[float], model: str = "default" ) -> bool: """Cache an embedding.""" key = self.get_embedding_key(text, model) return self.set(key, embedding) # Global cache instances _query_cache: Optional[QueryCache] = None _embedding_cache: Optional[EmbeddingCache] = None def get_query_cache() -> QueryCache: """Get or create query cache instance.""" global _query_cache if _query_cache is None: _query_cache = QueryCache() return _query_cache def get_embedding_cache() -> EmbeddingCache: """Get or create embedding cache instance.""" global _embedding_cache if _embedding_cache is None: _embedding_cache = EmbeddingCache() return _embedding_cache # Decorator for caching function results def cached(prefix: str = "func", ttl: int = 3600): """ Decorator to cache function results. Usage: @cached(prefix="my_func", ttl=600) def expensive_function(arg1, arg2): ... """ def decorator(func): cache = CacheManager(prefix=f"sparknet:{prefix}", default_ttl=ttl) def wrapper(*args, **kwargs): # Create cache key from function name and arguments key = f"{func.__name__}:{cache._hash_key(*args, **kwargs)}" # Try to get from cache result = cache.get(key) if result is not None: return result # Execute function and cache result result = func(*args, **kwargs) cache.set(key, result) return result return wrapper return decorator