| | import copy |
| | import warnings |
| | import numpy as np |
| | import torch |
| | from Bio.PDB import PDBParser |
| | from rdkit import Chem |
| | from rdkit.Chem.rdchem import BondType as BT |
| | from rdkit.Chem import AllChem, GetPeriodicTable, RemoveHs |
| | from rdkit.Geometry import Point3D |
| | from torch import cdist |
| | from torch_cluster import knn_graph |
| | import prody as pr |
| |
|
| | import torch.nn.functional as F |
| |
|
| | from datasets.conformer_matching import get_torsion_angles, optimize_rotatable_bonds |
| | from datasets.constants import aa_short2long, atom_order, three_to_one |
| | from datasets.parse_chi import get_chi_angles, get_coords, aa_idx2aa_short, get_onehot_sequence |
| | from utils.torsion import get_transformation_mask |
| | from utils.logging_utils import get_logger |
| |
|
| |
|
| | periodic_table = GetPeriodicTable() |
| | allowable_features = { |
| | 'possible_atomic_num_list': list(range(1, 119)) + ['misc'], |
| | 'possible_chirality_list': [ |
| | 'CHI_UNSPECIFIED', |
| | 'CHI_TETRAHEDRAL_CW', |
| | 'CHI_TETRAHEDRAL_CCW', |
| | 'CHI_OTHER' |
| | ], |
| | 'possible_degree_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 'misc'], |
| | 'possible_numring_list': [0, 1, 2, 3, 4, 5, 6, 'misc'], |
| | 'possible_implicit_valence_list': [0, 1, 2, 3, 4, 5, 6, 'misc'], |
| | 'possible_formal_charge_list': [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 'misc'], |
| | 'possible_numH_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], |
| | 'possible_number_radical_e_list': [0, 1, 2, 3, 4, 'misc'], |
| | 'possible_hybridization_list': [ |
| | 'SP', 'SP2', 'SP3', 'SP3D', 'SP3D2', 'misc' |
| | ], |
| | 'possible_is_aromatic_list': [False, True], |
| | 'possible_is_in_ring3_list': [False, True], |
| | 'possible_is_in_ring4_list': [False, True], |
| | 'possible_is_in_ring5_list': [False, True], |
| | 'possible_is_in_ring6_list': [False, True], |
| | 'possible_is_in_ring7_list': [False, True], |
| | 'possible_is_in_ring8_list': [False, True], |
| | 'possible_amino_acids': ['ALA', 'ARG', 'ASN', 'ASP', 'CYS', 'GLN', 'GLU', 'GLY', 'HIS', 'ILE', 'LEU', 'LYS', 'MET', |
| | 'PHE', 'PRO', 'SER', 'THR', 'TRP', 'TYR', 'VAL', 'HIP', 'HIE', 'TPO', 'HID', 'LEV', 'MEU', |
| | 'PTR', 'GLV', 'CYT', 'SEP', 'HIZ', 'CYM', 'GLM', 'ASQ', 'TYS', 'CYX', 'GLZ', 'misc'], |
| | 'possible_atom_type_2': ['C*', 'CA', 'CB', 'CD', 'CE', 'CG', 'CH', 'CZ', 'N*', 'ND', 'NE', 'NH', 'NZ', 'O*', 'OD', |
| | 'OE', 'OG', 'OH', 'OX', 'S*', 'SD', 'SG', 'misc'], |
| | 'possible_atom_type_3': ['C', 'CA', 'CB', 'CD', 'CD1', 'CD2', 'CE', 'CE1', 'CE2', 'CE3', 'CG', 'CG1', 'CG2', 'CH2', |
| | 'CZ', 'CZ2', 'CZ3', 'N', 'ND1', 'ND2', 'NE', 'NE1', 'NE2', 'NH1', 'NH2', 'NZ', 'O', 'OD1', |
| | 'OD2', 'OE1', 'OE2', 'OG', 'OG1', 'OH', 'OXT', 'SD', 'SG', 'misc'], |
| | } |
| | bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3} |
| |
|
| | lig_feature_dims = (list(map(len, [ |
| | allowable_features['possible_atomic_num_list'], |
| | allowable_features['possible_chirality_list'], |
| | allowable_features['possible_degree_list'], |
| | allowable_features['possible_formal_charge_list'], |
| | allowable_features['possible_implicit_valence_list'], |
| | allowable_features['possible_numH_list'], |
| | allowable_features['possible_number_radical_e_list'], |
| | allowable_features['possible_hybridization_list'], |
| | allowable_features['possible_is_aromatic_list'], |
| | allowable_features['possible_numring_list'], |
| | allowable_features['possible_is_in_ring3_list'], |
| | allowable_features['possible_is_in_ring4_list'], |
| | allowable_features['possible_is_in_ring5_list'], |
| | allowable_features['possible_is_in_ring6_list'], |
| | allowable_features['possible_is_in_ring7_list'], |
| | allowable_features['possible_is_in_ring8_list'], |
| | ])), 0) |
| |
|
| | rec_atom_feature_dims = (list(map(len, [ |
| | allowable_features['possible_amino_acids'], |
| | allowable_features['possible_atomic_num_list'], |
| | allowable_features['possible_atom_type_2'], |
| | allowable_features['possible_atom_type_3'], |
| | ])), 0) |
| |
|
| | rec_residue_feature_dims = (list(map(len, [ |
| | allowable_features['possible_amino_acids'] |
| | ])), 0) |
| |
|
| |
|
| | def lig_atom_featurizer(mol): |
| | ringinfo = mol.GetRingInfo() |
| | atom_features_list = [] |
| | for idx, atom in enumerate(mol.GetAtoms()): |
| | chiral_tag = str(atom.GetChiralTag()) |
| | if chiral_tag in ['CHI_SQUAREPLANAR', 'CHI_TRIGONALBIPYRAMIDAL', 'CHI_OCTAHEDRAL']: |
| | chiral_tag = 'CHI_OTHER' |
| |
|
| | atom_features_list.append([ |
| | safe_index(allowable_features['possible_atomic_num_list'], atom.GetAtomicNum()), |
| | allowable_features['possible_chirality_list'].index(str(chiral_tag)), |
| | safe_index(allowable_features['possible_degree_list'], atom.GetTotalDegree()), |
| | safe_index(allowable_features['possible_formal_charge_list'], atom.GetFormalCharge()), |
| | safe_index(allowable_features['possible_implicit_valence_list'], atom.GetImplicitValence()), |
| | safe_index(allowable_features['possible_numH_list'], atom.GetTotalNumHs()), |
| | safe_index(allowable_features['possible_number_radical_e_list'], atom.GetNumRadicalElectrons()), |
| | safe_index(allowable_features['possible_hybridization_list'], str(atom.GetHybridization())), |
| | allowable_features['possible_is_aromatic_list'].index(atom.GetIsAromatic()), |
| | safe_index(allowable_features['possible_numring_list'], ringinfo.NumAtomRings(idx)), |
| | allowable_features['possible_is_in_ring3_list'].index(ringinfo.IsAtomInRingOfSize(idx, 3)), |
| | allowable_features['possible_is_in_ring4_list'].index(ringinfo.IsAtomInRingOfSize(idx, 4)), |
| | allowable_features['possible_is_in_ring5_list'].index(ringinfo.IsAtomInRingOfSize(idx, 5)), |
| | allowable_features['possible_is_in_ring6_list'].index(ringinfo.IsAtomInRingOfSize(idx, 6)), |
| | allowable_features['possible_is_in_ring7_list'].index(ringinfo.IsAtomInRingOfSize(idx, 7)), |
| | allowable_features['possible_is_in_ring8_list'].index(ringinfo.IsAtomInRingOfSize(idx, 8)), |
| | |
| | ]) |
| | return torch.tensor(atom_features_list) |
| |
|
| |
|
| | def safe_index(l, e): |
| | """ Return index of element e in list l. If e is not present, return the last index """ |
| | try: |
| | return l.index(e) |
| | except: |
| | return len(l) - 1 |
| |
|
| |
|
| | def moad_extract_receptor_structure(path, complex_graph, neighbor_cutoff=20, max_neighbors=None, sequences_to_embeddings=None, |
| | knn_only_graph=False, lm_embeddings=None, all_atoms=False, atom_cutoff=None, atom_max_neighbors=None): |
| | |
| | pdb = pr.parsePDB(path) |
| | seq = pdb.ca.getSequence() |
| | coords = get_coords(pdb) |
| | one_hot = get_onehot_sequence(seq) |
| |
|
| | chain_ids = np.zeros(len(one_hot)) |
| | res_chain_ids = pdb.ca.getChids() |
| | res_seg_ids = pdb.ca.getSegnames() |
| | res_chain_ids = np.asarray([s + c for s, c in zip(res_seg_ids, res_chain_ids)]) |
| | ids = np.unique(res_chain_ids) |
| | sequences = [] |
| | lm_embeddings = lm_embeddings if sequences_to_embeddings is None else [] |
| |
|
| | for i, id in enumerate(ids): |
| | chain_ids[res_chain_ids == id] = i |
| |
|
| | s = np.argmax(one_hot[res_chain_ids == id], axis=1) |
| | s = ''.join([aa_idx2aa_short[aa_idx] for aa_idx in s]) |
| | sequences.append(s) |
| | if sequences_to_embeddings is not None: |
| | lm_embeddings.append(sequences_to_embeddings[s]) |
| |
|
| | complex_graph['receptor'].sequence = sequences |
| | complex_graph['receptor'].chain_ids = torch.from_numpy(np.asarray(chain_ids)).long() |
| |
|
| | new_extract_receptor_structure(seq, coords, complex_graph, neighbor_cutoff=neighbor_cutoff, max_neighbors=max_neighbors, |
| | lm_embeddings=lm_embeddings, knn_only_graph=knn_only_graph, all_atoms=all_atoms, |
| | atom_cutoff=atom_cutoff, atom_max_neighbors=atom_max_neighbors) |
| |
|
| |
|
| | def new_extract_receptor_structure(seq, all_coords, complex_graph, neighbor_cutoff=20, max_neighbors=None, lm_embeddings=None, |
| | knn_only_graph=False, all_atoms=False, atom_cutoff=None, atom_max_neighbors=None): |
| | chi_angles, one_hot = get_chi_angles(all_coords, seq, return_onehot=True) |
| | n_rel_pos, c_rel_pos = all_coords[:, 0, :] - all_coords[:, 1, :], all_coords[:, 2, :] - all_coords[:, 1, :] |
| | side_chain_vecs = torch.from_numpy(np.concatenate([chi_angles / 360, n_rel_pos, c_rel_pos], axis=1)) |
| |
|
| | |
| | coords = torch.tensor(all_coords[:, 1, :], dtype=torch.float) |
| | if len(coords) > 3000: |
| | raise ValueError(f'The receptor is too large {len(coords)}') |
| | if knn_only_graph: |
| | edge_index = knn_graph(coords, k=max_neighbors if max_neighbors else 32) |
| | else: |
| | distances = cdist(coords, coords) |
| | src_list = [] |
| | dst_list = [] |
| | for i in range(len(coords)): |
| | dst = list(np.where(distances[i, :] < neighbor_cutoff)[0]) |
| | dst.remove(i) |
| | max_neighbors = max_neighbors if max_neighbors else 1000 |
| | if max_neighbors != None and len(dst) > max_neighbors: |
| | dst = list(np.argsort(distances[i, :]))[1: max_neighbors + 1] |
| | if len(dst) == 0: |
| | dst = list(np.argsort(distances[i, :]))[1:2] |
| | print( |
| | f'The cutoff {neighbor_cutoff} was too small for one atom such that it had no neighbors. ' |
| | f'So we connected it to the closest other atom') |
| | assert i not in dst |
| | src = [i] * len(dst) |
| | src_list.extend(src) |
| | dst_list.extend(dst) |
| | edge_index = torch.from_numpy(np.asarray([dst_list, src_list])) |
| |
|
| | res_names_list = [aa_short2long[seq[i]] if seq[i] in aa_short2long else 'misc' for i in range(len(seq))] |
| | feature_list = [[safe_index(allowable_features['possible_amino_acids'], res)] for res in res_names_list] |
| | node_feat = torch.tensor(feature_list, dtype=torch.float32) |
| |
|
| | lm_embeddings = torch.tensor(np.concatenate(lm_embeddings, axis=0)) if lm_embeddings is not None else None |
| | complex_graph['receptor'].x = torch.cat([node_feat, lm_embeddings], axis=1) if lm_embeddings is not None else node_feat |
| | complex_graph['receptor'].pos = coords |
| | complex_graph['receptor'].side_chain_vecs = side_chain_vecs.float() |
| | complex_graph['receptor', 'rec_contact', 'receptor'].edge_index = edge_index |
| | if all_atoms: |
| | atom_coords = all_coords.reshape(-1, 3) |
| | atom_coords = torch.from_numpy(atom_coords[~np.any(np.isnan(atom_coords), axis=1)]).float() |
| |
|
| | if knn_only_graph: |
| | atoms_edge_index = knn_graph(atom_coords, k=atom_max_neighbors if atom_max_neighbors else 1000) |
| | else: |
| | atoms_distances = cdist(atom_coords, atom_coords) |
| | atom_src_list = [] |
| | atom_dst_list = [] |
| | for i in range(len(atom_coords)): |
| | dst = list(np.where(atoms_distances[i, :] < atom_cutoff)[0]) |
| | dst.remove(i) |
| | max_neighbors = atom_max_neighbors if atom_max_neighbors else 1000 |
| | if max_neighbors != None and len(dst) > max_neighbors: |
| | dst = list(np.argsort(atoms_distances[i, :]))[1: max_neighbors + 1] |
| | if len(dst) == 0: |
| | dst = list(np.argsort(atoms_distances[i, :]))[1:2] |
| | print( |
| | f'The atom_cutoff {atom_cutoff} was too small for one atom such that it had no neighbors. ' |
| | f'So we connected it to the closest other atom') |
| | assert i not in dst |
| | src = [i] * len(dst) |
| | atom_src_list.extend(src) |
| | atom_dst_list.extend(dst) |
| | atoms_edge_index = torch.from_numpy(np.asarray([atom_dst_list, atom_src_list])) |
| | |
| | feats = [get_moad_atom_feats(res, all_coords[i]) for i, res in enumerate(seq)] |
| | atom_feat = torch.from_numpy(np.concatenate(feats, axis=0)).float() |
| | c_alpha_idx = np.concatenate([np.zeros(len(f)) + i for i, f in enumerate(feats)]) |
| | np_array = np.stack([np.arange(len(atom_feat)), c_alpha_idx]) |
| | atom_res_edge_index = torch.from_numpy(np_array).long() |
| | complex_graph['atom'].x = atom_feat |
| | complex_graph['atom'].pos = atom_coords |
| | assert len(complex_graph['atom'].x) == len(complex_graph['atom'].pos) |
| | complex_graph['atom', 'atom_contact', 'atom'].edge_index = atoms_edge_index |
| | complex_graph['atom', 'atom_rec_contact', 'receptor'].edge_index = atom_res_edge_index |
| |
|
| | return |
| |
|
| |
|
| | def get_moad_atom_feats(res, coords): |
| | feats = [] |
| | res_long = aa_short2long[res] |
| | res_order = atom_order[res] |
| | for i, c in enumerate(coords): |
| | if np.any(np.isnan(c)): |
| | continue |
| | atom_feats = [] |
| | if res == '-': |
| | atom_feats = [safe_index(allowable_features['possible_amino_acids'], 'misc'), |
| | safe_index(allowable_features['possible_atomic_num_list'], 'misc'), |
| | safe_index(allowable_features['possible_atom_type_2'], 'misc'), |
| | safe_index(allowable_features['possible_atom_type_3'], 'misc')] |
| | else: |
| | atom_feats.append(safe_index(allowable_features['possible_amino_acids'], res_long)) |
| | if i >= len(res_order): |
| | atom_feats.extend([safe_index(allowable_features['possible_atomic_num_list'], 'misc'), |
| | safe_index(allowable_features['possible_atom_type_2'], 'misc'), |
| | safe_index(allowable_features['possible_atom_type_3'], 'misc')]) |
| | else: |
| | atom_name = res_order[i] |
| | try: |
| | atomic_num = periodic_table.GetAtomicNumber(atom_name[:1]) |
| | except: |
| | print("element", res_order[i][:1], 'not found') |
| | atomic_num = -1 |
| |
|
| | atom_feats.extend([safe_index(allowable_features['possible_atomic_num_list'], atomic_num), |
| | safe_index(allowable_features['possible_atom_type_2'], (atom_name + '*')[:2]), |
| | safe_index(allowable_features['possible_atom_type_3'], atom_name)]) |
| | feats.append(atom_feats) |
| | feats = np.asarray(feats) |
| | return feats |
| |
|
| |
|
| | def get_lig_graph(mol, complex_graph): |
| | atom_feats = lig_atom_featurizer(mol) |
| |
|
| | row, col, edge_type = [], [], [] |
| | for bond in mol.GetBonds(): |
| | start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() |
| | row += [start, end] |
| | col += [end, start] |
| | edge_type += 2 * [bonds[bond.GetBondType()]] if bond.GetBondType() != BT.UNSPECIFIED else [0, 0] |
| |
|
| | edge_index = torch.tensor([row, col], dtype=torch.long) |
| | edge_type = torch.tensor(edge_type, dtype=torch.long) |
| | edge_attr = F.one_hot(edge_type, num_classes=len(bonds)).to(torch.float) |
| |
|
| | complex_graph['ligand'].x = atom_feats |
| | complex_graph['ligand', 'lig_bond', 'ligand'].edge_index = edge_index |
| | complex_graph['ligand', 'lig_bond', 'ligand'].edge_attr = edge_attr |
| |
|
| | if mol.GetNumConformers() > 0: |
| | lig_coords = torch.from_numpy(mol.GetConformer().GetPositions()).float() |
| | complex_graph['ligand'].pos = lig_coords |
| |
|
| | return |
| |
|
| |
|
| | def generate_conformer(mol): |
| | ps = AllChem.ETKDGv2() |
| | failures, id = 0, -1 |
| | while failures < 3 and id == -1: |
| | if failures > 0: |
| | get_logger().debug(f'rdkit coords could not be generated. trying again {failures}.') |
| | id = AllChem.EmbedMolecule(mol, ps) |
| | failures += 1 |
| | if id == -1: |
| | get_logger().info('rdkit coords could not be generated without using random coords. using random coords now.') |
| | ps.useRandomCoords = True |
| | AllChem.EmbedMolecule(mol, ps) |
| | AllChem.MMFFOptimizeMolecule(mol, confId=0) |
| | return True |
| | |
| | |
| | return False |
| |
|
| |
|
| | def get_lig_graph_with_matching(mol_, complex_graph, popsize, maxiter, matching, keep_original, num_conformers, remove_hs, tries=10, skip_matching=False): |
| | if matching: |
| | mol_maybe_noh = copy.deepcopy(mol_) |
| | if remove_hs: |
| | mol_maybe_noh = RemoveHs(mol_maybe_noh, sanitize=True) |
| | mol_maybe_noh = AllChem.RemoveAllHs(mol_maybe_noh) |
| | if keep_original: |
| | positions = [] |
| | for conf in mol_maybe_noh.GetConformers(): |
| | positions.append(conf.GetPositions()) |
| | complex_graph['ligand'].orig_pos = np.asarray(positions) if len(positions) > 1 else positions[0] |
| |
|
| | |
| | _tmp = copy.deepcopy(mol_) |
| | if remove_hs: |
| | _tmp = RemoveHs(_tmp, sanitize=True) |
| | _tmp = AllChem.RemoveAllHs(_tmp) |
| | rotatable_bonds = get_torsion_angles(_tmp) |
| |
|
| | for i in range(num_conformers): |
| | mols, rmsds = [], [] |
| | for _ in range(tries): |
| | mol_rdkit = copy.deepcopy(mol_) |
| |
|
| | mol_rdkit.RemoveAllConformers() |
| | mol_rdkit = AllChem.AddHs(mol_rdkit) |
| | generate_conformer(mol_rdkit) |
| | if remove_hs: |
| | mol_rdkit = RemoveHs(mol_rdkit, sanitize=True) |
| | mol_rdkit = AllChem.RemoveAllHs(mol_rdkit) |
| | mol = AllChem.RemoveAllHs(copy.deepcopy(mol_maybe_noh)) |
| | if rotatable_bonds and not skip_matching: |
| | optimize_rotatable_bonds(mol_rdkit, mol, rotatable_bonds, popsize=popsize, maxiter=maxiter) |
| | mol.AddConformer(mol_rdkit.GetConformer()) |
| | rms_list = [] |
| | AllChem.AlignMolConformers(mol, RMSlist=rms_list) |
| | mol_rdkit.RemoveAllConformers() |
| | mol_rdkit.AddConformer(mol.GetConformers()[1]) |
| | mols.append(mol_rdkit) |
| | rmsds.append(rms_list[0]) |
| |
|
| | |
| | |
| | mol_rdkit = mols[np.argmin(rmsds)] |
| | if i == 0: |
| | complex_graph.rmsd_matching = min(rmsds) |
| | get_lig_graph(mol_rdkit, complex_graph) |
| | else: |
| | if torch.is_tensor(complex_graph['ligand'].pos): |
| | complex_graph['ligand'].pos = [complex_graph['ligand'].pos] |
| | complex_graph['ligand'].pos.append(torch.from_numpy(mol_rdkit.GetConformer().GetPositions()).float()) |
| |
|
| | else: |
| | complex_graph.rmsd_matching = 0 |
| | if remove_hs: mol_ = RemoveHs(mol_) |
| | get_lig_graph(mol_, complex_graph) |
| |
|
| | edge_mask, mask_rotate = get_transformation_mask(complex_graph) |
| | complex_graph['ligand'].edge_mask = torch.tensor(edge_mask) |
| | complex_graph['ligand'].mask_rotate = mask_rotate |
| |
|
| | return |
| |
|
| |
|
| | def get_rec_misc_atom_feat(bio_atom=None, atom_name=None, element=None, get_misc_features=False): |
| | if get_misc_features: |
| | return [safe_index(allowable_features['possible_amino_acids'], 'misc'), |
| | safe_index(allowable_features['possible_atomic_num_list'], 'misc'), |
| | safe_index(allowable_features['possible_atom_type_2'], 'misc'), |
| | safe_index(allowable_features['possible_atom_type_3'], 'misc')] |
| | if atom_name is not None: |
| | atom_name = atom_name |
| | else: |
| | atom_name = bio_atom.name |
| | if element is not None: |
| | element = element |
| | else: |
| | element = bio_atom.element |
| | if element == 'CD': |
| | element = 'C' |
| | assert not element == '' |
| | try: |
| | atomic_num = periodic_table.GetAtomicNumber(element.lower().capitalize()) |
| | except: |
| | atomic_num = -1 |
| |
|
| | atom_feat = [safe_index(allowable_features['possible_amino_acids'], bio_atom.get_parent().get_resname()), |
| | safe_index(allowable_features['possible_atomic_num_list'], atomic_num), |
| | safe_index(allowable_features['possible_atom_type_2'], (atom_name + '*')[:2]), |
| | safe_index(allowable_features['possible_atom_type_3'], atom_name)] |
| | return atom_feat |
| |
|
| |
|
| | def write_mol_with_coords(mol, new_coords, path): |
| | w = Chem.SDWriter(path) |
| | conf = mol.GetConformer() |
| | for i in range(mol.GetNumAtoms()): |
| | x,y,z = new_coords.astype(np.double)[i] |
| | conf.SetAtomPosition(i,Point3D(x,y,z)) |
| | w.write(mol) |
| | w.close() |
| |
|
| |
|
| | def read_molecule(molecule_file, sanitize=False, calc_charges=False, remove_hs=False): |
| | if molecule_file.endswith('.mol2'): |
| | mol = Chem.MolFromMol2File(molecule_file, sanitize=False, removeHs=False) |
| | elif molecule_file.endswith('.sdf'): |
| | supplier = Chem.SDMolSupplier(molecule_file, sanitize=False, removeHs=False) |
| | mol = supplier[0] |
| | elif molecule_file.endswith('.pdbqt'): |
| | with open(molecule_file) as file: |
| | pdbqt_data = file.readlines() |
| | pdb_block = '' |
| | for line in pdbqt_data: |
| | pdb_block += '{}\n'.format(line[:66]) |
| | mol = Chem.MolFromPDBBlock(pdb_block, sanitize=False, removeHs=False) |
| | elif molecule_file.endswith('.pdb'): |
| | mol = Chem.MolFromPDBFile(molecule_file, sanitize=False, removeHs=False) |
| | else: |
| | raise ValueError('Expect the format of the molecule_file to be ' |
| | 'one of .mol2, .sdf, .pdbqt and .pdb, got {}'.format(molecule_file)) |
| |
|
| | try: |
| | if sanitize or calc_charges: |
| | Chem.SanitizeMol(mol) |
| |
|
| | if calc_charges: |
| | |
| | try: |
| | AllChem.ComputeGasteigerCharges(mol) |
| | except: |
| | warnings.warn('Unable to compute charges for the molecule.') |
| |
|
| | if remove_hs: |
| | mol = Chem.RemoveHs(mol, sanitize=sanitize) |
| |
|
| | except Exception as e: |
| | |
| | import traceback |
| | msg = traceback.format_exc() |
| | get_logger().warning(f"Failed to process molecule: {molecule_file}\n{msg}") |
| | return None |
| |
|
| | return mol |
| |
|