| from collections import Counter, defaultdict |
| from ortools.linear_solver import pywraplp |
| import random |
| from omegaconf import DictConfig |
| import pandas as pd |
| from pathlib import Path |
| import os |
| import numpy as np |
| from sklearn.model_selection import train_test_split |
| from dpacman.data_tasks.fimo.post_fimo import get_reverse_complement |
| import json |
| import rootutils |
| from dpacman.utils import pylogger |
|
|
| root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) |
| logger = pylogger.RankedLogger(__name__, rank_zero_only=True) |
|
|
| def split_with_predefined_test( |
| full_df = pd.DataFrame(), |
| split_names=("train", "val", "test"), |
| test_trs=None, |
| test_dnas=None, |
| ratios=(0.8, 0.1, 0.1), |
| ): |
| """ |
| Method for splitting into train and val with a predefined test set. |
| The proteins in the test set, and the DNA clusters of the DNAs they're associated with, must be excluded from train and val. |
| The remaining rows for train and val are split to preserve 80/10/10 as best as possible. |
| """ |
| test = full_df.copy(deep=True) |
| if test_trs is not None: |
| test = test.loc[test["tr_seqid"].isin(test_trs)].reset_index(drop=True) |
| if test_dnas is not None: |
| test = test.loc[test["dna_seqid"].isin(test_dnas)].reset_index(drop=True) |
| |
| tr_clusters_to_exclude = test["tr_cluster_rep"].unique().tolist() |
| dna_clusters_to_exclude = test["dna_cluster_rep"].unique().tolist() |
| |
| remaining = full_df.loc[ |
| (~full_df["tr_cluster_rep"].isin(tr_clusters_to_exclude)) & |
| (~full_df["dna_cluster_rep"].isin(dna_clusters_to_exclude)) |
| ].reset_index(drop=True) |
| |
| test_ids = test["ID"].unique().tolist() |
| remaining_ids = remaining["ID"].unique().tolist() |
| remaining_clusters = remaining["dna_cluster_rep"].unique().tolist() |
| lost_rows = full_df.loc[ |
| (~full_df["ID"].isin(test_ids)) & |
| (~full_df["ID"].isin(remaining_ids)) |
| ] |
| |
| logger.info(f"Rows in test: {len(test)}") |
| logger.info(f"Rows to be split between train and val: {len(remaining)}") |
| total_rows = len(test) + len(remaining) |
| logger.info(f"Total rows: {total_rows}. Test percentage: {100*len(test)/total_rows:.2f}%") |
| logger.info(f"Lost rows: {len(lost_rows)}") |
| |
| train_ratio_from_remaining = round((0.8*total_rows)/len(remaining), 2) |
| |
| test_size_1 = 1 - train_ratio_from_remaining |
| logger.info( |
| f"\tPerforming first split: non-test clusters -> train clusters ({round(1-test_size_1,3)}) and val ({test_size_1})" |
| ) |
| X = remaining_clusters |
| y = [0] * len(remaining_clusters) |
| X_train, X_val, y_train, y_val = train_test_split( |
| X, y, test_size=test_size_1, random_state=0 |
| ) |
|
|
| train = remaining.loc[remaining["dna_cluster_rep"].isin(X_train)] |
| val = remaining.loc[remaining["dna_cluster_rep"].isin(X_val)] |
| leaky_test = lost_rows |
| |
| kept_by_split = { |
| "train": len(X_train), |
| "val": len(X_val), |
| "test": len(test["dna_cluster_rep"].unique()) |
| } |
| splits = { |
| "train": train, |
| "val": val, |
| "test": test, |
| "leaky_test": leaky_test |
| } |
| return splits, kept_by_split |
|
|
| def split_bipartite_fast( |
| dna_clusters, |
| split_names=("train", "val", "test"), |
| ratios=(0.8, 0.1, 0.1), |
| ): |
| |
| test_size_1 = 0.2 |
| test_size_2 = 0.5 |
| logger.info( |
| f"\tPerforming first split: all clusters -> train clusters ({round(1-test_size_1,3)}) and other ({test_size_1})" |
| ) |
| X = dna_clusters |
| y = [0] * len(dna_clusters) |
| X_train, X_test, y_train, y_test = train_test_split( |
| X, y, test_size=test_size_1, random_state=0 |
| ) |
| logger.info( |
| f"\tPerforming second split: other -> val clusters ({round(1-test_size_2,3)}) and test clusters ({test_size_2})" |
| ) |
| X_val, X_test, y_val, y_test = train_test_split( |
| X_test, y_test, test_size=test_size_2, random_state=0 |
| ) |
|
|
| dna_assign = {} |
| for x in X_train: |
| dna_assign[x] = "train" |
| for x in X_val: |
| dna_assign[x] = "val" |
| for x in X_test: |
| dna_assign[x] = "test" |
|
|
| kept_by_split = {"train": len(X_train), "val": len(X_val), "test": len(X_test)} |
| return dna_assign, kept_by_split |
|
|
| |
| def convert_scores(scores, mode=1): |
| """ |
| Two modes: 1 means FIMO peaks get 1. 0 means FIMO peaks get their max score |
| """ |
| svec = [int(x) for x in scores.split(",")] |
| max_score = max(svec) |
| if mode ==1: |
| binary_svec = [0 if x<max_score else 1 for x in svec] |
| assert(svec.count(max_score)==binary_svec.count(1)) |
| else: |
| binary_svec = [0 if x<max_score else max_score for x in svec] |
| assert(svec.count(max_score)==binary_svec.count(max_score)) |
| binary_svec = ",".join([str(x) for x in binary_svec]) |
| return binary_svec |
| |
| def print_split_ratios(kept_by_split): |
| total = sum(kept_by_split.values()) |
| train_pcnt = 100 * kept_by_split["train"] / total |
| val_pcnt = 100 * kept_by_split["val"] / total |
| test_pcnt = 100 * kept_by_split["test"] / total |
| logger.info( |
| f"Cluster distribution - Train: {train_pcnt:.2f}%, Val: {val_pcnt:.2f}%, Test: {test_pcnt:.2f}%" |
| ) |
|
|
|
|
| def make_edges( |
| processed_fimo_path: str, protein_cluster_path: str, dna_cluster_path: str |
| ): |
| """ |
| Make edges for input to the splitting algorithm. Edges consist of: (tr_cluster_rep)_(dna_cluster_rep) where the cluster rep is the sequence ID |
| """ |
| |
| protein_clusters = pd.read_csv(protein_cluster_path, header=None, sep="\t") |
| protein_clusters.columns = ["tr_cluster_rep", "tr_seqid"] |
|
|
| dna_clusters = pd.read_csv(dna_cluster_path, header=None, sep="\t") |
| dna_clusters.columns = ["dna_cluster_rep", "dna_seqid"] |
|
|
| |
| edges = pd.read_parquet(processed_fimo_path) |
| edges = pd.merge(edges, dna_clusters, on="dna_seqid", how="left") |
| edges = pd.merge(edges, protein_clusters, on="tr_seqid", how="left") |
| edges["edge"] = edges.apply( |
| lambda row: (row["tr_cluster_rep"], row["dna_cluster_rep"]), axis=1 |
| ) |
|
|
| logger.info(f"Total unique edges: {len(edges['edge'].unique().tolist())}") |
| dup_edges = edges.loc[edges.duplicated("edge")]["edge"].unique().tolist() |
| logger.info(f"Total edges with >1 datapoint: {len(dup_edges)}") |
| logger.info( |
| f"Total datapoints belonging to a duplicate edge: {len(edges.loc[edges['edge'].isin(dup_edges)])}" |
| ) |
| return edges |
|
|
|
|
| def check_validity(train, val, test, split_by="both"): |
| """ |
| Rigorous check for no overlap |
| Columns = ["ID","dna_sequence","tr_sequence","tr_cluster_rep","dna_cluster_rep", "scores","split"] |
| """ |
| train_ids = set(train["ID"].unique().tolist()) |
| val_ids = set(val["ID"].unique().tolist()) |
| test_ids = set(test["ID"].unique().tolist()) |
|
|
| assert len(train_ids.intersection(val_ids)) == 0 |
| assert len(train_ids.intersection(test_ids)) == 0 |
| assert len(val_ids.intersection(test_ids)) == 0 |
| logger.info(f"Pass! No overlap in IDs") |
|
|
| |
| train_tr_seqs = set(train["tr_sequence"].unique().tolist()) |
| val_tr_seqs = set(val["tr_sequence"].unique().tolist()) |
| test_tr_seqs = set(test["tr_sequence"].unique().tolist()) |
| |
| train_tr_reps = set(train["tr_cluster_rep"].unique().tolist()) |
| val_tr_reps = set(val["tr_cluster_rep"].unique().tolist()) |
| test_tr_reps = set(test["tr_cluster_rep"].unique().tolist()) |
| |
| logger.info(f"Train-Val TR intersection: {len(train_tr_seqs.intersection(val_tr_seqs))}") |
| logger.info(f"Train-Test TR intersection: {len(train_tr_seqs.intersection(test_tr_seqs))}") |
| logger.info(f"Val-Test TR intersection: {len(val_tr_seqs.intersection(test_tr_seqs))}") |
| |
| logger.info(f"Train-Val TR Cluster Rep intersection: {len(train_tr_reps.intersection(val_tr_reps))}") |
| logger.info(f"Train-Test TR Cluster Rep intersection: {len(train_tr_reps.intersection(test_tr_reps))}") |
| logger.info(f"Val-Test TR Cluster Rep intersection: {len(val_tr_reps.intersection(test_tr_reps))}") |
| |
| |
| train_dna_seqs = set(train["dna_sequence"].unique().tolist()) |
| val_dna_seqs = set(val["dna_sequence"].unique().tolist()) |
| test_dna_seqs = set(test["dna_sequence"].unique().tolist()) |
| |
| train_dna_reps = set(train["dna_cluster_rep"].unique().tolist()) |
| val_dna_reps = set(val["dna_cluster_rep"].unique().tolist()) |
| test_dna_reps = set(test["dna_cluster_rep"].unique().tolist()) |
| |
| logger.info(f"Train-Val DNA intersection: {len(train_dna_seqs.intersection(val_dna_seqs))}") |
| logger.info(f"Train-Test DNA intersection: {len(train_dna_seqs.intersection(test_dna_seqs))}") |
| logger.info(f"Val-Test DNA intersection: {len(val_dna_seqs.intersection(test_dna_seqs))}") |
| |
| logger.info(f"Train-Val DNA Cluster Rep intersection: {len(train_dna_reps.intersection(val_dna_reps))}") |
| logger.info(f"Train-Test DNA Cluster Rep intersection: {len(train_dna_reps.intersection(test_dna_reps))}") |
| logger.info(f"Val-Test DNA Cluster Rep intersection: {len(val_dna_reps.intersection(test_dna_reps))}") |
|
|
| if split_by != "dna": |
| assert len(train_tr_seqs.intersection(val_tr_seqs)) == 0 |
| assert len(train_tr_seqs.intersection(test_tr_seqs)) == 0 |
| assert len(val_tr_seqs.intersection(test_tr_seqs)) == 0 |
| logger.info(f"Pass! No overlap in TR sequences") |
|
|
| assert len(train_tr_reps.intersection(val_tr_reps)) == 0 |
| assert len(train_tr_reps.intersection(test_tr_reps)) == 0 |
| assert len(val_tr_reps.intersection(test_tr_reps)) == 0 |
| logger.info(f"Pass! No overlap in TR cluster reps") |
|
|
| if split_by != "protein": |
| assert len(train_dna_seqs.intersection(val_dna_seqs)) == 0 |
| assert len(train_dna_seqs.intersection(test_dna_seqs)) == 0 |
| assert len(val_dna_seqs.intersection(test_dna_seqs)) == 0 |
| logger.info(f"Pass! No overlap in DNA sequences") |
|
|
| assert len(train_dna_reps.intersection(val_dna_reps)) == 0 |
| assert len(train_dna_reps.intersection(test_dna_reps)) == 0 |
| assert len(val_dna_reps.intersection(test_dna_reps)) == 0 |
| logger.info(f"Pass! No overlap in DNA cluster reps") |
|
|
|
|
| def augment_rc(df): |
| """ |
| Get the reverse complement and add it as a datapoint, effectively doubling the dataset. |
| Also flip the orientation of the scores |
| |
| columns = ["ID","dna_sequence","tr_sequence","tr_cluster_rep","dna_cluster_rep", "scores","split"] |
| """ |
| df_rc = df.copy(deep=True) |
|
|
| df_rc["dna_sequence"] = df_rc["dna_sequence"].apply( |
| lambda x: get_reverse_complement(x) |
| ) |
| df_rc["ID"] = df_rc["ID"] + "_rc" |
| df_rc["scores"] = df_rc["scores"].apply(lambda s: ",".join(s.split(",")[::-1])) |
|
|
| final_df = pd.concat([df, df_rc]).reset_index(drop=True) |
|
|
| return final_df |
|
|
|
|
| def main(cfg: DictConfig): |
| """ |
| Take a set of DNA clusters + protein clusters, and create the best possible splits into train/val/test. |
| """ |
| |
| edge_df = make_edges( |
| processed_fimo_path=Path(root) / cfg.data_task.input_data_path, |
| protein_cluster_path=Path(root) / cfg.data_task.cluster_output_paths.protein, |
| dna_cluster_path=Path(root) / cfg.data_task.cluster_output_paths.dna, |
| ) |
| edges = edge_df["edge"].unique().tolist() |
|
|
| |
| total_proteins = len(edge_df["tr_seqid"].unique().tolist()) |
| total_protein_clusters = len(edge_df["tr_cluster_rep"].unique().tolist()) |
|
|
| no_protein_overlap = (total_proteins) == (total_protein_clusters) |
| logger.info(f"All proteins are in their own clusters: {no_protein_overlap}") |
|
|
| if cfg.data_task.split_by == "dna": |
| if cfg.data_task.test_trs or cfg.data_task.test_dnas: |
| logger.info(f"Splitting with predefined trs/dnas reserved for test set") |
| splits, kept_by_split = split_with_predefined_test( |
| full_df=edge_df, |
| split_names=("train", "val", "test"), |
| test_trs=cfg.data_task.test_trs if cfg.data_task.test_trs else None, |
| test_dnas=cfg.data_task.test_dnas if cfg.data_task.test_dnas else None, |
| ratios=(0.8, 0.1, 0.1), |
| ) |
| train = splits["train"] |
| train["split"]=["train"]*len(train) |
| val = splits["val"] |
| val["split"]=["val"]*len(val) |
| test = splits["test"] |
| test["split"]=["test"]*len(test) |
| leaky_test = splits["leaky_test"] |
| leaky_test["split"]=["leaky_test"]*len(leaky_test) |
| else: |
| logger.info(f"Easy split: all proteins are in their own clusters.") |
| dna_clusters = edge_df["dna_cluster_rep"].unique().tolist() |
| results = split_bipartite_fast( |
| dna_clusters, |
| split_names=("train", "val", "test"), |
| ratios=( |
| cfg.data_task.train_ratio, |
| cfg.data_task.val_ratio, |
| cfg.data_task.test_ratio, |
| ), |
| ) |
| dna_assign, kept_by_split = results |
|
|
| |
| edge_df["split"] = edge_df["dna_cluster_rep"].map(dna_assign) |
| train = edge_df.loc[edge_df["split"] == "train"].reset_index(drop=True) |
| val = edge_df.loc[edge_df["split"] == "val"].reset_index(drop=True) |
| test = edge_df.loc[edge_df["split"] == "test"].reset_index(drop=True) |
| leaky_test = pd.DataFrame(columns=edge_df.columns) |
|
|
| |
| print_split_ratios(kept_by_split) |
|
|
| |
| |
| assert len(edge_df["ID"].unique()) == len(edge_df) |
| split_cols = [ |
| "ID", |
| "dna_sequence", |
| "tr_sequence", |
| "tr_cluster_rep", |
| "dna_cluster_rep", |
| "scores", |
| "split", |
| ] |
| train = train[split_cols] |
| val = val[split_cols] |
| test = test[split_cols] |
| leaky_test = leaky_test[split_cols] |
|
|
| |
| check_validity(train, val, test, split_by=cfg.data_task.split_by) |
|
|
| total = sum([len(train), len(val), len(test), len(leaky_test)]) |
| logger.info(f"Length of train dataset: {len(train)} ({100*len(train)/total:.2f}%)") |
| logger.info(f"Length of val dataset: {len(val)} ({100*len(val)/total:.2f}%)") |
| logger.info(f"Length of test dataset: {len(test)} ({100*len(test)/total:.2f}%)") |
| logger.info(f"Length of leaky_test dataset: {len(leaky_test)} ({100*len(leaky_test)/total:.2f}%)") |
| logger.info(f"Total sequences = {total}. Same as edges size? {total==len(edge_df)}") |
|
|
| og_unique_dna = pd.concat([train, val, test, leaky_test]) |
| og_unique_dna = len(og_unique_dna["dna_sequence"].unique()) |
|
|
| |
| if cfg.data_task.augment_rc: |
| train = augment_rc(train) |
| val = augment_rc(val) |
| test = augment_rc(test) |
| leaky_test = augment_rc(leaky_test) |
|
|
| logger.info(f"Added reverse complement sequences to train, val, and test (and leaky test)") |
|
|
| check_validity(train, val, test, split_by=cfg.data_task.split_by) |
|
|
| total = sum([len(train), len(val), len(test), len(leaky_test)]) |
| logger.info( |
| f"Length of train dataset: {len(train)} ({100*len(train)/total:.2f}%)" |
| ) |
| logger.info(f"Length of val dataset: {len(val)} ({100*len(val)/total:.2f}%)") |
| logger.info(f"Length of test dataset: {len(test)} ({100*len(test)/total:.2f}%)") |
| logger.info(f"Length of leaky_test dataset: {len(leaky_test)} ({100*len(leaky_test)/total:.2f}%)") |
| logger.info( |
| f"Total sequences = {total}. Same as edges size? {total==len(edge_df)}" |
| ) |
|
|
| |
| all_data = pd.concat([train, val, test, leaky_test]) |
| all_data["dna_seqid"] = all_data["ID"].str.split("_", n=1, expand=True)[1] |
| dna_dict = dict(zip(all_data["dna_seqid"], all_data["dna_sequence"])) |
| assert len(dna_dict) == len(all_data.drop_duplicates(["dna_sequence"])) |
| new_map_path = str(Path(root) / cfg.data_task.dna_map_path).replace( |
| ".json", "_with_rc.json" |
| ) |
|
|
| with open(new_map_path, "w") as f: |
| json.dump(dna_dict, f, indent=2) |
| logger.info( |
| f"Saved DNA maps with reverse complements (len {len(dna_dict)}=2*original map of len {og_unique_dna}=={len(dna_dict)==2*og_unique_dna}) to {new_map_path}" |
| ) |
|
|
| |
| split_out_dir = Path(root) / cfg.data_task.split_out_dir |
| os.makedirs(split_out_dir, exist_ok=True) |
| |
| |
| train["fimo_binary_sores"] = train["scores"].apply(lambda x: convert_scores(x, mode=1)) |
| val["fimo_binary_sores"] = val["scores"].apply(lambda x: convert_scores(x, mode=1)) |
| test["fimo_binary_sores"] = test["scores"].apply(lambda x: convert_scores(x, mode=1)) |
| leaky_test["fimo_binary_sores"] = leaky_test["scores"].apply(lambda x: convert_scores(x, mode=1)) |
| |
| |
| split_final_cols = ["ID", "dna_sequence", "tr_sequence", "scores", "fimo_binary_sores", "split"] |
| train[split_final_cols].to_csv(split_out_dir / "train.csv", index=False) |
| val[split_final_cols].to_csv(split_out_dir / "val.csv", index=False) |
| test[split_final_cols].to_csv(split_out_dir / "test.csv", index=False) |
| leaky_test[split_final_cols].to_csv(split_out_dir / "leaky_test.csv", index=False) |
| logger.info(f"Saved all splits to {split_out_dir}") |
|
|