| import os |
| from rdkit import Chem |
| from rdkit.Chem import Draw, AllChem |
| from rdkit.Geometry import Point3D |
| from rdkit import RDLogger |
| import numpy as np |
| import rdkit.Chem |
|
|
| class MolecularVisualization: |
| def __init__(self, atom_decoder): |
| self.atom_decoder = atom_decoder |
|
|
| def mol_from_graphs(self, node_list, adjacency_matrix): |
| """ |
| Convert graphs to rdkit molecules |
| node_list: the nodes of a batch of nodes (bs x n) |
| adjacency_matrix: the adjacency_matrix of the molecule (bs x n x n) |
| """ |
| |
| atom_decoder = self.atom_decoder |
|
|
| |
| mol = Chem.RWMol() |
|
|
| |
| node_to_idx = {} |
| for i in range(len(node_list)): |
| if node_list[i] == -1: |
| continue |
| a = Chem.Atom(atom_decoder[int(node_list[i])]) |
| molIdx = mol.AddAtom(a) |
| node_to_idx[i] = molIdx |
|
|
| for ix, row in enumerate(adjacency_matrix): |
| for iy, bond in enumerate(row): |
| |
| if iy <= ix: |
| continue |
| if bond == 1: |
| bond_type = Chem.rdchem.BondType.SINGLE |
| elif bond == 2: |
| bond_type = Chem.rdchem.BondType.DOUBLE |
| elif bond == 3: |
| bond_type = Chem.rdchem.BondType.TRIPLE |
| elif bond == 4: |
| bond_type = Chem.rdchem.BondType.AROMATIC |
| else: |
| continue |
| mol.AddBond(node_to_idx[ix], node_to_idx[iy], bond_type) |
|
|
| try: |
| mol = mol.GetMol() |
| except rdkit.Chem.KekulizeException: |
| print("Can't kekulize molecule") |
| mol = None |
| return mol |
|
|
| def visualize_chain(self, nodes_list, adjacency_matrix): |
| RDLogger.DisableLog('rdApp.*') |
| |
| mols = [self.mol_from_graphs(nodes_list[i], adjacency_matrix[i]) for i in range(nodes_list.shape[0])] |
|
|
| |
| final_molecule = mols[-1] |
| AllChem.Compute2DCoords(final_molecule) |
|
|
| coords = [] |
| for i, atom in enumerate(final_molecule.GetAtoms()): |
| positions = final_molecule.GetConformer().GetAtomPosition(i) |
| coords.append((positions.x, positions.y, positions.z)) |
|
|
| |
| for i, mol in enumerate(mols): |
| AllChem.Compute2DCoords(mol) |
| conf = mol.GetConformer() |
| for j, atom in enumerate(mol.GetAtoms()): |
| x, y, z = coords[j] |
| conf.SetAtomPosition(j, Point3D(x, y, z)) |
|
|
| |
| mol_images = [] |
| for frame, mol in enumerate(mols): |
| img = Draw.MolToImage(mol, size=(300, 300), legend=f"Frame {frame}") |
| mol_images.append(img) |
|
|
| return mol_images |