flow-matching / src /training.py
sabertoaster's picture
Upload folder using huggingface_hub
4edc9aa verified
"""
Training script for 2-stage fMRI encoding with Flow Matching.
Stage 1: Train MultiSubjectConvLinearEncoder (Mean Anchor)
Stage 2: Train Conditional Flow Matching (Neural Vector Field) per subject.
"""
import argparse
import json
import math
import sys
import time
from pathlib import Path
from typing import Dict, Any, Optional
import numpy as np
import torch
import torch.nn as nn
from omegaconf import DictConfig, OmegaConf
from torch.utils.data import DataLoader
from timm.utils import AverageMeter, random_seed
from .visualize import plot_loss_curve
from .data import (
Algonauts2025Dataset,
load_algonauts2025_friends_fmri,
load_algonauts2025_movie10_fmri,
load_sharded_features,
episode_filter,
)
from .stage1.medarc_architecture import MultiSubjectConvLinearEncoder
from .stage2.CFM import CFM
from .metric import pearsonr_score
# DEFAULT_DATA_DIR = ROOT.parent / "algonauts2025/datasets" # Adjust based on workspace
DEFAULT_DATA_DIR = Path("/raid/lttung05/fmri_encoder/data")
SUBJECTS = (1, 2, 3, 5)
def load_features(cfg: DictConfig, model: str, layer: str) -> dict[str, np.ndarray]:
data_dir = Path(cfg.datasets_root or DEFAULT_DATA_DIR)
friends_features = load_sharded_features(
data_dir / "features", model=model, layer=layer, series="friends"
)
movie10_features = load_sharded_features(
data_dir / "features", model=model, layer=layer, series="movie10"
)
features = {**friends_features, **movie10_features}
return features
def pool_features(features: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
pooled = {}
for key, feat in features.items():
assert feat.ndim in {2, 3}
if feat.ndim == 3:
feat = feat.mean(axis=1)
pooled[key] = feat
return pooled
def make_data_loaders(cfg: DictConfig) -> dict[str, DataLoader]:
print("loading fmri data")
data_dir = Path(cfg.datasets_root or DEFAULT_DATA_DIR)
subjects = cfg.get("subjects", SUBJECTS)
friends_fmri = load_algonauts2025_friends_fmri(
data_dir / "algonauts_2025.competitors", subjects=subjects
)
movie10_fmri = load_algonauts2025_movie10_fmri(
data_dir / "algonauts_2025.competitors", subjects=subjects
)
all_fmri = {**friends_fmri, **movie10_fmri}
all_episodes = list(all_fmri)
all_features = []
for feat_name in cfg.include_features:
model, layer = feat_name.split("/")
feat_cfg = cfg.features[model]
model_name = feat_cfg.model
layer_name = feat_cfg.layers[layer]
print(f"loading features {feat_name} ({model_name}/{layer_name})")
features = load_features(cfg, model_name, layer_name)
if cfg.stage1.model.global_pool == "avg":
features = pool_features(features)
all_features.append(features)
data_loaders = {}
for ds_name, ds_cfg in cfg.datasets.items():
print(f"loading dataset: {ds_name}\n\n{OmegaConf.to_yaml(ds_cfg)}")
ds_cfg = ds_cfg.copy()
filter_cfg = ds_cfg.pop("filter")
filter_fn = episode_filter(**filter_cfg)
ds_episodes = list(filter(filter_fn, all_episodes))
# print(f"episodes: {ds_name}:\n\n{ds_episodes}")
dataset = Algonauts2025Dataset(
episode_list=ds_episodes,
fmri_data=all_fmri,
feat_data=all_features,
**ds_cfg,
)
batch_size = cfg.batch_size if ds_name == "train" else 1
loader = DataLoader(dataset, batch_size=batch_size)
data_loaders[ds_name] = loader
return data_loaders
def train_one_epoch_condition(
*,
epoch: int,
model: torch.nn.Module,
train_loader: DataLoader,
optimizer: torch.optim.Optimizer,
device: torch.device,
):
model.train()
use_cuda = device.type == "cuda"
if use_cuda:
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
loss_m = AverageMeter()
data_time_m = AverageMeter()
step_time_m = AverageMeter()
end = time.monotonic()
for batch_idx, batch in enumerate(train_loader):
feats = [f.to(device) for f in batch["features"]]
fmri = batch["fmri"].to(device) # (B, S, T, V) = [16, 4, 64, 1000]
# print(fmri.shape)
batch_size = fmri.size(0)
data_time = time.monotonic() - end
pred = model(feats) # (B, S, T, V)
loss = nn.MSELoss()(pred, fmri)
loss_item = loss.item()
if math.isnan(loss_item) or math.isinf(loss_item):
raise RuntimeError(
f"NaN/Inf loss encountered on step {batch_idx + 1}; exiting"
)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if use_cuda:
torch.cuda.synchronize()
step_time = time.monotonic() - end
loss_m.update(loss_item, batch_size)
data_time_m.update(data_time, batch_size)
step_time_m.update(step_time, batch_size)
if (batch_idx + 1) % 20 == 0:
tput = batch_size / step_time_m.avg
if use_cuda:
alloc_mem_gb = torch.cuda.max_memory_allocated() / 1e9
res_mem_gb = torch.cuda.max_memory_reserved() / 1e9
else:
alloc_mem_gb = res_mem_gb = 0.0
print(
f"Stage 1 Train: {epoch:>3d} [{batch_idx:>3d}]"
f" Loss: {loss_m.val:#.3g} ({loss_m.avg:#.3g})"
f" Time: {data_time_m.avg:.3f},{step_time_m.avg:.3f} {tput:.0f}/s"
f" Mem: {alloc_mem_gb:.2f},{res_mem_gb:.2f} GB"
)
end = time.monotonic()
return loss_m.avg
def train_one_epoch_flow_matching(
*,
epoch: int,
stage1_model: torch.nn.Module,
stage2_models: nn.ModuleDict, # subject_id -> CFM
train_loader: DataLoader,
optimizers: Dict[str, torch.optim.Optimizer],
device: torch.device,
subjects: list,
):
stage1_model.eval()
for model in stage2_models.values():
model.train()
use_cuda = device.type == "cuda"
if use_cuda:
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
loss_m = AverageMeter()
data_time_m = AverageMeter()
step_time_m = AverageMeter()
end = time.monotonic()
for batch_idx, batch in enumerate(train_loader):
feats = [f.to(device) for f in batch["features"]]
fmri = batch["fmri"].to(device) # (B, S, T, V)
batch_size = fmri.size(0)
data_time = time.monotonic() - end
# Get Mean Anchor from Stage 1 (Frozen)
with torch.no_grad():
mu_anchor = stage1_model(feats) # (B, S, T, V)
batch_loss = 0
# Train per-subject vector field
for i, sub in enumerate(subjects):
sub_key = str(sub)
cfm = stage2_models[sub_key]
optimizer = optimizers[sub_key]
# Prepare data for CFM: Expects (B, C, T)
# Input data: (B, T, V) -> Transpose to (B, V, T)
x1 = fmri[:, i].transpose(1, 2)
mu = mu_anchor[:, i].transpose(1, 2)
# CFM Compute Loss: x1=Target, mu=Condition
loss, _ = cfm.compute_loss(x1, mu)
optimizer.zero_grad()
loss.backward()
optimizer.step()
batch_loss += loss.item()
loss_item = batch_loss / len(subjects)
if math.isnan(loss_item) or math.isinf(loss_item):
raise RuntimeError(
f"NaN/Inf loss encountered on step {batch_idx + 1}; exiting"
)
if use_cuda:
torch.cuda.synchronize()
step_time = time.monotonic() - end
loss_m.update(loss_item, fmri.size(0))
data_time_m.update(data_time, batch_size)
step_time_m.update(step_time, batch_size)
if (batch_idx + 1) % 20 == 0:
tput = batch_size / step_time_m.avg
if use_cuda:
alloc_mem_gb = torch.cuda.max_memory_allocated() / 1e9
res_mem_gb = torch.cuda.max_memory_reserved() / 1e9
else:
alloc_mem_gb = res_mem_gb = 0.0
print(
f"Stage 2 Train: {epoch:>3d} [{batch_idx:>3d}]"
f" Loss: {loss_m.val:#.3g} ({loss_m.avg:#.3g})"
f" Time: {data_time_m.avg:.3f},{step_time_m.avg:.3f} {tput:.0f}/s"
f" Mem: {alloc_mem_gb:.2f},{res_mem_gb:.2f} GB"
)
end = time.monotonic()
return loss_m.avg
@torch.no_grad()
def evaluate_stage1(
*,
epoch: int,
model: torch.nn.Module,
val_loader: DataLoader,
device: torch.device,
subjects: list,
ds_name: str = "val",
):
model.eval()
loss_m = AverageMeter()
samples = []
outputs = []
for batch_idx, batch in enumerate(val_loader):
feats = [f.to(device) for f in batch["features"]]
fmri = batch["fmri"].to(device)
batch_size = fmri.size(0)
pred = model(feats)
loss = nn.MSELoss()(pred, fmri)
loss_m.update(loss.item(), batch_size)
N, S, L, C = fmri.shape
assert N, S == (1, 4)
outputs.append(pred.cpu().numpy().swapaxes(0, 1).reshape((S, N * L, C)))
samples.append(fmri.cpu().numpy().swapaxes(0, 1).reshape((S, N * L, C)))
outputs = np.concatenate(outputs, axis=1)
samples = np.concatenate(samples, axis=1)
metrics = {}
# Encoding accuracy metrics
dim = samples.shape[-1]
acc = 0.0
acc_map = np.zeros(dim)
for ii, sub in enumerate(subjects):
y_true = samples[ii].reshape(-1, dim)
y_pred = outputs[ii].reshape(-1, dim)
metrics[f"accmap_sub-{sub}"] = acc_map_i = pearsonr_score(y_true, y_pred)
metrics[f"acc_sub-{sub}"] = acc_i = np.mean(acc_map_i)
acc_map += acc_map_i / len(subjects)
acc += acc_i / len(subjects)
metrics["accmap_avg"] = acc_map
metrics["acc_avg"] = acc
accs_fmt = ",".join(
f"{val:.3f}" for key, val in metrics.items() if key.startswith("acc_sub-")
)
print(
f"Evaluate Stage 1 ({ds_name}): {epoch:>3d}"
f" Loss: {loss_m.avg:#.3g}"
f" Acc: {accs_fmt} ({acc:.3f})"
)
return acc, metrics
@torch.no_grad()
def evaluate_stage2(
*,
epoch: int,
stage1_model: torch.nn.Module,
stage2_models: nn.ModuleDict,
val_loader: DataLoader,
device: torch.device,
subjects: list,
ds_name: str = "val",
n_timesteps: int = 10,
):
stage1_model.eval()
for model in stage2_models.values():
model.eval()
samples = []
outputs = []
for batch in val_loader:
feats = [f.to(device) for f in batch["features"]]
fmri = batch["fmri"].to(device)
mu_anchor = stage1_model(feats)
batch_preds = []
for i, sub in enumerate(subjects):
sub_key = str(sub)
cfm = stage2_models[sub_key]
mu = mu_anchor[:, i].transpose(1, 2)
# Predict
pred = cfm(mu, n_timesteps=n_timesteps) # (B, V, T)
pred = pred.transpose(1, 2).unsqueeze(1) # (B, 1, T, V)
batch_preds.append(pred)
pred_combined = torch.cat(batch_preds, dim=1) # (B, S, T, V)
N, S, L, C = fmri.shape
assert N, S == (1, 4)
outputs.append(
pred_combined.cpu().numpy().swapaxes(0, 1).reshape((S, N * L, C))
)
samples.append(fmri.cpu().numpy().swapaxes(0, 1).reshape((S, N * L, C)))
outputs = np.concatenate(outputs, axis=1)
samples = np.concatenate(samples, axis=1)
metrics = {}
dim = samples.shape[-1]
acc = 0.0
acc_map = np.zeros(dim)
for ii, sub in enumerate(subjects):
y_true = samples[ii].reshape(-1, dim)
y_pred = outputs[ii].reshape(-1, dim)
metrics[f"accmap_sub-{sub}"] = acc_map_i = pearsonr_score(y_true, y_pred)
metrics[f"acc_sub-{sub}"] = acc_i = np.mean(acc_map_i)
acc_map += acc_map_i / len(subjects)
acc += acc_i / len(subjects)
metrics["accmap_avg"] = acc_map
metrics["acc_avg"] = acc
accs_fmt = ",".join(
f"{val:.3f}" for key, val in metrics.items() if key.startswith("acc_sub-")
)
print(f"Evaluate Stage 2 ({ds_name}): {epoch:>3d}" f" Acc: {accs_fmt} ({acc:.3f})")
return acc, metrics
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--cfg-path", type=str, default="config.yml")
args = parser.parse_args()
cfg = OmegaConf.load(args.cfg_path)
print("Config loaded:\n", OmegaConf.to_yaml(cfg))
out_dir = Path(cfg.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
OmegaConf.save(cfg, out_dir / "config.yaml")
random_seed(cfg.seed)
device = torch.device(cfg.device)
# --- Data Loading ---
data_loaders = make_data_loaders(cfg)
train_loader = data_loaders["train"]
val_loaders = data_loaders.copy()
val_loaders.pop("train")
# --- Model Setup: Stage 1 ---
print("Creating Stage 1 Model (Encoder)...")
# Get feat dims from first batch
sample_batch = next(iter(train_loader))
feat_dims = [f.shape[-1] for f in sample_batch["features"]]
subjects_list = cfg.get("subjects", SUBJECTS)
stage1_model = MultiSubjectConvLinearEncoder(
num_subjects=len(subjects_list),
feat_dims=feat_dims,
# hidden_model=hidden_model,
**cfg.stage1.model,
).to(device)
optimizer1 = torch.optim.AdamW(
stage1_model.parameters(),
lr=cfg.stage1.lr,
weight_decay=cfg.stage1.weight_decay,
)
# --- Training Loop: Stage 1 ---
print("--- Starting Stage 1 Training (Mean Anchor) ---")
best_score_s1 = -1.0
stage1_train_losses = []
stage1_val_accs = []
for epoch in range(cfg.stage1.epochs):
train_loss = train_one_epoch_condition(
epoch=epoch,
model=stage1_model,
train_loader=train_loader,
optimizer=optimizer1,
device=device,
)
stage1_train_losses.append(train_loss)
# Validation
val_acc = None
for name, loader in val_loaders.items():
acc, _ = evaluate_stage1(
epoch=epoch,
model=stage1_model,
val_loader=loader,
device=device,
subjects=subjects_list,
ds_name=name,
)
if name == cfg.val_set_name:
val_acc = acc
stage1_val_accs.append(val_acc if val_acc is not None else 0.0)
if val_acc is not None and val_acc > best_score_s1:
best_score_s1 = val_acc
torch.save(stage1_model.state_dict(), out_dir / "stage1_best.pt")
print("Saved best Stage 1 model.")
plot_loss_curve(
stage1_train_losses,
stage1_val_accs,
out_dir,
filename="stage1_loss_curve.png",
prefix="Stage 1",
)
print(f"Stage 1 Training Complete. Best model at Pearson's r {best_score_s1}")
# Reload best stage 1 model
stage1_model.load_state_dict(torch.load(out_dir / "stage1_best.pt"))
stage1_model.eval()
# --- Model Setup: Stage 2 ---
print("Creating Stage 2 Models (Flow Matching)...")
stage2_models = nn.ModuleDict()
optimizers2 = {}
# Determine target dim from data (V parameter)
target_dim = sample_batch["fmri"].shape[-1]
cfm_params = cfg.stage2.cfm
velocity_net_params = cfg.stage2.velocity_net
source_ve_params = cfg.stage2.source_ve
transport_params = cfg.stage2.transport
for sub in subjects_list:
sub_key = str(sub)
# Create one CFM per subject for "neural vector field per subject"
cfm_model = CFM(
feat_dim=target_dim,
cfm_params=cfm_params,
velocity_net_params=velocity_net_params,
source_ve_params=source_ve_params,
transport_params=transport_params,
).to(device)
stage2_models[sub_key] = cfm_model
optimizers2[sub_key] = torch.optim.AdamW(
cfm_model.parameters(),
lr=cfg.stage2.lr,
weight_decay=cfg.stage2.weight_decay,
)
# --- Training Loop: Stage 2 ---
print("--- Starting Stage 2 Training (Vector Fields) ---")
stage2_train_losses = []
for epoch in range(cfg.stage2.epochs):
train_loss = train_one_epoch_flow_matching(
epoch=epoch,
stage1_model=stage1_model,
stage2_models=stage2_models,
train_loader=train_loader,
optimizers=optimizers2,
device=device,
subjects=subjects_list,
)
stage2_train_losses.append(train_loss)
# Checkpointing
if epoch % 5 == 0 or epoch == cfg.stage2.epochs - 1:
ckpt_path = out_dir / f"stage2_epoch_{epoch}.pt"
torch.save(stage2_models.state_dict(), ckpt_path)
print(f"Saved Stage 2 checkpoint to {ckpt_path}")
# --- Add Evaluation
print("Evaluating final Stage 2 model...")
for name, loader in val_loaders.items():
evaluate_stage2(
epoch=cfg.stage2.epochs,
stage1_model=stage1_model,
stage2_models=stage2_models,
val_loader=loader,
device=device,
subjects=subjects_list,
ds_name=name,
n_timesteps=cfg.stage2.get("n_timesteps", 25),
)
plot_loss_curve(
stage2_train_losses,
out_path=out_dir,
filename="stage2_loss_curve.png",
prefix="Stage 2",
)
print("Done! All training complete.")
if __name__ == "__main__":
main()