krystv commited on
Commit
b758a3b
·
verified ·
1 Parent(s): 0034111

Upload notebook.ipynb with huggingface_hub

Browse files
Files changed (1) hide show
  1. notebook.ipynb +563 -0
notebook.ipynb ADDED
@@ -0,0 +1,563 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# 🔬 LatentRecurrentFlow (LRF) — A Novel Mobile-First Image Generation Architecture\n",
8
+ "\n",
9
+ "**A complete implementation of a novel image generation architecture designed for consumer devices.**\n",
10
+ "\n",
11
+ "## Key Innovations\n",
12
+ "\n",
13
+ "1. **Recursive Latent Refinement (RLR) Core** — HRM-inspired iterative reasoning on image latents with O(1) memory backpropagation\n",
14
+ "2. **Gated Linear Diffusion (GLD) Blocks** — O(N) subquadratic spatial mixing replacing quadratic self-attention\n",
15
+ "3. **Compact f=16 VAE** with SnapGen-inspired tiny decoder (1-2M params)\n",
16
+ "4. **Rectified Flow** training with consistency distillation readiness\n",
17
+ "5. **Editing-ready architecture** — same latent core supports text-to-image, inpainting, style editing, and more\n",
18
+ "\n",
19
+ "### Memory Budget\n",
20
+ "| Component | FP32 | INT8 (Mobile) |\n",
21
+ "|-----------|------|---------------|\n",
22
+ "| VAE Decoder | 4 MB | 1 MB |\n",
23
+ "| Text Encoder | 44 MB | 11 MB |\n",
24
+ "| Denoising Core | 2.5 MB | 0.6 MB |\n",
25
+ "| Activations (256²) | ~200 MB | ~100 MB |\n",
26
+ "| **Total** | **~250 MB** | **~113 MB** |\n",
27
+ "\n",
28
+ "This notebook demonstrates:\n",
29
+ "1. Architecture design and parameter counting\n",
30
+ "2. End-to-end VAE training\n",
31
+ "3. Flow matching denoiser training\n",
32
+ "4. Sample generation\n",
33
+ "5. Model saving and loading"
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "markdown",
38
+ "metadata": {},
39
+ "source": [
40
+ "## 0. Installation"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": null,
46
+ "metadata": {},
47
+ "outputs": [],
48
+ "source": [
49
+ "# Install dependencies\n",
50
+ "!pip install -q torch torchvision einops safetensors huggingface_hub pillow matplotlib"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "code",
55
+ "execution_count": null,
56
+ "metadata": {},
57
+ "outputs": [],
58
+ "source": [
59
+ "# Clone the LRF repo (if not already available)\n",
60
+ "import os, sys\n",
61
+ "\n",
62
+ "# If running from the repo, just add to path\n",
63
+ "if os.path.exists('lrf'):\n",
64
+ " sys.path.insert(0, '.')\n",
65
+ "else:\n",
66
+ " # Clone from HF Hub\n",
67
+ " !git clone https://huggingface.co/krystv/LatentRecurrentFlow\n",
68
+ " sys.path.insert(0, 'LatentRecurrentFlow')\n",
69
+ "\n",
70
+ "from lrf.model import LatentRecurrentFlow, RecursiveLatentCore, CompactVAE, GatedLinearAttention\n",
71
+ "from lrf.training import LRFTrainer, RectifiedFlowScheduler, SyntheticImageTextDataset\n",
72
+ "from lrf.pipeline import LRFPipeline, LRFTrainingPipeline\n",
73
+ "\n",
74
+ "import torch\n",
75
+ "import torch.nn.functional as F\n",
76
+ "from torch.utils.data import DataLoader\n",
77
+ "import matplotlib.pyplot as plt\n",
78
+ "import numpy as np\n",
79
+ "\n",
80
+ "# Device\n",
81
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
82
+ "print(f'Using device: {device}')\n",
83
+ "if device.type == 'cuda':\n",
84
+ " print(f'GPU: {torch.cuda.get_device_name()}')\n",
85
+ " print(f'VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB')"
86
+ ]
87
+ },
88
+ {
89
+ "cell_type": "markdown",
90
+ "metadata": {},
91
+ "source": [
92
+ "## 1. Architecture Overview & Parameter Counting"
93
+ ]
94
+ },
95
+ {
96
+ "cell_type": "code",
97
+ "execution_count": null,
98
+ "metadata": {},
99
+ "outputs": [],
100
+ "source": [
101
+ "# Create model with different configs\n",
102
+ "configs = {\n",
103
+ " 'Tiny (5.7M)': LatentRecurrentFlow.tiny_config(),\n",
104
+ " 'Default (16.3M)': LatentRecurrentFlow.default_config(),\n",
105
+ "}\n",
106
+ "\n",
107
+ "for name, config in configs.items():\n",
108
+ " model = LatentRecurrentFlow(config)\n",
109
+ " counts = model.count_parameters()\n",
110
+ " \n",
111
+ " print(f'\\n=== {name} ===')\n",
112
+ " print(f'Config: T_outer={config[\"T_outer\"]}, T_inner={config[\"T_inner\"]}, '\n",
113
+ " f'num_blocks={config[\"num_blocks\"]}')\n",
114
+ " print(f'Effective depth: {config[\"T_outer\"] * config[\"T_inner\"] * config[\"num_blocks\"]} layers '\n",
115
+ " f'(from {config[\"num_blocks\"]} unique blocks)')\n",
116
+ " for module, count in counts.items():\n",
117
+ " mb = count * 4 / 1e6\n",
118
+ " print(f' {module:20s}: {count:>12,} params ({mb:.1f} MB FP32)')\n",
119
+ " del model"
120
+ ]
121
+ },
122
+ {
123
+ "cell_type": "markdown",
124
+ "metadata": {},
125
+ "source": [
126
+ "## 2. Stage 1: VAE Training\n",
127
+ "\n",
128
+ "The VAE learns to compress images into a compact latent space.\n",
129
+ "- f=16 spatial compression: 256×256 → 16×16 latents\n",
130
+ "- C=16 or C=32 latent channels\n",
131
+ "- Tiny decoder (~280K params) inspired by SnapGen"
132
+ ]
133
+ },
134
+ {
135
+ "cell_type": "code",
136
+ "execution_count": null,
137
+ "metadata": {},
138
+ "outputs": [],
139
+ "source": [
140
+ "# Create model for training\n",
141
+ "config = LatentRecurrentFlow.tiny_config()\n",
142
+ "model = LatentRecurrentFlow(config).to(device)\n",
143
+ "\n",
144
+ "# Create synthetic dataset (replace with real data for actual training)\n",
145
+ "dataset = SyntheticImageTextDataset(\n",
146
+ " num_samples=500,\n",
147
+ " image_size=64,\n",
148
+ " max_text_length=32\n",
149
+ ")\n",
150
+ "dataloader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=0)\n",
151
+ "\n",
152
+ "# Create trainer\n",
153
+ "trainer = LRFTrainer(model, device, './lrf_checkpoints')\n",
154
+ "\n",
155
+ "print(f'Dataset size: {len(dataset)}')\n",
156
+ "print(f'Batch size: 8')\n",
157
+ "print(f'Image size: 64x64')\n",
158
+ "print(f'Latent size: {64//16}x{64//16}x{config[\"latent_channels\"]}')"
159
+ ]
160
+ },
161
+ {
162
+ "cell_type": "code",
163
+ "execution_count": null,
164
+ "metadata": {},
165
+ "outputs": [],
166
+ "source": [
167
+ "# Train VAE\n",
168
+ "vae_optimizer = torch.optim.AdamW(model.vae.parameters(), lr=1e-3, weight_decay=0.01)\n",
169
+ "\n",
170
+ "vae_losses = []\n",
171
+ "num_vae_steps = 100\n",
172
+ "\n",
173
+ "print('Training VAE...')\n",
174
+ "step = 0\n",
175
+ "for epoch in range(10): # Multiple epochs over small dataset\n",
176
+ " for batch in dataloader:\n",
177
+ " if step >= num_vae_steps:\n",
178
+ " break\n",
179
+ " losses = trainer.train_vae_step(batch['image'], vae_optimizer)\n",
180
+ " vae_losses.append(losses['total'])\n",
181
+ " if step % 20 == 0:\n",
182
+ " print(f' Step {step}: total={losses[\"total\"]:.4f}, '\n",
183
+ " f'recon={losses[\"recon\"]:.4f}, kl={losses[\"kl\"]:.4f}')\n",
184
+ " step += 1\n",
185
+ " if step >= num_vae_steps:\n",
186
+ " break\n",
187
+ "\n",
188
+ "# Plot VAE loss\n",
189
+ "plt.figure(figsize=(10, 4))\n",
190
+ "plt.plot(vae_losses)\n",
191
+ "plt.xlabel('Step')\n",
192
+ "plt.ylabel('Loss')\n",
193
+ "plt.title('VAE Training Loss')\n",
194
+ "plt.grid(True, alpha=0.3)\n",
195
+ "plt.show()\n",
196
+ "\n",
197
+ "# Save checkpoint\n",
198
+ "trainer.save_checkpoint('./lrf_checkpoints/vae.pt', 'vae', 0)"
199
+ ]
200
+ },
201
+ {
202
+ "cell_type": "code",
203
+ "execution_count": null,
204
+ "metadata": {},
205
+ "outputs": [],
206
+ "source": [
207
+ "# Visualize VAE reconstruction\n",
208
+ "model.eval()\n",
209
+ "with torch.no_grad():\n",
210
+ " sample_batch = next(iter(dataloader))\n",
211
+ " images = sample_batch['image'].to(device)\n",
212
+ " recon, _, _ = model.vae(images)\n",
213
+ "\n",
214
+ "fig, axes = plt.subplots(2, 4, figsize=(16, 8))\n",
215
+ "for i in range(4):\n",
216
+ " # Original\n",
217
+ " img = images[i].cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5\n",
218
+ " axes[0][i].imshow(np.clip(img, 0, 1))\n",
219
+ " axes[0][i].set_title(f'Original {i}')\n",
220
+ " axes[0][i].axis('off')\n",
221
+ " \n",
222
+ " # Reconstruction\n",
223
+ " rec = recon[i].cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5\n",
224
+ " axes[1][i].imshow(np.clip(rec, 0, 1))\n",
225
+ " axes[1][i].set_title(f'Reconstruction {i}')\n",
226
+ " axes[1][i].axis('off')\n",
227
+ "\n",
228
+ "plt.suptitle('VAE Reconstruction Quality', fontsize=14)\n",
229
+ "plt.tight_layout()\n",
230
+ "plt.show()"
231
+ ]
232
+ },
233
+ {
234
+ "cell_type": "markdown",
235
+ "metadata": {},
236
+ "source": [
237
+ "## 3. Stage 2: Flow Matching Denoiser Training\n",
238
+ "\n",
239
+ "The denoising core learns to predict the velocity field for rectified flow.\n",
240
+ "- VAE is frozen\n",
241
+ "- Core + text encoder are trained\n",
242
+ "- Uses SNR-weighted flow matching loss"
243
+ ]
244
+ },
245
+ {
246
+ "cell_type": "code",
247
+ "execution_count": null,
248
+ "metadata": {},
249
+ "outputs": [],
250
+ "source": [
251
+ "# Freeze VAE\n",
252
+ "for p in model.vae.parameters():\n",
253
+ " p.requires_grad = False\n",
254
+ "\n",
255
+ "# Train flow matching\n",
256
+ "flow_params = list(model.core.parameters()) + list(model.text_encoder.parameters())\n",
257
+ "flow_optimizer = torch.optim.AdamW(flow_params, lr=1e-3, weight_decay=0.01)\n",
258
+ "\n",
259
+ "flow_losses = []\n",
260
+ "num_flow_steps = 100\n",
261
+ "\n",
262
+ "print('Training flow matching denoiser...')\n",
263
+ "model.core.train()\n",
264
+ "model.text_encoder.train()\n",
265
+ "\n",
266
+ "step = 0\n",
267
+ "for epoch in range(10):\n",
268
+ " for batch in dataloader:\n",
269
+ " if step >= num_flow_steps:\n",
270
+ " break\n",
271
+ " losses = trainer.train_flow_step(\n",
272
+ " batch['image'], batch['token_ids'], batch['attention_mask'],\n",
273
+ " flow_optimizer, cfg_dropout=0.1\n",
274
+ " )\n",
275
+ " flow_losses.append(losses['flow_loss'])\n",
276
+ " if step % 20 == 0:\n",
277
+ " print(f' Step {step}: flow_loss={losses[\"flow_loss\"]:.4f}')\n",
278
+ " step += 1\n",
279
+ " if step >= num_flow_steps:\n",
280
+ " break\n",
281
+ "\n",
282
+ "# Plot flow loss\n",
283
+ "plt.figure(figsize=(10, 4))\n",
284
+ "plt.plot(flow_losses)\n",
285
+ "plt.xlabel('Step')\n",
286
+ "plt.ylabel('Flow Matching Loss')\n",
287
+ "plt.title('Denoiser Training Loss')\n",
288
+ "plt.grid(True, alpha=0.3)\n",
289
+ "plt.show()\n",
290
+ "\n",
291
+ "# Save checkpoint\n",
292
+ "trainer.save_checkpoint('./lrf_checkpoints/flow.pt', 'flow', 0)"
293
+ ]
294
+ },
295
+ {
296
+ "cell_type": "markdown",
297
+ "metadata": {},
298
+ "source": [
299
+ "## 4. Generation & Visualization\n",
300
+ "\n",
301
+ "Generate images using the trained model with Euler ODE sampling."
302
+ ]
303
+ },
304
+ {
305
+ "cell_type": "code",
306
+ "execution_count": null,
307
+ "metadata": {},
308
+ "outputs": [],
309
+ "source": [
310
+ "# Generate samples\n",
311
+ "model.eval()\n",
312
+ "\n",
313
+ "# Create prompts (using simple tokenization for prototype)\n",
314
+ "prompts = [\n",
315
+ " 'a beautiful sunset over the ocean with golden light',\n",
316
+ " 'a cute cat sitting on a windowsill',\n",
317
+ " 'a mountain landscape with snow and trees',\n",
318
+ " 'a colorful abstract painting with swirls',\n",
319
+ "]\n",
320
+ "\n",
321
+ "pipe = LRFPipeline(model, device=device)\n",
322
+ "\n",
323
+ "# Generate with different step counts\n",
324
+ "for num_steps in [5, 10, 20]:\n",
325
+ " images = pipe(\n",
326
+ " prompts,\n",
327
+ " num_steps=num_steps,\n",
328
+ " cfg_scale=1.0, # Low cfg for untrained model\n",
329
+ " height=64,\n",
330
+ " width=64,\n",
331
+ " seed=42,\n",
332
+ " )\n",
333
+ " \n",
334
+ " fig, axes = plt.subplots(1, 4, figsize=(16, 4))\n",
335
+ " for i in range(4):\n",
336
+ " img = images[i].cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5\n",
337
+ " axes[i].imshow(np.clip(img, 0, 1))\n",
338
+ " axes[i].set_title(prompts[i][:30] + '...')\n",
339
+ " axes[i].axis('off')\n",
340
+ " plt.suptitle(f'Generated Images ({num_steps} steps)', fontsize=14)\n",
341
+ " plt.tight_layout()\n",
342
+ " plt.show()"
343
+ ]
344
+ },
345
+ {
346
+ "cell_type": "markdown",
347
+ "metadata": {},
348
+ "source": [
349
+ "## 5. Save & Load Model"
350
+ ]
351
+ },
352
+ {
353
+ "cell_type": "code",
354
+ "execution_count": null,
355
+ "metadata": {},
356
+ "outputs": [],
357
+ "source": [
358
+ "# Save the complete model\n",
359
+ "pipe.save_pretrained('./lrf_model')\n",
360
+ "print('Model saved to ./lrf_model/')\n",
361
+ "\n",
362
+ "# List saved files\n",
363
+ "for f in os.listdir('./lrf_model'):\n",
364
+ " size = os.path.getsize(f'./lrf_model/{f}')\n",
365
+ " print(f' {f}: {size/1024:.1f} KB')\n",
366
+ "\n",
367
+ "# Reload and verify\n",
368
+ "pipe_loaded = LRFPipeline.from_pretrained('./lrf_model', device=str(device))\n",
369
+ "images_loaded = pipe_loaded('test prompt', num_steps=5, height=64, width=64, seed=42)\n",
370
+ "print(f'\\nReloaded model generates: {images_loaded.shape}')"
371
+ ]
372
+ },
373
+ {
374
+ "cell_type": "markdown",
375
+ "metadata": {},
376
+ "source": [
377
+ "## 6. Training Curriculum for Real Data\n",
378
+ "\n",
379
+ "The full training curriculum for production-quality models:"
380
+ ]
381
+ },
382
+ {
383
+ "cell_type": "code",
384
+ "execution_count": null,
385
+ "metadata": {},
386
+ "outputs": [],
387
+ "source": [
388
+ "# Display the full training curriculum\n",
389
+ "curriculum = LRFTrainingPipeline.get_curriculum()\n",
390
+ "\n",
391
+ "print('Full Training Curriculum')\n",
392
+ "print('=' * 70)\n",
393
+ "for i, stage_name in enumerate(curriculum):\n",
394
+ " stage = LRFTrainingPipeline.get_stage_config(stage_name)\n",
395
+ " print(f'\\nStage {i+1}: {stage_name}')\n",
396
+ " print(f' Description: {stage[\"description\"]}')\n",
397
+ " print(f' Freeze: {stage[\"freeze\"]}')\n",
398
+ " print(f' Train: {stage[\"train\"]}')\n",
399
+ " print(f' LR: {stage[\"lr\"]}')\n",
400
+ " print(f' Min steps: {stage[\"min_steps\"]:,}')\n",
401
+ " if 'resolution' in stage:\n",
402
+ " print(f' Resolution: {stage[\"resolution\"]}×{stage[\"resolution\"]}')\n",
403
+ "\n",
404
+ "print('\\n' + '=' * 70)\n",
405
+ "print('\\nRecommended datasets for each stage:')\n",
406
+ "print(' Stage 1 (VAE): ImageNet, COCO, or any large image dataset')\n",
407
+ "print(' Stage 2 (Flow 64): Synthetic captions from teacher (SDXL/SD3) + LAION-aesthetic')\n",
408
+ "print(' Stage 3 (Flow 256): Filtered LAION-aesthetic (score > 6.0) + synthetic')\n",
409
+ "print(' Stage 4 (Flow 512): High-quality curated dataset + JourneyDB')\n",
410
+ "print(' Stage 5 (Distill): Same as Stage 4 (distill from own multi-step model)')\n",
411
+ "print(' Stage 6 (Editing): InstructPix2Pix + MagicBrush + synthetic edit pairs')"
412
+ ]
413
+ },
414
+ {
415
+ "cell_type": "markdown",
416
+ "metadata": {},
417
+ "source": [
418
+ "## 7. Architecture Deep Dive\n",
419
+ "\n",
420
+ "### The Recursive Latent Refinement Loop"
421
+ ]
422
+ },
423
+ {
424
+ "cell_type": "code",
425
+ "execution_count": null,
426
+ "metadata": {},
427
+ "outputs": [],
428
+ "source": [
429
+ "# Demonstrate the recursive refinement\n",
430
+ "core = RecursiveLatentCore(\n",
431
+ " dim=32, cond_dim=64, num_blocks=2, num_heads=2, head_dim=16,\n",
432
+ " T_inner=4, T_outer=2, use_ift_training=False\n",
433
+ ")\n",
434
+ "\n",
435
+ "print('Recursive Latent Core Architecture')\n",
436
+ "print('=' * 50)\n",
437
+ "print(f'Unique GLD blocks: {core.num_blocks}')\n",
438
+ "print(f'T_outer (abstract updates): {core.T_outer}')\n",
439
+ "print(f'T_inner (refinement steps): {core.T_inner}')\n",
440
+ "print(f'Total recursions: {core.T_outer * core.T_inner}')\n",
441
+ "print(f'Effective depth: {core.T_outer * core.T_inner * core.num_blocks} layers')\n",
442
+ "print(f'Parameter reuse factor: {core.T_outer * core.T_inner}x')\n",
443
+ "print(f'\\nParameters: {sum(p.numel() for p in core.parameters()):,}')\n",
444
+ "\n",
445
+ "# Show memory savings from IFT\n",
446
+ "print('\\nMemory comparison:')\n",
447
+ "eff_depth = core.T_outer * core.T_inner * core.num_blocks\n",
448
+ "print(f' Standard backprop: O({eff_depth}) activation memory')\n",
449
+ "print(f' IFT backprop: O(1) activation memory')\n",
450
+ "print(f' Memory savings: {eff_depth}x')"
451
+ ]
452
+ },
453
+ {
454
+ "cell_type": "code",
455
+ "execution_count": null,
456
+ "metadata": {},
457
+ "outputs": [],
458
+ "source": [
459
+ "# Demonstrate GLA complexity\n",
460
+ "import time\n",
461
+ "\n",
462
+ "gla = GatedLinearAttention(dim=64, num_heads=4, head_dim=16)\n",
463
+ "\n",
464
+ "print('GLA Complexity Scaling')\n",
465
+ "print('=' * 50)\n",
466
+ "\n",
467
+ "sizes = [4, 8, 16, 32, 64]\n",
468
+ "times = []\n",
469
+ "\n",
470
+ "for s in sizes:\n",
471
+ " x = torch.randn(1, s*s, 64)\n",
472
+ " \n",
473
+ " # Warmup\n",
474
+ " _ = gla(x, h=s, w=s)\n",
475
+ " \n",
476
+ " # Time\n",
477
+ " t0 = time.time()\n",
478
+ " for _ in range(10):\n",
479
+ " _ = gla(x, h=s, w=s)\n",
480
+ " dt = (time.time() - t0) / 10\n",
481
+ " times.append(dt)\n",
482
+ " print(f' {s}×{s} = {s*s:>5} tokens: {dt*1000:.2f}ms')\n",
483
+ "\n",
484
+ "# Plot scaling\n",
485
+ "plt.figure(figsize=(8, 4))\n",
486
+ "tokens = [s*s for s in sizes]\n",
487
+ "plt.plot(tokens, [t*1000 for t in times], 'bo-', label='GLA (O(N))')\n",
488
+ "# Reference quadratic line\n",
489
+ "t_ref = times[0] * 1000\n",
490
+ "quadratic = [t_ref * (n / tokens[0])**2 for n in tokens]\n",
491
+ "plt.plot(tokens, quadratic, 'r--', label='Quadratic attention (O(N²))', alpha=0.5)\n",
492
+ "plt.xlabel('Number of tokens')\n",
493
+ "plt.ylabel('Time (ms)')\n",
494
+ "plt.title('GLA vs Quadratic Attention Scaling')\n",
495
+ "plt.legend()\n",
496
+ "plt.grid(True, alpha=0.3)\n",
497
+ "plt.show()"
498
+ ]
499
+ },
500
+ {
501
+ "cell_type": "markdown",
502
+ "metadata": {},
503
+ "source": [
504
+ "## 8. Push to HuggingFace Hub (Optional)"
505
+ ]
506
+ },
507
+ {
508
+ "cell_type": "code",
509
+ "execution_count": null,
510
+ "metadata": {},
511
+ "outputs": [],
512
+ "source": [
513
+ "# Uncomment to push to Hub\n",
514
+ "# from huggingface_hub import HfApi\n",
515
+ "# api = HfApi()\n",
516
+ "# api.upload_folder(\n",
517
+ "# folder_path='./lrf_model',\n",
518
+ "# repo_id='your-username/LatentRecurrentFlow',\n",
519
+ "# repo_type='model',\n",
520
+ "# )\n",
521
+ "print('To push to HF Hub, uncomment the code above and set your repo_id.')"
522
+ ]
523
+ },
524
+ {
525
+ "cell_type": "markdown",
526
+ "metadata": {},
527
+ "source": [
528
+ "---\n",
529
+ "\n",
530
+ "## Summary\n",
531
+ "\n",
532
+ "This notebook demonstrated the LatentRecurrentFlow architecture end-to-end:\n",
533
+ "\n",
534
+ "1. ✅ Model creation with parameter counting\n",
535
+ "2. ✅ VAE training for image compression\n",
536
+ "3. ✅ Flow matching denoiser training\n",
537
+ "4. ✅ Image generation with Euler ODE sampling\n",
538
+ "5. ✅ Model save/load with HF-compatible format\n",
539
+ "6. ✅ Training curriculum for production\n",
540
+ "\n",
541
+ "### Next Steps\n",
542
+ "- Replace synthetic data with real image-text pairs\n",
543
+ "- Scale to default config (16M params)\n",
544
+ "- Train on GPU for actual quality\n",
545
+ "- Add consistency distillation for 4-step generation\n",
546
+ "- Add editing fine-tuning stage"
547
+ ]
548
+ }
549
+ ],
550
+ "metadata": {
551
+ "kernelspec": {
552
+ "display_name": "Python 3",
553
+ "language": "python",
554
+ "name": "python3"
555
+ },
556
+ "language_info": {
557
+ "name": "python",
558
+ "version": "3.10.0"
559
+ }
560
+ },
561
+ "nbformat": 4,
562
+ "nbformat_minor": 4
563
+ }