| |
| from fuson_plm.utils.embedding import get_esm_embeddings, load_esm2_type, redump_pickle_dictionary, load_prott5, get_prott5_embeddings |
| from fuson_plm.utils.logging import log_update, open_logfile, print_configpy |
| from fuson_plm.utils.data_cleaning import find_invalid_chars |
| from fuson_plm.utils.constants import VALID_AAS |
| from fuson_plm.training.model import FusOnpLM |
| from transformers import AutoModelForMaskedLM, AutoTokenizer, AutoModel |
| import logging |
| import torch |
| import pickle |
| import os |
| import pandas as pd |
| import numpy as np |
|
|
| def validate_sequence_col(df, seq_col): |
| |
| if seq_col not in list(df.columns): |
| raise Exception("Error: provided sequence column does not exist in the input dataframe") |
| |
| |
| df['invalid_chars'] = df[seq_col].apply(lambda x: find_invalid_chars(x, VALID_AAS)) |
| all_invalid_chars = set().union(*df['invalid_chars']) |
| df = df.drop(columns=['invalid_chars']) |
| if len(all_invalid_chars)>0: |
| raise Exception(f"Error: invalid characters {all_invalid_chars} found in the sequence column") |
| |
| |
| sequences = df[seq_col] |
| if len(set(sequences))<len(sequences): log_update("\tWARNING: input data has duplicate sequences") |
|
|
| def load_fuson_model(ckpt_path): |
| |
| logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR) |
| |
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Using device: {device}") |
|
|
| |
| model = AutoModel.from_pretrained(ckpt_path) |
| tokenizer = AutoTokenizer.from_pretrained(ckpt_path) |
|
|
| |
| model.to(device) |
| model.eval() |
| |
| return model, tokenizer, device |
|
|
| def get_fuson_embeddings(model, tokenizer, sequences, device, average=True, print_updates=False, savepath=None, save_at_end=False, max_length=2000): |
| |
| if savepath is not None: |
| if savepath[-4::] != '.pkl': savepath += '.pkl' |
| |
| if print_updates: log_update(f"Dataset contains {len(sequences)} sequences.") |
| |
| |
| max_seq_len = max([len(s) for s in sequences]) |
| if max_length is None: max_length=max_seq_len+2 |
| |
| |
| embedding_dict = {} |
| |
| for i in range(len(sequences)): |
| sequence = sequences[i] |
| |
| with torch.no_grad(): |
| |
| inputs = tokenizer(sequence, return_tensors="pt", padding=True, truncation=True,max_length=max_length) |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
| |
| outputs = model(**inputs) |
| |
| embedding = outputs.last_hidden_state |
| |
| embedding = embedding.squeeze(0) |
| |
| embedding = embedding[1:-1, :] |
|
|
| |
| embedding = embedding.cpu().numpy() |
|
|
| |
| if average: |
| embedding = embedding.mean(0) |
| |
| |
| embedding_dict[sequence] = embedding |
| |
| |
| if not(savepath is None) and not(save_at_end): |
| with open(savepath, 'ab+') as f: |
| d = {sequence: embedding} |
| pickle.dump(d, f) |
|
|
| |
| if print_updates: log_update(f"sequence {i+1}: {sequence[0:10]}...") |
| |
| |
| if not(savepath is None): |
| |
| if save_at_end: |
| with open(savepath, 'wb') as f: |
| pickle.dump(embedding_dict, f) |
| |
| else: |
| redump_pickle_dictionary(savepath) |
|
|
| def embed_dataset(path_to_file, path_to_output, seq_col='aa_seq', model_type='fuson_plm', fuson_ckpt_path = None, average=True, overwrite=True, print_updates=False,max_length=2000): |
| |
| if os.path.exists(path_to_output): |
| if overwrite: |
| log_update(f"WARNING: these embeddings may already exist at {path_to_output} and will be overwritten") |
| else: |
| log_update(f"WARNING: these embeddings may already exist at {path_to_output}. Skipping.") |
| return None |
| |
| dataset = pd.read_csv(path_to_file) |
| |
| validate_sequence_col(dataset, seq_col) |
| |
| sequences = dataset[seq_col].unique().tolist() |
|
|
| |
| if model_type=='fuson_plm': |
| if not(os.path.exists(fuson_ckpt_path)): raise Exception("FusOn-pLM ckpt path does not exist") |
| |
| |
| try: |
| model, tokenizer, device = load_fuson_model(fuson_ckpt_path) |
| except: |
| raise Exception(f"Could not load FusOn-pLM from {fuson_ckpt_path}") |
| |
| |
| try: |
| get_fuson_embeddings(model, tokenizer, sequences, device, average=average, |
| print_updates=print_updates, savepath=path_to_output, save_at_end=False, |
| max_length=max_length) |
| except: |
| raise Exception("Could not generate FusOn-pLM embeddings") |
| |
| if model_type=='esm2_t33_650M_UR50D': |
| |
| try: |
| model, tokenizer, device = load_esm2_type(model_type) |
| except: |
| raise Exception(f"Could not load {model_type}") |
| |
| try: |
| get_esm_embeddings(model, tokenizer, sequences, device, average=average, |
| print_updates=print_updates, savepath=path_to_output, save_at_end=False, |
| max_length=max_length) |
| except: |
| raise Exception(f"Could not generate {model_type} embeddings") |
| |
| if model_type=="prot_t5_xl_half_uniref50_enc": |
| |
| try: |
| model, tokenizer, device = load_prott5() |
| except: |
| raise Exception(f"Could not load {model_type}") |
| |
| try: |
| get_prott5_embeddings(model, tokenizer, sequences, device, average=average, |
| print_updates=print_updates, savepath=path_to_output, save_at_end=False, |
| max_length=max_length) |
| except: |
| raise Exception(f"Could not generate {model_type} embeddings") |
| |
| |
| def embed_dataset_for_benchmark(fuson_ckpts=None, input_data_path=None, input_fname=None, average=True, seq_col='seq', benchmark_fusonplm=False, benchmark_esm=False, benchmark_fo_puncta_ml=False, benchmark_prott5=False, overwrite=False,max_length=None): |
| |
| os.makedirs('embeddings',exist_ok=True) |
| |
| |
| emb_type_tag ='average' if average else '2D' |
| |
| all_embedding_paths = dict() |
| |
| |
| if benchmark_fusonplm: |
| os.makedirs('embeddings/fuson_plm',exist_ok=True) |
| |
| log_update(f"\nMaking Fuson-PLM embeddings") |
| |
| if type(fuson_ckpts)==dict: |
| for model_name, epoch_list in fuson_ckpts.items(): |
| os.makedirs(f'embeddings/fuson_plm/{model_name}',exist_ok=True) |
| for epoch in epoch_list: |
| |
| fuson_ckpt_path = f'../../training/checkpoints/{model_name}/checkpoint_epoch_{epoch}' |
| if not(os.path.exists(fuson_ckpt_path)): raise Exception(f"Error. Cannot find ckpt path: {fuson_ckpt_path}") |
| |
| |
| embedding_output_dir = f'embeddings/fuson_plm/{model_name}/epoch{epoch}' |
| embedding_output_path = f'{embedding_output_dir}/{input_fname}_{emb_type_tag}_embeddings.pkl' |
| os.makedirs(embedding_output_dir,exist_ok=True) |
| |
| |
| model_type = 'fuson_plm' |
| all_embedding_paths[embedding_output_path] = { |
| 'model_type': model_type, |
| 'model': model_name, |
| 'epoch': epoch |
| } |
| |
| |
| log_update(f"\tUsing ckpt {fuson_ckpt_path} and saving results to {embedding_output_path}...") |
| embed_dataset(input_data_path, embedding_output_path, |
| seq_col=seq_col, model_type=model_type, |
| fuson_ckpt_path=fuson_ckpt_path, average=average, |
| overwrite=overwrite,print_updates=True, |
| max_length=max_length) |
| elif fuson_ckpts=="FusOn-pLM": |
| model_name = "best" |
| os.makedirs(f'embeddings/fuson_plm/{model_name}',exist_ok=True) |
| |
| |
| fuson_ckpt_path = "../../.." |
| if not(os.path.exists(fuson_ckpt_path)): raise Exception(f"Error. Cannot find ckpt path: {fuson_ckpt_path}") |
| |
| |
| embedding_output_dir = f'embeddings/fuson_plm/{model_name}' |
| embedding_output_path = f'{embedding_output_dir}/{input_fname}_{emb_type_tag}_embeddings.pkl' |
| os.makedirs(embedding_output_dir,exist_ok=True) |
| |
| |
| model_type = 'fuson_plm' |
| all_embedding_paths[embedding_output_path] = { |
| 'model_type': model_type, |
| 'model': model_name, |
| 'epoch': None |
| } |
| |
| |
| log_update(f"\tUsing ckpt {fuson_ckpt_path} and saving results to {embedding_output_path}...") |
| embed_dataset(input_data_path, embedding_output_path, |
| seq_col=seq_col, model_type=model_type, |
| fuson_ckpt_path=fuson_ckpt_path, average=average, |
| overwrite=overwrite,print_updates=True, |
| max_length=max_length) |
| else: |
| raise Exception(f"Error. fuson_ckpts should be a dict or str") |
| |
| |
| if benchmark_esm: |
| os.makedirs('embeddings/esm2_t33_650M_UR50D',exist_ok=True) |
| |
| |
| embedding_output_path = f'embeddings/esm2_t33_650M_UR50D/{input_fname}_{emb_type_tag}_embeddings.pkl' |
| |
| |
| model_type = 'esm2_t33_650M_UR50D' |
| all_embedding_paths[embedding_output_path] = { |
| 'model_type': model_type, |
| 'model': model_type, |
| 'epoch': np.nan |
| } |
|
|
| log_update(f"\nMaking ESM-2-650M embeddings for {input_data_path} and saving results to {embedding_output_path}...") |
| embed_dataset(input_data_path, embedding_output_path, |
| seq_col=seq_col, model_type=model_type, |
| fuson_ckpt_path = None, average=average, |
| overwrite=overwrite,print_updates=True, |
| max_length=max_length) |
| |
| if benchmark_prott5: |
| os.makedirs('embeddings/prot_t5_xl_half_uniref50_enc',exist_ok=True) |
| |
| |
| embedding_output_path = f'embeddings/prot_t5_xl_half_uniref50_enc/{input_fname}_{emb_type_tag}_embeddings.pkl' |
| |
| |
| model_type = 'prot_t5_xl_half_uniref50_enc' |
| all_embedding_paths[embedding_output_path] = { |
| 'model_type': model_type, |
| 'model': model_type, |
| 'epoch': np.nan |
| } |
|
|
| log_update(f"\nMaking ProtT5-XL-UniRef50 embeddings for {input_data_path} and saving results to {embedding_output_path}...") |
| embed_dataset(input_data_path, embedding_output_path, |
| seq_col=seq_col, model_type=model_type, |
| fuson_ckpt_path = None, average=average, |
| overwrite=overwrite,print_updates=True, |
| max_length=max_length) |
| |
| if benchmark_fo_puncta_ml: |
| embedding_output_path =f'FOdb_physicochemical_embeddings.pkl' |
| |
| all_embedding_paths[embedding_output_path] = { |
| 'model_type': 'fo_puncta_ml', |
| 'model': 'fo_puncta_ml', |
| 'epoch': np.nan |
| } |
| |
| return all_embedding_paths |