File size: 3,953 Bytes
b62e029
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee15d5b
 
 
b62e029
 
 
 
 
ee15d5b
b62e029
ee15d5b
 
 
 
b62e029
ee15d5b
 
 
b62e029
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# models/embedder.py

from typing import Any, Dict, List

import torch
from FlagEmbedding import BGEM3FlagModel
from pydantic import BaseModel

from core.exceptions import ModelLoadError
from core.logger import setup_logger

logger = setup_logger("embedder")

# Data structure for return (Type Hinting)
class EmbedderResult(BaseModel):
    dense_vector: List[float]
    sparse_indices: List[int]
    sparse_values: List[float]

class TextEmbedder:
    """
    Converts the input text into Dense Vectors and Sparse Vectors (Lexical Weights) using the BGE-M3 model.
    """
    def __init__(self, model_name: str = "BAAI/bge-m3", use_fp16: bool = False):
        self.model_name = model_name
        self.device = self._get_device()
        
        try:
            logger.info(f"⏳ Loading Embedder Model: {self.model_name} on {self.device}")
            self.model = BGEM3FlagModel(
                self.model_name, 
                use_fp16=(use_fp16 and self.device.startswith("cuda"))
            )
            self._warmup()
            logger.info("✅ Embedder Model loaded successfully.")
        except Exception as e:
            logger.critical(f"❌ Failed to load Embedder Model: {e}", exc_info=True)
            raise ModelLoadError(f"Embedder initialization failed: {e}")

    def _get_device(self) -> str:
        if torch.cuda.is_available():
            return "cuda"
        elif torch.backends.mps.is_available():
            return "mps" # Apple Silicon
        return "cpu"
    
    def _warmup(self):
        logger.info("Warming up embedder model with a dummy input.")
        self.encode_query("Hello world")

    def encode_query(self, text: str) -> EmbedderResult:
        """
        Converts a single query text into Qdrant hybrid search format.
        """
        try:
            # 1. model inference to get dense vector and sparse lexical weights
            output = self.model.encode(
                text, 
                return_dense=True, 
                return_sparse=True, 
                return_colbert_vecs=False
            )
            
            dense_vec = output['dense_vecs'].tolist()
            lexical_weights: Dict[str, float] = output['lexical_weights']
            
            # 2. Sparse Vector Transformation and Duplicate Index Removal
            # Prevent duplicates by storing in the format {token_id: max_weight}
            unique_sparse_data = {}
            
            # Convert text tokens into unique IDs (integers) using the BGE-M3 tokenizer
            for token_str, weight in lexical_weights.items():
                # Get the ID of the string token through the tokenizer (vocab index)
                token_id = self.model.tokenizer.convert_tokens_to_ids(token_str)
                # If the same token_id appears, maintain the higher weight
                if token_id is not None:
                    unique_sparse_data[token_id] = max(
                        unique_sparse_data.get(token_id, 0.0), 
                        float(weight)
                    )
                    
            sparse_indices = list(unique_sparse_data.keys())
            sparse_values = list(unique_sparse_data.values())

            return EmbedderResult(
                dense_vector=dense_vec,
                sparse_indices=sparse_indices,
                sparse_values=sparse_values
            )
            
        except Exception as e:
            logger.error(f"Failed to encode query '{text}': {e}")
            raise RuntimeError(f"Embedding generation failed: {e}")

    # options for batch encoding if needed in the future
    def encode_documents(self, texts: List[str], batch_size: int = 12) -> Dict[str, Any]:
        return self.model.encode(
            texts, 
            batch_size=batch_size, 
            max_length=8192, # BGE-M3's max token length
            return_dense=True, 
            return_sparse=True, 
            return_colbert_vecs=False
        )