flying101 commited on
Commit
a6687d5
·
verified ·
1 Parent(s): c49e452

Upload flowmatching.ipynb

Browse files

Pytorch implementation of flow matching for generative modeling on simple 2d sampling points. Amazing walkthrough video by Outlier

Files changed (1) hide show
  1. flowmatching.ipynb +170 -0
flowmatching.ipynb ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "#creating a simple sample of points\n",
10
+ "import numpy as np\n",
11
+ "import matplotlib.pyplot as plt\n",
12
+ "import math\n",
13
+ "import tqdm\n",
14
+ "import torch\n",
15
+ "from torch import nn\n",
16
+ "from matplotlib.colors import ListedColormap\n",
17
+ "\n",
18
+ "N = 1000 #number of points to sample\n",
19
+ "x_min, x_max = -4, 4\n",
20
+ "y_min, y_max = -4, 4\n",
21
+ "resolution = 100 #resolution of the grid\n",
22
+ "\n",
23
+ "x = np.linspace(x_min, x_max, resolution)\n",
24
+ "y = np.linspace(y_min, y_max, resolution)\n",
25
+ "X, Y = np.meshgrid(x, y)\n",
26
+ "\n",
27
+ "length = 4\n",
28
+ "checkerboard = np.indices((length, length)).sum(axis=0) % 2\n",
29
+ "\n",
30
+ "sampled_points = []\n",
31
+ "while len(sampled_points) < N:\n",
32
+ " x_sample = np.random.uniform(x_min, x_max)\n",
33
+ " y_sample = np.random.uniform(y_min, y_max)\n",
34
+ "\n",
35
+ " i = int((x_sample - x_min) / (x_max - x_min) * length)\n",
36
+ " j = int((y_sample - y_min) / (y_max - y_min) * length)\n",
37
+ "\n",
38
+ " if checkerboard[j, i] == 1:\n",
39
+ " sampled_points.append((x_sample, y_sample))\n",
40
+ "sampled_points = np.array(sampled_points) #sampled points is our x1"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": null,
46
+ "metadata": {},
47
+ "outputs": [],
48
+ "source": [
49
+ "t = 0.5\n",
50
+ "noise = np.random.randn(N, 2)\n",
51
+ "plt.figure(figsize=(6, 6))\n",
52
+ "plt.scatter(sampled_points[:, 0], sampled_points[:, 1], color=\"red\", marker=\"o\")\n",
53
+ "plt.scatter(noise[:, 0], noise[:, 1], color=\"blue\", marker=\"o\")\n",
54
+ "plt.scatter((1 - t) * noise[:, 0] + t * sampled_points[:, 0], (1 - t) * noise[:, 1] + t * sampled_points[:, 1], color=\"green\", marker=\"o\")\n",
55
+ "plt.show()"
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "execution_count": null,
61
+ "metadata": {},
62
+ "outputs": [],
63
+ "source": [
64
+ "#Model\n",
65
+ "class Block(nn.Module):\n",
66
+ " def __init__(self, channels=512):\n",
67
+ " super().__init__()\n",
68
+ " self.ff = nn.Linear(channels, channels)\n",
69
+ " self.act = nn.ReLU()\n",
70
+ "\n",
71
+ " def forward(self, x):\n",
72
+ " return self.act(self.ff(x))\n",
73
+ "\n",
74
+ "class MLP(nn.Module):\n",
75
+ " def __init__(self, channels_data=2, layers=5, channels=512, channels_t=512):\n",
76
+ " super().__init__()\n",
77
+ " self.channels_t = channels_t\n",
78
+ " self.in_projection = nn.Linear(channels_data, channels)\n",
79
+ " self.t_projection = nn.Linear(channels_t, channels)\n",
80
+ " self.blocks = nn.Sequential(*[\n",
81
+ " Block(channels) for _ in range(layers)\n",
82
+ " ])\n",
83
+ " self.out_projection = nn.Linear(channels, channels_data)\n",
84
+ "\n",
85
+ " def gen_t_embedding(self, t, max_positions=10000):\n",
86
+ " t = t * max_positions\n",
87
+ " half_dim = self.channels_t // 2\n",
88
+ " emb = math.log(max_positions) / (half_dim - 1)\n",
89
+ " emb = torch.arange(half_dim, device=t.device).float().mul(-emb).exp()\n",
90
+ " emb = t[:, None] * emb[None, :]\n",
91
+ " emb = torch.cat([emb.sin(), emb.cos()], dim=1)\n",
92
+ " if self.channels_t % 2 == 1: # zero pad\n",
93
+ " emb = nn.functional.pad(emb, (0, 1), mode='constant')\n",
94
+ " return emb\n",
95
+ "\n",
96
+ " def forward(self, x, t):\n",
97
+ " x = self.in_projection(x)\n",
98
+ " t = self.gen_t_embedding(t)\n",
99
+ " t = self.t_projection(t)\n",
100
+ " x = x + t \n",
101
+ " x = self.blocks(x)\n",
102
+ " x = self.out_projection(x)\n",
103
+ " return x"
104
+ ]
105
+ },
106
+ {
107
+ "cell_type": "code",
108
+ "execution_count": null,
109
+ "metadata": {},
110
+ "outputs": [],
111
+ "source": [
112
+ "model = MLP(layers=5, channels=512)\n",
113
+ "optim = torch.optim.AdamW(model.parameters(), lr=1e-4)\n",
114
+ "\n",
115
+ "data = torch.Tensor(sampled_points)\n",
116
+ "training_steps = 100_000\n",
117
+ "batch_size = 64\n",
118
+ "pbar = tqdm.tqdm(range(training_steps))\n",
119
+ "losses = []\n",
120
+ "for i in pbar:\n",
121
+ " x1 = data[torch.randint(data.size(0), (batch_size,))]\n",
122
+ " x0 = torch.randn_like(x1)\n",
123
+ " target = x1 - x0\n",
124
+ " t = torch.rand(x1.size(0))\n",
125
+ " xt = (1 - t[:, None]) * x0 + t[:, None] * x1\n",
126
+ " pred = model(xt, t) # also add t here\n",
127
+ " loss = ((target - pred)**2).mean()\n",
128
+ " loss.backward()\n",
129
+ " optim.step()\n",
130
+ " optim.zero_grad()\n",
131
+ " pbar.set_postfix(loss=loss.item())\n",
132
+ " losses.append(loss.item())"
133
+ ]
134
+ },
135
+ {
136
+ "cell_type": "code",
137
+ "execution_count": null,
138
+ "metadata": {},
139
+ "outputs": [],
140
+ "source": [
141
+ "#Sampling\n",
142
+ "torch.manual_seed(42)\n",
143
+ "model.eval().requires_grad_(False)\n",
144
+ "### from here\n",
145
+ "xt = torch.randn(1000, 2)\n",
146
+ "steps = 1000\n",
147
+ "plot_every = 100\n",
148
+ "for i, t in enumerate(torch.linspace(0, 1, steps), start=1):\n",
149
+ " pred = model(xt, t.expand(xt.size(0)))\n",
150
+ " xt = xt + (1 / steps) * pred\n",
151
+ "## to here, this is the sampling logic, and it in this case its moving random noise points into an organized checkerboard\n",
152
+ "##BUT, this sampling is literally applied anywhere from images to videos, because the goal is to move each noise sample to the specific location and modification\n",
153
+ " if i % plot_every == 0:\n",
154
+ " plt.figure(figsize=(6, 6))\n",
155
+ " plt.scatter(sampled_points[:, 0], sampled_points[:, 1], color=\"red\", marker=\"o\")\n",
156
+ " plt.scatter(xt[:, 0], xt[:, 1], color=\"green\", marker=\"o\")\n",
157
+ " plt.show()\n",
158
+ "model.train().requires_grad_(True)"
159
+ ]
160
+ }
161
+ ],
162
+ "metadata": {
163
+ "language_info": {
164
+ "name": "python"
165
+ },
166
+ "orig_nbformat": 4
167
+ },
168
+ "nbformat": 4,
169
+ "nbformat_minor": 2
170
+ }