ananya updates
Browse files
dpacman/classifier/model/clustering_data.py
ADDED
|
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import argparse
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import random
|
| 7 |
+
import sys
|
| 8 |
+
import subprocess
|
| 9 |
+
from collections import defaultdict
|
| 10 |
+
|
| 11 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 12 |
+
# Original helpers (kept; some lightly edited/commented where needed)
|
| 13 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 14 |
+
|
| 15 |
+
def read_ids_file(p):
|
| 16 |
+
p = Path(p)
|
| 17 |
+
if not p.exists():
|
| 18 |
+
raise FileNotFoundError(f"IDs file not found: {p}")
|
| 19 |
+
return [line.strip() for line in p.open() if line.strip()]
|
| 20 |
+
|
| 21 |
+
def split_embeddings(emb_path, ids_path, out_dir, prefix):
|
| 22 |
+
out_dir = Path(out_dir)
|
| 23 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 24 |
+
|
| 25 |
+
if not Path(emb_path).exists():
|
| 26 |
+
raise FileNotFoundError(f"Embedding file not found: {emb_path}")
|
| 27 |
+
if not Path(ids_path).exists():
|
| 28 |
+
raise FileNotFoundError(f"IDs file not found: {ids_path}")
|
| 29 |
+
|
| 30 |
+
if emb_path.endswith(".npz"):
|
| 31 |
+
data = np.load(emb_path, allow_pickle=True)
|
| 32 |
+
if "embeddings" in data:
|
| 33 |
+
emb = data["embeddings"]
|
| 34 |
+
else:
|
| 35 |
+
raise ValueError(f"{emb_path} missing 'embeddings' key")
|
| 36 |
+
else:
|
| 37 |
+
emb = np.load(emb_path)
|
| 38 |
+
|
| 39 |
+
ids = read_ids_file(ids_path)
|
| 40 |
+
if len(ids) != emb.shape[0]:
|
| 41 |
+
print(f"[WARN] length mismatch: {len(ids)} ids vs {emb.shape[0]} embeddings in {emb_path}", file=sys.stderr)
|
| 42 |
+
|
| 43 |
+
mapping = {}
|
| 44 |
+
for i, ident in enumerate(ids):
|
| 45 |
+
if i >= emb.shape[0]:
|
| 46 |
+
print(f"[WARN] skipping {ident}: no embedding at index {i}", file=sys.stderr)
|
| 47 |
+
continue
|
| 48 |
+
arr = emb[i]
|
| 49 |
+
out_file = out_dir / f"{prefix}_{ident}.npy"
|
| 50 |
+
np.save(out_file, arr)
|
| 51 |
+
mapping[ident] = str(out_file)
|
| 52 |
+
return mapping
|
| 53 |
+
|
| 54 |
+
def extract_symbol_from_tf_id(full_id: str) -> str:
|
| 55 |
+
"""
|
| 56 |
+
Given a TF embedding ID like 'sp|O15062|ZBTB5_HUMAN' or 'ZBTB5_HUMAN',
|
| 57 |
+
return the gene symbol uppercase (e.g., 'ZBTB5').
|
| 58 |
+
"""
|
| 59 |
+
if "|" in full_id:
|
| 60 |
+
try:
|
| 61 |
+
# format sp|Accession|SYMBOL_HUMAN
|
| 62 |
+
genepart = full_id.split("|")[2]
|
| 63 |
+
except IndexError:
|
| 64 |
+
genepart = full_id
|
| 65 |
+
else:
|
| 66 |
+
genepart = full_id
|
| 67 |
+
symbol = genepart.split("_")[0]
|
| 68 |
+
return symbol.upper()
|
| 69 |
+
|
| 70 |
+
def build_tf_symbol_map(tf_map):
|
| 71 |
+
"""
|
| 72 |
+
Build mapping gene_symbol -> list of embedding paths.
|
| 73 |
+
"""
|
| 74 |
+
symbol_map = {}
|
| 75 |
+
for full_id, path in tf_map.items():
|
| 76 |
+
symbol = extract_symbol_from_tf_id(full_id)
|
| 77 |
+
symbol_map.setdefault(symbol, []).append(path)
|
| 78 |
+
return symbol_map
|
| 79 |
+
|
| 80 |
+
def tf_key_from_path(path: str) -> str:
|
| 81 |
+
"""
|
| 82 |
+
Given a path like .../tf_sp|O15062|ZBTB5_HUMAN.npy, extract normalized symbol 'ZBTB5'.
|
| 83 |
+
"""
|
| 84 |
+
stem = Path(path).stem # e.g., tf_sp|O15062|ZBTB5_HUMAN
|
| 85 |
+
# remove leading prefix if present (tf_)
|
| 86 |
+
if "_" in stem:
|
| 87 |
+
_, rest = stem.split("_", 1)
|
| 88 |
+
else:
|
| 89 |
+
rest = stem
|
| 90 |
+
return extract_symbol_from_tf_id(rest)
|
| 91 |
+
|
| 92 |
+
def dna_key_from_path(path: str) -> str:
|
| 93 |
+
"""
|
| 94 |
+
Given .../dna_peak42.npy -> 'peak42'
|
| 95 |
+
"""
|
| 96 |
+
stem = Path(path).stem
|
| 97 |
+
if "_" in stem:
|
| 98 |
+
_, rest = stem.split("_", 1)
|
| 99 |
+
else:
|
| 100 |
+
rest = stem
|
| 101 |
+
return rest
|
| 102 |
+
|
| 103 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 104 |
+
# New helpers for MMseqs clustering & cluster-level splitting
|
| 105 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 106 |
+
|
| 107 |
+
def write_dna_fasta(df: pd.DataFrame, out_fasta: Path) -> None:
|
| 108 |
+
"""
|
| 109 |
+
Write unique DNA sequences to FASTA using dna_id as header.
|
| 110 |
+
Requires df with columns: dna_id, dna_sequence
|
| 111 |
+
"""
|
| 112 |
+
uniq = df[["dna_id", "dna_sequence"]].drop_duplicates()
|
| 113 |
+
with open(out_fasta, "w") as f:
|
| 114 |
+
for _, row in uniq.iterrows():
|
| 115 |
+
did = row["dna_id"]
|
| 116 |
+
seq = str(row["dna_sequence"]).upper().replace(" ", "").replace("\n", "")
|
| 117 |
+
f.write(f">{did}\n{seq}\n")
|
| 118 |
+
|
| 119 |
+
def run_mmseqs_easy_cluster(
|
| 120 |
+
mmseqs_bin: str,
|
| 121 |
+
fasta: Path,
|
| 122 |
+
out_prefix: Path,
|
| 123 |
+
tmp_dir: Path,
|
| 124 |
+
min_seq_id: float,
|
| 125 |
+
coverage: float,
|
| 126 |
+
cov_mode: int,
|
| 127 |
+
) -> Path:
|
| 128 |
+
"""
|
| 129 |
+
Runs mmseqs easy-cluster on nucleotide sequences.
|
| 130 |
+
Returns the path to a clusters TSV file (creating it if the default one isn't present).
|
| 131 |
+
"""
|
| 132 |
+
tmp_dir.mkdir(parents=True, exist_ok=True)
|
| 133 |
+
out_prefix.parent.mkdir(parents=True, exist_ok=True)
|
| 134 |
+
|
| 135 |
+
cmd = [
|
| 136 |
+
mmseqs_bin, "easy-cluster",
|
| 137 |
+
str(fasta), str(out_prefix), str(tmp_dir),
|
| 138 |
+
"--min-seq-id", str(min_seq_id),
|
| 139 |
+
"-c", str(coverage),
|
| 140 |
+
"--cov-mode", str(cov_mode),
|
| 141 |
+
# You can add performance flags here if needed, e.g.:
|
| 142 |
+
# "--threads", "8"
|
| 143 |
+
]
|
| 144 |
+
print("[i] Running:", " ".join(cmd), flush=True)
|
| 145 |
+
subprocess.run(cmd, check=True)
|
| 146 |
+
|
| 147 |
+
# MMseqs easy-cluster typically writes <out_prefix>_cluster.tsv
|
| 148 |
+
default_tsv = Path(str(out_prefix) + "_cluster.tsv")
|
| 149 |
+
if default_tsv.exists():
|
| 150 |
+
print(f"[i] Found cluster TSV: {default_tsv}")
|
| 151 |
+
return default_tsv
|
| 152 |
+
|
| 153 |
+
# Fallback: try createtsv if default is missing
|
| 154 |
+
# This requires the internal DBs. easy-cluster creates DBs alongside out_prefix.
|
| 155 |
+
# We'll try to locate them and emit a TSV.
|
| 156 |
+
in_db = Path(str(out_prefix) + "_query")
|
| 157 |
+
cl_db = Path(str(out_prefix) + "_cluster")
|
| 158 |
+
out_tsv = Path(str(out_prefix) + "_fallback_cluster.tsv")
|
| 159 |
+
if in_db.exists() and cl_db.exists():
|
| 160 |
+
cmd2 = [mmseqs_bin, "createtsv", str(in_db), str(in_db), str(cl_db), str(out_tsv)]
|
| 161 |
+
print("[i] Creating TSV via createtsv:", " ".join(cmd2), flush=True)
|
| 162 |
+
subprocess.run(cmd2, check=True)
|
| 163 |
+
if out_tsv.exists():
|
| 164 |
+
return out_tsv
|
| 165 |
+
|
| 166 |
+
raise FileNotFoundError("Could not locate clusters TSV from mmseqs. "
|
| 167 |
+
"Expected {default_tsv} or createtsv fallback.")
|
| 168 |
+
|
| 169 |
+
def parse_mmseqs_clusters(tsv_path: Path) -> dict:
|
| 170 |
+
"""
|
| 171 |
+
Parse MMseqs cluster TSV (rep \t member). Returns dna_id -> cluster_rep_id
|
| 172 |
+
"""
|
| 173 |
+
mapping = {}
|
| 174 |
+
with open(tsv_path) as f:
|
| 175 |
+
for line in f:
|
| 176 |
+
parts = line.rstrip("\n").split("\t")
|
| 177 |
+
if len(parts) < 2:
|
| 178 |
+
continue
|
| 179 |
+
rep, member = parts[0], parts[1]
|
| 180 |
+
mapping[member] = rep
|
| 181 |
+
# Some TSVs include rep->rep; if not, ensure rep is mapped to itself:
|
| 182 |
+
if rep not in mapping:
|
| 183 |
+
mapping[rep] = rep
|
| 184 |
+
return mapping
|
| 185 |
+
|
| 186 |
+
def assign_clusters_to_splits(cluster_rep_to_members: dict,
|
| 187 |
+
val_frac: float,
|
| 188 |
+
test_frac: float,
|
| 189 |
+
seed: int = 42):
|
| 190 |
+
"""
|
| 191 |
+
cluster_rep_to_members: dict[rep] = [members...]
|
| 192 |
+
Returns: dict with keys 'train','val','test' mapping to sets of dna_id.
|
| 193 |
+
Ensures all members of a cluster go to the same split.
|
| 194 |
+
"""
|
| 195 |
+
rng = random.Random(seed)
|
| 196 |
+
reps = list(cluster_rep_to_members.keys())
|
| 197 |
+
rng.shuffle(reps)
|
| 198 |
+
|
| 199 |
+
# Greedy-ish fill by total member counts to match desired fractions.
|
| 200 |
+
total = sum(len(cluster_rep_to_members[r]) for r in reps)
|
| 201 |
+
target_val = int(round(total * val_frac))
|
| 202 |
+
target_test = int(round(total * test_frac))
|
| 203 |
+
cur_val = cur_test = 0
|
| 204 |
+
|
| 205 |
+
val_ids, test_ids, train_ids = set(), set(), set()
|
| 206 |
+
for rep in reps:
|
| 207 |
+
members = cluster_rep_to_members[rep]
|
| 208 |
+
c = len(members)
|
| 209 |
+
# Fill val first, then test, then train
|
| 210 |
+
if cur_val + c <= target_val:
|
| 211 |
+
val_ids.update(members); cur_val += c
|
| 212 |
+
elif cur_test + c <= target_test:
|
| 213 |
+
test_ids.update(members); cur_test += c
|
| 214 |
+
else:
|
| 215 |
+
train_ids.update(members)
|
| 216 |
+
|
| 217 |
+
return {"train": train_ids, "val": val_ids, "test": test_ids}
|
| 218 |
+
|
| 219 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 220 |
+
# Main
|
| 221 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 222 |
+
|
| 223 |
+
def main():
|
| 224 |
+
parser = argparse.ArgumentParser(
|
| 225 |
+
description="Build TF-DNA pair lists with MMseqs clustering on DNA to prevent split leakage."
|
| 226 |
+
)
|
| 227 |
+
parser.add_argument("--final_csv", required=True, help="final.csv with TF_id and dna_sequence")
|
| 228 |
+
parser.add_argument("--dna_embed_npz", required=True, help="DNA embedding file (.npy or .npz)")
|
| 229 |
+
parser.add_argument("--dna_ids", required=True, help="IDs file for DNA embeddings (peak*.ids)")
|
| 230 |
+
parser.add_argument("--tf_embed_npy", required=True, help="TF embedding file (.npy or .npz)")
|
| 231 |
+
parser.add_argument("--tf_ids", required=True, help="IDs file for TF embeddings (sp|... ids)")
|
| 232 |
+
parser.add_argument("--out_dir", required=True, help="Output directory")
|
| 233 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 234 |
+
|
| 235 |
+
# NEW: MMseqs options & split fractions
|
| 236 |
+
parser.add_argument("--mmseqs_bin", default="mmseqs", help="Path to mmseqs binary")
|
| 237 |
+
parser.add_argument("--min_seq_id", type=float, default=0.9, help="MMseqs --min-seq-id")
|
| 238 |
+
parser.add_argument("--cov", type=float, default=0.8, help="MMseqs -c coverage fraction")
|
| 239 |
+
parser.add_argument("--cov_mode", type=int, default=1, help="MMseqs --cov-mode (1 = coverage of target)")
|
| 240 |
+
parser.add_argument("--val_frac", type=float, default=0.10)
|
| 241 |
+
parser.add_argument("--test_frac", type=float, default=0.10)
|
| 242 |
+
parser.add_argument("--tmp_dir", default=None, help="MMseqs tmp dir (defaults to out_dir/tmp)")
|
| 243 |
+
args = parser.parse_args()
|
| 244 |
+
|
| 245 |
+
random.seed(args.seed)
|
| 246 |
+
out_dir = Path(args.out_dir)
|
| 247 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 248 |
+
|
| 249 |
+
# Load final.csv
|
| 250 |
+
df = pd.read_csv(args.final_csv, dtype=str)
|
| 251 |
+
if "TF_id" not in df.columns or "dna_sequence" not in df.columns:
|
| 252 |
+
raise RuntimeError("final.csv must have columns TF_id and dna_sequence")
|
| 253 |
+
|
| 254 |
+
# Assign dna_id (unique per dna_sequence)
|
| 255 |
+
unique_seqs = df["dna_sequence"].drop_duplicates().tolist()
|
| 256 |
+
seq_to_id = {seq: f"peak{i}" for i, seq in enumerate(unique_seqs)}
|
| 257 |
+
df["dna_id"] = df["dna_sequence"].map(seq_to_id)
|
| 258 |
+
enriched_csv = out_dir / "final_with_dna_id.csv"
|
| 259 |
+
df.to_csv(enriched_csv, index=False)
|
| 260 |
+
print(f"[i] Wrote augmented final.csv with dna_id to {enriched_csv}")
|
| 261 |
+
|
| 262 |
+
# Split embeddings into per-item files (unchanged)
|
| 263 |
+
print(f"[i] Splitting DNA embeddings from {args.dna_embed_npz} with ids {args.dna_ids}")
|
| 264 |
+
dna_map = split_embeddings(args.dna_embed_npz, args.dna_ids, out_dir / "dna_single", "dna")
|
| 265 |
+
print(f"[i] DNA embeddings available: {len(dna_map)} (sample: {list(dna_map.keys())[:10]})")
|
| 266 |
+
print(f"[i] Splitting TF embeddings from {args.tf_embed_npy} with ids {args.tf_ids}")
|
| 267 |
+
tf_map = split_embeddings(args.tf_embed_npy, args.tf_ids, out_dir / "tf_single", "tf")
|
| 268 |
+
print(f"[i] TF embeddings available: {len(tf_map)} (sample: {list(tf_map.keys())[:10]})")
|
| 269 |
+
|
| 270 |
+
# Build gene-symbol normalized map
|
| 271 |
+
tf_symbol_map = build_tf_symbol_map(tf_map)
|
| 272 |
+
print(f"[i] TF symbol map keys (sample): {list(tf_symbol_map.keys())[:30]}")
|
| 273 |
+
|
| 274 |
+
# Diagnostic overlaps
|
| 275 |
+
norm_tf_in_final = set(t.split("_seq")[0].upper() for t in df["TF_id"].unique())
|
| 276 |
+
available_tf_symbols = set(tf_symbol_map.keys())
|
| 277 |
+
intersect_tf = norm_tf_in_final & available_tf_symbols
|
| 278 |
+
print(f"[i] Unique normalized TF symbols in final.csv: {len(norm_tf_in_final)}")
|
| 279 |
+
print(f"[i] Available TF embedding symbols: {len(available_tf_symbols)}")
|
| 280 |
+
print(f"[i] Intersection count: {len(intersect_tf)}")
|
| 281 |
+
if len(intersect_tf) == 0:
|
| 282 |
+
print("[ERROR] No overlap between normalized TF_id and TF embedding symbols.", file=sys.stderr)
|
| 283 |
+
print("Sample normalized TFs from final.csv:", sorted(list(norm_tf_in_final))[:30], file=sys.stderr)
|
| 284 |
+
print("Sample available TF symbols:", sorted(list(available_tf_symbols))[:30], file=sys.stderr)
|
| 285 |
+
sys.exit(1)
|
| 286 |
+
|
| 287 |
+
dna_ids_final = set(df["dna_id"].unique())
|
| 288 |
+
available_dna_ids = set(dna_map.keys())
|
| 289 |
+
intersect_dna = dna_ids_final & available_dna_ids
|
| 290 |
+
print(f"[i] Unique dna_id in final.csv: {len(dna_ids_final)}. Available DNA ids: {len(available_dna_ids)}. Intersection: {len(intersect_dna)}")
|
| 291 |
+
if len(intersect_dna) == 0:
|
| 292 |
+
print("[ERROR] No overlap on DNA ids.", file=sys.stderr)
|
| 293 |
+
sys.exit(1)
|
| 294 |
+
|
| 295 |
+
# ββ NEW: MMseqs clustering on DNA sequences βββββββββββββββββββββββββββ
|
| 296 |
+
fasta_path = out_dir / "dna_unique.fasta"
|
| 297 |
+
write_dna_fasta(df, fasta_path)
|
| 298 |
+
print(f"[i] Wrote FASTA with {df['dna_id'].nunique()} unique sequences β {fasta_path}")
|
| 299 |
+
|
| 300 |
+
tmp_dir = Path(args.tmp_dir) if args.tmp_dir else (out_dir / "mmseqs_tmp")
|
| 301 |
+
cluster_prefix = out_dir / "mmseqs_dna_clusters"
|
| 302 |
+
clusters_tsv = run_mmseqs_easy_cluster(
|
| 303 |
+
mmseqs_bin=args.mmseqs_bin,
|
| 304 |
+
fasta=fasta_path,
|
| 305 |
+
out_prefix=cluster_prefix,
|
| 306 |
+
tmp_dir=tmp_dir,
|
| 307 |
+
min_seq_id=args.min_seq_id,
|
| 308 |
+
coverage=args.cov,
|
| 309 |
+
cov_mode=args.cov_mode,
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
# Parse clusters
|
| 313 |
+
member_to_rep = parse_mmseqs_clusters(clusters_tsv) # dna_id -> rep_id
|
| 314 |
+
# Build rep -> members list
|
| 315 |
+
rep_to_members = defaultdict(list)
|
| 316 |
+
for member, rep in member_to_rep.items():
|
| 317 |
+
rep_to_members[rep].append(member)
|
| 318 |
+
|
| 319 |
+
print(f"[i] Parsed {len(rep_to_members)} clusters from {clusters_tsv}")
|
| 320 |
+
clusters_table = []
|
| 321 |
+
for rep, members in rep_to_members.items():
|
| 322 |
+
for m in members:
|
| 323 |
+
clusters_table.append((m, rep))
|
| 324 |
+
clusters_df = pd.DataFrame(clusters_table, columns=["dna_id", "cluster_id"])
|
| 325 |
+
clusters_df.to_csv(out_dir / "clusters.tsv", sep="\t", index=False)
|
| 326 |
+
print(f"[i] Wrote clusters mapping β {out_dir / 'clusters.tsv'}")
|
| 327 |
+
|
| 328 |
+
# Attach cluster_id back to final df
|
| 329 |
+
df = df.merge(clusters_df, on="dna_id", how="left")
|
| 330 |
+
df.to_csv(out_dir / "final_with_dna_id_and_cluster.csv", index=False)
|
| 331 |
+
print(f"[i] Wrote {out_dir / 'final_with_dna_id_and_cluster.csv'}")
|
| 332 |
+
|
| 333 |
+
# Assign entire clusters to splits
|
| 334 |
+
splits = assign_clusters_to_splits(rep_to_members,
|
| 335 |
+
val_frac=args.val_frac,
|
| 336 |
+
test_frac=args.test_frac,
|
| 337 |
+
seed=args.seed)
|
| 338 |
+
for k in ["train", "val", "test"]:
|
| 339 |
+
print(f"[i] {k}: {len(splits[k])} dna_ids")
|
| 340 |
+
|
| 341 |
+
# ββ Build positive pairs only, per split (NO negatives) βββββββββββββββ
|
| 342 |
+
positives_by_split = {"train": [], "val": [], "test": []}
|
| 343 |
+
# Build a quick dna_id -> embedding path map
|
| 344 |
+
dnaid_to_path = {did: path for did, path in dna_map.items()}
|
| 345 |
+
|
| 346 |
+
pos_count = 0
|
| 347 |
+
for _, row in df.iterrows():
|
| 348 |
+
tf_raw = row["TF_id"]
|
| 349 |
+
tf_symbol = tf_raw.split("_seq")[0].upper()
|
| 350 |
+
dnaid = row["dna_id"]
|
| 351 |
+
if (tf_symbol not in tf_symbol_map) or (dnaid not in dnaid_to_path):
|
| 352 |
+
continue
|
| 353 |
+
tf_embedding_path = tf_symbol_map[tf_symbol][0] # first embedding per symbol
|
| 354 |
+
|
| 355 |
+
# decide split by dna_id cluster assignment
|
| 356 |
+
if dnaid in splits["train"]:
|
| 357 |
+
positives_by_split["train"].append((tf_embedding_path, dnaid_to_path[dnaid], 1))
|
| 358 |
+
elif dnaid in splits["val"]:
|
| 359 |
+
positives_by_split["val"].append((tf_embedding_path, dnaid_to_path[dnaid], 1))
|
| 360 |
+
elif dnaid in splits["test"]:
|
| 361 |
+
positives_by_split["test"].append((tf_embedding_path, dnaid_to_path[dnaid], 1))
|
| 362 |
+
pos_count += 1
|
| 363 |
+
|
| 364 |
+
print(f"[i] Constructed positives across splits (rows in final.csv iterated: {len(df)})")
|
| 365 |
+
for k in ["train", "val", "test"]:
|
| 366 |
+
print(f"[i] positives[{k}] = {len(positives_by_split[k])}")
|
| 367 |
+
|
| 368 |
+
# # OLD: negatives (kept commented)
|
| 369 |
+
# negatives = []
|
| 370 |
+
# print(f"[i] Sampled {len(negatives)} negatives (neg_per_positive not used)")
|
| 371 |
+
|
| 372 |
+
# Emit split-specific pair lists
|
| 373 |
+
for split in ["train", "val", "test"]:
|
| 374 |
+
out_tsv = out_dir / f"pair_list_{split}.tsv"
|
| 375 |
+
with open(out_tsv, "w") as f:
|
| 376 |
+
for binder_path, glm_path, label in positives_by_split[split]: # + negatives if you add later
|
| 377 |
+
f.write(f"{binder_path}\t{glm_path}\t{label}\n")
|
| 378 |
+
print(f"[i] Wrote {len(positives_by_split[split])} examples to {out_tsv}")
|
| 379 |
+
|
| 380 |
+
print("β
Done. Cluster-aware splits ready.")
|
| 381 |
+
|
| 382 |
+
if __name__ == "__main__":
|
| 383 |
+
main()
|
dpacman/classifier/model/compute_embeddings.py
CHANGED
|
@@ -25,6 +25,9 @@ from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM, pipelin
|
|
| 25 |
import esm
|
| 26 |
from Bio import SeqIO
|
| 27 |
import time
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
# ---- model wrappers ----
|
| 30 |
|
|
@@ -67,7 +70,7 @@ class CaduceusEmbedder:
|
|
| 67 |
|
| 68 |
# return np.stack(outputs, axis=0) # (N, L, D)
|
| 69 |
outputs = []
|
| 70 |
-
for seq in seqs:
|
| 71 |
toks = self.tokenizer(
|
| 72 |
seq,
|
| 73 |
return_tensors="pt",
|
|
@@ -471,34 +474,45 @@ def embed_and_save(seqs, ids, embedder, out_path):
|
|
| 471 |
if __name__=="__main__":
|
| 472 |
|
| 473 |
p = argparse.ArgumentParser()
|
| 474 |
-
p.add_argument("--
|
| 475 |
p.add_argument("--genome-json-dir", default=None, help="(fallback) directory of UCSC JSONs for full chromosome embedding if peak FASTA is missing or you explicitly want chromosomes")
|
| 476 |
p.add_argument("--skip-dna", action="store_true", help="if set, skip the chromosome embedding step") #if glm embeddings successful but not plm embeddings
|
| 477 |
p.add_argument("--tf-fasta", required=True, help="input TF FASTA file")
|
| 478 |
p.add_argument("--chrom-model", default="caduceus")
|
| 479 |
p.add_argument("--tf-model", default="esm-dbp")
|
| 480 |
-
p.add_argument("--out-dir", default="
|
| 481 |
p.add_argument("--device", default="cpu")
|
| 482 |
args = p.parse_args()
|
| 483 |
|
| 484 |
os.makedirs(args.out_dir, exist_ok=True)
|
| 485 |
device = args.device
|
|
|
|
| 486 |
|
| 487 |
if not args.skip_dna:
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
#
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
peak_seqs
|
| 494 |
-
peak_ids = []
|
| 495 |
-
for rec in SeqIO.parse(peak_fasta, "fasta"):
|
| 496 |
-
peak_ids.append(rec.id)
|
| 497 |
-
peak_seqs.append(str(rec.seq))
|
| 498 |
-
print(f"Embedding {len(peak_seqs)} binding peak sequences from {peak_fasta}", flush=True)
|
| 499 |
dna_embedder = get_embedder(args.chrom_model, device, for_dna=True)
|
| 500 |
out_peaks = Path(args.out_dir) / f"peaks_{args.chrom_model}.npy"
|
| 501 |
embed_and_save(peak_seqs, peak_ids, dna_embedder, out_peaks)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 502 |
elif args.genome_json_dir:
|
| 503 |
# Legacy: load full chromosomes from JSONs (chr1β22, X, Y, M)
|
| 504 |
genome_dir = Path(args.genome_json_dir)
|
|
|
|
| 25 |
import esm
|
| 26 |
from Bio import SeqIO
|
| 27 |
import time
|
| 28 |
+
import pandas as pd
|
| 29 |
+
from tqdm.auto import tqdm
|
| 30 |
+
import logging, math
|
| 31 |
|
| 32 |
# ---- model wrappers ----
|
| 33 |
|
|
|
|
| 70 |
|
| 71 |
# return np.stack(outputs, axis=0) # (N, L, D)
|
| 72 |
outputs = []
|
| 73 |
+
for seq in tqdm(seqs, total=len(seqs), desc="DNA: Caduceus", dynamic_ncols=True):
|
| 74 |
toks = self.tokenizer(
|
| 75 |
seq,
|
| 76 |
return_tensors="pt",
|
|
|
|
| 474 |
if __name__=="__main__":
|
| 475 |
|
| 476 |
p = argparse.ArgumentParser()
|
| 477 |
+
#p.add_argument("--peak_fasta", default="binding_peaks_unique.fa", help="FASTA of deduplicated binding peak sequences; if present this is used for DNA embedding instead of genome JSONs")
|
| 478 |
p.add_argument("--genome-json-dir", default=None, help="(fallback) directory of UCSC JSONs for full chromosome embedding if peak FASTA is missing or you explicitly want chromosomes")
|
| 479 |
p.add_argument("--skip-dna", action="store_true", help="if set, skip the chromosome embedding step") #if glm embeddings successful but not plm embeddings
|
| 480 |
p.add_argument("--tf-fasta", required=True, help="input TF FASTA file")
|
| 481 |
p.add_argument("--chrom-model", default="caduceus")
|
| 482 |
p.add_argument("--tf-model", default="esm-dbp")
|
| 483 |
+
p.add_argument("--out-dir", default="dpacman/model/embeddings")
|
| 484 |
p.add_argument("--device", default="cpu")
|
| 485 |
args = p.parse_args()
|
| 486 |
|
| 487 |
os.makedirs(args.out_dir, exist_ok=True)
|
| 488 |
device = args.device
|
| 489 |
+
print(device)
|
| 490 |
|
| 491 |
if not args.skip_dna:
|
| 492 |
+
if args.genome_json_dir == None:
|
| 493 |
+
dna_df = pd.read_parquet('/home/a03-akrishna/DPACMAN/dpacman/model/remap2022_crm_fimo_output_q_processed.parquet', engine='pyarrow')
|
| 494 |
+
#df.to_csv('/home/a03-akrishna/DPACMAN/dpacman/model/remap2022_crm_fimo_output_q_processed.csv', index=False)
|
| 495 |
+
peak_seqs = dna_df["dna_sequence"]
|
| 496 |
+
peak_ids = dna_df["ID"]
|
| 497 |
+
print(f"Embedding {len(peak_seqs)} binding peak sequences from processed remap data", flush=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 498 |
dna_embedder = get_embedder(args.chrom_model, device, for_dna=True)
|
| 499 |
out_peaks = Path(args.out_dir) / f"peaks_{args.chrom_model}.npy"
|
| 500 |
embed_and_save(peak_seqs, peak_ids, dna_embedder, out_peaks)
|
| 501 |
+
|
| 502 |
+
# peak_fasta = Path(args.peak_fasta)
|
| 503 |
+
# if peak_fasta.exists():
|
| 504 |
+
# # Load peak sequences from FASTA
|
| 505 |
+
# from Bio import SeqIO
|
| 506 |
+
|
| 507 |
+
# peak_seqs = []
|
| 508 |
+
# peak_ids = []
|
| 509 |
+
# for rec in SeqIO.parse(peak_fasta, "fasta"):
|
| 510 |
+
# peak_ids.append(rec.id)
|
| 511 |
+
# peak_seqs.append(str(rec.seq))
|
| 512 |
+
# print(f"Embedding {len(peak_seqs)} binding peak sequences from {peak_fasta}", flush=True)
|
| 513 |
+
# dna_embedder = get_embedder(args.chrom_model, device, for_dna=True)
|
| 514 |
+
# out_peaks = Path(args.out_dir) / f"peaks_{args.chrom_model}.npy"
|
| 515 |
+
# embed_and_save(peak_seqs, peak_ids, dna_embedder, out_peaks)
|
| 516 |
elif args.genome_json_dir:
|
| 517 |
# Legacy: load full chromosomes from JSONs (chr1β22, X, Y, M)
|
| 518 |
genome_dir = Path(args.genome_json_dir)
|
dpacman/classifier/model/model.py
CHANGED
|
@@ -39,7 +39,7 @@ class CrossModalBlock(nn.Module):
|
|
| 39 |
def forward(self, binder: torch.Tensor, glm: torch.Tensor):
|
| 40 |
"""
|
| 41 |
binder: (batch, Lb, dim)
|
| 42 |
-
glm:
|
| 43 |
returns: updated binder representation (batch, Lb, dim)
|
| 44 |
"""
|
| 45 |
# binder self-attn + ffn
|
|
@@ -63,16 +63,50 @@ class CrossModalBlock(nn.Module):
|
|
| 63 |
c = self.ln_c2(c + c_ff)
|
| 64 |
return c # (batch, Lb, dim)
|
| 65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
class BindPredictor(nn.Module):
|
| 67 |
def __init__(self,
|
| 68 |
-
input_dim: int = 256,
|
|
|
|
|
|
|
|
|
|
| 69 |
hidden_dim: int = 256,
|
| 70 |
heads: int = 8,
|
| 71 |
num_layers: int = 4,
|
| 72 |
use_local_cnn_on_glm: bool = True):
|
| 73 |
super().__init__()
|
| 74 |
-
|
| 75 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
self.use_local_cnn = use_local_cnn_on_glm
|
| 77 |
self.local_cnn = LocalCNN(hidden_dim) if use_local_cnn_on_glm else nn.Identity()
|
| 78 |
|
|
@@ -81,23 +115,28 @@ class BindPredictor(nn.Module):
|
|
| 81 |
])
|
| 82 |
|
| 83 |
self.ln_out = nn.LayerNorm(hidden_dim)
|
| 84 |
-
self.head = nn.Sequential(
|
| 85 |
-
|
| 86 |
-
nn.Sigmoid()
|
| 87 |
-
)
|
| 88 |
|
| 89 |
def forward(self, binder_emb, glm_emb):
|
| 90 |
"""
|
| 91 |
-
binder_emb
|
|
|
|
|
|
|
| 92 |
"""
|
| 93 |
-
|
| 94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
if self.use_local_cnn:
|
| 96 |
-
g = self.local_cnn(g)
|
| 97 |
|
|
|
|
| 98 |
for layer in self.layers:
|
| 99 |
-
b = layer(b, g)
|
| 100 |
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
return self.head(
|
|
|
|
| 39 |
def forward(self, binder: torch.Tensor, glm: torch.Tensor):
|
| 40 |
"""
|
| 41 |
binder: (batch, Lb, dim)
|
| 42 |
+
glm: (batch, Lg, dim) -- has passed through its local CNN beforehand
|
| 43 |
returns: updated binder representation (batch, Lb, dim)
|
| 44 |
"""
|
| 45 |
# binder self-attn + ffn
|
|
|
|
| 63 |
c = self.ln_c2(c + c_ff)
|
| 64 |
return c # (batch, Lb, dim)
|
| 65 |
|
| 66 |
+
class DimCompressor(nn.Module):
|
| 67 |
+
"""
|
| 68 |
+
Learnable per-token compressor: maps any in_dim >= out_dim to out_dim (default 256).
|
| 69 |
+
If in_dim == out_dim, behaves as identity.
|
| 70 |
+
"""
|
| 71 |
+
def __init__(self, in_dim: int, out_dim: int = 256):
|
| 72 |
+
super().__init__()
|
| 73 |
+
if in_dim == out_dim:
|
| 74 |
+
self.net = nn.Identity()
|
| 75 |
+
else:
|
| 76 |
+
hidden = max(out_dim * 2, (in_dim + out_dim) // 2)
|
| 77 |
+
self.net = nn.Sequential(
|
| 78 |
+
nn.LayerNorm(in_dim),
|
| 79 |
+
nn.Linear(in_dim, hidden),
|
| 80 |
+
nn.GELU(),
|
| 81 |
+
nn.Linear(hidden, out_dim),
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 85 |
+
# x: (B, L, in_dim)
|
| 86 |
+
return self.net(x)
|
| 87 |
+
|
| 88 |
class BindPredictor(nn.Module):
|
| 89 |
def __init__(self,
|
| 90 |
+
# input_dim: int = 256, # OLD: single input dim
|
| 91 |
+
binder_input_dim: int = 1280, # NEW: TF (binder) original dim (e.g., 1280)
|
| 92 |
+
glm_input_dim: int = 256, # NEW: DNA/GLM original dim (e.g., 256)
|
| 93 |
+
compressed_dim: int = 256, # NEW: learnable compressed dim
|
| 94 |
hidden_dim: int = 256,
|
| 95 |
heads: int = 8,
|
| 96 |
num_layers: int = 4,
|
| 97 |
use_local_cnn_on_glm: bool = True):
|
| 98 |
super().__init__()
|
| 99 |
+
# OLD:
|
| 100 |
+
# self.proj_binder = nn.Linear(input_dim, hidden_dim)
|
| 101 |
+
# self.proj_glm = nn.Linear(input_dim, hidden_dim)
|
| 102 |
+
|
| 103 |
+
# NEW: learnable compressor for binder β 256, then project to hidden
|
| 104 |
+
self.binder_compress = DimCompressor(binder_input_dim, out_dim=compressed_dim)
|
| 105 |
+
self.proj_binder = nn.Linear(compressed_dim, hidden_dim)
|
| 106 |
+
|
| 107 |
+
# GLM side stays 256 β hidden
|
| 108 |
+
self.proj_glm = nn.Linear(glm_input_dim, hidden_dim)
|
| 109 |
+
|
| 110 |
self.use_local_cnn = use_local_cnn_on_glm
|
| 111 |
self.local_cnn = LocalCNN(hidden_dim) if use_local_cnn_on_glm else nn.Identity()
|
| 112 |
|
|
|
|
| 115 |
])
|
| 116 |
|
| 117 |
self.ln_out = nn.LayerNorm(hidden_dim)
|
| 118 |
+
# self.head = nn.Sequential(nn.Linear(hidden_dim, 1), nn.Sigmoid()) # OLD: returned probabilities
|
| 119 |
+
self.head = nn.Linear(hidden_dim, 1) # NEW: return logits (safe for AMP)
|
|
|
|
|
|
|
| 120 |
|
| 121 |
def forward(self, binder_emb, glm_emb):
|
| 122 |
"""
|
| 123 |
+
binder_emb: (B, Lb, binder_input_dim)
|
| 124 |
+
glm_emb: (B, Lg, glm_input_dim)
|
| 125 |
+
Returns per-nucleotide logits for the GLM sequence: (B, Lg)
|
| 126 |
"""
|
| 127 |
+
# Binder: learnable compression β 256 β hidden
|
| 128 |
+
b = self.binder_compress(binder_emb) # (B, Lb, 256)
|
| 129 |
+
b = self.proj_binder(b) # (B, Lb, hidden_dim)
|
| 130 |
+
|
| 131 |
+
# GLM: project β hidden, add local CNN context
|
| 132 |
+
g = self.proj_glm(glm_emb) # (B, Lg, hidden_dim)
|
| 133 |
if self.use_local_cnn:
|
| 134 |
+
g = self.local_cnn(g)
|
| 135 |
|
| 136 |
+
# Cross-modal blocks: update binder states using GLM
|
| 137 |
for layer in self.layers:
|
| 138 |
+
b = layer(b, g) # (B, Lb, hidden_dim)
|
| 139 |
|
| 140 |
+
# Predict per-nucleotide logits on the GLM tokens:
|
| 141 |
+
# return self.head(g).squeeze(-1) # OLD: probabilities (with Sigmoid in head)
|
| 142 |
+
return self.head(g).squeeze(-1) # NEW: logits (apply sigmoid only in loss/metrics)
|
dpacman/classifier/model/train.py
CHANGED
|
@@ -1,262 +1,383 @@
|
|
| 1 |
-
|
| 2 |
-
import
|
|
|
|
| 3 |
import numpy as np
|
|
|
|
| 4 |
import torch
|
| 5 |
from torch import nn
|
| 6 |
-
from torch.utils.data import Dataset, DataLoader
|
| 7 |
-
from
|
| 8 |
-
|
| 9 |
-
from collections import Counter
|
| 10 |
-
from sklearn.metrics import roc_auc_score, average_precision_score
|
| 11 |
-
from sklearn.decomposition import TruncatedSVD
|
| 12 |
-
import random
|
| 13 |
-
import sys
|
| 14 |
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
"""
|
| 19 |
-
tf_compressed_cache: dict mapping binder_path -> compressed (256-d) tensor/array
|
| 20 |
-
"""
|
| 21 |
-
assert len(binder_paths) == len(glm_paths) == len(labels)
|
| 22 |
-
self.binder_paths = binder_paths
|
| 23 |
-
self.glm_paths = glm_paths
|
| 24 |
-
self.labels = labels
|
| 25 |
-
self.tf_cache = tf_compressed_cache # already reduced to 256
|
| 26 |
-
|
| 27 |
-
def __len__(self):
|
| 28 |
-
return len(self.labels)
|
| 29 |
-
|
| 30 |
-
def __getitem__(self, idx):
|
| 31 |
-
# binder = TF embedding (possibly reduced)
|
| 32 |
-
b = self.tf_cache[self.binder_paths[idx]] # numpy array shape (L, 256) or (256,)
|
| 33 |
-
g = np.load(self.glm_paths[idx]) # glm (DNA) embedding
|
| 34 |
-
|
| 35 |
-
if b.ndim == 1:
|
| 36 |
-
b = b[None, :]
|
| 37 |
-
if g.ndim == 1:
|
| 38 |
-
g = g[None, :]
|
| 39 |
-
|
| 40 |
-
b_tensor = torch.from_numpy(b).float()
|
| 41 |
-
g_tensor = torch.from_numpy(g).float()
|
| 42 |
-
y = torch.tensor(self.labels[idx]).float()
|
| 43 |
-
return b_tensor, g_tensor, y
|
| 44 |
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
max_g = max(glm_lens)
|
| 51 |
-
|
| 52 |
-
def pad_seq(seq, target_len):
|
| 53 |
-
L, D = seq.shape
|
| 54 |
-
if L < target_len:
|
| 55 |
-
pad = torch.zeros((target_len - L, D), dtype=seq.dtype, device=seq.device)
|
| 56 |
-
return torch.cat([seq, pad], dim=0)
|
| 57 |
-
return seq
|
| 58 |
-
|
| 59 |
-
b_padded = torch.stack([pad_seq(b, max_b) for b in binders]) # (B, Lb, D)
|
| 60 |
-
g_padded = torch.stack([pad_seq(g, max_g) for g in glms]) # (B, Lg, D)
|
| 61 |
-
y = torch.stack(labels)
|
| 62 |
-
return b_padded, g_padded, y
|
| 63 |
-
|
| 64 |
-
# ---- utilities -------------------------------------------------------
|
| 65 |
-
def parse_pair_list(pair_list_path):
|
| 66 |
-
binder_paths, glm_paths, labels = [], [], []
|
| 67 |
-
with open(pair_list_path) as f:
|
| 68 |
-
for lineno, line in enumerate(f, start=1):
|
| 69 |
-
if not line.strip():
|
| 70 |
-
continue
|
| 71 |
parts = line.strip().split()
|
| 72 |
-
if len(parts)
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
"""
|
| 88 |
-
Load
|
| 89 |
"""
|
| 90 |
-
|
| 91 |
-
print(f"[i]
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
for p in
|
| 95 |
-
arr = np.load(p)
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
orig_dim =
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
#
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
model.eval()
|
| 142 |
-
|
| 143 |
-
|
| 144 |
with torch.no_grad():
|
| 145 |
-
for b,
|
| 146 |
-
b = b.to(device)
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
if
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
def main():
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
random.seed(args.seed)
|
| 181 |
np.random.seed(args.seed)
|
| 182 |
torch.manual_seed(args.seed)
|
| 183 |
-
|
| 184 |
-
print("DEBUG: starting training script with in-line TF compression", flush=True)
|
| 185 |
-
print(f"[i] pair_list: {args.pair_list}", flush=True)
|
| 186 |
-
print(f"[i] output dir: {args.out_dir}", flush=True)
|
| 187 |
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
|
| 188 |
-
binder_paths, glm_paths, labels = parse_pair_list(args.pair_list)
|
| 189 |
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
|
|
|
|
|
|
|
|
|
| 193 |
|
| 194 |
-
|
| 195 |
-
|
|
|
|
|
|
|
| 196 |
|
| 197 |
-
#
|
| 198 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
|
| 208 |
-
|
| 209 |
-
|
|
|
|
|
|
|
|
|
|
| 210 |
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
test_dl = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, collate_fn=collate_fn)
|
| 223 |
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
loss_fn = nn.BCELoss()
|
| 228 |
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
|
|
|
| 232 |
|
| 233 |
-
|
| 234 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
model.train()
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
main()
|
|
|
|
| 1 |
+
import argparse, random, sys
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
import numpy as np
|
| 5 |
+
import pandas as pd
|
| 6 |
import torch
|
| 7 |
from torch import nn
|
| 8 |
+
from torch.utils.data import Dataset, DataLoader, Sampler
|
| 9 |
+
# from sklearn.random_projection import GaussianRandomProjection # OLD (kept): projection was removed earlier
|
| 10 |
+
import matplotlib.pyplot as plt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
+
import torch.amp as amp
|
| 13 |
+
from torch.nn import functional as F
|
| 14 |
+
from model import BindPredictor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
+
# βββββββββββββββ utilities ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 17 |
+
def parse_pair_list(path):
|
| 18 |
+
binders, glms = [], []
|
| 19 |
+
with open(path) as f:
|
| 20 |
+
for ln, line in enumerate(f,1):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
parts = line.strip().split()
|
| 22 |
+
if len(parts) < 2: continue
|
| 23 |
+
b,g = parts[0], parts[1]
|
| 24 |
+
binders.append(b); glms.append(g)
|
| 25 |
+
return binders, glms
|
| 26 |
+
|
| 27 |
+
class ListBatchSampler(Sampler):
|
| 28 |
+
def __init__(self, batches): self.batches = batches
|
| 29 |
+
def __iter__(self): return iter(self.batches)
|
| 30 |
+
def __len__(self): return len(self.batches)
|
| 31 |
+
|
| 32 |
+
def make_buckets(idxs, glm_paths, batch_size, n_buckets=10, seed=42):
|
| 33 |
+
rng = random.Random(seed)
|
| 34 |
+
lengths = [(i, np.load(glm_paths[i]).shape[0]) for i in idxs]
|
| 35 |
+
lengths.sort(key=lambda x: x[1])
|
| 36 |
+
size = max(1, int(np.ceil(len(lengths)/n_buckets)))
|
| 37 |
+
buckets = [lengths[i:i+size] for i in range(0,len(lengths),size)]
|
| 38 |
+
batches = []
|
| 39 |
+
for bucket in buckets:
|
| 40 |
+
ids = [i for i,_ in bucket]
|
| 41 |
+
rng.shuffle(ids)
|
| 42 |
+
for i in range(0,len(ids),batch_size):
|
| 43 |
+
batches.append(ids[i:i+batch_size])
|
| 44 |
+
rng.shuffle(batches)
|
| 45 |
+
return batches
|
| 46 |
+
|
| 47 |
+
def dna_key_from_path(path: str) -> str:
|
| 48 |
+
""".../dna_peak42.npy -> 'peak42'"""
|
| 49 |
+
stem = Path(path).stem
|
| 50 |
+
if "_" in stem:
|
| 51 |
+
_, rest = stem.split("_", 1)
|
| 52 |
+
else:
|
| 53 |
+
rest = stem
|
| 54 |
+
return rest
|
| 55 |
+
|
| 56 |
+
def build_tf_cache(tf_paths, target_dim=256):
|
| 57 |
"""
|
| 58 |
+
Load raw TF embeddings without projecting; compression is learnable in the model.
|
| 59 |
"""
|
| 60 |
+
unique = sorted(set(tf_paths))
|
| 61 |
+
print(f"[i] (Learnable) Preparing {len(unique)} TF files; target {target_dim}d inside the model", flush=True)
|
| 62 |
+
|
| 63 |
+
pools, raw = [], []
|
| 64 |
+
for p in unique:
|
| 65 |
+
arr = np.load(p) # (L, D) or (D,)
|
| 66 |
+
raw.append(arr)
|
| 67 |
+
pools.append(arr.mean(axis=0) if arr.ndim==2 else arr)
|
| 68 |
+
M = np.stack(pools,0)
|
| 69 |
+
orig_dim = M.shape[1]
|
| 70 |
+
print(f"[i] Pooled shape β {M.shape} (orig_dim={orig_dim})", flush=True)
|
| 71 |
+
|
| 72 |
+
cache = {}
|
| 73 |
+
for i,p in enumerate(unique):
|
| 74 |
+
arr = raw[i]
|
| 75 |
+
# OLD: projection here (removed)
|
| 76 |
+
cache[p] = arr
|
| 77 |
+
print("[i] TF cache ready (raw); compression will be learned.", flush=True)
|
| 78 |
+
return cache
|
| 79 |
+
|
| 80 |
+
# βββββββββββββββ Dataset & Collation βββββββββββββββββββββββββββββββββββββ
|
| 81 |
+
class PairDataset(Dataset):
|
| 82 |
+
def __init__(self, tf_paths, dna_paths, final_df, tf_cache):
|
| 83 |
+
self.tf_paths, self.dna_paths = tf_paths, dna_paths
|
| 84 |
+
self.tf_cache = tf_cache
|
| 85 |
+
self.targets = {}
|
| 86 |
+
for _, row in final_df.iterrows():
|
| 87 |
+
dna_id = row["dna_id"]
|
| 88 |
+
vec = np.array(list(map(float, row["score_sig_r2"].split(","))), dtype=np.float32)
|
| 89 |
+
self.targets[dna_id] = vec
|
| 90 |
+
|
| 91 |
+
def __len__(self): return len(self.tf_paths)
|
| 92 |
+
|
| 93 |
+
def __getitem__(self, i):
|
| 94 |
+
b = self.tf_cache[self.tf_paths[i]] # (L_b, D_b) or (D_b,)
|
| 95 |
+
if b.ndim==1: b = b[None,:]
|
| 96 |
+
g = np.load(self.dna_paths[i]) # (L_g, 256) or (256,)
|
| 97 |
+
if g.ndim==1: g = g[None,:]
|
| 98 |
+
|
| 99 |
+
stem = Path(self.dna_paths[i]).stem
|
| 100 |
+
dna_id = stem.replace("dna_","")
|
| 101 |
+
t = self.targets.get(dna_id, np.zeros(g.shape[0],dtype=np.float32))
|
| 102 |
+
|
| 103 |
+
return torch.from_numpy(b).float(), \
|
| 104 |
+
torch.from_numpy(g).float(), \
|
| 105 |
+
torch.from_numpy(t).float()
|
| 106 |
+
|
| 107 |
+
def collate_fn(batch):
|
| 108 |
+
Bs = [b.shape[0] for b,_,_ in batch]
|
| 109 |
+
Gs = [g.shape[0] for _,g,_ in batch]
|
| 110 |
+
maxB, maxG = max(Bs), max(Gs)
|
| 111 |
+
|
| 112 |
+
def pad_seq(x, L):
|
| 113 |
+
if x.shape[0] < L:
|
| 114 |
+
pad = torch.zeros((L-x.shape[0], x.shape[1]), dtype=x.dtype, device=x.device)
|
| 115 |
+
return torch.cat([x, pad], dim=0)
|
| 116 |
+
return x
|
| 117 |
+
|
| 118 |
+
def pad_t(y, L):
|
| 119 |
+
if y.shape[0] < L:
|
| 120 |
+
pad = torch.zeros((L-y.shape[0],), dtype=y.dtype, device=y.device)
|
| 121 |
+
return torch.cat([y, pad], dim=0)
|
| 122 |
+
return y
|
| 123 |
+
|
| 124 |
+
b_stack = torch.stack([pad_seq(b, maxB) for b,_,_ in batch])
|
| 125 |
+
g_stack = torch.stack([pad_seq(g, maxG) for _,g,_ in batch])
|
| 126 |
+
t_stack = torch.stack([pad_t(t, maxG) for *_,t in batch])
|
| 127 |
+
return b_stack, g_stack, t_stack
|
| 128 |
+
|
| 129 |
+
# βββββββββββββββ losses, metrics βββββββββββββββββββββββββββββββββββββββββ
|
| 130 |
+
def combined_loss_components(logits, targets, peak_thresh=0.5, eps=1e-8):
|
| 131 |
+
probs = torch.sigmoid(logits)
|
| 132 |
+
labels = (targets >= peak_thresh).float()
|
| 133 |
+
non_peak_mask = (labels == 0).float()
|
| 134 |
+
peak_mask = (labels == 1).float()
|
| 135 |
+
|
| 136 |
+
bce_all = F.binary_cross_entropy_with_logits(logits, labels, reduction='none')
|
| 137 |
+
bce_non = (bce_all * non_peak_mask)
|
| 138 |
+
bce_non = bce_non.sum() / (non_peak_mask.sum() + eps)
|
| 139 |
+
|
| 140 |
+
mse_peaks = F.mse_loss(probs * peak_mask, targets * peak_mask, reduction='sum') \
|
| 141 |
+
/ (peak_mask.sum() + eps)
|
| 142 |
+
|
| 143 |
+
t_dist = (targets + eps)
|
| 144 |
+
p_dist = (probs + eps)
|
| 145 |
+
t_dist = t_dist / t_dist.sum(dim=1, keepdim=True)
|
| 146 |
+
p_dist = p_dist / p_dist.sum(dim=1, keepdim=True)
|
| 147 |
+
kl = (t_dist * (t_dist.clamp(min=eps).log() - p_dist.clamp(min=eps).log())).sum(dim=1).mean()
|
| 148 |
+
|
| 149 |
+
return bce_non, kl, mse_peaks, probs
|
| 150 |
+
|
| 151 |
+
def accuracy_percentage(logits, targets, peak_thresh=0.5):
|
| 152 |
+
probs = torch.sigmoid(logits)
|
| 153 |
+
preds_bin = (probs >= 0.5).float()
|
| 154 |
+
labels = (targets >= peak_thresh).float()
|
| 155 |
+
correct = (preds_bin == labels).float().sum()
|
| 156 |
+
total = torch.numel(labels)
|
| 157 |
+
return (correct / max(1, total)).item() * 100.0
|
| 158 |
+
|
| 159 |
+
def evaluate(model, dl, device, alpha, beta, gamma, peak_thresh, eps=1e-8):
|
| 160 |
model.eval()
|
| 161 |
+
tot_loss, tot_acc = 0.0, 0.0
|
| 162 |
+
n_batches = 0
|
| 163 |
with torch.no_grad():
|
| 164 |
+
for b,g,t in dl:
|
| 165 |
+
b,g,t = b.to(device), g.to(device), t.to(device)
|
| 166 |
+
logits = model(b,g)
|
| 167 |
+
bce_non, kl, mse_peaks, _ = combined_loss_components(logits, t, peak_thresh=peak_thresh, eps=eps)
|
| 168 |
+
loss = alpha*bce_non + beta*kl + gamma*mse_peaks
|
| 169 |
+
acc = accuracy_percentage(logits, t, peak_thresh=peak_thresh)
|
| 170 |
+
tot_loss += loss.item(); tot_acc += acc; n_batches += 1
|
| 171 |
+
if n_batches == 0: return float("nan"), float("nan")
|
| 172 |
+
return tot_loss / n_batches, tot_acc / n_batches
|
| 173 |
+
|
| 174 |
+
# βββββββββββββββ cluster-aware splitting ββββββββββββββββββββββββββββββββββ
|
| 175 |
+
def assign_clusters_to_splits(cluster_to_indices, val_frac=0.10, test_frac=0.10, seed=42):
|
| 176 |
+
"""
|
| 177 |
+
cluster_to_indices: dict[cluster_id] -> list of example indices (from pair_list) in that cluster
|
| 178 |
+
We greedily pack whole clusters into val/test until hitting targets (#examples), rest to train.
|
| 179 |
+
"""
|
| 180 |
+
rng = random.Random(seed)
|
| 181 |
+
clusters = list(cluster_to_indices.items())
|
| 182 |
+
rng.shuffle(clusters)
|
| 183 |
+
|
| 184 |
+
total = sum(len(ixs) for _, ixs in clusters)
|
| 185 |
+
target_val = int(round(total * val_frac))
|
| 186 |
+
target_test = int(round(total * test_frac))
|
| 187 |
+
cur_val = cur_test = 0
|
| 188 |
+
|
| 189 |
+
tr_ix, va_ix, te_ix = [], [], []
|
| 190 |
+
for cid, ixs in clusters:
|
| 191 |
+
c = len(ixs)
|
| 192 |
+
if cur_val + c <= target_val:
|
| 193 |
+
va_ix.extend(ixs); cur_val += c
|
| 194 |
+
elif cur_test + c <= target_test:
|
| 195 |
+
te_ix.extend(ixs); cur_test += c
|
| 196 |
+
else:
|
| 197 |
+
tr_ix.extend(ixs)
|
| 198 |
+
return tr_ix, va_ix, te_ix
|
| 199 |
+
|
| 200 |
+
# βββββββββββββββ train & main ββββββββββββββββββββββββββββββββββββββββββββ
|
| 201 |
def main():
|
| 202 |
+
p = argparse.ArgumentParser()
|
| 203 |
+
p.add_argument("--pair_list", required=True)
|
| 204 |
+
p.add_argument("--final_csv", required=True)
|
| 205 |
+
p.add_argument("--out_dir", required=True)
|
| 206 |
+
p.add_argument("--epochs", type=int, default=10)
|
| 207 |
+
p.add_argument("--batch_size", type=int, default=16)
|
| 208 |
+
p.add_argument("--accum_steps", type=int, default=4)
|
| 209 |
+
p.add_argument("--lr", type=float, default=1e-4)
|
| 210 |
+
p.add_argument("--device", default="cuda")
|
| 211 |
+
p.add_argument("--seed", type=int, default=42)
|
| 212 |
+
p.add_argument("--alpha", type=float, default=0.5)
|
| 213 |
+
p.add_argument("--beta", type=float, default=0.6)
|
| 214 |
+
p.add_argument("--gamma", type=float, default=0.6)
|
| 215 |
+
p.add_argument("--peak_thresh", type=float, default=0.5)
|
| 216 |
+
# NEW: fractions for cluster-aware split (used only if cluster_id present)
|
| 217 |
+
p.add_argument("--val_frac", type=float, default=0.10)
|
| 218 |
+
p.add_argument("--test_frac", type=float, default=0.10)
|
| 219 |
+
args = p.parse_args()
|
| 220 |
+
|
| 221 |
random.seed(args.seed)
|
| 222 |
np.random.seed(args.seed)
|
| 223 |
torch.manual_seed(args.seed)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
|
|
|
|
| 225 |
|
| 226 |
+
# 1) load pair list & final.csv (now may include cluster_id)
|
| 227 |
+
tf_paths, dna_paths = parse_pair_list(args.pair_list)
|
| 228 |
+
final_df = pd.read_csv(args.final_csv, dtype=str)
|
| 229 |
+
print(f"[i] Loaded {len(tf_paths)} pairs", flush=True)
|
| 230 |
+
|
| 231 |
+
tf_cache = build_tf_cache(tf_paths, target_dim=256)
|
| 232 |
|
| 233 |
+
# detect binder/DNA dims
|
| 234 |
+
sample_tf = tf_cache[tf_paths[0]]
|
| 235 |
+
binder_input_dim = sample_tf.shape[1] if sample_tf.ndim == 2 else sample_tf.shape[0]
|
| 236 |
+
glm_input_dim = 256
|
| 237 |
|
| 238 |
+
# 2) cluster-aware split if possible
|
| 239 |
+
use_cluster_split = ("cluster_id" in final_df.columns)
|
| 240 |
+
if use_cluster_split:
|
| 241 |
+
print("[i] Cluster column detected in final_csv; performing cluster-aware split.", flush=True)
|
| 242 |
+
# build dna_id -> cluster_id map
|
| 243 |
+
cid_map = (final_df[["dna_id","cluster_id"]].dropna().drop_duplicates()
|
| 244 |
+
.set_index("dna_id")["cluster_id"].to_dict())
|
| 245 |
|
| 246 |
+
# map each example (by index) to its dna_id and cluster
|
| 247 |
+
example_dna_ids = [dna_key_from_path(p) for p in dna_paths]
|
| 248 |
+
example_clusters = []
|
| 249 |
+
missing = 0
|
| 250 |
+
for did in example_dna_ids:
|
| 251 |
+
if did in cid_map:
|
| 252 |
+
example_clusters.append(cid_map[did])
|
| 253 |
+
else:
|
| 254 |
+
# fallback: treat singleton cluster
|
| 255 |
+
example_clusters.append(f"singleton::{did}")
|
| 256 |
+
missing += 1
|
| 257 |
+
if missing:
|
| 258 |
+
print(f"[WARN] {missing} dna_ids from pair_list not found in cluster map; treating as singleton clusters.", flush=True)
|
| 259 |
+
|
| 260 |
+
# build cluster -> indices
|
| 261 |
+
cluster_to_indices = {}
|
| 262 |
+
for i, cid in enumerate(example_clusters):
|
| 263 |
+
cluster_to_indices.setdefault(cid, []).append(i)
|
| 264 |
|
| 265 |
+
tr_idx, va_idx, te_idx = assign_clusters_to_splits(
|
| 266 |
+
cluster_to_indices,
|
| 267 |
+
val_frac=args.val_frac, test_frac=args.test_frac, seed=args.seed
|
| 268 |
+
)
|
| 269 |
+
print(f"[i] Cluster split sizes (examples): train={len(tr_idx)} val={len(va_idx)} test={len(te_idx)}", flush=True)
|
| 270 |
|
| 271 |
+
# helper to subset paths
|
| 272 |
+
def subset_by_indices(ixs):
|
| 273 |
+
return [tf_paths[i] for i in ixs], [dna_paths[i] for i in ixs]
|
| 274 |
|
| 275 |
+
tr_t, tr_d = subset_by_indices(tr_idx)
|
| 276 |
+
va_t, va_d = subset_by_indices(va_idx)
|
| 277 |
+
te_t, te_d = subset_by_indices(te_idx)
|
| 278 |
+
|
| 279 |
+
else:
|
| 280 |
+
print("[i] No cluster_id in final_csv; using random 80/10/10 split (OLD behavior).", flush=True)
|
| 281 |
+
# OLD random split (kept, now under else)
|
| 282 |
+
N = len(tf_paths)
|
| 283 |
+
idxs = list(range(N)); random.shuffle(idxs)
|
| 284 |
+
n_tr = int(0.8*N); n_va = int(0.1*N)
|
| 285 |
+
tr, va, te = idxs[:n_tr], idxs[n_tr:n_tr+n_va], idxs[n_tr+n_va:]
|
| 286 |
|
| 287 |
+
def subset(idxs_):
|
| 288 |
+
return [tf_paths[i] for i in idxs_], [dna_paths[i] for i in idxs_]
|
|
|
|
| 289 |
|
| 290 |
+
tr_t, tr_d = subset(tr)
|
| 291 |
+
va_t, va_d = subset(va)
|
| 292 |
+
te_t, te_d = subset(te)
|
|
|
|
| 293 |
|
| 294 |
+
# 3) bucketed samplers (unchanged, but now use the cluster-aware subsets when available)
|
| 295 |
+
tr_bs = make_buckets(list(range(len(tr_t))), tr_d, args.batch_size, n_buckets=10, seed=args.seed)
|
| 296 |
+
va_bs = make_buckets(list(range(len(va_t))), va_d, args.batch_size, n_buckets=5, seed=args.seed+1)
|
| 297 |
+
te_bs = make_buckets(list(range(len(te_t))), te_d, args.batch_size, n_buckets=5, seed=args.seed+2)
|
| 298 |
|
| 299 |
+
tr_dl = DataLoader(PairDataset(tr_t, tr_d, final_df, tf_cache),
|
| 300 |
+
batch_sampler=ListBatchSampler(tr_bs),
|
| 301 |
+
collate_fn=collate_fn)
|
| 302 |
+
va_dl = DataLoader(PairDataset(va_t, va_d, final_df, tf_cache),
|
| 303 |
+
batch_sampler=ListBatchSampler(va_bs),
|
| 304 |
+
collate_fn=collate_fn)
|
| 305 |
+
te_dl = DataLoader(PairDataset(te_t, te_d, final_df, tf_cache),
|
| 306 |
+
batch_sampler=ListBatchSampler(te_bs),
|
| 307 |
+
collate_fn=collate_fn)
|
| 308 |
+
|
| 309 |
+
# 4) model, optimizer, scaler
|
| 310 |
+
model = BindPredictor(binder_input_dim=binder_input_dim,
|
| 311 |
+
glm_input_dim=glm_input_dim,
|
| 312 |
+
compressed_dim=256,
|
| 313 |
+
hidden_dim=256,
|
| 314 |
+
heads=8, num_layers=4,
|
| 315 |
+
use_local_cnn_on_glm=True).to(device)
|
| 316 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
|
| 317 |
+
scaler = amp.GradScaler('cuda')
|
| 318 |
+
|
| 319 |
+
history, best_val = {"train": [], "val": []}, float("inf")
|
| 320 |
+
od = Path(args.out_dir); od.mkdir(exist_ok=True, parents=True)
|
| 321 |
+
|
| 322 |
+
for ep in range(1, args.epochs+1):
|
| 323 |
+
print(f"ββ[Epoch {ep}]ββββββββββββββββββββββββ", flush=True)
|
| 324 |
model.train()
|
| 325 |
+
optimizer.zero_grad()
|
| 326 |
+
acc_loss_sum, acc_acc_sum, n_train_batches = 0.0, 0.0, 0
|
| 327 |
+
|
| 328 |
+
for i, (b, g, t) in enumerate(tr_dl):
|
| 329 |
+
b, g, t = b.to(device), g.to(device), t.to(device)
|
| 330 |
+
with amp.autocast('cuda'):
|
| 331 |
+
logits = model(b, g)
|
| 332 |
+
bce_non, kl, mse_peaks, probs = combined_loss_components(
|
| 333 |
+
logits, t, peak_thresh=args.peak_thresh
|
| 334 |
+
)
|
| 335 |
+
loss = args.alpha*bce_non + args.beta*kl + args.gamma*mse_peaks
|
| 336 |
+
loss = loss / args.accum_steps
|
| 337 |
+
|
| 338 |
+
scaler.scale(loss).backward()
|
| 339 |
+
|
| 340 |
+
if (i + 1) % args.accum_steps == 0:
|
| 341 |
+
scaler.step(optimizer)
|
| 342 |
+
scaler.update()
|
| 343 |
+
optimizer.zero_grad()
|
| 344 |
+
|
| 345 |
+
with torch.no_grad():
|
| 346 |
+
acc_loss_sum += (loss.item() * args.accum_steps)
|
| 347 |
+
acc_acc_sum += accuracy_percentage(logits, t, peak_thresh=args.peak_thresh)
|
| 348 |
+
n_train_batches += 1
|
| 349 |
+
|
| 350 |
+
del b, g, t, logits, probs, loss, bce_non, kl, mse_peaks
|
| 351 |
+
torch.cuda.empty_cache()
|
| 352 |
+
|
| 353 |
+
# finalize if leftovers
|
| 354 |
+
if n_train_batches % args.accum_steps != 0:
|
| 355 |
+
scaler.step(optimizer); scaler.update(); optimizer.zero_grad()
|
| 356 |
+
|
| 357 |
+
train_loss = acc_loss_sum / max(1, n_train_batches)
|
| 358 |
+
train_acc = acc_acc_sum / max(1, n_train_batches)
|
| 359 |
+
|
| 360 |
+
val_loss, val_acc = evaluate(model, va_dl, device,
|
| 361 |
+
alpha=args.alpha, beta=args.beta, gamma=args.gamma,
|
| 362 |
+
peak_thresh=args.peak_thresh)
|
| 363 |
+
print(f"[Epoch {ep}] train_loss={train_loss:.4f} train_acc={train_acc:.2f}% "
|
| 364 |
+
f"val_loss={val_loss:.4f} val_acc={val_acc:.2f}%", flush=True)
|
| 365 |
+
|
| 366 |
+
history["train"].append(train_loss)
|
| 367 |
+
history["val"].append(val_loss)
|
| 368 |
+
if val_loss < best_val:
|
| 369 |
+
best_val = val_loss
|
| 370 |
+
torch.save(model.state_dict(), od/"best_model.pt")
|
| 371 |
+
print(f" Saved new best_model.pt (val_loss={val_loss:.4f}, val_acc={val_acc:.2f}%)", flush=True)
|
| 372 |
+
|
| 373 |
+
torch.save(model.state_dict(), od/"last_model.pt")
|
| 374 |
+
|
| 375 |
+
fig, ax = plt.subplots()
|
| 376 |
+
ax.plot(history["train"], label="train")
|
| 377 |
+
ax.plot(history["val"], label="val")
|
| 378 |
+
ax.set_xlabel("epoch"); ax.set_ylabel("combined loss"); ax.legend()
|
| 379 |
+
fig.savefig(od/"loss_curve.png")
|
| 380 |
+
print(f"β
Done β outputs in {od}", flush=True)
|
| 381 |
+
|
| 382 |
+
if __name__=="__main__":
|
| 383 |
main()
|