| """ |
| Micro-World Visualization: Understanding Residual Connections |
| |
| This script creates intuitive visualizations explaining: |
| 1. Signal flow through layers (forward pass) |
| 2. Gradient flow through layers (backward pass) |
| 3. The "gradient highway" effect of residual connections |
| 4. Layer-by-layer transformation visualization |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import numpy as np |
| import matplotlib.pyplot as plt |
| import matplotlib.patches as mpatches |
| from matplotlib.patches import FancyArrowPatch, FancyBboxPatch |
| import json |
| import os |
|
|
| |
| torch.manual_seed(42) |
| np.random.seed(42) |
|
|
| |
| with open('results_fair.json', 'r') as f: |
| results = json.load(f) |
|
|
| os.makedirs('plots_micro', exist_ok=True) |
|
|
| |
| |
| |
| def plot_signal_flow(): |
| """Visualize how signal magnitude changes through layers""" |
| |
| fig, axes = plt.subplots(1, 2, figsize=(14, 8)) |
| |
| plain_stds = results['plain_mlp']['activation_stds'] |
| res_stds = results['res_mlp']['activation_stds'] |
| |
| |
| input_std = 0.577 |
| plain_signal = [input_std] + plain_stds |
| res_signal = [input_std] + res_stds |
| |
| layers = range(len(plain_signal)) |
| |
| |
| ax = axes[0] |
| ax.set_title('PlainMLP: Signal DIES\n(No Residual Connection)', fontsize=14, fontweight='bold', color='#c0392b') |
| |
| |
| colors_plain = plt.cm.Reds(np.linspace(0.3, 0.9, len(plain_signal))) |
| bars = ax.bar(layers, plain_signal, color=colors_plain, edgecolor='darkred', linewidth=1.5) |
| |
| ax.set_xlabel('Layer (0=Input, 1-20=Hidden)', fontsize=12) |
| ax.set_ylabel('Signal Strength (Activation Std)', fontsize=12) |
| ax.set_ylim(0, 0.7) |
| |
| |
| ax.annotate('Signal\ncollapses!', xy=(15, 0.02), fontsize=12, color='darkred', |
| ha='center', fontweight='bold') |
| ax.axhline(y=0.1, color='gray', linestyle='--', alpha=0.5, label='Healthy threshold') |
| |
| |
| ax = axes[1] |
| ax.set_title('ResMLP: Signal PRESERVED\n(With Residual Connection)', fontsize=14, fontweight='bold', color='#2980b9') |
| |
| colors_res = plt.cm.Blues(np.linspace(0.3, 0.9, len(res_signal))) |
| bars = ax.bar(layers, res_signal, color=colors_res, edgecolor='darkblue', linewidth=1.5) |
| |
| ax.set_xlabel('Layer (0=Input, 1-20=Hidden)', fontsize=12) |
| ax.set_ylabel('Signal Strength (Activation Std)', fontsize=12) |
| ax.set_ylim(0, 0.7) |
| |
| |
| ax.annotate('Signal stays\nhealthy!', xy=(15, 0.25), fontsize=12, color='darkblue', |
| ha='center', fontweight='bold') |
| ax.axhline(y=0.1, color='gray', linestyle='--', alpha=0.5, label='Healthy threshold') |
| |
| plt.tight_layout() |
| plt.savefig('plots_micro/1_signal_flow.png', dpi=150, bbox_inches='tight') |
| plt.close() |
| print("[Plot 1] Signal flow visualization saved") |
|
|
|
|
| |
| |
| |
| def plot_gradient_flow(): |
| """Visualize gradient magnitude through layers""" |
| |
| fig, axes = plt.subplots(1, 2, figsize=(14, 8)) |
| |
| plain_grads = results['plain_mlp']['gradient_norms'] |
| res_grads = results['res_mlp']['gradient_norms'] |
| |
| layers = range(1, 21) |
| |
| |
| ax = axes[0] |
| ax.set_title('PlainMLP: Gradients VANISH\n(Backward Pass)', fontsize=14, fontweight='bold', color='#c0392b') |
| |
| |
| colors = plt.cm.Reds(np.linspace(0.9, 0.3, 20)) |
| ax.bar(layers, plain_grads, color=colors, edgecolor='darkred', linewidth=1) |
| ax.set_yscale('log') |
| ax.set_xlabel('Layer (1=First, 20=Last)', fontsize=12) |
| ax.set_ylabel('Gradient Magnitude (log scale)', fontsize=12) |
| ax.set_ylim(1e-20, 1e-1) |
| |
| |
| ax.annotate(f'Layer 20:\n{plain_grads[-1]:.1e}', xy=(20, plain_grads[-1]), |
| xytext=(17, 1e-4), fontsize=10, color='darkred', |
| arrowprops=dict(arrowstyle='->', color='darkred')) |
| ax.annotate(f'Layer 1:\n{plain_grads[0]:.1e}\n(DEAD!)', xy=(1, max(plain_grads[0], 1e-20)), |
| xytext=(4, 1e-15), fontsize=10, color='darkred', fontweight='bold', |
| arrowprops=dict(arrowstyle='->', color='darkred')) |
| |
| |
| ax = axes[1] |
| ax.set_title('ResMLP: Gradients FLOW\n(Backward Pass)', fontsize=14, fontweight='bold', color='#2980b9') |
| |
| colors = plt.cm.Blues(np.linspace(0.9, 0.3, 20)) |
| ax.bar(layers, res_grads, color=colors, edgecolor='darkblue', linewidth=1) |
| ax.set_yscale('log') |
| ax.set_xlabel('Layer (1=First, 20=Last)', fontsize=12) |
| ax.set_ylabel('Gradient Magnitude (log scale)', fontsize=12) |
| ax.set_ylim(1e-20, 1e-1) |
| |
| |
| ax.annotate(f'Layer 20:\n{res_grads[-1]:.1e}', xy=(20, res_grads[-1]), |
| xytext=(17, 1e-4), fontsize=10, color='darkblue', |
| arrowprops=dict(arrowstyle='->', color='darkblue')) |
| ax.annotate(f'Layer 1:\n{res_grads[0]:.1e}\n(Healthy!)', xy=(1, res_grads[0]), |
| xytext=(4, 1e-4), fontsize=10, color='darkblue', fontweight='bold', |
| arrowprops=dict(arrowstyle='->', color='darkblue')) |
| |
| plt.tight_layout() |
| plt.savefig('plots_micro/2_gradient_flow.png', dpi=150, bbox_inches='tight') |
| plt.close() |
| print("[Plot 2] Gradient flow visualization saved") |
|
|
|
|
| |
| |
| |
| def plot_highway_concept(): |
| """Visual diagram showing the gradient highway concept""" |
| |
| fig, axes = plt.subplots(2, 1, figsize=(14, 10)) |
| |
| |
| ax = axes[0] |
| ax.set_xlim(0, 12) |
| ax.set_ylim(0, 3) |
| ax.set_aspect('equal') |
| ax.axis('off') |
| ax.set_title('PlainMLP: Gradient Must Pass Through EVERY Layer\n(Like a winding mountain road)', |
| fontsize=14, fontweight='bold', color='#c0392b', pad=20) |
| |
| |
| for i in range(6): |
| x = 1 + i * 1.8 |
| box = FancyBboxPatch((x, 1), 1.2, 1, boxstyle="round,pad=0.05", |
| facecolor='#e74c3c', edgecolor='darkred', linewidth=2) |
| ax.add_patch(box) |
| ax.text(x + 0.6, 1.5, f'L{i+1}', ha='center', va='center', fontsize=11, |
| color='white', fontweight='bold') |
| |
| |
| if i < 5: |
| thickness = 3 * (0.5 ** i) |
| alpha = max(0.2, 1 - i * 0.18) |
| ax.annotate('', xy=(x + 1.8, 1.5), xytext=(x + 1.2, 1.5), |
| arrowprops=dict(arrowstyle='->', color='darkred', |
| lw=thickness, alpha=alpha)) |
| |
| |
| ax.text(0.3, 1.5, 'Gradient\n→', fontsize=10, ha='center', va='center', color='darkred') |
| ax.text(11.5, 1.5, '→ Loss', fontsize=10, ha='center', va='center', color='darkred') |
| |
| |
| ax.annotate('Gradient shrinks\nat each layer!', xy=(8, 0.5), fontsize=11, |
| color='darkred', style='italic') |
| |
| |
| ax = axes[1] |
| ax.set_xlim(0, 12) |
| ax.set_ylim(0, 3.5) |
| ax.set_aspect('equal') |
| ax.axis('off') |
| ax.set_title('ResMLP: Gradient Has a Direct HIGHWAY\n(Skip connections = express lane)', |
| fontsize=14, fontweight='bold', color='#2980b9', pad=20) |
| |
| |
| ax.plot([1, 11], [2.8, 2.8], color='#27ae60', linewidth=6, alpha=0.8) |
| ax.annotate('', xy=(11, 2.8), xytext=(10.5, 2.8), |
| arrowprops=dict(arrowstyle='->', color='#27ae60', lw=3)) |
| ax.text(6, 3.2, '✓ GRADIENT HIGHWAY (Identity Path)', ha='center', fontsize=12, |
| color='#27ae60', fontweight='bold') |
| |
| |
| for i in range(6): |
| x = 1 + i * 1.8 |
| box = FancyBboxPatch((x, 1), 1.2, 1, boxstyle="round,pad=0.05", |
| facecolor='#3498db', edgecolor='darkblue', linewidth=2) |
| ax.add_patch(box) |
| ax.text(x + 0.6, 1.5, f'L{i+1}', ha='center', va='center', fontsize=11, |
| color='white', fontweight='bold') |
| |
| |
| if i < 5: |
| ax.annotate('', xy=(x + 1.8, 1.5), xytext=(x + 1.2, 1.5), |
| arrowprops=dict(arrowstyle='->', color='darkblue', lw=2)) |
| |
| |
| ax.plot([x + 0.6, x + 0.6], [2, 2.8], color='#27ae60', linewidth=2, alpha=0.5) |
| |
| ax.text(0.3, 1.5, 'Gradient\n→', fontsize=10, ha='center', va='center', color='darkblue') |
| ax.text(11.5, 1.5, '→ Loss', fontsize=10, ha='center', va='center', color='darkblue') |
| |
| |
| ax.annotate('Gradient flows on highway\neven if layers block it!', xy=(8, 0.3), |
| fontsize=11, color='#27ae60', style='italic') |
| |
| plt.tight_layout() |
| plt.savefig('plots_micro/3_highway_concept.png', dpi=150, bbox_inches='tight') |
| plt.close() |
| print("[Plot 3] Highway concept visualization saved") |
|
|
|
|
| |
| |
| |
| def plot_chain_rule(): |
| """Visualize the chain rule multiplication effect""" |
| |
| fig, axes = plt.subplots(1, 2, figsize=(14, 7)) |
| |
| |
| num_layers = 20 |
| |
| |
| plain_layer_grad = 0.7 |
| plain_cumulative = [1.0] |
| for i in range(num_layers): |
| plain_cumulative.append(plain_cumulative[-1] * plain_layer_grad) |
| |
| |
| res_layer_contrib = 0.05 |
| res_cumulative = [1.0] |
| for i in range(num_layers): |
| |
| res_cumulative.append(res_cumulative[-1] * (1.0 + res_layer_contrib * (0.9 ** i))) |
| |
| layers = range(num_layers + 1) |
| |
| |
| ax = axes[0] |
| ax.semilogy(layers, plain_cumulative, 'o-', color='#e74c3c', linewidth=2, |
| markersize=8, label='PlainMLP: 0.7 × 0.7 × 0.7 × ...') |
| ax.semilogy(layers, res_cumulative, 's-', color='#3498db', linewidth=2, |
| markersize=8, label='ResMLP: (1+ε) × (1+ε) × ...') |
| |
| ax.set_xlabel('Layers Traversed (backward from loss)', fontsize=12) |
| ax.set_ylabel('Cumulative Gradient Scale (log)', fontsize=12) |
| ax.set_title('Chain Rule: Why Gradients Vanish\n(Multiplication Effect)', fontsize=14, fontweight='bold') |
| ax.legend(fontsize=11) |
| ax.grid(True, alpha=0.3) |
| ax.set_ylim(1e-8, 10) |
| |
| |
| ax.annotate(f'After 20 layers:\n{plain_cumulative[-1]:.1e}', |
| xy=(20, plain_cumulative[-1]), xytext=(15, 1e-6), |
| fontsize=10, color='#c0392b', |
| arrowprops=dict(arrowstyle='->', color='#c0392b')) |
| ax.annotate(f'After 20 layers:\n{res_cumulative[-1]:.2f}', |
| xy=(20, res_cumulative[-1]), xytext=(15, 3), |
| fontsize=10, color='#2980b9', |
| arrowprops=dict(arrowstyle='->', color='#2980b9')) |
| |
| |
| ax = axes[1] |
| ax.axis('off') |
| ax.set_xlim(0, 10) |
| ax.set_ylim(0, 10) |
| |
| ax.text(5, 9, 'The Math Behind It', fontsize=16, fontweight='bold', |
| ha='center', va='center') |
| |
| |
| ax.text(5, 7.5, 'PlainMLP Gradient:', fontsize=13, fontweight='bold', |
| ha='center', color='#c0392b') |
| ax.text(5, 6.5, r'$\frac{\partial L}{\partial x_1} = \frac{\partial L}{\partial x_{20}} \times \prod_{i=1}^{20} \frac{\partial x_{i+1}}{\partial x_i}$', |
| fontsize=14, ha='center', color='#c0392b') |
| ax.text(5, 5.5, '= (small) × (small) × ... × (small) = TINY!', |
| fontsize=11, ha='center', color='#c0392b', style='italic') |
| |
| |
| ax.text(5, 4, 'ResMLP Gradient:', fontsize=13, fontweight='bold', |
| ha='center', color='#2980b9') |
| ax.text(5, 3, r'$\frac{\partial L}{\partial x_1} = \frac{\partial L}{\partial x_{20}} \times \prod_{i=1}^{20} (1 + \frac{\partial f_i}{\partial x_i})$', |
| fontsize=14, ha='center', color='#2980b9') |
| ax.text(5, 2, '= (1+ε) × (1+ε) × ... = PRESERVED!', |
| fontsize=11, ha='center', color='#2980b9', style='italic') |
| |
| |
| box = FancyBboxPatch((1, 0.3), 8, 1.2, boxstyle="round,pad=0.1", |
| facecolor='#f9e79f', edgecolor='#f39c12', linewidth=2) |
| ax.add_patch(box) |
| ax.text(5, 0.9, '💡 Key Insight: The "+x" in residual adds a "1" to each gradient term,\n' |
| 'preventing the product from shrinking to zero!', |
| fontsize=11, ha='center', va='center', fontweight='bold') |
| |
| plt.tight_layout() |
| plt.savefig('plots_micro/4_chain_rule.png', dpi=150, bbox_inches='tight') |
| plt.close() |
| print("[Plot 4] Chain rule visualization saved") |
|
|
|
|
| |
| |
| |
| def plot_layer_transformation(): |
| """Show what happens to a single input vector through layers""" |
| |
| |
| class PlainMLP(nn.Module): |
| def __init__(self, dim, num_layers): |
| super().__init__() |
| self.layers = nn.ModuleList() |
| for _ in range(num_layers): |
| layer = nn.Linear(dim, dim) |
| nn.init.kaiming_normal_(layer.weight) |
| layer.weight.data *= 1.0 / np.sqrt(num_layers) |
| nn.init.zeros_(layer.bias) |
| self.layers.append(layer) |
| self.activation = nn.ReLU() |
| |
| def forward_with_intermediates(self, x): |
| intermediates = [x.clone()] |
| for layer in self.layers: |
| x = self.activation(layer(x)) |
| intermediates.append(x.clone()) |
| return intermediates |
| |
| class ResMLP(nn.Module): |
| def __init__(self, dim, num_layers): |
| super().__init__() |
| self.layers = nn.ModuleList() |
| for _ in range(num_layers): |
| layer = nn.Linear(dim, dim) |
| nn.init.kaiming_normal_(layer.weight) |
| layer.weight.data *= 1.0 / np.sqrt(num_layers) |
| nn.init.zeros_(layer.bias) |
| self.layers.append(layer) |
| self.activation = nn.ReLU() |
| |
| def forward_with_intermediates(self, x): |
| intermediates = [x.clone()] |
| for layer in self.layers: |
| x = x + self.activation(layer(x)) |
| intermediates.append(x.clone()) |
| return intermediates |
| |
| |
| dim = 64 |
| num_layers = 20 |
| plain = PlainMLP(dim, num_layers) |
| res = ResMLP(dim, num_layers) |
| |
| |
| x = torch.randn(1, dim) * 0.5 |
| |
| |
| plain_ints = plain.forward_with_intermediates(x) |
| res_ints = res.forward_with_intermediates(x) |
| |
| |
| plain_norms = [p.norm().item() for p in plain_ints] |
| res_norms = [r.norm().item() for r in res_ints] |
| |
| plain_2d = [p[0, :2].detach().numpy() for p in plain_ints] |
| res_2d = [r[0, :2].detach().numpy() for r in res_ints] |
| |
| fig, axes = plt.subplots(2, 2, figsize=(14, 12)) |
| |
| |
| ax = axes[0, 0] |
| layers = range(len(plain_norms)) |
| ax.plot(layers, plain_norms, 'o-', color='#e74c3c', linewidth=2, markersize=6, label='PlainMLP') |
| ax.plot(layers, res_norms, 's-', color='#3498db', linewidth=2, markersize=6, label='ResMLP') |
| ax.set_xlabel('Layer (0=Input)', fontsize=12) |
| ax.set_ylabel('Vector Magnitude (L2 norm)', fontsize=12) |
| ax.set_title('Signal Magnitude Through Network', fontsize=13, fontweight='bold') |
| ax.legend() |
| ax.grid(True, alpha=0.3) |
| |
| |
| ax = axes[0, 1] |
| |
| |
| plain_x = [p[0] for p in plain_2d] |
| plain_y = [p[1] for p in plain_2d] |
| ax.plot(plain_x, plain_y, 'o-', color='#e74c3c', linewidth=1.5, markersize=4, |
| alpha=0.7, label='PlainMLP path') |
| ax.scatter(plain_x[0], plain_y[0], s=100, color='#e74c3c', marker='*', zorder=5) |
| ax.scatter(plain_x[-1], plain_y[-1], s=100, color='#e74c3c', marker='X', zorder=5) |
| |
| |
| res_x = [r[0] for r in res_2d] |
| res_y = [r[1] for r in res_2d] |
| ax.plot(res_x, res_y, 's-', color='#3498db', linewidth=1.5, markersize=4, |
| alpha=0.7, label='ResMLP path') |
| ax.scatter(res_x[0], res_y[0], s=100, color='#3498db', marker='*', zorder=5) |
| ax.scatter(res_x[-1], res_y[-1], s=100, color='#3498db', marker='X', zorder=5) |
| |
| ax.set_xlabel('Dimension 1', fontsize=12) |
| ax.set_ylabel('Dimension 2', fontsize=12) |
| ax.set_title('2D Projection of Vector Path\n(★=start, ✕=end)', fontsize=13, fontweight='bold') |
| ax.legend() |
| ax.grid(True, alpha=0.3) |
| ax.axhline(y=0, color='gray', linestyle='-', alpha=0.3) |
| ax.axvline(x=0, color='gray', linestyle='-', alpha=0.3) |
| |
| |
| ax = axes[1, 0] |
| plain_acts = np.array([p[0, :32].detach().numpy() for p in plain_ints]) |
| im = ax.imshow(plain_acts.T, aspect='auto', cmap='Reds', interpolation='nearest') |
| ax.set_xlabel('Layer', fontsize=12) |
| ax.set_ylabel('Dimension (first 32)', fontsize=12) |
| ax.set_title('PlainMLP: Activations Die Out', fontsize=13, fontweight='bold', color='#c0392b') |
| plt.colorbar(im, ax=ax, label='Activation Value') |
| |
| |
| ax = axes[1, 1] |
| res_acts = np.array([r[0, :32].detach().numpy() for r in res_ints]) |
| im = ax.imshow(res_acts.T, aspect='auto', cmap='Blues', interpolation='nearest') |
| ax.set_xlabel('Layer', fontsize=12) |
| ax.set_ylabel('Dimension (first 32)', fontsize=12) |
| ax.set_title('ResMLP: Activations Stay Alive', fontsize=13, fontweight='bold', color='#2980b9') |
| plt.colorbar(im, ax=ax, label='Activation Value') |
| |
| plt.tight_layout() |
| plt.savefig('plots_micro/5_layer_transformation.png', dpi=150, bbox_inches='tight') |
| plt.close() |
| print("[Plot 5] Layer transformation visualization saved") |
|
|
|
|
| |
| |
| |
| def plot_learning_comparison(): |
| """Show what each model learned (or didn't learn)""" |
| |
| fig, axes = plt.subplots(2, 2, figsize=(14, 12)) |
| |
| plain_losses = results['plain_mlp']['loss_history'] |
| res_losses = results['res_mlp']['loss_history'] |
| |
| |
| ax = axes[0, 0] |
| steps = range(len(plain_losses)) |
| ax.plot(steps, plain_losses, color='#e74c3c', linewidth=2, label='PlainMLP') |
| ax.plot(steps, res_losses, color='#3498db', linewidth=2, label='ResMLP') |
| ax.set_xlabel('Training Steps', fontsize=12) |
| ax.set_ylabel('MSE Loss', fontsize=12) |
| ax.set_title('Learning Progress', fontsize=13, fontweight='bold') |
| ax.set_yscale('log') |
| ax.legend() |
| ax.grid(True, alpha=0.3) |
| |
| |
| ax.axvspan(0, 50, alpha=0.1, color='gray') |
| ax.text(25, 5, 'Early\nTraining', ha='center', fontsize=9, color='gray') |
| ax.axvspan(450, 500, alpha=0.1, color='green') |
| ax.text(475, 5, 'Final', ha='center', fontsize=9, color='gray') |
| |
| |
| ax = axes[0, 1] |
| |
| plain_initial = plain_losses[0] |
| plain_final = plain_losses[-1] |
| res_initial = res_losses[0] |
| res_final = res_losses[-1] |
| |
| plain_reduction = (1 - plain_final / plain_initial) * 100 |
| res_reduction = (1 - res_final / res_initial) * 100 |
| |
| bars = ax.bar(['PlainMLP', 'ResMLP'], [plain_reduction, res_reduction], |
| color=['#e74c3c', '#3498db'], edgecolor='black', linewidth=2) |
| ax.set_ylabel('Loss Reduction (%)', fontsize=12) |
| ax.set_title('How Much Did Each Model Learn?', fontsize=13, fontweight='bold') |
| ax.set_ylim(0, 110) |
| |
| |
| ax.text(0, plain_reduction + 3, f'{plain_reduction:.1f}%', ha='center', fontsize=14, fontweight='bold') |
| ax.text(1, res_reduction + 3, f'{res_reduction:.1f}%', ha='center', fontsize=14, fontweight='bold') |
| |
| |
| ax.text(0, plain_reduction/2, 'FAILED\nTO LEARN', ha='center', va='center', |
| fontsize=11, color='white', fontweight='bold') |
| ax.text(1, res_reduction/2, 'LEARNED\nSUCCESSFULLY', ha='center', va='center', |
| fontsize=11, color='white', fontweight='bold') |
| |
| |
| ax = axes[1, 0] |
| |
| plain_grads = results['plain_mlp']['gradient_norms'] |
| res_grads = results['res_mlp']['gradient_norms'] |
| |
| layers = range(1, 21) |
| width = 0.35 |
| |
| ax.bar([l - width/2 for l in layers], plain_grads, width, label='PlainMLP', |
| color='#e74c3c', alpha=0.8) |
| ax.bar([l + width/2 for l in layers], res_grads, width, label='ResMLP', |
| color='#3498db', alpha=0.8) |
| |
| ax.set_xlabel('Layer', fontsize=12) |
| ax.set_ylabel('Gradient Magnitude', fontsize=12) |
| ax.set_title('Final Gradient Distribution by Layer', fontsize=13, fontweight='bold') |
| ax.set_yscale('log') |
| ax.legend() |
| ax.grid(True, alpha=0.3, axis='y') |
| |
| |
| ax = axes[1, 1] |
| ax.axis('off') |
| ax.set_xlim(0, 10) |
| ax.set_ylim(0, 10) |
| |
| ax.text(5, 9.5, '📊 Summary: Why Residuals Work', fontsize=16, fontweight='bold', ha='center') |
| |
| |
| box1 = FancyBboxPatch((0.5, 5), 4, 3.5, boxstyle="round,pad=0.1", |
| facecolor='#fadbd8', edgecolor='#c0392b', linewidth=2) |
| ax.add_patch(box1) |
| ax.text(2.5, 8, 'PlainMLP ❌', fontsize=13, fontweight='bold', ha='center', color='#c0392b') |
| ax.text(2.5, 7, f'• Loss: {plain_final:.3f}', fontsize=11, ha='center') |
| ax.text(2.5, 6.3, f'• Gradient L1: {plain_grads[0]:.1e}', fontsize=11, ha='center') |
| ax.text(2.5, 5.6, '• Status: UNTRAINABLE', fontsize=11, ha='center', color='#c0392b') |
| |
| |
| box2 = FancyBboxPatch((5.5, 5), 4, 3.5, boxstyle="round,pad=0.1", |
| facecolor='#d4e6f1', edgecolor='#2980b9', linewidth=2) |
| ax.add_patch(box2) |
| ax.text(7.5, 8, 'ResMLP ✓', fontsize=13, fontweight='bold', ha='center', color='#2980b9') |
| ax.text(7.5, 7, f'• Loss: {res_final:.3f}', fontsize=11, ha='center') |
| ax.text(7.5, 6.3, f'• Gradient L1: {res_grads[0]:.1e}', fontsize=11, ha='center') |
| ax.text(7.5, 5.6, '• Status: TRAINED', fontsize=11, ha='center', color='#2980b9') |
| |
| |
| box3 = FancyBboxPatch((1, 0.5), 8, 3.5, boxstyle="round,pad=0.1", |
| facecolor='#fef9e7', edgecolor='#f39c12', linewidth=2) |
| ax.add_patch(box3) |
| ax.text(5, 3.5, '💡 The Residual Connection:', fontsize=13, fontweight='bold', ha='center') |
| ax.text(5, 2.6, '1. Creates a "gradient highway" for backpropagation', fontsize=11, ha='center') |
| ax.text(5, 1.9, '2. Preserves signal magnitude through forward pass', fontsize=11, ha='center') |
| ax.text(5, 1.2, '3. Allows training of very deep networks', fontsize=11, ha='center') |
| |
| plt.tight_layout() |
| plt.savefig('plots_micro/6_learning_comparison.png', dpi=150, bbox_inches='tight') |
| plt.close() |
| print("[Plot 6] Learning comparison visualization saved") |
|
|
|
|
| |
| |
| |
| if __name__ == "__main__": |
| print("=" * 60) |
| print("Creating Micro-World Visualizations") |
| print("=" * 60) |
| |
| plot_signal_flow() |
| plot_gradient_flow() |
| plot_highway_concept() |
| plot_chain_rule() |
| plot_layer_transformation() |
| plot_learning_comparison() |
| |
| print("\n" + "=" * 60) |
| print("All visualizations saved to plots_micro/") |
| print("=" * 60) |
|
|