svincoff commited on
Commit
7b33404
·
1 Parent(s): 121a325

eval mode, fixed, full binary mode

Browse files
.gitignore CHANGED
@@ -38,4 +38,6 @@ dpacman/peak_examples/
38
  dpacman/__pycache__/
39
  log.log
40
  log2.log
41
- dpacman/delay.log
 
 
 
38
  dpacman/__pycache__/
39
  log.log
40
  log2.log
41
+ dpacman/delay.log
42
+ dpacman/view_profiles.ipynb
43
+ dpacman/find_wandb_run_dirs.py
configs/data_module/pair.yaml CHANGED
@@ -4,6 +4,9 @@ train_file: data_files/processed/splits/by_dna/babytrain.csv
4
  val_file: data_files/processed/splits/by_dna/babyval.csv
5
  test_file: data_files/processed/splits/by_dna/babytest.csv
6
 
 
 
 
7
  tr_shelf_path: data_files/processed/embeddings/fimo_hits_only/trs_esm.shelf
8
  dna_shelf_path: data_files/processed/embeddings/fimo_hits_only/baby_peaks_segmentnt_pernuc_with_onehot.shelf
9
 
 
4
  val_file: data_files/processed/splits/by_dna/babyval.csv
5
  test_file: data_files/processed/splits/by_dna/babytest.csv
6
 
7
+ target_col: dna_sequence
8
+ score_col: scores
9
+
10
  tr_shelf_path: data_files/processed/embeddings/fimo_hits_only/trs_esm.shelf
11
  dna_shelf_path: data_files/processed/embeddings/fimo_hits_only/baby_peaks_segmentnt_pernuc_with_onehot.shelf
12
 
configs/eval.yaml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - paths: default
3
+ - hydra: default # ← tells Hydra to use the logging/output config
4
+ - data_module: pair
5
+ - model: classifier
6
+ - trainer: gpu
7
+ - extras: default
8
+ - logger: wandb
9
+ - callbacks: default
10
+ - _self_
11
+
12
+ # experiment configs allow for version control of specific hyperparameters
13
+ # e.g. best hyperparameters for given model and datamodule
14
+ - experiment: null
15
+
16
+ # config for hyperparameter optimization
17
+ - hparams_search: null
18
+
19
+ # debugging config (enable through command line, e.g. `python train.py debug=default)
20
+ - debug: null
21
+
22
+ task_name: eval/${model}
23
+
24
+ # tags to help you identify your experiments
25
+ # you can overwrite this in experiment configs
26
+ # overwrite from command line with `python train.py tags="[first_tag, second_tag]"`
27
+ tags: ["dev"]
28
+
29
+ # evaluate on test set, using best model weights achieved during training
30
+ # lightning chooses best weights based on the metric specified in checkpoint callback
31
+ test: True
32
+
33
+ # simply provide checkpoint path to resume training
34
+ ckpt_path: /home/a03-svincoff/DPACMAN/logs/train/classifier/runs/2025-08-25_18-08-13/checkpoints/epoch_009.ckpt
35
+
36
+ # seed for random number generators in pytorch, numpy and python.random
37
+ seed: 42
38
+
39
+ data_module:
40
+ train_file: null
41
+ val_file: null
42
+ test_file: data_files/processed/splits/by_dna/test.csv
configs/model/baseline.yaml CHANGED
@@ -7,4 +7,6 @@ weight_decay: 0.01
7
 
8
  glm_input_dim: 256
9
  compressed_dim: 256
10
- hidden_dim: 128
 
 
 
7
 
8
  glm_input_dim: 256
9
  compressed_dim: 256
10
+ hidden_dim: 128
11
+
12
+ loss_type: mixed
configs/model/classifier.yaml CHANGED
@@ -7,4 +7,6 @@ weight_decay: 0.01
7
 
8
  glm_input_dim: 1029
9
  compressed_dim: 1029
10
- hidden_dim: 256
 
 
 
7
 
8
  glm_input_dim: 1029
9
  compressed_dim: 1029
10
+ hidden_dim: 256
11
+
12
+ loss_type: mixed
dpacman/classifier/baseline.py CHANGED
@@ -24,6 +24,7 @@ class BaselineBindPredictor(LightningModule):
24
  gamma: float = 20,
25
  dropout: float = 0,
26
  weight_decay: float = 0.01,
 
27
  ):
28
  # Init
29
  super(BaselineBindPredictor, self).__init__()
@@ -78,7 +79,7 @@ class BaselineBindPredictor(LightningModule):
78
  """
79
  logits = self.forward(batch["binder_emb"], batch["glm_emb"], batch["binder_kpm"], batch["glm_kpm"])
80
  loss = calculate_loss(
81
- logits, batch["labels"], alpha=self.hparams.alpha, gamma=self.hparams.gamma
82
  )
83
  self.log(
84
  "train/loss",
@@ -113,7 +114,7 @@ class BaselineBindPredictor(LightningModule):
113
  def validation_step(self, batch, batch_idx):
114
  logits = self.forward(batch["binder_emb"], batch["glm_emb"], batch["binder_kpm"], batch["glm_kpm"])
115
  loss = calculate_loss(
116
- logits, batch["labels"], alpha=self.hparams.alpha, gamma=self.hparams.gamma
117
  )
118
  self.log(
119
  "val/loss",
@@ -143,7 +144,7 @@ class BaselineBindPredictor(LightningModule):
143
  def test_step(self, batch, batch_idx):
144
  logits = self.forward(batch["binder_emb"], batch["glm_emb"], batch["binder_kpm"], batch["glm_kpm"])
145
  loss = calculate_loss(
146
- logits, batch["labels"], alpha=self.hparams.alpha, gamma=self.hparams.gamma
147
  )
148
  self.log(
149
  "test/loss", loss, on_step=False, on_epoch=True, batch_size=logits.size(0)
 
24
  gamma: float = 20,
25
  dropout: float = 0,
26
  weight_decay: float = 0.01,
27
+ loss_type: str = "mixed"
28
  ):
29
  # Init
30
  super(BaselineBindPredictor, self).__init__()
 
79
  """
80
  logits = self.forward(batch["binder_emb"], batch["glm_emb"], batch["binder_kpm"], batch["glm_kpm"])
81
  loss = calculate_loss(
82
+ logits, batch["labels"], batch["binder_kpm"], batch["glm_kpm"], alpha=self.hparams.alpha, gamma=self.hparams.gamma, loss_type=self.hparams.loss_type
83
  )
84
  self.log(
85
  "train/loss",
 
114
  def validation_step(self, batch, batch_idx):
115
  logits = self.forward(batch["binder_emb"], batch["glm_emb"], batch["binder_kpm"], batch["glm_kpm"])
116
  loss = calculate_loss(
117
+ logits, batch["labels"], batch["binder_kpm"], batch["glm_kpm"], alpha=self.hparams.alpha, gamma=self.hparams.gamma, loss_type=self.hparams.loss_type
118
  )
119
  self.log(
120
  "val/loss",
 
144
  def test_step(self, batch, batch_idx):
145
  logits = self.forward(batch["binder_emb"], batch["glm_emb"], batch["binder_kpm"], batch["glm_kpm"])
146
  loss = calculate_loss(
147
+ logits, batch["labels"], batch["binder_kpm"], batch["glm_kpm"], alpha=self.hparams.alpha, gamma=self.hparams.gamma, loss_type=self.hparams.loss_type
148
  )
149
  self.log(
150
  "test/loss", loss, on_step=False, on_epoch=True, batch_size=logits.size(0)
dpacman/classifier/loss.py CHANGED
@@ -7,6 +7,12 @@ import torch.nn.functional as F
7
  from torchmetrics.functional.classification import (
8
  auroc, average_precision, roc, precision_recall_curve
9
  )
 
 
 
 
 
 
10
 
11
  def _expand_like(mask: torch.Tensor, like: torch.Tensor):
12
  # Make mask broadcastable to logits/targets (handles (B,L) vs (B,L,1))
@@ -14,7 +20,7 @@ def _expand_like(mask: torch.Tensor, like: torch.Tensor):
14
  mask = mask.unsqueeze(-1)
15
  return mask.expand_as(like)
16
 
17
- def bce_loss_masked(logits, targets, nonpeak_mask, pos_weight=None, eps=1e-8):
18
  """
19
  Compute masked BCE with logits over non-peak positions only.
20
  Expects nonpeak_mask already broadcastable to logits.
@@ -24,7 +30,7 @@ def bce_loss_masked(logits, targets, nonpeak_mask, pos_weight=None, eps=1e-8):
24
  loss = F.binary_cross_entropy_with_logits(
25
  logits, t, reduction="none", pos_weight=pos_weight
26
  )
27
- m = _expand_like(nonpeak_mask, loss).to(loss.dtype)
28
  denom = m.sum().clamp_min(eps)
29
  return (loss * m).sum() / denom
30
 
@@ -41,17 +47,32 @@ def mse_peaks_only(logits, targets, peak_mask, eps=1e-8):
41
  def calculate_loss(
42
  logits,
43
  targets,
 
 
44
  eps: float = 1e-8,
45
  alpha: float = 1.0,
46
  gamma: float = 1.0,
47
  pos_weight=None,
48
  pad_value: float = -1.0,
 
49
  ):
50
  """
51
  Combine masked-BCE (non-peak) + masked-MSE on probs (peak), ignoring padding.
52
  Assumes targets == -1 are pads; non-peak = 0; peak > 0.
 
 
 
 
 
 
53
  """
 
 
54
  valid = (targets != pad_value)
 
 
 
 
55
 
56
  # Peak / non-peak masks that exclude pads
57
  nonpeak_mask = valid & (targets == 0)
@@ -60,10 +81,18 @@ def calculate_loss(
60
  # For safety, zero-out targets at pad positions so they never feed into BCE/MSE
61
  targets_safe = torch.where(valid, targets, torch.zeros_like(targets))
62
 
63
- bce_nonpeak = bce_loss_masked(logits, targets_safe, nonpeak_mask, pos_weight=pos_weight, eps=eps)
64
- mse_peak = mse_peaks_only(logits, targets_safe, peak_mask, eps=eps)
65
-
66
- return alpha * bce_nonpeak + gamma * mse_peak
 
 
 
 
 
 
 
 
67
 
68
  @torch.no_grad()
69
  def auroc_zeros_vs_ones_from_logits(
@@ -81,6 +110,7 @@ def auroc_zeros_vs_ones_from_logits(
81
  tp, fp: integer counts per threshold (shape (T,))
82
  """
83
  device = logits.device
 
84
  valid = ~glm_kpm if glm_kpm is not None else torch.ones_like(labels, dtype=torch.bool, device=device)
85
  keep = valid & ((labels > pos_thresh) | (labels == 0.0))
86
  if keep.sum() == 0:
@@ -126,6 +156,7 @@ def auprc_zeros_vs_ones_from_logits(
126
  thresholds: (T,)
127
  """
128
  device = logits.device
 
129
  valid = ~glm_kpm if glm_kpm is not None else torch.ones_like(labels, dtype=torch.bool, device=device)
130
  keep = valid & ((labels > pos_thresh) | (labels == 0.0))
131
  if keep.sum() == 0:
 
7
  from torchmetrics.functional.classification import (
8
  auroc, average_precision, roc, precision_recall_curve
9
  )
10
+ import rootutils
11
+ from dpacman.utils import pylogger
12
+
13
+ root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
14
+ logger = pylogger.RankedLogger(__name__, rank_zero_only=True)
15
+
16
 
17
  def _expand_like(mask: torch.Tensor, like: torch.Tensor):
18
  # Make mask broadcastable to logits/targets (handles (B,L) vs (B,L,1))
 
20
  mask = mask.unsqueeze(-1)
21
  return mask.expand_as(like)
22
 
23
+ def bce_loss_masked(logits, targets, mask, pos_weight=None, eps=1e-8):
24
  """
25
  Compute masked BCE with logits over non-peak positions only.
26
  Expects nonpeak_mask already broadcastable to logits.
 
30
  loss = F.binary_cross_entropy_with_logits(
31
  logits, t, reduction="none", pos_weight=pos_weight
32
  )
33
+ m = _expand_like(mask, loss).to(loss.dtype)
34
  denom = m.sum().clamp_min(eps)
35
  return (loss * m).sum() / denom
36
 
 
47
  def calculate_loss(
48
  logits,
49
  targets,
50
+ binder_kpm,
51
+ glm_kpm,
52
  eps: float = 1e-8,
53
  alpha: float = 1.0,
54
  gamma: float = 1.0,
55
  pos_weight=None,
56
  pad_value: float = -1.0,
57
+ loss_type="mixed"
58
  ):
59
  """
60
  Combine masked-BCE (non-peak) + masked-MSE on probs (peak), ignoring padding.
61
  Assumes targets == -1 are pads; non-peak = 0; peak > 0.
62
+
63
+ binder_kpm is 1 at PAD positions, 0 elsewhere
64
+ glm_kpm is 1 at PAD positions, 0 elsewhere
65
+
66
+ if loss_type is mixed, we're doing binary cross entropy off the peaks and MSE on the peaks.
67
+ if loss_type is binary, we're doing binary cross entropy everywhere because the labels are binary.
68
  """
69
+ # calculate validity in two ways; these should be the same.
70
+ # targets are padded to -1 where there is not really a DNA sequence there
71
  valid = (targets != pad_value)
72
+ if glm_kpm is not None:
73
+ nvalid = torch.sum(valid).item()
74
+ nvalid_2 = torch.sum(~glm_kpm).item()
75
+ assert nvalid==nvalid_2
76
 
77
  # Peak / non-peak masks that exclude pads
78
  nonpeak_mask = valid & (targets == 0)
 
81
  # For safety, zero-out targets at pad positions so they never feed into BCE/MSE
82
  targets_safe = torch.where(valid, targets, torch.zeros_like(targets))
83
 
84
+ if loss_type=="mixed":
85
+ bce_nonpeak = bce_loss_masked(logits, targets_safe, nonpeak_mask, pos_weight=pos_weight, eps=eps)
86
+ mse_peak = mse_peaks_only(logits, targets_safe, peak_mask, eps=eps)
87
+ return alpha * bce_nonpeak + gamma * mse_peak
88
+ else:
89
+ # we're expecting all binary labels. make sure.
90
+ all_binary = ((targets_safe==1) | (targets_safe==0)).all().item()
91
+ if not(all_binary):
92
+ logger.info(f"WARNING: expecting all binary labels for loss_type={loss_type}. Did not get all binary labels.")
93
+ # bce over all valid positions
94
+ bce_all = bce_loss_masked(logits, targets_safe, valid, pos_weight=pos_weight, eps=eps)
95
+ return alpha*bce_all
96
 
97
  @torch.no_grad()
98
  def auroc_zeros_vs_ones_from_logits(
 
110
  tp, fp: integer counts per threshold (shape (T,))
111
  """
112
  device = logits.device
113
+ # glm_kpm is 1 where there's a pad, so ~glm_kpm is valid positions
114
  valid = ~glm_kpm if glm_kpm is not None else torch.ones_like(labels, dtype=torch.bool, device=device)
115
  keep = valid & ((labels > pos_thresh) | (labels == 0.0))
116
  if keep.sum() == 0:
 
156
  thresholds: (T,)
157
  """
158
  device = logits.device
159
+ # glm_kpm is 1 where there's a pad, so ~glm_kpm is valid
160
  valid = ~glm_kpm if glm_kpm is not None else torch.ones_like(labels, dtype=torch.bool, device=device)
161
  keep = valid & ((labels > pos_thresh) | (labels == 0.0))
162
  if keep.sum() == 0:
dpacman/classifier/model.py CHANGED
@@ -10,7 +10,6 @@ from .loss import calculate_loss, auprc_zeros_vs_ones_from_logits, auroc_zeros_v
10
 
11
  set_seed()
12
 
13
-
14
  class LocalCNN(nn.Module):
15
  def __init__(self, dim: int = 256, kernel_size: int = 3):
16
  super().__init__()
@@ -156,6 +155,7 @@ class BindPredictor(LightningModule):
156
  dropout: float = 0,
157
  use_local_cnn_on_glm: bool = True,
158
  weight_decay: float = 0.01,
 
159
  ):
160
  # Init
161
  super(BindPredictor, self).__init__()
@@ -222,7 +222,7 @@ class BindPredictor(LightningModule):
222
  """
223
  logits = self.forward(batch["binder_emb"], batch["glm_emb"], batch["binder_kpm"], batch["glm_kpm"])
224
  loss = calculate_loss(
225
- logits, batch["labels"], alpha=self.hparams.alpha, gamma=self.hparams.gamma
226
  )
227
  self.log(
228
  "train/loss",
@@ -256,7 +256,7 @@ class BindPredictor(LightningModule):
256
  def validation_step(self, batch, batch_idx):
257
  logits = self.forward(batch["binder_emb"], batch["glm_emb"], batch["binder_kpm"], batch["glm_kpm"])
258
  loss = calculate_loss(
259
- logits, batch["labels"], alpha=self.hparams.alpha, gamma=self.hparams.gamma
260
  )
261
  self.log(
262
  "val/loss",
@@ -287,7 +287,7 @@ class BindPredictor(LightningModule):
287
  def test_step(self, batch, batch_idx):
288
  logits = self.forward(batch["binder_emb"], batch["glm_emb"], batch["binder_kpm"], batch["glm_kpm"])
289
  loss = calculate_loss(
290
- logits, batch["labels"], alpha=self.hparams.alpha, gamma=self.hparams.gamma
291
  )
292
  self.log(
293
  "test/loss", loss, on_step=False, on_epoch=True, batch_size=logits.size(0)
@@ -307,8 +307,20 @@ class BindPredictor(LightningModule):
307
  self.log("test/auroc_0v1",
308
  auc if torch.isfinite(auc) else torch.tensor(0.0, device=auc.device),
309
  on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0))
 
310
  return loss
311
-
 
 
 
 
 
 
 
 
 
 
 
312
  def on_before_optimizer_step(self, optimizer):
313
  # Compute global L2 norm of all parameter gradients (ignores None grads)
314
  grads = []
 
10
 
11
  set_seed()
12
 
 
13
  class LocalCNN(nn.Module):
14
  def __init__(self, dim: int = 256, kernel_size: int = 3):
15
  super().__init__()
 
155
  dropout: float = 0,
156
  use_local_cnn_on_glm: bool = True,
157
  weight_decay: float = 0.01,
158
+ loss_type = "mixed"
159
  ):
160
  # Init
161
  super(BindPredictor, self).__init__()
 
222
  """
223
  logits = self.forward(batch["binder_emb"], batch["glm_emb"], batch["binder_kpm"], batch["glm_kpm"])
224
  loss = calculate_loss(
225
+ logits, batch["labels"], batch["binder_kpm"], batch["glm_kpm"], alpha=self.hparams.alpha, gamma=self.hparams.gamma, loss_type=self.hparams.loss_type
226
  )
227
  self.log(
228
  "train/loss",
 
256
  def validation_step(self, batch, batch_idx):
257
  logits = self.forward(batch["binder_emb"], batch["glm_emb"], batch["binder_kpm"], batch["glm_kpm"])
258
  loss = calculate_loss(
259
+ logits, batch["labels"], batch["binder_kpm"], batch["glm_kpm"], alpha=self.hparams.alpha, gamma=self.hparams.gamma, loss_type=self.hparams.loss_type
260
  )
261
  self.log(
262
  "val/loss",
 
287
  def test_step(self, batch, batch_idx):
288
  logits = self.forward(batch["binder_emb"], batch["glm_emb"], batch["binder_kpm"], batch["glm_kpm"])
289
  loss = calculate_loss(
290
+ logits, batch["labels"], batch["binder_kpm"], batch["glm_kpm"], alpha=self.hparams.alpha, gamma=self.hparams.gamma, loss_type=self.hparams.loss_type
291
  )
292
  self.log(
293
  "test/loss", loss, on_step=False, on_epoch=True, batch_size=logits.size(0)
 
307
  self.log("test/auroc_0v1",
308
  auc if torch.isfinite(auc) else torch.tensor(0.0, device=auc.device),
309
  on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0))
310
+
311
  return loss
312
+
313
+ def predict_step(self, batch, batch_idx, dataloader_idx=0):
314
+ logits = self.forward(batch["binder_emb"], batch["glm_emb"],
315
+ batch["binder_kpm"], batch["glm_kpm"]).squeeze(-1) # (B,L)
316
+ valid = ~batch["glm_kpm"] # (B,L)
317
+ return {
318
+ "ids": batch["ID"], # list[str]
319
+ "logits": logits.detach().cpu(), # (B,Lmax) padded
320
+ "valid": valid.detach().cpu(), # (B,Lmax) booleans
321
+ "labels": batch["labels"].detach().cpu(), # (B,Lmax) padded
322
+ }
323
+
324
  def on_before_optimizer_step(self, optimizer):
325
  # Compute global L2 norm of all parameter gradients (ignores None grads)
326
  grads = []
dpacman/data_modules/pair.py CHANGED
@@ -157,7 +157,7 @@ def make_length_batches(
157
  # ---- dataset ---------------------------------------------------------
158
  class PairDataset(Dataset):
159
  def __init__(
160
- self, dataset: pd.DataFrame, norm_value: int = 1333, round_to: int = 4
161
  ):
162
  """
163
  Args:
@@ -165,21 +165,29 @@ class PairDataset(Dataset):
165
  - norm_value: max score, which we'll use to divide all the integer scores in "scores"
166
  - round_to: how many decimal places for the numerical score values
167
  """
168
- self.dataset = self._load_and_normalize(dataset, norm_value, round_to)
169
- self.norm_value = (
170
- norm_value # what to divide everything in labels by to make it a float
171
- )
172
-
173
- def _load_and_normalize(self, dataset, norm_value: int, round_to: int):
 
 
 
174
  """
175
  Labels come in looking like "0,0,0,100,100,133,133,100,100,0,0,"
176
  This method turns the labels from strings into floats out to 4 decimal places
177
  """
 
 
 
 
 
178
  # split string into list of strings
179
- dataset["scores"] = dataset["scores"].apply(lambda x: x.split(","))
180
  # turn list of strings into list of normalized, rounded floats
181
- dataset["scores"] = dataset["scores"].apply(
182
- lambda x: [round(int(y) / norm_value, round_to) for y in x]
183
  )
184
 
185
  # convert to records for ease of loading
@@ -212,6 +220,9 @@ class PairDataModule(LightningDataModule):
212
  debug_run: bool = False,
213
  pin_memory: bool = False,
214
  shuffle_train_batch_order: bool = True,
 
 
 
215
  ):
216
  super().__init__()
217
  self.save_hyperparameters()
@@ -221,6 +232,9 @@ class PairDataModule(LightningDataModule):
221
  self.train_data_file = train_file
222
  self.val_data_file = val_file
223
  self.test_data_file = test_file
 
 
 
224
 
225
  # Initialize hyperparameters like batch size
226
  self.batch_size = batch_size
@@ -232,10 +246,12 @@ class PairDataModule(LightningDataModule):
232
  self.collate = ShelfCollator(
233
  tr_shelf_path=str(tr_shelf_path),
234
  dna_shelf_path=str(dna_shelf_path),
235
- tr_key="tr_sequence",
236
- dna_key="dna_sequence",
237
  dtype=torch.float32,
238
- pad_value=0.0,
 
 
239
  )
240
  self.drop_last = False # or True, your choice
241
  self.shuffle_batch_order = shuffle_train_batch_order # False keep batches deterministic per epoch; set True if you want to shuffle batch order
@@ -247,11 +263,13 @@ class PairDataModule(LightningDataModule):
247
  """
248
  Load and unpack an input csv whose columns are binder_path,glm_path,label
249
  """
250
- df = pd.read_csv(file_path)
251
- if lim is not None:
252
- df = df[:lim].reset_index(drop=True)
253
-
254
- return df[["ID", "dna_sequence", "tr_sequence", "scores"]]
 
 
255
 
256
  def setup(self, stage: str | None = None):
257
  lim = 5 if self.debug_run else None
@@ -260,7 +278,7 @@ class PairDataModule(LightningDataModule):
260
  if stage in (None, "fit"):
261
  if not hasattr(self, "train_dataset"):
262
  train_df = self.load_file(self.train_data_file, lim=lim)
263
- self.train_dataset = PairDataset(train_df)
264
  self.train_batches = make_length_batches(
265
  dataset_records=self.train_dataset.dataset,
266
  tr_shelf_path=str(self.hparams.tr_shelf_path),
@@ -276,7 +294,7 @@ class PairDataModule(LightningDataModule):
276
 
277
  if not hasattr(self, "val_dataset"):
278
  val_df = self.load_file(self.val_data_file, lim=lim)
279
- self.val_dataset = PairDataset(val_df)
280
  self.val_batches = make_length_batches(
281
  dataset_records=self.val_dataset.dataset,
282
  tr_shelf_path=str(self.hparams.tr_shelf_path),
@@ -291,7 +309,7 @@ class PairDataModule(LightningDataModule):
291
  if stage in (None, "validate"):
292
  if not hasattr(self, "val_dataset"):
293
  val_df = self.load_file(self.val_data_file, lim=lim)
294
- self.val_dataset = PairDataset(val_df)
295
  self.val_batches = make_length_batches(
296
  dataset_records=self.val_dataset.dataset,
297
  tr_shelf_path=str(self.hparams.tr_shelf_path),
@@ -306,7 +324,7 @@ class PairDataModule(LightningDataModule):
306
  if stage in (None, "test"):
307
  if not hasattr(self, "test_dataset"):
308
  test_df = self.load_file(self.test_data_file, lim=lim)
309
- self.test_dataset = PairDataset(test_df)
310
  self.test_batches = make_length_batches(
311
  dataset_records=self.test_dataset.dataset,
312
  tr_shelf_path=str(self.hparams.tr_shelf_path),
@@ -346,6 +364,17 @@ class PairDataModule(LightningDataModule):
346
  persistent_workers=(self.num_workers > 0),
347
  pin_memory=self.hparams.pin_memory,
348
  )
 
 
 
 
 
 
 
 
 
 
 
349
 
350
 
351
  class ShelfCollator:
@@ -373,13 +402,17 @@ class ShelfCollator:
373
  dna_key: str = "dna_sequence",
374
  dtype: torch.dtype = torch.float32,
375
  pad_value: float = -1.0,
 
 
376
  ):
377
  self.tr_path = tr_shelf_path
378
  self.dna_path = dna_shelf_path
 
379
  self.tr_key = tr_key
380
  self.dna_key = dna_key
381
  self.dtype = dtype
382
  self.pad_value = pad_value
 
383
 
384
  # opened lazily per worker:
385
  self._tr_db = None
@@ -400,7 +433,7 @@ class ShelfCollator:
400
  ids = [b.get("ID", None) for b in batch]
401
  tr_seqs = [b[self.tr_key] for b in batch]
402
  dna_seqs = [b[self.dna_key] for b in batch]
403
- scores_list = [b["scores"] for b in batch]
404
 
405
  # 1) Fetch embeddings lazily from shelves
406
  binder_list = []
@@ -438,10 +471,10 @@ class ShelfCollator:
438
  glm_emb = pad_sequence(
439
  glm_list, batch_first=True, padding_value=self.pad_value
440
  ) # [B, Lg_max, Dg]
441
-
442
  binder_lens = torch.as_tensor(binder_lens, dtype=torch.int64)
443
  glm_lens = torch.as_tensor(glm_lens, dtype=torch.int64)
444
-
445
  binder_mask = torch.arange(binder_emb.size(1)).unsqueeze(
446
  0
447
  ) < binder_lens.unsqueeze(
@@ -460,6 +493,24 @@ class ShelfCollator:
460
  labels = pad_sequence(
461
  labels_list, batch_first=True, padding_value=self.pad_value
462
  ) # [B, Lg_max]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
463
 
464
  return {
465
  "binder_emb": binder_emb, # [B, Lb_max, Db]
@@ -474,32 +525,6 @@ class ShelfCollator:
474
  "dna_sequence": dna_seqs,
475
  }
476
 
477
-
478
- def collate_fn(batch, tr_shelf_path, dna_shelf_path):
479
- Bs = [b.shape[0] for b, _, _ in batch]
480
- Gs = [g.shape[0] for _, g, _ in batch]
481
- maxB, maxG = max(Bs), max(Gs)
482
-
483
- def pad_seq(x, L):
484
- if x.shape[0] < L:
485
- pad = torch.zeros(
486
- (L - x.shape[0], x.shape[1]), dtype=x.dtype, device=x.device
487
- )
488
- return torch.cat([x, pad], dim=0)
489
- return x
490
-
491
- def pad_t(y, L):
492
- if y.shape[0] < L:
493
- pad = torch.zeros((L - y.shape[0],), dtype=y.dtype, device=y.device)
494
- return torch.cat([y, pad], dim=0)
495
- return y
496
-
497
- b_stack = torch.stack([pad_seq(b, maxB) for b, _, _ in batch])
498
- g_stack = torch.stack([pad_seq(g, maxG) for _, g, _ in batch])
499
- t_stack = torch.stack([pad_t(t, maxG) for *_, t in batch])
500
- return b_stack, g_stack, t_stack
501
-
502
-
503
  # ------------------------ Helpers for main method debugging only ------------------------------------------#
504
  def _peek_batches(dl, n_batches: int = 2, tag: str = "train"):
505
  logger.info(f"\n=== Peek {n_batches} batch(es) from {tag} loader ===")
@@ -519,13 +544,13 @@ def _peek_batches(dl, n_batches: int = 2, tag: str = "train"):
519
  logger.info(f" glm_mask true count: {gm.sum().item()} / {gm.numel()}")
520
  logger.info(f" glm_mask: {tuple(gm.shape)} dtype={gm.dtype}")
521
  logger.info(
522
- f" labels: {tuple(y.shape)} min={y.min().item():.4f} max={y.max().item():.4f}"
523
  )
524
  logger.info(f" IDs (first 5): {ids[:5]}")
 
525
  if i + 1 >= n_batches:
526
  break
527
 
528
-
529
  def _warn_on_paths(args):
530
  import os
531
 
@@ -577,7 +602,7 @@ def main():
577
  parser.add_argument("--batch_size", type=int, default=4)
578
  parser.add_argument("--num_workers", type=int, default=4)
579
  parser.add_argument(
580
- "--debug_run", action="store_true", help="limit dataset to a few rows"
581
  )
582
  parser.add_argument(
583
  "--n_batches", type=int, default=2, help="how many batches to print per split"
 
157
  # ---- dataset ---------------------------------------------------------
158
  class PairDataset(Dataset):
159
  def __init__(
160
+ self, dataset: pd.DataFrame, norm_value: int = 1333, round_to: int = 4, score_col="scores", target_col="dna_sequence", binder_col="tr_sequence"
161
  ):
162
  """
163
  Args:
 
165
  - norm_value: max score, which we'll use to divide all the integer scores in "scores"
166
  - round_to: how many decimal places for the numerical score values
167
  """
168
+ self.fake_scores=False
169
+ self.score_col = score_col
170
+ self.target_col = target_col
171
+ self.binder_col = binder_col
172
+ self.norm_value = norm_value
173
+ self.round_to = round_to
174
+ self.dataset = self._load_and_normalize(dataset)
175
+
176
+ def _load_and_normalize(self, dataset):
177
  """
178
  Labels come in looking like "0,0,0,100,100,133,133,100,100,0,0,"
179
  This method turns the labels from strings into floats out to 4 decimal places
180
  """
181
+ if self.score_col not in dataset.columns:
182
+ logger.info(f"Scores not provided. Adding placeholder scores where all positions are considered binding")
183
+ dataset[self.score_col] = dataset["dna_sequence"].str.len()
184
+ dataset[self.score_col] = dataset[self.score_col].apply(lambda x: ",".join([str(self.norm_value)]*x))
185
+ self.fake_scores=True
186
  # split string into list of strings
187
+ dataset[self.score_col] = dataset[self.score_col].apply(lambda x: x.split(","))
188
  # turn list of strings into list of normalized, rounded floats
189
+ dataset[self.score_col] = dataset[self.score_col].apply(
190
+ lambda x: [round(int(y) / self.norm_value, self.round_to) for y in x]
191
  )
192
 
193
  # convert to records for ease of loading
 
220
  debug_run: bool = False,
221
  pin_memory: bool = False,
222
  shuffle_train_batch_order: bool = True,
223
+ score_col: str = "scores",
224
+ target_col: str = "dna_sequence",
225
+ binder_col: str = "tr_sequence"
226
  ):
227
  super().__init__()
228
  self.save_hyperparameters()
 
232
  self.train_data_file = train_file
233
  self.val_data_file = val_file
234
  self.test_data_file = test_file
235
+ self.target_col = target_col
236
+ self.binder_col = binder_col
237
+ self.score_col = score_col
238
 
239
  # Initialize hyperparameters like batch size
240
  self.batch_size = batch_size
 
246
  self.collate = ShelfCollator(
247
  tr_shelf_path=str(tr_shelf_path),
248
  dna_shelf_path=str(dna_shelf_path),
249
+ tr_key=self.binder_col,
250
+ dna_key=self.target_col,
251
  dtype=torch.float32,
252
+ pad_value=-1.0,
253
+ debug_run =self.debug_run,
254
+ score_col = self.score_col
255
  )
256
  self.drop_last = False # or True, your choice
257
  self.shuffle_batch_order = shuffle_train_batch_order # False keep batches deterministic per epoch; set True if you want to shuffle batch order
 
263
  """
264
  Load and unpack an input csv whose columns are binder_path,glm_path,label
265
  """
266
+ try:
267
+ df = pd.read_csv(file_path)
268
+ if lim is not None:
269
+ df = df[:lim].reset_index(drop=True)
270
+ return df[["ID", "dna_sequence", "tr_sequence", "scores"]]
271
+ except:
272
+ raise Exception(f"{file_path} is not a valid file")
273
 
274
  def setup(self, stage: str | None = None):
275
  lim = 5 if self.debug_run else None
 
278
  if stage in (None, "fit"):
279
  if not hasattr(self, "train_dataset"):
280
  train_df = self.load_file(self.train_data_file, lim=lim)
281
+ self.train_dataset = PairDataset(train_df, score_col = self.score_col, target_col = self.target_col, binder_col = self.binder_col)
282
  self.train_batches = make_length_batches(
283
  dataset_records=self.train_dataset.dataset,
284
  tr_shelf_path=str(self.hparams.tr_shelf_path),
 
294
 
295
  if not hasattr(self, "val_dataset"):
296
  val_df = self.load_file(self.val_data_file, lim=lim)
297
+ self.val_dataset = PairDataset(val_df, score_col = self.score_col, target_col = self.target_col, binder_col = self.binder_col)
298
  self.val_batches = make_length_batches(
299
  dataset_records=self.val_dataset.dataset,
300
  tr_shelf_path=str(self.hparams.tr_shelf_path),
 
309
  if stage in (None, "validate"):
310
  if not hasattr(self, "val_dataset"):
311
  val_df = self.load_file(self.val_data_file, lim=lim)
312
+ self.val_dataset = PairDataset(val_df, score_col = self.score_col, target_col = self.target_col, binder_col = self.binder_col)
313
  self.val_batches = make_length_batches(
314
  dataset_records=self.val_dataset.dataset,
315
  tr_shelf_path=str(self.hparams.tr_shelf_path),
 
324
  if stage in (None, "test"):
325
  if not hasattr(self, "test_dataset"):
326
  test_df = self.load_file(self.test_data_file, lim=lim)
327
+ self.test_dataset = PairDataset(test_df, score_col = self.score_col, target_col = self.target_col, binder_col = self.binder_col)
328
  self.test_batches = make_length_batches(
329
  dataset_records=self.test_dataset.dataset,
330
  tr_shelf_path=str(self.hparams.tr_shelf_path),
 
364
  persistent_workers=(self.num_workers > 0),
365
  pin_memory=self.hparams.pin_memory,
366
  )
367
+
368
+ def predict_dataloader(self):
369
+ # Same as test
370
+ return DataLoader(
371
+ self.test_dataset,
372
+ batch_sampler=self.test_batch_sampler,
373
+ collate_fn=self.collate,
374
+ num_workers=self.num_workers,
375
+ persistent_workers=(self.num_workers > 0),
376
+ pin_memory=self.hparams.pin_memory,
377
+ )
378
 
379
 
380
  class ShelfCollator:
 
402
  dna_key: str = "dna_sequence",
403
  dtype: torch.dtype = torch.float32,
404
  pad_value: float = -1.0,
405
+ debug_run: bool = False,
406
+ score_col = "scores"
407
  ):
408
  self.tr_path = tr_shelf_path
409
  self.dna_path = dna_shelf_path
410
+ self.score_col = score_col
411
  self.tr_key = tr_key
412
  self.dna_key = dna_key
413
  self.dtype = dtype
414
  self.pad_value = pad_value
415
+ self.debug_run = debug_run
416
 
417
  # opened lazily per worker:
418
  self._tr_db = None
 
433
  ids = [b.get("ID", None) for b in batch]
434
  tr_seqs = [b[self.tr_key] for b in batch]
435
  dna_seqs = [b[self.dna_key] for b in batch]
436
+ scores_list = [b[self.score_col] for b in batch]
437
 
438
  # 1) Fetch embeddings lazily from shelves
439
  binder_list = []
 
471
  glm_emb = pad_sequence(
472
  glm_list, batch_first=True, padding_value=self.pad_value
473
  ) # [B, Lg_max, Dg]
474
+
475
  binder_lens = torch.as_tensor(binder_lens, dtype=torch.int64)
476
  glm_lens = torch.as_tensor(glm_lens, dtype=torch.int64)
477
+
478
  binder_mask = torch.arange(binder_emb.size(1)).unsqueeze(
479
  0
480
  ) < binder_lens.unsqueeze(
 
493
  labels = pad_sequence(
494
  labels_list, batch_first=True, padding_value=self.pad_value
495
  ) # [B, Lg_max]
496
+
497
+ if self.debug_run:
498
+ max_binder_len = max(binder_lens)
499
+ max_glm_len = max(glm_lens)
500
+ binder_expected_false = sum(max_binder_len-binder_lens).item()
501
+ binder_expected_true = sum(binder_lens)
502
+ binder_expected_total = binder_expected_true + binder_expected_false
503
+ glm_expected_false = sum(max_glm_len-glm_lens).item()
504
+ glm_expected_true = sum(glm_lens).item()
505
+ glm_expected_total = glm_expected_true + glm_expected_false
506
+ labels_neg1 = sum(sum(labels==-1)).item()
507
+ expected_labels_neg1 = glm_expected_false
508
+
509
+ logger.info(f" Max binder length: {max_binder_len}, original lengths: {binder_lens}, ultimate dimensions: {binder_emb.shape}")
510
+ logger.info(f" Binder expect: true/total = {binder_expected_true}/{binder_expected_total}")
511
+ logger.info(f" Max DNA length: {max_glm_len}, original lengths: {glm_lens}, ultimate dimensions: {glm_emb.shape}")
512
+ logger.info(f" DNA expect: true/total = {glm_expected_true}/{glm_expected_total}")
513
+ logger.info(f" Labels expect -1: -1/total = {expected_labels_neg1}/{glm_expected_total}. True: {labels_neg1}/{labels.numel()}")
514
 
515
  return {
516
  "binder_emb": binder_emb, # [B, Lb_max, Db]
 
525
  "dna_sequence": dna_seqs,
526
  }
527
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
528
  # ------------------------ Helpers for main method debugging only ------------------------------------------#
529
  def _peek_batches(dl, n_batches: int = 2, tag: str = "train"):
530
  logger.info(f"\n=== Peek {n_batches} batch(es) from {tag} loader ===")
 
544
  logger.info(f" glm_mask true count: {gm.sum().item()} / {gm.numel()}")
545
  logger.info(f" glm_mask: {tuple(gm.shape)} dtype={gm.dtype}")
546
  logger.info(
547
+ f" labels: {tuple(y.shape)} min={y.min().item():.4f} max={y.max().item():.4f}, total -1 = {sum(sum(y==-1)).item()}"
548
  )
549
  logger.info(f" IDs (first 5): {ids[:5]}")
550
+ # should make sure that the number of labels that are -1 equals the number of padding tokens
551
  if i + 1 >= n_batches:
552
  break
553
 
 
554
  def _warn_on_paths(args):
555
  import os
556
 
 
602
  parser.add_argument("--batch_size", type=int, default=4)
603
  parser.add_argument("--num_workers", type=int, default=4)
604
  parser.add_argument(
605
+ "--debug_run", default=True, action="store_true", help="limit dataset to a few rows"
606
  )
607
  parser.add_argument(
608
  "--n_batches", type=int, default=2, help="how many batches to print per split"
dpacman/scripts/eval.py CHANGED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Script for using the model just for inference.
3
+ """
4
+ from typing import Any, Dict, List, Optional, Tuple
5
+
6
+ import hydra
7
+ from hydra.core.hydra_config import HydraConfig
8
+ import torch
9
+ import rootutils
10
+ import lightning as L
11
+ from lightning import Callback, LightningDataModule, LightningModule, Trainer
12
+ from lightning.pytorch.loggers import Logger
13
+ from omegaconf import DictConfig
14
+ from pathlib import Path
15
+ import pandas as pd
16
+ from dpacman.classifier.loss import calculate_loss, auprc_zeros_vs_ones_from_logits, auroc_zeros_vs_ones_from_logits
17
+ import pickle
18
+
19
+ root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
20
+
21
+ from dpacman.utils import (
22
+ RankedLogger,
23
+ extras,
24
+ get_metric_value,
25
+ instantiate_callbacks,
26
+ instantiate_loggers,
27
+ log_hyperparameters,
28
+ task_wrapper,
29
+ )
30
+
31
+ log = RankedLogger(__name__, rank_zero_only=True)
32
+
33
+
34
+ def h100_settings():
35
+ # Use TensorFloat-32 for float32 matmuls → big speedup with tiny accuracy tradeoff
36
+ torch.set_float32_matmul_precision("high") # or "medium" for even more speed
37
+
38
+ # (optional; older PyTorch toggle)
39
+ torch.backends.cuda.matmul.allow_tf32 = True
40
+ torch.backends.cudnn.allow_tf32 = True
41
+
42
+ def flatten_preds(pred_batches):
43
+ """
44
+ Flatten what the model predicts, which includes:
45
+ "ids": batch["ID"], # list[str] or list
46
+ "logits": logits.detach().cpu(), # (B, Lmax) padded
47
+ "valid": valid.detach().cpu(), # (B, Lmax) booleans
48
+ "labels"
49
+ """
50
+ out = []
51
+ for b in pred_batches:
52
+ ids, logits, valid, labels = b["ids"], b["logits"], b["valid"], b["labels"]
53
+ for i, id_ in enumerate(ids):
54
+ L = int(valid[i].sum().item()) # strip padding
55
+ trim_logits = logits[i, :L].numpy()
56
+ out.append({"ID": id_, "logits": trim_logits, "labels": labels[i, :L].numpy()})
57
+ return out
58
+
59
+ @task_wrapper
60
+ def predict(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
61
+ """trains model given checkpoint on a datamodule train set.
62
+
63
+ This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
64
+ failure. Useful for multiruns, saving info about the crash, etc.
65
+
66
+ :param cfg: DictConfig configuration composed by Hydra.
67
+ :return: Tuple[dict, dict] with metrics and dict with all instantiated objects.
68
+ """
69
+ # set seed for random number generators in pytorch, numpy and python.random
70
+ if cfg.get("seed"):
71
+ L.seed_everything(cfg.seed, workers=True)
72
+
73
+ log.info(f"Instantiating datamodule <{cfg.data_module._target_}>")
74
+ datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data_module)
75
+
76
+ log.info(f"Instantiating model <{cfg.model._target_}>")
77
+ model: LightningModule = hydra.utils.instantiate(cfg.model)
78
+
79
+ log.info("Instantiating callbacks...")
80
+ callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks"))
81
+
82
+ log.info("Instantiating loggers...")
83
+ logger: List[Logger] = instantiate_loggers(cfg.get("logger"))
84
+
85
+ log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
86
+ trainer: Trainer = hydra.utils.instantiate(
87
+ cfg.trainer, callbacks=callbacks, logger=logger
88
+ )
89
+
90
+ object_dict = {
91
+ "cfg": cfg,
92
+ "datamodule": datamodule,
93
+ "model": model,
94
+ "callbacks": callbacks,
95
+ "logger": logger,
96
+ "trainer": trainer,
97
+ }
98
+
99
+ if logger:
100
+ log.info("Logging hyperparameters!")
101
+ log_hyperparameters(object_dict)
102
+
103
+ if cfg.get("test"):
104
+ log.info("Starting testing!")
105
+ ckpt_path = cfg.ckpt_path
106
+ if ckpt_path == "":
107
+ log.warning("No ckpt path was passed! Cannot continue")
108
+ return
109
+ trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
110
+
111
+ pred_batches = trainer.predict(model, datamodule=datamodule, ckpt_path=ckpt_path, return_predictions=True)
112
+ out = flatten_preds(pred_batches)
113
+
114
+ # make output dir
115
+ output_dir = Path(HydraConfig.get().run.dir)
116
+ save_path = output_dir / "predictions.pkl"
117
+ with open(save_path, "wb") as f:
118
+ pickle.dump(out, f)
119
+
120
+ # iterate through out and recalculate AUC, AUPRC, loss - only if there are labels
121
+ # only if the user actually passed scores; otherwise don't bother
122
+ if not(datamodule.test_dataset.fake_scores):
123
+ for i, d in enumerate(out):
124
+ loss = calculate_loss(
125
+ torch.tensor(d["logits"]), torch.tensor(d["labels"]), None, None, alpha=cfg.model.alpha, gamma=cfg.model.gamma
126
+ )
127
+ # ---- AUPRC and AUROC on labels in {0, >0.99} only ----
128
+ ap, n_pos, n_neg, precision, recall, ap_thresholds = auprc_zeros_vs_ones_from_logits(
129
+ torch.tensor(d["logits"]), torch.tensor(d["labels"]), torch.zeros(d["labels"].shape, dtype=torch.bool), pos_thresh=0.99
130
+ )
131
+ auc, n_pos, n_neg, tpr, fpr, auc_thresolds, tp, fp = auroc_zeros_vs_ones_from_logits(
132
+ torch.tensor(d["logits"]), torch.tensor(d["labels"]), torch.zeros(d["labels"].shape, dtype=torch.bool), pos_thresh=0.99
133
+ )
134
+ out[i]["loss"] = loss.item() if loss.numel()>0 else None
135
+ out[i]["auprc"] = ap.item() if ap.numel()>0 else None
136
+ out[i]["auroc"] = auc.item() if auc.numel()>0 else None
137
+ out[i]["n_pos"] = n_pos
138
+ out[i]["n_neg"] = n_neg
139
+ out[i]["precision"] = precision.numpy() if precision.numel()>0 else None
140
+ out[i]["recall"] = recall.numpy() if recall.numel()>0 else None
141
+ out[i]["auprc_thresholds"] = ap_thresholds.numpy() if ap_thresholds.numel()>0 else None
142
+ out[i]["auc_thresholds"] = auc_thresolds.numpy() if auc_thresolds.numel()>0 else None
143
+ out[i]["tpr"] = tpr
144
+ out[i]["fpr"] = fpr
145
+
146
+ # Summary CSV (no big arrays inside)
147
+ summary_rows = []
148
+ for d in out:
149
+ summary_rows.append({
150
+ "ID": d["ID"],
151
+ "loss": d.get("loss"),
152
+ "auprc": d.get("auprc"),
153
+ "auroc": d.get("auroc"),
154
+ "n_pos": d.get("n_pos"),
155
+ "n_neg": d.get("n_neg"),
156
+ })
157
+ save_path = output_dir / "summary.csv"
158
+ pd.DataFrame(summary_rows).to_csv(output_dir / "summary.csv", index=False)
159
+ # save it
160
+ log.info(f"Saved eval/predict results to {save_path}")
161
+
162
+ test_metrics = trainer.callback_metrics
163
+
164
+ # merge train and test metrics
165
+ metric_dict = {**test_metrics}
166
+
167
+ return metric_dict, object_dict
168
+
169
+
170
+ @hydra.main(
171
+ version_base="1.3", config_path=str(root / "configs"), config_name="eval.yaml"
172
+ )
173
+ def main(cfg: DictConfig) -> None:
174
+ """Main entry point for evaluation.
175
+
176
+ :param cfg: DictConfig configuration composed by Hydra.
177
+ """
178
+ # apply extra utilities
179
+ # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
180
+ extras(cfg)
181
+
182
+ h100_settings() # try using settings for faster h100s training
183
+
184
+ # train the model
185
+ metric_dict, _ = predict(cfg)
186
+
187
+ # safely retrieve metric value for hydra-based hyperparameter optimization
188
+ metric_value = get_metric_value(
189
+ metric_dict=metric_dict, metric_name=cfg.get("optimized_metric")
190
+ )
191
+
192
+ # return optimized metric
193
+ return metric_value
194
+
195
+
196
+ if __name__ == "__main__":
197
+ main()
dpacman/scripts/run_eval.sh ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Manually specify values used in the config
4
+ main_task="eval"
5
+ model_type="classifier"
6
+ timestamp=$(date "+%Y-%m-%d_%H-%M-%S")
7
+
8
+ run_dir="$HOME/DPACMAN/logs/${main_task}/${model_type}/runs/${timestamp}"
9
+ mkdir -p "$run_dir"
10
+
11
+ if [ -z "$WANDB_API_KEY" ]; then
12
+ read -s -p "Enter your WANDB API key: " wandb_key
13
+ echo
14
+ export WANDB_API_KEY="$wandb_key"
15
+ fi
16
+
17
+ CUDA_VISIBLE_DEVICES=3 nohup python -u -m scripts.eval \
18
+ hydra.run.dir="${run_dir}" \
19
+ data_module.test_file="data_files/processed/splits/by_dna/test.csv" \
20
+ data_module.tr_shelf_path="data_files/processed/embeddings/fimo_hits_only/trs_esm.shelf" \
21
+ data_module.dna_shelf_path="data_files/processed/embeddings/fimo_hits_only/peaks_caduceus.shelf" \
22
+ data_module.batch_size=16 \
23
+ model.glm_input_dim=256 \
24
+ model.compressed_dim=256 \
25
+ model.hidden_dim=256 \
26
+ ckpt_path="/home/a03-svincoff/DPACMAN/logs/train/classifier/runs/2025-08-27_18-52-25/checkpoints/epoch_009.ckpt" \
27
+ model.lr=1e-5 \
28
+ > "${run_dir}/run.log" 2>&1 &
29
+
30
+ echo $! > "${run_dir}/pid.txt"
dpacman/scripts/run_train.sh CHANGED
@@ -22,16 +22,18 @@ CUDA_VISIBLE_DEVICES=0,1 nohup python -u -m scripts.train \
22
  +trainer.gradient_clip_algorithm="norm" \
23
  hydra.run.dir="${run_dir}" \
24
  trainer.devices=2 \
25
- trainer.max_epochs=20 \
26
  data_module.train_file="data_files/processed/splits/by_dna/train.csv" \
27
  data_module.val_file="data_files/processed/splits/by_dna/val.csv" \
28
  data_module.test_file="data_files/processed/splits/by_dna/test.csv" \
29
  data_module.tr_shelf_path="data_files/processed/embeddings/fimo_hits_only/trs_esm.shelf" \
30
  data_module.dna_shelf_path="data_files/processed/embeddings/fimo_hits_only/peaks_caduceus.shelf" \
31
  data_module.batch_size=16 \
 
 
32
  model.glm_input_dim=256 \
33
  model.compressed_dim=256 \
34
- model.hidden_dim=128 \
35
  model.lr=1e-5 \
36
  > "${run_dir}/run.log" 2>&1 &
37
 
 
22
  +trainer.gradient_clip_algorithm="norm" \
23
  hydra.run.dir="${run_dir}" \
24
  trainer.devices=2 \
25
+ trainer.max_epochs=10 \
26
  data_module.train_file="data_files/processed/splits/by_dna/train.csv" \
27
  data_module.val_file="data_files/processed/splits/by_dna/val.csv" \
28
  data_module.test_file="data_files/processed/splits/by_dna/test.csv" \
29
  data_module.tr_shelf_path="data_files/processed/embeddings/fimo_hits_only/trs_esm.shelf" \
30
  data_module.dna_shelf_path="data_files/processed/embeddings/fimo_hits_only/peaks_caduceus.shelf" \
31
  data_module.batch_size=16 \
32
+ data_module.score_col="binary_scores" \
33
+ model.loss_type="binary" \
34
  model.glm_input_dim=256 \
35
  model.compressed_dim=256 \
36
+ model.hidden_dim=256 \
37
  model.lr=1e-5 \
38
  > "${run_dir}/run.log" 2>&1 &
39
 
dpacman/scripts/run_train_baseline.sh CHANGED
@@ -14,7 +14,7 @@ if [ -z "$WANDB_API_KEY" ]; then
14
  export WANDB_API_KEY="$wandb_key"
15
  fi
16
 
17
- CUDA_VISIBLE_DEVICES=0,1 nohup python -u -m scripts.train \
18
  +trainer.strategy=ddp \
19
  +trainer.use_distributed_sampler="false" \
20
  +trainer.detect_anomaly="false" \
@@ -29,6 +29,8 @@ CUDA_VISIBLE_DEVICES=0,1 nohup python -u -m scripts.train \
29
  data_module.tr_shelf_path="data_files/processed/embeddings/fimo_hits_only/trs_esm.shelf" \
30
  data_module.dna_shelf_path="data_files/processed/embeddings/fimo_hits_only/peaks_caduceus.shelf" \
31
  data_module.batch_size=16 \
 
 
32
  model=baseline \
33
  model.glm_input_dim=256 \
34
  model.compressed_dim=256 \
 
14
  export WANDB_API_KEY="$wandb_key"
15
  fi
16
 
17
+ CUDA_VISIBLE_DEVICES=2,3 nohup python -u -m scripts.train \
18
  +trainer.strategy=ddp \
19
  +trainer.use_distributed_sampler="false" \
20
  +trainer.detect_anomaly="false" \
 
29
  data_module.tr_shelf_path="data_files/processed/embeddings/fimo_hits_only/trs_esm.shelf" \
30
  data_module.dna_shelf_path="data_files/processed/embeddings/fimo_hits_only/peaks_caduceus.shelf" \
31
  data_module.batch_size=16 \
32
+ data_module.score_col="binary_scores" \
33
+ model.loss_type="binary" \
34
  model=baseline \
35
  model.glm_input_dim=256 \
36
  model.compressed_dim=256 \