SparseAE / extract.py
nancyH's picture
Upload folder using huggingface_hub
b46126b verified
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, RandomSampler
from collections import defaultdict
from tqdm import tqdm
# =====================
# 1. SETUP & DATA LOADING
# =====================
print("Loading sequence / phyloP data...")
dna_path = "/home/n5huang/dna_token/SparseAE/chr1_dna.txt"
phy_path = "/home/n5huang/dna_token/SparseAE/chr1_phyloP_norm.npy"
with open(dna_path) as f:
sequence = f.read().strip()
phy_norm = np.load(phy_path)
assert len(sequence) == len(phy_norm), "DNA and phyloP length mismatch!"
chrom_len = len(sequence)
print(f"Chromosome 1 length: {chrom_len:,} bp")
# =====================
# 2. DNA ENCODING (HANDLE 'N')
# =====================
print("Encoding DNA to one-hot (with N handling)...")
mapping = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
# Map bases to ints, using 4 as "N/unknown"
dna_int = np.fromiter((mapping.get(b, 4) for b in sequence), dtype=np.int8)
num_N = np.sum(dna_int == 4)
print(f"Number of N bases: {num_N:,}")
# One-hot with an extra row for N
# 0=A,1=C,2=G,3=T,4=[0,0,0,0,1]
temp_onehot = np.eye(5, dtype=np.float32)[dna_int]
# Slice to first 4 columns: N -> [0,0,0,0]
dna_onehot = temp_onehot[:, :4] # shape (chrom_len, 4)
# =====================
# 3. PHYLOP CHECK + COMBINE
# =====================
print("Preparing combined tensor...")
# Assume phy_norm is already in [-1,1]; warn if not.
max_abs_phy = np.max(np.abs(phy_norm))
if max_abs_phy > 1.1:
print(f"WARNING: phy_norm max abs={max_abs_phy:.3f} > 1.1; "
f"data may not be normalized as expected.")
phy_norm = phy_norm.astype(np.float32)
phy_col = phy_norm.reshape(-1, 1) # (chrom_len, 1)
combined_np = np.concatenate([dna_onehot, phy_col], axis=1) # (chrom_len, 5)
combined_tensor = torch.from_numpy(combined_np) # CPU tensor
print(f"Master tensor shape: {combined_tensor.shape}")
# =====================
# 4. DATASET: CHUNKED WINDOWING
# =====================
L = 50
class ChunkedChr1Dataset(Dataset):
def __init__(self, combined, L=50):
self.combined = combined
self.L = L
self.N = combined.shape[0] - L # number of valid start positions
def __len__(self):
return self.N
def __getitem__(self, idx):
# window: (L, 5)
window = self.combined[idx : idx + self.L]
dna = window[:, :4] # (L, 4)
phy = window[:, 4] # (L,)
return dna, phy, idx
dataset = ChunkedChr1Dataset(combined_tensor, L=L)
print(f"Dataset length (#windows): {len(dataset):,}")
# =====================
# 5. DATALOADER WITH RANDOM SAMPLER
# =====================
BATCH_SIZE = 1024
SAMPLES_PER_EPOCH = 5_000_000 # number of windows per epoch (tunable)
sampler = RandomSampler(
dataset,
replacement=True,
num_samples=SAMPLES_PER_EPOCH
)
loader = DataLoader(
dataset,
batch_size=BATCH_SIZE,
sampler=sampler,
shuffle=False, # <--- MUST BE FALSE for mapping back to genome
drop_last=False, # <--- Process every last bit
num_workers=0, # safer on large dataset
pin_memory=True
)
print("DataLoader ready.")
# =====================
# 6. MODEL: SPARSE AUTOENCODER
# =====================
INPUT_DIM = L * 5 # 4 DNA + 1 phyloP
LATENT_DIM = 2048
HIDDEN_DIM = 1024
class SparseAE(nn.Module):
def __init__(self, input_dim=INPUT_DIM, latent_dim=LATENT_DIM, hidden_dim=HIDDEN_DIM):
super().__init__()
# Encoder
self.encoder = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, latent_dim),
nn.ReLU() # ReLU helps sparsity with L1
)
# Decoder shared
self.dec_hidden = nn.Linear(latent_dim, hidden_dim)
# Decoder heads
self.dec_dna = nn.Linear(hidden_dim, L * 4)
self.dec_phy = nn.Linear(hidden_dim, L * 1)
def forward(self, dna, phy):
B = dna.size(0)
x = torch.cat(
[dna.reshape(B, -1), phy.reshape(B, -1)],
dim=1
) # (B, INPUT_DIM)
h = self.encoder(x)
dec = F.relu(self.dec_hidden(h))
recon_dna = self.dec_dna(dec).reshape(B, L, 4) # (B, L, 4)
recon_phy = torch.tanh(self.dec_phy(dec)).reshape(B, L) # (B, L)
return recon_dna, recon_phy, h
########################################
# 4. LOAD CHECKPOINT
########################################
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
model = SparseAE().to(device)
model.load_state_dict(torch.load("sparse_ae_50bp_epoch3.pt", map_location=device))
model.eval()
print("Model loaded.")
########################################
# 5. TOKEN EXTRACTION
########################################
print("Extracting tokens...")
all_token_ids = np.zeros(len(dataset), dtype=np.int32)
#h_values = np.zeros((len(dataset), LATENT_DIM), dtype=np.float32)
with torch.no_grad():
offset = 0
for dna_batch, phy_batch, idx_batch in tqdm(loader):
dna_batch = dna_batch.to(device).float()
phy_batch = phy_batch.to(device).float()
_, _, h = model(dna_batch, phy_batch)
h_cpu = h.cpu().numpy()
# argmax token
token_ids = np.argmax(h_cpu, axis=1)
all_token_ids[offset : offset + len(token_ids)] = token_ids
#h_values[offset : offset + len(token_ids)] = h_cpu
offset += len(token_ids)
print("Token extraction complete.")
np.save("token_ids.npy", all_token_ids)
#np.save("latent_h.npy", h_values)
# Histogram
hist = np.bincount(all_token_ids, minlength=LATENT_DIM)
np.save("token_hist.npy", hist)
print("Top tokens:")
#top_tokens = np.argsort(hist)[::-1][:20]
top_tokens = np.argsort(hist)[:20]
for t in top_tokens:
print(f"Token {t}: count={hist[t]}")
########################################
# 6. MOTIF SUMMARY FOR TOP TOKENS
########################################
print("\nBuilding PWM + average PhyloP for top tokens...")
# Initialize accumulators for ALL tokens
pwm_sum = {t: np.zeros((L, 4), dtype=np.float32) for t in range(LATENT_DIM)}
phy_sum = {t: np.zeros(L, dtype=np.float32) for t in range(LATENT_DIM)}
counts = {t: 0 for t in range(LATENT_DIM)}
print("Accumulating statistics (this may take 15-30 mins)...")
limit = len(all_token_ids) - L
for i in tqdm(range(limit)):
t = all_token_ids[i]
# Always accumulate
window = combined_np[i : i+L]
pwm_sum[t] += window[:, :4]
phy_sum[t] += window[:, 4]
counts[t] += 1
########################################
# 6A. Save per-token PWMs & phylo profiles
########################################
print("Saving profiles...")
for t in range(LATENT_DIM):
if counts[t] == 0:
continue
pwm = pwm_sum[t] / counts[t]
avg_phy = phy_sum[t] / counts[t]
np.save(f"token{t}_pwm.npy", pwm)
np.save(f"token{t}_phy.npy", avg_phy)
########################################
# 6B. Rank tokens by PhyloP and rarity
########################################
avg_phylop_per_token = np.zeros(LATENT_DIM)
count_per_token = np.zeros(LATENT_DIM)
for t in range(LATENT_DIM):
if counts[t] > 0:
avg_phylop_per_token[t] = (phy_sum[t] / counts[t]).mean()
count_per_token[t] = counts[t]
else:
avg_phylop_per_token[t] = -999
count_per_token[t] = 0
# Rank by PhyloP (high to low)
tokens_by_phylop = np.argsort(avg_phylop_per_token)[::-1]
top_phy_tokens = tokens_by_phylop[:20]
# Rank by rarity (low to high)
rare_tokens = np.argsort(count_per_token)[:20]
print("Top 20 conserved tokens:", top_phy_tokens)
print("Top 20 rarest tokens:", rare_tokens)
print("\n=== Extraction Completed Successfully ===")