svincoff commited on
Commit
41553a7
·
1 Parent(s): e42f54a

bug fixes

Browse files
.gitignore CHANGED
@@ -34,4 +34,5 @@ dpacman/combine_shards.py
34
  dpacman/combine.log
35
  dpacman/loss_sim.py
36
  dpacman/loss_temp.py
37
- dpacman/peak_examples/
 
 
34
  dpacman/combine.log
35
  dpacman/loss_sim.py
36
  dpacman/loss_temp.py
37
+ dpacman/peak_examples/
38
+ dpacman/__pycache__/
configs/model/classifier.yaml CHANGED
@@ -6,4 +6,5 @@ gamma: 20
6
  weight_decay: 0.01
7
 
8
  glm_input_dim: 1029
9
- compressed_dim: 1029
 
 
6
  weight_decay: 0.01
7
 
8
  glm_input_dim: 1029
9
+ compressed_dim: 1029
10
+ hidden_dim: 256
dpacman.egg-info/SOURCES.txt CHANGED
@@ -1,6 +1,44 @@
1
  README.md
2
  setup.py
 
 
 
 
 
 
3
  dpacman.egg-info/PKG-INFO
4
  dpacman.egg-info/SOURCES.txt
5
  dpacman.egg-info/dependency_links.txt
6
- dpacman.egg-info/top_level.txt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  README.md
2
  setup.py
3
+ dpacman/__init__.py
4
+ dpacman/combine_shards.py
5
+ dpacman/loss_sim.py
6
+ dpacman/loss_temp.py
7
+ dpacman/temp.py
8
+ dpacman/temp2.py
9
  dpacman.egg-info/PKG-INFO
10
  dpacman.egg-info/SOURCES.txt
11
  dpacman.egg-info/dependency_links.txt
12
+ dpacman.egg-info/top_level.txt
13
+ dpacman/classifier/__init__.py
14
+ dpacman/classifier/loss.py
15
+ dpacman/classifier/model.py
16
+ dpacman/classifier/old_train.py
17
+ dpacman/classifier/torch_model.py
18
+ dpacman/classifier/train.py
19
+ dpacman/classifier/model_tmp/__init__.py
20
+ dpacman/classifier/model_tmp/clustering_data.py
21
+ dpacman/classifier/model_tmp/compress_embeddings.py
22
+ dpacman/classifier/model_tmp/compute_embeddings.py
23
+ dpacman/classifier/model_tmp/extract_tf_symbols.py
24
+ dpacman/classifier/model_tmp/make_pair_list.py
25
+ dpacman/classifier/model_tmp/make_peak_fasta.py
26
+ dpacman/classifier/model_tmp/model.py
27
+ dpacman/classifier/model_tmp/prep_splits.py
28
+ dpacman/classifier/model_tmp/train.py
29
+ dpacman/data_modules/__init__.py
30
+ dpacman/data_modules/pair.py
31
+ dpacman/scripts/__init__.py
32
+ dpacman/scripts/eval.py
33
+ dpacman/scripts/preprocess.py
34
+ dpacman/scripts/train.py
35
+ dpacman/utils/__init__.py
36
+ dpacman/utils/clustering.py
37
+ dpacman/utils/instantiators.py
38
+ dpacman/utils/logging_utils.py
39
+ dpacman/utils/models.py
40
+ dpacman/utils/plotting_utils.py
41
+ dpacman/utils/pylogger.py
42
+ dpacman/utils/rich_utils.py
43
+ dpacman/utils/splitting.py
44
+ dpacman/utils/utils.py
dpacman.egg-info/top_level.txt CHANGED
@@ -1 +1 @@
1
-
 
1
+ dpacman
dpacman/classifier/model.py CHANGED
@@ -28,59 +28,93 @@ class LocalCNN(nn.Module):
28
 
29
 
30
  class CrossModalBlock(nn.Module):
31
- def __init__(self, dim: int = 256, heads: int = 8):
32
  super().__init__()
33
  # self-attention for both sides
34
  self.sa_binder = nn.MultiheadAttention(dim, heads, batch_first=True)
35
  self.sa_glm = nn.MultiheadAttention(dim, heads, batch_first=True)
 
36
  self.ln_b1 = nn.LayerNorm(dim)
37
  self.ln_g1 = nn.LayerNorm(dim)
38
-
39
- self.ffn_b = nn.Sequential(
40
  nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
41
  )
42
- self.ffn_g = nn.Sequential(
43
  nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
44
  )
45
  self.ln_b2 = nn.LayerNorm(dim)
46
  self.ln_g2 = nn.LayerNorm(dim)
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  # cross attention (binder queries, glm keys/values)
49
  # so the NDA path is updated by the transcriptoin factors
50
- self.cross_attn = nn.MultiheadAttention(dim, heads, batch_first=True)
51
- self.ln_c1 = nn.LayerNorm(dim)
52
- self.ffn_c = nn.Sequential(
53
  nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
54
  )
55
- self.ln_c2 = nn.LayerNorm(dim)
56
 
57
- def forward(self, binder: torch.Tensor, glm: torch.Tensor):
58
  """
59
  binder: (batch, Lb, dim)
60
  glm: (batch, Lg, dim) -- has passed through its local CNN beforehand
61
  returns: updated binder representation (batch, Lb, dim)
62
  """
 
63
  # binder: self-attn + ffn
64
  b = binder
65
- b_sa, _ = self.sa_binder(b, b, b)
66
  b = self.ln_b1(b + b_sa)
67
- b_ff = self.ffn_b(b)
68
  b = self.ln_b2(b + b_ff)
69
 
70
  # glm: self-attn + ffn
71
  g = glm
72
- g_sa, _ = self.sa_glm(g, g, g)
73
  g = self.ln_g1(g + g_sa)
74
- g_ff = self.ffn_g(g)
75
  g = self.ln_g2(g + g_ff)
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
- # cross-attention: glm queries binder and glm embeddings are updated
78
- g_to_b_ca, _ = self.cross_attn(g, b, b)
79
- g = self.ln_c1(g + g_to_b_ca)
80
- g_ff = self.ffn_c(g)
81
- g = self.ln_c2(g + g_ff)
82
- return g # (batch, Lb, dim)
 
83
 
 
 
 
 
 
 
84
 
85
  class DimCompressor(nn.Module):
86
  """
@@ -144,7 +178,7 @@ class BindPredictor(LightningModule):
144
  # self.head = nn.Sequential(nn.Linear(hidden_dim, 1), nn.Sigmoid()) # OLD: returned probabilities
145
  self.head = nn.Linear(hidden_dim, 1) # NEW: return logits (safe for AMP)
146
 
147
- def forward(self, binder_emb, glm_emb):
148
  """
149
  binder_emb: (B, Lb, binder_input_dim)
150
  glm_emb: (B, Lg, glm_input_dim)
@@ -161,14 +195,15 @@ class BindPredictor(LightningModule):
161
 
162
  # Cross-modal blocks: update binder states using GLM
163
  for layer in self.layers:
164
- g = layer(b, g) # (B, Lb, hidden_dim)
165
 
166
  # Predict per-nucleotide logits on the GLM tokens:
167
  # return self.head(g).squeeze(-1) # OLD: probabilities (with Sigmoid in head)
168
- return self.head(g).squeeze(
169
  -1
170
- ) # NEW: logits (apply sigmoid only in loss/metrics)
171
-
 
172
  # ----- Lightning hooks -----
173
  def training_step(self, batch, batch_idx):
174
  """
@@ -184,7 +219,7 @@ class BindPredictor(LightningModule):
184
  "dna_sequence"
185
  }
186
  """
187
- logits = self.forward(batch["binder_emb"], batch["glm_emb"])
188
  loss = calculate_loss(
189
  logits, batch["labels"], alpha=self.hparams.alpha, gamma=self.hparams.gamma
190
  )
@@ -199,7 +234,7 @@ class BindPredictor(LightningModule):
199
  return loss
200
 
201
  def validation_step(self, batch, batch_idx):
202
- logits = self.forward(batch["binder_emb"], batch["glm_emb"])
203
  loss = calculate_loss(
204
  logits, batch["labels"], alpha=self.hparams.alpha, gamma=self.hparams.gamma
205
  )
@@ -214,7 +249,7 @@ class BindPredictor(LightningModule):
214
  return loss
215
 
216
  def test_step(self, batch, batch_idx):
217
- logits = self.forward(batch["binder_emb"], batch["glm_emb"])
218
  loss = calculate_loss(
219
  logits, batch["labels"], alpha=self.hparams.alpha, gamma=self.hparams.gamma
220
  )
 
28
 
29
 
30
  class CrossModalBlock(nn.Module):
31
+ def __init__(self, dim: int = 256, heads: int = 8, dropout: float = 0.0):
32
  super().__init__()
33
  # self-attention for both sides
34
  self.sa_binder = nn.MultiheadAttention(dim, heads, batch_first=True)
35
  self.sa_glm = nn.MultiheadAttention(dim, heads, batch_first=True)
36
+ # first layer norms
37
  self.ln_b1 = nn.LayerNorm(dim)
38
  self.ln_g1 = nn.LayerNorm(dim)
39
+ # first feed forward networks
40
+ self.ffn_b1 = nn.Sequential(
41
  nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
42
  )
43
+ self.ffn_g1 = nn.Sequential(
44
  nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
45
  )
46
  self.ln_b2 = nn.LayerNorm(dim)
47
  self.ln_g2 = nn.LayerNorm(dim)
48
+
49
+ # 2) reciprocal cross-attn: g<-b and b<-g
50
+ # DNA/GLM updated by attending to Binder
51
+ self.cross_g2b_1_RCA = nn.MultiheadAttention(dim, heads, batch_first=True, dropout=dropout)
52
+ self.ln_g3_RCA = nn.LayerNorm(dim)
53
+ self.ffn_g2_RCA = nn.Sequential(nn.Linear(dim, dim*4), nn.GELU(), nn.Linear(dim*4, dim))
54
+ self.ln_g4_RCA = nn.LayerNorm(dim)
55
+
56
+ # Binder updated by attending to DNA/GLM
57
+ self.cross_b2g_1_RCA = nn.MultiheadAttention(dim, heads, batch_first=True, dropout=dropout)
58
+ self.ln_b3_RCA = nn.LayerNorm(dim)
59
+ self.ffn_b2_RCA = nn.Sequential(nn.Linear(dim, dim*4), nn.GELU(), nn.Linear(dim*4, dim))
60
+ self.ln_b4_RCA = nn.LayerNorm(dim)
61
 
62
  # cross attention (binder queries, glm keys/values)
63
  # so the NDA path is updated by the transcriptoin factors
64
+ self.cross_g2b_2 = nn.MultiheadAttention(dim, heads, batch_first=True)
65
+ self.ln_g5 = nn.LayerNorm(dim)
66
+ self.ffn_g3 = nn.Sequential(
67
  nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
68
  )
69
+ self.ln_g6 = nn.LayerNorm(dim)
70
 
71
+ def forward(self, binder: torch.Tensor, glm: torch.Tensor, binder_kpm_mask=None, glm_kpm_mask=None):
72
  """
73
  binder: (batch, Lb, dim)
74
  glm: (batch, Lg, dim) -- has passed through its local CNN beforehand
75
  returns: updated binder representation (batch, Lb, dim)
76
  """
77
+ # 1) Self-attentino and feed-forward networks for binder and DNA
78
  # binder: self-attn + ffn
79
  b = binder
80
+ b_sa, _ = self.sa_binder(b, b, b, key_padding_mask=binder_kpm_mask)
81
  b = self.ln_b1(b + b_sa)
82
+ b_ff = self.ffn_b1(b)
83
  b = self.ln_b2(b + b_ff)
84
 
85
  # glm: self-attn + ffn
86
  g = glm
87
+ g_sa, _ = self.sa_glm(g, g, g, key_padding_mask=glm_kpm_mask)
88
  g = self.ln_g1(g + g_sa)
89
+ g_ff = self.ffn_g1(g)
90
  g = self.ln_g2(g + g_ff)
91
+
92
+ # 2a) Reciprocal Cross-Attention:
93
+ # DNA updated by attending to Binder (Q=g, K=b, V=b)
94
+ # Binder updated by attending to DNA (Q=b, K=g, V=g)
95
+ g_ca, _ = self.cross_g2b_1_RCA(
96
+ g, b, b, key_padding_mask=binder_kpm_mask
97
+ # torch MultiheadAttention expects key_padding_mask=True for PADs;
98
+ # invert if your mask is True=keep:
99
+ # key_padding_mask=(~binder_mask.bool()) if binder_mask is not None else None
100
+ )
101
+ g = self.ln_g3_RCA(g + g_ca)
102
+ g = self.ln_g4_RCA(g + self.ffn_g2_RCA(g))
103
 
104
+ # 2b) Binder updated by attending to DNA/GLM (Q=b, K=g, V=g)
105
+ b_ca, _ = self.cross_b2g_1_RCA(
106
+ b, g, g, key_padding_mask=glm_kpm_mask
107
+ # key_padding_mask=(~glm_mask.bool()) if glm_mask is not None else None
108
+ )
109
+ b = self.ln_b3_RCA(b + b_ca)
110
+ b = self.ln_b4_RCA(b + self.ffn_b2_RCA(b))
111
 
112
+ # cross-attention: glm queries binder and glm embeddings are updated
113
+ g_to_b_ca, _ = self.cross_g2b_2(g, b, b, key_padding_mask=binder_kpm_mask)
114
+ g = self.ln_g5(g + g_to_b_ca)
115
+ g_ff = self.ffn_g3(g)
116
+ g = self.ln_g6(g + g_ff)
117
+ return b, g # (batch, Lb, dim)
118
 
119
  class DimCompressor(nn.Module):
120
  """
 
178
  # self.head = nn.Sequential(nn.Linear(hidden_dim, 1), nn.Sigmoid()) # OLD: returned probabilities
179
  self.head = nn.Linear(hidden_dim, 1) # NEW: return logits (safe for AMP)
180
 
181
+ def forward(self, binder_emb, glm_emb, binder_mask, glm_mask):
182
  """
183
  binder_emb: (B, Lb, binder_input_dim)
184
  glm_emb: (B, Lg, glm_input_dim)
 
195
 
196
  # Cross-modal blocks: update binder states using GLM
197
  for layer in self.layers:
198
+ b, g = layer(b, g, binder_mask, glm_mask) # (B, Lb, hidden_dim)
199
 
200
  # Predict per-nucleotide logits on the GLM tokens:
201
  # return self.head(g).squeeze(-1) # OLD: probabilities (with Sigmoid in head)
202
+ logits = self.head(g).squeeze(
203
  -1
204
+ )
205
+ return logits
206
+
207
  # ----- Lightning hooks -----
208
  def training_step(self, batch, batch_idx):
209
  """
 
219
  "dna_sequence"
220
  }
221
  """
222
+ logits = self.forward(batch["binder_emb"], batch["glm_emb"], batch["binder_kpm"], batch["glm_kpm"])
223
  loss = calculate_loss(
224
  logits, batch["labels"], alpha=self.hparams.alpha, gamma=self.hparams.gamma
225
  )
 
234
  return loss
235
 
236
  def validation_step(self, batch, batch_idx):
237
+ logits = self.forward(batch["binder_emb"], batch["glm_emb"], batch["binder_kpm"], batch["glm_kpm"])
238
  loss = calculate_loss(
239
  logits, batch["labels"], alpha=self.hparams.alpha, gamma=self.hparams.gamma
240
  )
 
249
  return loss
250
 
251
  def test_step(self, batch, batch_idx):
252
+ logits = self.forward(batch["binder_emb"], batch["glm_emb"], batch["binder_kpm"], batch["glm_kpm"])
253
  loss = calculate_loss(
254
  logits, batch["labels"], alpha=self.hparams.alpha, gamma=self.hparams.gamma
255
  )
dpacman/classifier/model_w_rca.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Lightning Module for the binding model.
3
+ """
4
+
5
+ import torch
6
+ from torch import nn
7
+ from lightning import LightningModule
8
+ from dpacman.utils.models import set_seed
9
+ from .loss import calculate_loss
10
+
11
+ set_seed()
12
+
13
+
14
+ class LocalCNN(nn.Module):
15
+ def __init__(self, dim: int = 256, kernel_size: int = 3):
16
+ super().__init__()
17
+ padding = kernel_size // 2
18
+ self.conv = nn.Conv1d(dim, dim, kernel_size=kernel_size, padding=padding)
19
+ self.act = nn.GELU()
20
+ self.ln = nn.LayerNorm(dim)
21
+
22
+ def forward(self, x: torch.Tensor):
23
+ # x: (batch, L, dim)
24
+ out = self.conv(x.transpose(1, 2)) # → (batch, dim, L)
25
+ out = self.act(out)
26
+ out = out.transpose(1, 2) # → (batch, L, dim)
27
+ return self.ln(out + x) # residual
28
+
29
+
30
+ # class CrossModalBlock(nn.Module):
31
+ # def __init__(self, dim: int = 256, heads: int = 8):
32
+ # super().__init__()
33
+ # # self-attention for both sides
34
+ # self.sa_binder = nn.MultiheadAttention(dim, heads, batch_first=True)
35
+ # self.sa_glm = nn.MultiheadAttention(dim, heads, batch_first=True)
36
+ # self.ln_b1 = nn.LayerNorm(dim)
37
+ # self.ln_g1 = nn.LayerNorm(dim)
38
+
39
+ # self.ffn_b = nn.Sequential(
40
+ # nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
41
+ # )
42
+ # self.ffn_g = nn.Sequential(
43
+ # nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
44
+ # )
45
+ # self.ln_b2 = nn.LayerNorm(dim)
46
+ # self.ln_g2 = nn.LayerNorm(dim)
47
+
48
+ # # cross attention (binder queries, glm keys/values)
49
+ # # so the NDA path is updated by the transcriptoin factors
50
+ # self.cross_attn = nn.MultiheadAttention(dim, heads, batch_first=True)
51
+ # self.ln_c1 = nn.LayerNorm(dim)
52
+ # self.ffn_c = nn.Sequential(
53
+ # nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
54
+ # )
55
+ # self.ln_c2 = nn.LayerNorm(dim)
56
+
57
+ # def forward(self, binder: torch.Tensor, glm: torch.Tensor):
58
+ # """
59
+ # binder: (batch, Lb, dim)
60
+ # glm: (batch, Lg, dim) -- has passed through its local CNN beforehand
61
+ # returns: updated binder representation (batch, Lb, dim)
62
+ # """
63
+ # # binder: self-attn + ffn
64
+ # b = binder
65
+ # b_sa, _ = self.sa_binder(b, b, b)
66
+ # b = self.ln_b1(b + b_sa)
67
+ # b_ff = self.ffn_b(b)
68
+ # b = self.ln_b2(b + b_ff)
69
+
70
+ # # glm: self-attn + ffn
71
+ # g = glm
72
+ # g_sa, _ = self.sa_glm(g, g, g)
73
+ # g = self.ln_g1(g + g_sa)
74
+ # g_ff = self.ffn_g(g)
75
+ # g = self.ln_g2(g + g_ff)
76
+
77
+ # # cross-attention: glm queries binder and glm embeddings are updated
78
+ # g_to_b_ca, _ = self.cross_attn(g, b, b)
79
+ # g = self.ln_c1(g + g_to_b_ca)
80
+ # g_ff = self.ffn_c(g)
81
+ # g = self.ln_c2(g + g_ff)
82
+ # return g # (batch, Lb, dim)
83
+
84
+ class CrossModalBlock(nn.Module):
85
+ def __init__(self, dim: int = 256, heads: int = 8, dropout: float = 0.0):
86
+ super().__init__()
87
+ # 1) self-attn on each stream
88
+ self.sa_binder = nn.MultiheadAttention(dim, heads, batch_first=True, dropout=dropout)
89
+ self.sa_glm = nn.MultiheadAttention(dim, heads, batch_first=True, dropout=dropout)
90
+ self.ln_b1 = nn.LayerNorm(dim)
91
+ self.ln_g1 = nn.LayerNorm(dim)
92
+ self.ffn_b = nn.Sequential(nn.Linear(dim, dim*4), nn.GELU(), nn.Linear(dim*4, dim))
93
+ self.ffn_g = nn.Sequential(nn.Linear(dim, dim*4), nn.GELU(), nn.Linear(dim*4, dim))
94
+ self.ln_b2 = nn.LayerNorm(dim)
95
+ self.ln_g2 = nn.LayerNorm(dim)
96
+
97
+ # 2) reciprocal cross-attn: g<-b and b<-g
98
+ # DNA/GLM updated by attending to Binder
99
+ self.cross_g2b = nn.MultiheadAttention(dim, heads, batch_first=True, dropout=dropout)
100
+ self.ln_g_ca1 = nn.LayerNorm(dim)
101
+ self.ffn_g_ca = nn.Sequential(nn.Linear(dim, dim*4), nn.GELU(), nn.Linear(dim*4, dim))
102
+ self.ln_g_ca2 = nn.LayerNorm(dim)
103
+
104
+ # Binder updated by attending to DNA/GLM
105
+ self.cross_b2g = nn.MultiheadAttention(dim, heads, batch_first=True, dropout=dropout)
106
+ self.ln_b_ca1 = nn.LayerNorm(dim)
107
+ self.ffn_b_ca = nn.Sequential(nn.Linear(dim, dim*4), nn.GELU(), nn.Linear(dim*4, dim))
108
+ self.ln_b_ca2 = nn.LayerNorm(dim)
109
+
110
+ def forward(
111
+ self,
112
+ binder: torch.Tensor, # (B, Lb, D)
113
+ glm: torch.Tensor, # (B, Lg, D)
114
+ binder_mask: torch.Tensor | None = None, # (B, Lb) True = keep
115
+ glm_mask: torch.Tensor | None = None, # (B, Lg) True = keep
116
+ ):
117
+ # 1) self-attn+FFN on each stream
118
+ b, g = binder, glm
119
+
120
+ b_sa, _ = self.sa_binder(b, b, b, key_padding_mask=None)
121
+ b = self.ln_b1(b + b_sa)
122
+ b = self.ln_b2(b + self.ffn_b(b))
123
+
124
+ g_sa, _ = self.sa_glm(g, g, g, key_padding_mask=None)
125
+ g = self.ln_g1(g + g_sa)
126
+ g = self.ln_g2(g + self.ffn_g(g))
127
+
128
+ # 2a) DNA/GLM updated by attending to Binder (Q=g, K=b, V=b)
129
+ g_ca, _ = self.cross_g2b(
130
+ g, b, b,
131
+ # torch MultiheadAttention expects key_padding_mask=True for PADs;
132
+ # invert if your mask is True=keep:
133
+ # key_padding_mask=(~binder_mask.bool()) if binder_mask is not None else None
134
+ )
135
+ g = self.ln_g_ca1(g + g_ca)
136
+ g = self.ln_g_ca2(g + self.ffn_g_ca(g))
137
+
138
+ # 2b) Binder updated by attending to DNA/GLM (Q=b, K=g, V=g)
139
+ b_ca, _ = self.cross_b2g(
140
+ b, g, g,
141
+ # key_padding_mask=(~glm_mask.bool()) if glm_mask is not None else None
142
+ )
143
+ b = self.ln_b_ca1(b + b_ca)
144
+ b = self.ln_b_ca2(b + self.ffn_b_ca(b))
145
+
146
+ return b, g
147
+
148
+
149
+
150
+ class DimCompressor(nn.Module):
151
+ """
152
+ Learnable per-token compressor: maps any in_dim >= out_dim to out_dim (default 256).
153
+ If in_dim == out_dim, behaves as identity.
154
+ """
155
+
156
+ def __init__(self, in_dim: int, out_dim: int = 256):
157
+ super().__init__()
158
+ if in_dim == out_dim:
159
+ self.net = nn.Identity()
160
+ else:
161
+ hidden = max(out_dim * 2, (in_dim + out_dim) // 2)
162
+ self.net = nn.Sequential(
163
+ nn.LayerNorm(in_dim),
164
+ nn.Linear(in_dim, hidden),
165
+ nn.GELU(),
166
+ nn.Linear(hidden, out_dim),
167
+ )
168
+
169
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
170
+ # x: (B, L, in_dim)
171
+ return self.net(x)
172
+
173
+
174
+ class BindPredictor(LightningModule):
175
+ def __init__(
176
+ self,
177
+ # input_dim: int = 256, # OLD: single input dim
178
+ binder_input_dim: int = 1280, # NEW: TF (binder) original dim (e.g., 1280)
179
+ glm_input_dim: int = 256, # NEW: DNA/GLM original dim (e.g., 256)
180
+ compressed_dim: int = 256, # NEW: learnable compressed dim
181
+ hidden_dim: int = 256,
182
+ heads: int = 8,
183
+ num_layers: int = 4,
184
+ lr: float = 1e-4,
185
+ alpha: float = 20,
186
+ gamma: float = 20,
187
+ use_local_cnn_on_glm: bool = True,
188
+ weight_decay: float = 0.01,
189
+ ):
190
+ # Init
191
+ super(BindPredictor, self).__init__()
192
+ self.save_hyperparameters()
193
+
194
+ # Learnable compressor for binder -> 256, then project to hidden
195
+ self.binder_compress = DimCompressor(binder_input_dim, out_dim=compressed_dim)
196
+ self.proj_binder = nn.Linear(compressed_dim, hidden_dim)
197
+
198
+ # GLM side stays 256 -> hidden
199
+ self.proj_glm = nn.Linear(glm_input_dim, hidden_dim)
200
+
201
+ self.use_local_cnn = use_local_cnn_on_glm
202
+ self.local_cnn = LocalCNN(hidden_dim) if use_local_cnn_on_glm else nn.Identity()
203
+
204
+ self.layers = nn.ModuleList(
205
+ [CrossModalBlock(hidden_dim, heads) for _ in range(num_layers)]
206
+ )
207
+
208
+ self.ln_out = nn.LayerNorm(hidden_dim)
209
+ # self.head = nn.Sequential(nn.Linear(hidden_dim, 1), nn.Sigmoid()) # OLD: returned probabilities
210
+ self.head = nn.Linear(hidden_dim, 1) # NEW: return logits (safe for AMP)
211
+
212
+ def forward(self, binder_emb, glm_emb):
213
+ """
214
+ binder_emb: (B, Lb, binder_input_dim)
215
+ glm_emb: (B, Lg, glm_input_dim)
216
+ Returns per-nucleotide logits for the GLM sequence: (B, Lg)
217
+ """
218
+ # Binder: learnable compression → 256 → hidden
219
+ b = self.binder_compress(binder_emb) # (B, Lb, 256)
220
+ b = self.proj_binder(b) # (B, Lb, hidden_dim)
221
+
222
+ # GLM: project → hidden, add local CNN context
223
+ g = self.proj_glm(glm_emb) # (B, Lg, hidden_dim)
224
+ if self.use_local_cnn:
225
+ g = self.local_cnn(g)
226
+
227
+ # Cross-modal blocks: update binder states using GLM
228
+ for layer in self.layers:
229
+ b, g = layer(b, g) # (B, Lb, hidden_dim)
230
+
231
+ # Predict per-nucleotide logits on the GLM tokens:
232
+ # return self.head(g).squeeze(-1) # OLD: probabilities (with Sigmoid in head)
233
+ return self.head(g).squeeze(
234
+ -1
235
+ ) # NEW: logits (apply sigmoid only in loss/metrics)
236
+
237
+ # ----- Lightning hooks -----
238
+ def training_step(self, batch, batch_idx):
239
+ """
240
+ Training step taken by PyTorch-Lightning trainer. Uses batch returned by data collator.
241
+ Colator returns a dictionary with:
242
+ "binder_emb" # [B, Lb_max, Db]
243
+ "binder_mask" # [B, Lb_max]
244
+ "glm_emb" # [B, Lg_max, Dg]
245
+ "glm_mask" # [B, Lg_max]
246
+ "labels" # [B, Lg_max]
247
+ "ID"
248
+ "tr_sequence"
249
+ "dna_sequence"
250
+ }
251
+ """
252
+ logits = self.forward(batch["binder_emb"], batch["glm_emb"])
253
+ loss = calculate_loss(
254
+ logits, batch["labels"], alpha=self.hparams.alpha, gamma=self.hparams.gamma
255
+ )
256
+ self.log(
257
+ "train/loss",
258
+ loss,
259
+ on_step=True,
260
+ on_epoch=True,
261
+ prog_bar=True,
262
+ batch_size=logits.size(0),
263
+ )
264
+ return loss
265
+
266
+ def validation_step(self, batch, batch_idx):
267
+ logits = self.forward(batch["binder_emb"], batch["glm_emb"])
268
+ loss = calculate_loss(
269
+ logits, batch["labels"], alpha=self.hparams.alpha, gamma=self.hparams.gamma
270
+ )
271
+ self.log(
272
+ "val/loss",
273
+ loss,
274
+ on_step=False,
275
+ on_epoch=True,
276
+ prog_bar=True,
277
+ batch_size=logits.size(0),
278
+ )
279
+ return loss
280
+
281
+ def test_step(self, batch, batch_idx):
282
+ logits = self.forward(batch["binder_emb"], batch["glm_emb"])
283
+ loss = calculate_loss(
284
+ logits, batch["labels"], alpha=self.hparams.alpha, gamma=self.hparams.gamma
285
+ )
286
+ self.log(
287
+ "test/loss", loss, on_step=False, on_epoch=True, batch_size=logits.size(0)
288
+ )
289
+ return loss
290
+
291
+ def on_train_epoch_end(self):
292
+ if False:
293
+ if self.train_auc.compute() is not None:
294
+ self.log("train/auroc", self.train_auc.compute(), prog_bar=True)
295
+ self.train_auc.reset()
296
+
297
+ def on_validation_epoch_end(self):
298
+ if False:
299
+ if self.val_auc.compute() is not None:
300
+ self.log("val/auroc", self.val_auc.compute(), prog_bar=True)
301
+ self.val_auc.reset()
302
+
303
+ def on_test_epoch_end(self):
304
+ if False:
305
+ if self.test_auc.compute() is not None:
306
+ self.log("test/auroc", self.test_auc.compute(), prog_bar=True)
307
+ self.test_auc.reset()
308
+
309
+ def configure_optimizers(self):
310
+ # AdamW + cosine as a sensible default
311
+ opt = torch.optim.AdamW(
312
+ self.parameters(),
313
+ lr=self.hparams.lr,
314
+ weight_decay=self.hparams.weight_decay,
315
+ )
316
+ # Scheduler optional—comment out if you prefer fixed LR
317
+ sch = torch.optim.lr_scheduler.CosineAnnealingLR(
318
+ opt, T_max=max(self.trainer.max_epochs, 1)
319
+ )
320
+ return {
321
+ "optimizer": opt,
322
+ "lr_scheduler": {"scheduler": sch, "interval": "epoch"},
323
+ }
dpacman/data_modules/pair.py CHANGED
@@ -2,7 +2,8 @@
2
  import argparse
3
  import numpy as np
4
  import torch
5
- from torch.utils.data import Dataset, DataLoader, Sampler
 
6
  from lightning import LightningDataModule
7
  from pathlib import Path
8
  from multiprocessing import cpu_count
@@ -14,11 +15,66 @@ from typing import List, Iterable, Sequence
14
  import sys
15
  import rootutils
16
  import logging
 
17
  from dpacman.utils import pylogger
18
 
19
  root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
20
  logger = pylogger.RankedLogger(__name__, rank_zero_only=True)
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  class PreBatchedSampler(Sampler[List[int]]):
23
  """
24
  Yields precomputed batches of indices, e.g. [[3,7,9], [0,1,2], ...].
@@ -211,9 +267,11 @@ class PairDataModule(LightningDataModule):
211
  batch_size=self.batch_size,
212
  drop_last=self.drop_last,
213
  )
214
- self.train_batch_sampler = PreBatchedSampler(
215
  self.train_batches,
216
  shuffle_batch_order=self.shuffle_batch_order,
 
 
217
  )
218
 
219
  if not hasattr(self, "val_dataset"):
@@ -225,8 +283,8 @@ class PairDataModule(LightningDataModule):
225
  batch_size=self.batch_size,
226
  drop_last=False,
227
  )
228
- self.val_batch_sampler = PreBatchedSampler(
229
- self.val_batches, shuffle_batch_order=False
230
  )
231
 
232
  # VALIDATE called standalone: ensure val is built
@@ -240,10 +298,10 @@ class PairDataModule(LightningDataModule):
240
  batch_size=self.batch_size,
241
  drop_last=False,
242
  )
243
- self.val_batch_sampler = PreBatchedSampler(
244
- self.val_batches, shuffle_batch_order=False
245
  )
246
-
247
  # TEST phase
248
  if stage in (None, "test"):
249
  if not hasattr(self, "test_dataset"):
@@ -393,19 +451,23 @@ class ShelfCollator:
393
  1
394
  ) # [B, Lg_max]
395
 
 
 
 
 
396
  # 3) Collate labels for DNA and pad
397
  labels_list = [torch.tensor(s, dtype=torch.float32) for s in scores_list]
398
  labels = pad_sequence(
399
- labels_list, batch_first=True, padding_value=0.0
400
  ) # [B, Lg_max]
401
- # (Optional) ensure labels are zeroed beyond mask:
402
- labels = labels * glm_mask.to(labels.dtype)
403
 
404
  return {
405
  "binder_emb": binder_emb, # [B, Lb_max, Db]
406
  "binder_mask": binder_mask, # [B, Lb_max]
 
407
  "glm_emb": glm_emb, # [B, Lg_max, Dg]
408
  "glm_mask": glm_mask, # [B, Lg_max]
 
409
  "labels": labels, # [B, Lg_max]
410
  "ID": ids,
411
  "tr_sequence": tr_seqs,
@@ -451,9 +513,11 @@ def _peek_batches(dl, n_batches: int = 2, tag: str = "train"):
451
 
452
  logger.info(f"\n[{tag}] batch {i+1}")
453
  logger.info(f" binder_emb: {tuple(be.shape)} dtype={be.dtype}")
 
454
  logger.info(f" binder_mask true count: {bm.sum().item()} / {bm.numel()}")
455
  logger.info(f" glm_emb: {tuple(ge.shape)} dtype={ge.dtype}")
456
  logger.info(f" glm_mask true count: {gm.sum().item()} / {gm.numel()}")
 
457
  logger.info(
458
  f" labels: {tuple(y.shape)} min={y.min().item():.4f} max={y.max().item():.4f}"
459
  )
 
2
  import argparse
3
  import numpy as np
4
  import torch
5
+ from torch.utils.data import Dataset, DataLoader, Sampler, BatchSampler
6
+ import torch.distributed as dist
7
  from lightning import LightningDataModule
8
  from pathlib import Path
9
  from multiprocessing import cpu_count
 
15
  import sys
16
  import rootutils
17
  import logging
18
+ import math
19
  from dpacman.utils import pylogger
20
 
21
  root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
22
  logger = pylogger.RankedLogger(__name__, rank_zero_only=True)
23
 
24
+ class PreBatchedDistributedBatchSampler(BatchSampler):
25
+ """
26
+ Accepts a precomputed list of batches (list[list[int]]) and shards them across DDP ranks.
27
+ - shuffle_batch_order: shuffle order of batches each epoch (deterministic via set_epoch)
28
+ - drop_last: drop remainder so each rank gets same #steps
29
+ """
30
+ def __init__(self, batches, shuffle_batch_order=False, drop_last=False, seed: int = 0):
31
+ # super expects attributes batch_size, drop_last, sampler – but we don't need them.
32
+ # We only need to subclass BatchSampler to satisfy Lightning's check.
33
+ self.batches = [list(b) for b in batches]
34
+ self.shuffle = shuffle_batch_order
35
+ self.drop_last = drop_last
36
+ self.seed = int(seed)
37
+ self.epoch = 0
38
+
39
+ if dist.is_available() and dist.is_initialized():
40
+ self.world_size = dist.get_world_size()
41
+ self.rank = dist.get_rank()
42
+ else:
43
+ self.world_size = 1
44
+ self.rank = 0
45
+
46
+ def __iter__(self):
47
+ n_batches = len(self.batches)
48
+ order = list(range(n_batches))
49
+
50
+ if self.shuffle:
51
+ g = torch.Generator()
52
+ g.manual_seed(self.seed + self.epoch)
53
+ order = torch.randperm(n_batches, generator=g).tolist()
54
+
55
+ # make divisible across ranks
56
+ if self.drop_last:
57
+ total = (len(order) // self.world_size) * self.world_size
58
+ order = order[:total]
59
+ else:
60
+ pad = (-len(order)) % self.world_size
61
+ if pad:
62
+ order = order + order[:pad]
63
+
64
+ # shard by rank
65
+ for i in order[self.rank::self.world_size]:
66
+ yield self.batches[i]
67
+
68
+ def __len__(self):
69
+ n = len(self.batches)
70
+ if self.drop_last:
71
+ return (n // self.world_size)
72
+ return math.ceil(n / self.world_size)
73
+
74
+ # Lightning will call this if present via its epoch hooks
75
+ def set_epoch(self, epoch: int):
76
+ self.epoch = int(epoch)
77
+
78
  class PreBatchedSampler(Sampler[List[int]]):
79
  """
80
  Yields precomputed batches of indices, e.g. [[3,7,9], [0,1,2], ...].
 
267
  batch_size=self.batch_size,
268
  drop_last=self.drop_last,
269
  )
270
+ self.train_batch_sampler = PreBatchedDistributedBatchSampler(
271
  self.train_batches,
272
  shuffle_batch_order=self.shuffle_batch_order,
273
+ drop_last=self.drop_last,
274
+ seed=0,
275
  )
276
 
277
  if not hasattr(self, "val_dataset"):
 
283
  batch_size=self.batch_size,
284
  drop_last=False,
285
  )
286
+ self.val_batch_sampler = PreBatchedDistributedBatchSampler(
287
+ self.val_batches, shuffle_batch_order=False, drop_last=False, seed=0
288
  )
289
 
290
  # VALIDATE called standalone: ensure val is built
 
298
  batch_size=self.batch_size,
299
  drop_last=False,
300
  )
301
+ self.val_batch_sampler = PreBatchedDistributedBatchSampler(
302
+ self.test_batches, shuffle_batch_order=False, drop_last=False, seed=0
303
  )
304
+
305
  # TEST phase
306
  if stage in (None, "test"):
307
  if not hasattr(self, "test_dataset"):
 
451
  1
452
  ) # [B, Lg_max]
453
 
454
+ # True = PAD (what MHA expects)
455
+ binder_kpm = ~binder_mask
456
+ glm_kpm = ~glm_mask
457
+
458
  # 3) Collate labels for DNA and pad
459
  labels_list = [torch.tensor(s, dtype=torch.float32) for s in scores_list]
460
  labels = pad_sequence(
461
+ labels_list, batch_first=True, padding_value=self.pad_value
462
  ) # [B, Lg_max]
 
 
463
 
464
  return {
465
  "binder_emb": binder_emb, # [B, Lb_max, Db]
466
  "binder_mask": binder_mask, # [B, Lb_max]
467
+ "binder_kpm": binder_kpm.bool(), # True = PAD ← pass to MHA
468
  "glm_emb": glm_emb, # [B, Lg_max, Dg]
469
  "glm_mask": glm_mask, # [B, Lg_max]
470
+ "glm_kpm": glm_kpm.bool(), # True = PAD ← pass to MHA
471
  "labels": labels, # [B, Lg_max]
472
  "ID": ids,
473
  "tr_sequence": tr_seqs,
 
513
 
514
  logger.info(f"\n[{tag}] batch {i+1}")
515
  logger.info(f" binder_emb: {tuple(be.shape)} dtype={be.dtype}")
516
+ logger.info(f" binder_emb: {tuple(bm.shape)} dtype={bm.dtype}")
517
  logger.info(f" binder_mask true count: {bm.sum().item()} / {bm.numel()}")
518
  logger.info(f" glm_emb: {tuple(ge.shape)} dtype={ge.dtype}")
519
  logger.info(f" glm_mask true count: {gm.sum().item()} / {gm.numel()}")
520
+ logger.info(f" glm_mask: {tuple(gm.shape)} dtype={gm.dtype}")
521
  logger.info(
522
  f" labels: {tuple(y.shape)} min={y.min().item():.4f} max={y.max().item():.4f}"
523
  )
dpacman/scripts/run_train.sh CHANGED
@@ -14,16 +14,20 @@ if [ -z "$WANDB_API_KEY" ]; then
14
  export WANDB_API_KEY="$wandb_key"
15
  fi
16
 
17
- nohup python -u -m scripts.train \
 
 
18
  hydra.run.dir="${run_dir}" \
 
19
  data_module.train_file="data_files/processed/splits/by_dna/train.csv" \
20
  data_module.val_file="data_files/processed/splits/by_dna/val.csv" \
21
  data_module.test_file="data_files/processed/splits/by_dna/test.csv" \
22
  data_module.tr_shelf_path="data_files/processed/embeddings/fimo_hits_only/trs_esm.shelf" \
23
  data_module.dna_shelf_path="data_files/processed/embeddings/fimo_hits_only/peaks_caduceus.shelf" \
24
  model.glm_input_dim=256 \
 
 
25
  model.lr=1e-5 \
26
- model.compressed_dim=1029 \
27
  > "${run_dir}/run.log" 2>&1 &
28
 
29
  echo $! > "${run_dir}/pid.txt"
 
14
  export WANDB_API_KEY="$wandb_key"
15
  fi
16
 
17
+ CUDA_VISIBLE_DEVICES=0,1 nohup python -u -m scripts.train \
18
+ +trainer.strategy=ddp \
19
+ +trainer.use_distributed_sampler="false"\
20
  hydra.run.dir="${run_dir}" \
21
+ trainer.devices=2 \
22
  data_module.train_file="data_files/processed/splits/by_dna/train.csv" \
23
  data_module.val_file="data_files/processed/splits/by_dna/val.csv" \
24
  data_module.test_file="data_files/processed/splits/by_dna/test.csv" \
25
  data_module.tr_shelf_path="data_files/processed/embeddings/fimo_hits_only/trs_esm.shelf" \
26
  data_module.dna_shelf_path="data_files/processed/embeddings/fimo_hits_only/peaks_caduceus.shelf" \
27
  model.glm_input_dim=256 \
28
+ model.compressed_dim=256 \
29
+ model.hidden_dim=256 \
30
  model.lr=1e-5 \
 
31
  > "${run_dir}/run.log" 2>&1 &
32
 
33
  echo $! > "${run_dir}/pid.txt"