protfunc / scripts /train_v3_fixed.py
Sbhat2026's picture
perf: ESM embedding cache + 1500aa limit, add research scripts
7f7a890
"""
train_v3_fixed.py β€” ProtFunc v3 (corrected training procedure)
================================================================
Fixes the core methodological error in train_v2.py:
WRONG (train_v2.py): propagate GO labels DURING TRAINING
CORRECT (this file): train on original labels; propagate ONLY during evaluation
Why the original approach was wrong:
Propagating labels before training causes the model to directly predict broad
ancestor terms ("binding", "catalytic activity") for nearly every protein.
This inflates val micro-Fmax from ~0.88 to ~0.97 β€” an apples-to-oranges
comparison that looks like a massive improvement but is mostly just the model
learning trivially predictable parent terms. Threshold calibration on propagated
ground truth then lets those broad terms fire at inference, producing hundreds
of predictions per protein.
The correct approach (used by CAFA competitors):
1. Train on experimental annotations as-is (specific terms only)
2. Evaluate using CAFA-style propagation (predictions + ground-truth both
propagated upward) β€” this is fair because a protein that does "ATP
hydrolysis" also implicitly performs "binding"
Warm-start from improved_res.pth:
Instead of rebuilding from scratch, we load improved_res.pth (Fmax=0.8846)
as the starting point. For the supplemented (360-dim) model, we copy all
weights and extend fc_in with small random values for the new feature dims.
This typically saves ~15-20 training epochs.
Ablation study for research question:
"How much predictive gain from AlphaFold structural features (pLDDT, PAE)?"
Run with --ablation to train three models in sequence:
A: ESM only (320 dim) β€” baseline
B: ESM + sequence features (331 dim) β€” adds composition/physicochemical
C: ESM + all features (360 dim) β€” adds pLDDT, PAE, AF confidence
Each model is evaluated on:
1. All proteins
2. AF-covered proteins only (where pLDDT/PAE are non-zero)
This isolates the structural feature contribution.
Outputs:
artifacts/protfunc_v3_fixed.pth β€” best model (model C by default)
artifacts/protfunc_v3_fixed_thresholds.json
artifacts/protfunc_v3_fixed_log.json
artifacts/ablation_results.json β€” if --ablation flag used
"""
import os, re, ast, json, math, time, argparse, warnings, requests, threading
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
import numpy as np
import pandas as pd
import joblib
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
warnings.filterwarnings("ignore")
# ── Paths ─────────────────────────────────────────────────────────────────────
BASE = Path(__file__).parent.parent
ART = BASE / "artifacts"
IMPORTANT = BASE / "Important Files"
ART.mkdir(exist_ok=True)
DATA_BASE = IMPORTANT / "merged_full_struct.parquet"
DATA_SUPP = IMPORTANT / "merged_full_struct_with_features.parquet"
MLB_PATH = IMPORTANT / "mlb_public_v1.pkl"
SPLITS_NPZ = ART / "splits" / "splits_n250000_seed42.npz"
MAMMAL_FASTA = IMPORTANT / "mammal_subset.fasta"
MAMMAL_EMB = ART / "generalization" / "mammal_embeddings_v3.parquet"
OBO_PATH = BASE / "go-basic.obo"
# Prefer best ablation checkpoint (same in_dim=360); fall back to improved_res.pth
_WARMSTART_CANDIDATES = [
ART / "checkpoints" / "ablation_C_ESM_seq_AF.pth",
ART / "checkpoints" / "protfunc_v3_fixed.pth",
ART / "checkpoints" / "improved_res.pth",
]
WARMSTART = next((p for p in _WARMSTART_CANDIDATES if p.exists()), _WARMSTART_CANDIDATES[-1])
CKPT_OUT = ART / "checkpoints" / "protfunc_v3_fixed.pth"
THRESH_OUT = ART / "thresholds" / "protfunc_v3_fixed_thresholds.json"
LOG_OUT = ART / "logs" / "protfunc_v3_fixed_log.json"
# ── Supplemented feature columns (39 total) ───────────────────────────────────
SUPP_COLS = [
"f_seq_len", "f_mean_hydro", "f_net_charge", "f_uversky_disorder",
"f_idr_frac_proxy", "f_lowcomp_proxy", "f_tm_frac_proxy", "f_tm_any_proxy",
"f_signal_peptide_proxy", "f_cf_helix_mean", "f_cf_sheet_mean",
"f_afdb_has_model",
"f_plddt_mean", "f_plddt_std", "f_plddt_q10", "f_plddt_q50", "f_plddt_q90",
"f_plddt_frac_gt90", "f_plddt_frac_gt70", "f_plddt_frac_lt50",
"f_distbin_0", "f_distbin_1", "f_distbin_2", "f_distbin_3", "f_distbin_4",
"f_distbin_5", "f_distbin_6", "f_distbin_7", "f_distbin_8", "f_distbin_9",
"f_pae_mean", "f_pae_median", "f_pae_p90", "f_pae_p95",
"f_pae_frac_lt5", "f_pae_frac_lt10", "f_pae_frac_gt20",
"f_seqfeat_present", "f_af_present",
]
SEQ_ONLY_COLS = [
"f_seq_len", "f_mean_hydro", "f_net_charge", "f_uversky_disorder",
"f_idr_frac_proxy", "f_lowcomp_proxy", "f_tm_frac_proxy", "f_tm_any_proxy",
"f_signal_peptide_proxy", "f_cf_helix_mean", "f_cf_sheet_mean",
]
# ── Config ────────────────────────────────────────────────────────────────────
SEED = 42
OUT_DIM = 8124
ESM_DIM = 320
SUPP_DIM = len(SUPP_COLS) # 39
BATCH = 512
PW_CLIP = (1.0, 100.0)
MIN_SUPPORT = 10
FALLBACK_THRESH = 0.50
API_TIMEOUT = 12
API_THREADS = 20
EXP_CODES = {"IDA","IMP","IPI","IGI","IEP","EXP","HDA","HMP","HGI","HEP","TAS","IC"}
np.random.seed(SEED)
torch.manual_seed(SEED)
# ─────────────────────────────────────────────────────────────────────────────
# Architecture (same as improved / train_v2)
# ─────────────────────────────────────────────────────────────────────────────
class ResBlock(nn.Module):
def __init__(self, dim: int, dropout: float):
super().__init__()
self.net = nn.Sequential(
nn.BatchNorm1d(dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(dim, dim),
nn.BatchNorm1d(dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(dim, dim),
)
def forward(self, x):
return x + self.net(x)
class ImprovedResidualMLP(nn.Module):
def __init__(self, in_dim, out_dim=OUT_DIM, hidden=2048, n_blocks=4, dropout=0.2):
super().__init__()
self.fc_in = nn.Linear(in_dim, hidden)
self.blocks = nn.ModuleList([ResBlock(hidden, dropout) for _ in range(n_blocks)])
self.fc_out = nn.Sequential(
nn.BatchNorm1d(hidden), nn.ReLU(), nn.Dropout(dropout),
nn.Linear(hidden, out_dim),
)
def forward(self, x):
h = self.fc_in(x)
for blk in self.blocks:
h = blk(h)
return self.fc_out(h)
# ─────────────────────────────────────────────────────────────────────────────
# Warm-start: copy weights from improved_res.pth (in_dim=320) β†’ new model
# ─────────────────────────────────────────────────────────────────────────────
def warm_start(model: ImprovedResidualMLP, ckpt_path: Path, in_dim_new: int):
"""
Load improved_res.pth weights into model, extending fc_in if needed.
For in_dim_new > 320: copies the first 320 columns of fc_in.weight and
initialises the remaining (in_dim_new - 320) columns with small random
values so the new features start near-zero and don't disrupt representations.
All other layers (blocks, fc_out) are copied exactly.
"""
if not ckpt_path.exists():
print(f" Warm-start skipped: {ckpt_path} not found (training from scratch)")
return
raw = torch.load(ckpt_path, map_location="cpu")
src = raw["model"] if isinstance(raw, dict) and "model" in raw else raw
dst = model.state_dict()
loaded, skipped = 0, 0
for k, v in src.items():
if k not in dst:
skipped += 1
continue
if k == "fc_in.weight":
old_in = v.shape[1]
new_in = dst[k].shape[1]
if new_in == old_in:
dst[k] = v
elif new_in > old_in:
# Extend: copy known dims, small random for new dims
new_w = torch.zeros_like(dst[k])
new_w[:, :old_in] = v
new_w[:, old_in:] = torch.randn(v.shape[0], new_in - old_in) * 0.01
dst[k] = new_w
else:
# Truncate: take first new_in columns of the checkpoint weight
dst[k] = v[:, :new_in].clone()
else:
if dst[k].shape == v.shape:
dst[k] = v
else:
skipped += 1
continue
loaded += 1
model.load_state_dict(dst)
print(f" Warm-start from {ckpt_path.name}: {loaded} tensors loaded, {skipped} skipped")
# ─────────────────────────────────────────────────────────────────────────────
# Dataset
# ─────────────────────────────────────────────────────────────────────────────
class EmbeddingDataset(Dataset):
def __init__(self, X, Y, indices):
self.X = X
self.Y = Y
self.idx = indices.astype(np.int64)
def __len__(self):
return len(self.idx)
def __getitem__(self, k):
i = int(self.idx[k])
return self.X[i], self.Y[i].astype(np.float32)
# ─────────────────────────────────────────────────────────────────────────────
# Loss
# ─────────────────────────────────────────────────────────────────────────────
class SmoothBCEWithLogitsLoss(nn.Module):
def __init__(self, pos_weight, smoothing=0.05):
super().__init__()
self.smooth = smoothing
self.bce = nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction="mean")
def forward(self, logits, targets):
t = targets * (1 - self.smooth) + (1 - targets) * self.smooth
return self.bce(logits, t)
def lr_lambda(warmup, total):
def fn(ep):
if ep < warmup:
return (ep + 1) / warmup
p = (ep - warmup) / max(total - warmup, 1)
return 0.5 * (1 + math.cos(math.pi * p))
return fn
# ─────────────────────────────────────────────────────────────────────────────
# Evaluation: micro-Fmax (streaming histogram over ORIGINAL ground truth)
# NOTE: no label propagation applied here β€” this measures on specific terms
# ─────────────────────────────────────────────────────────────────────────────
@torch.no_grad()
def eval_micro_fmax(model, loader, device, step=0.02):
model.eval()
edges = np.arange(0.0, 1.0 + step, step)
nbins = len(edges)
hp = np.zeros(nbins, np.int64)
hn = np.zeros(nbins, np.int64)
tp = 0
for Xb, Yb in loader:
p = torch.sigmoid(model(Xb.to(device))).cpu().numpy().ravel()
y = Yb.numpy().ravel() > 0.5
tp += int(y.sum())
bi = np.minimum(np.floor(p / step + 1e-9).astype(np.int64), nbins - 1)
if y.any(): hp += np.bincount(bi[y], minlength=nbins)
if (~y).any(): hn += np.bincount(bi[~y], minlength=nbins)
cum_tp = np.cumsum(hp[::-1])[::-1].astype(float)
cum_fp = np.cumsum(hn[::-1])[::-1].astype(float)
pred = cum_tp + cum_fp
prec = np.where(pred > 0, cum_tp / pred, 0.0)
rec = cum_tp / max(tp, 1)
denom = prec + rec
f1 = np.where(denom > 0, 2 * prec * rec / denom, 0.0)
b = int(np.argmax(f1))
return {"micro_fmax": float(f1[b]), "t_star": float(edges[b]),
"precision": float(prec[b]), "recall": float(rec[b])}
# ─────────────────────────────────────────────────────────────────────────────
# Evaluation: CAFA-style Fmax (propagate predictions UP; ground truth as-is)
# This is the meaningful comparison metric β€” matches CAFA competition scoring.
# ─────────────────────────────────────────────────────────────────────────────
@torch.no_grad()
def eval_cafa_fmax(model, loader, device, mlb_classes, go_parents, step=0.05):
if not go_parents:
return {"cafa_fmax": float("nan"), "t_star": float("nan")}
go2idx = {g: i for i, g in enumerate(mlb_classes)}
anc_map = {}
for gid in mlb_classes:
parents = go_parents.get(gid, set())
visited, stack = set(), list(parents)
while stack:
p = stack.pop()
if p not in visited:
visited.add(p)
stack.extend(go_parents.get(p, set()))
anc_map[gid] = {p for p in visited if p in go2idx}
all_probs, all_true = [], []
for Xb, Yb in loader:
all_probs.append(torch.sigmoid(model(Xb.to(device))).cpu().numpy())
all_true.append(Yb.numpy())
probs = np.concatenate(all_probs, axis=0).astype(np.float32)
true = np.concatenate(all_true, axis=0).astype(np.float32)
has_label = true.sum(axis=1) > 0
probs = probs[has_label]
true = true[has_label]
N = len(probs)
if N == 0:
return {"cafa_fmax": float("nan"), "t_star": float("nan")}
anc_idx = [
np.array([go2idx[a] for a in anc_map.get(g, set())], dtype=np.int64)
for g in mlb_classes
]
thresholds = np.arange(0.05, 0.96, step)
best_f1, best_t = -1.0, 0.5
for t in thresholds:
pred_bin = (probs >= t).astype(np.float32)
prop = pred_bin.copy()
for j, aidx in enumerate(anc_idx):
if len(aidx) == 0:
continue
mask = pred_bin[:, j] > 0
if mask.any():
prop[np.ix_(np.where(mask)[0], aidx)] = 1.0
tp_per = (prop * true).sum(axis=1)
pp_per = prop.sum(axis=1)
rp_per = true.sum(axis=1)
prec_per = np.where(pp_per > 0, tp_per / pp_per, 0.0)
rec_per = np.where(rp_per > 0, tp_per / rp_per, 0.0)
has_pred = pp_per > 0
if has_pred.sum() == 0:
continue
avg_prec = prec_per[has_pred].mean()
avg_rec = rec_per.mean()
denom = avg_prec + avg_rec
f1 = (2 * avg_prec * avg_rec / denom) if denom > 0 else 0.0
if f1 > best_f1:
best_f1, best_t = f1, float(t)
return {"cafa_fmax": round(float(best_f1), 4), "t_star": round(best_t, 3)}
# ─────────────────────────────────────────────────────────────────────────────
# Per-label threshold calibration (on ORIGINAL non-propagated val labels)
# ─────────────────────────────────────────────────────────────────────────────
@torch.no_grad()
def compute_per_label_thresholds(model, loader, device, support_tr,
min_support=MIN_SUPPORT, fallback=FALLBACK_THRESH):
model.eval()
all_p, all_y = [], []
for Xb, Yb in loader:
all_p.append(torch.sigmoid(model(Xb.to(device))).cpu())
all_y.append(Yb)
probs = torch.cat(all_p).numpy().astype(np.float32)
true = torch.cat(all_y).numpy().astype(np.float32)
steps = np.arange(0.10, 0.96, 0.05, dtype=np.float32)
thr = np.full(OUT_DIM, fallback, dtype=np.float32)
for j in range(OUT_DIM):
if int(support_tr[j]) < min_support:
continue
pj, tj = probs[:, j], true[:, j]
best_f1, best_t = -1.0, fallback
for t in steps:
pred = (pj >= t).astype(np.float32)
tp = float((pred * tj).sum())
fp = float((pred * (1 - tj)).sum())
fn = float(((1 - pred) * tj).sum())
d = 2 * tp + fp + fn
f1 = (2 * tp / d) if d > 0 else 0.0
if f1 > best_f1:
best_f1, best_t = f1, float(t)
thr[j] = best_t
return thr
# ─────────────────────────────────────────────────────────────────────────────
# GO hierarchy: parse OBO (identical to train_v2.py)
# ─────────────────────────────────────────────────────────────────────────────
def load_go_parents(obo_path: Path) -> dict:
if not obo_path.exists():
print(f" WARNING: {obo_path} not found β€” CAFA eval disabled")
return {}
ns_map, par_map = {}, {}
cur_id, cur_ns, cur_par, in_term = None, None, set(), False
def flush():
nonlocal cur_id, cur_ns, cur_par
if cur_id and cur_ns:
ns_map[cur_id] = cur_ns
par_map[cur_id] = cur_par
cur_id, cur_ns, cur_par = None, None, set()
with open(obo_path, encoding="utf-8") as fh:
for raw in fh:
line = raw.strip()
if line == "[Term]":
flush(); in_term = True; continue
if line.startswith("[") and line != "[Term]":
flush(); in_term = False; continue
if not in_term: continue
if line.startswith("id:"):
cur_id = line.split("id:", 1)[1].strip().split()[0]
elif line.startswith("namespace:"):
cur_ns = line.split("namespace:", 1)[1].strip()
elif line.startswith("is_obsolete:") and "true" in line:
cur_id = None
elif line.startswith("is_a:"):
cur_par.add(line.split("is_a:", 1)[1].strip().split()[0])
elif line.startswith("relationship:"):
pts = line.split("relationship:", 1)[1].strip().split()
if len(pts) >= 2 and pts[0] == "part_of":
cur_par.add(pts[1])
flush()
mf = {g for g, n in ns_map.items() if n == "molecular_function"}
return {g: (par_map[g] & mf) for g in mf}
# ─────────────────────────────────────────────────────────────────────────────
# Label parsing helper
# ─────────────────────────────────────────────────────────────────────────────
def parse_labels(x):
if x is None: return []
if isinstance(x, (list, np.ndarray)): return [int(v) for v in x]
if isinstance(x, str):
s = x.strip()
if not s or s.lower() == "nan": return []
try:
v = ast.literal_eval(s)
return [int(i) for i in (v if isinstance(v, (list, tuple)) else [v])]
except Exception:
return []
return []
def select_feature_matrix(X_all_full: np.ndarray, feature_level: str) -> np.ndarray:
"""
Match the feature slices used in HPO:
esm_only -> 320
esm_seq -> 331
esm_all -> 360
"""
if feature_level == "esm_only":
return X_all_full[:, :ESM_DIM]
if feature_level == "esm_seq":
seq_end = ESM_DIM + len(SEQ_ONLY_COLS)
return X_all_full[:, :seq_end]
if feature_level == "esm_all":
return X_all_full
raise ValueError(f"Unknown feature_level: {feature_level}")
# ─────────────────────────────────────────────────────────────────────────────
# Mammal data helpers (identical to train_v2.py)
# ─────────────────────────────────────────────────────────────────────────────
def parse_fasta(path):
header, seq = None, []
with open(path) as fh:
for line in fh:
line = line.strip()
if line.startswith(">"):
if header is not None:
yield header, "".join(seq)
header = line[1:].split()[0]
seq = []
else:
seq.append(line)
if header is not None:
yield header, "".join(seq)
def fetch_go_mf(uniprot_id, exp_codes, timeout):
url = f"https://rest.uniprot.org/uniprotkb/{uniprot_id}.json?fields=go_f"
try:
r = requests.get(url, timeout=timeout)
if r.status_code != 200:
return []
refs = r.json().get("uniProtKBCrossReferences", [])
terms = []
for ref in refs:
if ref.get("database") != "GO":
continue
go_id = ref.get("id", "")
props = {p.get("key"): p.get("value") for p in ref.get("properties", [])}
aspect = props.get("GoTerm", "")
evidence = props.get("GoEvidenceType", "")
if aspect.startswith("F:") and evidence.split(":")[0] in exp_codes:
terms.append(go_id)
return terms
except Exception:
return []
def build_mammal_dataset(fasta_path, mlb, device, max_seq_len=2500):
print(" Loading ESM-2 for mammal embedding...")
import esm as esm_lib
esm_model, alphabet = esm_lib.pretrained.esm2_t6_8M_UR50D()
esm_model = esm_model.to(device).eval()
bc = alphabet.get_batch_converter()
entries = [(h, s) for h, s in parse_fasta(fasta_path) if 30 <= len(s) <= max_seq_len]
print(f" Parsed {len(entries)} mammal sequences")
print(f" Fetching GO annotations from UniProt ({API_THREADS} threads)...")
go_labels, lock = {}, threading.Lock()
def fetch_one(hdr_seq):
hdr, _ = hdr_seq
uid = hdr.split("|")[1] if "|" in hdr else hdr.split()[0]
terms = fetch_go_mf(uid, EXP_CODES, API_TIMEOUT)
with lock:
go_labels[hdr] = terms
with ThreadPoolExecutor(max_workers=API_THREADS) as ex:
futures = {ex.submit(fetch_one, e): e for e in entries}
for i, fut in enumerate(as_completed(futures)):
if (i + 1) % 100 == 0:
print(f" {i+1}/{len(entries)} annotations fetched")
go2idx = {g: i for i, g in enumerate(mlb.classes_)}
labeled = [(h, s, [go2idx[t] for t in go_labels.get(h, []) if t in go2idx])
for h, s in entries]
labeled = [(h, s, l) for h, s, l in labeled if len(l) > 0]
print(f" {len(labeled)} mammal sequences with β‰₯1 GO-MF label (experimental evidence)")
print(f" Embedding {len(labeled)} sequences with ESM-2...")
TOKEN_BUDGET = 6000
rows = []
batch_buf = []
total_tok = 0
def flush_batch(buf):
labels_list = [item[2] for item in buf]
batch_data = [(h, s) for h, s, _ in buf]
_, _, toks = bc(batch_data)
with torch.no_grad():
rep = esm_model(toks.to(device), repr_layers=[6])["representations"][6]
for k, (h, s, labs) in enumerate(buf):
emb = rep[k, 1:len(s)+1].mean(0).cpu().numpy()
row = {f"Dim_{i}": float(emb[i]) for i in range(ESM_DIM)}
row["Label_Indices"] = labs
row["f_af_present"] = 0.0
rows.append(row)
for h, s, labs in labeled:
ntok = len(s) + 2
if total_tok + ntok > TOKEN_BUDGET and batch_buf:
flush_batch(batch_buf)
batch_buf, total_tok = [], 0
batch_buf.append((h, s, labs))
total_tok += ntok
if batch_buf:
flush_batch(batch_buf)
return pd.DataFrame(rows)
# ─────────────────────────────────────────────────────────────────────────────
# Core training function
# ─────────────────────────────────────────────────────────────────────────────
def run_training(args, X_all, Y_all, train_idx, val_idx, test_idx,
support_tr, mu, sd, go_parents, mlb, in_dim,
feature_label, ckpt_out, thresh_out, log_out,
device):
"""
Train one model variant. Returns dict with best val Fmax, test metrics, etc.
"""
ds_tr = EmbeddingDataset(X_all, Y_all, train_idx)
ds_va = EmbeddingDataset(X_all, Y_all, val_idx)
ds_te = EmbeddingDataset(X_all, Y_all, test_idx)
inv_sqrt = 1.0 / np.sqrt(np.maximum(support_tr, 1.0))
w_train = np.array([
float(np.mean(inv_sqrt[Y_all[i].astype(bool)])) if Y_all[i].any() else 1.0
for i in train_idx
], dtype=np.float32)
sampler = WeightedRandomSampler(torch.as_tensor(w_train), len(w_train), replacement=True)
ld_tr = DataLoader(ds_tr, args.batch, sampler=sampler, num_workers=args.num_workers)
ld_va = DataLoader(ds_va, args.batch, shuffle=False, num_workers=args.num_workers)
ld_te = DataLoader(ds_te, args.batch, shuffle=False, num_workers=args.num_workers)
pw_np = np.clip((len(train_idx) - support_tr) / np.maximum(support_tr, 1.0), *PW_CLIP)
pos_weight = torch.tensor(pw_np, dtype=torch.float32).to(device)
model = ImprovedResidualMLP(
in_dim=in_dim, out_dim=OUT_DIM,
hidden=args.hidden, n_blocks=args.blocks, dropout=args.dropout,
).to(device)
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\n [{feature_label}] Model: {n_params:,} params (in_dim={in_dim})")
# Warm-start from improved_res.pth
warm_start(model, WARMSTART, in_dim)
criterion = SmoothBCEWithLogitsLoss(pos_weight, smoothing=args.label_smooth)
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr,
weight_decay=args.weight_decay, eps=1e-7)
warmup = max(1, int(args.epochs * 0.08))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda(warmup, args.epochs))
print(f" [{feature_label}] Training (epochs={args.epochs}, patience={args.patience})")
best_fmax, no_improve = -1.0, 0
log = []
for ep in range(1, args.epochs + 1):
t0 = time.time()
model.train()
running = 0.0
for Xb, Yb in ld_tr:
Xb, Yb = Xb.to(device), Yb.to(device)
optimizer.zero_grad()
with torch.amp.autocast(device_type=device.type, dtype=torch.bfloat16):
loss = criterion(model(Xb), Yb)
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
running += loss.item()
scheduler.step()
train_loss = running / len(ld_tr)
mi = eval_micro_fmax(model, ld_va, device)
fmax = mi["micro_fmax"]
cafa_fmax = None
if go_parents and (ep % 5 == 0 or ep == args.epochs):
ci = eval_cafa_fmax(model, ld_va, device, mlb.classes_, go_parents)
cafa_fmax = ci["cafa_fmax"]
elapsed = time.time() - t0
lr_now = scheduler.get_last_lr()[0]
entry = {
"feature_label": feature_label, "epoch": ep,
"train_loss": round(train_loss, 6), "val_micro_fmax": round(fmax, 4),
"val_t_star": round(mi["t_star"], 3),
"val_prec": round(mi["precision"], 4), "val_rec": round(mi["recall"], 4),
"val_cafa_fmax": round(cafa_fmax, 4) if cafa_fmax is not None else None,
"lr": round(lr_now, 7), "elapsed_s": round(elapsed, 1),
}
log.append(entry)
with open(log_out, "w") as f:
json.dump(log, f, indent=2)
cafa_str = f" CAFA={cafa_fmax:.4f}" if cafa_fmax is not None else ""
print(
f" Ep {ep:3d}/{args.epochs} | loss={train_loss:.4f} | "
f"micro-fmax={fmax:.4f} @t={mi['t_star']:.2f} "
f"P={mi['precision']:.3f} R={mi['recall']:.3f}{cafa_str} | "
f"lr={lr_now:.2e} | {elapsed:.0f}s"
)
if fmax > best_fmax:
best_fmax = fmax
no_improve = 0
torch.save({
"model": model.state_dict(),
"epoch": ep,
"val_fmax": fmax,
"feature_label": feature_label,
"in_dim": in_dim,
"hidden": args.hidden,
"n_blocks": args.blocks,
"supp_mu": mu.tolist(),
"supp_sd": sd.tolist(),
"supp_cols": SUPP_COLS,
}, ckpt_out)
print(f" βœ“ Best micro-fmax={fmax:.4f} β€” saved")
else:
no_improve += 1
if no_improve >= args.patience:
print(f"\n Early stopping at epoch {ep}")
break
# Test evaluation
print(f"\n [{feature_label}] Test evaluation...")
ckpt = torch.load(ckpt_out, map_location=device)
model.load_state_dict(ckpt["model"])
test_micro = eval_micro_fmax(model, ld_te, device)
test_cafa = eval_cafa_fmax(model, ld_te, device, mlb.classes_, go_parents) if go_parents else {}
print(f" Test micro-Fmax={test_micro['micro_fmax']:.4f} "
f"P={test_micro['precision']:.4f} R={test_micro['recall']:.4f}")
if test_cafa:
print(f" Test CAFA-Fmax={test_cafa.get('cafa_fmax', 'N/A')}")
# Per-label thresholds on val set (non-propagated labels)
print(f" [{feature_label}] Computing per-label thresholds...")
thr = compute_per_label_thresholds(model, ld_va, device, support_tr)
with open(thresh_out, "w") as f:
json.dump({str(i): float(thr[i]) for i in range(OUT_DIM)}, f)
log.append({"test_micro": test_micro, "test_cafa": test_cafa})
with open(log_out, "w") as f:
json.dump(log, f, indent=2)
return {
"feature_label": feature_label,
"in_dim": in_dim,
"best_val_fmax": best_fmax,
"test_micro_fmax": test_micro["micro_fmax"],
"test_cafa_fmax": test_cafa.get("cafa_fmax"),
"ckpt": str(ckpt_out),
"thresh": str(thresh_out),
}
# ─────────────────────────────────────────────────────────────────────────────
# Main
# ─────────────────────────────────────────────────────────────────────────────
def train(args):
device = torch.device(
"mps" if torch.backends.mps.is_available() else
"cuda" if torch.cuda.is_available() else "cpu"
)
print(f"Device: {device}")
# Allow MPS to use up to 95% of unified memory (default is conservative)
if device.type == "mps":
torch.mps.set_per_process_memory_fraction(0.95)
# --start_from implies ablation mode
if args.start_from != "A":
args.ablation = True
ckpt_out = Path(args.checkpoint_out) if args.checkpoint_out else CKPT_OUT
thresh_out = Path(args.threshold_out) if args.threshold_out else THRESH_OUT
log_out = Path(args.log_out) if args.log_out else LOG_OUT
ckpt_out.parent.mkdir(parents=True, exist_ok=True)
thresh_out.parent.mkdir(parents=True, exist_ok=True)
log_out.parent.mkdir(parents=True, exist_ok=True)
# Resolve num_workers default based on device
if args.num_workers is None:
args.num_workers = 0 if device.type == "mps" else 4
print(f"DataLoader num_workers: {args.num_workers}")
# ── Load insect base dataset ───────────────────────────────────────────────
print("\n[1/6] Loading insect dataset...")
df_base = pd.read_parquet(DATA_BASE)
df_supp = pd.read_parquet(DATA_SUPP)
emb_cols = [c for c in df_base.columns if c.startswith("Dim_")]
assert len(emb_cols) == ESM_DIM
mlb = joblib.load(MLB_PATH)
assert len(mlb.classes_) == OUT_DIM
# ── GO hierarchy (for CAFA eval only β€” NOT used to inflate training labels) ─
print("[1b/6] Loading GO hierarchy (for CAFA-style evaluation only)...")
go_parents = load_go_parents(OBO_PATH)
print(f" GO parents loaded: {len(go_parents)} MF terms")
# ── Build insect label matrix (NO propagation β€” original annotations) ─────
print("[2/6] Building insect label matrix (original annotations, no propagation)...")
label_lists = [parse_labels(x) for x in df_base["Label_Indices"]]
Y_insect = np.zeros((len(df_base), OUT_DIM), dtype=np.uint8) # uint8: 2GB vs 8GB float32
for r, labs in enumerate(label_lists):
for j in labs:
if 0 <= j < OUT_DIM:
Y_insect[r, j] = 1
print(f" Insect label matrix: {int(Y_insect.sum()):,} positives across {len(Y_insect):,} proteins")
print(f" (No GO propagation applied β€” model will learn specific terms as annotated)")
# ── Supplemented features ─────────────────────────────────────────────────
S_raw = df_supp[SUPP_COLS].to_numpy(np.float32)
m_flag = df_supp["f_af_present"].to_numpy(np.float32).reshape(-1, 1)
X_base = df_base[emb_cols].to_numpy(np.float32)
# ── Splits ─────────────────────────────────────────────────────────────────
splits = np.load(SPLITS_NPZ)
train_idx = splits["train_idx"]
val_idx = splits["val_idx"]
test_idx = splits["test_idx"]
print(f" Insect splits β€” train:{len(train_idx)} val:{len(val_idx)} test:{len(test_idx)}")
# Normalise using train stats
S_tr = S_raw[train_idx]
mu = np.nanmean(S_tr, axis=0)
sd = np.where(np.nanstd(S_tr, axis=0) > 0, np.nanstd(S_tr, axis=0), 1.0)
S_z = (S_raw - mu) / (sd + 1e-12)
X_insect_full = np.concatenate([X_base, S_z, m_flag], axis=1).astype(np.float32)
print(f" Insect X shape (full features): {X_insect_full.shape}")
# ── Mammal data ────────────────────────────────────────────────────────────
X_mammal, Y_mammal = None, None
if not args.skip_mammal and MAMMAL_FASTA.exists():
print("\n[3/6] Processing mammal data...")
if MAMMAL_EMB.exists():
print(" Loading cached mammal embeddings...")
df_m = pd.read_parquet(MAMMAL_EMB)
else:
print(" Computing mammal embeddings (first run)...")
df_m = build_mammal_dataset(MAMMAL_FASTA, mlb, device)
df_m.to_parquet(MAMMAL_EMB, index=False)
m_emb_cols = [f"Dim_{i}" for i in range(ESM_DIM)]
X_m_base = df_m[m_emb_cols].to_numpy(np.float32)
# Build mammal label matrix β€” NO propagation (same as insect)
Y_m = np.zeros((len(df_m), OUT_DIM), dtype=np.uint8) # uint8 to match insect
for r, labs in enumerate(df_m["Label_Indices"].tolist()):
for j in parse_labels(labs):
if 0 <= j < OUT_DIM:
Y_m[r, j] = 1
print(f" Mammal: {int(Y_m.sum()):,} positives (original annotations, no propagation)")
S_m = np.zeros((len(df_m), SUPP_DIM), dtype=np.float32)
for ci, col in enumerate(SUPP_COLS):
if col in df_m.columns:
v = df_m[col].to_numpy(np.float32)
S_m[:, ci] = np.where(np.isnan(v), mu[ci], v)
S_m_z = (S_m - mu) / (sd + 1e-12)
m_m = df_m["f_af_present"].to_numpy(np.float32).reshape(-1, 1) if "f_af_present" in df_m else np.zeros((len(df_m), 1), np.float32)
X_mammal = np.concatenate([X_m_base, S_m_z, m_m], axis=1).astype(np.float32)
Y_mammal = Y_m
print(f" Mammal X shape: {X_mammal.shape}")
else:
print("\n[3/6] Skipping mammal data")
# ── Merge insect + mammal ─────────────────────────────────────────────────
print("\n[4/6] Merging datasets...")
if X_mammal is not None:
X_all_full = np.concatenate([X_insect_full, X_mammal], axis=0)
Y_all = np.concatenate([Y_insect, Y_mammal], axis=0)
mammal_idx = np.arange(len(X_insect_full), len(X_all_full))
train_idx_combined = np.concatenate([train_idx, mammal_idx])
print(f" Combined: {len(X_all_full)} rows (insect+mammal train: {len(train_idx_combined)})")
else:
X_all_full = X_insect_full
Y_all = Y_insect
train_idx_combined = train_idx
support_tr = Y_all[train_idx_combined].sum(0).astype(np.float32)
print(f" Training positives per label (median): {np.median(support_tr[support_tr>0]):.0f}")
# ── Ablation or single run ─────────────────────────────────────────────────
if args.ablation:
print("\n[5/6] Ablation study: ESM-only vs ESM+seq vs ESM+all (pLDDT/PAE)")
ablation_results = []
# Feature slices
# Model A: ESM only (320)
X_a = X_all_full[:, :ESM_DIM]
# Model B: ESM + sequence features (320 + 11 = 331)
seq_end = ESM_DIM + len(SEQ_ONLY_COLS)
X_b = X_all_full[:, :seq_end]
# Model C: full features including pLDDT/PAE (360)
X_c = X_all_full
variants = [
("A_ESM_only", X_a, ESM_DIM),
("B_ESM_seq", X_b, seq_end),
("C_ESM_seq_AF", X_c, X_c.shape[1]),
]
for label, X_variant, in_dim in variants:
if label[0] < args.start_from:
print(f" Skipping {label} (--start_from {args.start_from})")
continue
ckpt_v = ART / "checkpoints" / f"ablation_{label}.pth"
thresh_v = ART / "thresholds" / f"ablation_{label}_thresholds.json"
log_v = ART / "logs" / f"ablation_{label}_log.json"
result = run_training(
args, X_variant, Y_all, train_idx_combined, val_idx, test_idx,
support_tr, mu, sd, go_parents, mlb, in_dim,
label, ckpt_v, thresh_v, log_v, device,
)
ablation_results.append(result)
# Summary
ablation_out = ART / "ablation_results.json"
with open(ablation_out, "w") as f:
json.dump(ablation_results, f, indent=2)
print("\n" + "=" * 75)
print("ABLATION RESULTS")
print("=" * 75)
print(f"{'Model':<20} {'Val Fmax':>10} {'Test Fmax':>10} {'Test CAFA':>10}")
print("-" * 60)
for r in ablation_results:
cafa = f"{r['test_cafa_fmax']:.4f}" if r.get("test_cafa_fmax") else "N/A"
print(f"{r['feature_label']:<20} {r['best_val_fmax']:>10.4f} {r['test_micro_fmax']:>10.4f} {cafa:>10}")
print("=" * 75)
print(f"\nDetailed results β†’ {ablation_out}")
# The best model (Model C) becomes the deployed model
best = max(ablation_results, key=lambda r: r["test_micro_fmax"])
print(f"\nBest model: {best['feature_label']} (test Fmax={best['test_micro_fmax']:.4f})")
else:
# Single run: explicit feature slice chosen by pipeline / user
X_single = select_feature_matrix(X_all_full, args.feature_level)
in_dim = X_single.shape[1]
print(f"\n[5/6] Training single model ({args.feature_level}, in_dim={in_dim})...")
run_training(
args, X_single, Y_all, train_idx_combined, val_idx, test_idx,
support_tr, mu, sd, go_parents, mlb, in_dim,
args.feature_label, ckpt_out, thresh_out, log_out, device,
)
print(f"\n[6/6] Done.")
print(f" Main checkpoint β†’ {ckpt_out}")
print(f" Thresholds β†’ {thresh_out}")
print(f"\nTo deploy to HuggingFace:")
print(f" huggingface-cli upload Sbhat2026/protfunc-models {ckpt_out} protfunc_v3_fixed.pth")
print(f" huggingface-cli upload Sbhat2026/protfunc-models {thresh_out} protfunc_v3_fixed_thresholds.json")
if __name__ == "__main__":
p = argparse.ArgumentParser(description="ProtFunc v3 (corrected label handling)")
p.add_argument("--epochs", type=int, default=50,
help="Max training epochs (warm-start needs fewer than from-scratch)")
p.add_argument("--hidden", type=int, default=2048)
p.add_argument("--blocks", type=int, default=4)
p.add_argument("--dropout", type=float, default=0.20)
p.add_argument("--lr", type=float, default=5e-5,
help="Lower LR than from-scratch (warm-start already near optimum)")
p.add_argument("--batch", type=int, default=512)
p.add_argument("--patience", type=int, default=10)
p.add_argument("--label_smooth", type=float, default=0.05)
p.add_argument("--weight_decay", type=float, default=5e-4)
p.add_argument("--skip-mammal", action="store_true")
p.add_argument("--feature_level", type=str, default="esm_all",
choices=["esm_only", "esm_seq", "esm_all"],
help="Single-run feature slice (ignored when --ablation is used)")
p.add_argument("--feature_label", type=str, default="v3_fixed",
help="Single-run feature label stored in checkpoint/log output")
p.add_argument("--checkpoint_out", type=str, default="",
help="Optional explicit checkpoint output path")
p.add_argument("--threshold_out", type=str, default="",
help="Optional explicit threshold output path")
p.add_argument("--log_out", type=str, default="",
help="Optional explicit training log output path")
p.add_argument("--ablation", action="store_true",
help="Run ESM-only / ESM+seq / ESM+all ablation study")
p.add_argument("--start_from", type=str, default="A", choices=["A", "B", "C"],
help="Skip ablation variants before this model (implies --ablation)")
p.add_argument("--num_workers", type=int, default=None,
help="DataLoader num_workers (default: 0 on MPS, 4 otherwise)")
train(p.parse_args())