| import pandas as pd |
| import numpy as np |
| import pickle |
| from sklearn.manifold import TSNE |
| import matplotlib.font_manager as fm |
| from matplotlib.font_manager import FontProperties |
| import matplotlib.pyplot as plt |
| import matplotlib.gridspec as gridspec |
| import matplotlib.patches as patches |
| import seaborn as sns |
| import umap |
| import os |
|
|
| from fuson_plm.benchmarking.embed import embed_dataset_for_benchmark |
| import fuson_plm.benchmarking.embedding_exploration.config as config |
| from fuson_plm.utils.visualizing import set_font |
| from fuson_plm.utils.constants import TCGA_CODES, FODB_CODES, VALID_AAS, DELIMITERS |
| from fuson_plm.utils.logging import get_local_time, open_logfile, log_update, print_configpy |
|
|
|
|
| def get_dimred_embeddings(embeddings, dimred_type="umap"): |
| if dimred_type=="umap": |
| dimred_embeddings = get_umap_embeddings(embeddings) |
| return dimred_embeddings |
| if dimred_type=="tsne": |
| dimred_embeddings = get_tsne_embeddings(embeddings) |
| return dimred_embeddings |
|
|
| def get_tsne_embeddings(embeddings): |
| embeddings = np.array(embeddings) |
| tsne = TSNE(n_components=2, random_state=42,perplexity=5) |
| tsne_embeddings = tsne.fit_transform(embeddings) |
| return tsne_embeddings |
|
|
| def get_umap_embeddings(embeddings): |
| embeddings = np.array(embeddings) |
| umap_model = umap.UMAP(n_components=2, random_state=42, n_neighbors=15, metric='euclidean') |
| umap_embeddings = umap_model.fit_transform(embeddings) |
| return umap_embeddings |
|
|
| def plot_half_filled_circle(ax, x, y, left_color, right_color, size=100): |
| """ |
| Plots a circle filled in halves with specified colors. |
| |
| Parameters: |
| - ax: Matplotlib axis to draw on. |
| - x, y: Coordinates of the marker. |
| - left_color: Color of the left half. |
| - right_color: Color of the right half. |
| - size: Size of the marker. |
| """ |
| radius = (size ** 0.5) / 100 |
| |
| left_half = patches.Wedge((x, y), radius, 90, 270, color=left_color, ec="black") |
| |
| right_half = patches.Wedge((x, y), radius, 270, 90, color=right_color, ec="black") |
| |
| |
| ax.add_patch(left_half) |
| ax.add_patch(right_half) |
|
|
| def plot_umap_scatter_tftf_kk(df, filename="umap.png"): |
| """ |
| Plots a 2D scatterplot of UMAP coordinates with different markers and colors based on 'type'. |
| Only for TF::TF and Kinase::Kinase fusions |
| |
| Parameters: |
| - df (pd.DataFrame): DataFrame containing 'umap1', 'umap2', 'sequence', and 'type' columns. |
| """ |
| set_font() |
|
|
| |
| colors = { |
| "TF": "pink", |
| "Kinase": "orange" |
| } |
|
|
| |
| marker_colors = { |
| "TF::TF": colors["TF"], |
| "Kinase::Kinase": colors["Kinase"], |
| } |
|
|
| |
| fig, ax = plt.subplots(figsize=(10, 8)) |
| x_min, x_max = df["umap1"].min() - 1, df["umap1"].max() + 1 |
| y_min, y_max = df["umap2"].min() - 1, df["umap2"].max() + 1 |
| ax.set_xlim(x_min, x_max) |
| ax.set_ylim(y_min, y_max) |
| |
| |
| for i in range(len(df)): |
| row = df.iloc[i] |
| marker_type = row["fusion_type"] |
| x, y = row["umap1"], row["umap2"] |
| color = marker_colors[marker_type] |
| |
| ax.scatter(x, y, color=color, s=15, edgecolors="black", linewidth=0.5) |
| |
| |
| legend_elements = [ |
| patches.Patch(facecolor="pink", edgecolor="black", label="TF::TF"), |
| patches.Patch(facecolor="orange", edgecolor="black", label="Kinase::Kinase") |
| ] |
| ax.legend(handles=legend_elements, title="Fusion Type", fontsize=16, title_fontsize=16) |
|
|
| |
| plt.xlabel("UMAP 1", fontsize=20) |
| plt.ylabel("UMAP 2", fontsize=20) |
| plt.title("FusOn-pLM-embedded Transcription Factor and Kinase Fusions", fontsize=20) |
| plt.tight_layout() |
| |
| |
| plt.savefig(filename, dpi=300) |
| plt.show() |
| |
| def plot_umap_scatter_half_filled(df, filename="umap.png"): |
| """ |
| Plots a 2D scatterplot of UMAP coordinates with different markers and colors based on 'type'. |
| |
| Parameters: |
| - df (pd.DataFrame): DataFrame containing 'umap1', 'umap2', 'sequence', and 'type' columns. |
| """ |
| |
| colors = { |
| "TF": "pink", |
| "Kinase": "orange", |
| "Other": "grey" |
| } |
|
|
| |
| marker_colors = { |
| "TF::TF": {"left": colors["TF"], "right": colors["TF"]}, |
| "TF::Other": {"left": colors["TF"], "right": colors["Other"]}, |
| "Other::TF": {"left": colors["Other"], "right": colors["TF"]}, |
| "Kinase::Kinase": {"left": colors["Kinase"], "right": colors["Kinase"]}, |
| "Kinase::Other": {"left": colors["Kinase"], "right": colors["Other"]}, |
| "Other::Kinase": {"left": colors["Other"], "right": colors["Kinase"]}, |
| "Kinase::TF": {"left": colors["Kinase"], "right": colors["TF"]}, |
| "TF::Kinase": {"left": colors["TF"], "right": colors["Kinase"]}, |
| "Other::Other": {"left": colors["Other"], "right": colors["Other"]} |
| } |
|
|
| |
| fig, ax = plt.subplots(figsize=(10, 8)) |
| x_min, x_max = df["umap1"].min() - 1, df["umap1"].max() + 1 |
| y_min, y_max = df["umap2"].min() - 1, df["umap2"].max() + 1 |
| ax.set_xlim(x_min, x_max) |
| ax.set_ylim(y_min, y_max) |
| |
| |
| for i in range(len(df)): |
| row = df.iloc[i] |
| marker_type = row["fusion_type"] |
| x, y = row["umap1"], row["umap2"] |
| left_color = marker_colors[marker_type]["left"] |
| right_color = marker_colors[marker_type]["right"] |
| plot_half_filled_circle(ax, x, y, left_color, right_color, size=100) |
| |
| |
| legend_elements = [ |
| patches.Patch(facecolor="pink", edgecolor="black", label="TF"), |
| patches.Patch(facecolor="orange", edgecolor="black", label="Kinase"), |
| patches.Patch(facecolor="grey", edgecolor="black", label="Other") |
| ] |
| ax.legend(handles=legend_elements, title="Type") |
|
|
| |
| plt.xlabel("UMAP 1") |
| plt.ylabel("UMAP 2") |
| plt.title("UMAP Scatter Plot") |
| plt.tight_layout() |
| |
| |
| plt.savefig(filename, dpi=300) |
| plt.show() |
|
|
| def get_gene_type(gene, d): |
| if gene in d: |
| if d[gene] == 'kinase': |
| return 'Kinase' |
| if d[gene] == 'tf': |
| return 'TF' |
| else: |
| return 'Other' |
| |
| def get_tf_and_kinase_fusions_dataset(): |
| |
| tf_kinase_parts = pd.read_csv("data/salokas_2020_tableS3.csv") |
| print(tf_kinase_parts) |
| ht_tf_kinase_dict = dict(zip(tf_kinase_parts['Gene'],tf_kinase_parts['Kinase or TF'])) |
|
|
| |
| fuson_ht_db = pd.read_csv("../../data/blast/fuson_ht_db.csv") |
| fuson_ht_db[['hg','tg']] = fuson_ht_db['fusiongenes'].str.split("::",expand=True) |
|
|
| fuson_ht_db['hg_type'] = fuson_ht_db['hg'].apply(lambda x: get_gene_type(x, ht_tf_kinase_dict)) |
| fuson_ht_db['tg_type'] = fuson_ht_db['tg'].apply(lambda x: get_gene_type(x, ht_tf_kinase_dict)) |
| fuson_ht_db['fusion_type'] = fuson_ht_db['hg_type']+'::'+fuson_ht_db['tg_type'] |
| fuson_ht_db['type']=['fusion']*len(fuson_ht_db) |
| |
| categories = pd.DataFrame(fuson_ht_db['fusion_type'].value_counts()).reset_index()['index'].tolist() |
| categories = ["TF::TF","Kinase::Kinase"] |
| print(categories) |
| plot_df = None |
|
|
| for i, cat in enumerate(categories): |
| random_sample = fuson_ht_db.loc[fuson_ht_db['fusion_type']==cat].reset_index(drop=True) |
| |
| if i==0: |
| plot_df = random_sample |
| else: |
| plot_df = pd.concat([plot_df,random_sample],axis=0).reset_index(drop=True) |
|
|
| print(plot_df['fusion_type'].value_counts()) |
|
|
| |
| plot_df = plot_df[['aa_seq','fusiongenes','fusion_type','type']].rename( |
| columns={'aa_seq':'sequence','fusiongenes':'ID'} |
| ) |
| |
| return plot_df |
|
|
| def make_tf_and_kinase_fusions_plot(seqs_with_embeddings, savedir = '', dimred_type='umap'): |
| fuson_db = pd.read_csv("../../data/fuson_db.csv") |
| seq_id_dict = dict(zip(fuson_db['aa_seq'],fuson_db['seq_id'])) |
| |
| |
| data = seqs_with_embeddings[[f'{dimred_type}1',f'{dimred_type}2','sequence','fusion_type','ID']] |
| data['seq_id'] = data['sequence'].map(seq_id_dict) |
|
|
| tfkinase_save_dir = f"{savedir}" |
| os.makedirs(tfkinase_save_dir,exist_ok=True) |
| data.to_csv(f"{tfkinase_save_dir}/{dimred_type}_tf_and_kinase_fusions_source_data.csv",index=False) |
| plot_umap_scatter_tftf_kk(data,filename=f"{tfkinase_save_dir}/{dimred_type}_tf_and_kinase_fusions_visualization.png") |
| |
| def tf_and_kinase_fusions_plot(dimred_types, output_dir): |
| """ |
| Makes the embeddings, THEN calls the plot. only on the four favorites |
| """ |
| plot_df = get_tf_and_kinase_fusions_dataset() |
| plot_df.to_csv("data/tf_and_kinase_fusions.csv",index=False) |
| |
| |
| input_fname='tf_and_kinase' |
| all_embedding_paths = embed_dataset_for_benchmark( |
| fuson_ckpts=config.FUSON_PLM_CKPT, |
| input_data_path='data/tf_and_kinase_fusions.csv', input_fname=input_fname, |
| average=True, seq_col='sequence', |
| benchmark_fusonplm=True, |
| benchmark_esm=False, |
| benchmark_fo_puncta_ml=False, |
| overwrite=config.PERMISSION_TO_OVERWRITE) |
|
|
| |
| log_update("\nEmbedding sequences") |
| |
| for embedding_path, details in all_embedding_paths.items(): |
| log_update(f"\tBenchmarking embeddings at: {embedding_path}") |
| try: |
| with open(embedding_path, "rb") as f: |
| embeddings = pickle.load(f) |
| except: |
| raise Exception(f"Cannot read embeddings from {embedding_path}") |
| |
| |
| seqs_with_embeddings = pd.DataFrame.from_dict(embeddings.items()) |
| seqs_with_embeddings = seqs_with_embeddings.rename(columns={0: 'sequence', 1: 'embedding'}) |
| seqs_with_embeddings = pd.merge(seqs_with_embeddings, plot_df, on='sequence', how='inner') |
| |
| for dimred_type in dimred_types: |
| dimred_embeddings = get_dimred_embeddings(seqs_with_embeddings['embedding'].tolist(),dimred_type=dimred_type) |
|
|
| |
| data = pd.DataFrame(dimred_embeddings, columns=[f'{dimred_type}1', f'{dimred_type}2']) |
| |
| model_name = "_".join(embedding_path.split('embeddings/')[1].split('/')[1:-1]) |
| |
| seqs_with_embeddings[[f'{dimred_type}1', f'{dimred_type}2']] = data |
|
|
| |
| intermediate = '/'.join(embedding_path.split('embeddings/')[1].split('/')[0:-1]) |
| cur_output_dir = f"{output_dir}/{dimred_type}_plots/{intermediate}/{input_fname}" |
| |
| os.makedirs(cur_output_dir,exist_ok=True) |
| make_tf_and_kinase_fusions_plot(seqs_with_embeddings, savedir = cur_output_dir, dimred_type=dimred_type) |
| |
| def make_fusion_v_parts_favorites_plot(seqs_with_embeddings, savedir = None, dimred_type='umap'): |
| """ |
| Make plots showing that PAX3::FOXO1, EWS::FLI1, SS18::SSX1, EML4::ALK are embedded distinctly from their heads and tails |
| """ |
| set_font() |
| |
| |
| data = pd.read_csv("data/top_genes.csv") |
| seqs_with_embeddings = pd.merge(seqs_with_embeddings, data, on="sequence") |
| seqs_with_embeddings["Type"] = [""]*len(seqs_with_embeddings) |
| seqs_with_embeddings.loc[ |
| seqs_with_embeddings["gene"].str.contains("::"),"Type" |
| ] = "fusion_embeddings" |
| heads = seqs_with_embeddings.loc[seqs_with_embeddings["gene"].str.contains("::")]["gene"].str.split("::",expand=True)[0].tolist() |
| tails = seqs_with_embeddings.loc[seqs_with_embeddings["gene"].str.contains("::")]["gene"].str.split("::",expand=True)[1].tolist() |
| seqs_with_embeddings.loc[ |
| seqs_with_embeddings["gene"].isin(heads),"Type" |
| ] = "h_embeddings" |
| seqs_with_embeddings.loc[ |
| seqs_with_embeddings["gene"].isin(tails),"Type" |
| ] = "t_embeddings" |
| |
| |
| merge = seqs_with_embeddings.loc[seqs_with_embeddings['gene'].str.contains('::')].reset_index(drop=True)[['gene','sequence']] |
| merge["head"] = merge["gene"].str.split("::",expand=True)[0] |
| merge["tail"] = merge["gene"].str.split("::",expand=True)[1] |
| merge = pd.merge(merge, seqs_with_embeddings[['gene','sequence']].rename( |
| columns={'gene': 'head', 'sequence': 'h_sequence'}), |
| on='head',how='left' |
| ) |
| merge = pd.merge(merge, seqs_with_embeddings[['gene','sequence']].rename( |
| columns={'gene': 'tail', 'sequence': 't_sequence'}), |
| on='tail',how='left' |
| ) |
| |
| plt.figure() |
|
|
| |
| colors = { |
| 'fusion_embeddings': '#cf9dfa', |
| 'h_embeddings': '#eb8888', |
| 't_embeddings': '#5fa3e3', |
| } |
| markers = { |
| 'fusion_embeddings': 'o', |
| 'h_embeddings': '^', |
| 't_embeddings': 'v' |
| } |
| label_map = { |
| 'fusion_embeddings': 'Fusion', |
| 'h_embeddings': 'Head', |
| 't_embeddings': 'Tail', |
| } |
|
|
| |
| fig, axes = plt.subplots(2, 3, figsize=(18, 12)) |
| |
|
|
| |
| all_tsne1 = seqs_with_embeddings[f'{dimred_type}1'] |
| all_tsne2 = seqs_with_embeddings[f'{dimred_type}2'] |
| x_min, x_max = all_tsne1.min(), all_tsne1.max() |
| y_min, y_max = all_tsne2.min(), all_tsne2.max() |
| x_min, x_max = [11, 16] |
| y_min, y_max = [10, 22] |
|
|
| |
| x_ticks = np.arange(x_min, x_max + 1, 1) |
| y_ticks = np.arange(y_min, y_max + 1, 1) |
|
|
| |
| axes = axes.flatten() |
|
|
| for i, ax in enumerate(axes): |
| |
| fgene_name = merge.loc[i, 'gene'] |
| hgene = merge.loc[i, 'head'] |
| tgene = merge.loc[i, 'tail'] |
|
|
| |
| tsne_data = seqs_with_embeddings[seqs_with_embeddings['gene'].isin([fgene_name, hgene, tgene])] |
|
|
| |
| for emb_type in tsne_data['Type'].unique(): |
| subset = tsne_data[tsne_data['Type'] == emb_type] |
| ax.scatter(subset[f'{dimred_type}1'], subset[f'{dimred_type}2'], label=label_map[emb_type], color=colors[emb_type], marker=markers[emb_type], s=120, zorder=3) |
|
|
| ax.set_title(f'{fgene_name}',fontsize=44) |
| label_transform = { |
| 'tsne': 't-SNE', |
| 'umap': 'UMAP' |
| } |
| ax.set_xlabel(f'{label_transform[dimred_type]} 1',fontsize=44) |
| ax.set_ylabel(f'{label_transform[dimred_type]} 2',fontsize=44) |
| ax.grid(True, which='both', linestyle='--', linewidth=0.5, color='gray', zorder=1) |
|
|
| |
| ax.set_xlim(x_min, x_max) |
| ax.set_ylim(y_min, y_max) |
| ax.set_xticks(x_ticks) |
| ax.set_yticks(y_ticks) |
|
|
| |
| ax.set_xticklabels(ax.get_xticks(), rotation=45, ha='right') |
|
|
| ax.tick_params(axis='x', labelsize=16) |
| ax.tick_params(axis='y', labelsize=16) |
|
|
| for label in ax.get_xticklabels(): |
| label.set_fontsize(24) |
| for label in ax.get_yticklabels(): |
| label.set_fontsize(24) |
|
|
| |
| if i == 0: |
| legend = ax.legend(fontsize=20, markerscale=2, loc='best') |
| for text in legend.get_texts(): |
| text.set_fontsize(24) |
|
|
| |
| plt.tight_layout() |
| |
| |
| plt.show() |
|
|
| |
| plt.savefig(f'{savedir}/{dimred_type}_favorites_visualization.png', dpi=300) |
| |
| |
| seq_to_id_dict = pd.read_csv("../../data/fuson_db.csv") |
| seq_to_id_dict = dict(zip(seq_to_id_dict['aa_seq'],seq_to_id_dict['seq_id'])) |
| seqs_with_embeddings['seq_id'] = seqs_with_embeddings['sequence'].map(seq_to_id_dict) |
| seqs_with_embeddings[['umap1','umap2','sequence','Type','gene','id','seq_id']].to_csv(f"{savedir}/{dimred_type}_favorites_source_data.csv",index=False) |
| |
| def fusion_v_parts_favorites(dimred_types, output_dir): |
| """ |
| Makes the embeddings, THEN calls the plot. only on the four favorites |
| """ |
| |
| |
| input_fname='favorites' |
| all_embedding_paths = embed_dataset_for_benchmark( |
| fuson_ckpts=config.FUSON_PLM_CKPT, |
| input_data_path='data/top_genes.csv', input_fname=input_fname, |
| average=True, seq_col='sequence', |
| benchmark_fusonplm=True, |
| benchmark_esm=False, |
| benchmark_fo_puncta_ml=False, |
| overwrite=config.PERMISSION_TO_OVERWRITE) |
|
|
| |
| log_update("\nEmbedding sequences") |
| |
| for embedding_path, details in all_embedding_paths.items(): |
| log_update(f"\tBenchmarking embeddings at: {embedding_path}") |
| try: |
| with open(embedding_path, "rb") as f: |
| embeddings = pickle.load(f) |
| except: |
| raise Exception(f"Cannot read embeddings from {embedding_path}") |
| |
| |
| seqs_with_embeddings = pd.DataFrame.from_dict(embeddings.items()) |
| seqs_with_embeddings = seqs_with_embeddings.rename(columns={0: 'sequence', 1: 'embedding'}) |
| |
| |
| for dimred_type in dimred_types: |
| dimred_embeddings = get_dimred_embeddings(seqs_with_embeddings['embedding'].tolist(),dimred_type=dimred_type) |
|
|
| |
| data = pd.DataFrame(dimred_embeddings, columns=[f'{dimred_type}1', f'{dimred_type}2']) |
| |
| model_name = "_".join(embedding_path.split('embeddings/')[1].split('/')[1:-1]) |
| |
| seqs_with_embeddings[[f'{dimred_type}1', f'{dimred_type}2']] = data |
|
|
| |
| intermediate = '/'.join(embedding_path.split('embeddings/')[1].split('/')[0:-1]) |
| cur_output_dir = f"{output_dir}/{dimred_type}_plots/{intermediate}/{input_fname}" |
| |
| os.makedirs(cur_output_dir,exist_ok=True) |
| make_fusion_v_parts_favorites_plot(seqs_with_embeddings, savedir = cur_output_dir, dimred_type=dimred_type) |
| |
| def main(): |
| |
| os.makedirs('results',exist_ok=True) |
| output_dir = f'results/{get_local_time()}' |
| os.makedirs(output_dir,exist_ok=True) |
| |
| dimred_types = [] |
| if config.PLOT_UMAP: |
| dimred_types.append("umap") |
| |
| os.makedirs(f"{output_dir}/umap_plots",exist_ok=True) |
| if config.PLOT_TSNE: |
| dimred_types.append("tsne") |
| |
| os.makedirs(f"{output_dir}/tsne_plots",exist_ok=True) |
| |
| with open_logfile(f'{output_dir}/embedding_exploration_log.txt'): |
| print_configpy(config) |
| |
| fusion_v_parts_favorites(dimred_types, output_dir) |
| |
| tf_and_kinase_fusions_plot(dimred_types, output_dir) |
| |
| |
| if __name__ == "__main__": |
| main() |