FateFormerExplorer / interpretation /visualization.py
kaveh's picture
init
ef814bf
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import roc_curve, roc_auc_score
from scipy import stats
from scipy.stats import ttest_rel
import pandas as pd
def plot_conf_matrix_mlm_vs_nomlm(cms_mlm, cms_nomlm, m_type, only_agg=True, suptitle="Confusion Matrix Comparison"):
labels = ['Dead-end', 'Reprogramming']
if only_agg:
# Plot only the aggregated confusion matrices (last one in each list)
cms_mlm_agg = cms_mlm[-1]
cms_nomlm_agg = cms_nomlm[-1]
f = plt.figure(figsize=(12, 5))
plt.suptitle(suptitle, fontsize=16)
# Plot confusion matrix for aggregated MLM
plt.subplot(1, 2, 1)
sns.heatmap(cms_mlm_agg, annot=True, cmap='Blues', fmt='g', xticklabels=labels, yticklabels=labels)
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix - MLM (Aggregated)')
# Plot confusion matrix for aggregated No MLM
plt.subplot(1, 2, 2)
sns.heatmap(cms_nomlm_agg, annot=True, cmap='Blues', fmt='g', xticklabels=labels, yticklabels=labels)
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix - No MLM (Aggregated)')
f.savefig(f'./figures/confusion_matrices_{m_type}.pdf', bbox_inches='tight')
plt.tight_layout()
plt.show()
else:
# Plot confusion matrices for each fold
n_folds = len(cms_mlm)
f = plt.figure(figsize=(15, 2 * n_folds)) # Adjust figure size according to the number of folds
plt.suptitle(suptitle, fontsize=16)
for i in range(n_folds):
# Plot confusion matrix for MLM in the first row (subplot)
plt.subplot(n_folds, 2, i*2 + 1) # First column (MLM)
sns.heatmap(cms_mlm[i], annot=True, cmap='Blues', fmt='g', xticklabels=labels, yticklabels=labels)
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title(f'Confusion Matrix - MLM (Fold {i+1})')
# Plot confusion matrix for No MLM in the second column (subplot)
plt.subplot(n_folds, 2, i*2 + 2) # Second column (No MLM)
sns.heatmap(cms_nomlm[i], annot=True, cmap='Blues', fmt='g', xticklabels=labels, yticklabels=labels)
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title(f'Confusion Matrix - No MLM (Fold {i+1})')
f.savefig(f'./figures/confusion_matrices_folds_{m_type}.pdf', bbox_inches='tight')
plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()
def plot_training_vs_validation_losses(train_losses, val_losses, title="Losses"):
epochs = len(train_losses)
f = plt.figure(figsize=(10, 3))
plt.suptitle(title)
plt.subplot(1, 2, 1)
plt.plot(range(1, epochs+1), train_losses)
plt.xlabel('Epoch')
plt.ylabel('Train Loss')
plt.title('Train Loss')
plt.subplot(1, 2, 2)
plt.plot(range(1, epochs+1), val_losses)
plt.xlabel('Epoch')
plt.ylabel('Validation Loss')
plt.title('Validation Loss')
f.savefig('./figures/losses.pdf', bbox_inches='tight')
plt.tight_layout()
plt.show()
def plot_roc_auc_curve(val_preds, val_labels, m_type, aggregate=False):
if aggregate:
# Aggregate all folds into one list
all_labels = np.concatenate(val_labels).ravel()
all_preds = np.concatenate(val_preds).ravel()
auc = roc_auc_score(all_labels, all_preds)
fpr, tpr, _ = roc_curve(all_labels, all_preds)
f = plt.figure()
plt.plot(fpr, tpr, label=f'Aggregated AUC: {auc:.4f}')
plt.plot([0, 1], [0, 1], linestyle='--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve (Aggregated)')
plt.legend()
f.savefig(f'./figures/roc_curve_{m_type}.pdf', bbox_inches='tight')
plt.show()
else:
# Plot AUC for each fold separately
f = plt.figure()
for i, (labels, preds) in enumerate(zip(val_labels, val_preds), 1):
auc = roc_auc_score(labels, preds)
fpr, tpr, _ = roc_curve(labels, preds)
plt.plot(fpr, tpr, label=f'Fold {i} AUC: {auc:.4f}')
plt.plot([0, 1], [0, 1], linestyle='--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve (Each Fold)')
plt.legend()
f.savefig(f'./figures/roc_curve_{m_type}.pdf', bbox_inches='tight')
plt.show()
def plot_auc_boxplot_comparison(fold_results1, fold_results2, title="AUC Comparison"):
"""Plot AUC box comparison between two models."""
train_auc_scores_mlm = [fold['train_auc'] for fold in fold_results1]
train_auc_scores_nomlm = [fold['train_auc'] for fold in fold_results2]
val_auc_scores_mlm = [fold['best_val_auc'] for fold in fold_results1]
val_auc_scores_nomlm = [fold['best_val_auc'] for fold in fold_results2]
train_p_value = ttest_rel(train_auc_scores_mlm, train_auc_scores_nomlm).pvalue
val_p_value = ttest_rel(val_auc_scores_mlm, val_auc_scores_nomlm).pvalue
df_train = pd.DataFrame({
'Fold': [f'Fold {i+1}' for i in range(len(val_auc_scores_mlm))],
'with MLM': train_auc_scores_mlm,
'without MLM': train_auc_scores_nomlm,
})
df_valid = pd.DataFrame({
'Fold': [f'Fold {i+1}' for i in range(len(val_auc_scores_mlm))],
'with MLM': val_auc_scores_mlm,
'without MLM': val_auc_scores_nomlm
})
f = plt.figure(figsize=(12, 8))
plt.suptitle(title)
plt.subplot(1, 2, 1)
sns.boxplot(data=df_train, palette=["#1f77b4", "#ff7f0e"]) # Custom colors
plt.title(f'Train AUC Comparison (p-value = {train_p_value:.4f})')
plt.ylabel('AUC')
plt.ylim(0.5, 1)
plt.subplot(1, 2, 2)
sns.boxplot(data=df_valid, palette=["#2ca02c", "#d62728"]) # Custom colors
plt.title(f'Validation AUC Comparison (p-value = {val_p_value:.4f})')
plt.ylabel('AUC')
plt.ylim(0.5, 1)
f.savefig('./figures/auc_comparison.pdf', bbox_inches='tight')
plt.tight_layout()
plt.show()
def plot_loss_comparison_mlm_vs_nomlm(fold_results1, fold_results2, title="Loss Comparison"):
"""Plot loss comparison between two models."""
f = plt.figure(figsize=(12, 8))
for i, fold in enumerate(fold_results1):
train_losses_mlm = fold['metrics']['train_loss']
val_losses_mlm = fold['metrics']['val_loss']
train_losses_nomlm = fold_results2[i]['metrics']['train_loss']
val_losses_nomlm = fold_results2[i]['metrics']['val_loss']
epochs = range(1, len(train_losses_mlm) + 1)
plt.plot(epochs, train_losses_mlm, 'o-', label=f'Train Loss w/ Pre-Training - Fold {fold["fold"]}', alpha=0.5)
plt.plot(epochs, val_losses_mlm, 'x-', label=f'Validation Loss w/ Pre-Training - Fold {fold["fold"]}', alpha=0.5)
plt.plot(epochs, train_losses_nomlm, 'o--', label=f'Train Loss w/o Pre-Training - Fold {fold["fold"]}', alpha=0.5)
plt.plot(epochs, val_losses_nomlm, 'x--', label=f'Validation Loss w/o Pre-Training - Fold {fold["fold"]}', alpha=0.5)
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title(title)
plt.legend(loc='upper right', bbox_to_anchor=(1.3, 1))
f.savefig('./figures/loss_comparison.pdf', bbox_inches='tight')
plt.show()
def plot_fold_losses(fold_results, title="Losses"):
"""Plot loss for each fold."""
f = plt.figure(figsize=(12, 8))
for i, fold in enumerate(fold_results):
train_losses = fold['metrics']['train_loss']
val_losses = fold['metrics']['val_loss']
epochs = range(1, len(train_losses) + 1)
plt.plot(epochs, train_losses, 'o-', label=f'Train Loss - Fold {fold["fold"]}', alpha=0.5)
plt.plot(epochs, val_losses, 'x-', label=f'Validation Loss - Fold {fold["fold"]}', alpha=0.5)
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title(title)
plt.legend(loc='upper right', bbox_to_anchor=(1.3, 1))
f.savefig('./figures/fold_losses.pdf', bbox_inches='tight')
plt.show()
def plot_data_distribution(adata_RNA, adata_ATAC, adata_Flux, title="Data Distribution"):
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
plt.suptitle(title)
data = adata_RNA.X.toarray().flatten()
sns.histplot(data, bins=100, ax=axes[0], color='skyblue')
var, mean = data.var(), data.mean()
axes[0].set_title(f'RNA Distribution, var:{var:.2f}, mean:{mean:.2f}')
axes[0].set_xlabel('Expression level')
axes[0].set_ylabel('Frequency')
data = adata_ATAC.X.toarray().flatten()
sns.histplot(data, bins=100, ax=axes[1], color='lightgreen')
var, mean = data.var(), data.mean()
axes[1].set_title(f'ATAC Distribution, var:{var:.3f}, mean:{mean:.2f}')
axes[1].set_xlabel('Accessibility level')
axes[1].set_ylabel('Frequency')
data = adata_Flux.values.flatten()
sns.histplot(data, bins=100, ax=axes[2], color='salmon')
var, mean = data.var(), data.mean()
axes[2].set_title(f'Fluxomic Distribution, var:{var:.5f}, mean:{mean:.2f}')
axes[2].set_xlabel('Flux value')
axes[2].set_ylabel('Frequency')
fig.savefig('./figures/data_distribution.pdf', bbox_inches='tight')
plt.tight_layout()
plt.show()
def plot_att_weights(all_attention, dead_end_attention, reprogramming_attention,
feature_names=None, print_top_features=False, top_n=5, scale_weights=False, fix_scale=False,
use_mean_contribution=False):
print(all_attention.shape, "all_attention.shape")
print(dead_end_attention.shape, "dead_end_attention.shape")
print(reprogramming_attention.shape, "reprogramming_attention.shape")
def minmax_scale(arr):
arr = np.asarray(arr)
min_val = arr.min()
max_val = arr.max()
if max_val - min_val == 0:
return np.zeros_like(arr) # avoid divide by zero
return (arr - min_val) / (max_val - min_val)
avg_all_attention = all_attention.mean(axis=0) # Average attention weights across samples
avg_dead_end_attention = dead_end_attention.mean(axis=0)
avg_reprogramming_attention = reprogramming_attention.mean(axis=0)
# Store original unscaled versions for modality contribution calculation
avg_all_attention_orig = avg_all_attention.copy() if hasattr(avg_all_attention, 'copy') else np.array(avg_all_attention)
avg_dead_end_attention_orig = avg_dead_end_attention.copy() if hasattr(avg_dead_end_attention, 'copy') else np.array(avg_dead_end_attention)
avg_reprogramming_attention_orig = avg_reprogramming_attention.copy() if hasattr(avg_reprogramming_attention, 'copy') else np.array(avg_reprogramming_attention)
if scale_weights:
avg_all_attention = minmax_scale(avg_all_attention)
avg_dead_end_attention = minmax_scale(avg_dead_end_attention)
avg_reprogramming_attention = minmax_scale(avg_reprogramming_attention)
vmin, vmax = 0.0, 1.0
elif fix_scale: # fix scale of all attention weights to the same range
vmin, vmax = avg_all_attention.min(), avg_all_attention.max()
else:
vmin, vmax = None, None
# Visualize average attention weights
f = plt.figure(figsize=(15, 3))
divider1 = 945
divider2 = 945 + 884
def add_modality_labels(ax, attention_weights, attention_weights_orig, use_mean=False):
rna_weights = attention_weights_orig[:divider1]
atac_weights = attention_weights_orig[divider1:divider2]
flux_weights = attention_weights_orig[divider2:]
# Calculate metric based on method
if use_mean is False or use_mean == 'sum':
# Sum of all attention weights (original behavior)
rna_metric = rna_weights.sum()
atac_metric = atac_weights.sum()
flux_metric = flux_weights.sum()
elif use_mean is True or use_mean == 'mean':
# Mean attention per feature
rna_metric = rna_weights.mean()
atac_metric = atac_weights.mean()
flux_metric = flux_weights.mean()
elif use_mean == 'median':
# Median attention per feature (robust to zeros and outliers)
rna_metric = np.median(rna_weights)
atac_metric = np.median(atac_weights)
flux_metric = np.median(flux_weights)
elif use_mean == 'trimmed_mean':
# Trimmed mean: exclude lowest 25% and highest 5%
rna_metric = stats.trim_mean(rna_weights, proportiontocut=0.15) # removes 15% from each tail
atac_metric = stats.trim_mean(atac_weights, proportiontocut=0.15)
flux_metric = stats.trim_mean(flux_weights, proportiontocut=0.15)
elif use_mean == 'active_mean':
# Mean of only "active" features (attention > threshold)
threshold = np.percentile(attention_weights_orig, 25) # bottom 25% considered inactive
rna_active = rna_weights[rna_weights > threshold]
atac_active = atac_weights[atac_weights > threshold]
flux_active = flux_weights[flux_weights > threshold]
rna_metric = rna_active.mean() if len(rna_active) > 0 else 0
atac_metric = atac_active.mean() if len(atac_active) > 0 else 0
flux_metric = flux_active.mean() if len(flux_active) > 0 else 0
else:
raise ValueError(f"Invalid use_mean value: {use_mean}")
# # Normalize to percentages
# print(rna_metric, atac_metric, flux_metric, "rna_metric, atac_metric, flux_metric")
# total_metric = rna_metric + atac_metric + flux_metric
# rna_pct = (rna_metric / total_metric * 100) if total_metric > 0 else 0
# atac_pct = (atac_metric / total_metric * 100) if total_metric > 0 else 0
# flux_pct = (flux_metric / total_metric * 100) if total_metric > 0 else 0
# Calculate center positions for each modality
n_rna = divider1
n_atac = divider2 - divider1
n_flux = len(attention_weights) - divider2
rna_center = n_rna / 2
atac_center = divider1 + n_atac / 2
flux_center = divider2 + n_flux / 2
rna_metric_mean = rna_metric / n_rna
atac_metric_mean = atac_metric / n_atac
flux_metric_mean = flux_metric / n_flux
ax.text(rna_center, -0.3, f'Sum: {rna_metric:.3f}\nMean: {rna_metric_mean:.3f}', ha='center', va='bottom', fontsize=10, fontweight='bold')
ax.text(atac_center, -0.3, f'Sum: {atac_metric:.3f}\nMean: {atac_metric_mean:.3f}', ha='center', va='bottom', fontsize=10, fontweight='bold')
ax.text(flux_center, -0.3, f'Sum: {flux_metric:.3f}\nMean: {flux_metric_mean:.3f}', ha='center', va='bottom', fontsize=10, fontweight='bold')
plt.subplot(1, 3, 1)
ax1 = plt.gca()
sns.heatmap(avg_all_attention.reshape(1, -1), cmap='viridis', yticklabels=['All'], vmin=vmin, vmax=vmax, ax=ax1)
plt.title('Avg Att. W. (All Samples)')
plt.xlabel('Features')
plt.xticks([])
plt.axvline(x=divider1, color='red', linestyle='--', linewidth=2)
plt.axvline(x=divider2, color='red', linestyle='--', linewidth=2)
add_modality_labels(ax1, avg_all_attention, avg_all_attention_orig, use_mean=use_mean_contribution)
plt.subplot(1, 3, 2)
ax2 = plt.gca()
sns.heatmap(avg_dead_end_attention.reshape(1, -1), cmap='viridis', yticklabels=['Dead-end'], vmin=vmin, vmax=vmax, ax=ax2)
plt.title('Avg Att. W. (Dead-end Samples)')
plt.xlabel('Features')
plt.xticks([])
plt.axvline(x=divider1, color='red', linestyle='--', linewidth=2)
plt.axvline(x=divider2, color='red', linestyle='--', linewidth=2)
add_modality_labels(ax2, avg_dead_end_attention, avg_dead_end_attention_orig, use_mean=use_mean_contribution)
plt.subplot(1, 3, 3)
ax3 = plt.gca()
sns.heatmap(avg_reprogramming_attention.reshape(1, -1), cmap='viridis', yticklabels=['Reprogramming'], vmin=vmin, vmax=vmax, ax=ax3)
plt.title('Avg Att. W. (Reprogramming Samples)')
plt.xlabel('Features')
plt.xticks([])
plt.axvline(x=divider1, color='red', linestyle='--', linewidth=2)
plt.axvline(x=divider2, color='red', linestyle='--', linewidth=2)
add_modality_labels(ax3, avg_reprogramming_attention, avg_reprogramming_attention_orig, use_mean=use_mean_contribution)
# f.savefig('./figures/attention_weights.pdf', bbox_inches='tight')
plt.tight_layout()
plt.show()
if print_top_features:
def get_top_features(attention_weights, feature_names, top_n=top_n):
avg_attention = attention_weights.mean(axis=0).numpy() if hasattr(attention_weights, 'numpy') else attention_weights.mean(axis=0)
print(avg_attention.shape, len(feature_names))
top_indices = avg_attention.argsort()[-top_n:][::-1]
print(top_indices)
return [(feature_names[i], avg_attention[i]) for i in top_indices]
top_all = get_top_features(all_attention, feature_names)
top_dead_end = get_top_features(dead_end_attention, feature_names)
top_reprogramming = get_top_features(reprogramming_attention, feature_names)
print(f"Top {top_n} attended features (All samples):")
for feature, weight in top_all:
print(f"{feature}: {weight:.4f}", end=", ")
print(f"\nTop {top_n} attended features (Dead-end samples):")
for feature, weight in top_dead_end:
print(f"{feature}: {weight:.4f}", end=", ")
print(f"\nTop {top_n} attended features (Reprogramming samples):")
for feature, weight in top_reprogramming:
print(f"{feature}: {weight:.4f}", end=", ")
return f
def plot_att_weights_distribution(
all_attention, dead_end_attention, reprogramming_attention,
feature_names=None, plot_type='violin', top_n=5, print_means=False
):
divider1 = 944 # RNA ends
divider2 = 944 + 883 # ATAC ends, Flux begins
divider1 = 945
divider2 = 945 + 884
# Prepare data for plotting
def prepare_modality_data(attention_weights, condition_name):
"""Extract attention weights by modality"""
rna_weights = attention_weights[:, :divider1].flatten()
atac_weights = attention_weights[:, divider1:divider2].flatten()
flux_weights = attention_weights[:, divider2:].flatten()
return {
'RNA': rna_weights,
'ATAC': atac_weights,
'Flux': flux_weights,
'condition': condition_name,
}
all_data = prepare_modality_data(all_attention, 'All')
de_data = prepare_modality_data(dead_end_attention, 'Dead-end')
re_data = prepare_modality_data(reprogramming_attention, 'Reprogramming')
if plot_type in ['violin', 'box']:
# Create DataFrame for seaborn plotting
data_list = []
for condition_data in [all_data, de_data, re_data]:
condition = condition_data['condition']
for modality in ['RNA', 'ATAC', 'Flux']:
weights = condition_data[modality]
for weight in weights:
data_list.append({
'Condition': condition,
'Modality': modality,
'Attention Weight': weight
})
df = pd.DataFrame(data_list)
# Create figure with subplots for each condition
f, axes = plt.subplots(1, 3, figsize=(18, 5))
conditions = ['All', 'Dead-end', 'Reprogramming']
colors = ['#1f77b4', '#ff7f0e', '#2ca02c'] # RNA, ATAC, Flux colors
# Optionally print means
if print_means:
print("Mean attention weight values per modality and per condition:")
for idx, (ax, condition) in enumerate(zip(axes, conditions)):
condition_df = df[df['Condition'] == condition]
if plot_type == 'violin':
sns.violinplot(data=condition_df, x='Modality', y='Attention Weight',
palette=colors, ax=ax)
else: # box
sns.boxplot(data=condition_df, x='Modality', y='Attention Weight',
palette=colors, ax=ax)
ax.set_title(f'{condition} Samples', fontsize=12, fontweight='bold')
ax.set_xlabel('Modality', fontsize=11)
ax.set_ylabel('Attention Weight', fontsize=11)
ax.grid(axis='y', alpha=0.3)
for i, modality in enumerate(['RNA', 'ATAC', 'Flux']):
mod_data = condition_df[condition_df['Modality'] == modality]['Attention Weight']
mean_val = mod_data.mean()
std_val = mod_data.std()
ax.hlines(mean_val, i - 0.4, i + 0.4, colors='red', linestyles='--',
linewidth=2, alpha=0.7, label='Mean' if i == 0 else '')
if print_means:
print(f"{condition} - {modality}: mean={mean_val:.8f}, std={std_val:.8f}")
if idx == 0:
ax.legend()
else:
raise ValueError(f"plot_type must be 'violin', 'box', or 'hist', got '{plot_type}'")
plt.tight_layout()
plt.show()
return f
def plot_att_heads(all_attention_heads, dead_end_attention_heads, reprogramming_attention_heads, stacked=False):
n_heads = all_attention_heads.shape[1] # Assuming the second dimension is the number of heads
if stacked:
# Visualize stacked attention weights
f = plt.figure(figsize=(15, 10)) # Adjust figure size
# Plot for "All Samples" attention weights (stacked)
plt.subplot(1, 3, 1)
stacked_all_attention = all_attention_heads.mean(axis=0).reshape(n_heads, -1) # Stack attention heads
sns.heatmap(stacked_all_attention, cmap='viridis', yticklabels=[f'Head {i+1}' for i in range(n_heads)])
plt.title('Stacked Attention Weights (All Samples)')
plt.xlabel('Features')
plt.ylabel('Heads')
plt.xticks(rotation=90)
# Plot for "Dead-end Samples" attention weights (stacked)
plt.subplot(1, 3, 2)
stacked_dead_end_attention = dead_end_attention_heads.mean(axis=0).reshape(n_heads, -1)
sns.heatmap(stacked_dead_end_attention, cmap='viridis', yticklabels=[f'Head {i+1}' for i in range(n_heads)])
plt.title('Stacked Attention Weights (Dead-end Samples)')
plt.xlabel('Features')
plt.ylabel('Heads')
plt.xticks(rotation=90)
# Plot for "Reprogramming Samples" attention weights (stacked)
plt.subplot(1, 3, 3)
stacked_reprogramming_attention = reprogramming_attention_heads.mean(axis=0).reshape(n_heads, -1)
sns.heatmap(stacked_reprogramming_attention, cmap='viridis', yticklabels=[f'Head {i+1}' for i in range(n_heads)])
plt.title('Stacked Attention Weights (Reprogramming Samples)')
plt.xlabel('Features')
plt.ylabel('Heads')
plt.xticks(rotation=90)
f.savefig('./figures/attention_heads_stacked.pdf', bbox_inches='tight')
plt.tight_layout()
plt.show()
else:
# Visualize attention weights for each head
f = plt.figure(figsize=(15, 15)) # Adjusting the figure size to accommodate more subplots
# Plot for "All Samples" attention weights
for head in range(n_heads):
plt.subplot(n_heads, 3, 3 * head + 1) # (n_heads rows, 3 columns for each category)
sns.heatmap(all_attention_heads[:, head, :].mean(axis=0).reshape(1, -1), cmap='viridis', yticklabels=[f'Head {head+1}'])
plt.title(f'Head {head+1} Attention (All Samples)')
plt.xlabel('Features')
plt.xticks(rotation=90)
# Plot for "Dead-end Samples" attention weights
plt.subplot(n_heads, 3, 3 * head + 2)
sns.heatmap(dead_end_attention_heads[:, head, :].mean(axis=0).reshape(1, -1), cmap='viridis', yticklabels=[f'Head {head+1}'])
plt.title(f'Head {head+1} Attention (Dead-end Samples)')
plt.xlabel('Features')
plt.xticks(rotation=90)
# Plot for "Reprogramming Samples" attention weights
plt.subplot(n_heads, 3, 3 * head + 3)
sns.heatmap(reprogramming_attention_heads[:, head, :].mean(axis=0).reshape(1, -1), cmap='viridis', yticklabels=[f'Head {head+1}'])
plt.title(f'Head {head+1} Attention (Reprogramming Samples)')
plt.xlabel('Features')
plt.xticks(rotation=90)
f.savefig('./figures/attention_heads.pdf', bbox_inches='tight')
plt.tight_layout()
plt.show()