Spaces:
Sleeping
Sleeping
File size: 5,255 Bytes
32c275c 05055dd 32c275c 05055dd 32c275c 05055dd 32c275c 05055dd 32c275c e639e39 32c275c 05055dd e639e39 05055dd e639e39 05055dd e639e39 05055dd e639e39 05055dd 32c275c 05055dd e639e39 05055dd e639e39 32c275c e639e39 05055dd e639e39 32c275c 05055dd e639e39 05055dd e639e39 05055dd e639e39 05055dd 32c275c 05055dd 32c275c 05055dd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import os
import json
import subprocess
import warnings
import sys
import shutil
def sync_protein_metadata(jsonl_path, dict_path):
"""
Automated metadata sanitization to prevent KeyError (e.g., 'seq_chain_C').
Prunes non-proteogenic chain IDs from the dictionary before the run.
"""
if not os.path.exists(jsonl_path) or not os.path.exists(dict_path):
return
# 1. Identify chains that actually have proteogenic sequence data
valid_chains_map = {}
with open(jsonl_path, 'r') as f:
for line in f:
entry = json.loads(line)
name = entry['name']
# Only keep chains that have a 'seq_chain_X' entry in the JSONL
valid = {k.split('_')[-1] for k in entry.keys() if k.startswith('seq_chain_')}
valid_chains_map[name] = valid
# 2. Clean the chain ID dictionary
with open(dict_path, 'r') as f:
chain_id_dict = json.load(f)
for pdb_name, configs in chain_id_dict.items():
if pdb_name in valid_chains_map:
valid = valid_chains_map[pdb_name]
# configs[0] = redesign list, configs[1] = fixed list
original_chains = set(configs[0] + configs[1])
chain_id_dict[pdb_name] = [
[c for c in configs[0] if c in valid],
[c for c in configs[1] if c in valid]
]
# Diagnostic feedback
removed = original_chains - valid
if removed:
print(f"🧹 Sanitizer: Pruned non-protein chains from metadata: {removed}")
# 3. Overwrite with cleaned metadata for ProteinMPNN
with open(dict_path, 'w') as f:
json.dump(chain_id_dict, f)
def run_broteinshake_generator(pdb_path, fixed_chains, variable_chains, num_seqs=20, temp=0.1):
# 1. Setup identifiers and directories
pdb_name = os.path.basename(pdb_path).split('.')[0]
output_dir = f"./generated/{pdb_name}"
os.makedirs(output_dir, exist_ok=True)
script_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(script_dir)
proteinmpnn_dir = os.path.join(project_root, "ProteinMPNN")
if not os.path.exists(proteinmpnn_dir):
print("ProteinMPNN not found, cloning repository...")
subprocess.run(["git", "clone", "https://github.com/dauparas/ProteinMPNN.git"], cwd=project_root, check=True)
mpnn_script = os.path.join(proteinmpnn_dir, "protein_mpnn_run.py")
# 2. Handle Single vs Multi-Chain Logic
if not fixed_chains or len(fixed_chains) == 0:
chain_to_design = variable_chains[0] if variable_chains else "A"
mpnn_cmd = (
f"python -W ignore {mpnn_script} --pdb_path {pdb_path} --pdb_path_chains {chain_to_design} "
f"--out_folder {output_dir} --num_seq_per_target {num_seqs} --sampling_temp {temp} --seed 42 --batch_size 1"
)
print(f"🚀 Designing {pdb_name} (Single-chain: {chain_to_design})...")
else:
# Multi-chain setup
pdb_dir = os.path.dirname(os.path.abspath(pdb_path)) or "."
jsonl_path = os.path.join(output_dir, "parsed_pdbs.jsonl")
parse_script = os.path.join(proteinmpnn_dir, "helper_scripts", "parse_multiple_chains.py")
# Step A: Parse PDB to JSONL
subprocess.run(f"python -W ignore {parse_script} --input_path={pdb_dir}/ --output_path={jsonl_path}", shell=True, check=True)
# Step B: Create initial Chain Dictionary
pdb_name_clones = f"{pdb_name}_clones"
# Fix: ensure the name in JSONL matches the dict key
with open(jsonl_path, 'r') as f:
jsonl_data = json.loads(f.readline())
jsonl_data['name'] = pdb_name_clones
with open(jsonl_path, 'w') as f:
f.write(json.dumps(jsonl_data) + '\n')
chain_id_json = os.path.join(output_dir, "chain_id_dict.json")
chain_id_dict = {pdb_name_clones: [[c for c in variable_chains], [c for c in fixed_chains]]}
with open(chain_id_json, 'w') as f:
json.dump(chain_id_dict, f)
# Step C: AUTOMATED CLEANING - Prunes ghost chains like 'C'
sync_protein_metadata(jsonl_path, chain_id_json)
# Step D: Final Execution Command
mpnn_cmd = (
f"python -W ignore {mpnn_script} --jsonl_path {jsonl_path} --chain_id_jsonl {chain_id_json} "
f"--out_folder {output_dir} --num_seq_per_target {num_seqs} --sampling_temp {temp} --seed 42"
)
print(f"🚀 Designing {pdb_name}... (Fixed: {fixed_chains} | Redesign: {variable_chains})")
# 3. Execute with suppressed warnings
env = os.environ.copy()
env['PYTHONWARNINGS'] = 'ignore'
subprocess.run(mpnn_cmd, shell=True, check=True, env=env, stderr=subprocess.DEVNULL)
print(f"✅ Success! Design complete for {pdb_name}.")
if __name__ == "__main__":
if len(sys.argv) < 4:
print("Usage: python scripts/generator.py <pdb_path> <fixed_chains> <variable_chains> [num_seqs] [temp]")
sys.exit(1)
run_broteinshake_generator(
sys.argv[1], sys.argv[2], sys.argv[3],
int(sys.argv[4]) if len(sys.argv) > 4 else 20,
float(sys.argv[5]) if len(sys.argv) > 5 else 0.1
) |