import numpy as np import matplotlib.pyplot as plt import seaborn as sns from sklearn.metrics import roc_curve, roc_auc_score from scipy import stats from scipy.stats import ttest_rel import pandas as pd def plot_conf_matrix_mlm_vs_nomlm(cms_mlm, cms_nomlm, m_type, only_agg=True, suptitle="Confusion Matrix Comparison"): labels = ['Dead-end', 'Reprogramming'] if only_agg: # Plot only the aggregated confusion matrices (last one in each list) cms_mlm_agg = cms_mlm[-1] cms_nomlm_agg = cms_nomlm[-1] f = plt.figure(figsize=(12, 5)) plt.suptitle(suptitle, fontsize=16) # Plot confusion matrix for aggregated MLM plt.subplot(1, 2, 1) sns.heatmap(cms_mlm_agg, annot=True, cmap='Blues', fmt='g', xticklabels=labels, yticklabels=labels) plt.xlabel('Predicted') plt.ylabel('Actual') plt.title('Confusion Matrix - MLM (Aggregated)') # Plot confusion matrix for aggregated No MLM plt.subplot(1, 2, 2) sns.heatmap(cms_nomlm_agg, annot=True, cmap='Blues', fmt='g', xticklabels=labels, yticklabels=labels) plt.xlabel('Predicted') plt.ylabel('Actual') plt.title('Confusion Matrix - No MLM (Aggregated)') f.savefig(f'./figures/confusion_matrices_{m_type}.pdf', bbox_inches='tight') plt.tight_layout() plt.show() else: # Plot confusion matrices for each fold n_folds = len(cms_mlm) f = plt.figure(figsize=(15, 2 * n_folds)) # Adjust figure size according to the number of folds plt.suptitle(suptitle, fontsize=16) for i in range(n_folds): # Plot confusion matrix for MLM in the first row (subplot) plt.subplot(n_folds, 2, i*2 + 1) # First column (MLM) sns.heatmap(cms_mlm[i], annot=True, cmap='Blues', fmt='g', xticklabels=labels, yticklabels=labels) plt.xlabel('Predicted') plt.ylabel('Actual') plt.title(f'Confusion Matrix - MLM (Fold {i+1})') # Plot confusion matrix for No MLM in the second column (subplot) plt.subplot(n_folds, 2, i*2 + 2) # Second column (No MLM) sns.heatmap(cms_nomlm[i], annot=True, cmap='Blues', fmt='g', xticklabels=labels, yticklabels=labels) plt.xlabel('Predicted') plt.ylabel('Actual') plt.title(f'Confusion Matrix - No MLM (Fold {i+1})') f.savefig(f'./figures/confusion_matrices_folds_{m_type}.pdf', bbox_inches='tight') plt.tight_layout(rect=[0, 0, 1, 0.96]) plt.show() def plot_training_vs_validation_losses(train_losses, val_losses, title="Losses"): epochs = len(train_losses) f = plt.figure(figsize=(10, 3)) plt.suptitle(title) plt.subplot(1, 2, 1) plt.plot(range(1, epochs+1), train_losses) plt.xlabel('Epoch') plt.ylabel('Train Loss') plt.title('Train Loss') plt.subplot(1, 2, 2) plt.plot(range(1, epochs+1), val_losses) plt.xlabel('Epoch') plt.ylabel('Validation Loss') plt.title('Validation Loss') f.savefig('./figures/losses.pdf', bbox_inches='tight') plt.tight_layout() plt.show() def plot_roc_auc_curve(val_preds, val_labels, m_type, aggregate=False): if aggregate: # Aggregate all folds into one list all_labels = np.concatenate(val_labels).ravel() all_preds = np.concatenate(val_preds).ravel() auc = roc_auc_score(all_labels, all_preds) fpr, tpr, _ = roc_curve(all_labels, all_preds) f = plt.figure() plt.plot(fpr, tpr, label=f'Aggregated AUC: {auc:.4f}') plt.plot([0, 1], [0, 1], linestyle='--') plt.xlabel('False Positive Rate') plt.ylabel('True Positive Rate') plt.title('ROC Curve (Aggregated)') plt.legend() f.savefig(f'./figures/roc_curve_{m_type}.pdf', bbox_inches='tight') plt.show() else: # Plot AUC for each fold separately f = plt.figure() for i, (labels, preds) in enumerate(zip(val_labels, val_preds), 1): auc = roc_auc_score(labels, preds) fpr, tpr, _ = roc_curve(labels, preds) plt.plot(fpr, tpr, label=f'Fold {i} AUC: {auc:.4f}') plt.plot([0, 1], [0, 1], linestyle='--') plt.xlabel('False Positive Rate') plt.ylabel('True Positive Rate') plt.title('ROC Curve (Each Fold)') plt.legend() f.savefig(f'./figures/roc_curve_{m_type}.pdf', bbox_inches='tight') plt.show() def plot_auc_boxplot_comparison(fold_results1, fold_results2, title="AUC Comparison"): """Plot AUC box comparison between two models.""" train_auc_scores_mlm = [fold['train_auc'] for fold in fold_results1] train_auc_scores_nomlm = [fold['train_auc'] for fold in fold_results2] val_auc_scores_mlm = [fold['best_val_auc'] for fold in fold_results1] val_auc_scores_nomlm = [fold['best_val_auc'] for fold in fold_results2] train_p_value = ttest_rel(train_auc_scores_mlm, train_auc_scores_nomlm).pvalue val_p_value = ttest_rel(val_auc_scores_mlm, val_auc_scores_nomlm).pvalue df_train = pd.DataFrame({ 'Fold': [f'Fold {i+1}' for i in range(len(val_auc_scores_mlm))], 'with MLM': train_auc_scores_mlm, 'without MLM': train_auc_scores_nomlm, }) df_valid = pd.DataFrame({ 'Fold': [f'Fold {i+1}' for i in range(len(val_auc_scores_mlm))], 'with MLM': val_auc_scores_mlm, 'without MLM': val_auc_scores_nomlm }) f = plt.figure(figsize=(12, 8)) plt.suptitle(title) plt.subplot(1, 2, 1) sns.boxplot(data=df_train, palette=["#1f77b4", "#ff7f0e"]) # Custom colors plt.title(f'Train AUC Comparison (p-value = {train_p_value:.4f})') plt.ylabel('AUC') plt.ylim(0.5, 1) plt.subplot(1, 2, 2) sns.boxplot(data=df_valid, palette=["#2ca02c", "#d62728"]) # Custom colors plt.title(f'Validation AUC Comparison (p-value = {val_p_value:.4f})') plt.ylabel('AUC') plt.ylim(0.5, 1) f.savefig('./figures/auc_comparison.pdf', bbox_inches='tight') plt.tight_layout() plt.show() def plot_loss_comparison_mlm_vs_nomlm(fold_results1, fold_results2, title="Loss Comparison"): """Plot loss comparison between two models.""" f = plt.figure(figsize=(12, 8)) for i, fold in enumerate(fold_results1): train_losses_mlm = fold['metrics']['train_loss'] val_losses_mlm = fold['metrics']['val_loss'] train_losses_nomlm = fold_results2[i]['metrics']['train_loss'] val_losses_nomlm = fold_results2[i]['metrics']['val_loss'] epochs = range(1, len(train_losses_mlm) + 1) plt.plot(epochs, train_losses_mlm, 'o-', label=f'Train Loss w/ Pre-Training - Fold {fold["fold"]}', alpha=0.5) plt.plot(epochs, val_losses_mlm, 'x-', label=f'Validation Loss w/ Pre-Training - Fold {fold["fold"]}', alpha=0.5) plt.plot(epochs, train_losses_nomlm, 'o--', label=f'Train Loss w/o Pre-Training - Fold {fold["fold"]}', alpha=0.5) plt.plot(epochs, val_losses_nomlm, 'x--', label=f'Validation Loss w/o Pre-Training - Fold {fold["fold"]}', alpha=0.5) plt.xlabel('Epochs') plt.ylabel('Loss') plt.title(title) plt.legend(loc='upper right', bbox_to_anchor=(1.3, 1)) f.savefig('./figures/loss_comparison.pdf', bbox_inches='tight') plt.show() def plot_fold_losses(fold_results, title="Losses"): """Plot loss for each fold.""" f = plt.figure(figsize=(12, 8)) for i, fold in enumerate(fold_results): train_losses = fold['metrics']['train_loss'] val_losses = fold['metrics']['val_loss'] epochs = range(1, len(train_losses) + 1) plt.plot(epochs, train_losses, 'o-', label=f'Train Loss - Fold {fold["fold"]}', alpha=0.5) plt.plot(epochs, val_losses, 'x-', label=f'Validation Loss - Fold {fold["fold"]}', alpha=0.5) plt.xlabel('Epochs') plt.ylabel('Loss') plt.title(title) plt.legend(loc='upper right', bbox_to_anchor=(1.3, 1)) f.savefig('./figures/fold_losses.pdf', bbox_inches='tight') plt.show() def plot_data_distribution(adata_RNA, adata_ATAC, adata_Flux, title="Data Distribution"): fig, axes = plt.subplots(1, 3, figsize=(15, 5)) plt.suptitle(title) data = adata_RNA.X.toarray().flatten() sns.histplot(data, bins=100, ax=axes[0], color='skyblue') var, mean = data.var(), data.mean() axes[0].set_title(f'RNA Distribution, var:{var:.2f}, mean:{mean:.2f}') axes[0].set_xlabel('Expression level') axes[0].set_ylabel('Frequency') data = adata_ATAC.X.toarray().flatten() sns.histplot(data, bins=100, ax=axes[1], color='lightgreen') var, mean = data.var(), data.mean() axes[1].set_title(f'ATAC Distribution, var:{var:.3f}, mean:{mean:.2f}') axes[1].set_xlabel('Accessibility level') axes[1].set_ylabel('Frequency') data = adata_Flux.values.flatten() sns.histplot(data, bins=100, ax=axes[2], color='salmon') var, mean = data.var(), data.mean() axes[2].set_title(f'Fluxomic Distribution, var:{var:.5f}, mean:{mean:.2f}') axes[2].set_xlabel('Flux value') axes[2].set_ylabel('Frequency') fig.savefig('./figures/data_distribution.pdf', bbox_inches='tight') plt.tight_layout() plt.show() def plot_att_weights(all_attention, dead_end_attention, reprogramming_attention, feature_names=None, print_top_features=False, top_n=5, scale_weights=False, fix_scale=False, use_mean_contribution=False): print(all_attention.shape, "all_attention.shape") print(dead_end_attention.shape, "dead_end_attention.shape") print(reprogramming_attention.shape, "reprogramming_attention.shape") def minmax_scale(arr): arr = np.asarray(arr) min_val = arr.min() max_val = arr.max() if max_val - min_val == 0: return np.zeros_like(arr) # avoid divide by zero return (arr - min_val) / (max_val - min_val) avg_all_attention = all_attention.mean(axis=0) # Average attention weights across samples avg_dead_end_attention = dead_end_attention.mean(axis=0) avg_reprogramming_attention = reprogramming_attention.mean(axis=0) # Store original unscaled versions for modality contribution calculation avg_all_attention_orig = avg_all_attention.copy() if hasattr(avg_all_attention, 'copy') else np.array(avg_all_attention) avg_dead_end_attention_orig = avg_dead_end_attention.copy() if hasattr(avg_dead_end_attention, 'copy') else np.array(avg_dead_end_attention) avg_reprogramming_attention_orig = avg_reprogramming_attention.copy() if hasattr(avg_reprogramming_attention, 'copy') else np.array(avg_reprogramming_attention) if scale_weights: avg_all_attention = minmax_scale(avg_all_attention) avg_dead_end_attention = minmax_scale(avg_dead_end_attention) avg_reprogramming_attention = minmax_scale(avg_reprogramming_attention) vmin, vmax = 0.0, 1.0 elif fix_scale: # fix scale of all attention weights to the same range vmin, vmax = avg_all_attention.min(), avg_all_attention.max() else: vmin, vmax = None, None # Visualize average attention weights f = plt.figure(figsize=(15, 3)) divider1 = 945 divider2 = 945 + 884 def add_modality_labels(ax, attention_weights, attention_weights_orig, use_mean=False): rna_weights = attention_weights_orig[:divider1] atac_weights = attention_weights_orig[divider1:divider2] flux_weights = attention_weights_orig[divider2:] # Calculate metric based on method if use_mean is False or use_mean == 'sum': # Sum of all attention weights (original behavior) rna_metric = rna_weights.sum() atac_metric = atac_weights.sum() flux_metric = flux_weights.sum() elif use_mean is True or use_mean == 'mean': # Mean attention per feature rna_metric = rna_weights.mean() atac_metric = atac_weights.mean() flux_metric = flux_weights.mean() elif use_mean == 'median': # Median attention per feature (robust to zeros and outliers) rna_metric = np.median(rna_weights) atac_metric = np.median(atac_weights) flux_metric = np.median(flux_weights) elif use_mean == 'trimmed_mean': # Trimmed mean: exclude lowest 25% and highest 5% rna_metric = stats.trim_mean(rna_weights, proportiontocut=0.15) # removes 15% from each tail atac_metric = stats.trim_mean(atac_weights, proportiontocut=0.15) flux_metric = stats.trim_mean(flux_weights, proportiontocut=0.15) elif use_mean == 'active_mean': # Mean of only "active" features (attention > threshold) threshold = np.percentile(attention_weights_orig, 25) # bottom 25% considered inactive rna_active = rna_weights[rna_weights > threshold] atac_active = atac_weights[atac_weights > threshold] flux_active = flux_weights[flux_weights > threshold] rna_metric = rna_active.mean() if len(rna_active) > 0 else 0 atac_metric = atac_active.mean() if len(atac_active) > 0 else 0 flux_metric = flux_active.mean() if len(flux_active) > 0 else 0 else: raise ValueError(f"Invalid use_mean value: {use_mean}") # # Normalize to percentages # print(rna_metric, atac_metric, flux_metric, "rna_metric, atac_metric, flux_metric") # total_metric = rna_metric + atac_metric + flux_metric # rna_pct = (rna_metric / total_metric * 100) if total_metric > 0 else 0 # atac_pct = (atac_metric / total_metric * 100) if total_metric > 0 else 0 # flux_pct = (flux_metric / total_metric * 100) if total_metric > 0 else 0 # Calculate center positions for each modality n_rna = divider1 n_atac = divider2 - divider1 n_flux = len(attention_weights) - divider2 rna_center = n_rna / 2 atac_center = divider1 + n_atac / 2 flux_center = divider2 + n_flux / 2 rna_metric_mean = rna_metric / n_rna atac_metric_mean = atac_metric / n_atac flux_metric_mean = flux_metric / n_flux ax.text(rna_center, -0.3, f'Sum: {rna_metric:.3f}\nMean: {rna_metric_mean:.3f}', ha='center', va='bottom', fontsize=10, fontweight='bold') ax.text(atac_center, -0.3, f'Sum: {atac_metric:.3f}\nMean: {atac_metric_mean:.3f}', ha='center', va='bottom', fontsize=10, fontweight='bold') ax.text(flux_center, -0.3, f'Sum: {flux_metric:.3f}\nMean: {flux_metric_mean:.3f}', ha='center', va='bottom', fontsize=10, fontweight='bold') plt.subplot(1, 3, 1) ax1 = plt.gca() sns.heatmap(avg_all_attention.reshape(1, -1), cmap='viridis', yticklabels=['All'], vmin=vmin, vmax=vmax, ax=ax1) plt.title('Avg Att. W. (All Samples)') plt.xlabel('Features') plt.xticks([]) plt.axvline(x=divider1, color='red', linestyle='--', linewidth=2) plt.axvline(x=divider2, color='red', linestyle='--', linewidth=2) add_modality_labels(ax1, avg_all_attention, avg_all_attention_orig, use_mean=use_mean_contribution) plt.subplot(1, 3, 2) ax2 = plt.gca() sns.heatmap(avg_dead_end_attention.reshape(1, -1), cmap='viridis', yticklabels=['Dead-end'], vmin=vmin, vmax=vmax, ax=ax2) plt.title('Avg Att. W. (Dead-end Samples)') plt.xlabel('Features') plt.xticks([]) plt.axvline(x=divider1, color='red', linestyle='--', linewidth=2) plt.axvline(x=divider2, color='red', linestyle='--', linewidth=2) add_modality_labels(ax2, avg_dead_end_attention, avg_dead_end_attention_orig, use_mean=use_mean_contribution) plt.subplot(1, 3, 3) ax3 = plt.gca() sns.heatmap(avg_reprogramming_attention.reshape(1, -1), cmap='viridis', yticklabels=['Reprogramming'], vmin=vmin, vmax=vmax, ax=ax3) plt.title('Avg Att. W. (Reprogramming Samples)') plt.xlabel('Features') plt.xticks([]) plt.axvline(x=divider1, color='red', linestyle='--', linewidth=2) plt.axvline(x=divider2, color='red', linestyle='--', linewidth=2) add_modality_labels(ax3, avg_reprogramming_attention, avg_reprogramming_attention_orig, use_mean=use_mean_contribution) # f.savefig('./figures/attention_weights.pdf', bbox_inches='tight') plt.tight_layout() plt.show() if print_top_features: def get_top_features(attention_weights, feature_names, top_n=top_n): avg_attention = attention_weights.mean(axis=0).numpy() if hasattr(attention_weights, 'numpy') else attention_weights.mean(axis=0) print(avg_attention.shape, len(feature_names)) top_indices = avg_attention.argsort()[-top_n:][::-1] print(top_indices) return [(feature_names[i], avg_attention[i]) for i in top_indices] top_all = get_top_features(all_attention, feature_names) top_dead_end = get_top_features(dead_end_attention, feature_names) top_reprogramming = get_top_features(reprogramming_attention, feature_names) print(f"Top {top_n} attended features (All samples):") for feature, weight in top_all: print(f"{feature}: {weight:.4f}", end=", ") print(f"\nTop {top_n} attended features (Dead-end samples):") for feature, weight in top_dead_end: print(f"{feature}: {weight:.4f}", end=", ") print(f"\nTop {top_n} attended features (Reprogramming samples):") for feature, weight in top_reprogramming: print(f"{feature}: {weight:.4f}", end=", ") return f def plot_att_weights_distribution( all_attention, dead_end_attention, reprogramming_attention, feature_names=None, plot_type='violin', top_n=5, print_means=False ): divider1 = 944 # RNA ends divider2 = 944 + 883 # ATAC ends, Flux begins divider1 = 945 divider2 = 945 + 884 # Prepare data for plotting def prepare_modality_data(attention_weights, condition_name): """Extract attention weights by modality""" rna_weights = attention_weights[:, :divider1].flatten() atac_weights = attention_weights[:, divider1:divider2].flatten() flux_weights = attention_weights[:, divider2:].flatten() return { 'RNA': rna_weights, 'ATAC': atac_weights, 'Flux': flux_weights, 'condition': condition_name, } all_data = prepare_modality_data(all_attention, 'All') de_data = prepare_modality_data(dead_end_attention, 'Dead-end') re_data = prepare_modality_data(reprogramming_attention, 'Reprogramming') if plot_type in ['violin', 'box']: # Create DataFrame for seaborn plotting data_list = [] for condition_data in [all_data, de_data, re_data]: condition = condition_data['condition'] for modality in ['RNA', 'ATAC', 'Flux']: weights = condition_data[modality] for weight in weights: data_list.append({ 'Condition': condition, 'Modality': modality, 'Attention Weight': weight }) df = pd.DataFrame(data_list) # Create figure with subplots for each condition f, axes = plt.subplots(1, 3, figsize=(18, 5)) conditions = ['All', 'Dead-end', 'Reprogramming'] colors = ['#1f77b4', '#ff7f0e', '#2ca02c'] # RNA, ATAC, Flux colors # Optionally print means if print_means: print("Mean attention weight values per modality and per condition:") for idx, (ax, condition) in enumerate(zip(axes, conditions)): condition_df = df[df['Condition'] == condition] if plot_type == 'violin': sns.violinplot(data=condition_df, x='Modality', y='Attention Weight', palette=colors, ax=ax) else: # box sns.boxplot(data=condition_df, x='Modality', y='Attention Weight', palette=colors, ax=ax) ax.set_title(f'{condition} Samples', fontsize=12, fontweight='bold') ax.set_xlabel('Modality', fontsize=11) ax.set_ylabel('Attention Weight', fontsize=11) ax.grid(axis='y', alpha=0.3) for i, modality in enumerate(['RNA', 'ATAC', 'Flux']): mod_data = condition_df[condition_df['Modality'] == modality]['Attention Weight'] mean_val = mod_data.mean() std_val = mod_data.std() ax.hlines(mean_val, i - 0.4, i + 0.4, colors='red', linestyles='--', linewidth=2, alpha=0.7, label='Mean' if i == 0 else '') if print_means: print(f"{condition} - {modality}: mean={mean_val:.8f}, std={std_val:.8f}") if idx == 0: ax.legend() else: raise ValueError(f"plot_type must be 'violin', 'box', or 'hist', got '{plot_type}'") plt.tight_layout() plt.show() return f def plot_att_heads(all_attention_heads, dead_end_attention_heads, reprogramming_attention_heads, stacked=False): n_heads = all_attention_heads.shape[1] # Assuming the second dimension is the number of heads if stacked: # Visualize stacked attention weights f = plt.figure(figsize=(15, 10)) # Adjust figure size # Plot for "All Samples" attention weights (stacked) plt.subplot(1, 3, 1) stacked_all_attention = all_attention_heads.mean(axis=0).reshape(n_heads, -1) # Stack attention heads sns.heatmap(stacked_all_attention, cmap='viridis', yticklabels=[f'Head {i+1}' for i in range(n_heads)]) plt.title('Stacked Attention Weights (All Samples)') plt.xlabel('Features') plt.ylabel('Heads') plt.xticks(rotation=90) # Plot for "Dead-end Samples" attention weights (stacked) plt.subplot(1, 3, 2) stacked_dead_end_attention = dead_end_attention_heads.mean(axis=0).reshape(n_heads, -1) sns.heatmap(stacked_dead_end_attention, cmap='viridis', yticklabels=[f'Head {i+1}' for i in range(n_heads)]) plt.title('Stacked Attention Weights (Dead-end Samples)') plt.xlabel('Features') plt.ylabel('Heads') plt.xticks(rotation=90) # Plot for "Reprogramming Samples" attention weights (stacked) plt.subplot(1, 3, 3) stacked_reprogramming_attention = reprogramming_attention_heads.mean(axis=0).reshape(n_heads, -1) sns.heatmap(stacked_reprogramming_attention, cmap='viridis', yticklabels=[f'Head {i+1}' for i in range(n_heads)]) plt.title('Stacked Attention Weights (Reprogramming Samples)') plt.xlabel('Features') plt.ylabel('Heads') plt.xticks(rotation=90) f.savefig('./figures/attention_heads_stacked.pdf', bbox_inches='tight') plt.tight_layout() plt.show() else: # Visualize attention weights for each head f = plt.figure(figsize=(15, 15)) # Adjusting the figure size to accommodate more subplots # Plot for "All Samples" attention weights for head in range(n_heads): plt.subplot(n_heads, 3, 3 * head + 1) # (n_heads rows, 3 columns for each category) sns.heatmap(all_attention_heads[:, head, :].mean(axis=0).reshape(1, -1), cmap='viridis', yticklabels=[f'Head {head+1}']) plt.title(f'Head {head+1} Attention (All Samples)') plt.xlabel('Features') plt.xticks(rotation=90) # Plot for "Dead-end Samples" attention weights plt.subplot(n_heads, 3, 3 * head + 2) sns.heatmap(dead_end_attention_heads[:, head, :].mean(axis=0).reshape(1, -1), cmap='viridis', yticklabels=[f'Head {head+1}']) plt.title(f'Head {head+1} Attention (Dead-end Samples)') plt.xlabel('Features') plt.xticks(rotation=90) # Plot for "Reprogramming Samples" attention weights plt.subplot(n_heads, 3, 3 * head + 3) sns.heatmap(reprogramming_attention_heads[:, head, :].mean(axis=0).reshape(1, -1), cmap='viridis', yticklabels=[f'Head {head+1}']) plt.title(f'Head {head+1} Attention (Reprogramming Samples)') plt.xlabel('Features') plt.xticks(rotation=90) f.savefig('./figures/attention_heads.pdf', bbox_inches='tight') plt.tight_layout() plt.show()