Spaces:
Running
Running
| """ | |
| CoT Spatial Reasoning Demo | |
| Based on: "Chain-of-Thought Degrades Visual Spatial Reasoning" (arXiv:2604.16060) | |
| This demo explores how Chain-of-Thought prompting affects spatial reasoning | |
| capabilities in multimodal models. | |
| """ | |
| import gradio as gr | |
| from PIL import Image, ImageDraw | |
| import random | |
| def create_spatial_grid_puzzle(): | |
| """Create a spatial reasoning puzzle with grid layout""" | |
| img = Image.new('RGB', (400, 400), color='white') | |
| draw = ImageDraw.Draw(img) | |
| # Draw 3x3 grid | |
| colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#FFA07A', '#98D8C8', '#F7DC6F', '#DDA0DD', '#F0E68C', '#FFB6C1'] | |
| shapes = [] | |
| for i in range(3): | |
| for j in range(3): | |
| x, y = 50 + j * 100, 50 + i * 100 | |
| color = colors[i * 3 + j] | |
| # Draw different shapes | |
| if (i + j) % 3 == 0: | |
| draw.ellipse([x, y, x+60, y+60], fill=color, outline='black', width=2) | |
| shape = "circle" | |
| elif (i + j) % 3 == 1: | |
| draw.rectangle([x, y, x+60, y+60], fill=color, outline='black', width=2) | |
| shape = "square" | |
| else: | |
| draw.polygon([(x+30, y), (x+60, y+60), (x, y+60)], fill=color, outline='black', width=2) | |
| shape = "triangle" | |
| shapes.append({ | |
| "row": i + 1, | |
| "col": j + 1, | |
| "shape": shape, | |
| "color": color | |
| }) | |
| return img, shapes | |
| def direct_answer(puzzle_type): | |
| """Simulate direct answering (no CoT)""" | |
| img, shapes = create_spatial_grid_puzzle() | |
| if puzzle_type == "Center Shape": | |
| target = shapes[4] # Center | |
| question = "What shape is in the center (row 2, column 2)?" | |
| answer = target["shape"] | |
| elif puzzle_type == "Corner Colors": | |
| corners = [shapes[0], shapes[2], shapes[6], shapes[8]] | |
| question = "How many corners contain circles?" | |
| answer = str(sum(1 for s in corners if s["shape"] == "circle")) | |
| else: # Pattern Recognition | |
| question = "What shape appears most frequently?" | |
| counts = {} | |
| for s in shapes: | |
| counts[s["shape"]] = counts.get(s["shape"], 0) + 1 | |
| answer = max(counts, key=counts.get) | |
| response = f"**Direct Answer:** {answer}" | |
| return img, question, response | |
| def cot_answer(puzzle_type): | |
| """Simulate Chain-of-Thought reasoning""" | |
| img, shapes = create_spatial_grid_puzzle() | |
| if puzzle_type == "Center Shape": | |
| target = shapes[4] | |
| question = "What shape is in the center (row 2, column 2)?" | |
| cot = f"""**CoT Reasoning:** | |
| 1. The grid is 3x3, so center is at position (2,2) | |
| 2. Let me trace the grid: | |
| - Row 1: {shapes[0]['shape']}, {shapes[1]['shape']}, {shapes[2]['shape']} | |
| - Row 2: {shapes[3]['shape']}, [CENTER], {shapes[5]['shape']} | |
| - Row 3: {shapes[6]['shape']}, {shapes[7]['shape']}, {shapes[8]['shape']} | |
| 3. The center shape is a {target['shape']} | |
| **Answer:** {target['shape']}""" | |
| elif puzzle_type == "Corner Colors": | |
| corners = [shapes[0], shapes[2], shapes[6], shapes[8]] | |
| question = "How many corners contain circles?" | |
| corner_shapes = [s['shape'] for s in corners] | |
| circles = corner_shapes.count("circle") | |
| cot = f"""**CoT Reasoning:** | |
| 1. Corners are positions: (1,1), (1,3), (3,1), (3,3) | |
| 2. Corner shapes: {', '.join(corner_shapes)} | |
| 3. Count circles: {circles} | |
| **Answer:** {circles}""" | |
| else: # Pattern Recognition | |
| counts = {} | |
| for s in shapes: | |
| counts[s["shape"]] = counts.get(s["shape"], 0) + 1 | |
| most_common = max(counts, key=counts.get) | |
| cot = f"""**CoT Reasoning:** | |
| 1. Count all shapes in grid: | |
| - Circles: {counts.get('circle', 0)} | |
| - Squares: {counts.get('square', 0)} | |
| - Triangles: {counts.get('triangle', 0)} | |
| 2. Most common: {most_common} | |
| **Answer:** {most_common}""" | |
| return img, question, cot | |
| def compare_both(puzzle_type): | |
| """Compare direct vs CoT side by side""" | |
| img1, q1, direct = direct_answer(puzzle_type) | |
| img2, q2, cot = cot_answer(puzzle_type) | |
| comparison = f"""## {puzzle_type} | |
| **Question:** {q1} | |
| --- | |
| {direct} | |
| --- | |
| {cot} | |
| --- | |
| **Key Insight:** CoT adds reasoning steps but may introduce errors through over-analysis of spatial relationships.""" | |
| return img1, comparison | |
| # Gradio Interface | |
| with gr.Blocks(title="CoT Spatial Reasoning") as demo: | |
| gr.Markdown(""" | |
| # 📉 CoT Spatial Reasoning | |
| Exploring how Chain-of-Thought affects spatial reasoning capabilities. | |
| Based on: *"Chain-of-Thought Degrades Visual Spatial Reasoning Capabilities of Multimodal LLMs"* (arXiv:2604.16060) | |
| """) | |
| with gr.Tab("Live Comparison"): | |
| with gr.Row(): | |
| puzzle_select = gr.Dropdown( | |
| choices=["Center Shape", "Corner Colors", "Pattern Recognition"], | |
| value="Center Shape", | |
| label="Select Puzzle Type" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| puzzle_image = gr.Image(type="pil", label="Spatial Puzzle") | |
| with gr.Column(): | |
| comparison_output = gr.Markdown(label="Comparison") | |
| run_btn = gr.Button("Run Comparison", variant="primary") | |
| run_btn.click( | |
| fn=compare_both, | |
| inputs=[puzzle_select], | |
| outputs=[puzzle_image, comparison_output] | |
| ) | |
| with gr.Tab("Paper Findings"): | |
| gr.Markdown(""" | |
| ## Key Findings | |
| The paper demonstrates that Chain-of-Thought prompting can **degrade** spatial reasoning performance: | |
| 1. **Shortcut Learning**: Models learn to follow textual patterns rather than analyze visual space | |
| 2. **Over-verbalization**: Converting visual tasks to language introduces errors | |
| 3. **Task-dependent**: Effect varies by spatial reasoning type | |
| **Recommendation**: Use direct visual processing for pure spatial tasks. | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch() | |