File size: 8,109 Bytes
dbb04e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
"""

Batch Processing and GPU Acceleration (Phase 3.5.5)

===================================================

Implementation of batch operations for HDVs, leveraging PyTorch for GPU acceleration

if available, with fallback to NumPy (CPU).



Designed to scale comfortably from Raspberry Pi (CPU) to dedicated AI rigs (CUDA).

"""

import multiprocessing
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import List, Tuple, Optional

import numpy as np
from loguru import logger

# Try importing torch, handle failure gracefully (CPU only environment)
try:
    import torch
    TORCH_AVAILABLE = True
except ImportError:
    torch = None  # type: ignore[assignment]
    TORCH_AVAILABLE = False

from .binary_hdv import BinaryHDV, TextEncoder, batch_hamming_distance


def _encode_single_worker(args: tuple) -> bytes:
    """Module-level worker function for ProcessPoolExecutor (must be picklable)."""
    text, dim = args
    encoder = TextEncoder(dimension=dim)
    hdv = encoder.encode(text)
    return hdv.to_bytes()


class BatchProcessor:
    """

    Handles batched operations for HDV encoding and search.

    Automatically selects the best available backend (CUDA > MPS > CPU).

    """

    def __init__(self, use_gpu: bool = True, num_workers: Optional[int] = None):
        """

        Args:

            use_gpu: Whether to attempt using GPU acceleration.

            num_workers: Number of CPU workers for encoding (defaults to CPU count).

        """
        self.device = self._detect_device(use_gpu)
        self.num_workers = num_workers or multiprocessing.cpu_count()
        self.popcount_table_gpu = None  # Lazy init

        logger.info(f"BatchProcessor initialized on device: {self.device}")
        
        # Initialize text encoder for workers (pickled)
        # Note: TextEncoder is lightweight, so re-init in workers is fine.

    def _detect_device(self, use_gpu: bool) -> str:
        """Detect the best available compute device."""
        if not use_gpu or not TORCH_AVAILABLE:
            return "cpu"

        if torch.cuda.is_available():
            return "cuda"
        elif torch.backends.mps.is_available():
            return "mps"
        else:
            return "cpu"

    def _ensure_gpu_table(self):
        """Initialize bits-set lookup table on GPU if needed."""
        if self.device == "cpu" or self.popcount_table_gpu is not None:
            return

        # Precompute table matching numpy's _build_popcount_table
        # 256 values (0-255), value is number of bits set
        table = torch.tensor(
            [bin(i).count("1") for i in range(256)], 
            dtype=torch.int32,  # int32 suffices (max 8)
            device=self.device
        )
        self.popcount_table_gpu = table

    def encode_batch(self, texts: List[str], dimension: int = 16384) -> List[BinaryHDV]:
        """

        Encode a batch of texts into BinaryHDVs using parallel CPU processing.

        

        Encoding logic is strictly CPU-bound (tokenization + python loops),

        so we use ProcessPoolExecutor to bypass the GIL.

        """
        if not texts:
            return []

        results = [None] * len(texts)

        # Parallel execution using module-level worker (picklable)
        with ProcessPoolExecutor(max_workers=self.num_workers) as executor:
            future_to_idx = {
                executor.submit(_encode_single_worker, (text, dimension)): i
                for i, text in enumerate(texts)
            }

            for future in as_completed(future_to_idx):
                idx = future_to_idx[future]
                try:
                    raw_bytes = future.result()
                    results[idx] = BinaryHDV.from_bytes(raw_bytes, dimension=dimension)
                except Exception as e:
                    logger.error(f"Encoding failed for item {idx}: {e}")
                    results[idx] = BinaryHDV.zeros(dimension)

        return results

    def search_batch(

        self, 

        queries: List[BinaryHDV], 

        targets: List[BinaryHDV]

    ) -> np.ndarray:
        """

        Compute Hamming distance matrix between queries and targets.

        

        Args:

            queries: List of M query vectors.

            targets: List of N target vectors.

            

        Returns:

            np.ndarray of shape (M, N) with Hamming distances.

        """
        if not queries or not targets:
            return np.array([[]])

        # Prepare data containers
        d_bytes = queries[0].dimension // 8
        
        # Convert to numpy arrays first (fast packing)
        # Shape: (M, D//8) and (N, D//8)
        query_arr = np.stack([q.data for q in queries])
        target_arr = np.stack([t.data for t in targets])
        
        if self.device == "cpu":
            return self._search_cpu(query_arr, target_arr)
        else:
            return self._search_gpu(query_arr, target_arr)

    def _search_cpu(self, query_arr: np.ndarray, target_arr: np.ndarray) -> np.ndarray:
        """NumPy-based batch Hamming distance."""
        # query_arr: (M, B), target_arr: (N, B)
        # We need (M, N) distance matrix.
        # Broadcasting: (M, 1, B) x (1, N, B) -> (M, N, B)
        # Memory warning: M*N*B bytes could be large.
        
        M, B = query_arr.shape
        N = target_arr.shape[0]
        
        # Optimization: Process in chunks if M*N is large to avoid OOM
        # For now, simplistic implementation
        
        dists = np.zeros((M, N), dtype=np.int32)
        
        # Iterate over queries to save memory (broadcast takes M*N*B RAM)
        # Using the batch_hamming_distance from binary_hdv which is (1 vs N)
        # We can reuse the logic but call it M times? 
        # Or reimplement broadcasting with chunks.
        
        # Reusing binary_hdv logic for safety but adapting to array inputs
        # batch_hamming_distance in binary_hdv takes (BinaryHDV, database)
        # We have raw arrays here.
        
        # Let's import the table builder
        from .binary_hdv import _build_popcount_table
        popcount_table = _build_popcount_table()
        
        for i in range(M):
            # (N, B) XOR (B,) -> (N, B)
            xor_result = np.bitwise_xor(target_arr, query_arr[i])
            # (N, B) -> (N,) sums
            dists[i] = popcount_table[xor_result].sum(axis=1)
            
        return dists

    def _search_gpu(self, query_arr: np.ndarray, target_arr: np.ndarray) -> np.ndarray:
        """PyTorch-based batch Hamming distance."""
        self._ensure_gpu_table()
        
        # Transfer to GPU
        # uint8 in numpy -> uint8 in torch (ByteTensor)
        q_tensor = torch.from_numpy(query_arr).to(self.device)  # (M, B)
        t_tensor = torch.from_numpy(target_arr).to(self.device) # (N, B)
        
        M = q_tensor.shape[0]
        N = t_tensor.shape[0]
        
        # Result matrix
        dists = torch.zeros((M, N), dtype=torch.int32, device=self.device)
        
        # Chunking to fit in VRAM?
        # A 16k-bit vector is 2KB. 1M vectors = 2GB. 
        # (M, 1, B) x (1, N, B) -> (M, N, B) int8 is huge.
        # XOR is huge. We must loop queries or targets.
        
        # Loop over queries (like CPU implementation) is safest for VRAM
        for i in range(M):
            # (1, B) XOR (N, B) -> (N, B)
            # Bitwise XOR supported on ByteTensor? Yes.
            xor_result = torch.bitwise_xor(t_tensor, q_tensor[i])
            
            # Lookup popcount: (N, B) [0-255] -> int32 count
            # We treat xor_result as indices into the table
            # xor_result is uint8, needs to be long/int64 for indexing in torch usually
            counts = self.popcount_table_gpu[xor_result.long()]
            
            # Sum bits
            dists[i] = counts.sum(dim=1)
            
        return dists.cpu().numpy()