| | import pickle |
| | import torch |
| | from transformers import EsmModel, AutoTokenizer |
| | from transformers import T5Tokenizer, T5EncoderModel |
| | import pickle |
| | import logging |
| | from fuson_plm.utils.logging import log_update |
| |
|
| |
|
| | def redump_pickle_dictionary(pickle_path): |
| | """ |
| | Loads a pickle dictionary and redumps it in its location. This allows a clean reset for a pickle built with 'ab+' |
| | """ |
| | entries = {} |
| | |
| | with open(pickle_path, 'rb') as f: |
| | while True: |
| | try: |
| | entry = pickle.load(f) |
| | entries.update(entry) |
| | except EOFError: |
| | break |
| | except Exception as e: |
| | print(f"An error occurred: {e}") |
| | break |
| | |
| | with open(pickle_path, 'wb') as f: |
| | pickle.dump(entries, f) |
| | |
| | def load_esm2_type(esm_type, device=None): |
| | """ |
| | Loads ESM-2 version of a specified version (e.g. esm2_t33_650M_UR50D) |
| | """ |
| | |
| | logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR) |
| | |
| | if device is None: |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | print(f"Using device: {device}") |
| |
|
| | model = EsmModel.from_pretrained(f"facebook/{esm_type}") |
| | tokenizer = AutoTokenizer.from_pretrained(f"facebook/{esm_type}") |
| |
|
| | model.to(device) |
| | model.eval() |
| | |
| | return model, tokenizer, device |
| |
|
| | def load_prott5(): |
| | |
| | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') |
| | tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False) |
| | model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc") |
| | if device == torch.device('cpu'): |
| | model.to(torch.float32) |
| | model.to(device) |
| | return model, tokenizer, device |
| |
|
| | def get_esm_embeddings(model, tokenizer, sequences, device, average=True, print_updates=False, savepath=None, save_at_end=False,max_length=None): |
| | """ |
| | Compute ESM embeddings. |
| | |
| | Args: |
| | model |
| | tokenizer |
| | sequences |
| | device |
| | average: if True, the average embeddings will be taken and returned |
| | savepath: if savepath is not None, the embeddings will be saved somewhere. It must be a pickle |
| | """ |
| | |
| | if savepath is not None: |
| | if savepath[-4::] != '.pkl': savepath += '.pkl' |
| | |
| | |
| | max_seq_len = max([len(s) for s in sequences]) |
| | if max_length is None: max_length=max_seq_len+2 |
| | |
| | |
| | embedding_dict = {} |
| | |
| | for i in range(len(sequences)): |
| | sequence = sequences[i] |
| | |
| | with torch.no_grad(): |
| | inputs = tokenizer(sequence, return_tensors="pt",padding=True, truncation=True,max_length=max_length) |
| | inputs = {k:v.to(device) for k, v in inputs.items()} |
| | |
| | outputs = model(**inputs) |
| | embedding = outputs.last_hidden_state |
| |
|
| | |
| | embedding = embedding.squeeze(0) |
| | |
| | embedding = embedding[1:-1, :] |
| |
|
| | |
| | embedding = embedding.cpu().numpy() |
| | |
| | |
| | if average: |
| | embedding = embedding.mean(0) |
| | |
| | |
| | embedding_dict[sequence] = embedding |
| | |
| | |
| | if not(savepath is None) and not(save_at_end): |
| | with open(savepath, 'ab+') as f: |
| | d = {sequence: embedding} |
| | pickle.dump(d, f) |
| |
|
| | |
| | if print_updates: log_update(f"sequence {i+1}: {sequence[0:10]}...") |
| | |
| | |
| | if not(savepath is None): |
| | |
| | if save_at_end: |
| | with open(savepath, 'wb') as f: |
| | pickle.dump(embedding_dict, f) |
| | |
| | else: |
| | redump_pickle_dictionary(savepath) |
| | |
| | |
| | return embedding_dict |
| |
|
| | def get_prott5_embeddings(model, tokenizer, sequences, device, average=True, print_updates=False, savepath=None, save_at_end=False,max_length=None): |
| | |
| | if savepath is not None: |
| | if savepath[-4::] != '.pkl': savepath += '.pkl' |
| | |
| | |
| | max_seq_len = max([len(s) for s in sequences]) |
| | if max_length is None: max_length=max_seq_len+2 |
| | |
| | |
| | spaced_sequences = [' '.join(list(seq)) for seq in sequences] |
| | |
| | |
| | embedding_dict = {} |
| | |
| | for i in range(0, len(spaced_sequences)): |
| | spaced_sequence = spaced_sequences[i] |
| | seq = spaced_sequence.replace(" ", "") |
| |
|
| | with torch.no_grad(): |
| | inputs = tokenizer(spaced_sequence, return_tensors="pt", add_special_tokens=True, truncation=True,max_length=max_length) |
| | inputs = {k:v.to(device) for k, v in inputs.items()} |
| |
|
| | |
| | embedding_repr = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask']) |
| |
|
| | |
| | seq_length = len(seq) |
| | embedding = embedding_repr.last_hidden_state.squeeze(0) |
| | embedding = embedding[0:-1] |
| | embedding = embedding.cpu().numpy() |
| | embedding_log = f"\tembedding shape: {embedding.shape}" |
| | |
| | assert embedding.shape[1] == 1024 |
| | assert embedding.shape[0] == seq_length |
| | |
| | |
| | if average: |
| | dim_before = embedding.shape |
| | embedding = embedding.mean(0) |
| | embedding_log = f"\tembedding shape before avg: {dim_before}\tafter avg: {embedding.shape}" |
| |
|
| | |
| | embedding_dict[seq] = embedding |
| |
|
| | |
| | if not(savepath is None) and not(save_at_end): |
| | with open(savepath, 'ab+') as f: |
| | d = {seq: embedding} |
| | pickle.dump(d, f) |
| |
|
| | if print_updates: log_update(f"sequence {i+1}: {seq[0:10]}...{embedding_log}\t seq len: {seq_length}") |
| | |
| | |
| | if not(savepath is None): |
| | |
| | if save_at_end: |
| | with open(savepath, 'wb') as f: |
| | pickle.dump(embedding_dict, f) |
| | |
| | else: |
| | redump_pickle_dictionary(savepath) |
| | |
| | |
| | return embedding_dict |