""" Training script for 2-stage fMRI encoding with Flow Matching. Stage 1: Train MultiSubjectConvLinearEncoder (Mean Anchor) Stage 2: Train Conditional Flow Matching (Neural Vector Field) per subject. """ import argparse import json import math import sys import time from pathlib import Path from typing import Dict, Any, Optional import numpy as np import torch import torch.nn as nn from omegaconf import DictConfig, OmegaConf from torch.utils.data import DataLoader from timm.utils import AverageMeter, random_seed from .visualize import plot_loss_curve from .data import ( Algonauts2025Dataset, load_algonauts2025_friends_fmri, load_algonauts2025_movie10_fmri, load_sharded_features, episode_filter, ) from .stage1.medarc_architecture import MultiSubjectConvLinearEncoder from .stage2.CFM import CFM from .metric import pearsonr_score # DEFAULT_DATA_DIR = ROOT.parent / "algonauts2025/datasets" # Adjust based on workspace DEFAULT_DATA_DIR = Path("/raid/lttung05/fmri_encoder/data") SUBJECTS = (1, 2, 3, 5) def load_features(cfg: DictConfig, model: str, layer: str) -> dict[str, np.ndarray]: data_dir = Path(cfg.datasets_root or DEFAULT_DATA_DIR) friends_features = load_sharded_features( data_dir / "features", model=model, layer=layer, series="friends" ) movie10_features = load_sharded_features( data_dir / "features", model=model, layer=layer, series="movie10" ) features = {**friends_features, **movie10_features} return features def pool_features(features: dict[str, np.ndarray]) -> dict[str, np.ndarray]: pooled = {} for key, feat in features.items(): assert feat.ndim in {2, 3} if feat.ndim == 3: feat = feat.mean(axis=1) pooled[key] = feat return pooled def make_data_loaders(cfg: DictConfig) -> dict[str, DataLoader]: print("loading fmri data") data_dir = Path(cfg.datasets_root or DEFAULT_DATA_DIR) subjects = cfg.get("subjects", SUBJECTS) friends_fmri = load_algonauts2025_friends_fmri( data_dir / "algonauts_2025.competitors", subjects=subjects ) movie10_fmri = load_algonauts2025_movie10_fmri( data_dir / "algonauts_2025.competitors", subjects=subjects ) all_fmri = {**friends_fmri, **movie10_fmri} all_episodes = list(all_fmri) all_features = [] for feat_name in cfg.include_features: model, layer = feat_name.split("/") feat_cfg = cfg.features[model] model_name = feat_cfg.model layer_name = feat_cfg.layers[layer] print(f"loading features {feat_name} ({model_name}/{layer_name})") features = load_features(cfg, model_name, layer_name) if cfg.stage1.model.global_pool == "avg": features = pool_features(features) all_features.append(features) data_loaders = {} for ds_name, ds_cfg in cfg.datasets.items(): print(f"loading dataset: {ds_name}\n\n{OmegaConf.to_yaml(ds_cfg)}") ds_cfg = ds_cfg.copy() filter_cfg = ds_cfg.pop("filter") filter_fn = episode_filter(**filter_cfg) ds_episodes = list(filter(filter_fn, all_episodes)) # print(f"episodes: {ds_name}:\n\n{ds_episodes}") dataset = Algonauts2025Dataset( episode_list=ds_episodes, fmri_data=all_fmri, feat_data=all_features, **ds_cfg, ) batch_size = cfg.batch_size if ds_name == "train" else 1 loader = DataLoader(dataset, batch_size=batch_size) data_loaders[ds_name] = loader return data_loaders def train_one_epoch_condition( *, epoch: int, model: torch.nn.Module, train_loader: DataLoader, optimizer: torch.optim.Optimizer, device: torch.device, ): model.train() use_cuda = device.type == "cuda" if use_cuda: torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() loss_m = AverageMeter() data_time_m = AverageMeter() step_time_m = AverageMeter() end = time.monotonic() for batch_idx, batch in enumerate(train_loader): feats = [f.to(device) for f in batch["features"]] fmri = batch["fmri"].to(device) # (B, S, T, V) = [16, 4, 64, 1000] # print(fmri.shape) batch_size = fmri.size(0) data_time = time.monotonic() - end pred = model(feats) # (B, S, T, V) loss = nn.MSELoss()(pred, fmri) loss_item = loss.item() if math.isnan(loss_item) or math.isinf(loss_item): raise RuntimeError( f"NaN/Inf loss encountered on step {batch_idx + 1}; exiting" ) optimizer.zero_grad() loss.backward() optimizer.step() if use_cuda: torch.cuda.synchronize() step_time = time.monotonic() - end loss_m.update(loss_item, batch_size) data_time_m.update(data_time, batch_size) step_time_m.update(step_time, batch_size) if (batch_idx + 1) % 20 == 0: tput = batch_size / step_time_m.avg if use_cuda: alloc_mem_gb = torch.cuda.max_memory_allocated() / 1e9 res_mem_gb = torch.cuda.max_memory_reserved() / 1e9 else: alloc_mem_gb = res_mem_gb = 0.0 print( f"Stage 1 Train: {epoch:>3d} [{batch_idx:>3d}]" f" Loss: {loss_m.val:#.3g} ({loss_m.avg:#.3g})" f" Time: {data_time_m.avg:.3f},{step_time_m.avg:.3f} {tput:.0f}/s" f" Mem: {alloc_mem_gb:.2f},{res_mem_gb:.2f} GB" ) end = time.monotonic() return loss_m.avg def train_one_epoch_flow_matching( *, epoch: int, stage1_model: torch.nn.Module, stage2_models: nn.ModuleDict, # subject_id -> CFM train_loader: DataLoader, optimizers: Dict[str, torch.optim.Optimizer], device: torch.device, subjects: list, ): stage1_model.eval() for model in stage2_models.values(): model.train() use_cuda = device.type == "cuda" if use_cuda: torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() loss_m = AverageMeter() data_time_m = AverageMeter() step_time_m = AverageMeter() end = time.monotonic() for batch_idx, batch in enumerate(train_loader): feats = [f.to(device) for f in batch["features"]] fmri = batch["fmri"].to(device) # (B, S, T, V) batch_size = fmri.size(0) data_time = time.monotonic() - end # Get Mean Anchor from Stage 1 (Frozen) with torch.no_grad(): mu_anchor = stage1_model(feats) # (B, S, T, V) batch_loss = 0 # Train per-subject vector field for i, sub in enumerate(subjects): sub_key = str(sub) cfm = stage2_models[sub_key] optimizer = optimizers[sub_key] # Prepare data for CFM: Expects (B, C, T) # Input data: (B, T, V) -> Transpose to (B, V, T) x1 = fmri[:, i].transpose(1, 2) mu = mu_anchor[:, i].transpose(1, 2) # CFM Compute Loss: x1=Target, mu=Condition loss, _ = cfm.compute_loss(x1, mu) optimizer.zero_grad() loss.backward() optimizer.step() batch_loss += loss.item() loss_item = batch_loss / len(subjects) if math.isnan(loss_item) or math.isinf(loss_item): raise RuntimeError( f"NaN/Inf loss encountered on step {batch_idx + 1}; exiting" ) if use_cuda: torch.cuda.synchronize() step_time = time.monotonic() - end loss_m.update(loss_item, fmri.size(0)) data_time_m.update(data_time, batch_size) step_time_m.update(step_time, batch_size) if (batch_idx + 1) % 20 == 0: tput = batch_size / step_time_m.avg if use_cuda: alloc_mem_gb = torch.cuda.max_memory_allocated() / 1e9 res_mem_gb = torch.cuda.max_memory_reserved() / 1e9 else: alloc_mem_gb = res_mem_gb = 0.0 print( f"Stage 2 Train: {epoch:>3d} [{batch_idx:>3d}]" f" Loss: {loss_m.val:#.3g} ({loss_m.avg:#.3g})" f" Time: {data_time_m.avg:.3f},{step_time_m.avg:.3f} {tput:.0f}/s" f" Mem: {alloc_mem_gb:.2f},{res_mem_gb:.2f} GB" ) end = time.monotonic() return loss_m.avg @torch.no_grad() def evaluate_stage1( *, epoch: int, model: torch.nn.Module, val_loader: DataLoader, device: torch.device, subjects: list, ds_name: str = "val", ): model.eval() loss_m = AverageMeter() samples = [] outputs = [] for batch_idx, batch in enumerate(val_loader): feats = [f.to(device) for f in batch["features"]] fmri = batch["fmri"].to(device) batch_size = fmri.size(0) pred = model(feats) loss = nn.MSELoss()(pred, fmri) loss_m.update(loss.item(), batch_size) N, S, L, C = fmri.shape assert N, S == (1, 4) outputs.append(pred.cpu().numpy().swapaxes(0, 1).reshape((S, N * L, C))) samples.append(fmri.cpu().numpy().swapaxes(0, 1).reshape((S, N * L, C))) outputs = np.concatenate(outputs, axis=1) samples = np.concatenate(samples, axis=1) metrics = {} # Encoding accuracy metrics dim = samples.shape[-1] acc = 0.0 acc_map = np.zeros(dim) for ii, sub in enumerate(subjects): y_true = samples[ii].reshape(-1, dim) y_pred = outputs[ii].reshape(-1, dim) metrics[f"accmap_sub-{sub}"] = acc_map_i = pearsonr_score(y_true, y_pred) metrics[f"acc_sub-{sub}"] = acc_i = np.mean(acc_map_i) acc_map += acc_map_i / len(subjects) acc += acc_i / len(subjects) metrics["accmap_avg"] = acc_map metrics["acc_avg"] = acc accs_fmt = ",".join( f"{val:.3f}" for key, val in metrics.items() if key.startswith("acc_sub-") ) print( f"Evaluate Stage 1 ({ds_name}): {epoch:>3d}" f" Loss: {loss_m.avg:#.3g}" f" Acc: {accs_fmt} ({acc:.3f})" ) return acc, metrics @torch.no_grad() def evaluate_stage2( *, epoch: int, stage1_model: torch.nn.Module, stage2_models: nn.ModuleDict, val_loader: DataLoader, device: torch.device, subjects: list, ds_name: str = "val", n_timesteps: int = 10, ): stage1_model.eval() for model in stage2_models.values(): model.eval() samples = [] outputs = [] for batch in val_loader: feats = [f.to(device) for f in batch["features"]] fmri = batch["fmri"].to(device) mu_anchor = stage1_model(feats) batch_preds = [] for i, sub in enumerate(subjects): sub_key = str(sub) cfm = stage2_models[sub_key] mu = mu_anchor[:, i].transpose(1, 2) # Predict pred = cfm(mu, n_timesteps=n_timesteps) # (B, V, T) pred = pred.transpose(1, 2).unsqueeze(1) # (B, 1, T, V) batch_preds.append(pred) pred_combined = torch.cat(batch_preds, dim=1) # (B, S, T, V) N, S, L, C = fmri.shape assert N, S == (1, 4) outputs.append( pred_combined.cpu().numpy().swapaxes(0, 1).reshape((S, N * L, C)) ) samples.append(fmri.cpu().numpy().swapaxes(0, 1).reshape((S, N * L, C))) outputs = np.concatenate(outputs, axis=1) samples = np.concatenate(samples, axis=1) metrics = {} dim = samples.shape[-1] acc = 0.0 acc_map = np.zeros(dim) for ii, sub in enumerate(subjects): y_true = samples[ii].reshape(-1, dim) y_pred = outputs[ii].reshape(-1, dim) metrics[f"accmap_sub-{sub}"] = acc_map_i = pearsonr_score(y_true, y_pred) metrics[f"acc_sub-{sub}"] = acc_i = np.mean(acc_map_i) acc_map += acc_map_i / len(subjects) acc += acc_i / len(subjects) metrics["accmap_avg"] = acc_map metrics["acc_avg"] = acc accs_fmt = ",".join( f"{val:.3f}" for key, val in metrics.items() if key.startswith("acc_sub-") ) print(f"Evaluate Stage 2 ({ds_name}): {epoch:>3d}" f" Acc: {accs_fmt} ({acc:.3f})") return acc, metrics def main(): parser = argparse.ArgumentParser() parser.add_argument("--cfg-path", type=str, default="config.yml") args = parser.parse_args() cfg = OmegaConf.load(args.cfg_path) print("Config loaded:\n", OmegaConf.to_yaml(cfg)) out_dir = Path(cfg.out_dir) out_dir.mkdir(parents=True, exist_ok=True) OmegaConf.save(cfg, out_dir / "config.yaml") random_seed(cfg.seed) device = torch.device(cfg.device) # --- Data Loading --- data_loaders = make_data_loaders(cfg) train_loader = data_loaders["train"] val_loaders = data_loaders.copy() val_loaders.pop("train") # --- Model Setup: Stage 1 --- print("Creating Stage 1 Model (Encoder)...") # Get feat dims from first batch sample_batch = next(iter(train_loader)) feat_dims = [f.shape[-1] for f in sample_batch["features"]] subjects_list = cfg.get("subjects", SUBJECTS) stage1_model = MultiSubjectConvLinearEncoder( num_subjects=len(subjects_list), feat_dims=feat_dims, # hidden_model=hidden_model, **cfg.stage1.model, ).to(device) optimizer1 = torch.optim.AdamW( stage1_model.parameters(), lr=cfg.stage1.lr, weight_decay=cfg.stage1.weight_decay, ) # --- Training Loop: Stage 1 --- print("--- Starting Stage 1 Training (Mean Anchor) ---") best_score_s1 = -1.0 stage1_train_losses = [] stage1_val_accs = [] for epoch in range(cfg.stage1.epochs): train_loss = train_one_epoch_condition( epoch=epoch, model=stage1_model, train_loader=train_loader, optimizer=optimizer1, device=device, ) stage1_train_losses.append(train_loss) # Validation val_acc = None for name, loader in val_loaders.items(): acc, _ = evaluate_stage1( epoch=epoch, model=stage1_model, val_loader=loader, device=device, subjects=subjects_list, ds_name=name, ) if name == cfg.val_set_name: val_acc = acc stage1_val_accs.append(val_acc if val_acc is not None else 0.0) if val_acc is not None and val_acc > best_score_s1: best_score_s1 = val_acc torch.save(stage1_model.state_dict(), out_dir / "stage1_best.pt") print("Saved best Stage 1 model.") plot_loss_curve( stage1_train_losses, stage1_val_accs, out_dir, filename="stage1_loss_curve.png", prefix="Stage 1", ) print(f"Stage 1 Training Complete. Best model at Pearson's r {best_score_s1}") # Reload best stage 1 model stage1_model.load_state_dict(torch.load(out_dir / "stage1_best.pt")) stage1_model.eval() # --- Model Setup: Stage 2 --- print("Creating Stage 2 Models (Flow Matching)...") stage2_models = nn.ModuleDict() optimizers2 = {} # Determine target dim from data (V parameter) target_dim = sample_batch["fmri"].shape[-1] cfm_params = cfg.stage2.cfm velocity_net_params = cfg.stage2.velocity_net source_ve_params = cfg.stage2.source_ve transport_params = cfg.stage2.transport for sub in subjects_list: sub_key = str(sub) # Create one CFM per subject for "neural vector field per subject" cfm_model = CFM( feat_dim=target_dim, cfm_params=cfm_params, velocity_net_params=velocity_net_params, source_ve_params=source_ve_params, transport_params=transport_params, ).to(device) stage2_models[sub_key] = cfm_model optimizers2[sub_key] = torch.optim.AdamW( cfm_model.parameters(), lr=cfg.stage2.lr, weight_decay=cfg.stage2.weight_decay, ) # --- Training Loop: Stage 2 --- print("--- Starting Stage 2 Training (Vector Fields) ---") stage2_train_losses = [] for epoch in range(cfg.stage2.epochs): train_loss = train_one_epoch_flow_matching( epoch=epoch, stage1_model=stage1_model, stage2_models=stage2_models, train_loader=train_loader, optimizers=optimizers2, device=device, subjects=subjects_list, ) stage2_train_losses.append(train_loss) # Checkpointing if epoch % 5 == 0 or epoch == cfg.stage2.epochs - 1: ckpt_path = out_dir / f"stage2_epoch_{epoch}.pt" torch.save(stage2_models.state_dict(), ckpt_path) print(f"Saved Stage 2 checkpoint to {ckpt_path}") # --- Add Evaluation print("Evaluating final Stage 2 model...") for name, loader in val_loaders.items(): evaluate_stage2( epoch=cfg.stage2.epochs, stage1_model=stage1_model, stage2_models=stage2_models, val_loader=loader, device=device, subjects=subjects_list, ds_name=name, n_timesteps=cfg.stage2.get("n_timesteps", 25), ) plot_loss_curve( stage2_train_losses, out_path=out_dir, filename="stage2_loss_curve.png", prefix="Stage 2", ) print("Done! All training complete.") if __name__ == "__main__": main()