| | import numpy as np |
| | from tqdm import tqdm |
| | from rdkit import Chem, DataStructs |
| | from rdkit.Chem import Descriptors, Crippen, Lipinski, QED |
| | from analysis.SA_Score.sascorer import calculateScore |
| |
|
| | from analysis.molecule_builder import build_molecule |
| | from copy import deepcopy |
| |
|
| |
|
| | class CategoricalDistribution: |
| | EPS = 1e-10 |
| |
|
| | def __init__(self, histogram_dict, mapping): |
| | histogram = np.zeros(len(mapping)) |
| | for k, v in histogram_dict.items(): |
| | histogram[mapping[k]] = v |
| |
|
| | |
| | self.p = histogram / histogram.sum() |
| | self.mapping = deepcopy(mapping) |
| |
|
| | def kl_divergence(self, other_sample): |
| | sample_histogram = np.zeros(len(self.mapping)) |
| | for x in other_sample: |
| | |
| | sample_histogram[x] += 1 |
| |
|
| | |
| | q = sample_histogram / sample_histogram.sum() |
| |
|
| | return -np.sum(self.p * np.log(q / self.p + self.EPS)) |
| |
|
| |
|
| | def rdmol_to_smiles(rdmol): |
| | mol = Chem.Mol(rdmol) |
| | Chem.RemoveStereochemistry(mol) |
| | mol = Chem.RemoveHs(mol) |
| | return Chem.MolToSmiles(mol) |
| |
|
| |
|
| | class BasicMolecularMetrics(object): |
| | def __init__(self, dataset_info, dataset_smiles_list=None, |
| | connectivity_thresh=1.0): |
| | self.atom_decoder = dataset_info['atom_decoder'] |
| | if dataset_smiles_list is not None: |
| | dataset_smiles_list = set(dataset_smiles_list) |
| | self.dataset_smiles_list = dataset_smiles_list |
| | self.dataset_info = dataset_info |
| | self.connectivity_thresh = connectivity_thresh |
| |
|
| | def compute_validity(self, generated): |
| | """ generated: list of couples (positions, atom_types)""" |
| | if len(generated) < 1: |
| | return [], 0.0 |
| |
|
| | valid = [] |
| | for mol in generated: |
| | try: |
| | Chem.SanitizeMol(mol) |
| | except ValueError: |
| | continue |
| |
|
| | valid.append(mol) |
| |
|
| | return valid, len(valid) / len(generated) |
| |
|
| | def compute_connectivity(self, valid): |
| | """ Consider molecule connected if its largest fragment contains at |
| | least x% of all atoms, where x is determined by |
| | self.connectivity_thresh (defaults to 100%). """ |
| | if len(valid) < 1: |
| | return [], 0.0 |
| |
|
| | connected = [] |
| | connected_smiles = [] |
| | for mol in valid: |
| | mol_frags = Chem.rdmolops.GetMolFrags(mol, asMols=True) |
| | largest_mol = \ |
| | max(mol_frags, default=mol, key=lambda m: m.GetNumAtoms()) |
| | if largest_mol.GetNumAtoms() / mol.GetNumAtoms() >= self.connectivity_thresh: |
| | smiles = rdmol_to_smiles(largest_mol) |
| | if smiles is not None: |
| | connected_smiles.append(smiles) |
| | connected.append(largest_mol) |
| |
|
| | return connected, len(connected_smiles) / len(valid), connected_smiles |
| |
|
| | def compute_uniqueness(self, connected): |
| | """ valid: list of SMILES strings.""" |
| | if len(connected) < 1 or self.dataset_smiles_list is None: |
| | return [], 0.0 |
| |
|
| | return list(set(connected)), len(set(connected)) / len(connected) |
| |
|
| | def compute_novelty(self, unique): |
| | if len(unique) < 1: |
| | return [], 0.0 |
| |
|
| | num_novel = 0 |
| | novel = [] |
| | for smiles in unique: |
| | if smiles not in self.dataset_smiles_list: |
| | novel.append(smiles) |
| | num_novel += 1 |
| | return novel, num_novel / len(unique) |
| |
|
| | def evaluate_rdmols(self, rdmols): |
| | valid, validity = self.compute_validity(rdmols) |
| | print(f"Validity over {len(rdmols)} molecules: {validity * 100 :.2f}%") |
| |
|
| | connected, connectivity, connected_smiles = \ |
| | self.compute_connectivity(valid) |
| | print(f"Connectivity over {len(valid)} valid molecules: " |
| | f"{connectivity * 100 :.2f}%") |
| |
|
| | unique, uniqueness = self.compute_uniqueness(connected_smiles) |
| | print(f"Uniqueness over {len(connected)} connected molecules: " |
| | f"{uniqueness * 100 :.2f}%") |
| |
|
| | _, novelty = self.compute_novelty(unique) |
| | print(f"Novelty over {len(unique)} unique connected molecules: " |
| | f"{novelty * 100 :.2f}%") |
| |
|
| | return [validity, connectivity, uniqueness, novelty], [valid, connected] |
| |
|
| | def evaluate(self, generated): |
| | """ generated: list of pairs (positions: n x 3, atom_types: n [int]) |
| | the positions and atom types should already be masked. """ |
| |
|
| | rdmols = [build_molecule(*graph, self.dataset_info) |
| | for graph in generated] |
| | return self.evaluate_rdmols(rdmols) |
| |
|
| |
|
| | class MoleculeProperties: |
| |
|
| | @staticmethod |
| | def calculate_qed(rdmol): |
| | return QED.qed(rdmol) |
| |
|
| | @staticmethod |
| | def calculate_sa(rdmol): |
| | sa = calculateScore(rdmol) |
| | return round((10 - sa) / 9, 2) |
| |
|
| | @staticmethod |
| | def calculate_logp(rdmol): |
| | return Crippen.MolLogP(rdmol) |
| |
|
| | @staticmethod |
| | def calculate_lipinski(rdmol): |
| | rule_1 = Descriptors.ExactMolWt(rdmol) < 500 |
| | rule_2 = Lipinski.NumHDonors(rdmol) <= 5 |
| | rule_3 = Lipinski.NumHAcceptors(rdmol) <= 10 |
| | rule_4 = (logp := Crippen.MolLogP(rdmol) >= -2) & (logp <= 5) |
| | rule_5 = Chem.rdMolDescriptors.CalcNumRotatableBonds(rdmol) <= 10 |
| | return np.sum([int(a) for a in [rule_1, rule_2, rule_3, rule_4, rule_5]]) |
| |
|
| | @classmethod |
| | def calculate_diversity(cls, pocket_mols): |
| | if len(pocket_mols) < 2: |
| | return 0.0 |
| |
|
| | div = 0 |
| | total = 0 |
| | for i in range(len(pocket_mols)): |
| | for j in range(i + 1, len(pocket_mols)): |
| | div += 1 - cls.similarity(pocket_mols[i], pocket_mols[j]) |
| | total += 1 |
| | return div / total |
| |
|
| | @staticmethod |
| | def similarity(mol_a, mol_b): |
| | |
| | |
| | |
| | |
| | fp1 = Chem.RDKFingerprint(mol_a) |
| | fp2 = Chem.RDKFingerprint(mol_b) |
| | return DataStructs.TanimotoSimilarity(fp1, fp2) |
| |
|
| | def evaluate(self, pocket_rdmols): |
| | """ |
| | Run full evaluation |
| | Args: |
| | pocket_rdmols: list of lists, the inner list contains all RDKit |
| | molecules generated for a pocket |
| | Returns: |
| | QED, SA, LogP, Lipinski (per molecule), and Diversity (per pocket) |
| | """ |
| |
|
| | for pocket in pocket_rdmols: |
| | for mol in pocket: |
| | Chem.SanitizeMol(mol) |
| | assert mol is not None, "only evaluate valid molecules" |
| |
|
| | all_qed = [] |
| | all_sa = [] |
| | all_logp = [] |
| | all_lipinski = [] |
| | per_pocket_diversity = [] |
| | for pocket in tqdm(pocket_rdmols): |
| | all_qed.append([self.calculate_qed(mol) for mol in pocket]) |
| | all_sa.append([self.calculate_sa(mol) for mol in pocket]) |
| | all_logp.append([self.calculate_logp(mol) for mol in pocket]) |
| | all_lipinski.append([self.calculate_lipinski(mol) for mol in pocket]) |
| | per_pocket_diversity.append(self.calculate_diversity(pocket)) |
| |
|
| | print(f"{sum([len(p) for p in pocket_rdmols])} molecules from " |
| | f"{len(pocket_rdmols)} pockets evaluated.") |
| |
|
| | qed_flattened = [x for px in all_qed for x in px] |
| | print(f"QED: {np.mean(qed_flattened):.3f} \pm {np.std(qed_flattened):.2f}") |
| |
|
| | sa_flattened = [x for px in all_sa for x in px] |
| | print(f"SA: {np.mean(sa_flattened):.3f} \pm {np.std(sa_flattened):.2f}") |
| |
|
| | logp_flattened = [x for px in all_logp for x in px] |
| | print(f"LogP: {np.mean(logp_flattened):.3f} \pm {np.std(logp_flattened):.2f}") |
| |
|
| | lipinski_flattened = [x for px in all_lipinski for x in px] |
| | print(f"Lipinski: {np.mean(lipinski_flattened):.3f} \pm {np.std(lipinski_flattened):.2f}") |
| |
|
| | print(f"Diversity: {np.mean(per_pocket_diversity):.3f} \pm {np.std(per_pocket_diversity):.2f}") |
| |
|
| | return all_qed, all_sa, all_logp, all_lipinski, per_pocket_diversity |
| |
|
| | def evaluate_mean(self, rdmols): |
| | """ |
| | Run full evaluation and return mean of each property |
| | Args: |
| | rdmols: list of RDKit molecules |
| | Returns: |
| | QED, SA, LogP, Lipinski, and Diversity |
| | """ |
| |
|
| | if len(rdmols) < 1: |
| | return 0.0, 0.0, 0.0, 0.0, 0.0 |
| |
|
| | for mol in rdmols: |
| | Chem.SanitizeMol(mol) |
| | assert mol is not None, "only evaluate valid molecules" |
| |
|
| | qed = np.mean([self.calculate_qed(mol) for mol in rdmols]) |
| | sa = np.mean([self.calculate_sa(mol) for mol in rdmols]) |
| | logp = np.mean([self.calculate_logp(mol) for mol in rdmols]) |
| | lipinski = np.mean([self.calculate_lipinski(mol) for mol in rdmols]) |
| | diversity = self.calculate_diversity(rdmols) |
| |
|
| | return qed, sa, logp, lipinski, diversity |
| |
|