| import matplotlib.pyplot as plt |
| import matplotlib.patches as mpatches |
| import seaborn as sns |
| import pandas as pd |
| import numpy as np |
| import os |
| import matplotlib.colors as mcolors |
| from fuson_plm.utils.visualizing import set_font |
|
|
| fo_puncta_db_training_thresh31 = pd.DataFrame(data={ |
| 'Model Type': ['fo_puncta_ml'], |
| 'Model Name': ['fo_puncta_ml_literature'], |
| 'Model Epoch': np.nan, |
| 'Accuracy': 0.81, |
| 'Precision': 0.78, |
| 'Recall': 0.98, |
| 'F1 Score': 0.87, |
| 'AUROC': 0.88, |
| 'AUPRC': 0.94 |
| }) |
|
|
| fo_puncta_db_verification_thresh83 = pd.DataFrame(data={ |
| 'Model Type': ['fo_puncta_ml'], |
| 'Model Name': ['fo_puncta_ml_literature'], |
| 'Model Epoch': np.nan, |
| 'Accuracy': 0.79, |
| 'Precision': 0.81, |
| 'Recall': 0.89, |
| 'F1 Score': 0.85, |
| 'AUROC': 0.73, |
| 'AUPRC': 0.82 |
| }) |
|
|
| |
| def lengthen_model_name(row): |
| name = row['Model Name'] |
| epoch = row['Model Epoch'] |
| |
| if 'esm' in name: |
| return name |
| if 'puncta' in name: |
| return name |
| |
| return f'{name}_e{epoch}' |
|
|
| |
| def shorten_model_name(row): |
| name = row['Model Name'] |
| epoch = row['Model Epoch'] |
| |
| if 'esm' in name: |
| return 'ESM-2-650M' |
| if name=='fo_puncta_ml': |
| return 'FO-Puncta-ML' |
| if name=='fo_puncta_ml_literature': |
| return 'FO-Puncta-ML Lit' |
| if name=="prot_t5_xl_half_uniref50_enc": |
| return 'ProtT5-XL-U50' |
| |
| if 'snp_' in name: |
| prob_type = 'snp' |
| elif 'uniform_' in name: |
| prob_type = 'uni' |
| |
| layers = name.split('layers')[0].split('_')[-1] |
| dt = name.split('mask')[1].split('-', 1)[1] |
| |
| return f'{prob_type}_{layers}L_{dt}_e{epoch}' |
| |
| def make_final_bar(dataframe, title, save_path): |
| set_font() |
| df = dataframe.copy(deep=True) |
|
|
| |
| pivot_df = df.pivot(index='Metric', columns='Name', values='Value') |
| ordered_columns = [x for x in ['FOdb','ProtT5-XL-U50', 'ESM-2-650M', 'FusOn-pLM'] if x in pivot_df.columns] |
| pivot_df = pivot_df[ordered_columns] |
|
|
| |
| engineered_embeddings = ['FOdb'] |
| deep_learning_embeddings = ['ProtT5-XL-U50', 'ESM-2-650M', 'FusOn-pLM'] |
|
|
| |
| metric_order = ['Accuracy', 'Precision', 'Recall', 'F1', 'AUROC'][::-1] |
| pivot_df = pivot_df.reindex(metric_order) |
|
|
| |
| fig, ax = plt.subplots(figsize=(8, 6), dpi=300) |
|
|
| |
| bar_width = 0.2 |
| indices = np.arange(len(pivot_df)) |
|
|
| |
| color_map = { |
| |
| 'FOdb': "#E69F00", |
| 'ESM-2-650M': "#F0E442", |
| 'FusOn-pLM': "#FF69B4", |
| 'ProtT5-XL-U50': "#00ccff" |
| } |
| colors = [color_map[col] for col in ordered_columns] |
|
|
| |
| engineered_handles = [] |
| deep_learning_handles = [] |
| for i, (name, color) in enumerate(zip(pivot_df.columns, colors)): |
| bars = ax.barh(indices + i * bar_width, pivot_df[name], bar_width, label=name, color=color) |
| if name in engineered_embeddings: |
| engineered_handles.append(bars[0]) |
| else: |
| deep_learning_handles.append(bars[0]) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| plt.xlabel('Value', fontsize=44) |
| ax.set_yticks(indices + bar_width * 1.5) |
| ax.set_xlim([0, 1]) |
| ax.set_yticklabels(pivot_df.index) |
| |
| ax.tick_params(axis='x') |
| ax.set_title(title, fontsize=44) |
|
|
| |
| for label in plt.gca().get_xticklabels(): |
| label.set_fontsize(32) |
| for label in plt.gca().get_yticklabels(): |
| label.set_fontsize(32) |
|
|
| |
| if engineered_handles: |
| legend1 = fig.legend( |
| engineered_handles[::-1], |
| [emb for emb in engineered_embeddings if emb in ordered_columns][::-1], |
| loc='center left', |
| bbox_to_anchor=(1, 0.4), |
| title="Engineered Embeddings", |
| title_fontsize=24) |
| if deep_learning_handles: |
| legend2 = fig.legend( |
| deep_learning_handles[::-1], |
| [emb for emb in deep_learning_embeddings if emb in ordered_columns][::-1], |
| loc='center left', |
| bbox_to_anchor=(1, 0.6), |
| title="Learned Embeddings", |
| title_fontsize=24) |
|
|
| |
| if engineered_handles: |
| ax.add_artist(legend1) |
| for text in legend1.get_texts(): |
| text.set_fontsize(22) |
| for handle in legend1.legendHandles: |
| if isinstance(handle, mpatches.Patch): |
| handle.set_height(15) |
| handle.set_width(20) |
| elif hasattr(handle, '_sizes'): |
| handle._sizes = [200] |
|
|
| if deep_learning_handles: |
| ax.add_artist(legend2) |
| for text in legend2.get_texts(): |
| text.set_fontsize(22) |
| for handle in legend2.legendHandles: |
| if isinstance(handle, mpatches.Patch): |
| handle.set_height(15) |
| handle.set_width(20) |
| elif hasattr(handle, '_sizes'): |
| handle._sizes = [200] |
|
|
| plt.tight_layout() |
|
|
| |
| plt.savefig(save_path, dpi=300, bbox_inches='tight') |
|
|
| plt.show() |
|
|
| def prepare_data_for_bar(results_dir, task, split, thresh=None): |
| fname = f"{task}_{split}FOs_results.csv" |
| if thresh is not None: fname = f"{task}_{split}FOs_{thresh}thresh_results.csv" |
| image_save_path = results_dir + '/figures/' + fname.split('_results.csv')[0]+'_barchart.png' |
| |
| data = pd.read_csv(f"{results_dir}/{fname}") |
| data = data.loc[ |
| data['Model Name'].isin(['best', |
| 'fo_puncta_ml', |
| 'esm2_t33_650M_UR50D', |
| 'prot_t5_xl_half_uniref50_enc']) |
| ] |
| data = pd.DataFrame(data = { |
| 'Name': data['Model Name'].tolist() * 5, |
| 'Metric': ['Accuracy', 'Accuracy', 'Accuracy','Accuracy', |
| 'Precision', 'Precision', 'Precision', 'Precision', |
| 'Recall', 'Recall', 'Recall', 'Recall', |
| 'F1', 'F1', 'F1','F1', |
| 'AUROC', 'AUROC', 'AUROC','AUROC'], |
| 'Value': data['Accuracy'].tolist() + data['Precision'].tolist() + data['Recall'].tolist() + data['F1 Score'].tolist() + data['AUROC'].tolist() |
| } |
| ) |
| rename_dict = {'fo_puncta_ml': 'FOdb', |
| 'esm2_t33_650M_UR50D':'ESM-2-650M', |
| 'best':'FusOn-pLM', |
| 'prot_t5_xl_half_uniref50_enc': 'ProtT5-XL-U50'} |
| data['Name'] = data['Name'].map(rename_dict) |
| return data, image_save_path |
| |
| def make_all_final_bar_charts(results_dir): |
| |
| data, image_save_path = prepare_data_for_bar(results_dir,"formation","verification",thresh=0.83) |
| data_cp = data.copy(deep=True) |
| data_cp["Value"] = data_cp["Value"].round(3) |
| data_cp.to_csv(image_save_path.replace(".png","_source_data.csv"),index=False) |
| make_final_bar(data, "Puncta Propensity", image_save_path) |
| |
| |
| data, image_save_path = prepare_data_for_bar(results_dir,"nucleus","verification",thresh=None) |
| data_cp = data.copy(deep=True) |
| data_cp["Value"] = data_cp["Value"].round(3) |
| data_cp.to_csv(image_save_path.replace(".png","_source_data.csv"),index=False) |
| make_final_bar(data, "Nucleus Localization", image_save_path) |
| |
| |
| data, image_save_path = prepare_data_for_bar(results_dir,"cytoplasm","verification",thresh=None) |
| data_cp = data.copy(deep=True) |
| data_cp["Value"] = data_cp["Value"].round(3) |
| data_cp.to_csv(image_save_path.replace(".png","_source_data.csv"),index=False) |
| make_final_bar(data, "Cytoplasm Localization", image_save_path) |
|
|
| def main(): |
| |
| results_dir="results/final" |
| os.makedirs(f"{results_dir}/figures",exist_ok=True) |
| make_all_final_bar_charts(results_dir) |
| |
| if __name__ == '__main__': |
| main() |