import torch import torch.nn as nn import torch.optim as optim import matplotlib.pyplot as plt import numpy as np # --- 1. The Neural Vector Field --- # A simple MLP that takes (x, y, t) and outputs velocity (vx, vy) class VectorField(nn.Module): def __init__(self): super().__init__() self.net = nn.Sequential( nn.Linear(3, 64), nn.Tanh(), nn.Linear(64, 64), nn.Tanh(), nn.Linear(64, 2) # Output: (vx, vy) ) def forward(self, x, t): # Concatenate x (Batch, 2) and t (Batch, 1) if t.dim() == 0: t = t.expand(x.shape[0], 1) elif t.dim() == 1: t = t.view(-1, 1) xt = torch.cat([x, t], dim=1) return self.net(xt) # --- 2. Setup Data and Training --- model = VectorField() optimizer = optim.Adam(model.parameters(), lr=1e-3) # Target distribution: Two Gaussian blobs centered at (-2, -2) and (2, 2) def sample_data(batch_size): indices = torch.randint(0, 2, (batch_size,)) centers = torch.tensor([[-2., -2.], [2., 2.]]) noise = torch.randn(batch_size, 2) * 0.5 return centers[indices] + noise # Source distribution: Standard Gaussian centered at (0, 0) def sample_source(batch_size): return torch.randn(batch_size, 2) # --- 3. The Flow Matching Training Loop --- print("Training Flow Matching Model...") for step in range(2000): batch_size = 256 # Sample endpoints x0 = sample_source(batch_size) # Noise x1 = sample_data(batch_size) # Data # Sample random times t ~ U[0, 1] t = torch.rand(batch_size, 1) # Compute the interpolation (linear path) # x_t = (1 - t) * x0 + t * x1 x_t = (1 - t) * x0 + t * x1 # Calculate the target velocity (conditional flow) # u_t = x1 - x0 target_velocity = x1 - x0 # Predict velocity with neural network pred_velocity = model(x_t, t) # Loss: MSE between predicted and target velocity loss = torch.mean((pred_velocity - target_velocity) ** 2) optimizer.zero_grad() loss.backward() optimizer.step() if step % 500 == 0: print(f"Step {step}: Loss = {loss.item():.4f}") # --- 4. Inference: Solving the ODE --- # We solve dx/dt = v(x, t) using a simple Euler solver print("\nSampling (solving ODE)...") with torch.no_grad(): x = sample_source(1000) # Start from noise dt = 0.01 for t_step in np.arange(0, 1, dt): t_tensor = torch.full((x.shape[0], 1), t_step) velocity = model(x, t_tensor) x = x + velocity * dt # Euler update # Visualization final_samples = x.numpy() plt.figure(figsize=(6, 6)) plt.scatter(final_samples[:, 0], final_samples[:, 1], s=10, alpha=0.6, label="Generated") plt.title("Flow Matching Output (Approx. Data Dist.)") plt.grid(True) plt.tight_layout() plt.savefig("flow_matching_output.png") plt.close()