RFdiffusion3 / utils /pipelines.py
gabboud's picture
integrate download zip logic inside generation_with_input_config, remove gr.update
4a67952
from rfd3.engine import RFD3InferenceConfig, RFD3InferenceEngine
import gradio as gr
from lightning.fabric import seed_everything
import time
import os
import spaces
import subprocess
import gzip
from utils.handle_files import *
import sys
import shutil
from time import perf_counter
import random
def get_duration(input_file, pdb_file, num_batches, num_designs_per_batch, extra_args, max_duration):
return max_duration
@spaces.GPU(duration=get_duration)
def generation_with_input_config(input_file, pdb_file, num_batches, num_designs_per_batch, extra_args, max_duration):
"""
Runs an unconditional generation with the specified input config file. Saves the generated structures to a timestamped directory in the outputs folder and returns the path to the directory along with a list of the generated structures' file paths.
Parameters:
----------
input_file: gr.File,
gr.File object containing the uploaded config file (yaml or json). input_file.name is the path to the uploaded file on the server.
pdb_file: gr.File,
gr.File object containing the uploaded pdb file for conditioning the generation.
Returns:
-------
textbox_update: gr.update,
A gr.update object to update the textbox with the status of the generation. propagates subprocess errors to the textbox if the generation fails.
directory: str,
The path to the directory where the generated structures are saved.
results: list of dicts,
A list of the generated structures' file paths, where each dict contains batch number "batch", design number "design", path to cif file "cif_path", and path to pdb file "pdb_path".
"""
if input_file is None:
return "Please ensure you have uploaded a configuration file: .yaml or .json", None, None
elif pdb_file is None:
status_update = f"Running generation for {num_batches} batches of {num_designs_per_batch}\n job configuration uploaded from file {os.path.basename(input_file)}\n no scaffold/target provided"
else:
status_update = f"Running generation for {num_batches} batches of {num_designs_per_batch}\n job configuration uploaded from file {os.path.basename(input_file)}\n scaffold/target provided from file {os.path.basename(pdb_file)}"
session_hash = random.getrandbits(128)
time_stamp = time.strftime("%Y-%m-%d-%H-%M-%S")
directory = f"./outputs/generation_with_input_config/session_{session_hash}_{time_stamp}"
os.makedirs(directory, exist_ok=False)
try:
if pdb_file is not None:
# I need to do this because uploading files to a HF space stores each file in a separate temp directory so I need to copy them again to the same place.
shared_dir = os.path.join("uploads", f"{time_stamp}_{session_hash}")
os.makedirs(shared_dir)
copied_config_file = os.path.join(shared_dir, os.path.basename(input_file))
shutil.copy2(input_file, copied_config_file)
copied_pdb_file = os.path.join(shared_dir, os.path.basename(pdb_file))
shutil.copy2(pdb_file, copied_pdb_file)
command = f"rfd3 design inputs={copied_config_file} out_dir={directory} n_batches={num_batches} diffusion_batch_size={num_designs_per_batch}"
else:
command = f"rfd3 design inputs={input_file} out_dir={directory} n_batches={num_batches} diffusion_batch_size={num_designs_per_batch}"
if extra_args:
command += f" {extra_args}"
status_update += f"\nRunning command: {command}."
start = perf_counter()
res = subprocess.run(command, shell=True, check=True, text=True, capture_output=True)
status_update += f"\nGeneration successful! Command took {perf_counter() - start:.2f} seconds to run."
results = []
for file_name in os.listdir(directory):
if file_name.endswith(".cif.gz"):
name = os.path.basename(file_name).split(".")[0] #filename without extension
terms = name.split("_")
model_index = terms.index("model")
batch = int(terms[model_index - 1])
design = int(terms[model_index + 1])
cif_path = os.path.join(directory, file_name)
pdb_path = mcif_gz_to_pdb(cif_path)
results.append({"batch": batch, "design": design, "cif_path": cif_path, "pdb_path": pdb_path})
zip_path = download_results_as_zip(directory)
return status_update, results, zip_path
except subprocess.CalledProcessError as e:
return f"Generation failed:\n{e.stderr}", None, None
#def generation_with_input_config_factory(max_duration):
#
# @spaces.GPU(duration=max_duration)
# def generation_with_correct_time_limit(input_file, pdb_file, num_batches, num_designs_per_batch, extra_args):
# return generation_with_input_config_impl(input_file, pdb_file, num_batches, num_designs_per_batch, extra_args)
#
# return generation_with_correct_time_limit