|
|
import binascii |
|
|
import glob |
|
|
import os |
|
|
import pickle |
|
|
from collections import defaultdict |
|
|
from multiprocessing import Pool |
|
|
import random |
|
|
import copy |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
import torch |
|
|
from rdkit import Chem |
|
|
from rdkit.Chem import MolFromSmiles, AddHs |
|
|
from torch_geometric.data import Dataset, HeteroData |
|
|
from torch_geometric.transforms import BaseTransform |
|
|
from tqdm import tqdm |
|
|
from rdkit.Chem import RemoveAllHs |
|
|
|
|
|
from datasets.process_mols import read_molecule, get_lig_graph_with_matching, generate_conformer, moad_extract_receptor_structure |
|
|
from utils.diffusion_utils import modify_conformer, set_time |
|
|
from utils.utils import read_strings_from_txt, crop_beyond |
|
|
from utils import so3, torus |
|
|
|
|
|
|
|
|
class NoiseTransform(BaseTransform): |
|
|
def __init__(self, t_to_sigma, no_torsion, all_atom, alpha=1, beta=1, |
|
|
include_miscellaneous_atoms=False, crop_beyond_cutoff=None, time_independent=False, rmsd_cutoff=0, |
|
|
minimum_t=0, sampling_mixing_coeff=0): |
|
|
self.t_to_sigma = t_to_sigma |
|
|
self.no_torsion = no_torsion |
|
|
self.all_atom = all_atom |
|
|
self.include_miscellaneous_atoms = include_miscellaneous_atoms |
|
|
self.minimum_t = minimum_t |
|
|
self.mixing_coeff = sampling_mixing_coeff |
|
|
self.alpha = alpha |
|
|
self.beta = beta |
|
|
self.crop_beyond_cutoff = crop_beyond_cutoff |
|
|
self.rmsd_cutoff = rmsd_cutoff |
|
|
self.time_independent = time_independent |
|
|
|
|
|
def __call__(self, data): |
|
|
t_tr, t_rot, t_tor, t = self.get_time() |
|
|
return self.apply_noise(data, t_tr, t_rot, t_tor, t) |
|
|
|
|
|
def get_time(self): |
|
|
if self.time_independent: |
|
|
t = np.random.beta(self.alpha, self.beta) |
|
|
t_tr, t_rot, t_tor = t,t,t |
|
|
else: |
|
|
t = None |
|
|
if self.mixing_coeff == 0: |
|
|
t = np.random.beta(self.alpha, self.beta) |
|
|
t = self.minimum_t + t * (1 - self.minimum_t) |
|
|
else: |
|
|
choice = np.random.binomial(1, self.mixing_coeff) |
|
|
t1 = np.random.beta(self.alpha, self.beta) |
|
|
t1 = t1 * self.minimum_t |
|
|
t2 = np.random.beta(self.alpha, self.beta) |
|
|
t2 = self.minimum_t + t2 * (1 - self.minimum_t) |
|
|
t = choice * t1 + (1 - choice) * t2 |
|
|
|
|
|
t_tr, t_rot, t_tor = t,t,t |
|
|
return t_tr, t_rot, t_tor, t |
|
|
|
|
|
def apply_noise(self, data, t_tr, t_rot, t_tor, t, tr_update = None, rot_update=None, torsion_updates=None): |
|
|
if not torch.is_tensor(data['ligand'].pos): |
|
|
data['ligand'].pos = random.choice(data['ligand'].pos) |
|
|
|
|
|
if self.time_independent: |
|
|
orig_complex_graph = copy.deepcopy(data) |
|
|
|
|
|
tr_sigma, rot_sigma, tor_sigma = self.t_to_sigma(t_tr, t_rot, t_tor) |
|
|
|
|
|
if self.time_independent: |
|
|
set_time(data, 0, 0, 0, 0, 1, self.all_atom, device=None, include_miscellaneous_atoms=self.include_miscellaneous_atoms) |
|
|
else: |
|
|
set_time(data, t, t_tr, t_rot, t_tor, 1, self.all_atom, device=None, include_miscellaneous_atoms=self.include_miscellaneous_atoms) |
|
|
|
|
|
tr_update = torch.normal(mean=0, std=tr_sigma, size=(1, 3)) if tr_update is None else tr_update |
|
|
rot_update = so3.sample_vec(eps=rot_sigma) if rot_update is None else rot_update |
|
|
torsion_updates = np.random.normal(loc=0.0, scale=tor_sigma, size=data['ligand'].edge_mask.sum()) if torsion_updates is None else torsion_updates |
|
|
torsion_updates = None if self.no_torsion else torsion_updates |
|
|
try: |
|
|
modify_conformer(data, tr_update, torch.from_numpy(rot_update).float(), torsion_updates) |
|
|
except Exception as e: |
|
|
print("failed modify conformer") |
|
|
print(e) |
|
|
|
|
|
if self.time_independent: |
|
|
if self.no_torsion: |
|
|
orig_complex_graph['ligand'].orig_pos = (orig_complex_graph['ligand'].pos.cpu().numpy() + orig_complex_graph.original_center.cpu().numpy()) |
|
|
|
|
|
filterHs = torch.not_equal(data['ligand'].x[:, 0], 0).cpu().numpy() |
|
|
if isinstance(orig_complex_graph['ligand'].orig_pos, list): |
|
|
orig_complex_graph['ligand'].orig_pos = orig_complex_graph['ligand'].orig_pos[0] |
|
|
ligand_pos = data['ligand'].pos.cpu().numpy()[filterHs] |
|
|
orig_ligand_pos = orig_complex_graph['ligand'].orig_pos[filterHs] - orig_complex_graph.original_center.cpu().numpy() |
|
|
rmsd = np.sqrt(((ligand_pos - orig_ligand_pos) ** 2).sum(axis=1).mean(axis=0)) |
|
|
data.y = torch.tensor(rmsd < self.rmsd_cutoff).float().unsqueeze(0) |
|
|
data.atom_y = data.y |
|
|
return data |
|
|
|
|
|
data.tr_score = -tr_update / tr_sigma ** 2 |
|
|
data.rot_score = torch.from_numpy(so3.score_vec(vec=rot_update, eps=rot_sigma)).float().unsqueeze(0) |
|
|
data.tor_score = None if self.no_torsion else torch.from_numpy(torus.score(torsion_updates, tor_sigma)).float() |
|
|
data.tor_sigma_edge = None if self.no_torsion else np.ones(data['ligand'].edge_mask.sum()) * tor_sigma |
|
|
|
|
|
if data['ligand'].pos.shape[0] == 1: |
|
|
|
|
|
data.rot_score = data.rot_score * 0 |
|
|
|
|
|
if self.crop_beyond_cutoff is not None: |
|
|
crop_beyond(data, tr_sigma * 3 + self.crop_beyond_cutoff, self.all_atom) |
|
|
set_time(data, t, t_tr, t_rot, t_tor, 1, self.all_atom, device=None, include_miscellaneous_atoms=self.include_miscellaneous_atoms) |
|
|
return data |
|
|
|
|
|
|
|
|
class PDBBind(Dataset): |
|
|
def __init__(self, root, transform=None, cache_path='data/cache', split_path='data/', limit_complexes=0, chain_cutoff=10, |
|
|
receptor_radius=30, num_workers=1, c_alpha_max_neighbors=None, popsize=15, maxiter=15, |
|
|
matching=True, keep_original=False, max_lig_size=None, remove_hs=False, num_conformers=1, all_atoms=False, |
|
|
atom_radius=5, atom_max_neighbors=None, esm_embeddings_path=None, require_ligand=False, |
|
|
include_miscellaneous_atoms=False, |
|
|
protein_path_list=None, ligand_descriptions=None, keep_local_structures=False, |
|
|
protein_file="protein_processed", ligand_file="ligand", |
|
|
knn_only_graph=False, matching_tries=1, dataset='PDBBind'): |
|
|
|
|
|
super(PDBBind, self).__init__(root, transform) |
|
|
self.pdbbind_dir = root |
|
|
self.include_miscellaneous_atoms = include_miscellaneous_atoms |
|
|
self.max_lig_size = max_lig_size |
|
|
self.split_path = split_path |
|
|
self.limit_complexes = limit_complexes |
|
|
self.chain_cutoff = chain_cutoff |
|
|
self.receptor_radius = receptor_radius |
|
|
self.num_workers = num_workers |
|
|
self.c_alpha_max_neighbors = c_alpha_max_neighbors |
|
|
self.remove_hs = remove_hs |
|
|
self.esm_embeddings_path = esm_embeddings_path |
|
|
self.use_old_wrong_embedding_order = False |
|
|
self.require_ligand = require_ligand |
|
|
self.protein_path_list = protein_path_list |
|
|
self.ligand_descriptions = ligand_descriptions |
|
|
self.keep_local_structures = keep_local_structures |
|
|
self.protein_file = protein_file |
|
|
self.fixed_knn_radius_graph = True |
|
|
self.knn_only_graph = knn_only_graph |
|
|
self.matching_tries = matching_tries |
|
|
self.ligand_file = ligand_file |
|
|
self.dataset = dataset |
|
|
assert knn_only_graph or (not all_atoms) |
|
|
self.all_atoms = all_atoms |
|
|
if matching or protein_path_list is not None and ligand_descriptions is not None: |
|
|
cache_path += '_torsion' |
|
|
if all_atoms: |
|
|
cache_path += '_allatoms' |
|
|
self.full_cache_path = os.path.join(cache_path, f'{dataset}3_limit{self.limit_complexes}' |
|
|
f'_INDEX{os.path.splitext(os.path.basename(self.split_path))[0]}' |
|
|
f'_maxLigSize{self.max_lig_size}_H{int(not self.remove_hs)}' |
|
|
f'_recRad{self.receptor_radius}_recMax{self.c_alpha_max_neighbors}' |
|
|
f'_chainCutoff{self.chain_cutoff if self.chain_cutoff is None else int(self.chain_cutoff)}' |
|
|
+ (''if not all_atoms else f'_atomRad{atom_radius}_atomMax{atom_max_neighbors}') |
|
|
+ (''if not matching or num_conformers == 1 else f'_confs{num_conformers}') |
|
|
+ ('' if self.esm_embeddings_path is None else f'_esmEmbeddings') |
|
|
+ '_full' |
|
|
+ ('' if not keep_local_structures else f'_keptLocalStruct') |
|
|
+ ('' if protein_path_list is None or ligand_descriptions is None else str(binascii.crc32(''.join(ligand_descriptions + protein_path_list).encode()))) |
|
|
+ ('' if protein_file == "protein_processed" else '_' + protein_file) |
|
|
+ ('' if not self.fixed_knn_radius_graph else (f'_fixedKNN' if not self.knn_only_graph else '_fixedKNNonly')) |
|
|
+ ('' if not self.include_miscellaneous_atoms else '_miscAtoms') |
|
|
+ ('' if self.use_old_wrong_embedding_order else '_chainOrd') |
|
|
+ ('' if self.matching_tries == 1 else f'_tries{matching_tries}')) |
|
|
self.popsize, self.maxiter = popsize, maxiter |
|
|
self.matching, self.keep_original = matching, keep_original |
|
|
self.num_conformers = num_conformers |
|
|
|
|
|
self.atom_radius, self.atom_max_neighbors = atom_radius, atom_max_neighbors |
|
|
if not self.check_all_complexes(): |
|
|
os.makedirs(self.full_cache_path, exist_ok=True) |
|
|
if protein_path_list is None or ligand_descriptions is None: |
|
|
self.preprocessing() |
|
|
else: |
|
|
self.inference_preprocessing() |
|
|
|
|
|
self.complex_graphs, self.rdkit_ligands = self.collect_all_complexes() |
|
|
print_statistics(self.complex_graphs) |
|
|
list_names = [complex['name'] for complex in self.complex_graphs] |
|
|
with open(os.path.join(self.full_cache_path, f'pdbbind_{os.path.splitext(os.path.basename(self.split_path))[0][:3]}_names.txt'), 'w') as f: |
|
|
f.write('\n'.join(list_names)) |
|
|
|
|
|
def len(self): |
|
|
return len(self.complex_graphs) |
|
|
|
|
|
def get(self, idx): |
|
|
complex_graph = copy.deepcopy(self.complex_graphs[idx]) |
|
|
if self.require_ligand: |
|
|
complex_graph.mol = RemoveAllHs(copy.deepcopy(self.rdkit_ligands[idx])) |
|
|
|
|
|
for a in ['random_coords', 'coords', 'seq', 'sequence', 'mask', 'rmsd_matching', 'cluster', 'orig_seq', 'to_keep', 'chain_ids']: |
|
|
if hasattr(complex_graph, a): |
|
|
delattr(complex_graph, a) |
|
|
if hasattr(complex_graph['receptor'], a): |
|
|
delattr(complex_graph['receptor'], a) |
|
|
|
|
|
return complex_graph |
|
|
|
|
|
def preprocessing(self): |
|
|
print(f'Processing complexes from [{self.split_path}] and saving it to [{self.full_cache_path}]') |
|
|
|
|
|
complex_names_all = read_strings_from_txt(self.split_path) |
|
|
if self.limit_complexes is not None and self.limit_complexes != 0: |
|
|
complex_names_all = complex_names_all[:self.limit_complexes] |
|
|
print(f'Loading {len(complex_names_all)} complexes.') |
|
|
|
|
|
if self.esm_embeddings_path is not None: |
|
|
id_to_embeddings = torch.load(self.esm_embeddings_path) |
|
|
chain_embeddings_dictlist = defaultdict(list) |
|
|
chain_indices_dictlist = defaultdict(list) |
|
|
for key, embedding in id_to_embeddings.items(): |
|
|
key_name = key.split('_chain_')[0] |
|
|
if key_name in complex_names_all: |
|
|
chain_embeddings_dictlist[key_name].append(embedding) |
|
|
chain_indices_dictlist[key_name].append(int(key.split('_chain_')[1])) |
|
|
lm_embeddings_chains_all = [] |
|
|
for name in complex_names_all: |
|
|
complex_chains_embeddings = chain_embeddings_dictlist[name] |
|
|
complex_chains_indices = chain_indices_dictlist[name] |
|
|
chain_reorder_idx = np.argsort(complex_chains_indices) |
|
|
reordered_chains = [complex_chains_embeddings[i] for i in chain_reorder_idx] |
|
|
lm_embeddings_chains_all.append(reordered_chains) |
|
|
else: |
|
|
lm_embeddings_chains_all = [None] * len(complex_names_all) |
|
|
|
|
|
|
|
|
list_indices = list(range(len(complex_names_all)//1000+1)) |
|
|
random.shuffle(list_indices) |
|
|
for i in list_indices: |
|
|
if os.path.exists(os.path.join(self.full_cache_path, f"heterographs{i}.pkl")): |
|
|
continue |
|
|
complex_names = complex_names_all[1000*i:1000*(i+1)] |
|
|
lm_embeddings_chains = lm_embeddings_chains_all[1000*i:1000*(i+1)] |
|
|
complex_graphs, rdkit_ligands = [], [] |
|
|
if self.num_workers > 1: |
|
|
p = Pool(self.num_workers, maxtasksperchild=1) |
|
|
p.__enter__() |
|
|
with tqdm(total=len(complex_names), desc=f'loading complexes {i}/{len(complex_names_all)//1000+1}') as pbar: |
|
|
map_fn = p.imap_unordered if self.num_workers > 1 else map |
|
|
for t in map_fn(self.get_complex, zip(complex_names, lm_embeddings_chains, [None] * len(complex_names), [None] * len(complex_names))): |
|
|
complex_graphs.extend(t[0]) |
|
|
rdkit_ligands.extend(t[1]) |
|
|
pbar.update() |
|
|
if self.num_workers > 1: p.__exit__(None, None, None) |
|
|
|
|
|
with open(os.path.join(self.full_cache_path, f"heterographs{i}.pkl"), 'wb') as f: |
|
|
pickle.dump((complex_graphs), f) |
|
|
with open(os.path.join(self.full_cache_path, f"rdkit_ligands{i}.pkl"), 'wb') as f: |
|
|
pickle.dump((rdkit_ligands), f) |
|
|
|
|
|
def inference_preprocessing(self): |
|
|
ligands_list = [] |
|
|
print('Reading molecules and generating local structures with RDKit') |
|
|
for ligand_description in tqdm(self.ligand_descriptions): |
|
|
mol = MolFromSmiles(ligand_description) |
|
|
if mol is not None: |
|
|
mol = AddHs(mol) |
|
|
generate_conformer(mol) |
|
|
ligands_list.append(mol) |
|
|
else: |
|
|
mol = read_molecule(ligand_description, remove_hs=False, sanitize=True) |
|
|
if not self.keep_local_structures: |
|
|
mol.RemoveAllConformers() |
|
|
mol = AddHs(mol) |
|
|
generate_conformer(mol) |
|
|
ligands_list.append(mol) |
|
|
|
|
|
if self.esm_embeddings_path is not None: |
|
|
print('Reading language model embeddings.') |
|
|
lm_embeddings_chains_all = [] |
|
|
if not os.path.exists(self.esm_embeddings_path): raise Exception('ESM embeddings path does not exist: ',self.esm_embeddings_path) |
|
|
for protein_path in self.protein_path_list: |
|
|
embeddings_paths = sorted(glob.glob(os.path.join(self.esm_embeddings_path, os.path.basename(protein_path)) + '*')) |
|
|
lm_embeddings_chains = [] |
|
|
for embeddings_path in embeddings_paths: |
|
|
lm_embeddings_chains.append(torch.load(embeddings_path)['representations'][33]) |
|
|
lm_embeddings_chains_all.append(lm_embeddings_chains) |
|
|
else: |
|
|
lm_embeddings_chains_all = [None] * len(self.protein_path_list) |
|
|
|
|
|
print('Generating graphs for ligands and proteins') |
|
|
|
|
|
list_indices = list(range(len(self.protein_path_list)//1000+1)) |
|
|
random.shuffle(list_indices) |
|
|
for i in list_indices: |
|
|
if os.path.exists(os.path.join(self.full_cache_path, f"heterographs{i}.pkl")): |
|
|
continue |
|
|
protein_paths_chunk = self.protein_path_list[1000*i:1000*(i+1)] |
|
|
ligand_description_chunk = self.ligand_descriptions[1000*i:1000*(i+1)] |
|
|
ligands_chunk = ligands_list[1000 * i:1000 * (i + 1)] |
|
|
lm_embeddings_chains = lm_embeddings_chains_all[1000*i:1000*(i+1)] |
|
|
complex_graphs, rdkit_ligands = [], [] |
|
|
if self.num_workers > 1: |
|
|
p = Pool(self.num_workers, maxtasksperchild=1) |
|
|
p.__enter__() |
|
|
with tqdm(total=len(protein_paths_chunk), desc=f'loading complexes {i}/{len(protein_paths_chunk)//1000+1}') as pbar: |
|
|
map_fn = p.imap_unordered if self.num_workers > 1 else map |
|
|
for t in map_fn(self.get_complex, zip(protein_paths_chunk, lm_embeddings_chains, ligands_chunk,ligand_description_chunk)): |
|
|
complex_graphs.extend(t[0]) |
|
|
rdkit_ligands.extend(t[1]) |
|
|
pbar.update() |
|
|
if self.num_workers > 1: p.__exit__(None, None, None) |
|
|
|
|
|
with open(os.path.join(self.full_cache_path, f"heterographs{i}.pkl"), 'wb') as f: |
|
|
pickle.dump((complex_graphs), f) |
|
|
with open(os.path.join(self.full_cache_path, f"rdkit_ligands{i}.pkl"), 'wb') as f: |
|
|
pickle.dump((rdkit_ligands), f) |
|
|
|
|
|
def check_all_complexes(self): |
|
|
if os.path.exists(os.path.join(self.full_cache_path, f"heterographs.pkl")): |
|
|
return True |
|
|
|
|
|
complex_names_all = read_strings_from_txt(self.split_path) |
|
|
if self.limit_complexes is not None and self.limit_complexes != 0: |
|
|
complex_names_all = complex_names_all[:self.limit_complexes] |
|
|
for i in range(len(complex_names_all) // 1000 + 1): |
|
|
if not os.path.exists(os.path.join(self.full_cache_path, f"heterographs{i}.pkl")): |
|
|
return False |
|
|
return True |
|
|
|
|
|
def collect_all_complexes(self): |
|
|
print('Collecting all complexes from cache', self.full_cache_path) |
|
|
if os.path.exists(os.path.join(self.full_cache_path, f"heterographs.pkl")): |
|
|
with open(os.path.join(self.full_cache_path, "heterographs.pkl"), 'rb') as f: |
|
|
complex_graphs = pickle.load(f) |
|
|
if self.require_ligand: |
|
|
with open(os.path.join(self.full_cache_path, "rdkit_ligands.pkl"), 'rb') as f: |
|
|
rdkit_ligands = pickle.load(f) |
|
|
else: |
|
|
rdkit_ligands = None |
|
|
return complex_graphs, rdkit_ligands |
|
|
|
|
|
complex_names_all = read_strings_from_txt(self.split_path) |
|
|
if self.limit_complexes is not None and self.limit_complexes != 0: |
|
|
complex_names_all = complex_names_all[:self.limit_complexes] |
|
|
complex_graphs_all = [] |
|
|
for i in range(len(complex_names_all) // 1000 + 1): |
|
|
with open(os.path.join(self.full_cache_path, f"heterographs{i}.pkl"), 'rb') as f: |
|
|
print(i) |
|
|
l = pickle.load(f) |
|
|
complex_graphs_all.extend(l) |
|
|
|
|
|
rdkit_ligands_all = [] |
|
|
for i in range(len(complex_names_all) // 1000 + 1): |
|
|
with open(os.path.join(self.full_cache_path, f"rdkit_ligands{i}.pkl"), 'rb') as f: |
|
|
l = pickle.load(f) |
|
|
rdkit_ligands_all.extend(l) |
|
|
|
|
|
return complex_graphs_all, rdkit_ligands_all |
|
|
|
|
|
def get_complex(self, par): |
|
|
name, lm_embedding_chains, ligand, ligand_description = par |
|
|
if not os.path.exists(os.path.join(self.pdbbind_dir, name)) and ligand is None: |
|
|
print("Folder not found", name) |
|
|
return [], [] |
|
|
|
|
|
try: |
|
|
|
|
|
lig = read_mol(self.pdbbind_dir, name, suffix=self.ligand_file, remove_hs=False) |
|
|
if self.max_lig_size != None and lig.GetNumHeavyAtoms() > self.max_lig_size: |
|
|
print(f'Ligand with {lig.GetNumHeavyAtoms()} heavy atoms is larger than max_lig_size {self.max_lig_size}. Not including {name} in preprocessed data.') |
|
|
return [], [] |
|
|
|
|
|
complex_graph = HeteroData() |
|
|
complex_graph['name'] = name |
|
|
get_lig_graph_with_matching(lig, complex_graph, self.popsize, self.maxiter, self.matching, self.keep_original, |
|
|
self.num_conformers, remove_hs=self.remove_hs, tries=self.matching_tries) |
|
|
|
|
|
moad_extract_receptor_structure(path=os.path.join(self.pdbbind_dir, name, f'{name}_{self.protein_file}.pdb'), |
|
|
complex_graph=complex_graph, |
|
|
neighbor_cutoff=self.receptor_radius, |
|
|
max_neighbors=self.c_alpha_max_neighbors, |
|
|
lm_embeddings=lm_embedding_chains, |
|
|
knn_only_graph=self.knn_only_graph, |
|
|
all_atoms=self.all_atoms, |
|
|
atom_cutoff=self.atom_radius, |
|
|
atom_max_neighbors=self.atom_max_neighbors) |
|
|
|
|
|
except Exception as e: |
|
|
print(f'Skipping {name} because of the error:') |
|
|
print(e) |
|
|
return [], [] |
|
|
|
|
|
if self.dataset == 'posebusters': |
|
|
other_positions = [] |
|
|
all_mol_file = os.path.join(self.pdbbind_dir, name, f'{name}_ligands.sdf') |
|
|
supplier = Chem.SDMolSupplier(all_mol_file, sanitize=False, removeHs=False) |
|
|
for mol in supplier: |
|
|
Chem.SanitizeMol(mol) |
|
|
all_mol = RemoveAllHs(mol) |
|
|
for conf in all_mol.GetConformers(): |
|
|
other_positions.append(conf.GetPositions()) |
|
|
|
|
|
print(f'Found {len(other_positions)} alternative poses for {name}') |
|
|
complex_graph['ligand'].orig_pos = np.asarray(other_positions) |
|
|
|
|
|
protein_center = torch.mean(complex_graph['receptor'].pos, dim=0, keepdim=True) |
|
|
complex_graph['receptor'].pos -= protein_center |
|
|
if self.all_atoms: |
|
|
complex_graph['atom'].pos -= protein_center |
|
|
|
|
|
if (not self.matching) or self.num_conformers == 1: |
|
|
complex_graph['ligand'].pos -= protein_center |
|
|
else: |
|
|
for p in complex_graph['ligand'].pos: |
|
|
p -= protein_center |
|
|
|
|
|
complex_graph.original_center = protein_center |
|
|
complex_graph['receptor_name'] = name |
|
|
return [complex_graph], [lig] |
|
|
|
|
|
|
|
|
def print_statistics(complex_graphs): |
|
|
statistics = ([], [], [], [], [], []) |
|
|
receptor_sizes = [] |
|
|
|
|
|
for complex_graph in complex_graphs: |
|
|
lig_pos = complex_graph['ligand'].pos if torch.is_tensor(complex_graph['ligand'].pos) else complex_graph['ligand'].pos[0] |
|
|
receptor_sizes.append(complex_graph['receptor'].pos.shape[0]) |
|
|
radius_protein = torch.max(torch.linalg.vector_norm(complex_graph['receptor'].pos, dim=1)) |
|
|
molecule_center = torch.mean(lig_pos, dim=0) |
|
|
radius_molecule = torch.max( |
|
|
torch.linalg.vector_norm(lig_pos - molecule_center.unsqueeze(0), dim=1)) |
|
|
distance_center = torch.linalg.vector_norm(molecule_center) |
|
|
statistics[0].append(radius_protein) |
|
|
statistics[1].append(radius_molecule) |
|
|
statistics[2].append(distance_center) |
|
|
if "rmsd_matching" in complex_graph: |
|
|
statistics[3].append(complex_graph.rmsd_matching) |
|
|
else: |
|
|
statistics[3].append(0) |
|
|
statistics[4].append(int(complex_graph.random_coords) if "random_coords" in complex_graph else -1) |
|
|
if "random_coords" in complex_graph and complex_graph.random_coords and "rmsd_matching" in complex_graph: |
|
|
statistics[5].append(complex_graph.rmsd_matching) |
|
|
|
|
|
if len(statistics[5]) == 0: |
|
|
statistics[5].append(-1) |
|
|
name = ['radius protein', 'radius molecule', 'distance protein-mol', 'rmsd matching', 'random coordinates', 'random rmsd matching'] |
|
|
print('Number of complexes: ', len(complex_graphs)) |
|
|
for i in range(len(name)): |
|
|
array = np.asarray(statistics[i]) |
|
|
print(f"{name[i]}: mean {np.mean(array)}, std {np.std(array)}, max {np.max(array)}") |
|
|
|
|
|
return |
|
|
|
|
|
|
|
|
def read_mol(pdbbind_dir, name, suffix='ligand', remove_hs=False): |
|
|
lig = read_molecule(os.path.join(pdbbind_dir, name, f'{name}_{suffix}.sdf'), remove_hs=remove_hs, sanitize=True) |
|
|
if lig is None: |
|
|
lig = read_molecule(os.path.join(pdbbind_dir, name, f'{name}_{suffix}.mol2'), remove_hs=remove_hs, sanitize=True) |
|
|
return lig |
|
|
|
|
|
|
|
|
def read_mols(pdbbind_dir, name, remove_hs=False): |
|
|
ligs = [] |
|
|
for file in os.listdir(os.path.join(pdbbind_dir, name)): |
|
|
if file.endswith(".sdf") and 'rdkit' not in file: |
|
|
lig = read_molecule(os.path.join(pdbbind_dir, name, file), remove_hs=remove_hs, sanitize=True) |
|
|
if lig is None and os.path.exists(os.path.join(pdbbind_dir, name, file[:-4] + ".mol2")): |
|
|
print('Using the .sdf file failed. We found a .mol2 file instead and are trying to use that.') |
|
|
lig = read_molecule(os.path.join(pdbbind_dir, name, file[:-4] + ".mol2"), remove_hs=remove_hs, sanitize=True) |
|
|
if lig is not None: |
|
|
ligs.append(lig) |
|
|
return ligs |