| import os |
| from itertools import combinations |
|
|
| import matplotlib.pyplot as plt |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from datasets import Audio, load_dataset |
| from safetensors.torch import save_file |
| from tqdm import tqdm |
| from transformers import AutoFeatureExtractor, WhisperModel |
|
|
| from .config import * |
|
|
| model_ids = ENABLED_MODELS |
|
|
| |
| dataset = load_dataset("JacobLinCool/cv161-en-zh-subset-200", split="train") |
| if MAX_SAMPLES is not None: |
| dataset = dataset.select(range(min(MAX_SAMPLES, len(dataset)))) |
| print(f"Limited dataset to {len(dataset)} samples for testing") |
|
|
| dataset = dataset.cast_column("audio", Audio(sampling_rate=16_000)) |
|
|
| device = torch.device( |
| "cuda" |
| if torch.cuda.is_available() |
| else "mps" if torch.backends.mps.is_available() else "cpu" |
| ) |
| print(f"Using device: {device}") |
|
|
|
|
| def extract_layer_reps_generator(model_id, batch_size=4): |
| """ |
| Use a generator to process samples in batches, avoiding loading all hidden states into memory at once. |
| Yields (sample_idx, layer_reps) pairs, where layer_reps is a list of all layer representations for the sample. |
| """ |
| model = WhisperModel.from_pretrained(model_id).to(device) |
| feat_ext = AutoFeatureExtractor.from_pretrained(model_id) |
| model.eval() |
|
|
| for i in tqdm( |
| range(0, len(dataset), batch_size), desc=f"Processing {model_id} in batches" |
| ): |
| batch_end = min(i + batch_size, len(dataset)) |
| batch_samples = dataset.select(range(i, batch_end)) |
|
|
| |
| for j, sample in enumerate(batch_samples): |
| audio = sample["audio"] |
| samples = audio["array"] |
| sr = audio["sampling_rate"] |
|
|
| inputs = feat_ext( |
| samples, sampling_rate=sr, return_tensors="pt" |
| ).input_features.to(device) |
| with torch.no_grad(): |
| outputs = model.encoder( |
| inputs, return_dict=True, output_hidden_states=True |
| ) |
|
|
| |
| layer_reps_for_sample = [] |
| for hs in outputs.hidden_states: |
| |
| layer_rep = hs.squeeze(0) |
| if USE_HALF_PRECISION: |
| layer_rep = layer_rep.to(HALF_PRECISION_DTYPE) |
| layer_reps_for_sample.append(layer_rep) |
|
|
| yield i + j, layer_reps_for_sample |
|
|
| |
| del outputs, inputs |
| if AGGRESSIVE_CLEANUP and torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| |
| del model, feat_ext |
| if AGGRESSIVE_CLEANUP and torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
|
|
| def compute_linear_mse_matrix_temporal_memory_efficient( |
| model_a_id, model_b_id, n_steps=200, lr=1e-3, batch_size=4 |
| ): |
| """ |
| Memory-efficient version: For each layer pair (i, j), trains a 1x1 convolution as a linear probe and computes MSE. |
| Uses a generator to process in batches, avoiding loading all representations into memory at once. |
| Returns an MSE matrix of shape (layers_a, layers_b) and all trained probes. |
| """ |
| print(f"Computing alignment between {model_a_id} and {model_b_id}...") |
|
|
| |
| sample_gen_a = extract_layer_reps_generator(model_a_id, batch_size=1) |
| _, sample_reps_a = next(sample_gen_a) |
| layers_a = len(sample_reps_a) |
|
|
| sample_gen_b = extract_layer_reps_generator(model_b_id, batch_size=1) |
| _, sample_reps_b = next(sample_gen_b) |
| layers_b = len(sample_reps_b) |
|
|
| mse_mat = np.zeros((layers_a, layers_b)) |
| trained_probes = {} |
|
|
| pbar = tqdm(total=layers_a * layers_b, desc="Comparing layer pairs") |
|
|
| |
| gen_a = extract_layer_reps_generator(model_a_id, batch_size=batch_size) |
| gen_b = extract_layer_reps_generator(model_b_id, batch_size=batch_size) |
|
|
| |
| reps_a_dict_all = {} |
| for sample_idx, layer_reps in gen_a: |
| reps_a_dict_all[sample_idx] = layer_reps |
|
|
| reps_b_dict_all = {} |
| for sample_idx, layer_reps in gen_b: |
| reps_b_dict_all[sample_idx] = layer_reps |
|
|
| for i in range(layers_a): |
| for j in range(layers_b): |
| |
| reps_a_dict = {} |
| for sample_idx, layer_reps in reps_a_dict_all.items(): |
| if i < len(layer_reps): |
| reps_a_dict[sample_idx] = layer_reps[i] |
|
|
| reps_b_dict = {} |
| for sample_idx, layer_reps in reps_b_dict_all.items(): |
| if j < len(layer_reps): |
| reps_b_dict[sample_idx] = layer_reps[j] |
|
|
| |
| X_list = [reps_a_dict[idx] for idx in sorted(reps_a_dict.keys())] |
| Y_list = [reps_b_dict[idx] for idx in sorted(reps_b_dict.keys())] |
|
|
| |
| X_cat = torch.cat(X_list, dim=0).to(device) |
| Y_cat = torch.cat(Y_list, dim=0).to(device) |
|
|
| dim_a = X_cat.shape[1] |
| dim_b = Y_cat.shape[1] |
|
|
| |
| X = X_cat.T.unsqueeze(0) |
| Y = Y_cat.T.unsqueeze(0) |
|
|
| |
| probe = nn.Conv1d( |
| in_channels=dim_a, out_channels=dim_b, kernel_size=1, bias=False |
| ).to(device=device, dtype=HALF_PRECISION_DTYPE) |
| probe.train() |
|
|
| optimizer = torch.optim.Adam(probe.parameters(), lr=lr) |
| loss_fn = nn.MSELoss() |
|
|
| for step in tqdm(range(n_steps), desc=f"Training probe {i}->{j}"): |
| optimizer.zero_grad() |
| Y_pred = probe(X) |
| loss = loss_fn(Y_pred, Y) |
| loss.backward() |
| optimizer.step() |
|
|
| |
| final_mse = loss.item() |
| mse_mat[i, j] = final_mse |
| trained_probes[f"layer_{i}_to_{j}"] = probe.state_dict()["weight"] |
|
|
| |
| del ( |
| X_cat, |
| Y_cat, |
| X, |
| Y, |
| probe, |
| optimizer, |
| reps_a_dict, |
| reps_b_dict, |
| X_list, |
| Y_list, |
| ) |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| pbar.update(1) |
| pbar.set_postfix({"layer_a": i, "layer_b": j, "mse": f"{final_mse:.4f}"}) |
|
|
| pbar.close() |
| return mse_mat, trained_probes |
|
|
|
|
| if __name__ == "__main__": |
| print(f"Memory optimization settings:") |
| print(f" Batch size: {BATCH_SIZE}") |
| print(f" Training steps: {TRAINING_STEPS}") |
| if USE_HALF_PRECISION: |
| dtype_name = "bfloat16" if HALF_PRECISION_DTYPE == torch.bfloat16 else "float16" |
| print(f" Half precision: {USE_HALF_PRECISION} ({dtype_name})") |
| else: |
| print(f" Half precision: {USE_HALF_PRECISION}") |
| print(f" Aggressive cleanup: {AGGRESSIVE_CLEANUP}") |
| print(f" Models: {list(model_ids.keys())}") |
| print(f" Dataset size: {len(dataset)} samples") |
|
|
| |
| os.makedirs(OUTPUT_DIR, exist_ok=True) |
|
|
| |
| model_names = list(model_ids.keys()) |
| all_pairs = list(combinations(model_names, 2)) |
|
|
| print( |
| f"\nProcessing {len(all_pairs)} model pairs with memory-efficient approach..." |
| ) |
|
|
| for pair_idx, (model_a, model_b) in enumerate(all_pairs): |
| print( |
| f"\n[{pair_idx + 1}/{len(all_pairs)}] Computing temporal linear MSE for whisper-{model_a} vs whisper-{model_b}..." |
| ) |
|
|
| |
| mse_mat_temporal, trained_probes = ( |
| compute_linear_mse_matrix_temporal_memory_efficient( |
| model_ids[model_a], |
| model_ids[model_b], |
| n_steps=TRAINING_STEPS, |
| lr=LEARNING_RATE, |
| batch_size=BATCH_SIZE, |
| ) |
| ) |
|
|
| |
| model_save_path = f"{OUTPUT_DIR}/{model_a}-to-{model_b}-probes.safetensors" |
| save_file( |
| trained_probes, |
| model_save_path, |
| { |
| "from_model": model_a, |
| "to_model": model_b, |
| "from_layers": str(len(mse_mat_temporal)), |
| "to_layers": str(len(mse_mat_temporal[0])), |
| }, |
| ) |
| print(f"Saved trained probes to: {model_save_path}") |
|
|
| if SAVE_PLOTS: |
| |
| |
| eps = 1e-10 |
| log_mse_mat = -np.log10(mse_mat_temporal + eps) |
|
|
| plt.figure(figsize=(8, 6)) |
| plt.imshow( |
| log_mse_mat, aspect="auto", origin="lower" |
| ) |
| plt.colorbar(label="-log10(MSE)") |
| plt.title( |
| f"Temporal Linear MSE (log scale): whisper-{model_a} vs whisper-{model_b}" |
| ) |
| plt.xlabel(f"whisper-{model_b} layers") |
| plt.ylabel(f"whisper-{model_a} layers") |
| plt.tight_layout() |
|
|
| |
| plot_save_path = ( |
| f"{OUTPUT_DIR}/{model_a}-vs-{model_b}-temporal-linear-mse-log.png" |
| ) |
| plt.savefig(plot_save_path, dpi=PLOT_DPI) |
| plt.close() |
| print(f"Saved plot to: {plot_save_path}") |
|
|
| print(f"\nAll experiments complete! Results saved to '{OUTPUT_DIR}' directory") |
| print( |
| f"Generated {len(all_pairs)} visualization plots and {len(all_pairs)} trained probe models" |
| ) |
|
|