| | import pandas as pd |
| | import os |
| | import pickle |
| | from fuson_plm.data.config import SPLIT |
| | from fuson_plm.utils.logging import log_update, open_logfile |
| | from fuson_plm.utils.splitting import split_clusters, check_split_validity |
| | from fuson_plm.utils.visualizing import set_font, visualize_splits |
| | |
| | def get_benchmark_data(fuson_db_path, clusters): |
| | """ |
| | """ |
| | |
| | fuson_db = pd.read_csv(fuson_db_path) |
| | |
| | |
| | original_benchmark_sequences = fuson_db.loc[(fuson_db['benchmark'].notna()) ] |
| | benchmark_sequences = fuson_db.loc[ |
| | (fuson_db['benchmark'].notna()) & |
| | (fuson_db['aa_seq'].isin(list(clusters['member seq']))) |
| | ]['aa_seq'].to_list() |
| | |
| | |
| | benchmark_seq_ids = fuson_db.loc[fuson_db['benchmark'].notna()]['seq_id'] |
| | |
| | |
| | benchmark_cluster_reps = clusters.loc[clusters['member seq_id'].isin(benchmark_seq_ids)]['representative seq_id'].unique().tolist() |
| | log_update(f"\t{len(benchmark_sequences)}/{len(original_benchmark_sequences)} benchmarking sequences (only those shorter than config.CLUSTERING[\'max_seq_length\']) were grouped into {len(benchmark_cluster_reps)} clusters. These will be reserved for the test set.") |
| | |
| | return benchmark_cluster_reps, benchmark_sequences |
| |
|
| | def get_training_dfs(train, val, test): |
| | log_update('\nMaking dataframes for ESM finetuning...') |
| | |
| | |
| | train = train.drop(columns=['representative seq_id','member seq_id', 'representative seq']).rename(columns={'member seq':'sequence'}) |
| | val = val.drop(columns=['representative seq_id','member seq_id', 'representative seq']).rename(columns={'member seq':'sequence'}) |
| | test = test.drop(columns=['representative seq_id','member seq_id', 'representative seq']).rename(columns={'member seq':'sequence'}) |
| | |
| | return train, val, test |
| | |
| | def main(): |
| | """ |
| | """ |
| | |
| | LOG_PATH = "splitting_log.txt" |
| | FUSON_DB_PATH = SPLIT.FUSON_DB_PATH |
| | CLUSTER_OUTPUT_PATH = SPLIT.CLUSTER_OUTPUT_PATH |
| | RANDOM_STATE_1 = SPLIT.RANDOM_STATE_1 |
| | TEST_SIZE_1 = SPLIT.TEST_SIZE_1 |
| | RANDOM_STATE_2 = SPLIT.RANDOM_STATE_2 |
| | TEST_SIZE_2 = SPLIT.TEST_SIZE_2 |
| | |
| | |
| | set_font() |
| | |
| | |
| | with open_logfile(LOG_PATH): |
| | |
| | log_update("Loaded data-splitting configurations from config.py") |
| | SPLIT.print_config(indent='\t') |
| | |
| | |
| | os.makedirs("splits",exist_ok=True) |
| | |
| | |
| | clusters = pd.read_csv(CLUSTER_OUTPUT_PATH) |
| | reps = clusters['representative seq_id'].unique().tolist() |
| | log_update(f"\nPreparing clusters...\n\tCollected {len(reps)} clusters for splitting") |
| | |
| | |
| | benchmark_cluster_reps, benchmark_sequences = get_benchmark_data(FUSON_DB_PATH, clusters) |
| | |
| | |
| | splits = split_clusters(reps, benchmark_cluster_reps=benchmark_cluster_reps, |
| | random_state_1 = RANDOM_STATE_1, random_state_2 = RANDOM_STATE_2, test_size_1 = TEST_SIZE_1, test_size_2 = TEST_SIZE_2) |
| | X_train = splits['X_train'] |
| | X_val = splits['X_val'] |
| | X_test = splits['X_test'] |
| | |
| | |
| | train_clusters = clusters.loc[clusters['representative seq_id'].isin(X_train)].reset_index(drop=True) |
| | val_clusters = clusters.loc[clusters['representative seq_id'].isin(X_val)].reset_index(drop=True) |
| | test_clusters = clusters.loc[clusters['representative seq_id'].isin(X_test)].reset_index(drop=True) |
| | |
| | |
| | check_split_validity(train_clusters, val_clusters, test_clusters, benchmark_sequences=benchmark_sequences) |
| | |
| | |
| | min_train_seqlen = min(train_clusters['member seq'].str.len()) |
| | max_train_seqlen = max(train_clusters['member seq'].str.len()) |
| | min_val_seqlen = min(val_clusters['member seq'].str.len()) |
| | max_val_seqlen = max(val_clusters['member seq'].str.len()) |
| | min_test_seqlen = min(test_clusters['member seq'].str.len()) |
| | max_test_seqlen = max(test_clusters['member seq'].str.len()) |
| | log_update(f"\nLength breakdown summary...\n\tTrain: min seq length = {min_train_seqlen}, max seq length = {max_train_seqlen}") |
| | log_update(f"\tVal: min seq length = {min_val_seqlen}, max seq length = {max_val_seqlen}") |
| | log_update(f"\tTest: min seq length = {min_test_seqlen}, max seq length = {max_test_seqlen}") |
| | |
| | |
| | visualize_splits(train_clusters, val_clusters, test_clusters, benchmark_cluster_reps) |
| | |
| | |
| | train_clusters.to_csv("../data/splits/train_cluster_split.csv",index=False) |
| | val_clusters.to_csv("../data/splits/val_cluster_split.csv",index=False) |
| | test_clusters.to_csv("../data/splits/test_cluster_split.csv",index=False) |
| | log_update('\nSaved cluster splits to splitting/train_cluster_split.csv, splitting/val_cluster_split.csv, splitting/test_cluster_split.csv') |
| | cols=','.join(list(train_clusters.columns)) |
| | log_update(f'\tColumns: {cols}') |
| | |
| | |
| | train_df, val_df, test_df = get_training_dfs(train_clusters, val_clusters, test_clusters) |
| | train_df.to_csv("../data/splits/train_df.csv",index=False) |
| | val_df.to_csv("../data/splits/val_df.csv",index=False) |
| | test_df.to_csv("../data/splits/test_df.csv",index=False) |
| | log_update('\nSaved training dataframes to splits/train_df.csv, splits/val_df.csv, splits/test_df.csv') |
| | cols=','.join(list(train_df.columns)) |
| | log_update(f'\tColumns: {cols}') |
| |
|
| | if __name__ == "__main__": |
| | main() |