File size: 8,315 Bytes
a9dc537 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 |
"""
Caching utilities for SPARKNET
Provides LRU caching for LLM responses and embeddings
Following FAANG best practices for performance optimization
"""
import hashlib
import json
from typing import Any, Optional, Dict, Callable
from functools import wraps
from datetime import datetime, timedelta
from cachetools import TTLCache, LRUCache
from loguru import logger
class LLMResponseCache:
"""
Cache for LLM responses to reduce API calls and latency.
Features:
- TTL-based expiration
- LRU eviction policy
- Content-based hashing
- Statistics tracking
Example:
cache = LLMResponseCache(maxsize=1000, ttl=3600)
# Check cache
cached = cache.get(prompt, model)
if cached:
return cached
# Store result
cache.set(prompt, model, response)
"""
def __init__(
self,
maxsize: int = 1000,
ttl: int = 3600, # 1 hour default
enabled: bool = True,
):
"""
Initialize LLM response cache.
Args:
maxsize: Maximum number of cached responses
ttl: Time-to-live in seconds
enabled: Whether caching is enabled
"""
self.maxsize = maxsize
self.ttl = ttl
self.enabled = enabled
self._cache: TTLCache = TTLCache(maxsize=maxsize, ttl=ttl)
# Statistics
self._hits = 0
self._misses = 0
logger.info(f"Initialized LLMResponseCache (maxsize={maxsize}, ttl={ttl}s)")
def _hash_key(self, prompt: str, model: str, **kwargs) -> str:
"""Generate cache key from prompt and parameters."""
key_data = {
"prompt": prompt,
"model": model,
**kwargs,
}
key_str = json.dumps(key_data, sort_keys=True)
return hashlib.sha256(key_str.encode()).hexdigest()
def get(self, prompt: str, model: str, **kwargs) -> Optional[str]:
"""
Get cached response if available.
Args:
prompt: The prompt sent to the LLM
model: Model identifier
**kwargs: Additional parameters
Returns:
Cached response or None
"""
if not self.enabled:
return None
key = self._hash_key(prompt, model, **kwargs)
result = self._cache.get(key)
if result is not None:
self._hits += 1
logger.debug(f"Cache HIT for model={model}")
else:
self._misses += 1
return result
def set(self, prompt: str, model: str, response: str, **kwargs):
"""
Store response in cache.
Args:
prompt: The prompt sent to the LLM
model: Model identifier
response: The LLM response
**kwargs: Additional parameters
"""
if not self.enabled:
return
key = self._hash_key(prompt, model, **kwargs)
self._cache[key] = response
logger.debug(f"Cached response for model={model}")
def invalidate(self, prompt: str, model: str, **kwargs):
"""Invalidate a specific cache entry."""
key = self._hash_key(prompt, model, **kwargs)
self._cache.pop(key, None)
def clear(self):
"""Clear all cached entries."""
self._cache.clear()
logger.info("LLM response cache cleared")
@property
def stats(self) -> Dict[str, Any]:
"""Get cache statistics."""
total = self._hits + self._misses
hit_rate = (self._hits / total * 100) if total > 0 else 0
return {
"hits": self._hits,
"misses": self._misses,
"total": total,
"hit_rate": f"{hit_rate:.1f}%",
"size": len(self._cache),
"maxsize": self.maxsize,
"enabled": self.enabled,
}
class EmbeddingCache:
"""
Cache for text embeddings to avoid recomputation.
Uses LRU policy with configurable size.
Embeddings are stored as lists of floats.
"""
def __init__(self, maxsize: int = 10000, enabled: bool = True):
"""
Initialize embedding cache.
Args:
maxsize: Maximum number of cached embeddings
enabled: Whether caching is enabled
"""
self.maxsize = maxsize
self.enabled = enabled
self._cache: LRUCache = LRUCache(maxsize=maxsize)
self._hits = 0
self._misses = 0
logger.info(f"Initialized EmbeddingCache (maxsize={maxsize})")
def _hash_key(self, text: str, model: str) -> str:
"""Generate cache key from text and model."""
key_str = f"{model}:{text}"
return hashlib.sha256(key_str.encode()).hexdigest()
def get(self, text: str, model: str) -> Optional[list]:
"""Get cached embedding if available."""
if not self.enabled:
return None
key = self._hash_key(text, model)
result = self._cache.get(key)
if result is not None:
self._hits += 1
else:
self._misses += 1
return result
def set(self, text: str, model: str, embedding: list):
"""Store embedding in cache."""
if not self.enabled:
return
key = self._hash_key(text, model)
self._cache[key] = embedding
def get_batch(self, texts: list, model: str) -> tuple:
"""
Get cached embeddings for a batch of texts.
Returns:
Tuple of (cached_results, uncached_indices)
"""
results = {}
uncached = []
for i, text in enumerate(texts):
cached = self.get(text, model)
if cached is not None:
results[i] = cached
else:
uncached.append(i)
return results, uncached
def set_batch(self, texts: list, model: str, embeddings: list):
"""Store batch of embeddings."""
for text, embedding in zip(texts, embeddings):
self.set(text, model, embedding)
@property
def stats(self) -> Dict[str, Any]:
"""Get cache statistics."""
total = self._hits + self._misses
hit_rate = (self._hits / total * 100) if total > 0 else 0
return {
"hits": self._hits,
"misses": self._misses,
"hit_rate": f"{hit_rate:.1f}%",
"size": len(self._cache),
"maxsize": self.maxsize,
}
def cached_llm_call(cache: LLMResponseCache):
"""
Decorator for caching LLM function calls.
Example:
@cached_llm_call(llm_cache)
async def generate_response(prompt: str, model: str) -> str:
...
"""
def decorator(func: Callable) -> Callable:
@wraps(func)
async def async_wrapper(prompt: str, model: str, **kwargs):
# Check cache
cached = cache.get(prompt, model, **kwargs)
if cached is not None:
return cached
# Call function
result = await func(prompt, model, **kwargs)
# Cache result
cache.set(prompt, model, result, **kwargs)
return result
@wraps(func)
def sync_wrapper(prompt: str, model: str, **kwargs):
# Check cache
cached = cache.get(prompt, model, **kwargs)
if cached is not None:
return cached
# Call function
result = func(prompt, model, **kwargs)
# Cache result
cache.set(prompt, model, result, **kwargs)
return result
import asyncio
if asyncio.iscoroutinefunction(func):
return async_wrapper
return sync_wrapper
return decorator
# Global cache instances
_llm_cache: Optional[LLMResponseCache] = None
_embedding_cache: Optional[EmbeddingCache] = None
def get_llm_cache() -> LLMResponseCache:
"""Get or create the global LLM response cache."""
global _llm_cache
if _llm_cache is None:
_llm_cache = LLMResponseCache()
return _llm_cache
def get_embedding_cache() -> EmbeddingCache:
"""Get or create the global embedding cache."""
global _embedding_cache
if _embedding_cache is None:
_embedding_cache = EmbeddingCache()
return _embedding_cache
|