import os import pickle import numpy as np import torch from pathlib import Path from tqdm import tqdm import pandas as pd from boltz.data.mol import load_molecules from boltz.data import const from boltz.data.parse.mmcif_with_constraints import parse_mmcif from multiprocessing import Pool def compute_torsion_angles(coords, torsion_index): r_ij = coords[..., torsion_index[0], :] - coords[..., torsion_index[1], :] r_kj = coords[..., torsion_index[2], :] - coords[..., torsion_index[1], :] r_kl = coords[..., torsion_index[2], :] - coords[..., torsion_index[3], :] n_ijk = np.cross(r_ij, r_kj, axis=-1) n_jkl = np.cross(r_kj, r_kl, axis=-1) r_kj_norm = np.linalg.norm(r_kj, axis=-1) n_ijk_norm = np.linalg.norm(n_ijk, axis=-1) n_jkl_norm = np.linalg.norm(n_jkl, axis=-1) sign_phi = np.sign( r_kj[..., None, :] @ np.cross(n_ijk, n_jkl, axis=-1)[..., None] ).squeeze(axis=(-1, -2)) phi = sign_phi * np.arccos( np.clip( (n_ijk[..., None, :] @ n_jkl[..., None]).squeeze(axis=(-1, -2)) / (n_ijk_norm * n_jkl_norm), -1 + 1e-8, 1 - 1e-8, ) ) return phi def check_ligand_distance_geometry( structure, constraints, bond_buffer=0.25, angle_buffer=0.25, clash_buffer=0.2 ): coords = structure.coords["coords"] rdkit_bounds_constraints = constraints.rdkit_bounds_constraints pair_index = rdkit_bounds_constraints["atom_idxs"].copy().astype(np.int64).T bond_mask = rdkit_bounds_constraints["is_bond"].copy().astype(bool) angle_mask = rdkit_bounds_constraints["is_angle"].copy().astype(bool) upper_bounds = rdkit_bounds_constraints["upper_bound"].copy().astype(np.float32) lower_bounds = rdkit_bounds_constraints["lower_bound"].copy().astype(np.float32) dists = np.linalg.norm(coords[pair_index[0]] - coords[pair_index[1]], axis=-1) bond_length_violations = ( dists[bond_mask] <= lower_bounds[bond_mask] * (1.0 - bond_buffer) ) + (dists[bond_mask] >= upper_bounds[bond_mask] * (1.0 + bond_buffer)) bond_angle_violations = ( dists[angle_mask] <= lower_bounds[angle_mask] * (1.0 - angle_buffer) ) + (dists[angle_mask] >= upper_bounds[angle_mask] * (1.0 + angle_buffer)) internal_clash_violations = dists[~bond_mask * ~angle_mask] <= lower_bounds[ ~bond_mask * ~angle_mask ] * (1.0 - clash_buffer) num_ligands = sum( [ int(const.chain_types[chain["mol_type"]] == "NONPOLYMER") for chain in structure.chains ] ) return { "num_ligands": num_ligands, "num_bond_length_violations": bond_length_violations.sum(), "num_bonds": bond_mask.sum(), "num_bond_angle_violations": bond_angle_violations.sum(), "num_angles": angle_mask.sum(), "num_internal_clash_violations": internal_clash_violations.sum(), "num_non_neighbors": (~bond_mask * ~angle_mask).sum(), } def check_ligand_stereochemistry(structure, constraints): coords = structure.coords["coords"] chiral_atom_constraints = constraints.chiral_atom_constraints stereo_bond_constraints = constraints.stereo_bond_constraints chiral_atom_index = chiral_atom_constraints["atom_idxs"].T true_chiral_atom_orientations = chiral_atom_constraints["is_r"] chiral_atom_ref_mask = chiral_atom_constraints["is_reference"] chiral_atom_index = chiral_atom_index[:, chiral_atom_ref_mask] true_chiral_atom_orientations = true_chiral_atom_orientations[chiral_atom_ref_mask] pred_chiral_atom_orientations = ( compute_torsion_angles(coords, chiral_atom_index) > 0 ) chiral_atom_violations = ( pred_chiral_atom_orientations != true_chiral_atom_orientations ) stereo_bond_index = stereo_bond_constraints["atom_idxs"].T true_stereo_bond_orientations = stereo_bond_constraints["is_e"] stereo_bond_ref_mask = stereo_bond_constraints["is_reference"] stereo_bond_index = stereo_bond_index[:, stereo_bond_ref_mask] true_stereo_bond_orientations = true_stereo_bond_orientations[stereo_bond_ref_mask] pred_stereo_bond_orientations = ( np.abs(compute_torsion_angles(coords, stereo_bond_index)) > np.pi / 2 ) stereo_bond_violations = ( pred_stereo_bond_orientations != true_stereo_bond_orientations ) return { "num_chiral_atom_violations": chiral_atom_violations.sum(), "num_chiral_atoms": chiral_atom_index.shape[1], "num_stereo_bond_violations": stereo_bond_violations.sum(), "num_stereo_bonds": stereo_bond_index.shape[1], } def check_ligand_flatness(structure, constraints, buffer=0.25): coords = structure.coords["coords"] planar_ring_5_index = constraints.planar_ring_5_constraints["atom_idxs"] ring_5_coords = coords[planar_ring_5_index, :] centered_ring_5_coords = ring_5_coords - ring_5_coords.mean(axis=-2, keepdims=True) ring_5_vecs = np.linalg.svd(centered_ring_5_coords)[2][..., -1, :, None] ring_5_dists = np.abs((centered_ring_5_coords @ ring_5_vecs).squeeze(axis=-1)) ring_5_violations = np.all(ring_5_dists <= buffer, axis=-1) planar_ring_6_index = constraints.planar_ring_6_constraints["atom_idxs"] ring_6_coords = coords[planar_ring_6_index, :] centered_ring_6_coords = ring_6_coords - ring_6_coords.mean(axis=-2, keepdims=True) ring_6_vecs = np.linalg.svd(centered_ring_6_coords)[2][..., -1, :, None] ring_6_dists = np.abs((centered_ring_6_coords @ ring_6_vecs)).squeeze(axis=-1) ring_6_violations = np.any(ring_6_dists >= buffer, axis=-1) planar_bond_index = constraints.planar_bond_constraints["atom_idxs"] bond_coords = coords[planar_bond_index, :] centered_bond_coords = bond_coords - bond_coords.mean(axis=-2, keepdims=True) bond_vecs = np.linalg.svd(centered_bond_coords)[2][..., -1, :, None] bond_dists = np.abs((centered_bond_coords @ bond_vecs)).squeeze(axis=-1) bond_violations = np.any(bond_dists >= buffer, axis=-1) return { "num_planar_5_ring_violations": ring_5_violations.sum(), "num_planar_5_rings": ring_5_violations.shape[0], "num_planar_6_ring_violations": ring_6_violations.sum(), "num_planar_6_rings": ring_6_violations.shape[0], "num_planar_double_bond_violations": bond_violations.sum(), "num_planar_double_bonds": bond_violations.shape[0], } def check_steric_clash(structure, molecules, buffer=0.25): result = {} for type_i in const.chain_types: out_type_i = type_i.lower() out_type_i = out_type_i if out_type_i != "nonpolymer" else "ligand" result[f"num_chain_pairs_sym_{out_type_i}"] = 0 result[f"num_chain_clashes_sym_{out_type_i}"] = 0 for type_j in const.chain_types: out_type_j = type_j.lower() out_type_j = out_type_j if out_type_j != "nonpolymer" else "ligand" result[f"num_chain_pairs_asym_{out_type_i}_{out_type_j}"] = 0 result[f"num_chain_clashes_asym_{out_type_i}_{out_type_j}"] = 0 connected_chains = set() for bond in structure.bonds: if bond["chain_1"] != bond["chain_2"]: connected_chains.add(tuple(sorted((bond["chain_1"], bond["chain_2"])))) vdw_radii = [] for res in structure.residues: mol = molecules[res["name"]] token_atoms = structure.atoms[ res["atom_idx"] : res["atom_idx"] + res["atom_num"] ] atom_name_to_ref = {a.GetProp("name"): a for a in mol.GetAtoms()} token_atoms_ref = [atom_name_to_ref[a["name"]] for a in token_atoms] vdw_radii.extend( [const.vdw_radii[a.GetAtomicNum() - 1] for a in token_atoms_ref] ) vdw_radii = np.array(vdw_radii, dtype=np.float32) np.array([a.GetAtomicNum() for a in token_atoms_ref]) for i, chain_i in enumerate(structure.chains): for j, chain_j in enumerate(structure.chains): if ( chain_i["atom_num"] == 1 or chain_j["atom_num"] == 1 or j <= i or (i, j) in connected_chains ): continue coords_i = structure.coords["coords"][ chain_i["atom_idx"] : chain_i["atom_idx"] + chain_i["atom_num"] ] coords_j = structure.coords["coords"][ chain_j["atom_idx"] : chain_j["atom_idx"] + chain_j["atom_num"] ] dists = np.linalg.norm(coords_i[:, None, :] - coords_j[None, :, :], axis=-1) radii_i = vdw_radii[ chain_i["atom_idx"] : chain_i["atom_idx"] + chain_i["atom_num"] ] radii_j = vdw_radii[ chain_j["atom_idx"] : chain_j["atom_idx"] + chain_j["atom_num"] ] radii_sum = radii_i[:, None] + radii_j[None, :] is_clashing = np.any(dists < radii_sum * (1.00 - buffer)) type_i = const.chain_types[chain_i["mol_type"]].lower() type_j = const.chain_types[chain_j["mol_type"]].lower() type_i = type_i if type_i != "nonpolymer" else "ligand" type_j = type_j if type_j != "nonpolymer" else "ligand" is_symmetric = ( chain_i["entity_id"] == chain_j["entity_id"] and chain_i["atom_num"] == chain_j["atom_num"] ) if is_symmetric: key = "sym_" + type_i else: key = "asym_" + type_i + "_" + type_j result["num_chain_pairs_" + key] += 1 result["num_chain_clashes_" + key] += int(is_clashing) return result cache_dir = Path("/data/rbg/users/jwohlwend/boltz-cache") ccd_path = cache_dir / "ccd.pkl" moldir = cache_dir / "mols" with ccd_path.open("rb") as file: ccd = pickle.load(file) boltz1_dir = Path( "/data/rbg/shared/projects/foldeverything/boltz_results_final/outputs/test/boltz/predictions" ) boltz1x_dir = Path( "/data/scratch/getzn/boltz_private/boltz_1x_test_results_final_new/full_predictions" ) chai_dir = Path( "/data/rbg/shared/projects/foldeverything/boltz_results_final/outputs/test/chai" ) af3_dir = Path( "/data/rbg/shared/projects/foldeverything/boltz_results_final/outputs/test/af3" ) boltz1_pdb_ids = set(os.listdir(boltz1_dir)) boltz1x_pdb_ids = set(os.listdir(boltz1x_dir)) chai_pdb_ids = set(os.listdir(chai_dir)) af3_pdb_ids = set([pdb_id for pdb_id in os.listdir(af3_dir)]) common_pdb_ids = boltz1_pdb_ids & boltz1x_pdb_ids & chai_pdb_ids & af3_pdb_ids tools = ["boltz1", "boltz1x", "chai", "af3"] num_samples = 5 def process_fn(key): tool, pdb_id, model_idx = key if tool == "boltz1": cif_path = boltz1_dir / pdb_id / f"{pdb_id}_model_{model_idx}.cif" elif tool == "boltz1x": cif_path = boltz1x_dir / pdb_id / f"{pdb_id}_model_{model_idx}.cif" elif tool == "chai": cif_path = chai_dir / pdb_id / f"pred.model_idx_{model_idx}.cif" elif tool == "af3": cif_path = af3_dir / pdb_id.lower() / f"seed-1_sample-{model_idx}" / "model.cif" parsed_structure = parse_mmcif( cif_path, ccd, moldir, ) structure = parsed_structure.data constraints = parsed_structure.residue_constraints record = { "tool": tool, "pdb_id": pdb_id, "model_idx": model_idx, } record.update(check_ligand_distance_geometry(structure, constraints)) record.update(check_ligand_stereochemistry(structure, constraints)) record.update(check_ligand_flatness(structure, constraints)) record.update(check_steric_clash(structure, molecules=ccd)) return record keys = [] for tool in tools: for pdb_id in common_pdb_ids: for model_idx in range(num_samples): keys.append((tool, pdb_id, model_idx)) process_fn(keys[0]) records = [] with Pool(48) as p: with tqdm(total=len(keys)) as pbar: for record in p.imap_unordered(process_fn, keys): records.append(record) pbar.update(1) df = pd.DataFrame.from_records(records) df["num_chain_clashes_all"] = df[ [key for key in df.columns if "chain_clash" in key] ].sum(axis=1) df["num_pairs_all"] = df[[key for key in df.columns if "chain_pair" in key]].sum(axis=1) df["clash_free"] = df["num_chain_clashes_all"] == 0 df["valid_ligand"] = ( df[[key for key in df.columns if "violation" in key]].sum(axis=1) == 0 ) df["valid"] = (df["clash_free"]) & (df["valid_ligand"]) df.to_csv("physical_checks_test.csv")