File size: 4,929 Bytes
4edc9aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parent.parent))

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from types import SimpleNamespace
import numpy as np
# Import CFM from src instead of defining locally
from src.stage2.CFM import CFM

# =============================================================================
# Experiment: Gaussian  ->  8-Gaussians
# =============================================================================

np.random.seed(42)
torch.manual_seed(42)

# ---- GPU setup ------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

n_samples = 4000
scale = 4.0
centers = np.array(
    [
        (np.cos(t) * scale, np.sin(t) * scale)
        for t in np.linspace(0, 2 * np.pi, 8, endpoint=False)
    ]
)
assignments = np.random.randint(0, 8, size=n_samples)
gaussians_x = centers[assignments] + np.random.randn(n_samples, 2) * 0.4

target_tensor = torch.tensor(gaussians_x, dtype=torch.float32, device=device)
goal_dist = (target_tensor - target_tensor.mean(0)) / target_tensor.std(0)

# ---- build model ------------------------------------------------------------
cfm_params = SimpleNamespace(sigma_min=1e-4, solver="euler")
decoder_params = dict(hidden_dim=256, time_emb_dim=128, cond_dim=0)
model = CFM(feat_dim=2, cfm_params=cfm_params, decoder_params=decoder_params).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# ---- training loop ----------------------------------------------------------
epochs, batch_size = 3000, 512
losses = []

model.train()
for epoch in range(epochs):
    idx = torch.randint(0, n_samples, (batch_size,))
    x1 = goal_dist[idx].unsqueeze(-1)  # (B, 2, 1)

    # Conditional -> cluster embedding conditioning
    labels = torch.tensor(assignments[idx], dtype=torch.long, device=device)
    mu = model.label_emb(labels).unsqueeze(-1)

    loss, loss_dict = model.compute_loss(x1, mu)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    losses.append(loss.item())

    if (epoch + 1) % 1000 == 0:
        print(
            f"Epoch {epoch+1:5d}  loss={loss.item():.5f} | "
            f"FM={loss_dict['fm']:.5f} | "
            f"Var={loss_dict['var']:.5f} | "
            f"Align={loss_dict['align']:.5f}"
        )

# ---- inference  -------------------------------------------------------------
model.eval()
n_eval = 1000
eval_labels = torch.arange(8, device=device).repeat_interleave(n_eval // 8 + 1)[
    :n_eval
]
mu_eval = model.label_emb(eval_labels).unsqueeze(-1).detach()
steps = 100
t_span = torch.linspace(0, 1, steps + 1, device=device)

trajectories = []
with torch.no_grad():
    x = torch.randn(mu_eval.size(), device=device)
    trajectories.append(x.squeeze(-1).cpu().numpy().copy())

    t = t_span[0]
    dt = t_span[1] - t_span[0]

    snap_at = {0, 20, 40, 60, 80, 100}
    for step in range(1, len(t_span)):
        t_batch = t.expand(n_eval)
        dphi_dt = model.estimator(x, mu_eval, t_batch)
        x = x + dt * dphi_dt
        t = t + dt
        if step < len(t_span) - 1:
            dt = t_span[step + 1] - t
        if step in snap_at:
            trajectories.append(x.squeeze(-1).cpu().numpy().copy())

# ---- plot  ------------------------------------------------------------------
fig, axes = plt.subplots(1, 7, figsize=(21, 3))
fig.suptitle(
    "OT-CFM: Gaussian → 8 Gaussians  (conditional on cluster label)",
    fontsize=13,
    y=1.04,
)

times = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, "target"]
colors = ["#636EFA", "#7A89FB", "#9BA4FC", "#BCBFFD", "#DDDAFE", "#EF553B", "#00CC96"]

for ax, traj, label, c in zip(axes, trajectories, times, colors):
    ax.scatter(traj[:, 0], traj[:, 1], s=4, alpha=0.6, color=c, linewidths=0)
    ax.set_xlim(-3.5, 3.5)
    ax.set_ylim(-3.5, 3.5)
    ax.set_xlabel("X", fontsize=9)
    ax.set_ylabel("Y", fontsize=9)
    ax.set_title(f"t = {label}" if isinstance(label, float) else label, fontsize=10)
    ax.axis("off")

# last panel: overlay ground-truth
gt = goal_dist[:1000].cpu().numpy()
axes[-1].scatter(gt[:, 0], gt[:, 1], s=4, alpha=0.3, color="#00CC96", linewidths=0)
axes[-1].set_xlim(-3.5, 3.5)
axes[-1].set_ylim(-3.5, 3.5)
axes[-1].set_xlabel("X", fontsize=9)
axes[-1].set_ylabel("Y", fontsize=9)
axes[-1].set_title("target", fontsize=10)
axes[-1].axis("off")

# loss curve panel
fig2, ax2 = plt.subplots(figsize=(7, 3))
ax2.plot(
    np.convolve(losses, np.ones(50) / 50, mode="valid"), linewidth=1.2, color="#636EFA"
)
ax2.set_xlabel("Epoch")
ax2.set_ylabel("MSE Loss")
ax2.set_title("CFM Training Loss (50-epoch moving avg)")
ax2.spines[["top", "right"]].set_visible(False)

plt.tight_layout()
fig.savefig("cfm_trajectories_imported.png", dpi=130, bbox_inches="tight")
fig2.savefig("cfm_loss_imported.png", dpi=130, bbox_inches="tight")
print("Saved cfm_trajectories_imported.png and cfm_loss_imported.png")