svincoff commited on
Commit
80b6a2c
·
1 Parent(s): fad71ca

temp changes for gpu switch

Browse files
Files changed (40) hide show
  1. .gitignore +9 -1
  2. configs/data_modules/pair.yaml +9 -0
  3. configs/data_task/cluster/remap.yaml +44 -0
  4. configs/data_task/fimo/post_fimo.yaml +9 -4
  5. configs/data_task/split/remap.yaml +23 -0
  6. configs/models/classifier.yaml +11 -0
  7. configs/models/pooling/truncatedsvd.yaml +7 -0
  8. configs/train.yaml +8 -0
  9. dpacman/classifier/model/loss.py +34 -0
  10. dpacman/classifier/model/train.py +0 -79
  11. dpacman/classifier/model_tmp/__init__.py +0 -0
  12. dpacman/classifier/model_tmp/clustering_data.py +383 -0
  13. dpacman/classifier/model_tmp/compress_embeddings.py +54 -0
  14. dpacman/{data_tasks/embeddings → classifier/model_tmp}/compute_embeddings.py +342 -141
  15. dpacman/classifier/model_tmp/extract_tf_symbols.py +27 -0
  16. dpacman/classifier/model_tmp/make_pair_list.py +220 -0
  17. dpacman/classifier/model_tmp/make_peak_fasta.py +13 -0
  18. dpacman/classifier/model_tmp/model.py +103 -0
  19. dpacman/classifier/model_tmp/prep_splits.py +133 -0
  20. dpacman/classifier/model_tmp/train.py +180 -0
  21. dpacman/data_modules/__init__.py +0 -0
  22. dpacman/data_modules/pair.py +342 -0
  23. dpacman/data_tasks/cluster/__init__.py +0 -0
  24. dpacman/data_tasks/cluster/remap.py +144 -0
  25. dpacman/data_tasks/embeddings/embedders.py +560 -0
  26. dpacman/data_tasks/fimo/__init__.py +0 -0
  27. dpacman/data_tasks/fimo/post_fimo.py +526 -124
  28. dpacman/data_tasks/fimo/run_fimo.py +16 -7
  29. dpacman/data_tasks/split/__init__.py +0 -0
  30. dpacman/data_tasks/split/remap.py +512 -0
  31. dpacman/scripts/preprocess.py +20 -4
  32. dpacman/scripts/run_cluster.sh +19 -0
  33. dpacman/scripts/run_fimo.sh +3 -1
  34. dpacman/scripts/run_split.sh +21 -0
  35. dpacman/utils/README.md +1 -0
  36. dpacman/utils/__init__.py +0 -0
  37. dpacman/utils/clustering.py +144 -0
  38. dpacman/utils/models.py +19 -0
  39. dpacman/utils/splitting.py +0 -0
  40. environment.yaml +39 -36
.gitignore CHANGED
@@ -17,4 +17,12 @@ dpacman/scripts/__pycache__/
17
  dpacman/temp.py
18
  dpacman/temp2.py
19
  logs/
20
- tree.txt
 
 
 
 
 
 
 
 
 
17
  dpacman/temp.py
18
  dpacman/temp2.py
19
  logs/
20
+ tree.txt
21
+ dpacman/utils/ubuntu_font/
22
+ dpacman/tf_clusters.tsv
23
+ dpacman/tf_family_tree.nwk
24
+ dpacman/idmap_filt.csv
25
+ dpacman/temp3.py
26
+ dpacman/temp4.py
27
+ dpacman/temp.ipynb
28
+ dpacman/nohup.out
configs/data_modules/pair.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+
2
+ train_file: data_files/splits/train.csv
3
+ val_file: data_files/splits/val.csv
4
+ test_file: data_files/splits/test.csv
5
+
6
+ batch_size: 32
7
+ num_workers: 8
8
+
9
+ maximize_num_workers: False
configs/data_task/cluster/remap.yaml ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: remap
2
+ type: cluster
3
+
4
+ max_protein_length: 1998
5
+
6
+ cluster_dna_full: true
7
+ cluster_dna_peaks: true
8
+ cluster_protein: false
9
+
10
+ dna_full:
11
+ input_map_path: dpacman/data_files/processed/fimo/post_fimo/fimo_hits_only/maps/dna_seqid_to_dna_sequence.json
12
+ fasta_path: dpacman/data_files/processed/mmseqs/inputs/fimo_hits_only/dna_full.fasta
13
+ output_dir: dpacman/data_files/processed/mmseqs/outputs/fimo_hits_only/dna_full
14
+ mmseqs:
15
+ min_seq_id: 0.3
16
+ c: 0.8
17
+ cov_mode: 0
18
+ cluster_mode: 0
19
+
20
+ dna_peaks:
21
+ input_map_path: dpacman/data_files/processed/fimo/post_fimo/fimo_hits_only/maps/peak_seqid_to_peak_sequence.json
22
+ fasta_path: dpacman/data_files/processed/mmseqs/inputs/fimo_hits_only/dna_peaks.fasta
23
+ output_dir: dpacman/data_files/processed/mmseqs/outputs/fimo_hits_only/dna_peaks
24
+ mmseqs:
25
+ min_seq_id: 0.3
26
+ c: 0.8
27
+ cov_mode: 0
28
+ cluster_mode: 0
29
+
30
+ protein:
31
+ input_map_path: dpacman/data_files/processed/fimo/post_fimo/fimo_hits_only/maps/tr_seqid_to_tr_sequence.json
32
+ fasta_path: dpacman/data_files/processed/mmseqs/inputs/fimo_hits_only/protein.fasta
33
+ output_dir: dpacman/data_files/processed/mmseqs/outputs/fimo_hits_only/protein
34
+ mmseqs:
35
+ min_seq_id: 0.3
36
+ c: 0.8
37
+ cov_mode: 0
38
+ cluster_mode: 0
39
+
40
+ input_data_path: /vast/projects/pranam/lab/sophie/DPACMAN/dpacman/data_files/processed/fimo/post_fimo/fimo_hits_only/remap2022_crm_fimo_output_q_processed_seed0.parquet
41
+ path_to_mmseqs: /vast/projects/pranam/lab/shared/mmseqs
42
+
43
+ out_dir: na
44
+ final_csv: na
configs/data_task/fimo/post_fimo.yaml CHANGED
@@ -2,11 +2,16 @@ name: post_fimo
2
  type: fimo
3
 
4
  fimo_out_dir: dpacman/data_files/processed/fimo/fimo_out_q
5
- unprocessed_output_csv: dpacman/data_files/processed/fimo/remap2022_crm_fimo_output_q_unprocessed.csv
6
- processed_output_csv: dpacman/data_files/processed/fimo/remap2022_crm_fimo_output_q_processed.csv
7
  json_dir: dpacman/data_files/raw/genomes/hg38
8
- idmap_path: dpacman/data_files/raw/remap/idmapping_reviewed_true_2025_08_11.tsv
 
9
 
10
- jaspar_boost: 100
 
 
 
 
 
11
 
12
  debug: false
 
2
  type: fimo
3
 
4
  fimo_out_dir: dpacman/data_files/processed/fimo/fimo_out_q
5
+ processed_output_csv: dpacman/data_files/processed/fimo/post_fimo/fimo_hits_only/remap2022_crm_fimo_output_q_processed.csv
 
6
  json_dir: dpacman/data_files/raw/genomes/hg38
7
+ idmap_path: dpacman/data_files/raw/remap/idmapping_reviewed_true_2025_08_15.tsv
8
+ remap_path: /vast/projects/pranam/lab/sophie/DPACMAN/dpacman/data_files/processed/fimo/remap2022_crm_fimo_input.csv
9
 
10
+ jaspar_boost: 333
11
+
12
+ keep_fimo_only: true
13
+
14
+ seeds: [0]
15
+ max_protein_len: 1998
16
 
17
  debug: false
configs/data_task/split/remap.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: remap
2
+ type: split
3
+
4
+ max_protein_length: 1998
5
+
6
+ cluster_output_paths:
7
+ dna: dpacman/data_files/processed/mmseqs/outputs/fimo_hits_only/dna_full/mmseqs_cluster.tsv
8
+ protein: dpacman/data_files/processed/mmseqs/outputs/fimo_hits_only/protein/mmseqs_cluster.tsv
9
+
10
+ input_data_path: dpacman/data_files/processed/fimo/post_fimo/fimo_hits_only/remap2022_crm_fimo_output_q_processed_seed0.parquet
11
+ split_out_dir: dpacman/data_files/processed/splits
12
+
13
+ split_by: both # protein, dna, or both
14
+
15
+ test_ratio: 0.10
16
+ val_ratio: 0.10
17
+ train_ratio: 0.80
18
+
19
+ require_nonempty: true
20
+ ratio_tolerance: null
21
+ bigM: null
22
+
23
+ seed: 0
configs/models/classifier.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: classifier
2
+ type: train
3
+
4
+ params:
5
+ epochs: 10
6
+ batch_size: 32
7
+ lr: 1e-4
8
+ seed: 42
9
+
10
+ out_dir: null
11
+ pair_list: null
configs/models/pooling/truncatedsvd.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ n_components: 2
2
+ algorithm: randomized
3
+ n_iter: 5
4
+ n_oversamples: 10
5
+ poewr_iteration_normalizer: auto
6
+ random_state: 42
7
+ tol: 0
configs/train.yaml CHANGED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - _self_
3
+ - paths: default
4
+ - hydra: default # ← tells Hydra to use the logging/output config
5
+ - trainer: gpu
6
+ - data_task: model/classifier
7
+
8
+ task_name: train/${data_task.type}
dpacman/classifier/model/loss.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Define loss functions needed for training the model
3
+ """
4
+ import torch
5
+ from torch.nn import functional as F
6
+
7
+ def combined_loss_components(logits, targets, peak_thresh=0.5, eps=1e-8):
8
+ probs = torch.sigmoid(logits)
9
+ labels = (targets >= peak_thresh).float()
10
+ non_peak_mask = (labels == 0).float()
11
+ peak_mask = (labels == 1).float()
12
+
13
+ bce_all = F.binary_cross_entropy_with_logits(logits, labels, reduction='none')
14
+ bce_non = (bce_all * non_peak_mask)
15
+ bce_non = bce_non.sum() / (non_peak_mask.sum() + eps)
16
+
17
+ mse_peaks = F.mse_loss(probs * peak_mask, targets * peak_mask, reduction='sum') \
18
+ / (peak_mask.sum() + eps)
19
+
20
+ t_dist = (targets + eps)
21
+ p_dist = (probs + eps)
22
+ t_dist = t_dist / t_dist.sum(dim=1, keepdim=True)
23
+ p_dist = p_dist / p_dist.sum(dim=1, keepdim=True)
24
+ kl = (t_dist * (t_dist.clamp(min=eps).log() - p_dist.clamp(min=eps).log())).sum(dim=1).mean()
25
+
26
+ return bce_non, kl, mse_peaks, probs
27
+
28
+ def accuracy_percentage(logits, targets, peak_thresh=0.5):
29
+ probs = torch.sigmoid(logits)
30
+ preds_bin = (probs >= 0.5).float()
31
+ labels = (targets >= peak_thresh).float()
32
+ correct = (preds_bin == labels).float().sum()
33
+ total = torch.numel(labels)
34
+ return (correct / max(1, total)).item() * 100.0
dpacman/classifier/model/train.py CHANGED
@@ -77,85 +77,6 @@ def build_tf_cache(tf_paths, target_dim=256):
77
  print("[i] TF cache ready (raw); compression will be learned.", flush=True)
78
  return cache
79
 
80
- # ─────────────── Dataset & Collation ─────────────────────────────────────
81
- class PairDataset(Dataset):
82
- def __init__(self, tf_paths, dna_paths, final_df, tf_cache):
83
- self.tf_paths, self.dna_paths = tf_paths, dna_paths
84
- self.tf_cache = tf_cache
85
- self.targets = {}
86
- for _, row in final_df.iterrows():
87
- dna_id = row["dna_id"]
88
- vec = np.array(list(map(float, row["score_sig_r2"].split(","))), dtype=np.float32)
89
- self.targets[dna_id] = vec
90
-
91
- def __len__(self): return len(self.tf_paths)
92
-
93
- def __getitem__(self, i):
94
- b = self.tf_cache[self.tf_paths[i]] # (L_b, D_b) or (D_b,)
95
- if b.ndim==1: b = b[None,:]
96
- g = np.load(self.dna_paths[i]) # (L_g, 256) or (256,)
97
- if g.ndim==1: g = g[None,:]
98
-
99
- stem = Path(self.dna_paths[i]).stem
100
- dna_id = stem.replace("dna_","")
101
- t = self.targets.get(dna_id, np.zeros(g.shape[0],dtype=np.float32))
102
-
103
- return torch.from_numpy(b).float(), \
104
- torch.from_numpy(g).float(), \
105
- torch.from_numpy(t).float()
106
-
107
- def collate_fn(batch):
108
- Bs = [b.shape[0] for b,_,_ in batch]
109
- Gs = [g.shape[0] for _,g,_ in batch]
110
- maxB, maxG = max(Bs), max(Gs)
111
-
112
- def pad_seq(x, L):
113
- if x.shape[0] < L:
114
- pad = torch.zeros((L-x.shape[0], x.shape[1]), dtype=x.dtype, device=x.device)
115
- return torch.cat([x, pad], dim=0)
116
- return x
117
-
118
- def pad_t(y, L):
119
- if y.shape[0] < L:
120
- pad = torch.zeros((L-y.shape[0],), dtype=y.dtype, device=y.device)
121
- return torch.cat([y, pad], dim=0)
122
- return y
123
-
124
- b_stack = torch.stack([pad_seq(b, maxB) for b,_,_ in batch])
125
- g_stack = torch.stack([pad_seq(g, maxG) for _,g,_ in batch])
126
- t_stack = torch.stack([pad_t(t, maxG) for *_,t in batch])
127
- return b_stack, g_stack, t_stack
128
-
129
- # ─────────────── losses, metrics ─────────────────────────────────────────
130
- def combined_loss_components(logits, targets, peak_thresh=0.5, eps=1e-8):
131
- probs = torch.sigmoid(logits)
132
- labels = (targets >= peak_thresh).float()
133
- non_peak_mask = (labels == 0).float()
134
- peak_mask = (labels == 1).float()
135
-
136
- bce_all = F.binary_cross_entropy_with_logits(logits, labels, reduction='none')
137
- bce_non = (bce_all * non_peak_mask)
138
- bce_non = bce_non.sum() / (non_peak_mask.sum() + eps)
139
-
140
- mse_peaks = F.mse_loss(probs * peak_mask, targets * peak_mask, reduction='sum') \
141
- / (peak_mask.sum() + eps)
142
-
143
- t_dist = (targets + eps)
144
- p_dist = (probs + eps)
145
- t_dist = t_dist / t_dist.sum(dim=1, keepdim=True)
146
- p_dist = p_dist / p_dist.sum(dim=1, keepdim=True)
147
- kl = (t_dist * (t_dist.clamp(min=eps).log() - p_dist.clamp(min=eps).log())).sum(dim=1).mean()
148
-
149
- return bce_non, kl, mse_peaks, probs
150
-
151
- def accuracy_percentage(logits, targets, peak_thresh=0.5):
152
- probs = torch.sigmoid(logits)
153
- preds_bin = (probs >= 0.5).float()
154
- labels = (targets >= peak_thresh).float()
155
- correct = (preds_bin == labels).float().sum()
156
- total = torch.numel(labels)
157
- return (correct / max(1, total)).item() * 100.0
158
-
159
  def evaluate(model, dl, device, alpha, beta, gamma, peak_thresh, eps=1e-8):
160
  model.eval()
161
  tot_loss, tot_acc = 0.0, 0.0
 
77
  print("[i] TF cache ready (raw); compression will be learned.", flush=True)
78
  return cache
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  def evaluate(model, dl, device, alpha, beta, gamma, peak_thresh, eps=1e-8):
81
  model.eval()
82
  tot_loss, tot_acc = 0.0, 0.0
dpacman/classifier/model_tmp/__init__.py ADDED
File without changes
dpacman/classifier/model_tmp/clustering_data.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import argparse
3
+ import numpy as np
4
+ import pandas as pd
5
+ from pathlib import Path
6
+ import random
7
+ import sys
8
+ import subprocess
9
+ from collections import defaultdict
10
+
11
+ # ─────────────────────────────────────────────────────────────────────────
12
+ # Original helpers (kept; some lightly edited/commented where needed)
13
+ # ─────────────────────────────────────────────────────────────────────────
14
+
15
+ def read_ids_file(p):
16
+ p = Path(p)
17
+ if not p.exists():
18
+ raise FileNotFoundError(f"IDs file not found: {p}")
19
+ return [line.strip() for line in p.open() if line.strip()]
20
+
21
+ def split_embeddings(emb_path, ids_path, out_dir, prefix):
22
+ out_dir = Path(out_dir)
23
+ out_dir.mkdir(parents=True, exist_ok=True)
24
+
25
+ if not Path(emb_path).exists():
26
+ raise FileNotFoundError(f"Embedding file not found: {emb_path}")
27
+ if not Path(ids_path).exists():
28
+ raise FileNotFoundError(f"IDs file not found: {ids_path}")
29
+
30
+ if emb_path.endswith(".npz"):
31
+ data = np.load(emb_path, allow_pickle=True)
32
+ if "embeddings" in data:
33
+ emb = data["embeddings"]
34
+ else:
35
+ raise ValueError(f"{emb_path} missing 'embeddings' key")
36
+ else:
37
+ emb = np.load(emb_path)
38
+
39
+ ids = read_ids_file(ids_path)
40
+ if len(ids) != emb.shape[0]:
41
+ print(f"[WARN] length mismatch: {len(ids)} ids vs {emb.shape[0]} embeddings in {emb_path}", file=sys.stderr)
42
+
43
+ mapping = {}
44
+ for i, ident in enumerate(ids):
45
+ if i >= emb.shape[0]:
46
+ print(f"[WARN] skipping {ident}: no embedding at index {i}", file=sys.stderr)
47
+ continue
48
+ arr = emb[i]
49
+ out_file = out_dir / f"{prefix}_{ident}.npy"
50
+ np.save(out_file, arr)
51
+ mapping[ident] = str(out_file)
52
+ return mapping
53
+
54
+ def extract_symbol_from_tf_id(full_id: str) -> str:
55
+ """
56
+ Given a TF embedding ID like 'sp|O15062|ZBTB5_HUMAN' or 'ZBTB5_HUMAN',
57
+ return the gene symbol uppercase (e.g., 'ZBTB5').
58
+ """
59
+ if "|" in full_id:
60
+ try:
61
+ # format sp|Accession|SYMBOL_HUMAN
62
+ genepart = full_id.split("|")[2]
63
+ except IndexError:
64
+ genepart = full_id
65
+ else:
66
+ genepart = full_id
67
+ symbol = genepart.split("_")[0]
68
+ return symbol.upper()
69
+
70
+ def build_tf_symbol_map(tf_map):
71
+ """
72
+ Build mapping gene_symbol -> list of embedding paths.
73
+ """
74
+ symbol_map = {}
75
+ for full_id, path in tf_map.items():
76
+ symbol = extract_symbol_from_tf_id(full_id)
77
+ symbol_map.setdefault(symbol, []).append(path)
78
+ return symbol_map
79
+
80
+ def tf_key_from_path(path: str) -> str:
81
+ """
82
+ Given a path like .../tf_sp|O15062|ZBTB5_HUMAN.npy, extract normalized symbol 'ZBTB5'.
83
+ """
84
+ stem = Path(path).stem # e.g., tf_sp|O15062|ZBTB5_HUMAN
85
+ # remove leading prefix if present (tf_)
86
+ if "_" in stem:
87
+ _, rest = stem.split("_", 1)
88
+ else:
89
+ rest = stem
90
+ return extract_symbol_from_tf_id(rest)
91
+
92
+ def dna_key_from_path(path: str) -> str:
93
+ """
94
+ Given .../dna_peak42.npy -> 'peak42'
95
+ """
96
+ stem = Path(path).stem
97
+ if "_" in stem:
98
+ _, rest = stem.split("_", 1)
99
+ else:
100
+ rest = stem
101
+ return rest
102
+
103
+ # ─────────────────────────────────────────────────────────────────────────
104
+ # New helpers for MMseqs clustering & cluster-level splitting
105
+ # ─────────────────────────────────────────────────────────────────────────
106
+
107
+ def write_dna_fasta(df: pd.DataFrame, out_fasta: Path) -> None:
108
+ """
109
+ Write unique DNA sequences to FASTA using dna_id as header.
110
+ Requires df with columns: dna_id, dna_sequence
111
+ """
112
+ uniq = df[["dna_id", "dna_sequence"]].drop_duplicates()
113
+ with open(out_fasta, "w") as f:
114
+ for _, row in uniq.iterrows():
115
+ did = row["dna_id"]
116
+ seq = str(row["dna_sequence"]).upper().replace(" ", "").replace("\n", "")
117
+ f.write(f">{did}\n{seq}\n")
118
+
119
+ def run_mmseqs_easy_cluster(
120
+ mmseqs_bin: str,
121
+ fasta: Path,
122
+ out_prefix: Path,
123
+ tmp_dir: Path,
124
+ min_seq_id: float,
125
+ coverage: float,
126
+ cov_mode: int,
127
+ ) -> Path:
128
+ """
129
+ Runs mmseqs easy-cluster on nucleotide sequences.
130
+ Returns the path to a clusters TSV file (creating it if the default one isn't present).
131
+ """
132
+ tmp_dir.mkdir(parents=True, exist_ok=True)
133
+ out_prefix.parent.mkdir(parents=True, exist_ok=True)
134
+
135
+ cmd = [
136
+ mmseqs_bin, "easy-cluster",
137
+ str(fasta), str(out_prefix), str(tmp_dir),
138
+ "--min-seq-id", str(min_seq_id),
139
+ "-c", str(coverage),
140
+ "--cov-mode", str(cov_mode),
141
+ # You can add performance flags here if needed, e.g.:
142
+ # "--threads", "8"
143
+ ]
144
+ print("[i] Running:", " ".join(cmd), flush=True)
145
+ subprocess.run(cmd, check=True)
146
+
147
+ # MMseqs easy-cluster typically writes <out_prefix>_cluster.tsv
148
+ default_tsv = Path(str(out_prefix) + "_cluster.tsv")
149
+ if default_tsv.exists():
150
+ print(f"[i] Found cluster TSV: {default_tsv}")
151
+ return default_tsv
152
+
153
+ # Fallback: try createtsv if default is missing
154
+ # This requires the internal DBs. easy-cluster creates DBs alongside out_prefix.
155
+ # We'll try to locate them and emit a TSV.
156
+ in_db = Path(str(out_prefix) + "_query")
157
+ cl_db = Path(str(out_prefix) + "_cluster")
158
+ out_tsv = Path(str(out_prefix) + "_fallback_cluster.tsv")
159
+ if in_db.exists() and cl_db.exists():
160
+ cmd2 = [mmseqs_bin, "createtsv", str(in_db), str(in_db), str(cl_db), str(out_tsv)]
161
+ print("[i] Creating TSV via createtsv:", " ".join(cmd2), flush=True)
162
+ subprocess.run(cmd2, check=True)
163
+ if out_tsv.exists():
164
+ return out_tsv
165
+
166
+ raise FileNotFoundError("Could not locate clusters TSV from mmseqs. "
167
+ "Expected {default_tsv} or createtsv fallback.")
168
+
169
+ def parse_mmseqs_clusters(tsv_path: Path) -> dict:
170
+ """
171
+ Parse MMseqs cluster TSV (rep \t member). Returns dna_id -> cluster_rep_id
172
+ """
173
+ mapping = {}
174
+ with open(tsv_path) as f:
175
+ for line in f:
176
+ parts = line.rstrip("\n").split("\t")
177
+ if len(parts) < 2:
178
+ continue
179
+ rep, member = parts[0], parts[1]
180
+ mapping[member] = rep
181
+ # Some TSVs include rep->rep; if not, ensure rep is mapped to itself:
182
+ if rep not in mapping:
183
+ mapping[rep] = rep
184
+ return mapping
185
+
186
+ def assign_clusters_to_splits(cluster_rep_to_members: dict,
187
+ val_frac: float,
188
+ test_frac: float,
189
+ seed: int = 42):
190
+ """
191
+ cluster_rep_to_members: dict[rep] = [members...]
192
+ Returns: dict with keys 'train','val','test' mapping to sets of dna_id.
193
+ Ensures all members of a cluster go to the same split.
194
+ """
195
+ rng = random.Random(seed)
196
+ reps = list(cluster_rep_to_members.keys())
197
+ rng.shuffle(reps)
198
+
199
+ # Greedy-ish fill by total member counts to match desired fractions.
200
+ total = sum(len(cluster_rep_to_members[r]) for r in reps)
201
+ target_val = int(round(total * val_frac))
202
+ target_test = int(round(total * test_frac))
203
+ cur_val = cur_test = 0
204
+
205
+ val_ids, test_ids, train_ids = set(), set(), set()
206
+ for rep in reps:
207
+ members = cluster_rep_to_members[rep]
208
+ c = len(members)
209
+ # Fill val first, then test, then train
210
+ if cur_val + c <= target_val:
211
+ val_ids.update(members); cur_val += c
212
+ elif cur_test + c <= target_test:
213
+ test_ids.update(members); cur_test += c
214
+ else:
215
+ train_ids.update(members)
216
+
217
+ return {"train": train_ids, "val": val_ids, "test": test_ids}
218
+
219
+ # ─────────────────────────────────────────────────────────────────────────
220
+ # Main
221
+ # ─────────────────────────────────────────────────────────────────────────
222
+
223
+ def main():
224
+ parser = argparse.ArgumentParser(
225
+ description="Build TF-DNA pair lists with MMseqs clustering on DNA to prevent split leakage."
226
+ )
227
+ parser.add_argument("--final_csv", required=True, help="final.csv with TF_id and dna_sequence")
228
+ parser.add_argument("--dna_embed_npz", required=True, help="DNA embedding file (.npy or .npz)")
229
+ parser.add_argument("--dna_ids", required=True, help="IDs file for DNA embeddings (peak*.ids)")
230
+ parser.add_argument("--tf_embed_npy", required=True, help="TF embedding file (.npy or .npz)")
231
+ parser.add_argument("--tf_ids", required=True, help="IDs file for TF embeddings (sp|... ids)")
232
+ parser.add_argument("--out_dir", required=True, help="Output directory")
233
+ parser.add_argument("--seed", type=int, default=42)
234
+
235
+ # NEW: MMseqs options & split fractions
236
+ parser.add_argument("--mmseqs_bin", default="mmseqs", help="Path to mmseqs binary")
237
+ parser.add_argument("--min_seq_id", type=float, default=0.9, help="MMseqs --min-seq-id")
238
+ parser.add_argument("--cov", type=float, default=0.8, help="MMseqs -c coverage fraction")
239
+ parser.add_argument("--cov_mode", type=int, default=1, help="MMseqs --cov-mode (1 = coverage of target)")
240
+ parser.add_argument("--val_frac", type=float, default=0.10)
241
+ parser.add_argument("--test_frac", type=float, default=0.10)
242
+ parser.add_argument("--tmp_dir", default=None, help="MMseqs tmp dir (defaults to out_dir/tmp)")
243
+ args = parser.parse_args()
244
+
245
+ random.seed(args.seed)
246
+ out_dir = Path(args.out_dir)
247
+ out_dir.mkdir(parents=True, exist_ok=True)
248
+
249
+ # Load final.csv
250
+ df = pd.read_csv(args.final_csv, dtype=str)
251
+ if "TF_id" not in df.columns or "dna_sequence" not in df.columns:
252
+ raise RuntimeError("final.csv must have columns TF_id and dna_sequence")
253
+
254
+ # Assign dna_id (unique per dna_sequence)
255
+ unique_seqs = df["dna_sequence"].drop_duplicates().tolist()
256
+ seq_to_id = {seq: f"peak{i}" for i, seq in enumerate(unique_seqs)}
257
+ df["dna_id"] = df["dna_sequence"].map(seq_to_id)
258
+ enriched_csv = out_dir / "final_with_dna_id.csv"
259
+ df.to_csv(enriched_csv, index=False)
260
+ print(f"[i] Wrote augmented final.csv with dna_id to {enriched_csv}")
261
+
262
+ # Split embeddings into per-item files (unchanged)
263
+ print(f"[i] Splitting DNA embeddings from {args.dna_embed_npz} with ids {args.dna_ids}")
264
+ dna_map = split_embeddings(args.dna_embed_npz, args.dna_ids, out_dir / "dna_single", "dna")
265
+ print(f"[i] DNA embeddings available: {len(dna_map)} (sample: {list(dna_map.keys())[:10]})")
266
+ print(f"[i] Splitting TF embeddings from {args.tf_embed_npy} with ids {args.tf_ids}")
267
+ tf_map = split_embeddings(args.tf_embed_npy, args.tf_ids, out_dir / "tf_single", "tf")
268
+ print(f"[i] TF embeddings available: {len(tf_map)} (sample: {list(tf_map.keys())[:10]})")
269
+
270
+ # Build gene-symbol normalized map
271
+ tf_symbol_map = build_tf_symbol_map(tf_map)
272
+ print(f"[i] TF symbol map keys (sample): {list(tf_symbol_map.keys())[:30]}")
273
+
274
+ # Diagnostic overlaps
275
+ norm_tf_in_final = set(t.split("_seq")[0].upper() for t in df["TF_id"].unique())
276
+ available_tf_symbols = set(tf_symbol_map.keys())
277
+ intersect_tf = norm_tf_in_final & available_tf_symbols
278
+ print(f"[i] Unique normalized TF symbols in final.csv: {len(norm_tf_in_final)}")
279
+ print(f"[i] Available TF embedding symbols: {len(available_tf_symbols)}")
280
+ print(f"[i] Intersection count: {len(intersect_tf)}")
281
+ if len(intersect_tf) == 0:
282
+ print("[ERROR] No overlap between normalized TF_id and TF embedding symbols.", file=sys.stderr)
283
+ print("Sample normalized TFs from final.csv:", sorted(list(norm_tf_in_final))[:30], file=sys.stderr)
284
+ print("Sample available TF symbols:", sorted(list(available_tf_symbols))[:30], file=sys.stderr)
285
+ sys.exit(1)
286
+
287
+ dna_ids_final = set(df["dna_id"].unique())
288
+ available_dna_ids = set(dna_map.keys())
289
+ intersect_dna = dna_ids_final & available_dna_ids
290
+ print(f"[i] Unique dna_id in final.csv: {len(dna_ids_final)}. Available DNA ids: {len(available_dna_ids)}. Intersection: {len(intersect_dna)}")
291
+ if len(intersect_dna) == 0:
292
+ print("[ERROR] No overlap on DNA ids.", file=sys.stderr)
293
+ sys.exit(1)
294
+
295
+ # ── NEW: MMseqs clustering on DNA sequences ───────────────────────────
296
+ fasta_path = out_dir / "dna_unique.fasta"
297
+ write_dna_fasta(df, fasta_path)
298
+ print(f"[i] Wrote FASTA with {df['dna_id'].nunique()} unique sequences → {fasta_path}")
299
+
300
+ tmp_dir = Path(args.tmp_dir) if args.tmp_dir else (out_dir / "mmseqs_tmp")
301
+ cluster_prefix = out_dir / "mmseqs_dna_clusters"
302
+ clusters_tsv = run_mmseqs_easy_cluster(
303
+ mmseqs_bin=args.mmseqs_bin,
304
+ fasta=fasta_path,
305
+ out_prefix=cluster_prefix,
306
+ tmp_dir=tmp_dir,
307
+ min_seq_id=args.min_seq_id,
308
+ coverage=args.cov,
309
+ cov_mode=args.cov_mode,
310
+ )
311
+
312
+ # Parse clusters
313
+ member_to_rep = parse_mmseqs_clusters(clusters_tsv) # dna_id -> rep_id
314
+ # Build rep -> members list
315
+ rep_to_members = defaultdict(list)
316
+ for member, rep in member_to_rep.items():
317
+ rep_to_members[rep].append(member)
318
+
319
+ print(f"[i] Parsed {len(rep_to_members)} clusters from {clusters_tsv}")
320
+ clusters_table = []
321
+ for rep, members in rep_to_members.items():
322
+ for m in members:
323
+ clusters_table.append((m, rep))
324
+ clusters_df = pd.DataFrame(clusters_table, columns=["dna_id", "cluster_id"])
325
+ clusters_df.to_csv(out_dir / "clusters.tsv", sep="\t", index=False)
326
+ print(f"[i] Wrote clusters mapping → {out_dir / 'clusters.tsv'}")
327
+
328
+ # Attach cluster_id back to final df
329
+ df = df.merge(clusters_df, on="dna_id", how="left")
330
+ df.to_csv(out_dir / "final_with_dna_id_and_cluster.csv", index=False)
331
+ print(f"[i] Wrote {out_dir / 'final_with_dna_id_and_cluster.csv'}")
332
+
333
+ # Assign entire clusters to splits
334
+ splits = assign_clusters_to_splits(rep_to_members,
335
+ val_frac=args.val_frac,
336
+ test_frac=args.test_frac,
337
+ seed=args.seed)
338
+ for k in ["train", "val", "test"]:
339
+ print(f"[i] {k}: {len(splits[k])} dna_ids")
340
+
341
+ # ── Build positive pairs only, per split (NO negatives) ───────────────
342
+ positives_by_split = {"train": [], "val": [], "test": []}
343
+ # Build a quick dna_id -> embedding path map
344
+ dnaid_to_path = {did: path for did, path in dna_map.items()}
345
+
346
+ pos_count = 0
347
+ for _, row in df.iterrows():
348
+ tf_raw = row["TF_id"]
349
+ tf_symbol = tf_raw.split("_seq")[0].upper()
350
+ dnaid = row["dna_id"]
351
+ if (tf_symbol not in tf_symbol_map) or (dnaid not in dnaid_to_path):
352
+ continue
353
+ tf_embedding_path = tf_symbol_map[tf_symbol][0] # first embedding per symbol
354
+
355
+ # decide split by dna_id cluster assignment
356
+ if dnaid in splits["train"]:
357
+ positives_by_split["train"].append((tf_embedding_path, dnaid_to_path[dnaid], 1))
358
+ elif dnaid in splits["val"]:
359
+ positives_by_split["val"].append((tf_embedding_path, dnaid_to_path[dnaid], 1))
360
+ elif dnaid in splits["test"]:
361
+ positives_by_split["test"].append((tf_embedding_path, dnaid_to_path[dnaid], 1))
362
+ pos_count += 1
363
+
364
+ print(f"[i] Constructed positives across splits (rows in final.csv iterated: {len(df)})")
365
+ for k in ["train", "val", "test"]:
366
+ print(f"[i] positives[{k}] = {len(positives_by_split[k])}")
367
+
368
+ # # OLD: negatives (kept commented)
369
+ # negatives = []
370
+ # print(f"[i] Sampled {len(negatives)} negatives (neg_per_positive not used)")
371
+
372
+ # Emit split-specific pair lists
373
+ for split in ["train", "val", "test"]:
374
+ out_tsv = out_dir / f"pair_list_{split}.tsv"
375
+ with open(out_tsv, "w") as f:
376
+ for binder_path, glm_path, label in positives_by_split[split]: # + negatives if you add later
377
+ f.write(f"{binder_path}\t{glm_path}\t{label}\n")
378
+ print(f"[i] Wrote {len(positives_by_split[split])} examples to {out_tsv}")
379
+
380
+ print("✅ Done. Cluster-aware splits ready.")
381
+
382
+ if __name__ == "__main__":
383
+ main()
dpacman/classifier/model_tmp/compress_embeddings.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # compress_embeddings.py
2
+ # USAGE: python compress_embeddings.py --input_glob "/path/to/esm_embeddings/*.npy" --output_dir "/path/to/compressed_embeddings" --esm_dim 1280 --out_dim 256
3
+ # --------------
4
+ import os
5
+ import glob
6
+ import numpy as np
7
+ import torch
8
+ from torch import nn
9
+
10
+ class EmbeddingCompressor(nn.Module):
11
+ def __init__(self, input_dim: int = 1280, output_dim: int = 256):
12
+ super().__init__()
13
+ self.fc = nn.Linear(input_dim, output_dim)
14
+
15
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
16
+ """
17
+ x: (batch, L, input_dim) or (L, input_dim)
18
+ returns: (batch, output_dim) or (output_dim,)
19
+ """
20
+ if x.dim() == 2:
21
+ # single example: mean over tokens
22
+ x = x.mean(dim=0, keepdim=True) # → (1, input_dim)
23
+ else:
24
+ # batch: mean over tokens
25
+ x = x.mean(dim=1) # → (batch, input_dim)
26
+ return self.fc(x) # → (batch, output_dim)
27
+
28
+ def compress_file(in_path: str, out_path: str, model: EmbeddingCompressor):
29
+ arr = np.load(in_path) # shape (L, D) or (batch, L, D)
30
+ tensor = torch.from_numpy(arr).float()
31
+ with torch.no_grad():
32
+ compressed = model(tensor) # → (batch, 256)
33
+ out = compressed.cpu().numpy()
34
+ np.save(out_path, out)
35
+ print(f"Saved {out_path}")
36
+
37
+ if __name__ == "__main__":
38
+ import argparse
39
+ parser = argparse.ArgumentParser(description="Compress ESM embeddings to 256­d")
40
+ parser.add_argument("--input_glob", type=str, required=True,
41
+ help="Glob for your .npy ESM embeddings (e.g. data/esm_*.npy)")
42
+ parser.add_argument("--output_dir", type=str, required=True)
43
+ parser.add_argument("--esm_dim", type=int, default=1280)
44
+ parser.add_argument("--out_dim", type=int, default=256)
45
+ args = parser.parse_args()
46
+
47
+ os.makedirs(args.output_dir, exist_ok=True)
48
+ compressor = EmbeddingCompressor(args.esm_dim, args.out_dim)
49
+ compressor.eval()
50
+
51
+ for fn in glob.glob(args.input_glob):
52
+ base = os.path.basename(fn).replace(".npy", "_256.npy")
53
+ out_path = os.path.join(args.output_dir, base)
54
+ compress_file(fn, out_path, compressor)
dpacman/{data_tasks/embeddings → classifier/model_tmp}/compute_embeddings.py RENAMED
@@ -14,7 +14,6 @@ Usage example (DNA + protein in one go):
14
  --out-dir ../data_files/processed/tfclust/hg38_tf/embeddings \
15
  --device cuda
16
  """
17
-
18
  import os
19
  import re
20
  import argparse
@@ -29,7 +28,6 @@ import time
29
 
30
  # ---- model wrappers ----
31
 
32
-
33
  class CaduceusEmbedder:
34
  def __init__(self, device, chunk_size=131_072, overlap=0):
35
  """
@@ -41,54 +39,47 @@ class CaduceusEmbedder:
41
  self.tokenizer = AutoTokenizer.from_pretrained(
42
  model_name, trust_remote_code=True
43
  )
44
- self.model = (
45
- AutoModel.from_pretrained(model_name, trust_remote_code=True)
46
- .to(device)
47
- .eval()
48
- )
49
- self.device = device
50
  self.chunk_size = chunk_size
51
- self.step = chunk_size - overlap
52
 
53
  def embed(self, seqs):
54
  """
55
  seqs: List[str] of DNA sequences (each <= chunk_size for this test)
56
  returns: np.ndarray of shape (N, L, D), raw per‐token embeddings
57
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  outputs = []
59
  for seq in seqs:
60
- # --- old windowing + mean-pooling logic, now commented out ---
61
- # window_vecs = []
62
- # for i in range(0, len(seq), self.step):
63
- # chunk = seq[i : i + self.chunk_size]
64
- # if not chunk:
65
- # break
66
- # toks = self.tokenizer(
67
- # chunk,
68
- # return_tensors="pt",
69
- # padding=False,
70
- # truncation=True,
71
- # max_length=self.chunk_size
72
- # ).to(self.device)
73
- # with torch.no_grad():
74
- # out = self.model(**toks).last_hidden_state
75
- # window_vecs.append(out.mean(dim=1).squeeze(0).cpu())
76
- # seq_emb = torch.stack(window_vecs, dim=0).mean(dim=0).numpy()
77
- # outputs.append(seq_emb)
78
-
79
- # --- new: raw per‐token embeddings in one shot ---
80
  toks = self.tokenizer(
81
  seq,
82
  return_tensors="pt",
83
  padding=False,
84
  truncation=True,
85
- max_length=self.chunk_size,
86
  ).to(self.device)
87
  with torch.no_grad():
88
  out = self.model(**toks).last_hidden_state # (1, L, D)
89
- outputs.append(out.cpu().numpy()[0]) # (L, D)
 
90
 
91
- return np.stack(outputs, axis=0) # (N, L, D)
92
 
93
  def benchmark(self, lengths=None):
94
  """
@@ -110,29 +101,79 @@ class CaduceusEmbedder:
110
  t1 = time.perf_counter()
111
  print(f" length={sz:6,d} time={(t1-t0)*1000:7.1f} ms")
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
  class DNABertEmbedder:
115
  def __init__(self, device):
116
- self.tokenizer = AutoTokenizer.from_pretrained(
117
- "zhihan1996/DNA_bert_6", trust_remote_code=True
118
- )
119
- self.model = AutoModel.from_pretrained(
120
- "zhihan1996/DNA_bert_6", trust_remote_code=True
121
- ).to(device)
122
- self.device = device
123
 
124
  def embed(self, seqs):
125
  embs = []
126
  for s in seqs:
127
- tokens = self.tokenizer(s, return_tensors="pt", padding=True)[
128
- "input_ids"
129
- ].to(self.device)
130
  with torch.no_grad():
131
  out = self.model(tokens).last_hidden_state.mean(1)
132
  embs.append(out.cpu().numpy())
133
  return np.vstack(embs)
134
 
135
-
136
  class NucleotideTransformerEmbedder:
137
  def __init__(self, device):
138
  # HF “feature-extraction” returns a list of (L, D) arrays for each input
@@ -140,9 +181,7 @@ class NucleotideTransformerEmbedder:
140
  self.pipe = pipeline(
141
  "feature-extraction",
142
  model="InstaDeepAI/nucleotide-transformer-500m-1000g",
143
- device=(
144
- -1 if device == "cpu" else 0
145
- ), # HF uses -1 for CPU, 0 for GPU #:contentReference[oaicite:0]{index=0}
146
  )
147
 
148
  def embed(self, seqs):
@@ -152,35 +191,131 @@ class NucleotideTransformerEmbedder:
152
  """
153
  all_embeddings = self.pipe(seqs, truncation=True, padding=True)
154
  # all_embeddings is a List of shape (L, D) arrays
155
- pooled = [np.mean(x, axis=0) for x in all_embeddings]
156
- return np.vstack(pooled)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
 
159
  class ESMEmbedder:
160
- def __init__(self, device):
161
- self.model, self.alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
 
 
 
 
 
 
 
 
162
  self.batch_converter = self.alphabet.get_batch_converter()
163
  self.model.to(device).eval()
164
- self.device = device
 
 
 
 
165
 
166
- def embed(self, seqs):
167
- batch = [(str(i), seq) for i, seq in enumerate(seqs)]
168
- _, _, toks = self.batch_converter(batch)
169
- toks = toks.to(self.device)
170
- with torch.no_grad():
171
- results = self.model(toks, repr_layers=[33], return_contacts=False)
172
- reps = results["representations"][33]
173
- return reps[:, 1:-1].mean(1).cpu().numpy()
 
 
 
 
 
 
174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
  class ESMDBPEmbedder:
177
  def __init__(self, device):
178
  base_model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
179
  model_path = (
180
  Path(__file__).resolve().parent.parent
181
- / "pretrained"
182
- / "ESM-DBP"
183
- / "ESM-DBP.model"
184
  )
185
  checkpoint = torch.load(model_path, map_location="cpu")
186
  clean_sd = {}
@@ -196,17 +331,46 @@ class ESMDBPEmbedder:
196
  self.alphabet = alphabet
197
  self.batch_converter = alphabet.get_batch_converter()
198
  self.device = device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
 
200
  def embed(self, seqs):
201
- batch = [(str(i), seq) for i, seq in enumerate(seqs)]
202
- _, _, toks = self.batch_converter(batch)
203
- toks = toks.to(self.device)
204
- with torch.no_grad():
205
- out = self.model(toks, repr_layers=[33], return_contacts=False)
206
- reps = out["representations"][33]
207
- # skip start/end tokens
208
- return reps[:, 1:-1].mean(1).cpu().numpy()
209
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
 
211
  class GPNEmbedder:
212
  def __init__(self, device):
@@ -219,14 +383,16 @@ class GPNEmbedder:
219
 
220
  def embed(self, seqs):
221
  inputs = self.tokenizer(
222
- seqs, return_tensors="pt", padding=True, truncation=True
 
 
 
223
  ).to(self.device)
224
 
225
  with torch.no_grad():
226
  last_hidden = self.model(**inputs).last_hidden_state
227
  return last_hidden.mean(dim=1).cpu().numpy()
228
 
229
-
230
  class ProGenEmbedder:
231
  def __init__(self, device):
232
  model_name = "jinyuan22/ProGen2-base"
@@ -236,102 +402,137 @@ class ProGenEmbedder:
236
 
237
  def embed(self, seqs):
238
  inputs = self.tokenizer(
239
- seqs, return_tensors="pt", padding=True, truncation=True
 
 
 
240
  ).to(self.device)
241
  with torch.no_grad():
242
  last_hidden = self.model(**inputs).last_hidden_state
243
  return last_hidden.mean(dim=1).cpu().numpy()
244
 
245
-
246
  # ---- main pipeline ----
247
 
248
-
249
  def get_embedder(name, device, for_dna=True):
250
  name = name.lower()
251
  if for_dna:
252
- if name == "caduceus":
253
- return CaduceusEmbedder(device)
254
- if name == "dnabert":
255
- return DNABertEmbedder(device)
256
- if name == "nucleotide":
257
- return NucleotideTransformerEmbedder(device)
258
- if name == "gpn":
259
- return GPNEmbedder(device)
260
  else:
261
- if name in ("esm",):
262
- return ESMEmbedder(device)
263
- if name in ("esm-dbp", "esm_dbp"):
264
- return ESMDBPEmbedder(device)
265
- if name == "progen":
266
- return ProGenEmbedder(device)
267
  raise ValueError(f"Unknown model {name} (for_dna={for_dna})")
268
 
269
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  def embed_and_save(seqs, ids, embedder, out_path):
271
  embs = embedder.embed(seqs)
272
- np.save(out_path, embs)
273
- with open(out_path.with_suffix(".ids"), "w") as f:
274
- f.write("\n".join(ids))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
 
276
 
277
- if __name__ == "__main__":
278
 
279
  p = argparse.ArgumentParser()
280
- p.add_argument(
281
- "--genome-json-dir",
282
- default="data_files/raw/genomes/hg38",
283
- help="dir of UCSC JSONs",
284
- )
285
- p.add_argument(
286
- "--skip-dna",
287
- action="store_true",
288
- help="if set, skip the chromosome embedding step",
289
- ) # if glm embeddings successful but not plm embeddings
290
- p.add_argument("--tf-fasta", required=True, help="input TF FASTA file")
291
- p.add_argument("--chrom-model", default="caduceus")
292
- p.add_argument("--tf-model", default="esm-dbp")
293
- p.add_argument(
294
- "--out-dir", default="data_files/processed/tfclust/hg38_tf/embeddings"
295
- )
296
- p.add_argument("--device", default="cpu")
297
  args = p.parse_args()
298
 
299
  os.makedirs(args.out_dir, exist_ok=True)
300
  device = args.device
301
 
302
  if not args.skip_dna:
303
- # Load only primary chromosome JSONs (chr1–22, X, Y, M)
304
- genome_dir = Path(args.genome_json_dir)
305
- chrom_seqs, chrom_ids = [], []
306
- primary_pattern = re.compile(r"^hg38_chr(?:[1-9]|1[0-9]|2[0-2]|X|Y|M)\.json$")
307
- for j in sorted(genome_dir.iterdir()):
308
- if not primary_pattern.match(j.name):
309
- continue
310
- data = json.loads(j.read_text())
311
- seq = data.get("dna") or data.get("sequence")
312
- chrom = data.get("chrom") or j.stem.split("_")[-1]
313
- chrom_seqs.append(seq)
314
- chrom_ids.append(chrom)
315
- ########################
316
- cutoff = CaduceusEmbedder(device).chunk_size
317
- long_chroms = [
318
- (chrom, len(seq))
319
- for chrom, seq in zip(chrom_ids, chrom_seqs)
320
- if len(seq) > cutoff
321
- ]
322
- if long_chroms:
323
- print("⚠️ Chromosomes exceeding Caduceus max tokens ({}):".format(cutoff))
324
- for chrom, L in long_chroms:
325
- print(f" {chrom}: {L} bases")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
  else:
327
- print("All chromosomes Caduceus limit ({}).".format(cutoff))
328
 
329
- ####################
330
- chrom_embedder = get_embedder(args.chrom_model, device, for_dna=True)
331
- out_chrom = Path(args.out_dir) / f"chrom_{args.chrom_model}.npy"
332
- embed_and_save(chrom_seqs, chrom_ids, chrom_embedder, out_chrom)
333
 
334
- # Load TF sequences
335
  tf_seqs, tf_ids = [], []
336
  for record in SeqIO.parse(args.tf_fasta, "fasta"):
337
  tf_ids.append(record.id)
@@ -342,4 +543,4 @@ if __name__ == "__main__":
342
  out_tf = Path(args.out_dir) / f"tf_{args.tf_model}.npy"
343
  embed_and_save(tf_seqs, tf_ids, tf_embedder, out_tf)
344
 
345
- print("Done.")
 
14
  --out-dir ../data_files/processed/tfclust/hg38_tf/embeddings \
15
  --device cuda
16
  """
 
17
  import os
18
  import re
19
  import argparse
 
28
 
29
  # ---- model wrappers ----
30
 
 
31
  class CaduceusEmbedder:
32
  def __init__(self, device, chunk_size=131_072, overlap=0):
33
  """
 
39
  self.tokenizer = AutoTokenizer.from_pretrained(
40
  model_name, trust_remote_code=True
41
  )
42
+ self.model = AutoModel.from_pretrained(
43
+ model_name, trust_remote_code=True
44
+ ).to(device).eval()
45
+ self.device = device
 
 
46
  self.chunk_size = chunk_size
47
+ self.step = chunk_size - overlap
48
 
49
  def embed(self, seqs):
50
  """
51
  seqs: List[str] of DNA sequences (each <= chunk_size for this test)
52
  returns: np.ndarray of shape (N, L, D), raw per‐token embeddings
53
  """
54
+ # outputs = []
55
+ # for seq in seqs:
56
+ # # --- new: raw per‐token embeddings in one shot ---
57
+ # toks = self.tokenizer(
58
+ # seq,
59
+ # return_tensors="pt",
60
+ # padding=False,
61
+ # truncation=True,
62
+ # max_length=self.chunk_size
63
+ # ).to(self.device)
64
+ # with torch.no_grad():
65
+ # out = self.model(**toks).last_hidden_state # (1, L, D)
66
+ # outputs.append(out.cpu().numpy()[0]) # (L, D)
67
+
68
+ # return np.stack(outputs, axis=0) # (N, L, D)
69
  outputs = []
70
  for seq in seqs:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  toks = self.tokenizer(
72
  seq,
73
  return_tensors="pt",
74
  padding=False,
75
  truncation=True,
76
+ max_length=self.chunk_size
77
  ).to(self.device)
78
  with torch.no_grad():
79
  out = self.model(**toks).last_hidden_state # (1, L, D)
80
+ outputs.append(out.cpu().numpy()[0]) # (L, D)
81
+ return outputs # list of variable-length (L_i, D) arrays
82
 
 
83
 
84
  def benchmark(self, lengths=None):
85
  """
 
101
  t1 = time.perf_counter()
102
  print(f" length={sz:6,d} time={(t1-t0)*1000:7.1f} ms")
103
 
104
+ class SegmentNTEmbedder:
105
+ def __init__(self, device):
106
+ self.tokenizer = AutoTokenizer.from_pretrained("InstaDeepAI/segment_nt", trust_remote_code=True)
107
+ self.model = AutoModel.from_pretrained("InstaDeepAI/segment_nt", trust_remote_code=True).to(device).eval()
108
+ self.device = device
109
+
110
+ def _adjust_length(self, input_ids):
111
+ bs, L = input_ids.shape
112
+ excl = L - 1
113
+ remainder = (excl) % 4
114
+ if remainder != 0:
115
+ pad_needed = 4 - remainder
116
+ pad_tensor = torch.full((bs, pad_needed), self.tokenizer.pad_token_id, dtype=input_ids.dtype, device=input_ids.device)
117
+ input_ids = torch.cat([input_ids, pad_tensor], dim=1)
118
+ return input_ids
119
+
120
+ def embed(self, seqs, batch_size=16):
121
+ """
122
+ seqs: List[str]
123
+ Returns: np.ndarray of shape (N, D)
124
+ """
125
+ all_embeddings = []
126
+ for i in range(0, len(seqs), batch_size):
127
+ batch_seqs = seqs[i : i + batch_size]
128
+ encoded = self.tokenizer.batch_encode_plus(
129
+ batch_seqs,
130
+ return_tensors="pt",
131
+ padding=True,
132
+ truncation=True,
133
+ )
134
+ input_ids = encoded["input_ids"].to(self.device) # (B, L)
135
+ attention_mask = input_ids != self.tokenizer.pad_token_id
136
+
137
+ input_ids = self._adjust_length(input_ids)
138
+ attention_mask = (input_ids != self.tokenizer.pad_token_id)
139
+
140
+ with torch.no_grad():
141
+ outs = self.model(
142
+ input_ids,
143
+ attention_mask=attention_mask,
144
+ output_hidden_states=True,
145
+ return_dict=True,
146
+ )
147
+ if hasattr(outs, "hidden_states") and outs.hidden_states is not None:
148
+ last_hidden = outs.hidden_states[-1] # (B, L, D)
149
+ else:
150
+ last_hidden = outs.last_hidden_state # fallback
151
+
152
+ # Exclude CLS token if present (assume first token) and pool
153
+ pooled = last_hidden[:, 1:, :].mean(dim=1) # (B, D)
154
+ all_embeddings.append(pooled.cpu().numpy())
155
+
156
+ # release fragmentation
157
+ torch.cuda.empty_cache()
158
+
159
+ return np.vstack(all_embeddings) # (N, D)
160
+
161
 
162
  class DNABertEmbedder:
163
  def __init__(self, device):
164
+ self.tokenizer = AutoTokenizer.from_pretrained("zhihan1996/DNA_bert_6", trust_remote_code=True)
165
+ self.model = AutoModel.from_pretrained("zhihan1996/DNA_bert_6", trust_remote_code=True).to(device)
166
+ self.device = device
 
 
 
 
167
 
168
  def embed(self, seqs):
169
  embs = []
170
  for s in seqs:
171
+ tokens = self.tokenizer(s, return_tensors="pt", padding=True)["input_ids"].to(self.device)
 
 
172
  with torch.no_grad():
173
  out = self.model(tokens).last_hidden_state.mean(1)
174
  embs.append(out.cpu().numpy())
175
  return np.vstack(embs)
176
 
 
177
  class NucleotideTransformerEmbedder:
178
  def __init__(self, device):
179
  # HF “feature-extraction” returns a list of (L, D) arrays for each input
 
181
  self.pipe = pipeline(
182
  "feature-extraction",
183
  model="InstaDeepAI/nucleotide-transformer-500m-1000g",
184
+ device= -1 if device=="cpu" else 0 # HF uses -1 for CPU, 0 for GPU #:contentReference[oaicite:0]{index=0}
 
 
185
  )
186
 
187
  def embed(self, seqs):
 
191
  """
192
  all_embeddings = self.pipe(seqs, truncation=True, padding=True)
193
  # all_embeddings is a List of shape (L, D) arrays
194
+ pooled = [ np.mean(x, axis=0) for x in all_embeddings ]
195
+ return np.vstack(pooled)
196
+
197
+ # class ESMEmbedder:
198
+ # def __init__(self, device):
199
+ # self.model, self.alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
200
+ # self.batch_converter = self.alphabet.get_batch_converter()
201
+ # self.model.to(device).eval()
202
+ # self.device = device
203
+
204
+ # def embed(self, seqs):
205
+ # batch = [(str(i), seq) for i, seq in enumerate(seqs)]
206
+ # _, _, toks = self.batch_converter(batch)
207
+ # toks = toks.to(self.device)
208
+ # with torch.no_grad():
209
+ # results = self.model(toks, repr_layers=[33], return_contacts=False)
210
+ # reps = results["representations"][33]
211
+ # return reps[:, 1:-1].mean(1).cpu().numpy()
212
 
213
 
214
  class ESMEmbedder:
215
+ def __init__(self, device, model_name="esm2_t33_650M_UR50D"):
216
+ # Try to load the specified ESM-2 model; fallback to esm1b if missing
217
+ self.device = device
218
+ try:
219
+ self.model, self.alphabet = getattr(esm.pretrained, model_name)()
220
+ self.is_esm2 = model_name.lower().startswith("esm2")
221
+ except AttributeError:
222
+ # fallback to ESM-1b
223
+ self.model, self.alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
224
+ self.is_esm2 = False
225
  self.batch_converter = self.alphabet.get_batch_converter()
226
  self.model.to(device).eval()
227
+ # determine max length: esm2 models vary; use default 1024 for esm1b
228
+ self.max_len = 4096 if self.is_esm2 else 1024 # adjust if your esm2 variant has explicit limit
229
+ # for chunking: reserve 2 tokens if model uses BOS/EOS
230
+ self.chunk_size = self.max_len - 2
231
+ self.overlap = self.chunk_size // 4 # 25% overlap to smooth boundaries
232
 
233
+ def _chunk_sequence(self, seq):
234
+ """
235
+ Return list of possibly overlapping chunks of seq, each <= chunk_size.
236
+ """
237
+ if len(seq) <= self.chunk_size:
238
+ return [seq]
239
+ step = self.chunk_size - self.overlap
240
+ chunks = []
241
+ for i in range(0, len(seq), step):
242
+ chunk = seq[i : i + self.chunk_size]
243
+ if not chunk:
244
+ break
245
+ chunks.append(chunk)
246
+ return chunks
247
 
248
+ def embed(self, seqs):
249
+ """
250
+ seqs: List[str] of protein sequences.
251
+ Returns: np.ndarray of shape (N, D) pooled per-sequence embeddings.
252
+ """
253
+ all_embeddings = []
254
+ for i, seq in enumerate(seqs):
255
+ chunks = self._chunk_sequence(seq)
256
+ chunk_vecs = []
257
+ # process chunks in batch if small number, else sequentially
258
+ for chunk in chunks:
259
+ batch = [(str(i), chunk)]
260
+ _, _, toks = self.batch_converter(batch)
261
+ toks = toks.to(self.device)
262
+ with torch.no_grad():
263
+ results = self.model(toks, repr_layers=[33], return_contacts=False)
264
+ reps = results["representations"][33] # (1, L, D)
265
+ # remove BOS/EOS if present: take 1:-1 if length permits
266
+ if reps.size(1) > 2:
267
+ rep = reps[:, 1:-1].mean(1) # (1, D)
268
+ else:
269
+ rep = reps.mean(1) # fallback
270
+ chunk_vecs.append(rep.squeeze(0)) # (D,)
271
+ if len(chunk_vecs) == 1:
272
+ seq_vec = chunk_vecs[0]
273
+ else:
274
+ # average chunk vectors
275
+ stacked = torch.stack(chunk_vecs, dim=0) # (num_chunks, D)
276
+ seq_vec = stacked.mean(0)
277
+ all_embeddings.append(seq_vec.cpu().numpy())
278
+ return np.vstack(all_embeddings) # (N, D)
279
+
280
+
281
+ # class ESMDBPEmbedder:
282
+ # def __init__(self, device):
283
+ # base_model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
284
+ # model_path = (
285
+ # Path(__file__).resolve().parent.parent
286
+ # / "pretrained" / "ESM-DBP" / "ESM-DBP.model"
287
+ # )
288
+ # checkpoint = torch.load(model_path, map_location="cpu")
289
+ # clean_sd = {}
290
+ # for k, v in checkpoint.items():
291
+ # clean_sd[k.replace("module.", "")] = v
292
+ # result = base_model.load_state_dict(clean_sd, strict=False)
293
+ # if result.missing_keys:
294
+ # print(f"[ESMDBP] missing keys: {result.missing_keys}")
295
+ # if result.unexpected_keys:
296
+ # print(f"[ESMDBP] unexpected keys: {result.unexpected_keys}")
297
+
298
+ # self.model = base_model.to(device).eval()
299
+ # self.alphabet = alphabet
300
+ # self.batch_converter = alphabet.get_batch_converter()
301
+ # self.device = device
302
+
303
+ # def embed(self, seqs):
304
+ # batch = [(str(i), seq) for i, seq in enumerate(seqs)]
305
+ # _, _, toks = self.batch_converter(batch)
306
+ # toks = toks.to(self.device)
307
+ # with torch.no_grad():
308
+ # out = self.model(toks, repr_layers=[33], return_contacts=False)
309
+ # reps = out["representations"][33]
310
+ # # skip start/end tokens
311
+ # return reps[:, 1:-1].mean(1).cpu().numpy()
312
 
313
  class ESMDBPEmbedder:
314
  def __init__(self, device):
315
  base_model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
316
  model_path = (
317
  Path(__file__).resolve().parent.parent
318
+ / "pretrained" / "ESM-DBP" / "ESM-DBP.model"
 
 
319
  )
320
  checkpoint = torch.load(model_path, map_location="cpu")
321
  clean_sd = {}
 
331
  self.alphabet = alphabet
332
  self.batch_converter = alphabet.get_batch_converter()
333
  self.device = device
334
+ self.max_len = 1024 # same limit as esm1b
335
+ self.chunk_size = self.max_len - 2
336
+ self.overlap = self.chunk_size // 4
337
+
338
+ def _chunk_sequence(self, seq):
339
+ if len(seq) <= self.chunk_size:
340
+ return [seq]
341
+ step = self.chunk_size - self.overlap
342
+ chunks = []
343
+ for i in range(0, len(seq), step):
344
+ chunk = seq[i : i + self.chunk_size]
345
+ if not chunk:
346
+ break
347
+ chunks.append(chunk)
348
+ return chunks
349
 
350
  def embed(self, seqs):
351
+ all_embeddings = []
352
+ for i, seq in enumerate(seqs):
353
+ chunks = self._chunk_sequence(seq)
354
+ chunk_vecs = []
355
+ for chunk in chunks:
356
+ batch = [(str(i), chunk)]
357
+ _, _, toks = self.batch_converter(batch)
358
+ toks = toks.to(self.device)
359
+ with torch.no_grad():
360
+ out = self.model(toks, repr_layers=[33], return_contacts=False)
361
+ reps = out["representations"][33]
362
+ if reps.size(1) > 2:
363
+ rep = reps[:, 1:-1].mean(1)
364
+ else:
365
+ rep = reps.mean(1)
366
+ chunk_vecs.append(rep.squeeze(0))
367
+ if len(chunk_vecs) == 1:
368
+ seq_vec = chunk_vecs[0]
369
+ else:
370
+ stacked = torch.stack(chunk_vecs, dim=0)
371
+ seq_vec = stacked.mean(0)
372
+ all_embeddings.append(seq_vec.cpu().numpy())
373
+ return np.vstack(all_embeddings)
374
 
375
  class GPNEmbedder:
376
  def __init__(self, device):
 
383
 
384
  def embed(self, seqs):
385
  inputs = self.tokenizer(
386
+ seqs,
387
+ return_tensors="pt",
388
+ padding=True,
389
+ truncation=True
390
  ).to(self.device)
391
 
392
  with torch.no_grad():
393
  last_hidden = self.model(**inputs).last_hidden_state
394
  return last_hidden.mean(dim=1).cpu().numpy()
395
 
 
396
  class ProGenEmbedder:
397
  def __init__(self, device):
398
  model_name = "jinyuan22/ProGen2-base"
 
402
 
403
  def embed(self, seqs):
404
  inputs = self.tokenizer(
405
+ seqs,
406
+ return_tensors="pt",
407
+ padding=True,
408
+ truncation=True
409
  ).to(self.device)
410
  with torch.no_grad():
411
  last_hidden = self.model(**inputs).last_hidden_state
412
  return last_hidden.mean(dim=1).cpu().numpy()
413
 
 
414
  # ---- main pipeline ----
415
 
 
416
  def get_embedder(name, device, for_dna=True):
417
  name = name.lower()
418
  if for_dna:
419
+ if name=="caduceus": return CaduceusEmbedder(device)
420
+ if name=="dnabert": return DNABertEmbedder(device)
421
+ if name=="nucleotide": return NucleotideTransformerEmbedder(device)
422
+ if name=="gpn": return GPNEmbedder(device)
423
+ if name=="segmentnt": return SegmentNTEmbedder(device)
 
 
 
424
  else:
425
+ if name in ("esm",): return ESMEmbedder(device)
426
+ if name in ("esm-dbp","esm_dbp"): return ESMDBPEmbedder(device)
427
+ if name=="progen": return ProGenEmbedder(device)
 
 
 
428
  raise ValueError(f"Unknown model {name} (for_dna={for_dna})")
429
 
430
 
431
+ def pad_token_embeddings(list_of_arrays, pad_value=0.0):
432
+ """
433
+ list_of_arrays: list of (L_i, D) numpy arrays
434
+ Returns:
435
+ padded: (N, L_max, D) array
436
+ mask: (N, L_max) boolean array where True = real token, False = padding
437
+ """
438
+ N = len(list_of_arrays)
439
+ D = list_of_arrays[0].shape[1]
440
+ L_max = max(arr.shape[0] for arr in list_of_arrays)
441
+ padded = np.full((N, L_max, D), pad_value, dtype=list_of_arrays[0].dtype)
442
+ mask = np.zeros((N, L_max), dtype=bool)
443
+ for i, arr in enumerate(list_of_arrays):
444
+ L = arr.shape[0]
445
+ padded[i, :L] = arr
446
+ mask[i, :L] = True
447
+ return padded, mask
448
+
449
  def embed_and_save(seqs, ids, embedder, out_path):
450
  embs = embedder.embed(seqs)
451
+
452
+ # Decide whether we got variable-length per-token outputs (list of (L, D))
453
+ is_variable_token = isinstance(embs, (list, tuple)) and len(embs) > 0 and hasattr(embs[0], "shape") and embs[0].ndim == 2
454
+
455
+ if is_variable_token:
456
+ # pad to (N, L_max, D) + mask
457
+ padded, mask = pad_token_embeddings(embs)
458
+ # Save both embeddings and mask together in an .npz for convenience
459
+ np.savez_compressed(out_path.with_suffix(".caduceus.npz"),
460
+ embeddings=padded,
461
+ mask=mask,
462
+ ids=np.array(ids, dtype=object))
463
+ else:
464
+ # fixed shape output, e.g., pooled (N, D)
465
+ array = np.vstack(embs) if isinstance(embs, list) else embs
466
+ np.save(out_path, array)
467
+ with open(out_path.with_suffix(".ids"), "w") as f:
468
+ f.write("\n".join(ids))
469
 
470
 
471
+ if __name__=="__main__":
472
 
473
  p = argparse.ArgumentParser()
474
+ p.add_argument("--peak-fasta", default="binding_peaks_unique.fa", help="FASTA of deduplicated binding peak sequences; if present this is used for DNA embedding instead of genome JSONs")
475
+ p.add_argument("--genome-json-dir", default=None, help="(fallback) directory of UCSC JSONs for full chromosome embedding if peak FASTA is missing or you explicitly want chromosomes")
476
+ p.add_argument("--skip-dna", action="store_true", help="if set, skip the chromosome embedding step") #if glm embeddings successful but not plm embeddings
477
+ p.add_argument("--tf-fasta", required=True, help="input TF FASTA file")
478
+ p.add_argument("--chrom-model", default="caduceus")
479
+ p.add_argument("--tf-model", default="esm-dbp")
480
+ p.add_argument("--out-dir", default="data_files/processed/tfclust/hg38_tf/embeddings")
481
+ p.add_argument("--device", default="cpu")
 
 
 
 
 
 
 
 
 
482
  args = p.parse_args()
483
 
484
  os.makedirs(args.out_dir, exist_ok=True)
485
  device = args.device
486
 
487
  if not args.skip_dna:
488
+ peak_fasta = Path(args.peak_fasta)
489
+ if peak_fasta.exists():
490
+ # Load peak sequences from FASTA
491
+ from Bio import SeqIO
492
+
493
+ peak_seqs = []
494
+ peak_ids = []
495
+ for rec in SeqIO.parse(peak_fasta, "fasta"):
496
+ peak_ids.append(rec.id)
497
+ peak_seqs.append(str(rec.seq))
498
+ print(f"Embedding {len(peak_seqs)} binding peak sequences from {peak_fasta}", flush=True)
499
+ dna_embedder = get_embedder(args.chrom_model, device, for_dna=True)
500
+ out_peaks = Path(args.out_dir) / f"peaks_{args.chrom_model}.npy"
501
+ embed_and_save(peak_seqs, peak_ids, dna_embedder, out_peaks)
502
+ elif args.genome_json_dir:
503
+ # Legacy: load full chromosomes from JSONs (chr1–22, X, Y, M)
504
+ genome_dir = Path(args.genome_json_dir)
505
+ chrom_seqs, chrom_ids = [], []
506
+ primary_pattern = re.compile(r"^hg38_chr(?:[1-9]|1[0-9]|2[0-2]|X|Y|M)\.json$")
507
+ for j in sorted(genome_dir.iterdir()):
508
+ if not primary_pattern.match(j.name):
509
+ continue
510
+ data = json.loads(j.read_text())
511
+ seq = data.get("dna") or data.get("sequence")
512
+ chrom = data.get("chrom") or j.stem.split("_")[-1]
513
+ chrom_seqs.append(seq)
514
+ chrom_ids.append(chrom)
515
+ cutoff = CaduceusEmbedder(device).chunk_size
516
+ long_chroms = [
517
+ (chrom, len(seq))
518
+ for chrom, seq in zip(chrom_ids, chrom_seqs)
519
+ if len(seq) > cutoff
520
+ ]
521
+ if long_chroms:
522
+ print("⚠️ Chromosomes exceeding Caduceus max tokens ({}):".format(cutoff))
523
+ for chrom, L in long_chroms:
524
+ print(f" {chrom}: {L} bases")
525
+ else:
526
+ print("All chromosomes ≤ Caduceus limit ({}).".format(cutoff))
527
+
528
+ chrom_embedder = get_embedder(args.chrom_model, device, for_dna=True)
529
+ out_chrom = Path(args.out_dir) / f"chrom_{args.chrom_model}.npy"
530
+ embed_and_save(chrom_seqs, chrom_ids, chrom_embedder, out_chrom)
531
  else:
532
+ raise ValueError("No input for DNA embedding: provide a peak FASTA (default binding_peaks_unique.fa) or set --genome-json-dir for chromosome JSONs.")
533
 
 
 
 
 
534
 
535
+ #Load TF sequences
536
  tf_seqs, tf_ids = [], []
537
  for record in SeqIO.parse(args.tf_fasta, "fasta"):
538
  tf_ids.append(record.id)
 
543
  out_tf = Path(args.out_dir) / f"tf_{args.tf_model}.npy"
544
  embed_and_save(tf_seqs, tf_ids, tf_embedder, out_tf)
545
 
546
+ print("Done.")
dpacman/classifier/model_tmp/extract_tf_symbols.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import pandas as pd
3
+ from pathlib import Path
4
+
5
+ FINAL_CSV = Path("/home/a03-akrishna/DPACMAN/data_files/processed/final.csv")
6
+ OUT_SYMBOLS = Path("tf_symbols.txt")
7
+
8
+ def normalize_tf(tf_id: str) -> str:
9
+ return tf_id.split("_seq")[0].upper()
10
+
11
+ def main():
12
+ df = pd.read_csv(FINAL_CSV, dtype=str)
13
+ if "TF_id" not in df.columns:
14
+ raise RuntimeError("final.csv missing TF_id column")
15
+ tf_raw = df["TF_id"].dropna().unique().tolist()
16
+ normalized = sorted({normalize_tf(t) for t in tf_raw})
17
+ print(f"Unique raw TF_id count: {len(tf_raw)}")
18
+ print(f"Unique normalized TF symbols: {len(normalized)}")
19
+ with open(OUT_SYMBOLS, "w") as f:
20
+ for s in normalized:
21
+ f.write(s + "\n")
22
+ print(f"Wrote normalized TF symbols to {OUT_SYMBOLS}")
23
+ # Optional: show sample
24
+ print("Sample symbols:", normalized[:50])
25
+
26
+ if __name__ == "__main__":
27
+ main()
dpacman/classifier/model_tmp/make_pair_list.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import argparse
3
+ import numpy as np
4
+ import pandas as pd
5
+ from pathlib import Path
6
+ import random
7
+ import sys
8
+
9
+ def read_ids_file(p):
10
+ p = Path(p)
11
+ if not p.exists():
12
+ raise FileNotFoundError(f"IDs file not found: {p}")
13
+ return [line.strip() for line in p.open() if line.strip()]
14
+
15
+ def split_embeddings(emb_path, ids_path, out_dir, prefix):
16
+ out_dir = Path(out_dir)
17
+ out_dir.mkdir(parents=True, exist_ok=True)
18
+
19
+ if not Path(emb_path).exists():
20
+ raise FileNotFoundError(f"Embedding file not found: {emb_path}")
21
+ if not Path(ids_path).exists():
22
+ raise FileNotFoundError(f"IDs file not found: {ids_path}")
23
+
24
+ if emb_path.endswith(".npz"):
25
+ data = np.load(emb_path, allow_pickle=True)
26
+ if "embeddings" in data:
27
+ emb = data["embeddings"]
28
+ else:
29
+ raise ValueError(f"{emb_path} missing 'embeddings' key")
30
+ else:
31
+ emb = np.load(emb_path)
32
+
33
+ ids = read_ids_file(ids_path)
34
+ if len(ids) != emb.shape[0]:
35
+ print(f"[WARN] length mismatch: {len(ids)} ids vs {emb.shape[0]} embeddings in {emb_path}", file=sys.stderr)
36
+
37
+ mapping = {}
38
+ for i, ident in enumerate(ids):
39
+ if i >= emb.shape[0]:
40
+ print(f"[WARN] skipping {ident}: no embedding at index {i}", file=sys.stderr)
41
+ continue
42
+ arr = emb[i]
43
+ out_file = out_dir / f"{prefix}_{ident}.npy"
44
+ np.save(out_file, arr)
45
+ mapping[ident] = str(out_file)
46
+ return mapping
47
+
48
+ def extract_symbol_from_tf_id(full_id: str) -> str:
49
+ """
50
+ Given a TF embedding ID like 'sp|O15062|ZBTB5_HUMAN' or 'ZBTB5_HUMAN',
51
+ return the gene symbol uppercase (e.g., 'ZBTB5').
52
+ """
53
+ if "|" in full_id:
54
+ try:
55
+ # format sp|Accession|SYMBOL_HUMAN
56
+ genepart = full_id.split("|")[2]
57
+ except IndexError:
58
+ genepart = full_id
59
+ else:
60
+ genepart = full_id
61
+ symbol = genepart.split("_")[0]
62
+ return symbol.upper()
63
+
64
+ def build_tf_symbol_map(tf_map):
65
+ """
66
+ Build mapping gene_symbol -> list of embedding paths.
67
+ """
68
+ symbol_map = {}
69
+ for full_id, path in tf_map.items():
70
+ symbol = extract_symbol_from_tf_id(full_id)
71
+ symbol_map.setdefault(symbol, []).append(path)
72
+ return symbol_map
73
+
74
+ def tf_key_from_path(path: str) -> str:
75
+ """
76
+ Given a path like .../tf_sp|O15062|ZBTB5_HUMAN.npy, extract normalized symbol 'ZBTB5'.
77
+ """
78
+ stem = Path(path).stem # e.g., tf_sp|O15062|ZBTB5_HUMAN
79
+ # remove leading prefix if present (tf_)
80
+ if "_" in stem:
81
+ _, rest = stem.split("_", 1)
82
+ else:
83
+ rest = stem
84
+ return extract_symbol_from_tf_id(rest)
85
+
86
+ def dna_key_from_path(path: str) -> str:
87
+ """
88
+ Given .../dna_peak42.npy -> 'peak42'
89
+ """
90
+ stem = Path(path).stem
91
+ if "_" in stem:
92
+ _, rest = stem.split("_", 1)
93
+ else:
94
+ rest = stem
95
+ return rest
96
+
97
+ def main():
98
+ parser = argparse.ArgumentParser(
99
+ description="Build TF-DNA pair list from final.csv with gene-symbol normalization for TFs."
100
+ )
101
+ parser.add_argument("--final_csv", required=True, help="final.csv with TF_id and dna_sequence")
102
+ parser.add_argument("--dna_embed_npz", required=True, help="DNA embedding file (.npy or .npz)")
103
+ parser.add_argument("--dna_ids", required=True, help="IDs file for DNA embeddings (e.g., peak*.ids)")
104
+ parser.add_argument("--tf_embed_npy", required=True, help="TF embedding file (.npy or .npz)")
105
+ parser.add_argument("--tf_ids", required=True, help="IDs file for TF embeddings (e.g., sp|...|... ids)")
106
+ parser.add_argument("--out_dir", required=True, help="Output directory")
107
+ parser.add_argument("--neg_per_positive", type=int, default=2, help="Negatives per positive (half same-TF, half same-DNA)")
108
+ parser.add_argument("--seed", type=int, default=42)
109
+ args = parser.parse_args()
110
+
111
+ random.seed(args.seed)
112
+ out_dir = Path(args.out_dir)
113
+ out_dir.mkdir(parents=True, exist_ok=True)
114
+
115
+ # Load final.csv
116
+ df = pd.read_csv(args.final_csv, dtype=str)
117
+ if "TF_id" not in df.columns or "dna_sequence" not in df.columns:
118
+ raise RuntimeError("final.csv must have columns TF_id and dna_sequence")
119
+
120
+ # Assign dna_id (unique per dna_sequence)
121
+ unique_seqs = df["dna_sequence"].drop_duplicates().tolist()
122
+ seq_to_id = {seq: f"peak{i}" for i, seq in enumerate(unique_seqs)}
123
+ df["dna_id"] = df["dna_sequence"].map(seq_to_id)
124
+ enriched_csv = out_dir / "final_with_dna_id.csv"
125
+ df.to_csv(enriched_csv, index=False)
126
+ print(f"[i] Wrote augmented final.csv with dna_id to {enriched_csv}")
127
+
128
+ # Split embeddings into per-item files
129
+ print(f"[i] Splitting DNA embeddings from {args.dna_embed_npz} with ids {args.dna_ids}")
130
+ dna_map = split_embeddings(args.dna_embed_npz, args.dna_ids, out_dir / "dna_single", "dna")
131
+ print(f"[i] DNA embeddings available: {len(dna_map)} (sample: {list(dna_map.keys())[:10]})")
132
+ print(f"[i] Splitting TF embeddings from {args.tf_embed_npy} with ids {args.tf_ids}")
133
+ tf_map = split_embeddings(args.tf_embed_npy, args.tf_ids, out_dir / "tf_single", "tf")
134
+ print(f"[i] TF embeddings available: {len(tf_map)} (sample: {list(tf_map.keys())[:10]})")
135
+
136
+ # Build gene-symbol normalized map
137
+ tf_symbol_map = build_tf_symbol_map(tf_map)
138
+ print(f"[i] TF symbol map keys (sample): {list(tf_symbol_map.keys())[:30]}")
139
+
140
+ # Diagnostic overlaps
141
+ norm_tf_in_final = set(t.split("_seq")[0].upper() for t in df["TF_id"].unique())
142
+ available_tf_symbols = set(tf_symbol_map.keys())
143
+ intersect_tf = norm_tf_in_final & available_tf_symbols
144
+ print(f"[i] Unique normalized TF symbols in final.csv: {len(norm_tf_in_final)}")
145
+ print(f"[i] Available TF embedding symbols: {len(available_tf_symbols)}")
146
+ print(f"[i] Intersection count: {len(intersect_tf)}")
147
+ if len(intersect_tf) == 0:
148
+ print("[ERROR] No overlap between normalized TF_id and TF embedding symbols.", file=sys.stderr)
149
+ print("Sample normalized TFs from final.csv:", sorted(list(norm_tf_in_final))[:30], file=sys.stderr)
150
+ print("Sample available TF symbols:", sorted(list(available_tf_symbols))[:30], file=sys.stderr)
151
+ sys.exit(1)
152
+
153
+ dna_ids_final = set(df["dna_id"].unique())
154
+ available_dna_ids = set(dna_map.keys())
155
+ intersect_dna = dna_ids_final & available_dna_ids
156
+ print(f"[i] Unique dna_id in final.csv: {len(dna_ids_final)}. Available DNA ids: {len(available_dna_ids)}. Intersection: {len(intersect_dna)}")
157
+ if len(intersect_dna) == 0:
158
+ print("[ERROR] No overlap on DNA ids.", file=sys.stderr)
159
+ sys.exit(1)
160
+
161
+ # Build positive pairs
162
+ positives = []
163
+ for _, row in df.iterrows():
164
+ tf_raw = row["TF_id"]
165
+ tf_symbol = tf_raw.split("_seq")[0].upper()
166
+ dnaid = row["dna_id"]
167
+ if tf_symbol not in tf_symbol_map:
168
+ continue
169
+ if dnaid not in dna_map:
170
+ continue
171
+ # pick the first embedding for that symbol
172
+ tf_embedding_path = tf_symbol_map[tf_symbol][0]
173
+ positives.append((tf_embedding_path, dna_map[dnaid], 1))
174
+ print(f"[i] Constructed {len(positives)} positive pairs after TF symbol resolution")
175
+
176
+ if len(positives) == 0:
177
+ print("[ERROR] No positive pairs could be constructed; aborting.", file=sys.stderr)
178
+ sys.exit(1)
179
+
180
+ # Build negative samples
181
+ all_tf_symbols = sorted(tf_symbol_map.keys())
182
+ all_dnaids = sorted(dna_map.keys())
183
+ positive_set = set()
184
+ for tf_path, dna_path, _ in positives:
185
+ tf_key = tf_key_from_path(tf_path)
186
+ dna_key = dna_key_from_path(dna_path)
187
+ positive_set.add((tf_key, dna_key))
188
+
189
+ negatives = []
190
+ half = args.neg_per_positive // 2
191
+ for tf_path, dna_path, _ in positives:
192
+ tf_key = tf_key_from_path(tf_path)
193
+ dna_key = dna_key_from_path(dna_path)
194
+ # same TF, different DNA
195
+ for _ in range(half):
196
+ candidate_dna = random.choice(all_dnaids)
197
+ if candidate_dna == dna_key or (tf_key, candidate_dna) in positive_set:
198
+ continue
199
+ negatives.append((tf_path, dna_map[candidate_dna], 0))
200
+ # same DNA, different TF
201
+ for _ in range(half):
202
+ candidate_tf_symbol = random.choice(all_tf_symbols)
203
+ if candidate_tf_symbol == tf_key or (candidate_tf_symbol, dna_key) in positive_set:
204
+ continue
205
+ # pick its first embedding
206
+ candidate_tf_path = tf_symbol_map[candidate_tf_symbol][0]
207
+ negatives.append((candidate_tf_path, dna_map[dnaid], 0))
208
+
209
+ print(f"[i] Sampled {len(negatives)} negatives (neg_per_positive={args.neg_per_positive})")
210
+
211
+ # Write pair list
212
+ pair_list_path = out_dir / "pair_list.tsv"
213
+ with open(pair_list_path, "w") as f:
214
+ for binder_path, glm_path, label in positives + negatives:
215
+ # binder=TF, glm=DNA
216
+ f.write(f"{binder_path}\t{glm_path}\t{label}\n")
217
+ print(f"[i] Wrote {len(positives)+len(negatives)} examples to {pair_list_path}")
218
+
219
+ if __name__ == "__main__":
220
+ main()
dpacman/classifier/model_tmp/make_peak_fasta.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from pathlib import Path
3
+
4
+ df = pd.read_csv("/home/a03-akrishna/DPACMAN/data_files/processed/final.csv", dtype=str) # adjust path if needed
5
+ # get unique sequences
6
+ uniq = df[["dna_sequence"]].drop_duplicates().reset_index(drop=True)
7
+ # make headers: e.g., peak0, peak1, ...
8
+ out_fa = Path("binding_peaks_unique.fa")
9
+ with open(out_fa, "w") as f:
10
+ for i, seq in enumerate(uniq["dna_sequence"]):
11
+ header = f">peak{i}"
12
+ f.write(f"{header}\n{seq}\n")
13
+ print(f"Wrote {len(uniq)} unique binding sequences to {out_fa}")
dpacman/classifier/model_tmp/model.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ class LocalCNN(nn.Module):
5
+ def __init__(self, dim: int = 256, kernel_size: int = 3):
6
+ super().__init__()
7
+ padding = kernel_size // 2
8
+ self.conv = nn.Conv1d(dim, dim, kernel_size=kernel_size, padding=padding)
9
+ self.act = nn.GELU()
10
+ self.ln = nn.LayerNorm(dim)
11
+
12
+ def forward(self, x: torch.Tensor):
13
+ # x: (batch, L, dim)
14
+ out = self.conv(x.transpose(1, 2)) # → (batch, dim, L)
15
+ out = self.act(out)
16
+ out = out.transpose(1, 2) # → (batch, L, dim)
17
+ return self.ln(out + x) # residual
18
+
19
+ class CrossModalBlock(nn.Module):
20
+ def __init__(self, dim: int = 256, heads: int = 8):
21
+ super().__init__()
22
+ # self-attention for both sides
23
+ self.sa_binder = nn.MultiheadAttention(dim, heads, batch_first=True)
24
+ self.sa_glm = nn.MultiheadAttention(dim, heads, batch_first=True)
25
+ self.ln_b1 = nn.LayerNorm(dim)
26
+ self.ln_g1 = nn.LayerNorm(dim)
27
+
28
+ self.ffn_b = nn.Sequential(nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim))
29
+ self.ffn_g = nn.Sequential(nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim))
30
+ self.ln_b2 = nn.LayerNorm(dim)
31
+ self.ln_g2 = nn.LayerNorm(dim)
32
+
33
+ # cross attention (binder queries, glm keys/values)
34
+ self.cross_attn = nn.MultiheadAttention(dim, heads, batch_first=True)
35
+ self.ln_c1 = nn.LayerNorm(dim)
36
+ self.ffn_c = nn.Sequential(nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim))
37
+ self.ln_c2 = nn.LayerNorm(dim)
38
+
39
+ def forward(self, binder: torch.Tensor, glm: torch.Tensor):
40
+ """
41
+ binder: (batch, Lb, dim)
42
+ glm: (batch, Lg, dim) -- has passed through its local CNN beforehand
43
+ returns: updated binder representation (batch, Lb, dim)
44
+ """
45
+ # binder self-attn + ffn
46
+ b = binder
47
+ b_sa, _ = self.sa_binder(b, b, b)
48
+ b = self.ln_b1(b + b_sa)
49
+ b_ff = self.ffn_b(b)
50
+ b = self.ln_b2(b + b_ff)
51
+
52
+ # glm self-attn + ffn
53
+ g = glm
54
+ g_sa, _ = self.sa_glm(g, g, g)
55
+ g = self.ln_g1(g + g_sa)
56
+ g_ff = self.ffn_g(g)
57
+ g = self.ln_g2(g + g_ff)
58
+
59
+ # cross-attention: binder queries glm
60
+ c_sa, _ = self.cross_attn(b, g, g)
61
+ c = self.ln_c1(b + c_sa)
62
+ c_ff = self.ffn_c(c)
63
+ c = self.ln_c2(c + c_ff)
64
+ return c # (batch, Lb, dim)
65
+
66
+ class BindPredictor(nn.Module):
67
+ def __init__(self,
68
+ input_dim: int = 256,
69
+ hidden_dim: int = 256,
70
+ heads: int = 8,
71
+ num_layers: int = 4,
72
+ use_local_cnn_on_glm: bool = True):
73
+ super().__init__()
74
+ self.proj_binder = nn.Linear(input_dim, hidden_dim)
75
+ self.proj_glm = nn.Linear(input_dim, hidden_dim)
76
+ self.use_local_cnn = use_local_cnn_on_glm
77
+ self.local_cnn = LocalCNN(hidden_dim) if use_local_cnn_on_glm else nn.Identity()
78
+
79
+ self.layers = nn.ModuleList([
80
+ CrossModalBlock(hidden_dim, heads) for _ in range(num_layers)
81
+ ])
82
+
83
+ self.ln_out = nn.LayerNorm(hidden_dim)
84
+ self.head = nn.Sequential(
85
+ nn.Linear(hidden_dim, 1),
86
+ nn.Sigmoid()
87
+ )
88
+
89
+ def forward(self, binder_emb, glm_emb):
90
+ """
91
+ binder_emb, glm_emb: (batch, L, input_dim)
92
+ """
93
+ b = self.proj_binder(binder_emb) # (B, Lb, hidden_dim)
94
+ g = self.proj_glm(glm_emb) # (B, Lg, hidden_dim)
95
+ if self.use_local_cnn:
96
+ g = self.local_cnn(g) # local context injected
97
+
98
+ for layer in self.layers:
99
+ b = layer(b, g) # update binder with cross-modal info
100
+
101
+ pooled = b.mean(dim=1) # (B, hidden_dim)
102
+ out = self.ln_out(pooled)
103
+ return self.head(out).squeeze(-1) # (B,)
dpacman/classifier/model_tmp/prep_splits.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import sys
4
+ import json
5
+ from sklearn.decomposition import TruncatedSVD
6
+ from sklearn.model_selection import train_test_split
7
+ from collections import Counter
8
+
9
+ def parse_pair_list(pair_list_path):
10
+ binder_paths, glm_paths, labels = [], [], []
11
+ with open(pair_list_path) as f:
12
+ for lineno, line in enumerate(f, start=1):
13
+ if not line.strip():
14
+ continue
15
+ parts = line.strip().split()
16
+ if len(parts) != 3:
17
+ print(f"[WARN] skipping malformed line {lineno}: {line.strip()}", file=sys.stderr)
18
+ continue
19
+ b, g, l = parts
20
+ try:
21
+ lab = int(l)
22
+ except ValueError:
23
+ print(f"[WARN] invalid label on line {lineno}: {l}", file=sys.stderr)
24
+ continue
25
+ binder_paths.append(b)
26
+ glm_paths.append(g)
27
+ labels.append(lab)
28
+ return binder_paths, glm_paths, labels
29
+
30
+ def build_tf_compressed_cache(binder_paths, target_dim=256):
31
+ """
32
+ Load all unique TF (binder) embeddings, fit reduction if needed, and return dict mapping path->(L, target_dim) array.
33
+ """
34
+ unique_paths = sorted(set(binder_paths))
35
+ print(f"[i] Found {len(unique_paths)} unique TF embedding files to compress.", flush=True)
36
+ # Load all embeddings to determine dimensionality
37
+ samples = []
38
+ for p in unique_paths:
39
+ arr = np.load(p)
40
+ samples.append(arr)
41
+ # Determine if reduction needed: assume all have same embedding width
42
+ first = samples[0]
43
+ orig_dim = first.shape[1] if first.ndim == 2 else 1
44
+ reduction_needed = (orig_dim != target_dim)
45
+ tf_cache = {}
46
+
47
+ if reduction_needed:
48
+ # Build matrix to fit SVD: we need a 2D matrix per embedding; if lengths vary we can't directly stack.
49
+ # We'll do reduction per sequence individually using TruncatedSVD on concatenated flattened features:
50
+ # Simplest: for variable lengths, reduce each embedding separately with a learned linear projection.
51
+ # Here we fit a single TruncatedSVD on the concatenation of all sequence tokens (flattened) by padding/truncating to a fixed length.
52
+ # To avoid complexity, use PCA-like linear projection learned via SVD on mean-pooled vectors:
53
+ pooled = []
54
+ for arr in samples:
55
+ if arr.ndim == 2:
56
+ pooled.append(arr.mean(axis=0)) # (orig_dim,)
57
+ else:
58
+ pooled.append(arr) # degenerate
59
+ pooled_mat = np.stack(pooled, axis=0) # (N, orig_dim)
60
+ print(f"[i] Fitting TruncatedSVD on TF pooled embeddings: {pooled_mat.shape} -> {target_dim}", flush=True)
61
+ svd = TruncatedSVD(n_components=target_dim, random_state=42)
62
+ reduced_pooled = svd.fit_transform(pooled_mat) # (N, target_dim)
63
+
64
+ # For each original embedding, project token-level vectors by multiplying token vector with svd.components_.T
65
+ # svd.components_: (target_dim, orig_dim) so projection matrix is (orig_dim, target_dim)
66
+ proj_mat = svd.components_.T # (orig_dim, target_dim)
67
+ for i, p in enumerate(unique_paths):
68
+ arr = samples[i] # shape (L, orig_dim)
69
+ if arr.ndim == 1:
70
+ arr2 = arr @ proj_mat # (target_dim,)
71
+ else:
72
+ # project each token: (L, orig_dim) @ (orig_dim, target_dim) -> (L, target_dim)
73
+ arr2 = arr @ proj_mat
74
+ tf_cache[p] = arr2 # reduced per-token representation
75
+ print("[i] Completed compression of TF embeddings.", flush=True)
76
+ else:
77
+ # already correct dim: just cache originals
78
+ print(f"[i] TF embeddings already {target_dim}-dimensional; skipping reduction.", flush=True)
79
+ for i, p in enumerate(unique_paths):
80
+ arr = samples[i]
81
+ tf_cache[p] = arr
82
+ return tf_cache
83
+
84
+ def main():
85
+ #df = pd.read_csv("../data_files/processed/fimo/ananya_aug4_2025_final.csv")
86
+
87
+ binder_paths, glm_paths, labels = parse_pair_list("../data_files/processed/fimo/ananya_aug4_2025_pair_list.tsv")
88
+
89
+ if len(labels) == 0:
90
+ print("[ERROR] No valid pairs parsed. Exiting.", file=sys.stderr)
91
+ sys.exit(1)
92
+
93
+ label_counts = Counter(labels)
94
+ print(f"[i] Total examples parsed: {len(labels)}. Label distribution: {label_counts}", flush=True)
95
+
96
+ # build compressed TF cache (reduces to 256 if needed)
97
+ #tf_compressed_cache = build_tf_compressed_cache(binder_paths, target_dim=256)
98
+
99
+ # Combine all data into one structure for easy splitting
100
+ data = list(zip(binder_paths, glm_paths, labels))
101
+
102
+ # First split: train vs temp (val+test)
103
+ train_data, temp_data = train_test_split(data, test_size=0.2, random_state=42)
104
+
105
+ # Second split: val vs test (50% of 20% → 10% each)
106
+ val_data, test_data = train_test_split(temp_data, test_size=0.5, random_state=42)
107
+
108
+ # Unpack for dataset construction
109
+ def unpack(data):
110
+ binders, glms, labels = zip(*data)
111
+ return list(binders), list(glms), list(labels)
112
+
113
+
114
+ def save_split(binder_paths, glm_paths, labels, out_path):
115
+ df = pd.DataFrame({
116
+ "binder_path": binder_paths,
117
+ "glm_path": glm_paths,
118
+ "label": labels,
119
+ })
120
+ df.to_csv(out_path, index=False)
121
+
122
+ # Unpack data for saving
123
+ train_binders, train_glms, train_labels = unpack(train_data)
124
+ val_binders, val_glms, val_labels = unpack(val_data)
125
+ test_binders, test_glms, test_labels = unpack(test_data)
126
+
127
+ # Save each split
128
+ save_split(train_binders, train_glms, train_labels, "../data_files/splits/train.csv")
129
+ save_split(val_binders, val_glms, val_labels, "../data_files/splits/val.csv")
130
+ save_split(test_binders, test_glms, test_labels, "../data_files/splits/test.csv")
131
+
132
+ if __name__=="__main__":
133
+ main()
dpacman/classifier/model_tmp/train.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import argparse
3
+ import numpy as np
4
+ import torch
5
+ from torch import nn
6
+ from model import BindPredictor
7
+ from pathlib import Path
8
+ from collections import Counter
9
+ from sklearn.metrics import roc_auc_score, average_precision_score
10
+ from sklearn.decomposition import TruncatedSVD
11
+ import sys
12
+
13
+ from dpacman.utils.models import set_seed
14
+
15
+ def build_tf_compressed_cache(binder_paths, target_dim=256):
16
+ """
17
+ Load all unique TF (binder) embeddings, fit reduction if needed, and return dict mapping path->(L, target_dim) array.
18
+ """
19
+ unique_paths = sorted(set(binder_paths))
20
+ print(f"[i] Found {len(unique_paths)} unique TF embedding files to compress.", flush=True)
21
+ # Load all embeddings to determine dimensionality
22
+ samples = []
23
+ for p in unique_paths:
24
+ arr = np.load(p)
25
+ samples.append(arr)
26
+ # Determine if reduction needed: assume all have same embedding width
27
+ first = samples[0]
28
+ orig_dim = first.shape[1] if first.ndim == 2 else 1
29
+ reduction_needed = (orig_dim != target_dim)
30
+ tf_cache = {}
31
+
32
+ if reduction_needed:
33
+ # Build matrix to fit SVD: we need a 2D matrix per embedding; if lengths vary we can't directly stack.
34
+ # We'll do reduction per sequence individually using TruncatedSVD on concatenated flattened features:
35
+ # Simplest: for variable lengths, reduce each embedding separately with a learned linear projection.
36
+ # Here we fit a single TruncatedSVD on the concatenation of all sequence tokens (flattened) by padding/truncating to a fixed length.
37
+ # To avoid complexity, use PCA-like linear projection learned via SVD on mean-pooled vectors:
38
+ pooled = []
39
+ for arr in samples:
40
+ if arr.ndim == 2:
41
+ pooled.append(arr.mean(axis=0)) # (orig_dim,)
42
+ else:
43
+ pooled.append(arr) # degenerate
44
+ pooled_mat = np.stack(pooled, axis=0) # (N, orig_dim)
45
+ print(f"[i] Fitting TruncatedSVD on TF pooled embeddings: {pooled_mat.shape} -> {target_dim}", flush=True)
46
+ svd = TruncatedSVD(n_components=target_dim, random_state=42)
47
+ reduced_pooled = svd.fit_transform(pooled_mat) # (N, target_dim)
48
+
49
+ # For each original embedding, project token-level vectors by multiplying token vector with svd.components_.T
50
+ # svd.components_: (target_dim, orig_dim) so projection matrix is (orig_dim, target_dim)
51
+ proj_mat = svd.components_.T # (orig_dim, target_dim)
52
+ for i, p in enumerate(unique_paths):
53
+ arr = samples[i] # shape (L, orig_dim)
54
+ if arr.ndim == 1:
55
+ arr2 = arr @ proj_mat # (target_dim,)
56
+ else:
57
+ # project each token: (L, orig_dim) @ (orig_dim, target_dim) -> (L, target_dim)
58
+ arr2 = arr @ proj_mat
59
+ tf_cache[p] = arr2 # reduced per-token representation
60
+ print("[i] Completed compression of TF embeddings.", flush=True)
61
+ else:
62
+ # already correct dim: just cache originals
63
+ print(f"[i] TF embeddings already {target_dim}-dimensional; skipping reduction.", flush=True)
64
+ for i, p in enumerate(unique_paths):
65
+ arr = samples[i]
66
+ tf_cache[p] = arr
67
+ return tf_cache
68
+
69
+ def evaluate(model, dl, device):
70
+ model.eval()
71
+ all_labels = []
72
+ all_preds = []
73
+ with torch.no_grad():
74
+ for b, g, y in dl:
75
+ b = b.to(device)
76
+ g = g.to(device)
77
+ y = y.to(device)
78
+ pred = model(b, g)
79
+ all_labels.append(y.cpu())
80
+ all_preds.append(pred.cpu())
81
+ if not all_labels:
82
+ return 0.0, 0.0
83
+ y_true = torch.cat(all_labels).numpy()
84
+ y_score = torch.cat(all_preds).numpy()
85
+ try:
86
+ auc = roc_auc_score(y_true, y_score)
87
+ except Exception:
88
+ auc = 0.0
89
+ try:
90
+ ap = average_precision_score(y_true, y_score)
91
+ except Exception:
92
+ ap = 0.0
93
+ return auc, ap
94
+
95
+ def unpack(data):
96
+ binders, glms, labels = zip(*data)
97
+ return list(binders), list(glms), list(labels)
98
+
99
+ # ---- main ------------------------------------------------------------
100
+ def main(cfg):
101
+ # Set seed for reproducibility
102
+ set_seed(cfg.seed)
103
+
104
+ parser.add_argument("--out_dir", type=str, required=True)
105
+ parser.add_argument("--epochs", type=int, default=10)
106
+ parser.add_argument("--batch_size", type=int, default=32)
107
+ parser.add_argument("--lr", type=float, default=1e-4)
108
+ parser.add_argument("--device", type=str, default="cuda")
109
+ parser.add_argument("--seed", type=int, default=42)
110
+ args = parser.parse_args()
111
+
112
+ #
113
+ print("DEBUG: starting training script with in-line TF compression", flush=True)
114
+ device = torch.device(args.device if torch.cuda.is_available() else "cpu")
115
+ binder_paths, glm_paths, labels = parse_pair_list(cfg.pair_list)
116
+
117
+ if len(labels) == 0:
118
+ print("[ERROR] No valid pairs parsed. Exiting.", file=sys.stderr)
119
+ sys.exit(1)
120
+
121
+ label_counts = Counter(labels)
122
+ print(f"[i] Total examples parsed: {len(labels)}. Label distribution: {label_counts}", flush=True)
123
+
124
+ # build compressed TF cache (reduces to 256 if needed)
125
+ tf_compressed_cache = build_tf_compressed_cache(binder_paths, target_dim=256)
126
+
127
+ # load training data aloiaushasfoiuhasfoiuafasdfoihuaaasdfoiuhasfaaoiufhasfoasasfoiuh
128
+
129
+ train_ds = PairDataset(None, tf_compressed_cache=tf_compressed_cache)
130
+ val_ds = PairDataset(*subset(val_i), tf_compressed_cache=tf_compressed_cache)
131
+ test_ds = PairDataset(*subset(test_i), tf_compressed_cache=tf_compressed_cache)
132
+
133
+ print(f"[i] Train/Val/Test sizes: {len(train_ds)}/{len(val_ds)}/{len(test_ds)}", flush=True)
134
+ if len(train_ds) == 0 or len(val_ds) == 0:
135
+ print("[ERROR] Train or validation split is empty; cannot proceed.", file=sys.stderr)
136
+ sys.exit(1)
137
+
138
+ train_dl = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn)
139
+ val_dl = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn)
140
+ test_dl = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn)
141
+
142
+ model = BindPredictor(input_dim=256, hidden_dim=256, heads=8, num_layers=3, use_local_cnn_on_glm=True)
143
+ model = model.to(device)
144
+ optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-3)
145
+ loss_fn = nn.BCELoss()
146
+
147
+ best_val = -float("inf")
148
+ os_out = Path(args.out_dir)
149
+ os_out.mkdir(exist_ok=True, parents=True)
150
+
151
+ for epoch in range(1, args.epochs + 1):
152
+ print(f"[Epoch {epoch}] starting...", flush=True)
153
+ model.train()
154
+ running_loss = 0.0
155
+ for b, g, y in train_dl:
156
+ b = b.to(device)
157
+ g = g.to(device)
158
+ y = y.to(device)
159
+ pred = model(b, g)
160
+ loss = loss_fn(pred, y)
161
+ optimizer.zero_grad()
162
+ loss.backward()
163
+ optimizer.step()
164
+ running_loss += loss.item() * b.size(0)
165
+ train_loss = running_loss / len(train_ds)
166
+ val_auc, val_ap = evaluate(model, val_dl, device)
167
+ print(f"[Epoch {epoch}] train_loss={train_loss:.4f} val_auc={val_auc:.4f} val_ap={val_ap:.4f}", flush=True)
168
+
169
+ if val_auc > best_val:
170
+ best_val = val_auc
171
+ torch.save(model.state_dict(), os_out / "best_model.pt")
172
+ print(f"[Epoch {epoch}] Saved new best model with val_auc={val_auc:.4f}", flush=True)
173
+
174
+ torch.save(model.state_dict(), os_out / "last_model.pt")
175
+ test_auc, test_ap = evaluate(model, test_dl, device)
176
+ print(f"FINAL TEST: AUC={test_auc:.4f} AP={test_ap:.4f}", flush=True)
177
+ print(f"[i] Models written to {os_out}/best_model.pt and last_model.pt", flush=True)
178
+
179
+ if __name__ == "__main__":
180
+ main()
dpacman/data_modules/__init__.py ADDED
File without changes
dpacman/data_modules/pair.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import argparse
3
+ import numpy as np
4
+ import torch
5
+ from torch import nn
6
+ from torch.utils.data import Dataset, DataLoader
7
+ from lightning import LightningDataModule
8
+ from model import BindPredictor
9
+ from pathlib import Path
10
+ from collections import Counter
11
+ from sklearn.metrics import roc_auc_score, average_precision_score
12
+ from sklearn.decomposition import TruncatedSVD
13
+ from multiprocessing import cpu_count
14
+ from functools import partial
15
+ import random
16
+ import sys
17
+ import pandas as pd
18
+
19
+ from dpacman.utils import RankedLogger
20
+
21
+ logger = RankedLogger(__name__, rank_zero_only=True)
22
+
23
+
24
+ # ---- dataset ---------------------------------------------------------
25
+ class PairDataset(Dataset):
26
+ def __init__(self, tf_paths, dna_paths, final_df, tf_cache):
27
+ """
28
+ tf_cache: dict mapping binder_path -> compressed (256-d) tensor/array
29
+ """
30
+ self.tf_paths, self.dna_paths = tf_paths, dna_paths
31
+ self.tf_cache = tf_cache
32
+ self.targets = {}
33
+ for _, row in final_df.iterrows():
34
+ dna_id = row["dna_id"]
35
+ vec = np.array(list(map(float, row["score_sig_r2"].split(","))), dtype=np.float32)
36
+ self.targets[dna_id] = vec
37
+
38
+ def __len__(self):
39
+ return len(self.tf_paths)
40
+
41
+ def __getitem__(self, i):
42
+ b = self.tf_cache[self.tf_paths[i]] # (L_b, D_b) or (D_b,)
43
+ if b.ndim==1: b = b[None,:]
44
+ g = np.load(self.dna_paths[i]) # (L_g, 256) or (256,)
45
+ if g.ndim==1: g = g[None,:]
46
+
47
+ stem = Path(self.dna_paths[i]).stem
48
+ dna_id = stem.replace("dna_","")
49
+ t = self.targets.get(dna_id, np.zeros(g.shape[0],dtype=np.float32))
50
+
51
+ return torch.from_numpy(b).float(), \
52
+ torch.from_numpy(g).float(), \
53
+ torch.from_numpy(t).float()
54
+
55
+
56
+ class PairDataModule(LightningDataModule):
57
+ def __init__(
58
+ self,
59
+ train_file: str = "data_files/splits/train.csv",
60
+ val_file: str = "data_files/splits/val.csv",
61
+ test_file: str = "data_files/splits/test.csv",
62
+ tokenizer_path="facebook/esm2_t33_650M_UR50D",
63
+ batch_size: int = 1,
64
+ num_workers=8,
65
+ maximize_num_workers=False,
66
+ debug_run: bool = False,
67
+ pin_memory: bool = False,
68
+ ):
69
+ super().__init__()
70
+ self.save_hyperparameters()
71
+ self.debug_run = debug_run
72
+
73
+ # Initialize the data files
74
+ self.train_data_file = train_file
75
+ self.val_data_file = val_file
76
+ self.test_data_file = test_file
77
+
78
+ # Initialize hyperparameters like batch size
79
+ self.batch_size = batch_size
80
+ self.num_wokers = cpu_count() if maximize_num_workers else min(num_workers, cpu_count())
81
+
82
+ logger.info(f"num_workers={self.num_wokers}")
83
+ logger.info("Initialized BinderDecoyDataModule constants")
84
+
85
+ def load_and_unpack(self, file_path, lim=None):
86
+ """
87
+ Load and unpack an input csv whose columns are binder_path,glm_path,label
88
+ """
89
+ df = pd.read_csv(file_path)
90
+ if lim is not None:
91
+ df = df[:lim].reset_index(drop=True)
92
+
93
+ binder_paths = df["binder_path"].tolist()
94
+ glm_paths = df["glm_path"].tolist()
95
+ labels = df["label"].tolist()
96
+
97
+ return binder_paths, glm_paths, labels
98
+
99
+ def setup(self, stage):
100
+ lim = 5 if self.debug_run else None
101
+ if stage=="train":
102
+ binder_paths, glm_paths, labels = self.load_file(self.train_data_file, lim=lim)
103
+ self.train_dataset = PairDataset(binder_paths, glm_paths, labels)
104
+ elif stage=="val":
105
+ binder_paths, glm_paths, labels = self.load_file(self.val_data_file, lim=lim)
106
+ self.val_dataset = PairDataset(binder_paths, glm_paths, labels)
107
+ elif stage=="test":
108
+ binder_paths, glm_paths, labels = self.load_file(self.test_data_file, lim=lim)
109
+ self.test_dataset = PairDataset(binder_paths, glm_paths, labels)
110
+ else:
111
+ raise RuntimeError(f"Stage {stage} is not defined. Must be train, val, or test.")
112
+
113
+ def train_dataloader(self):
114
+ return DataLoader(
115
+ self.train_dataset,
116
+ batch_size=self.batch_size,
117
+ collate_fn=collate_fn,
118
+ num_workers=self.num_wokers,
119
+ pin_memory=self.hparams.pin_memory,
120
+ shuffle=True,
121
+ )
122
+
123
+ def val_dataloader(self):
124
+ return DataLoader(
125
+ self.val_dataset,
126
+ batch_size=self.batch_size,
127
+ collate_fn=collate_fn,
128
+ num_workers=self.num_wokers,
129
+ pin_memory=self.hparams.pin_memory,
130
+ shuffle=False,
131
+ )
132
+
133
+ def test_dataloader(self):
134
+ return DataLoader(
135
+ self.test_dataset,
136
+ batch_size=self.batch_size,
137
+ collate_fn=collate_fn,
138
+ num_workers=self.num_wokers,
139
+ pin_memory=self.hparams.pin_memory,
140
+ shuffle=False,
141
+ )
142
+
143
+
144
+ def collate_fn(batch):
145
+ Bs = [b.shape[0] for b,_,_ in batch]
146
+ Gs = [g.shape[0] for _,g,_ in batch]
147
+ maxB, maxG = max(Bs), max(Gs)
148
+
149
+ def pad_seq(x, L):
150
+ if x.shape[0] < L:
151
+ pad = torch.zeros((L-x.shape[0], x.shape[1]), dtype=x.dtype, device=x.device)
152
+ return torch.cat([x, pad], dim=0)
153
+ return x
154
+
155
+ def pad_t(y, L):
156
+ if y.shape[0] < L:
157
+ pad = torch.zeros((L-y.shape[0],), dtype=y.dtype, device=y.device)
158
+ return torch.cat([y, pad], dim=0)
159
+ return y
160
+
161
+ b_stack = torch.stack([pad_seq(b, maxB) for b,_,_ in batch])
162
+ g_stack = torch.stack([pad_seq(g, maxG) for _,g,_ in batch])
163
+ t_stack = torch.stack([pad_t(t, maxG) for *_,t in batch])
164
+ return b_stack, g_stack, t_stack
165
+
166
+ def build_tf_compressed_cache(binder_paths, target_dim=256):
167
+ """
168
+ Load all unique TF (binder) embeddings, fit reduction if needed, and return dict mapping path->(L, target_dim) array.
169
+ """
170
+ unique_paths = sorted(set(binder_paths))
171
+ logger.info(f"[i] Found {len(unique_paths)} unique TF embedding files to compress.", flush=True)
172
+ # Load all embeddings to determine dimensionality
173
+ samples = []
174
+ for p in unique_paths:
175
+ arr = np.load(p)
176
+ samples.append(arr)
177
+ # Determine if reduction needed: assume all have same embedding width
178
+ first = samples[0]
179
+ orig_dim = first.shape[1] if first.ndim == 2 else 1
180
+ reduction_needed = (orig_dim != target_dim)
181
+ tf_cache = {}
182
+
183
+ if reduction_needed:
184
+ # Build matrix to fit SVD: we need a 2D matrix per embedding; if lengths vary we can't directly stack.
185
+ # We'll do reduction per sequence individually using TruncatedSVD on concatenated flattened features:
186
+ # Simplest: for variable lengths, reduce each embedding separately with a learned linear projection.
187
+ # Here we fit a single TruncatedSVD on the concatenation of all sequence tokens (flattened) by padding/truncating to a fixed length.
188
+ # To avoid complexity, use PCA-like linear projection learned via SVD on mean-pooled vectors:
189
+ pooled = []
190
+ for arr in samples:
191
+ if arr.ndim == 2:
192
+ pooled.append(arr.mean(axis=0)) # (orig_dim,)
193
+ else:
194
+ pooled.append(arr) # degenerate
195
+ pooled_mat = np.stack(pooled, axis=0) # (N, orig_dim)
196
+ logger.info(f"[i] Fitting TruncatedSVD on TF pooled embeddings: {pooled_mat.shape} -> {target_dim}", flush=True)
197
+ svd = TruncatedSVD(n_components=target_dim, random_state=42)
198
+ reduced_pooled = svd.fit_transform(pooled_mat) # (N, target_dim)
199
+
200
+ # For each original embedding, project token-level vectors by multiplying token vector with svd.components_.T
201
+ # svd.components_: (target_dim, orig_dim) so projection matrix is (orig_dim, target_dim)
202
+ proj_mat = svd.components_.T # (orig_dim, target_dim)
203
+ for i, p in enumerate(unique_paths):
204
+ arr = samples[i] # shape (L, orig_dim)
205
+ if arr.ndim == 1:
206
+ arr2 = arr @ proj_mat # (target_dim,)
207
+ else:
208
+ # project each token: (L, orig_dim) @ (orig_dim, target_dim) -> (L, target_dim)
209
+ arr2 = arr @ proj_mat
210
+ tf_cache[p] = arr2 # reduced per-token representation
211
+ logger.info("[i] Completed compression of TF embeddings.", flush=True)
212
+ else:
213
+ # already correct dim: just cache originals
214
+ logger.info(f"[i] TF embeddings already {target_dim}-dimensional; skipping reduction.", flush=True)
215
+ for i, p in enumerate(unique_paths):
216
+ arr = samples[i]
217
+ tf_cache[p] = arr
218
+ return tf_cache
219
+
220
+ def evaluate(model, dl, device):
221
+ model.eval()
222
+ all_labels = []
223
+ all_preds = []
224
+ with torch.no_grad():
225
+ for b, g, y in dl:
226
+ b = b.to(device)
227
+ g = g.to(device)
228
+ y = y.to(device)
229
+ pred = model(b, g)
230
+ all_labels.append(y.cpu())
231
+ all_preds.append(pred.cpu())
232
+ if not all_labels:
233
+ return 0.0, 0.0
234
+ y_true = torch.cat(all_labels).numpy()
235
+ y_score = torch.cat(all_preds).numpy()
236
+ try:
237
+ auc = roc_auc_score(y_true, y_score)
238
+ except Exception:
239
+ auc = 0.0
240
+ try:
241
+ ap = average_precision_score(y_true, y_score)
242
+ except Exception:
243
+ ap = 0.0
244
+ return auc, ap
245
+
246
+ # ---- main ------------------------------------------------------------
247
+ def main():
248
+ parser = argparse.ArgumentParser()
249
+ parser.add_argument("--pair_list", type=str, required=True,
250
+ help="TSV: binder_path glm_path label")
251
+ parser.add_argument("--out_dir", type=str, required=True)
252
+ parser.add_argument("--epochs", type=int, default=10)
253
+ parser.add_argument("--batch_size", type=int, default=32)
254
+ parser.add_argument("--lr", type=float, default=1e-4)
255
+ parser.add_argument("--device", type=str, default="cuda")
256
+ parser.add_argument("--seed", type=int, default=42)
257
+ args = parser.parse_args()
258
+
259
+ # reproducibility
260
+ random.seed(args.seed)
261
+ np.random.seed(args.seed)
262
+ torch.manual_seed(args.seed)
263
+
264
+ logger.info("DEBUG: starting training script with in-line TF compression", flush=True)
265
+ logger.info(f"[i] pair_list: {args.pair_list}", flush=True)
266
+ logger.info(f"[i] output dir: {args.out_dir}", flush=True)
267
+ device = torch.device(args.device if torch.cuda.is_available() else "cpu")
268
+ binder_paths, glm_paths, labels = parse_pair_list(args.pair_list)
269
+
270
+ if len(labels) == 0:
271
+ logger.info("[ERROR] No valid pairs parsed. Exiting.", file=sys.stderr)
272
+ sys.exit(1)
273
+
274
+ label_counts = Counter(labels)
275
+ logger.info(f"[i] Total examples parsed: {len(labels)}. Label distribution: {label_counts}", flush=True)
276
+
277
+ # build compressed TF cache (reduces to 256 if needed)
278
+ tf_compressed_cache = build_tf_compressed_cache(binder_paths, target_dim=256)
279
+
280
+ # simple split: 80/10/10
281
+ n = len(labels)
282
+ idxs = np.arange(n)
283
+ np.random.shuffle(idxs)
284
+ train_i = idxs[: int(0.8 * n)]
285
+ val_i = idxs[int(0.8 * n): int(0.9 * n)]
286
+ test_i = idxs[int(0.9 * n):]
287
+
288
+ def subset(idxs):
289
+ return [binder_paths[i] for i in idxs], [glm_paths[i] for i in idxs], [labels[i] for i in idxs]
290
+
291
+ train_ds = PairDataset(*subset(train_i), tf_compressed_cache=tf_compressed_cache)
292
+ val_ds = PairDataset(*subset(val_i), tf_compressed_cache=tf_compressed_cache)
293
+ test_ds = PairDataset(*subset(test_i), tf_compressed_cache=tf_compressed_cache)
294
+
295
+ logger.info(f"[i] Train/Val/Test sizes: {len(train_ds)}/{len(val_ds)}/{len(test_ds)}", flush=True)
296
+ if len(train_ds) == 0 or len(val_ds) == 0:
297
+ logger.info("[ERROR] Train or validation split is empty; cannot proceed.", file=sys.stderr)
298
+ sys.exit(1)
299
+
300
+ train_dl = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn)
301
+ val_dl = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn)
302
+ test_dl = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn)
303
+
304
+ model = BindPredictor(input_dim=256, hidden_dim=256, heads=8, num_layers=3, use_local_cnn_on_glm=True)
305
+ model = model.to(device)
306
+ optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-3)
307
+ loss_fn = nn.BCELoss()
308
+
309
+ best_val = -float("inf")
310
+ os_out = Path(args.out_dir)
311
+ os_out.mkdir(exist_ok=True, parents=True)
312
+
313
+ for epoch in range(1, args.epochs + 1):
314
+ logger.info(f"[Epoch {epoch}] starting...", flush=True)
315
+ model.train()
316
+ running_loss = 0.0
317
+ for b, g, y in train_dl:
318
+ b = b.to(device)
319
+ g = g.to(device)
320
+ y = y.to(device)
321
+ pred = model(b, g)
322
+ loss = loss_fn(pred, y)
323
+ optimizer.zero_grad()
324
+ loss.backward()
325
+ optimizer.step()
326
+ running_loss += loss.item() * b.size(0)
327
+ train_loss = running_loss / len(train_ds)
328
+ val_auc, val_ap = evaluate(model, val_dl, device)
329
+ logger.info(f"[Epoch {epoch}] train_loss={train_loss:.4f} val_auc={val_auc:.4f} val_ap={val_ap:.4f}", flush=True)
330
+
331
+ if val_auc > best_val:
332
+ best_val = val_auc
333
+ torch.save(model.state_dict(), os_out / "best_model.pt")
334
+ logger.info(f"[Epoch {epoch}] Saved new best model with val_auc={val_auc:.4f}", flush=True)
335
+
336
+ torch.save(model.state_dict(), os_out / "last_model.pt")
337
+ test_auc, test_ap = evaluate(model, test_dl, device)
338
+ logger.info(f"FINAL TEST: AUC={test_auc:.4f} AP={test_ap:.4f}", flush=True)
339
+ logger.info(f"[i] Models written to {os_out}/best_model.pt and last_model.pt", flush=True)
340
+
341
+ if __name__ == "__main__":
342
+ main()
dpacman/data_tasks/cluster/__init__.py ADDED
File without changes
dpacman/data_tasks/cluster/remap.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Holds Python methods for clustering Remap DNA sequences.
3
+ """
4
+ import argparse
5
+ import numpy as np
6
+ import pandas as pd
7
+ from pathlib import Path
8
+ import random
9
+ import sys
10
+ import subprocess
11
+ from collections import defaultdict
12
+ import rootutils
13
+ import logging
14
+ import os
15
+ import json
16
+ from omegaconf import DictConfig
17
+ from hydra.core.hydra_config import HydraConfig
18
+
19
+ from dpacman.utils.clustering import make_fasta, process_fasta, analyze_clustering_result, run_mmseqs_clustering, cluster_summary
20
+
21
+ root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
22
+ logger = logging.getLogger(__name__)
23
+
24
+ def cluster_molecules(fasta_dict, fasta_path, mmseqs_params: DictConfig, output_dir="", path_to_mmseqs="../softwares/mmseqs", moltype="dna", use_gpu=True):
25
+ """
26
+ Args:
27
+ - fasta_dict: dictionary object where the keys are sequence IDs, and the values are sequences
28
+ - fasta_path: str or Path to where the output fasta should be saved
29
+ - mmseqs_params: DictConfig of mmseqs hparams
30
+ - type: molecule type, "dna" or "protein"
31
+ """
32
+
33
+ # make the fasta
34
+ logger.info(f"Making fasta at: {fasta_path}")
35
+ fasta_path = str(make_fasta(fasta_dict, fasta_path))
36
+
37
+ # prepare directories
38
+ output_dir = str(Path(root) / output_dir)
39
+ path_to_mmseqs = str(Path(root) / path_to_mmseqs)
40
+
41
+ # run mmseqs
42
+ dbtype=1
43
+ if moltype=="dna": dbtype=2
44
+ run_mmseqs_clustering(fasta_path,
45
+ output_dir,
46
+ min_seq_id=mmseqs_params.min_seq_id,
47
+ c=mmseqs_params.c,
48
+ cov_mode=mmseqs_params.cov_mode,
49
+ cluster_mode=mmseqs_params.cluster_mode,
50
+ dbtype=dbtype,
51
+ path_to_mmseqs=path_to_mmseqs)
52
+
53
+ tsv_path = [x for x in os.listdir(output_dir) if x.endswith(".tsv")][0]
54
+ clusters = analyze_clustering_result(
55
+ fasta_path, Path(output_dir) / tsv_path
56
+ )
57
+ logger.info(f"Made clusters DataFrame:\n{clusters.head()}")
58
+ cluster_summary(clusters)
59
+
60
+ def read_input_data(input_path):
61
+ """
62
+ Read the data from the input path.
63
+ It may be a csv or parquet
64
+ """
65
+ input_path = Path(root) / input_path
66
+ df = None
67
+ if str(input_path).endswith(".parquet"):
68
+ df = pd.read_parquet(input_path, engine='pyarrow')
69
+ elif str(input_path).endswith(".csv"):
70
+ df = pd.read_csv(input_path)
71
+ elif str(input_path).endswith(".tsv") or str(input_path).endswith(".txt"):
72
+ df = pd.read_csv(input_path, sep="\t")
73
+ else:
74
+ raise Exception(f"Cannot read input data from {input_path}: invalid file type")
75
+ return df
76
+
77
+ def main(cfg: DictConfig):
78
+ """
79
+ Run clustering on Remap protein AND DNA sequences.
80
+ Get clusters for each.
81
+ """
82
+ # Load input CSV
83
+ # columns: Index(['ID', 'tr_seqid', 'dna_seqid', 'tr_name', 'peak_id', 'chipscore', 'total_jaspar_hits', 'dna_sequence', 'tr_sequence', 'scores']
84
+ df = read_input_data(cfg.data_task.input_data_path)
85
+
86
+ # Separate configs
87
+ dna_full_cfg = cfg.data_task.dna_full
88
+ dna_peaks_cfg = cfg.data_task.dna_peaks
89
+ protein_cfg = cfg.data_task.protein
90
+ logger.info(f"Clustering DNA full: {cfg.data_task.cluster_dna_full}. Clustering DNA peaks: {cfg.data_task.cluster_dna_peaks}. Clustering protein: {cfg.data_task.cluster_protein}.")
91
+
92
+ # Make fastas
93
+ dna_full_fasta_path = Path(root) / dna_full_cfg.fasta_path
94
+ dna_peaks_fasta_path = Path(root) / dna_peaks_cfg.fasta_path
95
+ protein_fasta_path = Path(root) / protein_cfg.fasta_path
96
+ os.makedirs(dna_full_fasta_path.parent, exist_ok=True)
97
+ os.makedirs(dna_peaks_fasta_path.parent, exist_ok=True)
98
+ os.makedirs(protein_fasta_path.parent, exist_ok=True)
99
+
100
+ # Make dictioary needed for input to the fasta methods
101
+ with open(Path(root) / dna_full_cfg.input_map_path, "r") as f:
102
+ dna_full_fasta_dict = json.load(f)
103
+
104
+ with open(Path(root) / dna_peaks_cfg.input_map_path, "r") as f:
105
+ dna_peaks_fasta_dict = json.load(f)
106
+
107
+ with open(Path(root) / protein_cfg.input_map_path, "r") as f:
108
+ protein_fasta_dict = json.load(f)
109
+
110
+ logger.info(f"Loaded DNA seq dict from: {dna_full_cfg.input_map_path}. Size: {len(dna_full_fasta_dict)}")
111
+ logger.info(f"Loaded DNA peaks dict from: {dna_peaks_cfg.input_map_path}. Size: {len(dna_peaks_fasta_dict)}")
112
+ logger.info(f"Loaded TR (protein) seq dict from: {protein_cfg.input_map_path}. Size: {len(protein_fasta_dict)}")
113
+
114
+ # Build hash-sets once (drop NaNs to avoid weird matches)
115
+ dna_ids = set(df["dna_seqid"].dropna())
116
+ peak_ids = set(df["peak_seqid"].dropna())
117
+ tr_ids = set(df["tr_seqid"].dropna())
118
+
119
+ # Iterate only the intersection (fast when allowed << dict size)
120
+ dna_full_fasta_dict = {k: dna_full_fasta_dict[k] for k in (dna_full_fasta_dict.keys() & dna_ids)}
121
+ dna_peaks_fasta_dict = {k: dna_peaks_fasta_dict[k] for k in (dna_peaks_fasta_dict.keys() & peak_ids)}
122
+ protein_fasta_dict = {k: protein_fasta_dict[k] for k in (protein_fasta_dict.keys() & tr_ids)}
123
+
124
+ logger.info(f"Filtered dictionaries to only sequences in the filtered training data.")
125
+ logger.info(f"Total DNA sequences: {len(dna_full_fasta_dict)}. Total peak sequences: {len(dna_peaks_fasta_dict)}. Total protein sequences: {len(protein_fasta_dict)}")
126
+
127
+ if cfg.data_task.cluster_dna_full:
128
+ logger.info(f"Clustering DNA full sequences, with context")
129
+ cluster_molecules(dna_full_fasta_dict, dna_full_fasta_path,
130
+ mmseqs_params=dna_full_cfg.mmseqs, output_dir=dna_full_cfg.output_dir, path_to_mmseqs=cfg.data_task.path_to_mmseqs, moltype="dna")
131
+
132
+ if cfg.data_task.cluster_dna_peaks:
133
+ logger.info(f"Clustering DNA peak sequences")
134
+ cluster_molecules(dna_peaks_fasta_dict, dna_peaks_fasta_path,
135
+ mmseqs_params=dna_peaks_cfg.mmseqs, output_dir=dna_peaks_cfg.output_dir, path_to_mmseqs=cfg.data_task.path_to_mmseqs, moltype="dna")
136
+
137
+ if cfg.data_task.cluster_protein:
138
+ logger.info("Clustering protein sequences.")
139
+ cluster_molecules(protein_fasta_dict, protein_fasta_path, mmseqs_params=protein_cfg.mmseqs, output_dir=protein_cfg.output_dir, path_to_mmseqs=cfg.data_task.path_to_mmseqs, moltype="protein")
140
+
141
+ logger.info("Clustering pipeline complete")
142
+
143
+ if __name__ == "__main__":
144
+ main()
dpacman/data_tasks/embeddings/embedders.py ADDED
@@ -0,0 +1,560 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Plug-and-play embedding extraction for:
3
+ • Chromosome sequences (from raw UCSC JSON)
4
+ • TF sequences (transcription_factors.fasta)
5
+
6
+ Usage example (DNA + protein in one go):
7
+ module load miniconda/24.7.1
8
+ conda activate dpacman
9
+ python dpacman/data/compute_embeddings.py \
10
+ --genome-json-dir ../data_files/raw/genomes/hg38 \
11
+ --tf-fasta ../data_files/processed/tfclust/hg38_tf/transcription_factors.fasta \
12
+ --chrom-model caduceus \
13
+ --tf-model esm-dbp \
14
+ --out-dir ../data_files/processed/tfclust/hg38_tf/embeddings \
15
+ --device cuda
16
+ """
17
+ import os
18
+ import re
19
+ import argparse
20
+ import json
21
+ import numpy as np
22
+ from pathlib import Path
23
+ import torch
24
+ from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM, pipeline
25
+ import esm
26
+ from Bio import SeqIO
27
+ import time
28
+ import pandas as pd
29
+ from tqdm.auto import tqdm
30
+ import logging, math
31
+
32
+ # ---- model wrappers ----
33
+
34
+ class CaduceusEmbedder:
35
+ def __init__(self, device, chunk_size=131_072, overlap=0):
36
+ """
37
+ device: 'cpu' or 'cuda'
38
+ chunk_size: max bases (and thus tokens) to send in one forward pass
39
+ overlap: how many bases each window overlaps the previous; 0 = no overlap
40
+ """
41
+ model_name = "kuleshov-group/caduceus-ph_seqlen-131k_d_model-256_n_layer-16"
42
+ self.tokenizer = AutoTokenizer.from_pretrained(
43
+ model_name, trust_remote_code=True
44
+ )
45
+ self.model = AutoModel.from_pretrained(
46
+ model_name, trust_remote_code=True
47
+ ).to(device).eval()
48
+ self.device = device
49
+ self.chunk_size = chunk_size
50
+ self.step = chunk_size - overlap
51
+
52
+ def embed(self, seqs):
53
+ """
54
+ seqs: List[str] of DNA sequences (each <= chunk_size for this test)
55
+ returns: np.ndarray of shape (N, L, D), raw per‐token embeddings
56
+ """
57
+ # outputs = []
58
+ # for seq in seqs:
59
+ # # --- new: raw per‐token embeddings in one shot ---
60
+ # toks = self.tokenizer(
61
+ # seq,
62
+ # return_tensors="pt",
63
+ # padding=False,
64
+ # truncation=True,
65
+ # max_length=self.chunk_size
66
+ # ).to(self.device)
67
+ # with torch.no_grad():
68
+ # out = self.model(**toks).last_hidden_state # (1, L, D)
69
+ # outputs.append(out.cpu().numpy()[0]) # (L, D)
70
+
71
+ # return np.stack(outputs, axis=0) # (N, L, D)
72
+ outputs = []
73
+ for seq in tqdm(seqs, total=len(seqs), desc="DNA: Caduceus", dynamic_ncols=True):
74
+ toks = self.tokenizer(
75
+ seq,
76
+ return_tensors="pt",
77
+ padding=False,
78
+ truncation=True,
79
+ max_length=self.chunk_size
80
+ ).to(self.device)
81
+ with torch.no_grad():
82
+ out = self.model(**toks).last_hidden_state # (1, L, D)
83
+ outputs.append(out.cpu().numpy()[0]) # (L, D)
84
+ return outputs # list of variable-length (L_i, D) arrays
85
+
86
+
87
+ def benchmark(self, lengths=None):
88
+ """
89
+ Time embedding on single-sequence of various lengths.
90
+ By default tests [5K,10K,50K,100K,chunk_size].
91
+ """
92
+ tests = lengths or [5_000, 10_000, 50_000, 100_000, self.chunk_size]
93
+ print(f"→ Benchmarking Caduceus on device={self.device}")
94
+ for sz in tests:
95
+ seq = "A" * sz
96
+ # Warm-up
97
+ _ = self.embed([seq])
98
+ if self.device != "cpu":
99
+ torch.cuda.synchronize()
100
+ t0 = time.perf_counter()
101
+ _ = self.embed([seq])
102
+ if self.device != "cpu":
103
+ torch.cuda.synchronize()
104
+ t1 = time.perf_counter()
105
+ print(f" length={sz:6,d} time={(t1-t0)*1000:7.1f} ms")
106
+
107
+ class SegmentNTEmbedder:
108
+ def __init__(self, device):
109
+ self.tokenizer = AutoTokenizer.from_pretrained("InstaDeepAI/segment_nt", trust_remote_code=True)
110
+ self.model = AutoModel.from_pretrained("InstaDeepAI/segment_nt", trust_remote_code=True).to(device).eval()
111
+ self.device = device
112
+
113
+ def _adjust_length(self, input_ids):
114
+ bs, L = input_ids.shape
115
+ excl = L - 1
116
+ remainder = (excl) % 4
117
+ if remainder != 0:
118
+ pad_needed = 4 - remainder
119
+ pad_tensor = torch.full((bs, pad_needed), self.tokenizer.pad_token_id, dtype=input_ids.dtype, device=input_ids.device)
120
+ input_ids = torch.cat([input_ids, pad_tensor], dim=1)
121
+ return input_ids
122
+
123
+ def embed(self, seqs, batch_size=16):
124
+ """
125
+ seqs: List[str]
126
+ Returns: np.ndarray of shape (N, D)
127
+ """
128
+ all_embeddings = []
129
+ for i in range(0, len(seqs), batch_size):
130
+ batch_seqs = seqs[i : i + batch_size]
131
+ encoded = self.tokenizer.batch_encode_plus(
132
+ batch_seqs,
133
+ return_tensors="pt",
134
+ padding=True,
135
+ truncation=True,
136
+ )
137
+ input_ids = encoded["input_ids"].to(self.device) # (B, L)
138
+ attention_mask = input_ids != self.tokenizer.pad_token_id
139
+
140
+ input_ids = self._adjust_length(input_ids)
141
+ attention_mask = (input_ids != self.tokenizer.pad_token_id)
142
+
143
+ with torch.no_grad():
144
+ outs = self.model(
145
+ input_ids,
146
+ attention_mask=attention_mask,
147
+ output_hidden_states=True,
148
+ return_dict=True,
149
+ )
150
+ if hasattr(outs, "hidden_states") and outs.hidden_states is not None:
151
+ last_hidden = outs.hidden_states[-1] # (B, L, D)
152
+ else:
153
+ last_hidden = outs.last_hidden_state # fallback
154
+
155
+ # Exclude CLS token if present (assume first token) and pool
156
+ pooled = last_hidden[:, 1:, :].mean(dim=1) # (B, D)
157
+ all_embeddings.append(pooled.cpu().numpy())
158
+
159
+ # release fragmentation
160
+ torch.cuda.empty_cache()
161
+
162
+ return np.vstack(all_embeddings) # (N, D)
163
+
164
+
165
+ class DNABertEmbedder:
166
+ def __init__(self, device):
167
+ self.tokenizer = AutoTokenizer.from_pretrained("zhihan1996/DNA_bert_6", trust_remote_code=True)
168
+ self.model = AutoModel.from_pretrained("zhihan1996/DNA_bert_6", trust_remote_code=True).to(device)
169
+ self.device = device
170
+
171
+ def embed(self, seqs):
172
+ embs = []
173
+ for s in seqs:
174
+ tokens = self.tokenizer(s, return_tensors="pt", padding=True)["input_ids"].to(self.device)
175
+ with torch.no_grad():
176
+ out = self.model(tokens).last_hidden_state.mean(1)
177
+ embs.append(out.cpu().numpy())
178
+ return np.vstack(embs)
179
+
180
+ class NucleotideTransformerEmbedder:
181
+ def __init__(self, device):
182
+ # HF “feature-extraction” returns a list of (L, D) arrays for each input
183
+ # device: “cpu” or “cuda”
184
+ self.pipe = pipeline(
185
+ "feature-extraction",
186
+ model="InstaDeepAI/nucleotide-transformer-500m-1000g",
187
+ device= -1 if device=="cpu" else 0 # HF uses -1 for CPU, 0 for GPU #:contentReference[oaicite:0]{index=0}
188
+ )
189
+
190
+ def embed(self, seqs):
191
+ """
192
+ seqs: List[str] of raw DNA sequences
193
+ returns: (N, D) array, one D-dim vector per sequence
194
+ """
195
+ all_embeddings = self.pipe(seqs, truncation=True, padding=True)
196
+ # all_embeddings is a List of shape (L, D) arrays
197
+ pooled = [ np.mean(x, axis=0) for x in all_embeddings ]
198
+ return np.vstack(pooled)
199
+
200
+ # class ESMEmbedder:
201
+ # def __init__(self, device):
202
+ # self.model, self.alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
203
+ # self.batch_converter = self.alphabet.get_batch_converter()
204
+ # self.model.to(device).eval()
205
+ # self.device = device
206
+
207
+ # def embed(self, seqs):
208
+ # batch = [(str(i), seq) for i, seq in enumerate(seqs)]
209
+ # _, _, toks = self.batch_converter(batch)
210
+ # toks = toks.to(self.device)
211
+ # with torch.no_grad():
212
+ # results = self.model(toks, repr_layers=[33], return_contacts=False)
213
+ # reps = results["representations"][33]
214
+ # return reps[:, 1:-1].mean(1).cpu().numpy()
215
+
216
+
217
+ class ESMEmbedder:
218
+ def __init__(self, device, model_name="esm2_t33_650M_UR50D"):
219
+ # Try to load the specified ESM-2 model; fallback to esm1b if missing
220
+ self.device = device
221
+ try:
222
+ self.model, self.alphabet = getattr(esm.pretrained, model_name)()
223
+ self.is_esm2 = model_name.lower().startswith("esm2")
224
+ except AttributeError:
225
+ # fallback to ESM-1b
226
+ self.model, self.alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
227
+ self.is_esm2 = False
228
+ self.batch_converter = self.alphabet.get_batch_converter()
229
+ self.model.to(device).eval()
230
+ # determine max length: esm2 models vary; use default 1024 for esm1b
231
+ self.max_len = 4096 if self.is_esm2 else 1024 # adjust if your esm2 variant has explicit limit
232
+ # for chunking: reserve 2 tokens if model uses BOS/EOS
233
+ self.chunk_size = self.max_len - 2
234
+ self.overlap = self.chunk_size // 4 # 25% overlap to smooth boundaries
235
+
236
+ def _chunk_sequence(self, seq):
237
+ """
238
+ Return list of possibly overlapping chunks of seq, each <= chunk_size.
239
+ """
240
+ if len(seq) <= self.chunk_size:
241
+ return [seq]
242
+ step = self.chunk_size - self.overlap
243
+ chunks = []
244
+ for i in range(0, len(seq), step):
245
+ chunk = seq[i : i + self.chunk_size]
246
+ if not chunk:
247
+ break
248
+ chunks.append(chunk)
249
+ return chunks
250
+
251
+ def embed(self, seqs):
252
+ """
253
+ seqs: List[str] of protein sequences.
254
+ Returns: np.ndarray of shape (N, D) pooled per-sequence embeddings.
255
+ """
256
+ all_embeddings = []
257
+ for i, seq in enumerate(seqs):
258
+ chunks = self._chunk_sequence(seq)
259
+ chunk_vecs = []
260
+ # process chunks in batch if small number, else sequentially
261
+ for chunk in chunks:
262
+ batch = [(str(i), chunk)]
263
+ _, _, toks = self.batch_converter(batch)
264
+ toks = toks.to(self.device)
265
+ with torch.no_grad():
266
+ results = self.model(toks, repr_layers=[33], return_contacts=False)
267
+ reps = results["representations"][33] # (1, L, D)
268
+ # remove BOS/EOS if present: take 1:-1 if length permits
269
+ if reps.size(1) > 2:
270
+ rep = reps[:, 1:-1].mean(1) # (1, D)
271
+ else:
272
+ rep = reps.mean(1) # fallback
273
+ chunk_vecs.append(rep.squeeze(0)) # (D,)
274
+ if len(chunk_vecs) == 1:
275
+ seq_vec = chunk_vecs[0]
276
+ else:
277
+ # average chunk vectors
278
+ stacked = torch.stack(chunk_vecs, dim=0) # (num_chunks, D)
279
+ seq_vec = stacked.mean(0)
280
+ all_embeddings.append(seq_vec.cpu().numpy())
281
+ return np.vstack(all_embeddings) # (N, D)
282
+
283
+
284
+ # class ESMDBPEmbedder:
285
+ # def __init__(self, device):
286
+ # base_model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
287
+ # model_path = (
288
+ # Path(__file__).resolve().parent.parent
289
+ # / "pretrained" / "ESM-DBP" / "ESM-DBP.model"
290
+ # )
291
+ # checkpoint = torch.load(model_path, map_location="cpu")
292
+ # clean_sd = {}
293
+ # for k, v in checkpoint.items():
294
+ # clean_sd[k.replace("module.", "")] = v
295
+ # result = base_model.load_state_dict(clean_sd, strict=False)
296
+ # if result.missing_keys:
297
+ # print(f"[ESMDBP] missing keys: {result.missing_keys}")
298
+ # if result.unexpected_keys:
299
+ # print(f"[ESMDBP] unexpected keys: {result.unexpected_keys}")
300
+
301
+ # self.model = base_model.to(device).eval()
302
+ # self.alphabet = alphabet
303
+ # self.batch_converter = alphabet.get_batch_converter()
304
+ # self.device = device
305
+
306
+ # def embed(self, seqs):
307
+ # batch = [(str(i), seq) for i, seq in enumerate(seqs)]
308
+ # _, _, toks = self.batch_converter(batch)
309
+ # toks = toks.to(self.device)
310
+ # with torch.no_grad():
311
+ # out = self.model(toks, repr_layers=[33], return_contacts=False)
312
+ # reps = out["representations"][33]
313
+ # # skip start/end tokens
314
+ # return reps[:, 1:-1].mean(1).cpu().numpy()
315
+
316
+ class ESMDBPEmbedder:
317
+ def __init__(self, device):
318
+ base_model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
319
+ model_path = (
320
+ Path(__file__).resolve().parent.parent
321
+ / "pretrained" / "ESM-DBP" / "ESM-DBP.model"
322
+ )
323
+ checkpoint = torch.load(model_path, map_location="cpu")
324
+ clean_sd = {}
325
+ for k, v in checkpoint.items():
326
+ clean_sd[k.replace("module.", "")] = v
327
+ result = base_model.load_state_dict(clean_sd, strict=False)
328
+ if result.missing_keys:
329
+ print(f"[ESMDBP] missing keys: {result.missing_keys}")
330
+ if result.unexpected_keys:
331
+ print(f"[ESMDBP] unexpected keys: {result.unexpected_keys}")
332
+
333
+ self.model = base_model.to(device).eval()
334
+ self.alphabet = alphabet
335
+ self.batch_converter = alphabet.get_batch_converter()
336
+ self.device = device
337
+ self.max_len = 1024 # same limit as esm1b
338
+ self.chunk_size = self.max_len - 2
339
+ self.overlap = self.chunk_size // 4
340
+
341
+ def _chunk_sequence(self, seq):
342
+ if len(seq) <= self.chunk_size:
343
+ return [seq]
344
+ step = self.chunk_size - self.overlap
345
+ chunks = []
346
+ for i in range(0, len(seq), step):
347
+ chunk = seq[i : i + self.chunk_size]
348
+ if not chunk:
349
+ break
350
+ chunks.append(chunk)
351
+ return chunks
352
+
353
+ def embed(self, seqs):
354
+ all_embeddings = []
355
+ for i, seq in enumerate(seqs):
356
+ chunks = self._chunk_sequence(seq)
357
+ chunk_vecs = []
358
+ for chunk in chunks:
359
+ batch = [(str(i), chunk)]
360
+ _, _, toks = self.batch_converter(batch)
361
+ toks = toks.to(self.device)
362
+ with torch.no_grad():
363
+ out = self.model(toks, repr_layers=[33], return_contacts=False)
364
+ reps = out["representations"][33]
365
+ if reps.size(1) > 2:
366
+ rep = reps[:, 1:-1].mean(1)
367
+ else:
368
+ rep = reps.mean(1)
369
+ chunk_vecs.append(rep.squeeze(0))
370
+ if len(chunk_vecs) == 1:
371
+ seq_vec = chunk_vecs[0]
372
+ else:
373
+ stacked = torch.stack(chunk_vecs, dim=0)
374
+ seq_vec = stacked.mean(0)
375
+ all_embeddings.append(seq_vec.cpu().numpy())
376
+ return np.vstack(all_embeddings)
377
+
378
+ class GPNEmbedder:
379
+ def __init__(self, device):
380
+ model_name = "songlab/gpn-msa-sapiens"
381
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
382
+ self.model = AutoModelForMaskedLM.from_pretrained(model_name)
383
+ self.model.to(device)
384
+ self.model.eval()
385
+ self.device = device
386
+
387
+ def embed(self, seqs):
388
+ inputs = self.tokenizer(
389
+ seqs,
390
+ return_tensors="pt",
391
+ padding=True,
392
+ truncation=True
393
+ ).to(self.device)
394
+
395
+ with torch.no_grad():
396
+ last_hidden = self.model(**inputs).last_hidden_state
397
+ return last_hidden.mean(dim=1).cpu().numpy()
398
+
399
+ class ProGenEmbedder:
400
+ def __init__(self, device):
401
+ model_name = "jinyuan22/ProGen2-base"
402
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
403
+ self.model = AutoModel.from_pretrained(model_name).to(device).eval()
404
+ self.device = device
405
+
406
+ def embed(self, seqs):
407
+ inputs = self.tokenizer(
408
+ seqs,
409
+ return_tensors="pt",
410
+ padding=True,
411
+ truncation=True
412
+ ).to(self.device)
413
+ with torch.no_grad():
414
+ last_hidden = self.model(**inputs).last_hidden_state
415
+ return last_hidden.mean(dim=1).cpu().numpy()
416
+
417
+ # ---- main pipeline ----
418
+
419
+ def get_embedder(name, device, for_dna=True):
420
+ name = name.lower()
421
+ if for_dna:
422
+ if name=="caduceus": return CaduceusEmbedder(device)
423
+ if name=="dnabert": return DNABertEmbedder(device)
424
+ if name=="nucleotide": return NucleotideTransformerEmbedder(device)
425
+ if name=="gpn": return GPNEmbedder(device)
426
+ if name=="segmentnt": return SegmentNTEmbedder(device)
427
+ else:
428
+ if name in ("esm",): return ESMEmbedder(device)
429
+ if name in ("esm-dbp","esm_dbp"): return ESMDBPEmbedder(device)
430
+ if name=="progen": return ProGenEmbedder(device)
431
+ raise ValueError(f"Unknown model {name} (for_dna={for_dna})")
432
+
433
+
434
+ def pad_token_embeddings(list_of_arrays, pad_value=0.0):
435
+ """
436
+ list_of_arrays: list of (L_i, D) numpy arrays
437
+ Returns:
438
+ padded: (N, L_max, D) array
439
+ mask: (N, L_max) boolean array where True = real token, False = padding
440
+ """
441
+ N = len(list_of_arrays)
442
+ D = list_of_arrays[0].shape[1]
443
+ L_max = max(arr.shape[0] for arr in list_of_arrays)
444
+ padded = np.full((N, L_max, D), pad_value, dtype=list_of_arrays[0].dtype)
445
+ mask = np.zeros((N, L_max), dtype=bool)
446
+ for i, arr in enumerate(list_of_arrays):
447
+ L = arr.shape[0]
448
+ padded[i, :L] = arr
449
+ mask[i, :L] = True
450
+ return padded, mask
451
+
452
+ def embed_and_save(seqs, ids, embedder, out_path):
453
+ embs = embedder.embed(seqs)
454
+
455
+ # Decide whether we got variable-length per-token outputs (list of (L, D))
456
+ is_variable_token = isinstance(embs, (list, tuple)) and len(embs) > 0 and hasattr(embs[0], "shape") and embs[0].ndim == 2
457
+
458
+ if is_variable_token:
459
+ # pad to (N, L_max, D) + mask
460
+ padded, mask = pad_token_embeddings(embs)
461
+ # Save both embeddings and mask together in an .npz for convenience
462
+ np.savez_compressed(out_path.with_suffix(".caduceus.npz"),
463
+ embeddings=padded,
464
+ mask=mask,
465
+ ids=np.array(ids, dtype=object))
466
+ else:
467
+ # fixed shape output, e.g., pooled (N, D)
468
+ array = np.vstack(embs) if isinstance(embs, list) else embs
469
+ np.save(out_path, array)
470
+ with open(out_path.with_suffix(".ids"), "w") as f:
471
+ f.write("\n".join(ids))
472
+
473
+
474
+ if __name__=="__main__":
475
+
476
+ p = argparse.ArgumentParser()
477
+ #p.add_argument("--peak_fasta", default="binding_peaks_unique.fa", help="FASTA of deduplicated binding peak sequences; if present this is used for DNA embedding instead of genome JSONs")
478
+ p.add_argument("--genome-json-dir", default=None, help="(fallback) directory of UCSC JSONs for full chromosome embedding if peak FASTA is missing or you explicitly want chromosomes")
479
+ p.add_argument("--skip-dna", action="store_true", help="if set, skip the chromosome embedding step") #if glm embeddings successful but not plm embeddings
480
+ p.add_argument("--tf-fasta", required=True, help="input TF FASTA file")
481
+ p.add_argument("--chrom-model", default="caduceus")
482
+ p.add_argument("--tf-model", default="esm-dbp")
483
+ p.add_argument("--out-dir", default="dpacman/model/embeddings")
484
+ p.add_argument("--device", default="cpu")
485
+ args = p.parse_args()
486
+
487
+ os.makedirs(args.out_dir, exist_ok=True)
488
+ device = args.device
489
+ print(device)
490
+
491
+ if not args.skip_dna:
492
+ if args.genome_json_dir == None:
493
+ dna_df = pd.read_parquet('/home/a03-akrishna/DPACMAN/dpacman/model/remap2022_crm_fimo_output_q_processed.parquet', engine='pyarrow')
494
+ #df.to_csv('/home/a03-akrishna/DPACMAN/dpacman/model/remap2022_crm_fimo_output_q_processed.csv', index=False)
495
+ peak_seqs = dna_df["dna_sequence"]
496
+ peak_ids = dna_df["ID"]
497
+ print(f"Embedding {len(peak_seqs)} binding peak sequences from processed remap data", flush=True)
498
+ dna_embedder = get_embedder(args.chrom_model, device, for_dna=True)
499
+ out_peaks = Path(args.out_dir) / f"peaks_{args.chrom_model}.npy"
500
+ embed_and_save(peak_seqs, peak_ids, dna_embedder, out_peaks)
501
+
502
+ # peak_fasta = Path(args.peak_fasta)
503
+ # if peak_fasta.exists():
504
+ # # Load peak sequences from FASTA
505
+ # from Bio import SeqIO
506
+
507
+ # peak_seqs = []
508
+ # peak_ids = []
509
+ # for rec in SeqIO.parse(peak_fasta, "fasta"):
510
+ # peak_ids.append(rec.id)
511
+ # peak_seqs.append(str(rec.seq))
512
+ # print(f"Embedding {len(peak_seqs)} binding peak sequences from {peak_fasta}", flush=True)
513
+ # dna_embedder = get_embedder(args.chrom_model, device, for_dna=True)
514
+ # out_peaks = Path(args.out_dir) / f"peaks_{args.chrom_model}.npy"
515
+ # embed_and_save(peak_seqs, peak_ids, dna_embedder, out_peaks)
516
+ elif args.genome_json_dir:
517
+ # Legacy: load full chromosomes from JSONs (chr1–22, X, Y, M)
518
+ genome_dir = Path(args.genome_json_dir)
519
+ chrom_seqs, chrom_ids = [], []
520
+ primary_pattern = re.compile(r"^hg38_chr(?:[1-9]|1[0-9]|2[0-2]|X|Y|M)\.json$")
521
+ for j in sorted(genome_dir.iterdir()):
522
+ if not primary_pattern.match(j.name):
523
+ continue
524
+ data = json.loads(j.read_text())
525
+ seq = data.get("dna") or data.get("sequence")
526
+ chrom = data.get("chrom") or j.stem.split("_")[-1]
527
+ chrom_seqs.append(seq)
528
+ chrom_ids.append(chrom)
529
+ cutoff = CaduceusEmbedder(device).chunk_size
530
+ long_chroms = [
531
+ (chrom, len(seq))
532
+ for chrom, seq in zip(chrom_ids, chrom_seqs)
533
+ if len(seq) > cutoff
534
+ ]
535
+ if long_chroms:
536
+ print("⚠️ Chromosomes exceeding Caduceus max tokens ({}):".format(cutoff))
537
+ for chrom, L in long_chroms:
538
+ print(f" {chrom}: {L} bases")
539
+ else:
540
+ print("All chromosomes ≤ Caduceus limit ({}).".format(cutoff))
541
+
542
+ chrom_embedder = get_embedder(args.chrom_model, device, for_dna=True)
543
+ out_chrom = Path(args.out_dir) / f"chrom_{args.chrom_model}.npy"
544
+ embed_and_save(chrom_seqs, chrom_ids, chrom_embedder, out_chrom)
545
+ else:
546
+ raise ValueError("No input for DNA embedding: provide a peak FASTA (default binding_peaks_unique.fa) or set --genome-json-dir for chromosome JSONs.")
547
+
548
+
549
+ #Load TF sequences
550
+ tf_seqs, tf_ids = [], []
551
+ for record in SeqIO.parse(args.tf_fasta, "fasta"):
552
+ tf_ids.append(record.id)
553
+ tf_seqs.append(str(record.seq))
554
+
555
+ # embed and save
556
+ tf_embedder = get_embedder(args.tf_model, device, for_dna=False)
557
+ out_tf = Path(args.out_dir) / f"tf_{args.tf_model}.npy"
558
+ embed_and_save(tf_seqs, tf_ids, tf_embedder, out_tf)
559
+
560
+ print("Done.")
dpacman/data_tasks/fimo/__init__.py ADDED
File without changes
dpacman/data_tasks/fimo/post_fimo.py CHANGED
@@ -8,6 +8,7 @@ import multiprocessing as mp
8
  import numpy as np
9
  import pandas as pd
10
  import math
 
11
  import rootutils
12
  import polars as pl
13
  from omegaconf import DictConfig
@@ -37,10 +38,18 @@ def format_sig(sig_vals, decimals=4, atol=0.0, rtol=1e-5):
37
  return ",".join(out.tolist())
38
 
39
  def _safe_process(task):
 
 
 
 
 
 
 
40
  try:
41
- return ("ok", _process_one_chrom_folder(task))
 
42
  except Exception as e:
43
- return ("err", (task[0], repr(e), traceback.format_exc()))
44
 
45
  def discover_chrom_folders(fimo_out_dir: Path) -> list[str]:
46
  return sorted(
@@ -51,6 +60,10 @@ def discover_chrom_folders(fimo_out_dir: Path) -> list[str]:
51
  def _process_one_row(row, dna: str, jaspar_boost: int = 100) -> dict:
52
  # row order: TR, chrom, cstart, cend, peak_s, peak_e, chipscore, jaspar
53
  trname, chrom, cstart, cend, peak_s, peak_e, chipscore, jaspar = row
 
 
 
 
54
 
55
  seq = dna[cstart:cend]
56
  L = len(seq)
@@ -62,7 +75,7 @@ def _process_one_row(row, dna: str, jaspar_boost: int = 100) -> dict:
62
  peak_seq = ""
63
  if ps < L and pe > 0:
64
  scores[max(ps, 0):min(pe, L)] = chipscore
65
- peak_seq = dna[max(ps, 0):min(pe, L)]
66
 
67
  # JASPAR hits (+jaspar_boost)
68
  # only run if the peak is not np.nan
@@ -92,7 +105,7 @@ def _process_one_row(row, dna: str, jaspar_boost: int = 100) -> dict:
92
 
93
  def _process_one_chrom_folder(task) -> pd.DataFrame:
94
  """Runs inside a worker process. Reads one chrom’s final.csv, loads DNA once, builds records."""
95
- chrom_folder, fimo_out_dir_str, json_dir, jaspar_boost, output_parts_folder = task
96
 
97
  # make unique logger for this process
98
  log_dir = Path(HydraConfig.get().run.dir) / "logs"
@@ -120,6 +133,13 @@ def _process_one_chrom_folder(task) -> pd.DataFrame:
120
 
121
  if df.empty:
122
  return pd.DataFrame()
 
 
 
 
 
 
 
123
 
124
  # Normalize dtypes up-front
125
  df["#chrom"] = df["#chrom"].astype(str)
@@ -176,170 +196,342 @@ def _process_one_chrom_folder(task) -> pd.DataFrame:
176
 
177
  return savepath
178
 
179
- def build_dataset_fast_mp(fimo_out_dir: Path, json_dir: str, debug: bool, max_workers: int | None, jaspar_boost: int = 100, output_parts_folder: str = None) -> pd.DataFrame:
180
  """
181
  Multiprocessing to build final dataset across chromosomes
182
  """
183
  chrom_folders = discover_chrom_folders(fimo_out_dir)
184
  if not chrom_folders:
185
  logger.warning(f"No chrom* folders with final.csv under {fimo_out_dir}")
186
- return pd.DataFrame()
187
 
188
  if debug:
189
- chrom_folders = ["chromY"] if "chromY" in chrom_folders else chrom_folders[:1]
 
190
  logger.info(f"DEBUG MODE: considering {chrom_folders[0]} only")
191
 
192
- tasks = [(cf, str(fimo_out_dir), json_dir, jaspar_boost, output_parts_folder) for cf in chrom_folders]
 
193
 
194
- # serial path (debug/deterministic)
195
- if max_workers is not None and max_workers <= 1 or len(tasks) == 1:
196
- parts = []
197
- errs = []
 
 
 
 
 
 
 
 
 
 
 
198
  for t in tasks:
199
  status, payload = _safe_process(t)
200
- if status == "ok" and isinstance(payload, pd.DataFrame) and not payload.empty:
201
- parts.append(payload)
202
- else:
203
- errs.append(payload)
204
- if errs:
205
- for chrom, msg, tb in errs:
206
- logger.error("Worker error for %s: %s\n%s", chrom, msg, tb)
207
- # raise after serial run if you want hard failure
208
- return pd.concat(parts, ignore_index=True) if parts else pd.DataFrame()
209
-
210
- # parallel path
211
- procs = min(max_workers or mp.cpu_count(), len(tasks))
212
- logger.info(f"Using {procs} parallel workers for {len(tasks)} chrom folders")
213
-
214
- paths, errs = [], []
215
- with mp.Pool(processes=procs, maxtasksperchild=10) as pool:
216
- for status, payload in pool.imap_unordered(_safe_process, tasks, chunksize=1):
217
- if status == "ok" and isinstance(payload, pd.DataFrame) and not payload.empty:
218
- paths.append(payload)
219
- else:
220
- errs.append(payload)
221
-
222
  if errs:
223
  for chrom, msg, tb in errs:
224
  logger.error("Worker error for %s: %s\n%s", chrom, msg, tb)
225
- # optional: raise RuntimeError("One or more workers failed, see logs.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
- paths = [p for p in parts if os.path.exists(p)]
228
- return paths
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
  def combine_processed_with_polars(
231
  paths_to_processed_dfs: list[str],
232
  idmap_path: str, # TSV with columns: From, Entry, Sequence
233
  out_path: str, # e.g., "processed_out.parquet" or ".csv"
 
 
 
234
  ):
235
  if not paths_to_processed_dfs:
236
  logger.info("No records produced; nothing to write.")
237
  return
238
 
239
- # 1) Scan all CSVs lazily (no full read)
240
- lfs = [pl.scan_csv(p, infer_schema_length=0) for p in paths_to_processed_dfs]
241
- lf = pl.concat(lfs, how="vertical")
242
- logger.info(f"Scanned CSVs")
243
-
244
- # 2) Drop duplicate occurrences of tr_name and peak_sequence, because these are the same peak
245
- lf = lf.unique(subset=["tr_name", "peak_sequence"], keep="first", maintain_order=True)
246
- logger.info(f"Dropped duplicate examples of tr_name + peak_sequence")
 
 
 
 
 
 
 
 
 
 
 
247
 
248
- # 3) Join small idmap (read eagerly; it’s tiny)
249
  idmap = (
250
  pl.read_csv(idmap_path, separator="\t", columns=["From", "Entry", "Sequence"])
251
- .rename({"From": "tr_name", "Entry": "tr_uniprot", "Sequence": "tr_sequence"})
252
  )
253
- lf = lf.join(idmap.lazy(), on="tr_name", how="left")
254
- logger.info(f"Merged in UniProt IDs and TR sequences from UniProt ID mappping")
 
 
 
 
 
 
255
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  # 4) Per-chromosome unique peak index and peak_id
257
  # (dense rank over peak_sequence per chrom; if you require "first-appearance" order,
258
  # see the note below for an alternate approach.)
259
- lf = lf.with_columns([
 
 
 
 
 
 
 
 
260
  pl.col("peak_sequence").fill_null("").alias("peak_sequence"),
261
  pl.col("chrom").cast(pl.Utf8),
262
  ])
263
- lf = lf.with_columns(
264
  pl.col("peak_sequence")
265
- .rank(method="dense") # 1,2,3,... per group
266
- .over("chrom")
267
- .cast(pl.Int64)
268
- .alias("chrom_peak_idx")
269
  )
270
- lf = lf.with_columns(
271
- pl.format("chrom{}_peak{}", pl.col("chrom"), pl.col("chrom_peak_idx")).alias("peak_id")
272
  )
273
- logger.info(f"Assigned unique peak_ids per chromosome based on peak_sequence")
274
 
275
  # 5) Build stable IDs for dna_sequence and tr_sequence based on first appearance
276
  # (do this by creating small maps with unique(..., maintain_order=True) and joining)
277
- dna_map = (
278
- lf.select("dna_sequence")
279
- .unique(maintain_order=True)
280
- .with_row_index("dna_idx", offset=1)
281
- .with_columns(pl.format("dnaseq{}", pl.col("dna_idx")).alias("dna_seqid"))
282
- .select("dna_sequence", "dna_seqid")
 
 
 
 
 
 
 
283
  )
284
- logger.info(f"Assigned dna_sequence IDs")
285
- tr_map = (
286
- lf.select("tr_sequence")
287
- .unique(maintain_order=True)
288
- .with_row_index("tr_idx", offset=1)
289
- .with_columns(pl.format("trseq{}", pl.col("tr_idx")).alias("tr_seqid"))
290
- .select("tr_sequence", "tr_seqid")
291
- )
292
- logger.info(f"Assigned tr_sequence IDs")
293
- lf = lf.join(dna_map, on="dna_sequence", how="left").join(tr_map, on="tr_sequence", how="left")
294
- logger.info(f"Applied dna_sequence and tr_sequence IDs to main table")
295
-
296
- # 6) Final ID and column selection
297
- lf = lf.with_columns(
298
- (pl.col("tr_seqid") + pl.lit("_") + pl.col("dna_seqid")).alias("ID")
299
- )
300
- cols = [
301
- "ID", "tr_name", "peak_id", "chipscore", "total_jaspar_hits",
302
- "dna_sequence", "tr_sequence", "scores"
303
- ]
304
- lf_out = lf.select(cols)
305
- #n_rows = lf_out.select(pl.len().alias("rows")).collect(streaming=True)["rows"][0]
306
- logger.info(f"Selected final columns")
307
-
308
- # 7) Write streaming to disk
309
- out_path = str(out_path)
310
- Path(out_path).parent.mkdir(parents=True, exist_ok=True)
311
 
312
- if out_path.lower().endswith(".parquet"):
313
- lf_out.sink_parquet(out_path, compression="zstd", statistics=True, row_group_size=128_000)
314
- logger.info(f"Wrote parquet file to {out_path}")
315
- elif out_path.lower().endswith(".csv"):
316
- # NOTE: collect(streaming=True) still returns an in-memory DataFrame;
317
- # prefer Parquet for very large outputs.
318
- lf_out.collect(streaming=True).write_csv(out_path)
319
- logger.info(f"Wrote csv file to {out_path}")
320
- else:
321
- # default to Parquet if no/unknown extension
322
- lf_out.sink_parquet(out_path + ".parquet", compression="zstd", statistics=True)
323
- logger.info(f"Wrote parquet file to {out_path}")
324
-
325
- # 8) (Optional) small summary: unique peaks per chrom
326
- peaks_per_chrom = (
327
- lf.select("chrom", "peak_sequence")
328
- .unique()
329
- .group_by("chrom")
330
- .len()
331
- .collect(streaming=True)
332
- .sort("chrom")
333
  )
334
- logger.info(f"Summary per chromosome:\n{peaks_per_chrom}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
 
336
- logger.info("Schema:")
337
- for name, dtype in lf_out.schema.items():
338
- logger.info(f" {name}: {dtype}")
339
 
340
- # Quick preview of a few rows (safe)
341
- logger.info("\nHead(5):")
342
- logger.info(lf_out.head(5).collect()) # or: lf.limit(5).collect()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
 
344
  # Save the FIRST 1000 rows to CSV (streaming-friendly)
345
  df_first = lf_out.limit(1000).collect(streaming=True)
@@ -347,6 +539,197 @@ def combine_processed_with_polars(
347
  df_first.write_csv(example_out_path)
348
  logger.info(f"Wrote first 1000 rows to {example_out_path} as an example")
349
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
  def main(cfg: DictConfig):
351
  debug = bool(cfg.data_task.debug)
352
  json_dir = cfg.data_task.json_dir
@@ -357,24 +740,43 @@ def main(cfg: DictConfig):
357
 
358
  logger.info(f"Debug: {debug}")
359
  logger.info(f"Reading per-chrom final.csv under: {fimo_out_dir}")
 
 
 
 
360
 
361
- if False:
 
362
  paths_to_processed_dfs = build_dataset_fast_mp(
363
  fimo_out_dir=fimo_out_dir,
364
  json_dir=json_dir,
365
  debug=debug,
366
  max_workers=max_workers,
367
  jaspar_boost=cfg.data_task.jaspar_boost,
368
- output_parts_folder=output_parts_folder
 
369
  )
370
  else:
371
  paths_to_processed_dfs = [output_parts_folder/x for x in os.listdir(output_parts_folder)] if output_parts_folder.exists() else []
372
-
 
 
 
 
 
 
 
 
 
 
 
373
  logger.info(f"Combining {len(paths_to_processed_dfs)} processed parts with Polars")
374
  combine_processed_with_polars(
375
  paths_to_processed_dfs=paths_to_processed_dfs,
376
- idmap_path=Path(root) / cfg.data_task.idmap_path,
377
- out_path=str(processed_output_csv).replace(".csv", ".parquet")
 
 
378
  )
379
 
380
  # Delete the folder that had the temporary DFs, don't need these
 
8
  import numpy as np
9
  import pandas as pd
10
  import math
11
+ import json
12
  import rootutils
13
  import polars as pl
14
  from omegaconf import DictConfig
 
38
  return ",".join(out.tolist())
39
 
40
  def _safe_process(task):
41
+ """
42
+ Returns:
43
+ ("ok", <path-to-output>) on success
44
+ ("err", (chrom, msg, traceback)) on failure
45
+ """
46
+ import traceback as tb
47
+ chrom = task[0]
48
  try:
49
+ out_path = _process_one_chrom_folder(task) # MUST return a path (str/Path)
50
+ return ("ok", str(out_path))
51
  except Exception as e:
52
+ return ("err", (chrom, repr(e), tb.format_exc()))
53
 
54
  def discover_chrom_folders(fimo_out_dir: Path) -> list[str]:
55
  return sorted(
 
60
  def _process_one_row(row, dna: str, jaspar_boost: int = 100) -> dict:
61
  # row order: TR, chrom, cstart, cend, peak_s, peak_e, chipscore, jaspar
62
  trname, chrom, cstart, cend, peak_s, peak_e, chipscore, jaspar = row
63
+
64
+ # very few chipscores are > 1000. standardize by setting >1000 to a max score
65
+ if chipscore>=1000:
66
+ chipscore=1000
67
 
68
  seq = dna[cstart:cend]
69
  L = len(seq)
 
75
  peak_seq = ""
76
  if ps < L and pe > 0:
77
  scores[max(ps, 0):min(pe, L)] = chipscore
78
+ peak_seq = seq[max(ps, 0):min(pe, L)]
79
 
80
  # JASPAR hits (+jaspar_boost)
81
  # only run if the peak is not np.nan
 
105
 
106
  def _process_one_chrom_folder(task) -> pd.DataFrame:
107
  """Runs inside a worker process. Reads one chrom’s final.csv, loads DNA once, builds records."""
108
+ chrom_folder, fimo_out_dir_str, json_dir, jaspar_boost, output_parts_folder, keep_fimo_only = task
109
 
110
  # make unique logger for this process
111
  log_dir = Path(HydraConfig.get().run.dir) / "logs"
 
133
 
134
  if df.empty:
135
  return pd.DataFrame()
136
+
137
+ if keep_fimo_only:
138
+ logger.info(f"keep_fimo_only=True. Starting with {len(df)} rows.")
139
+ df = df.loc[
140
+ ~df["jaspar"].isna()
141
+ ].reset_index(drop=True)
142
+ logger.info(f"After keeping fimo hits only: {len(df)} rows remain.")
143
 
144
  # Normalize dtypes up-front
145
  df["#chrom"] = df["#chrom"].astype(str)
 
196
 
197
  return savepath
198
 
199
+ def build_dataset_fast_mp(fimo_out_dir: Path, json_dir: str, debug: bool, max_workers: int | None, jaspar_boost: int = 100, output_parts_folder: str = None, keep_fimo_only: bool=False) -> pd.DataFrame:
200
  """
201
  Multiprocessing to build final dataset across chromosomes
202
  """
203
  chrom_folders = discover_chrom_folders(fimo_out_dir)
204
  if not chrom_folders:
205
  logger.warning(f"No chrom* folders with final.csv under {fimo_out_dir}")
206
+ return []
207
 
208
  if debug:
209
+ # keep chromY if present, otherwise just the first
210
+ chrom_folders = [c for c in chrom_folders if c == "chromY"] or chrom_folders[:1]
211
  logger.info(f"DEBUG MODE: considering {chrom_folders[0]} only")
212
 
213
+ tasks = [(cf, str(fimo_out_dir), json_dir, jaspar_boost, output_parts_folder, keep_fimo_only)
214
+ for cf in chrom_folders]
215
 
216
+ def _collect(status, payload, good_paths, errs):
217
+ if status == "ok":
218
+ p = Path(payload)
219
+ if p.exists():
220
+ good_paths.append(p)
221
+ else:
222
+ errs.append(("?", f"output missing: {p}", ""))
223
+ else:
224
+ chrom, msg, tb = payload
225
+ errs.append((chrom, msg, tb))
226
+
227
+ # Serial path (debug/deterministic or single task)
228
+ if (max_workers is not None and max_workers <= 1) or len(tasks) == 1:
229
+ good_paths: list[Path] = []
230
+ errs: list[tuple[str, str, str]] = []
231
  for t in tasks:
232
  status, payload = _safe_process(t)
233
+ _collect(status, payload, good_paths, errs)
234
+ else:
235
+ # Parallel path
236
+ procs = min(max_workers or mp.cpu_count(), len(tasks))
237
+ logger.info(f"Using {procs} parallel workers for {len(tasks)} chrom folders")
238
+
239
+ good_paths: list[Path] = []
240
+ errs: list[tuple[str, str, str]] = []
241
+ with mp.Pool(processes=procs, maxtasksperchild=10) as pool:
242
+ for status, payload in pool.imap_unordered(_safe_process, tasks, chunksize=1):
243
+ _collect(status, payload, good_paths, errs)
244
+
 
 
 
 
 
 
 
 
 
 
245
  if errs:
246
  for chrom, msg, tb in errs:
247
  logger.error("Worker error for %s: %s\n%s", chrom, msg, tb)
248
+ # Optionally: raise RuntimeError("One or more workers failed; see logs.")
249
+
250
+ return [str(p) for p in good_paths]
251
+
252
+ def dedup_trname_peakseq_weighted(lf: pl.LazyFrame, seed: int = 42, outdir: str | None = None) -> pl.LazyFrame:
253
+ """
254
+ Remove duplicate pairings of TR + peak sequence, but keep the distribution of chromosomes as best as possible.
255
+ Use a seed so the results are reproducible
256
+ """
257
+ # Normalize key dtypes
258
+ lf = lf.with_columns([
259
+ pl.col("chrom").cast(pl.Utf8),
260
+ pl.col("tr_name").cast(pl.Utf8),
261
+ pl.col("peak_sequence").fill_null("<NULL>").cast(pl.Utf8),
262
+ ])
263
+
264
+ # --- BEFORE: counts/ratios (materialize tiny table)
265
+ pre_df = (
266
+ lf.group_by("chrom").len()
267
+ .with_columns((pl.col("len") / pl.col("len").sum()).alias("pre_ratio"))
268
+ .sort("chrom")
269
+ .collect()
270
+ )
271
+
272
+ # Expected #groups (must equal result rows)
273
+ exp_groups = lf.select(["tr_name","peak_sequence"]).unique().collect().height
274
+ logger.info(f"Expected groups: {exp_groups}")
275
+
276
+ # Tiny weights table back to lazy
277
+ pre_lf = pre_df.lazy().select(["chrom", "pre_ratio"]).rename({"pre_ratio": "w"})
278
+
279
+ # Weighted random score = log(w) + Gumbel(0,1); tie-break by a stable hash
280
+ TWO64 = 18446744073709551616.0
281
+ eps = 1e-12
282
+
283
+ h_expr = pl.concat_str(
284
+ [pl.lit(f"seed:{seed}"), pl.col("tr_name"), pl.col("peak_sequence"), pl.col("chrom")],
285
+ separator="|",
286
+ ).hash().cast(pl.UInt64)
287
+
288
+ u_expr = (h_expr.cast(pl.Float64) + 1.0) / pl.lit(TWO64)
289
+ u_expr = pl.when(u_expr < eps).then(eps).when(u_expr > 1 - eps).then(1 - eps).otherwise(u_expr)
290
 
291
+ logw_expr = pl.when(pl.col("w").is_null() | (pl.col("w") <= 0)).then(eps).otherwise(pl.col("w")).log()
292
+ gumbel_expr = -(-u_expr.log()).log()
293
+ score_expr = (logw_expr + gumbel_expr).alias("_score")
294
+ hash_expr = h_expr.alias("_h")
295
+
296
+ # Attach weights & scores, globally sort, then unique on the keys (keep first)
297
+ lf_sorted = (
298
+ lf.join(pre_lf, on="chrom", how="left")
299
+ .with_columns([score_expr, hash_expr])
300
+ .sort(["_score","_h"], descending=[True, False])
301
+ )
302
+
303
+ lf_sel = lf_sorted.unique(subset=["tr_name","peak_sequence"], keep="first") \
304
+ .drop(["w","_score","_h"])
305
+
306
+ # --- AFTER: counts/ratios + save
307
+ post_df = (
308
+ lf_sel.group_by("chrom").len()
309
+ .with_columns((pl.col("len") / pl.col("len").sum()).alias("post_ratio"))
310
+ .sort("chrom")
311
+ .collect(streaming=True)
312
+ )
313
+ compare_df = (pre_df.select(["chrom","len","pre_ratio"]).rename({"len":"pre_n"})
314
+ .join(post_df.select(["chrom","len","post_ratio"]).rename({"len":"post_n"}),
315
+ on="chrom", how="full")
316
+ .fill_null(0)
317
+ .with_columns((pl.col("post_ratio") - pl.col("pre_ratio")).abs().alias("abs_delta"),
318
+ (100*(pl.col("post_ratio") - pl.col("pre_ratio"))/pl.col("pre_ratio")).abs().alias("pcnt_delta"))
319
+ .sort("chrom")
320
+ ).to_pandas().drop(columns=["chrom_right"])
321
+ # --- Sanity: must keep exactly one per group
322
+ got_rows = lf_sel.select(pl.len()).collect()["len"][0]
323
+ if got_rows != exp_groups:
324
+ # optional: raise or just log
325
+ logger.warning(f"Dedup cardinality mismatch: expected {exp_groups}, got {got_rows}")
326
+
327
+ return lf_sel, compare_df
328
+
329
+ def write_map(lf: pl.LazyFrame, out_path: str, key: str, val: str, outname: str):
330
+ """
331
+ Write the ID maps we created, spanning all the data. Will be called for:
332
+ - tr_seqid to tr_sequence
333
+ - peak_seqid to peak_sequence
334
+ - dna_seqid to dna_sequence
335
+ """
336
+ maps_dir = Path(out_path).parent / "maps"
337
+ maps_dir.mkdir(parents=True, exist_ok=True)
338
+
339
+ df = (
340
+ lf.select([pl.col(key), pl.col(val)])
341
+ .unique()
342
+ .collect(streaming=True)
343
+ )
344
+ mapping = dict(zip(df[key].to_list(), df[val].to_list()))
345
+ with open(maps_dir / outname, "w") as f:
346
+ json.dump(mapping, f, indent=2)
347
+
348
 
349
  def combine_processed_with_polars(
350
  paths_to_processed_dfs: list[str],
351
  idmap_path: str, # TSV with columns: From, Entry, Sequence
352
  out_path: str, # e.g., "processed_out.parquet" or ".csv"
353
+ max_protein_len: int = None,
354
+ check_violations: bool = False,
355
+ seeds: list = [0]
356
  ):
357
  if not paths_to_processed_dfs:
358
  logger.info("No records produced; nothing to write.")
359
  return
360
 
361
+ # 1) Scan each CSV and normalize dtypes BEFORE concat
362
+ lfs = []
363
+ for p in paths_to_processed_dfs:
364
+ lf_i = (
365
+ pl.scan_csv(p) # don't use infer_schema_length=0 here
366
+ .with_columns([
367
+ pl.col("chrom").cast(pl.Utf8),
368
+ pl.col("tr_name").cast(pl.Utf8),
369
+ pl.col("dna_sequence").cast(pl.Utf8),
370
+ pl.col("peak_sequence").cast(pl.Utf8),
371
+ pl.col("scores").cast(pl.Utf8),
372
+ pl.col("chipscore").cast(pl.Float64), # robust (int -> float OK)
373
+ pl.col("total_jaspar_hits").cast(pl.Int64),
374
+ ])
375
+ )
376
+ lfs.append(lf_i)
377
+
378
+ # 2) Now concat; schemas match
379
+ lf_og = pl.concat(lfs, how="vertical")
380
 
381
+ # Read idmap to get list of unmapped TRs, those with no sequence
382
  idmap = (
383
  pl.read_csv(idmap_path, separator="\t", columns=["From", "Entry", "Sequence"])
384
+ .rename({"From": "tr_name", "Entry": "tr_uniprot", "Sequence": "tr_sequence"})
385
  )
386
+ idmap = idmap.with_columns(
387
+ pl.col("tr_sequence").map_elements(lambda x: len(x), return_dtype=pl.Int64).alias("tr_len")
388
+ )
389
+ if max_protein_len is not None:
390
+ idmap = idmap.filter(
391
+ pl.col("tr_len")<=max_protein_len
392
+ )
393
+ logger.info(f"Filtered valid TRs to only those with len <= {max_protein_len}")
394
 
395
+ success_trs = list(idmap["tr_name"].unique())
396
+ logger.info(f"Total valid TRs: {len(success_trs)}")
397
+
398
+ # Filter lf to only have "success-TRs"
399
+ lf_og = lf_og.filter(pl.col("tr_name").is_in(success_trs))
400
+
401
+ # 2) We COULD drop duplicate occurrences of tr_name and peak_sequence, because these are the same peak. But here we're showing them in different contexts
402
+ # Instead, let's drop duplicate occurrences of the same tr_name and dna_sequence, because these are duplicate datapoints.
403
+ lf_out = None # the last one will be used to save the example file
404
+ out_path = str(out_path)
405
+ Path(out_path).parent.mkdir(parents=True, exist_ok=True)
406
+
407
+ lf_og = lf_og.join(idmap.lazy(), on="tr_name", how="left")
408
+ logger.info(f"Merged in UniProt IDs and TR sequences from UniProt ID mappping")
409
+
410
  # 4) Per-chromosome unique peak index and peak_id
411
  # (dense rank over peak_sequence per chrom; if you require "first-appearance" order,
412
  # see the note below for an alternate approach.)
413
+ # Ensure types
414
+ lf_og = lf_og.with_columns([
415
+ pl.col("dna_sequence").cast(pl.Utf8),
416
+ pl.col("tr_sequence").cast(pl.Utf8),
417
+ pl.col("peak_sequence").cast(pl.Utf8)
418
+ ])
419
+
420
+ # set chrom+ peak sequence IDs
421
+ lf_og = lf_og.with_columns([
422
  pl.col("peak_sequence").fill_null("").alias("peak_sequence"),
423
  pl.col("chrom").cast(pl.Utf8),
424
  ])
425
+ lf_og = lf_og.with_columns(
426
  pl.col("peak_sequence")
427
+ .rank(method="dense") # 1,2,3,... per group
428
+ .over("chrom")
429
+ .cast(pl.Int64)
430
+ .alias("chrom_peak_idx")
431
  )
432
+ lf_og = lf_og.with_columns(
433
+ pl.format("chr{}_peak{}", pl.col("chrom"), pl.col("chrom_peak_idx")).alias("chrpeak_id")
434
  )
435
+ logger.info(f"Assigned unique chrpeak_ids per chromosome based on peak_sequence")
436
 
437
  # 5) Build stable IDs for dna_sequence and tr_sequence based on first appearance
438
  # (do this by creating small maps with unique(..., maintain_order=True) and joining)
439
+ # Sequence-based IDs without any joins
440
+ lf_og = (
441
+ lf_og.with_columns([
442
+ pl.col("dna_sequence").rank(method="dense").cast(pl.Int64).alias("dna_idx"),
443
+ pl.col("tr_sequence").rank(method="dense").cast(pl.Int64).alias("tr_idx"),
444
+ pl.col("peak_sequence").rank(method="dense").cast(pl.Int64).alias("peak_idx"),
445
+ ])
446
+ .with_columns([
447
+ pl.format("dnaseq{}", pl.col("dna_idx")).alias("dna_seqid"),
448
+ pl.format("trseq{}", pl.col("tr_idx")).alias("tr_seqid"),
449
+ pl.format("peakseq{}", pl.col("peak_idx")).alias("peak_seqid"),
450
+ ])
451
+ .drop(["dna_idx", "tr_idx", "peak_idx"])
452
  )
453
+ logger.info(f"Assigned unique dna IDs, transcriptional regulator IDs, and peak IDs")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
454
 
455
+ # Final ID (will never be None now)
456
+ lf_og = lf_og.with_columns(
457
+ pl.concat_str([pl.col("tr_seqid"), pl.lit("_"), pl.col("dna_seqid")], ignore_nulls=False)
458
+ .alias("ID")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
459
  )
460
+
461
+ # Write the maps
462
+ # call it for each mapping
463
+ write_map(lf_og, out_path=out_path, val="tr_sequence", key="tr_seqid", outname="tr_seqid_to_tr_sequence.json")
464
+ write_map(lf_og, out_path=out_path, val="peak_sequence", key="peak_seqid", outname="peak_seqid_to_peak_sequence.json")
465
+ write_map(lf_og, out_path=out_path, val="dna_sequence", key="dna_seqid", outname="dna_seqid_to_dna_sequence.json")
466
+
467
+ for seed in seeds:
468
+ # edit out path to include seed
469
+ if "." in out_path:
470
+ out_path_full = out_path[0:out_path.rindex(".")] + f"_seed{seed}" + out_path[out_path.rindex("."):]
471
+ else:
472
+ out_path_full = out_path + f"_seed{seed}.parquet"
473
+
474
+ lf, compare_df = dedup_trname_peakseq_weighted(lf_og, seed=seed)
475
+ logger.info(f"Dropped duplicate examples of tr_name + peak_sequence. Maintained chrom distribution with weighted random sampling (seed={seed}).")
476
+
477
+ # Save comparison df. Annotate with debug if it's a debug run
478
+ compare_df_path = str(Path(out_path).parent/f"chrom_ratio_compare_seed{seed}.csv")
479
+ if "debug" in out_path: compare_df_path = compare_df_path.replace(".csv", "_debug.csv")
480
+ compare_df.to_csv(compare_df_path, index=False)
481
+
482
+ # 3) Join small idmap (read eagerly; it’s tiny)
483
+ lf = lf.join(idmap.lazy(), on="tr_name", how="left")
484
+ logger.info(f"Merged in UniProt IDs and TR sequences from UniProt ID mappping")
485
 
 
 
 
486
 
487
+ logger.info(f"Applied dna_sequence and tr_sequence IDs to main table")
488
+
489
+ # Each sequence maps to exactly one id
490
+ viol1 = (
491
+ lf.select("dna_sequence", "dna_seqid").unique()
492
+ .group_by("dna_sequence").agg(pl.n_unique("dna_seqid").alias("n_ids"))
493
+ .filter(pl.col("n_ids") > 1)
494
+ .collect()
495
+ )
496
+ # Each id maps to exactly one sequence
497
+ viol2 = (
498
+ lf.select("dna_sequence", "dna_seqid").unique()
499
+ .group_by("dna_seqid").agg(pl.n_unique("dna_sequence").alias("n_seqs"))
500
+ .filter(pl.col("n_seqs") > 1)
501
+ .collect()
502
+ )
503
+ logger.info("viol1 rows (seq→>1 id): %d; viol2 rows (id→>1 seq): %d", viol1.height, viol2.height)
504
+
505
+ # No NULLs
506
+ nulls = lf.select([
507
+ pl.col("dna_seqid").is_null().sum().alias("null_dna_seqid"),
508
+ pl.col("tr_seqid").is_null().sum().alias("null_tr_seqid"),
509
+ pl.col("ID").is_null().sum().alias("null_ID"),
510
+ ]).collect()
511
+ logger.info("NULL counts:\n%s", nulls)
512
+
513
+ # 6) Final column selection
514
+ cols = [
515
+ "ID", "tr_seqid", "dna_seqid", "peak_seqid", "chrpeak_id", "tr_name", "chipscore", "total_jaspar_hits",
516
+ "dna_sequence", "tr_sequence", "scores"
517
+ ]
518
+ lf_out = lf.select(cols)
519
+ #n_rows = lf_out.select(pl.len().alias("rows")).collect(streaming=True)["rows"][0]
520
+ logger.info(f"Selected final columns")
521
+
522
+ # 7) Write streaming to disk
523
+ if out_path_full.lower().endswith(".parquet"):
524
+ lf_out.sink_parquet(out_path_full, compression="zstd", statistics=True, row_group_size=128_000)
525
+ logger.info(f"Wrote parquet file to {out_path_full}")
526
+ elif out_path_full.lower().endswith(".csv"):
527
+ # NOTE: collect(streaming=True) still returns an in-memory DataFrame;
528
+ # prefer Parquet for very large outputs.
529
+ lf_out.collect(streaming=True).write_csv(out_path_full)
530
+ logger.info(f"Wrote csv file to {out_path_full}")
531
+ else:
532
+ # default to Parquet if no/unknown extension
533
+ lf_out.sink_parquet(out_path_full + ".parquet", compression="zstd", statistics=True)
534
+ logger.info(f"Wrote parquet file to {out_path_full}")
535
 
536
  # Save the FIRST 1000 rows to CSV (streaming-friendly)
537
  df_first = lf_out.limit(1000).collect(streaming=True)
 
539
  df_first.write_csv(example_out_path)
540
  logger.info(f"Wrote first 1000 rows to {example_out_path} as an example")
541
 
542
+ # FIMO check
543
+ def get_reverse_complement(s):
544
+ """
545
+ Returns 5' to 3' sequence of the reverse complement
546
+ """
547
+ chars = list(s)
548
+ recon = []
549
+ rev_map = {
550
+ "a": "t",
551
+ "c": "g",
552
+ "t": "a",
553
+ "g": "c",
554
+ "A":"T",
555
+ "C": "G",
556
+ "T": "A",
557
+ "G": "C",
558
+ "n": "n",
559
+ "N": "N"
560
+ }
561
+ for c in chars:
562
+ recon += [rev_map[c]]
563
+
564
+ recon = "".join(recon)
565
+ return recon[::-1]
566
+
567
+ def extract_jaspar_motifs(row, reverse_complement=False):
568
+ s = row["scores"]
569
+ s = [int(x) for x in s.split(",")]
570
+ n_motifs = row["total_jaspar_hits"]
571
+ if n_motifs==0:
572
+ return ""
573
+ chipscore = row["chipscore"]
574
+ dna_seq = row["dna_sequence"]
575
+ if reverse_complement:
576
+ dna_seq = row["dna_sequence_rc"]
577
+ jaspar_indices = [i for i in list(range(len(s))) if s[i]>chipscore]
578
+
579
+ pred_motif = ""
580
+ for i in list(range(jaspar_indices[0], jaspar_indices[-1]+1)):
581
+ if not(i in jaspar_indices):
582
+ pred_motif += "-"
583
+ else:
584
+ pred_motif += dna_seq[i]
585
+
586
+ return pred_motif
587
+
588
+ def clean_idmap(idmap_path):
589
+ """
590
+ The raw ID Map from UniProt returned multiple results.
591
+ We went to ReMap and wrote down what the right mappings are in these cases.
592
+ """
593
+
594
+ manual_map = {
595
+ "BACH1": "O14867",
596
+ "BAP1": "Q92560",
597
+ "BDP1": "A6H8Y1",
598
+ "BRF1": "Q92994",
599
+ "CUX1": "Q13948",
600
+ "DDX21": "Q9NR30",
601
+ "ERG": "P11308",
602
+ "HBP1": "O60381",
603
+ "KLF14": "Q8TD94",
604
+ "MED1": "Q15648",
605
+ "MED25": "Q71SY5",
606
+ "MGA": "Q8IWI9",
607
+ "NRF1": "Q16656",
608
+ "PAF1": "Q8N7H5",
609
+ "PDX1": "P52945",
610
+ "RBP2": "P50120",
611
+ "RLF": "Q13129",
612
+ "SP1": "P08047",
613
+ "SPIN1": "Q9Y657",
614
+ "STAG1": "Q8WVM7",
615
+ "TAF15": "Q92804",
616
+ "TCF3": "P15923",
617
+ "ZFP36": "P26651",
618
+ "EVI1": "Q03112",
619
+ "MCM2": "P49736"
620
+ }
621
+ idmap = pd.read_csv(idmap_path, sep="\t")
622
+ idmap["Remap_Entry"] = idmap.apply(lambda row: row["Entry"] if not(row["From"] in manual_map) else manual_map[row["From"]], axis=1)
623
+ idmap_remapped = idmap.loc[
624
+ idmap["Entry"]==idmap["Remap_Entry"]
625
+ ].reset_index(drop=True).drop(columns=["Remap_Entry"])
626
+
627
+ assert len(idmap_remapped)==len(idmap_remapped["From"].unique())
628
+ logger.info(f"Total transcriptional regulators successfully mapped in UniProt: {len(idmap_remapped)}")
629
+
630
+ clean_idmap_path = Path(root)/"dpacman/data_files/processed/remap/idmapping_reviewed_true_processed_2025_08_11.tsv"
631
+ idmap_remapped.to_csv(clean_idmap_path, sep="\t")
632
+ return clean_idmap_path
633
+
634
+ def debug_fimo_check(path_to_chrom_fimo, path_to_processed_chrom, chrom="Y", json_dir=""):
635
+ """
636
+ Make sure we are properly extracting fimo sequences.
637
+ """
638
+ processed = pd.read_csv(path_to_processed_chrom)
639
+ processed["pred_motif_string"] = processed.apply(lambda row: extract_jaspar_motifs(row), axis=1)
640
+ processed["dna_sequence_rc"] = processed["dna_sequence"].apply(lambda x: get_reverse_complement(x))
641
+ processed["pred_motif_string_rc"] = processed.apply(lambda row: extract_jaspar_motifs(row, reverse_complement=True), axis=1)
642
+ processed_trs = processed["tr_name"].unique().tolist()
643
+
644
+ fimo = pd.read_csv(path_to_chrom_fimo)
645
+ fimo["input_tr"] = fimo["sequence_name"].str.split("_",expand=True)[2]
646
+ fimo_valid = fimo.loc[
647
+ (fimo["motif_alt_id"]==fimo["input_tr"]) &
648
+ (fimo["motif_alt_id"].isin(processed_trs))
649
+ ].reset_index(drop=True)
650
+ logger.info(f"Total valid FIMO matches: {len(fimo_valid)}")
651
+ logger.info(f"Total transcriptional regulators being considered: {len(processed_trs)}")
652
+
653
+ # Load DNA
654
+ cache_debug = load_chrom_dna(chrom, {}, json_dir=json_dir)
655
+
656
+ # Randomly select a positive and negative row to test
657
+ pos_row = fimo_valid.loc[fimo_valid["strand"]=="+"].sample(n=1, random_state=44)
658
+ neg_row = fimo_valid.loc[fimo_valid["strand"]=="-"].sample(n=1, random_state=44)
659
+
660
+ # Iterate through the rows
661
+ for row in [pos_row, neg_row]:
662
+ indices = [int(x) for x in row["sequence_name"].item().split("_")[-2::]]
663
+ chipseq_start, chipseq_end = indices
664
+ strand = row['strand'].item()
665
+ logger.info(f"ChIPseq start: {chipseq_start}, ChIPseq end: {chipseq_end}")
666
+ logger.info(f"Strand: {strand}")
667
+
668
+ motif_start = chipseq_start + int(row["start"].item()) -1
669
+ motif_end = chipseq_start + int(row["stop"].item())
670
+ motif = row['matched_sequence'].item()
671
+ full_seq = cache_debug[chipseq_start:chipseq_end].upper()
672
+ logger.info(f"Full sequence: {full_seq}")
673
+ logger.info(f"Full sequence reverse complement: {get_reverse_complement(full_seq)}")
674
+ our_motif = cache_debug[motif_start:motif_end].upper()
675
+
676
+ if strand=="+":
677
+ logger.info(f"True motif found by FIMO: {motif}")
678
+ logger.info(f"Extracted motif on our end: {our_motif}")
679
+ logger.info(f"Correct extraction: {motif==our_motif}")
680
+
681
+ matching_rows = processed.loc[
682
+ (processed["dna_sequence"].str.contains(full_seq)) &
683
+ (processed["pred_motif_string"].str.contains(our_motif))
684
+ ]
685
+ if strand=="-":
686
+ our_motif_rc = get_reverse_complement(our_motif)
687
+
688
+ logger.info(f"True motif found by FIMO: {motif}")
689
+ logger.info(f"Extracted motif on our end: {our_motif_rc}")
690
+ logger.info(f"Correct extraction: {motif==our_motif_rc}")
691
+ logger.info(f"Motif that will appear in the forward sequence: {our_motif}")
692
+ matching_rows = processed.loc[
693
+ (processed["dna_sequence"].str.contains(full_seq)) &
694
+ (processed["pred_motif_string"].str.contains(our_motif))
695
+ ]
696
+
697
+ # Now find if there are rows with the same TR, and the same DNA sequence and motif
698
+ matching_row_trs = sorted(matching_rows["tr_name"].unique().tolist())
699
+ expected_tr = row['motif_alt_id'].item()
700
+ logger.info(f"TR from selected row: {expected_tr}")
701
+ logger.info(f"TRs with same motif: {','.join(matching_row_trs)}")
702
+ logger.info(f"Expected TR in list: {expected_tr in matching_row_trs}")
703
+
704
+ def debug_remap_check(remap_path, path_to_processed_chrom, chrom="Y", json_dir=""):
705
+ """
706
+ For debugging mode: pick a random row from processed remap. make sure the sequence matches the one we're getting here.
707
+ """
708
+ remap = pd.read_csv(remap_path)
709
+ remap["ChIPStart"] = remap["ChIPStart"].astype(int)
710
+ remap["ChIPEnd"] = remap["ChIPEnd"].astype(int)
711
+
712
+ row = remap.loc[remap["#chrom"]=="Y"].sample(n=1, random_state=42)
713
+
714
+ start, end = row["ChIPStart"].item(), row["ChIPEnd"].item()
715
+
716
+ cache_debug = load_chrom_dna(chrom, {}, json_dir=json_dir)
717
+ test_seq = cache_debug[start:end].upper()
718
+ logger.info(f"Randomly sampled sequence ({len(test_seq)} nucleotides), chrY {start}:{end}\n\tsequence: {test_seq}")
719
+ should_find = remap.loc[
720
+ (remap["ChIPStart"]==start) &
721
+ (remap["ChIPEnd"]==end)
722
+ ]["TR"].unique().tolist()
723
+ logger.info(f"Expect to find {len(should_find)} TRs: {', '.join(sorted(should_find))}")
724
+
725
+ processed = pd.read_csv(path_to_processed_chrom)
726
+ did_find = processed.loc[
727
+ processed["peak_sequence"]==test_seq
728
+ ]["tr_name"].unique().tolist()
729
+ logger.info(f"Looked up same sequence in processed chrY file.\nFound TRs: {', '.join(sorted(did_find))}")
730
+ logger.info(f"found==expected: {did_find==should_find}")
731
+
732
+
733
  def main(cfg: DictConfig):
734
  debug = bool(cfg.data_task.debug)
735
  json_dir = cfg.data_task.json_dir
 
740
 
741
  logger.info(f"Debug: {debug}")
742
  logger.info(f"Reading per-chrom final.csv under: {fimo_out_dir}")
743
+
744
+ # process the idmap
745
+ idmap_path=Path(root) / cfg.data_task.idmap_path
746
+ clean_idmap_path = clean_idmap(idmap_path)
747
 
748
+ # If we don't have temp files to process
749
+ if not(os.path.exists(output_parts_folder)) or (os.path.exists(output_parts_folder) and len(os.listdir(output_parts_folder))<24):
750
  paths_to_processed_dfs = build_dataset_fast_mp(
751
  fimo_out_dir=fimo_out_dir,
752
  json_dir=json_dir,
753
  debug=debug,
754
  max_workers=max_workers,
755
  jaspar_boost=cfg.data_task.jaspar_boost,
756
+ output_parts_folder=output_parts_folder,
757
+ keep_fimo_only=cfg.data_task.keep_fimo_only
758
  )
759
  else:
760
  paths_to_processed_dfs = [output_parts_folder/x for x in os.listdir(output_parts_folder)] if output_parts_folder.exists() else []
761
+
762
+ # Debug methods: (1) make sure our peak sequences correspond to remap, (2) make sure our FIMO sequences correspond to FIMO results
763
+ out_path = str(processed_output_csv).replace(".csv", ".parquet")
764
+ if debug:
765
+ debug_remap_check(remap_path=Path(root) / cfg.data_task.remap_path,
766
+ path_to_processed_chrom=Path(output_parts_folder)/"chromY_processed.csv",
767
+ chrom="Y", json_dir=json_dir)
768
+ debug_fimo_check(path_to_chrom_fimo=Path(root) / cfg.data_task.fimo_out_dir / "chromY" / "fimo_annotations.csv",
769
+ path_to_processed_chrom=Path(output_parts_folder)/"chromY_processed.csv",
770
+ chrom="Y", json_dir=json_dir)
771
+ out_path = out_path.replace(".parquet", "_debug.parquet")
772
+
773
  logger.info(f"Combining {len(paths_to_processed_dfs)} processed parts with Polars")
774
  combine_processed_with_polars(
775
  paths_to_processed_dfs=paths_to_processed_dfs,
776
+ idmap_path=clean_idmap_path,
777
+ out_path=out_path,
778
+ max_protein_len=cfg.data_task.max_protein_len,
779
+ seeds=cfg.data_task.seeds
780
  )
781
 
782
  # Delete the folder that had the temporary DFs, don't need these
dpacman/data_tasks/fimo/run_fimo.py CHANGED
@@ -152,17 +152,26 @@ def run_fimo_chunk(cfg):
152
  def annotate_with_fimo(df, fdf):
153
  df = df.reset_index().rename(columns={"index":"idx"})
154
  df["sequence_name"] = df["idx"].astype(str) + "_chr" + df["#chrom"] + "_" + df["TR"] + "_" + df["contextStart"].astype(str) + "_" + df["contextEnd"].astype(str) #construt it the same way as headers
155
- fdf = fdf.merge(df[["sequence_name", "contextStart"]], on="sequence_name", how="left")
156
- fdf["genomic_start"] = fdf["contextStart"] + fdf["start"] - 1
157
- fdf["genomic_end"] = fdf["contextStart"] + fdf["stop"]
158
- fdf["coord"] = (
159
- fdf["genomic_start"].astype(str) + "-" + fdf["genomic_end"].astype(str)
 
 
 
 
 
 
 
 
 
160
  )
161
- agg = fdf.groupby("sequence_name")["coord"].agg(lambda hits: ",".join(hits))
 
162
  df["jaspar"] = df["sequence_name"].map(agg).fillna("")
163
  return df
164
 
165
-
166
  def main(cfg: DictConfig):
167
  """
168
  Main method for running FIMO analysis, searching JASPAR motifs against ChIP-seq peaks
 
152
  def annotate_with_fimo(df, fdf):
153
  df = df.reset_index().rename(columns={"index":"idx"})
154
  df["sequence_name"] = df["idx"].astype(str) + "_chr" + df["#chrom"] + "_" + df["TR"] + "_" + df["contextStart"].astype(str) + "_" + df["contextEnd"].astype(str) #construt it the same way as headers
155
+
156
+ # Crucial: filter FDF results to only rows where the TF whose motif was found actually matches the TF that was detected there.
157
+ fdf["input_tr"] = fdf["sequence_name"].str.split("_",expand=True)[2]
158
+ true_matches = fdf.loc[
159
+ fdf["motif_alt_id"]==fdf["input_tr"]
160
+ ].reset_index(drop=True)
161
+ logger.info(f"Length of full returned FIMO results: {len(fdf)}")
162
+ logger.info(f"Length of true matches, where the FIMO tr and the input tr match: {len(true_matches)}")
163
+
164
+ true_matches = true_matches.merge(df[["sequence_name", "contextStart"]], on="sequence_name", how="left")
165
+ true_matches["genomic_start"] = true_matches["contextStart"] + true_matches["start"] - 1
166
+ true_matches["genomic_end"] = true_matches["contextStart"] + true_matches["stop"]
167
+ true_matches["coord"] = (
168
+ true_matches["genomic_start"].astype(str) + "-" + true_matches["genomic_end"].astype(str)
169
  )
170
+
171
+ agg = true_matches.groupby("sequence_name")["coord"].agg(lambda hits: ",".join(hits))
172
  df["jaspar"] = df["sequence_name"].map(agg).fillna("")
173
  return df
174
 
 
175
  def main(cfg: DictConfig):
176
  """
177
  Main method for running FIMO analysis, searching JASPAR motifs against ChIP-seq peaks
dpacman/data_tasks/split/__init__.py ADDED
File without changes
dpacman/data_tasks/split/remap.py ADDED
@@ -0,0 +1,512 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import Counter, defaultdict
2
+ from ortools.linear_solver import pywraplp
3
+ import random
4
+ import logging
5
+ from omegaconf import DictConfig
6
+ import rootutils
7
+ import pandas as pd
8
+ from pathlib import Path
9
+ import os
10
+ import numpy as np
11
+ from sklearn.model_selection import train_test_split
12
+
13
+ root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
14
+ logger = logging.getLogger(__name__)
15
+
16
+ def split_bipartite_fast(
17
+ dna_clusters,
18
+ split_names=("train","val","test"),
19
+ ratios=(0.8,0.1,0.1),
20
+ ):
21
+ # use sklearn
22
+ test_size_1 = 0.2
23
+ test_size_2 = 0.5
24
+ logger.info(f"\tPerforming first split: all clusters -> train clusters ({round(1-test_size_1,3)}) and other ({test_size_1})")
25
+ X = dna_clusters
26
+ y = [0]*len(dna_clusters)
27
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size_1, random_state=0)
28
+ logger.info(f"\tPerforming second split: other -> val clusters ({round(1-test_size_2,3)}) and test clusters ({test_size_2})")
29
+ X_val, X_test, y_val, y_test = train_test_split(X_test, y_test, test_size=test_size_2, random_state=0)
30
+
31
+ dna_assign = {}
32
+ for x in X_train:
33
+ dna_assign[x] = "train"
34
+ for x in X_val:
35
+ dna_assign[x] = "val"
36
+ for x in X_test:
37
+ dna_assign[x] = "test"
38
+
39
+ kept_by_split = {'train': len(X_train), 'val': len(X_val), 'test': len(X_test)}
40
+ return dna_assign, kept_by_split
41
+
42
+ def split_bipartite_with_ratios_and_leaky(
43
+ edges,
44
+ split_names=("train","val","test"),
45
+ ratios=(0.8, 0.1, 0.1),
46
+ require_nonempty=False,
47
+ ratio_tolerance=None, # None = soft ratios only; 0.0 = exact band (use with care)
48
+ bigM=None,
49
+ shuffle_within_pair=False,
50
+ seed=0,
51
+ test_edges_must=None, # NEW: list of (tf,dna) with duplicates OR dict {(tf,dna): count}
52
+ ):
53
+ """
54
+ edges: list of (tf_cluster_id, dna_cluster_id). Duplicates allowed (-> weights).
55
+ test_edges_must: None, list of pairs, or dict {(tf,dna): required_count}.
56
+ - If a pair appears with required_count > 0, at least that many examples MUST be kept in TEST.
57
+ - This implicitly pins both clusters of that pair to TEST (cluster exclusivity).
58
+
59
+ Returns:
60
+ tf_assign: {tf_cluster -> split}
61
+ dna_assign: {dna_cluster -> split}
62
+ kept_by_split: {split -> kept_count} (train/val/test only)
63
+ total_kept: int
64
+ split_to_indices: {split -> [input indices]} including 'leaky_test'
65
+ split_to_edges: {split -> [(tf,dna), ...]} including 'leaky_test'
66
+ """
67
+ # Aggregate counts per pair
68
+ w = Counter(edges)
69
+ tfs = {t for (t, _) in w}
70
+ dnas = {d for (_, d) in w}
71
+ S = list(split_names)
72
+ rs = dict(zip(S, ratios))
73
+ N = sum(w.values())
74
+ if bigM is None:
75
+ bigM = 1000 * max(1, N)
76
+
77
+ # Index original edges so we can return a per-example split
78
+ pair_to_indices = defaultdict(list)
79
+ for idx, (c, d) in enumerate(edges):
80
+ pair_to_indices[(c, d)].append(idx)
81
+
82
+ if shuffle_within_pair:
83
+ rng = random.Random(seed)
84
+ for key in pair_to_indices:
85
+ rng.shuffle(pair_to_indices[key])
86
+
87
+ # Parse required test edges
88
+ req_test = Counter()
89
+ if test_edges_must:
90
+ if isinstance(test_edges_must, dict):
91
+ for k, v in test_edges_must.items():
92
+ if not isinstance(k, tuple) or len(k) != 2:
93
+ raise ValueError("test_edges_must dict keys must be (tf_cluster, dna_cluster)")
94
+ if v < 0:
95
+ raise ValueError("required_count must be non-negative")
96
+ if v:
97
+ req_test[k] += int(v)
98
+ else:
99
+ # assume iterable of pairs
100
+ req_test = Counter(test_edges_must)
101
+ # Validate against available counts
102
+ for pair, req in req_test.items():
103
+ if pair not in w:
104
+ raise ValueError(f"Required test pair {pair} not present in edges.")
105
+ if req > w[pair]:
106
+ raise ValueError(
107
+ f"Required count {req} for {pair} exceeds available {w[pair]}."
108
+ )
109
+
110
+ # Build solver
111
+ solver = pywraplp.Solver.CreateSolver("CBC")
112
+ if solver is None:
113
+ raise RuntimeError("Could not create CBC solver.")
114
+
115
+ # Binary cluster assignments
116
+ x = {(c,s): solver.BoolVar(f"x[{c},{s}]") for c in tfs for s in S}
117
+ y = {(d,s): solver.BoolVar(f"y[{d},{s}]") for d in dnas for s in S}
118
+
119
+ # Each cluster in exactly one split
120
+ for c in tfs:
121
+ solver.Add(sum(x[c,s] for s in S) == 1)
122
+ for d in dnas:
123
+ solver.Add(sum(y[d,s] for s in S) == 1)
124
+
125
+ # Integer kept counts per pair and split (allow partial within-pair)
126
+ k = {((c,d),s): solver.IntVar(0, w[(c,d)], f"k[{c},{d},{s}]") for (c,d) in w for s in S}
127
+
128
+ # Only keep in split s if both endpoint clusters are assigned to s
129
+ for (c,d), wt in w.items():
130
+ for s in S:
131
+ solver.Add(k[((c,d),s)] <= wt * x[c,s])
132
+ solver.Add(k[((c,d),s)] <= wt * y[d,s])
133
+
134
+ # Enforce minimum kept counts in TEST for required pairs
135
+ for (c,d), req in req_test.items():
136
+ solver.Add(k[((c,d), "test")] >= req)
137
+
138
+ # Optional: ensure each split has at least one cluster (feasibility depends on counts)
139
+ if require_nonempty:
140
+ for s in S:
141
+ solver.Add(
142
+ sum(x[c,s] for c in tfs) + sum(y[d,s] for d in dnas) >= 1
143
+ )
144
+
145
+ # Kept counts per split and total
146
+ K = {s: solver.IntVar(0, N, f"K[{s}]") for s in S}
147
+ for s in S:
148
+ solver.Add(K[s] == sum(k[((c,d),s)] for (c,d) in w))
149
+ T = solver.IntVar(0, N, "T")
150
+ solver.Add(T == sum(K[s] for s in S))
151
+
152
+ # Ratio deviation: K_s - r_s * T = d+ - d-
153
+ dpos = {s: solver.NumVar(0, solver.infinity(), f"dpos[{s}]") for s in S}
154
+ dneg = {s: solver.NumVar(0, solver.infinity(), f"dneg[{s}]") for s in S}
155
+ for s in S:
156
+ solver.Add(K[s] - rs[s]*T == dpos[s] - dneg[s])
157
+
158
+ # Optional hard band around target ratios
159
+ if ratio_tolerance is not None:
160
+ eps = float(ratio_tolerance)
161
+ for s in S:
162
+ solver.Add(K[s] >= (rs[s] - eps) * T)
163
+ solver.Add(K[s] <= (rs[s] + eps) * T)
164
+
165
+ # Objective: maximize T then minimize total deviation
166
+ obj = solver.Objective()
167
+ obj.SetMaximization()
168
+ obj.SetCoefficient(T, float(bigM))
169
+ for s in S:
170
+ obj.SetCoefficient(dpos[s], -1.0)
171
+ obj.SetCoefficient(dneg[s], -1.0)
172
+
173
+ status = solver.Solve()
174
+ if status not in (pywraplp.Solver.OPTIMAL, pywraplp.Solver.FEASIBLE):
175
+ raise RuntimeError("No feasible solution (check ratio_tolerance vs. required test edges).")
176
+
177
+ # Read cluster assignments
178
+ tf_assign = {c: next(s for s in S if x[c,s].solution_value() > 0.5) for c in tfs}
179
+ dna_assign = {d: next(s for s in S if y[d,s].solution_value() > 0.5) for d in dnas}
180
+
181
+ # Kept counts per split
182
+ kept_by_split = {s: int(round(K[s].solution_value())) for s in S}
183
+ total_kept = int(round(T.solution_value()))
184
+
185
+ # ---- Build per-example split assignment (including 'leaky_test') ----
186
+ split_to_indices = {s: [] for s in S}
187
+ remaining_indices = {pair: list(pair_to_indices[pair]) for pair in pair_to_indices}
188
+
189
+ # Allocate the kept examples per split (train/val/test)
190
+ for (c,d), wt in w.items():
191
+ for s in S:
192
+ cnt = int(round(k[((c,d),s)].solution_value()))
193
+ if cnt > 0:
194
+ take = remaining_indices[(c,d)][:cnt]
195
+ split_to_indices[s].extend(take)
196
+ remaining_indices[(c,d)] = remaining_indices[(c,d)][cnt:]
197
+
198
+ # Everything left becomes leaky_test
199
+ leaky_indices = []
200
+ for pair, idxs in remaining_indices.items():
201
+ if idxs:
202
+ leaky_indices.extend(idxs)
203
+
204
+ split_to_indices["leaky_test"] = leaky_indices
205
+ split_to_edges = {s: [edges[i] for i in split_to_indices[s]] for s in split_to_indices}
206
+
207
+ return tf_assign, dna_assign, kept_by_split, total_kept, split_to_indices, split_to_edges
208
+
209
+ from collections import Counter, defaultdict
210
+ import random
211
+
212
+ class DSU:
213
+ def __init__(self): self.p = {}
214
+ def find(self, x):
215
+ if x not in self.p: self.p[x] = x
216
+ while self.p[x] != x:
217
+ self.p[x] = self.p[self.p[x]]
218
+ x = self.p[x]
219
+ return x
220
+ def union(self, a,b):
221
+ ra, rb = self.find(a), self.find(b)
222
+ if ra != rb: self.p[rb] = ra
223
+
224
+ def split_bipartite_by_components(
225
+ edges,
226
+ split_names=("train","val","test"),
227
+ ratios=(0.8,0.1,0.1),
228
+ seed=0,
229
+ require_nonempty=False,
230
+ test_edges_must=None, # None, list[(tf,dna)], or dict{(tf,dna): count}
231
+ ):
232
+ """
233
+ Guarantees exclusivity: each TF cluster and DNA cluster appears in at most one split.
234
+ Strategy: find connected components in the TF–DNA bipartite graph and assign components wholesale.
235
+ """
236
+ rng = random.Random(seed)
237
+ w = Counter(edges) # multiplicities per pair
238
+ if not w: raise ValueError("No edges.")
239
+
240
+ # 1) Build components with Union-Find (prefix to keep TF/DNA namespaces disjoint)
241
+ dsu = DSU()
242
+ for (tf, dna) in w:
243
+ dsu.union(("T", tf), ("D", dna))
244
+ comp_pairs = defaultdict(list)
245
+ comp_weight = defaultdict(int)
246
+ for (tf, dna), cnt in w.items():
247
+ root = dsu.find(("T", tf)) # component id = root of TF endpoint
248
+ comp_pairs[root].append((tf, dna))
249
+ comp_weight[root] += cnt
250
+
251
+ comps = list(comp_pairs.keys())
252
+ C = len(comps)
253
+ S = list(split_names)
254
+ rs = dict(zip(S, ratios))
255
+ N = sum(comp_weight[c] for c in comps)
256
+ target = {s: int(round(rs[s] * N)) for s in S}
257
+
258
+ # 2) Pin components that contain required TEST pairs
259
+ pinned = {} # comp_root -> pinned_split ("test")
260
+ if test_edges_must:
261
+ req = Counter(test_edges_must) if not isinstance(test_edges_must, dict) else Counter(test_edges_must)
262
+ # Map each required pair to its component, ensure feasibility
263
+ for (tf, dna), r in req.items():
264
+ if (tf, dna) not in w:
265
+ raise ValueError(f"Required pair {(tf,dna)} not present.")
266
+ if r > w[(tf, dna)]:
267
+ raise ValueError(f"Required count {r} for {(tf,dna)} exceeds available {w[(tf,dna)]}.")
268
+ comp = dsu.find(("T", tf))
269
+ if comp in pinned and pinned[comp] != "test":
270
+ raise ValueError(f"Component conflict: already pinned to {pinned[comp]}, but {(tf,dna)} demands test.")
271
+ pinned[comp] = "test"
272
+ # NOTE: pinning a pair pins the WHOLE component to test (to keep exclusivity).
273
+ # If you only want some edges kept in test and discard the rest, handle below when materializing.
274
+
275
+ # 3) Assign components greedily by deficit
276
+ kept_by_split = {s: 0 for s in S}
277
+ comp_assign = {} # comp_root -> split
278
+
279
+ # First assign pinned comps
280
+ for comp, split in pinned.items():
281
+ comp_assign[comp] = split
282
+ kept_by_split[split] += comp_weight[comp]
283
+
284
+ # Sort remaining components by descending weight
285
+ remaining = [c for c in comps if c not in comp_assign]
286
+ remaining.sort(key=lambda c: comp_weight[c], reverse=True)
287
+
288
+ # Ensure nonempty splits if requested (seed with largest remaining comps)
289
+ if require_nonempty:
290
+ seeds = remaining[:min(len(S), len(remaining))]
291
+ for comp, s in zip(seeds, S):
292
+ comp_assign[comp] = s
293
+ kept_by_split[s] += comp_weight[comp]
294
+ remaining = [c for c in remaining if c not in comp_assign]
295
+
296
+ for comp in remaining:
297
+ # choose split with largest deficit (target - current)
298
+ deficits = {s: target[s] - kept_by_split[s] for s in S}
299
+ best = max(deficits, key=lambda s: deficits[s])
300
+ comp_assign[comp] = best
301
+ kept_by_split[best] += comp_weight[comp]
302
+
303
+ total_kept = sum(kept_by_split.values())
304
+
305
+ # 4) Materialize per-example indices (and verify exclusivity)
306
+ pair_to_indices = defaultdict(list)
307
+ for idx, pair in enumerate(edges):
308
+ pair_to_indices[pair].append(idx)
309
+
310
+ split_to_indices = {s: [] for s in S}
311
+ for comp, s in comp_assign.items():
312
+ for pair in comp_pairs[comp]:
313
+ split_to_indices[s].extend(pair_to_indices[pair])
314
+
315
+ # Optional: if you pinned a comp due to a small 'must-test' count but
316
+ # want to *discard* the rest instead of keeping them in test, uncomment:
317
+ # for comp, s in comp_assign.items():
318
+ # if s == "test" and test_edges_must:
319
+ # # Keep only the required counts; dump extras to 'leaky_test'
320
+ # ...
321
+ # (Left out for clarity; default is: keep the whole component in its split.)
322
+
323
+ # 5) Build edge lists and simple cluster assignments
324
+ split_to_edges = {s: [edges[i] for i in split_to_indices[s]] for s in split_to_indices}
325
+ tf_assign, dna_assign = {}, {}
326
+ for comp, s in comp_assign.items():
327
+ for (tf, dna) in comp_pairs[comp]:
328
+ tf_assign[tf] = s
329
+ dna_assign[dna] = s
330
+
331
+ # 6) Safety check: no DNA/TF appears in multiple splits
332
+ tf_in_split = defaultdict(set)
333
+ dna_in_split = defaultdict(set)
334
+ for s, elist in split_to_edges.items():
335
+ for tf, dna in elist:
336
+ tf_in_split[tf].add(s)
337
+ dna_in_split[dna].add(s)
338
+ dup_tf = {tf: ss for tf, ss in tf_in_split.items() if len(ss) > 1}
339
+ dup_dna = {dn: ss for dn, ss in dna_in_split.items() if len(ss) > 1}
340
+ assert not dup_tf and not dup_dna, f"Exclusivity violated: {dup_tf} {dup_dna}"
341
+
342
+ return tf_assign, dna_assign, kept_by_split, total_kept, split_to_indices, split_to_edges
343
+
344
+ def print_split_ratios(kept_by_split):
345
+ total = sum(kept_by_split.values())
346
+ train_pcnt = 100 * kept_by_split["train"] / total
347
+ val_pcnt = 100 * kept_by_split["val"] / total
348
+ test_pcnt = 100 * kept_by_split["test"] / total
349
+ logger.info(f"Cluster distribution - Train: {train_pcnt:.2f}%, Val: {val_pcnt:.2f}%, Test: {test_pcnt:.2f}%")
350
+
351
+ def make_edges(processed_fimo_path: str, protein_cluster_path: str, dna_cluster_path: str):
352
+ """
353
+ Make edges for input to the splitting algorithm. Edges consist of: (tr_cluster_rep)_(dna_cluster_rep) where the cluster rep is the sequence ID
354
+ """
355
+ # Read cluser data
356
+ protein_clusters = pd.read_csv(protein_cluster_path, header=None,sep="\t")
357
+ protein_clusters.columns=["tr_cluster_rep","tr_seqid"]
358
+
359
+ dna_clusters = pd.read_csv(dna_cluster_path, header=None,sep="\t")
360
+ dna_clusters.columns=["dna_cluster_rep","dna_seqid"]
361
+
362
+ # Read datapoints
363
+ edges = pd.read_parquet(processed_fimo_path)
364
+ edges = pd.merge(edges, dna_clusters, on="dna_seqid",how="left")
365
+ edges = pd.merge(edges, protein_clusters, on="tr_seqid",how="left")
366
+ edges["edge"] = edges.apply(lambda row: (row["tr_cluster_rep"], row["dna_cluster_rep"]), axis=1)
367
+
368
+ logger.info(f"Total unique edges: {len(edges['edge'].unique().tolist())}")
369
+ dup_edges = edges.loc[edges.duplicated("edge")]["edge"].unique().tolist()
370
+ logger.info(f"Total edges with >1 datapoint: {len(dup_edges)}")
371
+ logger.info(f"Total datapoints belonging to a duplicate edge: {len(edges.loc[edges['edge'].isin(dup_edges)])}")
372
+ return edges
373
+
374
+ def check_validity(train, val, test, split_by="both"):
375
+ """
376
+ Rigorous check for no overlap
377
+ Columns = ["ID","dna_sequence","tr_sequence","tr_cluster_rep","dna_cluster_rep", "scores","split"]
378
+ """
379
+ train_ids = set(train["ID"].unique().tolist())
380
+ val_ids = set(val["ID"].unique().tolist())
381
+ test_ids = set(test["ID"].unique().tolist())
382
+
383
+ assert len(train_ids.intersection(val_ids))==0
384
+ assert len(train_ids.intersection(test_ids))==0
385
+ assert len(val_ids.intersection(test_ids))==0
386
+ logger.info(f"Pass! No overlap in IDs")
387
+
388
+ if split_by!="dna":
389
+ train_tr_seqs = set(train["tr_sequence"].unique().tolist())
390
+ val_tr_seqs = set(val["tr_sequence"].unique().tolist())
391
+ test_tr_seqs = set(test["tr_sequence"].unique().tolist())
392
+
393
+ assert len(train_tr_seqs.intersection(val_tr_seqs))==0
394
+ assert len(train_tr_seqs.intersection(test_tr_seqs))==0
395
+ assert len(val_tr_seqs.intersection(test_tr_seqs))==0
396
+ logger.info(f"Pass! No overlap in TR sequences")
397
+
398
+ train_tr_reps = set(train["tr_cluster_rep"].unique().tolist())
399
+ val_tr_reps = set(val["tr_cluster_rep"].unique().tolist())
400
+ test_tr_reps = set(test["tr_cluster_rep"].unique().tolist())
401
+
402
+ assert len(train_tr_reps.intersection(val_tr_reps))==0
403
+ assert len(train_tr_reps.intersection(test_tr_reps))==0
404
+ assert len(val_tr_reps.intersection(test_tr_reps))==0
405
+ logger.info(f"Pass! No overlap in TR cluster reps")
406
+
407
+ if split_by!="protein":
408
+ train_dna_seqs = set(train["dna_sequence"].unique().tolist())
409
+ val_dna_seqs = set(val["dna_sequence"].unique().tolist())
410
+ test_dna_seqs = set(test["dna_sequence"].unique().tolist())
411
+
412
+ assert len(train_dna_seqs.intersection(val_dna_seqs))==0
413
+ assert len(train_dna_seqs.intersection(test_dna_seqs))==0
414
+ assert len(val_dna_seqs.intersection(test_dna_seqs))==0
415
+ logger.info(f"Pass! No overlap in DNA sequences")
416
+
417
+ train_dna_reps = set(train["dna_cluster_rep"].unique().tolist())
418
+ val_dna_reps = set(val["dna_cluster_rep"].unique().tolist())
419
+ test_dna_reps = set(test["dna_cluster_rep"].unique().tolist())
420
+
421
+ assert len(train_dna_reps.intersection(val_dna_reps))==0
422
+ assert len(train_dna_reps.intersection(test_dna_reps))==0
423
+ assert len(val_dna_reps.intersection(test_dna_reps))==0
424
+ logger.info(f"Pass! No overlap in DNA cluster reps")
425
+
426
+ def main(cfg: DictConfig):
427
+ """
428
+ Take a set of DNA clusters + protein clusters, and create the best possible splits into train/val/test.
429
+ """
430
+ # construct edges from training data
431
+ edge_df = make_edges(processed_fimo_path=Path(root) / cfg.data_task.input_data_path,
432
+ protein_cluster_path=Path(root) / cfg.data_task.cluster_output_paths.protein,
433
+ dna_cluster_path=Path(root) / cfg.data_task.cluster_output_paths.dna)
434
+ edges = edge_df["edge"].unique().tolist()
435
+
436
+ # figure out if we actually even have a conflict
437
+ total_proteins = len(edge_df["tr_seqid"].unique().tolist())
438
+ total_protein_clusters = len(edge_df["tr_cluster_rep"].unique().tolist())
439
+
440
+ no_protein_overlap = (total_proteins)==(total_protein_clusters)
441
+ logger.info(f"All proteins are in their own clusters: {no_protein_overlap}")
442
+
443
+ if cfg.data_task.split_by=="dna":
444
+ logger.info(f"Easy split: all proteins are in their own clusters.")
445
+ dna_clusters = edge_df["dna_cluster_rep"].unique().tolist()
446
+ results = split_bipartite_fast(
447
+ dna_clusters,
448
+ split_names=("train","val","test"),
449
+ ratios=(cfg.data_task.train_ratio, cfg.data_task.val_ratio, cfg.data_task.test_ratio),
450
+ )
451
+ dna_assign, kept_by_split = results
452
+
453
+ edge_df["split"] = edge_df["dna_seqid"].map(dna_assign)
454
+ else:
455
+ results = split_bipartite_by_components(
456
+ edges,
457
+ split_names=("train","val","test"),
458
+ ratios=(cfg.data_task.train_ratio, cfg.data_task.val_ratio, cfg.data_task.test_ratio),
459
+ require_nonempty=cfg.data_task.require_nonempty,
460
+ seed=cfg.data_task.seed,
461
+ test_edges_must=None,
462
+ )
463
+
464
+ tf_assign, dna_assign, kept_by_split, total_kept, split_to_indices, split_to_edges = results
465
+
466
+ # Map each sample to its split
467
+ print(tf_assign)
468
+ print(dna_assign)
469
+ edge_df["tr_split"] = edge_df["tr_cluster_rep"].map(tf_assign)
470
+ edge_df["dna_split"] = edge_df["dna_cluster_rep"].map(dna_assign)
471
+ edge_df["same_split"] = edge_df["tr_split"]==edge_df["dna_split"] # should always be true if easy cluster
472
+ edge_df["split"] = edge_df["tr_split"]
473
+ print(edge_df)
474
+ edge_df["split"] = np.where(
475
+ edge_df["same_split"],
476
+ edge_df["split"], # keep existing split if same_split == True
477
+ "leak" # otherwise leak
478
+ )
479
+ print(edge_df)
480
+
481
+ # Print ratios: hopefully close to desired (e.g. 80/10/10)
482
+ print_split_ratios(kept_by_split)
483
+
484
+ # Make train, val, test sets
485
+ # make sure no ID is duplicate
486
+ assert len(edge_df["ID"].unique())==len(edge_df)
487
+ split_cols = ["ID","dna_sequence","tr_sequence","tr_cluster_rep","dna_cluster_rep", "scores","split"]
488
+ train = edge_df.loc[
489
+ edge_df["split"]=="train"
490
+ ].reset_index(drop=True)[split_cols]
491
+ val = edge_df.loc[
492
+ edge_df["split"]=="val"
493
+ ].reset_index(drop=True)[split_cols]
494
+ test = edge_df.loc[
495
+ edge_df["split"]=="test"
496
+ ].reset_index(drop=True)[split_cols]
497
+
498
+ # ensure there is no overlap
499
+ check_validity(train, val, test, split_by=cfg.data_task.split_by)
500
+
501
+ logger.info(f"Length of train dataset: {len(train)} ({100*len(train)/sum([len(train),len(val),len(test)]):.2f}%)")
502
+ logger.info(f"Length of val dataset: {len(val)} ({100*len(val)/sum([len(train),len(val),len(test)]):.2f}%)")
503
+ logger.info(f"Length of test dataset: {len(test)} ({100*len(test)/sum([len(train),len(val),len(test)]):.2f}%)")
504
+
505
+ # create the output dir
506
+ split_out_dir = Path(root)/cfg.data_task.split_out_dir
507
+ os.makedirs(split_out_dir, exist_ok=True)
508
+ split_final_cols = ["ID","dna_sequence","tr_sequence","scores","split"]
509
+ train[split_final_cols].to_csv(split_out_dir/"train.csv", index=False)
510
+ val[split_final_cols].to_csv(split_out_dir/"val.csv", index=False)
511
+ test[split_final_cols].to_csv(split_out_dir/"test.csv", index=False)
512
+ logger.info(f"Saved all splits to {split_out_dir}")
dpacman/scripts/preprocess.py CHANGED
@@ -1,11 +1,9 @@
1
  import rootutils
2
-
3
- root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
4
-
5
  import hydra
6
  from omegaconf import DictConfig
7
  import logging
8
 
 
9
  logger = logging.getLogger(__name__)
10
 
11
  # import your processing entry points here
@@ -15,7 +13,8 @@ from dpacman.data_tasks.clean.remap import main as clean_remap_main
15
  from dpacman.data_tasks.fimo.pre_fimo import main as pre_fimo_main
16
  from dpacman.data_tasks.fimo.run_fimo import main as run_fimo_main
17
  from dpacman.data_tasks.fimo.post_fimo import main as post_fimo_main
18
-
 
19
 
20
  @hydra.main(
21
  config_path=str(root / "configs"), config_name="preprocess", version_base="1.3"
@@ -26,6 +25,7 @@ def main(cfg: DictConfig):
26
 
27
  logger.info(f"Running {task_type} task: {task_name}")
28
 
 
29
  if task_type == "download":
30
  if task_name == "genome":
31
  download_genome_main(cfg)
@@ -34,12 +34,14 @@ def main(cfg: DictConfig):
34
  else:
35
  raise ValueError(f"No download pipeline defined for: {task_name}")
36
 
 
37
  elif task_type == "clean":
38
  if task_name == "remap":
39
  clean_remap_main(cfg)
40
  else:
41
  raise ValueError(f"No clean pipeline defined for: {task_name}")
42
 
 
43
  elif task_type == "fimo":
44
  if task_name == "pre_fimo":
45
  pre_fimo_main(cfg)
@@ -50,6 +52,20 @@ def main(cfg: DictConfig):
50
  else:
51
  raise ValueError(f"No clean pipeline defined for: {task_name}")
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  else:
54
  raise ValueError(f"Unknown task type: {task_type}")
55
 
 
1
  import rootutils
 
 
 
2
  import hydra
3
  from omegaconf import DictConfig
4
  import logging
5
 
6
+ root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
7
  logger = logging.getLogger(__name__)
8
 
9
  # import your processing entry points here
 
13
  from dpacman.data_tasks.fimo.pre_fimo import main as pre_fimo_main
14
  from dpacman.data_tasks.fimo.run_fimo import main as run_fimo_main
15
  from dpacman.data_tasks.fimo.post_fimo import main as post_fimo_main
16
+ from dpacman.data_tasks.cluster.remap import main as cluster_remap_main
17
+ from dpacman.data_tasks.split.remap import main as split_remap_main
18
 
19
  @hydra.main(
20
  config_path=str(root / "configs"), config_name="preprocess", version_base="1.3"
 
25
 
26
  logger.info(f"Running {task_type} task: {task_name}")
27
 
28
+ # Download
29
  if task_type == "download":
30
  if task_name == "genome":
31
  download_genome_main(cfg)
 
34
  else:
35
  raise ValueError(f"No download pipeline defined for: {task_name}")
36
 
37
+ # Clean
38
  elif task_type == "clean":
39
  if task_name == "remap":
40
  clean_remap_main(cfg)
41
  else:
42
  raise ValueError(f"No clean pipeline defined for: {task_name}")
43
 
44
+ # Fimo
45
  elif task_type == "fimo":
46
  if task_name == "pre_fimo":
47
  pre_fimo_main(cfg)
 
52
  else:
53
  raise ValueError(f"No clean pipeline defined for: {task_name}")
54
 
55
+ # Cluster
56
+ elif task_type == "cluster":
57
+ if task_name == "remap":
58
+ cluster_remap_main(cfg)
59
+ else:
60
+ raise ValueError(f"No clean pipeline defined for: {task_name}")
61
+
62
+ elif task_type == "split":
63
+ if task_name == "remap":
64
+ split_remap_main(cfg)
65
+ else:
66
+ raise ValueError(f"No clean pipeline defined for: {task_name}")
67
+
68
+ # Unknown - error
69
  else:
70
  raise ValueError(f"Unknown task type: {task_type}")
71
 
dpacman/scripts/run_cluster.sh ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Manually specify values used in the config
4
+ main_task="preprocess"
5
+ data_task_type="cluster"
6
+ timestamp=$(date "+%Y-%m-%d_%H-%M-%S")
7
+
8
+ run_dir="/vast/projects/pranam/lab/sophie/DPACMAN/logs/${main_task}/${data_task_type}/runs/${timestamp}"
9
+ mkdir -p "$run_dir"
10
+
11
+ nohup python -u -m scripts.preprocess \
12
+ hydra.run.dir="${run_dir}" \
13
+ data_task=${data_task_type}/remap \
14
+ data_task.cluster_dna_full="false" \
15
+ data_task.cluster_dna_peaks="false" \
16
+ data_task.cluster_protein="true" \
17
+ > "${run_dir}/run.log" 2>&1 &
18
+
19
+ echo $! > "${run_dir}/pid.txt"
dpacman/scripts/run_fimo.sh CHANGED
@@ -5,13 +5,15 @@ main_task="preprocess"
5
  data_task_type="fimo"
6
  timestamp=$(date "+%Y-%m-%d_%H-%M-%S")
7
 
8
- run_dir="$HOME/DPACMAN/logs/${main_task}/${data_task_type}/runs/${timestamp}"
9
  mkdir -p "$run_dir"
10
 
11
  nohup python -u -m scripts.preprocess \
12
  hydra.run.dir="${run_dir}" \
13
  data_task=${data_task_type}/post_fimo \
14
  data_task.debug="false" \
 
 
15
  > "${run_dir}/run.log" 2>&1 &
16
 
17
  echo $! > "${run_dir}/pid.txt"
 
5
  data_task_type="fimo"
6
  timestamp=$(date "+%Y-%m-%d_%H-%M-%S")
7
 
8
+ run_dir="/vast/projects/pranam/lab/sophie/DPACMAN/logs/${main_task}/${data_task_type}/runs/${timestamp}"
9
  mkdir -p "$run_dir"
10
 
11
  nohup python -u -m scripts.preprocess \
12
  hydra.run.dir="${run_dir}" \
13
  data_task=${data_task_type}/post_fimo \
14
  data_task.debug="false" \
15
+ data_task.keep_fimo_only="false" \
16
+ data_task.processed_output_csv="dpacman/data_files/processed/fimo/post_fimo/all_peaks/remap2022_crm_fimo_output_q_processed.csv" \
17
  > "${run_dir}/run.log" 2>&1 &
18
 
19
  echo $! > "${run_dir}/pid.txt"
dpacman/scripts/run_split.sh ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Manually specify values used in the config
4
+ main_task="preprocess"
5
+ data_task_type="split"
6
+ timestamp=$(date "+%Y-%m-%d_%H-%M-%S")
7
+
8
+ run_dir="/vast/projects/pranam/lab/sophie/DPACMAN/logs/${main_task}/${data_task_type}/runs/${timestamp}"
9
+ mkdir -p "$run_dir"
10
+
11
+ nohup python -u -m scripts.preprocess \
12
+ hydra.run.dir="${run_dir}" \
13
+ data_task="${data_task_type}/remap" \
14
+ data_task.split_by=dna \
15
+ data_task.train_ratio=0.8 \
16
+ data_task.val_ratio=0.1 \
17
+ data_task.test_ratio=0.1 \
18
+ data_task.split_out_dir=dpacman/data_files/processed/splits/by_dna \
19
+ > "${run_dir}/run.log" 2>&1 &
20
+
21
+ echo $! > "${run_dir}/pid.txt"
dpacman/utils/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ Code in this folder was taken from olsdavis/gdd repository, which was based on DPLM (Diffusion Protein Language Model)
dpacman/utils/__init__.py ADDED
File without changes
dpacman/utils/clustering.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import os
3
+ import subprocess
4
+ import sys
5
+ from Bio import SeqIO
6
+ import shutil
7
+
8
+ import rootutils
9
+ import logging
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ def ensure_mmseqs_in_path(mmseqs_dir):
14
+ """
15
+ Checks if MMseqs2 is in the PATH. If it's not, add it. MMseqs2 will not run if this is not done correctly.
16
+
17
+ Args:
18
+ mmseqs_dir (str): Directory containing MMseqs2 binaries
19
+ """
20
+ mmseqs_bin = os.path.join(mmseqs_dir, 'mmseqs')
21
+
22
+ # Check if mmseqs is already in PATH
23
+ if shutil.which('mmseqs') is None:
24
+ # Export the MMseqs2 directory to PATH
25
+ os.environ['PATH'] = f"{mmseqs_dir}:{os.environ['PATH']}"
26
+ logger.info(f"\tAdded {mmseqs_dir} to PATH")
27
+
28
+ def process_fasta(fasta_path):
29
+ fasta_sequences = SeqIO.parse(open(fasta_path),'fasta')
30
+ d = {}
31
+ for fasta in fasta_sequences:
32
+ id, sequence = fasta.id, str(fasta.seq)
33
+
34
+ d[id] = sequence
35
+
36
+ return d
37
+
38
+ def analyze_clustering_result(input_fasta: str, tsv_path: str):
39
+ """
40
+ Args:
41
+ input_fasta (str): path to input fasta file
42
+ """
43
+
44
+ # Process input fasta
45
+ input_d = process_fasta(input_fasta)
46
+
47
+ # Process clusters.tsv
48
+ clusters = pd.read_csv(f'{tsv_path}',sep='\t',header=None)
49
+ clusters = clusters.rename(columns={
50
+ 0: 'representative seq_id',
51
+ 1: 'member seq_id'
52
+ })
53
+
54
+ clusters['representative seq'] = clusters['representative seq_id'].apply(lambda seq_id: input_d[seq_id])
55
+ clusters['member seq'] = clusters['member seq_id'].apply(lambda seq_id: input_d[seq_id])
56
+
57
+ # Sort them so that splitting results are reproducible
58
+ clusters = clusters.sort_values(by=['representative seq_id','member seq_id'],ascending=True).reset_index(drop=True)
59
+
60
+ return clusters
61
+
62
+ def make_fasta(sequences: dict, fasta_path: str):
63
+ """
64
+ Makes a fasta file from sequences, where the key is the header and the value is the sequence.
65
+
66
+ Args:
67
+ sequences (dict): A dictionary where the key is the header and the value is the sequence.
68
+
69
+ Returns:
70
+ str: The path to the fasta file.
71
+ """
72
+ with open(fasta_path, 'w') as f:
73
+ for header, sequence in sequences.items():
74
+ f.write(f'>{header}\n{sequence}\n')
75
+
76
+ return fasta_path
77
+
78
+ def run_mmseqs_clustering(input_fasta, output_dir, min_seq_id=0.3, c=0.8, cov_mode=0, cluster_mode=0, path_to_mmseqs='fuson_plm/mmseqs', dbtype=1):
79
+ """
80
+ Runs MMSeqs2 clustering using easycluster module
81
+
82
+ Args:
83
+ input_fasta (str): path to input fasta file, formatted >header\nsequence\n>header\nsequence....
84
+ output_dir (str): path to output dir for clustering results
85
+ min_seq_id (float): number [0,1] representing --min-seq-id in cluster command
86
+ c (float): nunber [0,1] representing -c in cluster command
87
+ cov_mode (int): number 0, 1, 2, or 3 representing --cov-mode in cluster command
88
+ cluster_mode (int): number 0, 1, or 2 representing --cluster-mode in cluster command
89
+
90
+ """
91
+ # Get mmseqs dir
92
+ logger.info("\nRunning MMSeqs clustering...")
93
+ mmseqs_dir = os.path.join(path_to_mmseqs[0:path_to_mmseqs.index('/mmseqs')], 'mmseqs/bin')
94
+ logger.info(f"Running mmseqs clustering from {mmseqs_dir}")
95
+
96
+ # Ensure MMseqs2 is in the PATH
97
+ ensure_mmseqs_in_path(mmseqs_dir)
98
+
99
+ # Define paths for MMseqs2
100
+ mmseqs_bin = "mmseqs" # Ensure this is in your PATH or provide the full path to mmseqs binary
101
+
102
+ # Create the output directory
103
+ os.makedirs(output_dir, exist_ok=True)
104
+
105
+ # Run MMseqs2 easy-cluster
106
+ cmd_easy_cluster = [
107
+ mmseqs_bin, "easy-cluster", input_fasta, os.path.join(output_dir, "mmseqs"), output_dir,
108
+ "--min-seq-id", str(min_seq_id),
109
+ "-c", str(c),
110
+ "--cov-mode", str(cov_mode),
111
+ "--cluster-mode", str(cluster_mode),
112
+ "--dbtype", str(dbtype)
113
+ ]
114
+
115
+ # Write the command to a log file
116
+ logger.info("\n\tCommand entered to MMSeqs2:")
117
+ logger.info("\t" + " ".join(cmd_easy_cluster) + "\n")
118
+
119
+ subprocess.run(cmd_easy_cluster, check=True)
120
+
121
+ logger.info(f"Clustering completed. Results are in {output_dir}")
122
+
123
+ def cluster_summary(clusters: pd.DataFrame):
124
+ """
125
+ Summarizes how many clusters were formed, how big they are, etc ...
126
+ """
127
+ grouped_clusters = clusters.groupby('representative seq_id')['member seq_id'].count().reset_index().rename(columns={'member seq_id':'member count'})
128
+ assert len(grouped_clusters) == len(clusters['representative seq_id'].unique()) # make sure number of cluster reps = # grouped clusters
129
+
130
+ total_seqs = sum(grouped_clusters['member count'])
131
+ logger.info(f"Created {len(grouped_clusters)} clusters of {total_seqs} sequences")
132
+ logger.info(f"\t{len(grouped_clusters.loc[grouped_clusters['member count']==1])} clusters of size 1")
133
+ csize1_seqs = sum(grouped_clusters[grouped_clusters['member count']==1]['member count'])
134
+ logger.info(f"\t\tsequences: {csize1_seqs} ({round(100*csize1_seqs/total_seqs, 2)}%)")
135
+
136
+ logger.info(f"\t{len(grouped_clusters.loc[grouped_clusters['member count']>1])} clusters of size > 1")
137
+ csizeg1_seqs = sum(grouped_clusters[grouped_clusters['member count']>1]['member count'])
138
+ logger.info(f"\t\tsequences: {csizeg1_seqs} ({round(100*csizeg1_seqs/total_seqs, 2)}%)")
139
+ logger.info(f"\tlargest cluster: {max(grouped_clusters['member count'])}")
140
+
141
+ logger.info("\nCluster size breakdown below...")
142
+
143
+ value_counts = grouped_clusters['member count'].value_counts().reset_index().rename(columns={'member count':'cluster size (n_members)','count': 'n_clusters'})
144
+ logger.info(value_counts.sort_values(by='cluster size (n_members)',ascending=True).to_string(index=False))
dpacman/utils/models.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model-related utilities, such as setting seed
3
+ """
4
+ import torch
5
+ import numpy as np
6
+ import random
7
+ import os
8
+
9
+ def set_seed(seed: int = 42) -> None:
10
+ np.random.seed(seed)
11
+ random.seed(seed)
12
+ torch.manual_seed(seed)
13
+ torch.cuda.manual_seed(seed)
14
+ # When running on the CuDNN backend, two further options must be set
15
+ torch.backends.cudnn.deterministic = True
16
+ torch.backends.cudnn.benchmark = False
17
+ # Set a fixed value for the hash seed
18
+ os.environ["PYTHONHASHSEED"] = str(seed)
19
+ print(f"Random seed set as {seed}")
dpacman/utils/splitting.py ADDED
File without changes
environment.yaml CHANGED
@@ -1,42 +1,45 @@
1
- # reasons you might want to use `environment.yaml` instead of `requirements.txt`:
2
- # - pip installs packages in a loop, without ensuring dependencies across all packages
3
- # are fulfilled simultaneously, but conda achieves proper dependency control across
4
- # all packages
5
- # - conda allows for installing packages without requiring certain compilers or
6
- # libraries to be available in the system, since it installs precompiled binaries
7
- # in case of errors look here: https://pytorch.org/get-started/previous-versions/
8
-
9
- name: dpacman
10
-
11
  channels:
12
- - bioconda
13
  - conda-forge
14
- - defaults
15
-
16
- # it is strongly recommended to specify versions of packages installed through conda
17
- # to avoid situation when version-unspecified packages install their latest major
18
- # versions which can sometimes break things
19
-
20
- # current approach below keeps the dependencies in the same major versions across all
21
- # users, but allows for different minor and patch versions of packages where backwards
22
- # compatibility is usually guaranteed
23
-
24
  dependencies:
25
  - python=3.10
26
- - dask[complete]
27
  - pip>=23
28
- - ghostscript=9.18
 
 
 
 
 
 
 
 
 
 
 
29
  - pip:
30
- - rootutils==1.0.7
31
- - polars==1.32.2
32
- - hydra-core==1.3.2 # Hydra for config management
33
- - hydra-colorlog==1.2.0 # Allow colorful logging in Hydra
34
- - omegaconf==2.3.0 # Required by hydra-core
35
- - pandas==2.2.3
36
- - lxml==5.3.0
37
- - pymex==0.9.31
38
- - gitpython==3.1.44
39
- - black==25.1.0 # code formatter
40
- - tqdm==4.67.1
41
- - matplotlib==3.10.3
42
- - -e .
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: dnabind
 
 
 
 
 
 
 
 
 
2
  channels:
 
3
  - conda-forge
4
+ - bioconda
 
 
 
 
 
 
 
 
 
5
  dependencies:
6
  - python=3.10
 
7
  - pip>=23
8
+ # dask "complete" equivalent via conda packages
9
+ - dask
10
+ - distributed
11
+ - rich=13.*
12
+ # ghostscript 9.18 is very old; keep only if you truly need that exact version.
13
+ # If not, drop the pin or use the conda-forge version (>=10).
14
+ - ghostscript
15
+ - lxml=5.3.0
16
+ - pandas=2.2.3
17
+ - gitpython=3.1.*
18
+ - tqdm=4.67.*
19
+ - matplotlib=3.10.*
20
  - pip:
21
+ # Pull GPU wheels (CUDA 12.8) from PyTorch's cu128 index; fall back to PyPI for others
22
+ - --index-url https://download.pytorch.org/whl/cu128
23
+ - --extra-index-url https://pypi.org/simple
24
+
25
+ # PyTorch + CUDA 12.8
26
+ - torch==2.7.1
27
+ - torchvision==0.22.1
28
+ # - torchaudio==2.7.1 # optional, if you need it
29
+
30
+ # Lightning (classic)
31
+ - pytorch-lightning
32
+
33
+ # Your pinned Python libs
34
+ - rootutils==1.0.7
35
+ - polars==1.32.2
36
+ - hydra-core==1.3.2
37
+ - hydra-colorlog==1.2.0
38
+ - omegaconf==2.3.0
39
+ - pymex==0.9.31
40
+ - transformers==4.51.2
41
+ - scikit-learn==1.7.1
42
+ - biopython==1.85
43
+ - ortools==9.14.6206
44
+ # Your package in editable mode
45
+ - -e .