import os import hydra import lightning as L import numpy as np import omegaconf import pandas as pd import rdkit import rich.syntax import rich.tree import torch from tqdm.auto import tqdm import pdb import torch.nn.functional as F import dataloader import diffusion from models.bindevaluator import BindEvaluator from transformers import AutoTokenizer, EsmModel from faesm.esm import FAEsmForMaskedLM import torch.nn as nn from transformers import AutoModelForMaskedLM, AutoTokenizer, AutoModel import numpy as np DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # PEPMLM_NAME = "ChatterjeeLab/PepMLM-650M" # PEPMLM_TOKEN = "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" #place your access token here # PEPMLM_MODEL = AutoModelForMaskedLM.from_pretrained(PEPMLM_NAME, token=PEPMLM_TOKEN) # pepmlm_tokenizer = AutoTokenizer.from_pretrained(PEPMLM_NAME, token=PEPMLM_TOKEN) # pepmlm = PEPMLM_MODEL.to(DEVICE) rdkit.rdBase.DisableLog('rdApp.error') omegaconf.OmegaConf.register_new_resolver( 'cwd', os.getcwd) omegaconf.OmegaConf.register_new_resolver( 'device_count', torch.cuda.device_count) omegaconf.OmegaConf.register_new_resolver( 'eval', eval) omegaconf.OmegaConf.register_new_resolver( 'div_up', lambda x, y: (x + y - 1) // y) omegaconf.OmegaConf.register_new_resolver( 'if_then_else', lambda condition, x, y: x if condition else y ) def _print_config( config: omegaconf.DictConfig, resolve: bool = True) -> None: """Prints content of DictConfig using Rich library and its tree structure. Args: config (DictConfig): Configuration composed by Hydra. resolve (bool): Whether to resolve reference fields of DictConfig. """ style = 'dim' tree = rich.tree.Tree('CONFIG', style=style, guide_style=style) fields = config.keys() for field in fields: branch = tree.add(field, style=style, guide_style=style) config_section = config.get(field) branch_content = str(config_section) if isinstance(config_section, omegaconf.DictConfig): branch_content = omegaconf.OmegaConf.to_yaml( config_section, resolve=resolve) branch.add(rich.syntax.Syntax(branch_content, 'yaml')) rich.print(tree) def parse_motif(motif: str) -> list: parts = motif.split(',') result = [] for part in parts: part = part.strip() if '-' in part: start, end = map(int, part.split('-')) result.extend(range(start, end + 1)) else: result.append(int(part)) return torch.tensor(result) @hydra.main(version_base=None, config_path='./configs', config_name='config') def main(config: omegaconf.DictConfig) -> None: # Reproducibility L.seed_everything(config.seed) os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' torch.use_deterministic_algorithms(True) torch.backends.cudnn.benchmark = False # _print_config(config, resolve=True) print(f"Checkpoint: {config.eval.checkpoint_path}") tokenizer = dataloader.get_tokenizer(config) target_sequence = tokenizer(config.eval.target_sequence, return_tensors='pt')['input_ids'] pretrained = diffusion.Diffusion.load_from_checkpoint( config.eval.checkpoint_path, tokenizer=tokenizer, config=config, logger=False) pretrained.eval() pretrained = pretrained.to('cuda') bindevaluator = BindEvaluator.load_from_checkpoint( config.guidance.classifier_checkpoint_path, n_layers=8, d_model=128, d_hidden=128, n_head=8, d_k=64, d_v=128, d_inner=64) bindevaluator = bindevaluator.to('cuda') # below is the implementation of ESM with flash attention # using 650M --> might use a bugger/smaller model # esm = EsmModel.from_pretrained("facebook/esm2_t6_650M_UR50D") # esm = esm.to("cuda") # tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_650M_UR50D") esm = FAEsmForMaskedLM.from_pretrained("facebook/esm2_t33_650M_UR50D").to("cuda").eval().to(torch.float16) samples = [] original_binder = config.sampling.original_binder original_binder_input = esm.tokenizer(original_binder, return_tensors="pt") original_binder_input = {k: v.to('cuda') for k, v in original_binder_input.items()} original_binder_outputs = esm(**original_binder_input) original_binder_embedding = original_binder_outputs['last_hidden_state'] original_binder_embedding_avg = torch.mean(original_binder_embedding, dim=1) for _ in tqdm( range(config.sampling.num_sample_batches), desc='Gen. batches', leave=False): sample = pretrained.sample( target_sequence = target_sequence, target_motifs = parse_motif(config.eval.target_motifs), classifier_model = bindevaluator ) sample_decoded = pretrained.tokenizer.batch_decode(sample) samples_processed = [seq.replace(' ', '')[5:-5] for seq in sample_decoded] print('sample: ', samples_processed) samples.extend(samples_processed) samples_similarity = {} with torch.no_grad(): for seq in tqdm(samples, desc='Computing similarities'): seq_input = esm.tokenizer(seq, return_tensors="pt") seq_input = {k: v.to('cuda') for k, v in seq_input.items()} seq_output = esm(**seq_input) seq_embedding = seq_output['last_hidden_state'] seq_embedding_avg = torch.mean(seq_embedding, dim=1) similarity_score = F.cosine_similarity(seq_embedding_avg, original_binder_embedding_avg) samples_similarity[seq] = similarity_score.item() outputs_csv = pd.DataFrame({ 'samples': list(samples), 'samples_similarity': list(samples_similarity.values()) }) print("outputs_csv", outputs_csv) outputs_csv.to_csv('il2_alpha_guidance.csv', index = False) if __name__ == '__main__': main()