SparseAE / run_train.py
nancyH's picture
Upload folder using huggingface_hub
b46126b verified
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, RandomSampler
# =====================
# 1. SETUP & DATA LOADING
# =====================
print("Loading sequence / phyloP data...")
dna_path = "/home/n5huang/dna_token/SparseAE/chr1_dna.txt"
phy_path = "/home/n5huang/dna_token/SparseAE/chr1_phyloP_norm.npy"
with open(dna_path) as f:
sequence = f.read().strip()
phy_norm = np.load(phy_path)
assert len(sequence) == len(phy_norm), "DNA and phyloP length mismatch!"
chrom_len = len(sequence)
print(f"Chromosome 1 length: {chrom_len:,} bp")
# =====================
# 2. DNA ENCODING (HANDLE 'N')
# =====================
print("Encoding DNA to one-hot (with N handling)...")
mapping = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
# Map bases to ints, using 4 as "N/unknown"
dna_int = np.fromiter((mapping.get(b, 4) for b in sequence), dtype=np.int8)
num_N = np.sum(dna_int == 4)
print(f"Number of N bases: {num_N:,}")
# One-hot with an extra row for N
# 0=A,1=C,2=G,3=T,4=[0,0,0,0,1]
temp_onehot = np.eye(5, dtype=np.float32)[dna_int]
# Slice to first 4 columns: N -> [0,0,0,0]
dna_onehot = temp_onehot[:, :4] # shape (chrom_len, 4)
# =====================
# 3. PHYLOP CHECK + COMBINE
# =====================
print("Preparing combined tensor...")
# Assume phy_norm is already in [-1,1]; warn if not.
max_abs_phy = np.max(np.abs(phy_norm))
if max_abs_phy > 1.1:
print(f"WARNING: phy_norm max abs={max_abs_phy:.3f} > 1.1; "
f"data may not be normalized as expected.")
phy_norm = phy_norm.astype(np.float32)
phy_col = phy_norm.reshape(-1, 1) # (chrom_len, 1)
combined_np = np.concatenate([dna_onehot, phy_col], axis=1) # (chrom_len, 5)
combined_tensor = torch.from_numpy(combined_np) # CPU tensor
print(f"Master tensor shape: {combined_tensor.shape}")
# =====================
# 4. DATASET: CHUNKED WINDOWING
# =====================
L = 50
class ChunkedChr1Dataset(Dataset):
def __init__(self, combined, L=50):
self.combined = combined
self.L = L
self.N = combined.shape[0] - L # number of valid start positions
def __len__(self):
return self.N
def __getitem__(self, idx):
# window: (L, 5)
window = self.combined[idx : idx + self.L]
dna = window[:, :4] # (L, 4)
phy = window[:, 4] # (L,)
return dna, phy
dataset = ChunkedChr1Dataset(combined_tensor, L=L)
print(f"Dataset length (#windows): {len(dataset):,}")
# =====================
# 5. DATALOADER WITH RANDOM SAMPLER
# =====================
BATCH_SIZE = 1024
SAMPLES_PER_EPOCH = 5_000_000 # number of windows per epoch (tunable)
sampler = RandomSampler(
dataset,
replacement=True,
num_samples=SAMPLES_PER_EPOCH
)
loader = DataLoader(
dataset,
batch_size=BATCH_SIZE,
sampler=sampler,
drop_last=True,
num_workers=0, # safer on large dataset
pin_memory=True
)
print("DataLoader ready.")
# =====================
# 6. MODEL: SPARSE AUTOENCODER
# =====================
INPUT_DIM = L * 5 # 4 DNA + 1 phyloP
LATENT_DIM = 2048
HIDDEN_DIM = 1024
class SparseAE(nn.Module):
def __init__(self, input_dim=INPUT_DIM, latent_dim=LATENT_DIM, hidden_dim=HIDDEN_DIM):
super().__init__()
# Encoder
self.encoder = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, latent_dim),
nn.ReLU() # ReLU helps sparsity with L1
)
# Decoder shared
self.dec_hidden = nn.Linear(latent_dim, hidden_dim)
# Decoder heads
self.dec_dna = nn.Linear(hidden_dim, L * 4)
self.dec_phy = nn.Linear(hidden_dim, L * 1)
def forward(self, dna, phy):
B = dna.size(0)
x = torch.cat(
[dna.reshape(B, -1), phy.reshape(B, -1)],
dim=1
) # (B, INPUT_DIM)
h = self.encoder(x)
dec = F.relu(self.dec_hidden(h))
recon_dna = self.dec_dna(dec).reshape(B, L, 4) # (B, L, 4)
recon_phy = torch.tanh(self.dec_phy(dec)).reshape(B, L) # (B, L)
return recon_dna, recon_phy, h
# =====================
# 7. TRAINING LOOP
# =====================
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Training on device: {device}")
model = SparseAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
#lambda_l1 = 0.01 # slightly stronger sparsity
lambda_l1_start = 0.02
lambda_l1_end = 0.005
phy_weight = 10.0
num_epochs = 5
PRINT_EVERY = 1000 # batches
beta_kl_schedule = [0.0, 0.01, 0.02, 0.05, 0.1] # per epoch
print("Starting training...")
for epoch in range(num_epochs):
model.train()
total_loss = 0.0
total_dna = 0.0
total_phy = 0.0
total_active = 0.0
batch_count = 0
for dna_batch, phy_batch in loader:
batch_count += 1
dna_batch = dna_batch.to(device, non_blocking=True).float() # (B, L, 4)
phy_batch = phy_batch.to(device, non_blocking=True).float() # (B, L)
optimizer.zero_grad()
recon_dna, recon_phy, h = model(dna_batch, phy_batch)
# Mask positions that are 'N' (all-zero one-hot)
mask = dna_batch.sum(dim=-1) > 0 # (B, L), True where valid base
# --- DNA loss (masked CE) ---
true_dna_cls = dna_batch.argmax(dim=-1) # (B, L)
dna_logits = recon_dna.permute(0, 2, 1) # (B, 4, L)
loss_dna_raw = F.cross_entropy(dna_logits, true_dna_cls, reduction='none') # (B, L)
if mask.sum() > 0:
loss_dna = (loss_dna_raw * mask).sum() / mask.sum()
else:
loss_dna = torch.tensor(0.0, device=device)
# --- PhyloP loss (masked MSE) ---
loss_phy_raw = F.mse_loss(recon_phy, phy_batch, reduction='none') # (B, L)
if mask.sum() > 0:
loss_phy = (loss_phy_raw * mask).sum() / mask.sum()
else:
loss_phy = torch.tensor(0.0, device=device)
# --- KL sparsity penalty ---
rho = 0.02 # target sparsity
eps = 1e-12
rho_hat = torch.mean(h, dim=0)
rho_hat = torch.clamp(rho_hat, min=1e-6, max=1-1e-6)
kl_per_unit = (
rho * torch.log((rho + eps) / (rho_hat + eps)) +
(1 - rho) * torch.log(((1 - rho) + eps) / ((1 - rho_hat) + eps))
)
beta_kl = beta_kl_schedule[min(epoch, len(beta_kl_schedule)-1)]
#loss_kl = 1 * kl_per_unit.sum() # β = 1 regularization weight
loss_kl = beta_kl * kl_per_unit.sum()
lambda_l1 = (
lambda_l1_start
+ (lambda_l1_end - lambda_l1_start) * (epoch / (num_epochs - 1))
)
# --- L1 sparsity on latent ---
loss_l1 = lambda_l1 * torch.mean(torch.abs(h))
# Total loss
loss = loss_dna + phy_weight * loss_phy + loss_l1 + loss_kl
loss.backward()
optimizer.step()
# Logging accumulators
B = dna_batch.size(0)
total_loss += loss.item() * B
total_dna += loss_dna.item() * B
total_phy += loss_phy.item() * B
# approximate number of active neurons (h > threshold)
active_count = (h > 0.01).float().sum(dim=1).mean().item()
total_active += active_count * B
if batch_count % PRINT_EVERY == 0:
print(
f"Epoch {epoch+1} | Batch {batch_count} | "
f"Loss={loss.item():.4f} | DNA_CE={loss_dna.item():.4f} | "
f"Phy_MSE={loss_phy.item():.5f} | Active={active_count:.1f}"
)
# Epoch summary
N = SAMPLES_PER_EPOCH # effective number of samples this epoch
avg_loss = total_loss / N
avg_dna = total_dna / N
avg_phy = total_phy / N
avg_active = total_active / N
print(f"\n=== Epoch {epoch+1}/{num_epochs} COMPLETE ===")
print(
f"Avg Loss={avg_loss:.4f} | Avg DNA_CE={avg_dna:.4f} | "
f"Avg Phy_MSE={avg_phy:.5f} | "
f"Avg Active Neurons={avg_active:.1f} / {LATENT_DIM} "
f"({100.0 * avg_active / LATENT_DIM:.1f}%)"
)
# Save checkpoint
ckpt_path = f"sparse_ae_50bp_epoch{epoch+1}.pt"
torch.save(model.state_dict(), ckpt_path)
print(f"Saved checkpoint to {ckpt_path}\n")