SPARKNET / src /rag /embeddings.py
MHamdan's picture
Initial commit: SPARKNET framework
d520909
"""
Embedding Adapters for RAG Subsystem
Provides:
- Abstract EmbeddingAdapter interface
- Ollama embeddings (local, default)
- OpenAI embeddings (optional, feature-flagged)
"""
from abc import ABC, abstractmethod
from typing import List, Optional, Union
from pydantic import BaseModel, Field
from loguru import logger
import hashlib
import json
from pathlib import Path
try:
import httpx
HTTPX_AVAILABLE = True
except ImportError:
HTTPX_AVAILABLE = False
try:
import openai
OPENAI_AVAILABLE = True
except ImportError:
OPENAI_AVAILABLE = False
class EmbeddingConfig(BaseModel):
"""Configuration for embedding adapters."""
# Adapter selection
adapter_type: str = Field(
default="ollama",
description="Embedding adapter type: ollama, openai"
)
# Ollama settings
ollama_base_url: str = Field(
default="http://localhost:11434",
description="Ollama API base URL"
)
ollama_model: str = Field(
default="nomic-embed-text",
description="Ollama embedding model (nomic-embed-text, mxbai-embed-large)"
)
# OpenAI settings (feature-flagged)
openai_enabled: bool = Field(
default=False,
description="Enable OpenAI embeddings"
)
openai_model: str = Field(
default="text-embedding-3-small",
description="OpenAI embedding model"
)
openai_api_key: Optional[str] = Field(
default=None,
description="OpenAI API key (or use OPENAI_API_KEY env var)"
)
# Common settings
batch_size: int = Field(default=32, ge=1, description="Batch size for embedding")
timeout: float = Field(default=60.0, ge=1.0, description="Request timeout in seconds")
# Caching
enable_cache: bool = Field(default=True, description="Enable embedding cache")
cache_directory: str = Field(
default="./data/embedding_cache",
description="Cache directory for embeddings"
)
class EmbeddingAdapter(ABC):
"""Abstract interface for embedding adapters."""
@abstractmethod
def embed_text(self, text: str) -> List[float]:
"""
Embed a single text.
Args:
text: Text to embed
Returns:
Embedding vector
"""
pass
@abstractmethod
def embed_batch(self, texts: List[str]) -> List[List[float]]:
"""
Embed multiple texts.
Args:
texts: List of texts to embed
Returns:
List of embedding vectors
"""
pass
@property
@abstractmethod
def embedding_dimension(self) -> int:
"""Return embedding dimension."""
pass
@property
@abstractmethod
def model_name(self) -> str:
"""Return model name."""
pass
class EmbeddingCache:
"""Simple file-based embedding cache."""
def __init__(self, cache_dir: str, model_name: str):
"""Initialize cache."""
self.cache_dir = Path(cache_dir) / model_name.replace("/", "_")
self.cache_dir.mkdir(parents=True, exist_ok=True)
self._memory_cache: dict = {}
def _hash_text(self, text: str) -> str:
"""Generate cache key from text."""
return hashlib.sha256(text.encode()).hexdigest()[:32]
def get(self, text: str) -> Optional[List[float]]:
"""Get cached embedding."""
key = self._hash_text(text)
# Check memory cache first
if key in self._memory_cache:
return self._memory_cache[key]
# Check file cache
cache_file = self.cache_dir / f"{key}.json"
if cache_file.exists():
try:
with open(cache_file, "r") as f:
embedding = json.load(f)
self._memory_cache[key] = embedding
return embedding
except:
pass
return None
def put(self, text: str, embedding: List[float]):
"""Cache embedding."""
key = self._hash_text(text)
# Memory cache
self._memory_cache[key] = embedding
# File cache
cache_file = self.cache_dir / f"{key}.json"
try:
with open(cache_file, "w") as f:
json.dump(embedding, f)
except Exception as e:
logger.warning(f"Failed to cache embedding: {e}")
class OllamaEmbedding(EmbeddingAdapter):
"""
Ollama embedding adapter for local embeddings.
Supports models:
- nomic-embed-text (768 dimensions, recommended)
- mxbai-embed-large (1024 dimensions)
- all-minilm (384 dimensions)
"""
# Known embedding dimensions
MODEL_DIMENSIONS = {
"nomic-embed-text": 768,
"mxbai-embed-large": 1024,
"all-minilm": 384,
"snowflake-arctic-embed": 1024,
}
def __init__(self, config: Optional[EmbeddingConfig] = None):
"""Initialize Ollama embedding adapter."""
if not HTTPX_AVAILABLE:
raise ImportError("httpx is required for Ollama. Install with: pip install httpx")
self.config = config or EmbeddingConfig()
self._base_url = self.config.ollama_base_url.rstrip("/")
self._model = self.config.ollama_model
self._dimension: Optional[int] = self.MODEL_DIMENSIONS.get(self._model)
# Initialize cache if enabled
self._cache: Optional[EmbeddingCache] = None
if self.config.enable_cache:
self._cache = EmbeddingCache(self.config.cache_directory, self._model)
logger.info(f"OllamaEmbedding initialized: {self._model}")
def embed_text(self, text: str) -> List[float]:
"""Embed a single text."""
# Check cache
if self._cache:
cached = self._cache.get(text)
if cached is not None:
return cached
# Call Ollama API
with httpx.Client(timeout=self.config.timeout) as client:
response = client.post(
f"{self._base_url}/api/embeddings",
json={
"model": self._model,
"prompt": text,
}
)
response.raise_for_status()
result = response.json()
embedding = result["embedding"]
# Update dimension if not known
if self._dimension is None:
self._dimension = len(embedding)
# Cache result
if self._cache:
self._cache.put(text, embedding)
return embedding
def embed_batch(self, texts: List[str]) -> List[List[float]]:
"""Embed multiple texts."""
embeddings = []
for i in range(0, len(texts), self.config.batch_size):
batch = texts[i:i + self.config.batch_size]
for text in batch:
embedding = self.embed_text(text)
embeddings.append(embedding)
return embeddings
@property
def embedding_dimension(self) -> int:
"""Return embedding dimension."""
if self._dimension is None:
# Probe with a test embedding
test_embedding = self.embed_text("test")
self._dimension = len(test_embedding)
return self._dimension
@property
def model_name(self) -> str:
"""Return model name."""
return f"ollama/{self._model}"
class OpenAIEmbedding(EmbeddingAdapter):
"""
OpenAI embedding adapter (feature-flagged).
Supports models:
- text-embedding-3-small (1536 dimensions)
- text-embedding-3-large (3072 dimensions)
- text-embedding-ada-002 (1536 dimensions, legacy)
"""
MODEL_DIMENSIONS = {
"text-embedding-3-small": 1536,
"text-embedding-3-large": 3072,
"text-embedding-ada-002": 1536,
}
def __init__(self, config: Optional[EmbeddingConfig] = None):
"""Initialize OpenAI embedding adapter."""
if not OPENAI_AVAILABLE:
raise ImportError("openai is required. Install with: pip install openai")
self.config = config or EmbeddingConfig()
if not self.config.openai_enabled:
raise ValueError("OpenAI embeddings not enabled in config")
self._model = self.config.openai_model
self._dimension = self.MODEL_DIMENSIONS.get(self._model, 1536)
# Initialize OpenAI client
api_key = self.config.openai_api_key
self._client = openai.OpenAI(api_key=api_key) if api_key else openai.OpenAI()
# Initialize cache if enabled
self._cache: Optional[EmbeddingCache] = None
if self.config.enable_cache:
self._cache = EmbeddingCache(self.config.cache_directory, self._model)
logger.info(f"OpenAIEmbedding initialized: {self._model}")
def embed_text(self, text: str) -> List[float]:
"""Embed a single text."""
# Check cache
if self._cache:
cached = self._cache.get(text)
if cached is not None:
return cached
# Call OpenAI API
response = self._client.embeddings.create(
model=self._model,
input=text,
)
embedding = response.data[0].embedding
# Cache result
if self._cache:
self._cache.put(text, embedding)
return embedding
def embed_batch(self, texts: List[str]) -> List[List[float]]:
"""Embed multiple texts."""
embeddings = []
for i in range(0, len(texts), self.config.batch_size):
batch = texts[i:i + self.config.batch_size]
# Check cache for batch
to_embed = []
cached_indices = {}
for j, text in enumerate(batch):
if self._cache:
cached = self._cache.get(text)
if cached is not None:
cached_indices[j] = cached
continue
to_embed.append((j, text))
# Embed uncached texts
if to_embed:
indices, texts_to_embed = zip(*to_embed)
response = self._client.embeddings.create(
model=self._model,
input=list(texts_to_embed),
)
for idx, (j, text) in enumerate(to_embed):
embedding = response.data[idx].embedding
cached_indices[j] = embedding
if self._cache:
self._cache.put(text, embedding)
# Reconstruct batch order
for j in range(len(batch)):
embeddings.append(cached_indices[j])
return embeddings
@property
def embedding_dimension(self) -> int:
"""Return embedding dimension."""
return self._dimension
@property
def model_name(self) -> str:
"""Return model name."""
return f"openai/{self._model}"
# Factory function
_embedding_adapter: Optional[EmbeddingAdapter] = None
def get_embedding_adapter(
config: Optional[EmbeddingConfig] = None,
) -> EmbeddingAdapter:
"""
Get or create singleton embedding adapter.
Args:
config: Embedding configuration
Returns:
EmbeddingAdapter instance
"""
global _embedding_adapter
if _embedding_adapter is None:
config = config or EmbeddingConfig()
if config.adapter_type == "openai" and config.openai_enabled:
_embedding_adapter = OpenAIEmbedding(config)
else:
# Default to Ollama
_embedding_adapter = OllamaEmbedding(config)
return _embedding_adapter
def reset_embedding_adapter():
"""Reset the global embedding adapter instance."""
global _embedding_adapter
_embedding_adapter = None