temp changes for gpu switch
Browse files- .gitignore +9 -1
- configs/data_modules/pair.yaml +9 -0
- configs/data_task/cluster/remap.yaml +44 -0
- configs/data_task/fimo/post_fimo.yaml +9 -4
- configs/data_task/split/remap.yaml +23 -0
- configs/models/classifier.yaml +11 -0
- configs/models/pooling/truncatedsvd.yaml +7 -0
- configs/train.yaml +8 -0
- dpacman/classifier/model/loss.py +34 -0
- dpacman/classifier/model/train.py +0 -79
- dpacman/classifier/model_tmp/__init__.py +0 -0
- dpacman/classifier/model_tmp/clustering_data.py +383 -0
- dpacman/classifier/model_tmp/compress_embeddings.py +54 -0
- dpacman/{data_tasks/embeddings → classifier/model_tmp}/compute_embeddings.py +342 -141
- dpacman/classifier/model_tmp/extract_tf_symbols.py +27 -0
- dpacman/classifier/model_tmp/make_pair_list.py +220 -0
- dpacman/classifier/model_tmp/make_peak_fasta.py +13 -0
- dpacman/classifier/model_tmp/model.py +103 -0
- dpacman/classifier/model_tmp/prep_splits.py +133 -0
- dpacman/classifier/model_tmp/train.py +180 -0
- dpacman/data_modules/__init__.py +0 -0
- dpacman/data_modules/pair.py +342 -0
- dpacman/data_tasks/cluster/__init__.py +0 -0
- dpacman/data_tasks/cluster/remap.py +144 -0
- dpacman/data_tasks/embeddings/embedders.py +560 -0
- dpacman/data_tasks/fimo/__init__.py +0 -0
- dpacman/data_tasks/fimo/post_fimo.py +526 -124
- dpacman/data_tasks/fimo/run_fimo.py +16 -7
- dpacman/data_tasks/split/__init__.py +0 -0
- dpacman/data_tasks/split/remap.py +512 -0
- dpacman/scripts/preprocess.py +20 -4
- dpacman/scripts/run_cluster.sh +19 -0
- dpacman/scripts/run_fimo.sh +3 -1
- dpacman/scripts/run_split.sh +21 -0
- dpacman/utils/README.md +1 -0
- dpacman/utils/__init__.py +0 -0
- dpacman/utils/clustering.py +144 -0
- dpacman/utils/models.py +19 -0
- dpacman/utils/splitting.py +0 -0
- environment.yaml +39 -36
.gitignore
CHANGED
|
@@ -17,4 +17,12 @@ dpacman/scripts/__pycache__/
|
|
| 17 |
dpacman/temp.py
|
| 18 |
dpacman/temp2.py
|
| 19 |
logs/
|
| 20 |
-
tree.txt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
dpacman/temp.py
|
| 18 |
dpacman/temp2.py
|
| 19 |
logs/
|
| 20 |
+
tree.txt
|
| 21 |
+
dpacman/utils/ubuntu_font/
|
| 22 |
+
dpacman/tf_clusters.tsv
|
| 23 |
+
dpacman/tf_family_tree.nwk
|
| 24 |
+
dpacman/idmap_filt.csv
|
| 25 |
+
dpacman/temp3.py
|
| 26 |
+
dpacman/temp4.py
|
| 27 |
+
dpacman/temp.ipynb
|
| 28 |
+
dpacman/nohup.out
|
configs/data_modules/pair.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
train_file: data_files/splits/train.csv
|
| 3 |
+
val_file: data_files/splits/val.csv
|
| 4 |
+
test_file: data_files/splits/test.csv
|
| 5 |
+
|
| 6 |
+
batch_size: 32
|
| 7 |
+
num_workers: 8
|
| 8 |
+
|
| 9 |
+
maximize_num_workers: False
|
configs/data_task/cluster/remap.yaml
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: remap
|
| 2 |
+
type: cluster
|
| 3 |
+
|
| 4 |
+
max_protein_length: 1998
|
| 5 |
+
|
| 6 |
+
cluster_dna_full: true
|
| 7 |
+
cluster_dna_peaks: true
|
| 8 |
+
cluster_protein: false
|
| 9 |
+
|
| 10 |
+
dna_full:
|
| 11 |
+
input_map_path: dpacman/data_files/processed/fimo/post_fimo/fimo_hits_only/maps/dna_seqid_to_dna_sequence.json
|
| 12 |
+
fasta_path: dpacman/data_files/processed/mmseqs/inputs/fimo_hits_only/dna_full.fasta
|
| 13 |
+
output_dir: dpacman/data_files/processed/mmseqs/outputs/fimo_hits_only/dna_full
|
| 14 |
+
mmseqs:
|
| 15 |
+
min_seq_id: 0.3
|
| 16 |
+
c: 0.8
|
| 17 |
+
cov_mode: 0
|
| 18 |
+
cluster_mode: 0
|
| 19 |
+
|
| 20 |
+
dna_peaks:
|
| 21 |
+
input_map_path: dpacman/data_files/processed/fimo/post_fimo/fimo_hits_only/maps/peak_seqid_to_peak_sequence.json
|
| 22 |
+
fasta_path: dpacman/data_files/processed/mmseqs/inputs/fimo_hits_only/dna_peaks.fasta
|
| 23 |
+
output_dir: dpacman/data_files/processed/mmseqs/outputs/fimo_hits_only/dna_peaks
|
| 24 |
+
mmseqs:
|
| 25 |
+
min_seq_id: 0.3
|
| 26 |
+
c: 0.8
|
| 27 |
+
cov_mode: 0
|
| 28 |
+
cluster_mode: 0
|
| 29 |
+
|
| 30 |
+
protein:
|
| 31 |
+
input_map_path: dpacman/data_files/processed/fimo/post_fimo/fimo_hits_only/maps/tr_seqid_to_tr_sequence.json
|
| 32 |
+
fasta_path: dpacman/data_files/processed/mmseqs/inputs/fimo_hits_only/protein.fasta
|
| 33 |
+
output_dir: dpacman/data_files/processed/mmseqs/outputs/fimo_hits_only/protein
|
| 34 |
+
mmseqs:
|
| 35 |
+
min_seq_id: 0.3
|
| 36 |
+
c: 0.8
|
| 37 |
+
cov_mode: 0
|
| 38 |
+
cluster_mode: 0
|
| 39 |
+
|
| 40 |
+
input_data_path: /vast/projects/pranam/lab/sophie/DPACMAN/dpacman/data_files/processed/fimo/post_fimo/fimo_hits_only/remap2022_crm_fimo_output_q_processed_seed0.parquet
|
| 41 |
+
path_to_mmseqs: /vast/projects/pranam/lab/shared/mmseqs
|
| 42 |
+
|
| 43 |
+
out_dir: na
|
| 44 |
+
final_csv: na
|
configs/data_task/fimo/post_fimo.yaml
CHANGED
|
@@ -2,11 +2,16 @@ name: post_fimo
|
|
| 2 |
type: fimo
|
| 3 |
|
| 4 |
fimo_out_dir: dpacman/data_files/processed/fimo/fimo_out_q
|
| 5 |
-
|
| 6 |
-
processed_output_csv: dpacman/data_files/processed/fimo/remap2022_crm_fimo_output_q_processed.csv
|
| 7 |
json_dir: dpacman/data_files/raw/genomes/hg38
|
| 8 |
-
idmap_path: dpacman/data_files/raw/remap/
|
|
|
|
| 9 |
|
| 10 |
-
jaspar_boost:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
debug: false
|
|
|
|
| 2 |
type: fimo
|
| 3 |
|
| 4 |
fimo_out_dir: dpacman/data_files/processed/fimo/fimo_out_q
|
| 5 |
+
processed_output_csv: dpacman/data_files/processed/fimo/post_fimo/fimo_hits_only/remap2022_crm_fimo_output_q_processed.csv
|
|
|
|
| 6 |
json_dir: dpacman/data_files/raw/genomes/hg38
|
| 7 |
+
idmap_path: dpacman/data_files/raw/remap/idmapping_reviewed_true_2025_08_15.tsv
|
| 8 |
+
remap_path: /vast/projects/pranam/lab/sophie/DPACMAN/dpacman/data_files/processed/fimo/remap2022_crm_fimo_input.csv
|
| 9 |
|
| 10 |
+
jaspar_boost: 333
|
| 11 |
+
|
| 12 |
+
keep_fimo_only: true
|
| 13 |
+
|
| 14 |
+
seeds: [0]
|
| 15 |
+
max_protein_len: 1998
|
| 16 |
|
| 17 |
debug: false
|
configs/data_task/split/remap.yaml
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: remap
|
| 2 |
+
type: split
|
| 3 |
+
|
| 4 |
+
max_protein_length: 1998
|
| 5 |
+
|
| 6 |
+
cluster_output_paths:
|
| 7 |
+
dna: dpacman/data_files/processed/mmseqs/outputs/fimo_hits_only/dna_full/mmseqs_cluster.tsv
|
| 8 |
+
protein: dpacman/data_files/processed/mmseqs/outputs/fimo_hits_only/protein/mmseqs_cluster.tsv
|
| 9 |
+
|
| 10 |
+
input_data_path: dpacman/data_files/processed/fimo/post_fimo/fimo_hits_only/remap2022_crm_fimo_output_q_processed_seed0.parquet
|
| 11 |
+
split_out_dir: dpacman/data_files/processed/splits
|
| 12 |
+
|
| 13 |
+
split_by: both # protein, dna, or both
|
| 14 |
+
|
| 15 |
+
test_ratio: 0.10
|
| 16 |
+
val_ratio: 0.10
|
| 17 |
+
train_ratio: 0.80
|
| 18 |
+
|
| 19 |
+
require_nonempty: true
|
| 20 |
+
ratio_tolerance: null
|
| 21 |
+
bigM: null
|
| 22 |
+
|
| 23 |
+
seed: 0
|
configs/models/classifier.yaml
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: classifier
|
| 2 |
+
type: train
|
| 3 |
+
|
| 4 |
+
params:
|
| 5 |
+
epochs: 10
|
| 6 |
+
batch_size: 32
|
| 7 |
+
lr: 1e-4
|
| 8 |
+
seed: 42
|
| 9 |
+
|
| 10 |
+
out_dir: null
|
| 11 |
+
pair_list: null
|
configs/models/pooling/truncatedsvd.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
n_components: 2
|
| 2 |
+
algorithm: randomized
|
| 3 |
+
n_iter: 5
|
| 4 |
+
n_oversamples: 10
|
| 5 |
+
poewr_iteration_normalizer: auto
|
| 6 |
+
random_state: 42
|
| 7 |
+
tol: 0
|
configs/train.yaml
CHANGED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- _self_
|
| 3 |
+
- paths: default
|
| 4 |
+
- hydra: default # ← tells Hydra to use the logging/output config
|
| 5 |
+
- trainer: gpu
|
| 6 |
+
- data_task: model/classifier
|
| 7 |
+
|
| 8 |
+
task_name: train/${data_task.type}
|
dpacman/classifier/model/loss.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Define loss functions needed for training the model
|
| 3 |
+
"""
|
| 4 |
+
import torch
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
|
| 7 |
+
def combined_loss_components(logits, targets, peak_thresh=0.5, eps=1e-8):
|
| 8 |
+
probs = torch.sigmoid(logits)
|
| 9 |
+
labels = (targets >= peak_thresh).float()
|
| 10 |
+
non_peak_mask = (labels == 0).float()
|
| 11 |
+
peak_mask = (labels == 1).float()
|
| 12 |
+
|
| 13 |
+
bce_all = F.binary_cross_entropy_with_logits(logits, labels, reduction='none')
|
| 14 |
+
bce_non = (bce_all * non_peak_mask)
|
| 15 |
+
bce_non = bce_non.sum() / (non_peak_mask.sum() + eps)
|
| 16 |
+
|
| 17 |
+
mse_peaks = F.mse_loss(probs * peak_mask, targets * peak_mask, reduction='sum') \
|
| 18 |
+
/ (peak_mask.sum() + eps)
|
| 19 |
+
|
| 20 |
+
t_dist = (targets + eps)
|
| 21 |
+
p_dist = (probs + eps)
|
| 22 |
+
t_dist = t_dist / t_dist.sum(dim=1, keepdim=True)
|
| 23 |
+
p_dist = p_dist / p_dist.sum(dim=1, keepdim=True)
|
| 24 |
+
kl = (t_dist * (t_dist.clamp(min=eps).log() - p_dist.clamp(min=eps).log())).sum(dim=1).mean()
|
| 25 |
+
|
| 26 |
+
return bce_non, kl, mse_peaks, probs
|
| 27 |
+
|
| 28 |
+
def accuracy_percentage(logits, targets, peak_thresh=0.5):
|
| 29 |
+
probs = torch.sigmoid(logits)
|
| 30 |
+
preds_bin = (probs >= 0.5).float()
|
| 31 |
+
labels = (targets >= peak_thresh).float()
|
| 32 |
+
correct = (preds_bin == labels).float().sum()
|
| 33 |
+
total = torch.numel(labels)
|
| 34 |
+
return (correct / max(1, total)).item() * 100.0
|
dpacman/classifier/model/train.py
CHANGED
|
@@ -77,85 +77,6 @@ def build_tf_cache(tf_paths, target_dim=256):
|
|
| 77 |
print("[i] TF cache ready (raw); compression will be learned.", flush=True)
|
| 78 |
return cache
|
| 79 |
|
| 80 |
-
# ─────────────── Dataset & Collation ─────────────────────────────────────
|
| 81 |
-
class PairDataset(Dataset):
|
| 82 |
-
def __init__(self, tf_paths, dna_paths, final_df, tf_cache):
|
| 83 |
-
self.tf_paths, self.dna_paths = tf_paths, dna_paths
|
| 84 |
-
self.tf_cache = tf_cache
|
| 85 |
-
self.targets = {}
|
| 86 |
-
for _, row in final_df.iterrows():
|
| 87 |
-
dna_id = row["dna_id"]
|
| 88 |
-
vec = np.array(list(map(float, row["score_sig_r2"].split(","))), dtype=np.float32)
|
| 89 |
-
self.targets[dna_id] = vec
|
| 90 |
-
|
| 91 |
-
def __len__(self): return len(self.tf_paths)
|
| 92 |
-
|
| 93 |
-
def __getitem__(self, i):
|
| 94 |
-
b = self.tf_cache[self.tf_paths[i]] # (L_b, D_b) or (D_b,)
|
| 95 |
-
if b.ndim==1: b = b[None,:]
|
| 96 |
-
g = np.load(self.dna_paths[i]) # (L_g, 256) or (256,)
|
| 97 |
-
if g.ndim==1: g = g[None,:]
|
| 98 |
-
|
| 99 |
-
stem = Path(self.dna_paths[i]).stem
|
| 100 |
-
dna_id = stem.replace("dna_","")
|
| 101 |
-
t = self.targets.get(dna_id, np.zeros(g.shape[0],dtype=np.float32))
|
| 102 |
-
|
| 103 |
-
return torch.from_numpy(b).float(), \
|
| 104 |
-
torch.from_numpy(g).float(), \
|
| 105 |
-
torch.from_numpy(t).float()
|
| 106 |
-
|
| 107 |
-
def collate_fn(batch):
|
| 108 |
-
Bs = [b.shape[0] for b,_,_ in batch]
|
| 109 |
-
Gs = [g.shape[0] for _,g,_ in batch]
|
| 110 |
-
maxB, maxG = max(Bs), max(Gs)
|
| 111 |
-
|
| 112 |
-
def pad_seq(x, L):
|
| 113 |
-
if x.shape[0] < L:
|
| 114 |
-
pad = torch.zeros((L-x.shape[0], x.shape[1]), dtype=x.dtype, device=x.device)
|
| 115 |
-
return torch.cat([x, pad], dim=0)
|
| 116 |
-
return x
|
| 117 |
-
|
| 118 |
-
def pad_t(y, L):
|
| 119 |
-
if y.shape[0] < L:
|
| 120 |
-
pad = torch.zeros((L-y.shape[0],), dtype=y.dtype, device=y.device)
|
| 121 |
-
return torch.cat([y, pad], dim=0)
|
| 122 |
-
return y
|
| 123 |
-
|
| 124 |
-
b_stack = torch.stack([pad_seq(b, maxB) for b,_,_ in batch])
|
| 125 |
-
g_stack = torch.stack([pad_seq(g, maxG) for _,g,_ in batch])
|
| 126 |
-
t_stack = torch.stack([pad_t(t, maxG) for *_,t in batch])
|
| 127 |
-
return b_stack, g_stack, t_stack
|
| 128 |
-
|
| 129 |
-
# ─────────────── losses, metrics ─────────────────────────────────────────
|
| 130 |
-
def combined_loss_components(logits, targets, peak_thresh=0.5, eps=1e-8):
|
| 131 |
-
probs = torch.sigmoid(logits)
|
| 132 |
-
labels = (targets >= peak_thresh).float()
|
| 133 |
-
non_peak_mask = (labels == 0).float()
|
| 134 |
-
peak_mask = (labels == 1).float()
|
| 135 |
-
|
| 136 |
-
bce_all = F.binary_cross_entropy_with_logits(logits, labels, reduction='none')
|
| 137 |
-
bce_non = (bce_all * non_peak_mask)
|
| 138 |
-
bce_non = bce_non.sum() / (non_peak_mask.sum() + eps)
|
| 139 |
-
|
| 140 |
-
mse_peaks = F.mse_loss(probs * peak_mask, targets * peak_mask, reduction='sum') \
|
| 141 |
-
/ (peak_mask.sum() + eps)
|
| 142 |
-
|
| 143 |
-
t_dist = (targets + eps)
|
| 144 |
-
p_dist = (probs + eps)
|
| 145 |
-
t_dist = t_dist / t_dist.sum(dim=1, keepdim=True)
|
| 146 |
-
p_dist = p_dist / p_dist.sum(dim=1, keepdim=True)
|
| 147 |
-
kl = (t_dist * (t_dist.clamp(min=eps).log() - p_dist.clamp(min=eps).log())).sum(dim=1).mean()
|
| 148 |
-
|
| 149 |
-
return bce_non, kl, mse_peaks, probs
|
| 150 |
-
|
| 151 |
-
def accuracy_percentage(logits, targets, peak_thresh=0.5):
|
| 152 |
-
probs = torch.sigmoid(logits)
|
| 153 |
-
preds_bin = (probs >= 0.5).float()
|
| 154 |
-
labels = (targets >= peak_thresh).float()
|
| 155 |
-
correct = (preds_bin == labels).float().sum()
|
| 156 |
-
total = torch.numel(labels)
|
| 157 |
-
return (correct / max(1, total)).item() * 100.0
|
| 158 |
-
|
| 159 |
def evaluate(model, dl, device, alpha, beta, gamma, peak_thresh, eps=1e-8):
|
| 160 |
model.eval()
|
| 161 |
tot_loss, tot_acc = 0.0, 0.0
|
|
|
|
| 77 |
print("[i] TF cache ready (raw); compression will be learned.", flush=True)
|
| 78 |
return cache
|
| 79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
def evaluate(model, dl, device, alpha, beta, gamma, peak_thresh, eps=1e-8):
|
| 81 |
model.eval()
|
| 82 |
tot_loss, tot_acc = 0.0, 0.0
|
dpacman/classifier/model_tmp/__init__.py
ADDED
|
File without changes
|
dpacman/classifier/model_tmp/clustering_data.py
ADDED
|
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import argparse
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import random
|
| 7 |
+
import sys
|
| 8 |
+
import subprocess
|
| 9 |
+
from collections import defaultdict
|
| 10 |
+
|
| 11 |
+
# ─────────────────────────────────────────────────────────────────────────
|
| 12 |
+
# Original helpers (kept; some lightly edited/commented where needed)
|
| 13 |
+
# ─────────────────────────────────────────────────────────────────────────
|
| 14 |
+
|
| 15 |
+
def read_ids_file(p):
|
| 16 |
+
p = Path(p)
|
| 17 |
+
if not p.exists():
|
| 18 |
+
raise FileNotFoundError(f"IDs file not found: {p}")
|
| 19 |
+
return [line.strip() for line in p.open() if line.strip()]
|
| 20 |
+
|
| 21 |
+
def split_embeddings(emb_path, ids_path, out_dir, prefix):
|
| 22 |
+
out_dir = Path(out_dir)
|
| 23 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 24 |
+
|
| 25 |
+
if not Path(emb_path).exists():
|
| 26 |
+
raise FileNotFoundError(f"Embedding file not found: {emb_path}")
|
| 27 |
+
if not Path(ids_path).exists():
|
| 28 |
+
raise FileNotFoundError(f"IDs file not found: {ids_path}")
|
| 29 |
+
|
| 30 |
+
if emb_path.endswith(".npz"):
|
| 31 |
+
data = np.load(emb_path, allow_pickle=True)
|
| 32 |
+
if "embeddings" in data:
|
| 33 |
+
emb = data["embeddings"]
|
| 34 |
+
else:
|
| 35 |
+
raise ValueError(f"{emb_path} missing 'embeddings' key")
|
| 36 |
+
else:
|
| 37 |
+
emb = np.load(emb_path)
|
| 38 |
+
|
| 39 |
+
ids = read_ids_file(ids_path)
|
| 40 |
+
if len(ids) != emb.shape[0]:
|
| 41 |
+
print(f"[WARN] length mismatch: {len(ids)} ids vs {emb.shape[0]} embeddings in {emb_path}", file=sys.stderr)
|
| 42 |
+
|
| 43 |
+
mapping = {}
|
| 44 |
+
for i, ident in enumerate(ids):
|
| 45 |
+
if i >= emb.shape[0]:
|
| 46 |
+
print(f"[WARN] skipping {ident}: no embedding at index {i}", file=sys.stderr)
|
| 47 |
+
continue
|
| 48 |
+
arr = emb[i]
|
| 49 |
+
out_file = out_dir / f"{prefix}_{ident}.npy"
|
| 50 |
+
np.save(out_file, arr)
|
| 51 |
+
mapping[ident] = str(out_file)
|
| 52 |
+
return mapping
|
| 53 |
+
|
| 54 |
+
def extract_symbol_from_tf_id(full_id: str) -> str:
|
| 55 |
+
"""
|
| 56 |
+
Given a TF embedding ID like 'sp|O15062|ZBTB5_HUMAN' or 'ZBTB5_HUMAN',
|
| 57 |
+
return the gene symbol uppercase (e.g., 'ZBTB5').
|
| 58 |
+
"""
|
| 59 |
+
if "|" in full_id:
|
| 60 |
+
try:
|
| 61 |
+
# format sp|Accession|SYMBOL_HUMAN
|
| 62 |
+
genepart = full_id.split("|")[2]
|
| 63 |
+
except IndexError:
|
| 64 |
+
genepart = full_id
|
| 65 |
+
else:
|
| 66 |
+
genepart = full_id
|
| 67 |
+
symbol = genepart.split("_")[0]
|
| 68 |
+
return symbol.upper()
|
| 69 |
+
|
| 70 |
+
def build_tf_symbol_map(tf_map):
|
| 71 |
+
"""
|
| 72 |
+
Build mapping gene_symbol -> list of embedding paths.
|
| 73 |
+
"""
|
| 74 |
+
symbol_map = {}
|
| 75 |
+
for full_id, path in tf_map.items():
|
| 76 |
+
symbol = extract_symbol_from_tf_id(full_id)
|
| 77 |
+
symbol_map.setdefault(symbol, []).append(path)
|
| 78 |
+
return symbol_map
|
| 79 |
+
|
| 80 |
+
def tf_key_from_path(path: str) -> str:
|
| 81 |
+
"""
|
| 82 |
+
Given a path like .../tf_sp|O15062|ZBTB5_HUMAN.npy, extract normalized symbol 'ZBTB5'.
|
| 83 |
+
"""
|
| 84 |
+
stem = Path(path).stem # e.g., tf_sp|O15062|ZBTB5_HUMAN
|
| 85 |
+
# remove leading prefix if present (tf_)
|
| 86 |
+
if "_" in stem:
|
| 87 |
+
_, rest = stem.split("_", 1)
|
| 88 |
+
else:
|
| 89 |
+
rest = stem
|
| 90 |
+
return extract_symbol_from_tf_id(rest)
|
| 91 |
+
|
| 92 |
+
def dna_key_from_path(path: str) -> str:
|
| 93 |
+
"""
|
| 94 |
+
Given .../dna_peak42.npy -> 'peak42'
|
| 95 |
+
"""
|
| 96 |
+
stem = Path(path).stem
|
| 97 |
+
if "_" in stem:
|
| 98 |
+
_, rest = stem.split("_", 1)
|
| 99 |
+
else:
|
| 100 |
+
rest = stem
|
| 101 |
+
return rest
|
| 102 |
+
|
| 103 |
+
# ─────────────────────────────────────────────────────────────────────────
|
| 104 |
+
# New helpers for MMseqs clustering & cluster-level splitting
|
| 105 |
+
# ─────────────────────────────────────────────────────────────────────────
|
| 106 |
+
|
| 107 |
+
def write_dna_fasta(df: pd.DataFrame, out_fasta: Path) -> None:
|
| 108 |
+
"""
|
| 109 |
+
Write unique DNA sequences to FASTA using dna_id as header.
|
| 110 |
+
Requires df with columns: dna_id, dna_sequence
|
| 111 |
+
"""
|
| 112 |
+
uniq = df[["dna_id", "dna_sequence"]].drop_duplicates()
|
| 113 |
+
with open(out_fasta, "w") as f:
|
| 114 |
+
for _, row in uniq.iterrows():
|
| 115 |
+
did = row["dna_id"]
|
| 116 |
+
seq = str(row["dna_sequence"]).upper().replace(" ", "").replace("\n", "")
|
| 117 |
+
f.write(f">{did}\n{seq}\n")
|
| 118 |
+
|
| 119 |
+
def run_mmseqs_easy_cluster(
|
| 120 |
+
mmseqs_bin: str,
|
| 121 |
+
fasta: Path,
|
| 122 |
+
out_prefix: Path,
|
| 123 |
+
tmp_dir: Path,
|
| 124 |
+
min_seq_id: float,
|
| 125 |
+
coverage: float,
|
| 126 |
+
cov_mode: int,
|
| 127 |
+
) -> Path:
|
| 128 |
+
"""
|
| 129 |
+
Runs mmseqs easy-cluster on nucleotide sequences.
|
| 130 |
+
Returns the path to a clusters TSV file (creating it if the default one isn't present).
|
| 131 |
+
"""
|
| 132 |
+
tmp_dir.mkdir(parents=True, exist_ok=True)
|
| 133 |
+
out_prefix.parent.mkdir(parents=True, exist_ok=True)
|
| 134 |
+
|
| 135 |
+
cmd = [
|
| 136 |
+
mmseqs_bin, "easy-cluster",
|
| 137 |
+
str(fasta), str(out_prefix), str(tmp_dir),
|
| 138 |
+
"--min-seq-id", str(min_seq_id),
|
| 139 |
+
"-c", str(coverage),
|
| 140 |
+
"--cov-mode", str(cov_mode),
|
| 141 |
+
# You can add performance flags here if needed, e.g.:
|
| 142 |
+
# "--threads", "8"
|
| 143 |
+
]
|
| 144 |
+
print("[i] Running:", " ".join(cmd), flush=True)
|
| 145 |
+
subprocess.run(cmd, check=True)
|
| 146 |
+
|
| 147 |
+
# MMseqs easy-cluster typically writes <out_prefix>_cluster.tsv
|
| 148 |
+
default_tsv = Path(str(out_prefix) + "_cluster.tsv")
|
| 149 |
+
if default_tsv.exists():
|
| 150 |
+
print(f"[i] Found cluster TSV: {default_tsv}")
|
| 151 |
+
return default_tsv
|
| 152 |
+
|
| 153 |
+
# Fallback: try createtsv if default is missing
|
| 154 |
+
# This requires the internal DBs. easy-cluster creates DBs alongside out_prefix.
|
| 155 |
+
# We'll try to locate them and emit a TSV.
|
| 156 |
+
in_db = Path(str(out_prefix) + "_query")
|
| 157 |
+
cl_db = Path(str(out_prefix) + "_cluster")
|
| 158 |
+
out_tsv = Path(str(out_prefix) + "_fallback_cluster.tsv")
|
| 159 |
+
if in_db.exists() and cl_db.exists():
|
| 160 |
+
cmd2 = [mmseqs_bin, "createtsv", str(in_db), str(in_db), str(cl_db), str(out_tsv)]
|
| 161 |
+
print("[i] Creating TSV via createtsv:", " ".join(cmd2), flush=True)
|
| 162 |
+
subprocess.run(cmd2, check=True)
|
| 163 |
+
if out_tsv.exists():
|
| 164 |
+
return out_tsv
|
| 165 |
+
|
| 166 |
+
raise FileNotFoundError("Could not locate clusters TSV from mmseqs. "
|
| 167 |
+
"Expected {default_tsv} or createtsv fallback.")
|
| 168 |
+
|
| 169 |
+
def parse_mmseqs_clusters(tsv_path: Path) -> dict:
|
| 170 |
+
"""
|
| 171 |
+
Parse MMseqs cluster TSV (rep \t member). Returns dna_id -> cluster_rep_id
|
| 172 |
+
"""
|
| 173 |
+
mapping = {}
|
| 174 |
+
with open(tsv_path) as f:
|
| 175 |
+
for line in f:
|
| 176 |
+
parts = line.rstrip("\n").split("\t")
|
| 177 |
+
if len(parts) < 2:
|
| 178 |
+
continue
|
| 179 |
+
rep, member = parts[0], parts[1]
|
| 180 |
+
mapping[member] = rep
|
| 181 |
+
# Some TSVs include rep->rep; if not, ensure rep is mapped to itself:
|
| 182 |
+
if rep not in mapping:
|
| 183 |
+
mapping[rep] = rep
|
| 184 |
+
return mapping
|
| 185 |
+
|
| 186 |
+
def assign_clusters_to_splits(cluster_rep_to_members: dict,
|
| 187 |
+
val_frac: float,
|
| 188 |
+
test_frac: float,
|
| 189 |
+
seed: int = 42):
|
| 190 |
+
"""
|
| 191 |
+
cluster_rep_to_members: dict[rep] = [members...]
|
| 192 |
+
Returns: dict with keys 'train','val','test' mapping to sets of dna_id.
|
| 193 |
+
Ensures all members of a cluster go to the same split.
|
| 194 |
+
"""
|
| 195 |
+
rng = random.Random(seed)
|
| 196 |
+
reps = list(cluster_rep_to_members.keys())
|
| 197 |
+
rng.shuffle(reps)
|
| 198 |
+
|
| 199 |
+
# Greedy-ish fill by total member counts to match desired fractions.
|
| 200 |
+
total = sum(len(cluster_rep_to_members[r]) for r in reps)
|
| 201 |
+
target_val = int(round(total * val_frac))
|
| 202 |
+
target_test = int(round(total * test_frac))
|
| 203 |
+
cur_val = cur_test = 0
|
| 204 |
+
|
| 205 |
+
val_ids, test_ids, train_ids = set(), set(), set()
|
| 206 |
+
for rep in reps:
|
| 207 |
+
members = cluster_rep_to_members[rep]
|
| 208 |
+
c = len(members)
|
| 209 |
+
# Fill val first, then test, then train
|
| 210 |
+
if cur_val + c <= target_val:
|
| 211 |
+
val_ids.update(members); cur_val += c
|
| 212 |
+
elif cur_test + c <= target_test:
|
| 213 |
+
test_ids.update(members); cur_test += c
|
| 214 |
+
else:
|
| 215 |
+
train_ids.update(members)
|
| 216 |
+
|
| 217 |
+
return {"train": train_ids, "val": val_ids, "test": test_ids}
|
| 218 |
+
|
| 219 |
+
# ─────────────────────────────────────────────────────────────────────────
|
| 220 |
+
# Main
|
| 221 |
+
# ─────────────────────────────────────────────────────────────────────────
|
| 222 |
+
|
| 223 |
+
def main():
|
| 224 |
+
parser = argparse.ArgumentParser(
|
| 225 |
+
description="Build TF-DNA pair lists with MMseqs clustering on DNA to prevent split leakage."
|
| 226 |
+
)
|
| 227 |
+
parser.add_argument("--final_csv", required=True, help="final.csv with TF_id and dna_sequence")
|
| 228 |
+
parser.add_argument("--dna_embed_npz", required=True, help="DNA embedding file (.npy or .npz)")
|
| 229 |
+
parser.add_argument("--dna_ids", required=True, help="IDs file for DNA embeddings (peak*.ids)")
|
| 230 |
+
parser.add_argument("--tf_embed_npy", required=True, help="TF embedding file (.npy or .npz)")
|
| 231 |
+
parser.add_argument("--tf_ids", required=True, help="IDs file for TF embeddings (sp|... ids)")
|
| 232 |
+
parser.add_argument("--out_dir", required=True, help="Output directory")
|
| 233 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 234 |
+
|
| 235 |
+
# NEW: MMseqs options & split fractions
|
| 236 |
+
parser.add_argument("--mmseqs_bin", default="mmseqs", help="Path to mmseqs binary")
|
| 237 |
+
parser.add_argument("--min_seq_id", type=float, default=0.9, help="MMseqs --min-seq-id")
|
| 238 |
+
parser.add_argument("--cov", type=float, default=0.8, help="MMseqs -c coverage fraction")
|
| 239 |
+
parser.add_argument("--cov_mode", type=int, default=1, help="MMseqs --cov-mode (1 = coverage of target)")
|
| 240 |
+
parser.add_argument("--val_frac", type=float, default=0.10)
|
| 241 |
+
parser.add_argument("--test_frac", type=float, default=0.10)
|
| 242 |
+
parser.add_argument("--tmp_dir", default=None, help="MMseqs tmp dir (defaults to out_dir/tmp)")
|
| 243 |
+
args = parser.parse_args()
|
| 244 |
+
|
| 245 |
+
random.seed(args.seed)
|
| 246 |
+
out_dir = Path(args.out_dir)
|
| 247 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 248 |
+
|
| 249 |
+
# Load final.csv
|
| 250 |
+
df = pd.read_csv(args.final_csv, dtype=str)
|
| 251 |
+
if "TF_id" not in df.columns or "dna_sequence" not in df.columns:
|
| 252 |
+
raise RuntimeError("final.csv must have columns TF_id and dna_sequence")
|
| 253 |
+
|
| 254 |
+
# Assign dna_id (unique per dna_sequence)
|
| 255 |
+
unique_seqs = df["dna_sequence"].drop_duplicates().tolist()
|
| 256 |
+
seq_to_id = {seq: f"peak{i}" for i, seq in enumerate(unique_seqs)}
|
| 257 |
+
df["dna_id"] = df["dna_sequence"].map(seq_to_id)
|
| 258 |
+
enriched_csv = out_dir / "final_with_dna_id.csv"
|
| 259 |
+
df.to_csv(enriched_csv, index=False)
|
| 260 |
+
print(f"[i] Wrote augmented final.csv with dna_id to {enriched_csv}")
|
| 261 |
+
|
| 262 |
+
# Split embeddings into per-item files (unchanged)
|
| 263 |
+
print(f"[i] Splitting DNA embeddings from {args.dna_embed_npz} with ids {args.dna_ids}")
|
| 264 |
+
dna_map = split_embeddings(args.dna_embed_npz, args.dna_ids, out_dir / "dna_single", "dna")
|
| 265 |
+
print(f"[i] DNA embeddings available: {len(dna_map)} (sample: {list(dna_map.keys())[:10]})")
|
| 266 |
+
print(f"[i] Splitting TF embeddings from {args.tf_embed_npy} with ids {args.tf_ids}")
|
| 267 |
+
tf_map = split_embeddings(args.tf_embed_npy, args.tf_ids, out_dir / "tf_single", "tf")
|
| 268 |
+
print(f"[i] TF embeddings available: {len(tf_map)} (sample: {list(tf_map.keys())[:10]})")
|
| 269 |
+
|
| 270 |
+
# Build gene-symbol normalized map
|
| 271 |
+
tf_symbol_map = build_tf_symbol_map(tf_map)
|
| 272 |
+
print(f"[i] TF symbol map keys (sample): {list(tf_symbol_map.keys())[:30]}")
|
| 273 |
+
|
| 274 |
+
# Diagnostic overlaps
|
| 275 |
+
norm_tf_in_final = set(t.split("_seq")[0].upper() for t in df["TF_id"].unique())
|
| 276 |
+
available_tf_symbols = set(tf_symbol_map.keys())
|
| 277 |
+
intersect_tf = norm_tf_in_final & available_tf_symbols
|
| 278 |
+
print(f"[i] Unique normalized TF symbols in final.csv: {len(norm_tf_in_final)}")
|
| 279 |
+
print(f"[i] Available TF embedding symbols: {len(available_tf_symbols)}")
|
| 280 |
+
print(f"[i] Intersection count: {len(intersect_tf)}")
|
| 281 |
+
if len(intersect_tf) == 0:
|
| 282 |
+
print("[ERROR] No overlap between normalized TF_id and TF embedding symbols.", file=sys.stderr)
|
| 283 |
+
print("Sample normalized TFs from final.csv:", sorted(list(norm_tf_in_final))[:30], file=sys.stderr)
|
| 284 |
+
print("Sample available TF symbols:", sorted(list(available_tf_symbols))[:30], file=sys.stderr)
|
| 285 |
+
sys.exit(1)
|
| 286 |
+
|
| 287 |
+
dna_ids_final = set(df["dna_id"].unique())
|
| 288 |
+
available_dna_ids = set(dna_map.keys())
|
| 289 |
+
intersect_dna = dna_ids_final & available_dna_ids
|
| 290 |
+
print(f"[i] Unique dna_id in final.csv: {len(dna_ids_final)}. Available DNA ids: {len(available_dna_ids)}. Intersection: {len(intersect_dna)}")
|
| 291 |
+
if len(intersect_dna) == 0:
|
| 292 |
+
print("[ERROR] No overlap on DNA ids.", file=sys.stderr)
|
| 293 |
+
sys.exit(1)
|
| 294 |
+
|
| 295 |
+
# ── NEW: MMseqs clustering on DNA sequences ───────────────────────────
|
| 296 |
+
fasta_path = out_dir / "dna_unique.fasta"
|
| 297 |
+
write_dna_fasta(df, fasta_path)
|
| 298 |
+
print(f"[i] Wrote FASTA with {df['dna_id'].nunique()} unique sequences → {fasta_path}")
|
| 299 |
+
|
| 300 |
+
tmp_dir = Path(args.tmp_dir) if args.tmp_dir else (out_dir / "mmseqs_tmp")
|
| 301 |
+
cluster_prefix = out_dir / "mmseqs_dna_clusters"
|
| 302 |
+
clusters_tsv = run_mmseqs_easy_cluster(
|
| 303 |
+
mmseqs_bin=args.mmseqs_bin,
|
| 304 |
+
fasta=fasta_path,
|
| 305 |
+
out_prefix=cluster_prefix,
|
| 306 |
+
tmp_dir=tmp_dir,
|
| 307 |
+
min_seq_id=args.min_seq_id,
|
| 308 |
+
coverage=args.cov,
|
| 309 |
+
cov_mode=args.cov_mode,
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
# Parse clusters
|
| 313 |
+
member_to_rep = parse_mmseqs_clusters(clusters_tsv) # dna_id -> rep_id
|
| 314 |
+
# Build rep -> members list
|
| 315 |
+
rep_to_members = defaultdict(list)
|
| 316 |
+
for member, rep in member_to_rep.items():
|
| 317 |
+
rep_to_members[rep].append(member)
|
| 318 |
+
|
| 319 |
+
print(f"[i] Parsed {len(rep_to_members)} clusters from {clusters_tsv}")
|
| 320 |
+
clusters_table = []
|
| 321 |
+
for rep, members in rep_to_members.items():
|
| 322 |
+
for m in members:
|
| 323 |
+
clusters_table.append((m, rep))
|
| 324 |
+
clusters_df = pd.DataFrame(clusters_table, columns=["dna_id", "cluster_id"])
|
| 325 |
+
clusters_df.to_csv(out_dir / "clusters.tsv", sep="\t", index=False)
|
| 326 |
+
print(f"[i] Wrote clusters mapping → {out_dir / 'clusters.tsv'}")
|
| 327 |
+
|
| 328 |
+
# Attach cluster_id back to final df
|
| 329 |
+
df = df.merge(clusters_df, on="dna_id", how="left")
|
| 330 |
+
df.to_csv(out_dir / "final_with_dna_id_and_cluster.csv", index=False)
|
| 331 |
+
print(f"[i] Wrote {out_dir / 'final_with_dna_id_and_cluster.csv'}")
|
| 332 |
+
|
| 333 |
+
# Assign entire clusters to splits
|
| 334 |
+
splits = assign_clusters_to_splits(rep_to_members,
|
| 335 |
+
val_frac=args.val_frac,
|
| 336 |
+
test_frac=args.test_frac,
|
| 337 |
+
seed=args.seed)
|
| 338 |
+
for k in ["train", "val", "test"]:
|
| 339 |
+
print(f"[i] {k}: {len(splits[k])} dna_ids")
|
| 340 |
+
|
| 341 |
+
# ── Build positive pairs only, per split (NO negatives) ───────────────
|
| 342 |
+
positives_by_split = {"train": [], "val": [], "test": []}
|
| 343 |
+
# Build a quick dna_id -> embedding path map
|
| 344 |
+
dnaid_to_path = {did: path for did, path in dna_map.items()}
|
| 345 |
+
|
| 346 |
+
pos_count = 0
|
| 347 |
+
for _, row in df.iterrows():
|
| 348 |
+
tf_raw = row["TF_id"]
|
| 349 |
+
tf_symbol = tf_raw.split("_seq")[0].upper()
|
| 350 |
+
dnaid = row["dna_id"]
|
| 351 |
+
if (tf_symbol not in tf_symbol_map) or (dnaid not in dnaid_to_path):
|
| 352 |
+
continue
|
| 353 |
+
tf_embedding_path = tf_symbol_map[tf_symbol][0] # first embedding per symbol
|
| 354 |
+
|
| 355 |
+
# decide split by dna_id cluster assignment
|
| 356 |
+
if dnaid in splits["train"]:
|
| 357 |
+
positives_by_split["train"].append((tf_embedding_path, dnaid_to_path[dnaid], 1))
|
| 358 |
+
elif dnaid in splits["val"]:
|
| 359 |
+
positives_by_split["val"].append((tf_embedding_path, dnaid_to_path[dnaid], 1))
|
| 360 |
+
elif dnaid in splits["test"]:
|
| 361 |
+
positives_by_split["test"].append((tf_embedding_path, dnaid_to_path[dnaid], 1))
|
| 362 |
+
pos_count += 1
|
| 363 |
+
|
| 364 |
+
print(f"[i] Constructed positives across splits (rows in final.csv iterated: {len(df)})")
|
| 365 |
+
for k in ["train", "val", "test"]:
|
| 366 |
+
print(f"[i] positives[{k}] = {len(positives_by_split[k])}")
|
| 367 |
+
|
| 368 |
+
# # OLD: negatives (kept commented)
|
| 369 |
+
# negatives = []
|
| 370 |
+
# print(f"[i] Sampled {len(negatives)} negatives (neg_per_positive not used)")
|
| 371 |
+
|
| 372 |
+
# Emit split-specific pair lists
|
| 373 |
+
for split in ["train", "val", "test"]:
|
| 374 |
+
out_tsv = out_dir / f"pair_list_{split}.tsv"
|
| 375 |
+
with open(out_tsv, "w") as f:
|
| 376 |
+
for binder_path, glm_path, label in positives_by_split[split]: # + negatives if you add later
|
| 377 |
+
f.write(f"{binder_path}\t{glm_path}\t{label}\n")
|
| 378 |
+
print(f"[i] Wrote {len(positives_by_split[split])} examples to {out_tsv}")
|
| 379 |
+
|
| 380 |
+
print("✅ Done. Cluster-aware splits ready.")
|
| 381 |
+
|
| 382 |
+
if __name__ == "__main__":
|
| 383 |
+
main()
|
dpacman/classifier/model_tmp/compress_embeddings.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# compress_embeddings.py
|
| 2 |
+
# USAGE: python compress_embeddings.py --input_glob "/path/to/esm_embeddings/*.npy" --output_dir "/path/to/compressed_embeddings" --esm_dim 1280 --out_dim 256
|
| 3 |
+
# --------------
|
| 4 |
+
import os
|
| 5 |
+
import glob
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn
|
| 9 |
+
|
| 10 |
+
class EmbeddingCompressor(nn.Module):
|
| 11 |
+
def __init__(self, input_dim: int = 1280, output_dim: int = 256):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.fc = nn.Linear(input_dim, output_dim)
|
| 14 |
+
|
| 15 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 16 |
+
"""
|
| 17 |
+
x: (batch, L, input_dim) or (L, input_dim)
|
| 18 |
+
returns: (batch, output_dim) or (output_dim,)
|
| 19 |
+
"""
|
| 20 |
+
if x.dim() == 2:
|
| 21 |
+
# single example: mean over tokens
|
| 22 |
+
x = x.mean(dim=0, keepdim=True) # → (1, input_dim)
|
| 23 |
+
else:
|
| 24 |
+
# batch: mean over tokens
|
| 25 |
+
x = x.mean(dim=1) # → (batch, input_dim)
|
| 26 |
+
return self.fc(x) # → (batch, output_dim)
|
| 27 |
+
|
| 28 |
+
def compress_file(in_path: str, out_path: str, model: EmbeddingCompressor):
|
| 29 |
+
arr = np.load(in_path) # shape (L, D) or (batch, L, D)
|
| 30 |
+
tensor = torch.from_numpy(arr).float()
|
| 31 |
+
with torch.no_grad():
|
| 32 |
+
compressed = model(tensor) # → (batch, 256)
|
| 33 |
+
out = compressed.cpu().numpy()
|
| 34 |
+
np.save(out_path, out)
|
| 35 |
+
print(f"Saved {out_path}")
|
| 36 |
+
|
| 37 |
+
if __name__ == "__main__":
|
| 38 |
+
import argparse
|
| 39 |
+
parser = argparse.ArgumentParser(description="Compress ESM embeddings to 256d")
|
| 40 |
+
parser.add_argument("--input_glob", type=str, required=True,
|
| 41 |
+
help="Glob for your .npy ESM embeddings (e.g. data/esm_*.npy)")
|
| 42 |
+
parser.add_argument("--output_dir", type=str, required=True)
|
| 43 |
+
parser.add_argument("--esm_dim", type=int, default=1280)
|
| 44 |
+
parser.add_argument("--out_dim", type=int, default=256)
|
| 45 |
+
args = parser.parse_args()
|
| 46 |
+
|
| 47 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 48 |
+
compressor = EmbeddingCompressor(args.esm_dim, args.out_dim)
|
| 49 |
+
compressor.eval()
|
| 50 |
+
|
| 51 |
+
for fn in glob.glob(args.input_glob):
|
| 52 |
+
base = os.path.basename(fn).replace(".npy", "_256.npy")
|
| 53 |
+
out_path = os.path.join(args.output_dir, base)
|
| 54 |
+
compress_file(fn, out_path, compressor)
|
dpacman/{data_tasks/embeddings → classifier/model_tmp}/compute_embeddings.py
RENAMED
|
@@ -14,7 +14,6 @@ Usage example (DNA + protein in one go):
|
|
| 14 |
--out-dir ../data_files/processed/tfclust/hg38_tf/embeddings \
|
| 15 |
--device cuda
|
| 16 |
"""
|
| 17 |
-
|
| 18 |
import os
|
| 19 |
import re
|
| 20 |
import argparse
|
|
@@ -29,7 +28,6 @@ import time
|
|
| 29 |
|
| 30 |
# ---- model wrappers ----
|
| 31 |
|
| 32 |
-
|
| 33 |
class CaduceusEmbedder:
|
| 34 |
def __init__(self, device, chunk_size=131_072, overlap=0):
|
| 35 |
"""
|
|
@@ -41,54 +39,47 @@ class CaduceusEmbedder:
|
|
| 41 |
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 42 |
model_name, trust_remote_code=True
|
| 43 |
)
|
| 44 |
-
self.model = (
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
)
|
| 49 |
-
self.device = device
|
| 50 |
self.chunk_size = chunk_size
|
| 51 |
-
self.step
|
| 52 |
|
| 53 |
def embed(self, seqs):
|
| 54 |
"""
|
| 55 |
seqs: List[str] of DNA sequences (each <= chunk_size for this test)
|
| 56 |
returns: np.ndarray of shape (N, L, D), raw per‐token embeddings
|
| 57 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
outputs = []
|
| 59 |
for seq in seqs:
|
| 60 |
-
# --- old windowing + mean-pooling logic, now commented out ---
|
| 61 |
-
# window_vecs = []
|
| 62 |
-
# for i in range(0, len(seq), self.step):
|
| 63 |
-
# chunk = seq[i : i + self.chunk_size]
|
| 64 |
-
# if not chunk:
|
| 65 |
-
# break
|
| 66 |
-
# toks = self.tokenizer(
|
| 67 |
-
# chunk,
|
| 68 |
-
# return_tensors="pt",
|
| 69 |
-
# padding=False,
|
| 70 |
-
# truncation=True,
|
| 71 |
-
# max_length=self.chunk_size
|
| 72 |
-
# ).to(self.device)
|
| 73 |
-
# with torch.no_grad():
|
| 74 |
-
# out = self.model(**toks).last_hidden_state
|
| 75 |
-
# window_vecs.append(out.mean(dim=1).squeeze(0).cpu())
|
| 76 |
-
# seq_emb = torch.stack(window_vecs, dim=0).mean(dim=0).numpy()
|
| 77 |
-
# outputs.append(seq_emb)
|
| 78 |
-
|
| 79 |
-
# --- new: raw per‐token embeddings in one shot ---
|
| 80 |
toks = self.tokenizer(
|
| 81 |
seq,
|
| 82 |
return_tensors="pt",
|
| 83 |
padding=False,
|
| 84 |
truncation=True,
|
| 85 |
-
max_length=self.chunk_size
|
| 86 |
).to(self.device)
|
| 87 |
with torch.no_grad():
|
| 88 |
out = self.model(**toks).last_hidden_state # (1, L, D)
|
| 89 |
-
outputs.append(out.cpu().numpy()[0])
|
|
|
|
| 90 |
|
| 91 |
-
return np.stack(outputs, axis=0) # (N, L, D)
|
| 92 |
|
| 93 |
def benchmark(self, lengths=None):
|
| 94 |
"""
|
|
@@ -110,29 +101,79 @@ class CaduceusEmbedder:
|
|
| 110 |
t1 = time.perf_counter()
|
| 111 |
print(f" length={sz:6,d} time={(t1-t0)*1000:7.1f} ms")
|
| 112 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
class DNABertEmbedder:
|
| 115 |
def __init__(self, device):
|
| 116 |
-
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
self.model = AutoModel.from_pretrained(
|
| 120 |
-
"zhihan1996/DNA_bert_6", trust_remote_code=True
|
| 121 |
-
).to(device)
|
| 122 |
-
self.device = device
|
| 123 |
|
| 124 |
def embed(self, seqs):
|
| 125 |
embs = []
|
| 126 |
for s in seqs:
|
| 127 |
-
tokens = self.tokenizer(s, return_tensors="pt", padding=True)[
|
| 128 |
-
"input_ids"
|
| 129 |
-
].to(self.device)
|
| 130 |
with torch.no_grad():
|
| 131 |
out = self.model(tokens).last_hidden_state.mean(1)
|
| 132 |
embs.append(out.cpu().numpy())
|
| 133 |
return np.vstack(embs)
|
| 134 |
|
| 135 |
-
|
| 136 |
class NucleotideTransformerEmbedder:
|
| 137 |
def __init__(self, device):
|
| 138 |
# HF “feature-extraction” returns a list of (L, D) arrays for each input
|
|
@@ -140,9 +181,7 @@ class NucleotideTransformerEmbedder:
|
|
| 140 |
self.pipe = pipeline(
|
| 141 |
"feature-extraction",
|
| 142 |
model="InstaDeepAI/nucleotide-transformer-500m-1000g",
|
| 143 |
-
device=
|
| 144 |
-
-1 if device == "cpu" else 0
|
| 145 |
-
), # HF uses -1 for CPU, 0 for GPU #:contentReference[oaicite:0]{index=0}
|
| 146 |
)
|
| 147 |
|
| 148 |
def embed(self, seqs):
|
|
@@ -152,35 +191,131 @@ class NucleotideTransformerEmbedder:
|
|
| 152 |
"""
|
| 153 |
all_embeddings = self.pipe(seqs, truncation=True, padding=True)
|
| 154 |
# all_embeddings is a List of shape (L, D) arrays
|
| 155 |
-
pooled = [np.mean(x, axis=0) for x in all_embeddings]
|
| 156 |
-
return np.vstack(pooled)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
|
| 158 |
|
| 159 |
class ESMEmbedder:
|
| 160 |
-
def __init__(self, device):
|
| 161 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
self.batch_converter = self.alphabet.get_batch_converter()
|
| 163 |
self.model.to(device).eval()
|
| 164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
|
| 166 |
-
def
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
|
| 176 |
class ESMDBPEmbedder:
|
| 177 |
def __init__(self, device):
|
| 178 |
base_model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
|
| 179 |
model_path = (
|
| 180 |
Path(__file__).resolve().parent.parent
|
| 181 |
-
/ "pretrained"
|
| 182 |
-
/ "ESM-DBP"
|
| 183 |
-
/ "ESM-DBP.model"
|
| 184 |
)
|
| 185 |
checkpoint = torch.load(model_path, map_location="cpu")
|
| 186 |
clean_sd = {}
|
|
@@ -196,17 +331,46 @@ class ESMDBPEmbedder:
|
|
| 196 |
self.alphabet = alphabet
|
| 197 |
self.batch_converter = alphabet.get_batch_converter()
|
| 198 |
self.device = device
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
|
| 200 |
def embed(self, seqs):
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
|
| 211 |
class GPNEmbedder:
|
| 212 |
def __init__(self, device):
|
|
@@ -219,14 +383,16 @@ class GPNEmbedder:
|
|
| 219 |
|
| 220 |
def embed(self, seqs):
|
| 221 |
inputs = self.tokenizer(
|
| 222 |
-
seqs,
|
|
|
|
|
|
|
|
|
|
| 223 |
).to(self.device)
|
| 224 |
|
| 225 |
with torch.no_grad():
|
| 226 |
last_hidden = self.model(**inputs).last_hidden_state
|
| 227 |
return last_hidden.mean(dim=1).cpu().numpy()
|
| 228 |
|
| 229 |
-
|
| 230 |
class ProGenEmbedder:
|
| 231 |
def __init__(self, device):
|
| 232 |
model_name = "jinyuan22/ProGen2-base"
|
|
@@ -236,102 +402,137 @@ class ProGenEmbedder:
|
|
| 236 |
|
| 237 |
def embed(self, seqs):
|
| 238 |
inputs = self.tokenizer(
|
| 239 |
-
seqs,
|
|
|
|
|
|
|
|
|
|
| 240 |
).to(self.device)
|
| 241 |
with torch.no_grad():
|
| 242 |
last_hidden = self.model(**inputs).last_hidden_state
|
| 243 |
return last_hidden.mean(dim=1).cpu().numpy()
|
| 244 |
|
| 245 |
-
|
| 246 |
# ---- main pipeline ----
|
| 247 |
|
| 248 |
-
|
| 249 |
def get_embedder(name, device, for_dna=True):
|
| 250 |
name = name.lower()
|
| 251 |
if for_dna:
|
| 252 |
-
if name
|
| 253 |
-
|
| 254 |
-
if name
|
| 255 |
-
|
| 256 |
-
if name
|
| 257 |
-
return NucleotideTransformerEmbedder(device)
|
| 258 |
-
if name == "gpn":
|
| 259 |
-
return GPNEmbedder(device)
|
| 260 |
else:
|
| 261 |
-
if name in ("esm",):
|
| 262 |
-
|
| 263 |
-
if name
|
| 264 |
-
return ESMDBPEmbedder(device)
|
| 265 |
-
if name == "progen":
|
| 266 |
-
return ProGenEmbedder(device)
|
| 267 |
raise ValueError(f"Unknown model {name} (for_dna={for_dna})")
|
| 268 |
|
| 269 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
def embed_and_save(seqs, ids, embedder, out_path):
|
| 271 |
embs = embedder.embed(seqs)
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 275 |
|
| 276 |
|
| 277 |
-
if __name__
|
| 278 |
|
| 279 |
p = argparse.ArgumentParser()
|
| 280 |
-
p.add_argument(
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
)
|
| 285 |
-
p.add_argument(
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
help="if set, skip the chromosome embedding step",
|
| 289 |
-
) # if glm embeddings successful but not plm embeddings
|
| 290 |
-
p.add_argument("--tf-fasta", required=True, help="input TF FASTA file")
|
| 291 |
-
p.add_argument("--chrom-model", default="caduceus")
|
| 292 |
-
p.add_argument("--tf-model", default="esm-dbp")
|
| 293 |
-
p.add_argument(
|
| 294 |
-
"--out-dir", default="data_files/processed/tfclust/hg38_tf/embeddings"
|
| 295 |
-
)
|
| 296 |
-
p.add_argument("--device", default="cpu")
|
| 297 |
args = p.parse_args()
|
| 298 |
|
| 299 |
os.makedirs(args.out_dir, exist_ok=True)
|
| 300 |
device = args.device
|
| 301 |
|
| 302 |
if not args.skip_dna:
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
(
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 326 |
else:
|
| 327 |
-
|
| 328 |
|
| 329 |
-
####################
|
| 330 |
-
chrom_embedder = get_embedder(args.chrom_model, device, for_dna=True)
|
| 331 |
-
out_chrom = Path(args.out_dir) / f"chrom_{args.chrom_model}.npy"
|
| 332 |
-
embed_and_save(chrom_seqs, chrom_ids, chrom_embedder, out_chrom)
|
| 333 |
|
| 334 |
-
#
|
| 335 |
tf_seqs, tf_ids = [], []
|
| 336 |
for record in SeqIO.parse(args.tf_fasta, "fasta"):
|
| 337 |
tf_ids.append(record.id)
|
|
@@ -342,4 +543,4 @@ if __name__ == "__main__":
|
|
| 342 |
out_tf = Path(args.out_dir) / f"tf_{args.tf_model}.npy"
|
| 343 |
embed_and_save(tf_seqs, tf_ids, tf_embedder, out_tf)
|
| 344 |
|
| 345 |
-
print("Done.")
|
|
|
|
| 14 |
--out-dir ../data_files/processed/tfclust/hg38_tf/embeddings \
|
| 15 |
--device cuda
|
| 16 |
"""
|
|
|
|
| 17 |
import os
|
| 18 |
import re
|
| 19 |
import argparse
|
|
|
|
| 28 |
|
| 29 |
# ---- model wrappers ----
|
| 30 |
|
|
|
|
| 31 |
class CaduceusEmbedder:
|
| 32 |
def __init__(self, device, chunk_size=131_072, overlap=0):
|
| 33 |
"""
|
|
|
|
| 39 |
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 40 |
model_name, trust_remote_code=True
|
| 41 |
)
|
| 42 |
+
self.model = AutoModel.from_pretrained(
|
| 43 |
+
model_name, trust_remote_code=True
|
| 44 |
+
).to(device).eval()
|
| 45 |
+
self.device = device
|
|
|
|
|
|
|
| 46 |
self.chunk_size = chunk_size
|
| 47 |
+
self.step = chunk_size - overlap
|
| 48 |
|
| 49 |
def embed(self, seqs):
|
| 50 |
"""
|
| 51 |
seqs: List[str] of DNA sequences (each <= chunk_size for this test)
|
| 52 |
returns: np.ndarray of shape (N, L, D), raw per‐token embeddings
|
| 53 |
"""
|
| 54 |
+
# outputs = []
|
| 55 |
+
# for seq in seqs:
|
| 56 |
+
# # --- new: raw per‐token embeddings in one shot ---
|
| 57 |
+
# toks = self.tokenizer(
|
| 58 |
+
# seq,
|
| 59 |
+
# return_tensors="pt",
|
| 60 |
+
# padding=False,
|
| 61 |
+
# truncation=True,
|
| 62 |
+
# max_length=self.chunk_size
|
| 63 |
+
# ).to(self.device)
|
| 64 |
+
# with torch.no_grad():
|
| 65 |
+
# out = self.model(**toks).last_hidden_state # (1, L, D)
|
| 66 |
+
# outputs.append(out.cpu().numpy()[0]) # (L, D)
|
| 67 |
+
|
| 68 |
+
# return np.stack(outputs, axis=0) # (N, L, D)
|
| 69 |
outputs = []
|
| 70 |
for seq in seqs:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
toks = self.tokenizer(
|
| 72 |
seq,
|
| 73 |
return_tensors="pt",
|
| 74 |
padding=False,
|
| 75 |
truncation=True,
|
| 76 |
+
max_length=self.chunk_size
|
| 77 |
).to(self.device)
|
| 78 |
with torch.no_grad():
|
| 79 |
out = self.model(**toks).last_hidden_state # (1, L, D)
|
| 80 |
+
outputs.append(out.cpu().numpy()[0]) # (L, D)
|
| 81 |
+
return outputs # list of variable-length (L_i, D) arrays
|
| 82 |
|
|
|
|
| 83 |
|
| 84 |
def benchmark(self, lengths=None):
|
| 85 |
"""
|
|
|
|
| 101 |
t1 = time.perf_counter()
|
| 102 |
print(f" length={sz:6,d} time={(t1-t0)*1000:7.1f} ms")
|
| 103 |
|
| 104 |
+
class SegmentNTEmbedder:
|
| 105 |
+
def __init__(self, device):
|
| 106 |
+
self.tokenizer = AutoTokenizer.from_pretrained("InstaDeepAI/segment_nt", trust_remote_code=True)
|
| 107 |
+
self.model = AutoModel.from_pretrained("InstaDeepAI/segment_nt", trust_remote_code=True).to(device).eval()
|
| 108 |
+
self.device = device
|
| 109 |
+
|
| 110 |
+
def _adjust_length(self, input_ids):
|
| 111 |
+
bs, L = input_ids.shape
|
| 112 |
+
excl = L - 1
|
| 113 |
+
remainder = (excl) % 4
|
| 114 |
+
if remainder != 0:
|
| 115 |
+
pad_needed = 4 - remainder
|
| 116 |
+
pad_tensor = torch.full((bs, pad_needed), self.tokenizer.pad_token_id, dtype=input_ids.dtype, device=input_ids.device)
|
| 117 |
+
input_ids = torch.cat([input_ids, pad_tensor], dim=1)
|
| 118 |
+
return input_ids
|
| 119 |
+
|
| 120 |
+
def embed(self, seqs, batch_size=16):
|
| 121 |
+
"""
|
| 122 |
+
seqs: List[str]
|
| 123 |
+
Returns: np.ndarray of shape (N, D)
|
| 124 |
+
"""
|
| 125 |
+
all_embeddings = []
|
| 126 |
+
for i in range(0, len(seqs), batch_size):
|
| 127 |
+
batch_seqs = seqs[i : i + batch_size]
|
| 128 |
+
encoded = self.tokenizer.batch_encode_plus(
|
| 129 |
+
batch_seqs,
|
| 130 |
+
return_tensors="pt",
|
| 131 |
+
padding=True,
|
| 132 |
+
truncation=True,
|
| 133 |
+
)
|
| 134 |
+
input_ids = encoded["input_ids"].to(self.device) # (B, L)
|
| 135 |
+
attention_mask = input_ids != self.tokenizer.pad_token_id
|
| 136 |
+
|
| 137 |
+
input_ids = self._adjust_length(input_ids)
|
| 138 |
+
attention_mask = (input_ids != self.tokenizer.pad_token_id)
|
| 139 |
+
|
| 140 |
+
with torch.no_grad():
|
| 141 |
+
outs = self.model(
|
| 142 |
+
input_ids,
|
| 143 |
+
attention_mask=attention_mask,
|
| 144 |
+
output_hidden_states=True,
|
| 145 |
+
return_dict=True,
|
| 146 |
+
)
|
| 147 |
+
if hasattr(outs, "hidden_states") and outs.hidden_states is not None:
|
| 148 |
+
last_hidden = outs.hidden_states[-1] # (B, L, D)
|
| 149 |
+
else:
|
| 150 |
+
last_hidden = outs.last_hidden_state # fallback
|
| 151 |
+
|
| 152 |
+
# Exclude CLS token if present (assume first token) and pool
|
| 153 |
+
pooled = last_hidden[:, 1:, :].mean(dim=1) # (B, D)
|
| 154 |
+
all_embeddings.append(pooled.cpu().numpy())
|
| 155 |
+
|
| 156 |
+
# release fragmentation
|
| 157 |
+
torch.cuda.empty_cache()
|
| 158 |
+
|
| 159 |
+
return np.vstack(all_embeddings) # (N, D)
|
| 160 |
+
|
| 161 |
|
| 162 |
class DNABertEmbedder:
|
| 163 |
def __init__(self, device):
|
| 164 |
+
self.tokenizer = AutoTokenizer.from_pretrained("zhihan1996/DNA_bert_6", trust_remote_code=True)
|
| 165 |
+
self.model = AutoModel.from_pretrained("zhihan1996/DNA_bert_6", trust_remote_code=True).to(device)
|
| 166 |
+
self.device = device
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
|
| 168 |
def embed(self, seqs):
|
| 169 |
embs = []
|
| 170 |
for s in seqs:
|
| 171 |
+
tokens = self.tokenizer(s, return_tensors="pt", padding=True)["input_ids"].to(self.device)
|
|
|
|
|
|
|
| 172 |
with torch.no_grad():
|
| 173 |
out = self.model(tokens).last_hidden_state.mean(1)
|
| 174 |
embs.append(out.cpu().numpy())
|
| 175 |
return np.vstack(embs)
|
| 176 |
|
|
|
|
| 177 |
class NucleotideTransformerEmbedder:
|
| 178 |
def __init__(self, device):
|
| 179 |
# HF “feature-extraction” returns a list of (L, D) arrays for each input
|
|
|
|
| 181 |
self.pipe = pipeline(
|
| 182 |
"feature-extraction",
|
| 183 |
model="InstaDeepAI/nucleotide-transformer-500m-1000g",
|
| 184 |
+
device= -1 if device=="cpu" else 0 # HF uses -1 for CPU, 0 for GPU #:contentReference[oaicite:0]{index=0}
|
|
|
|
|
|
|
| 185 |
)
|
| 186 |
|
| 187 |
def embed(self, seqs):
|
|
|
|
| 191 |
"""
|
| 192 |
all_embeddings = self.pipe(seqs, truncation=True, padding=True)
|
| 193 |
# all_embeddings is a List of shape (L, D) arrays
|
| 194 |
+
pooled = [ np.mean(x, axis=0) for x in all_embeddings ]
|
| 195 |
+
return np.vstack(pooled)
|
| 196 |
+
|
| 197 |
+
# class ESMEmbedder:
|
| 198 |
+
# def __init__(self, device):
|
| 199 |
+
# self.model, self.alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
|
| 200 |
+
# self.batch_converter = self.alphabet.get_batch_converter()
|
| 201 |
+
# self.model.to(device).eval()
|
| 202 |
+
# self.device = device
|
| 203 |
+
|
| 204 |
+
# def embed(self, seqs):
|
| 205 |
+
# batch = [(str(i), seq) for i, seq in enumerate(seqs)]
|
| 206 |
+
# _, _, toks = self.batch_converter(batch)
|
| 207 |
+
# toks = toks.to(self.device)
|
| 208 |
+
# with torch.no_grad():
|
| 209 |
+
# results = self.model(toks, repr_layers=[33], return_contacts=False)
|
| 210 |
+
# reps = results["representations"][33]
|
| 211 |
+
# return reps[:, 1:-1].mean(1).cpu().numpy()
|
| 212 |
|
| 213 |
|
| 214 |
class ESMEmbedder:
|
| 215 |
+
def __init__(self, device, model_name="esm2_t33_650M_UR50D"):
|
| 216 |
+
# Try to load the specified ESM-2 model; fallback to esm1b if missing
|
| 217 |
+
self.device = device
|
| 218 |
+
try:
|
| 219 |
+
self.model, self.alphabet = getattr(esm.pretrained, model_name)()
|
| 220 |
+
self.is_esm2 = model_name.lower().startswith("esm2")
|
| 221 |
+
except AttributeError:
|
| 222 |
+
# fallback to ESM-1b
|
| 223 |
+
self.model, self.alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
|
| 224 |
+
self.is_esm2 = False
|
| 225 |
self.batch_converter = self.alphabet.get_batch_converter()
|
| 226 |
self.model.to(device).eval()
|
| 227 |
+
# determine max length: esm2 models vary; use default 1024 for esm1b
|
| 228 |
+
self.max_len = 4096 if self.is_esm2 else 1024 # adjust if your esm2 variant has explicit limit
|
| 229 |
+
# for chunking: reserve 2 tokens if model uses BOS/EOS
|
| 230 |
+
self.chunk_size = self.max_len - 2
|
| 231 |
+
self.overlap = self.chunk_size // 4 # 25% overlap to smooth boundaries
|
| 232 |
|
| 233 |
+
def _chunk_sequence(self, seq):
|
| 234 |
+
"""
|
| 235 |
+
Return list of possibly overlapping chunks of seq, each <= chunk_size.
|
| 236 |
+
"""
|
| 237 |
+
if len(seq) <= self.chunk_size:
|
| 238 |
+
return [seq]
|
| 239 |
+
step = self.chunk_size - self.overlap
|
| 240 |
+
chunks = []
|
| 241 |
+
for i in range(0, len(seq), step):
|
| 242 |
+
chunk = seq[i : i + self.chunk_size]
|
| 243 |
+
if not chunk:
|
| 244 |
+
break
|
| 245 |
+
chunks.append(chunk)
|
| 246 |
+
return chunks
|
| 247 |
|
| 248 |
+
def embed(self, seqs):
|
| 249 |
+
"""
|
| 250 |
+
seqs: List[str] of protein sequences.
|
| 251 |
+
Returns: np.ndarray of shape (N, D) pooled per-sequence embeddings.
|
| 252 |
+
"""
|
| 253 |
+
all_embeddings = []
|
| 254 |
+
for i, seq in enumerate(seqs):
|
| 255 |
+
chunks = self._chunk_sequence(seq)
|
| 256 |
+
chunk_vecs = []
|
| 257 |
+
# process chunks in batch if small number, else sequentially
|
| 258 |
+
for chunk in chunks:
|
| 259 |
+
batch = [(str(i), chunk)]
|
| 260 |
+
_, _, toks = self.batch_converter(batch)
|
| 261 |
+
toks = toks.to(self.device)
|
| 262 |
+
with torch.no_grad():
|
| 263 |
+
results = self.model(toks, repr_layers=[33], return_contacts=False)
|
| 264 |
+
reps = results["representations"][33] # (1, L, D)
|
| 265 |
+
# remove BOS/EOS if present: take 1:-1 if length permits
|
| 266 |
+
if reps.size(1) > 2:
|
| 267 |
+
rep = reps[:, 1:-1].mean(1) # (1, D)
|
| 268 |
+
else:
|
| 269 |
+
rep = reps.mean(1) # fallback
|
| 270 |
+
chunk_vecs.append(rep.squeeze(0)) # (D,)
|
| 271 |
+
if len(chunk_vecs) == 1:
|
| 272 |
+
seq_vec = chunk_vecs[0]
|
| 273 |
+
else:
|
| 274 |
+
# average chunk vectors
|
| 275 |
+
stacked = torch.stack(chunk_vecs, dim=0) # (num_chunks, D)
|
| 276 |
+
seq_vec = stacked.mean(0)
|
| 277 |
+
all_embeddings.append(seq_vec.cpu().numpy())
|
| 278 |
+
return np.vstack(all_embeddings) # (N, D)
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
# class ESMDBPEmbedder:
|
| 282 |
+
# def __init__(self, device):
|
| 283 |
+
# base_model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
|
| 284 |
+
# model_path = (
|
| 285 |
+
# Path(__file__).resolve().parent.parent
|
| 286 |
+
# / "pretrained" / "ESM-DBP" / "ESM-DBP.model"
|
| 287 |
+
# )
|
| 288 |
+
# checkpoint = torch.load(model_path, map_location="cpu")
|
| 289 |
+
# clean_sd = {}
|
| 290 |
+
# for k, v in checkpoint.items():
|
| 291 |
+
# clean_sd[k.replace("module.", "")] = v
|
| 292 |
+
# result = base_model.load_state_dict(clean_sd, strict=False)
|
| 293 |
+
# if result.missing_keys:
|
| 294 |
+
# print(f"[ESMDBP] missing keys: {result.missing_keys}")
|
| 295 |
+
# if result.unexpected_keys:
|
| 296 |
+
# print(f"[ESMDBP] unexpected keys: {result.unexpected_keys}")
|
| 297 |
+
|
| 298 |
+
# self.model = base_model.to(device).eval()
|
| 299 |
+
# self.alphabet = alphabet
|
| 300 |
+
# self.batch_converter = alphabet.get_batch_converter()
|
| 301 |
+
# self.device = device
|
| 302 |
+
|
| 303 |
+
# def embed(self, seqs):
|
| 304 |
+
# batch = [(str(i), seq) for i, seq in enumerate(seqs)]
|
| 305 |
+
# _, _, toks = self.batch_converter(batch)
|
| 306 |
+
# toks = toks.to(self.device)
|
| 307 |
+
# with torch.no_grad():
|
| 308 |
+
# out = self.model(toks, repr_layers=[33], return_contacts=False)
|
| 309 |
+
# reps = out["representations"][33]
|
| 310 |
+
# # skip start/end tokens
|
| 311 |
+
# return reps[:, 1:-1].mean(1).cpu().numpy()
|
| 312 |
|
| 313 |
class ESMDBPEmbedder:
|
| 314 |
def __init__(self, device):
|
| 315 |
base_model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
|
| 316 |
model_path = (
|
| 317 |
Path(__file__).resolve().parent.parent
|
| 318 |
+
/ "pretrained" / "ESM-DBP" / "ESM-DBP.model"
|
|
|
|
|
|
|
| 319 |
)
|
| 320 |
checkpoint = torch.load(model_path, map_location="cpu")
|
| 321 |
clean_sd = {}
|
|
|
|
| 331 |
self.alphabet = alphabet
|
| 332 |
self.batch_converter = alphabet.get_batch_converter()
|
| 333 |
self.device = device
|
| 334 |
+
self.max_len = 1024 # same limit as esm1b
|
| 335 |
+
self.chunk_size = self.max_len - 2
|
| 336 |
+
self.overlap = self.chunk_size // 4
|
| 337 |
+
|
| 338 |
+
def _chunk_sequence(self, seq):
|
| 339 |
+
if len(seq) <= self.chunk_size:
|
| 340 |
+
return [seq]
|
| 341 |
+
step = self.chunk_size - self.overlap
|
| 342 |
+
chunks = []
|
| 343 |
+
for i in range(0, len(seq), step):
|
| 344 |
+
chunk = seq[i : i + self.chunk_size]
|
| 345 |
+
if not chunk:
|
| 346 |
+
break
|
| 347 |
+
chunks.append(chunk)
|
| 348 |
+
return chunks
|
| 349 |
|
| 350 |
def embed(self, seqs):
|
| 351 |
+
all_embeddings = []
|
| 352 |
+
for i, seq in enumerate(seqs):
|
| 353 |
+
chunks = self._chunk_sequence(seq)
|
| 354 |
+
chunk_vecs = []
|
| 355 |
+
for chunk in chunks:
|
| 356 |
+
batch = [(str(i), chunk)]
|
| 357 |
+
_, _, toks = self.batch_converter(batch)
|
| 358 |
+
toks = toks.to(self.device)
|
| 359 |
+
with torch.no_grad():
|
| 360 |
+
out = self.model(toks, repr_layers=[33], return_contacts=False)
|
| 361 |
+
reps = out["representations"][33]
|
| 362 |
+
if reps.size(1) > 2:
|
| 363 |
+
rep = reps[:, 1:-1].mean(1)
|
| 364 |
+
else:
|
| 365 |
+
rep = reps.mean(1)
|
| 366 |
+
chunk_vecs.append(rep.squeeze(0))
|
| 367 |
+
if len(chunk_vecs) == 1:
|
| 368 |
+
seq_vec = chunk_vecs[0]
|
| 369 |
+
else:
|
| 370 |
+
stacked = torch.stack(chunk_vecs, dim=0)
|
| 371 |
+
seq_vec = stacked.mean(0)
|
| 372 |
+
all_embeddings.append(seq_vec.cpu().numpy())
|
| 373 |
+
return np.vstack(all_embeddings)
|
| 374 |
|
| 375 |
class GPNEmbedder:
|
| 376 |
def __init__(self, device):
|
|
|
|
| 383 |
|
| 384 |
def embed(self, seqs):
|
| 385 |
inputs = self.tokenizer(
|
| 386 |
+
seqs,
|
| 387 |
+
return_tensors="pt",
|
| 388 |
+
padding=True,
|
| 389 |
+
truncation=True
|
| 390 |
).to(self.device)
|
| 391 |
|
| 392 |
with torch.no_grad():
|
| 393 |
last_hidden = self.model(**inputs).last_hidden_state
|
| 394 |
return last_hidden.mean(dim=1).cpu().numpy()
|
| 395 |
|
|
|
|
| 396 |
class ProGenEmbedder:
|
| 397 |
def __init__(self, device):
|
| 398 |
model_name = "jinyuan22/ProGen2-base"
|
|
|
|
| 402 |
|
| 403 |
def embed(self, seqs):
|
| 404 |
inputs = self.tokenizer(
|
| 405 |
+
seqs,
|
| 406 |
+
return_tensors="pt",
|
| 407 |
+
padding=True,
|
| 408 |
+
truncation=True
|
| 409 |
).to(self.device)
|
| 410 |
with torch.no_grad():
|
| 411 |
last_hidden = self.model(**inputs).last_hidden_state
|
| 412 |
return last_hidden.mean(dim=1).cpu().numpy()
|
| 413 |
|
|
|
|
| 414 |
# ---- main pipeline ----
|
| 415 |
|
|
|
|
| 416 |
def get_embedder(name, device, for_dna=True):
|
| 417 |
name = name.lower()
|
| 418 |
if for_dna:
|
| 419 |
+
if name=="caduceus": return CaduceusEmbedder(device)
|
| 420 |
+
if name=="dnabert": return DNABertEmbedder(device)
|
| 421 |
+
if name=="nucleotide": return NucleotideTransformerEmbedder(device)
|
| 422 |
+
if name=="gpn": return GPNEmbedder(device)
|
| 423 |
+
if name=="segmentnt": return SegmentNTEmbedder(device)
|
|
|
|
|
|
|
|
|
|
| 424 |
else:
|
| 425 |
+
if name in ("esm",): return ESMEmbedder(device)
|
| 426 |
+
if name in ("esm-dbp","esm_dbp"): return ESMDBPEmbedder(device)
|
| 427 |
+
if name=="progen": return ProGenEmbedder(device)
|
|
|
|
|
|
|
|
|
|
| 428 |
raise ValueError(f"Unknown model {name} (for_dna={for_dna})")
|
| 429 |
|
| 430 |
|
| 431 |
+
def pad_token_embeddings(list_of_arrays, pad_value=0.0):
|
| 432 |
+
"""
|
| 433 |
+
list_of_arrays: list of (L_i, D) numpy arrays
|
| 434 |
+
Returns:
|
| 435 |
+
padded: (N, L_max, D) array
|
| 436 |
+
mask: (N, L_max) boolean array where True = real token, False = padding
|
| 437 |
+
"""
|
| 438 |
+
N = len(list_of_arrays)
|
| 439 |
+
D = list_of_arrays[0].shape[1]
|
| 440 |
+
L_max = max(arr.shape[0] for arr in list_of_arrays)
|
| 441 |
+
padded = np.full((N, L_max, D), pad_value, dtype=list_of_arrays[0].dtype)
|
| 442 |
+
mask = np.zeros((N, L_max), dtype=bool)
|
| 443 |
+
for i, arr in enumerate(list_of_arrays):
|
| 444 |
+
L = arr.shape[0]
|
| 445 |
+
padded[i, :L] = arr
|
| 446 |
+
mask[i, :L] = True
|
| 447 |
+
return padded, mask
|
| 448 |
+
|
| 449 |
def embed_and_save(seqs, ids, embedder, out_path):
|
| 450 |
embs = embedder.embed(seqs)
|
| 451 |
+
|
| 452 |
+
# Decide whether we got variable-length per-token outputs (list of (L, D))
|
| 453 |
+
is_variable_token = isinstance(embs, (list, tuple)) and len(embs) > 0 and hasattr(embs[0], "shape") and embs[0].ndim == 2
|
| 454 |
+
|
| 455 |
+
if is_variable_token:
|
| 456 |
+
# pad to (N, L_max, D) + mask
|
| 457 |
+
padded, mask = pad_token_embeddings(embs)
|
| 458 |
+
# Save both embeddings and mask together in an .npz for convenience
|
| 459 |
+
np.savez_compressed(out_path.with_suffix(".caduceus.npz"),
|
| 460 |
+
embeddings=padded,
|
| 461 |
+
mask=mask,
|
| 462 |
+
ids=np.array(ids, dtype=object))
|
| 463 |
+
else:
|
| 464 |
+
# fixed shape output, e.g., pooled (N, D)
|
| 465 |
+
array = np.vstack(embs) if isinstance(embs, list) else embs
|
| 466 |
+
np.save(out_path, array)
|
| 467 |
+
with open(out_path.with_suffix(".ids"), "w") as f:
|
| 468 |
+
f.write("\n".join(ids))
|
| 469 |
|
| 470 |
|
| 471 |
+
if __name__=="__main__":
|
| 472 |
|
| 473 |
p = argparse.ArgumentParser()
|
| 474 |
+
p.add_argument("--peak-fasta", default="binding_peaks_unique.fa", help="FASTA of deduplicated binding peak sequences; if present this is used for DNA embedding instead of genome JSONs")
|
| 475 |
+
p.add_argument("--genome-json-dir", default=None, help="(fallback) directory of UCSC JSONs for full chromosome embedding if peak FASTA is missing or you explicitly want chromosomes")
|
| 476 |
+
p.add_argument("--skip-dna", action="store_true", help="if set, skip the chromosome embedding step") #if glm embeddings successful but not plm embeddings
|
| 477 |
+
p.add_argument("--tf-fasta", required=True, help="input TF FASTA file")
|
| 478 |
+
p.add_argument("--chrom-model", default="caduceus")
|
| 479 |
+
p.add_argument("--tf-model", default="esm-dbp")
|
| 480 |
+
p.add_argument("--out-dir", default="data_files/processed/tfclust/hg38_tf/embeddings")
|
| 481 |
+
p.add_argument("--device", default="cpu")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 482 |
args = p.parse_args()
|
| 483 |
|
| 484 |
os.makedirs(args.out_dir, exist_ok=True)
|
| 485 |
device = args.device
|
| 486 |
|
| 487 |
if not args.skip_dna:
|
| 488 |
+
peak_fasta = Path(args.peak_fasta)
|
| 489 |
+
if peak_fasta.exists():
|
| 490 |
+
# Load peak sequences from FASTA
|
| 491 |
+
from Bio import SeqIO
|
| 492 |
+
|
| 493 |
+
peak_seqs = []
|
| 494 |
+
peak_ids = []
|
| 495 |
+
for rec in SeqIO.parse(peak_fasta, "fasta"):
|
| 496 |
+
peak_ids.append(rec.id)
|
| 497 |
+
peak_seqs.append(str(rec.seq))
|
| 498 |
+
print(f"Embedding {len(peak_seqs)} binding peak sequences from {peak_fasta}", flush=True)
|
| 499 |
+
dna_embedder = get_embedder(args.chrom_model, device, for_dna=True)
|
| 500 |
+
out_peaks = Path(args.out_dir) / f"peaks_{args.chrom_model}.npy"
|
| 501 |
+
embed_and_save(peak_seqs, peak_ids, dna_embedder, out_peaks)
|
| 502 |
+
elif args.genome_json_dir:
|
| 503 |
+
# Legacy: load full chromosomes from JSONs (chr1–22, X, Y, M)
|
| 504 |
+
genome_dir = Path(args.genome_json_dir)
|
| 505 |
+
chrom_seqs, chrom_ids = [], []
|
| 506 |
+
primary_pattern = re.compile(r"^hg38_chr(?:[1-9]|1[0-9]|2[0-2]|X|Y|M)\.json$")
|
| 507 |
+
for j in sorted(genome_dir.iterdir()):
|
| 508 |
+
if not primary_pattern.match(j.name):
|
| 509 |
+
continue
|
| 510 |
+
data = json.loads(j.read_text())
|
| 511 |
+
seq = data.get("dna") or data.get("sequence")
|
| 512 |
+
chrom = data.get("chrom") or j.stem.split("_")[-1]
|
| 513 |
+
chrom_seqs.append(seq)
|
| 514 |
+
chrom_ids.append(chrom)
|
| 515 |
+
cutoff = CaduceusEmbedder(device).chunk_size
|
| 516 |
+
long_chroms = [
|
| 517 |
+
(chrom, len(seq))
|
| 518 |
+
for chrom, seq in zip(chrom_ids, chrom_seqs)
|
| 519 |
+
if len(seq) > cutoff
|
| 520 |
+
]
|
| 521 |
+
if long_chroms:
|
| 522 |
+
print("⚠️ Chromosomes exceeding Caduceus max tokens ({}):".format(cutoff))
|
| 523 |
+
for chrom, L in long_chroms:
|
| 524 |
+
print(f" {chrom}: {L} bases")
|
| 525 |
+
else:
|
| 526 |
+
print("All chromosomes ≤ Caduceus limit ({}).".format(cutoff))
|
| 527 |
+
|
| 528 |
+
chrom_embedder = get_embedder(args.chrom_model, device, for_dna=True)
|
| 529 |
+
out_chrom = Path(args.out_dir) / f"chrom_{args.chrom_model}.npy"
|
| 530 |
+
embed_and_save(chrom_seqs, chrom_ids, chrom_embedder, out_chrom)
|
| 531 |
else:
|
| 532 |
+
raise ValueError("No input for DNA embedding: provide a peak FASTA (default binding_peaks_unique.fa) or set --genome-json-dir for chromosome JSONs.")
|
| 533 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 534 |
|
| 535 |
+
#Load TF sequences
|
| 536 |
tf_seqs, tf_ids = [], []
|
| 537 |
for record in SeqIO.parse(args.tf_fasta, "fasta"):
|
| 538 |
tf_ids.append(record.id)
|
|
|
|
| 543 |
out_tf = Path(args.out_dir) / f"tf_{args.tf_model}.npy"
|
| 544 |
embed_and_save(tf_seqs, tf_ids, tf_embedder, out_tf)
|
| 545 |
|
| 546 |
+
print("Done.")
|
dpacman/classifier/model_tmp/extract_tf_symbols.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import pandas as pd
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
FINAL_CSV = Path("/home/a03-akrishna/DPACMAN/data_files/processed/final.csv")
|
| 6 |
+
OUT_SYMBOLS = Path("tf_symbols.txt")
|
| 7 |
+
|
| 8 |
+
def normalize_tf(tf_id: str) -> str:
|
| 9 |
+
return tf_id.split("_seq")[0].upper()
|
| 10 |
+
|
| 11 |
+
def main():
|
| 12 |
+
df = pd.read_csv(FINAL_CSV, dtype=str)
|
| 13 |
+
if "TF_id" not in df.columns:
|
| 14 |
+
raise RuntimeError("final.csv missing TF_id column")
|
| 15 |
+
tf_raw = df["TF_id"].dropna().unique().tolist()
|
| 16 |
+
normalized = sorted({normalize_tf(t) for t in tf_raw})
|
| 17 |
+
print(f"Unique raw TF_id count: {len(tf_raw)}")
|
| 18 |
+
print(f"Unique normalized TF symbols: {len(normalized)}")
|
| 19 |
+
with open(OUT_SYMBOLS, "w") as f:
|
| 20 |
+
for s in normalized:
|
| 21 |
+
f.write(s + "\n")
|
| 22 |
+
print(f"Wrote normalized TF symbols to {OUT_SYMBOLS}")
|
| 23 |
+
# Optional: show sample
|
| 24 |
+
print("Sample symbols:", normalized[:50])
|
| 25 |
+
|
| 26 |
+
if __name__ == "__main__":
|
| 27 |
+
main()
|
dpacman/classifier/model_tmp/make_pair_list.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import argparse
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import random
|
| 7 |
+
import sys
|
| 8 |
+
|
| 9 |
+
def read_ids_file(p):
|
| 10 |
+
p = Path(p)
|
| 11 |
+
if not p.exists():
|
| 12 |
+
raise FileNotFoundError(f"IDs file not found: {p}")
|
| 13 |
+
return [line.strip() for line in p.open() if line.strip()]
|
| 14 |
+
|
| 15 |
+
def split_embeddings(emb_path, ids_path, out_dir, prefix):
|
| 16 |
+
out_dir = Path(out_dir)
|
| 17 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 18 |
+
|
| 19 |
+
if not Path(emb_path).exists():
|
| 20 |
+
raise FileNotFoundError(f"Embedding file not found: {emb_path}")
|
| 21 |
+
if not Path(ids_path).exists():
|
| 22 |
+
raise FileNotFoundError(f"IDs file not found: {ids_path}")
|
| 23 |
+
|
| 24 |
+
if emb_path.endswith(".npz"):
|
| 25 |
+
data = np.load(emb_path, allow_pickle=True)
|
| 26 |
+
if "embeddings" in data:
|
| 27 |
+
emb = data["embeddings"]
|
| 28 |
+
else:
|
| 29 |
+
raise ValueError(f"{emb_path} missing 'embeddings' key")
|
| 30 |
+
else:
|
| 31 |
+
emb = np.load(emb_path)
|
| 32 |
+
|
| 33 |
+
ids = read_ids_file(ids_path)
|
| 34 |
+
if len(ids) != emb.shape[0]:
|
| 35 |
+
print(f"[WARN] length mismatch: {len(ids)} ids vs {emb.shape[0]} embeddings in {emb_path}", file=sys.stderr)
|
| 36 |
+
|
| 37 |
+
mapping = {}
|
| 38 |
+
for i, ident in enumerate(ids):
|
| 39 |
+
if i >= emb.shape[0]:
|
| 40 |
+
print(f"[WARN] skipping {ident}: no embedding at index {i}", file=sys.stderr)
|
| 41 |
+
continue
|
| 42 |
+
arr = emb[i]
|
| 43 |
+
out_file = out_dir / f"{prefix}_{ident}.npy"
|
| 44 |
+
np.save(out_file, arr)
|
| 45 |
+
mapping[ident] = str(out_file)
|
| 46 |
+
return mapping
|
| 47 |
+
|
| 48 |
+
def extract_symbol_from_tf_id(full_id: str) -> str:
|
| 49 |
+
"""
|
| 50 |
+
Given a TF embedding ID like 'sp|O15062|ZBTB5_HUMAN' or 'ZBTB5_HUMAN',
|
| 51 |
+
return the gene symbol uppercase (e.g., 'ZBTB5').
|
| 52 |
+
"""
|
| 53 |
+
if "|" in full_id:
|
| 54 |
+
try:
|
| 55 |
+
# format sp|Accession|SYMBOL_HUMAN
|
| 56 |
+
genepart = full_id.split("|")[2]
|
| 57 |
+
except IndexError:
|
| 58 |
+
genepart = full_id
|
| 59 |
+
else:
|
| 60 |
+
genepart = full_id
|
| 61 |
+
symbol = genepart.split("_")[0]
|
| 62 |
+
return symbol.upper()
|
| 63 |
+
|
| 64 |
+
def build_tf_symbol_map(tf_map):
|
| 65 |
+
"""
|
| 66 |
+
Build mapping gene_symbol -> list of embedding paths.
|
| 67 |
+
"""
|
| 68 |
+
symbol_map = {}
|
| 69 |
+
for full_id, path in tf_map.items():
|
| 70 |
+
symbol = extract_symbol_from_tf_id(full_id)
|
| 71 |
+
symbol_map.setdefault(symbol, []).append(path)
|
| 72 |
+
return symbol_map
|
| 73 |
+
|
| 74 |
+
def tf_key_from_path(path: str) -> str:
|
| 75 |
+
"""
|
| 76 |
+
Given a path like .../tf_sp|O15062|ZBTB5_HUMAN.npy, extract normalized symbol 'ZBTB5'.
|
| 77 |
+
"""
|
| 78 |
+
stem = Path(path).stem # e.g., tf_sp|O15062|ZBTB5_HUMAN
|
| 79 |
+
# remove leading prefix if present (tf_)
|
| 80 |
+
if "_" in stem:
|
| 81 |
+
_, rest = stem.split("_", 1)
|
| 82 |
+
else:
|
| 83 |
+
rest = stem
|
| 84 |
+
return extract_symbol_from_tf_id(rest)
|
| 85 |
+
|
| 86 |
+
def dna_key_from_path(path: str) -> str:
|
| 87 |
+
"""
|
| 88 |
+
Given .../dna_peak42.npy -> 'peak42'
|
| 89 |
+
"""
|
| 90 |
+
stem = Path(path).stem
|
| 91 |
+
if "_" in stem:
|
| 92 |
+
_, rest = stem.split("_", 1)
|
| 93 |
+
else:
|
| 94 |
+
rest = stem
|
| 95 |
+
return rest
|
| 96 |
+
|
| 97 |
+
def main():
|
| 98 |
+
parser = argparse.ArgumentParser(
|
| 99 |
+
description="Build TF-DNA pair list from final.csv with gene-symbol normalization for TFs."
|
| 100 |
+
)
|
| 101 |
+
parser.add_argument("--final_csv", required=True, help="final.csv with TF_id and dna_sequence")
|
| 102 |
+
parser.add_argument("--dna_embed_npz", required=True, help="DNA embedding file (.npy or .npz)")
|
| 103 |
+
parser.add_argument("--dna_ids", required=True, help="IDs file for DNA embeddings (e.g., peak*.ids)")
|
| 104 |
+
parser.add_argument("--tf_embed_npy", required=True, help="TF embedding file (.npy or .npz)")
|
| 105 |
+
parser.add_argument("--tf_ids", required=True, help="IDs file for TF embeddings (e.g., sp|...|... ids)")
|
| 106 |
+
parser.add_argument("--out_dir", required=True, help="Output directory")
|
| 107 |
+
parser.add_argument("--neg_per_positive", type=int, default=2, help="Negatives per positive (half same-TF, half same-DNA)")
|
| 108 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 109 |
+
args = parser.parse_args()
|
| 110 |
+
|
| 111 |
+
random.seed(args.seed)
|
| 112 |
+
out_dir = Path(args.out_dir)
|
| 113 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 114 |
+
|
| 115 |
+
# Load final.csv
|
| 116 |
+
df = pd.read_csv(args.final_csv, dtype=str)
|
| 117 |
+
if "TF_id" not in df.columns or "dna_sequence" not in df.columns:
|
| 118 |
+
raise RuntimeError("final.csv must have columns TF_id and dna_sequence")
|
| 119 |
+
|
| 120 |
+
# Assign dna_id (unique per dna_sequence)
|
| 121 |
+
unique_seqs = df["dna_sequence"].drop_duplicates().tolist()
|
| 122 |
+
seq_to_id = {seq: f"peak{i}" for i, seq in enumerate(unique_seqs)}
|
| 123 |
+
df["dna_id"] = df["dna_sequence"].map(seq_to_id)
|
| 124 |
+
enriched_csv = out_dir / "final_with_dna_id.csv"
|
| 125 |
+
df.to_csv(enriched_csv, index=False)
|
| 126 |
+
print(f"[i] Wrote augmented final.csv with dna_id to {enriched_csv}")
|
| 127 |
+
|
| 128 |
+
# Split embeddings into per-item files
|
| 129 |
+
print(f"[i] Splitting DNA embeddings from {args.dna_embed_npz} with ids {args.dna_ids}")
|
| 130 |
+
dna_map = split_embeddings(args.dna_embed_npz, args.dna_ids, out_dir / "dna_single", "dna")
|
| 131 |
+
print(f"[i] DNA embeddings available: {len(dna_map)} (sample: {list(dna_map.keys())[:10]})")
|
| 132 |
+
print(f"[i] Splitting TF embeddings from {args.tf_embed_npy} with ids {args.tf_ids}")
|
| 133 |
+
tf_map = split_embeddings(args.tf_embed_npy, args.tf_ids, out_dir / "tf_single", "tf")
|
| 134 |
+
print(f"[i] TF embeddings available: {len(tf_map)} (sample: {list(tf_map.keys())[:10]})")
|
| 135 |
+
|
| 136 |
+
# Build gene-symbol normalized map
|
| 137 |
+
tf_symbol_map = build_tf_symbol_map(tf_map)
|
| 138 |
+
print(f"[i] TF symbol map keys (sample): {list(tf_symbol_map.keys())[:30]}")
|
| 139 |
+
|
| 140 |
+
# Diagnostic overlaps
|
| 141 |
+
norm_tf_in_final = set(t.split("_seq")[0].upper() for t in df["TF_id"].unique())
|
| 142 |
+
available_tf_symbols = set(tf_symbol_map.keys())
|
| 143 |
+
intersect_tf = norm_tf_in_final & available_tf_symbols
|
| 144 |
+
print(f"[i] Unique normalized TF symbols in final.csv: {len(norm_tf_in_final)}")
|
| 145 |
+
print(f"[i] Available TF embedding symbols: {len(available_tf_symbols)}")
|
| 146 |
+
print(f"[i] Intersection count: {len(intersect_tf)}")
|
| 147 |
+
if len(intersect_tf) == 0:
|
| 148 |
+
print("[ERROR] No overlap between normalized TF_id and TF embedding symbols.", file=sys.stderr)
|
| 149 |
+
print("Sample normalized TFs from final.csv:", sorted(list(norm_tf_in_final))[:30], file=sys.stderr)
|
| 150 |
+
print("Sample available TF symbols:", sorted(list(available_tf_symbols))[:30], file=sys.stderr)
|
| 151 |
+
sys.exit(1)
|
| 152 |
+
|
| 153 |
+
dna_ids_final = set(df["dna_id"].unique())
|
| 154 |
+
available_dna_ids = set(dna_map.keys())
|
| 155 |
+
intersect_dna = dna_ids_final & available_dna_ids
|
| 156 |
+
print(f"[i] Unique dna_id in final.csv: {len(dna_ids_final)}. Available DNA ids: {len(available_dna_ids)}. Intersection: {len(intersect_dna)}")
|
| 157 |
+
if len(intersect_dna) == 0:
|
| 158 |
+
print("[ERROR] No overlap on DNA ids.", file=sys.stderr)
|
| 159 |
+
sys.exit(1)
|
| 160 |
+
|
| 161 |
+
# Build positive pairs
|
| 162 |
+
positives = []
|
| 163 |
+
for _, row in df.iterrows():
|
| 164 |
+
tf_raw = row["TF_id"]
|
| 165 |
+
tf_symbol = tf_raw.split("_seq")[0].upper()
|
| 166 |
+
dnaid = row["dna_id"]
|
| 167 |
+
if tf_symbol not in tf_symbol_map:
|
| 168 |
+
continue
|
| 169 |
+
if dnaid not in dna_map:
|
| 170 |
+
continue
|
| 171 |
+
# pick the first embedding for that symbol
|
| 172 |
+
tf_embedding_path = tf_symbol_map[tf_symbol][0]
|
| 173 |
+
positives.append((tf_embedding_path, dna_map[dnaid], 1))
|
| 174 |
+
print(f"[i] Constructed {len(positives)} positive pairs after TF symbol resolution")
|
| 175 |
+
|
| 176 |
+
if len(positives) == 0:
|
| 177 |
+
print("[ERROR] No positive pairs could be constructed; aborting.", file=sys.stderr)
|
| 178 |
+
sys.exit(1)
|
| 179 |
+
|
| 180 |
+
# Build negative samples
|
| 181 |
+
all_tf_symbols = sorted(tf_symbol_map.keys())
|
| 182 |
+
all_dnaids = sorted(dna_map.keys())
|
| 183 |
+
positive_set = set()
|
| 184 |
+
for tf_path, dna_path, _ in positives:
|
| 185 |
+
tf_key = tf_key_from_path(tf_path)
|
| 186 |
+
dna_key = dna_key_from_path(dna_path)
|
| 187 |
+
positive_set.add((tf_key, dna_key))
|
| 188 |
+
|
| 189 |
+
negatives = []
|
| 190 |
+
half = args.neg_per_positive // 2
|
| 191 |
+
for tf_path, dna_path, _ in positives:
|
| 192 |
+
tf_key = tf_key_from_path(tf_path)
|
| 193 |
+
dna_key = dna_key_from_path(dna_path)
|
| 194 |
+
# same TF, different DNA
|
| 195 |
+
for _ in range(half):
|
| 196 |
+
candidate_dna = random.choice(all_dnaids)
|
| 197 |
+
if candidate_dna == dna_key or (tf_key, candidate_dna) in positive_set:
|
| 198 |
+
continue
|
| 199 |
+
negatives.append((tf_path, dna_map[candidate_dna], 0))
|
| 200 |
+
# same DNA, different TF
|
| 201 |
+
for _ in range(half):
|
| 202 |
+
candidate_tf_symbol = random.choice(all_tf_symbols)
|
| 203 |
+
if candidate_tf_symbol == tf_key or (candidate_tf_symbol, dna_key) in positive_set:
|
| 204 |
+
continue
|
| 205 |
+
# pick its first embedding
|
| 206 |
+
candidate_tf_path = tf_symbol_map[candidate_tf_symbol][0]
|
| 207 |
+
negatives.append((candidate_tf_path, dna_map[dnaid], 0))
|
| 208 |
+
|
| 209 |
+
print(f"[i] Sampled {len(negatives)} negatives (neg_per_positive={args.neg_per_positive})")
|
| 210 |
+
|
| 211 |
+
# Write pair list
|
| 212 |
+
pair_list_path = out_dir / "pair_list.tsv"
|
| 213 |
+
with open(pair_list_path, "w") as f:
|
| 214 |
+
for binder_path, glm_path, label in positives + negatives:
|
| 215 |
+
# binder=TF, glm=DNA
|
| 216 |
+
f.write(f"{binder_path}\t{glm_path}\t{label}\n")
|
| 217 |
+
print(f"[i] Wrote {len(positives)+len(negatives)} examples to {pair_list_path}")
|
| 218 |
+
|
| 219 |
+
if __name__ == "__main__":
|
| 220 |
+
main()
|
dpacman/classifier/model_tmp/make_peak_fasta.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
df = pd.read_csv("/home/a03-akrishna/DPACMAN/data_files/processed/final.csv", dtype=str) # adjust path if needed
|
| 5 |
+
# get unique sequences
|
| 6 |
+
uniq = df[["dna_sequence"]].drop_duplicates().reset_index(drop=True)
|
| 7 |
+
# make headers: e.g., peak0, peak1, ...
|
| 8 |
+
out_fa = Path("binding_peaks_unique.fa")
|
| 9 |
+
with open(out_fa, "w") as f:
|
| 10 |
+
for i, seq in enumerate(uniq["dna_sequence"]):
|
| 11 |
+
header = f">peak{i}"
|
| 12 |
+
f.write(f"{header}\n{seq}\n")
|
| 13 |
+
print(f"Wrote {len(uniq)} unique binding sequences to {out_fa}")
|
dpacman/classifier/model_tmp/model.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
|
| 4 |
+
class LocalCNN(nn.Module):
|
| 5 |
+
def __init__(self, dim: int = 256, kernel_size: int = 3):
|
| 6 |
+
super().__init__()
|
| 7 |
+
padding = kernel_size // 2
|
| 8 |
+
self.conv = nn.Conv1d(dim, dim, kernel_size=kernel_size, padding=padding)
|
| 9 |
+
self.act = nn.GELU()
|
| 10 |
+
self.ln = nn.LayerNorm(dim)
|
| 11 |
+
|
| 12 |
+
def forward(self, x: torch.Tensor):
|
| 13 |
+
# x: (batch, L, dim)
|
| 14 |
+
out = self.conv(x.transpose(1, 2)) # → (batch, dim, L)
|
| 15 |
+
out = self.act(out)
|
| 16 |
+
out = out.transpose(1, 2) # → (batch, L, dim)
|
| 17 |
+
return self.ln(out + x) # residual
|
| 18 |
+
|
| 19 |
+
class CrossModalBlock(nn.Module):
|
| 20 |
+
def __init__(self, dim: int = 256, heads: int = 8):
|
| 21 |
+
super().__init__()
|
| 22 |
+
# self-attention for both sides
|
| 23 |
+
self.sa_binder = nn.MultiheadAttention(dim, heads, batch_first=True)
|
| 24 |
+
self.sa_glm = nn.MultiheadAttention(dim, heads, batch_first=True)
|
| 25 |
+
self.ln_b1 = nn.LayerNorm(dim)
|
| 26 |
+
self.ln_g1 = nn.LayerNorm(dim)
|
| 27 |
+
|
| 28 |
+
self.ffn_b = nn.Sequential(nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim))
|
| 29 |
+
self.ffn_g = nn.Sequential(nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim))
|
| 30 |
+
self.ln_b2 = nn.LayerNorm(dim)
|
| 31 |
+
self.ln_g2 = nn.LayerNorm(dim)
|
| 32 |
+
|
| 33 |
+
# cross attention (binder queries, glm keys/values)
|
| 34 |
+
self.cross_attn = nn.MultiheadAttention(dim, heads, batch_first=True)
|
| 35 |
+
self.ln_c1 = nn.LayerNorm(dim)
|
| 36 |
+
self.ffn_c = nn.Sequential(nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim))
|
| 37 |
+
self.ln_c2 = nn.LayerNorm(dim)
|
| 38 |
+
|
| 39 |
+
def forward(self, binder: torch.Tensor, glm: torch.Tensor):
|
| 40 |
+
"""
|
| 41 |
+
binder: (batch, Lb, dim)
|
| 42 |
+
glm: (batch, Lg, dim) -- has passed through its local CNN beforehand
|
| 43 |
+
returns: updated binder representation (batch, Lb, dim)
|
| 44 |
+
"""
|
| 45 |
+
# binder self-attn + ffn
|
| 46 |
+
b = binder
|
| 47 |
+
b_sa, _ = self.sa_binder(b, b, b)
|
| 48 |
+
b = self.ln_b1(b + b_sa)
|
| 49 |
+
b_ff = self.ffn_b(b)
|
| 50 |
+
b = self.ln_b2(b + b_ff)
|
| 51 |
+
|
| 52 |
+
# glm self-attn + ffn
|
| 53 |
+
g = glm
|
| 54 |
+
g_sa, _ = self.sa_glm(g, g, g)
|
| 55 |
+
g = self.ln_g1(g + g_sa)
|
| 56 |
+
g_ff = self.ffn_g(g)
|
| 57 |
+
g = self.ln_g2(g + g_ff)
|
| 58 |
+
|
| 59 |
+
# cross-attention: binder queries glm
|
| 60 |
+
c_sa, _ = self.cross_attn(b, g, g)
|
| 61 |
+
c = self.ln_c1(b + c_sa)
|
| 62 |
+
c_ff = self.ffn_c(c)
|
| 63 |
+
c = self.ln_c2(c + c_ff)
|
| 64 |
+
return c # (batch, Lb, dim)
|
| 65 |
+
|
| 66 |
+
class BindPredictor(nn.Module):
|
| 67 |
+
def __init__(self,
|
| 68 |
+
input_dim: int = 256,
|
| 69 |
+
hidden_dim: int = 256,
|
| 70 |
+
heads: int = 8,
|
| 71 |
+
num_layers: int = 4,
|
| 72 |
+
use_local_cnn_on_glm: bool = True):
|
| 73 |
+
super().__init__()
|
| 74 |
+
self.proj_binder = nn.Linear(input_dim, hidden_dim)
|
| 75 |
+
self.proj_glm = nn.Linear(input_dim, hidden_dim)
|
| 76 |
+
self.use_local_cnn = use_local_cnn_on_glm
|
| 77 |
+
self.local_cnn = LocalCNN(hidden_dim) if use_local_cnn_on_glm else nn.Identity()
|
| 78 |
+
|
| 79 |
+
self.layers = nn.ModuleList([
|
| 80 |
+
CrossModalBlock(hidden_dim, heads) for _ in range(num_layers)
|
| 81 |
+
])
|
| 82 |
+
|
| 83 |
+
self.ln_out = nn.LayerNorm(hidden_dim)
|
| 84 |
+
self.head = nn.Sequential(
|
| 85 |
+
nn.Linear(hidden_dim, 1),
|
| 86 |
+
nn.Sigmoid()
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
def forward(self, binder_emb, glm_emb):
|
| 90 |
+
"""
|
| 91 |
+
binder_emb, glm_emb: (batch, L, input_dim)
|
| 92 |
+
"""
|
| 93 |
+
b = self.proj_binder(binder_emb) # (B, Lb, hidden_dim)
|
| 94 |
+
g = self.proj_glm(glm_emb) # (B, Lg, hidden_dim)
|
| 95 |
+
if self.use_local_cnn:
|
| 96 |
+
g = self.local_cnn(g) # local context injected
|
| 97 |
+
|
| 98 |
+
for layer in self.layers:
|
| 99 |
+
b = layer(b, g) # update binder with cross-modal info
|
| 100 |
+
|
| 101 |
+
pooled = b.mean(dim=1) # (B, hidden_dim)
|
| 102 |
+
out = self.ln_out(pooled)
|
| 103 |
+
return self.head(out).squeeze(-1) # (B,)
|
dpacman/classifier/model_tmp/prep_splits.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import sys
|
| 4 |
+
import json
|
| 5 |
+
from sklearn.decomposition import TruncatedSVD
|
| 6 |
+
from sklearn.model_selection import train_test_split
|
| 7 |
+
from collections import Counter
|
| 8 |
+
|
| 9 |
+
def parse_pair_list(pair_list_path):
|
| 10 |
+
binder_paths, glm_paths, labels = [], [], []
|
| 11 |
+
with open(pair_list_path) as f:
|
| 12 |
+
for lineno, line in enumerate(f, start=1):
|
| 13 |
+
if not line.strip():
|
| 14 |
+
continue
|
| 15 |
+
parts = line.strip().split()
|
| 16 |
+
if len(parts) != 3:
|
| 17 |
+
print(f"[WARN] skipping malformed line {lineno}: {line.strip()}", file=sys.stderr)
|
| 18 |
+
continue
|
| 19 |
+
b, g, l = parts
|
| 20 |
+
try:
|
| 21 |
+
lab = int(l)
|
| 22 |
+
except ValueError:
|
| 23 |
+
print(f"[WARN] invalid label on line {lineno}: {l}", file=sys.stderr)
|
| 24 |
+
continue
|
| 25 |
+
binder_paths.append(b)
|
| 26 |
+
glm_paths.append(g)
|
| 27 |
+
labels.append(lab)
|
| 28 |
+
return binder_paths, glm_paths, labels
|
| 29 |
+
|
| 30 |
+
def build_tf_compressed_cache(binder_paths, target_dim=256):
|
| 31 |
+
"""
|
| 32 |
+
Load all unique TF (binder) embeddings, fit reduction if needed, and return dict mapping path->(L, target_dim) array.
|
| 33 |
+
"""
|
| 34 |
+
unique_paths = sorted(set(binder_paths))
|
| 35 |
+
print(f"[i] Found {len(unique_paths)} unique TF embedding files to compress.", flush=True)
|
| 36 |
+
# Load all embeddings to determine dimensionality
|
| 37 |
+
samples = []
|
| 38 |
+
for p in unique_paths:
|
| 39 |
+
arr = np.load(p)
|
| 40 |
+
samples.append(arr)
|
| 41 |
+
# Determine if reduction needed: assume all have same embedding width
|
| 42 |
+
first = samples[0]
|
| 43 |
+
orig_dim = first.shape[1] if first.ndim == 2 else 1
|
| 44 |
+
reduction_needed = (orig_dim != target_dim)
|
| 45 |
+
tf_cache = {}
|
| 46 |
+
|
| 47 |
+
if reduction_needed:
|
| 48 |
+
# Build matrix to fit SVD: we need a 2D matrix per embedding; if lengths vary we can't directly stack.
|
| 49 |
+
# We'll do reduction per sequence individually using TruncatedSVD on concatenated flattened features:
|
| 50 |
+
# Simplest: for variable lengths, reduce each embedding separately with a learned linear projection.
|
| 51 |
+
# Here we fit a single TruncatedSVD on the concatenation of all sequence tokens (flattened) by padding/truncating to a fixed length.
|
| 52 |
+
# To avoid complexity, use PCA-like linear projection learned via SVD on mean-pooled vectors:
|
| 53 |
+
pooled = []
|
| 54 |
+
for arr in samples:
|
| 55 |
+
if arr.ndim == 2:
|
| 56 |
+
pooled.append(arr.mean(axis=0)) # (orig_dim,)
|
| 57 |
+
else:
|
| 58 |
+
pooled.append(arr) # degenerate
|
| 59 |
+
pooled_mat = np.stack(pooled, axis=0) # (N, orig_dim)
|
| 60 |
+
print(f"[i] Fitting TruncatedSVD on TF pooled embeddings: {pooled_mat.shape} -> {target_dim}", flush=True)
|
| 61 |
+
svd = TruncatedSVD(n_components=target_dim, random_state=42)
|
| 62 |
+
reduced_pooled = svd.fit_transform(pooled_mat) # (N, target_dim)
|
| 63 |
+
|
| 64 |
+
# For each original embedding, project token-level vectors by multiplying token vector with svd.components_.T
|
| 65 |
+
# svd.components_: (target_dim, orig_dim) so projection matrix is (orig_dim, target_dim)
|
| 66 |
+
proj_mat = svd.components_.T # (orig_dim, target_dim)
|
| 67 |
+
for i, p in enumerate(unique_paths):
|
| 68 |
+
arr = samples[i] # shape (L, orig_dim)
|
| 69 |
+
if arr.ndim == 1:
|
| 70 |
+
arr2 = arr @ proj_mat # (target_dim,)
|
| 71 |
+
else:
|
| 72 |
+
# project each token: (L, orig_dim) @ (orig_dim, target_dim) -> (L, target_dim)
|
| 73 |
+
arr2 = arr @ proj_mat
|
| 74 |
+
tf_cache[p] = arr2 # reduced per-token representation
|
| 75 |
+
print("[i] Completed compression of TF embeddings.", flush=True)
|
| 76 |
+
else:
|
| 77 |
+
# already correct dim: just cache originals
|
| 78 |
+
print(f"[i] TF embeddings already {target_dim}-dimensional; skipping reduction.", flush=True)
|
| 79 |
+
for i, p in enumerate(unique_paths):
|
| 80 |
+
arr = samples[i]
|
| 81 |
+
tf_cache[p] = arr
|
| 82 |
+
return tf_cache
|
| 83 |
+
|
| 84 |
+
def main():
|
| 85 |
+
#df = pd.read_csv("../data_files/processed/fimo/ananya_aug4_2025_final.csv")
|
| 86 |
+
|
| 87 |
+
binder_paths, glm_paths, labels = parse_pair_list("../data_files/processed/fimo/ananya_aug4_2025_pair_list.tsv")
|
| 88 |
+
|
| 89 |
+
if len(labels) == 0:
|
| 90 |
+
print("[ERROR] No valid pairs parsed. Exiting.", file=sys.stderr)
|
| 91 |
+
sys.exit(1)
|
| 92 |
+
|
| 93 |
+
label_counts = Counter(labels)
|
| 94 |
+
print(f"[i] Total examples parsed: {len(labels)}. Label distribution: {label_counts}", flush=True)
|
| 95 |
+
|
| 96 |
+
# build compressed TF cache (reduces to 256 if needed)
|
| 97 |
+
#tf_compressed_cache = build_tf_compressed_cache(binder_paths, target_dim=256)
|
| 98 |
+
|
| 99 |
+
# Combine all data into one structure for easy splitting
|
| 100 |
+
data = list(zip(binder_paths, glm_paths, labels))
|
| 101 |
+
|
| 102 |
+
# First split: train vs temp (val+test)
|
| 103 |
+
train_data, temp_data = train_test_split(data, test_size=0.2, random_state=42)
|
| 104 |
+
|
| 105 |
+
# Second split: val vs test (50% of 20% → 10% each)
|
| 106 |
+
val_data, test_data = train_test_split(temp_data, test_size=0.5, random_state=42)
|
| 107 |
+
|
| 108 |
+
# Unpack for dataset construction
|
| 109 |
+
def unpack(data):
|
| 110 |
+
binders, glms, labels = zip(*data)
|
| 111 |
+
return list(binders), list(glms), list(labels)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def save_split(binder_paths, glm_paths, labels, out_path):
|
| 115 |
+
df = pd.DataFrame({
|
| 116 |
+
"binder_path": binder_paths,
|
| 117 |
+
"glm_path": glm_paths,
|
| 118 |
+
"label": labels,
|
| 119 |
+
})
|
| 120 |
+
df.to_csv(out_path, index=False)
|
| 121 |
+
|
| 122 |
+
# Unpack data for saving
|
| 123 |
+
train_binders, train_glms, train_labels = unpack(train_data)
|
| 124 |
+
val_binders, val_glms, val_labels = unpack(val_data)
|
| 125 |
+
test_binders, test_glms, test_labels = unpack(test_data)
|
| 126 |
+
|
| 127 |
+
# Save each split
|
| 128 |
+
save_split(train_binders, train_glms, train_labels, "../data_files/splits/train.csv")
|
| 129 |
+
save_split(val_binders, val_glms, val_labels, "../data_files/splits/val.csv")
|
| 130 |
+
save_split(test_binders, test_glms, test_labels, "../data_files/splits/test.csv")
|
| 131 |
+
|
| 132 |
+
if __name__=="__main__":
|
| 133 |
+
main()
|
dpacman/classifier/model_tmp/train.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import argparse
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
from model import BindPredictor
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from collections import Counter
|
| 9 |
+
from sklearn.metrics import roc_auc_score, average_precision_score
|
| 10 |
+
from sklearn.decomposition import TruncatedSVD
|
| 11 |
+
import sys
|
| 12 |
+
|
| 13 |
+
from dpacman.utils.models import set_seed
|
| 14 |
+
|
| 15 |
+
def build_tf_compressed_cache(binder_paths, target_dim=256):
|
| 16 |
+
"""
|
| 17 |
+
Load all unique TF (binder) embeddings, fit reduction if needed, and return dict mapping path->(L, target_dim) array.
|
| 18 |
+
"""
|
| 19 |
+
unique_paths = sorted(set(binder_paths))
|
| 20 |
+
print(f"[i] Found {len(unique_paths)} unique TF embedding files to compress.", flush=True)
|
| 21 |
+
# Load all embeddings to determine dimensionality
|
| 22 |
+
samples = []
|
| 23 |
+
for p in unique_paths:
|
| 24 |
+
arr = np.load(p)
|
| 25 |
+
samples.append(arr)
|
| 26 |
+
# Determine if reduction needed: assume all have same embedding width
|
| 27 |
+
first = samples[0]
|
| 28 |
+
orig_dim = first.shape[1] if first.ndim == 2 else 1
|
| 29 |
+
reduction_needed = (orig_dim != target_dim)
|
| 30 |
+
tf_cache = {}
|
| 31 |
+
|
| 32 |
+
if reduction_needed:
|
| 33 |
+
# Build matrix to fit SVD: we need a 2D matrix per embedding; if lengths vary we can't directly stack.
|
| 34 |
+
# We'll do reduction per sequence individually using TruncatedSVD on concatenated flattened features:
|
| 35 |
+
# Simplest: for variable lengths, reduce each embedding separately with a learned linear projection.
|
| 36 |
+
# Here we fit a single TruncatedSVD on the concatenation of all sequence tokens (flattened) by padding/truncating to a fixed length.
|
| 37 |
+
# To avoid complexity, use PCA-like linear projection learned via SVD on mean-pooled vectors:
|
| 38 |
+
pooled = []
|
| 39 |
+
for arr in samples:
|
| 40 |
+
if arr.ndim == 2:
|
| 41 |
+
pooled.append(arr.mean(axis=0)) # (orig_dim,)
|
| 42 |
+
else:
|
| 43 |
+
pooled.append(arr) # degenerate
|
| 44 |
+
pooled_mat = np.stack(pooled, axis=0) # (N, orig_dim)
|
| 45 |
+
print(f"[i] Fitting TruncatedSVD on TF pooled embeddings: {pooled_mat.shape} -> {target_dim}", flush=True)
|
| 46 |
+
svd = TruncatedSVD(n_components=target_dim, random_state=42)
|
| 47 |
+
reduced_pooled = svd.fit_transform(pooled_mat) # (N, target_dim)
|
| 48 |
+
|
| 49 |
+
# For each original embedding, project token-level vectors by multiplying token vector with svd.components_.T
|
| 50 |
+
# svd.components_: (target_dim, orig_dim) so projection matrix is (orig_dim, target_dim)
|
| 51 |
+
proj_mat = svd.components_.T # (orig_dim, target_dim)
|
| 52 |
+
for i, p in enumerate(unique_paths):
|
| 53 |
+
arr = samples[i] # shape (L, orig_dim)
|
| 54 |
+
if arr.ndim == 1:
|
| 55 |
+
arr2 = arr @ proj_mat # (target_dim,)
|
| 56 |
+
else:
|
| 57 |
+
# project each token: (L, orig_dim) @ (orig_dim, target_dim) -> (L, target_dim)
|
| 58 |
+
arr2 = arr @ proj_mat
|
| 59 |
+
tf_cache[p] = arr2 # reduced per-token representation
|
| 60 |
+
print("[i] Completed compression of TF embeddings.", flush=True)
|
| 61 |
+
else:
|
| 62 |
+
# already correct dim: just cache originals
|
| 63 |
+
print(f"[i] TF embeddings already {target_dim}-dimensional; skipping reduction.", flush=True)
|
| 64 |
+
for i, p in enumerate(unique_paths):
|
| 65 |
+
arr = samples[i]
|
| 66 |
+
tf_cache[p] = arr
|
| 67 |
+
return tf_cache
|
| 68 |
+
|
| 69 |
+
def evaluate(model, dl, device):
|
| 70 |
+
model.eval()
|
| 71 |
+
all_labels = []
|
| 72 |
+
all_preds = []
|
| 73 |
+
with torch.no_grad():
|
| 74 |
+
for b, g, y in dl:
|
| 75 |
+
b = b.to(device)
|
| 76 |
+
g = g.to(device)
|
| 77 |
+
y = y.to(device)
|
| 78 |
+
pred = model(b, g)
|
| 79 |
+
all_labels.append(y.cpu())
|
| 80 |
+
all_preds.append(pred.cpu())
|
| 81 |
+
if not all_labels:
|
| 82 |
+
return 0.0, 0.0
|
| 83 |
+
y_true = torch.cat(all_labels).numpy()
|
| 84 |
+
y_score = torch.cat(all_preds).numpy()
|
| 85 |
+
try:
|
| 86 |
+
auc = roc_auc_score(y_true, y_score)
|
| 87 |
+
except Exception:
|
| 88 |
+
auc = 0.0
|
| 89 |
+
try:
|
| 90 |
+
ap = average_precision_score(y_true, y_score)
|
| 91 |
+
except Exception:
|
| 92 |
+
ap = 0.0
|
| 93 |
+
return auc, ap
|
| 94 |
+
|
| 95 |
+
def unpack(data):
|
| 96 |
+
binders, glms, labels = zip(*data)
|
| 97 |
+
return list(binders), list(glms), list(labels)
|
| 98 |
+
|
| 99 |
+
# ---- main ------------------------------------------------------------
|
| 100 |
+
def main(cfg):
|
| 101 |
+
# Set seed for reproducibility
|
| 102 |
+
set_seed(cfg.seed)
|
| 103 |
+
|
| 104 |
+
parser.add_argument("--out_dir", type=str, required=True)
|
| 105 |
+
parser.add_argument("--epochs", type=int, default=10)
|
| 106 |
+
parser.add_argument("--batch_size", type=int, default=32)
|
| 107 |
+
parser.add_argument("--lr", type=float, default=1e-4)
|
| 108 |
+
parser.add_argument("--device", type=str, default="cuda")
|
| 109 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 110 |
+
args = parser.parse_args()
|
| 111 |
+
|
| 112 |
+
#
|
| 113 |
+
print("DEBUG: starting training script with in-line TF compression", flush=True)
|
| 114 |
+
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
|
| 115 |
+
binder_paths, glm_paths, labels = parse_pair_list(cfg.pair_list)
|
| 116 |
+
|
| 117 |
+
if len(labels) == 0:
|
| 118 |
+
print("[ERROR] No valid pairs parsed. Exiting.", file=sys.stderr)
|
| 119 |
+
sys.exit(1)
|
| 120 |
+
|
| 121 |
+
label_counts = Counter(labels)
|
| 122 |
+
print(f"[i] Total examples parsed: {len(labels)}. Label distribution: {label_counts}", flush=True)
|
| 123 |
+
|
| 124 |
+
# build compressed TF cache (reduces to 256 if needed)
|
| 125 |
+
tf_compressed_cache = build_tf_compressed_cache(binder_paths, target_dim=256)
|
| 126 |
+
|
| 127 |
+
# load training data aloiaushasfoiuhasfoiuafasdfoihuaaasdfoiuhasfaaoiufhasfoasasfoiuh
|
| 128 |
+
|
| 129 |
+
train_ds = PairDataset(None, tf_compressed_cache=tf_compressed_cache)
|
| 130 |
+
val_ds = PairDataset(*subset(val_i), tf_compressed_cache=tf_compressed_cache)
|
| 131 |
+
test_ds = PairDataset(*subset(test_i), tf_compressed_cache=tf_compressed_cache)
|
| 132 |
+
|
| 133 |
+
print(f"[i] Train/Val/Test sizes: {len(train_ds)}/{len(val_ds)}/{len(test_ds)}", flush=True)
|
| 134 |
+
if len(train_ds) == 0 or len(val_ds) == 0:
|
| 135 |
+
print("[ERROR] Train or validation split is empty; cannot proceed.", file=sys.stderr)
|
| 136 |
+
sys.exit(1)
|
| 137 |
+
|
| 138 |
+
train_dl = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn)
|
| 139 |
+
val_dl = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn)
|
| 140 |
+
test_dl = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn)
|
| 141 |
+
|
| 142 |
+
model = BindPredictor(input_dim=256, hidden_dim=256, heads=8, num_layers=3, use_local_cnn_on_glm=True)
|
| 143 |
+
model = model.to(device)
|
| 144 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-3)
|
| 145 |
+
loss_fn = nn.BCELoss()
|
| 146 |
+
|
| 147 |
+
best_val = -float("inf")
|
| 148 |
+
os_out = Path(args.out_dir)
|
| 149 |
+
os_out.mkdir(exist_ok=True, parents=True)
|
| 150 |
+
|
| 151 |
+
for epoch in range(1, args.epochs + 1):
|
| 152 |
+
print(f"[Epoch {epoch}] starting...", flush=True)
|
| 153 |
+
model.train()
|
| 154 |
+
running_loss = 0.0
|
| 155 |
+
for b, g, y in train_dl:
|
| 156 |
+
b = b.to(device)
|
| 157 |
+
g = g.to(device)
|
| 158 |
+
y = y.to(device)
|
| 159 |
+
pred = model(b, g)
|
| 160 |
+
loss = loss_fn(pred, y)
|
| 161 |
+
optimizer.zero_grad()
|
| 162 |
+
loss.backward()
|
| 163 |
+
optimizer.step()
|
| 164 |
+
running_loss += loss.item() * b.size(0)
|
| 165 |
+
train_loss = running_loss / len(train_ds)
|
| 166 |
+
val_auc, val_ap = evaluate(model, val_dl, device)
|
| 167 |
+
print(f"[Epoch {epoch}] train_loss={train_loss:.4f} val_auc={val_auc:.4f} val_ap={val_ap:.4f}", flush=True)
|
| 168 |
+
|
| 169 |
+
if val_auc > best_val:
|
| 170 |
+
best_val = val_auc
|
| 171 |
+
torch.save(model.state_dict(), os_out / "best_model.pt")
|
| 172 |
+
print(f"[Epoch {epoch}] Saved new best model with val_auc={val_auc:.4f}", flush=True)
|
| 173 |
+
|
| 174 |
+
torch.save(model.state_dict(), os_out / "last_model.pt")
|
| 175 |
+
test_auc, test_ap = evaluate(model, test_dl, device)
|
| 176 |
+
print(f"FINAL TEST: AUC={test_auc:.4f} AP={test_ap:.4f}", flush=True)
|
| 177 |
+
print(f"[i] Models written to {os_out}/best_model.pt and last_model.pt", flush=True)
|
| 178 |
+
|
| 179 |
+
if __name__ == "__main__":
|
| 180 |
+
main()
|
dpacman/data_modules/__init__.py
ADDED
|
File without changes
|
dpacman/data_modules/pair.py
ADDED
|
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import argparse
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
from torch.utils.data import Dataset, DataLoader
|
| 7 |
+
from lightning import LightningDataModule
|
| 8 |
+
from model import BindPredictor
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from collections import Counter
|
| 11 |
+
from sklearn.metrics import roc_auc_score, average_precision_score
|
| 12 |
+
from sklearn.decomposition import TruncatedSVD
|
| 13 |
+
from multiprocessing import cpu_count
|
| 14 |
+
from functools import partial
|
| 15 |
+
import random
|
| 16 |
+
import sys
|
| 17 |
+
import pandas as pd
|
| 18 |
+
|
| 19 |
+
from dpacman.utils import RankedLogger
|
| 20 |
+
|
| 21 |
+
logger = RankedLogger(__name__, rank_zero_only=True)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# ---- dataset ---------------------------------------------------------
|
| 25 |
+
class PairDataset(Dataset):
|
| 26 |
+
def __init__(self, tf_paths, dna_paths, final_df, tf_cache):
|
| 27 |
+
"""
|
| 28 |
+
tf_cache: dict mapping binder_path -> compressed (256-d) tensor/array
|
| 29 |
+
"""
|
| 30 |
+
self.tf_paths, self.dna_paths = tf_paths, dna_paths
|
| 31 |
+
self.tf_cache = tf_cache
|
| 32 |
+
self.targets = {}
|
| 33 |
+
for _, row in final_df.iterrows():
|
| 34 |
+
dna_id = row["dna_id"]
|
| 35 |
+
vec = np.array(list(map(float, row["score_sig_r2"].split(","))), dtype=np.float32)
|
| 36 |
+
self.targets[dna_id] = vec
|
| 37 |
+
|
| 38 |
+
def __len__(self):
|
| 39 |
+
return len(self.tf_paths)
|
| 40 |
+
|
| 41 |
+
def __getitem__(self, i):
|
| 42 |
+
b = self.tf_cache[self.tf_paths[i]] # (L_b, D_b) or (D_b,)
|
| 43 |
+
if b.ndim==1: b = b[None,:]
|
| 44 |
+
g = np.load(self.dna_paths[i]) # (L_g, 256) or (256,)
|
| 45 |
+
if g.ndim==1: g = g[None,:]
|
| 46 |
+
|
| 47 |
+
stem = Path(self.dna_paths[i]).stem
|
| 48 |
+
dna_id = stem.replace("dna_","")
|
| 49 |
+
t = self.targets.get(dna_id, np.zeros(g.shape[0],dtype=np.float32))
|
| 50 |
+
|
| 51 |
+
return torch.from_numpy(b).float(), \
|
| 52 |
+
torch.from_numpy(g).float(), \
|
| 53 |
+
torch.from_numpy(t).float()
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class PairDataModule(LightningDataModule):
|
| 57 |
+
def __init__(
|
| 58 |
+
self,
|
| 59 |
+
train_file: str = "data_files/splits/train.csv",
|
| 60 |
+
val_file: str = "data_files/splits/val.csv",
|
| 61 |
+
test_file: str = "data_files/splits/test.csv",
|
| 62 |
+
tokenizer_path="facebook/esm2_t33_650M_UR50D",
|
| 63 |
+
batch_size: int = 1,
|
| 64 |
+
num_workers=8,
|
| 65 |
+
maximize_num_workers=False,
|
| 66 |
+
debug_run: bool = False,
|
| 67 |
+
pin_memory: bool = False,
|
| 68 |
+
):
|
| 69 |
+
super().__init__()
|
| 70 |
+
self.save_hyperparameters()
|
| 71 |
+
self.debug_run = debug_run
|
| 72 |
+
|
| 73 |
+
# Initialize the data files
|
| 74 |
+
self.train_data_file = train_file
|
| 75 |
+
self.val_data_file = val_file
|
| 76 |
+
self.test_data_file = test_file
|
| 77 |
+
|
| 78 |
+
# Initialize hyperparameters like batch size
|
| 79 |
+
self.batch_size = batch_size
|
| 80 |
+
self.num_wokers = cpu_count() if maximize_num_workers else min(num_workers, cpu_count())
|
| 81 |
+
|
| 82 |
+
logger.info(f"num_workers={self.num_wokers}")
|
| 83 |
+
logger.info("Initialized BinderDecoyDataModule constants")
|
| 84 |
+
|
| 85 |
+
def load_and_unpack(self, file_path, lim=None):
|
| 86 |
+
"""
|
| 87 |
+
Load and unpack an input csv whose columns are binder_path,glm_path,label
|
| 88 |
+
"""
|
| 89 |
+
df = pd.read_csv(file_path)
|
| 90 |
+
if lim is not None:
|
| 91 |
+
df = df[:lim].reset_index(drop=True)
|
| 92 |
+
|
| 93 |
+
binder_paths = df["binder_path"].tolist()
|
| 94 |
+
glm_paths = df["glm_path"].tolist()
|
| 95 |
+
labels = df["label"].tolist()
|
| 96 |
+
|
| 97 |
+
return binder_paths, glm_paths, labels
|
| 98 |
+
|
| 99 |
+
def setup(self, stage):
|
| 100 |
+
lim = 5 if self.debug_run else None
|
| 101 |
+
if stage=="train":
|
| 102 |
+
binder_paths, glm_paths, labels = self.load_file(self.train_data_file, lim=lim)
|
| 103 |
+
self.train_dataset = PairDataset(binder_paths, glm_paths, labels)
|
| 104 |
+
elif stage=="val":
|
| 105 |
+
binder_paths, glm_paths, labels = self.load_file(self.val_data_file, lim=lim)
|
| 106 |
+
self.val_dataset = PairDataset(binder_paths, glm_paths, labels)
|
| 107 |
+
elif stage=="test":
|
| 108 |
+
binder_paths, glm_paths, labels = self.load_file(self.test_data_file, lim=lim)
|
| 109 |
+
self.test_dataset = PairDataset(binder_paths, glm_paths, labels)
|
| 110 |
+
else:
|
| 111 |
+
raise RuntimeError(f"Stage {stage} is not defined. Must be train, val, or test.")
|
| 112 |
+
|
| 113 |
+
def train_dataloader(self):
|
| 114 |
+
return DataLoader(
|
| 115 |
+
self.train_dataset,
|
| 116 |
+
batch_size=self.batch_size,
|
| 117 |
+
collate_fn=collate_fn,
|
| 118 |
+
num_workers=self.num_wokers,
|
| 119 |
+
pin_memory=self.hparams.pin_memory,
|
| 120 |
+
shuffle=True,
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
def val_dataloader(self):
|
| 124 |
+
return DataLoader(
|
| 125 |
+
self.val_dataset,
|
| 126 |
+
batch_size=self.batch_size,
|
| 127 |
+
collate_fn=collate_fn,
|
| 128 |
+
num_workers=self.num_wokers,
|
| 129 |
+
pin_memory=self.hparams.pin_memory,
|
| 130 |
+
shuffle=False,
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
def test_dataloader(self):
|
| 134 |
+
return DataLoader(
|
| 135 |
+
self.test_dataset,
|
| 136 |
+
batch_size=self.batch_size,
|
| 137 |
+
collate_fn=collate_fn,
|
| 138 |
+
num_workers=self.num_wokers,
|
| 139 |
+
pin_memory=self.hparams.pin_memory,
|
| 140 |
+
shuffle=False,
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def collate_fn(batch):
|
| 145 |
+
Bs = [b.shape[0] for b,_,_ in batch]
|
| 146 |
+
Gs = [g.shape[0] for _,g,_ in batch]
|
| 147 |
+
maxB, maxG = max(Bs), max(Gs)
|
| 148 |
+
|
| 149 |
+
def pad_seq(x, L):
|
| 150 |
+
if x.shape[0] < L:
|
| 151 |
+
pad = torch.zeros((L-x.shape[0], x.shape[1]), dtype=x.dtype, device=x.device)
|
| 152 |
+
return torch.cat([x, pad], dim=0)
|
| 153 |
+
return x
|
| 154 |
+
|
| 155 |
+
def pad_t(y, L):
|
| 156 |
+
if y.shape[0] < L:
|
| 157 |
+
pad = torch.zeros((L-y.shape[0],), dtype=y.dtype, device=y.device)
|
| 158 |
+
return torch.cat([y, pad], dim=0)
|
| 159 |
+
return y
|
| 160 |
+
|
| 161 |
+
b_stack = torch.stack([pad_seq(b, maxB) for b,_,_ in batch])
|
| 162 |
+
g_stack = torch.stack([pad_seq(g, maxG) for _,g,_ in batch])
|
| 163 |
+
t_stack = torch.stack([pad_t(t, maxG) for *_,t in batch])
|
| 164 |
+
return b_stack, g_stack, t_stack
|
| 165 |
+
|
| 166 |
+
def build_tf_compressed_cache(binder_paths, target_dim=256):
|
| 167 |
+
"""
|
| 168 |
+
Load all unique TF (binder) embeddings, fit reduction if needed, and return dict mapping path->(L, target_dim) array.
|
| 169 |
+
"""
|
| 170 |
+
unique_paths = sorted(set(binder_paths))
|
| 171 |
+
logger.info(f"[i] Found {len(unique_paths)} unique TF embedding files to compress.", flush=True)
|
| 172 |
+
# Load all embeddings to determine dimensionality
|
| 173 |
+
samples = []
|
| 174 |
+
for p in unique_paths:
|
| 175 |
+
arr = np.load(p)
|
| 176 |
+
samples.append(arr)
|
| 177 |
+
# Determine if reduction needed: assume all have same embedding width
|
| 178 |
+
first = samples[0]
|
| 179 |
+
orig_dim = first.shape[1] if first.ndim == 2 else 1
|
| 180 |
+
reduction_needed = (orig_dim != target_dim)
|
| 181 |
+
tf_cache = {}
|
| 182 |
+
|
| 183 |
+
if reduction_needed:
|
| 184 |
+
# Build matrix to fit SVD: we need a 2D matrix per embedding; if lengths vary we can't directly stack.
|
| 185 |
+
# We'll do reduction per sequence individually using TruncatedSVD on concatenated flattened features:
|
| 186 |
+
# Simplest: for variable lengths, reduce each embedding separately with a learned linear projection.
|
| 187 |
+
# Here we fit a single TruncatedSVD on the concatenation of all sequence tokens (flattened) by padding/truncating to a fixed length.
|
| 188 |
+
# To avoid complexity, use PCA-like linear projection learned via SVD on mean-pooled vectors:
|
| 189 |
+
pooled = []
|
| 190 |
+
for arr in samples:
|
| 191 |
+
if arr.ndim == 2:
|
| 192 |
+
pooled.append(arr.mean(axis=0)) # (orig_dim,)
|
| 193 |
+
else:
|
| 194 |
+
pooled.append(arr) # degenerate
|
| 195 |
+
pooled_mat = np.stack(pooled, axis=0) # (N, orig_dim)
|
| 196 |
+
logger.info(f"[i] Fitting TruncatedSVD on TF pooled embeddings: {pooled_mat.shape} -> {target_dim}", flush=True)
|
| 197 |
+
svd = TruncatedSVD(n_components=target_dim, random_state=42)
|
| 198 |
+
reduced_pooled = svd.fit_transform(pooled_mat) # (N, target_dim)
|
| 199 |
+
|
| 200 |
+
# For each original embedding, project token-level vectors by multiplying token vector with svd.components_.T
|
| 201 |
+
# svd.components_: (target_dim, orig_dim) so projection matrix is (orig_dim, target_dim)
|
| 202 |
+
proj_mat = svd.components_.T # (orig_dim, target_dim)
|
| 203 |
+
for i, p in enumerate(unique_paths):
|
| 204 |
+
arr = samples[i] # shape (L, orig_dim)
|
| 205 |
+
if arr.ndim == 1:
|
| 206 |
+
arr2 = arr @ proj_mat # (target_dim,)
|
| 207 |
+
else:
|
| 208 |
+
# project each token: (L, orig_dim) @ (orig_dim, target_dim) -> (L, target_dim)
|
| 209 |
+
arr2 = arr @ proj_mat
|
| 210 |
+
tf_cache[p] = arr2 # reduced per-token representation
|
| 211 |
+
logger.info("[i] Completed compression of TF embeddings.", flush=True)
|
| 212 |
+
else:
|
| 213 |
+
# already correct dim: just cache originals
|
| 214 |
+
logger.info(f"[i] TF embeddings already {target_dim}-dimensional; skipping reduction.", flush=True)
|
| 215 |
+
for i, p in enumerate(unique_paths):
|
| 216 |
+
arr = samples[i]
|
| 217 |
+
tf_cache[p] = arr
|
| 218 |
+
return tf_cache
|
| 219 |
+
|
| 220 |
+
def evaluate(model, dl, device):
|
| 221 |
+
model.eval()
|
| 222 |
+
all_labels = []
|
| 223 |
+
all_preds = []
|
| 224 |
+
with torch.no_grad():
|
| 225 |
+
for b, g, y in dl:
|
| 226 |
+
b = b.to(device)
|
| 227 |
+
g = g.to(device)
|
| 228 |
+
y = y.to(device)
|
| 229 |
+
pred = model(b, g)
|
| 230 |
+
all_labels.append(y.cpu())
|
| 231 |
+
all_preds.append(pred.cpu())
|
| 232 |
+
if not all_labels:
|
| 233 |
+
return 0.0, 0.0
|
| 234 |
+
y_true = torch.cat(all_labels).numpy()
|
| 235 |
+
y_score = torch.cat(all_preds).numpy()
|
| 236 |
+
try:
|
| 237 |
+
auc = roc_auc_score(y_true, y_score)
|
| 238 |
+
except Exception:
|
| 239 |
+
auc = 0.0
|
| 240 |
+
try:
|
| 241 |
+
ap = average_precision_score(y_true, y_score)
|
| 242 |
+
except Exception:
|
| 243 |
+
ap = 0.0
|
| 244 |
+
return auc, ap
|
| 245 |
+
|
| 246 |
+
# ---- main ------------------------------------------------------------
|
| 247 |
+
def main():
|
| 248 |
+
parser = argparse.ArgumentParser()
|
| 249 |
+
parser.add_argument("--pair_list", type=str, required=True,
|
| 250 |
+
help="TSV: binder_path glm_path label")
|
| 251 |
+
parser.add_argument("--out_dir", type=str, required=True)
|
| 252 |
+
parser.add_argument("--epochs", type=int, default=10)
|
| 253 |
+
parser.add_argument("--batch_size", type=int, default=32)
|
| 254 |
+
parser.add_argument("--lr", type=float, default=1e-4)
|
| 255 |
+
parser.add_argument("--device", type=str, default="cuda")
|
| 256 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 257 |
+
args = parser.parse_args()
|
| 258 |
+
|
| 259 |
+
# reproducibility
|
| 260 |
+
random.seed(args.seed)
|
| 261 |
+
np.random.seed(args.seed)
|
| 262 |
+
torch.manual_seed(args.seed)
|
| 263 |
+
|
| 264 |
+
logger.info("DEBUG: starting training script with in-line TF compression", flush=True)
|
| 265 |
+
logger.info(f"[i] pair_list: {args.pair_list}", flush=True)
|
| 266 |
+
logger.info(f"[i] output dir: {args.out_dir}", flush=True)
|
| 267 |
+
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
|
| 268 |
+
binder_paths, glm_paths, labels = parse_pair_list(args.pair_list)
|
| 269 |
+
|
| 270 |
+
if len(labels) == 0:
|
| 271 |
+
logger.info("[ERROR] No valid pairs parsed. Exiting.", file=sys.stderr)
|
| 272 |
+
sys.exit(1)
|
| 273 |
+
|
| 274 |
+
label_counts = Counter(labels)
|
| 275 |
+
logger.info(f"[i] Total examples parsed: {len(labels)}. Label distribution: {label_counts}", flush=True)
|
| 276 |
+
|
| 277 |
+
# build compressed TF cache (reduces to 256 if needed)
|
| 278 |
+
tf_compressed_cache = build_tf_compressed_cache(binder_paths, target_dim=256)
|
| 279 |
+
|
| 280 |
+
# simple split: 80/10/10
|
| 281 |
+
n = len(labels)
|
| 282 |
+
idxs = np.arange(n)
|
| 283 |
+
np.random.shuffle(idxs)
|
| 284 |
+
train_i = idxs[: int(0.8 * n)]
|
| 285 |
+
val_i = idxs[int(0.8 * n): int(0.9 * n)]
|
| 286 |
+
test_i = idxs[int(0.9 * n):]
|
| 287 |
+
|
| 288 |
+
def subset(idxs):
|
| 289 |
+
return [binder_paths[i] for i in idxs], [glm_paths[i] for i in idxs], [labels[i] for i in idxs]
|
| 290 |
+
|
| 291 |
+
train_ds = PairDataset(*subset(train_i), tf_compressed_cache=tf_compressed_cache)
|
| 292 |
+
val_ds = PairDataset(*subset(val_i), tf_compressed_cache=tf_compressed_cache)
|
| 293 |
+
test_ds = PairDataset(*subset(test_i), tf_compressed_cache=tf_compressed_cache)
|
| 294 |
+
|
| 295 |
+
logger.info(f"[i] Train/Val/Test sizes: {len(train_ds)}/{len(val_ds)}/{len(test_ds)}", flush=True)
|
| 296 |
+
if len(train_ds) == 0 or len(val_ds) == 0:
|
| 297 |
+
logger.info("[ERROR] Train or validation split is empty; cannot proceed.", file=sys.stderr)
|
| 298 |
+
sys.exit(1)
|
| 299 |
+
|
| 300 |
+
train_dl = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn)
|
| 301 |
+
val_dl = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn)
|
| 302 |
+
test_dl = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn)
|
| 303 |
+
|
| 304 |
+
model = BindPredictor(input_dim=256, hidden_dim=256, heads=8, num_layers=3, use_local_cnn_on_glm=True)
|
| 305 |
+
model = model.to(device)
|
| 306 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-3)
|
| 307 |
+
loss_fn = nn.BCELoss()
|
| 308 |
+
|
| 309 |
+
best_val = -float("inf")
|
| 310 |
+
os_out = Path(args.out_dir)
|
| 311 |
+
os_out.mkdir(exist_ok=True, parents=True)
|
| 312 |
+
|
| 313 |
+
for epoch in range(1, args.epochs + 1):
|
| 314 |
+
logger.info(f"[Epoch {epoch}] starting...", flush=True)
|
| 315 |
+
model.train()
|
| 316 |
+
running_loss = 0.0
|
| 317 |
+
for b, g, y in train_dl:
|
| 318 |
+
b = b.to(device)
|
| 319 |
+
g = g.to(device)
|
| 320 |
+
y = y.to(device)
|
| 321 |
+
pred = model(b, g)
|
| 322 |
+
loss = loss_fn(pred, y)
|
| 323 |
+
optimizer.zero_grad()
|
| 324 |
+
loss.backward()
|
| 325 |
+
optimizer.step()
|
| 326 |
+
running_loss += loss.item() * b.size(0)
|
| 327 |
+
train_loss = running_loss / len(train_ds)
|
| 328 |
+
val_auc, val_ap = evaluate(model, val_dl, device)
|
| 329 |
+
logger.info(f"[Epoch {epoch}] train_loss={train_loss:.4f} val_auc={val_auc:.4f} val_ap={val_ap:.4f}", flush=True)
|
| 330 |
+
|
| 331 |
+
if val_auc > best_val:
|
| 332 |
+
best_val = val_auc
|
| 333 |
+
torch.save(model.state_dict(), os_out / "best_model.pt")
|
| 334 |
+
logger.info(f"[Epoch {epoch}] Saved new best model with val_auc={val_auc:.4f}", flush=True)
|
| 335 |
+
|
| 336 |
+
torch.save(model.state_dict(), os_out / "last_model.pt")
|
| 337 |
+
test_auc, test_ap = evaluate(model, test_dl, device)
|
| 338 |
+
logger.info(f"FINAL TEST: AUC={test_auc:.4f} AP={test_ap:.4f}", flush=True)
|
| 339 |
+
logger.info(f"[i] Models written to {os_out}/best_model.pt and last_model.pt", flush=True)
|
| 340 |
+
|
| 341 |
+
if __name__ == "__main__":
|
| 342 |
+
main()
|
dpacman/data_tasks/cluster/__init__.py
ADDED
|
File without changes
|
dpacman/data_tasks/cluster/remap.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Holds Python methods for clustering Remap DNA sequences.
|
| 3 |
+
"""
|
| 4 |
+
import argparse
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
import random
|
| 9 |
+
import sys
|
| 10 |
+
import subprocess
|
| 11 |
+
from collections import defaultdict
|
| 12 |
+
import rootutils
|
| 13 |
+
import logging
|
| 14 |
+
import os
|
| 15 |
+
import json
|
| 16 |
+
from omegaconf import DictConfig
|
| 17 |
+
from hydra.core.hydra_config import HydraConfig
|
| 18 |
+
|
| 19 |
+
from dpacman.utils.clustering import make_fasta, process_fasta, analyze_clustering_result, run_mmseqs_clustering, cluster_summary
|
| 20 |
+
|
| 21 |
+
root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
def cluster_molecules(fasta_dict, fasta_path, mmseqs_params: DictConfig, output_dir="", path_to_mmseqs="../softwares/mmseqs", moltype="dna", use_gpu=True):
|
| 25 |
+
"""
|
| 26 |
+
Args:
|
| 27 |
+
- fasta_dict: dictionary object where the keys are sequence IDs, and the values are sequences
|
| 28 |
+
- fasta_path: str or Path to where the output fasta should be saved
|
| 29 |
+
- mmseqs_params: DictConfig of mmseqs hparams
|
| 30 |
+
- type: molecule type, "dna" or "protein"
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
# make the fasta
|
| 34 |
+
logger.info(f"Making fasta at: {fasta_path}")
|
| 35 |
+
fasta_path = str(make_fasta(fasta_dict, fasta_path))
|
| 36 |
+
|
| 37 |
+
# prepare directories
|
| 38 |
+
output_dir = str(Path(root) / output_dir)
|
| 39 |
+
path_to_mmseqs = str(Path(root) / path_to_mmseqs)
|
| 40 |
+
|
| 41 |
+
# run mmseqs
|
| 42 |
+
dbtype=1
|
| 43 |
+
if moltype=="dna": dbtype=2
|
| 44 |
+
run_mmseqs_clustering(fasta_path,
|
| 45 |
+
output_dir,
|
| 46 |
+
min_seq_id=mmseqs_params.min_seq_id,
|
| 47 |
+
c=mmseqs_params.c,
|
| 48 |
+
cov_mode=mmseqs_params.cov_mode,
|
| 49 |
+
cluster_mode=mmseqs_params.cluster_mode,
|
| 50 |
+
dbtype=dbtype,
|
| 51 |
+
path_to_mmseqs=path_to_mmseqs)
|
| 52 |
+
|
| 53 |
+
tsv_path = [x for x in os.listdir(output_dir) if x.endswith(".tsv")][0]
|
| 54 |
+
clusters = analyze_clustering_result(
|
| 55 |
+
fasta_path, Path(output_dir) / tsv_path
|
| 56 |
+
)
|
| 57 |
+
logger.info(f"Made clusters DataFrame:\n{clusters.head()}")
|
| 58 |
+
cluster_summary(clusters)
|
| 59 |
+
|
| 60 |
+
def read_input_data(input_path):
|
| 61 |
+
"""
|
| 62 |
+
Read the data from the input path.
|
| 63 |
+
It may be a csv or parquet
|
| 64 |
+
"""
|
| 65 |
+
input_path = Path(root) / input_path
|
| 66 |
+
df = None
|
| 67 |
+
if str(input_path).endswith(".parquet"):
|
| 68 |
+
df = pd.read_parquet(input_path, engine='pyarrow')
|
| 69 |
+
elif str(input_path).endswith(".csv"):
|
| 70 |
+
df = pd.read_csv(input_path)
|
| 71 |
+
elif str(input_path).endswith(".tsv") or str(input_path).endswith(".txt"):
|
| 72 |
+
df = pd.read_csv(input_path, sep="\t")
|
| 73 |
+
else:
|
| 74 |
+
raise Exception(f"Cannot read input data from {input_path}: invalid file type")
|
| 75 |
+
return df
|
| 76 |
+
|
| 77 |
+
def main(cfg: DictConfig):
|
| 78 |
+
"""
|
| 79 |
+
Run clustering on Remap protein AND DNA sequences.
|
| 80 |
+
Get clusters for each.
|
| 81 |
+
"""
|
| 82 |
+
# Load input CSV
|
| 83 |
+
# columns: Index(['ID', 'tr_seqid', 'dna_seqid', 'tr_name', 'peak_id', 'chipscore', 'total_jaspar_hits', 'dna_sequence', 'tr_sequence', 'scores']
|
| 84 |
+
df = read_input_data(cfg.data_task.input_data_path)
|
| 85 |
+
|
| 86 |
+
# Separate configs
|
| 87 |
+
dna_full_cfg = cfg.data_task.dna_full
|
| 88 |
+
dna_peaks_cfg = cfg.data_task.dna_peaks
|
| 89 |
+
protein_cfg = cfg.data_task.protein
|
| 90 |
+
logger.info(f"Clustering DNA full: {cfg.data_task.cluster_dna_full}. Clustering DNA peaks: {cfg.data_task.cluster_dna_peaks}. Clustering protein: {cfg.data_task.cluster_protein}.")
|
| 91 |
+
|
| 92 |
+
# Make fastas
|
| 93 |
+
dna_full_fasta_path = Path(root) / dna_full_cfg.fasta_path
|
| 94 |
+
dna_peaks_fasta_path = Path(root) / dna_peaks_cfg.fasta_path
|
| 95 |
+
protein_fasta_path = Path(root) / protein_cfg.fasta_path
|
| 96 |
+
os.makedirs(dna_full_fasta_path.parent, exist_ok=True)
|
| 97 |
+
os.makedirs(dna_peaks_fasta_path.parent, exist_ok=True)
|
| 98 |
+
os.makedirs(protein_fasta_path.parent, exist_ok=True)
|
| 99 |
+
|
| 100 |
+
# Make dictioary needed for input to the fasta methods
|
| 101 |
+
with open(Path(root) / dna_full_cfg.input_map_path, "r") as f:
|
| 102 |
+
dna_full_fasta_dict = json.load(f)
|
| 103 |
+
|
| 104 |
+
with open(Path(root) / dna_peaks_cfg.input_map_path, "r") as f:
|
| 105 |
+
dna_peaks_fasta_dict = json.load(f)
|
| 106 |
+
|
| 107 |
+
with open(Path(root) / protein_cfg.input_map_path, "r") as f:
|
| 108 |
+
protein_fasta_dict = json.load(f)
|
| 109 |
+
|
| 110 |
+
logger.info(f"Loaded DNA seq dict from: {dna_full_cfg.input_map_path}. Size: {len(dna_full_fasta_dict)}")
|
| 111 |
+
logger.info(f"Loaded DNA peaks dict from: {dna_peaks_cfg.input_map_path}. Size: {len(dna_peaks_fasta_dict)}")
|
| 112 |
+
logger.info(f"Loaded TR (protein) seq dict from: {protein_cfg.input_map_path}. Size: {len(protein_fasta_dict)}")
|
| 113 |
+
|
| 114 |
+
# Build hash-sets once (drop NaNs to avoid weird matches)
|
| 115 |
+
dna_ids = set(df["dna_seqid"].dropna())
|
| 116 |
+
peak_ids = set(df["peak_seqid"].dropna())
|
| 117 |
+
tr_ids = set(df["tr_seqid"].dropna())
|
| 118 |
+
|
| 119 |
+
# Iterate only the intersection (fast when allowed << dict size)
|
| 120 |
+
dna_full_fasta_dict = {k: dna_full_fasta_dict[k] for k in (dna_full_fasta_dict.keys() & dna_ids)}
|
| 121 |
+
dna_peaks_fasta_dict = {k: dna_peaks_fasta_dict[k] for k in (dna_peaks_fasta_dict.keys() & peak_ids)}
|
| 122 |
+
protein_fasta_dict = {k: protein_fasta_dict[k] for k in (protein_fasta_dict.keys() & tr_ids)}
|
| 123 |
+
|
| 124 |
+
logger.info(f"Filtered dictionaries to only sequences in the filtered training data.")
|
| 125 |
+
logger.info(f"Total DNA sequences: {len(dna_full_fasta_dict)}. Total peak sequences: {len(dna_peaks_fasta_dict)}. Total protein sequences: {len(protein_fasta_dict)}")
|
| 126 |
+
|
| 127 |
+
if cfg.data_task.cluster_dna_full:
|
| 128 |
+
logger.info(f"Clustering DNA full sequences, with context")
|
| 129 |
+
cluster_molecules(dna_full_fasta_dict, dna_full_fasta_path,
|
| 130 |
+
mmseqs_params=dna_full_cfg.mmseqs, output_dir=dna_full_cfg.output_dir, path_to_mmseqs=cfg.data_task.path_to_mmseqs, moltype="dna")
|
| 131 |
+
|
| 132 |
+
if cfg.data_task.cluster_dna_peaks:
|
| 133 |
+
logger.info(f"Clustering DNA peak sequences")
|
| 134 |
+
cluster_molecules(dna_peaks_fasta_dict, dna_peaks_fasta_path,
|
| 135 |
+
mmseqs_params=dna_peaks_cfg.mmseqs, output_dir=dna_peaks_cfg.output_dir, path_to_mmseqs=cfg.data_task.path_to_mmseqs, moltype="dna")
|
| 136 |
+
|
| 137 |
+
if cfg.data_task.cluster_protein:
|
| 138 |
+
logger.info("Clustering protein sequences.")
|
| 139 |
+
cluster_molecules(protein_fasta_dict, protein_fasta_path, mmseqs_params=protein_cfg.mmseqs, output_dir=protein_cfg.output_dir, path_to_mmseqs=cfg.data_task.path_to_mmseqs, moltype="protein")
|
| 140 |
+
|
| 141 |
+
logger.info("Clustering pipeline complete")
|
| 142 |
+
|
| 143 |
+
if __name__ == "__main__":
|
| 144 |
+
main()
|
dpacman/data_tasks/embeddings/embedders.py
ADDED
|
@@ -0,0 +1,560 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Plug-and-play embedding extraction for:
|
| 3 |
+
• Chromosome sequences (from raw UCSC JSON)
|
| 4 |
+
• TF sequences (transcription_factors.fasta)
|
| 5 |
+
|
| 6 |
+
Usage example (DNA + protein in one go):
|
| 7 |
+
module load miniconda/24.7.1
|
| 8 |
+
conda activate dpacman
|
| 9 |
+
python dpacman/data/compute_embeddings.py \
|
| 10 |
+
--genome-json-dir ../data_files/raw/genomes/hg38 \
|
| 11 |
+
--tf-fasta ../data_files/processed/tfclust/hg38_tf/transcription_factors.fasta \
|
| 12 |
+
--chrom-model caduceus \
|
| 13 |
+
--tf-model esm-dbp \
|
| 14 |
+
--out-dir ../data_files/processed/tfclust/hg38_tf/embeddings \
|
| 15 |
+
--device cuda
|
| 16 |
+
"""
|
| 17 |
+
import os
|
| 18 |
+
import re
|
| 19 |
+
import argparse
|
| 20 |
+
import json
|
| 21 |
+
import numpy as np
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
import torch
|
| 24 |
+
from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM, pipeline
|
| 25 |
+
import esm
|
| 26 |
+
from Bio import SeqIO
|
| 27 |
+
import time
|
| 28 |
+
import pandas as pd
|
| 29 |
+
from tqdm.auto import tqdm
|
| 30 |
+
import logging, math
|
| 31 |
+
|
| 32 |
+
# ---- model wrappers ----
|
| 33 |
+
|
| 34 |
+
class CaduceusEmbedder:
|
| 35 |
+
def __init__(self, device, chunk_size=131_072, overlap=0):
|
| 36 |
+
"""
|
| 37 |
+
device: 'cpu' or 'cuda'
|
| 38 |
+
chunk_size: max bases (and thus tokens) to send in one forward pass
|
| 39 |
+
overlap: how many bases each window overlaps the previous; 0 = no overlap
|
| 40 |
+
"""
|
| 41 |
+
model_name = "kuleshov-group/caduceus-ph_seqlen-131k_d_model-256_n_layer-16"
|
| 42 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 43 |
+
model_name, trust_remote_code=True
|
| 44 |
+
)
|
| 45 |
+
self.model = AutoModel.from_pretrained(
|
| 46 |
+
model_name, trust_remote_code=True
|
| 47 |
+
).to(device).eval()
|
| 48 |
+
self.device = device
|
| 49 |
+
self.chunk_size = chunk_size
|
| 50 |
+
self.step = chunk_size - overlap
|
| 51 |
+
|
| 52 |
+
def embed(self, seqs):
|
| 53 |
+
"""
|
| 54 |
+
seqs: List[str] of DNA sequences (each <= chunk_size for this test)
|
| 55 |
+
returns: np.ndarray of shape (N, L, D), raw per‐token embeddings
|
| 56 |
+
"""
|
| 57 |
+
# outputs = []
|
| 58 |
+
# for seq in seqs:
|
| 59 |
+
# # --- new: raw per‐token embeddings in one shot ---
|
| 60 |
+
# toks = self.tokenizer(
|
| 61 |
+
# seq,
|
| 62 |
+
# return_tensors="pt",
|
| 63 |
+
# padding=False,
|
| 64 |
+
# truncation=True,
|
| 65 |
+
# max_length=self.chunk_size
|
| 66 |
+
# ).to(self.device)
|
| 67 |
+
# with torch.no_grad():
|
| 68 |
+
# out = self.model(**toks).last_hidden_state # (1, L, D)
|
| 69 |
+
# outputs.append(out.cpu().numpy()[0]) # (L, D)
|
| 70 |
+
|
| 71 |
+
# return np.stack(outputs, axis=0) # (N, L, D)
|
| 72 |
+
outputs = []
|
| 73 |
+
for seq in tqdm(seqs, total=len(seqs), desc="DNA: Caduceus", dynamic_ncols=True):
|
| 74 |
+
toks = self.tokenizer(
|
| 75 |
+
seq,
|
| 76 |
+
return_tensors="pt",
|
| 77 |
+
padding=False,
|
| 78 |
+
truncation=True,
|
| 79 |
+
max_length=self.chunk_size
|
| 80 |
+
).to(self.device)
|
| 81 |
+
with torch.no_grad():
|
| 82 |
+
out = self.model(**toks).last_hidden_state # (1, L, D)
|
| 83 |
+
outputs.append(out.cpu().numpy()[0]) # (L, D)
|
| 84 |
+
return outputs # list of variable-length (L_i, D) arrays
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def benchmark(self, lengths=None):
|
| 88 |
+
"""
|
| 89 |
+
Time embedding on single-sequence of various lengths.
|
| 90 |
+
By default tests [5K,10K,50K,100K,chunk_size].
|
| 91 |
+
"""
|
| 92 |
+
tests = lengths or [5_000, 10_000, 50_000, 100_000, self.chunk_size]
|
| 93 |
+
print(f"→ Benchmarking Caduceus on device={self.device}")
|
| 94 |
+
for sz in tests:
|
| 95 |
+
seq = "A" * sz
|
| 96 |
+
# Warm-up
|
| 97 |
+
_ = self.embed([seq])
|
| 98 |
+
if self.device != "cpu":
|
| 99 |
+
torch.cuda.synchronize()
|
| 100 |
+
t0 = time.perf_counter()
|
| 101 |
+
_ = self.embed([seq])
|
| 102 |
+
if self.device != "cpu":
|
| 103 |
+
torch.cuda.synchronize()
|
| 104 |
+
t1 = time.perf_counter()
|
| 105 |
+
print(f" length={sz:6,d} time={(t1-t0)*1000:7.1f} ms")
|
| 106 |
+
|
| 107 |
+
class SegmentNTEmbedder:
|
| 108 |
+
def __init__(self, device):
|
| 109 |
+
self.tokenizer = AutoTokenizer.from_pretrained("InstaDeepAI/segment_nt", trust_remote_code=True)
|
| 110 |
+
self.model = AutoModel.from_pretrained("InstaDeepAI/segment_nt", trust_remote_code=True).to(device).eval()
|
| 111 |
+
self.device = device
|
| 112 |
+
|
| 113 |
+
def _adjust_length(self, input_ids):
|
| 114 |
+
bs, L = input_ids.shape
|
| 115 |
+
excl = L - 1
|
| 116 |
+
remainder = (excl) % 4
|
| 117 |
+
if remainder != 0:
|
| 118 |
+
pad_needed = 4 - remainder
|
| 119 |
+
pad_tensor = torch.full((bs, pad_needed), self.tokenizer.pad_token_id, dtype=input_ids.dtype, device=input_ids.device)
|
| 120 |
+
input_ids = torch.cat([input_ids, pad_tensor], dim=1)
|
| 121 |
+
return input_ids
|
| 122 |
+
|
| 123 |
+
def embed(self, seqs, batch_size=16):
|
| 124 |
+
"""
|
| 125 |
+
seqs: List[str]
|
| 126 |
+
Returns: np.ndarray of shape (N, D)
|
| 127 |
+
"""
|
| 128 |
+
all_embeddings = []
|
| 129 |
+
for i in range(0, len(seqs), batch_size):
|
| 130 |
+
batch_seqs = seqs[i : i + batch_size]
|
| 131 |
+
encoded = self.tokenizer.batch_encode_plus(
|
| 132 |
+
batch_seqs,
|
| 133 |
+
return_tensors="pt",
|
| 134 |
+
padding=True,
|
| 135 |
+
truncation=True,
|
| 136 |
+
)
|
| 137 |
+
input_ids = encoded["input_ids"].to(self.device) # (B, L)
|
| 138 |
+
attention_mask = input_ids != self.tokenizer.pad_token_id
|
| 139 |
+
|
| 140 |
+
input_ids = self._adjust_length(input_ids)
|
| 141 |
+
attention_mask = (input_ids != self.tokenizer.pad_token_id)
|
| 142 |
+
|
| 143 |
+
with torch.no_grad():
|
| 144 |
+
outs = self.model(
|
| 145 |
+
input_ids,
|
| 146 |
+
attention_mask=attention_mask,
|
| 147 |
+
output_hidden_states=True,
|
| 148 |
+
return_dict=True,
|
| 149 |
+
)
|
| 150 |
+
if hasattr(outs, "hidden_states") and outs.hidden_states is not None:
|
| 151 |
+
last_hidden = outs.hidden_states[-1] # (B, L, D)
|
| 152 |
+
else:
|
| 153 |
+
last_hidden = outs.last_hidden_state # fallback
|
| 154 |
+
|
| 155 |
+
# Exclude CLS token if present (assume first token) and pool
|
| 156 |
+
pooled = last_hidden[:, 1:, :].mean(dim=1) # (B, D)
|
| 157 |
+
all_embeddings.append(pooled.cpu().numpy())
|
| 158 |
+
|
| 159 |
+
# release fragmentation
|
| 160 |
+
torch.cuda.empty_cache()
|
| 161 |
+
|
| 162 |
+
return np.vstack(all_embeddings) # (N, D)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class DNABertEmbedder:
|
| 166 |
+
def __init__(self, device):
|
| 167 |
+
self.tokenizer = AutoTokenizer.from_pretrained("zhihan1996/DNA_bert_6", trust_remote_code=True)
|
| 168 |
+
self.model = AutoModel.from_pretrained("zhihan1996/DNA_bert_6", trust_remote_code=True).to(device)
|
| 169 |
+
self.device = device
|
| 170 |
+
|
| 171 |
+
def embed(self, seqs):
|
| 172 |
+
embs = []
|
| 173 |
+
for s in seqs:
|
| 174 |
+
tokens = self.tokenizer(s, return_tensors="pt", padding=True)["input_ids"].to(self.device)
|
| 175 |
+
with torch.no_grad():
|
| 176 |
+
out = self.model(tokens).last_hidden_state.mean(1)
|
| 177 |
+
embs.append(out.cpu().numpy())
|
| 178 |
+
return np.vstack(embs)
|
| 179 |
+
|
| 180 |
+
class NucleotideTransformerEmbedder:
|
| 181 |
+
def __init__(self, device):
|
| 182 |
+
# HF “feature-extraction” returns a list of (L, D) arrays for each input
|
| 183 |
+
# device: “cpu” or “cuda”
|
| 184 |
+
self.pipe = pipeline(
|
| 185 |
+
"feature-extraction",
|
| 186 |
+
model="InstaDeepAI/nucleotide-transformer-500m-1000g",
|
| 187 |
+
device= -1 if device=="cpu" else 0 # HF uses -1 for CPU, 0 for GPU #:contentReference[oaicite:0]{index=0}
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
def embed(self, seqs):
|
| 191 |
+
"""
|
| 192 |
+
seqs: List[str] of raw DNA sequences
|
| 193 |
+
returns: (N, D) array, one D-dim vector per sequence
|
| 194 |
+
"""
|
| 195 |
+
all_embeddings = self.pipe(seqs, truncation=True, padding=True)
|
| 196 |
+
# all_embeddings is a List of shape (L, D) arrays
|
| 197 |
+
pooled = [ np.mean(x, axis=0) for x in all_embeddings ]
|
| 198 |
+
return np.vstack(pooled)
|
| 199 |
+
|
| 200 |
+
# class ESMEmbedder:
|
| 201 |
+
# def __init__(self, device):
|
| 202 |
+
# self.model, self.alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
|
| 203 |
+
# self.batch_converter = self.alphabet.get_batch_converter()
|
| 204 |
+
# self.model.to(device).eval()
|
| 205 |
+
# self.device = device
|
| 206 |
+
|
| 207 |
+
# def embed(self, seqs):
|
| 208 |
+
# batch = [(str(i), seq) for i, seq in enumerate(seqs)]
|
| 209 |
+
# _, _, toks = self.batch_converter(batch)
|
| 210 |
+
# toks = toks.to(self.device)
|
| 211 |
+
# with torch.no_grad():
|
| 212 |
+
# results = self.model(toks, repr_layers=[33], return_contacts=False)
|
| 213 |
+
# reps = results["representations"][33]
|
| 214 |
+
# return reps[:, 1:-1].mean(1).cpu().numpy()
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
class ESMEmbedder:
|
| 218 |
+
def __init__(self, device, model_name="esm2_t33_650M_UR50D"):
|
| 219 |
+
# Try to load the specified ESM-2 model; fallback to esm1b if missing
|
| 220 |
+
self.device = device
|
| 221 |
+
try:
|
| 222 |
+
self.model, self.alphabet = getattr(esm.pretrained, model_name)()
|
| 223 |
+
self.is_esm2 = model_name.lower().startswith("esm2")
|
| 224 |
+
except AttributeError:
|
| 225 |
+
# fallback to ESM-1b
|
| 226 |
+
self.model, self.alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
|
| 227 |
+
self.is_esm2 = False
|
| 228 |
+
self.batch_converter = self.alphabet.get_batch_converter()
|
| 229 |
+
self.model.to(device).eval()
|
| 230 |
+
# determine max length: esm2 models vary; use default 1024 for esm1b
|
| 231 |
+
self.max_len = 4096 if self.is_esm2 else 1024 # adjust if your esm2 variant has explicit limit
|
| 232 |
+
# for chunking: reserve 2 tokens if model uses BOS/EOS
|
| 233 |
+
self.chunk_size = self.max_len - 2
|
| 234 |
+
self.overlap = self.chunk_size // 4 # 25% overlap to smooth boundaries
|
| 235 |
+
|
| 236 |
+
def _chunk_sequence(self, seq):
|
| 237 |
+
"""
|
| 238 |
+
Return list of possibly overlapping chunks of seq, each <= chunk_size.
|
| 239 |
+
"""
|
| 240 |
+
if len(seq) <= self.chunk_size:
|
| 241 |
+
return [seq]
|
| 242 |
+
step = self.chunk_size - self.overlap
|
| 243 |
+
chunks = []
|
| 244 |
+
for i in range(0, len(seq), step):
|
| 245 |
+
chunk = seq[i : i + self.chunk_size]
|
| 246 |
+
if not chunk:
|
| 247 |
+
break
|
| 248 |
+
chunks.append(chunk)
|
| 249 |
+
return chunks
|
| 250 |
+
|
| 251 |
+
def embed(self, seqs):
|
| 252 |
+
"""
|
| 253 |
+
seqs: List[str] of protein sequences.
|
| 254 |
+
Returns: np.ndarray of shape (N, D) pooled per-sequence embeddings.
|
| 255 |
+
"""
|
| 256 |
+
all_embeddings = []
|
| 257 |
+
for i, seq in enumerate(seqs):
|
| 258 |
+
chunks = self._chunk_sequence(seq)
|
| 259 |
+
chunk_vecs = []
|
| 260 |
+
# process chunks in batch if small number, else sequentially
|
| 261 |
+
for chunk in chunks:
|
| 262 |
+
batch = [(str(i), chunk)]
|
| 263 |
+
_, _, toks = self.batch_converter(batch)
|
| 264 |
+
toks = toks.to(self.device)
|
| 265 |
+
with torch.no_grad():
|
| 266 |
+
results = self.model(toks, repr_layers=[33], return_contacts=False)
|
| 267 |
+
reps = results["representations"][33] # (1, L, D)
|
| 268 |
+
# remove BOS/EOS if present: take 1:-1 if length permits
|
| 269 |
+
if reps.size(1) > 2:
|
| 270 |
+
rep = reps[:, 1:-1].mean(1) # (1, D)
|
| 271 |
+
else:
|
| 272 |
+
rep = reps.mean(1) # fallback
|
| 273 |
+
chunk_vecs.append(rep.squeeze(0)) # (D,)
|
| 274 |
+
if len(chunk_vecs) == 1:
|
| 275 |
+
seq_vec = chunk_vecs[0]
|
| 276 |
+
else:
|
| 277 |
+
# average chunk vectors
|
| 278 |
+
stacked = torch.stack(chunk_vecs, dim=0) # (num_chunks, D)
|
| 279 |
+
seq_vec = stacked.mean(0)
|
| 280 |
+
all_embeddings.append(seq_vec.cpu().numpy())
|
| 281 |
+
return np.vstack(all_embeddings) # (N, D)
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
# class ESMDBPEmbedder:
|
| 285 |
+
# def __init__(self, device):
|
| 286 |
+
# base_model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
|
| 287 |
+
# model_path = (
|
| 288 |
+
# Path(__file__).resolve().parent.parent
|
| 289 |
+
# / "pretrained" / "ESM-DBP" / "ESM-DBP.model"
|
| 290 |
+
# )
|
| 291 |
+
# checkpoint = torch.load(model_path, map_location="cpu")
|
| 292 |
+
# clean_sd = {}
|
| 293 |
+
# for k, v in checkpoint.items():
|
| 294 |
+
# clean_sd[k.replace("module.", "")] = v
|
| 295 |
+
# result = base_model.load_state_dict(clean_sd, strict=False)
|
| 296 |
+
# if result.missing_keys:
|
| 297 |
+
# print(f"[ESMDBP] missing keys: {result.missing_keys}")
|
| 298 |
+
# if result.unexpected_keys:
|
| 299 |
+
# print(f"[ESMDBP] unexpected keys: {result.unexpected_keys}")
|
| 300 |
+
|
| 301 |
+
# self.model = base_model.to(device).eval()
|
| 302 |
+
# self.alphabet = alphabet
|
| 303 |
+
# self.batch_converter = alphabet.get_batch_converter()
|
| 304 |
+
# self.device = device
|
| 305 |
+
|
| 306 |
+
# def embed(self, seqs):
|
| 307 |
+
# batch = [(str(i), seq) for i, seq in enumerate(seqs)]
|
| 308 |
+
# _, _, toks = self.batch_converter(batch)
|
| 309 |
+
# toks = toks.to(self.device)
|
| 310 |
+
# with torch.no_grad():
|
| 311 |
+
# out = self.model(toks, repr_layers=[33], return_contacts=False)
|
| 312 |
+
# reps = out["representations"][33]
|
| 313 |
+
# # skip start/end tokens
|
| 314 |
+
# return reps[:, 1:-1].mean(1).cpu().numpy()
|
| 315 |
+
|
| 316 |
+
class ESMDBPEmbedder:
|
| 317 |
+
def __init__(self, device):
|
| 318 |
+
base_model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
|
| 319 |
+
model_path = (
|
| 320 |
+
Path(__file__).resolve().parent.parent
|
| 321 |
+
/ "pretrained" / "ESM-DBP" / "ESM-DBP.model"
|
| 322 |
+
)
|
| 323 |
+
checkpoint = torch.load(model_path, map_location="cpu")
|
| 324 |
+
clean_sd = {}
|
| 325 |
+
for k, v in checkpoint.items():
|
| 326 |
+
clean_sd[k.replace("module.", "")] = v
|
| 327 |
+
result = base_model.load_state_dict(clean_sd, strict=False)
|
| 328 |
+
if result.missing_keys:
|
| 329 |
+
print(f"[ESMDBP] missing keys: {result.missing_keys}")
|
| 330 |
+
if result.unexpected_keys:
|
| 331 |
+
print(f"[ESMDBP] unexpected keys: {result.unexpected_keys}")
|
| 332 |
+
|
| 333 |
+
self.model = base_model.to(device).eval()
|
| 334 |
+
self.alphabet = alphabet
|
| 335 |
+
self.batch_converter = alphabet.get_batch_converter()
|
| 336 |
+
self.device = device
|
| 337 |
+
self.max_len = 1024 # same limit as esm1b
|
| 338 |
+
self.chunk_size = self.max_len - 2
|
| 339 |
+
self.overlap = self.chunk_size // 4
|
| 340 |
+
|
| 341 |
+
def _chunk_sequence(self, seq):
|
| 342 |
+
if len(seq) <= self.chunk_size:
|
| 343 |
+
return [seq]
|
| 344 |
+
step = self.chunk_size - self.overlap
|
| 345 |
+
chunks = []
|
| 346 |
+
for i in range(0, len(seq), step):
|
| 347 |
+
chunk = seq[i : i + self.chunk_size]
|
| 348 |
+
if not chunk:
|
| 349 |
+
break
|
| 350 |
+
chunks.append(chunk)
|
| 351 |
+
return chunks
|
| 352 |
+
|
| 353 |
+
def embed(self, seqs):
|
| 354 |
+
all_embeddings = []
|
| 355 |
+
for i, seq in enumerate(seqs):
|
| 356 |
+
chunks = self._chunk_sequence(seq)
|
| 357 |
+
chunk_vecs = []
|
| 358 |
+
for chunk in chunks:
|
| 359 |
+
batch = [(str(i), chunk)]
|
| 360 |
+
_, _, toks = self.batch_converter(batch)
|
| 361 |
+
toks = toks.to(self.device)
|
| 362 |
+
with torch.no_grad():
|
| 363 |
+
out = self.model(toks, repr_layers=[33], return_contacts=False)
|
| 364 |
+
reps = out["representations"][33]
|
| 365 |
+
if reps.size(1) > 2:
|
| 366 |
+
rep = reps[:, 1:-1].mean(1)
|
| 367 |
+
else:
|
| 368 |
+
rep = reps.mean(1)
|
| 369 |
+
chunk_vecs.append(rep.squeeze(0))
|
| 370 |
+
if len(chunk_vecs) == 1:
|
| 371 |
+
seq_vec = chunk_vecs[0]
|
| 372 |
+
else:
|
| 373 |
+
stacked = torch.stack(chunk_vecs, dim=0)
|
| 374 |
+
seq_vec = stacked.mean(0)
|
| 375 |
+
all_embeddings.append(seq_vec.cpu().numpy())
|
| 376 |
+
return np.vstack(all_embeddings)
|
| 377 |
+
|
| 378 |
+
class GPNEmbedder:
|
| 379 |
+
def __init__(self, device):
|
| 380 |
+
model_name = "songlab/gpn-msa-sapiens"
|
| 381 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 382 |
+
self.model = AutoModelForMaskedLM.from_pretrained(model_name)
|
| 383 |
+
self.model.to(device)
|
| 384 |
+
self.model.eval()
|
| 385 |
+
self.device = device
|
| 386 |
+
|
| 387 |
+
def embed(self, seqs):
|
| 388 |
+
inputs = self.tokenizer(
|
| 389 |
+
seqs,
|
| 390 |
+
return_tensors="pt",
|
| 391 |
+
padding=True,
|
| 392 |
+
truncation=True
|
| 393 |
+
).to(self.device)
|
| 394 |
+
|
| 395 |
+
with torch.no_grad():
|
| 396 |
+
last_hidden = self.model(**inputs).last_hidden_state
|
| 397 |
+
return last_hidden.mean(dim=1).cpu().numpy()
|
| 398 |
+
|
| 399 |
+
class ProGenEmbedder:
|
| 400 |
+
def __init__(self, device):
|
| 401 |
+
model_name = "jinyuan22/ProGen2-base"
|
| 402 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 403 |
+
self.model = AutoModel.from_pretrained(model_name).to(device).eval()
|
| 404 |
+
self.device = device
|
| 405 |
+
|
| 406 |
+
def embed(self, seqs):
|
| 407 |
+
inputs = self.tokenizer(
|
| 408 |
+
seqs,
|
| 409 |
+
return_tensors="pt",
|
| 410 |
+
padding=True,
|
| 411 |
+
truncation=True
|
| 412 |
+
).to(self.device)
|
| 413 |
+
with torch.no_grad():
|
| 414 |
+
last_hidden = self.model(**inputs).last_hidden_state
|
| 415 |
+
return last_hidden.mean(dim=1).cpu().numpy()
|
| 416 |
+
|
| 417 |
+
# ---- main pipeline ----
|
| 418 |
+
|
| 419 |
+
def get_embedder(name, device, for_dna=True):
|
| 420 |
+
name = name.lower()
|
| 421 |
+
if for_dna:
|
| 422 |
+
if name=="caduceus": return CaduceusEmbedder(device)
|
| 423 |
+
if name=="dnabert": return DNABertEmbedder(device)
|
| 424 |
+
if name=="nucleotide": return NucleotideTransformerEmbedder(device)
|
| 425 |
+
if name=="gpn": return GPNEmbedder(device)
|
| 426 |
+
if name=="segmentnt": return SegmentNTEmbedder(device)
|
| 427 |
+
else:
|
| 428 |
+
if name in ("esm",): return ESMEmbedder(device)
|
| 429 |
+
if name in ("esm-dbp","esm_dbp"): return ESMDBPEmbedder(device)
|
| 430 |
+
if name=="progen": return ProGenEmbedder(device)
|
| 431 |
+
raise ValueError(f"Unknown model {name} (for_dna={for_dna})")
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
def pad_token_embeddings(list_of_arrays, pad_value=0.0):
|
| 435 |
+
"""
|
| 436 |
+
list_of_arrays: list of (L_i, D) numpy arrays
|
| 437 |
+
Returns:
|
| 438 |
+
padded: (N, L_max, D) array
|
| 439 |
+
mask: (N, L_max) boolean array where True = real token, False = padding
|
| 440 |
+
"""
|
| 441 |
+
N = len(list_of_arrays)
|
| 442 |
+
D = list_of_arrays[0].shape[1]
|
| 443 |
+
L_max = max(arr.shape[0] for arr in list_of_arrays)
|
| 444 |
+
padded = np.full((N, L_max, D), pad_value, dtype=list_of_arrays[0].dtype)
|
| 445 |
+
mask = np.zeros((N, L_max), dtype=bool)
|
| 446 |
+
for i, arr in enumerate(list_of_arrays):
|
| 447 |
+
L = arr.shape[0]
|
| 448 |
+
padded[i, :L] = arr
|
| 449 |
+
mask[i, :L] = True
|
| 450 |
+
return padded, mask
|
| 451 |
+
|
| 452 |
+
def embed_and_save(seqs, ids, embedder, out_path):
|
| 453 |
+
embs = embedder.embed(seqs)
|
| 454 |
+
|
| 455 |
+
# Decide whether we got variable-length per-token outputs (list of (L, D))
|
| 456 |
+
is_variable_token = isinstance(embs, (list, tuple)) and len(embs) > 0 and hasattr(embs[0], "shape") and embs[0].ndim == 2
|
| 457 |
+
|
| 458 |
+
if is_variable_token:
|
| 459 |
+
# pad to (N, L_max, D) + mask
|
| 460 |
+
padded, mask = pad_token_embeddings(embs)
|
| 461 |
+
# Save both embeddings and mask together in an .npz for convenience
|
| 462 |
+
np.savez_compressed(out_path.with_suffix(".caduceus.npz"),
|
| 463 |
+
embeddings=padded,
|
| 464 |
+
mask=mask,
|
| 465 |
+
ids=np.array(ids, dtype=object))
|
| 466 |
+
else:
|
| 467 |
+
# fixed shape output, e.g., pooled (N, D)
|
| 468 |
+
array = np.vstack(embs) if isinstance(embs, list) else embs
|
| 469 |
+
np.save(out_path, array)
|
| 470 |
+
with open(out_path.with_suffix(".ids"), "w") as f:
|
| 471 |
+
f.write("\n".join(ids))
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
if __name__=="__main__":
|
| 475 |
+
|
| 476 |
+
p = argparse.ArgumentParser()
|
| 477 |
+
#p.add_argument("--peak_fasta", default="binding_peaks_unique.fa", help="FASTA of deduplicated binding peak sequences; if present this is used for DNA embedding instead of genome JSONs")
|
| 478 |
+
p.add_argument("--genome-json-dir", default=None, help="(fallback) directory of UCSC JSONs for full chromosome embedding if peak FASTA is missing or you explicitly want chromosomes")
|
| 479 |
+
p.add_argument("--skip-dna", action="store_true", help="if set, skip the chromosome embedding step") #if glm embeddings successful but not plm embeddings
|
| 480 |
+
p.add_argument("--tf-fasta", required=True, help="input TF FASTA file")
|
| 481 |
+
p.add_argument("--chrom-model", default="caduceus")
|
| 482 |
+
p.add_argument("--tf-model", default="esm-dbp")
|
| 483 |
+
p.add_argument("--out-dir", default="dpacman/model/embeddings")
|
| 484 |
+
p.add_argument("--device", default="cpu")
|
| 485 |
+
args = p.parse_args()
|
| 486 |
+
|
| 487 |
+
os.makedirs(args.out_dir, exist_ok=True)
|
| 488 |
+
device = args.device
|
| 489 |
+
print(device)
|
| 490 |
+
|
| 491 |
+
if not args.skip_dna:
|
| 492 |
+
if args.genome_json_dir == None:
|
| 493 |
+
dna_df = pd.read_parquet('/home/a03-akrishna/DPACMAN/dpacman/model/remap2022_crm_fimo_output_q_processed.parquet', engine='pyarrow')
|
| 494 |
+
#df.to_csv('/home/a03-akrishna/DPACMAN/dpacman/model/remap2022_crm_fimo_output_q_processed.csv', index=False)
|
| 495 |
+
peak_seqs = dna_df["dna_sequence"]
|
| 496 |
+
peak_ids = dna_df["ID"]
|
| 497 |
+
print(f"Embedding {len(peak_seqs)} binding peak sequences from processed remap data", flush=True)
|
| 498 |
+
dna_embedder = get_embedder(args.chrom_model, device, for_dna=True)
|
| 499 |
+
out_peaks = Path(args.out_dir) / f"peaks_{args.chrom_model}.npy"
|
| 500 |
+
embed_and_save(peak_seqs, peak_ids, dna_embedder, out_peaks)
|
| 501 |
+
|
| 502 |
+
# peak_fasta = Path(args.peak_fasta)
|
| 503 |
+
# if peak_fasta.exists():
|
| 504 |
+
# # Load peak sequences from FASTA
|
| 505 |
+
# from Bio import SeqIO
|
| 506 |
+
|
| 507 |
+
# peak_seqs = []
|
| 508 |
+
# peak_ids = []
|
| 509 |
+
# for rec in SeqIO.parse(peak_fasta, "fasta"):
|
| 510 |
+
# peak_ids.append(rec.id)
|
| 511 |
+
# peak_seqs.append(str(rec.seq))
|
| 512 |
+
# print(f"Embedding {len(peak_seqs)} binding peak sequences from {peak_fasta}", flush=True)
|
| 513 |
+
# dna_embedder = get_embedder(args.chrom_model, device, for_dna=True)
|
| 514 |
+
# out_peaks = Path(args.out_dir) / f"peaks_{args.chrom_model}.npy"
|
| 515 |
+
# embed_and_save(peak_seqs, peak_ids, dna_embedder, out_peaks)
|
| 516 |
+
elif args.genome_json_dir:
|
| 517 |
+
# Legacy: load full chromosomes from JSONs (chr1–22, X, Y, M)
|
| 518 |
+
genome_dir = Path(args.genome_json_dir)
|
| 519 |
+
chrom_seqs, chrom_ids = [], []
|
| 520 |
+
primary_pattern = re.compile(r"^hg38_chr(?:[1-9]|1[0-9]|2[0-2]|X|Y|M)\.json$")
|
| 521 |
+
for j in sorted(genome_dir.iterdir()):
|
| 522 |
+
if not primary_pattern.match(j.name):
|
| 523 |
+
continue
|
| 524 |
+
data = json.loads(j.read_text())
|
| 525 |
+
seq = data.get("dna") or data.get("sequence")
|
| 526 |
+
chrom = data.get("chrom") or j.stem.split("_")[-1]
|
| 527 |
+
chrom_seqs.append(seq)
|
| 528 |
+
chrom_ids.append(chrom)
|
| 529 |
+
cutoff = CaduceusEmbedder(device).chunk_size
|
| 530 |
+
long_chroms = [
|
| 531 |
+
(chrom, len(seq))
|
| 532 |
+
for chrom, seq in zip(chrom_ids, chrom_seqs)
|
| 533 |
+
if len(seq) > cutoff
|
| 534 |
+
]
|
| 535 |
+
if long_chroms:
|
| 536 |
+
print("⚠️ Chromosomes exceeding Caduceus max tokens ({}):".format(cutoff))
|
| 537 |
+
for chrom, L in long_chroms:
|
| 538 |
+
print(f" {chrom}: {L} bases")
|
| 539 |
+
else:
|
| 540 |
+
print("All chromosomes ≤ Caduceus limit ({}).".format(cutoff))
|
| 541 |
+
|
| 542 |
+
chrom_embedder = get_embedder(args.chrom_model, device, for_dna=True)
|
| 543 |
+
out_chrom = Path(args.out_dir) / f"chrom_{args.chrom_model}.npy"
|
| 544 |
+
embed_and_save(chrom_seqs, chrom_ids, chrom_embedder, out_chrom)
|
| 545 |
+
else:
|
| 546 |
+
raise ValueError("No input for DNA embedding: provide a peak FASTA (default binding_peaks_unique.fa) or set --genome-json-dir for chromosome JSONs.")
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
#Load TF sequences
|
| 550 |
+
tf_seqs, tf_ids = [], []
|
| 551 |
+
for record in SeqIO.parse(args.tf_fasta, "fasta"):
|
| 552 |
+
tf_ids.append(record.id)
|
| 553 |
+
tf_seqs.append(str(record.seq))
|
| 554 |
+
|
| 555 |
+
# embed and save
|
| 556 |
+
tf_embedder = get_embedder(args.tf_model, device, for_dna=False)
|
| 557 |
+
out_tf = Path(args.out_dir) / f"tf_{args.tf_model}.npy"
|
| 558 |
+
embed_and_save(tf_seqs, tf_ids, tf_embedder, out_tf)
|
| 559 |
+
|
| 560 |
+
print("Done.")
|
dpacman/data_tasks/fimo/__init__.py
ADDED
|
File without changes
|
dpacman/data_tasks/fimo/post_fimo.py
CHANGED
|
@@ -8,6 +8,7 @@ import multiprocessing as mp
|
|
| 8 |
import numpy as np
|
| 9 |
import pandas as pd
|
| 10 |
import math
|
|
|
|
| 11 |
import rootutils
|
| 12 |
import polars as pl
|
| 13 |
from omegaconf import DictConfig
|
|
@@ -37,10 +38,18 @@ def format_sig(sig_vals, decimals=4, atol=0.0, rtol=1e-5):
|
|
| 37 |
return ",".join(out.tolist())
|
| 38 |
|
| 39 |
def _safe_process(task):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
try:
|
| 41 |
-
|
|
|
|
| 42 |
except Exception as e:
|
| 43 |
-
return ("err", (
|
| 44 |
|
| 45 |
def discover_chrom_folders(fimo_out_dir: Path) -> list[str]:
|
| 46 |
return sorted(
|
|
@@ -51,6 +60,10 @@ def discover_chrom_folders(fimo_out_dir: Path) -> list[str]:
|
|
| 51 |
def _process_one_row(row, dna: str, jaspar_boost: int = 100) -> dict:
|
| 52 |
# row order: TR, chrom, cstart, cend, peak_s, peak_e, chipscore, jaspar
|
| 53 |
trname, chrom, cstart, cend, peak_s, peak_e, chipscore, jaspar = row
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
seq = dna[cstart:cend]
|
| 56 |
L = len(seq)
|
|
@@ -62,7 +75,7 @@ def _process_one_row(row, dna: str, jaspar_boost: int = 100) -> dict:
|
|
| 62 |
peak_seq = ""
|
| 63 |
if ps < L and pe > 0:
|
| 64 |
scores[max(ps, 0):min(pe, L)] = chipscore
|
| 65 |
-
peak_seq =
|
| 66 |
|
| 67 |
# JASPAR hits (+jaspar_boost)
|
| 68 |
# only run if the peak is not np.nan
|
|
@@ -92,7 +105,7 @@ def _process_one_row(row, dna: str, jaspar_boost: int = 100) -> dict:
|
|
| 92 |
|
| 93 |
def _process_one_chrom_folder(task) -> pd.DataFrame:
|
| 94 |
"""Runs inside a worker process. Reads one chrom’s final.csv, loads DNA once, builds records."""
|
| 95 |
-
chrom_folder, fimo_out_dir_str, json_dir, jaspar_boost, output_parts_folder = task
|
| 96 |
|
| 97 |
# make unique logger for this process
|
| 98 |
log_dir = Path(HydraConfig.get().run.dir) / "logs"
|
|
@@ -120,6 +133,13 @@ def _process_one_chrom_folder(task) -> pd.DataFrame:
|
|
| 120 |
|
| 121 |
if df.empty:
|
| 122 |
return pd.DataFrame()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
|
| 124 |
# Normalize dtypes up-front
|
| 125 |
df["#chrom"] = df["#chrom"].astype(str)
|
|
@@ -176,170 +196,342 @@ def _process_one_chrom_folder(task) -> pd.DataFrame:
|
|
| 176 |
|
| 177 |
return savepath
|
| 178 |
|
| 179 |
-
def build_dataset_fast_mp(fimo_out_dir: Path, json_dir: str, debug: bool, max_workers: int | None, jaspar_boost: int = 100, output_parts_folder: str = None) -> pd.DataFrame:
|
| 180 |
"""
|
| 181 |
Multiprocessing to build final dataset across chromosomes
|
| 182 |
"""
|
| 183 |
chrom_folders = discover_chrom_folders(fimo_out_dir)
|
| 184 |
if not chrom_folders:
|
| 185 |
logger.warning(f"No chrom* folders with final.csv under {fimo_out_dir}")
|
| 186 |
-
return
|
| 187 |
|
| 188 |
if debug:
|
| 189 |
-
|
|
|
|
| 190 |
logger.info(f"DEBUG MODE: considering {chrom_folders[0]} only")
|
| 191 |
|
| 192 |
-
tasks = [(cf, str(fimo_out_dir), json_dir, jaspar_boost, output_parts_folder
|
|
|
|
| 193 |
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
for t in tasks:
|
| 199 |
status, payload = _safe_process(t)
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
logger.info(f"Using {procs} parallel workers for {len(tasks)} chrom folders")
|
| 213 |
-
|
| 214 |
-
paths, errs = [], []
|
| 215 |
-
with mp.Pool(processes=procs, maxtasksperchild=10) as pool:
|
| 216 |
-
for status, payload in pool.imap_unordered(_safe_process, tasks, chunksize=1):
|
| 217 |
-
if status == "ok" and isinstance(payload, pd.DataFrame) and not payload.empty:
|
| 218 |
-
paths.append(payload)
|
| 219 |
-
else:
|
| 220 |
-
errs.append(payload)
|
| 221 |
-
|
| 222 |
if errs:
|
| 223 |
for chrom, msg, tb in errs:
|
| 224 |
logger.error("Worker error for %s: %s\n%s", chrom, msg, tb)
|
| 225 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
|
| 227 |
-
|
| 228 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
|
| 230 |
def combine_processed_with_polars(
|
| 231 |
paths_to_processed_dfs: list[str],
|
| 232 |
idmap_path: str, # TSV with columns: From, Entry, Sequence
|
| 233 |
out_path: str, # e.g., "processed_out.parquet" or ".csv"
|
|
|
|
|
|
|
|
|
|
| 234 |
):
|
| 235 |
if not paths_to_processed_dfs:
|
| 236 |
logger.info("No records produced; nothing to write.")
|
| 237 |
return
|
| 238 |
|
| 239 |
-
# 1) Scan
|
| 240 |
-
lfs = [
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
|
| 248 |
-
#
|
| 249 |
idmap = (
|
| 250 |
pl.read_csv(idmap_path, separator="\t", columns=["From", "Entry", "Sequence"])
|
| 251 |
-
|
| 252 |
)
|
| 253 |
-
|
| 254 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
# 4) Per-chromosome unique peak index and peak_id
|
| 257 |
# (dense rank over peak_sequence per chrom; if you require "first-appearance" order,
|
| 258 |
# see the note below for an alternate approach.)
|
| 259 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
pl.col("peak_sequence").fill_null("").alias("peak_sequence"),
|
| 261 |
pl.col("chrom").cast(pl.Utf8),
|
| 262 |
])
|
| 263 |
-
|
| 264 |
pl.col("peak_sequence")
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
)
|
| 270 |
-
|
| 271 |
-
pl.format("
|
| 272 |
)
|
| 273 |
-
logger.info(f"Assigned unique
|
| 274 |
|
| 275 |
# 5) Build stable IDs for dna_sequence and tr_sequence based on first appearance
|
| 276 |
# (do this by creating small maps with unique(..., maintain_order=True) and joining)
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
)
|
| 284 |
-
logger.info(f"Assigned
|
| 285 |
-
tr_map = (
|
| 286 |
-
lf.select("tr_sequence")
|
| 287 |
-
.unique(maintain_order=True)
|
| 288 |
-
.with_row_index("tr_idx", offset=1)
|
| 289 |
-
.with_columns(pl.format("trseq{}", pl.col("tr_idx")).alias("tr_seqid"))
|
| 290 |
-
.select("tr_sequence", "tr_seqid")
|
| 291 |
-
)
|
| 292 |
-
logger.info(f"Assigned tr_sequence IDs")
|
| 293 |
-
lf = lf.join(dna_map, on="dna_sequence", how="left").join(tr_map, on="tr_sequence", how="left")
|
| 294 |
-
logger.info(f"Applied dna_sequence and tr_sequence IDs to main table")
|
| 295 |
-
|
| 296 |
-
# 6) Final ID and column selection
|
| 297 |
-
lf = lf.with_columns(
|
| 298 |
-
(pl.col("tr_seqid") + pl.lit("_") + pl.col("dna_seqid")).alias("ID")
|
| 299 |
-
)
|
| 300 |
-
cols = [
|
| 301 |
-
"ID", "tr_name", "peak_id", "chipscore", "total_jaspar_hits",
|
| 302 |
-
"dna_sequence", "tr_sequence", "scores"
|
| 303 |
-
]
|
| 304 |
-
lf_out = lf.select(cols)
|
| 305 |
-
#n_rows = lf_out.select(pl.len().alias("rows")).collect(streaming=True)["rows"][0]
|
| 306 |
-
logger.info(f"Selected final columns")
|
| 307 |
-
|
| 308 |
-
# 7) Write streaming to disk
|
| 309 |
-
out_path = str(out_path)
|
| 310 |
-
Path(out_path).parent.mkdir(parents=True, exist_ok=True)
|
| 311 |
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
# NOTE: collect(streaming=True) still returns an in-memory DataFrame;
|
| 317 |
-
# prefer Parquet for very large outputs.
|
| 318 |
-
lf_out.collect(streaming=True).write_csv(out_path)
|
| 319 |
-
logger.info(f"Wrote csv file to {out_path}")
|
| 320 |
-
else:
|
| 321 |
-
# default to Parquet if no/unknown extension
|
| 322 |
-
lf_out.sink_parquet(out_path + ".parquet", compression="zstd", statistics=True)
|
| 323 |
-
logger.info(f"Wrote parquet file to {out_path}")
|
| 324 |
-
|
| 325 |
-
# 8) (Optional) small summary: unique peaks per chrom
|
| 326 |
-
peaks_per_chrom = (
|
| 327 |
-
lf.select("chrom", "peak_sequence")
|
| 328 |
-
.unique()
|
| 329 |
-
.group_by("chrom")
|
| 330 |
-
.len()
|
| 331 |
-
.collect(streaming=True)
|
| 332 |
-
.sort("chrom")
|
| 333 |
)
|
| 334 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 335 |
|
| 336 |
-
logger.info("Schema:")
|
| 337 |
-
for name, dtype in lf_out.schema.items():
|
| 338 |
-
logger.info(f" {name}: {dtype}")
|
| 339 |
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 343 |
|
| 344 |
# Save the FIRST 1000 rows to CSV (streaming-friendly)
|
| 345 |
df_first = lf_out.limit(1000).collect(streaming=True)
|
|
@@ -347,6 +539,197 @@ def combine_processed_with_polars(
|
|
| 347 |
df_first.write_csv(example_out_path)
|
| 348 |
logger.info(f"Wrote first 1000 rows to {example_out_path} as an example")
|
| 349 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
def main(cfg: DictConfig):
|
| 351 |
debug = bool(cfg.data_task.debug)
|
| 352 |
json_dir = cfg.data_task.json_dir
|
|
@@ -357,24 +740,43 @@ def main(cfg: DictConfig):
|
|
| 357 |
|
| 358 |
logger.info(f"Debug: {debug}")
|
| 359 |
logger.info(f"Reading per-chrom final.csv under: {fimo_out_dir}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 360 |
|
| 361 |
-
|
|
|
|
| 362 |
paths_to_processed_dfs = build_dataset_fast_mp(
|
| 363 |
fimo_out_dir=fimo_out_dir,
|
| 364 |
json_dir=json_dir,
|
| 365 |
debug=debug,
|
| 366 |
max_workers=max_workers,
|
| 367 |
jaspar_boost=cfg.data_task.jaspar_boost,
|
| 368 |
-
output_parts_folder=output_parts_folder
|
|
|
|
| 369 |
)
|
| 370 |
else:
|
| 371 |
paths_to_processed_dfs = [output_parts_folder/x for x in os.listdir(output_parts_folder)] if output_parts_folder.exists() else []
|
| 372 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
logger.info(f"Combining {len(paths_to_processed_dfs)} processed parts with Polars")
|
| 374 |
combine_processed_with_polars(
|
| 375 |
paths_to_processed_dfs=paths_to_processed_dfs,
|
| 376 |
-
idmap_path=
|
| 377 |
-
out_path=
|
|
|
|
|
|
|
| 378 |
)
|
| 379 |
|
| 380 |
# Delete the folder that had the temporary DFs, don't need these
|
|
|
|
| 8 |
import numpy as np
|
| 9 |
import pandas as pd
|
| 10 |
import math
|
| 11 |
+
import json
|
| 12 |
import rootutils
|
| 13 |
import polars as pl
|
| 14 |
from omegaconf import DictConfig
|
|
|
|
| 38 |
return ",".join(out.tolist())
|
| 39 |
|
| 40 |
def _safe_process(task):
|
| 41 |
+
"""
|
| 42 |
+
Returns:
|
| 43 |
+
("ok", <path-to-output>) on success
|
| 44 |
+
("err", (chrom, msg, traceback)) on failure
|
| 45 |
+
"""
|
| 46 |
+
import traceback as tb
|
| 47 |
+
chrom = task[0]
|
| 48 |
try:
|
| 49 |
+
out_path = _process_one_chrom_folder(task) # MUST return a path (str/Path)
|
| 50 |
+
return ("ok", str(out_path))
|
| 51 |
except Exception as e:
|
| 52 |
+
return ("err", (chrom, repr(e), tb.format_exc()))
|
| 53 |
|
| 54 |
def discover_chrom_folders(fimo_out_dir: Path) -> list[str]:
|
| 55 |
return sorted(
|
|
|
|
| 60 |
def _process_one_row(row, dna: str, jaspar_boost: int = 100) -> dict:
|
| 61 |
# row order: TR, chrom, cstart, cend, peak_s, peak_e, chipscore, jaspar
|
| 62 |
trname, chrom, cstart, cend, peak_s, peak_e, chipscore, jaspar = row
|
| 63 |
+
|
| 64 |
+
# very few chipscores are > 1000. standardize by setting >1000 to a max score
|
| 65 |
+
if chipscore>=1000:
|
| 66 |
+
chipscore=1000
|
| 67 |
|
| 68 |
seq = dna[cstart:cend]
|
| 69 |
L = len(seq)
|
|
|
|
| 75 |
peak_seq = ""
|
| 76 |
if ps < L and pe > 0:
|
| 77 |
scores[max(ps, 0):min(pe, L)] = chipscore
|
| 78 |
+
peak_seq = seq[max(ps, 0):min(pe, L)]
|
| 79 |
|
| 80 |
# JASPAR hits (+jaspar_boost)
|
| 81 |
# only run if the peak is not np.nan
|
|
|
|
| 105 |
|
| 106 |
def _process_one_chrom_folder(task) -> pd.DataFrame:
|
| 107 |
"""Runs inside a worker process. Reads one chrom’s final.csv, loads DNA once, builds records."""
|
| 108 |
+
chrom_folder, fimo_out_dir_str, json_dir, jaspar_boost, output_parts_folder, keep_fimo_only = task
|
| 109 |
|
| 110 |
# make unique logger for this process
|
| 111 |
log_dir = Path(HydraConfig.get().run.dir) / "logs"
|
|
|
|
| 133 |
|
| 134 |
if df.empty:
|
| 135 |
return pd.DataFrame()
|
| 136 |
+
|
| 137 |
+
if keep_fimo_only:
|
| 138 |
+
logger.info(f"keep_fimo_only=True. Starting with {len(df)} rows.")
|
| 139 |
+
df = df.loc[
|
| 140 |
+
~df["jaspar"].isna()
|
| 141 |
+
].reset_index(drop=True)
|
| 142 |
+
logger.info(f"After keeping fimo hits only: {len(df)} rows remain.")
|
| 143 |
|
| 144 |
# Normalize dtypes up-front
|
| 145 |
df["#chrom"] = df["#chrom"].astype(str)
|
|
|
|
| 196 |
|
| 197 |
return savepath
|
| 198 |
|
| 199 |
+
def build_dataset_fast_mp(fimo_out_dir: Path, json_dir: str, debug: bool, max_workers: int | None, jaspar_boost: int = 100, output_parts_folder: str = None, keep_fimo_only: bool=False) -> pd.DataFrame:
|
| 200 |
"""
|
| 201 |
Multiprocessing to build final dataset across chromosomes
|
| 202 |
"""
|
| 203 |
chrom_folders = discover_chrom_folders(fimo_out_dir)
|
| 204 |
if not chrom_folders:
|
| 205 |
logger.warning(f"No chrom* folders with final.csv under {fimo_out_dir}")
|
| 206 |
+
return []
|
| 207 |
|
| 208 |
if debug:
|
| 209 |
+
# keep chromY if present, otherwise just the first
|
| 210 |
+
chrom_folders = [c for c in chrom_folders if c == "chromY"] or chrom_folders[:1]
|
| 211 |
logger.info(f"DEBUG MODE: considering {chrom_folders[0]} only")
|
| 212 |
|
| 213 |
+
tasks = [(cf, str(fimo_out_dir), json_dir, jaspar_boost, output_parts_folder, keep_fimo_only)
|
| 214 |
+
for cf in chrom_folders]
|
| 215 |
|
| 216 |
+
def _collect(status, payload, good_paths, errs):
|
| 217 |
+
if status == "ok":
|
| 218 |
+
p = Path(payload)
|
| 219 |
+
if p.exists():
|
| 220 |
+
good_paths.append(p)
|
| 221 |
+
else:
|
| 222 |
+
errs.append(("?", f"output missing: {p}", ""))
|
| 223 |
+
else:
|
| 224 |
+
chrom, msg, tb = payload
|
| 225 |
+
errs.append((chrom, msg, tb))
|
| 226 |
+
|
| 227 |
+
# Serial path (debug/deterministic or single task)
|
| 228 |
+
if (max_workers is not None and max_workers <= 1) or len(tasks) == 1:
|
| 229 |
+
good_paths: list[Path] = []
|
| 230 |
+
errs: list[tuple[str, str, str]] = []
|
| 231 |
for t in tasks:
|
| 232 |
status, payload = _safe_process(t)
|
| 233 |
+
_collect(status, payload, good_paths, errs)
|
| 234 |
+
else:
|
| 235 |
+
# Parallel path
|
| 236 |
+
procs = min(max_workers or mp.cpu_count(), len(tasks))
|
| 237 |
+
logger.info(f"Using {procs} parallel workers for {len(tasks)} chrom folders")
|
| 238 |
+
|
| 239 |
+
good_paths: list[Path] = []
|
| 240 |
+
errs: list[tuple[str, str, str]] = []
|
| 241 |
+
with mp.Pool(processes=procs, maxtasksperchild=10) as pool:
|
| 242 |
+
for status, payload in pool.imap_unordered(_safe_process, tasks, chunksize=1):
|
| 243 |
+
_collect(status, payload, good_paths, errs)
|
| 244 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
if errs:
|
| 246 |
for chrom, msg, tb in errs:
|
| 247 |
logger.error("Worker error for %s: %s\n%s", chrom, msg, tb)
|
| 248 |
+
# Optionally: raise RuntimeError("One or more workers failed; see logs.")
|
| 249 |
+
|
| 250 |
+
return [str(p) for p in good_paths]
|
| 251 |
+
|
| 252 |
+
def dedup_trname_peakseq_weighted(lf: pl.LazyFrame, seed: int = 42, outdir: str | None = None) -> pl.LazyFrame:
|
| 253 |
+
"""
|
| 254 |
+
Remove duplicate pairings of TR + peak sequence, but keep the distribution of chromosomes as best as possible.
|
| 255 |
+
Use a seed so the results are reproducible
|
| 256 |
+
"""
|
| 257 |
+
# Normalize key dtypes
|
| 258 |
+
lf = lf.with_columns([
|
| 259 |
+
pl.col("chrom").cast(pl.Utf8),
|
| 260 |
+
pl.col("tr_name").cast(pl.Utf8),
|
| 261 |
+
pl.col("peak_sequence").fill_null("<NULL>").cast(pl.Utf8),
|
| 262 |
+
])
|
| 263 |
+
|
| 264 |
+
# --- BEFORE: counts/ratios (materialize tiny table)
|
| 265 |
+
pre_df = (
|
| 266 |
+
lf.group_by("chrom").len()
|
| 267 |
+
.with_columns((pl.col("len") / pl.col("len").sum()).alias("pre_ratio"))
|
| 268 |
+
.sort("chrom")
|
| 269 |
+
.collect()
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
# Expected #groups (must equal result rows)
|
| 273 |
+
exp_groups = lf.select(["tr_name","peak_sequence"]).unique().collect().height
|
| 274 |
+
logger.info(f"Expected groups: {exp_groups}")
|
| 275 |
+
|
| 276 |
+
# Tiny weights table back to lazy
|
| 277 |
+
pre_lf = pre_df.lazy().select(["chrom", "pre_ratio"]).rename({"pre_ratio": "w"})
|
| 278 |
+
|
| 279 |
+
# Weighted random score = log(w) + Gumbel(0,1); tie-break by a stable hash
|
| 280 |
+
TWO64 = 18446744073709551616.0
|
| 281 |
+
eps = 1e-12
|
| 282 |
+
|
| 283 |
+
h_expr = pl.concat_str(
|
| 284 |
+
[pl.lit(f"seed:{seed}"), pl.col("tr_name"), pl.col("peak_sequence"), pl.col("chrom")],
|
| 285 |
+
separator="|",
|
| 286 |
+
).hash().cast(pl.UInt64)
|
| 287 |
+
|
| 288 |
+
u_expr = (h_expr.cast(pl.Float64) + 1.0) / pl.lit(TWO64)
|
| 289 |
+
u_expr = pl.when(u_expr < eps).then(eps).when(u_expr > 1 - eps).then(1 - eps).otherwise(u_expr)
|
| 290 |
|
| 291 |
+
logw_expr = pl.when(pl.col("w").is_null() | (pl.col("w") <= 0)).then(eps).otherwise(pl.col("w")).log()
|
| 292 |
+
gumbel_expr = -(-u_expr.log()).log()
|
| 293 |
+
score_expr = (logw_expr + gumbel_expr).alias("_score")
|
| 294 |
+
hash_expr = h_expr.alias("_h")
|
| 295 |
+
|
| 296 |
+
# Attach weights & scores, globally sort, then unique on the keys (keep first)
|
| 297 |
+
lf_sorted = (
|
| 298 |
+
lf.join(pre_lf, on="chrom", how="left")
|
| 299 |
+
.with_columns([score_expr, hash_expr])
|
| 300 |
+
.sort(["_score","_h"], descending=[True, False])
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
lf_sel = lf_sorted.unique(subset=["tr_name","peak_sequence"], keep="first") \
|
| 304 |
+
.drop(["w","_score","_h"])
|
| 305 |
+
|
| 306 |
+
# --- AFTER: counts/ratios + save
|
| 307 |
+
post_df = (
|
| 308 |
+
lf_sel.group_by("chrom").len()
|
| 309 |
+
.with_columns((pl.col("len") / pl.col("len").sum()).alias("post_ratio"))
|
| 310 |
+
.sort("chrom")
|
| 311 |
+
.collect(streaming=True)
|
| 312 |
+
)
|
| 313 |
+
compare_df = (pre_df.select(["chrom","len","pre_ratio"]).rename({"len":"pre_n"})
|
| 314 |
+
.join(post_df.select(["chrom","len","post_ratio"]).rename({"len":"post_n"}),
|
| 315 |
+
on="chrom", how="full")
|
| 316 |
+
.fill_null(0)
|
| 317 |
+
.with_columns((pl.col("post_ratio") - pl.col("pre_ratio")).abs().alias("abs_delta"),
|
| 318 |
+
(100*(pl.col("post_ratio") - pl.col("pre_ratio"))/pl.col("pre_ratio")).abs().alias("pcnt_delta"))
|
| 319 |
+
.sort("chrom")
|
| 320 |
+
).to_pandas().drop(columns=["chrom_right"])
|
| 321 |
+
# --- Sanity: must keep exactly one per group
|
| 322 |
+
got_rows = lf_sel.select(pl.len()).collect()["len"][0]
|
| 323 |
+
if got_rows != exp_groups:
|
| 324 |
+
# optional: raise or just log
|
| 325 |
+
logger.warning(f"Dedup cardinality mismatch: expected {exp_groups}, got {got_rows}")
|
| 326 |
+
|
| 327 |
+
return lf_sel, compare_df
|
| 328 |
+
|
| 329 |
+
def write_map(lf: pl.LazyFrame, out_path: str, key: str, val: str, outname: str):
|
| 330 |
+
"""
|
| 331 |
+
Write the ID maps we created, spanning all the data. Will be called for:
|
| 332 |
+
- tr_seqid to tr_sequence
|
| 333 |
+
- peak_seqid to peak_sequence
|
| 334 |
+
- dna_seqid to dna_sequence
|
| 335 |
+
"""
|
| 336 |
+
maps_dir = Path(out_path).parent / "maps"
|
| 337 |
+
maps_dir.mkdir(parents=True, exist_ok=True)
|
| 338 |
+
|
| 339 |
+
df = (
|
| 340 |
+
lf.select([pl.col(key), pl.col(val)])
|
| 341 |
+
.unique()
|
| 342 |
+
.collect(streaming=True)
|
| 343 |
+
)
|
| 344 |
+
mapping = dict(zip(df[key].to_list(), df[val].to_list()))
|
| 345 |
+
with open(maps_dir / outname, "w") as f:
|
| 346 |
+
json.dump(mapping, f, indent=2)
|
| 347 |
+
|
| 348 |
|
| 349 |
def combine_processed_with_polars(
|
| 350 |
paths_to_processed_dfs: list[str],
|
| 351 |
idmap_path: str, # TSV with columns: From, Entry, Sequence
|
| 352 |
out_path: str, # e.g., "processed_out.parquet" or ".csv"
|
| 353 |
+
max_protein_len: int = None,
|
| 354 |
+
check_violations: bool = False,
|
| 355 |
+
seeds: list = [0]
|
| 356 |
):
|
| 357 |
if not paths_to_processed_dfs:
|
| 358 |
logger.info("No records produced; nothing to write.")
|
| 359 |
return
|
| 360 |
|
| 361 |
+
# 1) Scan each CSV and normalize dtypes BEFORE concat
|
| 362 |
+
lfs = []
|
| 363 |
+
for p in paths_to_processed_dfs:
|
| 364 |
+
lf_i = (
|
| 365 |
+
pl.scan_csv(p) # don't use infer_schema_length=0 here
|
| 366 |
+
.with_columns([
|
| 367 |
+
pl.col("chrom").cast(pl.Utf8),
|
| 368 |
+
pl.col("tr_name").cast(pl.Utf8),
|
| 369 |
+
pl.col("dna_sequence").cast(pl.Utf8),
|
| 370 |
+
pl.col("peak_sequence").cast(pl.Utf8),
|
| 371 |
+
pl.col("scores").cast(pl.Utf8),
|
| 372 |
+
pl.col("chipscore").cast(pl.Float64), # robust (int -> float OK)
|
| 373 |
+
pl.col("total_jaspar_hits").cast(pl.Int64),
|
| 374 |
+
])
|
| 375 |
+
)
|
| 376 |
+
lfs.append(lf_i)
|
| 377 |
+
|
| 378 |
+
# 2) Now concat; schemas match
|
| 379 |
+
lf_og = pl.concat(lfs, how="vertical")
|
| 380 |
|
| 381 |
+
# Read idmap to get list of unmapped TRs, those with no sequence
|
| 382 |
idmap = (
|
| 383 |
pl.read_csv(idmap_path, separator="\t", columns=["From", "Entry", "Sequence"])
|
| 384 |
+
.rename({"From": "tr_name", "Entry": "tr_uniprot", "Sequence": "tr_sequence"})
|
| 385 |
)
|
| 386 |
+
idmap = idmap.with_columns(
|
| 387 |
+
pl.col("tr_sequence").map_elements(lambda x: len(x), return_dtype=pl.Int64).alias("tr_len")
|
| 388 |
+
)
|
| 389 |
+
if max_protein_len is not None:
|
| 390 |
+
idmap = idmap.filter(
|
| 391 |
+
pl.col("tr_len")<=max_protein_len
|
| 392 |
+
)
|
| 393 |
+
logger.info(f"Filtered valid TRs to only those with len <= {max_protein_len}")
|
| 394 |
|
| 395 |
+
success_trs = list(idmap["tr_name"].unique())
|
| 396 |
+
logger.info(f"Total valid TRs: {len(success_trs)}")
|
| 397 |
+
|
| 398 |
+
# Filter lf to only have "success-TRs"
|
| 399 |
+
lf_og = lf_og.filter(pl.col("tr_name").is_in(success_trs))
|
| 400 |
+
|
| 401 |
+
# 2) We COULD drop duplicate occurrences of tr_name and peak_sequence, because these are the same peak. But here we're showing them in different contexts
|
| 402 |
+
# Instead, let's drop duplicate occurrences of the same tr_name and dna_sequence, because these are duplicate datapoints.
|
| 403 |
+
lf_out = None # the last one will be used to save the example file
|
| 404 |
+
out_path = str(out_path)
|
| 405 |
+
Path(out_path).parent.mkdir(parents=True, exist_ok=True)
|
| 406 |
+
|
| 407 |
+
lf_og = lf_og.join(idmap.lazy(), on="tr_name", how="left")
|
| 408 |
+
logger.info(f"Merged in UniProt IDs and TR sequences from UniProt ID mappping")
|
| 409 |
+
|
| 410 |
# 4) Per-chromosome unique peak index and peak_id
|
| 411 |
# (dense rank over peak_sequence per chrom; if you require "first-appearance" order,
|
| 412 |
# see the note below for an alternate approach.)
|
| 413 |
+
# Ensure types
|
| 414 |
+
lf_og = lf_og.with_columns([
|
| 415 |
+
pl.col("dna_sequence").cast(pl.Utf8),
|
| 416 |
+
pl.col("tr_sequence").cast(pl.Utf8),
|
| 417 |
+
pl.col("peak_sequence").cast(pl.Utf8)
|
| 418 |
+
])
|
| 419 |
+
|
| 420 |
+
# set chrom+ peak sequence IDs
|
| 421 |
+
lf_og = lf_og.with_columns([
|
| 422 |
pl.col("peak_sequence").fill_null("").alias("peak_sequence"),
|
| 423 |
pl.col("chrom").cast(pl.Utf8),
|
| 424 |
])
|
| 425 |
+
lf_og = lf_og.with_columns(
|
| 426 |
pl.col("peak_sequence")
|
| 427 |
+
.rank(method="dense") # 1,2,3,... per group
|
| 428 |
+
.over("chrom")
|
| 429 |
+
.cast(pl.Int64)
|
| 430 |
+
.alias("chrom_peak_idx")
|
| 431 |
)
|
| 432 |
+
lf_og = lf_og.with_columns(
|
| 433 |
+
pl.format("chr{}_peak{}", pl.col("chrom"), pl.col("chrom_peak_idx")).alias("chrpeak_id")
|
| 434 |
)
|
| 435 |
+
logger.info(f"Assigned unique chrpeak_ids per chromosome based on peak_sequence")
|
| 436 |
|
| 437 |
# 5) Build stable IDs for dna_sequence and tr_sequence based on first appearance
|
| 438 |
# (do this by creating small maps with unique(..., maintain_order=True) and joining)
|
| 439 |
+
# Sequence-based IDs without any joins
|
| 440 |
+
lf_og = (
|
| 441 |
+
lf_og.with_columns([
|
| 442 |
+
pl.col("dna_sequence").rank(method="dense").cast(pl.Int64).alias("dna_idx"),
|
| 443 |
+
pl.col("tr_sequence").rank(method="dense").cast(pl.Int64).alias("tr_idx"),
|
| 444 |
+
pl.col("peak_sequence").rank(method="dense").cast(pl.Int64).alias("peak_idx"),
|
| 445 |
+
])
|
| 446 |
+
.with_columns([
|
| 447 |
+
pl.format("dnaseq{}", pl.col("dna_idx")).alias("dna_seqid"),
|
| 448 |
+
pl.format("trseq{}", pl.col("tr_idx")).alias("tr_seqid"),
|
| 449 |
+
pl.format("peakseq{}", pl.col("peak_idx")).alias("peak_seqid"),
|
| 450 |
+
])
|
| 451 |
+
.drop(["dna_idx", "tr_idx", "peak_idx"])
|
| 452 |
)
|
| 453 |
+
logger.info(f"Assigned unique dna IDs, transcriptional regulator IDs, and peak IDs")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 454 |
|
| 455 |
+
# Final ID (will never be None now)
|
| 456 |
+
lf_og = lf_og.with_columns(
|
| 457 |
+
pl.concat_str([pl.col("tr_seqid"), pl.lit("_"), pl.col("dna_seqid")], ignore_nulls=False)
|
| 458 |
+
.alias("ID")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 459 |
)
|
| 460 |
+
|
| 461 |
+
# Write the maps
|
| 462 |
+
# call it for each mapping
|
| 463 |
+
write_map(lf_og, out_path=out_path, val="tr_sequence", key="tr_seqid", outname="tr_seqid_to_tr_sequence.json")
|
| 464 |
+
write_map(lf_og, out_path=out_path, val="peak_sequence", key="peak_seqid", outname="peak_seqid_to_peak_sequence.json")
|
| 465 |
+
write_map(lf_og, out_path=out_path, val="dna_sequence", key="dna_seqid", outname="dna_seqid_to_dna_sequence.json")
|
| 466 |
+
|
| 467 |
+
for seed in seeds:
|
| 468 |
+
# edit out path to include seed
|
| 469 |
+
if "." in out_path:
|
| 470 |
+
out_path_full = out_path[0:out_path.rindex(".")] + f"_seed{seed}" + out_path[out_path.rindex("."):]
|
| 471 |
+
else:
|
| 472 |
+
out_path_full = out_path + f"_seed{seed}.parquet"
|
| 473 |
+
|
| 474 |
+
lf, compare_df = dedup_trname_peakseq_weighted(lf_og, seed=seed)
|
| 475 |
+
logger.info(f"Dropped duplicate examples of tr_name + peak_sequence. Maintained chrom distribution with weighted random sampling (seed={seed}).")
|
| 476 |
+
|
| 477 |
+
# Save comparison df. Annotate with debug if it's a debug run
|
| 478 |
+
compare_df_path = str(Path(out_path).parent/f"chrom_ratio_compare_seed{seed}.csv")
|
| 479 |
+
if "debug" in out_path: compare_df_path = compare_df_path.replace(".csv", "_debug.csv")
|
| 480 |
+
compare_df.to_csv(compare_df_path, index=False)
|
| 481 |
+
|
| 482 |
+
# 3) Join small idmap (read eagerly; it’s tiny)
|
| 483 |
+
lf = lf.join(idmap.lazy(), on="tr_name", how="left")
|
| 484 |
+
logger.info(f"Merged in UniProt IDs and TR sequences from UniProt ID mappping")
|
| 485 |
|
|
|
|
|
|
|
|
|
|
| 486 |
|
| 487 |
+
logger.info(f"Applied dna_sequence and tr_sequence IDs to main table")
|
| 488 |
+
|
| 489 |
+
# Each sequence maps to exactly one id
|
| 490 |
+
viol1 = (
|
| 491 |
+
lf.select("dna_sequence", "dna_seqid").unique()
|
| 492 |
+
.group_by("dna_sequence").agg(pl.n_unique("dna_seqid").alias("n_ids"))
|
| 493 |
+
.filter(pl.col("n_ids") > 1)
|
| 494 |
+
.collect()
|
| 495 |
+
)
|
| 496 |
+
# Each id maps to exactly one sequence
|
| 497 |
+
viol2 = (
|
| 498 |
+
lf.select("dna_sequence", "dna_seqid").unique()
|
| 499 |
+
.group_by("dna_seqid").agg(pl.n_unique("dna_sequence").alias("n_seqs"))
|
| 500 |
+
.filter(pl.col("n_seqs") > 1)
|
| 501 |
+
.collect()
|
| 502 |
+
)
|
| 503 |
+
logger.info("viol1 rows (seq→>1 id): %d; viol2 rows (id→>1 seq): %d", viol1.height, viol2.height)
|
| 504 |
+
|
| 505 |
+
# No NULLs
|
| 506 |
+
nulls = lf.select([
|
| 507 |
+
pl.col("dna_seqid").is_null().sum().alias("null_dna_seqid"),
|
| 508 |
+
pl.col("tr_seqid").is_null().sum().alias("null_tr_seqid"),
|
| 509 |
+
pl.col("ID").is_null().sum().alias("null_ID"),
|
| 510 |
+
]).collect()
|
| 511 |
+
logger.info("NULL counts:\n%s", nulls)
|
| 512 |
+
|
| 513 |
+
# 6) Final column selection
|
| 514 |
+
cols = [
|
| 515 |
+
"ID", "tr_seqid", "dna_seqid", "peak_seqid", "chrpeak_id", "tr_name", "chipscore", "total_jaspar_hits",
|
| 516 |
+
"dna_sequence", "tr_sequence", "scores"
|
| 517 |
+
]
|
| 518 |
+
lf_out = lf.select(cols)
|
| 519 |
+
#n_rows = lf_out.select(pl.len().alias("rows")).collect(streaming=True)["rows"][0]
|
| 520 |
+
logger.info(f"Selected final columns")
|
| 521 |
+
|
| 522 |
+
# 7) Write streaming to disk
|
| 523 |
+
if out_path_full.lower().endswith(".parquet"):
|
| 524 |
+
lf_out.sink_parquet(out_path_full, compression="zstd", statistics=True, row_group_size=128_000)
|
| 525 |
+
logger.info(f"Wrote parquet file to {out_path_full}")
|
| 526 |
+
elif out_path_full.lower().endswith(".csv"):
|
| 527 |
+
# NOTE: collect(streaming=True) still returns an in-memory DataFrame;
|
| 528 |
+
# prefer Parquet for very large outputs.
|
| 529 |
+
lf_out.collect(streaming=True).write_csv(out_path_full)
|
| 530 |
+
logger.info(f"Wrote csv file to {out_path_full}")
|
| 531 |
+
else:
|
| 532 |
+
# default to Parquet if no/unknown extension
|
| 533 |
+
lf_out.sink_parquet(out_path_full + ".parquet", compression="zstd", statistics=True)
|
| 534 |
+
logger.info(f"Wrote parquet file to {out_path_full}")
|
| 535 |
|
| 536 |
# Save the FIRST 1000 rows to CSV (streaming-friendly)
|
| 537 |
df_first = lf_out.limit(1000).collect(streaming=True)
|
|
|
|
| 539 |
df_first.write_csv(example_out_path)
|
| 540 |
logger.info(f"Wrote first 1000 rows to {example_out_path} as an example")
|
| 541 |
|
| 542 |
+
# FIMO check
|
| 543 |
+
def get_reverse_complement(s):
|
| 544 |
+
"""
|
| 545 |
+
Returns 5' to 3' sequence of the reverse complement
|
| 546 |
+
"""
|
| 547 |
+
chars = list(s)
|
| 548 |
+
recon = []
|
| 549 |
+
rev_map = {
|
| 550 |
+
"a": "t",
|
| 551 |
+
"c": "g",
|
| 552 |
+
"t": "a",
|
| 553 |
+
"g": "c",
|
| 554 |
+
"A":"T",
|
| 555 |
+
"C": "G",
|
| 556 |
+
"T": "A",
|
| 557 |
+
"G": "C",
|
| 558 |
+
"n": "n",
|
| 559 |
+
"N": "N"
|
| 560 |
+
}
|
| 561 |
+
for c in chars:
|
| 562 |
+
recon += [rev_map[c]]
|
| 563 |
+
|
| 564 |
+
recon = "".join(recon)
|
| 565 |
+
return recon[::-1]
|
| 566 |
+
|
| 567 |
+
def extract_jaspar_motifs(row, reverse_complement=False):
|
| 568 |
+
s = row["scores"]
|
| 569 |
+
s = [int(x) for x in s.split(",")]
|
| 570 |
+
n_motifs = row["total_jaspar_hits"]
|
| 571 |
+
if n_motifs==0:
|
| 572 |
+
return ""
|
| 573 |
+
chipscore = row["chipscore"]
|
| 574 |
+
dna_seq = row["dna_sequence"]
|
| 575 |
+
if reverse_complement:
|
| 576 |
+
dna_seq = row["dna_sequence_rc"]
|
| 577 |
+
jaspar_indices = [i for i in list(range(len(s))) if s[i]>chipscore]
|
| 578 |
+
|
| 579 |
+
pred_motif = ""
|
| 580 |
+
for i in list(range(jaspar_indices[0], jaspar_indices[-1]+1)):
|
| 581 |
+
if not(i in jaspar_indices):
|
| 582 |
+
pred_motif += "-"
|
| 583 |
+
else:
|
| 584 |
+
pred_motif += dna_seq[i]
|
| 585 |
+
|
| 586 |
+
return pred_motif
|
| 587 |
+
|
| 588 |
+
def clean_idmap(idmap_path):
|
| 589 |
+
"""
|
| 590 |
+
The raw ID Map from UniProt returned multiple results.
|
| 591 |
+
We went to ReMap and wrote down what the right mappings are in these cases.
|
| 592 |
+
"""
|
| 593 |
+
|
| 594 |
+
manual_map = {
|
| 595 |
+
"BACH1": "O14867",
|
| 596 |
+
"BAP1": "Q92560",
|
| 597 |
+
"BDP1": "A6H8Y1",
|
| 598 |
+
"BRF1": "Q92994",
|
| 599 |
+
"CUX1": "Q13948",
|
| 600 |
+
"DDX21": "Q9NR30",
|
| 601 |
+
"ERG": "P11308",
|
| 602 |
+
"HBP1": "O60381",
|
| 603 |
+
"KLF14": "Q8TD94",
|
| 604 |
+
"MED1": "Q15648",
|
| 605 |
+
"MED25": "Q71SY5",
|
| 606 |
+
"MGA": "Q8IWI9",
|
| 607 |
+
"NRF1": "Q16656",
|
| 608 |
+
"PAF1": "Q8N7H5",
|
| 609 |
+
"PDX1": "P52945",
|
| 610 |
+
"RBP2": "P50120",
|
| 611 |
+
"RLF": "Q13129",
|
| 612 |
+
"SP1": "P08047",
|
| 613 |
+
"SPIN1": "Q9Y657",
|
| 614 |
+
"STAG1": "Q8WVM7",
|
| 615 |
+
"TAF15": "Q92804",
|
| 616 |
+
"TCF3": "P15923",
|
| 617 |
+
"ZFP36": "P26651",
|
| 618 |
+
"EVI1": "Q03112",
|
| 619 |
+
"MCM2": "P49736"
|
| 620 |
+
}
|
| 621 |
+
idmap = pd.read_csv(idmap_path, sep="\t")
|
| 622 |
+
idmap["Remap_Entry"] = idmap.apply(lambda row: row["Entry"] if not(row["From"] in manual_map) else manual_map[row["From"]], axis=1)
|
| 623 |
+
idmap_remapped = idmap.loc[
|
| 624 |
+
idmap["Entry"]==idmap["Remap_Entry"]
|
| 625 |
+
].reset_index(drop=True).drop(columns=["Remap_Entry"])
|
| 626 |
+
|
| 627 |
+
assert len(idmap_remapped)==len(idmap_remapped["From"].unique())
|
| 628 |
+
logger.info(f"Total transcriptional regulators successfully mapped in UniProt: {len(idmap_remapped)}")
|
| 629 |
+
|
| 630 |
+
clean_idmap_path = Path(root)/"dpacman/data_files/processed/remap/idmapping_reviewed_true_processed_2025_08_11.tsv"
|
| 631 |
+
idmap_remapped.to_csv(clean_idmap_path, sep="\t")
|
| 632 |
+
return clean_idmap_path
|
| 633 |
+
|
| 634 |
+
def debug_fimo_check(path_to_chrom_fimo, path_to_processed_chrom, chrom="Y", json_dir=""):
|
| 635 |
+
"""
|
| 636 |
+
Make sure we are properly extracting fimo sequences.
|
| 637 |
+
"""
|
| 638 |
+
processed = pd.read_csv(path_to_processed_chrom)
|
| 639 |
+
processed["pred_motif_string"] = processed.apply(lambda row: extract_jaspar_motifs(row), axis=1)
|
| 640 |
+
processed["dna_sequence_rc"] = processed["dna_sequence"].apply(lambda x: get_reverse_complement(x))
|
| 641 |
+
processed["pred_motif_string_rc"] = processed.apply(lambda row: extract_jaspar_motifs(row, reverse_complement=True), axis=1)
|
| 642 |
+
processed_trs = processed["tr_name"].unique().tolist()
|
| 643 |
+
|
| 644 |
+
fimo = pd.read_csv(path_to_chrom_fimo)
|
| 645 |
+
fimo["input_tr"] = fimo["sequence_name"].str.split("_",expand=True)[2]
|
| 646 |
+
fimo_valid = fimo.loc[
|
| 647 |
+
(fimo["motif_alt_id"]==fimo["input_tr"]) &
|
| 648 |
+
(fimo["motif_alt_id"].isin(processed_trs))
|
| 649 |
+
].reset_index(drop=True)
|
| 650 |
+
logger.info(f"Total valid FIMO matches: {len(fimo_valid)}")
|
| 651 |
+
logger.info(f"Total transcriptional regulators being considered: {len(processed_trs)}")
|
| 652 |
+
|
| 653 |
+
# Load DNA
|
| 654 |
+
cache_debug = load_chrom_dna(chrom, {}, json_dir=json_dir)
|
| 655 |
+
|
| 656 |
+
# Randomly select a positive and negative row to test
|
| 657 |
+
pos_row = fimo_valid.loc[fimo_valid["strand"]=="+"].sample(n=1, random_state=44)
|
| 658 |
+
neg_row = fimo_valid.loc[fimo_valid["strand"]=="-"].sample(n=1, random_state=44)
|
| 659 |
+
|
| 660 |
+
# Iterate through the rows
|
| 661 |
+
for row in [pos_row, neg_row]:
|
| 662 |
+
indices = [int(x) for x in row["sequence_name"].item().split("_")[-2::]]
|
| 663 |
+
chipseq_start, chipseq_end = indices
|
| 664 |
+
strand = row['strand'].item()
|
| 665 |
+
logger.info(f"ChIPseq start: {chipseq_start}, ChIPseq end: {chipseq_end}")
|
| 666 |
+
logger.info(f"Strand: {strand}")
|
| 667 |
+
|
| 668 |
+
motif_start = chipseq_start + int(row["start"].item()) -1
|
| 669 |
+
motif_end = chipseq_start + int(row["stop"].item())
|
| 670 |
+
motif = row['matched_sequence'].item()
|
| 671 |
+
full_seq = cache_debug[chipseq_start:chipseq_end].upper()
|
| 672 |
+
logger.info(f"Full sequence: {full_seq}")
|
| 673 |
+
logger.info(f"Full sequence reverse complement: {get_reverse_complement(full_seq)}")
|
| 674 |
+
our_motif = cache_debug[motif_start:motif_end].upper()
|
| 675 |
+
|
| 676 |
+
if strand=="+":
|
| 677 |
+
logger.info(f"True motif found by FIMO: {motif}")
|
| 678 |
+
logger.info(f"Extracted motif on our end: {our_motif}")
|
| 679 |
+
logger.info(f"Correct extraction: {motif==our_motif}")
|
| 680 |
+
|
| 681 |
+
matching_rows = processed.loc[
|
| 682 |
+
(processed["dna_sequence"].str.contains(full_seq)) &
|
| 683 |
+
(processed["pred_motif_string"].str.contains(our_motif))
|
| 684 |
+
]
|
| 685 |
+
if strand=="-":
|
| 686 |
+
our_motif_rc = get_reverse_complement(our_motif)
|
| 687 |
+
|
| 688 |
+
logger.info(f"True motif found by FIMO: {motif}")
|
| 689 |
+
logger.info(f"Extracted motif on our end: {our_motif_rc}")
|
| 690 |
+
logger.info(f"Correct extraction: {motif==our_motif_rc}")
|
| 691 |
+
logger.info(f"Motif that will appear in the forward sequence: {our_motif}")
|
| 692 |
+
matching_rows = processed.loc[
|
| 693 |
+
(processed["dna_sequence"].str.contains(full_seq)) &
|
| 694 |
+
(processed["pred_motif_string"].str.contains(our_motif))
|
| 695 |
+
]
|
| 696 |
+
|
| 697 |
+
# Now find if there are rows with the same TR, and the same DNA sequence and motif
|
| 698 |
+
matching_row_trs = sorted(matching_rows["tr_name"].unique().tolist())
|
| 699 |
+
expected_tr = row['motif_alt_id'].item()
|
| 700 |
+
logger.info(f"TR from selected row: {expected_tr}")
|
| 701 |
+
logger.info(f"TRs with same motif: {','.join(matching_row_trs)}")
|
| 702 |
+
logger.info(f"Expected TR in list: {expected_tr in matching_row_trs}")
|
| 703 |
+
|
| 704 |
+
def debug_remap_check(remap_path, path_to_processed_chrom, chrom="Y", json_dir=""):
|
| 705 |
+
"""
|
| 706 |
+
For debugging mode: pick a random row from processed remap. make sure the sequence matches the one we're getting here.
|
| 707 |
+
"""
|
| 708 |
+
remap = pd.read_csv(remap_path)
|
| 709 |
+
remap["ChIPStart"] = remap["ChIPStart"].astype(int)
|
| 710 |
+
remap["ChIPEnd"] = remap["ChIPEnd"].astype(int)
|
| 711 |
+
|
| 712 |
+
row = remap.loc[remap["#chrom"]=="Y"].sample(n=1, random_state=42)
|
| 713 |
+
|
| 714 |
+
start, end = row["ChIPStart"].item(), row["ChIPEnd"].item()
|
| 715 |
+
|
| 716 |
+
cache_debug = load_chrom_dna(chrom, {}, json_dir=json_dir)
|
| 717 |
+
test_seq = cache_debug[start:end].upper()
|
| 718 |
+
logger.info(f"Randomly sampled sequence ({len(test_seq)} nucleotides), chrY {start}:{end}\n\tsequence: {test_seq}")
|
| 719 |
+
should_find = remap.loc[
|
| 720 |
+
(remap["ChIPStart"]==start) &
|
| 721 |
+
(remap["ChIPEnd"]==end)
|
| 722 |
+
]["TR"].unique().tolist()
|
| 723 |
+
logger.info(f"Expect to find {len(should_find)} TRs: {', '.join(sorted(should_find))}")
|
| 724 |
+
|
| 725 |
+
processed = pd.read_csv(path_to_processed_chrom)
|
| 726 |
+
did_find = processed.loc[
|
| 727 |
+
processed["peak_sequence"]==test_seq
|
| 728 |
+
]["tr_name"].unique().tolist()
|
| 729 |
+
logger.info(f"Looked up same sequence in processed chrY file.\nFound TRs: {', '.join(sorted(did_find))}")
|
| 730 |
+
logger.info(f"found==expected: {did_find==should_find}")
|
| 731 |
+
|
| 732 |
+
|
| 733 |
def main(cfg: DictConfig):
|
| 734 |
debug = bool(cfg.data_task.debug)
|
| 735 |
json_dir = cfg.data_task.json_dir
|
|
|
|
| 740 |
|
| 741 |
logger.info(f"Debug: {debug}")
|
| 742 |
logger.info(f"Reading per-chrom final.csv under: {fimo_out_dir}")
|
| 743 |
+
|
| 744 |
+
# process the idmap
|
| 745 |
+
idmap_path=Path(root) / cfg.data_task.idmap_path
|
| 746 |
+
clean_idmap_path = clean_idmap(idmap_path)
|
| 747 |
|
| 748 |
+
# If we don't have temp files to process
|
| 749 |
+
if not(os.path.exists(output_parts_folder)) or (os.path.exists(output_parts_folder) and len(os.listdir(output_parts_folder))<24):
|
| 750 |
paths_to_processed_dfs = build_dataset_fast_mp(
|
| 751 |
fimo_out_dir=fimo_out_dir,
|
| 752 |
json_dir=json_dir,
|
| 753 |
debug=debug,
|
| 754 |
max_workers=max_workers,
|
| 755 |
jaspar_boost=cfg.data_task.jaspar_boost,
|
| 756 |
+
output_parts_folder=output_parts_folder,
|
| 757 |
+
keep_fimo_only=cfg.data_task.keep_fimo_only
|
| 758 |
)
|
| 759 |
else:
|
| 760 |
paths_to_processed_dfs = [output_parts_folder/x for x in os.listdir(output_parts_folder)] if output_parts_folder.exists() else []
|
| 761 |
+
|
| 762 |
+
# Debug methods: (1) make sure our peak sequences correspond to remap, (2) make sure our FIMO sequences correspond to FIMO results
|
| 763 |
+
out_path = str(processed_output_csv).replace(".csv", ".parquet")
|
| 764 |
+
if debug:
|
| 765 |
+
debug_remap_check(remap_path=Path(root) / cfg.data_task.remap_path,
|
| 766 |
+
path_to_processed_chrom=Path(output_parts_folder)/"chromY_processed.csv",
|
| 767 |
+
chrom="Y", json_dir=json_dir)
|
| 768 |
+
debug_fimo_check(path_to_chrom_fimo=Path(root) / cfg.data_task.fimo_out_dir / "chromY" / "fimo_annotations.csv",
|
| 769 |
+
path_to_processed_chrom=Path(output_parts_folder)/"chromY_processed.csv",
|
| 770 |
+
chrom="Y", json_dir=json_dir)
|
| 771 |
+
out_path = out_path.replace(".parquet", "_debug.parquet")
|
| 772 |
+
|
| 773 |
logger.info(f"Combining {len(paths_to_processed_dfs)} processed parts with Polars")
|
| 774 |
combine_processed_with_polars(
|
| 775 |
paths_to_processed_dfs=paths_to_processed_dfs,
|
| 776 |
+
idmap_path=clean_idmap_path,
|
| 777 |
+
out_path=out_path,
|
| 778 |
+
max_protein_len=cfg.data_task.max_protein_len,
|
| 779 |
+
seeds=cfg.data_task.seeds
|
| 780 |
)
|
| 781 |
|
| 782 |
# Delete the folder that had the temporary DFs, don't need these
|
dpacman/data_tasks/fimo/run_fimo.py
CHANGED
|
@@ -152,17 +152,26 @@ def run_fimo_chunk(cfg):
|
|
| 152 |
def annotate_with_fimo(df, fdf):
|
| 153 |
df = df.reset_index().rename(columns={"index":"idx"})
|
| 154 |
df["sequence_name"] = df["idx"].astype(str) + "_chr" + df["#chrom"] + "_" + df["TR"] + "_" + df["contextStart"].astype(str) + "_" + df["contextEnd"].astype(str) #construt it the same way as headers
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
fdf["
|
| 158 |
-
|
| 159 |
-
fdf["
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
)
|
| 161 |
-
|
|
|
|
| 162 |
df["jaspar"] = df["sequence_name"].map(agg).fillna("")
|
| 163 |
return df
|
| 164 |
|
| 165 |
-
|
| 166 |
def main(cfg: DictConfig):
|
| 167 |
"""
|
| 168 |
Main method for running FIMO analysis, searching JASPAR motifs against ChIP-seq peaks
|
|
|
|
| 152 |
def annotate_with_fimo(df, fdf):
|
| 153 |
df = df.reset_index().rename(columns={"index":"idx"})
|
| 154 |
df["sequence_name"] = df["idx"].astype(str) + "_chr" + df["#chrom"] + "_" + df["TR"] + "_" + df["contextStart"].astype(str) + "_" + df["contextEnd"].astype(str) #construt it the same way as headers
|
| 155 |
+
|
| 156 |
+
# Crucial: filter FDF results to only rows where the TF whose motif was found actually matches the TF that was detected there.
|
| 157 |
+
fdf["input_tr"] = fdf["sequence_name"].str.split("_",expand=True)[2]
|
| 158 |
+
true_matches = fdf.loc[
|
| 159 |
+
fdf["motif_alt_id"]==fdf["input_tr"]
|
| 160 |
+
].reset_index(drop=True)
|
| 161 |
+
logger.info(f"Length of full returned FIMO results: {len(fdf)}")
|
| 162 |
+
logger.info(f"Length of true matches, where the FIMO tr and the input tr match: {len(true_matches)}")
|
| 163 |
+
|
| 164 |
+
true_matches = true_matches.merge(df[["sequence_name", "contextStart"]], on="sequence_name", how="left")
|
| 165 |
+
true_matches["genomic_start"] = true_matches["contextStart"] + true_matches["start"] - 1
|
| 166 |
+
true_matches["genomic_end"] = true_matches["contextStart"] + true_matches["stop"]
|
| 167 |
+
true_matches["coord"] = (
|
| 168 |
+
true_matches["genomic_start"].astype(str) + "-" + true_matches["genomic_end"].astype(str)
|
| 169 |
)
|
| 170 |
+
|
| 171 |
+
agg = true_matches.groupby("sequence_name")["coord"].agg(lambda hits: ",".join(hits))
|
| 172 |
df["jaspar"] = df["sequence_name"].map(agg).fillna("")
|
| 173 |
return df
|
| 174 |
|
|
|
|
| 175 |
def main(cfg: DictConfig):
|
| 176 |
"""
|
| 177 |
Main method for running FIMO analysis, searching JASPAR motifs against ChIP-seq peaks
|
dpacman/data_tasks/split/__init__.py
ADDED
|
File without changes
|
dpacman/data_tasks/split/remap.py
ADDED
|
@@ -0,0 +1,512 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import Counter, defaultdict
|
| 2 |
+
from ortools.linear_solver import pywraplp
|
| 3 |
+
import random
|
| 4 |
+
import logging
|
| 5 |
+
from omegaconf import DictConfig
|
| 6 |
+
import rootutils
|
| 7 |
+
import pandas as pd
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
import os
|
| 10 |
+
import numpy as np
|
| 11 |
+
from sklearn.model_selection import train_test_split
|
| 12 |
+
|
| 13 |
+
root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
def split_bipartite_fast(
|
| 17 |
+
dna_clusters,
|
| 18 |
+
split_names=("train","val","test"),
|
| 19 |
+
ratios=(0.8,0.1,0.1),
|
| 20 |
+
):
|
| 21 |
+
# use sklearn
|
| 22 |
+
test_size_1 = 0.2
|
| 23 |
+
test_size_2 = 0.5
|
| 24 |
+
logger.info(f"\tPerforming first split: all clusters -> train clusters ({round(1-test_size_1,3)}) and other ({test_size_1})")
|
| 25 |
+
X = dna_clusters
|
| 26 |
+
y = [0]*len(dna_clusters)
|
| 27 |
+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size_1, random_state=0)
|
| 28 |
+
logger.info(f"\tPerforming second split: other -> val clusters ({round(1-test_size_2,3)}) and test clusters ({test_size_2})")
|
| 29 |
+
X_val, X_test, y_val, y_test = train_test_split(X_test, y_test, test_size=test_size_2, random_state=0)
|
| 30 |
+
|
| 31 |
+
dna_assign = {}
|
| 32 |
+
for x in X_train:
|
| 33 |
+
dna_assign[x] = "train"
|
| 34 |
+
for x in X_val:
|
| 35 |
+
dna_assign[x] = "val"
|
| 36 |
+
for x in X_test:
|
| 37 |
+
dna_assign[x] = "test"
|
| 38 |
+
|
| 39 |
+
kept_by_split = {'train': len(X_train), 'val': len(X_val), 'test': len(X_test)}
|
| 40 |
+
return dna_assign, kept_by_split
|
| 41 |
+
|
| 42 |
+
def split_bipartite_with_ratios_and_leaky(
|
| 43 |
+
edges,
|
| 44 |
+
split_names=("train","val","test"),
|
| 45 |
+
ratios=(0.8, 0.1, 0.1),
|
| 46 |
+
require_nonempty=False,
|
| 47 |
+
ratio_tolerance=None, # None = soft ratios only; 0.0 = exact band (use with care)
|
| 48 |
+
bigM=None,
|
| 49 |
+
shuffle_within_pair=False,
|
| 50 |
+
seed=0,
|
| 51 |
+
test_edges_must=None, # NEW: list of (tf,dna) with duplicates OR dict {(tf,dna): count}
|
| 52 |
+
):
|
| 53 |
+
"""
|
| 54 |
+
edges: list of (tf_cluster_id, dna_cluster_id). Duplicates allowed (-> weights).
|
| 55 |
+
test_edges_must: None, list of pairs, or dict {(tf,dna): required_count}.
|
| 56 |
+
- If a pair appears with required_count > 0, at least that many examples MUST be kept in TEST.
|
| 57 |
+
- This implicitly pins both clusters of that pair to TEST (cluster exclusivity).
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
tf_assign: {tf_cluster -> split}
|
| 61 |
+
dna_assign: {dna_cluster -> split}
|
| 62 |
+
kept_by_split: {split -> kept_count} (train/val/test only)
|
| 63 |
+
total_kept: int
|
| 64 |
+
split_to_indices: {split -> [input indices]} including 'leaky_test'
|
| 65 |
+
split_to_edges: {split -> [(tf,dna), ...]} including 'leaky_test'
|
| 66 |
+
"""
|
| 67 |
+
# Aggregate counts per pair
|
| 68 |
+
w = Counter(edges)
|
| 69 |
+
tfs = {t for (t, _) in w}
|
| 70 |
+
dnas = {d for (_, d) in w}
|
| 71 |
+
S = list(split_names)
|
| 72 |
+
rs = dict(zip(S, ratios))
|
| 73 |
+
N = sum(w.values())
|
| 74 |
+
if bigM is None:
|
| 75 |
+
bigM = 1000 * max(1, N)
|
| 76 |
+
|
| 77 |
+
# Index original edges so we can return a per-example split
|
| 78 |
+
pair_to_indices = defaultdict(list)
|
| 79 |
+
for idx, (c, d) in enumerate(edges):
|
| 80 |
+
pair_to_indices[(c, d)].append(idx)
|
| 81 |
+
|
| 82 |
+
if shuffle_within_pair:
|
| 83 |
+
rng = random.Random(seed)
|
| 84 |
+
for key in pair_to_indices:
|
| 85 |
+
rng.shuffle(pair_to_indices[key])
|
| 86 |
+
|
| 87 |
+
# Parse required test edges
|
| 88 |
+
req_test = Counter()
|
| 89 |
+
if test_edges_must:
|
| 90 |
+
if isinstance(test_edges_must, dict):
|
| 91 |
+
for k, v in test_edges_must.items():
|
| 92 |
+
if not isinstance(k, tuple) or len(k) != 2:
|
| 93 |
+
raise ValueError("test_edges_must dict keys must be (tf_cluster, dna_cluster)")
|
| 94 |
+
if v < 0:
|
| 95 |
+
raise ValueError("required_count must be non-negative")
|
| 96 |
+
if v:
|
| 97 |
+
req_test[k] += int(v)
|
| 98 |
+
else:
|
| 99 |
+
# assume iterable of pairs
|
| 100 |
+
req_test = Counter(test_edges_must)
|
| 101 |
+
# Validate against available counts
|
| 102 |
+
for pair, req in req_test.items():
|
| 103 |
+
if pair not in w:
|
| 104 |
+
raise ValueError(f"Required test pair {pair} not present in edges.")
|
| 105 |
+
if req > w[pair]:
|
| 106 |
+
raise ValueError(
|
| 107 |
+
f"Required count {req} for {pair} exceeds available {w[pair]}."
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
# Build solver
|
| 111 |
+
solver = pywraplp.Solver.CreateSolver("CBC")
|
| 112 |
+
if solver is None:
|
| 113 |
+
raise RuntimeError("Could not create CBC solver.")
|
| 114 |
+
|
| 115 |
+
# Binary cluster assignments
|
| 116 |
+
x = {(c,s): solver.BoolVar(f"x[{c},{s}]") for c in tfs for s in S}
|
| 117 |
+
y = {(d,s): solver.BoolVar(f"y[{d},{s}]") for d in dnas for s in S}
|
| 118 |
+
|
| 119 |
+
# Each cluster in exactly one split
|
| 120 |
+
for c in tfs:
|
| 121 |
+
solver.Add(sum(x[c,s] for s in S) == 1)
|
| 122 |
+
for d in dnas:
|
| 123 |
+
solver.Add(sum(y[d,s] for s in S) == 1)
|
| 124 |
+
|
| 125 |
+
# Integer kept counts per pair and split (allow partial within-pair)
|
| 126 |
+
k = {((c,d),s): solver.IntVar(0, w[(c,d)], f"k[{c},{d},{s}]") for (c,d) in w for s in S}
|
| 127 |
+
|
| 128 |
+
# Only keep in split s if both endpoint clusters are assigned to s
|
| 129 |
+
for (c,d), wt in w.items():
|
| 130 |
+
for s in S:
|
| 131 |
+
solver.Add(k[((c,d),s)] <= wt * x[c,s])
|
| 132 |
+
solver.Add(k[((c,d),s)] <= wt * y[d,s])
|
| 133 |
+
|
| 134 |
+
# Enforce minimum kept counts in TEST for required pairs
|
| 135 |
+
for (c,d), req in req_test.items():
|
| 136 |
+
solver.Add(k[((c,d), "test")] >= req)
|
| 137 |
+
|
| 138 |
+
# Optional: ensure each split has at least one cluster (feasibility depends on counts)
|
| 139 |
+
if require_nonempty:
|
| 140 |
+
for s in S:
|
| 141 |
+
solver.Add(
|
| 142 |
+
sum(x[c,s] for c in tfs) + sum(y[d,s] for d in dnas) >= 1
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
# Kept counts per split and total
|
| 146 |
+
K = {s: solver.IntVar(0, N, f"K[{s}]") for s in S}
|
| 147 |
+
for s in S:
|
| 148 |
+
solver.Add(K[s] == sum(k[((c,d),s)] for (c,d) in w))
|
| 149 |
+
T = solver.IntVar(0, N, "T")
|
| 150 |
+
solver.Add(T == sum(K[s] for s in S))
|
| 151 |
+
|
| 152 |
+
# Ratio deviation: K_s - r_s * T = d+ - d-
|
| 153 |
+
dpos = {s: solver.NumVar(0, solver.infinity(), f"dpos[{s}]") for s in S}
|
| 154 |
+
dneg = {s: solver.NumVar(0, solver.infinity(), f"dneg[{s}]") for s in S}
|
| 155 |
+
for s in S:
|
| 156 |
+
solver.Add(K[s] - rs[s]*T == dpos[s] - dneg[s])
|
| 157 |
+
|
| 158 |
+
# Optional hard band around target ratios
|
| 159 |
+
if ratio_tolerance is not None:
|
| 160 |
+
eps = float(ratio_tolerance)
|
| 161 |
+
for s in S:
|
| 162 |
+
solver.Add(K[s] >= (rs[s] - eps) * T)
|
| 163 |
+
solver.Add(K[s] <= (rs[s] + eps) * T)
|
| 164 |
+
|
| 165 |
+
# Objective: maximize T then minimize total deviation
|
| 166 |
+
obj = solver.Objective()
|
| 167 |
+
obj.SetMaximization()
|
| 168 |
+
obj.SetCoefficient(T, float(bigM))
|
| 169 |
+
for s in S:
|
| 170 |
+
obj.SetCoefficient(dpos[s], -1.0)
|
| 171 |
+
obj.SetCoefficient(dneg[s], -1.0)
|
| 172 |
+
|
| 173 |
+
status = solver.Solve()
|
| 174 |
+
if status not in (pywraplp.Solver.OPTIMAL, pywraplp.Solver.FEASIBLE):
|
| 175 |
+
raise RuntimeError("No feasible solution (check ratio_tolerance vs. required test edges).")
|
| 176 |
+
|
| 177 |
+
# Read cluster assignments
|
| 178 |
+
tf_assign = {c: next(s for s in S if x[c,s].solution_value() > 0.5) for c in tfs}
|
| 179 |
+
dna_assign = {d: next(s for s in S if y[d,s].solution_value() > 0.5) for d in dnas}
|
| 180 |
+
|
| 181 |
+
# Kept counts per split
|
| 182 |
+
kept_by_split = {s: int(round(K[s].solution_value())) for s in S}
|
| 183 |
+
total_kept = int(round(T.solution_value()))
|
| 184 |
+
|
| 185 |
+
# ---- Build per-example split assignment (including 'leaky_test') ----
|
| 186 |
+
split_to_indices = {s: [] for s in S}
|
| 187 |
+
remaining_indices = {pair: list(pair_to_indices[pair]) for pair in pair_to_indices}
|
| 188 |
+
|
| 189 |
+
# Allocate the kept examples per split (train/val/test)
|
| 190 |
+
for (c,d), wt in w.items():
|
| 191 |
+
for s in S:
|
| 192 |
+
cnt = int(round(k[((c,d),s)].solution_value()))
|
| 193 |
+
if cnt > 0:
|
| 194 |
+
take = remaining_indices[(c,d)][:cnt]
|
| 195 |
+
split_to_indices[s].extend(take)
|
| 196 |
+
remaining_indices[(c,d)] = remaining_indices[(c,d)][cnt:]
|
| 197 |
+
|
| 198 |
+
# Everything left becomes leaky_test
|
| 199 |
+
leaky_indices = []
|
| 200 |
+
for pair, idxs in remaining_indices.items():
|
| 201 |
+
if idxs:
|
| 202 |
+
leaky_indices.extend(idxs)
|
| 203 |
+
|
| 204 |
+
split_to_indices["leaky_test"] = leaky_indices
|
| 205 |
+
split_to_edges = {s: [edges[i] for i in split_to_indices[s]] for s in split_to_indices}
|
| 206 |
+
|
| 207 |
+
return tf_assign, dna_assign, kept_by_split, total_kept, split_to_indices, split_to_edges
|
| 208 |
+
|
| 209 |
+
from collections import Counter, defaultdict
|
| 210 |
+
import random
|
| 211 |
+
|
| 212 |
+
class DSU:
|
| 213 |
+
def __init__(self): self.p = {}
|
| 214 |
+
def find(self, x):
|
| 215 |
+
if x not in self.p: self.p[x] = x
|
| 216 |
+
while self.p[x] != x:
|
| 217 |
+
self.p[x] = self.p[self.p[x]]
|
| 218 |
+
x = self.p[x]
|
| 219 |
+
return x
|
| 220 |
+
def union(self, a,b):
|
| 221 |
+
ra, rb = self.find(a), self.find(b)
|
| 222 |
+
if ra != rb: self.p[rb] = ra
|
| 223 |
+
|
| 224 |
+
def split_bipartite_by_components(
|
| 225 |
+
edges,
|
| 226 |
+
split_names=("train","val","test"),
|
| 227 |
+
ratios=(0.8,0.1,0.1),
|
| 228 |
+
seed=0,
|
| 229 |
+
require_nonempty=False,
|
| 230 |
+
test_edges_must=None, # None, list[(tf,dna)], or dict{(tf,dna): count}
|
| 231 |
+
):
|
| 232 |
+
"""
|
| 233 |
+
Guarantees exclusivity: each TF cluster and DNA cluster appears in at most one split.
|
| 234 |
+
Strategy: find connected components in the TF–DNA bipartite graph and assign components wholesale.
|
| 235 |
+
"""
|
| 236 |
+
rng = random.Random(seed)
|
| 237 |
+
w = Counter(edges) # multiplicities per pair
|
| 238 |
+
if not w: raise ValueError("No edges.")
|
| 239 |
+
|
| 240 |
+
# 1) Build components with Union-Find (prefix to keep TF/DNA namespaces disjoint)
|
| 241 |
+
dsu = DSU()
|
| 242 |
+
for (tf, dna) in w:
|
| 243 |
+
dsu.union(("T", tf), ("D", dna))
|
| 244 |
+
comp_pairs = defaultdict(list)
|
| 245 |
+
comp_weight = defaultdict(int)
|
| 246 |
+
for (tf, dna), cnt in w.items():
|
| 247 |
+
root = dsu.find(("T", tf)) # component id = root of TF endpoint
|
| 248 |
+
comp_pairs[root].append((tf, dna))
|
| 249 |
+
comp_weight[root] += cnt
|
| 250 |
+
|
| 251 |
+
comps = list(comp_pairs.keys())
|
| 252 |
+
C = len(comps)
|
| 253 |
+
S = list(split_names)
|
| 254 |
+
rs = dict(zip(S, ratios))
|
| 255 |
+
N = sum(comp_weight[c] for c in comps)
|
| 256 |
+
target = {s: int(round(rs[s] * N)) for s in S}
|
| 257 |
+
|
| 258 |
+
# 2) Pin components that contain required TEST pairs
|
| 259 |
+
pinned = {} # comp_root -> pinned_split ("test")
|
| 260 |
+
if test_edges_must:
|
| 261 |
+
req = Counter(test_edges_must) if not isinstance(test_edges_must, dict) else Counter(test_edges_must)
|
| 262 |
+
# Map each required pair to its component, ensure feasibility
|
| 263 |
+
for (tf, dna), r in req.items():
|
| 264 |
+
if (tf, dna) not in w:
|
| 265 |
+
raise ValueError(f"Required pair {(tf,dna)} not present.")
|
| 266 |
+
if r > w[(tf, dna)]:
|
| 267 |
+
raise ValueError(f"Required count {r} for {(tf,dna)} exceeds available {w[(tf,dna)]}.")
|
| 268 |
+
comp = dsu.find(("T", tf))
|
| 269 |
+
if comp in pinned and pinned[comp] != "test":
|
| 270 |
+
raise ValueError(f"Component conflict: already pinned to {pinned[comp]}, but {(tf,dna)} demands test.")
|
| 271 |
+
pinned[comp] = "test"
|
| 272 |
+
# NOTE: pinning a pair pins the WHOLE component to test (to keep exclusivity).
|
| 273 |
+
# If you only want some edges kept in test and discard the rest, handle below when materializing.
|
| 274 |
+
|
| 275 |
+
# 3) Assign components greedily by deficit
|
| 276 |
+
kept_by_split = {s: 0 for s in S}
|
| 277 |
+
comp_assign = {} # comp_root -> split
|
| 278 |
+
|
| 279 |
+
# First assign pinned comps
|
| 280 |
+
for comp, split in pinned.items():
|
| 281 |
+
comp_assign[comp] = split
|
| 282 |
+
kept_by_split[split] += comp_weight[comp]
|
| 283 |
+
|
| 284 |
+
# Sort remaining components by descending weight
|
| 285 |
+
remaining = [c for c in comps if c not in comp_assign]
|
| 286 |
+
remaining.sort(key=lambda c: comp_weight[c], reverse=True)
|
| 287 |
+
|
| 288 |
+
# Ensure nonempty splits if requested (seed with largest remaining comps)
|
| 289 |
+
if require_nonempty:
|
| 290 |
+
seeds = remaining[:min(len(S), len(remaining))]
|
| 291 |
+
for comp, s in zip(seeds, S):
|
| 292 |
+
comp_assign[comp] = s
|
| 293 |
+
kept_by_split[s] += comp_weight[comp]
|
| 294 |
+
remaining = [c for c in remaining if c not in comp_assign]
|
| 295 |
+
|
| 296 |
+
for comp in remaining:
|
| 297 |
+
# choose split with largest deficit (target - current)
|
| 298 |
+
deficits = {s: target[s] - kept_by_split[s] for s in S}
|
| 299 |
+
best = max(deficits, key=lambda s: deficits[s])
|
| 300 |
+
comp_assign[comp] = best
|
| 301 |
+
kept_by_split[best] += comp_weight[comp]
|
| 302 |
+
|
| 303 |
+
total_kept = sum(kept_by_split.values())
|
| 304 |
+
|
| 305 |
+
# 4) Materialize per-example indices (and verify exclusivity)
|
| 306 |
+
pair_to_indices = defaultdict(list)
|
| 307 |
+
for idx, pair in enumerate(edges):
|
| 308 |
+
pair_to_indices[pair].append(idx)
|
| 309 |
+
|
| 310 |
+
split_to_indices = {s: [] for s in S}
|
| 311 |
+
for comp, s in comp_assign.items():
|
| 312 |
+
for pair in comp_pairs[comp]:
|
| 313 |
+
split_to_indices[s].extend(pair_to_indices[pair])
|
| 314 |
+
|
| 315 |
+
# Optional: if you pinned a comp due to a small 'must-test' count but
|
| 316 |
+
# want to *discard* the rest instead of keeping them in test, uncomment:
|
| 317 |
+
# for comp, s in comp_assign.items():
|
| 318 |
+
# if s == "test" and test_edges_must:
|
| 319 |
+
# # Keep only the required counts; dump extras to 'leaky_test'
|
| 320 |
+
# ...
|
| 321 |
+
# (Left out for clarity; default is: keep the whole component in its split.)
|
| 322 |
+
|
| 323 |
+
# 5) Build edge lists and simple cluster assignments
|
| 324 |
+
split_to_edges = {s: [edges[i] for i in split_to_indices[s]] for s in split_to_indices}
|
| 325 |
+
tf_assign, dna_assign = {}, {}
|
| 326 |
+
for comp, s in comp_assign.items():
|
| 327 |
+
for (tf, dna) in comp_pairs[comp]:
|
| 328 |
+
tf_assign[tf] = s
|
| 329 |
+
dna_assign[dna] = s
|
| 330 |
+
|
| 331 |
+
# 6) Safety check: no DNA/TF appears in multiple splits
|
| 332 |
+
tf_in_split = defaultdict(set)
|
| 333 |
+
dna_in_split = defaultdict(set)
|
| 334 |
+
for s, elist in split_to_edges.items():
|
| 335 |
+
for tf, dna in elist:
|
| 336 |
+
tf_in_split[tf].add(s)
|
| 337 |
+
dna_in_split[dna].add(s)
|
| 338 |
+
dup_tf = {tf: ss for tf, ss in tf_in_split.items() if len(ss) > 1}
|
| 339 |
+
dup_dna = {dn: ss for dn, ss in dna_in_split.items() if len(ss) > 1}
|
| 340 |
+
assert not dup_tf and not dup_dna, f"Exclusivity violated: {dup_tf} {dup_dna}"
|
| 341 |
+
|
| 342 |
+
return tf_assign, dna_assign, kept_by_split, total_kept, split_to_indices, split_to_edges
|
| 343 |
+
|
| 344 |
+
def print_split_ratios(kept_by_split):
|
| 345 |
+
total = sum(kept_by_split.values())
|
| 346 |
+
train_pcnt = 100 * kept_by_split["train"] / total
|
| 347 |
+
val_pcnt = 100 * kept_by_split["val"] / total
|
| 348 |
+
test_pcnt = 100 * kept_by_split["test"] / total
|
| 349 |
+
logger.info(f"Cluster distribution - Train: {train_pcnt:.2f}%, Val: {val_pcnt:.2f}%, Test: {test_pcnt:.2f}%")
|
| 350 |
+
|
| 351 |
+
def make_edges(processed_fimo_path: str, protein_cluster_path: str, dna_cluster_path: str):
|
| 352 |
+
"""
|
| 353 |
+
Make edges for input to the splitting algorithm. Edges consist of: (tr_cluster_rep)_(dna_cluster_rep) where the cluster rep is the sequence ID
|
| 354 |
+
"""
|
| 355 |
+
# Read cluser data
|
| 356 |
+
protein_clusters = pd.read_csv(protein_cluster_path, header=None,sep="\t")
|
| 357 |
+
protein_clusters.columns=["tr_cluster_rep","tr_seqid"]
|
| 358 |
+
|
| 359 |
+
dna_clusters = pd.read_csv(dna_cluster_path, header=None,sep="\t")
|
| 360 |
+
dna_clusters.columns=["dna_cluster_rep","dna_seqid"]
|
| 361 |
+
|
| 362 |
+
# Read datapoints
|
| 363 |
+
edges = pd.read_parquet(processed_fimo_path)
|
| 364 |
+
edges = pd.merge(edges, dna_clusters, on="dna_seqid",how="left")
|
| 365 |
+
edges = pd.merge(edges, protein_clusters, on="tr_seqid",how="left")
|
| 366 |
+
edges["edge"] = edges.apply(lambda row: (row["tr_cluster_rep"], row["dna_cluster_rep"]), axis=1)
|
| 367 |
+
|
| 368 |
+
logger.info(f"Total unique edges: {len(edges['edge'].unique().tolist())}")
|
| 369 |
+
dup_edges = edges.loc[edges.duplicated("edge")]["edge"].unique().tolist()
|
| 370 |
+
logger.info(f"Total edges with >1 datapoint: {len(dup_edges)}")
|
| 371 |
+
logger.info(f"Total datapoints belonging to a duplicate edge: {len(edges.loc[edges['edge'].isin(dup_edges)])}")
|
| 372 |
+
return edges
|
| 373 |
+
|
| 374 |
+
def check_validity(train, val, test, split_by="both"):
|
| 375 |
+
"""
|
| 376 |
+
Rigorous check for no overlap
|
| 377 |
+
Columns = ["ID","dna_sequence","tr_sequence","tr_cluster_rep","dna_cluster_rep", "scores","split"]
|
| 378 |
+
"""
|
| 379 |
+
train_ids = set(train["ID"].unique().tolist())
|
| 380 |
+
val_ids = set(val["ID"].unique().tolist())
|
| 381 |
+
test_ids = set(test["ID"].unique().tolist())
|
| 382 |
+
|
| 383 |
+
assert len(train_ids.intersection(val_ids))==0
|
| 384 |
+
assert len(train_ids.intersection(test_ids))==0
|
| 385 |
+
assert len(val_ids.intersection(test_ids))==0
|
| 386 |
+
logger.info(f"Pass! No overlap in IDs")
|
| 387 |
+
|
| 388 |
+
if split_by!="dna":
|
| 389 |
+
train_tr_seqs = set(train["tr_sequence"].unique().tolist())
|
| 390 |
+
val_tr_seqs = set(val["tr_sequence"].unique().tolist())
|
| 391 |
+
test_tr_seqs = set(test["tr_sequence"].unique().tolist())
|
| 392 |
+
|
| 393 |
+
assert len(train_tr_seqs.intersection(val_tr_seqs))==0
|
| 394 |
+
assert len(train_tr_seqs.intersection(test_tr_seqs))==0
|
| 395 |
+
assert len(val_tr_seqs.intersection(test_tr_seqs))==0
|
| 396 |
+
logger.info(f"Pass! No overlap in TR sequences")
|
| 397 |
+
|
| 398 |
+
train_tr_reps = set(train["tr_cluster_rep"].unique().tolist())
|
| 399 |
+
val_tr_reps = set(val["tr_cluster_rep"].unique().tolist())
|
| 400 |
+
test_tr_reps = set(test["tr_cluster_rep"].unique().tolist())
|
| 401 |
+
|
| 402 |
+
assert len(train_tr_reps.intersection(val_tr_reps))==0
|
| 403 |
+
assert len(train_tr_reps.intersection(test_tr_reps))==0
|
| 404 |
+
assert len(val_tr_reps.intersection(test_tr_reps))==0
|
| 405 |
+
logger.info(f"Pass! No overlap in TR cluster reps")
|
| 406 |
+
|
| 407 |
+
if split_by!="protein":
|
| 408 |
+
train_dna_seqs = set(train["dna_sequence"].unique().tolist())
|
| 409 |
+
val_dna_seqs = set(val["dna_sequence"].unique().tolist())
|
| 410 |
+
test_dna_seqs = set(test["dna_sequence"].unique().tolist())
|
| 411 |
+
|
| 412 |
+
assert len(train_dna_seqs.intersection(val_dna_seqs))==0
|
| 413 |
+
assert len(train_dna_seqs.intersection(test_dna_seqs))==0
|
| 414 |
+
assert len(val_dna_seqs.intersection(test_dna_seqs))==0
|
| 415 |
+
logger.info(f"Pass! No overlap in DNA sequences")
|
| 416 |
+
|
| 417 |
+
train_dna_reps = set(train["dna_cluster_rep"].unique().tolist())
|
| 418 |
+
val_dna_reps = set(val["dna_cluster_rep"].unique().tolist())
|
| 419 |
+
test_dna_reps = set(test["dna_cluster_rep"].unique().tolist())
|
| 420 |
+
|
| 421 |
+
assert len(train_dna_reps.intersection(val_dna_reps))==0
|
| 422 |
+
assert len(train_dna_reps.intersection(test_dna_reps))==0
|
| 423 |
+
assert len(val_dna_reps.intersection(test_dna_reps))==0
|
| 424 |
+
logger.info(f"Pass! No overlap in DNA cluster reps")
|
| 425 |
+
|
| 426 |
+
def main(cfg: DictConfig):
|
| 427 |
+
"""
|
| 428 |
+
Take a set of DNA clusters + protein clusters, and create the best possible splits into train/val/test.
|
| 429 |
+
"""
|
| 430 |
+
# construct edges from training data
|
| 431 |
+
edge_df = make_edges(processed_fimo_path=Path(root) / cfg.data_task.input_data_path,
|
| 432 |
+
protein_cluster_path=Path(root) / cfg.data_task.cluster_output_paths.protein,
|
| 433 |
+
dna_cluster_path=Path(root) / cfg.data_task.cluster_output_paths.dna)
|
| 434 |
+
edges = edge_df["edge"].unique().tolist()
|
| 435 |
+
|
| 436 |
+
# figure out if we actually even have a conflict
|
| 437 |
+
total_proteins = len(edge_df["tr_seqid"].unique().tolist())
|
| 438 |
+
total_protein_clusters = len(edge_df["tr_cluster_rep"].unique().tolist())
|
| 439 |
+
|
| 440 |
+
no_protein_overlap = (total_proteins)==(total_protein_clusters)
|
| 441 |
+
logger.info(f"All proteins are in their own clusters: {no_protein_overlap}")
|
| 442 |
+
|
| 443 |
+
if cfg.data_task.split_by=="dna":
|
| 444 |
+
logger.info(f"Easy split: all proteins are in their own clusters.")
|
| 445 |
+
dna_clusters = edge_df["dna_cluster_rep"].unique().tolist()
|
| 446 |
+
results = split_bipartite_fast(
|
| 447 |
+
dna_clusters,
|
| 448 |
+
split_names=("train","val","test"),
|
| 449 |
+
ratios=(cfg.data_task.train_ratio, cfg.data_task.val_ratio, cfg.data_task.test_ratio),
|
| 450 |
+
)
|
| 451 |
+
dna_assign, kept_by_split = results
|
| 452 |
+
|
| 453 |
+
edge_df["split"] = edge_df["dna_seqid"].map(dna_assign)
|
| 454 |
+
else:
|
| 455 |
+
results = split_bipartite_by_components(
|
| 456 |
+
edges,
|
| 457 |
+
split_names=("train","val","test"),
|
| 458 |
+
ratios=(cfg.data_task.train_ratio, cfg.data_task.val_ratio, cfg.data_task.test_ratio),
|
| 459 |
+
require_nonempty=cfg.data_task.require_nonempty,
|
| 460 |
+
seed=cfg.data_task.seed,
|
| 461 |
+
test_edges_must=None,
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
tf_assign, dna_assign, kept_by_split, total_kept, split_to_indices, split_to_edges = results
|
| 465 |
+
|
| 466 |
+
# Map each sample to its split
|
| 467 |
+
print(tf_assign)
|
| 468 |
+
print(dna_assign)
|
| 469 |
+
edge_df["tr_split"] = edge_df["tr_cluster_rep"].map(tf_assign)
|
| 470 |
+
edge_df["dna_split"] = edge_df["dna_cluster_rep"].map(dna_assign)
|
| 471 |
+
edge_df["same_split"] = edge_df["tr_split"]==edge_df["dna_split"] # should always be true if easy cluster
|
| 472 |
+
edge_df["split"] = edge_df["tr_split"]
|
| 473 |
+
print(edge_df)
|
| 474 |
+
edge_df["split"] = np.where(
|
| 475 |
+
edge_df["same_split"],
|
| 476 |
+
edge_df["split"], # keep existing split if same_split == True
|
| 477 |
+
"leak" # otherwise leak
|
| 478 |
+
)
|
| 479 |
+
print(edge_df)
|
| 480 |
+
|
| 481 |
+
# Print ratios: hopefully close to desired (e.g. 80/10/10)
|
| 482 |
+
print_split_ratios(kept_by_split)
|
| 483 |
+
|
| 484 |
+
# Make train, val, test sets
|
| 485 |
+
# make sure no ID is duplicate
|
| 486 |
+
assert len(edge_df["ID"].unique())==len(edge_df)
|
| 487 |
+
split_cols = ["ID","dna_sequence","tr_sequence","tr_cluster_rep","dna_cluster_rep", "scores","split"]
|
| 488 |
+
train = edge_df.loc[
|
| 489 |
+
edge_df["split"]=="train"
|
| 490 |
+
].reset_index(drop=True)[split_cols]
|
| 491 |
+
val = edge_df.loc[
|
| 492 |
+
edge_df["split"]=="val"
|
| 493 |
+
].reset_index(drop=True)[split_cols]
|
| 494 |
+
test = edge_df.loc[
|
| 495 |
+
edge_df["split"]=="test"
|
| 496 |
+
].reset_index(drop=True)[split_cols]
|
| 497 |
+
|
| 498 |
+
# ensure there is no overlap
|
| 499 |
+
check_validity(train, val, test, split_by=cfg.data_task.split_by)
|
| 500 |
+
|
| 501 |
+
logger.info(f"Length of train dataset: {len(train)} ({100*len(train)/sum([len(train),len(val),len(test)]):.2f}%)")
|
| 502 |
+
logger.info(f"Length of val dataset: {len(val)} ({100*len(val)/sum([len(train),len(val),len(test)]):.2f}%)")
|
| 503 |
+
logger.info(f"Length of test dataset: {len(test)} ({100*len(test)/sum([len(train),len(val),len(test)]):.2f}%)")
|
| 504 |
+
|
| 505 |
+
# create the output dir
|
| 506 |
+
split_out_dir = Path(root)/cfg.data_task.split_out_dir
|
| 507 |
+
os.makedirs(split_out_dir, exist_ok=True)
|
| 508 |
+
split_final_cols = ["ID","dna_sequence","tr_sequence","scores","split"]
|
| 509 |
+
train[split_final_cols].to_csv(split_out_dir/"train.csv", index=False)
|
| 510 |
+
val[split_final_cols].to_csv(split_out_dir/"val.csv", index=False)
|
| 511 |
+
test[split_final_cols].to_csv(split_out_dir/"test.csv", index=False)
|
| 512 |
+
logger.info(f"Saved all splits to {split_out_dir}")
|
dpacman/scripts/preprocess.py
CHANGED
|
@@ -1,11 +1,9 @@
|
|
| 1 |
import rootutils
|
| 2 |
-
|
| 3 |
-
root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
| 4 |
-
|
| 5 |
import hydra
|
| 6 |
from omegaconf import DictConfig
|
| 7 |
import logging
|
| 8 |
|
|
|
|
| 9 |
logger = logging.getLogger(__name__)
|
| 10 |
|
| 11 |
# import your processing entry points here
|
|
@@ -15,7 +13,8 @@ from dpacman.data_tasks.clean.remap import main as clean_remap_main
|
|
| 15 |
from dpacman.data_tasks.fimo.pre_fimo import main as pre_fimo_main
|
| 16 |
from dpacman.data_tasks.fimo.run_fimo import main as run_fimo_main
|
| 17 |
from dpacman.data_tasks.fimo.post_fimo import main as post_fimo_main
|
| 18 |
-
|
|
|
|
| 19 |
|
| 20 |
@hydra.main(
|
| 21 |
config_path=str(root / "configs"), config_name="preprocess", version_base="1.3"
|
|
@@ -26,6 +25,7 @@ def main(cfg: DictConfig):
|
|
| 26 |
|
| 27 |
logger.info(f"Running {task_type} task: {task_name}")
|
| 28 |
|
|
|
|
| 29 |
if task_type == "download":
|
| 30 |
if task_name == "genome":
|
| 31 |
download_genome_main(cfg)
|
|
@@ -34,12 +34,14 @@ def main(cfg: DictConfig):
|
|
| 34 |
else:
|
| 35 |
raise ValueError(f"No download pipeline defined for: {task_name}")
|
| 36 |
|
|
|
|
| 37 |
elif task_type == "clean":
|
| 38 |
if task_name == "remap":
|
| 39 |
clean_remap_main(cfg)
|
| 40 |
else:
|
| 41 |
raise ValueError(f"No clean pipeline defined for: {task_name}")
|
| 42 |
|
|
|
|
| 43 |
elif task_type == "fimo":
|
| 44 |
if task_name == "pre_fimo":
|
| 45 |
pre_fimo_main(cfg)
|
|
@@ -50,6 +52,20 @@ def main(cfg: DictConfig):
|
|
| 50 |
else:
|
| 51 |
raise ValueError(f"No clean pipeline defined for: {task_name}")
|
| 52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
else:
|
| 54 |
raise ValueError(f"Unknown task type: {task_type}")
|
| 55 |
|
|
|
|
| 1 |
import rootutils
|
|
|
|
|
|
|
|
|
|
| 2 |
import hydra
|
| 3 |
from omegaconf import DictConfig
|
| 4 |
import logging
|
| 5 |
|
| 6 |
+
root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
| 7 |
logger = logging.getLogger(__name__)
|
| 8 |
|
| 9 |
# import your processing entry points here
|
|
|
|
| 13 |
from dpacman.data_tasks.fimo.pre_fimo import main as pre_fimo_main
|
| 14 |
from dpacman.data_tasks.fimo.run_fimo import main as run_fimo_main
|
| 15 |
from dpacman.data_tasks.fimo.post_fimo import main as post_fimo_main
|
| 16 |
+
from dpacman.data_tasks.cluster.remap import main as cluster_remap_main
|
| 17 |
+
from dpacman.data_tasks.split.remap import main as split_remap_main
|
| 18 |
|
| 19 |
@hydra.main(
|
| 20 |
config_path=str(root / "configs"), config_name="preprocess", version_base="1.3"
|
|
|
|
| 25 |
|
| 26 |
logger.info(f"Running {task_type} task: {task_name}")
|
| 27 |
|
| 28 |
+
# Download
|
| 29 |
if task_type == "download":
|
| 30 |
if task_name == "genome":
|
| 31 |
download_genome_main(cfg)
|
|
|
|
| 34 |
else:
|
| 35 |
raise ValueError(f"No download pipeline defined for: {task_name}")
|
| 36 |
|
| 37 |
+
# Clean
|
| 38 |
elif task_type == "clean":
|
| 39 |
if task_name == "remap":
|
| 40 |
clean_remap_main(cfg)
|
| 41 |
else:
|
| 42 |
raise ValueError(f"No clean pipeline defined for: {task_name}")
|
| 43 |
|
| 44 |
+
# Fimo
|
| 45 |
elif task_type == "fimo":
|
| 46 |
if task_name == "pre_fimo":
|
| 47 |
pre_fimo_main(cfg)
|
|
|
|
| 52 |
else:
|
| 53 |
raise ValueError(f"No clean pipeline defined for: {task_name}")
|
| 54 |
|
| 55 |
+
# Cluster
|
| 56 |
+
elif task_type == "cluster":
|
| 57 |
+
if task_name == "remap":
|
| 58 |
+
cluster_remap_main(cfg)
|
| 59 |
+
else:
|
| 60 |
+
raise ValueError(f"No clean pipeline defined for: {task_name}")
|
| 61 |
+
|
| 62 |
+
elif task_type == "split":
|
| 63 |
+
if task_name == "remap":
|
| 64 |
+
split_remap_main(cfg)
|
| 65 |
+
else:
|
| 66 |
+
raise ValueError(f"No clean pipeline defined for: {task_name}")
|
| 67 |
+
|
| 68 |
+
# Unknown - error
|
| 69 |
else:
|
| 70 |
raise ValueError(f"Unknown task type: {task_type}")
|
| 71 |
|
dpacman/scripts/run_cluster.sh
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# Manually specify values used in the config
|
| 4 |
+
main_task="preprocess"
|
| 5 |
+
data_task_type="cluster"
|
| 6 |
+
timestamp=$(date "+%Y-%m-%d_%H-%M-%S")
|
| 7 |
+
|
| 8 |
+
run_dir="/vast/projects/pranam/lab/sophie/DPACMAN/logs/${main_task}/${data_task_type}/runs/${timestamp}"
|
| 9 |
+
mkdir -p "$run_dir"
|
| 10 |
+
|
| 11 |
+
nohup python -u -m scripts.preprocess \
|
| 12 |
+
hydra.run.dir="${run_dir}" \
|
| 13 |
+
data_task=${data_task_type}/remap \
|
| 14 |
+
data_task.cluster_dna_full="false" \
|
| 15 |
+
data_task.cluster_dna_peaks="false" \
|
| 16 |
+
data_task.cluster_protein="true" \
|
| 17 |
+
> "${run_dir}/run.log" 2>&1 &
|
| 18 |
+
|
| 19 |
+
echo $! > "${run_dir}/pid.txt"
|
dpacman/scripts/run_fimo.sh
CHANGED
|
@@ -5,13 +5,15 @@ main_task="preprocess"
|
|
| 5 |
data_task_type="fimo"
|
| 6 |
timestamp=$(date "+%Y-%m-%d_%H-%M-%S")
|
| 7 |
|
| 8 |
-
run_dir="
|
| 9 |
mkdir -p "$run_dir"
|
| 10 |
|
| 11 |
nohup python -u -m scripts.preprocess \
|
| 12 |
hydra.run.dir="${run_dir}" \
|
| 13 |
data_task=${data_task_type}/post_fimo \
|
| 14 |
data_task.debug="false" \
|
|
|
|
|
|
|
| 15 |
> "${run_dir}/run.log" 2>&1 &
|
| 16 |
|
| 17 |
echo $! > "${run_dir}/pid.txt"
|
|
|
|
| 5 |
data_task_type="fimo"
|
| 6 |
timestamp=$(date "+%Y-%m-%d_%H-%M-%S")
|
| 7 |
|
| 8 |
+
run_dir="/vast/projects/pranam/lab/sophie/DPACMAN/logs/${main_task}/${data_task_type}/runs/${timestamp}"
|
| 9 |
mkdir -p "$run_dir"
|
| 10 |
|
| 11 |
nohup python -u -m scripts.preprocess \
|
| 12 |
hydra.run.dir="${run_dir}" \
|
| 13 |
data_task=${data_task_type}/post_fimo \
|
| 14 |
data_task.debug="false" \
|
| 15 |
+
data_task.keep_fimo_only="false" \
|
| 16 |
+
data_task.processed_output_csv="dpacman/data_files/processed/fimo/post_fimo/all_peaks/remap2022_crm_fimo_output_q_processed.csv" \
|
| 17 |
> "${run_dir}/run.log" 2>&1 &
|
| 18 |
|
| 19 |
echo $! > "${run_dir}/pid.txt"
|
dpacman/scripts/run_split.sh
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# Manually specify values used in the config
|
| 4 |
+
main_task="preprocess"
|
| 5 |
+
data_task_type="split"
|
| 6 |
+
timestamp=$(date "+%Y-%m-%d_%H-%M-%S")
|
| 7 |
+
|
| 8 |
+
run_dir="/vast/projects/pranam/lab/sophie/DPACMAN/logs/${main_task}/${data_task_type}/runs/${timestamp}"
|
| 9 |
+
mkdir -p "$run_dir"
|
| 10 |
+
|
| 11 |
+
nohup python -u -m scripts.preprocess \
|
| 12 |
+
hydra.run.dir="${run_dir}" \
|
| 13 |
+
data_task="${data_task_type}/remap" \
|
| 14 |
+
data_task.split_by=dna \
|
| 15 |
+
data_task.train_ratio=0.8 \
|
| 16 |
+
data_task.val_ratio=0.1 \
|
| 17 |
+
data_task.test_ratio=0.1 \
|
| 18 |
+
data_task.split_out_dir=dpacman/data_files/processed/splits/by_dna \
|
| 19 |
+
> "${run_dir}/run.log" 2>&1 &
|
| 20 |
+
|
| 21 |
+
echo $! > "${run_dir}/pid.txt"
|
dpacman/utils/README.md
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Code in this folder was taken from olsdavis/gdd repository, which was based on DPLM (Diffusion Protein Language Model)
|
dpacman/utils/__init__.py
ADDED
|
File without changes
|
dpacman/utils/clustering.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import os
|
| 3 |
+
import subprocess
|
| 4 |
+
import sys
|
| 5 |
+
from Bio import SeqIO
|
| 6 |
+
import shutil
|
| 7 |
+
|
| 8 |
+
import rootutils
|
| 9 |
+
import logging
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
def ensure_mmseqs_in_path(mmseqs_dir):
|
| 14 |
+
"""
|
| 15 |
+
Checks if MMseqs2 is in the PATH. If it's not, add it. MMseqs2 will not run if this is not done correctly.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
mmseqs_dir (str): Directory containing MMseqs2 binaries
|
| 19 |
+
"""
|
| 20 |
+
mmseqs_bin = os.path.join(mmseqs_dir, 'mmseqs')
|
| 21 |
+
|
| 22 |
+
# Check if mmseqs is already in PATH
|
| 23 |
+
if shutil.which('mmseqs') is None:
|
| 24 |
+
# Export the MMseqs2 directory to PATH
|
| 25 |
+
os.environ['PATH'] = f"{mmseqs_dir}:{os.environ['PATH']}"
|
| 26 |
+
logger.info(f"\tAdded {mmseqs_dir} to PATH")
|
| 27 |
+
|
| 28 |
+
def process_fasta(fasta_path):
|
| 29 |
+
fasta_sequences = SeqIO.parse(open(fasta_path),'fasta')
|
| 30 |
+
d = {}
|
| 31 |
+
for fasta in fasta_sequences:
|
| 32 |
+
id, sequence = fasta.id, str(fasta.seq)
|
| 33 |
+
|
| 34 |
+
d[id] = sequence
|
| 35 |
+
|
| 36 |
+
return d
|
| 37 |
+
|
| 38 |
+
def analyze_clustering_result(input_fasta: str, tsv_path: str):
|
| 39 |
+
"""
|
| 40 |
+
Args:
|
| 41 |
+
input_fasta (str): path to input fasta file
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
# Process input fasta
|
| 45 |
+
input_d = process_fasta(input_fasta)
|
| 46 |
+
|
| 47 |
+
# Process clusters.tsv
|
| 48 |
+
clusters = pd.read_csv(f'{tsv_path}',sep='\t',header=None)
|
| 49 |
+
clusters = clusters.rename(columns={
|
| 50 |
+
0: 'representative seq_id',
|
| 51 |
+
1: 'member seq_id'
|
| 52 |
+
})
|
| 53 |
+
|
| 54 |
+
clusters['representative seq'] = clusters['representative seq_id'].apply(lambda seq_id: input_d[seq_id])
|
| 55 |
+
clusters['member seq'] = clusters['member seq_id'].apply(lambda seq_id: input_d[seq_id])
|
| 56 |
+
|
| 57 |
+
# Sort them so that splitting results are reproducible
|
| 58 |
+
clusters = clusters.sort_values(by=['representative seq_id','member seq_id'],ascending=True).reset_index(drop=True)
|
| 59 |
+
|
| 60 |
+
return clusters
|
| 61 |
+
|
| 62 |
+
def make_fasta(sequences: dict, fasta_path: str):
|
| 63 |
+
"""
|
| 64 |
+
Makes a fasta file from sequences, where the key is the header and the value is the sequence.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
sequences (dict): A dictionary where the key is the header and the value is the sequence.
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
str: The path to the fasta file.
|
| 71 |
+
"""
|
| 72 |
+
with open(fasta_path, 'w') as f:
|
| 73 |
+
for header, sequence in sequences.items():
|
| 74 |
+
f.write(f'>{header}\n{sequence}\n')
|
| 75 |
+
|
| 76 |
+
return fasta_path
|
| 77 |
+
|
| 78 |
+
def run_mmseqs_clustering(input_fasta, output_dir, min_seq_id=0.3, c=0.8, cov_mode=0, cluster_mode=0, path_to_mmseqs='fuson_plm/mmseqs', dbtype=1):
|
| 79 |
+
"""
|
| 80 |
+
Runs MMSeqs2 clustering using easycluster module
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
input_fasta (str): path to input fasta file, formatted >header\nsequence\n>header\nsequence....
|
| 84 |
+
output_dir (str): path to output dir for clustering results
|
| 85 |
+
min_seq_id (float): number [0,1] representing --min-seq-id in cluster command
|
| 86 |
+
c (float): nunber [0,1] representing -c in cluster command
|
| 87 |
+
cov_mode (int): number 0, 1, 2, or 3 representing --cov-mode in cluster command
|
| 88 |
+
cluster_mode (int): number 0, 1, or 2 representing --cluster-mode in cluster command
|
| 89 |
+
|
| 90 |
+
"""
|
| 91 |
+
# Get mmseqs dir
|
| 92 |
+
logger.info("\nRunning MMSeqs clustering...")
|
| 93 |
+
mmseqs_dir = os.path.join(path_to_mmseqs[0:path_to_mmseqs.index('/mmseqs')], 'mmseqs/bin')
|
| 94 |
+
logger.info(f"Running mmseqs clustering from {mmseqs_dir}")
|
| 95 |
+
|
| 96 |
+
# Ensure MMseqs2 is in the PATH
|
| 97 |
+
ensure_mmseqs_in_path(mmseqs_dir)
|
| 98 |
+
|
| 99 |
+
# Define paths for MMseqs2
|
| 100 |
+
mmseqs_bin = "mmseqs" # Ensure this is in your PATH or provide the full path to mmseqs binary
|
| 101 |
+
|
| 102 |
+
# Create the output directory
|
| 103 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 104 |
+
|
| 105 |
+
# Run MMseqs2 easy-cluster
|
| 106 |
+
cmd_easy_cluster = [
|
| 107 |
+
mmseqs_bin, "easy-cluster", input_fasta, os.path.join(output_dir, "mmseqs"), output_dir,
|
| 108 |
+
"--min-seq-id", str(min_seq_id),
|
| 109 |
+
"-c", str(c),
|
| 110 |
+
"--cov-mode", str(cov_mode),
|
| 111 |
+
"--cluster-mode", str(cluster_mode),
|
| 112 |
+
"--dbtype", str(dbtype)
|
| 113 |
+
]
|
| 114 |
+
|
| 115 |
+
# Write the command to a log file
|
| 116 |
+
logger.info("\n\tCommand entered to MMSeqs2:")
|
| 117 |
+
logger.info("\t" + " ".join(cmd_easy_cluster) + "\n")
|
| 118 |
+
|
| 119 |
+
subprocess.run(cmd_easy_cluster, check=True)
|
| 120 |
+
|
| 121 |
+
logger.info(f"Clustering completed. Results are in {output_dir}")
|
| 122 |
+
|
| 123 |
+
def cluster_summary(clusters: pd.DataFrame):
|
| 124 |
+
"""
|
| 125 |
+
Summarizes how many clusters were formed, how big they are, etc ...
|
| 126 |
+
"""
|
| 127 |
+
grouped_clusters = clusters.groupby('representative seq_id')['member seq_id'].count().reset_index().rename(columns={'member seq_id':'member count'})
|
| 128 |
+
assert len(grouped_clusters) == len(clusters['representative seq_id'].unique()) # make sure number of cluster reps = # grouped clusters
|
| 129 |
+
|
| 130 |
+
total_seqs = sum(grouped_clusters['member count'])
|
| 131 |
+
logger.info(f"Created {len(grouped_clusters)} clusters of {total_seqs} sequences")
|
| 132 |
+
logger.info(f"\t{len(grouped_clusters.loc[grouped_clusters['member count']==1])} clusters of size 1")
|
| 133 |
+
csize1_seqs = sum(grouped_clusters[grouped_clusters['member count']==1]['member count'])
|
| 134 |
+
logger.info(f"\t\tsequences: {csize1_seqs} ({round(100*csize1_seqs/total_seqs, 2)}%)")
|
| 135 |
+
|
| 136 |
+
logger.info(f"\t{len(grouped_clusters.loc[grouped_clusters['member count']>1])} clusters of size > 1")
|
| 137 |
+
csizeg1_seqs = sum(grouped_clusters[grouped_clusters['member count']>1]['member count'])
|
| 138 |
+
logger.info(f"\t\tsequences: {csizeg1_seqs} ({round(100*csizeg1_seqs/total_seqs, 2)}%)")
|
| 139 |
+
logger.info(f"\tlargest cluster: {max(grouped_clusters['member count'])}")
|
| 140 |
+
|
| 141 |
+
logger.info("\nCluster size breakdown below...")
|
| 142 |
+
|
| 143 |
+
value_counts = grouped_clusters['member count'].value_counts().reset_index().rename(columns={'member count':'cluster size (n_members)','count': 'n_clusters'})
|
| 144 |
+
logger.info(value_counts.sort_values(by='cluster size (n_members)',ascending=True).to_string(index=False))
|
dpacman/utils/models.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Model-related utilities, such as setting seed
|
| 3 |
+
"""
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
import random
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
def set_seed(seed: int = 42) -> None:
|
| 10 |
+
np.random.seed(seed)
|
| 11 |
+
random.seed(seed)
|
| 12 |
+
torch.manual_seed(seed)
|
| 13 |
+
torch.cuda.manual_seed(seed)
|
| 14 |
+
# When running on the CuDNN backend, two further options must be set
|
| 15 |
+
torch.backends.cudnn.deterministic = True
|
| 16 |
+
torch.backends.cudnn.benchmark = False
|
| 17 |
+
# Set a fixed value for the hash seed
|
| 18 |
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
| 19 |
+
print(f"Random seed set as {seed}")
|
dpacman/utils/splitting.py
ADDED
|
File without changes
|
environment.yaml
CHANGED
|
@@ -1,42 +1,45 @@
|
|
| 1 |
-
|
| 2 |
-
# - pip installs packages in a loop, without ensuring dependencies across all packages
|
| 3 |
-
# are fulfilled simultaneously, but conda achieves proper dependency control across
|
| 4 |
-
# all packages
|
| 5 |
-
# - conda allows for installing packages without requiring certain compilers or
|
| 6 |
-
# libraries to be available in the system, since it installs precompiled binaries
|
| 7 |
-
# in case of errors look here: https://pytorch.org/get-started/previous-versions/
|
| 8 |
-
|
| 9 |
-
name: dpacman
|
| 10 |
-
|
| 11 |
channels:
|
| 12 |
-
- bioconda
|
| 13 |
- conda-forge
|
| 14 |
-
-
|
| 15 |
-
|
| 16 |
-
# it is strongly recommended to specify versions of packages installed through conda
|
| 17 |
-
# to avoid situation when version-unspecified packages install their latest major
|
| 18 |
-
# versions which can sometimes break things
|
| 19 |
-
|
| 20 |
-
# current approach below keeps the dependencies in the same major versions across all
|
| 21 |
-
# users, but allows for different minor and patch versions of packages where backwards
|
| 22 |
-
# compatibility is usually guaranteed
|
| 23 |
-
|
| 24 |
dependencies:
|
| 25 |
- python=3.10
|
| 26 |
-
- dask[complete]
|
| 27 |
- pip>=23
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
- pip:
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: dnabind
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
channels:
|
|
|
|
| 3 |
- conda-forge
|
| 4 |
+
- bioconda
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
dependencies:
|
| 6 |
- python=3.10
|
|
|
|
| 7 |
- pip>=23
|
| 8 |
+
# dask "complete" equivalent via conda packages
|
| 9 |
+
- dask
|
| 10 |
+
- distributed
|
| 11 |
+
- rich=13.*
|
| 12 |
+
# ghostscript 9.18 is very old; keep only if you truly need that exact version.
|
| 13 |
+
# If not, drop the pin or use the conda-forge version (>=10).
|
| 14 |
+
- ghostscript
|
| 15 |
+
- lxml=5.3.0
|
| 16 |
+
- pandas=2.2.3
|
| 17 |
+
- gitpython=3.1.*
|
| 18 |
+
- tqdm=4.67.*
|
| 19 |
+
- matplotlib=3.10.*
|
| 20 |
- pip:
|
| 21 |
+
# Pull GPU wheels (CUDA 12.8) from PyTorch's cu128 index; fall back to PyPI for others
|
| 22 |
+
- --index-url https://download.pytorch.org/whl/cu128
|
| 23 |
+
- --extra-index-url https://pypi.org/simple
|
| 24 |
+
|
| 25 |
+
# PyTorch + CUDA 12.8
|
| 26 |
+
- torch==2.7.1
|
| 27 |
+
- torchvision==0.22.1
|
| 28 |
+
# - torchaudio==2.7.1 # optional, if you need it
|
| 29 |
+
|
| 30 |
+
# Lightning (classic)
|
| 31 |
+
- pytorch-lightning
|
| 32 |
+
|
| 33 |
+
# Your pinned Python libs
|
| 34 |
+
- rootutils==1.0.7
|
| 35 |
+
- polars==1.32.2
|
| 36 |
+
- hydra-core==1.3.2
|
| 37 |
+
- hydra-colorlog==1.2.0
|
| 38 |
+
- omegaconf==2.3.0
|
| 39 |
+
- pymex==0.9.31
|
| 40 |
+
- transformers==4.51.2
|
| 41 |
+
- scikit-learn==1.7.1
|
| 42 |
+
- biopython==1.85
|
| 43 |
+
- ortools==9.14.6206
|
| 44 |
+
# Your package in editable mode
|
| 45 |
+
- -e .
|