| import os |
| import hydra |
| import lightning as L |
| import numpy as np |
| import omegaconf |
| import pandas as pd |
| import rdkit |
| import rich.syntax |
| import rich.tree |
| import torch |
| from tqdm.auto import tqdm |
| import pdb |
| import torch.nn.functional as F |
| import dataloader |
| import diffusion |
| from models.bindevaluator import BindEvaluator |
| from transformers import AutoTokenizer, EsmModel |
| from faesm.esm import FAEsmForMaskedLM |
| import torch.nn as nn |
| from transformers import AutoModelForMaskedLM, AutoTokenizer, AutoModel |
| import numpy as np |
|
|
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
| |
| |
| |
|
|
| |
| rdkit.rdBase.DisableLog('rdApp.error') |
|
|
| omegaconf.OmegaConf.register_new_resolver( |
| 'cwd', os.getcwd) |
| omegaconf.OmegaConf.register_new_resolver( |
| 'device_count', torch.cuda.device_count) |
| omegaconf.OmegaConf.register_new_resolver( |
| 'eval', eval) |
| omegaconf.OmegaConf.register_new_resolver( |
| 'div_up', lambda x, y: (x + y - 1) // y) |
| omegaconf.OmegaConf.register_new_resolver( |
| 'if_then_else', |
| lambda condition, x, y: x if condition else y |
| ) |
|
|
| def _print_config( |
| config: omegaconf.DictConfig, |
| resolve: bool = True) -> None: |
| """Prints content of DictConfig using Rich library and its tree structure. |
| |
| Args: |
| config (DictConfig): Configuration composed by Hydra. |
| resolve (bool): Whether to resolve reference fields of DictConfig. |
| """ |
|
|
| style = 'dim' |
| tree = rich.tree.Tree('CONFIG', style=style, |
| guide_style=style) |
|
|
| fields = config.keys() |
| for field in fields: |
| branch = tree.add(field, style=style, guide_style=style) |
|
|
| config_section = config.get(field) |
| branch_content = str(config_section) |
| if isinstance(config_section, omegaconf.DictConfig): |
| branch_content = omegaconf.OmegaConf.to_yaml( |
| config_section, resolve=resolve) |
|
|
| branch.add(rich.syntax.Syntax(branch_content, 'yaml')) |
| rich.print(tree) |
|
|
| def parse_motif(motif: str) -> list: |
| parts = motif.split(',') |
| result = [] |
|
|
| for part in parts: |
| part = part.strip() |
| if '-' in part: |
| start, end = map(int, part.split('-')) |
| result.extend(range(start, end + 1)) |
| else: |
| result.append(int(part)) |
|
|
| return torch.tensor(result) |
|
|
|
|
| @hydra.main(version_base=None, config_path='./configs', |
| config_name='config') |
| def main(config: omegaconf.DictConfig) -> None: |
| |
| L.seed_everything(config.seed) |
| os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' |
| torch.use_deterministic_algorithms(True) |
| torch.backends.cudnn.benchmark = False |
|
|
| |
| print(f"Checkpoint: {config.eval.checkpoint_path}") |
|
|
| tokenizer = dataloader.get_tokenizer(config) |
| target_sequence = tokenizer(config.eval.target_sequence, return_tensors='pt')['input_ids'] |
| |
| pretrained = diffusion.Diffusion.load_from_checkpoint( |
| config.eval.checkpoint_path, |
| tokenizer=tokenizer, |
| config=config, logger=False) |
| pretrained.eval() |
| pretrained = pretrained.to('cuda') |
|
|
| bindevaluator = BindEvaluator.load_from_checkpoint( |
| config.guidance.classifier_checkpoint_path, |
| n_layers=8, |
| d_model=128, |
| d_hidden=128, |
| n_head=8, |
| d_k=64, |
| d_v=128, |
| d_inner=64) |
| bindevaluator = bindevaluator.to('cuda') |
|
|
| |
| |
| |
| |
| |
|
|
| esm = FAEsmForMaskedLM.from_pretrained("facebook/esm2_t33_650M_UR50D").to("cuda").eval().to(torch.float16) |
|
|
|
|
| samples = [] |
| original_binder = config.sampling.original_binder |
| original_binder_input = esm.tokenizer(original_binder, return_tensors="pt") |
| original_binder_input = {k: v.to('cuda') for k, v in original_binder_input.items()} |
| original_binder_outputs = esm(**original_binder_input) |
| original_binder_embedding = original_binder_outputs['last_hidden_state'] |
| original_binder_embedding_avg = torch.mean(original_binder_embedding, dim=1) |
|
|
| for _ in tqdm( |
| range(config.sampling.num_sample_batches), |
| desc='Gen. batches', leave=False): |
| sample = pretrained.sample( |
| target_sequence = target_sequence, |
| target_motifs = parse_motif(config.eval.target_motifs), |
| classifier_model = bindevaluator |
| ) |
| sample_decoded = pretrained.tokenizer.batch_decode(sample) |
| samples_processed = [seq.replace(' ', '')[5:-5] for seq in sample_decoded] |
| print('sample: ', samples_processed) |
| samples.extend(samples_processed) |
|
|
| samples_similarity = {} |
|
|
| with torch.no_grad(): |
| for seq in tqdm(samples, desc='Computing similarities'): |
| seq_input = esm.tokenizer(seq, return_tensors="pt") |
| seq_input = {k: v.to('cuda') for k, v in seq_input.items()} |
| seq_output = esm(**seq_input) |
| seq_embedding = seq_output['last_hidden_state'] |
| seq_embedding_avg = torch.mean(seq_embedding, dim=1) |
| similarity_score = F.cosine_similarity(seq_embedding_avg, original_binder_embedding_avg) |
| samples_similarity[seq] = similarity_score.item() |
| |
|
|
| outputs_csv = pd.DataFrame({ |
| 'samples': list(samples), |
| 'samples_similarity': list(samples_similarity.values()) |
| }) |
| print("outputs_csv", outputs_csv) |
| outputs_csv.to_csv('il2_alpha_guidance.csv', index = False) |
|
|
|
|
|
|
| if __name__ == '__main__': |
| main() |