""" train_v3_fixed.py — ProtFunc v3 (corrected training procedure) ================================================================ Fixes the core methodological error in train_v2.py: WRONG (train_v2.py): propagate GO labels DURING TRAINING CORRECT (this file): train on original labels; propagate ONLY during evaluation Why the original approach was wrong: Propagating labels before training causes the model to directly predict broad ancestor terms ("binding", "catalytic activity") for nearly every protein. This inflates val micro-Fmax from ~0.88 to ~0.97 — an apples-to-oranges comparison that looks like a massive improvement but is mostly just the model learning trivially predictable parent terms. Threshold calibration on propagated ground truth then lets those broad terms fire at inference, producing hundreds of predictions per protein. The correct approach (used by CAFA competitors): 1. Train on experimental annotations as-is (specific terms only) 2. Evaluate using CAFA-style propagation (predictions + ground-truth both propagated upward) — this is fair because a protein that does "ATP hydrolysis" also implicitly performs "binding" Warm-start from improved_res.pth: Instead of rebuilding from scratch, we load improved_res.pth (Fmax=0.8846) as the starting point. For the supplemented (360-dim) model, we copy all weights and extend fc_in with small random values for the new feature dims. This typically saves ~15-20 training epochs. Ablation study for research question: "How much predictive gain from AlphaFold structural features (pLDDT, PAE)?" Run with --ablation to train three models in sequence: A: ESM only (320 dim) — baseline B: ESM + sequence features (331 dim) — adds composition/physicochemical C: ESM + all features (360 dim) — adds pLDDT, PAE, AF confidence Each model is evaluated on: 1. All proteins 2. AF-covered proteins only (where pLDDT/PAE are non-zero) This isolates the structural feature contribution. Outputs: artifacts/protfunc_v3_fixed.pth — best model (model C by default) artifacts/protfunc_v3_fixed_thresholds.json artifacts/protfunc_v3_fixed_log.json artifacts/ablation_results.json — if --ablation flag used """ import os, re, ast, json, math, time, argparse, warnings, requests, threading from pathlib import Path from concurrent.futures import ThreadPoolExecutor, as_completed import numpy as np import pandas as pd import joblib import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler warnings.filterwarnings("ignore") # ── Paths ───────────────────────────────────────────────────────────────────── BASE = Path(__file__).parent.parent ART = BASE / "artifacts" IMPORTANT = BASE / "Important Files" ART.mkdir(exist_ok=True) DATA_BASE = IMPORTANT / "merged_full_struct.parquet" DATA_SUPP = IMPORTANT / "merged_full_struct_with_features.parquet" MLB_PATH = IMPORTANT / "mlb_public_v1.pkl" SPLITS_NPZ = ART / "splits" / "splits_n250000_seed42.npz" MAMMAL_FASTA = IMPORTANT / "mammal_subset.fasta" MAMMAL_EMB = ART / "generalization" / "mammal_embeddings_v3.parquet" OBO_PATH = BASE / "go-basic.obo" # Prefer best ablation checkpoint (same in_dim=360); fall back to improved_res.pth _WARMSTART_CANDIDATES = [ ART / "checkpoints" / "ablation_C_ESM_seq_AF.pth", ART / "checkpoints" / "protfunc_v3_fixed.pth", ART / "checkpoints" / "improved_res.pth", ] WARMSTART = next((p for p in _WARMSTART_CANDIDATES if p.exists()), _WARMSTART_CANDIDATES[-1]) CKPT_OUT = ART / "checkpoints" / "protfunc_v3_fixed.pth" THRESH_OUT = ART / "thresholds" / "protfunc_v3_fixed_thresholds.json" LOG_OUT = ART / "logs" / "protfunc_v3_fixed_log.json" # ── Supplemented feature columns (39 total) ─────────────────────────────────── SUPP_COLS = [ "f_seq_len", "f_mean_hydro", "f_net_charge", "f_uversky_disorder", "f_idr_frac_proxy", "f_lowcomp_proxy", "f_tm_frac_proxy", "f_tm_any_proxy", "f_signal_peptide_proxy", "f_cf_helix_mean", "f_cf_sheet_mean", "f_afdb_has_model", "f_plddt_mean", "f_plddt_std", "f_plddt_q10", "f_plddt_q50", "f_plddt_q90", "f_plddt_frac_gt90", "f_plddt_frac_gt70", "f_plddt_frac_lt50", "f_distbin_0", "f_distbin_1", "f_distbin_2", "f_distbin_3", "f_distbin_4", "f_distbin_5", "f_distbin_6", "f_distbin_7", "f_distbin_8", "f_distbin_9", "f_pae_mean", "f_pae_median", "f_pae_p90", "f_pae_p95", "f_pae_frac_lt5", "f_pae_frac_lt10", "f_pae_frac_gt20", "f_seqfeat_present", "f_af_present", ] SEQ_ONLY_COLS = [ "f_seq_len", "f_mean_hydro", "f_net_charge", "f_uversky_disorder", "f_idr_frac_proxy", "f_lowcomp_proxy", "f_tm_frac_proxy", "f_tm_any_proxy", "f_signal_peptide_proxy", "f_cf_helix_mean", "f_cf_sheet_mean", ] # ── Config ──────────────────────────────────────────────────────────────────── SEED = 42 OUT_DIM = 8124 ESM_DIM = 320 SUPP_DIM = len(SUPP_COLS) # 39 BATCH = 512 PW_CLIP = (1.0, 100.0) MIN_SUPPORT = 10 FALLBACK_THRESH = 0.50 API_TIMEOUT = 12 API_THREADS = 20 EXP_CODES = {"IDA","IMP","IPI","IGI","IEP","EXP","HDA","HMP","HGI","HEP","TAS","IC"} np.random.seed(SEED) torch.manual_seed(SEED) # ───────────────────────────────────────────────────────────────────────────── # Architecture (same as improved / train_v2) # ───────────────────────────────────────────────────────────────────────────── class ResBlock(nn.Module): def __init__(self, dim: int, dropout: float): super().__init__() self.net = nn.Sequential( nn.BatchNorm1d(dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(dim, dim), nn.BatchNorm1d(dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(dim, dim), ) def forward(self, x): return x + self.net(x) class ImprovedResidualMLP(nn.Module): def __init__(self, in_dim, out_dim=OUT_DIM, hidden=2048, n_blocks=4, dropout=0.2): super().__init__() self.fc_in = nn.Linear(in_dim, hidden) self.blocks = nn.ModuleList([ResBlock(hidden, dropout) for _ in range(n_blocks)]) self.fc_out = nn.Sequential( nn.BatchNorm1d(hidden), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden, out_dim), ) def forward(self, x): h = self.fc_in(x) for blk in self.blocks: h = blk(h) return self.fc_out(h) # ───────────────────────────────────────────────────────────────────────────── # Warm-start: copy weights from improved_res.pth (in_dim=320) → new model # ───────────────────────────────────────────────────────────────────────────── def warm_start(model: ImprovedResidualMLP, ckpt_path: Path, in_dim_new: int): """ Load improved_res.pth weights into model, extending fc_in if needed. For in_dim_new > 320: copies the first 320 columns of fc_in.weight and initialises the remaining (in_dim_new - 320) columns with small random values so the new features start near-zero and don't disrupt representations. All other layers (blocks, fc_out) are copied exactly. """ if not ckpt_path.exists(): print(f" Warm-start skipped: {ckpt_path} not found (training from scratch)") return raw = torch.load(ckpt_path, map_location="cpu") src = raw["model"] if isinstance(raw, dict) and "model" in raw else raw dst = model.state_dict() loaded, skipped = 0, 0 for k, v in src.items(): if k not in dst: skipped += 1 continue if k == "fc_in.weight": old_in = v.shape[1] new_in = dst[k].shape[1] if new_in == old_in: dst[k] = v elif new_in > old_in: # Extend: copy known dims, small random for new dims new_w = torch.zeros_like(dst[k]) new_w[:, :old_in] = v new_w[:, old_in:] = torch.randn(v.shape[0], new_in - old_in) * 0.01 dst[k] = new_w else: # Truncate: take first new_in columns of the checkpoint weight dst[k] = v[:, :new_in].clone() else: if dst[k].shape == v.shape: dst[k] = v else: skipped += 1 continue loaded += 1 model.load_state_dict(dst) print(f" Warm-start from {ckpt_path.name}: {loaded} tensors loaded, {skipped} skipped") # ───────────────────────────────────────────────────────────────────────────── # Dataset # ───────────────────────────────────────────────────────────────────────────── class EmbeddingDataset(Dataset): def __init__(self, X, Y, indices): self.X = X self.Y = Y self.idx = indices.astype(np.int64) def __len__(self): return len(self.idx) def __getitem__(self, k): i = int(self.idx[k]) return self.X[i], self.Y[i].astype(np.float32) # ───────────────────────────────────────────────────────────────────────────── # Loss # ───────────────────────────────────────────────────────────────────────────── class SmoothBCEWithLogitsLoss(nn.Module): def __init__(self, pos_weight, smoothing=0.05): super().__init__() self.smooth = smoothing self.bce = nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction="mean") def forward(self, logits, targets): t = targets * (1 - self.smooth) + (1 - targets) * self.smooth return self.bce(logits, t) def lr_lambda(warmup, total): def fn(ep): if ep < warmup: return (ep + 1) / warmup p = (ep - warmup) / max(total - warmup, 1) return 0.5 * (1 + math.cos(math.pi * p)) return fn # ───────────────────────────────────────────────────────────────────────────── # Evaluation: micro-Fmax (streaming histogram over ORIGINAL ground truth) # NOTE: no label propagation applied here — this measures on specific terms # ───────────────────────────────────────────────────────────────────────────── @torch.no_grad() def eval_micro_fmax(model, loader, device, step=0.02): model.eval() edges = np.arange(0.0, 1.0 + step, step) nbins = len(edges) hp = np.zeros(nbins, np.int64) hn = np.zeros(nbins, np.int64) tp = 0 for Xb, Yb in loader: p = torch.sigmoid(model(Xb.to(device))).cpu().numpy().ravel() y = Yb.numpy().ravel() > 0.5 tp += int(y.sum()) bi = np.minimum(np.floor(p / step + 1e-9).astype(np.int64), nbins - 1) if y.any(): hp += np.bincount(bi[y], minlength=nbins) if (~y).any(): hn += np.bincount(bi[~y], minlength=nbins) cum_tp = np.cumsum(hp[::-1])[::-1].astype(float) cum_fp = np.cumsum(hn[::-1])[::-1].astype(float) pred = cum_tp + cum_fp prec = np.where(pred > 0, cum_tp / pred, 0.0) rec = cum_tp / max(tp, 1) denom = prec + rec f1 = np.where(denom > 0, 2 * prec * rec / denom, 0.0) b = int(np.argmax(f1)) return {"micro_fmax": float(f1[b]), "t_star": float(edges[b]), "precision": float(prec[b]), "recall": float(rec[b])} # ───────────────────────────────────────────────────────────────────────────── # Evaluation: CAFA-style Fmax (propagate predictions UP; ground truth as-is) # This is the meaningful comparison metric — matches CAFA competition scoring. # ───────────────────────────────────────────────────────────────────────────── @torch.no_grad() def eval_cafa_fmax(model, loader, device, mlb_classes, go_parents, step=0.05): if not go_parents: return {"cafa_fmax": float("nan"), "t_star": float("nan")} go2idx = {g: i for i, g in enumerate(mlb_classes)} anc_map = {} for gid in mlb_classes: parents = go_parents.get(gid, set()) visited, stack = set(), list(parents) while stack: p = stack.pop() if p not in visited: visited.add(p) stack.extend(go_parents.get(p, set())) anc_map[gid] = {p for p in visited if p in go2idx} all_probs, all_true = [], [] for Xb, Yb in loader: all_probs.append(torch.sigmoid(model(Xb.to(device))).cpu().numpy()) all_true.append(Yb.numpy()) probs = np.concatenate(all_probs, axis=0).astype(np.float32) true = np.concatenate(all_true, axis=0).astype(np.float32) has_label = true.sum(axis=1) > 0 probs = probs[has_label] true = true[has_label] N = len(probs) if N == 0: return {"cafa_fmax": float("nan"), "t_star": float("nan")} anc_idx = [ np.array([go2idx[a] for a in anc_map.get(g, set())], dtype=np.int64) for g in mlb_classes ] thresholds = np.arange(0.05, 0.96, step) best_f1, best_t = -1.0, 0.5 for t in thresholds: pred_bin = (probs >= t).astype(np.float32) prop = pred_bin.copy() for j, aidx in enumerate(anc_idx): if len(aidx) == 0: continue mask = pred_bin[:, j] > 0 if mask.any(): prop[np.ix_(np.where(mask)[0], aidx)] = 1.0 tp_per = (prop * true).sum(axis=1) pp_per = prop.sum(axis=1) rp_per = true.sum(axis=1) prec_per = np.where(pp_per > 0, tp_per / pp_per, 0.0) rec_per = np.where(rp_per > 0, tp_per / rp_per, 0.0) has_pred = pp_per > 0 if has_pred.sum() == 0: continue avg_prec = prec_per[has_pred].mean() avg_rec = rec_per.mean() denom = avg_prec + avg_rec f1 = (2 * avg_prec * avg_rec / denom) if denom > 0 else 0.0 if f1 > best_f1: best_f1, best_t = f1, float(t) return {"cafa_fmax": round(float(best_f1), 4), "t_star": round(best_t, 3)} # ───────────────────────────────────────────────────────────────────────────── # Per-label threshold calibration (on ORIGINAL non-propagated val labels) # ───────────────────────────────────────────────────────────────────────────── @torch.no_grad() def compute_per_label_thresholds(model, loader, device, support_tr, min_support=MIN_SUPPORT, fallback=FALLBACK_THRESH): model.eval() all_p, all_y = [], [] for Xb, Yb in loader: all_p.append(torch.sigmoid(model(Xb.to(device))).cpu()) all_y.append(Yb) probs = torch.cat(all_p).numpy().astype(np.float32) true = torch.cat(all_y).numpy().astype(np.float32) steps = np.arange(0.10, 0.96, 0.05, dtype=np.float32) thr = np.full(OUT_DIM, fallback, dtype=np.float32) for j in range(OUT_DIM): if int(support_tr[j]) < min_support: continue pj, tj = probs[:, j], true[:, j] best_f1, best_t = -1.0, fallback for t in steps: pred = (pj >= t).astype(np.float32) tp = float((pred * tj).sum()) fp = float((pred * (1 - tj)).sum()) fn = float(((1 - pred) * tj).sum()) d = 2 * tp + fp + fn f1 = (2 * tp / d) if d > 0 else 0.0 if f1 > best_f1: best_f1, best_t = f1, float(t) thr[j] = best_t return thr # ───────────────────────────────────────────────────────────────────────────── # GO hierarchy: parse OBO (identical to train_v2.py) # ───────────────────────────────────────────────────────────────────────────── def load_go_parents(obo_path: Path) -> dict: if not obo_path.exists(): print(f" WARNING: {obo_path} not found — CAFA eval disabled") return {} ns_map, par_map = {}, {} cur_id, cur_ns, cur_par, in_term = None, None, set(), False def flush(): nonlocal cur_id, cur_ns, cur_par if cur_id and cur_ns: ns_map[cur_id] = cur_ns par_map[cur_id] = cur_par cur_id, cur_ns, cur_par = None, None, set() with open(obo_path, encoding="utf-8") as fh: for raw in fh: line = raw.strip() if line == "[Term]": flush(); in_term = True; continue if line.startswith("[") and line != "[Term]": flush(); in_term = False; continue if not in_term: continue if line.startswith("id:"): cur_id = line.split("id:", 1)[1].strip().split()[0] elif line.startswith("namespace:"): cur_ns = line.split("namespace:", 1)[1].strip() elif line.startswith("is_obsolete:") and "true" in line: cur_id = None elif line.startswith("is_a:"): cur_par.add(line.split("is_a:", 1)[1].strip().split()[0]) elif line.startswith("relationship:"): pts = line.split("relationship:", 1)[1].strip().split() if len(pts) >= 2 and pts[0] == "part_of": cur_par.add(pts[1]) flush() mf = {g for g, n in ns_map.items() if n == "molecular_function"} return {g: (par_map[g] & mf) for g in mf} # ───────────────────────────────────────────────────────────────────────────── # Label parsing helper # ───────────────────────────────────────────────────────────────────────────── def parse_labels(x): if x is None: return [] if isinstance(x, (list, np.ndarray)): return [int(v) for v in x] if isinstance(x, str): s = x.strip() if not s or s.lower() == "nan": return [] try: v = ast.literal_eval(s) return [int(i) for i in (v if isinstance(v, (list, tuple)) else [v])] except Exception: return [] return [] def select_feature_matrix(X_all_full: np.ndarray, feature_level: str) -> np.ndarray: """ Match the feature slices used in HPO: esm_only -> 320 esm_seq -> 331 esm_all -> 360 """ if feature_level == "esm_only": return X_all_full[:, :ESM_DIM] if feature_level == "esm_seq": seq_end = ESM_DIM + len(SEQ_ONLY_COLS) return X_all_full[:, :seq_end] if feature_level == "esm_all": return X_all_full raise ValueError(f"Unknown feature_level: {feature_level}") # ───────────────────────────────────────────────────────────────────────────── # Mammal data helpers (identical to train_v2.py) # ───────────────────────────────────────────────────────────────────────────── def parse_fasta(path): header, seq = None, [] with open(path) as fh: for line in fh: line = line.strip() if line.startswith(">"): if header is not None: yield header, "".join(seq) header = line[1:].split()[0] seq = [] else: seq.append(line) if header is not None: yield header, "".join(seq) def fetch_go_mf(uniprot_id, exp_codes, timeout): url = f"https://rest.uniprot.org/uniprotkb/{uniprot_id}.json?fields=go_f" try: r = requests.get(url, timeout=timeout) if r.status_code != 200: return [] refs = r.json().get("uniProtKBCrossReferences", []) terms = [] for ref in refs: if ref.get("database") != "GO": continue go_id = ref.get("id", "") props = {p.get("key"): p.get("value") for p in ref.get("properties", [])} aspect = props.get("GoTerm", "") evidence = props.get("GoEvidenceType", "") if aspect.startswith("F:") and evidence.split(":")[0] in exp_codes: terms.append(go_id) return terms except Exception: return [] def build_mammal_dataset(fasta_path, mlb, device, max_seq_len=2500): print(" Loading ESM-2 for mammal embedding...") import esm as esm_lib esm_model, alphabet = esm_lib.pretrained.esm2_t6_8M_UR50D() esm_model = esm_model.to(device).eval() bc = alphabet.get_batch_converter() entries = [(h, s) for h, s in parse_fasta(fasta_path) if 30 <= len(s) <= max_seq_len] print(f" Parsed {len(entries)} mammal sequences") print(f" Fetching GO annotations from UniProt ({API_THREADS} threads)...") go_labels, lock = {}, threading.Lock() def fetch_one(hdr_seq): hdr, _ = hdr_seq uid = hdr.split("|")[1] if "|" in hdr else hdr.split()[0] terms = fetch_go_mf(uid, EXP_CODES, API_TIMEOUT) with lock: go_labels[hdr] = terms with ThreadPoolExecutor(max_workers=API_THREADS) as ex: futures = {ex.submit(fetch_one, e): e for e in entries} for i, fut in enumerate(as_completed(futures)): if (i + 1) % 100 == 0: print(f" {i+1}/{len(entries)} annotations fetched") go2idx = {g: i for i, g in enumerate(mlb.classes_)} labeled = [(h, s, [go2idx[t] for t in go_labels.get(h, []) if t in go2idx]) for h, s in entries] labeled = [(h, s, l) for h, s, l in labeled if len(l) > 0] print(f" {len(labeled)} mammal sequences with ≥1 GO-MF label (experimental evidence)") print(f" Embedding {len(labeled)} sequences with ESM-2...") TOKEN_BUDGET = 6000 rows = [] batch_buf = [] total_tok = 0 def flush_batch(buf): labels_list = [item[2] for item in buf] batch_data = [(h, s) for h, s, _ in buf] _, _, toks = bc(batch_data) with torch.no_grad(): rep = esm_model(toks.to(device), repr_layers=[6])["representations"][6] for k, (h, s, labs) in enumerate(buf): emb = rep[k, 1:len(s)+1].mean(0).cpu().numpy() row = {f"Dim_{i}": float(emb[i]) for i in range(ESM_DIM)} row["Label_Indices"] = labs row["f_af_present"] = 0.0 rows.append(row) for h, s, labs in labeled: ntok = len(s) + 2 if total_tok + ntok > TOKEN_BUDGET and batch_buf: flush_batch(batch_buf) batch_buf, total_tok = [], 0 batch_buf.append((h, s, labs)) total_tok += ntok if batch_buf: flush_batch(batch_buf) return pd.DataFrame(rows) # ───────────────────────────────────────────────────────────────────────────── # Core training function # ───────────────────────────────────────────────────────────────────────────── def run_training(args, X_all, Y_all, train_idx, val_idx, test_idx, support_tr, mu, sd, go_parents, mlb, in_dim, feature_label, ckpt_out, thresh_out, log_out, device): """ Train one model variant. Returns dict with best val Fmax, test metrics, etc. """ ds_tr = EmbeddingDataset(X_all, Y_all, train_idx) ds_va = EmbeddingDataset(X_all, Y_all, val_idx) ds_te = EmbeddingDataset(X_all, Y_all, test_idx) inv_sqrt = 1.0 / np.sqrt(np.maximum(support_tr, 1.0)) w_train = np.array([ float(np.mean(inv_sqrt[Y_all[i].astype(bool)])) if Y_all[i].any() else 1.0 for i in train_idx ], dtype=np.float32) sampler = WeightedRandomSampler(torch.as_tensor(w_train), len(w_train), replacement=True) ld_tr = DataLoader(ds_tr, args.batch, sampler=sampler, num_workers=args.num_workers) ld_va = DataLoader(ds_va, args.batch, shuffle=False, num_workers=args.num_workers) ld_te = DataLoader(ds_te, args.batch, shuffle=False, num_workers=args.num_workers) pw_np = np.clip((len(train_idx) - support_tr) / np.maximum(support_tr, 1.0), *PW_CLIP) pos_weight = torch.tensor(pw_np, dtype=torch.float32).to(device) model = ImprovedResidualMLP( in_dim=in_dim, out_dim=OUT_DIM, hidden=args.hidden, n_blocks=args.blocks, dropout=args.dropout, ).to(device) n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"\n [{feature_label}] Model: {n_params:,} params (in_dim={in_dim})") # Warm-start from improved_res.pth warm_start(model, WARMSTART, in_dim) criterion = SmoothBCEWithLogitsLoss(pos_weight, smoothing=args.label_smooth) optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=1e-7) warmup = max(1, int(args.epochs * 0.08)) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda(warmup, args.epochs)) print(f" [{feature_label}] Training (epochs={args.epochs}, patience={args.patience})") best_fmax, no_improve = -1.0, 0 log = [] for ep in range(1, args.epochs + 1): t0 = time.time() model.train() running = 0.0 for Xb, Yb in ld_tr: Xb, Yb = Xb.to(device), Yb.to(device) optimizer.zero_grad() with torch.amp.autocast(device_type=device.type, dtype=torch.bfloat16): loss = criterion(model(Xb), Yb) loss.backward() nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() running += loss.item() scheduler.step() train_loss = running / len(ld_tr) mi = eval_micro_fmax(model, ld_va, device) fmax = mi["micro_fmax"] cafa_fmax = None if go_parents and (ep % 5 == 0 or ep == args.epochs): ci = eval_cafa_fmax(model, ld_va, device, mlb.classes_, go_parents) cafa_fmax = ci["cafa_fmax"] elapsed = time.time() - t0 lr_now = scheduler.get_last_lr()[0] entry = { "feature_label": feature_label, "epoch": ep, "train_loss": round(train_loss, 6), "val_micro_fmax": round(fmax, 4), "val_t_star": round(mi["t_star"], 3), "val_prec": round(mi["precision"], 4), "val_rec": round(mi["recall"], 4), "val_cafa_fmax": round(cafa_fmax, 4) if cafa_fmax is not None else None, "lr": round(lr_now, 7), "elapsed_s": round(elapsed, 1), } log.append(entry) with open(log_out, "w") as f: json.dump(log, f, indent=2) cafa_str = f" CAFA={cafa_fmax:.4f}" if cafa_fmax is not None else "" print( f" Ep {ep:3d}/{args.epochs} | loss={train_loss:.4f} | " f"micro-fmax={fmax:.4f} @t={mi['t_star']:.2f} " f"P={mi['precision']:.3f} R={mi['recall']:.3f}{cafa_str} | " f"lr={lr_now:.2e} | {elapsed:.0f}s" ) if fmax > best_fmax: best_fmax = fmax no_improve = 0 torch.save({ "model": model.state_dict(), "epoch": ep, "val_fmax": fmax, "feature_label": feature_label, "in_dim": in_dim, "hidden": args.hidden, "n_blocks": args.blocks, "supp_mu": mu.tolist(), "supp_sd": sd.tolist(), "supp_cols": SUPP_COLS, }, ckpt_out) print(f" ✓ Best micro-fmax={fmax:.4f} — saved") else: no_improve += 1 if no_improve >= args.patience: print(f"\n Early stopping at epoch {ep}") break # Test evaluation print(f"\n [{feature_label}] Test evaluation...") ckpt = torch.load(ckpt_out, map_location=device) model.load_state_dict(ckpt["model"]) test_micro = eval_micro_fmax(model, ld_te, device) test_cafa = eval_cafa_fmax(model, ld_te, device, mlb.classes_, go_parents) if go_parents else {} print(f" Test micro-Fmax={test_micro['micro_fmax']:.4f} " f"P={test_micro['precision']:.4f} R={test_micro['recall']:.4f}") if test_cafa: print(f" Test CAFA-Fmax={test_cafa.get('cafa_fmax', 'N/A')}") # Per-label thresholds on val set (non-propagated labels) print(f" [{feature_label}] Computing per-label thresholds...") thr = compute_per_label_thresholds(model, ld_va, device, support_tr) with open(thresh_out, "w") as f: json.dump({str(i): float(thr[i]) for i in range(OUT_DIM)}, f) log.append({"test_micro": test_micro, "test_cafa": test_cafa}) with open(log_out, "w") as f: json.dump(log, f, indent=2) return { "feature_label": feature_label, "in_dim": in_dim, "best_val_fmax": best_fmax, "test_micro_fmax": test_micro["micro_fmax"], "test_cafa_fmax": test_cafa.get("cafa_fmax"), "ckpt": str(ckpt_out), "thresh": str(thresh_out), } # ───────────────────────────────────────────────────────────────────────────── # Main # ───────────────────────────────────────────────────────────────────────────── def train(args): device = torch.device( "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu" ) print(f"Device: {device}") # Allow MPS to use up to 95% of unified memory (default is conservative) if device.type == "mps": torch.mps.set_per_process_memory_fraction(0.95) # --start_from implies ablation mode if args.start_from != "A": args.ablation = True ckpt_out = Path(args.checkpoint_out) if args.checkpoint_out else CKPT_OUT thresh_out = Path(args.threshold_out) if args.threshold_out else THRESH_OUT log_out = Path(args.log_out) if args.log_out else LOG_OUT ckpt_out.parent.mkdir(parents=True, exist_ok=True) thresh_out.parent.mkdir(parents=True, exist_ok=True) log_out.parent.mkdir(parents=True, exist_ok=True) # Resolve num_workers default based on device if args.num_workers is None: args.num_workers = 0 if device.type == "mps" else 4 print(f"DataLoader num_workers: {args.num_workers}") # ── Load insect base dataset ─────────────────────────────────────────────── print("\n[1/6] Loading insect dataset...") df_base = pd.read_parquet(DATA_BASE) df_supp = pd.read_parquet(DATA_SUPP) emb_cols = [c for c in df_base.columns if c.startswith("Dim_")] assert len(emb_cols) == ESM_DIM mlb = joblib.load(MLB_PATH) assert len(mlb.classes_) == OUT_DIM # ── GO hierarchy (for CAFA eval only — NOT used to inflate training labels) ─ print("[1b/6] Loading GO hierarchy (for CAFA-style evaluation only)...") go_parents = load_go_parents(OBO_PATH) print(f" GO parents loaded: {len(go_parents)} MF terms") # ── Build insect label matrix (NO propagation — original annotations) ───── print("[2/6] Building insect label matrix (original annotations, no propagation)...") label_lists = [parse_labels(x) for x in df_base["Label_Indices"]] Y_insect = np.zeros((len(df_base), OUT_DIM), dtype=np.uint8) # uint8: 2GB vs 8GB float32 for r, labs in enumerate(label_lists): for j in labs: if 0 <= j < OUT_DIM: Y_insect[r, j] = 1 print(f" Insect label matrix: {int(Y_insect.sum()):,} positives across {len(Y_insect):,} proteins") print(f" (No GO propagation applied — model will learn specific terms as annotated)") # ── Supplemented features ───────────────────────────────────────────────── S_raw = df_supp[SUPP_COLS].to_numpy(np.float32) m_flag = df_supp["f_af_present"].to_numpy(np.float32).reshape(-1, 1) X_base = df_base[emb_cols].to_numpy(np.float32) # ── Splits ───────────────────────────────────────────────────────────────── splits = np.load(SPLITS_NPZ) train_idx = splits["train_idx"] val_idx = splits["val_idx"] test_idx = splits["test_idx"] print(f" Insect splits — train:{len(train_idx)} val:{len(val_idx)} test:{len(test_idx)}") # Normalise using train stats S_tr = S_raw[train_idx] mu = np.nanmean(S_tr, axis=0) sd = np.where(np.nanstd(S_tr, axis=0) > 0, np.nanstd(S_tr, axis=0), 1.0) S_z = (S_raw - mu) / (sd + 1e-12) X_insect_full = np.concatenate([X_base, S_z, m_flag], axis=1).astype(np.float32) print(f" Insect X shape (full features): {X_insect_full.shape}") # ── Mammal data ──────────────────────────────────────────────────────────── X_mammal, Y_mammal = None, None if not args.skip_mammal and MAMMAL_FASTA.exists(): print("\n[3/6] Processing mammal data...") if MAMMAL_EMB.exists(): print(" Loading cached mammal embeddings...") df_m = pd.read_parquet(MAMMAL_EMB) else: print(" Computing mammal embeddings (first run)...") df_m = build_mammal_dataset(MAMMAL_FASTA, mlb, device) df_m.to_parquet(MAMMAL_EMB, index=False) m_emb_cols = [f"Dim_{i}" for i in range(ESM_DIM)] X_m_base = df_m[m_emb_cols].to_numpy(np.float32) # Build mammal label matrix — NO propagation (same as insect) Y_m = np.zeros((len(df_m), OUT_DIM), dtype=np.uint8) # uint8 to match insect for r, labs in enumerate(df_m["Label_Indices"].tolist()): for j in parse_labels(labs): if 0 <= j < OUT_DIM: Y_m[r, j] = 1 print(f" Mammal: {int(Y_m.sum()):,} positives (original annotations, no propagation)") S_m = np.zeros((len(df_m), SUPP_DIM), dtype=np.float32) for ci, col in enumerate(SUPP_COLS): if col in df_m.columns: v = df_m[col].to_numpy(np.float32) S_m[:, ci] = np.where(np.isnan(v), mu[ci], v) S_m_z = (S_m - mu) / (sd + 1e-12) m_m = df_m["f_af_present"].to_numpy(np.float32).reshape(-1, 1) if "f_af_present" in df_m else np.zeros((len(df_m), 1), np.float32) X_mammal = np.concatenate([X_m_base, S_m_z, m_m], axis=1).astype(np.float32) Y_mammal = Y_m print(f" Mammal X shape: {X_mammal.shape}") else: print("\n[3/6] Skipping mammal data") # ── Merge insect + mammal ───────────────────────────────────────────────── print("\n[4/6] Merging datasets...") if X_mammal is not None: X_all_full = np.concatenate([X_insect_full, X_mammal], axis=0) Y_all = np.concatenate([Y_insect, Y_mammal], axis=0) mammal_idx = np.arange(len(X_insect_full), len(X_all_full)) train_idx_combined = np.concatenate([train_idx, mammal_idx]) print(f" Combined: {len(X_all_full)} rows (insect+mammal train: {len(train_idx_combined)})") else: X_all_full = X_insect_full Y_all = Y_insect train_idx_combined = train_idx support_tr = Y_all[train_idx_combined].sum(0).astype(np.float32) print(f" Training positives per label (median): {np.median(support_tr[support_tr>0]):.0f}") # ── Ablation or single run ───────────────────────────────────────────────── if args.ablation: print("\n[5/6] Ablation study: ESM-only vs ESM+seq vs ESM+all (pLDDT/PAE)") ablation_results = [] # Feature slices # Model A: ESM only (320) X_a = X_all_full[:, :ESM_DIM] # Model B: ESM + sequence features (320 + 11 = 331) seq_end = ESM_DIM + len(SEQ_ONLY_COLS) X_b = X_all_full[:, :seq_end] # Model C: full features including pLDDT/PAE (360) X_c = X_all_full variants = [ ("A_ESM_only", X_a, ESM_DIM), ("B_ESM_seq", X_b, seq_end), ("C_ESM_seq_AF", X_c, X_c.shape[1]), ] for label, X_variant, in_dim in variants: if label[0] < args.start_from: print(f" Skipping {label} (--start_from {args.start_from})") continue ckpt_v = ART / "checkpoints" / f"ablation_{label}.pth" thresh_v = ART / "thresholds" / f"ablation_{label}_thresholds.json" log_v = ART / "logs" / f"ablation_{label}_log.json" result = run_training( args, X_variant, Y_all, train_idx_combined, val_idx, test_idx, support_tr, mu, sd, go_parents, mlb, in_dim, label, ckpt_v, thresh_v, log_v, device, ) ablation_results.append(result) # Summary ablation_out = ART / "ablation_results.json" with open(ablation_out, "w") as f: json.dump(ablation_results, f, indent=2) print("\n" + "=" * 75) print("ABLATION RESULTS") print("=" * 75) print(f"{'Model':<20} {'Val Fmax':>10} {'Test Fmax':>10} {'Test CAFA':>10}") print("-" * 60) for r in ablation_results: cafa = f"{r['test_cafa_fmax']:.4f}" if r.get("test_cafa_fmax") else "N/A" print(f"{r['feature_label']:<20} {r['best_val_fmax']:>10.4f} {r['test_micro_fmax']:>10.4f} {cafa:>10}") print("=" * 75) print(f"\nDetailed results → {ablation_out}") # The best model (Model C) becomes the deployed model best = max(ablation_results, key=lambda r: r["test_micro_fmax"]) print(f"\nBest model: {best['feature_label']} (test Fmax={best['test_micro_fmax']:.4f})") else: # Single run: explicit feature slice chosen by pipeline / user X_single = select_feature_matrix(X_all_full, args.feature_level) in_dim = X_single.shape[1] print(f"\n[5/6] Training single model ({args.feature_level}, in_dim={in_dim})...") run_training( args, X_single, Y_all, train_idx_combined, val_idx, test_idx, support_tr, mu, sd, go_parents, mlb, in_dim, args.feature_label, ckpt_out, thresh_out, log_out, device, ) print(f"\n[6/6] Done.") print(f" Main checkpoint → {ckpt_out}") print(f" Thresholds → {thresh_out}") print(f"\nTo deploy to HuggingFace:") print(f" huggingface-cli upload Sbhat2026/protfunc-models {ckpt_out} protfunc_v3_fixed.pth") print(f" huggingface-cli upload Sbhat2026/protfunc-models {thresh_out} protfunc_v3_fixed_thresholds.json") if __name__ == "__main__": p = argparse.ArgumentParser(description="ProtFunc v3 (corrected label handling)") p.add_argument("--epochs", type=int, default=50, help="Max training epochs (warm-start needs fewer than from-scratch)") p.add_argument("--hidden", type=int, default=2048) p.add_argument("--blocks", type=int, default=4) p.add_argument("--dropout", type=float, default=0.20) p.add_argument("--lr", type=float, default=5e-5, help="Lower LR than from-scratch (warm-start already near optimum)") p.add_argument("--batch", type=int, default=512) p.add_argument("--patience", type=int, default=10) p.add_argument("--label_smooth", type=float, default=0.05) p.add_argument("--weight_decay", type=float, default=5e-4) p.add_argument("--skip-mammal", action="store_true") p.add_argument("--feature_level", type=str, default="esm_all", choices=["esm_only", "esm_seq", "esm_all"], help="Single-run feature slice (ignored when --ablation is used)") p.add_argument("--feature_label", type=str, default="v3_fixed", help="Single-run feature label stored in checkpoint/log output") p.add_argument("--checkpoint_out", type=str, default="", help="Optional explicit checkpoint output path") p.add_argument("--threshold_out", type=str, default="", help="Optional explicit threshold output path") p.add_argument("--log_out", type=str, default="", help="Optional explicit training log output path") p.add_argument("--ablation", action="store_true", help="Run ESM-only / ESM+seq / ESM+all ablation study") p.add_argument("--start_from", type=str, default="A", choices=["A", "B", "C"], help="Skip ablation variants before this model (implies --ablation)") p.add_argument("--num_workers", type=int, default=None, help="DataLoader num_workers (default: 0 on MPS, 4 otherwise)") train(p.parse_args())