Spaces:
Running on Zero
Running on Zero
| from rfd3.engine import RFD3InferenceConfig, RFD3InferenceEngine | |
| import gradio as gr | |
| from lightning.fabric import seed_everything | |
| import time | |
| import os | |
| import spaces | |
| import subprocess | |
| import gzip | |
| from utils.handle_files import * | |
| import sys | |
| import shutil | |
| from time import perf_counter | |
| import random | |
| def get_duration(input_file, pdb_file, num_batches, num_designs_per_batch, extra_args, max_duration): | |
| return max_duration | |
| def generation_with_input_config(input_file, pdb_file, num_batches, num_designs_per_batch, extra_args, max_duration): | |
| """ | |
| Runs an unconditional generation with the specified input config file. Saves the generated structures to a timestamped directory in the outputs folder and returns the path to the directory along with a list of the generated structures' file paths. | |
| Parameters: | |
| ---------- | |
| input_file: gr.File, | |
| gr.File object containing the uploaded config file (yaml or json). input_file.name is the path to the uploaded file on the server. | |
| pdb_file: gr.File, | |
| gr.File object containing the uploaded pdb file for conditioning the generation. | |
| Returns: | |
| ------- | |
| textbox_update: gr.update, | |
| A gr.update object to update the textbox with the status of the generation. propagates subprocess errors to the textbox if the generation fails. | |
| directory: str, | |
| The path to the directory where the generated structures are saved. | |
| results: list of dicts, | |
| A list of the generated structures' file paths, where each dict contains batch number "batch", design number "design", path to cif file "cif_path", and path to pdb file "pdb_path". | |
| """ | |
| if input_file is None: | |
| return "Please ensure you have uploaded a configuration file: .yaml or .json", None, None | |
| elif pdb_file is None: | |
| status_update = f"Running generation for {num_batches} batches of {num_designs_per_batch}\n job configuration uploaded from file {os.path.basename(input_file)}\n no scaffold/target provided" | |
| else: | |
| status_update = f"Running generation for {num_batches} batches of {num_designs_per_batch}\n job configuration uploaded from file {os.path.basename(input_file)}\n scaffold/target provided from file {os.path.basename(pdb_file)}" | |
| session_hash = random.getrandbits(128) | |
| time_stamp = time.strftime("%Y-%m-%d-%H-%M-%S") | |
| directory = f"./outputs/generation_with_input_config/session_{session_hash}_{time_stamp}" | |
| os.makedirs(directory, exist_ok=False) | |
| try: | |
| if pdb_file is not None: | |
| # I need to do this because uploading files to a HF space stores each file in a separate temp directory so I need to copy them again to the same place. | |
| shared_dir = os.path.join("uploads", f"{time_stamp}_{session_hash}") | |
| os.makedirs(shared_dir) | |
| copied_config_file = os.path.join(shared_dir, os.path.basename(input_file)) | |
| shutil.copy2(input_file, copied_config_file) | |
| copied_pdb_file = os.path.join(shared_dir, os.path.basename(pdb_file)) | |
| shutil.copy2(pdb_file, copied_pdb_file) | |
| command = f"rfd3 design inputs={copied_config_file} out_dir={directory} n_batches={num_batches} diffusion_batch_size={num_designs_per_batch}" | |
| else: | |
| command = f"rfd3 design inputs={input_file} out_dir={directory} n_batches={num_batches} diffusion_batch_size={num_designs_per_batch}" | |
| if extra_args: | |
| command += f" {extra_args}" | |
| status_update += f"\nRunning command: {command}." | |
| start = perf_counter() | |
| res = subprocess.run(command, shell=True, check=True, text=True, capture_output=True) | |
| status_update += f"\nGeneration successful! Command took {perf_counter() - start:.2f} seconds to run." | |
| results = [] | |
| for file_name in os.listdir(directory): | |
| if file_name.endswith(".cif.gz"): | |
| name = os.path.basename(file_name).split(".")[0] #filename without extension | |
| terms = name.split("_") | |
| model_index = terms.index("model") | |
| batch = int(terms[model_index - 1]) | |
| design = int(terms[model_index + 1]) | |
| cif_path = os.path.join(directory, file_name) | |
| pdb_path = mcif_gz_to_pdb(cif_path) | |
| results.append({"batch": batch, "design": design, "cif_path": cif_path, "pdb_path": pdb_path}) | |
| zip_path = download_results_as_zip(directory) | |
| return status_update, results, zip_path | |
| except subprocess.CalledProcessError as e: | |
| return f"Generation failed:\n{e.stderr}", None, None | |
| #def generation_with_input_config_factory(max_duration): | |
| # | |
| # @spaces.GPU(duration=max_duration) | |
| # def generation_with_correct_time_limit(input_file, pdb_file, num_batches, num_designs_per_batch, extra_args): | |
| # return generation_with_input_config_impl(input_file, pdb_file, num_batches, num_designs_per_batch, extra_args) | |
| # | |
| # return generation_with_correct_time_limit | |