""" Code for baseline model to compare the classifier to """ from lightning import LightningModule import torch import torch.nn as nn from .loss import calculate_loss, auprc_zeros_vs_ones_from_logits, auroc_zeros_vs_ones_from_logits from .model import DimCompressor class BaselineBindPredictor(LightningModule): """ Baseline predictor: simple MLP that just concatenates the embeddings and outputs per-token predictions. """ def __init__( self, # input_dim: int = 256, # OLD: single input dim binder_input_dim: int = 1280, # NEW: TF (binder) original dim (e.g., 1280) glm_input_dim: int = 256, # NEW: DNA/GLM original dim (e.g., 256) compressed_dim: int = 256, # NEW: learnable compressed dim hidden_dim: int = 256, lr: float = 1e-4, alpha: float = 20, gamma: float = 20, dropout: float = 0, weight_decay: float = 0.01, loss_type: str = "mixed" ): # Init super(BaselineBindPredictor, self).__init__() self.save_hyperparameters() # Learnable compressor for binder -> 256, then project to hidden self.binder_compress = DimCompressor(binder_input_dim, out_dim=compressed_dim) self.mlp = torch.nn.Sequential( torch.nn.Linear(compressed_dim, hidden_dim), torch.nn.ReLU(), torch.nn.Linear(hidden_dim, 1), torch.nn.ReLU(), ) def forward(self, binder_emb, glm_emb, binder_mask, glm_mask): """ binder_emb: (B, Lb, binder_input_dim) glm_emb: (B, Lg, glm_input_dim) Returns per-nucleotide logits for the GLM sequence: (B, Lg) """ # Binder: learnable compression → glm_input_dim b = self.binder_compress(binder_emb) # (B, Lb, glm_input_dim) # Concatenate target and binder. Concatenate on the length dimension lg = glm_emb.shape[1] concat_embeddings = torch.concat((glm_emb,b), dim=1) # (B, Lb + Lg, glm_input_dim) # Run concatenated embeddings through MLP logits = self.mlp(concat_embeddings) # (B, Lb + Lg, 1) # Get only the DNA logits. logits = logits[:,0:lg,:].squeeze( -1 ) return logits # ----- Lightning hooks ----- def training_step(self, batch, batch_idx): """ Training step taken by PyTorch-Lightning trainer. Uses batch returned by data collator. Colator returns a dictionary with: "binder_emb" # [B, Lb_max, Db] "binder_kpm" # [B, Lb_max] "glm_emb" # [B, Lg_max, Dg] "glm_kpm" # [B, Lg_max] "labels" # [B, Lg_max] "ID" "tr_sequence" "dna_sequence" } """ logits = self.forward(batch["binder_emb"], batch["glm_emb"], batch["binder_kpm"], batch["glm_kpm"]) loss = calculate_loss( logits, batch["labels"], batch["binder_kpm"], batch["glm_kpm"], alpha=self.hparams.alpha, gamma=self.hparams.gamma, loss_type=self.hparams.loss_type ) self.log( "train/loss", loss, on_step=True, on_epoch=True, prog_bar=True, batch_size=logits.size(0), ) # ---- AUPRC and AUROC on labels in {0, >0.99} only ---- ap, n_pos, n_neg, precision, recall, thresholds = auprc_zeros_vs_ones_from_logits( logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99 ) auc, n_pos, n_neg, tpr, fpr, thresolds, tp, fp = auroc_zeros_vs_ones_from_logits( logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99 ) # per-batch AP (epoch-mean is a decent summary); sync across GPUs if using DDP self.log("train/auprc_0v1", ap if torch.isfinite(ap) else torch.tensor(0.0, device=ap.device), on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0)) self.log("train/auroc_0v1", auc if torch.isfinite(auc) else torch.tensor(0.0, device=auc.device), on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0)) # (optional) also log class counts so you can sanity-check balance self.log("train/n_pos_0v1", float(n_pos), on_step=False, on_epoch=True, sync_dist=True) self.log("train/n_neg_0v1", float(n_neg), on_step=False, on_epoch=True, sync_dist=True) return loss def validation_step(self, batch, batch_idx): logits = self.forward(batch["binder_emb"], batch["glm_emb"], batch["binder_kpm"], batch["glm_kpm"]) loss = calculate_loss( logits, batch["labels"], batch["binder_kpm"], batch["glm_kpm"], alpha=self.hparams.alpha, gamma=self.hparams.gamma, loss_type=self.hparams.loss_type ) self.log( "val/loss", loss, on_step=False, on_epoch=True, prog_bar=True, batch_size=logits.size(0), ) # ---- AUPRC and AUROC on labels in {0, >0.99} only ---- ap, n_pos, n_neg, precision, recall, thresholds = auprc_zeros_vs_ones_from_logits( logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99 ) auc, n_pos, n_neg, tpr, fpr, thresolds, tp, fp = auroc_zeros_vs_ones_from_logits( logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99 ) # per-batch AP (epoch-mean is a decent summary); sync across GPUs if using DDP self.log("val/auprc_0v1", ap if torch.isfinite(ap) else torch.tensor(0.0, device=ap.device), on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0)) self.log("val/auroc_0v1", auc if torch.isfinite(auc) else torch.tensor(0.0, device=auc.device), on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0)) return loss def test_step(self, batch, batch_idx): logits = self.forward(batch["binder_emb"], batch["glm_emb"], batch["binder_kpm"], batch["glm_kpm"]) loss = calculate_loss( logits, batch["labels"], batch["binder_kpm"], batch["glm_kpm"], alpha=self.hparams.alpha, gamma=self.hparams.gamma, loss_type=self.hparams.loss_type ) self.log( "test/loss", loss, on_step=False, on_epoch=True, batch_size=logits.size(0) ) # ---- AUPRC and AUROC on labels in {0, >0.99} only ---- ap, n_pos, n_neg, precision, recall, thresholds = auprc_zeros_vs_ones_from_logits( logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99 ) auc, n_pos, n_neg, tpr, fpr, thresolds, tp, fp = auroc_zeros_vs_ones_from_logits( logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99 ) # per-batch AP (epoch-mean is a decent summary); sync across GPUs if using DDP self.log("test/auprc_0v1", ap if torch.isfinite(ap) else torch.tensor(0.0, device=ap.device), on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0)) self.log("test/auroc_0v1", auc if torch.isfinite(auc) else torch.tensor(0.0, device=auc.device), on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0)) return loss def on_before_optimizer_step(self, optimizer): # Compute global L2 norm of all parameter gradients (ignores None grads) grads = [] for p in self.parameters(): if p.grad is not None: # .detach() avoids autograd tracking; .float() avoids fp16 overflow in norms grads.append(p.grad.detach().float().norm(2)) if grads: total_norm = torch.norm(torch.stack(grads), p=2) self.log("train/grad_norm", total_norm, on_step=True, prog_bar=False, logger=True) def on_after_backward(self): grads = [p.grad.detach().float().norm(2) for p in self.parameters() if p.grad is not None] if grads: total_norm = torch.norm(torch.stack(grads), p=2) self.log("train/grad_norm_back", total_norm, on_step=True, prog_bar=False) def on_train_epoch_end(self): if False: if self.train_auc.compute() is not None: self.log("train/auroc", self.train_auc.compute(), prog_bar=True) self.train_auc.reset() def on_validation_epoch_end(self): if False: if self.val_auc.compute() is not None: self.log("val/auroc", self.val_auc.compute(), prog_bar=True) self.val_auc.reset() def on_test_epoch_end(self): if False: if self.test_auc.compute() is not None: self.log("test/auroc", self.test_auc.compute(), prog_bar=True) self.test_auc.reset() def configure_optimizers(self): # AdamW + cosine as a sensible default opt = torch.optim.AdamW( self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay, ) # Scheduler optional—comment out if you prefer fixed LR sch = torch.optim.lr_scheduler.CosineAnnealingLR( opt, T_max=max(self.trainer.max_epochs, 1) ) return { "optimizer": opt, "lr_scheduler": {"scheduler": sch, "interval": "epoch"}, }