svincoff commited on
Commit
1d43edf
·
1 Parent(s): 41553a7
configs/logger/wandb.yaml CHANGED
@@ -8,6 +8,7 @@ wandb:
8
  id: null # pass correct id to resume experiment!
9
  anonymous: null # enable anonymous logging
10
  project: "dnabind"
 
11
  log_model: False # upload lightning ckpts
12
  prefix: "" # a string to put at the beginning of metric keys
13
  # entity: "" # set to name of your wandb team
 
8
  id: null # pass correct id to resume experiment!
9
  anonymous: null # enable anonymous logging
10
  project: "dnabind"
11
+ entity: "sophia-vincoff-team"
12
  log_model: False # upload lightning ckpts
13
  prefix: "" # a string to put at the beginning of metric keys
14
  # entity: "" # set to name of your wandb team
dpacman/classifier/loss.py CHANGED
@@ -62,6 +62,60 @@ def calculate_loss(
62
 
63
  return alpha * bce_nonpeak + gamma * mse_peak
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  def accuracy_percentage(
66
  logits,
67
  targets,
 
62
 
63
  return alpha * bce_nonpeak + gamma * mse_peak
64
 
65
+ import torch
66
+
67
+ @torch.no_grad()
68
+ def auprc_zeros_vs_ones_from_logits(
69
+ logits: torch.Tensor, # (B, L)
70
+ labels: torch.Tensor, # (B, L)
71
+ glm_kpm: torch.Tensor | None, # (B, L) True=PAD; pass None if not available
72
+ pos_thresh: float = 0.99,
73
+ ) -> tuple[torch.Tensor, int, int]:
74
+ """
75
+ Returns (ap, n_pos, n_neg). AP is Average Precision (area under PR).
76
+ Uses only positions with labels == 0.0 or > pos_thresh. Ignores PADs via glm_kpm.
77
+ Computation stays on the same device as logits.
78
+ """
79
+ probs = torch.sigmoid(logits)
80
+
81
+ # Valid positions: not padded
82
+ if glm_kpm is not None:
83
+ valid = ~glm_kpm
84
+ else:
85
+ valid = torch.ones_like(labels, dtype=torch.bool, device=labels.device)
86
+
87
+ # Keep only exact zeros and near-ones
88
+ pos = labels > pos_thresh
89
+ neg = labels == 0.0
90
+ keep = valid & (pos | neg)
91
+
92
+ if keep.sum() == 0:
93
+ return torch.tensor(float('nan'), device=logits.device), 0, 0
94
+
95
+ y = pos[keep].to(probs.dtype) # 1 for >0.99, 0 for 0.0
96
+ s = probs[keep].to(probs.dtype)
97
+
98
+ n = y.numel()
99
+ n_pos = int(y.sum().item())
100
+ n_neg = n - n_pos
101
+ if n_pos == 0: # no positives → AP = 0 by convention
102
+ return torch.tensor(0.0, device=logits.device), 0, n_neg
103
+
104
+ # Sort by score descending
105
+ order = torch.argsort(s, descending=True)
106
+ y_sorted = y[order]
107
+
108
+ # CumTP and precision/recall
109
+ tp = torch.cumsum(y_sorted, dim=0)
110
+ ranks = torch.arange(1, n + 1, device=logits.device, dtype=probs.dtype)
111
+ precision = tp / ranks
112
+ recall = tp / n_pos
113
+
114
+ # AP = sum( precision * Δrecall )
115
+ recall_prev = torch.cat([torch.zeros(1, device=logits.device, dtype=probs.dtype), recall[:-1]])
116
+ ap = (precision * (recall - recall_prev)).sum()
117
+ return ap, n_pos, n_neg
118
+
119
  def accuracy_percentage(
120
  logits,
121
  targets,
dpacman/classifier/model.py CHANGED
@@ -6,7 +6,7 @@ import torch
6
  from torch import nn
7
  from lightning import LightningModule
8
  from dpacman.utils.models import set_seed
9
- from .loss import calculate_loss
10
 
11
  set_seed()
12
 
@@ -174,7 +174,7 @@ class BindPredictor(LightningModule):
174
  [CrossModalBlock(hidden_dim, heads) for _ in range(num_layers)]
175
  )
176
 
177
- self.ln_out = nn.LayerNorm(hidden_dim)
178
  # self.head = nn.Sequential(nn.Linear(hidden_dim, 1), nn.Sigmoid()) # OLD: returned probabilities
179
  self.head = nn.Linear(hidden_dim, 1) # NEW: return logits (safe for AMP)
180
 
@@ -231,6 +231,20 @@ class BindPredictor(LightningModule):
231
  prog_bar=True,
232
  batch_size=logits.size(0),
233
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  return loss
235
 
236
  def validation_step(self, batch, batch_idx):
 
6
  from torch import nn
7
  from lightning import LightningModule
8
  from dpacman.utils.models import set_seed
9
+ from .loss import calculate_loss, auprc_zeros_vs_ones_from_logits
10
 
11
  set_seed()
12
 
 
174
  [CrossModalBlock(hidden_dim, heads) for _ in range(num_layers)]
175
  )
176
 
177
+ #self.ln_out = nn.LayerNorm(hidden_dim)
178
  # self.head = nn.Sequential(nn.Linear(hidden_dim, 1), nn.Sigmoid()) # OLD: returned probabilities
179
  self.head = nn.Linear(hidden_dim, 1) # NEW: return logits (safe for AMP)
180
 
 
231
  prog_bar=True,
232
  batch_size=logits.size(0),
233
  )
234
+
235
+ # ---- AUPRC on labels in {0, >0.99} only ----
236
+ if False:
237
+ ap, n_pos, n_neg = auprc_zeros_vs_ones_from_logits(
238
+ logits.detach(), batch["labels"], batch.get("glm_kpm"), pos_thresh=0.99
239
+ )
240
+ # per-batch AP (epoch-mean is a decent summary); sync across GPUs if using DDP
241
+ self.log("train/auprc_0v1",
242
+ ap if torch.isfinite(ap) else torch.tensor(0.0, device=ap.device),
243
+ on_step=False, on_epoch=True, prog_bar=True, sync_dist=True, batch_size=logits.size(0))
244
+ # (optional) also log class counts so you can sanity-check balance
245
+ self.log("train/n_pos_0v1", float(n_pos), on_step=False, on_epoch=True, sync_dist=True)
246
+ self.log("train/n_neg_0v1", float(n_neg), on_step=False, on_epoch=True, sync_dist=True)
247
+
248
  return loss
249
 
250
  def validation_step(self, batch, batch_idx):
dpacman/scripts/run_train.sh CHANGED
@@ -17,11 +17,12 @@ fi
17
  CUDA_VISIBLE_DEVICES=0,1 nohup python -u -m scripts.train \
18
  +trainer.strategy=ddp \
19
  +trainer.use_distributed_sampler="false"\
 
20
  hydra.run.dir="${run_dir}" \
21
  trainer.devices=2 \
22
- data_module.train_file="data_files/processed/splits/by_dna/train.csv" \
23
- data_module.val_file="data_files/processed/splits/by_dna/val.csv" \
24
- data_module.test_file="data_files/processed/splits/by_dna/test.csv" \
25
  data_module.tr_shelf_path="data_files/processed/embeddings/fimo_hits_only/trs_esm.shelf" \
26
  data_module.dna_shelf_path="data_files/processed/embeddings/fimo_hits_only/peaks_caduceus.shelf" \
27
  model.glm_input_dim=256 \
 
17
  CUDA_VISIBLE_DEVICES=0,1 nohup python -u -m scripts.train \
18
  +trainer.strategy=ddp \
19
  +trainer.use_distributed_sampler="false"\
20
+ +trainer.detect_anomaly="false"\
21
  hydra.run.dir="${run_dir}" \
22
  trainer.devices=2 \
23
+ data_module.train_file="data_files/processed/splits/by_dna/babytrain.csv" \
24
+ data_module.val_file="data_files/processed/splits/by_dna/babyval.csv" \
25
+ data_module.test_file="data_files/processed/splits/by_dna/babytest.csv" \
26
  data_module.tr_shelf_path="data_files/processed/embeddings/fimo_hits_only/trs_esm.shelf" \
27
  data_module.dna_shelf_path="data_files/processed/embeddings/fimo_hits_only/peaks_caduceus.shelf" \
28
  model.glm_input_dim=256 \