svincoff's picture
added dropout and overfit prevention
9da03b7
from collections import Counter, defaultdict
from ortools.linear_solver import pywraplp
import random
from omegaconf import DictConfig
import pandas as pd
from pathlib import Path
import os
import numpy as np
from sklearn.model_selection import train_test_split
from dpacman.data_tasks.fimo.post_fimo import get_reverse_complement
import json
import rootutils
from dpacman.utils import pylogger
root = rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
logger = pylogger.RankedLogger(__name__, rank_zero_only=True)
def split_with_predefined_test(
full_df = pd.DataFrame(),
split_names=("train", "val", "test"),
test_trs=None,
test_dnas=None,
ratios=(0.8, 0.1, 0.1),
):
"""
Method for splitting into train and val with a predefined test set.
The proteins in the test set, and the DNA clusters of the DNAs they're associated with, must be excluded from train and val.
The remaining rows for train and val are split to preserve 80/10/10 as best as possible.
"""
test = full_df.copy(deep=True)
if test_trs is not None:
test = test.loc[test["tr_seqid"].isin(test_trs)].reset_index(drop=True)
if test_dnas is not None:
test = test.loc[test["dna_seqid"].isin(test_dnas)].reset_index(drop=True)
tr_clusters_to_exclude = test["tr_cluster_rep"].unique().tolist()
dna_clusters_to_exclude = test["dna_cluster_rep"].unique().tolist()
remaining = full_df.loc[
(~full_df["tr_cluster_rep"].isin(tr_clusters_to_exclude)) &
(~full_df["dna_cluster_rep"].isin(dna_clusters_to_exclude))
].reset_index(drop=True)
test_ids = test["ID"].unique().tolist()
remaining_ids = remaining["ID"].unique().tolist()
remaining_clusters = remaining["dna_cluster_rep"].unique().tolist()
lost_rows = full_df.loc[
(~full_df["ID"].isin(test_ids)) &
(~full_df["ID"].isin(remaining_ids))
]
logger.info(f"Rows in test: {len(test)}")
logger.info(f"Rows to be split between train and val: {len(remaining)}")
total_rows = len(test) + len(remaining)
logger.info(f"Total rows: {total_rows}. Test percentage: {100*len(test)/total_rows:.2f}%")
logger.info(f"Lost rows: {len(lost_rows)}")
train_ratio_from_remaining = round((0.8*total_rows)/len(remaining), 2)
# use sklearn
test_size_1 = 1 - train_ratio_from_remaining
logger.info(
f"\tPerforming first split: non-test clusters -> train clusters ({round(1-test_size_1,3)}) and val ({test_size_1})"
)
X = remaining_clusters
y = [0] * len(remaining_clusters)
X_train, X_val, y_train, y_val = train_test_split(
X, y, test_size=test_size_1, random_state=0
)
train = remaining.loc[remaining["dna_cluster_rep"].isin(X_train)]
val = remaining.loc[remaining["dna_cluster_rep"].isin(X_val)]
leaky_test = lost_rows
kept_by_split = {
"train": len(X_train),
"val": len(X_val),
"test": len(test["dna_cluster_rep"].unique())
}
splits = {
"train": train,
"val": val,
"test": test,
"leaky_test": leaky_test
}
return splits, kept_by_split
def split_bipartite_fast(
dna_clusters,
split_names=("train", "val", "test"),
ratios=(0.8, 0.1, 0.1),
):
# use sklearn
test_size_1 = 0.2
test_size_2 = 0.5
logger.info(
f"\tPerforming first split: all clusters -> train clusters ({round(1-test_size_1,3)}) and other ({test_size_1})"
)
X = dna_clusters
y = [0] * len(dna_clusters)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=test_size_1, random_state=0
)
logger.info(
f"\tPerforming second split: other -> val clusters ({round(1-test_size_2,3)}) and test clusters ({test_size_2})"
)
X_val, X_test, y_val, y_test = train_test_split(
X_test, y_test, test_size=test_size_2, random_state=0
)
dna_assign = {}
for x in X_train:
dna_assign[x] = "train"
for x in X_val:
dna_assign[x] = "val"
for x in X_test:
dna_assign[x] = "test"
kept_by_split = {"train": len(X_train), "val": len(X_val), "test": len(X_test)}
return dna_assign, kept_by_split
# construct new labels
def convert_scores(scores, mode=1):
"""
Two modes: 1 means FIMO peaks get 1. 0 means FIMO peaks get their max score
"""
svec = [int(x) for x in scores.split(",")]
max_score = max(svec)
if mode ==1:
binary_svec = [0 if x<max_score else 1 for x in svec]
assert(svec.count(max_score)==binary_svec.count(1))
else:
binary_svec = [0 if x<max_score else max_score for x in svec]
assert(svec.count(max_score)==binary_svec.count(max_score))
binary_svec = ",".join([str(x) for x in binary_svec])
return binary_svec
def print_split_ratios(kept_by_split):
total = sum(kept_by_split.values())
train_pcnt = 100 * kept_by_split["train"] / total
val_pcnt = 100 * kept_by_split["val"] / total
test_pcnt = 100 * kept_by_split["test"] / total
logger.info(
f"Cluster distribution - Train: {train_pcnt:.2f}%, Val: {val_pcnt:.2f}%, Test: {test_pcnt:.2f}%"
)
def make_edges(
processed_fimo_path: str, protein_cluster_path: str, dna_cluster_path: str
):
"""
Make edges for input to the splitting algorithm. Edges consist of: (tr_cluster_rep)_(dna_cluster_rep) where the cluster rep is the sequence ID
"""
# Read cluser data
protein_clusters = pd.read_csv(protein_cluster_path, header=None, sep="\t")
protein_clusters.columns = ["tr_cluster_rep", "tr_seqid"]
dna_clusters = pd.read_csv(dna_cluster_path, header=None, sep="\t")
dna_clusters.columns = ["dna_cluster_rep", "dna_seqid"]
# Read datapoints
edges = pd.read_parquet(processed_fimo_path)
edges = pd.merge(edges, dna_clusters, on="dna_seqid", how="left")
edges = pd.merge(edges, protein_clusters, on="tr_seqid", how="left")
edges["edge"] = edges.apply(
lambda row: (row["tr_cluster_rep"], row["dna_cluster_rep"]), axis=1
)
logger.info(f"Total unique edges: {len(edges['edge'].unique().tolist())}")
dup_edges = edges.loc[edges.duplicated("edge")]["edge"].unique().tolist()
logger.info(f"Total edges with >1 datapoint: {len(dup_edges)}")
logger.info(
f"Total datapoints belonging to a duplicate edge: {len(edges.loc[edges['edge'].isin(dup_edges)])}"
)
return edges
def check_validity(train, val, test, split_by="both"):
"""
Rigorous check for no overlap
Columns = ["ID","dna_sequence","tr_sequence","tr_cluster_rep","dna_cluster_rep", "scores","split"]
"""
train_ids = set(train["ID"].unique().tolist())
val_ids = set(val["ID"].unique().tolist())
test_ids = set(test["ID"].unique().tolist())
assert len(train_ids.intersection(val_ids)) == 0
assert len(train_ids.intersection(test_ids)) == 0
assert len(val_ids.intersection(test_ids)) == 0
logger.info(f"Pass! No overlap in IDs")
# Investigate TR intersection. No assertions unless we are explicitly splitting on this.
train_tr_seqs = set(train["tr_sequence"].unique().tolist())
val_tr_seqs = set(val["tr_sequence"].unique().tolist())
test_tr_seqs = set(test["tr_sequence"].unique().tolist())
train_tr_reps = set(train["tr_cluster_rep"].unique().tolist())
val_tr_reps = set(val["tr_cluster_rep"].unique().tolist())
test_tr_reps = set(test["tr_cluster_rep"].unique().tolist())
logger.info(f"Train-Val TR intersection: {len(train_tr_seqs.intersection(val_tr_seqs))}")
logger.info(f"Train-Test TR intersection: {len(train_tr_seqs.intersection(test_tr_seqs))}")
logger.info(f"Val-Test TR intersection: {len(val_tr_seqs.intersection(test_tr_seqs))}")
logger.info(f"Train-Val TR Cluster Rep intersection: {len(train_tr_reps.intersection(val_tr_reps))}")
logger.info(f"Train-Test TR Cluster Rep intersection: {len(train_tr_reps.intersection(test_tr_reps))}")
logger.info(f"Val-Test TR Cluster Rep intersection: {len(val_tr_reps.intersection(test_tr_reps))}")
# Investigate DNA intersection. No assertions unless we are explicitly splitting on this.
train_dna_seqs = set(train["dna_sequence"].unique().tolist())
val_dna_seqs = set(val["dna_sequence"].unique().tolist())
test_dna_seqs = set(test["dna_sequence"].unique().tolist())
train_dna_reps = set(train["dna_cluster_rep"].unique().tolist())
val_dna_reps = set(val["dna_cluster_rep"].unique().tolist())
test_dna_reps = set(test["dna_cluster_rep"].unique().tolist())
logger.info(f"Train-Val DNA intersection: {len(train_dna_seqs.intersection(val_dna_seqs))}")
logger.info(f"Train-Test DNA intersection: {len(train_dna_seqs.intersection(test_dna_seqs))}")
logger.info(f"Val-Test DNA intersection: {len(val_dna_seqs.intersection(test_dna_seqs))}")
logger.info(f"Train-Val DNA Cluster Rep intersection: {len(train_dna_reps.intersection(val_dna_reps))}")
logger.info(f"Train-Test DNA Cluster Rep intersection: {len(train_dna_reps.intersection(test_dna_reps))}")
logger.info(f"Val-Test DNA Cluster Rep intersection: {len(val_dna_reps.intersection(test_dna_reps))}")
if split_by != "dna":
assert len(train_tr_seqs.intersection(val_tr_seqs)) == 0
assert len(train_tr_seqs.intersection(test_tr_seqs)) == 0
assert len(val_tr_seqs.intersection(test_tr_seqs)) == 0
logger.info(f"Pass! No overlap in TR sequences")
assert len(train_tr_reps.intersection(val_tr_reps)) == 0
assert len(train_tr_reps.intersection(test_tr_reps)) == 0
assert len(val_tr_reps.intersection(test_tr_reps)) == 0
logger.info(f"Pass! No overlap in TR cluster reps")
if split_by != "protein":
assert len(train_dna_seqs.intersection(val_dna_seqs)) == 0
assert len(train_dna_seqs.intersection(test_dna_seqs)) == 0
assert len(val_dna_seqs.intersection(test_dna_seqs)) == 0
logger.info(f"Pass! No overlap in DNA sequences")
assert len(train_dna_reps.intersection(val_dna_reps)) == 0
assert len(train_dna_reps.intersection(test_dna_reps)) == 0
assert len(val_dna_reps.intersection(test_dna_reps)) == 0
logger.info(f"Pass! No overlap in DNA cluster reps")
def augment_rc(df):
"""
Get the reverse complement and add it as a datapoint, effectively doubling the dataset.
Also flip the orientation of the scores
columns = ["ID","dna_sequence","tr_sequence","tr_cluster_rep","dna_cluster_rep", "scores","split"]
"""
df_rc = df.copy(deep=True)
df_rc["dna_sequence"] = df_rc["dna_sequence"].apply(
lambda x: get_reverse_complement(x)
)
df_rc["ID"] = df_rc["ID"] + "_rc"
df_rc["scores"] = df_rc["scores"].apply(lambda s: ",".join(s.split(",")[::-1]))
final_df = pd.concat([df, df_rc]).reset_index(drop=True)
return final_df
def main(cfg: DictConfig):
"""
Take a set of DNA clusters + protein clusters, and create the best possible splits into train/val/test.
"""
# construct edges from training data
edge_df = make_edges(
processed_fimo_path=Path(root) / cfg.data_task.input_data_path,
protein_cluster_path=Path(root) / cfg.data_task.cluster_output_paths.protein,
dna_cluster_path=Path(root) / cfg.data_task.cluster_output_paths.dna,
)
edges = edge_df["edge"].unique().tolist()
# figure out if we actually even have a conflict
total_proteins = len(edge_df["tr_seqid"].unique().tolist())
total_protein_clusters = len(edge_df["tr_cluster_rep"].unique().tolist())
no_protein_overlap = (total_proteins) == (total_protein_clusters)
logger.info(f"All proteins are in their own clusters: {no_protein_overlap}")
if cfg.data_task.split_by == "dna":
if cfg.data_task.test_trs or cfg.data_task.test_dnas:
logger.info(f"Splitting with predefined trs/dnas reserved for test set")
splits, kept_by_split = split_with_predefined_test(
full_df=edge_df,
split_names=("train", "val", "test"),
test_trs=cfg.data_task.test_trs if cfg.data_task.test_trs else None,
test_dnas=cfg.data_task.test_dnas if cfg.data_task.test_dnas else None,
ratios=(0.8, 0.1, 0.1),
)
train = splits["train"]
train["split"]=["train"]*len(train)
val = splits["val"]
val["split"]=["val"]*len(val)
test = splits["test"]
test["split"]=["test"]*len(test)
leaky_test = splits["leaky_test"]
leaky_test["split"]=["leaky_test"]*len(leaky_test)
else:
logger.info(f"Easy split: all proteins are in their own clusters.")
dna_clusters = edge_df["dna_cluster_rep"].unique().tolist()
results = split_bipartite_fast(
dna_clusters,
split_names=("train", "val", "test"),
ratios=(
cfg.data_task.train_ratio,
cfg.data_task.val_ratio,
cfg.data_task.test_ratio,
),
)
dna_assign, kept_by_split = results
# assign datapoints to cluster by their DNA cluster rep
edge_df["split"] = edge_df["dna_cluster_rep"].map(dna_assign)
train = edge_df.loc[edge_df["split"] == "train"].reset_index(drop=True)
val = edge_df.loc[edge_df["split"] == "val"].reset_index(drop=True)
test = edge_df.loc[edge_df["split"] == "test"].reset_index(drop=True)
leaky_test = pd.DataFrame(columns=edge_df.columns)
# Print ratios: hopefully close to desired (e.g. 80/10/10)
print_split_ratios(kept_by_split)
# Make train, val, test sets
# make sure no ID is duplicate
assert len(edge_df["ID"].unique()) == len(edge_df)
split_cols = [
"ID",
"dna_sequence",
"tr_sequence",
"tr_cluster_rep",
"dna_cluster_rep",
"scores",
"split",
]
train = train[split_cols]
val = val[split_cols]
test = test[split_cols]
leaky_test = leaky_test[split_cols]
# ensure there is no overlap
check_validity(train, val, test, split_by=cfg.data_task.split_by)
total = sum([len(train), len(val), len(test), len(leaky_test)])
logger.info(f"Length of train dataset: {len(train)} ({100*len(train)/total:.2f}%)")
logger.info(f"Length of val dataset: {len(val)} ({100*len(val)/total:.2f}%)")
logger.info(f"Length of test dataset: {len(test)} ({100*len(test)/total:.2f}%)")
logger.info(f"Length of leaky_test dataset: {len(leaky_test)} ({100*len(leaky_test)/total:.2f}%)")
logger.info(f"Total sequences = {total}. Same as edges size? {total==len(edge_df)}")
og_unique_dna = pd.concat([train, val, test, leaky_test])
og_unique_dna = len(og_unique_dna["dna_sequence"].unique())
## Now do RC data augmentation if asked
if cfg.data_task.augment_rc:
train = augment_rc(train)
val = augment_rc(val)
test = augment_rc(test)
leaky_test = augment_rc(leaky_test)
logger.info(f"Added reverse complement sequences to train, val, and test (and leaky test)")
check_validity(train, val, test, split_by=cfg.data_task.split_by)
total = sum([len(train), len(val), len(test), len(leaky_test)])
logger.info(
f"Length of train dataset: {len(train)} ({100*len(train)/total:.2f}%)"
)
logger.info(f"Length of val dataset: {len(val)} ({100*len(val)/total:.2f}%)")
logger.info(f"Length of test dataset: {len(test)} ({100*len(test)/total:.2f}%)")
logger.info(f"Length of leaky_test dataset: {len(leaky_test)} ({100*len(leaky_test)/total:.2f}%)")
logger.info(
f"Total sequences = {total}. Same as edges size? {total==len(edge_df)}"
)
# since we've added all these new DNA sequences, we do need a new mapping of seq id to dna sequence
all_data = pd.concat([train, val, test, leaky_test])
all_data["dna_seqid"] = all_data["ID"].str.split("_", n=1, expand=True)[1]
dna_dict = dict(zip(all_data["dna_seqid"], all_data["dna_sequence"]))
assert len(dna_dict) == len(all_data.drop_duplicates(["dna_sequence"]))
new_map_path = str(Path(root) / cfg.data_task.dna_map_path).replace(
".json", "_with_rc.json"
)
with open(new_map_path, "w") as f:
json.dump(dna_dict, f, indent=2)
logger.info(
f"Saved DNA maps with reverse complements (len {len(dna_dict)}=2*original map of len {og_unique_dna}=={len(dna_dict)==2*og_unique_dna}) to {new_map_path}"
)
# create the output dir
split_out_dir = Path(root) / cfg.data_task.split_out_dir
os.makedirs(split_out_dir, exist_ok=True)
# add binary_scores to allow other training modes
train["fimo_binary_sores"] = train["scores"].apply(lambda x: convert_scores(x, mode=1))
val["fimo_binary_sores"] = val["scores"].apply(lambda x: convert_scores(x, mode=1))
test["fimo_binary_sores"] = test["scores"].apply(lambda x: convert_scores(x, mode=1))
leaky_test["fimo_binary_sores"] = leaky_test["scores"].apply(lambda x: convert_scores(x, mode=1))
# slect final cols and save
split_final_cols = ["ID", "dna_sequence", "tr_sequence", "scores", "fimo_binary_sores", "split"]
train[split_final_cols].to_csv(split_out_dir / "train.csv", index=False)
val[split_final_cols].to_csv(split_out_dir / "val.csv", index=False)
test[split_final_cols].to_csv(split_out_dir / "test.csv", index=False)
leaky_test[split_final_cols].to_csv(split_out_dir / "leaky_test.csv", index=False)
logger.info(f"Saved all splits to {split_out_dir}")