Spaces:
Paused
Paused
| """FE2E: Depth + Normal estimation from a single image (CPU, pre-quantized INT8)""" | |
| from __future__ import annotations | |
| import gc | |
| import os | |
| import sys | |
| import time | |
| import torch | |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) | |
| MODELS_DIR = "/tmp/fe2e_models" | |
| os.makedirs(MODELS_DIR, exist_ok=True) | |
| class Args: | |
| prompt_type = "empty" | |
| single_denoise = True | |
| empty_prompt_cache = os.path.join(os.path.dirname(os.path.abspath(__file__)), "latent", "no_info.npz") | |
| norm_type = "ln" | |
| def _download(repo, filename, token=None): | |
| import shutil | |
| from huggingface_hub import hf_hub_download | |
| basename = os.path.basename(filename) | |
| dest = os.path.join(MODELS_DIR, basename) | |
| if not os.path.exists(dest): | |
| print(f"[init] Downloading {repo}/{filename}...") | |
| src = hf_hub_download(repo, filename, token=token) | |
| shutil.copy2(src, dest) | |
| print(f"[init] {basename}: {os.path.getsize(dest)/1024/1024:.0f} MB") | |
| return dest | |
| def _load_generator(): | |
| from infer.inference import ImageGenerator | |
| token = os.environ.get("HF_TOKEN") | |
| dit_path = _download("WeReCooking2/FE2E-INT8", "dit_int8_full.pt", token) | |
| vae_path = _download("WeReCooking2/FE2E-INT8", "vae_full.pt", token) | |
| args = Args() | |
| print("[init] Loading pre-quantized INT8 DiT (full model, mmap)...") | |
| t0 = time.time() | |
| dit = torch.load(dit_path, map_location="cpu", weights_only=False, mmap=True) | |
| gc.collect() | |
| print(f"[init] DiT loaded in {time.time()-t0:.0f}s") | |
| print("[init] Loading VAE (full model)...") | |
| ae = torch.load(vae_path, map_location="cpu", weights_only=False, mmap=True) | |
| gc.collect() | |
| generator = ImageGenerator.__new__(ImageGenerator) | |
| generator.device = torch.device("cpu") | |
| generator.args = args | |
| generator.ae = ae | |
| generator.dit = dit | |
| generator.llm_encoder = None | |
| generator.quantized = False | |
| generator.offload = False | |
| generator.lora_module = None | |
| print(f"[init] Ready. Total load: {time.time()-t0:.0f}s") | |
| return generator | |
| print("[init] Loading model at startup (not lazy)...") | |
| GENERATOR = _load_generator() | |
| print("[init] Model ready, starting Gradio...") | |
| def generate(image): | |
| import gradio as gr | |
| from PIL import Image | |
| if image is None: | |
| raise gr.Error("Please upload an image.") | |
| if isinstance(image, str): | |
| image = Image.open(image).convert("RGB") | |
| elif not isinstance(image, Image.Image): | |
| image = Image.fromarray(image).convert("RGB") | |
| args = Args() | |
| print(f"[gen] Input: {image.size}") | |
| t0 = time.time() | |
| with torch.inference_mode(): | |
| images, Lpred, Rpred = GENERATOR.generate_image( | |
| prompt="", | |
| negative_prompt="", | |
| ref_images=image, | |
| num_samples=1, | |
| num_steps=1, | |
| cfg_guidance=6.0, | |
| seed=42, | |
| show_progress=True, | |
| args=args, | |
| ) | |
| elapsed = time.time() - t0 | |
| import numpy as np | |
| import matplotlib.cm as cm | |
| normal_np = Rpred[0].cpu().float().numpy().transpose(1, 2, 0) | |
| normal_norm = np.linalg.norm(normal_np, axis=-1, keepdims=True) | |
| normal_norm[normal_norm < 1e-12] = 1e-12 | |
| normal_np = normal_np / normal_norm | |
| normal_rgb = (((normal_np + 1) * 0.5) * 255).clip(0, 255).astype(np.uint8) | |
| normal_map = Image.fromarray(normal_rgb) | |
| normal_map = normal_map.resize(image.size) | |
| depth_np = Lpred[0].cpu().float().mean(dim=0).numpy() | |
| depth_np = (depth_np - depth_np.min()) / (depth_np.max() - depth_np.min() + 1e-8) | |
| depth_colored = (cm.turbo(depth_np)[:, :, :3] * 255).astype(np.uint8) | |
| depth_map = Image.fromarray(depth_colored) | |
| depth_map = depth_map.resize(image.size) | |
| status = f"Generated in {elapsed:.1f}s ({image.size[0]}x{image.size[1]}, single denoise, INT8)" | |
| print(f"[gen] {status}") | |
| return depth_map, normal_map, status | |
| import gradio as gr | |
| with gr.Blocks(title="FE2E: Depth + Normal (CPU)") as demo: | |
| gr.Markdown( | |
| "**[FE2E](https://github.com/AMAP-ML/FE2E)** Depth + Normal from a single image (CVPR 2026). " | |
| "Takes ~29 min for 768x1024, 1 step, Step1X-Edit DiT + LDRN LoRA (pre-merged), INT8 quantized on CPU." | |
| ) | |
| with gr.Row(equal_height=True): | |
| input_img = gr.Image(label="Input", type="pil", height=256) | |
| depth_out = gr.Image(label="Depth", type="pil", height=256) | |
| normal_out = gr.Image(label="Normal", type="pil", height=256) | |
| with gr.Row(): | |
| run_btn = gr.Button("Estimate Depth + Normal", variant="primary", size="lg") | |
| status_out = gr.Textbox(label="Status", interactive=False, scale=2) | |
| run_btn.click( | |
| fn=generate, | |
| inputs=[input_img], | |
| outputs=[depth_out, normal_out, status_out], | |
| concurrency_limit=1, | |
| api_name="generate", | |
| ) | |
| gr.Examples( | |
| examples=["assets/example.jpg"], | |
| inputs=[input_img], | |
| outputs=[depth_out, normal_out, status_out], | |
| fn=generate, | |
| cache_examples=False, | |
| label="Examples", | |
| ) | |
| demo.queue(default_concurrency_limit=1) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True, | |
| mcp_server=True, ssr_mode=False) | |