| """ |
| Plug-and-play embedding extraction for: |
| • Chromosome sequences (from raw UCSC JSON) |
| • TF sequences (transcription_factors.fasta) |
| |
| Usage example (DNA + protein in one go): |
| module load miniconda/24.7.1 |
| conda activate dpacman |
| python dpacman/data/compute_embeddings.py \ |
| --genome-json-dir ../data_files/raw/genomes/hg38 \ |
| --tf-fasta ../data_files/processed/tfclust/hg38_tf/transcription_factors.fasta \ |
| --chrom-model caduceus \ |
| --tf-model esm-dbp \ |
| --out-dir ../data_files/processed/tfclust/hg38_tf/embeddings \ |
| --device cuda |
| """ |
|
|
| import numpy as np |
| from pathlib import Path |
| import torch |
| from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM, pipeline |
| import time |
| import esm |
| from tqdm.auto import tqdm |
| from sklearn.preprocessing import OneHotEncoder |
| import math |
| import rootutils |
| from dpacman.utils import pylogger |
| from tqdm import trange |
| from tqdm.contrib.logging import logging_redirect_tqdm |
|
|
| root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) |
| logger = pylogger.RankedLogger(__name__, rank_zero_only=True) |
| |
|
|
|
|
| class CaduceusEmbedder: |
| def __init__(self, device, chunk_size=131_072, overlap=0): |
| """ |
| device: 'cpu' or 'cuda' |
| chunk_size: max bases (and thus tokens) to send in one forward pass |
| overlap: how many bases each window overlaps the previous; 0 = no overlap |
| """ |
| model_name = "kuleshov-group/caduceus-ph_seqlen-131k_d_model-256_n_layer-16" |
| self.tokenizer = AutoTokenizer.from_pretrained( |
| model_name, trust_remote_code=True |
| ) |
| self.model = ( |
| AutoModelForMaskedLM.from_pretrained(model_name, trust_remote_code=True) |
| .to(device) |
| .eval() |
| ) |
| self.device = device |
| self.chunk_size = chunk_size |
| self.step = chunk_size - overlap |
|
|
| def embed(self, seqs, batch_size=1, pooling=False): |
| """ |
| seqs: List[str] of DNA sequences (each <= chunk_size for this test) |
| returns: np.ndarray of shape (N, L, D), raw per‐token embeddings |
| """ |
| n = len(seqs) |
| if n == 0: |
| return {} |
|
|
| |
| max_len = max(len(s) for s in seqs) |
| logger.info(f"Max length (will be padded/truncated to tokenizer setting): {max_len}") |
| |
| outputs = {} |
|
|
| with logging_redirect_tqdm(): |
| for i in range(0, n, batch_size): |
| batch_seqs = seqs[i : i + batch_size] |
| logger.info(f"Embedding batch {n//(batch_size*(i+1))}") |
| |
| for seq in tqdm(batch_seqs, total=len(batch_seqs), desc="DNA: Caduceus", dynamic_ncols=True): |
| toks = self.tokenizer( |
| seq, |
| return_tensors="pt", |
| padding=False, |
| truncation=True, |
| max_length=self.chunk_size |
| ).to(self.device) |
| with torch.no_grad(): |
| out = self.model(**toks).last_hidden_state |
| outputs[seq] = out.cpu().numpy().squeeze(0)[0:-1,:] |
| return outputs |
| |
| def benchmark(self, lengths=None): |
| """ |
| Time embedding on single-sequence of various lengths. |
| By default tests [5K,10K,50K,100K,chunk_size]. |
| """ |
| tests = lengths or [5_000, 10_000, 50_000, 100_000, self.chunk_size] |
| print(f"→ Benchmarking Caduceus on device={self.device}") |
| for sz in tests: |
| seq = "A" * sz |
| |
| _ = self.embed([seq]) |
| if self.device != "cpu": |
| torch.cuda.synchronize() |
| t0 = time.perf_counter() |
| _ = self.embed([seq]) |
| if self.device != "cpu": |
| torch.cuda.synchronize() |
| t1 = time.perf_counter() |
| print(f" length={sz:6,d} time={(t1-t0)*1000:7.1f} ms") |
|
|
|
|
| class SegmentNTEmbedder: |
| def __init__(self, device): |
| self.tokenizer = AutoTokenizer.from_pretrained( |
| "InstaDeepAI/segment_nt", trust_remote_code=True |
| ) |
| self.model = ( |
| AutoModel.from_pretrained("InstaDeepAI/segment_nt", trust_remote_code=True) |
| .to(device) |
| .eval() |
| ) |
| self.device = device |
|
|
| def _adjust_length(self, input_ids): |
| """ |
| Pads the length so it's divisible by 4; this is needed to get through the BPNet |
| """ |
| bs, L = input_ids.shape |
| excl = L - 1 |
| remainder = (excl) % 4 |
| if remainder != 0: |
| pad_needed = 4 - remainder |
| pad_tensor = torch.full( |
| (bs, pad_needed), |
| self.tokenizer.pad_token_id, |
| dtype=input_ids.dtype, |
| device=input_ids.device, |
| ) |
| input_ids = torch.cat([input_ids, pad_tensor], dim=1) |
| return input_ids |
|
|
| def embed(self, seqs, batch_size=1, log_every_pct=5, pooling=False): |
| """ |
| seqs: List[str] |
| Returns: Dict[str, np.ndarray] |
| - pooling=True -> {seq: (D,)} |
| - pooling=False -> {seq: (L-1, D)} (excludes CLS, retains padding/truncation) |
| """ |
| n = len(seqs) |
| if n == 0: |
| return {} |
|
|
| |
| steps = list(range(log_every_pct, 101, log_every_pct)) |
| checkpoints = [max(1, math.ceil(n * p / 100)) for p in steps] |
| ck_idx = 0 |
| processed = 0 |
|
|
| |
| try: |
| max_len = max(len(s) for s in seqs) |
| msg = ( |
| f"Max length (will be padded/truncated to tokenizer setting): {max_len}" |
| ) |
| (logger.info if logger is not None else print)(msg) |
| except Exception: |
| pass |
|
|
| out = {} |
|
|
| for i in range(0, n, batch_size): |
| batch_seqs = seqs[i : i + batch_size] |
|
|
| encoded = self.tokenizer.batch_encode_plus( |
| batch_seqs, |
| return_tensors="pt", |
| padding=True, |
| truncation=True, |
| max_length=1998, |
| ) |
|
|
| orig_len = encoded["input_ids"].shape[1] |
|
|
| input_ids = encoded["input_ids"].to(self.device) |
| logger.info(f"input_ids.shape: {input_ids.shape}") |
| |
| input_ids = self._adjust_length(input_ids) |
| logger.info(f"after adjusting length: input_ids.shape: {input_ids.shape}") |
| attention_mask = input_ids != self.tokenizer.pad_token_id |
|
|
| with torch.no_grad(): |
| outs = self.model( |
| input_ids, |
| attention_mask=attention_mask, |
| output_hidden_states=True, |
| return_dict=True, |
| ) |
|
|
| last_hidden = ( |
| outs.hidden_states[-1] |
| if getattr(outs, "hidden_states", None) is not None |
| else outs.last_hidden_state |
| ) |
| logger.info(f"last_hidden.shape: {last_hidden.shape}") |
|
|
| |
| last_hidden = last_hidden[ |
| :, 1:orig_len, : |
| ] |
| logger.info( |
| f"after cutting first position: last_hidden.shape: {last_hidden.shape}" |
| ) |
|
|
| if pooling: |
| |
| pooled = last_hidden.mean(dim=1) |
| pooled_np = pooled.detach().cpu().numpy() |
| for j, s in enumerate(batch_seqs): |
| out[s] = pooled_np[j] |
| else: |
| |
| emb_np = last_hidden.detach().cpu().numpy() |
| for j, s in enumerate(batch_seqs): |
| out[s] = emb_np[j] |
|
|
| processed += len(batch_seqs) |
|
|
| |
| while ck_idx < len(checkpoints) and processed >= checkpoints[ck_idx]: |
| pct = steps[ck_idx] |
| msg = f"[embed] {processed}/{n} ({pct}%)" |
| try: |
| (logger.info if logger is not None else print)(msg) |
| except Exception: |
| print(msg, flush=True) |
| ck_idx += 1 |
|
|
| |
| try: |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| except Exception: |
| pass |
|
|
| return out |
|
|
|
|
| class DNABertEmbedder: |
| def __init__(self, device): |
| self.tokenizer = AutoTokenizer.from_pretrained( |
| "zhihan1996/DNA_bert_6", trust_remote_code=True |
| ) |
| self.model = AutoModel.from_pretrained( |
| "zhihan1996/DNA_bert_6", trust_remote_code=True |
| ).to(device) |
| self.device = device |
|
|
| def embed(self, seqs, batch_size=1): |
| embs = [] |
| for s in seqs: |
| tokens = self.tokenizer(s, return_tensors="pt", padding=True)[ |
| "input_ids" |
| ].to(self.device) |
| with torch.no_grad(): |
| out = self.model(tokens).last_hidden_state.mean(1) |
| embs.append(out.cpu().numpy()) |
| return np.vstack(embs) |
|
|
|
|
| class OneHotEmbedder: |
| """ |
| Simple one-hot encoder as a baseline |
| """ |
|
|
| def __init__(self, device=None): |
| self.nucleotides = [list("ACTGN")] |
| self.model = OneHotEncoder(categories=self.nucleotides, dtype=int) |
|
|
| def embed(self, seqs, batch_size=1): |
| out = {} |
| for s in seqs: |
| |
| tokens = np.array(list(s)).reshape(-1, 1) |
| embedding = self.model.fit_transform(tokens).toarray() |
| out[s] = embedding |
| return out |
|
|
|
|
| class NucleotideTransformerEmbedder: |
| def __init__(self, device): |
| |
| |
| self.pipe = pipeline( |
| "feature-extraction", |
| model="InstaDeepAI/nucleotide-transformer-500m-1000g", |
| device=( |
| -1 if device == "cpu" else 0 |
| ), |
| ) |
|
|
| def embed(self, seqs, batch_size=1): |
| """ |
| seqs: List[str] of raw DNA sequences |
| returns: (N, D) array, one D-dim vector per sequence |
| """ |
| all_embeddings = self.pipe(seqs, truncation=True, padding=True) |
| |
| pooled = [np.mean(x, axis=0) for x in all_embeddings] |
| return np.vstack(pooled) |
|
|
|
|
| class ESMEmbedder: |
| def __init__(self, device, model_name="esm2_t33_650M_UR50D"): |
| |
| self.device = device |
| try: |
| self.model, self.alphabet = getattr(esm.pretrained, model_name)() |
| self.is_esm2 = model_name.lower().startswith("esm2") |
| except AttributeError: |
| |
| self.model, self.alphabet = esm.pretrained.esm1b_t33_650M_UR50S() |
| self.is_esm2 = False |
| self.batch_converter = self.alphabet.get_batch_converter() |
| self.model.to(device).eval() |
| |
| self.max_len = ( |
| 4096 if self.is_esm2 else 1024 |
| ) |
| |
| self.chunk_size = self.max_len - 2 |
| self.overlap = self.chunk_size // 4 |
|
|
| def _chunk_sequence(self, seq): |
| """ |
| Return list of possibly overlapping chunks of seq, each <= chunk_size. |
| """ |
| if len(seq) <= self.chunk_size: |
| return [seq] |
| logger.info(f"Calling chunk sequence") |
| step = self.chunk_size - self.overlap |
| chunks = [] |
| for i in range(0, len(seq), step): |
| chunk = seq[i : i + self.chunk_size] |
| if not chunk: |
| break |
| chunks.append(chunk) |
| return chunks |
|
|
| def embed(self, seqs, batch_size=1, avg=False): |
| """ |
| seqs: List[str] of protein sequences. |
| Returns: np.ndarray of: shape (N, D) pooled per-sequence embeddings if avg true; shape (N, L, D) otherwise |
| """ |
| all_embeddings = {} |
| for i, seq in enumerate(seqs): |
| chunks = self._chunk_sequence(seq) |
| chunk_vecs = [] |
| |
| for chunk in chunks: |
| batch = [(str(i), chunk)] |
| _, _, toks = self.batch_converter(batch) |
| toks = toks.to(self.device) |
| with torch.no_grad(): |
| results = self.model(toks, repr_layers=[33], return_contacts=False) |
| reps = results["representations"][33] |
| |
| if reps.size(1) > 2: |
| rep = reps[:, 1:-1] |
| if avg: |
| rep = reps.mean(1) |
| chunk_vecs.append(rep.squeeze(0)) |
| |
| if len(chunk_vecs) == 1: |
| seq_vec = chunk_vecs[0] |
| |
| else: |
| |
| stacked = torch.stack(chunk_vecs, dim=0) |
| seq_vec = stacked.mean(0) |
| all_embeddings[seq] = seq_vec.cpu().numpy() |
| return all_embeddings |
|
|
|
|
| class ESMDBPEmbedder: |
| def __init__(self, device): |
| base_model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S() |
| model_path = ( |
| Path(__file__).resolve().parent.parent |
| / "pretrained" |
| / "ESM-DBP" |
| / "ESM-DBP.model" |
| ) |
| checkpoint = torch.load(model_path, map_location="cpu") |
| clean_sd = {} |
| for k, v in checkpoint.items(): |
| clean_sd[k.replace("module.", "")] = v |
| result = base_model.load_state_dict(clean_sd, strict=False) |
| if result.missing_keys: |
| print(f"[ESMDBP] missing keys: {result.missing_keys}") |
| if result.unexpected_keys: |
| print(f"[ESMDBP] unexpected keys: {result.unexpected_keys}") |
|
|
| self.model = base_model.to(device).eval() |
| self.alphabet = alphabet |
| self.batch_converter = alphabet.get_batch_converter() |
| self.device = device |
| self.max_len = 1024 |
| self.chunk_size = self.max_len - 2 |
| self.overlap = self.chunk_size // 4 |
|
|
| def _chunk_sequence(self, seq): |
| if len(seq) <= self.chunk_size: |
| return [seq] |
| step = self.chunk_size - self.overlap |
| chunks = [] |
| for i in range(0, len(seq), step): |
| chunk = seq[i : i + self.chunk_size] |
| if not chunk: |
| break |
| chunks.append(chunk) |
| return chunks |
|
|
| def embed(self, seqs, batch_size=1): |
| all_embeddings = [] |
| for i, seq in enumerate(seqs): |
| chunks = self._chunk_sequence(seq) |
| chunk_vecs = [] |
| for chunk in chunks: |
| batch = [(str(i), chunk)] |
| _, _, toks = self.batch_converter(batch) |
| toks = toks.to(self.device) |
| with torch.no_grad(): |
| out = self.model(toks, repr_layers=[33], return_contacts=False) |
| reps = out["representations"][33] |
| if reps.size(1) > 2: |
| rep = reps[:, 1:-1].mean(1) |
| else: |
| rep = reps.mean(1) |
| chunk_vecs.append(rep.squeeze(0)) |
| if len(chunk_vecs) == 1: |
| seq_vec = chunk_vecs[0] |
| else: |
| stacked = torch.stack(chunk_vecs, dim=0) |
| seq_vec = stacked.mean(0) |
| all_embeddings.append(seq_vec.cpu().numpy()) |
| return np.vstack(all_embeddings) |
|
|
|
|
| class GPNEmbedder: |
| def __init__(self, device): |
| model_name = "songlab/gpn-msa-sapiens" |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
| self.model = AutoModelForMaskedLM.from_pretrained(model_name) |
| self.model.to(device) |
| self.model.eval() |
| self.device = device |
|
|
| def embed(self, seqs, batch_size=1): |
| inputs = self.tokenizer( |
| seqs, return_tensors="pt", padding=True, truncation=True |
| ).to(self.device) |
|
|
| with torch.no_grad(): |
| last_hidden = self.model(**inputs).last_hidden_state |
| return last_hidden.mean(dim=1).cpu().numpy() |
|
|
|
|
| class ProGenEmbedder: |
| def __init__(self, device): |
| model_name = "jinyuan22/ProGen2-base" |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
| self.model = AutoModel.from_pretrained(model_name).to(device).eval() |
| self.device = device |
|
|
| def embed(self, seqs, batch_size=1): |
| inputs = self.tokenizer( |
| seqs, return_tensors="pt", padding=True, truncation=True |
| ).to(self.device) |
| with torch.no_grad(): |
| last_hidden = self.model(**inputs).last_hidden_state |
| return last_hidden.mean(dim=1).cpu().numpy() |
|
|