| | import torch |
| | import torch.nn as nn |
| | from transformers import PreTrainedTokenizer |
| | from transformers.tokenization_utils_base import BatchEncoding |
| | from transformers import AutoTokenizer, AutoModel |
| | from rdkit import Chem |
| | from rdkit.Chem import Descriptors, AllChem, MACCSkeys |
| | from rdkit.ML.Descriptors import MoleculeDescriptors |
| | from rdkit import RDLogger |
| | from rdkit.Chem import Draw |
| | import joblib |
| | import numpy as np |
| | import os |
| | from huggingface_hub import snapshot_download |
| | import warnings |
| | from sklearn.exceptions import InconsistentVersionWarning |
| | from torchvision import models, transforms |
| | from PIL import Image |
| | warnings.filterwarnings("ignore", category=InconsistentVersionWarning) |
| | RDLogger.DisableLog('rdApp.*') |
| |
|
| | class BBBTokenizer(PreTrainedTokenizer): |
| | def __init__(self, **kwargs): |
| | super().__init__(**kwargs) |
| | |
| | self.calc = MoleculeDescriptors.MolecularDescriptorCalculator([i[0] for i in Descriptors.descList]) |
| |
|
| | self.tokenizer = AutoTokenizer.from_pretrained('DeepChem/ChemBERTa-100M-MLM') |
| | self.chemberta = AutoModel.from_pretrained('DeepChem/ChemBERTa-100M-MLM').eval() |
| |
|
| | self.resnet50_backbone = models.resnet50(weights="IMAGENET1K_V1") |
| | self.resnet = nn.Sequential(*list(self.resnet50_backbone.children())[:-1]).eval() |
| | self.img_preprocess = transforms.Compose([ |
| | transforms.Resize((224, 224)), |
| | transforms.ToTensor(), |
| | transforms.Normalize( |
| | mean=[0.485, 0.456, 0.406], |
| | std=[0.229, 0.224, 0.225], |
| | ) |
| | ]) |
| |
|
| | self.feature_transformer_tab = None |
| | self.feature_transformer_img = None |
| | self.feature_transformer_txt = None |
| | self.task = None |
| |
|
| | def generate_tab_features(self, smiles): |
| | mol = Chem.MolFromSmiles(smiles) |
| | |
| | if mol is None: |
| | return torch.tensor(self.feature_transformer_tab.n_features_in_, dtype=torch.float32) |
| | |
| | rdkit_2d = np.array(self.calc.CalcDescriptors(mol)) |
| | rdkit_2d[np.isinf(rdkit_2d)] = np.nan |
| | rdkit_2d = np.nan_to_num(rdkit_2d, nan=0.0, posinf=0.0, neginf=0.0) |
| | maccs = np.array(list(MACCSkeys.GenMACCSKeys(mol).ToBitString()), dtype=int) |
| | tab_input = np.concatenate([rdkit_2d, maccs]) |
| | tab_input = self.feature_transformer_tab.transform(tab_input.reshape(1, -1))[0] |
| | tab_input = np.clip(tab_input, -1e5, 1e5) |
| | return torch.tensor(tab_input, dtype=torch.float32) |
| |
|
| | def generate_img_features(self, smiles): |
| | mol = Chem.MolFromSmiles(smiles) |
| | if mol is None: |
| | img = Image.new("RGB", (300,300), color=(0,0,0)) |
| | else: |
| | img = Draw.MolToImage(mol, size=(300, 300)) |
| | img = self.img_preprocess(img) |
| | with torch.no_grad(): |
| | img_input = self.resnet(img.unsqueeze(0)).squeeze(-1).squeeze(-1) |
| | img_input = self.feature_transformer_img.transform(img_input.reshape(1, -1))[0] |
| | return torch.tensor(img_input, dtype=torch.float32) |
| |
|
| | def generate_txt_features(self, smiles): |
| | encoded = self.tokenizer(smiles, return_tensors="pt") |
| | with torch.no_grad(): |
| | outputs = self.chemberta(**encoded) |
| | hidden_states = outputs.last_hidden_state[0].mean(axis=0).numpy() |
| | txt_input = self.feature_transformer_txt.transform(hidden_states.reshape(1, -1))[0] |
| | return torch.tensor(txt_input, dtype=torch.float32) |
| |
|
| | def _batch_encode_plus( |
| | self, |
| | batch_smiles: list[str], |
| | task: str = 'classification', |
| | return_tensors: str = "pt", |
| | **kwargs |
| | ): |
| | if self.task is None or self.task != task: |
| | if task == 'classification': |
| | model_dir = snapshot_download("SaeedLab/TITAN-BBB", allow_patterns=["normalize_cls_tabular.joblib"]) |
| | transformer_tab_path = os.path.join(model_dir, "normalize_cls_tabular.joblib") |
| | model_dir = snapshot_download("SaeedLab/TITAN-BBB", allow_patterns=["normalize_cls_image.joblib"]) |
| | transformer_img_path = os.path.join(model_dir, "normalize_cls_image.joblib") |
| | model_dir = snapshot_download("SaeedLab/TITAN-BBB", allow_patterns=["normalize_cls_text.joblib"]) |
| | transformer_txt_path = os.path.join(model_dir, "normalize_cls_text.joblib") |
| | self.task = task |
| |
|
| | elif task == 'regression': |
| | model_dir = snapshot_download("SaeedLab/TITAN-BBB", allow_patterns=["normalize_reg_tabular.joblib"]) |
| | transformer_tab_path = os.path.join(model_dir, "normalize_reg_tabular.joblib") |
| | model_dir = snapshot_download("SaeedLab/TITAN-BBB", allow_patterns=["normalize_reg_image.joblib"]) |
| | transformer_img_path = os.path.join(model_dir, "normalize_reg_image.joblib") |
| | model_dir = snapshot_download("SaeedLab/TITAN-BBB", allow_patterns=["normalize_reg_text.joblib"]) |
| | transformer_txt_path = os.path.join(model_dir, "normalize_reg_text.joblib") |
| | self.task = task |
| |
|
| | else: |
| | raise ValueError('task not defined') |
| | return |
| |
|
| | self.feature_transformer_tab = joblib.load(transformer_tab_path) |
| | self.feature_transformer_img = joblib.load(transformer_img_path) |
| | self.feature_transformer_txt = joblib.load(transformer_txt_path) |
| | |
| | data_list = [] |
| | tab, img, txt = [], [], [] |
| |
|
| | for smiles in batch_smiles: |
| | tab.append(self.generate_tab_features(smiles)) |
| | img.append(self.generate_img_features(smiles)) |
| | txt.append(self.generate_txt_features(smiles)) |
| |
|
| | tab = torch.stack(tab) |
| | img = torch.stack(img) |
| | txt = torch.stack(txt) |
| |
|
| | output = {} |
| | output["tab"] = tab |
| | output["img"] = img |
| | output["txt"] = txt |
| | |
| | return BatchEncoding(output, tensor_type=return_tensors) |
| |
|
| | def encode(self, |
| | batch_smiles: list[str], |
| | task: str = 'classification', |
| | return_tensors: str = "pt", |
| | **kwargs): |
| | return self._batch_encode_plus(batch_smiles, task, return_tensors, **kwargs) |
| |
|
| | def __call__(self, |
| | batch_smiles: list[str], |
| | task: str = 'classification', |
| | return_tensors: str = "pt", |
| | **kwargs): |
| | return self._batch_encode_plus(batch_smiles, task, return_tensors, **kwargs) |
| | |
| | def _tokenize(self, text, **kwargs): |
| | return [] |
| |
|
| | def save_vocabulary(self, save_directory, filename_prefix=None): |
| | return () |
| | |
| | def get_vocab(self): |
| | return {"<pad>":0, "<bos>":1, "<eos>":2, "<unk>":3, "<mask>":4} |
| |
|
| | @property |
| | def vocab_size(self): |
| | return len(self.get_vocab()) |