XciD's picture
XciD HF Staff
feat: add AI render mode with SDXL + ControlNet Tile inpainting
9a8a023 unverified
import spaces
import numpy as np
import torch
import gradio as gr
from PIL import Image, ImageDraw
from transformers import (
AutoImageProcessor,
Mask2FormerForUniversalSegmentation,
AutoModelForDepthEstimation,
)
from diffusers import (
StableDiffusionXLControlNetInpaintPipeline,
ControlNetModel,
AutoencoderKL,
)
# ─── Segmentation + Depth models ─────────────────
seg_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-large-ade-semantic")
seg_model = Mask2FormerForUniversalSegmentation.from_pretrained(
"facebook/mask2former-swin-large-ade-semantic"
)
FLOOR_KEYWORDS = {'floor', 'flooring', 'rug', 'carpet', 'mat'}
FLOOR_IDS = set()
id2label = seg_model.config.id2label
for idx, label in id2label.items():
if any(kw in label.lower() for kw in FLOOR_KEYWORDS):
FLOOR_IDS.add(int(idx))
print(f"Floor class: {idx} = {label}")
if not FLOOR_IDS:
FLOOR_IDS = {3, 28}
depth_processor = AutoImageProcessor.from_pretrained("depth-anything/Depth-Anything-V2-Large-hf")
depth_model = AutoModelForDepthEstimation.from_pretrained(
"depth-anything/Depth-Anything-V2-Large-hf", torch_dtype=torch.float16
)
# ─── SDXL + ControlNet Tile for AI rendering ─────
print("Loading ControlNet Tile + SDXL inpainting pipeline...")
controlnet = ControlNetModel.from_pretrained(
"xinsir/controlnet-tile-sdxl-1.0",
torch_dtype=torch.float16,
)
vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix",
torch_dtype=torch.float16,
)
inpaint_pipe = StableDiffusionXLControlNetInpaintPipeline.from_pretrained(
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
controlnet=controlnet,
vae=vae,
torch_dtype=torch.float16,
variant="fp16",
)
inpaint_pipe.enable_model_cpu_offload()
print("Pipeline loaded.")
@spaces.GPU(duration=60)
@torch.inference_mode()
def predict(image):
if image is None:
raise gr.Error("No image provided")
orig_w, orig_h = image.size
max_size = 1024
scale = min(1.0, max_size / max(orig_w, orig_h))
proc_w, proc_h = int(orig_w * scale), int(orig_h * scale)
image_resized = image.resize((proc_w, proc_h), Image.LANCZOS)
device = seg_model.device
seg_inputs = seg_processor(images=image_resized, return_tensors="pt")
seg_inputs = {k: v.to(device) for k, v in seg_inputs.items()}
seg_outputs = seg_model(**seg_inputs)
seg_result = seg_processor.post_process_semantic_segmentation(
seg_outputs, target_sizes=[(proc_h, proc_w)]
)[0]
seg_map = seg_result.cpu().numpy()
floor_mask = np.zeros((proc_h, proc_w), dtype=np.uint8)
unique_classes = np.unique(seg_map)
print(f"Detected classes: {[(int(c), id2label.get(c, '?')) for c in unique_classes]}")
for class_id in FLOOR_IDS:
floor_mask[seg_map == class_id] = 255
mask_img = Image.fromarray(floor_mask).resize((orig_w, orig_h), Image.NEAREST)
depth_inputs = depth_processor(images=image_resized, return_tensors="pt")
depth_inputs = {k: v.to(device, dtype=torch.float16) if v.is_floating_point() else v.to(device) for k, v in depth_inputs.items()}
depth_outputs = depth_model(**depth_inputs)
depth_map = depth_outputs.predicted_depth.squeeze().cpu().numpy()
depth_min, depth_max = depth_map.min(), depth_map.max()
if depth_max - depth_min > 0:
depth_norm = ((depth_map - depth_min) / (depth_max - depth_min) * 255).astype(np.uint8)
else:
depth_norm = np.zeros_like(depth_map, dtype=np.uint8)
depth_img = Image.fromarray(depth_norm).resize((orig_w, orig_h), Image.BILINEAR)
return mask_img, depth_img
def create_tiled_control_image(tile_texture, width, height):
"""Tile the texture image to fill width x height."""
tw, th = tile_texture.size
control = Image.new("RGB", (width, height))
for y in range(0, height, th):
for x in range(0, width, tw):
control.paste(tile_texture, (x, y))
return control
@spaces.GPU(duration=120)
@torch.inference_mode()
def render_ai(room_image, tile_texture):
if room_image is None or tile_texture is None:
raise gr.Error("Room image and tile texture are required")
# Step 1: Get floor mask
mask_img, _ = predict.__wrapped__(room_image)
# Resize everything to 1024x1024 for SDXL
size = 1024
room_resized = room_image.resize((size, size), Image.LANCZOS)
mask_resized = mask_img.resize((size, size), Image.NEAREST)
# Step 2: Create tiled control image from tile texture
tile_size = max(64, size // 8)
tile_resized = tile_texture.resize((tile_size, tile_size), Image.LANCZOS)
control_image = create_tiled_control_image(tile_resized, size, size)
# Step 3: Run SDXL inpainting with ControlNet Tile
result = inpaint_pipe(
prompt="ceramic tile floor, tiled floor with repeating pattern, interior design photo, photorealistic",
negative_prompt="blurry, distorted, low quality, watermark, text",
image=room_resized,
mask_image=mask_resized,
control_image=control_image,
num_inference_steps=25,
guidance_scale=7.0,
controlnet_conditioning_scale=0.9,
strength=0.95,
generator=torch.Generator(device="cuda").manual_seed(42),
).images[0]
# Resize back to original dimensions
result = result.resize((room_image.size[0], room_image.size[1]), Image.LANCZOS)
return result
with gr.Blocks() as demo:
gr.Markdown("# Tile Visualizer API")
with gr.Tab("Segmentation"):
with gr.Row():
seg_input = gr.Image(type="pil", label="Room photo")
with gr.Row():
mask_output = gr.Image(type="pil", label="Floor mask")
depth_output = gr.Image(type="pil", label="Depth map")
seg_btn = gr.Button("Segment")
seg_btn.click(fn=predict, inputs=seg_input, outputs=[mask_output, depth_output])
with gr.Tab("AI Render"):
with gr.Row():
render_room = gr.Image(type="pil", label="Room photo")
render_tile = gr.Image(type="pil", label="Tile texture")
render_output = gr.Image(type="pil", label="Result")
render_btn = gr.Button("Render")
render_btn.click(fn=render_ai, inputs=[render_room, render_tile], outputs=render_output)
app = demo.app
from starlette.middleware.cors import CORSMiddleware
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
if __name__ == "__main__":
demo.launch(ssr_mode=False)