|
|
""" |
|
|
Caching utilities for SPARKNET |
|
|
Provides LRU caching for LLM responses and embeddings |
|
|
Following FAANG best practices for performance optimization |
|
|
""" |
|
|
|
|
|
import hashlib |
|
|
import json |
|
|
from typing import Any, Optional, Dict, Callable |
|
|
from functools import wraps |
|
|
from datetime import datetime, timedelta |
|
|
from cachetools import TTLCache, LRUCache |
|
|
from loguru import logger |
|
|
|
|
|
|
|
|
class LLMResponseCache: |
|
|
""" |
|
|
Cache for LLM responses to reduce API calls and latency. |
|
|
|
|
|
Features: |
|
|
- TTL-based expiration |
|
|
- LRU eviction policy |
|
|
- Content-based hashing |
|
|
- Statistics tracking |
|
|
|
|
|
Example: |
|
|
cache = LLMResponseCache(maxsize=1000, ttl=3600) |
|
|
|
|
|
# Check cache |
|
|
cached = cache.get(prompt, model) |
|
|
if cached: |
|
|
return cached |
|
|
|
|
|
# Store result |
|
|
cache.set(prompt, model, response) |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
maxsize: int = 1000, |
|
|
ttl: int = 3600, |
|
|
enabled: bool = True, |
|
|
): |
|
|
""" |
|
|
Initialize LLM response cache. |
|
|
|
|
|
Args: |
|
|
maxsize: Maximum number of cached responses |
|
|
ttl: Time-to-live in seconds |
|
|
enabled: Whether caching is enabled |
|
|
""" |
|
|
self.maxsize = maxsize |
|
|
self.ttl = ttl |
|
|
self.enabled = enabled |
|
|
self._cache: TTLCache = TTLCache(maxsize=maxsize, ttl=ttl) |
|
|
|
|
|
|
|
|
self._hits = 0 |
|
|
self._misses = 0 |
|
|
|
|
|
logger.info(f"Initialized LLMResponseCache (maxsize={maxsize}, ttl={ttl}s)") |
|
|
|
|
|
def _hash_key(self, prompt: str, model: str, **kwargs) -> str: |
|
|
"""Generate cache key from prompt and parameters.""" |
|
|
key_data = { |
|
|
"prompt": prompt, |
|
|
"model": model, |
|
|
**kwargs, |
|
|
} |
|
|
key_str = json.dumps(key_data, sort_keys=True) |
|
|
return hashlib.sha256(key_str.encode()).hexdigest() |
|
|
|
|
|
def get(self, prompt: str, model: str, **kwargs) -> Optional[str]: |
|
|
""" |
|
|
Get cached response if available. |
|
|
|
|
|
Args: |
|
|
prompt: The prompt sent to the LLM |
|
|
model: Model identifier |
|
|
**kwargs: Additional parameters |
|
|
|
|
|
Returns: |
|
|
Cached response or None |
|
|
""" |
|
|
if not self.enabled: |
|
|
return None |
|
|
|
|
|
key = self._hash_key(prompt, model, **kwargs) |
|
|
result = self._cache.get(key) |
|
|
|
|
|
if result is not None: |
|
|
self._hits += 1 |
|
|
logger.debug(f"Cache HIT for model={model}") |
|
|
else: |
|
|
self._misses += 1 |
|
|
|
|
|
return result |
|
|
|
|
|
def set(self, prompt: str, model: str, response: str, **kwargs): |
|
|
""" |
|
|
Store response in cache. |
|
|
|
|
|
Args: |
|
|
prompt: The prompt sent to the LLM |
|
|
model: Model identifier |
|
|
response: The LLM response |
|
|
**kwargs: Additional parameters |
|
|
""" |
|
|
if not self.enabled: |
|
|
return |
|
|
|
|
|
key = self._hash_key(prompt, model, **kwargs) |
|
|
self._cache[key] = response |
|
|
logger.debug(f"Cached response for model={model}") |
|
|
|
|
|
def invalidate(self, prompt: str, model: str, **kwargs): |
|
|
"""Invalidate a specific cache entry.""" |
|
|
key = self._hash_key(prompt, model, **kwargs) |
|
|
self._cache.pop(key, None) |
|
|
|
|
|
def clear(self): |
|
|
"""Clear all cached entries.""" |
|
|
self._cache.clear() |
|
|
logger.info("LLM response cache cleared") |
|
|
|
|
|
@property |
|
|
def stats(self) -> Dict[str, Any]: |
|
|
"""Get cache statistics.""" |
|
|
total = self._hits + self._misses |
|
|
hit_rate = (self._hits / total * 100) if total > 0 else 0 |
|
|
|
|
|
return { |
|
|
"hits": self._hits, |
|
|
"misses": self._misses, |
|
|
"total": total, |
|
|
"hit_rate": f"{hit_rate:.1f}%", |
|
|
"size": len(self._cache), |
|
|
"maxsize": self.maxsize, |
|
|
"enabled": self.enabled, |
|
|
} |
|
|
|
|
|
|
|
|
class EmbeddingCache: |
|
|
""" |
|
|
Cache for text embeddings to avoid recomputation. |
|
|
|
|
|
Uses LRU policy with configurable size. |
|
|
Embeddings are stored as lists of floats. |
|
|
""" |
|
|
|
|
|
def __init__(self, maxsize: int = 10000, enabled: bool = True): |
|
|
""" |
|
|
Initialize embedding cache. |
|
|
|
|
|
Args: |
|
|
maxsize: Maximum number of cached embeddings |
|
|
enabled: Whether caching is enabled |
|
|
""" |
|
|
self.maxsize = maxsize |
|
|
self.enabled = enabled |
|
|
self._cache: LRUCache = LRUCache(maxsize=maxsize) |
|
|
|
|
|
self._hits = 0 |
|
|
self._misses = 0 |
|
|
|
|
|
logger.info(f"Initialized EmbeddingCache (maxsize={maxsize})") |
|
|
|
|
|
def _hash_key(self, text: str, model: str) -> str: |
|
|
"""Generate cache key from text and model.""" |
|
|
key_str = f"{model}:{text}" |
|
|
return hashlib.sha256(key_str.encode()).hexdigest() |
|
|
|
|
|
def get(self, text: str, model: str) -> Optional[list]: |
|
|
"""Get cached embedding if available.""" |
|
|
if not self.enabled: |
|
|
return None |
|
|
|
|
|
key = self._hash_key(text, model) |
|
|
result = self._cache.get(key) |
|
|
|
|
|
if result is not None: |
|
|
self._hits += 1 |
|
|
else: |
|
|
self._misses += 1 |
|
|
|
|
|
return result |
|
|
|
|
|
def set(self, text: str, model: str, embedding: list): |
|
|
"""Store embedding in cache.""" |
|
|
if not self.enabled: |
|
|
return |
|
|
|
|
|
key = self._hash_key(text, model) |
|
|
self._cache[key] = embedding |
|
|
|
|
|
def get_batch(self, texts: list, model: str) -> tuple: |
|
|
""" |
|
|
Get cached embeddings for a batch of texts. |
|
|
|
|
|
Returns: |
|
|
Tuple of (cached_results, uncached_indices) |
|
|
""" |
|
|
results = {} |
|
|
uncached = [] |
|
|
|
|
|
for i, text in enumerate(texts): |
|
|
cached = self.get(text, model) |
|
|
if cached is not None: |
|
|
results[i] = cached |
|
|
else: |
|
|
uncached.append(i) |
|
|
|
|
|
return results, uncached |
|
|
|
|
|
def set_batch(self, texts: list, model: str, embeddings: list): |
|
|
"""Store batch of embeddings.""" |
|
|
for text, embedding in zip(texts, embeddings): |
|
|
self.set(text, model, embedding) |
|
|
|
|
|
@property |
|
|
def stats(self) -> Dict[str, Any]: |
|
|
"""Get cache statistics.""" |
|
|
total = self._hits + self._misses |
|
|
hit_rate = (self._hits / total * 100) if total > 0 else 0 |
|
|
|
|
|
return { |
|
|
"hits": self._hits, |
|
|
"misses": self._misses, |
|
|
"hit_rate": f"{hit_rate:.1f}%", |
|
|
"size": len(self._cache), |
|
|
"maxsize": self.maxsize, |
|
|
} |
|
|
|
|
|
|
|
|
def cached_llm_call(cache: LLMResponseCache): |
|
|
""" |
|
|
Decorator for caching LLM function calls. |
|
|
|
|
|
Example: |
|
|
@cached_llm_call(llm_cache) |
|
|
async def generate_response(prompt: str, model: str) -> str: |
|
|
... |
|
|
""" |
|
|
|
|
|
def decorator(func: Callable) -> Callable: |
|
|
@wraps(func) |
|
|
async def async_wrapper(prompt: str, model: str, **kwargs): |
|
|
|
|
|
cached = cache.get(prompt, model, **kwargs) |
|
|
if cached is not None: |
|
|
return cached |
|
|
|
|
|
|
|
|
result = await func(prompt, model, **kwargs) |
|
|
|
|
|
|
|
|
cache.set(prompt, model, result, **kwargs) |
|
|
return result |
|
|
|
|
|
@wraps(func) |
|
|
def sync_wrapper(prompt: str, model: str, **kwargs): |
|
|
|
|
|
cached = cache.get(prompt, model, **kwargs) |
|
|
if cached is not None: |
|
|
return cached |
|
|
|
|
|
|
|
|
result = func(prompt, model, **kwargs) |
|
|
|
|
|
|
|
|
cache.set(prompt, model, result, **kwargs) |
|
|
return result |
|
|
|
|
|
import asyncio |
|
|
if asyncio.iscoroutinefunction(func): |
|
|
return async_wrapper |
|
|
return sync_wrapper |
|
|
|
|
|
return decorator |
|
|
|
|
|
|
|
|
|
|
|
_llm_cache: Optional[LLMResponseCache] = None |
|
|
_embedding_cache: Optional[EmbeddingCache] = None |
|
|
|
|
|
|
|
|
def get_llm_cache() -> LLMResponseCache: |
|
|
"""Get or create the global LLM response cache.""" |
|
|
global _llm_cache |
|
|
if _llm_cache is None: |
|
|
_llm_cache = LLMResponseCache() |
|
|
return _llm_cache |
|
|
|
|
|
|
|
|
def get_embedding_cache() -> EmbeddingCache: |
|
|
"""Get or create the global embedding cache.""" |
|
|
global _embedding_cache |
|
|
if _embedding_cache is None: |
|
|
_embedding_cache = EmbeddingCache() |
|
|
return _embedding_cache |
|
|
|