svincoff commited on
Commit
3b9cde0
·
1 Parent(s): f936712

added model files

Browse files
dpacman/classifier/model/__init__.py ADDED
File without changes
dpacman/classifier/model/compress_embeddings.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # compress_embeddings.py
2
+ # USAGE: python compress_embeddings.py --input_glob "/path/to/esm_embeddings/*.npy" --output_dir "/path/to/compressed_embeddings" --esm_dim 1280 --out_dim 256
3
+ # --------------
4
+ import os
5
+ import glob
6
+ import numpy as np
7
+ import torch
8
+ from torch import nn
9
+
10
+ class EmbeddingCompressor(nn.Module):
11
+ def __init__(self, input_dim: int = 1280, output_dim: int = 256):
12
+ super().__init__()
13
+ self.fc = nn.Linear(input_dim, output_dim)
14
+
15
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
16
+ """
17
+ x: (batch, L, input_dim) or (L, input_dim)
18
+ returns: (batch, output_dim) or (output_dim,)
19
+ """
20
+ if x.dim() == 2:
21
+ # single example: mean over tokens
22
+ x = x.mean(dim=0, keepdim=True) # → (1, input_dim)
23
+ else:
24
+ # batch: mean over tokens
25
+ x = x.mean(dim=1) # → (batch, input_dim)
26
+ return self.fc(x) # → (batch, output_dim)
27
+
28
+ def compress_file(in_path: str, out_path: str, model: EmbeddingCompressor):
29
+ arr = np.load(in_path) # shape (L, D) or (batch, L, D)
30
+ tensor = torch.from_numpy(arr).float()
31
+ with torch.no_grad():
32
+ compressed = model(tensor) # → (batch, 256)
33
+ out = compressed.cpu().numpy()
34
+ np.save(out_path, out)
35
+ print(f"Saved {out_path}")
36
+
37
+ if __name__ == "__main__":
38
+ import argparse
39
+ parser = argparse.ArgumentParser(description="Compress ESM embeddings to 256­d")
40
+ parser.add_argument("--input_glob", type=str, required=True,
41
+ help="Glob for your .npy ESM embeddings (e.g. data/esm_*.npy)")
42
+ parser.add_argument("--output_dir", type=str, required=True)
43
+ parser.add_argument("--esm_dim", type=int, default=1280)
44
+ parser.add_argument("--out_dim", type=int, default=256)
45
+ args = parser.parse_args()
46
+
47
+ os.makedirs(args.output_dir, exist_ok=True)
48
+ compressor = EmbeddingCompressor(args.esm_dim, args.out_dim)
49
+ compressor.eval()
50
+
51
+ for fn in glob.glob(args.input_glob):
52
+ base = os.path.basename(fn).replace(".npy", "_256.npy")
53
+ out_path = os.path.join(args.output_dir, base)
54
+ compress_file(fn, out_path, compressor)
dpacman/classifier/model/compute_embeddings.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Plug-and-play embedding extraction for:
3
+ • Chromosome sequences (from raw UCSC JSON)
4
+ • TF sequences (transcription_factors.fasta)
5
+
6
+ Usage example (DNA + protein in one go):
7
+ module load miniconda/24.7.1
8
+ conda activate dpacman
9
+ python dpacman/data/compute_embeddings.py \
10
+ --genome-json-dir ../data_files/raw/genomes/hg38 \
11
+ --tf-fasta ../data_files/processed/tfclust/hg38_tf/transcription_factors.fasta \
12
+ --chrom-model caduceus \
13
+ --tf-model esm-dbp \
14
+ --out-dir ../data_files/processed/tfclust/hg38_tf/embeddings \
15
+ --device cuda
16
+ """
17
+ import os
18
+ import re
19
+ import argparse
20
+ import json
21
+ import numpy as np
22
+ from pathlib import Path
23
+ import torch
24
+ from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM, pipeline
25
+ import esm
26
+ from Bio import SeqIO
27
+ import time
28
+
29
+ # ---- model wrappers ----
30
+
31
+ class CaduceusEmbedder:
32
+ def __init__(self, device, chunk_size=131_072, overlap=0):
33
+ """
34
+ device: 'cpu' or 'cuda'
35
+ chunk_size: max bases (and thus tokens) to send in one forward pass
36
+ overlap: how many bases each window overlaps the previous; 0 = no overlap
37
+ """
38
+ model_name = "kuleshov-group/caduceus-ph_seqlen-131k_d_model-256_n_layer-16"
39
+ self.tokenizer = AutoTokenizer.from_pretrained(
40
+ model_name, trust_remote_code=True
41
+ )
42
+ self.model = AutoModel.from_pretrained(
43
+ model_name, trust_remote_code=True
44
+ ).to(device).eval()
45
+ self.device = device
46
+ self.chunk_size = chunk_size
47
+ self.step = chunk_size - overlap
48
+
49
+ def embed(self, seqs):
50
+ """
51
+ seqs: List[str] of DNA sequences (each <= chunk_size for this test)
52
+ returns: np.ndarray of shape (N, L, D), raw per‐token embeddings
53
+ """
54
+ # outputs = []
55
+ # for seq in seqs:
56
+ # # --- new: raw per‐token embeddings in one shot ---
57
+ # toks = self.tokenizer(
58
+ # seq,
59
+ # return_tensors="pt",
60
+ # padding=False,
61
+ # truncation=True,
62
+ # max_length=self.chunk_size
63
+ # ).to(self.device)
64
+ # with torch.no_grad():
65
+ # out = self.model(**toks).last_hidden_state # (1, L, D)
66
+ # outputs.append(out.cpu().numpy()[0]) # (L, D)
67
+
68
+ # return np.stack(outputs, axis=0) # (N, L, D)
69
+ outputs = []
70
+ for seq in seqs:
71
+ toks = self.tokenizer(
72
+ seq,
73
+ return_tensors="pt",
74
+ padding=False,
75
+ truncation=True,
76
+ max_length=self.chunk_size
77
+ ).to(self.device)
78
+ with torch.no_grad():
79
+ out = self.model(**toks).last_hidden_state # (1, L, D)
80
+ outputs.append(out.cpu().numpy()[0]) # (L, D)
81
+ return outputs # list of variable-length (L_i, D) arrays
82
+
83
+
84
+ def benchmark(self, lengths=None):
85
+ """
86
+ Time embedding on single-sequence of various lengths.
87
+ By default tests [5K,10K,50K,100K,chunk_size].
88
+ """
89
+ tests = lengths or [5_000, 10_000, 50_000, 100_000, self.chunk_size]
90
+ print(f"→ Benchmarking Caduceus on device={self.device}")
91
+ for sz in tests:
92
+ seq = "A" * sz
93
+ # Warm-up
94
+ _ = self.embed([seq])
95
+ if self.device != "cpu":
96
+ torch.cuda.synchronize()
97
+ t0 = time.perf_counter()
98
+ _ = self.embed([seq])
99
+ if self.device != "cpu":
100
+ torch.cuda.synchronize()
101
+ t1 = time.perf_counter()
102
+ print(f" length={sz:6,d} time={(t1-t0)*1000:7.1f} ms")
103
+
104
+ class SegmentNTEmbedder:
105
+ def __init__(self, device):
106
+ self.tokenizer = AutoTokenizer.from_pretrained("InstaDeepAI/segment_nt", trust_remote_code=True)
107
+ self.model = AutoModel.from_pretrained("InstaDeepAI/segment_nt", trust_remote_code=True).to(device).eval()
108
+ self.device = device
109
+
110
+ def _adjust_length(self, input_ids):
111
+ bs, L = input_ids.shape
112
+ excl = L - 1
113
+ remainder = (excl) % 4
114
+ if remainder != 0:
115
+ pad_needed = 4 - remainder
116
+ pad_tensor = torch.full((bs, pad_needed), self.tokenizer.pad_token_id, dtype=input_ids.dtype, device=input_ids.device)
117
+ input_ids = torch.cat([input_ids, pad_tensor], dim=1)
118
+ return input_ids
119
+
120
+ def embed(self, seqs, batch_size=16):
121
+ """
122
+ seqs: List[str]
123
+ Returns: np.ndarray of shape (N, D)
124
+ """
125
+ all_embeddings = []
126
+ for i in range(0, len(seqs), batch_size):
127
+ batch_seqs = seqs[i : i + batch_size]
128
+ encoded = self.tokenizer.batch_encode_plus(
129
+ batch_seqs,
130
+ return_tensors="pt",
131
+ padding=True,
132
+ truncation=True,
133
+ )
134
+ input_ids = encoded["input_ids"].to(self.device) # (B, L)
135
+ attention_mask = input_ids != self.tokenizer.pad_token_id
136
+
137
+ input_ids = self._adjust_length(input_ids)
138
+ attention_mask = (input_ids != self.tokenizer.pad_token_id)
139
+
140
+ with torch.no_grad():
141
+ outs = self.model(
142
+ input_ids,
143
+ attention_mask=attention_mask,
144
+ output_hidden_states=True,
145
+ return_dict=True,
146
+ )
147
+ if hasattr(outs, "hidden_states") and outs.hidden_states is not None:
148
+ last_hidden = outs.hidden_states[-1] # (B, L, D)
149
+ else:
150
+ last_hidden = outs.last_hidden_state # fallback
151
+
152
+ # Exclude CLS token if present (assume first token) and pool
153
+ pooled = last_hidden[:, 1:, :].mean(dim=1) # (B, D)
154
+ all_embeddings.append(pooled.cpu().numpy())
155
+
156
+ # release fragmentation
157
+ torch.cuda.empty_cache()
158
+
159
+ return np.vstack(all_embeddings) # (N, D)
160
+
161
+
162
+ class DNABertEmbedder:
163
+ def __init__(self, device):
164
+ self.tokenizer = AutoTokenizer.from_pretrained("zhihan1996/DNA_bert_6", trust_remote_code=True)
165
+ self.model = AutoModel.from_pretrained("zhihan1996/DNA_bert_6", trust_remote_code=True).to(device)
166
+ self.device = device
167
+
168
+ def embed(self, seqs):
169
+ embs = []
170
+ for s in seqs:
171
+ tokens = self.tokenizer(s, return_tensors="pt", padding=True)["input_ids"].to(self.device)
172
+ with torch.no_grad():
173
+ out = self.model(tokens).last_hidden_state.mean(1)
174
+ embs.append(out.cpu().numpy())
175
+ return np.vstack(embs)
176
+
177
+ class NucleotideTransformerEmbedder:
178
+ def __init__(self, device):
179
+ # HF “feature-extraction” returns a list of (L, D) arrays for each input
180
+ # device: “cpu” or “cuda”
181
+ self.pipe = pipeline(
182
+ "feature-extraction",
183
+ model="InstaDeepAI/nucleotide-transformer-500m-1000g",
184
+ device= -1 if device=="cpu" else 0 # HF uses -1 for CPU, 0 for GPU #:contentReference[oaicite:0]{index=0}
185
+ )
186
+
187
+ def embed(self, seqs):
188
+ """
189
+ seqs: List[str] of raw DNA sequences
190
+ returns: (N, D) array, one D-dim vector per sequence
191
+ """
192
+ all_embeddings = self.pipe(seqs, truncation=True, padding=True)
193
+ # all_embeddings is a List of shape (L, D) arrays
194
+ pooled = [ np.mean(x, axis=0) for x in all_embeddings ]
195
+ return np.vstack(pooled)
196
+
197
+ # class ESMEmbedder:
198
+ # def __init__(self, device):
199
+ # self.model, self.alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
200
+ # self.batch_converter = self.alphabet.get_batch_converter()
201
+ # self.model.to(device).eval()
202
+ # self.device = device
203
+
204
+ # def embed(self, seqs):
205
+ # batch = [(str(i), seq) for i, seq in enumerate(seqs)]
206
+ # _, _, toks = self.batch_converter(batch)
207
+ # toks = toks.to(self.device)
208
+ # with torch.no_grad():
209
+ # results = self.model(toks, repr_layers=[33], return_contacts=False)
210
+ # reps = results["representations"][33]
211
+ # return reps[:, 1:-1].mean(1).cpu().numpy()
212
+
213
+
214
+ class ESMEmbedder:
215
+ def __init__(self, device, model_name="esm2_t33_650M_UR50D"):
216
+ # Try to load the specified ESM-2 model; fallback to esm1b if missing
217
+ self.device = device
218
+ try:
219
+ self.model, self.alphabet = getattr(esm.pretrained, model_name)()
220
+ self.is_esm2 = model_name.lower().startswith("esm2")
221
+ except AttributeError:
222
+ # fallback to ESM-1b
223
+ self.model, self.alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
224
+ self.is_esm2 = False
225
+ self.batch_converter = self.alphabet.get_batch_converter()
226
+ self.model.to(device).eval()
227
+ # determine max length: esm2 models vary; use default 1024 for esm1b
228
+ self.max_len = 4096 if self.is_esm2 else 1024 # adjust if your esm2 variant has explicit limit
229
+ # for chunking: reserve 2 tokens if model uses BOS/EOS
230
+ self.chunk_size = self.max_len - 2
231
+ self.overlap = self.chunk_size // 4 # 25% overlap to smooth boundaries
232
+
233
+ def _chunk_sequence(self, seq):
234
+ """
235
+ Return list of possibly overlapping chunks of seq, each <= chunk_size.
236
+ """
237
+ if len(seq) <= self.chunk_size:
238
+ return [seq]
239
+ step = self.chunk_size - self.overlap
240
+ chunks = []
241
+ for i in range(0, len(seq), step):
242
+ chunk = seq[i : i + self.chunk_size]
243
+ if not chunk:
244
+ break
245
+ chunks.append(chunk)
246
+ return chunks
247
+
248
+ def embed(self, seqs):
249
+ """
250
+ seqs: List[str] of protein sequences.
251
+ Returns: np.ndarray of shape (N, D) pooled per-sequence embeddings.
252
+ """
253
+ all_embeddings = []
254
+ for i, seq in enumerate(seqs):
255
+ chunks = self._chunk_sequence(seq)
256
+ chunk_vecs = []
257
+ # process chunks in batch if small number, else sequentially
258
+ for chunk in chunks:
259
+ batch = [(str(i), chunk)]
260
+ _, _, toks = self.batch_converter(batch)
261
+ toks = toks.to(self.device)
262
+ with torch.no_grad():
263
+ results = self.model(toks, repr_layers=[33], return_contacts=False)
264
+ reps = results["representations"][33] # (1, L, D)
265
+ # remove BOS/EOS if present: take 1:-1 if length permits
266
+ if reps.size(1) > 2:
267
+ rep = reps[:, 1:-1].mean(1) # (1, D)
268
+ else:
269
+ rep = reps.mean(1) # fallback
270
+ chunk_vecs.append(rep.squeeze(0)) # (D,)
271
+ if len(chunk_vecs) == 1:
272
+ seq_vec = chunk_vecs[0]
273
+ else:
274
+ # average chunk vectors
275
+ stacked = torch.stack(chunk_vecs, dim=0) # (num_chunks, D)
276
+ seq_vec = stacked.mean(0)
277
+ all_embeddings.append(seq_vec.cpu().numpy())
278
+ return np.vstack(all_embeddings) # (N, D)
279
+
280
+
281
+ # class ESMDBPEmbedder:
282
+ # def __init__(self, device):
283
+ # base_model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
284
+ # model_path = (
285
+ # Path(__file__).resolve().parent.parent
286
+ # / "pretrained" / "ESM-DBP" / "ESM-DBP.model"
287
+ # )
288
+ # checkpoint = torch.load(model_path, map_location="cpu")
289
+ # clean_sd = {}
290
+ # for k, v in checkpoint.items():
291
+ # clean_sd[k.replace("module.", "")] = v
292
+ # result = base_model.load_state_dict(clean_sd, strict=False)
293
+ # if result.missing_keys:
294
+ # print(f"[ESMDBP] missing keys: {result.missing_keys}")
295
+ # if result.unexpected_keys:
296
+ # print(f"[ESMDBP] unexpected keys: {result.unexpected_keys}")
297
+
298
+ # self.model = base_model.to(device).eval()
299
+ # self.alphabet = alphabet
300
+ # self.batch_converter = alphabet.get_batch_converter()
301
+ # self.device = device
302
+
303
+ # def embed(self, seqs):
304
+ # batch = [(str(i), seq) for i, seq in enumerate(seqs)]
305
+ # _, _, toks = self.batch_converter(batch)
306
+ # toks = toks.to(self.device)
307
+ # with torch.no_grad():
308
+ # out = self.model(toks, repr_layers=[33], return_contacts=False)
309
+ # reps = out["representations"][33]
310
+ # # skip start/end tokens
311
+ # return reps[:, 1:-1].mean(1).cpu().numpy()
312
+
313
+ class ESMDBPEmbedder:
314
+ def __init__(self, device):
315
+ base_model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
316
+ model_path = (
317
+ Path(__file__).resolve().parent.parent
318
+ / "pretrained" / "ESM-DBP" / "ESM-DBP.model"
319
+ )
320
+ checkpoint = torch.load(model_path, map_location="cpu")
321
+ clean_sd = {}
322
+ for k, v in checkpoint.items():
323
+ clean_sd[k.replace("module.", "")] = v
324
+ result = base_model.load_state_dict(clean_sd, strict=False)
325
+ if result.missing_keys:
326
+ print(f"[ESMDBP] missing keys: {result.missing_keys}")
327
+ if result.unexpected_keys:
328
+ print(f"[ESMDBP] unexpected keys: {result.unexpected_keys}")
329
+
330
+ self.model = base_model.to(device).eval()
331
+ self.alphabet = alphabet
332
+ self.batch_converter = alphabet.get_batch_converter()
333
+ self.device = device
334
+ self.max_len = 1024 # same limit as esm1b
335
+ self.chunk_size = self.max_len - 2
336
+ self.overlap = self.chunk_size // 4
337
+
338
+ def _chunk_sequence(self, seq):
339
+ if len(seq) <= self.chunk_size:
340
+ return [seq]
341
+ step = self.chunk_size - self.overlap
342
+ chunks = []
343
+ for i in range(0, len(seq), step):
344
+ chunk = seq[i : i + self.chunk_size]
345
+ if not chunk:
346
+ break
347
+ chunks.append(chunk)
348
+ return chunks
349
+
350
+ def embed(self, seqs):
351
+ all_embeddings = []
352
+ for i, seq in enumerate(seqs):
353
+ chunks = self._chunk_sequence(seq)
354
+ chunk_vecs = []
355
+ for chunk in chunks:
356
+ batch = [(str(i), chunk)]
357
+ _, _, toks = self.batch_converter(batch)
358
+ toks = toks.to(self.device)
359
+ with torch.no_grad():
360
+ out = self.model(toks, repr_layers=[33], return_contacts=False)
361
+ reps = out["representations"][33]
362
+ if reps.size(1) > 2:
363
+ rep = reps[:, 1:-1].mean(1)
364
+ else:
365
+ rep = reps.mean(1)
366
+ chunk_vecs.append(rep.squeeze(0))
367
+ if len(chunk_vecs) == 1:
368
+ seq_vec = chunk_vecs[0]
369
+ else:
370
+ stacked = torch.stack(chunk_vecs, dim=0)
371
+ seq_vec = stacked.mean(0)
372
+ all_embeddings.append(seq_vec.cpu().numpy())
373
+ return np.vstack(all_embeddings)
374
+
375
+ class GPNEmbedder:
376
+ def __init__(self, device):
377
+ model_name = "songlab/gpn-msa-sapiens"
378
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
379
+ self.model = AutoModelForMaskedLM.from_pretrained(model_name)
380
+ self.model.to(device)
381
+ self.model.eval()
382
+ self.device = device
383
+
384
+ def embed(self, seqs):
385
+ inputs = self.tokenizer(
386
+ seqs,
387
+ return_tensors="pt",
388
+ padding=True,
389
+ truncation=True
390
+ ).to(self.device)
391
+
392
+ with torch.no_grad():
393
+ last_hidden = self.model(**inputs).last_hidden_state
394
+ return last_hidden.mean(dim=1).cpu().numpy()
395
+
396
+ class ProGenEmbedder:
397
+ def __init__(self, device):
398
+ model_name = "jinyuan22/ProGen2-base"
399
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
400
+ self.model = AutoModel.from_pretrained(model_name).to(device).eval()
401
+ self.device = device
402
+
403
+ def embed(self, seqs):
404
+ inputs = self.tokenizer(
405
+ seqs,
406
+ return_tensors="pt",
407
+ padding=True,
408
+ truncation=True
409
+ ).to(self.device)
410
+ with torch.no_grad():
411
+ last_hidden = self.model(**inputs).last_hidden_state
412
+ return last_hidden.mean(dim=1).cpu().numpy()
413
+
414
+ # ---- main pipeline ----
415
+
416
+ def get_embedder(name, device, for_dna=True):
417
+ name = name.lower()
418
+ if for_dna:
419
+ if name=="caduceus": return CaduceusEmbedder(device)
420
+ if name=="dnabert": return DNABertEmbedder(device)
421
+ if name=="nucleotide": return NucleotideTransformerEmbedder(device)
422
+ if name=="gpn": return GPNEmbedder(device)
423
+ if name=="segmentnt": return SegmentNTEmbedder(device)
424
+ else:
425
+ if name in ("esm",): return ESMEmbedder(device)
426
+ if name in ("esm-dbp","esm_dbp"): return ESMDBPEmbedder(device)
427
+ if name=="progen": return ProGenEmbedder(device)
428
+ raise ValueError(f"Unknown model {name} (for_dna={for_dna})")
429
+
430
+
431
+ def pad_token_embeddings(list_of_arrays, pad_value=0.0):
432
+ """
433
+ list_of_arrays: list of (L_i, D) numpy arrays
434
+ Returns:
435
+ padded: (N, L_max, D) array
436
+ mask: (N, L_max) boolean array where True = real token, False = padding
437
+ """
438
+ N = len(list_of_arrays)
439
+ D = list_of_arrays[0].shape[1]
440
+ L_max = max(arr.shape[0] for arr in list_of_arrays)
441
+ padded = np.full((N, L_max, D), pad_value, dtype=list_of_arrays[0].dtype)
442
+ mask = np.zeros((N, L_max), dtype=bool)
443
+ for i, arr in enumerate(list_of_arrays):
444
+ L = arr.shape[0]
445
+ padded[i, :L] = arr
446
+ mask[i, :L] = True
447
+ return padded, mask
448
+
449
+ def embed_and_save(seqs, ids, embedder, out_path):
450
+ embs = embedder.embed(seqs)
451
+
452
+ # Decide whether we got variable-length per-token outputs (list of (L, D))
453
+ is_variable_token = isinstance(embs, (list, tuple)) and len(embs) > 0 and hasattr(embs[0], "shape") and embs[0].ndim == 2
454
+
455
+ if is_variable_token:
456
+ # pad to (N, L_max, D) + mask
457
+ padded, mask = pad_token_embeddings(embs)
458
+ # Save both embeddings and mask together in an .npz for convenience
459
+ np.savez_compressed(out_path.with_suffix(".caduceus.npz"),
460
+ embeddings=padded,
461
+ mask=mask,
462
+ ids=np.array(ids, dtype=object))
463
+ else:
464
+ # fixed shape output, e.g., pooled (N, D)
465
+ array = np.vstack(embs) if isinstance(embs, list) else embs
466
+ np.save(out_path, array)
467
+ with open(out_path.with_suffix(".ids"), "w") as f:
468
+ f.write("\n".join(ids))
469
+
470
+
471
+ if __name__=="__main__":
472
+
473
+ p = argparse.ArgumentParser()
474
+ p.add_argument("--peak-fasta", default="binding_peaks_unique.fa", help="FASTA of deduplicated binding peak sequences; if present this is used for DNA embedding instead of genome JSONs")
475
+ p.add_argument("--genome-json-dir", default=None, help="(fallback) directory of UCSC JSONs for full chromosome embedding if peak FASTA is missing or you explicitly want chromosomes")
476
+ p.add_argument("--skip-dna", action="store_true", help="if set, skip the chromosome embedding step") #if glm embeddings successful but not plm embeddings
477
+ p.add_argument("--tf-fasta", required=True, help="input TF FASTA file")
478
+ p.add_argument("--chrom-model", default="caduceus")
479
+ p.add_argument("--tf-model", default="esm-dbp")
480
+ p.add_argument("--out-dir", default="data_files/processed/tfclust/hg38_tf/embeddings")
481
+ p.add_argument("--device", default="cpu")
482
+ args = p.parse_args()
483
+
484
+ os.makedirs(args.out_dir, exist_ok=True)
485
+ device = args.device
486
+
487
+ if not args.skip_dna:
488
+ peak_fasta = Path(args.peak_fasta)
489
+ if peak_fasta.exists():
490
+ # Load peak sequences from FASTA
491
+ from Bio import SeqIO
492
+
493
+ peak_seqs = []
494
+ peak_ids = []
495
+ for rec in SeqIO.parse(peak_fasta, "fasta"):
496
+ peak_ids.append(rec.id)
497
+ peak_seqs.append(str(rec.seq))
498
+ print(f"Embedding {len(peak_seqs)} binding peak sequences from {peak_fasta}", flush=True)
499
+ dna_embedder = get_embedder(args.chrom_model, device, for_dna=True)
500
+ out_peaks = Path(args.out_dir) / f"peaks_{args.chrom_model}.npy"
501
+ embed_and_save(peak_seqs, peak_ids, dna_embedder, out_peaks)
502
+ elif args.genome_json_dir:
503
+ # Legacy: load full chromosomes from JSONs (chr1–22, X, Y, M)
504
+ genome_dir = Path(args.genome_json_dir)
505
+ chrom_seqs, chrom_ids = [], []
506
+ primary_pattern = re.compile(r"^hg38_chr(?:[1-9]|1[0-9]|2[0-2]|X|Y|M)\.json$")
507
+ for j in sorted(genome_dir.iterdir()):
508
+ if not primary_pattern.match(j.name):
509
+ continue
510
+ data = json.loads(j.read_text())
511
+ seq = data.get("dna") or data.get("sequence")
512
+ chrom = data.get("chrom") or j.stem.split("_")[-1]
513
+ chrom_seqs.append(seq)
514
+ chrom_ids.append(chrom)
515
+ cutoff = CaduceusEmbedder(device).chunk_size
516
+ long_chroms = [
517
+ (chrom, len(seq))
518
+ for chrom, seq in zip(chrom_ids, chrom_seqs)
519
+ if len(seq) > cutoff
520
+ ]
521
+ if long_chroms:
522
+ print("⚠️ Chromosomes exceeding Caduceus max tokens ({}):".format(cutoff))
523
+ for chrom, L in long_chroms:
524
+ print(f" {chrom}: {L} bases")
525
+ else:
526
+ print("All chromosomes ≤ Caduceus limit ({}).".format(cutoff))
527
+
528
+ chrom_embedder = get_embedder(args.chrom_model, device, for_dna=True)
529
+ out_chrom = Path(args.out_dir) / f"chrom_{args.chrom_model}.npy"
530
+ embed_and_save(chrom_seqs, chrom_ids, chrom_embedder, out_chrom)
531
+ else:
532
+ raise ValueError("No input for DNA embedding: provide a peak FASTA (default binding_peaks_unique.fa) or set --genome-json-dir for chromosome JSONs.")
533
+
534
+
535
+ #Load TF sequences
536
+ tf_seqs, tf_ids = [], []
537
+ for record in SeqIO.parse(args.tf_fasta, "fasta"):
538
+ tf_ids.append(record.id)
539
+ tf_seqs.append(str(record.seq))
540
+
541
+ # embed and save
542
+ tf_embedder = get_embedder(args.tf_model, device, for_dna=False)
543
+ out_tf = Path(args.out_dir) / f"tf_{args.tf_model}.npy"
544
+ embed_and_save(tf_seqs, tf_ids, tf_embedder, out_tf)
545
+
546
+ print("Done.")
dpacman/classifier/model/extract_tf_symbols.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import pandas as pd
3
+ from pathlib import Path
4
+
5
+ FINAL_CSV = Path("/home/a03-akrishna/DPACMAN/data_files/processed/final.csv")
6
+ OUT_SYMBOLS = Path("tf_symbols.txt")
7
+
8
+ def normalize_tf(tf_id: str) -> str:
9
+ return tf_id.split("_seq")[0].upper()
10
+
11
+ def main():
12
+ df = pd.read_csv(FINAL_CSV, dtype=str)
13
+ if "TF_id" not in df.columns:
14
+ raise RuntimeError("final.csv missing TF_id column")
15
+ tf_raw = df["TF_id"].dropna().unique().tolist()
16
+ normalized = sorted({normalize_tf(t) for t in tf_raw})
17
+ print(f"Unique raw TF_id count: {len(tf_raw)}")
18
+ print(f"Unique normalized TF symbols: {len(normalized)}")
19
+ with open(OUT_SYMBOLS, "w") as f:
20
+ for s in normalized:
21
+ f.write(s + "\n")
22
+ print(f"Wrote normalized TF symbols to {OUT_SYMBOLS}")
23
+ # Optional: show sample
24
+ print("Sample symbols:", normalized[:50])
25
+
26
+ if __name__ == "__main__":
27
+ main()
dpacman/classifier/model/make_pair_list.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import argparse
3
+ import numpy as np
4
+ import pandas as pd
5
+ from pathlib import Path
6
+ import random
7
+ import sys
8
+
9
+ def read_ids_file(p):
10
+ p = Path(p)
11
+ if not p.exists():
12
+ raise FileNotFoundError(f"IDs file not found: {p}")
13
+ return [line.strip() for line in p.open() if line.strip()]
14
+
15
+ def split_embeddings(emb_path, ids_path, out_dir, prefix):
16
+ out_dir = Path(out_dir)
17
+ out_dir.mkdir(parents=True, exist_ok=True)
18
+
19
+ if not Path(emb_path).exists():
20
+ raise FileNotFoundError(f"Embedding file not found: {emb_path}")
21
+ if not Path(ids_path).exists():
22
+ raise FileNotFoundError(f"IDs file not found: {ids_path}")
23
+
24
+ if emb_path.endswith(".npz"):
25
+ data = np.load(emb_path, allow_pickle=True)
26
+ if "embeddings" in data:
27
+ emb = data["embeddings"]
28
+ else:
29
+ raise ValueError(f"{emb_path} missing 'embeddings' key")
30
+ else:
31
+ emb = np.load(emb_path)
32
+
33
+ ids = read_ids_file(ids_path)
34
+ if len(ids) != emb.shape[0]:
35
+ print(f"[WARN] length mismatch: {len(ids)} ids vs {emb.shape[0]} embeddings in {emb_path}", file=sys.stderr)
36
+
37
+ mapping = {}
38
+ for i, ident in enumerate(ids):
39
+ if i >= emb.shape[0]:
40
+ print(f"[WARN] skipping {ident}: no embedding at index {i}", file=sys.stderr)
41
+ continue
42
+ arr = emb[i]
43
+ out_file = out_dir / f"{prefix}_{ident}.npy"
44
+ np.save(out_file, arr)
45
+ mapping[ident] = str(out_file)
46
+ return mapping
47
+
48
+ def extract_symbol_from_tf_id(full_id: str) -> str:
49
+ """
50
+ Given a TF embedding ID like 'sp|O15062|ZBTB5_HUMAN' or 'ZBTB5_HUMAN',
51
+ return the gene symbol uppercase (e.g., 'ZBTB5').
52
+ """
53
+ if "|" in full_id:
54
+ try:
55
+ # format sp|Accession|SYMBOL_HUMAN
56
+ genepart = full_id.split("|")[2]
57
+ except IndexError:
58
+ genepart = full_id
59
+ else:
60
+ genepart = full_id
61
+ symbol = genepart.split("_")[0]
62
+ return symbol.upper()
63
+
64
+ def build_tf_symbol_map(tf_map):
65
+ """
66
+ Build mapping gene_symbol -> list of embedding paths.
67
+ """
68
+ symbol_map = {}
69
+ for full_id, path in tf_map.items():
70
+ symbol = extract_symbol_from_tf_id(full_id)
71
+ symbol_map.setdefault(symbol, []).append(path)
72
+ return symbol_map
73
+
74
+ def tf_key_from_path(path: str) -> str:
75
+ """
76
+ Given a path like .../tf_sp|O15062|ZBTB5_HUMAN.npy, extract normalized symbol 'ZBTB5'.
77
+ """
78
+ stem = Path(path).stem # e.g., tf_sp|O15062|ZBTB5_HUMAN
79
+ # remove leading prefix if present (tf_)
80
+ if "_" in stem:
81
+ _, rest = stem.split("_", 1)
82
+ else:
83
+ rest = stem
84
+ return extract_symbol_from_tf_id(rest)
85
+
86
+ def dna_key_from_path(path: str) -> str:
87
+ """
88
+ Given .../dna_peak42.npy -> 'peak42'
89
+ """
90
+ stem = Path(path).stem
91
+ if "_" in stem:
92
+ _, rest = stem.split("_", 1)
93
+ else:
94
+ rest = stem
95
+ return rest
96
+
97
+ def main():
98
+ parser = argparse.ArgumentParser(
99
+ description="Build TF-DNA pair list from final.csv with gene-symbol normalization for TFs."
100
+ )
101
+ parser.add_argument("--final_csv", required=True, help="final.csv with TF_id and dna_sequence")
102
+ parser.add_argument("--dna_embed_npz", required=True, help="DNA embedding file (.npy or .npz)")
103
+ parser.add_argument("--dna_ids", required=True, help="IDs file for DNA embeddings (e.g., peak*.ids)")
104
+ parser.add_argument("--tf_embed_npy", required=True, help="TF embedding file (.npy or .npz)")
105
+ parser.add_argument("--tf_ids", required=True, help="IDs file for TF embeddings (e.g., sp|...|... ids)")
106
+ parser.add_argument("--out_dir", required=True, help="Output directory")
107
+ parser.add_argument("--neg_per_positive", type=int, default=2, help="Negatives per positive (half same-TF, half same-DNA)")
108
+ parser.add_argument("--seed", type=int, default=42)
109
+ args = parser.parse_args()
110
+
111
+ random.seed(args.seed)
112
+ out_dir = Path(args.out_dir)
113
+ out_dir.mkdir(parents=True, exist_ok=True)
114
+
115
+ # Load final.csv
116
+ df = pd.read_csv(args.final_csv, dtype=str)
117
+ if "TF_id" not in df.columns or "dna_sequence" not in df.columns:
118
+ raise RuntimeError("final.csv must have columns TF_id and dna_sequence")
119
+
120
+ # Assign dna_id (unique per dna_sequence)
121
+ unique_seqs = df["dna_sequence"].drop_duplicates().tolist()
122
+ seq_to_id = {seq: f"peak{i}" for i, seq in enumerate(unique_seqs)}
123
+ df["dna_id"] = df["dna_sequence"].map(seq_to_id)
124
+ enriched_csv = out_dir / "final_with_dna_id.csv"
125
+ df.to_csv(enriched_csv, index=False)
126
+ print(f"[i] Wrote augmented final.csv with dna_id to {enriched_csv}")
127
+
128
+ # Split embeddings into per-item files
129
+ print(f"[i] Splitting DNA embeddings from {args.dna_embed_npz} with ids {args.dna_ids}")
130
+ dna_map = split_embeddings(args.dna_embed_npz, args.dna_ids, out_dir / "dna_single", "dna")
131
+ print(f"[i] DNA embeddings available: {len(dna_map)} (sample: {list(dna_map.keys())[:10]})")
132
+ print(f"[i] Splitting TF embeddings from {args.tf_embed_npy} with ids {args.tf_ids}")
133
+ tf_map = split_embeddings(args.tf_embed_npy, args.tf_ids, out_dir / "tf_single", "tf")
134
+ print(f"[i] TF embeddings available: {len(tf_map)} (sample: {list(tf_map.keys())[:10]})")
135
+
136
+ # Build gene-symbol normalized map
137
+ tf_symbol_map = build_tf_symbol_map(tf_map)
138
+ print(f"[i] TF symbol map keys (sample): {list(tf_symbol_map.keys())[:30]}")
139
+
140
+ # Diagnostic overlaps
141
+ norm_tf_in_final = set(t.split("_seq")[0].upper() for t in df["TF_id"].unique())
142
+ available_tf_symbols = set(tf_symbol_map.keys())
143
+ intersect_tf = norm_tf_in_final & available_tf_symbols
144
+ print(f"[i] Unique normalized TF symbols in final.csv: {len(norm_tf_in_final)}")
145
+ print(f"[i] Available TF embedding symbols: {len(available_tf_symbols)}")
146
+ print(f"[i] Intersection count: {len(intersect_tf)}")
147
+ if len(intersect_tf) == 0:
148
+ print("[ERROR] No overlap between normalized TF_id and TF embedding symbols.", file=sys.stderr)
149
+ print("Sample normalized TFs from final.csv:", sorted(list(norm_tf_in_final))[:30], file=sys.stderr)
150
+ print("Sample available TF symbols:", sorted(list(available_tf_symbols))[:30], file=sys.stderr)
151
+ sys.exit(1)
152
+
153
+ dna_ids_final = set(df["dna_id"].unique())
154
+ available_dna_ids = set(dna_map.keys())
155
+ intersect_dna = dna_ids_final & available_dna_ids
156
+ print(f"[i] Unique dna_id in final.csv: {len(dna_ids_final)}. Available DNA ids: {len(available_dna_ids)}. Intersection: {len(intersect_dna)}")
157
+ if len(intersect_dna) == 0:
158
+ print("[ERROR] No overlap on DNA ids.", file=sys.stderr)
159
+ sys.exit(1)
160
+
161
+ # Build positive pairs
162
+ positives = []
163
+ for _, row in df.iterrows():
164
+ tf_raw = row["TF_id"]
165
+ tf_symbol = tf_raw.split("_seq")[0].upper()
166
+ dnaid = row["dna_id"]
167
+ if tf_symbol not in tf_symbol_map:
168
+ continue
169
+ if dnaid not in dna_map:
170
+ continue
171
+ # pick the first embedding for that symbol
172
+ tf_embedding_path = tf_symbol_map[tf_symbol][0]
173
+ positives.append((tf_embedding_path, dna_map[dnaid], 1))
174
+ print(f"[i] Constructed {len(positives)} positive pairs after TF symbol resolution")
175
+
176
+ if len(positives) == 0:
177
+ print("[ERROR] No positive pairs could be constructed; aborting.", file=sys.stderr)
178
+ sys.exit(1)
179
+
180
+ # Build negative samples
181
+ all_tf_symbols = sorted(tf_symbol_map.keys())
182
+ all_dnaids = sorted(dna_map.keys())
183
+ positive_set = set()
184
+ for tf_path, dna_path, _ in positives:
185
+ tf_key = tf_key_from_path(tf_path)
186
+ dna_key = dna_key_from_path(dna_path)
187
+ positive_set.add((tf_key, dna_key))
188
+
189
+ negatives = []
190
+ half = args.neg_per_positive // 2
191
+ for tf_path, dna_path, _ in positives:
192
+ tf_key = tf_key_from_path(tf_path)
193
+ dna_key = dna_key_from_path(dna_path)
194
+ # same TF, different DNA
195
+ for _ in range(half):
196
+ candidate_dna = random.choice(all_dnaids)
197
+ if candidate_dna == dna_key or (tf_key, candidate_dna) in positive_set:
198
+ continue
199
+ negatives.append((tf_path, dna_map[candidate_dna], 0))
200
+ # same DNA, different TF
201
+ for _ in range(half):
202
+ candidate_tf_symbol = random.choice(all_tf_symbols)
203
+ if candidate_tf_symbol == tf_key or (candidate_tf_symbol, dna_key) in positive_set:
204
+ continue
205
+ # pick its first embedding
206
+ candidate_tf_path = tf_symbol_map[candidate_tf_symbol][0]
207
+ negatives.append((candidate_tf_path, dna_map[dnaid], 0))
208
+
209
+ print(f"[i] Sampled {len(negatives)} negatives (neg_per_positive={args.neg_per_positive})")
210
+
211
+ # Write pair list
212
+ pair_list_path = out_dir / "pair_list.tsv"
213
+ with open(pair_list_path, "w") as f:
214
+ for binder_path, glm_path, label in positives + negatives:
215
+ # binder=TF, glm=DNA
216
+ f.write(f"{binder_path}\t{glm_path}\t{label}\n")
217
+ print(f"[i] Wrote {len(positives)+len(negatives)} examples to {pair_list_path}")
218
+
219
+ if __name__ == "__main__":
220
+ main()
dpacman/classifier/model/make_peak_fasta.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from pathlib import Path
3
+
4
+ df = pd.read_csv("/home/a03-akrishna/DPACMAN/data_files/processed/final.csv", dtype=str) # adjust path if needed
5
+ # get unique sequences
6
+ uniq = df[["dna_sequence"]].drop_duplicates().reset_index(drop=True)
7
+ # make headers: e.g., peak0, peak1, ...
8
+ out_fa = Path("binding_peaks_unique.fa")
9
+ with open(out_fa, "w") as f:
10
+ for i, seq in enumerate(uniq["dna_sequence"]):
11
+ header = f">peak{i}"
12
+ f.write(f"{header}\n{seq}\n")
13
+ print(f"Wrote {len(uniq)} unique binding sequences to {out_fa}")
dpacman/classifier/model/model.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ class LocalCNN(nn.Module):
5
+ def __init__(self, dim: int = 256, kernel_size: int = 3):
6
+ super().__init__()
7
+ padding = kernel_size // 2
8
+ self.conv = nn.Conv1d(dim, dim, kernel_size=kernel_size, padding=padding)
9
+ self.act = nn.GELU()
10
+ self.ln = nn.LayerNorm(dim)
11
+
12
+ def forward(self, x: torch.Tensor):
13
+ # x: (batch, L, dim)
14
+ out = self.conv(x.transpose(1, 2)) # → (batch, dim, L)
15
+ out = self.act(out)
16
+ out = out.transpose(1, 2) # → (batch, L, dim)
17
+ return self.ln(out + x) # residual
18
+
19
+ class CrossModalBlock(nn.Module):
20
+ def __init__(self, dim: int = 256, heads: int = 8):
21
+ super().__init__()
22
+ # self-attention for both sides
23
+ self.sa_binder = nn.MultiheadAttention(dim, heads, batch_first=True)
24
+ self.sa_glm = nn.MultiheadAttention(dim, heads, batch_first=True)
25
+ self.ln_b1 = nn.LayerNorm(dim)
26
+ self.ln_g1 = nn.LayerNorm(dim)
27
+
28
+ self.ffn_b = nn.Sequential(nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim))
29
+ self.ffn_g = nn.Sequential(nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim))
30
+ self.ln_b2 = nn.LayerNorm(dim)
31
+ self.ln_g2 = nn.LayerNorm(dim)
32
+
33
+ # cross attention (binder queries, glm keys/values)
34
+ self.cross_attn = nn.MultiheadAttention(dim, heads, batch_first=True)
35
+ self.ln_c1 = nn.LayerNorm(dim)
36
+ self.ffn_c = nn.Sequential(nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim))
37
+ self.ln_c2 = nn.LayerNorm(dim)
38
+
39
+ def forward(self, binder: torch.Tensor, glm: torch.Tensor):
40
+ """
41
+ binder: (batch, Lb, dim)
42
+ glm: (batch, Lg, dim) -- has passed through its local CNN beforehand
43
+ returns: updated binder representation (batch, Lb, dim)
44
+ """
45
+ # binder self-attn + ffn
46
+ b = binder
47
+ b_sa, _ = self.sa_binder(b, b, b)
48
+ b = self.ln_b1(b + b_sa)
49
+ b_ff = self.ffn_b(b)
50
+ b = self.ln_b2(b + b_ff)
51
+
52
+ # glm self-attn + ffn
53
+ g = glm
54
+ g_sa, _ = self.sa_glm(g, g, g)
55
+ g = self.ln_g1(g + g_sa)
56
+ g_ff = self.ffn_g(g)
57
+ g = self.ln_g2(g + g_ff)
58
+
59
+ # cross-attention: binder queries glm
60
+ c_sa, _ = self.cross_attn(b, g, g)
61
+ c = self.ln_c1(b + c_sa)
62
+ c_ff = self.ffn_c(c)
63
+ c = self.ln_c2(c + c_ff)
64
+ return c # (batch, Lb, dim)
65
+
66
+ class BindPredictor(nn.Module):
67
+ def __init__(self,
68
+ input_dim: int = 256,
69
+ hidden_dim: int = 256,
70
+ heads: int = 8,
71
+ num_layers: int = 4,
72
+ use_local_cnn_on_glm: bool = True):
73
+ super().__init__()
74
+ self.proj_binder = nn.Linear(input_dim, hidden_dim)
75
+ self.proj_glm = nn.Linear(input_dim, hidden_dim)
76
+ self.use_local_cnn = use_local_cnn_on_glm
77
+ self.local_cnn = LocalCNN(hidden_dim) if use_local_cnn_on_glm else nn.Identity()
78
+
79
+ self.layers = nn.ModuleList([
80
+ CrossModalBlock(hidden_dim, heads) for _ in range(num_layers)
81
+ ])
82
+
83
+ self.ln_out = nn.LayerNorm(hidden_dim)
84
+ self.head = nn.Sequential(
85
+ nn.Linear(hidden_dim, 1),
86
+ nn.Sigmoid()
87
+ )
88
+
89
+ def forward(self, binder_emb, glm_emb):
90
+ """
91
+ binder_emb, glm_emb: (batch, L, input_dim)
92
+ """
93
+ b = self.proj_binder(binder_emb) # (B, Lb, hidden_dim)
94
+ g = self.proj_glm(glm_emb) # (B, Lg, hidden_dim)
95
+ if self.use_local_cnn:
96
+ g = self.local_cnn(g) # local context injected
97
+
98
+ for layer in self.layers:
99
+ b = layer(b, g) # update binder with cross-modal info
100
+
101
+ pooled = b.mean(dim=1) # (B, hidden_dim)
102
+ out = self.ln_out(pooled)
103
+ return self.head(out).squeeze(-1) # (B,)
dpacman/classifier/model/train.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import argparse
3
+ import numpy as np
4
+ import torch
5
+ from torch import nn
6
+ from torch.utils.data import Dataset, DataLoader
7
+ from model import BindPredictor
8
+ from pathlib import Path
9
+ from collections import Counter
10
+ from sklearn.metrics import roc_auc_score, average_precision_score
11
+ from sklearn.decomposition import TruncatedSVD
12
+ import random
13
+ import sys
14
+
15
+ # ---- dataset ---------------------------------------------------------
16
+ class PairDataset(Dataset):
17
+ def __init__(self, binder_paths, glm_paths, labels, tf_compressed_cache):
18
+ """
19
+ tf_compressed_cache: dict mapping binder_path -> compressed (256-d) tensor/array
20
+ """
21
+ assert len(binder_paths) == len(glm_paths) == len(labels)
22
+ self.binder_paths = binder_paths
23
+ self.glm_paths = glm_paths
24
+ self.labels = labels
25
+ self.tf_cache = tf_compressed_cache # already reduced to 256
26
+
27
+ def __len__(self):
28
+ return len(self.labels)
29
+
30
+ def __getitem__(self, idx):
31
+ # binder = TF embedding (possibly reduced)
32
+ b = self.tf_cache[self.binder_paths[idx]] # numpy array shape (L, 256) or (256,)
33
+ g = np.load(self.glm_paths[idx]) # glm (DNA) embedding
34
+
35
+ if b.ndim == 1:
36
+ b = b[None, :]
37
+ if g.ndim == 1:
38
+ g = g[None, :]
39
+
40
+ b_tensor = torch.from_numpy(b).float()
41
+ g_tensor = torch.from_numpy(g).float()
42
+ y = torch.tensor(self.labels[idx]).float()
43
+ return b_tensor, g_tensor, y
44
+
45
+ def collate_fn(batch):
46
+ binders, glms, labels = zip(*batch)
47
+ binder_lens = [b.shape[0] for b in binders]
48
+ glm_lens = [g.shape[0] for g in glms]
49
+ max_b = max(binder_lens)
50
+ max_g = max(glm_lens)
51
+
52
+ def pad_seq(seq, target_len):
53
+ L, D = seq.shape
54
+ if L < target_len:
55
+ pad = torch.zeros((target_len - L, D), dtype=seq.dtype, device=seq.device)
56
+ return torch.cat([seq, pad], dim=0)
57
+ return seq
58
+
59
+ b_padded = torch.stack([pad_seq(b, max_b) for b in binders]) # (B, Lb, D)
60
+ g_padded = torch.stack([pad_seq(g, max_g) for g in glms]) # (B, Lg, D)
61
+ y = torch.stack(labels)
62
+ return b_padded, g_padded, y
63
+
64
+ # ---- utilities -------------------------------------------------------
65
+ def parse_pair_list(pair_list_path):
66
+ binder_paths, glm_paths, labels = [], [], []
67
+ with open(pair_list_path) as f:
68
+ for lineno, line in enumerate(f, start=1):
69
+ if not line.strip():
70
+ continue
71
+ parts = line.strip().split()
72
+ if len(parts) != 3:
73
+ print(f"[WARN] skipping malformed line {lineno}: {line.strip()}", file=sys.stderr)
74
+ continue
75
+ b, g, l = parts
76
+ try:
77
+ lab = int(l)
78
+ except ValueError:
79
+ print(f"[WARN] invalid label on line {lineno}: {l}", file=sys.stderr)
80
+ continue
81
+ binder_paths.append(b)
82
+ glm_paths.append(g)
83
+ labels.append(lab)
84
+ return binder_paths, glm_paths, labels
85
+
86
+ def build_tf_compressed_cache(binder_paths, target_dim=256):
87
+ """
88
+ Load all unique TF (binder) embeddings, fit reduction if needed, and return dict mapping path->(L, target_dim) array.
89
+ """
90
+ unique_paths = sorted(set(binder_paths))
91
+ print(f"[i] Found {len(unique_paths)} unique TF embedding files to compress.", flush=True)
92
+ # Load all embeddings to determine dimensionality
93
+ samples = []
94
+ for p in unique_paths:
95
+ arr = np.load(p)
96
+ samples.append(arr)
97
+ # Determine if reduction needed: assume all have same embedding width
98
+ first = samples[0]
99
+ orig_dim = first.shape[1] if first.ndim == 2 else 1
100
+ reduction_needed = (orig_dim != target_dim)
101
+ tf_cache = {}
102
+
103
+ if reduction_needed:
104
+ # Build matrix to fit SVD: we need a 2D matrix per embedding; if lengths vary we can't directly stack.
105
+ # We'll do reduction per sequence individually using TruncatedSVD on concatenated flattened features:
106
+ # Simplest: for variable lengths, reduce each embedding separately with a learned linear projection.
107
+ # Here we fit a single TruncatedSVD on the concatenation of all sequence tokens (flattened) by padding/truncating to a fixed length.
108
+ # To avoid complexity, use PCA-like linear projection learned via SVD on mean-pooled vectors:
109
+ pooled = []
110
+ for arr in samples:
111
+ if arr.ndim == 2:
112
+ pooled.append(arr.mean(axis=0)) # (orig_dim,)
113
+ else:
114
+ pooled.append(arr) # degenerate
115
+ pooled_mat = np.stack(pooled, axis=0) # (N, orig_dim)
116
+ print(f"[i] Fitting TruncatedSVD on TF pooled embeddings: {pooled_mat.shape} -> {target_dim}", flush=True)
117
+ svd = TruncatedSVD(n_components=target_dim, random_state=42)
118
+ reduced_pooled = svd.fit_transform(pooled_mat) # (N, target_dim)
119
+
120
+ # For each original embedding, project token-level vectors by multiplying token vector with svd.components_.T
121
+ # svd.components_: (target_dim, orig_dim) so projection matrix is (orig_dim, target_dim)
122
+ proj_mat = svd.components_.T # (orig_dim, target_dim)
123
+ for i, p in enumerate(unique_paths):
124
+ arr = samples[i] # shape (L, orig_dim)
125
+ if arr.ndim == 1:
126
+ arr2 = arr @ proj_mat # (target_dim,)
127
+ else:
128
+ # project each token: (L, orig_dim) @ (orig_dim, target_dim) -> (L, target_dim)
129
+ arr2 = arr @ proj_mat
130
+ tf_cache[p] = arr2 # reduced per-token representation
131
+ print("[i] Completed compression of TF embeddings.", flush=True)
132
+ else:
133
+ # already correct dim: just cache originals
134
+ print(f"[i] TF embeddings already {target_dim}-dimensional; skipping reduction.", flush=True)
135
+ for i, p in enumerate(unique_paths):
136
+ arr = samples[i]
137
+ tf_cache[p] = arr
138
+ return tf_cache
139
+
140
+ def evaluate(model, dl, device):
141
+ model.eval()
142
+ all_labels = []
143
+ all_preds = []
144
+ with torch.no_grad():
145
+ for b, g, y in dl:
146
+ b = b.to(device)
147
+ g = g.to(device)
148
+ y = y.to(device)
149
+ pred = model(b, g)
150
+ all_labels.append(y.cpu())
151
+ all_preds.append(pred.cpu())
152
+ if not all_labels:
153
+ return 0.0, 0.0
154
+ y_true = torch.cat(all_labels).numpy()
155
+ y_score = torch.cat(all_preds).numpy()
156
+ try:
157
+ auc = roc_auc_score(y_true, y_score)
158
+ except Exception:
159
+ auc = 0.0
160
+ try:
161
+ ap = average_precision_score(y_true, y_score)
162
+ except Exception:
163
+ ap = 0.0
164
+ return auc, ap
165
+
166
+ # ---- main ------------------------------------------------------------
167
+ def main():
168
+ parser = argparse.ArgumentParser()
169
+ parser.add_argument("--pair_list", type=str, required=True,
170
+ help="TSV: binder_path glm_path label")
171
+ parser.add_argument("--out_dir", type=str, required=True)
172
+ parser.add_argument("--epochs", type=int, default=10)
173
+ parser.add_argument("--batch_size", type=int, default=32)
174
+ parser.add_argument("--lr", type=float, default=1e-4)
175
+ parser.add_argument("--device", type=str, default="cuda")
176
+ parser.add_argument("--seed", type=int, default=42)
177
+ args = parser.parse_args()
178
+
179
+ # reproducibility
180
+ random.seed(args.seed)
181
+ np.random.seed(args.seed)
182
+ torch.manual_seed(args.seed)
183
+
184
+ print("DEBUG: starting training script with in-line TF compression", flush=True)
185
+ print(f"[i] pair_list: {args.pair_list}", flush=True)
186
+ print(f"[i] output dir: {args.out_dir}", flush=True)
187
+ device = torch.device(args.device if torch.cuda.is_available() else "cpu")
188
+ binder_paths, glm_paths, labels = parse_pair_list(args.pair_list)
189
+
190
+ if len(labels) == 0:
191
+ print("[ERROR] No valid pairs parsed. Exiting.", file=sys.stderr)
192
+ sys.exit(1)
193
+
194
+ label_counts = Counter(labels)
195
+ print(f"[i] Total examples parsed: {len(labels)}. Label distribution: {label_counts}", flush=True)
196
+
197
+ # build compressed TF cache (reduces to 256 if needed)
198
+ tf_compressed_cache = build_tf_compressed_cache(binder_paths, target_dim=256)
199
+
200
+ # simple split: 80/10/10
201
+ n = len(labels)
202
+ idxs = np.arange(n)
203
+ np.random.shuffle(idxs)
204
+ train_i = idxs[: int(0.8 * n)]
205
+ val_i = idxs[int(0.8 * n): int(0.9 * n)]
206
+ test_i = idxs[int(0.9 * n):]
207
+
208
+ def subset(idxs):
209
+ return [binder_paths[i] for i in idxs], [glm_paths[i] for i in idxs], [labels[i] for i in idxs]
210
+
211
+ train_ds = PairDataset(*subset(train_i), tf_compressed_cache=tf_compressed_cache)
212
+ val_ds = PairDataset(*subset(val_i), tf_compressed_cache=tf_compressed_cache)
213
+ test_ds = PairDataset(*subset(test_i), tf_compressed_cache=tf_compressed_cache)
214
+
215
+ print(f"[i] Train/Val/Test sizes: {len(train_ds)}/{len(val_ds)}/{len(test_ds)}", flush=True)
216
+ if len(train_ds) == 0 or len(val_ds) == 0:
217
+ print("[ERROR] Train or validation split is empty; cannot proceed.", file=sys.stderr)
218
+ sys.exit(1)
219
+
220
+ train_dl = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn)
221
+ val_dl = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn)
222
+ test_dl = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn)
223
+
224
+ model = BindPredictor(input_dim=256, hidden_dim=256, heads=8, num_layers=3, use_local_cnn_on_glm=True)
225
+ model = model.to(device)
226
+ optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-3)
227
+ loss_fn = nn.BCELoss()
228
+
229
+ best_val = -float("inf")
230
+ os_out = Path(args.out_dir)
231
+ os_out.mkdir(exist_ok=True, parents=True)
232
+
233
+ for epoch in range(1, args.epochs + 1):
234
+ print(f"[Epoch {epoch}] starting...", flush=True)
235
+ model.train()
236
+ running_loss = 0.0
237
+ for b, g, y in train_dl:
238
+ b = b.to(device)
239
+ g = g.to(device)
240
+ y = y.to(device)
241
+ pred = model(b, g)
242
+ loss = loss_fn(pred, y)
243
+ optimizer.zero_grad()
244
+ loss.backward()
245
+ optimizer.step()
246
+ running_loss += loss.item() * b.size(0)
247
+ train_loss = running_loss / len(train_ds)
248
+ val_auc, val_ap = evaluate(model, val_dl, device)
249
+ print(f"[Epoch {epoch}] train_loss={train_loss:.4f} val_auc={val_auc:.4f} val_ap={val_ap:.4f}", flush=True)
250
+
251
+ if val_auc > best_val:
252
+ best_val = val_auc
253
+ torch.save(model.state_dict(), os_out / "best_model.pt")
254
+ print(f"[Epoch {epoch}] Saved new best model with val_auc={val_auc:.4f}", flush=True)
255
+
256
+ torch.save(model.state_dict(), os_out / "last_model.pt")
257
+ test_auc, test_ap = evaluate(model, test_dl, device)
258
+ print(f"FINAL TEST: AUC={test_auc:.4f} AP={test_ap:.4f}", flush=True)
259
+ print(f"[i] Models written to {os_out}/best_model.pt and last_model.pt", flush=True)
260
+
261
+ if __name__ == "__main__":
262
+ main()