import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader, RandomSampler from collections import defaultdict from tqdm import tqdm # ===================== # 1. SETUP & DATA LOADING # ===================== print("Loading sequence / phyloP data...") dna_path = "/home/n5huang/dna_token/SparseAE/chr1_dna.txt" phy_path = "/home/n5huang/dna_token/SparseAE/chr1_phyloP_norm.npy" with open(dna_path) as f: sequence = f.read().strip() phy_norm = np.load(phy_path) assert len(sequence) == len(phy_norm), "DNA and phyloP length mismatch!" chrom_len = len(sequence) print(f"Chromosome 1 length: {chrom_len:,} bp") # ===================== # 2. DNA ENCODING (HANDLE 'N') # ===================== print("Encoding DNA to one-hot (with N handling)...") mapping = {'A': 0, 'C': 1, 'G': 2, 'T': 3} # Map bases to ints, using 4 as "N/unknown" dna_int = np.fromiter((mapping.get(b, 4) for b in sequence), dtype=np.int8) num_N = np.sum(dna_int == 4) print(f"Number of N bases: {num_N:,}") # One-hot with an extra row for N # 0=A,1=C,2=G,3=T,4=[0,0,0,0,1] temp_onehot = np.eye(5, dtype=np.float32)[dna_int] # Slice to first 4 columns: N -> [0,0,0,0] dna_onehot = temp_onehot[:, :4] # shape (chrom_len, 4) # ===================== # 3. PHYLOP CHECK + COMBINE # ===================== print("Preparing combined tensor...") # Assume phy_norm is already in [-1,1]; warn if not. max_abs_phy = np.max(np.abs(phy_norm)) if max_abs_phy > 1.1: print(f"WARNING: phy_norm max abs={max_abs_phy:.3f} > 1.1; " f"data may not be normalized as expected.") phy_norm = phy_norm.astype(np.float32) phy_col = phy_norm.reshape(-1, 1) # (chrom_len, 1) combined_np = np.concatenate([dna_onehot, phy_col], axis=1) # (chrom_len, 5) combined_tensor = torch.from_numpy(combined_np) # CPU tensor print(f"Master tensor shape: {combined_tensor.shape}") # ===================== # 4. DATASET: CHUNKED WINDOWING # ===================== L = 50 class ChunkedChr1Dataset(Dataset): def __init__(self, combined, L=50): self.combined = combined self.L = L self.N = combined.shape[0] - L # number of valid start positions def __len__(self): return self.N def __getitem__(self, idx): # window: (L, 5) window = self.combined[idx : idx + self.L] dna = window[:, :4] # (L, 4) phy = window[:, 4] # (L,) return dna, phy, idx dataset = ChunkedChr1Dataset(combined_tensor, L=L) print(f"Dataset length (#windows): {len(dataset):,}") # ===================== # 5. DATALOADER WITH RANDOM SAMPLER # ===================== BATCH_SIZE = 1024 SAMPLES_PER_EPOCH = 5_000_000 # number of windows per epoch (tunable) sampler = RandomSampler( dataset, replacement=True, num_samples=SAMPLES_PER_EPOCH ) loader = DataLoader( dataset, batch_size=BATCH_SIZE, sampler=sampler, shuffle=False, # <--- MUST BE FALSE for mapping back to genome drop_last=False, # <--- Process every last bit num_workers=0, # safer on large dataset pin_memory=True ) print("DataLoader ready.") # ===================== # 6. MODEL: SPARSE AUTOENCODER # ===================== INPUT_DIM = L * 5 # 4 DNA + 1 phyloP LATENT_DIM = 2048 HIDDEN_DIM = 1024 class SparseAE(nn.Module): def __init__(self, input_dim=INPUT_DIM, latent_dim=LATENT_DIM, hidden_dim=HIDDEN_DIM): super().__init__() # Encoder self.encoder = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, latent_dim), nn.ReLU() # ReLU helps sparsity with L1 ) # Decoder shared self.dec_hidden = nn.Linear(latent_dim, hidden_dim) # Decoder heads self.dec_dna = nn.Linear(hidden_dim, L * 4) self.dec_phy = nn.Linear(hidden_dim, L * 1) def forward(self, dna, phy): B = dna.size(0) x = torch.cat( [dna.reshape(B, -1), phy.reshape(B, -1)], dim=1 ) # (B, INPUT_DIM) h = self.encoder(x) dec = F.relu(self.dec_hidden(h)) recon_dna = self.dec_dna(dec).reshape(B, L, 4) # (B, L, 4) recon_phy = torch.tanh(self.dec_phy(dec)).reshape(B, L) # (B, L) return recon_dna, recon_phy, h ######################################## # 4. LOAD CHECKPOINT ######################################## device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") model = SparseAE().to(device) model.load_state_dict(torch.load("sparse_ae_50bp_epoch3.pt", map_location=device)) model.eval() print("Model loaded.") ######################################## # 5. TOKEN EXTRACTION ######################################## print("Extracting tokens...") all_token_ids = np.zeros(len(dataset), dtype=np.int32) #h_values = np.zeros((len(dataset), LATENT_DIM), dtype=np.float32) with torch.no_grad(): offset = 0 for dna_batch, phy_batch, idx_batch in tqdm(loader): dna_batch = dna_batch.to(device).float() phy_batch = phy_batch.to(device).float() _, _, h = model(dna_batch, phy_batch) h_cpu = h.cpu().numpy() # argmax token token_ids = np.argmax(h_cpu, axis=1) all_token_ids[offset : offset + len(token_ids)] = token_ids #h_values[offset : offset + len(token_ids)] = h_cpu offset += len(token_ids) print("Token extraction complete.") np.save("token_ids.npy", all_token_ids) #np.save("latent_h.npy", h_values) # Histogram hist = np.bincount(all_token_ids, minlength=LATENT_DIM) np.save("token_hist.npy", hist) print("Top tokens:") #top_tokens = np.argsort(hist)[::-1][:20] top_tokens = np.argsort(hist)[:20] for t in top_tokens: print(f"Token {t}: count={hist[t]}") ######################################## # 6. MOTIF SUMMARY FOR TOP TOKENS ######################################## print("\nBuilding PWM + average PhyloP for top tokens...") # Initialize accumulators for ALL tokens pwm_sum = {t: np.zeros((L, 4), dtype=np.float32) for t in range(LATENT_DIM)} phy_sum = {t: np.zeros(L, dtype=np.float32) for t in range(LATENT_DIM)} counts = {t: 0 for t in range(LATENT_DIM)} print("Accumulating statistics (this may take 15-30 mins)...") limit = len(all_token_ids) - L for i in tqdm(range(limit)): t = all_token_ids[i] # Always accumulate window = combined_np[i : i+L] pwm_sum[t] += window[:, :4] phy_sum[t] += window[:, 4] counts[t] += 1 ######################################## # 6A. Save per-token PWMs & phylo profiles ######################################## print("Saving profiles...") for t in range(LATENT_DIM): if counts[t] == 0: continue pwm = pwm_sum[t] / counts[t] avg_phy = phy_sum[t] / counts[t] np.save(f"token{t}_pwm.npy", pwm) np.save(f"token{t}_phy.npy", avg_phy) ######################################## # 6B. Rank tokens by PhyloP and rarity ######################################## avg_phylop_per_token = np.zeros(LATENT_DIM) count_per_token = np.zeros(LATENT_DIM) for t in range(LATENT_DIM): if counts[t] > 0: avg_phylop_per_token[t] = (phy_sum[t] / counts[t]).mean() count_per_token[t] = counts[t] else: avg_phylop_per_token[t] = -999 count_per_token[t] = 0 # Rank by PhyloP (high to low) tokens_by_phylop = np.argsort(avg_phylop_per_token)[::-1] top_phy_tokens = tokens_by_phylop[:20] # Rank by rarity (low to high) rare_tokens = np.argsort(count_per_token)[:20] print("Top 20 conserved tokens:", top_phy_tokens) print("Top 20 rarest tokens:", rare_tokens) print("\n=== Extraction Completed Successfully ===")