protfunc / scripts /prep_taxon.py
Sbhat2026's picture
perf: ESM embedding cache + 1500aa limit, add research scripts
7f7a890
"""
prep_taxon.py β€” Universal taxon data preparation pipeline for ProtFunc
=======================================================================
Given a FASTA file for ANY taxonomic group, this script produces a
standard-format parquet ready for eval_generalization.py or fine-tuning.
Steps
-----
1. Parse FASTA (filters by sequence length: 30–2500 aa)
2. Fetch UniProt GO-MF annotations (experimental evidence, parallel)
3. Compute sequence features (11 physicochemical + proxy features)
4. Embed sequences with ESM-2 (t6, 8M params, 320-dim)
5. Optionally fetch AlphaFold pLDDT/PAE via AFDB API (--fetch_af flag)
6. Save parquet with same schema as mammal_embeddings_v3.parquet
Output format
-------------
Columns: Protein_ID, Label_Indices, Dim_0..Dim_319,
f_seq_len, f_mean_hydro, f_net_charge, f_uversky_disorder,
f_idr_frac_proxy, f_lowcomp_proxy, f_tm_frac_proxy, f_tm_any_proxy,
f_signal_peptide_proxy, f_cf_helix_mean, f_cf_sheet_mean,
f_afdb_has_model,
f_plddt_mean, ..., f_pae_frac_gt20,
f_seqfeat_present, f_af_present
Usage
-----
# Minimal: ESM embeddings + GO labels only
python scripts/prep_taxon.py \\
--fasta "Important Files/mammal_subset.fasta" \\
--taxon_name mammals \\
--mlb "Important Files/mlb_public_v1.pkl" \\
--out artifacts/mammals_prepped.parquet
# Full: include AlphaFold structural features
python scripts/prep_taxon.py \\
--fasta mygroup.fasta \\
--taxon_name fungi \\
--mlb "Important Files/mlb_public_v1.pkl" \\
--out artifacts/fungi_prepped.parquet \\
--fetch_af
# Skip GO fetch if you already have labels (provide label CSV: uniprot_id,go_ids)
python scripts/prep_taxon.py \\
--fasta mygroup.fasta \\
--taxon_name archaea \\
--mlb "Important Files/mlb_public_v1.pkl" \\
--labels_csv my_labels.csv \\
--out artifacts/archaea_prepped.parquet
# Keep all sequences (not just experimentally labeled) for inference-only
python scripts/prep_taxon.py \\
--fasta mygroup.fasta \\
--taxon_name plants \\
--mlb "Important Files/mlb_public_v1.pkl" \\
--out artifacts/plants_prepped.parquet \\
--keep_unlabeled
"""
import argparse
import json
import math
import re
import threading
import time
import warnings
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
import joblib
import numpy as np
import pandas as pd
import requests
import torch
warnings.filterwarnings("ignore")
# ── Constants ─────────────────────────────────────────────────────────────────
ESM_DIM = 320
API_THREADS = 20
API_TIMEOUT = 12
EXP_CODES = {"IDA","IMP","IPI","IGI","IEP","EXP","HDA","HMP","HGI","HEP","TAS","IC"}
MIN_LEN, MAX_LEN = 30, 2500
# Amino acid hydrophobicity (Kyte-Doolittle)
_KD = {"A":1.8,"C":2.5,"D":-3.5,"E":-3.5,"F":2.8,"G":-0.4,"H":-3.2,"I":4.5,
"K":-3.9,"L":3.8,"M":1.9,"N":-3.5,"P":-1.6,"Q":-3.5,"R":-4.5,"S":-0.8,
"T":-0.7,"V":4.2,"W":-0.9,"Y":-1.3}
_CHARGE = {"K":1,"R":1,"D":-1,"E":-1}
_POLAR = set("DEHKNQRST")
_TM_HYDRO = {"I","L","V","F","M","A","W"}
# ── FASTA parser ─────────────────────────────────────────────────────────────
def parse_fasta(path):
header, seq = None, []
with open(path) as fh:
for line in fh:
line = line.strip()
if line.startswith(">"):
if header is not None:
yield header, "".join(seq)
header = line[1:]
seq = []
else:
seq.append(line)
if header is not None:
yield header, "".join(seq)
def extract_uniprot_id(header: str) -> str:
"""Extract UniProt accession from various FASTA header formats."""
# sp|P12345|GENE_HUMAN or tr|A0A123|...
m = re.search(r"[st][pr]\|([A-Z0-9]+)\|", header)
if m:
return m.group(1)
# Plain accession (word boundary)
m = re.match(r"([A-Z][A-Z0-9]{5,9})\b", header.split()[0])
if m:
return m.group(1)
return header.split()[0]
# ── UniProt GO annotation fetch ───────────────────────────────────────────────
def fetch_go_mf(uniprot_id: str, exp_codes=EXP_CODES, timeout=API_TIMEOUT) -> list:
url = f"https://rest.uniprot.org/uniprotkb/{uniprot_id}.json?fields=go_f"
try:
r = requests.get(url, timeout=timeout)
if r.status_code != 200:
return []
refs = r.json().get("uniProtKBCrossReferences", [])
terms = []
for ref in refs:
if ref.get("database") != "GO":
continue
go_id = ref.get("id", "")
props = {p.get("key"): p.get("value") for p in ref.get("properties", [])}
aspect = props.get("GoTerm", "")
evcode = props.get("GoEvidenceType", "").split(":")[0]
if aspect.startswith("F:") and evcode in exp_codes:
terms.append(go_id)
return terms
except Exception:
return []
def fetch_all_go(entries: list, threads=API_THREADS) -> dict:
"""Parallel GO fetch. entries: list of (header, seq)."""
go_labels = {}
lock = threading.Lock()
print(f" Fetching GO annotations ({threads} threads)...")
def fetch_one(hdr_seq):
hdr, _ = hdr_seq
uid = extract_uniprot_id(hdr)
terms = fetch_go_mf(uid)
with lock:
go_labels[hdr] = terms
with ThreadPoolExecutor(max_workers=threads) as ex:
futures = {ex.submit(fetch_one, e): e for e in entries}
for i, fut in enumerate(as_completed(futures)):
if (i + 1) % 100 == 0:
print(f" {i+1}/{len(entries)} done")
try:
fut.result()
except Exception:
pass
return go_labels
# ── Sequence feature computation ─────────────────────────────────────────────
def seq_features(seq: str) -> dict:
"""
Compute 11 physicochemical/proxy features β€” same set as SUPP_COLS[:11].
All features are zero-safe (no hard failures).
"""
s = seq.upper()
n = len(s)
aa = {c: s.count(c) / n for c in "ACDEFGHIKLMNPQRSTVWY" if n > 0}
# Length (log-scaled)
f_seq_len = math.log1p(n)
# Mean hydrophobicity
f_mean_hydro = np.mean([_KD.get(c, 0.0) for c in s]) if n else 0.0
# Net charge per residue (at pH 7)
f_net_charge = sum(_CHARGE.get(c, 0) for c in s) / max(n, 1)
# Uversky disorder proxy: fraction charged + polar
charged_frac = sum(1 for c in s if c in "DEKR") / max(n, 1)
hydro_frac = sum(1 for c in s if _KD.get(c, 0) > 0) / max(n, 1)
f_uversky_disorder = max(0.0, charged_frac - hydro_frac)
# IDR proxy: fraction of polar/charged (disordered AAs)
f_idr_frac_proxy = sum(1 for c in s if c in _POLAR) / max(n, 1)
# Low complexity proxy: max single-aa run / length
runs = [len(m.group()) for m in re.finditer(r"(.)\1+", s)] or [1]
f_lowcomp_proxy = max(runs) / max(n, 1)
# TM helix proxy: fraction of strongly hydrophobic residues
f_tm_frac_proxy = sum(1 for c in s if c in _TM_HYDRO) / max(n, 1)
f_tm_any_proxy = float(f_tm_frac_proxy > 0.55)
# Signal peptide proxy: N-terminal hydrophobic stretch (first 30 aa)
nt = s[:30]
f_signal_peptide_proxy = float(
sum(1 for c in nt if c in _TM_HYDRO) / max(len(nt), 1) > 0.40
)
# Chou-Fasman helix/sheet propensity proxies (simplified)
helix_aa = set("AELM")
sheet_aa = set("CFITVY")
f_cf_helix_mean = sum(1 for c in s if c in helix_aa) / max(n, 1)
f_cf_sheet_mean = sum(1 for c in s if c in sheet_aa) / max(n, 1)
return {
"f_seq_len": f_seq_len,
"f_mean_hydro": f_mean_hydro,
"f_net_charge": f_net_charge,
"f_uversky_disorder": f_uversky_disorder,
"f_idr_frac_proxy": f_idr_frac_proxy,
"f_lowcomp_proxy": f_lowcomp_proxy,
"f_tm_frac_proxy": f_tm_frac_proxy,
"f_tm_any_proxy": f_tm_any_proxy,
"f_signal_peptide_proxy": f_signal_peptide_proxy,
"f_cf_helix_mean": f_cf_helix_mean,
"f_cf_sheet_mean": f_cf_sheet_mean,
}
# ── AlphaFold feature fetch ───────────────────────────────────────────────────
def fetch_af_features(uniprot_id: str, timeout=20) -> dict | None:
"""
Fetch pLDDT and PAE statistics from AlphaFold DB v4.
Returns None if not found or fetch fails.
"""
base = f"https://alphafold.ebi.ac.uk/api/prediction/{uniprot_id}"
try:
r = requests.get(base, timeout=timeout)
if r.status_code != 200:
return None
data = r.json()
if not data:
return None
entry = data[0]
plddt_url = entry.get("plddt_url") or entry.get("cifUrl", "").replace(".cif", ".json")
pae_url = entry.get("paeImageUrl", "").replace("-pae.png", ".json")
plddt_vals, pae_vals = None, None
# Try to get per-residue pLDDT
scores_url = entry.get("confidenceUrl") or plddt_url
if scores_url:
rs = requests.get(scores_url, timeout=timeout)
if rs.status_code == 200:
try:
plddt_vals = np.array(rs.json(), dtype=np.float32)
except Exception:
pass
# Try to get PAE matrix
pae_json_url = entry.get("paeDocUrl") or pae_url
if pae_json_url and pae_json_url.endswith(".json"):
rp = requests.get(pae_json_url, timeout=timeout)
if rp.status_code == 200:
try:
raw = rp.json()
if isinstance(raw, list) and raw:
pae_vals = np.array(raw[0].get("predicted_aligned_error", []),
dtype=np.float32).ravel()
except Exception:
pass
if plddt_vals is None:
return None
p = plddt_vals
feat = {
"f_afdb_has_model": 1.0,
"f_plddt_mean": float(np.mean(p)),
"f_plddt_std": float(np.std(p)),
"f_plddt_q10": float(np.percentile(p, 10)),
"f_plddt_q50": float(np.percentile(p, 50)),
"f_plddt_q90": float(np.percentile(p, 90)),
"f_plddt_frac_gt90": float((p > 90).mean()),
"f_plddt_frac_gt70": float((p > 70).mean()),
"f_plddt_frac_lt50": float((p < 50).mean()),
}
# Distance bins (pLDDT deciles as proxy for structural confidence distribution)
bins = np.linspace(0, 100, 11)
hist, _ = np.histogram(p, bins=bins)
for i, v in enumerate(hist):
feat[f"f_distbin_{i}"] = float(v) / max(len(p), 1)
if pae_vals is not None and len(pae_vals) > 0:
pae = pae_vals
feat.update({
"f_pae_mean": float(np.mean(pae)),
"f_pae_median": float(np.median(pae)),
"f_pae_p90": float(np.percentile(pae, 90)),
"f_pae_p95": float(np.percentile(pae, 95)),
"f_pae_frac_lt5": float((pae < 5).mean()),
"f_pae_frac_lt10": float((pae < 10).mean()),
"f_pae_frac_gt20": float((pae > 20).mean()),
})
feat["f_seqfeat_present"] = 1.0
feat["f_af_present"] = 1.0
return feat
except Exception:
return None
# ── ESM-2 embedding ───────────────────────────────────────────────────────────
def embed_sequences(entries: list, device: torch.device,
token_budget: int = 6000) -> dict:
"""
entries: list of (header, seq)
Returns {header: np.ndarray of shape (320,)}
"""
import esm as esm_lib
print(" Loading ESM-2 (esm2_t6_8M_UR50D)...")
esm_model, alphabet = esm_lib.pretrained.esm2_t6_8M_UR50D()
esm_model = esm_model.to(device).eval()
bc = alphabet.get_batch_converter()
results = {}
batch_buf, total_tok = [], 0
def flush(buf):
batch_data = [(h, s) for h, s, _ in buf]
_, _, toks = bc(batch_data)
with torch.no_grad():
rep = esm_model(toks.to(device), repr_layers=[6])["representations"][6]
for k, (h, s, _) in enumerate(buf):
emb = rep[k, 1:len(s)+1].mean(0).cpu().numpy()
results[h] = emb
for h, s in entries:
ntok = len(s) + 2
if total_tok + ntok > token_budget and batch_buf:
flush(batch_buf)
batch_buf, total_tok = [], 0
batch_buf.append((h, s, None))
total_tok += ntok
if batch_buf:
flush(batch_buf)
return results
# ── Main ──────────────────────────────────────────────────────────────────────
def main():
ap = argparse.ArgumentParser(description="ProtFunc universal taxon data prep")
ap.add_argument("--fasta", required=True)
ap.add_argument("--taxon_name", required=True, help="e.g. 'mammals', 'fungi'")
ap.add_argument("--mlb", required=True, help="Path to mlb_public_v1.pkl")
ap.add_argument("--out", required=True, help="Output .parquet path")
ap.add_argument("--labels_csv", default=None,
help="CSV with columns uniprot_id,go_ids (semicolon-sep). "
"Skips UniProt API fetch if provided.")
ap.add_argument("--fetch_af", action="store_true",
help="Fetch AlphaFold pLDDT/PAE features from AFDB")
ap.add_argument("--keep_unlabeled", action="store_true",
help="Include proteins with 0 GO labels (for inference-only use)")
ap.add_argument("--max_len", type=int, default=MAX_LEN)
ap.add_argument("--min_len", type=int, default=MIN_LEN)
ap.add_argument("--threads", type=int, default=API_THREADS)
ap.add_argument("--device", default="auto")
args = ap.parse_args()
if args.device == "auto":
device = torch.device(
"mps" if torch.backends.mps.is_available() else
"cuda" if torch.cuda.is_available() else "cpu"
)
else:
device = torch.device(args.device)
print(f"Device: {device}")
print(f"Taxon: {args.taxon_name}")
# ── Parse FASTA ──────────────────────────────────────────────────────────
print(f"\n[1/5] Parsing FASTA: {args.fasta}")
all_entries = [
(h, s) for h, s in parse_fasta(args.fasta)
if args.min_len <= len(s) <= args.max_len
]
print(f" {len(all_entries)} sequences ({args.min_len}–{args.max_len} aa)")
if not all_entries:
print(" No sequences in range β€” exiting.")
return
# ── Load MLB ─────────────────────────────────────────────────────────────
mlb = joblib.load(args.mlb)
go2idx = {g: i for i, g in enumerate(mlb.classes_)}
print(f" MLB loaded: {len(mlb.classes_)} GO-MF terms")
# ── GO annotations ────────────────────────────────────────────────────────
print(f"\n[2/5] GO annotations...")
if args.labels_csv:
print(f" Loading from CSV: {args.labels_csv}")
df_lab = pd.read_csv(args.labels_csv)
ext_labels = {}
for _, row in df_lab.iterrows():
uid = str(row["uniprot_id"])
gids = [g.strip() for g in str(row.get("go_ids", "")).split(";") if g.strip()]
ext_labels[uid] = gids
go_labels = {h: ext_labels.get(extract_uniprot_id(h), []) for h, _ in all_entries}
else:
go_labels = fetch_all_go(all_entries, threads=args.threads)
# Map to label indices
label_index_map = {}
for h, _ in all_entries:
terms = go_labels.get(h, [])
idxs = [go2idx[t] for t in terms if t in go2idx]
label_index_map[h] = idxs
n_labeled = sum(1 for v in label_index_map.values() if len(v) > 0)
print(f" {n_labeled}/{len(all_entries)} sequences have β‰₯1 GO-MF label")
# Filter to labeled unless keep_unlabeled
if not args.keep_unlabeled:
entries = [(h, s) for h, s in all_entries if label_index_map.get(h)]
print(f" Keeping {len(entries)} labeled sequences (use --keep_unlabeled to keep all)")
else:
entries = all_entries
print(f" Keeping all {len(entries)} sequences")
if not entries:
print(" No sequences to embed β€” exiting.")
return
# ── Sequence features ─────────────────────────────────────────────────────
print(f"\n[3/5] Computing sequence features...")
seq_feats = {h: seq_features(s) for h, s in entries}
print(f" Done for {len(seq_feats)} sequences")
# ── ESM-2 embeddings ──────────────────────────────────────────────────────
print(f"\n[4/5] ESM-2 embeddings...")
t0 = time.time()
embeddings = embed_sequences(entries, device)
print(f" Embedded {len(embeddings)} sequences in {time.time()-t0:.1f}s")
# ── AlphaFold features (optional) ─────────────────────────────────────────
af_feats = {}
if args.fetch_af:
print(f"\n[4b/5] Fetching AlphaFold features (parallel, {args.threads} threads)...")
af_lock = threading.Lock()
def fetch_af_one(hdr_seq):
h, _ = hdr_seq
uid = extract_uniprot_id(h)
feat = fetch_af_features(uid)
with af_lock:
af_feats[h] = feat
with ThreadPoolExecutor(max_workers=args.threads) as ex:
futs = {ex.submit(fetch_af_one, e): e for e in entries}
for i, fut in enumerate(as_completed(futs)):
if (i + 1) % 50 == 0:
found = sum(1 for v in af_feats.values() if v is not None)
print(f" {i+1}/{len(entries)} fetched ({found} with AF models)")
try:
fut.result()
except Exception:
pass
found = sum(1 for v in af_feats.values() if v is not None)
print(f" AlphaFold coverage: {found}/{len(entries)} ({100*found/max(len(entries),1):.1f}%)")
# ── Assemble parquet ──────────────────────────────────────────────────────
print(f"\n[5/5] Assembling output parquet...")
SUPP_COLS = [
"f_seq_len", "f_mean_hydro", "f_net_charge", "f_uversky_disorder",
"f_idr_frac_proxy", "f_lowcomp_proxy", "f_tm_frac_proxy", "f_tm_any_proxy",
"f_signal_peptide_proxy", "f_cf_helix_mean", "f_cf_sheet_mean",
"f_afdb_has_model",
"f_plddt_mean", "f_plddt_std", "f_plddt_q10", "f_plddt_q50", "f_plddt_q90",
"f_plddt_frac_gt90", "f_plddt_frac_gt70", "f_plddt_frac_lt50",
"f_distbin_0", "f_distbin_1", "f_distbin_2", "f_distbin_3", "f_distbin_4",
"f_distbin_5", "f_distbin_6", "f_distbin_7", "f_distbin_8", "f_distbin_9",
"f_pae_mean", "f_pae_median", "f_pae_p90", "f_pae_p95",
"f_pae_frac_lt5", "f_pae_frac_lt10", "f_pae_frac_gt20",
"f_seqfeat_present", "f_af_present",
]
rows = []
for h, s in entries:
emb = embeddings.get(h)
if emb is None:
continue
row = {"Protein_ID": h, "Label_Indices": label_index_map.get(h, [])}
for i, v in enumerate(emb):
row[f"Dim_{i}"] = float(v)
sf = seq_feats.get(h, {})
aff = af_feats.get(h) if args.fetch_af else None
merged_feats = {}
# Sequence features (always present)
for k in ["f_seq_len","f_mean_hydro","f_net_charge","f_uversky_disorder",
"f_idr_frac_proxy","f_lowcomp_proxy","f_tm_frac_proxy","f_tm_any_proxy",
"f_signal_peptide_proxy","f_cf_helix_mean","f_cf_sheet_mean"]:
merged_feats[k] = sf.get(k, 0.0)
merged_feats["f_seqfeat_present"] = 1.0
# AF features (zero if not fetched or not found)
if aff is not None:
for k, v in aff.items():
merged_feats[k] = v
else:
for k in SUPP_COLS:
if k not in merged_feats:
merged_feats[k] = 0.0
for k in SUPP_COLS:
row[k] = float(merged_feats.get(k, 0.0))
rows.append(row)
df_out = pd.DataFrame(rows)
out_path = Path(args.out)
out_path.parent.mkdir(exist_ok=True)
df_out.to_parquet(out_path, index=False)
print(f" Saved {len(df_out)} proteins β†’ {out_path}")
# Summary
n_with_labels = int((df_out["Label_Indices"].apply(len) > 0).sum())
n_with_af = int((df_out["f_af_present"] > 0).sum()) if "f_af_present" in df_out else 0
print(f"\nSummary for '{args.taxon_name}':")
print(f" Total proteins: {len(df_out)}")
print(f" With GO-MF labels: {n_with_labels}")
print(f" With AF structure: {n_with_af}")
print(f" Output: {out_path}")
print(f"\nNext step:")
print(f" python scripts/eval_generalization.py \\")
print(f" --checkpoint artifacts/protfunc_v3_fixed.pth \\")
print(f" --thresholds artifacts/protfunc_v3_fixed_thresholds.json \\")
print(f" --mlb 'Important Files/mlb_public_v1.pkl' \\")
print(f" --taxon_parquet {out_path} \\")
print(f" --taxon_name {args.taxon_name}")
if __name__ == "__main__":
main()