import torch import torch.nn as nn import torch.nn.functional as F class ProjectionHead(nn.Module): def __init__(self, in_dim: int = 512, out_dim: int = 3008, activation: str = "tanh"): super().__init__() self.linear = nn.Linear(in_dim, out_dim, bias=False) act = activation.lower() if act == "tanh": self.act = nn.Tanh() elif act == "relu": self.act = nn.ReLU() elif act == "gelu": self.act = nn.GELU() elif act == "sigmoid": self.act = nn.Sigmoid() else: raise ValueError(f"Unsupported activation: {activation}") nn.init.normal_(self.linear.weight, mean=0.0, std=0.02) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.act(self.linear(x)) def binarize_sign(x: torch.Tensor) -> torch.Tensor: """Return {0,1} bits: 1 if x>0 else 0, dtype int64. Expects x on any device; returns tensor on same device. """ return (x > 0).to(torch.int64) def pack_bits_64(bits: torch.Tensor, dim_features: int) -> torch.Tensor: """Pack {0,1} bits to int64 words of length 64. Args: bits: Tensor (N, D) with 0/1 int64 entries dim_features: D, must be divisible by 64 Returns: Tensor (N, D//64) int64 on same device. """ assert dim_features % 64 == 0, "proj_dim must be divisible by 64 for bit-packing" if bits.dtype != torch.int64: bits = bits.to(torch.int64) N = bits.size(0) words = dim_features // 64 bits = bits.view(N, words, 64) shifts = torch.arange(64, device=bits.device, dtype=torch.int64) packed = (bits << shifts).sum(-1) return packed.contiguous() def _popcount64(x: torch.Tensor) -> torch.Tensor: """Compute population count for each int64 element using bit hacks. Returns same shape tensor with counts in int64. """ # Constants must be int64 on the same device m1 = torch.tensor(0x5555555555555555, dtype=torch.int64, device=x.device) m2 = torch.tensor(0x3333333333333333, dtype=torch.int64, device=x.device) m4 = torch.tensor(0x0F0F0F0F0F0F0F0F, dtype=torch.int64, device=x.device) x = x - ((x >> 1) & m1) x = (x & m2) + ((x >> 2) & m2) x = (x + (x >> 4)) & m4 x = x + (x >> 8) x = x + (x >> 16) x = x + (x >> 32) return x & torch.tensor(0x7F, dtype=torch.int64, device=x.device) def hamming_distance_packed(a_words: torch.Tensor, b_words: torch.Tensor, block: int = 1024) -> torch.Tensor: """Compute pairwise Hamming distances between two packed code sets. Args: a_words: (Na, W) int64 packed codes b_words: (Nb, W) int64 packed codes block: block size for A to limit memory Returns: dist: (Na, Nb) int64 distances """ assert a_words.dtype == torch.int64 and b_words.dtype == torch.int64 Na, W = a_words.shape Nb, Wb = b_words.shape assert W == Wb device = a_words.device out = torch.empty((Na, Nb), dtype=torch.int64, device=device) for s in range(0, Na, block): e = min(Na, s + block) aw = a_words[s:e] # (bs, W) # (bs, Nb, W) xor = aw.unsqueeze(1) ^ b_words.unsqueeze(0) pc = _popcount64(xor) dist = pc.sum(-1) # (bs, Nb) out[s:e] = dist return out