svincoff commited on
Commit
9da03b7
·
1 Parent(s): 7b33404

added dropout and overfit prevention

Browse files
.gitignore CHANGED
@@ -40,4 +40,6 @@ log.log
40
  log2.log
41
  dpacman/delay.log
42
  dpacman/view_profiles.ipynb
43
- dpacman/find_wandb_run_dirs.py
 
 
 
40
  log2.log
41
  dpacman/delay.log
42
  dpacman/view_profiles.ipynb
43
+ dpacman/find_wandb_run_dirs.py
44
+ dpacman/delay_binary.log
45
+ dpacman/delay_mix.log
configs/data_module/pair.yaml CHANGED
@@ -6,6 +6,7 @@ test_file: data_files/processed/splits/by_dna/babytest.csv
6
 
7
  target_col: dna_sequence
8
  score_col: scores
 
9
 
10
  tr_shelf_path: data_files/processed/embeddings/fimo_hits_only/trs_esm.shelf
11
  dna_shelf_path: data_files/processed/embeddings/fimo_hits_only/baby_peaks_segmentnt_pernuc_with_onehot.shelf
 
6
 
7
  target_col: dna_sequence
8
  score_col: scores
9
+ norm_value: 1333
10
 
11
  tr_shelf_path: data_files/processed/embeddings/fimo_hits_only/trs_esm.shelf
12
  dna_shelf_path: data_files/processed/embeddings/fimo_hits_only/baby_peaks_segmentnt_pernuc_with_onehot.shelf
configs/data_task/split/remap.yaml CHANGED
@@ -12,7 +12,9 @@ split_out_dir: dpacman/data_files/processed/splits
12
 
13
  dna_map_path: dpacman/data_files/processed/fimo/post_fimo/fimo_hits_only/maps/dna_seqid_to_dna_sequence.json
14
 
15
- split_by: both # protein, dna, or both
 
 
16
  augment_rc: true
17
 
18
  test_ratio: 0.10
 
12
 
13
  dna_map_path: dpacman/data_files/processed/fimo/post_fimo/fimo_hits_only/maps/dna_seqid_to_dna_sequence.json
14
 
15
+ split_by: dna # protein, dna, or both
16
+ test_trs: ["trseq23","trseq26","trseq17"]
17
+ test_dnas: null
18
  augment_rc: true
19
 
20
  test_ratio: 0.10
configs/model/classifier.yaml CHANGED
@@ -4,6 +4,7 @@ lr: 1e-4
4
  alpha: 20
5
  gamma: 20
6
  weight_decay: 0.01
 
7
 
8
  glm_input_dim: 1029
9
  compressed_dim: 1029
 
4
  alpha: 20
5
  gamma: 20
6
  weight_decay: 0.01
7
+ dropout: 0.1
8
 
9
  glm_input_dim: 1029
10
  compressed_dim: 1029
dpacman/benchmark/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ This folder is for benchmarking the trained classifier.
dpacman/benchmark/__init__.py ADDED
File without changes
dpacman/classifier/model.py CHANGED
@@ -11,82 +11,96 @@ from .loss import calculate_loss, auprc_zeros_vs_ones_from_logits, auroc_zeros_v
11
  set_seed()
12
 
13
  class LocalCNN(nn.Module):
14
- def __init__(self, dim: int = 256, kernel_size: int = 3):
15
  super().__init__()
16
  padding = kernel_size // 2
17
  self.conv = nn.Conv1d(dim, dim, kernel_size=kernel_size, padding=padding)
18
  self.act = nn.GELU()
19
  self.ln = nn.LayerNorm(dim)
 
 
20
 
21
  def forward(self, x: torch.Tensor):
22
  # x: (batch, L, dim)
23
  out = self.conv(x.transpose(1, 2)) # → (batch, dim, L)
24
  out = self.act(out)
 
25
  out = out.transpose(1, 2) # → (batch, L, dim)
26
  return self.ln(out + x) # residual
27
 
28
 
29
  class CrossModalBlock(nn.Module):
30
- def __init__(self, dim: int = 256, heads: int = 8, dropout: float = 0.0):
31
  super().__init__()
32
  # self-attention for both sides
33
- self.sa_binder = nn.MultiheadAttention(dim, heads, batch_first=True)
34
- self.sa_glm = nn.MultiheadAttention(dim, heads, batch_first=True)
 
 
35
  # first layer norms
36
  self.ln_b1 = nn.LayerNorm(dim)
37
  self.ln_g1 = nn.LayerNorm(dim)
38
  # first feed forward networks
39
  self.ffn_b1 = nn.Sequential(
40
- nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
41
  )
42
  self.ffn_g1 = nn.Sequential(
43
- nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
44
  )
 
 
 
45
  self.ln_b2 = nn.LayerNorm(dim)
46
  self.ln_g2 = nn.LayerNorm(dim)
47
 
48
  # 2) reciprocal cross-attn: g<-b and b<-g
49
  # DNA/GLM updated by attending to Binder
50
  self.cross_g2b_1_RCA = nn.MultiheadAttention(dim, heads, batch_first=True, dropout=dropout)
 
51
  self.ln_g3_RCA = nn.LayerNorm(dim)
52
- self.ffn_g2_RCA = nn.Sequential(nn.Linear(dim, dim*4), nn.GELU(), nn.Linear(dim*4, dim))
 
53
  self.ln_g4_RCA = nn.LayerNorm(dim)
54
 
55
  # Binder updated by attending to DNA/GLM
56
  self.cross_b2g_1_RCA = nn.MultiheadAttention(dim, heads, batch_first=True, dropout=dropout)
 
57
  self.ln_b3_RCA = nn.LayerNorm(dim)
58
- self.ffn_b2_RCA = nn.Sequential(nn.Linear(dim, dim*4), nn.GELU(), nn.Linear(dim*4, dim))
 
59
  self.ln_b4_RCA = nn.LayerNorm(dim)
60
 
61
  # cross attention (binder queries, glm keys/values)
62
- # so the NDA path is updated by the transcriptoin factors
63
  self.cross_g2b_2 = nn.MultiheadAttention(dim, heads, batch_first=True)
 
64
  self.ln_g5 = nn.LayerNorm(dim)
65
  self.ffn_g3 = nn.Sequential(
66
- nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
67
  )
 
68
  self.ln_g6 = nn.LayerNorm(dim)
69
 
70
  def forward(self, binder: torch.Tensor, glm: torch.Tensor, binder_kpm_mask=None, glm_kpm_mask=None):
71
  """
72
  binder: (batch, Lb, dim)
73
  glm: (batch, Lg, dim) -- has passed through its local CNN beforehand
74
- returns: updated binder representation (batch, Lb, dim)
75
  """
76
- # 1) Self-attentino and feed-forward networks for binder and DNA
77
  # binder: self-attn + ffn
78
  b = binder
79
  b_sa, _ = self.sa_binder(b, b, b, key_padding_mask=binder_kpm_mask)
80
- b = self.ln_b1(b + b_sa)
81
  b_ff = self.ffn_b1(b)
82
- b = self.ln_b2(b + b_ff)
83
 
84
  # glm: self-attn + ffn
85
  g = glm
86
  g_sa, _ = self.sa_glm(g, g, g, key_padding_mask=glm_kpm_mask)
87
- g = self.ln_g1(g + g_sa)
88
  g_ff = self.ffn_g1(g)
89
- g = self.ln_g2(g + g_ff)
90
 
91
  # 2a) Reciprocal Cross-Attention:
92
  # DNA updated by attending to Binder (Q=g, K=b, V=b)
@@ -97,22 +111,22 @@ class CrossModalBlock(nn.Module):
97
  # invert if your mask is True=keep:
98
  # key_padding_mask=(~binder_mask.bool()) if binder_mask is not None else None
99
  )
100
- g = self.ln_g3_RCA(g + g_ca)
101
- g = self.ln_g4_RCA(g + self.ffn_g2_RCA(g))
102
 
103
  # 2b) Binder updated by attending to DNA/GLM (Q=b, K=g, V=g)
104
  b_ca, _ = self.cross_b2g_1_RCA(
105
  b, g, g, key_padding_mask=glm_kpm_mask
106
  # key_padding_mask=(~glm_mask.bool()) if glm_mask is not None else None
107
  )
108
- b = self.ln_b3_RCA(b + b_ca)
109
- b = self.ln_b4_RCA(b + self.ffn_b2_RCA(b))
110
 
111
  # cross-attention: glm queries binder and glm embeddings are updated
112
  g_to_b_ca, _ = self.cross_g2b_2(g, b, b, key_padding_mask=binder_kpm_mask)
113
- g = self.ln_g5(g + g_to_b_ca)
114
  g_ff = self.ffn_g3(g)
115
- g = self.ln_g6(g + g_ff)
116
  return b, g # (batch, Lb, dim)
117
 
118
  class DimCompressor(nn.Module):
@@ -164,12 +178,15 @@ class BindPredictor(LightningModule):
164
  # Learnable compressor for binder -> 256, then project to hidden
165
  self.binder_compress = DimCompressor(binder_input_dim, out_dim=compressed_dim)
166
  self.proj_binder = nn.Linear(compressed_dim, hidden_dim)
 
 
167
 
168
  # GLM side stays 256 -> hidden
169
  self.proj_glm = nn.Linear(glm_input_dim, hidden_dim)
170
-
 
171
  self.use_local_cnn = use_local_cnn_on_glm
172
- self.local_cnn = LocalCNN(hidden_dim) if use_local_cnn_on_glm else nn.Identity()
173
 
174
  self.layers = nn.ModuleList(
175
  [CrossModalBlock(hidden_dim, heads, self.hparams.dropout) for _ in range(num_layers)]
@@ -188,9 +205,11 @@ class BindPredictor(LightningModule):
188
  # Binder: learnable compression → 256 → hidden
189
  b = self.binder_compress(binder_emb) # (B, Lb, 256)
190
  b = self.proj_binder(b) # (B, Lb, hidden_dim)
 
191
 
192
  # GLM: project → hidden, add local CNN context
193
  g = self.proj_glm(glm_emb) # (B, Lg, hidden_dim)
 
194
  if self.use_local_cnn:
195
  g = self.local_cnn(g)
196
 
 
11
  set_seed()
12
 
13
  class LocalCNN(nn.Module):
14
+ def __init__(self, dim: int = 256, kernel_size: int = 3, dropout=0.1):
15
  super().__init__()
16
  padding = kernel_size // 2
17
  self.conv = nn.Conv1d(dim, dim, kernel_size=kernel_size, padding=padding)
18
  self.act = nn.GELU()
19
  self.ln = nn.LayerNorm(dim)
20
+
21
+ self.dropout = nn.Dropout(dropout)
22
 
23
  def forward(self, x: torch.Tensor):
24
  # x: (batch, L, dim)
25
  out = self.conv(x.transpose(1, 2)) # → (batch, dim, L)
26
  out = self.act(out)
27
+ out = self.dropout(out) # dropout before the layer norm
28
  out = out.transpose(1, 2) # → (batch, L, dim)
29
  return self.ln(out + x) # residual
30
 
31
 
32
  class CrossModalBlock(nn.Module):
33
+ def __init__(self, dim: int = 256, heads: int = 8, dropout: float = 0.1):
34
  super().__init__()
35
  # self-attention for both sides
36
+ self.sa_binder = nn.MultiheadAttention(dim, heads, batch_first=True, dropout=dropout)
37
+ self.sa_glm = nn.MultiheadAttention(dim, heads, batch_first=True, dropout=dropout)
38
+ self.do_sa_b = nn.Dropout(dropout)
39
+ self.do_sa_g = nn.Dropout(dropout)
40
  # first layer norms
41
  self.ln_b1 = nn.LayerNorm(dim)
42
  self.ln_g1 = nn.LayerNorm(dim)
43
  # first feed forward networks
44
  self.ffn_b1 = nn.Sequential(
45
+ nn.Linear(dim, dim * 4), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim * 4, dim)
46
  )
47
  self.ffn_g1 = nn.Sequential(
48
+ nn.Linear(dim, dim * 4), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim * 4, dim)
49
  )
50
+ self.do_ffn_b1 = nn.Dropout(dropout)
51
+ self.do_ffn_g1 = nn.Dropout(dropout)
52
+
53
  self.ln_b2 = nn.LayerNorm(dim)
54
  self.ln_g2 = nn.LayerNorm(dim)
55
 
56
  # 2) reciprocal cross-attn: g<-b and b<-g
57
  # DNA/GLM updated by attending to Binder
58
  self.cross_g2b_1_RCA = nn.MultiheadAttention(dim, heads, batch_first=True, dropout=dropout)
59
+ self.do_rca_g = nn.Dropout(dropout)
60
  self.ln_g3_RCA = nn.LayerNorm(dim)
61
+ self.ffn_g2_RCA = nn.Sequential(nn.Linear(dim, dim*4), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim*4, dim))
62
+ self.do_ffn_g2 = nn.Dropout(dropout)
63
  self.ln_g4_RCA = nn.LayerNorm(dim)
64
 
65
  # Binder updated by attending to DNA/GLM
66
  self.cross_b2g_1_RCA = nn.MultiheadAttention(dim, heads, batch_first=True, dropout=dropout)
67
+ self.do_rca_b = nn.Dropout(dropout)
68
  self.ln_b3_RCA = nn.LayerNorm(dim)
69
+ self.ffn_b2_RCA = nn.Sequential(nn.Linear(dim, dim*4), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim*4, dim))
70
+ self.do_ffn_b2 = nn.Dropout(dropout)
71
  self.ln_b4_RCA = nn.LayerNorm(dim)
72
 
73
  # cross attention (binder queries, glm keys/values)
74
+ # so the NDA path is updated by the transcription factors
75
  self.cross_g2b_2 = nn.MultiheadAttention(dim, heads, batch_first=True)
76
+ self.do_g2b2 = nn.Dropout(dropout)
77
  self.ln_g5 = nn.LayerNorm(dim)
78
  self.ffn_g3 = nn.Sequential(
79
+ nn.Linear(dim, dim * 4), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim * 4, dim)
80
  )
81
+ self.do_ffn_g3 = nn.Dropout(dropout)
82
  self.ln_g6 = nn.LayerNorm(dim)
83
 
84
  def forward(self, binder: torch.Tensor, glm: torch.Tensor, binder_kpm_mask=None, glm_kpm_mask=None):
85
  """
86
  binder: (batch, Lb, dim)
87
  glm: (batch, Lg, dim) -- has passed through its local CNN beforehand
88
+ returns: updated binder representation (batch, Lb, dim) and gLM representation
89
  """
90
+ # 1) Self-attention and feed-forward networks for binder and DNA
91
  # binder: self-attn + ffn
92
  b = binder
93
  b_sa, _ = self.sa_binder(b, b, b, key_padding_mask=binder_kpm_mask)
94
+ b = self.ln_b1(b + self.do_sa_b(b_sa))
95
  b_ff = self.ffn_b1(b)
96
+ b = self.ln_b2(b + self.do_ffn_b1(b_ff))
97
 
98
  # glm: self-attn + ffn
99
  g = glm
100
  g_sa, _ = self.sa_glm(g, g, g, key_padding_mask=glm_kpm_mask)
101
+ g = self.ln_g1(g + self.do_sa_g(g_sa))
102
  g_ff = self.ffn_g1(g)
103
+ g = self.ln_g2(g + self.do_ffn_g1(g_ff))
104
 
105
  # 2a) Reciprocal Cross-Attention:
106
  # DNA updated by attending to Binder (Q=g, K=b, V=b)
 
111
  # invert if your mask is True=keep:
112
  # key_padding_mask=(~binder_mask.bool()) if binder_mask is not None else None
113
  )
114
+ g = self.ln_g3_RCA(g + self.do_rca_g(g_ca))
115
+ g = self.ln_g4_RCA(g + self.do_ffn_g2(self.ffn_g2_RCA(g)))
116
 
117
  # 2b) Binder updated by attending to DNA/GLM (Q=b, K=g, V=g)
118
  b_ca, _ = self.cross_b2g_1_RCA(
119
  b, g, g, key_padding_mask=glm_kpm_mask
120
  # key_padding_mask=(~glm_mask.bool()) if glm_mask is not None else None
121
  )
122
+ b = self.ln_b3_RCA(b + self.do_rca_b(b_ca))
123
+ b = self.ln_b4_RCA(b + self.do_ffn_b2(self.ffn_b2_RCA(b)))
124
 
125
  # cross-attention: glm queries binder and glm embeddings are updated
126
  g_to_b_ca, _ = self.cross_g2b_2(g, b, b, key_padding_mask=binder_kpm_mask)
127
+ g = self.ln_g5(g + self.do_g2b2(g_to_b_ca))
128
  g_ff = self.ffn_g3(g)
129
+ g = self.ln_g6(g + self.do_ffn_g3(g_ff))
130
  return b, g # (batch, Lb, dim)
131
 
132
  class DimCompressor(nn.Module):
 
178
  # Learnable compressor for binder -> 256, then project to hidden
179
  self.binder_compress = DimCompressor(binder_input_dim, out_dim=compressed_dim)
180
  self.proj_binder = nn.Linear(compressed_dim, hidden_dim)
181
+ self.dropout_b1 = nn.Dropout(dropout)
182
+ self.act = nn.GELU()
183
 
184
  # GLM side stays 256 -> hidden
185
  self.proj_glm = nn.Linear(glm_input_dim, hidden_dim)
186
+ self.dropout_g1 = nn.Dropout(dropout)
187
+
188
  self.use_local_cnn = use_local_cnn_on_glm
189
+ self.local_cnn = LocalCNN(hidden_dim, dropout=self.hparams.dropout) if use_local_cnn_on_glm else nn.Identity()
190
 
191
  self.layers = nn.ModuleList(
192
  [CrossModalBlock(hidden_dim, heads, self.hparams.dropout) for _ in range(num_layers)]
 
205
  # Binder: learnable compression → 256 → hidden
206
  b = self.binder_compress(binder_emb) # (B, Lb, 256)
207
  b = self.proj_binder(b) # (B, Lb, hidden_dim)
208
+ b = self.dropout_b1(self.act(b))
209
 
210
  # GLM: project → hidden, add local CNN context
211
  g = self.proj_glm(glm_emb) # (B, Lg, hidden_dim)
212
+ g = self.dropout_g1(self.act(g))
213
  if self.use_local_cnn:
214
  g = self.local_cnn(g)
215
 
dpacman/classifier/old_model.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Lightning Module for the binding model.
3
+ """
4
+
5
+ import torch
6
+ from torch import nn
7
+ from lightning import LightningModule
8
+ from dpacman.utils.models import set_seed
9
+ from .loss import calculate_loss, auprc_zeros_vs_ones_from_logits, auroc_zeros_vs_ones_from_logits
10
+
11
+ set_seed()
12
+
13
+ class LocalCNN(nn.Module):
14
+ def __init__(self, dim: int = 256, kernel_size: int = 3):
15
+ super().__init__()
16
+ padding = kernel_size // 2
17
+ self.conv = nn.Conv1d(dim, dim, kernel_size=kernel_size, padding=padding)
18
+ self.act = nn.GELU()
19
+ self.ln = nn.LayerNorm(dim)
20
+
21
+ def forward(self, x: torch.Tensor):
22
+ # x: (batch, L, dim)
23
+ out = self.conv(x.transpose(1, 2)) # → (batch, dim, L)
24
+ out = self.act(out)
25
+ out = out.transpose(1, 2) # → (batch, L, dim)
26
+ return self.ln(out + x) # residual
27
+
28
+
29
+ class CrossModalBlock(nn.Module):
30
+ def __init__(self, dim: int = 256, heads: int = 8, dropout: float = 0.0):
31
+ super().__init__()
32
+ # self-attention for both sides
33
+ self.sa_binder = nn.MultiheadAttention(dim, heads, batch_first=True)
34
+ self.sa_glm = nn.MultiheadAttention(dim, heads, batch_first=True)
35
+ # first layer norms
36
+ self.ln_b1 = nn.LayerNorm(dim)
37
+ self.ln_g1 = nn.LayerNorm(dim)
38
+ # first feed forward networks
39
+ self.ffn_b1 = nn.Sequential(
40
+ nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
41
+ )
42
+ self.ffn_g1 = nn.Sequential(
43
+ nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
44
+ )
45
+ self.ln_b2 = nn.LayerNorm(dim)
46
+ self.ln_g2 = nn.LayerNorm(dim)
47
+
48
+ # 2) reciprocal cross-attn: g<-b and b<-g
49
+ # DNA/GLM updated by attending to Binder
50
+ self.cross_g2b_1_RCA = nn.MultiheadAttention(dim, heads, batch_first=True, dropout=dropout)
51
+ self.ln_g3_RCA = nn.LayerNorm(dim)
52
+ self.ffn_g2_RCA = nn.Sequential(nn.Linear(dim, dim*4), nn.GELU(), nn.Linear(dim*4, dim))
53
+ self.ln_g4_RCA = nn.LayerNorm(dim)
54
+
55
+ # Binder updated by attending to DNA/GLM
56
+ self.cross_b2g_1_RCA = nn.MultiheadAttention(dim, heads, batch_first=True, dropout=dropout)
57
+ self.ln_b3_RCA = nn.LayerNorm(dim)
58
+ self.ffn_b2_RCA = nn.Sequential(nn.Linear(dim, dim*4), nn.GELU(), nn.Linear(dim*4, dim))
59
+ self.ln_b4_RCA = nn.LayerNorm(dim)
60
+
61
+ # cross attention (binder queries, glm keys/values)
62
+ # so the NDA path is updated by the transcriptoin factors
63
+ self.cross_g2b_2 = nn.MultiheadAttention(dim, heads, batch_first=True)
64
+ self.ln_g5 = nn.LayerNorm(dim)
65
+ self.ffn_g3 = nn.Sequential(
66
+ nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
67
+ )
68
+ self.ln_g6 = nn.LayerNorm(dim)
69
+
70
+ def forward(self, binder: torch.Tensor, glm: torch.Tensor, binder_kpm_mask=None, glm_kpm_mask=None):
71
+ """
72
+ binder: (batch, Lb, dim)
73
+ glm: (batch, Lg, dim) -- has passed through its local CNN beforehand
74
+ returns: updated binder representation (batch, Lb, dim)
75
+ """
76
+ # 1) Self-attentino and feed-forward networks for binder and DNA
77
+ # binder: self-attn + ffn
78
+ b = binder
79
+ b_sa, _ = self.sa_binder(b, b, b, key_padding_mask=binder_kpm_mask)
80
+ b = self.ln_b1(b + b_sa)
81
+ b_ff = self.ffn_b1(b)
82
+ b = self.ln_b2(b + b_ff)
83
+
84
+ # glm: self-attn + ffn
85
+ g = glm
86
+ g_sa, _ = self.sa_glm(g, g, g, key_padding_mask=glm_kpm_mask)
87
+ g = self.ln_g1(g + g_sa)
88
+ g_ff = self.ffn_g1(g)
89
+ g = self.ln_g2(g + g_ff)
90
+
91
+ # 2a) Reciprocal Cross-Attention:
92
+ # DNA updated by attending to Binder (Q=g, K=b, V=b)
93
+ # Binder updated by attending to DNA (Q=b, K=g, V=g)
94
+ g_ca, _ = self.cross_g2b_1_RCA(
95
+ g, b, b, key_padding_mask=binder_kpm_mask
96
+ # torch MultiheadAttention expects key_padding_mask=True for PADs;
97
+ # invert if your mask is True=keep:
98
+ # key_padding_mask=(~binder_mask.bool()) if binder_mask is not None else None
99
+ )
100
+ g = self.ln_g3_RCA(g + g_ca)
101
+ g = self.ln_g4_RCA(g + self.ffn_g2_RCA(g))
102
+
103
+ # 2b) Binder updated by attending to DNA/GLM (Q=b, K=g, V=g)
104
+ b_ca, _ = self.cross_b2g_1_RCA(
105
+ b, g, g, key_padding_mask=glm_kpm_mask
106
+ # key_padding_mask=(~glm_mask.bool()) if glm_mask is not None else None
107
+ )
108
+ b = self.ln_b3_RCA(b + b_ca)
109
+ b = self.ln_b4_RCA(b + self.ffn_b2_RCA(b))
110
+
111
+ # cross-attention: glm queries binder and glm embeddings are updated
112
+ g_to_b_ca, _ = self.cross_g2b_2(g, b, b, key_padding_mask=binder_kpm_mask)
113
+ g = self.ln_g5(g + g_to_b_ca)
114
+ g_ff = self.ffn_g3(g)
115
+ g = self.ln_g6(g + g_ff)
116
+ return b, g # (batch, Lb, dim)
117
+
118
+ class DimCompressor(nn.Module):
119
+ """
120
+ Learnable per-token compressor: maps any in_dim >= out_dim to out_dim (default 256).
121
+ If in_dim == out_dim, behaves as identity.
122
+ """
123
+
124
+ def __init__(self, in_dim: int, out_dim: int = 256):
125
+ super().__init__()
126
+ if in_dim == out_dim:
127
+ self.net = nn.Identity()
128
+ else:
129
+ hidden = max(out_dim * 2, (in_dim + out_dim) // 2)
130
+ self.net = nn.Sequential(
131
+ nn.LayerNorm(in_dim),
132
+ nn.Linear(in_dim, hidden),
133
+ nn.GELU(),
134
+ nn.Linear(hidden, out_dim),
135
+ )
136
+
137
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
138
+ # x: (B, L, in_dim)
139
+ return self.net(x)
140
+
141
+
142
+ class BindPredictor(LightningModule):
143
+ def __init__(
144
+ self,
145
+ # input_dim: int = 256, # OLD: single input dim
146
+ binder_input_dim: int = 1280, # NEW: TF (binder) original dim (e.g., 1280)
147
+ glm_input_dim: int = 256, # NEW: DNA/GLM original dim (e.g., 256)
148
+ compressed_dim: int = 256, # NEW: learnable compressed dim
149
+ hidden_dim: int = 256,
150
+ heads: int = 8,
151
+ num_layers: int = 4,
152
+ lr: float = 1e-4,
153
+ alpha: float = 20,
154
+ gamma: float = 20,
155
+ dropout: float = 0,
156
+ use_local_cnn_on_glm: bool = True,
157
+ weight_decay: float = 0.01,
158
+ loss_type = "mixed"
159
+ ):
160
+ # Init
161
+ super(BindPredictor, self).__init__()
162
+ self.save_hyperparameters()
163
+
164
+ # Learnable compressor for binder -> 256, then project to hidden
165
+ self.binder_compress = DimCompressor(binder_input_dim, out_dim=compressed_dim)
166
+ self.proj_binder = nn.Linear(compressed_dim, hidden_dim)
167
+
168
+ # GLM side stays 256 -> hidden
169
+ self.proj_glm = nn.Linear(glm_input_dim, hidden_dim)
170
+
171
+ self.use_local_cnn = use_local_cnn_on_glm
172
+ self.local_cnn = LocalCNN(hidden_dim) if use_local_cnn_on_glm else nn.Identity()
173
+
174
+ self.layers = nn.ModuleList(
175
+ [CrossModalBlock(hidden_dim, heads, self.hparams.dropout) for _ in range(num_layers)]
176
+ )
177
+
178
+ #self.ln_out = nn.LayerNorm(hidden_dim)
179
+ # self.head = nn.Sequential(nn.Linear(hidden_dim, 1), nn.Sigmoid()) # OLD: returned probabilities
180
+ self.head = nn.Linear(hidden_dim, 1) # NEW: return logits (safe for AMP)
181
+
182
+ def forward(self, binder_emb, glm_emb, binder_mask, glm_mask):
183
+ """
184
+ binder_emb: (B, Lb, binder_input_dim)
185
+ glm_emb: (B, Lg, glm_input_dim)
186
+ Returns per-nucleotide logits for the GLM sequence: (B, Lg)
187
+ """
188
+ # Binder: learnable compression → 256 → hidden
189
+ b = self.binder_compress(binder_emb) # (B, Lb, 256)
190
+ b = self.proj_binder(b) # (B, Lb, hidden_dim)
191
+
192
+ # GLM: project → hidden, add local CNN context
193
+ g = self.proj_glm(glm_emb) # (B, Lg, hidden_dim)
194
+ if self.use_local_cnn:
195
+ g = self.local_cnn(g)
196
+
197
+ # Cross-modal blocks: update binder states using GLM
198
+ for layer in self.layers:
199
+ b, g = layer(b, g, binder_mask, glm_mask) # (B, Lb, hidden_dim)
200
+
201
+ # Predict per-nucleotide logits on the GLM tokens:
202
+ # return self.head(g).squeeze(-1) # OLD: probabilities (with Sigmoid in head)
203
+ logits = self.head(g).squeeze(
204
+ -1
205
+ )
206
+ return logits
207
+
208
+ # ----- Lightning hooks -----
209
+ def training_step(self, batch, batch_idx):
210
+ """
211
+ Training step taken by PyTorch-Lightning trainer. Uses batch returned by data collator.
212
+ Colator returns a dictionary with:
213
+ "binder_emb" # [B, Lb_max, Db]
214
+ "binder_kpm" # [B, Lb_max]
215
+ "glm_emb" # [B, Lg_max, Dg]
216
+ "glm_kpm" # [B, Lg_max]
217
+ "labels" # [B, Lg_max]
218
+ "ID"
219
+ "tr_sequence"
220
+ "dna_sequence"
221
+ }
222
+ """
223
+ logits = self.forward(batch["binder_emb"], batch["glm_emb"], batch["binder_kpm"], batch["glm_kpm"])
224
+ loss = calculate_loss(
225
+ logits, batch["labels"], batch["binder_kpm"], batch["glm_kpm"], alpha=self.hparams.alpha, gamma=self.hparams.gamma, loss_type=self.hparams.loss_type
226
+ )
227
+ self.log(
228
+ "train/loss",
229
+ loss,
230
+ on_step=True,
231
+ on_epoch=True,
232
+ prog_bar=True,
233
+ batch_size=logits.size(0),
234
+ )
235
+
236
+ # ---- AUPRC and AUROC on labels in {0, >0.99} only ----
237
+ ap, n_pos, n_neg, precision, recall, thresholds = auprc_zeros_vs_ones_from_logits(
238
+ logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99
239
+ )
240
+ auc, n_pos, n_neg, tpr, fpr, thresolds, tp, fp = auroc_zeros_vs_ones_from_logits(
241
+ logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99
242
+ )
243
+ # per-batch AP (epoch-mean is a decent summary); sync across GPUs if using DDP
244
+ self.log("train/auprc_0v1",
245
+ ap if torch.isfinite(ap) else torch.tensor(0.0, device=ap.device),
246
+ on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0))
247
+ self.log("train/auroc_0v1",
248
+ auc if torch.isfinite(auc) else torch.tensor(0.0, device=auc.device),
249
+ on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0))
250
+ # (optional) also log class counts so you can sanity-check balance
251
+ self.log("train/n_pos_0v1", float(n_pos), on_step=False, on_epoch=True, sync_dist=True)
252
+ self.log("train/n_neg_0v1", float(n_neg), on_step=False, on_epoch=True, sync_dist=True)
253
+
254
+ return loss
255
+
256
+ def validation_step(self, batch, batch_idx):
257
+ logits = self.forward(batch["binder_emb"], batch["glm_emb"], batch["binder_kpm"], batch["glm_kpm"])
258
+ loss = calculate_loss(
259
+ logits, batch["labels"], batch["binder_kpm"], batch["glm_kpm"], alpha=self.hparams.alpha, gamma=self.hparams.gamma, loss_type=self.hparams.loss_type
260
+ )
261
+ self.log(
262
+ "val/loss",
263
+ loss,
264
+ on_step=False,
265
+ on_epoch=True,
266
+ prog_bar=True,
267
+ batch_size=logits.size(0),
268
+ )
269
+
270
+ # ---- AUPRC and AUROC on labels in {0, >0.99} only ----
271
+ ap, n_pos, n_neg, precision, recall, thresholds = auprc_zeros_vs_ones_from_logits(
272
+ logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99
273
+ )
274
+ auc, n_pos, n_neg, tpr, fpr, thresolds, tp, fp = auroc_zeros_vs_ones_from_logits(
275
+ logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99
276
+ )
277
+ # per-batch AP (epoch-mean is a decent summary); sync across GPUs if using DDP
278
+ self.log("val/auprc_0v1",
279
+ ap if torch.isfinite(ap) else torch.tensor(0.0, device=ap.device),
280
+ on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0))
281
+ self.log("val/auroc_0v1",
282
+ auc if torch.isfinite(auc) else torch.tensor(0.0, device=auc.device),
283
+ on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0))
284
+
285
+ return loss
286
+
287
+ def test_step(self, batch, batch_idx):
288
+ logits = self.forward(batch["binder_emb"], batch["glm_emb"], batch["binder_kpm"], batch["glm_kpm"])
289
+ loss = calculate_loss(
290
+ logits, batch["labels"], batch["binder_kpm"], batch["glm_kpm"], alpha=self.hparams.alpha, gamma=self.hparams.gamma, loss_type=self.hparams.loss_type
291
+ )
292
+ self.log(
293
+ "test/loss", loss, on_step=False, on_epoch=True, batch_size=logits.size(0)
294
+ )
295
+
296
+ # ---- AUPRC and AUROC on labels in {0, >0.99} only ----
297
+ ap, n_pos, n_neg, precision, recall, thresholds = auprc_zeros_vs_ones_from_logits(
298
+ logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99
299
+ )
300
+ auc, n_pos, n_neg, tpr, fpr, thresolds, tp, fp = auroc_zeros_vs_ones_from_logits(
301
+ logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99
302
+ )
303
+ # per-batch AP (epoch-mean is a decent summary); sync across GPUs if using DDP
304
+ self.log("test/auprc_0v1",
305
+ ap if torch.isfinite(ap) else torch.tensor(0.0, device=ap.device),
306
+ on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0))
307
+ self.log("test/auroc_0v1",
308
+ auc if torch.isfinite(auc) else torch.tensor(0.0, device=auc.device),
309
+ on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0))
310
+
311
+ return loss
312
+
313
+ def predict_step(self, batch, batch_idx, dataloader_idx=0):
314
+ logits = self.forward(batch["binder_emb"], batch["glm_emb"],
315
+ batch["binder_kpm"], batch["glm_kpm"]).squeeze(-1) # (B,L)
316
+ valid = ~batch["glm_kpm"] # (B,L)
317
+ return {
318
+ "ids": batch["ID"], # list[str]
319
+ "logits": logits.detach().cpu(), # (B,Lmax) padded
320
+ "valid": valid.detach().cpu(), # (B,Lmax) booleans
321
+ "labels": batch["labels"].detach().cpu(), # (B,Lmax) padded
322
+ }
323
+
324
+ def on_before_optimizer_step(self, optimizer):
325
+ # Compute global L2 norm of all parameter gradients (ignores None grads)
326
+ grads = []
327
+ for p in self.parameters():
328
+ if p.grad is not None:
329
+ # .detach() avoids autograd tracking; .float() avoids fp16 overflow in norms
330
+ grads.append(p.grad.detach().float().norm(2))
331
+ if grads:
332
+ total_norm = torch.norm(torch.stack(grads), p=2)
333
+ self.log("train/grad_norm", total_norm, on_step=True, prog_bar=False, logger=True)
334
+
335
+ def on_after_backward(self):
336
+ grads = [p.grad.detach().float().norm(2)
337
+ for p in self.parameters() if p.grad is not None]
338
+ if grads:
339
+ total_norm = torch.norm(torch.stack(grads), p=2)
340
+ self.log("train/grad_norm_back", total_norm, on_step=True, prog_bar=False)
341
+
342
+ def on_train_epoch_end(self):
343
+ if False:
344
+ if self.train_auc.compute() is not None:
345
+ self.log("train/auroc", self.train_auc.compute(), prog_bar=True)
346
+ self.train_auc.reset()
347
+
348
+ def on_validation_epoch_end(self):
349
+ if False:
350
+ if self.val_auc.compute() is not None:
351
+ self.log("val/auroc", self.val_auc.compute(), prog_bar=True)
352
+ self.val_auc.reset()
353
+
354
+ def on_test_epoch_end(self):
355
+ if False:
356
+ if self.test_auc.compute() is not None:
357
+ self.log("test/auroc", self.test_auc.compute(), prog_bar=True)
358
+ self.test_auc.reset()
359
+
360
+ def configure_optimizers(self):
361
+ # AdamW + cosine as a sensible default
362
+ opt = torch.optim.AdamW(
363
+ self.parameters(),
364
+ lr=self.hparams.lr,
365
+ weight_decay=self.hparams.weight_decay,
366
+ )
367
+ # Scheduler optional—comment out if you prefer fixed LR
368
+ sch = torch.optim.lr_scheduler.CosineAnnealingLR(
369
+ opt, T_max=max(self.trainer.max_epochs, 1)
370
+ )
371
+ return {
372
+ "optimizer": opt,
373
+ "lr_scheduler": {"scheduler": sch, "interval": "epoch"},
374
+ }
dpacman/data_modules/pair.py CHANGED
@@ -180,16 +180,16 @@ class PairDataset(Dataset):
180
  """
181
  if self.score_col not in dataset.columns:
182
  logger.info(f"Scores not provided. Adding placeholder scores where all positions are considered binding")
183
- dataset[self.score_col] = dataset["dna_sequence"].str.len()
184
  dataset[self.score_col] = dataset[self.score_col].apply(lambda x: ",".join([str(self.norm_value)]*x))
185
  self.fake_scores=True
186
  # split string into list of strings
187
  dataset[self.score_col] = dataset[self.score_col].apply(lambda x: x.split(","))
 
188
  # turn list of strings into list of normalized, rounded floats
189
  dataset[self.score_col] = dataset[self.score_col].apply(
190
  lambda x: [round(int(y) / self.norm_value, self.round_to) for y in x]
191
  )
192
-
193
  # convert to records for ease of loading
194
  dataset = dataset.to_dict(orient="records")
195
  return dataset
@@ -222,11 +222,13 @@ class PairDataModule(LightningDataModule):
222
  shuffle_train_batch_order: bool = True,
223
  score_col: str = "scores",
224
  target_col: str = "dna_sequence",
225
- binder_col: str = "tr_sequence"
 
226
  ):
227
  super().__init__()
228
  self.save_hyperparameters()
229
  self.debug_run = debug_run
 
230
 
231
  # Initialize the data files
232
  self.train_data_file = train_file
@@ -267,7 +269,7 @@ class PairDataModule(LightningDataModule):
267
  df = pd.read_csv(file_path)
268
  if lim is not None:
269
  df = df[:lim].reset_index(drop=True)
270
- return df[["ID", "dna_sequence", "tr_sequence", "scores"]]
271
  except:
272
  raise Exception(f"{file_path} is not a valid file")
273
 
@@ -278,7 +280,7 @@ class PairDataModule(LightningDataModule):
278
  if stage in (None, "fit"):
279
  if not hasattr(self, "train_dataset"):
280
  train_df = self.load_file(self.train_data_file, lim=lim)
281
- self.train_dataset = PairDataset(train_df, score_col = self.score_col, target_col = self.target_col, binder_col = self.binder_col)
282
  self.train_batches = make_length_batches(
283
  dataset_records=self.train_dataset.dataset,
284
  tr_shelf_path=str(self.hparams.tr_shelf_path),
@@ -294,7 +296,7 @@ class PairDataModule(LightningDataModule):
294
 
295
  if not hasattr(self, "val_dataset"):
296
  val_df = self.load_file(self.val_data_file, lim=lim)
297
- self.val_dataset = PairDataset(val_df, score_col = self.score_col, target_col = self.target_col, binder_col = self.binder_col)
298
  self.val_batches = make_length_batches(
299
  dataset_records=self.val_dataset.dataset,
300
  tr_shelf_path=str(self.hparams.tr_shelf_path),
@@ -309,7 +311,7 @@ class PairDataModule(LightningDataModule):
309
  if stage in (None, "validate"):
310
  if not hasattr(self, "val_dataset"):
311
  val_df = self.load_file(self.val_data_file, lim=lim)
312
- self.val_dataset = PairDataset(val_df, score_col = self.score_col, target_col = self.target_col, binder_col = self.binder_col)
313
  self.val_batches = make_length_batches(
314
  dataset_records=self.val_dataset.dataset,
315
  tr_shelf_path=str(self.hparams.tr_shelf_path),
@@ -324,7 +326,7 @@ class PairDataModule(LightningDataModule):
324
  if stage in (None, "test"):
325
  if not hasattr(self, "test_dataset"):
326
  test_df = self.load_file(self.test_data_file, lim=lim)
327
- self.test_dataset = PairDataset(test_df, score_col = self.score_col, target_col = self.target_col, binder_col = self.binder_col)
328
  self.test_batches = make_length_batches(
329
  dataset_records=self.test_dataset.dataset,
330
  tr_shelf_path=str(self.hparams.tr_shelf_path),
@@ -623,6 +625,8 @@ def main():
623
  debug_run=args.debug_run,
624
  shuffle_train_batch_order=args.shuffle_train_batch_order,
625
  pin_memory=False,
 
 
626
  )
627
 
628
  # ---- Train ----
 
180
  """
181
  if self.score_col not in dataset.columns:
182
  logger.info(f"Scores not provided. Adding placeholder scores where all positions are considered binding")
183
+ dataset[self.score_col] = dataset[self.target_col].str.len()
184
  dataset[self.score_col] = dataset[self.score_col].apply(lambda x: ",".join([str(self.norm_value)]*x))
185
  self.fake_scores=True
186
  # split string into list of strings
187
  dataset[self.score_col] = dataset[self.score_col].apply(lambda x: x.split(","))
188
+ dataset["copycol"] = dataset[self.score_col]
189
  # turn list of strings into list of normalized, rounded floats
190
  dataset[self.score_col] = dataset[self.score_col].apply(
191
  lambda x: [round(int(y) / self.norm_value, self.round_to) for y in x]
192
  )
 
193
  # convert to records for ease of loading
194
  dataset = dataset.to_dict(orient="records")
195
  return dataset
 
222
  shuffle_train_batch_order: bool = True,
223
  score_col: str = "scores",
224
  target_col: str = "dna_sequence",
225
+ binder_col: str = "tr_sequence",
226
+ norm_value: int = 1333
227
  ):
228
  super().__init__()
229
  self.save_hyperparameters()
230
  self.debug_run = debug_run
231
+ self.norm_value = norm_value
232
 
233
  # Initialize the data files
234
  self.train_data_file = train_file
 
269
  df = pd.read_csv(file_path)
270
  if lim is not None:
271
  df = df[:lim].reset_index(drop=True)
272
+ return df
273
  except:
274
  raise Exception(f"{file_path} is not a valid file")
275
 
 
280
  if stage in (None, "fit"):
281
  if not hasattr(self, "train_dataset"):
282
  train_df = self.load_file(self.train_data_file, lim=lim)
283
+ self.train_dataset = PairDataset(train_df, norm_value = self.norm_value, score_col = self.score_col, target_col = self.target_col, binder_col = self.binder_col)
284
  self.train_batches = make_length_batches(
285
  dataset_records=self.train_dataset.dataset,
286
  tr_shelf_path=str(self.hparams.tr_shelf_path),
 
296
 
297
  if not hasattr(self, "val_dataset"):
298
  val_df = self.load_file(self.val_data_file, lim=lim)
299
+ self.val_dataset = PairDataset(val_df, norm_value = self.norm_value, score_col = self.score_col, target_col = self.target_col, binder_col = self.binder_col)
300
  self.val_batches = make_length_batches(
301
  dataset_records=self.val_dataset.dataset,
302
  tr_shelf_path=str(self.hparams.tr_shelf_path),
 
311
  if stage in (None, "validate"):
312
  if not hasattr(self, "val_dataset"):
313
  val_df = self.load_file(self.val_data_file, lim=lim)
314
+ self.val_dataset = PairDataset(val_df, norm_value = self.norm_value, score_col = self.score_col, target_col = self.target_col, binder_col = self.binder_col)
315
  self.val_batches = make_length_batches(
316
  dataset_records=self.val_dataset.dataset,
317
  tr_shelf_path=str(self.hparams.tr_shelf_path),
 
326
  if stage in (None, "test"):
327
  if not hasattr(self, "test_dataset"):
328
  test_df = self.load_file(self.test_data_file, lim=lim)
329
+ self.test_dataset = PairDataset(test_df, norm_value = self.norm_value, score_col = self.score_col, target_col = self.target_col, binder_col = self.binder_col)
330
  self.test_batches = make_length_batches(
331
  dataset_records=self.test_dataset.dataset,
332
  tr_shelf_path=str(self.hparams.tr_shelf_path),
 
625
  debug_run=args.debug_run,
626
  shuffle_train_batch_order=args.shuffle_train_batch_order,
627
  pin_memory=False,
628
+ score_col="binary_scores",
629
+ norm_value=1
630
  )
631
 
632
  # ---- Train ----
dpacman/data_tasks/split/complex_remap.py ADDED
@@ -0,0 +1,736 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import Counter, defaultdict
2
+ from ortools.linear_solver import pywraplp
3
+ import random
4
+ from omegaconf import DictConfig
5
+ import pandas as pd
6
+ from pathlib import Path
7
+ import os
8
+ import numpy as np
9
+ from sklearn.model_selection import train_test_split
10
+ from dpacman.data_tasks.fimo.post_fimo import get_reverse_complement
11
+ import json
12
+ import rootutils
13
+ from dpacman.utils import pylogger
14
+
15
+ root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
16
+ logger = pylogger.RankedLogger(__name__, rank_zero_only=True)
17
+
18
+ def split_with_predefined_test(
19
+ full_df = pd.DataFrame(),
20
+ split_names=("train", "val", "test"),
21
+ test_trs=None,
22
+ test_dnas=None,
23
+ ratios=(0.8, 0.1, 0.1),
24
+ ):
25
+ """
26
+ Method for splitting into train and val with a predefined test set.
27
+ The proteins in the test set, and the DNA clusters of the DNAs they're associated with, must be excluded from train and val.
28
+ The remaining rows for train and val are split to preserve 80/10/10 as best as possible.
29
+ """
30
+ full_df[""]
31
+ test = full_df.copy(deep=True)
32
+ if test_trs is not None:
33
+ test = test.loc[test["tr_seqid"].isin(test_trs)].reset_index(drop=True)
34
+ if test_dnas is not None:
35
+ test = test.loc[test["dna_seqid"].isin(test_dnas)].reset_index(drop=True)
36
+
37
+ tr_clusters_to_exclude = test["tr_cluster_rep"].unique().tolist()
38
+ dna_clusters_to_exclude = test["dna_cluster_rep"].unique().tolist()
39
+
40
+ remaining = full_df.loc[
41
+ (~full_df["tr_cluster_rep"].isin(tr_clusters_to_exclude)) &
42
+ (~full_df["dna_cluster_rep"].isin(dna_clusters_to_exclude))
43
+ ].reset_index(drop=True)
44
+
45
+ test_ids = test["ID"].unique().tolist()
46
+ remaining_ids = remaining["ID"].unique().tolist()
47
+ remaining_clusters = remaining["dna_cluster_rep"].unique().tolis()
48
+ lost_rows = full_df.loc[
49
+ (~full_df["ID"].isin(test_ids)) &
50
+ (~full_df["ID"].isin(remaining_ids))
51
+ ]
52
+
53
+ logger.info(f"Rows in test: {len(test)}")
54
+ logger.info(f"Rows to be split between train and val: {len(remaining)}")
55
+ total_rows = len(test) + len(remaining)
56
+ logger.info(f"Total rows: {total_rows}. Test percentage: {100*len(test)/total_rows:.2f}%")
57
+ logger.info(f"Lost rows: {len(lost_rows)}")
58
+
59
+ train_ratio_from_remaining = round((0.8*total_rows)/len(remaining), 2)
60
+ # use sklearn
61
+ test_size_1 = 1 - train_ratio_from_remaining
62
+ logger.info(
63
+ f"\tPerforming first split: non-test clusters -> train clusters ({round(1-test_size_1,3)}) and val ({test_size_1})"
64
+ )
65
+ X = remaining_clusters
66
+ y = [0] * len(remaining_clusters)
67
+ X_train, X_val, y_train, y_val = train_test_split(
68
+ X, y, test_size=test_size_1, random_state=0
69
+ )
70
+
71
+ train = remaining.loc[remaining["dna_cluster_rep"].isin(X_train)]
72
+ val = remaining.loc[remaining["dna_cluster_rep"].isin(X_val)]
73
+ leaky_test = lost_rows
74
+
75
+ splits = {
76
+ "train": train,
77
+ "val": val,
78
+ "test": test,
79
+ "leaky_test": leaky_test
80
+ }
81
+ return splits
82
+
83
+ def split_bipartite_fast(
84
+ dna_clusters,
85
+ split_names=("train", "val", "test"),
86
+ ratios=(0.8, 0.1, 0.1),
87
+ ):
88
+ # use sklearn
89
+ test_size_1 = 0.2
90
+ test_size_2 = 0.5
91
+ logger.info(
92
+ f"\tPerforming first split: all clusters -> train clusters ({round(1-test_size_1,3)}) and other ({test_size_1})"
93
+ )
94
+ X = dna_clusters
95
+ y = [0] * len(dna_clusters)
96
+ X_train, X_test, y_train, y_test = train_test_split(
97
+ X, y, test_size=test_size_1, random_state=0
98
+ )
99
+ logger.info(
100
+ f"\tPerforming second split: other -> val clusters ({round(1-test_size_2,3)}) and test clusters ({test_size_2})"
101
+ )
102
+ X_val, X_test, y_val, y_test = train_test_split(
103
+ X_test, y_test, test_size=test_size_2, random_state=0
104
+ )
105
+
106
+ dna_assign = {}
107
+ for x in X_train:
108
+ dna_assign[x] = "train"
109
+ for x in X_val:
110
+ dna_assign[x] = "val"
111
+ for x in X_test:
112
+ dna_assign[x] = "test"
113
+
114
+ kept_by_split = {"train": len(X_train), "val": len(X_val), "test": len(X_test)}
115
+ return dna_assign, kept_by_split
116
+
117
+ def convert_scores(scores):
118
+ svec = [int(x) for x in scores.split(",")]
119
+ max_score = max(svec)
120
+ binary_svec = [0 if x<max_score else 1 for x in svec]
121
+ assert(svec.count(max_score)==binary_svec.count(1))
122
+ binary_svec = ",".join([str(x) for x in binary_svec])
123
+ return binary_svec
124
+
125
+ def split_bipartite_with_ratios_and_leaky(
126
+ edges,
127
+ split_names=("train", "val", "test"),
128
+ ratios=(0.8, 0.1, 0.1),
129
+ require_nonempty=False,
130
+ ratio_tolerance=None, # None = soft ratios only; 0.0 = exact band (use with care)
131
+ bigM=None,
132
+ shuffle_within_pair=False,
133
+ seed=0,
134
+ test_edges_must=None, # NEW: list of (tf,dna) with duplicates OR dict {(tf,dna): count}
135
+ ):
136
+ """
137
+ edges: list of (tf_cluster_id, dna_cluster_id). Duplicates allowed (-> weights).
138
+ test_edges_must: None, list of pairs, or dict {(tf,dna): required_count}.
139
+ - If a pair appears with required_count > 0, at least that many examples MUST be kept in TEST.
140
+ - This implicitly pins both clusters of that pair to TEST (cluster exclusivity).
141
+
142
+ Returns:
143
+ tf_assign: {tf_cluster -> split}
144
+ dna_assign: {dna_cluster -> split}
145
+ kept_by_split: {split -> kept_count} (train/val/test only)
146
+ total_kept: int
147
+ split_to_indices: {split -> [input indices]} including 'leaky_test'
148
+ split_to_edges: {split -> [(tf,dna), ...]} including 'leaky_test'
149
+ """
150
+ # Aggregate counts per pair
151
+ w = Counter(edges)
152
+ tfs = {t for (t, _) in w}
153
+ dnas = {d for (_, d) in w}
154
+ S = list(split_names)
155
+ rs = dict(zip(S, ratios))
156
+ N = sum(w.values())
157
+ if bigM is None:
158
+ bigM = 1000 * max(1, N)
159
+
160
+ # Index original edges so we can return a per-example split
161
+ pair_to_indices = defaultdict(list)
162
+ for idx, (c, d) in enumerate(edges):
163
+ pair_to_indices[(c, d)].append(idx)
164
+
165
+ if shuffle_within_pair:
166
+ rng = random.Random(seed)
167
+ for key in pair_to_indices:
168
+ rng.shuffle(pair_to_indices[key])
169
+
170
+ # Parse required test edges
171
+ req_test = Counter()
172
+ if test_edges_must:
173
+ if isinstance(test_edges_must, dict):
174
+ for k, v in test_edges_must.items():
175
+ if not isinstance(k, tuple) or len(k) != 2:
176
+ raise ValueError(
177
+ "test_edges_must dict keys must be (tf_cluster, dna_cluster)"
178
+ )
179
+ if v < 0:
180
+ raise ValueError("required_count must be non-negative")
181
+ if v:
182
+ req_test[k] += int(v)
183
+ else:
184
+ # assume iterable of pairs
185
+ req_test = Counter(test_edges_must)
186
+ # Validate against available counts
187
+ for pair, req in req_test.items():
188
+ if pair not in w:
189
+ raise ValueError(f"Required test pair {pair} not present in edges.")
190
+ if req > w[pair]:
191
+ raise ValueError(
192
+ f"Required count {req} for {pair} exceeds available {w[pair]}."
193
+ )
194
+
195
+ # Build solver
196
+ solver = pywraplp.Solver.CreateSolver("CBC")
197
+ if solver is None:
198
+ raise RuntimeError("Could not create CBC solver.")
199
+
200
+ # Binary cluster assignments
201
+ x = {(c, s): solver.BoolVar(f"x[{c},{s}]") for c in tfs for s in S}
202
+ y = {(d, s): solver.BoolVar(f"y[{d},{s}]") for d in dnas for s in S}
203
+
204
+ # Each cluster in exactly one split
205
+ for c in tfs:
206
+ solver.Add(sum(x[c, s] for s in S) == 1)
207
+ for d in dnas:
208
+ solver.Add(sum(y[d, s] for s in S) == 1)
209
+
210
+ # Integer kept counts per pair and split (allow partial within-pair)
211
+ k = {
212
+ ((c, d), s): solver.IntVar(0, w[(c, d)], f"k[{c},{d},{s}]")
213
+ for (c, d) in w
214
+ for s in S
215
+ }
216
+
217
+ # Only keep in split s if both endpoint clusters are assigned to s
218
+ for (c, d), wt in w.items():
219
+ for s in S:
220
+ solver.Add(k[((c, d), s)] <= wt * x[c, s])
221
+ solver.Add(k[((c, d), s)] <= wt * y[d, s])
222
+
223
+ # Enforce minimum kept counts in TEST for required pairs
224
+ for (c, d), req in req_test.items():
225
+ solver.Add(k[((c, d), "test")] >= req)
226
+
227
+ # Optional: ensure each split has at least one cluster (feasibility depends on counts)
228
+ if require_nonempty:
229
+ for s in S:
230
+ solver.Add(sum(x[c, s] for c in tfs) + sum(y[d, s] for d in dnas) >= 1)
231
+
232
+ # Kept counts per split and total
233
+ K = {s: solver.IntVar(0, N, f"K[{s}]") for s in S}
234
+ for s in S:
235
+ solver.Add(K[s] == sum(k[((c, d), s)] for (c, d) in w))
236
+ T = solver.IntVar(0, N, "T")
237
+ solver.Add(T == sum(K[s] for s in S))
238
+
239
+ # Ratio deviation: K_s - r_s * T = d+ - d-
240
+ dpos = {s: solver.NumVar(0, solver.infinity(), f"dpos[{s}]") for s in S}
241
+ dneg = {s: solver.NumVar(0, solver.infinity(), f"dneg[{s}]") for s in S}
242
+ for s in S:
243
+ solver.Add(K[s] - rs[s] * T == dpos[s] - dneg[s])
244
+
245
+ # Optional hard band around target ratios
246
+ if ratio_tolerance is not None:
247
+ eps = float(ratio_tolerance)
248
+ for s in S:
249
+ solver.Add(K[s] >= (rs[s] - eps) * T)
250
+ solver.Add(K[s] <= (rs[s] + eps) * T)
251
+
252
+ # Objective: maximize T then minimize total deviation
253
+ obj = solver.Objective()
254
+ obj.SetMaximization()
255
+ obj.SetCoefficient(T, float(bigM))
256
+ for s in S:
257
+ obj.SetCoefficient(dpos[s], -1.0)
258
+ obj.SetCoefficient(dneg[s], -1.0)
259
+
260
+ status = solver.Solve()
261
+ if status not in (pywraplp.Solver.OPTIMAL, pywraplp.Solver.FEASIBLE):
262
+ raise RuntimeError(
263
+ "No feasible solution (check ratio_tolerance vs. required test edges)."
264
+ )
265
+
266
+ # Read cluster assignments
267
+ tf_assign = {c: next(s for s in S if x[c, s].solution_value() > 0.5) for c in tfs}
268
+ dna_assign = {d: next(s for s in S if y[d, s].solution_value() > 0.5) for d in dnas}
269
+
270
+ # Kept counts per split
271
+ kept_by_split = {s: int(round(K[s].solution_value())) for s in S}
272
+ total_kept = int(round(T.solution_value()))
273
+
274
+ # ---- Build per-example split assignment (including 'leaky_test') ----
275
+ split_to_indices = {s: [] for s in S}
276
+ remaining_indices = {pair: list(pair_to_indices[pair]) for pair in pair_to_indices}
277
+
278
+ # Allocate the kept examples per split (train/val/test)
279
+ for (c, d), wt in w.items():
280
+ for s in S:
281
+ cnt = int(round(k[((c, d), s)].solution_value()))
282
+ if cnt > 0:
283
+ take = remaining_indices[(c, d)][:cnt]
284
+ split_to_indices[s].extend(take)
285
+ remaining_indices[(c, d)] = remaining_indices[(c, d)][cnt:]
286
+
287
+ # Everything left becomes leaky_test
288
+ leaky_indices = []
289
+ for pair, idxs in remaining_indices.items():
290
+ if idxs:
291
+ leaky_indices.extend(idxs)
292
+
293
+ split_to_indices["leaky_test"] = leaky_indices
294
+ split_to_edges = {
295
+ s: [edges[i] for i in split_to_indices[s]] for s in split_to_indices
296
+ }
297
+
298
+ return (
299
+ tf_assign,
300
+ dna_assign,
301
+ kept_by_split,
302
+ total_kept,
303
+ split_to_indices,
304
+ split_to_edges,
305
+ )
306
+
307
+
308
+ class DSU:
309
+ def __init__(self):
310
+ self.p = {}
311
+
312
+ def find(self, x):
313
+ if x not in self.p:
314
+ self.p[x] = x
315
+ while self.p[x] != x:
316
+ self.p[x] = self.p[self.p[x]]
317
+ x = self.p[x]
318
+ return x
319
+
320
+ def union(self, a, b):
321
+ ra, rb = self.find(a), self.find(b)
322
+ if ra != rb:
323
+ self.p[rb] = ra
324
+
325
+
326
+ def split_bipartite_by_components(
327
+ edges,
328
+ split_names=("train", "val", "test"),
329
+ ratios=(0.8, 0.1, 0.1),
330
+ seed=0,
331
+ require_nonempty=False,
332
+ test_edges_must=None, # None, list[(tf,dna)], or dict{(tf,dna): count}
333
+ ):
334
+ """
335
+ Guarantees exclusivity: each TF cluster and DNA cluster appears in at most one split.
336
+ Strategy: find connected components in the TF–DNA bipartite graph and assign components wholesale.
337
+ """
338
+ rng = random.Random(seed)
339
+ w = Counter(edges) # multiplicities per pair
340
+ if not w:
341
+ raise ValueError("No edges.")
342
+
343
+ # 1) Build components with Union-Find (prefix to keep TF/DNA namespaces disjoint)
344
+ dsu = DSU()
345
+ for tf, dna in w:
346
+ dsu.union(("T", tf), ("D", dna))
347
+ comp_pairs = defaultdict(list)
348
+ comp_weight = defaultdict(int)
349
+ for (tf, dna), cnt in w.items():
350
+ root = dsu.find(("T", tf)) # component id = root of TF endpoint
351
+ comp_pairs[root].append((tf, dna))
352
+ comp_weight[root] += cnt
353
+
354
+ comps = list(comp_pairs.keys())
355
+ C = len(comps)
356
+ S = list(split_names)
357
+ rs = dict(zip(S, ratios))
358
+ N = sum(comp_weight[c] for c in comps)
359
+ target = {s: int(round(rs[s] * N)) for s in S}
360
+
361
+ # 2) Pin components that contain required TEST pairs
362
+ pinned = {} # comp_root -> pinned_split ("test")
363
+ if test_edges_must:
364
+ req = (
365
+ Counter(test_edges_must)
366
+ if not isinstance(test_edges_must, dict)
367
+ else Counter(test_edges_must)
368
+ )
369
+ # Map each required pair to its component, ensure feasibility
370
+ for (tf, dna), r in req.items():
371
+ if (tf, dna) not in w:
372
+ raise ValueError(f"Required pair {(tf,dna)} not present.")
373
+ if r > w[(tf, dna)]:
374
+ raise ValueError(
375
+ f"Required count {r} for {(tf,dna)} exceeds available {w[(tf,dna)]}."
376
+ )
377
+ comp = dsu.find(("T", tf))
378
+ if comp in pinned and pinned[comp] != "test":
379
+ raise ValueError(
380
+ f"Component conflict: already pinned to {pinned[comp]}, but {(tf,dna)} demands test."
381
+ )
382
+ pinned[comp] = "test"
383
+ # NOTE: pinning a pair pins the WHOLE component to test (to keep exclusivity).
384
+ # If you only want some edges kept in test and discard the rest, handle below when materializing.
385
+
386
+ # 3) Assign components greedily by deficit
387
+ kept_by_split = {s: 0 for s in S}
388
+ comp_assign = {} # comp_root -> split
389
+
390
+ # First assign pinned comps
391
+ for comp, split in pinned.items():
392
+ comp_assign[comp] = split
393
+ kept_by_split[split] += comp_weight[comp]
394
+
395
+ # Sort remaining components by descending weight
396
+ remaining = [c for c in comps if c not in comp_assign]
397
+ remaining.sort(key=lambda c: comp_weight[c], reverse=True)
398
+
399
+ # Ensure nonempty splits if requested (seed with largest remaining comps)
400
+ if require_nonempty:
401
+ seeds = remaining[: min(len(S), len(remaining))]
402
+ for comp, s in zip(seeds, S):
403
+ comp_assign[comp] = s
404
+ kept_by_split[s] += comp_weight[comp]
405
+ remaining = [c for c in remaining if c not in comp_assign]
406
+
407
+ for comp in remaining:
408
+ # choose split with largest deficit (target - current)
409
+ deficits = {s: target[s] - kept_by_split[s] for s in S}
410
+ best = max(deficits, key=lambda s: deficits[s])
411
+ comp_assign[comp] = best
412
+ kept_by_split[best] += comp_weight[comp]
413
+
414
+ total_kept = sum(kept_by_split.values())
415
+
416
+ # 4) Materialize per-example indices (and verify exclusivity)
417
+ pair_to_indices = defaultdict(list)
418
+ for idx, pair in enumerate(edges):
419
+ pair_to_indices[pair].append(idx)
420
+
421
+ split_to_indices = {s: [] for s in S}
422
+ for comp, s in comp_assign.items():
423
+ for pair in comp_pairs[comp]:
424
+ split_to_indices[s].extend(pair_to_indices[pair])
425
+
426
+ # Optional: if you pinned a comp due to a small 'must-test' count but
427
+ # want to *discard* the rest instead of keeping them in test, uncomment:
428
+ # for comp, s in comp_assign.items():
429
+ # if s == "test" and test_edges_must:
430
+ # # Keep only the required counts; dump extras to 'leaky_test'
431
+ # ...
432
+ # (Left out for clarity; default is: keep the whole component in its split.)
433
+
434
+ # 5) Build edge lists and simple cluster assignments
435
+ split_to_edges = {
436
+ s: [edges[i] for i in split_to_indices[s]] for s in split_to_indices
437
+ }
438
+ tf_assign, dna_assign = {}, {}
439
+ for comp, s in comp_assign.items():
440
+ for tf, dna in comp_pairs[comp]:
441
+ tf_assign[tf] = s
442
+ dna_assign[dna] = s
443
+
444
+ # 6) Safety check: no DNA/TF appears in multiple splits
445
+ tf_in_split = defaultdict(set)
446
+ dna_in_split = defaultdict(set)
447
+ for s, elist in split_to_edges.items():
448
+ for tf, dna in elist:
449
+ tf_in_split[tf].add(s)
450
+ dna_in_split[dna].add(s)
451
+ dup_tf = {tf: ss for tf, ss in tf_in_split.items() if len(ss) > 1}
452
+ dup_dna = {dn: ss for dn, ss in dna_in_split.items() if len(ss) > 1}
453
+ assert not dup_tf and not dup_dna, f"Exclusivity violated: {dup_tf} {dup_dna}"
454
+
455
+ return (
456
+ tf_assign,
457
+ dna_assign,
458
+ kept_by_split,
459
+ total_kept,
460
+ split_to_indices,
461
+ split_to_edges,
462
+ )
463
+
464
+
465
+ def print_split_ratios(kept_by_split):
466
+ total = sum(kept_by_split.values())
467
+ train_pcnt = 100 * kept_by_split["train"] / total
468
+ val_pcnt = 100 * kept_by_split["val"] / total
469
+ test_pcnt = 100 * kept_by_split["test"] / total
470
+ logger.info(
471
+ f"Cluster distribution - Train: {train_pcnt:.2f}%, Val: {val_pcnt:.2f}%, Test: {test_pcnt:.2f}%"
472
+ )
473
+
474
+
475
+ def make_edges(
476
+ processed_fimo_path: str, protein_cluster_path: str, dna_cluster_path: str
477
+ ):
478
+ """
479
+ Make edges for input to the splitting algorithm. Edges consist of: (tr_cluster_rep)_(dna_cluster_rep) where the cluster rep is the sequence ID
480
+ """
481
+ # Read cluser data
482
+ protein_clusters = pd.read_csv(protein_cluster_path, header=None, sep="\t")
483
+ protein_clusters.columns = ["tr_cluster_rep", "tr_seqid"]
484
+
485
+ dna_clusters = pd.read_csv(dna_cluster_path, header=None, sep="\t")
486
+ dna_clusters.columns = ["dna_cluster_rep", "dna_seqid"]
487
+
488
+ # Read datapoints
489
+ edges = pd.read_parquet(processed_fimo_path)
490
+ edges = pd.merge(edges, dna_clusters, on="dna_seqid", how="left")
491
+ edges = pd.merge(edges, protein_clusters, on="tr_seqid", how="left")
492
+ edges["edge"] = edges.apply(
493
+ lambda row: (row["tr_cluster_rep"], row["dna_cluster_rep"]), axis=1
494
+ )
495
+
496
+ logger.info(f"Total unique edges: {len(edges['edge'].unique().tolist())}")
497
+ dup_edges = edges.loc[edges.duplicated("edge")]["edge"].unique().tolist()
498
+ logger.info(f"Total edges with >1 datapoint: {len(dup_edges)}")
499
+ logger.info(
500
+ f"Total datapoints belonging to a duplicate edge: {len(edges.loc[edges['edge'].isin(dup_edges)])}"
501
+ )
502
+ return edges
503
+
504
+
505
+ def check_validity(train, val, test, split_by="both"):
506
+ """
507
+ Rigorous check for no overlap
508
+ Columns = ["ID","dna_sequence","tr_sequence","tr_cluster_rep","dna_cluster_rep", "scores","split"]
509
+ """
510
+ train_ids = set(train["ID"].unique().tolist())
511
+ val_ids = set(val["ID"].unique().tolist())
512
+ test_ids = set(test["ID"].unique().tolist())
513
+
514
+ assert len(train_ids.intersection(val_ids)) == 0
515
+ assert len(train_ids.intersection(test_ids)) == 0
516
+ assert len(val_ids.intersection(test_ids)) == 0
517
+ logger.info(f"Pass! No overlap in IDs")
518
+
519
+ if split_by != "dna":
520
+ train_tr_seqs = set(train["tr_sequence"].unique().tolist())
521
+ val_tr_seqs = set(val["tr_sequence"].unique().tolist())
522
+ test_tr_seqs = set(test["tr_sequence"].unique().tolist())
523
+
524
+ assert len(train_tr_seqs.intersection(val_tr_seqs)) == 0
525
+ assert len(train_tr_seqs.intersection(test_tr_seqs)) == 0
526
+ assert len(val_tr_seqs.intersection(test_tr_seqs)) == 0
527
+ logger.info(f"Pass! No overlap in TR sequences")
528
+
529
+ train_tr_reps = set(train["tr_cluster_rep"].unique().tolist())
530
+ val_tr_reps = set(val["tr_cluster_rep"].unique().tolist())
531
+ test_tr_reps = set(test["tr_cluster_rep"].unique().tolist())
532
+
533
+ assert len(train_tr_reps.intersection(val_tr_reps)) == 0
534
+ assert len(train_tr_reps.intersection(test_tr_reps)) == 0
535
+ assert len(val_tr_reps.intersection(test_tr_reps)) == 0
536
+ logger.info(f"Pass! No overlap in TR cluster reps")
537
+
538
+ if split_by != "protein":
539
+ train_dna_seqs = set(train["dna_sequence"].unique().tolist())
540
+ val_dna_seqs = set(val["dna_sequence"].unique().tolist())
541
+ test_dna_seqs = set(test["dna_sequence"].unique().tolist())
542
+
543
+ assert len(train_dna_seqs.intersection(val_dna_seqs)) == 0
544
+ assert len(train_dna_seqs.intersection(test_dna_seqs)) == 0
545
+ assert len(val_dna_seqs.intersection(test_dna_seqs)) == 0
546
+ logger.info(f"Pass! No overlap in DNA sequences")
547
+
548
+ train_dna_reps = set(train["dna_cluster_rep"].unique().tolist())
549
+ val_dna_reps = set(val["dna_cluster_rep"].unique().tolist())
550
+ test_dna_reps = set(test["dna_cluster_rep"].unique().tolist())
551
+
552
+ assert len(train_dna_reps.intersection(val_dna_reps)) == 0
553
+ assert len(train_dna_reps.intersection(test_dna_reps)) == 0
554
+ assert len(val_dna_reps.intersection(test_dna_reps)) == 0
555
+ logger.info(f"Pass! No overlap in DNA cluster reps")
556
+
557
+
558
+ def augment_rc(df):
559
+ """
560
+ Get the reverse complement and add it as a datapoint, effectively doubling the dataset.
561
+ Also flip the orientation of the scores
562
+
563
+ columns = ["ID","dna_sequence","tr_sequence","tr_cluster_rep","dna_cluster_rep", "scores","split"]
564
+ """
565
+ df_rc = df.copy(deep=True)
566
+
567
+ df_rc["dna_sequence"] = df_rc["dna_sequence"].apply(
568
+ lambda x: get_reverse_complement(x)
569
+ )
570
+ df_rc["ID"] = df_rc["ID"] + "_rc"
571
+ df_rc["scores"] = df_rc["scores"].apply(lambda s: ",".join(s.split(",")[::-1]))
572
+
573
+ final_df = pd.concat([df, df_rc]).reset_index(drop=True)
574
+
575
+ return final_df
576
+
577
+
578
+ def main(cfg: DictConfig):
579
+ """
580
+ Take a set of DNA clusters + protein clusters, and create the best possible splits into train/val/test.
581
+ """
582
+ # construct edges from training data
583
+ edge_df = make_edges(
584
+ processed_fimo_path=Path(root) / cfg.data_task.input_data_path,
585
+ protein_cluster_path=Path(root) / cfg.data_task.cluster_output_paths.protein,
586
+ dna_cluster_path=Path(root) / cfg.data_task.cluster_output_paths.dna,
587
+ )
588
+ edges = edge_df["edge"].unique().tolist()
589
+
590
+ # figure out if we actually even have a conflict
591
+ total_proteins = len(edge_df["tr_seqid"].unique().tolist())
592
+ total_protein_clusters = len(edge_df["tr_cluster_rep"].unique().tolist())
593
+
594
+ no_protein_overlap = (total_proteins) == (total_protein_clusters)
595
+ logger.info(f"All proteins are in their own clusters: {no_protein_overlap}")
596
+
597
+ if cfg.data_task.split_by == "dna":
598
+ if cfg.data_task.p_exclude:
599
+ return
600
+ else:
601
+ logger.info(f"Easy split: all proteins are in their own clusters.")
602
+ dna_clusters = edge_df["dna_cluster_rep"].unique().tolist()
603
+ results = split_bipartite_fast(
604
+ dna_clusters,
605
+ split_names=("train", "val", "test"),
606
+ ratios=(
607
+ cfg.data_task.train_ratio,
608
+ cfg.data_task.val_ratio,
609
+ cfg.data_task.test_ratio,
610
+ ),
611
+ )
612
+ dna_assign, kept_by_split = results
613
+
614
+ # assign datapoints to cluster by their DNA cluster rep
615
+ edge_df["split"] = edge_df["dna_cluster_rep"].map(dna_assign)
616
+ else:
617
+ results = split_bipartite_by_components(
618
+ edges,
619
+ split_names=("train", "val", "test"),
620
+ ratios=(
621
+ cfg.data_task.train_ratio,
622
+ cfg.data_task.val_ratio,
623
+ cfg.data_task.test_ratio,
624
+ ),
625
+ require_nonempty=cfg.data_task.require_nonempty,
626
+ seed=cfg.data_task.seed,
627
+ test_edges_must=None,
628
+ )
629
+
630
+ (
631
+ tf_assign,
632
+ dna_assign,
633
+ kept_by_split,
634
+ total_kept,
635
+ split_to_indices,
636
+ split_to_edges,
637
+ ) = results
638
+
639
+ # Map each sample to its split
640
+ print(tf_assign)
641
+ print(dna_assign)
642
+ edge_df["tr_split"] = edge_df["tr_cluster_rep"].map(tf_assign)
643
+ edge_df["dna_split"] = edge_df["dna_cluster_rep"].map(dna_assign)
644
+ edge_df["same_split"] = (
645
+ edge_df["tr_split"] == edge_df["dna_split"]
646
+ ) # should always be true if easy cluster
647
+ edge_df["split"] = edge_df["tr_split"]
648
+ print(edge_df)
649
+ edge_df["split"] = np.where(
650
+ edge_df["same_split"],
651
+ edge_df["split"], # keep existing split if same_split == True
652
+ "leak", # otherwise leak
653
+ )
654
+ print(edge_df)
655
+
656
+ # Print ratios: hopefully close to desired (e.g. 80/10/10)
657
+ print_split_ratios(kept_by_split)
658
+
659
+ # Make train, val, test sets
660
+ # make sure no ID is duplicate
661
+ assert len(edge_df["ID"].unique()) == len(edge_df)
662
+ split_cols = [
663
+ "ID",
664
+ "dna_sequence",
665
+ "tr_sequence",
666
+ "tr_cluster_rep",
667
+ "dna_cluster_rep",
668
+ "scores",
669
+ "split",
670
+ ]
671
+ train = edge_df.loc[edge_df["split"] == "train"].reset_index(drop=True)[split_cols]
672
+ val = edge_df.loc[edge_df["split"] == "val"].reset_index(drop=True)[split_cols]
673
+ test = edge_df.loc[edge_df["split"] == "test"].reset_index(drop=True)[split_cols]
674
+
675
+ # ensure there is no overlap
676
+ check_validity(train, val, test, split_by=cfg.data_task.split_by)
677
+
678
+ total = sum([len(train), len(val), len(test)])
679
+ logger.info(f"Length of train dataset: {len(train)} ({100*len(train)/total:.2f}%)")
680
+ logger.info(f"Length of val dataset: {len(val)} ({100*len(val)/total:.2f}%)")
681
+ logger.info(f"Length of test dataset: {len(test)} ({100*len(test)/total:.2f}%)")
682
+ logger.info(f"Total sequences = {total}. Same as edges size? {total==len(edge_df)}")
683
+
684
+ og_unique_dna = pd.concat([train, val, test])
685
+ og_unique_dna = len(og_unique_dna["dna_sequence"].unique())
686
+
687
+ ## Now do RC data augmentation if asked
688
+ if cfg.data_task.augment_rc:
689
+ train = augment_rc(train)
690
+ val = augment_rc(val)
691
+ test = augment_rc(test)
692
+
693
+ logger.info(f"Added reverse complement sequences to train, val, and test.")
694
+
695
+ check_validity(train, val, test, split_by=cfg.data_task.split_by)
696
+
697
+ total = sum([len(train), len(val), len(test)])
698
+ logger.info(
699
+ f"Length of train dataset: {len(train)} ({100*len(train)/total:.2f}%)"
700
+ )
701
+ logger.info(f"Length of val dataset: {len(val)} ({100*len(val)/total:.2f}%)")
702
+ logger.info(f"Length of test dataset: {len(test)} ({100*len(test)/total:.2f}%)")
703
+ logger.info(
704
+ f"Total sequences = {total}. Same as edges size? {total==len(edge_df)}"
705
+ )
706
+
707
+ # since we've added all these new DNA sequences, we do need a new apping of seq id to dna sequence
708
+ all_data = pd.concat([train, val, test])
709
+ all_data["dna_seqid"] = all_data["ID"].str.split("_", n=1, expand=True)[1]
710
+ dna_dict = dict(zip(all_data["dna_seqid"], all_data["dna_sequence"]))
711
+ assert len(dna_dict) == len(all_data.drop_duplicates(["dna_sequence"]))
712
+ new_map_path = str(Path(root) / cfg.data_task.dna_map_path).replace(
713
+ ".json", "_with_rc.json"
714
+ )
715
+
716
+ with open(new_map_path, "w") as f:
717
+ json.dump(dna_dict, f, indent=2)
718
+ logger.info(
719
+ f"Saved DNA maps with reverse complements (len {len(dna_dict)}=2*original map of len {og_unique_dna}=={len(dna_dict)==2*og_unique_dna}) to {new_map_path}"
720
+ )
721
+
722
+ # create the output dir
723
+ split_out_dir = Path(root) / cfg.data_task.split_out_dir
724
+ os.makedirs(split_out_dir, exist_ok=True)
725
+
726
+ # add binary_scores to allow other training modes
727
+ train["fimo_binary_sores"] = train["scores"].apply(lambda x: convert_scores(x))
728
+ val["fimo_binary_sores"] = val["scores"].apply(lambda x: convert_scores(x))
729
+ test["fimo_binary_sores"] = test["scores"].apply(lambda x: convert_scores(x))
730
+
731
+ # slect final cols and save
732
+ split_final_cols = ["ID", "dna_sequence", "tr_sequence", "scores", "fimo_binary_sores", "split"]
733
+ train[split_final_cols].to_csv(split_out_dir / "train.csv", index=False)
734
+ val[split_final_cols].to_csv(split_out_dir / "val.csv", index=False)
735
+ test[split_final_cols].to_csv(split_out_dir / "test.csv", index=False)
736
+ logger.info(f"Saved all splits to {split_out_dir}")
dpacman/data_tasks/split/remap.py CHANGED
@@ -15,6 +15,74 @@ from dpacman.utils import pylogger
15
  root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
16
  logger = pylogger.RankedLogger(__name__, rank_zero_only=True)
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  def split_bipartite_fast(
20
  dna_clusters,
@@ -50,354 +118,22 @@ def split_bipartite_fast(
50
  kept_by_split = {"train": len(X_train), "val": len(X_val), "test": len(X_test)}
51
  return dna_assign, kept_by_split
52
 
53
- def convert_scores(scores):
 
 
 
 
54
  svec = [int(x) for x in scores.split(",")]
55
  max_score = max(svec)
56
- binary_svec = [0 if x<max_score else 1 for x in svec]
57
- assert(svec.count(max_score)==binary_svec.count(1))
 
 
 
 
58
  binary_svec = ",".join([str(x) for x in binary_svec])
59
  return binary_svec
60
-
61
- def split_bipartite_with_ratios_and_leaky(
62
- edges,
63
- split_names=("train", "val", "test"),
64
- ratios=(0.8, 0.1, 0.1),
65
- require_nonempty=False,
66
- ratio_tolerance=None, # None = soft ratios only; 0.0 = exact band (use with care)
67
- bigM=None,
68
- shuffle_within_pair=False,
69
- seed=0,
70
- test_edges_must=None, # NEW: list of (tf,dna) with duplicates OR dict {(tf,dna): count}
71
- ):
72
- """
73
- edges: list of (tf_cluster_id, dna_cluster_id). Duplicates allowed (-> weights).
74
- test_edges_must: None, list of pairs, or dict {(tf,dna): required_count}.
75
- - If a pair appears with required_count > 0, at least that many examples MUST be kept in TEST.
76
- - This implicitly pins both clusters of that pair to TEST (cluster exclusivity).
77
-
78
- Returns:
79
- tf_assign: {tf_cluster -> split}
80
- dna_assign: {dna_cluster -> split}
81
- kept_by_split: {split -> kept_count} (train/val/test only)
82
- total_kept: int
83
- split_to_indices: {split -> [input indices]} including 'leaky_test'
84
- split_to_edges: {split -> [(tf,dna), ...]} including 'leaky_test'
85
- """
86
- # Aggregate counts per pair
87
- w = Counter(edges)
88
- tfs = {t for (t, _) in w}
89
- dnas = {d for (_, d) in w}
90
- S = list(split_names)
91
- rs = dict(zip(S, ratios))
92
- N = sum(w.values())
93
- if bigM is None:
94
- bigM = 1000 * max(1, N)
95
-
96
- # Index original edges so we can return a per-example split
97
- pair_to_indices = defaultdict(list)
98
- for idx, (c, d) in enumerate(edges):
99
- pair_to_indices[(c, d)].append(idx)
100
-
101
- if shuffle_within_pair:
102
- rng = random.Random(seed)
103
- for key in pair_to_indices:
104
- rng.shuffle(pair_to_indices[key])
105
-
106
- # Parse required test edges
107
- req_test = Counter()
108
- if test_edges_must:
109
- if isinstance(test_edges_must, dict):
110
- for k, v in test_edges_must.items():
111
- if not isinstance(k, tuple) or len(k) != 2:
112
- raise ValueError(
113
- "test_edges_must dict keys must be (tf_cluster, dna_cluster)"
114
- )
115
- if v < 0:
116
- raise ValueError("required_count must be non-negative")
117
- if v:
118
- req_test[k] += int(v)
119
- else:
120
- # assume iterable of pairs
121
- req_test = Counter(test_edges_must)
122
- # Validate against available counts
123
- for pair, req in req_test.items():
124
- if pair not in w:
125
- raise ValueError(f"Required test pair {pair} not present in edges.")
126
- if req > w[pair]:
127
- raise ValueError(
128
- f"Required count {req} for {pair} exceeds available {w[pair]}."
129
- )
130
-
131
- # Build solver
132
- solver = pywraplp.Solver.CreateSolver("CBC")
133
- if solver is None:
134
- raise RuntimeError("Could not create CBC solver.")
135
-
136
- # Binary cluster assignments
137
- x = {(c, s): solver.BoolVar(f"x[{c},{s}]") for c in tfs for s in S}
138
- y = {(d, s): solver.BoolVar(f"y[{d},{s}]") for d in dnas for s in S}
139
-
140
- # Each cluster in exactly one split
141
- for c in tfs:
142
- solver.Add(sum(x[c, s] for s in S) == 1)
143
- for d in dnas:
144
- solver.Add(sum(y[d, s] for s in S) == 1)
145
-
146
- # Integer kept counts per pair and split (allow partial within-pair)
147
- k = {
148
- ((c, d), s): solver.IntVar(0, w[(c, d)], f"k[{c},{d},{s}]")
149
- for (c, d) in w
150
- for s in S
151
- }
152
-
153
- # Only keep in split s if both endpoint clusters are assigned to s
154
- for (c, d), wt in w.items():
155
- for s in S:
156
- solver.Add(k[((c, d), s)] <= wt * x[c, s])
157
- solver.Add(k[((c, d), s)] <= wt * y[d, s])
158
-
159
- # Enforce minimum kept counts in TEST for required pairs
160
- for (c, d), req in req_test.items():
161
- solver.Add(k[((c, d), "test")] >= req)
162
-
163
- # Optional: ensure each split has at least one cluster (feasibility depends on counts)
164
- if require_nonempty:
165
- for s in S:
166
- solver.Add(sum(x[c, s] for c in tfs) + sum(y[d, s] for d in dnas) >= 1)
167
-
168
- # Kept counts per split and total
169
- K = {s: solver.IntVar(0, N, f"K[{s}]") for s in S}
170
- for s in S:
171
- solver.Add(K[s] == sum(k[((c, d), s)] for (c, d) in w))
172
- T = solver.IntVar(0, N, "T")
173
- solver.Add(T == sum(K[s] for s in S))
174
-
175
- # Ratio deviation: K_s - r_s * T = d+ - d-
176
- dpos = {s: solver.NumVar(0, solver.infinity(), f"dpos[{s}]") for s in S}
177
- dneg = {s: solver.NumVar(0, solver.infinity(), f"dneg[{s}]") for s in S}
178
- for s in S:
179
- solver.Add(K[s] - rs[s] * T == dpos[s] - dneg[s])
180
-
181
- # Optional hard band around target ratios
182
- if ratio_tolerance is not None:
183
- eps = float(ratio_tolerance)
184
- for s in S:
185
- solver.Add(K[s] >= (rs[s] - eps) * T)
186
- solver.Add(K[s] <= (rs[s] + eps) * T)
187
-
188
- # Objective: maximize T then minimize total deviation
189
- obj = solver.Objective()
190
- obj.SetMaximization()
191
- obj.SetCoefficient(T, float(bigM))
192
- for s in S:
193
- obj.SetCoefficient(dpos[s], -1.0)
194
- obj.SetCoefficient(dneg[s], -1.0)
195
-
196
- status = solver.Solve()
197
- if status not in (pywraplp.Solver.OPTIMAL, pywraplp.Solver.FEASIBLE):
198
- raise RuntimeError(
199
- "No feasible solution (check ratio_tolerance vs. required test edges)."
200
- )
201
-
202
- # Read cluster assignments
203
- tf_assign = {c: next(s for s in S if x[c, s].solution_value() > 0.5) for c in tfs}
204
- dna_assign = {d: next(s for s in S if y[d, s].solution_value() > 0.5) for d in dnas}
205
-
206
- # Kept counts per split
207
- kept_by_split = {s: int(round(K[s].solution_value())) for s in S}
208
- total_kept = int(round(T.solution_value()))
209
-
210
- # ---- Build per-example split assignment (including 'leaky_test') ----
211
- split_to_indices = {s: [] for s in S}
212
- remaining_indices = {pair: list(pair_to_indices[pair]) for pair in pair_to_indices}
213
-
214
- # Allocate the kept examples per split (train/val/test)
215
- for (c, d), wt in w.items():
216
- for s in S:
217
- cnt = int(round(k[((c, d), s)].solution_value()))
218
- if cnt > 0:
219
- take = remaining_indices[(c, d)][:cnt]
220
- split_to_indices[s].extend(take)
221
- remaining_indices[(c, d)] = remaining_indices[(c, d)][cnt:]
222
-
223
- # Everything left becomes leaky_test
224
- leaky_indices = []
225
- for pair, idxs in remaining_indices.items():
226
- if idxs:
227
- leaky_indices.extend(idxs)
228
-
229
- split_to_indices["leaky_test"] = leaky_indices
230
- split_to_edges = {
231
- s: [edges[i] for i in split_to_indices[s]] for s in split_to_indices
232
- }
233
-
234
- return (
235
- tf_assign,
236
- dna_assign,
237
- kept_by_split,
238
- total_kept,
239
- split_to_indices,
240
- split_to_edges,
241
- )
242
-
243
-
244
- class DSU:
245
- def __init__(self):
246
- self.p = {}
247
-
248
- def find(self, x):
249
- if x not in self.p:
250
- self.p[x] = x
251
- while self.p[x] != x:
252
- self.p[x] = self.p[self.p[x]]
253
- x = self.p[x]
254
- return x
255
-
256
- def union(self, a, b):
257
- ra, rb = self.find(a), self.find(b)
258
- if ra != rb:
259
- self.p[rb] = ra
260
-
261
-
262
- def split_bipartite_by_components(
263
- edges,
264
- split_names=("train", "val", "test"),
265
- ratios=(0.8, 0.1, 0.1),
266
- seed=0,
267
- require_nonempty=False,
268
- test_edges_must=None, # None, list[(tf,dna)], or dict{(tf,dna): count}
269
- ):
270
- """
271
- Guarantees exclusivity: each TF cluster and DNA cluster appears in at most one split.
272
- Strategy: find connected components in the TF–DNA bipartite graph and assign components wholesale.
273
- """
274
- rng = random.Random(seed)
275
- w = Counter(edges) # multiplicities per pair
276
- if not w:
277
- raise ValueError("No edges.")
278
-
279
- # 1) Build components with Union-Find (prefix to keep TF/DNA namespaces disjoint)
280
- dsu = DSU()
281
- for tf, dna in w:
282
- dsu.union(("T", tf), ("D", dna))
283
- comp_pairs = defaultdict(list)
284
- comp_weight = defaultdict(int)
285
- for (tf, dna), cnt in w.items():
286
- root = dsu.find(("T", tf)) # component id = root of TF endpoint
287
- comp_pairs[root].append((tf, dna))
288
- comp_weight[root] += cnt
289
-
290
- comps = list(comp_pairs.keys())
291
- C = len(comps)
292
- S = list(split_names)
293
- rs = dict(zip(S, ratios))
294
- N = sum(comp_weight[c] for c in comps)
295
- target = {s: int(round(rs[s] * N)) for s in S}
296
-
297
- # 2) Pin components that contain required TEST pairs
298
- pinned = {} # comp_root -> pinned_split ("test")
299
- if test_edges_must:
300
- req = (
301
- Counter(test_edges_must)
302
- if not isinstance(test_edges_must, dict)
303
- else Counter(test_edges_must)
304
- )
305
- # Map each required pair to its component, ensure feasibility
306
- for (tf, dna), r in req.items():
307
- if (tf, dna) not in w:
308
- raise ValueError(f"Required pair {(tf,dna)} not present.")
309
- if r > w[(tf, dna)]:
310
- raise ValueError(
311
- f"Required count {r} for {(tf,dna)} exceeds available {w[(tf,dna)]}."
312
- )
313
- comp = dsu.find(("T", tf))
314
- if comp in pinned and pinned[comp] != "test":
315
- raise ValueError(
316
- f"Component conflict: already pinned to {pinned[comp]}, but {(tf,dna)} demands test."
317
- )
318
- pinned[comp] = "test"
319
- # NOTE: pinning a pair pins the WHOLE component to test (to keep exclusivity).
320
- # If you only want some edges kept in test and discard the rest, handle below when materializing.
321
-
322
- # 3) Assign components greedily by deficit
323
- kept_by_split = {s: 0 for s in S}
324
- comp_assign = {} # comp_root -> split
325
-
326
- # First assign pinned comps
327
- for comp, split in pinned.items():
328
- comp_assign[comp] = split
329
- kept_by_split[split] += comp_weight[comp]
330
-
331
- # Sort remaining components by descending weight
332
- remaining = [c for c in comps if c not in comp_assign]
333
- remaining.sort(key=lambda c: comp_weight[c], reverse=True)
334
-
335
- # Ensure nonempty splits if requested (seed with largest remaining comps)
336
- if require_nonempty:
337
- seeds = remaining[: min(len(S), len(remaining))]
338
- for comp, s in zip(seeds, S):
339
- comp_assign[comp] = s
340
- kept_by_split[s] += comp_weight[comp]
341
- remaining = [c for c in remaining if c not in comp_assign]
342
-
343
- for comp in remaining:
344
- # choose split with largest deficit (target - current)
345
- deficits = {s: target[s] - kept_by_split[s] for s in S}
346
- best = max(deficits, key=lambda s: deficits[s])
347
- comp_assign[comp] = best
348
- kept_by_split[best] += comp_weight[comp]
349
-
350
- total_kept = sum(kept_by_split.values())
351
-
352
- # 4) Materialize per-example indices (and verify exclusivity)
353
- pair_to_indices = defaultdict(list)
354
- for idx, pair in enumerate(edges):
355
- pair_to_indices[pair].append(idx)
356
-
357
- split_to_indices = {s: [] for s in S}
358
- for comp, s in comp_assign.items():
359
- for pair in comp_pairs[comp]:
360
- split_to_indices[s].extend(pair_to_indices[pair])
361
-
362
- # Optional: if you pinned a comp due to a small 'must-test' count but
363
- # want to *discard* the rest instead of keeping them in test, uncomment:
364
- # for comp, s in comp_assign.items():
365
- # if s == "test" and test_edges_must:
366
- # # Keep only the required counts; dump extras to 'leaky_test'
367
- # ...
368
- # (Left out for clarity; default is: keep the whole component in its split.)
369
-
370
- # 5) Build edge lists and simple cluster assignments
371
- split_to_edges = {
372
- s: [edges[i] for i in split_to_indices[s]] for s in split_to_indices
373
- }
374
- tf_assign, dna_assign = {}, {}
375
- for comp, s in comp_assign.items():
376
- for tf, dna in comp_pairs[comp]:
377
- tf_assign[tf] = s
378
- dna_assign[dna] = s
379
-
380
- # 6) Safety check: no DNA/TF appears in multiple splits
381
- tf_in_split = defaultdict(set)
382
- dna_in_split = defaultdict(set)
383
- for s, elist in split_to_edges.items():
384
- for tf, dna in elist:
385
- tf_in_split[tf].add(s)
386
- dna_in_split[dna].add(s)
387
- dup_tf = {tf: ss for tf, ss in tf_in_split.items() if len(ss) > 1}
388
- dup_dna = {dn: ss for dn, ss in dna_in_split.items() if len(ss) > 1}
389
- assert not dup_tf and not dup_dna, f"Exclusivity violated: {dup_tf} {dup_dna}"
390
-
391
- return (
392
- tf_assign,
393
- dna_assign,
394
- kept_by_split,
395
- total_kept,
396
- split_to_indices,
397
- split_to_edges,
398
- )
399
-
400
-
401
  def print_split_ratios(kept_by_split):
402
  total = sum(kept_by_split.values())
403
  train_pcnt = 100 * kept_by_split["train"] / total
@@ -452,39 +188,57 @@ def check_validity(train, val, test, split_by="both"):
452
  assert len(val_ids.intersection(test_ids)) == 0
453
  logger.info(f"Pass! No overlap in IDs")
454
 
455
- if split_by != "dna":
456
- train_tr_seqs = set(train["tr_sequence"].unique().tolist())
457
- val_tr_seqs = set(val["tr_sequence"].unique().tolist())
458
- test_tr_seqs = set(test["tr_sequence"].unique().tolist())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
459
 
 
460
  assert len(train_tr_seqs.intersection(val_tr_seqs)) == 0
461
  assert len(train_tr_seqs.intersection(test_tr_seqs)) == 0
462
  assert len(val_tr_seqs.intersection(test_tr_seqs)) == 0
463
  logger.info(f"Pass! No overlap in TR sequences")
464
 
465
- train_tr_reps = set(train["tr_cluster_rep"].unique().tolist())
466
- val_tr_reps = set(val["tr_cluster_rep"].unique().tolist())
467
- test_tr_reps = set(test["tr_cluster_rep"].unique().tolist())
468
-
469
  assert len(train_tr_reps.intersection(val_tr_reps)) == 0
470
  assert len(train_tr_reps.intersection(test_tr_reps)) == 0
471
  assert len(val_tr_reps.intersection(test_tr_reps)) == 0
472
  logger.info(f"Pass! No overlap in TR cluster reps")
473
 
474
  if split_by != "protein":
475
- train_dna_seqs = set(train["dna_sequence"].unique().tolist())
476
- val_dna_seqs = set(val["dna_sequence"].unique().tolist())
477
- test_dna_seqs = set(test["dna_sequence"].unique().tolist())
478
-
479
  assert len(train_dna_seqs.intersection(val_dna_seqs)) == 0
480
  assert len(train_dna_seqs.intersection(test_dna_seqs)) == 0
481
  assert len(val_dna_seqs.intersection(test_dna_seqs)) == 0
482
  logger.info(f"Pass! No overlap in DNA sequences")
483
 
484
- train_dna_reps = set(train["dna_cluster_rep"].unique().tolist())
485
- val_dna_reps = set(val["dna_cluster_rep"].unique().tolist())
486
- test_dna_reps = set(test["dna_cluster_rep"].unique().tolist())
487
-
488
  assert len(train_dna_reps.intersection(val_dna_reps)) == 0
489
  assert len(train_dna_reps.intersection(test_dna_reps)) == 0
490
  assert len(val_dna_reps.intersection(test_dna_reps)) == 0
@@ -531,8 +285,23 @@ def main(cfg: DictConfig):
531
  logger.info(f"All proteins are in their own clusters: {no_protein_overlap}")
532
 
533
  if cfg.data_task.split_by == "dna":
534
- if cfg.data_task.p_exclude:
535
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
536
  else:
537
  logger.info(f"Easy split: all proteins are in their own clusters.")
538
  dna_clusters = edge_df["dna_cluster_rep"].unique().tolist()
@@ -549,75 +318,42 @@ def main(cfg: DictConfig):
549
 
550
  # assign datapoints to cluster by their DNA cluster rep
551
  edge_df["split"] = edge_df["dna_cluster_rep"].map(dna_assign)
552
- else:
553
- results = split_bipartite_by_components(
554
- edges,
555
- split_names=("train", "val", "test"),
556
- ratios=(
557
- cfg.data_task.train_ratio,
558
- cfg.data_task.val_ratio,
559
- cfg.data_task.test_ratio,
560
- ),
561
- require_nonempty=cfg.data_task.require_nonempty,
562
- seed=cfg.data_task.seed,
563
- test_edges_must=None,
564
- )
565
-
566
- (
567
- tf_assign,
568
- dna_assign,
569
- kept_by_split,
570
- total_kept,
571
- split_to_indices,
572
- split_to_edges,
573
- ) = results
574
-
575
- # Map each sample to its split
576
- print(tf_assign)
577
- print(dna_assign)
578
- edge_df["tr_split"] = edge_df["tr_cluster_rep"].map(tf_assign)
579
- edge_df["dna_split"] = edge_df["dna_cluster_rep"].map(dna_assign)
580
- edge_df["same_split"] = (
581
- edge_df["tr_split"] == edge_df["dna_split"]
582
- ) # should always be true if easy cluster
583
- edge_df["split"] = edge_df["tr_split"]
584
- print(edge_df)
585
- edge_df["split"] = np.where(
586
- edge_df["same_split"],
587
- edge_df["split"], # keep existing split if same_split == True
588
- "leak", # otherwise leak
589
- )
590
- print(edge_df)
591
-
592
- # Print ratios: hopefully close to desired (e.g. 80/10/10)
593
- print_split_ratios(kept_by_split)
594
-
595
- # Make train, val, test sets
596
- # make sure no ID is duplicate
597
- assert len(edge_df["ID"].unique()) == len(edge_df)
598
- split_cols = [
599
- "ID",
600
- "dna_sequence",
601
- "tr_sequence",
602
- "tr_cluster_rep",
603
- "dna_cluster_rep",
604
- "scores",
605
- "split",
606
- ]
607
- train = edge_df.loc[edge_df["split"] == "train"].reset_index(drop=True)[split_cols]
608
- val = edge_df.loc[edge_df["split"] == "val"].reset_index(drop=True)[split_cols]
609
- test = edge_df.loc[edge_df["split"] == "test"].reset_index(drop=True)[split_cols]
610
 
611
  # ensure there is no overlap
612
  check_validity(train, val, test, split_by=cfg.data_task.split_by)
613
 
614
- total = sum([len(train), len(val), len(test)])
615
  logger.info(f"Length of train dataset: {len(train)} ({100*len(train)/total:.2f}%)")
616
  logger.info(f"Length of val dataset: {len(val)} ({100*len(val)/total:.2f}%)")
617
  logger.info(f"Length of test dataset: {len(test)} ({100*len(test)/total:.2f}%)")
 
618
  logger.info(f"Total sequences = {total}. Same as edges size? {total==len(edge_df)}")
619
 
620
- og_unique_dna = pd.concat([train, val, test])
621
  og_unique_dna = len(og_unique_dna["dna_sequence"].unique())
622
 
623
  ## Now do RC data augmentation if asked
@@ -625,23 +361,25 @@ def main(cfg: DictConfig):
625
  train = augment_rc(train)
626
  val = augment_rc(val)
627
  test = augment_rc(test)
 
628
 
629
- logger.info(f"Added reverse complement sequences to train, val, and test.")
630
 
631
  check_validity(train, val, test, split_by=cfg.data_task.split_by)
632
 
633
- total = sum([len(train), len(val), len(test)])
634
  logger.info(
635
  f"Length of train dataset: {len(train)} ({100*len(train)/total:.2f}%)"
636
  )
637
  logger.info(f"Length of val dataset: {len(val)} ({100*len(val)/total:.2f}%)")
638
  logger.info(f"Length of test dataset: {len(test)} ({100*len(test)/total:.2f}%)")
 
639
  logger.info(
640
  f"Total sequences = {total}. Same as edges size? {total==len(edge_df)}"
641
  )
642
 
643
- # since we've added all these new DNA sequences, we do need a new apping of seq id to dna sequence
644
- all_data = pd.concat([train, val, test])
645
  all_data["dna_seqid"] = all_data["ID"].str.split("_", n=1, expand=True)[1]
646
  dna_dict = dict(zip(all_data["dna_seqid"], all_data["dna_sequence"]))
647
  assert len(dna_dict) == len(all_data.drop_duplicates(["dna_sequence"]))
@@ -660,13 +398,15 @@ def main(cfg: DictConfig):
660
  os.makedirs(split_out_dir, exist_ok=True)
661
 
662
  # add binary_scores to allow other training modes
663
- train["fimo_binary_sores"] = train["scores"].apply(lambda x: convert_scores(x))
664
- val["fimo_binary_sores"] = val["scores"].apply(lambda x: convert_scores(x))
665
- test["fimo_binary_sores"] = test["scores"].apply(lambda x: convert_scores(x))
 
666
 
667
  # slect final cols and save
668
  split_final_cols = ["ID", "dna_sequence", "tr_sequence", "scores", "fimo_binary_sores", "split"]
669
  train[split_final_cols].to_csv(split_out_dir / "train.csv", index=False)
670
  val[split_final_cols].to_csv(split_out_dir / "val.csv", index=False)
671
  test[split_final_cols].to_csv(split_out_dir / "test.csv", index=False)
 
672
  logger.info(f"Saved all splits to {split_out_dir}")
 
15
  root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
16
  logger = pylogger.RankedLogger(__name__, rank_zero_only=True)
17
 
18
+ def split_with_predefined_test(
19
+ full_df = pd.DataFrame(),
20
+ split_names=("train", "val", "test"),
21
+ test_trs=None,
22
+ test_dnas=None,
23
+ ratios=(0.8, 0.1, 0.1),
24
+ ):
25
+ """
26
+ Method for splitting into train and val with a predefined test set.
27
+ The proteins in the test set, and the DNA clusters of the DNAs they're associated with, must be excluded from train and val.
28
+ The remaining rows for train and val are split to preserve 80/10/10 as best as possible.
29
+ """
30
+ test = full_df.copy(deep=True)
31
+ if test_trs is not None:
32
+ test = test.loc[test["tr_seqid"].isin(test_trs)].reset_index(drop=True)
33
+ if test_dnas is not None:
34
+ test = test.loc[test["dna_seqid"].isin(test_dnas)].reset_index(drop=True)
35
+
36
+ tr_clusters_to_exclude = test["tr_cluster_rep"].unique().tolist()
37
+ dna_clusters_to_exclude = test["dna_cluster_rep"].unique().tolist()
38
+
39
+ remaining = full_df.loc[
40
+ (~full_df["tr_cluster_rep"].isin(tr_clusters_to_exclude)) &
41
+ (~full_df["dna_cluster_rep"].isin(dna_clusters_to_exclude))
42
+ ].reset_index(drop=True)
43
+
44
+ test_ids = test["ID"].unique().tolist()
45
+ remaining_ids = remaining["ID"].unique().tolist()
46
+ remaining_clusters = remaining["dna_cluster_rep"].unique().tolist()
47
+ lost_rows = full_df.loc[
48
+ (~full_df["ID"].isin(test_ids)) &
49
+ (~full_df["ID"].isin(remaining_ids))
50
+ ]
51
+
52
+ logger.info(f"Rows in test: {len(test)}")
53
+ logger.info(f"Rows to be split between train and val: {len(remaining)}")
54
+ total_rows = len(test) + len(remaining)
55
+ logger.info(f"Total rows: {total_rows}. Test percentage: {100*len(test)/total_rows:.2f}%")
56
+ logger.info(f"Lost rows: {len(lost_rows)}")
57
+
58
+ train_ratio_from_remaining = round((0.8*total_rows)/len(remaining), 2)
59
+ # use sklearn
60
+ test_size_1 = 1 - train_ratio_from_remaining
61
+ logger.info(
62
+ f"\tPerforming first split: non-test clusters -> train clusters ({round(1-test_size_1,3)}) and val ({test_size_1})"
63
+ )
64
+ X = remaining_clusters
65
+ y = [0] * len(remaining_clusters)
66
+ X_train, X_val, y_train, y_val = train_test_split(
67
+ X, y, test_size=test_size_1, random_state=0
68
+ )
69
+
70
+ train = remaining.loc[remaining["dna_cluster_rep"].isin(X_train)]
71
+ val = remaining.loc[remaining["dna_cluster_rep"].isin(X_val)]
72
+ leaky_test = lost_rows
73
+
74
+ kept_by_split = {
75
+ "train": len(X_train),
76
+ "val": len(X_val),
77
+ "test": len(test["dna_cluster_rep"].unique())
78
+ }
79
+ splits = {
80
+ "train": train,
81
+ "val": val,
82
+ "test": test,
83
+ "leaky_test": leaky_test
84
+ }
85
+ return splits, kept_by_split
86
 
87
  def split_bipartite_fast(
88
  dna_clusters,
 
118
  kept_by_split = {"train": len(X_train), "val": len(X_val), "test": len(X_test)}
119
  return dna_assign, kept_by_split
120
 
121
+ # construct new labels
122
+ def convert_scores(scores, mode=1):
123
+ """
124
+ Two modes: 1 means FIMO peaks get 1. 0 means FIMO peaks get their max score
125
+ """
126
  svec = [int(x) for x in scores.split(",")]
127
  max_score = max(svec)
128
+ if mode ==1:
129
+ binary_svec = [0 if x<max_score else 1 for x in svec]
130
+ assert(svec.count(max_score)==binary_svec.count(1))
131
+ else:
132
+ binary_svec = [0 if x<max_score else max_score for x in svec]
133
+ assert(svec.count(max_score)==binary_svec.count(max_score))
134
  binary_svec = ",".join([str(x) for x in binary_svec])
135
  return binary_svec
136
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  def print_split_ratios(kept_by_split):
138
  total = sum(kept_by_split.values())
139
  train_pcnt = 100 * kept_by_split["train"] / total
 
188
  assert len(val_ids.intersection(test_ids)) == 0
189
  logger.info(f"Pass! No overlap in IDs")
190
 
191
+ # Investigate TR intersection. No assertions unless we are explicitly splitting on this.
192
+ train_tr_seqs = set(train["tr_sequence"].unique().tolist())
193
+ val_tr_seqs = set(val["tr_sequence"].unique().tolist())
194
+ test_tr_seqs = set(test["tr_sequence"].unique().tolist())
195
+
196
+ train_tr_reps = set(train["tr_cluster_rep"].unique().tolist())
197
+ val_tr_reps = set(val["tr_cluster_rep"].unique().tolist())
198
+ test_tr_reps = set(test["tr_cluster_rep"].unique().tolist())
199
+
200
+ logger.info(f"Train-Val TR intersection: {len(train_tr_seqs.intersection(val_tr_seqs))}")
201
+ logger.info(f"Train-Test TR intersection: {len(train_tr_seqs.intersection(test_tr_seqs))}")
202
+ logger.info(f"Val-Test TR intersection: {len(val_tr_seqs.intersection(test_tr_seqs))}")
203
+
204
+ logger.info(f"Train-Val TR Cluster Rep intersection: {len(train_tr_reps.intersection(val_tr_reps))}")
205
+ logger.info(f"Train-Test TR Cluster Rep intersection: {len(train_tr_reps.intersection(test_tr_reps))}")
206
+ logger.info(f"Val-Test TR Cluster Rep intersection: {len(val_tr_reps.intersection(test_tr_reps))}")
207
+
208
+ # Investigate DNA intersection. No assertions unless we are explicitly splitting on this.
209
+ train_dna_seqs = set(train["dna_sequence"].unique().tolist())
210
+ val_dna_seqs = set(val["dna_sequence"].unique().tolist())
211
+ test_dna_seqs = set(test["dna_sequence"].unique().tolist())
212
+
213
+ train_dna_reps = set(train["dna_cluster_rep"].unique().tolist())
214
+ val_dna_reps = set(val["dna_cluster_rep"].unique().tolist())
215
+ test_dna_reps = set(test["dna_cluster_rep"].unique().tolist())
216
+
217
+ logger.info(f"Train-Val DNA intersection: {len(train_dna_seqs.intersection(val_dna_seqs))}")
218
+ logger.info(f"Train-Test DNA intersection: {len(train_dna_seqs.intersection(test_dna_seqs))}")
219
+ logger.info(f"Val-Test DNA intersection: {len(val_dna_seqs.intersection(test_dna_seqs))}")
220
+
221
+ logger.info(f"Train-Val DNA Cluster Rep intersection: {len(train_dna_reps.intersection(val_dna_reps))}")
222
+ logger.info(f"Train-Test DNA Cluster Rep intersection: {len(train_dna_reps.intersection(test_dna_reps))}")
223
+ logger.info(f"Val-Test DNA Cluster Rep intersection: {len(val_dna_reps.intersection(test_dna_reps))}")
224
 
225
+ if split_by != "dna":
226
  assert len(train_tr_seqs.intersection(val_tr_seqs)) == 0
227
  assert len(train_tr_seqs.intersection(test_tr_seqs)) == 0
228
  assert len(val_tr_seqs.intersection(test_tr_seqs)) == 0
229
  logger.info(f"Pass! No overlap in TR sequences")
230
 
 
 
 
 
231
  assert len(train_tr_reps.intersection(val_tr_reps)) == 0
232
  assert len(train_tr_reps.intersection(test_tr_reps)) == 0
233
  assert len(val_tr_reps.intersection(test_tr_reps)) == 0
234
  logger.info(f"Pass! No overlap in TR cluster reps")
235
 
236
  if split_by != "protein":
 
 
 
 
237
  assert len(train_dna_seqs.intersection(val_dna_seqs)) == 0
238
  assert len(train_dna_seqs.intersection(test_dna_seqs)) == 0
239
  assert len(val_dna_seqs.intersection(test_dna_seqs)) == 0
240
  logger.info(f"Pass! No overlap in DNA sequences")
241
 
 
 
 
 
242
  assert len(train_dna_reps.intersection(val_dna_reps)) == 0
243
  assert len(train_dna_reps.intersection(test_dna_reps)) == 0
244
  assert len(val_dna_reps.intersection(test_dna_reps)) == 0
 
285
  logger.info(f"All proteins are in their own clusters: {no_protein_overlap}")
286
 
287
  if cfg.data_task.split_by == "dna":
288
+ if cfg.data_task.test_trs or cfg.data_task.test_dnas:
289
+ logger.info(f"Splitting with predefined trs/dnas reserved for test set")
290
+ splits, kept_by_split = split_with_predefined_test(
291
+ full_df=edge_df,
292
+ split_names=("train", "val", "test"),
293
+ test_trs=cfg.data_task.test_trs if cfg.data_task.test_trs else None,
294
+ test_dnas=cfg.data_task.test_dnas if cfg.data_task.test_dnas else None,
295
+ ratios=(0.8, 0.1, 0.1),
296
+ )
297
+ train = splits["train"]
298
+ train["split"]=["train"]*len(train)
299
+ val = splits["val"]
300
+ val["split"]=["val"]*len(val)
301
+ test = splits["test"]
302
+ test["split"]=["test"]*len(test)
303
+ leaky_test = splits["leaky_test"]
304
+ leaky_test["split"]=["leaky_test"]*len(leaky_test)
305
  else:
306
  logger.info(f"Easy split: all proteins are in their own clusters.")
307
  dna_clusters = edge_df["dna_cluster_rep"].unique().tolist()
 
318
 
319
  # assign datapoints to cluster by their DNA cluster rep
320
  edge_df["split"] = edge_df["dna_cluster_rep"].map(dna_assign)
321
+ train = edge_df.loc[edge_df["split"] == "train"].reset_index(drop=True)
322
+ val = edge_df.loc[edge_df["split"] == "val"].reset_index(drop=True)
323
+ test = edge_df.loc[edge_df["split"] == "test"].reset_index(drop=True)
324
+ leaky_test = pd.DataFrame(columns=edge_df.columns)
325
+
326
+ # Print ratios: hopefully close to desired (e.g. 80/10/10)
327
+ print_split_ratios(kept_by_split)
328
+
329
+ # Make train, val, test sets
330
+ # make sure no ID is duplicate
331
+ assert len(edge_df["ID"].unique()) == len(edge_df)
332
+ split_cols = [
333
+ "ID",
334
+ "dna_sequence",
335
+ "tr_sequence",
336
+ "tr_cluster_rep",
337
+ "dna_cluster_rep",
338
+ "scores",
339
+ "split",
340
+ ]
341
+ train = train[split_cols]
342
+ val = val[split_cols]
343
+ test = test[split_cols]
344
+ leaky_test = leaky_test[split_cols]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
 
346
  # ensure there is no overlap
347
  check_validity(train, val, test, split_by=cfg.data_task.split_by)
348
 
349
+ total = sum([len(train), len(val), len(test), len(leaky_test)])
350
  logger.info(f"Length of train dataset: {len(train)} ({100*len(train)/total:.2f}%)")
351
  logger.info(f"Length of val dataset: {len(val)} ({100*len(val)/total:.2f}%)")
352
  logger.info(f"Length of test dataset: {len(test)} ({100*len(test)/total:.2f}%)")
353
+ logger.info(f"Length of leaky_test dataset: {len(leaky_test)} ({100*len(leaky_test)/total:.2f}%)")
354
  logger.info(f"Total sequences = {total}. Same as edges size? {total==len(edge_df)}")
355
 
356
+ og_unique_dna = pd.concat([train, val, test, leaky_test])
357
  og_unique_dna = len(og_unique_dna["dna_sequence"].unique())
358
 
359
  ## Now do RC data augmentation if asked
 
361
  train = augment_rc(train)
362
  val = augment_rc(val)
363
  test = augment_rc(test)
364
+ leaky_test = augment_rc(leaky_test)
365
 
366
+ logger.info(f"Added reverse complement sequences to train, val, and test (and leaky test)")
367
 
368
  check_validity(train, val, test, split_by=cfg.data_task.split_by)
369
 
370
+ total = sum([len(train), len(val), len(test), len(leaky_test)])
371
  logger.info(
372
  f"Length of train dataset: {len(train)} ({100*len(train)/total:.2f}%)"
373
  )
374
  logger.info(f"Length of val dataset: {len(val)} ({100*len(val)/total:.2f}%)")
375
  logger.info(f"Length of test dataset: {len(test)} ({100*len(test)/total:.2f}%)")
376
+ logger.info(f"Length of leaky_test dataset: {len(leaky_test)} ({100*len(leaky_test)/total:.2f}%)")
377
  logger.info(
378
  f"Total sequences = {total}. Same as edges size? {total==len(edge_df)}"
379
  )
380
 
381
+ # since we've added all these new DNA sequences, we do need a new mapping of seq id to dna sequence
382
+ all_data = pd.concat([train, val, test, leaky_test])
383
  all_data["dna_seqid"] = all_data["ID"].str.split("_", n=1, expand=True)[1]
384
  dna_dict = dict(zip(all_data["dna_seqid"], all_data["dna_sequence"]))
385
  assert len(dna_dict) == len(all_data.drop_duplicates(["dna_sequence"]))
 
398
  os.makedirs(split_out_dir, exist_ok=True)
399
 
400
  # add binary_scores to allow other training modes
401
+ train["fimo_binary_sores"] = train["scores"].apply(lambda x: convert_scores(x, mode=1))
402
+ val["fimo_binary_sores"] = val["scores"].apply(lambda x: convert_scores(x, mode=1))
403
+ test["fimo_binary_sores"] = test["scores"].apply(lambda x: convert_scores(x, mode=1))
404
+ leaky_test["fimo_binary_sores"] = leaky_test["scores"].apply(lambda x: convert_scores(x, mode=1))
405
 
406
  # slect final cols and save
407
  split_final_cols = ["ID", "dna_sequence", "tr_sequence", "scores", "fimo_binary_sores", "split"]
408
  train[split_final_cols].to_csv(split_out_dir / "train.csv", index=False)
409
  val[split_final_cols].to_csv(split_out_dir / "val.csv", index=False)
410
  test[split_final_cols].to_csv(split_out_dir / "test.csv", index=False)
411
+ leaky_test[split_final_cols].to_csv(split_out_dir / "leaky_test.csv", index=False)
412
  logger.info(f"Saved all splits to {split_out_dir}")
dpacman/data_tasks/split/remap_handpick.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Not neat, but this is what I did to make exclusive splits. saving here for now.
3
+ """
4
+
5
+ ## Full pipeline
6
+ import pandas as pd
7
+ protein_clusters = pd.read_csv("/home/a03-svincoff/DPACMAN/dpacman/data_files/processed/mmseqs/outputs/fimo_hits_only/protein/mmseqs_cluster.tsv", sep="\t", header=None)
8
+ protein_clusters.columns=["tr_cluster_rep","tr_cluster_member"]
9
+ protein_clusters.head()
10
+
11
+ dna_clusters = pd.read_csv("/home/a03-svincoff/DPACMAN/dpacman/data_files/processed/mmseqs/outputs/fimo_hits_only/dna_full/mmseqs_cluster.tsv", sep="\t", header=None)
12
+ dna_clusters.columns=["dna_cluster_rep","dna_cluster_member"]
13
+ dna_clusters.head()
14
+
15
+ all_data = pd.read_parquet("/home/a03-svincoff/DPACMAN/dpacman/data_files/processed/fimo/post_fimo/fimo_hits_only/remap2022_crm_fimo_output_q_processed_seed0.parquet")
16
+ all_data
17
+
18
+ protein_cluster_map = dict(zip(protein_clusters["tr_cluster_member"],protein_clusters["tr_cluster_rep"]))
19
+ dna_cluster_map = dict(zip(dna_clusters["dna_cluster_member"],dna_clusters["dna_cluster_rep"]))
20
+ print(len(protein_cluster_map))
21
+ print(len(dna_cluster_map))
22
+ all_data["tr_cluster_rep"] = all_data["tr_seqid"].map(protein_cluster_map)
23
+ all_data["dna_cluster_rep"] = all_data["dna_seqid"].map(dna_cluster_map)
24
+ print(len(all_data[all_data["tr_cluster_rep"].isna()]))
25
+ print(len(all_data[all_data["dna_cluster_rep"].isna()]))
26
+ all_data.head()
27
+
28
+
29
+ ### handpick test
30
+ handpicked_test_trs = ["trseq23","trseq26","trseq17"]
31
+ handpicked_test = all_data.loc[
32
+ all_data["tr_cluster_rep"].isin(handpicked_test_trs)
33
+ ].reset_index(drop=True)
34
+
35
+ off_limits_dna_clusters = handpicked_test["dna_cluster_rep"].unique().tolist()
36
+ remaining = all_data.loc[
37
+ (~all_data["tr_cluster_rep"].isin(handpicked_test_trs)) &
38
+ (~all_data["dna_cluster_rep"].isin(off_limits_dna_clusters))
39
+ ].reset_index(drop=True)
40
+
41
+ test_ids = handpicked_test["ID"].unique().tolist()
42
+ remaining_ids = remaining["ID"].unique().tolist()
43
+ lost_rows = all_data.loc[
44
+ (~all_data["ID"].isin(test_ids)) &
45
+ (~all_data["ID"].isin(remaining_ids))
46
+ ]
47
+ print(f"Rows in test: {len(handpicked_test)}")
48
+ print(f"Rows to be split between train and val: {len(remaining)}")
49
+ total_rows = len(handpicked_test) + len(remaining)
50
+ print(f"Total rows: {total_rows}. Test percentage: {100*len(handpicked_test)/total_rows:.2f}%")
51
+ print(f"Lost rows: {len(lost_rows)}")
52
+
53
+ ### handpick val
54
+ handpicked_val_trs = ["trseq9", "trseq5", "trseq28"]
55
+
56
+ handpicked_val = remaining.loc[
57
+ remaining["tr_cluster_rep"].isin(handpicked_val_trs)
58
+ ].reset_index(drop=True)
59
+
60
+ off_limits_dna_clusters = handpicked_val["dna_cluster_rep"].unique().tolist()
61
+ train_remain = remaining.loc[
62
+ (~remaining["tr_cluster_rep"].isin(handpicked_val_trs)) &
63
+ (~remaining["dna_cluster_rep"].isin(off_limits_dna_clusters))
64
+ ].reset_index(drop=True)
65
+
66
+ val_ids = handpicked_val["ID"].unique().tolist()
67
+ train_remain_ids = train_remain["ID"].unique().tolist()
68
+ lost_rows = all_data.loc[
69
+ (~all_data["ID"].isin(test_ids)) &
70
+ (~all_data["ID"].isin(val_ids)) &
71
+ (~all_data["ID"].isin(train_remain_ids))
72
+ ]
73
+ print(f"Rows in val: {len(handpicked_val)}")
74
+ print(f"Rows left for train: {len(train_remain)}")
75
+ total_rows = len(handpicked_val) + len(train_remain)
76
+ print(f"Total rows: {total_rows}. Test percentage: {100*len(handpicked_val)/total_rows:.2f}%")
77
+ print(f"Lost rows: {len(lost_rows)}")
78
+
79
+ train_exclusive = all_data.loc[
80
+ all_data["ID"].isin(train_remain_ids)
81
+ ].reset_index(drop=True)
82
+
83
+ val_exclusive = all_data.loc[
84
+ all_data["ID"].isin(val_ids)
85
+ ].reset_index(drop=True)
86
+
87
+ test_exclusive = all_data.loc[
88
+ all_data["ID"].isin(test_ids)
89
+ ].reset_index(drop=True)
90
+
91
+ leaky_test = all_data.loc[
92
+ ~(all_data["ID"].isin(train_exclusive["ID"].tolist())) &
93
+ ~(all_data["ID"].isin(val_exclusive["ID"].tolist())) &
94
+ ~(all_data["ID"].isin(test_exclusive["ID"].tolist()))
95
+ ].reset_index(drop=True)
96
+
97
+ print(f"Original total: {len(all_data)}")
98
+ retained_total = len(train_exclusive)+len(val_exclusive)+len(test_exclusive)
99
+ print(f"New, exclusive total: {retained_total}")
100
+ print(f"Lost rows: {len(all_data)-retained_total}")
101
+ print(f"Length train: {len(train_exclusive)}/{retained_total} ({100*len(train_exclusive)/retained_total:.2f}%)")
102
+ print(f"Length val: {len(val_exclusive)}/{retained_total} ({100*len(val_exclusive)/retained_total:.2f}%)")
103
+ print(f"Length test: {len(test_exclusive)}/{retained_total} ({100*len(test_exclusive)/retained_total:.2f}%)")
104
+
105
+ def check_validity(train_exclusive, val_exclusive, test_exclusive):
106
+ train_exclusive_ids = set(train_exclusive["ID"].unique().tolist())
107
+ val_exclusive_ids = set(val_exclusive["ID"].unique().tolist())
108
+ test_exclusive_ids = set(test_exclusive["ID"].unique().tolist())
109
+
110
+ assert len(train_exclusive_ids.intersection(val_exclusive_ids)) == 0
111
+ assert len(train_exclusive_ids.intersection(test_exclusive_ids)) == 0
112
+ assert len(val_exclusive_ids.intersection(test_exclusive_ids)) == 0
113
+ print(f"Pass! No overlap in IDs")
114
+
115
+ # Investigate TR intersection. No assertions unless we are explicitly splitting on this.
116
+ train_exclusive_tr_seqs = set(train_exclusive["tr_sequence"].unique().tolist())
117
+ val_exclusive_tr_seqs = set(val_exclusive["tr_sequence"].unique().tolist())
118
+ test_exclusive_tr_seqs = set(test_exclusive["tr_sequence"].unique().tolist())
119
+
120
+ train_exclusive_tr_reps = set(train_exclusive["tr_cluster_rep"].unique().tolist())
121
+ val_exclusive_tr_reps = set(val_exclusive["tr_cluster_rep"].unique().tolist())
122
+ test_exclusive_tr_reps = set(test_exclusive["tr_cluster_rep"].unique().tolist())
123
+
124
+ print(f"Train-Val TR intersection: {len(train_exclusive_tr_seqs.intersection(val_exclusive_tr_seqs))}")
125
+ print(f"Train-Test TR intersection: {len(train_exclusive_tr_seqs.intersection(test_exclusive_tr_seqs))}")
126
+ print(f"Val-Test TR intersection: {len(val_exclusive_tr_seqs.intersection(test_exclusive_tr_seqs))}")
127
+
128
+ print(f"Train-Val TR Cluster Rep intersection: {len(train_exclusive_tr_reps.intersection(val_exclusive_tr_reps))}")
129
+ print(f"Train-Test TR Cluster Rep intersection: {len(train_exclusive_tr_reps.intersection(test_exclusive_tr_reps))}")
130
+ print(f"Val-Test TR Cluster Rep intersection: {len(val_exclusive_tr_reps.intersection(test_exclusive_tr_reps))}")
131
+
132
+ # Investigate DNA intersection. No assertions unless we are explicitly splitting on this.
133
+ train_exclusive_dna_seqs = set(train_exclusive["dna_sequence"].unique().tolist())
134
+ val_exclusive_dna_seqs = set(val_exclusive["dna_sequence"].unique().tolist())
135
+ test_exclusive_dna_seqs = set(test_exclusive["dna_sequence"].unique().tolist())
136
+
137
+ train_exclusive_dna_reps = set(train_exclusive["dna_cluster_rep"].unique().tolist())
138
+ val_exclusive_dna_reps = set(val_exclusive["dna_cluster_rep"].unique().tolist())
139
+ test_exclusive_dna_reps = set(test_exclusive["dna_cluster_rep"].unique().tolist())
140
+
141
+ print(f"Train-Val DNA intersection: {len(train_exclusive_dna_seqs.intersection(val_exclusive_dna_seqs))}")
142
+ print(f"Train-Test DNA intersection: {len(train_exclusive_dna_seqs.intersection(test_exclusive_dna_seqs))}")
143
+ print(f"Val-Test DNA intersection: {len(val_exclusive_dna_seqs.intersection(test_exclusive_dna_seqs))}")
144
+
145
+ print(f"Train-Val DNA Cluster Rep intersection: {len(train_exclusive_dna_reps.intersection(val_exclusive_dna_reps))}")
146
+ print(f"Train-Test DNA Cluster Rep intersection: {len(train_exclusive_dna_reps.intersection(test_exclusive_dna_reps))}")
147
+ print(f"Val-Test DNA Cluster Rep intersection: {len(val_exclusive_dna_reps.intersection(test_exclusive_dna_reps))}")
148
+
149
+ def get_reverse_complement(s):
150
+ """
151
+ Returns 5' to 3' sequence of the reverse complement
152
+ """
153
+ chars = list(s)
154
+ recon = []
155
+ rev_map = {
156
+ "a": "t",
157
+ "c": "g",
158
+ "t": "a",
159
+ "g": "c",
160
+ "A": "T",
161
+ "C": "G",
162
+ "T": "A",
163
+ "G": "C",
164
+ "n": "n",
165
+ "N": "N",
166
+ }
167
+ for c in chars:
168
+ recon += [rev_map[c]]
169
+
170
+ recon = "".join(recon)
171
+ return recon[::-1]
172
+
173
+ # now make reverse complements
174
+ def augment_rc(df):
175
+ """
176
+ Get the reverse complement and add it as a datapoint, effectively doubling the dataset.
177
+ Also flip the orientation of the scores
178
+
179
+ columns = ["ID","dna_sequence","tr_sequence","tr_cluster_rep","dna_cluster_rep", "scores","split"]
180
+ """
181
+ df_rc = df.copy(deep=True)
182
+
183
+ df_rc["dna_sequence"] = df_rc["dna_sequence"].apply(
184
+ lambda x: get_reverse_complement(x)
185
+ )
186
+ df_rc["ID"] = df_rc["ID"] + "_rc"
187
+ df_rc["scores"] = df_rc["scores"].apply(lambda s: ",".join(s.split(",")[::-1]))
188
+
189
+ final_df = pd.concat([df, df_rc]).reset_index(drop=True)
190
+
191
+ return final_df
192
+
193
+ def convert_scores(scores, mode=1):
194
+ """
195
+ Two modes: 1 means FIMO peaks get 1. 0 means FIMO peaks get their max score
196
+ """
197
+ svec = [int(x) for x in scores.split(",")]
198
+ max_score = max(svec)
199
+ if mode ==1:
200
+ binary_svec = [0 if x<max_score else 1 for x in svec]
201
+ assert(svec.count(max_score)==binary_svec.count(1))
202
+ else:
203
+ binary_svec = [0 if x<max_score else max_score for x in svec]
204
+ assert(svec.count(max_score)==binary_svec.count(max_score))
205
+ binary_svec = ",".join([str(x) for x in binary_svec])
206
+ return binary_svec
207
+
208
+ check_validity(train_exclusive, val_exclusive, test_exclusive)
209
+
210
+ train_exclusive = augment_rc(train_exclusive)
211
+ val_exclusive = augment_rc(val_exclusive)
212
+ test_exclusive = augment_rc(test_exclusive)
213
+ leaky_test = augment_rc(leaky_test)
214
+
215
+ print(f"Added reverse complement sequences to train_exclusive, val_exclusive, and test_exclusive (and leaky test_exclusive)")
216
+
217
+ check_validity(train_exclusive, val_exclusive, test_exclusive)
218
+
219
+ total = sum([len(train_exclusive), len(val_exclusive), len(test_exclusive), len(leaky_test)])
220
+ print(
221
+ f"Length of train_exclusive dataset: {len(train_exclusive)} ({100*len(train_exclusive)/total:.2f}%)"
222
+ )
223
+ print(f"Length of val_exclusive dataset: {len(val_exclusive)} ({100*len(val_exclusive)/total:.2f}%)")
224
+ print(f"Length of test_exclusive dataset: {len(test_exclusive)} ({100*len(test_exclusive)/total:.2f}%)")
225
+ print(f"Length of leaky_test dataset: {len(leaky_test)} ({100*len(leaky_test)/total:.2f}%)")
226
+ print(
227
+ f"Total sequences = {total}. Same as edges size*2? {total==len(all_data)*2}"
228
+ )
229
+
230
+ # since we've added all these new DNA sequences, we do need a new mapping of seq id to dna sequence
231
+ all_data = pd.concat([train_exclusive, val_exclusive, test_exclusive, leaky_test])
232
+ all_data["dna_seqid"] = all_data["ID"].str.split("_", n=1, expand=True)[1]
233
+ dna_dict = dict(zip(all_data["dna_seqid"], all_data["dna_sequence"]))
234
+ assert len(dna_dict) == len(all_data.drop_duplicates(["dna_sequence"]))
235
+
236
+ # create the output dir
237
+ import os
238
+ from pathlib import Path
239
+ split_out_dir = Path("/home/a03-svincoff/DPACMAN/dpacman/data_files/processed/splits/handpicked_val_test")
240
+ os.makedirs(split_out_dir, exist_ok=True)
241
+
242
+ # add binary_scores to allow other training modes
243
+ train_exclusive["binary_sores"] = train_exclusive["scores"].apply(lambda x: convert_scores(x, mode=1))
244
+ val_exclusive["binary_sores"] = val_exclusive["scores"].apply(lambda x: convert_scores(x, mode=1))
245
+ test_exclusive["binary_sores"] = test_exclusive["scores"].apply(lambda x: convert_scores(x, mode=1))
246
+ leaky_test["binary_sores"] = leaky_test["scores"].apply(lambda x: convert_scores(x, mode=1))
247
+
248
+ train_exclusive["split"] = ["train"]*len(train_exclusive)
249
+ val_exclusive["split"] = ["val"]*len(val_exclusive)
250
+ test_exclusive["split"] = ["test"]*len(test_exclusive)
251
+ leaky_test["split"] = ["leakytest"]*len(leaky_test)
252
+
253
+ # slect final cols and save
254
+ split_final_cols = ["ID", "dna_sequence", "tr_sequence", "scores", "binary_sores", "split"]
255
+ train_exclusive[split_final_cols].to_csv(split_out_dir / "train.csv", index=False)
256
+ val_exclusive[split_final_cols].to_csv(split_out_dir / "val.csv", index=False)
257
+ test_exclusive[split_final_cols].to_csv(split_out_dir / "test.csv", index=False)
258
+ leaky_test[split_final_cols].to_csv(split_out_dir / "leakytest.csv", index=False)
259
+ print(f"Saved all splits to {split_out_dir}")
260
+
261
+ # make baby versions too
262
+ train_exclusive[split_final_cols].sample(400, random_state=42).to_csv(split_out_dir / "babytrain.csv", index=False)
263
+ val_exclusive[split_final_cols].sample(50, random_state=42).to_csv(split_out_dir / "babyval.csv", index=False)
264
+ test_exclusive[split_final_cols].sample(50, random_state=42).to_csv(split_out_dir / "babytest.csv", index=False)
265
+ leaky_test[split_final_cols].sample(50, random_state=42).to_csv(split_out_dir / "babyleakytest.csv", index=False)
266
+
dpacman/find_wandb_run_name.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import os
3
+ import argparse
4
+ from typing import Optional
5
+
6
+ SENTINEL = "View project at"
7
+ SKIP_DIRS = {".git", ".hg", ".svn", "__pycache__", ".mamba", ".conda"}
8
+
9
+ def extract_run_name(log_path: str) -> Optional[str]:
10
+ """
11
+ Return the run name if we find a line that:
12
+ - starts with 'wandb:' (after leading whitespace is stripped)
13
+ - whose second-to-last word is 'run' and the last word is the run name
14
+ and this occurs before the first line containing SENTINEL.
15
+ Otherwise return None.
16
+ """
17
+ try:
18
+ with open(log_path, "r", errors="ignore") as f:
19
+ for raw in f:
20
+ if SENTINEL in raw:
21
+ return None
22
+ line = raw.strip()
23
+ if not line.startswith("wandb: "):
24
+ continue
25
+ toks = line.split()
26
+ if len(toks) >= 2 and toks[-2] == "run":
27
+ return toks[-1] # the run name (last token)
28
+ except OSError:
29
+ return None
30
+ return None
31
+
32
+ def list_runs(root: str, followlinks: bool = False, do_rename=False):
33
+ for dirpath, dirs, files in os.walk(root, followlinks=followlinks):
34
+ # prune junk dirs
35
+ dirs[:] = [d for d in dirs if d not in SKIP_DIRS]
36
+ if "run.log" in files:
37
+ log_path = os.path.join(dirpath, "run.log")
38
+ name = extract_run_name(log_path)
39
+ new_dir_path = dirpath
40
+ if name:
41
+ if name not in dirpath:
42
+ new_dir_path = f"{dirpath}-{name}"
43
+ print(f"{name if name else '<unknown>'}\t{new_dir_path}")
44
+
45
+ if do_rename and new_dir_path != dirpath:
46
+ parent = os.path.dirname(dirpath)
47
+ # resolve absolute path for safety
48
+ abs_old = os.path.abspath(dirpath)
49
+ abs_new = os.path.abspath(new_dir_path)
50
+
51
+ if os.path.exists(abs_new):
52
+ print(f"⚠️ Target {abs_new} already exists, skipping rename.")
53
+ else:
54
+ print(f"Renaming {abs_old} → {abs_new}")
55
+ os.rename(abs_old, abs_new)
56
+
57
+ def main():
58
+ ap = argparse.ArgumentParser(
59
+ description="List W&B run names by parsing lines that start with 'wandb:' and end with 'run <name>'."
60
+ )
61
+ ap.add_argument("root", help="Root directory to search")
62
+ ap.add_argument("--followlinks", action="store_true", help="Follow symlinks while walking")
63
+ args = ap.parse_args()
64
+ list_runs(args.root, followlinks=args.followlinks)
65
+
66
+ if __name__ == "__main__":
67
+ main()
dpacman/make_splits.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
dpacman/manual_scan_chroms.ipynb ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "f6c01484",
6
+ "metadata": {},
7
+ "source": [
8
+ "Temporary notebook for manually scanning chromosomes for sequences of interest"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": 1,
14
+ "id": "0608f91e",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "seq_of_interest = \"GCAGATCTGCACATC\""
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": 2,
24
+ "id": "3c245151",
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": [
28
+ "genome_dir = \"/home/a03-svincoff/DPACMAN/dpacman/data_files/raw/genomes/hg38\""
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "code",
33
+ "execution_count": 7,
34
+ "id": "682098b6",
35
+ "metadata": {},
36
+ "outputs": [
37
+ {
38
+ "name": "stdout",
39
+ "output_type": "stream",
40
+ "text": [
41
+ "dict_keys(['chr12', 'chr5', 'chr17', 'chr2', 'chr21', 'chr1', 'chrM', 'chr22', 'chr20', 'chr16', 'chr9', 'chr8', 'chr19', 'chr7', 'chr11', 'chr3', 'chr4', 'chr14', 'chr15', 'chr18', 'chrY', 'chr6', 'chrX', 'chr13', 'chr10'])\n"
42
+ ]
43
+ }
44
+ ],
45
+ "source": [
46
+ "import json\n",
47
+ "import os\n",
48
+ "chrom_cache = {}\n",
49
+ "for chrom_file in os.listdir(genome_dir):\n",
50
+ " chrom = chrom_file.split(\"hg38_\")[1].split(\".json\")[0]\n",
51
+ " with open(f\"{genome_dir}/{chrom_file}\", \"r\") as f:\n",
52
+ " chrom_cache[chrom] = json.load(f)\n",
53
+ "\n",
54
+ "print(chrom_cache.keys())"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": null,
60
+ "id": "fd6cca79",
61
+ "metadata": {},
62
+ "outputs": [
63
+ {
64
+ "name": "stdout",
65
+ "output_type": "stream",
66
+ "text": [
67
+ "Testing sequence A: TAGCAGGATGTGT\n",
68
+ "Testing sequence B: GCAGATCTGCACATC\n",
69
+ "Testing sequence C: CGACACCTGACGCG\n",
70
+ "Testing sequence D: CGCTATCCAGAGCG\n",
71
+ "Testing sequence E: CGCGATGCTTCTCG\n",
72
+ "Testing sequence F: CGGCTGGATTACCG\n",
73
+ "Testing sequence G: CGAGAACATAGTCG\n",
74
+ "Testing sequence H: CGGGGAAACGCCCG\n",
75
+ "Testing sequence I: CGCCCAAAGCCGCG\n",
76
+ "Testing sequence J: CGGAGGTAATGACG\n",
77
+ "Testing sequence K: CGCACCGACTCACG\n",
78
+ "Testing sequence L: CGGCCCTTTGCGCG\n",
79
+ "Testing sequence M: CGCCGTTAGTGTCG\n"
80
+ ]
81
+ }
82
+ ],
83
+ "source": [
84
+ "baker_sequences = {\n",
85
+ " \"A\": \"TAGCAGGATGTGT\",\n",
86
+ " \"B\": \"GCAGATCTGCACATC\",\n",
87
+ " \"C\": \"CGACACCTGACGCG\",\n",
88
+ " \"D\": \"CGCTATCCAGAGCG\",\n",
89
+ " \"E\": \"CGCGATGCTTCTCG\",\n",
90
+ " \"F\": \"CGGCTGGATTACCG\",\n",
91
+ " \"G\": \"CGAGAACATAGTCG\",\n",
92
+ " \"H\": \"CGGGGAAACGCCCG\",\n",
93
+ " \"I\": \"CGCCCAAAGCCGCG\",\n",
94
+ " \"J\": \"CGGAGGTAATGACG\",\n",
95
+ " \"K\": \"CGCACCGACTCACG\",\n",
96
+ " \"L\": \"CGGCCCTTTGCGCG\",\n",
97
+ " \"M\": \"CGCCGTTAGTGTCG\"\n",
98
+ "}\n",
99
+ "sorted_chroms = list(chrom_cache.keys())\n",
100
+ "sorted_chroms = sorted(sorted_chroms, key = lambda x: int(x.split(\"chr\")[1]) if x.split(\"chr\")[1] not in [\"M\",\"X\",\"Y\"] else 0)\n",
101
+ "\n",
102
+ "for seq_letter, seq in baker_sequences.items():\n",
103
+ " print(f\"Testing sequence {seq_letter}: {seq}\")\n",
104
+ " for chrom in sorted_chroms:\n",
105
+ " chrom_dna = chrom_cache[chrom][\"dna\"].upper()\n",
106
+ " match_chroms=[]\n",
107
+ " try:\n",
108
+ " print(f\"\\tChrom {chrom} index of sequence {seq_letter} ({seq}): {chrom.index(seq)}\")\n",
109
+ " match_chroms+=[chrom]\n",
110
+ " except:\n",
111
+ " match_chroms = match_chroms\n",
112
+ " print(f\"\\tChrom {chrom} does not have sequence {seq_letter}({seq})\")"
113
+ ]
114
+ },
115
+ {
116
+ "cell_type": "code",
117
+ "execution_count": null,
118
+ "id": "99078ac6",
119
+ "metadata": {},
120
+ "outputs": [],
121
+ "source": [
122
+ "#[m.start() for m in re.finditer(chr1_dna, 'GCAGATCTGCACATC')]"
123
+ ]
124
+ }
125
+ ],
126
+ "metadata": {
127
+ "kernelspec": {
128
+ "display_name": "dnabind2",
129
+ "language": "python",
130
+ "name": "python3"
131
+ },
132
+ "language_info": {
133
+ "codemirror_mode": {
134
+ "name": "ipython",
135
+ "version": 3
136
+ },
137
+ "file_extension": ".py",
138
+ "mimetype": "text/x-python",
139
+ "name": "python",
140
+ "nbconvert_exporter": "python",
141
+ "pygments_lexer": "ipython3",
142
+ "version": "3.10.14"
143
+ }
144
+ },
145
+ "nbformat": 4,
146
+ "nbformat_minor": 5
147
+ }
dpacman/scripts/delay_run.sh CHANGED
@@ -1,10 +1,10 @@
1
  #!/usr/bin/env bash
2
  set -euo pipefail
3
 
4
- # Usage: ./stagger.sh <first_script.sh> <second_script.sh>
5
  # Optional: override waits via env vars WAIT1 / WAIT2 (seconds). Defaults: 3 hours each.
6
 
7
- WAIT1=${WAIT1:-3600} # 3 hours in seconds
8
  WAIT2=${WAIT2:-10800}
9
 
10
  SCRIPT1="${1:?usage: $0 <first_script.sh> <second_script.sh>}"
 
1
  #!/usr/bin/env bash
2
  set -euo pipefail
3
 
4
+ # Usage: nohup bash scripts/delay_run.sh scripts/run_train.sh scripts/run_train_2.sh > delay.log 2>&1 &
5
  # Optional: override waits via env vars WAIT1 / WAIT2 (seconds). Defaults: 3 hours each.
6
 
7
+ WAIT1=${WAIT1:-10800} # 3 hours in seconds
8
  WAIT2=${WAIT2:-10800}
9
 
10
  SCRIPT1="${1:?usage: $0 <first_script.sh> <second_script.sh>}"
dpacman/scripts/run_eval.sh CHANGED
@@ -14,7 +14,7 @@ if [ -z "$WANDB_API_KEY" ]; then
14
  export WANDB_API_KEY="$wandb_key"
15
  fi
16
 
17
- CUDA_VISIBLE_DEVICES=3 nohup python -u -m scripts.eval \
18
  hydra.run.dir="${run_dir}" \
19
  data_module.test_file="data_files/processed/splits/by_dna/test.csv" \
20
  data_module.tr_shelf_path="data_files/processed/embeddings/fimo_hits_only/trs_esm.shelf" \
@@ -23,8 +23,10 @@ CUDA_VISIBLE_DEVICES=3 nohup python -u -m scripts.eval \
23
  model.glm_input_dim=256 \
24
  model.compressed_dim=256 \
25
  model.hidden_dim=256 \
26
- ckpt_path="/home/a03-svincoff/DPACMAN/logs/train/classifier/runs/2025-08-27_18-52-25/checkpoints/epoch_009.ckpt" \
27
- model.lr=1e-5 \
 
 
28
  > "${run_dir}/run.log" 2>&1 &
29
 
30
  echo $! > "${run_dir}/pid.txt"
 
14
  export WANDB_API_KEY="$wandb_key"
15
  fi
16
 
17
+ CUDA_VISIBLE_DEVICES=2 nohup python -u -m scripts.eval \
18
  hydra.run.dir="${run_dir}" \
19
  data_module.test_file="data_files/processed/splits/by_dna/test.csv" \
20
  data_module.tr_shelf_path="data_files/processed/embeddings/fimo_hits_only/trs_esm.shelf" \
 
23
  model.glm_input_dim=256 \
24
  model.compressed_dim=256 \
25
  model.hidden_dim=256 \
26
+ data_module.score_col="binary_scores" \
27
+ data_module.norm_value=1 \
28
+ model.loss_type="binary" \
29
+ ckpt_path="/home/a03-svincoff/DPACMAN/logs/train/classifier/runs/2025-08-28_04-37-58-stoic-snowball-99/checkpoints/epoch_009.ckpt" \
30
  > "${run_dir}/run.log" 2>&1 &
31
 
32
  echo $! > "${run_dir}/pid.txt"
dpacman/scripts/run_split.sh CHANGED
@@ -10,13 +10,14 @@ mkdir -p "$run_dir"
10
 
11
  nohup python -u -m scripts.preprocess \
12
  hydra.run.dir="${run_dir}" \
13
- +data_task.p_exclude="true" \
14
  data_task="${data_task_type}/remap" \
 
 
15
  data_task.split_by=dna \
16
  data_task.train_ratio=0.8 \
17
  data_task.val_ratio=0.1 \
18
  data_task.test_ratio=0.1 \
19
- data_task.split_out_dir=dpacman/data_files/processed/splits/by_both \
20
  > "${run_dir}/run.log" 2>&1 &
21
 
22
  echo $! > "${run_dir}/pid.txt"
 
10
 
11
  nohup python -u -m scripts.preprocess \
12
  hydra.run.dir="${run_dir}" \
 
13
  data_task="${data_task_type}/remap" \
14
+ data_task.test_trs=["trseq23","trseq26","trseq17"] \
15
+ data_task.test_dnas=null \
16
  data_task.split_by=dna \
17
  data_task.train_ratio=0.8 \
18
  data_task.val_ratio=0.1 \
19
  data_task.test_ratio=0.1 \
20
+ data_task.split_out_dir=dpacman/data_files/processed/splits/handpicked_test \
21
  > "${run_dir}/run.log" 2>&1 &
22
 
23
  echo $! > "${run_dir}/pid.txt"
dpacman/scripts/run_train.sh CHANGED
@@ -23,17 +23,19 @@ CUDA_VISIBLE_DEVICES=0,1 nohup python -u -m scripts.train \
23
  hydra.run.dir="${run_dir}" \
24
  trainer.devices=2 \
25
  trainer.max_epochs=10 \
26
- data_module.train_file="data_files/processed/splits/by_dna/train.csv" \
27
- data_module.val_file="data_files/processed/splits/by_dna/val.csv" \
28
- data_module.test_file="data_files/processed/splits/by_dna/test.csv" \
29
  data_module.tr_shelf_path="data_files/processed/embeddings/fimo_hits_only/trs_esm.shelf" \
30
  data_module.dna_shelf_path="data_files/processed/embeddings/fimo_hits_only/peaks_caduceus.shelf" \
31
  data_module.batch_size=16 \
32
  data_module.score_col="binary_scores" \
 
33
  model.loss_type="binary" \
34
  model.glm_input_dim=256 \
35
  model.compressed_dim=256 \
36
  model.hidden_dim=256 \
 
37
  model.lr=1e-5 \
38
  > "${run_dir}/run.log" 2>&1 &
39
 
 
23
  hydra.run.dir="${run_dir}" \
24
  trainer.devices=2 \
25
  trainer.max_epochs=10 \
26
+ data_module.train_file="/home/a03-svincoff/DPACMAN/dpacman/data_files/processed/splits/handpicked_val_test_cropTR4/train.csv" \
27
+ data_module.val_file="/home/a03-svincoff/DPACMAN/dpacman/data_files/processed/splits/handpicked_val_test_cropTR4/val.csv" \
28
+ data_module.test_file="/home/a03-svincoff/DPACMAN/dpacman/data_files/processed/splits/handpicked_val_test_cropTR4/test.csv" \
29
  data_module.tr_shelf_path="data_files/processed/embeddings/fimo_hits_only/trs_esm.shelf" \
30
  data_module.dna_shelf_path="data_files/processed/embeddings/fimo_hits_only/peaks_caduceus.shelf" \
31
  data_module.batch_size=16 \
32
  data_module.score_col="binary_scores" \
33
+ data_module.norm_value=1 \
34
  model.loss_type="binary" \
35
  model.glm_input_dim=256 \
36
  model.compressed_dim=256 \
37
  model.hidden_dim=256 \
38
+ model.dropout=0.2 \
39
  model.lr=1e-5 \
40
  > "${run_dir}/run.log" 2>&1 &
41
 
dpacman/scripts/run_train_baseline.sh CHANGED
@@ -31,6 +31,7 @@ CUDA_VISIBLE_DEVICES=2,3 nohup python -u -m scripts.train \
31
  data_module.batch_size=16 \
32
  data_module.score_col="binary_scores" \
33
  model.loss_type="binary" \
 
34
  model=baseline \
35
  model.glm_input_dim=256 \
36
  model.compressed_dim=256 \
 
31
  data_module.batch_size=16 \
32
  data_module.score_col="binary_scores" \
33
  model.loss_type="binary" \
34
+ data_module.norm_value=1 \
35
  model=baseline \
36
  model.glm_input_dim=256 \
37
  model.compressed_dim=256 \