Spaces:
Sleeping
Sleeping
| import spaces | |
| import numpy as np | |
| import torch | |
| import gradio as gr | |
| from PIL import Image, ImageDraw | |
| from transformers import ( | |
| AutoImageProcessor, | |
| Mask2FormerForUniversalSegmentation, | |
| AutoModelForDepthEstimation, | |
| ) | |
| from diffusers import ( | |
| StableDiffusionXLControlNetInpaintPipeline, | |
| ControlNetModel, | |
| AutoencoderKL, | |
| ) | |
| # ─── Segmentation + Depth models ───────────────── | |
| seg_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-large-ade-semantic") | |
| seg_model = Mask2FormerForUniversalSegmentation.from_pretrained( | |
| "facebook/mask2former-swin-large-ade-semantic" | |
| ) | |
| FLOOR_KEYWORDS = {'floor', 'flooring', 'rug', 'carpet', 'mat'} | |
| FLOOR_IDS = set() | |
| id2label = seg_model.config.id2label | |
| for idx, label in id2label.items(): | |
| if any(kw in label.lower() for kw in FLOOR_KEYWORDS): | |
| FLOOR_IDS.add(int(idx)) | |
| print(f"Floor class: {idx} = {label}") | |
| if not FLOOR_IDS: | |
| FLOOR_IDS = {3, 28} | |
| depth_processor = AutoImageProcessor.from_pretrained("depth-anything/Depth-Anything-V2-Large-hf") | |
| depth_model = AutoModelForDepthEstimation.from_pretrained( | |
| "depth-anything/Depth-Anything-V2-Large-hf", torch_dtype=torch.float16 | |
| ) | |
| # ─── SDXL + ControlNet Tile for AI rendering ───── | |
| print("Loading ControlNet Tile + SDXL inpainting pipeline...") | |
| controlnet = ControlNetModel.from_pretrained( | |
| "xinsir/controlnet-tile-sdxl-1.0", | |
| torch_dtype=torch.float16, | |
| ) | |
| vae = AutoencoderKL.from_pretrained( | |
| "madebyollin/sdxl-vae-fp16-fix", | |
| torch_dtype=torch.float16, | |
| ) | |
| inpaint_pipe = StableDiffusionXLControlNetInpaintPipeline.from_pretrained( | |
| "diffusers/stable-diffusion-xl-1.0-inpainting-0.1", | |
| controlnet=controlnet, | |
| vae=vae, | |
| torch_dtype=torch.float16, | |
| variant="fp16", | |
| ) | |
| inpaint_pipe.enable_model_cpu_offload() | |
| print("Pipeline loaded.") | |
| def predict(image): | |
| if image is None: | |
| raise gr.Error("No image provided") | |
| orig_w, orig_h = image.size | |
| max_size = 1024 | |
| scale = min(1.0, max_size / max(orig_w, orig_h)) | |
| proc_w, proc_h = int(orig_w * scale), int(orig_h * scale) | |
| image_resized = image.resize((proc_w, proc_h), Image.LANCZOS) | |
| device = seg_model.device | |
| seg_inputs = seg_processor(images=image_resized, return_tensors="pt") | |
| seg_inputs = {k: v.to(device) for k, v in seg_inputs.items()} | |
| seg_outputs = seg_model(**seg_inputs) | |
| seg_result = seg_processor.post_process_semantic_segmentation( | |
| seg_outputs, target_sizes=[(proc_h, proc_w)] | |
| )[0] | |
| seg_map = seg_result.cpu().numpy() | |
| floor_mask = np.zeros((proc_h, proc_w), dtype=np.uint8) | |
| unique_classes = np.unique(seg_map) | |
| print(f"Detected classes: {[(int(c), id2label.get(c, '?')) for c in unique_classes]}") | |
| for class_id in FLOOR_IDS: | |
| floor_mask[seg_map == class_id] = 255 | |
| mask_img = Image.fromarray(floor_mask).resize((orig_w, orig_h), Image.NEAREST) | |
| depth_inputs = depth_processor(images=image_resized, return_tensors="pt") | |
| depth_inputs = {k: v.to(device, dtype=torch.float16) if v.is_floating_point() else v.to(device) for k, v in depth_inputs.items()} | |
| depth_outputs = depth_model(**depth_inputs) | |
| depth_map = depth_outputs.predicted_depth.squeeze().cpu().numpy() | |
| depth_min, depth_max = depth_map.min(), depth_map.max() | |
| if depth_max - depth_min > 0: | |
| depth_norm = ((depth_map - depth_min) / (depth_max - depth_min) * 255).astype(np.uint8) | |
| else: | |
| depth_norm = np.zeros_like(depth_map, dtype=np.uint8) | |
| depth_img = Image.fromarray(depth_norm).resize((orig_w, orig_h), Image.BILINEAR) | |
| return mask_img, depth_img | |
| def create_tiled_control_image(tile_texture, width, height): | |
| """Tile the texture image to fill width x height.""" | |
| tw, th = tile_texture.size | |
| control = Image.new("RGB", (width, height)) | |
| for y in range(0, height, th): | |
| for x in range(0, width, tw): | |
| control.paste(tile_texture, (x, y)) | |
| return control | |
| def render_ai(room_image, tile_texture): | |
| if room_image is None or tile_texture is None: | |
| raise gr.Error("Room image and tile texture are required") | |
| # Step 1: Get floor mask | |
| mask_img, _ = predict.__wrapped__(room_image) | |
| # Resize everything to 1024x1024 for SDXL | |
| size = 1024 | |
| room_resized = room_image.resize((size, size), Image.LANCZOS) | |
| mask_resized = mask_img.resize((size, size), Image.NEAREST) | |
| # Step 2: Create tiled control image from tile texture | |
| tile_size = max(64, size // 8) | |
| tile_resized = tile_texture.resize((tile_size, tile_size), Image.LANCZOS) | |
| control_image = create_tiled_control_image(tile_resized, size, size) | |
| # Step 3: Run SDXL inpainting with ControlNet Tile | |
| result = inpaint_pipe( | |
| prompt="ceramic tile floor, tiled floor with repeating pattern, interior design photo, photorealistic", | |
| negative_prompt="blurry, distorted, low quality, watermark, text", | |
| image=room_resized, | |
| mask_image=mask_resized, | |
| control_image=control_image, | |
| num_inference_steps=25, | |
| guidance_scale=7.0, | |
| controlnet_conditioning_scale=0.9, | |
| strength=0.95, | |
| generator=torch.Generator(device="cuda").manual_seed(42), | |
| ).images[0] | |
| # Resize back to original dimensions | |
| result = result.resize((room_image.size[0], room_image.size[1]), Image.LANCZOS) | |
| return result | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Tile Visualizer API") | |
| with gr.Tab("Segmentation"): | |
| with gr.Row(): | |
| seg_input = gr.Image(type="pil", label="Room photo") | |
| with gr.Row(): | |
| mask_output = gr.Image(type="pil", label="Floor mask") | |
| depth_output = gr.Image(type="pil", label="Depth map") | |
| seg_btn = gr.Button("Segment") | |
| seg_btn.click(fn=predict, inputs=seg_input, outputs=[mask_output, depth_output]) | |
| with gr.Tab("AI Render"): | |
| with gr.Row(): | |
| render_room = gr.Image(type="pil", label="Room photo") | |
| render_tile = gr.Image(type="pil", label="Tile texture") | |
| render_output = gr.Image(type="pil", label="Result") | |
| render_btn = gr.Button("Render") | |
| render_btn.click(fn=render_ai, inputs=[render_room, render_tile], outputs=render_output) | |
| app = demo.app | |
| from starlette.middleware.cors import CORSMiddleware | |
| app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) | |
| if __name__ == "__main__": | |
| demo.launch(ssr_mode=False) | |