svincoff commited on
Commit
121a325
·
1 Parent(s): 4c4b1fc

baseline compare

Browse files
.gitignore CHANGED
@@ -35,4 +35,7 @@ dpacman/combine.log
35
  dpacman/loss_sim.py
36
  dpacman/loss_temp.py
37
  dpacman/peak_examples/
38
- dpacman/__pycache__/
 
 
 
 
35
  dpacman/loss_sim.py
36
  dpacman/loss_temp.py
37
  dpacman/peak_examples/
38
+ dpacman/__pycache__/
39
+ log.log
40
+ log2.log
41
+ dpacman/delay.log
configs/model/baseline.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: dpacman.classifier.baseline.BaselineBindPredictor
2
+
3
+ lr: 1e-4
4
+ alpha: 20
5
+ gamma: 20
6
+ weight_decay: 0.01
7
+
8
+ glm_input_dim: 256
9
+ compressed_dim: 256
10
+ hidden_dim: 128
configs/model/pooling/truncatedsvd.yaml DELETED
@@ -1,7 +0,0 @@
1
- n_components: 2
2
- algorithm: randomized
3
- n_iter: 5
4
- n_oversamples: 10
5
- poewr_iteration_normalizer: auto
6
- random_state: 42
7
- tol: 0
 
 
 
 
 
 
 
 
dpacman/classifier/{model_w_rca.py → baseline.py} RENAMED
@@ -1,177 +1,17 @@
1
  """
2
- Lightning Module for the binding model.
3
  """
4
 
5
- import torch
6
- from torch import nn
7
  from lightning import LightningModule
8
- from dpacman.utils.models import set_seed
9
- from .loss import calculate_loss
10
-
11
- set_seed()
12
-
13
-
14
- class LocalCNN(nn.Module):
15
- def __init__(self, dim: int = 256, kernel_size: int = 3):
16
- super().__init__()
17
- padding = kernel_size // 2
18
- self.conv = nn.Conv1d(dim, dim, kernel_size=kernel_size, padding=padding)
19
- self.act = nn.GELU()
20
- self.ln = nn.LayerNorm(dim)
21
-
22
- def forward(self, x: torch.Tensor):
23
- # x: (batch, L, dim)
24
- out = self.conv(x.transpose(1, 2)) # → (batch, dim, L)
25
- out = self.act(out)
26
- out = out.transpose(1, 2) # → (batch, L, dim)
27
- return self.ln(out + x) # residual
28
-
29
-
30
- # class CrossModalBlock(nn.Module):
31
- # def __init__(self, dim: int = 256, heads: int = 8):
32
- # super().__init__()
33
- # # self-attention for both sides
34
- # self.sa_binder = nn.MultiheadAttention(dim, heads, batch_first=True)
35
- # self.sa_glm = nn.MultiheadAttention(dim, heads, batch_first=True)
36
- # self.ln_b1 = nn.LayerNorm(dim)
37
- # self.ln_g1 = nn.LayerNorm(dim)
38
-
39
- # self.ffn_b = nn.Sequential(
40
- # nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
41
- # )
42
- # self.ffn_g = nn.Sequential(
43
- # nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
44
- # )
45
- # self.ln_b2 = nn.LayerNorm(dim)
46
- # self.ln_g2 = nn.LayerNorm(dim)
47
-
48
- # # cross attention (binder queries, glm keys/values)
49
- # # so the NDA path is updated by the transcriptoin factors
50
- # self.cross_attn = nn.MultiheadAttention(dim, heads, batch_first=True)
51
- # self.ln_c1 = nn.LayerNorm(dim)
52
- # self.ffn_c = nn.Sequential(
53
- # nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
54
- # )
55
- # self.ln_c2 = nn.LayerNorm(dim)
56
-
57
- # def forward(self, binder: torch.Tensor, glm: torch.Tensor):
58
- # """
59
- # binder: (batch, Lb, dim)
60
- # glm: (batch, Lg, dim) -- has passed through its local CNN beforehand
61
- # returns: updated binder representation (batch, Lb, dim)
62
- # """
63
- # # binder: self-attn + ffn
64
- # b = binder
65
- # b_sa, _ = self.sa_binder(b, b, b)
66
- # b = self.ln_b1(b + b_sa)
67
- # b_ff = self.ffn_b(b)
68
- # b = self.ln_b2(b + b_ff)
69
-
70
- # # glm: self-attn + ffn
71
- # g = glm
72
- # g_sa, _ = self.sa_glm(g, g, g)
73
- # g = self.ln_g1(g + g_sa)
74
- # g_ff = self.ffn_g(g)
75
- # g = self.ln_g2(g + g_ff)
76
-
77
- # # cross-attention: glm queries binder and glm embeddings are updated
78
- # g_to_b_ca, _ = self.cross_attn(g, b, b)
79
- # g = self.ln_c1(g + g_to_b_ca)
80
- # g_ff = self.ffn_c(g)
81
- # g = self.ln_c2(g + g_ff)
82
- # return g # (batch, Lb, dim)
83
-
84
- class CrossModalBlock(nn.Module):
85
- def __init__(self, dim: int = 256, heads: int = 8, dropout: float = 0.0):
86
- super().__init__()
87
- # 1) self-attn on each stream
88
- self.sa_binder = nn.MultiheadAttention(dim, heads, batch_first=True, dropout=dropout)
89
- self.sa_glm = nn.MultiheadAttention(dim, heads, batch_first=True, dropout=dropout)
90
- self.ln_b1 = nn.LayerNorm(dim)
91
- self.ln_g1 = nn.LayerNorm(dim)
92
- self.ffn_b = nn.Sequential(nn.Linear(dim, dim*4), nn.GELU(), nn.Linear(dim*4, dim))
93
- self.ffn_g = nn.Sequential(nn.Linear(dim, dim*4), nn.GELU(), nn.Linear(dim*4, dim))
94
- self.ln_b2 = nn.LayerNorm(dim)
95
- self.ln_g2 = nn.LayerNorm(dim)
96
-
97
- # 2) reciprocal cross-attn: g<-b and b<-g
98
- # DNA/GLM updated by attending to Binder
99
- self.cross_g2b = nn.MultiheadAttention(dim, heads, batch_first=True, dropout=dropout)
100
- self.ln_g_ca1 = nn.LayerNorm(dim)
101
- self.ffn_g_ca = nn.Sequential(nn.Linear(dim, dim*4), nn.GELU(), nn.Linear(dim*4, dim))
102
- self.ln_g_ca2 = nn.LayerNorm(dim)
103
-
104
- # Binder updated by attending to DNA/GLM
105
- self.cross_b2g = nn.MultiheadAttention(dim, heads, batch_first=True, dropout=dropout)
106
- self.ln_b_ca1 = nn.LayerNorm(dim)
107
- self.ffn_b_ca = nn.Sequential(nn.Linear(dim, dim*4), nn.GELU(), nn.Linear(dim*4, dim))
108
- self.ln_b_ca2 = nn.LayerNorm(dim)
109
-
110
- def forward(
111
- self,
112
- binder: torch.Tensor, # (B, Lb, D)
113
- glm: torch.Tensor, # (B, Lg, D)
114
- binder_mask: torch.Tensor | None = None, # (B, Lb) True = keep
115
- glm_mask: torch.Tensor | None = None, # (B, Lg) True = keep
116
- ):
117
- # 1) self-attn+FFN on each stream
118
- b, g = binder, glm
119
-
120
- b_sa, _ = self.sa_binder(b, b, b, key_padding_mask=None)
121
- b = self.ln_b1(b + b_sa)
122
- b = self.ln_b2(b + self.ffn_b(b))
123
-
124
- g_sa, _ = self.sa_glm(g, g, g, key_padding_mask=None)
125
- g = self.ln_g1(g + g_sa)
126
- g = self.ln_g2(g + self.ffn_g(g))
127
-
128
- # 2a) DNA/GLM updated by attending to Binder (Q=g, K=b, V=b)
129
- g_ca, _ = self.cross_g2b(
130
- g, b, b,
131
- # torch MultiheadAttention expects key_padding_mask=True for PADs;
132
- # invert if your mask is True=keep:
133
- # key_padding_mask=(~binder_mask.bool()) if binder_mask is not None else None
134
- )
135
- g = self.ln_g_ca1(g + g_ca)
136
- g = self.ln_g_ca2(g + self.ffn_g_ca(g))
137
-
138
- # 2b) Binder updated by attending to DNA/GLM (Q=b, K=g, V=g)
139
- b_ca, _ = self.cross_b2g(
140
- b, g, g,
141
- # key_padding_mask=(~glm_mask.bool()) if glm_mask is not None else None
142
- )
143
- b = self.ln_b_ca1(b + b_ca)
144
- b = self.ln_b_ca2(b + self.ffn_b_ca(b))
145
-
146
- return b, g
147
-
148
-
149
 
150
- class DimCompressor(nn.Module):
151
  """
152
- Learnable per-token compressor: maps any in_dim >= out_dim to out_dim (default 256).
153
- If in_dim == out_dim, behaves as identity.
154
  """
155
-
156
- def __init__(self, in_dim: int, out_dim: int = 256):
157
- super().__init__()
158
- if in_dim == out_dim:
159
- self.net = nn.Identity()
160
- else:
161
- hidden = max(out_dim * 2, (in_dim + out_dim) // 2)
162
- self.net = nn.Sequential(
163
- nn.LayerNorm(in_dim),
164
- nn.Linear(in_dim, hidden),
165
- nn.GELU(),
166
- nn.Linear(hidden, out_dim),
167
- )
168
-
169
- def forward(self, x: torch.Tensor) -> torch.Tensor:
170
- # x: (B, L, in_dim)
171
- return self.net(x)
172
-
173
-
174
- class BindPredictor(LightningModule):
175
  def __init__(
176
  self,
177
  # input_dim: int = 256, # OLD: single input dim
@@ -179,77 +19,64 @@ class BindPredictor(LightningModule):
179
  glm_input_dim: int = 256, # NEW: DNA/GLM original dim (e.g., 256)
180
  compressed_dim: int = 256, # NEW: learnable compressed dim
181
  hidden_dim: int = 256,
182
- heads: int = 8,
183
- num_layers: int = 4,
184
  lr: float = 1e-4,
185
  alpha: float = 20,
186
  gamma: float = 20,
187
- use_local_cnn_on_glm: bool = True,
188
  weight_decay: float = 0.01,
189
  ):
190
  # Init
191
- super(BindPredictor, self).__init__()
192
  self.save_hyperparameters()
193
 
194
  # Learnable compressor for binder -> 256, then project to hidden
195
  self.binder_compress = DimCompressor(binder_input_dim, out_dim=compressed_dim)
196
- self.proj_binder = nn.Linear(compressed_dim, hidden_dim)
197
-
198
- # GLM side stays 256 -> hidden
199
- self.proj_glm = nn.Linear(glm_input_dim, hidden_dim)
200
-
201
- self.use_local_cnn = use_local_cnn_on_glm
202
- self.local_cnn = LocalCNN(hidden_dim) if use_local_cnn_on_glm else nn.Identity()
203
-
204
- self.layers = nn.ModuleList(
205
- [CrossModalBlock(hidden_dim, heads) for _ in range(num_layers)]
206
  )
207
-
208
- self.ln_out = nn.LayerNorm(hidden_dim)
209
- # self.head = nn.Sequential(nn.Linear(hidden_dim, 1), nn.Sigmoid()) # OLD: returned probabilities
210
- self.head = nn.Linear(hidden_dim, 1) # NEW: return logits (safe for AMP)
211
-
212
- def forward(self, binder_emb, glm_emb):
213
  """
214
  binder_emb: (B, Lb, binder_input_dim)
215
  glm_emb: (B, Lg, glm_input_dim)
216
  Returns per-nucleotide logits for the GLM sequence: (B, Lg)
217
  """
218
- # Binder: learnable compression → 256 → hidden
219
- b = self.binder_compress(binder_emb) # (B, Lb, 256)
220
- b = self.proj_binder(b) # (B, Lb, hidden_dim)
221
-
222
- # GLM: project → hidden, add local CNN context
223
- g = self.proj_glm(glm_emb) # (B, Lg, hidden_dim)
224
- if self.use_local_cnn:
225
- g = self.local_cnn(g)
226
-
227
- # Cross-modal blocks: update binder states using GLM
228
- for layer in self.layers:
229
- b, g = layer(b, g) # (B, Lb, hidden_dim)
230
-
231
- # Predict per-nucleotide logits on the GLM tokens:
232
- # return self.head(g).squeeze(-1) # OLD: probabilities (with Sigmoid in head)
233
- return self.head(g).squeeze(
234
  -1
235
- ) # NEW: logits (apply sigmoid only in loss/metrics)
236
-
 
237
  # ----- Lightning hooks -----
238
  def training_step(self, batch, batch_idx):
239
  """
240
  Training step taken by PyTorch-Lightning trainer. Uses batch returned by data collator.
241
  Colator returns a dictionary with:
242
  "binder_emb" # [B, Lb_max, Db]
243
- "binder_mask" # [B, Lb_max]
244
  "glm_emb" # [B, Lg_max, Dg]
245
- "glm_mask" # [B, Lg_max]
246
  "labels" # [B, Lg_max]
247
  "ID"
248
  "tr_sequence"
249
  "dna_sequence"
250
  }
251
  """
252
- logits = self.forward(batch["binder_emb"], batch["glm_emb"])
253
  loss = calculate_loss(
254
  logits, batch["labels"], alpha=self.hparams.alpha, gamma=self.hparams.gamma
255
  )
@@ -261,10 +88,30 @@ class BindPredictor(LightningModule):
261
  prog_bar=True,
262
  batch_size=logits.size(0),
263
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  return loss
265
 
266
  def validation_step(self, batch, batch_idx):
267
- logits = self.forward(batch["binder_emb"], batch["glm_emb"])
268
  loss = calculate_loss(
269
  logits, batch["labels"], alpha=self.hparams.alpha, gamma=self.hparams.gamma
270
  )
@@ -276,17 +123,65 @@ class BindPredictor(LightningModule):
276
  prog_bar=True,
277
  batch_size=logits.size(0),
278
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
  return loss
280
 
281
  def test_step(self, batch, batch_idx):
282
- logits = self.forward(batch["binder_emb"], batch["glm_emb"])
283
  loss = calculate_loss(
284
  logits, batch["labels"], alpha=self.hparams.alpha, gamma=self.hparams.gamma
285
  )
286
  self.log(
287
  "test/loss", loss, on_step=False, on_epoch=True, batch_size=logits.size(0)
288
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
  return loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
 
291
  def on_train_epoch_end(self):
292
  if False:
@@ -320,4 +215,4 @@ class BindPredictor(LightningModule):
320
  return {
321
  "optimizer": opt,
322
  "lr_scheduler": {"scheduler": sch, "interval": "epoch"},
323
- }
 
1
  """
2
+ Code for baseline model to compare the classifier to
3
  """
4
 
 
 
5
  from lightning import LightningModule
6
+ import torch
7
+ import torch.nn as nn
8
+ from .loss import calculate_loss, auprc_zeros_vs_ones_from_logits, auroc_zeros_vs_ones_from_logits
9
+ from .model import DimCompressor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ class BaselineBindPredictor(LightningModule):
12
  """
13
+ Baseline predictor: simple MLP that just concatenates the embeddings and outputs per-token predictions.
 
14
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  def __init__(
16
  self,
17
  # input_dim: int = 256, # OLD: single input dim
 
19
  glm_input_dim: int = 256, # NEW: DNA/GLM original dim (e.g., 256)
20
  compressed_dim: int = 256, # NEW: learnable compressed dim
21
  hidden_dim: int = 256,
 
 
22
  lr: float = 1e-4,
23
  alpha: float = 20,
24
  gamma: float = 20,
25
+ dropout: float = 0,
26
  weight_decay: float = 0.01,
27
  ):
28
  # Init
29
+ super(BaselineBindPredictor, self).__init__()
30
  self.save_hyperparameters()
31
 
32
  # Learnable compressor for binder -> 256, then project to hidden
33
  self.binder_compress = DimCompressor(binder_input_dim, out_dim=compressed_dim)
34
+
35
+ self.mlp = torch.nn.Sequential(
36
+ torch.nn.Linear(compressed_dim, hidden_dim),
37
+ torch.nn.ReLU(),
38
+ torch.nn.Linear(hidden_dim, 1),
39
+ torch.nn.ReLU(),
 
 
 
 
40
  )
41
+
42
+ def forward(self, binder_emb, glm_emb, binder_mask, glm_mask):
 
 
 
 
43
  """
44
  binder_emb: (B, Lb, binder_input_dim)
45
  glm_emb: (B, Lg, glm_input_dim)
46
  Returns per-nucleotide logits for the GLM sequence: (B, Lg)
47
  """
48
+ # Binder: learnable compression → glm_input_dim
49
+ b = self.binder_compress(binder_emb) # (B, Lb, glm_input_dim)
50
+
51
+ # Concatenate target and binder. Concatenate on the length dimension
52
+ lg = glm_emb.shape[1]
53
+ concat_embeddings = torch.concat((glm_emb,b), dim=1) # (B, Lb + Lg, glm_input_dim)
54
+
55
+ # Run concatenated embeddings through MLP
56
+ logits = self.mlp(concat_embeddings) # (B, Lb + Lg, 1)
57
+
58
+ # Get only the DNA logits.
59
+ logits = logits[:,0:lg,:].squeeze(
 
 
 
 
60
  -1
61
+ )
62
+ return logits
63
+
64
  # ----- Lightning hooks -----
65
  def training_step(self, batch, batch_idx):
66
  """
67
  Training step taken by PyTorch-Lightning trainer. Uses batch returned by data collator.
68
  Colator returns a dictionary with:
69
  "binder_emb" # [B, Lb_max, Db]
70
+ "binder_kpm" # [B, Lb_max]
71
  "glm_emb" # [B, Lg_max, Dg]
72
+ "glm_kpm" # [B, Lg_max]
73
  "labels" # [B, Lg_max]
74
  "ID"
75
  "tr_sequence"
76
  "dna_sequence"
77
  }
78
  """
79
+ logits = self.forward(batch["binder_emb"], batch["glm_emb"], batch["binder_kpm"], batch["glm_kpm"])
80
  loss = calculate_loss(
81
  logits, batch["labels"], alpha=self.hparams.alpha, gamma=self.hparams.gamma
82
  )
 
88
  prog_bar=True,
89
  batch_size=logits.size(0),
90
  )
91
+
92
+ # ---- AUPRC and AUROC on labels in {0, >0.99} only ----
93
+ ap, n_pos, n_neg, precision, recall, thresholds = auprc_zeros_vs_ones_from_logits(
94
+ logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99
95
+ )
96
+ auc, n_pos, n_neg, tpr, fpr, thresolds, tp, fp = auroc_zeros_vs_ones_from_logits(
97
+ logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99
98
+ )
99
+ # per-batch AP (epoch-mean is a decent summary); sync across GPUs if using DDP
100
+ self.log("train/auprc_0v1",
101
+ ap if torch.isfinite(ap) else torch.tensor(0.0, device=ap.device),
102
+ on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0))
103
+ self.log("train/auroc_0v1",
104
+ auc if torch.isfinite(auc) else torch.tensor(0.0, device=auc.device),
105
+ on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0))
106
+
107
+ # (optional) also log class counts so you can sanity-check balance
108
+ self.log("train/n_pos_0v1", float(n_pos), on_step=False, on_epoch=True, sync_dist=True)
109
+ self.log("train/n_neg_0v1", float(n_neg), on_step=False, on_epoch=True, sync_dist=True)
110
+
111
  return loss
112
 
113
  def validation_step(self, batch, batch_idx):
114
+ logits = self.forward(batch["binder_emb"], batch["glm_emb"], batch["binder_kpm"], batch["glm_kpm"])
115
  loss = calculate_loss(
116
  logits, batch["labels"], alpha=self.hparams.alpha, gamma=self.hparams.gamma
117
  )
 
123
  prog_bar=True,
124
  batch_size=logits.size(0),
125
  )
126
+
127
+ # ---- AUPRC and AUROC on labels in {0, >0.99} only ----
128
+ ap, n_pos, n_neg, precision, recall, thresholds = auprc_zeros_vs_ones_from_logits(
129
+ logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99
130
+ )
131
+ auc, n_pos, n_neg, tpr, fpr, thresolds, tp, fp = auroc_zeros_vs_ones_from_logits(
132
+ logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99
133
+ )
134
+ # per-batch AP (epoch-mean is a decent summary); sync across GPUs if using DDP
135
+ self.log("val/auprc_0v1",
136
+ ap if torch.isfinite(ap) else torch.tensor(0.0, device=ap.device),
137
+ on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0))
138
+ self.log("val/auroc_0v1",
139
+ auc if torch.isfinite(auc) else torch.tensor(0.0, device=auc.device),
140
+ on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0))
141
  return loss
142
 
143
  def test_step(self, batch, batch_idx):
144
+ logits = self.forward(batch["binder_emb"], batch["glm_emb"], batch["binder_kpm"], batch["glm_kpm"])
145
  loss = calculate_loss(
146
  logits, batch["labels"], alpha=self.hparams.alpha, gamma=self.hparams.gamma
147
  )
148
  self.log(
149
  "test/loss", loss, on_step=False, on_epoch=True, batch_size=logits.size(0)
150
  )
151
+
152
+ # ---- AUPRC and AUROC on labels in {0, >0.99} only ----
153
+ ap, n_pos, n_neg, precision, recall, thresholds = auprc_zeros_vs_ones_from_logits(
154
+ logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99
155
+ )
156
+ auc, n_pos, n_neg, tpr, fpr, thresolds, tp, fp = auroc_zeros_vs_ones_from_logits(
157
+ logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99
158
+ )
159
+ # per-batch AP (epoch-mean is a decent summary); sync across GPUs if using DDP
160
+ self.log("test/auprc_0v1",
161
+ ap if torch.isfinite(ap) else torch.tensor(0.0, device=ap.device),
162
+ on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0))
163
+ self.log("test/auroc_0v1",
164
+ auc if torch.isfinite(auc) else torch.tensor(0.0, device=auc.device),
165
+ on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0))
166
  return loss
167
+
168
+ def on_before_optimizer_step(self, optimizer):
169
+ # Compute global L2 norm of all parameter gradients (ignores None grads)
170
+ grads = []
171
+ for p in self.parameters():
172
+ if p.grad is not None:
173
+ # .detach() avoids autograd tracking; .float() avoids fp16 overflow in norms
174
+ grads.append(p.grad.detach().float().norm(2))
175
+ if grads:
176
+ total_norm = torch.norm(torch.stack(grads), p=2)
177
+ self.log("train/grad_norm", total_norm, on_step=True, prog_bar=False, logger=True)
178
+
179
+ def on_after_backward(self):
180
+ grads = [p.grad.detach().float().norm(2)
181
+ for p in self.parameters() if p.grad is not None]
182
+ if grads:
183
+ total_norm = torch.norm(torch.stack(grads), p=2)
184
+ self.log("train/grad_norm_back", total_norm, on_step=True, prog_bar=False)
185
 
186
  def on_train_epoch_end(self):
187
  if False:
 
215
  return {
216
  "optimizer": opt,
217
  "lr_scheduler": {"scheduler": sch, "interval": "epoch"},
218
+ }
dpacman/classifier/loss.py CHANGED
@@ -4,6 +4,9 @@ Define loss functions needed for training the model — padding safe (-1 sentine
4
 
5
  import torch
6
  import torch.nn.functional as F
 
 
 
7
 
8
  def _expand_like(mask: torch.Tensor, like: torch.Tensor):
9
  # Make mask broadcastable to logits/targets (handles (B,L) vs (B,L,1))
@@ -62,59 +65,89 @@ def calculate_loss(
62
 
63
  return alpha * bce_nonpeak + gamma * mse_peak
64
 
65
- import torch
66
-
67
  @torch.no_grad()
68
- def auprc_zeros_vs_ones_from_logits(
69
  logits: torch.Tensor, # (B, L)
70
  labels: torch.Tensor, # (B, L)
71
- glm_kpm: torch.Tensor | None, # (B, L) True=PAD; pass None if not available
72
  pos_thresh: float = 0.99,
73
- ) -> tuple[torch.Tensor, int, int]:
74
  """
75
- Returns (ap, n_pos, n_neg). AP is Average Precision (area under PR).
76
- Uses only positions with labels == 0.0 or > pos_thresh. Ignores PADs via glm_kpm.
77
- Computation stays on the same device as logits.
 
 
 
78
  """
79
- probs = torch.sigmoid(logits)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
- # Valid positions: not padded
82
- if glm_kpm is not None:
83
- valid = ~glm_kpm
84
- else:
85
- valid = torch.ones_like(labels, dtype=torch.bool, device=labels.device)
86
 
87
- # Keep only exact zeros and near-ones
88
- pos = labels > pos_thresh
89
- neg = labels == 0.0
90
- keep = valid & (pos | neg)
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  if keep.sum() == 0:
93
- return torch.tensor(float('nan'), device=logits.device), 0, 0
 
94
 
95
- y = pos[keep].to(probs.dtype) # 1 for >0.99, 0 for 0.0
96
- s = probs[keep].to(probs.dtype)
97
 
98
- n = y.numel()
99
  n_pos = int(y.sum().item())
100
- n_neg = n - n_pos
101
- if n_pos == 0: # no positives → AP = 0 by convention
102
- return torch.tensor(0.0, device=logits.device), 0, n_neg
103
-
104
- # Sort by score descending
105
- order = torch.argsort(s, descending=True)
106
- y_sorted = y[order]
107
-
108
- # CumTP and precision/recall
109
- tp = torch.cumsum(y_sorted, dim=0)
110
- ranks = torch.arange(1, n + 1, device=logits.device, dtype=probs.dtype)
111
- precision = tp / ranks
112
- recall = tp / n_pos
113
-
114
- # AP = sum( precision * Δrecall )
115
- recall_prev = torch.cat([torch.zeros(1, device=logits.device, dtype=probs.dtype), recall[:-1]])
116
- ap = (precision * (recall - recall_prev)).sum()
117
- return ap, n_pos, n_neg
118
 
119
  def accuracy_percentage(
120
  logits,
 
4
 
5
  import torch
6
  import torch.nn.functional as F
7
+ from torchmetrics.functional.classification import (
8
+ auroc, average_precision, roc, precision_recall_curve
9
+ )
10
 
11
  def _expand_like(mask: torch.Tensor, like: torch.Tensor):
12
  # Make mask broadcastable to logits/targets (handles (B,L) vs (B,L,1))
 
65
 
66
  return alpha * bce_nonpeak + gamma * mse_peak
67
 
 
 
68
  @torch.no_grad()
69
+ def auroc_zeros_vs_ones_from_logits(
70
  logits: torch.Tensor, # (B, L)
71
  labels: torch.Tensor, # (B, L)
72
+ glm_kpm: torch.Tensor | None = None, # (B, L) True=PAD
73
  pos_thresh: float = 0.99,
74
+ ):
75
  """
76
+ Returns:
77
+ auc: scalar tensor (AUROC)
78
+ n_pos, n_neg: ints
79
+ tpr, fpr: tensors of shape (T,)
80
+ thresholds: tensor of shape (T,)
81
+ tp, fp: integer counts per threshold (shape (T,))
82
  """
83
+ device = logits.device
84
+ valid = ~glm_kpm if glm_kpm is not None else torch.ones_like(labels, dtype=torch.bool, device=device)
85
+ keep = valid & ((labels > pos_thresh) | (labels == 0.0))
86
+ if keep.sum() == 0:
87
+ return (torch.tensor(float('nan'), device=device), 0, 0,
88
+ torch.empty(0, device=device), torch.empty(0, device=device),
89
+ torch.empty(0, device=device), torch.empty(0, device=device), torch.empty(0, device=device))
90
+
91
+ y = (labels[keep] > pos_thresh).to(torch.int)
92
+ s = logits[keep]
93
+
94
+ n_pos = int(y.sum().item())
95
+ n_neg = y.numel() - n_pos
96
+ if n_pos == 0 or n_neg == 0:
97
+ return (torch.tensor(float('nan'), device=device), n_pos, n_neg,
98
+ torch.empty(0, device=device), torch.empty(0, device=device),
99
+ torch.empty(0, device=device), torch.empty(0, device=device), torch.empty(0, device=device))
100
+
101
+ # Full ROC curve
102
+ fpr, tpr, thresholds = roc(s, y, task="binary")
103
+ # AUROC (TM handles logits)
104
+ auc = auroc(s, y, task="binary")
105
 
106
+ # Convert rates to counts (round to nearest to avoid float off-by-one)
107
+ tp = (tpr * n_pos).round().to(torch.long)
108
+ fp = (fpr * n_neg).round().to(torch.long)
 
 
109
 
110
+ return auc.to(device), n_pos, n_neg, tpr.to(device), fpr.to(device), thresholds.to(device), tp.to(device), fp.to(device)
 
 
 
111
 
112
+
113
+ @torch.no_grad()
114
+ def auprc_zeros_vs_ones_from_logits(
115
+ logits: torch.Tensor, # (B, L)
116
+ labels: torch.Tensor, # (B, L)
117
+ glm_kpm: torch.Tensor | None = None, # (B, L) True=PAD
118
+ pos_thresh: float = 0.99,
119
+ ):
120
+ """
121
+ Returns:
122
+ ap: scalar tensor (Average Precision / AUPRC)
123
+ n_pos, n_neg: ints
124
+ precision: (T,)
125
+ recall: (T,)
126
+ thresholds: (T,)
127
+ """
128
+ device = logits.device
129
+ valid = ~glm_kpm if glm_kpm is not None else torch.ones_like(labels, dtype=torch.bool, device=device)
130
+ keep = valid & ((labels > pos_thresh) | (labels == 0.0))
131
  if keep.sum() == 0:
132
+ return (torch.tensor(float('nan'), device=device), 0, 0,
133
+ torch.empty(0, device=device), torch.empty(0, device=device), torch.empty(0, device=device))
134
 
135
+ y = (labels[keep] > pos_thresh).to(torch.int)
136
+ s = logits[keep]
137
 
 
138
  n_pos = int(y.sum().item())
139
+ n_neg = y.numel() - n_pos
140
+ if n_pos == 0:
141
+ # By convention, AP=0 when there are no positives
142
+ return (torch.tensor(0.0, device=device), 0, n_neg,
143
+ torch.empty(0, device=device), torch.empty(0, device=device), torch.empty(0, device=device))
144
+
145
+ # Full PR curve
146
+ precision, recall, thresholds = precision_recall_curve(s, y, task="binary")
147
+ # Average Precision / AUPRC
148
+ ap = average_precision(s, y, task="binary")
149
+
150
+ return ap.to(device), n_pos, n_neg, precision.to(device), recall.to(device), thresholds.to(device)
 
 
 
 
 
 
151
 
152
  def accuracy_percentage(
153
  logits,
dpacman/classifier/model.py CHANGED
@@ -6,7 +6,7 @@ import torch
6
  from torch import nn
7
  from lightning import LightningModule
8
  from dpacman.utils.models import set_seed
9
- from .loss import calculate_loss, auprc_zeros_vs_ones_from_logits
10
 
11
  set_seed()
12
 
@@ -211,9 +211,9 @@ class BindPredictor(LightningModule):
211
  Training step taken by PyTorch-Lightning trainer. Uses batch returned by data collator.
212
  Colator returns a dictionary with:
213
  "binder_emb" # [B, Lb_max, Db]
214
- "binder_mask" # [B, Lb_max]
215
  "glm_emb" # [B, Lg_max, Dg]
216
- "glm_mask" # [B, Lg_max]
217
  "labels" # [B, Lg_max]
218
  "ID"
219
  "tr_sequence"
@@ -233,14 +233,20 @@ class BindPredictor(LightningModule):
233
  batch_size=logits.size(0),
234
  )
235
 
236
- # ---- AUPRC on labels in {0, >0.99} only ----
237
- ap, n_pos, n_neg = auprc_zeros_vs_ones_from_logits(
 
 
 
238
  logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99
239
  )
240
  # per-batch AP (epoch-mean is a decent summary); sync across GPUs if using DDP
241
  self.log("train/auprc_0v1",
242
  ap if torch.isfinite(ap) else torch.tensor(0.0, device=ap.device),
243
  on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0))
 
 
 
244
  # (optional) also log class counts so you can sanity-check balance
245
  self.log("train/n_pos_0v1", float(n_pos), on_step=False, on_epoch=True, sync_dist=True)
246
  self.log("train/n_neg_0v1", float(n_neg), on_step=False, on_epoch=True, sync_dist=True)
@@ -260,6 +266,22 @@ class BindPredictor(LightningModule):
260
  prog_bar=True,
261
  batch_size=logits.size(0),
262
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
  return loss
264
 
265
  def test_step(self, batch, batch_idx):
@@ -270,6 +292,21 @@ class BindPredictor(LightningModule):
270
  self.log(
271
  "test/loss", loss, on_step=False, on_epoch=True, batch_size=logits.size(0)
272
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
  return loss
274
 
275
  def on_before_optimizer_step(self, optimizer):
 
6
  from torch import nn
7
  from lightning import LightningModule
8
  from dpacman.utils.models import set_seed
9
+ from .loss import calculate_loss, auprc_zeros_vs_ones_from_logits, auroc_zeros_vs_ones_from_logits
10
 
11
  set_seed()
12
 
 
211
  Training step taken by PyTorch-Lightning trainer. Uses batch returned by data collator.
212
  Colator returns a dictionary with:
213
  "binder_emb" # [B, Lb_max, Db]
214
+ "binder_kpm" # [B, Lb_max]
215
  "glm_emb" # [B, Lg_max, Dg]
216
+ "glm_kpm" # [B, Lg_max]
217
  "labels" # [B, Lg_max]
218
  "ID"
219
  "tr_sequence"
 
233
  batch_size=logits.size(0),
234
  )
235
 
236
+ # ---- AUPRC and AUROC on labels in {0, >0.99} only ----
237
+ ap, n_pos, n_neg, precision, recall, thresholds = auprc_zeros_vs_ones_from_logits(
238
+ logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99
239
+ )
240
+ auc, n_pos, n_neg, tpr, fpr, thresolds, tp, fp = auroc_zeros_vs_ones_from_logits(
241
  logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99
242
  )
243
  # per-batch AP (epoch-mean is a decent summary); sync across GPUs if using DDP
244
  self.log("train/auprc_0v1",
245
  ap if torch.isfinite(ap) else torch.tensor(0.0, device=ap.device),
246
  on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0))
247
+ self.log("train/auroc_0v1",
248
+ auc if torch.isfinite(auc) else torch.tensor(0.0, device=auc.device),
249
+ on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0))
250
  # (optional) also log class counts so you can sanity-check balance
251
  self.log("train/n_pos_0v1", float(n_pos), on_step=False, on_epoch=True, sync_dist=True)
252
  self.log("train/n_neg_0v1", float(n_neg), on_step=False, on_epoch=True, sync_dist=True)
 
266
  prog_bar=True,
267
  batch_size=logits.size(0),
268
  )
269
+
270
+ # ---- AUPRC and AUROC on labels in {0, >0.99} only ----
271
+ ap, n_pos, n_neg, precision, recall, thresholds = auprc_zeros_vs_ones_from_logits(
272
+ logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99
273
+ )
274
+ auc, n_pos, n_neg, tpr, fpr, thresolds, tp, fp = auroc_zeros_vs_ones_from_logits(
275
+ logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99
276
+ )
277
+ # per-batch AP (epoch-mean is a decent summary); sync across GPUs if using DDP
278
+ self.log("val/auprc_0v1",
279
+ ap if torch.isfinite(ap) else torch.tensor(0.0, device=ap.device),
280
+ on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0))
281
+ self.log("val/auroc_0v1",
282
+ auc if torch.isfinite(auc) else torch.tensor(0.0, device=auc.device),
283
+ on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0))
284
+
285
  return loss
286
 
287
  def test_step(self, batch, batch_idx):
 
292
  self.log(
293
  "test/loss", loss, on_step=False, on_epoch=True, batch_size=logits.size(0)
294
  )
295
+
296
+ # ---- AUPRC and AUROC on labels in {0, >0.99} only ----
297
+ ap, n_pos, n_neg, precision, recall, thresholds = auprc_zeros_vs_ones_from_logits(
298
+ logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99
299
+ )
300
+ auc, n_pos, n_neg, tpr, fpr, thresolds, tp, fp = auroc_zeros_vs_ones_from_logits(
301
+ logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99
302
+ )
303
+ # per-batch AP (epoch-mean is a decent summary); sync across GPUs if using DDP
304
+ self.log("test/auprc_0v1",
305
+ ap if torch.isfinite(ap) else torch.tensor(0.0, device=ap.device),
306
+ on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0))
307
+ self.log("test/auroc_0v1",
308
+ auc if torch.isfinite(auc) else torch.tensor(0.0, device=auc.device),
309
+ on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0))
310
  return loss
311
 
312
  def on_before_optimizer_step(self, optimizer):
dpacman/classifier/model_tmp/__init__.py DELETED
File without changes
dpacman/classifier/model_tmp/clustering_data.py DELETED
@@ -1,475 +0,0 @@
1
- #!/usr/bin/env python3
2
- import argparse
3
- import numpy as np
4
- import pandas as pd
5
- from pathlib import Path
6
- import random
7
- import sys
8
- import subprocess
9
- from collections import defaultdict
10
-
11
- # ─────────────────────────────────────────────────────────────────────────
12
- # Original helpers (kept; some lightly edited/commented where needed)
13
- # ─────────────────────────────────────────────────────────────────────────
14
-
15
-
16
- def read_ids_file(p):
17
- p = Path(p)
18
- if not p.exists():
19
- raise FileNotFoundError(f"IDs file not found: {p}")
20
- return [line.strip() for line in p.open() if line.strip()]
21
-
22
-
23
- def split_embeddings(emb_path, ids_path, out_dir, prefix):
24
- out_dir = Path(out_dir)
25
- out_dir.mkdir(parents=True, exist_ok=True)
26
-
27
- if not Path(emb_path).exists():
28
- raise FileNotFoundError(f"Embedding file not found: {emb_path}")
29
- if not Path(ids_path).exists():
30
- raise FileNotFoundError(f"IDs file not found: {ids_path}")
31
-
32
- if emb_path.endswith(".npz"):
33
- data = np.load(emb_path, allow_pickle=True)
34
- if "embeddings" in data:
35
- emb = data["embeddings"]
36
- else:
37
- raise ValueError(f"{emb_path} missing 'embeddings' key")
38
- else:
39
- emb = np.load(emb_path)
40
-
41
- ids = read_ids_file(ids_path)
42
- if len(ids) != emb.shape[0]:
43
- print(
44
- f"[WARN] length mismatch: {len(ids)} ids vs {emb.shape[0]} embeddings in {emb_path}",
45
- file=sys.stderr,
46
- )
47
-
48
- mapping = {}
49
- for i, ident in enumerate(ids):
50
- if i >= emb.shape[0]:
51
- print(
52
- f"[WARN] skipping {ident}: no embedding at index {i}", file=sys.stderr
53
- )
54
- continue
55
- arr = emb[i]
56
- out_file = out_dir / f"{prefix}_{ident}.npy"
57
- np.save(out_file, arr)
58
- mapping[ident] = str(out_file)
59
- return mapping
60
-
61
-
62
- def extract_symbol_from_tf_id(full_id: str) -> str:
63
- """
64
- Given a TF embedding ID like 'sp|O15062|ZBTB5_HUMAN' or 'ZBTB5_HUMAN',
65
- return the gene symbol uppercase (e.g., 'ZBTB5').
66
- """
67
- if "|" in full_id:
68
- try:
69
- # format sp|Accession|SYMBOL_HUMAN
70
- genepart = full_id.split("|")[2]
71
- except IndexError:
72
- genepart = full_id
73
- else:
74
- genepart = full_id
75
- symbol = genepart.split("_")[0]
76
- return symbol.upper()
77
-
78
-
79
- def build_tf_symbol_map(tf_map):
80
- """
81
- Build mapping gene_symbol -> list of embedding paths.
82
- """
83
- symbol_map = {}
84
- for full_id, path in tf_map.items():
85
- symbol = extract_symbol_from_tf_id(full_id)
86
- symbol_map.setdefault(symbol, []).append(path)
87
- return symbol_map
88
-
89
-
90
- def tf_key_from_path(path: str) -> str:
91
- """
92
- Given a path like .../tf_sp|O15062|ZBTB5_HUMAN.npy, extract normalized symbol 'ZBTB5'.
93
- """
94
- stem = Path(path).stem # e.g., tf_sp|O15062|ZBTB5_HUMAN
95
- # remove leading prefix if present (tf_)
96
- if "_" in stem:
97
- _, rest = stem.split("_", 1)
98
- else:
99
- rest = stem
100
- return extract_symbol_from_tf_id(rest)
101
-
102
-
103
- def dna_key_from_path(path: str) -> str:
104
- """
105
- Given .../dna_peak42.npy -> 'peak42'
106
- """
107
- stem = Path(path).stem
108
- if "_" in stem:
109
- _, rest = stem.split("_", 1)
110
- else:
111
- rest = stem
112
- return rest
113
-
114
-
115
- # ─────────────────────────────────────────────────────────────────────────
116
- # New helpers for MMseqs clustering & cluster-level splitting
117
- # ─────────────────────────────────────────────────────────────────────────
118
-
119
-
120
- def write_dna_fasta(df: pd.DataFrame, out_fasta: Path) -> None:
121
- """
122
- Write unique DNA sequences to FASTA using dna_id as header.
123
- Requires df with columns: dna_id, dna_sequence
124
- """
125
- uniq = df[["dna_id", "dna_sequence"]].drop_duplicates()
126
- with open(out_fasta, "w") as f:
127
- for _, row in uniq.iterrows():
128
- did = row["dna_id"]
129
- seq = str(row["dna_sequence"]).upper().replace(" ", "").replace("\n", "")
130
- f.write(f">{did}\n{seq}\n")
131
-
132
-
133
- def run_mmseqs_easy_cluster(
134
- mmseqs_bin: str,
135
- fasta: Path,
136
- out_prefix: Path,
137
- tmp_dir: Path,
138
- min_seq_id: float,
139
- coverage: float,
140
- cov_mode: int,
141
- ) -> Path:
142
- """
143
- Runs mmseqs easy-cluster on nucleotide sequences.
144
- Returns the path to a clusters TSV file (creating it if the default one isn't present).
145
- """
146
- tmp_dir.mkdir(parents=True, exist_ok=True)
147
- out_prefix.parent.mkdir(parents=True, exist_ok=True)
148
-
149
- cmd = [
150
- mmseqs_bin,
151
- "easy-cluster",
152
- str(fasta),
153
- str(out_prefix),
154
- str(tmp_dir),
155
- "--min-seq-id",
156
- str(min_seq_id),
157
- "-c",
158
- str(coverage),
159
- "--cov-mode",
160
- str(cov_mode),
161
- # You can add performance flags here if needed, e.g.:
162
- # "--threads", "8"
163
- ]
164
- print("[i] Running:", " ".join(cmd), flush=True)
165
- subprocess.run(cmd, check=True)
166
-
167
- # MMseqs easy-cluster typically writes <out_prefix>_cluster.tsv
168
- default_tsv = Path(str(out_prefix) + "_cluster.tsv")
169
- if default_tsv.exists():
170
- print(f"[i] Found cluster TSV: {default_tsv}")
171
- return default_tsv
172
-
173
- # Fallback: try createtsv if default is missing
174
- # This requires the internal DBs. easy-cluster creates DBs alongside out_prefix.
175
- # We'll try to locate them and emit a TSV.
176
- in_db = Path(str(out_prefix) + "_query")
177
- cl_db = Path(str(out_prefix) + "_cluster")
178
- out_tsv = Path(str(out_prefix) + "_fallback_cluster.tsv")
179
- if in_db.exists() and cl_db.exists():
180
- cmd2 = [
181
- mmseqs_bin,
182
- "createtsv",
183
- str(in_db),
184
- str(in_db),
185
- str(cl_db),
186
- str(out_tsv),
187
- ]
188
- print("[i] Creating TSV via createtsv:", " ".join(cmd2), flush=True)
189
- subprocess.run(cmd2, check=True)
190
- if out_tsv.exists():
191
- return out_tsv
192
-
193
- raise FileNotFoundError(
194
- "Could not locate clusters TSV from mmseqs. "
195
- "Expected {default_tsv} or createtsv fallback."
196
- )
197
-
198
-
199
- def parse_mmseqs_clusters(tsv_path: Path) -> dict:
200
- """
201
- Parse MMseqs cluster TSV (rep \t member). Returns dna_id -> cluster_rep_id
202
- """
203
- mapping = {}
204
- with open(tsv_path) as f:
205
- for line in f:
206
- parts = line.rstrip("\n").split("\t")
207
- if len(parts) < 2:
208
- continue
209
- rep, member = parts[0], parts[1]
210
- mapping[member] = rep
211
- # Some TSVs include rep->rep; if not, ensure rep is mapped to itself:
212
- if rep not in mapping:
213
- mapping[rep] = rep
214
- return mapping
215
-
216
-
217
- def assign_clusters_to_splits(
218
- cluster_rep_to_members: dict, val_frac: float, test_frac: float, seed: int = 42
219
- ):
220
- """
221
- cluster_rep_to_members: dict[rep] = [members...]
222
- Returns: dict with keys 'train','val','test' mapping to sets of dna_id.
223
- Ensures all members of a cluster go to the same split.
224
- """
225
- rng = random.Random(seed)
226
- reps = list(cluster_rep_to_members.keys())
227
- rng.shuffle(reps)
228
-
229
- # Greedy-ish fill by total member counts to match desired fractions.
230
- total = sum(len(cluster_rep_to_members[r]) for r in reps)
231
- target_val = int(round(total * val_frac))
232
- target_test = int(round(total * test_frac))
233
- cur_val = cur_test = 0
234
-
235
- val_ids, test_ids, train_ids = set(), set(), set()
236
- for rep in reps:
237
- members = cluster_rep_to_members[rep]
238
- c = len(members)
239
- # Fill val first, then test, then train
240
- if cur_val + c <= target_val:
241
- val_ids.update(members)
242
- cur_val += c
243
- elif cur_test + c <= target_test:
244
- test_ids.update(members)
245
- cur_test += c
246
- else:
247
- train_ids.update(members)
248
-
249
- return {"train": train_ids, "val": val_ids, "test": test_ids}
250
-
251
-
252
- # ─────────────────────────────────────────────────────────────────────────
253
- # Main
254
- # ─────────────────────────────────────────────────────────────────────────
255
-
256
-
257
- def main():
258
- parser = argparse.ArgumentParser(
259
- description="Build TF-DNA pair lists with MMseqs clustering on DNA to prevent split leakage."
260
- )
261
- parser.add_argument(
262
- "--final_csv", required=True, help="final.csv with TF_id and dna_sequence"
263
- )
264
- parser.add_argument(
265
- "--dna_embed_npz", required=True, help="DNA embedding file (.npy or .npz)"
266
- )
267
- parser.add_argument(
268
- "--dna_ids", required=True, help="IDs file for DNA embeddings (peak*.ids)"
269
- )
270
- parser.add_argument(
271
- "--tf_embed_npy", required=True, help="TF embedding file (.npy or .npz)"
272
- )
273
- parser.add_argument(
274
- "--tf_ids", required=True, help="IDs file for TF embeddings (sp|... ids)"
275
- )
276
- parser.add_argument("--out_dir", required=True, help="Output directory")
277
- parser.add_argument("--seed", type=int, default=42)
278
-
279
- # NEW: MMseqs options & split fractions
280
- parser.add_argument("--mmseqs_bin", default="mmseqs", help="Path to mmseqs binary")
281
- parser.add_argument(
282
- "--min_seq_id", type=float, default=0.9, help="MMseqs --min-seq-id"
283
- )
284
- parser.add_argument(
285
- "--cov", type=float, default=0.8, help="MMseqs -c coverage fraction"
286
- )
287
- parser.add_argument(
288
- "--cov_mode",
289
- type=int,
290
- default=1,
291
- help="MMseqs --cov-mode (1 = coverage of target)",
292
- )
293
- parser.add_argument("--val_frac", type=float, default=0.10)
294
- parser.add_argument("--test_frac", type=float, default=0.10)
295
- parser.add_argument(
296
- "--tmp_dir", default=None, help="MMseqs tmp dir (defaults to out_dir/tmp)"
297
- )
298
- args = parser.parse_args()
299
-
300
- random.seed(args.seed)
301
- out_dir = Path(args.out_dir)
302
- out_dir.mkdir(parents=True, exist_ok=True)
303
-
304
- # Load final.csv
305
- df = pd.read_csv(args.final_csv, dtype=str)
306
- if "TF_id" not in df.columns or "dna_sequence" not in df.columns:
307
- raise RuntimeError("final.csv must have columns TF_id and dna_sequence")
308
-
309
- # Assign dna_id (unique per dna_sequence)
310
- unique_seqs = df["dna_sequence"].drop_duplicates().tolist()
311
- seq_to_id = {seq: f"peak{i}" for i, seq in enumerate(unique_seqs)}
312
- df["dna_id"] = df["dna_sequence"].map(seq_to_id)
313
- enriched_csv = out_dir / "final_with_dna_id.csv"
314
- df.to_csv(enriched_csv, index=False)
315
- print(f"[i] Wrote augmented final.csv with dna_id to {enriched_csv}")
316
-
317
- # Split embeddings into per-item files (unchanged)
318
- print(
319
- f"[i] Splitting DNA embeddings from {args.dna_embed_npz} with ids {args.dna_ids}"
320
- )
321
- dna_map = split_embeddings(
322
- args.dna_embed_npz, args.dna_ids, out_dir / "dna_single", "dna"
323
- )
324
- print(
325
- f"[i] DNA embeddings available: {len(dna_map)} (sample: {list(dna_map.keys())[:10]})"
326
- )
327
- print(
328
- f"[i] Splitting TF embeddings from {args.tf_embed_npy} with ids {args.tf_ids}"
329
- )
330
- tf_map = split_embeddings(
331
- args.tf_embed_npy, args.tf_ids, out_dir / "tf_single", "tf"
332
- )
333
- print(
334
- f"[i] TF embeddings available: {len(tf_map)} (sample: {list(tf_map.keys())[:10]})"
335
- )
336
-
337
- # Build gene-symbol normalized map
338
- tf_symbol_map = build_tf_symbol_map(tf_map)
339
- print(f"[i] TF symbol map keys (sample): {list(tf_symbol_map.keys())[:30]}")
340
-
341
- # Diagnostic overlaps
342
- norm_tf_in_final = set(t.split("_seq")[0].upper() for t in df["TF_id"].unique())
343
- available_tf_symbols = set(tf_symbol_map.keys())
344
- intersect_tf = norm_tf_in_final & available_tf_symbols
345
- print(f"[i] Unique normalized TF symbols in final.csv: {len(norm_tf_in_final)}")
346
- print(f"[i] Available TF embedding symbols: {len(available_tf_symbols)}")
347
- print(f"[i] Intersection count: {len(intersect_tf)}")
348
- if len(intersect_tf) == 0:
349
- print(
350
- "[ERROR] No overlap between normalized TF_id and TF embedding symbols.",
351
- file=sys.stderr,
352
- )
353
- print(
354
- "Sample normalized TFs from final.csv:",
355
- sorted(list(norm_tf_in_final))[:30],
356
- file=sys.stderr,
357
- )
358
- print(
359
- "Sample available TF symbols:",
360
- sorted(list(available_tf_symbols))[:30],
361
- file=sys.stderr,
362
- )
363
- sys.exit(1)
364
-
365
- dna_ids_final = set(df["dna_id"].unique())
366
- available_dna_ids = set(dna_map.keys())
367
- intersect_dna = dna_ids_final & available_dna_ids
368
- print(
369
- f"[i] Unique dna_id in final.csv: {len(dna_ids_final)}. Available DNA ids: {len(available_dna_ids)}. Intersection: {len(intersect_dna)}"
370
- )
371
- if len(intersect_dna) == 0:
372
- print("[ERROR] No overlap on DNA ids.", file=sys.stderr)
373
- sys.exit(1)
374
-
375
- # ── NEW: MMseqs clustering on DNA sequences ───────────────────────────
376
- fasta_path = out_dir / "dna_unique.fasta"
377
- write_dna_fasta(df, fasta_path)
378
- print(
379
- f"[i] Wrote FASTA with {df['dna_id'].nunique()} unique sequences → {fasta_path}"
380
- )
381
-
382
- tmp_dir = Path(args.tmp_dir) if args.tmp_dir else (out_dir / "mmseqs_tmp")
383
- cluster_prefix = out_dir / "mmseqs_dna_clusters"
384
- clusters_tsv = run_mmseqs_easy_cluster(
385
- mmseqs_bin=args.mmseqs_bin,
386
- fasta=fasta_path,
387
- out_prefix=cluster_prefix,
388
- tmp_dir=tmp_dir,
389
- min_seq_id=args.min_seq_id,
390
- coverage=args.cov,
391
- cov_mode=args.cov_mode,
392
- )
393
-
394
- # Parse clusters
395
- member_to_rep = parse_mmseqs_clusters(clusters_tsv) # dna_id -> rep_id
396
- # Build rep -> members list
397
- rep_to_members = defaultdict(list)
398
- for member, rep in member_to_rep.items():
399
- rep_to_members[rep].append(member)
400
-
401
- print(f"[i] Parsed {len(rep_to_members)} clusters from {clusters_tsv}")
402
- clusters_table = []
403
- for rep, members in rep_to_members.items():
404
- for m in members:
405
- clusters_table.append((m, rep))
406
- clusters_df = pd.DataFrame(clusters_table, columns=["dna_id", "cluster_id"])
407
- clusters_df.to_csv(out_dir / "clusters.tsv", sep="\t", index=False)
408
- print(f"[i] Wrote clusters mapping → {out_dir / 'clusters.tsv'}")
409
-
410
- # Attach cluster_id back to final df
411
- df = df.merge(clusters_df, on="dna_id", how="left")
412
- df.to_csv(out_dir / "final_with_dna_id_and_cluster.csv", index=False)
413
- print(f"[i] Wrote {out_dir / 'final_with_dna_id_and_cluster.csv'}")
414
-
415
- # Assign entire clusters to splits
416
- splits = assign_clusters_to_splits(
417
- rep_to_members, val_frac=args.val_frac, test_frac=args.test_frac, seed=args.seed
418
- )
419
- for k in ["train", "val", "test"]:
420
- print(f"[i] {k}: {len(splits[k])} dna_ids")
421
-
422
- # ── Build positive pairs only, per split (NO negatives) ───────────────
423
- positives_by_split = {"train": [], "val": [], "test": []}
424
- # Build a quick dna_id -> embedding path map
425
- dnaid_to_path = {did: path for did, path in dna_map.items()}
426
-
427
- pos_count = 0
428
- for _, row in df.iterrows():
429
- tf_raw = row["TF_id"]
430
- tf_symbol = tf_raw.split("_seq")[0].upper()
431
- dnaid = row["dna_id"]
432
- if (tf_symbol not in tf_symbol_map) or (dnaid not in dnaid_to_path):
433
- continue
434
- tf_embedding_path = tf_symbol_map[tf_symbol][0] # first embedding per symbol
435
-
436
- # decide split by dna_id cluster assignment
437
- if dnaid in splits["train"]:
438
- positives_by_split["train"].append(
439
- (tf_embedding_path, dnaid_to_path[dnaid], 1)
440
- )
441
- elif dnaid in splits["val"]:
442
- positives_by_split["val"].append(
443
- (tf_embedding_path, dnaid_to_path[dnaid], 1)
444
- )
445
- elif dnaid in splits["test"]:
446
- positives_by_split["test"].append(
447
- (tf_embedding_path, dnaid_to_path[dnaid], 1)
448
- )
449
- pos_count += 1
450
-
451
- print(
452
- f"[i] Constructed positives across splits (rows in final.csv iterated: {len(df)})"
453
- )
454
- for k in ["train", "val", "test"]:
455
- print(f"[i] positives[{k}] = {len(positives_by_split[k])}")
456
-
457
- # # OLD: negatives (kept commented)
458
- # negatives = []
459
- # print(f"[i] Sampled {len(negatives)} negatives (neg_per_positive not used)")
460
-
461
- # Emit split-specific pair lists
462
- for split in ["train", "val", "test"]:
463
- out_tsv = out_dir / f"pair_list_{split}.tsv"
464
- with open(out_tsv, "w") as f:
465
- for binder_path, glm_path, label in positives_by_split[
466
- split
467
- ]: # + negatives if you add later
468
- f.write(f"{binder_path}\t{glm_path}\t{label}\n")
469
- print(f"[i] Wrote {len(positives_by_split[split])} examples to {out_tsv}")
470
-
471
- print("✅ Done. Cluster-aware splits ready.")
472
-
473
-
474
- if __name__ == "__main__":
475
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dpacman/classifier/model_tmp/compress_embeddings.py DELETED
@@ -1,62 +0,0 @@
1
- # compress_embeddings.py
2
- # USAGE: python compress_embeddings.py --input_glob "/path/to/esm_embeddings/*.npy" --output_dir "/path/to/compressed_embeddings" --esm_dim 1280 --out_dim 256
3
- # --------------
4
- import os
5
- import glob
6
- import numpy as np
7
- import torch
8
- from torch import nn
9
-
10
-
11
- class EmbeddingCompressor(nn.Module):
12
- def __init__(self, input_dim: int = 1280, output_dim: int = 256):
13
- super().__init__()
14
- self.fc = nn.Linear(input_dim, output_dim)
15
-
16
- def forward(self, x: torch.Tensor) -> torch.Tensor:
17
- """
18
- x: (batch, L, input_dim) or (L, input_dim)
19
- returns: (batch, output_dim) or (output_dim,)
20
- """
21
- if x.dim() == 2:
22
- # single example: mean over tokens
23
- x = x.mean(dim=0, keepdim=True) # → (1, input_dim)
24
- else:
25
- # batch: mean over tokens
26
- x = x.mean(dim=1) # → (batch, input_dim)
27
- return self.fc(x) # → (batch, output_dim)
28
-
29
-
30
- def compress_file(in_path: str, out_path: str, model: EmbeddingCompressor):
31
- arr = np.load(in_path) # shape (L, D) or (batch, L, D)
32
- tensor = torch.from_numpy(arr).float()
33
- with torch.no_grad():
34
- compressed = model(tensor) # → (batch, 256)
35
- out = compressed.cpu().numpy()
36
- np.save(out_path, out)
37
- print(f"Saved {out_path}")
38
-
39
-
40
- if __name__ == "__main__":
41
- import argparse
42
-
43
- parser = argparse.ArgumentParser(description="Compress ESM embeddings to 256­d")
44
- parser.add_argument(
45
- "--input_glob",
46
- type=str,
47
- required=True,
48
- help="Glob for your .npy ESM embeddings (e.g. data/esm_*.npy)",
49
- )
50
- parser.add_argument("--output_dir", type=str, required=True)
51
- parser.add_argument("--esm_dim", type=int, default=1280)
52
- parser.add_argument("--out_dim", type=int, default=256)
53
- args = parser.parse_args()
54
-
55
- os.makedirs(args.output_dir, exist_ok=True)
56
- compressor = EmbeddingCompressor(args.esm_dim, args.out_dim)
57
- compressor.eval()
58
-
59
- for fn in glob.glob(args.input_glob):
60
- base = os.path.basename(fn).replace(".npy", "_256.npy")
61
- out_path = os.path.join(args.output_dir, base)
62
- compress_file(fn, out_path, compressor)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dpacman/classifier/model_tmp/compute_embeddings.py DELETED
@@ -1,612 +0,0 @@
1
- """
2
- Plug-and-play embedding extraction for:
3
- • Chromosome sequences (from raw UCSC JSON)
4
- • TF sequences (transcription_factors.fasta)
5
-
6
- Usage example (DNA + protein in one go):
7
- module load miniconda/24.7.1
8
- conda activate dpacman
9
- python dpacman/data/compute_embeddings.py \
10
- --genome-json-dir ../data_files/raw/genomes/hg38 \
11
- --tf-fasta ../data_files/processed/tfclust/hg38_tf/transcription_factors.fasta \
12
- --chrom-model caduceus \
13
- --tf-model esm-dbp \
14
- --out-dir ../data_files/processed/tfclust/hg38_tf/embeddings \
15
- --device cuda
16
- """
17
-
18
- import os
19
- import re
20
- import argparse
21
- import json
22
- import numpy as np
23
- from pathlib import Path
24
- import torch
25
- from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM, pipeline
26
- import esm
27
- from Bio import SeqIO
28
- import time
29
-
30
- # ---- model wrappers ----
31
-
32
-
33
- class CaduceusEmbedder:
34
- def __init__(self, device, chunk_size=131_072, overlap=0):
35
- """
36
- device: 'cpu' or 'cuda'
37
- chunk_size: max bases (and thus tokens) to send in one forward pass
38
- overlap: how many bases each window overlaps the previous; 0 = no overlap
39
- """
40
- model_name = "kuleshov-group/caduceus-ph_seqlen-131k_d_model-256_n_layer-16"
41
- self.tokenizer = AutoTokenizer.from_pretrained(
42
- model_name, trust_remote_code=True
43
- )
44
- self.model = (
45
- AutoModel.from_pretrained(model_name, trust_remote_code=True)
46
- .to(device)
47
- .eval()
48
- )
49
- self.device = device
50
- self.chunk_size = chunk_size
51
- self.step = chunk_size - overlap
52
-
53
- def embed(self, seqs):
54
- """
55
- seqs: List[str] of DNA sequences (each <= chunk_size for this test)
56
- returns: np.ndarray of shape (N, L, D), raw per‐token embeddings
57
- """
58
- # outputs = []
59
- # for seq in seqs:
60
- # # --- new: raw per‐token embeddings in one shot ---
61
- # toks = self.tokenizer(
62
- # seq,
63
- # return_tensors="pt",
64
- # padding=False,
65
- # truncation=True,
66
- # max_length=self.chunk_size
67
- # ).to(self.device)
68
- # with torch.no_grad():
69
- # out = self.model(**toks).last_hidden_state # (1, L, D)
70
- # outputs.append(out.cpu().numpy()[0]) # (L, D)
71
-
72
- # return np.stack(outputs, axis=0) # (N, L, D)
73
- outputs = []
74
- for seq in seqs:
75
- toks = self.tokenizer(
76
- seq,
77
- return_tensors="pt",
78
- padding=False,
79
- truncation=True,
80
- max_length=self.chunk_size,
81
- ).to(self.device)
82
- with torch.no_grad():
83
- out = self.model(**toks).last_hidden_state # (1, L, D)
84
- outputs.append(out.cpu().numpy()[0]) # (L, D)
85
- return outputs # list of variable-length (L_i, D) arrays
86
-
87
- def benchmark(self, lengths=None):
88
- """
89
- Time embedding on single-sequence of various lengths.
90
- By default tests [5K,10K,50K,100K,chunk_size].
91
- """
92
- tests = lengths or [5_000, 10_000, 50_000, 100_000, self.chunk_size]
93
- print(f"→ Benchmarking Caduceus on device={self.device}")
94
- for sz in tests:
95
- seq = "A" * sz
96
- # Warm-up
97
- _ = self.embed([seq])
98
- if self.device != "cpu":
99
- torch.cuda.synchronize()
100
- t0 = time.perf_counter()
101
- _ = self.embed([seq])
102
- if self.device != "cpu":
103
- torch.cuda.synchronize()
104
- t1 = time.perf_counter()
105
- print(f" length={sz:6,d} time={(t1-t0)*1000:7.1f} ms")
106
-
107
-
108
- class SegmentNTEmbedder:
109
- def __init__(self, device):
110
- self.tokenizer = AutoTokenizer.from_pretrained(
111
- "InstaDeepAI/segment_nt", trust_remote_code=True
112
- )
113
- self.model = (
114
- AutoModel.from_pretrained("InstaDeepAI/segment_nt", trust_remote_code=True)
115
- .to(device)
116
- .eval()
117
- )
118
- self.device = device
119
-
120
- def _adjust_length(self, input_ids):
121
- bs, L = input_ids.shape
122
- excl = L - 1
123
- remainder = (excl) % 4
124
- if remainder != 0:
125
- pad_needed = 4 - remainder
126
- pad_tensor = torch.full(
127
- (bs, pad_needed),
128
- self.tokenizer.pad_token_id,
129
- dtype=input_ids.dtype,
130
- device=input_ids.device,
131
- )
132
- input_ids = torch.cat([input_ids, pad_tensor], dim=1)
133
- return input_ids
134
-
135
- def embed(self, seqs, batch_size=16):
136
- """
137
- seqs: List[str]
138
- Returns: np.ndarray of shape (N, D)
139
- """
140
- all_embeddings = []
141
- for i in range(0, len(seqs), batch_size):
142
- batch_seqs = seqs[i : i + batch_size]
143
- encoded = self.tokenizer.batch_encode_plus(
144
- batch_seqs,
145
- return_tensors="pt",
146
- padding=True,
147
- truncation=True,
148
- )
149
- input_ids = encoded["input_ids"].to(self.device) # (B, L)
150
- attention_mask = input_ids != self.tokenizer.pad_token_id
151
-
152
- input_ids = self._adjust_length(input_ids)
153
- attention_mask = input_ids != self.tokenizer.pad_token_id
154
-
155
- with torch.no_grad():
156
- outs = self.model(
157
- input_ids,
158
- attention_mask=attention_mask,
159
- output_hidden_states=True,
160
- return_dict=True,
161
- )
162
- if hasattr(outs, "hidden_states") and outs.hidden_states is not None:
163
- last_hidden = outs.hidden_states[-1] # (B, L, D)
164
- else:
165
- last_hidden = outs.last_hidden_state # fallback
166
-
167
- # Exclude CLS token if present (assume first token) and pool
168
- pooled = last_hidden[:, 1:, :].mean(dim=1) # (B, D)
169
- all_embeddings.append(pooled.cpu().numpy())
170
-
171
- # release fragmentation
172
- torch.cuda.empty_cache()
173
-
174
- return np.vstack(all_embeddings) # (N, D)
175
-
176
-
177
- class DNABertEmbedder:
178
- def __init__(self, device):
179
- self.tokenizer = AutoTokenizer.from_pretrained(
180
- "zhihan1996/DNA_bert_6", trust_remote_code=True
181
- )
182
- self.model = AutoModel.from_pretrained(
183
- "zhihan1996/DNA_bert_6", trust_remote_code=True
184
- ).to(device)
185
- self.device = device
186
-
187
- def embed(self, seqs):
188
- embs = []
189
- for s in seqs:
190
- tokens = self.tokenizer(s, return_tensors="pt", padding=True)[
191
- "input_ids"
192
- ].to(self.device)
193
- with torch.no_grad():
194
- out = self.model(tokens).last_hidden_state.mean(1)
195
- embs.append(out.cpu().numpy())
196
- return np.vstack(embs)
197
-
198
-
199
- class NucleotideTransformerEmbedder:
200
- def __init__(self, device):
201
- # HF “feature-extraction” returns a list of (L, D) arrays for each input
202
- # device: “cpu” or “cuda”
203
- self.pipe = pipeline(
204
- "feature-extraction",
205
- model="InstaDeepAI/nucleotide-transformer-500m-1000g",
206
- device=(
207
- -1 if device == "cpu" else 0
208
- ), # HF uses -1 for CPU, 0 for GPU #:contentReference[oaicite:0]{index=0}
209
- )
210
-
211
- def embed(self, seqs):
212
- """
213
- seqs: List[str] of raw DNA sequences
214
- returns: (N, D) array, one D-dim vector per sequence
215
- """
216
- all_embeddings = self.pipe(seqs, truncation=True, padding=True)
217
- # all_embeddings is a List of shape (L, D) arrays
218
- pooled = [np.mean(x, axis=0) for x in all_embeddings]
219
- return np.vstack(pooled)
220
-
221
-
222
- # class ESMEmbedder:
223
- # def __init__(self, device):
224
- # self.model, self.alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
225
- # self.batch_converter = self.alphabet.get_batch_converter()
226
- # self.model.to(device).eval()
227
- # self.device = device
228
-
229
- # def embed(self, seqs):
230
- # batch = [(str(i), seq) for i, seq in enumerate(seqs)]
231
- # _, _, toks = self.batch_converter(batch)
232
- # toks = toks.to(self.device)
233
- # with torch.no_grad():
234
- # results = self.model(toks, repr_layers=[33], return_contacts=False)
235
- # reps = results["representations"][33]
236
- # return reps[:, 1:-1].mean(1).cpu().numpy()
237
-
238
-
239
- class ESMEmbedder:
240
- def __init__(self, device, model_name="esm2_t33_650M_UR50D"):
241
- # Try to load the specified ESM-2 model; fallback to esm1b if missing
242
- self.device = device
243
- try:
244
- self.model, self.alphabet = getattr(esm.pretrained, model_name)()
245
- self.is_esm2 = model_name.lower().startswith("esm2")
246
- except AttributeError:
247
- # fallback to ESM-1b
248
- self.model, self.alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
249
- self.is_esm2 = False
250
- self.batch_converter = self.alphabet.get_batch_converter()
251
- self.model.to(device).eval()
252
- # determine max length: esm2 models vary; use default 1024 for esm1b
253
- self.max_len = (
254
- 4096 if self.is_esm2 else 1024
255
- ) # adjust if your esm2 variant has explicit limit
256
- # for chunking: reserve 2 tokens if model uses BOS/EOS
257
- self.chunk_size = self.max_len - 2
258
- self.overlap = self.chunk_size // 4 # 25% overlap to smooth boundaries
259
-
260
- def _chunk_sequence(self, seq):
261
- """
262
- Return list of possibly overlapping chunks of seq, each <= chunk_size.
263
- """
264
- if len(seq) <= self.chunk_size:
265
- return [seq]
266
- step = self.chunk_size - self.overlap
267
- chunks = []
268
- for i in range(0, len(seq), step):
269
- chunk = seq[i : i + self.chunk_size]
270
- if not chunk:
271
- break
272
- chunks.append(chunk)
273
- return chunks
274
-
275
- def embed(self, seqs):
276
- """
277
- seqs: List[str] of protein sequences.
278
- Returns: np.ndarray of shape (N, D) pooled per-sequence embeddings.
279
- """
280
- all_embeddings = []
281
- for i, seq in enumerate(seqs):
282
- chunks = self._chunk_sequence(seq)
283
- chunk_vecs = []
284
- # process chunks in batch if small number, else sequentially
285
- for chunk in chunks:
286
- batch = [(str(i), chunk)]
287
- _, _, toks = self.batch_converter(batch)
288
- toks = toks.to(self.device)
289
- with torch.no_grad():
290
- results = self.model(toks, repr_layers=[33], return_contacts=False)
291
- reps = results["representations"][33] # (1, L, D)
292
- # remove BOS/EOS if present: take 1:-1 if length permits
293
- if reps.size(1) > 2:
294
- rep = reps[:, 1:-1].mean(1) # (1, D)
295
- else:
296
- rep = reps.mean(1) # fallback
297
- chunk_vecs.append(rep.squeeze(0)) # (D,)
298
- if len(chunk_vecs) == 1:
299
- seq_vec = chunk_vecs[0]
300
- else:
301
- # average chunk vectors
302
- stacked = torch.stack(chunk_vecs, dim=0) # (num_chunks, D)
303
- seq_vec = stacked.mean(0)
304
- all_embeddings.append(seq_vec.cpu().numpy())
305
- return np.vstack(all_embeddings) # (N, D)
306
-
307
-
308
- # class ESMDBPEmbedder:
309
- # def __init__(self, device):
310
- # base_model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
311
- # model_path = (
312
- # Path(__file__).resolve().parent.parent
313
- # / "pretrained" / "ESM-DBP" / "ESM-DBP.model"
314
- # )
315
- # checkpoint = torch.load(model_path, map_location="cpu")
316
- # clean_sd = {}
317
- # for k, v in checkpoint.items():
318
- # clean_sd[k.replace("module.", "")] = v
319
- # result = base_model.load_state_dict(clean_sd, strict=False)
320
- # if result.missing_keys:
321
- # print(f"[ESMDBP] missing keys: {result.missing_keys}")
322
- # if result.unexpected_keys:
323
- # print(f"[ESMDBP] unexpected keys: {result.unexpected_keys}")
324
-
325
- # self.model = base_model.to(device).eval()
326
- # self.alphabet = alphabet
327
- # self.batch_converter = alphabet.get_batch_converter()
328
- # self.device = device
329
-
330
- # def embed(self, seqs):
331
- # batch = [(str(i), seq) for i, seq in enumerate(seqs)]
332
- # _, _, toks = self.batch_converter(batch)
333
- # toks = toks.to(self.device)
334
- # with torch.no_grad():
335
- # out = self.model(toks, repr_layers=[33], return_contacts=False)
336
- # reps = out["representations"][33]
337
- # # skip start/end tokens
338
- # return reps[:, 1:-1].mean(1).cpu().numpy()
339
-
340
-
341
- class ESMDBPEmbedder:
342
- def __init__(self, device):
343
- base_model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
344
- model_path = (
345
- Path(__file__).resolve().parent.parent
346
- / "pretrained"
347
- / "ESM-DBP"
348
- / "ESM-DBP.model"
349
- )
350
- checkpoint = torch.load(model_path, map_location="cpu")
351
- clean_sd = {}
352
- for k, v in checkpoint.items():
353
- clean_sd[k.replace("module.", "")] = v
354
- result = base_model.load_state_dict(clean_sd, strict=False)
355
- if result.missing_keys:
356
- print(f"[ESMDBP] missing keys: {result.missing_keys}")
357
- if result.unexpected_keys:
358
- print(f"[ESMDBP] unexpected keys: {result.unexpected_keys}")
359
-
360
- self.model = base_model.to(device).eval()
361
- self.alphabet = alphabet
362
- self.batch_converter = alphabet.get_batch_converter()
363
- self.device = device
364
- self.max_len = 1024 # same limit as esm1b
365
- self.chunk_size = self.max_len - 2
366
- self.overlap = self.chunk_size // 4
367
-
368
- def _chunk_sequence(self, seq):
369
- if len(seq) <= self.chunk_size:
370
- return [seq]
371
- step = self.chunk_size - self.overlap
372
- chunks = []
373
- for i in range(0, len(seq), step):
374
- chunk = seq[i : i + self.chunk_size]
375
- if not chunk:
376
- break
377
- chunks.append(chunk)
378
- return chunks
379
-
380
- def embed(self, seqs):
381
- all_embeddings = []
382
- for i, seq in enumerate(seqs):
383
- chunks = self._chunk_sequence(seq)
384
- chunk_vecs = []
385
- for chunk in chunks:
386
- batch = [(str(i), chunk)]
387
- _, _, toks = self.batch_converter(batch)
388
- toks = toks.to(self.device)
389
- with torch.no_grad():
390
- out = self.model(toks, repr_layers=[33], return_contacts=False)
391
- reps = out["representations"][33]
392
- if reps.size(1) > 2:
393
- rep = reps[:, 1:-1].mean(1)
394
- else:
395
- rep = reps.mean(1)
396
- chunk_vecs.append(rep.squeeze(0))
397
- if len(chunk_vecs) == 1:
398
- seq_vec = chunk_vecs[0]
399
- else:
400
- stacked = torch.stack(chunk_vecs, dim=0)
401
- seq_vec = stacked.mean(0)
402
- all_embeddings.append(seq_vec.cpu().numpy())
403
- return np.vstack(all_embeddings)
404
-
405
-
406
- class GPNEmbedder:
407
- def __init__(self, device):
408
- model_name = "songlab/gpn-msa-sapiens"
409
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
410
- self.model = AutoModelForMaskedLM.from_pretrained(model_name)
411
- self.model.to(device)
412
- self.model.eval()
413
- self.device = device
414
-
415
- def embed(self, seqs):
416
- inputs = self.tokenizer(
417
- seqs, return_tensors="pt", padding=True, truncation=True
418
- ).to(self.device)
419
-
420
- with torch.no_grad():
421
- last_hidden = self.model(**inputs).last_hidden_state
422
- return last_hidden.mean(dim=1).cpu().numpy()
423
-
424
-
425
- class ProGenEmbedder:
426
- def __init__(self, device):
427
- model_name = "jinyuan22/ProGen2-base"
428
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
429
- self.model = AutoModel.from_pretrained(model_name).to(device).eval()
430
- self.device = device
431
-
432
- def embed(self, seqs):
433
- inputs = self.tokenizer(
434
- seqs, return_tensors="pt", padding=True, truncation=True
435
- ).to(self.device)
436
- with torch.no_grad():
437
- last_hidden = self.model(**inputs).last_hidden_state
438
- return last_hidden.mean(dim=1).cpu().numpy()
439
-
440
-
441
- # ---- main pipeline ----
442
-
443
-
444
- def get_embedder(name, device, for_dna=True):
445
- name = name.lower()
446
- if for_dna:
447
- if name == "caduceus":
448
- return CaduceusEmbedder(device)
449
- if name == "dnabert":
450
- return DNABertEmbedder(device)
451
- if name == "nucleotide":
452
- return NucleotideTransformerEmbedder(device)
453
- if name == "gpn":
454
- return GPNEmbedder(device)
455
- if name == "segmentnt":
456
- return SegmentNTEmbedder(device)
457
- else:
458
- if name in ("esm",):
459
- return ESMEmbedder(device)
460
- if name in ("esm-dbp", "esm_dbp"):
461
- return ESMDBPEmbedder(device)
462
- if name == "progen":
463
- return ProGenEmbedder(device)
464
- raise ValueError(f"Unknown model {name} (for_dna={for_dna})")
465
-
466
-
467
- def pad_token_embeddings(list_of_arrays, pad_value=0.0):
468
- """
469
- list_of_arrays: list of (L_i, D) numpy arrays
470
- Returns:
471
- padded: (N, L_max, D) array
472
- mask: (N, L_max) boolean array where True = real token, False = padding
473
- """
474
- N = len(list_of_arrays)
475
- D = list_of_arrays[0].shape[1]
476
- L_max = max(arr.shape[0] for arr in list_of_arrays)
477
- padded = np.full((N, L_max, D), pad_value, dtype=list_of_arrays[0].dtype)
478
- mask = np.zeros((N, L_max), dtype=bool)
479
- for i, arr in enumerate(list_of_arrays):
480
- L = arr.shape[0]
481
- padded[i, :L] = arr
482
- mask[i, :L] = True
483
- return padded, mask
484
-
485
-
486
- def embed_and_save(seqs, ids, embedder, out_path):
487
- embs = embedder.embed(seqs)
488
-
489
- # Decide whether we got variable-length per-token outputs (list of (L, D))
490
- is_variable_token = (
491
- isinstance(embs, (list, tuple))
492
- and len(embs) > 0
493
- and hasattr(embs[0], "shape")
494
- and embs[0].ndim == 2
495
- )
496
-
497
- if is_variable_token:
498
- # pad to (N, L_max, D) + mask
499
- padded, mask = pad_token_embeddings(embs)
500
- # Save both embeddings and mask together in an .npz for convenience
501
- np.savez_compressed(
502
- out_path.with_suffix(".caduceus.npz"),
503
- embeddings=padded,
504
- mask=mask,
505
- ids=np.array(ids, dtype=object),
506
- )
507
- else:
508
- # fixed shape output, e.g., pooled (N, D)
509
- array = np.vstack(embs) if isinstance(embs, list) else embs
510
- np.save(out_path, array)
511
- with open(out_path.with_suffix(".ids"), "w") as f:
512
- f.write("\n".join(ids))
513
-
514
-
515
- if __name__ == "__main__":
516
-
517
- p = argparse.ArgumentParser()
518
- p.add_argument(
519
- "--peak-fasta",
520
- default="binding_peaks_unique.fa",
521
- help="FASTA of deduplicated binding peak sequences; if present this is used for DNA embedding instead of genome JSONs",
522
- )
523
- p.add_argument(
524
- "--genome-json-dir",
525
- default=None,
526
- help="(fallback) directory of UCSC JSONs for full chromosome embedding if peak FASTA is missing or you explicitly want chromosomes",
527
- )
528
- p.add_argument(
529
- "--skip-dna",
530
- action="store_true",
531
- help="if set, skip the chromosome embedding step",
532
- ) # if glm embeddings successful but not plm embeddings
533
- p.add_argument("--tf-fasta", required=True, help="input TF FASTA file")
534
- p.add_argument("--chrom-model", default="caduceus")
535
- p.add_argument("--tf-model", default="esm-dbp")
536
- p.add_argument(
537
- "--out-dir", default="data_files/processed/tfclust/hg38_tf/embeddings"
538
- )
539
- p.add_argument("--device", default="cpu")
540
- args = p.parse_args()
541
-
542
- os.makedirs(args.out_dir, exist_ok=True)
543
- device = args.device
544
-
545
- if not args.skip_dna:
546
- peak_fasta = Path(args.peak_fasta)
547
- if peak_fasta.exists():
548
- # Load peak sequences from FASTA
549
- from Bio import SeqIO
550
-
551
- peak_seqs = []
552
- peak_ids = []
553
- for rec in SeqIO.parse(peak_fasta, "fasta"):
554
- peak_ids.append(rec.id)
555
- peak_seqs.append(str(rec.seq))
556
- print(
557
- f"Embedding {len(peak_seqs)} binding peak sequences from {peak_fasta}",
558
- flush=True,
559
- )
560
- dna_embedder = get_embedder(args.chrom_model, device, for_dna=True)
561
- out_peaks = Path(args.out_dir) / f"peaks_{args.chrom_model}.npy"
562
- embed_and_save(peak_seqs, peak_ids, dna_embedder, out_peaks)
563
- elif args.genome_json_dir:
564
- # Legacy: load full chromosomes from JSONs (chr1–22, X, Y, M)
565
- genome_dir = Path(args.genome_json_dir)
566
- chrom_seqs, chrom_ids = [], []
567
- primary_pattern = re.compile(
568
- r"^hg38_chr(?:[1-9]|1[0-9]|2[0-2]|X|Y|M)\.json$"
569
- )
570
- for j in sorted(genome_dir.iterdir()):
571
- if not primary_pattern.match(j.name):
572
- continue
573
- data = json.loads(j.read_text())
574
- seq = data.get("dna") or data.get("sequence")
575
- chrom = data.get("chrom") or j.stem.split("_")[-1]
576
- chrom_seqs.append(seq)
577
- chrom_ids.append(chrom)
578
- cutoff = CaduceusEmbedder(device).chunk_size
579
- long_chroms = [
580
- (chrom, len(seq))
581
- for chrom, seq in zip(chrom_ids, chrom_seqs)
582
- if len(seq) > cutoff
583
- ]
584
- if long_chroms:
585
- print(
586
- "⚠️ Chromosomes exceeding Caduceus max tokens ({}):".format(cutoff)
587
- )
588
- for chrom, L in long_chroms:
589
- print(f" {chrom}: {L} bases")
590
- else:
591
- print("All chromosomes ≤ Caduceus limit ({}).".format(cutoff))
592
-
593
- chrom_embedder = get_embedder(args.chrom_model, device, for_dna=True)
594
- out_chrom = Path(args.out_dir) / f"chrom_{args.chrom_model}.npy"
595
- embed_and_save(chrom_seqs, chrom_ids, chrom_embedder, out_chrom)
596
- else:
597
- raise ValueError(
598
- "No input for DNA embedding: provide a peak FASTA (default binding_peaks_unique.fa) or set --genome-json-dir for chromosome JSONs."
599
- )
600
-
601
- # Load TF sequences
602
- tf_seqs, tf_ids = [], []
603
- for record in SeqIO.parse(args.tf_fasta, "fasta"):
604
- tf_ids.append(record.id)
605
- tf_seqs.append(str(record.seq))
606
-
607
- # embed and save
608
- tf_embedder = get_embedder(args.tf_model, device, for_dna=False)
609
- out_tf = Path(args.out_dir) / f"tf_{args.tf_model}.npy"
610
- embed_and_save(tf_seqs, tf_ids, tf_embedder, out_tf)
611
-
612
- print("Done.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dpacman/classifier/model_tmp/extract_tf_symbols.py DELETED
@@ -1,30 +0,0 @@
1
- #!/usr/bin/env python3
2
- import pandas as pd
3
- from pathlib import Path
4
-
5
- FINAL_CSV = Path("/home/a03-akrishna/DPACMAN/data_files/processed/final.csv")
6
- OUT_SYMBOLS = Path("tf_symbols.txt")
7
-
8
-
9
- def normalize_tf(tf_id: str) -> str:
10
- return tf_id.split("_seq")[0].upper()
11
-
12
-
13
- def main():
14
- df = pd.read_csv(FINAL_CSV, dtype=str)
15
- if "TF_id" not in df.columns:
16
- raise RuntimeError("final.csv missing TF_id column")
17
- tf_raw = df["TF_id"].dropna().unique().tolist()
18
- normalized = sorted({normalize_tf(t) for t in tf_raw})
19
- print(f"Unique raw TF_id count: {len(tf_raw)}")
20
- print(f"Unique normalized TF symbols: {len(normalized)}")
21
- with open(OUT_SYMBOLS, "w") as f:
22
- for s in normalized:
23
- f.write(s + "\n")
24
- print(f"Wrote normalized TF symbols to {OUT_SYMBOLS}")
25
- # Optional: show sample
26
- print("Sample symbols:", normalized[:50])
27
-
28
-
29
- if __name__ == "__main__":
30
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dpacman/classifier/model_tmp/make_pair_list.py DELETED
@@ -1,282 +0,0 @@
1
- #!/usr/bin/env python3
2
- import argparse
3
- import numpy as np
4
- import pandas as pd
5
- from pathlib import Path
6
- import random
7
- import sys
8
-
9
-
10
- def read_ids_file(p):
11
- p = Path(p)
12
- if not p.exists():
13
- raise FileNotFoundError(f"IDs file not found: {p}")
14
- return [line.strip() for line in p.open() if line.strip()]
15
-
16
-
17
- def split_embeddings(emb_path, ids_path, out_dir, prefix):
18
- out_dir = Path(out_dir)
19
- out_dir.mkdir(parents=True, exist_ok=True)
20
-
21
- if not Path(emb_path).exists():
22
- raise FileNotFoundError(f"Embedding file not found: {emb_path}")
23
- if not Path(ids_path).exists():
24
- raise FileNotFoundError(f"IDs file not found: {ids_path}")
25
-
26
- if emb_path.endswith(".npz"):
27
- data = np.load(emb_path, allow_pickle=True)
28
- if "embeddings" in data:
29
- emb = data["embeddings"]
30
- else:
31
- raise ValueError(f"{emb_path} missing 'embeddings' key")
32
- else:
33
- emb = np.load(emb_path)
34
-
35
- ids = read_ids_file(ids_path)
36
- if len(ids) != emb.shape[0]:
37
- print(
38
- f"[WARN] length mismatch: {len(ids)} ids vs {emb.shape[0]} embeddings in {emb_path}",
39
- file=sys.stderr,
40
- )
41
-
42
- mapping = {}
43
- for i, ident in enumerate(ids):
44
- if i >= emb.shape[0]:
45
- print(
46
- f"[WARN] skipping {ident}: no embedding at index {i}", file=sys.stderr
47
- )
48
- continue
49
- arr = emb[i]
50
- out_file = out_dir / f"{prefix}_{ident}.npy"
51
- np.save(out_file, arr)
52
- mapping[ident] = str(out_file)
53
- return mapping
54
-
55
-
56
- def extract_symbol_from_tf_id(full_id: str) -> str:
57
- """
58
- Given a TF embedding ID like 'sp|O15062|ZBTB5_HUMAN' or 'ZBTB5_HUMAN',
59
- return the gene symbol uppercase (e.g., 'ZBTB5').
60
- """
61
- if "|" in full_id:
62
- try:
63
- # format sp|Accession|SYMBOL_HUMAN
64
- genepart = full_id.split("|")[2]
65
- except IndexError:
66
- genepart = full_id
67
- else:
68
- genepart = full_id
69
- symbol = genepart.split("_")[0]
70
- return symbol.upper()
71
-
72
-
73
- def build_tf_symbol_map(tf_map):
74
- """
75
- Build mapping gene_symbol -> list of embedding paths.
76
- """
77
- symbol_map = {}
78
- for full_id, path in tf_map.items():
79
- symbol = extract_symbol_from_tf_id(full_id)
80
- symbol_map.setdefault(symbol, []).append(path)
81
- return symbol_map
82
-
83
-
84
- def tf_key_from_path(path: str) -> str:
85
- """
86
- Given a path like .../tf_sp|O15062|ZBTB5_HUMAN.npy, extract normalized symbol 'ZBTB5'.
87
- """
88
- stem = Path(path).stem # e.g., tf_sp|O15062|ZBTB5_HUMAN
89
- # remove leading prefix if present (tf_)
90
- if "_" in stem:
91
- _, rest = stem.split("_", 1)
92
- else:
93
- rest = stem
94
- return extract_symbol_from_tf_id(rest)
95
-
96
-
97
- def dna_key_from_path(path: str) -> str:
98
- """
99
- Given .../dna_peak42.npy -> 'peak42'
100
- """
101
- stem = Path(path).stem
102
- if "_" in stem:
103
- _, rest = stem.split("_", 1)
104
- else:
105
- rest = stem
106
- return rest
107
-
108
-
109
- def main():
110
- parser = argparse.ArgumentParser(
111
- description="Build TF-DNA pair list from final.csv with gene-symbol normalization for TFs."
112
- )
113
- parser.add_argument(
114
- "--final_csv", required=True, help="final.csv with TF_id and dna_sequence"
115
- )
116
- parser.add_argument(
117
- "--dna_embed_npz", required=True, help="DNA embedding file (.npy or .npz)"
118
- )
119
- parser.add_argument(
120
- "--dna_ids", required=True, help="IDs file for DNA embeddings (e.g., peak*.ids)"
121
- )
122
- parser.add_argument(
123
- "--tf_embed_npy", required=True, help="TF embedding file (.npy or .npz)"
124
- )
125
- parser.add_argument(
126
- "--tf_ids",
127
- required=True,
128
- help="IDs file for TF embeddings (e.g., sp|...|... ids)",
129
- )
130
- parser.add_argument("--out_dir", required=True, help="Output directory")
131
- parser.add_argument(
132
- "--neg_per_positive",
133
- type=int,
134
- default=2,
135
- help="Negatives per positive (half same-TF, half same-DNA)",
136
- )
137
- parser.add_argument("--seed", type=int, default=42)
138
- args = parser.parse_args()
139
-
140
- random.seed(args.seed)
141
- out_dir = Path(args.out_dir)
142
- out_dir.mkdir(parents=True, exist_ok=True)
143
-
144
- # Load final.csv
145
- df = pd.read_csv(args.final_csv, dtype=str)
146
- if "TF_id" not in df.columns or "dna_sequence" not in df.columns:
147
- raise RuntimeError("final.csv must have columns TF_id and dna_sequence")
148
-
149
- # Assign dna_id (unique per dna_sequence)
150
- unique_seqs = df["dna_sequence"].drop_duplicates().tolist()
151
- seq_to_id = {seq: f"peak{i}" for i, seq in enumerate(unique_seqs)}
152
- df["dna_id"] = df["dna_sequence"].map(seq_to_id)
153
- enriched_csv = out_dir / "final_with_dna_id.csv"
154
- df.to_csv(enriched_csv, index=False)
155
- print(f"[i] Wrote augmented final.csv with dna_id to {enriched_csv}")
156
-
157
- # Split embeddings into per-item files
158
- print(
159
- f"[i] Splitting DNA embeddings from {args.dna_embed_npz} with ids {args.dna_ids}"
160
- )
161
- dna_map = split_embeddings(
162
- args.dna_embed_npz, args.dna_ids, out_dir / "dna_single", "dna"
163
- )
164
- print(
165
- f"[i] DNA embeddings available: {len(dna_map)} (sample: {list(dna_map.keys())[:10]})"
166
- )
167
- print(
168
- f"[i] Splitting TF embeddings from {args.tf_embed_npy} with ids {args.tf_ids}"
169
- )
170
- tf_map = split_embeddings(
171
- args.tf_embed_npy, args.tf_ids, out_dir / "tf_single", "tf"
172
- )
173
- print(
174
- f"[i] TF embeddings available: {len(tf_map)} (sample: {list(tf_map.keys())[:10]})"
175
- )
176
-
177
- # Build gene-symbol normalized map
178
- tf_symbol_map = build_tf_symbol_map(tf_map)
179
- print(f"[i] TF symbol map keys (sample): {list(tf_symbol_map.keys())[:30]}")
180
-
181
- # Diagnostic overlaps
182
- norm_tf_in_final = set(t.split("_seq")[0].upper() for t in df["TF_id"].unique())
183
- available_tf_symbols = set(tf_symbol_map.keys())
184
- intersect_tf = norm_tf_in_final & available_tf_symbols
185
- print(f"[i] Unique normalized TF symbols in final.csv: {len(norm_tf_in_final)}")
186
- print(f"[i] Available TF embedding symbols: {len(available_tf_symbols)}")
187
- print(f"[i] Intersection count: {len(intersect_tf)}")
188
- if len(intersect_tf) == 0:
189
- print(
190
- "[ERROR] No overlap between normalized TF_id and TF embedding symbols.",
191
- file=sys.stderr,
192
- )
193
- print(
194
- "Sample normalized TFs from final.csv:",
195
- sorted(list(norm_tf_in_final))[:30],
196
- file=sys.stderr,
197
- )
198
- print(
199
- "Sample available TF symbols:",
200
- sorted(list(available_tf_symbols))[:30],
201
- file=sys.stderr,
202
- )
203
- sys.exit(1)
204
-
205
- dna_ids_final = set(df["dna_id"].unique())
206
- available_dna_ids = set(dna_map.keys())
207
- intersect_dna = dna_ids_final & available_dna_ids
208
- print(
209
- f"[i] Unique dna_id in final.csv: {len(dna_ids_final)}. Available DNA ids: {len(available_dna_ids)}. Intersection: {len(intersect_dna)}"
210
- )
211
- if len(intersect_dna) == 0:
212
- print("[ERROR] No overlap on DNA ids.", file=sys.stderr)
213
- sys.exit(1)
214
-
215
- # Build positive pairs
216
- positives = []
217
- for _, row in df.iterrows():
218
- tf_raw = row["TF_id"]
219
- tf_symbol = tf_raw.split("_seq")[0].upper()
220
- dnaid = row["dna_id"]
221
- if tf_symbol not in tf_symbol_map:
222
- continue
223
- if dnaid not in dna_map:
224
- continue
225
- # pick the first embedding for that symbol
226
- tf_embedding_path = tf_symbol_map[tf_symbol][0]
227
- positives.append((tf_embedding_path, dna_map[dnaid], 1))
228
- print(f"[i] Constructed {len(positives)} positive pairs after TF symbol resolution")
229
-
230
- if len(positives) == 0:
231
- print(
232
- "[ERROR] No positive pairs could be constructed; aborting.", file=sys.stderr
233
- )
234
- sys.exit(1)
235
-
236
- # Build negative samples
237
- all_tf_symbols = sorted(tf_symbol_map.keys())
238
- all_dnaids = sorted(dna_map.keys())
239
- positive_set = set()
240
- for tf_path, dna_path, _ in positives:
241
- tf_key = tf_key_from_path(tf_path)
242
- dna_key = dna_key_from_path(dna_path)
243
- positive_set.add((tf_key, dna_key))
244
-
245
- negatives = []
246
- half = args.neg_per_positive // 2
247
- for tf_path, dna_path, _ in positives:
248
- tf_key = tf_key_from_path(tf_path)
249
- dna_key = dna_key_from_path(dna_path)
250
- # same TF, different DNA
251
- for _ in range(half):
252
- candidate_dna = random.choice(all_dnaids)
253
- if candidate_dna == dna_key or (tf_key, candidate_dna) in positive_set:
254
- continue
255
- negatives.append((tf_path, dna_map[candidate_dna], 0))
256
- # same DNA, different TF
257
- for _ in range(half):
258
- candidate_tf_symbol = random.choice(all_tf_symbols)
259
- if (
260
- candidate_tf_symbol == tf_key
261
- or (candidate_tf_symbol, dna_key) in positive_set
262
- ):
263
- continue
264
- # pick its first embedding
265
- candidate_tf_path = tf_symbol_map[candidate_tf_symbol][0]
266
- negatives.append((candidate_tf_path, dna_map[dnaid], 0))
267
-
268
- print(
269
- f"[i] Sampled {len(negatives)} negatives (neg_per_positive={args.neg_per_positive})"
270
- )
271
-
272
- # Write pair list
273
- pair_list_path = out_dir / "pair_list.tsv"
274
- with open(pair_list_path, "w") as f:
275
- for binder_path, glm_path, label in positives + negatives:
276
- # binder=TF, glm=DNA
277
- f.write(f"{binder_path}\t{glm_path}\t{label}\n")
278
- print(f"[i] Wrote {len(positives)+len(negatives)} examples to {pair_list_path}")
279
-
280
-
281
- if __name__ == "__main__":
282
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dpacman/classifier/model_tmp/make_peak_fasta.py DELETED
@@ -1,15 +0,0 @@
1
- import pandas as pd
2
- from pathlib import Path
3
-
4
- df = pd.read_csv(
5
- "/home/a03-akrishna/DPACMAN/data_files/processed/final.csv", dtype=str
6
- ) # adjust path if needed
7
- # get unique sequences
8
- uniq = df[["dna_sequence"]].drop_duplicates().reset_index(drop=True)
9
- # make headers: e.g., peak0, peak1, ...
10
- out_fa = Path("binding_peaks_unique.fa")
11
- with open(out_fa, "w") as f:
12
- for i, seq in enumerate(uniq["dna_sequence"]):
13
- header = f">peak{i}"
14
- f.write(f"{header}\n{seq}\n")
15
- print(f"Wrote {len(uniq)} unique binding sequences to {out_fa}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dpacman/classifier/model_tmp/model.py DELETED
@@ -1,111 +0,0 @@
1
- import torch
2
- from torch import nn
3
-
4
-
5
- class LocalCNN(nn.Module):
6
- def __init__(self, dim: int = 256, kernel_size: int = 3):
7
- super().__init__()
8
- padding = kernel_size // 2
9
- self.conv = nn.Conv1d(dim, dim, kernel_size=kernel_size, padding=padding)
10
- self.act = nn.GELU()
11
- self.ln = nn.LayerNorm(dim)
12
-
13
- def forward(self, x: torch.Tensor):
14
- # x: (batch, L, dim)
15
- out = self.conv(x.transpose(1, 2)) # → (batch, dim, L)
16
- out = self.act(out)
17
- out = out.transpose(1, 2) # → (batch, L, dim)
18
- return self.ln(out + x) # residual
19
-
20
-
21
- class CrossModalBlock(nn.Module):
22
- def __init__(self, dim: int = 256, heads: int = 8):
23
- super().__init__()
24
- # self-attention for both sides
25
- self.sa_binder = nn.MultiheadAttention(dim, heads, batch_first=True)
26
- self.sa_glm = nn.MultiheadAttention(dim, heads, batch_first=True)
27
- self.ln_b1 = nn.LayerNorm(dim)
28
- self.ln_g1 = nn.LayerNorm(dim)
29
-
30
- self.ffn_b = nn.Sequential(
31
- nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
32
- )
33
- self.ffn_g = nn.Sequential(
34
- nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
35
- )
36
- self.ln_b2 = nn.LayerNorm(dim)
37
- self.ln_g2 = nn.LayerNorm(dim)
38
-
39
- # cross attention (binder queries, glm keys/values)
40
- self.cross_attn = nn.MultiheadAttention(dim, heads, batch_first=True)
41
- self.ln_c1 = nn.LayerNorm(dim)
42
- self.ffn_c = nn.Sequential(
43
- nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
44
- )
45
- self.ln_c2 = nn.LayerNorm(dim)
46
-
47
- def forward(self, binder: torch.Tensor, glm: torch.Tensor):
48
- """
49
- binder: (batch, Lb, dim)
50
- glm: (batch, Lg, dim) -- has passed through its local CNN beforehand
51
- returns: updated binder representation (batch, Lb, dim)
52
- """
53
- # binder self-attn + ffn
54
- b = binder
55
- b_sa, _ = self.sa_binder(b, b, b)
56
- b = self.ln_b1(b + b_sa)
57
- b_ff = self.ffn_b(b)
58
- b = self.ln_b2(b + b_ff)
59
-
60
- # glm self-attn + ffn
61
- g = glm
62
- g_sa, _ = self.sa_glm(g, g, g)
63
- g = self.ln_g1(g + g_sa)
64
- g_ff = self.ffn_g(g)
65
- g = self.ln_g2(g + g_ff)
66
-
67
- # cross-attention: binder queries glm
68
- c_sa, _ = self.cross_attn(b, g, g)
69
- c = self.ln_c1(b + c_sa)
70
- c_ff = self.ffn_c(c)
71
- c = self.ln_c2(c + c_ff)
72
- return c # (batch, Lb, dim)
73
-
74
-
75
- class BindPredictor(nn.Module):
76
- def __init__(
77
- self,
78
- input_dim: int = 256,
79
- hidden_dim: int = 256,
80
- heads: int = 8,
81
- num_layers: int = 4,
82
- use_local_cnn_on_glm: bool = True,
83
- ):
84
- super().__init__()
85
- self.proj_binder = nn.Linear(input_dim, hidden_dim)
86
- self.proj_glm = nn.Linear(input_dim, hidden_dim)
87
- self.use_local_cnn = use_local_cnn_on_glm
88
- self.local_cnn = LocalCNN(hidden_dim) if use_local_cnn_on_glm else nn.Identity()
89
-
90
- self.layers = nn.ModuleList(
91
- [CrossModalBlock(hidden_dim, heads) for _ in range(num_layers)]
92
- )
93
-
94
- self.ln_out = nn.LayerNorm(hidden_dim)
95
- self.head = nn.Sequential(nn.Linear(hidden_dim, 1), nn.Sigmoid())
96
-
97
- def forward(self, binder_emb, glm_emb):
98
- """
99
- binder_emb, glm_emb: (batch, L, input_dim)
100
- """
101
- b = self.proj_binder(binder_emb) # (B, Lb, hidden_dim)
102
- g = self.proj_glm(glm_emb) # (B, Lg, hidden_dim)
103
- if self.use_local_cnn:
104
- g = self.local_cnn(g) # local context injected
105
-
106
- for layer in self.layers:
107
- b = layer(b, g) # update binder with cross-modal info
108
-
109
- pooled = b.mean(dim=1) # (B, hidden_dim)
110
- out = self.ln_out(pooled)
111
- return self.head(out).squeeze(-1) # (B,)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dpacman/classifier/model_tmp/prep_splits.py DELETED
@@ -1,157 +0,0 @@
1
- import numpy as np
2
- import pandas as pd
3
- import sys
4
- import json
5
- from sklearn.decomposition import TruncatedSVD
6
- from sklearn.model_selection import train_test_split
7
- from collections import Counter
8
-
9
-
10
- def parse_pair_list(pair_list_path):
11
- binder_paths, glm_paths, labels = [], [], []
12
- with open(pair_list_path) as f:
13
- for lineno, line in enumerate(f, start=1):
14
- if not line.strip():
15
- continue
16
- parts = line.strip().split()
17
- if len(parts) != 3:
18
- print(
19
- f"[WARN] skipping malformed line {lineno}: {line.strip()}",
20
- file=sys.stderr,
21
- )
22
- continue
23
- b, g, l = parts
24
- try:
25
- lab = int(l)
26
- except ValueError:
27
- print(f"[WARN] invalid label on line {lineno}: {l}", file=sys.stderr)
28
- continue
29
- binder_paths.append(b)
30
- glm_paths.append(g)
31
- labels.append(lab)
32
- return binder_paths, glm_paths, labels
33
-
34
-
35
- def build_tf_compressed_cache(binder_paths, target_dim=256):
36
- """
37
- Load all unique TF (binder) embeddings, fit reduction if needed, and return dict mapping path->(L, target_dim) array.
38
- """
39
- unique_paths = sorted(set(binder_paths))
40
- print(
41
- f"[i] Found {len(unique_paths)} unique TF embedding files to compress.",
42
- flush=True,
43
- )
44
- # Load all embeddings to determine dimensionality
45
- samples = []
46
- for p in unique_paths:
47
- arr = np.load(p)
48
- samples.append(arr)
49
- # Determine if reduction needed: assume all have same embedding width
50
- first = samples[0]
51
- orig_dim = first.shape[1] if first.ndim == 2 else 1
52
- reduction_needed = orig_dim != target_dim
53
- tf_cache = {}
54
-
55
- if reduction_needed:
56
- # Build matrix to fit SVD: we need a 2D matrix per embedding; if lengths vary we can't directly stack.
57
- # We'll do reduction per sequence individually using TruncatedSVD on concatenated flattened features:
58
- # Simplest: for variable lengths, reduce each embedding separately with a learned linear projection.
59
- # Here we fit a single TruncatedSVD on the concatenation of all sequence tokens (flattened) by padding/truncating to a fixed length.
60
- # To avoid complexity, use PCA-like linear projection learned via SVD on mean-pooled vectors:
61
- pooled = []
62
- for arr in samples:
63
- if arr.ndim == 2:
64
- pooled.append(arr.mean(axis=0)) # (orig_dim,)
65
- else:
66
- pooled.append(arr) # degenerate
67
- pooled_mat = np.stack(pooled, axis=0) # (N, orig_dim)
68
- print(
69
- f"[i] Fitting TruncatedSVD on TF pooled embeddings: {pooled_mat.shape} -> {target_dim}",
70
- flush=True,
71
- )
72
- svd = TruncatedSVD(n_components=target_dim, random_state=42)
73
- reduced_pooled = svd.fit_transform(pooled_mat) # (N, target_dim)
74
-
75
- # For each original embedding, project token-level vectors by multiplying token vector with svd.components_.T
76
- # svd.components_: (target_dim, orig_dim) so projection matrix is (orig_dim, target_dim)
77
- proj_mat = svd.components_.T # (orig_dim, target_dim)
78
- for i, p in enumerate(unique_paths):
79
- arr = samples[i] # shape (L, orig_dim)
80
- if arr.ndim == 1:
81
- arr2 = arr @ proj_mat # (target_dim,)
82
- else:
83
- # project each token: (L, orig_dim) @ (orig_dim, target_dim) -> (L, target_dim)
84
- arr2 = arr @ proj_mat
85
- tf_cache[p] = arr2 # reduced per-token representation
86
- print("[i] Completed compression of TF embeddings.", flush=True)
87
- else:
88
- # already correct dim: just cache originals
89
- print(
90
- f"[i] TF embeddings already {target_dim}-dimensional; skipping reduction.",
91
- flush=True,
92
- )
93
- for i, p in enumerate(unique_paths):
94
- arr = samples[i]
95
- tf_cache[p] = arr
96
- return tf_cache
97
-
98
-
99
- def main():
100
- # df = pd.read_csv("../data_files/processed/fimo/ananya_aug4_2025_final.csv")
101
-
102
- binder_paths, glm_paths, labels = parse_pair_list(
103
- "../data_files/processed/fimo/ananya_aug4_2025_pair_list.tsv"
104
- )
105
-
106
- if len(labels) == 0:
107
- print("[ERROR] No valid pairs parsed. Exiting.", file=sys.stderr)
108
- sys.exit(1)
109
-
110
- label_counts = Counter(labels)
111
- print(
112
- f"[i] Total examples parsed: {len(labels)}. Label distribution: {label_counts}",
113
- flush=True,
114
- )
115
-
116
- # build compressed TF cache (reduces to 256 if needed)
117
- # tf_compressed_cache = build_tf_compressed_cache(binder_paths, target_dim=256)
118
-
119
- # Combine all data into one structure for easy splitting
120
- data = list(zip(binder_paths, glm_paths, labels))
121
-
122
- # First split: train vs temp (val+test)
123
- train_data, temp_data = train_test_split(data, test_size=0.2, random_state=42)
124
-
125
- # Second split: val vs test (50% of 20% → 10% each)
126
- val_data, test_data = train_test_split(temp_data, test_size=0.5, random_state=42)
127
-
128
- # Unpack for dataset construction
129
- def unpack(data):
130
- binders, glms, labels = zip(*data)
131
- return list(binders), list(glms), list(labels)
132
-
133
- def save_split(binder_paths, glm_paths, labels, out_path):
134
- df = pd.DataFrame(
135
- {
136
- "binder_path": binder_paths,
137
- "glm_path": glm_paths,
138
- "label": labels,
139
- }
140
- )
141
- df.to_csv(out_path, index=False)
142
-
143
- # Unpack data for saving
144
- train_binders, train_glms, train_labels = unpack(train_data)
145
- val_binders, val_glms, val_labels = unpack(val_data)
146
- test_binders, test_glms, test_labels = unpack(test_data)
147
-
148
- # Save each split
149
- save_split(
150
- train_binders, train_glms, train_labels, "../data_files/splits/train.csv"
151
- )
152
- save_split(val_binders, val_glms, val_labels, "../data_files/splits/val.csv")
153
- save_split(test_binders, test_glms, test_labels, "../data_files/splits/test.csv")
154
-
155
-
156
- if __name__ == "__main__":
157
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dpacman/classifier/model_tmp/train.py DELETED
@@ -1,217 +0,0 @@
1
- #!/usr/bin/env python3
2
- import argparse
3
- import numpy as np
4
- import torch
5
- from torch import nn
6
- from model import BindPredictor
7
- from pathlib import Path
8
- from collections import Counter
9
- from sklearn.metrics import roc_auc_score, average_precision_score
10
- from sklearn.decomposition import TruncatedSVD
11
- import sys
12
-
13
- from dpacman.utils.models import set_seed
14
-
15
-
16
- def build_tf_compressed_cache(binder_paths, target_dim=256):
17
- """
18
- Load all unique TF (binder) embeddings, fit reduction if needed, and return dict mapping path->(L, target_dim) array.
19
- """
20
- unique_paths = sorted(set(binder_paths))
21
- print(
22
- f"[i] Found {len(unique_paths)} unique TF embedding files to compress.",
23
- flush=True,
24
- )
25
- # Load all embeddings to determine dimensionality
26
- samples = []
27
- for p in unique_paths:
28
- arr = np.load(p)
29
- samples.append(arr)
30
- # Determine if reduction needed: assume all have same embedding width
31
- first = samples[0]
32
- orig_dim = first.shape[1] if first.ndim == 2 else 1
33
- reduction_needed = orig_dim != target_dim
34
- tf_cache = {}
35
-
36
- if reduction_needed:
37
- # Build matrix to fit SVD: we need a 2D matrix per embedding; if lengths vary we can't directly stack.
38
- # We'll do reduction per sequence individually using TruncatedSVD on concatenated flattened features:
39
- # Simplest: for variable lengths, reduce each embedding separately with a learned linear projection.
40
- # Here we fit a single TruncatedSVD on the concatenation of all sequence tokens (flattened) by padding/truncating to a fixed length.
41
- # To avoid complexity, use PCA-like linear projection learned via SVD on mean-pooled vectors:
42
- pooled = []
43
- for arr in samples:
44
- if arr.ndim == 2:
45
- pooled.append(arr.mean(axis=0)) # (orig_dim,)
46
- else:
47
- pooled.append(arr) # degenerate
48
- pooled_mat = np.stack(pooled, axis=0) # (N, orig_dim)
49
- print(
50
- f"[i] Fitting TruncatedSVD on TF pooled embeddings: {pooled_mat.shape} -> {target_dim}",
51
- flush=True,
52
- )
53
- svd = TruncatedSVD(n_components=target_dim, random_state=42)
54
- reduced_pooled = svd.fit_transform(pooled_mat) # (N, target_dim)
55
-
56
- # For each original embedding, project token-level vectors by multiplying token vector with svd.components_.T
57
- # svd.components_: (target_dim, orig_dim) so projection matrix is (orig_dim, target_dim)
58
- proj_mat = svd.components_.T # (orig_dim, target_dim)
59
- for i, p in enumerate(unique_paths):
60
- arr = samples[i] # shape (L, orig_dim)
61
- if arr.ndim == 1:
62
- arr2 = arr @ proj_mat # (target_dim,)
63
- else:
64
- # project each token: (L, orig_dim) @ (orig_dim, target_dim) -> (L, target_dim)
65
- arr2 = arr @ proj_mat
66
- tf_cache[p] = arr2 # reduced per-token representation
67
- print("[i] Completed compression of TF embeddings.", flush=True)
68
- else:
69
- # already correct dim: just cache originals
70
- print(
71
- f"[i] TF embeddings already {target_dim}-dimensional; skipping reduction.",
72
- flush=True,
73
- )
74
- for i, p in enumerate(unique_paths):
75
- arr = samples[i]
76
- tf_cache[p] = arr
77
- return tf_cache
78
-
79
-
80
- def evaluate(model, dl, device):
81
- model.eval()
82
- all_labels = []
83
- all_preds = []
84
- with torch.no_grad():
85
- for b, g, y in dl:
86
- b = b.to(device)
87
- g = g.to(device)
88
- y = y.to(device)
89
- pred = model(b, g)
90
- all_labels.append(y.cpu())
91
- all_preds.append(pred.cpu())
92
- if not all_labels:
93
- return 0.0, 0.0
94
- y_true = torch.cat(all_labels).numpy()
95
- y_score = torch.cat(all_preds).numpy()
96
- try:
97
- auc = roc_auc_score(y_true, y_score)
98
- except Exception:
99
- auc = 0.0
100
- try:
101
- ap = average_precision_score(y_true, y_score)
102
- except Exception:
103
- ap = 0.0
104
- return auc, ap
105
-
106
-
107
- def unpack(data):
108
- binders, glms, labels = zip(*data)
109
- return list(binders), list(glms), list(labels)
110
-
111
-
112
- # ---- main ------------------------------------------------------------
113
- def main(cfg):
114
- # Set seed for reproducibility
115
- set_seed(cfg.seed)
116
-
117
- parser.add_argument("--out_dir", type=str, required=True)
118
- parser.add_argument("--epochs", type=int, default=10)
119
- parser.add_argument("--batch_size", type=int, default=32)
120
- parser.add_argument("--lr", type=float, default=1e-4)
121
- parser.add_argument("--device", type=str, default="cuda")
122
- parser.add_argument("--seed", type=int, default=42)
123
- args = parser.parse_args()
124
-
125
- #
126
- print("DEBUG: starting training script with in-line TF compression", flush=True)
127
- device = torch.device(args.device if torch.cuda.is_available() else "cpu")
128
- binder_paths, glm_paths, labels = parse_pair_list(cfg.pair_list)
129
-
130
- if len(labels) == 0:
131
- print("[ERROR] No valid pairs parsed. Exiting.", file=sys.stderr)
132
- sys.exit(1)
133
-
134
- label_counts = Counter(labels)
135
- print(
136
- f"[i] Total examples parsed: {len(labels)}. Label distribution: {label_counts}",
137
- flush=True,
138
- )
139
-
140
- # build compressed TF cache (reduces to 256 if needed)
141
- tf_compressed_cache = build_tf_compressed_cache(binder_paths, target_dim=256)
142
-
143
- # load training data aloiaushasfoiuhasfoiuafasdfoihuaaasdfoiuhasfaaoiufhasfoasasfoiuh
144
-
145
- train_ds = PairDataset(None, tf_compressed_cache=tf_compressed_cache)
146
- val_ds = PairDataset(*subset(val_i), tf_compressed_cache=tf_compressed_cache)
147
- test_ds = PairDataset(*subset(test_i), tf_compressed_cache=tf_compressed_cache)
148
-
149
- print(
150
- f"[i] Train/Val/Test sizes: {len(train_ds)}/{len(val_ds)}/{len(test_ds)}",
151
- flush=True,
152
- )
153
- if len(train_ds) == 0 or len(val_ds) == 0:
154
- print(
155
- "[ERROR] Train or validation split is empty; cannot proceed.",
156
- file=sys.stderr,
157
- )
158
- sys.exit(1)
159
-
160
- train_dl = DataLoader(
161
- train_ds, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn
162
- )
163
- val_dl = DataLoader(
164
- val_ds, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn
165
- )
166
- test_dl = DataLoader(
167
- test_ds, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn
168
- )
169
-
170
- model = BindPredictor(
171
- input_dim=256, hidden_dim=256, heads=8, num_layers=3, use_local_cnn_on_glm=True
172
- )
173
- model = model.to(device)
174
- optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-3)
175
- loss_fn = nn.BCELoss()
176
-
177
- best_val = -float("inf")
178
- os_out = Path(args.out_dir)
179
- os_out.mkdir(exist_ok=True, parents=True)
180
-
181
- for epoch in range(1, args.epochs + 1):
182
- print(f"[Epoch {epoch}] starting...", flush=True)
183
- model.train()
184
- running_loss = 0.0
185
- for b, g, y in train_dl:
186
- b = b.to(device)
187
- g = g.to(device)
188
- y = y.to(device)
189
- pred = model(b, g)
190
- loss = loss_fn(pred, y)
191
- optimizer.zero_grad()
192
- loss.backward()
193
- optimizer.step()
194
- running_loss += loss.item() * b.size(0)
195
- train_loss = running_loss / len(train_ds)
196
- val_auc, val_ap = evaluate(model, val_dl, device)
197
- print(
198
- f"[Epoch {epoch}] train_loss={train_loss:.4f} val_auc={val_auc:.4f} val_ap={val_ap:.4f}",
199
- flush=True,
200
- )
201
-
202
- if val_auc > best_val:
203
- best_val = val_auc
204
- torch.save(model.state_dict(), os_out / "best_model.pt")
205
- print(
206
- f"[Epoch {epoch}] Saved new best model with val_auc={val_auc:.4f}",
207
- flush=True,
208
- )
209
-
210
- torch.save(model.state_dict(), os_out / "last_model.pt")
211
- test_auc, test_ap = evaluate(model, test_dl, device)
212
- print(f"FINAL TEST: AUC={test_auc:.4f} AP={test_ap:.4f}", flush=True)
213
- print(f"[i] Models written to {os_out}/best_model.pt and last_model.pt", flush=True)
214
-
215
-
216
- if __name__ == "__main__":
217
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dpacman/classifier/old_train.py DELETED
@@ -1,486 +0,0 @@
1
- import argparse, random, sys
2
- from pathlib import Path
3
-
4
- import numpy as np
5
- import pandas as pd
6
- import torch
7
- from torch import nn
8
- from torch.utils.data import Dataset, DataLoader, Sampler
9
-
10
- # from sklearn.random_projection import GaussianRandomProjection # OLD (kept): projection was removed earlier
11
- import matplotlib.pyplot as plt
12
-
13
- import torch.amp as amp
14
- from torch.nn import functional as F
15
- from model import BindPredictor
16
-
17
-
18
- # ─────────────── utilities ────────────────────────────────────────────────
19
- def parse_pair_list(path):
20
- binders, glms = [], []
21
- with open(path) as f:
22
- for ln, line in enumerate(f, 1):
23
- parts = line.strip().split()
24
- if len(parts) < 2:
25
- continue
26
- b, g = parts[0], parts[1]
27
- binders.append(b)
28
- glms.append(g)
29
- return binders, glms
30
-
31
-
32
- class ListBatchSampler(Sampler):
33
- def __init__(self, batches):
34
- self.batches = batches
35
-
36
- def __iter__(self):
37
- return iter(self.batches)
38
-
39
- def __len__(self):
40
- return len(self.batches)
41
-
42
-
43
- def make_buckets(idxs, glm_paths, batch_size, n_buckets=10, seed=42):
44
- rng = random.Random(seed)
45
- lengths = [(i, np.load(glm_paths[i]).shape[0]) for i in idxs]
46
- lengths.sort(key=lambda x: x[1])
47
- size = max(1, int(np.ceil(len(lengths) / n_buckets)))
48
- buckets = [lengths[i : i + size] for i in range(0, len(lengths), size)]
49
- batches = []
50
- for bucket in buckets:
51
- ids = [i for i, _ in bucket]
52
- rng.shuffle(ids)
53
- for i in range(0, len(ids), batch_size):
54
- batches.append(ids[i : i + batch_size])
55
- rng.shuffle(batches)
56
- return batches
57
-
58
-
59
- def dna_key_from_path(path: str) -> str:
60
- """.../dna_peak42.npy -> 'peak42'"""
61
- stem = Path(path).stem
62
- if "_" in stem:
63
- _, rest = stem.split("_", 1)
64
- else:
65
- rest = stem
66
- return rest
67
-
68
-
69
- def build_tf_cache(tf_paths, target_dim=256):
70
- """
71
- Load raw TF embeddings without projecting; compression is learnable in the model.
72
- """
73
- unique = sorted(set(tf_paths))
74
- print(
75
- f"[i] (Learnable) Preparing {len(unique)} TF files; target {target_dim}d inside the model",
76
- flush=True,
77
- )
78
-
79
- pools, raw = [], []
80
- for p in unique:
81
- arr = np.load(p) # (L, D) or (D,)
82
- raw.append(arr)
83
- pools.append(arr.mean(axis=0) if arr.ndim == 2 else arr)
84
- M = np.stack(pools, 0)
85
- orig_dim = M.shape[1]
86
- print(f"[i] Pooled shape → {M.shape} (orig_dim={orig_dim})", flush=True)
87
-
88
- cache = {}
89
- for i, p in enumerate(unique):
90
- arr = raw[i]
91
- # OLD: projection here (removed)
92
- cache[p] = arr
93
- print("[i] TF cache ready (raw); compression will be learned.", flush=True)
94
- return cache
95
-
96
-
97
- # ─────────────── Dataset & Collation ─────────────────────────────────────
98
- class PairDataset(Dataset):
99
- def __init__(self, tf_paths, dna_paths, final_df, tf_cache):
100
- self.tf_paths, self.dna_paths = tf_paths, dna_paths
101
- self.tf_cache = tf_cache
102
- self.targets = {}
103
- for _, row in final_df.iterrows():
104
- dna_id = row["dna_id"]
105
- vec = np.array(
106
- list(map(float, row["score_sig_r2"].split(","))), dtype=np.float32
107
- )
108
- self.targets[dna_id] = vec
109
-
110
- def __len__(self):
111
- return len(self.tf_paths)
112
-
113
- def __getitem__(self, i):
114
- b = self.tf_cache[self.tf_paths[i]] # (L_b, D_b) or (D_b,)
115
- if b.ndim == 1:
116
- b = b[None, :]
117
- g = np.load(self.dna_paths[i]) # (L_g, 256) or (256,)
118
- if g.ndim == 1:
119
- g = g[None, :]
120
-
121
- stem = Path(self.dna_paths[i]).stem
122
- dna_id = stem.replace("dna_", "")
123
- t = self.targets.get(dna_id, np.zeros(g.shape[0], dtype=np.float32))
124
-
125
- return (
126
- torch.from_numpy(b).float(),
127
- torch.from_numpy(g).float(),
128
- torch.from_numpy(t).float(),
129
- )
130
-
131
-
132
- def collate_fn(batch):
133
- Bs = [b.shape[0] for b, _, _ in batch]
134
- Gs = [g.shape[0] for _, g, _ in batch]
135
- maxB, maxG = max(Bs), max(Gs)
136
-
137
- def pad_seq(x, L):
138
- if x.shape[0] < L:
139
- pad = torch.zeros(
140
- (L - x.shape[0], x.shape[1]), dtype=x.dtype, device=x.device
141
- )
142
- return torch.cat([x, pad], dim=0)
143
- return x
144
-
145
- def pad_t(y, L):
146
- if y.shape[0] < L:
147
- pad = torch.zeros((L - y.shape[0],), dtype=y.dtype, device=y.device)
148
- return torch.cat([y, pad], dim=0)
149
- return y
150
-
151
- b_stack = torch.stack([pad_seq(b, maxB) for b, _, _ in batch])
152
- g_stack = torch.stack([pad_seq(g, maxG) for _, g, _ in batch])
153
- t_stack = torch.stack([pad_t(t, maxG) for *_, t in batch])
154
- return b_stack, g_stack, t_stack
155
-
156
-
157
- # ──���──────────── losses, metrics ─────────────────────────────────────────
158
- def combined_loss_components(logits, targets, peak_thresh=0.5, eps=1e-8):
159
- probs = torch.sigmoid(logits)
160
- labels = (targets >= peak_thresh).float()
161
- non_peak_mask = (labels == 0).float()
162
- peak_mask = (labels == 1).float()
163
-
164
- bce_all = F.binary_cross_entropy_with_logits(logits, labels, reduction="none")
165
- bce_non = bce_all * non_peak_mask
166
- bce_non = bce_non.sum() / (non_peak_mask.sum() + eps)
167
-
168
- mse_peaks = F.mse_loss(probs * peak_mask, targets * peak_mask, reduction="sum") / (
169
- peak_mask.sum() + eps
170
- )
171
- mse_global = F.mse_loss(probs, targets, reduction="mean")
172
-
173
- t_dist = targets + eps
174
- p_dist = probs + eps
175
- t_dist = t_dist / t_dist.sum(dim=1, keepdim=True)
176
- p_dist = p_dist / p_dist.sum(dim=1, keepdim=True)
177
- kl = (
178
- (t_dist * (t_dist.clamp(min=eps).log() - p_dist.clamp(min=eps).log()))
179
- .sum(dim=1)
180
- .mean()
181
- )
182
-
183
- return bce_non, kl, mse_global, probs
184
-
185
-
186
- def accuracy_percentage(logits, targets, peak_thresh=0.5):
187
- probs = torch.sigmoid(logits)
188
- preds_bin = (probs >= 0.5).float()
189
- labels = (targets >= peak_thresh).float()
190
- correct = (preds_bin == labels).float().sum()
191
- total = torch.numel(labels)
192
- return (correct / max(1, total)).item() * 100.0
193
-
194
-
195
- def evaluate(model, dl, device, alpha, beta, gamma, peak_thresh, eps=1e-8):
196
- model.eval()
197
- tot_loss, tot_acc = 0.0, 0.0
198
- n_batches = 0
199
- with torch.no_grad():
200
- for b, g, t in dl:
201
- b, g, t = b.to(device), g.to(device), t.to(device)
202
- logits = model(b, g)
203
- bce_non, kl, mse_global, _ = combined_loss_components(
204
- logits, t, peak_thresh=peak_thresh, eps=eps
205
- )
206
- loss = alpha * bce_non + beta * kl + gamma * mse_global
207
- acc = accuracy_percentage(logits, t, peak_thresh=peak_thresh)
208
- tot_loss += loss.item()
209
- tot_acc += acc
210
- n_batches += 1
211
- if n_batches == 0:
212
- return float("nan"), float("nan")
213
- return tot_loss / n_batches, tot_acc / n_batches
214
-
215
-
216
- # ─────────────── cluster-aware splitting ──────────────────────────────────
217
- def assign_clusters_to_splits(
218
- cluster_to_indices, val_frac=0.10, test_frac=0.10, seed=42
219
- ):
220
- """
221
- cluster_to_indices: dict[cluster_id] -> list of example indices (from pair_list) in that cluster
222
- We greedily pack whole clusters into val/test until hitting targets (#examples), rest to train.
223
- """
224
- rng = random.Random(seed)
225
- clusters = list(cluster_to_indices.items())
226
- rng.shuffle(clusters)
227
-
228
- total = sum(len(ixs) for _, ixs in clusters)
229
- target_val = int(round(total * val_frac))
230
- target_test = int(round(total * test_frac))
231
- cur_val = cur_test = 0
232
-
233
- tr_ix, va_ix, te_ix = [], [], []
234
- for cid, ixs in clusters:
235
- c = len(ixs)
236
- if cur_val + c <= target_val:
237
- va_ix.extend(ixs)
238
- cur_val += c
239
- elif cur_test + c <= target_test:
240
- te_ix.extend(ixs)
241
- cur_test += c
242
- else:
243
- tr_ix.extend(ixs)
244
- return tr_ix, va_ix, te_ix
245
-
246
-
247
- # ─────────────── train & main ────────────────────────────────────────────
248
- def main():
249
- p = argparse.ArgumentParser()
250
- p.add_argument("--pair_list", required=True)
251
- p.add_argument("--final_csv", required=True)
252
- p.add_argument("--out_dir", required=True)
253
- p.add_argument("--epochs", type=int, default=10)
254
- p.add_argument("--batch_size", type=int, default=16)
255
- p.add_argument("--accum_steps", type=int, default=4)
256
- p.add_argument("--lr", type=float, default=1e-4)
257
- p.add_argument("--device", default="cuda")
258
- p.add_argument("--seed", type=int, default=42)
259
- p.add_argument("--alpha", type=float, default=1)
260
- p.add_argument("--beta", type=float, default=0)
261
- p.add_argument("--gamma", type=float, default=1)
262
- p.add_argument("--peak_thresh", type=float, default=0.5)
263
- # NEW: fractions for cluster-aware split (used only if cluster_id present)
264
- p.add_argument("--val_frac", type=float, default=0.10)
265
- p.add_argument("--test_frac", type=float, default=0.10)
266
- args = p.parse_args()
267
-
268
- random.seed(args.seed)
269
- np.random.seed(args.seed)
270
- torch.manual_seed(args.seed)
271
- device = torch.device(args.device if torch.cuda.is_available() else "cpu")
272
-
273
- # 1) load pair list & final.csv (now may include cluster_id)
274
- tf_paths, dna_paths = parse_pair_list(args.pair_list)
275
- final_df = pd.read_csv(args.final_csv, dtype=str)
276
- print(f"[i] Loaded {len(tf_paths)} pairs", flush=True)
277
-
278
- tf_cache = build_tf_cache(tf_paths, target_dim=256)
279
-
280
- # detect binder/DNA dims
281
- sample_tf = tf_cache[tf_paths[0]]
282
- binder_input_dim = sample_tf.shape[1] if sample_tf.ndim == 2 else sample_tf.shape[0]
283
- glm_input_dim = 256
284
-
285
- # 2) cluster-aware split if possible
286
- use_cluster_split = "cluster_id" in final_df.columns
287
- if use_cluster_split:
288
- print(
289
- "[i] Cluster column detected in final_csv; performing cluster-aware split.",
290
- flush=True,
291
- )
292
- # build dna_id -> cluster_id map
293
- cid_map = (
294
- final_df[["dna_id", "cluster_id"]]
295
- .dropna()
296
- .drop_duplicates()
297
- .set_index("dna_id")["cluster_id"]
298
- .to_dict()
299
- )
300
-
301
- # map each example (by index) to its dna_id and cluster
302
- example_dna_ids = [dna_key_from_path(p) for p in dna_paths]
303
- example_clusters = []
304
- missing = 0
305
- for did in example_dna_ids:
306
- if did in cid_map:
307
- example_clusters.append(cid_map[did])
308
- else:
309
- # fallback: treat singleton cluster
310
- example_clusters.append(f"singleton::{did}")
311
- missing += 1
312
- if missing:
313
- print(
314
- f"[WARN] {missing} dna_ids from pair_list not found in cluster map; treating as singleton clusters.",
315
- flush=True,
316
- )
317
-
318
- # build cluster -> indices
319
- cluster_to_indices = {}
320
- for i, cid in enumerate(example_clusters):
321
- cluster_to_indices.setdefault(cid, []).append(i)
322
-
323
- tr_idx, va_idx, te_idx = assign_clusters_to_splits(
324
- cluster_to_indices,
325
- val_frac=args.val_frac,
326
- test_frac=args.test_frac,
327
- seed=args.seed,
328
- )
329
- print(
330
- f"[i] Cluster split sizes (examples): train={len(tr_idx)} val={len(va_idx)} test={len(te_idx)}",
331
- flush=True,
332
- )
333
-
334
- # helper to subset paths
335
- def subset_by_indices(ixs):
336
- return [tf_paths[i] for i in ixs], [dna_paths[i] for i in ixs]
337
-
338
- tr_t, tr_d = subset_by_indices(tr_idx)
339
- va_t, va_d = subset_by_indices(va_idx)
340
- te_t, te_d = subset_by_indices(te_idx)
341
-
342
- else:
343
- print(
344
- "[i] No cluster_id in final_csv; using random 80/10/10 split (OLD behavior).",
345
- flush=True,
346
- )
347
- # OLD random split (kept, now under else)
348
- N = len(tf_paths)
349
- idxs = list(range(N))
350
- random.shuffle(idxs)
351
- n_tr = int(0.8 * N)
352
- n_va = int(0.1 * N)
353
- tr, va, te = idxs[:n_tr], idxs[n_tr : n_tr + n_va], idxs[n_tr + n_va :]
354
-
355
- def subset(idxs_):
356
- return [tf_paths[i] for i in idxs_], [dna_paths[i] for i in idxs_]
357
-
358
- tr_t, tr_d = subset(tr)
359
- va_t, va_d = subset(va)
360
- te_t, te_d = subset(te)
361
-
362
- # 3) bucketed samplers (unchanged, but now use the cluster-aware subsets when available)
363
- tr_bs = make_buckets(
364
- list(range(len(tr_t))), tr_d, args.batch_size, n_buckets=10, seed=args.seed
365
- )
366
- va_bs = make_buckets(
367
- list(range(len(va_t))), va_d, args.batch_size, n_buckets=5, seed=args.seed + 1
368
- )
369
- te_bs = make_buckets(
370
- list(range(len(te_t))), te_d, args.batch_size, n_buckets=5, seed=args.seed + 2
371
- )
372
-
373
- tr_dl = DataLoader(
374
- PairDataset(tr_t, tr_d, final_df, tf_cache),
375
- batch_sampler=ListBatchSampler(tr_bs),
376
- collate_fn=collate_fn,
377
- )
378
- va_dl = DataLoader(
379
- PairDataset(va_t, va_d, final_df, tf_cache),
380
- batch_sampler=ListBatchSampler(va_bs),
381
- collate_fn=collate_fn,
382
- )
383
- te_dl = DataLoader(
384
- PairDataset(te_t, te_d, final_df, tf_cache),
385
- batch_sampler=ListBatchSampler(te_bs),
386
- collate_fn=collate_fn,
387
- )
388
-
389
- # 4) model, optimizer, scaler
390
- model = BindPredictor(
391
- binder_input_dim=binder_input_dim,
392
- glm_input_dim=glm_input_dim,
393
- compressed_dim=256,
394
- hidden_dim=256,
395
- heads=8,
396
- num_layers=4,
397
- use_local_cnn_on_glm=True,
398
- ).to(device)
399
- optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
400
- scaler = amp.GradScaler("cuda")
401
-
402
- history, best_val = {"train": [], "val": []}, float("inf")
403
- od = Path(args.out_dir)
404
- od.mkdir(exist_ok=True, parents=True)
405
-
406
- for ep in range(1, args.epochs + 1):
407
- print(f"┌─[Epoch {ep}]────────────────────────", flush=True)
408
- model.train()
409
- optimizer.zero_grad()
410
- acc_loss_sum, acc_acc_sum, n_train_batches = 0.0, 0.0, 0
411
-
412
- for i, (b, g, t) in enumerate(tr_dl):
413
- b, g, t = b.to(device), g.to(device), t.to(device)
414
- with amp.autocast("cuda"):
415
- logits = model(b, g)
416
- bce_non, kl, mse_global, probs = combined_loss_components(
417
- logits, t, peak_thresh=args.peak_thresh
418
- )
419
- loss = args.alpha * bce_non + args.beta * kl + args.gamma * mse_global
420
- loss = loss / args.accum_steps
421
-
422
- scaler.scale(loss).backward()
423
-
424
- if (i + 1) % args.accum_steps == 0:
425
- scaler.step(optimizer)
426
- scaler.update()
427
- optimizer.zero_grad()
428
-
429
- with torch.no_grad():
430
- acc_loss_sum += loss.item() * args.accum_steps
431
- acc_acc_sum += accuracy_percentage(
432
- logits, t, peak_thresh=args.peak_thresh
433
- )
434
- n_train_batches += 1
435
-
436
- del b, g, t, logits, probs, loss, bce_non, kl, mse_global
437
- torch.cuda.empty_cache()
438
-
439
- # finalize if leftovers
440
- if n_train_batches % args.accum_steps != 0:
441
- scaler.step(optimizer)
442
- scaler.update()
443
- optimizer.zero_grad()
444
-
445
- train_loss = acc_loss_sum / max(1, n_train_batches)
446
- train_acc = acc_acc_sum / max(1, n_train_batches)
447
-
448
- val_loss, val_acc = evaluate(
449
- model,
450
- va_dl,
451
- device,
452
- alpha=args.alpha,
453
- beta=args.beta,
454
- gamma=args.gamma,
455
- peak_thresh=args.peak_thresh,
456
- )
457
- print(
458
- f"[Epoch {ep}] train_loss={train_loss:.4f} train_acc={train_acc:.2f}% "
459
- f"val_loss={val_loss:.4f} val_acc={val_acc:.2f}%",
460
- flush=True,
461
- )
462
-
463
- history["train"].append(train_loss)
464
- history["val"].append(val_loss)
465
- if val_loss < best_val:
466
- best_val = val_loss
467
- torch.save(model.state_dict(), od / "best_model.pt")
468
- print(
469
- f" Saved new best_model.pt (val_loss={val_loss:.4f}, val_acc={val_acc:.2f}%)",
470
- flush=True,
471
- )
472
-
473
- torch.save(model.state_dict(), od / "last_model.pt")
474
-
475
- fig, ax = plt.subplots()
476
- ax.plot(history["train"], label="train")
477
- ax.plot(history["val"], label="val")
478
- ax.set_xlabel("epoch")
479
- ax.set_ylabel("combined loss")
480
- ax.legend()
481
- fig.savefig(od / "loss_curve.png")
482
- print(f"✅ Done → outputs in {od}", flush=True)
483
-
484
-
485
- if __name__ == "__main__":
486
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dpacman/classifier/torch_model.py DELETED
@@ -1,157 +0,0 @@
1
- import torch
2
- from torch import nn
3
-
4
-
5
- class LocalCNN(nn.Module):
6
- def __init__(self, dim: int = 256, kernel_size: int = 3):
7
- super().__init__()
8
- padding = kernel_size // 2
9
- self.conv = nn.Conv1d(dim, dim, kernel_size=kernel_size, padding=padding)
10
- self.act = nn.GELU()
11
- self.ln = nn.LayerNorm(dim)
12
-
13
- def forward(self, x: torch.Tensor):
14
- # x: (batch, L, dim)
15
- out = self.conv(x.transpose(1, 2)) # → (batch, dim, L)
16
- out = self.act(out)
17
- out = out.transpose(1, 2) # → (batch, L, dim)
18
- return self.ln(out + x) # residual
19
-
20
-
21
- class CrossModalBlock(nn.Module):
22
- def __init__(self, dim: int = 256, heads: int = 8):
23
- super().__init__()
24
- # self-attention for both sides
25
- self.sa_binder = nn.MultiheadAttention(dim, heads, batch_first=True)
26
- self.sa_glm = nn.MultiheadAttention(dim, heads, batch_first=True)
27
- self.ln_b1 = nn.LayerNorm(dim)
28
- self.ln_g1 = nn.LayerNorm(dim)
29
-
30
- self.ffn_b = nn.Sequential(
31
- nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
32
- )
33
- self.ffn_g = nn.Sequential(
34
- nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
35
- )
36
- self.ln_b2 = nn.LayerNorm(dim)
37
- self.ln_g2 = nn.LayerNorm(dim)
38
-
39
- # cross attention (binder queries, glm keys/values)
40
- self.cross_attn = nn.MultiheadAttention(dim, heads, batch_first=True)
41
- self.ln_c1 = nn.LayerNorm(dim)
42
- self.ffn_c = nn.Sequential(
43
- nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
44
- )
45
- self.ln_c2 = nn.LayerNorm(dim)
46
-
47
- def forward(self, binder: torch.Tensor, glm: torch.Tensor):
48
- """
49
- binder: (batch, Lb, dim)
50
- glm: (batch, Lg, dim) -- has passed through its local CNN beforehand
51
- returns: updated binder representation (batch, Lb, dim)
52
- """
53
- # binder self-attn + ffn
54
- b = binder
55
- b_sa, _ = self.sa_binder(b, b, b)
56
- b = self.ln_b1(b + b_sa)
57
- b_ff = self.ffn_b(b)
58
- b = self.ln_b2(b + b_ff)
59
-
60
- # glm self-attn + ffn
61
- g = glm
62
- g_sa, _ = self.sa_glm(g, g, g)
63
- g = self.ln_g1(g + g_sa)
64
- g_ff = self.ffn_g(g)
65
- g = self.ln_g2(g + g_ff)
66
-
67
- # cross-attention: binder queries glm
68
- c_sa, _ = self.cross_attn(b, g, g)
69
- c = self.ln_c1(b + c_sa)
70
- c_ff = self.ffn_c(c)
71
- c = self.ln_c2(c + c_ff)
72
- return c # (batch, Lb, dim)
73
-
74
-
75
- class DimCompressor(nn.Module):
76
- """
77
- Learnable per-token compressor: maps any in_dim >= out_dim to out_dim (default 256).
78
- If in_dim == out_dim, behaves as identity.
79
- """
80
-
81
- def __init__(self, in_dim: int, out_dim: int = 256):
82
- super().__init__()
83
- if in_dim == out_dim:
84
- self.net = nn.Identity()
85
- else:
86
- hidden = max(out_dim * 2, (in_dim + out_dim) // 2)
87
- self.net = nn.Sequential(
88
- nn.LayerNorm(in_dim),
89
- nn.Linear(in_dim, hidden),
90
- nn.GELU(),
91
- nn.Linear(hidden, out_dim),
92
- )
93
-
94
- def forward(self, x: torch.Tensor) -> torch.Tensor:
95
- # x: (B, L, in_dim)
96
- return self.net(x)
97
-
98
-
99
- class BindPredictor(nn.Module):
100
- def __init__(
101
- self,
102
- # input_dim: int = 256, # OLD: single input dim
103
- binder_input_dim: int = 1280, # NEW: TF (binder) original dim (e.g., 1280)
104
- glm_input_dim: int = 256, # NEW: DNA/GLM original dim (e.g., 256)
105
- compressed_dim: int = 256, # NEW: learnable compressed dim
106
- hidden_dim: int = 256,
107
- heads: int = 8,
108
- num_layers: int = 4,
109
- use_local_cnn_on_glm: bool = True,
110
- ):
111
- super().__init__()
112
- # OLD:
113
- # self.proj_binder = nn.Linear(input_dim, hidden_dim)
114
- # self.proj_glm = nn.Linear(input_dim, hidden_dim)
115
-
116
- # NEW: learnable compressor for binder → 256, then project to hidden
117
- self.binder_compress = DimCompressor(binder_input_dim, out_dim=compressed_dim)
118
- self.proj_binder = nn.Linear(compressed_dim, hidden_dim)
119
-
120
- # GLM side stays 256 → hidden
121
- self.proj_glm = nn.Linear(glm_input_dim, hidden_dim)
122
-
123
- self.use_local_cnn = use_local_cnn_on_glm
124
- self.local_cnn = LocalCNN(hidden_dim) if use_local_cnn_on_glm else nn.Identity()
125
-
126
- self.layers = nn.ModuleList(
127
- [CrossModalBlock(hidden_dim, heads) for _ in range(num_layers)]
128
- )
129
-
130
- self.ln_out = nn.LayerNorm(hidden_dim)
131
- # self.head = nn.Sequential(nn.Linear(hidden_dim, 1), nn.Sigmoid()) # OLD: returned probabilities
132
- self.head = nn.Linear(hidden_dim, 1) # NEW: return logits (safe for AMP)
133
-
134
- def forward(self, binder_emb, glm_emb):
135
- """
136
- binder_emb: (B, Lb, binder_input_dim)
137
- glm_emb: (B, Lg, glm_input_dim)
138
- Returns per-nucleotide logits for the GLM sequence: (B, Lg)
139
- """
140
- # Binder: learnable compression → 256 → hidden
141
- b = self.binder_compress(binder_emb) # (B, Lb, 256)
142
- b = self.proj_binder(b) # (B, Lb, hidden_dim)
143
-
144
- # GLM: project → hidden, add local CNN context
145
- g = self.proj_glm(glm_emb) # (B, Lg, hidden_dim)
146
- if self.use_local_cnn:
147
- g = self.local_cnn(g)
148
-
149
- # Cross-modal blocks: update binder states using GLM
150
- for layer in self.layers:
151
- b = layer(b, g) # (B, Lb, hidden_dim)
152
-
153
- # Predict per-nucleotide logits on the GLM tokens:
154
- # return self.head(g).squeeze(-1) # OLD: probabilities (with Sigmoid in head)
155
- return self.head(g).squeeze(
156
- -1
157
- ) # NEW: logits (apply sigmoid only in loss/metrics)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dpacman/classifier/train.py DELETED
@@ -1,220 +0,0 @@
1
- #!/usr/bin/env python3
2
- import argparse
3
- import numpy as np
4
- import torch
5
- from torch import nn
6
- from model import BindPredictor
7
- from pathlib import Path
8
- from collections import Counter
9
- from sklearn.metrics import roc_auc_score, average_precision_score
10
- from sklearn.decomposition import TruncatedSVD
11
- import sys
12
-
13
- from dpacman.utils.models import set_seed
14
-
15
-
16
- def build_tf_compressed_cache(binder_paths, target_dim=256):
17
- """
18
- Load all unique TF (binder) embeddings, fit reduction if needed, and return dict mapping path->(L, target_dim) array.
19
- """
20
- unique_paths = sorted(set(binder_paths))
21
- print(
22
- f"[i] Found {len(unique_paths)} unique TF embedding files to compress.",
23
- flush=True,
24
- )
25
- # Load all embeddings to determine dimensionality
26
- samples = []
27
- for p in unique_paths:
28
- arr = np.load(p)
29
- samples.append(arr)
30
- # Determine if reduction needed: assume all have same embedding width
31
- first = samples[0]
32
- orig_dim = first.shape[1] if first.ndim == 2 else 1
33
- reduction_needed = orig_dim != target_dim
34
- tf_cache = {}
35
-
36
- if reduction_needed:
37
- # Build matrix to fit SVD: we need a 2D matrix per embedding; if lengths vary we can't directly stack.
38
- # We'll do reduction per sequence individually using TruncatedSVD on concatenated flattened features:
39
- # Simplest: for variable lengths, reduce each embedding separately with a learned linear projection.
40
- # Here we fit a single TruncatedSVD on the concatenation of all sequence tokens (flattened) by padding/truncating to a fixed length.
41
- # To avoid complexity, use PCA-like linear projection learned via SVD on mean-pooled vectors:
42
- pooled = []
43
- for arr in samples:
44
- if arr.ndim == 2:
45
- pooled.append(arr.mean(axis=0)) # (orig_dim,)
46
- else:
47
- pooled.append(arr) # degenerate
48
- pooled_mat = np.stack(pooled, axis=0) # (N, orig_dim)
49
- print(
50
- f"[i] Fitting TruncatedSVD on TF pooled embeddings: {pooled_mat.shape} -> {target_dim}",
51
- flush=True,
52
- )
53
- svd = TruncatedSVD(n_components=target_dim, random_state=42)
54
- reduced_pooled = svd.fit_transform(pooled_mat) # (N, target_dim)
55
-
56
- # For each original embedding, project token-level vectors by multiplying token vector with svd.components_.T
57
- # svd.components_: (target_dim, orig_dim) so projection matrix is (orig_dim, target_dim)
58
- proj_mat = svd.components_.T # (orig_dim, target_dim)
59
- for i, p in enumerate(unique_paths):
60
- arr = samples[i] # shape (L, orig_dim)
61
- if arr.ndim == 1:
62
- arr2 = arr @ proj_mat # (target_dim,)
63
- else:
64
- # project each token: (L, orig_dim) @ (orig_dim, target_dim) -> (L, target_dim)
65
- arr2 = arr @ proj_mat
66
- tf_cache[p] = arr2 # reduced per-token representation
67
- print("[i] Completed compression of TF embeddings.", flush=True)
68
- else:
69
- # already correct dim: just cache originals
70
- print(
71
- f"[i] TF embeddings already {target_dim}-dimensional; skipping reduction.",
72
- flush=True,
73
- )
74
- for i, p in enumerate(unique_paths):
75
- arr = samples[i]
76
- tf_cache[p] = arr
77
- return tf_cache
78
-
79
-
80
- def evaluate(model, dl, device):
81
- model.eval()
82
- all_labels = []
83
- all_preds = []
84
- with torch.no_grad():
85
- for b, g, y in dl:
86
- b = b.to(device)
87
- g = g.to(device)
88
- y = y.to(device)
89
- pred = model(b, g)
90
- all_labels.append(y.cpu())
91
- all_preds.append(pred.cpu())
92
- if not all_labels:
93
- return 0.0, 0.0
94
- y_true = torch.cat(all_labels).numpy()
95
- y_score = torch.cat(all_preds).numpy()
96
- try:
97
- auc = roc_auc_score(y_true, y_score)
98
- except Exception:
99
- auc = 0.0
100
- try:
101
- ap = average_precision_score(y_true, y_score)
102
- except Exception:
103
- ap = 0.0
104
- return auc, ap
105
-
106
-
107
- def unpack(data):
108
- binders, glms, labels = zip(*data)
109
- return list(binders), list(glms), list(labels)
110
-
111
-
112
- # ---- main ------------------------------------------------------------
113
- def main(cfg):
114
- """
115
- Main method, used to train the model.
116
- """
117
- # Set seed for reproducibility
118
- set_seed(cfg.seed)
119
-
120
- parser.add_argument("--out_dir", type=str, required=True)
121
- parser.add_argument("--epochs", type=int, default=10)
122
- parser.add_argument("--batch_size", type=int, default=32)
123
- parser.add_argument("--lr", type=float, default=1e-4)
124
- parser.add_argument("--device", type=str, default="cuda")
125
- parser.add_argument("--seed", type=int, default=42)
126
- args = parser.parse_args()
127
-
128
- #
129
- print("DEBUG: starting training script with in-line TF compression", flush=True)
130
- device = torch.device(args.device if torch.cuda.is_available() else "cpu")
131
- binder_paths, glm_paths, labels = parse_pair_list(cfg.pair_list)
132
-
133
- if len(labels) == 0:
134
- print("[ERROR] No valid pairs parsed. Exiting.", file=sys.stderr)
135
- sys.exit(1)
136
-
137
- label_counts = Counter(labels)
138
- print(
139
- f"[i] Total examples parsed: {len(labels)}. Label distribution: {label_counts}",
140
- flush=True,
141
- )
142
-
143
- # build compressed TF cache (reduces to 256 if needed)
144
- tf_compressed_cache = build_tf_compressed_cache(binder_paths, target_dim=256)
145
-
146
- # load training data aloiaushasfoiuhasfoiuafasdfoihuaaasdfoiuhasfaaoiufhasfoasasfoiuh
147
-
148
- train_ds = PairDataset(None, tf_compressed_cache=tf_compressed_cache)
149
- val_ds = PairDataset(*subset(val_i), tf_compressed_cache=tf_compressed_cache)
150
- test_ds = PairDataset(*subset(test_i), tf_compressed_cache=tf_compressed_cache)
151
-
152
- print(
153
- f"[i] Train/Val/Test sizes: {len(train_ds)}/{len(val_ds)}/{len(test_ds)}",
154
- flush=True,
155
- )
156
- if len(train_ds) == 0 or len(val_ds) == 0:
157
- print(
158
- "[ERROR] Train or validation split is empty; cannot proceed.",
159
- file=sys.stderr,
160
- )
161
- sys.exit(1)
162
-
163
- train_dl = DataLoader(
164
- train_ds, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn
165
- )
166
- val_dl = DataLoader(
167
- val_ds, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn
168
- )
169
- test_dl = DataLoader(
170
- test_ds, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn
171
- )
172
-
173
- model = BindPredictor(
174
- input_dim=256, hidden_dim=256, heads=8, num_layers=3, use_local_cnn_on_glm=True
175
- )
176
- model = model.to(device)
177
- optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-3)
178
- loss_fn = nn.BCELoss()
179
-
180
- best_val = -float("inf")
181
- os_out = Path(args.out_dir)
182
- os_out.mkdir(exist_ok=True, parents=True)
183
-
184
- for epoch in range(1, args.epochs + 1):
185
- print(f"[Epoch {epoch}] starting...", flush=True)
186
- model.train()
187
- running_loss = 0.0
188
- for b, g, y in train_dl:
189
- b = b.to(device)
190
- g = g.to(device)
191
- y = y.to(device)
192
- pred = model(b, g)
193
- loss = loss_fn(pred, y)
194
- optimizer.zero_grad()
195
- loss.backward()
196
- optimizer.step()
197
- running_loss += loss.item() * b.size(0)
198
- train_loss = running_loss / len(train_ds)
199
- val_auc, val_ap = evaluate(model, val_dl, device)
200
- print(
201
- f"[Epoch {epoch}] train_loss={train_loss:.4f} val_auc={val_auc:.4f} val_ap={val_ap:.4f}",
202
- flush=True,
203
- )
204
-
205
- if val_auc > best_val:
206
- best_val = val_auc
207
- torch.save(model.state_dict(), os_out / "best_model.pt")
208
- print(
209
- f"[Epoch {epoch}] Saved new best model with val_auc={val_auc:.4f}",
210
- flush=True,
211
- )
212
-
213
- torch.save(model.state_dict(), os_out / "last_model.pt")
214
- test_auc, test_ap = evaluate(model, test_dl, device)
215
- print(f"FINAL TEST: AUC={test_auc:.4f} AP={test_ap:.4f}", flush=True)
216
- print(f"[i] Models written to {os_out}/best_model.pt and last_model.pt", flush=True)
217
-
218
-
219
- if __name__ == "__main__":
220
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dpacman/scripts/delay_run.sh CHANGED
@@ -4,7 +4,7 @@ set -euo pipefail
4
  # Usage: ./stagger.sh <first_script.sh> <second_script.sh>
5
  # Optional: override waits via env vars WAIT1 / WAIT2 (seconds). Defaults: 3 hours each.
6
 
7
- WAIT1=${WAIT1:-10800} # 3 hours in seconds
8
  WAIT2=${WAIT2:-10800}
9
 
10
  SCRIPT1="${1:?usage: $0 <first_script.sh> <second_script.sh>}"
 
4
  # Usage: ./stagger.sh <first_script.sh> <second_script.sh>
5
  # Optional: override waits via env vars WAIT1 / WAIT2 (seconds). Defaults: 3 hours each.
6
 
7
+ WAIT1=${WAIT1:-3600} # 3 hours in seconds
8
  WAIT2=${WAIT2:-10800}
9
 
10
  SCRIPT1="${1:?usage: $0 <first_script.sh> <second_script.sh>}"
dpacman/scripts/run_train.sh CHANGED
@@ -22,7 +22,7 @@ CUDA_VISIBLE_DEVICES=0,1 nohup python -u -m scripts.train \
22
  +trainer.gradient_clip_algorithm="norm" \
23
  hydra.run.dir="${run_dir}" \
24
  trainer.devices=2 \
25
- trainer.max_epochs=10 \
26
  data_module.train_file="data_files/processed/splits/by_dna/train.csv" \
27
  data_module.val_file="data_files/processed/splits/by_dna/val.csv" \
28
  data_module.test_file="data_files/processed/splits/by_dna/test.csv" \
@@ -31,8 +31,8 @@ CUDA_VISIBLE_DEVICES=0,1 nohup python -u -m scripts.train \
31
  data_module.batch_size=16 \
32
  model.glm_input_dim=256 \
33
  model.compressed_dim=256 \
34
- model.hidden_dim=256 \
35
- model.lr=5e-6 \
36
  > "${run_dir}/run.log" 2>&1 &
37
 
38
  echo $! > "${run_dir}/pid.txt"
 
22
  +trainer.gradient_clip_algorithm="norm" \
23
  hydra.run.dir="${run_dir}" \
24
  trainer.devices=2 \
25
+ trainer.max_epochs=20 \
26
  data_module.train_file="data_files/processed/splits/by_dna/train.csv" \
27
  data_module.val_file="data_files/processed/splits/by_dna/val.csv" \
28
  data_module.test_file="data_files/processed/splits/by_dna/test.csv" \
 
31
  data_module.batch_size=16 \
32
  model.glm_input_dim=256 \
33
  model.compressed_dim=256 \
34
+ model.hidden_dim=128 \
35
+ model.lr=1e-5 \
36
  > "${run_dir}/run.log" 2>&1 &
37
 
38
  echo $! > "${run_dir}/pid.txt"
dpacman/scripts/run_train_baseline.sh ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Manually specify values used in the config
4
+ main_task="train"
5
+ model_type="baseline"
6
+ timestamp=$(date "+%Y-%m-%d_%H-%M-%S")
7
+
8
+ run_dir="$HOME/DPACMAN/logs/${main_task}/${model_type}/runs/${timestamp}"
9
+ mkdir -p "$run_dir"
10
+
11
+ if [ -z "$WANDB_API_KEY" ]; then
12
+ read -s -p "Enter your WANDB API key: " wandb_key
13
+ echo
14
+ export WANDB_API_KEY="$wandb_key"
15
+ fi
16
+
17
+ CUDA_VISIBLE_DEVICES=0,1 nohup python -u -m scripts.train \
18
+ +trainer.strategy=ddp \
19
+ +trainer.use_distributed_sampler="false" \
20
+ +trainer.detect_anomaly="false" \
21
+ +trainer.gradient_clip_val=0.5 \
22
+ +trainer.gradient_clip_algorithm="norm" \
23
+ hydra.run.dir="${run_dir}" \
24
+ trainer.devices=2 \
25
+ trainer.max_epochs=10 \
26
+ data_module.train_file="data_files/processed/splits/by_dna/train.csv" \
27
+ data_module.val_file="data_files/processed/splits/by_dna/val.csv" \
28
+ data_module.test_file="data_files/processed/splits/by_dna/test.csv" \
29
+ data_module.tr_shelf_path="data_files/processed/embeddings/fimo_hits_only/trs_esm.shelf" \
30
+ data_module.dna_shelf_path="data_files/processed/embeddings/fimo_hits_only/peaks_caduceus.shelf" \
31
+ data_module.batch_size=16 \
32
+ model=baseline \
33
+ model.glm_input_dim=256 \
34
+ model.compressed_dim=256 \
35
+ model.hidden_dim=128 \
36
+ model.lr=1e-5 \
37
+ > "${run_dir}/run.log" 2>&1 &
38
+
39
+ echo $! > "${run_dir}/pid.txt"