added dropout and overfit prevention
Browse files- .gitignore +3 -1
- configs/data_module/pair.yaml +1 -0
- configs/data_task/split/remap.yaml +3 -1
- configs/model/classifier.yaml +1 -0
- dpacman/benchmark/README.md +1 -0
- dpacman/benchmark/__init__.py +0 -0
- dpacman/classifier/model.py +43 -24
- dpacman/classifier/old_model.py +374 -0
- dpacman/data_modules/pair.py +12 -8
- dpacman/data_tasks/split/complex_remap.py +736 -0
- dpacman/data_tasks/split/remap.py +169 -429
- dpacman/data_tasks/split/remap_handpick.py +266 -0
- dpacman/find_wandb_run_name.py +67 -0
- dpacman/make_splits.ipynb +0 -0
- dpacman/manual_scan_chroms.ipynb +147 -0
- dpacman/scripts/delay_run.sh +2 -2
- dpacman/scripts/run_eval.sh +5 -3
- dpacman/scripts/run_split.sh +3 -2
- dpacman/scripts/run_train.sh +5 -3
- dpacman/scripts/run_train_baseline.sh +1 -0
.gitignore
CHANGED
|
@@ -40,4 +40,6 @@ log.log
|
|
| 40 |
log2.log
|
| 41 |
dpacman/delay.log
|
| 42 |
dpacman/view_profiles.ipynb
|
| 43 |
-
dpacman/find_wandb_run_dirs.py
|
|
|
|
|
|
|
|
|
| 40 |
log2.log
|
| 41 |
dpacman/delay.log
|
| 42 |
dpacman/view_profiles.ipynb
|
| 43 |
+
dpacman/find_wandb_run_dirs.py
|
| 44 |
+
dpacman/delay_binary.log
|
| 45 |
+
dpacman/delay_mix.log
|
configs/data_module/pair.yaml
CHANGED
|
@@ -6,6 +6,7 @@ test_file: data_files/processed/splits/by_dna/babytest.csv
|
|
| 6 |
|
| 7 |
target_col: dna_sequence
|
| 8 |
score_col: scores
|
|
|
|
| 9 |
|
| 10 |
tr_shelf_path: data_files/processed/embeddings/fimo_hits_only/trs_esm.shelf
|
| 11 |
dna_shelf_path: data_files/processed/embeddings/fimo_hits_only/baby_peaks_segmentnt_pernuc_with_onehot.shelf
|
|
|
|
| 6 |
|
| 7 |
target_col: dna_sequence
|
| 8 |
score_col: scores
|
| 9 |
+
norm_value: 1333
|
| 10 |
|
| 11 |
tr_shelf_path: data_files/processed/embeddings/fimo_hits_only/trs_esm.shelf
|
| 12 |
dna_shelf_path: data_files/processed/embeddings/fimo_hits_only/baby_peaks_segmentnt_pernuc_with_onehot.shelf
|
configs/data_task/split/remap.yaml
CHANGED
|
@@ -12,7 +12,9 @@ split_out_dir: dpacman/data_files/processed/splits
|
|
| 12 |
|
| 13 |
dna_map_path: dpacman/data_files/processed/fimo/post_fimo/fimo_hits_only/maps/dna_seqid_to_dna_sequence.json
|
| 14 |
|
| 15 |
-
split_by:
|
|
|
|
|
|
|
| 16 |
augment_rc: true
|
| 17 |
|
| 18 |
test_ratio: 0.10
|
|
|
|
| 12 |
|
| 13 |
dna_map_path: dpacman/data_files/processed/fimo/post_fimo/fimo_hits_only/maps/dna_seqid_to_dna_sequence.json
|
| 14 |
|
| 15 |
+
split_by: dna # protein, dna, or both
|
| 16 |
+
test_trs: ["trseq23","trseq26","trseq17"]
|
| 17 |
+
test_dnas: null
|
| 18 |
augment_rc: true
|
| 19 |
|
| 20 |
test_ratio: 0.10
|
configs/model/classifier.yaml
CHANGED
|
@@ -4,6 +4,7 @@ lr: 1e-4
|
|
| 4 |
alpha: 20
|
| 5 |
gamma: 20
|
| 6 |
weight_decay: 0.01
|
|
|
|
| 7 |
|
| 8 |
glm_input_dim: 1029
|
| 9 |
compressed_dim: 1029
|
|
|
|
| 4 |
alpha: 20
|
| 5 |
gamma: 20
|
| 6 |
weight_decay: 0.01
|
| 7 |
+
dropout: 0.1
|
| 8 |
|
| 9 |
glm_input_dim: 1029
|
| 10 |
compressed_dim: 1029
|
dpacman/benchmark/README.md
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
This folder is for benchmarking the trained classifier.
|
dpacman/benchmark/__init__.py
ADDED
|
File without changes
|
dpacman/classifier/model.py
CHANGED
|
@@ -11,82 +11,96 @@ from .loss import calculate_loss, auprc_zeros_vs_ones_from_logits, auroc_zeros_v
|
|
| 11 |
set_seed()
|
| 12 |
|
| 13 |
class LocalCNN(nn.Module):
|
| 14 |
-
def __init__(self, dim: int = 256, kernel_size: int = 3):
|
| 15 |
super().__init__()
|
| 16 |
padding = kernel_size // 2
|
| 17 |
self.conv = nn.Conv1d(dim, dim, kernel_size=kernel_size, padding=padding)
|
| 18 |
self.act = nn.GELU()
|
| 19 |
self.ln = nn.LayerNorm(dim)
|
|
|
|
|
|
|
| 20 |
|
| 21 |
def forward(self, x: torch.Tensor):
|
| 22 |
# x: (batch, L, dim)
|
| 23 |
out = self.conv(x.transpose(1, 2)) # → (batch, dim, L)
|
| 24 |
out = self.act(out)
|
|
|
|
| 25 |
out = out.transpose(1, 2) # → (batch, L, dim)
|
| 26 |
return self.ln(out + x) # residual
|
| 27 |
|
| 28 |
|
| 29 |
class CrossModalBlock(nn.Module):
|
| 30 |
-
def __init__(self, dim: int = 256, heads: int = 8, dropout: float = 0.
|
| 31 |
super().__init__()
|
| 32 |
# self-attention for both sides
|
| 33 |
-
self.sa_binder = nn.MultiheadAttention(dim, heads, batch_first=True)
|
| 34 |
-
self.sa_glm = nn.MultiheadAttention(dim, heads, batch_first=True)
|
|
|
|
|
|
|
| 35 |
# first layer norms
|
| 36 |
self.ln_b1 = nn.LayerNorm(dim)
|
| 37 |
self.ln_g1 = nn.LayerNorm(dim)
|
| 38 |
# first feed forward networks
|
| 39 |
self.ffn_b1 = nn.Sequential(
|
| 40 |
-
nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
|
| 41 |
)
|
| 42 |
self.ffn_g1 = nn.Sequential(
|
| 43 |
-
nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
|
| 44 |
)
|
|
|
|
|
|
|
|
|
|
| 45 |
self.ln_b2 = nn.LayerNorm(dim)
|
| 46 |
self.ln_g2 = nn.LayerNorm(dim)
|
| 47 |
|
| 48 |
# 2) reciprocal cross-attn: g<-b and b<-g
|
| 49 |
# DNA/GLM updated by attending to Binder
|
| 50 |
self.cross_g2b_1_RCA = nn.MultiheadAttention(dim, heads, batch_first=True, dropout=dropout)
|
|
|
|
| 51 |
self.ln_g3_RCA = nn.LayerNorm(dim)
|
| 52 |
-
self.ffn_g2_RCA = nn.Sequential(nn.Linear(dim, dim*4), nn.GELU(), nn.Linear(dim*4, dim))
|
|
|
|
| 53 |
self.ln_g4_RCA = nn.LayerNorm(dim)
|
| 54 |
|
| 55 |
# Binder updated by attending to DNA/GLM
|
| 56 |
self.cross_b2g_1_RCA = nn.MultiheadAttention(dim, heads, batch_first=True, dropout=dropout)
|
|
|
|
| 57 |
self.ln_b3_RCA = nn.LayerNorm(dim)
|
| 58 |
-
self.ffn_b2_RCA = nn.Sequential(nn.Linear(dim, dim*4), nn.GELU(), nn.Linear(dim*4, dim))
|
|
|
|
| 59 |
self.ln_b4_RCA = nn.LayerNorm(dim)
|
| 60 |
|
| 61 |
# cross attention (binder queries, glm keys/values)
|
| 62 |
-
# so the NDA path is updated by the
|
| 63 |
self.cross_g2b_2 = nn.MultiheadAttention(dim, heads, batch_first=True)
|
|
|
|
| 64 |
self.ln_g5 = nn.LayerNorm(dim)
|
| 65 |
self.ffn_g3 = nn.Sequential(
|
| 66 |
-
nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
|
| 67 |
)
|
|
|
|
| 68 |
self.ln_g6 = nn.LayerNorm(dim)
|
| 69 |
|
| 70 |
def forward(self, binder: torch.Tensor, glm: torch.Tensor, binder_kpm_mask=None, glm_kpm_mask=None):
|
| 71 |
"""
|
| 72 |
binder: (batch, Lb, dim)
|
| 73 |
glm: (batch, Lg, dim) -- has passed through its local CNN beforehand
|
| 74 |
-
returns: updated binder representation (batch, Lb, dim)
|
| 75 |
"""
|
| 76 |
-
# 1) Self-
|
| 77 |
# binder: self-attn + ffn
|
| 78 |
b = binder
|
| 79 |
b_sa, _ = self.sa_binder(b, b, b, key_padding_mask=binder_kpm_mask)
|
| 80 |
-
b = self.ln_b1(b + b_sa)
|
| 81 |
b_ff = self.ffn_b1(b)
|
| 82 |
-
b = self.ln_b2(b + b_ff)
|
| 83 |
|
| 84 |
# glm: self-attn + ffn
|
| 85 |
g = glm
|
| 86 |
g_sa, _ = self.sa_glm(g, g, g, key_padding_mask=glm_kpm_mask)
|
| 87 |
-
g = self.ln_g1(g + g_sa)
|
| 88 |
g_ff = self.ffn_g1(g)
|
| 89 |
-
g = self.ln_g2(g + g_ff)
|
| 90 |
|
| 91 |
# 2a) Reciprocal Cross-Attention:
|
| 92 |
# DNA updated by attending to Binder (Q=g, K=b, V=b)
|
|
@@ -97,22 +111,22 @@ class CrossModalBlock(nn.Module):
|
|
| 97 |
# invert if your mask is True=keep:
|
| 98 |
# key_padding_mask=(~binder_mask.bool()) if binder_mask is not None else None
|
| 99 |
)
|
| 100 |
-
g = self.ln_g3_RCA(g + g_ca)
|
| 101 |
-
g = self.ln_g4_RCA(g + self.ffn_g2_RCA(g))
|
| 102 |
|
| 103 |
# 2b) Binder updated by attending to DNA/GLM (Q=b, K=g, V=g)
|
| 104 |
b_ca, _ = self.cross_b2g_1_RCA(
|
| 105 |
b, g, g, key_padding_mask=glm_kpm_mask
|
| 106 |
# key_padding_mask=(~glm_mask.bool()) if glm_mask is not None else None
|
| 107 |
)
|
| 108 |
-
b = self.ln_b3_RCA(b + b_ca)
|
| 109 |
-
b = self.ln_b4_RCA(b + self.ffn_b2_RCA(b))
|
| 110 |
|
| 111 |
# cross-attention: glm queries binder and glm embeddings are updated
|
| 112 |
g_to_b_ca, _ = self.cross_g2b_2(g, b, b, key_padding_mask=binder_kpm_mask)
|
| 113 |
-
g = self.ln_g5(g + g_to_b_ca)
|
| 114 |
g_ff = self.ffn_g3(g)
|
| 115 |
-
g = self.ln_g6(g + g_ff)
|
| 116 |
return b, g # (batch, Lb, dim)
|
| 117 |
|
| 118 |
class DimCompressor(nn.Module):
|
|
@@ -164,12 +178,15 @@ class BindPredictor(LightningModule):
|
|
| 164 |
# Learnable compressor for binder -> 256, then project to hidden
|
| 165 |
self.binder_compress = DimCompressor(binder_input_dim, out_dim=compressed_dim)
|
| 166 |
self.proj_binder = nn.Linear(compressed_dim, hidden_dim)
|
|
|
|
|
|
|
| 167 |
|
| 168 |
# GLM side stays 256 -> hidden
|
| 169 |
self.proj_glm = nn.Linear(glm_input_dim, hidden_dim)
|
| 170 |
-
|
|
|
|
| 171 |
self.use_local_cnn = use_local_cnn_on_glm
|
| 172 |
-
self.local_cnn = LocalCNN(hidden_dim) if use_local_cnn_on_glm else nn.Identity()
|
| 173 |
|
| 174 |
self.layers = nn.ModuleList(
|
| 175 |
[CrossModalBlock(hidden_dim, heads, self.hparams.dropout) for _ in range(num_layers)]
|
|
@@ -188,9 +205,11 @@ class BindPredictor(LightningModule):
|
|
| 188 |
# Binder: learnable compression → 256 → hidden
|
| 189 |
b = self.binder_compress(binder_emb) # (B, Lb, 256)
|
| 190 |
b = self.proj_binder(b) # (B, Lb, hidden_dim)
|
|
|
|
| 191 |
|
| 192 |
# GLM: project → hidden, add local CNN context
|
| 193 |
g = self.proj_glm(glm_emb) # (B, Lg, hidden_dim)
|
|
|
|
| 194 |
if self.use_local_cnn:
|
| 195 |
g = self.local_cnn(g)
|
| 196 |
|
|
|
|
| 11 |
set_seed()
|
| 12 |
|
| 13 |
class LocalCNN(nn.Module):
|
| 14 |
+
def __init__(self, dim: int = 256, kernel_size: int = 3, dropout=0.1):
|
| 15 |
super().__init__()
|
| 16 |
padding = kernel_size // 2
|
| 17 |
self.conv = nn.Conv1d(dim, dim, kernel_size=kernel_size, padding=padding)
|
| 18 |
self.act = nn.GELU()
|
| 19 |
self.ln = nn.LayerNorm(dim)
|
| 20 |
+
|
| 21 |
+
self.dropout = nn.Dropout(dropout)
|
| 22 |
|
| 23 |
def forward(self, x: torch.Tensor):
|
| 24 |
# x: (batch, L, dim)
|
| 25 |
out = self.conv(x.transpose(1, 2)) # → (batch, dim, L)
|
| 26 |
out = self.act(out)
|
| 27 |
+
out = self.dropout(out) # dropout before the layer norm
|
| 28 |
out = out.transpose(1, 2) # → (batch, L, dim)
|
| 29 |
return self.ln(out + x) # residual
|
| 30 |
|
| 31 |
|
| 32 |
class CrossModalBlock(nn.Module):
|
| 33 |
+
def __init__(self, dim: int = 256, heads: int = 8, dropout: float = 0.1):
|
| 34 |
super().__init__()
|
| 35 |
# self-attention for both sides
|
| 36 |
+
self.sa_binder = nn.MultiheadAttention(dim, heads, batch_first=True, dropout=dropout)
|
| 37 |
+
self.sa_glm = nn.MultiheadAttention(dim, heads, batch_first=True, dropout=dropout)
|
| 38 |
+
self.do_sa_b = nn.Dropout(dropout)
|
| 39 |
+
self.do_sa_g = nn.Dropout(dropout)
|
| 40 |
# first layer norms
|
| 41 |
self.ln_b1 = nn.LayerNorm(dim)
|
| 42 |
self.ln_g1 = nn.LayerNorm(dim)
|
| 43 |
# first feed forward networks
|
| 44 |
self.ffn_b1 = nn.Sequential(
|
| 45 |
+
nn.Linear(dim, dim * 4), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim * 4, dim)
|
| 46 |
)
|
| 47 |
self.ffn_g1 = nn.Sequential(
|
| 48 |
+
nn.Linear(dim, dim * 4), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim * 4, dim)
|
| 49 |
)
|
| 50 |
+
self.do_ffn_b1 = nn.Dropout(dropout)
|
| 51 |
+
self.do_ffn_g1 = nn.Dropout(dropout)
|
| 52 |
+
|
| 53 |
self.ln_b2 = nn.LayerNorm(dim)
|
| 54 |
self.ln_g2 = nn.LayerNorm(dim)
|
| 55 |
|
| 56 |
# 2) reciprocal cross-attn: g<-b and b<-g
|
| 57 |
# DNA/GLM updated by attending to Binder
|
| 58 |
self.cross_g2b_1_RCA = nn.MultiheadAttention(dim, heads, batch_first=True, dropout=dropout)
|
| 59 |
+
self.do_rca_g = nn.Dropout(dropout)
|
| 60 |
self.ln_g3_RCA = nn.LayerNorm(dim)
|
| 61 |
+
self.ffn_g2_RCA = nn.Sequential(nn.Linear(dim, dim*4), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim*4, dim))
|
| 62 |
+
self.do_ffn_g2 = nn.Dropout(dropout)
|
| 63 |
self.ln_g4_RCA = nn.LayerNorm(dim)
|
| 64 |
|
| 65 |
# Binder updated by attending to DNA/GLM
|
| 66 |
self.cross_b2g_1_RCA = nn.MultiheadAttention(dim, heads, batch_first=True, dropout=dropout)
|
| 67 |
+
self.do_rca_b = nn.Dropout(dropout)
|
| 68 |
self.ln_b3_RCA = nn.LayerNorm(dim)
|
| 69 |
+
self.ffn_b2_RCA = nn.Sequential(nn.Linear(dim, dim*4), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim*4, dim))
|
| 70 |
+
self.do_ffn_b2 = nn.Dropout(dropout)
|
| 71 |
self.ln_b4_RCA = nn.LayerNorm(dim)
|
| 72 |
|
| 73 |
# cross attention (binder queries, glm keys/values)
|
| 74 |
+
# so the NDA path is updated by the transcription factors
|
| 75 |
self.cross_g2b_2 = nn.MultiheadAttention(dim, heads, batch_first=True)
|
| 76 |
+
self.do_g2b2 = nn.Dropout(dropout)
|
| 77 |
self.ln_g5 = nn.LayerNorm(dim)
|
| 78 |
self.ffn_g3 = nn.Sequential(
|
| 79 |
+
nn.Linear(dim, dim * 4), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim * 4, dim)
|
| 80 |
)
|
| 81 |
+
self.do_ffn_g3 = nn.Dropout(dropout)
|
| 82 |
self.ln_g6 = nn.LayerNorm(dim)
|
| 83 |
|
| 84 |
def forward(self, binder: torch.Tensor, glm: torch.Tensor, binder_kpm_mask=None, glm_kpm_mask=None):
|
| 85 |
"""
|
| 86 |
binder: (batch, Lb, dim)
|
| 87 |
glm: (batch, Lg, dim) -- has passed through its local CNN beforehand
|
| 88 |
+
returns: updated binder representation (batch, Lb, dim) and gLM representation
|
| 89 |
"""
|
| 90 |
+
# 1) Self-attention and feed-forward networks for binder and DNA
|
| 91 |
# binder: self-attn + ffn
|
| 92 |
b = binder
|
| 93 |
b_sa, _ = self.sa_binder(b, b, b, key_padding_mask=binder_kpm_mask)
|
| 94 |
+
b = self.ln_b1(b + self.do_sa_b(b_sa))
|
| 95 |
b_ff = self.ffn_b1(b)
|
| 96 |
+
b = self.ln_b2(b + self.do_ffn_b1(b_ff))
|
| 97 |
|
| 98 |
# glm: self-attn + ffn
|
| 99 |
g = glm
|
| 100 |
g_sa, _ = self.sa_glm(g, g, g, key_padding_mask=glm_kpm_mask)
|
| 101 |
+
g = self.ln_g1(g + self.do_sa_g(g_sa))
|
| 102 |
g_ff = self.ffn_g1(g)
|
| 103 |
+
g = self.ln_g2(g + self.do_ffn_g1(g_ff))
|
| 104 |
|
| 105 |
# 2a) Reciprocal Cross-Attention:
|
| 106 |
# DNA updated by attending to Binder (Q=g, K=b, V=b)
|
|
|
|
| 111 |
# invert if your mask is True=keep:
|
| 112 |
# key_padding_mask=(~binder_mask.bool()) if binder_mask is not None else None
|
| 113 |
)
|
| 114 |
+
g = self.ln_g3_RCA(g + self.do_rca_g(g_ca))
|
| 115 |
+
g = self.ln_g4_RCA(g + self.do_ffn_g2(self.ffn_g2_RCA(g)))
|
| 116 |
|
| 117 |
# 2b) Binder updated by attending to DNA/GLM (Q=b, K=g, V=g)
|
| 118 |
b_ca, _ = self.cross_b2g_1_RCA(
|
| 119 |
b, g, g, key_padding_mask=glm_kpm_mask
|
| 120 |
# key_padding_mask=(~glm_mask.bool()) if glm_mask is not None else None
|
| 121 |
)
|
| 122 |
+
b = self.ln_b3_RCA(b + self.do_rca_b(b_ca))
|
| 123 |
+
b = self.ln_b4_RCA(b + self.do_ffn_b2(self.ffn_b2_RCA(b)))
|
| 124 |
|
| 125 |
# cross-attention: glm queries binder and glm embeddings are updated
|
| 126 |
g_to_b_ca, _ = self.cross_g2b_2(g, b, b, key_padding_mask=binder_kpm_mask)
|
| 127 |
+
g = self.ln_g5(g + self.do_g2b2(g_to_b_ca))
|
| 128 |
g_ff = self.ffn_g3(g)
|
| 129 |
+
g = self.ln_g6(g + self.do_ffn_g3(g_ff))
|
| 130 |
return b, g # (batch, Lb, dim)
|
| 131 |
|
| 132 |
class DimCompressor(nn.Module):
|
|
|
|
| 178 |
# Learnable compressor for binder -> 256, then project to hidden
|
| 179 |
self.binder_compress = DimCompressor(binder_input_dim, out_dim=compressed_dim)
|
| 180 |
self.proj_binder = nn.Linear(compressed_dim, hidden_dim)
|
| 181 |
+
self.dropout_b1 = nn.Dropout(dropout)
|
| 182 |
+
self.act = nn.GELU()
|
| 183 |
|
| 184 |
# GLM side stays 256 -> hidden
|
| 185 |
self.proj_glm = nn.Linear(glm_input_dim, hidden_dim)
|
| 186 |
+
self.dropout_g1 = nn.Dropout(dropout)
|
| 187 |
+
|
| 188 |
self.use_local_cnn = use_local_cnn_on_glm
|
| 189 |
+
self.local_cnn = LocalCNN(hidden_dim, dropout=self.hparams.dropout) if use_local_cnn_on_glm else nn.Identity()
|
| 190 |
|
| 191 |
self.layers = nn.ModuleList(
|
| 192 |
[CrossModalBlock(hidden_dim, heads, self.hparams.dropout) for _ in range(num_layers)]
|
|
|
|
| 205 |
# Binder: learnable compression → 256 → hidden
|
| 206 |
b = self.binder_compress(binder_emb) # (B, Lb, 256)
|
| 207 |
b = self.proj_binder(b) # (B, Lb, hidden_dim)
|
| 208 |
+
b = self.dropout_b1(self.act(b))
|
| 209 |
|
| 210 |
# GLM: project → hidden, add local CNN context
|
| 211 |
g = self.proj_glm(glm_emb) # (B, Lg, hidden_dim)
|
| 212 |
+
g = self.dropout_g1(self.act(g))
|
| 213 |
if self.use_local_cnn:
|
| 214 |
g = self.local_cnn(g)
|
| 215 |
|
dpacman/classifier/old_model.py
ADDED
|
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Lightning Module for the binding model.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
from lightning import LightningModule
|
| 8 |
+
from dpacman.utils.models import set_seed
|
| 9 |
+
from .loss import calculate_loss, auprc_zeros_vs_ones_from_logits, auroc_zeros_vs_ones_from_logits
|
| 10 |
+
|
| 11 |
+
set_seed()
|
| 12 |
+
|
| 13 |
+
class LocalCNN(nn.Module):
|
| 14 |
+
def __init__(self, dim: int = 256, kernel_size: int = 3):
|
| 15 |
+
super().__init__()
|
| 16 |
+
padding = kernel_size // 2
|
| 17 |
+
self.conv = nn.Conv1d(dim, dim, kernel_size=kernel_size, padding=padding)
|
| 18 |
+
self.act = nn.GELU()
|
| 19 |
+
self.ln = nn.LayerNorm(dim)
|
| 20 |
+
|
| 21 |
+
def forward(self, x: torch.Tensor):
|
| 22 |
+
# x: (batch, L, dim)
|
| 23 |
+
out = self.conv(x.transpose(1, 2)) # → (batch, dim, L)
|
| 24 |
+
out = self.act(out)
|
| 25 |
+
out = out.transpose(1, 2) # → (batch, L, dim)
|
| 26 |
+
return self.ln(out + x) # residual
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class CrossModalBlock(nn.Module):
|
| 30 |
+
def __init__(self, dim: int = 256, heads: int = 8, dropout: float = 0.0):
|
| 31 |
+
super().__init__()
|
| 32 |
+
# self-attention for both sides
|
| 33 |
+
self.sa_binder = nn.MultiheadAttention(dim, heads, batch_first=True)
|
| 34 |
+
self.sa_glm = nn.MultiheadAttention(dim, heads, batch_first=True)
|
| 35 |
+
# first layer norms
|
| 36 |
+
self.ln_b1 = nn.LayerNorm(dim)
|
| 37 |
+
self.ln_g1 = nn.LayerNorm(dim)
|
| 38 |
+
# first feed forward networks
|
| 39 |
+
self.ffn_b1 = nn.Sequential(
|
| 40 |
+
nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
|
| 41 |
+
)
|
| 42 |
+
self.ffn_g1 = nn.Sequential(
|
| 43 |
+
nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
|
| 44 |
+
)
|
| 45 |
+
self.ln_b2 = nn.LayerNorm(dim)
|
| 46 |
+
self.ln_g2 = nn.LayerNorm(dim)
|
| 47 |
+
|
| 48 |
+
# 2) reciprocal cross-attn: g<-b and b<-g
|
| 49 |
+
# DNA/GLM updated by attending to Binder
|
| 50 |
+
self.cross_g2b_1_RCA = nn.MultiheadAttention(dim, heads, batch_first=True, dropout=dropout)
|
| 51 |
+
self.ln_g3_RCA = nn.LayerNorm(dim)
|
| 52 |
+
self.ffn_g2_RCA = nn.Sequential(nn.Linear(dim, dim*4), nn.GELU(), nn.Linear(dim*4, dim))
|
| 53 |
+
self.ln_g4_RCA = nn.LayerNorm(dim)
|
| 54 |
+
|
| 55 |
+
# Binder updated by attending to DNA/GLM
|
| 56 |
+
self.cross_b2g_1_RCA = nn.MultiheadAttention(dim, heads, batch_first=True, dropout=dropout)
|
| 57 |
+
self.ln_b3_RCA = nn.LayerNorm(dim)
|
| 58 |
+
self.ffn_b2_RCA = nn.Sequential(nn.Linear(dim, dim*4), nn.GELU(), nn.Linear(dim*4, dim))
|
| 59 |
+
self.ln_b4_RCA = nn.LayerNorm(dim)
|
| 60 |
+
|
| 61 |
+
# cross attention (binder queries, glm keys/values)
|
| 62 |
+
# so the NDA path is updated by the transcriptoin factors
|
| 63 |
+
self.cross_g2b_2 = nn.MultiheadAttention(dim, heads, batch_first=True)
|
| 64 |
+
self.ln_g5 = nn.LayerNorm(dim)
|
| 65 |
+
self.ffn_g3 = nn.Sequential(
|
| 66 |
+
nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
|
| 67 |
+
)
|
| 68 |
+
self.ln_g6 = nn.LayerNorm(dim)
|
| 69 |
+
|
| 70 |
+
def forward(self, binder: torch.Tensor, glm: torch.Tensor, binder_kpm_mask=None, glm_kpm_mask=None):
|
| 71 |
+
"""
|
| 72 |
+
binder: (batch, Lb, dim)
|
| 73 |
+
glm: (batch, Lg, dim) -- has passed through its local CNN beforehand
|
| 74 |
+
returns: updated binder representation (batch, Lb, dim)
|
| 75 |
+
"""
|
| 76 |
+
# 1) Self-attentino and feed-forward networks for binder and DNA
|
| 77 |
+
# binder: self-attn + ffn
|
| 78 |
+
b = binder
|
| 79 |
+
b_sa, _ = self.sa_binder(b, b, b, key_padding_mask=binder_kpm_mask)
|
| 80 |
+
b = self.ln_b1(b + b_sa)
|
| 81 |
+
b_ff = self.ffn_b1(b)
|
| 82 |
+
b = self.ln_b2(b + b_ff)
|
| 83 |
+
|
| 84 |
+
# glm: self-attn + ffn
|
| 85 |
+
g = glm
|
| 86 |
+
g_sa, _ = self.sa_glm(g, g, g, key_padding_mask=glm_kpm_mask)
|
| 87 |
+
g = self.ln_g1(g + g_sa)
|
| 88 |
+
g_ff = self.ffn_g1(g)
|
| 89 |
+
g = self.ln_g2(g + g_ff)
|
| 90 |
+
|
| 91 |
+
# 2a) Reciprocal Cross-Attention:
|
| 92 |
+
# DNA updated by attending to Binder (Q=g, K=b, V=b)
|
| 93 |
+
# Binder updated by attending to DNA (Q=b, K=g, V=g)
|
| 94 |
+
g_ca, _ = self.cross_g2b_1_RCA(
|
| 95 |
+
g, b, b, key_padding_mask=binder_kpm_mask
|
| 96 |
+
# torch MultiheadAttention expects key_padding_mask=True for PADs;
|
| 97 |
+
# invert if your mask is True=keep:
|
| 98 |
+
# key_padding_mask=(~binder_mask.bool()) if binder_mask is not None else None
|
| 99 |
+
)
|
| 100 |
+
g = self.ln_g3_RCA(g + g_ca)
|
| 101 |
+
g = self.ln_g4_RCA(g + self.ffn_g2_RCA(g))
|
| 102 |
+
|
| 103 |
+
# 2b) Binder updated by attending to DNA/GLM (Q=b, K=g, V=g)
|
| 104 |
+
b_ca, _ = self.cross_b2g_1_RCA(
|
| 105 |
+
b, g, g, key_padding_mask=glm_kpm_mask
|
| 106 |
+
# key_padding_mask=(~glm_mask.bool()) if glm_mask is not None else None
|
| 107 |
+
)
|
| 108 |
+
b = self.ln_b3_RCA(b + b_ca)
|
| 109 |
+
b = self.ln_b4_RCA(b + self.ffn_b2_RCA(b))
|
| 110 |
+
|
| 111 |
+
# cross-attention: glm queries binder and glm embeddings are updated
|
| 112 |
+
g_to_b_ca, _ = self.cross_g2b_2(g, b, b, key_padding_mask=binder_kpm_mask)
|
| 113 |
+
g = self.ln_g5(g + g_to_b_ca)
|
| 114 |
+
g_ff = self.ffn_g3(g)
|
| 115 |
+
g = self.ln_g6(g + g_ff)
|
| 116 |
+
return b, g # (batch, Lb, dim)
|
| 117 |
+
|
| 118 |
+
class DimCompressor(nn.Module):
|
| 119 |
+
"""
|
| 120 |
+
Learnable per-token compressor: maps any in_dim >= out_dim to out_dim (default 256).
|
| 121 |
+
If in_dim == out_dim, behaves as identity.
|
| 122 |
+
"""
|
| 123 |
+
|
| 124 |
+
def __init__(self, in_dim: int, out_dim: int = 256):
|
| 125 |
+
super().__init__()
|
| 126 |
+
if in_dim == out_dim:
|
| 127 |
+
self.net = nn.Identity()
|
| 128 |
+
else:
|
| 129 |
+
hidden = max(out_dim * 2, (in_dim + out_dim) // 2)
|
| 130 |
+
self.net = nn.Sequential(
|
| 131 |
+
nn.LayerNorm(in_dim),
|
| 132 |
+
nn.Linear(in_dim, hidden),
|
| 133 |
+
nn.GELU(),
|
| 134 |
+
nn.Linear(hidden, out_dim),
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 138 |
+
# x: (B, L, in_dim)
|
| 139 |
+
return self.net(x)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class BindPredictor(LightningModule):
|
| 143 |
+
def __init__(
|
| 144 |
+
self,
|
| 145 |
+
# input_dim: int = 256, # OLD: single input dim
|
| 146 |
+
binder_input_dim: int = 1280, # NEW: TF (binder) original dim (e.g., 1280)
|
| 147 |
+
glm_input_dim: int = 256, # NEW: DNA/GLM original dim (e.g., 256)
|
| 148 |
+
compressed_dim: int = 256, # NEW: learnable compressed dim
|
| 149 |
+
hidden_dim: int = 256,
|
| 150 |
+
heads: int = 8,
|
| 151 |
+
num_layers: int = 4,
|
| 152 |
+
lr: float = 1e-4,
|
| 153 |
+
alpha: float = 20,
|
| 154 |
+
gamma: float = 20,
|
| 155 |
+
dropout: float = 0,
|
| 156 |
+
use_local_cnn_on_glm: bool = True,
|
| 157 |
+
weight_decay: float = 0.01,
|
| 158 |
+
loss_type = "mixed"
|
| 159 |
+
):
|
| 160 |
+
# Init
|
| 161 |
+
super(BindPredictor, self).__init__()
|
| 162 |
+
self.save_hyperparameters()
|
| 163 |
+
|
| 164 |
+
# Learnable compressor for binder -> 256, then project to hidden
|
| 165 |
+
self.binder_compress = DimCompressor(binder_input_dim, out_dim=compressed_dim)
|
| 166 |
+
self.proj_binder = nn.Linear(compressed_dim, hidden_dim)
|
| 167 |
+
|
| 168 |
+
# GLM side stays 256 -> hidden
|
| 169 |
+
self.proj_glm = nn.Linear(glm_input_dim, hidden_dim)
|
| 170 |
+
|
| 171 |
+
self.use_local_cnn = use_local_cnn_on_glm
|
| 172 |
+
self.local_cnn = LocalCNN(hidden_dim) if use_local_cnn_on_glm else nn.Identity()
|
| 173 |
+
|
| 174 |
+
self.layers = nn.ModuleList(
|
| 175 |
+
[CrossModalBlock(hidden_dim, heads, self.hparams.dropout) for _ in range(num_layers)]
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
#self.ln_out = nn.LayerNorm(hidden_dim)
|
| 179 |
+
# self.head = nn.Sequential(nn.Linear(hidden_dim, 1), nn.Sigmoid()) # OLD: returned probabilities
|
| 180 |
+
self.head = nn.Linear(hidden_dim, 1) # NEW: return logits (safe for AMP)
|
| 181 |
+
|
| 182 |
+
def forward(self, binder_emb, glm_emb, binder_mask, glm_mask):
|
| 183 |
+
"""
|
| 184 |
+
binder_emb: (B, Lb, binder_input_dim)
|
| 185 |
+
glm_emb: (B, Lg, glm_input_dim)
|
| 186 |
+
Returns per-nucleotide logits for the GLM sequence: (B, Lg)
|
| 187 |
+
"""
|
| 188 |
+
# Binder: learnable compression → 256 → hidden
|
| 189 |
+
b = self.binder_compress(binder_emb) # (B, Lb, 256)
|
| 190 |
+
b = self.proj_binder(b) # (B, Lb, hidden_dim)
|
| 191 |
+
|
| 192 |
+
# GLM: project → hidden, add local CNN context
|
| 193 |
+
g = self.proj_glm(glm_emb) # (B, Lg, hidden_dim)
|
| 194 |
+
if self.use_local_cnn:
|
| 195 |
+
g = self.local_cnn(g)
|
| 196 |
+
|
| 197 |
+
# Cross-modal blocks: update binder states using GLM
|
| 198 |
+
for layer in self.layers:
|
| 199 |
+
b, g = layer(b, g, binder_mask, glm_mask) # (B, Lb, hidden_dim)
|
| 200 |
+
|
| 201 |
+
# Predict per-nucleotide logits on the GLM tokens:
|
| 202 |
+
# return self.head(g).squeeze(-1) # OLD: probabilities (with Sigmoid in head)
|
| 203 |
+
logits = self.head(g).squeeze(
|
| 204 |
+
-1
|
| 205 |
+
)
|
| 206 |
+
return logits
|
| 207 |
+
|
| 208 |
+
# ----- Lightning hooks -----
|
| 209 |
+
def training_step(self, batch, batch_idx):
|
| 210 |
+
"""
|
| 211 |
+
Training step taken by PyTorch-Lightning trainer. Uses batch returned by data collator.
|
| 212 |
+
Colator returns a dictionary with:
|
| 213 |
+
"binder_emb" # [B, Lb_max, Db]
|
| 214 |
+
"binder_kpm" # [B, Lb_max]
|
| 215 |
+
"glm_emb" # [B, Lg_max, Dg]
|
| 216 |
+
"glm_kpm" # [B, Lg_max]
|
| 217 |
+
"labels" # [B, Lg_max]
|
| 218 |
+
"ID"
|
| 219 |
+
"tr_sequence"
|
| 220 |
+
"dna_sequence"
|
| 221 |
+
}
|
| 222 |
+
"""
|
| 223 |
+
logits = self.forward(batch["binder_emb"], batch["glm_emb"], batch["binder_kpm"], batch["glm_kpm"])
|
| 224 |
+
loss = calculate_loss(
|
| 225 |
+
logits, batch["labels"], batch["binder_kpm"], batch["glm_kpm"], alpha=self.hparams.alpha, gamma=self.hparams.gamma, loss_type=self.hparams.loss_type
|
| 226 |
+
)
|
| 227 |
+
self.log(
|
| 228 |
+
"train/loss",
|
| 229 |
+
loss,
|
| 230 |
+
on_step=True,
|
| 231 |
+
on_epoch=True,
|
| 232 |
+
prog_bar=True,
|
| 233 |
+
batch_size=logits.size(0),
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
# ---- AUPRC and AUROC on labels in {0, >0.99} only ----
|
| 237 |
+
ap, n_pos, n_neg, precision, recall, thresholds = auprc_zeros_vs_ones_from_logits(
|
| 238 |
+
logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99
|
| 239 |
+
)
|
| 240 |
+
auc, n_pos, n_neg, tpr, fpr, thresolds, tp, fp = auroc_zeros_vs_ones_from_logits(
|
| 241 |
+
logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99
|
| 242 |
+
)
|
| 243 |
+
# per-batch AP (epoch-mean is a decent summary); sync across GPUs if using DDP
|
| 244 |
+
self.log("train/auprc_0v1",
|
| 245 |
+
ap if torch.isfinite(ap) else torch.tensor(0.0, device=ap.device),
|
| 246 |
+
on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0))
|
| 247 |
+
self.log("train/auroc_0v1",
|
| 248 |
+
auc if torch.isfinite(auc) else torch.tensor(0.0, device=auc.device),
|
| 249 |
+
on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0))
|
| 250 |
+
# (optional) also log class counts so you can sanity-check balance
|
| 251 |
+
self.log("train/n_pos_0v1", float(n_pos), on_step=False, on_epoch=True, sync_dist=True)
|
| 252 |
+
self.log("train/n_neg_0v1", float(n_neg), on_step=False, on_epoch=True, sync_dist=True)
|
| 253 |
+
|
| 254 |
+
return loss
|
| 255 |
+
|
| 256 |
+
def validation_step(self, batch, batch_idx):
|
| 257 |
+
logits = self.forward(batch["binder_emb"], batch["glm_emb"], batch["binder_kpm"], batch["glm_kpm"])
|
| 258 |
+
loss = calculate_loss(
|
| 259 |
+
logits, batch["labels"], batch["binder_kpm"], batch["glm_kpm"], alpha=self.hparams.alpha, gamma=self.hparams.gamma, loss_type=self.hparams.loss_type
|
| 260 |
+
)
|
| 261 |
+
self.log(
|
| 262 |
+
"val/loss",
|
| 263 |
+
loss,
|
| 264 |
+
on_step=False,
|
| 265 |
+
on_epoch=True,
|
| 266 |
+
prog_bar=True,
|
| 267 |
+
batch_size=logits.size(0),
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
# ---- AUPRC and AUROC on labels in {0, >0.99} only ----
|
| 271 |
+
ap, n_pos, n_neg, precision, recall, thresholds = auprc_zeros_vs_ones_from_logits(
|
| 272 |
+
logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99
|
| 273 |
+
)
|
| 274 |
+
auc, n_pos, n_neg, tpr, fpr, thresolds, tp, fp = auroc_zeros_vs_ones_from_logits(
|
| 275 |
+
logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99
|
| 276 |
+
)
|
| 277 |
+
# per-batch AP (epoch-mean is a decent summary); sync across GPUs if using DDP
|
| 278 |
+
self.log("val/auprc_0v1",
|
| 279 |
+
ap if torch.isfinite(ap) else torch.tensor(0.0, device=ap.device),
|
| 280 |
+
on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0))
|
| 281 |
+
self.log("val/auroc_0v1",
|
| 282 |
+
auc if torch.isfinite(auc) else torch.tensor(0.0, device=auc.device),
|
| 283 |
+
on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0))
|
| 284 |
+
|
| 285 |
+
return loss
|
| 286 |
+
|
| 287 |
+
def test_step(self, batch, batch_idx):
|
| 288 |
+
logits = self.forward(batch["binder_emb"], batch["glm_emb"], batch["binder_kpm"], batch["glm_kpm"])
|
| 289 |
+
loss = calculate_loss(
|
| 290 |
+
logits, batch["labels"], batch["binder_kpm"], batch["glm_kpm"], alpha=self.hparams.alpha, gamma=self.hparams.gamma, loss_type=self.hparams.loss_type
|
| 291 |
+
)
|
| 292 |
+
self.log(
|
| 293 |
+
"test/loss", loss, on_step=False, on_epoch=True, batch_size=logits.size(0)
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
# ---- AUPRC and AUROC on labels in {0, >0.99} only ----
|
| 297 |
+
ap, n_pos, n_neg, precision, recall, thresholds = auprc_zeros_vs_ones_from_logits(
|
| 298 |
+
logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99
|
| 299 |
+
)
|
| 300 |
+
auc, n_pos, n_neg, tpr, fpr, thresolds, tp, fp = auroc_zeros_vs_ones_from_logits(
|
| 301 |
+
logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99
|
| 302 |
+
)
|
| 303 |
+
# per-batch AP (epoch-mean is a decent summary); sync across GPUs if using DDP
|
| 304 |
+
self.log("test/auprc_0v1",
|
| 305 |
+
ap if torch.isfinite(ap) else torch.tensor(0.0, device=ap.device),
|
| 306 |
+
on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0))
|
| 307 |
+
self.log("test/auroc_0v1",
|
| 308 |
+
auc if torch.isfinite(auc) else torch.tensor(0.0, device=auc.device),
|
| 309 |
+
on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0))
|
| 310 |
+
|
| 311 |
+
return loss
|
| 312 |
+
|
| 313 |
+
def predict_step(self, batch, batch_idx, dataloader_idx=0):
|
| 314 |
+
logits = self.forward(batch["binder_emb"], batch["glm_emb"],
|
| 315 |
+
batch["binder_kpm"], batch["glm_kpm"]).squeeze(-1) # (B,L)
|
| 316 |
+
valid = ~batch["glm_kpm"] # (B,L)
|
| 317 |
+
return {
|
| 318 |
+
"ids": batch["ID"], # list[str]
|
| 319 |
+
"logits": logits.detach().cpu(), # (B,Lmax) padded
|
| 320 |
+
"valid": valid.detach().cpu(), # (B,Lmax) booleans
|
| 321 |
+
"labels": batch["labels"].detach().cpu(), # (B,Lmax) padded
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
def on_before_optimizer_step(self, optimizer):
|
| 325 |
+
# Compute global L2 norm of all parameter gradients (ignores None grads)
|
| 326 |
+
grads = []
|
| 327 |
+
for p in self.parameters():
|
| 328 |
+
if p.grad is not None:
|
| 329 |
+
# .detach() avoids autograd tracking; .float() avoids fp16 overflow in norms
|
| 330 |
+
grads.append(p.grad.detach().float().norm(2))
|
| 331 |
+
if grads:
|
| 332 |
+
total_norm = torch.norm(torch.stack(grads), p=2)
|
| 333 |
+
self.log("train/grad_norm", total_norm, on_step=True, prog_bar=False, logger=True)
|
| 334 |
+
|
| 335 |
+
def on_after_backward(self):
|
| 336 |
+
grads = [p.grad.detach().float().norm(2)
|
| 337 |
+
for p in self.parameters() if p.grad is not None]
|
| 338 |
+
if grads:
|
| 339 |
+
total_norm = torch.norm(torch.stack(grads), p=2)
|
| 340 |
+
self.log("train/grad_norm_back", total_norm, on_step=True, prog_bar=False)
|
| 341 |
+
|
| 342 |
+
def on_train_epoch_end(self):
|
| 343 |
+
if False:
|
| 344 |
+
if self.train_auc.compute() is not None:
|
| 345 |
+
self.log("train/auroc", self.train_auc.compute(), prog_bar=True)
|
| 346 |
+
self.train_auc.reset()
|
| 347 |
+
|
| 348 |
+
def on_validation_epoch_end(self):
|
| 349 |
+
if False:
|
| 350 |
+
if self.val_auc.compute() is not None:
|
| 351 |
+
self.log("val/auroc", self.val_auc.compute(), prog_bar=True)
|
| 352 |
+
self.val_auc.reset()
|
| 353 |
+
|
| 354 |
+
def on_test_epoch_end(self):
|
| 355 |
+
if False:
|
| 356 |
+
if self.test_auc.compute() is not None:
|
| 357 |
+
self.log("test/auroc", self.test_auc.compute(), prog_bar=True)
|
| 358 |
+
self.test_auc.reset()
|
| 359 |
+
|
| 360 |
+
def configure_optimizers(self):
|
| 361 |
+
# AdamW + cosine as a sensible default
|
| 362 |
+
opt = torch.optim.AdamW(
|
| 363 |
+
self.parameters(),
|
| 364 |
+
lr=self.hparams.lr,
|
| 365 |
+
weight_decay=self.hparams.weight_decay,
|
| 366 |
+
)
|
| 367 |
+
# Scheduler optional—comment out if you prefer fixed LR
|
| 368 |
+
sch = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 369 |
+
opt, T_max=max(self.trainer.max_epochs, 1)
|
| 370 |
+
)
|
| 371 |
+
return {
|
| 372 |
+
"optimizer": opt,
|
| 373 |
+
"lr_scheduler": {"scheduler": sch, "interval": "epoch"},
|
| 374 |
+
}
|
dpacman/data_modules/pair.py
CHANGED
|
@@ -180,16 +180,16 @@ class PairDataset(Dataset):
|
|
| 180 |
"""
|
| 181 |
if self.score_col not in dataset.columns:
|
| 182 |
logger.info(f"Scores not provided. Adding placeholder scores where all positions are considered binding")
|
| 183 |
-
dataset[self.score_col] = dataset[
|
| 184 |
dataset[self.score_col] = dataset[self.score_col].apply(lambda x: ",".join([str(self.norm_value)]*x))
|
| 185 |
self.fake_scores=True
|
| 186 |
# split string into list of strings
|
| 187 |
dataset[self.score_col] = dataset[self.score_col].apply(lambda x: x.split(","))
|
|
|
|
| 188 |
# turn list of strings into list of normalized, rounded floats
|
| 189 |
dataset[self.score_col] = dataset[self.score_col].apply(
|
| 190 |
lambda x: [round(int(y) / self.norm_value, self.round_to) for y in x]
|
| 191 |
)
|
| 192 |
-
|
| 193 |
# convert to records for ease of loading
|
| 194 |
dataset = dataset.to_dict(orient="records")
|
| 195 |
return dataset
|
|
@@ -222,11 +222,13 @@ class PairDataModule(LightningDataModule):
|
|
| 222 |
shuffle_train_batch_order: bool = True,
|
| 223 |
score_col: str = "scores",
|
| 224 |
target_col: str = "dna_sequence",
|
| 225 |
-
binder_col: str = "tr_sequence"
|
|
|
|
| 226 |
):
|
| 227 |
super().__init__()
|
| 228 |
self.save_hyperparameters()
|
| 229 |
self.debug_run = debug_run
|
|
|
|
| 230 |
|
| 231 |
# Initialize the data files
|
| 232 |
self.train_data_file = train_file
|
|
@@ -267,7 +269,7 @@ class PairDataModule(LightningDataModule):
|
|
| 267 |
df = pd.read_csv(file_path)
|
| 268 |
if lim is not None:
|
| 269 |
df = df[:lim].reset_index(drop=True)
|
| 270 |
-
return df
|
| 271 |
except:
|
| 272 |
raise Exception(f"{file_path} is not a valid file")
|
| 273 |
|
|
@@ -278,7 +280,7 @@ class PairDataModule(LightningDataModule):
|
|
| 278 |
if stage in (None, "fit"):
|
| 279 |
if not hasattr(self, "train_dataset"):
|
| 280 |
train_df = self.load_file(self.train_data_file, lim=lim)
|
| 281 |
-
self.train_dataset = PairDataset(train_df, score_col = self.score_col, target_col = self.target_col, binder_col = self.binder_col)
|
| 282 |
self.train_batches = make_length_batches(
|
| 283 |
dataset_records=self.train_dataset.dataset,
|
| 284 |
tr_shelf_path=str(self.hparams.tr_shelf_path),
|
|
@@ -294,7 +296,7 @@ class PairDataModule(LightningDataModule):
|
|
| 294 |
|
| 295 |
if not hasattr(self, "val_dataset"):
|
| 296 |
val_df = self.load_file(self.val_data_file, lim=lim)
|
| 297 |
-
self.val_dataset = PairDataset(val_df, score_col = self.score_col, target_col = self.target_col, binder_col = self.binder_col)
|
| 298 |
self.val_batches = make_length_batches(
|
| 299 |
dataset_records=self.val_dataset.dataset,
|
| 300 |
tr_shelf_path=str(self.hparams.tr_shelf_path),
|
|
@@ -309,7 +311,7 @@ class PairDataModule(LightningDataModule):
|
|
| 309 |
if stage in (None, "validate"):
|
| 310 |
if not hasattr(self, "val_dataset"):
|
| 311 |
val_df = self.load_file(self.val_data_file, lim=lim)
|
| 312 |
-
self.val_dataset = PairDataset(val_df, score_col = self.score_col, target_col = self.target_col, binder_col = self.binder_col)
|
| 313 |
self.val_batches = make_length_batches(
|
| 314 |
dataset_records=self.val_dataset.dataset,
|
| 315 |
tr_shelf_path=str(self.hparams.tr_shelf_path),
|
|
@@ -324,7 +326,7 @@ class PairDataModule(LightningDataModule):
|
|
| 324 |
if stage in (None, "test"):
|
| 325 |
if not hasattr(self, "test_dataset"):
|
| 326 |
test_df = self.load_file(self.test_data_file, lim=lim)
|
| 327 |
-
self.test_dataset = PairDataset(test_df, score_col = self.score_col, target_col = self.target_col, binder_col = self.binder_col)
|
| 328 |
self.test_batches = make_length_batches(
|
| 329 |
dataset_records=self.test_dataset.dataset,
|
| 330 |
tr_shelf_path=str(self.hparams.tr_shelf_path),
|
|
@@ -623,6 +625,8 @@ def main():
|
|
| 623 |
debug_run=args.debug_run,
|
| 624 |
shuffle_train_batch_order=args.shuffle_train_batch_order,
|
| 625 |
pin_memory=False,
|
|
|
|
|
|
|
| 626 |
)
|
| 627 |
|
| 628 |
# ---- Train ----
|
|
|
|
| 180 |
"""
|
| 181 |
if self.score_col not in dataset.columns:
|
| 182 |
logger.info(f"Scores not provided. Adding placeholder scores where all positions are considered binding")
|
| 183 |
+
dataset[self.score_col] = dataset[self.target_col].str.len()
|
| 184 |
dataset[self.score_col] = dataset[self.score_col].apply(lambda x: ",".join([str(self.norm_value)]*x))
|
| 185 |
self.fake_scores=True
|
| 186 |
# split string into list of strings
|
| 187 |
dataset[self.score_col] = dataset[self.score_col].apply(lambda x: x.split(","))
|
| 188 |
+
dataset["copycol"] = dataset[self.score_col]
|
| 189 |
# turn list of strings into list of normalized, rounded floats
|
| 190 |
dataset[self.score_col] = dataset[self.score_col].apply(
|
| 191 |
lambda x: [round(int(y) / self.norm_value, self.round_to) for y in x]
|
| 192 |
)
|
|
|
|
| 193 |
# convert to records for ease of loading
|
| 194 |
dataset = dataset.to_dict(orient="records")
|
| 195 |
return dataset
|
|
|
|
| 222 |
shuffle_train_batch_order: bool = True,
|
| 223 |
score_col: str = "scores",
|
| 224 |
target_col: str = "dna_sequence",
|
| 225 |
+
binder_col: str = "tr_sequence",
|
| 226 |
+
norm_value: int = 1333
|
| 227 |
):
|
| 228 |
super().__init__()
|
| 229 |
self.save_hyperparameters()
|
| 230 |
self.debug_run = debug_run
|
| 231 |
+
self.norm_value = norm_value
|
| 232 |
|
| 233 |
# Initialize the data files
|
| 234 |
self.train_data_file = train_file
|
|
|
|
| 269 |
df = pd.read_csv(file_path)
|
| 270 |
if lim is not None:
|
| 271 |
df = df[:lim].reset_index(drop=True)
|
| 272 |
+
return df
|
| 273 |
except:
|
| 274 |
raise Exception(f"{file_path} is not a valid file")
|
| 275 |
|
|
|
|
| 280 |
if stage in (None, "fit"):
|
| 281 |
if not hasattr(self, "train_dataset"):
|
| 282 |
train_df = self.load_file(self.train_data_file, lim=lim)
|
| 283 |
+
self.train_dataset = PairDataset(train_df, norm_value = self.norm_value, score_col = self.score_col, target_col = self.target_col, binder_col = self.binder_col)
|
| 284 |
self.train_batches = make_length_batches(
|
| 285 |
dataset_records=self.train_dataset.dataset,
|
| 286 |
tr_shelf_path=str(self.hparams.tr_shelf_path),
|
|
|
|
| 296 |
|
| 297 |
if not hasattr(self, "val_dataset"):
|
| 298 |
val_df = self.load_file(self.val_data_file, lim=lim)
|
| 299 |
+
self.val_dataset = PairDataset(val_df, norm_value = self.norm_value, score_col = self.score_col, target_col = self.target_col, binder_col = self.binder_col)
|
| 300 |
self.val_batches = make_length_batches(
|
| 301 |
dataset_records=self.val_dataset.dataset,
|
| 302 |
tr_shelf_path=str(self.hparams.tr_shelf_path),
|
|
|
|
| 311 |
if stage in (None, "validate"):
|
| 312 |
if not hasattr(self, "val_dataset"):
|
| 313 |
val_df = self.load_file(self.val_data_file, lim=lim)
|
| 314 |
+
self.val_dataset = PairDataset(val_df, norm_value = self.norm_value, score_col = self.score_col, target_col = self.target_col, binder_col = self.binder_col)
|
| 315 |
self.val_batches = make_length_batches(
|
| 316 |
dataset_records=self.val_dataset.dataset,
|
| 317 |
tr_shelf_path=str(self.hparams.tr_shelf_path),
|
|
|
|
| 326 |
if stage in (None, "test"):
|
| 327 |
if not hasattr(self, "test_dataset"):
|
| 328 |
test_df = self.load_file(self.test_data_file, lim=lim)
|
| 329 |
+
self.test_dataset = PairDataset(test_df, norm_value = self.norm_value, score_col = self.score_col, target_col = self.target_col, binder_col = self.binder_col)
|
| 330 |
self.test_batches = make_length_batches(
|
| 331 |
dataset_records=self.test_dataset.dataset,
|
| 332 |
tr_shelf_path=str(self.hparams.tr_shelf_path),
|
|
|
|
| 625 |
debug_run=args.debug_run,
|
| 626 |
shuffle_train_batch_order=args.shuffle_train_batch_order,
|
| 627 |
pin_memory=False,
|
| 628 |
+
score_col="binary_scores",
|
| 629 |
+
norm_value=1
|
| 630 |
)
|
| 631 |
|
| 632 |
# ---- Train ----
|
dpacman/data_tasks/split/complex_remap.py
ADDED
|
@@ -0,0 +1,736 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import Counter, defaultdict
|
| 2 |
+
from ortools.linear_solver import pywraplp
|
| 3 |
+
import random
|
| 4 |
+
from omegaconf import DictConfig
|
| 5 |
+
import pandas as pd
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import os
|
| 8 |
+
import numpy as np
|
| 9 |
+
from sklearn.model_selection import train_test_split
|
| 10 |
+
from dpacman.data_tasks.fimo.post_fimo import get_reverse_complement
|
| 11 |
+
import json
|
| 12 |
+
import rootutils
|
| 13 |
+
from dpacman.utils import pylogger
|
| 14 |
+
|
| 15 |
+
root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
| 16 |
+
logger = pylogger.RankedLogger(__name__, rank_zero_only=True)
|
| 17 |
+
|
| 18 |
+
def split_with_predefined_test(
|
| 19 |
+
full_df = pd.DataFrame(),
|
| 20 |
+
split_names=("train", "val", "test"),
|
| 21 |
+
test_trs=None,
|
| 22 |
+
test_dnas=None,
|
| 23 |
+
ratios=(0.8, 0.1, 0.1),
|
| 24 |
+
):
|
| 25 |
+
"""
|
| 26 |
+
Method for splitting into train and val with a predefined test set.
|
| 27 |
+
The proteins in the test set, and the DNA clusters of the DNAs they're associated with, must be excluded from train and val.
|
| 28 |
+
The remaining rows for train and val are split to preserve 80/10/10 as best as possible.
|
| 29 |
+
"""
|
| 30 |
+
full_df[""]
|
| 31 |
+
test = full_df.copy(deep=True)
|
| 32 |
+
if test_trs is not None:
|
| 33 |
+
test = test.loc[test["tr_seqid"].isin(test_trs)].reset_index(drop=True)
|
| 34 |
+
if test_dnas is not None:
|
| 35 |
+
test = test.loc[test["dna_seqid"].isin(test_dnas)].reset_index(drop=True)
|
| 36 |
+
|
| 37 |
+
tr_clusters_to_exclude = test["tr_cluster_rep"].unique().tolist()
|
| 38 |
+
dna_clusters_to_exclude = test["dna_cluster_rep"].unique().tolist()
|
| 39 |
+
|
| 40 |
+
remaining = full_df.loc[
|
| 41 |
+
(~full_df["tr_cluster_rep"].isin(tr_clusters_to_exclude)) &
|
| 42 |
+
(~full_df["dna_cluster_rep"].isin(dna_clusters_to_exclude))
|
| 43 |
+
].reset_index(drop=True)
|
| 44 |
+
|
| 45 |
+
test_ids = test["ID"].unique().tolist()
|
| 46 |
+
remaining_ids = remaining["ID"].unique().tolist()
|
| 47 |
+
remaining_clusters = remaining["dna_cluster_rep"].unique().tolis()
|
| 48 |
+
lost_rows = full_df.loc[
|
| 49 |
+
(~full_df["ID"].isin(test_ids)) &
|
| 50 |
+
(~full_df["ID"].isin(remaining_ids))
|
| 51 |
+
]
|
| 52 |
+
|
| 53 |
+
logger.info(f"Rows in test: {len(test)}")
|
| 54 |
+
logger.info(f"Rows to be split between train and val: {len(remaining)}")
|
| 55 |
+
total_rows = len(test) + len(remaining)
|
| 56 |
+
logger.info(f"Total rows: {total_rows}. Test percentage: {100*len(test)/total_rows:.2f}%")
|
| 57 |
+
logger.info(f"Lost rows: {len(lost_rows)}")
|
| 58 |
+
|
| 59 |
+
train_ratio_from_remaining = round((0.8*total_rows)/len(remaining), 2)
|
| 60 |
+
# use sklearn
|
| 61 |
+
test_size_1 = 1 - train_ratio_from_remaining
|
| 62 |
+
logger.info(
|
| 63 |
+
f"\tPerforming first split: non-test clusters -> train clusters ({round(1-test_size_1,3)}) and val ({test_size_1})"
|
| 64 |
+
)
|
| 65 |
+
X = remaining_clusters
|
| 66 |
+
y = [0] * len(remaining_clusters)
|
| 67 |
+
X_train, X_val, y_train, y_val = train_test_split(
|
| 68 |
+
X, y, test_size=test_size_1, random_state=0
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
train = remaining.loc[remaining["dna_cluster_rep"].isin(X_train)]
|
| 72 |
+
val = remaining.loc[remaining["dna_cluster_rep"].isin(X_val)]
|
| 73 |
+
leaky_test = lost_rows
|
| 74 |
+
|
| 75 |
+
splits = {
|
| 76 |
+
"train": train,
|
| 77 |
+
"val": val,
|
| 78 |
+
"test": test,
|
| 79 |
+
"leaky_test": leaky_test
|
| 80 |
+
}
|
| 81 |
+
return splits
|
| 82 |
+
|
| 83 |
+
def split_bipartite_fast(
|
| 84 |
+
dna_clusters,
|
| 85 |
+
split_names=("train", "val", "test"),
|
| 86 |
+
ratios=(0.8, 0.1, 0.1),
|
| 87 |
+
):
|
| 88 |
+
# use sklearn
|
| 89 |
+
test_size_1 = 0.2
|
| 90 |
+
test_size_2 = 0.5
|
| 91 |
+
logger.info(
|
| 92 |
+
f"\tPerforming first split: all clusters -> train clusters ({round(1-test_size_1,3)}) and other ({test_size_1})"
|
| 93 |
+
)
|
| 94 |
+
X = dna_clusters
|
| 95 |
+
y = [0] * len(dna_clusters)
|
| 96 |
+
X_train, X_test, y_train, y_test = train_test_split(
|
| 97 |
+
X, y, test_size=test_size_1, random_state=0
|
| 98 |
+
)
|
| 99 |
+
logger.info(
|
| 100 |
+
f"\tPerforming second split: other -> val clusters ({round(1-test_size_2,3)}) and test clusters ({test_size_2})"
|
| 101 |
+
)
|
| 102 |
+
X_val, X_test, y_val, y_test = train_test_split(
|
| 103 |
+
X_test, y_test, test_size=test_size_2, random_state=0
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
dna_assign = {}
|
| 107 |
+
for x in X_train:
|
| 108 |
+
dna_assign[x] = "train"
|
| 109 |
+
for x in X_val:
|
| 110 |
+
dna_assign[x] = "val"
|
| 111 |
+
for x in X_test:
|
| 112 |
+
dna_assign[x] = "test"
|
| 113 |
+
|
| 114 |
+
kept_by_split = {"train": len(X_train), "val": len(X_val), "test": len(X_test)}
|
| 115 |
+
return dna_assign, kept_by_split
|
| 116 |
+
|
| 117 |
+
def convert_scores(scores):
|
| 118 |
+
svec = [int(x) for x in scores.split(",")]
|
| 119 |
+
max_score = max(svec)
|
| 120 |
+
binary_svec = [0 if x<max_score else 1 for x in svec]
|
| 121 |
+
assert(svec.count(max_score)==binary_svec.count(1))
|
| 122 |
+
binary_svec = ",".join([str(x) for x in binary_svec])
|
| 123 |
+
return binary_svec
|
| 124 |
+
|
| 125 |
+
def split_bipartite_with_ratios_and_leaky(
|
| 126 |
+
edges,
|
| 127 |
+
split_names=("train", "val", "test"),
|
| 128 |
+
ratios=(0.8, 0.1, 0.1),
|
| 129 |
+
require_nonempty=False,
|
| 130 |
+
ratio_tolerance=None, # None = soft ratios only; 0.0 = exact band (use with care)
|
| 131 |
+
bigM=None,
|
| 132 |
+
shuffle_within_pair=False,
|
| 133 |
+
seed=0,
|
| 134 |
+
test_edges_must=None, # NEW: list of (tf,dna) with duplicates OR dict {(tf,dna): count}
|
| 135 |
+
):
|
| 136 |
+
"""
|
| 137 |
+
edges: list of (tf_cluster_id, dna_cluster_id). Duplicates allowed (-> weights).
|
| 138 |
+
test_edges_must: None, list of pairs, or dict {(tf,dna): required_count}.
|
| 139 |
+
- If a pair appears with required_count > 0, at least that many examples MUST be kept in TEST.
|
| 140 |
+
- This implicitly pins both clusters of that pair to TEST (cluster exclusivity).
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
tf_assign: {tf_cluster -> split}
|
| 144 |
+
dna_assign: {dna_cluster -> split}
|
| 145 |
+
kept_by_split: {split -> kept_count} (train/val/test only)
|
| 146 |
+
total_kept: int
|
| 147 |
+
split_to_indices: {split -> [input indices]} including 'leaky_test'
|
| 148 |
+
split_to_edges: {split -> [(tf,dna), ...]} including 'leaky_test'
|
| 149 |
+
"""
|
| 150 |
+
# Aggregate counts per pair
|
| 151 |
+
w = Counter(edges)
|
| 152 |
+
tfs = {t for (t, _) in w}
|
| 153 |
+
dnas = {d for (_, d) in w}
|
| 154 |
+
S = list(split_names)
|
| 155 |
+
rs = dict(zip(S, ratios))
|
| 156 |
+
N = sum(w.values())
|
| 157 |
+
if bigM is None:
|
| 158 |
+
bigM = 1000 * max(1, N)
|
| 159 |
+
|
| 160 |
+
# Index original edges so we can return a per-example split
|
| 161 |
+
pair_to_indices = defaultdict(list)
|
| 162 |
+
for idx, (c, d) in enumerate(edges):
|
| 163 |
+
pair_to_indices[(c, d)].append(idx)
|
| 164 |
+
|
| 165 |
+
if shuffle_within_pair:
|
| 166 |
+
rng = random.Random(seed)
|
| 167 |
+
for key in pair_to_indices:
|
| 168 |
+
rng.shuffle(pair_to_indices[key])
|
| 169 |
+
|
| 170 |
+
# Parse required test edges
|
| 171 |
+
req_test = Counter()
|
| 172 |
+
if test_edges_must:
|
| 173 |
+
if isinstance(test_edges_must, dict):
|
| 174 |
+
for k, v in test_edges_must.items():
|
| 175 |
+
if not isinstance(k, tuple) or len(k) != 2:
|
| 176 |
+
raise ValueError(
|
| 177 |
+
"test_edges_must dict keys must be (tf_cluster, dna_cluster)"
|
| 178 |
+
)
|
| 179 |
+
if v < 0:
|
| 180 |
+
raise ValueError("required_count must be non-negative")
|
| 181 |
+
if v:
|
| 182 |
+
req_test[k] += int(v)
|
| 183 |
+
else:
|
| 184 |
+
# assume iterable of pairs
|
| 185 |
+
req_test = Counter(test_edges_must)
|
| 186 |
+
# Validate against available counts
|
| 187 |
+
for pair, req in req_test.items():
|
| 188 |
+
if pair not in w:
|
| 189 |
+
raise ValueError(f"Required test pair {pair} not present in edges.")
|
| 190 |
+
if req > w[pair]:
|
| 191 |
+
raise ValueError(
|
| 192 |
+
f"Required count {req} for {pair} exceeds available {w[pair]}."
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
# Build solver
|
| 196 |
+
solver = pywraplp.Solver.CreateSolver("CBC")
|
| 197 |
+
if solver is None:
|
| 198 |
+
raise RuntimeError("Could not create CBC solver.")
|
| 199 |
+
|
| 200 |
+
# Binary cluster assignments
|
| 201 |
+
x = {(c, s): solver.BoolVar(f"x[{c},{s}]") for c in tfs for s in S}
|
| 202 |
+
y = {(d, s): solver.BoolVar(f"y[{d},{s}]") for d in dnas for s in S}
|
| 203 |
+
|
| 204 |
+
# Each cluster in exactly one split
|
| 205 |
+
for c in tfs:
|
| 206 |
+
solver.Add(sum(x[c, s] for s in S) == 1)
|
| 207 |
+
for d in dnas:
|
| 208 |
+
solver.Add(sum(y[d, s] for s in S) == 1)
|
| 209 |
+
|
| 210 |
+
# Integer kept counts per pair and split (allow partial within-pair)
|
| 211 |
+
k = {
|
| 212 |
+
((c, d), s): solver.IntVar(0, w[(c, d)], f"k[{c},{d},{s}]")
|
| 213 |
+
for (c, d) in w
|
| 214 |
+
for s in S
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
# Only keep in split s if both endpoint clusters are assigned to s
|
| 218 |
+
for (c, d), wt in w.items():
|
| 219 |
+
for s in S:
|
| 220 |
+
solver.Add(k[((c, d), s)] <= wt * x[c, s])
|
| 221 |
+
solver.Add(k[((c, d), s)] <= wt * y[d, s])
|
| 222 |
+
|
| 223 |
+
# Enforce minimum kept counts in TEST for required pairs
|
| 224 |
+
for (c, d), req in req_test.items():
|
| 225 |
+
solver.Add(k[((c, d), "test")] >= req)
|
| 226 |
+
|
| 227 |
+
# Optional: ensure each split has at least one cluster (feasibility depends on counts)
|
| 228 |
+
if require_nonempty:
|
| 229 |
+
for s in S:
|
| 230 |
+
solver.Add(sum(x[c, s] for c in tfs) + sum(y[d, s] for d in dnas) >= 1)
|
| 231 |
+
|
| 232 |
+
# Kept counts per split and total
|
| 233 |
+
K = {s: solver.IntVar(0, N, f"K[{s}]") for s in S}
|
| 234 |
+
for s in S:
|
| 235 |
+
solver.Add(K[s] == sum(k[((c, d), s)] for (c, d) in w))
|
| 236 |
+
T = solver.IntVar(0, N, "T")
|
| 237 |
+
solver.Add(T == sum(K[s] for s in S))
|
| 238 |
+
|
| 239 |
+
# Ratio deviation: K_s - r_s * T = d+ - d-
|
| 240 |
+
dpos = {s: solver.NumVar(0, solver.infinity(), f"dpos[{s}]") for s in S}
|
| 241 |
+
dneg = {s: solver.NumVar(0, solver.infinity(), f"dneg[{s}]") for s in S}
|
| 242 |
+
for s in S:
|
| 243 |
+
solver.Add(K[s] - rs[s] * T == dpos[s] - dneg[s])
|
| 244 |
+
|
| 245 |
+
# Optional hard band around target ratios
|
| 246 |
+
if ratio_tolerance is not None:
|
| 247 |
+
eps = float(ratio_tolerance)
|
| 248 |
+
for s in S:
|
| 249 |
+
solver.Add(K[s] >= (rs[s] - eps) * T)
|
| 250 |
+
solver.Add(K[s] <= (rs[s] + eps) * T)
|
| 251 |
+
|
| 252 |
+
# Objective: maximize T then minimize total deviation
|
| 253 |
+
obj = solver.Objective()
|
| 254 |
+
obj.SetMaximization()
|
| 255 |
+
obj.SetCoefficient(T, float(bigM))
|
| 256 |
+
for s in S:
|
| 257 |
+
obj.SetCoefficient(dpos[s], -1.0)
|
| 258 |
+
obj.SetCoefficient(dneg[s], -1.0)
|
| 259 |
+
|
| 260 |
+
status = solver.Solve()
|
| 261 |
+
if status not in (pywraplp.Solver.OPTIMAL, pywraplp.Solver.FEASIBLE):
|
| 262 |
+
raise RuntimeError(
|
| 263 |
+
"No feasible solution (check ratio_tolerance vs. required test edges)."
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
# Read cluster assignments
|
| 267 |
+
tf_assign = {c: next(s for s in S if x[c, s].solution_value() > 0.5) for c in tfs}
|
| 268 |
+
dna_assign = {d: next(s for s in S if y[d, s].solution_value() > 0.5) for d in dnas}
|
| 269 |
+
|
| 270 |
+
# Kept counts per split
|
| 271 |
+
kept_by_split = {s: int(round(K[s].solution_value())) for s in S}
|
| 272 |
+
total_kept = int(round(T.solution_value()))
|
| 273 |
+
|
| 274 |
+
# ---- Build per-example split assignment (including 'leaky_test') ----
|
| 275 |
+
split_to_indices = {s: [] for s in S}
|
| 276 |
+
remaining_indices = {pair: list(pair_to_indices[pair]) for pair in pair_to_indices}
|
| 277 |
+
|
| 278 |
+
# Allocate the kept examples per split (train/val/test)
|
| 279 |
+
for (c, d), wt in w.items():
|
| 280 |
+
for s in S:
|
| 281 |
+
cnt = int(round(k[((c, d), s)].solution_value()))
|
| 282 |
+
if cnt > 0:
|
| 283 |
+
take = remaining_indices[(c, d)][:cnt]
|
| 284 |
+
split_to_indices[s].extend(take)
|
| 285 |
+
remaining_indices[(c, d)] = remaining_indices[(c, d)][cnt:]
|
| 286 |
+
|
| 287 |
+
# Everything left becomes leaky_test
|
| 288 |
+
leaky_indices = []
|
| 289 |
+
for pair, idxs in remaining_indices.items():
|
| 290 |
+
if idxs:
|
| 291 |
+
leaky_indices.extend(idxs)
|
| 292 |
+
|
| 293 |
+
split_to_indices["leaky_test"] = leaky_indices
|
| 294 |
+
split_to_edges = {
|
| 295 |
+
s: [edges[i] for i in split_to_indices[s]] for s in split_to_indices
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
return (
|
| 299 |
+
tf_assign,
|
| 300 |
+
dna_assign,
|
| 301 |
+
kept_by_split,
|
| 302 |
+
total_kept,
|
| 303 |
+
split_to_indices,
|
| 304 |
+
split_to_edges,
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
class DSU:
|
| 309 |
+
def __init__(self):
|
| 310 |
+
self.p = {}
|
| 311 |
+
|
| 312 |
+
def find(self, x):
|
| 313 |
+
if x not in self.p:
|
| 314 |
+
self.p[x] = x
|
| 315 |
+
while self.p[x] != x:
|
| 316 |
+
self.p[x] = self.p[self.p[x]]
|
| 317 |
+
x = self.p[x]
|
| 318 |
+
return x
|
| 319 |
+
|
| 320 |
+
def union(self, a, b):
|
| 321 |
+
ra, rb = self.find(a), self.find(b)
|
| 322 |
+
if ra != rb:
|
| 323 |
+
self.p[rb] = ra
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def split_bipartite_by_components(
|
| 327 |
+
edges,
|
| 328 |
+
split_names=("train", "val", "test"),
|
| 329 |
+
ratios=(0.8, 0.1, 0.1),
|
| 330 |
+
seed=0,
|
| 331 |
+
require_nonempty=False,
|
| 332 |
+
test_edges_must=None, # None, list[(tf,dna)], or dict{(tf,dna): count}
|
| 333 |
+
):
|
| 334 |
+
"""
|
| 335 |
+
Guarantees exclusivity: each TF cluster and DNA cluster appears in at most one split.
|
| 336 |
+
Strategy: find connected components in the TF–DNA bipartite graph and assign components wholesale.
|
| 337 |
+
"""
|
| 338 |
+
rng = random.Random(seed)
|
| 339 |
+
w = Counter(edges) # multiplicities per pair
|
| 340 |
+
if not w:
|
| 341 |
+
raise ValueError("No edges.")
|
| 342 |
+
|
| 343 |
+
# 1) Build components with Union-Find (prefix to keep TF/DNA namespaces disjoint)
|
| 344 |
+
dsu = DSU()
|
| 345 |
+
for tf, dna in w:
|
| 346 |
+
dsu.union(("T", tf), ("D", dna))
|
| 347 |
+
comp_pairs = defaultdict(list)
|
| 348 |
+
comp_weight = defaultdict(int)
|
| 349 |
+
for (tf, dna), cnt in w.items():
|
| 350 |
+
root = dsu.find(("T", tf)) # component id = root of TF endpoint
|
| 351 |
+
comp_pairs[root].append((tf, dna))
|
| 352 |
+
comp_weight[root] += cnt
|
| 353 |
+
|
| 354 |
+
comps = list(comp_pairs.keys())
|
| 355 |
+
C = len(comps)
|
| 356 |
+
S = list(split_names)
|
| 357 |
+
rs = dict(zip(S, ratios))
|
| 358 |
+
N = sum(comp_weight[c] for c in comps)
|
| 359 |
+
target = {s: int(round(rs[s] * N)) for s in S}
|
| 360 |
+
|
| 361 |
+
# 2) Pin components that contain required TEST pairs
|
| 362 |
+
pinned = {} # comp_root -> pinned_split ("test")
|
| 363 |
+
if test_edges_must:
|
| 364 |
+
req = (
|
| 365 |
+
Counter(test_edges_must)
|
| 366 |
+
if not isinstance(test_edges_must, dict)
|
| 367 |
+
else Counter(test_edges_must)
|
| 368 |
+
)
|
| 369 |
+
# Map each required pair to its component, ensure feasibility
|
| 370 |
+
for (tf, dna), r in req.items():
|
| 371 |
+
if (tf, dna) not in w:
|
| 372 |
+
raise ValueError(f"Required pair {(tf,dna)} not present.")
|
| 373 |
+
if r > w[(tf, dna)]:
|
| 374 |
+
raise ValueError(
|
| 375 |
+
f"Required count {r} for {(tf,dna)} exceeds available {w[(tf,dna)]}."
|
| 376 |
+
)
|
| 377 |
+
comp = dsu.find(("T", tf))
|
| 378 |
+
if comp in pinned and pinned[comp] != "test":
|
| 379 |
+
raise ValueError(
|
| 380 |
+
f"Component conflict: already pinned to {pinned[comp]}, but {(tf,dna)} demands test."
|
| 381 |
+
)
|
| 382 |
+
pinned[comp] = "test"
|
| 383 |
+
# NOTE: pinning a pair pins the WHOLE component to test (to keep exclusivity).
|
| 384 |
+
# If you only want some edges kept in test and discard the rest, handle below when materializing.
|
| 385 |
+
|
| 386 |
+
# 3) Assign components greedily by deficit
|
| 387 |
+
kept_by_split = {s: 0 for s in S}
|
| 388 |
+
comp_assign = {} # comp_root -> split
|
| 389 |
+
|
| 390 |
+
# First assign pinned comps
|
| 391 |
+
for comp, split in pinned.items():
|
| 392 |
+
comp_assign[comp] = split
|
| 393 |
+
kept_by_split[split] += comp_weight[comp]
|
| 394 |
+
|
| 395 |
+
# Sort remaining components by descending weight
|
| 396 |
+
remaining = [c for c in comps if c not in comp_assign]
|
| 397 |
+
remaining.sort(key=lambda c: comp_weight[c], reverse=True)
|
| 398 |
+
|
| 399 |
+
# Ensure nonempty splits if requested (seed with largest remaining comps)
|
| 400 |
+
if require_nonempty:
|
| 401 |
+
seeds = remaining[: min(len(S), len(remaining))]
|
| 402 |
+
for comp, s in zip(seeds, S):
|
| 403 |
+
comp_assign[comp] = s
|
| 404 |
+
kept_by_split[s] += comp_weight[comp]
|
| 405 |
+
remaining = [c for c in remaining if c not in comp_assign]
|
| 406 |
+
|
| 407 |
+
for comp in remaining:
|
| 408 |
+
# choose split with largest deficit (target - current)
|
| 409 |
+
deficits = {s: target[s] - kept_by_split[s] for s in S}
|
| 410 |
+
best = max(deficits, key=lambda s: deficits[s])
|
| 411 |
+
comp_assign[comp] = best
|
| 412 |
+
kept_by_split[best] += comp_weight[comp]
|
| 413 |
+
|
| 414 |
+
total_kept = sum(kept_by_split.values())
|
| 415 |
+
|
| 416 |
+
# 4) Materialize per-example indices (and verify exclusivity)
|
| 417 |
+
pair_to_indices = defaultdict(list)
|
| 418 |
+
for idx, pair in enumerate(edges):
|
| 419 |
+
pair_to_indices[pair].append(idx)
|
| 420 |
+
|
| 421 |
+
split_to_indices = {s: [] for s in S}
|
| 422 |
+
for comp, s in comp_assign.items():
|
| 423 |
+
for pair in comp_pairs[comp]:
|
| 424 |
+
split_to_indices[s].extend(pair_to_indices[pair])
|
| 425 |
+
|
| 426 |
+
# Optional: if you pinned a comp due to a small 'must-test' count but
|
| 427 |
+
# want to *discard* the rest instead of keeping them in test, uncomment:
|
| 428 |
+
# for comp, s in comp_assign.items():
|
| 429 |
+
# if s == "test" and test_edges_must:
|
| 430 |
+
# # Keep only the required counts; dump extras to 'leaky_test'
|
| 431 |
+
# ...
|
| 432 |
+
# (Left out for clarity; default is: keep the whole component in its split.)
|
| 433 |
+
|
| 434 |
+
# 5) Build edge lists and simple cluster assignments
|
| 435 |
+
split_to_edges = {
|
| 436 |
+
s: [edges[i] for i in split_to_indices[s]] for s in split_to_indices
|
| 437 |
+
}
|
| 438 |
+
tf_assign, dna_assign = {}, {}
|
| 439 |
+
for comp, s in comp_assign.items():
|
| 440 |
+
for tf, dna in comp_pairs[comp]:
|
| 441 |
+
tf_assign[tf] = s
|
| 442 |
+
dna_assign[dna] = s
|
| 443 |
+
|
| 444 |
+
# 6) Safety check: no DNA/TF appears in multiple splits
|
| 445 |
+
tf_in_split = defaultdict(set)
|
| 446 |
+
dna_in_split = defaultdict(set)
|
| 447 |
+
for s, elist in split_to_edges.items():
|
| 448 |
+
for tf, dna in elist:
|
| 449 |
+
tf_in_split[tf].add(s)
|
| 450 |
+
dna_in_split[dna].add(s)
|
| 451 |
+
dup_tf = {tf: ss for tf, ss in tf_in_split.items() if len(ss) > 1}
|
| 452 |
+
dup_dna = {dn: ss for dn, ss in dna_in_split.items() if len(ss) > 1}
|
| 453 |
+
assert not dup_tf and not dup_dna, f"Exclusivity violated: {dup_tf} {dup_dna}"
|
| 454 |
+
|
| 455 |
+
return (
|
| 456 |
+
tf_assign,
|
| 457 |
+
dna_assign,
|
| 458 |
+
kept_by_split,
|
| 459 |
+
total_kept,
|
| 460 |
+
split_to_indices,
|
| 461 |
+
split_to_edges,
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
def print_split_ratios(kept_by_split):
|
| 466 |
+
total = sum(kept_by_split.values())
|
| 467 |
+
train_pcnt = 100 * kept_by_split["train"] / total
|
| 468 |
+
val_pcnt = 100 * kept_by_split["val"] / total
|
| 469 |
+
test_pcnt = 100 * kept_by_split["test"] / total
|
| 470 |
+
logger.info(
|
| 471 |
+
f"Cluster distribution - Train: {train_pcnt:.2f}%, Val: {val_pcnt:.2f}%, Test: {test_pcnt:.2f}%"
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
def make_edges(
|
| 476 |
+
processed_fimo_path: str, protein_cluster_path: str, dna_cluster_path: str
|
| 477 |
+
):
|
| 478 |
+
"""
|
| 479 |
+
Make edges for input to the splitting algorithm. Edges consist of: (tr_cluster_rep)_(dna_cluster_rep) where the cluster rep is the sequence ID
|
| 480 |
+
"""
|
| 481 |
+
# Read cluser data
|
| 482 |
+
protein_clusters = pd.read_csv(protein_cluster_path, header=None, sep="\t")
|
| 483 |
+
protein_clusters.columns = ["tr_cluster_rep", "tr_seqid"]
|
| 484 |
+
|
| 485 |
+
dna_clusters = pd.read_csv(dna_cluster_path, header=None, sep="\t")
|
| 486 |
+
dna_clusters.columns = ["dna_cluster_rep", "dna_seqid"]
|
| 487 |
+
|
| 488 |
+
# Read datapoints
|
| 489 |
+
edges = pd.read_parquet(processed_fimo_path)
|
| 490 |
+
edges = pd.merge(edges, dna_clusters, on="dna_seqid", how="left")
|
| 491 |
+
edges = pd.merge(edges, protein_clusters, on="tr_seqid", how="left")
|
| 492 |
+
edges["edge"] = edges.apply(
|
| 493 |
+
lambda row: (row["tr_cluster_rep"], row["dna_cluster_rep"]), axis=1
|
| 494 |
+
)
|
| 495 |
+
|
| 496 |
+
logger.info(f"Total unique edges: {len(edges['edge'].unique().tolist())}")
|
| 497 |
+
dup_edges = edges.loc[edges.duplicated("edge")]["edge"].unique().tolist()
|
| 498 |
+
logger.info(f"Total edges with >1 datapoint: {len(dup_edges)}")
|
| 499 |
+
logger.info(
|
| 500 |
+
f"Total datapoints belonging to a duplicate edge: {len(edges.loc[edges['edge'].isin(dup_edges)])}"
|
| 501 |
+
)
|
| 502 |
+
return edges
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
def check_validity(train, val, test, split_by="both"):
|
| 506 |
+
"""
|
| 507 |
+
Rigorous check for no overlap
|
| 508 |
+
Columns = ["ID","dna_sequence","tr_sequence","tr_cluster_rep","dna_cluster_rep", "scores","split"]
|
| 509 |
+
"""
|
| 510 |
+
train_ids = set(train["ID"].unique().tolist())
|
| 511 |
+
val_ids = set(val["ID"].unique().tolist())
|
| 512 |
+
test_ids = set(test["ID"].unique().tolist())
|
| 513 |
+
|
| 514 |
+
assert len(train_ids.intersection(val_ids)) == 0
|
| 515 |
+
assert len(train_ids.intersection(test_ids)) == 0
|
| 516 |
+
assert len(val_ids.intersection(test_ids)) == 0
|
| 517 |
+
logger.info(f"Pass! No overlap in IDs")
|
| 518 |
+
|
| 519 |
+
if split_by != "dna":
|
| 520 |
+
train_tr_seqs = set(train["tr_sequence"].unique().tolist())
|
| 521 |
+
val_tr_seqs = set(val["tr_sequence"].unique().tolist())
|
| 522 |
+
test_tr_seqs = set(test["tr_sequence"].unique().tolist())
|
| 523 |
+
|
| 524 |
+
assert len(train_tr_seqs.intersection(val_tr_seqs)) == 0
|
| 525 |
+
assert len(train_tr_seqs.intersection(test_tr_seqs)) == 0
|
| 526 |
+
assert len(val_tr_seqs.intersection(test_tr_seqs)) == 0
|
| 527 |
+
logger.info(f"Pass! No overlap in TR sequences")
|
| 528 |
+
|
| 529 |
+
train_tr_reps = set(train["tr_cluster_rep"].unique().tolist())
|
| 530 |
+
val_tr_reps = set(val["tr_cluster_rep"].unique().tolist())
|
| 531 |
+
test_tr_reps = set(test["tr_cluster_rep"].unique().tolist())
|
| 532 |
+
|
| 533 |
+
assert len(train_tr_reps.intersection(val_tr_reps)) == 0
|
| 534 |
+
assert len(train_tr_reps.intersection(test_tr_reps)) == 0
|
| 535 |
+
assert len(val_tr_reps.intersection(test_tr_reps)) == 0
|
| 536 |
+
logger.info(f"Pass! No overlap in TR cluster reps")
|
| 537 |
+
|
| 538 |
+
if split_by != "protein":
|
| 539 |
+
train_dna_seqs = set(train["dna_sequence"].unique().tolist())
|
| 540 |
+
val_dna_seqs = set(val["dna_sequence"].unique().tolist())
|
| 541 |
+
test_dna_seqs = set(test["dna_sequence"].unique().tolist())
|
| 542 |
+
|
| 543 |
+
assert len(train_dna_seqs.intersection(val_dna_seqs)) == 0
|
| 544 |
+
assert len(train_dna_seqs.intersection(test_dna_seqs)) == 0
|
| 545 |
+
assert len(val_dna_seqs.intersection(test_dna_seqs)) == 0
|
| 546 |
+
logger.info(f"Pass! No overlap in DNA sequences")
|
| 547 |
+
|
| 548 |
+
train_dna_reps = set(train["dna_cluster_rep"].unique().tolist())
|
| 549 |
+
val_dna_reps = set(val["dna_cluster_rep"].unique().tolist())
|
| 550 |
+
test_dna_reps = set(test["dna_cluster_rep"].unique().tolist())
|
| 551 |
+
|
| 552 |
+
assert len(train_dna_reps.intersection(val_dna_reps)) == 0
|
| 553 |
+
assert len(train_dna_reps.intersection(test_dna_reps)) == 0
|
| 554 |
+
assert len(val_dna_reps.intersection(test_dna_reps)) == 0
|
| 555 |
+
logger.info(f"Pass! No overlap in DNA cluster reps")
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
def augment_rc(df):
|
| 559 |
+
"""
|
| 560 |
+
Get the reverse complement and add it as a datapoint, effectively doubling the dataset.
|
| 561 |
+
Also flip the orientation of the scores
|
| 562 |
+
|
| 563 |
+
columns = ["ID","dna_sequence","tr_sequence","tr_cluster_rep","dna_cluster_rep", "scores","split"]
|
| 564 |
+
"""
|
| 565 |
+
df_rc = df.copy(deep=True)
|
| 566 |
+
|
| 567 |
+
df_rc["dna_sequence"] = df_rc["dna_sequence"].apply(
|
| 568 |
+
lambda x: get_reverse_complement(x)
|
| 569 |
+
)
|
| 570 |
+
df_rc["ID"] = df_rc["ID"] + "_rc"
|
| 571 |
+
df_rc["scores"] = df_rc["scores"].apply(lambda s: ",".join(s.split(",")[::-1]))
|
| 572 |
+
|
| 573 |
+
final_df = pd.concat([df, df_rc]).reset_index(drop=True)
|
| 574 |
+
|
| 575 |
+
return final_df
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
def main(cfg: DictConfig):
|
| 579 |
+
"""
|
| 580 |
+
Take a set of DNA clusters + protein clusters, and create the best possible splits into train/val/test.
|
| 581 |
+
"""
|
| 582 |
+
# construct edges from training data
|
| 583 |
+
edge_df = make_edges(
|
| 584 |
+
processed_fimo_path=Path(root) / cfg.data_task.input_data_path,
|
| 585 |
+
protein_cluster_path=Path(root) / cfg.data_task.cluster_output_paths.protein,
|
| 586 |
+
dna_cluster_path=Path(root) / cfg.data_task.cluster_output_paths.dna,
|
| 587 |
+
)
|
| 588 |
+
edges = edge_df["edge"].unique().tolist()
|
| 589 |
+
|
| 590 |
+
# figure out if we actually even have a conflict
|
| 591 |
+
total_proteins = len(edge_df["tr_seqid"].unique().tolist())
|
| 592 |
+
total_protein_clusters = len(edge_df["tr_cluster_rep"].unique().tolist())
|
| 593 |
+
|
| 594 |
+
no_protein_overlap = (total_proteins) == (total_protein_clusters)
|
| 595 |
+
logger.info(f"All proteins are in their own clusters: {no_protein_overlap}")
|
| 596 |
+
|
| 597 |
+
if cfg.data_task.split_by == "dna":
|
| 598 |
+
if cfg.data_task.p_exclude:
|
| 599 |
+
return
|
| 600 |
+
else:
|
| 601 |
+
logger.info(f"Easy split: all proteins are in their own clusters.")
|
| 602 |
+
dna_clusters = edge_df["dna_cluster_rep"].unique().tolist()
|
| 603 |
+
results = split_bipartite_fast(
|
| 604 |
+
dna_clusters,
|
| 605 |
+
split_names=("train", "val", "test"),
|
| 606 |
+
ratios=(
|
| 607 |
+
cfg.data_task.train_ratio,
|
| 608 |
+
cfg.data_task.val_ratio,
|
| 609 |
+
cfg.data_task.test_ratio,
|
| 610 |
+
),
|
| 611 |
+
)
|
| 612 |
+
dna_assign, kept_by_split = results
|
| 613 |
+
|
| 614 |
+
# assign datapoints to cluster by their DNA cluster rep
|
| 615 |
+
edge_df["split"] = edge_df["dna_cluster_rep"].map(dna_assign)
|
| 616 |
+
else:
|
| 617 |
+
results = split_bipartite_by_components(
|
| 618 |
+
edges,
|
| 619 |
+
split_names=("train", "val", "test"),
|
| 620 |
+
ratios=(
|
| 621 |
+
cfg.data_task.train_ratio,
|
| 622 |
+
cfg.data_task.val_ratio,
|
| 623 |
+
cfg.data_task.test_ratio,
|
| 624 |
+
),
|
| 625 |
+
require_nonempty=cfg.data_task.require_nonempty,
|
| 626 |
+
seed=cfg.data_task.seed,
|
| 627 |
+
test_edges_must=None,
|
| 628 |
+
)
|
| 629 |
+
|
| 630 |
+
(
|
| 631 |
+
tf_assign,
|
| 632 |
+
dna_assign,
|
| 633 |
+
kept_by_split,
|
| 634 |
+
total_kept,
|
| 635 |
+
split_to_indices,
|
| 636 |
+
split_to_edges,
|
| 637 |
+
) = results
|
| 638 |
+
|
| 639 |
+
# Map each sample to its split
|
| 640 |
+
print(tf_assign)
|
| 641 |
+
print(dna_assign)
|
| 642 |
+
edge_df["tr_split"] = edge_df["tr_cluster_rep"].map(tf_assign)
|
| 643 |
+
edge_df["dna_split"] = edge_df["dna_cluster_rep"].map(dna_assign)
|
| 644 |
+
edge_df["same_split"] = (
|
| 645 |
+
edge_df["tr_split"] == edge_df["dna_split"]
|
| 646 |
+
) # should always be true if easy cluster
|
| 647 |
+
edge_df["split"] = edge_df["tr_split"]
|
| 648 |
+
print(edge_df)
|
| 649 |
+
edge_df["split"] = np.where(
|
| 650 |
+
edge_df["same_split"],
|
| 651 |
+
edge_df["split"], # keep existing split if same_split == True
|
| 652 |
+
"leak", # otherwise leak
|
| 653 |
+
)
|
| 654 |
+
print(edge_df)
|
| 655 |
+
|
| 656 |
+
# Print ratios: hopefully close to desired (e.g. 80/10/10)
|
| 657 |
+
print_split_ratios(kept_by_split)
|
| 658 |
+
|
| 659 |
+
# Make train, val, test sets
|
| 660 |
+
# make sure no ID is duplicate
|
| 661 |
+
assert len(edge_df["ID"].unique()) == len(edge_df)
|
| 662 |
+
split_cols = [
|
| 663 |
+
"ID",
|
| 664 |
+
"dna_sequence",
|
| 665 |
+
"tr_sequence",
|
| 666 |
+
"tr_cluster_rep",
|
| 667 |
+
"dna_cluster_rep",
|
| 668 |
+
"scores",
|
| 669 |
+
"split",
|
| 670 |
+
]
|
| 671 |
+
train = edge_df.loc[edge_df["split"] == "train"].reset_index(drop=True)[split_cols]
|
| 672 |
+
val = edge_df.loc[edge_df["split"] == "val"].reset_index(drop=True)[split_cols]
|
| 673 |
+
test = edge_df.loc[edge_df["split"] == "test"].reset_index(drop=True)[split_cols]
|
| 674 |
+
|
| 675 |
+
# ensure there is no overlap
|
| 676 |
+
check_validity(train, val, test, split_by=cfg.data_task.split_by)
|
| 677 |
+
|
| 678 |
+
total = sum([len(train), len(val), len(test)])
|
| 679 |
+
logger.info(f"Length of train dataset: {len(train)} ({100*len(train)/total:.2f}%)")
|
| 680 |
+
logger.info(f"Length of val dataset: {len(val)} ({100*len(val)/total:.2f}%)")
|
| 681 |
+
logger.info(f"Length of test dataset: {len(test)} ({100*len(test)/total:.2f}%)")
|
| 682 |
+
logger.info(f"Total sequences = {total}. Same as edges size? {total==len(edge_df)}")
|
| 683 |
+
|
| 684 |
+
og_unique_dna = pd.concat([train, val, test])
|
| 685 |
+
og_unique_dna = len(og_unique_dna["dna_sequence"].unique())
|
| 686 |
+
|
| 687 |
+
## Now do RC data augmentation if asked
|
| 688 |
+
if cfg.data_task.augment_rc:
|
| 689 |
+
train = augment_rc(train)
|
| 690 |
+
val = augment_rc(val)
|
| 691 |
+
test = augment_rc(test)
|
| 692 |
+
|
| 693 |
+
logger.info(f"Added reverse complement sequences to train, val, and test.")
|
| 694 |
+
|
| 695 |
+
check_validity(train, val, test, split_by=cfg.data_task.split_by)
|
| 696 |
+
|
| 697 |
+
total = sum([len(train), len(val), len(test)])
|
| 698 |
+
logger.info(
|
| 699 |
+
f"Length of train dataset: {len(train)} ({100*len(train)/total:.2f}%)"
|
| 700 |
+
)
|
| 701 |
+
logger.info(f"Length of val dataset: {len(val)} ({100*len(val)/total:.2f}%)")
|
| 702 |
+
logger.info(f"Length of test dataset: {len(test)} ({100*len(test)/total:.2f}%)")
|
| 703 |
+
logger.info(
|
| 704 |
+
f"Total sequences = {total}. Same as edges size? {total==len(edge_df)}"
|
| 705 |
+
)
|
| 706 |
+
|
| 707 |
+
# since we've added all these new DNA sequences, we do need a new apping of seq id to dna sequence
|
| 708 |
+
all_data = pd.concat([train, val, test])
|
| 709 |
+
all_data["dna_seqid"] = all_data["ID"].str.split("_", n=1, expand=True)[1]
|
| 710 |
+
dna_dict = dict(zip(all_data["dna_seqid"], all_data["dna_sequence"]))
|
| 711 |
+
assert len(dna_dict) == len(all_data.drop_duplicates(["dna_sequence"]))
|
| 712 |
+
new_map_path = str(Path(root) / cfg.data_task.dna_map_path).replace(
|
| 713 |
+
".json", "_with_rc.json"
|
| 714 |
+
)
|
| 715 |
+
|
| 716 |
+
with open(new_map_path, "w") as f:
|
| 717 |
+
json.dump(dna_dict, f, indent=2)
|
| 718 |
+
logger.info(
|
| 719 |
+
f"Saved DNA maps with reverse complements (len {len(dna_dict)}=2*original map of len {og_unique_dna}=={len(dna_dict)==2*og_unique_dna}) to {new_map_path}"
|
| 720 |
+
)
|
| 721 |
+
|
| 722 |
+
# create the output dir
|
| 723 |
+
split_out_dir = Path(root) / cfg.data_task.split_out_dir
|
| 724 |
+
os.makedirs(split_out_dir, exist_ok=True)
|
| 725 |
+
|
| 726 |
+
# add binary_scores to allow other training modes
|
| 727 |
+
train["fimo_binary_sores"] = train["scores"].apply(lambda x: convert_scores(x))
|
| 728 |
+
val["fimo_binary_sores"] = val["scores"].apply(lambda x: convert_scores(x))
|
| 729 |
+
test["fimo_binary_sores"] = test["scores"].apply(lambda x: convert_scores(x))
|
| 730 |
+
|
| 731 |
+
# slect final cols and save
|
| 732 |
+
split_final_cols = ["ID", "dna_sequence", "tr_sequence", "scores", "fimo_binary_sores", "split"]
|
| 733 |
+
train[split_final_cols].to_csv(split_out_dir / "train.csv", index=False)
|
| 734 |
+
val[split_final_cols].to_csv(split_out_dir / "val.csv", index=False)
|
| 735 |
+
test[split_final_cols].to_csv(split_out_dir / "test.csv", index=False)
|
| 736 |
+
logger.info(f"Saved all splits to {split_out_dir}")
|
dpacman/data_tasks/split/remap.py
CHANGED
|
@@ -15,6 +15,74 @@ from dpacman.utils import pylogger
|
|
| 15 |
root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
| 16 |
logger = pylogger.RankedLogger(__name__, rank_zero_only=True)
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
def split_bipartite_fast(
|
| 20 |
dna_clusters,
|
|
@@ -50,354 +118,22 @@ def split_bipartite_fast(
|
|
| 50 |
kept_by_split = {"train": len(X_train), "val": len(X_val), "test": len(X_test)}
|
| 51 |
return dna_assign, kept_by_split
|
| 52 |
|
| 53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
svec = [int(x) for x in scores.split(",")]
|
| 55 |
max_score = max(svec)
|
| 56 |
-
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
binary_svec = ",".join([str(x) for x in binary_svec])
|
| 59 |
return binary_svec
|
| 60 |
-
|
| 61 |
-
def split_bipartite_with_ratios_and_leaky(
|
| 62 |
-
edges,
|
| 63 |
-
split_names=("train", "val", "test"),
|
| 64 |
-
ratios=(0.8, 0.1, 0.1),
|
| 65 |
-
require_nonempty=False,
|
| 66 |
-
ratio_tolerance=None, # None = soft ratios only; 0.0 = exact band (use with care)
|
| 67 |
-
bigM=None,
|
| 68 |
-
shuffle_within_pair=False,
|
| 69 |
-
seed=0,
|
| 70 |
-
test_edges_must=None, # NEW: list of (tf,dna) with duplicates OR dict {(tf,dna): count}
|
| 71 |
-
):
|
| 72 |
-
"""
|
| 73 |
-
edges: list of (tf_cluster_id, dna_cluster_id). Duplicates allowed (-> weights).
|
| 74 |
-
test_edges_must: None, list of pairs, or dict {(tf,dna): required_count}.
|
| 75 |
-
- If a pair appears with required_count > 0, at least that many examples MUST be kept in TEST.
|
| 76 |
-
- This implicitly pins both clusters of that pair to TEST (cluster exclusivity).
|
| 77 |
-
|
| 78 |
-
Returns:
|
| 79 |
-
tf_assign: {tf_cluster -> split}
|
| 80 |
-
dna_assign: {dna_cluster -> split}
|
| 81 |
-
kept_by_split: {split -> kept_count} (train/val/test only)
|
| 82 |
-
total_kept: int
|
| 83 |
-
split_to_indices: {split -> [input indices]} including 'leaky_test'
|
| 84 |
-
split_to_edges: {split -> [(tf,dna), ...]} including 'leaky_test'
|
| 85 |
-
"""
|
| 86 |
-
# Aggregate counts per pair
|
| 87 |
-
w = Counter(edges)
|
| 88 |
-
tfs = {t for (t, _) in w}
|
| 89 |
-
dnas = {d for (_, d) in w}
|
| 90 |
-
S = list(split_names)
|
| 91 |
-
rs = dict(zip(S, ratios))
|
| 92 |
-
N = sum(w.values())
|
| 93 |
-
if bigM is None:
|
| 94 |
-
bigM = 1000 * max(1, N)
|
| 95 |
-
|
| 96 |
-
# Index original edges so we can return a per-example split
|
| 97 |
-
pair_to_indices = defaultdict(list)
|
| 98 |
-
for idx, (c, d) in enumerate(edges):
|
| 99 |
-
pair_to_indices[(c, d)].append(idx)
|
| 100 |
-
|
| 101 |
-
if shuffle_within_pair:
|
| 102 |
-
rng = random.Random(seed)
|
| 103 |
-
for key in pair_to_indices:
|
| 104 |
-
rng.shuffle(pair_to_indices[key])
|
| 105 |
-
|
| 106 |
-
# Parse required test edges
|
| 107 |
-
req_test = Counter()
|
| 108 |
-
if test_edges_must:
|
| 109 |
-
if isinstance(test_edges_must, dict):
|
| 110 |
-
for k, v in test_edges_must.items():
|
| 111 |
-
if not isinstance(k, tuple) or len(k) != 2:
|
| 112 |
-
raise ValueError(
|
| 113 |
-
"test_edges_must dict keys must be (tf_cluster, dna_cluster)"
|
| 114 |
-
)
|
| 115 |
-
if v < 0:
|
| 116 |
-
raise ValueError("required_count must be non-negative")
|
| 117 |
-
if v:
|
| 118 |
-
req_test[k] += int(v)
|
| 119 |
-
else:
|
| 120 |
-
# assume iterable of pairs
|
| 121 |
-
req_test = Counter(test_edges_must)
|
| 122 |
-
# Validate against available counts
|
| 123 |
-
for pair, req in req_test.items():
|
| 124 |
-
if pair not in w:
|
| 125 |
-
raise ValueError(f"Required test pair {pair} not present in edges.")
|
| 126 |
-
if req > w[pair]:
|
| 127 |
-
raise ValueError(
|
| 128 |
-
f"Required count {req} for {pair} exceeds available {w[pair]}."
|
| 129 |
-
)
|
| 130 |
-
|
| 131 |
-
# Build solver
|
| 132 |
-
solver = pywraplp.Solver.CreateSolver("CBC")
|
| 133 |
-
if solver is None:
|
| 134 |
-
raise RuntimeError("Could not create CBC solver.")
|
| 135 |
-
|
| 136 |
-
# Binary cluster assignments
|
| 137 |
-
x = {(c, s): solver.BoolVar(f"x[{c},{s}]") for c in tfs for s in S}
|
| 138 |
-
y = {(d, s): solver.BoolVar(f"y[{d},{s}]") for d in dnas for s in S}
|
| 139 |
-
|
| 140 |
-
# Each cluster in exactly one split
|
| 141 |
-
for c in tfs:
|
| 142 |
-
solver.Add(sum(x[c, s] for s in S) == 1)
|
| 143 |
-
for d in dnas:
|
| 144 |
-
solver.Add(sum(y[d, s] for s in S) == 1)
|
| 145 |
-
|
| 146 |
-
# Integer kept counts per pair and split (allow partial within-pair)
|
| 147 |
-
k = {
|
| 148 |
-
((c, d), s): solver.IntVar(0, w[(c, d)], f"k[{c},{d},{s}]")
|
| 149 |
-
for (c, d) in w
|
| 150 |
-
for s in S
|
| 151 |
-
}
|
| 152 |
-
|
| 153 |
-
# Only keep in split s if both endpoint clusters are assigned to s
|
| 154 |
-
for (c, d), wt in w.items():
|
| 155 |
-
for s in S:
|
| 156 |
-
solver.Add(k[((c, d), s)] <= wt * x[c, s])
|
| 157 |
-
solver.Add(k[((c, d), s)] <= wt * y[d, s])
|
| 158 |
-
|
| 159 |
-
# Enforce minimum kept counts in TEST for required pairs
|
| 160 |
-
for (c, d), req in req_test.items():
|
| 161 |
-
solver.Add(k[((c, d), "test")] >= req)
|
| 162 |
-
|
| 163 |
-
# Optional: ensure each split has at least one cluster (feasibility depends on counts)
|
| 164 |
-
if require_nonempty:
|
| 165 |
-
for s in S:
|
| 166 |
-
solver.Add(sum(x[c, s] for c in tfs) + sum(y[d, s] for d in dnas) >= 1)
|
| 167 |
-
|
| 168 |
-
# Kept counts per split and total
|
| 169 |
-
K = {s: solver.IntVar(0, N, f"K[{s}]") for s in S}
|
| 170 |
-
for s in S:
|
| 171 |
-
solver.Add(K[s] == sum(k[((c, d), s)] for (c, d) in w))
|
| 172 |
-
T = solver.IntVar(0, N, "T")
|
| 173 |
-
solver.Add(T == sum(K[s] for s in S))
|
| 174 |
-
|
| 175 |
-
# Ratio deviation: K_s - r_s * T = d+ - d-
|
| 176 |
-
dpos = {s: solver.NumVar(0, solver.infinity(), f"dpos[{s}]") for s in S}
|
| 177 |
-
dneg = {s: solver.NumVar(0, solver.infinity(), f"dneg[{s}]") for s in S}
|
| 178 |
-
for s in S:
|
| 179 |
-
solver.Add(K[s] - rs[s] * T == dpos[s] - dneg[s])
|
| 180 |
-
|
| 181 |
-
# Optional hard band around target ratios
|
| 182 |
-
if ratio_tolerance is not None:
|
| 183 |
-
eps = float(ratio_tolerance)
|
| 184 |
-
for s in S:
|
| 185 |
-
solver.Add(K[s] >= (rs[s] - eps) * T)
|
| 186 |
-
solver.Add(K[s] <= (rs[s] + eps) * T)
|
| 187 |
-
|
| 188 |
-
# Objective: maximize T then minimize total deviation
|
| 189 |
-
obj = solver.Objective()
|
| 190 |
-
obj.SetMaximization()
|
| 191 |
-
obj.SetCoefficient(T, float(bigM))
|
| 192 |
-
for s in S:
|
| 193 |
-
obj.SetCoefficient(dpos[s], -1.0)
|
| 194 |
-
obj.SetCoefficient(dneg[s], -1.0)
|
| 195 |
-
|
| 196 |
-
status = solver.Solve()
|
| 197 |
-
if status not in (pywraplp.Solver.OPTIMAL, pywraplp.Solver.FEASIBLE):
|
| 198 |
-
raise RuntimeError(
|
| 199 |
-
"No feasible solution (check ratio_tolerance vs. required test edges)."
|
| 200 |
-
)
|
| 201 |
-
|
| 202 |
-
# Read cluster assignments
|
| 203 |
-
tf_assign = {c: next(s for s in S if x[c, s].solution_value() > 0.5) for c in tfs}
|
| 204 |
-
dna_assign = {d: next(s for s in S if y[d, s].solution_value() > 0.5) for d in dnas}
|
| 205 |
-
|
| 206 |
-
# Kept counts per split
|
| 207 |
-
kept_by_split = {s: int(round(K[s].solution_value())) for s in S}
|
| 208 |
-
total_kept = int(round(T.solution_value()))
|
| 209 |
-
|
| 210 |
-
# ---- Build per-example split assignment (including 'leaky_test') ----
|
| 211 |
-
split_to_indices = {s: [] for s in S}
|
| 212 |
-
remaining_indices = {pair: list(pair_to_indices[pair]) for pair in pair_to_indices}
|
| 213 |
-
|
| 214 |
-
# Allocate the kept examples per split (train/val/test)
|
| 215 |
-
for (c, d), wt in w.items():
|
| 216 |
-
for s in S:
|
| 217 |
-
cnt = int(round(k[((c, d), s)].solution_value()))
|
| 218 |
-
if cnt > 0:
|
| 219 |
-
take = remaining_indices[(c, d)][:cnt]
|
| 220 |
-
split_to_indices[s].extend(take)
|
| 221 |
-
remaining_indices[(c, d)] = remaining_indices[(c, d)][cnt:]
|
| 222 |
-
|
| 223 |
-
# Everything left becomes leaky_test
|
| 224 |
-
leaky_indices = []
|
| 225 |
-
for pair, idxs in remaining_indices.items():
|
| 226 |
-
if idxs:
|
| 227 |
-
leaky_indices.extend(idxs)
|
| 228 |
-
|
| 229 |
-
split_to_indices["leaky_test"] = leaky_indices
|
| 230 |
-
split_to_edges = {
|
| 231 |
-
s: [edges[i] for i in split_to_indices[s]] for s in split_to_indices
|
| 232 |
-
}
|
| 233 |
-
|
| 234 |
-
return (
|
| 235 |
-
tf_assign,
|
| 236 |
-
dna_assign,
|
| 237 |
-
kept_by_split,
|
| 238 |
-
total_kept,
|
| 239 |
-
split_to_indices,
|
| 240 |
-
split_to_edges,
|
| 241 |
-
)
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
class DSU:
|
| 245 |
-
def __init__(self):
|
| 246 |
-
self.p = {}
|
| 247 |
-
|
| 248 |
-
def find(self, x):
|
| 249 |
-
if x not in self.p:
|
| 250 |
-
self.p[x] = x
|
| 251 |
-
while self.p[x] != x:
|
| 252 |
-
self.p[x] = self.p[self.p[x]]
|
| 253 |
-
x = self.p[x]
|
| 254 |
-
return x
|
| 255 |
-
|
| 256 |
-
def union(self, a, b):
|
| 257 |
-
ra, rb = self.find(a), self.find(b)
|
| 258 |
-
if ra != rb:
|
| 259 |
-
self.p[rb] = ra
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
def split_bipartite_by_components(
|
| 263 |
-
edges,
|
| 264 |
-
split_names=("train", "val", "test"),
|
| 265 |
-
ratios=(0.8, 0.1, 0.1),
|
| 266 |
-
seed=0,
|
| 267 |
-
require_nonempty=False,
|
| 268 |
-
test_edges_must=None, # None, list[(tf,dna)], or dict{(tf,dna): count}
|
| 269 |
-
):
|
| 270 |
-
"""
|
| 271 |
-
Guarantees exclusivity: each TF cluster and DNA cluster appears in at most one split.
|
| 272 |
-
Strategy: find connected components in the TF–DNA bipartite graph and assign components wholesale.
|
| 273 |
-
"""
|
| 274 |
-
rng = random.Random(seed)
|
| 275 |
-
w = Counter(edges) # multiplicities per pair
|
| 276 |
-
if not w:
|
| 277 |
-
raise ValueError("No edges.")
|
| 278 |
-
|
| 279 |
-
# 1) Build components with Union-Find (prefix to keep TF/DNA namespaces disjoint)
|
| 280 |
-
dsu = DSU()
|
| 281 |
-
for tf, dna in w:
|
| 282 |
-
dsu.union(("T", tf), ("D", dna))
|
| 283 |
-
comp_pairs = defaultdict(list)
|
| 284 |
-
comp_weight = defaultdict(int)
|
| 285 |
-
for (tf, dna), cnt in w.items():
|
| 286 |
-
root = dsu.find(("T", tf)) # component id = root of TF endpoint
|
| 287 |
-
comp_pairs[root].append((tf, dna))
|
| 288 |
-
comp_weight[root] += cnt
|
| 289 |
-
|
| 290 |
-
comps = list(comp_pairs.keys())
|
| 291 |
-
C = len(comps)
|
| 292 |
-
S = list(split_names)
|
| 293 |
-
rs = dict(zip(S, ratios))
|
| 294 |
-
N = sum(comp_weight[c] for c in comps)
|
| 295 |
-
target = {s: int(round(rs[s] * N)) for s in S}
|
| 296 |
-
|
| 297 |
-
# 2) Pin components that contain required TEST pairs
|
| 298 |
-
pinned = {} # comp_root -> pinned_split ("test")
|
| 299 |
-
if test_edges_must:
|
| 300 |
-
req = (
|
| 301 |
-
Counter(test_edges_must)
|
| 302 |
-
if not isinstance(test_edges_must, dict)
|
| 303 |
-
else Counter(test_edges_must)
|
| 304 |
-
)
|
| 305 |
-
# Map each required pair to its component, ensure feasibility
|
| 306 |
-
for (tf, dna), r in req.items():
|
| 307 |
-
if (tf, dna) not in w:
|
| 308 |
-
raise ValueError(f"Required pair {(tf,dna)} not present.")
|
| 309 |
-
if r > w[(tf, dna)]:
|
| 310 |
-
raise ValueError(
|
| 311 |
-
f"Required count {r} for {(tf,dna)} exceeds available {w[(tf,dna)]}."
|
| 312 |
-
)
|
| 313 |
-
comp = dsu.find(("T", tf))
|
| 314 |
-
if comp in pinned and pinned[comp] != "test":
|
| 315 |
-
raise ValueError(
|
| 316 |
-
f"Component conflict: already pinned to {pinned[comp]}, but {(tf,dna)} demands test."
|
| 317 |
-
)
|
| 318 |
-
pinned[comp] = "test"
|
| 319 |
-
# NOTE: pinning a pair pins the WHOLE component to test (to keep exclusivity).
|
| 320 |
-
# If you only want some edges kept in test and discard the rest, handle below when materializing.
|
| 321 |
-
|
| 322 |
-
# 3) Assign components greedily by deficit
|
| 323 |
-
kept_by_split = {s: 0 for s in S}
|
| 324 |
-
comp_assign = {} # comp_root -> split
|
| 325 |
-
|
| 326 |
-
# First assign pinned comps
|
| 327 |
-
for comp, split in pinned.items():
|
| 328 |
-
comp_assign[comp] = split
|
| 329 |
-
kept_by_split[split] += comp_weight[comp]
|
| 330 |
-
|
| 331 |
-
# Sort remaining components by descending weight
|
| 332 |
-
remaining = [c for c in comps if c not in comp_assign]
|
| 333 |
-
remaining.sort(key=lambda c: comp_weight[c], reverse=True)
|
| 334 |
-
|
| 335 |
-
# Ensure nonempty splits if requested (seed with largest remaining comps)
|
| 336 |
-
if require_nonempty:
|
| 337 |
-
seeds = remaining[: min(len(S), len(remaining))]
|
| 338 |
-
for comp, s in zip(seeds, S):
|
| 339 |
-
comp_assign[comp] = s
|
| 340 |
-
kept_by_split[s] += comp_weight[comp]
|
| 341 |
-
remaining = [c for c in remaining if c not in comp_assign]
|
| 342 |
-
|
| 343 |
-
for comp in remaining:
|
| 344 |
-
# choose split with largest deficit (target - current)
|
| 345 |
-
deficits = {s: target[s] - kept_by_split[s] for s in S}
|
| 346 |
-
best = max(deficits, key=lambda s: deficits[s])
|
| 347 |
-
comp_assign[comp] = best
|
| 348 |
-
kept_by_split[best] += comp_weight[comp]
|
| 349 |
-
|
| 350 |
-
total_kept = sum(kept_by_split.values())
|
| 351 |
-
|
| 352 |
-
# 4) Materialize per-example indices (and verify exclusivity)
|
| 353 |
-
pair_to_indices = defaultdict(list)
|
| 354 |
-
for idx, pair in enumerate(edges):
|
| 355 |
-
pair_to_indices[pair].append(idx)
|
| 356 |
-
|
| 357 |
-
split_to_indices = {s: [] for s in S}
|
| 358 |
-
for comp, s in comp_assign.items():
|
| 359 |
-
for pair in comp_pairs[comp]:
|
| 360 |
-
split_to_indices[s].extend(pair_to_indices[pair])
|
| 361 |
-
|
| 362 |
-
# Optional: if you pinned a comp due to a small 'must-test' count but
|
| 363 |
-
# want to *discard* the rest instead of keeping them in test, uncomment:
|
| 364 |
-
# for comp, s in comp_assign.items():
|
| 365 |
-
# if s == "test" and test_edges_must:
|
| 366 |
-
# # Keep only the required counts; dump extras to 'leaky_test'
|
| 367 |
-
# ...
|
| 368 |
-
# (Left out for clarity; default is: keep the whole component in its split.)
|
| 369 |
-
|
| 370 |
-
# 5) Build edge lists and simple cluster assignments
|
| 371 |
-
split_to_edges = {
|
| 372 |
-
s: [edges[i] for i in split_to_indices[s]] for s in split_to_indices
|
| 373 |
-
}
|
| 374 |
-
tf_assign, dna_assign = {}, {}
|
| 375 |
-
for comp, s in comp_assign.items():
|
| 376 |
-
for tf, dna in comp_pairs[comp]:
|
| 377 |
-
tf_assign[tf] = s
|
| 378 |
-
dna_assign[dna] = s
|
| 379 |
-
|
| 380 |
-
# 6) Safety check: no DNA/TF appears in multiple splits
|
| 381 |
-
tf_in_split = defaultdict(set)
|
| 382 |
-
dna_in_split = defaultdict(set)
|
| 383 |
-
for s, elist in split_to_edges.items():
|
| 384 |
-
for tf, dna in elist:
|
| 385 |
-
tf_in_split[tf].add(s)
|
| 386 |
-
dna_in_split[dna].add(s)
|
| 387 |
-
dup_tf = {tf: ss for tf, ss in tf_in_split.items() if len(ss) > 1}
|
| 388 |
-
dup_dna = {dn: ss for dn, ss in dna_in_split.items() if len(ss) > 1}
|
| 389 |
-
assert not dup_tf and not dup_dna, f"Exclusivity violated: {dup_tf} {dup_dna}"
|
| 390 |
-
|
| 391 |
-
return (
|
| 392 |
-
tf_assign,
|
| 393 |
-
dna_assign,
|
| 394 |
-
kept_by_split,
|
| 395 |
-
total_kept,
|
| 396 |
-
split_to_indices,
|
| 397 |
-
split_to_edges,
|
| 398 |
-
)
|
| 399 |
-
|
| 400 |
-
|
| 401 |
def print_split_ratios(kept_by_split):
|
| 402 |
total = sum(kept_by_split.values())
|
| 403 |
train_pcnt = 100 * kept_by_split["train"] / total
|
|
@@ -452,39 +188,57 @@ def check_validity(train, val, test, split_by="both"):
|
|
| 452 |
assert len(val_ids.intersection(test_ids)) == 0
|
| 453 |
logger.info(f"Pass! No overlap in IDs")
|
| 454 |
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 459 |
|
|
|
|
| 460 |
assert len(train_tr_seqs.intersection(val_tr_seqs)) == 0
|
| 461 |
assert len(train_tr_seqs.intersection(test_tr_seqs)) == 0
|
| 462 |
assert len(val_tr_seqs.intersection(test_tr_seqs)) == 0
|
| 463 |
logger.info(f"Pass! No overlap in TR sequences")
|
| 464 |
|
| 465 |
-
train_tr_reps = set(train["tr_cluster_rep"].unique().tolist())
|
| 466 |
-
val_tr_reps = set(val["tr_cluster_rep"].unique().tolist())
|
| 467 |
-
test_tr_reps = set(test["tr_cluster_rep"].unique().tolist())
|
| 468 |
-
|
| 469 |
assert len(train_tr_reps.intersection(val_tr_reps)) == 0
|
| 470 |
assert len(train_tr_reps.intersection(test_tr_reps)) == 0
|
| 471 |
assert len(val_tr_reps.intersection(test_tr_reps)) == 0
|
| 472 |
logger.info(f"Pass! No overlap in TR cluster reps")
|
| 473 |
|
| 474 |
if split_by != "protein":
|
| 475 |
-
train_dna_seqs = set(train["dna_sequence"].unique().tolist())
|
| 476 |
-
val_dna_seqs = set(val["dna_sequence"].unique().tolist())
|
| 477 |
-
test_dna_seqs = set(test["dna_sequence"].unique().tolist())
|
| 478 |
-
|
| 479 |
assert len(train_dna_seqs.intersection(val_dna_seqs)) == 0
|
| 480 |
assert len(train_dna_seqs.intersection(test_dna_seqs)) == 0
|
| 481 |
assert len(val_dna_seqs.intersection(test_dna_seqs)) == 0
|
| 482 |
logger.info(f"Pass! No overlap in DNA sequences")
|
| 483 |
|
| 484 |
-
train_dna_reps = set(train["dna_cluster_rep"].unique().tolist())
|
| 485 |
-
val_dna_reps = set(val["dna_cluster_rep"].unique().tolist())
|
| 486 |
-
test_dna_reps = set(test["dna_cluster_rep"].unique().tolist())
|
| 487 |
-
|
| 488 |
assert len(train_dna_reps.intersection(val_dna_reps)) == 0
|
| 489 |
assert len(train_dna_reps.intersection(test_dna_reps)) == 0
|
| 490 |
assert len(val_dna_reps.intersection(test_dna_reps)) == 0
|
|
@@ -531,8 +285,23 @@ def main(cfg: DictConfig):
|
|
| 531 |
logger.info(f"All proteins are in their own clusters: {no_protein_overlap}")
|
| 532 |
|
| 533 |
if cfg.data_task.split_by == "dna":
|
| 534 |
-
if cfg.data_task.
|
| 535 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 536 |
else:
|
| 537 |
logger.info(f"Easy split: all proteins are in their own clusters.")
|
| 538 |
dna_clusters = edge_df["dna_cluster_rep"].unique().tolist()
|
|
@@ -549,75 +318,42 @@ def main(cfg: DictConfig):
|
|
| 549 |
|
| 550 |
# assign datapoints to cluster by their DNA cluster rep
|
| 551 |
edge_df["split"] = edge_df["dna_cluster_rep"].map(dna_assign)
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
|
| 572 |
-
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
print(tf_assign)
|
| 577 |
-
print(dna_assign)
|
| 578 |
-
edge_df["tr_split"] = edge_df["tr_cluster_rep"].map(tf_assign)
|
| 579 |
-
edge_df["dna_split"] = edge_df["dna_cluster_rep"].map(dna_assign)
|
| 580 |
-
edge_df["same_split"] = (
|
| 581 |
-
edge_df["tr_split"] == edge_df["dna_split"]
|
| 582 |
-
) # should always be true if easy cluster
|
| 583 |
-
edge_df["split"] = edge_df["tr_split"]
|
| 584 |
-
print(edge_df)
|
| 585 |
-
edge_df["split"] = np.where(
|
| 586 |
-
edge_df["same_split"],
|
| 587 |
-
edge_df["split"], # keep existing split if same_split == True
|
| 588 |
-
"leak", # otherwise leak
|
| 589 |
-
)
|
| 590 |
-
print(edge_df)
|
| 591 |
-
|
| 592 |
-
# Print ratios: hopefully close to desired (e.g. 80/10/10)
|
| 593 |
-
print_split_ratios(kept_by_split)
|
| 594 |
-
|
| 595 |
-
# Make train, val, test sets
|
| 596 |
-
# make sure no ID is duplicate
|
| 597 |
-
assert len(edge_df["ID"].unique()) == len(edge_df)
|
| 598 |
-
split_cols = [
|
| 599 |
-
"ID",
|
| 600 |
-
"dna_sequence",
|
| 601 |
-
"tr_sequence",
|
| 602 |
-
"tr_cluster_rep",
|
| 603 |
-
"dna_cluster_rep",
|
| 604 |
-
"scores",
|
| 605 |
-
"split",
|
| 606 |
-
]
|
| 607 |
-
train = edge_df.loc[edge_df["split"] == "train"].reset_index(drop=True)[split_cols]
|
| 608 |
-
val = edge_df.loc[edge_df["split"] == "val"].reset_index(drop=True)[split_cols]
|
| 609 |
-
test = edge_df.loc[edge_df["split"] == "test"].reset_index(drop=True)[split_cols]
|
| 610 |
|
| 611 |
# ensure there is no overlap
|
| 612 |
check_validity(train, val, test, split_by=cfg.data_task.split_by)
|
| 613 |
|
| 614 |
-
total = sum([len(train), len(val), len(test)])
|
| 615 |
logger.info(f"Length of train dataset: {len(train)} ({100*len(train)/total:.2f}%)")
|
| 616 |
logger.info(f"Length of val dataset: {len(val)} ({100*len(val)/total:.2f}%)")
|
| 617 |
logger.info(f"Length of test dataset: {len(test)} ({100*len(test)/total:.2f}%)")
|
|
|
|
| 618 |
logger.info(f"Total sequences = {total}. Same as edges size? {total==len(edge_df)}")
|
| 619 |
|
| 620 |
-
og_unique_dna = pd.concat([train, val, test])
|
| 621 |
og_unique_dna = len(og_unique_dna["dna_sequence"].unique())
|
| 622 |
|
| 623 |
## Now do RC data augmentation if asked
|
|
@@ -625,23 +361,25 @@ def main(cfg: DictConfig):
|
|
| 625 |
train = augment_rc(train)
|
| 626 |
val = augment_rc(val)
|
| 627 |
test = augment_rc(test)
|
|
|
|
| 628 |
|
| 629 |
-
logger.info(f"Added reverse complement sequences to train, val, and test
|
| 630 |
|
| 631 |
check_validity(train, val, test, split_by=cfg.data_task.split_by)
|
| 632 |
|
| 633 |
-
total = sum([len(train), len(val), len(test)])
|
| 634 |
logger.info(
|
| 635 |
f"Length of train dataset: {len(train)} ({100*len(train)/total:.2f}%)"
|
| 636 |
)
|
| 637 |
logger.info(f"Length of val dataset: {len(val)} ({100*len(val)/total:.2f}%)")
|
| 638 |
logger.info(f"Length of test dataset: {len(test)} ({100*len(test)/total:.2f}%)")
|
|
|
|
| 639 |
logger.info(
|
| 640 |
f"Total sequences = {total}. Same as edges size? {total==len(edge_df)}"
|
| 641 |
)
|
| 642 |
|
| 643 |
-
# since we've added all these new DNA sequences, we do need a new
|
| 644 |
-
all_data = pd.concat([train, val, test])
|
| 645 |
all_data["dna_seqid"] = all_data["ID"].str.split("_", n=1, expand=True)[1]
|
| 646 |
dna_dict = dict(zip(all_data["dna_seqid"], all_data["dna_sequence"]))
|
| 647 |
assert len(dna_dict) == len(all_data.drop_duplicates(["dna_sequence"]))
|
|
@@ -660,13 +398,15 @@ def main(cfg: DictConfig):
|
|
| 660 |
os.makedirs(split_out_dir, exist_ok=True)
|
| 661 |
|
| 662 |
# add binary_scores to allow other training modes
|
| 663 |
-
train["fimo_binary_sores"] = train["scores"].apply(lambda x: convert_scores(x))
|
| 664 |
-
val["fimo_binary_sores"] = val["scores"].apply(lambda x: convert_scores(x))
|
| 665 |
-
test["fimo_binary_sores"] = test["scores"].apply(lambda x: convert_scores(x))
|
|
|
|
| 666 |
|
| 667 |
# slect final cols and save
|
| 668 |
split_final_cols = ["ID", "dna_sequence", "tr_sequence", "scores", "fimo_binary_sores", "split"]
|
| 669 |
train[split_final_cols].to_csv(split_out_dir / "train.csv", index=False)
|
| 670 |
val[split_final_cols].to_csv(split_out_dir / "val.csv", index=False)
|
| 671 |
test[split_final_cols].to_csv(split_out_dir / "test.csv", index=False)
|
|
|
|
| 672 |
logger.info(f"Saved all splits to {split_out_dir}")
|
|
|
|
| 15 |
root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
| 16 |
logger = pylogger.RankedLogger(__name__, rank_zero_only=True)
|
| 17 |
|
| 18 |
+
def split_with_predefined_test(
|
| 19 |
+
full_df = pd.DataFrame(),
|
| 20 |
+
split_names=("train", "val", "test"),
|
| 21 |
+
test_trs=None,
|
| 22 |
+
test_dnas=None,
|
| 23 |
+
ratios=(0.8, 0.1, 0.1),
|
| 24 |
+
):
|
| 25 |
+
"""
|
| 26 |
+
Method for splitting into train and val with a predefined test set.
|
| 27 |
+
The proteins in the test set, and the DNA clusters of the DNAs they're associated with, must be excluded from train and val.
|
| 28 |
+
The remaining rows for train and val are split to preserve 80/10/10 as best as possible.
|
| 29 |
+
"""
|
| 30 |
+
test = full_df.copy(deep=True)
|
| 31 |
+
if test_trs is not None:
|
| 32 |
+
test = test.loc[test["tr_seqid"].isin(test_trs)].reset_index(drop=True)
|
| 33 |
+
if test_dnas is not None:
|
| 34 |
+
test = test.loc[test["dna_seqid"].isin(test_dnas)].reset_index(drop=True)
|
| 35 |
+
|
| 36 |
+
tr_clusters_to_exclude = test["tr_cluster_rep"].unique().tolist()
|
| 37 |
+
dna_clusters_to_exclude = test["dna_cluster_rep"].unique().tolist()
|
| 38 |
+
|
| 39 |
+
remaining = full_df.loc[
|
| 40 |
+
(~full_df["tr_cluster_rep"].isin(tr_clusters_to_exclude)) &
|
| 41 |
+
(~full_df["dna_cluster_rep"].isin(dna_clusters_to_exclude))
|
| 42 |
+
].reset_index(drop=True)
|
| 43 |
+
|
| 44 |
+
test_ids = test["ID"].unique().tolist()
|
| 45 |
+
remaining_ids = remaining["ID"].unique().tolist()
|
| 46 |
+
remaining_clusters = remaining["dna_cluster_rep"].unique().tolist()
|
| 47 |
+
lost_rows = full_df.loc[
|
| 48 |
+
(~full_df["ID"].isin(test_ids)) &
|
| 49 |
+
(~full_df["ID"].isin(remaining_ids))
|
| 50 |
+
]
|
| 51 |
+
|
| 52 |
+
logger.info(f"Rows in test: {len(test)}")
|
| 53 |
+
logger.info(f"Rows to be split between train and val: {len(remaining)}")
|
| 54 |
+
total_rows = len(test) + len(remaining)
|
| 55 |
+
logger.info(f"Total rows: {total_rows}. Test percentage: {100*len(test)/total_rows:.2f}%")
|
| 56 |
+
logger.info(f"Lost rows: {len(lost_rows)}")
|
| 57 |
+
|
| 58 |
+
train_ratio_from_remaining = round((0.8*total_rows)/len(remaining), 2)
|
| 59 |
+
# use sklearn
|
| 60 |
+
test_size_1 = 1 - train_ratio_from_remaining
|
| 61 |
+
logger.info(
|
| 62 |
+
f"\tPerforming first split: non-test clusters -> train clusters ({round(1-test_size_1,3)}) and val ({test_size_1})"
|
| 63 |
+
)
|
| 64 |
+
X = remaining_clusters
|
| 65 |
+
y = [0] * len(remaining_clusters)
|
| 66 |
+
X_train, X_val, y_train, y_val = train_test_split(
|
| 67 |
+
X, y, test_size=test_size_1, random_state=0
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
train = remaining.loc[remaining["dna_cluster_rep"].isin(X_train)]
|
| 71 |
+
val = remaining.loc[remaining["dna_cluster_rep"].isin(X_val)]
|
| 72 |
+
leaky_test = lost_rows
|
| 73 |
+
|
| 74 |
+
kept_by_split = {
|
| 75 |
+
"train": len(X_train),
|
| 76 |
+
"val": len(X_val),
|
| 77 |
+
"test": len(test["dna_cluster_rep"].unique())
|
| 78 |
+
}
|
| 79 |
+
splits = {
|
| 80 |
+
"train": train,
|
| 81 |
+
"val": val,
|
| 82 |
+
"test": test,
|
| 83 |
+
"leaky_test": leaky_test
|
| 84 |
+
}
|
| 85 |
+
return splits, kept_by_split
|
| 86 |
|
| 87 |
def split_bipartite_fast(
|
| 88 |
dna_clusters,
|
|
|
|
| 118 |
kept_by_split = {"train": len(X_train), "val": len(X_val), "test": len(X_test)}
|
| 119 |
return dna_assign, kept_by_split
|
| 120 |
|
| 121 |
+
# construct new labels
|
| 122 |
+
def convert_scores(scores, mode=1):
|
| 123 |
+
"""
|
| 124 |
+
Two modes: 1 means FIMO peaks get 1. 0 means FIMO peaks get their max score
|
| 125 |
+
"""
|
| 126 |
svec = [int(x) for x in scores.split(",")]
|
| 127 |
max_score = max(svec)
|
| 128 |
+
if mode ==1:
|
| 129 |
+
binary_svec = [0 if x<max_score else 1 for x in svec]
|
| 130 |
+
assert(svec.count(max_score)==binary_svec.count(1))
|
| 131 |
+
else:
|
| 132 |
+
binary_svec = [0 if x<max_score else max_score for x in svec]
|
| 133 |
+
assert(svec.count(max_score)==binary_svec.count(max_score))
|
| 134 |
binary_svec = ",".join([str(x) for x in binary_svec])
|
| 135 |
return binary_svec
|
| 136 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
def print_split_ratios(kept_by_split):
|
| 138 |
total = sum(kept_by_split.values())
|
| 139 |
train_pcnt = 100 * kept_by_split["train"] / total
|
|
|
|
| 188 |
assert len(val_ids.intersection(test_ids)) == 0
|
| 189 |
logger.info(f"Pass! No overlap in IDs")
|
| 190 |
|
| 191 |
+
# Investigate TR intersection. No assertions unless we are explicitly splitting on this.
|
| 192 |
+
train_tr_seqs = set(train["tr_sequence"].unique().tolist())
|
| 193 |
+
val_tr_seqs = set(val["tr_sequence"].unique().tolist())
|
| 194 |
+
test_tr_seqs = set(test["tr_sequence"].unique().tolist())
|
| 195 |
+
|
| 196 |
+
train_tr_reps = set(train["tr_cluster_rep"].unique().tolist())
|
| 197 |
+
val_tr_reps = set(val["tr_cluster_rep"].unique().tolist())
|
| 198 |
+
test_tr_reps = set(test["tr_cluster_rep"].unique().tolist())
|
| 199 |
+
|
| 200 |
+
logger.info(f"Train-Val TR intersection: {len(train_tr_seqs.intersection(val_tr_seqs))}")
|
| 201 |
+
logger.info(f"Train-Test TR intersection: {len(train_tr_seqs.intersection(test_tr_seqs))}")
|
| 202 |
+
logger.info(f"Val-Test TR intersection: {len(val_tr_seqs.intersection(test_tr_seqs))}")
|
| 203 |
+
|
| 204 |
+
logger.info(f"Train-Val TR Cluster Rep intersection: {len(train_tr_reps.intersection(val_tr_reps))}")
|
| 205 |
+
logger.info(f"Train-Test TR Cluster Rep intersection: {len(train_tr_reps.intersection(test_tr_reps))}")
|
| 206 |
+
logger.info(f"Val-Test TR Cluster Rep intersection: {len(val_tr_reps.intersection(test_tr_reps))}")
|
| 207 |
+
|
| 208 |
+
# Investigate DNA intersection. No assertions unless we are explicitly splitting on this.
|
| 209 |
+
train_dna_seqs = set(train["dna_sequence"].unique().tolist())
|
| 210 |
+
val_dna_seqs = set(val["dna_sequence"].unique().tolist())
|
| 211 |
+
test_dna_seqs = set(test["dna_sequence"].unique().tolist())
|
| 212 |
+
|
| 213 |
+
train_dna_reps = set(train["dna_cluster_rep"].unique().tolist())
|
| 214 |
+
val_dna_reps = set(val["dna_cluster_rep"].unique().tolist())
|
| 215 |
+
test_dna_reps = set(test["dna_cluster_rep"].unique().tolist())
|
| 216 |
+
|
| 217 |
+
logger.info(f"Train-Val DNA intersection: {len(train_dna_seqs.intersection(val_dna_seqs))}")
|
| 218 |
+
logger.info(f"Train-Test DNA intersection: {len(train_dna_seqs.intersection(test_dna_seqs))}")
|
| 219 |
+
logger.info(f"Val-Test DNA intersection: {len(val_dna_seqs.intersection(test_dna_seqs))}")
|
| 220 |
+
|
| 221 |
+
logger.info(f"Train-Val DNA Cluster Rep intersection: {len(train_dna_reps.intersection(val_dna_reps))}")
|
| 222 |
+
logger.info(f"Train-Test DNA Cluster Rep intersection: {len(train_dna_reps.intersection(test_dna_reps))}")
|
| 223 |
+
logger.info(f"Val-Test DNA Cluster Rep intersection: {len(val_dna_reps.intersection(test_dna_reps))}")
|
| 224 |
|
| 225 |
+
if split_by != "dna":
|
| 226 |
assert len(train_tr_seqs.intersection(val_tr_seqs)) == 0
|
| 227 |
assert len(train_tr_seqs.intersection(test_tr_seqs)) == 0
|
| 228 |
assert len(val_tr_seqs.intersection(test_tr_seqs)) == 0
|
| 229 |
logger.info(f"Pass! No overlap in TR sequences")
|
| 230 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
assert len(train_tr_reps.intersection(val_tr_reps)) == 0
|
| 232 |
assert len(train_tr_reps.intersection(test_tr_reps)) == 0
|
| 233 |
assert len(val_tr_reps.intersection(test_tr_reps)) == 0
|
| 234 |
logger.info(f"Pass! No overlap in TR cluster reps")
|
| 235 |
|
| 236 |
if split_by != "protein":
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
assert len(train_dna_seqs.intersection(val_dna_seqs)) == 0
|
| 238 |
assert len(train_dna_seqs.intersection(test_dna_seqs)) == 0
|
| 239 |
assert len(val_dna_seqs.intersection(test_dna_seqs)) == 0
|
| 240 |
logger.info(f"Pass! No overlap in DNA sequences")
|
| 241 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
assert len(train_dna_reps.intersection(val_dna_reps)) == 0
|
| 243 |
assert len(train_dna_reps.intersection(test_dna_reps)) == 0
|
| 244 |
assert len(val_dna_reps.intersection(test_dna_reps)) == 0
|
|
|
|
| 285 |
logger.info(f"All proteins are in their own clusters: {no_protein_overlap}")
|
| 286 |
|
| 287 |
if cfg.data_task.split_by == "dna":
|
| 288 |
+
if cfg.data_task.test_trs or cfg.data_task.test_dnas:
|
| 289 |
+
logger.info(f"Splitting with predefined trs/dnas reserved for test set")
|
| 290 |
+
splits, kept_by_split = split_with_predefined_test(
|
| 291 |
+
full_df=edge_df,
|
| 292 |
+
split_names=("train", "val", "test"),
|
| 293 |
+
test_trs=cfg.data_task.test_trs if cfg.data_task.test_trs else None,
|
| 294 |
+
test_dnas=cfg.data_task.test_dnas if cfg.data_task.test_dnas else None,
|
| 295 |
+
ratios=(0.8, 0.1, 0.1),
|
| 296 |
+
)
|
| 297 |
+
train = splits["train"]
|
| 298 |
+
train["split"]=["train"]*len(train)
|
| 299 |
+
val = splits["val"]
|
| 300 |
+
val["split"]=["val"]*len(val)
|
| 301 |
+
test = splits["test"]
|
| 302 |
+
test["split"]=["test"]*len(test)
|
| 303 |
+
leaky_test = splits["leaky_test"]
|
| 304 |
+
leaky_test["split"]=["leaky_test"]*len(leaky_test)
|
| 305 |
else:
|
| 306 |
logger.info(f"Easy split: all proteins are in their own clusters.")
|
| 307 |
dna_clusters = edge_df["dna_cluster_rep"].unique().tolist()
|
|
|
|
| 318 |
|
| 319 |
# assign datapoints to cluster by their DNA cluster rep
|
| 320 |
edge_df["split"] = edge_df["dna_cluster_rep"].map(dna_assign)
|
| 321 |
+
train = edge_df.loc[edge_df["split"] == "train"].reset_index(drop=True)
|
| 322 |
+
val = edge_df.loc[edge_df["split"] == "val"].reset_index(drop=True)
|
| 323 |
+
test = edge_df.loc[edge_df["split"] == "test"].reset_index(drop=True)
|
| 324 |
+
leaky_test = pd.DataFrame(columns=edge_df.columns)
|
| 325 |
+
|
| 326 |
+
# Print ratios: hopefully close to desired (e.g. 80/10/10)
|
| 327 |
+
print_split_ratios(kept_by_split)
|
| 328 |
+
|
| 329 |
+
# Make train, val, test sets
|
| 330 |
+
# make sure no ID is duplicate
|
| 331 |
+
assert len(edge_df["ID"].unique()) == len(edge_df)
|
| 332 |
+
split_cols = [
|
| 333 |
+
"ID",
|
| 334 |
+
"dna_sequence",
|
| 335 |
+
"tr_sequence",
|
| 336 |
+
"tr_cluster_rep",
|
| 337 |
+
"dna_cluster_rep",
|
| 338 |
+
"scores",
|
| 339 |
+
"split",
|
| 340 |
+
]
|
| 341 |
+
train = train[split_cols]
|
| 342 |
+
val = val[split_cols]
|
| 343 |
+
test = test[split_cols]
|
| 344 |
+
leaky_test = leaky_test[split_cols]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
|
| 346 |
# ensure there is no overlap
|
| 347 |
check_validity(train, val, test, split_by=cfg.data_task.split_by)
|
| 348 |
|
| 349 |
+
total = sum([len(train), len(val), len(test), len(leaky_test)])
|
| 350 |
logger.info(f"Length of train dataset: {len(train)} ({100*len(train)/total:.2f}%)")
|
| 351 |
logger.info(f"Length of val dataset: {len(val)} ({100*len(val)/total:.2f}%)")
|
| 352 |
logger.info(f"Length of test dataset: {len(test)} ({100*len(test)/total:.2f}%)")
|
| 353 |
+
logger.info(f"Length of leaky_test dataset: {len(leaky_test)} ({100*len(leaky_test)/total:.2f}%)")
|
| 354 |
logger.info(f"Total sequences = {total}. Same as edges size? {total==len(edge_df)}")
|
| 355 |
|
| 356 |
+
og_unique_dna = pd.concat([train, val, test, leaky_test])
|
| 357 |
og_unique_dna = len(og_unique_dna["dna_sequence"].unique())
|
| 358 |
|
| 359 |
## Now do RC data augmentation if asked
|
|
|
|
| 361 |
train = augment_rc(train)
|
| 362 |
val = augment_rc(val)
|
| 363 |
test = augment_rc(test)
|
| 364 |
+
leaky_test = augment_rc(leaky_test)
|
| 365 |
|
| 366 |
+
logger.info(f"Added reverse complement sequences to train, val, and test (and leaky test)")
|
| 367 |
|
| 368 |
check_validity(train, val, test, split_by=cfg.data_task.split_by)
|
| 369 |
|
| 370 |
+
total = sum([len(train), len(val), len(test), len(leaky_test)])
|
| 371 |
logger.info(
|
| 372 |
f"Length of train dataset: {len(train)} ({100*len(train)/total:.2f}%)"
|
| 373 |
)
|
| 374 |
logger.info(f"Length of val dataset: {len(val)} ({100*len(val)/total:.2f}%)")
|
| 375 |
logger.info(f"Length of test dataset: {len(test)} ({100*len(test)/total:.2f}%)")
|
| 376 |
+
logger.info(f"Length of leaky_test dataset: {len(leaky_test)} ({100*len(leaky_test)/total:.2f}%)")
|
| 377 |
logger.info(
|
| 378 |
f"Total sequences = {total}. Same as edges size? {total==len(edge_df)}"
|
| 379 |
)
|
| 380 |
|
| 381 |
+
# since we've added all these new DNA sequences, we do need a new mapping of seq id to dna sequence
|
| 382 |
+
all_data = pd.concat([train, val, test, leaky_test])
|
| 383 |
all_data["dna_seqid"] = all_data["ID"].str.split("_", n=1, expand=True)[1]
|
| 384 |
dna_dict = dict(zip(all_data["dna_seqid"], all_data["dna_sequence"]))
|
| 385 |
assert len(dna_dict) == len(all_data.drop_duplicates(["dna_sequence"]))
|
|
|
|
| 398 |
os.makedirs(split_out_dir, exist_ok=True)
|
| 399 |
|
| 400 |
# add binary_scores to allow other training modes
|
| 401 |
+
train["fimo_binary_sores"] = train["scores"].apply(lambda x: convert_scores(x, mode=1))
|
| 402 |
+
val["fimo_binary_sores"] = val["scores"].apply(lambda x: convert_scores(x, mode=1))
|
| 403 |
+
test["fimo_binary_sores"] = test["scores"].apply(lambda x: convert_scores(x, mode=1))
|
| 404 |
+
leaky_test["fimo_binary_sores"] = leaky_test["scores"].apply(lambda x: convert_scores(x, mode=1))
|
| 405 |
|
| 406 |
# slect final cols and save
|
| 407 |
split_final_cols = ["ID", "dna_sequence", "tr_sequence", "scores", "fimo_binary_sores", "split"]
|
| 408 |
train[split_final_cols].to_csv(split_out_dir / "train.csv", index=False)
|
| 409 |
val[split_final_cols].to_csv(split_out_dir / "val.csv", index=False)
|
| 410 |
test[split_final_cols].to_csv(split_out_dir / "test.csv", index=False)
|
| 411 |
+
leaky_test[split_final_cols].to_csv(split_out_dir / "leaky_test.csv", index=False)
|
| 412 |
logger.info(f"Saved all splits to {split_out_dir}")
|
dpacman/data_tasks/split/remap_handpick.py
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Not neat, but this is what I did to make exclusive splits. saving here for now.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
## Full pipeline
|
| 6 |
+
import pandas as pd
|
| 7 |
+
protein_clusters = pd.read_csv("/home/a03-svincoff/DPACMAN/dpacman/data_files/processed/mmseqs/outputs/fimo_hits_only/protein/mmseqs_cluster.tsv", sep="\t", header=None)
|
| 8 |
+
protein_clusters.columns=["tr_cluster_rep","tr_cluster_member"]
|
| 9 |
+
protein_clusters.head()
|
| 10 |
+
|
| 11 |
+
dna_clusters = pd.read_csv("/home/a03-svincoff/DPACMAN/dpacman/data_files/processed/mmseqs/outputs/fimo_hits_only/dna_full/mmseqs_cluster.tsv", sep="\t", header=None)
|
| 12 |
+
dna_clusters.columns=["dna_cluster_rep","dna_cluster_member"]
|
| 13 |
+
dna_clusters.head()
|
| 14 |
+
|
| 15 |
+
all_data = pd.read_parquet("/home/a03-svincoff/DPACMAN/dpacman/data_files/processed/fimo/post_fimo/fimo_hits_only/remap2022_crm_fimo_output_q_processed_seed0.parquet")
|
| 16 |
+
all_data
|
| 17 |
+
|
| 18 |
+
protein_cluster_map = dict(zip(protein_clusters["tr_cluster_member"],protein_clusters["tr_cluster_rep"]))
|
| 19 |
+
dna_cluster_map = dict(zip(dna_clusters["dna_cluster_member"],dna_clusters["dna_cluster_rep"]))
|
| 20 |
+
print(len(protein_cluster_map))
|
| 21 |
+
print(len(dna_cluster_map))
|
| 22 |
+
all_data["tr_cluster_rep"] = all_data["tr_seqid"].map(protein_cluster_map)
|
| 23 |
+
all_data["dna_cluster_rep"] = all_data["dna_seqid"].map(dna_cluster_map)
|
| 24 |
+
print(len(all_data[all_data["tr_cluster_rep"].isna()]))
|
| 25 |
+
print(len(all_data[all_data["dna_cluster_rep"].isna()]))
|
| 26 |
+
all_data.head()
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
### handpick test
|
| 30 |
+
handpicked_test_trs = ["trseq23","trseq26","trseq17"]
|
| 31 |
+
handpicked_test = all_data.loc[
|
| 32 |
+
all_data["tr_cluster_rep"].isin(handpicked_test_trs)
|
| 33 |
+
].reset_index(drop=True)
|
| 34 |
+
|
| 35 |
+
off_limits_dna_clusters = handpicked_test["dna_cluster_rep"].unique().tolist()
|
| 36 |
+
remaining = all_data.loc[
|
| 37 |
+
(~all_data["tr_cluster_rep"].isin(handpicked_test_trs)) &
|
| 38 |
+
(~all_data["dna_cluster_rep"].isin(off_limits_dna_clusters))
|
| 39 |
+
].reset_index(drop=True)
|
| 40 |
+
|
| 41 |
+
test_ids = handpicked_test["ID"].unique().tolist()
|
| 42 |
+
remaining_ids = remaining["ID"].unique().tolist()
|
| 43 |
+
lost_rows = all_data.loc[
|
| 44 |
+
(~all_data["ID"].isin(test_ids)) &
|
| 45 |
+
(~all_data["ID"].isin(remaining_ids))
|
| 46 |
+
]
|
| 47 |
+
print(f"Rows in test: {len(handpicked_test)}")
|
| 48 |
+
print(f"Rows to be split between train and val: {len(remaining)}")
|
| 49 |
+
total_rows = len(handpicked_test) + len(remaining)
|
| 50 |
+
print(f"Total rows: {total_rows}. Test percentage: {100*len(handpicked_test)/total_rows:.2f}%")
|
| 51 |
+
print(f"Lost rows: {len(lost_rows)}")
|
| 52 |
+
|
| 53 |
+
### handpick val
|
| 54 |
+
handpicked_val_trs = ["trseq9", "trseq5", "trseq28"]
|
| 55 |
+
|
| 56 |
+
handpicked_val = remaining.loc[
|
| 57 |
+
remaining["tr_cluster_rep"].isin(handpicked_val_trs)
|
| 58 |
+
].reset_index(drop=True)
|
| 59 |
+
|
| 60 |
+
off_limits_dna_clusters = handpicked_val["dna_cluster_rep"].unique().tolist()
|
| 61 |
+
train_remain = remaining.loc[
|
| 62 |
+
(~remaining["tr_cluster_rep"].isin(handpicked_val_trs)) &
|
| 63 |
+
(~remaining["dna_cluster_rep"].isin(off_limits_dna_clusters))
|
| 64 |
+
].reset_index(drop=True)
|
| 65 |
+
|
| 66 |
+
val_ids = handpicked_val["ID"].unique().tolist()
|
| 67 |
+
train_remain_ids = train_remain["ID"].unique().tolist()
|
| 68 |
+
lost_rows = all_data.loc[
|
| 69 |
+
(~all_data["ID"].isin(test_ids)) &
|
| 70 |
+
(~all_data["ID"].isin(val_ids)) &
|
| 71 |
+
(~all_data["ID"].isin(train_remain_ids))
|
| 72 |
+
]
|
| 73 |
+
print(f"Rows in val: {len(handpicked_val)}")
|
| 74 |
+
print(f"Rows left for train: {len(train_remain)}")
|
| 75 |
+
total_rows = len(handpicked_val) + len(train_remain)
|
| 76 |
+
print(f"Total rows: {total_rows}. Test percentage: {100*len(handpicked_val)/total_rows:.2f}%")
|
| 77 |
+
print(f"Lost rows: {len(lost_rows)}")
|
| 78 |
+
|
| 79 |
+
train_exclusive = all_data.loc[
|
| 80 |
+
all_data["ID"].isin(train_remain_ids)
|
| 81 |
+
].reset_index(drop=True)
|
| 82 |
+
|
| 83 |
+
val_exclusive = all_data.loc[
|
| 84 |
+
all_data["ID"].isin(val_ids)
|
| 85 |
+
].reset_index(drop=True)
|
| 86 |
+
|
| 87 |
+
test_exclusive = all_data.loc[
|
| 88 |
+
all_data["ID"].isin(test_ids)
|
| 89 |
+
].reset_index(drop=True)
|
| 90 |
+
|
| 91 |
+
leaky_test = all_data.loc[
|
| 92 |
+
~(all_data["ID"].isin(train_exclusive["ID"].tolist())) &
|
| 93 |
+
~(all_data["ID"].isin(val_exclusive["ID"].tolist())) &
|
| 94 |
+
~(all_data["ID"].isin(test_exclusive["ID"].tolist()))
|
| 95 |
+
].reset_index(drop=True)
|
| 96 |
+
|
| 97 |
+
print(f"Original total: {len(all_data)}")
|
| 98 |
+
retained_total = len(train_exclusive)+len(val_exclusive)+len(test_exclusive)
|
| 99 |
+
print(f"New, exclusive total: {retained_total}")
|
| 100 |
+
print(f"Lost rows: {len(all_data)-retained_total}")
|
| 101 |
+
print(f"Length train: {len(train_exclusive)}/{retained_total} ({100*len(train_exclusive)/retained_total:.2f}%)")
|
| 102 |
+
print(f"Length val: {len(val_exclusive)}/{retained_total} ({100*len(val_exclusive)/retained_total:.2f}%)")
|
| 103 |
+
print(f"Length test: {len(test_exclusive)}/{retained_total} ({100*len(test_exclusive)/retained_total:.2f}%)")
|
| 104 |
+
|
| 105 |
+
def check_validity(train_exclusive, val_exclusive, test_exclusive):
|
| 106 |
+
train_exclusive_ids = set(train_exclusive["ID"].unique().tolist())
|
| 107 |
+
val_exclusive_ids = set(val_exclusive["ID"].unique().tolist())
|
| 108 |
+
test_exclusive_ids = set(test_exclusive["ID"].unique().tolist())
|
| 109 |
+
|
| 110 |
+
assert len(train_exclusive_ids.intersection(val_exclusive_ids)) == 0
|
| 111 |
+
assert len(train_exclusive_ids.intersection(test_exclusive_ids)) == 0
|
| 112 |
+
assert len(val_exclusive_ids.intersection(test_exclusive_ids)) == 0
|
| 113 |
+
print(f"Pass! No overlap in IDs")
|
| 114 |
+
|
| 115 |
+
# Investigate TR intersection. No assertions unless we are explicitly splitting on this.
|
| 116 |
+
train_exclusive_tr_seqs = set(train_exclusive["tr_sequence"].unique().tolist())
|
| 117 |
+
val_exclusive_tr_seqs = set(val_exclusive["tr_sequence"].unique().tolist())
|
| 118 |
+
test_exclusive_tr_seqs = set(test_exclusive["tr_sequence"].unique().tolist())
|
| 119 |
+
|
| 120 |
+
train_exclusive_tr_reps = set(train_exclusive["tr_cluster_rep"].unique().tolist())
|
| 121 |
+
val_exclusive_tr_reps = set(val_exclusive["tr_cluster_rep"].unique().tolist())
|
| 122 |
+
test_exclusive_tr_reps = set(test_exclusive["tr_cluster_rep"].unique().tolist())
|
| 123 |
+
|
| 124 |
+
print(f"Train-Val TR intersection: {len(train_exclusive_tr_seqs.intersection(val_exclusive_tr_seqs))}")
|
| 125 |
+
print(f"Train-Test TR intersection: {len(train_exclusive_tr_seqs.intersection(test_exclusive_tr_seqs))}")
|
| 126 |
+
print(f"Val-Test TR intersection: {len(val_exclusive_tr_seqs.intersection(test_exclusive_tr_seqs))}")
|
| 127 |
+
|
| 128 |
+
print(f"Train-Val TR Cluster Rep intersection: {len(train_exclusive_tr_reps.intersection(val_exclusive_tr_reps))}")
|
| 129 |
+
print(f"Train-Test TR Cluster Rep intersection: {len(train_exclusive_tr_reps.intersection(test_exclusive_tr_reps))}")
|
| 130 |
+
print(f"Val-Test TR Cluster Rep intersection: {len(val_exclusive_tr_reps.intersection(test_exclusive_tr_reps))}")
|
| 131 |
+
|
| 132 |
+
# Investigate DNA intersection. No assertions unless we are explicitly splitting on this.
|
| 133 |
+
train_exclusive_dna_seqs = set(train_exclusive["dna_sequence"].unique().tolist())
|
| 134 |
+
val_exclusive_dna_seqs = set(val_exclusive["dna_sequence"].unique().tolist())
|
| 135 |
+
test_exclusive_dna_seqs = set(test_exclusive["dna_sequence"].unique().tolist())
|
| 136 |
+
|
| 137 |
+
train_exclusive_dna_reps = set(train_exclusive["dna_cluster_rep"].unique().tolist())
|
| 138 |
+
val_exclusive_dna_reps = set(val_exclusive["dna_cluster_rep"].unique().tolist())
|
| 139 |
+
test_exclusive_dna_reps = set(test_exclusive["dna_cluster_rep"].unique().tolist())
|
| 140 |
+
|
| 141 |
+
print(f"Train-Val DNA intersection: {len(train_exclusive_dna_seqs.intersection(val_exclusive_dna_seqs))}")
|
| 142 |
+
print(f"Train-Test DNA intersection: {len(train_exclusive_dna_seqs.intersection(test_exclusive_dna_seqs))}")
|
| 143 |
+
print(f"Val-Test DNA intersection: {len(val_exclusive_dna_seqs.intersection(test_exclusive_dna_seqs))}")
|
| 144 |
+
|
| 145 |
+
print(f"Train-Val DNA Cluster Rep intersection: {len(train_exclusive_dna_reps.intersection(val_exclusive_dna_reps))}")
|
| 146 |
+
print(f"Train-Test DNA Cluster Rep intersection: {len(train_exclusive_dna_reps.intersection(test_exclusive_dna_reps))}")
|
| 147 |
+
print(f"Val-Test DNA Cluster Rep intersection: {len(val_exclusive_dna_reps.intersection(test_exclusive_dna_reps))}")
|
| 148 |
+
|
| 149 |
+
def get_reverse_complement(s):
|
| 150 |
+
"""
|
| 151 |
+
Returns 5' to 3' sequence of the reverse complement
|
| 152 |
+
"""
|
| 153 |
+
chars = list(s)
|
| 154 |
+
recon = []
|
| 155 |
+
rev_map = {
|
| 156 |
+
"a": "t",
|
| 157 |
+
"c": "g",
|
| 158 |
+
"t": "a",
|
| 159 |
+
"g": "c",
|
| 160 |
+
"A": "T",
|
| 161 |
+
"C": "G",
|
| 162 |
+
"T": "A",
|
| 163 |
+
"G": "C",
|
| 164 |
+
"n": "n",
|
| 165 |
+
"N": "N",
|
| 166 |
+
}
|
| 167 |
+
for c in chars:
|
| 168 |
+
recon += [rev_map[c]]
|
| 169 |
+
|
| 170 |
+
recon = "".join(recon)
|
| 171 |
+
return recon[::-1]
|
| 172 |
+
|
| 173 |
+
# now make reverse complements
|
| 174 |
+
def augment_rc(df):
|
| 175 |
+
"""
|
| 176 |
+
Get the reverse complement and add it as a datapoint, effectively doubling the dataset.
|
| 177 |
+
Also flip the orientation of the scores
|
| 178 |
+
|
| 179 |
+
columns = ["ID","dna_sequence","tr_sequence","tr_cluster_rep","dna_cluster_rep", "scores","split"]
|
| 180 |
+
"""
|
| 181 |
+
df_rc = df.copy(deep=True)
|
| 182 |
+
|
| 183 |
+
df_rc["dna_sequence"] = df_rc["dna_sequence"].apply(
|
| 184 |
+
lambda x: get_reverse_complement(x)
|
| 185 |
+
)
|
| 186 |
+
df_rc["ID"] = df_rc["ID"] + "_rc"
|
| 187 |
+
df_rc["scores"] = df_rc["scores"].apply(lambda s: ",".join(s.split(",")[::-1]))
|
| 188 |
+
|
| 189 |
+
final_df = pd.concat([df, df_rc]).reset_index(drop=True)
|
| 190 |
+
|
| 191 |
+
return final_df
|
| 192 |
+
|
| 193 |
+
def convert_scores(scores, mode=1):
|
| 194 |
+
"""
|
| 195 |
+
Two modes: 1 means FIMO peaks get 1. 0 means FIMO peaks get their max score
|
| 196 |
+
"""
|
| 197 |
+
svec = [int(x) for x in scores.split(",")]
|
| 198 |
+
max_score = max(svec)
|
| 199 |
+
if mode ==1:
|
| 200 |
+
binary_svec = [0 if x<max_score else 1 for x in svec]
|
| 201 |
+
assert(svec.count(max_score)==binary_svec.count(1))
|
| 202 |
+
else:
|
| 203 |
+
binary_svec = [0 if x<max_score else max_score for x in svec]
|
| 204 |
+
assert(svec.count(max_score)==binary_svec.count(max_score))
|
| 205 |
+
binary_svec = ",".join([str(x) for x in binary_svec])
|
| 206 |
+
return binary_svec
|
| 207 |
+
|
| 208 |
+
check_validity(train_exclusive, val_exclusive, test_exclusive)
|
| 209 |
+
|
| 210 |
+
train_exclusive = augment_rc(train_exclusive)
|
| 211 |
+
val_exclusive = augment_rc(val_exclusive)
|
| 212 |
+
test_exclusive = augment_rc(test_exclusive)
|
| 213 |
+
leaky_test = augment_rc(leaky_test)
|
| 214 |
+
|
| 215 |
+
print(f"Added reverse complement sequences to train_exclusive, val_exclusive, and test_exclusive (and leaky test_exclusive)")
|
| 216 |
+
|
| 217 |
+
check_validity(train_exclusive, val_exclusive, test_exclusive)
|
| 218 |
+
|
| 219 |
+
total = sum([len(train_exclusive), len(val_exclusive), len(test_exclusive), len(leaky_test)])
|
| 220 |
+
print(
|
| 221 |
+
f"Length of train_exclusive dataset: {len(train_exclusive)} ({100*len(train_exclusive)/total:.2f}%)"
|
| 222 |
+
)
|
| 223 |
+
print(f"Length of val_exclusive dataset: {len(val_exclusive)} ({100*len(val_exclusive)/total:.2f}%)")
|
| 224 |
+
print(f"Length of test_exclusive dataset: {len(test_exclusive)} ({100*len(test_exclusive)/total:.2f}%)")
|
| 225 |
+
print(f"Length of leaky_test dataset: {len(leaky_test)} ({100*len(leaky_test)/total:.2f}%)")
|
| 226 |
+
print(
|
| 227 |
+
f"Total sequences = {total}. Same as edges size*2? {total==len(all_data)*2}"
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
# since we've added all these new DNA sequences, we do need a new mapping of seq id to dna sequence
|
| 231 |
+
all_data = pd.concat([train_exclusive, val_exclusive, test_exclusive, leaky_test])
|
| 232 |
+
all_data["dna_seqid"] = all_data["ID"].str.split("_", n=1, expand=True)[1]
|
| 233 |
+
dna_dict = dict(zip(all_data["dna_seqid"], all_data["dna_sequence"]))
|
| 234 |
+
assert len(dna_dict) == len(all_data.drop_duplicates(["dna_sequence"]))
|
| 235 |
+
|
| 236 |
+
# create the output dir
|
| 237 |
+
import os
|
| 238 |
+
from pathlib import Path
|
| 239 |
+
split_out_dir = Path("/home/a03-svincoff/DPACMAN/dpacman/data_files/processed/splits/handpicked_val_test")
|
| 240 |
+
os.makedirs(split_out_dir, exist_ok=True)
|
| 241 |
+
|
| 242 |
+
# add binary_scores to allow other training modes
|
| 243 |
+
train_exclusive["binary_sores"] = train_exclusive["scores"].apply(lambda x: convert_scores(x, mode=1))
|
| 244 |
+
val_exclusive["binary_sores"] = val_exclusive["scores"].apply(lambda x: convert_scores(x, mode=1))
|
| 245 |
+
test_exclusive["binary_sores"] = test_exclusive["scores"].apply(lambda x: convert_scores(x, mode=1))
|
| 246 |
+
leaky_test["binary_sores"] = leaky_test["scores"].apply(lambda x: convert_scores(x, mode=1))
|
| 247 |
+
|
| 248 |
+
train_exclusive["split"] = ["train"]*len(train_exclusive)
|
| 249 |
+
val_exclusive["split"] = ["val"]*len(val_exclusive)
|
| 250 |
+
test_exclusive["split"] = ["test"]*len(test_exclusive)
|
| 251 |
+
leaky_test["split"] = ["leakytest"]*len(leaky_test)
|
| 252 |
+
|
| 253 |
+
# slect final cols and save
|
| 254 |
+
split_final_cols = ["ID", "dna_sequence", "tr_sequence", "scores", "binary_sores", "split"]
|
| 255 |
+
train_exclusive[split_final_cols].to_csv(split_out_dir / "train.csv", index=False)
|
| 256 |
+
val_exclusive[split_final_cols].to_csv(split_out_dir / "val.csv", index=False)
|
| 257 |
+
test_exclusive[split_final_cols].to_csv(split_out_dir / "test.csv", index=False)
|
| 258 |
+
leaky_test[split_final_cols].to_csv(split_out_dir / "leakytest.csv", index=False)
|
| 259 |
+
print(f"Saved all splits to {split_out_dir}")
|
| 260 |
+
|
| 261 |
+
# make baby versions too
|
| 262 |
+
train_exclusive[split_final_cols].sample(400, random_state=42).to_csv(split_out_dir / "babytrain.csv", index=False)
|
| 263 |
+
val_exclusive[split_final_cols].sample(50, random_state=42).to_csv(split_out_dir / "babyval.csv", index=False)
|
| 264 |
+
test_exclusive[split_final_cols].sample(50, random_state=42).to_csv(split_out_dir / "babytest.csv", index=False)
|
| 265 |
+
leaky_test[split_final_cols].sample(50, random_state=42).to_csv(split_out_dir / "babyleakytest.csv", index=False)
|
| 266 |
+
|
dpacman/find_wandb_run_name.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import os
|
| 3 |
+
import argparse
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
SENTINEL = "View project at"
|
| 7 |
+
SKIP_DIRS = {".git", ".hg", ".svn", "__pycache__", ".mamba", ".conda"}
|
| 8 |
+
|
| 9 |
+
def extract_run_name(log_path: str) -> Optional[str]:
|
| 10 |
+
"""
|
| 11 |
+
Return the run name if we find a line that:
|
| 12 |
+
- starts with 'wandb:' (after leading whitespace is stripped)
|
| 13 |
+
- whose second-to-last word is 'run' and the last word is the run name
|
| 14 |
+
and this occurs before the first line containing SENTINEL.
|
| 15 |
+
Otherwise return None.
|
| 16 |
+
"""
|
| 17 |
+
try:
|
| 18 |
+
with open(log_path, "r", errors="ignore") as f:
|
| 19 |
+
for raw in f:
|
| 20 |
+
if SENTINEL in raw:
|
| 21 |
+
return None
|
| 22 |
+
line = raw.strip()
|
| 23 |
+
if not line.startswith("wandb: "):
|
| 24 |
+
continue
|
| 25 |
+
toks = line.split()
|
| 26 |
+
if len(toks) >= 2 and toks[-2] == "run":
|
| 27 |
+
return toks[-1] # the run name (last token)
|
| 28 |
+
except OSError:
|
| 29 |
+
return None
|
| 30 |
+
return None
|
| 31 |
+
|
| 32 |
+
def list_runs(root: str, followlinks: bool = False, do_rename=False):
|
| 33 |
+
for dirpath, dirs, files in os.walk(root, followlinks=followlinks):
|
| 34 |
+
# prune junk dirs
|
| 35 |
+
dirs[:] = [d for d in dirs if d not in SKIP_DIRS]
|
| 36 |
+
if "run.log" in files:
|
| 37 |
+
log_path = os.path.join(dirpath, "run.log")
|
| 38 |
+
name = extract_run_name(log_path)
|
| 39 |
+
new_dir_path = dirpath
|
| 40 |
+
if name:
|
| 41 |
+
if name not in dirpath:
|
| 42 |
+
new_dir_path = f"{dirpath}-{name}"
|
| 43 |
+
print(f"{name if name else '<unknown>'}\t{new_dir_path}")
|
| 44 |
+
|
| 45 |
+
if do_rename and new_dir_path != dirpath:
|
| 46 |
+
parent = os.path.dirname(dirpath)
|
| 47 |
+
# resolve absolute path for safety
|
| 48 |
+
abs_old = os.path.abspath(dirpath)
|
| 49 |
+
abs_new = os.path.abspath(new_dir_path)
|
| 50 |
+
|
| 51 |
+
if os.path.exists(abs_new):
|
| 52 |
+
print(f"⚠️ Target {abs_new} already exists, skipping rename.")
|
| 53 |
+
else:
|
| 54 |
+
print(f"Renaming {abs_old} → {abs_new}")
|
| 55 |
+
os.rename(abs_old, abs_new)
|
| 56 |
+
|
| 57 |
+
def main():
|
| 58 |
+
ap = argparse.ArgumentParser(
|
| 59 |
+
description="List W&B run names by parsing lines that start with 'wandb:' and end with 'run <name>'."
|
| 60 |
+
)
|
| 61 |
+
ap.add_argument("root", help="Root directory to search")
|
| 62 |
+
ap.add_argument("--followlinks", action="store_true", help="Follow symlinks while walking")
|
| 63 |
+
args = ap.parse_args()
|
| 64 |
+
list_runs(args.root, followlinks=args.followlinks)
|
| 65 |
+
|
| 66 |
+
if __name__ == "__main__":
|
| 67 |
+
main()
|
dpacman/make_splits.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
dpacman/manual_scan_chroms.ipynb
ADDED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "f6c01484",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"Temporary notebook for manually scanning chromosomes for sequences of interest"
|
| 9 |
+
]
|
| 10 |
+
},
|
| 11 |
+
{
|
| 12 |
+
"cell_type": "code",
|
| 13 |
+
"execution_count": 1,
|
| 14 |
+
"id": "0608f91e",
|
| 15 |
+
"metadata": {},
|
| 16 |
+
"outputs": [],
|
| 17 |
+
"source": [
|
| 18 |
+
"seq_of_interest = \"GCAGATCTGCACATC\""
|
| 19 |
+
]
|
| 20 |
+
},
|
| 21 |
+
{
|
| 22 |
+
"cell_type": "code",
|
| 23 |
+
"execution_count": 2,
|
| 24 |
+
"id": "3c245151",
|
| 25 |
+
"metadata": {},
|
| 26 |
+
"outputs": [],
|
| 27 |
+
"source": [
|
| 28 |
+
"genome_dir = \"/home/a03-svincoff/DPACMAN/dpacman/data_files/raw/genomes/hg38\""
|
| 29 |
+
]
|
| 30 |
+
},
|
| 31 |
+
{
|
| 32 |
+
"cell_type": "code",
|
| 33 |
+
"execution_count": 7,
|
| 34 |
+
"id": "682098b6",
|
| 35 |
+
"metadata": {},
|
| 36 |
+
"outputs": [
|
| 37 |
+
{
|
| 38 |
+
"name": "stdout",
|
| 39 |
+
"output_type": "stream",
|
| 40 |
+
"text": [
|
| 41 |
+
"dict_keys(['chr12', 'chr5', 'chr17', 'chr2', 'chr21', 'chr1', 'chrM', 'chr22', 'chr20', 'chr16', 'chr9', 'chr8', 'chr19', 'chr7', 'chr11', 'chr3', 'chr4', 'chr14', 'chr15', 'chr18', 'chrY', 'chr6', 'chrX', 'chr13', 'chr10'])\n"
|
| 42 |
+
]
|
| 43 |
+
}
|
| 44 |
+
],
|
| 45 |
+
"source": [
|
| 46 |
+
"import json\n",
|
| 47 |
+
"import os\n",
|
| 48 |
+
"chrom_cache = {}\n",
|
| 49 |
+
"for chrom_file in os.listdir(genome_dir):\n",
|
| 50 |
+
" chrom = chrom_file.split(\"hg38_\")[1].split(\".json\")[0]\n",
|
| 51 |
+
" with open(f\"{genome_dir}/{chrom_file}\", \"r\") as f:\n",
|
| 52 |
+
" chrom_cache[chrom] = json.load(f)\n",
|
| 53 |
+
"\n",
|
| 54 |
+
"print(chrom_cache.keys())"
|
| 55 |
+
]
|
| 56 |
+
},
|
| 57 |
+
{
|
| 58 |
+
"cell_type": "code",
|
| 59 |
+
"execution_count": null,
|
| 60 |
+
"id": "fd6cca79",
|
| 61 |
+
"metadata": {},
|
| 62 |
+
"outputs": [
|
| 63 |
+
{
|
| 64 |
+
"name": "stdout",
|
| 65 |
+
"output_type": "stream",
|
| 66 |
+
"text": [
|
| 67 |
+
"Testing sequence A: TAGCAGGATGTGT\n",
|
| 68 |
+
"Testing sequence B: GCAGATCTGCACATC\n",
|
| 69 |
+
"Testing sequence C: CGACACCTGACGCG\n",
|
| 70 |
+
"Testing sequence D: CGCTATCCAGAGCG\n",
|
| 71 |
+
"Testing sequence E: CGCGATGCTTCTCG\n",
|
| 72 |
+
"Testing sequence F: CGGCTGGATTACCG\n",
|
| 73 |
+
"Testing sequence G: CGAGAACATAGTCG\n",
|
| 74 |
+
"Testing sequence H: CGGGGAAACGCCCG\n",
|
| 75 |
+
"Testing sequence I: CGCCCAAAGCCGCG\n",
|
| 76 |
+
"Testing sequence J: CGGAGGTAATGACG\n",
|
| 77 |
+
"Testing sequence K: CGCACCGACTCACG\n",
|
| 78 |
+
"Testing sequence L: CGGCCCTTTGCGCG\n",
|
| 79 |
+
"Testing sequence M: CGCCGTTAGTGTCG\n"
|
| 80 |
+
]
|
| 81 |
+
}
|
| 82 |
+
],
|
| 83 |
+
"source": [
|
| 84 |
+
"baker_sequences = {\n",
|
| 85 |
+
" \"A\": \"TAGCAGGATGTGT\",\n",
|
| 86 |
+
" \"B\": \"GCAGATCTGCACATC\",\n",
|
| 87 |
+
" \"C\": \"CGACACCTGACGCG\",\n",
|
| 88 |
+
" \"D\": \"CGCTATCCAGAGCG\",\n",
|
| 89 |
+
" \"E\": \"CGCGATGCTTCTCG\",\n",
|
| 90 |
+
" \"F\": \"CGGCTGGATTACCG\",\n",
|
| 91 |
+
" \"G\": \"CGAGAACATAGTCG\",\n",
|
| 92 |
+
" \"H\": \"CGGGGAAACGCCCG\",\n",
|
| 93 |
+
" \"I\": \"CGCCCAAAGCCGCG\",\n",
|
| 94 |
+
" \"J\": \"CGGAGGTAATGACG\",\n",
|
| 95 |
+
" \"K\": \"CGCACCGACTCACG\",\n",
|
| 96 |
+
" \"L\": \"CGGCCCTTTGCGCG\",\n",
|
| 97 |
+
" \"M\": \"CGCCGTTAGTGTCG\"\n",
|
| 98 |
+
"}\n",
|
| 99 |
+
"sorted_chroms = list(chrom_cache.keys())\n",
|
| 100 |
+
"sorted_chroms = sorted(sorted_chroms, key = lambda x: int(x.split(\"chr\")[1]) if x.split(\"chr\")[1] not in [\"M\",\"X\",\"Y\"] else 0)\n",
|
| 101 |
+
"\n",
|
| 102 |
+
"for seq_letter, seq in baker_sequences.items():\n",
|
| 103 |
+
" print(f\"Testing sequence {seq_letter}: {seq}\")\n",
|
| 104 |
+
" for chrom in sorted_chroms:\n",
|
| 105 |
+
" chrom_dna = chrom_cache[chrom][\"dna\"].upper()\n",
|
| 106 |
+
" match_chroms=[]\n",
|
| 107 |
+
" try:\n",
|
| 108 |
+
" print(f\"\\tChrom {chrom} index of sequence {seq_letter} ({seq}): {chrom.index(seq)}\")\n",
|
| 109 |
+
" match_chroms+=[chrom]\n",
|
| 110 |
+
" except:\n",
|
| 111 |
+
" match_chroms = match_chroms\n",
|
| 112 |
+
" print(f\"\\tChrom {chrom} does not have sequence {seq_letter}({seq})\")"
|
| 113 |
+
]
|
| 114 |
+
},
|
| 115 |
+
{
|
| 116 |
+
"cell_type": "code",
|
| 117 |
+
"execution_count": null,
|
| 118 |
+
"id": "99078ac6",
|
| 119 |
+
"metadata": {},
|
| 120 |
+
"outputs": [],
|
| 121 |
+
"source": [
|
| 122 |
+
"#[m.start() for m in re.finditer(chr1_dna, 'GCAGATCTGCACATC')]"
|
| 123 |
+
]
|
| 124 |
+
}
|
| 125 |
+
],
|
| 126 |
+
"metadata": {
|
| 127 |
+
"kernelspec": {
|
| 128 |
+
"display_name": "dnabind2",
|
| 129 |
+
"language": "python",
|
| 130 |
+
"name": "python3"
|
| 131 |
+
},
|
| 132 |
+
"language_info": {
|
| 133 |
+
"codemirror_mode": {
|
| 134 |
+
"name": "ipython",
|
| 135 |
+
"version": 3
|
| 136 |
+
},
|
| 137 |
+
"file_extension": ".py",
|
| 138 |
+
"mimetype": "text/x-python",
|
| 139 |
+
"name": "python",
|
| 140 |
+
"nbconvert_exporter": "python",
|
| 141 |
+
"pygments_lexer": "ipython3",
|
| 142 |
+
"version": "3.10.14"
|
| 143 |
+
}
|
| 144 |
+
},
|
| 145 |
+
"nbformat": 4,
|
| 146 |
+
"nbformat_minor": 5
|
| 147 |
+
}
|
dpacman/scripts/delay_run.sh
CHANGED
|
@@ -1,10 +1,10 @@
|
|
| 1 |
#!/usr/bin/env bash
|
| 2 |
set -euo pipefail
|
| 3 |
|
| 4 |
-
# Usage: ./
|
| 5 |
# Optional: override waits via env vars WAIT1 / WAIT2 (seconds). Defaults: 3 hours each.
|
| 6 |
|
| 7 |
-
WAIT1=${WAIT1:-
|
| 8 |
WAIT2=${WAIT2:-10800}
|
| 9 |
|
| 10 |
SCRIPT1="${1:?usage: $0 <first_script.sh> <second_script.sh>}"
|
|
|
|
| 1 |
#!/usr/bin/env bash
|
| 2 |
set -euo pipefail
|
| 3 |
|
| 4 |
+
# Usage: nohup bash scripts/delay_run.sh scripts/run_train.sh scripts/run_train_2.sh > delay.log 2>&1 &
|
| 5 |
# Optional: override waits via env vars WAIT1 / WAIT2 (seconds). Defaults: 3 hours each.
|
| 6 |
|
| 7 |
+
WAIT1=${WAIT1:-10800} # 3 hours in seconds
|
| 8 |
WAIT2=${WAIT2:-10800}
|
| 9 |
|
| 10 |
SCRIPT1="${1:?usage: $0 <first_script.sh> <second_script.sh>}"
|
dpacman/scripts/run_eval.sh
CHANGED
|
@@ -14,7 +14,7 @@ if [ -z "$WANDB_API_KEY" ]; then
|
|
| 14 |
export WANDB_API_KEY="$wandb_key"
|
| 15 |
fi
|
| 16 |
|
| 17 |
-
CUDA_VISIBLE_DEVICES=
|
| 18 |
hydra.run.dir="${run_dir}" \
|
| 19 |
data_module.test_file="data_files/processed/splits/by_dna/test.csv" \
|
| 20 |
data_module.tr_shelf_path="data_files/processed/embeddings/fimo_hits_only/trs_esm.shelf" \
|
|
@@ -23,8 +23,10 @@ CUDA_VISIBLE_DEVICES=3 nohup python -u -m scripts.eval \
|
|
| 23 |
model.glm_input_dim=256 \
|
| 24 |
model.compressed_dim=256 \
|
| 25 |
model.hidden_dim=256 \
|
| 26 |
-
|
| 27 |
-
|
|
|
|
|
|
|
| 28 |
> "${run_dir}/run.log" 2>&1 &
|
| 29 |
|
| 30 |
echo $! > "${run_dir}/pid.txt"
|
|
|
|
| 14 |
export WANDB_API_KEY="$wandb_key"
|
| 15 |
fi
|
| 16 |
|
| 17 |
+
CUDA_VISIBLE_DEVICES=2 nohup python -u -m scripts.eval \
|
| 18 |
hydra.run.dir="${run_dir}" \
|
| 19 |
data_module.test_file="data_files/processed/splits/by_dna/test.csv" \
|
| 20 |
data_module.tr_shelf_path="data_files/processed/embeddings/fimo_hits_only/trs_esm.shelf" \
|
|
|
|
| 23 |
model.glm_input_dim=256 \
|
| 24 |
model.compressed_dim=256 \
|
| 25 |
model.hidden_dim=256 \
|
| 26 |
+
data_module.score_col="binary_scores" \
|
| 27 |
+
data_module.norm_value=1 \
|
| 28 |
+
model.loss_type="binary" \
|
| 29 |
+
ckpt_path="/home/a03-svincoff/DPACMAN/logs/train/classifier/runs/2025-08-28_04-37-58-stoic-snowball-99/checkpoints/epoch_009.ckpt" \
|
| 30 |
> "${run_dir}/run.log" 2>&1 &
|
| 31 |
|
| 32 |
echo $! > "${run_dir}/pid.txt"
|
dpacman/scripts/run_split.sh
CHANGED
|
@@ -10,13 +10,14 @@ mkdir -p "$run_dir"
|
|
| 10 |
|
| 11 |
nohup python -u -m scripts.preprocess \
|
| 12 |
hydra.run.dir="${run_dir}" \
|
| 13 |
-
+data_task.p_exclude="true" \
|
| 14 |
data_task="${data_task_type}/remap" \
|
|
|
|
|
|
|
| 15 |
data_task.split_by=dna \
|
| 16 |
data_task.train_ratio=0.8 \
|
| 17 |
data_task.val_ratio=0.1 \
|
| 18 |
data_task.test_ratio=0.1 \
|
| 19 |
-
data_task.split_out_dir=dpacman/data_files/processed/splits/
|
| 20 |
> "${run_dir}/run.log" 2>&1 &
|
| 21 |
|
| 22 |
echo $! > "${run_dir}/pid.txt"
|
|
|
|
| 10 |
|
| 11 |
nohup python -u -m scripts.preprocess \
|
| 12 |
hydra.run.dir="${run_dir}" \
|
|
|
|
| 13 |
data_task="${data_task_type}/remap" \
|
| 14 |
+
data_task.test_trs=["trseq23","trseq26","trseq17"] \
|
| 15 |
+
data_task.test_dnas=null \
|
| 16 |
data_task.split_by=dna \
|
| 17 |
data_task.train_ratio=0.8 \
|
| 18 |
data_task.val_ratio=0.1 \
|
| 19 |
data_task.test_ratio=0.1 \
|
| 20 |
+
data_task.split_out_dir=dpacman/data_files/processed/splits/handpicked_test \
|
| 21 |
> "${run_dir}/run.log" 2>&1 &
|
| 22 |
|
| 23 |
echo $! > "${run_dir}/pid.txt"
|
dpacman/scripts/run_train.sh
CHANGED
|
@@ -23,17 +23,19 @@ CUDA_VISIBLE_DEVICES=0,1 nohup python -u -m scripts.train \
|
|
| 23 |
hydra.run.dir="${run_dir}" \
|
| 24 |
trainer.devices=2 \
|
| 25 |
trainer.max_epochs=10 \
|
| 26 |
-
data_module.train_file="data_files/processed/splits/
|
| 27 |
-
data_module.val_file="data_files/processed/splits/
|
| 28 |
-
data_module.test_file="data_files/processed/splits/
|
| 29 |
data_module.tr_shelf_path="data_files/processed/embeddings/fimo_hits_only/trs_esm.shelf" \
|
| 30 |
data_module.dna_shelf_path="data_files/processed/embeddings/fimo_hits_only/peaks_caduceus.shelf" \
|
| 31 |
data_module.batch_size=16 \
|
| 32 |
data_module.score_col="binary_scores" \
|
|
|
|
| 33 |
model.loss_type="binary" \
|
| 34 |
model.glm_input_dim=256 \
|
| 35 |
model.compressed_dim=256 \
|
| 36 |
model.hidden_dim=256 \
|
|
|
|
| 37 |
model.lr=1e-5 \
|
| 38 |
> "${run_dir}/run.log" 2>&1 &
|
| 39 |
|
|
|
|
| 23 |
hydra.run.dir="${run_dir}" \
|
| 24 |
trainer.devices=2 \
|
| 25 |
trainer.max_epochs=10 \
|
| 26 |
+
data_module.train_file="/home/a03-svincoff/DPACMAN/dpacman/data_files/processed/splits/handpicked_val_test_cropTR4/train.csv" \
|
| 27 |
+
data_module.val_file="/home/a03-svincoff/DPACMAN/dpacman/data_files/processed/splits/handpicked_val_test_cropTR4/val.csv" \
|
| 28 |
+
data_module.test_file="/home/a03-svincoff/DPACMAN/dpacman/data_files/processed/splits/handpicked_val_test_cropTR4/test.csv" \
|
| 29 |
data_module.tr_shelf_path="data_files/processed/embeddings/fimo_hits_only/trs_esm.shelf" \
|
| 30 |
data_module.dna_shelf_path="data_files/processed/embeddings/fimo_hits_only/peaks_caduceus.shelf" \
|
| 31 |
data_module.batch_size=16 \
|
| 32 |
data_module.score_col="binary_scores" \
|
| 33 |
+
data_module.norm_value=1 \
|
| 34 |
model.loss_type="binary" \
|
| 35 |
model.glm_input_dim=256 \
|
| 36 |
model.compressed_dim=256 \
|
| 37 |
model.hidden_dim=256 \
|
| 38 |
+
model.dropout=0.2 \
|
| 39 |
model.lr=1e-5 \
|
| 40 |
> "${run_dir}/run.log" 2>&1 &
|
| 41 |
|
dpacman/scripts/run_train_baseline.sh
CHANGED
|
@@ -31,6 +31,7 @@ CUDA_VISIBLE_DEVICES=2,3 nohup python -u -m scripts.train \
|
|
| 31 |
data_module.batch_size=16 \
|
| 32 |
data_module.score_col="binary_scores" \
|
| 33 |
model.loss_type="binary" \
|
|
|
|
| 34 |
model=baseline \
|
| 35 |
model.glm_input_dim=256 \
|
| 36 |
model.compressed_dim=256 \
|
|
|
|
| 31 |
data_module.batch_size=16 \
|
| 32 |
data_module.score_col="binary_scores" \
|
| 33 |
model.loss_type="binary" \
|
| 34 |
+
data_module.norm_value=1 \
|
| 35 |
model=baseline \
|
| 36 |
model.glm_input_dim=256 \
|
| 37 |
model.compressed_dim=256 \
|