File size: 4,929 Bytes
4edc9aa | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 | import sys
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parent.parent))
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from types import SimpleNamespace
import numpy as np
# Import CFM from src instead of defining locally
from src.stage2.CFM import CFM
# =============================================================================
# Experiment: Gaussian -> 8-Gaussians
# =============================================================================
np.random.seed(42)
torch.manual_seed(42)
# ---- GPU setup ------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
n_samples = 4000
scale = 4.0
centers = np.array(
[
(np.cos(t) * scale, np.sin(t) * scale)
for t in np.linspace(0, 2 * np.pi, 8, endpoint=False)
]
)
assignments = np.random.randint(0, 8, size=n_samples)
gaussians_x = centers[assignments] + np.random.randn(n_samples, 2) * 0.4
target_tensor = torch.tensor(gaussians_x, dtype=torch.float32, device=device)
goal_dist = (target_tensor - target_tensor.mean(0)) / target_tensor.std(0)
# ---- build model ------------------------------------------------------------
cfm_params = SimpleNamespace(sigma_min=1e-4, solver="euler")
decoder_params = dict(hidden_dim=256, time_emb_dim=128, cond_dim=0)
model = CFM(feat_dim=2, cfm_params=cfm_params, decoder_params=decoder_params).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# ---- training loop ----------------------------------------------------------
epochs, batch_size = 3000, 512
losses = []
model.train()
for epoch in range(epochs):
idx = torch.randint(0, n_samples, (batch_size,))
x1 = goal_dist[idx].unsqueeze(-1) # (B, 2, 1)
# Conditional -> cluster embedding conditioning
labels = torch.tensor(assignments[idx], dtype=torch.long, device=device)
mu = model.label_emb(labels).unsqueeze(-1)
loss, loss_dict = model.compute_loss(x1, mu)
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.append(loss.item())
if (epoch + 1) % 1000 == 0:
print(
f"Epoch {epoch+1:5d} loss={loss.item():.5f} | "
f"FM={loss_dict['fm']:.5f} | "
f"Var={loss_dict['var']:.5f} | "
f"Align={loss_dict['align']:.5f}"
)
# ---- inference -------------------------------------------------------------
model.eval()
n_eval = 1000
eval_labels = torch.arange(8, device=device).repeat_interleave(n_eval // 8 + 1)[
:n_eval
]
mu_eval = model.label_emb(eval_labels).unsqueeze(-1).detach()
steps = 100
t_span = torch.linspace(0, 1, steps + 1, device=device)
trajectories = []
with torch.no_grad():
x = torch.randn(mu_eval.size(), device=device)
trajectories.append(x.squeeze(-1).cpu().numpy().copy())
t = t_span[0]
dt = t_span[1] - t_span[0]
snap_at = {0, 20, 40, 60, 80, 100}
for step in range(1, len(t_span)):
t_batch = t.expand(n_eval)
dphi_dt = model.estimator(x, mu_eval, t_batch)
x = x + dt * dphi_dt
t = t + dt
if step < len(t_span) - 1:
dt = t_span[step + 1] - t
if step in snap_at:
trajectories.append(x.squeeze(-1).cpu().numpy().copy())
# ---- plot ------------------------------------------------------------------
fig, axes = plt.subplots(1, 7, figsize=(21, 3))
fig.suptitle(
"OT-CFM: Gaussian → 8 Gaussians (conditional on cluster label)",
fontsize=13,
y=1.04,
)
times = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, "target"]
colors = ["#636EFA", "#7A89FB", "#9BA4FC", "#BCBFFD", "#DDDAFE", "#EF553B", "#00CC96"]
for ax, traj, label, c in zip(axes, trajectories, times, colors):
ax.scatter(traj[:, 0], traj[:, 1], s=4, alpha=0.6, color=c, linewidths=0)
ax.set_xlim(-3.5, 3.5)
ax.set_ylim(-3.5, 3.5)
ax.set_xlabel("X", fontsize=9)
ax.set_ylabel("Y", fontsize=9)
ax.set_title(f"t = {label}" if isinstance(label, float) else label, fontsize=10)
ax.axis("off")
# last panel: overlay ground-truth
gt = goal_dist[:1000].cpu().numpy()
axes[-1].scatter(gt[:, 0], gt[:, 1], s=4, alpha=0.3, color="#00CC96", linewidths=0)
axes[-1].set_xlim(-3.5, 3.5)
axes[-1].set_ylim(-3.5, 3.5)
axes[-1].set_xlabel("X", fontsize=9)
axes[-1].set_ylabel("Y", fontsize=9)
axes[-1].set_title("target", fontsize=10)
axes[-1].axis("off")
# loss curve panel
fig2, ax2 = plt.subplots(figsize=(7, 3))
ax2.plot(
np.convolve(losses, np.ones(50) / 50, mode="valid"), linewidth=1.2, color="#636EFA"
)
ax2.set_xlabel("Epoch")
ax2.set_ylabel("MSE Loss")
ax2.set_title("CFM Training Loss (50-epoch moving avg)")
ax2.spines[["top", "right"]].set_visible(False)
plt.tight_layout()
fig.savefig("cfm_trajectories_imported.png", dpi=130, bbox_inches="tight")
fig2.savefig("cfm_loss_imported.png", dpi=130, bbox_inches="tight")
print("Saved cfm_trajectories_imported.png and cfm_loss_imported.png")
|