Spaces:
Sleeping
Sleeping
| """ | |
| Semantic cache that caches and retrieves similar queries using embeddings. | |
| More advanced than exact match caching - understands semantic similarity. | |
| """ | |
| import numpy as np | |
| from typing import List, Dict, Any, Optional, Tuple | |
| import sqlite3 | |
| import hashlib | |
| import json | |
| import time | |
| from datetime import datetime, timedelta | |
| from pathlib import Path | |
| import faiss | |
| import logging | |
| from dataclasses import dataclass | |
| from enum import Enum | |
| from app.hyper_config import config | |
| from app.ultra_fast_embeddings import get_embedder | |
| logger = logging.getLogger(__name__) | |
| class CacheStrategy(str, Enum): | |
| EXACT = "exact" # Exact match only | |
| SEMANTIC = "semantic" # Semantic similarity | |
| HYBRID = "hybrid" # Both exact and semantic | |
| class CacheEntry: | |
| query: str | |
| query_hash: str | |
| query_embedding: np.ndarray | |
| answer: str | |
| chunks_used: List[str] | |
| metadata: Dict[str, Any] | |
| created_at: datetime | |
| accessed_at: datetime | |
| access_count: int | |
| ttl_seconds: int | |
| class SemanticCache: | |
| """ | |
| Advanced semantic cache that understands similar queries. | |
| Features: | |
| - Exact match caching | |
| - Semantic similarity caching | |
| - FAISS-based similarity search | |
| - TTL and LRU eviction | |
| - Adaptive similarity thresholds | |
| - Performance metrics | |
| """ | |
| def __init__( | |
| self, | |
| cache_dir: Optional[Path] = None, | |
| strategy: CacheStrategy = CacheStrategy.HYBRID, | |
| similarity_threshold: float = 0.85, | |
| max_cache_size: int = 10000, | |
| ttl_hours: int = 24 | |
| ): | |
| self.cache_dir = cache_dir or config.cache_dir | |
| self.cache_dir.mkdir(exist_ok=True) | |
| self.strategy = strategy | |
| self.similarity_threshold = similarity_threshold | |
| self.max_cache_size = max_cache_size | |
| self.ttl_hours = ttl_hours | |
| # Database connection | |
| self.db_path = self.cache_dir / "semantic_cache.db" | |
| self.conn = None | |
| # FAISS index for semantic search | |
| self.faiss_index = None | |
| self.embedding_dim = 384 # Default, will be updated | |
| self.entry_ids = [] # Map FAISS indices to cache entries | |
| # Embedder for semantic caching | |
| self.embedder = None | |
| # Performance metrics | |
| self.hits = 0 | |
| self.misses = 0 | |
| self.semantic_hits = 0 | |
| self.exact_hits = 0 | |
| self._initialized = False | |
| def initialize(self): | |
| """Initialize the cache database and FAISS index.""" | |
| if self._initialized: | |
| return | |
| logger.info(f"🚀 Initializing SemanticCache (strategy: {self.strategy.value})") | |
| # Initialize database | |
| self._init_database() | |
| # Initialize embedder for semantic caching | |
| if self.strategy in [CacheStrategy.SEMANTIC, CacheStrategy.HYBRID]: | |
| self.embedder = get_embedder() | |
| self.embedding_dim = 384 # Get from embedder | |
| # Initialize FAISS index for semantic search | |
| if self.strategy in [CacheStrategy.SEMANTIC, CacheStrategy.HYBRID]: | |
| self._init_faiss_index() | |
| # Load existing cache entries | |
| self._load_cache_entries() | |
| logger.info(f"✅ SemanticCache initialized with {len(self.entry_ids)} entries") | |
| self._initialized = True | |
| def _init_database(self): | |
| """Initialize the cache database.""" | |
| self.conn = sqlite3.connect(self.db_path) | |
| cursor = self.conn.cursor() | |
| # Create cache table | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS cache_entries ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| query TEXT NOT NULL, | |
| query_hash TEXT UNIQUE NOT NULL, | |
| query_embedding BLOB, | |
| answer TEXT NOT NULL, | |
| chunks_used_json TEXT NOT NULL, | |
| metadata_json TEXT NOT NULL, | |
| created_at TIMESTAMP NOT NULL, | |
| accessed_at TIMESTAMP NOT NULL, | |
| access_count INTEGER DEFAULT 1, | |
| ttl_seconds INTEGER NOT NULL, | |
| embedding_hash TEXT | |
| ) | |
| """) | |
| # Create indexes | |
| cursor.execute("CREATE INDEX IF NOT EXISTS idx_query_hash ON cache_entries(query_hash)") | |
| cursor.execute("CREATE INDEX IF NOT EXISTS idx_accessed_at ON cache_entries(accessed_at)") | |
| cursor.execute("CREATE INDEX IF NOT EXISTS idx_embedding_hash ON cache_entries(embedding_hash)") | |
| self.conn.commit() | |
| def _init_faiss_index(self): | |
| """Initialize FAISS index for semantic search.""" | |
| self.faiss_index = faiss.IndexFlatL2(self.embedding_dim) | |
| self.entry_ids = [] | |
| def _load_cache_entries(self): | |
| """Load existing cache entries into FAISS index.""" | |
| if self.strategy not in [CacheStrategy.SEMANTIC, CacheStrategy.HYBRID]: | |
| return | |
| cursor = self.conn.cursor() | |
| cursor.execute(""" | |
| SELECT id, query_embedding FROM cache_entries | |
| WHERE query_embedding IS NOT NULL | |
| ORDER BY accessed_at DESC | |
| LIMIT 1000 | |
| """) | |
| for entry_id, embedding_blob in cursor.fetchall(): | |
| if embedding_blob: | |
| embedding = np.frombuffer(embedding_blob, dtype=np.float32) | |
| self.faiss_index.add(embedding.reshape(1, -1)) | |
| self.entry_ids.append(entry_id) | |
| logger.info(f"Loaded {len(self.entry_ids)} entries into FAISS index") | |
| def get(self, query: str) -> Optional[Tuple[str, List[str]]]: | |
| """ | |
| Get cached answer for query. | |
| Returns: | |
| Tuple of (answer, chunks_used) or None if not found | |
| """ | |
| if not self._initialized: | |
| self.initialize() | |
| query_hash = self._hash_query(query) | |
| # Try exact match first | |
| if self.strategy in [CacheStrategy.EXACT, CacheStrategy.HYBRID]: | |
| result = self._get_exact(query_hash) | |
| if result: | |
| self.exact_hits += 1 | |
| self.hits += 1 | |
| return result | |
| # Try semantic match | |
| if self.strategy in [CacheStrategy.SEMANTIC, CacheStrategy.HYBRID]: | |
| result = self._get_semantic(query) | |
| if result: | |
| self.semantic_hits += 1 | |
| self.hits += 1 | |
| return result | |
| self.misses += 1 | |
| return None | |
| def _get_exact(self, query_hash: str) -> Optional[Tuple[str, List[str]]]: | |
| """Get exact match from cache.""" | |
| cursor = self.conn.cursor() | |
| cursor.execute(""" | |
| SELECT answer, chunks_used_json, accessed_at, ttl_seconds | |
| FROM cache_entries | |
| WHERE query_hash = ? | |
| LIMIT 1 | |
| """, (query_hash,)) | |
| row = cursor.fetchone() | |
| if not row: | |
| return None | |
| answer, chunks_used_json, accessed_at_str, ttl_seconds = row | |
| # Check TTL | |
| accessed_at = datetime.fromisoformat(accessed_at_str) | |
| if self._is_expired(accessed_at, ttl_seconds): | |
| self._delete_entry(query_hash) | |
| return None | |
| # Update access time | |
| self._update_access_time(query_hash) | |
| chunks_used = json.loads(chunks_used_json) | |
| return answer, chunks_used | |
| def _get_semantic(self, query: str) -> Optional[Tuple[str, List[str]]]: | |
| """Get semantic match from cache.""" | |
| if not self.embedder or not self.faiss_index or len(self.entry_ids) == 0: | |
| return None | |
| # Get query embedding | |
| query_embedding = self.embedder.embed_single(query) | |
| query_embedding = query_embedding.astype(np.float32).reshape(1, -1) | |
| # Search in FAISS index | |
| distances, indices = self.faiss_index.search(query_embedding, 3) # Top 3 | |
| # Check similarity threshold | |
| for i, (distance, idx) in enumerate(zip(distances[0], indices[0])): | |
| if idx >= 0 and idx < len(self.entry_ids): | |
| similarity = 1.0 / (1.0 + distance) # Convert distance to similarity | |
| if similarity >= self.similarity_threshold: | |
| entry_id = self.entry_ids[idx] | |
| # Get entry from database | |
| cursor = self.conn.cursor() | |
| cursor.execute(""" | |
| SELECT answer, chunks_used_json, accessed_at, ttl_seconds, query | |
| FROM cache_entries | |
| WHERE id = ? | |
| LIMIT 1 | |
| """, (entry_id,)) | |
| row = cursor.fetchone() | |
| if row: | |
| answer, chunks_used_json, accessed_at_str, ttl_seconds, original_query = row | |
| # Check TTL | |
| accessed_at = datetime.fromisoformat(accessed_at_str) | |
| if self._is_expired(accessed_at, ttl_seconds): | |
| self._delete_by_id(entry_id) | |
| continue | |
| # Update access time | |
| self._update_access_by_id(entry_id) | |
| chunks_used = json.loads(chunks_used_json) | |
| logger.debug(f"Semantic cache hit: similarity={similarity:.3f}, " | |
| f"original='{original_query[:30]}...', " | |
| f"current='{query[:30]}...'") | |
| return answer, chunks_used | |
| return None | |
| def put( | |
| self, | |
| query: str, | |
| answer: str, | |
| chunks_used: List[str], | |
| metadata: Optional[Dict[str, Any]] = None, | |
| ttl_seconds: Optional[int] = None | |
| ): | |
| """ | |
| Store query and answer in cache. | |
| Args: | |
| query: The user query | |
| answer: Generated answer | |
| chunks_used: List of chunks used for answer | |
| metadata: Additional metadata | |
| ttl_seconds: Time to live in seconds | |
| """ | |
| if not self._initialized: | |
| self.initialize() | |
| query_hash = self._hash_query(query) | |
| ttl = ttl_seconds or (self.ttl_hours * 3600) | |
| # Get query embedding for semantic caching | |
| query_embedding = None | |
| embedding_hash = None | |
| if self.strategy in [CacheStrategy.SEMANTIC, CacheStrategy.HYBRID] and self.embedder: | |
| embedding_result = self.embedder.embed_single(query) | |
| query_embedding = embedding_result.astype(np.float32).tobytes() | |
| embedding_hash = hashlib.md5(query_embedding).hexdigest() | |
| # Prepare data for database | |
| chunks_used_json = json.dumps(chunks_used) | |
| metadata_json = json.dumps(metadata or {}) | |
| now = datetime.now().isoformat() | |
| cursor = self.conn.cursor() | |
| try: | |
| # Try to insert new entry | |
| cursor.execute(""" | |
| INSERT INTO cache_entries ( | |
| query, query_hash, query_embedding, answer, chunks_used_json, | |
| metadata_json, created_at, accessed_at, ttl_seconds, embedding_hash | |
| ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) | |
| """, ( | |
| query, query_hash, query_embedding, answer, chunks_used_json, | |
| metadata_json, now, now, ttl, embedding_hash | |
| )) | |
| entry_id = cursor.lastrowid | |
| # Add to FAISS index if semantic caching | |
| if (self.strategy in [CacheStrategy.SEMANTIC, CacheStrategy.HYBRID] and | |
| query_embedding and self.faiss_index is not None): | |
| embedding = np.frombuffer(query_embedding, dtype=np.float32) | |
| self.faiss_index.add(embedding.reshape(1, -1)) | |
| self.entry_ids.append(entry_id) | |
| self.conn.commit() | |
| logger.debug(f"Cached query: '{query[:50]}...'") | |
| # Evict old entries if cache is too large | |
| self._evict_if_needed() | |
| except sqlite3.IntegrityError: | |
| # Entry already exists, update it | |
| self.conn.rollback() | |
| self._update_entry(query_hash, answer, chunks_used_json, metadata_json, now, ttl) | |
| def _update_entry( | |
| self, | |
| query_hash: str, | |
| answer: str, | |
| chunks_used_json: str, | |
| metadata_json: str, | |
| timestamp: str, | |
| ttl_seconds: int | |
| ): | |
| """Update existing cache entry.""" | |
| cursor = self.conn.cursor() | |
| cursor.execute(""" | |
| UPDATE cache_entries | |
| SET answer = ?, chunks_used_json = ?, metadata_json = ?, | |
| accessed_at = ?, ttl_seconds = ?, access_count = access_count + 1 | |
| WHERE query_hash = ? | |
| """, (answer, chunks_used_json, metadata_json, timestamp, ttl_seconds, query_hash)) | |
| self.conn.commit() | |
| def _update_access_time(self, query_hash: str): | |
| """Update access time for cache entry.""" | |
| cursor = self.conn.cursor() | |
| cursor.execute(""" | |
| UPDATE cache_entries | |
| SET accessed_at = ?, access_count = access_count + 1 | |
| WHERE query_hash = ? | |
| """, (datetime.now().isoformat(), query_hash)) | |
| self.conn.commit() | |
| def _update_access_by_id(self, entry_id: int): | |
| """Update access time by entry ID.""" | |
| cursor = self.conn.cursor() | |
| cursor.execute(""" | |
| UPDATE cache_entries | |
| SET accessed_at = ?, access_count = access_count + 1 | |
| WHERE id = ? | |
| """, (datetime.now().isoformat(), entry_id)) | |
| self.conn.commit() | |
| def _delete_entry(self, query_hash: str): | |
| """Delete cache entry by query hash.""" | |
| cursor = self.conn.cursor() | |
| # Get entry ID for FAISS removal | |
| cursor.execute("SELECT id FROM cache_entries WHERE query_hash = ?", (query_hash,)) | |
| row = cursor.fetchone() | |
| if row: | |
| entry_id = row[0] | |
| self._remove_from_faiss(entry_id) | |
| # Delete from database | |
| cursor.execute("DELETE FROM cache_entries WHERE query_hash = ?", (query_hash,)) | |
| self.conn.commit() | |
| def _delete_by_id(self, entry_id: int): | |
| """Delete cache entry by ID.""" | |
| self._remove_from_faiss(entry_id) | |
| cursor = self.conn.cursor() | |
| cursor.execute("DELETE FROM cache_entries WHERE id = ?", (entry_id,)) | |
| self.conn.commit() | |
| def _remove_from_faiss(self, entry_id: int): | |
| """Remove entry from FAISS index (simplified - FAISS doesn't support removal).""" | |
| # FAISS doesn't support removal, so we'll just mark for rebuild | |
| # In production, consider using IndexIDMap or rebuilding periodically | |
| if entry_id in self.entry_ids: | |
| idx = self.entry_ids.index(entry_id) | |
| # We can't remove from FAISS, so we'll just remove from our mapping | |
| # The index will be rebuilt on next load | |
| del self.entry_ids[idx] | |
| def _evict_if_needed(self): | |
| """Evict old entries if cache exceeds max size.""" | |
| cursor = self.conn.cursor() | |
| cursor.execute("SELECT COUNT(*) FROM cache_entries") | |
| count = cursor.fetchone()[0] | |
| if count > self.max_cache_size: | |
| # Delete oldest accessed entries | |
| cursor.execute(""" | |
| DELETE FROM cache_entries | |
| WHERE id IN ( | |
| SELECT id FROM cache_entries | |
| ORDER BY accessed_at ASC | |
| LIMIT ? | |
| ) | |
| """, (count - self.max_cache_size,)) | |
| self.conn.commit() | |
| # Rebuild FAISS index | |
| if self.strategy in [CacheStrategy.SEMANTIC, CacheStrategy.HYBRID]: | |
| self._rebuild_faiss_index() | |
| def _rebuild_faiss_index(self): | |
| """Rebuild FAISS index from database.""" | |
| if self.faiss_index: | |
| self.faiss_index.reset() | |
| self.entry_ids = [] | |
| self._load_cache_entries() | |
| def _hash_query(self, query: str) -> str: | |
| """Create hash for query.""" | |
| return hashlib.md5(query.encode()).hexdigest() | |
| def _is_expired(self, accessed_at: datetime, ttl_seconds: int) -> bool: | |
| """Check if cache entry is expired.""" | |
| expiry_time = accessed_at + timedelta(seconds=ttl_seconds) | |
| return datetime.now() > expiry_time | |
| def clear(self): | |
| """Clear all cache entries.""" | |
| cursor = self.conn.cursor() | |
| cursor.execute("DELETE FROM cache_entries") | |
| self.conn.commit() | |
| if self.faiss_index: | |
| self.faiss_index.reset() | |
| self.entry_ids = [] | |
| logger.info("Cache cleared") | |
| def get_stats(self) -> Dict[str, Any]: | |
| """Get cache statistics.""" | |
| cursor = self.conn.cursor() | |
| cursor.execute("SELECT COUNT(*) FROM cache_entries") | |
| total_entries = cursor.fetchone()[0] | |
| cursor.execute("SELECT SUM(access_count) FROM cache_entries") | |
| total_accesses = cursor.fetchone()[0] or 0 | |
| cursor.execute(""" | |
| SELECT COUNT(*) FROM cache_entries | |
| WHERE datetime(accessed_at) < datetime('now', '-7 days') | |
| """) | |
| stale_entries = cursor.fetchone()[0] | |
| hit_rate = self.hits / (self.hits + self.misses) if (self.hits + self.misses) > 0 else 0 | |
| return { | |
| "total_entries": total_entries, | |
| "total_accesses": total_accesses, | |
| "stale_entries": stale_entries, | |
| "hits": self.hits, | |
| "misses": self.misses, | |
| "exact_hits": self.exact_hits, | |
| "semantic_hits": self.semantic_hits, | |
| "hit_rate": hit_rate, | |
| "strategy": self.strategy.value, | |
| "similarity_threshold": self.similarity_threshold, | |
| "faiss_entries": len(self.entry_ids) | |
| } | |
| def __del__(self): | |
| """Cleanup.""" | |
| if self.conn: | |
| self.conn.close() | |
| # Global cache instance | |
| _cache_instance = None | |
| def get_semantic_cache() -> SemanticCache: | |
| """Get or create the global semantic cache instance.""" | |
| global _cache_instance | |
| if _cache_instance is None: | |
| _cache_instance = SemanticCache( | |
| strategy=CacheStrategy.HYBRID, | |
| similarity_threshold=0.85, | |
| max_cache_size=5000, | |
| ttl_hours=24 | |
| ) | |
| _cache_instance.initialize() | |
| return _cache_instance | |
| # Test function | |
| if __name__ == "__main__": | |
| import logging | |
| logging.basicConfig(level=logging.INFO) | |
| print("\n🧪 Testing SemanticCache...") | |
| cache = SemanticCache( | |
| strategy=CacheStrategy.HYBRID, | |
| similarity_threshold=0.8, | |
| max_cache_size=100 | |
| ) | |
| cache.initialize() | |
| # Test exact caching | |
| print("\n📝 Testing exact caching...") | |
| query1 = "What is machine learning?" | |
| answer1 = "Machine learning is a subset of AI that enables systems to learn from data." | |
| chunks1 = ["chunk1", "chunk2"] | |
| cache.put(query1, answer1, chunks1) | |
| cached = cache.get(query1) | |
| if cached: | |
| print(f" Exact cache HIT: {cached[0][:50]}...") | |
| else: | |
| print(" Exact cache MISS") | |
| # Test semantic caching | |
| print("\n📝 Testing semantic caching...") | |
| similar_query = "Can you explain machine learning?" | |
| cached = cache.get(similar_query) | |
| if cached: | |
| print(f" Semantic cache HIT: {cached[0][:50]}...") | |
| else: | |
| print(" Semantic cache MISS (might need lower threshold)") | |
| # Test non-similar query | |
| print("\n📝 Testing non-similar query...") | |
| different_query = "What is the capital of France?" | |
| cached = cache.get(different_query) | |
| if cached: | |
| print(f" Unexpected HIT: {cached[0][:50]}...") | |
| else: | |
| print(" Expected MISS") | |
| # Get stats | |
| stats = cache.get_stats() | |
| print("\n📊 Cache Statistics:") | |
| for key, value in stats.items(): | |
| print(f" {key}: {value}") | |
| # Clear cache | |
| cache.clear() | |
| print("\n🧹 Cache cleared") | |