test / modules /proj_head.py
jaewooo's picture
Initial upload
de15dc5 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
class ProjectionHead(nn.Module):
def __init__(self, in_dim: int = 512, out_dim: int = 3008, activation: str = "tanh"):
super().__init__()
self.linear = nn.Linear(in_dim, out_dim, bias=False)
act = activation.lower()
if act == "tanh":
self.act = nn.Tanh()
elif act == "relu":
self.act = nn.ReLU()
elif act == "gelu":
self.act = nn.GELU()
elif act == "sigmoid":
self.act = nn.Sigmoid()
else:
raise ValueError(f"Unsupported activation: {activation}")
nn.init.normal_(self.linear.weight, mean=0.0, std=0.02)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.act(self.linear(x))
def binarize_sign(x: torch.Tensor) -> torch.Tensor:
"""Return {0,1} bits: 1 if x>0 else 0, dtype int64.
Expects x on any device; returns tensor on same device.
"""
return (x > 0).to(torch.int64)
def pack_bits_64(bits: torch.Tensor, dim_features: int) -> torch.Tensor:
"""Pack {0,1} bits to int64 words of length 64.
Args:
bits: Tensor (N, D) with 0/1 int64 entries
dim_features: D, must be divisible by 64
Returns:
Tensor (N, D//64) int64 on same device.
"""
assert dim_features % 64 == 0, "proj_dim must be divisible by 64 for bit-packing"
if bits.dtype != torch.int64:
bits = bits.to(torch.int64)
N = bits.size(0)
words = dim_features // 64
bits = bits.view(N, words, 64)
shifts = torch.arange(64, device=bits.device, dtype=torch.int64)
packed = (bits << shifts).sum(-1)
return packed.contiguous()
def _popcount64(x: torch.Tensor) -> torch.Tensor:
"""Compute population count for each int64 element using bit hacks.
Returns same shape tensor with counts in int64.
"""
# Constants must be int64 on the same device
m1 = torch.tensor(0x5555555555555555, dtype=torch.int64, device=x.device)
m2 = torch.tensor(0x3333333333333333, dtype=torch.int64, device=x.device)
m4 = torch.tensor(0x0F0F0F0F0F0F0F0F, dtype=torch.int64, device=x.device)
x = x - ((x >> 1) & m1)
x = (x & m2) + ((x >> 2) & m2)
x = (x + (x >> 4)) & m4
x = x + (x >> 8)
x = x + (x >> 16)
x = x + (x >> 32)
return x & torch.tensor(0x7F, dtype=torch.int64, device=x.device)
def hamming_distance_packed(a_words: torch.Tensor, b_words: torch.Tensor, block: int = 1024) -> torch.Tensor:
"""Compute pairwise Hamming distances between two packed code sets.
Args:
a_words: (Na, W) int64 packed codes
b_words: (Nb, W) int64 packed codes
block: block size for A to limit memory
Returns:
dist: (Na, Nb) int64 distances
"""
assert a_words.dtype == torch.int64 and b_words.dtype == torch.int64
Na, W = a_words.shape
Nb, Wb = b_words.shape
assert W == Wb
device = a_words.device
out = torch.empty((Na, Nb), dtype=torch.int64, device=device)
for s in range(0, Na, block):
e = min(Na, s + block)
aw = a_words[s:e] # (bs, W)
# (bs, Nb, W)
xor = aw.unsqueeze(1) ^ b_words.unsqueeze(0)
pc = _popcount64(xor)
dist = pc.sum(-1) # (bs, Nb)
out[s:e] = dist
return out