| """ |
| prep_taxon.py β Universal taxon data preparation pipeline for ProtFunc |
| ======================================================================= |
| Given a FASTA file for ANY taxonomic group, this script produces a |
| standard-format parquet ready for eval_generalization.py or fine-tuning. |
| |
| Steps |
| ----- |
| 1. Parse FASTA (filters by sequence length: 30β2500 aa) |
| 2. Fetch UniProt GO-MF annotations (experimental evidence, parallel) |
| 3. Compute sequence features (11 physicochemical + proxy features) |
| 4. Embed sequences with ESM-2 (t6, 8M params, 320-dim) |
| 5. Optionally fetch AlphaFold pLDDT/PAE via AFDB API (--fetch_af flag) |
| 6. Save parquet with same schema as mammal_embeddings_v3.parquet |
| |
| Output format |
| ------------- |
| Columns: Protein_ID, Label_Indices, Dim_0..Dim_319, |
| f_seq_len, f_mean_hydro, f_net_charge, f_uversky_disorder, |
| f_idr_frac_proxy, f_lowcomp_proxy, f_tm_frac_proxy, f_tm_any_proxy, |
| f_signal_peptide_proxy, f_cf_helix_mean, f_cf_sheet_mean, |
| f_afdb_has_model, |
| f_plddt_mean, ..., f_pae_frac_gt20, |
| f_seqfeat_present, f_af_present |
| |
| Usage |
| ----- |
| # Minimal: ESM embeddings + GO labels only |
| python scripts/prep_taxon.py \\ |
| --fasta "Important Files/mammal_subset.fasta" \\ |
| --taxon_name mammals \\ |
| --mlb "Important Files/mlb_public_v1.pkl" \\ |
| --out artifacts/mammals_prepped.parquet |
| |
| # Full: include AlphaFold structural features |
| python scripts/prep_taxon.py \\ |
| --fasta mygroup.fasta \\ |
| --taxon_name fungi \\ |
| --mlb "Important Files/mlb_public_v1.pkl" \\ |
| --out artifacts/fungi_prepped.parquet \\ |
| --fetch_af |
| |
| # Skip GO fetch if you already have labels (provide label CSV: uniprot_id,go_ids) |
| python scripts/prep_taxon.py \\ |
| --fasta mygroup.fasta \\ |
| --taxon_name archaea \\ |
| --mlb "Important Files/mlb_public_v1.pkl" \\ |
| --labels_csv my_labels.csv \\ |
| --out artifacts/archaea_prepped.parquet |
| |
| # Keep all sequences (not just experimentally labeled) for inference-only |
| python scripts/prep_taxon.py \\ |
| --fasta mygroup.fasta \\ |
| --taxon_name plants \\ |
| --mlb "Important Files/mlb_public_v1.pkl" \\ |
| --out artifacts/plants_prepped.parquet \\ |
| --keep_unlabeled |
| """ |
|
|
| import argparse |
| import json |
| import math |
| import re |
| import threading |
| import time |
| import warnings |
| from concurrent.futures import ThreadPoolExecutor, as_completed |
| from pathlib import Path |
|
|
| import joblib |
| import numpy as np |
| import pandas as pd |
| import requests |
| import torch |
|
|
| warnings.filterwarnings("ignore") |
|
|
| |
| ESM_DIM = 320 |
| API_THREADS = 20 |
| API_TIMEOUT = 12 |
| EXP_CODES = {"IDA","IMP","IPI","IGI","IEP","EXP","HDA","HMP","HGI","HEP","TAS","IC"} |
| MIN_LEN, MAX_LEN = 30, 2500 |
|
|
| |
| _KD = {"A":1.8,"C":2.5,"D":-3.5,"E":-3.5,"F":2.8,"G":-0.4,"H":-3.2,"I":4.5, |
| "K":-3.9,"L":3.8,"M":1.9,"N":-3.5,"P":-1.6,"Q":-3.5,"R":-4.5,"S":-0.8, |
| "T":-0.7,"V":4.2,"W":-0.9,"Y":-1.3} |
| _CHARGE = {"K":1,"R":1,"D":-1,"E":-1} |
| _POLAR = set("DEHKNQRST") |
| _TM_HYDRO = {"I","L","V","F","M","A","W"} |
|
|
|
|
| |
|
|
| def parse_fasta(path): |
| header, seq = None, [] |
| with open(path) as fh: |
| for line in fh: |
| line = line.strip() |
| if line.startswith(">"): |
| if header is not None: |
| yield header, "".join(seq) |
| header = line[1:] |
| seq = [] |
| else: |
| seq.append(line) |
| if header is not None: |
| yield header, "".join(seq) |
|
|
|
|
| def extract_uniprot_id(header: str) -> str: |
| """Extract UniProt accession from various FASTA header formats.""" |
| |
| m = re.search(r"[st][pr]\|([A-Z0-9]+)\|", header) |
| if m: |
| return m.group(1) |
| |
| m = re.match(r"([A-Z][A-Z0-9]{5,9})\b", header.split()[0]) |
| if m: |
| return m.group(1) |
| return header.split()[0] |
|
|
|
|
| |
|
|
| def fetch_go_mf(uniprot_id: str, exp_codes=EXP_CODES, timeout=API_TIMEOUT) -> list: |
| url = f"https://rest.uniprot.org/uniprotkb/{uniprot_id}.json?fields=go_f" |
| try: |
| r = requests.get(url, timeout=timeout) |
| if r.status_code != 200: |
| return [] |
| refs = r.json().get("uniProtKBCrossReferences", []) |
| terms = [] |
| for ref in refs: |
| if ref.get("database") != "GO": |
| continue |
| go_id = ref.get("id", "") |
| props = {p.get("key"): p.get("value") for p in ref.get("properties", [])} |
| aspect = props.get("GoTerm", "") |
| evcode = props.get("GoEvidenceType", "").split(":")[0] |
| if aspect.startswith("F:") and evcode in exp_codes: |
| terms.append(go_id) |
| return terms |
| except Exception: |
| return [] |
|
|
|
|
| def fetch_all_go(entries: list, threads=API_THREADS) -> dict: |
| """Parallel GO fetch. entries: list of (header, seq).""" |
| go_labels = {} |
| lock = threading.Lock() |
| print(f" Fetching GO annotations ({threads} threads)...") |
|
|
| def fetch_one(hdr_seq): |
| hdr, _ = hdr_seq |
| uid = extract_uniprot_id(hdr) |
| terms = fetch_go_mf(uid) |
| with lock: |
| go_labels[hdr] = terms |
|
|
| with ThreadPoolExecutor(max_workers=threads) as ex: |
| futures = {ex.submit(fetch_one, e): e for e in entries} |
| for i, fut in enumerate(as_completed(futures)): |
| if (i + 1) % 100 == 0: |
| print(f" {i+1}/{len(entries)} done") |
| try: |
| fut.result() |
| except Exception: |
| pass |
|
|
| return go_labels |
|
|
|
|
| |
|
|
| def seq_features(seq: str) -> dict: |
| """ |
| Compute 11 physicochemical/proxy features β same set as SUPP_COLS[:11]. |
| All features are zero-safe (no hard failures). |
| """ |
| s = seq.upper() |
| n = len(s) |
| aa = {c: s.count(c) / n for c in "ACDEFGHIKLMNPQRSTVWY" if n > 0} |
|
|
| |
| f_seq_len = math.log1p(n) |
|
|
| |
| f_mean_hydro = np.mean([_KD.get(c, 0.0) for c in s]) if n else 0.0 |
|
|
| |
| f_net_charge = sum(_CHARGE.get(c, 0) for c in s) / max(n, 1) |
|
|
| |
| charged_frac = sum(1 for c in s if c in "DEKR") / max(n, 1) |
| hydro_frac = sum(1 for c in s if _KD.get(c, 0) > 0) / max(n, 1) |
| f_uversky_disorder = max(0.0, charged_frac - hydro_frac) |
|
|
| |
| f_idr_frac_proxy = sum(1 for c in s if c in _POLAR) / max(n, 1) |
|
|
| |
| runs = [len(m.group()) for m in re.finditer(r"(.)\1+", s)] or [1] |
| f_lowcomp_proxy = max(runs) / max(n, 1) |
|
|
| |
| f_tm_frac_proxy = sum(1 for c in s if c in _TM_HYDRO) / max(n, 1) |
| f_tm_any_proxy = float(f_tm_frac_proxy > 0.55) |
|
|
| |
| nt = s[:30] |
| f_signal_peptide_proxy = float( |
| sum(1 for c in nt if c in _TM_HYDRO) / max(len(nt), 1) > 0.40 |
| ) |
|
|
| |
| helix_aa = set("AELM") |
| sheet_aa = set("CFITVY") |
| f_cf_helix_mean = sum(1 for c in s if c in helix_aa) / max(n, 1) |
| f_cf_sheet_mean = sum(1 for c in s if c in sheet_aa) / max(n, 1) |
|
|
| return { |
| "f_seq_len": f_seq_len, |
| "f_mean_hydro": f_mean_hydro, |
| "f_net_charge": f_net_charge, |
| "f_uversky_disorder": f_uversky_disorder, |
| "f_idr_frac_proxy": f_idr_frac_proxy, |
| "f_lowcomp_proxy": f_lowcomp_proxy, |
| "f_tm_frac_proxy": f_tm_frac_proxy, |
| "f_tm_any_proxy": f_tm_any_proxy, |
| "f_signal_peptide_proxy": f_signal_peptide_proxy, |
| "f_cf_helix_mean": f_cf_helix_mean, |
| "f_cf_sheet_mean": f_cf_sheet_mean, |
| } |
|
|
|
|
| |
|
|
| def fetch_af_features(uniprot_id: str, timeout=20) -> dict | None: |
| """ |
| Fetch pLDDT and PAE statistics from AlphaFold DB v4. |
| Returns None if not found or fetch fails. |
| """ |
| base = f"https://alphafold.ebi.ac.uk/api/prediction/{uniprot_id}" |
| try: |
| r = requests.get(base, timeout=timeout) |
| if r.status_code != 200: |
| return None |
| data = r.json() |
| if not data: |
| return None |
| entry = data[0] |
|
|
| plddt_url = entry.get("plddt_url") or entry.get("cifUrl", "").replace(".cif", ".json") |
| pae_url = entry.get("paeImageUrl", "").replace("-pae.png", ".json") |
|
|
| plddt_vals, pae_vals = None, None |
|
|
| |
| scores_url = entry.get("confidenceUrl") or plddt_url |
| if scores_url: |
| rs = requests.get(scores_url, timeout=timeout) |
| if rs.status_code == 200: |
| try: |
| plddt_vals = np.array(rs.json(), dtype=np.float32) |
| except Exception: |
| pass |
|
|
| |
| pae_json_url = entry.get("paeDocUrl") or pae_url |
| if pae_json_url and pae_json_url.endswith(".json"): |
| rp = requests.get(pae_json_url, timeout=timeout) |
| if rp.status_code == 200: |
| try: |
| raw = rp.json() |
| if isinstance(raw, list) and raw: |
| pae_vals = np.array(raw[0].get("predicted_aligned_error", []), |
| dtype=np.float32).ravel() |
| except Exception: |
| pass |
|
|
| if plddt_vals is None: |
| return None |
|
|
| p = plddt_vals |
| feat = { |
| "f_afdb_has_model": 1.0, |
| "f_plddt_mean": float(np.mean(p)), |
| "f_plddt_std": float(np.std(p)), |
| "f_plddt_q10": float(np.percentile(p, 10)), |
| "f_plddt_q50": float(np.percentile(p, 50)), |
| "f_plddt_q90": float(np.percentile(p, 90)), |
| "f_plddt_frac_gt90": float((p > 90).mean()), |
| "f_plddt_frac_gt70": float((p > 70).mean()), |
| "f_plddt_frac_lt50": float((p < 50).mean()), |
| } |
| |
| bins = np.linspace(0, 100, 11) |
| hist, _ = np.histogram(p, bins=bins) |
| for i, v in enumerate(hist): |
| feat[f"f_distbin_{i}"] = float(v) / max(len(p), 1) |
|
|
| if pae_vals is not None and len(pae_vals) > 0: |
| pae = pae_vals |
| feat.update({ |
| "f_pae_mean": float(np.mean(pae)), |
| "f_pae_median": float(np.median(pae)), |
| "f_pae_p90": float(np.percentile(pae, 90)), |
| "f_pae_p95": float(np.percentile(pae, 95)), |
| "f_pae_frac_lt5": float((pae < 5).mean()), |
| "f_pae_frac_lt10": float((pae < 10).mean()), |
| "f_pae_frac_gt20": float((pae > 20).mean()), |
| }) |
|
|
| feat["f_seqfeat_present"] = 1.0 |
| feat["f_af_present"] = 1.0 |
| return feat |
|
|
| except Exception: |
| return None |
|
|
|
|
| |
|
|
| def embed_sequences(entries: list, device: torch.device, |
| token_budget: int = 6000) -> dict: |
| """ |
| entries: list of (header, seq) |
| Returns {header: np.ndarray of shape (320,)} |
| """ |
| import esm as esm_lib |
| print(" Loading ESM-2 (esm2_t6_8M_UR50D)...") |
| esm_model, alphabet = esm_lib.pretrained.esm2_t6_8M_UR50D() |
| esm_model = esm_model.to(device).eval() |
| bc = alphabet.get_batch_converter() |
|
|
| results = {} |
| batch_buf, total_tok = [], 0 |
|
|
| def flush(buf): |
| batch_data = [(h, s) for h, s, _ in buf] |
| _, _, toks = bc(batch_data) |
| with torch.no_grad(): |
| rep = esm_model(toks.to(device), repr_layers=[6])["representations"][6] |
| for k, (h, s, _) in enumerate(buf): |
| emb = rep[k, 1:len(s)+1].mean(0).cpu().numpy() |
| results[h] = emb |
|
|
| for h, s in entries: |
| ntok = len(s) + 2 |
| if total_tok + ntok > token_budget and batch_buf: |
| flush(batch_buf) |
| batch_buf, total_tok = [], 0 |
| batch_buf.append((h, s, None)) |
| total_tok += ntok |
| if batch_buf: |
| flush(batch_buf) |
|
|
| return results |
|
|
|
|
| |
|
|
| def main(): |
| ap = argparse.ArgumentParser(description="ProtFunc universal taxon data prep") |
| ap.add_argument("--fasta", required=True) |
| ap.add_argument("--taxon_name", required=True, help="e.g. 'mammals', 'fungi'") |
| ap.add_argument("--mlb", required=True, help="Path to mlb_public_v1.pkl") |
| ap.add_argument("--out", required=True, help="Output .parquet path") |
| ap.add_argument("--labels_csv", default=None, |
| help="CSV with columns uniprot_id,go_ids (semicolon-sep). " |
| "Skips UniProt API fetch if provided.") |
| ap.add_argument("--fetch_af", action="store_true", |
| help="Fetch AlphaFold pLDDT/PAE features from AFDB") |
| ap.add_argument("--keep_unlabeled", action="store_true", |
| help="Include proteins with 0 GO labels (for inference-only use)") |
| ap.add_argument("--max_len", type=int, default=MAX_LEN) |
| ap.add_argument("--min_len", type=int, default=MIN_LEN) |
| ap.add_argument("--threads", type=int, default=API_THREADS) |
| ap.add_argument("--device", default="auto") |
| args = ap.parse_args() |
|
|
| if args.device == "auto": |
| device = torch.device( |
| "mps" if torch.backends.mps.is_available() else |
| "cuda" if torch.cuda.is_available() else "cpu" |
| ) |
| else: |
| device = torch.device(args.device) |
| print(f"Device: {device}") |
| print(f"Taxon: {args.taxon_name}") |
|
|
| |
| print(f"\n[1/5] Parsing FASTA: {args.fasta}") |
| all_entries = [ |
| (h, s) for h, s in parse_fasta(args.fasta) |
| if args.min_len <= len(s) <= args.max_len |
| ] |
| print(f" {len(all_entries)} sequences ({args.min_len}β{args.max_len} aa)") |
|
|
| if not all_entries: |
| print(" No sequences in range β exiting.") |
| return |
|
|
| |
| mlb = joblib.load(args.mlb) |
| go2idx = {g: i for i, g in enumerate(mlb.classes_)} |
| print(f" MLB loaded: {len(mlb.classes_)} GO-MF terms") |
|
|
| |
| print(f"\n[2/5] GO annotations...") |
| if args.labels_csv: |
| print(f" Loading from CSV: {args.labels_csv}") |
| df_lab = pd.read_csv(args.labels_csv) |
| ext_labels = {} |
| for _, row in df_lab.iterrows(): |
| uid = str(row["uniprot_id"]) |
| gids = [g.strip() for g in str(row.get("go_ids", "")).split(";") if g.strip()] |
| ext_labels[uid] = gids |
| go_labels = {h: ext_labels.get(extract_uniprot_id(h), []) for h, _ in all_entries} |
| else: |
| go_labels = fetch_all_go(all_entries, threads=args.threads) |
|
|
| |
| label_index_map = {} |
| for h, _ in all_entries: |
| terms = go_labels.get(h, []) |
| idxs = [go2idx[t] for t in terms if t in go2idx] |
| label_index_map[h] = idxs |
|
|
| n_labeled = sum(1 for v in label_index_map.values() if len(v) > 0) |
| print(f" {n_labeled}/{len(all_entries)} sequences have β₯1 GO-MF label") |
|
|
| |
| if not args.keep_unlabeled: |
| entries = [(h, s) for h, s in all_entries if label_index_map.get(h)] |
| print(f" Keeping {len(entries)} labeled sequences (use --keep_unlabeled to keep all)") |
| else: |
| entries = all_entries |
| print(f" Keeping all {len(entries)} sequences") |
|
|
| if not entries: |
| print(" No sequences to embed β exiting.") |
| return |
|
|
| |
| print(f"\n[3/5] Computing sequence features...") |
| seq_feats = {h: seq_features(s) for h, s in entries} |
| print(f" Done for {len(seq_feats)} sequences") |
|
|
| |
| print(f"\n[4/5] ESM-2 embeddings...") |
| t0 = time.time() |
| embeddings = embed_sequences(entries, device) |
| print(f" Embedded {len(embeddings)} sequences in {time.time()-t0:.1f}s") |
|
|
| |
| af_feats = {} |
| if args.fetch_af: |
| print(f"\n[4b/5] Fetching AlphaFold features (parallel, {args.threads} threads)...") |
| af_lock = threading.Lock() |
|
|
| def fetch_af_one(hdr_seq): |
| h, _ = hdr_seq |
| uid = extract_uniprot_id(h) |
| feat = fetch_af_features(uid) |
| with af_lock: |
| af_feats[h] = feat |
|
|
| with ThreadPoolExecutor(max_workers=args.threads) as ex: |
| futs = {ex.submit(fetch_af_one, e): e for e in entries} |
| for i, fut in enumerate(as_completed(futs)): |
| if (i + 1) % 50 == 0: |
| found = sum(1 for v in af_feats.values() if v is not None) |
| print(f" {i+1}/{len(entries)} fetched ({found} with AF models)") |
| try: |
| fut.result() |
| except Exception: |
| pass |
|
|
| found = sum(1 for v in af_feats.values() if v is not None) |
| print(f" AlphaFold coverage: {found}/{len(entries)} ({100*found/max(len(entries),1):.1f}%)") |
|
|
| |
| print(f"\n[5/5] Assembling output parquet...") |
| SUPP_COLS = [ |
| "f_seq_len", "f_mean_hydro", "f_net_charge", "f_uversky_disorder", |
| "f_idr_frac_proxy", "f_lowcomp_proxy", "f_tm_frac_proxy", "f_tm_any_proxy", |
| "f_signal_peptide_proxy", "f_cf_helix_mean", "f_cf_sheet_mean", |
| "f_afdb_has_model", |
| "f_plddt_mean", "f_plddt_std", "f_plddt_q10", "f_plddt_q50", "f_plddt_q90", |
| "f_plddt_frac_gt90", "f_plddt_frac_gt70", "f_plddt_frac_lt50", |
| "f_distbin_0", "f_distbin_1", "f_distbin_2", "f_distbin_3", "f_distbin_4", |
| "f_distbin_5", "f_distbin_6", "f_distbin_7", "f_distbin_8", "f_distbin_9", |
| "f_pae_mean", "f_pae_median", "f_pae_p90", "f_pae_p95", |
| "f_pae_frac_lt5", "f_pae_frac_lt10", "f_pae_frac_gt20", |
| "f_seqfeat_present", "f_af_present", |
| ] |
|
|
| rows = [] |
| for h, s in entries: |
| emb = embeddings.get(h) |
| if emb is None: |
| continue |
| row = {"Protein_ID": h, "Label_Indices": label_index_map.get(h, [])} |
| for i, v in enumerate(emb): |
| row[f"Dim_{i}"] = float(v) |
|
|
| sf = seq_feats.get(h, {}) |
| aff = af_feats.get(h) if args.fetch_af else None |
| merged_feats = {} |
| |
| for k in ["f_seq_len","f_mean_hydro","f_net_charge","f_uversky_disorder", |
| "f_idr_frac_proxy","f_lowcomp_proxy","f_tm_frac_proxy","f_tm_any_proxy", |
| "f_signal_peptide_proxy","f_cf_helix_mean","f_cf_sheet_mean"]: |
| merged_feats[k] = sf.get(k, 0.0) |
| merged_feats["f_seqfeat_present"] = 1.0 |
|
|
| |
| if aff is not None: |
| for k, v in aff.items(): |
| merged_feats[k] = v |
| else: |
| for k in SUPP_COLS: |
| if k not in merged_feats: |
| merged_feats[k] = 0.0 |
|
|
| for k in SUPP_COLS: |
| row[k] = float(merged_feats.get(k, 0.0)) |
|
|
| rows.append(row) |
|
|
| df_out = pd.DataFrame(rows) |
| out_path = Path(args.out) |
| out_path.parent.mkdir(exist_ok=True) |
| df_out.to_parquet(out_path, index=False) |
| print(f" Saved {len(df_out)} proteins β {out_path}") |
|
|
| |
| n_with_labels = int((df_out["Label_Indices"].apply(len) > 0).sum()) |
| n_with_af = int((df_out["f_af_present"] > 0).sum()) if "f_af_present" in df_out else 0 |
| print(f"\nSummary for '{args.taxon_name}':") |
| print(f" Total proteins: {len(df_out)}") |
| print(f" With GO-MF labels: {n_with_labels}") |
| print(f" With AF structure: {n_with_af}") |
| print(f" Output: {out_path}") |
| print(f"\nNext step:") |
| print(f" python scripts/eval_generalization.py \\") |
| print(f" --checkpoint artifacts/protfunc_v3_fixed.pth \\") |
| print(f" --thresholds artifacts/protfunc_v3_fixed_thresholds.json \\") |
| print(f" --mlb 'Important Files/mlb_public_v1.pkl' \\") |
| print(f" --taxon_parquet {out_path} \\") |
| print(f" --taxon_name {args.taxon_name}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|