| import pandas as pd |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from torch.utils.data import Dataset, DataLoader |
| from transformers import BertConfig, BertModel, AutoTokenizer |
| from rdkit import Chem, RDLogger |
| from rdkit.Chem.Scaffolds import MurckoScaffold |
| import copy |
| from tqdm import tqdm |
| import os |
| from sklearn.metrics import roc_auc_score, root_mean_squared_error, mean_absolute_error |
| from itertools import compress |
| from collections import defaultdict |
| from sklearn.metrics.pairwise import cosine_similarity |
| RDLogger.DisableLog('rdApp.*') |
|
|
|
|
| torch.set_float32_matmul_precision('high') |
|
|
| |
| class SmilesEnumerator: |
| """Generates randomized SMILES strings for data augmentation.""" |
| def randomize_smiles(self, smiles): |
| try: |
| mol = Chem.MolFromSmiles(smiles) |
| return Chem.MolToSmiles(mol, doRandom=True, canonical=False) if mol else smiles |
| except: |
| return smiles |
|
|
|
|
| def compute_embedding_similarity(encoder, smiles_list, tokenizer, device, max_len=256): |
| encoder.eval() |
| enumerator = SmilesEnumerator() |
|
|
| embeddings_orig = [] |
| embeddings_aug = [] |
|
|
| with torch.no_grad(): |
| for smi in smiles_list: |
| |
| encoding_orig = tokenizer( |
| smi, |
| truncation=True, |
| padding='max_length', |
| max_length=max_len, |
| return_tensors='pt' |
| ) |
| |
| smi_aug = enumerator.randomize_smiles(smi) |
| encoding_aug = tokenizer( |
| smi_aug, |
| truncation=True, |
| padding='max_length', |
| max_length=max_len, |
| return_tensors='pt' |
| ) |
|
|
| input_ids_orig = encoding_orig.input_ids.to(device) |
| attention_mask_orig = encoding_orig.attention_mask.to(device) |
| input_ids_aug = encoding_aug.input_ids.to(device) |
| attention_mask_aug = encoding_aug.attention_mask.to(device) |
|
|
| emb_orig = encoder(input_ids_orig, attention_mask_orig).cpu().numpy().flatten() |
| emb_aug = encoder(input_ids_aug, attention_mask_aug).cpu().numpy().flatten() |
|
|
| embeddings_orig.append(emb_orig) |
| embeddings_aug.append(emb_aug) |
|
|
| embeddings_orig = np.array(embeddings_orig) |
| embeddings_aug = np.array(embeddings_aug) |
|
|
| |
| similarities = np.array([cosine_similarity([embeddings_orig[i]], [embeddings_aug[i]])[0][0] for i in range(len(embeddings_orig))]) |
| return similarities |
|
|
| |
| def load_lists_from_url(data): |
| if data == 'bbbp': |
| df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/BBBP.csv') |
| smiles, labels = df.smiles, df.p_np |
| elif data == 'clintox': |
| df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/clintox.csv.gz', compression='gzip') |
| smiles = df.smiles |
| labels = df.drop(['smiles'], axis=1) |
| elif data == 'hiv': |
| df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/HIV.csv') |
| smiles, labels = df.smiles, df.HIV_active |
| elif data == 'sider': |
| df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/sider.csv.gz', compression='gzip') |
| smiles = df.smiles |
| labels = df.drop(['smiles'], axis=1) |
| elif data == 'esol': |
| df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/delaney-processed.csv') |
| smiles = df.smiles |
| labels = df['ESOL predicted log solubility in mols per litre'] |
| elif data == 'freesolv': |
| df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/SAMPL.csv') |
| smiles = df.smiles |
| labels = df.calc |
| elif data == 'lipophicility': |
| df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/Lipophilicity.csv') |
| smiles, labels = df.smiles, df['exp'] |
| elif data == 'tox21': |
| df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/tox21.csv.gz', compression='gzip') |
| df = df.dropna(axis=0, how='any').reset_index(drop=True) |
| smiles = df.smiles |
| labels = df.drop(['mol_id', 'smiles'], axis=1) |
| elif data == 'bace': |
| df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/bace.csv') |
| smiles, labels = df.mol, df.Class |
| elif data == 'qm8': |
| df = pd.read_csv('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/qm8.csv') |
| df = df.dropna(axis=0, how='any').reset_index(drop=True) |
| smiles = df.smiles |
| labels = df.drop(['smiles', 'E2-PBE0.1', 'E1-PBE0.1', 'f1-PBE0.1', 'f2-PBE0.1'], axis=1) |
| return smiles, labels |
|
|
| |
| class ScaffoldSplitter: |
| def __init__(self, data, seed, train_frac=0.8, val_frac=0.1, test_frac=0.1, include_chirality=True): |
| self.data = data |
| self.seed = seed |
| self.include_chirality = include_chirality |
| self.train_frac = train_frac |
| self.val_frac = val_frac |
| self.test_frac = test_frac |
|
|
| def generate_scaffold(self, smiles): |
| mol = Chem.MolFromSmiles(smiles) |
| scaffold = MurckoScaffold.MurckoScaffoldSmiles(mol=mol, includeChirality=self.include_chirality) |
| return scaffold |
|
|
| def scaffold_split(self): |
| smiles, labels = load_lists_from_url(self.data) |
| non_null = np.ones(len(smiles)) == 0 |
|
|
| if self.data in {'tox21', 'sider', 'clintox'}: |
| for i in range(len(smiles)): |
| if Chem.MolFromSmiles(smiles[i]) and labels.loc[i].isnull().sum() == 0: |
| non_null[i] = 1 |
| else: |
| for i in range(len(smiles)): |
| if Chem.MolFromSmiles(smiles[i]): |
| non_null[i] = 1 |
|
|
| smiles_list = list(compress(enumerate(smiles), non_null)) |
| rng = np.random.RandomState(self.seed) |
|
|
| scaffolds = defaultdict(list) |
| for i, sms in smiles_list: |
| scaffold = self.generate_scaffold(sms) |
| scaffolds[scaffold].append(i) |
|
|
| scaffold_sets = list(scaffolds.values()) |
| rng.shuffle(scaffold_sets) |
| n_total_val = int(np.floor(self.val_frac * len(smiles_list))) |
| n_total_test = int(np.floor(self.test_frac * len(smiles_list))) |
| train_idx, val_idx, test_idx = [], [], [] |
|
|
| for scaffold_set in scaffold_sets: |
| if len(val_idx) + len(scaffold_set) <= n_total_val: |
| val_idx.extend(scaffold_set) |
| elif len(test_idx) + len(scaffold_set) <= n_total_test: |
| test_idx.extend(scaffold_set) |
| else: |
| train_idx.extend(scaffold_set) |
| return train_idx, val_idx, test_idx |
|
|
| |
| def random_split_indices(n, seed=42, train_frac=0.8, val_frac=0.1, test_frac=0.1): |
| np.random.seed(seed) |
| indices = np.random.permutation(n) |
| n_train = int(n * train_frac) |
| n_val = int(n * val_frac) |
| train_idx = indices[:n_train] |
| val_idx = indices[n_train:n_train+n_val] |
| test_idx = indices[n_train+n_val:] |
| return train_idx.tolist(), val_idx.tolist(), test_idx.tolist() |
|
|
| |
| class MoleculeDataset(Dataset): |
| def __init__(self, smiles_list, labels, tokenizer, max_len=512): |
| self.smiles_list = smiles_list |
| self.labels = labels |
| self.tokenizer = tokenizer |
| self.max_len = max_len |
|
|
| def __len__(self): |
| return len(self.smiles_list) |
|
|
| def __getitem__(self, idx): |
| smiles = self.smiles_list[idx] |
| label = self.labels.iloc[idx] |
|
|
| encoding = self.tokenizer( |
| smiles, |
| truncation=True, |
| padding='max_length', |
| max_length=self.max_len, |
| return_tensors='pt' |
| ) |
| item = {key: val.squeeze(0) for key, val in encoding.items()} |
| if isinstance(label, pd.Series): |
| label_values = label.values.astype(np.float32) |
| else: |
| label_values = np.array([label], dtype=np.float32) |
| item['labels'] = torch.tensor(label_values, dtype=torch.float) |
| return item |
|
|
| |
| def global_ap(x): |
| return torch.mean(x.view(x.size(0), x.size(1), -1), dim=1) |
|
|
| class SimSonEncoder(nn.Module): |
| def __init__(self, config: BertConfig, max_len: int, dropout: float = 0.1): |
| super(SimSonEncoder, self).__init__() |
| self.config = config |
| self.max_len = max_len |
| self.bert = BertModel(config, add_pooling_layer=False) |
| self.linear = nn.Linear(config.hidden_size, max_len) |
| self.dropout = nn.Dropout(dropout) |
| def forward(self, input_ids, attention_mask=None): |
| if attention_mask is None: |
| attention_mask = input_ids.ne(self.config.pad_token_id) |
| outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) |
| hidden_states = self.dropout(outputs.last_hidden_state) |
| pooled = global_ap(hidden_states) |
| return self.linear(pooled) |
|
|
| class SimSonClassifier(nn.Module): |
| def __init__(self, encoder: SimSonEncoder, num_labels: int, dropout=0.1): |
| super(SimSonClassifier, self).__init__() |
| self.encoder = encoder |
| self.clf = nn.Linear(encoder.max_len, num_labels) |
| self.relu = nn.ReLU() |
| self.dropout = nn.Dropout(dropout) |
| def forward(self, input_ids, attention_mask=None): |
| x = self.encoder(input_ids, attention_mask) |
| x = self.relu(self.dropout(x)) |
| logits = self.clf(x) |
| return logits |
|
|
| def load_encoder_params(self, state_dict_path): |
| self.encoder.load_state_dict(torch.load(state_dict_path)) |
| print("Pretrained encoder parameters loaded.") |
|
|
| |
| def get_criterion(task_type, num_labels): |
| if task_type == 'classification': |
| return nn.BCEWithLogitsLoss() |
| elif task_type == 'regression': |
| return nn.MSELoss() |
| else: |
| raise ValueError(f"Unknown task type: {task_type}") |
|
|
| def train_epoch(model, dataloader, optimizer, scheduler, criterion, device): |
| model.train() |
| total_loss = 0 |
| for batch in dataloader: |
| inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'} |
| labels = batch['labels'].to(device) |
| optimizer.zero_grad() |
| outputs = model(**inputs) |
| loss = criterion(outputs, labels) |
| loss.backward() |
| optimizer.step() |
| |
| total_loss += loss.item() |
| return total_loss / len(dataloader) |
|
|
| def eval_epoch(model, dataloader, criterion, device): |
| model.eval() |
| total_loss = 0 |
| with torch.no_grad(): |
| for batch in dataloader: |
| inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'} |
| labels = batch['labels'].to(device) |
| outputs = model(**inputs) |
| loss = criterion(outputs, labels) |
| total_loss += loss.item() |
| return total_loss / len(dataloader) |
|
|
| def test_model(model, dataloader, device): |
| model.eval() |
| all_preds, all_labels = [], [] |
| with torch.no_grad(): |
| for batch in dataloader: |
| inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'} |
| labels = batch['labels'] |
| outputs = model(**inputs) |
| preds = torch.sigmoid(outputs) |
| all_preds.append(preds.cpu().numpy()) |
| all_labels.append(labels.numpy()) |
| return np.concatenate(all_preds), np.concatenate(all_labels) |
|
|
| def calc_val_metrics(model, dataloader, criterion, device, task_type): |
| model.eval() |
| all_labels, all_preds = [], [] |
| total_loss = 0 |
| with torch.no_grad(): |
| for batch in dataloader: |
| inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'} |
| labels = batch['labels'].to(device) |
| outputs = model(**inputs) |
| loss = criterion(outputs, labels) |
| total_loss += loss.item() |
| if task_type == 'classification': |
| pred_probs = torch.sigmoid(outputs).cpu().numpy() |
| all_preds.append(pred_probs) |
| all_labels.append(labels.cpu().numpy()) |
| else: |
| |
| preds = outputs.cpu().numpy() |
| all_preds.append(preds) |
| all_labels.append(labels.cpu().numpy()) |
| avg_loss = total_loss / len(dataloader) |
| if task_type == 'classification': |
| y_true = np.concatenate(all_labels) |
| y_pred = np.concatenate(all_preds) |
| try: |
| score = roc_auc_score(y_true, y_pred, average='macro') |
| except Exception: |
| score = 0.0 |
| return avg_loss, score |
| else: |
| return avg_loss, None |
|
|
| |
| def main(): |
| DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print(f"Using device: {DEVICE}") |
|
|
| DATASETS_TO_RUN = { |
| |
| |
| |
| |
| |
| |
| 'clintox': {'task_type': 'classification', 'num_labels': 2, 'split': 'random'}, |
| |
| } |
| PATIENCE = 15 |
| EPOCHS = 50 |
| LEARNING_RATE = 1e-4 |
| BATCH_SIZE = 16 |
| MAX_LEN = 512 |
|
|
| TOKENIZER = AutoTokenizer.from_pretrained('DeepChem/ChemBERTa-77M-MTR') |
| ENCODER_CONFIG = BertConfig( |
| vocab_size=TOKENIZER.vocab_size, |
| hidden_size=768, |
| num_hidden_layers=4, |
| num_attention_heads=12, |
| intermediate_size=2048, |
| max_position_embeddings=512 |
| ) |
|
|
| aggregated_results = {} |
|
|
| for name, info in DATASETS_TO_RUN.items(): |
| print(f"\n{'='*20} Processing Dataset: {name.upper()} ({info['split']} split) {'='*20}") |
| smiles, labels = load_lists_from_url(name) |
|
|
| |
| if info.get('split', 'scaffold') == 'scaffold': |
| splitter = ScaffoldSplitter(data=name, seed=42) |
| train_idx, val_idx, test_idx = splitter.scaffold_split() |
| elif info['split'] == 'random': |
| train_idx, val_idx, test_idx = random_split_indices(len(smiles), seed=42) |
| else: |
| raise ValueError(f"Unknown split type for {name}: {info['split']}") |
|
|
| train_smiles = smiles.iloc[train_idx].reset_index(drop=True) |
| train_labels = labels.iloc[train_idx].reset_index(drop=True) |
| val_smiles = smiles.iloc[val_idx].reset_index(drop=True) |
| val_labels = labels.iloc[val_idx].reset_index(drop=True) |
| test_smiles = smiles.iloc[test_idx].reset_index(drop=True) |
| test_labels = labels.iloc[test_idx].reset_index(drop=True) |
| print(f"Data split - Train: {len(train_smiles)}, Val: {len(val_smiles)}, Test: {len(test_smiles)}") |
|
|
| train_dataset = MoleculeDataset(train_smiles, train_labels, TOKENIZER, MAX_LEN) |
| val_dataset = MoleculeDataset(val_smiles, val_labels, TOKENIZER, MAX_LEN) |
| test_dataset = MoleculeDataset(test_smiles, test_labels, TOKENIZER, MAX_LEN) |
|
|
| train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True) |
| val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False) |
| test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False) |
|
|
| encoder = SimSonEncoder(ENCODER_CONFIG, 512) |
| encoder = torch.compile(encoder) |
| model = SimSonClassifier(encoder, num_labels=info['num_labels']).to(DEVICE) |
| model.load_encoder_params('../simson_checkpoints/checkpoint_best_model.bin') |
| criterion = get_criterion(info['task_type'], info['num_labels']) |
| optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=0.0024) |
| scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.59298) |
|
|
| best_val_loss = float('-inf') |
| best_model_state = None |
| current_patience = 0 |
| for epoch in range(EPOCHS): |
| train_loss = train_epoch(model, train_loader, optimizer, scheduler, criterion, DEVICE) |
| val_loss, val_metric = calc_val_metrics(model, val_loader, criterion, 'cuda', info['task_type']) |
| print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | ROC AUC: {val_metric:.4f}") |
|
|
| if val_metric <= val_loss: |
| best_val_loss = val_loss |
| best_model_state = copy.deepcopy(model.state_dict()) |
| print(f" -> New best model saved with validation loss: {best_val_loss:.4f}") |
| current_patience = 0 |
| else: |
| current_patience += 1 |
| if current_patience >= PATIENCE: |
| print(f'Early stopping at {PATIENCE} epochs') |
| break |
|
|
| print("\nTesting with the best model...") |
| if not best_model_state is None: |
| model.load_state_dict(best_model_state) |
| test_loss = eval_epoch(model, test_loader, criterion, DEVICE) |
| print(f'Test loss: {test_loss}') |
| test_preds, test_true = test_model(model, test_loader, DEVICE) |
|
|
| aggregated_results[name] = { |
| 'best_val_loss': best_val_loss, |
| 'test_predictions': test_preds, |
| 'test_labels': test_true |
| } |
| print(f"Finished testing for {name}.") |
| test_smiles_list = list(test_smiles) |
| similarities = compute_embedding_similarity( |
| model.encoder, test_smiles_list, TOKENIZER, DEVICE, MAX_LEN |
| ) |
| print(f"Similarity score: {similarities.mean():.4f}") |
| if name == 'do_not_save': |
| torch.save(model.encoder.state_dict(), 'moleculenet_clintox_encoder.bin') |
|
|
|
|
|
|
| print(f"\n{'='*20} AGGREGATED RESULTS {'='*20}") |
| for name, result in aggregated_results.items(): |
| if name in ['bbbp', 'tox21', 'sider', 'clintox', 'hiv', 'bace']: |
| auc = roc_auc_score(result['test_labels'], result['test_predictions'], average='macro') |
| print(f'{name} ROC AUC: {auc}') |
|
|
| if name in ['lipophicility', 'esol', 'qm8']: |
| rmse = root_mean_squared_error(result['test_labels'], result['test_predictions']) |
| mae = mean_absolute_error(result['test_labels'], result['test_predictions']) |
| print(f'{name} MAE: {mae}') |
| print(f'{name} RMSE: {rmse}') |
|
|
| print("\nScript finished.") |
|
|
| if __name__ == '__main__': |
| main() |
|
|