directionality_probe / protify /FastPLMs /boltz /scripts /eval /physcialsim_metrics.py
nikraf's picture
Upload folder using huggingface_hub
714cf46 verified
import os
import pickle
import numpy as np
import torch
from pathlib import Path
from tqdm import tqdm
import pandas as pd
from boltz.data.mol import load_molecules
from boltz.data import const
from boltz.data.parse.mmcif_with_constraints import parse_mmcif
from multiprocessing import Pool
def compute_torsion_angles(coords, torsion_index):
r_ij = coords[..., torsion_index[0], :] - coords[..., torsion_index[1], :]
r_kj = coords[..., torsion_index[2], :] - coords[..., torsion_index[1], :]
r_kl = coords[..., torsion_index[2], :] - coords[..., torsion_index[3], :]
n_ijk = np.cross(r_ij, r_kj, axis=-1)
n_jkl = np.cross(r_kj, r_kl, axis=-1)
r_kj_norm = np.linalg.norm(r_kj, axis=-1)
n_ijk_norm = np.linalg.norm(n_ijk, axis=-1)
n_jkl_norm = np.linalg.norm(n_jkl, axis=-1)
sign_phi = np.sign(
r_kj[..., None, :] @ np.cross(n_ijk, n_jkl, axis=-1)[..., None]
).squeeze(axis=(-1, -2))
phi = sign_phi * np.arccos(
np.clip(
(n_ijk[..., None, :] @ n_jkl[..., None]).squeeze(axis=(-1, -2))
/ (n_ijk_norm * n_jkl_norm),
-1 + 1e-8,
1 - 1e-8,
)
)
return phi
def check_ligand_distance_geometry(
structure, constraints, bond_buffer=0.25, angle_buffer=0.25, clash_buffer=0.2
):
coords = structure.coords["coords"]
rdkit_bounds_constraints = constraints.rdkit_bounds_constraints
pair_index = rdkit_bounds_constraints["atom_idxs"].copy().astype(np.int64).T
bond_mask = rdkit_bounds_constraints["is_bond"].copy().astype(bool)
angle_mask = rdkit_bounds_constraints["is_angle"].copy().astype(bool)
upper_bounds = rdkit_bounds_constraints["upper_bound"].copy().astype(np.float32)
lower_bounds = rdkit_bounds_constraints["lower_bound"].copy().astype(np.float32)
dists = np.linalg.norm(coords[pair_index[0]] - coords[pair_index[1]], axis=-1)
bond_length_violations = (
dists[bond_mask] <= lower_bounds[bond_mask] * (1.0 - bond_buffer)
) + (dists[bond_mask] >= upper_bounds[bond_mask] * (1.0 + bond_buffer))
bond_angle_violations = (
dists[angle_mask] <= lower_bounds[angle_mask] * (1.0 - angle_buffer)
) + (dists[angle_mask] >= upper_bounds[angle_mask] * (1.0 + angle_buffer))
internal_clash_violations = dists[~bond_mask * ~angle_mask] <= lower_bounds[
~bond_mask * ~angle_mask
] * (1.0 - clash_buffer)
num_ligands = sum(
[
int(const.chain_types[chain["mol_type"]] == "NONPOLYMER")
for chain in structure.chains
]
)
return {
"num_ligands": num_ligands,
"num_bond_length_violations": bond_length_violations.sum(),
"num_bonds": bond_mask.sum(),
"num_bond_angle_violations": bond_angle_violations.sum(),
"num_angles": angle_mask.sum(),
"num_internal_clash_violations": internal_clash_violations.sum(),
"num_non_neighbors": (~bond_mask * ~angle_mask).sum(),
}
def check_ligand_stereochemistry(structure, constraints):
coords = structure.coords["coords"]
chiral_atom_constraints = constraints.chiral_atom_constraints
stereo_bond_constraints = constraints.stereo_bond_constraints
chiral_atom_index = chiral_atom_constraints["atom_idxs"].T
true_chiral_atom_orientations = chiral_atom_constraints["is_r"]
chiral_atom_ref_mask = chiral_atom_constraints["is_reference"]
chiral_atom_index = chiral_atom_index[:, chiral_atom_ref_mask]
true_chiral_atom_orientations = true_chiral_atom_orientations[chiral_atom_ref_mask]
pred_chiral_atom_orientations = (
compute_torsion_angles(coords, chiral_atom_index) > 0
)
chiral_atom_violations = (
pred_chiral_atom_orientations != true_chiral_atom_orientations
)
stereo_bond_index = stereo_bond_constraints["atom_idxs"].T
true_stereo_bond_orientations = stereo_bond_constraints["is_e"]
stereo_bond_ref_mask = stereo_bond_constraints["is_reference"]
stereo_bond_index = stereo_bond_index[:, stereo_bond_ref_mask]
true_stereo_bond_orientations = true_stereo_bond_orientations[stereo_bond_ref_mask]
pred_stereo_bond_orientations = (
np.abs(compute_torsion_angles(coords, stereo_bond_index)) > np.pi / 2
)
stereo_bond_violations = (
pred_stereo_bond_orientations != true_stereo_bond_orientations
)
return {
"num_chiral_atom_violations": chiral_atom_violations.sum(),
"num_chiral_atoms": chiral_atom_index.shape[1],
"num_stereo_bond_violations": stereo_bond_violations.sum(),
"num_stereo_bonds": stereo_bond_index.shape[1],
}
def check_ligand_flatness(structure, constraints, buffer=0.25):
coords = structure.coords["coords"]
planar_ring_5_index = constraints.planar_ring_5_constraints["atom_idxs"]
ring_5_coords = coords[planar_ring_5_index, :]
centered_ring_5_coords = ring_5_coords - ring_5_coords.mean(axis=-2, keepdims=True)
ring_5_vecs = np.linalg.svd(centered_ring_5_coords)[2][..., -1, :, None]
ring_5_dists = np.abs((centered_ring_5_coords @ ring_5_vecs).squeeze(axis=-1))
ring_5_violations = np.all(ring_5_dists <= buffer, axis=-1)
planar_ring_6_index = constraints.planar_ring_6_constraints["atom_idxs"]
ring_6_coords = coords[planar_ring_6_index, :]
centered_ring_6_coords = ring_6_coords - ring_6_coords.mean(axis=-2, keepdims=True)
ring_6_vecs = np.linalg.svd(centered_ring_6_coords)[2][..., -1, :, None]
ring_6_dists = np.abs((centered_ring_6_coords @ ring_6_vecs)).squeeze(axis=-1)
ring_6_violations = np.any(ring_6_dists >= buffer, axis=-1)
planar_bond_index = constraints.planar_bond_constraints["atom_idxs"]
bond_coords = coords[planar_bond_index, :]
centered_bond_coords = bond_coords - bond_coords.mean(axis=-2, keepdims=True)
bond_vecs = np.linalg.svd(centered_bond_coords)[2][..., -1, :, None]
bond_dists = np.abs((centered_bond_coords @ bond_vecs)).squeeze(axis=-1)
bond_violations = np.any(bond_dists >= buffer, axis=-1)
return {
"num_planar_5_ring_violations": ring_5_violations.sum(),
"num_planar_5_rings": ring_5_violations.shape[0],
"num_planar_6_ring_violations": ring_6_violations.sum(),
"num_planar_6_rings": ring_6_violations.shape[0],
"num_planar_double_bond_violations": bond_violations.sum(),
"num_planar_double_bonds": bond_violations.shape[0],
}
def check_steric_clash(structure, molecules, buffer=0.25):
result = {}
for type_i in const.chain_types:
out_type_i = type_i.lower()
out_type_i = out_type_i if out_type_i != "nonpolymer" else "ligand"
result[f"num_chain_pairs_sym_{out_type_i}"] = 0
result[f"num_chain_clashes_sym_{out_type_i}"] = 0
for type_j in const.chain_types:
out_type_j = type_j.lower()
out_type_j = out_type_j if out_type_j != "nonpolymer" else "ligand"
result[f"num_chain_pairs_asym_{out_type_i}_{out_type_j}"] = 0
result[f"num_chain_clashes_asym_{out_type_i}_{out_type_j}"] = 0
connected_chains = set()
for bond in structure.bonds:
if bond["chain_1"] != bond["chain_2"]:
connected_chains.add(tuple(sorted((bond["chain_1"], bond["chain_2"]))))
vdw_radii = []
for res in structure.residues:
mol = molecules[res["name"]]
token_atoms = structure.atoms[
res["atom_idx"] : res["atom_idx"] + res["atom_num"]
]
atom_name_to_ref = {a.GetProp("name"): a for a in mol.GetAtoms()}
token_atoms_ref = [atom_name_to_ref[a["name"]] for a in token_atoms]
vdw_radii.extend(
[const.vdw_radii[a.GetAtomicNum() - 1] for a in token_atoms_ref]
)
vdw_radii = np.array(vdw_radii, dtype=np.float32)
np.array([a.GetAtomicNum() for a in token_atoms_ref])
for i, chain_i in enumerate(structure.chains):
for j, chain_j in enumerate(structure.chains):
if (
chain_i["atom_num"] == 1
or chain_j["atom_num"] == 1
or j <= i
or (i, j) in connected_chains
):
continue
coords_i = structure.coords["coords"][
chain_i["atom_idx"] : chain_i["atom_idx"] + chain_i["atom_num"]
]
coords_j = structure.coords["coords"][
chain_j["atom_idx"] : chain_j["atom_idx"] + chain_j["atom_num"]
]
dists = np.linalg.norm(coords_i[:, None, :] - coords_j[None, :, :], axis=-1)
radii_i = vdw_radii[
chain_i["atom_idx"] : chain_i["atom_idx"] + chain_i["atom_num"]
]
radii_j = vdw_radii[
chain_j["atom_idx"] : chain_j["atom_idx"] + chain_j["atom_num"]
]
radii_sum = radii_i[:, None] + radii_j[None, :]
is_clashing = np.any(dists < radii_sum * (1.00 - buffer))
type_i = const.chain_types[chain_i["mol_type"]].lower()
type_j = const.chain_types[chain_j["mol_type"]].lower()
type_i = type_i if type_i != "nonpolymer" else "ligand"
type_j = type_j if type_j != "nonpolymer" else "ligand"
is_symmetric = (
chain_i["entity_id"] == chain_j["entity_id"]
and chain_i["atom_num"] == chain_j["atom_num"]
)
if is_symmetric:
key = "sym_" + type_i
else:
key = "asym_" + type_i + "_" + type_j
result["num_chain_pairs_" + key] += 1
result["num_chain_clashes_" + key] += int(is_clashing)
return result
cache_dir = Path("/data/rbg/users/jwohlwend/boltz-cache")
ccd_path = cache_dir / "ccd.pkl"
moldir = cache_dir / "mols"
with ccd_path.open("rb") as file:
ccd = pickle.load(file)
boltz1_dir = Path(
"/data/rbg/shared/projects/foldeverything/boltz_results_final/outputs/test/boltz/predictions"
)
boltz1x_dir = Path(
"/data/scratch/getzn/boltz_private/boltz_1x_test_results_final_new/full_predictions"
)
chai_dir = Path(
"/data/rbg/shared/projects/foldeverything/boltz_results_final/outputs/test/chai"
)
af3_dir = Path(
"/data/rbg/shared/projects/foldeverything/boltz_results_final/outputs/test/af3"
)
boltz1_pdb_ids = set(os.listdir(boltz1_dir))
boltz1x_pdb_ids = set(os.listdir(boltz1x_dir))
chai_pdb_ids = set(os.listdir(chai_dir))
af3_pdb_ids = set([pdb_id for pdb_id in os.listdir(af3_dir)])
common_pdb_ids = boltz1_pdb_ids & boltz1x_pdb_ids & chai_pdb_ids & af3_pdb_ids
tools = ["boltz1", "boltz1x", "chai", "af3"]
num_samples = 5
def process_fn(key):
tool, pdb_id, model_idx = key
if tool == "boltz1":
cif_path = boltz1_dir / pdb_id / f"{pdb_id}_model_{model_idx}.cif"
elif tool == "boltz1x":
cif_path = boltz1x_dir / pdb_id / f"{pdb_id}_model_{model_idx}.cif"
elif tool == "chai":
cif_path = chai_dir / pdb_id / f"pred.model_idx_{model_idx}.cif"
elif tool == "af3":
cif_path = af3_dir / pdb_id.lower() / f"seed-1_sample-{model_idx}" / "model.cif"
parsed_structure = parse_mmcif(
cif_path,
ccd,
moldir,
)
structure = parsed_structure.data
constraints = parsed_structure.residue_constraints
record = {
"tool": tool,
"pdb_id": pdb_id,
"model_idx": model_idx,
}
record.update(check_ligand_distance_geometry(structure, constraints))
record.update(check_ligand_stereochemistry(structure, constraints))
record.update(check_ligand_flatness(structure, constraints))
record.update(check_steric_clash(structure, molecules=ccd))
return record
keys = []
for tool in tools:
for pdb_id in common_pdb_ids:
for model_idx in range(num_samples):
keys.append((tool, pdb_id, model_idx))
process_fn(keys[0])
records = []
with Pool(48) as p:
with tqdm(total=len(keys)) as pbar:
for record in p.imap_unordered(process_fn, keys):
records.append(record)
pbar.update(1)
df = pd.DataFrame.from_records(records)
df["num_chain_clashes_all"] = df[
[key for key in df.columns if "chain_clash" in key]
].sum(axis=1)
df["num_pairs_all"] = df[[key for key in df.columns if "chain_pair" in key]].sum(axis=1)
df["clash_free"] = df["num_chain_clashes_all"] == 0
df["valid_ligand"] = (
df[[key for key in df.columns if "violation" in key]].sum(axis=1) == 0
)
df["valid"] = (df["clash_free"]) & (df["valid_ligand"])
df.to_csv("physical_checks_test.csv")