| import os |
| import pickle |
|
|
| import numpy as np |
| import torch |
| from pathlib import Path |
| from tqdm import tqdm |
| import pandas as pd |
| from boltz.data.mol import load_molecules |
| from boltz.data import const |
| from boltz.data.parse.mmcif_with_constraints import parse_mmcif |
| from multiprocessing import Pool |
|
|
|
|
| def compute_torsion_angles(coords, torsion_index): |
| r_ij = coords[..., torsion_index[0], :] - coords[..., torsion_index[1], :] |
| r_kj = coords[..., torsion_index[2], :] - coords[..., torsion_index[1], :] |
| r_kl = coords[..., torsion_index[2], :] - coords[..., torsion_index[3], :] |
| n_ijk = np.cross(r_ij, r_kj, axis=-1) |
| n_jkl = np.cross(r_kj, r_kl, axis=-1) |
| r_kj_norm = np.linalg.norm(r_kj, axis=-1) |
| n_ijk_norm = np.linalg.norm(n_ijk, axis=-1) |
| n_jkl_norm = np.linalg.norm(n_jkl, axis=-1) |
| sign_phi = np.sign( |
| r_kj[..., None, :] @ np.cross(n_ijk, n_jkl, axis=-1)[..., None] |
| ).squeeze(axis=(-1, -2)) |
| phi = sign_phi * np.arccos( |
| np.clip( |
| (n_ijk[..., None, :] @ n_jkl[..., None]).squeeze(axis=(-1, -2)) |
| / (n_ijk_norm * n_jkl_norm), |
| -1 + 1e-8, |
| 1 - 1e-8, |
| ) |
| ) |
| return phi |
|
|
|
|
| def check_ligand_distance_geometry( |
| structure, constraints, bond_buffer=0.25, angle_buffer=0.25, clash_buffer=0.2 |
| ): |
| coords = structure.coords["coords"] |
| rdkit_bounds_constraints = constraints.rdkit_bounds_constraints |
| pair_index = rdkit_bounds_constraints["atom_idxs"].copy().astype(np.int64).T |
| bond_mask = rdkit_bounds_constraints["is_bond"].copy().astype(bool) |
| angle_mask = rdkit_bounds_constraints["is_angle"].copy().astype(bool) |
| upper_bounds = rdkit_bounds_constraints["upper_bound"].copy().astype(np.float32) |
| lower_bounds = rdkit_bounds_constraints["lower_bound"].copy().astype(np.float32) |
| dists = np.linalg.norm(coords[pair_index[0]] - coords[pair_index[1]], axis=-1) |
| bond_length_violations = ( |
| dists[bond_mask] <= lower_bounds[bond_mask] * (1.0 - bond_buffer) |
| ) + (dists[bond_mask] >= upper_bounds[bond_mask] * (1.0 + bond_buffer)) |
| bond_angle_violations = ( |
| dists[angle_mask] <= lower_bounds[angle_mask] * (1.0 - angle_buffer) |
| ) + (dists[angle_mask] >= upper_bounds[angle_mask] * (1.0 + angle_buffer)) |
| internal_clash_violations = dists[~bond_mask * ~angle_mask] <= lower_bounds[ |
| ~bond_mask * ~angle_mask |
| ] * (1.0 - clash_buffer) |
| num_ligands = sum( |
| [ |
| int(const.chain_types[chain["mol_type"]] == "NONPOLYMER") |
| for chain in structure.chains |
| ] |
| ) |
| return { |
| "num_ligands": num_ligands, |
| "num_bond_length_violations": bond_length_violations.sum(), |
| "num_bonds": bond_mask.sum(), |
| "num_bond_angle_violations": bond_angle_violations.sum(), |
| "num_angles": angle_mask.sum(), |
| "num_internal_clash_violations": internal_clash_violations.sum(), |
| "num_non_neighbors": (~bond_mask * ~angle_mask).sum(), |
| } |
|
|
|
|
| def check_ligand_stereochemistry(structure, constraints): |
| coords = structure.coords["coords"] |
| chiral_atom_constraints = constraints.chiral_atom_constraints |
| stereo_bond_constraints = constraints.stereo_bond_constraints |
|
|
| chiral_atom_index = chiral_atom_constraints["atom_idxs"].T |
| true_chiral_atom_orientations = chiral_atom_constraints["is_r"] |
| chiral_atom_ref_mask = chiral_atom_constraints["is_reference"] |
| chiral_atom_index = chiral_atom_index[:, chiral_atom_ref_mask] |
| true_chiral_atom_orientations = true_chiral_atom_orientations[chiral_atom_ref_mask] |
| pred_chiral_atom_orientations = ( |
| compute_torsion_angles(coords, chiral_atom_index) > 0 |
| ) |
| chiral_atom_violations = ( |
| pred_chiral_atom_orientations != true_chiral_atom_orientations |
| ) |
|
|
| stereo_bond_index = stereo_bond_constraints["atom_idxs"].T |
| true_stereo_bond_orientations = stereo_bond_constraints["is_e"] |
| stereo_bond_ref_mask = stereo_bond_constraints["is_reference"] |
| stereo_bond_index = stereo_bond_index[:, stereo_bond_ref_mask] |
| true_stereo_bond_orientations = true_stereo_bond_orientations[stereo_bond_ref_mask] |
| pred_stereo_bond_orientations = ( |
| np.abs(compute_torsion_angles(coords, stereo_bond_index)) > np.pi / 2 |
| ) |
| stereo_bond_violations = ( |
| pred_stereo_bond_orientations != true_stereo_bond_orientations |
| ) |
|
|
| return { |
| "num_chiral_atom_violations": chiral_atom_violations.sum(), |
| "num_chiral_atoms": chiral_atom_index.shape[1], |
| "num_stereo_bond_violations": stereo_bond_violations.sum(), |
| "num_stereo_bonds": stereo_bond_index.shape[1], |
| } |
|
|
|
|
| def check_ligand_flatness(structure, constraints, buffer=0.25): |
| coords = structure.coords["coords"] |
|
|
| planar_ring_5_index = constraints.planar_ring_5_constraints["atom_idxs"] |
| ring_5_coords = coords[planar_ring_5_index, :] |
| centered_ring_5_coords = ring_5_coords - ring_5_coords.mean(axis=-2, keepdims=True) |
| ring_5_vecs = np.linalg.svd(centered_ring_5_coords)[2][..., -1, :, None] |
| ring_5_dists = np.abs((centered_ring_5_coords @ ring_5_vecs).squeeze(axis=-1)) |
| ring_5_violations = np.all(ring_5_dists <= buffer, axis=-1) |
|
|
| planar_ring_6_index = constraints.planar_ring_6_constraints["atom_idxs"] |
| ring_6_coords = coords[planar_ring_6_index, :] |
| centered_ring_6_coords = ring_6_coords - ring_6_coords.mean(axis=-2, keepdims=True) |
| ring_6_vecs = np.linalg.svd(centered_ring_6_coords)[2][..., -1, :, None] |
| ring_6_dists = np.abs((centered_ring_6_coords @ ring_6_vecs)).squeeze(axis=-1) |
| ring_6_violations = np.any(ring_6_dists >= buffer, axis=-1) |
|
|
| planar_bond_index = constraints.planar_bond_constraints["atom_idxs"] |
| bond_coords = coords[planar_bond_index, :] |
| centered_bond_coords = bond_coords - bond_coords.mean(axis=-2, keepdims=True) |
| bond_vecs = np.linalg.svd(centered_bond_coords)[2][..., -1, :, None] |
| bond_dists = np.abs((centered_bond_coords @ bond_vecs)).squeeze(axis=-1) |
| bond_violations = np.any(bond_dists >= buffer, axis=-1) |
|
|
| return { |
| "num_planar_5_ring_violations": ring_5_violations.sum(), |
| "num_planar_5_rings": ring_5_violations.shape[0], |
| "num_planar_6_ring_violations": ring_6_violations.sum(), |
| "num_planar_6_rings": ring_6_violations.shape[0], |
| "num_planar_double_bond_violations": bond_violations.sum(), |
| "num_planar_double_bonds": bond_violations.shape[0], |
| } |
|
|
|
|
| def check_steric_clash(structure, molecules, buffer=0.25): |
| result = {} |
| for type_i in const.chain_types: |
| out_type_i = type_i.lower() |
| out_type_i = out_type_i if out_type_i != "nonpolymer" else "ligand" |
| result[f"num_chain_pairs_sym_{out_type_i}"] = 0 |
| result[f"num_chain_clashes_sym_{out_type_i}"] = 0 |
| for type_j in const.chain_types: |
| out_type_j = type_j.lower() |
| out_type_j = out_type_j if out_type_j != "nonpolymer" else "ligand" |
| result[f"num_chain_pairs_asym_{out_type_i}_{out_type_j}"] = 0 |
| result[f"num_chain_clashes_asym_{out_type_i}_{out_type_j}"] = 0 |
|
|
| connected_chains = set() |
| for bond in structure.bonds: |
| if bond["chain_1"] != bond["chain_2"]: |
| connected_chains.add(tuple(sorted((bond["chain_1"], bond["chain_2"])))) |
|
|
| vdw_radii = [] |
| for res in structure.residues: |
| mol = molecules[res["name"]] |
| token_atoms = structure.atoms[ |
| res["atom_idx"] : res["atom_idx"] + res["atom_num"] |
| ] |
| atom_name_to_ref = {a.GetProp("name"): a for a in mol.GetAtoms()} |
| token_atoms_ref = [atom_name_to_ref[a["name"]] for a in token_atoms] |
| vdw_radii.extend( |
| [const.vdw_radii[a.GetAtomicNum() - 1] for a in token_atoms_ref] |
| ) |
| vdw_radii = np.array(vdw_radii, dtype=np.float32) |
|
|
| np.array([a.GetAtomicNum() for a in token_atoms_ref]) |
| for i, chain_i in enumerate(structure.chains): |
| for j, chain_j in enumerate(structure.chains): |
| if ( |
| chain_i["atom_num"] == 1 |
| or chain_j["atom_num"] == 1 |
| or j <= i |
| or (i, j) in connected_chains |
| ): |
| continue |
| coords_i = structure.coords["coords"][ |
| chain_i["atom_idx"] : chain_i["atom_idx"] + chain_i["atom_num"] |
| ] |
| coords_j = structure.coords["coords"][ |
| chain_j["atom_idx"] : chain_j["atom_idx"] + chain_j["atom_num"] |
| ] |
| dists = np.linalg.norm(coords_i[:, None, :] - coords_j[None, :, :], axis=-1) |
| radii_i = vdw_radii[ |
| chain_i["atom_idx"] : chain_i["atom_idx"] + chain_i["atom_num"] |
| ] |
| radii_j = vdw_radii[ |
| chain_j["atom_idx"] : chain_j["atom_idx"] + chain_j["atom_num"] |
| ] |
| radii_sum = radii_i[:, None] + radii_j[None, :] |
| is_clashing = np.any(dists < radii_sum * (1.00 - buffer)) |
| type_i = const.chain_types[chain_i["mol_type"]].lower() |
| type_j = const.chain_types[chain_j["mol_type"]].lower() |
| type_i = type_i if type_i != "nonpolymer" else "ligand" |
| type_j = type_j if type_j != "nonpolymer" else "ligand" |
| is_symmetric = ( |
| chain_i["entity_id"] == chain_j["entity_id"] |
| and chain_i["atom_num"] == chain_j["atom_num"] |
| ) |
| if is_symmetric: |
| key = "sym_" + type_i |
| else: |
| key = "asym_" + type_i + "_" + type_j |
| result["num_chain_pairs_" + key] += 1 |
| result["num_chain_clashes_" + key] += int(is_clashing) |
| return result |
|
|
|
|
| cache_dir = Path("/data/rbg/users/jwohlwend/boltz-cache") |
| ccd_path = cache_dir / "ccd.pkl" |
| moldir = cache_dir / "mols" |
| with ccd_path.open("rb") as file: |
| ccd = pickle.load(file) |
|
|
| boltz1_dir = Path( |
| "/data/rbg/shared/projects/foldeverything/boltz_results_final/outputs/test/boltz/predictions" |
| ) |
| boltz1x_dir = Path( |
| "/data/scratch/getzn/boltz_private/boltz_1x_test_results_final_new/full_predictions" |
| ) |
| chai_dir = Path( |
| "/data/rbg/shared/projects/foldeverything/boltz_results_final/outputs/test/chai" |
| ) |
| af3_dir = Path( |
| "/data/rbg/shared/projects/foldeverything/boltz_results_final/outputs/test/af3" |
| ) |
|
|
| boltz1_pdb_ids = set(os.listdir(boltz1_dir)) |
| boltz1x_pdb_ids = set(os.listdir(boltz1x_dir)) |
| chai_pdb_ids = set(os.listdir(chai_dir)) |
| af3_pdb_ids = set([pdb_id for pdb_id in os.listdir(af3_dir)]) |
| common_pdb_ids = boltz1_pdb_ids & boltz1x_pdb_ids & chai_pdb_ids & af3_pdb_ids |
|
|
| tools = ["boltz1", "boltz1x", "chai", "af3"] |
| num_samples = 5 |
|
|
|
|
| def process_fn(key): |
| tool, pdb_id, model_idx = key |
| if tool == "boltz1": |
| cif_path = boltz1_dir / pdb_id / f"{pdb_id}_model_{model_idx}.cif" |
| elif tool == "boltz1x": |
| cif_path = boltz1x_dir / pdb_id / f"{pdb_id}_model_{model_idx}.cif" |
| elif tool == "chai": |
| cif_path = chai_dir / pdb_id / f"pred.model_idx_{model_idx}.cif" |
| elif tool == "af3": |
| cif_path = af3_dir / pdb_id.lower() / f"seed-1_sample-{model_idx}" / "model.cif" |
|
|
| parsed_structure = parse_mmcif( |
| cif_path, |
| ccd, |
| moldir, |
| ) |
| structure = parsed_structure.data |
| constraints = parsed_structure.residue_constraints |
|
|
| record = { |
| "tool": tool, |
| "pdb_id": pdb_id, |
| "model_idx": model_idx, |
| } |
| record.update(check_ligand_distance_geometry(structure, constraints)) |
| record.update(check_ligand_stereochemistry(structure, constraints)) |
| record.update(check_ligand_flatness(structure, constraints)) |
| record.update(check_steric_clash(structure, molecules=ccd)) |
| return record |
|
|
|
|
| keys = [] |
| for tool in tools: |
| for pdb_id in common_pdb_ids: |
| for model_idx in range(num_samples): |
| keys.append((tool, pdb_id, model_idx)) |
|
|
| process_fn(keys[0]) |
| records = [] |
| with Pool(48) as p: |
| with tqdm(total=len(keys)) as pbar: |
| for record in p.imap_unordered(process_fn, keys): |
| records.append(record) |
| pbar.update(1) |
| df = pd.DataFrame.from_records(records) |
|
|
| df["num_chain_clashes_all"] = df[ |
| [key for key in df.columns if "chain_clash" in key] |
| ].sum(axis=1) |
| df["num_pairs_all"] = df[[key for key in df.columns if "chain_pair" in key]].sum(axis=1) |
| df["clash_free"] = df["num_chain_clashes_all"] == 0 |
| df["valid_ligand"] = ( |
| df[[key for key in df.columns if "violation" in key]].sum(axis=1) == 0 |
| ) |
| df["valid"] = (df["clash_free"]) & (df["valid_ligand"]) |
|
|
| df.to_csv("physical_checks_test.csv") |
|
|