File size: 5,762 Bytes
c262d88 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 | import os
import hydra
import lightning as L
import numpy as np
import omegaconf
import pandas as pd
import rdkit
import rich.syntax
import rich.tree
import torch
from tqdm.auto import tqdm
import pdb
import torch.nn.functional as F
import dataloader
import diffusion
from models.bindevaluator import BindEvaluator
from transformers import AutoTokenizer, EsmModel
from faesm.esm import FAEsmForMaskedLM
import torch.nn as nn
from transformers import AutoModelForMaskedLM, AutoTokenizer, AutoModel
import numpy as np
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# PEPMLM_NAME = "ChatterjeeLab/PepMLM-650M"
# PEPMLM_TOKEN = "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" #place your access token here
# PEPMLM_MODEL = AutoModelForMaskedLM.from_pretrained(PEPMLM_NAME, token=PEPMLM_TOKEN)
# pepmlm_tokenizer = AutoTokenizer.from_pretrained(PEPMLM_NAME, token=PEPMLM_TOKEN)
# pepmlm = PEPMLM_MODEL.to(DEVICE)
rdkit.rdBase.DisableLog('rdApp.error')
omegaconf.OmegaConf.register_new_resolver(
'cwd', os.getcwd)
omegaconf.OmegaConf.register_new_resolver(
'device_count', torch.cuda.device_count)
omegaconf.OmegaConf.register_new_resolver(
'eval', eval)
omegaconf.OmegaConf.register_new_resolver(
'div_up', lambda x, y: (x + y - 1) // y)
omegaconf.OmegaConf.register_new_resolver(
'if_then_else',
lambda condition, x, y: x if condition else y
)
def _print_config(
config: omegaconf.DictConfig,
resolve: bool = True) -> None:
"""Prints content of DictConfig using Rich library and its tree structure.
Args:
config (DictConfig): Configuration composed by Hydra.
resolve (bool): Whether to resolve reference fields of DictConfig.
"""
style = 'dim'
tree = rich.tree.Tree('CONFIG', style=style,
guide_style=style)
fields = config.keys()
for field in fields:
branch = tree.add(field, style=style, guide_style=style)
config_section = config.get(field)
branch_content = str(config_section)
if isinstance(config_section, omegaconf.DictConfig):
branch_content = omegaconf.OmegaConf.to_yaml(
config_section, resolve=resolve)
branch.add(rich.syntax.Syntax(branch_content, 'yaml'))
rich.print(tree)
def parse_motif(motif: str) -> list:
parts = motif.split(',')
result = []
for part in parts:
part = part.strip()
if '-' in part:
start, end = map(int, part.split('-'))
result.extend(range(start, end + 1))
else:
result.append(int(part))
return torch.tensor(result)
@hydra.main(version_base=None, config_path='./configs',
config_name='config')
def main(config: omegaconf.DictConfig) -> None:
# Reproducibility
L.seed_everything(config.seed)
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.benchmark = False
# _print_config(config, resolve=True)
print(f"Checkpoint: {config.eval.checkpoint_path}")
tokenizer = dataloader.get_tokenizer(config)
target_sequence = tokenizer(config.eval.target_sequence, return_tensors='pt')['input_ids']
pretrained = diffusion.Diffusion.load_from_checkpoint(
config.eval.checkpoint_path,
tokenizer=tokenizer,
config=config, logger=False)
pretrained.eval()
pretrained = pretrained.to('cuda')
bindevaluator = BindEvaluator.load_from_checkpoint(
config.guidance.classifier_checkpoint_path,
n_layers=8,
d_model=128,
d_hidden=128,
n_head=8,
d_k=64,
d_v=128,
d_inner=64)
bindevaluator = bindevaluator.to('cuda')
# below is the implementation of ESM with flash attention
# using 650M --> might use a bugger/smaller model
# esm = EsmModel.from_pretrained("facebook/esm2_t6_650M_UR50D")
# esm = esm.to("cuda")
# tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_650M_UR50D")
esm = FAEsmForMaskedLM.from_pretrained("facebook/esm2_t33_650M_UR50D").to("cuda").eval().to(torch.float16)
samples = []
original_binder = config.sampling.original_binder
original_binder_input = esm.tokenizer(original_binder, return_tensors="pt")
original_binder_input = {k: v.to('cuda') for k, v in original_binder_input.items()}
original_binder_outputs = esm(**original_binder_input)
original_binder_embedding = original_binder_outputs['last_hidden_state']
original_binder_embedding_avg = torch.mean(original_binder_embedding, dim=1)
for _ in tqdm(
range(config.sampling.num_sample_batches),
desc='Gen. batches', leave=False):
sample = pretrained.sample(
target_sequence = target_sequence,
target_motifs = parse_motif(config.eval.target_motifs),
classifier_model = bindevaluator
)
sample_decoded = pretrained.tokenizer.batch_decode(sample)
samples_processed = [seq.replace(' ', '')[5:-5] for seq in sample_decoded]
print('sample: ', samples_processed)
samples.extend(samples_processed)
samples_similarity = {}
with torch.no_grad():
for seq in tqdm(samples, desc='Computing similarities'):
seq_input = esm.tokenizer(seq, return_tensors="pt")
seq_input = {k: v.to('cuda') for k, v in seq_input.items()}
seq_output = esm(**seq_input)
seq_embedding = seq_output['last_hidden_state']
seq_embedding_avg = torch.mean(seq_embedding, dim=1)
similarity_score = F.cosine_similarity(seq_embedding_avg, original_binder_embedding_avg)
samples_similarity[seq] = similarity_score.item()
outputs_csv = pd.DataFrame({
'samples': list(samples),
'samples_similarity': list(samples_similarity.values())
})
print("outputs_csv", outputs_csv)
outputs_csv.to_csv('il2_alpha_guidance.csv', index = False)
if __name__ == '__main__':
main() |