| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset, DataLoader, RandomSampler |
| from collections import defaultdict |
| from tqdm import tqdm |
|
|
|
|
|
|
|
|
| |
| |
| |
| print("Loading sequence / phyloP data...") |
|
|
| dna_path = "/home/n5huang/dna_token/SparseAE/chr1_dna.txt" |
| phy_path = "/home/n5huang/dna_token/SparseAE/chr1_phyloP_norm.npy" |
|
|
| with open(dna_path) as f: |
| sequence = f.read().strip() |
|
|
| phy_norm = np.load(phy_path) |
|
|
| assert len(sequence) == len(phy_norm), "DNA and phyloP length mismatch!" |
| chrom_len = len(sequence) |
| print(f"Chromosome 1 length: {chrom_len:,} bp") |
|
|
| |
| |
| |
| print("Encoding DNA to one-hot (with N handling)...") |
|
|
| mapping = {'A': 0, 'C': 1, 'G': 2, 'T': 3} |
|
|
| |
| dna_int = np.fromiter((mapping.get(b, 4) for b in sequence), dtype=np.int8) |
| num_N = np.sum(dna_int == 4) |
| print(f"Number of N bases: {num_N:,}") |
|
|
| |
| |
| temp_onehot = np.eye(5, dtype=np.float32)[dna_int] |
|
|
| |
| dna_onehot = temp_onehot[:, :4] |
|
|
| |
| |
| |
| print("Preparing combined tensor...") |
|
|
| |
| max_abs_phy = np.max(np.abs(phy_norm)) |
| if max_abs_phy > 1.1: |
| print(f"WARNING: phy_norm max abs={max_abs_phy:.3f} > 1.1; " |
| f"data may not be normalized as expected.") |
|
|
| phy_norm = phy_norm.astype(np.float32) |
| phy_col = phy_norm.reshape(-1, 1) |
|
|
| combined_np = np.concatenate([dna_onehot, phy_col], axis=1) |
| combined_tensor = torch.from_numpy(combined_np) |
|
|
| print(f"Master tensor shape: {combined_tensor.shape}") |
|
|
| |
| |
| |
| L = 50 |
|
|
| class ChunkedChr1Dataset(Dataset): |
| def __init__(self, combined, L=50): |
| self.combined = combined |
| self.L = L |
| self.N = combined.shape[0] - L |
|
|
| def __len__(self): |
| return self.N |
|
|
| def __getitem__(self, idx): |
| |
| window = self.combined[idx : idx + self.L] |
| dna = window[:, :4] |
| phy = window[:, 4] |
| return dna, phy, idx |
|
|
| dataset = ChunkedChr1Dataset(combined_tensor, L=L) |
| print(f"Dataset length (#windows): {len(dataset):,}") |
|
|
| |
| |
| |
| BATCH_SIZE = 1024 |
| SAMPLES_PER_EPOCH = 5_000_000 |
|
|
| sampler = RandomSampler( |
| dataset, |
| replacement=True, |
| num_samples=SAMPLES_PER_EPOCH |
| ) |
|
|
| loader = DataLoader( |
| dataset, |
| batch_size=BATCH_SIZE, |
| sampler=sampler, |
| shuffle=False, |
| drop_last=False, |
| num_workers=0, |
| pin_memory=True |
| ) |
|
|
| print("DataLoader ready.") |
|
|
| |
| |
| |
| INPUT_DIM = L * 5 |
| LATENT_DIM = 2048 |
| HIDDEN_DIM = 1024 |
|
|
| class SparseAE(nn.Module): |
| def __init__(self, input_dim=INPUT_DIM, latent_dim=LATENT_DIM, hidden_dim=HIDDEN_DIM): |
| super().__init__() |
|
|
| |
| self.encoder = nn.Sequential( |
| nn.Linear(input_dim, hidden_dim), |
| nn.ReLU(), |
| nn.Linear(hidden_dim, latent_dim), |
| nn.ReLU() |
| ) |
|
|
| |
| self.dec_hidden = nn.Linear(latent_dim, hidden_dim) |
|
|
| |
| self.dec_dna = nn.Linear(hidden_dim, L * 4) |
| self.dec_phy = nn.Linear(hidden_dim, L * 1) |
|
|
| def forward(self, dna, phy): |
| B = dna.size(0) |
|
|
| x = torch.cat( |
| [dna.reshape(B, -1), phy.reshape(B, -1)], |
| dim=1 |
| ) |
|
|
| h = self.encoder(x) |
| dec = F.relu(self.dec_hidden(h)) |
|
|
| recon_dna = self.dec_dna(dec).reshape(B, L, 4) |
| recon_phy = torch.tanh(self.dec_phy(dec)).reshape(B, L) |
|
|
| return recon_dna, recon_phy, h |
|
|
| |
| |
| |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f"Using device: {device}") |
|
|
| model = SparseAE().to(device) |
| model.load_state_dict(torch.load("sparse_ae_50bp_epoch3.pt", map_location=device)) |
| model.eval() |
| print("Model loaded.") |
|
|
|
|
| |
| |
| |
|
|
| print("Extracting tokens...") |
|
|
| all_token_ids = np.zeros(len(dataset), dtype=np.int32) |
| |
|
|
| with torch.no_grad(): |
| offset = 0 |
| for dna_batch, phy_batch, idx_batch in tqdm(loader): |
| dna_batch = dna_batch.to(device).float() |
| phy_batch = phy_batch.to(device).float() |
|
|
| _, _, h = model(dna_batch, phy_batch) |
| h_cpu = h.cpu().numpy() |
|
|
| |
| token_ids = np.argmax(h_cpu, axis=1) |
|
|
| all_token_ids[offset : offset + len(token_ids)] = token_ids |
| |
|
|
| offset += len(token_ids) |
|
|
| print("Token extraction complete.") |
|
|
| np.save("token_ids.npy", all_token_ids) |
| |
|
|
|
|
|
|
| |
| hist = np.bincount(all_token_ids, minlength=LATENT_DIM) |
| np.save("token_hist.npy", hist) |
|
|
| print("Top tokens:") |
| |
| top_tokens = np.argsort(hist)[:20] |
|
|
| for t in top_tokens: |
| print(f"Token {t}: count={hist[t]}") |
|
|
|
|
| |
| |
| |
|
|
| print("\nBuilding PWM + average PhyloP for top tokens...") |
| |
| pwm_sum = {t: np.zeros((L, 4), dtype=np.float32) for t in range(LATENT_DIM)} |
| phy_sum = {t: np.zeros(L, dtype=np.float32) for t in range(LATENT_DIM)} |
| counts = {t: 0 for t in range(LATENT_DIM)} |
|
|
| print("Accumulating statistics (this may take 15-30 mins)...") |
|
|
| limit = len(all_token_ids) - L |
|
|
| for i in tqdm(range(limit)): |
| t = all_token_ids[i] |
|
|
| |
| window = combined_np[i : i+L] |
|
|
| pwm_sum[t] += window[:, :4] |
| phy_sum[t] += window[:, 4] |
| counts[t] += 1 |
|
|
|
|
|
|
| |
| |
| |
|
|
| print("Saving profiles...") |
|
|
| for t in range(LATENT_DIM): |
| if counts[t] == 0: |
| continue |
|
|
| pwm = pwm_sum[t] / counts[t] |
| avg_phy = phy_sum[t] / counts[t] |
|
|
| np.save(f"token{t}_pwm.npy", pwm) |
| np.save(f"token{t}_phy.npy", avg_phy) |
|
|
|
|
| |
| |
| |
|
|
| avg_phylop_per_token = np.zeros(LATENT_DIM) |
| count_per_token = np.zeros(LATENT_DIM) |
|
|
| for t in range(LATENT_DIM): |
| if counts[t] > 0: |
| avg_phylop_per_token[t] = (phy_sum[t] / counts[t]).mean() |
| count_per_token[t] = counts[t] |
| else: |
| avg_phylop_per_token[t] = -999 |
| count_per_token[t] = 0 |
|
|
| |
| tokens_by_phylop = np.argsort(avg_phylop_per_token)[::-1] |
| top_phy_tokens = tokens_by_phylop[:20] |
|
|
| |
| rare_tokens = np.argsort(count_per_token)[:20] |
|
|
| print("Top 20 conserved tokens:", top_phy_tokens) |
| print("Top 20 rarest tokens:", rare_tokens) |
|
|
|
|
|
|
| print("\n=== Extraction Completed Successfully ===") |
|
|
|
|