| |
| """ |
| ============================================================================= |
| CLOSING THE AI-BRAIN LOOP |
| Using TRIBE v2 to identify architectural gaps between AI models and the brain |
| ============================================================================= |
| |
| Methodology: |
| Phase 1 β Load TRIBE v2, run inference, capture per-layer AI features |
| Phase 2 β Layer-wise encoding analysis: which AI layers predict which brain regions |
| Phase 3 β Modality ablation: which encoder drives which brain area |
| Phase 4 β RSA: representational similarity between AI layers and brain ROIs |
| Phase 5 β Divergence mapping: where the brain does something AI can't capture |
| Phase 6 β Architectural implications: what's missing in current AI |
| |
| Output: /home/azureuser/loop_results/ |
| """ |
|
|
| import os |
| import sys |
| import logging |
| import warnings |
| import time |
|
|
| import numpy as np |
| import torch |
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| from matplotlib.gridspec import GridSpec |
| import pandas as pd |
| from pathlib import Path |
| from scipy import stats |
| from scipy.spatial.distance import pdist, squareform |
| from einops import rearrange |
|
|
| warnings.filterwarnings("ignore") |
| logging.basicConfig(level=logging.INFO, format="[%(asctime)s] %(message)s", datefmt="%H:%M:%S") |
| log = logging.getLogger(__name__) |
|
|
| OUT = Path("/home/azureuser/loop_results") |
| OUT.mkdir(exist_ok=True) |
|
|
| |
| |
| |
|
|
| log.info("PHASE 1: Loading TRIBE v2 and running inference") |
|
|
| from tribev2 import TribeModel |
|
|
| CACHE = "/home/azureuser/cache" |
| VIDEO = "/home/azureuser/test_stimulus.mp4" |
|
|
| model = TribeModel.from_pretrained("facebook/tribev2", cache_folder=CACHE) |
| fmri_model = model._model |
| device = fmri_model.device |
| log.info(f"Model on {device}, n_outputs={fmri_model.n_outputs}") |
|
|
| |
| log.info("Model structure:") |
| log.info(f" Projectors: {list(fmri_model.projectors.keys())}") |
| log.info(f" Hidden dim: {fmri_model.config.hidden}") |
| log.info(f" Layer aggregation: {fmri_model.config.layer_aggregation}") |
| log.info(f" Extractor aggregation: {fmri_model.config.extractor_aggregation}") |
| if hasattr(fmri_model, 'encoder'): |
| log.info(f" Encoder: {type(fmri_model.encoder).__name__}") |
| if hasattr(fmri_model, 'low_rank_head'): |
| log.info(f" Low-rank head: {fmri_model.low_rank_head}") |
|
|
| |
| log.info(f"Processing video: {VIDEO}") |
| events = model.get_events_dataframe(video_path=VIDEO) |
| log.info(f"Events: {len(events)} rows, types: {events.type.unique().tolist()}") |
|
|
| |
| loader = model.data.get_loaders(events=events, split_to_build="all")["all"] |
|
|
| |
| |
| |
|
|
| all_features = {} |
| all_projected = {} |
| all_post_encoder = [] |
| all_brain_preds = [] |
|
|
| |
| proj_captures = {} |
| proj_hooks = [] |
|
|
| def make_proj_hook(name): |
| def hook(mod, inp, out): |
| proj_captures[name] = out.detach().cpu().numpy() |
| return hook |
|
|
| for mod_name, proj in fmri_model.projectors.items(): |
| proj_hooks.append(proj.register_forward_hook(make_proj_hook(mod_name))) |
|
|
| |
| encoder_capture = [None] |
| def enc_hook(mod, inp, out): |
| encoder_capture[0] = out.detach().cpu().numpy() |
| if hasattr(fmri_model, 'encoder') and fmri_model.encoder is not None: |
| proj_hooks.append(fmri_model.encoder.register_forward_hook(enc_hook)) |
|
|
| log.info("Running inference with hooks...") |
| t0 = time.time() |
|
|
| with torch.inference_mode(): |
| for batch_idx, batch in enumerate(loader): |
| batch = batch.to(device) |
|
|
| |
| for mod_name in fmri_model.projectors.keys(): |
| if mod_name in batch.data: |
| feat = batch.data[mod_name].detach().cpu().numpy() |
| if feat.ndim == 3: |
| feat = feat[:, np.newaxis, :, :] |
| if mod_name not in all_features: |
| all_features[mod_name] = [] |
| all_features[mod_name].append(feat) |
|
|
| |
| y_pred = fmri_model(batch).detach().cpu().numpy() |
| y_pred = rearrange(y_pred, 'b v t -> (b t) v') |
| all_brain_preds.append(y_pred) |
|
|
| |
| for mod_name in proj_captures: |
| if mod_name not in all_projected: |
| all_projected[mod_name] = [] |
| all_projected[mod_name].append(proj_captures[mod_name]) |
| proj_captures.clear() |
|
|
| |
| if encoder_capture[0] is not None: |
| all_post_encoder.append(encoder_capture[0]) |
| encoder_capture[0] = None |
|
|
| |
| for h in proj_hooks: |
| h.remove() |
|
|
| elapsed = time.time() - t0 |
| log.info(f"Inference done in {elapsed:.1f}s") |
|
|
| |
| brain_preds = np.concatenate(all_brain_preds, axis=0) |
| log.info(f"Brain predictions: {brain_preds.shape}") |
|
|
| for mod in all_features: |
| all_features[mod] = np.concatenate(all_features[mod], axis=0) |
| log.info(f"Raw features [{mod}]: {all_features[mod].shape}") |
|
|
| for mod in all_projected: |
| all_projected[mod] = np.concatenate(all_projected[mod], axis=0) |
| log.info(f"Projected features [{mod}]: {all_projected[mod].shape}") |
|
|
| if all_post_encoder: |
| post_encoder = np.concatenate(all_post_encoder, axis=0) |
| log.info(f"Post-encoder features: {post_encoder.shape}") |
|
|
| |
| |
| |
|
|
| log.info("PHASE 2: Loading brain parcellation (Destrieux atlas, fsaverage5)") |
|
|
| from nilearn import datasets |
| from nilearn.surface import load_surf_data |
|
|
| fsaverage5 = datasets.fetch_surf_fsaverage("fsaverage5") |
|
|
| |
| labels_lh = load_surf_data(fsaverage5["annot_left_destrieux"]) |
| labels_rh = load_surf_data(fsaverage5["annot_right_destrieux"]) |
|
|
| N_VERT = 10242 |
| all_labels = np.concatenate([labels_lh, labels_rh]) |
|
|
| |
| destrieux = datasets.fetch_atlas_destrieux_2009() |
| label_names_raw = destrieux["labels"] |
| label_names = {} |
| for i, name in enumerate(label_names_raw): |
| if isinstance(name, bytes): |
| name = name.decode("utf-8") |
| label_names[i] = name |
|
|
| |
| regions = {} |
| for lid in np.unique(all_labels): |
| mask = all_labels == lid |
| n = mask.sum() |
| if n < 5: |
| continue |
| name = label_names.get(int(lid), f"region_{lid}") |
| if name == "Unknown" or name == "Medial_wall": |
| continue |
| regions[int(lid)] = {"name": name, "mask": mask, "n_vertices": int(n)} |
|
|
| log.info(f"Found {len(regions)} usable brain regions") |
|
|
| |
| |
| |
|
|
| log.info("PHASE 3: Layer-wise encoding analysis") |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| T_total = brain_preds.shape[0] |
| V = brain_preds.shape[1] |
|
|
| |
| region_timeseries = {} |
| for lid, rinfo in regions.items(): |
| region_timeseries[lid] = brain_preds[:, rinfo["mask"]].mean(axis=1) |
|
|
| |
| |
| layer_timeseries = {} |
| for mod, feats in all_features.items(): |
| B, L, D, T_batch = feats.shape |
| |
| feats_flat = rearrange(feats, 'b l d t -> (b t) l d') |
| |
| min_t = min(feats_flat.shape[0], T_total) |
| feats_flat = feats_flat[:min_t] |
|
|
| layer_timeseries[mod] = {} |
| for l in range(L): |
| |
| layer_timeseries[mod][l] = feats_flat[:, l, :].mean(axis=1) |
|
|
| log.info("Computing layer-brain correlations...") |
|
|
| |
| layer_brain_corr = {} |
| for mod in layer_timeseries: |
| L = len(layer_timeseries[mod]) |
| for l in range(L): |
| key = f"{mod}_L{l}" |
| layer_brain_corr[key] = {} |
| layer_ts = layer_timeseries[mod][l] |
| min_t = min(len(layer_ts), T_total) |
| for lid, rinfo in regions.items(): |
| brain_ts = region_timeseries[lid][:min_t] |
| lt = layer_ts[:min_t] |
| if np.std(lt) < 1e-10 or np.std(brain_ts) < 1e-10: |
| r = 0.0 |
| else: |
| r, _ = stats.pearsonr(lt, brain_ts) |
| layer_brain_corr[key][lid] = r |
|
|
| |
| all_layer_keys = sorted(layer_brain_corr.keys()) |
| all_region_ids = sorted(regions.keys()) |
| region_names_list = [regions[lid]["name"] for lid in all_region_ids] |
|
|
| corr_matrix = np.zeros((len(all_layer_keys), len(all_region_ids))) |
| for i, lk in enumerate(all_layer_keys): |
| for j, lid in enumerate(all_region_ids): |
| corr_matrix[i, j] = layer_brain_corr[lk].get(lid, 0) |
|
|
| log.info(f"Correlation matrix: {corr_matrix.shape} (AI layers x brain regions)") |
|
|
| |
| |
| |
|
|
| log.info("PHASE 4: Modality ablation analysis") |
|
|
| modalities = list(fmri_model.projectors.keys()) |
| log.info(f"Ablating modalities: {modalities}") |
|
|
| |
| ablation_preds = {"full": brain_preds} |
|
|
| with torch.inference_mode(): |
| for mod_to_ablate in modalities: |
| log.info(f" Ablating: {mod_to_ablate}") |
| preds_list = [] |
| loader = model.data.get_loaders(events=events, split_to_build="all")["all"] |
| for batch in loader: |
| batch = batch.to(device) |
| |
| if mod_to_ablate in batch.data: |
| original = batch.data[mod_to_ablate].clone() |
| batch.data[mod_to_ablate] = torch.zeros_like(original) |
| y = fmri_model(batch).detach().cpu().numpy() |
| y = rearrange(y, 'b v t -> (b t) v') |
| preds_list.append(y) |
| batch.data[mod_to_ablate] = original |
| else: |
| y = fmri_model(batch).detach().cpu().numpy() |
| y = rearrange(y, 'b v t -> (b t) v') |
| preds_list.append(y) |
| ablation_preds[mod_to_ablate] = np.concatenate(preds_list, axis=0) |
| log.info(f" shape: {ablation_preds[mod_to_ablate].shape}") |
|
|
| |
| |
| region_mod_importance = {} |
| for lid, rinfo in regions.items(): |
| mask = rinfo["mask"] |
| full = ablation_preds["full"][:, mask] |
| imp = {} |
| for mod in modalities: |
| ablated = ablation_preds[mod][:, mask] |
| |
| delta_mse = np.mean((full - ablated) ** 2) |
| imp[mod] = float(delta_mse) |
| total = sum(imp.values()) + 1e-12 |
| imp_norm = {k: v / total for k, v in imp.items()} |
| region_mod_importance[lid] = imp_norm |
|
|
| log.info("Modality ablation done") |
|
|
| |
| |
| |
|
|
| log.info("PHASE 5: Representational Similarity Analysis") |
|
|
| |
| |
|
|
| SEGMENT_SIZE = 2 |
| n_segments = T_total // SEGMENT_SIZE |
|
|
| |
| def build_segments(timeseries, n_segments, seg_size): |
| """Average timeseries within segments.""" |
| segs = [] |
| for i in range(n_segments): |
| start = i * seg_size |
| end = start + seg_size |
| if end <= len(timeseries): |
| segs.append(timeseries[start:end].mean(axis=0) if timeseries.ndim > 1 else timeseries[start:end].mean()) |
| return np.array(segs) |
|
|
| |
| log.info("Building brain RDMs per region...") |
| brain_rdms = {} |
| for lid, rinfo in regions.items(): |
| region_data = brain_preds[:, rinfo["mask"]] |
| seg_data = [] |
| for i in range(n_segments): |
| s, e = i * SEGMENT_SIZE, (i + 1) * SEGMENT_SIZE |
| if e <= region_data.shape[0]: |
| seg_data.append(region_data[s:e].mean(axis=0)) |
| if len(seg_data) < 3: |
| continue |
| seg_data = np.array(seg_data) |
| if seg_data.std() < 1e-10: |
| brain_rdms[lid] = np.zeros((len(seg_data), len(seg_data))) |
| else: |
| brain_rdms[lid] = squareform(pdist(seg_data, metric="correlation")) |
|
|
| |
| log.info("Building AI feature RDMs per layer...") |
| ai_rdms = {} |
| for mod, feats in all_features.items(): |
| B, L, D, T_batch = feats.shape |
| feats_flat = rearrange(feats, 'b l d t -> (b t) l d') |
| min_t = min(feats_flat.shape[0], T_total) |
| feats_flat = feats_flat[:min_t] |
|
|
| for l in range(L): |
| layer_data = feats_flat[:, l, :] |
| seg_data = [] |
| for i in range(n_segments): |
| s, e = i * SEGMENT_SIZE, (i + 1) * SEGMENT_SIZE |
| if e <= layer_data.shape[0]: |
| seg_data.append(layer_data[s:e].mean(axis=0)) |
| if len(seg_data) < 3: |
| continue |
| seg_data = np.array(seg_data) |
| key = f"{mod}_L{l}" |
| if seg_data.std() < 1e-10: |
| ai_rdms[key] = np.zeros((len(seg_data), len(seg_data))) |
| else: |
| ai_rdms[key] = squareform(pdist(seg_data, metric="correlation")) |
|
|
| |
| log.info("Computing RSA (Spearman correlation between RDMs)...") |
| rsa_matrix = np.zeros((len(ai_rdms), len(brain_rdms))) |
| ai_rdm_keys = sorted(ai_rdms.keys()) |
| brain_rdm_keys = sorted(brain_rdms.keys()) |
|
|
| for i, ak in enumerate(ai_rdm_keys): |
| ai_vec = squareform(ai_rdms[ak]) |
| if len(ai_vec) == 0: |
| continue |
| for j, bk in enumerate(brain_rdm_keys): |
| brain_vec = squareform(brain_rdms[bk]) |
| min_len = min(len(ai_vec), len(brain_vec)) |
| if min_len < 3: |
| continue |
| rho, _ = stats.spearmanr(ai_vec[:min_len], brain_vec[:min_len]) |
| rsa_matrix[i, j] = rho if not np.isnan(rho) else 0 |
|
|
| log.info(f"RSA matrix: {rsa_matrix.shape}") |
|
|
| |
| |
| |
|
|
| log.info("PHASE 6: Identifying AI-brain divergences") |
|
|
| results = [] |
| for idx_j, lid in enumerate(brain_rdm_keys): |
| rinfo = regions[lid] |
| imp = region_mod_importance.get(lid, {}) |
|
|
| |
| rsa_col = rsa_matrix[:, idx_j] |
| best_rsa = np.max(np.abs(rsa_col)) if len(rsa_col) > 0 else 0 |
| best_rsa_layer = ai_rdm_keys[np.argmax(np.abs(rsa_col))] if len(rsa_col) > 0 else "none" |
|
|
| |
| region_corrs = corr_matrix[:, all_region_ids.index(lid)] if lid in all_region_ids else np.zeros(1) |
| best_encoding = np.max(np.abs(region_corrs)) |
| best_enc_layer = all_layer_keys[np.argmax(np.abs(region_corrs))] if len(region_corrs) > 0 else "none" |
|
|
| |
| probs = np.array(list(imp.values())) + 1e-10 |
| probs = probs / probs.sum() |
| entropy = -np.sum(probs * np.log2(probs)) |
|
|
| |
| temporal_var = float(brain_preds[:, rinfo["mask"]].mean(axis=1).var()) |
|
|
| |
| divergence = temporal_var * (1 - best_rsa) * entropy |
|
|
| results.append({ |
| "region_id": lid, |
| "region": rinfo["name"], |
| "hemisphere": "LH" if lid < 100 else "RH", |
| "n_vertices": rinfo["n_vertices"], |
| "temporal_variance": temporal_var, |
| "best_rsa_alignment": best_rsa, |
| "best_rsa_layer": best_rsa_layer, |
| "best_encoding_corr": best_encoding, |
| "best_encoding_layer": best_enc_layer, |
| "modality_entropy": entropy, |
| "divergence_score": divergence, |
| **{f"imp_{k}": v for k, v in imp.items()}, |
| }) |
|
|
| df = pd.DataFrame(results) |
| df = df.sort_values("divergence_score", ascending=False) |
| df.to_csv(OUT / "full_analysis.csv", index=False) |
|
|
| |
| |
| |
|
|
| log.info("PHASE 7: Generating visualizations and report") |
|
|
| |
| fig, ax = plt.subplots(figsize=(20, max(8, len(all_layer_keys) * 0.3))) |
| im = ax.imshow(corr_matrix, aspect="auto", cmap="RdBu_r", vmin=-1, vmax=1) |
| ax.set_yticks(range(len(all_layer_keys))) |
| ax.set_yticklabels(all_layer_keys, fontsize=6) |
| ax.set_xticks(range(len(region_names_list))) |
| ax.set_xticklabels(region_names_list, fontsize=4, rotation=90) |
| ax.set_title("Layer-wise Encoding: AI Layer β Brain Region Correlation", fontsize=14) |
| ax.set_ylabel("AI Encoder Layer") |
| ax.set_xlabel("Brain Region (Destrieux)") |
| plt.colorbar(im, ax=ax, label="Pearson r") |
| fig.tight_layout() |
| fig.savefig(OUT / "01_layer_brain_correlation.png", dpi=200) |
| plt.close(fig) |
| log.info("Saved 01_layer_brain_correlation.png") |
|
|
| |
| fig, ax = plt.subplots(figsize=(20, max(8, len(ai_rdm_keys) * 0.3))) |
| im = ax.imshow(rsa_matrix, aspect="auto", cmap="RdBu_r", vmin=-0.5, vmax=0.5) |
| ax.set_yticks(range(len(ai_rdm_keys))) |
| ax.set_yticklabels(ai_rdm_keys, fontsize=6) |
| brain_region_names_rsa = [regions[lid]["name"] for lid in brain_rdm_keys] |
| ax.set_xticks(range(len(brain_region_names_rsa))) |
| ax.set_xticklabels(brain_region_names_rsa, fontsize=4, rotation=90) |
| ax.set_title("RSA: AI Layer β Brain Region Representational Similarity", fontsize=14) |
| ax.set_ylabel("AI Encoder Layer") |
| ax.set_xlabel("Brain Region (Destrieux)") |
| plt.colorbar(im, ax=ax, label="Spearman Ο") |
| fig.tight_layout() |
| fig.savefig(OUT / "02_rsa_heatmap.png", dpi=200) |
| plt.close(fig) |
| log.info("Saved 02_rsa_heatmap.png") |
|
|
| |
| fig, ax = plt.subplots(figsize=(12, 10)) |
| sc = ax.scatter( |
| df["best_rsa_alignment"], |
| df["temporal_variance"], |
| s=df["n_vertices"] * 0.5, |
| c=df["divergence_score"], |
| cmap="YlOrRd", |
| alpha=0.7, |
| edgecolors="k", |
| linewidths=0.3, |
| ) |
| |
| for _, row in df.head(12).iterrows(): |
| ax.annotate( |
| row["region"], |
| (row["best_rsa_alignment"], row["temporal_variance"]), |
| fontsize=6, alpha=0.9, |
| arrowprops=dict(arrowstyle="-", alpha=0.3), |
| textcoords="offset points", xytext=(5, 5), |
| ) |
| ax.set_xlabel("Best RSA Alignment (max |Spearman Ο| across AI layers)", fontsize=11) |
| ax.set_ylabel("Temporal Variance (brain dynamics)", fontsize=11) |
| ax.set_title("AI-Brain Divergence Map\nBottom-right = active brain regions poorly captured by AI", fontsize=13) |
| plt.colorbar(sc, label="Divergence Score") |
| fig.tight_layout() |
| fig.savefig(OUT / "03_divergence_scatter.png", dpi=200) |
| plt.close(fig) |
| log.info("Saved 03_divergence_scatter.png") |
|
|
| |
| imp_cols = [c for c in df.columns if c.startswith("imp_")] |
| if imp_cols: |
| top_30 = df.nlargest(30, "temporal_variance") |
| fig, ax = plt.subplots(figsize=(14, 8)) |
| bottom = np.zeros(len(top_30)) |
| colors = plt.cm.Set2(np.linspace(0, 1, len(imp_cols))) |
| for ci, col in enumerate(imp_cols): |
| vals = top_30[col].values |
| ax.barh(range(len(top_30)), vals, left=bottom, color=colors[ci], |
| label=col.replace("imp_", "").upper()) |
| bottom += vals |
| ax.set_yticks(range(len(top_30))) |
| ax.set_yticklabels(top_30["region"].values, fontsize=7) |
| ax.set_xlabel("Relative Modality Importance") |
| ax.set_title("Modality Contribution per Brain Region (top 30 by dynamics)", fontsize=13) |
| ax.legend(loc="lower right") |
| fig.tight_layout() |
| fig.savefig(OUT / "04_modality_importance.png", dpi=200) |
| plt.close(fig) |
| log.info("Saved 04_modality_importance.png") |
|
|
| |
| try: |
| from nilearn.plotting import plot_surf_stat_map |
|
|
| vertex_divergence = np.zeros(N_VERT * 2) |
| for _, row in df.iterrows(): |
| lid = row["region_id"] |
| if lid in regions: |
| vertex_divergence[regions[lid]["mask"]] = row["divergence_score"] |
|
|
| fig = plt.figure(figsize=(16, 12)) |
| for idx, (hemi, view) in enumerate([ |
| ("left", "lateral"), ("left", "medial"), |
| ("right", "lateral"), ("right", "medial") |
| ]): |
| ax = fig.add_subplot(2, 2, idx + 1, projection="3d") |
| if hemi == "left": |
| data = vertex_divergence[:N_VERT] |
| else: |
| data = vertex_divergence[N_VERT:] |
| plot_surf_stat_map( |
| fsaverage5[f"pial_{hemi}"], |
| data, |
| hemi=hemi, |
| view=view, |
| cmap="YlOrRd", |
| title=f"Divergence ({hemi} {view})", |
| figure=fig, |
| axes=ax, |
| ) |
| fig.suptitle("Brain Surface: AI-Brain Divergence Scores", fontsize=14, y=1.02) |
| fig.tight_layout() |
| fig.savefig(OUT / "05_brain_surface_divergence.png", dpi=200, bbox_inches="tight") |
| plt.close(fig) |
| log.info("Saved 05_brain_surface_divergence.png") |
| except Exception as e: |
| log.warning(f"Brain surface plot failed: {e}") |
|
|
| |
| for mod in all_features: |
| mod_keys = [k for k in all_layer_keys if k.startswith(f"{mod}_")] |
| if not mod_keys: |
| continue |
| mod_indices = [all_layer_keys.index(k) for k in mod_keys] |
| mod_corr = corr_matrix[mod_indices, :] |
|
|
| fig, ax = plt.subplots(figsize=(14, 6)) |
| im = ax.imshow(mod_corr, aspect="auto", cmap="RdBu_r", vmin=-1, vmax=1) |
| ax.set_yticks(range(len(mod_keys))) |
| ax.set_yticklabels([f"Layer {i}" for i in range(len(mod_keys))], fontsize=8) |
| ax.set_xticks(range(len(region_names_list))) |
| ax.set_xticklabels(region_names_list, fontsize=4, rotation=90) |
| ax.set_title(f"{mod.upper()} Encoder: Layer-wise Brain Alignment", fontsize=13) |
| ax.set_ylabel("Encoder Layer (early β late)") |
| ax.set_xlabel("Brain Region") |
| plt.colorbar(im, ax=ax, label="Pearson r") |
| fig.tight_layout() |
| fig.savefig(OUT / f"06_{mod}_layer_alignment.png", dpi=200) |
| plt.close(fig) |
| log.info(f"Saved 06_{mod}_layer_alignment.png") |
|
|
| |
| |
| |
|
|
| report = [] |
| report.append("=" * 100) |
| report.append("CLOSING THE AI-BRAIN LOOP: Analysis Report") |
| report.append(f"Generated: {pd.Timestamp.now()}") |
| report.append("=" * 100) |
|
|
| report.append("\n\n--- DATASET ---") |
| report.append(f"Stimulus: {VIDEO}") |
| report.append(f"Total time points: {T_total}") |
| report.append(f"Brain vertices: {V} (fsaverage5)") |
| report.append(f"Brain regions analyzed: {len(regions)} (Destrieux atlas)") |
| report.append(f"AI modalities: {modalities}") |
| for mod, feats in all_features.items(): |
| report.append(f" {mod}: {feats.shape[1]} layers, {feats.shape[2]}-dim features") |
|
|
| report.append("\n\n--- TOP 15 DIVERGENCE REGIONS ---") |
| report.append("(Brain regions with high dynamics but poor AI alignment)") |
| report.append("") |
| cols = ["region", "temporal_variance", "best_rsa_alignment", "best_rsa_layer", |
| "modality_entropy", "divergence_score"] |
| cols_present = [c for c in cols if c in df.columns] |
| report.append(df[cols_present].head(15).to_string(index=False, float_format="%.4f")) |
|
|
| report.append("\n\n--- TOP 15 WELL-ALIGNED REGIONS ---") |
| report.append("(Brain regions where AI encoders match brain representations well)") |
| well_aligned = df.nlargest(15, "best_rsa_alignment") |
| report.append(well_aligned[cols_present].to_string(index=False, float_format="%.4f")) |
|
|
| report.append("\n\n--- MODALITY DOMINANCE ---") |
| for mod in modalities: |
| col = f"imp_{mod}" |
| if col not in df.columns: |
| continue |
| report.append(f"\n{mod.upper()}-dominated regions (top 5):") |
| top = df.nlargest(5, col) |
| for _, row in top.iterrows(): |
| report.append(f" {row['region']:45s} importance={row[col]:.4f} rsa={row['best_rsa_alignment']:.4f}") |
|
|
| report.append("\n\n--- ARCHITECTURAL IMPLICATIONS ---") |
| report.append(""" |
| Based on the divergence analysis, here are the gaps in current AI architectures |
| and proposed solutions: |
| |
| 1. HIGH-ENTROPY DIVERGENCE REGIONS (multiple modalities contribute equally, |
| but overall alignment is poor): |
| β The brain performs CROSS-MODAL INTEGRATION that the concatenation-based |
| fusion in TRIBE v2 (and most multimodal AI) doesn't capture. |
| β Proposed fix: EARLY FUSION with cross-attention between modality streams |
| at intermediate layers, not just late concatenation. |
| |
| 2. HIGH TEMPORAL VARIANCE + LOW ALIGNMENT: |
| β The brain has strong TEMPORAL DYNAMICS (prediction, memory, feedback loops) |
| that feedforward AI encoders miss entirely. |
| β Proposed fix: Add RECURRENT connections or PREDICTIVE CODING layers that |
| generate top-down predictions and propagate prediction errors. |
| |
| 3. REGIONS WHERE NO SINGLE LAYER ALIGNS WELL: |
| β The brain's computation in these areas may involve representations that |
| DON'T EXIST in any layer of V-JEPA2, LLaMA, or Wav2Vec-BERT. |
| β Proposed fix: Train a NEW encoder objective that explicitly optimizes for |
| brain alignment in these gap regions (brain-guided contrastive learning). |
| |
| 4. LAYER DEPTH PATTERNS: |
| β If early AI layers align with sensory cortex and late layers align with |
| association cortex, this confirms the HIERARCHICAL CORRESPONDENCE between |
| DNNs and the cortical hierarchy. |
| β Where this breaks (e.g., late layers don't align with prefrontal cortex), |
| it suggests the model lacks EXECUTIVE/ABSTRACT processing. |
| """) |
|
|
| report.append("\n--- FILES ---") |
| report.append(f"Full CSV: {OUT / 'full_analysis.csv'}") |
| report.append(f"Plots: {OUT / '*.png'}") |
| report.append("=" * 100) |
|
|
| report_text = "\n".join(report) |
| print(report_text) |
|
|
| with open(OUT / "report.txt", "w") as f: |
| f.write(report_text) |
|
|
| log.info(f"All results saved to {OUT}") |
| log.info("Done.") |
|
|