O96a's picture
Add app.py - fix NO_APP_FILE error
b44db72 verified
"""
CoT Spatial Reasoning Demo
Based on: "Chain-of-Thought Degrades Visual Spatial Reasoning" (arXiv:2604.16060)
This demo explores how Chain-of-Thought prompting affects spatial reasoning
capabilities in multimodal models.
"""
import gradio as gr
from PIL import Image, ImageDraw
import random
def create_spatial_grid_puzzle():
"""Create a spatial reasoning puzzle with grid layout"""
img = Image.new('RGB', (400, 400), color='white')
draw = ImageDraw.Draw(img)
# Draw 3x3 grid
colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#FFA07A', '#98D8C8', '#F7DC6F', '#DDA0DD', '#F0E68C', '#FFB6C1']
shapes = []
for i in range(3):
for j in range(3):
x, y = 50 + j * 100, 50 + i * 100
color = colors[i * 3 + j]
# Draw different shapes
if (i + j) % 3 == 0:
draw.ellipse([x, y, x+60, y+60], fill=color, outline='black', width=2)
shape = "circle"
elif (i + j) % 3 == 1:
draw.rectangle([x, y, x+60, y+60], fill=color, outline='black', width=2)
shape = "square"
else:
draw.polygon([(x+30, y), (x+60, y+60), (x, y+60)], fill=color, outline='black', width=2)
shape = "triangle"
shapes.append({
"row": i + 1,
"col": j + 1,
"shape": shape,
"color": color
})
return img, shapes
def direct_answer(puzzle_type):
"""Simulate direct answering (no CoT)"""
img, shapes = create_spatial_grid_puzzle()
if puzzle_type == "Center Shape":
target = shapes[4] # Center
question = "What shape is in the center (row 2, column 2)?"
answer = target["shape"]
elif puzzle_type == "Corner Colors":
corners = [shapes[0], shapes[2], shapes[6], shapes[8]]
question = "How many corners contain circles?"
answer = str(sum(1 for s in corners if s["shape"] == "circle"))
else: # Pattern Recognition
question = "What shape appears most frequently?"
counts = {}
for s in shapes:
counts[s["shape"]] = counts.get(s["shape"], 0) + 1
answer = max(counts, key=counts.get)
response = f"**Direct Answer:** {answer}"
return img, question, response
def cot_answer(puzzle_type):
"""Simulate Chain-of-Thought reasoning"""
img, shapes = create_spatial_grid_puzzle()
if puzzle_type == "Center Shape":
target = shapes[4]
question = "What shape is in the center (row 2, column 2)?"
cot = f"""**CoT Reasoning:**
1. The grid is 3x3, so center is at position (2,2)
2. Let me trace the grid:
- Row 1: {shapes[0]['shape']}, {shapes[1]['shape']}, {shapes[2]['shape']}
- Row 2: {shapes[3]['shape']}, [CENTER], {shapes[5]['shape']}
- Row 3: {shapes[6]['shape']}, {shapes[7]['shape']}, {shapes[8]['shape']}
3. The center shape is a {target['shape']}
**Answer:** {target['shape']}"""
elif puzzle_type == "Corner Colors":
corners = [shapes[0], shapes[2], shapes[6], shapes[8]]
question = "How many corners contain circles?"
corner_shapes = [s['shape'] for s in corners]
circles = corner_shapes.count("circle")
cot = f"""**CoT Reasoning:**
1. Corners are positions: (1,1), (1,3), (3,1), (3,3)
2. Corner shapes: {', '.join(corner_shapes)}
3. Count circles: {circles}
**Answer:** {circles}"""
else: # Pattern Recognition
counts = {}
for s in shapes:
counts[s["shape"]] = counts.get(s["shape"], 0) + 1
most_common = max(counts, key=counts.get)
cot = f"""**CoT Reasoning:**
1. Count all shapes in grid:
- Circles: {counts.get('circle', 0)}
- Squares: {counts.get('square', 0)}
- Triangles: {counts.get('triangle', 0)}
2. Most common: {most_common}
**Answer:** {most_common}"""
return img, question, cot
def compare_both(puzzle_type):
"""Compare direct vs CoT side by side"""
img1, q1, direct = direct_answer(puzzle_type)
img2, q2, cot = cot_answer(puzzle_type)
comparison = f"""## {puzzle_type}
**Question:** {q1}
---
{direct}
---
{cot}
---
**Key Insight:** CoT adds reasoning steps but may introduce errors through over-analysis of spatial relationships."""
return img1, comparison
# Gradio Interface
with gr.Blocks(title="CoT Spatial Reasoning") as demo:
gr.Markdown("""
# 📉 CoT Spatial Reasoning
Exploring how Chain-of-Thought affects spatial reasoning capabilities.
Based on: *"Chain-of-Thought Degrades Visual Spatial Reasoning Capabilities of Multimodal LLMs"* (arXiv:2604.16060)
""")
with gr.Tab("Live Comparison"):
with gr.Row():
puzzle_select = gr.Dropdown(
choices=["Center Shape", "Corner Colors", "Pattern Recognition"],
value="Center Shape",
label="Select Puzzle Type"
)
with gr.Row():
with gr.Column():
puzzle_image = gr.Image(type="pil", label="Spatial Puzzle")
with gr.Column():
comparison_output = gr.Markdown(label="Comparison")
run_btn = gr.Button("Run Comparison", variant="primary")
run_btn.click(
fn=compare_both,
inputs=[puzzle_select],
outputs=[puzzle_image, comparison_output]
)
with gr.Tab("Paper Findings"):
gr.Markdown("""
## Key Findings
The paper demonstrates that Chain-of-Thought prompting can **degrade** spatial reasoning performance:
1. **Shortcut Learning**: Models learn to follow textual patterns rather than analyze visual space
2. **Over-verbalization**: Converting visual tasks to language introduces errors
3. **Task-dependent**: Effect varies by spatial reasoning type
**Recommendation**: Use direct visual processing for pure spatial tasks.
""")
if __name__ == "__main__":
demo.launch()