FE2E-CPU / app.py
Nekochu's picture
Revert to PyTorch INT8 (ONNX export produces NaN)
563ef1b verified
"""FE2E: Depth + Normal estimation from a single image (CPU, pre-quantized INT8)"""
from __future__ import annotations
import gc
import os
import sys
import time
import torch
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
MODELS_DIR = "/tmp/fe2e_models"
os.makedirs(MODELS_DIR, exist_ok=True)
class Args:
prompt_type = "empty"
single_denoise = True
empty_prompt_cache = os.path.join(os.path.dirname(os.path.abspath(__file__)), "latent", "no_info.npz")
norm_type = "ln"
def _download(repo, filename, token=None):
import shutil
from huggingface_hub import hf_hub_download
basename = os.path.basename(filename)
dest = os.path.join(MODELS_DIR, basename)
if not os.path.exists(dest):
print(f"[init] Downloading {repo}/{filename}...")
src = hf_hub_download(repo, filename, token=token)
shutil.copy2(src, dest)
print(f"[init] {basename}: {os.path.getsize(dest)/1024/1024:.0f} MB")
return dest
def _load_generator():
from infer.inference import ImageGenerator
token = os.environ.get("HF_TOKEN")
dit_path = _download("WeReCooking2/FE2E-INT8", "dit_int8_full.pt", token)
vae_path = _download("WeReCooking2/FE2E-INT8", "vae_full.pt", token)
args = Args()
print("[init] Loading pre-quantized INT8 DiT (full model, mmap)...")
t0 = time.time()
dit = torch.load(dit_path, map_location="cpu", weights_only=False, mmap=True)
gc.collect()
print(f"[init] DiT loaded in {time.time()-t0:.0f}s")
print("[init] Loading VAE (full model)...")
ae = torch.load(vae_path, map_location="cpu", weights_only=False, mmap=True)
gc.collect()
generator = ImageGenerator.__new__(ImageGenerator)
generator.device = torch.device("cpu")
generator.args = args
generator.ae = ae
generator.dit = dit
generator.llm_encoder = None
generator.quantized = False
generator.offload = False
generator.lora_module = None
print(f"[init] Ready. Total load: {time.time()-t0:.0f}s")
return generator
print("[init] Loading model at startup (not lazy)...")
GENERATOR = _load_generator()
print("[init] Model ready, starting Gradio...")
def generate(image):
import gradio as gr
from PIL import Image
if image is None:
raise gr.Error("Please upload an image.")
if isinstance(image, str):
image = Image.open(image).convert("RGB")
elif not isinstance(image, Image.Image):
image = Image.fromarray(image).convert("RGB")
args = Args()
print(f"[gen] Input: {image.size}")
t0 = time.time()
with torch.inference_mode():
images, Lpred, Rpred = GENERATOR.generate_image(
prompt="",
negative_prompt="",
ref_images=image,
num_samples=1,
num_steps=1,
cfg_guidance=6.0,
seed=42,
show_progress=True,
args=args,
)
elapsed = time.time() - t0
import numpy as np
import matplotlib.cm as cm
normal_np = Rpred[0].cpu().float().numpy().transpose(1, 2, 0)
normal_norm = np.linalg.norm(normal_np, axis=-1, keepdims=True)
normal_norm[normal_norm < 1e-12] = 1e-12
normal_np = normal_np / normal_norm
normal_rgb = (((normal_np + 1) * 0.5) * 255).clip(0, 255).astype(np.uint8)
normal_map = Image.fromarray(normal_rgb)
normal_map = normal_map.resize(image.size)
depth_np = Lpred[0].cpu().float().mean(dim=0).numpy()
depth_np = (depth_np - depth_np.min()) / (depth_np.max() - depth_np.min() + 1e-8)
depth_colored = (cm.turbo(depth_np)[:, :, :3] * 255).astype(np.uint8)
depth_map = Image.fromarray(depth_colored)
depth_map = depth_map.resize(image.size)
status = f"Generated in {elapsed:.1f}s ({image.size[0]}x{image.size[1]}, single denoise, INT8)"
print(f"[gen] {status}")
return depth_map, normal_map, status
import gradio as gr
with gr.Blocks(title="FE2E: Depth + Normal (CPU)") as demo:
gr.Markdown(
"**[FE2E](https://github.com/AMAP-ML/FE2E)** Depth + Normal from a single image (CVPR 2026). "
"Takes ~29 min for 768x1024, 1 step, Step1X-Edit DiT + LDRN LoRA (pre-merged), INT8 quantized on CPU."
)
with gr.Row(equal_height=True):
input_img = gr.Image(label="Input", type="pil", height=256)
depth_out = gr.Image(label="Depth", type="pil", height=256)
normal_out = gr.Image(label="Normal", type="pil", height=256)
with gr.Row():
run_btn = gr.Button("Estimate Depth + Normal", variant="primary", size="lg")
status_out = gr.Textbox(label="Status", interactive=False, scale=2)
run_btn.click(
fn=generate,
inputs=[input_img],
outputs=[depth_out, normal_out, status_out],
concurrency_limit=1,
api_name="generate",
)
gr.Examples(
examples=["assets/example.jpg"],
inputs=[input_img],
outputs=[depth_out, normal_out, status_out],
fn=generate,
cache_examples=False,
label="Examples",
)
demo.queue(default_concurrency_limit=1)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True,
mcp_server=True, ssr_mode=False)