"""FE2E: Depth + Normal estimation from a single image (CPU, pre-quantized INT8)""" from __future__ import annotations import gc import os import sys import time import torch sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) MODELS_DIR = "/tmp/fe2e_models" os.makedirs(MODELS_DIR, exist_ok=True) class Args: prompt_type = "empty" single_denoise = True empty_prompt_cache = os.path.join(os.path.dirname(os.path.abspath(__file__)), "latent", "no_info.npz") norm_type = "ln" def _download(repo, filename, token=None): import shutil from huggingface_hub import hf_hub_download basename = os.path.basename(filename) dest = os.path.join(MODELS_DIR, basename) if not os.path.exists(dest): print(f"[init] Downloading {repo}/{filename}...") src = hf_hub_download(repo, filename, token=token) shutil.copy2(src, dest) print(f"[init] {basename}: {os.path.getsize(dest)/1024/1024:.0f} MB") return dest def _load_generator(): from infer.inference import ImageGenerator token = os.environ.get("HF_TOKEN") dit_path = _download("WeReCooking2/FE2E-INT8", "dit_int8_full.pt", token) vae_path = _download("WeReCooking2/FE2E-INT8", "vae_full.pt", token) args = Args() print("[init] Loading pre-quantized INT8 DiT (full model, mmap)...") t0 = time.time() dit = torch.load(dit_path, map_location="cpu", weights_only=False, mmap=True) gc.collect() print(f"[init] DiT loaded in {time.time()-t0:.0f}s") print("[init] Loading VAE (full model)...") ae = torch.load(vae_path, map_location="cpu", weights_only=False, mmap=True) gc.collect() generator = ImageGenerator.__new__(ImageGenerator) generator.device = torch.device("cpu") generator.args = args generator.ae = ae generator.dit = dit generator.llm_encoder = None generator.quantized = False generator.offload = False generator.lora_module = None print(f"[init] Ready. Total load: {time.time()-t0:.0f}s") return generator print("[init] Loading model at startup (not lazy)...") GENERATOR = _load_generator() print("[init] Model ready, starting Gradio...") def generate(image): import gradio as gr from PIL import Image if image is None: raise gr.Error("Please upload an image.") if isinstance(image, str): image = Image.open(image).convert("RGB") elif not isinstance(image, Image.Image): image = Image.fromarray(image).convert("RGB") args = Args() print(f"[gen] Input: {image.size}") t0 = time.time() with torch.inference_mode(): images, Lpred, Rpred = GENERATOR.generate_image( prompt="", negative_prompt="", ref_images=image, num_samples=1, num_steps=1, cfg_guidance=6.0, seed=42, show_progress=True, args=args, ) elapsed = time.time() - t0 import numpy as np import matplotlib.cm as cm normal_np = Rpred[0].cpu().float().numpy().transpose(1, 2, 0) normal_norm = np.linalg.norm(normal_np, axis=-1, keepdims=True) normal_norm[normal_norm < 1e-12] = 1e-12 normal_np = normal_np / normal_norm normal_rgb = (((normal_np + 1) * 0.5) * 255).clip(0, 255).astype(np.uint8) normal_map = Image.fromarray(normal_rgb) normal_map = normal_map.resize(image.size) depth_np = Lpred[0].cpu().float().mean(dim=0).numpy() depth_np = (depth_np - depth_np.min()) / (depth_np.max() - depth_np.min() + 1e-8) depth_colored = (cm.turbo(depth_np)[:, :, :3] * 255).astype(np.uint8) depth_map = Image.fromarray(depth_colored) depth_map = depth_map.resize(image.size) status = f"Generated in {elapsed:.1f}s ({image.size[0]}x{image.size[1]}, single denoise, INT8)" print(f"[gen] {status}") return depth_map, normal_map, status import gradio as gr with gr.Blocks(title="FE2E: Depth + Normal (CPU)") as demo: gr.Markdown( "**[FE2E](https://github.com/AMAP-ML/FE2E)** Depth + Normal from a single image (CVPR 2026). " "Takes ~29 min for 768x1024, 1 step, Step1X-Edit DiT + LDRN LoRA (pre-merged), INT8 quantized on CPU." ) with gr.Row(equal_height=True): input_img = gr.Image(label="Input", type="pil", height=256) depth_out = gr.Image(label="Depth", type="pil", height=256) normal_out = gr.Image(label="Normal", type="pil", height=256) with gr.Row(): run_btn = gr.Button("Estimate Depth + Normal", variant="primary", size="lg") status_out = gr.Textbox(label="Status", interactive=False, scale=2) run_btn.click( fn=generate, inputs=[input_img], outputs=[depth_out, normal_out, status_out], concurrency_limit=1, api_name="generate", ) gr.Examples( examples=["assets/example.jpg"], inputs=[input_img], outputs=[depth_out, normal_out, status_out], fn=generate, cache_examples=False, label="Examples", ) demo.queue(default_concurrency_limit=1) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True, mcp_server=True, ssr_mode=False)