""" prep_taxon.py — Universal taxon data preparation pipeline for ProtFunc ======================================================================= Given a FASTA file for ANY taxonomic group, this script produces a standard-format parquet ready for eval_generalization.py or fine-tuning. Steps ----- 1. Parse FASTA (filters by sequence length: 30–2500 aa) 2. Fetch UniProt GO-MF annotations (experimental evidence, parallel) 3. Compute sequence features (11 physicochemical + proxy features) 4. Embed sequences with ESM-2 (t6, 8M params, 320-dim) 5. Optionally fetch AlphaFold pLDDT/PAE via AFDB API (--fetch_af flag) 6. Save parquet with same schema as mammal_embeddings_v3.parquet Output format ------------- Columns: Protein_ID, Label_Indices, Dim_0..Dim_319, f_seq_len, f_mean_hydro, f_net_charge, f_uversky_disorder, f_idr_frac_proxy, f_lowcomp_proxy, f_tm_frac_proxy, f_tm_any_proxy, f_signal_peptide_proxy, f_cf_helix_mean, f_cf_sheet_mean, f_afdb_has_model, f_plddt_mean, ..., f_pae_frac_gt20, f_seqfeat_present, f_af_present Usage ----- # Minimal: ESM embeddings + GO labels only python scripts/prep_taxon.py \\ --fasta "Important Files/mammal_subset.fasta" \\ --taxon_name mammals \\ --mlb "Important Files/mlb_public_v1.pkl" \\ --out artifacts/mammals_prepped.parquet # Full: include AlphaFold structural features python scripts/prep_taxon.py \\ --fasta mygroup.fasta \\ --taxon_name fungi \\ --mlb "Important Files/mlb_public_v1.pkl" \\ --out artifacts/fungi_prepped.parquet \\ --fetch_af # Skip GO fetch if you already have labels (provide label CSV: uniprot_id,go_ids) python scripts/prep_taxon.py \\ --fasta mygroup.fasta \\ --taxon_name archaea \\ --mlb "Important Files/mlb_public_v1.pkl" \\ --labels_csv my_labels.csv \\ --out artifacts/archaea_prepped.parquet # Keep all sequences (not just experimentally labeled) for inference-only python scripts/prep_taxon.py \\ --fasta mygroup.fasta \\ --taxon_name plants \\ --mlb "Important Files/mlb_public_v1.pkl" \\ --out artifacts/plants_prepped.parquet \\ --keep_unlabeled """ import argparse import json import math import re import threading import time import warnings from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path import joblib import numpy as np import pandas as pd import requests import torch warnings.filterwarnings("ignore") # ── Constants ───────────────────────────────────────────────────────────────── ESM_DIM = 320 API_THREADS = 20 API_TIMEOUT = 12 EXP_CODES = {"IDA","IMP","IPI","IGI","IEP","EXP","HDA","HMP","HGI","HEP","TAS","IC"} MIN_LEN, MAX_LEN = 30, 2500 # Amino acid hydrophobicity (Kyte-Doolittle) _KD = {"A":1.8,"C":2.5,"D":-3.5,"E":-3.5,"F":2.8,"G":-0.4,"H":-3.2,"I":4.5, "K":-3.9,"L":3.8,"M":1.9,"N":-3.5,"P":-1.6,"Q":-3.5,"R":-4.5,"S":-0.8, "T":-0.7,"V":4.2,"W":-0.9,"Y":-1.3} _CHARGE = {"K":1,"R":1,"D":-1,"E":-1} _POLAR = set("DEHKNQRST") _TM_HYDRO = {"I","L","V","F","M","A","W"} # ── FASTA parser ───────────────────────────────────────────────────────────── def parse_fasta(path): header, seq = None, [] with open(path) as fh: for line in fh: line = line.strip() if line.startswith(">"): if header is not None: yield header, "".join(seq) header = line[1:] seq = [] else: seq.append(line) if header is not None: yield header, "".join(seq) def extract_uniprot_id(header: str) -> str: """Extract UniProt accession from various FASTA header formats.""" # sp|P12345|GENE_HUMAN or tr|A0A123|... m = re.search(r"[st][pr]\|([A-Z0-9]+)\|", header) if m: return m.group(1) # Plain accession (word boundary) m = re.match(r"([A-Z][A-Z0-9]{5,9})\b", header.split()[0]) if m: return m.group(1) return header.split()[0] # ── UniProt GO annotation fetch ─────────────────────────────────────────────── def fetch_go_mf(uniprot_id: str, exp_codes=EXP_CODES, timeout=API_TIMEOUT) -> list: url = f"https://rest.uniprot.org/uniprotkb/{uniprot_id}.json?fields=go_f" try: r = requests.get(url, timeout=timeout) if r.status_code != 200: return [] refs = r.json().get("uniProtKBCrossReferences", []) terms = [] for ref in refs: if ref.get("database") != "GO": continue go_id = ref.get("id", "") props = {p.get("key"): p.get("value") for p in ref.get("properties", [])} aspect = props.get("GoTerm", "") evcode = props.get("GoEvidenceType", "").split(":")[0] if aspect.startswith("F:") and evcode in exp_codes: terms.append(go_id) return terms except Exception: return [] def fetch_all_go(entries: list, threads=API_THREADS) -> dict: """Parallel GO fetch. entries: list of (header, seq).""" go_labels = {} lock = threading.Lock() print(f" Fetching GO annotations ({threads} threads)...") def fetch_one(hdr_seq): hdr, _ = hdr_seq uid = extract_uniprot_id(hdr) terms = fetch_go_mf(uid) with lock: go_labels[hdr] = terms with ThreadPoolExecutor(max_workers=threads) as ex: futures = {ex.submit(fetch_one, e): e for e in entries} for i, fut in enumerate(as_completed(futures)): if (i + 1) % 100 == 0: print(f" {i+1}/{len(entries)} done") try: fut.result() except Exception: pass return go_labels # ── Sequence feature computation ───────────────────────────────────────────── def seq_features(seq: str) -> dict: """ Compute 11 physicochemical/proxy features — same set as SUPP_COLS[:11]. All features are zero-safe (no hard failures). """ s = seq.upper() n = len(s) aa = {c: s.count(c) / n for c in "ACDEFGHIKLMNPQRSTVWY" if n > 0} # Length (log-scaled) f_seq_len = math.log1p(n) # Mean hydrophobicity f_mean_hydro = np.mean([_KD.get(c, 0.0) for c in s]) if n else 0.0 # Net charge per residue (at pH 7) f_net_charge = sum(_CHARGE.get(c, 0) for c in s) / max(n, 1) # Uversky disorder proxy: fraction charged + polar charged_frac = sum(1 for c in s if c in "DEKR") / max(n, 1) hydro_frac = sum(1 for c in s if _KD.get(c, 0) > 0) / max(n, 1) f_uversky_disorder = max(0.0, charged_frac - hydro_frac) # IDR proxy: fraction of polar/charged (disordered AAs) f_idr_frac_proxy = sum(1 for c in s if c in _POLAR) / max(n, 1) # Low complexity proxy: max single-aa run / length runs = [len(m.group()) for m in re.finditer(r"(.)\1+", s)] or [1] f_lowcomp_proxy = max(runs) / max(n, 1) # TM helix proxy: fraction of strongly hydrophobic residues f_tm_frac_proxy = sum(1 for c in s if c in _TM_HYDRO) / max(n, 1) f_tm_any_proxy = float(f_tm_frac_proxy > 0.55) # Signal peptide proxy: N-terminal hydrophobic stretch (first 30 aa) nt = s[:30] f_signal_peptide_proxy = float( sum(1 for c in nt if c in _TM_HYDRO) / max(len(nt), 1) > 0.40 ) # Chou-Fasman helix/sheet propensity proxies (simplified) helix_aa = set("AELM") sheet_aa = set("CFITVY") f_cf_helix_mean = sum(1 for c in s if c in helix_aa) / max(n, 1) f_cf_sheet_mean = sum(1 for c in s if c in sheet_aa) / max(n, 1) return { "f_seq_len": f_seq_len, "f_mean_hydro": f_mean_hydro, "f_net_charge": f_net_charge, "f_uversky_disorder": f_uversky_disorder, "f_idr_frac_proxy": f_idr_frac_proxy, "f_lowcomp_proxy": f_lowcomp_proxy, "f_tm_frac_proxy": f_tm_frac_proxy, "f_tm_any_proxy": f_tm_any_proxy, "f_signal_peptide_proxy": f_signal_peptide_proxy, "f_cf_helix_mean": f_cf_helix_mean, "f_cf_sheet_mean": f_cf_sheet_mean, } # ── AlphaFold feature fetch ─────────────────────────────────────────────────── def fetch_af_features(uniprot_id: str, timeout=20) -> dict | None: """ Fetch pLDDT and PAE statistics from AlphaFold DB v4. Returns None if not found or fetch fails. """ base = f"https://alphafold.ebi.ac.uk/api/prediction/{uniprot_id}" try: r = requests.get(base, timeout=timeout) if r.status_code != 200: return None data = r.json() if not data: return None entry = data[0] plddt_url = entry.get("plddt_url") or entry.get("cifUrl", "").replace(".cif", ".json") pae_url = entry.get("paeImageUrl", "").replace("-pae.png", ".json") plddt_vals, pae_vals = None, None # Try to get per-residue pLDDT scores_url = entry.get("confidenceUrl") or plddt_url if scores_url: rs = requests.get(scores_url, timeout=timeout) if rs.status_code == 200: try: plddt_vals = np.array(rs.json(), dtype=np.float32) except Exception: pass # Try to get PAE matrix pae_json_url = entry.get("paeDocUrl") or pae_url if pae_json_url and pae_json_url.endswith(".json"): rp = requests.get(pae_json_url, timeout=timeout) if rp.status_code == 200: try: raw = rp.json() if isinstance(raw, list) and raw: pae_vals = np.array(raw[0].get("predicted_aligned_error", []), dtype=np.float32).ravel() except Exception: pass if plddt_vals is None: return None p = plddt_vals feat = { "f_afdb_has_model": 1.0, "f_plddt_mean": float(np.mean(p)), "f_plddt_std": float(np.std(p)), "f_plddt_q10": float(np.percentile(p, 10)), "f_plddt_q50": float(np.percentile(p, 50)), "f_plddt_q90": float(np.percentile(p, 90)), "f_plddt_frac_gt90": float((p > 90).mean()), "f_plddt_frac_gt70": float((p > 70).mean()), "f_plddt_frac_lt50": float((p < 50).mean()), } # Distance bins (pLDDT deciles as proxy for structural confidence distribution) bins = np.linspace(0, 100, 11) hist, _ = np.histogram(p, bins=bins) for i, v in enumerate(hist): feat[f"f_distbin_{i}"] = float(v) / max(len(p), 1) if pae_vals is not None and len(pae_vals) > 0: pae = pae_vals feat.update({ "f_pae_mean": float(np.mean(pae)), "f_pae_median": float(np.median(pae)), "f_pae_p90": float(np.percentile(pae, 90)), "f_pae_p95": float(np.percentile(pae, 95)), "f_pae_frac_lt5": float((pae < 5).mean()), "f_pae_frac_lt10": float((pae < 10).mean()), "f_pae_frac_gt20": float((pae > 20).mean()), }) feat["f_seqfeat_present"] = 1.0 feat["f_af_present"] = 1.0 return feat except Exception: return None # ── ESM-2 embedding ─────────────────────────────────────────────────────────── def embed_sequences(entries: list, device: torch.device, token_budget: int = 6000) -> dict: """ entries: list of (header, seq) Returns {header: np.ndarray of shape (320,)} """ import esm as esm_lib print(" Loading ESM-2 (esm2_t6_8M_UR50D)...") esm_model, alphabet = esm_lib.pretrained.esm2_t6_8M_UR50D() esm_model = esm_model.to(device).eval() bc = alphabet.get_batch_converter() results = {} batch_buf, total_tok = [], 0 def flush(buf): batch_data = [(h, s) for h, s, _ in buf] _, _, toks = bc(batch_data) with torch.no_grad(): rep = esm_model(toks.to(device), repr_layers=[6])["representations"][6] for k, (h, s, _) in enumerate(buf): emb = rep[k, 1:len(s)+1].mean(0).cpu().numpy() results[h] = emb for h, s in entries: ntok = len(s) + 2 if total_tok + ntok > token_budget and batch_buf: flush(batch_buf) batch_buf, total_tok = [], 0 batch_buf.append((h, s, None)) total_tok += ntok if batch_buf: flush(batch_buf) return results # ── Main ────────────────────────────────────────────────────────────────────── def main(): ap = argparse.ArgumentParser(description="ProtFunc universal taxon data prep") ap.add_argument("--fasta", required=True) ap.add_argument("--taxon_name", required=True, help="e.g. 'mammals', 'fungi'") ap.add_argument("--mlb", required=True, help="Path to mlb_public_v1.pkl") ap.add_argument("--out", required=True, help="Output .parquet path") ap.add_argument("--labels_csv", default=None, help="CSV with columns uniprot_id,go_ids (semicolon-sep). " "Skips UniProt API fetch if provided.") ap.add_argument("--fetch_af", action="store_true", help="Fetch AlphaFold pLDDT/PAE features from AFDB") ap.add_argument("--keep_unlabeled", action="store_true", help="Include proteins with 0 GO labels (for inference-only use)") ap.add_argument("--max_len", type=int, default=MAX_LEN) ap.add_argument("--min_len", type=int, default=MIN_LEN) ap.add_argument("--threads", type=int, default=API_THREADS) ap.add_argument("--device", default="auto") args = ap.parse_args() if args.device == "auto": device = torch.device( "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu" ) else: device = torch.device(args.device) print(f"Device: {device}") print(f"Taxon: {args.taxon_name}") # ── Parse FASTA ────────────────────────────────────────────────────────── print(f"\n[1/5] Parsing FASTA: {args.fasta}") all_entries = [ (h, s) for h, s in parse_fasta(args.fasta) if args.min_len <= len(s) <= args.max_len ] print(f" {len(all_entries)} sequences ({args.min_len}–{args.max_len} aa)") if not all_entries: print(" No sequences in range — exiting.") return # ── Load MLB ───────────────────────────────────────────────────────────── mlb = joblib.load(args.mlb) go2idx = {g: i for i, g in enumerate(mlb.classes_)} print(f" MLB loaded: {len(mlb.classes_)} GO-MF terms") # ── GO annotations ──────────────────────────────────────────────────────── print(f"\n[2/5] GO annotations...") if args.labels_csv: print(f" Loading from CSV: {args.labels_csv}") df_lab = pd.read_csv(args.labels_csv) ext_labels = {} for _, row in df_lab.iterrows(): uid = str(row["uniprot_id"]) gids = [g.strip() for g in str(row.get("go_ids", "")).split(";") if g.strip()] ext_labels[uid] = gids go_labels = {h: ext_labels.get(extract_uniprot_id(h), []) for h, _ in all_entries} else: go_labels = fetch_all_go(all_entries, threads=args.threads) # Map to label indices label_index_map = {} for h, _ in all_entries: terms = go_labels.get(h, []) idxs = [go2idx[t] for t in terms if t in go2idx] label_index_map[h] = idxs n_labeled = sum(1 for v in label_index_map.values() if len(v) > 0) print(f" {n_labeled}/{len(all_entries)} sequences have ≥1 GO-MF label") # Filter to labeled unless keep_unlabeled if not args.keep_unlabeled: entries = [(h, s) for h, s in all_entries if label_index_map.get(h)] print(f" Keeping {len(entries)} labeled sequences (use --keep_unlabeled to keep all)") else: entries = all_entries print(f" Keeping all {len(entries)} sequences") if not entries: print(" No sequences to embed — exiting.") return # ── Sequence features ───────────────────────────────────────────────────── print(f"\n[3/5] Computing sequence features...") seq_feats = {h: seq_features(s) for h, s in entries} print(f" Done for {len(seq_feats)} sequences") # ── ESM-2 embeddings ────────────────────────────────────────────────────── print(f"\n[4/5] ESM-2 embeddings...") t0 = time.time() embeddings = embed_sequences(entries, device) print(f" Embedded {len(embeddings)} sequences in {time.time()-t0:.1f}s") # ── AlphaFold features (optional) ───────────────────────────────────────── af_feats = {} if args.fetch_af: print(f"\n[4b/5] Fetching AlphaFold features (parallel, {args.threads} threads)...") af_lock = threading.Lock() def fetch_af_one(hdr_seq): h, _ = hdr_seq uid = extract_uniprot_id(h) feat = fetch_af_features(uid) with af_lock: af_feats[h] = feat with ThreadPoolExecutor(max_workers=args.threads) as ex: futs = {ex.submit(fetch_af_one, e): e for e in entries} for i, fut in enumerate(as_completed(futs)): if (i + 1) % 50 == 0: found = sum(1 for v in af_feats.values() if v is not None) print(f" {i+1}/{len(entries)} fetched ({found} with AF models)") try: fut.result() except Exception: pass found = sum(1 for v in af_feats.values() if v is not None) print(f" AlphaFold coverage: {found}/{len(entries)} ({100*found/max(len(entries),1):.1f}%)") # ── Assemble parquet ────────────────────────────────────────────────────── print(f"\n[5/5] Assembling output parquet...") SUPP_COLS = [ "f_seq_len", "f_mean_hydro", "f_net_charge", "f_uversky_disorder", "f_idr_frac_proxy", "f_lowcomp_proxy", "f_tm_frac_proxy", "f_tm_any_proxy", "f_signal_peptide_proxy", "f_cf_helix_mean", "f_cf_sheet_mean", "f_afdb_has_model", "f_plddt_mean", "f_plddt_std", "f_plddt_q10", "f_plddt_q50", "f_plddt_q90", "f_plddt_frac_gt90", "f_plddt_frac_gt70", "f_plddt_frac_lt50", "f_distbin_0", "f_distbin_1", "f_distbin_2", "f_distbin_3", "f_distbin_4", "f_distbin_5", "f_distbin_6", "f_distbin_7", "f_distbin_8", "f_distbin_9", "f_pae_mean", "f_pae_median", "f_pae_p90", "f_pae_p95", "f_pae_frac_lt5", "f_pae_frac_lt10", "f_pae_frac_gt20", "f_seqfeat_present", "f_af_present", ] rows = [] for h, s in entries: emb = embeddings.get(h) if emb is None: continue row = {"Protein_ID": h, "Label_Indices": label_index_map.get(h, [])} for i, v in enumerate(emb): row[f"Dim_{i}"] = float(v) sf = seq_feats.get(h, {}) aff = af_feats.get(h) if args.fetch_af else None merged_feats = {} # Sequence features (always present) for k in ["f_seq_len","f_mean_hydro","f_net_charge","f_uversky_disorder", "f_idr_frac_proxy","f_lowcomp_proxy","f_tm_frac_proxy","f_tm_any_proxy", "f_signal_peptide_proxy","f_cf_helix_mean","f_cf_sheet_mean"]: merged_feats[k] = sf.get(k, 0.0) merged_feats["f_seqfeat_present"] = 1.0 # AF features (zero if not fetched or not found) if aff is not None: for k, v in aff.items(): merged_feats[k] = v else: for k in SUPP_COLS: if k not in merged_feats: merged_feats[k] = 0.0 for k in SUPP_COLS: row[k] = float(merged_feats.get(k, 0.0)) rows.append(row) df_out = pd.DataFrame(rows) out_path = Path(args.out) out_path.parent.mkdir(exist_ok=True) df_out.to_parquet(out_path, index=False) print(f" Saved {len(df_out)} proteins → {out_path}") # Summary n_with_labels = int((df_out["Label_Indices"].apply(len) > 0).sum()) n_with_af = int((df_out["f_af_present"] > 0).sum()) if "f_af_present" in df_out else 0 print(f"\nSummary for '{args.taxon_name}':") print(f" Total proteins: {len(df_out)}") print(f" With GO-MF labels: {n_with_labels}") print(f" With AF structure: {n_with_af}") print(f" Output: {out_path}") print(f"\nNext step:") print(f" python scripts/eval_generalization.py \\") print(f" --checkpoint artifacts/protfunc_v3_fixed.pth \\") print(f" --thresholds artifacts/protfunc_v3_fixed_thresholds.json \\") print(f" --mlb 'Important Files/mlb_public_v1.pkl' \\") print(f" --taxon_parquet {out_path} \\") print(f" --taxon_name {args.taxon_name}") if __name__ == "__main__": main()