svincoff commited on
Commit
fad71ca
Β·
1 Parent(s): 1927f8f

ananya updates

Browse files
dpacman/classifier/model/clustering_data.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import argparse
3
+ import numpy as np
4
+ import pandas as pd
5
+ from pathlib import Path
6
+ import random
7
+ import sys
8
+ import subprocess
9
+ from collections import defaultdict
10
+
11
+ # ─────────────────────────────────────────────────────────────────────────
12
+ # Original helpers (kept; some lightly edited/commented where needed)
13
+ # ─────────────────────────────────────────────────────────────────────────
14
+
15
+ def read_ids_file(p):
16
+ p = Path(p)
17
+ if not p.exists():
18
+ raise FileNotFoundError(f"IDs file not found: {p}")
19
+ return [line.strip() for line in p.open() if line.strip()]
20
+
21
+ def split_embeddings(emb_path, ids_path, out_dir, prefix):
22
+ out_dir = Path(out_dir)
23
+ out_dir.mkdir(parents=True, exist_ok=True)
24
+
25
+ if not Path(emb_path).exists():
26
+ raise FileNotFoundError(f"Embedding file not found: {emb_path}")
27
+ if not Path(ids_path).exists():
28
+ raise FileNotFoundError(f"IDs file not found: {ids_path}")
29
+
30
+ if emb_path.endswith(".npz"):
31
+ data = np.load(emb_path, allow_pickle=True)
32
+ if "embeddings" in data:
33
+ emb = data["embeddings"]
34
+ else:
35
+ raise ValueError(f"{emb_path} missing 'embeddings' key")
36
+ else:
37
+ emb = np.load(emb_path)
38
+
39
+ ids = read_ids_file(ids_path)
40
+ if len(ids) != emb.shape[0]:
41
+ print(f"[WARN] length mismatch: {len(ids)} ids vs {emb.shape[0]} embeddings in {emb_path}", file=sys.stderr)
42
+
43
+ mapping = {}
44
+ for i, ident in enumerate(ids):
45
+ if i >= emb.shape[0]:
46
+ print(f"[WARN] skipping {ident}: no embedding at index {i}", file=sys.stderr)
47
+ continue
48
+ arr = emb[i]
49
+ out_file = out_dir / f"{prefix}_{ident}.npy"
50
+ np.save(out_file, arr)
51
+ mapping[ident] = str(out_file)
52
+ return mapping
53
+
54
+ def extract_symbol_from_tf_id(full_id: str) -> str:
55
+ """
56
+ Given a TF embedding ID like 'sp|O15062|ZBTB5_HUMAN' or 'ZBTB5_HUMAN',
57
+ return the gene symbol uppercase (e.g., 'ZBTB5').
58
+ """
59
+ if "|" in full_id:
60
+ try:
61
+ # format sp|Accession|SYMBOL_HUMAN
62
+ genepart = full_id.split("|")[2]
63
+ except IndexError:
64
+ genepart = full_id
65
+ else:
66
+ genepart = full_id
67
+ symbol = genepart.split("_")[0]
68
+ return symbol.upper()
69
+
70
+ def build_tf_symbol_map(tf_map):
71
+ """
72
+ Build mapping gene_symbol -> list of embedding paths.
73
+ """
74
+ symbol_map = {}
75
+ for full_id, path in tf_map.items():
76
+ symbol = extract_symbol_from_tf_id(full_id)
77
+ symbol_map.setdefault(symbol, []).append(path)
78
+ return symbol_map
79
+
80
+ def tf_key_from_path(path: str) -> str:
81
+ """
82
+ Given a path like .../tf_sp|O15062|ZBTB5_HUMAN.npy, extract normalized symbol 'ZBTB5'.
83
+ """
84
+ stem = Path(path).stem # e.g., tf_sp|O15062|ZBTB5_HUMAN
85
+ # remove leading prefix if present (tf_)
86
+ if "_" in stem:
87
+ _, rest = stem.split("_", 1)
88
+ else:
89
+ rest = stem
90
+ return extract_symbol_from_tf_id(rest)
91
+
92
+ def dna_key_from_path(path: str) -> str:
93
+ """
94
+ Given .../dna_peak42.npy -> 'peak42'
95
+ """
96
+ stem = Path(path).stem
97
+ if "_" in stem:
98
+ _, rest = stem.split("_", 1)
99
+ else:
100
+ rest = stem
101
+ return rest
102
+
103
+ # ─────────────────────────────────────────────────────────────────────────
104
+ # New helpers for MMseqs clustering & cluster-level splitting
105
+ # ─────────────────────────────────────────────────────────────────────────
106
+
107
+ def write_dna_fasta(df: pd.DataFrame, out_fasta: Path) -> None:
108
+ """
109
+ Write unique DNA sequences to FASTA using dna_id as header.
110
+ Requires df with columns: dna_id, dna_sequence
111
+ """
112
+ uniq = df[["dna_id", "dna_sequence"]].drop_duplicates()
113
+ with open(out_fasta, "w") as f:
114
+ for _, row in uniq.iterrows():
115
+ did = row["dna_id"]
116
+ seq = str(row["dna_sequence"]).upper().replace(" ", "").replace("\n", "")
117
+ f.write(f">{did}\n{seq}\n")
118
+
119
+ def run_mmseqs_easy_cluster(
120
+ mmseqs_bin: str,
121
+ fasta: Path,
122
+ out_prefix: Path,
123
+ tmp_dir: Path,
124
+ min_seq_id: float,
125
+ coverage: float,
126
+ cov_mode: int,
127
+ ) -> Path:
128
+ """
129
+ Runs mmseqs easy-cluster on nucleotide sequences.
130
+ Returns the path to a clusters TSV file (creating it if the default one isn't present).
131
+ """
132
+ tmp_dir.mkdir(parents=True, exist_ok=True)
133
+ out_prefix.parent.mkdir(parents=True, exist_ok=True)
134
+
135
+ cmd = [
136
+ mmseqs_bin, "easy-cluster",
137
+ str(fasta), str(out_prefix), str(tmp_dir),
138
+ "--min-seq-id", str(min_seq_id),
139
+ "-c", str(coverage),
140
+ "--cov-mode", str(cov_mode),
141
+ # You can add performance flags here if needed, e.g.:
142
+ # "--threads", "8"
143
+ ]
144
+ print("[i] Running:", " ".join(cmd), flush=True)
145
+ subprocess.run(cmd, check=True)
146
+
147
+ # MMseqs easy-cluster typically writes <out_prefix>_cluster.tsv
148
+ default_tsv = Path(str(out_prefix) + "_cluster.tsv")
149
+ if default_tsv.exists():
150
+ print(f"[i] Found cluster TSV: {default_tsv}")
151
+ return default_tsv
152
+
153
+ # Fallback: try createtsv if default is missing
154
+ # This requires the internal DBs. easy-cluster creates DBs alongside out_prefix.
155
+ # We'll try to locate them and emit a TSV.
156
+ in_db = Path(str(out_prefix) + "_query")
157
+ cl_db = Path(str(out_prefix) + "_cluster")
158
+ out_tsv = Path(str(out_prefix) + "_fallback_cluster.tsv")
159
+ if in_db.exists() and cl_db.exists():
160
+ cmd2 = [mmseqs_bin, "createtsv", str(in_db), str(in_db), str(cl_db), str(out_tsv)]
161
+ print("[i] Creating TSV via createtsv:", " ".join(cmd2), flush=True)
162
+ subprocess.run(cmd2, check=True)
163
+ if out_tsv.exists():
164
+ return out_tsv
165
+
166
+ raise FileNotFoundError("Could not locate clusters TSV from mmseqs. "
167
+ "Expected {default_tsv} or createtsv fallback.")
168
+
169
+ def parse_mmseqs_clusters(tsv_path: Path) -> dict:
170
+ """
171
+ Parse MMseqs cluster TSV (rep \t member). Returns dna_id -> cluster_rep_id
172
+ """
173
+ mapping = {}
174
+ with open(tsv_path) as f:
175
+ for line in f:
176
+ parts = line.rstrip("\n").split("\t")
177
+ if len(parts) < 2:
178
+ continue
179
+ rep, member = parts[0], parts[1]
180
+ mapping[member] = rep
181
+ # Some TSVs include rep->rep; if not, ensure rep is mapped to itself:
182
+ if rep not in mapping:
183
+ mapping[rep] = rep
184
+ return mapping
185
+
186
+ def assign_clusters_to_splits(cluster_rep_to_members: dict,
187
+ val_frac: float,
188
+ test_frac: float,
189
+ seed: int = 42):
190
+ """
191
+ cluster_rep_to_members: dict[rep] = [members...]
192
+ Returns: dict with keys 'train','val','test' mapping to sets of dna_id.
193
+ Ensures all members of a cluster go to the same split.
194
+ """
195
+ rng = random.Random(seed)
196
+ reps = list(cluster_rep_to_members.keys())
197
+ rng.shuffle(reps)
198
+
199
+ # Greedy-ish fill by total member counts to match desired fractions.
200
+ total = sum(len(cluster_rep_to_members[r]) for r in reps)
201
+ target_val = int(round(total * val_frac))
202
+ target_test = int(round(total * test_frac))
203
+ cur_val = cur_test = 0
204
+
205
+ val_ids, test_ids, train_ids = set(), set(), set()
206
+ for rep in reps:
207
+ members = cluster_rep_to_members[rep]
208
+ c = len(members)
209
+ # Fill val first, then test, then train
210
+ if cur_val + c <= target_val:
211
+ val_ids.update(members); cur_val += c
212
+ elif cur_test + c <= target_test:
213
+ test_ids.update(members); cur_test += c
214
+ else:
215
+ train_ids.update(members)
216
+
217
+ return {"train": train_ids, "val": val_ids, "test": test_ids}
218
+
219
+ # ─────────────────────────────────────────────────────────────────────────
220
+ # Main
221
+ # ─────────────────────────────────────────────────────────────────────────
222
+
223
+ def main():
224
+ parser = argparse.ArgumentParser(
225
+ description="Build TF-DNA pair lists with MMseqs clustering on DNA to prevent split leakage."
226
+ )
227
+ parser.add_argument("--final_csv", required=True, help="final.csv with TF_id and dna_sequence")
228
+ parser.add_argument("--dna_embed_npz", required=True, help="DNA embedding file (.npy or .npz)")
229
+ parser.add_argument("--dna_ids", required=True, help="IDs file for DNA embeddings (peak*.ids)")
230
+ parser.add_argument("--tf_embed_npy", required=True, help="TF embedding file (.npy or .npz)")
231
+ parser.add_argument("--tf_ids", required=True, help="IDs file for TF embeddings (sp|... ids)")
232
+ parser.add_argument("--out_dir", required=True, help="Output directory")
233
+ parser.add_argument("--seed", type=int, default=42)
234
+
235
+ # NEW: MMseqs options & split fractions
236
+ parser.add_argument("--mmseqs_bin", default="mmseqs", help="Path to mmseqs binary")
237
+ parser.add_argument("--min_seq_id", type=float, default=0.9, help="MMseqs --min-seq-id")
238
+ parser.add_argument("--cov", type=float, default=0.8, help="MMseqs -c coverage fraction")
239
+ parser.add_argument("--cov_mode", type=int, default=1, help="MMseqs --cov-mode (1 = coverage of target)")
240
+ parser.add_argument("--val_frac", type=float, default=0.10)
241
+ parser.add_argument("--test_frac", type=float, default=0.10)
242
+ parser.add_argument("--tmp_dir", default=None, help="MMseqs tmp dir (defaults to out_dir/tmp)")
243
+ args = parser.parse_args()
244
+
245
+ random.seed(args.seed)
246
+ out_dir = Path(args.out_dir)
247
+ out_dir.mkdir(parents=True, exist_ok=True)
248
+
249
+ # Load final.csv
250
+ df = pd.read_csv(args.final_csv, dtype=str)
251
+ if "TF_id" not in df.columns or "dna_sequence" not in df.columns:
252
+ raise RuntimeError("final.csv must have columns TF_id and dna_sequence")
253
+
254
+ # Assign dna_id (unique per dna_sequence)
255
+ unique_seqs = df["dna_sequence"].drop_duplicates().tolist()
256
+ seq_to_id = {seq: f"peak{i}" for i, seq in enumerate(unique_seqs)}
257
+ df["dna_id"] = df["dna_sequence"].map(seq_to_id)
258
+ enriched_csv = out_dir / "final_with_dna_id.csv"
259
+ df.to_csv(enriched_csv, index=False)
260
+ print(f"[i] Wrote augmented final.csv with dna_id to {enriched_csv}")
261
+
262
+ # Split embeddings into per-item files (unchanged)
263
+ print(f"[i] Splitting DNA embeddings from {args.dna_embed_npz} with ids {args.dna_ids}")
264
+ dna_map = split_embeddings(args.dna_embed_npz, args.dna_ids, out_dir / "dna_single", "dna")
265
+ print(f"[i] DNA embeddings available: {len(dna_map)} (sample: {list(dna_map.keys())[:10]})")
266
+ print(f"[i] Splitting TF embeddings from {args.tf_embed_npy} with ids {args.tf_ids}")
267
+ tf_map = split_embeddings(args.tf_embed_npy, args.tf_ids, out_dir / "tf_single", "tf")
268
+ print(f"[i] TF embeddings available: {len(tf_map)} (sample: {list(tf_map.keys())[:10]})")
269
+
270
+ # Build gene-symbol normalized map
271
+ tf_symbol_map = build_tf_symbol_map(tf_map)
272
+ print(f"[i] TF symbol map keys (sample): {list(tf_symbol_map.keys())[:30]}")
273
+
274
+ # Diagnostic overlaps
275
+ norm_tf_in_final = set(t.split("_seq")[0].upper() for t in df["TF_id"].unique())
276
+ available_tf_symbols = set(tf_symbol_map.keys())
277
+ intersect_tf = norm_tf_in_final & available_tf_symbols
278
+ print(f"[i] Unique normalized TF symbols in final.csv: {len(norm_tf_in_final)}")
279
+ print(f"[i] Available TF embedding symbols: {len(available_tf_symbols)}")
280
+ print(f"[i] Intersection count: {len(intersect_tf)}")
281
+ if len(intersect_tf) == 0:
282
+ print("[ERROR] No overlap between normalized TF_id and TF embedding symbols.", file=sys.stderr)
283
+ print("Sample normalized TFs from final.csv:", sorted(list(norm_tf_in_final))[:30], file=sys.stderr)
284
+ print("Sample available TF symbols:", sorted(list(available_tf_symbols))[:30], file=sys.stderr)
285
+ sys.exit(1)
286
+
287
+ dna_ids_final = set(df["dna_id"].unique())
288
+ available_dna_ids = set(dna_map.keys())
289
+ intersect_dna = dna_ids_final & available_dna_ids
290
+ print(f"[i] Unique dna_id in final.csv: {len(dna_ids_final)}. Available DNA ids: {len(available_dna_ids)}. Intersection: {len(intersect_dna)}")
291
+ if len(intersect_dna) == 0:
292
+ print("[ERROR] No overlap on DNA ids.", file=sys.stderr)
293
+ sys.exit(1)
294
+
295
+ # ── NEW: MMseqs clustering on DNA sequences ───────────────────────────
296
+ fasta_path = out_dir / "dna_unique.fasta"
297
+ write_dna_fasta(df, fasta_path)
298
+ print(f"[i] Wrote FASTA with {df['dna_id'].nunique()} unique sequences β†’ {fasta_path}")
299
+
300
+ tmp_dir = Path(args.tmp_dir) if args.tmp_dir else (out_dir / "mmseqs_tmp")
301
+ cluster_prefix = out_dir / "mmseqs_dna_clusters"
302
+ clusters_tsv = run_mmseqs_easy_cluster(
303
+ mmseqs_bin=args.mmseqs_bin,
304
+ fasta=fasta_path,
305
+ out_prefix=cluster_prefix,
306
+ tmp_dir=tmp_dir,
307
+ min_seq_id=args.min_seq_id,
308
+ coverage=args.cov,
309
+ cov_mode=args.cov_mode,
310
+ )
311
+
312
+ # Parse clusters
313
+ member_to_rep = parse_mmseqs_clusters(clusters_tsv) # dna_id -> rep_id
314
+ # Build rep -> members list
315
+ rep_to_members = defaultdict(list)
316
+ for member, rep in member_to_rep.items():
317
+ rep_to_members[rep].append(member)
318
+
319
+ print(f"[i] Parsed {len(rep_to_members)} clusters from {clusters_tsv}")
320
+ clusters_table = []
321
+ for rep, members in rep_to_members.items():
322
+ for m in members:
323
+ clusters_table.append((m, rep))
324
+ clusters_df = pd.DataFrame(clusters_table, columns=["dna_id", "cluster_id"])
325
+ clusters_df.to_csv(out_dir / "clusters.tsv", sep="\t", index=False)
326
+ print(f"[i] Wrote clusters mapping β†’ {out_dir / 'clusters.tsv'}")
327
+
328
+ # Attach cluster_id back to final df
329
+ df = df.merge(clusters_df, on="dna_id", how="left")
330
+ df.to_csv(out_dir / "final_with_dna_id_and_cluster.csv", index=False)
331
+ print(f"[i] Wrote {out_dir / 'final_with_dna_id_and_cluster.csv'}")
332
+
333
+ # Assign entire clusters to splits
334
+ splits = assign_clusters_to_splits(rep_to_members,
335
+ val_frac=args.val_frac,
336
+ test_frac=args.test_frac,
337
+ seed=args.seed)
338
+ for k in ["train", "val", "test"]:
339
+ print(f"[i] {k}: {len(splits[k])} dna_ids")
340
+
341
+ # ── Build positive pairs only, per split (NO negatives) ───────────────
342
+ positives_by_split = {"train": [], "val": [], "test": []}
343
+ # Build a quick dna_id -> embedding path map
344
+ dnaid_to_path = {did: path for did, path in dna_map.items()}
345
+
346
+ pos_count = 0
347
+ for _, row in df.iterrows():
348
+ tf_raw = row["TF_id"]
349
+ tf_symbol = tf_raw.split("_seq")[0].upper()
350
+ dnaid = row["dna_id"]
351
+ if (tf_symbol not in tf_symbol_map) or (dnaid not in dnaid_to_path):
352
+ continue
353
+ tf_embedding_path = tf_symbol_map[tf_symbol][0] # first embedding per symbol
354
+
355
+ # decide split by dna_id cluster assignment
356
+ if dnaid in splits["train"]:
357
+ positives_by_split["train"].append((tf_embedding_path, dnaid_to_path[dnaid], 1))
358
+ elif dnaid in splits["val"]:
359
+ positives_by_split["val"].append((tf_embedding_path, dnaid_to_path[dnaid], 1))
360
+ elif dnaid in splits["test"]:
361
+ positives_by_split["test"].append((tf_embedding_path, dnaid_to_path[dnaid], 1))
362
+ pos_count += 1
363
+
364
+ print(f"[i] Constructed positives across splits (rows in final.csv iterated: {len(df)})")
365
+ for k in ["train", "val", "test"]:
366
+ print(f"[i] positives[{k}] = {len(positives_by_split[k])}")
367
+
368
+ # # OLD: negatives (kept commented)
369
+ # negatives = []
370
+ # print(f"[i] Sampled {len(negatives)} negatives (neg_per_positive not used)")
371
+
372
+ # Emit split-specific pair lists
373
+ for split in ["train", "val", "test"]:
374
+ out_tsv = out_dir / f"pair_list_{split}.tsv"
375
+ with open(out_tsv, "w") as f:
376
+ for binder_path, glm_path, label in positives_by_split[split]: # + negatives if you add later
377
+ f.write(f"{binder_path}\t{glm_path}\t{label}\n")
378
+ print(f"[i] Wrote {len(positives_by_split[split])} examples to {out_tsv}")
379
+
380
+ print("βœ… Done. Cluster-aware splits ready.")
381
+
382
+ if __name__ == "__main__":
383
+ main()
dpacman/classifier/model/compute_embeddings.py CHANGED
@@ -25,6 +25,9 @@ from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM, pipelin
25
  import esm
26
  from Bio import SeqIO
27
  import time
 
 
 
28
 
29
  # ---- model wrappers ----
30
 
@@ -67,7 +70,7 @@ class CaduceusEmbedder:
67
 
68
  # return np.stack(outputs, axis=0) # (N, L, D)
69
  outputs = []
70
- for seq in seqs:
71
  toks = self.tokenizer(
72
  seq,
73
  return_tensors="pt",
@@ -471,34 +474,45 @@ def embed_and_save(seqs, ids, embedder, out_path):
471
  if __name__=="__main__":
472
 
473
  p = argparse.ArgumentParser()
474
- p.add_argument("--peak-fasta", default="binding_peaks_unique.fa", help="FASTA of deduplicated binding peak sequences; if present this is used for DNA embedding instead of genome JSONs")
475
  p.add_argument("--genome-json-dir", default=None, help="(fallback) directory of UCSC JSONs for full chromosome embedding if peak FASTA is missing or you explicitly want chromosomes")
476
  p.add_argument("--skip-dna", action="store_true", help="if set, skip the chromosome embedding step") #if glm embeddings successful but not plm embeddings
477
  p.add_argument("--tf-fasta", required=True, help="input TF FASTA file")
478
  p.add_argument("--chrom-model", default="caduceus")
479
  p.add_argument("--tf-model", default="esm-dbp")
480
- p.add_argument("--out-dir", default="data_files/processed/tfclust/hg38_tf/embeddings")
481
  p.add_argument("--device", default="cpu")
482
  args = p.parse_args()
483
 
484
  os.makedirs(args.out_dir, exist_ok=True)
485
  device = args.device
 
486
 
487
  if not args.skip_dna:
488
- peak_fasta = Path(args.peak_fasta)
489
- if peak_fasta.exists():
490
- # Load peak sequences from FASTA
491
- from Bio import SeqIO
492
-
493
- peak_seqs = []
494
- peak_ids = []
495
- for rec in SeqIO.parse(peak_fasta, "fasta"):
496
- peak_ids.append(rec.id)
497
- peak_seqs.append(str(rec.seq))
498
- print(f"Embedding {len(peak_seqs)} binding peak sequences from {peak_fasta}", flush=True)
499
  dna_embedder = get_embedder(args.chrom_model, device, for_dna=True)
500
  out_peaks = Path(args.out_dir) / f"peaks_{args.chrom_model}.npy"
501
  embed_and_save(peak_seqs, peak_ids, dna_embedder, out_peaks)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
502
  elif args.genome_json_dir:
503
  # Legacy: load full chromosomes from JSONs (chr1–22, X, Y, M)
504
  genome_dir = Path(args.genome_json_dir)
 
25
  import esm
26
  from Bio import SeqIO
27
  import time
28
+ import pandas as pd
29
+ from tqdm.auto import tqdm
30
+ import logging, math
31
 
32
  # ---- model wrappers ----
33
 
 
70
 
71
  # return np.stack(outputs, axis=0) # (N, L, D)
72
  outputs = []
73
+ for seq in tqdm(seqs, total=len(seqs), desc="DNA: Caduceus", dynamic_ncols=True):
74
  toks = self.tokenizer(
75
  seq,
76
  return_tensors="pt",
 
474
  if __name__=="__main__":
475
 
476
  p = argparse.ArgumentParser()
477
+ #p.add_argument("--peak_fasta", default="binding_peaks_unique.fa", help="FASTA of deduplicated binding peak sequences; if present this is used for DNA embedding instead of genome JSONs")
478
  p.add_argument("--genome-json-dir", default=None, help="(fallback) directory of UCSC JSONs for full chromosome embedding if peak FASTA is missing or you explicitly want chromosomes")
479
  p.add_argument("--skip-dna", action="store_true", help="if set, skip the chromosome embedding step") #if glm embeddings successful but not plm embeddings
480
  p.add_argument("--tf-fasta", required=True, help="input TF FASTA file")
481
  p.add_argument("--chrom-model", default="caduceus")
482
  p.add_argument("--tf-model", default="esm-dbp")
483
+ p.add_argument("--out-dir", default="dpacman/model/embeddings")
484
  p.add_argument("--device", default="cpu")
485
  args = p.parse_args()
486
 
487
  os.makedirs(args.out_dir, exist_ok=True)
488
  device = args.device
489
+ print(device)
490
 
491
  if not args.skip_dna:
492
+ if args.genome_json_dir == None:
493
+ dna_df = pd.read_parquet('/home/a03-akrishna/DPACMAN/dpacman/model/remap2022_crm_fimo_output_q_processed.parquet', engine='pyarrow')
494
+ #df.to_csv('/home/a03-akrishna/DPACMAN/dpacman/model/remap2022_crm_fimo_output_q_processed.csv', index=False)
495
+ peak_seqs = dna_df["dna_sequence"]
496
+ peak_ids = dna_df["ID"]
497
+ print(f"Embedding {len(peak_seqs)} binding peak sequences from processed remap data", flush=True)
 
 
 
 
 
498
  dna_embedder = get_embedder(args.chrom_model, device, for_dna=True)
499
  out_peaks = Path(args.out_dir) / f"peaks_{args.chrom_model}.npy"
500
  embed_and_save(peak_seqs, peak_ids, dna_embedder, out_peaks)
501
+
502
+ # peak_fasta = Path(args.peak_fasta)
503
+ # if peak_fasta.exists():
504
+ # # Load peak sequences from FASTA
505
+ # from Bio import SeqIO
506
+
507
+ # peak_seqs = []
508
+ # peak_ids = []
509
+ # for rec in SeqIO.parse(peak_fasta, "fasta"):
510
+ # peak_ids.append(rec.id)
511
+ # peak_seqs.append(str(rec.seq))
512
+ # print(f"Embedding {len(peak_seqs)} binding peak sequences from {peak_fasta}", flush=True)
513
+ # dna_embedder = get_embedder(args.chrom_model, device, for_dna=True)
514
+ # out_peaks = Path(args.out_dir) / f"peaks_{args.chrom_model}.npy"
515
+ # embed_and_save(peak_seqs, peak_ids, dna_embedder, out_peaks)
516
  elif args.genome_json_dir:
517
  # Legacy: load full chromosomes from JSONs (chr1–22, X, Y, M)
518
  genome_dir = Path(args.genome_json_dir)
dpacman/classifier/model/model.py CHANGED
@@ -39,7 +39,7 @@ class CrossModalBlock(nn.Module):
39
  def forward(self, binder: torch.Tensor, glm: torch.Tensor):
40
  """
41
  binder: (batch, Lb, dim)
42
- glm: (batch, Lg, dim) -- has passed through its local CNN beforehand
43
  returns: updated binder representation (batch, Lb, dim)
44
  """
45
  # binder self-attn + ffn
@@ -63,16 +63,50 @@ class CrossModalBlock(nn.Module):
63
  c = self.ln_c2(c + c_ff)
64
  return c # (batch, Lb, dim)
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  class BindPredictor(nn.Module):
67
  def __init__(self,
68
- input_dim: int = 256,
 
 
 
69
  hidden_dim: int = 256,
70
  heads: int = 8,
71
  num_layers: int = 4,
72
  use_local_cnn_on_glm: bool = True):
73
  super().__init__()
74
- self.proj_binder = nn.Linear(input_dim, hidden_dim)
75
- self.proj_glm = nn.Linear(input_dim, hidden_dim)
 
 
 
 
 
 
 
 
 
76
  self.use_local_cnn = use_local_cnn_on_glm
77
  self.local_cnn = LocalCNN(hidden_dim) if use_local_cnn_on_glm else nn.Identity()
78
 
@@ -81,23 +115,28 @@ class BindPredictor(nn.Module):
81
  ])
82
 
83
  self.ln_out = nn.LayerNorm(hidden_dim)
84
- self.head = nn.Sequential(
85
- nn.Linear(hidden_dim, 1),
86
- nn.Sigmoid()
87
- )
88
 
89
  def forward(self, binder_emb, glm_emb):
90
  """
91
- binder_emb, glm_emb: (batch, L, input_dim)
 
 
92
  """
93
- b = self.proj_binder(binder_emb) # (B, Lb, hidden_dim)
94
- g = self.proj_glm(glm_emb) # (B, Lg, hidden_dim)
 
 
 
 
95
  if self.use_local_cnn:
96
- g = self.local_cnn(g) # local context injected
97
 
 
98
  for layer in self.layers:
99
- b = layer(b, g) # update binder with cross-modal info
100
 
101
- pooled = b.mean(dim=1) # (B, hidden_dim)
102
- out = self.ln_out(pooled)
103
- return self.head(out).squeeze(-1) # (B,)
 
39
  def forward(self, binder: torch.Tensor, glm: torch.Tensor):
40
  """
41
  binder: (batch, Lb, dim)
42
+ glm: (batch, Lg, dim) -- has passed through its local CNN beforehand
43
  returns: updated binder representation (batch, Lb, dim)
44
  """
45
  # binder self-attn + ffn
 
63
  c = self.ln_c2(c + c_ff)
64
  return c # (batch, Lb, dim)
65
 
66
+ class DimCompressor(nn.Module):
67
+ """
68
+ Learnable per-token compressor: maps any in_dim >= out_dim to out_dim (default 256).
69
+ If in_dim == out_dim, behaves as identity.
70
+ """
71
+ def __init__(self, in_dim: int, out_dim: int = 256):
72
+ super().__init__()
73
+ if in_dim == out_dim:
74
+ self.net = nn.Identity()
75
+ else:
76
+ hidden = max(out_dim * 2, (in_dim + out_dim) // 2)
77
+ self.net = nn.Sequential(
78
+ nn.LayerNorm(in_dim),
79
+ nn.Linear(in_dim, hidden),
80
+ nn.GELU(),
81
+ nn.Linear(hidden, out_dim),
82
+ )
83
+
84
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
85
+ # x: (B, L, in_dim)
86
+ return self.net(x)
87
+
88
  class BindPredictor(nn.Module):
89
  def __init__(self,
90
+ # input_dim: int = 256, # OLD: single input dim
91
+ binder_input_dim: int = 1280, # NEW: TF (binder) original dim (e.g., 1280)
92
+ glm_input_dim: int = 256, # NEW: DNA/GLM original dim (e.g., 256)
93
+ compressed_dim: int = 256, # NEW: learnable compressed dim
94
  hidden_dim: int = 256,
95
  heads: int = 8,
96
  num_layers: int = 4,
97
  use_local_cnn_on_glm: bool = True):
98
  super().__init__()
99
+ # OLD:
100
+ # self.proj_binder = nn.Linear(input_dim, hidden_dim)
101
+ # self.proj_glm = nn.Linear(input_dim, hidden_dim)
102
+
103
+ # NEW: learnable compressor for binder β†’ 256, then project to hidden
104
+ self.binder_compress = DimCompressor(binder_input_dim, out_dim=compressed_dim)
105
+ self.proj_binder = nn.Linear(compressed_dim, hidden_dim)
106
+
107
+ # GLM side stays 256 β†’ hidden
108
+ self.proj_glm = nn.Linear(glm_input_dim, hidden_dim)
109
+
110
  self.use_local_cnn = use_local_cnn_on_glm
111
  self.local_cnn = LocalCNN(hidden_dim) if use_local_cnn_on_glm else nn.Identity()
112
 
 
115
  ])
116
 
117
  self.ln_out = nn.LayerNorm(hidden_dim)
118
+ # self.head = nn.Sequential(nn.Linear(hidden_dim, 1), nn.Sigmoid()) # OLD: returned probabilities
119
+ self.head = nn.Linear(hidden_dim, 1) # NEW: return logits (safe for AMP)
 
 
120
 
121
  def forward(self, binder_emb, glm_emb):
122
  """
123
+ binder_emb: (B, Lb, binder_input_dim)
124
+ glm_emb: (B, Lg, glm_input_dim)
125
+ Returns per-nucleotide logits for the GLM sequence: (B, Lg)
126
  """
127
+ # Binder: learnable compression β†’ 256 β†’ hidden
128
+ b = self.binder_compress(binder_emb) # (B, Lb, 256)
129
+ b = self.proj_binder(b) # (B, Lb, hidden_dim)
130
+
131
+ # GLM: project β†’ hidden, add local CNN context
132
+ g = self.proj_glm(glm_emb) # (B, Lg, hidden_dim)
133
  if self.use_local_cnn:
134
+ g = self.local_cnn(g)
135
 
136
+ # Cross-modal blocks: update binder states using GLM
137
  for layer in self.layers:
138
+ b = layer(b, g) # (B, Lb, hidden_dim)
139
 
140
+ # Predict per-nucleotide logits on the GLM tokens:
141
+ # return self.head(g).squeeze(-1) # OLD: probabilities (with Sigmoid in head)
142
+ return self.head(g).squeeze(-1) # NEW: logits (apply sigmoid only in loss/metrics)
dpacman/classifier/model/train.py CHANGED
@@ -1,262 +1,383 @@
1
- #!/usr/bin/env python3
2
- import argparse
 
3
  import numpy as np
 
4
  import torch
5
  from torch import nn
6
- from torch.utils.data import Dataset, DataLoader
7
- from model import BindPredictor
8
- from pathlib import Path
9
- from collections import Counter
10
- from sklearn.metrics import roc_auc_score, average_precision_score
11
- from sklearn.decomposition import TruncatedSVD
12
- import random
13
- import sys
14
 
15
- # ---- dataset ---------------------------------------------------------
16
- class PairDataset(Dataset):
17
- def __init__(self, binder_paths, glm_paths, labels, tf_compressed_cache):
18
- """
19
- tf_compressed_cache: dict mapping binder_path -> compressed (256-d) tensor/array
20
- """
21
- assert len(binder_paths) == len(glm_paths) == len(labels)
22
- self.binder_paths = binder_paths
23
- self.glm_paths = glm_paths
24
- self.labels = labels
25
- self.tf_cache = tf_compressed_cache # already reduced to 256
26
-
27
- def __len__(self):
28
- return len(self.labels)
29
-
30
- def __getitem__(self, idx):
31
- # binder = TF embedding (possibly reduced)
32
- b = self.tf_cache[self.binder_paths[idx]] # numpy array shape (L, 256) or (256,)
33
- g = np.load(self.glm_paths[idx]) # glm (DNA) embedding
34
-
35
- if b.ndim == 1:
36
- b = b[None, :]
37
- if g.ndim == 1:
38
- g = g[None, :]
39
-
40
- b_tensor = torch.from_numpy(b).float()
41
- g_tensor = torch.from_numpy(g).float()
42
- y = torch.tensor(self.labels[idx]).float()
43
- return b_tensor, g_tensor, y
44
 
45
- def collate_fn(batch):
46
- binders, glms, labels = zip(*batch)
47
- binder_lens = [b.shape[0] for b in binders]
48
- glm_lens = [g.shape[0] for g in glms]
49
- max_b = max(binder_lens)
50
- max_g = max(glm_lens)
51
-
52
- def pad_seq(seq, target_len):
53
- L, D = seq.shape
54
- if L < target_len:
55
- pad = torch.zeros((target_len - L, D), dtype=seq.dtype, device=seq.device)
56
- return torch.cat([seq, pad], dim=0)
57
- return seq
58
-
59
- b_padded = torch.stack([pad_seq(b, max_b) for b in binders]) # (B, Lb, D)
60
- g_padded = torch.stack([pad_seq(g, max_g) for g in glms]) # (B, Lg, D)
61
- y = torch.stack(labels)
62
- return b_padded, g_padded, y
63
-
64
- # ---- utilities -------------------------------------------------------
65
- def parse_pair_list(pair_list_path):
66
- binder_paths, glm_paths, labels = [], [], []
67
- with open(pair_list_path) as f:
68
- for lineno, line in enumerate(f, start=1):
69
- if not line.strip():
70
- continue
71
  parts = line.strip().split()
72
- if len(parts) != 3:
73
- print(f"[WARN] skipping malformed line {lineno}: {line.strip()}", file=sys.stderr)
74
- continue
75
- b, g, l = parts
76
- try:
77
- lab = int(l)
78
- except ValueError:
79
- print(f"[WARN] invalid label on line {lineno}: {l}", file=sys.stderr)
80
- continue
81
- binder_paths.append(b)
82
- glm_paths.append(g)
83
- labels.append(lab)
84
- return binder_paths, glm_paths, labels
85
-
86
- def build_tf_compressed_cache(binder_paths, target_dim=256):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  """
88
- Load all unique TF (binder) embeddings, fit reduction if needed, and return dict mapping path->(L, target_dim) array.
89
  """
90
- unique_paths = sorted(set(binder_paths))
91
- print(f"[i] Found {len(unique_paths)} unique TF embedding files to compress.", flush=True)
92
- # Load all embeddings to determine dimensionality
93
- samples = []
94
- for p in unique_paths:
95
- arr = np.load(p)
96
- samples.append(arr)
97
- # Determine if reduction needed: assume all have same embedding width
98
- first = samples[0]
99
- orig_dim = first.shape[1] if first.ndim == 2 else 1
100
- reduction_needed = (orig_dim != target_dim)
101
- tf_cache = {}
102
-
103
- if reduction_needed:
104
- # Build matrix to fit SVD: we need a 2D matrix per embedding; if lengths vary we can't directly stack.
105
- # We'll do reduction per sequence individually using TruncatedSVD on concatenated flattened features:
106
- # Simplest: for variable lengths, reduce each embedding separately with a learned linear projection.
107
- # Here we fit a single TruncatedSVD on the concatenation of all sequence tokens (flattened) by padding/truncating to a fixed length.
108
- # To avoid complexity, use PCA-like linear projection learned via SVD on mean-pooled vectors:
109
- pooled = []
110
- for arr in samples:
111
- if arr.ndim == 2:
112
- pooled.append(arr.mean(axis=0)) # (orig_dim,)
113
- else:
114
- pooled.append(arr) # degenerate
115
- pooled_mat = np.stack(pooled, axis=0) # (N, orig_dim)
116
- print(f"[i] Fitting TruncatedSVD on TF pooled embeddings: {pooled_mat.shape} -> {target_dim}", flush=True)
117
- svd = TruncatedSVD(n_components=target_dim, random_state=42)
118
- reduced_pooled = svd.fit_transform(pooled_mat) # (N, target_dim)
119
-
120
- # For each original embedding, project token-level vectors by multiplying token vector with svd.components_.T
121
- # svd.components_: (target_dim, orig_dim) so projection matrix is (orig_dim, target_dim)
122
- proj_mat = svd.components_.T # (orig_dim, target_dim)
123
- for i, p in enumerate(unique_paths):
124
- arr = samples[i] # shape (L, orig_dim)
125
- if arr.ndim == 1:
126
- arr2 = arr @ proj_mat # (target_dim,)
127
- else:
128
- # project each token: (L, orig_dim) @ (orig_dim, target_dim) -> (L, target_dim)
129
- arr2 = arr @ proj_mat
130
- tf_cache[p] = arr2 # reduced per-token representation
131
- print("[i] Completed compression of TF embeddings.", flush=True)
132
- else:
133
- # already correct dim: just cache originals
134
- print(f"[i] TF embeddings already {target_dim}-dimensional; skipping reduction.", flush=True)
135
- for i, p in enumerate(unique_paths):
136
- arr = samples[i]
137
- tf_cache[p] = arr
138
- return tf_cache
139
-
140
- def evaluate(model, dl, device):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  model.eval()
142
- all_labels = []
143
- all_preds = []
144
  with torch.no_grad():
145
- for b, g, y in dl:
146
- b = b.to(device)
147
- g = g.to(device)
148
- y = y.to(device)
149
- pred = model(b, g)
150
- all_labels.append(y.cpu())
151
- all_preds.append(pred.cpu())
152
- if not all_labels:
153
- return 0.0, 0.0
154
- y_true = torch.cat(all_labels).numpy()
155
- y_score = torch.cat(all_preds).numpy()
156
- try:
157
- auc = roc_auc_score(y_true, y_score)
158
- except Exception:
159
- auc = 0.0
160
- try:
161
- ap = average_precision_score(y_true, y_score)
162
- except Exception:
163
- ap = 0.0
164
- return auc, ap
165
-
166
- # ---- main ------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  def main():
168
- parser = argparse.ArgumentParser()
169
- parser.add_argument("--pair_list", type=str, required=True,
170
- help="TSV: binder_path glm_path label")
171
- parser.add_argument("--out_dir", type=str, required=True)
172
- parser.add_argument("--epochs", type=int, default=10)
173
- parser.add_argument("--batch_size", type=int, default=32)
174
- parser.add_argument("--lr", type=float, default=1e-4)
175
- parser.add_argument("--device", type=str, default="cuda")
176
- parser.add_argument("--seed", type=int, default=42)
177
- args = parser.parse_args()
178
-
179
- # reproducibility
 
 
 
 
 
 
 
180
  random.seed(args.seed)
181
  np.random.seed(args.seed)
182
  torch.manual_seed(args.seed)
183
-
184
- print("DEBUG: starting training script with in-line TF compression", flush=True)
185
- print(f"[i] pair_list: {args.pair_list}", flush=True)
186
- print(f"[i] output dir: {args.out_dir}", flush=True)
187
  device = torch.device(args.device if torch.cuda.is_available() else "cpu")
188
- binder_paths, glm_paths, labels = parse_pair_list(args.pair_list)
189
 
190
- if len(labels) == 0:
191
- print("[ERROR] No valid pairs parsed. Exiting.", file=sys.stderr)
192
- sys.exit(1)
 
 
 
193
 
194
- label_counts = Counter(labels)
195
- print(f"[i] Total examples parsed: {len(labels)}. Label distribution: {label_counts}", flush=True)
 
 
196
 
197
- # build compressed TF cache (reduces to 256 if needed)
198
- tf_compressed_cache = build_tf_compressed_cache(binder_paths, target_dim=256)
 
 
 
 
 
199
 
200
- # simple split: 80/10/10
201
- n = len(labels)
202
- idxs = np.arange(n)
203
- np.random.shuffle(idxs)
204
- train_i = idxs[: int(0.8 * n)]
205
- val_i = idxs[int(0.8 * n): int(0.9 * n)]
206
- test_i = idxs[int(0.9 * n):]
 
 
 
 
 
 
 
 
 
 
 
207
 
208
- def subset(idxs):
209
- return [binder_paths[i] for i in idxs], [glm_paths[i] for i in idxs], [labels[i] for i in idxs]
 
 
 
210
 
211
- train_ds = PairDataset(*subset(train_i), tf_compressed_cache=tf_compressed_cache)
212
- val_ds = PairDataset(*subset(val_i), tf_compressed_cache=tf_compressed_cache)
213
- test_ds = PairDataset(*subset(test_i), tf_compressed_cache=tf_compressed_cache)
214
 
215
- print(f"[i] Train/Val/Test sizes: {len(train_ds)}/{len(val_ds)}/{len(test_ds)}", flush=True)
216
- if len(train_ds) == 0 or len(val_ds) == 0:
217
- print("[ERROR] Train or validation split is empty; cannot proceed.", file=sys.stderr)
218
- sys.exit(1)
 
 
 
 
 
 
 
219
 
220
- train_dl = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn)
221
- val_dl = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn)
222
- test_dl = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn)
223
 
224
- model = BindPredictor(input_dim=256, hidden_dim=256, heads=8, num_layers=3, use_local_cnn_on_glm=True)
225
- model = model.to(device)
226
- optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-3)
227
- loss_fn = nn.BCELoss()
228
 
229
- best_val = -float("inf")
230
- os_out = Path(args.out_dir)
231
- os_out.mkdir(exist_ok=True, parents=True)
 
232
 
233
- for epoch in range(1, args.epochs + 1):
234
- print(f"[Epoch {epoch}] starting...", flush=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  model.train()
236
- running_loss = 0.0
237
- for b, g, y in train_dl:
238
- b = b.to(device)
239
- g = g.to(device)
240
- y = y.to(device)
241
- pred = model(b, g)
242
- loss = loss_fn(pred, y)
243
- optimizer.zero_grad()
244
- loss.backward()
245
- optimizer.step()
246
- running_loss += loss.item() * b.size(0)
247
- train_loss = running_loss / len(train_ds)
248
- val_auc, val_ap = evaluate(model, val_dl, device)
249
- print(f"[Epoch {epoch}] train_loss={train_loss:.4f} val_auc={val_auc:.4f} val_ap={val_ap:.4f}", flush=True)
250
-
251
- if val_auc > best_val:
252
- best_val = val_auc
253
- torch.save(model.state_dict(), os_out / "best_model.pt")
254
- print(f"[Epoch {epoch}] Saved new best model with val_auc={val_auc:.4f}", flush=True)
255
-
256
- torch.save(model.state_dict(), os_out / "last_model.pt")
257
- test_auc, test_ap = evaluate(model, test_dl, device)
258
- print(f"FINAL TEST: AUC={test_auc:.4f} AP={test_ap:.4f}", flush=True)
259
- print(f"[i] Models written to {os_out}/best_model.pt and last_model.pt", flush=True)
260
-
261
- if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
  main()
 
1
+ import argparse, random, sys
2
+ from pathlib import Path
3
+
4
  import numpy as np
5
+ import pandas as pd
6
  import torch
7
  from torch import nn
8
+ from torch.utils.data import Dataset, DataLoader, Sampler
9
+ # from sklearn.random_projection import GaussianRandomProjection # OLD (kept): projection was removed earlier
10
+ import matplotlib.pyplot as plt
 
 
 
 
 
11
 
12
+ import torch.amp as amp
13
+ from torch.nn import functional as F
14
+ from model import BindPredictor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ # ─────────────── utilities ────────────────────────────────────────────────
17
+ def parse_pair_list(path):
18
+ binders, glms = [], []
19
+ with open(path) as f:
20
+ for ln, line in enumerate(f,1):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  parts = line.strip().split()
22
+ if len(parts) < 2: continue
23
+ b,g = parts[0], parts[1]
24
+ binders.append(b); glms.append(g)
25
+ return binders, glms
26
+
27
+ class ListBatchSampler(Sampler):
28
+ def __init__(self, batches): self.batches = batches
29
+ def __iter__(self): return iter(self.batches)
30
+ def __len__(self): return len(self.batches)
31
+
32
+ def make_buckets(idxs, glm_paths, batch_size, n_buckets=10, seed=42):
33
+ rng = random.Random(seed)
34
+ lengths = [(i, np.load(glm_paths[i]).shape[0]) for i in idxs]
35
+ lengths.sort(key=lambda x: x[1])
36
+ size = max(1, int(np.ceil(len(lengths)/n_buckets)))
37
+ buckets = [lengths[i:i+size] for i in range(0,len(lengths),size)]
38
+ batches = []
39
+ for bucket in buckets:
40
+ ids = [i for i,_ in bucket]
41
+ rng.shuffle(ids)
42
+ for i in range(0,len(ids),batch_size):
43
+ batches.append(ids[i:i+batch_size])
44
+ rng.shuffle(batches)
45
+ return batches
46
+
47
+ def dna_key_from_path(path: str) -> str:
48
+ """.../dna_peak42.npy -> 'peak42'"""
49
+ stem = Path(path).stem
50
+ if "_" in stem:
51
+ _, rest = stem.split("_", 1)
52
+ else:
53
+ rest = stem
54
+ return rest
55
+
56
+ def build_tf_cache(tf_paths, target_dim=256):
57
  """
58
+ Load raw TF embeddings without projecting; compression is learnable in the model.
59
  """
60
+ unique = sorted(set(tf_paths))
61
+ print(f"[i] (Learnable) Preparing {len(unique)} TF files; target {target_dim}d inside the model", flush=True)
62
+
63
+ pools, raw = [], []
64
+ for p in unique:
65
+ arr = np.load(p) # (L, D) or (D,)
66
+ raw.append(arr)
67
+ pools.append(arr.mean(axis=0) if arr.ndim==2 else arr)
68
+ M = np.stack(pools,0)
69
+ orig_dim = M.shape[1]
70
+ print(f"[i] Pooled shape β†’ {M.shape} (orig_dim={orig_dim})", flush=True)
71
+
72
+ cache = {}
73
+ for i,p in enumerate(unique):
74
+ arr = raw[i]
75
+ # OLD: projection here (removed)
76
+ cache[p] = arr
77
+ print("[i] TF cache ready (raw); compression will be learned.", flush=True)
78
+ return cache
79
+
80
+ # ─────────────── Dataset & Collation ─────────────────────────────────────
81
+ class PairDataset(Dataset):
82
+ def __init__(self, tf_paths, dna_paths, final_df, tf_cache):
83
+ self.tf_paths, self.dna_paths = tf_paths, dna_paths
84
+ self.tf_cache = tf_cache
85
+ self.targets = {}
86
+ for _, row in final_df.iterrows():
87
+ dna_id = row["dna_id"]
88
+ vec = np.array(list(map(float, row["score_sig_r2"].split(","))), dtype=np.float32)
89
+ self.targets[dna_id] = vec
90
+
91
+ def __len__(self): return len(self.tf_paths)
92
+
93
+ def __getitem__(self, i):
94
+ b = self.tf_cache[self.tf_paths[i]] # (L_b, D_b) or (D_b,)
95
+ if b.ndim==1: b = b[None,:]
96
+ g = np.load(self.dna_paths[i]) # (L_g, 256) or (256,)
97
+ if g.ndim==1: g = g[None,:]
98
+
99
+ stem = Path(self.dna_paths[i]).stem
100
+ dna_id = stem.replace("dna_","")
101
+ t = self.targets.get(dna_id, np.zeros(g.shape[0],dtype=np.float32))
102
+
103
+ return torch.from_numpy(b).float(), \
104
+ torch.from_numpy(g).float(), \
105
+ torch.from_numpy(t).float()
106
+
107
+ def collate_fn(batch):
108
+ Bs = [b.shape[0] for b,_,_ in batch]
109
+ Gs = [g.shape[0] for _,g,_ in batch]
110
+ maxB, maxG = max(Bs), max(Gs)
111
+
112
+ def pad_seq(x, L):
113
+ if x.shape[0] < L:
114
+ pad = torch.zeros((L-x.shape[0], x.shape[1]), dtype=x.dtype, device=x.device)
115
+ return torch.cat([x, pad], dim=0)
116
+ return x
117
+
118
+ def pad_t(y, L):
119
+ if y.shape[0] < L:
120
+ pad = torch.zeros((L-y.shape[0],), dtype=y.dtype, device=y.device)
121
+ return torch.cat([y, pad], dim=0)
122
+ return y
123
+
124
+ b_stack = torch.stack([pad_seq(b, maxB) for b,_,_ in batch])
125
+ g_stack = torch.stack([pad_seq(g, maxG) for _,g,_ in batch])
126
+ t_stack = torch.stack([pad_t(t, maxG) for *_,t in batch])
127
+ return b_stack, g_stack, t_stack
128
+
129
+ # ─────────────── losses, metrics ─────────────────────────────────────────
130
+ def combined_loss_components(logits, targets, peak_thresh=0.5, eps=1e-8):
131
+ probs = torch.sigmoid(logits)
132
+ labels = (targets >= peak_thresh).float()
133
+ non_peak_mask = (labels == 0).float()
134
+ peak_mask = (labels == 1).float()
135
+
136
+ bce_all = F.binary_cross_entropy_with_logits(logits, labels, reduction='none')
137
+ bce_non = (bce_all * non_peak_mask)
138
+ bce_non = bce_non.sum() / (non_peak_mask.sum() + eps)
139
+
140
+ mse_peaks = F.mse_loss(probs * peak_mask, targets * peak_mask, reduction='sum') \
141
+ / (peak_mask.sum() + eps)
142
+
143
+ t_dist = (targets + eps)
144
+ p_dist = (probs + eps)
145
+ t_dist = t_dist / t_dist.sum(dim=1, keepdim=True)
146
+ p_dist = p_dist / p_dist.sum(dim=1, keepdim=True)
147
+ kl = (t_dist * (t_dist.clamp(min=eps).log() - p_dist.clamp(min=eps).log())).sum(dim=1).mean()
148
+
149
+ return bce_non, kl, mse_peaks, probs
150
+
151
+ def accuracy_percentage(logits, targets, peak_thresh=0.5):
152
+ probs = torch.sigmoid(logits)
153
+ preds_bin = (probs >= 0.5).float()
154
+ labels = (targets >= peak_thresh).float()
155
+ correct = (preds_bin == labels).float().sum()
156
+ total = torch.numel(labels)
157
+ return (correct / max(1, total)).item() * 100.0
158
+
159
+ def evaluate(model, dl, device, alpha, beta, gamma, peak_thresh, eps=1e-8):
160
  model.eval()
161
+ tot_loss, tot_acc = 0.0, 0.0
162
+ n_batches = 0
163
  with torch.no_grad():
164
+ for b,g,t in dl:
165
+ b,g,t = b.to(device), g.to(device), t.to(device)
166
+ logits = model(b,g)
167
+ bce_non, kl, mse_peaks, _ = combined_loss_components(logits, t, peak_thresh=peak_thresh, eps=eps)
168
+ loss = alpha*bce_non + beta*kl + gamma*mse_peaks
169
+ acc = accuracy_percentage(logits, t, peak_thresh=peak_thresh)
170
+ tot_loss += loss.item(); tot_acc += acc; n_batches += 1
171
+ if n_batches == 0: return float("nan"), float("nan")
172
+ return tot_loss / n_batches, tot_acc / n_batches
173
+
174
+ # ─────────────── cluster-aware splitting ──────────────────────────────────
175
+ def assign_clusters_to_splits(cluster_to_indices, val_frac=0.10, test_frac=0.10, seed=42):
176
+ """
177
+ cluster_to_indices: dict[cluster_id] -> list of example indices (from pair_list) in that cluster
178
+ We greedily pack whole clusters into val/test until hitting targets (#examples), rest to train.
179
+ """
180
+ rng = random.Random(seed)
181
+ clusters = list(cluster_to_indices.items())
182
+ rng.shuffle(clusters)
183
+
184
+ total = sum(len(ixs) for _, ixs in clusters)
185
+ target_val = int(round(total * val_frac))
186
+ target_test = int(round(total * test_frac))
187
+ cur_val = cur_test = 0
188
+
189
+ tr_ix, va_ix, te_ix = [], [], []
190
+ for cid, ixs in clusters:
191
+ c = len(ixs)
192
+ if cur_val + c <= target_val:
193
+ va_ix.extend(ixs); cur_val += c
194
+ elif cur_test + c <= target_test:
195
+ te_ix.extend(ixs); cur_test += c
196
+ else:
197
+ tr_ix.extend(ixs)
198
+ return tr_ix, va_ix, te_ix
199
+
200
+ # ─────────────── train & main ────────────────────────────────────────────
201
  def main():
202
+ p = argparse.ArgumentParser()
203
+ p.add_argument("--pair_list", required=True)
204
+ p.add_argument("--final_csv", required=True)
205
+ p.add_argument("--out_dir", required=True)
206
+ p.add_argument("--epochs", type=int, default=10)
207
+ p.add_argument("--batch_size", type=int, default=16)
208
+ p.add_argument("--accum_steps", type=int, default=4)
209
+ p.add_argument("--lr", type=float, default=1e-4)
210
+ p.add_argument("--device", default="cuda")
211
+ p.add_argument("--seed", type=int, default=42)
212
+ p.add_argument("--alpha", type=float, default=0.5)
213
+ p.add_argument("--beta", type=float, default=0.6)
214
+ p.add_argument("--gamma", type=float, default=0.6)
215
+ p.add_argument("--peak_thresh", type=float, default=0.5)
216
+ # NEW: fractions for cluster-aware split (used only if cluster_id present)
217
+ p.add_argument("--val_frac", type=float, default=0.10)
218
+ p.add_argument("--test_frac", type=float, default=0.10)
219
+ args = p.parse_args()
220
+
221
  random.seed(args.seed)
222
  np.random.seed(args.seed)
223
  torch.manual_seed(args.seed)
 
 
 
 
224
  device = torch.device(args.device if torch.cuda.is_available() else "cpu")
 
225
 
226
+ # 1) load pair list & final.csv (now may include cluster_id)
227
+ tf_paths, dna_paths = parse_pair_list(args.pair_list)
228
+ final_df = pd.read_csv(args.final_csv, dtype=str)
229
+ print(f"[i] Loaded {len(tf_paths)} pairs", flush=True)
230
+
231
+ tf_cache = build_tf_cache(tf_paths, target_dim=256)
232
 
233
+ # detect binder/DNA dims
234
+ sample_tf = tf_cache[tf_paths[0]]
235
+ binder_input_dim = sample_tf.shape[1] if sample_tf.ndim == 2 else sample_tf.shape[0]
236
+ glm_input_dim = 256
237
 
238
+ # 2) cluster-aware split if possible
239
+ use_cluster_split = ("cluster_id" in final_df.columns)
240
+ if use_cluster_split:
241
+ print("[i] Cluster column detected in final_csv; performing cluster-aware split.", flush=True)
242
+ # build dna_id -> cluster_id map
243
+ cid_map = (final_df[["dna_id","cluster_id"]].dropna().drop_duplicates()
244
+ .set_index("dna_id")["cluster_id"].to_dict())
245
 
246
+ # map each example (by index) to its dna_id and cluster
247
+ example_dna_ids = [dna_key_from_path(p) for p in dna_paths]
248
+ example_clusters = []
249
+ missing = 0
250
+ for did in example_dna_ids:
251
+ if did in cid_map:
252
+ example_clusters.append(cid_map[did])
253
+ else:
254
+ # fallback: treat singleton cluster
255
+ example_clusters.append(f"singleton::{did}")
256
+ missing += 1
257
+ if missing:
258
+ print(f"[WARN] {missing} dna_ids from pair_list not found in cluster map; treating as singleton clusters.", flush=True)
259
+
260
+ # build cluster -> indices
261
+ cluster_to_indices = {}
262
+ for i, cid in enumerate(example_clusters):
263
+ cluster_to_indices.setdefault(cid, []).append(i)
264
 
265
+ tr_idx, va_idx, te_idx = assign_clusters_to_splits(
266
+ cluster_to_indices,
267
+ val_frac=args.val_frac, test_frac=args.test_frac, seed=args.seed
268
+ )
269
+ print(f"[i] Cluster split sizes (examples): train={len(tr_idx)} val={len(va_idx)} test={len(te_idx)}", flush=True)
270
 
271
+ # helper to subset paths
272
+ def subset_by_indices(ixs):
273
+ return [tf_paths[i] for i in ixs], [dna_paths[i] for i in ixs]
274
 
275
+ tr_t, tr_d = subset_by_indices(tr_idx)
276
+ va_t, va_d = subset_by_indices(va_idx)
277
+ te_t, te_d = subset_by_indices(te_idx)
278
+
279
+ else:
280
+ print("[i] No cluster_id in final_csv; using random 80/10/10 split (OLD behavior).", flush=True)
281
+ # OLD random split (kept, now under else)
282
+ N = len(tf_paths)
283
+ idxs = list(range(N)); random.shuffle(idxs)
284
+ n_tr = int(0.8*N); n_va = int(0.1*N)
285
+ tr, va, te = idxs[:n_tr], idxs[n_tr:n_tr+n_va], idxs[n_tr+n_va:]
286
 
287
+ def subset(idxs_):
288
+ return [tf_paths[i] for i in idxs_], [dna_paths[i] for i in idxs_]
 
289
 
290
+ tr_t, tr_d = subset(tr)
291
+ va_t, va_d = subset(va)
292
+ te_t, te_d = subset(te)
 
293
 
294
+ # 3) bucketed samplers (unchanged, but now use the cluster-aware subsets when available)
295
+ tr_bs = make_buckets(list(range(len(tr_t))), tr_d, args.batch_size, n_buckets=10, seed=args.seed)
296
+ va_bs = make_buckets(list(range(len(va_t))), va_d, args.batch_size, n_buckets=5, seed=args.seed+1)
297
+ te_bs = make_buckets(list(range(len(te_t))), te_d, args.batch_size, n_buckets=5, seed=args.seed+2)
298
 
299
+ tr_dl = DataLoader(PairDataset(tr_t, tr_d, final_df, tf_cache),
300
+ batch_sampler=ListBatchSampler(tr_bs),
301
+ collate_fn=collate_fn)
302
+ va_dl = DataLoader(PairDataset(va_t, va_d, final_df, tf_cache),
303
+ batch_sampler=ListBatchSampler(va_bs),
304
+ collate_fn=collate_fn)
305
+ te_dl = DataLoader(PairDataset(te_t, te_d, final_df, tf_cache),
306
+ batch_sampler=ListBatchSampler(te_bs),
307
+ collate_fn=collate_fn)
308
+
309
+ # 4) model, optimizer, scaler
310
+ model = BindPredictor(binder_input_dim=binder_input_dim,
311
+ glm_input_dim=glm_input_dim,
312
+ compressed_dim=256,
313
+ hidden_dim=256,
314
+ heads=8, num_layers=4,
315
+ use_local_cnn_on_glm=True).to(device)
316
+ optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
317
+ scaler = amp.GradScaler('cuda')
318
+
319
+ history, best_val = {"train": [], "val": []}, float("inf")
320
+ od = Path(args.out_dir); od.mkdir(exist_ok=True, parents=True)
321
+
322
+ for ep in range(1, args.epochs+1):
323
+ print(f"β”Œβ”€[Epoch {ep}]────────────────────────", flush=True)
324
  model.train()
325
+ optimizer.zero_grad()
326
+ acc_loss_sum, acc_acc_sum, n_train_batches = 0.0, 0.0, 0
327
+
328
+ for i, (b, g, t) in enumerate(tr_dl):
329
+ b, g, t = b.to(device), g.to(device), t.to(device)
330
+ with amp.autocast('cuda'):
331
+ logits = model(b, g)
332
+ bce_non, kl, mse_peaks, probs = combined_loss_components(
333
+ logits, t, peak_thresh=args.peak_thresh
334
+ )
335
+ loss = args.alpha*bce_non + args.beta*kl + args.gamma*mse_peaks
336
+ loss = loss / args.accum_steps
337
+
338
+ scaler.scale(loss).backward()
339
+
340
+ if (i + 1) % args.accum_steps == 0:
341
+ scaler.step(optimizer)
342
+ scaler.update()
343
+ optimizer.zero_grad()
344
+
345
+ with torch.no_grad():
346
+ acc_loss_sum += (loss.item() * args.accum_steps)
347
+ acc_acc_sum += accuracy_percentage(logits, t, peak_thresh=args.peak_thresh)
348
+ n_train_batches += 1
349
+
350
+ del b, g, t, logits, probs, loss, bce_non, kl, mse_peaks
351
+ torch.cuda.empty_cache()
352
+
353
+ # finalize if leftovers
354
+ if n_train_batches % args.accum_steps != 0:
355
+ scaler.step(optimizer); scaler.update(); optimizer.zero_grad()
356
+
357
+ train_loss = acc_loss_sum / max(1, n_train_batches)
358
+ train_acc = acc_acc_sum / max(1, n_train_batches)
359
+
360
+ val_loss, val_acc = evaluate(model, va_dl, device,
361
+ alpha=args.alpha, beta=args.beta, gamma=args.gamma,
362
+ peak_thresh=args.peak_thresh)
363
+ print(f"[Epoch {ep}] train_loss={train_loss:.4f} train_acc={train_acc:.2f}% "
364
+ f"val_loss={val_loss:.4f} val_acc={val_acc:.2f}%", flush=True)
365
+
366
+ history["train"].append(train_loss)
367
+ history["val"].append(val_loss)
368
+ if val_loss < best_val:
369
+ best_val = val_loss
370
+ torch.save(model.state_dict(), od/"best_model.pt")
371
+ print(f" Saved new best_model.pt (val_loss={val_loss:.4f}, val_acc={val_acc:.2f}%)", flush=True)
372
+
373
+ torch.save(model.state_dict(), od/"last_model.pt")
374
+
375
+ fig, ax = plt.subplots()
376
+ ax.plot(history["train"], label="train")
377
+ ax.plot(history["val"], label="val")
378
+ ax.set_xlabel("epoch"); ax.set_ylabel("combined loss"); ax.legend()
379
+ fig.savefig(od/"loss_curve.png")
380
+ print(f"βœ… Done β†’ outputs in {od}", flush=True)
381
+
382
+ if __name__=="__main__":
383
  main()