recent changes
Browse files- configs/data_task/clean/remap.yaml +1 -1
- dpacman/classifier/loss.py +144 -32
- dpacman/data_modules/pair.py +1 -1
- dpacman/data_tasks/embeddings/embedders.py +31 -32
- dpacman/scripts/run_embeddings.sh +4 -3
- dpacman/scripts/run_train.sh +7 -0
- h100_env.yaml +1 -0
- h100_env2.yaml +57 -0
configs/data_task/clean/remap.yaml
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
name: remap
|
| 2 |
-
|
| 3 |
|
| 4 |
nr_raw_path: dpacman/data_files/raw/remap/remap2022_nr_macs2_hg38_v1_0.bed
|
| 5 |
nr_processed_dir: dpacman/data_files/processed/remap
|
|
|
|
| 1 |
name: remap
|
| 2 |
+
task_type: clean
|
| 3 |
|
| 4 |
nr_raw_path: dpacman/data_files/raw/remap/remap2022_nr_macs2_hg38_v1_0.bed
|
| 5 |
nr_processed_dir: dpacman/data_files/processed/remap
|
dpacman/classifier/loss.py
CHANGED
|
@@ -1,58 +1,170 @@
|
|
| 1 |
"""
|
| 2 |
-
Define loss functions needed for training the model
|
| 3 |
"""
|
| 4 |
|
| 5 |
import torch
|
| 6 |
-
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
-
def bce_loss_masked(logits, targets, nonpeak_mask, pos_weight=None):
|
| 10 |
"""
|
| 11 |
-
Compute
|
| 12 |
-
|
| 13 |
"""
|
|
|
|
|
|
|
| 14 |
loss = F.binary_cross_entropy_with_logits(
|
| 15 |
-
logits,
|
| 16 |
)
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
|
| 21 |
def mse_peaks_only(logits, targets, peak_mask, eps=1e-8):
|
| 22 |
"""
|
| 23 |
-
Calculate MSE on peaks only.
|
| 24 |
"""
|
| 25 |
probs = torch.sigmoid(logits)
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
)
|
| 29 |
-
return
|
| 30 |
-
|
| 31 |
|
| 32 |
-
def calculate_loss(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
"""
|
| 34 |
-
Combine masked-BCE +
|
|
|
|
| 35 |
"""
|
| 36 |
-
|
| 37 |
-
# Anything outside a peak will have a label equal to 0.
|
| 38 |
-
nonpeak_mask = (targets == 0).float()
|
| 39 |
-
peak_mask = (targets > 0).float()
|
| 40 |
|
| 41 |
-
|
| 42 |
-
|
|
|
|
| 43 |
|
| 44 |
-
|
|
|
|
| 45 |
|
| 46 |
-
|
|
|
|
| 47 |
|
|
|
|
| 48 |
|
| 49 |
-
def accuracy_percentage(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
"""
|
| 51 |
-
Compute accuracy
|
| 52 |
"""
|
|
|
|
| 53 |
probs = torch.sigmoid(logits)
|
| 54 |
-
preds_bin = (probs >= 0.5)
|
| 55 |
-
labels
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
+
Define loss functions needed for training the model — padding safe (-1 sentinel)
|
| 3 |
"""
|
| 4 |
|
| 5 |
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
|
| 8 |
+
def _expand_like(mask: torch.Tensor, like: torch.Tensor):
|
| 9 |
+
# Make mask broadcastable to logits/targets (handles (B,L) vs (B,L,1))
|
| 10 |
+
while mask.dim() < like.dim():
|
| 11 |
+
mask = mask.unsqueeze(-1)
|
| 12 |
+
return mask.expand_as(like)
|
| 13 |
|
| 14 |
+
def bce_loss_masked(logits, targets, nonpeak_mask, pos_weight=None, eps=1e-8):
|
| 15 |
"""
|
| 16 |
+
Compute masked BCE with logits over non-peak positions only.
|
| 17 |
+
Expects nonpeak_mask already broadcastable to logits.
|
| 18 |
"""
|
| 19 |
+
# Clamp targets into [0,1] to be safe, even if pads slip through earlier
|
| 20 |
+
t = targets.clamp(0.0, 1.0)
|
| 21 |
loss = F.binary_cross_entropy_with_logits(
|
| 22 |
+
logits, t, reduction="none", pos_weight=pos_weight
|
| 23 |
)
|
| 24 |
+
m = _expand_like(nonpeak_mask, loss).to(loss.dtype)
|
| 25 |
+
denom = m.sum().clamp_min(eps)
|
| 26 |
+
return (loss * m).sum() / denom
|
| 27 |
|
| 28 |
def mse_peaks_only(logits, targets, peak_mask, eps=1e-8):
|
| 29 |
"""
|
| 30 |
+
Calculate MSE on peaks only (on probabilities), masking everything else.
|
| 31 |
"""
|
| 32 |
probs = torch.sigmoid(logits)
|
| 33 |
+
per_elem = F.mse_loss(probs, targets, reduction="none")
|
| 34 |
+
m = _expand_like(peak_mask, per_elem).to(per_elem.dtype)
|
| 35 |
+
denom = m.sum().clamp_min(eps)
|
| 36 |
+
return (per_elem * m).sum() / denom
|
|
|
|
| 37 |
|
| 38 |
+
def calculate_loss(
|
| 39 |
+
logits,
|
| 40 |
+
targets,
|
| 41 |
+
eps: float = 1e-8,
|
| 42 |
+
alpha: float = 1.0,
|
| 43 |
+
gamma: float = 1.0,
|
| 44 |
+
pos_weight=None,
|
| 45 |
+
pad_value: float = -1.0,
|
| 46 |
+
):
|
| 47 |
"""
|
| 48 |
+
Combine masked-BCE (non-peak) + masked-MSE on probs (peak), ignoring padding.
|
| 49 |
+
Assumes targets == -1 are pads; non-peak = 0; peak > 0.
|
| 50 |
"""
|
| 51 |
+
valid = (targets != pad_value)
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
+
# Peak / non-peak masks that exclude pads
|
| 54 |
+
nonpeak_mask = valid & (targets == 0)
|
| 55 |
+
peak_mask = valid & (targets > 0)
|
| 56 |
|
| 57 |
+
# For safety, zero-out targets at pad positions so they never feed into BCE/MSE
|
| 58 |
+
targets_safe = torch.where(valid, targets, torch.zeros_like(targets))
|
| 59 |
|
| 60 |
+
bce_nonpeak = bce_loss_masked(logits, targets_safe, nonpeak_mask, pos_weight=pos_weight, eps=eps)
|
| 61 |
+
mse_peak = mse_peaks_only(logits, targets_safe, peak_mask, eps=eps)
|
| 62 |
|
| 63 |
+
return alpha * bce_nonpeak + gamma * mse_peak
|
| 64 |
|
| 65 |
+
def accuracy_percentage(
|
| 66 |
+
logits,
|
| 67 |
+
targets,
|
| 68 |
+
peak_thresh: float = 0.5,
|
| 69 |
+
eps: float = 1e-8,
|
| 70 |
+
pad_value: float = -1.0,
|
| 71 |
+
):
|
| 72 |
"""
|
| 73 |
+
Compute accuracy for predicting high-confidence peaks (prob >= 0.5), ignoring padding.
|
| 74 |
"""
|
| 75 |
+
valid = (targets != pad_value)
|
| 76 |
probs = torch.sigmoid(logits)
|
| 77 |
+
preds_bin = (probs >= 0.5)
|
| 78 |
+
labels = (targets >= peak_thresh)
|
| 79 |
+
|
| 80 |
+
v = _expand_like(valid, preds_bin)
|
| 81 |
+
correct = ((preds_bin == labels) & v).to(torch.float32).sum()
|
| 82 |
+
total = v.to(torch.float32).sum().clamp_min(eps)
|
| 83 |
+
return (correct / total).item() * 100.0
|
| 84 |
+
|
| 85 |
+
if __name__ == "__main__":
|
| 86 |
+
import torch
|
| 87 |
+
|
| 88 |
+
torch.manual_seed(0)
|
| 89 |
+
PAD = -1.0
|
| 90 |
+
|
| 91 |
+
def make_targets_BL(B=2, L=8, pad_positions=(6, 7)):
|
| 92 |
+
"""Create (B,L) targets: 0=non-peak, >0=peak, -1=pad."""
|
| 93 |
+
t = torch.zeros(B, L)
|
| 94 |
+
# sprinkle a few peaks (values in [0.6, 1.0])
|
| 95 |
+
t[:, 1] = torch.rand(B) * 0.4 + 0.6
|
| 96 |
+
t[:, 3] = torch.rand(B) * 0.4 + 0.6
|
| 97 |
+
# pads
|
| 98 |
+
for p in pad_positions:
|
| 99 |
+
t[:, p] = PAD
|
| 100 |
+
return t
|
| 101 |
+
|
| 102 |
+
def make_targets_BLC(B=2, L=8, C=3, pad_positions=(6, 7)):
|
| 103 |
+
"""
|
| 104 |
+
Create (B,L,C) targets by broadcasting a (B,L) base across channels
|
| 105 |
+
(so masking needs to expand correctly).
|
| 106 |
+
"""
|
| 107 |
+
base = make_targets_BL(B, L, pad_positions) # (B,L)
|
| 108 |
+
t = base.unsqueeze(-1).expand(-1, -1, C).clone()
|
| 109 |
+
# Make channel 1 slightly different to show per-channel variety
|
| 110 |
+
t[..., 1] = torch.where(t[..., 1] > 0, (t[..., 1] * 0.85).clamp(0, 1), t[..., 1])
|
| 111 |
+
return t
|
| 112 |
+
|
| 113 |
+
def mask_stats(name, logits, targets, pad_value=PAD):
|
| 114 |
+
valid = (targets != pad_value)
|
| 115 |
+
nonpeak_mask = valid & (targets == 0)
|
| 116 |
+
peak_mask = valid & (targets > 0)
|
| 117 |
+
|
| 118 |
+
m_nonpeak = _expand_like(nonpeak_mask, logits)
|
| 119 |
+
m_peak = _expand_like(peak_mask, logits)
|
| 120 |
+
|
| 121 |
+
print(f"\n[{name}]")
|
| 122 |
+
print(f" logits.shape = {tuple(logits.shape)}")
|
| 123 |
+
print(f" targets.shape = {tuple(targets.shape)}")
|
| 124 |
+
# Previews (first batch)
|
| 125 |
+
if targets.dim() == 2: # (B,L)
|
| 126 |
+
print(f" targets[0,:] preview: {targets[0]}")
|
| 127 |
+
else: # (B,L,C)
|
| 128 |
+
print(f" targets[0,:,0] ch0 preview: {targets[0,:,0]}")
|
| 129 |
+
print(f" targets[0,:,1] ch1 preview: {targets[0,:,1]}")
|
| 130 |
+
# Mask counts after EXPANSION (these define denominators)
|
| 131 |
+
print(f" #non-peak elems used = {m_nonpeak.sum().item():.0f}")
|
| 132 |
+
print(f" #peak elems used = {m_peak.sum().item():.0f}")
|
| 133 |
+
|
| 134 |
+
# =========================
|
| 135 |
+
# Case A: (B, L)
|
| 136 |
+
# =========================
|
| 137 |
+
B, L = 2, 8
|
| 138 |
+
logits_BL = torch.randn(B, L) # raw scores
|
| 139 |
+
targets_BL = make_targets_BL(B, L) # 0, >0, and -1 pads
|
| 140 |
+
|
| 141 |
+
mask_stats("BL", logits_BL, targets_BL, pad_value=PAD)
|
| 142 |
+
|
| 143 |
+
loss_BL = calculate_loss(
|
| 144 |
+
logits_BL, targets_BL, pad_value=PAD, alpha=1.0, gamma=1.0
|
| 145 |
+
)
|
| 146 |
+
acc_BL = accuracy_percentage(
|
| 147 |
+
logits_BL, targets_BL, pad_value=PAD, peak_thresh=0.5
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
print(f" loss_BL = {loss_BL.item():.6f}")
|
| 151 |
+
print(f" acc_BL = {acc_BL:.2f}%")
|
| 152 |
+
|
| 153 |
+
# =========================
|
| 154 |
+
# Case B: (B, L, C)
|
| 155 |
+
# =========================
|
| 156 |
+
B, L, C = 2, 8, 3
|
| 157 |
+
logits_BLC = torch.randn(B, L, C) # raw scores with channels
|
| 158 |
+
targets_BLC = make_targets_BLC(B, L, C) # broadcasted targets + tweaks
|
| 159 |
+
|
| 160 |
+
mask_stats("BLC", logits_BLC, targets_BLC, pad_value=PAD)
|
| 161 |
+
|
| 162 |
+
loss_BLC = calculate_loss(
|
| 163 |
+
logits_BLC, targets_BLC, pad_value=PAD, alpha=1.0, gamma=1.0
|
| 164 |
+
)
|
| 165 |
+
acc_BLC = accuracy_percentage(
|
| 166 |
+
logits_BLC, targets_BLC, pad_value=PAD, peak_thresh=0.5
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
print(f" loss_BLC = {loss_BLC.item():.6f}")
|
| 170 |
+
print(f" acc_BLC = {acc_BLC:.2f}%")
|
dpacman/data_modules/pair.py
CHANGED
|
@@ -314,7 +314,7 @@ class ShelfCollator:
|
|
| 314 |
tr_key: str = "tr_sequence",
|
| 315 |
dna_key: str = "dna_sequence",
|
| 316 |
dtype: torch.dtype = torch.float32,
|
| 317 |
-
pad_value: float =
|
| 318 |
):
|
| 319 |
self.tr_path = tr_shelf_path
|
| 320 |
self.dna_path = dna_shelf_path
|
|
|
|
| 314 |
tr_key: str = "tr_sequence",
|
| 315 |
dna_key: str = "dna_sequence",
|
| 316 |
dtype: torch.dtype = torch.float32,
|
| 317 |
+
pad_value: float = -1.0,
|
| 318 |
):
|
| 319 |
self.tr_path = tr_shelf_path
|
| 320 |
self.dna_path = dna_shelf_path
|
dpacman/data_tasks/embeddings/embedders.py
CHANGED
|
@@ -26,6 +26,8 @@ from sklearn.preprocessing import OneHotEncoder
|
|
| 26 |
import math
|
| 27 |
import rootutils
|
| 28 |
from dpacman.utils import pylogger
|
|
|
|
|
|
|
| 29 |
|
| 30 |
root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
| 31 |
logger = pylogger.RankedLogger(__name__, rank_zero_only=True)
|
|
@@ -44,7 +46,7 @@ class CaduceusEmbedder:
|
|
| 44 |
model_name, trust_remote_code=True
|
| 45 |
)
|
| 46 |
self.model = (
|
| 47 |
-
|
| 48 |
.to(device)
|
| 49 |
.eval()
|
| 50 |
)
|
|
@@ -52,42 +54,39 @@ class CaduceusEmbedder:
|
|
| 52 |
self.chunk_size = chunk_size
|
| 53 |
self.step = chunk_size - overlap
|
| 54 |
|
| 55 |
-
def embed(self, seqs, batch_size=1):
|
| 56 |
"""
|
| 57 |
seqs: List[str] of DNA sequences (each <= chunk_size for this test)
|
| 58 |
returns: np.ndarray of shape (N, L, D), raw per‐token embeddings
|
| 59 |
"""
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
# toks = self.tokenizer(
|
| 64 |
-
# seq,
|
| 65 |
-
# return_tensors="pt",
|
| 66 |
-
# padding=False,
|
| 67 |
-
# truncation=True,
|
| 68 |
-
# max_length=self.chunk_size
|
| 69 |
-
# ).to(self.device)
|
| 70 |
-
# with torch.no_grad():
|
| 71 |
-
# out = self.model(**toks).last_hidden_state # (1, L, D)
|
| 72 |
-
# outputs.append(out.cpu().numpy()[0]) # (L, D)
|
| 73 |
-
|
| 74 |
-
# return np.stack(outputs, axis=0) # (N, L, D)
|
| 75 |
-
outputs = []
|
| 76 |
-
for seq in tqdm(
|
| 77 |
-
seqs, total=len(seqs), desc="DNA: Caduceus", dynamic_ncols=True
|
| 78 |
-
):
|
| 79 |
-
toks = self.tokenizer(
|
| 80 |
-
seq,
|
| 81 |
-
return_tensors="pt",
|
| 82 |
-
padding=False,
|
| 83 |
-
truncation=True,
|
| 84 |
-
max_length=self.chunk_size,
|
| 85 |
-
).to(self.device)
|
| 86 |
-
with torch.no_grad():
|
| 87 |
-
out = self.model(**toks).last_hidden_state # (1, L, D)
|
| 88 |
-
outputs.append(out.cpu().numpy()[0]) # (L, D)
|
| 89 |
-
return outputs # list of variable-length (L_i, D) arrays
|
| 90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
def benchmark(self, lengths=None):
|
| 92 |
"""
|
| 93 |
Time embedding on single-sequence of various lengths.
|
|
|
|
| 26 |
import math
|
| 27 |
import rootutils
|
| 28 |
from dpacman.utils import pylogger
|
| 29 |
+
from tqdm import trange
|
| 30 |
+
from tqdm.contrib.logging import logging_redirect_tqdm
|
| 31 |
|
| 32 |
root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
| 33 |
logger = pylogger.RankedLogger(__name__, rank_zero_only=True)
|
|
|
|
| 46 |
model_name, trust_remote_code=True
|
| 47 |
)
|
| 48 |
self.model = (
|
| 49 |
+
AutoModelForMaskedLM.from_pretrained(model_name, trust_remote_code=True)
|
| 50 |
.to(device)
|
| 51 |
.eval()
|
| 52 |
)
|
|
|
|
| 54 |
self.chunk_size = chunk_size
|
| 55 |
self.step = chunk_size - overlap
|
| 56 |
|
| 57 |
+
def embed(self, seqs, batch_size=1, pooling=False):
|
| 58 |
"""
|
| 59 |
seqs: List[str] of DNA sequences (each <= chunk_size for this test)
|
| 60 |
returns: np.ndarray of shape (N, L, D), raw per‐token embeddings
|
| 61 |
"""
|
| 62 |
+
n = len(seqs)
|
| 63 |
+
if n == 0:
|
| 64 |
+
return {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
+
# (Optional) quick info; uses logger if provided, else print
|
| 67 |
+
max_len = max(len(s) for s in seqs)
|
| 68 |
+
logger.info(f"Max length (will be padded/truncated to tokenizer setting): {max_len}")
|
| 69 |
+
|
| 70 |
+
outputs = {} # seq -> embedding
|
| 71 |
+
|
| 72 |
+
with logging_redirect_tqdm():
|
| 73 |
+
for i in range(0, n, batch_size):
|
| 74 |
+
batch_seqs = seqs[i : i + batch_size]
|
| 75 |
+
logger.info(f"Embedding batch {n//(batch_size*(i+1))}")
|
| 76 |
+
|
| 77 |
+
for seq in tqdm(batch_seqs, total=len(batch_seqs), desc="DNA: Caduceus", dynamic_ncols=True):
|
| 78 |
+
toks = self.tokenizer( # note: the tokenization
|
| 79 |
+
seq,
|
| 80 |
+
return_tensors="pt",
|
| 81 |
+
padding=False,
|
| 82 |
+
truncation=True,
|
| 83 |
+
max_length=self.chunk_size
|
| 84 |
+
).to(self.device)
|
| 85 |
+
with torch.no_grad():
|
| 86 |
+
out = self.model(**toks).last_hidden_state # (1, L+1, D)
|
| 87 |
+
outputs[seq] = out.cpu().numpy().squeeze(0)[0:-1,:] # (L, D)
|
| 88 |
+
return outputs # list of variable-length (L_i, D) arrays
|
| 89 |
+
|
| 90 |
def benchmark(self, lengths=None):
|
| 91 |
"""
|
| 92 |
Time embedding on single-sequence of various lengths.
|
dpacman/scripts/run_embeddings.sh
CHANGED
|
@@ -8,10 +8,11 @@ timestamp=$(date "+%Y-%m-%d_%H-%M-%S")
|
|
| 8 |
run_dir="$HOME/DPACMAN/logs/${main_task}/${data_task_type}/runs/${timestamp}"
|
| 9 |
mkdir -p "$run_dir"
|
| 10 |
|
| 11 |
-
nohup python -u -m scripts.preprocess \
|
| 12 |
hydra.run.dir="${run_dir}" \
|
| 13 |
-
data_task="${data_task_type}/
|
| 14 |
-
data_task.
|
|
|
|
| 15 |
> "${run_dir}/run.log" 2>&1 &
|
| 16 |
|
| 17 |
echo $! > "${run_dir}/pid.txt"
|
|
|
|
| 8 |
run_dir="$HOME/DPACMAN/logs/${main_task}/${data_task_type}/runs/${timestamp}"
|
| 9 |
mkdir -p "$run_dir"
|
| 10 |
|
| 11 |
+
nohup python -s -u -m scripts.preprocess \
|
| 12 |
hydra.run.dir="${run_dir}" \
|
| 13 |
+
data_task="${data_task_type}/dna" \
|
| 14 |
+
data_task.chrom_model="caduceus" \
|
| 15 |
+
data_task.debug="true" \
|
| 16 |
> "${run_dir}/run.log" 2>&1 &
|
| 17 |
|
| 18 |
echo $! > "${run_dir}/pid.txt"
|
dpacman/scripts/run_train.sh
CHANGED
|
@@ -16,6 +16,13 @@ fi
|
|
| 16 |
|
| 17 |
nohup python -u -m scripts.train \
|
| 18 |
hydra.run.dir="${run_dir}" \
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
> "${run_dir}/run.log" 2>&1 &
|
| 20 |
|
| 21 |
echo $! > "${run_dir}/pid.txt"
|
|
|
|
| 16 |
|
| 17 |
nohup python -u -m scripts.train \
|
| 18 |
hydra.run.dir="${run_dir}" \
|
| 19 |
+
data_module.train_file="data_files/processed/splits/by_dna/train.csv" \
|
| 20 |
+
data_module.val_file="data_files/processed/splits/by_dna/val.csv" \
|
| 21 |
+
data_module.test_file="data_files/processed/splits/by_dna/test.csv" \
|
| 22 |
+
data_module.tr_shelf_path="data_files/processed/embeddings/fimo_hits_only/trs_esm.shelf" \
|
| 23 |
+
data_module.dna_shelf_path="data_files/processed/embeddings/fimo_hits_only/peaks_caduceus.shelf" \
|
| 24 |
+
model.glm_input_dim=256 \
|
| 25 |
+
model.compressed_dim=1029 \
|
| 26 |
> "${run_dir}/run.log" 2>&1 &
|
| 27 |
|
| 28 |
echo $! > "${run_dir}/pid.txt"
|
h100_env.yaml
CHANGED
|
@@ -43,6 +43,7 @@ dependencies:
|
|
| 43 |
- tqdm==4.67.1
|
| 44 |
- matplotlib==3.10.3
|
| 45 |
- transformers==4.55.2
|
|
|
|
| 46 |
- biopython==1.85
|
| 47 |
- ortools==9.14.6206
|
| 48 |
- fair-esm==2.0.0
|
|
|
|
| 43 |
- tqdm==4.67.1
|
| 44 |
- matplotlib==3.10.3
|
| 45 |
- transformers==4.55.2
|
| 46 |
+
- huggingface_hub==0.34.4
|
| 47 |
- biopython==1.85
|
| 48 |
- ortools==9.14.6206
|
| 49 |
- fair-esm==2.0.0
|
h100_env2.yaml
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: dnabind3
|
| 2 |
+
channels:
|
| 3 |
+
- conda-forge
|
| 4 |
+
- defaults
|
| 5 |
+
|
| 6 |
+
dependencies:
|
| 7 |
+
- python=3.10
|
| 8 |
+
- pip>=24
|
| 9 |
+
# compiled / heavy libs via conda-forge
|
| 10 |
+
- numpy>=2.0,<3.0
|
| 11 |
+
- scikit-learn>=1.5,<1.7
|
| 12 |
+
- pandas>=2.2,<2.3
|
| 13 |
+
- matplotlib>=3.8,<3.11
|
| 14 |
+
- lxml>=5.2,<6
|
| 15 |
+
- lightning=2.5.1
|
| 16 |
+
- torchmetrics>=1.3
|
| 17 |
+
- dask
|
| 18 |
+
- distributed
|
| 19 |
+
- dask-ml
|
| 20 |
+
# toolchain for JIT/building CUDA extensions (mamba-ssm, Triton kernels)
|
| 21 |
+
- cuda-toolkit=12.4
|
| 22 |
+
- cmake
|
| 23 |
+
- ninja
|
| 24 |
+
|
| 25 |
+
- pip:
|
| 26 |
+
# Force CUDA wheels and keep them from being overwritten by CPU builds
|
| 27 |
+
- --index-url=https://download.pytorch.org/whl/cu124
|
| 28 |
+
- torch==2.6.0+cu124
|
| 29 |
+
|
| 30 |
+
# HF stack + hard deps used at runtime
|
| 31 |
+
- transformers==4.53.0
|
| 32 |
+
- tokenizers>=0.21,<0.22
|
| 33 |
+
- safetensors>=0.4.3
|
| 34 |
+
- huggingface-hub==0.34.4
|
| 35 |
+
- regex
|
| 36 |
+
|
| 37 |
+
# Your libs
|
| 38 |
+
- rootutils==1.0.7
|
| 39 |
+
- hydra-core==1.3.2
|
| 40 |
+
- hydra-colorlog==1.2.0
|
| 41 |
+
- omegaconf==2.3.0
|
| 42 |
+
- pymex==0.9.31
|
| 43 |
+
- gitpython==3.1.44
|
| 44 |
+
- black==25.1.0
|
| 45 |
+
- tqdm==4.67.1
|
| 46 |
+
- biopython==1.85
|
| 47 |
+
- ortools==9.14.6206
|
| 48 |
+
- fair-esm==2.0.0
|
| 49 |
+
- rich==14.1.0
|
| 50 |
+
- wandb==0.21.1
|
| 51 |
+
|
| 52 |
+
# Mamba + Triton (for CUDA kernels)
|
| 53 |
+
- mamba-ssm==2.2.4
|
| 54 |
+
- triton>=3.0,<3.5
|
| 55 |
+
|
| 56 |
+
# your package
|
| 57 |
+
- -e .
|