| | |
| | !pip install -q datasets safetensors huggingface_hub |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from torch.utils.data import DataLoader |
| | from datasets import load_dataset |
| | from huggingface_hub import hf_hub_download |
| | from safetensors.torch import load_file as load_safetensors |
| | import matplotlib.pyplot as plt |
| | import numpy as np |
| | import math |
| | import json |
| |
|
| | device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| |
|
| | |
| | |
| | |
| |
|
| | class MobiusLensRaw(nn.Module): |
| | def __init__(self, dim, layer_idx, total_layers, scale_range=(1.0, 9.0)): |
| | super().__init__() |
| | self.dim = dim |
| | self.t = layer_idx / max(total_layers - 1, 1) |
| | scale_span = scale_range[1] - scale_range[0] |
| | step = scale_span / max(total_layers, 1) |
| | self.register_buffer('scales', torch.tensor([scale_range[0] + self.t * scale_span, |
| | scale_range[0] + self.t * scale_span + step])) |
| | self.twist_in_angle = nn.Parameter(torch.tensor(self.t * math.pi)) |
| | self.twist_in_proj = nn.Linear(dim, dim, bias=False) |
| | self.omega = nn.Parameter(torch.tensor(math.pi)) |
| | self.alpha = nn.Parameter(torch.tensor(1.5)) |
| | self.phase_l, self.drift_l = nn.Parameter(torch.zeros(2)), nn.Parameter(torch.ones(2)) |
| | self.phase_m, self.drift_m = nn.Parameter(torch.zeros(2)), nn.Parameter(torch.zeros(2)) |
| | self.phase_r, self.drift_r = nn.Parameter(torch.zeros(2)), nn.Parameter(-torch.ones(2)) |
| | self.accum_weights = nn.Parameter(torch.tensor([0.4, 0.2, 0.4])) |
| | self.xor_weight = nn.Parameter(torch.tensor(0.7)) |
| | self.gate_norm = nn.LayerNorm(dim) |
| | self.twist_out_angle = nn.Parameter(torch.tensor(-self.t * math.pi)) |
| | self.twist_out_proj = nn.Linear(dim, dim, bias=False) |
| |
|
| | def forward(self, x): |
| | cos_t, sin_t = torch.cos(self.twist_in_angle), torch.sin(self.twist_in_angle) |
| | x = x * cos_t + self.twist_in_proj(x) * sin_t |
| | x_norm = torch.tanh(x) |
| | t = x_norm.abs().mean(dim=-1, keepdim=True).unsqueeze(-2) |
| | x_exp = x_norm.unsqueeze(-2) |
| | s = self.scales.view(-1, 1) |
| | a = self.alpha.abs() + 0.1 |
| | def wave(phase, drift): |
| | pos = s * self.omega * (x_exp + drift.view(-1, 1) * t) + phase.view(-1, 1) |
| | return torch.exp(-a * torch.sin(pos).pow(2)).prod(dim=-2) |
| | L, M, R = wave(self.phase_l, self.drift_l), wave(self.phase_m, self.drift_m), wave(self.phase_r, self.drift_r) |
| | w = torch.softmax(self.accum_weights, dim=0) |
| | xor_w = torch.sigmoid(self.xor_weight) |
| | lr = xor_w * (L + R - 2*L*R).abs() + (1 - xor_w) * L * R |
| | gate = torch.sigmoid(self.gate_norm((w[0]*L + w[1]*M + w[2]*R) * (0.5 + 0.5*lr))) |
| | x = x * gate |
| | cos_t, sin_t = torch.cos(self.twist_out_angle), torch.sin(self.twist_out_angle) |
| | return x * cos_t + self.twist_out_proj(x) * sin_t, gate |
| | |
| | def forward_raw(self, x): |
| | """Return raw L/M/R values for inspection.""" |
| | cos_t, sin_t = torch.cos(self.twist_in_angle), torch.sin(self.twist_in_angle) |
| | x_twisted = x * cos_t + self.twist_in_proj(x) * sin_t |
| | x_norm = torch.tanh(x_twisted) |
| | t = x_norm.abs().mean(dim=-1, keepdim=True).unsqueeze(-2) |
| | x_exp = x_norm.unsqueeze(-2) |
| | s = self.scales.view(-1, 1) |
| | a = self.alpha.abs() + 0.1 |
| | |
| | def wave_detailed(phase, drift): |
| | pos = s * self.omega * (x_exp + drift.view(-1, 1) * t) + phase.view(-1, 1) |
| | sin_val = torch.sin(pos) |
| | exp_val = torch.exp(-a * sin_val.pow(2)) |
| | prod_val = exp_val.prod(dim=-2) |
| | return prod_val, sin_val, exp_val |
| | |
| | L, L_sin, L_exp = wave_detailed(self.phase_l, self.drift_l) |
| | M, M_sin, M_exp = wave_detailed(self.phase_m, self.drift_m) |
| | R, R_sin, R_exp = wave_detailed(self.phase_r, self.drift_r) |
| | |
| | w = torch.softmax(self.accum_weights, dim=0) |
| | xor_w = torch.sigmoid(self.xor_weight) |
| | xor_comp = (L + R - 2*L*R).abs() |
| | and_comp = L * R |
| | lr = xor_w * xor_comp + (1 - xor_w) * and_comp |
| | gate_pre = (w[0]*L + w[1]*M + w[2]*R) * (0.5 + 0.5*lr) |
| | gate = torch.sigmoid(self.gate_norm(gate_pre)) |
| | |
| | return { |
| | 'x_norm': x_norm, 'L': L, 'M': M, 'R': R, |
| | 'L_sin': L_sin, 'L_exp': L_exp, |
| | 'xor_comp': xor_comp, 'and_comp': and_comp, |
| | 'gate_pre': gate_pre, 'gate': gate, |
| | 'omega': self.omega.item(), 'alpha': a.item(), |
| | 'scales': self.scales.cpu().numpy(), |
| | 'weights': w.detach().cpu().numpy(), |
| | 'xor_weight': xor_w.item(), |
| | } |
| |
|
| | class MobiusBlockRaw(nn.Module): |
| | def __init__(self, channels, layer_idx, total_layers, scale_range=(1.0, 9.0), reduction=0.5): |
| | super().__init__() |
| | self.conv = nn.Sequential(nn.Conv2d(channels, channels, 3, padding=1, groups=channels, bias=False), |
| | nn.Conv2d(channels, channels, 1, bias=False), nn.BatchNorm2d(channels)) |
| | self.lens = MobiusLensRaw(channels, layer_idx, total_layers, scale_range) |
| | third = channels // 3 |
| | which_third = layer_idx % 3 |
| | mask = torch.ones(channels) |
| | mask[which_third*third : which_third*third + third + (channels%3 if which_third==2 else 0)] = reduction |
| | self.register_buffer('thirds_mask', mask.view(1, -1, 1, 1)) |
| | self.residual_weight = nn.Parameter(torch.tensor(0.9)) |
| |
|
| | def forward(self, x): |
| | identity = x |
| | h = self.conv(x).permute(0, 2, 3, 1) |
| | h, gate = self.lens(h) |
| | h = h.permute(0, 3, 1, 2) * self.thirds_mask |
| | rw = torch.sigmoid(self.residual_weight) |
| | return rw * identity + (1 - rw) * h |
| | |
| | def forward_raw(self, x): |
| | h = self.conv(x).permute(0, 2, 3, 1) |
| | return self.lens.forward_raw(h) |
| |
|
| | class MobiusNetRaw(nn.Module): |
| | def __init__(self, in_chans=1, num_classes=1000, channels=(64,128,256), |
| | depths=(2,2,2), scale_range=(0.5,2.5), use_integrator=True): |
| | super().__init__() |
| | total_layers = sum(depths) |
| | channels = list(channels) |
| | self.stem = nn.Sequential(nn.Conv2d(in_chans, channels[0], 3, padding=1, bias=False), nn.BatchNorm2d(channels[0])) |
| | self.stages = nn.ModuleList() |
| | self.downsamples = nn.ModuleList() |
| | layer_idx = 0 |
| | for si, d in enumerate(depths): |
| | self.stages.append(nn.ModuleList([MobiusBlockRaw(channels[si], layer_idx+i, total_layers, scale_range) for i in range(d)])) |
| | layer_idx += d |
| | if si < len(depths)-1: |
| | self.downsamples.append(nn.Sequential(nn.Conv2d(channels[si], channels[si+1], 3, stride=2, padding=1, bias=False), nn.BatchNorm2d(channels[si+1]))) |
| | |
| | self.integrator = nn.Sequential(nn.Conv2d(channels[-1], channels[-1], 3, padding=1, bias=False), |
| | nn.BatchNorm2d(channels[-1]), nn.GELU()) if use_integrator else nn.Identity() |
| | self.pool = nn.AdaptiveAvgPool2d(1) |
| | self.head = nn.Linear(channels[-1], num_classes) |
| |
|
| | def get_block_raw(self, x, target_stage, target_block): |
| | """Forward to target block and return raw wave data.""" |
| | x = self.stem(x) |
| | for si, stage in enumerate(self.stages): |
| | for bi, block in enumerate(stage): |
| | if si == target_stage and bi == target_block: |
| | return block.forward_raw(x) |
| | x = block(x) |
| | if si < len(self.downsamples): |
| | x = self.downsamples[si](x) |
| | return None |
| |
|
| | |
| | |
| | |
| |
|
| | print("Loading model...") |
| | config_path = hf_hub_download("AbstractPhil/mobiusnet-distillations", |
| | "checkpoints/mobius_tiny_s_imagenet_clip_vit_l14/20260111_000512/config.json") |
| | with open(config_path) as f: |
| | config = json.load(f) |
| | model_path = hf_hub_download("AbstractPhil/mobiusnet-distillations", |
| | "checkpoints/mobius_tiny_s_imagenet_clip_vit_l14/20260111_000512/checkpoints/best_model.safetensors") |
| |
|
| | cfg = config['model'] |
| | model = MobiusNetRaw(cfg['in_chans'], cfg['num_classes'], tuple(cfg['channels']), |
| | tuple(cfg['depths']), tuple(cfg['scale_range']), cfg['use_integrator']).to(device) |
| | model.load_state_dict(load_safetensors(model_path)) |
| | model.eval() |
| | print("✓ Loaded") |
| |
|
| | |
| | |
| | |
| |
|
| | ds = load_dataset("AbstractPhil/imagenet-clip-features-orderly", "clip_vit_l14", |
| | split="validation", streaming=True).with_format("torch") |
| | loader = DataLoader(ds, batch_size=16) |
| | batch = next(iter(loader)) |
| | x = batch['clip_features'].view(-1, 1, 24, 32).to(device) |
| |
|
| | |
| | |
| | |
| |
|
| | blocks = [(0,0), (0,1), (1,0), (1,1), (2,0), (2,1)] |
| | block_names = ['S0B0', 'S0B1', 'S1B0', 'S1B1', 'S2B0', 'S2B1'] |
| |
|
| | fig, axes = plt.subplots(6, 6, figsize=(24, 24)) |
| |
|
| | for bi, ((si, bii), name) in enumerate(zip(blocks, block_names)): |
| | with torch.no_grad(): |
| | raw = model.get_block_raw(x, si, bii) |
| | |
| | print(f"\n{'='*60}") |
| | print(f"{name}: ω={raw['omega']:.3f}, α={raw['alpha']:.3f}, scales={raw['scales']}") |
| | print(f" Weights: L={raw['weights'][0]:.3f}, M={raw['weights'][1]:.3f}, R={raw['weights'][2]:.3f}") |
| | print(f" XOR weight: {raw['xor_weight']:.3f}") |
| | |
| | L, M, R = raw['L'], raw['M'], raw['R'] |
| | gate = raw['gate'] |
| | |
| | print(f" L: min={L.min():.6f}, max={L.max():.6f}, mean={L.mean():.6f}, std={L.std():.6f}") |
| | print(f" M: min={M.min():.6f}, max={M.max():.6f}, mean={M.mean():.6f}, std={M.std():.6f}") |
| | print(f" R: min={R.min():.6f}, max={R.max():.6f}, mean={R.mean():.6f}, std={R.std():.6f}") |
| | print(f" Gate: min={gate.min():.4f}, max={gate.max():.4f}, mean={gate.mean():.4f}") |
| | |
| | |
| | print(f" L_sin range: [{raw['L_sin'].min():.4f}, {raw['L_sin'].max():.4f}]") |
| | print(f" L_exp range: [{raw['L_exp'].min():.6f}, {raw['L_exp'].max():.6f}]") |
| | print(f" x_norm range: [{raw['x_norm'].min():.4f}, {raw['x_norm'].max():.4f}]") |
| | |
| | |
| | axes[bi, 0].hist(L.cpu().numpy().flatten(), bins=50, color='red', alpha=0.7, density=True) |
| | axes[bi, 0].set_title(f'{name} L\nμ={L.mean():.4f}, σ={L.std():.4f}', fontsize=10) |
| | axes[bi, 0].axvline(x=L.mean().item(), color='black', linestyle='--') |
| | |
| | axes[bi, 1].hist(M.cpu().numpy().flatten(), bins=50, color='green', alpha=0.7, density=True) |
| | axes[bi, 1].set_title(f'{name} M\nμ={M.mean():.4f}', fontsize=10) |
| | |
| | axes[bi, 2].hist(R.cpu().numpy().flatten(), bins=50, color='blue', alpha=0.7, density=True) |
| | axes[bi, 2].set_title(f'{name} R\nμ={R.mean():.4f}', fontsize=10) |
| | |
| | axes[bi, 3].hist(gate.cpu().numpy().flatten(), bins=50, color='purple', alpha=0.7, density=True) |
| | axes[bi, 3].set_title(f'{name} Gate\nμ={gate.mean():.4f}', fontsize=10) |
| | |
| | |
| | L_spatial = L[0].mean(dim=-1).cpu().numpy() |
| | axes[bi, 4].imshow(L_spatial, cmap='hot', aspect='auto') |
| | axes[bi, 4].set_title(f'{name} L spatial\nα={raw["alpha"]:.2f}', fontsize=10) |
| | axes[bi, 4].axis('off') |
| | |
| | gate_spatial = gate[0].mean(dim=-1).cpu().numpy() |
| | axes[bi, 5].imshow(gate_spatial, cmap='viridis', aspect='auto', vmin=0, vmax=1) |
| | axes[bi, 5].set_title(f'{name} Gate spatial', fontsize=10) |
| | axes[bi, 5].axis('off') |
| |
|
| | plt.suptitle('Raw Wave Diagnostics: L/M/R Distributions', fontsize=14, fontweight='bold') |
| | plt.tight_layout() |
| | plt.savefig("mobius_raw_diagnostics.png", dpi=150, bbox_inches="tight") |
| | plt.show() |
| |
|
| | |
| | |
| | |
| |
|
| | print("\n" + "="*70) |
| | print("ANALYSIS: Wave Function Behavior") |
| | print("="*70) |
| |
|
| | |
| | |
| |
|
| | print(""" |
| | Wave function: exp(-α * sin²(ω * s * (x + drift*t))) |
| | |
| | For high α (like 5.12 at S2B1): |
| | - This becomes a VERY narrow peak around sin(...)=0 |
| | - i.e., when ω*s*(x+drift*t) = n*π |
| | |
| | The prod over 2 scales means BOTH scales must hit a peak simultaneously. |
| | This is extremely rare, so most values → exp(-5.12) ≈ 0.006 |
| | |
| | BUT: The gate is computed AFTER LayerNorm on gate_pre! |
| | gate = sigmoid(LayerNorm(weighted_sum * (0.5 + 0.5*lr))) |
| | |
| | LayerNorm rescales the near-zero values to have mean=0, std=1 |
| | Then sigmoid maps that to ~0.5 centered distribution. |
| | |
| | This is why gates are ~0.4-0.5 even when raw L/M/R are tiny. |
| | """) |
| |
|
| | |
| | with torch.no_grad(): |
| | raw = model.get_block_raw(x, 2, 1) |
| | |
| | print(f"\nS2B1 gate_pre: min={raw['gate_pre'].min():.6f}, max={raw['gate_pre'].max():.6f}, mean={raw['gate_pre'].mean():.6f}") |
| | print(f"S2B1 gate: min={raw['gate'].min():.4f}, max={raw['gate'].max():.4f}, mean={raw['gate'].mean():.4f}") |
| |
|
| | |
| | print(f"\nThe information is in relative L/M/R differences across channels:") |
| | L_per_channel = raw['L'][0].mean(dim=(0,1)).cpu().numpy() |
| | M_per_channel = raw['M'][0].mean(dim=(0,1)).cpu().numpy() |
| | R_per_channel = raw['R'][0].mean(dim=(0,1)).cpu().numpy() |
| |
|
| | fig2, ax2 = plt.subplots(1, 1, figsize=(14, 4)) |
| | channels = np.arange(len(L_per_channel)) |
| | ax2.plot(channels, L_per_channel, 'r-', alpha=0.7, label='L') |
| | ax2.plot(channels, M_per_channel, 'g-', alpha=0.7, label='M') |
| | ax2.plot(channels, R_per_channel, 'b-', alpha=0.7, label='R') |
| | ax2.set_xlabel('Channel') |
| | ax2.set_ylabel('Mean activation') |
| | ax2.set_title('S2B1: L/M/R per channel (the signal is in the variance)') |
| | ax2.legend() |
| | plt.tight_layout() |
| | plt.savefig("mobius_channel_variance.png", dpi=150) |
| | plt.show() |
| |
|
| | print(f"\nPer-channel variance:") |
| | print(f" L channels std: {L_per_channel.std():.6f}") |
| | print(f" M channels std: {M_per_channel.std():.6f}") |
| | print(f" R channels std: {R_per_channel.std():.6f}") |