XciD HF Staff commited on
Commit
9a8a023
Β·
unverified Β·
1 Parent(s): 6bc9ac3

feat: add AI render mode with SDXL + ControlNet Tile inpainting

Browse files
Files changed (2) hide show
  1. app.py +96 -19
  2. requirements.txt +2 -0
app.py CHANGED
@@ -2,20 +2,25 @@ import spaces
2
  import numpy as np
3
  import torch
4
  import gradio as gr
5
- from PIL import Image
6
  from transformers import (
7
  AutoImageProcessor,
8
  Mask2FormerForUniversalSegmentation,
9
  AutoModelForDepthEstimation,
10
  )
 
 
 
 
 
 
 
11
 
12
- # Load models on CPU at startup. @spaces.GPU moves them to CUDA automatically.
13
  seg_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-large-ade-semantic")
14
  seg_model = Mask2FormerForUniversalSegmentation.from_pretrained(
15
  "facebook/mask2former-swin-large-ade-semantic"
16
  )
17
 
18
- # Find floor/rug class IDs from model config
19
  FLOOR_KEYWORDS = {'floor', 'flooring', 'rug', 'carpet', 'mat'}
20
  FLOOR_IDS = set()
21
  id2label = seg_model.config.id2label
@@ -31,8 +36,29 @@ depth_model = AutoModelForDepthEstimation.from_pretrained(
31
  "depth-anything/Depth-Anything-V2-Large-hf", torch_dtype=torch.float16
32
  )
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- @spaces.GPU
36
  @torch.inference_mode()
37
  def predict(image):
38
  if image is None:
@@ -46,10 +72,8 @@ def predict(image):
46
 
47
  device = seg_model.device
48
 
49
- # Segmentation (Mask2Former) - keep float32 for numerical stability
50
  seg_inputs = seg_processor(images=image_resized, return_tensors="pt")
51
  seg_inputs = {k: v.to(device) for k, v in seg_inputs.items()}
52
-
53
  seg_outputs = seg_model(**seg_inputs)
54
  seg_result = seg_processor.post_process_semantic_segmentation(
55
  seg_outputs, target_sizes=[(proc_h, proc_w)]
@@ -57,19 +81,15 @@ def predict(image):
57
 
58
  seg_map = seg_result.cpu().numpy()
59
  floor_mask = np.zeros((proc_h, proc_w), dtype=np.uint8)
60
- # Debug: log unique classes found
61
  unique_classes = np.unique(seg_map)
62
  print(f"Detected classes: {[(int(c), id2label.get(c, '?')) for c in unique_classes]}")
63
- print(f"Floor IDs: {FLOOR_IDS}")
64
  for class_id in FLOOR_IDS:
65
  floor_mask[seg_map == class_id] = 255
66
 
67
  mask_img = Image.fromarray(floor_mask).resize((orig_w, orig_h), Image.NEAREST)
68
 
69
- # Depth estimation
70
  depth_inputs = depth_processor(images=image_resized, return_tensors="pt")
71
  depth_inputs = {k: v.to(device, dtype=torch.float16) if v.is_floating_point() else v.to(device) for k, v in depth_inputs.items()}
72
-
73
  depth_outputs = depth_model(**depth_inputs)
74
  depth_map = depth_outputs.predicted_depth.squeeze().cpu().numpy()
75
 
@@ -84,17 +104,74 @@ def predict(image):
84
  return mask_img, depth_img
85
 
86
 
87
- with gr.Blocks() as demo:
88
- gr.Markdown("# Tile Visualizer - Segmentation API")
 
 
 
 
 
 
89
 
90
- with gr.Row():
91
- input_image = gr.Image(type="pil", label="Room photo")
92
- with gr.Row():
93
- mask_output = gr.Image(type="pil", label="Floor mask")
94
- depth_output = gr.Image(type="pil", label="Depth map")
95
 
96
- btn = gr.Button("Process")
97
- btn.click(fn=predict, inputs=input_image, outputs=[mask_output, depth_output])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  app = demo.app
100
 
 
2
  import numpy as np
3
  import torch
4
  import gradio as gr
5
+ from PIL import Image, ImageDraw
6
  from transformers import (
7
  AutoImageProcessor,
8
  Mask2FormerForUniversalSegmentation,
9
  AutoModelForDepthEstimation,
10
  )
11
+ from diffusers import (
12
+ StableDiffusionXLControlNetInpaintPipeline,
13
+ ControlNetModel,
14
+ AutoencoderKL,
15
+ )
16
+
17
+ # ─── Segmentation + Depth models ─────────────────
18
 
 
19
  seg_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-large-ade-semantic")
20
  seg_model = Mask2FormerForUniversalSegmentation.from_pretrained(
21
  "facebook/mask2former-swin-large-ade-semantic"
22
  )
23
 
 
24
  FLOOR_KEYWORDS = {'floor', 'flooring', 'rug', 'carpet', 'mat'}
25
  FLOOR_IDS = set()
26
  id2label = seg_model.config.id2label
 
36
  "depth-anything/Depth-Anything-V2-Large-hf", torch_dtype=torch.float16
37
  )
38
 
39
+ # ─── SDXL + ControlNet Tile for AI rendering ─────
40
+
41
+ print("Loading ControlNet Tile + SDXL inpainting pipeline...")
42
+ controlnet = ControlNetModel.from_pretrained(
43
+ "xinsir/controlnet-tile-sdxl-1.0",
44
+ torch_dtype=torch.float16,
45
+ )
46
+ vae = AutoencoderKL.from_pretrained(
47
+ "madebyollin/sdxl-vae-fp16-fix",
48
+ torch_dtype=torch.float16,
49
+ )
50
+ inpaint_pipe = StableDiffusionXLControlNetInpaintPipeline.from_pretrained(
51
+ "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
52
+ controlnet=controlnet,
53
+ vae=vae,
54
+ torch_dtype=torch.float16,
55
+ variant="fp16",
56
+ )
57
+ inpaint_pipe.enable_model_cpu_offload()
58
+ print("Pipeline loaded.")
59
+
60
 
61
+ @spaces.GPU(duration=60)
62
  @torch.inference_mode()
63
  def predict(image):
64
  if image is None:
 
72
 
73
  device = seg_model.device
74
 
 
75
  seg_inputs = seg_processor(images=image_resized, return_tensors="pt")
76
  seg_inputs = {k: v.to(device) for k, v in seg_inputs.items()}
 
77
  seg_outputs = seg_model(**seg_inputs)
78
  seg_result = seg_processor.post_process_semantic_segmentation(
79
  seg_outputs, target_sizes=[(proc_h, proc_w)]
 
81
 
82
  seg_map = seg_result.cpu().numpy()
83
  floor_mask = np.zeros((proc_h, proc_w), dtype=np.uint8)
 
84
  unique_classes = np.unique(seg_map)
85
  print(f"Detected classes: {[(int(c), id2label.get(c, '?')) for c in unique_classes]}")
 
86
  for class_id in FLOOR_IDS:
87
  floor_mask[seg_map == class_id] = 255
88
 
89
  mask_img = Image.fromarray(floor_mask).resize((orig_w, orig_h), Image.NEAREST)
90
 
 
91
  depth_inputs = depth_processor(images=image_resized, return_tensors="pt")
92
  depth_inputs = {k: v.to(device, dtype=torch.float16) if v.is_floating_point() else v.to(device) for k, v in depth_inputs.items()}
 
93
  depth_outputs = depth_model(**depth_inputs)
94
  depth_map = depth_outputs.predicted_depth.squeeze().cpu().numpy()
95
 
 
104
  return mask_img, depth_img
105
 
106
 
107
+ def create_tiled_control_image(tile_texture, width, height):
108
+ """Tile the texture image to fill width x height."""
109
+ tw, th = tile_texture.size
110
+ control = Image.new("RGB", (width, height))
111
+ for y in range(0, height, th):
112
+ for x in range(0, width, tw):
113
+ control.paste(tile_texture, (x, y))
114
+ return control
115
 
 
 
 
 
 
116
 
117
+ @spaces.GPU(duration=120)
118
+ @torch.inference_mode()
119
+ def render_ai(room_image, tile_texture):
120
+ if room_image is None or tile_texture is None:
121
+ raise gr.Error("Room image and tile texture are required")
122
+
123
+ # Step 1: Get floor mask
124
+ mask_img, _ = predict.__wrapped__(room_image)
125
+
126
+ # Resize everything to 1024x1024 for SDXL
127
+ size = 1024
128
+ room_resized = room_image.resize((size, size), Image.LANCZOS)
129
+ mask_resized = mask_img.resize((size, size), Image.NEAREST)
130
+
131
+ # Step 2: Create tiled control image from tile texture
132
+ tile_size = max(64, size // 8)
133
+ tile_resized = tile_texture.resize((tile_size, tile_size), Image.LANCZOS)
134
+ control_image = create_tiled_control_image(tile_resized, size, size)
135
+
136
+ # Step 3: Run SDXL inpainting with ControlNet Tile
137
+ result = inpaint_pipe(
138
+ prompt="ceramic tile floor, tiled floor with repeating pattern, interior design photo, photorealistic",
139
+ negative_prompt="blurry, distorted, low quality, watermark, text",
140
+ image=room_resized,
141
+ mask_image=mask_resized,
142
+ control_image=control_image,
143
+ num_inference_steps=25,
144
+ guidance_scale=7.0,
145
+ controlnet_conditioning_scale=0.9,
146
+ strength=0.95,
147
+ generator=torch.Generator(device="cuda").manual_seed(42),
148
+ ).images[0]
149
+
150
+ # Resize back to original dimensions
151
+ result = result.resize((room_image.size[0], room_image.size[1]), Image.LANCZOS)
152
+
153
+ return result
154
+
155
+
156
+ with gr.Blocks() as demo:
157
+ gr.Markdown("# Tile Visualizer API")
158
+
159
+ with gr.Tab("Segmentation"):
160
+ with gr.Row():
161
+ seg_input = gr.Image(type="pil", label="Room photo")
162
+ with gr.Row():
163
+ mask_output = gr.Image(type="pil", label="Floor mask")
164
+ depth_output = gr.Image(type="pil", label="Depth map")
165
+ seg_btn = gr.Button("Segment")
166
+ seg_btn.click(fn=predict, inputs=seg_input, outputs=[mask_output, depth_output])
167
+
168
+ with gr.Tab("AI Render"):
169
+ with gr.Row():
170
+ render_room = gr.Image(type="pil", label="Room photo")
171
+ render_tile = gr.Image(type="pil", label="Tile texture")
172
+ render_output = gr.Image(type="pil", label="Result")
173
+ render_btn = gr.Button("Render")
174
+ render_btn.click(fn=render_ai, inputs=[render_room, render_tile], outputs=render_output)
175
 
176
  app = demo.app
177
 
requirements.txt CHANGED
@@ -1,8 +1,10 @@
1
  torch
2
  torchvision
3
  transformers
 
4
  Pillow
5
  numpy
6
  gradio
7
  accelerate
8
  scipy
 
 
1
  torch
2
  torchvision
3
  transformers
4
+ diffusers
5
  Pillow
6
  numpy
7
  gradio
8
  accelerate
9
  scipy
10
+ safetensors