moPPIt-v2 / sample_emb_guidance.py
Kseniia-Kholina's picture
script with esm embedding guidance
c262d88 verified
import os
import hydra
import lightning as L
import numpy as np
import omegaconf
import pandas as pd
import rdkit
import rich.syntax
import rich.tree
import torch
from tqdm.auto import tqdm
import pdb
import torch.nn.functional as F
import dataloader
import diffusion
from models.bindevaluator import BindEvaluator
from transformers import AutoTokenizer, EsmModel
from faesm.esm import FAEsmForMaskedLM
import torch.nn as nn
from transformers import AutoModelForMaskedLM, AutoTokenizer, AutoModel
import numpy as np
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# PEPMLM_NAME = "ChatterjeeLab/PepMLM-650M"
# PEPMLM_TOKEN = "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" #place your access token here
# PEPMLM_MODEL = AutoModelForMaskedLM.from_pretrained(PEPMLM_NAME, token=PEPMLM_TOKEN)
# pepmlm_tokenizer = AutoTokenizer.from_pretrained(PEPMLM_NAME, token=PEPMLM_TOKEN)
# pepmlm = PEPMLM_MODEL.to(DEVICE)
rdkit.rdBase.DisableLog('rdApp.error')
omegaconf.OmegaConf.register_new_resolver(
'cwd', os.getcwd)
omegaconf.OmegaConf.register_new_resolver(
'device_count', torch.cuda.device_count)
omegaconf.OmegaConf.register_new_resolver(
'eval', eval)
omegaconf.OmegaConf.register_new_resolver(
'div_up', lambda x, y: (x + y - 1) // y)
omegaconf.OmegaConf.register_new_resolver(
'if_then_else',
lambda condition, x, y: x if condition else y
)
def _print_config(
config: omegaconf.DictConfig,
resolve: bool = True) -> None:
"""Prints content of DictConfig using Rich library and its tree structure.
Args:
config (DictConfig): Configuration composed by Hydra.
resolve (bool): Whether to resolve reference fields of DictConfig.
"""
style = 'dim'
tree = rich.tree.Tree('CONFIG', style=style,
guide_style=style)
fields = config.keys()
for field in fields:
branch = tree.add(field, style=style, guide_style=style)
config_section = config.get(field)
branch_content = str(config_section)
if isinstance(config_section, omegaconf.DictConfig):
branch_content = omegaconf.OmegaConf.to_yaml(
config_section, resolve=resolve)
branch.add(rich.syntax.Syntax(branch_content, 'yaml'))
rich.print(tree)
def parse_motif(motif: str) -> list:
parts = motif.split(',')
result = []
for part in parts:
part = part.strip()
if '-' in part:
start, end = map(int, part.split('-'))
result.extend(range(start, end + 1))
else:
result.append(int(part))
return torch.tensor(result)
@hydra.main(version_base=None, config_path='./configs',
config_name='config')
def main(config: omegaconf.DictConfig) -> None:
# Reproducibility
L.seed_everything(config.seed)
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.benchmark = False
# _print_config(config, resolve=True)
print(f"Checkpoint: {config.eval.checkpoint_path}")
tokenizer = dataloader.get_tokenizer(config)
target_sequence = tokenizer(config.eval.target_sequence, return_tensors='pt')['input_ids']
pretrained = diffusion.Diffusion.load_from_checkpoint(
config.eval.checkpoint_path,
tokenizer=tokenizer,
config=config, logger=False)
pretrained.eval()
pretrained = pretrained.to('cuda')
bindevaluator = BindEvaluator.load_from_checkpoint(
config.guidance.classifier_checkpoint_path,
n_layers=8,
d_model=128,
d_hidden=128,
n_head=8,
d_k=64,
d_v=128,
d_inner=64)
bindevaluator = bindevaluator.to('cuda')
# below is the implementation of ESM with flash attention
# using 650M --> might use a bugger/smaller model
# esm = EsmModel.from_pretrained("facebook/esm2_t6_650M_UR50D")
# esm = esm.to("cuda")
# tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_650M_UR50D")
esm = FAEsmForMaskedLM.from_pretrained("facebook/esm2_t33_650M_UR50D").to("cuda").eval().to(torch.float16)
samples = []
original_binder = config.sampling.original_binder
original_binder_input = esm.tokenizer(original_binder, return_tensors="pt")
original_binder_input = {k: v.to('cuda') for k, v in original_binder_input.items()}
original_binder_outputs = esm(**original_binder_input)
original_binder_embedding = original_binder_outputs['last_hidden_state']
original_binder_embedding_avg = torch.mean(original_binder_embedding, dim=1)
for _ in tqdm(
range(config.sampling.num_sample_batches),
desc='Gen. batches', leave=False):
sample = pretrained.sample(
target_sequence = target_sequence,
target_motifs = parse_motif(config.eval.target_motifs),
classifier_model = bindevaluator
)
sample_decoded = pretrained.tokenizer.batch_decode(sample)
samples_processed = [seq.replace(' ', '')[5:-5] for seq in sample_decoded]
print('sample: ', samples_processed)
samples.extend(samples_processed)
samples_similarity = {}
with torch.no_grad():
for seq in tqdm(samples, desc='Computing similarities'):
seq_input = esm.tokenizer(seq, return_tensors="pt")
seq_input = {k: v.to('cuda') for k, v in seq_input.items()}
seq_output = esm(**seq_input)
seq_embedding = seq_output['last_hidden_state']
seq_embedding_avg = torch.mean(seq_embedding, dim=1)
similarity_score = F.cosine_similarity(seq_embedding_avg, original_binder_embedding_avg)
samples_similarity[seq] = similarity_score.item()
outputs_csv = pd.DataFrame({
'samples': list(samples),
'samples_similarity': list(samples_similarity.values())
})
print("outputs_csv", outputs_csv)
outputs_csv.to_csv('il2_alpha_guidance.csv', index = False)
if __name__ == '__main__':
main()