""" CoT Spatial Reasoning Demo Based on: "Chain-of-Thought Degrades Visual Spatial Reasoning" (arXiv:2604.16060) This demo explores how Chain-of-Thought prompting affects spatial reasoning capabilities in multimodal models. """ import gradio as gr from PIL import Image, ImageDraw import random def create_spatial_grid_puzzle(): """Create a spatial reasoning puzzle with grid layout""" img = Image.new('RGB', (400, 400), color='white') draw = ImageDraw.Draw(img) # Draw 3x3 grid colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#FFA07A', '#98D8C8', '#F7DC6F', '#DDA0DD', '#F0E68C', '#FFB6C1'] shapes = [] for i in range(3): for j in range(3): x, y = 50 + j * 100, 50 + i * 100 color = colors[i * 3 + j] # Draw different shapes if (i + j) % 3 == 0: draw.ellipse([x, y, x+60, y+60], fill=color, outline='black', width=2) shape = "circle" elif (i + j) % 3 == 1: draw.rectangle([x, y, x+60, y+60], fill=color, outline='black', width=2) shape = "square" else: draw.polygon([(x+30, y), (x+60, y+60), (x, y+60)], fill=color, outline='black', width=2) shape = "triangle" shapes.append({ "row": i + 1, "col": j + 1, "shape": shape, "color": color }) return img, shapes def direct_answer(puzzle_type): """Simulate direct answering (no CoT)""" img, shapes = create_spatial_grid_puzzle() if puzzle_type == "Center Shape": target = shapes[4] # Center question = "What shape is in the center (row 2, column 2)?" answer = target["shape"] elif puzzle_type == "Corner Colors": corners = [shapes[0], shapes[2], shapes[6], shapes[8]] question = "How many corners contain circles?" answer = str(sum(1 for s in corners if s["shape"] == "circle")) else: # Pattern Recognition question = "What shape appears most frequently?" counts = {} for s in shapes: counts[s["shape"]] = counts.get(s["shape"], 0) + 1 answer = max(counts, key=counts.get) response = f"**Direct Answer:** {answer}" return img, question, response def cot_answer(puzzle_type): """Simulate Chain-of-Thought reasoning""" img, shapes = create_spatial_grid_puzzle() if puzzle_type == "Center Shape": target = shapes[4] question = "What shape is in the center (row 2, column 2)?" cot = f"""**CoT Reasoning:** 1. The grid is 3x3, so center is at position (2,2) 2. Let me trace the grid: - Row 1: {shapes[0]['shape']}, {shapes[1]['shape']}, {shapes[2]['shape']} - Row 2: {shapes[3]['shape']}, [CENTER], {shapes[5]['shape']} - Row 3: {shapes[6]['shape']}, {shapes[7]['shape']}, {shapes[8]['shape']} 3. The center shape is a {target['shape']} **Answer:** {target['shape']}""" elif puzzle_type == "Corner Colors": corners = [shapes[0], shapes[2], shapes[6], shapes[8]] question = "How many corners contain circles?" corner_shapes = [s['shape'] for s in corners] circles = corner_shapes.count("circle") cot = f"""**CoT Reasoning:** 1. Corners are positions: (1,1), (1,3), (3,1), (3,3) 2. Corner shapes: {', '.join(corner_shapes)} 3. Count circles: {circles} **Answer:** {circles}""" else: # Pattern Recognition counts = {} for s in shapes: counts[s["shape"]] = counts.get(s["shape"], 0) + 1 most_common = max(counts, key=counts.get) cot = f"""**CoT Reasoning:** 1. Count all shapes in grid: - Circles: {counts.get('circle', 0)} - Squares: {counts.get('square', 0)} - Triangles: {counts.get('triangle', 0)} 2. Most common: {most_common} **Answer:** {most_common}""" return img, question, cot def compare_both(puzzle_type): """Compare direct vs CoT side by side""" img1, q1, direct = direct_answer(puzzle_type) img2, q2, cot = cot_answer(puzzle_type) comparison = f"""## {puzzle_type} **Question:** {q1} --- {direct} --- {cot} --- **Key Insight:** CoT adds reasoning steps but may introduce errors through over-analysis of spatial relationships.""" return img1, comparison # Gradio Interface with gr.Blocks(title="CoT Spatial Reasoning") as demo: gr.Markdown(""" # 📉 CoT Spatial Reasoning Exploring how Chain-of-Thought affects spatial reasoning capabilities. Based on: *"Chain-of-Thought Degrades Visual Spatial Reasoning Capabilities of Multimodal LLMs"* (arXiv:2604.16060) """) with gr.Tab("Live Comparison"): with gr.Row(): puzzle_select = gr.Dropdown( choices=["Center Shape", "Corner Colors", "Pattern Recognition"], value="Center Shape", label="Select Puzzle Type" ) with gr.Row(): with gr.Column(): puzzle_image = gr.Image(type="pil", label="Spatial Puzzle") with gr.Column(): comparison_output = gr.Markdown(label="Comparison") run_btn = gr.Button("Run Comparison", variant="primary") run_btn.click( fn=compare_both, inputs=[puzzle_select], outputs=[puzzle_image, comparison_output] ) with gr.Tab("Paper Findings"): gr.Markdown(""" ## Key Findings The paper demonstrates that Chain-of-Thought prompting can **degrade** spatial reasoning performance: 1. **Shortcut Learning**: Models learn to follow textual patterns rather than analyze visual space 2. **Over-verbalization**: Converting visual tasks to language introduces errors 3. **Task-dependent**: Effect varies by spatial reasoning type **Recommendation**: Use direct visual processing for pure spatial tasks. """) if __name__ == "__main__": demo.launch()