ananyakrishna commited on
Commit
e3102c9
·
1 Parent(s): b44075a

preliminary plug + play

Browse files
Files changed (1) hide show
  1. dpacman/data/compute_embeddings.py +243 -0
dpacman/data/compute_embeddings.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Plug-and-play embedding extraction for:
3
+ • Chromosome sequences (from raw UCSC JSON)
4
+ • TF sequences (transcription_factors.fasta)
5
+
6
+ Usage example (DNA + protein in one go):
7
+ module load miniconda/24.7.1
8
+ conda activate dpacman
9
+ python dpacman/data/compute_embeddings.py \
10
+ --genome-json-dir ../data_files/raw/genomes/hg38 \
11
+ --tf-fasta ../data_files/processed/tfclust/hg38_tf/transcription_factors.fasta \
12
+ --chrom-model caduceus \
13
+ --tf-model esm-dbp \
14
+ --out-dir ../data_files/processed/tfclust/hg38_tf/embeddings \
15
+ --device cuda
16
+ """
17
+ import os
18
+ import re
19
+ import argparse
20
+ import json
21
+ import numpy as np
22
+ from pathlib import Path
23
+ import torch
24
+ from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM, pipeline
25
+ import esm
26
+
27
+ # ---- model wrappers ----
28
+
29
+ class CaduceusEmbedder:
30
+ def __init__(self, device, chunk_size=131_072, overlap=0):
31
+ """
32
+ device: 'cpu' or 'cuda'
33
+ chunk_size: max bases (and thus tokens) to send in one forward pass
34
+ overlap: how many bases each window overlaps the previous; 0 = no overlap
35
+ """
36
+ model_name = "kuleshov-group/caduceus-ph_seqlen-131k_d_model-256_n_layer-16"
37
+ self.tokenizer = AutoTokenizer.from_pretrained(
38
+ model_name, trust_remote_code=True
39
+ )
40
+ self.model = AutoModel.from_pretrained(
41
+ model_name, trust_remote_code=True
42
+ ).to(device).eval()
43
+ self.device = device
44
+ self.chunk_size = chunk_size
45
+ self.step = chunk_size - overlap
46
+
47
+ def embed(self, seqs):
48
+ all_embs = []
49
+ for seq in seqs:
50
+ window_vecs = []
51
+ # slide windows of up to chunk_size bases
52
+ for i in range(0, len(seq), self.step):
53
+ chunk = seq[i : i + self.chunk_size]
54
+ if not chunk:
55
+ break
56
+ # enforce truncation so tokens <= chunk_size
57
+ toks = self.tokenizer(
58
+ chunk,
59
+ return_tensors="pt",
60
+ padding=False,
61
+ truncation=True,
62
+ max_length=self.chunk_size
63
+ ).to(self.device)
64
+ with torch.no_grad():
65
+ out = self.model(**toks).last_hidden_state
66
+ # mean-pool tokens → (D,)
67
+ window_vecs.append(out.mean(dim=1).squeeze(0).cpu())
68
+ # average over windows → one (D,) vector per full sequence
69
+ seq_emb = torch.stack(window_vecs, dim=0).mean(dim=0).numpy()
70
+ all_embs.append(seq_emb)
71
+ return np.vstack(all_embs) # shape (N, D)
72
+
73
+ class DNABertEmbedder:
74
+ def __init__(self, device):
75
+ self.tokenizer = AutoTokenizer.from_pretrained("zhihan1996/DNA_bert_6", trust_remote_code=True)
76
+ self.model = AutoModel.from_pretrained("zhihan1996/DNA_bert_6", trust_remote_code=True).to(device)
77
+ self.device = device
78
+
79
+ def embed(self, seqs):
80
+ embs = []
81
+ for s in seqs:
82
+ tokens = self.tokenizer(s, return_tensors="pt", padding=True)["input_ids"].to(self.device)
83
+ with torch.no_grad():
84
+ out = self.model(tokens).last_hidden_state.mean(1)
85
+ embs.append(out.cpu().numpy())
86
+ return np.vstack(embs)
87
+
88
+ class NucleotideTransformerEmbedder:
89
+ def __init__(self, device):
90
+ # HF “feature-extraction” returns a list of (L, D) arrays for each input
91
+ # device: “cpu” or “cuda”
92
+ self.pipe = pipeline(
93
+ "feature-extraction",
94
+ model="InstaDeepAI/nucleotide-transformer-500m-1000g",
95
+ device= -1 if device=="cpu" else 0 # HF uses -1 for CPU, 0 for GPU #:contentReference[oaicite:0]{index=0}
96
+ )
97
+
98
+ def embed(self, seqs):
99
+ """
100
+ seqs: List[str] of raw DNA sequences
101
+ returns: (N, D) array, one D-dim vector per sequence
102
+ """
103
+ all_embeddings = self.pipe(seqs, truncation=True, padding=True)
104
+ # all_embeddings is a List of shape (L, D) arrays
105
+ pooled = [ np.mean(x, axis=0) for x in all_embeddings ]
106
+ return np.vstack(pooled)
107
+
108
+ class ESMEmbedder:
109
+ def __init__(self, device):
110
+ self.model, self.alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
111
+ self.batch_converter = self.alphabet.get_batch_converter()
112
+ self.model.to(device).eval()
113
+ self.device = device
114
+
115
+ def embed(self, seqs):
116
+ batch = [(str(i), seq) for i, seq in enumerate(seqs)]
117
+ _, _, toks = self.batch_converter(batch)
118
+ toks = toks.to(self.device)
119
+ with torch.no_grad():
120
+ results = self.model(toks, repr_layers=[33], return_contacts=False)
121
+ reps = results["representations"][33]
122
+ return reps[:, 1:-1].mean(1).cpu().numpy()
123
+
124
+ class ESMDBPEmbedder:
125
+ def __init__(self, device):
126
+ # Load a local ESM-DBP model from pretrained directory
127
+ model_path = Path(__file__).resolve().parent.parent / 'pretrained'/ 'ESM-DBP'/ 'ESM-DBP.model'
128
+ self.model, self.alphabet = esm.pretrained.load_model_and_alphabet_and_params(str(model_path))
129
+ self.batch_converter = self.alphabet.get_batch_converter()
130
+ self.model.to(device).eval()
131
+ self.device = device
132
+
133
+ def embed(self, seqs):
134
+ batch = [(str(i), seq) for i, seq in enumerate(seqs)]
135
+ _, _, toks = self.batch_converter(batch)
136
+ toks = toks.to(self.device)
137
+ with torch.no_grad():
138
+ results = self.model(toks, repr_layers=[33], return_contacts=False)
139
+ reps = results["representations"][33]
140
+ return reps[:, 1:-1].mean(1).cpu().numpy()
141
+
142
+ class GPNEmbedder:
143
+ def __init__(self, device):
144
+ model_name = "songlab/gpn-msa-sapiens"
145
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
146
+ self.model = AutoModelForMaskedLM.from_pretrained(model_name)
147
+ self.model.to(device)
148
+ self.model.eval()
149
+ self.device = device
150
+
151
+ def embed(self, seqs):
152
+ inputs = self.tokenizer(
153
+ seqs,
154
+ return_tensors="pt",
155
+ padding=True,
156
+ truncation=True
157
+ ).to(self.device)
158
+
159
+ with torch.no_grad():
160
+ last_hidden = self.model(**inputs).last_hidden_state
161
+ return last_hidden.mean(dim=1).cpu().numpy()
162
+
163
+ class ProGenEmbedder:
164
+ def __init__(self, device):
165
+ model_name = "jinyuan22/ProGen2-base"
166
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
167
+ self.model = AutoModel.from_pretrained(model_name).to(device).eval()
168
+ self.device = device
169
+
170
+ def embed(self, seqs):
171
+ inputs = self.tokenizer(
172
+ seqs,
173
+ return_tensors="pt",
174
+ padding=True,
175
+ truncation=True
176
+ ).to(self.device)
177
+ with torch.no_grad():
178
+ last_hidden = self.model(**inputs).last_hidden_state
179
+ return last_hidden.mean(dim=1).cpu().numpy()
180
+
181
+ # ---- main pipeline ----
182
+
183
+ def get_embedder(name, device, for_dna=True):
184
+ name = name.lower()
185
+ if for_dna:
186
+ if name=="caduceus": return CaduceusEmbedder(device)
187
+ if name=="dnabert": return DNABertEmbedder(device)
188
+ if name=="nucleotide": return NucleotideTransformerEmbedder(device)
189
+ if name=="gpn": return GPNEmbedder(device)
190
+ else:
191
+ if name in ("esm",): return ESMEmbedder(device)
192
+ if name in ("esm-dbp","esm_dbp"): return ESMDBPEmbedder(device)
193
+ if name=="progen": return ProGenEmbedder(device)
194
+ raise ValueError(f"Unknown model {name} (for_dna={for_dna})")
195
+
196
+
197
+ def embed_and_save(seqs, ids, embedder, out_path):
198
+ embs = embedder.embed(seqs)
199
+ np.save(out_path, embs)
200
+ with open(out_path.with_suffix(".ids"), "w") as f:
201
+ f.write("\n".join(ids))
202
+
203
+ if __name__=="__main__":
204
+ p = argparse.ArgumentParser()
205
+ p.add_argument("--genome-json-dir", default="data_files/raw/genomes/hg38", help="dir of UCSC JSONs")
206
+ p.add_argument("--tf-fasta", required=True, help="input TF FASTA file")
207
+ p.add_argument("--chrom-model", default="caduceus")
208
+ p.add_argument("--tf-model", default="esm-dbp")
209
+ p.add_argument("--out-dir", default="data_files/processed/tfclust/hg38_tf/embeddings")
210
+ p.add_argument("--device", default="cpu")
211
+ args = p.parse_args()
212
+
213
+ os.makedirs(args.out_dir, exist_ok=True)
214
+ device = args.device
215
+
216
+ #Load only primary chromosome JSONs (chr1–22, X, Y, M)
217
+ genome_dir = Path(args.genome_json_dir)
218
+ chrom_seqs, chrom_ids = [], []
219
+ primary_pattern = re.compile(r"^hg38_chr(?:[1-9]|1[0-9]|2[0-2]|X|Y|M)\.json$")
220
+ for j in sorted(genome_dir.iterdir()):
221
+ if not primary_pattern.match(j.name):
222
+ continue
223
+ data = json.loads(j.read_text())
224
+ seq = data.get("dna") or data.get("sequence")
225
+ chrom = data.get("chrom") or j.stem.split("_")[-1]
226
+ chrom_seqs.append(seq)
227
+ chrom_ids.append(chrom)
228
+ chrom_embedder = get_embedder(args.chrom_model, device, for_dna=True)
229
+ out_chrom = Path(args.out_dir)/f"chrom_{args.chrom_model}.npy"
230
+ embed_and_save(chrom_seqs, chrom_ids, chrom_embedder, out_chrom)
231
+
232
+ #Load TF sequences
233
+ tf_seqs, tf_ids = [], []
234
+ with open(args.tf_fasta) as f:
235
+ for header in f:
236
+ seq = next(f).strip()
237
+ tf_ids.append(header[1:].split()[0])
238
+ tf_seqs.append(seq)
239
+ tf_embedder = get_embedder(args.tf_model, device, for_dna=False)
240
+ out_tf = Path(args.out_dir)/f"tf_{args.tf_model}.npy"
241
+ embed_and_save(tf_seqs, tf_ids, tf_embedder, out_tf)
242
+
243
+ print("Done.")