# %% # ============================================================================= # CFM 2-D Toy Experiment — self-contained, no src path hacks # ============================================================================= # Architecture contract (from your code): # Decoder.forward(x, mu, t) # x : (B, feat_dim, L) # mu : (B, feat_dim, L) # t : (B,) <-- scalar per sample, NOT (B,1,1) # => out : (B, feat_dim, L) # # CFM.compute_loss(x1, mu) # x1 : (B, feat_dim, L) # mu : (B, feat_dim, L) # Inside compute_loss, t is sampled as (B, 1, 1) and passed directly to # estimator — BUT Decoder.time_emb expects (B,). # FIX: squeeze t inside Decoder.forward, or patch compute_loss to pass t.squeeze(). # We patch the Decoder forward to handle both (B,) and (B,1,1). # ============================================================================= import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import matplotlib.pyplot as plt from types import SimpleNamespace from typing import List, Optional from abc import ABC, abstractmethod # ── helpers ────────────────────────────────────────────────────────────────── def sinusoidal_pos_emb(t: torch.Tensor, dim: int) -> torch.Tensor: """t: (B,) -> (B, dim)""" device = t.device half = dim // 2 freqs = torch.exp(-torch.arange(half, device=device) * (np.log(10000) / (half - 1))) args = t[:, None] * freqs[None] return torch.cat([args.sin(), args.cos()], dim=-1) class SinusoidalPosEmb(nn.Module): def __init__(self, dim: int): super().__init__() self.dim = dim def forward(self, t: torch.Tensor) -> torch.Tensor: # accept (B,), (B,1), or (B,1,1) — always return (B, dim) t = t.view(t.shape[0]) return sinusoidal_pos_emb(t, self.dim) # ── MLP block ───────────────────────────────────────────────────────────────── # %% class MLP(nn.Module): def __init__(self, in_c, hidden_c, out_c, time_emb_dim): super().__init__() self.time_net = nn.Sequential(nn.Linear(time_emb_dim, hidden_c), nn.Mish()) self.net1 = nn.Sequential(nn.Conv1d(in_c, hidden_c, 1), nn.ReLU()) self.net2 = nn.Sequential(nn.Conv1d(hidden_c, hidden_c, 1), nn.ReLU()) self.net3 = nn.Sequential(nn.Conv1d(hidden_c, hidden_c, 1), nn.ReLU()) self.out = nn.Conv1d(hidden_c, out_c, 1) def forward(self, x, time_emb): h = self.net1(x) h = h + self.time_net(time_emb).unsqueeze(-1) h = self.net2(h) h = self.net3(h) return self.out(h) # class MLP(nn.Module): # def __init__(self, in_c: int, hidden_c: int, out_c: int, time_emb_dim: int): # super().__init__() # self.time_net = nn.Sequential(nn.Linear(time_emb_dim, hidden_c), nn.Mish()) # self.net1 = nn.Sequential(nn.Linear(in_c, hidden_c), nn.ReLU()) # self.net2 = nn.Linear(hidden_c, out_c) # def forward(self, x: torch.Tensor, time_emb: torch.Tensor) -> torch.Tensor: # # x : (B, in_c, L) # # time_emb : (B, time_emb_dim) # x_t = x.transpose(1, 2) # (B, L, in_c) for Linear # out = self.net1(x_t) # (B, L, hidden_c) # out = out + self.time_net(time_emb).unsqueeze(1) # broadcast over L # out = self.net2(out) # (B, L, out_c) # return out.transpose(1, 2) # (B, out_c, L) # %% # ── Decoder ─────────────────────────────────────────────────────────────────── class Decoder(nn.Module): """ Lightweight MLP velocity estimator for toy 2-D flow-matching. Tensor contract --------------- forward(x, mu, t) -> vel x : (B, feat_dim, L) mu : (B, feat_dim, L) t : (B,) | (B,1) | (B,1,1) # all accepted vel : (B, feat_dim, L) """ def __init__( self, in_c: int = 2, hidden_dim: int = 128, out_c: int = 2, time_emb_dim: int = 64, cond_dim: int = 0, ): super().__init__() self.time_emb = SinusoidalPosEmb(time_emb_dim) self.time_mlp = nn.Sequential( nn.Linear(time_emb_dim, time_emb_dim), ) # concat(x, mu) along channel dim -> 2*feat_dim channels self.net = MLP( in_c=in_c * 2, hidden_c=hidden_dim, out_c=out_c, time_emb_dim=time_emb_dim ) self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0.0, 0.02) if m.bias is not None: nn.init.zeros_(m.bias) def forward( self, x: torch.Tensor, mu: torch.Tensor, t: torch.Tensor, cond=None, ) -> torch.Tensor: # normalise t to (B,) regardless of input shape t_flat = t.reshape(x.shape[0]) # (B,) t_emb = self.time_mlp(self.time_emb(t_flat)) # (B, time_emb_dim) # concat along channel axis (B, 2*feat_dim, L) xmu = torch.cat([x, mu], dim=1) return self.net(xmu, t_emb) # (B, feat_dim, L) # -- SourceGenerator class SourceGenerator(nn.Module): def __init__(self, feat_dim: int, hidden_dim: int = 64): super().__init__() # Outputs 2 * feat_dim to hold both mean and log_var self.net = nn.Sequential( nn.Conv1d(feat_dim, hidden_dim, 1), nn.Mish(), nn.Conv1d(hidden_dim, feat_dim * 2, 1), ) def forward(self, mu: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: # mu: (B, feat_dim, L) out = self.net(mu) # (B, 2*feat_dim, L) mean_c, logvar_c = out.chunk(2, dim=1) # each (B, feat_dim, L) return mean_c, logvar_c # ── BASECFM ─────────────────────────────────────────────────────────────────── class BASECFM(nn.Module, ABC): def __init__(self, feat_dim: int, cfm_params): super().__init__() self.feat_dim = feat_dim self.sigma_min = cfm_params.sigma_min self.estimator: Optional[nn.Module] = None self.src_gen: Optional[nn.Module] = None # ---- inference ----------------------------------------------------------- @torch.inference_mode() def forward( self, mu: torch.Tensor, # (B, feat_dim, L) n_timesteps: int, temperature: float = 1.0, ) -> torch.Tensor: z = self.src_gen(mu) * temperature t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device) return self.solve_euler(z, t_span, mu) def solve_euler( self, x: torch.Tensor, # (B, feat_dim, L) t_span: torch.Tensor, # (n_steps+1,) mu: torch.Tensor, # (B, feat_dim, L) ) -> torch.Tensor: t = t_span[0] dt = t_span[1] - t_span[0] B = x.shape[0] for step in range(1, len(t_span)): t_batch = t.expand(B, device=device) # (B,) dphi_dt = self.estimator(x, mu, t_batch) x = x + dt * dphi_dt t = t + dt if step < len(t_span) - 1: dt = t_span[step + 1] - t return x # ---- training ------------------------------------------------------------ def compute_loss( self, x1: torch.Tensor, # (B, feat_dim, L) mu: torch.Tensor, # (B, feat_dim, L) lambda_var: float = 1, # Hyperparameters from the paper lambda_align: float = 0, ) -> tuple: B = x1.shape[0] # t sampled per sample, broadcast-ready for interpolation t = torch.rand(B, 1, 1, device=mu.device, dtype=mu.dtype) # (B,1,1) # z = torch.randn_like(mu) # (B, C, L) mean_c, logvar_c = self.src_gen(mu) # (B, C, L) eps = torch.randn_like(mean_c) z = mean_c + torch.exp(0.5 * logvar_c) * eps y = (1 - (1 - self.sigma_min) * t) * z + t * x1 # interpolant u = x1 - (1 - self.sigma_min) * z # target velocity # estimator expects t as (B,) t_batch = t.reshape(B) pred = self.estimator(y, mu, t_batch) # 4. Standard Flow Matching Loss loss_fm = F.mse_loss(pred, u) # 5. Variance Regularization Loss [Eq. 9 in paper] # D_KL( N(mu_c, sigma_c^2) || N(mu_c, I) ) = 0.5 * (sigma^2 - log(sigma^2) - 1) loss_var = 0.5 * (torch.exp(logvar_c) - logvar_c - 1).mean() # 6. Cosine Alignment Loss [Eq. 10 in paper] sim = F.cosine_similarity(z.flatten(1), x1.flatten(1), dim=1) loss_align = (1.0 - sim).mean() # 7. Total Loss [Eq. 11 in paper] loss_total = loss_fm + lambda_var * loss_var + lambda_align * loss_align # Return total loss, and a dictionary for logging loss_dict = { "fm": loss_fm.item(), "var": loss_var.item(), "align": loss_align.item(), } return loss_total, loss_dict class CFM(BASECFM): def __init__( self, feat_dim: int, cfm_params, decoder_params: dict, num_classes: int = 8 ): super().__init__(feat_dim=feat_dim, cfm_params=cfm_params) self.estimator = Decoder(in_c=feat_dim, out_c=feat_dim, **decoder_params) self.label_emb = nn.Embedding(num_classes, feat_dim) self.src_gen = SourceGenerator(feat_dim=feat_dim) # %% # ============================================================================= # Experiment: Gaussian -> 8-Gaussians # ============================================================================= np.random.seed(42) torch.manual_seed(42) # ---- GPU setup ------ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") n_samples = 4000 scale = 4.0 centers = np.array( [ (np.cos(t) * scale, np.sin(t) * scale) for t in np.linspace(0, 2 * np.pi, 8, endpoint=False) ] ) assignments = np.random.randint(0, 8, size=n_samples) gaussians_x = centers[assignments] + np.random.randn(n_samples, 2) * 0.4 target_tensor = torch.tensor(gaussians_x, dtype=torch.float32, device=device) goal_dist = (target_tensor - target_tensor.mean(0)) / target_tensor.std(0) # ---- build model ------------------------------------------------------------ cfm_params = SimpleNamespace(sigma_min=1e-4, solver="euler") decoder_params = dict(hidden_dim=256, time_emb_dim=128, cond_dim=0) model = CFM(feat_dim=2, cfm_params=cfm_params, decoder_params=decoder_params).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) # ---- training loop ---------------------------------------------------------- epochs, batch_size = 3000, 512 losses = [] model.train() for epoch in range(epochs): idx = torch.randint(0, n_samples, (batch_size,)) x1 = goal_dist[idx].unsqueeze(-1) # (B, 2, 1) # Conditional -> cluster embedding conditioning labels = torch.tensor(assignments[idx], dtype=torch.long, device=device) mu = model.label_emb(labels).unsqueeze(-1) loss, loss_dict = model.compute_loss(x1, mu) optimizer.zero_grad() loss.backward() optimizer.step() losses.append(loss.item()) if (epoch + 1) % 1000 == 0: print( f"Epoch {epoch+1:5d} loss={loss.item():.5f} | " f"FM={loss_dict['fm']:.5f} | " f"Var={loss_dict['var']:.5f} | " f"Align={loss_dict['align']:.5f}" ) # %% # ---- inference ------------------------------------------------------------- model.eval() n_eval = 1000 eval_labels = torch.arange(8, device=device).repeat_interleave(n_eval // 8 + 1)[ :n_eval ] # TODO: investigate mu_eval = model.label_emb(eval_labels).unsqueeze(-1).detach() steps = 100 t_span = torch.linspace(0, 1, steps + 1, device=device) trajectories = [] with torch.no_grad(): x = torch.randn(mu_eval.size(), device=device) trajectories.append(x.squeeze(-1).cpu().numpy().copy()) t = t_span[0] dt = t_span[1] - t_span[0] snap_at = {0, 20, 40, 60, 80, 100} for step in range(1, len(t_span)): t_batch = t.expand(n_eval) dphi_dt = model.estimator(x, mu_eval, t_batch) x = x + dt * dphi_dt t = t + dt if step < len(t_span) - 1: dt = t_span[step + 1] - t if step in snap_at: trajectories.append(x.squeeze(-1).cpu().numpy().copy()) print(x.max(), " -- ", x.min()) # ---- plot ------------------------------------------------------------------ fig, axes = plt.subplots(1, 7, figsize=(21, 3)) fig.suptitle( "OT-CFM: Gaussian → 8 Gaussians (conditional on cluster label)", fontsize=13, y=1.04, ) times = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, "target"] colors = ["#636EFA", "#7A89FB", "#9BA4FC", "#BCBFFD", "#DDDAFE", "#EF553B", "#00CC96"] for ax, traj, label, c in zip(axes, trajectories, times, colors): ax.scatter(traj[:, 0], traj[:, 1], s=4, alpha=0.6, color=c, linewidths=0) ax.set_xlim(-3.5, 3.5) ax.set_ylim(-3.5, 3.5) ax.set_xlabel("X", fontsize=9) ax.set_ylabel("Y", fontsize=9) ax.set_title(f"t = {label}" if isinstance(label, float) else label, fontsize=10) ax.axis("off") # last panel: overlay ground-truth gt = goal_dist[:1000].cpu().numpy() axes[-1].scatter(gt[:, 0], gt[:, 1], s=4, alpha=0.3, color="#00CC96", linewidths=0) axes[-1].set_xlim(-3.5, 3.5) axes[-1].set_ylim(-3.5, 3.5) axes[-1].set_xlabel("X", fontsize=9) axes[-1].set_ylabel("Y", fontsize=9) axes[-1].set_title("target", fontsize=10) axes[-1].axis("off") # loss curve panel fig2, ax2 = plt.subplots(figsize=(7, 3)) ax2.plot( np.convolve(losses, np.ones(50) / 50, mode="valid"), linewidth=1.2, color="#636EFA" ) ax2.set_xlabel("Epoch") ax2.set_ylabel("MSE Loss") ax2.set_title("CFM Training Loss (50-epoch moving avg)") ax2.spines[["top", "right"]].set_visible(False) plt.tight_layout() fig.savefig("cfm_trajectories.png", dpi=130, bbox_inches="tight") fig2.savefig("cfm_loss.png", dpi=130, bbox_inches="tight") print("Saved cfm_trajectories.png and cfm_loss.png") # %% from torchinfo import summary print(summary(model)) # %% print(goal_dist.max(), goal_dist.min()) # %%