# Plotting utils, such as setting ubuntu font import matplotlib.pyplot as plt import matplotlib.font_manager as fm from matplotlib.font_manager import FontProperties import pickle import torch from transformers import EsmModel, AutoTokenizer import logging import os import numpy as np def set_font(): # Load and set the font # Get the directory where this script lives utils_dir = os.path.dirname(os.path.abspath(__file__)) font_dir = os.path.join(utils_dir, "ubuntu_font") # adjust as needed # Paths for regular, bold, italic fonts regular_font_path = os.path.join(font_dir, "Ubuntu-Regular.ttf") bold_font_path = os.path.join(font_dir, "Ubuntu-Bold.ttf") italic_font_path = os.path.join(font_dir, "Ubuntu-Italic.ttf") bold_italic_font_path = os.path.join(font_dir, "Ubuntu-BoldItalic.ttf") # Load the font properties regular_font = FontProperties(fname=regular_font_path) bold_font = FontProperties(fname=bold_font_path) italic_font = FontProperties(fname=italic_font_path) bold_italic_font = FontProperties(fname=bold_italic_font_path) # Add the fonts to the font manager fm.fontManager.addfont(regular_font_path) fm.fontManager.addfont(bold_font_path) fm.fontManager.addfont(italic_font_path) fm.fontManager.addfont(bold_italic_font_path) # Set the font family globally to Ubuntu plt.rcParams["font.family"] = regular_font.get_name() # Set the font family globally to Ubuntu plt.rcParams["font.family"] = regular_font.get_name() plt.rcParams["mathtext.fontset"] = "custom" plt.rcParams["mathtext.rm"] = regular_font.get_name() plt.rcParams["mathtext.it"] = italic_font.get_name() plt.rcParams["mathtext.bf"] = bold_font.get_name() def redump_pickle_dictionary(pickle_path): """ Loads a pickle dictionary and redumps it in its location. This allows a clean reset for a pickle built with 'ab+' """ entries = {} # Load one by one with open(pickle_path, "rb") as f: while True: try: entry = pickle.load(f) entries.update(entry) except EOFError: break # End of file reached except Exception as e: print(f"An error occurred: {e}") break # Redump with open(pickle_path, "wb") as f: pickle.dump(entries, f) def load_esm2_type(esm_type, device=None): """ Loads ESM-2 version of a specified version (e.g. esm2_t33_650M_UR50D) """ # Suppress warnings about newly initialized 'esm.pooler.dense.bias', 'esm.pooler.dense.weight' layers - these are not used to extract embeddings logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR) if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") model = EsmModel.from_pretrained(f"facebook/{esm_type}") tokenizer = AutoTokenizer.from_pretrained(f"facebook/{esm_type}") model.to(device) model.eval() # disables dropout for deterministic results return model, tokenizer, device def get_esm_embeddings( model, tokenizer, sequences, device, average=True, print_updates=False, savepath=None, save_at_end=False, max_length=None, ): """ Compute ESM embeddings. Args: model tokenizer sequences device average: if True, the average embeddings will be taken and returned savepath: if savepath is not None, the embeddings will be saved somewhere. It must be a pickle """ # Correct save path to pickle if necessary if savepath is not None: if savepath[-4::] != ".pkl": savepath += ".pkl" # If no max length was passed, just set it to the maximum in the dataset max_seq_len = max([len(s) for s in sequences]) if max_length is None: max_length = max_seq_len + 2 # +2 for BOS, EOS # Initialize an empty dict to store the ESM embeddings embedding_dict = {} # Iterate through the seqs for i in range(len(sequences)): sequence = sequences[i] # Get the embeddings with torch.no_grad(): inputs = tokenizer( sequence, return_tensors="pt", padding=True, truncation=True, max_length=max_length, ) inputs = {k: v.to(device) for k, v in inputs.items()} outputs = model(**inputs) embedding = outputs.last_hidden_state # remove extra dimension embedding = embedding.squeeze(0) # remove BOS and EOS tokens embedding = embedding[1:-1, :] # Convert embeddings to numpy array (if needed) embedding = embedding.cpu().numpy() # Average (if necessary) if average: embedding = embedding.mean(0) # Add to dictionary embedding_dict[sequence] = embedding # Save individual embedding (if necessary) if not (savepath is None) and not (save_at_end): with open(savepath, "ab+") as f: d = {sequence: embedding} pickle.dump(d, f) # Print update (if necessary) if print_updates: print(f"sequence {i+1}: {sequence[0:10]}...") # Dump all at once at the end (if necessary) if not (savepath is None): # If saving for the first time, just dump it if save_at_end: with open(savepath, "wb") as f: pickle.dump(embedding_dict, f) # If we've been saving all along and made it here without crashing, correct the pickle file so it can be loaded nicely else: redump_pickle_dictionary(savepath) # Return the dictionary return embedding_dict def one_hot_encode_sequence(seq, add=True, max_len=None): """ One-hot encode a single protein sequence. Pads or truncates to max_len. Parameters: - seq: protein sequence (string of amino acids) - max_len: desired fixed length (pads with zeros or truncates) Returns: - 2D numpy array of shape (max_len, 20) """ AA_ORDER = list("ACDEFGHIKLMNPQRSTVWY") AA_TO_IDX = {aa: i for i, aa in enumerate(AA_ORDER)} if max_len is None: max_len = len(seq) one_hot = np.zeros((max_len, len(AA_ORDER)), dtype=np.float32) for i, aa in enumerate(seq[:max_len]): if aa in AA_TO_IDX: one_hot[i, AA_TO_IDX[aa]] = 1.0 # Add (if necessary) if add: # add across length dimension one_hot = np.sum(one_hot, axis=0) return one_hot def get_one_hot_embeddings( sequences, print_updates=False, savepath=None, add=True, save_at_end=False, max_length=None, ): """ Compute One Hot embeddings """ # Correct save path to pickle if necessary if savepath is not None: if savepath[-4::] != ".pkl": savepath += ".pkl" # If no max length was passed, just set it to the maximum in the dataset max_seq_len = max([len(s) for s in sequences]) if max_length is None: max_length = max_seq_len + 2 # +2 for BOS, EOS # Initialize an empty dict to store the ESM embeddings embedding_dict = {} # Iterate through the seqs for i in range(len(sequences)): sequence = sequences[i] embedding = one_hot_encode_sequence(sequence, add=add, max_len=None) # Add to dictionary embedding_dict[sequence] = embedding # Save individual embedding (if necessary) if not (savepath is None) and not (save_at_end): with open(savepath, "ab+") as f: d = {sequence: embedding} pickle.dump(d, f) # Print update (if necessary) if print_updates: print(f"sequence {i+1}: {sequence[0:10]}...") # Dump all at once at the end (if necessary) if not (savepath is None): # If saving for the first time, just dump it if save_at_end: with open(savepath, "wb") as f: pickle.dump(embedding_dict, f) # If we've been saving all along and made it here without crashing, correct the pickle file so it can be loaded nicely else: redump_pickle_dictionary(savepath) # Return the dictionary return embedding_dict