| """ |
| train_v3_fixed.py β ProtFunc v3 (corrected training procedure) |
| ================================================================ |
| Fixes the core methodological error in train_v2.py: |
| |
| WRONG (train_v2.py): propagate GO labels DURING TRAINING |
| CORRECT (this file): train on original labels; propagate ONLY during evaluation |
| |
| Why the original approach was wrong: |
| Propagating labels before training causes the model to directly predict broad |
| ancestor terms ("binding", "catalytic activity") for nearly every protein. |
| This inflates val micro-Fmax from ~0.88 to ~0.97 β an apples-to-oranges |
| comparison that looks like a massive improvement but is mostly just the model |
| learning trivially predictable parent terms. Threshold calibration on propagated |
| ground truth then lets those broad terms fire at inference, producing hundreds |
| of predictions per protein. |
| |
| The correct approach (used by CAFA competitors): |
| 1. Train on experimental annotations as-is (specific terms only) |
| 2. Evaluate using CAFA-style propagation (predictions + ground-truth both |
| propagated upward) β this is fair because a protein that does "ATP |
| hydrolysis" also implicitly performs "binding" |
| |
| Warm-start from improved_res.pth: |
| Instead of rebuilding from scratch, we load improved_res.pth (Fmax=0.8846) |
| as the starting point. For the supplemented (360-dim) model, we copy all |
| weights and extend fc_in with small random values for the new feature dims. |
| This typically saves ~15-20 training epochs. |
| |
| Ablation study for research question: |
| "How much predictive gain from AlphaFold structural features (pLDDT, PAE)?" |
| Run with --ablation to train three models in sequence: |
| A: ESM only (320 dim) β baseline |
| B: ESM + sequence features (331 dim) β adds composition/physicochemical |
| C: ESM + all features (360 dim) β adds pLDDT, PAE, AF confidence |
| |
| Each model is evaluated on: |
| 1. All proteins |
| 2. AF-covered proteins only (where pLDDT/PAE are non-zero) |
| This isolates the structural feature contribution. |
| |
| Outputs: |
| artifacts/protfunc_v3_fixed.pth β best model (model C by default) |
| artifacts/protfunc_v3_fixed_thresholds.json |
| artifacts/protfunc_v3_fixed_log.json |
| artifacts/ablation_results.json β if --ablation flag used |
| """ |
|
|
| import os, re, ast, json, math, time, argparse, warnings, requests, threading |
| from pathlib import Path |
| from concurrent.futures import ThreadPoolExecutor, as_completed |
|
|
| import numpy as np |
| import pandas as pd |
| import joblib |
| import torch |
| import torch.nn as nn |
| from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler |
|
|
| warnings.filterwarnings("ignore") |
|
|
| |
| BASE = Path(__file__).parent.parent |
| ART = BASE / "artifacts" |
| IMPORTANT = BASE / "Important Files" |
| ART.mkdir(exist_ok=True) |
|
|
| DATA_BASE = IMPORTANT / "merged_full_struct.parquet" |
| DATA_SUPP = IMPORTANT / "merged_full_struct_with_features.parquet" |
| MLB_PATH = IMPORTANT / "mlb_public_v1.pkl" |
| SPLITS_NPZ = ART / "splits" / "splits_n250000_seed42.npz" |
| MAMMAL_FASTA = IMPORTANT / "mammal_subset.fasta" |
| MAMMAL_EMB = ART / "generalization" / "mammal_embeddings_v3.parquet" |
| OBO_PATH = BASE / "go-basic.obo" |
| |
| _WARMSTART_CANDIDATES = [ |
| ART / "checkpoints" / "ablation_C_ESM_seq_AF.pth", |
| ART / "checkpoints" / "protfunc_v3_fixed.pth", |
| ART / "checkpoints" / "improved_res.pth", |
| ] |
| WARMSTART = next((p for p in _WARMSTART_CANDIDATES if p.exists()), _WARMSTART_CANDIDATES[-1]) |
|
|
| CKPT_OUT = ART / "checkpoints" / "protfunc_v3_fixed.pth" |
| THRESH_OUT = ART / "thresholds" / "protfunc_v3_fixed_thresholds.json" |
| LOG_OUT = ART / "logs" / "protfunc_v3_fixed_log.json" |
|
|
| |
| SUPP_COLS = [ |
| "f_seq_len", "f_mean_hydro", "f_net_charge", "f_uversky_disorder", |
| "f_idr_frac_proxy", "f_lowcomp_proxy", "f_tm_frac_proxy", "f_tm_any_proxy", |
| "f_signal_peptide_proxy", "f_cf_helix_mean", "f_cf_sheet_mean", |
| "f_afdb_has_model", |
| "f_plddt_mean", "f_plddt_std", "f_plddt_q10", "f_plddt_q50", "f_plddt_q90", |
| "f_plddt_frac_gt90", "f_plddt_frac_gt70", "f_plddt_frac_lt50", |
| "f_distbin_0", "f_distbin_1", "f_distbin_2", "f_distbin_3", "f_distbin_4", |
| "f_distbin_5", "f_distbin_6", "f_distbin_7", "f_distbin_8", "f_distbin_9", |
| "f_pae_mean", "f_pae_median", "f_pae_p90", "f_pae_p95", |
| "f_pae_frac_lt5", "f_pae_frac_lt10", "f_pae_frac_gt20", |
| "f_seqfeat_present", "f_af_present", |
| ] |
| SEQ_ONLY_COLS = [ |
| "f_seq_len", "f_mean_hydro", "f_net_charge", "f_uversky_disorder", |
| "f_idr_frac_proxy", "f_lowcomp_proxy", "f_tm_frac_proxy", "f_tm_any_proxy", |
| "f_signal_peptide_proxy", "f_cf_helix_mean", "f_cf_sheet_mean", |
| ] |
|
|
| |
| SEED = 42 |
| OUT_DIM = 8124 |
| ESM_DIM = 320 |
| SUPP_DIM = len(SUPP_COLS) |
| BATCH = 512 |
| PW_CLIP = (1.0, 100.0) |
| MIN_SUPPORT = 10 |
| FALLBACK_THRESH = 0.50 |
| API_TIMEOUT = 12 |
| API_THREADS = 20 |
| EXP_CODES = {"IDA","IMP","IPI","IGI","IEP","EXP","HDA","HMP","HGI","HEP","TAS","IC"} |
|
|
| np.random.seed(SEED) |
| torch.manual_seed(SEED) |
|
|
|
|
| |
| |
| |
|
|
| class ResBlock(nn.Module): |
| def __init__(self, dim: int, dropout: float): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.BatchNorm1d(dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(dim, dim), |
| nn.BatchNorm1d(dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(dim, dim), |
| ) |
| def forward(self, x): |
| return x + self.net(x) |
|
|
|
|
| class ImprovedResidualMLP(nn.Module): |
| def __init__(self, in_dim, out_dim=OUT_DIM, hidden=2048, n_blocks=4, dropout=0.2): |
| super().__init__() |
| self.fc_in = nn.Linear(in_dim, hidden) |
| self.blocks = nn.ModuleList([ResBlock(hidden, dropout) for _ in range(n_blocks)]) |
| self.fc_out = nn.Sequential( |
| nn.BatchNorm1d(hidden), nn.ReLU(), nn.Dropout(dropout), |
| nn.Linear(hidden, out_dim), |
| ) |
| def forward(self, x): |
| h = self.fc_in(x) |
| for blk in self.blocks: |
| h = blk(h) |
| return self.fc_out(h) |
|
|
|
|
| |
| |
| |
|
|
| def warm_start(model: ImprovedResidualMLP, ckpt_path: Path, in_dim_new: int): |
| """ |
| Load improved_res.pth weights into model, extending fc_in if needed. |
| For in_dim_new > 320: copies the first 320 columns of fc_in.weight and |
| initialises the remaining (in_dim_new - 320) columns with small random |
| values so the new features start near-zero and don't disrupt representations. |
| All other layers (blocks, fc_out) are copied exactly. |
| """ |
| if not ckpt_path.exists(): |
| print(f" Warm-start skipped: {ckpt_path} not found (training from scratch)") |
| return |
| raw = torch.load(ckpt_path, map_location="cpu") |
| src = raw["model"] if isinstance(raw, dict) and "model" in raw else raw |
| dst = model.state_dict() |
|
|
| loaded, skipped = 0, 0 |
| for k, v in src.items(): |
| if k not in dst: |
| skipped += 1 |
| continue |
| if k == "fc_in.weight": |
| old_in = v.shape[1] |
| new_in = dst[k].shape[1] |
| if new_in == old_in: |
| dst[k] = v |
| elif new_in > old_in: |
| |
| new_w = torch.zeros_like(dst[k]) |
| new_w[:, :old_in] = v |
| new_w[:, old_in:] = torch.randn(v.shape[0], new_in - old_in) * 0.01 |
| dst[k] = new_w |
| else: |
| |
| dst[k] = v[:, :new_in].clone() |
| else: |
| if dst[k].shape == v.shape: |
| dst[k] = v |
| else: |
| skipped += 1 |
| continue |
| loaded += 1 |
|
|
| model.load_state_dict(dst) |
| print(f" Warm-start from {ckpt_path.name}: {loaded} tensors loaded, {skipped} skipped") |
|
|
|
|
| |
| |
| |
|
|
| class EmbeddingDataset(Dataset): |
| def __init__(self, X, Y, indices): |
| self.X = X |
| self.Y = Y |
| self.idx = indices.astype(np.int64) |
| def __len__(self): |
| return len(self.idx) |
| def __getitem__(self, k): |
| i = int(self.idx[k]) |
| return self.X[i], self.Y[i].astype(np.float32) |
|
|
|
|
| |
| |
| |
|
|
| class SmoothBCEWithLogitsLoss(nn.Module): |
| def __init__(self, pos_weight, smoothing=0.05): |
| super().__init__() |
| self.smooth = smoothing |
| self.bce = nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction="mean") |
| def forward(self, logits, targets): |
| t = targets * (1 - self.smooth) + (1 - targets) * self.smooth |
| return self.bce(logits, t) |
|
|
|
|
| def lr_lambda(warmup, total): |
| def fn(ep): |
| if ep < warmup: |
| return (ep + 1) / warmup |
| p = (ep - warmup) / max(total - warmup, 1) |
| return 0.5 * (1 + math.cos(math.pi * p)) |
| return fn |
|
|
|
|
| |
| |
| |
| |
|
|
| @torch.no_grad() |
| def eval_micro_fmax(model, loader, device, step=0.02): |
| model.eval() |
| edges = np.arange(0.0, 1.0 + step, step) |
| nbins = len(edges) |
| hp = np.zeros(nbins, np.int64) |
| hn = np.zeros(nbins, np.int64) |
| tp = 0 |
| for Xb, Yb in loader: |
| p = torch.sigmoid(model(Xb.to(device))).cpu().numpy().ravel() |
| y = Yb.numpy().ravel() > 0.5 |
| tp += int(y.sum()) |
| bi = np.minimum(np.floor(p / step + 1e-9).astype(np.int64), nbins - 1) |
| if y.any(): hp += np.bincount(bi[y], minlength=nbins) |
| if (~y).any(): hn += np.bincount(bi[~y], minlength=nbins) |
| cum_tp = np.cumsum(hp[::-1])[::-1].astype(float) |
| cum_fp = np.cumsum(hn[::-1])[::-1].astype(float) |
| pred = cum_tp + cum_fp |
| prec = np.where(pred > 0, cum_tp / pred, 0.0) |
| rec = cum_tp / max(tp, 1) |
| denom = prec + rec |
| f1 = np.where(denom > 0, 2 * prec * rec / denom, 0.0) |
| b = int(np.argmax(f1)) |
| return {"micro_fmax": float(f1[b]), "t_star": float(edges[b]), |
| "precision": float(prec[b]), "recall": float(rec[b])} |
|
|
|
|
| |
| |
| |
| |
|
|
| @torch.no_grad() |
| def eval_cafa_fmax(model, loader, device, mlb_classes, go_parents, step=0.05): |
| if not go_parents: |
| return {"cafa_fmax": float("nan"), "t_star": float("nan")} |
| go2idx = {g: i for i, g in enumerate(mlb_classes)} |
| anc_map = {} |
| for gid in mlb_classes: |
| parents = go_parents.get(gid, set()) |
| visited, stack = set(), list(parents) |
| while stack: |
| p = stack.pop() |
| if p not in visited: |
| visited.add(p) |
| stack.extend(go_parents.get(p, set())) |
| anc_map[gid] = {p for p in visited if p in go2idx} |
|
|
| all_probs, all_true = [], [] |
| for Xb, Yb in loader: |
| all_probs.append(torch.sigmoid(model(Xb.to(device))).cpu().numpy()) |
| all_true.append(Yb.numpy()) |
| probs = np.concatenate(all_probs, axis=0).astype(np.float32) |
| true = np.concatenate(all_true, axis=0).astype(np.float32) |
|
|
| has_label = true.sum(axis=1) > 0 |
| probs = probs[has_label] |
| true = true[has_label] |
|
|
| N = len(probs) |
| if N == 0: |
| return {"cafa_fmax": float("nan"), "t_star": float("nan")} |
|
|
| anc_idx = [ |
| np.array([go2idx[a] for a in anc_map.get(g, set())], dtype=np.int64) |
| for g in mlb_classes |
| ] |
| thresholds = np.arange(0.05, 0.96, step) |
| best_f1, best_t = -1.0, 0.5 |
|
|
| for t in thresholds: |
| pred_bin = (probs >= t).astype(np.float32) |
| prop = pred_bin.copy() |
| for j, aidx in enumerate(anc_idx): |
| if len(aidx) == 0: |
| continue |
| mask = pred_bin[:, j] > 0 |
| if mask.any(): |
| prop[np.ix_(np.where(mask)[0], aidx)] = 1.0 |
|
|
| tp_per = (prop * true).sum(axis=1) |
| pp_per = prop.sum(axis=1) |
| rp_per = true.sum(axis=1) |
| prec_per = np.where(pp_per > 0, tp_per / pp_per, 0.0) |
| rec_per = np.where(rp_per > 0, tp_per / rp_per, 0.0) |
| has_pred = pp_per > 0 |
| if has_pred.sum() == 0: |
| continue |
| avg_prec = prec_per[has_pred].mean() |
| avg_rec = rec_per.mean() |
| denom = avg_prec + avg_rec |
| f1 = (2 * avg_prec * avg_rec / denom) if denom > 0 else 0.0 |
| if f1 > best_f1: |
| best_f1, best_t = f1, float(t) |
|
|
| return {"cafa_fmax": round(float(best_f1), 4), "t_star": round(best_t, 3)} |
|
|
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def compute_per_label_thresholds(model, loader, device, support_tr, |
| min_support=MIN_SUPPORT, fallback=FALLBACK_THRESH): |
| model.eval() |
| all_p, all_y = [], [] |
| for Xb, Yb in loader: |
| all_p.append(torch.sigmoid(model(Xb.to(device))).cpu()) |
| all_y.append(Yb) |
| probs = torch.cat(all_p).numpy().astype(np.float32) |
| true = torch.cat(all_y).numpy().astype(np.float32) |
|
|
| steps = np.arange(0.10, 0.96, 0.05, dtype=np.float32) |
| thr = np.full(OUT_DIM, fallback, dtype=np.float32) |
| for j in range(OUT_DIM): |
| if int(support_tr[j]) < min_support: |
| continue |
| pj, tj = probs[:, j], true[:, j] |
| best_f1, best_t = -1.0, fallback |
| for t in steps: |
| pred = (pj >= t).astype(np.float32) |
| tp = float((pred * tj).sum()) |
| fp = float((pred * (1 - tj)).sum()) |
| fn = float(((1 - pred) * tj).sum()) |
| d = 2 * tp + fp + fn |
| f1 = (2 * tp / d) if d > 0 else 0.0 |
| if f1 > best_f1: |
| best_f1, best_t = f1, float(t) |
| thr[j] = best_t |
| return thr |
|
|
|
|
| |
| |
| |
|
|
| def load_go_parents(obo_path: Path) -> dict: |
| if not obo_path.exists(): |
| print(f" WARNING: {obo_path} not found β CAFA eval disabled") |
| return {} |
| ns_map, par_map = {}, {} |
| cur_id, cur_ns, cur_par, in_term = None, None, set(), False |
|
|
| def flush(): |
| nonlocal cur_id, cur_ns, cur_par |
| if cur_id and cur_ns: |
| ns_map[cur_id] = cur_ns |
| par_map[cur_id] = cur_par |
| cur_id, cur_ns, cur_par = None, None, set() |
|
|
| with open(obo_path, encoding="utf-8") as fh: |
| for raw in fh: |
| line = raw.strip() |
| if line == "[Term]": |
| flush(); in_term = True; continue |
| if line.startswith("[") and line != "[Term]": |
| flush(); in_term = False; continue |
| if not in_term: continue |
| if line.startswith("id:"): |
| cur_id = line.split("id:", 1)[1].strip().split()[0] |
| elif line.startswith("namespace:"): |
| cur_ns = line.split("namespace:", 1)[1].strip() |
| elif line.startswith("is_obsolete:") and "true" in line: |
| cur_id = None |
| elif line.startswith("is_a:"): |
| cur_par.add(line.split("is_a:", 1)[1].strip().split()[0]) |
| elif line.startswith("relationship:"): |
| pts = line.split("relationship:", 1)[1].strip().split() |
| if len(pts) >= 2 and pts[0] == "part_of": |
| cur_par.add(pts[1]) |
| flush() |
| mf = {g for g, n in ns_map.items() if n == "molecular_function"} |
| return {g: (par_map[g] & mf) for g in mf} |
|
|
|
|
| |
| |
| |
|
|
| def parse_labels(x): |
| if x is None: return [] |
| if isinstance(x, (list, np.ndarray)): return [int(v) for v in x] |
| if isinstance(x, str): |
| s = x.strip() |
| if not s or s.lower() == "nan": return [] |
| try: |
| v = ast.literal_eval(s) |
| return [int(i) for i in (v if isinstance(v, (list, tuple)) else [v])] |
| except Exception: |
| return [] |
| return [] |
|
|
|
|
| def select_feature_matrix(X_all_full: np.ndarray, feature_level: str) -> np.ndarray: |
| """ |
| Match the feature slices used in HPO: |
| esm_only -> 320 |
| esm_seq -> 331 |
| esm_all -> 360 |
| """ |
| if feature_level == "esm_only": |
| return X_all_full[:, :ESM_DIM] |
| if feature_level == "esm_seq": |
| seq_end = ESM_DIM + len(SEQ_ONLY_COLS) |
| return X_all_full[:, :seq_end] |
| if feature_level == "esm_all": |
| return X_all_full |
| raise ValueError(f"Unknown feature_level: {feature_level}") |
|
|
|
|
| |
| |
| |
|
|
| def parse_fasta(path): |
| header, seq = None, [] |
| with open(path) as fh: |
| for line in fh: |
| line = line.strip() |
| if line.startswith(">"): |
| if header is not None: |
| yield header, "".join(seq) |
| header = line[1:].split()[0] |
| seq = [] |
| else: |
| seq.append(line) |
| if header is not None: |
| yield header, "".join(seq) |
|
|
|
|
| def fetch_go_mf(uniprot_id, exp_codes, timeout): |
| url = f"https://rest.uniprot.org/uniprotkb/{uniprot_id}.json?fields=go_f" |
| try: |
| r = requests.get(url, timeout=timeout) |
| if r.status_code != 200: |
| return [] |
| refs = r.json().get("uniProtKBCrossReferences", []) |
| terms = [] |
| for ref in refs: |
| if ref.get("database") != "GO": |
| continue |
| go_id = ref.get("id", "") |
| props = {p.get("key"): p.get("value") for p in ref.get("properties", [])} |
| aspect = props.get("GoTerm", "") |
| evidence = props.get("GoEvidenceType", "") |
| if aspect.startswith("F:") and evidence.split(":")[0] in exp_codes: |
| terms.append(go_id) |
| return terms |
| except Exception: |
| return [] |
|
|
|
|
| def build_mammal_dataset(fasta_path, mlb, device, max_seq_len=2500): |
| print(" Loading ESM-2 for mammal embedding...") |
| import esm as esm_lib |
| esm_model, alphabet = esm_lib.pretrained.esm2_t6_8M_UR50D() |
| esm_model = esm_model.to(device).eval() |
| bc = alphabet.get_batch_converter() |
|
|
| entries = [(h, s) for h, s in parse_fasta(fasta_path) if 30 <= len(s) <= max_seq_len] |
| print(f" Parsed {len(entries)} mammal sequences") |
|
|
| print(f" Fetching GO annotations from UniProt ({API_THREADS} threads)...") |
| go_labels, lock = {}, threading.Lock() |
|
|
| def fetch_one(hdr_seq): |
| hdr, _ = hdr_seq |
| uid = hdr.split("|")[1] if "|" in hdr else hdr.split()[0] |
| terms = fetch_go_mf(uid, EXP_CODES, API_TIMEOUT) |
| with lock: |
| go_labels[hdr] = terms |
|
|
| with ThreadPoolExecutor(max_workers=API_THREADS) as ex: |
| futures = {ex.submit(fetch_one, e): e for e in entries} |
| for i, fut in enumerate(as_completed(futures)): |
| if (i + 1) % 100 == 0: |
| print(f" {i+1}/{len(entries)} annotations fetched") |
|
|
| go2idx = {g: i for i, g in enumerate(mlb.classes_)} |
| labeled = [(h, s, [go2idx[t] for t in go_labels.get(h, []) if t in go2idx]) |
| for h, s in entries] |
| labeled = [(h, s, l) for h, s, l in labeled if len(l) > 0] |
| print(f" {len(labeled)} mammal sequences with β₯1 GO-MF label (experimental evidence)") |
|
|
| print(f" Embedding {len(labeled)} sequences with ESM-2...") |
| TOKEN_BUDGET = 6000 |
| rows = [] |
| batch_buf = [] |
| total_tok = 0 |
|
|
| def flush_batch(buf): |
| labels_list = [item[2] for item in buf] |
| batch_data = [(h, s) for h, s, _ in buf] |
| _, _, toks = bc(batch_data) |
| with torch.no_grad(): |
| rep = esm_model(toks.to(device), repr_layers=[6])["representations"][6] |
| for k, (h, s, labs) in enumerate(buf): |
| emb = rep[k, 1:len(s)+1].mean(0).cpu().numpy() |
| row = {f"Dim_{i}": float(emb[i]) for i in range(ESM_DIM)} |
| row["Label_Indices"] = labs |
| row["f_af_present"] = 0.0 |
| rows.append(row) |
|
|
| for h, s, labs in labeled: |
| ntok = len(s) + 2 |
| if total_tok + ntok > TOKEN_BUDGET and batch_buf: |
| flush_batch(batch_buf) |
| batch_buf, total_tok = [], 0 |
| batch_buf.append((h, s, labs)) |
| total_tok += ntok |
| if batch_buf: |
| flush_batch(batch_buf) |
|
|
| return pd.DataFrame(rows) |
|
|
|
|
| |
| |
| |
|
|
| def run_training(args, X_all, Y_all, train_idx, val_idx, test_idx, |
| support_tr, mu, sd, go_parents, mlb, in_dim, |
| feature_label, ckpt_out, thresh_out, log_out, |
| device): |
| """ |
| Train one model variant. Returns dict with best val Fmax, test metrics, etc. |
| """ |
| ds_tr = EmbeddingDataset(X_all, Y_all, train_idx) |
| ds_va = EmbeddingDataset(X_all, Y_all, val_idx) |
| ds_te = EmbeddingDataset(X_all, Y_all, test_idx) |
|
|
| inv_sqrt = 1.0 / np.sqrt(np.maximum(support_tr, 1.0)) |
| w_train = np.array([ |
| float(np.mean(inv_sqrt[Y_all[i].astype(bool)])) if Y_all[i].any() else 1.0 |
| for i in train_idx |
| ], dtype=np.float32) |
| sampler = WeightedRandomSampler(torch.as_tensor(w_train), len(w_train), replacement=True) |
|
|
| ld_tr = DataLoader(ds_tr, args.batch, sampler=sampler, num_workers=args.num_workers) |
| ld_va = DataLoader(ds_va, args.batch, shuffle=False, num_workers=args.num_workers) |
| ld_te = DataLoader(ds_te, args.batch, shuffle=False, num_workers=args.num_workers) |
|
|
| pw_np = np.clip((len(train_idx) - support_tr) / np.maximum(support_tr, 1.0), *PW_CLIP) |
| pos_weight = torch.tensor(pw_np, dtype=torch.float32).to(device) |
|
|
| model = ImprovedResidualMLP( |
| in_dim=in_dim, out_dim=OUT_DIM, |
| hidden=args.hidden, n_blocks=args.blocks, dropout=args.dropout, |
| ).to(device) |
| n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| print(f"\n [{feature_label}] Model: {n_params:,} params (in_dim={in_dim})") |
|
|
| |
| warm_start(model, WARMSTART, in_dim) |
|
|
| criterion = SmoothBCEWithLogitsLoss(pos_weight, smoothing=args.label_smooth) |
| optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, |
| weight_decay=args.weight_decay, eps=1e-7) |
| warmup = max(1, int(args.epochs * 0.08)) |
| scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda(warmup, args.epochs)) |
|
|
| print(f" [{feature_label}] Training (epochs={args.epochs}, patience={args.patience})") |
| best_fmax, no_improve = -1.0, 0 |
| log = [] |
|
|
| for ep in range(1, args.epochs + 1): |
| t0 = time.time() |
| model.train() |
| running = 0.0 |
| for Xb, Yb in ld_tr: |
| Xb, Yb = Xb.to(device), Yb.to(device) |
| optimizer.zero_grad() |
| with torch.amp.autocast(device_type=device.type, dtype=torch.bfloat16): |
| loss = criterion(model(Xb), Yb) |
| loss.backward() |
| nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
| running += loss.item() |
| scheduler.step() |
|
|
| train_loss = running / len(ld_tr) |
| mi = eval_micro_fmax(model, ld_va, device) |
| fmax = mi["micro_fmax"] |
|
|
| cafa_fmax = None |
| if go_parents and (ep % 5 == 0 or ep == args.epochs): |
| ci = eval_cafa_fmax(model, ld_va, device, mlb.classes_, go_parents) |
| cafa_fmax = ci["cafa_fmax"] |
|
|
| elapsed = time.time() - t0 |
| lr_now = scheduler.get_last_lr()[0] |
| entry = { |
| "feature_label": feature_label, "epoch": ep, |
| "train_loss": round(train_loss, 6), "val_micro_fmax": round(fmax, 4), |
| "val_t_star": round(mi["t_star"], 3), |
| "val_prec": round(mi["precision"], 4), "val_rec": round(mi["recall"], 4), |
| "val_cafa_fmax": round(cafa_fmax, 4) if cafa_fmax is not None else None, |
| "lr": round(lr_now, 7), "elapsed_s": round(elapsed, 1), |
| } |
| log.append(entry) |
| with open(log_out, "w") as f: |
| json.dump(log, f, indent=2) |
|
|
| cafa_str = f" CAFA={cafa_fmax:.4f}" if cafa_fmax is not None else "" |
| print( |
| f" Ep {ep:3d}/{args.epochs} | loss={train_loss:.4f} | " |
| f"micro-fmax={fmax:.4f} @t={mi['t_star']:.2f} " |
| f"P={mi['precision']:.3f} R={mi['recall']:.3f}{cafa_str} | " |
| f"lr={lr_now:.2e} | {elapsed:.0f}s" |
| ) |
|
|
| if fmax > best_fmax: |
| best_fmax = fmax |
| no_improve = 0 |
| torch.save({ |
| "model": model.state_dict(), |
| "epoch": ep, |
| "val_fmax": fmax, |
| "feature_label": feature_label, |
| "in_dim": in_dim, |
| "hidden": args.hidden, |
| "n_blocks": args.blocks, |
| "supp_mu": mu.tolist(), |
| "supp_sd": sd.tolist(), |
| "supp_cols": SUPP_COLS, |
| }, ckpt_out) |
| print(f" β Best micro-fmax={fmax:.4f} β saved") |
| else: |
| no_improve += 1 |
| if no_improve >= args.patience: |
| print(f"\n Early stopping at epoch {ep}") |
| break |
|
|
| |
| print(f"\n [{feature_label}] Test evaluation...") |
| ckpt = torch.load(ckpt_out, map_location=device) |
| model.load_state_dict(ckpt["model"]) |
| test_micro = eval_micro_fmax(model, ld_te, device) |
| test_cafa = eval_cafa_fmax(model, ld_te, device, mlb.classes_, go_parents) if go_parents else {} |
| print(f" Test micro-Fmax={test_micro['micro_fmax']:.4f} " |
| f"P={test_micro['precision']:.4f} R={test_micro['recall']:.4f}") |
| if test_cafa: |
| print(f" Test CAFA-Fmax={test_cafa.get('cafa_fmax', 'N/A')}") |
|
|
| |
| print(f" [{feature_label}] Computing per-label thresholds...") |
| thr = compute_per_label_thresholds(model, ld_va, device, support_tr) |
| with open(thresh_out, "w") as f: |
| json.dump({str(i): float(thr[i]) for i in range(OUT_DIM)}, f) |
|
|
| log.append({"test_micro": test_micro, "test_cafa": test_cafa}) |
| with open(log_out, "w") as f: |
| json.dump(log, f, indent=2) |
|
|
| return { |
| "feature_label": feature_label, |
| "in_dim": in_dim, |
| "best_val_fmax": best_fmax, |
| "test_micro_fmax": test_micro["micro_fmax"], |
| "test_cafa_fmax": test_cafa.get("cafa_fmax"), |
| "ckpt": str(ckpt_out), |
| "thresh": str(thresh_out), |
| } |
|
|
|
|
| |
| |
| |
|
|
| def train(args): |
| device = torch.device( |
| "mps" if torch.backends.mps.is_available() else |
| "cuda" if torch.cuda.is_available() else "cpu" |
| ) |
| print(f"Device: {device}") |
|
|
| |
| if device.type == "mps": |
| torch.mps.set_per_process_memory_fraction(0.95) |
|
|
| |
| if args.start_from != "A": |
| args.ablation = True |
|
|
| ckpt_out = Path(args.checkpoint_out) if args.checkpoint_out else CKPT_OUT |
| thresh_out = Path(args.threshold_out) if args.threshold_out else THRESH_OUT |
| log_out = Path(args.log_out) if args.log_out else LOG_OUT |
| ckpt_out.parent.mkdir(parents=True, exist_ok=True) |
| thresh_out.parent.mkdir(parents=True, exist_ok=True) |
| log_out.parent.mkdir(parents=True, exist_ok=True) |
|
|
| |
| if args.num_workers is None: |
| args.num_workers = 0 if device.type == "mps" else 4 |
| print(f"DataLoader num_workers: {args.num_workers}") |
|
|
| |
| print("\n[1/6] Loading insect dataset...") |
| df_base = pd.read_parquet(DATA_BASE) |
| df_supp = pd.read_parquet(DATA_SUPP) |
| emb_cols = [c for c in df_base.columns if c.startswith("Dim_")] |
| assert len(emb_cols) == ESM_DIM |
|
|
| mlb = joblib.load(MLB_PATH) |
| assert len(mlb.classes_) == OUT_DIM |
|
|
| |
| print("[1b/6] Loading GO hierarchy (for CAFA-style evaluation only)...") |
| go_parents = load_go_parents(OBO_PATH) |
| print(f" GO parents loaded: {len(go_parents)} MF terms") |
|
|
| |
| print("[2/6] Building insect label matrix (original annotations, no propagation)...") |
| label_lists = [parse_labels(x) for x in df_base["Label_Indices"]] |
| Y_insect = np.zeros((len(df_base), OUT_DIM), dtype=np.uint8) |
| for r, labs in enumerate(label_lists): |
| for j in labs: |
| if 0 <= j < OUT_DIM: |
| Y_insect[r, j] = 1 |
| print(f" Insect label matrix: {int(Y_insect.sum()):,} positives across {len(Y_insect):,} proteins") |
| print(f" (No GO propagation applied β model will learn specific terms as annotated)") |
|
|
| |
| S_raw = df_supp[SUPP_COLS].to_numpy(np.float32) |
| m_flag = df_supp["f_af_present"].to_numpy(np.float32).reshape(-1, 1) |
| X_base = df_base[emb_cols].to_numpy(np.float32) |
|
|
| |
| splits = np.load(SPLITS_NPZ) |
| train_idx = splits["train_idx"] |
| val_idx = splits["val_idx"] |
| test_idx = splits["test_idx"] |
| print(f" Insect splits β train:{len(train_idx)} val:{len(val_idx)} test:{len(test_idx)}") |
|
|
| |
| S_tr = S_raw[train_idx] |
| mu = np.nanmean(S_tr, axis=0) |
| sd = np.where(np.nanstd(S_tr, axis=0) > 0, np.nanstd(S_tr, axis=0), 1.0) |
| S_z = (S_raw - mu) / (sd + 1e-12) |
| X_insect_full = np.concatenate([X_base, S_z, m_flag], axis=1).astype(np.float32) |
| print(f" Insect X shape (full features): {X_insect_full.shape}") |
|
|
| |
| X_mammal, Y_mammal = None, None |
| if not args.skip_mammal and MAMMAL_FASTA.exists(): |
| print("\n[3/6] Processing mammal data...") |
| if MAMMAL_EMB.exists(): |
| print(" Loading cached mammal embeddings...") |
| df_m = pd.read_parquet(MAMMAL_EMB) |
| else: |
| print(" Computing mammal embeddings (first run)...") |
| df_m = build_mammal_dataset(MAMMAL_FASTA, mlb, device) |
| df_m.to_parquet(MAMMAL_EMB, index=False) |
|
|
| m_emb_cols = [f"Dim_{i}" for i in range(ESM_DIM)] |
| X_m_base = df_m[m_emb_cols].to_numpy(np.float32) |
|
|
| |
| Y_m = np.zeros((len(df_m), OUT_DIM), dtype=np.uint8) |
| for r, labs in enumerate(df_m["Label_Indices"].tolist()): |
| for j in parse_labels(labs): |
| if 0 <= j < OUT_DIM: |
| Y_m[r, j] = 1 |
| print(f" Mammal: {int(Y_m.sum()):,} positives (original annotations, no propagation)") |
|
|
| S_m = np.zeros((len(df_m), SUPP_DIM), dtype=np.float32) |
| for ci, col in enumerate(SUPP_COLS): |
| if col in df_m.columns: |
| v = df_m[col].to_numpy(np.float32) |
| S_m[:, ci] = np.where(np.isnan(v), mu[ci], v) |
| S_m_z = (S_m - mu) / (sd + 1e-12) |
| m_m = df_m["f_af_present"].to_numpy(np.float32).reshape(-1, 1) if "f_af_present" in df_m else np.zeros((len(df_m), 1), np.float32) |
| X_mammal = np.concatenate([X_m_base, S_m_z, m_m], axis=1).astype(np.float32) |
| Y_mammal = Y_m |
| print(f" Mammal X shape: {X_mammal.shape}") |
| else: |
| print("\n[3/6] Skipping mammal data") |
|
|
| |
| print("\n[4/6] Merging datasets...") |
| if X_mammal is not None: |
| X_all_full = np.concatenate([X_insect_full, X_mammal], axis=0) |
| Y_all = np.concatenate([Y_insect, Y_mammal], axis=0) |
| mammal_idx = np.arange(len(X_insect_full), len(X_all_full)) |
| train_idx_combined = np.concatenate([train_idx, mammal_idx]) |
| print(f" Combined: {len(X_all_full)} rows (insect+mammal train: {len(train_idx_combined)})") |
| else: |
| X_all_full = X_insect_full |
| Y_all = Y_insect |
| train_idx_combined = train_idx |
|
|
| support_tr = Y_all[train_idx_combined].sum(0).astype(np.float32) |
| print(f" Training positives per label (median): {np.median(support_tr[support_tr>0]):.0f}") |
|
|
| |
| if args.ablation: |
| print("\n[5/6] Ablation study: ESM-only vs ESM+seq vs ESM+all (pLDDT/PAE)") |
| ablation_results = [] |
|
|
| |
| |
| X_a = X_all_full[:, :ESM_DIM] |
| |
| seq_end = ESM_DIM + len(SEQ_ONLY_COLS) |
| X_b = X_all_full[:, :seq_end] |
| |
| X_c = X_all_full |
|
|
| variants = [ |
| ("A_ESM_only", X_a, ESM_DIM), |
| ("B_ESM_seq", X_b, seq_end), |
| ("C_ESM_seq_AF", X_c, X_c.shape[1]), |
| ] |
|
|
| for label, X_variant, in_dim in variants: |
| if label[0] < args.start_from: |
| print(f" Skipping {label} (--start_from {args.start_from})") |
| continue |
| ckpt_v = ART / "checkpoints" / f"ablation_{label}.pth" |
| thresh_v = ART / "thresholds" / f"ablation_{label}_thresholds.json" |
| log_v = ART / "logs" / f"ablation_{label}_log.json" |
| result = run_training( |
| args, X_variant, Y_all, train_idx_combined, val_idx, test_idx, |
| support_tr, mu, sd, go_parents, mlb, in_dim, |
| label, ckpt_v, thresh_v, log_v, device, |
| ) |
| ablation_results.append(result) |
|
|
| |
| ablation_out = ART / "ablation_results.json" |
| with open(ablation_out, "w") as f: |
| json.dump(ablation_results, f, indent=2) |
|
|
| print("\n" + "=" * 75) |
| print("ABLATION RESULTS") |
| print("=" * 75) |
| print(f"{'Model':<20} {'Val Fmax':>10} {'Test Fmax':>10} {'Test CAFA':>10}") |
| print("-" * 60) |
| for r in ablation_results: |
| cafa = f"{r['test_cafa_fmax']:.4f}" if r.get("test_cafa_fmax") else "N/A" |
| print(f"{r['feature_label']:<20} {r['best_val_fmax']:>10.4f} {r['test_micro_fmax']:>10.4f} {cafa:>10}") |
| print("=" * 75) |
| print(f"\nDetailed results β {ablation_out}") |
|
|
| |
| best = max(ablation_results, key=lambda r: r["test_micro_fmax"]) |
| print(f"\nBest model: {best['feature_label']} (test Fmax={best['test_micro_fmax']:.4f})") |
|
|
| else: |
| |
| X_single = select_feature_matrix(X_all_full, args.feature_level) |
| in_dim = X_single.shape[1] |
| print(f"\n[5/6] Training single model ({args.feature_level}, in_dim={in_dim})...") |
| run_training( |
| args, X_single, Y_all, train_idx_combined, val_idx, test_idx, |
| support_tr, mu, sd, go_parents, mlb, in_dim, |
| args.feature_label, ckpt_out, thresh_out, log_out, device, |
| ) |
|
|
| print(f"\n[6/6] Done.") |
| print(f" Main checkpoint β {ckpt_out}") |
| print(f" Thresholds β {thresh_out}") |
| print(f"\nTo deploy to HuggingFace:") |
| print(f" huggingface-cli upload Sbhat2026/protfunc-models {ckpt_out} protfunc_v3_fixed.pth") |
| print(f" huggingface-cli upload Sbhat2026/protfunc-models {thresh_out} protfunc_v3_fixed_thresholds.json") |
|
|
|
|
| if __name__ == "__main__": |
| p = argparse.ArgumentParser(description="ProtFunc v3 (corrected label handling)") |
| p.add_argument("--epochs", type=int, default=50, |
| help="Max training epochs (warm-start needs fewer than from-scratch)") |
| p.add_argument("--hidden", type=int, default=2048) |
| p.add_argument("--blocks", type=int, default=4) |
| p.add_argument("--dropout", type=float, default=0.20) |
| p.add_argument("--lr", type=float, default=5e-5, |
| help="Lower LR than from-scratch (warm-start already near optimum)") |
| p.add_argument("--batch", type=int, default=512) |
| p.add_argument("--patience", type=int, default=10) |
| p.add_argument("--label_smooth", type=float, default=0.05) |
| p.add_argument("--weight_decay", type=float, default=5e-4) |
| p.add_argument("--skip-mammal", action="store_true") |
| p.add_argument("--feature_level", type=str, default="esm_all", |
| choices=["esm_only", "esm_seq", "esm_all"], |
| help="Single-run feature slice (ignored when --ablation is used)") |
| p.add_argument("--feature_label", type=str, default="v3_fixed", |
| help="Single-run feature label stored in checkpoint/log output") |
| p.add_argument("--checkpoint_out", type=str, default="", |
| help="Optional explicit checkpoint output path") |
| p.add_argument("--threshold_out", type=str, default="", |
| help="Optional explicit threshold output path") |
| p.add_argument("--log_out", type=str, default="", |
| help="Optional explicit training log output path") |
| p.add_argument("--ablation", action="store_true", |
| help="Run ESM-only / ESM+seq / ESM+all ablation study") |
| p.add_argument("--start_from", type=str, default="A", choices=["A", "B", "C"], |
| help="Skip ablation variants before this model (implies --ablation)") |
| p.add_argument("--num_workers", type=int, default=None, |
| help="DataLoader num_workers (default: 0 on MPS, 4 otherwise)") |
| train(p.parse_args()) |
|
|