baseline compare
Browse files- .gitignore +4 -1
- configs/model/baseline.yaml +10 -0
- configs/model/pooling/truncatedsvd.yaml +0 -7
- dpacman/classifier/{model_w_rca.py → baseline.py} +106 -211
- dpacman/classifier/loss.py +73 -40
- dpacman/classifier/model.py +42 -5
- dpacman/classifier/model_tmp/__init__.py +0 -0
- dpacman/classifier/model_tmp/clustering_data.py +0 -475
- dpacman/classifier/model_tmp/compress_embeddings.py +0 -62
- dpacman/classifier/model_tmp/compute_embeddings.py +0 -612
- dpacman/classifier/model_tmp/extract_tf_symbols.py +0 -30
- dpacman/classifier/model_tmp/make_pair_list.py +0 -282
- dpacman/classifier/model_tmp/make_peak_fasta.py +0 -15
- dpacman/classifier/model_tmp/model.py +0 -111
- dpacman/classifier/model_tmp/prep_splits.py +0 -157
- dpacman/classifier/model_tmp/train.py +0 -217
- dpacman/classifier/old_train.py +0 -486
- dpacman/classifier/torch_model.py +0 -157
- dpacman/classifier/train.py +0 -220
- dpacman/scripts/delay_run.sh +1 -1
- dpacman/scripts/run_train.sh +3 -3
- dpacman/scripts/run_train_baseline.sh +39 -0
.gitignore
CHANGED
|
@@ -35,4 +35,7 @@ dpacman/combine.log
|
|
| 35 |
dpacman/loss_sim.py
|
| 36 |
dpacman/loss_temp.py
|
| 37 |
dpacman/peak_examples/
|
| 38 |
-
dpacman/__pycache__/
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
dpacman/loss_sim.py
|
| 36 |
dpacman/loss_temp.py
|
| 37 |
dpacman/peak_examples/
|
| 38 |
+
dpacman/__pycache__/
|
| 39 |
+
log.log
|
| 40 |
+
log2.log
|
| 41 |
+
dpacman/delay.log
|
configs/model/baseline.yaml
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: dpacman.classifier.baseline.BaselineBindPredictor
|
| 2 |
+
|
| 3 |
+
lr: 1e-4
|
| 4 |
+
alpha: 20
|
| 5 |
+
gamma: 20
|
| 6 |
+
weight_decay: 0.01
|
| 7 |
+
|
| 8 |
+
glm_input_dim: 256
|
| 9 |
+
compressed_dim: 256
|
| 10 |
+
hidden_dim: 128
|
configs/model/pooling/truncatedsvd.yaml
DELETED
|
@@ -1,7 +0,0 @@
|
|
| 1 |
-
n_components: 2
|
| 2 |
-
algorithm: randomized
|
| 3 |
-
n_iter: 5
|
| 4 |
-
n_oversamples: 10
|
| 5 |
-
poewr_iteration_normalizer: auto
|
| 6 |
-
random_state: 42
|
| 7 |
-
tol: 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dpacman/classifier/{model_w_rca.py → baseline.py}
RENAMED
|
@@ -1,177 +1,17 @@
|
|
| 1 |
"""
|
| 2 |
-
|
| 3 |
"""
|
| 4 |
|
| 5 |
-
import torch
|
| 6 |
-
from torch import nn
|
| 7 |
from lightning import LightningModule
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
class LocalCNN(nn.Module):
|
| 15 |
-
def __init__(self, dim: int = 256, kernel_size: int = 3):
|
| 16 |
-
super().__init__()
|
| 17 |
-
padding = kernel_size // 2
|
| 18 |
-
self.conv = nn.Conv1d(dim, dim, kernel_size=kernel_size, padding=padding)
|
| 19 |
-
self.act = nn.GELU()
|
| 20 |
-
self.ln = nn.LayerNorm(dim)
|
| 21 |
-
|
| 22 |
-
def forward(self, x: torch.Tensor):
|
| 23 |
-
# x: (batch, L, dim)
|
| 24 |
-
out = self.conv(x.transpose(1, 2)) # → (batch, dim, L)
|
| 25 |
-
out = self.act(out)
|
| 26 |
-
out = out.transpose(1, 2) # → (batch, L, dim)
|
| 27 |
-
return self.ln(out + x) # residual
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
# class CrossModalBlock(nn.Module):
|
| 31 |
-
# def __init__(self, dim: int = 256, heads: int = 8):
|
| 32 |
-
# super().__init__()
|
| 33 |
-
# # self-attention for both sides
|
| 34 |
-
# self.sa_binder = nn.MultiheadAttention(dim, heads, batch_first=True)
|
| 35 |
-
# self.sa_glm = nn.MultiheadAttention(dim, heads, batch_first=True)
|
| 36 |
-
# self.ln_b1 = nn.LayerNorm(dim)
|
| 37 |
-
# self.ln_g1 = nn.LayerNorm(dim)
|
| 38 |
-
|
| 39 |
-
# self.ffn_b = nn.Sequential(
|
| 40 |
-
# nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
|
| 41 |
-
# )
|
| 42 |
-
# self.ffn_g = nn.Sequential(
|
| 43 |
-
# nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
|
| 44 |
-
# )
|
| 45 |
-
# self.ln_b2 = nn.LayerNorm(dim)
|
| 46 |
-
# self.ln_g2 = nn.LayerNorm(dim)
|
| 47 |
-
|
| 48 |
-
# # cross attention (binder queries, glm keys/values)
|
| 49 |
-
# # so the NDA path is updated by the transcriptoin factors
|
| 50 |
-
# self.cross_attn = nn.MultiheadAttention(dim, heads, batch_first=True)
|
| 51 |
-
# self.ln_c1 = nn.LayerNorm(dim)
|
| 52 |
-
# self.ffn_c = nn.Sequential(
|
| 53 |
-
# nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
|
| 54 |
-
# )
|
| 55 |
-
# self.ln_c2 = nn.LayerNorm(dim)
|
| 56 |
-
|
| 57 |
-
# def forward(self, binder: torch.Tensor, glm: torch.Tensor):
|
| 58 |
-
# """
|
| 59 |
-
# binder: (batch, Lb, dim)
|
| 60 |
-
# glm: (batch, Lg, dim) -- has passed through its local CNN beforehand
|
| 61 |
-
# returns: updated binder representation (batch, Lb, dim)
|
| 62 |
-
# """
|
| 63 |
-
# # binder: self-attn + ffn
|
| 64 |
-
# b = binder
|
| 65 |
-
# b_sa, _ = self.sa_binder(b, b, b)
|
| 66 |
-
# b = self.ln_b1(b + b_sa)
|
| 67 |
-
# b_ff = self.ffn_b(b)
|
| 68 |
-
# b = self.ln_b2(b + b_ff)
|
| 69 |
-
|
| 70 |
-
# # glm: self-attn + ffn
|
| 71 |
-
# g = glm
|
| 72 |
-
# g_sa, _ = self.sa_glm(g, g, g)
|
| 73 |
-
# g = self.ln_g1(g + g_sa)
|
| 74 |
-
# g_ff = self.ffn_g(g)
|
| 75 |
-
# g = self.ln_g2(g + g_ff)
|
| 76 |
-
|
| 77 |
-
# # cross-attention: glm queries binder and glm embeddings are updated
|
| 78 |
-
# g_to_b_ca, _ = self.cross_attn(g, b, b)
|
| 79 |
-
# g = self.ln_c1(g + g_to_b_ca)
|
| 80 |
-
# g_ff = self.ffn_c(g)
|
| 81 |
-
# g = self.ln_c2(g + g_ff)
|
| 82 |
-
# return g # (batch, Lb, dim)
|
| 83 |
-
|
| 84 |
-
class CrossModalBlock(nn.Module):
|
| 85 |
-
def __init__(self, dim: int = 256, heads: int = 8, dropout: float = 0.0):
|
| 86 |
-
super().__init__()
|
| 87 |
-
# 1) self-attn on each stream
|
| 88 |
-
self.sa_binder = nn.MultiheadAttention(dim, heads, batch_first=True, dropout=dropout)
|
| 89 |
-
self.sa_glm = nn.MultiheadAttention(dim, heads, batch_first=True, dropout=dropout)
|
| 90 |
-
self.ln_b1 = nn.LayerNorm(dim)
|
| 91 |
-
self.ln_g1 = nn.LayerNorm(dim)
|
| 92 |
-
self.ffn_b = nn.Sequential(nn.Linear(dim, dim*4), nn.GELU(), nn.Linear(dim*4, dim))
|
| 93 |
-
self.ffn_g = nn.Sequential(nn.Linear(dim, dim*4), nn.GELU(), nn.Linear(dim*4, dim))
|
| 94 |
-
self.ln_b2 = nn.LayerNorm(dim)
|
| 95 |
-
self.ln_g2 = nn.LayerNorm(dim)
|
| 96 |
-
|
| 97 |
-
# 2) reciprocal cross-attn: g<-b and b<-g
|
| 98 |
-
# DNA/GLM updated by attending to Binder
|
| 99 |
-
self.cross_g2b = nn.MultiheadAttention(dim, heads, batch_first=True, dropout=dropout)
|
| 100 |
-
self.ln_g_ca1 = nn.LayerNorm(dim)
|
| 101 |
-
self.ffn_g_ca = nn.Sequential(nn.Linear(dim, dim*4), nn.GELU(), nn.Linear(dim*4, dim))
|
| 102 |
-
self.ln_g_ca2 = nn.LayerNorm(dim)
|
| 103 |
-
|
| 104 |
-
# Binder updated by attending to DNA/GLM
|
| 105 |
-
self.cross_b2g = nn.MultiheadAttention(dim, heads, batch_first=True, dropout=dropout)
|
| 106 |
-
self.ln_b_ca1 = nn.LayerNorm(dim)
|
| 107 |
-
self.ffn_b_ca = nn.Sequential(nn.Linear(dim, dim*4), nn.GELU(), nn.Linear(dim*4, dim))
|
| 108 |
-
self.ln_b_ca2 = nn.LayerNorm(dim)
|
| 109 |
-
|
| 110 |
-
def forward(
|
| 111 |
-
self,
|
| 112 |
-
binder: torch.Tensor, # (B, Lb, D)
|
| 113 |
-
glm: torch.Tensor, # (B, Lg, D)
|
| 114 |
-
binder_mask: torch.Tensor | None = None, # (B, Lb) True = keep
|
| 115 |
-
glm_mask: torch.Tensor | None = None, # (B, Lg) True = keep
|
| 116 |
-
):
|
| 117 |
-
# 1) self-attn+FFN on each stream
|
| 118 |
-
b, g = binder, glm
|
| 119 |
-
|
| 120 |
-
b_sa, _ = self.sa_binder(b, b, b, key_padding_mask=None)
|
| 121 |
-
b = self.ln_b1(b + b_sa)
|
| 122 |
-
b = self.ln_b2(b + self.ffn_b(b))
|
| 123 |
-
|
| 124 |
-
g_sa, _ = self.sa_glm(g, g, g, key_padding_mask=None)
|
| 125 |
-
g = self.ln_g1(g + g_sa)
|
| 126 |
-
g = self.ln_g2(g + self.ffn_g(g))
|
| 127 |
-
|
| 128 |
-
# 2a) DNA/GLM updated by attending to Binder (Q=g, K=b, V=b)
|
| 129 |
-
g_ca, _ = self.cross_g2b(
|
| 130 |
-
g, b, b,
|
| 131 |
-
# torch MultiheadAttention expects key_padding_mask=True for PADs;
|
| 132 |
-
# invert if your mask is True=keep:
|
| 133 |
-
# key_padding_mask=(~binder_mask.bool()) if binder_mask is not None else None
|
| 134 |
-
)
|
| 135 |
-
g = self.ln_g_ca1(g + g_ca)
|
| 136 |
-
g = self.ln_g_ca2(g + self.ffn_g_ca(g))
|
| 137 |
-
|
| 138 |
-
# 2b) Binder updated by attending to DNA/GLM (Q=b, K=g, V=g)
|
| 139 |
-
b_ca, _ = self.cross_b2g(
|
| 140 |
-
b, g, g,
|
| 141 |
-
# key_padding_mask=(~glm_mask.bool()) if glm_mask is not None else None
|
| 142 |
-
)
|
| 143 |
-
b = self.ln_b_ca1(b + b_ca)
|
| 144 |
-
b = self.ln_b_ca2(b + self.ffn_b_ca(b))
|
| 145 |
-
|
| 146 |
-
return b, g
|
| 147 |
-
|
| 148 |
-
|
| 149 |
|
| 150 |
-
class
|
| 151 |
"""
|
| 152 |
-
|
| 153 |
-
If in_dim == out_dim, behaves as identity.
|
| 154 |
"""
|
| 155 |
-
|
| 156 |
-
def __init__(self, in_dim: int, out_dim: int = 256):
|
| 157 |
-
super().__init__()
|
| 158 |
-
if in_dim == out_dim:
|
| 159 |
-
self.net = nn.Identity()
|
| 160 |
-
else:
|
| 161 |
-
hidden = max(out_dim * 2, (in_dim + out_dim) // 2)
|
| 162 |
-
self.net = nn.Sequential(
|
| 163 |
-
nn.LayerNorm(in_dim),
|
| 164 |
-
nn.Linear(in_dim, hidden),
|
| 165 |
-
nn.GELU(),
|
| 166 |
-
nn.Linear(hidden, out_dim),
|
| 167 |
-
)
|
| 168 |
-
|
| 169 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 170 |
-
# x: (B, L, in_dim)
|
| 171 |
-
return self.net(x)
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
class BindPredictor(LightningModule):
|
| 175 |
def __init__(
|
| 176 |
self,
|
| 177 |
# input_dim: int = 256, # OLD: single input dim
|
|
@@ -179,77 +19,64 @@ class BindPredictor(LightningModule):
|
|
| 179 |
glm_input_dim: int = 256, # NEW: DNA/GLM original dim (e.g., 256)
|
| 180 |
compressed_dim: int = 256, # NEW: learnable compressed dim
|
| 181 |
hidden_dim: int = 256,
|
| 182 |
-
heads: int = 8,
|
| 183 |
-
num_layers: int = 4,
|
| 184 |
lr: float = 1e-4,
|
| 185 |
alpha: float = 20,
|
| 186 |
gamma: float = 20,
|
| 187 |
-
|
| 188 |
weight_decay: float = 0.01,
|
| 189 |
):
|
| 190 |
# Init
|
| 191 |
-
super(
|
| 192 |
self.save_hyperparameters()
|
| 193 |
|
| 194 |
# Learnable compressor for binder -> 256, then project to hidden
|
| 195 |
self.binder_compress = DimCompressor(binder_input_dim, out_dim=compressed_dim)
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
self.local_cnn = LocalCNN(hidden_dim) if use_local_cnn_on_glm else nn.Identity()
|
| 203 |
-
|
| 204 |
-
self.layers = nn.ModuleList(
|
| 205 |
-
[CrossModalBlock(hidden_dim, heads) for _ in range(num_layers)]
|
| 206 |
)
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
# self.head = nn.Sequential(nn.Linear(hidden_dim, 1), nn.Sigmoid()) # OLD: returned probabilities
|
| 210 |
-
self.head = nn.Linear(hidden_dim, 1) # NEW: return logits (safe for AMP)
|
| 211 |
-
|
| 212 |
-
def forward(self, binder_emb, glm_emb):
|
| 213 |
"""
|
| 214 |
binder_emb: (B, Lb, binder_input_dim)
|
| 215 |
glm_emb: (B, Lg, glm_input_dim)
|
| 216 |
Returns per-nucleotide logits for the GLM sequence: (B, Lg)
|
| 217 |
"""
|
| 218 |
-
# Binder: learnable compression →
|
| 219 |
-
b = self.binder_compress(binder_emb) # (B, Lb,
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
# Predict per-nucleotide logits on the GLM tokens:
|
| 232 |
-
# return self.head(g).squeeze(-1) # OLD: probabilities (with Sigmoid in head)
|
| 233 |
-
return self.head(g).squeeze(
|
| 234 |
-1
|
| 235 |
-
)
|
| 236 |
-
|
|
|
|
| 237 |
# ----- Lightning hooks -----
|
| 238 |
def training_step(self, batch, batch_idx):
|
| 239 |
"""
|
| 240 |
Training step taken by PyTorch-Lightning trainer. Uses batch returned by data collator.
|
| 241 |
Colator returns a dictionary with:
|
| 242 |
"binder_emb" # [B, Lb_max, Db]
|
| 243 |
-
"
|
| 244 |
"glm_emb" # [B, Lg_max, Dg]
|
| 245 |
-
"
|
| 246 |
"labels" # [B, Lg_max]
|
| 247 |
"ID"
|
| 248 |
"tr_sequence"
|
| 249 |
"dna_sequence"
|
| 250 |
}
|
| 251 |
"""
|
| 252 |
-
logits = self.forward(batch["binder_emb"], batch["glm_emb"])
|
| 253 |
loss = calculate_loss(
|
| 254 |
logits, batch["labels"], alpha=self.hparams.alpha, gamma=self.hparams.gamma
|
| 255 |
)
|
|
@@ -261,10 +88,30 @@ class BindPredictor(LightningModule):
|
|
| 261 |
prog_bar=True,
|
| 262 |
batch_size=logits.size(0),
|
| 263 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
return loss
|
| 265 |
|
| 266 |
def validation_step(self, batch, batch_idx):
|
| 267 |
-
logits = self.forward(batch["binder_emb"], batch["glm_emb"])
|
| 268 |
loss = calculate_loss(
|
| 269 |
logits, batch["labels"], alpha=self.hparams.alpha, gamma=self.hparams.gamma
|
| 270 |
)
|
|
@@ -276,17 +123,65 @@ class BindPredictor(LightningModule):
|
|
| 276 |
prog_bar=True,
|
| 277 |
batch_size=logits.size(0),
|
| 278 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
return loss
|
| 280 |
|
| 281 |
def test_step(self, batch, batch_idx):
|
| 282 |
-
logits = self.forward(batch["binder_emb"], batch["glm_emb"])
|
| 283 |
loss = calculate_loss(
|
| 284 |
logits, batch["labels"], alpha=self.hparams.alpha, gamma=self.hparams.gamma
|
| 285 |
)
|
| 286 |
self.log(
|
| 287 |
"test/loss", loss, on_step=False, on_epoch=True, batch_size=logits.size(0)
|
| 288 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
return loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
|
| 291 |
def on_train_epoch_end(self):
|
| 292 |
if False:
|
|
@@ -320,4 +215,4 @@ class BindPredictor(LightningModule):
|
|
| 320 |
return {
|
| 321 |
"optimizer": opt,
|
| 322 |
"lr_scheduler": {"scheduler": sch, "interval": "epoch"},
|
| 323 |
-
}
|
|
|
|
| 1 |
"""
|
| 2 |
+
Code for baseline model to compare the classifier to
|
| 3 |
"""
|
| 4 |
|
|
|
|
|
|
|
| 5 |
from lightning import LightningModule
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from .loss import calculate_loss, auprc_zeros_vs_ones_from_logits, auroc_zeros_vs_ones_from_logits
|
| 9 |
+
from .model import DimCompressor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
+
class BaselineBindPredictor(LightningModule):
|
| 12 |
"""
|
| 13 |
+
Baseline predictor: simple MLP that just concatenates the embeddings and outputs per-token predictions.
|
|
|
|
| 14 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
def __init__(
|
| 16 |
self,
|
| 17 |
# input_dim: int = 256, # OLD: single input dim
|
|
|
|
| 19 |
glm_input_dim: int = 256, # NEW: DNA/GLM original dim (e.g., 256)
|
| 20 |
compressed_dim: int = 256, # NEW: learnable compressed dim
|
| 21 |
hidden_dim: int = 256,
|
|
|
|
|
|
|
| 22 |
lr: float = 1e-4,
|
| 23 |
alpha: float = 20,
|
| 24 |
gamma: float = 20,
|
| 25 |
+
dropout: float = 0,
|
| 26 |
weight_decay: float = 0.01,
|
| 27 |
):
|
| 28 |
# Init
|
| 29 |
+
super(BaselineBindPredictor, self).__init__()
|
| 30 |
self.save_hyperparameters()
|
| 31 |
|
| 32 |
# Learnable compressor for binder -> 256, then project to hidden
|
| 33 |
self.binder_compress = DimCompressor(binder_input_dim, out_dim=compressed_dim)
|
| 34 |
+
|
| 35 |
+
self.mlp = torch.nn.Sequential(
|
| 36 |
+
torch.nn.Linear(compressed_dim, hidden_dim),
|
| 37 |
+
torch.nn.ReLU(),
|
| 38 |
+
torch.nn.Linear(hidden_dim, 1),
|
| 39 |
+
torch.nn.ReLU(),
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
)
|
| 41 |
+
|
| 42 |
+
def forward(self, binder_emb, glm_emb, binder_mask, glm_mask):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
"""
|
| 44 |
binder_emb: (B, Lb, binder_input_dim)
|
| 45 |
glm_emb: (B, Lg, glm_input_dim)
|
| 46 |
Returns per-nucleotide logits for the GLM sequence: (B, Lg)
|
| 47 |
"""
|
| 48 |
+
# Binder: learnable compression → glm_input_dim
|
| 49 |
+
b = self.binder_compress(binder_emb) # (B, Lb, glm_input_dim)
|
| 50 |
+
|
| 51 |
+
# Concatenate target and binder. Concatenate on the length dimension
|
| 52 |
+
lg = glm_emb.shape[1]
|
| 53 |
+
concat_embeddings = torch.concat((glm_emb,b), dim=1) # (B, Lb + Lg, glm_input_dim)
|
| 54 |
+
|
| 55 |
+
# Run concatenated embeddings through MLP
|
| 56 |
+
logits = self.mlp(concat_embeddings) # (B, Lb + Lg, 1)
|
| 57 |
+
|
| 58 |
+
# Get only the DNA logits.
|
| 59 |
+
logits = logits[:,0:lg,:].squeeze(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
-1
|
| 61 |
+
)
|
| 62 |
+
return logits
|
| 63 |
+
|
| 64 |
# ----- Lightning hooks -----
|
| 65 |
def training_step(self, batch, batch_idx):
|
| 66 |
"""
|
| 67 |
Training step taken by PyTorch-Lightning trainer. Uses batch returned by data collator.
|
| 68 |
Colator returns a dictionary with:
|
| 69 |
"binder_emb" # [B, Lb_max, Db]
|
| 70 |
+
"binder_kpm" # [B, Lb_max]
|
| 71 |
"glm_emb" # [B, Lg_max, Dg]
|
| 72 |
+
"glm_kpm" # [B, Lg_max]
|
| 73 |
"labels" # [B, Lg_max]
|
| 74 |
"ID"
|
| 75 |
"tr_sequence"
|
| 76 |
"dna_sequence"
|
| 77 |
}
|
| 78 |
"""
|
| 79 |
+
logits = self.forward(batch["binder_emb"], batch["glm_emb"], batch["binder_kpm"], batch["glm_kpm"])
|
| 80 |
loss = calculate_loss(
|
| 81 |
logits, batch["labels"], alpha=self.hparams.alpha, gamma=self.hparams.gamma
|
| 82 |
)
|
|
|
|
| 88 |
prog_bar=True,
|
| 89 |
batch_size=logits.size(0),
|
| 90 |
)
|
| 91 |
+
|
| 92 |
+
# ---- AUPRC and AUROC on labels in {0, >0.99} only ----
|
| 93 |
+
ap, n_pos, n_neg, precision, recall, thresholds = auprc_zeros_vs_ones_from_logits(
|
| 94 |
+
logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99
|
| 95 |
+
)
|
| 96 |
+
auc, n_pos, n_neg, tpr, fpr, thresolds, tp, fp = auroc_zeros_vs_ones_from_logits(
|
| 97 |
+
logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99
|
| 98 |
+
)
|
| 99 |
+
# per-batch AP (epoch-mean is a decent summary); sync across GPUs if using DDP
|
| 100 |
+
self.log("train/auprc_0v1",
|
| 101 |
+
ap if torch.isfinite(ap) else torch.tensor(0.0, device=ap.device),
|
| 102 |
+
on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0))
|
| 103 |
+
self.log("train/auroc_0v1",
|
| 104 |
+
auc if torch.isfinite(auc) else torch.tensor(0.0, device=auc.device),
|
| 105 |
+
on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0))
|
| 106 |
+
|
| 107 |
+
# (optional) also log class counts so you can sanity-check balance
|
| 108 |
+
self.log("train/n_pos_0v1", float(n_pos), on_step=False, on_epoch=True, sync_dist=True)
|
| 109 |
+
self.log("train/n_neg_0v1", float(n_neg), on_step=False, on_epoch=True, sync_dist=True)
|
| 110 |
+
|
| 111 |
return loss
|
| 112 |
|
| 113 |
def validation_step(self, batch, batch_idx):
|
| 114 |
+
logits = self.forward(batch["binder_emb"], batch["glm_emb"], batch["binder_kpm"], batch["glm_kpm"])
|
| 115 |
loss = calculate_loss(
|
| 116 |
logits, batch["labels"], alpha=self.hparams.alpha, gamma=self.hparams.gamma
|
| 117 |
)
|
|
|
|
| 123 |
prog_bar=True,
|
| 124 |
batch_size=logits.size(0),
|
| 125 |
)
|
| 126 |
+
|
| 127 |
+
# ---- AUPRC and AUROC on labels in {0, >0.99} only ----
|
| 128 |
+
ap, n_pos, n_neg, precision, recall, thresholds = auprc_zeros_vs_ones_from_logits(
|
| 129 |
+
logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99
|
| 130 |
+
)
|
| 131 |
+
auc, n_pos, n_neg, tpr, fpr, thresolds, tp, fp = auroc_zeros_vs_ones_from_logits(
|
| 132 |
+
logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99
|
| 133 |
+
)
|
| 134 |
+
# per-batch AP (epoch-mean is a decent summary); sync across GPUs if using DDP
|
| 135 |
+
self.log("val/auprc_0v1",
|
| 136 |
+
ap if torch.isfinite(ap) else torch.tensor(0.0, device=ap.device),
|
| 137 |
+
on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0))
|
| 138 |
+
self.log("val/auroc_0v1",
|
| 139 |
+
auc if torch.isfinite(auc) else torch.tensor(0.0, device=auc.device),
|
| 140 |
+
on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0))
|
| 141 |
return loss
|
| 142 |
|
| 143 |
def test_step(self, batch, batch_idx):
|
| 144 |
+
logits = self.forward(batch["binder_emb"], batch["glm_emb"], batch["binder_kpm"], batch["glm_kpm"])
|
| 145 |
loss = calculate_loss(
|
| 146 |
logits, batch["labels"], alpha=self.hparams.alpha, gamma=self.hparams.gamma
|
| 147 |
)
|
| 148 |
self.log(
|
| 149 |
"test/loss", loss, on_step=False, on_epoch=True, batch_size=logits.size(0)
|
| 150 |
)
|
| 151 |
+
|
| 152 |
+
# ---- AUPRC and AUROC on labels in {0, >0.99} only ----
|
| 153 |
+
ap, n_pos, n_neg, precision, recall, thresholds = auprc_zeros_vs_ones_from_logits(
|
| 154 |
+
logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99
|
| 155 |
+
)
|
| 156 |
+
auc, n_pos, n_neg, tpr, fpr, thresolds, tp, fp = auroc_zeros_vs_ones_from_logits(
|
| 157 |
+
logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99
|
| 158 |
+
)
|
| 159 |
+
# per-batch AP (epoch-mean is a decent summary); sync across GPUs if using DDP
|
| 160 |
+
self.log("test/auprc_0v1",
|
| 161 |
+
ap if torch.isfinite(ap) else torch.tensor(0.0, device=ap.device),
|
| 162 |
+
on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0))
|
| 163 |
+
self.log("test/auroc_0v1",
|
| 164 |
+
auc if torch.isfinite(auc) else torch.tensor(0.0, device=auc.device),
|
| 165 |
+
on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0))
|
| 166 |
return loss
|
| 167 |
+
|
| 168 |
+
def on_before_optimizer_step(self, optimizer):
|
| 169 |
+
# Compute global L2 norm of all parameter gradients (ignores None grads)
|
| 170 |
+
grads = []
|
| 171 |
+
for p in self.parameters():
|
| 172 |
+
if p.grad is not None:
|
| 173 |
+
# .detach() avoids autograd tracking; .float() avoids fp16 overflow in norms
|
| 174 |
+
grads.append(p.grad.detach().float().norm(2))
|
| 175 |
+
if grads:
|
| 176 |
+
total_norm = torch.norm(torch.stack(grads), p=2)
|
| 177 |
+
self.log("train/grad_norm", total_norm, on_step=True, prog_bar=False, logger=True)
|
| 178 |
+
|
| 179 |
+
def on_after_backward(self):
|
| 180 |
+
grads = [p.grad.detach().float().norm(2)
|
| 181 |
+
for p in self.parameters() if p.grad is not None]
|
| 182 |
+
if grads:
|
| 183 |
+
total_norm = torch.norm(torch.stack(grads), p=2)
|
| 184 |
+
self.log("train/grad_norm_back", total_norm, on_step=True, prog_bar=False)
|
| 185 |
|
| 186 |
def on_train_epoch_end(self):
|
| 187 |
if False:
|
|
|
|
| 215 |
return {
|
| 216 |
"optimizer": opt,
|
| 217 |
"lr_scheduler": {"scheduler": sch, "interval": "epoch"},
|
| 218 |
+
}
|
dpacman/classifier/loss.py
CHANGED
|
@@ -4,6 +4,9 @@ Define loss functions needed for training the model — padding safe (-1 sentine
|
|
| 4 |
|
| 5 |
import torch
|
| 6 |
import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
def _expand_like(mask: torch.Tensor, like: torch.Tensor):
|
| 9 |
# Make mask broadcastable to logits/targets (handles (B,L) vs (B,L,1))
|
|
@@ -62,59 +65,89 @@ def calculate_loss(
|
|
| 62 |
|
| 63 |
return alpha * bce_nonpeak + gamma * mse_peak
|
| 64 |
|
| 65 |
-
import torch
|
| 66 |
-
|
| 67 |
@torch.no_grad()
|
| 68 |
-
def
|
| 69 |
logits: torch.Tensor, # (B, L)
|
| 70 |
labels: torch.Tensor, # (B, L)
|
| 71 |
-
glm_kpm: torch.Tensor | None,
|
| 72 |
pos_thresh: float = 0.99,
|
| 73 |
-
)
|
| 74 |
"""
|
| 75 |
-
Returns
|
| 76 |
-
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
| 78 |
"""
|
| 79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
-
#
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
else:
|
| 85 |
-
valid = torch.ones_like(labels, dtype=torch.bool, device=labels.device)
|
| 86 |
|
| 87 |
-
|
| 88 |
-
pos = labels > pos_thresh
|
| 89 |
-
neg = labels == 0.0
|
| 90 |
-
keep = valid & (pos | neg)
|
| 91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
if keep.sum() == 0:
|
| 93 |
-
return torch.tensor(float('nan'), device=
|
|
|
|
| 94 |
|
| 95 |
-
y =
|
| 96 |
-
s =
|
| 97 |
|
| 98 |
-
n = y.numel()
|
| 99 |
n_pos = int(y.sum().item())
|
| 100 |
-
n_neg =
|
| 101 |
-
if n_pos == 0:
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
#
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
recall = tp / n_pos
|
| 113 |
-
|
| 114 |
-
# AP = sum( precision * Δrecall )
|
| 115 |
-
recall_prev = torch.cat([torch.zeros(1, device=logits.device, dtype=probs.dtype), recall[:-1]])
|
| 116 |
-
ap = (precision * (recall - recall_prev)).sum()
|
| 117 |
-
return ap, n_pos, n_neg
|
| 118 |
|
| 119 |
def accuracy_percentage(
|
| 120 |
logits,
|
|
|
|
| 4 |
|
| 5 |
import torch
|
| 6 |
import torch.nn.functional as F
|
| 7 |
+
from torchmetrics.functional.classification import (
|
| 8 |
+
auroc, average_precision, roc, precision_recall_curve
|
| 9 |
+
)
|
| 10 |
|
| 11 |
def _expand_like(mask: torch.Tensor, like: torch.Tensor):
|
| 12 |
# Make mask broadcastable to logits/targets (handles (B,L) vs (B,L,1))
|
|
|
|
| 65 |
|
| 66 |
return alpha * bce_nonpeak + gamma * mse_peak
|
| 67 |
|
|
|
|
|
|
|
| 68 |
@torch.no_grad()
|
| 69 |
+
def auroc_zeros_vs_ones_from_logits(
|
| 70 |
logits: torch.Tensor, # (B, L)
|
| 71 |
labels: torch.Tensor, # (B, L)
|
| 72 |
+
glm_kpm: torch.Tensor | None = None, # (B, L) True=PAD
|
| 73 |
pos_thresh: float = 0.99,
|
| 74 |
+
):
|
| 75 |
"""
|
| 76 |
+
Returns:
|
| 77 |
+
auc: scalar tensor (AUROC)
|
| 78 |
+
n_pos, n_neg: ints
|
| 79 |
+
tpr, fpr: tensors of shape (T,)
|
| 80 |
+
thresholds: tensor of shape (T,)
|
| 81 |
+
tp, fp: integer counts per threshold (shape (T,))
|
| 82 |
"""
|
| 83 |
+
device = logits.device
|
| 84 |
+
valid = ~glm_kpm if glm_kpm is not None else torch.ones_like(labels, dtype=torch.bool, device=device)
|
| 85 |
+
keep = valid & ((labels > pos_thresh) | (labels == 0.0))
|
| 86 |
+
if keep.sum() == 0:
|
| 87 |
+
return (torch.tensor(float('nan'), device=device), 0, 0,
|
| 88 |
+
torch.empty(0, device=device), torch.empty(0, device=device),
|
| 89 |
+
torch.empty(0, device=device), torch.empty(0, device=device), torch.empty(0, device=device))
|
| 90 |
+
|
| 91 |
+
y = (labels[keep] > pos_thresh).to(torch.int)
|
| 92 |
+
s = logits[keep]
|
| 93 |
+
|
| 94 |
+
n_pos = int(y.sum().item())
|
| 95 |
+
n_neg = y.numel() - n_pos
|
| 96 |
+
if n_pos == 0 or n_neg == 0:
|
| 97 |
+
return (torch.tensor(float('nan'), device=device), n_pos, n_neg,
|
| 98 |
+
torch.empty(0, device=device), torch.empty(0, device=device),
|
| 99 |
+
torch.empty(0, device=device), torch.empty(0, device=device), torch.empty(0, device=device))
|
| 100 |
+
|
| 101 |
+
# Full ROC curve
|
| 102 |
+
fpr, tpr, thresholds = roc(s, y, task="binary")
|
| 103 |
+
# AUROC (TM handles logits)
|
| 104 |
+
auc = auroc(s, y, task="binary")
|
| 105 |
|
| 106 |
+
# Convert rates to counts (round to nearest to avoid float off-by-one)
|
| 107 |
+
tp = (tpr * n_pos).round().to(torch.long)
|
| 108 |
+
fp = (fpr * n_neg).round().to(torch.long)
|
|
|
|
|
|
|
| 109 |
|
| 110 |
+
return auc.to(device), n_pos, n_neg, tpr.to(device), fpr.to(device), thresholds.to(device), tp.to(device), fp.to(device)
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
+
|
| 113 |
+
@torch.no_grad()
|
| 114 |
+
def auprc_zeros_vs_ones_from_logits(
|
| 115 |
+
logits: torch.Tensor, # (B, L)
|
| 116 |
+
labels: torch.Tensor, # (B, L)
|
| 117 |
+
glm_kpm: torch.Tensor | None = None, # (B, L) True=PAD
|
| 118 |
+
pos_thresh: float = 0.99,
|
| 119 |
+
):
|
| 120 |
+
"""
|
| 121 |
+
Returns:
|
| 122 |
+
ap: scalar tensor (Average Precision / AUPRC)
|
| 123 |
+
n_pos, n_neg: ints
|
| 124 |
+
precision: (T,)
|
| 125 |
+
recall: (T,)
|
| 126 |
+
thresholds: (T,)
|
| 127 |
+
"""
|
| 128 |
+
device = logits.device
|
| 129 |
+
valid = ~glm_kpm if glm_kpm is not None else torch.ones_like(labels, dtype=torch.bool, device=device)
|
| 130 |
+
keep = valid & ((labels > pos_thresh) | (labels == 0.0))
|
| 131 |
if keep.sum() == 0:
|
| 132 |
+
return (torch.tensor(float('nan'), device=device), 0, 0,
|
| 133 |
+
torch.empty(0, device=device), torch.empty(0, device=device), torch.empty(0, device=device))
|
| 134 |
|
| 135 |
+
y = (labels[keep] > pos_thresh).to(torch.int)
|
| 136 |
+
s = logits[keep]
|
| 137 |
|
|
|
|
| 138 |
n_pos = int(y.sum().item())
|
| 139 |
+
n_neg = y.numel() - n_pos
|
| 140 |
+
if n_pos == 0:
|
| 141 |
+
# By convention, AP=0 when there are no positives
|
| 142 |
+
return (torch.tensor(0.0, device=device), 0, n_neg,
|
| 143 |
+
torch.empty(0, device=device), torch.empty(0, device=device), torch.empty(0, device=device))
|
| 144 |
+
|
| 145 |
+
# Full PR curve
|
| 146 |
+
precision, recall, thresholds = precision_recall_curve(s, y, task="binary")
|
| 147 |
+
# Average Precision / AUPRC
|
| 148 |
+
ap = average_precision(s, y, task="binary")
|
| 149 |
+
|
| 150 |
+
return ap.to(device), n_pos, n_neg, precision.to(device), recall.to(device), thresholds.to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
|
| 152 |
def accuracy_percentage(
|
| 153 |
logits,
|
dpacman/classifier/model.py
CHANGED
|
@@ -6,7 +6,7 @@ import torch
|
|
| 6 |
from torch import nn
|
| 7 |
from lightning import LightningModule
|
| 8 |
from dpacman.utils.models import set_seed
|
| 9 |
-
from .loss import calculate_loss, auprc_zeros_vs_ones_from_logits
|
| 10 |
|
| 11 |
set_seed()
|
| 12 |
|
|
@@ -211,9 +211,9 @@ class BindPredictor(LightningModule):
|
|
| 211 |
Training step taken by PyTorch-Lightning trainer. Uses batch returned by data collator.
|
| 212 |
Colator returns a dictionary with:
|
| 213 |
"binder_emb" # [B, Lb_max, Db]
|
| 214 |
-
"
|
| 215 |
"glm_emb" # [B, Lg_max, Dg]
|
| 216 |
-
"
|
| 217 |
"labels" # [B, Lg_max]
|
| 218 |
"ID"
|
| 219 |
"tr_sequence"
|
|
@@ -233,14 +233,20 @@ class BindPredictor(LightningModule):
|
|
| 233 |
batch_size=logits.size(0),
|
| 234 |
)
|
| 235 |
|
| 236 |
-
# ---- AUPRC on labels in {0, >0.99} only ----
|
| 237 |
-
ap, n_pos, n_neg = auprc_zeros_vs_ones_from_logits(
|
|
|
|
|
|
|
|
|
|
| 238 |
logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99
|
| 239 |
)
|
| 240 |
# per-batch AP (epoch-mean is a decent summary); sync across GPUs if using DDP
|
| 241 |
self.log("train/auprc_0v1",
|
| 242 |
ap if torch.isfinite(ap) else torch.tensor(0.0, device=ap.device),
|
| 243 |
on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0))
|
|
|
|
|
|
|
|
|
|
| 244 |
# (optional) also log class counts so you can sanity-check balance
|
| 245 |
self.log("train/n_pos_0v1", float(n_pos), on_step=False, on_epoch=True, sync_dist=True)
|
| 246 |
self.log("train/n_neg_0v1", float(n_neg), on_step=False, on_epoch=True, sync_dist=True)
|
|
@@ -260,6 +266,22 @@ class BindPredictor(LightningModule):
|
|
| 260 |
prog_bar=True,
|
| 261 |
batch_size=logits.size(0),
|
| 262 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
return loss
|
| 264 |
|
| 265 |
def test_step(self, batch, batch_idx):
|
|
@@ -270,6 +292,21 @@ class BindPredictor(LightningModule):
|
|
| 270 |
self.log(
|
| 271 |
"test/loss", loss, on_step=False, on_epoch=True, batch_size=logits.size(0)
|
| 272 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
return loss
|
| 274 |
|
| 275 |
def on_before_optimizer_step(self, optimizer):
|
|
|
|
| 6 |
from torch import nn
|
| 7 |
from lightning import LightningModule
|
| 8 |
from dpacman.utils.models import set_seed
|
| 9 |
+
from .loss import calculate_loss, auprc_zeros_vs_ones_from_logits, auroc_zeros_vs_ones_from_logits
|
| 10 |
|
| 11 |
set_seed()
|
| 12 |
|
|
|
|
| 211 |
Training step taken by PyTorch-Lightning trainer. Uses batch returned by data collator.
|
| 212 |
Colator returns a dictionary with:
|
| 213 |
"binder_emb" # [B, Lb_max, Db]
|
| 214 |
+
"binder_kpm" # [B, Lb_max]
|
| 215 |
"glm_emb" # [B, Lg_max, Dg]
|
| 216 |
+
"glm_kpm" # [B, Lg_max]
|
| 217 |
"labels" # [B, Lg_max]
|
| 218 |
"ID"
|
| 219 |
"tr_sequence"
|
|
|
|
| 233 |
batch_size=logits.size(0),
|
| 234 |
)
|
| 235 |
|
| 236 |
+
# ---- AUPRC and AUROC on labels in {0, >0.99} only ----
|
| 237 |
+
ap, n_pos, n_neg, precision, recall, thresholds = auprc_zeros_vs_ones_from_logits(
|
| 238 |
+
logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99
|
| 239 |
+
)
|
| 240 |
+
auc, n_pos, n_neg, tpr, fpr, thresolds, tp, fp = auroc_zeros_vs_ones_from_logits(
|
| 241 |
logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99
|
| 242 |
)
|
| 243 |
# per-batch AP (epoch-mean is a decent summary); sync across GPUs if using DDP
|
| 244 |
self.log("train/auprc_0v1",
|
| 245 |
ap if torch.isfinite(ap) else torch.tensor(0.0, device=ap.device),
|
| 246 |
on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0))
|
| 247 |
+
self.log("train/auroc_0v1",
|
| 248 |
+
auc if torch.isfinite(auc) else torch.tensor(0.0, device=auc.device),
|
| 249 |
+
on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0))
|
| 250 |
# (optional) also log class counts so you can sanity-check balance
|
| 251 |
self.log("train/n_pos_0v1", float(n_pos), on_step=False, on_epoch=True, sync_dist=True)
|
| 252 |
self.log("train/n_neg_0v1", float(n_neg), on_step=False, on_epoch=True, sync_dist=True)
|
|
|
|
| 266 |
prog_bar=True,
|
| 267 |
batch_size=logits.size(0),
|
| 268 |
)
|
| 269 |
+
|
| 270 |
+
# ---- AUPRC and AUROC on labels in {0, >0.99} only ----
|
| 271 |
+
ap, n_pos, n_neg, precision, recall, thresholds = auprc_zeros_vs_ones_from_logits(
|
| 272 |
+
logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99
|
| 273 |
+
)
|
| 274 |
+
auc, n_pos, n_neg, tpr, fpr, thresolds, tp, fp = auroc_zeros_vs_ones_from_logits(
|
| 275 |
+
logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99
|
| 276 |
+
)
|
| 277 |
+
# per-batch AP (epoch-mean is a decent summary); sync across GPUs if using DDP
|
| 278 |
+
self.log("val/auprc_0v1",
|
| 279 |
+
ap if torch.isfinite(ap) else torch.tensor(0.0, device=ap.device),
|
| 280 |
+
on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0))
|
| 281 |
+
self.log("val/auroc_0v1",
|
| 282 |
+
auc if torch.isfinite(auc) else torch.tensor(0.0, device=auc.device),
|
| 283 |
+
on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0))
|
| 284 |
+
|
| 285 |
return loss
|
| 286 |
|
| 287 |
def test_step(self, batch, batch_idx):
|
|
|
|
| 292 |
self.log(
|
| 293 |
"test/loss", loss, on_step=False, on_epoch=True, batch_size=logits.size(0)
|
| 294 |
)
|
| 295 |
+
|
| 296 |
+
# ---- AUPRC and AUROC on labels in {0, >0.99} only ----
|
| 297 |
+
ap, n_pos, n_neg, precision, recall, thresholds = auprc_zeros_vs_ones_from_logits(
|
| 298 |
+
logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99
|
| 299 |
+
)
|
| 300 |
+
auc, n_pos, n_neg, tpr, fpr, thresolds, tp, fp = auroc_zeros_vs_ones_from_logits(
|
| 301 |
+
logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99
|
| 302 |
+
)
|
| 303 |
+
# per-batch AP (epoch-mean is a decent summary); sync across GPUs if using DDP
|
| 304 |
+
self.log("test/auprc_0v1",
|
| 305 |
+
ap if torch.isfinite(ap) else torch.tensor(0.0, device=ap.device),
|
| 306 |
+
on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0))
|
| 307 |
+
self.log("test/auroc_0v1",
|
| 308 |
+
auc if torch.isfinite(auc) else torch.tensor(0.0, device=auc.device),
|
| 309 |
+
on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0))
|
| 310 |
return loss
|
| 311 |
|
| 312 |
def on_before_optimizer_step(self, optimizer):
|
dpacman/classifier/model_tmp/__init__.py
DELETED
|
File without changes
|
dpacman/classifier/model_tmp/clustering_data.py
DELETED
|
@@ -1,475 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
import argparse
|
| 3 |
-
import numpy as np
|
| 4 |
-
import pandas as pd
|
| 5 |
-
from pathlib import Path
|
| 6 |
-
import random
|
| 7 |
-
import sys
|
| 8 |
-
import subprocess
|
| 9 |
-
from collections import defaultdict
|
| 10 |
-
|
| 11 |
-
# ─────────────────────────────────────────────────────────────────────────
|
| 12 |
-
# Original helpers (kept; some lightly edited/commented where needed)
|
| 13 |
-
# ─────────────────────────────────────────────────────────────────────────
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
def read_ids_file(p):
|
| 17 |
-
p = Path(p)
|
| 18 |
-
if not p.exists():
|
| 19 |
-
raise FileNotFoundError(f"IDs file not found: {p}")
|
| 20 |
-
return [line.strip() for line in p.open() if line.strip()]
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
def split_embeddings(emb_path, ids_path, out_dir, prefix):
|
| 24 |
-
out_dir = Path(out_dir)
|
| 25 |
-
out_dir.mkdir(parents=True, exist_ok=True)
|
| 26 |
-
|
| 27 |
-
if not Path(emb_path).exists():
|
| 28 |
-
raise FileNotFoundError(f"Embedding file not found: {emb_path}")
|
| 29 |
-
if not Path(ids_path).exists():
|
| 30 |
-
raise FileNotFoundError(f"IDs file not found: {ids_path}")
|
| 31 |
-
|
| 32 |
-
if emb_path.endswith(".npz"):
|
| 33 |
-
data = np.load(emb_path, allow_pickle=True)
|
| 34 |
-
if "embeddings" in data:
|
| 35 |
-
emb = data["embeddings"]
|
| 36 |
-
else:
|
| 37 |
-
raise ValueError(f"{emb_path} missing 'embeddings' key")
|
| 38 |
-
else:
|
| 39 |
-
emb = np.load(emb_path)
|
| 40 |
-
|
| 41 |
-
ids = read_ids_file(ids_path)
|
| 42 |
-
if len(ids) != emb.shape[0]:
|
| 43 |
-
print(
|
| 44 |
-
f"[WARN] length mismatch: {len(ids)} ids vs {emb.shape[0]} embeddings in {emb_path}",
|
| 45 |
-
file=sys.stderr,
|
| 46 |
-
)
|
| 47 |
-
|
| 48 |
-
mapping = {}
|
| 49 |
-
for i, ident in enumerate(ids):
|
| 50 |
-
if i >= emb.shape[0]:
|
| 51 |
-
print(
|
| 52 |
-
f"[WARN] skipping {ident}: no embedding at index {i}", file=sys.stderr
|
| 53 |
-
)
|
| 54 |
-
continue
|
| 55 |
-
arr = emb[i]
|
| 56 |
-
out_file = out_dir / f"{prefix}_{ident}.npy"
|
| 57 |
-
np.save(out_file, arr)
|
| 58 |
-
mapping[ident] = str(out_file)
|
| 59 |
-
return mapping
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
def extract_symbol_from_tf_id(full_id: str) -> str:
|
| 63 |
-
"""
|
| 64 |
-
Given a TF embedding ID like 'sp|O15062|ZBTB5_HUMAN' or 'ZBTB5_HUMAN',
|
| 65 |
-
return the gene symbol uppercase (e.g., 'ZBTB5').
|
| 66 |
-
"""
|
| 67 |
-
if "|" in full_id:
|
| 68 |
-
try:
|
| 69 |
-
# format sp|Accession|SYMBOL_HUMAN
|
| 70 |
-
genepart = full_id.split("|")[2]
|
| 71 |
-
except IndexError:
|
| 72 |
-
genepart = full_id
|
| 73 |
-
else:
|
| 74 |
-
genepart = full_id
|
| 75 |
-
symbol = genepart.split("_")[0]
|
| 76 |
-
return symbol.upper()
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
def build_tf_symbol_map(tf_map):
|
| 80 |
-
"""
|
| 81 |
-
Build mapping gene_symbol -> list of embedding paths.
|
| 82 |
-
"""
|
| 83 |
-
symbol_map = {}
|
| 84 |
-
for full_id, path in tf_map.items():
|
| 85 |
-
symbol = extract_symbol_from_tf_id(full_id)
|
| 86 |
-
symbol_map.setdefault(symbol, []).append(path)
|
| 87 |
-
return symbol_map
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
def tf_key_from_path(path: str) -> str:
|
| 91 |
-
"""
|
| 92 |
-
Given a path like .../tf_sp|O15062|ZBTB5_HUMAN.npy, extract normalized symbol 'ZBTB5'.
|
| 93 |
-
"""
|
| 94 |
-
stem = Path(path).stem # e.g., tf_sp|O15062|ZBTB5_HUMAN
|
| 95 |
-
# remove leading prefix if present (tf_)
|
| 96 |
-
if "_" in stem:
|
| 97 |
-
_, rest = stem.split("_", 1)
|
| 98 |
-
else:
|
| 99 |
-
rest = stem
|
| 100 |
-
return extract_symbol_from_tf_id(rest)
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
def dna_key_from_path(path: str) -> str:
|
| 104 |
-
"""
|
| 105 |
-
Given .../dna_peak42.npy -> 'peak42'
|
| 106 |
-
"""
|
| 107 |
-
stem = Path(path).stem
|
| 108 |
-
if "_" in stem:
|
| 109 |
-
_, rest = stem.split("_", 1)
|
| 110 |
-
else:
|
| 111 |
-
rest = stem
|
| 112 |
-
return rest
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
# ─────────────────────────────────────────────────────────────────────────
|
| 116 |
-
# New helpers for MMseqs clustering & cluster-level splitting
|
| 117 |
-
# ─────────────────────────────────────────────────────────────────────────
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
def write_dna_fasta(df: pd.DataFrame, out_fasta: Path) -> None:
|
| 121 |
-
"""
|
| 122 |
-
Write unique DNA sequences to FASTA using dna_id as header.
|
| 123 |
-
Requires df with columns: dna_id, dna_sequence
|
| 124 |
-
"""
|
| 125 |
-
uniq = df[["dna_id", "dna_sequence"]].drop_duplicates()
|
| 126 |
-
with open(out_fasta, "w") as f:
|
| 127 |
-
for _, row in uniq.iterrows():
|
| 128 |
-
did = row["dna_id"]
|
| 129 |
-
seq = str(row["dna_sequence"]).upper().replace(" ", "").replace("\n", "")
|
| 130 |
-
f.write(f">{did}\n{seq}\n")
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
def run_mmseqs_easy_cluster(
|
| 134 |
-
mmseqs_bin: str,
|
| 135 |
-
fasta: Path,
|
| 136 |
-
out_prefix: Path,
|
| 137 |
-
tmp_dir: Path,
|
| 138 |
-
min_seq_id: float,
|
| 139 |
-
coverage: float,
|
| 140 |
-
cov_mode: int,
|
| 141 |
-
) -> Path:
|
| 142 |
-
"""
|
| 143 |
-
Runs mmseqs easy-cluster on nucleotide sequences.
|
| 144 |
-
Returns the path to a clusters TSV file (creating it if the default one isn't present).
|
| 145 |
-
"""
|
| 146 |
-
tmp_dir.mkdir(parents=True, exist_ok=True)
|
| 147 |
-
out_prefix.parent.mkdir(parents=True, exist_ok=True)
|
| 148 |
-
|
| 149 |
-
cmd = [
|
| 150 |
-
mmseqs_bin,
|
| 151 |
-
"easy-cluster",
|
| 152 |
-
str(fasta),
|
| 153 |
-
str(out_prefix),
|
| 154 |
-
str(tmp_dir),
|
| 155 |
-
"--min-seq-id",
|
| 156 |
-
str(min_seq_id),
|
| 157 |
-
"-c",
|
| 158 |
-
str(coverage),
|
| 159 |
-
"--cov-mode",
|
| 160 |
-
str(cov_mode),
|
| 161 |
-
# You can add performance flags here if needed, e.g.:
|
| 162 |
-
# "--threads", "8"
|
| 163 |
-
]
|
| 164 |
-
print("[i] Running:", " ".join(cmd), flush=True)
|
| 165 |
-
subprocess.run(cmd, check=True)
|
| 166 |
-
|
| 167 |
-
# MMseqs easy-cluster typically writes <out_prefix>_cluster.tsv
|
| 168 |
-
default_tsv = Path(str(out_prefix) + "_cluster.tsv")
|
| 169 |
-
if default_tsv.exists():
|
| 170 |
-
print(f"[i] Found cluster TSV: {default_tsv}")
|
| 171 |
-
return default_tsv
|
| 172 |
-
|
| 173 |
-
# Fallback: try createtsv if default is missing
|
| 174 |
-
# This requires the internal DBs. easy-cluster creates DBs alongside out_prefix.
|
| 175 |
-
# We'll try to locate them and emit a TSV.
|
| 176 |
-
in_db = Path(str(out_prefix) + "_query")
|
| 177 |
-
cl_db = Path(str(out_prefix) + "_cluster")
|
| 178 |
-
out_tsv = Path(str(out_prefix) + "_fallback_cluster.tsv")
|
| 179 |
-
if in_db.exists() and cl_db.exists():
|
| 180 |
-
cmd2 = [
|
| 181 |
-
mmseqs_bin,
|
| 182 |
-
"createtsv",
|
| 183 |
-
str(in_db),
|
| 184 |
-
str(in_db),
|
| 185 |
-
str(cl_db),
|
| 186 |
-
str(out_tsv),
|
| 187 |
-
]
|
| 188 |
-
print("[i] Creating TSV via createtsv:", " ".join(cmd2), flush=True)
|
| 189 |
-
subprocess.run(cmd2, check=True)
|
| 190 |
-
if out_tsv.exists():
|
| 191 |
-
return out_tsv
|
| 192 |
-
|
| 193 |
-
raise FileNotFoundError(
|
| 194 |
-
"Could not locate clusters TSV from mmseqs. "
|
| 195 |
-
"Expected {default_tsv} or createtsv fallback."
|
| 196 |
-
)
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
def parse_mmseqs_clusters(tsv_path: Path) -> dict:
|
| 200 |
-
"""
|
| 201 |
-
Parse MMseqs cluster TSV (rep \t member). Returns dna_id -> cluster_rep_id
|
| 202 |
-
"""
|
| 203 |
-
mapping = {}
|
| 204 |
-
with open(tsv_path) as f:
|
| 205 |
-
for line in f:
|
| 206 |
-
parts = line.rstrip("\n").split("\t")
|
| 207 |
-
if len(parts) < 2:
|
| 208 |
-
continue
|
| 209 |
-
rep, member = parts[0], parts[1]
|
| 210 |
-
mapping[member] = rep
|
| 211 |
-
# Some TSVs include rep->rep; if not, ensure rep is mapped to itself:
|
| 212 |
-
if rep not in mapping:
|
| 213 |
-
mapping[rep] = rep
|
| 214 |
-
return mapping
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
def assign_clusters_to_splits(
|
| 218 |
-
cluster_rep_to_members: dict, val_frac: float, test_frac: float, seed: int = 42
|
| 219 |
-
):
|
| 220 |
-
"""
|
| 221 |
-
cluster_rep_to_members: dict[rep] = [members...]
|
| 222 |
-
Returns: dict with keys 'train','val','test' mapping to sets of dna_id.
|
| 223 |
-
Ensures all members of a cluster go to the same split.
|
| 224 |
-
"""
|
| 225 |
-
rng = random.Random(seed)
|
| 226 |
-
reps = list(cluster_rep_to_members.keys())
|
| 227 |
-
rng.shuffle(reps)
|
| 228 |
-
|
| 229 |
-
# Greedy-ish fill by total member counts to match desired fractions.
|
| 230 |
-
total = sum(len(cluster_rep_to_members[r]) for r in reps)
|
| 231 |
-
target_val = int(round(total * val_frac))
|
| 232 |
-
target_test = int(round(total * test_frac))
|
| 233 |
-
cur_val = cur_test = 0
|
| 234 |
-
|
| 235 |
-
val_ids, test_ids, train_ids = set(), set(), set()
|
| 236 |
-
for rep in reps:
|
| 237 |
-
members = cluster_rep_to_members[rep]
|
| 238 |
-
c = len(members)
|
| 239 |
-
# Fill val first, then test, then train
|
| 240 |
-
if cur_val + c <= target_val:
|
| 241 |
-
val_ids.update(members)
|
| 242 |
-
cur_val += c
|
| 243 |
-
elif cur_test + c <= target_test:
|
| 244 |
-
test_ids.update(members)
|
| 245 |
-
cur_test += c
|
| 246 |
-
else:
|
| 247 |
-
train_ids.update(members)
|
| 248 |
-
|
| 249 |
-
return {"train": train_ids, "val": val_ids, "test": test_ids}
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
# ─────────────────────────────────────────────────────────────────────────
|
| 253 |
-
# Main
|
| 254 |
-
# ─────────────────────────────────────────────────────────────────────────
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
def main():
|
| 258 |
-
parser = argparse.ArgumentParser(
|
| 259 |
-
description="Build TF-DNA pair lists with MMseqs clustering on DNA to prevent split leakage."
|
| 260 |
-
)
|
| 261 |
-
parser.add_argument(
|
| 262 |
-
"--final_csv", required=True, help="final.csv with TF_id and dna_sequence"
|
| 263 |
-
)
|
| 264 |
-
parser.add_argument(
|
| 265 |
-
"--dna_embed_npz", required=True, help="DNA embedding file (.npy or .npz)"
|
| 266 |
-
)
|
| 267 |
-
parser.add_argument(
|
| 268 |
-
"--dna_ids", required=True, help="IDs file for DNA embeddings (peak*.ids)"
|
| 269 |
-
)
|
| 270 |
-
parser.add_argument(
|
| 271 |
-
"--tf_embed_npy", required=True, help="TF embedding file (.npy or .npz)"
|
| 272 |
-
)
|
| 273 |
-
parser.add_argument(
|
| 274 |
-
"--tf_ids", required=True, help="IDs file for TF embeddings (sp|... ids)"
|
| 275 |
-
)
|
| 276 |
-
parser.add_argument("--out_dir", required=True, help="Output directory")
|
| 277 |
-
parser.add_argument("--seed", type=int, default=42)
|
| 278 |
-
|
| 279 |
-
# NEW: MMseqs options & split fractions
|
| 280 |
-
parser.add_argument("--mmseqs_bin", default="mmseqs", help="Path to mmseqs binary")
|
| 281 |
-
parser.add_argument(
|
| 282 |
-
"--min_seq_id", type=float, default=0.9, help="MMseqs --min-seq-id"
|
| 283 |
-
)
|
| 284 |
-
parser.add_argument(
|
| 285 |
-
"--cov", type=float, default=0.8, help="MMseqs -c coverage fraction"
|
| 286 |
-
)
|
| 287 |
-
parser.add_argument(
|
| 288 |
-
"--cov_mode",
|
| 289 |
-
type=int,
|
| 290 |
-
default=1,
|
| 291 |
-
help="MMseqs --cov-mode (1 = coverage of target)",
|
| 292 |
-
)
|
| 293 |
-
parser.add_argument("--val_frac", type=float, default=0.10)
|
| 294 |
-
parser.add_argument("--test_frac", type=float, default=0.10)
|
| 295 |
-
parser.add_argument(
|
| 296 |
-
"--tmp_dir", default=None, help="MMseqs tmp dir (defaults to out_dir/tmp)"
|
| 297 |
-
)
|
| 298 |
-
args = parser.parse_args()
|
| 299 |
-
|
| 300 |
-
random.seed(args.seed)
|
| 301 |
-
out_dir = Path(args.out_dir)
|
| 302 |
-
out_dir.mkdir(parents=True, exist_ok=True)
|
| 303 |
-
|
| 304 |
-
# Load final.csv
|
| 305 |
-
df = pd.read_csv(args.final_csv, dtype=str)
|
| 306 |
-
if "TF_id" not in df.columns or "dna_sequence" not in df.columns:
|
| 307 |
-
raise RuntimeError("final.csv must have columns TF_id and dna_sequence")
|
| 308 |
-
|
| 309 |
-
# Assign dna_id (unique per dna_sequence)
|
| 310 |
-
unique_seqs = df["dna_sequence"].drop_duplicates().tolist()
|
| 311 |
-
seq_to_id = {seq: f"peak{i}" for i, seq in enumerate(unique_seqs)}
|
| 312 |
-
df["dna_id"] = df["dna_sequence"].map(seq_to_id)
|
| 313 |
-
enriched_csv = out_dir / "final_with_dna_id.csv"
|
| 314 |
-
df.to_csv(enriched_csv, index=False)
|
| 315 |
-
print(f"[i] Wrote augmented final.csv with dna_id to {enriched_csv}")
|
| 316 |
-
|
| 317 |
-
# Split embeddings into per-item files (unchanged)
|
| 318 |
-
print(
|
| 319 |
-
f"[i] Splitting DNA embeddings from {args.dna_embed_npz} with ids {args.dna_ids}"
|
| 320 |
-
)
|
| 321 |
-
dna_map = split_embeddings(
|
| 322 |
-
args.dna_embed_npz, args.dna_ids, out_dir / "dna_single", "dna"
|
| 323 |
-
)
|
| 324 |
-
print(
|
| 325 |
-
f"[i] DNA embeddings available: {len(dna_map)} (sample: {list(dna_map.keys())[:10]})"
|
| 326 |
-
)
|
| 327 |
-
print(
|
| 328 |
-
f"[i] Splitting TF embeddings from {args.tf_embed_npy} with ids {args.tf_ids}"
|
| 329 |
-
)
|
| 330 |
-
tf_map = split_embeddings(
|
| 331 |
-
args.tf_embed_npy, args.tf_ids, out_dir / "tf_single", "tf"
|
| 332 |
-
)
|
| 333 |
-
print(
|
| 334 |
-
f"[i] TF embeddings available: {len(tf_map)} (sample: {list(tf_map.keys())[:10]})"
|
| 335 |
-
)
|
| 336 |
-
|
| 337 |
-
# Build gene-symbol normalized map
|
| 338 |
-
tf_symbol_map = build_tf_symbol_map(tf_map)
|
| 339 |
-
print(f"[i] TF symbol map keys (sample): {list(tf_symbol_map.keys())[:30]}")
|
| 340 |
-
|
| 341 |
-
# Diagnostic overlaps
|
| 342 |
-
norm_tf_in_final = set(t.split("_seq")[0].upper() for t in df["TF_id"].unique())
|
| 343 |
-
available_tf_symbols = set(tf_symbol_map.keys())
|
| 344 |
-
intersect_tf = norm_tf_in_final & available_tf_symbols
|
| 345 |
-
print(f"[i] Unique normalized TF symbols in final.csv: {len(norm_tf_in_final)}")
|
| 346 |
-
print(f"[i] Available TF embedding symbols: {len(available_tf_symbols)}")
|
| 347 |
-
print(f"[i] Intersection count: {len(intersect_tf)}")
|
| 348 |
-
if len(intersect_tf) == 0:
|
| 349 |
-
print(
|
| 350 |
-
"[ERROR] No overlap between normalized TF_id and TF embedding symbols.",
|
| 351 |
-
file=sys.stderr,
|
| 352 |
-
)
|
| 353 |
-
print(
|
| 354 |
-
"Sample normalized TFs from final.csv:",
|
| 355 |
-
sorted(list(norm_tf_in_final))[:30],
|
| 356 |
-
file=sys.stderr,
|
| 357 |
-
)
|
| 358 |
-
print(
|
| 359 |
-
"Sample available TF symbols:",
|
| 360 |
-
sorted(list(available_tf_symbols))[:30],
|
| 361 |
-
file=sys.stderr,
|
| 362 |
-
)
|
| 363 |
-
sys.exit(1)
|
| 364 |
-
|
| 365 |
-
dna_ids_final = set(df["dna_id"].unique())
|
| 366 |
-
available_dna_ids = set(dna_map.keys())
|
| 367 |
-
intersect_dna = dna_ids_final & available_dna_ids
|
| 368 |
-
print(
|
| 369 |
-
f"[i] Unique dna_id in final.csv: {len(dna_ids_final)}. Available DNA ids: {len(available_dna_ids)}. Intersection: {len(intersect_dna)}"
|
| 370 |
-
)
|
| 371 |
-
if len(intersect_dna) == 0:
|
| 372 |
-
print("[ERROR] No overlap on DNA ids.", file=sys.stderr)
|
| 373 |
-
sys.exit(1)
|
| 374 |
-
|
| 375 |
-
# ── NEW: MMseqs clustering on DNA sequences ───────────────────────────
|
| 376 |
-
fasta_path = out_dir / "dna_unique.fasta"
|
| 377 |
-
write_dna_fasta(df, fasta_path)
|
| 378 |
-
print(
|
| 379 |
-
f"[i] Wrote FASTA with {df['dna_id'].nunique()} unique sequences → {fasta_path}"
|
| 380 |
-
)
|
| 381 |
-
|
| 382 |
-
tmp_dir = Path(args.tmp_dir) if args.tmp_dir else (out_dir / "mmseqs_tmp")
|
| 383 |
-
cluster_prefix = out_dir / "mmseqs_dna_clusters"
|
| 384 |
-
clusters_tsv = run_mmseqs_easy_cluster(
|
| 385 |
-
mmseqs_bin=args.mmseqs_bin,
|
| 386 |
-
fasta=fasta_path,
|
| 387 |
-
out_prefix=cluster_prefix,
|
| 388 |
-
tmp_dir=tmp_dir,
|
| 389 |
-
min_seq_id=args.min_seq_id,
|
| 390 |
-
coverage=args.cov,
|
| 391 |
-
cov_mode=args.cov_mode,
|
| 392 |
-
)
|
| 393 |
-
|
| 394 |
-
# Parse clusters
|
| 395 |
-
member_to_rep = parse_mmseqs_clusters(clusters_tsv) # dna_id -> rep_id
|
| 396 |
-
# Build rep -> members list
|
| 397 |
-
rep_to_members = defaultdict(list)
|
| 398 |
-
for member, rep in member_to_rep.items():
|
| 399 |
-
rep_to_members[rep].append(member)
|
| 400 |
-
|
| 401 |
-
print(f"[i] Parsed {len(rep_to_members)} clusters from {clusters_tsv}")
|
| 402 |
-
clusters_table = []
|
| 403 |
-
for rep, members in rep_to_members.items():
|
| 404 |
-
for m in members:
|
| 405 |
-
clusters_table.append((m, rep))
|
| 406 |
-
clusters_df = pd.DataFrame(clusters_table, columns=["dna_id", "cluster_id"])
|
| 407 |
-
clusters_df.to_csv(out_dir / "clusters.tsv", sep="\t", index=False)
|
| 408 |
-
print(f"[i] Wrote clusters mapping → {out_dir / 'clusters.tsv'}")
|
| 409 |
-
|
| 410 |
-
# Attach cluster_id back to final df
|
| 411 |
-
df = df.merge(clusters_df, on="dna_id", how="left")
|
| 412 |
-
df.to_csv(out_dir / "final_with_dna_id_and_cluster.csv", index=False)
|
| 413 |
-
print(f"[i] Wrote {out_dir / 'final_with_dna_id_and_cluster.csv'}")
|
| 414 |
-
|
| 415 |
-
# Assign entire clusters to splits
|
| 416 |
-
splits = assign_clusters_to_splits(
|
| 417 |
-
rep_to_members, val_frac=args.val_frac, test_frac=args.test_frac, seed=args.seed
|
| 418 |
-
)
|
| 419 |
-
for k in ["train", "val", "test"]:
|
| 420 |
-
print(f"[i] {k}: {len(splits[k])} dna_ids")
|
| 421 |
-
|
| 422 |
-
# ── Build positive pairs only, per split (NO negatives) ───────────────
|
| 423 |
-
positives_by_split = {"train": [], "val": [], "test": []}
|
| 424 |
-
# Build a quick dna_id -> embedding path map
|
| 425 |
-
dnaid_to_path = {did: path for did, path in dna_map.items()}
|
| 426 |
-
|
| 427 |
-
pos_count = 0
|
| 428 |
-
for _, row in df.iterrows():
|
| 429 |
-
tf_raw = row["TF_id"]
|
| 430 |
-
tf_symbol = tf_raw.split("_seq")[0].upper()
|
| 431 |
-
dnaid = row["dna_id"]
|
| 432 |
-
if (tf_symbol not in tf_symbol_map) or (dnaid not in dnaid_to_path):
|
| 433 |
-
continue
|
| 434 |
-
tf_embedding_path = tf_symbol_map[tf_symbol][0] # first embedding per symbol
|
| 435 |
-
|
| 436 |
-
# decide split by dna_id cluster assignment
|
| 437 |
-
if dnaid in splits["train"]:
|
| 438 |
-
positives_by_split["train"].append(
|
| 439 |
-
(tf_embedding_path, dnaid_to_path[dnaid], 1)
|
| 440 |
-
)
|
| 441 |
-
elif dnaid in splits["val"]:
|
| 442 |
-
positives_by_split["val"].append(
|
| 443 |
-
(tf_embedding_path, dnaid_to_path[dnaid], 1)
|
| 444 |
-
)
|
| 445 |
-
elif dnaid in splits["test"]:
|
| 446 |
-
positives_by_split["test"].append(
|
| 447 |
-
(tf_embedding_path, dnaid_to_path[dnaid], 1)
|
| 448 |
-
)
|
| 449 |
-
pos_count += 1
|
| 450 |
-
|
| 451 |
-
print(
|
| 452 |
-
f"[i] Constructed positives across splits (rows in final.csv iterated: {len(df)})"
|
| 453 |
-
)
|
| 454 |
-
for k in ["train", "val", "test"]:
|
| 455 |
-
print(f"[i] positives[{k}] = {len(positives_by_split[k])}")
|
| 456 |
-
|
| 457 |
-
# # OLD: negatives (kept commented)
|
| 458 |
-
# negatives = []
|
| 459 |
-
# print(f"[i] Sampled {len(negatives)} negatives (neg_per_positive not used)")
|
| 460 |
-
|
| 461 |
-
# Emit split-specific pair lists
|
| 462 |
-
for split in ["train", "val", "test"]:
|
| 463 |
-
out_tsv = out_dir / f"pair_list_{split}.tsv"
|
| 464 |
-
with open(out_tsv, "w") as f:
|
| 465 |
-
for binder_path, glm_path, label in positives_by_split[
|
| 466 |
-
split
|
| 467 |
-
]: # + negatives if you add later
|
| 468 |
-
f.write(f"{binder_path}\t{glm_path}\t{label}\n")
|
| 469 |
-
print(f"[i] Wrote {len(positives_by_split[split])} examples to {out_tsv}")
|
| 470 |
-
|
| 471 |
-
print("✅ Done. Cluster-aware splits ready.")
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
if __name__ == "__main__":
|
| 475 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dpacman/classifier/model_tmp/compress_embeddings.py
DELETED
|
@@ -1,62 +0,0 @@
|
|
| 1 |
-
# compress_embeddings.py
|
| 2 |
-
# USAGE: python compress_embeddings.py --input_glob "/path/to/esm_embeddings/*.npy" --output_dir "/path/to/compressed_embeddings" --esm_dim 1280 --out_dim 256
|
| 3 |
-
# --------------
|
| 4 |
-
import os
|
| 5 |
-
import glob
|
| 6 |
-
import numpy as np
|
| 7 |
-
import torch
|
| 8 |
-
from torch import nn
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
class EmbeddingCompressor(nn.Module):
|
| 12 |
-
def __init__(self, input_dim: int = 1280, output_dim: int = 256):
|
| 13 |
-
super().__init__()
|
| 14 |
-
self.fc = nn.Linear(input_dim, output_dim)
|
| 15 |
-
|
| 16 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 17 |
-
"""
|
| 18 |
-
x: (batch, L, input_dim) or (L, input_dim)
|
| 19 |
-
returns: (batch, output_dim) or (output_dim,)
|
| 20 |
-
"""
|
| 21 |
-
if x.dim() == 2:
|
| 22 |
-
# single example: mean over tokens
|
| 23 |
-
x = x.mean(dim=0, keepdim=True) # → (1, input_dim)
|
| 24 |
-
else:
|
| 25 |
-
# batch: mean over tokens
|
| 26 |
-
x = x.mean(dim=1) # → (batch, input_dim)
|
| 27 |
-
return self.fc(x) # → (batch, output_dim)
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
def compress_file(in_path: str, out_path: str, model: EmbeddingCompressor):
|
| 31 |
-
arr = np.load(in_path) # shape (L, D) or (batch, L, D)
|
| 32 |
-
tensor = torch.from_numpy(arr).float()
|
| 33 |
-
with torch.no_grad():
|
| 34 |
-
compressed = model(tensor) # → (batch, 256)
|
| 35 |
-
out = compressed.cpu().numpy()
|
| 36 |
-
np.save(out_path, out)
|
| 37 |
-
print(f"Saved {out_path}")
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
if __name__ == "__main__":
|
| 41 |
-
import argparse
|
| 42 |
-
|
| 43 |
-
parser = argparse.ArgumentParser(description="Compress ESM embeddings to 256d")
|
| 44 |
-
parser.add_argument(
|
| 45 |
-
"--input_glob",
|
| 46 |
-
type=str,
|
| 47 |
-
required=True,
|
| 48 |
-
help="Glob for your .npy ESM embeddings (e.g. data/esm_*.npy)",
|
| 49 |
-
)
|
| 50 |
-
parser.add_argument("--output_dir", type=str, required=True)
|
| 51 |
-
parser.add_argument("--esm_dim", type=int, default=1280)
|
| 52 |
-
parser.add_argument("--out_dim", type=int, default=256)
|
| 53 |
-
args = parser.parse_args()
|
| 54 |
-
|
| 55 |
-
os.makedirs(args.output_dir, exist_ok=True)
|
| 56 |
-
compressor = EmbeddingCompressor(args.esm_dim, args.out_dim)
|
| 57 |
-
compressor.eval()
|
| 58 |
-
|
| 59 |
-
for fn in glob.glob(args.input_glob):
|
| 60 |
-
base = os.path.basename(fn).replace(".npy", "_256.npy")
|
| 61 |
-
out_path = os.path.join(args.output_dir, base)
|
| 62 |
-
compress_file(fn, out_path, compressor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dpacman/classifier/model_tmp/compute_embeddings.py
DELETED
|
@@ -1,612 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Plug-and-play embedding extraction for:
|
| 3 |
-
• Chromosome sequences (from raw UCSC JSON)
|
| 4 |
-
• TF sequences (transcription_factors.fasta)
|
| 5 |
-
|
| 6 |
-
Usage example (DNA + protein in one go):
|
| 7 |
-
module load miniconda/24.7.1
|
| 8 |
-
conda activate dpacman
|
| 9 |
-
python dpacman/data/compute_embeddings.py \
|
| 10 |
-
--genome-json-dir ../data_files/raw/genomes/hg38 \
|
| 11 |
-
--tf-fasta ../data_files/processed/tfclust/hg38_tf/transcription_factors.fasta \
|
| 12 |
-
--chrom-model caduceus \
|
| 13 |
-
--tf-model esm-dbp \
|
| 14 |
-
--out-dir ../data_files/processed/tfclust/hg38_tf/embeddings \
|
| 15 |
-
--device cuda
|
| 16 |
-
"""
|
| 17 |
-
|
| 18 |
-
import os
|
| 19 |
-
import re
|
| 20 |
-
import argparse
|
| 21 |
-
import json
|
| 22 |
-
import numpy as np
|
| 23 |
-
from pathlib import Path
|
| 24 |
-
import torch
|
| 25 |
-
from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM, pipeline
|
| 26 |
-
import esm
|
| 27 |
-
from Bio import SeqIO
|
| 28 |
-
import time
|
| 29 |
-
|
| 30 |
-
# ---- model wrappers ----
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
class CaduceusEmbedder:
|
| 34 |
-
def __init__(self, device, chunk_size=131_072, overlap=0):
|
| 35 |
-
"""
|
| 36 |
-
device: 'cpu' or 'cuda'
|
| 37 |
-
chunk_size: max bases (and thus tokens) to send in one forward pass
|
| 38 |
-
overlap: how many bases each window overlaps the previous; 0 = no overlap
|
| 39 |
-
"""
|
| 40 |
-
model_name = "kuleshov-group/caduceus-ph_seqlen-131k_d_model-256_n_layer-16"
|
| 41 |
-
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 42 |
-
model_name, trust_remote_code=True
|
| 43 |
-
)
|
| 44 |
-
self.model = (
|
| 45 |
-
AutoModel.from_pretrained(model_name, trust_remote_code=True)
|
| 46 |
-
.to(device)
|
| 47 |
-
.eval()
|
| 48 |
-
)
|
| 49 |
-
self.device = device
|
| 50 |
-
self.chunk_size = chunk_size
|
| 51 |
-
self.step = chunk_size - overlap
|
| 52 |
-
|
| 53 |
-
def embed(self, seqs):
|
| 54 |
-
"""
|
| 55 |
-
seqs: List[str] of DNA sequences (each <= chunk_size for this test)
|
| 56 |
-
returns: np.ndarray of shape (N, L, D), raw per‐token embeddings
|
| 57 |
-
"""
|
| 58 |
-
# outputs = []
|
| 59 |
-
# for seq in seqs:
|
| 60 |
-
# # --- new: raw per‐token embeddings in one shot ---
|
| 61 |
-
# toks = self.tokenizer(
|
| 62 |
-
# seq,
|
| 63 |
-
# return_tensors="pt",
|
| 64 |
-
# padding=False,
|
| 65 |
-
# truncation=True,
|
| 66 |
-
# max_length=self.chunk_size
|
| 67 |
-
# ).to(self.device)
|
| 68 |
-
# with torch.no_grad():
|
| 69 |
-
# out = self.model(**toks).last_hidden_state # (1, L, D)
|
| 70 |
-
# outputs.append(out.cpu().numpy()[0]) # (L, D)
|
| 71 |
-
|
| 72 |
-
# return np.stack(outputs, axis=0) # (N, L, D)
|
| 73 |
-
outputs = []
|
| 74 |
-
for seq in seqs:
|
| 75 |
-
toks = self.tokenizer(
|
| 76 |
-
seq,
|
| 77 |
-
return_tensors="pt",
|
| 78 |
-
padding=False,
|
| 79 |
-
truncation=True,
|
| 80 |
-
max_length=self.chunk_size,
|
| 81 |
-
).to(self.device)
|
| 82 |
-
with torch.no_grad():
|
| 83 |
-
out = self.model(**toks).last_hidden_state # (1, L, D)
|
| 84 |
-
outputs.append(out.cpu().numpy()[0]) # (L, D)
|
| 85 |
-
return outputs # list of variable-length (L_i, D) arrays
|
| 86 |
-
|
| 87 |
-
def benchmark(self, lengths=None):
|
| 88 |
-
"""
|
| 89 |
-
Time embedding on single-sequence of various lengths.
|
| 90 |
-
By default tests [5K,10K,50K,100K,chunk_size].
|
| 91 |
-
"""
|
| 92 |
-
tests = lengths or [5_000, 10_000, 50_000, 100_000, self.chunk_size]
|
| 93 |
-
print(f"→ Benchmarking Caduceus on device={self.device}")
|
| 94 |
-
for sz in tests:
|
| 95 |
-
seq = "A" * sz
|
| 96 |
-
# Warm-up
|
| 97 |
-
_ = self.embed([seq])
|
| 98 |
-
if self.device != "cpu":
|
| 99 |
-
torch.cuda.synchronize()
|
| 100 |
-
t0 = time.perf_counter()
|
| 101 |
-
_ = self.embed([seq])
|
| 102 |
-
if self.device != "cpu":
|
| 103 |
-
torch.cuda.synchronize()
|
| 104 |
-
t1 = time.perf_counter()
|
| 105 |
-
print(f" length={sz:6,d} time={(t1-t0)*1000:7.1f} ms")
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
class SegmentNTEmbedder:
|
| 109 |
-
def __init__(self, device):
|
| 110 |
-
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 111 |
-
"InstaDeepAI/segment_nt", trust_remote_code=True
|
| 112 |
-
)
|
| 113 |
-
self.model = (
|
| 114 |
-
AutoModel.from_pretrained("InstaDeepAI/segment_nt", trust_remote_code=True)
|
| 115 |
-
.to(device)
|
| 116 |
-
.eval()
|
| 117 |
-
)
|
| 118 |
-
self.device = device
|
| 119 |
-
|
| 120 |
-
def _adjust_length(self, input_ids):
|
| 121 |
-
bs, L = input_ids.shape
|
| 122 |
-
excl = L - 1
|
| 123 |
-
remainder = (excl) % 4
|
| 124 |
-
if remainder != 0:
|
| 125 |
-
pad_needed = 4 - remainder
|
| 126 |
-
pad_tensor = torch.full(
|
| 127 |
-
(bs, pad_needed),
|
| 128 |
-
self.tokenizer.pad_token_id,
|
| 129 |
-
dtype=input_ids.dtype,
|
| 130 |
-
device=input_ids.device,
|
| 131 |
-
)
|
| 132 |
-
input_ids = torch.cat([input_ids, pad_tensor], dim=1)
|
| 133 |
-
return input_ids
|
| 134 |
-
|
| 135 |
-
def embed(self, seqs, batch_size=16):
|
| 136 |
-
"""
|
| 137 |
-
seqs: List[str]
|
| 138 |
-
Returns: np.ndarray of shape (N, D)
|
| 139 |
-
"""
|
| 140 |
-
all_embeddings = []
|
| 141 |
-
for i in range(0, len(seqs), batch_size):
|
| 142 |
-
batch_seqs = seqs[i : i + batch_size]
|
| 143 |
-
encoded = self.tokenizer.batch_encode_plus(
|
| 144 |
-
batch_seqs,
|
| 145 |
-
return_tensors="pt",
|
| 146 |
-
padding=True,
|
| 147 |
-
truncation=True,
|
| 148 |
-
)
|
| 149 |
-
input_ids = encoded["input_ids"].to(self.device) # (B, L)
|
| 150 |
-
attention_mask = input_ids != self.tokenizer.pad_token_id
|
| 151 |
-
|
| 152 |
-
input_ids = self._adjust_length(input_ids)
|
| 153 |
-
attention_mask = input_ids != self.tokenizer.pad_token_id
|
| 154 |
-
|
| 155 |
-
with torch.no_grad():
|
| 156 |
-
outs = self.model(
|
| 157 |
-
input_ids,
|
| 158 |
-
attention_mask=attention_mask,
|
| 159 |
-
output_hidden_states=True,
|
| 160 |
-
return_dict=True,
|
| 161 |
-
)
|
| 162 |
-
if hasattr(outs, "hidden_states") and outs.hidden_states is not None:
|
| 163 |
-
last_hidden = outs.hidden_states[-1] # (B, L, D)
|
| 164 |
-
else:
|
| 165 |
-
last_hidden = outs.last_hidden_state # fallback
|
| 166 |
-
|
| 167 |
-
# Exclude CLS token if present (assume first token) and pool
|
| 168 |
-
pooled = last_hidden[:, 1:, :].mean(dim=1) # (B, D)
|
| 169 |
-
all_embeddings.append(pooled.cpu().numpy())
|
| 170 |
-
|
| 171 |
-
# release fragmentation
|
| 172 |
-
torch.cuda.empty_cache()
|
| 173 |
-
|
| 174 |
-
return np.vstack(all_embeddings) # (N, D)
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
class DNABertEmbedder:
|
| 178 |
-
def __init__(self, device):
|
| 179 |
-
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 180 |
-
"zhihan1996/DNA_bert_6", trust_remote_code=True
|
| 181 |
-
)
|
| 182 |
-
self.model = AutoModel.from_pretrained(
|
| 183 |
-
"zhihan1996/DNA_bert_6", trust_remote_code=True
|
| 184 |
-
).to(device)
|
| 185 |
-
self.device = device
|
| 186 |
-
|
| 187 |
-
def embed(self, seqs):
|
| 188 |
-
embs = []
|
| 189 |
-
for s in seqs:
|
| 190 |
-
tokens = self.tokenizer(s, return_tensors="pt", padding=True)[
|
| 191 |
-
"input_ids"
|
| 192 |
-
].to(self.device)
|
| 193 |
-
with torch.no_grad():
|
| 194 |
-
out = self.model(tokens).last_hidden_state.mean(1)
|
| 195 |
-
embs.append(out.cpu().numpy())
|
| 196 |
-
return np.vstack(embs)
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
class NucleotideTransformerEmbedder:
|
| 200 |
-
def __init__(self, device):
|
| 201 |
-
# HF “feature-extraction” returns a list of (L, D) arrays for each input
|
| 202 |
-
# device: “cpu” or “cuda”
|
| 203 |
-
self.pipe = pipeline(
|
| 204 |
-
"feature-extraction",
|
| 205 |
-
model="InstaDeepAI/nucleotide-transformer-500m-1000g",
|
| 206 |
-
device=(
|
| 207 |
-
-1 if device == "cpu" else 0
|
| 208 |
-
), # HF uses -1 for CPU, 0 for GPU #:contentReference[oaicite:0]{index=0}
|
| 209 |
-
)
|
| 210 |
-
|
| 211 |
-
def embed(self, seqs):
|
| 212 |
-
"""
|
| 213 |
-
seqs: List[str] of raw DNA sequences
|
| 214 |
-
returns: (N, D) array, one D-dim vector per sequence
|
| 215 |
-
"""
|
| 216 |
-
all_embeddings = self.pipe(seqs, truncation=True, padding=True)
|
| 217 |
-
# all_embeddings is a List of shape (L, D) arrays
|
| 218 |
-
pooled = [np.mean(x, axis=0) for x in all_embeddings]
|
| 219 |
-
return np.vstack(pooled)
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
# class ESMEmbedder:
|
| 223 |
-
# def __init__(self, device):
|
| 224 |
-
# self.model, self.alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
|
| 225 |
-
# self.batch_converter = self.alphabet.get_batch_converter()
|
| 226 |
-
# self.model.to(device).eval()
|
| 227 |
-
# self.device = device
|
| 228 |
-
|
| 229 |
-
# def embed(self, seqs):
|
| 230 |
-
# batch = [(str(i), seq) for i, seq in enumerate(seqs)]
|
| 231 |
-
# _, _, toks = self.batch_converter(batch)
|
| 232 |
-
# toks = toks.to(self.device)
|
| 233 |
-
# with torch.no_grad():
|
| 234 |
-
# results = self.model(toks, repr_layers=[33], return_contacts=False)
|
| 235 |
-
# reps = results["representations"][33]
|
| 236 |
-
# return reps[:, 1:-1].mean(1).cpu().numpy()
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
class ESMEmbedder:
|
| 240 |
-
def __init__(self, device, model_name="esm2_t33_650M_UR50D"):
|
| 241 |
-
# Try to load the specified ESM-2 model; fallback to esm1b if missing
|
| 242 |
-
self.device = device
|
| 243 |
-
try:
|
| 244 |
-
self.model, self.alphabet = getattr(esm.pretrained, model_name)()
|
| 245 |
-
self.is_esm2 = model_name.lower().startswith("esm2")
|
| 246 |
-
except AttributeError:
|
| 247 |
-
# fallback to ESM-1b
|
| 248 |
-
self.model, self.alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
|
| 249 |
-
self.is_esm2 = False
|
| 250 |
-
self.batch_converter = self.alphabet.get_batch_converter()
|
| 251 |
-
self.model.to(device).eval()
|
| 252 |
-
# determine max length: esm2 models vary; use default 1024 for esm1b
|
| 253 |
-
self.max_len = (
|
| 254 |
-
4096 if self.is_esm2 else 1024
|
| 255 |
-
) # adjust if your esm2 variant has explicit limit
|
| 256 |
-
# for chunking: reserve 2 tokens if model uses BOS/EOS
|
| 257 |
-
self.chunk_size = self.max_len - 2
|
| 258 |
-
self.overlap = self.chunk_size // 4 # 25% overlap to smooth boundaries
|
| 259 |
-
|
| 260 |
-
def _chunk_sequence(self, seq):
|
| 261 |
-
"""
|
| 262 |
-
Return list of possibly overlapping chunks of seq, each <= chunk_size.
|
| 263 |
-
"""
|
| 264 |
-
if len(seq) <= self.chunk_size:
|
| 265 |
-
return [seq]
|
| 266 |
-
step = self.chunk_size - self.overlap
|
| 267 |
-
chunks = []
|
| 268 |
-
for i in range(0, len(seq), step):
|
| 269 |
-
chunk = seq[i : i + self.chunk_size]
|
| 270 |
-
if not chunk:
|
| 271 |
-
break
|
| 272 |
-
chunks.append(chunk)
|
| 273 |
-
return chunks
|
| 274 |
-
|
| 275 |
-
def embed(self, seqs):
|
| 276 |
-
"""
|
| 277 |
-
seqs: List[str] of protein sequences.
|
| 278 |
-
Returns: np.ndarray of shape (N, D) pooled per-sequence embeddings.
|
| 279 |
-
"""
|
| 280 |
-
all_embeddings = []
|
| 281 |
-
for i, seq in enumerate(seqs):
|
| 282 |
-
chunks = self._chunk_sequence(seq)
|
| 283 |
-
chunk_vecs = []
|
| 284 |
-
# process chunks in batch if small number, else sequentially
|
| 285 |
-
for chunk in chunks:
|
| 286 |
-
batch = [(str(i), chunk)]
|
| 287 |
-
_, _, toks = self.batch_converter(batch)
|
| 288 |
-
toks = toks.to(self.device)
|
| 289 |
-
with torch.no_grad():
|
| 290 |
-
results = self.model(toks, repr_layers=[33], return_contacts=False)
|
| 291 |
-
reps = results["representations"][33] # (1, L, D)
|
| 292 |
-
# remove BOS/EOS if present: take 1:-1 if length permits
|
| 293 |
-
if reps.size(1) > 2:
|
| 294 |
-
rep = reps[:, 1:-1].mean(1) # (1, D)
|
| 295 |
-
else:
|
| 296 |
-
rep = reps.mean(1) # fallback
|
| 297 |
-
chunk_vecs.append(rep.squeeze(0)) # (D,)
|
| 298 |
-
if len(chunk_vecs) == 1:
|
| 299 |
-
seq_vec = chunk_vecs[0]
|
| 300 |
-
else:
|
| 301 |
-
# average chunk vectors
|
| 302 |
-
stacked = torch.stack(chunk_vecs, dim=0) # (num_chunks, D)
|
| 303 |
-
seq_vec = stacked.mean(0)
|
| 304 |
-
all_embeddings.append(seq_vec.cpu().numpy())
|
| 305 |
-
return np.vstack(all_embeddings) # (N, D)
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
# class ESMDBPEmbedder:
|
| 309 |
-
# def __init__(self, device):
|
| 310 |
-
# base_model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
|
| 311 |
-
# model_path = (
|
| 312 |
-
# Path(__file__).resolve().parent.parent
|
| 313 |
-
# / "pretrained" / "ESM-DBP" / "ESM-DBP.model"
|
| 314 |
-
# )
|
| 315 |
-
# checkpoint = torch.load(model_path, map_location="cpu")
|
| 316 |
-
# clean_sd = {}
|
| 317 |
-
# for k, v in checkpoint.items():
|
| 318 |
-
# clean_sd[k.replace("module.", "")] = v
|
| 319 |
-
# result = base_model.load_state_dict(clean_sd, strict=False)
|
| 320 |
-
# if result.missing_keys:
|
| 321 |
-
# print(f"[ESMDBP] missing keys: {result.missing_keys}")
|
| 322 |
-
# if result.unexpected_keys:
|
| 323 |
-
# print(f"[ESMDBP] unexpected keys: {result.unexpected_keys}")
|
| 324 |
-
|
| 325 |
-
# self.model = base_model.to(device).eval()
|
| 326 |
-
# self.alphabet = alphabet
|
| 327 |
-
# self.batch_converter = alphabet.get_batch_converter()
|
| 328 |
-
# self.device = device
|
| 329 |
-
|
| 330 |
-
# def embed(self, seqs):
|
| 331 |
-
# batch = [(str(i), seq) for i, seq in enumerate(seqs)]
|
| 332 |
-
# _, _, toks = self.batch_converter(batch)
|
| 333 |
-
# toks = toks.to(self.device)
|
| 334 |
-
# with torch.no_grad():
|
| 335 |
-
# out = self.model(toks, repr_layers=[33], return_contacts=False)
|
| 336 |
-
# reps = out["representations"][33]
|
| 337 |
-
# # skip start/end tokens
|
| 338 |
-
# return reps[:, 1:-1].mean(1).cpu().numpy()
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
class ESMDBPEmbedder:
|
| 342 |
-
def __init__(self, device):
|
| 343 |
-
base_model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
|
| 344 |
-
model_path = (
|
| 345 |
-
Path(__file__).resolve().parent.parent
|
| 346 |
-
/ "pretrained"
|
| 347 |
-
/ "ESM-DBP"
|
| 348 |
-
/ "ESM-DBP.model"
|
| 349 |
-
)
|
| 350 |
-
checkpoint = torch.load(model_path, map_location="cpu")
|
| 351 |
-
clean_sd = {}
|
| 352 |
-
for k, v in checkpoint.items():
|
| 353 |
-
clean_sd[k.replace("module.", "")] = v
|
| 354 |
-
result = base_model.load_state_dict(clean_sd, strict=False)
|
| 355 |
-
if result.missing_keys:
|
| 356 |
-
print(f"[ESMDBP] missing keys: {result.missing_keys}")
|
| 357 |
-
if result.unexpected_keys:
|
| 358 |
-
print(f"[ESMDBP] unexpected keys: {result.unexpected_keys}")
|
| 359 |
-
|
| 360 |
-
self.model = base_model.to(device).eval()
|
| 361 |
-
self.alphabet = alphabet
|
| 362 |
-
self.batch_converter = alphabet.get_batch_converter()
|
| 363 |
-
self.device = device
|
| 364 |
-
self.max_len = 1024 # same limit as esm1b
|
| 365 |
-
self.chunk_size = self.max_len - 2
|
| 366 |
-
self.overlap = self.chunk_size // 4
|
| 367 |
-
|
| 368 |
-
def _chunk_sequence(self, seq):
|
| 369 |
-
if len(seq) <= self.chunk_size:
|
| 370 |
-
return [seq]
|
| 371 |
-
step = self.chunk_size - self.overlap
|
| 372 |
-
chunks = []
|
| 373 |
-
for i in range(0, len(seq), step):
|
| 374 |
-
chunk = seq[i : i + self.chunk_size]
|
| 375 |
-
if not chunk:
|
| 376 |
-
break
|
| 377 |
-
chunks.append(chunk)
|
| 378 |
-
return chunks
|
| 379 |
-
|
| 380 |
-
def embed(self, seqs):
|
| 381 |
-
all_embeddings = []
|
| 382 |
-
for i, seq in enumerate(seqs):
|
| 383 |
-
chunks = self._chunk_sequence(seq)
|
| 384 |
-
chunk_vecs = []
|
| 385 |
-
for chunk in chunks:
|
| 386 |
-
batch = [(str(i), chunk)]
|
| 387 |
-
_, _, toks = self.batch_converter(batch)
|
| 388 |
-
toks = toks.to(self.device)
|
| 389 |
-
with torch.no_grad():
|
| 390 |
-
out = self.model(toks, repr_layers=[33], return_contacts=False)
|
| 391 |
-
reps = out["representations"][33]
|
| 392 |
-
if reps.size(1) > 2:
|
| 393 |
-
rep = reps[:, 1:-1].mean(1)
|
| 394 |
-
else:
|
| 395 |
-
rep = reps.mean(1)
|
| 396 |
-
chunk_vecs.append(rep.squeeze(0))
|
| 397 |
-
if len(chunk_vecs) == 1:
|
| 398 |
-
seq_vec = chunk_vecs[0]
|
| 399 |
-
else:
|
| 400 |
-
stacked = torch.stack(chunk_vecs, dim=0)
|
| 401 |
-
seq_vec = stacked.mean(0)
|
| 402 |
-
all_embeddings.append(seq_vec.cpu().numpy())
|
| 403 |
-
return np.vstack(all_embeddings)
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
class GPNEmbedder:
|
| 407 |
-
def __init__(self, device):
|
| 408 |
-
model_name = "songlab/gpn-msa-sapiens"
|
| 409 |
-
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 410 |
-
self.model = AutoModelForMaskedLM.from_pretrained(model_name)
|
| 411 |
-
self.model.to(device)
|
| 412 |
-
self.model.eval()
|
| 413 |
-
self.device = device
|
| 414 |
-
|
| 415 |
-
def embed(self, seqs):
|
| 416 |
-
inputs = self.tokenizer(
|
| 417 |
-
seqs, return_tensors="pt", padding=True, truncation=True
|
| 418 |
-
).to(self.device)
|
| 419 |
-
|
| 420 |
-
with torch.no_grad():
|
| 421 |
-
last_hidden = self.model(**inputs).last_hidden_state
|
| 422 |
-
return last_hidden.mean(dim=1).cpu().numpy()
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
class ProGenEmbedder:
|
| 426 |
-
def __init__(self, device):
|
| 427 |
-
model_name = "jinyuan22/ProGen2-base"
|
| 428 |
-
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 429 |
-
self.model = AutoModel.from_pretrained(model_name).to(device).eval()
|
| 430 |
-
self.device = device
|
| 431 |
-
|
| 432 |
-
def embed(self, seqs):
|
| 433 |
-
inputs = self.tokenizer(
|
| 434 |
-
seqs, return_tensors="pt", padding=True, truncation=True
|
| 435 |
-
).to(self.device)
|
| 436 |
-
with torch.no_grad():
|
| 437 |
-
last_hidden = self.model(**inputs).last_hidden_state
|
| 438 |
-
return last_hidden.mean(dim=1).cpu().numpy()
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
# ---- main pipeline ----
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
def get_embedder(name, device, for_dna=True):
|
| 445 |
-
name = name.lower()
|
| 446 |
-
if for_dna:
|
| 447 |
-
if name == "caduceus":
|
| 448 |
-
return CaduceusEmbedder(device)
|
| 449 |
-
if name == "dnabert":
|
| 450 |
-
return DNABertEmbedder(device)
|
| 451 |
-
if name == "nucleotide":
|
| 452 |
-
return NucleotideTransformerEmbedder(device)
|
| 453 |
-
if name == "gpn":
|
| 454 |
-
return GPNEmbedder(device)
|
| 455 |
-
if name == "segmentnt":
|
| 456 |
-
return SegmentNTEmbedder(device)
|
| 457 |
-
else:
|
| 458 |
-
if name in ("esm",):
|
| 459 |
-
return ESMEmbedder(device)
|
| 460 |
-
if name in ("esm-dbp", "esm_dbp"):
|
| 461 |
-
return ESMDBPEmbedder(device)
|
| 462 |
-
if name == "progen":
|
| 463 |
-
return ProGenEmbedder(device)
|
| 464 |
-
raise ValueError(f"Unknown model {name} (for_dna={for_dna})")
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
def pad_token_embeddings(list_of_arrays, pad_value=0.0):
|
| 468 |
-
"""
|
| 469 |
-
list_of_arrays: list of (L_i, D) numpy arrays
|
| 470 |
-
Returns:
|
| 471 |
-
padded: (N, L_max, D) array
|
| 472 |
-
mask: (N, L_max) boolean array where True = real token, False = padding
|
| 473 |
-
"""
|
| 474 |
-
N = len(list_of_arrays)
|
| 475 |
-
D = list_of_arrays[0].shape[1]
|
| 476 |
-
L_max = max(arr.shape[0] for arr in list_of_arrays)
|
| 477 |
-
padded = np.full((N, L_max, D), pad_value, dtype=list_of_arrays[0].dtype)
|
| 478 |
-
mask = np.zeros((N, L_max), dtype=bool)
|
| 479 |
-
for i, arr in enumerate(list_of_arrays):
|
| 480 |
-
L = arr.shape[0]
|
| 481 |
-
padded[i, :L] = arr
|
| 482 |
-
mask[i, :L] = True
|
| 483 |
-
return padded, mask
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
def embed_and_save(seqs, ids, embedder, out_path):
|
| 487 |
-
embs = embedder.embed(seqs)
|
| 488 |
-
|
| 489 |
-
# Decide whether we got variable-length per-token outputs (list of (L, D))
|
| 490 |
-
is_variable_token = (
|
| 491 |
-
isinstance(embs, (list, tuple))
|
| 492 |
-
and len(embs) > 0
|
| 493 |
-
and hasattr(embs[0], "shape")
|
| 494 |
-
and embs[0].ndim == 2
|
| 495 |
-
)
|
| 496 |
-
|
| 497 |
-
if is_variable_token:
|
| 498 |
-
# pad to (N, L_max, D) + mask
|
| 499 |
-
padded, mask = pad_token_embeddings(embs)
|
| 500 |
-
# Save both embeddings and mask together in an .npz for convenience
|
| 501 |
-
np.savez_compressed(
|
| 502 |
-
out_path.with_suffix(".caduceus.npz"),
|
| 503 |
-
embeddings=padded,
|
| 504 |
-
mask=mask,
|
| 505 |
-
ids=np.array(ids, dtype=object),
|
| 506 |
-
)
|
| 507 |
-
else:
|
| 508 |
-
# fixed shape output, e.g., pooled (N, D)
|
| 509 |
-
array = np.vstack(embs) if isinstance(embs, list) else embs
|
| 510 |
-
np.save(out_path, array)
|
| 511 |
-
with open(out_path.with_suffix(".ids"), "w") as f:
|
| 512 |
-
f.write("\n".join(ids))
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
if __name__ == "__main__":
|
| 516 |
-
|
| 517 |
-
p = argparse.ArgumentParser()
|
| 518 |
-
p.add_argument(
|
| 519 |
-
"--peak-fasta",
|
| 520 |
-
default="binding_peaks_unique.fa",
|
| 521 |
-
help="FASTA of deduplicated binding peak sequences; if present this is used for DNA embedding instead of genome JSONs",
|
| 522 |
-
)
|
| 523 |
-
p.add_argument(
|
| 524 |
-
"--genome-json-dir",
|
| 525 |
-
default=None,
|
| 526 |
-
help="(fallback) directory of UCSC JSONs for full chromosome embedding if peak FASTA is missing or you explicitly want chromosomes",
|
| 527 |
-
)
|
| 528 |
-
p.add_argument(
|
| 529 |
-
"--skip-dna",
|
| 530 |
-
action="store_true",
|
| 531 |
-
help="if set, skip the chromosome embedding step",
|
| 532 |
-
) # if glm embeddings successful but not plm embeddings
|
| 533 |
-
p.add_argument("--tf-fasta", required=True, help="input TF FASTA file")
|
| 534 |
-
p.add_argument("--chrom-model", default="caduceus")
|
| 535 |
-
p.add_argument("--tf-model", default="esm-dbp")
|
| 536 |
-
p.add_argument(
|
| 537 |
-
"--out-dir", default="data_files/processed/tfclust/hg38_tf/embeddings"
|
| 538 |
-
)
|
| 539 |
-
p.add_argument("--device", default="cpu")
|
| 540 |
-
args = p.parse_args()
|
| 541 |
-
|
| 542 |
-
os.makedirs(args.out_dir, exist_ok=True)
|
| 543 |
-
device = args.device
|
| 544 |
-
|
| 545 |
-
if not args.skip_dna:
|
| 546 |
-
peak_fasta = Path(args.peak_fasta)
|
| 547 |
-
if peak_fasta.exists():
|
| 548 |
-
# Load peak sequences from FASTA
|
| 549 |
-
from Bio import SeqIO
|
| 550 |
-
|
| 551 |
-
peak_seqs = []
|
| 552 |
-
peak_ids = []
|
| 553 |
-
for rec in SeqIO.parse(peak_fasta, "fasta"):
|
| 554 |
-
peak_ids.append(rec.id)
|
| 555 |
-
peak_seqs.append(str(rec.seq))
|
| 556 |
-
print(
|
| 557 |
-
f"Embedding {len(peak_seqs)} binding peak sequences from {peak_fasta}",
|
| 558 |
-
flush=True,
|
| 559 |
-
)
|
| 560 |
-
dna_embedder = get_embedder(args.chrom_model, device, for_dna=True)
|
| 561 |
-
out_peaks = Path(args.out_dir) / f"peaks_{args.chrom_model}.npy"
|
| 562 |
-
embed_and_save(peak_seqs, peak_ids, dna_embedder, out_peaks)
|
| 563 |
-
elif args.genome_json_dir:
|
| 564 |
-
# Legacy: load full chromosomes from JSONs (chr1–22, X, Y, M)
|
| 565 |
-
genome_dir = Path(args.genome_json_dir)
|
| 566 |
-
chrom_seqs, chrom_ids = [], []
|
| 567 |
-
primary_pattern = re.compile(
|
| 568 |
-
r"^hg38_chr(?:[1-9]|1[0-9]|2[0-2]|X|Y|M)\.json$"
|
| 569 |
-
)
|
| 570 |
-
for j in sorted(genome_dir.iterdir()):
|
| 571 |
-
if not primary_pattern.match(j.name):
|
| 572 |
-
continue
|
| 573 |
-
data = json.loads(j.read_text())
|
| 574 |
-
seq = data.get("dna") or data.get("sequence")
|
| 575 |
-
chrom = data.get("chrom") or j.stem.split("_")[-1]
|
| 576 |
-
chrom_seqs.append(seq)
|
| 577 |
-
chrom_ids.append(chrom)
|
| 578 |
-
cutoff = CaduceusEmbedder(device).chunk_size
|
| 579 |
-
long_chroms = [
|
| 580 |
-
(chrom, len(seq))
|
| 581 |
-
for chrom, seq in zip(chrom_ids, chrom_seqs)
|
| 582 |
-
if len(seq) > cutoff
|
| 583 |
-
]
|
| 584 |
-
if long_chroms:
|
| 585 |
-
print(
|
| 586 |
-
"⚠️ Chromosomes exceeding Caduceus max tokens ({}):".format(cutoff)
|
| 587 |
-
)
|
| 588 |
-
for chrom, L in long_chroms:
|
| 589 |
-
print(f" {chrom}: {L} bases")
|
| 590 |
-
else:
|
| 591 |
-
print("All chromosomes ≤ Caduceus limit ({}).".format(cutoff))
|
| 592 |
-
|
| 593 |
-
chrom_embedder = get_embedder(args.chrom_model, device, for_dna=True)
|
| 594 |
-
out_chrom = Path(args.out_dir) / f"chrom_{args.chrom_model}.npy"
|
| 595 |
-
embed_and_save(chrom_seqs, chrom_ids, chrom_embedder, out_chrom)
|
| 596 |
-
else:
|
| 597 |
-
raise ValueError(
|
| 598 |
-
"No input for DNA embedding: provide a peak FASTA (default binding_peaks_unique.fa) or set --genome-json-dir for chromosome JSONs."
|
| 599 |
-
)
|
| 600 |
-
|
| 601 |
-
# Load TF sequences
|
| 602 |
-
tf_seqs, tf_ids = [], []
|
| 603 |
-
for record in SeqIO.parse(args.tf_fasta, "fasta"):
|
| 604 |
-
tf_ids.append(record.id)
|
| 605 |
-
tf_seqs.append(str(record.seq))
|
| 606 |
-
|
| 607 |
-
# embed and save
|
| 608 |
-
tf_embedder = get_embedder(args.tf_model, device, for_dna=False)
|
| 609 |
-
out_tf = Path(args.out_dir) / f"tf_{args.tf_model}.npy"
|
| 610 |
-
embed_and_save(tf_seqs, tf_ids, tf_embedder, out_tf)
|
| 611 |
-
|
| 612 |
-
print("Done.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dpacman/classifier/model_tmp/extract_tf_symbols.py
DELETED
|
@@ -1,30 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
import pandas as pd
|
| 3 |
-
from pathlib import Path
|
| 4 |
-
|
| 5 |
-
FINAL_CSV = Path("/home/a03-akrishna/DPACMAN/data_files/processed/final.csv")
|
| 6 |
-
OUT_SYMBOLS = Path("tf_symbols.txt")
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
def normalize_tf(tf_id: str) -> str:
|
| 10 |
-
return tf_id.split("_seq")[0].upper()
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
def main():
|
| 14 |
-
df = pd.read_csv(FINAL_CSV, dtype=str)
|
| 15 |
-
if "TF_id" not in df.columns:
|
| 16 |
-
raise RuntimeError("final.csv missing TF_id column")
|
| 17 |
-
tf_raw = df["TF_id"].dropna().unique().tolist()
|
| 18 |
-
normalized = sorted({normalize_tf(t) for t in tf_raw})
|
| 19 |
-
print(f"Unique raw TF_id count: {len(tf_raw)}")
|
| 20 |
-
print(f"Unique normalized TF symbols: {len(normalized)}")
|
| 21 |
-
with open(OUT_SYMBOLS, "w") as f:
|
| 22 |
-
for s in normalized:
|
| 23 |
-
f.write(s + "\n")
|
| 24 |
-
print(f"Wrote normalized TF symbols to {OUT_SYMBOLS}")
|
| 25 |
-
# Optional: show sample
|
| 26 |
-
print("Sample symbols:", normalized[:50])
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
if __name__ == "__main__":
|
| 30 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dpacman/classifier/model_tmp/make_pair_list.py
DELETED
|
@@ -1,282 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
import argparse
|
| 3 |
-
import numpy as np
|
| 4 |
-
import pandas as pd
|
| 5 |
-
from pathlib import Path
|
| 6 |
-
import random
|
| 7 |
-
import sys
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
def read_ids_file(p):
|
| 11 |
-
p = Path(p)
|
| 12 |
-
if not p.exists():
|
| 13 |
-
raise FileNotFoundError(f"IDs file not found: {p}")
|
| 14 |
-
return [line.strip() for line in p.open() if line.strip()]
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
def split_embeddings(emb_path, ids_path, out_dir, prefix):
|
| 18 |
-
out_dir = Path(out_dir)
|
| 19 |
-
out_dir.mkdir(parents=True, exist_ok=True)
|
| 20 |
-
|
| 21 |
-
if not Path(emb_path).exists():
|
| 22 |
-
raise FileNotFoundError(f"Embedding file not found: {emb_path}")
|
| 23 |
-
if not Path(ids_path).exists():
|
| 24 |
-
raise FileNotFoundError(f"IDs file not found: {ids_path}")
|
| 25 |
-
|
| 26 |
-
if emb_path.endswith(".npz"):
|
| 27 |
-
data = np.load(emb_path, allow_pickle=True)
|
| 28 |
-
if "embeddings" in data:
|
| 29 |
-
emb = data["embeddings"]
|
| 30 |
-
else:
|
| 31 |
-
raise ValueError(f"{emb_path} missing 'embeddings' key")
|
| 32 |
-
else:
|
| 33 |
-
emb = np.load(emb_path)
|
| 34 |
-
|
| 35 |
-
ids = read_ids_file(ids_path)
|
| 36 |
-
if len(ids) != emb.shape[0]:
|
| 37 |
-
print(
|
| 38 |
-
f"[WARN] length mismatch: {len(ids)} ids vs {emb.shape[0]} embeddings in {emb_path}",
|
| 39 |
-
file=sys.stderr,
|
| 40 |
-
)
|
| 41 |
-
|
| 42 |
-
mapping = {}
|
| 43 |
-
for i, ident in enumerate(ids):
|
| 44 |
-
if i >= emb.shape[0]:
|
| 45 |
-
print(
|
| 46 |
-
f"[WARN] skipping {ident}: no embedding at index {i}", file=sys.stderr
|
| 47 |
-
)
|
| 48 |
-
continue
|
| 49 |
-
arr = emb[i]
|
| 50 |
-
out_file = out_dir / f"{prefix}_{ident}.npy"
|
| 51 |
-
np.save(out_file, arr)
|
| 52 |
-
mapping[ident] = str(out_file)
|
| 53 |
-
return mapping
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
def extract_symbol_from_tf_id(full_id: str) -> str:
|
| 57 |
-
"""
|
| 58 |
-
Given a TF embedding ID like 'sp|O15062|ZBTB5_HUMAN' or 'ZBTB5_HUMAN',
|
| 59 |
-
return the gene symbol uppercase (e.g., 'ZBTB5').
|
| 60 |
-
"""
|
| 61 |
-
if "|" in full_id:
|
| 62 |
-
try:
|
| 63 |
-
# format sp|Accession|SYMBOL_HUMAN
|
| 64 |
-
genepart = full_id.split("|")[2]
|
| 65 |
-
except IndexError:
|
| 66 |
-
genepart = full_id
|
| 67 |
-
else:
|
| 68 |
-
genepart = full_id
|
| 69 |
-
symbol = genepart.split("_")[0]
|
| 70 |
-
return symbol.upper()
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
def build_tf_symbol_map(tf_map):
|
| 74 |
-
"""
|
| 75 |
-
Build mapping gene_symbol -> list of embedding paths.
|
| 76 |
-
"""
|
| 77 |
-
symbol_map = {}
|
| 78 |
-
for full_id, path in tf_map.items():
|
| 79 |
-
symbol = extract_symbol_from_tf_id(full_id)
|
| 80 |
-
symbol_map.setdefault(symbol, []).append(path)
|
| 81 |
-
return symbol_map
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
def tf_key_from_path(path: str) -> str:
|
| 85 |
-
"""
|
| 86 |
-
Given a path like .../tf_sp|O15062|ZBTB5_HUMAN.npy, extract normalized symbol 'ZBTB5'.
|
| 87 |
-
"""
|
| 88 |
-
stem = Path(path).stem # e.g., tf_sp|O15062|ZBTB5_HUMAN
|
| 89 |
-
# remove leading prefix if present (tf_)
|
| 90 |
-
if "_" in stem:
|
| 91 |
-
_, rest = stem.split("_", 1)
|
| 92 |
-
else:
|
| 93 |
-
rest = stem
|
| 94 |
-
return extract_symbol_from_tf_id(rest)
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
def dna_key_from_path(path: str) -> str:
|
| 98 |
-
"""
|
| 99 |
-
Given .../dna_peak42.npy -> 'peak42'
|
| 100 |
-
"""
|
| 101 |
-
stem = Path(path).stem
|
| 102 |
-
if "_" in stem:
|
| 103 |
-
_, rest = stem.split("_", 1)
|
| 104 |
-
else:
|
| 105 |
-
rest = stem
|
| 106 |
-
return rest
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
def main():
|
| 110 |
-
parser = argparse.ArgumentParser(
|
| 111 |
-
description="Build TF-DNA pair list from final.csv with gene-symbol normalization for TFs."
|
| 112 |
-
)
|
| 113 |
-
parser.add_argument(
|
| 114 |
-
"--final_csv", required=True, help="final.csv with TF_id and dna_sequence"
|
| 115 |
-
)
|
| 116 |
-
parser.add_argument(
|
| 117 |
-
"--dna_embed_npz", required=True, help="DNA embedding file (.npy or .npz)"
|
| 118 |
-
)
|
| 119 |
-
parser.add_argument(
|
| 120 |
-
"--dna_ids", required=True, help="IDs file for DNA embeddings (e.g., peak*.ids)"
|
| 121 |
-
)
|
| 122 |
-
parser.add_argument(
|
| 123 |
-
"--tf_embed_npy", required=True, help="TF embedding file (.npy or .npz)"
|
| 124 |
-
)
|
| 125 |
-
parser.add_argument(
|
| 126 |
-
"--tf_ids",
|
| 127 |
-
required=True,
|
| 128 |
-
help="IDs file for TF embeddings (e.g., sp|...|... ids)",
|
| 129 |
-
)
|
| 130 |
-
parser.add_argument("--out_dir", required=True, help="Output directory")
|
| 131 |
-
parser.add_argument(
|
| 132 |
-
"--neg_per_positive",
|
| 133 |
-
type=int,
|
| 134 |
-
default=2,
|
| 135 |
-
help="Negatives per positive (half same-TF, half same-DNA)",
|
| 136 |
-
)
|
| 137 |
-
parser.add_argument("--seed", type=int, default=42)
|
| 138 |
-
args = parser.parse_args()
|
| 139 |
-
|
| 140 |
-
random.seed(args.seed)
|
| 141 |
-
out_dir = Path(args.out_dir)
|
| 142 |
-
out_dir.mkdir(parents=True, exist_ok=True)
|
| 143 |
-
|
| 144 |
-
# Load final.csv
|
| 145 |
-
df = pd.read_csv(args.final_csv, dtype=str)
|
| 146 |
-
if "TF_id" not in df.columns or "dna_sequence" not in df.columns:
|
| 147 |
-
raise RuntimeError("final.csv must have columns TF_id and dna_sequence")
|
| 148 |
-
|
| 149 |
-
# Assign dna_id (unique per dna_sequence)
|
| 150 |
-
unique_seqs = df["dna_sequence"].drop_duplicates().tolist()
|
| 151 |
-
seq_to_id = {seq: f"peak{i}" for i, seq in enumerate(unique_seqs)}
|
| 152 |
-
df["dna_id"] = df["dna_sequence"].map(seq_to_id)
|
| 153 |
-
enriched_csv = out_dir / "final_with_dna_id.csv"
|
| 154 |
-
df.to_csv(enriched_csv, index=False)
|
| 155 |
-
print(f"[i] Wrote augmented final.csv with dna_id to {enriched_csv}")
|
| 156 |
-
|
| 157 |
-
# Split embeddings into per-item files
|
| 158 |
-
print(
|
| 159 |
-
f"[i] Splitting DNA embeddings from {args.dna_embed_npz} with ids {args.dna_ids}"
|
| 160 |
-
)
|
| 161 |
-
dna_map = split_embeddings(
|
| 162 |
-
args.dna_embed_npz, args.dna_ids, out_dir / "dna_single", "dna"
|
| 163 |
-
)
|
| 164 |
-
print(
|
| 165 |
-
f"[i] DNA embeddings available: {len(dna_map)} (sample: {list(dna_map.keys())[:10]})"
|
| 166 |
-
)
|
| 167 |
-
print(
|
| 168 |
-
f"[i] Splitting TF embeddings from {args.tf_embed_npy} with ids {args.tf_ids}"
|
| 169 |
-
)
|
| 170 |
-
tf_map = split_embeddings(
|
| 171 |
-
args.tf_embed_npy, args.tf_ids, out_dir / "tf_single", "tf"
|
| 172 |
-
)
|
| 173 |
-
print(
|
| 174 |
-
f"[i] TF embeddings available: {len(tf_map)} (sample: {list(tf_map.keys())[:10]})"
|
| 175 |
-
)
|
| 176 |
-
|
| 177 |
-
# Build gene-symbol normalized map
|
| 178 |
-
tf_symbol_map = build_tf_symbol_map(tf_map)
|
| 179 |
-
print(f"[i] TF symbol map keys (sample): {list(tf_symbol_map.keys())[:30]}")
|
| 180 |
-
|
| 181 |
-
# Diagnostic overlaps
|
| 182 |
-
norm_tf_in_final = set(t.split("_seq")[0].upper() for t in df["TF_id"].unique())
|
| 183 |
-
available_tf_symbols = set(tf_symbol_map.keys())
|
| 184 |
-
intersect_tf = norm_tf_in_final & available_tf_symbols
|
| 185 |
-
print(f"[i] Unique normalized TF symbols in final.csv: {len(norm_tf_in_final)}")
|
| 186 |
-
print(f"[i] Available TF embedding symbols: {len(available_tf_symbols)}")
|
| 187 |
-
print(f"[i] Intersection count: {len(intersect_tf)}")
|
| 188 |
-
if len(intersect_tf) == 0:
|
| 189 |
-
print(
|
| 190 |
-
"[ERROR] No overlap between normalized TF_id and TF embedding symbols.",
|
| 191 |
-
file=sys.stderr,
|
| 192 |
-
)
|
| 193 |
-
print(
|
| 194 |
-
"Sample normalized TFs from final.csv:",
|
| 195 |
-
sorted(list(norm_tf_in_final))[:30],
|
| 196 |
-
file=sys.stderr,
|
| 197 |
-
)
|
| 198 |
-
print(
|
| 199 |
-
"Sample available TF symbols:",
|
| 200 |
-
sorted(list(available_tf_symbols))[:30],
|
| 201 |
-
file=sys.stderr,
|
| 202 |
-
)
|
| 203 |
-
sys.exit(1)
|
| 204 |
-
|
| 205 |
-
dna_ids_final = set(df["dna_id"].unique())
|
| 206 |
-
available_dna_ids = set(dna_map.keys())
|
| 207 |
-
intersect_dna = dna_ids_final & available_dna_ids
|
| 208 |
-
print(
|
| 209 |
-
f"[i] Unique dna_id in final.csv: {len(dna_ids_final)}. Available DNA ids: {len(available_dna_ids)}. Intersection: {len(intersect_dna)}"
|
| 210 |
-
)
|
| 211 |
-
if len(intersect_dna) == 0:
|
| 212 |
-
print("[ERROR] No overlap on DNA ids.", file=sys.stderr)
|
| 213 |
-
sys.exit(1)
|
| 214 |
-
|
| 215 |
-
# Build positive pairs
|
| 216 |
-
positives = []
|
| 217 |
-
for _, row in df.iterrows():
|
| 218 |
-
tf_raw = row["TF_id"]
|
| 219 |
-
tf_symbol = tf_raw.split("_seq")[0].upper()
|
| 220 |
-
dnaid = row["dna_id"]
|
| 221 |
-
if tf_symbol not in tf_symbol_map:
|
| 222 |
-
continue
|
| 223 |
-
if dnaid not in dna_map:
|
| 224 |
-
continue
|
| 225 |
-
# pick the first embedding for that symbol
|
| 226 |
-
tf_embedding_path = tf_symbol_map[tf_symbol][0]
|
| 227 |
-
positives.append((tf_embedding_path, dna_map[dnaid], 1))
|
| 228 |
-
print(f"[i] Constructed {len(positives)} positive pairs after TF symbol resolution")
|
| 229 |
-
|
| 230 |
-
if len(positives) == 0:
|
| 231 |
-
print(
|
| 232 |
-
"[ERROR] No positive pairs could be constructed; aborting.", file=sys.stderr
|
| 233 |
-
)
|
| 234 |
-
sys.exit(1)
|
| 235 |
-
|
| 236 |
-
# Build negative samples
|
| 237 |
-
all_tf_symbols = sorted(tf_symbol_map.keys())
|
| 238 |
-
all_dnaids = sorted(dna_map.keys())
|
| 239 |
-
positive_set = set()
|
| 240 |
-
for tf_path, dna_path, _ in positives:
|
| 241 |
-
tf_key = tf_key_from_path(tf_path)
|
| 242 |
-
dna_key = dna_key_from_path(dna_path)
|
| 243 |
-
positive_set.add((tf_key, dna_key))
|
| 244 |
-
|
| 245 |
-
negatives = []
|
| 246 |
-
half = args.neg_per_positive // 2
|
| 247 |
-
for tf_path, dna_path, _ in positives:
|
| 248 |
-
tf_key = tf_key_from_path(tf_path)
|
| 249 |
-
dna_key = dna_key_from_path(dna_path)
|
| 250 |
-
# same TF, different DNA
|
| 251 |
-
for _ in range(half):
|
| 252 |
-
candidate_dna = random.choice(all_dnaids)
|
| 253 |
-
if candidate_dna == dna_key or (tf_key, candidate_dna) in positive_set:
|
| 254 |
-
continue
|
| 255 |
-
negatives.append((tf_path, dna_map[candidate_dna], 0))
|
| 256 |
-
# same DNA, different TF
|
| 257 |
-
for _ in range(half):
|
| 258 |
-
candidate_tf_symbol = random.choice(all_tf_symbols)
|
| 259 |
-
if (
|
| 260 |
-
candidate_tf_symbol == tf_key
|
| 261 |
-
or (candidate_tf_symbol, dna_key) in positive_set
|
| 262 |
-
):
|
| 263 |
-
continue
|
| 264 |
-
# pick its first embedding
|
| 265 |
-
candidate_tf_path = tf_symbol_map[candidate_tf_symbol][0]
|
| 266 |
-
negatives.append((candidate_tf_path, dna_map[dnaid], 0))
|
| 267 |
-
|
| 268 |
-
print(
|
| 269 |
-
f"[i] Sampled {len(negatives)} negatives (neg_per_positive={args.neg_per_positive})"
|
| 270 |
-
)
|
| 271 |
-
|
| 272 |
-
# Write pair list
|
| 273 |
-
pair_list_path = out_dir / "pair_list.tsv"
|
| 274 |
-
with open(pair_list_path, "w") as f:
|
| 275 |
-
for binder_path, glm_path, label in positives + negatives:
|
| 276 |
-
# binder=TF, glm=DNA
|
| 277 |
-
f.write(f"{binder_path}\t{glm_path}\t{label}\n")
|
| 278 |
-
print(f"[i] Wrote {len(positives)+len(negatives)} examples to {pair_list_path}")
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
if __name__ == "__main__":
|
| 282 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dpacman/classifier/model_tmp/make_peak_fasta.py
DELETED
|
@@ -1,15 +0,0 @@
|
|
| 1 |
-
import pandas as pd
|
| 2 |
-
from pathlib import Path
|
| 3 |
-
|
| 4 |
-
df = pd.read_csv(
|
| 5 |
-
"/home/a03-akrishna/DPACMAN/data_files/processed/final.csv", dtype=str
|
| 6 |
-
) # adjust path if needed
|
| 7 |
-
# get unique sequences
|
| 8 |
-
uniq = df[["dna_sequence"]].drop_duplicates().reset_index(drop=True)
|
| 9 |
-
# make headers: e.g., peak0, peak1, ...
|
| 10 |
-
out_fa = Path("binding_peaks_unique.fa")
|
| 11 |
-
with open(out_fa, "w") as f:
|
| 12 |
-
for i, seq in enumerate(uniq["dna_sequence"]):
|
| 13 |
-
header = f">peak{i}"
|
| 14 |
-
f.write(f"{header}\n{seq}\n")
|
| 15 |
-
print(f"Wrote {len(uniq)} unique binding sequences to {out_fa}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dpacman/classifier/model_tmp/model.py
DELETED
|
@@ -1,111 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
from torch import nn
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
class LocalCNN(nn.Module):
|
| 6 |
-
def __init__(self, dim: int = 256, kernel_size: int = 3):
|
| 7 |
-
super().__init__()
|
| 8 |
-
padding = kernel_size // 2
|
| 9 |
-
self.conv = nn.Conv1d(dim, dim, kernel_size=kernel_size, padding=padding)
|
| 10 |
-
self.act = nn.GELU()
|
| 11 |
-
self.ln = nn.LayerNorm(dim)
|
| 12 |
-
|
| 13 |
-
def forward(self, x: torch.Tensor):
|
| 14 |
-
# x: (batch, L, dim)
|
| 15 |
-
out = self.conv(x.transpose(1, 2)) # → (batch, dim, L)
|
| 16 |
-
out = self.act(out)
|
| 17 |
-
out = out.transpose(1, 2) # → (batch, L, dim)
|
| 18 |
-
return self.ln(out + x) # residual
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
class CrossModalBlock(nn.Module):
|
| 22 |
-
def __init__(self, dim: int = 256, heads: int = 8):
|
| 23 |
-
super().__init__()
|
| 24 |
-
# self-attention for both sides
|
| 25 |
-
self.sa_binder = nn.MultiheadAttention(dim, heads, batch_first=True)
|
| 26 |
-
self.sa_glm = nn.MultiheadAttention(dim, heads, batch_first=True)
|
| 27 |
-
self.ln_b1 = nn.LayerNorm(dim)
|
| 28 |
-
self.ln_g1 = nn.LayerNorm(dim)
|
| 29 |
-
|
| 30 |
-
self.ffn_b = nn.Sequential(
|
| 31 |
-
nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
|
| 32 |
-
)
|
| 33 |
-
self.ffn_g = nn.Sequential(
|
| 34 |
-
nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
|
| 35 |
-
)
|
| 36 |
-
self.ln_b2 = nn.LayerNorm(dim)
|
| 37 |
-
self.ln_g2 = nn.LayerNorm(dim)
|
| 38 |
-
|
| 39 |
-
# cross attention (binder queries, glm keys/values)
|
| 40 |
-
self.cross_attn = nn.MultiheadAttention(dim, heads, batch_first=True)
|
| 41 |
-
self.ln_c1 = nn.LayerNorm(dim)
|
| 42 |
-
self.ffn_c = nn.Sequential(
|
| 43 |
-
nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
|
| 44 |
-
)
|
| 45 |
-
self.ln_c2 = nn.LayerNorm(dim)
|
| 46 |
-
|
| 47 |
-
def forward(self, binder: torch.Tensor, glm: torch.Tensor):
|
| 48 |
-
"""
|
| 49 |
-
binder: (batch, Lb, dim)
|
| 50 |
-
glm: (batch, Lg, dim) -- has passed through its local CNN beforehand
|
| 51 |
-
returns: updated binder representation (batch, Lb, dim)
|
| 52 |
-
"""
|
| 53 |
-
# binder self-attn + ffn
|
| 54 |
-
b = binder
|
| 55 |
-
b_sa, _ = self.sa_binder(b, b, b)
|
| 56 |
-
b = self.ln_b1(b + b_sa)
|
| 57 |
-
b_ff = self.ffn_b(b)
|
| 58 |
-
b = self.ln_b2(b + b_ff)
|
| 59 |
-
|
| 60 |
-
# glm self-attn + ffn
|
| 61 |
-
g = glm
|
| 62 |
-
g_sa, _ = self.sa_glm(g, g, g)
|
| 63 |
-
g = self.ln_g1(g + g_sa)
|
| 64 |
-
g_ff = self.ffn_g(g)
|
| 65 |
-
g = self.ln_g2(g + g_ff)
|
| 66 |
-
|
| 67 |
-
# cross-attention: binder queries glm
|
| 68 |
-
c_sa, _ = self.cross_attn(b, g, g)
|
| 69 |
-
c = self.ln_c1(b + c_sa)
|
| 70 |
-
c_ff = self.ffn_c(c)
|
| 71 |
-
c = self.ln_c2(c + c_ff)
|
| 72 |
-
return c # (batch, Lb, dim)
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
class BindPredictor(nn.Module):
|
| 76 |
-
def __init__(
|
| 77 |
-
self,
|
| 78 |
-
input_dim: int = 256,
|
| 79 |
-
hidden_dim: int = 256,
|
| 80 |
-
heads: int = 8,
|
| 81 |
-
num_layers: int = 4,
|
| 82 |
-
use_local_cnn_on_glm: bool = True,
|
| 83 |
-
):
|
| 84 |
-
super().__init__()
|
| 85 |
-
self.proj_binder = nn.Linear(input_dim, hidden_dim)
|
| 86 |
-
self.proj_glm = nn.Linear(input_dim, hidden_dim)
|
| 87 |
-
self.use_local_cnn = use_local_cnn_on_glm
|
| 88 |
-
self.local_cnn = LocalCNN(hidden_dim) if use_local_cnn_on_glm else nn.Identity()
|
| 89 |
-
|
| 90 |
-
self.layers = nn.ModuleList(
|
| 91 |
-
[CrossModalBlock(hidden_dim, heads) for _ in range(num_layers)]
|
| 92 |
-
)
|
| 93 |
-
|
| 94 |
-
self.ln_out = nn.LayerNorm(hidden_dim)
|
| 95 |
-
self.head = nn.Sequential(nn.Linear(hidden_dim, 1), nn.Sigmoid())
|
| 96 |
-
|
| 97 |
-
def forward(self, binder_emb, glm_emb):
|
| 98 |
-
"""
|
| 99 |
-
binder_emb, glm_emb: (batch, L, input_dim)
|
| 100 |
-
"""
|
| 101 |
-
b = self.proj_binder(binder_emb) # (B, Lb, hidden_dim)
|
| 102 |
-
g = self.proj_glm(glm_emb) # (B, Lg, hidden_dim)
|
| 103 |
-
if self.use_local_cnn:
|
| 104 |
-
g = self.local_cnn(g) # local context injected
|
| 105 |
-
|
| 106 |
-
for layer in self.layers:
|
| 107 |
-
b = layer(b, g) # update binder with cross-modal info
|
| 108 |
-
|
| 109 |
-
pooled = b.mean(dim=1) # (B, hidden_dim)
|
| 110 |
-
out = self.ln_out(pooled)
|
| 111 |
-
return self.head(out).squeeze(-1) # (B,)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dpacman/classifier/model_tmp/prep_splits.py
DELETED
|
@@ -1,157 +0,0 @@
|
|
| 1 |
-
import numpy as np
|
| 2 |
-
import pandas as pd
|
| 3 |
-
import sys
|
| 4 |
-
import json
|
| 5 |
-
from sklearn.decomposition import TruncatedSVD
|
| 6 |
-
from sklearn.model_selection import train_test_split
|
| 7 |
-
from collections import Counter
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
def parse_pair_list(pair_list_path):
|
| 11 |
-
binder_paths, glm_paths, labels = [], [], []
|
| 12 |
-
with open(pair_list_path) as f:
|
| 13 |
-
for lineno, line in enumerate(f, start=1):
|
| 14 |
-
if not line.strip():
|
| 15 |
-
continue
|
| 16 |
-
parts = line.strip().split()
|
| 17 |
-
if len(parts) != 3:
|
| 18 |
-
print(
|
| 19 |
-
f"[WARN] skipping malformed line {lineno}: {line.strip()}",
|
| 20 |
-
file=sys.stderr,
|
| 21 |
-
)
|
| 22 |
-
continue
|
| 23 |
-
b, g, l = parts
|
| 24 |
-
try:
|
| 25 |
-
lab = int(l)
|
| 26 |
-
except ValueError:
|
| 27 |
-
print(f"[WARN] invalid label on line {lineno}: {l}", file=sys.stderr)
|
| 28 |
-
continue
|
| 29 |
-
binder_paths.append(b)
|
| 30 |
-
glm_paths.append(g)
|
| 31 |
-
labels.append(lab)
|
| 32 |
-
return binder_paths, glm_paths, labels
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
def build_tf_compressed_cache(binder_paths, target_dim=256):
|
| 36 |
-
"""
|
| 37 |
-
Load all unique TF (binder) embeddings, fit reduction if needed, and return dict mapping path->(L, target_dim) array.
|
| 38 |
-
"""
|
| 39 |
-
unique_paths = sorted(set(binder_paths))
|
| 40 |
-
print(
|
| 41 |
-
f"[i] Found {len(unique_paths)} unique TF embedding files to compress.",
|
| 42 |
-
flush=True,
|
| 43 |
-
)
|
| 44 |
-
# Load all embeddings to determine dimensionality
|
| 45 |
-
samples = []
|
| 46 |
-
for p in unique_paths:
|
| 47 |
-
arr = np.load(p)
|
| 48 |
-
samples.append(arr)
|
| 49 |
-
# Determine if reduction needed: assume all have same embedding width
|
| 50 |
-
first = samples[0]
|
| 51 |
-
orig_dim = first.shape[1] if first.ndim == 2 else 1
|
| 52 |
-
reduction_needed = orig_dim != target_dim
|
| 53 |
-
tf_cache = {}
|
| 54 |
-
|
| 55 |
-
if reduction_needed:
|
| 56 |
-
# Build matrix to fit SVD: we need a 2D matrix per embedding; if lengths vary we can't directly stack.
|
| 57 |
-
# We'll do reduction per sequence individually using TruncatedSVD on concatenated flattened features:
|
| 58 |
-
# Simplest: for variable lengths, reduce each embedding separately with a learned linear projection.
|
| 59 |
-
# Here we fit a single TruncatedSVD on the concatenation of all sequence tokens (flattened) by padding/truncating to a fixed length.
|
| 60 |
-
# To avoid complexity, use PCA-like linear projection learned via SVD on mean-pooled vectors:
|
| 61 |
-
pooled = []
|
| 62 |
-
for arr in samples:
|
| 63 |
-
if arr.ndim == 2:
|
| 64 |
-
pooled.append(arr.mean(axis=0)) # (orig_dim,)
|
| 65 |
-
else:
|
| 66 |
-
pooled.append(arr) # degenerate
|
| 67 |
-
pooled_mat = np.stack(pooled, axis=0) # (N, orig_dim)
|
| 68 |
-
print(
|
| 69 |
-
f"[i] Fitting TruncatedSVD on TF pooled embeddings: {pooled_mat.shape} -> {target_dim}",
|
| 70 |
-
flush=True,
|
| 71 |
-
)
|
| 72 |
-
svd = TruncatedSVD(n_components=target_dim, random_state=42)
|
| 73 |
-
reduced_pooled = svd.fit_transform(pooled_mat) # (N, target_dim)
|
| 74 |
-
|
| 75 |
-
# For each original embedding, project token-level vectors by multiplying token vector with svd.components_.T
|
| 76 |
-
# svd.components_: (target_dim, orig_dim) so projection matrix is (orig_dim, target_dim)
|
| 77 |
-
proj_mat = svd.components_.T # (orig_dim, target_dim)
|
| 78 |
-
for i, p in enumerate(unique_paths):
|
| 79 |
-
arr = samples[i] # shape (L, orig_dim)
|
| 80 |
-
if arr.ndim == 1:
|
| 81 |
-
arr2 = arr @ proj_mat # (target_dim,)
|
| 82 |
-
else:
|
| 83 |
-
# project each token: (L, orig_dim) @ (orig_dim, target_dim) -> (L, target_dim)
|
| 84 |
-
arr2 = arr @ proj_mat
|
| 85 |
-
tf_cache[p] = arr2 # reduced per-token representation
|
| 86 |
-
print("[i] Completed compression of TF embeddings.", flush=True)
|
| 87 |
-
else:
|
| 88 |
-
# already correct dim: just cache originals
|
| 89 |
-
print(
|
| 90 |
-
f"[i] TF embeddings already {target_dim}-dimensional; skipping reduction.",
|
| 91 |
-
flush=True,
|
| 92 |
-
)
|
| 93 |
-
for i, p in enumerate(unique_paths):
|
| 94 |
-
arr = samples[i]
|
| 95 |
-
tf_cache[p] = arr
|
| 96 |
-
return tf_cache
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
def main():
|
| 100 |
-
# df = pd.read_csv("../data_files/processed/fimo/ananya_aug4_2025_final.csv")
|
| 101 |
-
|
| 102 |
-
binder_paths, glm_paths, labels = parse_pair_list(
|
| 103 |
-
"../data_files/processed/fimo/ananya_aug4_2025_pair_list.tsv"
|
| 104 |
-
)
|
| 105 |
-
|
| 106 |
-
if len(labels) == 0:
|
| 107 |
-
print("[ERROR] No valid pairs parsed. Exiting.", file=sys.stderr)
|
| 108 |
-
sys.exit(1)
|
| 109 |
-
|
| 110 |
-
label_counts = Counter(labels)
|
| 111 |
-
print(
|
| 112 |
-
f"[i] Total examples parsed: {len(labels)}. Label distribution: {label_counts}",
|
| 113 |
-
flush=True,
|
| 114 |
-
)
|
| 115 |
-
|
| 116 |
-
# build compressed TF cache (reduces to 256 if needed)
|
| 117 |
-
# tf_compressed_cache = build_tf_compressed_cache(binder_paths, target_dim=256)
|
| 118 |
-
|
| 119 |
-
# Combine all data into one structure for easy splitting
|
| 120 |
-
data = list(zip(binder_paths, glm_paths, labels))
|
| 121 |
-
|
| 122 |
-
# First split: train vs temp (val+test)
|
| 123 |
-
train_data, temp_data = train_test_split(data, test_size=0.2, random_state=42)
|
| 124 |
-
|
| 125 |
-
# Second split: val vs test (50% of 20% → 10% each)
|
| 126 |
-
val_data, test_data = train_test_split(temp_data, test_size=0.5, random_state=42)
|
| 127 |
-
|
| 128 |
-
# Unpack for dataset construction
|
| 129 |
-
def unpack(data):
|
| 130 |
-
binders, glms, labels = zip(*data)
|
| 131 |
-
return list(binders), list(glms), list(labels)
|
| 132 |
-
|
| 133 |
-
def save_split(binder_paths, glm_paths, labels, out_path):
|
| 134 |
-
df = pd.DataFrame(
|
| 135 |
-
{
|
| 136 |
-
"binder_path": binder_paths,
|
| 137 |
-
"glm_path": glm_paths,
|
| 138 |
-
"label": labels,
|
| 139 |
-
}
|
| 140 |
-
)
|
| 141 |
-
df.to_csv(out_path, index=False)
|
| 142 |
-
|
| 143 |
-
# Unpack data for saving
|
| 144 |
-
train_binders, train_glms, train_labels = unpack(train_data)
|
| 145 |
-
val_binders, val_glms, val_labels = unpack(val_data)
|
| 146 |
-
test_binders, test_glms, test_labels = unpack(test_data)
|
| 147 |
-
|
| 148 |
-
# Save each split
|
| 149 |
-
save_split(
|
| 150 |
-
train_binders, train_glms, train_labels, "../data_files/splits/train.csv"
|
| 151 |
-
)
|
| 152 |
-
save_split(val_binders, val_glms, val_labels, "../data_files/splits/val.csv")
|
| 153 |
-
save_split(test_binders, test_glms, test_labels, "../data_files/splits/test.csv")
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
if __name__ == "__main__":
|
| 157 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dpacman/classifier/model_tmp/train.py
DELETED
|
@@ -1,217 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
import argparse
|
| 3 |
-
import numpy as np
|
| 4 |
-
import torch
|
| 5 |
-
from torch import nn
|
| 6 |
-
from model import BindPredictor
|
| 7 |
-
from pathlib import Path
|
| 8 |
-
from collections import Counter
|
| 9 |
-
from sklearn.metrics import roc_auc_score, average_precision_score
|
| 10 |
-
from sklearn.decomposition import TruncatedSVD
|
| 11 |
-
import sys
|
| 12 |
-
|
| 13 |
-
from dpacman.utils.models import set_seed
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
def build_tf_compressed_cache(binder_paths, target_dim=256):
|
| 17 |
-
"""
|
| 18 |
-
Load all unique TF (binder) embeddings, fit reduction if needed, and return dict mapping path->(L, target_dim) array.
|
| 19 |
-
"""
|
| 20 |
-
unique_paths = sorted(set(binder_paths))
|
| 21 |
-
print(
|
| 22 |
-
f"[i] Found {len(unique_paths)} unique TF embedding files to compress.",
|
| 23 |
-
flush=True,
|
| 24 |
-
)
|
| 25 |
-
# Load all embeddings to determine dimensionality
|
| 26 |
-
samples = []
|
| 27 |
-
for p in unique_paths:
|
| 28 |
-
arr = np.load(p)
|
| 29 |
-
samples.append(arr)
|
| 30 |
-
# Determine if reduction needed: assume all have same embedding width
|
| 31 |
-
first = samples[0]
|
| 32 |
-
orig_dim = first.shape[1] if first.ndim == 2 else 1
|
| 33 |
-
reduction_needed = orig_dim != target_dim
|
| 34 |
-
tf_cache = {}
|
| 35 |
-
|
| 36 |
-
if reduction_needed:
|
| 37 |
-
# Build matrix to fit SVD: we need a 2D matrix per embedding; if lengths vary we can't directly stack.
|
| 38 |
-
# We'll do reduction per sequence individually using TruncatedSVD on concatenated flattened features:
|
| 39 |
-
# Simplest: for variable lengths, reduce each embedding separately with a learned linear projection.
|
| 40 |
-
# Here we fit a single TruncatedSVD on the concatenation of all sequence tokens (flattened) by padding/truncating to a fixed length.
|
| 41 |
-
# To avoid complexity, use PCA-like linear projection learned via SVD on mean-pooled vectors:
|
| 42 |
-
pooled = []
|
| 43 |
-
for arr in samples:
|
| 44 |
-
if arr.ndim == 2:
|
| 45 |
-
pooled.append(arr.mean(axis=0)) # (orig_dim,)
|
| 46 |
-
else:
|
| 47 |
-
pooled.append(arr) # degenerate
|
| 48 |
-
pooled_mat = np.stack(pooled, axis=0) # (N, orig_dim)
|
| 49 |
-
print(
|
| 50 |
-
f"[i] Fitting TruncatedSVD on TF pooled embeddings: {pooled_mat.shape} -> {target_dim}",
|
| 51 |
-
flush=True,
|
| 52 |
-
)
|
| 53 |
-
svd = TruncatedSVD(n_components=target_dim, random_state=42)
|
| 54 |
-
reduced_pooled = svd.fit_transform(pooled_mat) # (N, target_dim)
|
| 55 |
-
|
| 56 |
-
# For each original embedding, project token-level vectors by multiplying token vector with svd.components_.T
|
| 57 |
-
# svd.components_: (target_dim, orig_dim) so projection matrix is (orig_dim, target_dim)
|
| 58 |
-
proj_mat = svd.components_.T # (orig_dim, target_dim)
|
| 59 |
-
for i, p in enumerate(unique_paths):
|
| 60 |
-
arr = samples[i] # shape (L, orig_dim)
|
| 61 |
-
if arr.ndim == 1:
|
| 62 |
-
arr2 = arr @ proj_mat # (target_dim,)
|
| 63 |
-
else:
|
| 64 |
-
# project each token: (L, orig_dim) @ (orig_dim, target_dim) -> (L, target_dim)
|
| 65 |
-
arr2 = arr @ proj_mat
|
| 66 |
-
tf_cache[p] = arr2 # reduced per-token representation
|
| 67 |
-
print("[i] Completed compression of TF embeddings.", flush=True)
|
| 68 |
-
else:
|
| 69 |
-
# already correct dim: just cache originals
|
| 70 |
-
print(
|
| 71 |
-
f"[i] TF embeddings already {target_dim}-dimensional; skipping reduction.",
|
| 72 |
-
flush=True,
|
| 73 |
-
)
|
| 74 |
-
for i, p in enumerate(unique_paths):
|
| 75 |
-
arr = samples[i]
|
| 76 |
-
tf_cache[p] = arr
|
| 77 |
-
return tf_cache
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
def evaluate(model, dl, device):
|
| 81 |
-
model.eval()
|
| 82 |
-
all_labels = []
|
| 83 |
-
all_preds = []
|
| 84 |
-
with torch.no_grad():
|
| 85 |
-
for b, g, y in dl:
|
| 86 |
-
b = b.to(device)
|
| 87 |
-
g = g.to(device)
|
| 88 |
-
y = y.to(device)
|
| 89 |
-
pred = model(b, g)
|
| 90 |
-
all_labels.append(y.cpu())
|
| 91 |
-
all_preds.append(pred.cpu())
|
| 92 |
-
if not all_labels:
|
| 93 |
-
return 0.0, 0.0
|
| 94 |
-
y_true = torch.cat(all_labels).numpy()
|
| 95 |
-
y_score = torch.cat(all_preds).numpy()
|
| 96 |
-
try:
|
| 97 |
-
auc = roc_auc_score(y_true, y_score)
|
| 98 |
-
except Exception:
|
| 99 |
-
auc = 0.0
|
| 100 |
-
try:
|
| 101 |
-
ap = average_precision_score(y_true, y_score)
|
| 102 |
-
except Exception:
|
| 103 |
-
ap = 0.0
|
| 104 |
-
return auc, ap
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
def unpack(data):
|
| 108 |
-
binders, glms, labels = zip(*data)
|
| 109 |
-
return list(binders), list(glms), list(labels)
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
# ---- main ------------------------------------------------------------
|
| 113 |
-
def main(cfg):
|
| 114 |
-
# Set seed for reproducibility
|
| 115 |
-
set_seed(cfg.seed)
|
| 116 |
-
|
| 117 |
-
parser.add_argument("--out_dir", type=str, required=True)
|
| 118 |
-
parser.add_argument("--epochs", type=int, default=10)
|
| 119 |
-
parser.add_argument("--batch_size", type=int, default=32)
|
| 120 |
-
parser.add_argument("--lr", type=float, default=1e-4)
|
| 121 |
-
parser.add_argument("--device", type=str, default="cuda")
|
| 122 |
-
parser.add_argument("--seed", type=int, default=42)
|
| 123 |
-
args = parser.parse_args()
|
| 124 |
-
|
| 125 |
-
#
|
| 126 |
-
print("DEBUG: starting training script with in-line TF compression", flush=True)
|
| 127 |
-
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
|
| 128 |
-
binder_paths, glm_paths, labels = parse_pair_list(cfg.pair_list)
|
| 129 |
-
|
| 130 |
-
if len(labels) == 0:
|
| 131 |
-
print("[ERROR] No valid pairs parsed. Exiting.", file=sys.stderr)
|
| 132 |
-
sys.exit(1)
|
| 133 |
-
|
| 134 |
-
label_counts = Counter(labels)
|
| 135 |
-
print(
|
| 136 |
-
f"[i] Total examples parsed: {len(labels)}. Label distribution: {label_counts}",
|
| 137 |
-
flush=True,
|
| 138 |
-
)
|
| 139 |
-
|
| 140 |
-
# build compressed TF cache (reduces to 256 if needed)
|
| 141 |
-
tf_compressed_cache = build_tf_compressed_cache(binder_paths, target_dim=256)
|
| 142 |
-
|
| 143 |
-
# load training data aloiaushasfoiuhasfoiuafasdfoihuaaasdfoiuhasfaaoiufhasfoasasfoiuh
|
| 144 |
-
|
| 145 |
-
train_ds = PairDataset(None, tf_compressed_cache=tf_compressed_cache)
|
| 146 |
-
val_ds = PairDataset(*subset(val_i), tf_compressed_cache=tf_compressed_cache)
|
| 147 |
-
test_ds = PairDataset(*subset(test_i), tf_compressed_cache=tf_compressed_cache)
|
| 148 |
-
|
| 149 |
-
print(
|
| 150 |
-
f"[i] Train/Val/Test sizes: {len(train_ds)}/{len(val_ds)}/{len(test_ds)}",
|
| 151 |
-
flush=True,
|
| 152 |
-
)
|
| 153 |
-
if len(train_ds) == 0 or len(val_ds) == 0:
|
| 154 |
-
print(
|
| 155 |
-
"[ERROR] Train or validation split is empty; cannot proceed.",
|
| 156 |
-
file=sys.stderr,
|
| 157 |
-
)
|
| 158 |
-
sys.exit(1)
|
| 159 |
-
|
| 160 |
-
train_dl = DataLoader(
|
| 161 |
-
train_ds, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn
|
| 162 |
-
)
|
| 163 |
-
val_dl = DataLoader(
|
| 164 |
-
val_ds, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn
|
| 165 |
-
)
|
| 166 |
-
test_dl = DataLoader(
|
| 167 |
-
test_ds, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn
|
| 168 |
-
)
|
| 169 |
-
|
| 170 |
-
model = BindPredictor(
|
| 171 |
-
input_dim=256, hidden_dim=256, heads=8, num_layers=3, use_local_cnn_on_glm=True
|
| 172 |
-
)
|
| 173 |
-
model = model.to(device)
|
| 174 |
-
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-3)
|
| 175 |
-
loss_fn = nn.BCELoss()
|
| 176 |
-
|
| 177 |
-
best_val = -float("inf")
|
| 178 |
-
os_out = Path(args.out_dir)
|
| 179 |
-
os_out.mkdir(exist_ok=True, parents=True)
|
| 180 |
-
|
| 181 |
-
for epoch in range(1, args.epochs + 1):
|
| 182 |
-
print(f"[Epoch {epoch}] starting...", flush=True)
|
| 183 |
-
model.train()
|
| 184 |
-
running_loss = 0.0
|
| 185 |
-
for b, g, y in train_dl:
|
| 186 |
-
b = b.to(device)
|
| 187 |
-
g = g.to(device)
|
| 188 |
-
y = y.to(device)
|
| 189 |
-
pred = model(b, g)
|
| 190 |
-
loss = loss_fn(pred, y)
|
| 191 |
-
optimizer.zero_grad()
|
| 192 |
-
loss.backward()
|
| 193 |
-
optimizer.step()
|
| 194 |
-
running_loss += loss.item() * b.size(0)
|
| 195 |
-
train_loss = running_loss / len(train_ds)
|
| 196 |
-
val_auc, val_ap = evaluate(model, val_dl, device)
|
| 197 |
-
print(
|
| 198 |
-
f"[Epoch {epoch}] train_loss={train_loss:.4f} val_auc={val_auc:.4f} val_ap={val_ap:.4f}",
|
| 199 |
-
flush=True,
|
| 200 |
-
)
|
| 201 |
-
|
| 202 |
-
if val_auc > best_val:
|
| 203 |
-
best_val = val_auc
|
| 204 |
-
torch.save(model.state_dict(), os_out / "best_model.pt")
|
| 205 |
-
print(
|
| 206 |
-
f"[Epoch {epoch}] Saved new best model with val_auc={val_auc:.4f}",
|
| 207 |
-
flush=True,
|
| 208 |
-
)
|
| 209 |
-
|
| 210 |
-
torch.save(model.state_dict(), os_out / "last_model.pt")
|
| 211 |
-
test_auc, test_ap = evaluate(model, test_dl, device)
|
| 212 |
-
print(f"FINAL TEST: AUC={test_auc:.4f} AP={test_ap:.4f}", flush=True)
|
| 213 |
-
print(f"[i] Models written to {os_out}/best_model.pt and last_model.pt", flush=True)
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
if __name__ == "__main__":
|
| 217 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dpacman/classifier/old_train.py
DELETED
|
@@ -1,486 +0,0 @@
|
|
| 1 |
-
import argparse, random, sys
|
| 2 |
-
from pathlib import Path
|
| 3 |
-
|
| 4 |
-
import numpy as np
|
| 5 |
-
import pandas as pd
|
| 6 |
-
import torch
|
| 7 |
-
from torch import nn
|
| 8 |
-
from torch.utils.data import Dataset, DataLoader, Sampler
|
| 9 |
-
|
| 10 |
-
# from sklearn.random_projection import GaussianRandomProjection # OLD (kept): projection was removed earlier
|
| 11 |
-
import matplotlib.pyplot as plt
|
| 12 |
-
|
| 13 |
-
import torch.amp as amp
|
| 14 |
-
from torch.nn import functional as F
|
| 15 |
-
from model import BindPredictor
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
# ─────────────── utilities ────────────────────────────────────────────────
|
| 19 |
-
def parse_pair_list(path):
|
| 20 |
-
binders, glms = [], []
|
| 21 |
-
with open(path) as f:
|
| 22 |
-
for ln, line in enumerate(f, 1):
|
| 23 |
-
parts = line.strip().split()
|
| 24 |
-
if len(parts) < 2:
|
| 25 |
-
continue
|
| 26 |
-
b, g = parts[0], parts[1]
|
| 27 |
-
binders.append(b)
|
| 28 |
-
glms.append(g)
|
| 29 |
-
return binders, glms
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
class ListBatchSampler(Sampler):
|
| 33 |
-
def __init__(self, batches):
|
| 34 |
-
self.batches = batches
|
| 35 |
-
|
| 36 |
-
def __iter__(self):
|
| 37 |
-
return iter(self.batches)
|
| 38 |
-
|
| 39 |
-
def __len__(self):
|
| 40 |
-
return len(self.batches)
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
def make_buckets(idxs, glm_paths, batch_size, n_buckets=10, seed=42):
|
| 44 |
-
rng = random.Random(seed)
|
| 45 |
-
lengths = [(i, np.load(glm_paths[i]).shape[0]) for i in idxs]
|
| 46 |
-
lengths.sort(key=lambda x: x[1])
|
| 47 |
-
size = max(1, int(np.ceil(len(lengths) / n_buckets)))
|
| 48 |
-
buckets = [lengths[i : i + size] for i in range(0, len(lengths), size)]
|
| 49 |
-
batches = []
|
| 50 |
-
for bucket in buckets:
|
| 51 |
-
ids = [i for i, _ in bucket]
|
| 52 |
-
rng.shuffle(ids)
|
| 53 |
-
for i in range(0, len(ids), batch_size):
|
| 54 |
-
batches.append(ids[i : i + batch_size])
|
| 55 |
-
rng.shuffle(batches)
|
| 56 |
-
return batches
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
def dna_key_from_path(path: str) -> str:
|
| 60 |
-
""".../dna_peak42.npy -> 'peak42'"""
|
| 61 |
-
stem = Path(path).stem
|
| 62 |
-
if "_" in stem:
|
| 63 |
-
_, rest = stem.split("_", 1)
|
| 64 |
-
else:
|
| 65 |
-
rest = stem
|
| 66 |
-
return rest
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
def build_tf_cache(tf_paths, target_dim=256):
|
| 70 |
-
"""
|
| 71 |
-
Load raw TF embeddings without projecting; compression is learnable in the model.
|
| 72 |
-
"""
|
| 73 |
-
unique = sorted(set(tf_paths))
|
| 74 |
-
print(
|
| 75 |
-
f"[i] (Learnable) Preparing {len(unique)} TF files; target {target_dim}d inside the model",
|
| 76 |
-
flush=True,
|
| 77 |
-
)
|
| 78 |
-
|
| 79 |
-
pools, raw = [], []
|
| 80 |
-
for p in unique:
|
| 81 |
-
arr = np.load(p) # (L, D) or (D,)
|
| 82 |
-
raw.append(arr)
|
| 83 |
-
pools.append(arr.mean(axis=0) if arr.ndim == 2 else arr)
|
| 84 |
-
M = np.stack(pools, 0)
|
| 85 |
-
orig_dim = M.shape[1]
|
| 86 |
-
print(f"[i] Pooled shape → {M.shape} (orig_dim={orig_dim})", flush=True)
|
| 87 |
-
|
| 88 |
-
cache = {}
|
| 89 |
-
for i, p in enumerate(unique):
|
| 90 |
-
arr = raw[i]
|
| 91 |
-
# OLD: projection here (removed)
|
| 92 |
-
cache[p] = arr
|
| 93 |
-
print("[i] TF cache ready (raw); compression will be learned.", flush=True)
|
| 94 |
-
return cache
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
# ─────────────── Dataset & Collation ─────────────────────────────────────
|
| 98 |
-
class PairDataset(Dataset):
|
| 99 |
-
def __init__(self, tf_paths, dna_paths, final_df, tf_cache):
|
| 100 |
-
self.tf_paths, self.dna_paths = tf_paths, dna_paths
|
| 101 |
-
self.tf_cache = tf_cache
|
| 102 |
-
self.targets = {}
|
| 103 |
-
for _, row in final_df.iterrows():
|
| 104 |
-
dna_id = row["dna_id"]
|
| 105 |
-
vec = np.array(
|
| 106 |
-
list(map(float, row["score_sig_r2"].split(","))), dtype=np.float32
|
| 107 |
-
)
|
| 108 |
-
self.targets[dna_id] = vec
|
| 109 |
-
|
| 110 |
-
def __len__(self):
|
| 111 |
-
return len(self.tf_paths)
|
| 112 |
-
|
| 113 |
-
def __getitem__(self, i):
|
| 114 |
-
b = self.tf_cache[self.tf_paths[i]] # (L_b, D_b) or (D_b,)
|
| 115 |
-
if b.ndim == 1:
|
| 116 |
-
b = b[None, :]
|
| 117 |
-
g = np.load(self.dna_paths[i]) # (L_g, 256) or (256,)
|
| 118 |
-
if g.ndim == 1:
|
| 119 |
-
g = g[None, :]
|
| 120 |
-
|
| 121 |
-
stem = Path(self.dna_paths[i]).stem
|
| 122 |
-
dna_id = stem.replace("dna_", "")
|
| 123 |
-
t = self.targets.get(dna_id, np.zeros(g.shape[0], dtype=np.float32))
|
| 124 |
-
|
| 125 |
-
return (
|
| 126 |
-
torch.from_numpy(b).float(),
|
| 127 |
-
torch.from_numpy(g).float(),
|
| 128 |
-
torch.from_numpy(t).float(),
|
| 129 |
-
)
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
def collate_fn(batch):
|
| 133 |
-
Bs = [b.shape[0] for b, _, _ in batch]
|
| 134 |
-
Gs = [g.shape[0] for _, g, _ in batch]
|
| 135 |
-
maxB, maxG = max(Bs), max(Gs)
|
| 136 |
-
|
| 137 |
-
def pad_seq(x, L):
|
| 138 |
-
if x.shape[0] < L:
|
| 139 |
-
pad = torch.zeros(
|
| 140 |
-
(L - x.shape[0], x.shape[1]), dtype=x.dtype, device=x.device
|
| 141 |
-
)
|
| 142 |
-
return torch.cat([x, pad], dim=0)
|
| 143 |
-
return x
|
| 144 |
-
|
| 145 |
-
def pad_t(y, L):
|
| 146 |
-
if y.shape[0] < L:
|
| 147 |
-
pad = torch.zeros((L - y.shape[0],), dtype=y.dtype, device=y.device)
|
| 148 |
-
return torch.cat([y, pad], dim=0)
|
| 149 |
-
return y
|
| 150 |
-
|
| 151 |
-
b_stack = torch.stack([pad_seq(b, maxB) for b, _, _ in batch])
|
| 152 |
-
g_stack = torch.stack([pad_seq(g, maxG) for _, g, _ in batch])
|
| 153 |
-
t_stack = torch.stack([pad_t(t, maxG) for *_, t in batch])
|
| 154 |
-
return b_stack, g_stack, t_stack
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
# ──���──────────── losses, metrics ─────────────────────────────────────────
|
| 158 |
-
def combined_loss_components(logits, targets, peak_thresh=0.5, eps=1e-8):
|
| 159 |
-
probs = torch.sigmoid(logits)
|
| 160 |
-
labels = (targets >= peak_thresh).float()
|
| 161 |
-
non_peak_mask = (labels == 0).float()
|
| 162 |
-
peak_mask = (labels == 1).float()
|
| 163 |
-
|
| 164 |
-
bce_all = F.binary_cross_entropy_with_logits(logits, labels, reduction="none")
|
| 165 |
-
bce_non = bce_all * non_peak_mask
|
| 166 |
-
bce_non = bce_non.sum() / (non_peak_mask.sum() + eps)
|
| 167 |
-
|
| 168 |
-
mse_peaks = F.mse_loss(probs * peak_mask, targets * peak_mask, reduction="sum") / (
|
| 169 |
-
peak_mask.sum() + eps
|
| 170 |
-
)
|
| 171 |
-
mse_global = F.mse_loss(probs, targets, reduction="mean")
|
| 172 |
-
|
| 173 |
-
t_dist = targets + eps
|
| 174 |
-
p_dist = probs + eps
|
| 175 |
-
t_dist = t_dist / t_dist.sum(dim=1, keepdim=True)
|
| 176 |
-
p_dist = p_dist / p_dist.sum(dim=1, keepdim=True)
|
| 177 |
-
kl = (
|
| 178 |
-
(t_dist * (t_dist.clamp(min=eps).log() - p_dist.clamp(min=eps).log()))
|
| 179 |
-
.sum(dim=1)
|
| 180 |
-
.mean()
|
| 181 |
-
)
|
| 182 |
-
|
| 183 |
-
return bce_non, kl, mse_global, probs
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
def accuracy_percentage(logits, targets, peak_thresh=0.5):
|
| 187 |
-
probs = torch.sigmoid(logits)
|
| 188 |
-
preds_bin = (probs >= 0.5).float()
|
| 189 |
-
labels = (targets >= peak_thresh).float()
|
| 190 |
-
correct = (preds_bin == labels).float().sum()
|
| 191 |
-
total = torch.numel(labels)
|
| 192 |
-
return (correct / max(1, total)).item() * 100.0
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
def evaluate(model, dl, device, alpha, beta, gamma, peak_thresh, eps=1e-8):
|
| 196 |
-
model.eval()
|
| 197 |
-
tot_loss, tot_acc = 0.0, 0.0
|
| 198 |
-
n_batches = 0
|
| 199 |
-
with torch.no_grad():
|
| 200 |
-
for b, g, t in dl:
|
| 201 |
-
b, g, t = b.to(device), g.to(device), t.to(device)
|
| 202 |
-
logits = model(b, g)
|
| 203 |
-
bce_non, kl, mse_global, _ = combined_loss_components(
|
| 204 |
-
logits, t, peak_thresh=peak_thresh, eps=eps
|
| 205 |
-
)
|
| 206 |
-
loss = alpha * bce_non + beta * kl + gamma * mse_global
|
| 207 |
-
acc = accuracy_percentage(logits, t, peak_thresh=peak_thresh)
|
| 208 |
-
tot_loss += loss.item()
|
| 209 |
-
tot_acc += acc
|
| 210 |
-
n_batches += 1
|
| 211 |
-
if n_batches == 0:
|
| 212 |
-
return float("nan"), float("nan")
|
| 213 |
-
return tot_loss / n_batches, tot_acc / n_batches
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
# ─────────────── cluster-aware splitting ──────────────────────────────────
|
| 217 |
-
def assign_clusters_to_splits(
|
| 218 |
-
cluster_to_indices, val_frac=0.10, test_frac=0.10, seed=42
|
| 219 |
-
):
|
| 220 |
-
"""
|
| 221 |
-
cluster_to_indices: dict[cluster_id] -> list of example indices (from pair_list) in that cluster
|
| 222 |
-
We greedily pack whole clusters into val/test until hitting targets (#examples), rest to train.
|
| 223 |
-
"""
|
| 224 |
-
rng = random.Random(seed)
|
| 225 |
-
clusters = list(cluster_to_indices.items())
|
| 226 |
-
rng.shuffle(clusters)
|
| 227 |
-
|
| 228 |
-
total = sum(len(ixs) for _, ixs in clusters)
|
| 229 |
-
target_val = int(round(total * val_frac))
|
| 230 |
-
target_test = int(round(total * test_frac))
|
| 231 |
-
cur_val = cur_test = 0
|
| 232 |
-
|
| 233 |
-
tr_ix, va_ix, te_ix = [], [], []
|
| 234 |
-
for cid, ixs in clusters:
|
| 235 |
-
c = len(ixs)
|
| 236 |
-
if cur_val + c <= target_val:
|
| 237 |
-
va_ix.extend(ixs)
|
| 238 |
-
cur_val += c
|
| 239 |
-
elif cur_test + c <= target_test:
|
| 240 |
-
te_ix.extend(ixs)
|
| 241 |
-
cur_test += c
|
| 242 |
-
else:
|
| 243 |
-
tr_ix.extend(ixs)
|
| 244 |
-
return tr_ix, va_ix, te_ix
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
# ─────────────── train & main ────────────────────────────────────────────
|
| 248 |
-
def main():
|
| 249 |
-
p = argparse.ArgumentParser()
|
| 250 |
-
p.add_argument("--pair_list", required=True)
|
| 251 |
-
p.add_argument("--final_csv", required=True)
|
| 252 |
-
p.add_argument("--out_dir", required=True)
|
| 253 |
-
p.add_argument("--epochs", type=int, default=10)
|
| 254 |
-
p.add_argument("--batch_size", type=int, default=16)
|
| 255 |
-
p.add_argument("--accum_steps", type=int, default=4)
|
| 256 |
-
p.add_argument("--lr", type=float, default=1e-4)
|
| 257 |
-
p.add_argument("--device", default="cuda")
|
| 258 |
-
p.add_argument("--seed", type=int, default=42)
|
| 259 |
-
p.add_argument("--alpha", type=float, default=1)
|
| 260 |
-
p.add_argument("--beta", type=float, default=0)
|
| 261 |
-
p.add_argument("--gamma", type=float, default=1)
|
| 262 |
-
p.add_argument("--peak_thresh", type=float, default=0.5)
|
| 263 |
-
# NEW: fractions for cluster-aware split (used only if cluster_id present)
|
| 264 |
-
p.add_argument("--val_frac", type=float, default=0.10)
|
| 265 |
-
p.add_argument("--test_frac", type=float, default=0.10)
|
| 266 |
-
args = p.parse_args()
|
| 267 |
-
|
| 268 |
-
random.seed(args.seed)
|
| 269 |
-
np.random.seed(args.seed)
|
| 270 |
-
torch.manual_seed(args.seed)
|
| 271 |
-
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
|
| 272 |
-
|
| 273 |
-
# 1) load pair list & final.csv (now may include cluster_id)
|
| 274 |
-
tf_paths, dna_paths = parse_pair_list(args.pair_list)
|
| 275 |
-
final_df = pd.read_csv(args.final_csv, dtype=str)
|
| 276 |
-
print(f"[i] Loaded {len(tf_paths)} pairs", flush=True)
|
| 277 |
-
|
| 278 |
-
tf_cache = build_tf_cache(tf_paths, target_dim=256)
|
| 279 |
-
|
| 280 |
-
# detect binder/DNA dims
|
| 281 |
-
sample_tf = tf_cache[tf_paths[0]]
|
| 282 |
-
binder_input_dim = sample_tf.shape[1] if sample_tf.ndim == 2 else sample_tf.shape[0]
|
| 283 |
-
glm_input_dim = 256
|
| 284 |
-
|
| 285 |
-
# 2) cluster-aware split if possible
|
| 286 |
-
use_cluster_split = "cluster_id" in final_df.columns
|
| 287 |
-
if use_cluster_split:
|
| 288 |
-
print(
|
| 289 |
-
"[i] Cluster column detected in final_csv; performing cluster-aware split.",
|
| 290 |
-
flush=True,
|
| 291 |
-
)
|
| 292 |
-
# build dna_id -> cluster_id map
|
| 293 |
-
cid_map = (
|
| 294 |
-
final_df[["dna_id", "cluster_id"]]
|
| 295 |
-
.dropna()
|
| 296 |
-
.drop_duplicates()
|
| 297 |
-
.set_index("dna_id")["cluster_id"]
|
| 298 |
-
.to_dict()
|
| 299 |
-
)
|
| 300 |
-
|
| 301 |
-
# map each example (by index) to its dna_id and cluster
|
| 302 |
-
example_dna_ids = [dna_key_from_path(p) for p in dna_paths]
|
| 303 |
-
example_clusters = []
|
| 304 |
-
missing = 0
|
| 305 |
-
for did in example_dna_ids:
|
| 306 |
-
if did in cid_map:
|
| 307 |
-
example_clusters.append(cid_map[did])
|
| 308 |
-
else:
|
| 309 |
-
# fallback: treat singleton cluster
|
| 310 |
-
example_clusters.append(f"singleton::{did}")
|
| 311 |
-
missing += 1
|
| 312 |
-
if missing:
|
| 313 |
-
print(
|
| 314 |
-
f"[WARN] {missing} dna_ids from pair_list not found in cluster map; treating as singleton clusters.",
|
| 315 |
-
flush=True,
|
| 316 |
-
)
|
| 317 |
-
|
| 318 |
-
# build cluster -> indices
|
| 319 |
-
cluster_to_indices = {}
|
| 320 |
-
for i, cid in enumerate(example_clusters):
|
| 321 |
-
cluster_to_indices.setdefault(cid, []).append(i)
|
| 322 |
-
|
| 323 |
-
tr_idx, va_idx, te_idx = assign_clusters_to_splits(
|
| 324 |
-
cluster_to_indices,
|
| 325 |
-
val_frac=args.val_frac,
|
| 326 |
-
test_frac=args.test_frac,
|
| 327 |
-
seed=args.seed,
|
| 328 |
-
)
|
| 329 |
-
print(
|
| 330 |
-
f"[i] Cluster split sizes (examples): train={len(tr_idx)} val={len(va_idx)} test={len(te_idx)}",
|
| 331 |
-
flush=True,
|
| 332 |
-
)
|
| 333 |
-
|
| 334 |
-
# helper to subset paths
|
| 335 |
-
def subset_by_indices(ixs):
|
| 336 |
-
return [tf_paths[i] for i in ixs], [dna_paths[i] for i in ixs]
|
| 337 |
-
|
| 338 |
-
tr_t, tr_d = subset_by_indices(tr_idx)
|
| 339 |
-
va_t, va_d = subset_by_indices(va_idx)
|
| 340 |
-
te_t, te_d = subset_by_indices(te_idx)
|
| 341 |
-
|
| 342 |
-
else:
|
| 343 |
-
print(
|
| 344 |
-
"[i] No cluster_id in final_csv; using random 80/10/10 split (OLD behavior).",
|
| 345 |
-
flush=True,
|
| 346 |
-
)
|
| 347 |
-
# OLD random split (kept, now under else)
|
| 348 |
-
N = len(tf_paths)
|
| 349 |
-
idxs = list(range(N))
|
| 350 |
-
random.shuffle(idxs)
|
| 351 |
-
n_tr = int(0.8 * N)
|
| 352 |
-
n_va = int(0.1 * N)
|
| 353 |
-
tr, va, te = idxs[:n_tr], idxs[n_tr : n_tr + n_va], idxs[n_tr + n_va :]
|
| 354 |
-
|
| 355 |
-
def subset(idxs_):
|
| 356 |
-
return [tf_paths[i] for i in idxs_], [dna_paths[i] for i in idxs_]
|
| 357 |
-
|
| 358 |
-
tr_t, tr_d = subset(tr)
|
| 359 |
-
va_t, va_d = subset(va)
|
| 360 |
-
te_t, te_d = subset(te)
|
| 361 |
-
|
| 362 |
-
# 3) bucketed samplers (unchanged, but now use the cluster-aware subsets when available)
|
| 363 |
-
tr_bs = make_buckets(
|
| 364 |
-
list(range(len(tr_t))), tr_d, args.batch_size, n_buckets=10, seed=args.seed
|
| 365 |
-
)
|
| 366 |
-
va_bs = make_buckets(
|
| 367 |
-
list(range(len(va_t))), va_d, args.batch_size, n_buckets=5, seed=args.seed + 1
|
| 368 |
-
)
|
| 369 |
-
te_bs = make_buckets(
|
| 370 |
-
list(range(len(te_t))), te_d, args.batch_size, n_buckets=5, seed=args.seed + 2
|
| 371 |
-
)
|
| 372 |
-
|
| 373 |
-
tr_dl = DataLoader(
|
| 374 |
-
PairDataset(tr_t, tr_d, final_df, tf_cache),
|
| 375 |
-
batch_sampler=ListBatchSampler(tr_bs),
|
| 376 |
-
collate_fn=collate_fn,
|
| 377 |
-
)
|
| 378 |
-
va_dl = DataLoader(
|
| 379 |
-
PairDataset(va_t, va_d, final_df, tf_cache),
|
| 380 |
-
batch_sampler=ListBatchSampler(va_bs),
|
| 381 |
-
collate_fn=collate_fn,
|
| 382 |
-
)
|
| 383 |
-
te_dl = DataLoader(
|
| 384 |
-
PairDataset(te_t, te_d, final_df, tf_cache),
|
| 385 |
-
batch_sampler=ListBatchSampler(te_bs),
|
| 386 |
-
collate_fn=collate_fn,
|
| 387 |
-
)
|
| 388 |
-
|
| 389 |
-
# 4) model, optimizer, scaler
|
| 390 |
-
model = BindPredictor(
|
| 391 |
-
binder_input_dim=binder_input_dim,
|
| 392 |
-
glm_input_dim=glm_input_dim,
|
| 393 |
-
compressed_dim=256,
|
| 394 |
-
hidden_dim=256,
|
| 395 |
-
heads=8,
|
| 396 |
-
num_layers=4,
|
| 397 |
-
use_local_cnn_on_glm=True,
|
| 398 |
-
).to(device)
|
| 399 |
-
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
|
| 400 |
-
scaler = amp.GradScaler("cuda")
|
| 401 |
-
|
| 402 |
-
history, best_val = {"train": [], "val": []}, float("inf")
|
| 403 |
-
od = Path(args.out_dir)
|
| 404 |
-
od.mkdir(exist_ok=True, parents=True)
|
| 405 |
-
|
| 406 |
-
for ep in range(1, args.epochs + 1):
|
| 407 |
-
print(f"┌─[Epoch {ep}]────────────────────────", flush=True)
|
| 408 |
-
model.train()
|
| 409 |
-
optimizer.zero_grad()
|
| 410 |
-
acc_loss_sum, acc_acc_sum, n_train_batches = 0.0, 0.0, 0
|
| 411 |
-
|
| 412 |
-
for i, (b, g, t) in enumerate(tr_dl):
|
| 413 |
-
b, g, t = b.to(device), g.to(device), t.to(device)
|
| 414 |
-
with amp.autocast("cuda"):
|
| 415 |
-
logits = model(b, g)
|
| 416 |
-
bce_non, kl, mse_global, probs = combined_loss_components(
|
| 417 |
-
logits, t, peak_thresh=args.peak_thresh
|
| 418 |
-
)
|
| 419 |
-
loss = args.alpha * bce_non + args.beta * kl + args.gamma * mse_global
|
| 420 |
-
loss = loss / args.accum_steps
|
| 421 |
-
|
| 422 |
-
scaler.scale(loss).backward()
|
| 423 |
-
|
| 424 |
-
if (i + 1) % args.accum_steps == 0:
|
| 425 |
-
scaler.step(optimizer)
|
| 426 |
-
scaler.update()
|
| 427 |
-
optimizer.zero_grad()
|
| 428 |
-
|
| 429 |
-
with torch.no_grad():
|
| 430 |
-
acc_loss_sum += loss.item() * args.accum_steps
|
| 431 |
-
acc_acc_sum += accuracy_percentage(
|
| 432 |
-
logits, t, peak_thresh=args.peak_thresh
|
| 433 |
-
)
|
| 434 |
-
n_train_batches += 1
|
| 435 |
-
|
| 436 |
-
del b, g, t, logits, probs, loss, bce_non, kl, mse_global
|
| 437 |
-
torch.cuda.empty_cache()
|
| 438 |
-
|
| 439 |
-
# finalize if leftovers
|
| 440 |
-
if n_train_batches % args.accum_steps != 0:
|
| 441 |
-
scaler.step(optimizer)
|
| 442 |
-
scaler.update()
|
| 443 |
-
optimizer.zero_grad()
|
| 444 |
-
|
| 445 |
-
train_loss = acc_loss_sum / max(1, n_train_batches)
|
| 446 |
-
train_acc = acc_acc_sum / max(1, n_train_batches)
|
| 447 |
-
|
| 448 |
-
val_loss, val_acc = evaluate(
|
| 449 |
-
model,
|
| 450 |
-
va_dl,
|
| 451 |
-
device,
|
| 452 |
-
alpha=args.alpha,
|
| 453 |
-
beta=args.beta,
|
| 454 |
-
gamma=args.gamma,
|
| 455 |
-
peak_thresh=args.peak_thresh,
|
| 456 |
-
)
|
| 457 |
-
print(
|
| 458 |
-
f"[Epoch {ep}] train_loss={train_loss:.4f} train_acc={train_acc:.2f}% "
|
| 459 |
-
f"val_loss={val_loss:.4f} val_acc={val_acc:.2f}%",
|
| 460 |
-
flush=True,
|
| 461 |
-
)
|
| 462 |
-
|
| 463 |
-
history["train"].append(train_loss)
|
| 464 |
-
history["val"].append(val_loss)
|
| 465 |
-
if val_loss < best_val:
|
| 466 |
-
best_val = val_loss
|
| 467 |
-
torch.save(model.state_dict(), od / "best_model.pt")
|
| 468 |
-
print(
|
| 469 |
-
f" Saved new best_model.pt (val_loss={val_loss:.4f}, val_acc={val_acc:.2f}%)",
|
| 470 |
-
flush=True,
|
| 471 |
-
)
|
| 472 |
-
|
| 473 |
-
torch.save(model.state_dict(), od / "last_model.pt")
|
| 474 |
-
|
| 475 |
-
fig, ax = plt.subplots()
|
| 476 |
-
ax.plot(history["train"], label="train")
|
| 477 |
-
ax.plot(history["val"], label="val")
|
| 478 |
-
ax.set_xlabel("epoch")
|
| 479 |
-
ax.set_ylabel("combined loss")
|
| 480 |
-
ax.legend()
|
| 481 |
-
fig.savefig(od / "loss_curve.png")
|
| 482 |
-
print(f"✅ Done → outputs in {od}", flush=True)
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
if __name__ == "__main__":
|
| 486 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dpacman/classifier/torch_model.py
DELETED
|
@@ -1,157 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
from torch import nn
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
class LocalCNN(nn.Module):
|
| 6 |
-
def __init__(self, dim: int = 256, kernel_size: int = 3):
|
| 7 |
-
super().__init__()
|
| 8 |
-
padding = kernel_size // 2
|
| 9 |
-
self.conv = nn.Conv1d(dim, dim, kernel_size=kernel_size, padding=padding)
|
| 10 |
-
self.act = nn.GELU()
|
| 11 |
-
self.ln = nn.LayerNorm(dim)
|
| 12 |
-
|
| 13 |
-
def forward(self, x: torch.Tensor):
|
| 14 |
-
# x: (batch, L, dim)
|
| 15 |
-
out = self.conv(x.transpose(1, 2)) # → (batch, dim, L)
|
| 16 |
-
out = self.act(out)
|
| 17 |
-
out = out.transpose(1, 2) # → (batch, L, dim)
|
| 18 |
-
return self.ln(out + x) # residual
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
class CrossModalBlock(nn.Module):
|
| 22 |
-
def __init__(self, dim: int = 256, heads: int = 8):
|
| 23 |
-
super().__init__()
|
| 24 |
-
# self-attention for both sides
|
| 25 |
-
self.sa_binder = nn.MultiheadAttention(dim, heads, batch_first=True)
|
| 26 |
-
self.sa_glm = nn.MultiheadAttention(dim, heads, batch_first=True)
|
| 27 |
-
self.ln_b1 = nn.LayerNorm(dim)
|
| 28 |
-
self.ln_g1 = nn.LayerNorm(dim)
|
| 29 |
-
|
| 30 |
-
self.ffn_b = nn.Sequential(
|
| 31 |
-
nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
|
| 32 |
-
)
|
| 33 |
-
self.ffn_g = nn.Sequential(
|
| 34 |
-
nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
|
| 35 |
-
)
|
| 36 |
-
self.ln_b2 = nn.LayerNorm(dim)
|
| 37 |
-
self.ln_g2 = nn.LayerNorm(dim)
|
| 38 |
-
|
| 39 |
-
# cross attention (binder queries, glm keys/values)
|
| 40 |
-
self.cross_attn = nn.MultiheadAttention(dim, heads, batch_first=True)
|
| 41 |
-
self.ln_c1 = nn.LayerNorm(dim)
|
| 42 |
-
self.ffn_c = nn.Sequential(
|
| 43 |
-
nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim)
|
| 44 |
-
)
|
| 45 |
-
self.ln_c2 = nn.LayerNorm(dim)
|
| 46 |
-
|
| 47 |
-
def forward(self, binder: torch.Tensor, glm: torch.Tensor):
|
| 48 |
-
"""
|
| 49 |
-
binder: (batch, Lb, dim)
|
| 50 |
-
glm: (batch, Lg, dim) -- has passed through its local CNN beforehand
|
| 51 |
-
returns: updated binder representation (batch, Lb, dim)
|
| 52 |
-
"""
|
| 53 |
-
# binder self-attn + ffn
|
| 54 |
-
b = binder
|
| 55 |
-
b_sa, _ = self.sa_binder(b, b, b)
|
| 56 |
-
b = self.ln_b1(b + b_sa)
|
| 57 |
-
b_ff = self.ffn_b(b)
|
| 58 |
-
b = self.ln_b2(b + b_ff)
|
| 59 |
-
|
| 60 |
-
# glm self-attn + ffn
|
| 61 |
-
g = glm
|
| 62 |
-
g_sa, _ = self.sa_glm(g, g, g)
|
| 63 |
-
g = self.ln_g1(g + g_sa)
|
| 64 |
-
g_ff = self.ffn_g(g)
|
| 65 |
-
g = self.ln_g2(g + g_ff)
|
| 66 |
-
|
| 67 |
-
# cross-attention: binder queries glm
|
| 68 |
-
c_sa, _ = self.cross_attn(b, g, g)
|
| 69 |
-
c = self.ln_c1(b + c_sa)
|
| 70 |
-
c_ff = self.ffn_c(c)
|
| 71 |
-
c = self.ln_c2(c + c_ff)
|
| 72 |
-
return c # (batch, Lb, dim)
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
class DimCompressor(nn.Module):
|
| 76 |
-
"""
|
| 77 |
-
Learnable per-token compressor: maps any in_dim >= out_dim to out_dim (default 256).
|
| 78 |
-
If in_dim == out_dim, behaves as identity.
|
| 79 |
-
"""
|
| 80 |
-
|
| 81 |
-
def __init__(self, in_dim: int, out_dim: int = 256):
|
| 82 |
-
super().__init__()
|
| 83 |
-
if in_dim == out_dim:
|
| 84 |
-
self.net = nn.Identity()
|
| 85 |
-
else:
|
| 86 |
-
hidden = max(out_dim * 2, (in_dim + out_dim) // 2)
|
| 87 |
-
self.net = nn.Sequential(
|
| 88 |
-
nn.LayerNorm(in_dim),
|
| 89 |
-
nn.Linear(in_dim, hidden),
|
| 90 |
-
nn.GELU(),
|
| 91 |
-
nn.Linear(hidden, out_dim),
|
| 92 |
-
)
|
| 93 |
-
|
| 94 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 95 |
-
# x: (B, L, in_dim)
|
| 96 |
-
return self.net(x)
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
class BindPredictor(nn.Module):
|
| 100 |
-
def __init__(
|
| 101 |
-
self,
|
| 102 |
-
# input_dim: int = 256, # OLD: single input dim
|
| 103 |
-
binder_input_dim: int = 1280, # NEW: TF (binder) original dim (e.g., 1280)
|
| 104 |
-
glm_input_dim: int = 256, # NEW: DNA/GLM original dim (e.g., 256)
|
| 105 |
-
compressed_dim: int = 256, # NEW: learnable compressed dim
|
| 106 |
-
hidden_dim: int = 256,
|
| 107 |
-
heads: int = 8,
|
| 108 |
-
num_layers: int = 4,
|
| 109 |
-
use_local_cnn_on_glm: bool = True,
|
| 110 |
-
):
|
| 111 |
-
super().__init__()
|
| 112 |
-
# OLD:
|
| 113 |
-
# self.proj_binder = nn.Linear(input_dim, hidden_dim)
|
| 114 |
-
# self.proj_glm = nn.Linear(input_dim, hidden_dim)
|
| 115 |
-
|
| 116 |
-
# NEW: learnable compressor for binder → 256, then project to hidden
|
| 117 |
-
self.binder_compress = DimCompressor(binder_input_dim, out_dim=compressed_dim)
|
| 118 |
-
self.proj_binder = nn.Linear(compressed_dim, hidden_dim)
|
| 119 |
-
|
| 120 |
-
# GLM side stays 256 → hidden
|
| 121 |
-
self.proj_glm = nn.Linear(glm_input_dim, hidden_dim)
|
| 122 |
-
|
| 123 |
-
self.use_local_cnn = use_local_cnn_on_glm
|
| 124 |
-
self.local_cnn = LocalCNN(hidden_dim) if use_local_cnn_on_glm else nn.Identity()
|
| 125 |
-
|
| 126 |
-
self.layers = nn.ModuleList(
|
| 127 |
-
[CrossModalBlock(hidden_dim, heads) for _ in range(num_layers)]
|
| 128 |
-
)
|
| 129 |
-
|
| 130 |
-
self.ln_out = nn.LayerNorm(hidden_dim)
|
| 131 |
-
# self.head = nn.Sequential(nn.Linear(hidden_dim, 1), nn.Sigmoid()) # OLD: returned probabilities
|
| 132 |
-
self.head = nn.Linear(hidden_dim, 1) # NEW: return logits (safe for AMP)
|
| 133 |
-
|
| 134 |
-
def forward(self, binder_emb, glm_emb):
|
| 135 |
-
"""
|
| 136 |
-
binder_emb: (B, Lb, binder_input_dim)
|
| 137 |
-
glm_emb: (B, Lg, glm_input_dim)
|
| 138 |
-
Returns per-nucleotide logits for the GLM sequence: (B, Lg)
|
| 139 |
-
"""
|
| 140 |
-
# Binder: learnable compression → 256 → hidden
|
| 141 |
-
b = self.binder_compress(binder_emb) # (B, Lb, 256)
|
| 142 |
-
b = self.proj_binder(b) # (B, Lb, hidden_dim)
|
| 143 |
-
|
| 144 |
-
# GLM: project → hidden, add local CNN context
|
| 145 |
-
g = self.proj_glm(glm_emb) # (B, Lg, hidden_dim)
|
| 146 |
-
if self.use_local_cnn:
|
| 147 |
-
g = self.local_cnn(g)
|
| 148 |
-
|
| 149 |
-
# Cross-modal blocks: update binder states using GLM
|
| 150 |
-
for layer in self.layers:
|
| 151 |
-
b = layer(b, g) # (B, Lb, hidden_dim)
|
| 152 |
-
|
| 153 |
-
# Predict per-nucleotide logits on the GLM tokens:
|
| 154 |
-
# return self.head(g).squeeze(-1) # OLD: probabilities (with Sigmoid in head)
|
| 155 |
-
return self.head(g).squeeze(
|
| 156 |
-
-1
|
| 157 |
-
) # NEW: logits (apply sigmoid only in loss/metrics)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dpacman/classifier/train.py
DELETED
|
@@ -1,220 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
import argparse
|
| 3 |
-
import numpy as np
|
| 4 |
-
import torch
|
| 5 |
-
from torch import nn
|
| 6 |
-
from model import BindPredictor
|
| 7 |
-
from pathlib import Path
|
| 8 |
-
from collections import Counter
|
| 9 |
-
from sklearn.metrics import roc_auc_score, average_precision_score
|
| 10 |
-
from sklearn.decomposition import TruncatedSVD
|
| 11 |
-
import sys
|
| 12 |
-
|
| 13 |
-
from dpacman.utils.models import set_seed
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
def build_tf_compressed_cache(binder_paths, target_dim=256):
|
| 17 |
-
"""
|
| 18 |
-
Load all unique TF (binder) embeddings, fit reduction if needed, and return dict mapping path->(L, target_dim) array.
|
| 19 |
-
"""
|
| 20 |
-
unique_paths = sorted(set(binder_paths))
|
| 21 |
-
print(
|
| 22 |
-
f"[i] Found {len(unique_paths)} unique TF embedding files to compress.",
|
| 23 |
-
flush=True,
|
| 24 |
-
)
|
| 25 |
-
# Load all embeddings to determine dimensionality
|
| 26 |
-
samples = []
|
| 27 |
-
for p in unique_paths:
|
| 28 |
-
arr = np.load(p)
|
| 29 |
-
samples.append(arr)
|
| 30 |
-
# Determine if reduction needed: assume all have same embedding width
|
| 31 |
-
first = samples[0]
|
| 32 |
-
orig_dim = first.shape[1] if first.ndim == 2 else 1
|
| 33 |
-
reduction_needed = orig_dim != target_dim
|
| 34 |
-
tf_cache = {}
|
| 35 |
-
|
| 36 |
-
if reduction_needed:
|
| 37 |
-
# Build matrix to fit SVD: we need a 2D matrix per embedding; if lengths vary we can't directly stack.
|
| 38 |
-
# We'll do reduction per sequence individually using TruncatedSVD on concatenated flattened features:
|
| 39 |
-
# Simplest: for variable lengths, reduce each embedding separately with a learned linear projection.
|
| 40 |
-
# Here we fit a single TruncatedSVD on the concatenation of all sequence tokens (flattened) by padding/truncating to a fixed length.
|
| 41 |
-
# To avoid complexity, use PCA-like linear projection learned via SVD on mean-pooled vectors:
|
| 42 |
-
pooled = []
|
| 43 |
-
for arr in samples:
|
| 44 |
-
if arr.ndim == 2:
|
| 45 |
-
pooled.append(arr.mean(axis=0)) # (orig_dim,)
|
| 46 |
-
else:
|
| 47 |
-
pooled.append(arr) # degenerate
|
| 48 |
-
pooled_mat = np.stack(pooled, axis=0) # (N, orig_dim)
|
| 49 |
-
print(
|
| 50 |
-
f"[i] Fitting TruncatedSVD on TF pooled embeddings: {pooled_mat.shape} -> {target_dim}",
|
| 51 |
-
flush=True,
|
| 52 |
-
)
|
| 53 |
-
svd = TruncatedSVD(n_components=target_dim, random_state=42)
|
| 54 |
-
reduced_pooled = svd.fit_transform(pooled_mat) # (N, target_dim)
|
| 55 |
-
|
| 56 |
-
# For each original embedding, project token-level vectors by multiplying token vector with svd.components_.T
|
| 57 |
-
# svd.components_: (target_dim, orig_dim) so projection matrix is (orig_dim, target_dim)
|
| 58 |
-
proj_mat = svd.components_.T # (orig_dim, target_dim)
|
| 59 |
-
for i, p in enumerate(unique_paths):
|
| 60 |
-
arr = samples[i] # shape (L, orig_dim)
|
| 61 |
-
if arr.ndim == 1:
|
| 62 |
-
arr2 = arr @ proj_mat # (target_dim,)
|
| 63 |
-
else:
|
| 64 |
-
# project each token: (L, orig_dim) @ (orig_dim, target_dim) -> (L, target_dim)
|
| 65 |
-
arr2 = arr @ proj_mat
|
| 66 |
-
tf_cache[p] = arr2 # reduced per-token representation
|
| 67 |
-
print("[i] Completed compression of TF embeddings.", flush=True)
|
| 68 |
-
else:
|
| 69 |
-
# already correct dim: just cache originals
|
| 70 |
-
print(
|
| 71 |
-
f"[i] TF embeddings already {target_dim}-dimensional; skipping reduction.",
|
| 72 |
-
flush=True,
|
| 73 |
-
)
|
| 74 |
-
for i, p in enumerate(unique_paths):
|
| 75 |
-
arr = samples[i]
|
| 76 |
-
tf_cache[p] = arr
|
| 77 |
-
return tf_cache
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
def evaluate(model, dl, device):
|
| 81 |
-
model.eval()
|
| 82 |
-
all_labels = []
|
| 83 |
-
all_preds = []
|
| 84 |
-
with torch.no_grad():
|
| 85 |
-
for b, g, y in dl:
|
| 86 |
-
b = b.to(device)
|
| 87 |
-
g = g.to(device)
|
| 88 |
-
y = y.to(device)
|
| 89 |
-
pred = model(b, g)
|
| 90 |
-
all_labels.append(y.cpu())
|
| 91 |
-
all_preds.append(pred.cpu())
|
| 92 |
-
if not all_labels:
|
| 93 |
-
return 0.0, 0.0
|
| 94 |
-
y_true = torch.cat(all_labels).numpy()
|
| 95 |
-
y_score = torch.cat(all_preds).numpy()
|
| 96 |
-
try:
|
| 97 |
-
auc = roc_auc_score(y_true, y_score)
|
| 98 |
-
except Exception:
|
| 99 |
-
auc = 0.0
|
| 100 |
-
try:
|
| 101 |
-
ap = average_precision_score(y_true, y_score)
|
| 102 |
-
except Exception:
|
| 103 |
-
ap = 0.0
|
| 104 |
-
return auc, ap
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
def unpack(data):
|
| 108 |
-
binders, glms, labels = zip(*data)
|
| 109 |
-
return list(binders), list(glms), list(labels)
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
# ---- main ------------------------------------------------------------
|
| 113 |
-
def main(cfg):
|
| 114 |
-
"""
|
| 115 |
-
Main method, used to train the model.
|
| 116 |
-
"""
|
| 117 |
-
# Set seed for reproducibility
|
| 118 |
-
set_seed(cfg.seed)
|
| 119 |
-
|
| 120 |
-
parser.add_argument("--out_dir", type=str, required=True)
|
| 121 |
-
parser.add_argument("--epochs", type=int, default=10)
|
| 122 |
-
parser.add_argument("--batch_size", type=int, default=32)
|
| 123 |
-
parser.add_argument("--lr", type=float, default=1e-4)
|
| 124 |
-
parser.add_argument("--device", type=str, default="cuda")
|
| 125 |
-
parser.add_argument("--seed", type=int, default=42)
|
| 126 |
-
args = parser.parse_args()
|
| 127 |
-
|
| 128 |
-
#
|
| 129 |
-
print("DEBUG: starting training script with in-line TF compression", flush=True)
|
| 130 |
-
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
|
| 131 |
-
binder_paths, glm_paths, labels = parse_pair_list(cfg.pair_list)
|
| 132 |
-
|
| 133 |
-
if len(labels) == 0:
|
| 134 |
-
print("[ERROR] No valid pairs parsed. Exiting.", file=sys.stderr)
|
| 135 |
-
sys.exit(1)
|
| 136 |
-
|
| 137 |
-
label_counts = Counter(labels)
|
| 138 |
-
print(
|
| 139 |
-
f"[i] Total examples parsed: {len(labels)}. Label distribution: {label_counts}",
|
| 140 |
-
flush=True,
|
| 141 |
-
)
|
| 142 |
-
|
| 143 |
-
# build compressed TF cache (reduces to 256 if needed)
|
| 144 |
-
tf_compressed_cache = build_tf_compressed_cache(binder_paths, target_dim=256)
|
| 145 |
-
|
| 146 |
-
# load training data aloiaushasfoiuhasfoiuafasdfoihuaaasdfoiuhasfaaoiufhasfoasasfoiuh
|
| 147 |
-
|
| 148 |
-
train_ds = PairDataset(None, tf_compressed_cache=tf_compressed_cache)
|
| 149 |
-
val_ds = PairDataset(*subset(val_i), tf_compressed_cache=tf_compressed_cache)
|
| 150 |
-
test_ds = PairDataset(*subset(test_i), tf_compressed_cache=tf_compressed_cache)
|
| 151 |
-
|
| 152 |
-
print(
|
| 153 |
-
f"[i] Train/Val/Test sizes: {len(train_ds)}/{len(val_ds)}/{len(test_ds)}",
|
| 154 |
-
flush=True,
|
| 155 |
-
)
|
| 156 |
-
if len(train_ds) == 0 or len(val_ds) == 0:
|
| 157 |
-
print(
|
| 158 |
-
"[ERROR] Train or validation split is empty; cannot proceed.",
|
| 159 |
-
file=sys.stderr,
|
| 160 |
-
)
|
| 161 |
-
sys.exit(1)
|
| 162 |
-
|
| 163 |
-
train_dl = DataLoader(
|
| 164 |
-
train_ds, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn
|
| 165 |
-
)
|
| 166 |
-
val_dl = DataLoader(
|
| 167 |
-
val_ds, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn
|
| 168 |
-
)
|
| 169 |
-
test_dl = DataLoader(
|
| 170 |
-
test_ds, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn
|
| 171 |
-
)
|
| 172 |
-
|
| 173 |
-
model = BindPredictor(
|
| 174 |
-
input_dim=256, hidden_dim=256, heads=8, num_layers=3, use_local_cnn_on_glm=True
|
| 175 |
-
)
|
| 176 |
-
model = model.to(device)
|
| 177 |
-
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-3)
|
| 178 |
-
loss_fn = nn.BCELoss()
|
| 179 |
-
|
| 180 |
-
best_val = -float("inf")
|
| 181 |
-
os_out = Path(args.out_dir)
|
| 182 |
-
os_out.mkdir(exist_ok=True, parents=True)
|
| 183 |
-
|
| 184 |
-
for epoch in range(1, args.epochs + 1):
|
| 185 |
-
print(f"[Epoch {epoch}] starting...", flush=True)
|
| 186 |
-
model.train()
|
| 187 |
-
running_loss = 0.0
|
| 188 |
-
for b, g, y in train_dl:
|
| 189 |
-
b = b.to(device)
|
| 190 |
-
g = g.to(device)
|
| 191 |
-
y = y.to(device)
|
| 192 |
-
pred = model(b, g)
|
| 193 |
-
loss = loss_fn(pred, y)
|
| 194 |
-
optimizer.zero_grad()
|
| 195 |
-
loss.backward()
|
| 196 |
-
optimizer.step()
|
| 197 |
-
running_loss += loss.item() * b.size(0)
|
| 198 |
-
train_loss = running_loss / len(train_ds)
|
| 199 |
-
val_auc, val_ap = evaluate(model, val_dl, device)
|
| 200 |
-
print(
|
| 201 |
-
f"[Epoch {epoch}] train_loss={train_loss:.4f} val_auc={val_auc:.4f} val_ap={val_ap:.4f}",
|
| 202 |
-
flush=True,
|
| 203 |
-
)
|
| 204 |
-
|
| 205 |
-
if val_auc > best_val:
|
| 206 |
-
best_val = val_auc
|
| 207 |
-
torch.save(model.state_dict(), os_out / "best_model.pt")
|
| 208 |
-
print(
|
| 209 |
-
f"[Epoch {epoch}] Saved new best model with val_auc={val_auc:.4f}",
|
| 210 |
-
flush=True,
|
| 211 |
-
)
|
| 212 |
-
|
| 213 |
-
torch.save(model.state_dict(), os_out / "last_model.pt")
|
| 214 |
-
test_auc, test_ap = evaluate(model, test_dl, device)
|
| 215 |
-
print(f"FINAL TEST: AUC={test_auc:.4f} AP={test_ap:.4f}", flush=True)
|
| 216 |
-
print(f"[i] Models written to {os_out}/best_model.pt and last_model.pt", flush=True)
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
if __name__ == "__main__":
|
| 220 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dpacman/scripts/delay_run.sh
CHANGED
|
@@ -4,7 +4,7 @@ set -euo pipefail
|
|
| 4 |
# Usage: ./stagger.sh <first_script.sh> <second_script.sh>
|
| 5 |
# Optional: override waits via env vars WAIT1 / WAIT2 (seconds). Defaults: 3 hours each.
|
| 6 |
|
| 7 |
-
WAIT1=${WAIT1:-
|
| 8 |
WAIT2=${WAIT2:-10800}
|
| 9 |
|
| 10 |
SCRIPT1="${1:?usage: $0 <first_script.sh> <second_script.sh>}"
|
|
|
|
| 4 |
# Usage: ./stagger.sh <first_script.sh> <second_script.sh>
|
| 5 |
# Optional: override waits via env vars WAIT1 / WAIT2 (seconds). Defaults: 3 hours each.
|
| 6 |
|
| 7 |
+
WAIT1=${WAIT1:-3600} # 3 hours in seconds
|
| 8 |
WAIT2=${WAIT2:-10800}
|
| 9 |
|
| 10 |
SCRIPT1="${1:?usage: $0 <first_script.sh> <second_script.sh>}"
|
dpacman/scripts/run_train.sh
CHANGED
|
@@ -22,7 +22,7 @@ CUDA_VISIBLE_DEVICES=0,1 nohup python -u -m scripts.train \
|
|
| 22 |
+trainer.gradient_clip_algorithm="norm" \
|
| 23 |
hydra.run.dir="${run_dir}" \
|
| 24 |
trainer.devices=2 \
|
| 25 |
-
trainer.max_epochs=
|
| 26 |
data_module.train_file="data_files/processed/splits/by_dna/train.csv" \
|
| 27 |
data_module.val_file="data_files/processed/splits/by_dna/val.csv" \
|
| 28 |
data_module.test_file="data_files/processed/splits/by_dna/test.csv" \
|
|
@@ -31,8 +31,8 @@ CUDA_VISIBLE_DEVICES=0,1 nohup python -u -m scripts.train \
|
|
| 31 |
data_module.batch_size=16 \
|
| 32 |
model.glm_input_dim=256 \
|
| 33 |
model.compressed_dim=256 \
|
| 34 |
-
model.hidden_dim=
|
| 35 |
-
model.lr=
|
| 36 |
> "${run_dir}/run.log" 2>&1 &
|
| 37 |
|
| 38 |
echo $! > "${run_dir}/pid.txt"
|
|
|
|
| 22 |
+trainer.gradient_clip_algorithm="norm" \
|
| 23 |
hydra.run.dir="${run_dir}" \
|
| 24 |
trainer.devices=2 \
|
| 25 |
+
trainer.max_epochs=20 \
|
| 26 |
data_module.train_file="data_files/processed/splits/by_dna/train.csv" \
|
| 27 |
data_module.val_file="data_files/processed/splits/by_dna/val.csv" \
|
| 28 |
data_module.test_file="data_files/processed/splits/by_dna/test.csv" \
|
|
|
|
| 31 |
data_module.batch_size=16 \
|
| 32 |
model.glm_input_dim=256 \
|
| 33 |
model.compressed_dim=256 \
|
| 34 |
+
model.hidden_dim=128 \
|
| 35 |
+
model.lr=1e-5 \
|
| 36 |
> "${run_dir}/run.log" 2>&1 &
|
| 37 |
|
| 38 |
echo $! > "${run_dir}/pid.txt"
|
dpacman/scripts/run_train_baseline.sh
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# Manually specify values used in the config
|
| 4 |
+
main_task="train"
|
| 5 |
+
model_type="baseline"
|
| 6 |
+
timestamp=$(date "+%Y-%m-%d_%H-%M-%S")
|
| 7 |
+
|
| 8 |
+
run_dir="$HOME/DPACMAN/logs/${main_task}/${model_type}/runs/${timestamp}"
|
| 9 |
+
mkdir -p "$run_dir"
|
| 10 |
+
|
| 11 |
+
if [ -z "$WANDB_API_KEY" ]; then
|
| 12 |
+
read -s -p "Enter your WANDB API key: " wandb_key
|
| 13 |
+
echo
|
| 14 |
+
export WANDB_API_KEY="$wandb_key"
|
| 15 |
+
fi
|
| 16 |
+
|
| 17 |
+
CUDA_VISIBLE_DEVICES=0,1 nohup python -u -m scripts.train \
|
| 18 |
+
+trainer.strategy=ddp \
|
| 19 |
+
+trainer.use_distributed_sampler="false" \
|
| 20 |
+
+trainer.detect_anomaly="false" \
|
| 21 |
+
+trainer.gradient_clip_val=0.5 \
|
| 22 |
+
+trainer.gradient_clip_algorithm="norm" \
|
| 23 |
+
hydra.run.dir="${run_dir}" \
|
| 24 |
+
trainer.devices=2 \
|
| 25 |
+
trainer.max_epochs=10 \
|
| 26 |
+
data_module.train_file="data_files/processed/splits/by_dna/train.csv" \
|
| 27 |
+
data_module.val_file="data_files/processed/splits/by_dna/val.csv" \
|
| 28 |
+
data_module.test_file="data_files/processed/splits/by_dna/test.csv" \
|
| 29 |
+
data_module.tr_shelf_path="data_files/processed/embeddings/fimo_hits_only/trs_esm.shelf" \
|
| 30 |
+
data_module.dna_shelf_path="data_files/processed/embeddings/fimo_hits_only/peaks_caduceus.shelf" \
|
| 31 |
+
data_module.batch_size=16 \
|
| 32 |
+
model=baseline \
|
| 33 |
+
model.glm_input_dim=256 \
|
| 34 |
+
model.compressed_dim=256 \
|
| 35 |
+
model.hidden_dim=128 \
|
| 36 |
+
model.lr=1e-5 \
|
| 37 |
+
> "${run_dir}/run.log" 2>&1 &
|
| 38 |
+
|
| 39 |
+
echo $! > "${run_dir}/pid.txt"
|