""" Embedding Adapters for RAG Subsystem Provides: - Abstract EmbeddingAdapter interface - Ollama embeddings (local, default) - OpenAI embeddings (optional, feature-flagged) """ from abc import ABC, abstractmethod from typing import List, Optional, Union from pydantic import BaseModel, Field from loguru import logger import hashlib import json from pathlib import Path try: import httpx HTTPX_AVAILABLE = True except ImportError: HTTPX_AVAILABLE = False try: import openai OPENAI_AVAILABLE = True except ImportError: OPENAI_AVAILABLE = False class EmbeddingConfig(BaseModel): """Configuration for embedding adapters.""" # Adapter selection adapter_type: str = Field( default="ollama", description="Embedding adapter type: ollama, openai" ) # Ollama settings ollama_base_url: str = Field( default="http://localhost:11434", description="Ollama API base URL" ) ollama_model: str = Field( default="nomic-embed-text", description="Ollama embedding model (nomic-embed-text, mxbai-embed-large)" ) # OpenAI settings (feature-flagged) openai_enabled: bool = Field( default=False, description="Enable OpenAI embeddings" ) openai_model: str = Field( default="text-embedding-3-small", description="OpenAI embedding model" ) openai_api_key: Optional[str] = Field( default=None, description="OpenAI API key (or use OPENAI_API_KEY env var)" ) # Common settings batch_size: int = Field(default=32, ge=1, description="Batch size for embedding") timeout: float = Field(default=60.0, ge=1.0, description="Request timeout in seconds") # Caching enable_cache: bool = Field(default=True, description="Enable embedding cache") cache_directory: str = Field( default="./data/embedding_cache", description="Cache directory for embeddings" ) class EmbeddingAdapter(ABC): """Abstract interface for embedding adapters.""" @abstractmethod def embed_text(self, text: str) -> List[float]: """ Embed a single text. Args: text: Text to embed Returns: Embedding vector """ pass @abstractmethod def embed_batch(self, texts: List[str]) -> List[List[float]]: """ Embed multiple texts. Args: texts: List of texts to embed Returns: List of embedding vectors """ pass @property @abstractmethod def embedding_dimension(self) -> int: """Return embedding dimension.""" pass @property @abstractmethod def model_name(self) -> str: """Return model name.""" pass class EmbeddingCache: """Simple file-based embedding cache.""" def __init__(self, cache_dir: str, model_name: str): """Initialize cache.""" self.cache_dir = Path(cache_dir) / model_name.replace("/", "_") self.cache_dir.mkdir(parents=True, exist_ok=True) self._memory_cache: dict = {} def _hash_text(self, text: str) -> str: """Generate cache key from text.""" return hashlib.sha256(text.encode()).hexdigest()[:32] def get(self, text: str) -> Optional[List[float]]: """Get cached embedding.""" key = self._hash_text(text) # Check memory cache first if key in self._memory_cache: return self._memory_cache[key] # Check file cache cache_file = self.cache_dir / f"{key}.json" if cache_file.exists(): try: with open(cache_file, "r") as f: embedding = json.load(f) self._memory_cache[key] = embedding return embedding except: pass return None def put(self, text: str, embedding: List[float]): """Cache embedding.""" key = self._hash_text(text) # Memory cache self._memory_cache[key] = embedding # File cache cache_file = self.cache_dir / f"{key}.json" try: with open(cache_file, "w") as f: json.dump(embedding, f) except Exception as e: logger.warning(f"Failed to cache embedding: {e}") class OllamaEmbedding(EmbeddingAdapter): """ Ollama embedding adapter for local embeddings. Supports models: - nomic-embed-text (768 dimensions, recommended) - mxbai-embed-large (1024 dimensions) - all-minilm (384 dimensions) """ # Known embedding dimensions MODEL_DIMENSIONS = { "nomic-embed-text": 768, "mxbai-embed-large": 1024, "all-minilm": 384, "snowflake-arctic-embed": 1024, } def __init__(self, config: Optional[EmbeddingConfig] = None): """Initialize Ollama embedding adapter.""" if not HTTPX_AVAILABLE: raise ImportError("httpx is required for Ollama. Install with: pip install httpx") self.config = config or EmbeddingConfig() self._base_url = self.config.ollama_base_url.rstrip("/") self._model = self.config.ollama_model self._dimension: Optional[int] = self.MODEL_DIMENSIONS.get(self._model) # Initialize cache if enabled self._cache: Optional[EmbeddingCache] = None if self.config.enable_cache: self._cache = EmbeddingCache(self.config.cache_directory, self._model) logger.info(f"OllamaEmbedding initialized: {self._model}") def embed_text(self, text: str) -> List[float]: """Embed a single text.""" # Check cache if self._cache: cached = self._cache.get(text) if cached is not None: return cached # Call Ollama API with httpx.Client(timeout=self.config.timeout) as client: response = client.post( f"{self._base_url}/api/embeddings", json={ "model": self._model, "prompt": text, } ) response.raise_for_status() result = response.json() embedding = result["embedding"] # Update dimension if not known if self._dimension is None: self._dimension = len(embedding) # Cache result if self._cache: self._cache.put(text, embedding) return embedding def embed_batch(self, texts: List[str]) -> List[List[float]]: """Embed multiple texts.""" embeddings = [] for i in range(0, len(texts), self.config.batch_size): batch = texts[i:i + self.config.batch_size] for text in batch: embedding = self.embed_text(text) embeddings.append(embedding) return embeddings @property def embedding_dimension(self) -> int: """Return embedding dimension.""" if self._dimension is None: # Probe with a test embedding test_embedding = self.embed_text("test") self._dimension = len(test_embedding) return self._dimension @property def model_name(self) -> str: """Return model name.""" return f"ollama/{self._model}" class OpenAIEmbedding(EmbeddingAdapter): """ OpenAI embedding adapter (feature-flagged). Supports models: - text-embedding-3-small (1536 dimensions) - text-embedding-3-large (3072 dimensions) - text-embedding-ada-002 (1536 dimensions, legacy) """ MODEL_DIMENSIONS = { "text-embedding-3-small": 1536, "text-embedding-3-large": 3072, "text-embedding-ada-002": 1536, } def __init__(self, config: Optional[EmbeddingConfig] = None): """Initialize OpenAI embedding adapter.""" if not OPENAI_AVAILABLE: raise ImportError("openai is required. Install with: pip install openai") self.config = config or EmbeddingConfig() if not self.config.openai_enabled: raise ValueError("OpenAI embeddings not enabled in config") self._model = self.config.openai_model self._dimension = self.MODEL_DIMENSIONS.get(self._model, 1536) # Initialize OpenAI client api_key = self.config.openai_api_key self._client = openai.OpenAI(api_key=api_key) if api_key else openai.OpenAI() # Initialize cache if enabled self._cache: Optional[EmbeddingCache] = None if self.config.enable_cache: self._cache = EmbeddingCache(self.config.cache_directory, self._model) logger.info(f"OpenAIEmbedding initialized: {self._model}") def embed_text(self, text: str) -> List[float]: """Embed a single text.""" # Check cache if self._cache: cached = self._cache.get(text) if cached is not None: return cached # Call OpenAI API response = self._client.embeddings.create( model=self._model, input=text, ) embedding = response.data[0].embedding # Cache result if self._cache: self._cache.put(text, embedding) return embedding def embed_batch(self, texts: List[str]) -> List[List[float]]: """Embed multiple texts.""" embeddings = [] for i in range(0, len(texts), self.config.batch_size): batch = texts[i:i + self.config.batch_size] # Check cache for batch to_embed = [] cached_indices = {} for j, text in enumerate(batch): if self._cache: cached = self._cache.get(text) if cached is not None: cached_indices[j] = cached continue to_embed.append((j, text)) # Embed uncached texts if to_embed: indices, texts_to_embed = zip(*to_embed) response = self._client.embeddings.create( model=self._model, input=list(texts_to_embed), ) for idx, (j, text) in enumerate(to_embed): embedding = response.data[idx].embedding cached_indices[j] = embedding if self._cache: self._cache.put(text, embedding) # Reconstruct batch order for j in range(len(batch)): embeddings.append(cached_indices[j]) return embeddings @property def embedding_dimension(self) -> int: """Return embedding dimension.""" return self._dimension @property def model_name(self) -> str: """Return model name.""" return f"openai/{self._model}" # Factory function _embedding_adapter: Optional[EmbeddingAdapter] = None def get_embedding_adapter( config: Optional[EmbeddingConfig] = None, ) -> EmbeddingAdapter: """ Get or create singleton embedding adapter. Args: config: Embedding configuration Returns: EmbeddingAdapter instance """ global _embedding_adapter if _embedding_adapter is None: config = config or EmbeddingConfig() if config.adapter_type == "openai" and config.openai_enabled: _embedding_adapter = OpenAIEmbedding(config) else: # Default to Ollama _embedding_adapter = OllamaEmbedding(config) return _embedding_adapter def reset_embedding_adapter(): """Reset the global embedding adapter instance.""" global _embedding_adapter _embedding_adapter = None