Nekochu commited on
Commit
551acb3
·
1 Parent(s): 405d2b1

fix: diffusers dep, layer-by-layer FP8->FP32 cast, LoRA merge in FP32, INT8 quant

Browse files
Files changed (2) hide show
  1. Dockerfile +1 -1
  2. app.py +82 -58
Dockerfile CHANGED
@@ -3,7 +3,7 @@ FROM python:3.11-slim
3
  RUN apt-get update && apt-get install -y --no-install-recommends git && rm -rf /var/lib/apt/lists/*
4
 
5
  RUN pip install --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/cpu
6
- RUN pip install --no-cache-dir "gradio[mcp]" Pillow huggingface-hub safetensors einops numpy tqdm
7
 
8
  WORKDIR /app
9
  COPY . .
 
3
  RUN apt-get update && apt-get install -y --no-install-recommends git && rm -rf /var/lib/apt/lists/*
4
 
5
  RUN pip install --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/cpu
6
+ RUN pip install --no-cache-dir "gradio[mcp]" Pillow huggingface-hub safetensors einops numpy tqdm "diffusers[torch]" transformers
7
 
8
  WORKDIR /app
9
  COPY . .
app.py CHANGED
@@ -1,4 +1,4 @@
1
- """FE2E: Depth + Normal estimation from a single image (CPU, FP32 with dynamic INT8)"""
2
  from __future__ import annotations
3
 
4
  import gc
@@ -6,7 +6,6 @@ import os
6
  import sys
7
  import time
8
 
9
- import numpy as np
10
  import torch
11
 
12
  sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
@@ -22,58 +21,89 @@ class Args:
22
  norm_type = "ln"
23
 
24
 
25
- def _download_models():
26
- from huggingface_hub import hf_hub_download
27
  import shutil
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  token = os.environ.get("HF_TOKEN")
30
- files = {
31
- "dit": ("rkfg/Step1X-Edit-FP8", "step1x-edit-i1258-FP8.safetensors"),
32
- "vae": ("exander/FE2E", "pretrain/vae.safetensors"),
33
- "lora": ("exander/FE2E", "LDRN.safetensors"),
34
- }
35
- paths = {}
36
- for key, (repo, filename) in files.items():
37
- basename = os.path.basename(filename)
38
- dest = os.path.join(MODELS_DIR, basename)
39
- if not os.path.exists(dest):
40
- print(f"[init] Downloading {repo}/{filename}...")
41
- src = hf_hub_download(repo, filename, token=token)
42
- shutil.copy2(src, dest)
43
- print(f"[init] {basename}: {os.path.getsize(dest)/1024/1024:.0f} MB")
44
- paths[key] = dest
45
- return paths
46
-
47
-
48
- def _load_generator(paths):
49
- """Load model: FP8 weights cast to FP32 for CPU, with LoRA merged."""
50
- from infer.inference import ImageGenerator
51
 
52
  args = Args()
53
- print("[init] Loading model (FP8 -> FP32 on CPU)...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  t0 = time.time()
 
 
 
 
 
 
55
 
56
- generator = ImageGenerator(
57
- dit_path=paths["dit"],
58
- ae_path=paths["vae"],
59
- quantized=True,
60
- offload=False,
61
- lora=paths["lora"],
62
- device="cpu",
63
- args=args,
64
- )
65
- # FP8 tensors can't compute on CPU, cast to FP32
66
- generator.dit = generator.dit.float()
67
- generator.ae = generator.ae.float()
68
-
69
- # Dynamic INT8 quantization for linear layers (biggest speedup on CPU)
70
- generator.dit = torch.quantization.quantize_dynamic(
71
- generator.dit, {torch.nn.Linear}, dtype=torch.qint8
72
- )
73
 
74
- elapsed = time.time() - t0
75
- print(f"[init] Model loaded + quantized in {elapsed:.0f}s")
 
 
 
 
 
 
 
 
 
76
  gc.collect()
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  return generator
78
 
79
 
@@ -84,7 +114,7 @@ def generate(image):
84
  """Estimate depth and surface normals from a single image.
85
 
86
  Args:
87
- image: Input image (PIL Image or filepath).
88
 
89
  Returns:
90
  tuple: (depth_map, normal_map, status_message)
@@ -95,8 +125,7 @@ def generate(image):
95
 
96
  global GENERATOR
97
  if GENERATOR is None:
98
- paths = _download_models()
99
- GENERATOR = _load_generator(paths)
100
 
101
  if image is None:
102
  raise gr.Error("Please upload an image.")
@@ -107,10 +136,10 @@ def generate(image):
107
  image = Image.fromarray(image).convert("RGB")
108
 
109
  args = Args()
110
- print(f"[gen] Input: {image.size}, starting inference...")
111
  t0 = time.time()
112
 
113
- with torch.inference_mode():
114
  images, Lpred, Rpred = GENERATOR.generate_image(
115
  prompt="",
116
  negative_prompt="",
@@ -125,14 +154,11 @@ def generate(image):
125
 
126
  elapsed = time.time() - t0
127
 
128
- # Normal map from model output
129
  normal_map = images[0] if images else None
130
-
131
- # Depth map from Lpred
132
  Lpred_img = Lpred[0].clamp(0, 1).cpu()
133
  depth_map = F.to_pil_image(Lpred_img)
134
 
135
- status = f"Generated in {elapsed:.1f}s ({image.size[0]}x{image.size[1]}, single denoise)"
136
  print(f"[gen] {status}")
137
  return depth_map, normal_map, status
138
 
@@ -144,7 +170,6 @@ def main():
144
  infer = sub.add_parser("infer")
145
  infer.add_argument("-i", "--input", required=True)
146
  infer.add_argument("-o", "--output-dir", default=".")
147
-
148
  args = parser.parse_args()
149
 
150
  if args.command == "infer":
@@ -160,9 +185,8 @@ def main():
160
 
161
  with gr.Blocks(title="FE2E: Depth + Normal (CPU)") as demo:
162
  gr.Markdown(
163
- "**[FE2E](https://github.com/AMAP-ML/FE2E)** Depth + Normal estimation from a single image. "
164
- "Single denoise step via Step1X-Edit DiT + LDRN LoRA. "
165
- "CPU inference with dynamic INT8 quantization."
166
  )
167
  with gr.Row():
168
  with gr.Column():
 
1
+ """FE2E: Depth + Normal estimation from a single image (CPU)"""
2
  from __future__ import annotations
3
 
4
  import gc
 
6
  import sys
7
  import time
8
 
 
9
  import torch
10
 
11
  sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
 
21
  norm_type = "ln"
22
 
23
 
24
+ def _download(repo, filename, token=None):
 
25
  import shutil
26
+ from huggingface_hub import hf_hub_download
27
+ basename = os.path.basename(filename)
28
+ dest = os.path.join(MODELS_DIR, basename)
29
+ if not os.path.exists(dest):
30
+ print(f"[init] Downloading {repo}/{filename}...")
31
+ src = hf_hub_download(repo, filename, token=token)
32
+ shutil.copy2(src, dest)
33
+ print(f"[init] {basename}: {os.path.getsize(dest)/1024/1024:.0f} MB")
34
+ return dest
35
+
36
+
37
+ def _load_generator():
38
+ """Load FP8 model, cast to FP32, merge LoRA, quantize INT8."""
39
+ from safetensors.torch import load_file
40
+ from modules.model_edit import Step1XParams, Step1XEdit
41
+ from modules.autoencoder import AutoEncoder
42
+ from infer.inference import ImageGenerator, load_state_dict, equip_dit_with_lora_sd_scripts
43
+ import numpy as np
44
 
45
  token = os.environ.get("HF_TOKEN")
46
+ dit_path = _download("rkfg/Step1X-Edit-FP8", "step1x-edit-i1258-FP8.safetensors", token)
47
+ vae_path = _download("exander/FE2E", "pretrain/vae.safetensors", token)
48
+ lora_path = _download("exander/FE2E", "LDRN.safetensors", token)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  args = Args()
51
+
52
+ print("[init] Building model on meta device...")
53
+ with torch.device("meta"):
54
+ ae = AutoEncoder(
55
+ resolution=256, in_channels=3, ch=128, out_ch=3,
56
+ ch_mult=[1, 2, 4, 4], num_res_blocks=2, z_channels=16,
57
+ scale_factor=0.3611, shift_factor=0.1159,
58
+ )
59
+ step1x_params = Step1XParams(
60
+ in_channels=64, out_channels=64, vec_in_dim=768, context_in_dim=4096,
61
+ hidden_size=3072, mlp_ratio=4.0, num_heads=24, depth=19,
62
+ depth_single_blocks=38, axes_dim=[16, 56, 56], theta=10_000, qkv_bias=True,
63
+ )
64
+ dit = Step1XEdit(step1x_params)
65
+
66
+ # Load weights as FP8, then cast to FP32 layer by layer to avoid 46 GB peak
67
+ print("[init] Loading FP8 weights and casting to FP32 (layer by layer)...")
68
  t0 = time.time()
69
+ fp8_sd = load_file(dit_path, device="cpu")
70
+ dit_sd = {}
71
+ for k, v in fp8_sd.items():
72
+ dit_sd[k] = v.float()
73
+ del fp8_sd
74
+ gc.collect()
75
 
76
+ dit = dit.to(dtype=torch.float32)
77
+ missing, unexpected = dit.load_state_dict(dit_sd, strict=False, assign=True)
78
+ del dit_sd
79
+ gc.collect()
80
+ print(f"[init] DiT loaded in {time.time()-t0:.0f}s (missing={len(missing)}, unexpected={len(unexpected)})")
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
+ # Load VAE
83
+ ae = load_state_dict(ae, vae_path, "cpu")
84
+ ae = ae.float()
85
+
86
+ # Merge LoRA in FP32 (full precision merge)
87
+ print("[init] Merging LoRA...")
88
+ equip_dit_with_lora_sd_scripts(ae, [None], dit, lora_path, device="cpu")
89
+
90
+ # Dynamic INT8 quantization
91
+ print("[init] Applying dynamic INT8 quantization...")
92
+ dit = torch.quantization.quantize_dynamic(dit, {torch.nn.Linear}, dtype=torch.qint8)
93
  gc.collect()
94
+
95
+ # Build generator wrapper
96
+ generator = ImageGenerator.__new__(ImageGenerator)
97
+ generator.device = torch.device("cpu")
98
+ generator.args = args
99
+ generator.ae = ae
100
+ generator.dit = dit
101
+ generator.llm_encoder = None
102
+ generator.quantized = False
103
+ generator.offload = False
104
+ generator.lora_module = None
105
+
106
+ print("[init] Model ready for inference")
107
  return generator
108
 
109
 
 
114
  """Estimate depth and surface normals from a single image.
115
 
116
  Args:
117
+ image: Input image (PIL Image).
118
 
119
  Returns:
120
  tuple: (depth_map, normal_map, status_message)
 
125
 
126
  global GENERATOR
127
  if GENERATOR is None:
128
+ GENERATOR = _load_generator()
 
129
 
130
  if image is None:
131
  raise gr.Error("Please upload an image.")
 
136
  image = Image.fromarray(image).convert("RGB")
137
 
138
  args = Args()
139
+ print(f"[gen] Input: {image.size}")
140
  t0 = time.time()
141
 
142
+ with torch.inference_mode(), torch.no_grad():
143
  images, Lpred, Rpred = GENERATOR.generate_image(
144
  prompt="",
145
  negative_prompt="",
 
154
 
155
  elapsed = time.time() - t0
156
 
 
157
  normal_map = images[0] if images else None
 
 
158
  Lpred_img = Lpred[0].clamp(0, 1).cpu()
159
  depth_map = F.to_pil_image(Lpred_img)
160
 
161
+ status = f"Generated in {elapsed:.1f}s ({image.size[0]}x{image.size[1]}, single denoise, INT8)"
162
  print(f"[gen] {status}")
163
  return depth_map, normal_map, status
164
 
 
170
  infer = sub.add_parser("infer")
171
  infer.add_argument("-i", "--input", required=True)
172
  infer.add_argument("-o", "--output-dir", default=".")
 
173
  args = parser.parse_args()
174
 
175
  if args.command == "infer":
 
185
 
186
  with gr.Blocks(title="FE2E: Depth + Normal (CPU)") as demo:
187
  gr.Markdown(
188
+ "**[FE2E](https://github.com/AMAP-ML/FE2E)** Depth + Normal from a single image (CVPR 2026). "
189
+ "Single denoise step, Step1X-Edit DiT + LDRN LoRA, dynamic INT8 on CPU."
 
190
  )
191
  with gr.Row():
192
  with gr.Column():