import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader, RandomSampler # ===================== # 1. SETUP & DATA LOADING # ===================== print("Loading sequence / phyloP data...") dna_path = "/home/n5huang/dna_token/SparseAE/chr1_dna.txt" phy_path = "/home/n5huang/dna_token/SparseAE/chr1_phyloP_norm.npy" with open(dna_path) as f: sequence = f.read().strip() phy_norm = np.load(phy_path) assert len(sequence) == len(phy_norm), "DNA and phyloP length mismatch!" chrom_len = len(sequence) print(f"Chromosome 1 length: {chrom_len:,} bp") # ===================== # 2. DNA ENCODING (HANDLE 'N') # ===================== print("Encoding DNA to one-hot (with N handling)...") mapping = {'A': 0, 'C': 1, 'G': 2, 'T': 3} # Map bases to ints, using 4 as "N/unknown" dna_int = np.fromiter((mapping.get(b, 4) for b in sequence), dtype=np.int8) num_N = np.sum(dna_int == 4) print(f"Number of N bases: {num_N:,}") # One-hot with an extra row for N # 0=A,1=C,2=G,3=T,4=[0,0,0,0,1] temp_onehot = np.eye(5, dtype=np.float32)[dna_int] # Slice to first 4 columns: N -> [0,0,0,0] dna_onehot = temp_onehot[:, :4] # shape (chrom_len, 4) # ===================== # 3. PHYLOP CHECK + COMBINE # ===================== print("Preparing combined tensor...") # Assume phy_norm is already in [-1,1]; warn if not. max_abs_phy = np.max(np.abs(phy_norm)) if max_abs_phy > 1.1: print(f"WARNING: phy_norm max abs={max_abs_phy:.3f} > 1.1; " f"data may not be normalized as expected.") phy_norm = phy_norm.astype(np.float32) phy_col = phy_norm.reshape(-1, 1) # (chrom_len, 1) combined_np = np.concatenate([dna_onehot, phy_col], axis=1) # (chrom_len, 5) combined_tensor = torch.from_numpy(combined_np) # CPU tensor print(f"Master tensor shape: {combined_tensor.shape}") # ===================== # 4. DATASET: CHUNKED WINDOWING # ===================== L = 50 class ChunkedChr1Dataset(Dataset): def __init__(self, combined, L=50): self.combined = combined self.L = L self.N = combined.shape[0] - L # number of valid start positions def __len__(self): return self.N def __getitem__(self, idx): # window: (L, 5) window = self.combined[idx : idx + self.L] dna = window[:, :4] # (L, 4) phy = window[:, 4] # (L,) return dna, phy dataset = ChunkedChr1Dataset(combined_tensor, L=L) print(f"Dataset length (#windows): {len(dataset):,}") # ===================== # 5. DATALOADER WITH RANDOM SAMPLER # ===================== BATCH_SIZE = 1024 SAMPLES_PER_EPOCH = 5_000_000 # number of windows per epoch (tunable) sampler = RandomSampler( dataset, replacement=True, num_samples=SAMPLES_PER_EPOCH ) loader = DataLoader( dataset, batch_size=BATCH_SIZE, sampler=sampler, drop_last=True, num_workers=0, # safer on large dataset pin_memory=True ) print("DataLoader ready.") # ===================== # 6. MODEL: SPARSE AUTOENCODER # ===================== INPUT_DIM = L * 5 # 4 DNA + 1 phyloP LATENT_DIM = 2048 HIDDEN_DIM = 1024 class SparseAE(nn.Module): def __init__(self, input_dim=INPUT_DIM, latent_dim=LATENT_DIM, hidden_dim=HIDDEN_DIM): super().__init__() # Encoder self.encoder = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, latent_dim), nn.ReLU() # ReLU helps sparsity with L1 ) # Decoder shared self.dec_hidden = nn.Linear(latent_dim, hidden_dim) # Decoder heads self.dec_dna = nn.Linear(hidden_dim, L * 4) self.dec_phy = nn.Linear(hidden_dim, L * 1) def forward(self, dna, phy): B = dna.size(0) x = torch.cat( [dna.reshape(B, -1), phy.reshape(B, -1)], dim=1 ) # (B, INPUT_DIM) h = self.encoder(x) dec = F.relu(self.dec_hidden(h)) recon_dna = self.dec_dna(dec).reshape(B, L, 4) # (B, L, 4) recon_phy = torch.tanh(self.dec_phy(dec)).reshape(B, L) # (B, L) return recon_dna, recon_phy, h # ===================== # 7. TRAINING LOOP # ===================== device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Training on device: {device}") model = SparseAE().to(device) optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) #lambda_l1 = 0.01 # slightly stronger sparsity lambda_l1_start = 0.02 lambda_l1_end = 0.005 phy_weight = 10.0 num_epochs = 5 PRINT_EVERY = 1000 # batches beta_kl_schedule = [0.0, 0.01, 0.02, 0.05, 0.1] # per epoch print("Starting training...") for epoch in range(num_epochs): model.train() total_loss = 0.0 total_dna = 0.0 total_phy = 0.0 total_active = 0.0 batch_count = 0 for dna_batch, phy_batch in loader: batch_count += 1 dna_batch = dna_batch.to(device, non_blocking=True).float() # (B, L, 4) phy_batch = phy_batch.to(device, non_blocking=True).float() # (B, L) optimizer.zero_grad() recon_dna, recon_phy, h = model(dna_batch, phy_batch) # Mask positions that are 'N' (all-zero one-hot) mask = dna_batch.sum(dim=-1) > 0 # (B, L), True where valid base # --- DNA loss (masked CE) --- true_dna_cls = dna_batch.argmax(dim=-1) # (B, L) dna_logits = recon_dna.permute(0, 2, 1) # (B, 4, L) loss_dna_raw = F.cross_entropy(dna_logits, true_dna_cls, reduction='none') # (B, L) if mask.sum() > 0: loss_dna = (loss_dna_raw * mask).sum() / mask.sum() else: loss_dna = torch.tensor(0.0, device=device) # --- PhyloP loss (masked MSE) --- loss_phy_raw = F.mse_loss(recon_phy, phy_batch, reduction='none') # (B, L) if mask.sum() > 0: loss_phy = (loss_phy_raw * mask).sum() / mask.sum() else: loss_phy = torch.tensor(0.0, device=device) # --- KL sparsity penalty --- rho = 0.02 # target sparsity eps = 1e-12 rho_hat = torch.mean(h, dim=0) rho_hat = torch.clamp(rho_hat, min=1e-6, max=1-1e-6) kl_per_unit = ( rho * torch.log((rho + eps) / (rho_hat + eps)) + (1 - rho) * torch.log(((1 - rho) + eps) / ((1 - rho_hat) + eps)) ) beta_kl = beta_kl_schedule[min(epoch, len(beta_kl_schedule)-1)] #loss_kl = 1 * kl_per_unit.sum() # β = 1 regularization weight loss_kl = beta_kl * kl_per_unit.sum() lambda_l1 = ( lambda_l1_start + (lambda_l1_end - lambda_l1_start) * (epoch / (num_epochs - 1)) ) # --- L1 sparsity on latent --- loss_l1 = lambda_l1 * torch.mean(torch.abs(h)) # Total loss loss = loss_dna + phy_weight * loss_phy + loss_l1 + loss_kl loss.backward() optimizer.step() # Logging accumulators B = dna_batch.size(0) total_loss += loss.item() * B total_dna += loss_dna.item() * B total_phy += loss_phy.item() * B # approximate number of active neurons (h > threshold) active_count = (h > 0.01).float().sum(dim=1).mean().item() total_active += active_count * B if batch_count % PRINT_EVERY == 0: print( f"Epoch {epoch+1} | Batch {batch_count} | " f"Loss={loss.item():.4f} | DNA_CE={loss_dna.item():.4f} | " f"Phy_MSE={loss_phy.item():.5f} | Active={active_count:.1f}" ) # Epoch summary N = SAMPLES_PER_EPOCH # effective number of samples this epoch avg_loss = total_loss / N avg_dna = total_dna / N avg_phy = total_phy / N avg_active = total_active / N print(f"\n=== Epoch {epoch+1}/{num_epochs} COMPLETE ===") print( f"Avg Loss={avg_loss:.4f} | Avg DNA_CE={avg_dna:.4f} | " f"Avg Phy_MSE={avg_phy:.5f} | " f"Avg Active Neurons={avg_active:.1f} / {LATENT_DIM} " f"({100.0 * avg_active / LATENT_DIM:.1f}%)" ) # Save checkpoint ckpt_path = f"sparse_ae_50bp_epoch{epoch+1}.pt" torch.save(model.state_dict(), ckpt_path) print(f"Saved checkpoint to {ckpt_path}\n")