svincoff's picture
training works
29899b4
"""
Embed DNA sequences from ReMap peaks.
"""
from .utils import pad_token_embeddings, embed_and_save
from dpacman.data_tasks.embeddings import get_embedder
import os
import torch
import json
import pandas as pd
from pathlib import Path
from omegaconf import DictConfig
import rootutils
from dpacman.utils import pylogger
root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
logger = pylogger.RankedLogger(__name__, rank_zero_only=True)
def main(cfg: DictConfig):
logger.info(
f"Making embeddings using {cfg.data_task.chrom_model} for dna sequences at {cfg.data_task.input_file}"
)
# make out dir if necessary
out_dir = Path(root) / cfg.data_task.out_dir
os.makedirs(out_dir, exist_ok=True)
# set device
device = "cpu"
if cfg.data_task.device == "gpu":
if torch.cuda.is_available():
device = "cuda"
logger.info(f"Using device: {device}")
# read the input file
input_file = Path(root) / cfg.data_task.input_file
if str(input_file).endswith(".json"):
# load the json and isolate the sequences and ids
with open(input_file, "r") as f:
d = json.load(f)
df = pd.DataFrame.from_dict(d, orient="index").reset_index()
df.columns = ["seq_id", "sequence"]
if cfg.data_task.debug:
logger.info(f"DEBUG MODE. Only embedding 5 sequences")
df = df.sample(n=5, random_state=42).reset_index(drop=True)
# crucial: sort by SIZE so that we are padding things in a way that will support pre-batching
df["sequence_length"] = df["sequence"].str.len()
df = df.sort_values(by="sequence_length", ascending=True).reset_index(drop=True)
# turn into list of sequences and IDs
peak_seqs = df["sequence"].tolist()
peak_ids = df["seq_id"].tolist()
logger.info(
f"Embedding {len(peak_seqs)} binding peak sequences from processed remap data"
)
# Get the DNA embedder
dna_embedder = get_embedder(cfg.data_task.chrom_model, device, for_dna=True)
logger.info(f"Device of embedding model: {dna_embedder.device}")
out_peaks = str(out_dir / f"peaks_{cfg.data_task.chrom_model}.pkl")
if cfg.data_task.debug:
out_peaks = out_peaks.replace(".pkl", "_debug.pkl")
embed_and_save(
peak_seqs,
peak_ids,
dna_embedder,
out_peaks,
batch_size=cfg.data_task.batch_size,
)
logger.info(f"Finished embedding DNA sequences. Saved to: {out_peaks}")
if __name__ == "__main__":
main()