| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset, DataLoader, RandomSampler |
|
|
| |
| |
| |
| print("Loading sequence / phyloP data...") |
|
|
| dna_path = "/home/n5huang/dna_token/SparseAE/chr1_dna.txt" |
| phy_path = "/home/n5huang/dna_token/SparseAE/chr1_phyloP_norm.npy" |
|
|
| with open(dna_path) as f: |
| sequence = f.read().strip() |
|
|
| phy_norm = np.load(phy_path) |
|
|
| assert len(sequence) == len(phy_norm), "DNA and phyloP length mismatch!" |
| chrom_len = len(sequence) |
| print(f"Chromosome 1 length: {chrom_len:,} bp") |
|
|
| |
| |
| |
| print("Encoding DNA to one-hot (with N handling)...") |
|
|
| mapping = {'A': 0, 'C': 1, 'G': 2, 'T': 3} |
|
|
| |
| dna_int = np.fromiter((mapping.get(b, 4) for b in sequence), dtype=np.int8) |
| num_N = np.sum(dna_int == 4) |
| print(f"Number of N bases: {num_N:,}") |
|
|
| |
| |
| temp_onehot = np.eye(5, dtype=np.float32)[dna_int] |
|
|
| |
| dna_onehot = temp_onehot[:, :4] |
|
|
| |
| |
| |
| print("Preparing combined tensor...") |
|
|
| |
| max_abs_phy = np.max(np.abs(phy_norm)) |
| if max_abs_phy > 1.1: |
| print(f"WARNING: phy_norm max abs={max_abs_phy:.3f} > 1.1; " |
| f"data may not be normalized as expected.") |
|
|
| phy_norm = phy_norm.astype(np.float32) |
| phy_col = phy_norm.reshape(-1, 1) |
|
|
| combined_np = np.concatenate([dna_onehot, phy_col], axis=1) |
| combined_tensor = torch.from_numpy(combined_np) |
|
|
| print(f"Master tensor shape: {combined_tensor.shape}") |
|
|
| |
| |
| |
| L = 50 |
|
|
| class ChunkedChr1Dataset(Dataset): |
| def __init__(self, combined, L=50): |
| self.combined = combined |
| self.L = L |
| self.N = combined.shape[0] - L |
|
|
| def __len__(self): |
| return self.N |
|
|
| def __getitem__(self, idx): |
| |
| window = self.combined[idx : idx + self.L] |
| dna = window[:, :4] |
| phy = window[:, 4] |
| return dna, phy |
|
|
| dataset = ChunkedChr1Dataset(combined_tensor, L=L) |
| print(f"Dataset length (#windows): {len(dataset):,}") |
|
|
| |
| |
| |
| BATCH_SIZE = 1024 |
| SAMPLES_PER_EPOCH = 5_000_000 |
|
|
| sampler = RandomSampler( |
| dataset, |
| replacement=True, |
| num_samples=SAMPLES_PER_EPOCH |
| ) |
|
|
| loader = DataLoader( |
| dataset, |
| batch_size=BATCH_SIZE, |
| sampler=sampler, |
| drop_last=True, |
| num_workers=0, |
| pin_memory=True |
| ) |
|
|
| print("DataLoader ready.") |
|
|
| |
| |
| |
| INPUT_DIM = L * 5 |
| LATENT_DIM = 2048 |
| HIDDEN_DIM = 1024 |
|
|
| class SparseAE(nn.Module): |
| def __init__(self, input_dim=INPUT_DIM, latent_dim=LATENT_DIM, hidden_dim=HIDDEN_DIM): |
| super().__init__() |
|
|
| |
| self.encoder = nn.Sequential( |
| nn.Linear(input_dim, hidden_dim), |
| nn.ReLU(), |
| nn.Linear(hidden_dim, latent_dim), |
| nn.ReLU() |
| ) |
|
|
| |
| self.dec_hidden = nn.Linear(latent_dim, hidden_dim) |
|
|
| |
| self.dec_dna = nn.Linear(hidden_dim, L * 4) |
| self.dec_phy = nn.Linear(hidden_dim, L * 1) |
|
|
| def forward(self, dna, phy): |
| B = dna.size(0) |
|
|
| x = torch.cat( |
| [dna.reshape(B, -1), phy.reshape(B, -1)], |
| dim=1 |
| ) |
|
|
| h = self.encoder(x) |
| dec = F.relu(self.dec_hidden(h)) |
|
|
| recon_dna = self.dec_dna(dec).reshape(B, L, 4) |
| recon_phy = torch.tanh(self.dec_phy(dec)).reshape(B, L) |
|
|
| return recon_dna, recon_phy, h |
|
|
| |
| |
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f"Training on device: {device}") |
|
|
| model = SparseAE().to(device) |
| optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) |
|
|
| |
| lambda_l1_start = 0.02 |
| lambda_l1_end = 0.005 |
| phy_weight = 10.0 |
| num_epochs = 5 |
| PRINT_EVERY = 1000 |
| beta_kl_schedule = [0.0, 0.01, 0.02, 0.05, 0.1] |
|
|
|
|
|
|
| print("Starting training...") |
|
|
| for epoch in range(num_epochs): |
| model.train() |
| total_loss = 0.0 |
| total_dna = 0.0 |
| total_phy = 0.0 |
| total_active = 0.0 |
|
|
| batch_count = 0 |
| for dna_batch, phy_batch in loader: |
| batch_count += 1 |
|
|
| dna_batch = dna_batch.to(device, non_blocking=True).float() |
| phy_batch = phy_batch.to(device, non_blocking=True).float() |
|
|
| optimizer.zero_grad() |
|
|
| recon_dna, recon_phy, h = model(dna_batch, phy_batch) |
|
|
| |
| mask = dna_batch.sum(dim=-1) > 0 |
|
|
| |
| true_dna_cls = dna_batch.argmax(dim=-1) |
| dna_logits = recon_dna.permute(0, 2, 1) |
| loss_dna_raw = F.cross_entropy(dna_logits, true_dna_cls, reduction='none') |
|
|
| if mask.sum() > 0: |
| loss_dna = (loss_dna_raw * mask).sum() / mask.sum() |
| else: |
| loss_dna = torch.tensor(0.0, device=device) |
|
|
| |
| loss_phy_raw = F.mse_loss(recon_phy, phy_batch, reduction='none') |
|
|
| if mask.sum() > 0: |
| loss_phy = (loss_phy_raw * mask).sum() / mask.sum() |
| else: |
| loss_phy = torch.tensor(0.0, device=device) |
|
|
| |
| rho = 0.02 |
| eps = 1e-12 |
| rho_hat = torch.mean(h, dim=0) |
| rho_hat = torch.clamp(rho_hat, min=1e-6, max=1-1e-6) |
| kl_per_unit = ( |
| rho * torch.log((rho + eps) / (rho_hat + eps)) + |
| (1 - rho) * torch.log(((1 - rho) + eps) / ((1 - rho_hat) + eps)) |
| ) |
| beta_kl = beta_kl_schedule[min(epoch, len(beta_kl_schedule)-1)] |
| |
| loss_kl = beta_kl * kl_per_unit.sum() |
|
|
|
|
| lambda_l1 = ( |
| lambda_l1_start |
| + (lambda_l1_end - lambda_l1_start) * (epoch / (num_epochs - 1)) |
| ) |
|
|
| |
| loss_l1 = lambda_l1 * torch.mean(torch.abs(h)) |
|
|
| |
| loss = loss_dna + phy_weight * loss_phy + loss_l1 + loss_kl |
|
|
| loss.backward() |
| optimizer.step() |
|
|
| |
| B = dna_batch.size(0) |
| total_loss += loss.item() * B |
| total_dna += loss_dna.item() * B |
| total_phy += loss_phy.item() * B |
|
|
| |
| active_count = (h > 0.01).float().sum(dim=1).mean().item() |
| total_active += active_count * B |
|
|
| if batch_count % PRINT_EVERY == 0: |
| print( |
| f"Epoch {epoch+1} | Batch {batch_count} | " |
| f"Loss={loss.item():.4f} | DNA_CE={loss_dna.item():.4f} | " |
| f"Phy_MSE={loss_phy.item():.5f} | Active={active_count:.1f}" |
| ) |
|
|
| |
| N = SAMPLES_PER_EPOCH |
| avg_loss = total_loss / N |
| avg_dna = total_dna / N |
| avg_phy = total_phy / N |
| avg_active = total_active / N |
|
|
| print(f"\n=== Epoch {epoch+1}/{num_epochs} COMPLETE ===") |
| print( |
| f"Avg Loss={avg_loss:.4f} | Avg DNA_CE={avg_dna:.4f} | " |
| f"Avg Phy_MSE={avg_phy:.5f} | " |
| f"Avg Active Neurons={avg_active:.1f} / {LATENT_DIM} " |
| f"({100.0 * avg_active / LATENT_DIM:.1f}%)" |
| ) |
|
|
| |
| ckpt_path = f"sparse_ae_50bp_epoch{epoch+1}.pt" |
| torch.save(model.state_dict(), ckpt_path) |
| print(f"Saved checkpoint to {ckpt_path}\n") |
|
|