BroteinShake / scripts /extract.py
42Cummer's picture
Upload 129 files
32c275c verified
def extract_best_design(fasta_file: str, output_file: str) -> None:
import os
if not os.path.exists(fasta_file):
raise FileNotFoundError(f"Input FASTA file not found: {fasta_file}")
best_score = float('inf')
best_header = ""
best_seq = ""
with open(fasta_file, 'r') as f:
lines = [line.strip() for line in f.readlines() if line.strip()]
for i in range(0, len(lines), 2):
if i + 1 >= len(lines):
break
header = lines[i]
sequence = lines[i+1]
# Skip the original native sequence (first entry)
if "sample" not in header:
continue
# Parse the score: "score=0.7647"
try:
score_part = [p for p in header.split(',') if 'score' in p][0]
score = float(score_part.split('=')[1])
if score < best_score:
best_score = score
best_header = header
best_seq = sequence
except (IndexError, ValueError) as e:
continue
if best_seq:
os.makedirs(os.path.dirname(output_file) if os.path.dirname(output_file) else '.', exist_ok=True)
with open(output_file, 'w') as f:
f.write(f"{best_header}\n{best_seq}\n")
print(f"✅ Success! Best design (score={best_score:.4f}) saved to {output_file}")
else:
raise ValueError(f"No valid designs found in {fasta_file}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Extract best design from ProteinMPNN output")
parser.add_argument("--input_fa", type=str, default="generated/3kas/seqs/3kas_clones.fa",
help="Input FASTA file path (relative to project root)")
parser.add_argument("--output_fa", type=str, default="generated/shuttle/best_shuttle.fa",
help="Output FASTA file path (relative to project root)")
args = parser.parse_args()
extract_best_design(args.input_fa, args.output_fa)