Nekochu commited on
Commit
563ef1b
·
verified ·
1 Parent(s): 10d0786

Revert to PyTorch INT8 (ONNX export produces NaN)

Browse files
Files changed (1) hide show
  1. app.py +63 -165
app.py CHANGED
@@ -1,90 +1,80 @@
1
- """FE2E: Depth + Normal estimation from a single image (CPU, ONNX INT8 DiT + PyTorch VAE)"""
2
  from __future__ import annotations
3
 
4
  import gc
5
- import math
6
  import os
7
- import shutil
8
  import sys
9
  import time
10
 
11
- import numpy as np
12
  import torch
13
- from einops import rearrange, repeat
14
- from PIL import Image
15
- from torchvision.transforms import functional as TF
16
- import torch.nn.functional as Func
17
 
18
  sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
19
 
20
  MODELS_DIR = "/tmp/fe2e_models"
21
- ONNX_DIR = os.path.join(MODELS_DIR, "onnx")
22
- os.makedirs(ONNX_DIR, exist_ok=True)
23
 
24
- EMPTY_PROMPT_CACHE = os.path.join(os.path.dirname(os.path.abspath(__file__)), "latent", "no_info.npz")
25
 
 
 
 
 
 
26
 
27
- def _download(repo, filename, dest_dir, token=None):
 
 
28
  from huggingface_hub import hf_hub_download
29
  basename = os.path.basename(filename)
30
- dest = os.path.join(dest_dir, basename)
31
  if not os.path.exists(dest):
32
  print(f"[init] Downloading {repo}/{filename}...")
33
  src = hf_hub_download(repo, filename, token=token)
34
  shutil.copy2(src, dest)
35
- size_mb = os.path.getsize(dest) / 1024 / 1024
36
- print(f"[init] {basename}: {size_mb:.0f} MB")
37
  return dest
38
 
39
 
40
- def _load_all():
41
- import onnxruntime as ort
42
 
43
  token = os.environ.get("HF_TOKEN")
44
- repo = "WeReCooking2/FE2E-INT8"
 
45
 
46
- print("[init] Downloading ONNX INT8 DiT...")
47
- onnx_path = _download(repo, "onnx/dit_int8.onnx", ONNX_DIR, token)
48
- _download(repo, "onnx/dit_int8.onnx.data", ONNX_DIR, token)
 
 
 
49
 
50
- print("[init] Downloading VAE...")
51
- vae_path = _download(repo, "vae_full.pt", MODELS_DIR, token)
52
 
53
- print("[init] Creating ONNX Runtime session (mmap + low memory)...")
54
- t0 = time.time()
55
- opts = ort.SessionOptions()
56
- opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
57
- opts.inter_op_num_threads = 1
58
- opts.intra_op_num_threads = 2
59
- opts.enable_mem_pattern = True
60
- opts.enable_mem_reuse = True
61
- opts.add_session_config_entry("session.disable_prepacking", "1")
62
- dit_session = ort.InferenceSession(onnx_path, opts, providers=["CPUExecutionProvider"])
63
- print(f"[init] DiT session ready in {time.time() - t0:.0f}s")
64
-
65
- print("[init] Loading VAE...")
66
- vae = torch.load(vae_path, map_location="cpu", weights_only=False, mmap=True)
67
- vae.eval()
68
  gc.collect()
69
 
70
- print("[init] Loading empty prompt cache...")
71
- data = np.load(EMPTY_PROMPT_CACHE, allow_pickle=True)
72
- embeds = torch.from_numpy(data["embeds"]).unsqueeze(0)
73
- masks = torch.from_numpy(data["masks"]).unsqueeze(0)
 
 
 
 
 
74
 
75
- print("[init] Ready.")
76
- return dit_session, vae, embeds, masks
77
 
78
 
79
- print("[init] Loading models at startup...")
80
- DIT_SESSION, VAE, PROMPT_EMBEDS, PROMPT_MASKS = _load_all()
81
- print("[init] All models loaded, starting Gradio...")
82
 
83
 
84
  def generate(image):
85
- """Run depth + normal estimation on a single image."""
86
  import gradio as gr
87
- import matplotlib.cm as cm
88
 
89
  if image is None:
90
  raise gr.Error("Please upload an image.")
@@ -94,135 +84,43 @@ def generate(image):
94
  elif not isinstance(image, Image.Image):
95
  image = Image.fromarray(image).convert("RGB")
96
 
97
- orig_size = image.size
98
- print(f"[gen] Input: {orig_size}")
99
  t0 = time.time()
100
 
101
- # --- Resize to 1024x768 (matches ONNX model's fixed shape) ---
102
- image_resized = image.resize((1024, 768))
103
- img_tensor = TF.to_tensor(image_resized).unsqueeze(0) # [1, 3, 768, 1024]
104
- height, width = 768, 1024
105
-
106
- # --- VAE encode ---
107
- with torch.inference_mode():
108
- ref_latent = VAE.encode(img_tensor * 2 - 1) # [1, 16, 96, 128]
109
-
110
- # --- Prepare DiT inputs ---
111
- h_lat, w_lat = ref_latent.shape[2], ref_latent.shape[3] # 96, 128
112
- h_half, w_half = h_lat // 2, w_lat // 2 # 48, 64
113
- n_patches = h_half * w_half # 3072
114
-
115
- # Noise (zeros for single denoise)
116
- noise = torch.zeros(1, 16, h_lat, w_lat, dtype=torch.float32)
117
-
118
- # Rearrange to patches
119
- noise_patches = rearrange(noise, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
120
- ref_patches = rearrange(ref_latent, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
121
-
122
- # Concatenate noise + reference along sequence dim
123
- img = torch.cat([noise_patches, ref_patches], dim=1) # [1, 6144, 64]
124
- # Duplicate for CFG (conditional + unconditional)
125
- img = torch.cat([img, img], dim=0) # [2, 6144, 64]
126
-
127
- # Position IDs for noise patches
128
- img_ids_noise = torch.zeros(h_half, w_half, 3)
129
- img_ids_noise[..., 1] = torch.arange(h_half)[:, None].float()
130
- img_ids_noise[..., 2] = torch.arange(w_half)[None, :].float()
131
- img_ids_noise = rearrange(img_ids_noise, "h w c -> 1 (h w) c")
132
-
133
- # Position IDs for reference patches (same layout)
134
- img_ids_ref = torch.zeros(h_half, w_half, 3)
135
- img_ids_ref[..., 1] = torch.arange(h_half)[:, None].float()
136
- img_ids_ref[..., 2] = torch.arange(w_half)[None, :].float()
137
- img_ids_ref = rearrange(img_ids_ref, "h w c -> 1 (h w) c")
138
-
139
- img_ids = torch.cat([img_ids_noise, img_ids_ref], dim=1) # [1, 6144, 3]
140
- img_ids = img_ids.repeat(2, 1, 1) # [2, 6144, 3]
141
-
142
- # Text embeddings from cache
143
- txt = torch.cat([PROMPT_EMBEDS, PROMPT_EMBEDS], dim=0) # [2, 640, 3584]
144
- mask = torch.cat([PROMPT_MASKS, PROMPT_MASKS], dim=0) # [2, 640]
145
- txt_ids = torch.zeros(2, txt.shape[1], 3) # [2, 640, 3]
146
-
147
- # Timestep = 1.0 (single denoise step)
148
- t_vec = torch.full((2,), 1.0, dtype=torch.float32)
149
-
150
- # --- Run DiT via ONNX Runtime ---
151
- print("[gen] Running ONNX INT8 DiT inference...")
152
- t_dit = time.time()
153
-
154
- feeds = {
155
- "img": img.numpy().astype(np.float16),
156
- "img_ids": img_ids.numpy().astype(np.float16),
157
- "txt_ids": txt_ids.numpy().astype(np.float16),
158
- "timesteps": t_vec.numpy().astype(np.float16),
159
- "llm_embedding": txt.numpy().astype(np.float16),
160
- "t_vec": t_vec.numpy().astype(np.float16),
161
- "mask": mask.numpy().astype(np.float16),
162
- }
163
-
164
- pred_np = DIT_SESSION.run(None, feeds)[0] # [2, 6144, 64] float16
165
- pred = torch.from_numpy(pred_np).float()
166
-
167
- dit_elapsed = time.time() - t_dit
168
- print(f"[gen] DiT inference: {dit_elapsed:.1f}s")
169
-
170
- # --- CFG ---
171
- cfg_guidance = 6.0
172
- cond = pred[:1]
173
- uncond = pred[1:]
174
- pred_cfg = uncond + cfg_guidance * (cond - uncond) # [1, 6144, 64]
175
-
176
- # --- Apply single denoise step: img + (0 - 1) * pred = img - pred ---
177
- img_out = img[:1].float() + (0.0 - 1.0) * pred_cfg # [1, 6144, 64]
178
-
179
- # Split: first half is output (depth/normal latent), second half was reference
180
- output_patches = img_out[:, :n_patches] # [1, 3072, 64]
181
- pred1_ref = cond[:, n_patches:] # [1, 3072, 64] (reference prediction)
182
-
183
- # --- Unpack patches back to spatial ---
184
- depth_latent = rearrange(
185
- output_patches, "b (h w) (c ph pw) -> b c (h ph) (w pw)",
186
- h=h_half, w=w_half, ph=2, pw=2
187
- ) # [1, 16, 96, 128]
188
-
189
- normal_latent = rearrange(
190
- pred1_ref, "b (h w) (c ph pw) -> b c (h ph) (w pw)",
191
- h=h_half, w=w_half, ph=2, pw=2
192
- ) # [1, 16, 96, 128]
193
-
194
- # --- Unpack as the original code does (depth, normal from 2-panel layout) ---
195
- # Actually, looking at the original code more carefully:
196
- # Lpred, Rpred = unpack_latents(pred, h//16, w//16)
197
- # which splits the patches into two panels
198
- # But in single_denoise mode, the code uses denoise() not double_denoise()
199
- # denoise() returns: img[:, :seq_len//2], pred1[:, seq_len//2:]
200
- # So depth = first half of updated image, normal = second half of cond prediction
201
-
202
- # --- VAE decode ---
203
  with torch.inference_mode():
204
- depth_decoded = VAE.decode(depth_latent)
205
- normal_decoded = VAE.decode(normal_latent)
206
-
207
- depth_decoded = depth_decoded.clamp(-1, 1).mul(0.5).add(0.5)
 
 
 
 
 
 
 
208
 
209
  elapsed = time.time() - t0
210
 
211
- # --- Normal map visualization ---
212
- normal_np = normal_decoded[0].cpu().float().numpy().transpose(1, 2, 0) # (H, W, 3)
 
 
213
  normal_norm = np.linalg.norm(normal_np, axis=-1, keepdims=True)
214
  normal_norm[normal_norm < 1e-12] = 1e-12
215
  normal_np = normal_np / normal_norm
216
  normal_rgb = (((normal_np + 1) * 0.5) * 255).clip(0, 255).astype(np.uint8)
217
- normal_map = Image.fromarray(normal_rgb).resize(orig_size)
 
218
 
219
- # --- Depth map visualization (turbo colormap) ---
220
- depth_np = depth_decoded[0].cpu().float().mean(dim=0).numpy()
221
  depth_np = (depth_np - depth_np.min()) / (depth_np.max() - depth_np.min() + 1e-8)
222
  depth_colored = (cm.turbo(depth_np)[:, :, :3] * 255).astype(np.uint8)
223
- depth_map = Image.fromarray(depth_colored).resize(orig_size)
 
224
 
225
- status = f"Generated in {elapsed:.1f}s (DiT: {dit_elapsed:.1f}s, 1024x768, ONNX INT8)"
226
  print(f"[gen] {status}")
227
  return depth_map, normal_map, status
228
 
@@ -232,7 +130,7 @@ import gradio as gr
232
  with gr.Blocks(title="FE2E: Depth + Normal (CPU)") as demo:
233
  gr.Markdown(
234
  "**[FE2E](https://github.com/AMAP-ML/FE2E)** Depth + Normal from a single image (CVPR 2026). "
235
- "Step1X-Edit DiT + LDRN LoRA (pre-merged), ONNX INT8 quantized on CPU."
236
  )
237
  with gr.Row(equal_height=True):
238
  input_img = gr.Image(label="Input", type="pil", height=256)
 
1
+ """FE2E: Depth + Normal estimation from a single image (CPU, pre-quantized INT8)"""
2
  from __future__ import annotations
3
 
4
  import gc
 
5
  import os
 
6
  import sys
7
  import time
8
 
 
9
  import torch
 
 
 
 
10
 
11
  sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
12
 
13
  MODELS_DIR = "/tmp/fe2e_models"
14
+ os.makedirs(MODELS_DIR, exist_ok=True)
 
15
 
 
16
 
17
+ class Args:
18
+ prompt_type = "empty"
19
+ single_denoise = True
20
+ empty_prompt_cache = os.path.join(os.path.dirname(os.path.abspath(__file__)), "latent", "no_info.npz")
21
+ norm_type = "ln"
22
 
23
+
24
+ def _download(repo, filename, token=None):
25
+ import shutil
26
  from huggingface_hub import hf_hub_download
27
  basename = os.path.basename(filename)
28
+ dest = os.path.join(MODELS_DIR, basename)
29
  if not os.path.exists(dest):
30
  print(f"[init] Downloading {repo}/{filename}...")
31
  src = hf_hub_download(repo, filename, token=token)
32
  shutil.copy2(src, dest)
33
+ print(f"[init] {basename}: {os.path.getsize(dest)/1024/1024:.0f} MB")
 
34
  return dest
35
 
36
 
37
+ def _load_generator():
38
+ from infer.inference import ImageGenerator
39
 
40
  token = os.environ.get("HF_TOKEN")
41
+ dit_path = _download("WeReCooking2/FE2E-INT8", "dit_int8_full.pt", token)
42
+ vae_path = _download("WeReCooking2/FE2E-INT8", "vae_full.pt", token)
43
 
44
+ args = Args()
45
+ print("[init] Loading pre-quantized INT8 DiT (full model, mmap)...")
46
+ t0 = time.time()
47
+ dit = torch.load(dit_path, map_location="cpu", weights_only=False, mmap=True)
48
+ gc.collect()
49
+ print(f"[init] DiT loaded in {time.time()-t0:.0f}s")
50
 
51
+ print("[init] Loading VAE (full model)...")
52
+ ae = torch.load(vae_path, map_location="cpu", weights_only=False, mmap=True)
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  gc.collect()
55
 
56
+ generator = ImageGenerator.__new__(ImageGenerator)
57
+ generator.device = torch.device("cpu")
58
+ generator.args = args
59
+ generator.ae = ae
60
+ generator.dit = dit
61
+ generator.llm_encoder = None
62
+ generator.quantized = False
63
+ generator.offload = False
64
+ generator.lora_module = None
65
 
66
+ print(f"[init] Ready. Total load: {time.time()-t0:.0f}s")
67
+ return generator
68
 
69
 
70
+ print("[init] Loading model at startup (not lazy)...")
71
+ GENERATOR = _load_generator()
72
+ print("[init] Model ready, starting Gradio...")
73
 
74
 
75
  def generate(image):
 
76
  import gradio as gr
77
+ from PIL import Image
78
 
79
  if image is None:
80
  raise gr.Error("Please upload an image.")
 
84
  elif not isinstance(image, Image.Image):
85
  image = Image.fromarray(image).convert("RGB")
86
 
87
+ args = Args()
88
+ print(f"[gen] Input: {image.size}")
89
  t0 = time.time()
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  with torch.inference_mode():
92
+ images, Lpred, Rpred = GENERATOR.generate_image(
93
+ prompt="",
94
+ negative_prompt="",
95
+ ref_images=image,
96
+ num_samples=1,
97
+ num_steps=1,
98
+ cfg_guidance=6.0,
99
+ seed=42,
100
+ show_progress=True,
101
+ args=args,
102
+ )
103
 
104
  elapsed = time.time() - t0
105
 
106
+ import numpy as np
107
+ import matplotlib.cm as cm
108
+
109
+ normal_np = Rpred[0].cpu().float().numpy().transpose(1, 2, 0)
110
  normal_norm = np.linalg.norm(normal_np, axis=-1, keepdims=True)
111
  normal_norm[normal_norm < 1e-12] = 1e-12
112
  normal_np = normal_np / normal_norm
113
  normal_rgb = (((normal_np + 1) * 0.5) * 255).clip(0, 255).astype(np.uint8)
114
+ normal_map = Image.fromarray(normal_rgb)
115
+ normal_map = normal_map.resize(image.size)
116
 
117
+ depth_np = Lpred[0].cpu().float().mean(dim=0).numpy()
 
118
  depth_np = (depth_np - depth_np.min()) / (depth_np.max() - depth_np.min() + 1e-8)
119
  depth_colored = (cm.turbo(depth_np)[:, :, :3] * 255).astype(np.uint8)
120
+ depth_map = Image.fromarray(depth_colored)
121
+ depth_map = depth_map.resize(image.size)
122
 
123
+ status = f"Generated in {elapsed:.1f}s ({image.size[0]}x{image.size[1]}, single denoise, INT8)"
124
  print(f"[gen] {status}")
125
  return depth_map, normal_map, status
126
 
 
130
  with gr.Blocks(title="FE2E: Depth + Normal (CPU)") as demo:
131
  gr.Markdown(
132
  "**[FE2E](https://github.com/AMAP-ML/FE2E)** Depth + Normal from a single image (CVPR 2026). "
133
+ "Takes ~29 min for 768x1024, 1 step, Step1X-Edit DiT + LDRN LoRA (pre-merged), INT8 quantized on CPU."
134
  )
135
  with gr.Row(equal_height=True):
136
  input_img = gr.Image(label="Input", type="pil", height=256)