from collections import Counter, defaultdict from ortools.linear_solver import pywraplp import random from omegaconf import DictConfig import pandas as pd from pathlib import Path import os import numpy as np from sklearn.model_selection import train_test_split from dpacman.data_tasks.fimo.post_fimo import get_reverse_complement import json import rootutils from dpacman.utils import pylogger root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) logger = pylogger.RankedLogger(__name__, rank_zero_only=True) def split_with_predefined_test( full_df = pd.DataFrame(), split_names=("train", "val", "test"), test_trs=None, test_dnas=None, ratios=(0.8, 0.1, 0.1), ): """ Method for splitting into train and val with a predefined test set. The proteins in the test set, and the DNA clusters of the DNAs they're associated with, must be excluded from train and val. The remaining rows for train and val are split to preserve 80/10/10 as best as possible. """ test = full_df.copy(deep=True) if test_trs is not None: test = test.loc[test["tr_seqid"].isin(test_trs)].reset_index(drop=True) if test_dnas is not None: test = test.loc[test["dna_seqid"].isin(test_dnas)].reset_index(drop=True) tr_clusters_to_exclude = test["tr_cluster_rep"].unique().tolist() dna_clusters_to_exclude = test["dna_cluster_rep"].unique().tolist() remaining = full_df.loc[ (~full_df["tr_cluster_rep"].isin(tr_clusters_to_exclude)) & (~full_df["dna_cluster_rep"].isin(dna_clusters_to_exclude)) ].reset_index(drop=True) test_ids = test["ID"].unique().tolist() remaining_ids = remaining["ID"].unique().tolist() remaining_clusters = remaining["dna_cluster_rep"].unique().tolist() lost_rows = full_df.loc[ (~full_df["ID"].isin(test_ids)) & (~full_df["ID"].isin(remaining_ids)) ] logger.info(f"Rows in test: {len(test)}") logger.info(f"Rows to be split between train and val: {len(remaining)}") total_rows = len(test) + len(remaining) logger.info(f"Total rows: {total_rows}. Test percentage: {100*len(test)/total_rows:.2f}%") logger.info(f"Lost rows: {len(lost_rows)}") train_ratio_from_remaining = round((0.8*total_rows)/len(remaining), 2) # use sklearn test_size_1 = 1 - train_ratio_from_remaining logger.info( f"\tPerforming first split: non-test clusters -> train clusters ({round(1-test_size_1,3)}) and val ({test_size_1})" ) X = remaining_clusters y = [0] * len(remaining_clusters) X_train, X_val, y_train, y_val = train_test_split( X, y, test_size=test_size_1, random_state=0 ) train = remaining.loc[remaining["dna_cluster_rep"].isin(X_train)] val = remaining.loc[remaining["dna_cluster_rep"].isin(X_val)] leaky_test = lost_rows kept_by_split = { "train": len(X_train), "val": len(X_val), "test": len(test["dna_cluster_rep"].unique()) } splits = { "train": train, "val": val, "test": test, "leaky_test": leaky_test } return splits, kept_by_split def split_bipartite_fast( dna_clusters, split_names=("train", "val", "test"), ratios=(0.8, 0.1, 0.1), ): # use sklearn test_size_1 = 0.2 test_size_2 = 0.5 logger.info( f"\tPerforming first split: all clusters -> train clusters ({round(1-test_size_1,3)}) and other ({test_size_1})" ) X = dna_clusters y = [0] * len(dna_clusters) X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=test_size_1, random_state=0 ) logger.info( f"\tPerforming second split: other -> val clusters ({round(1-test_size_2,3)}) and test clusters ({test_size_2})" ) X_val, X_test, y_val, y_test = train_test_split( X_test, y_test, test_size=test_size_2, random_state=0 ) dna_assign = {} for x in X_train: dna_assign[x] = "train" for x in X_val: dna_assign[x] = "val" for x in X_test: dna_assign[x] = "test" kept_by_split = {"train": len(X_train), "val": len(X_val), "test": len(X_test)} return dna_assign, kept_by_split # construct new labels def convert_scores(scores, mode=1): """ Two modes: 1 means FIMO peaks get 1. 0 means FIMO peaks get their max score """ svec = [int(x) for x in scores.split(",")] max_score = max(svec) if mode ==1: binary_svec = [0 if x1 datapoint: {len(dup_edges)}") logger.info( f"Total datapoints belonging to a duplicate edge: {len(edges.loc[edges['edge'].isin(dup_edges)])}" ) return edges def check_validity(train, val, test, split_by="both"): """ Rigorous check for no overlap Columns = ["ID","dna_sequence","tr_sequence","tr_cluster_rep","dna_cluster_rep", "scores","split"] """ train_ids = set(train["ID"].unique().tolist()) val_ids = set(val["ID"].unique().tolist()) test_ids = set(test["ID"].unique().tolist()) assert len(train_ids.intersection(val_ids)) == 0 assert len(train_ids.intersection(test_ids)) == 0 assert len(val_ids.intersection(test_ids)) == 0 logger.info(f"Pass! No overlap in IDs") # Investigate TR intersection. No assertions unless we are explicitly splitting on this. train_tr_seqs = set(train["tr_sequence"].unique().tolist()) val_tr_seqs = set(val["tr_sequence"].unique().tolist()) test_tr_seqs = set(test["tr_sequence"].unique().tolist()) train_tr_reps = set(train["tr_cluster_rep"].unique().tolist()) val_tr_reps = set(val["tr_cluster_rep"].unique().tolist()) test_tr_reps = set(test["tr_cluster_rep"].unique().tolist()) logger.info(f"Train-Val TR intersection: {len(train_tr_seqs.intersection(val_tr_seqs))}") logger.info(f"Train-Test TR intersection: {len(train_tr_seqs.intersection(test_tr_seqs))}") logger.info(f"Val-Test TR intersection: {len(val_tr_seqs.intersection(test_tr_seqs))}") logger.info(f"Train-Val TR Cluster Rep intersection: {len(train_tr_reps.intersection(val_tr_reps))}") logger.info(f"Train-Test TR Cluster Rep intersection: {len(train_tr_reps.intersection(test_tr_reps))}") logger.info(f"Val-Test TR Cluster Rep intersection: {len(val_tr_reps.intersection(test_tr_reps))}") # Investigate DNA intersection. No assertions unless we are explicitly splitting on this. train_dna_seqs = set(train["dna_sequence"].unique().tolist()) val_dna_seqs = set(val["dna_sequence"].unique().tolist()) test_dna_seqs = set(test["dna_sequence"].unique().tolist()) train_dna_reps = set(train["dna_cluster_rep"].unique().tolist()) val_dna_reps = set(val["dna_cluster_rep"].unique().tolist()) test_dna_reps = set(test["dna_cluster_rep"].unique().tolist()) logger.info(f"Train-Val DNA intersection: {len(train_dna_seqs.intersection(val_dna_seqs))}") logger.info(f"Train-Test DNA intersection: {len(train_dna_seqs.intersection(test_dna_seqs))}") logger.info(f"Val-Test DNA intersection: {len(val_dna_seqs.intersection(test_dna_seqs))}") logger.info(f"Train-Val DNA Cluster Rep intersection: {len(train_dna_reps.intersection(val_dna_reps))}") logger.info(f"Train-Test DNA Cluster Rep intersection: {len(train_dna_reps.intersection(test_dna_reps))}") logger.info(f"Val-Test DNA Cluster Rep intersection: {len(val_dna_reps.intersection(test_dna_reps))}") if split_by != "dna": assert len(train_tr_seqs.intersection(val_tr_seqs)) == 0 assert len(train_tr_seqs.intersection(test_tr_seqs)) == 0 assert len(val_tr_seqs.intersection(test_tr_seqs)) == 0 logger.info(f"Pass! No overlap in TR sequences") assert len(train_tr_reps.intersection(val_tr_reps)) == 0 assert len(train_tr_reps.intersection(test_tr_reps)) == 0 assert len(val_tr_reps.intersection(test_tr_reps)) == 0 logger.info(f"Pass! No overlap in TR cluster reps") if split_by != "protein": assert len(train_dna_seqs.intersection(val_dna_seqs)) == 0 assert len(train_dna_seqs.intersection(test_dna_seqs)) == 0 assert len(val_dna_seqs.intersection(test_dna_seqs)) == 0 logger.info(f"Pass! No overlap in DNA sequences") assert len(train_dna_reps.intersection(val_dna_reps)) == 0 assert len(train_dna_reps.intersection(test_dna_reps)) == 0 assert len(val_dna_reps.intersection(test_dna_reps)) == 0 logger.info(f"Pass! No overlap in DNA cluster reps") def augment_rc(df): """ Get the reverse complement and add it as a datapoint, effectively doubling the dataset. Also flip the orientation of the scores columns = ["ID","dna_sequence","tr_sequence","tr_cluster_rep","dna_cluster_rep", "scores","split"] """ df_rc = df.copy(deep=True) df_rc["dna_sequence"] = df_rc["dna_sequence"].apply( lambda x: get_reverse_complement(x) ) df_rc["ID"] = df_rc["ID"] + "_rc" df_rc["scores"] = df_rc["scores"].apply(lambda s: ",".join(s.split(",")[::-1])) final_df = pd.concat([df, df_rc]).reset_index(drop=True) return final_df def main(cfg: DictConfig): """ Take a set of DNA clusters + protein clusters, and create the best possible splits into train/val/test. """ # construct edges from training data edge_df = make_edges( processed_fimo_path=Path(root) / cfg.data_task.input_data_path, protein_cluster_path=Path(root) / cfg.data_task.cluster_output_paths.protein, dna_cluster_path=Path(root) / cfg.data_task.cluster_output_paths.dna, ) edges = edge_df["edge"].unique().tolist() # figure out if we actually even have a conflict total_proteins = len(edge_df["tr_seqid"].unique().tolist()) total_protein_clusters = len(edge_df["tr_cluster_rep"].unique().tolist()) no_protein_overlap = (total_proteins) == (total_protein_clusters) logger.info(f"All proteins are in their own clusters: {no_protein_overlap}") if cfg.data_task.split_by == "dna": if cfg.data_task.test_trs or cfg.data_task.test_dnas: logger.info(f"Splitting with predefined trs/dnas reserved for test set") splits, kept_by_split = split_with_predefined_test( full_df=edge_df, split_names=("train", "val", "test"), test_trs=cfg.data_task.test_trs if cfg.data_task.test_trs else None, test_dnas=cfg.data_task.test_dnas if cfg.data_task.test_dnas else None, ratios=(0.8, 0.1, 0.1), ) train = splits["train"] train["split"]=["train"]*len(train) val = splits["val"] val["split"]=["val"]*len(val) test = splits["test"] test["split"]=["test"]*len(test) leaky_test = splits["leaky_test"] leaky_test["split"]=["leaky_test"]*len(leaky_test) else: logger.info(f"Easy split: all proteins are in their own clusters.") dna_clusters = edge_df["dna_cluster_rep"].unique().tolist() results = split_bipartite_fast( dna_clusters, split_names=("train", "val", "test"), ratios=( cfg.data_task.train_ratio, cfg.data_task.val_ratio, cfg.data_task.test_ratio, ), ) dna_assign, kept_by_split = results # assign datapoints to cluster by their DNA cluster rep edge_df["split"] = edge_df["dna_cluster_rep"].map(dna_assign) train = edge_df.loc[edge_df["split"] == "train"].reset_index(drop=True) val = edge_df.loc[edge_df["split"] == "val"].reset_index(drop=True) test = edge_df.loc[edge_df["split"] == "test"].reset_index(drop=True) leaky_test = pd.DataFrame(columns=edge_df.columns) # Print ratios: hopefully close to desired (e.g. 80/10/10) print_split_ratios(kept_by_split) # Make train, val, test sets # make sure no ID is duplicate assert len(edge_df["ID"].unique()) == len(edge_df) split_cols = [ "ID", "dna_sequence", "tr_sequence", "tr_cluster_rep", "dna_cluster_rep", "scores", "split", ] train = train[split_cols] val = val[split_cols] test = test[split_cols] leaky_test = leaky_test[split_cols] # ensure there is no overlap check_validity(train, val, test, split_by=cfg.data_task.split_by) total = sum([len(train), len(val), len(test), len(leaky_test)]) logger.info(f"Length of train dataset: {len(train)} ({100*len(train)/total:.2f}%)") logger.info(f"Length of val dataset: {len(val)} ({100*len(val)/total:.2f}%)") logger.info(f"Length of test dataset: {len(test)} ({100*len(test)/total:.2f}%)") logger.info(f"Length of leaky_test dataset: {len(leaky_test)} ({100*len(leaky_test)/total:.2f}%)") logger.info(f"Total sequences = {total}. Same as edges size? {total==len(edge_df)}") og_unique_dna = pd.concat([train, val, test, leaky_test]) og_unique_dna = len(og_unique_dna["dna_sequence"].unique()) ## Now do RC data augmentation if asked if cfg.data_task.augment_rc: train = augment_rc(train) val = augment_rc(val) test = augment_rc(test) leaky_test = augment_rc(leaky_test) logger.info(f"Added reverse complement sequences to train, val, and test (and leaky test)") check_validity(train, val, test, split_by=cfg.data_task.split_by) total = sum([len(train), len(val), len(test), len(leaky_test)]) logger.info( f"Length of train dataset: {len(train)} ({100*len(train)/total:.2f}%)" ) logger.info(f"Length of val dataset: {len(val)} ({100*len(val)/total:.2f}%)") logger.info(f"Length of test dataset: {len(test)} ({100*len(test)/total:.2f}%)") logger.info(f"Length of leaky_test dataset: {len(leaky_test)} ({100*len(leaky_test)/total:.2f}%)") logger.info( f"Total sequences = {total}. Same as edges size? {total==len(edge_df)}" ) # since we've added all these new DNA sequences, we do need a new mapping of seq id to dna sequence all_data = pd.concat([train, val, test, leaky_test]) all_data["dna_seqid"] = all_data["ID"].str.split("_", n=1, expand=True)[1] dna_dict = dict(zip(all_data["dna_seqid"], all_data["dna_sequence"])) assert len(dna_dict) == len(all_data.drop_duplicates(["dna_sequence"])) new_map_path = str(Path(root) / cfg.data_task.dna_map_path).replace( ".json", "_with_rc.json" ) with open(new_map_path, "w") as f: json.dump(dna_dict, f, indent=2) logger.info( f"Saved DNA maps with reverse complements (len {len(dna_dict)}=2*original map of len {og_unique_dna}=={len(dna_dict)==2*og_unique_dna}) to {new_map_path}" ) # create the output dir split_out_dir = Path(root) / cfg.data_task.split_out_dir os.makedirs(split_out_dir, exist_ok=True) # add binary_scores to allow other training modes train["fimo_binary_sores"] = train["scores"].apply(lambda x: convert_scores(x, mode=1)) val["fimo_binary_sores"] = val["scores"].apply(lambda x: convert_scores(x, mode=1)) test["fimo_binary_sores"] = test["scores"].apply(lambda x: convert_scores(x, mode=1)) leaky_test["fimo_binary_sores"] = leaky_test["scores"].apply(lambda x: convert_scores(x, mode=1)) # slect final cols and save split_final_cols = ["ID", "dna_sequence", "tr_sequence", "scores", "fimo_binary_sores", "split"] train[split_final_cols].to_csv(split_out_dir / "train.csv", index=False) val[split_final_cols].to_csv(split_out_dir / "val.csv", index=False) test[split_final_cols].to_csv(split_out_dir / "test.csv", index=False) leaky_test[split_final_cols].to_csv(split_out_dir / "leaky_test.csv", index=False) logger.info(f"Saved all splits to {split_out_dir}")