| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| |
|
| | class ProjectionHead(nn.Module): |
| | def __init__(self, in_dim: int = 512, out_dim: int = 3008, activation: str = "tanh"): |
| | super().__init__() |
| | self.linear = nn.Linear(in_dim, out_dim, bias=False) |
| | act = activation.lower() |
| | if act == "tanh": |
| | self.act = nn.Tanh() |
| | elif act == "relu": |
| | self.act = nn.ReLU() |
| | elif act == "gelu": |
| | self.act = nn.GELU() |
| | elif act == "sigmoid": |
| | self.act = nn.Sigmoid() |
| | else: |
| | raise ValueError(f"Unsupported activation: {activation}") |
| | nn.init.normal_(self.linear.weight, mean=0.0, std=0.02) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | return self.act(self.linear(x)) |
| |
|
| |
|
| | def binarize_sign(x: torch.Tensor) -> torch.Tensor: |
| | """Return {0,1} bits: 1 if x>0 else 0, dtype int64. |
| | Expects x on any device; returns tensor on same device. |
| | """ |
| | return (x > 0).to(torch.int64) |
| |
|
| |
|
| | def pack_bits_64(bits: torch.Tensor, dim_features: int) -> torch.Tensor: |
| | """Pack {0,1} bits to int64 words of length 64. |
| | |
| | Args: |
| | bits: Tensor (N, D) with 0/1 int64 entries |
| | dim_features: D, must be divisible by 64 |
| | Returns: |
| | Tensor (N, D//64) int64 on same device. |
| | """ |
| | assert dim_features % 64 == 0, "proj_dim must be divisible by 64 for bit-packing" |
| | if bits.dtype != torch.int64: |
| | bits = bits.to(torch.int64) |
| | N = bits.size(0) |
| | words = dim_features // 64 |
| | bits = bits.view(N, words, 64) |
| | shifts = torch.arange(64, device=bits.device, dtype=torch.int64) |
| | packed = (bits << shifts).sum(-1) |
| | return packed.contiguous() |
| |
|
| |
|
| | def _popcount64(x: torch.Tensor) -> torch.Tensor: |
| | """Compute population count for each int64 element using bit hacks. |
| | Returns same shape tensor with counts in int64. |
| | """ |
| | |
| | m1 = torch.tensor(0x5555555555555555, dtype=torch.int64, device=x.device) |
| | m2 = torch.tensor(0x3333333333333333, dtype=torch.int64, device=x.device) |
| | m4 = torch.tensor(0x0F0F0F0F0F0F0F0F, dtype=torch.int64, device=x.device) |
| | x = x - ((x >> 1) & m1) |
| | x = (x & m2) + ((x >> 2) & m2) |
| | x = (x + (x >> 4)) & m4 |
| | x = x + (x >> 8) |
| | x = x + (x >> 16) |
| | x = x + (x >> 32) |
| | return x & torch.tensor(0x7F, dtype=torch.int64, device=x.device) |
| |
|
| |
|
| | def hamming_distance_packed(a_words: torch.Tensor, b_words: torch.Tensor, block: int = 1024) -> torch.Tensor: |
| | """Compute pairwise Hamming distances between two packed code sets. |
| | |
| | Args: |
| | a_words: (Na, W) int64 packed codes |
| | b_words: (Nb, W) int64 packed codes |
| | block: block size for A to limit memory |
| | Returns: |
| | dist: (Na, Nb) int64 distances |
| | """ |
| | assert a_words.dtype == torch.int64 and b_words.dtype == torch.int64 |
| | Na, W = a_words.shape |
| | Nb, Wb = b_words.shape |
| | assert W == Wb |
| | device = a_words.device |
| | out = torch.empty((Na, Nb), dtype=torch.int64, device=device) |
| | for s in range(0, Na, block): |
| | e = min(Na, s + block) |
| | aw = a_words[s:e] |
| | |
| | xor = aw.unsqueeze(1) ^ b_words.unsqueeze(0) |
| | pc = _popcount64(xor) |
| | dist = pc.sum(-1) |
| | out[s:e] = dist |
| | return out |
| |
|
| |
|