Spaces:
Paused
Paused
File size: 5,233 Bytes
563ef1b 405d2b1 563ef1b 405d2b1 563ef1b 405d2b1 563ef1b 551acb3 563ef1b 551acb3 563ef1b 551acb3 563ef1b 405d2b1 563ef1b 405d2b1 563ef1b 405d2b1 563ef1b 551acb3 405d2b1 551acb3 563ef1b 551acb3 563ef1b 405d2b1 563ef1b 405d2b1 563ef1b 405d2b1 563ef1b 405d2b1 a3ba5ed 563ef1b 405d2b1 6ee4bac d65d5b5 563ef1b d65d5b5 563ef1b d65d5b5 563ef1b d65d5b5 563ef1b 405d2b1 563ef1b 405d2b1 284342e 405d2b1 284342e 563ef1b 284342e d65d5b5 284342e d65d5b5 284342e 405d2b1 d65d5b5 284342e 405d2b1 284342e 405d2b1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 | """FE2E: Depth + Normal estimation from a single image (CPU, pre-quantized INT8)"""
from __future__ import annotations
import gc
import os
import sys
import time
import torch
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
MODELS_DIR = "/tmp/fe2e_models"
os.makedirs(MODELS_DIR, exist_ok=True)
class Args:
prompt_type = "empty"
single_denoise = True
empty_prompt_cache = os.path.join(os.path.dirname(os.path.abspath(__file__)), "latent", "no_info.npz")
norm_type = "ln"
def _download(repo, filename, token=None):
import shutil
from huggingface_hub import hf_hub_download
basename = os.path.basename(filename)
dest = os.path.join(MODELS_DIR, basename)
if not os.path.exists(dest):
print(f"[init] Downloading {repo}/{filename}...")
src = hf_hub_download(repo, filename, token=token)
shutil.copy2(src, dest)
print(f"[init] {basename}: {os.path.getsize(dest)/1024/1024:.0f} MB")
return dest
def _load_generator():
from infer.inference import ImageGenerator
token = os.environ.get("HF_TOKEN")
dit_path = _download("WeReCooking2/FE2E-INT8", "dit_int8_full.pt", token)
vae_path = _download("WeReCooking2/FE2E-INT8", "vae_full.pt", token)
args = Args()
print("[init] Loading pre-quantized INT8 DiT (full model, mmap)...")
t0 = time.time()
dit = torch.load(dit_path, map_location="cpu", weights_only=False, mmap=True)
gc.collect()
print(f"[init] DiT loaded in {time.time()-t0:.0f}s")
print("[init] Loading VAE (full model)...")
ae = torch.load(vae_path, map_location="cpu", weights_only=False, mmap=True)
gc.collect()
generator = ImageGenerator.__new__(ImageGenerator)
generator.device = torch.device("cpu")
generator.args = args
generator.ae = ae
generator.dit = dit
generator.llm_encoder = None
generator.quantized = False
generator.offload = False
generator.lora_module = None
print(f"[init] Ready. Total load: {time.time()-t0:.0f}s")
return generator
print("[init] Loading model at startup (not lazy)...")
GENERATOR = _load_generator()
print("[init] Model ready, starting Gradio...")
def generate(image):
import gradio as gr
from PIL import Image
if image is None:
raise gr.Error("Please upload an image.")
if isinstance(image, str):
image = Image.open(image).convert("RGB")
elif not isinstance(image, Image.Image):
image = Image.fromarray(image).convert("RGB")
args = Args()
print(f"[gen] Input: {image.size}")
t0 = time.time()
with torch.inference_mode():
images, Lpred, Rpred = GENERATOR.generate_image(
prompt="",
negative_prompt="",
ref_images=image,
num_samples=1,
num_steps=1,
cfg_guidance=6.0,
seed=42,
show_progress=True,
args=args,
)
elapsed = time.time() - t0
import numpy as np
import matplotlib.cm as cm
normal_np = Rpred[0].cpu().float().numpy().transpose(1, 2, 0)
normal_norm = np.linalg.norm(normal_np, axis=-1, keepdims=True)
normal_norm[normal_norm < 1e-12] = 1e-12
normal_np = normal_np / normal_norm
normal_rgb = (((normal_np + 1) * 0.5) * 255).clip(0, 255).astype(np.uint8)
normal_map = Image.fromarray(normal_rgb)
normal_map = normal_map.resize(image.size)
depth_np = Lpred[0].cpu().float().mean(dim=0).numpy()
depth_np = (depth_np - depth_np.min()) / (depth_np.max() - depth_np.min() + 1e-8)
depth_colored = (cm.turbo(depth_np)[:, :, :3] * 255).astype(np.uint8)
depth_map = Image.fromarray(depth_colored)
depth_map = depth_map.resize(image.size)
status = f"Generated in {elapsed:.1f}s ({image.size[0]}x{image.size[1]}, single denoise, INT8)"
print(f"[gen] {status}")
return depth_map, normal_map, status
import gradio as gr
with gr.Blocks(title="FE2E: Depth + Normal (CPU)") as demo:
gr.Markdown(
"**[FE2E](https://github.com/AMAP-ML/FE2E)** Depth + Normal from a single image (CVPR 2026). "
"Takes ~29 min for 768x1024, 1 step, Step1X-Edit DiT + LDRN LoRA (pre-merged), INT8 quantized on CPU."
)
with gr.Row(equal_height=True):
input_img = gr.Image(label="Input", type="pil", height=256)
depth_out = gr.Image(label="Depth", type="pil", height=256)
normal_out = gr.Image(label="Normal", type="pil", height=256)
with gr.Row():
run_btn = gr.Button("Estimate Depth + Normal", variant="primary", size="lg")
status_out = gr.Textbox(label="Status", interactive=False, scale=2)
run_btn.click(
fn=generate,
inputs=[input_img],
outputs=[depth_out, normal_out, status_out],
concurrency_limit=1,
api_name="generate",
)
gr.Examples(
examples=["assets/example.jpg"],
inputs=[input_img],
outputs=[depth_out, normal_out, status_out],
fn=generate,
cache_examples=False,
label="Examples",
)
demo.queue(default_concurrency_limit=1)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True,
mcp_server=True, ssr_mode=False)
|