| | import datetime |
| | import io |
| | import os |
| | import shutil |
| | import subprocess |
| | import tempfile |
| | import uuid |
| |
|
| | import logging |
| | import zipfile |
| | from typing import List, Dict |
| |
|
| | import requests |
| |
|
| | PROJECT_URL = "https://github.com/gcorso/DiffDock" |
| |
|
| | ARG_ORDER = ["samples_per_complex"] |
| |
|
| | APP_DIR = os.path.dirname(os.path.abspath(__file__)) |
| | PROJECT_DIR = os.path.abspath(os.path.join(APP_DIR, "..")) |
| | |
| | TEMP_DIR = os.path.join(APP_DIR, ".tmp") |
| | os.makedirs(TEMP_DIR, exist_ok=True) |
| |
|
| |
|
| | def set_env_variables(): |
| | if "DiffDockDir" not in os.environ: |
| | work_dir = os.path.abspath(PROJECT_DIR) |
| | if os.path.exists(work_dir): |
| | os.environ["DiffDockDir"] = work_dir |
| | else: |
| | raise ValueError(f"DiffDockDir {work_dir} not found") |
| |
|
| | if "LOG_LEVEL" not in os.environ: |
| | os.environ["LOG_LEVEL"] = "INFO" |
| |
|
| |
|
| | def configure_logging(level=None): |
| | if level is None: |
| | level = getattr(logging, os.environ.get("LOG_LEVEL", "INFO")) |
| |
|
| | |
| | |
| | logging.basicConfig( |
| | level=level, |
| | format="[%(asctime)s] [%(filename)s:%(lineno)d] %(levelname)s - %(message)s", |
| | datefmt="%Y-%m-%d %H:%M:%S %Z", |
| | handlers=[ |
| | logging.StreamHandler(), |
| | |
| | |
| | ], |
| | ) |
| |
|
| |
|
| | def kwargs_to_cli_args(**kwargs) -> List[str]: |
| | """ |
| | Converts keyword arguments to a CLI argument string. |
| | Boolean kwargs are added as flags if True, and omitted if False. |
| | """ |
| | cli_args = [] |
| | for key, value in kwargs.items(): |
| | if isinstance(value, bool): |
| | if value: |
| | cli_args.append(f"--{key}") |
| | else: |
| | if value is not None and str(value) != "": |
| | cli_args.append(f"--{key}={value}") |
| |
|
| | return cli_args |
| |
|
| |
|
| | def read_file_lines(fi_path: str, skip_remarks=True): |
| | with open(fi_path, "r") as fp: |
| | lines = fp.readlines() |
| | if skip_remarks: |
| | lines = list(filter(lambda x: not x.upper().startswith("REMARK"), lines)) |
| | mol = "".join(lines) |
| | return mol |
| |
|
| |
|
| | def run_cli_command( |
| | protein_path: str, |
| | ligand: str, |
| | config_path: str, |
| | *args, |
| | work_dir=None, |
| | ): |
| | if work_dir is None: |
| | work_dir = os.environ.get( |
| | "DiffDockDir", PROJECT_DIR |
| | ) |
| |
|
| | assert len(args) == len(ARG_ORDER), f'Expected {len(ARG_ORDER)} arguments, got {len(args)}' |
| |
|
| | inference_log_level = os.environ.get("INFERENCE_LOG_LEVEL", os.environ.get("LOG_LEVEL", "WARNING")) |
| |
|
| | all_arg_dict = {"protein_path": protein_path, "ligand": ligand, "config": config_path, |
| | "no_final_step_noise": True, "loglevel": inference_log_level} |
| | for arg_name, arg_val in zip(ARG_ORDER, args): |
| | all_arg_dict[arg_name] = arg_val |
| |
|
| | |
| | result = subprocess.run( |
| | ["python3", "utils/print_device.py"], |
| | cwd=work_dir, |
| | check=False, |
| | text=True, |
| | capture_output=True, |
| | env=os.environ, |
| | ) |
| | logging.debug(f"Device check output:\n{result.stdout}") |
| |
|
| | command = [ |
| | "python3", |
| | "inference.py"] |
| |
|
| | command += kwargs_to_cli_args(**all_arg_dict) |
| |
|
| | with tempfile.TemporaryDirectory() as temp_dir: |
| | temp_dir_path = temp_dir |
| | command.append(f"--out_dir={temp_dir_path}") |
| |
|
| | |
| | command_str = " ".join(command) |
| | logging.info(f"Executing command: {command_str}") |
| |
|
| | |
| | try: |
| | |
| | skip_running = os.environ.get("__SKIP_RUNNING", "false").lower() == "true" |
| | if not skip_running: |
| | result = subprocess.run( |
| | command, |
| | cwd=work_dir, |
| | check=False, |
| | text=True, |
| | capture_output=True, |
| | ) |
| | logging.debug(f"Command output:\n{result.stdout}") |
| | full_output = f"Standard out:\n{result.stdout}" |
| | if result.stderr: |
| | |
| | stderr_lines = result.stderr.split("\n") |
| | stderr_lines = filter(lambda x: "%|" not in x, stderr_lines) |
| | stderr_text = "\n".join(stderr_lines) |
| | logging.error(f"Command error:\n{stderr_text}") |
| | full_output += f"\nStandard error:\n{stderr_text}" |
| |
|
| | with open(f"{temp_dir_path}/output.log", "w") as log_file: |
| | log_file.write(full_output) |
| |
|
| | else: |
| | logging.debug("Skipping command execution") |
| | artificial_output_dir = os.path.join(TEMP_DIR, "artificial_output") |
| | os.makedirs(artificial_output_dir, exist_ok=True) |
| | shutil.copy(protein_path, os.path.join(artificial_output_dir, "protein.pdb")) |
| | shutil.copy(ligand, os.path.join(artificial_output_dir, "rank1.sdf")) |
| | shutil.copy(ligand, os.path.join(artificial_output_dir, "rank1_confidence-0.10.sdf")) |
| |
|
| | except subprocess.CalledProcessError as e: |
| | logging.error(f"An error occurred while executing the command: {e}") |
| |
|
| | |
| | sub_dirs = [os.path.join(temp_dir_path, x) for x in os.listdir(temp_dir_path)] |
| | sub_dirs = list(filter(lambda x: os.path.isdir(x), sub_dirs)) |
| | logging.debug(f"Output Subdirectories: {sub_dirs}") |
| | if len(sub_dirs) == 1: |
| | sub_dir = sub_dirs[0] |
| | |
| | trg_protein_path = os.path.join(sub_dir, os.path.basename(protein_path)) |
| | shutil.copy(protein_path, trg_protein_path) |
| |
|
| | |
| | |
| | timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") |
| | uuid_tag = str(uuid.uuid4())[0:8] |
| | unique_filename = f"diffdock_output_{timestamp}_{uuid_tag}" |
| | zip_base_name = os.path.join("tmp", unique_filename) |
| |
|
| | logging.debug(f"About to zip directory '{temp_dir}' to {unique_filename}") |
| |
|
| | full_zip_path = shutil.make_archive(zip_base_name, "zip", temp_dir) |
| |
|
| | logging.debug(f"Directory '{temp_dir}' zipped to {unique_filename}'") |
| |
|
| | return full_zip_path |
| |
|
| |
|
| | def parse_ligand_filename(filename: str) -> Dict: |
| | """ |
| | Parses an sdf filename to extract information. |
| | """ |
| | if not filename.endswith(".sdf"): |
| | return {} |
| |
|
| | basename = os.path.basename(filename).replace(".sdf", "") |
| | tokens = basename.split("_") |
| | rank = tokens[0] |
| | rank = int(rank.replace("rank", "")) |
| | if len(tokens) == 1: |
| | return {"filename": basename, "rank": rank, "confidence": None} |
| |
|
| | con_str = tokens[1] |
| | conf_val = float(con_str.replace("confidence", "")) |
| |
|
| | return {"filename": basename, "rank": rank, "confidence": conf_val} |
| |
|
| |
|
| | def process_zip_file(zip_path: str): |
| | pdb_file = [] |
| | sdf_files = [] |
| | with zipfile.ZipFile(open(zip_path, "rb")) as my_zip_file: |
| | for filename in my_zip_file.namelist(): |
| | |
| | if filename.endswith("/"): |
| | continue |
| |
|
| | if filename.endswith(".pdb"): |
| | content = my_zip_file.read(filename).decode("utf-8") |
| | pdb_file.append({"path": filename, "content": content}) |
| |
|
| | if filename.endswith(".sdf"): |
| | info = parse_ligand_filename(filename) |
| | info["content"] = my_zip_file.read(filename).decode("utf-8") |
| | info["path"] = filename |
| | sdf_files.append(info) |
| |
|
| | sdf_files = sorted(sdf_files, key=lambda x: x.get("rank", 1_000)) |
| |
|
| | return pdb_file, sdf_files |
| |
|
| |
|
| | def download_pdb(pdb_code: str, work_dir: str): |
| | pdb_url = f"https://files.rcsb.org/download/{pdb_code}.pdb" |
| | pdb_path = os.path.join(work_dir, f"{pdb_code}.pdb") |
| | if not os.path.exists(pdb_path): |
| | logging.debug(f"Downloading PDB file for {pdb_code} from {pdb_url}") |
| | response = requests.get(pdb_url, allow_redirects=True) |
| | if response.status_code == 200: |
| | with open(pdb_path, "w") as pdb_file: |
| | pdb_file.write(response.text) |
| | else: |
| | logging.error(f"Failed to download PDB file for {pdb_code} from {pdb_url}") |
| | pdb_path = None |
| |
|
| | else: |
| | logging.info(f"PDB file for {pdb_code} already exists at {pdb_path}") |
| |
|
| | return pdb_path |
| |
|
| |
|
| | def test_run_cli(): |
| | |
| | set_env_variables() |
| | configure_logging() |
| |
|
| | work_dir = os.path.abspath(PROJECT_DIR) |
| | os.environ["DiffDockDir"] = work_dir |
| | protein_path = os.path.join(work_dir, "data", "3dpf", "3dpf_protein.pdb") |
| | ligand = os.path.join(work_dir, "data", "3dpf", "3dpf_ligand.sdf") |
| | config_file = os.path.join(APP_DIR, "default_inference_args.yaml") |
| |
|
| | run_cli_command( |
| | protein_path, |
| | ligand, |
| | config_file, |
| | 10, |
| | False, |
| | True, |
| | None |
| | ) |
| |
|