from .utils import embed_and_save from dpacman.data_tasks.embeddings import get_embedder import os import torch import json import pandas as pd from pathlib import Path from omegaconf import DictConfig import rootutils from dpacman.utils import pylogger root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) logger = pylogger.RankedLogger(__name__, rank_zero_only=True) def main(cfg: DictConfig): logger.info( f"Making embeddings using {cfg.data_task.prot_model} for protein sequences at {cfg.data_task.input_file}" ) # make out dir if necessary out_dir = Path(root) / cfg.data_task.out_dir os.makedirs(out_dir, exist_ok=True) # set device device = "cpu" if cfg.data_task.device == "gpu": if torch.cuda.is_available(): device = "cuda" logger.info(f"Using device: {device}") # read the input file input_file = Path(root) / cfg.data_task.input_file if str(input_file).endswith(".json"): # load the json and isolate the sequences and ids with open(input_file, "r") as f: d = json.load(f) df = pd.DataFrame.from_dict(d, orient="index").reset_index() df.columns = ["seq_id", "sequence"] if cfg.data_task.debug: logger.info(f"DEBUG MODE. Only embedding 5 sequences") df = df.sample(n=5, random_state=42).reset_index(drop=True) # crucial: sort by SIZE so that we are padding things in a way that will support pre-batching df["sequence_length"] = df["sequence"].str.len() df = df.sort_values(by="sequence_length", ascending=True).reset_index(drop=True) # turn into list of sequences and IDs tr_seqs = df["sequence"].tolist() tr_ids = df["seq_id"].tolist() logger.info( f"Embedding {len(tr_ids)} transcriptional regulators (TRs) from processed remap data" ) # Get the DNA embedder prot_embedder = get_embedder(cfg.data_task.prot_model, device, for_dna=False) logger.info(f"Device of embedding model: {prot_embedder.device}") out_trs = str(out_dir / f"trs_{cfg.data_task.prot_model}.pkl") if cfg.data_task.debug: out_trs = out_trs.replace(".pkl", "_debug.pkl") embed_and_save( tr_seqs, tr_ids, prot_embedder, out_trs, batch_size=cfg.data_task.batch_size, save_as_shelf=cfg.data_task.save_as_shelf, ) logger.info(f"Finished embedding protein sequences. Saved to: {out_trs}") if __name__ == "__main__": main()