svincoff commited on
Commit
c237769
·
1 Parent(s): 29899b4

recent changes

Browse files
configs/data_task/clean/remap.yaml CHANGED
@@ -1,5 +1,5 @@
1
  name: remap
2
- type: clean
3
 
4
  nr_raw_path: dpacman/data_files/raw/remap/remap2022_nr_macs2_hg38_v1_0.bed
5
  nr_processed_dir: dpacman/data_files/processed/remap
 
1
  name: remap
2
+ task_type: clean
3
 
4
  nr_raw_path: dpacman/data_files/raw/remap/remap2022_nr_macs2_hg38_v1_0.bed
5
  nr_processed_dir: dpacman/data_files/processed/remap
dpacman/classifier/loss.py CHANGED
@@ -1,58 +1,170 @@
1
  """
2
- Define loss functions needed for training the model
3
  """
4
 
5
  import torch
6
- from torch.nn import functional as F
7
 
 
 
 
 
 
8
 
9
- def bce_loss_masked(logits, targets, nonpeak_mask, pos_weight=None):
10
  """
11
- Compute the masked Binary Cross Entropy, only on certain positions.
12
- We will only compute BCE on positions whre nonpeak_mask == 1.0; the mask represents non-peak positions
13
  """
 
 
14
  loss = F.binary_cross_entropy_with_logits(
15
- logits, targets, reduction="none", pos_weight=pos_weight
16
  )
17
- denom = nonpeak_mask.sum().clamp_min(1.0)
18
- return (loss * nonpeak_mask).sum() / denom
19
-
20
 
21
  def mse_peaks_only(logits, targets, peak_mask, eps=1e-8):
22
  """
23
- Calculate MSE on peaks only.
24
  """
25
  probs = torch.sigmoid(logits)
26
- mse_peaks = F.mse_loss(probs * peak_mask, targets * peak_mask, reduction="sum") / (
27
- peak_mask.sum() + eps
28
- )
29
- return mse_peaks
30
-
31
 
32
- def calculate_loss(logits, targets, eps=1e-8, alpha=1.0, gamma=1.0):
 
 
 
 
 
 
 
 
33
  """
34
- Combine masked-BCE + global-MSE to get a loss vlaue
 
35
  """
36
- # Calculate peak and non-peak masks.
37
- # Anything outside a peak will have a label equal to 0.
38
- nonpeak_mask = (targets == 0).float()
39
- peak_mask = (targets > 0).float()
40
 
41
- bce_nonpeak = bce_loss_masked(logits, targets, nonpeak_mask)
42
- mse_peak = mse_peaks_only(logits, targets, peak_mask, eps=eps)
 
43
 
44
- loss = alpha * bce_nonpeak + gamma * mse_peak
 
45
 
46
- return loss
 
47
 
 
48
 
49
- def accuracy_percentage(logits, targets, peak_thresh=0.5):
 
 
 
 
 
 
50
  """
51
- Compute accuracy in predicting high-confidence peaks (probability > 0.5)
52
  """
 
53
  probs = torch.sigmoid(logits)
54
- preds_bin = (probs >= 0.5).float()
55
- labels = (targets >= peak_thresh).float()
56
- correct = (preds_bin == labels).float().sum()
57
- total = torch.numel(labels)
58
- return (correct / max(1, total)).item() * 100.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
+ Define loss functions needed for training the model — padding safe (-1 sentinel)
3
  """
4
 
5
  import torch
6
+ import torch.nn.functional as F
7
 
8
+ def _expand_like(mask: torch.Tensor, like: torch.Tensor):
9
+ # Make mask broadcastable to logits/targets (handles (B,L) vs (B,L,1))
10
+ while mask.dim() < like.dim():
11
+ mask = mask.unsqueeze(-1)
12
+ return mask.expand_as(like)
13
 
14
+ def bce_loss_masked(logits, targets, nonpeak_mask, pos_weight=None, eps=1e-8):
15
  """
16
+ Compute masked BCE with logits over non-peak positions only.
17
+ Expects nonpeak_mask already broadcastable to logits.
18
  """
19
+ # Clamp targets into [0,1] to be safe, even if pads slip through earlier
20
+ t = targets.clamp(0.0, 1.0)
21
  loss = F.binary_cross_entropy_with_logits(
22
+ logits, t, reduction="none", pos_weight=pos_weight
23
  )
24
+ m = _expand_like(nonpeak_mask, loss).to(loss.dtype)
25
+ denom = m.sum().clamp_min(eps)
26
+ return (loss * m).sum() / denom
27
 
28
  def mse_peaks_only(logits, targets, peak_mask, eps=1e-8):
29
  """
30
+ Calculate MSE on peaks only (on probabilities), masking everything else.
31
  """
32
  probs = torch.sigmoid(logits)
33
+ per_elem = F.mse_loss(probs, targets, reduction="none")
34
+ m = _expand_like(peak_mask, per_elem).to(per_elem.dtype)
35
+ denom = m.sum().clamp_min(eps)
36
+ return (per_elem * m).sum() / denom
 
37
 
38
+ def calculate_loss(
39
+ logits,
40
+ targets,
41
+ eps: float = 1e-8,
42
+ alpha: float = 1.0,
43
+ gamma: float = 1.0,
44
+ pos_weight=None,
45
+ pad_value: float = -1.0,
46
+ ):
47
  """
48
+ Combine masked-BCE (non-peak) + masked-MSE on probs (peak), ignoring padding.
49
+ Assumes targets == -1 are pads; non-peak = 0; peak > 0.
50
  """
51
+ valid = (targets != pad_value)
 
 
 
52
 
53
+ # Peak / non-peak masks that exclude pads
54
+ nonpeak_mask = valid & (targets == 0)
55
+ peak_mask = valid & (targets > 0)
56
 
57
+ # For safety, zero-out targets at pad positions so they never feed into BCE/MSE
58
+ targets_safe = torch.where(valid, targets, torch.zeros_like(targets))
59
 
60
+ bce_nonpeak = bce_loss_masked(logits, targets_safe, nonpeak_mask, pos_weight=pos_weight, eps=eps)
61
+ mse_peak = mse_peaks_only(logits, targets_safe, peak_mask, eps=eps)
62
 
63
+ return alpha * bce_nonpeak + gamma * mse_peak
64
 
65
+ def accuracy_percentage(
66
+ logits,
67
+ targets,
68
+ peak_thresh: float = 0.5,
69
+ eps: float = 1e-8,
70
+ pad_value: float = -1.0,
71
+ ):
72
  """
73
+ Compute accuracy for predicting high-confidence peaks (prob >= 0.5), ignoring padding.
74
  """
75
+ valid = (targets != pad_value)
76
  probs = torch.sigmoid(logits)
77
+ preds_bin = (probs >= 0.5)
78
+ labels = (targets >= peak_thresh)
79
+
80
+ v = _expand_like(valid, preds_bin)
81
+ correct = ((preds_bin == labels) & v).to(torch.float32).sum()
82
+ total = v.to(torch.float32).sum().clamp_min(eps)
83
+ return (correct / total).item() * 100.0
84
+
85
+ if __name__ == "__main__":
86
+ import torch
87
+
88
+ torch.manual_seed(0)
89
+ PAD = -1.0
90
+
91
+ def make_targets_BL(B=2, L=8, pad_positions=(6, 7)):
92
+ """Create (B,L) targets: 0=non-peak, >0=peak, -1=pad."""
93
+ t = torch.zeros(B, L)
94
+ # sprinkle a few peaks (values in [0.6, 1.0])
95
+ t[:, 1] = torch.rand(B) * 0.4 + 0.6
96
+ t[:, 3] = torch.rand(B) * 0.4 + 0.6
97
+ # pads
98
+ for p in pad_positions:
99
+ t[:, p] = PAD
100
+ return t
101
+
102
+ def make_targets_BLC(B=2, L=8, C=3, pad_positions=(6, 7)):
103
+ """
104
+ Create (B,L,C) targets by broadcasting a (B,L) base across channels
105
+ (so masking needs to expand correctly).
106
+ """
107
+ base = make_targets_BL(B, L, pad_positions) # (B,L)
108
+ t = base.unsqueeze(-1).expand(-1, -1, C).clone()
109
+ # Make channel 1 slightly different to show per-channel variety
110
+ t[..., 1] = torch.where(t[..., 1] > 0, (t[..., 1] * 0.85).clamp(0, 1), t[..., 1])
111
+ return t
112
+
113
+ def mask_stats(name, logits, targets, pad_value=PAD):
114
+ valid = (targets != pad_value)
115
+ nonpeak_mask = valid & (targets == 0)
116
+ peak_mask = valid & (targets > 0)
117
+
118
+ m_nonpeak = _expand_like(nonpeak_mask, logits)
119
+ m_peak = _expand_like(peak_mask, logits)
120
+
121
+ print(f"\n[{name}]")
122
+ print(f" logits.shape = {tuple(logits.shape)}")
123
+ print(f" targets.shape = {tuple(targets.shape)}")
124
+ # Previews (first batch)
125
+ if targets.dim() == 2: # (B,L)
126
+ print(f" targets[0,:] preview: {targets[0]}")
127
+ else: # (B,L,C)
128
+ print(f" targets[0,:,0] ch0 preview: {targets[0,:,0]}")
129
+ print(f" targets[0,:,1] ch1 preview: {targets[0,:,1]}")
130
+ # Mask counts after EXPANSION (these define denominators)
131
+ print(f" #non-peak elems used = {m_nonpeak.sum().item():.0f}")
132
+ print(f" #peak elems used = {m_peak.sum().item():.0f}")
133
+
134
+ # =========================
135
+ # Case A: (B, L)
136
+ # =========================
137
+ B, L = 2, 8
138
+ logits_BL = torch.randn(B, L) # raw scores
139
+ targets_BL = make_targets_BL(B, L) # 0, >0, and -1 pads
140
+
141
+ mask_stats("BL", logits_BL, targets_BL, pad_value=PAD)
142
+
143
+ loss_BL = calculate_loss(
144
+ logits_BL, targets_BL, pad_value=PAD, alpha=1.0, gamma=1.0
145
+ )
146
+ acc_BL = accuracy_percentage(
147
+ logits_BL, targets_BL, pad_value=PAD, peak_thresh=0.5
148
+ )
149
+
150
+ print(f" loss_BL = {loss_BL.item():.6f}")
151
+ print(f" acc_BL = {acc_BL:.2f}%")
152
+
153
+ # =========================
154
+ # Case B: (B, L, C)
155
+ # =========================
156
+ B, L, C = 2, 8, 3
157
+ logits_BLC = torch.randn(B, L, C) # raw scores with channels
158
+ targets_BLC = make_targets_BLC(B, L, C) # broadcasted targets + tweaks
159
+
160
+ mask_stats("BLC", logits_BLC, targets_BLC, pad_value=PAD)
161
+
162
+ loss_BLC = calculate_loss(
163
+ logits_BLC, targets_BLC, pad_value=PAD, alpha=1.0, gamma=1.0
164
+ )
165
+ acc_BLC = accuracy_percentage(
166
+ logits_BLC, targets_BLC, pad_value=PAD, peak_thresh=0.5
167
+ )
168
+
169
+ print(f" loss_BLC = {loss_BLC.item():.6f}")
170
+ print(f" acc_BLC = {acc_BLC:.2f}%")
dpacman/data_modules/pair.py CHANGED
@@ -314,7 +314,7 @@ class ShelfCollator:
314
  tr_key: str = "tr_sequence",
315
  dna_key: str = "dna_sequence",
316
  dtype: torch.dtype = torch.float32,
317
- pad_value: float = 0.0,
318
  ):
319
  self.tr_path = tr_shelf_path
320
  self.dna_path = dna_shelf_path
 
314
  tr_key: str = "tr_sequence",
315
  dna_key: str = "dna_sequence",
316
  dtype: torch.dtype = torch.float32,
317
+ pad_value: float = -1.0,
318
  ):
319
  self.tr_path = tr_shelf_path
320
  self.dna_path = dna_shelf_path
dpacman/data_tasks/embeddings/embedders.py CHANGED
@@ -26,6 +26,8 @@ from sklearn.preprocessing import OneHotEncoder
26
  import math
27
  import rootutils
28
  from dpacman.utils import pylogger
 
 
29
 
30
  root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
31
  logger = pylogger.RankedLogger(__name__, rank_zero_only=True)
@@ -44,7 +46,7 @@ class CaduceusEmbedder:
44
  model_name, trust_remote_code=True
45
  )
46
  self.model = (
47
- AutoModel.from_pretrained(model_name, trust_remote_code=True)
48
  .to(device)
49
  .eval()
50
  )
@@ -52,42 +54,39 @@ class CaduceusEmbedder:
52
  self.chunk_size = chunk_size
53
  self.step = chunk_size - overlap
54
 
55
- def embed(self, seqs, batch_size=1):
56
  """
57
  seqs: List[str] of DNA sequences (each <= chunk_size for this test)
58
  returns: np.ndarray of shape (N, L, D), raw per‐token embeddings
59
  """
60
- # outputs = []
61
- # for seq in seqs:
62
- # # --- new: raw per‐token embeddings in one shot ---
63
- # toks = self.tokenizer(
64
- # seq,
65
- # return_tensors="pt",
66
- # padding=False,
67
- # truncation=True,
68
- # max_length=self.chunk_size
69
- # ).to(self.device)
70
- # with torch.no_grad():
71
- # out = self.model(**toks).last_hidden_state # (1, L, D)
72
- # outputs.append(out.cpu().numpy()[0]) # (L, D)
73
-
74
- # return np.stack(outputs, axis=0) # (N, L, D)
75
- outputs = []
76
- for seq in tqdm(
77
- seqs, total=len(seqs), desc="DNA: Caduceus", dynamic_ncols=True
78
- ):
79
- toks = self.tokenizer(
80
- seq,
81
- return_tensors="pt",
82
- padding=False,
83
- truncation=True,
84
- max_length=self.chunk_size,
85
- ).to(self.device)
86
- with torch.no_grad():
87
- out = self.model(**toks).last_hidden_state # (1, L, D)
88
- outputs.append(out.cpu().numpy()[0]) # (L, D)
89
- return outputs # list of variable-length (L_i, D) arrays
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  def benchmark(self, lengths=None):
92
  """
93
  Time embedding on single-sequence of various lengths.
 
26
  import math
27
  import rootutils
28
  from dpacman.utils import pylogger
29
+ from tqdm import trange
30
+ from tqdm.contrib.logging import logging_redirect_tqdm
31
 
32
  root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
33
  logger = pylogger.RankedLogger(__name__, rank_zero_only=True)
 
46
  model_name, trust_remote_code=True
47
  )
48
  self.model = (
49
+ AutoModelForMaskedLM.from_pretrained(model_name, trust_remote_code=True)
50
  .to(device)
51
  .eval()
52
  )
 
54
  self.chunk_size = chunk_size
55
  self.step = chunk_size - overlap
56
 
57
+ def embed(self, seqs, batch_size=1, pooling=False):
58
  """
59
  seqs: List[str] of DNA sequences (each <= chunk_size for this test)
60
  returns: np.ndarray of shape (N, L, D), raw per‐token embeddings
61
  """
62
+ n = len(seqs)
63
+ if n == 0:
64
+ return {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
+ # (Optional) quick info; uses logger if provided, else print
67
+ max_len = max(len(s) for s in seqs)
68
+ logger.info(f"Max length (will be padded/truncated to tokenizer setting): {max_len}")
69
+
70
+ outputs = {} # seq -> embedding
71
+
72
+ with logging_redirect_tqdm():
73
+ for i in range(0, n, batch_size):
74
+ batch_seqs = seqs[i : i + batch_size]
75
+ logger.info(f"Embedding batch {n//(batch_size*(i+1))}")
76
+
77
+ for seq in tqdm(batch_seqs, total=len(batch_seqs), desc="DNA: Caduceus", dynamic_ncols=True):
78
+ toks = self.tokenizer( # note: the tokenization
79
+ seq,
80
+ return_tensors="pt",
81
+ padding=False,
82
+ truncation=True,
83
+ max_length=self.chunk_size
84
+ ).to(self.device)
85
+ with torch.no_grad():
86
+ out = self.model(**toks).last_hidden_state # (1, L+1, D)
87
+ outputs[seq] = out.cpu().numpy().squeeze(0)[0:-1,:] # (L, D)
88
+ return outputs # list of variable-length (L_i, D) arrays
89
+
90
  def benchmark(self, lengths=None):
91
  """
92
  Time embedding on single-sequence of various lengths.
dpacman/scripts/run_embeddings.sh CHANGED
@@ -8,10 +8,11 @@ timestamp=$(date "+%Y-%m-%d_%H-%M-%S")
8
  run_dir="$HOME/DPACMAN/logs/${main_task}/${data_task_type}/runs/${timestamp}"
9
  mkdir -p "$run_dir"
10
 
11
- nohup python -u -m scripts.preprocess \
12
  hydra.run.dir="${run_dir}" \
13
- data_task="${data_task_type}/protein" \
14
- data_task.debug="false" \
 
15
  > "${run_dir}/run.log" 2>&1 &
16
 
17
  echo $! > "${run_dir}/pid.txt"
 
8
  run_dir="$HOME/DPACMAN/logs/${main_task}/${data_task_type}/runs/${timestamp}"
9
  mkdir -p "$run_dir"
10
 
11
+ nohup python -s -u -m scripts.preprocess \
12
  hydra.run.dir="${run_dir}" \
13
+ data_task="${data_task_type}/dna" \
14
+ data_task.chrom_model="caduceus" \
15
+ data_task.debug="true" \
16
  > "${run_dir}/run.log" 2>&1 &
17
 
18
  echo $! > "${run_dir}/pid.txt"
dpacman/scripts/run_train.sh CHANGED
@@ -16,6 +16,13 @@ fi
16
 
17
  nohup python -u -m scripts.train \
18
  hydra.run.dir="${run_dir}" \
 
 
 
 
 
 
 
19
  > "${run_dir}/run.log" 2>&1 &
20
 
21
  echo $! > "${run_dir}/pid.txt"
 
16
 
17
  nohup python -u -m scripts.train \
18
  hydra.run.dir="${run_dir}" \
19
+ data_module.train_file="data_files/processed/splits/by_dna/train.csv" \
20
+ data_module.val_file="data_files/processed/splits/by_dna/val.csv" \
21
+ data_module.test_file="data_files/processed/splits/by_dna/test.csv" \
22
+ data_module.tr_shelf_path="data_files/processed/embeddings/fimo_hits_only/trs_esm.shelf" \
23
+ data_module.dna_shelf_path="data_files/processed/embeddings/fimo_hits_only/peaks_caduceus.shelf" \
24
+ model.glm_input_dim=256 \
25
+ model.compressed_dim=1029 \
26
  > "${run_dir}/run.log" 2>&1 &
27
 
28
  echo $! > "${run_dir}/pid.txt"
h100_env.yaml CHANGED
@@ -43,6 +43,7 @@ dependencies:
43
  - tqdm==4.67.1
44
  - matplotlib==3.10.3
45
  - transformers==4.55.2
 
46
  - biopython==1.85
47
  - ortools==9.14.6206
48
  - fair-esm==2.0.0
 
43
  - tqdm==4.67.1
44
  - matplotlib==3.10.3
45
  - transformers==4.55.2
46
+ - huggingface_hub==0.34.4
47
  - biopython==1.85
48
  - ortools==9.14.6206
49
  - fair-esm==2.0.0
h100_env2.yaml ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: dnabind3
2
+ channels:
3
+ - conda-forge
4
+ - defaults
5
+
6
+ dependencies:
7
+ - python=3.10
8
+ - pip>=24
9
+ # compiled / heavy libs via conda-forge
10
+ - numpy>=2.0,<3.0
11
+ - scikit-learn>=1.5,<1.7
12
+ - pandas>=2.2,<2.3
13
+ - matplotlib>=3.8,<3.11
14
+ - lxml>=5.2,<6
15
+ - lightning=2.5.1
16
+ - torchmetrics>=1.3
17
+ - dask
18
+ - distributed
19
+ - dask-ml
20
+ # toolchain for JIT/building CUDA extensions (mamba-ssm, Triton kernels)
21
+ - cuda-toolkit=12.4
22
+ - cmake
23
+ - ninja
24
+
25
+ - pip:
26
+ # Force CUDA wheels and keep them from being overwritten by CPU builds
27
+ - --index-url=https://download.pytorch.org/whl/cu124
28
+ - torch==2.6.0+cu124
29
+
30
+ # HF stack + hard deps used at runtime
31
+ - transformers==4.53.0
32
+ - tokenizers>=0.21,<0.22
33
+ - safetensors>=0.4.3
34
+ - huggingface-hub==0.34.4
35
+ - regex
36
+
37
+ # Your libs
38
+ - rootutils==1.0.7
39
+ - hydra-core==1.3.2
40
+ - hydra-colorlog==1.2.0
41
+ - omegaconf==2.3.0
42
+ - pymex==0.9.31
43
+ - gitpython==3.1.44
44
+ - black==25.1.0
45
+ - tqdm==4.67.1
46
+ - biopython==1.85
47
+ - ortools==9.14.6206
48
+ - fair-esm==2.0.0
49
+ - rich==14.1.0
50
+ - wandb==0.21.1
51
+
52
+ # Mamba + Triton (for CUDA kernels)
53
+ - mamba-ssm==2.2.4
54
+ - triton>=3.0,<3.5
55
+
56
+ # your package
57
+ - -e .