| from rdkit import Chem, RDLogger |
|
|
| RDLogger.DisableLog("rdApp.*") |
|
|
| import re |
| import random |
| import logging |
| from rdkit import Chem |
| from typing import List, Tuple, Optional |
| random.seed(0) |
| import torch |
|
|
| bond_dict = [ |
| None, |
| Chem.rdchem.BondType.SINGLE, |
| Chem.rdchem.BondType.DOUBLE, |
| Chem.rdchem.BondType.TRIPLE, |
| Chem.rdchem.BondType.AROMATIC, |
| ] |
|
|
| ATOM_VALENCY = {6: 4, 7: 3, 8: 2, 9: 1, 15: 3, 16: 2, 17: 1, 35: 1, 53: 1} |
|
|
| logger = logging.getLogger(__name__) |
|
|
| def check_polymer(smiles): |
| if "*" in smiles: |
| monomer = smiles.replace("*", "[H]") |
| if mol2smiles(get_mol(monomer)) is None: |
| logger.warning(f"Invalid polymerization point") |
| return False |
| else: |
| return True |
| return True |
| |
| def graph_to_smiles(molecule_list: List[Tuple], atom_decoder: list) -> List[Optional[str]]: |
|
|
| smiles_list = [] |
| for index, graph in enumerate(molecule_list): |
| try: |
| atom_types, edge_types = graph |
| mol_init = build_molecule_with_partial_charges(atom_types, edge_types, atom_decoder) |
| |
| |
| for connection in (True, False): |
| mol_conn, _ = correct_mol(mol_init, connection=connection) |
| if mol_conn is not None: |
| break |
| else: |
| logger.warning(f"Failed to correct molecule {index}") |
| mol_conn = mol_init |
|
|
| |
| smiles = mol2smiles(mol_conn) |
| if not smiles: |
| logger.warning(f"Failed to convert molecule {index} to SMILES, falling back to RDKit MolToSmiles") |
| smiles = Chem.MolToSmiles(mol_conn) |
|
|
| if smiles: |
| mol = get_mol(smiles) |
| if mol is not None: |
| |
| mol_frags = Chem.rdmolops.GetMolFrags(mol, asMols=True, sanitizeFrags=False) |
| largest_mol = max(mol_frags, key=lambda m: m.GetNumAtoms()) |
| |
| largest_smiles = mol2smiles(largest_mol) |
| if largest_smiles and len(largest_smiles) > 1: |
| if check_polymer(largest_smiles): |
| smiles_list.append(largest_smiles) |
| else: |
| smiles_list.append(None) |
| elif check_polymer(smiles): |
| smiles_list.append(smiles) |
| else: |
| smiles_list.append(None) |
| else: |
| logger.warning(f"Failed to convert SMILES back to molecule for index {index}") |
| smiles_list.append(None) |
| else: |
| logger.warning(f"Failed to generate SMILES for molecule {index}, appending None") |
| smiles_list.append(None) |
|
|
| except Exception as e: |
| logger.error(f"Error processing molecule {index}: {str(e)}") |
| try: |
| |
| fallback_smiles = Chem.MolToSmiles(mol_init) |
| if fallback_smiles: |
| smiles_list.append(fallback_smiles) |
| logger.warning(f"Used RDKit MolToSmiles fallback for molecule {index}") |
| else: |
| smiles_list.append(None) |
| logger.warning(f"RDKit MolToSmiles fallback failed for molecule {index}, appending None") |
| except Exception as e2: |
| logger.error(f"All attempts failed for molecule {index}: {str(e2)}") |
| smiles_list.append(None) |
|
|
| return smiles_list |
|
|
| def build_molecule_with_partial_charges( |
| atom_types, edge_types, atom_decoder, verbose=False |
| ): |
| if verbose: |
| print("\nbuilding new molecule") |
|
|
| mol = Chem.RWMol() |
| for atom in atom_types: |
| a = Chem.Atom(atom_decoder[atom.item()]) |
| mol.AddAtom(a) |
| if verbose: |
| print("Atom added: ", atom.item(), atom_decoder[atom.item()]) |
|
|
| edge_types = torch.triu(edge_types) |
| all_bonds = torch.nonzero(edge_types) |
|
|
| for i, bond in enumerate(all_bonds): |
| if bond[0].item() != bond[1].item(): |
| mol.AddBond( |
| bond[0].item(), |
| bond[1].item(), |
| bond_dict[edge_types[bond[0], bond[1]].item()], |
| ) |
| if verbose: |
| print( |
| "bond added:", |
| bond[0].item(), |
| bond[1].item(), |
| edge_types[bond[0], bond[1]].item(), |
| bond_dict[edge_types[bond[0], bond[1]].item()], |
| ) |
| |
| |
| flag, atomid_valence = check_valency(mol) |
| if verbose: |
| print("flag, valence", flag, atomid_valence) |
| if flag: |
| continue |
| else: |
| if len(atomid_valence) == 2: |
| idx = atomid_valence[0] |
| v = atomid_valence[1] |
| an = mol.GetAtomWithIdx(idx).GetAtomicNum() |
| if verbose: |
| print("atomic num of atom with a large valence", an) |
| if an in (7, 8, 16) and (v - ATOM_VALENCY[an]) == 1: |
| mol.GetAtomWithIdx(idx).SetFormalCharge(1) |
| |
| else: |
| continue |
| return mol |
|
|
|
|
| def correct_mol(mol, connection=False): |
| |
| no_correct = False |
| flag, _ = check_valency(mol) |
| if flag: |
| no_correct = True |
|
|
| while True: |
| if connection: |
| mol_conn = connect_fragments(mol) |
| mol = mol_conn |
| if mol is None: |
| return None, no_correct |
| flag, atomid_valence = check_valency(mol) |
| if flag: |
| break |
| else: |
| try: |
| assert len(atomid_valence) == 2 |
| idx = atomid_valence[0] |
| v = atomid_valence[1] |
| queue = [] |
| check_idx = 0 |
| for b in mol.GetAtomWithIdx(idx).GetBonds(): |
| type = int(b.GetBondType()) |
| queue.append( |
| (b.GetIdx(), type, b.GetBeginAtomIdx(), b.GetEndAtomIdx()) |
| ) |
| if type == 12: |
| check_idx += 1 |
| queue.sort(key=lambda tup: tup[1], reverse=True) |
|
|
| if queue[-1][1] == 12: |
| return None, no_correct |
| elif len(queue) > 0: |
| start = queue[check_idx][2] |
| end = queue[check_idx][3] |
| t = queue[check_idx][1] - 1 |
| mol.RemoveBond(start, end) |
| if t >= 1: |
| mol.AddBond(start, end, bond_dict[t]) |
| except Exception as e: |
| |
| return None, no_correct |
| return mol, no_correct |
|
|
| def check_valid(smiles): |
| mol = get_mol(smiles) |
| if mol is None: |
| return False |
| smiles = mol2smiles(mol) |
| if smiles is None: |
| return False |
| return True |
|
|
| def get_mol(smiles_or_mol): |
| """ |
| Loads SMILES/molecule into RDKit's object |
| """ |
| if isinstance(smiles_or_mol, str): |
| if len(smiles_or_mol) == 0: |
| return None |
| mol = Chem.MolFromSmiles(smiles_or_mol) |
| if mol is None: |
| return None |
| try: |
| Chem.SanitizeMol(mol) |
| except ValueError: |
| return None |
| return mol |
| return smiles_or_mol |
|
|
|
|
| def mol2smiles(mol): |
| if mol is None: |
| return None |
| try: |
| Chem.SanitizeMol(mol) |
| except ValueError: |
| return None |
| return Chem.MolToSmiles(mol) |
|
|
|
|
| def check_valency(mol): |
| try: |
| |
| Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_PROPERTIES) |
| return True, None |
| except ValueError as e: |
| e = str(e) |
| p = e.find("#") |
| e_sub = e[p:] |
| atomid_valence = list(map(int, re.findall(r"\d+", e_sub))) |
| return False, atomid_valence |
| except Exception as e: |
| |
| return False, [] |
|
|
|
|
| |
| def select_atom_with_available_valency(frag): |
| atoms = list(frag.GetAtoms()) |
| random.shuffle(atoms) |
| for atom in atoms: |
| if atom.GetAtomicNum() > 1 and atom.GetImplicitValence() > 0: |
| return atom |
| return None |
|
|
|
|
| def select_atoms_with_available_valency(frag): |
| return [ |
| atom |
| for atom in frag.GetAtoms() |
| if atom.GetAtomicNum() > 1 and atom.GetImplicitValence() > 0 |
| ] |
|
|
|
|
| def try_to_connect_fragments(combined_mol, frag, atom1, atom2): |
| |
| trial_combined_mol = Chem.RWMol(combined_mol) |
| trial_frag = Chem.RWMol(frag) |
|
|
| |
| new_indices = { |
| atom.GetIdx(): trial_combined_mol.AddAtom(atom) |
| for atom in trial_frag.GetAtoms() |
| } |
|
|
| |
| trial_combined_mol.AddBond( |
| atom1.GetIdx(), new_indices[atom2.GetIdx()], Chem.BondType.SINGLE |
| ) |
|
|
| |
| for atom_idx in [atom1.GetIdx(), new_indices[atom2.GetIdx()]]: |
| atom = trial_combined_mol.GetAtomWithIdx(atom_idx) |
| num_h = atom.GetTotalNumHs() |
| atom.SetNumExplicitHs(max(0, num_h - 1)) |
|
|
| |
| for bond in trial_frag.GetBonds(): |
| trial_combined_mol.AddBond( |
| new_indices[bond.GetBeginAtomIdx()], |
| new_indices[bond.GetEndAtomIdx()], |
| bond.GetBondType(), |
| ) |
|
|
| |
| new_mol = Chem.Mol(trial_combined_mol) |
| try: |
| Chem.SanitizeMol(new_mol) |
| return new_mol |
| except Chem.MolSanitizeException: |
| return None |
|
|
|
|
| def connect_fragments(mol): |
| |
| frags = Chem.GetMolFrags(mol, asMols=True, sanitizeFrags=False) |
| if len(frags) < 2: |
| return mol |
|
|
| combined_mol = Chem.RWMol(frags[0]) |
|
|
| for frag in frags[1:]: |
| |
| atoms1 = select_atoms_with_available_valency(combined_mol) |
| atoms2 = select_atoms_with_available_valency(frag) |
|
|
| |
| for atom1 in atoms1: |
| for atom2 in atoms2: |
| new_mol = try_to_connect_fragments(combined_mol, frag, atom1, atom2) |
| if new_mol is not None: |
| |
| combined_mol = new_mol |
| break |
| else: |
| |
| continue |
| |
| break |
| else: |
| |
| return None |
|
|
| return combined_mol |
|
|
|
|
| |
|
|