| | """ |
| | Evaluation and WandB visualization for diffusion models on The Well. |
| | |
| | Produces: |
| | - Single-step comparison images: Condition | Ground Truth | Prediction |
| | - Multi-step rollout videos: GT trajectory vs Predicted trajectory (side-by-side) |
| | - Per-step MSE metrics for rollout quality analysis |
| | """ |
| | import numpy as np |
| | import torch |
| | import torch.nn.functional as F |
| | import logging |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def _get_colormap(name="RdBu_r"): |
| | """Return a colormap function (avoids repeated imports).""" |
| | import matplotlib |
| | matplotlib.use("Agg") |
| | import matplotlib.cm as cm |
| | return cm.get_cmap(name) |
| |
|
| | _CMAP_CACHE = {} |
| |
|
| | def apply_colormap(field_01, cmap_name="RdBu_r"): |
| | """[H, W] float in [0,1] → [H, W, 3] uint8 RGB.""" |
| | if cmap_name not in _CMAP_CACHE: |
| | _CMAP_CACHE[cmap_name] = _get_colormap(cmap_name) |
| | rgba = _CMAP_CACHE[cmap_name](np.clip(field_01, 0, 1)) |
| | return (rgba[:, :, :3] * 255).astype(np.uint8) |
| |
|
| |
|
| | def normalize_for_vis(f, vmin=None, vmax=None): |
| | """Percentile-robust normalization to [0, 1].""" |
| | if vmin is None: |
| | vmin = np.percentile(f, 2) |
| | if vmax is None: |
| | vmax = np.percentile(f, 98) |
| | return np.clip((f - vmin) / max(vmax - vmin, 1e-8), 0, 1), vmin, vmax |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def _comparison_image(cond, gt, pred, cmap="RdBu_r"): |
| | """Build a [H, W*3+4, 3] uint8 image: Cond | GT | Pred.""" |
| | vals = np.concatenate([cond.flat, gt.flat, pred.flat]) |
| | vmin, vmax = np.percentile(vals, 2), np.percentile(vals, 98) |
| |
|
| | def rgb(f): |
| | n, _, _ = normalize_for_vis(f, vmin, vmax) |
| | return apply_colormap(n, cmap) |
| |
|
| | H = cond.shape[0] |
| | sep = np.full((H, 2, 3), 200, dtype=np.uint8) |
| | return np.concatenate([rgb(cond), sep, rgb(gt), sep, rgb(pred)], axis=1) |
| |
|
| |
|
| | @torch.no_grad() |
| | def single_step_eval(model, val_loader, device, n_batches=4, ddim_steps=50): |
| | """Compute val MSE and generate comparison images. |
| | |
| | Returns: |
| | metrics: dict {'val/mse': float} |
| | comparisons: list of (image_array, caption_string) |
| | """ |
| | from data_pipeline import prepare_batch |
| |
|
| | model.eval() |
| | total_mse, n_samples = 0.0, 0 |
| | first_data = None |
| |
|
| | for i, batch in enumerate(val_loader): |
| | if i >= n_batches: |
| | break |
| | x_cond, x_target = prepare_batch(batch, device) |
| | x_pred = model.sample_ddim(x_cond, shape=x_target.shape, steps=ddim_steps) |
| |
|
| | mse = F.mse_loss(x_pred, x_target).item() |
| | total_mse += mse * x_target.shape[0] |
| | n_samples += x_target.shape[0] |
| |
|
| | if i == 0: |
| | first_data = (x_cond[:4].cpu(), x_target[:4].cpu(), x_pred[:4].cpu()) |
| |
|
| | avg_mse = total_mse / max(n_samples, 1) |
| |
|
| | comparisons = [] |
| | if first_data is not None: |
| | xc, xt, xp = first_data |
| | n_ch = min(xc.shape[1], 4) |
| | for b in range(xc.shape[0]): |
| | for ch in range(n_ch): |
| | img = _comparison_image( |
| | xc[b, ch].numpy(), xt[b, ch].numpy(), xp[b, ch].numpy() |
| | ) |
| | comparisons.append((img, f"sample{b}_ch{ch}")) |
| |
|
| | model.train() |
| | return {"val/mse": avg_mse}, comparisons |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | @torch.no_grad() |
| | def rollout_eval( |
| | model, rollout_loader, device, |
| | n_rollout=20, ddim_steps=50, channel=0, cmap="RdBu_r", |
| | ): |
| | """Autoregressive rollout with GT comparison video. |
| | |
| | Creates side-by-side video: Ground Truth | Prediction |
| | and computes per-step MSE. |
| | |
| | Args: |
| | model: GaussianDiffusion instance. |
| | rollout_loader: DataLoader with n_steps_output >= n_rollout. |
| | device: torch device. |
| | n_rollout: autoregressive prediction steps. |
| | ddim_steps: DDIM denoising steps per prediction. |
| | channel: which field channel to visualize. |
| | cmap: matplotlib colormap. |
| | |
| | Returns: |
| | video: [T, 3, H, W_combined] uint8 for wandb.Video. |
| | per_step_mse: list[float] of length n_rollout. |
| | """ |
| | model.eval() |
| | batch = next(iter(rollout_loader)) |
| |
|
| | |
| | inp = batch["input_fields"][:1] |
| | out = batch["output_fields"][:1] |
| |
|
| | T_out = out.shape[1] |
| | n_steps = min(n_rollout, T_out) |
| | C = inp.shape[-1] |
| |
|
| | |
| | x_cond = inp[:, 0].permute(0, 3, 1, 2).float().to(device) |
| |
|
| | |
| | gt_frames = [out[:, t].permute(0, 3, 1, 2).float() for t in range(n_steps)] |
| |
|
| | |
| | pred_frames = [] |
| | per_step_mse = [] |
| | cond = x_cond |
| |
|
| | for t in range(n_steps): |
| | pred = model.sample_ddim(cond, shape=cond.shape, steps=ddim_steps, eta=0.0) |
| | pred_cpu = pred.cpu() |
| | pred_frames.append(pred_cpu) |
| |
|
| | mse_t = F.mse_loss(pred_cpu, gt_frames[t]).item() |
| | per_step_mse.append(mse_t) |
| |
|
| | cond = pred |
| | if (t + 1) % 5 == 0: |
| | logger.info(f" rollout step {t+1}/{n_steps}, mse={mse_t:.6f}") |
| |
|
| | |
| | ch = min(channel, C - 1) |
| |
|
| | |
| | all_vals = [x_cond[0, ch].cpu().numpy().flat] |
| | for t in range(n_steps): |
| | all_vals.append(gt_frames[t][0, ch].numpy().flat) |
| | all_vals.append(pred_frames[t][0, ch].numpy().flat) |
| | all_vals = np.concatenate(list(all_vals)) |
| | vmin, vmax = np.percentile(all_vals, 2), np.percentile(all_vals, 98) |
| |
|
| | def to_rgb(field_2d): |
| | n, _, _ = normalize_for_vis(field_2d, vmin, vmax) |
| | return apply_colormap(n, cmap) |
| |
|
| | H, W = x_cond.shape[2], x_cond.shape[3] |
| | sep = np.full((H, 4, 3), 200, dtype=np.uint8) |
| |
|
| | |
| | def _label_frame(gt_rgb, pred_rgb): |
| | """Concatenate with separator.""" |
| | return np.concatenate([gt_rgb, sep, pred_rgb], axis=1) |
| |
|
| | frames = [] |
| |
|
| | |
| | init_rgb = to_rgb(x_cond[0, ch].cpu().numpy()) |
| | frames.append(_label_frame(init_rgb, init_rgb).transpose(2, 0, 1)) |
| |
|
| | |
| | for t in range(n_steps): |
| | gt_rgb = to_rgb(gt_frames[t][0, ch].numpy()) |
| | pred_rgb = to_rgb(pred_frames[t][0, ch].numpy()) |
| | frames.append(_label_frame(gt_rgb, pred_rgb).transpose(2, 0, 1)) |
| |
|
| | video = np.stack(frames).astype(np.uint8) |
| |
|
| | model.train() |
| | return video, per_step_mse |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def run_evaluation( |
| | model, val_loader, rollout_loader, device, |
| | global_step, wandb_run=None, |
| | n_val_batches=4, n_rollout=20, ddim_steps=50, |
| | ): |
| | """Run full evaluation: single-step metrics + rollout video. |
| | |
| | Logs everything to WandB if wandb_run is provided. |
| | |
| | Returns: |
| | dict of all metrics. |
| | """ |
| | logger.info("Running single-step evaluation...") |
| | metrics, comparisons = single_step_eval( |
| | model, val_loader, device, n_batches=n_val_batches, ddim_steps=ddim_steps |
| | ) |
| | logger.info(f" val/mse = {metrics['val/mse']:.6f}") |
| |
|
| | logger.info(f"Running {n_rollout}-step rollout evaluation...") |
| | video, rollout_mse = rollout_eval( |
| | model, rollout_loader, device, n_rollout=n_rollout, ddim_steps=ddim_steps |
| | ) |
| | logger.info(f" rollout MSE (step 1/last): {rollout_mse[0]:.6f} / {rollout_mse[-1]:.6f}") |
| |
|
| | |
| | metrics["val/rollout_mse_mean"] = float(np.mean(rollout_mse)) |
| | metrics["val/rollout_mse_final"] = rollout_mse[-1] |
| | for t, m in enumerate(rollout_mse): |
| | metrics[f"val/rollout_mse_step{t}"] = m |
| |
|
| | |
| | if wandb_run is not None: |
| | import wandb |
| |
|
| | wandb_run.log(metrics, step=global_step) |
| |
|
| | |
| | for img, caption in comparisons[:8]: |
| | wandb_run.log( |
| | {f"eval/{caption}": wandb.Image(img, caption="Cond | GT | Pred")}, |
| | step=global_step, |
| | ) |
| |
|
| | |
| | wandb_run.log( |
| | {"eval/rollout_video": wandb.Video(video, fps=4, format="mp4", |
| | caption="Left=GT Right=Prediction")}, |
| | step=global_step, |
| | ) |
| |
|
| | |
| | table = wandb.Table(columns=["step", "mse"], data=[[t, m] for t, m in enumerate(rollout_mse)]) |
| | wandb_run.log( |
| | {"eval/rollout_mse_curve": wandb.plot.line( |
| | table, "step", "mse", title="Rollout MSE vs Step" |
| | )}, |
| | step=global_step, |
| | ) |
| |
|
| | return metrics |
| |
|