import spaces from extend3d import Extend3D from trellis.utils import render_utils, postprocessing_utils import imageio import random import uuid from pathlib import Path import numpy as np import torch import gradio as gr MODEL_ID = "microsoft/TRELLIS-image-large" DEFAULT_OUTPUT_DIR = "./output" # --------------------------------------------------------------------------- # Pipeline loading # --------------------------------------------------------------------------- PIPELINE: Extend3D = Extend3D.from_pretrained(MODEL_ID).cuda() # --------------------------------------------------------------------------- # Inference # --------------------------------------------------------------------------- @spaces.GPU(duration=300) def run_extend3d( image_pil, seed: int, randomize_seed: bool, width: int, length: int, div: int, ss_optim: bool, ss_iterations: int, ss_steps: int, ss_rescale_t: float, ss_t_noise: float, ss_t_start: float, ss_cfg_strength: float, ss_alpha: float, ss_batch_size: int, slat_optim: bool, slat_steps: int, slat_rescale_t: float, slat_cfg_strength: float, slat_batch_size: int, progress=gr.Progress(), ): if randomize_seed: seed = random.randint(0, 2147483647) torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) pipe = PIPELINE output = pipe.run( image_pil, width, length, div, ss_optim, ss_iterations, ss_steps, ss_rescale_t, ss_t_noise, ss_t_start, ss_cfg_strength, ss_alpha, ss_batch_size, slat_optim, slat_steps, slat_rescale_t, slat_cfg_strength, slat_batch_size, progress_callback=lambda frac, desc: progress(frac, desc=desc), ) gaussian = output["gaussian"][0] mesh = output["mesh"][0] out_dir = Path(DEFAULT_OUTPUT_DIR) out_dir.mkdir(parents=True, exist_ok=True) run_id = uuid.uuid4().hex # Render preview video progress(0, desc="Rendering video...") color_frames = render_utils.render_video(gaussian, r=1.6, resolution=1024)["color"] progress(0.5, desc="Rendering video...") normal_frames = render_utils.render_video(mesh, r=1.6, resolution=1024)["normal"] progress(1.0, desc="Rendering video...") video_frames = [ np.concatenate([c, n], axis=1) for c, n in zip(color_frames, normal_frames) ] video_path = str(out_dir / f"preview_{run_id}.mp4") imageio.mimsave(video_path, video_frames, fps=30) # Export GLB mesh progress(0, desc="Exporting GLB...") glb = postprocessing_utils.to_glb(gaussian, mesh, simplify=0.98, texture_size=1024) glb.visual.material.metallicFactor = 0.0 glb_path = str(out_dir / f"preview_{run_id}.glb") glb.export(glb_path) progress(1.0, desc="Done!") return video_path, glb_path, seed # --------------------------------------------------------------------------- # UI # --------------------------------------------------------------------------- css = """ #examples_gallery .gallery-item { width: 160px !important; height: 160px !important; min-width: 160px !important; } #examples_gallery img { width: 100% !important; height: 100% !important; object-fit: cover; } #examples_gallery .gallery { width: 100% !important; height: 100% !important; object-fit: cover; max-width: none !important; justify-content: center; } """ with gr.Blocks(title="Extend3D Demo", css=css) as demo: gr.Markdown("# Extend3D: Town-scale 3D Generation") gr.Markdown("[Project Page](https://seungwoo-yoon.github.io/extend3d-page/) | [Code](https://github.com/Seungwoo-Yoon/Extend3D) | [Paper](#)") with gr.Row(): # Left column: inputs and settings with gr.Column(scale=4, min_width=420): gr.Markdown("### Input") image_in = gr.Image(label="Input Image", type="pil") run_btn = gr.Button("Run", variant="primary") with gr.Accordion("Settings", open=True): seed = gr.Slider(0, 2147483647, value=42, step=1, label="seed") randomize_seed = gr.Checkbox(value=True, label="randomize_seed") with gr.Row(): width = gr.Slider(1, 8, value=2, step=1, label="width") length = gr.Slider(1, 8, value=2, step=1, label="length") div = gr.Slider(1, 8, value=4, step=1, label="div") with gr.Accordion("Sparse Structure Settings", open=False): ss_optim = gr.Checkbox(value=True, label="optimize") with gr.Row(): ss_iterations = gr.Slider(1, 10, value=3, step=1, label="iterations") ss_steps = gr.Slider(1, 100, value=25, step=1, label="steps") with gr.Row(): ss_rescale_t = gr.Slider(1, 10, value=3.0, step=0.1, label="rescale_t") ss_cfg_strength = gr.Slider(1, 10, value=7.5, step=0.1, label="cfg_strength") with gr.Row(): ss_t_noise = gr.Slider(0, 1, value=0.6, step=0.1, label="t_noise") ss_t_start = gr.Slider(0, 1, value=0.8, step=0.1, label="t_start") ss_alpha = gr.Slider(1, 10, value=5.0, step=0.1, label="alpha") ss_batch_size = gr.Slider(1, 16, value=1, step=1, label="batch_size") with gr.Accordion("SLAT Settings", open=False): slat_optim = gr.Checkbox(value=True, label="optimize") with gr.Row(): slat_steps = gr.Slider(1, 100, value=25, step=1, label="steps") with gr.Row(): slat_rescale_t = gr.Slider(1, 10, value=3.0, step=0.1, label="rescale_t") slat_cfg_strength = gr.Slider(1, 10, value=3.0, step=0.1, label="cfg_strength") slat_batch_size = gr.Slider(1, 16, value=1, step=1, label="batch_size") # Right column: outputs with gr.Column(scale=5, min_width=420): gr.Markdown("### Output") preview_video = gr.Video(label="3D Preview (Video)", value=None, autoplay=True, loop=True) preview_glb = gr.Model3D(label="3D Preview (GLB)", value=None) gr.Examples( examples=[ "assets/examples/0.png", "assets/examples/1.png", "assets/examples/2.png", "assets/examples/3.png", "assets/examples/4.png", "assets/examples/5.webp", ], inputs=[image_in], label="Examples", examples_per_page=6, elem_id="examples_gallery", ) run_btn.click( fn=run_extend3d, inputs=[ image_in, seed, randomize_seed, width, length, div, ss_optim, ss_iterations, ss_steps, ss_rescale_t, ss_t_noise, ss_t_start, ss_cfg_strength, ss_alpha, ss_batch_size, slat_optim, slat_steps, slat_rescale_t, slat_cfg_strength, slat_batch_size, ], outputs=[preview_video, preview_glb, seed], ) if __name__ == "__main__": demo.launch()