bug fixes
Browse files- .gitignore +2 -1
- configs/model/classifier.yaml +2 -1
- dpacman.egg-info/SOURCES.txt +39 -1
- dpacman.egg-info/top_level.txt +1 -1
- dpacman/classifier/model.py +62 -27
- dpacman/classifier/model_w_rca.py +323 -0
- dpacman/data_modules/pair.py +74 -10
- dpacman/scripts/run_train.sh +6 -2
.gitignore
CHANGED
|
@@ -34,4 +34,5 @@ dpacman/combine_shards.py
|
|
| 34 |
dpacman/combine.log
|
| 35 |
dpacman/loss_sim.py
|
| 36 |
dpacman/loss_temp.py
|
| 37 |
-
dpacman/peak_examples/
|
|
|
|
|
|
| 34 |
dpacman/combine.log
|
| 35 |
dpacman/loss_sim.py
|
| 36 |
dpacman/loss_temp.py
|
| 37 |
+
dpacman/peak_examples/
|
| 38 |
+
dpacman/__pycache__/
|
configs/model/classifier.yaml
CHANGED
|
@@ -6,4 +6,5 @@ gamma: 20
|
|
| 6 |
weight_decay: 0.01
|
| 7 |
|
| 8 |
glm_input_dim: 1029
|
| 9 |
-
compressed_dim: 1029
|
|
|
|
|
|
| 6 |
weight_decay: 0.01
|
| 7 |
|
| 8 |
glm_input_dim: 1029
|
| 9 |
+
compressed_dim: 1029
|
| 10 |
+
hidden_dim: 256
|
dpacman.egg-info/SOURCES.txt
CHANGED
|
@@ -1,6 +1,44 @@
|
|
| 1 |
README.md
|
| 2 |
setup.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
dpacman.egg-info/PKG-INFO
|
| 4 |
dpacman.egg-info/SOURCES.txt
|
| 5 |
dpacman.egg-info/dependency_links.txt
|
| 6 |
-
dpacman.egg-info/top_level.txt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
README.md
|
| 2 |
setup.py
|
| 3 |
+
dpacman/__init__.py
|
| 4 |
+
dpacman/combine_shards.py
|
| 5 |
+
dpacman/loss_sim.py
|
| 6 |
+
dpacman/loss_temp.py
|
| 7 |
+
dpacman/temp.py
|
| 8 |
+
dpacman/temp2.py
|
| 9 |
dpacman.egg-info/PKG-INFO
|
| 10 |
dpacman.egg-info/SOURCES.txt
|
| 11 |
dpacman.egg-info/dependency_links.txt
|
| 12 |
+
dpacman.egg-info/top_level.txt
|
| 13 |
+
dpacman/classifier/__init__.py
|
| 14 |
+
dpacman/classifier/loss.py
|
| 15 |
+
dpacman/classifier/model.py
|
| 16 |
+
dpacman/classifier/old_train.py
|
| 17 |
+
dpacman/classifier/torch_model.py
|
| 18 |
+
dpacman/classifier/train.py
|
| 19 |
+
dpacman/classifier/model_tmp/__init__.py
|
| 20 |
+
dpacman/classifier/model_tmp/clustering_data.py
|
| 21 |
+
dpacman/classifier/model_tmp/compress_embeddings.py
|
| 22 |
+
dpacman/classifier/model_tmp/compute_embeddings.py
|
| 23 |
+
dpacman/classifier/model_tmp/extract_tf_symbols.py
|
| 24 |
+
dpacman/classifier/model_tmp/make_pair_list.py
|
| 25 |
+
dpacman/classifier/model_tmp/make_peak_fasta.py
|
| 26 |
+
dpacman/classifier/model_tmp/model.py
|
| 27 |
+
dpacman/classifier/model_tmp/prep_splits.py
|
| 28 |
+
dpacman/classifier/model_tmp/train.py
|
| 29 |
+
dpacman/data_modules/__init__.py
|
| 30 |
+
dpacman/data_modules/pair.py
|
| 31 |
+
dpacman/scripts/__init__.py
|
| 32 |
+
dpacman/scripts/eval.py
|
| 33 |
+
dpacman/scripts/preprocess.py
|
| 34 |
+
dpacman/scripts/train.py
|
| 35 |
+
dpacman/utils/__init__.py
|
| 36 |
+
dpacman/utils/clustering.py
|
| 37 |
+
dpacman/utils/instantiators.py
|
| 38 |
+
dpacman/utils/logging_utils.py
|
| 39 |
+
dpacman/utils/models.py
|
| 40 |
+
dpacman/utils/plotting_utils.py
|
| 41 |
+
dpacman/utils/pylogger.py
|
| 42 |
+
dpacman/utils/rich_utils.py
|
| 43 |
+
dpacman/utils/splitting.py
|
| 44 |
+
dpacman/utils/utils.py
|
dpacman.egg-info/top_level.txt
CHANGED
|
@@ -1 +1 @@
|
|
| 1 |
-
|
|
|
|
| 1 |
+
dpacman
|
dpacman/classifier/model.py
CHANGED
|
@@ -28,59 +28,93 @@ class LocalCNN(nn.Module):
|
|
| 28 |
|
| 29 |
|
| 30 |
class CrossModalBlock(nn.Module):
|
| 31 |
-
def __init__(self, dim: int = 256, heads: int = 8):
|
| 32 |
super().__init__()
|
| 33 |
# self-attention for both sides
|
| 34 |
self.sa_binder = nn.MultiheadAttention(dim, heads, batch_first=True)
|
| 35 |
self.sa_glm = nn.MultiheadAttention(dim, heads, batch_first=True)
|
|
|
|
| 36 |
self.ln_b1 = nn.LayerNorm(dim)
|
| 37 |
self.ln_g1 = nn.LayerNorm(dim)
|
| 38 |
-
|
| 39 |
-
self.
|
| 40 |
nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
|
| 41 |
)
|
| 42 |
-
self.
|
| 43 |
nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
|
| 44 |
)
|
| 45 |
self.ln_b2 = nn.LayerNorm(dim)
|
| 46 |
self.ln_g2 = nn.LayerNorm(dim)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
# cross attention (binder queries, glm keys/values)
|
| 49 |
# so the NDA path is updated by the transcriptoin factors
|
| 50 |
-
self.
|
| 51 |
-
self.
|
| 52 |
-
self.
|
| 53 |
nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
|
| 54 |
)
|
| 55 |
-
self.
|
| 56 |
|
| 57 |
-
def forward(self, binder: torch.Tensor, glm: torch.Tensor):
|
| 58 |
"""
|
| 59 |
binder: (batch, Lb, dim)
|
| 60 |
glm: (batch, Lg, dim) -- has passed through its local CNN beforehand
|
| 61 |
returns: updated binder representation (batch, Lb, dim)
|
| 62 |
"""
|
|
|
|
| 63 |
# binder: self-attn + ffn
|
| 64 |
b = binder
|
| 65 |
-
b_sa, _ = self.sa_binder(b, b, b)
|
| 66 |
b = self.ln_b1(b + b_sa)
|
| 67 |
-
b_ff = self.
|
| 68 |
b = self.ln_b2(b + b_ff)
|
| 69 |
|
| 70 |
# glm: self-attn + ffn
|
| 71 |
g = glm
|
| 72 |
-
g_sa, _ = self.sa_glm(g, g, g)
|
| 73 |
g = self.ln_g1(g + g_sa)
|
| 74 |
-
g_ff = self.
|
| 75 |
g = self.ln_g2(g + g_ff)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
-
#
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
|
|
|
| 83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
class DimCompressor(nn.Module):
|
| 86 |
"""
|
|
@@ -144,7 +178,7 @@ class BindPredictor(LightningModule):
|
|
| 144 |
# self.head = nn.Sequential(nn.Linear(hidden_dim, 1), nn.Sigmoid()) # OLD: returned probabilities
|
| 145 |
self.head = nn.Linear(hidden_dim, 1) # NEW: return logits (safe for AMP)
|
| 146 |
|
| 147 |
-
def forward(self, binder_emb, glm_emb):
|
| 148 |
"""
|
| 149 |
binder_emb: (B, Lb, binder_input_dim)
|
| 150 |
glm_emb: (B, Lg, glm_input_dim)
|
|
@@ -161,14 +195,15 @@ class BindPredictor(LightningModule):
|
|
| 161 |
|
| 162 |
# Cross-modal blocks: update binder states using GLM
|
| 163 |
for layer in self.layers:
|
| 164 |
-
g = layer(b, g) # (B, Lb, hidden_dim)
|
| 165 |
|
| 166 |
# Predict per-nucleotide logits on the GLM tokens:
|
| 167 |
# return self.head(g).squeeze(-1) # OLD: probabilities (with Sigmoid in head)
|
| 168 |
-
|
| 169 |
-1
|
| 170 |
-
)
|
| 171 |
-
|
|
|
|
| 172 |
# ----- Lightning hooks -----
|
| 173 |
def training_step(self, batch, batch_idx):
|
| 174 |
"""
|
|
@@ -184,7 +219,7 @@ class BindPredictor(LightningModule):
|
|
| 184 |
"dna_sequence"
|
| 185 |
}
|
| 186 |
"""
|
| 187 |
-
logits = self.forward(batch["binder_emb"], batch["glm_emb"])
|
| 188 |
loss = calculate_loss(
|
| 189 |
logits, batch["labels"], alpha=self.hparams.alpha, gamma=self.hparams.gamma
|
| 190 |
)
|
|
@@ -199,7 +234,7 @@ class BindPredictor(LightningModule):
|
|
| 199 |
return loss
|
| 200 |
|
| 201 |
def validation_step(self, batch, batch_idx):
|
| 202 |
-
logits = self.forward(batch["binder_emb"], batch["glm_emb"])
|
| 203 |
loss = calculate_loss(
|
| 204 |
logits, batch["labels"], alpha=self.hparams.alpha, gamma=self.hparams.gamma
|
| 205 |
)
|
|
@@ -214,7 +249,7 @@ class BindPredictor(LightningModule):
|
|
| 214 |
return loss
|
| 215 |
|
| 216 |
def test_step(self, batch, batch_idx):
|
| 217 |
-
logits = self.forward(batch["binder_emb"], batch["glm_emb"])
|
| 218 |
loss = calculate_loss(
|
| 219 |
logits, batch["labels"], alpha=self.hparams.alpha, gamma=self.hparams.gamma
|
| 220 |
)
|
|
|
|
| 28 |
|
| 29 |
|
| 30 |
class CrossModalBlock(nn.Module):
|
| 31 |
+
def __init__(self, dim: int = 256, heads: int = 8, dropout: float = 0.0):
|
| 32 |
super().__init__()
|
| 33 |
# self-attention for both sides
|
| 34 |
self.sa_binder = nn.MultiheadAttention(dim, heads, batch_first=True)
|
| 35 |
self.sa_glm = nn.MultiheadAttention(dim, heads, batch_first=True)
|
| 36 |
+
# first layer norms
|
| 37 |
self.ln_b1 = nn.LayerNorm(dim)
|
| 38 |
self.ln_g1 = nn.LayerNorm(dim)
|
| 39 |
+
# first feed forward networks
|
| 40 |
+
self.ffn_b1 = nn.Sequential(
|
| 41 |
nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
|
| 42 |
)
|
| 43 |
+
self.ffn_g1 = nn.Sequential(
|
| 44 |
nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
|
| 45 |
)
|
| 46 |
self.ln_b2 = nn.LayerNorm(dim)
|
| 47 |
self.ln_g2 = nn.LayerNorm(dim)
|
| 48 |
+
|
| 49 |
+
# 2) reciprocal cross-attn: g<-b and b<-g
|
| 50 |
+
# DNA/GLM updated by attending to Binder
|
| 51 |
+
self.cross_g2b_1_RCA = nn.MultiheadAttention(dim, heads, batch_first=True, dropout=dropout)
|
| 52 |
+
self.ln_g3_RCA = nn.LayerNorm(dim)
|
| 53 |
+
self.ffn_g2_RCA = nn.Sequential(nn.Linear(dim, dim*4), nn.GELU(), nn.Linear(dim*4, dim))
|
| 54 |
+
self.ln_g4_RCA = nn.LayerNorm(dim)
|
| 55 |
+
|
| 56 |
+
# Binder updated by attending to DNA/GLM
|
| 57 |
+
self.cross_b2g_1_RCA = nn.MultiheadAttention(dim, heads, batch_first=True, dropout=dropout)
|
| 58 |
+
self.ln_b3_RCA = nn.LayerNorm(dim)
|
| 59 |
+
self.ffn_b2_RCA = nn.Sequential(nn.Linear(dim, dim*4), nn.GELU(), nn.Linear(dim*4, dim))
|
| 60 |
+
self.ln_b4_RCA = nn.LayerNorm(dim)
|
| 61 |
|
| 62 |
# cross attention (binder queries, glm keys/values)
|
| 63 |
# so the NDA path is updated by the transcriptoin factors
|
| 64 |
+
self.cross_g2b_2 = nn.MultiheadAttention(dim, heads, batch_first=True)
|
| 65 |
+
self.ln_g5 = nn.LayerNorm(dim)
|
| 66 |
+
self.ffn_g3 = nn.Sequential(
|
| 67 |
nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
|
| 68 |
)
|
| 69 |
+
self.ln_g6 = nn.LayerNorm(dim)
|
| 70 |
|
| 71 |
+
def forward(self, binder: torch.Tensor, glm: torch.Tensor, binder_kpm_mask=None, glm_kpm_mask=None):
|
| 72 |
"""
|
| 73 |
binder: (batch, Lb, dim)
|
| 74 |
glm: (batch, Lg, dim) -- has passed through its local CNN beforehand
|
| 75 |
returns: updated binder representation (batch, Lb, dim)
|
| 76 |
"""
|
| 77 |
+
# 1) Self-attentino and feed-forward networks for binder and DNA
|
| 78 |
# binder: self-attn + ffn
|
| 79 |
b = binder
|
| 80 |
+
b_sa, _ = self.sa_binder(b, b, b, key_padding_mask=binder_kpm_mask)
|
| 81 |
b = self.ln_b1(b + b_sa)
|
| 82 |
+
b_ff = self.ffn_b1(b)
|
| 83 |
b = self.ln_b2(b + b_ff)
|
| 84 |
|
| 85 |
# glm: self-attn + ffn
|
| 86 |
g = glm
|
| 87 |
+
g_sa, _ = self.sa_glm(g, g, g, key_padding_mask=glm_kpm_mask)
|
| 88 |
g = self.ln_g1(g + g_sa)
|
| 89 |
+
g_ff = self.ffn_g1(g)
|
| 90 |
g = self.ln_g2(g + g_ff)
|
| 91 |
+
|
| 92 |
+
# 2a) Reciprocal Cross-Attention:
|
| 93 |
+
# DNA updated by attending to Binder (Q=g, K=b, V=b)
|
| 94 |
+
# Binder updated by attending to DNA (Q=b, K=g, V=g)
|
| 95 |
+
g_ca, _ = self.cross_g2b_1_RCA(
|
| 96 |
+
g, b, b, key_padding_mask=binder_kpm_mask
|
| 97 |
+
# torch MultiheadAttention expects key_padding_mask=True for PADs;
|
| 98 |
+
# invert if your mask is True=keep:
|
| 99 |
+
# key_padding_mask=(~binder_mask.bool()) if binder_mask is not None else None
|
| 100 |
+
)
|
| 101 |
+
g = self.ln_g3_RCA(g + g_ca)
|
| 102 |
+
g = self.ln_g4_RCA(g + self.ffn_g2_RCA(g))
|
| 103 |
|
| 104 |
+
# 2b) Binder updated by attending to DNA/GLM (Q=b, K=g, V=g)
|
| 105 |
+
b_ca, _ = self.cross_b2g_1_RCA(
|
| 106 |
+
b, g, g, key_padding_mask=glm_kpm_mask
|
| 107 |
+
# key_padding_mask=(~glm_mask.bool()) if glm_mask is not None else None
|
| 108 |
+
)
|
| 109 |
+
b = self.ln_b3_RCA(b + b_ca)
|
| 110 |
+
b = self.ln_b4_RCA(b + self.ffn_b2_RCA(b))
|
| 111 |
|
| 112 |
+
# cross-attention: glm queries binder and glm embeddings are updated
|
| 113 |
+
g_to_b_ca, _ = self.cross_g2b_2(g, b, b, key_padding_mask=binder_kpm_mask)
|
| 114 |
+
g = self.ln_g5(g + g_to_b_ca)
|
| 115 |
+
g_ff = self.ffn_g3(g)
|
| 116 |
+
g = self.ln_g6(g + g_ff)
|
| 117 |
+
return b, g # (batch, Lb, dim)
|
| 118 |
|
| 119 |
class DimCompressor(nn.Module):
|
| 120 |
"""
|
|
|
|
| 178 |
# self.head = nn.Sequential(nn.Linear(hidden_dim, 1), nn.Sigmoid()) # OLD: returned probabilities
|
| 179 |
self.head = nn.Linear(hidden_dim, 1) # NEW: return logits (safe for AMP)
|
| 180 |
|
| 181 |
+
def forward(self, binder_emb, glm_emb, binder_mask, glm_mask):
|
| 182 |
"""
|
| 183 |
binder_emb: (B, Lb, binder_input_dim)
|
| 184 |
glm_emb: (B, Lg, glm_input_dim)
|
|
|
|
| 195 |
|
| 196 |
# Cross-modal blocks: update binder states using GLM
|
| 197 |
for layer in self.layers:
|
| 198 |
+
b, g = layer(b, g, binder_mask, glm_mask) # (B, Lb, hidden_dim)
|
| 199 |
|
| 200 |
# Predict per-nucleotide logits on the GLM tokens:
|
| 201 |
# return self.head(g).squeeze(-1) # OLD: probabilities (with Sigmoid in head)
|
| 202 |
+
logits = self.head(g).squeeze(
|
| 203 |
-1
|
| 204 |
+
)
|
| 205 |
+
return logits
|
| 206 |
+
|
| 207 |
# ----- Lightning hooks -----
|
| 208 |
def training_step(self, batch, batch_idx):
|
| 209 |
"""
|
|
|
|
| 219 |
"dna_sequence"
|
| 220 |
}
|
| 221 |
"""
|
| 222 |
+
logits = self.forward(batch["binder_emb"], batch["glm_emb"], batch["binder_kpm"], batch["glm_kpm"])
|
| 223 |
loss = calculate_loss(
|
| 224 |
logits, batch["labels"], alpha=self.hparams.alpha, gamma=self.hparams.gamma
|
| 225 |
)
|
|
|
|
| 234 |
return loss
|
| 235 |
|
| 236 |
def validation_step(self, batch, batch_idx):
|
| 237 |
+
logits = self.forward(batch["binder_emb"], batch["glm_emb"], batch["binder_kpm"], batch["glm_kpm"])
|
| 238 |
loss = calculate_loss(
|
| 239 |
logits, batch["labels"], alpha=self.hparams.alpha, gamma=self.hparams.gamma
|
| 240 |
)
|
|
|
|
| 249 |
return loss
|
| 250 |
|
| 251 |
def test_step(self, batch, batch_idx):
|
| 252 |
+
logits = self.forward(batch["binder_emb"], batch["glm_emb"], batch["binder_kpm"], batch["glm_kpm"])
|
| 253 |
loss = calculate_loss(
|
| 254 |
logits, batch["labels"], alpha=self.hparams.alpha, gamma=self.hparams.gamma
|
| 255 |
)
|
dpacman/classifier/model_w_rca.py
ADDED
|
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Lightning Module for the binding model.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
from lightning import LightningModule
|
| 8 |
+
from dpacman.utils.models import set_seed
|
| 9 |
+
from .loss import calculate_loss
|
| 10 |
+
|
| 11 |
+
set_seed()
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class LocalCNN(nn.Module):
|
| 15 |
+
def __init__(self, dim: int = 256, kernel_size: int = 3):
|
| 16 |
+
super().__init__()
|
| 17 |
+
padding = kernel_size // 2
|
| 18 |
+
self.conv = nn.Conv1d(dim, dim, kernel_size=kernel_size, padding=padding)
|
| 19 |
+
self.act = nn.GELU()
|
| 20 |
+
self.ln = nn.LayerNorm(dim)
|
| 21 |
+
|
| 22 |
+
def forward(self, x: torch.Tensor):
|
| 23 |
+
# x: (batch, L, dim)
|
| 24 |
+
out = self.conv(x.transpose(1, 2)) # → (batch, dim, L)
|
| 25 |
+
out = self.act(out)
|
| 26 |
+
out = out.transpose(1, 2) # → (batch, L, dim)
|
| 27 |
+
return self.ln(out + x) # residual
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# class CrossModalBlock(nn.Module):
|
| 31 |
+
# def __init__(self, dim: int = 256, heads: int = 8):
|
| 32 |
+
# super().__init__()
|
| 33 |
+
# # self-attention for both sides
|
| 34 |
+
# self.sa_binder = nn.MultiheadAttention(dim, heads, batch_first=True)
|
| 35 |
+
# self.sa_glm = nn.MultiheadAttention(dim, heads, batch_first=True)
|
| 36 |
+
# self.ln_b1 = nn.LayerNorm(dim)
|
| 37 |
+
# self.ln_g1 = nn.LayerNorm(dim)
|
| 38 |
+
|
| 39 |
+
# self.ffn_b = nn.Sequential(
|
| 40 |
+
# nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
|
| 41 |
+
# )
|
| 42 |
+
# self.ffn_g = nn.Sequential(
|
| 43 |
+
# nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
|
| 44 |
+
# )
|
| 45 |
+
# self.ln_b2 = nn.LayerNorm(dim)
|
| 46 |
+
# self.ln_g2 = nn.LayerNorm(dim)
|
| 47 |
+
|
| 48 |
+
# # cross attention (binder queries, glm keys/values)
|
| 49 |
+
# # so the NDA path is updated by the transcriptoin factors
|
| 50 |
+
# self.cross_attn = nn.MultiheadAttention(dim, heads, batch_first=True)
|
| 51 |
+
# self.ln_c1 = nn.LayerNorm(dim)
|
| 52 |
+
# self.ffn_c = nn.Sequential(
|
| 53 |
+
# nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
|
| 54 |
+
# )
|
| 55 |
+
# self.ln_c2 = nn.LayerNorm(dim)
|
| 56 |
+
|
| 57 |
+
# def forward(self, binder: torch.Tensor, glm: torch.Tensor):
|
| 58 |
+
# """
|
| 59 |
+
# binder: (batch, Lb, dim)
|
| 60 |
+
# glm: (batch, Lg, dim) -- has passed through its local CNN beforehand
|
| 61 |
+
# returns: updated binder representation (batch, Lb, dim)
|
| 62 |
+
# """
|
| 63 |
+
# # binder: self-attn + ffn
|
| 64 |
+
# b = binder
|
| 65 |
+
# b_sa, _ = self.sa_binder(b, b, b)
|
| 66 |
+
# b = self.ln_b1(b + b_sa)
|
| 67 |
+
# b_ff = self.ffn_b(b)
|
| 68 |
+
# b = self.ln_b2(b + b_ff)
|
| 69 |
+
|
| 70 |
+
# # glm: self-attn + ffn
|
| 71 |
+
# g = glm
|
| 72 |
+
# g_sa, _ = self.sa_glm(g, g, g)
|
| 73 |
+
# g = self.ln_g1(g + g_sa)
|
| 74 |
+
# g_ff = self.ffn_g(g)
|
| 75 |
+
# g = self.ln_g2(g + g_ff)
|
| 76 |
+
|
| 77 |
+
# # cross-attention: glm queries binder and glm embeddings are updated
|
| 78 |
+
# g_to_b_ca, _ = self.cross_attn(g, b, b)
|
| 79 |
+
# g = self.ln_c1(g + g_to_b_ca)
|
| 80 |
+
# g_ff = self.ffn_c(g)
|
| 81 |
+
# g = self.ln_c2(g + g_ff)
|
| 82 |
+
# return g # (batch, Lb, dim)
|
| 83 |
+
|
| 84 |
+
class CrossModalBlock(nn.Module):
|
| 85 |
+
def __init__(self, dim: int = 256, heads: int = 8, dropout: float = 0.0):
|
| 86 |
+
super().__init__()
|
| 87 |
+
# 1) self-attn on each stream
|
| 88 |
+
self.sa_binder = nn.MultiheadAttention(dim, heads, batch_first=True, dropout=dropout)
|
| 89 |
+
self.sa_glm = nn.MultiheadAttention(dim, heads, batch_first=True, dropout=dropout)
|
| 90 |
+
self.ln_b1 = nn.LayerNorm(dim)
|
| 91 |
+
self.ln_g1 = nn.LayerNorm(dim)
|
| 92 |
+
self.ffn_b = nn.Sequential(nn.Linear(dim, dim*4), nn.GELU(), nn.Linear(dim*4, dim))
|
| 93 |
+
self.ffn_g = nn.Sequential(nn.Linear(dim, dim*4), nn.GELU(), nn.Linear(dim*4, dim))
|
| 94 |
+
self.ln_b2 = nn.LayerNorm(dim)
|
| 95 |
+
self.ln_g2 = nn.LayerNorm(dim)
|
| 96 |
+
|
| 97 |
+
# 2) reciprocal cross-attn: g<-b and b<-g
|
| 98 |
+
# DNA/GLM updated by attending to Binder
|
| 99 |
+
self.cross_g2b = nn.MultiheadAttention(dim, heads, batch_first=True, dropout=dropout)
|
| 100 |
+
self.ln_g_ca1 = nn.LayerNorm(dim)
|
| 101 |
+
self.ffn_g_ca = nn.Sequential(nn.Linear(dim, dim*4), nn.GELU(), nn.Linear(dim*4, dim))
|
| 102 |
+
self.ln_g_ca2 = nn.LayerNorm(dim)
|
| 103 |
+
|
| 104 |
+
# Binder updated by attending to DNA/GLM
|
| 105 |
+
self.cross_b2g = nn.MultiheadAttention(dim, heads, batch_first=True, dropout=dropout)
|
| 106 |
+
self.ln_b_ca1 = nn.LayerNorm(dim)
|
| 107 |
+
self.ffn_b_ca = nn.Sequential(nn.Linear(dim, dim*4), nn.GELU(), nn.Linear(dim*4, dim))
|
| 108 |
+
self.ln_b_ca2 = nn.LayerNorm(dim)
|
| 109 |
+
|
| 110 |
+
def forward(
|
| 111 |
+
self,
|
| 112 |
+
binder: torch.Tensor, # (B, Lb, D)
|
| 113 |
+
glm: torch.Tensor, # (B, Lg, D)
|
| 114 |
+
binder_mask: torch.Tensor | None = None, # (B, Lb) True = keep
|
| 115 |
+
glm_mask: torch.Tensor | None = None, # (B, Lg) True = keep
|
| 116 |
+
):
|
| 117 |
+
# 1) self-attn+FFN on each stream
|
| 118 |
+
b, g = binder, glm
|
| 119 |
+
|
| 120 |
+
b_sa, _ = self.sa_binder(b, b, b, key_padding_mask=None)
|
| 121 |
+
b = self.ln_b1(b + b_sa)
|
| 122 |
+
b = self.ln_b2(b + self.ffn_b(b))
|
| 123 |
+
|
| 124 |
+
g_sa, _ = self.sa_glm(g, g, g, key_padding_mask=None)
|
| 125 |
+
g = self.ln_g1(g + g_sa)
|
| 126 |
+
g = self.ln_g2(g + self.ffn_g(g))
|
| 127 |
+
|
| 128 |
+
# 2a) DNA/GLM updated by attending to Binder (Q=g, K=b, V=b)
|
| 129 |
+
g_ca, _ = self.cross_g2b(
|
| 130 |
+
g, b, b,
|
| 131 |
+
# torch MultiheadAttention expects key_padding_mask=True for PADs;
|
| 132 |
+
# invert if your mask is True=keep:
|
| 133 |
+
# key_padding_mask=(~binder_mask.bool()) if binder_mask is not None else None
|
| 134 |
+
)
|
| 135 |
+
g = self.ln_g_ca1(g + g_ca)
|
| 136 |
+
g = self.ln_g_ca2(g + self.ffn_g_ca(g))
|
| 137 |
+
|
| 138 |
+
# 2b) Binder updated by attending to DNA/GLM (Q=b, K=g, V=g)
|
| 139 |
+
b_ca, _ = self.cross_b2g(
|
| 140 |
+
b, g, g,
|
| 141 |
+
# key_padding_mask=(~glm_mask.bool()) if glm_mask is not None else None
|
| 142 |
+
)
|
| 143 |
+
b = self.ln_b_ca1(b + b_ca)
|
| 144 |
+
b = self.ln_b_ca2(b + self.ffn_b_ca(b))
|
| 145 |
+
|
| 146 |
+
return b, g
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class DimCompressor(nn.Module):
|
| 151 |
+
"""
|
| 152 |
+
Learnable per-token compressor: maps any in_dim >= out_dim to out_dim (default 256).
|
| 153 |
+
If in_dim == out_dim, behaves as identity.
|
| 154 |
+
"""
|
| 155 |
+
|
| 156 |
+
def __init__(self, in_dim: int, out_dim: int = 256):
|
| 157 |
+
super().__init__()
|
| 158 |
+
if in_dim == out_dim:
|
| 159 |
+
self.net = nn.Identity()
|
| 160 |
+
else:
|
| 161 |
+
hidden = max(out_dim * 2, (in_dim + out_dim) // 2)
|
| 162 |
+
self.net = nn.Sequential(
|
| 163 |
+
nn.LayerNorm(in_dim),
|
| 164 |
+
nn.Linear(in_dim, hidden),
|
| 165 |
+
nn.GELU(),
|
| 166 |
+
nn.Linear(hidden, out_dim),
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 170 |
+
# x: (B, L, in_dim)
|
| 171 |
+
return self.net(x)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class BindPredictor(LightningModule):
|
| 175 |
+
def __init__(
|
| 176 |
+
self,
|
| 177 |
+
# input_dim: int = 256, # OLD: single input dim
|
| 178 |
+
binder_input_dim: int = 1280, # NEW: TF (binder) original dim (e.g., 1280)
|
| 179 |
+
glm_input_dim: int = 256, # NEW: DNA/GLM original dim (e.g., 256)
|
| 180 |
+
compressed_dim: int = 256, # NEW: learnable compressed dim
|
| 181 |
+
hidden_dim: int = 256,
|
| 182 |
+
heads: int = 8,
|
| 183 |
+
num_layers: int = 4,
|
| 184 |
+
lr: float = 1e-4,
|
| 185 |
+
alpha: float = 20,
|
| 186 |
+
gamma: float = 20,
|
| 187 |
+
use_local_cnn_on_glm: bool = True,
|
| 188 |
+
weight_decay: float = 0.01,
|
| 189 |
+
):
|
| 190 |
+
# Init
|
| 191 |
+
super(BindPredictor, self).__init__()
|
| 192 |
+
self.save_hyperparameters()
|
| 193 |
+
|
| 194 |
+
# Learnable compressor for binder -> 256, then project to hidden
|
| 195 |
+
self.binder_compress = DimCompressor(binder_input_dim, out_dim=compressed_dim)
|
| 196 |
+
self.proj_binder = nn.Linear(compressed_dim, hidden_dim)
|
| 197 |
+
|
| 198 |
+
# GLM side stays 256 -> hidden
|
| 199 |
+
self.proj_glm = nn.Linear(glm_input_dim, hidden_dim)
|
| 200 |
+
|
| 201 |
+
self.use_local_cnn = use_local_cnn_on_glm
|
| 202 |
+
self.local_cnn = LocalCNN(hidden_dim) if use_local_cnn_on_glm else nn.Identity()
|
| 203 |
+
|
| 204 |
+
self.layers = nn.ModuleList(
|
| 205 |
+
[CrossModalBlock(hidden_dim, heads) for _ in range(num_layers)]
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
self.ln_out = nn.LayerNorm(hidden_dim)
|
| 209 |
+
# self.head = nn.Sequential(nn.Linear(hidden_dim, 1), nn.Sigmoid()) # OLD: returned probabilities
|
| 210 |
+
self.head = nn.Linear(hidden_dim, 1) # NEW: return logits (safe for AMP)
|
| 211 |
+
|
| 212 |
+
def forward(self, binder_emb, glm_emb):
|
| 213 |
+
"""
|
| 214 |
+
binder_emb: (B, Lb, binder_input_dim)
|
| 215 |
+
glm_emb: (B, Lg, glm_input_dim)
|
| 216 |
+
Returns per-nucleotide logits for the GLM sequence: (B, Lg)
|
| 217 |
+
"""
|
| 218 |
+
# Binder: learnable compression → 256 → hidden
|
| 219 |
+
b = self.binder_compress(binder_emb) # (B, Lb, 256)
|
| 220 |
+
b = self.proj_binder(b) # (B, Lb, hidden_dim)
|
| 221 |
+
|
| 222 |
+
# GLM: project → hidden, add local CNN context
|
| 223 |
+
g = self.proj_glm(glm_emb) # (B, Lg, hidden_dim)
|
| 224 |
+
if self.use_local_cnn:
|
| 225 |
+
g = self.local_cnn(g)
|
| 226 |
+
|
| 227 |
+
# Cross-modal blocks: update binder states using GLM
|
| 228 |
+
for layer in self.layers:
|
| 229 |
+
b, g = layer(b, g) # (B, Lb, hidden_dim)
|
| 230 |
+
|
| 231 |
+
# Predict per-nucleotide logits on the GLM tokens:
|
| 232 |
+
# return self.head(g).squeeze(-1) # OLD: probabilities (with Sigmoid in head)
|
| 233 |
+
return self.head(g).squeeze(
|
| 234 |
+
-1
|
| 235 |
+
) # NEW: logits (apply sigmoid only in loss/metrics)
|
| 236 |
+
|
| 237 |
+
# ----- Lightning hooks -----
|
| 238 |
+
def training_step(self, batch, batch_idx):
|
| 239 |
+
"""
|
| 240 |
+
Training step taken by PyTorch-Lightning trainer. Uses batch returned by data collator.
|
| 241 |
+
Colator returns a dictionary with:
|
| 242 |
+
"binder_emb" # [B, Lb_max, Db]
|
| 243 |
+
"binder_mask" # [B, Lb_max]
|
| 244 |
+
"glm_emb" # [B, Lg_max, Dg]
|
| 245 |
+
"glm_mask" # [B, Lg_max]
|
| 246 |
+
"labels" # [B, Lg_max]
|
| 247 |
+
"ID"
|
| 248 |
+
"tr_sequence"
|
| 249 |
+
"dna_sequence"
|
| 250 |
+
}
|
| 251 |
+
"""
|
| 252 |
+
logits = self.forward(batch["binder_emb"], batch["glm_emb"])
|
| 253 |
+
loss = calculate_loss(
|
| 254 |
+
logits, batch["labels"], alpha=self.hparams.alpha, gamma=self.hparams.gamma
|
| 255 |
+
)
|
| 256 |
+
self.log(
|
| 257 |
+
"train/loss",
|
| 258 |
+
loss,
|
| 259 |
+
on_step=True,
|
| 260 |
+
on_epoch=True,
|
| 261 |
+
prog_bar=True,
|
| 262 |
+
batch_size=logits.size(0),
|
| 263 |
+
)
|
| 264 |
+
return loss
|
| 265 |
+
|
| 266 |
+
def validation_step(self, batch, batch_idx):
|
| 267 |
+
logits = self.forward(batch["binder_emb"], batch["glm_emb"])
|
| 268 |
+
loss = calculate_loss(
|
| 269 |
+
logits, batch["labels"], alpha=self.hparams.alpha, gamma=self.hparams.gamma
|
| 270 |
+
)
|
| 271 |
+
self.log(
|
| 272 |
+
"val/loss",
|
| 273 |
+
loss,
|
| 274 |
+
on_step=False,
|
| 275 |
+
on_epoch=True,
|
| 276 |
+
prog_bar=True,
|
| 277 |
+
batch_size=logits.size(0),
|
| 278 |
+
)
|
| 279 |
+
return loss
|
| 280 |
+
|
| 281 |
+
def test_step(self, batch, batch_idx):
|
| 282 |
+
logits = self.forward(batch["binder_emb"], batch["glm_emb"])
|
| 283 |
+
loss = calculate_loss(
|
| 284 |
+
logits, batch["labels"], alpha=self.hparams.alpha, gamma=self.hparams.gamma
|
| 285 |
+
)
|
| 286 |
+
self.log(
|
| 287 |
+
"test/loss", loss, on_step=False, on_epoch=True, batch_size=logits.size(0)
|
| 288 |
+
)
|
| 289 |
+
return loss
|
| 290 |
+
|
| 291 |
+
def on_train_epoch_end(self):
|
| 292 |
+
if False:
|
| 293 |
+
if self.train_auc.compute() is not None:
|
| 294 |
+
self.log("train/auroc", self.train_auc.compute(), prog_bar=True)
|
| 295 |
+
self.train_auc.reset()
|
| 296 |
+
|
| 297 |
+
def on_validation_epoch_end(self):
|
| 298 |
+
if False:
|
| 299 |
+
if self.val_auc.compute() is not None:
|
| 300 |
+
self.log("val/auroc", self.val_auc.compute(), prog_bar=True)
|
| 301 |
+
self.val_auc.reset()
|
| 302 |
+
|
| 303 |
+
def on_test_epoch_end(self):
|
| 304 |
+
if False:
|
| 305 |
+
if self.test_auc.compute() is not None:
|
| 306 |
+
self.log("test/auroc", self.test_auc.compute(), prog_bar=True)
|
| 307 |
+
self.test_auc.reset()
|
| 308 |
+
|
| 309 |
+
def configure_optimizers(self):
|
| 310 |
+
# AdamW + cosine as a sensible default
|
| 311 |
+
opt = torch.optim.AdamW(
|
| 312 |
+
self.parameters(),
|
| 313 |
+
lr=self.hparams.lr,
|
| 314 |
+
weight_decay=self.hparams.weight_decay,
|
| 315 |
+
)
|
| 316 |
+
# Scheduler optional—comment out if you prefer fixed LR
|
| 317 |
+
sch = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 318 |
+
opt, T_max=max(self.trainer.max_epochs, 1)
|
| 319 |
+
)
|
| 320 |
+
return {
|
| 321 |
+
"optimizer": opt,
|
| 322 |
+
"lr_scheduler": {"scheduler": sch, "interval": "epoch"},
|
| 323 |
+
}
|
dpacman/data_modules/pair.py
CHANGED
|
@@ -2,7 +2,8 @@
|
|
| 2 |
import argparse
|
| 3 |
import numpy as np
|
| 4 |
import torch
|
| 5 |
-
from torch.utils.data import Dataset, DataLoader, Sampler
|
|
|
|
| 6 |
from lightning import LightningDataModule
|
| 7 |
from pathlib import Path
|
| 8 |
from multiprocessing import cpu_count
|
|
@@ -14,11 +15,66 @@ from typing import List, Iterable, Sequence
|
|
| 14 |
import sys
|
| 15 |
import rootutils
|
| 16 |
import logging
|
|
|
|
| 17 |
from dpacman.utils import pylogger
|
| 18 |
|
| 19 |
root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
| 20 |
logger = pylogger.RankedLogger(__name__, rank_zero_only=True)
|
| 21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
class PreBatchedSampler(Sampler[List[int]]):
|
| 23 |
"""
|
| 24 |
Yields precomputed batches of indices, e.g. [[3,7,9], [0,1,2], ...].
|
|
@@ -211,9 +267,11 @@ class PairDataModule(LightningDataModule):
|
|
| 211 |
batch_size=self.batch_size,
|
| 212 |
drop_last=self.drop_last,
|
| 213 |
)
|
| 214 |
-
self.train_batch_sampler =
|
| 215 |
self.train_batches,
|
| 216 |
shuffle_batch_order=self.shuffle_batch_order,
|
|
|
|
|
|
|
| 217 |
)
|
| 218 |
|
| 219 |
if not hasattr(self, "val_dataset"):
|
|
@@ -225,8 +283,8 @@ class PairDataModule(LightningDataModule):
|
|
| 225 |
batch_size=self.batch_size,
|
| 226 |
drop_last=False,
|
| 227 |
)
|
| 228 |
-
self.val_batch_sampler =
|
| 229 |
-
self.val_batches, shuffle_batch_order=False
|
| 230 |
)
|
| 231 |
|
| 232 |
# VALIDATE called standalone: ensure val is built
|
|
@@ -240,10 +298,10 @@ class PairDataModule(LightningDataModule):
|
|
| 240 |
batch_size=self.batch_size,
|
| 241 |
drop_last=False,
|
| 242 |
)
|
| 243 |
-
self.val_batch_sampler =
|
| 244 |
-
self.
|
| 245 |
)
|
| 246 |
-
|
| 247 |
# TEST phase
|
| 248 |
if stage in (None, "test"):
|
| 249 |
if not hasattr(self, "test_dataset"):
|
|
@@ -393,19 +451,23 @@ class ShelfCollator:
|
|
| 393 |
1
|
| 394 |
) # [B, Lg_max]
|
| 395 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 396 |
# 3) Collate labels for DNA and pad
|
| 397 |
labels_list = [torch.tensor(s, dtype=torch.float32) for s in scores_list]
|
| 398 |
labels = pad_sequence(
|
| 399 |
-
labels_list, batch_first=True, padding_value=
|
| 400 |
) # [B, Lg_max]
|
| 401 |
-
# (Optional) ensure labels are zeroed beyond mask:
|
| 402 |
-
labels = labels * glm_mask.to(labels.dtype)
|
| 403 |
|
| 404 |
return {
|
| 405 |
"binder_emb": binder_emb, # [B, Lb_max, Db]
|
| 406 |
"binder_mask": binder_mask, # [B, Lb_max]
|
|
|
|
| 407 |
"glm_emb": glm_emb, # [B, Lg_max, Dg]
|
| 408 |
"glm_mask": glm_mask, # [B, Lg_max]
|
|
|
|
| 409 |
"labels": labels, # [B, Lg_max]
|
| 410 |
"ID": ids,
|
| 411 |
"tr_sequence": tr_seqs,
|
|
@@ -451,9 +513,11 @@ def _peek_batches(dl, n_batches: int = 2, tag: str = "train"):
|
|
| 451 |
|
| 452 |
logger.info(f"\n[{tag}] batch {i+1}")
|
| 453 |
logger.info(f" binder_emb: {tuple(be.shape)} dtype={be.dtype}")
|
|
|
|
| 454 |
logger.info(f" binder_mask true count: {bm.sum().item()} / {bm.numel()}")
|
| 455 |
logger.info(f" glm_emb: {tuple(ge.shape)} dtype={ge.dtype}")
|
| 456 |
logger.info(f" glm_mask true count: {gm.sum().item()} / {gm.numel()}")
|
|
|
|
| 457 |
logger.info(
|
| 458 |
f" labels: {tuple(y.shape)} min={y.min().item():.4f} max={y.max().item():.4f}"
|
| 459 |
)
|
|
|
|
| 2 |
import argparse
|
| 3 |
import numpy as np
|
| 4 |
import torch
|
| 5 |
+
from torch.utils.data import Dataset, DataLoader, Sampler, BatchSampler
|
| 6 |
+
import torch.distributed as dist
|
| 7 |
from lightning import LightningDataModule
|
| 8 |
from pathlib import Path
|
| 9 |
from multiprocessing import cpu_count
|
|
|
|
| 15 |
import sys
|
| 16 |
import rootutils
|
| 17 |
import logging
|
| 18 |
+
import math
|
| 19 |
from dpacman.utils import pylogger
|
| 20 |
|
| 21 |
root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
| 22 |
logger = pylogger.RankedLogger(__name__, rank_zero_only=True)
|
| 23 |
|
| 24 |
+
class PreBatchedDistributedBatchSampler(BatchSampler):
|
| 25 |
+
"""
|
| 26 |
+
Accepts a precomputed list of batches (list[list[int]]) and shards them across DDP ranks.
|
| 27 |
+
- shuffle_batch_order: shuffle order of batches each epoch (deterministic via set_epoch)
|
| 28 |
+
- drop_last: drop remainder so each rank gets same #steps
|
| 29 |
+
"""
|
| 30 |
+
def __init__(self, batches, shuffle_batch_order=False, drop_last=False, seed: int = 0):
|
| 31 |
+
# super expects attributes batch_size, drop_last, sampler – but we don't need them.
|
| 32 |
+
# We only need to subclass BatchSampler to satisfy Lightning's check.
|
| 33 |
+
self.batches = [list(b) for b in batches]
|
| 34 |
+
self.shuffle = shuffle_batch_order
|
| 35 |
+
self.drop_last = drop_last
|
| 36 |
+
self.seed = int(seed)
|
| 37 |
+
self.epoch = 0
|
| 38 |
+
|
| 39 |
+
if dist.is_available() and dist.is_initialized():
|
| 40 |
+
self.world_size = dist.get_world_size()
|
| 41 |
+
self.rank = dist.get_rank()
|
| 42 |
+
else:
|
| 43 |
+
self.world_size = 1
|
| 44 |
+
self.rank = 0
|
| 45 |
+
|
| 46 |
+
def __iter__(self):
|
| 47 |
+
n_batches = len(self.batches)
|
| 48 |
+
order = list(range(n_batches))
|
| 49 |
+
|
| 50 |
+
if self.shuffle:
|
| 51 |
+
g = torch.Generator()
|
| 52 |
+
g.manual_seed(self.seed + self.epoch)
|
| 53 |
+
order = torch.randperm(n_batches, generator=g).tolist()
|
| 54 |
+
|
| 55 |
+
# make divisible across ranks
|
| 56 |
+
if self.drop_last:
|
| 57 |
+
total = (len(order) // self.world_size) * self.world_size
|
| 58 |
+
order = order[:total]
|
| 59 |
+
else:
|
| 60 |
+
pad = (-len(order)) % self.world_size
|
| 61 |
+
if pad:
|
| 62 |
+
order = order + order[:pad]
|
| 63 |
+
|
| 64 |
+
# shard by rank
|
| 65 |
+
for i in order[self.rank::self.world_size]:
|
| 66 |
+
yield self.batches[i]
|
| 67 |
+
|
| 68 |
+
def __len__(self):
|
| 69 |
+
n = len(self.batches)
|
| 70 |
+
if self.drop_last:
|
| 71 |
+
return (n // self.world_size)
|
| 72 |
+
return math.ceil(n / self.world_size)
|
| 73 |
+
|
| 74 |
+
# Lightning will call this if present via its epoch hooks
|
| 75 |
+
def set_epoch(self, epoch: int):
|
| 76 |
+
self.epoch = int(epoch)
|
| 77 |
+
|
| 78 |
class PreBatchedSampler(Sampler[List[int]]):
|
| 79 |
"""
|
| 80 |
Yields precomputed batches of indices, e.g. [[3,7,9], [0,1,2], ...].
|
|
|
|
| 267 |
batch_size=self.batch_size,
|
| 268 |
drop_last=self.drop_last,
|
| 269 |
)
|
| 270 |
+
self.train_batch_sampler = PreBatchedDistributedBatchSampler(
|
| 271 |
self.train_batches,
|
| 272 |
shuffle_batch_order=self.shuffle_batch_order,
|
| 273 |
+
drop_last=self.drop_last,
|
| 274 |
+
seed=0,
|
| 275 |
)
|
| 276 |
|
| 277 |
if not hasattr(self, "val_dataset"):
|
|
|
|
| 283 |
batch_size=self.batch_size,
|
| 284 |
drop_last=False,
|
| 285 |
)
|
| 286 |
+
self.val_batch_sampler = PreBatchedDistributedBatchSampler(
|
| 287 |
+
self.val_batches, shuffle_batch_order=False, drop_last=False, seed=0
|
| 288 |
)
|
| 289 |
|
| 290 |
# VALIDATE called standalone: ensure val is built
|
|
|
|
| 298 |
batch_size=self.batch_size,
|
| 299 |
drop_last=False,
|
| 300 |
)
|
| 301 |
+
self.val_batch_sampler = PreBatchedDistributedBatchSampler(
|
| 302 |
+
self.test_batches, shuffle_batch_order=False, drop_last=False, seed=0
|
| 303 |
)
|
| 304 |
+
|
| 305 |
# TEST phase
|
| 306 |
if stage in (None, "test"):
|
| 307 |
if not hasattr(self, "test_dataset"):
|
|
|
|
| 451 |
1
|
| 452 |
) # [B, Lg_max]
|
| 453 |
|
| 454 |
+
# True = PAD (what MHA expects)
|
| 455 |
+
binder_kpm = ~binder_mask
|
| 456 |
+
glm_kpm = ~glm_mask
|
| 457 |
+
|
| 458 |
# 3) Collate labels for DNA and pad
|
| 459 |
labels_list = [torch.tensor(s, dtype=torch.float32) for s in scores_list]
|
| 460 |
labels = pad_sequence(
|
| 461 |
+
labels_list, batch_first=True, padding_value=self.pad_value
|
| 462 |
) # [B, Lg_max]
|
|
|
|
|
|
|
| 463 |
|
| 464 |
return {
|
| 465 |
"binder_emb": binder_emb, # [B, Lb_max, Db]
|
| 466 |
"binder_mask": binder_mask, # [B, Lb_max]
|
| 467 |
+
"binder_kpm": binder_kpm.bool(), # True = PAD ← pass to MHA
|
| 468 |
"glm_emb": glm_emb, # [B, Lg_max, Dg]
|
| 469 |
"glm_mask": glm_mask, # [B, Lg_max]
|
| 470 |
+
"glm_kpm": glm_kpm.bool(), # True = PAD ← pass to MHA
|
| 471 |
"labels": labels, # [B, Lg_max]
|
| 472 |
"ID": ids,
|
| 473 |
"tr_sequence": tr_seqs,
|
|
|
|
| 513 |
|
| 514 |
logger.info(f"\n[{tag}] batch {i+1}")
|
| 515 |
logger.info(f" binder_emb: {tuple(be.shape)} dtype={be.dtype}")
|
| 516 |
+
logger.info(f" binder_emb: {tuple(bm.shape)} dtype={bm.dtype}")
|
| 517 |
logger.info(f" binder_mask true count: {bm.sum().item()} / {bm.numel()}")
|
| 518 |
logger.info(f" glm_emb: {tuple(ge.shape)} dtype={ge.dtype}")
|
| 519 |
logger.info(f" glm_mask true count: {gm.sum().item()} / {gm.numel()}")
|
| 520 |
+
logger.info(f" glm_mask: {tuple(gm.shape)} dtype={gm.dtype}")
|
| 521 |
logger.info(
|
| 522 |
f" labels: {tuple(y.shape)} min={y.min().item():.4f} max={y.max().item():.4f}"
|
| 523 |
)
|
dpacman/scripts/run_train.sh
CHANGED
|
@@ -14,16 +14,20 @@ if [ -z "$WANDB_API_KEY" ]; then
|
|
| 14 |
export WANDB_API_KEY="$wandb_key"
|
| 15 |
fi
|
| 16 |
|
| 17 |
-
nohup python -u -m scripts.train \
|
|
|
|
|
|
|
| 18 |
hydra.run.dir="${run_dir}" \
|
|
|
|
| 19 |
data_module.train_file="data_files/processed/splits/by_dna/train.csv" \
|
| 20 |
data_module.val_file="data_files/processed/splits/by_dna/val.csv" \
|
| 21 |
data_module.test_file="data_files/processed/splits/by_dna/test.csv" \
|
| 22 |
data_module.tr_shelf_path="data_files/processed/embeddings/fimo_hits_only/trs_esm.shelf" \
|
| 23 |
data_module.dna_shelf_path="data_files/processed/embeddings/fimo_hits_only/peaks_caduceus.shelf" \
|
| 24 |
model.glm_input_dim=256 \
|
|
|
|
|
|
|
| 25 |
model.lr=1e-5 \
|
| 26 |
-
model.compressed_dim=1029 \
|
| 27 |
> "${run_dir}/run.log" 2>&1 &
|
| 28 |
|
| 29 |
echo $! > "${run_dir}/pid.txt"
|
|
|
|
| 14 |
export WANDB_API_KEY="$wandb_key"
|
| 15 |
fi
|
| 16 |
|
| 17 |
+
CUDA_VISIBLE_DEVICES=0,1 nohup python -u -m scripts.train \
|
| 18 |
+
+trainer.strategy=ddp \
|
| 19 |
+
+trainer.use_distributed_sampler="false"\
|
| 20 |
hydra.run.dir="${run_dir}" \
|
| 21 |
+
trainer.devices=2 \
|
| 22 |
data_module.train_file="data_files/processed/splits/by_dna/train.csv" \
|
| 23 |
data_module.val_file="data_files/processed/splits/by_dna/val.csv" \
|
| 24 |
data_module.test_file="data_files/processed/splits/by_dna/test.csv" \
|
| 25 |
data_module.tr_shelf_path="data_files/processed/embeddings/fimo_hits_only/trs_esm.shelf" \
|
| 26 |
data_module.dna_shelf_path="data_files/processed/embeddings/fimo_hits_only/peaks_caduceus.shelf" \
|
| 27 |
model.glm_input_dim=256 \
|
| 28 |
+
model.compressed_dim=256 \
|
| 29 |
+
model.hidden_dim=256 \
|
| 30 |
model.lr=1e-5 \
|
|
|
|
| 31 |
> "${run_dir}/run.log" 2>&1 &
|
| 32 |
|
| 33 |
echo $! > "${run_dir}/pid.txt"
|