File size: 9,746 Bytes
121a325
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b33404
121a325
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b33404
121a325
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b33404
121a325
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b33404
121a325
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
"""
Code for baseline model to compare the classifier to
"""

from lightning import LightningModule
import torch
import torch.nn as nn
from .loss import calculate_loss, auprc_zeros_vs_ones_from_logits, auroc_zeros_vs_ones_from_logits
from .model import DimCompressor

class BaselineBindPredictor(LightningModule):
    """
    Baseline predictor: simple MLP that just concatenates the embeddings and outputs per-token predictions. 
    """
    def __init__(
        self,
        # input_dim: int = 256,                     # OLD: single input dim
        binder_input_dim: int = 1280,  # NEW: TF (binder) original dim (e.g., 1280)
        glm_input_dim: int = 256,  # NEW: DNA/GLM original dim (e.g., 256)
        compressed_dim: int = 256,  # NEW: learnable compressed dim
        hidden_dim: int = 256,
        lr: float = 1e-4,
        alpha: float = 20,
        gamma: float = 20,
        dropout: float = 0,
        weight_decay: float = 0.01,
        loss_type: str = "mixed"
    ):
        # Init
        super(BaselineBindPredictor, self).__init__()
        self.save_hyperparameters()

        # Learnable compressor for binder -> 256, then project to hidden
        self.binder_compress = DimCompressor(binder_input_dim, out_dim=compressed_dim)
        
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(compressed_dim, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, 1),
            torch.nn.ReLU(),
        )
        
    def forward(self, binder_emb, glm_emb, binder_mask, glm_mask):
        """
        binder_emb: (B, Lb, binder_input_dim)
        glm_emb:    (B, Lg, glm_input_dim)
        Returns per-nucleotide logits for the GLM sequence: (B, Lg)
        """
        # Binder: learnable compression → glm_input_dim
        b = self.binder_compress(binder_emb)  # (B, Lb, glm_input_dim)
        
        # Concatenate target and binder. Concatenate on the length dimension
        lg = glm_emb.shape[1]
        concat_embeddings = torch.concat((glm_emb,b), dim=1) # (B, Lb + Lg, glm_input_dim)
        
        # Run concatenated embeddings through MLP
        logits = self.mlp(concat_embeddings)  # (B, Lb + Lg, 1)

        # Get only the DNA logits. 
        logits = logits[:,0:lg,:].squeeze(
            -1
        )  
        return logits
        
    # ----- Lightning hooks -----
    def training_step(self, batch, batch_idx):
        """
        Training step taken by PyTorch-Lightning trainer. Uses batch returned by data collator.
        Colator returns a dictionary with:
            "binder_emb"    # [B, Lb_max, Db]
            "binder_kpm"    # [B, Lb_max]
            "glm_emb"       # [B, Lg_max, Dg]
            "glm_kpm"       # [B, Lg_max]
            "labels"        # [B, Lg_max]
            "ID"
            "tr_sequence"
            "dna_sequence"
        }
        """
        logits = self.forward(batch["binder_emb"], batch["glm_emb"], batch["binder_kpm"], batch["glm_kpm"])
        loss = calculate_loss(
            logits, batch["labels"], batch["binder_kpm"], batch["glm_kpm"], alpha=self.hparams.alpha, gamma=self.hparams.gamma, loss_type=self.hparams.loss_type
        )
        self.log(
            "train/loss",
            loss,
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            batch_size=logits.size(0),
        )
        
        # ---- AUPRC and AUROC on labels in {0, >0.99} only ----
        ap, n_pos, n_neg, precision, recall, thresholds = auprc_zeros_vs_ones_from_logits(
            logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99
        )
        auc, n_pos, n_neg, tpr, fpr, thresolds, tp, fp = auroc_zeros_vs_ones_from_logits(
            logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99
        )
        # per-batch AP (epoch-mean is a decent summary); sync across GPUs if using DDP
        self.log("train/auprc_0v1",
                ap if torch.isfinite(ap) else torch.tensor(0.0, device=ap.device),
                on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0))
        self.log("train/auroc_0v1",
                auc if torch.isfinite(auc) else torch.tensor(0.0, device=auc.device),
                on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0))
        
        # (optional) also log class counts so you can sanity-check balance
        self.log("train/n_pos_0v1", float(n_pos), on_step=False, on_epoch=True, sync_dist=True)
        self.log("train/n_neg_0v1", float(n_neg), on_step=False, on_epoch=True, sync_dist=True)

        return loss

    def validation_step(self, batch, batch_idx):
        logits = self.forward(batch["binder_emb"], batch["glm_emb"], batch["binder_kpm"], batch["glm_kpm"])
        loss = calculate_loss(
            logits, batch["labels"], batch["binder_kpm"], batch["glm_kpm"], alpha=self.hparams.alpha, gamma=self.hparams.gamma, loss_type=self.hparams.loss_type
        )
        self.log(
            "val/loss",
            loss,
            on_step=False,
            on_epoch=True,
            prog_bar=True,
            batch_size=logits.size(0),
        )
        
        # ---- AUPRC and AUROC on labels in {0, >0.99} only ----
        ap, n_pos, n_neg, precision, recall, thresholds = auprc_zeros_vs_ones_from_logits(
            logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99
        )
        auc, n_pos, n_neg, tpr, fpr, thresolds, tp, fp = auroc_zeros_vs_ones_from_logits(
            logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99
        )
        # per-batch AP (epoch-mean is a decent summary); sync across GPUs if using DDP
        self.log("val/auprc_0v1",
                ap if torch.isfinite(ap) else torch.tensor(0.0, device=ap.device),
                on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0))
        self.log("val/auroc_0v1",
                auc if torch.isfinite(auc) else torch.tensor(0.0, device=auc.device),
                on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0))
        return loss

    def test_step(self, batch, batch_idx):
        logits = self.forward(batch["binder_emb"], batch["glm_emb"], batch["binder_kpm"], batch["glm_kpm"])
        loss = calculate_loss(
            logits, batch["labels"], batch["binder_kpm"], batch["glm_kpm"], alpha=self.hparams.alpha, gamma=self.hparams.gamma, loss_type=self.hparams.loss_type
        )
        self.log(
            "test/loss", loss, on_step=False, on_epoch=True, batch_size=logits.size(0)
        )
        
        # ---- AUPRC and AUROC on labels in {0, >0.99} only ----
        ap, n_pos, n_neg, precision, recall, thresholds = auprc_zeros_vs_ones_from_logits(
            logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99
        )
        auc, n_pos, n_neg, tpr, fpr, thresolds, tp, fp = auroc_zeros_vs_ones_from_logits(
            logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99
        )
        # per-batch AP (epoch-mean is a decent summary); sync across GPUs if using DDP
        self.log("test/auprc_0v1",
                ap if torch.isfinite(ap) else torch.tensor(0.0, device=ap.device),
                on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0))
        self.log("test/auroc_0v1",
                auc if torch.isfinite(auc) else torch.tensor(0.0, device=auc.device),
                on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0))
        return loss
    
    def on_before_optimizer_step(self, optimizer):
        # Compute global L2 norm of all parameter gradients (ignores None grads)
        grads = []
        for p in self.parameters():
            if p.grad is not None:
                # .detach() avoids autograd tracking; .float() avoids fp16 overflow in norms
                grads.append(p.grad.detach().float().norm(2))
        if grads:
            total_norm = torch.norm(torch.stack(grads), p=2)
            self.log("train/grad_norm", total_norm, on_step=True, prog_bar=False, logger=True)
    
    def on_after_backward(self):
        grads = [p.grad.detach().float().norm(2)
                for p in self.parameters() if p.grad is not None]
        if grads:
            total_norm = torch.norm(torch.stack(grads), p=2)
            self.log("train/grad_norm_back", total_norm, on_step=True, prog_bar=False)

    def on_train_epoch_end(self):
        if False:
            if self.train_auc.compute() is not None:
                self.log("train/auroc", self.train_auc.compute(), prog_bar=True)
            self.train_auc.reset()

    def on_validation_epoch_end(self):
        if False:
            if self.val_auc.compute() is not None:
                self.log("val/auroc", self.val_auc.compute(), prog_bar=True)
            self.val_auc.reset()

    def on_test_epoch_end(self):
        if False:
            if self.test_auc.compute() is not None:
                self.log("test/auroc", self.test_auc.compute(), prog_bar=True)
            self.test_auc.reset()

    def configure_optimizers(self):
        # AdamW + cosine as a sensible default
        opt = torch.optim.AdamW(
            self.parameters(),
            lr=self.hparams.lr,
            weight_decay=self.hparams.weight_decay,
        )
        # Scheduler optional—comment out if you prefer fixed LR
        sch = torch.optim.lr_scheduler.CosineAnnealingLR(
            opt, T_max=max(self.trainer.max_epochs, 1)
        )
        return {
            "optimizer": opt,
            "lr_scheduler": {"scheduler": sch, "interval": "epoch"},
        }