added model files
Browse files- dpacman/classifier/model/__init__.py +0 -0
- dpacman/classifier/model/compress_embeddings.py +54 -0
- dpacman/classifier/model/compute_embeddings.py +546 -0
- dpacman/classifier/model/extract_tf_symbols.py +27 -0
- dpacman/classifier/model/make_pair_list.py +220 -0
- dpacman/classifier/model/make_peak_fasta.py +13 -0
- dpacman/classifier/model/model.py +103 -0
- dpacman/classifier/model/train.py +262 -0
dpacman/classifier/model/__init__.py
ADDED
|
File without changes
|
dpacman/classifier/model/compress_embeddings.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# compress_embeddings.py
|
| 2 |
+
# USAGE: python compress_embeddings.py --input_glob "/path/to/esm_embeddings/*.npy" --output_dir "/path/to/compressed_embeddings" --esm_dim 1280 --out_dim 256
|
| 3 |
+
# --------------
|
| 4 |
+
import os
|
| 5 |
+
import glob
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn
|
| 9 |
+
|
| 10 |
+
class EmbeddingCompressor(nn.Module):
|
| 11 |
+
def __init__(self, input_dim: int = 1280, output_dim: int = 256):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.fc = nn.Linear(input_dim, output_dim)
|
| 14 |
+
|
| 15 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 16 |
+
"""
|
| 17 |
+
x: (batch, L, input_dim) or (L, input_dim)
|
| 18 |
+
returns: (batch, output_dim) or (output_dim,)
|
| 19 |
+
"""
|
| 20 |
+
if x.dim() == 2:
|
| 21 |
+
# single example: mean over tokens
|
| 22 |
+
x = x.mean(dim=0, keepdim=True) # → (1, input_dim)
|
| 23 |
+
else:
|
| 24 |
+
# batch: mean over tokens
|
| 25 |
+
x = x.mean(dim=1) # → (batch, input_dim)
|
| 26 |
+
return self.fc(x) # → (batch, output_dim)
|
| 27 |
+
|
| 28 |
+
def compress_file(in_path: str, out_path: str, model: EmbeddingCompressor):
|
| 29 |
+
arr = np.load(in_path) # shape (L, D) or (batch, L, D)
|
| 30 |
+
tensor = torch.from_numpy(arr).float()
|
| 31 |
+
with torch.no_grad():
|
| 32 |
+
compressed = model(tensor) # → (batch, 256)
|
| 33 |
+
out = compressed.cpu().numpy()
|
| 34 |
+
np.save(out_path, out)
|
| 35 |
+
print(f"Saved {out_path}")
|
| 36 |
+
|
| 37 |
+
if __name__ == "__main__":
|
| 38 |
+
import argparse
|
| 39 |
+
parser = argparse.ArgumentParser(description="Compress ESM embeddings to 256d")
|
| 40 |
+
parser.add_argument("--input_glob", type=str, required=True,
|
| 41 |
+
help="Glob for your .npy ESM embeddings (e.g. data/esm_*.npy)")
|
| 42 |
+
parser.add_argument("--output_dir", type=str, required=True)
|
| 43 |
+
parser.add_argument("--esm_dim", type=int, default=1280)
|
| 44 |
+
parser.add_argument("--out_dim", type=int, default=256)
|
| 45 |
+
args = parser.parse_args()
|
| 46 |
+
|
| 47 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 48 |
+
compressor = EmbeddingCompressor(args.esm_dim, args.out_dim)
|
| 49 |
+
compressor.eval()
|
| 50 |
+
|
| 51 |
+
for fn in glob.glob(args.input_glob):
|
| 52 |
+
base = os.path.basename(fn).replace(".npy", "_256.npy")
|
| 53 |
+
out_path = os.path.join(args.output_dir, base)
|
| 54 |
+
compress_file(fn, out_path, compressor)
|
dpacman/classifier/model/compute_embeddings.py
ADDED
|
@@ -0,0 +1,546 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Plug-and-play embedding extraction for:
|
| 3 |
+
• Chromosome sequences (from raw UCSC JSON)
|
| 4 |
+
• TF sequences (transcription_factors.fasta)
|
| 5 |
+
|
| 6 |
+
Usage example (DNA + protein in one go):
|
| 7 |
+
module load miniconda/24.7.1
|
| 8 |
+
conda activate dpacman
|
| 9 |
+
python dpacman/data/compute_embeddings.py \
|
| 10 |
+
--genome-json-dir ../data_files/raw/genomes/hg38 \
|
| 11 |
+
--tf-fasta ../data_files/processed/tfclust/hg38_tf/transcription_factors.fasta \
|
| 12 |
+
--chrom-model caduceus \
|
| 13 |
+
--tf-model esm-dbp \
|
| 14 |
+
--out-dir ../data_files/processed/tfclust/hg38_tf/embeddings \
|
| 15 |
+
--device cuda
|
| 16 |
+
"""
|
| 17 |
+
import os
|
| 18 |
+
import re
|
| 19 |
+
import argparse
|
| 20 |
+
import json
|
| 21 |
+
import numpy as np
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
import torch
|
| 24 |
+
from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM, pipeline
|
| 25 |
+
import esm
|
| 26 |
+
from Bio import SeqIO
|
| 27 |
+
import time
|
| 28 |
+
|
| 29 |
+
# ---- model wrappers ----
|
| 30 |
+
|
| 31 |
+
class CaduceusEmbedder:
|
| 32 |
+
def __init__(self, device, chunk_size=131_072, overlap=0):
|
| 33 |
+
"""
|
| 34 |
+
device: 'cpu' or 'cuda'
|
| 35 |
+
chunk_size: max bases (and thus tokens) to send in one forward pass
|
| 36 |
+
overlap: how many bases each window overlaps the previous; 0 = no overlap
|
| 37 |
+
"""
|
| 38 |
+
model_name = "kuleshov-group/caduceus-ph_seqlen-131k_d_model-256_n_layer-16"
|
| 39 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 40 |
+
model_name, trust_remote_code=True
|
| 41 |
+
)
|
| 42 |
+
self.model = AutoModel.from_pretrained(
|
| 43 |
+
model_name, trust_remote_code=True
|
| 44 |
+
).to(device).eval()
|
| 45 |
+
self.device = device
|
| 46 |
+
self.chunk_size = chunk_size
|
| 47 |
+
self.step = chunk_size - overlap
|
| 48 |
+
|
| 49 |
+
def embed(self, seqs):
|
| 50 |
+
"""
|
| 51 |
+
seqs: List[str] of DNA sequences (each <= chunk_size for this test)
|
| 52 |
+
returns: np.ndarray of shape (N, L, D), raw per‐token embeddings
|
| 53 |
+
"""
|
| 54 |
+
# outputs = []
|
| 55 |
+
# for seq in seqs:
|
| 56 |
+
# # --- new: raw per‐token embeddings in one shot ---
|
| 57 |
+
# toks = self.tokenizer(
|
| 58 |
+
# seq,
|
| 59 |
+
# return_tensors="pt",
|
| 60 |
+
# padding=False,
|
| 61 |
+
# truncation=True,
|
| 62 |
+
# max_length=self.chunk_size
|
| 63 |
+
# ).to(self.device)
|
| 64 |
+
# with torch.no_grad():
|
| 65 |
+
# out = self.model(**toks).last_hidden_state # (1, L, D)
|
| 66 |
+
# outputs.append(out.cpu().numpy()[0]) # (L, D)
|
| 67 |
+
|
| 68 |
+
# return np.stack(outputs, axis=0) # (N, L, D)
|
| 69 |
+
outputs = []
|
| 70 |
+
for seq in seqs:
|
| 71 |
+
toks = self.tokenizer(
|
| 72 |
+
seq,
|
| 73 |
+
return_tensors="pt",
|
| 74 |
+
padding=False,
|
| 75 |
+
truncation=True,
|
| 76 |
+
max_length=self.chunk_size
|
| 77 |
+
).to(self.device)
|
| 78 |
+
with torch.no_grad():
|
| 79 |
+
out = self.model(**toks).last_hidden_state # (1, L, D)
|
| 80 |
+
outputs.append(out.cpu().numpy()[0]) # (L, D)
|
| 81 |
+
return outputs # list of variable-length (L_i, D) arrays
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def benchmark(self, lengths=None):
|
| 85 |
+
"""
|
| 86 |
+
Time embedding on single-sequence of various lengths.
|
| 87 |
+
By default tests [5K,10K,50K,100K,chunk_size].
|
| 88 |
+
"""
|
| 89 |
+
tests = lengths or [5_000, 10_000, 50_000, 100_000, self.chunk_size]
|
| 90 |
+
print(f"→ Benchmarking Caduceus on device={self.device}")
|
| 91 |
+
for sz in tests:
|
| 92 |
+
seq = "A" * sz
|
| 93 |
+
# Warm-up
|
| 94 |
+
_ = self.embed([seq])
|
| 95 |
+
if self.device != "cpu":
|
| 96 |
+
torch.cuda.synchronize()
|
| 97 |
+
t0 = time.perf_counter()
|
| 98 |
+
_ = self.embed([seq])
|
| 99 |
+
if self.device != "cpu":
|
| 100 |
+
torch.cuda.synchronize()
|
| 101 |
+
t1 = time.perf_counter()
|
| 102 |
+
print(f" length={sz:6,d} time={(t1-t0)*1000:7.1f} ms")
|
| 103 |
+
|
| 104 |
+
class SegmentNTEmbedder:
|
| 105 |
+
def __init__(self, device):
|
| 106 |
+
self.tokenizer = AutoTokenizer.from_pretrained("InstaDeepAI/segment_nt", trust_remote_code=True)
|
| 107 |
+
self.model = AutoModel.from_pretrained("InstaDeepAI/segment_nt", trust_remote_code=True).to(device).eval()
|
| 108 |
+
self.device = device
|
| 109 |
+
|
| 110 |
+
def _adjust_length(self, input_ids):
|
| 111 |
+
bs, L = input_ids.shape
|
| 112 |
+
excl = L - 1
|
| 113 |
+
remainder = (excl) % 4
|
| 114 |
+
if remainder != 0:
|
| 115 |
+
pad_needed = 4 - remainder
|
| 116 |
+
pad_tensor = torch.full((bs, pad_needed), self.tokenizer.pad_token_id, dtype=input_ids.dtype, device=input_ids.device)
|
| 117 |
+
input_ids = torch.cat([input_ids, pad_tensor], dim=1)
|
| 118 |
+
return input_ids
|
| 119 |
+
|
| 120 |
+
def embed(self, seqs, batch_size=16):
|
| 121 |
+
"""
|
| 122 |
+
seqs: List[str]
|
| 123 |
+
Returns: np.ndarray of shape (N, D)
|
| 124 |
+
"""
|
| 125 |
+
all_embeddings = []
|
| 126 |
+
for i in range(0, len(seqs), batch_size):
|
| 127 |
+
batch_seqs = seqs[i : i + batch_size]
|
| 128 |
+
encoded = self.tokenizer.batch_encode_plus(
|
| 129 |
+
batch_seqs,
|
| 130 |
+
return_tensors="pt",
|
| 131 |
+
padding=True,
|
| 132 |
+
truncation=True,
|
| 133 |
+
)
|
| 134 |
+
input_ids = encoded["input_ids"].to(self.device) # (B, L)
|
| 135 |
+
attention_mask = input_ids != self.tokenizer.pad_token_id
|
| 136 |
+
|
| 137 |
+
input_ids = self._adjust_length(input_ids)
|
| 138 |
+
attention_mask = (input_ids != self.tokenizer.pad_token_id)
|
| 139 |
+
|
| 140 |
+
with torch.no_grad():
|
| 141 |
+
outs = self.model(
|
| 142 |
+
input_ids,
|
| 143 |
+
attention_mask=attention_mask,
|
| 144 |
+
output_hidden_states=True,
|
| 145 |
+
return_dict=True,
|
| 146 |
+
)
|
| 147 |
+
if hasattr(outs, "hidden_states") and outs.hidden_states is not None:
|
| 148 |
+
last_hidden = outs.hidden_states[-1] # (B, L, D)
|
| 149 |
+
else:
|
| 150 |
+
last_hidden = outs.last_hidden_state # fallback
|
| 151 |
+
|
| 152 |
+
# Exclude CLS token if present (assume first token) and pool
|
| 153 |
+
pooled = last_hidden[:, 1:, :].mean(dim=1) # (B, D)
|
| 154 |
+
all_embeddings.append(pooled.cpu().numpy())
|
| 155 |
+
|
| 156 |
+
# release fragmentation
|
| 157 |
+
torch.cuda.empty_cache()
|
| 158 |
+
|
| 159 |
+
return np.vstack(all_embeddings) # (N, D)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class DNABertEmbedder:
|
| 163 |
+
def __init__(self, device):
|
| 164 |
+
self.tokenizer = AutoTokenizer.from_pretrained("zhihan1996/DNA_bert_6", trust_remote_code=True)
|
| 165 |
+
self.model = AutoModel.from_pretrained("zhihan1996/DNA_bert_6", trust_remote_code=True).to(device)
|
| 166 |
+
self.device = device
|
| 167 |
+
|
| 168 |
+
def embed(self, seqs):
|
| 169 |
+
embs = []
|
| 170 |
+
for s in seqs:
|
| 171 |
+
tokens = self.tokenizer(s, return_tensors="pt", padding=True)["input_ids"].to(self.device)
|
| 172 |
+
with torch.no_grad():
|
| 173 |
+
out = self.model(tokens).last_hidden_state.mean(1)
|
| 174 |
+
embs.append(out.cpu().numpy())
|
| 175 |
+
return np.vstack(embs)
|
| 176 |
+
|
| 177 |
+
class NucleotideTransformerEmbedder:
|
| 178 |
+
def __init__(self, device):
|
| 179 |
+
# HF “feature-extraction” returns a list of (L, D) arrays for each input
|
| 180 |
+
# device: “cpu” or “cuda”
|
| 181 |
+
self.pipe = pipeline(
|
| 182 |
+
"feature-extraction",
|
| 183 |
+
model="InstaDeepAI/nucleotide-transformer-500m-1000g",
|
| 184 |
+
device= -1 if device=="cpu" else 0 # HF uses -1 for CPU, 0 for GPU #:contentReference[oaicite:0]{index=0}
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
def embed(self, seqs):
|
| 188 |
+
"""
|
| 189 |
+
seqs: List[str] of raw DNA sequences
|
| 190 |
+
returns: (N, D) array, one D-dim vector per sequence
|
| 191 |
+
"""
|
| 192 |
+
all_embeddings = self.pipe(seqs, truncation=True, padding=True)
|
| 193 |
+
# all_embeddings is a List of shape (L, D) arrays
|
| 194 |
+
pooled = [ np.mean(x, axis=0) for x in all_embeddings ]
|
| 195 |
+
return np.vstack(pooled)
|
| 196 |
+
|
| 197 |
+
# class ESMEmbedder:
|
| 198 |
+
# def __init__(self, device):
|
| 199 |
+
# self.model, self.alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
|
| 200 |
+
# self.batch_converter = self.alphabet.get_batch_converter()
|
| 201 |
+
# self.model.to(device).eval()
|
| 202 |
+
# self.device = device
|
| 203 |
+
|
| 204 |
+
# def embed(self, seqs):
|
| 205 |
+
# batch = [(str(i), seq) for i, seq in enumerate(seqs)]
|
| 206 |
+
# _, _, toks = self.batch_converter(batch)
|
| 207 |
+
# toks = toks.to(self.device)
|
| 208 |
+
# with torch.no_grad():
|
| 209 |
+
# results = self.model(toks, repr_layers=[33], return_contacts=False)
|
| 210 |
+
# reps = results["representations"][33]
|
| 211 |
+
# return reps[:, 1:-1].mean(1).cpu().numpy()
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
class ESMEmbedder:
|
| 215 |
+
def __init__(self, device, model_name="esm2_t33_650M_UR50D"):
|
| 216 |
+
# Try to load the specified ESM-2 model; fallback to esm1b if missing
|
| 217 |
+
self.device = device
|
| 218 |
+
try:
|
| 219 |
+
self.model, self.alphabet = getattr(esm.pretrained, model_name)()
|
| 220 |
+
self.is_esm2 = model_name.lower().startswith("esm2")
|
| 221 |
+
except AttributeError:
|
| 222 |
+
# fallback to ESM-1b
|
| 223 |
+
self.model, self.alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
|
| 224 |
+
self.is_esm2 = False
|
| 225 |
+
self.batch_converter = self.alphabet.get_batch_converter()
|
| 226 |
+
self.model.to(device).eval()
|
| 227 |
+
# determine max length: esm2 models vary; use default 1024 for esm1b
|
| 228 |
+
self.max_len = 4096 if self.is_esm2 else 1024 # adjust if your esm2 variant has explicit limit
|
| 229 |
+
# for chunking: reserve 2 tokens if model uses BOS/EOS
|
| 230 |
+
self.chunk_size = self.max_len - 2
|
| 231 |
+
self.overlap = self.chunk_size // 4 # 25% overlap to smooth boundaries
|
| 232 |
+
|
| 233 |
+
def _chunk_sequence(self, seq):
|
| 234 |
+
"""
|
| 235 |
+
Return list of possibly overlapping chunks of seq, each <= chunk_size.
|
| 236 |
+
"""
|
| 237 |
+
if len(seq) <= self.chunk_size:
|
| 238 |
+
return [seq]
|
| 239 |
+
step = self.chunk_size - self.overlap
|
| 240 |
+
chunks = []
|
| 241 |
+
for i in range(0, len(seq), step):
|
| 242 |
+
chunk = seq[i : i + self.chunk_size]
|
| 243 |
+
if not chunk:
|
| 244 |
+
break
|
| 245 |
+
chunks.append(chunk)
|
| 246 |
+
return chunks
|
| 247 |
+
|
| 248 |
+
def embed(self, seqs):
|
| 249 |
+
"""
|
| 250 |
+
seqs: List[str] of protein sequences.
|
| 251 |
+
Returns: np.ndarray of shape (N, D) pooled per-sequence embeddings.
|
| 252 |
+
"""
|
| 253 |
+
all_embeddings = []
|
| 254 |
+
for i, seq in enumerate(seqs):
|
| 255 |
+
chunks = self._chunk_sequence(seq)
|
| 256 |
+
chunk_vecs = []
|
| 257 |
+
# process chunks in batch if small number, else sequentially
|
| 258 |
+
for chunk in chunks:
|
| 259 |
+
batch = [(str(i), chunk)]
|
| 260 |
+
_, _, toks = self.batch_converter(batch)
|
| 261 |
+
toks = toks.to(self.device)
|
| 262 |
+
with torch.no_grad():
|
| 263 |
+
results = self.model(toks, repr_layers=[33], return_contacts=False)
|
| 264 |
+
reps = results["representations"][33] # (1, L, D)
|
| 265 |
+
# remove BOS/EOS if present: take 1:-1 if length permits
|
| 266 |
+
if reps.size(1) > 2:
|
| 267 |
+
rep = reps[:, 1:-1].mean(1) # (1, D)
|
| 268 |
+
else:
|
| 269 |
+
rep = reps.mean(1) # fallback
|
| 270 |
+
chunk_vecs.append(rep.squeeze(0)) # (D,)
|
| 271 |
+
if len(chunk_vecs) == 1:
|
| 272 |
+
seq_vec = chunk_vecs[0]
|
| 273 |
+
else:
|
| 274 |
+
# average chunk vectors
|
| 275 |
+
stacked = torch.stack(chunk_vecs, dim=0) # (num_chunks, D)
|
| 276 |
+
seq_vec = stacked.mean(0)
|
| 277 |
+
all_embeddings.append(seq_vec.cpu().numpy())
|
| 278 |
+
return np.vstack(all_embeddings) # (N, D)
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
# class ESMDBPEmbedder:
|
| 282 |
+
# def __init__(self, device):
|
| 283 |
+
# base_model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
|
| 284 |
+
# model_path = (
|
| 285 |
+
# Path(__file__).resolve().parent.parent
|
| 286 |
+
# / "pretrained" / "ESM-DBP" / "ESM-DBP.model"
|
| 287 |
+
# )
|
| 288 |
+
# checkpoint = torch.load(model_path, map_location="cpu")
|
| 289 |
+
# clean_sd = {}
|
| 290 |
+
# for k, v in checkpoint.items():
|
| 291 |
+
# clean_sd[k.replace("module.", "")] = v
|
| 292 |
+
# result = base_model.load_state_dict(clean_sd, strict=False)
|
| 293 |
+
# if result.missing_keys:
|
| 294 |
+
# print(f"[ESMDBP] missing keys: {result.missing_keys}")
|
| 295 |
+
# if result.unexpected_keys:
|
| 296 |
+
# print(f"[ESMDBP] unexpected keys: {result.unexpected_keys}")
|
| 297 |
+
|
| 298 |
+
# self.model = base_model.to(device).eval()
|
| 299 |
+
# self.alphabet = alphabet
|
| 300 |
+
# self.batch_converter = alphabet.get_batch_converter()
|
| 301 |
+
# self.device = device
|
| 302 |
+
|
| 303 |
+
# def embed(self, seqs):
|
| 304 |
+
# batch = [(str(i), seq) for i, seq in enumerate(seqs)]
|
| 305 |
+
# _, _, toks = self.batch_converter(batch)
|
| 306 |
+
# toks = toks.to(self.device)
|
| 307 |
+
# with torch.no_grad():
|
| 308 |
+
# out = self.model(toks, repr_layers=[33], return_contacts=False)
|
| 309 |
+
# reps = out["representations"][33]
|
| 310 |
+
# # skip start/end tokens
|
| 311 |
+
# return reps[:, 1:-1].mean(1).cpu().numpy()
|
| 312 |
+
|
| 313 |
+
class ESMDBPEmbedder:
|
| 314 |
+
def __init__(self, device):
|
| 315 |
+
base_model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
|
| 316 |
+
model_path = (
|
| 317 |
+
Path(__file__).resolve().parent.parent
|
| 318 |
+
/ "pretrained" / "ESM-DBP" / "ESM-DBP.model"
|
| 319 |
+
)
|
| 320 |
+
checkpoint = torch.load(model_path, map_location="cpu")
|
| 321 |
+
clean_sd = {}
|
| 322 |
+
for k, v in checkpoint.items():
|
| 323 |
+
clean_sd[k.replace("module.", "")] = v
|
| 324 |
+
result = base_model.load_state_dict(clean_sd, strict=False)
|
| 325 |
+
if result.missing_keys:
|
| 326 |
+
print(f"[ESMDBP] missing keys: {result.missing_keys}")
|
| 327 |
+
if result.unexpected_keys:
|
| 328 |
+
print(f"[ESMDBP] unexpected keys: {result.unexpected_keys}")
|
| 329 |
+
|
| 330 |
+
self.model = base_model.to(device).eval()
|
| 331 |
+
self.alphabet = alphabet
|
| 332 |
+
self.batch_converter = alphabet.get_batch_converter()
|
| 333 |
+
self.device = device
|
| 334 |
+
self.max_len = 1024 # same limit as esm1b
|
| 335 |
+
self.chunk_size = self.max_len - 2
|
| 336 |
+
self.overlap = self.chunk_size // 4
|
| 337 |
+
|
| 338 |
+
def _chunk_sequence(self, seq):
|
| 339 |
+
if len(seq) <= self.chunk_size:
|
| 340 |
+
return [seq]
|
| 341 |
+
step = self.chunk_size - self.overlap
|
| 342 |
+
chunks = []
|
| 343 |
+
for i in range(0, len(seq), step):
|
| 344 |
+
chunk = seq[i : i + self.chunk_size]
|
| 345 |
+
if not chunk:
|
| 346 |
+
break
|
| 347 |
+
chunks.append(chunk)
|
| 348 |
+
return chunks
|
| 349 |
+
|
| 350 |
+
def embed(self, seqs):
|
| 351 |
+
all_embeddings = []
|
| 352 |
+
for i, seq in enumerate(seqs):
|
| 353 |
+
chunks = self._chunk_sequence(seq)
|
| 354 |
+
chunk_vecs = []
|
| 355 |
+
for chunk in chunks:
|
| 356 |
+
batch = [(str(i), chunk)]
|
| 357 |
+
_, _, toks = self.batch_converter(batch)
|
| 358 |
+
toks = toks.to(self.device)
|
| 359 |
+
with torch.no_grad():
|
| 360 |
+
out = self.model(toks, repr_layers=[33], return_contacts=False)
|
| 361 |
+
reps = out["representations"][33]
|
| 362 |
+
if reps.size(1) > 2:
|
| 363 |
+
rep = reps[:, 1:-1].mean(1)
|
| 364 |
+
else:
|
| 365 |
+
rep = reps.mean(1)
|
| 366 |
+
chunk_vecs.append(rep.squeeze(0))
|
| 367 |
+
if len(chunk_vecs) == 1:
|
| 368 |
+
seq_vec = chunk_vecs[0]
|
| 369 |
+
else:
|
| 370 |
+
stacked = torch.stack(chunk_vecs, dim=0)
|
| 371 |
+
seq_vec = stacked.mean(0)
|
| 372 |
+
all_embeddings.append(seq_vec.cpu().numpy())
|
| 373 |
+
return np.vstack(all_embeddings)
|
| 374 |
+
|
| 375 |
+
class GPNEmbedder:
|
| 376 |
+
def __init__(self, device):
|
| 377 |
+
model_name = "songlab/gpn-msa-sapiens"
|
| 378 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 379 |
+
self.model = AutoModelForMaskedLM.from_pretrained(model_name)
|
| 380 |
+
self.model.to(device)
|
| 381 |
+
self.model.eval()
|
| 382 |
+
self.device = device
|
| 383 |
+
|
| 384 |
+
def embed(self, seqs):
|
| 385 |
+
inputs = self.tokenizer(
|
| 386 |
+
seqs,
|
| 387 |
+
return_tensors="pt",
|
| 388 |
+
padding=True,
|
| 389 |
+
truncation=True
|
| 390 |
+
).to(self.device)
|
| 391 |
+
|
| 392 |
+
with torch.no_grad():
|
| 393 |
+
last_hidden = self.model(**inputs).last_hidden_state
|
| 394 |
+
return last_hidden.mean(dim=1).cpu().numpy()
|
| 395 |
+
|
| 396 |
+
class ProGenEmbedder:
|
| 397 |
+
def __init__(self, device):
|
| 398 |
+
model_name = "jinyuan22/ProGen2-base"
|
| 399 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 400 |
+
self.model = AutoModel.from_pretrained(model_name).to(device).eval()
|
| 401 |
+
self.device = device
|
| 402 |
+
|
| 403 |
+
def embed(self, seqs):
|
| 404 |
+
inputs = self.tokenizer(
|
| 405 |
+
seqs,
|
| 406 |
+
return_tensors="pt",
|
| 407 |
+
padding=True,
|
| 408 |
+
truncation=True
|
| 409 |
+
).to(self.device)
|
| 410 |
+
with torch.no_grad():
|
| 411 |
+
last_hidden = self.model(**inputs).last_hidden_state
|
| 412 |
+
return last_hidden.mean(dim=1).cpu().numpy()
|
| 413 |
+
|
| 414 |
+
# ---- main pipeline ----
|
| 415 |
+
|
| 416 |
+
def get_embedder(name, device, for_dna=True):
|
| 417 |
+
name = name.lower()
|
| 418 |
+
if for_dna:
|
| 419 |
+
if name=="caduceus": return CaduceusEmbedder(device)
|
| 420 |
+
if name=="dnabert": return DNABertEmbedder(device)
|
| 421 |
+
if name=="nucleotide": return NucleotideTransformerEmbedder(device)
|
| 422 |
+
if name=="gpn": return GPNEmbedder(device)
|
| 423 |
+
if name=="segmentnt": return SegmentNTEmbedder(device)
|
| 424 |
+
else:
|
| 425 |
+
if name in ("esm",): return ESMEmbedder(device)
|
| 426 |
+
if name in ("esm-dbp","esm_dbp"): return ESMDBPEmbedder(device)
|
| 427 |
+
if name=="progen": return ProGenEmbedder(device)
|
| 428 |
+
raise ValueError(f"Unknown model {name} (for_dna={for_dna})")
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
def pad_token_embeddings(list_of_arrays, pad_value=0.0):
|
| 432 |
+
"""
|
| 433 |
+
list_of_arrays: list of (L_i, D) numpy arrays
|
| 434 |
+
Returns:
|
| 435 |
+
padded: (N, L_max, D) array
|
| 436 |
+
mask: (N, L_max) boolean array where True = real token, False = padding
|
| 437 |
+
"""
|
| 438 |
+
N = len(list_of_arrays)
|
| 439 |
+
D = list_of_arrays[0].shape[1]
|
| 440 |
+
L_max = max(arr.shape[0] for arr in list_of_arrays)
|
| 441 |
+
padded = np.full((N, L_max, D), pad_value, dtype=list_of_arrays[0].dtype)
|
| 442 |
+
mask = np.zeros((N, L_max), dtype=bool)
|
| 443 |
+
for i, arr in enumerate(list_of_arrays):
|
| 444 |
+
L = arr.shape[0]
|
| 445 |
+
padded[i, :L] = arr
|
| 446 |
+
mask[i, :L] = True
|
| 447 |
+
return padded, mask
|
| 448 |
+
|
| 449 |
+
def embed_and_save(seqs, ids, embedder, out_path):
|
| 450 |
+
embs = embedder.embed(seqs)
|
| 451 |
+
|
| 452 |
+
# Decide whether we got variable-length per-token outputs (list of (L, D))
|
| 453 |
+
is_variable_token = isinstance(embs, (list, tuple)) and len(embs) > 0 and hasattr(embs[0], "shape") and embs[0].ndim == 2
|
| 454 |
+
|
| 455 |
+
if is_variable_token:
|
| 456 |
+
# pad to (N, L_max, D) + mask
|
| 457 |
+
padded, mask = pad_token_embeddings(embs)
|
| 458 |
+
# Save both embeddings and mask together in an .npz for convenience
|
| 459 |
+
np.savez_compressed(out_path.with_suffix(".caduceus.npz"),
|
| 460 |
+
embeddings=padded,
|
| 461 |
+
mask=mask,
|
| 462 |
+
ids=np.array(ids, dtype=object))
|
| 463 |
+
else:
|
| 464 |
+
# fixed shape output, e.g., pooled (N, D)
|
| 465 |
+
array = np.vstack(embs) if isinstance(embs, list) else embs
|
| 466 |
+
np.save(out_path, array)
|
| 467 |
+
with open(out_path.with_suffix(".ids"), "w") as f:
|
| 468 |
+
f.write("\n".join(ids))
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
if __name__=="__main__":
|
| 472 |
+
|
| 473 |
+
p = argparse.ArgumentParser()
|
| 474 |
+
p.add_argument("--peak-fasta", default="binding_peaks_unique.fa", help="FASTA of deduplicated binding peak sequences; if present this is used for DNA embedding instead of genome JSONs")
|
| 475 |
+
p.add_argument("--genome-json-dir", default=None, help="(fallback) directory of UCSC JSONs for full chromosome embedding if peak FASTA is missing or you explicitly want chromosomes")
|
| 476 |
+
p.add_argument("--skip-dna", action="store_true", help="if set, skip the chromosome embedding step") #if glm embeddings successful but not plm embeddings
|
| 477 |
+
p.add_argument("--tf-fasta", required=True, help="input TF FASTA file")
|
| 478 |
+
p.add_argument("--chrom-model", default="caduceus")
|
| 479 |
+
p.add_argument("--tf-model", default="esm-dbp")
|
| 480 |
+
p.add_argument("--out-dir", default="data_files/processed/tfclust/hg38_tf/embeddings")
|
| 481 |
+
p.add_argument("--device", default="cpu")
|
| 482 |
+
args = p.parse_args()
|
| 483 |
+
|
| 484 |
+
os.makedirs(args.out_dir, exist_ok=True)
|
| 485 |
+
device = args.device
|
| 486 |
+
|
| 487 |
+
if not args.skip_dna:
|
| 488 |
+
peak_fasta = Path(args.peak_fasta)
|
| 489 |
+
if peak_fasta.exists():
|
| 490 |
+
# Load peak sequences from FASTA
|
| 491 |
+
from Bio import SeqIO
|
| 492 |
+
|
| 493 |
+
peak_seqs = []
|
| 494 |
+
peak_ids = []
|
| 495 |
+
for rec in SeqIO.parse(peak_fasta, "fasta"):
|
| 496 |
+
peak_ids.append(rec.id)
|
| 497 |
+
peak_seqs.append(str(rec.seq))
|
| 498 |
+
print(f"Embedding {len(peak_seqs)} binding peak sequences from {peak_fasta}", flush=True)
|
| 499 |
+
dna_embedder = get_embedder(args.chrom_model, device, for_dna=True)
|
| 500 |
+
out_peaks = Path(args.out_dir) / f"peaks_{args.chrom_model}.npy"
|
| 501 |
+
embed_and_save(peak_seqs, peak_ids, dna_embedder, out_peaks)
|
| 502 |
+
elif args.genome_json_dir:
|
| 503 |
+
# Legacy: load full chromosomes from JSONs (chr1–22, X, Y, M)
|
| 504 |
+
genome_dir = Path(args.genome_json_dir)
|
| 505 |
+
chrom_seqs, chrom_ids = [], []
|
| 506 |
+
primary_pattern = re.compile(r"^hg38_chr(?:[1-9]|1[0-9]|2[0-2]|X|Y|M)\.json$")
|
| 507 |
+
for j in sorted(genome_dir.iterdir()):
|
| 508 |
+
if not primary_pattern.match(j.name):
|
| 509 |
+
continue
|
| 510 |
+
data = json.loads(j.read_text())
|
| 511 |
+
seq = data.get("dna") or data.get("sequence")
|
| 512 |
+
chrom = data.get("chrom") or j.stem.split("_")[-1]
|
| 513 |
+
chrom_seqs.append(seq)
|
| 514 |
+
chrom_ids.append(chrom)
|
| 515 |
+
cutoff = CaduceusEmbedder(device).chunk_size
|
| 516 |
+
long_chroms = [
|
| 517 |
+
(chrom, len(seq))
|
| 518 |
+
for chrom, seq in zip(chrom_ids, chrom_seqs)
|
| 519 |
+
if len(seq) > cutoff
|
| 520 |
+
]
|
| 521 |
+
if long_chroms:
|
| 522 |
+
print("⚠️ Chromosomes exceeding Caduceus max tokens ({}):".format(cutoff))
|
| 523 |
+
for chrom, L in long_chroms:
|
| 524 |
+
print(f" {chrom}: {L} bases")
|
| 525 |
+
else:
|
| 526 |
+
print("All chromosomes ≤ Caduceus limit ({}).".format(cutoff))
|
| 527 |
+
|
| 528 |
+
chrom_embedder = get_embedder(args.chrom_model, device, for_dna=True)
|
| 529 |
+
out_chrom = Path(args.out_dir) / f"chrom_{args.chrom_model}.npy"
|
| 530 |
+
embed_and_save(chrom_seqs, chrom_ids, chrom_embedder, out_chrom)
|
| 531 |
+
else:
|
| 532 |
+
raise ValueError("No input for DNA embedding: provide a peak FASTA (default binding_peaks_unique.fa) or set --genome-json-dir for chromosome JSONs.")
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
#Load TF sequences
|
| 536 |
+
tf_seqs, tf_ids = [], []
|
| 537 |
+
for record in SeqIO.parse(args.tf_fasta, "fasta"):
|
| 538 |
+
tf_ids.append(record.id)
|
| 539 |
+
tf_seqs.append(str(record.seq))
|
| 540 |
+
|
| 541 |
+
# embed and save
|
| 542 |
+
tf_embedder = get_embedder(args.tf_model, device, for_dna=False)
|
| 543 |
+
out_tf = Path(args.out_dir) / f"tf_{args.tf_model}.npy"
|
| 544 |
+
embed_and_save(tf_seqs, tf_ids, tf_embedder, out_tf)
|
| 545 |
+
|
| 546 |
+
print("Done.")
|
dpacman/classifier/model/extract_tf_symbols.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import pandas as pd
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
FINAL_CSV = Path("/home/a03-akrishna/DPACMAN/data_files/processed/final.csv")
|
| 6 |
+
OUT_SYMBOLS = Path("tf_symbols.txt")
|
| 7 |
+
|
| 8 |
+
def normalize_tf(tf_id: str) -> str:
|
| 9 |
+
return tf_id.split("_seq")[0].upper()
|
| 10 |
+
|
| 11 |
+
def main():
|
| 12 |
+
df = pd.read_csv(FINAL_CSV, dtype=str)
|
| 13 |
+
if "TF_id" not in df.columns:
|
| 14 |
+
raise RuntimeError("final.csv missing TF_id column")
|
| 15 |
+
tf_raw = df["TF_id"].dropna().unique().tolist()
|
| 16 |
+
normalized = sorted({normalize_tf(t) for t in tf_raw})
|
| 17 |
+
print(f"Unique raw TF_id count: {len(tf_raw)}")
|
| 18 |
+
print(f"Unique normalized TF symbols: {len(normalized)}")
|
| 19 |
+
with open(OUT_SYMBOLS, "w") as f:
|
| 20 |
+
for s in normalized:
|
| 21 |
+
f.write(s + "\n")
|
| 22 |
+
print(f"Wrote normalized TF symbols to {OUT_SYMBOLS}")
|
| 23 |
+
# Optional: show sample
|
| 24 |
+
print("Sample symbols:", normalized[:50])
|
| 25 |
+
|
| 26 |
+
if __name__ == "__main__":
|
| 27 |
+
main()
|
dpacman/classifier/model/make_pair_list.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import argparse
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import random
|
| 7 |
+
import sys
|
| 8 |
+
|
| 9 |
+
def read_ids_file(p):
|
| 10 |
+
p = Path(p)
|
| 11 |
+
if not p.exists():
|
| 12 |
+
raise FileNotFoundError(f"IDs file not found: {p}")
|
| 13 |
+
return [line.strip() for line in p.open() if line.strip()]
|
| 14 |
+
|
| 15 |
+
def split_embeddings(emb_path, ids_path, out_dir, prefix):
|
| 16 |
+
out_dir = Path(out_dir)
|
| 17 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 18 |
+
|
| 19 |
+
if not Path(emb_path).exists():
|
| 20 |
+
raise FileNotFoundError(f"Embedding file not found: {emb_path}")
|
| 21 |
+
if not Path(ids_path).exists():
|
| 22 |
+
raise FileNotFoundError(f"IDs file not found: {ids_path}")
|
| 23 |
+
|
| 24 |
+
if emb_path.endswith(".npz"):
|
| 25 |
+
data = np.load(emb_path, allow_pickle=True)
|
| 26 |
+
if "embeddings" in data:
|
| 27 |
+
emb = data["embeddings"]
|
| 28 |
+
else:
|
| 29 |
+
raise ValueError(f"{emb_path} missing 'embeddings' key")
|
| 30 |
+
else:
|
| 31 |
+
emb = np.load(emb_path)
|
| 32 |
+
|
| 33 |
+
ids = read_ids_file(ids_path)
|
| 34 |
+
if len(ids) != emb.shape[0]:
|
| 35 |
+
print(f"[WARN] length mismatch: {len(ids)} ids vs {emb.shape[0]} embeddings in {emb_path}", file=sys.stderr)
|
| 36 |
+
|
| 37 |
+
mapping = {}
|
| 38 |
+
for i, ident in enumerate(ids):
|
| 39 |
+
if i >= emb.shape[0]:
|
| 40 |
+
print(f"[WARN] skipping {ident}: no embedding at index {i}", file=sys.stderr)
|
| 41 |
+
continue
|
| 42 |
+
arr = emb[i]
|
| 43 |
+
out_file = out_dir / f"{prefix}_{ident}.npy"
|
| 44 |
+
np.save(out_file, arr)
|
| 45 |
+
mapping[ident] = str(out_file)
|
| 46 |
+
return mapping
|
| 47 |
+
|
| 48 |
+
def extract_symbol_from_tf_id(full_id: str) -> str:
|
| 49 |
+
"""
|
| 50 |
+
Given a TF embedding ID like 'sp|O15062|ZBTB5_HUMAN' or 'ZBTB5_HUMAN',
|
| 51 |
+
return the gene symbol uppercase (e.g., 'ZBTB5').
|
| 52 |
+
"""
|
| 53 |
+
if "|" in full_id:
|
| 54 |
+
try:
|
| 55 |
+
# format sp|Accession|SYMBOL_HUMAN
|
| 56 |
+
genepart = full_id.split("|")[2]
|
| 57 |
+
except IndexError:
|
| 58 |
+
genepart = full_id
|
| 59 |
+
else:
|
| 60 |
+
genepart = full_id
|
| 61 |
+
symbol = genepart.split("_")[0]
|
| 62 |
+
return symbol.upper()
|
| 63 |
+
|
| 64 |
+
def build_tf_symbol_map(tf_map):
|
| 65 |
+
"""
|
| 66 |
+
Build mapping gene_symbol -> list of embedding paths.
|
| 67 |
+
"""
|
| 68 |
+
symbol_map = {}
|
| 69 |
+
for full_id, path in tf_map.items():
|
| 70 |
+
symbol = extract_symbol_from_tf_id(full_id)
|
| 71 |
+
symbol_map.setdefault(symbol, []).append(path)
|
| 72 |
+
return symbol_map
|
| 73 |
+
|
| 74 |
+
def tf_key_from_path(path: str) -> str:
|
| 75 |
+
"""
|
| 76 |
+
Given a path like .../tf_sp|O15062|ZBTB5_HUMAN.npy, extract normalized symbol 'ZBTB5'.
|
| 77 |
+
"""
|
| 78 |
+
stem = Path(path).stem # e.g., tf_sp|O15062|ZBTB5_HUMAN
|
| 79 |
+
# remove leading prefix if present (tf_)
|
| 80 |
+
if "_" in stem:
|
| 81 |
+
_, rest = stem.split("_", 1)
|
| 82 |
+
else:
|
| 83 |
+
rest = stem
|
| 84 |
+
return extract_symbol_from_tf_id(rest)
|
| 85 |
+
|
| 86 |
+
def dna_key_from_path(path: str) -> str:
|
| 87 |
+
"""
|
| 88 |
+
Given .../dna_peak42.npy -> 'peak42'
|
| 89 |
+
"""
|
| 90 |
+
stem = Path(path).stem
|
| 91 |
+
if "_" in stem:
|
| 92 |
+
_, rest = stem.split("_", 1)
|
| 93 |
+
else:
|
| 94 |
+
rest = stem
|
| 95 |
+
return rest
|
| 96 |
+
|
| 97 |
+
def main():
|
| 98 |
+
parser = argparse.ArgumentParser(
|
| 99 |
+
description="Build TF-DNA pair list from final.csv with gene-symbol normalization for TFs."
|
| 100 |
+
)
|
| 101 |
+
parser.add_argument("--final_csv", required=True, help="final.csv with TF_id and dna_sequence")
|
| 102 |
+
parser.add_argument("--dna_embed_npz", required=True, help="DNA embedding file (.npy or .npz)")
|
| 103 |
+
parser.add_argument("--dna_ids", required=True, help="IDs file for DNA embeddings (e.g., peak*.ids)")
|
| 104 |
+
parser.add_argument("--tf_embed_npy", required=True, help="TF embedding file (.npy or .npz)")
|
| 105 |
+
parser.add_argument("--tf_ids", required=True, help="IDs file for TF embeddings (e.g., sp|...|... ids)")
|
| 106 |
+
parser.add_argument("--out_dir", required=True, help="Output directory")
|
| 107 |
+
parser.add_argument("--neg_per_positive", type=int, default=2, help="Negatives per positive (half same-TF, half same-DNA)")
|
| 108 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 109 |
+
args = parser.parse_args()
|
| 110 |
+
|
| 111 |
+
random.seed(args.seed)
|
| 112 |
+
out_dir = Path(args.out_dir)
|
| 113 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 114 |
+
|
| 115 |
+
# Load final.csv
|
| 116 |
+
df = pd.read_csv(args.final_csv, dtype=str)
|
| 117 |
+
if "TF_id" not in df.columns or "dna_sequence" not in df.columns:
|
| 118 |
+
raise RuntimeError("final.csv must have columns TF_id and dna_sequence")
|
| 119 |
+
|
| 120 |
+
# Assign dna_id (unique per dna_sequence)
|
| 121 |
+
unique_seqs = df["dna_sequence"].drop_duplicates().tolist()
|
| 122 |
+
seq_to_id = {seq: f"peak{i}" for i, seq in enumerate(unique_seqs)}
|
| 123 |
+
df["dna_id"] = df["dna_sequence"].map(seq_to_id)
|
| 124 |
+
enriched_csv = out_dir / "final_with_dna_id.csv"
|
| 125 |
+
df.to_csv(enriched_csv, index=False)
|
| 126 |
+
print(f"[i] Wrote augmented final.csv with dna_id to {enriched_csv}")
|
| 127 |
+
|
| 128 |
+
# Split embeddings into per-item files
|
| 129 |
+
print(f"[i] Splitting DNA embeddings from {args.dna_embed_npz} with ids {args.dna_ids}")
|
| 130 |
+
dna_map = split_embeddings(args.dna_embed_npz, args.dna_ids, out_dir / "dna_single", "dna")
|
| 131 |
+
print(f"[i] DNA embeddings available: {len(dna_map)} (sample: {list(dna_map.keys())[:10]})")
|
| 132 |
+
print(f"[i] Splitting TF embeddings from {args.tf_embed_npy} with ids {args.tf_ids}")
|
| 133 |
+
tf_map = split_embeddings(args.tf_embed_npy, args.tf_ids, out_dir / "tf_single", "tf")
|
| 134 |
+
print(f"[i] TF embeddings available: {len(tf_map)} (sample: {list(tf_map.keys())[:10]})")
|
| 135 |
+
|
| 136 |
+
# Build gene-symbol normalized map
|
| 137 |
+
tf_symbol_map = build_tf_symbol_map(tf_map)
|
| 138 |
+
print(f"[i] TF symbol map keys (sample): {list(tf_symbol_map.keys())[:30]}")
|
| 139 |
+
|
| 140 |
+
# Diagnostic overlaps
|
| 141 |
+
norm_tf_in_final = set(t.split("_seq")[0].upper() for t in df["TF_id"].unique())
|
| 142 |
+
available_tf_symbols = set(tf_symbol_map.keys())
|
| 143 |
+
intersect_tf = norm_tf_in_final & available_tf_symbols
|
| 144 |
+
print(f"[i] Unique normalized TF symbols in final.csv: {len(norm_tf_in_final)}")
|
| 145 |
+
print(f"[i] Available TF embedding symbols: {len(available_tf_symbols)}")
|
| 146 |
+
print(f"[i] Intersection count: {len(intersect_tf)}")
|
| 147 |
+
if len(intersect_tf) == 0:
|
| 148 |
+
print("[ERROR] No overlap between normalized TF_id and TF embedding symbols.", file=sys.stderr)
|
| 149 |
+
print("Sample normalized TFs from final.csv:", sorted(list(norm_tf_in_final))[:30], file=sys.stderr)
|
| 150 |
+
print("Sample available TF symbols:", sorted(list(available_tf_symbols))[:30], file=sys.stderr)
|
| 151 |
+
sys.exit(1)
|
| 152 |
+
|
| 153 |
+
dna_ids_final = set(df["dna_id"].unique())
|
| 154 |
+
available_dna_ids = set(dna_map.keys())
|
| 155 |
+
intersect_dna = dna_ids_final & available_dna_ids
|
| 156 |
+
print(f"[i] Unique dna_id in final.csv: {len(dna_ids_final)}. Available DNA ids: {len(available_dna_ids)}. Intersection: {len(intersect_dna)}")
|
| 157 |
+
if len(intersect_dna) == 0:
|
| 158 |
+
print("[ERROR] No overlap on DNA ids.", file=sys.stderr)
|
| 159 |
+
sys.exit(1)
|
| 160 |
+
|
| 161 |
+
# Build positive pairs
|
| 162 |
+
positives = []
|
| 163 |
+
for _, row in df.iterrows():
|
| 164 |
+
tf_raw = row["TF_id"]
|
| 165 |
+
tf_symbol = tf_raw.split("_seq")[0].upper()
|
| 166 |
+
dnaid = row["dna_id"]
|
| 167 |
+
if tf_symbol not in tf_symbol_map:
|
| 168 |
+
continue
|
| 169 |
+
if dnaid not in dna_map:
|
| 170 |
+
continue
|
| 171 |
+
# pick the first embedding for that symbol
|
| 172 |
+
tf_embedding_path = tf_symbol_map[tf_symbol][0]
|
| 173 |
+
positives.append((tf_embedding_path, dna_map[dnaid], 1))
|
| 174 |
+
print(f"[i] Constructed {len(positives)} positive pairs after TF symbol resolution")
|
| 175 |
+
|
| 176 |
+
if len(positives) == 0:
|
| 177 |
+
print("[ERROR] No positive pairs could be constructed; aborting.", file=sys.stderr)
|
| 178 |
+
sys.exit(1)
|
| 179 |
+
|
| 180 |
+
# Build negative samples
|
| 181 |
+
all_tf_symbols = sorted(tf_symbol_map.keys())
|
| 182 |
+
all_dnaids = sorted(dna_map.keys())
|
| 183 |
+
positive_set = set()
|
| 184 |
+
for tf_path, dna_path, _ in positives:
|
| 185 |
+
tf_key = tf_key_from_path(tf_path)
|
| 186 |
+
dna_key = dna_key_from_path(dna_path)
|
| 187 |
+
positive_set.add((tf_key, dna_key))
|
| 188 |
+
|
| 189 |
+
negatives = []
|
| 190 |
+
half = args.neg_per_positive // 2
|
| 191 |
+
for tf_path, dna_path, _ in positives:
|
| 192 |
+
tf_key = tf_key_from_path(tf_path)
|
| 193 |
+
dna_key = dna_key_from_path(dna_path)
|
| 194 |
+
# same TF, different DNA
|
| 195 |
+
for _ in range(half):
|
| 196 |
+
candidate_dna = random.choice(all_dnaids)
|
| 197 |
+
if candidate_dna == dna_key or (tf_key, candidate_dna) in positive_set:
|
| 198 |
+
continue
|
| 199 |
+
negatives.append((tf_path, dna_map[candidate_dna], 0))
|
| 200 |
+
# same DNA, different TF
|
| 201 |
+
for _ in range(half):
|
| 202 |
+
candidate_tf_symbol = random.choice(all_tf_symbols)
|
| 203 |
+
if candidate_tf_symbol == tf_key or (candidate_tf_symbol, dna_key) in positive_set:
|
| 204 |
+
continue
|
| 205 |
+
# pick its first embedding
|
| 206 |
+
candidate_tf_path = tf_symbol_map[candidate_tf_symbol][0]
|
| 207 |
+
negatives.append((candidate_tf_path, dna_map[dnaid], 0))
|
| 208 |
+
|
| 209 |
+
print(f"[i] Sampled {len(negatives)} negatives (neg_per_positive={args.neg_per_positive})")
|
| 210 |
+
|
| 211 |
+
# Write pair list
|
| 212 |
+
pair_list_path = out_dir / "pair_list.tsv"
|
| 213 |
+
with open(pair_list_path, "w") as f:
|
| 214 |
+
for binder_path, glm_path, label in positives + negatives:
|
| 215 |
+
# binder=TF, glm=DNA
|
| 216 |
+
f.write(f"{binder_path}\t{glm_path}\t{label}\n")
|
| 217 |
+
print(f"[i] Wrote {len(positives)+len(negatives)} examples to {pair_list_path}")
|
| 218 |
+
|
| 219 |
+
if __name__ == "__main__":
|
| 220 |
+
main()
|
dpacman/classifier/model/make_peak_fasta.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
df = pd.read_csv("/home/a03-akrishna/DPACMAN/data_files/processed/final.csv", dtype=str) # adjust path if needed
|
| 5 |
+
# get unique sequences
|
| 6 |
+
uniq = df[["dna_sequence"]].drop_duplicates().reset_index(drop=True)
|
| 7 |
+
# make headers: e.g., peak0, peak1, ...
|
| 8 |
+
out_fa = Path("binding_peaks_unique.fa")
|
| 9 |
+
with open(out_fa, "w") as f:
|
| 10 |
+
for i, seq in enumerate(uniq["dna_sequence"]):
|
| 11 |
+
header = f">peak{i}"
|
| 12 |
+
f.write(f"{header}\n{seq}\n")
|
| 13 |
+
print(f"Wrote {len(uniq)} unique binding sequences to {out_fa}")
|
dpacman/classifier/model/model.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
|
| 4 |
+
class LocalCNN(nn.Module):
|
| 5 |
+
def __init__(self, dim: int = 256, kernel_size: int = 3):
|
| 6 |
+
super().__init__()
|
| 7 |
+
padding = kernel_size // 2
|
| 8 |
+
self.conv = nn.Conv1d(dim, dim, kernel_size=kernel_size, padding=padding)
|
| 9 |
+
self.act = nn.GELU()
|
| 10 |
+
self.ln = nn.LayerNorm(dim)
|
| 11 |
+
|
| 12 |
+
def forward(self, x: torch.Tensor):
|
| 13 |
+
# x: (batch, L, dim)
|
| 14 |
+
out = self.conv(x.transpose(1, 2)) # → (batch, dim, L)
|
| 15 |
+
out = self.act(out)
|
| 16 |
+
out = out.transpose(1, 2) # → (batch, L, dim)
|
| 17 |
+
return self.ln(out + x) # residual
|
| 18 |
+
|
| 19 |
+
class CrossModalBlock(nn.Module):
|
| 20 |
+
def __init__(self, dim: int = 256, heads: int = 8):
|
| 21 |
+
super().__init__()
|
| 22 |
+
# self-attention for both sides
|
| 23 |
+
self.sa_binder = nn.MultiheadAttention(dim, heads, batch_first=True)
|
| 24 |
+
self.sa_glm = nn.MultiheadAttention(dim, heads, batch_first=True)
|
| 25 |
+
self.ln_b1 = nn.LayerNorm(dim)
|
| 26 |
+
self.ln_g1 = nn.LayerNorm(dim)
|
| 27 |
+
|
| 28 |
+
self.ffn_b = nn.Sequential(nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim))
|
| 29 |
+
self.ffn_g = nn.Sequential(nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim))
|
| 30 |
+
self.ln_b2 = nn.LayerNorm(dim)
|
| 31 |
+
self.ln_g2 = nn.LayerNorm(dim)
|
| 32 |
+
|
| 33 |
+
# cross attention (binder queries, glm keys/values)
|
| 34 |
+
self.cross_attn = nn.MultiheadAttention(dim, heads, batch_first=True)
|
| 35 |
+
self.ln_c1 = nn.LayerNorm(dim)
|
| 36 |
+
self.ffn_c = nn.Sequential(nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim))
|
| 37 |
+
self.ln_c2 = nn.LayerNorm(dim)
|
| 38 |
+
|
| 39 |
+
def forward(self, binder: torch.Tensor, glm: torch.Tensor):
|
| 40 |
+
"""
|
| 41 |
+
binder: (batch, Lb, dim)
|
| 42 |
+
glm: (batch, Lg, dim) -- has passed through its local CNN beforehand
|
| 43 |
+
returns: updated binder representation (batch, Lb, dim)
|
| 44 |
+
"""
|
| 45 |
+
# binder self-attn + ffn
|
| 46 |
+
b = binder
|
| 47 |
+
b_sa, _ = self.sa_binder(b, b, b)
|
| 48 |
+
b = self.ln_b1(b + b_sa)
|
| 49 |
+
b_ff = self.ffn_b(b)
|
| 50 |
+
b = self.ln_b2(b + b_ff)
|
| 51 |
+
|
| 52 |
+
# glm self-attn + ffn
|
| 53 |
+
g = glm
|
| 54 |
+
g_sa, _ = self.sa_glm(g, g, g)
|
| 55 |
+
g = self.ln_g1(g + g_sa)
|
| 56 |
+
g_ff = self.ffn_g(g)
|
| 57 |
+
g = self.ln_g2(g + g_ff)
|
| 58 |
+
|
| 59 |
+
# cross-attention: binder queries glm
|
| 60 |
+
c_sa, _ = self.cross_attn(b, g, g)
|
| 61 |
+
c = self.ln_c1(b + c_sa)
|
| 62 |
+
c_ff = self.ffn_c(c)
|
| 63 |
+
c = self.ln_c2(c + c_ff)
|
| 64 |
+
return c # (batch, Lb, dim)
|
| 65 |
+
|
| 66 |
+
class BindPredictor(nn.Module):
|
| 67 |
+
def __init__(self,
|
| 68 |
+
input_dim: int = 256,
|
| 69 |
+
hidden_dim: int = 256,
|
| 70 |
+
heads: int = 8,
|
| 71 |
+
num_layers: int = 4,
|
| 72 |
+
use_local_cnn_on_glm: bool = True):
|
| 73 |
+
super().__init__()
|
| 74 |
+
self.proj_binder = nn.Linear(input_dim, hidden_dim)
|
| 75 |
+
self.proj_glm = nn.Linear(input_dim, hidden_dim)
|
| 76 |
+
self.use_local_cnn = use_local_cnn_on_glm
|
| 77 |
+
self.local_cnn = LocalCNN(hidden_dim) if use_local_cnn_on_glm else nn.Identity()
|
| 78 |
+
|
| 79 |
+
self.layers = nn.ModuleList([
|
| 80 |
+
CrossModalBlock(hidden_dim, heads) for _ in range(num_layers)
|
| 81 |
+
])
|
| 82 |
+
|
| 83 |
+
self.ln_out = nn.LayerNorm(hidden_dim)
|
| 84 |
+
self.head = nn.Sequential(
|
| 85 |
+
nn.Linear(hidden_dim, 1),
|
| 86 |
+
nn.Sigmoid()
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
def forward(self, binder_emb, glm_emb):
|
| 90 |
+
"""
|
| 91 |
+
binder_emb, glm_emb: (batch, L, input_dim)
|
| 92 |
+
"""
|
| 93 |
+
b = self.proj_binder(binder_emb) # (B, Lb, hidden_dim)
|
| 94 |
+
g = self.proj_glm(glm_emb) # (B, Lg, hidden_dim)
|
| 95 |
+
if self.use_local_cnn:
|
| 96 |
+
g = self.local_cnn(g) # local context injected
|
| 97 |
+
|
| 98 |
+
for layer in self.layers:
|
| 99 |
+
b = layer(b, g) # update binder with cross-modal info
|
| 100 |
+
|
| 101 |
+
pooled = b.mean(dim=1) # (B, hidden_dim)
|
| 102 |
+
out = self.ln_out(pooled)
|
| 103 |
+
return self.head(out).squeeze(-1) # (B,)
|
dpacman/classifier/model/train.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import argparse
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
from torch.utils.data import Dataset, DataLoader
|
| 7 |
+
from model import BindPredictor
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from collections import Counter
|
| 10 |
+
from sklearn.metrics import roc_auc_score, average_precision_score
|
| 11 |
+
from sklearn.decomposition import TruncatedSVD
|
| 12 |
+
import random
|
| 13 |
+
import sys
|
| 14 |
+
|
| 15 |
+
# ---- dataset ---------------------------------------------------------
|
| 16 |
+
class PairDataset(Dataset):
|
| 17 |
+
def __init__(self, binder_paths, glm_paths, labels, tf_compressed_cache):
|
| 18 |
+
"""
|
| 19 |
+
tf_compressed_cache: dict mapping binder_path -> compressed (256-d) tensor/array
|
| 20 |
+
"""
|
| 21 |
+
assert len(binder_paths) == len(glm_paths) == len(labels)
|
| 22 |
+
self.binder_paths = binder_paths
|
| 23 |
+
self.glm_paths = glm_paths
|
| 24 |
+
self.labels = labels
|
| 25 |
+
self.tf_cache = tf_compressed_cache # already reduced to 256
|
| 26 |
+
|
| 27 |
+
def __len__(self):
|
| 28 |
+
return len(self.labels)
|
| 29 |
+
|
| 30 |
+
def __getitem__(self, idx):
|
| 31 |
+
# binder = TF embedding (possibly reduced)
|
| 32 |
+
b = self.tf_cache[self.binder_paths[idx]] # numpy array shape (L, 256) or (256,)
|
| 33 |
+
g = np.load(self.glm_paths[idx]) # glm (DNA) embedding
|
| 34 |
+
|
| 35 |
+
if b.ndim == 1:
|
| 36 |
+
b = b[None, :]
|
| 37 |
+
if g.ndim == 1:
|
| 38 |
+
g = g[None, :]
|
| 39 |
+
|
| 40 |
+
b_tensor = torch.from_numpy(b).float()
|
| 41 |
+
g_tensor = torch.from_numpy(g).float()
|
| 42 |
+
y = torch.tensor(self.labels[idx]).float()
|
| 43 |
+
return b_tensor, g_tensor, y
|
| 44 |
+
|
| 45 |
+
def collate_fn(batch):
|
| 46 |
+
binders, glms, labels = zip(*batch)
|
| 47 |
+
binder_lens = [b.shape[0] for b in binders]
|
| 48 |
+
glm_lens = [g.shape[0] for g in glms]
|
| 49 |
+
max_b = max(binder_lens)
|
| 50 |
+
max_g = max(glm_lens)
|
| 51 |
+
|
| 52 |
+
def pad_seq(seq, target_len):
|
| 53 |
+
L, D = seq.shape
|
| 54 |
+
if L < target_len:
|
| 55 |
+
pad = torch.zeros((target_len - L, D), dtype=seq.dtype, device=seq.device)
|
| 56 |
+
return torch.cat([seq, pad], dim=0)
|
| 57 |
+
return seq
|
| 58 |
+
|
| 59 |
+
b_padded = torch.stack([pad_seq(b, max_b) for b in binders]) # (B, Lb, D)
|
| 60 |
+
g_padded = torch.stack([pad_seq(g, max_g) for g in glms]) # (B, Lg, D)
|
| 61 |
+
y = torch.stack(labels)
|
| 62 |
+
return b_padded, g_padded, y
|
| 63 |
+
|
| 64 |
+
# ---- utilities -------------------------------------------------------
|
| 65 |
+
def parse_pair_list(pair_list_path):
|
| 66 |
+
binder_paths, glm_paths, labels = [], [], []
|
| 67 |
+
with open(pair_list_path) as f:
|
| 68 |
+
for lineno, line in enumerate(f, start=1):
|
| 69 |
+
if not line.strip():
|
| 70 |
+
continue
|
| 71 |
+
parts = line.strip().split()
|
| 72 |
+
if len(parts) != 3:
|
| 73 |
+
print(f"[WARN] skipping malformed line {lineno}: {line.strip()}", file=sys.stderr)
|
| 74 |
+
continue
|
| 75 |
+
b, g, l = parts
|
| 76 |
+
try:
|
| 77 |
+
lab = int(l)
|
| 78 |
+
except ValueError:
|
| 79 |
+
print(f"[WARN] invalid label on line {lineno}: {l}", file=sys.stderr)
|
| 80 |
+
continue
|
| 81 |
+
binder_paths.append(b)
|
| 82 |
+
glm_paths.append(g)
|
| 83 |
+
labels.append(lab)
|
| 84 |
+
return binder_paths, glm_paths, labels
|
| 85 |
+
|
| 86 |
+
def build_tf_compressed_cache(binder_paths, target_dim=256):
|
| 87 |
+
"""
|
| 88 |
+
Load all unique TF (binder) embeddings, fit reduction if needed, and return dict mapping path->(L, target_dim) array.
|
| 89 |
+
"""
|
| 90 |
+
unique_paths = sorted(set(binder_paths))
|
| 91 |
+
print(f"[i] Found {len(unique_paths)} unique TF embedding files to compress.", flush=True)
|
| 92 |
+
# Load all embeddings to determine dimensionality
|
| 93 |
+
samples = []
|
| 94 |
+
for p in unique_paths:
|
| 95 |
+
arr = np.load(p)
|
| 96 |
+
samples.append(arr)
|
| 97 |
+
# Determine if reduction needed: assume all have same embedding width
|
| 98 |
+
first = samples[0]
|
| 99 |
+
orig_dim = first.shape[1] if first.ndim == 2 else 1
|
| 100 |
+
reduction_needed = (orig_dim != target_dim)
|
| 101 |
+
tf_cache = {}
|
| 102 |
+
|
| 103 |
+
if reduction_needed:
|
| 104 |
+
# Build matrix to fit SVD: we need a 2D matrix per embedding; if lengths vary we can't directly stack.
|
| 105 |
+
# We'll do reduction per sequence individually using TruncatedSVD on concatenated flattened features:
|
| 106 |
+
# Simplest: for variable lengths, reduce each embedding separately with a learned linear projection.
|
| 107 |
+
# Here we fit a single TruncatedSVD on the concatenation of all sequence tokens (flattened) by padding/truncating to a fixed length.
|
| 108 |
+
# To avoid complexity, use PCA-like linear projection learned via SVD on mean-pooled vectors:
|
| 109 |
+
pooled = []
|
| 110 |
+
for arr in samples:
|
| 111 |
+
if arr.ndim == 2:
|
| 112 |
+
pooled.append(arr.mean(axis=0)) # (orig_dim,)
|
| 113 |
+
else:
|
| 114 |
+
pooled.append(arr) # degenerate
|
| 115 |
+
pooled_mat = np.stack(pooled, axis=0) # (N, orig_dim)
|
| 116 |
+
print(f"[i] Fitting TruncatedSVD on TF pooled embeddings: {pooled_mat.shape} -> {target_dim}", flush=True)
|
| 117 |
+
svd = TruncatedSVD(n_components=target_dim, random_state=42)
|
| 118 |
+
reduced_pooled = svd.fit_transform(pooled_mat) # (N, target_dim)
|
| 119 |
+
|
| 120 |
+
# For each original embedding, project token-level vectors by multiplying token vector with svd.components_.T
|
| 121 |
+
# svd.components_: (target_dim, orig_dim) so projection matrix is (orig_dim, target_dim)
|
| 122 |
+
proj_mat = svd.components_.T # (orig_dim, target_dim)
|
| 123 |
+
for i, p in enumerate(unique_paths):
|
| 124 |
+
arr = samples[i] # shape (L, orig_dim)
|
| 125 |
+
if arr.ndim == 1:
|
| 126 |
+
arr2 = arr @ proj_mat # (target_dim,)
|
| 127 |
+
else:
|
| 128 |
+
# project each token: (L, orig_dim) @ (orig_dim, target_dim) -> (L, target_dim)
|
| 129 |
+
arr2 = arr @ proj_mat
|
| 130 |
+
tf_cache[p] = arr2 # reduced per-token representation
|
| 131 |
+
print("[i] Completed compression of TF embeddings.", flush=True)
|
| 132 |
+
else:
|
| 133 |
+
# already correct dim: just cache originals
|
| 134 |
+
print(f"[i] TF embeddings already {target_dim}-dimensional; skipping reduction.", flush=True)
|
| 135 |
+
for i, p in enumerate(unique_paths):
|
| 136 |
+
arr = samples[i]
|
| 137 |
+
tf_cache[p] = arr
|
| 138 |
+
return tf_cache
|
| 139 |
+
|
| 140 |
+
def evaluate(model, dl, device):
|
| 141 |
+
model.eval()
|
| 142 |
+
all_labels = []
|
| 143 |
+
all_preds = []
|
| 144 |
+
with torch.no_grad():
|
| 145 |
+
for b, g, y in dl:
|
| 146 |
+
b = b.to(device)
|
| 147 |
+
g = g.to(device)
|
| 148 |
+
y = y.to(device)
|
| 149 |
+
pred = model(b, g)
|
| 150 |
+
all_labels.append(y.cpu())
|
| 151 |
+
all_preds.append(pred.cpu())
|
| 152 |
+
if not all_labels:
|
| 153 |
+
return 0.0, 0.0
|
| 154 |
+
y_true = torch.cat(all_labels).numpy()
|
| 155 |
+
y_score = torch.cat(all_preds).numpy()
|
| 156 |
+
try:
|
| 157 |
+
auc = roc_auc_score(y_true, y_score)
|
| 158 |
+
except Exception:
|
| 159 |
+
auc = 0.0
|
| 160 |
+
try:
|
| 161 |
+
ap = average_precision_score(y_true, y_score)
|
| 162 |
+
except Exception:
|
| 163 |
+
ap = 0.0
|
| 164 |
+
return auc, ap
|
| 165 |
+
|
| 166 |
+
# ---- main ------------------------------------------------------------
|
| 167 |
+
def main():
|
| 168 |
+
parser = argparse.ArgumentParser()
|
| 169 |
+
parser.add_argument("--pair_list", type=str, required=True,
|
| 170 |
+
help="TSV: binder_path glm_path label")
|
| 171 |
+
parser.add_argument("--out_dir", type=str, required=True)
|
| 172 |
+
parser.add_argument("--epochs", type=int, default=10)
|
| 173 |
+
parser.add_argument("--batch_size", type=int, default=32)
|
| 174 |
+
parser.add_argument("--lr", type=float, default=1e-4)
|
| 175 |
+
parser.add_argument("--device", type=str, default="cuda")
|
| 176 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 177 |
+
args = parser.parse_args()
|
| 178 |
+
|
| 179 |
+
# reproducibility
|
| 180 |
+
random.seed(args.seed)
|
| 181 |
+
np.random.seed(args.seed)
|
| 182 |
+
torch.manual_seed(args.seed)
|
| 183 |
+
|
| 184 |
+
print("DEBUG: starting training script with in-line TF compression", flush=True)
|
| 185 |
+
print(f"[i] pair_list: {args.pair_list}", flush=True)
|
| 186 |
+
print(f"[i] output dir: {args.out_dir}", flush=True)
|
| 187 |
+
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
|
| 188 |
+
binder_paths, glm_paths, labels = parse_pair_list(args.pair_list)
|
| 189 |
+
|
| 190 |
+
if len(labels) == 0:
|
| 191 |
+
print("[ERROR] No valid pairs parsed. Exiting.", file=sys.stderr)
|
| 192 |
+
sys.exit(1)
|
| 193 |
+
|
| 194 |
+
label_counts = Counter(labels)
|
| 195 |
+
print(f"[i] Total examples parsed: {len(labels)}. Label distribution: {label_counts}", flush=True)
|
| 196 |
+
|
| 197 |
+
# build compressed TF cache (reduces to 256 if needed)
|
| 198 |
+
tf_compressed_cache = build_tf_compressed_cache(binder_paths, target_dim=256)
|
| 199 |
+
|
| 200 |
+
# simple split: 80/10/10
|
| 201 |
+
n = len(labels)
|
| 202 |
+
idxs = np.arange(n)
|
| 203 |
+
np.random.shuffle(idxs)
|
| 204 |
+
train_i = idxs[: int(0.8 * n)]
|
| 205 |
+
val_i = idxs[int(0.8 * n): int(0.9 * n)]
|
| 206 |
+
test_i = idxs[int(0.9 * n):]
|
| 207 |
+
|
| 208 |
+
def subset(idxs):
|
| 209 |
+
return [binder_paths[i] for i in idxs], [glm_paths[i] for i in idxs], [labels[i] for i in idxs]
|
| 210 |
+
|
| 211 |
+
train_ds = PairDataset(*subset(train_i), tf_compressed_cache=tf_compressed_cache)
|
| 212 |
+
val_ds = PairDataset(*subset(val_i), tf_compressed_cache=tf_compressed_cache)
|
| 213 |
+
test_ds = PairDataset(*subset(test_i), tf_compressed_cache=tf_compressed_cache)
|
| 214 |
+
|
| 215 |
+
print(f"[i] Train/Val/Test sizes: {len(train_ds)}/{len(val_ds)}/{len(test_ds)}", flush=True)
|
| 216 |
+
if len(train_ds) == 0 or len(val_ds) == 0:
|
| 217 |
+
print("[ERROR] Train or validation split is empty; cannot proceed.", file=sys.stderr)
|
| 218 |
+
sys.exit(1)
|
| 219 |
+
|
| 220 |
+
train_dl = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn)
|
| 221 |
+
val_dl = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn)
|
| 222 |
+
test_dl = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn)
|
| 223 |
+
|
| 224 |
+
model = BindPredictor(input_dim=256, hidden_dim=256, heads=8, num_layers=3, use_local_cnn_on_glm=True)
|
| 225 |
+
model = model.to(device)
|
| 226 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-3)
|
| 227 |
+
loss_fn = nn.BCELoss()
|
| 228 |
+
|
| 229 |
+
best_val = -float("inf")
|
| 230 |
+
os_out = Path(args.out_dir)
|
| 231 |
+
os_out.mkdir(exist_ok=True, parents=True)
|
| 232 |
+
|
| 233 |
+
for epoch in range(1, args.epochs + 1):
|
| 234 |
+
print(f"[Epoch {epoch}] starting...", flush=True)
|
| 235 |
+
model.train()
|
| 236 |
+
running_loss = 0.0
|
| 237 |
+
for b, g, y in train_dl:
|
| 238 |
+
b = b.to(device)
|
| 239 |
+
g = g.to(device)
|
| 240 |
+
y = y.to(device)
|
| 241 |
+
pred = model(b, g)
|
| 242 |
+
loss = loss_fn(pred, y)
|
| 243 |
+
optimizer.zero_grad()
|
| 244 |
+
loss.backward()
|
| 245 |
+
optimizer.step()
|
| 246 |
+
running_loss += loss.item() * b.size(0)
|
| 247 |
+
train_loss = running_loss / len(train_ds)
|
| 248 |
+
val_auc, val_ap = evaluate(model, val_dl, device)
|
| 249 |
+
print(f"[Epoch {epoch}] train_loss={train_loss:.4f} val_auc={val_auc:.4f} val_ap={val_ap:.4f}", flush=True)
|
| 250 |
+
|
| 251 |
+
if val_auc > best_val:
|
| 252 |
+
best_val = val_auc
|
| 253 |
+
torch.save(model.state_dict(), os_out / "best_model.pt")
|
| 254 |
+
print(f"[Epoch {epoch}] Saved new best model with val_auc={val_auc:.4f}", flush=True)
|
| 255 |
+
|
| 256 |
+
torch.save(model.state_dict(), os_out / "last_model.pt")
|
| 257 |
+
test_auc, test_ap = evaluate(model, test_dl, device)
|
| 258 |
+
print(f"FINAL TEST: AUC={test_auc:.4f} AP={test_ap:.4f}", flush=True)
|
| 259 |
+
print(f"[i] Models written to {os_out}/best_model.pt and last_model.pt", flush=True)
|
| 260 |
+
|
| 261 |
+
if __name__ == "__main__":
|
| 262 |
+
main()
|