asdf98 commited on
Commit
4f8b3e5
Β·
verified Β·
1 Parent(s): 9bfb518

Add run_demo.py

Browse files
Files changed (1) hide show
  1. run_demo.py +201 -0
run_demo.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ MicroForge End-to-End Demo Script
4
+ Runs the full notebook content as pure Python (no Jupyter magic).
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import time
10
+ import os
11
+ import sys
12
+
13
+ # Ensure we can import microforge
14
+ sys.path.insert(0, '/app')
15
+
16
+ print("=" * 70)
17
+ print("πŸ”¨ MicroForge: End-to-End Architecture Demo")
18
+ print("=" * 70)
19
+
20
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
21
+ print(f'Device: {device}')
22
+
23
+ # ── 1. Import all modules ──
24
+ from microforge.vae import MicroForgeVAE
25
+ from microforge.backbone import MicroForgeBackbone
26
+ from microforge.planner import RecurrentLatentPlanner
27
+ from microforge.pipeline import MicroForgePipeline, SimpleTextEncoder
28
+ from microforge.training import MicroForgeTrainer, FlowMatchingScheduler, MicroForgeLoss
29
+
30
+ print("βœ“ All modules imported")
31
+
32
+ # ── 2. Test VAE configs ──
33
+ print("\n── VAE Configurations ──")
34
+ for config in ['tiny', 'small', 'base']:
35
+ vae = MicroForgeVAE(config=config)
36
+ params = sum(p.numel() for p in vae.parameters())
37
+ x = torch.randn(1, 3, 256, 256)
38
+ x_recon, mu, logvar = vae(x)
39
+ print(f" {config:>5}: {params:>12,} params | latent {mu.shape} | {params*2/1e6:.0f} MB fp16")
40
+ del vae
41
+
42
+ # ── 3. Test Backbone configs ──
43
+ print("\n── Backbone Configurations ──")
44
+ for config_name in ['tiny', 'small', 'base']:
45
+ lc = 16 if config_name == 'tiny' else 32
46
+ bb = MicroForgeBackbone(latent_channels=lc, config=config_name)
47
+ params = sum(p.numel() for p in bb.parameters())
48
+ z = torch.randn(1, lc, 8, 8)
49
+ t0 = time.time()
50
+ v = bb(z, torch.rand(1), torch.randn(1, 10, 768), torch.randn(1, 768))
51
+ ms = (time.time() - t0) * 1000
52
+ print(f" {config_name:>5}: {params:>12,} params | {ms:.0f}ms | {params*2/1e6:.0f} MB fp16")
53
+ del bb
54
+
55
+ # ── 4. Planner test ──
56
+ print("\n── Recurrent Latent Planner ──")
57
+ planner = RecurrentLatentPlanner(num_plan_tokens=32, dim=384, text_dim=768, latent_channels=32)
58
+ params = sum(p.numel() for p in planner.parameters())
59
+ print(f" Params: {params:,} | Plan state: {planner.get_plan_size_bytes()} bytes")
60
+
61
+ text_pooled = torch.randn(1, 768)
62
+ plan = planner.initialize_plan(text_pooled, 1)
63
+ for step in range(3):
64
+ img = torch.randn(1, 64, 32)
65
+ t_emb = torch.randn(1, 384)
66
+ plan, out = planner(img, plan, t_emb)
67
+ plan = planner.initialize_plan(text_pooled, 1, prev_plan=plan)
68
+ print(f" Step {step}: plan_norm={plan.norm():.2f}, out_norm={out.norm():.2f}")
69
+ del planner
70
+
71
+ # ── 5. Full Pipeline ──
72
+ print("\n── Full Pipeline Assembly ──")
73
+ vae = MicroForgeVAE(config='tiny')
74
+ backbone = MicroForgeBackbone(latent_channels=16, config='tiny')
75
+ planner = RecurrentLatentPlanner(num_plan_tokens=16, dim=256, text_dim=768, latent_channels=16)
76
+ text_enc = SimpleTextEncoder(vocab_size=8192, embed_dim=768, num_layers=2)
77
+
78
+ pipeline = MicroForgePipeline(vae, backbone, text_enc, planner, device='cpu')
79
+ params = pipeline.count_parameters()
80
+ print(f" Total params: {params['total']:,}")
81
+ for k, v in params.items():
82
+ if k != 'total':
83
+ print(f" {k}: {v:,}")
84
+
85
+ mem = pipeline.get_memory_estimate(512, 512)
86
+ print(f" Est. inference @512px: {mem['estimated_inference_mb']:.0f} MB")
87
+
88
+ # ── 6. Text2Img ──
89
+ print("\n── Text-to-Image Generation ──")
90
+ tokens = torch.randint(0, 8192, (1, 10))
91
+ t0 = time.time()
92
+ images = pipeline.text2img(tokens, height=128, width=128, num_steps=4, cfg_scale=1.0, seed=42)
93
+ ms = (time.time() - t0) * 1000
94
+ print(f" Generated {images.shape} in {ms:.0f}ms")
95
+ print(f" Range: [{images.min():.2f}, {images.max():.2f}]")
96
+
97
+ # ── 7. Training Demo ──
98
+ print("\n── Training Pipeline Demo ──")
99
+ vae_train = MicroForgeVAE(config='tiny')
100
+ bb_train = MicroForgeBackbone(latent_channels=16, config='tiny')
101
+ pl_train = RecurrentLatentPlanner(num_plan_tokens=16, dim=256, text_dim=768, latent_channels=16)
102
+
103
+ # VAE training
104
+ print(" Stage 1: VAE Training")
105
+ vae_train.train()
106
+ vae_opt = torch.optim.AdamW(vae_train.parameters(), lr=1e-4)
107
+ loss_fn = MicroForgeLoss(lambda_kl=1e-6)
108
+ for i in range(20):
109
+ imgs = torch.randn(4, 3, 128, 128) * 0.5
110
+ x_recon, mu, logvar = vae_train(imgs)
111
+ losses = loss_fn.vae_loss(x_recon, imgs, mu, logvar)
112
+ vae_opt.zero_grad()
113
+ losses['total'].backward()
114
+ torch.nn.utils.clip_grad_norm_(vae_train.parameters(), 2.0)
115
+ vae_opt.step()
116
+ if i % 5 == 0:
117
+ print(f" Step {i:3d}: recon={losses['recon'].item():.4f}")
118
+
119
+ # Backbone training
120
+ print(" Stage 2: Backbone Flow Matching")
121
+ vae_train.eval()
122
+ trainer = MicroForgeTrainer(vae_train, bb_train, pl_train, lr=1e-4, use_ema=True)
123
+ for i in range(20):
124
+ imgs = torch.randn(2, 3, 128, 128) * 0.5
125
+ text_emb = torch.randn(2, 10, 768)
126
+ text_pooled = torch.randn(2, 768)
127
+ losses = trainer.train_step(imgs, text_emb, text_pooled)
128
+ if i % 5 == 0:
129
+ print(f" Step {i:3d}: flow={losses['flow']:.2f}")
130
+
131
+ # ── 8. Editing pathway ──
132
+ print("\n── Editing Pathway Test ──")
133
+ bb = MicroForgeBackbone(latent_channels=16, config='tiny')
134
+ z_gen = torch.randn(1, 16, 8, 8)
135
+ z_edit = torch.randn(1, 16, 8, 16)
136
+ t = torch.rand(1)
137
+ te = torch.randn(1, 5, 768)
138
+ tp = torch.randn(1, 768)
139
+
140
+ v_gen = bb(z_gen, t, te, tp)
141
+ v_edit = bb(z_edit, t, te, tp)
142
+ print(f" Generation: {z_gen.shape} -> {v_gen.shape}")
143
+ print(f" Editing: {z_edit.shape} -> {v_edit.shape}")
144
+
145
+ # ── 9. Staged freeze/thaw ──
146
+ print("\n── Staged Training Config ──")
147
+ vae_s = MicroForgeVAE(config='tiny')
148
+ bb_s = MicroForgeBackbone(latent_channels=16, config='tiny')
149
+ pl_s = RecurrentLatentPlanner(num_plan_tokens=16, dim=256, text_dim=768, latent_channels=16)
150
+
151
+ def count_t(m): return sum(p.numel() for p in m.parameters() if p.requires_grad)
152
+ def freeze(m):
153
+ for p in m.parameters(): p.requires_grad_(False)
154
+ def unfreeze(m):
155
+ for p in m.parameters(): p.requires_grad_(True)
156
+
157
+ freeze(bb_s); freeze(pl_s); unfreeze(vae_s)
158
+ print(f" Stage 1 (VAE only): {count_t(vae_s):,} trainable")
159
+
160
+ freeze(vae_s); unfreeze(bb_s); unfreeze(pl_s)
161
+ print(f" Stage 2 (Backbone+Plan): {count_t(bb_s)+count_t(pl_s):,} trainable")
162
+
163
+ unfreeze(vae_s)
164
+ print(f" Stage 5 (Joint): {count_t(vae_s)+count_t(bb_s)+count_t(pl_s):,} trainable")
165
+
166
+ # ── 10. Architecture comparison ──
167
+ print("\n── Architecture Comparison ──")
168
+ comparison = [
169
+ ('SD-v1.5', '860M', '~3.4 GB', 'O(NΒ²)'),
170
+ ('SDXL', '2.6B', '~6.5 GB', 'O(NΒ²)'),
171
+ ('SANA-Sprint', '600M+2B', '~5.5 GB', 'O(N)'),
172
+ ('SnapGen', '380M+2B', '~4 GB', 'O(NΒ²)'),
173
+ ('DreamLite', '389M+2B', '~4 GB', 'O(NΒ²)'),
174
+ ('MicroForge-tiny', '28M+text', '~0.2 GB', 'O(N)'),
175
+ ('MicroForge-small', '114M+text', '~0.6 GB', 'O(N)'),
176
+ ]
177
+ print(f" {'Model':>18} | {'Params':>12} | {'VRAM':>10} | {'Complexity':>10}")
178
+ print(" " + "-" * 60)
179
+ for row in comparison:
180
+ print(f" {row[0]:>18} | {row[1]:>12} | {row[2]:>10} | {row[3]:>10}")
181
+
182
+ # ── 11. Save checkpoint ──
183
+ print("\n── Save Checkpoint ──")
184
+ os.makedirs('/app/checkpoints', exist_ok=True)
185
+ ckpt = {
186
+ 'vae': vae_train.state_dict(),
187
+ 'backbone': bb_train.state_dict(),
188
+ 'planner': pl_train.state_dict(),
189
+ 'config': {
190
+ 'vae_config': 'tiny', 'backbone_config': 'tiny',
191
+ 'latent_channels': 16, 'plan_tokens': 16, 'plan_dim': 256,
192
+ },
193
+ 'version': '0.1.0',
194
+ }
195
+ torch.save(ckpt, '/app/checkpoints/microforge_tiny_demo.pt')
196
+ size = os.path.getsize('/app/checkpoints/microforge_tiny_demo.pt') / 1e6
197
+ print(f" Saved: {size:.1f} MB")
198
+
199
+ print("\n" + "=" * 70)
200
+ print("βœ… MicroForge End-to-End Demo Complete β€” All Tests Passed")
201
+ print("=" * 70)