import spaces import numpy as np import torch import gradio as gr from PIL import Image, ImageDraw from transformers import ( AutoImageProcessor, Mask2FormerForUniversalSegmentation, AutoModelForDepthEstimation, ) from diffusers import ( StableDiffusionXLControlNetInpaintPipeline, ControlNetModel, AutoencoderKL, ) # ─── Segmentation + Depth models ───────────────── seg_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-large-ade-semantic") seg_model = Mask2FormerForUniversalSegmentation.from_pretrained( "facebook/mask2former-swin-large-ade-semantic" ) FLOOR_KEYWORDS = {'floor', 'flooring', 'rug', 'carpet', 'mat'} FLOOR_IDS = set() id2label = seg_model.config.id2label for idx, label in id2label.items(): if any(kw in label.lower() for kw in FLOOR_KEYWORDS): FLOOR_IDS.add(int(idx)) print(f"Floor class: {idx} = {label}") if not FLOOR_IDS: FLOOR_IDS = {3, 28} depth_processor = AutoImageProcessor.from_pretrained("depth-anything/Depth-Anything-V2-Large-hf") depth_model = AutoModelForDepthEstimation.from_pretrained( "depth-anything/Depth-Anything-V2-Large-hf", torch_dtype=torch.float16 ) # ─── SDXL + ControlNet Tile for AI rendering ───── print("Loading ControlNet Tile + SDXL inpainting pipeline...") controlnet = ControlNetModel.from_pretrained( "xinsir/controlnet-tile-sdxl-1.0", torch_dtype=torch.float16, ) vae = AutoencoderKL.from_pretrained( "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, ) inpaint_pipe = StableDiffusionXLControlNetInpaintPipeline.from_pretrained( "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", controlnet=controlnet, vae=vae, torch_dtype=torch.float16, variant="fp16", ) inpaint_pipe.enable_model_cpu_offload() print("Pipeline loaded.") @spaces.GPU(duration=60) @torch.inference_mode() def predict(image): if image is None: raise gr.Error("No image provided") orig_w, orig_h = image.size max_size = 1024 scale = min(1.0, max_size / max(orig_w, orig_h)) proc_w, proc_h = int(orig_w * scale), int(orig_h * scale) image_resized = image.resize((proc_w, proc_h), Image.LANCZOS) device = seg_model.device seg_inputs = seg_processor(images=image_resized, return_tensors="pt") seg_inputs = {k: v.to(device) for k, v in seg_inputs.items()} seg_outputs = seg_model(**seg_inputs) seg_result = seg_processor.post_process_semantic_segmentation( seg_outputs, target_sizes=[(proc_h, proc_w)] )[0] seg_map = seg_result.cpu().numpy() floor_mask = np.zeros((proc_h, proc_w), dtype=np.uint8) unique_classes = np.unique(seg_map) print(f"Detected classes: {[(int(c), id2label.get(c, '?')) for c in unique_classes]}") for class_id in FLOOR_IDS: floor_mask[seg_map == class_id] = 255 mask_img = Image.fromarray(floor_mask).resize((orig_w, orig_h), Image.NEAREST) depth_inputs = depth_processor(images=image_resized, return_tensors="pt") depth_inputs = {k: v.to(device, dtype=torch.float16) if v.is_floating_point() else v.to(device) for k, v in depth_inputs.items()} depth_outputs = depth_model(**depth_inputs) depth_map = depth_outputs.predicted_depth.squeeze().cpu().numpy() depth_min, depth_max = depth_map.min(), depth_map.max() if depth_max - depth_min > 0: depth_norm = ((depth_map - depth_min) / (depth_max - depth_min) * 255).astype(np.uint8) else: depth_norm = np.zeros_like(depth_map, dtype=np.uint8) depth_img = Image.fromarray(depth_norm).resize((orig_w, orig_h), Image.BILINEAR) return mask_img, depth_img def create_tiled_control_image(tile_texture, width, height): """Tile the texture image to fill width x height.""" tw, th = tile_texture.size control = Image.new("RGB", (width, height)) for y in range(0, height, th): for x in range(0, width, tw): control.paste(tile_texture, (x, y)) return control @spaces.GPU(duration=120) @torch.inference_mode() def render_ai(room_image, tile_texture): if room_image is None or tile_texture is None: raise gr.Error("Room image and tile texture are required") # Step 1: Get floor mask mask_img, _ = predict.__wrapped__(room_image) # Resize everything to 1024x1024 for SDXL size = 1024 room_resized = room_image.resize((size, size), Image.LANCZOS) mask_resized = mask_img.resize((size, size), Image.NEAREST) # Step 2: Create tiled control image from tile texture tile_size = max(64, size // 8) tile_resized = tile_texture.resize((tile_size, tile_size), Image.LANCZOS) control_image = create_tiled_control_image(tile_resized, size, size) # Step 3: Run SDXL inpainting with ControlNet Tile result = inpaint_pipe( prompt="ceramic tile floor, tiled floor with repeating pattern, interior design photo, photorealistic", negative_prompt="blurry, distorted, low quality, watermark, text", image=room_resized, mask_image=mask_resized, control_image=control_image, num_inference_steps=25, guidance_scale=7.0, controlnet_conditioning_scale=0.9, strength=0.95, generator=torch.Generator(device="cuda").manual_seed(42), ).images[0] # Resize back to original dimensions result = result.resize((room_image.size[0], room_image.size[1]), Image.LANCZOS) return result with gr.Blocks() as demo: gr.Markdown("# Tile Visualizer API") with gr.Tab("Segmentation"): with gr.Row(): seg_input = gr.Image(type="pil", label="Room photo") with gr.Row(): mask_output = gr.Image(type="pil", label="Floor mask") depth_output = gr.Image(type="pil", label="Depth map") seg_btn = gr.Button("Segment") seg_btn.click(fn=predict, inputs=seg_input, outputs=[mask_output, depth_output]) with gr.Tab("AI Render"): with gr.Row(): render_room = gr.Image(type="pil", label="Room photo") render_tile = gr.Image(type="pil", label="Tile texture") render_output = gr.Image(type="pil", label="Result") render_btn = gr.Button("Render") render_btn.click(fn=render_ai, inputs=[render_room, render_tile], outputs=render_output) app = demo.app from starlette.middleware.cors import CORSMiddleware app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) if __name__ == "__main__": demo.launch(ssr_mode=False)