|
|
""" |
|
|
Embedding Adapters for RAG Subsystem |
|
|
|
|
|
Provides: |
|
|
- Abstract EmbeddingAdapter interface |
|
|
- Ollama embeddings (local, default) |
|
|
- OpenAI embeddings (optional, feature-flagged) |
|
|
""" |
|
|
|
|
|
from abc import ABC, abstractmethod |
|
|
from typing import List, Optional, Union |
|
|
from pydantic import BaseModel, Field |
|
|
from loguru import logger |
|
|
import hashlib |
|
|
import json |
|
|
from pathlib import Path |
|
|
|
|
|
try: |
|
|
import httpx |
|
|
HTTPX_AVAILABLE = True |
|
|
except ImportError: |
|
|
HTTPX_AVAILABLE = False |
|
|
|
|
|
try: |
|
|
import openai |
|
|
OPENAI_AVAILABLE = True |
|
|
except ImportError: |
|
|
OPENAI_AVAILABLE = False |
|
|
|
|
|
|
|
|
class EmbeddingConfig(BaseModel): |
|
|
"""Configuration for embedding adapters.""" |
|
|
|
|
|
adapter_type: str = Field( |
|
|
default="ollama", |
|
|
description="Embedding adapter type: ollama, openai" |
|
|
) |
|
|
|
|
|
|
|
|
ollama_base_url: str = Field( |
|
|
default="http://localhost:11434", |
|
|
description="Ollama API base URL" |
|
|
) |
|
|
ollama_model: str = Field( |
|
|
default="nomic-embed-text", |
|
|
description="Ollama embedding model (nomic-embed-text, mxbai-embed-large)" |
|
|
) |
|
|
|
|
|
|
|
|
openai_enabled: bool = Field( |
|
|
default=False, |
|
|
description="Enable OpenAI embeddings" |
|
|
) |
|
|
openai_model: str = Field( |
|
|
default="text-embedding-3-small", |
|
|
description="OpenAI embedding model" |
|
|
) |
|
|
openai_api_key: Optional[str] = Field( |
|
|
default=None, |
|
|
description="OpenAI API key (or use OPENAI_API_KEY env var)" |
|
|
) |
|
|
|
|
|
|
|
|
batch_size: int = Field(default=32, ge=1, description="Batch size for embedding") |
|
|
timeout: float = Field(default=60.0, ge=1.0, description="Request timeout in seconds") |
|
|
|
|
|
|
|
|
enable_cache: bool = Field(default=True, description="Enable embedding cache") |
|
|
cache_directory: str = Field( |
|
|
default="./data/embedding_cache", |
|
|
description="Cache directory for embeddings" |
|
|
) |
|
|
|
|
|
|
|
|
class EmbeddingAdapter(ABC): |
|
|
"""Abstract interface for embedding adapters.""" |
|
|
|
|
|
@abstractmethod |
|
|
def embed_text(self, text: str) -> List[float]: |
|
|
""" |
|
|
Embed a single text. |
|
|
|
|
|
Args: |
|
|
text: Text to embed |
|
|
|
|
|
Returns: |
|
|
Embedding vector |
|
|
""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
def embed_batch(self, texts: List[str]) -> List[List[float]]: |
|
|
""" |
|
|
Embed multiple texts. |
|
|
|
|
|
Args: |
|
|
texts: List of texts to embed |
|
|
|
|
|
Returns: |
|
|
List of embedding vectors |
|
|
""" |
|
|
pass |
|
|
|
|
|
@property |
|
|
@abstractmethod |
|
|
def embedding_dimension(self) -> int: |
|
|
"""Return embedding dimension.""" |
|
|
pass |
|
|
|
|
|
@property |
|
|
@abstractmethod |
|
|
def model_name(self) -> str: |
|
|
"""Return model name.""" |
|
|
pass |
|
|
|
|
|
|
|
|
class EmbeddingCache: |
|
|
"""Simple file-based embedding cache.""" |
|
|
|
|
|
def __init__(self, cache_dir: str, model_name: str): |
|
|
"""Initialize cache.""" |
|
|
self.cache_dir = Path(cache_dir) / model_name.replace("/", "_") |
|
|
self.cache_dir.mkdir(parents=True, exist_ok=True) |
|
|
self._memory_cache: dict = {} |
|
|
|
|
|
def _hash_text(self, text: str) -> str: |
|
|
"""Generate cache key from text.""" |
|
|
return hashlib.sha256(text.encode()).hexdigest()[:32] |
|
|
|
|
|
def get(self, text: str) -> Optional[List[float]]: |
|
|
"""Get cached embedding.""" |
|
|
key = self._hash_text(text) |
|
|
|
|
|
|
|
|
if key in self._memory_cache: |
|
|
return self._memory_cache[key] |
|
|
|
|
|
|
|
|
cache_file = self.cache_dir / f"{key}.json" |
|
|
if cache_file.exists(): |
|
|
try: |
|
|
with open(cache_file, "r") as f: |
|
|
embedding = json.load(f) |
|
|
self._memory_cache[key] = embedding |
|
|
return embedding |
|
|
except: |
|
|
pass |
|
|
|
|
|
return None |
|
|
|
|
|
def put(self, text: str, embedding: List[float]): |
|
|
"""Cache embedding.""" |
|
|
key = self._hash_text(text) |
|
|
|
|
|
|
|
|
self._memory_cache[key] = embedding |
|
|
|
|
|
|
|
|
cache_file = self.cache_dir / f"{key}.json" |
|
|
try: |
|
|
with open(cache_file, "w") as f: |
|
|
json.dump(embedding, f) |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to cache embedding: {e}") |
|
|
|
|
|
|
|
|
class OllamaEmbedding(EmbeddingAdapter): |
|
|
""" |
|
|
Ollama embedding adapter for local embeddings. |
|
|
|
|
|
Supports models: |
|
|
- nomic-embed-text (768 dimensions, recommended) |
|
|
- mxbai-embed-large (1024 dimensions) |
|
|
- all-minilm (384 dimensions) |
|
|
""" |
|
|
|
|
|
|
|
|
MODEL_DIMENSIONS = { |
|
|
"nomic-embed-text": 768, |
|
|
"mxbai-embed-large": 1024, |
|
|
"all-minilm": 384, |
|
|
"snowflake-arctic-embed": 1024, |
|
|
} |
|
|
|
|
|
def __init__(self, config: Optional[EmbeddingConfig] = None): |
|
|
"""Initialize Ollama embedding adapter.""" |
|
|
if not HTTPX_AVAILABLE: |
|
|
raise ImportError("httpx is required for Ollama. Install with: pip install httpx") |
|
|
|
|
|
self.config = config or EmbeddingConfig() |
|
|
self._base_url = self.config.ollama_base_url.rstrip("/") |
|
|
self._model = self.config.ollama_model |
|
|
self._dimension: Optional[int] = self.MODEL_DIMENSIONS.get(self._model) |
|
|
|
|
|
|
|
|
self._cache: Optional[EmbeddingCache] = None |
|
|
if self.config.enable_cache: |
|
|
self._cache = EmbeddingCache(self.config.cache_directory, self._model) |
|
|
|
|
|
logger.info(f"OllamaEmbedding initialized: {self._model}") |
|
|
|
|
|
def embed_text(self, text: str) -> List[float]: |
|
|
"""Embed a single text.""" |
|
|
|
|
|
if self._cache: |
|
|
cached = self._cache.get(text) |
|
|
if cached is not None: |
|
|
return cached |
|
|
|
|
|
|
|
|
with httpx.Client(timeout=self.config.timeout) as client: |
|
|
response = client.post( |
|
|
f"{self._base_url}/api/embeddings", |
|
|
json={ |
|
|
"model": self._model, |
|
|
"prompt": text, |
|
|
} |
|
|
) |
|
|
response.raise_for_status() |
|
|
result = response.json() |
|
|
|
|
|
embedding = result["embedding"] |
|
|
|
|
|
|
|
|
if self._dimension is None: |
|
|
self._dimension = len(embedding) |
|
|
|
|
|
|
|
|
if self._cache: |
|
|
self._cache.put(text, embedding) |
|
|
|
|
|
return embedding |
|
|
|
|
|
def embed_batch(self, texts: List[str]) -> List[List[float]]: |
|
|
"""Embed multiple texts.""" |
|
|
embeddings = [] |
|
|
|
|
|
for i in range(0, len(texts), self.config.batch_size): |
|
|
batch = texts[i:i + self.config.batch_size] |
|
|
|
|
|
for text in batch: |
|
|
embedding = self.embed_text(text) |
|
|
embeddings.append(embedding) |
|
|
|
|
|
return embeddings |
|
|
|
|
|
@property |
|
|
def embedding_dimension(self) -> int: |
|
|
"""Return embedding dimension.""" |
|
|
if self._dimension is None: |
|
|
|
|
|
test_embedding = self.embed_text("test") |
|
|
self._dimension = len(test_embedding) |
|
|
return self._dimension |
|
|
|
|
|
@property |
|
|
def model_name(self) -> str: |
|
|
"""Return model name.""" |
|
|
return f"ollama/{self._model}" |
|
|
|
|
|
|
|
|
class OpenAIEmbedding(EmbeddingAdapter): |
|
|
""" |
|
|
OpenAI embedding adapter (feature-flagged). |
|
|
|
|
|
Supports models: |
|
|
- text-embedding-3-small (1536 dimensions) |
|
|
- text-embedding-3-large (3072 dimensions) |
|
|
- text-embedding-ada-002 (1536 dimensions, legacy) |
|
|
""" |
|
|
|
|
|
MODEL_DIMENSIONS = { |
|
|
"text-embedding-3-small": 1536, |
|
|
"text-embedding-3-large": 3072, |
|
|
"text-embedding-ada-002": 1536, |
|
|
} |
|
|
|
|
|
def __init__(self, config: Optional[EmbeddingConfig] = None): |
|
|
"""Initialize OpenAI embedding adapter.""" |
|
|
if not OPENAI_AVAILABLE: |
|
|
raise ImportError("openai is required. Install with: pip install openai") |
|
|
|
|
|
self.config = config or EmbeddingConfig() |
|
|
|
|
|
if not self.config.openai_enabled: |
|
|
raise ValueError("OpenAI embeddings not enabled in config") |
|
|
|
|
|
self._model = self.config.openai_model |
|
|
self._dimension = self.MODEL_DIMENSIONS.get(self._model, 1536) |
|
|
|
|
|
|
|
|
api_key = self.config.openai_api_key |
|
|
self._client = openai.OpenAI(api_key=api_key) if api_key else openai.OpenAI() |
|
|
|
|
|
|
|
|
self._cache: Optional[EmbeddingCache] = None |
|
|
if self.config.enable_cache: |
|
|
self._cache = EmbeddingCache(self.config.cache_directory, self._model) |
|
|
|
|
|
logger.info(f"OpenAIEmbedding initialized: {self._model}") |
|
|
|
|
|
def embed_text(self, text: str) -> List[float]: |
|
|
"""Embed a single text.""" |
|
|
|
|
|
if self._cache: |
|
|
cached = self._cache.get(text) |
|
|
if cached is not None: |
|
|
return cached |
|
|
|
|
|
|
|
|
response = self._client.embeddings.create( |
|
|
model=self._model, |
|
|
input=text, |
|
|
) |
|
|
|
|
|
embedding = response.data[0].embedding |
|
|
|
|
|
|
|
|
if self._cache: |
|
|
self._cache.put(text, embedding) |
|
|
|
|
|
return embedding |
|
|
|
|
|
def embed_batch(self, texts: List[str]) -> List[List[float]]: |
|
|
"""Embed multiple texts.""" |
|
|
embeddings = [] |
|
|
|
|
|
for i in range(0, len(texts), self.config.batch_size): |
|
|
batch = texts[i:i + self.config.batch_size] |
|
|
|
|
|
|
|
|
to_embed = [] |
|
|
cached_indices = {} |
|
|
|
|
|
for j, text in enumerate(batch): |
|
|
if self._cache: |
|
|
cached = self._cache.get(text) |
|
|
if cached is not None: |
|
|
cached_indices[j] = cached |
|
|
continue |
|
|
to_embed.append((j, text)) |
|
|
|
|
|
|
|
|
if to_embed: |
|
|
indices, texts_to_embed = zip(*to_embed) |
|
|
response = self._client.embeddings.create( |
|
|
model=self._model, |
|
|
input=list(texts_to_embed), |
|
|
) |
|
|
|
|
|
for idx, (j, text) in enumerate(to_embed): |
|
|
embedding = response.data[idx].embedding |
|
|
cached_indices[j] = embedding |
|
|
|
|
|
if self._cache: |
|
|
self._cache.put(text, embedding) |
|
|
|
|
|
|
|
|
for j in range(len(batch)): |
|
|
embeddings.append(cached_indices[j]) |
|
|
|
|
|
return embeddings |
|
|
|
|
|
@property |
|
|
def embedding_dimension(self) -> int: |
|
|
"""Return embedding dimension.""" |
|
|
return self._dimension |
|
|
|
|
|
@property |
|
|
def model_name(self) -> str: |
|
|
"""Return model name.""" |
|
|
return f"openai/{self._model}" |
|
|
|
|
|
|
|
|
|
|
|
_embedding_adapter: Optional[EmbeddingAdapter] = None |
|
|
|
|
|
|
|
|
def get_embedding_adapter( |
|
|
config: Optional[EmbeddingConfig] = None, |
|
|
) -> EmbeddingAdapter: |
|
|
""" |
|
|
Get or create singleton embedding adapter. |
|
|
|
|
|
Args: |
|
|
config: Embedding configuration |
|
|
|
|
|
Returns: |
|
|
EmbeddingAdapter instance |
|
|
""" |
|
|
global _embedding_adapter |
|
|
|
|
|
if _embedding_adapter is None: |
|
|
config = config or EmbeddingConfig() |
|
|
|
|
|
if config.adapter_type == "openai" and config.openai_enabled: |
|
|
_embedding_adapter = OpenAIEmbedding(config) |
|
|
else: |
|
|
|
|
|
_embedding_adapter = OllamaEmbedding(config) |
|
|
|
|
|
return _embedding_adapter |
|
|
|
|
|
|
|
|
def reset_embedding_adapter(): |
|
|
"""Reset the global embedding adapter instance.""" |
|
|
global _embedding_adapter |
|
|
_embedding_adapter = None |
|
|
|