{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "machine_shape": "hm", "gpuType": "A100" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "code", "source": [ "try:\n", " !pip uninstall -qy geometricvocab\n", "except:\n", " pass\n", "\n", "!pip install -q git+https://github.com/AbstractEyes/lattice_vocabulary.git" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "iHaIH_8OI38S", "outputId": "361524ac-8ef7-419f-bb7b-05fea1062443" }, "execution_count": 1, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "\u001b[33mWARNING: Skipping geometricvocab as it is not installed.\u001b[0m\u001b[33m\n", "\u001b[0m Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", " Building wheel for geometricvocab (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n" ] } ] }, { "cell_type": "code", "source": [ "#@title True Cayley-Menger KSimplex Linear\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets, transforms\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import math\n", "from typing import List, Tuple, Optional\n", "\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "print(f\"Device: {device}\")\n", "\n", "# ============================================================================\n", "# TRUE CAYLEY-MENGER KSIMPLEX PRIMITIVES\n", "# ============================================================================\n", "\n", "class SimplexEdgeLayer(nn.Module):\n", " \"\"\"\n", " Routes information along k-simplex edges.\n", " For k-simplex: k+1 vertices, each connected to k others.\n", "\n", " Input: vertex features [N, B, D, V, E]\n", " Output: updated vertex features [N, B, D, V, E]\n", " \"\"\"\n", "\n", " def __init__(self, n: int, num_vertices: int, vertex_dim: int):\n", " super().__init__()\n", " self.n = n\n", " self.num_vertices = num_vertices # V = k+1\n", " self.vertex_dim = vertex_dim # E\n", " self.num_edges = num_vertices - 1 # k edges per vertex\n", "\n", " # Edge message weights: for each vertex, k incoming edge transforms\n", " self.edge_weight = nn.Parameter(torch.empty(n, num_vertices, self.num_edges, vertex_dim, vertex_dim))\n", "\n", " # Routing: which neighbors to attend to\n", " self.route_weight = nn.Parameter(torch.empty(n, num_vertices, self.num_edges, vertex_dim))\n", " self.route_bias = nn.Parameter(torch.zeros(n, num_vertices, self.num_edges))\n", "\n", " # Output projection per vertex\n", " self.out_weight = nn.Parameter(torch.empty(n, num_vertices, vertex_dim, vertex_dim))\n", " self.out_bias = nn.Parameter(torch.zeros(n, num_vertices, vertex_dim))\n", "\n", " self._init_weights()\n", " self._build_neighbor_indices()\n", "\n", " def _init_weights(self):\n", " for i in range(self.n):\n", " for v in range(self.num_vertices):\n", " nn.init.orthogonal_(self.out_weight[i, v])\n", " for e in range(self.num_edges):\n", " nn.init.kaiming_uniform_(self.edge_weight[i, v, e], a=math.sqrt(5))\n", " nn.init.kaiming_uniform_(self.route_weight[i, v, e].unsqueeze(0), a=math.sqrt(5))\n", "\n", " def _build_neighbor_indices(self):\n", " \"\"\"For each vertex, list its k neighbors (all other vertices in simplex).\"\"\"\n", " neighbors = []\n", " for v in range(self.num_vertices):\n", " v_neighbors = [j for j in range(self.num_vertices) if j != v]\n", " neighbors.append(v_neighbors)\n", " self.register_buffer('neighbor_idx', torch.tensor(neighbors)) # [V, k]\n", "\n", " def forward(self, vertices: torch.Tensor) -> torch.Tensor:\n", " \"\"\"\n", " vertices: [N, B, D, V, E]\n", " returns: [N, B, D, V, E]\n", " \"\"\"\n", " N, B, D, V, E = vertices.shape\n", "\n", " # Gather neighbor features for each vertex\n", " # neighbor_idx: [V, k] -> expand to [N, B, D, V, k, E]\n", " idx = self.neighbor_idx.view(1, 1, 1, V, -1, 1).expand(N, B, D, -1, -1, E)\n", " neighbors = vertices.unsqueeze(-2).expand(-1, -1, -1, -1, self.num_edges, -1)\n", " neighbor_feats = torch.gather(\n", " vertices.unsqueeze(3).expand(-1, -1, -1, V, -1, -1),\n", " dim=4,\n", " index=idx\n", " ) # [N, B, D, V, k, E]\n", "\n", " # Compute edge messages\n", " # edge_weight: [N, V, k, E, E]\n", " messages = torch.einsum('nvkeo,nbdvke->nbdvko', self.edge_weight, neighbor_feats)\n", " # messages: [N, B, D, V, k, E]\n", "\n", " # Compute routing weights\n", " route_logits = torch.einsum('nvke,nbdvke->nbdvk', self.route_weight, neighbor_feats)\n", " route_logits = route_logits + self.route_bias.view(N, 1, 1, V, -1)\n", " route_weights = F.softmax(route_logits, dim=-1) # [N, B, D, V, k]\n", "\n", " # Aggregate messages\n", " aggregated = (messages * route_weights.unsqueeze(-1)).sum(dim=-2) # [N, B, D, V, E]\n", "\n", " # Output transform with residual\n", " out = torch.einsum('nveo,nbdve->nbdvo', self.out_weight, aggregated)\n", " out = out + self.out_bias.view(N, 1, 1, V, E)\n", "\n", " return vertices + out # Residual connection\n", "\n", "\n", "class CayleyMengerExit(nn.Module):\n", " \"\"\"\n", " Computes true Cayley-Menger geometry from vertex positions.\n", "\n", " - Projects accumulated features to vertex coordinates\n", " - Computes TRUE Euclidean distances via Gram matrix\n", " - Computes CM determinant (guaranteed valid!)\n", " - Outputs energy based on geometry\n", " \"\"\"\n", "\n", " def __init__(self, n: int, vertex_dim: int, num_layers: int, k: int):\n", " super().__init__()\n", " self.n = n\n", " self.k = k\n", " self.num_vertices = k + 1 # V\n", " self.vertex_dim = vertex_dim # E (embedding dim per vertex)\n", " self.coord_dim = k # Minimum dims to embed k-simplex\n", " self.num_pairs = (self.num_vertices * (self.num_vertices - 1)) // 2\n", "\n", " # Layer accumulation weights (holographic: all layers contribute)\n", " inverse_init = torch.linspace(1.0, 0.1, num_layers).unsqueeze(0).expand(n, -1).clone()\n", " self.inverse_weights = nn.Parameter(inverse_init)\n", "\n", " # Project vertex features → coordinates in R^k\n", " self.to_coords = nn.Parameter(torch.empty(n, self.num_vertices, self.coord_dim, vertex_dim))\n", " self.coord_bias = nn.Parameter(torch.zeros(n, self.num_vertices, self.coord_dim))\n", "\n", " # Energy from geometry: distances + volume → scalar\n", " self.energy_weight = nn.Parameter(torch.empty(n, 1, self.num_pairs + 1))\n", " self.energy_bias = nn.Parameter(torch.zeros(n, 1))\n", "\n", " self._init_weights()\n", " self._register_pair_indices()\n", "\n", " def _init_weights(self):\n", " for i in range(self.n):\n", " for v in range(self.num_vertices):\n", " nn.init.kaiming_uniform_(self.to_coords[i, v], a=math.sqrt(5))\n", " nn.init.kaiming_uniform_(self.energy_weight[i], a=math.sqrt(5))\n", "\n", " def _register_pair_indices(self):\n", " pair_i, pair_j = [], []\n", " for i in range(self.num_vertices):\n", " for j in range(i + 1, self.num_vertices):\n", " pair_i.append(i)\n", " pair_j.append(j)\n", " self.register_buffer('pair_i', torch.tensor(pair_i))\n", " self.register_buffer('pair_j', torch.tensor(pair_j))\n", "\n", " def compute_distances_sq(self, coords: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n", " \"\"\"\n", " Compute true Euclidean squared distances via Gram matrix.\n", " coords: [N, B, D, V, C] - V vertices in C-dimensional space\n", " returns: d_sq_pairs [N, B, D, num_pairs], d_sq_full [N, B, D, V, V]\n", " \"\"\"\n", " # Gram matrix: G_ij = \n", " gram = torch.einsum('nbdvc,nbdwc->nbdvw', coords, coords)\n", "\n", " # Squared norms (diagonal)\n", " norms_sq = torch.diagonal(gram, dim1=-2, dim2=-1) # [N, B, D, V]\n", "\n", " # d²(i,j) = ||v_i||² + ||v_j||² - 2\n", " d_sq_full = norms_sq.unsqueeze(-1) + norms_sq.unsqueeze(-2) - 2 * gram\n", "\n", " # Ensure non-negative (numerical stability)\n", " d_sq_full = F.relu(d_sq_full)\n", "\n", " # Extract upper triangle pairs\n", " d_sq_pairs = d_sq_full[..., self.pair_i, self.pair_j]\n", "\n", " return d_sq_pairs, d_sq_full\n", "\n", " def cayley_menger_volume_sq(self, d_sq_full: torch.Tensor) -> torch.Tensor:\n", " \"\"\"\n", " Compute squared volume from CM determinant.\n", " d_sq_full: [N, B, D, V, V]\n", " returns: [N, B, D]\n", " \"\"\"\n", " N, B, D, V, _ = d_sq_full.shape\n", "\n", " # CM matrix is (V+1) x (V+1)\n", " cm_size = V + 1\n", " cm = torch.zeros(N, B, D, cm_size, cm_size, device=d_sq_full.device, dtype=d_sq_full.dtype)\n", "\n", " # Border: first row/col = 1 (except [0,0] = 0)\n", " cm[..., 0, 1:] = 1.0\n", " cm[..., 1:, 0] = 1.0\n", "\n", " # Interior: squared distances\n", " cm[..., 1:, 1:] = d_sq_full\n", "\n", " # Determinant\n", " det = torch.linalg.det(cm)\n", "\n", " # Volume² = (-1)^(k+1) * det / (2^k * (k!)²)\n", " sign = (-1.0) ** (self.k + 1)\n", " factorial_k = math.factorial(self.k)\n", " prefactor = sign / ((2.0 ** self.k) * (factorial_k ** 2))\n", "\n", " vol_sq = prefactor * det\n", "\n", " return vol_sq\n", "\n", " def forward(self, layer_outputs: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:\n", " \"\"\"\n", " layer_outputs: list of [N, B, D, V, E] tensors\n", " returns: energy, d_sq_pairs, vol_sq, coords\n", " \"\"\"\n", " # Stack layers: [N, B, D, V, E, L]\n", " stacked = torch.stack(layer_outputs, dim=-1)\n", "\n", " # Weighted accumulation (holographic)\n", " weights = F.softmax(self.inverse_weights, dim=1) # [N, L]\n", " weights = weights.view(self.n, 1, 1, 1, 1, -1)\n", " accumulated = (stacked * weights).sum(dim=-1) # [N, B, D, V, E]\n", "\n", " # Project to coordinates in R^k\n", " # to_coords: [N, V, C, E], accumulated: [N, B, D, V, E]\n", " coords = torch.einsum('nvce,nbdve->nbdvc', self.to_coords, accumulated)\n", " coords = coords + self.coord_bias.view(self.n, 1, 1, self.num_vertices, self.coord_dim)\n", " # coords: [N, B, D, V, C]\n", "\n", " # True distances\n", " d_sq_pairs, d_sq_full = self.compute_distances_sq(coords)\n", "\n", " # CM volume (guaranteed valid!)\n", " vol_sq = self.cayley_menger_volume_sq(d_sq_full)\n", "\n", " # Energy from geometry\n", " combined = torch.cat([d_sq_pairs, vol_sq.unsqueeze(-1)], dim=-1)\n", " energy = torch.einsum('nep,nbdp->nbde', self.energy_weight, combined)\n", " energy = energy.squeeze(-1) + self.energy_bias.view(self.n, 1, 1)\n", "\n", " return energy, d_sq_pairs, vol_sq, coords\n", "\n", "\n", "class WideKSimplexLinear(nn.Module):\n", " \"\"\"\n", " True Cayley-Menger KSimplex linear layer.\n", "\n", " Replaces standard linear with simplex geometry:\n", " - Input projected to k+1 vertex features\n", " - Vertices communicate along simplex edges\n", " - CM geometry (distances, volume) drives output\n", " \"\"\"\n", "\n", " def __init__(self, n: int, input_dim: int, output_dim: int, k: int = 4, vertex_dim: int = 8):\n", " super().__init__()\n", " self.n = n\n", " self.input_dim = input_dim\n", " self.output_dim = output_dim\n", " self.k = k\n", " self.num_vertices = k + 1\n", " self.vertex_dim = vertex_dim\n", " self.depth = k + 1 # Layers = vertices (simplex depth)\n", "\n", " # Entry: project input to vertex features\n", " self.entry = nn.Parameter(torch.empty(n, self.num_vertices, vertex_dim, input_dim))\n", " self.entry_bias = nn.Parameter(torch.zeros(n, self.num_vertices, vertex_dim))\n", "\n", " # Simplex edge layers\n", " self.layers = nn.ModuleList([\n", " SimplexEdgeLayer(n, self.num_vertices, vertex_dim)\n", " for _ in range(self.depth)\n", " ])\n", "\n", " # CM exit\n", " self.exit = CayleyMengerExit(n, vertex_dim, self.depth, k)\n", "\n", " # Output projection\n", " self.out_proj = nn.Parameter(torch.empty(n, output_dim, self.exit.num_pairs + 1))\n", " self.out_bias = nn.Parameter(torch.zeros(n, output_dim))\n", "\n", " self._init_weights()\n", "\n", " def _init_weights(self):\n", " for i in range(self.n):\n", " for v in range(self.num_vertices):\n", " nn.init.kaiming_uniform_(self.entry[i, v], a=math.sqrt(5))\n", " nn.init.kaiming_uniform_(self.out_proj[i], a=math.sqrt(5))\n", "\n", " def forward(self, x: torch.Tensor, return_geometry: bool = False):\n", " \"\"\"\n", " x: [N, B, D] - N parallel, B batch, D input dim\n", " returns: [N, B, O] output, optionally geometry\n", " \"\"\"\n", " N, B, D = x.shape\n", "\n", " # Project to vertex features: [N, B, D] → [N, B, 1, V, E]\n", " # Treat D as a batch dimension\n", " h = torch.einsum('nvei,nbd->nbdve', self.entry, x)\n", " h = h + self.entry_bias.view(N, 1, 1, self.num_vertices, self.vertex_dim)\n", " # h: [N, B, D, V, E] but D=input_dim, we want spatial dim\n", "\n", " # Actually, let's treat it correctly:\n", " # x: [N, B, input_dim] → vertices: [N, B, V, E]\n", " h = torch.einsum('nvei,nbi->nbve', self.entry, x)\n", " h = h + self.entry_bias.view(N, 1, self.num_vertices, self.vertex_dim)\n", " # h: [N, B, V, E]\n", "\n", " # Add dummy spatial dim for compatibility (D=1)\n", " h = h.unsqueeze(2) # [N, B, 1, V, E]\n", "\n", " # Process through simplex layers\n", " layer_outputs = []\n", " for layer in self.layers:\n", " h = layer(h)\n", " layer_outputs.append(h)\n", "\n", " # CM exit\n", " energy, d_sq, vol_sq, coords = self.exit(layer_outputs)\n", " # energy: [N, B, 1], d_sq: [N, B, 1, pairs], vol_sq: [N, B, 1]\n", "\n", " # Squeeze spatial dim\n", " energy = energy.squeeze(2) # [N, B]\n", " d_sq = d_sq.squeeze(2) # [N, B, pairs]\n", " vol_sq = vol_sq.squeeze(2) # [N, B]\n", "\n", " # Output from full geometry\n", " geom = torch.cat([d_sq, vol_sq.unsqueeze(-1)], dim=-1) # [N, B, pairs+1]\n", " out = torch.einsum('noi,nbi->nbo', self.out_proj, geom)\n", " out = out + self.out_bias.view(N, 1, -1)\n", "\n", " if return_geometry:\n", " return out, {\n", " 'energy': energy,\n", " 'd_sq': d_sq,\n", " 'vol_sq': vol_sq,\n", " 'coords': coords.squeeze(2), # [N, B, V, C]\n", " }\n", " return out\n", "\n", "\n", "# ============================================================================\n", "# FASHIONMNIST MODEL\n", "# ============================================================================\n", "\n", "class FashionKSimplexNet(nn.Module):\n", " \"\"\"KSimplex classifier for FashionMNIST.\"\"\"\n", "\n", " def __init__(self, k: int = 4, vertex_dim: int = 16, num_classes: int = 10):\n", " super().__init__()\n", " self.k = k\n", "\n", " # Simple conv stem\n", " self.stem = nn.Sequential(\n", " nn.Conv2d(1, 32, 3, padding=1),\n", " nn.BatchNorm2d(32),\n", " nn.GELU(),\n", " nn.Conv2d(32, 64, 3, stride=2, padding=1), # 14x14\n", " nn.BatchNorm2d(64),\n", " nn.GELU(),\n", " nn.Conv2d(64, 128, 3, stride=2, padding=1), # 7x7\n", " nn.BatchNorm2d(128),\n", " nn.GELU(),\n", " )\n", "\n", " self.pool = nn.AdaptiveAvgPool2d(1)\n", " self.flat_dim = 128\n", "\n", " # KSimplex replaces final linear\n", " # n=1 (single parallel), input=128, output=num_classes\n", " self.ksimplex = WideKSimplexLinear(\n", " n=1,\n", " input_dim=self.flat_dim,\n", " output_dim=num_classes,\n", " k=k,\n", " vertex_dim=vertex_dim\n", " )\n", "\n", " def forward(self, x: torch.Tensor, return_geometry: bool = False):\n", " B = x.size(0)\n", "\n", " # Conv features\n", " h = self.stem(x)\n", " h = self.pool(h).flatten(1) # [B, 128]\n", "\n", " # Add N dimension for ksimplex (n=1)\n", " h = h.unsqueeze(0) # [1, B, 128]\n", "\n", " # KSimplex forward\n", " if return_geometry:\n", " out, geom = self.ksimplex(h, return_geometry=True)\n", " return out.squeeze(0), geom # [B, classes]\n", "\n", " out = self.ksimplex(h)\n", " return out.squeeze(0)\n", "\n", "\n", "# ============================================================================\n", "# TRAINING\n", "# ============================================================================\n", "\n", "def train():\n", " # Data\n", " transform = transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.2860,), (0.3530,))\n", " ])\n", "\n", " train_data = datasets.FashionMNIST('./data', train=True, download=True, transform=transform)\n", " test_data = datasets.FashionMNIST('./data', train=False, download=True, transform=transform)\n", "\n", " train_loader = DataLoader(train_data, batch_size=128, shuffle=True, num_workers=2)\n", " test_loader = DataLoader(test_data, batch_size=256, shuffle=False, num_workers=2)\n", "\n", " # Model\n", " model = FashionKSimplexNet(k=2, vertex_dim=16, num_classes=10).to(device)\n", "\n", " total_params = sum(p.numel() for p in model.parameters())\n", " print(f\"Model params: {total_params:,}\")\n", " print(f\"k={model.k}, vertices={model.k+1}, pairs={(model.k+1)*model.k//2}\")\n", "\n", " # Training\n", " optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.05)\n", " scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30)\n", "\n", " best_acc = 0\n", "\n", " print(\"\\nTraining...\")\n", " print(\"=\"*70)\n", "\n", " for epoch in range(30):\n", " # Train\n", " model.train()\n", " train_loss, correct, total = 0, 0, 0\n", "\n", " for images, labels in train_loader:\n", " images, labels = images.to(device), labels.to(device)\n", "\n", " optimizer.zero_grad()\n", " logits = model(images)\n", " loss = F.cross_entropy(logits, labels)\n", " loss.backward()\n", " optimizer.step()\n", "\n", " train_loss += loss.item() * images.size(0)\n", " correct += (logits.argmax(1) == labels).sum().item()\n", " total += images.size(0)\n", "\n", " train_acc = correct / total\n", " train_loss = train_loss / total\n", "\n", " # Eval\n", " model.eval()\n", " correct, total = 0, 0\n", " vol_stats = []\n", "\n", " with torch.no_grad():\n", " for images, labels in test_loader:\n", " images, labels = images.to(device), labels.to(device)\n", " logits, geom = model(images, return_geometry=True)\n", " correct += (logits.argmax(1) == labels).sum().item()\n", " total += images.size(0)\n", " vol_stats.append(geom['vol_sq'].cpu())\n", "\n", " test_acc = correct / total\n", " scheduler.step()\n", "\n", " if test_acc > best_acc:\n", " best_acc = test_acc\n", "\n", " # Volume stats\n", " vol_all = torch.cat(vol_stats, dim=1).squeeze(0) # [B_total]\n", " vol_mean = vol_all.mean().item()\n", " vol_std = vol_all.std().item()\n", " vol_pos = (vol_all > 0).float().mean().item()\n", "\n", " if epoch % 5 == 0 or epoch == 29:\n", " print(f\"Epoch {epoch+1:2d} | Loss: {train_loss:.4f} | Train: {train_acc:.2%} | Test: {test_acc:.2%} | Best: {best_acc:.2%}\")\n", " print(f\" | Vol²: μ={vol_mean:.4f}, σ={vol_std:.4f}, valid={vol_pos:.1%}\")\n", "\n", " print(\"=\"*70)\n", " print(f\"Best Accuracy: {best_acc:.2%}\")\n", "\n", " return model\n", "\n", "\n", "# ============================================================================\n", "# GEOMETRY ANALYSIS\n", "# ============================================================================\n", "\n", "def analyze_geometry(model):\n", " print(\"\\n\" + \"=\"*70)\n", " print(\"GEOMETRY ANALYSIS\")\n", " print(\"=\"*70)\n", "\n", " transform = transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.2860,), (0.3530,))\n", " ])\n", " test_data = datasets.FashionMNIST('./data', train=False, transform=transform)\n", " test_loader = DataLoader(test_data, batch_size=256, shuffle=False)\n", "\n", " model.eval()\n", "\n", " all_d_sq = []\n", " all_vol_sq = []\n", " all_coords = []\n", " all_labels = []\n", "\n", " with torch.no_grad():\n", " for images, labels in test_loader:\n", " images = images.to(device)\n", " _, geom = model(images, return_geometry=True)\n", " all_d_sq.append(geom['d_sq'].cpu())\n", " all_vol_sq.append(geom['vol_sq'].cpu())\n", " all_coords.append(geom['coords'].cpu())\n", " all_labels.append(labels)\n", "\n", " d_sq = torch.cat(all_d_sq, dim=1).squeeze(0) # [B, pairs]\n", " vol_sq = torch.cat(all_vol_sq, dim=1).squeeze(0) # [B]\n", " coords = torch.cat(all_coords, dim=1).squeeze(0) # [B, V, C]\n", " labels = torch.cat(all_labels, dim=0) # [B]\n", "\n", " print(f\"\\nDistances² shape: {list(d_sq.shape)}\")\n", " print(f\"Volume² shape: {list(vol_sq.shape)}\")\n", " print(f\"Coords shape: {list(coords.shape)}\")\n", "\n", " print(f\"\\nDistance² stats:\")\n", " for p in range(d_sq.shape[1]):\n", " print(f\" Pair {p}: μ={d_sq[:, p].mean():.4f}, σ={d_sq[:, p].std():.4f}\")\n", "\n", " print(f\"\\nVolume² stats:\")\n", " print(f\" μ={vol_sq.mean():.4f}, σ={vol_sq.std():.4f}\")\n", " print(f\" min={vol_sq.min():.4f}, max={vol_sq.max():.4f}\")\n", " print(f\" valid (>0): {(vol_sq > 0).float().mean():.1%}\")\n", "\n", " # Per-class volume\n", " print(f\"\\nPer-class Volume²:\")\n", " class_names = ['T-shirt', 'Trouser', 'Pullover', 'Dress', 'Coat',\n", " 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Boot']\n", " for c in range(10):\n", " mask = labels == c\n", " vol_c = vol_sq[mask]\n", " print(f\" {class_names[c]:10s}: μ={vol_c.mean():.4f}, σ={vol_c.std():.4f}\")\n", "\n", " # Visualization\n", " fig, axes = plt.subplots(2, 3, figsize=(15, 10))\n", "\n", " # Distance distributions\n", " ax = axes[0, 0]\n", " for p in range(min(5, d_sq.shape[1])):\n", " ax.hist(d_sq[:, p].numpy(), bins=50, alpha=0.5, label=f'Pair {p}')\n", " ax.set_xlabel('Distance²')\n", " ax.set_ylabel('Count')\n", " ax.set_title('Pairwise Distance² Distributions')\n", " ax.legend()\n", "\n", " # Volume distribution\n", " ax = axes[0, 1]\n", " ax.hist(vol_sq.numpy(), bins=50, color='purple', alpha=0.7)\n", " ax.axvline(x=0, color='red', linestyle='--', label='Valid threshold')\n", " ax.set_xlabel('Volume²')\n", " ax.set_ylabel('Count')\n", " ax.set_title('CM Volume² Distribution')\n", " ax.legend()\n", "\n", " # Per-class volume\n", " ax = axes[0, 2]\n", " vol_by_class = [vol_sq[labels == c].mean().item() for c in range(10)]\n", " ax.bar(range(10), vol_by_class, color='steelblue')\n", " ax.set_xticks(range(10))\n", " ax.set_xticklabels([n[:6] for n in class_names], rotation=45, ha='right')\n", " ax.set_ylabel('Mean Volume²')\n", " ax.set_title('Volume² by Class')\n", "\n", " # Vertex coords (first 2 dims of first vertex)\n", " ax = axes[1, 0]\n", " for c in range(10):\n", " mask = labels == c\n", " ax.scatter(coords[mask, 0, 0].numpy(), coords[mask, 0, 1].numpy(),\n", " alpha=0.3, s=5, label=class_names[c])\n", " ax.set_xlabel('Vertex 0, Dim 0')\n", " ax.set_ylabel('Vertex 0, Dim 1')\n", " ax.set_title('Vertex 0 Coordinates by Class')\n", "\n", " # Distance vs volume\n", " ax = axes[1, 1]\n", " mean_d = d_sq.mean(dim=1)\n", " ax.scatter(mean_d.numpy(), vol_sq.numpy(), alpha=0.2, s=5)\n", " ax.set_xlabel('Mean Distance²')\n", " ax.set_ylabel('Volume²')\n", " ax.set_title('Distance vs Volume Relationship')\n", "\n", " # Simplex shape analysis\n", " ax = axes[1, 2]\n", " # Normalized distances (relative to first)\n", " d_normalized = d_sq / (d_sq[:, 0:1] + 1e-6)\n", " for p in range(1, min(5, d_sq.shape[1])):\n", " ax.hist(d_normalized[:, p].numpy(), bins=50, alpha=0.5, label=f'd{p}/d0')\n", " ax.set_xlabel('Normalized Distance')\n", " ax.set_ylabel('Count')\n", " ax.set_title('Simplex Shape (Normalized Distances)')\n", " ax.legend()\n", "\n", " plt.tight_layout()\n", " plt.savefig('ksimplex_geometry.png', dpi=150, bbox_inches='tight')\n", " plt.show()\n", "\n", " return d_sq, vol_sq, coords, labels\n", "\n", "\n", "# ============================================================================\n", "# RUN\n", "# ============================================================================\n", "\n", "if __name__ == \"__main__\":\n", " model = train()\n", " d_sq, vol_sq, coords, labels = analyze_geometry(model)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "cellView": "form", "id": "2bupmS5D7E_T", "outputId": "49122bf9-4662-4f32-f55c-9ad458abc6e8" }, "execution_count": 4, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Device: cuda\n", "Model params: 106,834\n", "k=2, vertices=3, pairs=3\n", "\n", "Training...\n", "======================================================================\n", "Epoch 1 | Loss: 1.3680 | Train: 48.35% | Test: 50.96% | Best: 50.96%\n", " | Vol²: μ=16.1708, σ=24.7744, valid=100.0%\n", "Epoch 6 | Loss: 0.4053 | Train: 85.61% | Test: 83.02% | Best: 83.02%\n", " | Vol²: μ=36.1256, σ=47.2951, valid=100.0%\n", "Epoch 11 | Loss: 0.3009 | Train: 89.18% | Test: 86.09% | Best: 86.37%\n", " | Vol²: μ=38.3405, σ=48.9495, valid=99.9%\n", "Epoch 16 | Loss: 0.2307 | Train: 91.56% | Test: 88.37% | Best: 88.37%\n", " | Vol²: μ=62.7198, σ=94.9429, valid=100.0%\n", "Epoch 21 | Loss: 0.1651 | Train: 93.91% | Test: 88.42% | Best: 88.69%\n", " | Vol²: μ=76.7864, σ=104.4838, valid=99.9%\n", "Epoch 26 | Loss: 0.1153 | Train: 95.79% | Test: 88.45% | Best: 88.73%\n", " | Vol²: μ=88.7963, σ=108.6138, valid=100.0%\n", "Epoch 30 | Loss: 0.0981 | Train: 96.48% | Test: 88.69% | Best: 88.73%\n", " | Vol²: μ=104.2887, σ=127.8587, valid=100.0%\n", "======================================================================\n", "Best Accuracy: 88.73%\n", "\n", "======================================================================\n", "GEOMETRY ANALYSIS\n", "======================================================================\n", "\n", "Distances² shape: [10000, 3]\n", "Volume² shape: [10000]\n", "Coords shape: [10000, 3, 2]\n", "\n", "Distance² stats:\n", " Pair 0: μ=96.5467, σ=107.5725\n", " Pair 1: μ=142.0466, σ=182.9677\n", " Pair 2: μ=185.9473, σ=176.7757\n", "\n", "Volume² stats:\n", " μ=104.2887, σ=127.8587\n", " min=-0.0006, max=930.7391\n", " valid (>0): 100.0%\n", "\n", "Per-class Volume²:\n", " T-shirt : μ=28.1097, σ=19.5074\n", " Trouser : μ=379.4676, σ=151.7682\n", " Pullover : μ=48.7069, σ=35.9273\n", " Dress : μ=68.1964, σ=27.3748\n", " Coat : μ=129.3032, σ=77.7689\n", " Sandal : μ=7.4617, σ=12.0020\n", " Shirt : μ=115.8989, σ=81.0484\n", " Sneaker : μ=7.1859, σ=9.3125\n", " Bag : μ=148.3312, σ=120.7901\n", " Boot : μ=110.2260, σ=56.2804\n" ] }, { "output_type": "display_data", "data": { "text/plain": [ "
" ], "image/png": "\n" }, "metadata": {} } ] }, { "cell_type": "code", "source": [ "#@title Hierarchical K-Simplex with Depth-Wise Attenuation\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets, transforms\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import math\n", "from typing import List, Tuple, Dict, Optional\n", "from itertools import combinations\n", "\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "print(f\"Device: {device}\")\n", "\n", "# ============================================================================\n", "# HIERARCHICAL SIMPLEX COMPLEX\n", "# ============================================================================\n", "\n", "class SimplicialLevel(nn.Module):\n", " \"\"\"\n", " Operates on k-simplices at a specific level.\n", "\n", " k=0: vertices (points)\n", " k=1: edges (pairs)\n", " k=2: faces (triangles)\n", " k=3: tetrahedra\n", " ...\n", "\n", " Each level:\n", " - Receives features from (k-1)-simplices (boundary)\n", " - Computes k-simplex features via boundary aggregation\n", " - Computes CM volume at this level\n", " - Outputs attenuation based on geometric validity\n", " \"\"\"\n", "\n", " def __init__(self, k: int, num_vertices: int, feature_dim: int):\n", " super().__init__()\n", " self.k = k\n", " self.num_vertices = num_vertices\n", " self.feature_dim = feature_dim\n", "\n", " # Number of k-simplices in complete simplex on num_vertices\n", " # C(num_vertices, k+1)\n", " self.num_simplices = math.comb(num_vertices, k + 1)\n", "\n", " # Boundary has (k+1) faces, each a (k-1)-simplex\n", " self.boundary_size = k + 1\n", "\n", " # Build simplex indices\n", " self._build_simplex_indices()\n", "\n", " if k == 0:\n", " # Vertices: just project input\n", " self.proj = nn.Linear(feature_dim, feature_dim)\n", " else:\n", " # k > 0: aggregate from boundary (k+1 faces)\n", " self.boundary_agg = nn.Linear(self.boundary_size * feature_dim, feature_dim)\n", " self.boundary_gate = nn.Linear(self.boundary_size * feature_dim, self.boundary_size)\n", "\n", " # Per-simplex transform\n", " self.simplex_transform = nn.Sequential(\n", " nn.Linear(feature_dim, feature_dim),\n", " nn.LayerNorm(feature_dim),\n", " nn.GELU(),\n", " nn.Linear(feature_dim, feature_dim),\n", " )\n", "\n", " # Coordinate projection for CM computation (k-simplex needs k+1 vertices in R^k)\n", " self.to_coords = nn.Linear(feature_dim, k + 1) if k > 0 else None\n", "\n", " # Validity gate: attenuate based on volume\n", " self.validity_scale = nn.Parameter(torch.tensor(1.0))\n", " self.validity_bias = nn.Parameter(torch.tensor(0.0))\n", "\n", " def _build_simplex_indices(self):\n", " \"\"\"Build index tensors for k-simplices and their boundaries.\"\"\"\n", " # All k-simplices (combinations of k+1 vertices)\n", " simplices = list(combinations(range(self.num_vertices), self.k + 1))\n", " self.register_buffer('simplex_idx', torch.tensor(simplices)) # [num_simplices, k+1]\n", "\n", " if self.k > 0:\n", " # Boundary: for each k-simplex, its (k+1) boundary (k-1)-simplices\n", " # Boundary of [v0, v1, ..., vk] = [v1,...,vk], [v0,v2,...,vk], ...\n", " boundary_indices = []\n", " lower_simplices = list(combinations(range(self.num_vertices), self.k))\n", " lower_to_idx = {s: i for i, s in enumerate(lower_simplices)}\n", "\n", " for simplex in simplices:\n", " boundaries = []\n", " for i in range(self.k + 1):\n", " # Remove vertex i to get boundary face\n", " face = tuple(v for j, v in enumerate(simplex) if j != i)\n", " boundaries.append(lower_to_idx[face])\n", " boundary_indices.append(boundaries)\n", "\n", " self.register_buffer('boundary_idx', torch.tensor(boundary_indices)) # [num_simplices, k+1]\n", "\n", " def compute_cm_volume_sq(self, features: torch.Tensor) -> torch.Tensor:\n", " \"\"\"\n", " Compute CM volume² for each k-simplex.\n", " features: [B, num_simplices, feature_dim]\n", " returns: [B, num_simplices]\n", " \"\"\"\n", " if self.k == 0:\n", " # 0-simplices (points) have no volume, return 1 (always valid)\n", " return torch.ones(features.shape[0], features.shape[1], device=features.device)\n", "\n", " B, S, F = features.shape\n", "\n", " # Project to coordinates: [B, S, k+1] (k+1 coords for k-simplex vertices)\n", " coords = self.to_coords(features) # [B, S, k+1]\n", "\n", " # For k-simplex: we need distances between k+1 \"virtual\" vertices\n", " # Treat each coordinate as a vertex position in R^1, stack to get R^(k+1)\n", " # Actually: interpret the k+1 outputs as positions of k+1 vertices in R^1\n", " # Then compute pairwise distances\n", "\n", " # Reshape: each simplex has k+1 vertices, each in R^1\n", " # For proper CM: vertices in R^k minimum\n", " # Simpler: use the k+1 features as vertex positions in R^1\n", "\n", " # Pairwise squared distances via broadcasting\n", " # coords: [B, S, k+1] where k+1 = num vertices\n", " v = coords.unsqueeze(-1) # [B, S, k+1, 1]\n", " d_sq = (v - v.transpose(-1, -2)).pow(2).squeeze(-1).squeeze(-1) # Doesn't work for R^1\n", "\n", " # Better: embed in R^k properly\n", " # Let's do it correctly with feature_dim as embedding space\n", "\n", " # Actually, let's gather vertex features and compute real distances\n", " return self._compute_cm_from_vertices(features)\n", "\n", " def _compute_cm_from_vertices(self, features: torch.Tensor) -> torch.Tensor:\n", " \"\"\"\n", " Compute CM volume from simplex vertex features.\n", " Uses feature vectors as vertex positions.\n", " \"\"\"\n", " B, S, F = features.shape\n", " k = self.k\n", "\n", " if k == 0:\n", " return torch.ones(B, S, device=features.device)\n", "\n", " # For k=1 (edges): volume = distance\n", " # For k=2 (triangles): volume = area\n", " # etc.\n", "\n", " # Use learned projection to k-dimensional space\n", " coords = self.to_coords(features) # [B, S, k+1]\n", "\n", " # Interpret as k+1 points in R^1, compute pairwise distances\n", " # For proper k-volume, we need k+1 points in R^k\n", " # Simplification: use scalar coords, volume = product of differences (degenerate)\n", "\n", " # Better approach: compute Gram determinant\n", " # For vectors v1, ..., vk, volume² = det(G) where G_ij = \n", "\n", " # Coords: [B, S, k+1] - treat as k+1 scalar \"positions\"\n", " # Shift to origin: subtract first vertex\n", " if k >= 1:\n", " shifted = coords[..., 1:] - coords[..., 0:1] # [B, S, k]\n", "\n", " if k == 1:\n", " # Edge: length² = (v1 - v0)²\n", " vol_sq = shifted.pow(2).squeeze(-1) # [B, S]\n", " else:\n", " # Gram matrix: [B, S, k, k]\n", " gram = torch.einsum('bsi,bsj->bsij', shifted, shifted)\n", " vol_sq = torch.linalg.det(gram) # [B, S]\n", "\n", " return vol_sq\n", "\n", " return torch.ones(B, S, device=features.device)\n", "\n", " def forward(self, vertex_features: torch.Tensor, lower_features: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n", " \"\"\"\n", " vertex_features: [B, num_vertices, feature_dim]\n", " lower_features: [B, num_lower_simplices, feature_dim] (k-1 level output)\n", "\n", " returns:\n", " simplex_features: [B, num_simplices, feature_dim]\n", " volume_sq: [B, num_simplices]\n", " attenuation: [B, num_simplices] (0-1 validity gate)\n", " \"\"\"\n", " B = vertex_features.shape[0]\n", "\n", " if self.k == 0:\n", " # Vertices: project input\n", " features = self.proj(vertex_features) # [B, V, F]\n", " vol_sq = torch.ones(B, self.num_simplices, device=features.device)\n", " attn = torch.ones(B, self.num_simplices, device=features.device)\n", " return features, vol_sq, attn\n", "\n", " # k > 0: aggregate from boundary\n", " # Gather boundary (k-1)-simplex features\n", " # boundary_idx: [num_simplices, k+1]\n", " idx = self.boundary_idx.unsqueeze(0).unsqueeze(-1).expand(B, -1, -1, lower_features.shape[-1])\n", " boundary_feats = torch.gather(\n", " lower_features.unsqueeze(1).expand(-1, self.num_simplices, -1, -1),\n", " dim=2,\n", " index=idx\n", " ) # [B, num_simplices, k+1, feature_dim]\n", "\n", " # Flatten boundary\n", " boundary_flat = boundary_feats.flatten(-2) # [B, num_simplices, (k+1)*feature_dim]\n", "\n", " # Gated aggregation\n", " gates = F.softmax(self.boundary_gate(boundary_flat), dim=-1) # [B, S, k+1]\n", " gated_boundary = (boundary_feats * gates.unsqueeze(-1)).sum(dim=2) # [B, S, F]\n", "\n", " # Also use linear combination\n", " agg_features = self.boundary_agg(boundary_flat) # [B, S, F]\n", "\n", " # Combine with residual\n", " features = gated_boundary + agg_features\n", " features = self.simplex_transform(features)\n", "\n", " # Compute volume\n", " vol_sq = self._compute_cm_from_vertices(features) # [B, S]\n", "\n", " # Validity attenuation: sigmoid of scaled log-volume\n", " # Large positive volume -> attn near 1\n", " # Small/negative volume -> attn near 0\n", " log_vol = torch.log(vol_sq.abs() + 1e-8)\n", " attn = torch.sigmoid(self.validity_scale * log_vol + self.validity_bias)\n", "\n", " return features, vol_sq, attn\n", "\n", "\n", "class HierarchicalSimplexComplex(nn.Module):\n", " \"\"\"\n", " Full hierarchical simplex complex from k=0 to k=max_k.\n", "\n", " Each level builds on the previous:\n", " - k=0: vertices\n", " - k=1: edges (from vertex pairs)\n", " - k=2: triangles (from edge triples)\n", " - etc.\n", "\n", " Output combines all levels with learned weighting,\n", " attenuated by geometric validity at each level.\n", " \"\"\"\n", "\n", " def __init__(self, num_vertices: int, feature_dim: int, max_k: int):\n", " super().__init__()\n", " self.num_vertices = num_vertices\n", " self.feature_dim = feature_dim\n", " self.max_k = max_k\n", "\n", " # Levels k=0 to k=max_k\n", " self.levels = nn.ModuleList([\n", " SimplicialLevel(k, num_vertices, feature_dim)\n", " for k in range(max_k + 1)\n", " ])\n", "\n", " # Level importance weights (learned)\n", " self.level_weights = nn.Parameter(torch.ones(max_k + 1) / (max_k + 1))\n", "\n", " # Output aggregation\n", " total_simplices = sum(math.comb(num_vertices, k + 1) for k in range(max_k + 1))\n", " self.output_proj = nn.Linear(total_simplices, feature_dim)\n", "\n", " # Store simplex counts for output gathering\n", " self.simplex_counts = [math.comb(num_vertices, k + 1) for k in range(max_k + 1)]\n", "\n", " def forward(self, vertex_features: torch.Tensor) -> Tuple[torch.Tensor, Dict]:\n", " \"\"\"\n", " vertex_features: [B, num_vertices, feature_dim]\n", "\n", " returns:\n", " output: [B, feature_dim]\n", " diagnostics: dict with per-level info\n", " \"\"\"\n", " B = vertex_features.shape[0]\n", "\n", " all_features = []\n", " all_volumes = []\n", " all_attenuations = []\n", "\n", " lower_features = None\n", "\n", " for k, level in enumerate(self.levels):\n", " features, vol_sq, attn = level(vertex_features, lower_features)\n", "\n", " all_features.append(features)\n", " all_volumes.append(vol_sq)\n", " all_attenuations.append(attn)\n", "\n", " lower_features = features\n", "\n", " # Weighted combination across levels\n", " level_w = F.softmax(self.level_weights, dim=0)\n", "\n", " # Pool each level: [B, num_simplices_k] -> [B, 1] weighted by attenuation\n", " level_outputs = []\n", " for k, (features, attn) in enumerate(zip(all_features, all_attenuations)):\n", " # Attenuation-weighted mean\n", " weighted = (features * attn.unsqueeze(-1)).mean(dim=1) # [B, F]\n", " level_outputs.append(level_w[k] * weighted)\n", "\n", " # Sum across levels\n", " output = sum(level_outputs) # [B, F]\n", "\n", " diagnostics = {\n", " 'features': all_features,\n", " 'volumes': all_volumes,\n", " 'attenuations': all_attenuations,\n", " 'level_weights': level_w.detach(),\n", " }\n", "\n", " return output, diagnostics\n", "\n", "\n", "class DeepSimplicialNetwork(nn.Module):\n", " \"\"\"\n", " Stack multiple hierarchical simplex complexes.\n", " Each complex operates on the output of the previous.\n", " \"\"\"\n", "\n", " def __init__(self, num_vertices: int, feature_dim: int, max_k: int, num_complexes: int = 3):\n", " super().__init__()\n", " self.num_vertices = num_vertices\n", " self.feature_dim = feature_dim\n", " self.max_k = max_k\n", " self.num_complexes = num_complexes\n", "\n", " # Input projection to vertices\n", " self.input_proj = nn.Linear(feature_dim, num_vertices * feature_dim)\n", "\n", " # Stacked complexes\n", " self.complexes = nn.ModuleList([\n", " HierarchicalSimplexComplex(num_vertices, feature_dim, max_k)\n", " for _ in range(num_complexes)\n", " ])\n", "\n", " # Inter-complex transforms\n", " self.inter_transforms = nn.ModuleList([\n", " nn.Sequential(\n", " nn.Linear(feature_dim, feature_dim),\n", " nn.LayerNorm(feature_dim),\n", " nn.GELU(),\n", " )\n", " for _ in range(num_complexes - 1)\n", " ])\n", "\n", " # Vertex reconstruction for next complex\n", " self.vertex_reconstructs = nn.ModuleList([\n", " nn.Linear(feature_dim, num_vertices * feature_dim)\n", " for _ in range(num_complexes - 1)\n", " ])\n", "\n", " def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[Dict]]:\n", " \"\"\"\n", " x: [B, feature_dim]\n", " returns: [B, feature_dim], list of diagnostics\n", " \"\"\"\n", " B = x.shape[0]\n", "\n", " # Project to initial vertices\n", " vertices = self.input_proj(x).view(B, self.num_vertices, self.feature_dim)\n", "\n", " all_diagnostics = []\n", "\n", " for i, complex_layer in enumerate(self.complexes):\n", " output, diag = complex_layer(vertices)\n", " all_diagnostics.append(diag)\n", "\n", " if i < self.num_complexes - 1:\n", " # Transform and reconstruct vertices for next complex\n", " transformed = self.inter_transforms[i](output)\n", " vertices = self.vertex_reconstructs[i](transformed).view(B, self.num_vertices, self.feature_dim)\n", " # Residual\n", " vertices = vertices + self.input_proj(x).view(B, self.num_vertices, self.feature_dim)\n", "\n", " return output, all_diagnostics\n", "\n", "\n", "# ============================================================================\n", "# FASHIONMNIST MODEL\n", "# ============================================================================\n", "\n", "class FashionHierarchicalSimplex(nn.Module):\n", " def __init__(self, num_vertices: int = 6, feature_dim: int = 32,\n", " max_k: int = 4, num_complexes: int = 2, num_classes: int = 10):\n", " super().__init__()\n", "\n", " self.stem = nn.Sequential(\n", " nn.Conv2d(1, 32, 3, padding=1),\n", " nn.BatchNorm2d(32),\n", " nn.GELU(),\n", " nn.Conv2d(32, 64, 3, stride=2, padding=1),\n", " nn.BatchNorm2d(64),\n", " nn.GELU(),\n", " nn.Conv2d(64, 128, 3, stride=2, padding=1),\n", " nn.BatchNorm2d(128),\n", " nn.GELU(),\n", " )\n", "\n", " self.pool = nn.AdaptiveAvgPool2d(1)\n", "\n", " # Project to simplex feature dim\n", " self.pre_simplex = nn.Linear(128, feature_dim)\n", "\n", " # Hierarchical simplex\n", " self.simplex = DeepSimplicialNetwork(\n", " num_vertices=num_vertices,\n", " feature_dim=feature_dim,\n", " max_k=max_k,\n", " num_complexes=num_complexes\n", " )\n", "\n", " # Classification head\n", " self.head = nn.Linear(feature_dim, num_classes)\n", "\n", " self.num_vertices = num_vertices\n", " self.max_k = max_k\n", "\n", " def forward(self, x: torch.Tensor, return_diagnostics: bool = False):\n", " B = x.shape[0]\n", "\n", " h = self.stem(x)\n", " h = self.pool(h).flatten(1)\n", " h = self.pre_simplex(h)\n", "\n", " simplex_out, diagnostics = self.simplex(h)\n", "\n", " logits = self.head(simplex_out)\n", "\n", " if return_diagnostics:\n", " return logits, diagnostics\n", " return logits\n", "\n", "\n", "# ============================================================================\n", "# TRAINING\n", "# ============================================================================\n", "\n", "def train():\n", " transform = transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.2860,), (0.3530,))\n", " ])\n", "\n", " train_data = datasets.FashionMNIST('./data', train=True, download=True, transform=transform)\n", " test_data = datasets.FashionMNIST('./data', train=False, download=True, transform=transform)\n", "\n", " train_loader = DataLoader(train_data, batch_size=128, shuffle=True, num_workers=2)\n", " test_loader = DataLoader(test_data, batch_size=256, shuffle=False, num_workers=2)\n", "\n", " model = FashionHierarchicalSimplex(\n", " num_vertices=6,\n", " feature_dim=32,\n", " max_k=4, # Will use k=0,1,2,3,4\n", " num_complexes=2,\n", " num_classes=10\n", " ).to(device)\n", "\n", " total_params = sum(p.numel() for p in model.parameters())\n", " print(f\"Model params: {total_params:,}\")\n", " print(f\"Vertices: {model.num_vertices}, Max k: {model.max_k}\")\n", " print(f\"Simplices per level: {[math.comb(model.num_vertices, k+1) for k in range(model.max_k+1)]}\")\n", "\n", " optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.05)\n", " scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30)\n", "\n", " best_acc = 0\n", "\n", " print(\"\\nTraining...\")\n", " print(\"=\"*80)\n", "\n", " for epoch in range(30):\n", " model.train()\n", " train_loss, correct, total = 0, 0, 0\n", "\n", " for images, labels in train_loader:\n", " images, labels = images.to(device), labels.to(device)\n", "\n", " optimizer.zero_grad()\n", " logits = model(images)\n", " loss = F.cross_entropy(logits, labels)\n", " loss.backward()\n", " optimizer.step()\n", "\n", " train_loss += loss.item() * images.size(0)\n", " correct += (logits.argmax(1) == labels).sum().item()\n", " total += images.size(0)\n", "\n", " train_acc = correct / total\n", " train_loss = train_loss / total\n", "\n", " model.eval()\n", " correct, total = 0, 0\n", " level_stats = None\n", "\n", " with torch.no_grad():\n", " for images, labels in test_loader:\n", " images, labels = images.to(device), labels.to(device)\n", " logits, diagnostics = model(images, return_diagnostics=True)\n", " correct += (logits.argmax(1) == labels).sum().item()\n", " total += images.size(0)\n", "\n", " # Collect stats from first complex\n", " if level_stats is None:\n", " level_stats = {\n", " 'weights': diagnostics[0]['level_weights'].cpu(),\n", " 'volumes': [v.mean().item() for v in diagnostics[0]['volumes']],\n", " 'attenuations': [a.mean().item() for a in diagnostics[0]['attenuations']],\n", " }\n", "\n", " test_acc = correct / total\n", " scheduler.step()\n", "\n", " if test_acc > best_acc:\n", " best_acc = test_acc\n", "\n", " if epoch % 5 == 0 or epoch == 29:\n", " print(f\"Epoch {epoch+1:2d} | Loss: {train_loss:.4f} | Train: {train_acc:.2%} | Test: {test_acc:.2%} | Best: {best_acc:.2%}\")\n", "\n", " weights_str = \", \".join([f\"k{k}={w:.3f}\" for k, w in enumerate(level_stats['weights'])])\n", " print(f\" | Level weights: [{weights_str}]\")\n", "\n", " vol_str = \", \".join([f\"k{k}={v:.4f}\" for k, v in enumerate(level_stats['volumes'])])\n", " print(f\" | Volumes: [{vol_str}]\")\n", "\n", " attn_str = \", \".join([f\"k{k}={a:.3f}\" for k, a in enumerate(level_stats['attenuations'])])\n", " print(f\" | Attenuation: [{attn_str}]\")\n", "\n", " print(\"=\"*80)\n", " print(f\"Best Accuracy: {best_acc:.2%}\")\n", "\n", " return model\n", "\n", "\n", "# ============================================================================\n", "# ANALYSIS\n", "# ============================================================================\n", "\n", "def analyze(model):\n", " print(\"\\n\" + \"=\"*80)\n", " print(\"HIERARCHICAL SIMPLEX ANALYSIS\")\n", " print(\"=\"*80)\n", "\n", " transform = transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.2860,), (0.3530,))\n", " ])\n", " test_data = datasets.FashionMNIST('./data', train=False, transform=transform)\n", " test_loader = DataLoader(test_data, batch_size=256, shuffle=False)\n", "\n", " model.eval()\n", "\n", " all_diagnostics = []\n", " all_labels = []\n", "\n", " with torch.no_grad():\n", " for images, labels in test_loader:\n", " images = images.to(device)\n", " _, diag = model(images, return_diagnostics=True)\n", " all_diagnostics.append(diag)\n", " all_labels.append(labels)\n", "\n", " labels = torch.cat(all_labels)\n", "\n", " # Aggregate per-level stats\n", " max_k = model.max_k\n", "\n", " fig, axes = plt.subplots(2, 3, figsize=(15, 10))\n", "\n", " # 1. Level weights\n", " ax = axes[0, 0]\n", " weights = all_diagnostics[0][0]['level_weights'].numpy()\n", " ax.bar(range(max_k + 1), weights, color='steelblue')\n", " ax.set_xlabel('k-level')\n", " ax.set_ylabel('Weight')\n", " ax.set_title('Learned Level Weights')\n", " ax.set_xticks(range(max_k + 1))\n", "\n", " # 2. Volume by level\n", " ax = axes[0, 1]\n", " vol_by_level = []\n", " for k in range(max_k + 1):\n", " vols = torch.cat([d[0]['volumes'][k].flatten() for d in all_diagnostics])\n", " vol_by_level.append(vols.cpu())\n", "\n", " ax.boxplot([v.numpy() for v in vol_by_level], labels=[f'k={k}' for k in range(max_k + 1)])\n", " ax.set_xlabel('k-level')\n", " ax.set_ylabel('Volume²')\n", " ax.set_title('Volume² Distribution by Level')\n", " ax.set_yscale('symlog')\n", "\n", " # 3. Attenuation by level\n", " ax = axes[0, 2]\n", " attn_by_level = []\n", " for k in range(max_k + 1):\n", " attns = torch.cat([d[0]['attenuations'][k].flatten() for d in all_diagnostics])\n", " attn_by_level.append(attns.mean().item())\n", "\n", " ax.bar(range(max_k + 1), attn_by_level, color='coral')\n", " ax.set_xlabel('k-level')\n", " ax.set_ylabel('Mean Attenuation')\n", " ax.set_title('Attenuation by Level (1=fully valid)')\n", " ax.set_xticks(range(max_k + 1))\n", " ax.set_ylim(0, 1)\n", "\n", " # 4. Volume by class (for highest k)\n", " ax = axes[1, 0]\n", " class_names = ['T-shirt', 'Trouser', 'Pullover', 'Dress', 'Coat',\n", " 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Boot']\n", "\n", " final_vols = vol_by_level[max_k]\n", " vol_by_class = []\n", " for c in range(10):\n", " mask = labels == c\n", " vol_by_class.append(final_vols[mask].mean().item())\n", "\n", " ax.bar(range(10), vol_by_class, color='purple', alpha=0.7)\n", " ax.set_xticks(range(10))\n", " ax.set_xticklabels([n[:6] for n in class_names], rotation=45, ha='right')\n", " ax.set_ylabel(f'Mean Volume² (k={max_k})')\n", " ax.set_title(f'Top-Level Volume by Class')\n", "\n", " # 5. Attenuation correlation with accuracy\n", " ax = axes[1, 1]\n", " # For each sample, get total attenuation\n", " total_attns = []\n", " for d in all_diagnostics:\n", " batch_attn = sum(d[0]['attenuations'][k].mean(dim=1) for k in range(max_k + 1))\n", " total_attns.append(batch_attn.cpu())\n", " total_attns = torch.cat(total_attns)\n", "\n", " ax.hist(total_attns.numpy(), bins=50, color='green', alpha=0.7)\n", " ax.set_xlabel('Total Attenuation')\n", " ax.set_ylabel('Count')\n", " ax.set_title('Distribution of Total Attenuation')\n", "\n", " # 6. Level importance across complexes\n", " ax = axes[1, 2]\n", " num_complexes = len(all_diagnostics[0])\n", " for c_idx in range(num_complexes):\n", " w = all_diagnostics[0][c_idx]['level_weights'].numpy()\n", " ax.plot(range(max_k + 1), w, 'o-', label=f'Complex {c_idx+1}')\n", " ax.set_xlabel('k-level')\n", " ax.set_ylabel('Weight')\n", " ax.set_title('Level Weights Across Complexes')\n", " ax.legend()\n", " ax.set_xticks(range(max_k + 1))\n", "\n", " plt.tight_layout()\n", " plt.savefig('hierarchical_simplex_analysis.png', dpi=150, bbox_inches='tight')\n", " plt.show()\n", "\n", " # Print summary\n", " print(f\"\\nLevel Summary:\")\n", " for k in range(max_k + 1):\n", " num_simp = math.comb(model.num_vertices, k + 1)\n", " vol_mean = vol_by_level[k].mean().item()\n", " vol_std = vol_by_level[k].std().item()\n", " attn = attn_by_level[k]\n", " print(f\" k={k}: {num_simp:3d} simplices | Vol²: μ={vol_mean:8.4f}, σ={vol_std:8.4f} | Attn: {attn:.3f}\")\n", "\n", "\n", "# ============================================================================\n", "# RUN\n", "# ============================================================================\n", "\n", "if __name__ == \"__main__\":\n", " model = train()\n", " analyze(model)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 668 }, "cellView": "form", "id": "pCctl4MWB9XS", "outputId": "d476e980-59ec-4aab-ece0-8bb77ea8214a" }, "execution_count": 5, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Device: cuda\n", "Model params: 172,640\n", "Vertices: 6, Max k: 4\n", "Simplices per level: [6, 15, 20, 15, 6]\n", "\n", "Training...\n", "================================================================================\n", "Epoch 1 | Loss: 1.0471 | Train: 58.45% | Test: 75.72% | Best: 75.72%\n", " | Level weights: [k0=0.193, k1=0.208, k2=0.199, k3=0.199, k4=0.199]\n", " | Volumes: [k0=1.0000, k1=11.4210, k2=0.0000, k3=-0.0000, k4=-0.0000]\n", " | Attenuation: [k0=1.000, k1=0.773, k2=0.000, k3=0.000, k4=0.000]\n", "Epoch 6 | Loss: 0.3745 | Train: 86.54% | Test: 85.42% | Best: 85.42%\n", " | Level weights: [k0=0.194, k1=0.204, k2=0.200, k3=0.200, k4=0.200]\n", " | Volumes: [k0=1.0000, k1=7.1449, k2=-0.0000, k3=-0.0000, k4=0.0000]\n", " | Attenuation: [k0=1.000, k1=0.665, k2=0.000, k3=0.000, k4=0.000]\n", "Epoch 11 | Loss: 0.2896 | Train: 89.61% | Test: 86.59% | Best: 86.59%\n", " | Level weights: [k0=0.202, k1=0.203, k2=0.198, k3=0.198, k4=0.198]\n", " | Volumes: [k0=1.0000, k1=5.1970, k2=0.0000, k3=-0.0000, k4=-0.0000]\n", " | Attenuation: [k0=1.000, k1=0.688, k2=0.000, k3=0.000, k4=0.000]\n" ] }, { "output_type": "error", "ename": "KeyboardInterrupt", "evalue": "", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", "\u001b[0;32m/tmp/ipython-input-382490007.py\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 666\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 667\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0m__name__\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"__main__\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 668\u001b[0;31m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 669\u001b[0m \u001b[0manalyze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/tmp/ipython-input-382490007.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m()\u001b[0m\n\u001b[1;32m 477\u001b[0m \u001b[0mimages\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mimages\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 478\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 479\u001b[0;31m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 480\u001b[0m \u001b[0mlogits\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimages\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 481\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcross_entropy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlogits\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/torch/_compile.py\u001b[0m in \u001b[0;36minner\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__dynamo_disable\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdisable_fn\u001b[0m \u001b[0;31m# type: ignore[attr-defined]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 52\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 53\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mdisable_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 54\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0minner\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py\u001b[0m in \u001b[0;36m_fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 1042\u001b[0m \u001b[0m_maybe_set_eval_frame\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_callback_from_stance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcallback\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1043\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1044\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1045\u001b[0m \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1046\u001b[0m \u001b[0mset_eval_frame\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/torch/optim/optimizer.py\u001b[0m in \u001b[0;36mzero_grad\u001b[0;34m(self, set_to_none)\u001b[0m\n\u001b[1;32m 1033\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1034\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mset_to_none\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1035\u001b[0;31m \u001b[0mp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1036\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1037\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad_fn\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ] }, { "cell_type": "code", "source": [ "#@title True CM KSimplex with Geometric Alignment\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets, transforms\n", "import math\n", "from typing import List, Dict, Tuple\n", "from itertools import combinations\n", "\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "print(f\"Device: {device}\")\n", "\n", "# ============================================================================\n", "# CAYLEY-MENGER CORE\n", "# ============================================================================\n", "\n", "class CayleyMengerValidator(nn.Module):\n", " \"\"\"\n", " Computes true Cayley-Menger determinant for k-simplex validation.\n", "\n", " For k-simplex with k+1 vertices:\n", " - CM matrix is (k+2) × (k+2)\n", " - det(CM) relates to squared volume\n", " - Valid iff (-1)^(k+1) * det > 0\n", " \"\"\"\n", "\n", " def __init__(self, k: int):\n", " super().__init__()\n", " self.k = k\n", " self.num_vertices = k + 1\n", " self.cm_size = k + 2\n", "\n", " # Pairs for distance computation\n", " pairs = list(combinations(range(self.num_vertices), 2))\n", " self.num_pairs = len(pairs)\n", " self.register_buffer('pair_i', torch.tensor([p[0] for p in pairs]))\n", " self.register_buffer('pair_j', torch.tensor([p[1] for p in pairs]))\n", "\n", " # Volume prefactor: (-1)^(k+1) / (2^k * (k!)^2)\n", " sign = (-1.0) ** (self.k + 1)\n", " factorial_k = math.factorial(self.k)\n", " self.prefactor = sign / ((2.0 ** self.k) * (factorial_k ** 2))\n", "\n", " def pairwise_distances_sq(self, vertices: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n", " \"\"\"\n", " Compute TRUE Euclidean squared distances from vertex positions.\n", "\n", " vertices: [..., num_vertices, embed_dim]\n", " returns: d_sq_pairs [..., num_pairs], d_sq_matrix [..., V, V]\n", " \"\"\"\n", " # Gram matrix: G_ij = \n", " gram = torch.einsum('...ve,...we->...vw', vertices, vertices)\n", "\n", " # Squared norms on diagonal\n", " norms_sq = torch.diagonal(gram, dim1=-2, dim2=-1)\n", "\n", " # d²(i,j) = ||v_i||² + ||v_j||² - 2\n", " d_sq_matrix = norms_sq.unsqueeze(-1) + norms_sq.unsqueeze(-2) - 2 * gram\n", "\n", " # Numerical stability\n", " d_sq_matrix = F.relu(d_sq_matrix)\n", "\n", " # Extract pairs\n", " d_sq_pairs = d_sq_matrix[..., self.pair_i, self.pair_j]\n", "\n", " return d_sq_pairs, d_sq_matrix\n", "\n", " def cayley_menger_det(self, d_sq_matrix: torch.Tensor) -> torch.Tensor:\n", " \"\"\"\n", " Compute CM determinant from squared distance matrix.\n", "\n", " d_sq_matrix: [..., V, V]\n", " returns: [...] determinant\n", " \"\"\"\n", " shape = d_sq_matrix.shape[:-2]\n", " V = d_sq_matrix.shape[-1]\n", "\n", " # Build CM matrix: (V+1) × (V+1)\n", " cm = torch.zeros(*shape, V + 1, V + 1, device=d_sq_matrix.device, dtype=d_sq_matrix.dtype)\n", "\n", " # Border: first row/col = 1 (except [0,0] = 0)\n", " cm[..., 0, 1:] = 1.0\n", " cm[..., 1:, 0] = 1.0\n", "\n", " # Interior: squared distances\n", " cm[..., 1:, 1:] = d_sq_matrix\n", "\n", " # Determinant\n", " det = torch.linalg.det(cm)\n", "\n", " return det\n", "\n", " def volume_squared(self, d_sq_matrix: torch.Tensor) -> torch.Tensor:\n", " \"\"\"Compute squared volume from CM determinant.\"\"\"\n", " det = self.cayley_menger_det(d_sq_matrix)\n", " return self.prefactor * det\n", "\n", " def is_valid(self, vol_sq: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:\n", " \"\"\"Check if simplex is geometrically valid (positive volume²).\"\"\"\n", " return vol_sq > eps\n", "\n", " def forward(self, vertices: torch.Tensor) -> Dict[str, torch.Tensor]:\n", " \"\"\"\n", " Full CM computation.\n", "\n", " vertices: [..., num_vertices, embed_dim]\n", " \"\"\"\n", " d_sq_pairs, d_sq_matrix = self.pairwise_distances_sq(vertices)\n", " det = self.cayley_menger_det(d_sq_matrix)\n", " vol_sq = self.prefactor * det\n", " valid = self.is_valid(vol_sq)\n", "\n", " return {\n", " 'd_sq_pairs': d_sq_pairs,\n", " 'd_sq_matrix': d_sq_matrix,\n", " 'cm_det': det,\n", " 'vol_sq': vol_sq,\n", " 'valid': valid,\n", " }\n", "\n", "\n", "# ============================================================================\n", "# GEOMETRIC SIMPLEX LEVEL\n", "# ============================================================================\n", "\n", "class GeometricSimplexLevel(nn.Module):\n", " \"\"\"\n", " Level k of the simplicial complex.\n", "\n", " Takes features from level k-1, constructs k-simplices,\n", " computes TRUE CM geometry, and outputs with validity attenuation.\n", "\n", " Key: vertices are ACTUAL positions in R^embed_dim, not learned proxies.\n", " \"\"\"\n", "\n", " def __init__(self, k: int, n_in: int, n_out: int, feature_dim: int, embed_dim: int = None):\n", " super().__init__()\n", " self.k = k\n", " self.n_in = n_in\n", " self.n_out = n_out\n", " self.feature_dim = feature_dim\n", " self.embed_dim = embed_dim or max(k + 1, 4) # At least k+1 dims for k-simplex\n", " self.num_vertices = k + 1\n", "\n", " # Cayley-Menger validator\n", " self.cm = CayleyMengerValidator(k)\n", "\n", " # Vertex selection: which inputs contribute to each output simplex\n", " # Each output selects k+1 vertices from n_in inputs\n", " self.vertex_logits = nn.Parameter(torch.randn(n_out, self.num_vertices, n_in) * 0.02)\n", "\n", " # Feature to embedding: project features to geometric coordinates\n", " self.to_embed = nn.Linear(feature_dim, self.embed_dim)\n", "\n", " # Feature transform (operates on aggregated features, not geometry)\n", " self.feature_mlp = nn.Sequential(\n", " nn.Linear(feature_dim, feature_dim * 2),\n", " nn.LayerNorm(feature_dim * 2),\n", " nn.GELU(),\n", " nn.Linear(feature_dim * 2, feature_dim),\n", " )\n", "\n", " # Geometry-conditioned gating\n", " # Input: distances + volume -> gate\n", " gate_input_dim = self.cm.num_pairs + 1 # pairs + volume\n", " self.geom_gate = nn.Sequential(\n", " nn.Linear(gate_input_dim, feature_dim),\n", " nn.Sigmoid(),\n", " )\n", "\n", " # Validity-based attenuation (learned scaling)\n", " self.valid_scale = nn.Parameter(torch.tensor(1.0))\n", " self.valid_bias = nn.Parameter(torch.tensor(0.0))\n", "\n", " # Output projection if sizes differ\n", " if n_in != n_out:\n", " self.out_proj = nn.Linear(n_in * feature_dim, n_out * feature_dim)\n", " else:\n", " self.out_proj = None\n", "\n", " def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Dict]:\n", " \"\"\"\n", " x: [B, n_in, feature_dim]\n", " returns: [B, n_out, feature_dim], geometry_info\n", " \"\"\"\n", " B, N, F = x.shape\n", "\n", " # === STEP 1: Select vertices for each output simplex ===\n", " # vertex_logits: [n_out, num_vertices, n_in]\n", " # Softmax over inputs for each vertex position\n", " vertex_weights = F.softmax(self.vertex_logits, dim=-1) # [n_out, V, n_in]\n", "\n", " # Gather weighted input features for each simplex's vertices\n", " # x: [B, n_in, F] -> selected: [B, n_out, V, F]\n", " selected = torch.einsum('ovn,bnf->bovf', vertex_weights, x)\n", "\n", " # === STEP 2: Compute geometric embeddings ===\n", " # Project features to embedding space for CM computation\n", " embeddings = self.to_embed(selected) # [B, n_out, V, embed_dim]\n", "\n", " # === STEP 3: TRUE Cayley-Menger computation ===\n", " cm_results = self.cm(embeddings)\n", " d_sq = cm_results['d_sq_pairs'] # [B, n_out, num_pairs]\n", " vol_sq = cm_results['vol_sq'] # [B, n_out]\n", " valid = cm_results['valid'] # [B, n_out]\n", "\n", " # === STEP 4: Geometry-conditioned feature transform ===\n", " # Concatenate geometric info\n", " geom_info = torch.cat([d_sq, vol_sq.unsqueeze(-1)], dim=-1) # [B, n_out, pairs+1]\n", "\n", " # Compute geometry-based gate\n", " geom_gate = self.geom_gate(geom_info) # [B, n_out, F]\n", "\n", " # Aggregate vertex features (mean over vertices in each simplex)\n", " agg_features = selected.mean(dim=2) # [B, n_out, F]\n", "\n", " # Transform with geometric gating\n", " transformed = self.feature_mlp(agg_features) * geom_gate\n", "\n", " # === STEP 5: Validity-based attenuation ===\n", " # Attenuate based on geometric validity (vol_sq > 0)\n", " log_vol = torch.log(vol_sq.abs() + 1e-8)\n", " attenuation = torch.sigmoid(self.valid_scale * log_vol + self.valid_bias)\n", "\n", " # Apply attenuation\n", " out = transformed * attenuation.unsqueeze(-1)\n", "\n", " # === STEP 6: Residual connection ===\n", " if self.out_proj is not None:\n", " residual = self.out_proj(x.flatten(1)).view(B, self.n_out, F)\n", " else:\n", " residual = x\n", "\n", " out = out + 0.1 * residual\n", "\n", " return out, {\n", " 'k': self.k,\n", " 'd_sq': d_sq,\n", " 'vol_sq': vol_sq,\n", " 'valid': valid,\n", " 'attenuation': attenuation,\n", " 'vertex_weights': vertex_weights.detach(),\n", " }\n", "\n", "\n", "# ============================================================================\n", "# GEOMETRIC ALIGNMENT LOSS\n", "# ============================================================================\n", "\n", "class GeometricAlignmentLoss(nn.Module):\n", " \"\"\"\n", " Ensures geometric coherence across levels.\n", "\n", " 1. Volume consistency: higher k should have proportional volumes\n", " 2. Validity maximization: encourage valid simplices\n", " 3. Distance regularity: prevent degenerate configurations\n", " \"\"\"\n", "\n", " def __init__(self, vol_weight: float = 0.1, valid_weight: float = 0.1, reg_weight: float = 0.01):\n", " super().__init__()\n", " self.vol_weight = vol_weight\n", " self.valid_weight = valid_weight\n", " self.reg_weight = reg_weight\n", "\n", " def forward(self, level_infos: List[Dict]) -> Dict[str, torch.Tensor]:\n", " losses = {}\n", "\n", " # 1. Validity loss: encourage vol_sq > 0\n", " validity_loss = 0\n", " for info in level_infos:\n", " vol_sq = info['vol_sq']\n", " # Penalize negative volumes (invalid simplices)\n", " validity_loss = validity_loss + F.relu(-vol_sq).mean()\n", " losses['validity'] = self.valid_weight * validity_loss\n", "\n", " # 2. Volume consistency: volumes should scale appropriately with k\n", " if len(level_infos) > 1:\n", " vol_consistency = 0\n", " for i in range(1, len(level_infos)):\n", " vol_prev = level_infos[i-1]['vol_sq'].abs().mean()\n", " vol_curr = level_infos[i]['vol_sq'].abs().mean()\n", " # Higher k should have smaller volume (more constrained)\n", " # But not too small (degenerate)\n", " ratio = vol_curr / (vol_prev + 1e-8)\n", " # Penalize extreme ratios\n", " vol_consistency = vol_consistency + (ratio - 0.5).pow(2)\n", " losses['consistency'] = self.vol_weight * vol_consistency\n", "\n", " # 3. Distance regularity: prevent all distances from collapsing\n", " reg_loss = 0\n", " for info in level_infos:\n", " d_sq = info['d_sq']\n", " # Penalize very small distances (degenerate)\n", " reg_loss = reg_loss + F.relu(0.1 - d_sq.mean()).pow(2)\n", " # Penalize very large variance (unstable)\n", " reg_loss = reg_loss + d_sq.var() * 0.01\n", " losses['regularity'] = self.reg_weight * reg_loss\n", "\n", " losses['total'] = sum(losses.values())\n", "\n", " return losses\n", "\n", "\n", "# ============================================================================\n", "# FULL COMPLEX WITH GEOMETRIC ALIGNMENT\n", "# ============================================================================\n", "\n", "class GeometricSimplexComplex(nn.Module):\n", " \"\"\"\n", " Full hierarchical complex with TRUE CM geometry and alignment.\n", "\n", " Each level k builds k-simplices from (k-1) outputs.\n", " Complexity grows: n_k > n_{k-1}\n", " \"\"\"\n", "\n", " def __init__(self, n_base: int, feature_dim: int, max_k: int,\n", " growth: float = 1.5, embed_dim: int = 8):\n", " super().__init__()\n", " self.n_base = n_base\n", " self.feature_dim = feature_dim\n", " self.max_k = max_k\n", " self.embed_dim = embed_dim\n", "\n", " # Compute growing sizes\n", " sizes = [n_base]\n", " for k in range(max_k):\n", " increment = max(1, int(growth * (k + 1)))\n", " sizes.append(sizes[-1] + increment)\n", " self.sizes = sizes\n", "\n", " # Input projection\n", " self.proj_in = nn.Linear(feature_dim, n_base * feature_dim)\n", "\n", " # Geometric levels\n", " self.levels = nn.ModuleList()\n", " for k in range(max_k):\n", " self.levels.append(GeometricSimplexLevel(\n", " k=k,\n", " n_in=sizes[k],\n", " n_out=sizes[k+1],\n", " feature_dim=feature_dim,\n", " embed_dim=embed_dim,\n", " ))\n", "\n", " # Geometric alignment loss\n", " self.align_loss = GeometricAlignmentLoss()\n", "\n", " # Level importance\n", " self.level_weights = nn.Parameter(torch.ones(max_k) / max_k)\n", "\n", " def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Dict]:\n", " \"\"\"x: [B, feature_dim]\"\"\"\n", " B = x.shape[0]\n", "\n", " # Initial vertices\n", " h = self.proj_in(x).view(B, self.n_base, self.feature_dim)\n", "\n", " all_outputs = []\n", " all_infos = []\n", "\n", " for level in self.levels:\n", " h, info = level(h)\n", " all_outputs.append(h)\n", " all_infos.append(info)\n", "\n", " # Weighted combination with validity-aware pooling\n", " w = F.softmax(self.level_weights, dim=0)\n", "\n", " pooled = []\n", " for i, (out, info) in enumerate(zip(all_outputs, all_infos)):\n", " attn = info['attenuation']\n", " # Validity-weighted mean\n", " weighted = (out * attn.unsqueeze(-1)).sum(dim=1) / (attn.sum(dim=1, keepdim=True) + 1e-8)\n", " pooled.append(w[i] * weighted)\n", "\n", " result = sum(pooled)\n", "\n", " # Compute alignment loss\n", " align_losses = self.align_loss(all_infos)\n", "\n", " return result, {\n", " 'sizes': self.sizes,\n", " 'level_infos': all_infos,\n", " 'level_weights': w.detach(),\n", " 'align_losses': align_losses,\n", " }\n", "\n", "\n", "# ============================================================================\n", "# MODEL\n", "# ============================================================================\n", "\n", "class FashionGeometricSimplex(nn.Module):\n", " def __init__(self, n_base=4, feature_dim=32, max_k=5, growth=1.5,\n", " embed_dim=8, n_complex=2):\n", " super().__init__()\n", "\n", " self.stem = nn.Sequential(\n", " nn.Conv2d(1, 32, 3, padding=1),\n", " nn.BatchNorm2d(32),\n", " nn.GELU(),\n", " nn.Conv2d(32, 64, 3, stride=2, padding=1),\n", " nn.BatchNorm2d(64),\n", " nn.GELU(),\n", " nn.Conv2d(64, 128, 3, stride=2, padding=1),\n", " nn.BatchNorm2d(128),\n", " nn.GELU(),\n", " )\n", " self.pool = nn.AdaptiveAvgPool2d(1)\n", " self.proj = nn.Linear(128, feature_dim)\n", "\n", " self.complexes = nn.ModuleList()\n", " self.norms = nn.ModuleList()\n", " for i in range(n_complex):\n", " cpx = GeometricSimplexComplex(n_base, feature_dim, max_k, growth, embed_dim)\n", " self.complexes.append(cpx)\n", " self.norms.append(nn.LayerNorm(feature_dim))\n", " print(f\"Complex {i+1} sizes: {cpx.sizes}\")\n", "\n", " self.head = nn.Linear(feature_dim, 10)\n", "\n", " def forward(self, x, return_diag=False):\n", " h = self.stem(x)\n", " h = self.pool(h).flatten(1)\n", " h = self.proj(h)\n", "\n", " all_diag = []\n", " total_align_loss = 0\n", "\n", " for cpx, norm in zip(self.complexes, self.norms):\n", " out, diag = cpx(h)\n", " all_diag.append(diag)\n", " total_align_loss = total_align_loss + diag['align_losses']['total']\n", " h = norm(h + out)\n", "\n", " logits = self.head(h)\n", "\n", " if return_diag:\n", " return logits, all_diag, total_align_loss\n", " return logits, total_align_loss\n", "\n", "\n", "# ============================================================================\n", "# TRAINING\n", "# ============================================================================\n", "\n", "def train():\n", " transform = transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.2860,), (0.3530,))\n", " ])\n", "\n", " train_data = datasets.FashionMNIST('./data', train=True, download=True, transform=transform)\n", " test_data = datasets.FashionMNIST('./data', train=False, download=True, transform=transform)\n", "\n", " train_loader = DataLoader(train_data, batch_size=128, shuffle=True, num_workers=2)\n", " test_loader = DataLoader(test_data, batch_size=256, shuffle=False, num_workers=2)\n", "\n", " print(\"\\nBuilding model...\")\n", " model = FashionGeometricSimplex(\n", " n_base=4,\n", " feature_dim=32,\n", " max_k=5,\n", " growth=1.5,\n", " embed_dim=8,\n", " n_complex=2,\n", " ).to(device)\n", "\n", " params = sum(p.numel() for p in model.parameters())\n", " print(f\"Parameters: {params:,}\")\n", "\n", " optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.05)\n", " scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30)\n", "\n", " best = 0\n", " print(\"\\nTraining...\")\n", " print(\"=\" * 90)\n", "\n", " for epoch in range(30):\n", " model.train()\n", " loss_sum, align_sum, correct, total = 0, 0, 0, 0\n", "\n", " for imgs, labels in train_loader:\n", " imgs, labels = imgs.to(device), labels.to(device)\n", "\n", " optimizer.zero_grad()\n", " logits, align_loss = model(imgs)\n", "\n", " ce_loss = F.cross_entropy(logits, labels)\n", " total_loss = ce_loss + align_loss\n", "\n", " total_loss.backward()\n", " optimizer.step()\n", "\n", " loss_sum += ce_loss.item() * imgs.size(0)\n", " align_sum += align_loss.item() * imgs.size(0)\n", " correct += (logits.argmax(1) == labels).sum().item()\n", " total += imgs.size(0)\n", "\n", " train_acc = correct / total\n", " train_loss = loss_sum / total\n", " train_align = align_sum / total\n", "\n", " # Eval\n", " model.eval()\n", " correct, total = 0, 0\n", " sample_diag = None\n", "\n", " with torch.no_grad():\n", " for imgs, labels in test_loader:\n", " imgs, labels = imgs.to(device), labels.to(device)\n", " logits, diag, _ = model(imgs, return_diag=True)\n", " correct += (logits.argmax(1) == labels).sum().item()\n", " total += imgs.size(0)\n", " if sample_diag is None:\n", " sample_diag = diag\n", "\n", " test_acc = correct / total\n", " scheduler.step()\n", "\n", " if test_acc > best:\n", " best = test_acc\n", "\n", " if epoch % 5 == 0 or epoch == 29:\n", " print(f\"Ep {epoch+1:2d} | CE {train_loss:.4f} | Align {train_align:.4f} | Train {train_acc:.2%} | Test {test_acc:.2%} | Best {best:.2%}\")\n", "\n", " # Geometry stats\n", " d = sample_diag[0]\n", " for k, info in enumerate(d['level_infos']):\n", " vol = info['vol_sq'].mean().item()\n", " valid_pct = info['valid'].float().mean().item()\n", " attn = info['attenuation'].mean().item()\n", " d_mean = info['d_sq'].mean().item()\n", " print(f\" k={k}: vol²={vol:8.4f} | valid={valid_pct:.1%} | attn={attn:.3f} | d²={d_mean:.4f}\")\n", "\n", " print(\"=\" * 90)\n", " print(f\"Best: {best:.2%}\")\n", " return model\n", "\n", "\n", "if __name__ == \"__main__\":\n", " model = train()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 477 }, "cellView": "form", "id": "llm5QaVACpL8", "outputId": "20adb10e-dd5c-4005-9407-1ec53a89804b" }, "execution_count": 1, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Device: cuda\n", "\n", "Building model...\n", "Complex 1 sizes: [4, 5, 8, 12, 18, 25]\n", "Complex 2 sizes: [4, 5, 8, 12, 18, 25]\n", "Parameters: 1,848,756\n", "\n", "Training...\n", "==========================================================================================\n" ] }, { "output_type": "error", "ename": "AttributeError", "evalue": "'int' object has no attribute 'softmax'", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m/tmp/ipython-input-251580671.py\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 541\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 542\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0m__name__\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"__main__\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 543\u001b[0;31m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m/tmp/ipython-input-251580671.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m()\u001b[0m\n\u001b[1;32m 486\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 487\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 488\u001b[0;31m \u001b[0mlogits\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0malign_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimgs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 489\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 490\u001b[0m \u001b[0mce_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcross_entropy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlogits\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1773\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1774\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1775\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1776\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1777\u001b[0m \u001b[0;31m# torchrec tests the code consistency with the following code\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1784\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1785\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1786\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1787\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1788\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/tmp/ipython-input-251580671.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, return_diag)\u001b[0m\n\u001b[1;32m 430\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 431\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mcpx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnorm\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcomplexes\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorms\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 432\u001b[0;31m \u001b[0mout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdiag\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcpx\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 433\u001b[0m \u001b[0mall_diag\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdiag\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 434\u001b[0m \u001b[0mtotal_align_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtotal_align_loss\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mdiag\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'align_losses'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'total'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1773\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1774\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1775\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1776\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1777\u001b[0m \u001b[0;31m# torchrec tests the code consistency with the following code\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1784\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1785\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1786\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1787\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1788\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/tmp/ipython-input-251580671.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 361\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 362\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mlevel\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlevels\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 363\u001b[0;31m \u001b[0mh\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minfo\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlevel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 364\u001b[0m \u001b[0mall_outputs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mh\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 365\u001b[0m \u001b[0mall_infos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minfo\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1773\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1774\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1775\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1776\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1777\u001b[0m \u001b[0;31m# torchrec tests the code consistency with the following code\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1784\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1785\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1786\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1787\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1788\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/tmp/ipython-input-251580671.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 190\u001b[0m \u001b[0;31m# vertex_logits: [n_out, num_vertices, n_in]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 191\u001b[0m \u001b[0;31m# Softmax over inputs for each vertex position\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 192\u001b[0;31m \u001b[0mvertex_weights\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msoftmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvertex_logits\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# [n_out, V, n_in]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 193\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 194\u001b[0m \u001b[0;31m# Gather weighted input features for each simplex's vertices\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mAttributeError\u001b[0m: 'int' object has no attribute 'softmax'" ] } ] }, { "cell_type": "code", "source": [ "#@title True CM KSimplex - Fixed k=0 Edge Case\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets, transforms\n", "import math\n", "from typing import List, Dict, Tuple\n", "from itertools import combinations\n", "\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "print(f\"Device: {device}\")\n", "\n", "# ============================================================================\n", "# CAYLEY-MENGER VALIDATOR\n", "# ============================================================================\n", "\n", "class CMValidator(nn.Module):\n", " def __init__(self, simplex_order):\n", " super().__init__()\n", " self._order = simplex_order\n", " self._nv = simplex_order + 1\n", "\n", " # k=0 has no pairs (single point)\n", " if self._nv < 2:\n", " self._npairs = 0\n", " self.register_buffer('_pi', torch.tensor([], dtype=torch.long))\n", " self.register_buffer('_pj', torch.tensor([], dtype=torch.long))\n", " else:\n", " pairs = list(combinations(range(self._nv), 2))\n", " self._npairs = len(pairs)\n", " self.register_buffer('_pi', torch.tensor([p[0] for p in pairs], dtype=torch.long))\n", " self.register_buffer('_pj', torch.tensor([p[1] for p in pairs], dtype=torch.long))\n", "\n", " if self._order > 0:\n", " sign = (-1.0) ** (self._order + 1)\n", " fact = math.factorial(self._order)\n", " self._prefactor = sign / ((2.0 ** self._order) * (fact ** 2))\n", " else:\n", " self._prefactor = 1.0\n", "\n", " def forward(self, verts):\n", " \"\"\"verts: [..., num_vertices, embed_dim]\"\"\"\n", "\n", " # k=0: single point, no geometry\n", " if self._order == 0:\n", " shape = verts.shape[:-2]\n", " vol2 = torch.ones(*shape, device=verts.device, dtype=verts.dtype)\n", " d2_pairs = torch.zeros(*shape, 0, device=verts.device, dtype=verts.dtype)\n", " valid = torch.ones(*shape, device=verts.device, dtype=torch.bool)\n", " return d2_pairs, vol2, valid\n", "\n", " # Gram matrix for distances\n", " gram = torch.einsum('...ve,...we->...vw', verts, verts)\n", " norms = torch.diagonal(gram, dim1=-2, dim2=-1)\n", " d2_mat = norms.unsqueeze(-1) + norms.unsqueeze(-2) - 2 * gram\n", " d2_mat = F.relu(d2_mat) # Numerical stability\n", "\n", " # Extract pairs\n", " d2_pairs = d2_mat[..., self._pi, self._pj]\n", "\n", " # CM determinant\n", " shape = d2_mat.shape[:-2]\n", " V = d2_mat.shape[-1]\n", " cm = torch.zeros(*shape, V+1, V+1, device=d2_mat.device, dtype=d2_mat.dtype)\n", " cm[..., 0, 1:] = 1.0\n", " cm[..., 1:, 0] = 1.0\n", " cm[..., 1:, 1:] = d2_mat\n", " det = torch.linalg.det(cm)\n", "\n", " vol2 = self._prefactor * det\n", " valid = vol2 > 1e-8\n", "\n", " return d2_pairs, vol2, valid\n", "\n", "\n", "# ============================================================================\n", "# GEOMETRIC LEVEL\n", "# ============================================================================\n", "\n", "class GeoLevel(nn.Module):\n", " def __init__(self, order, nin, nout, fdim, edim):\n", " super().__init__()\n", " self._order = order\n", " self._nin = nin\n", " self._nout = nout\n", " self._fdim = fdim\n", " self._edim = edim\n", " self._nv = order + 1\n", "\n", " # CM validator\n", " self._cm = CMValidator(order)\n", "\n", " # Vertex selection: [nout, nv, nin]\n", " _sel = torch.randn(nout, self._nv, nin) * 0.02\n", " self.register_parameter('_W_select', nn.Parameter(_sel))\n", "\n", " # Embedding projection\n", " self._to_embed = nn.Linear(fdim, edim)\n", "\n", " # Feature MLP\n", " self._mlp = nn.Sequential(\n", " nn.Linear(fdim, fdim * 2),\n", " nn.LayerNorm(fdim * 2),\n", " nn.GELU(),\n", " nn.Linear(fdim * 2, fdim),\n", " )\n", "\n", " # Geometry gate - handle k=0 case (no pairs)\n", " gate_in = max(1, self._cm._npairs + 1) # At least 1 input\n", " self._gate = nn.Sequential(\n", " nn.Linear(gate_in, fdim),\n", " nn.Sigmoid(),\n", " )\n", "\n", " # Attenuation params\n", " self.register_parameter('_scale', nn.Parameter(torch.tensor(1.0)))\n", " self.register_parameter('_bias', nn.Parameter(torch.tensor(0.0)))\n", "\n", " # Residual projection\n", " if nin != nout:\n", " self._proj = nn.Linear(nin * fdim, nout * fdim)\n", " else:\n", " self._proj = None\n", "\n", " def forward(self, x):\n", " \"\"\"x: [B, nin, fdim]\"\"\"\n", " B = x.shape[0]\n", "\n", " # Soft vertex selection\n", " sel = F.softmax(self._W_select, dim=-1) # [nout, nv, nin]\n", "\n", " # Select vertices: [B, nout, nv, fdim]\n", " picked = torch.einsum('ovn,bnf->bovf', sel, x)\n", "\n", " # Embed for geometry\n", " emb = self._to_embed(picked) # [B, nout, nv, edim]\n", "\n", " # CM computation\n", " d2, vol2, valid = self._cm(emb)\n", "\n", " # Geometric gating\n", " if self._order == 0:\n", " # k=0: just use volume (which is 1)\n", " geo = vol2.unsqueeze(-1) # [B, nout, 1]\n", " else:\n", " geo = torch.cat([d2, vol2.unsqueeze(-1)], dim=-1)\n", " gate = self._gate(geo)\n", "\n", " # Feature transform\n", " agg = picked.mean(dim=2) # [B, nout, fdim]\n", " out = self._mlp(agg) * gate\n", "\n", " # Validity attenuation\n", " logv = torch.log(vol2.abs() + 1e-8)\n", " attn = torch.sigmoid(self._scale * logv + self._bias)\n", " out = out * attn.unsqueeze(-1)\n", "\n", " # Residual\n", " if self._proj is not None:\n", " res = self._proj(x.flatten(1)).view(B, self._nout, self._fdim)\n", " else:\n", " res = x\n", " out = out + 0.1 * res\n", "\n", " return out, {'d2': d2, 'vol2': vol2, 'valid': valid, 'attn': attn, 'order': self._order}\n", "\n", "\n", "# ============================================================================\n", "# COMPLEX\n", "# ============================================================================\n", "\n", "class GeoComplex(nn.Module):\n", " def __init__(self, nbase, fdim, depth, growth, edim):\n", " super().__init__()\n", " self._nbase = nbase\n", " self._fdim = fdim\n", " self._depth = depth\n", "\n", " # Growing sizes\n", " sizes = [nbase]\n", " for i in range(depth):\n", " sizes.append(sizes[-1] + max(1, int(growth * (i + 1))))\n", " self._sizes = sizes\n", "\n", " # Input projection\n", " self._proj_in = nn.Linear(fdim, nbase * fdim)\n", "\n", " # Build levels - start from k=1 (edges) since k=0 is trivial\n", " level_list = []\n", " for i in range(depth):\n", " lv = GeoLevel(\n", " order=i + 1, # Start at k=1, not k=0\n", " nin=sizes[i],\n", " nout=sizes[i+1],\n", " fdim=fdim,\n", " edim=edim\n", " )\n", " level_list.append(lv)\n", " self._levels = nn.ModuleList(level_list)\n", "\n", " # Level weights\n", " self.register_parameter('_lw', nn.Parameter(torch.ones(depth) / depth))\n", "\n", " def forward(self, x):\n", " \"\"\"x: [B, fdim]\"\"\"\n", " B = x.shape[0]\n", " h = self._proj_in(x).view(B, self._nbase, self._fdim)\n", "\n", " outs, infos = [], []\n", " for lv in self._levels:\n", " h, info = lv(h)\n", " outs.append(h)\n", " infos.append(info)\n", "\n", " # Weighted pool\n", " w = F.softmax(self._lw, dim=0)\n", " pooled = []\n", " for i, (o, inf) in enumerate(zip(outs, infos)):\n", " a = inf['attn']\n", " wm = (o * a.unsqueeze(-1)).sum(1) / (a.sum(1, keepdim=True) + 1e-8)\n", " pooled.append(w[i] * wm)\n", "\n", " return sum(pooled), {'sizes': self._sizes, 'infos': infos, 'weights': w.detach()}\n", "\n", "\n", "# ============================================================================\n", "# MODEL\n", "# ============================================================================\n", "\n", "class FashionGeoSimplex(nn.Module):\n", " def __init__(self, nbase=4, fdim=32, depth=5, growth=1.5, edim=8, ncpx=2):\n", " super().__init__()\n", "\n", " self._stem = nn.Sequential(\n", " nn.Conv2d(1, 32, 3, padding=1), nn.BatchNorm2d(32), nn.GELU(),\n", " nn.Conv2d(32, 64, 3, stride=2, padding=1), nn.BatchNorm2d(64), nn.GELU(),\n", " nn.Conv2d(64, 128, 3, stride=2, padding=1), nn.BatchNorm2d(128), nn.GELU(),\n", " )\n", " self._pool = nn.AdaptiveAvgPool2d(1)\n", " self._proj = nn.Linear(128, fdim)\n", "\n", " cpx_list = []\n", " for i in range(ncpx):\n", " c = GeoComplex(nbase, fdim, depth, growth, edim)\n", " cpx_list.append(c)\n", " print(f\"Complex {i+1} sizes: {c._sizes}\")\n", " self._cpx = nn.ModuleList(cpx_list)\n", " self._norms = nn.ModuleList([nn.LayerNorm(fdim) for _ in range(ncpx)])\n", "\n", " self._head = nn.Linear(fdim, 10)\n", "\n", " def forward(self, x, ret_diag=False):\n", " h = self._stem(x)\n", " h = self._pool(h).flatten(1)\n", " h = self._proj(h)\n", "\n", " diags = []\n", " align = 0\n", " for cpx, norm in zip(self._cpx, self._norms):\n", " out, diag = cpx(h)\n", " diags.append(diag)\n", " for inf in diag['infos']:\n", " align = align + F.relu(-inf['vol2']).mean()\n", " h = norm(h + out)\n", "\n", " logits = self._head(h)\n", "\n", " if ret_diag:\n", " return logits, diags, align\n", " return logits, align\n", "\n", "\n", "# ============================================================================\n", "# TRAIN\n", "# ============================================================================\n", "\n", "def train():\n", " transform = transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.2860,), (0.3530,))\n", " ])\n", "\n", " train_ds = datasets.FashionMNIST('./data', train=True, download=True, transform=transform)\n", " test_ds = datasets.FashionMNIST('./data', train=False, download=True, transform=transform)\n", " train_dl = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=2)\n", " test_dl = DataLoader(test_ds, batch_size=256, shuffle=False, num_workers=2)\n", "\n", " print(\"\\nBuilding model...\")\n", " model = FashionGeoSimplex(\n", " nbase=4, fdim=32, depth=5, growth=1.5, edim=8, ncpx=2\n", " ).to(device)\n", "\n", " print(f\"Params: {sum(p.numel() for p in model.parameters()):,}\")\n", "\n", " opt = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.05)\n", " sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=30)\n", "\n", " best = 0\n", " print(\"\\nTraining...\")\n", " print(\"=\" * 100)\n", "\n", " for ep in range(30):\n", " model.train()\n", " lsum, asum, cor, tot = 0, 0, 0, 0\n", "\n", " for img, lab in train_dl:\n", " img, lab = img.to(device), lab.to(device)\n", "\n", " opt.zero_grad()\n", " logits, align = model(img)\n", " ce = F.cross_entropy(logits, lab)\n", " loss = ce + 0.1 * align\n", " loss.backward()\n", " opt.step()\n", "\n", " lsum += ce.item() * img.size(0)\n", " asum += align.item() * img.size(0)\n", " cor += (logits.argmax(1) == lab).sum().item()\n", " tot += img.size(0)\n", "\n", " tr_acc = cor / tot\n", " tr_loss = lsum / tot\n", " tr_align = asum / tot\n", "\n", " model.eval()\n", " cor, tot = 0, 0\n", " samp = None\n", " with torch.no_grad():\n", " for img, lab in test_dl:\n", " img, lab = img.to(device), lab.to(device)\n", " logits, diag, _ = model(img, ret_diag=True)\n", " cor += (logits.argmax(1) == lab).sum().item()\n", " tot += img.size(0)\n", " if samp is None:\n", " samp = diag\n", "\n", " te_acc = cor / tot\n", " sched.step()\n", " if te_acc > best:\n", " best = te_acc\n", "\n", " if ep % 5 == 0 or ep == 29:\n", " print(f\"Ep {ep+1:2d} | CE {tr_loss:.4f} | Align {tr_align:.4f} | Tr {tr_acc:.2%} | Te {te_acc:.2%} | Best {best:.2%}\")\n", " for inf in samp[0]['infos']:\n", " k = inf['order']\n", " v = inf['vol2'].mean().item()\n", " vld = inf['valid'].float().mean().item()\n", " att = inf['attn'].mean().item()\n", " d = inf['d2'].mean().item() if inf['d2'].numel() > 0 else 0\n", " npairs = inf['d2'].shape[-1] if inf['d2'].numel() > 0 else 0\n", " print(f\" k={k}: pairs={npairs:2d} vol²={v:10.4f} valid={vld:5.1%} attn={att:.3f} d²={d:.4f}\")\n", "\n", " print(\"=\" * 100)\n", " print(f\"Best: {best:.2%}\")\n", " return model\n", "\n", "\n", "if __name__ == \"__main__\":\n", " model = train()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 790 }, "id": "AwfSL9fyGRz2", "outputId": "7c11affd-ea50-4706-819a-e564ec756246" }, "execution_count": 5, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Device: cuda\n", "\n", "Building model...\n", "Complex 1 sizes: [4, 5, 8, 12, 18, 25]\n", "Complex 2 sizes: [4, 5, 8, 12, 18, 25]\n", "Params: 1,851,360\n", "\n", "Training...\n", "====================================================================================================\n", "Ep 1 | CE 0.6076 | Align 0.0000 | Tr 80.14% | Te 84.11% | Best 84.11%\n", " k=1: pairs= 1 vol²= 0.0001 valid=69.9% attn=0.000 d²=0.0001\n", " k=2: pairs= 3 vol²= 0.0000 valid= 0.0% attn=0.000 d²=0.0000\n", " k=3: pairs= 6 vol²= -0.0000 valid= 0.0% attn=0.000 d²=0.0000\n", " k=4: pairs=10 vol²= 0.0000 valid= 0.0% attn=0.000 d²=0.0000\n", " k=5: pairs=15 vol²= 0.0000 valid= 0.0% attn=0.000 d²=0.0000\n", "Ep 6 | CE 0.2593 | Align 0.0000 | Tr 90.64% | Te 88.86% | Best 88.86%\n", " k=1: pairs= 1 vol²= 0.0002 valid=60.2% attn=0.000 d²=0.0002\n", " k=2: pairs= 3 vol²= 0.0000 valid= 0.0% attn=0.000 d²=0.0000\n", " k=3: pairs= 6 vol²= -0.0000 valid= 0.0% attn=0.000 d²=0.0000\n", " k=4: pairs=10 vol²= 0.0000 valid= 0.0% attn=0.000 d²=0.0000\n", " k=5: pairs=15 vol²= 0.0000 valid= 0.0% attn=0.000 d²=0.0000\n", "Ep 11 | CE 0.1813 | Align 0.0000 | Tr 93.45% | Te 89.39% | Best 90.19%\n", " k=1: pairs= 1 vol²= 0.0002 valid=60.2% attn=0.000 d²=0.0002\n", " k=2: pairs= 3 vol²= 0.0000 valid= 0.0% attn=0.000 d²=0.0000\n", " k=3: pairs= 6 vol²= -0.0000 valid= 0.0% attn=0.000 d²=0.0000\n", " k=4: pairs=10 vol²= 0.0000 valid= 0.0% attn=0.000 d²=0.0000\n", " k=5: pairs=15 vol²= 0.0000 valid= 0.0% attn=0.000 d²=0.0000\n" ] }, { "output_type": "error", "ename": "KeyboardInterrupt", "evalue": "", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", "\u001b[0;32m/tmp/ipython-input-2853210191.py\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 358\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 359\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0m__name__\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"__main__\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 360\u001b[0;31m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m/tmp/ipython-input-2853210191.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m()\u001b[0m\n\u001b[1;32m 312\u001b[0m \u001b[0mce\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcross_entropy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlogits\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlab\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 313\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mce\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m0.1\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0malign\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 314\u001b[0;31m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 315\u001b[0m \u001b[0mopt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 316\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/torch/_tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 623\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 624\u001b[0m )\n\u001b[0;32m--> 625\u001b[0;31m torch.autograd.backward(\n\u001b[0m\u001b[1;32m 626\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 627\u001b[0m )\n", "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 352\u001b[0m \u001b[0;31m# some Python versions print out the first line of a multi-line function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 353\u001b[0m \u001b[0;31m# calls in the traceback and some print out the last line\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 354\u001b[0;31m _engine_run_backward(\n\u001b[0m\u001b[1;32m 355\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 356\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/torch/autograd/graph.py\u001b[0m in \u001b[0;36m_engine_run_backward\u001b[0;34m(t_outputs, *args, **kwargs)\u001b[0m\n\u001b[1;32m 839\u001b[0m \u001b[0munregister_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_register_logging_hooks_on_whole_graph\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt_outputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 840\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 841\u001b[0;31m return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass\n\u001b[0m\u001b[1;32m 842\u001b[0m \u001b[0mt_outputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 843\u001b[0m ) # Calls into the C++ engine to run the backward pass\n", "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ] }, { "cell_type": "code", "source": [ "#@title GeoLevel with Uniform Contribution Scale\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets, transforms\n", "import math\n", "from itertools import combinations\n", "\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "print(f\"Device: {device}\")\n", "\n", "from geovocab2.shapes.factory.simplex_factory import SimplexFactory\n", "\n", "# ============================================================================\n", "# CAYLEY-MENGER VALIDATOR\n", "# ============================================================================\n", "\n", "class CMValidator(nn.Module):\n", " def __init__(self, simplex_order):\n", " super().__init__()\n", " self._order = simplex_order\n", " self._nv = simplex_order + 1\n", "\n", " if self._nv < 2:\n", " self._npairs = 0\n", " self.register_buffer('_pi', torch.tensor([], dtype=torch.long))\n", " self.register_buffer('_pj', torch.tensor([], dtype=torch.long))\n", " else:\n", " pairs = list(combinations(range(self._nv), 2))\n", " self._npairs = len(pairs)\n", " self.register_buffer('_pi', torch.tensor([p[0] for p in pairs], dtype=torch.long))\n", " self.register_buffer('_pj', torch.tensor([p[1] for p in pairs], dtype=torch.long))\n", "\n", " if self._order > 0:\n", " sign = (-1.0) ** (self._order + 1)\n", " fact = math.factorial(self._order)\n", " self._prefactor = sign / ((2.0 ** self._order) * (fact ** 2))\n", " else:\n", " self._prefactor = 1.0\n", "\n", " def forward(self, verts):\n", " if self._order == 0:\n", " shape = verts.shape[:-2]\n", " return (torch.zeros(*shape, 0, device=verts.device),\n", " torch.ones(*shape, device=verts.device),\n", " torch.ones(*shape, dtype=torch.bool, device=verts.device))\n", "\n", " gram = torch.einsum('...ve,...we->...vw', verts, verts)\n", " norms = torch.diagonal(gram, dim1=-2, dim2=-1)\n", " d2_mat = norms.unsqueeze(-1) + norms.unsqueeze(-2) - 2 * gram\n", " d2_mat = F.relu(d2_mat)\n", "\n", " d2_pairs = d2_mat[..., self._pi, self._pj]\n", "\n", " shape = d2_mat.shape[:-2]\n", " V = d2_mat.shape[-1]\n", " cm = torch.zeros(*shape, V+1, V+1, device=d2_mat.device, dtype=d2_mat.dtype)\n", " cm[..., 0, 1:] = 1.0\n", " cm[..., 1:, 0] = 1.0\n", " cm[..., 1:, 1:] = d2_mat\n", " det = torch.linalg.det(cm)\n", "\n", " vol2 = self._prefactor * det\n", " valid = vol2 > 1e-8\n", "\n", " return d2_pairs, vol2, valid\n", "\n", "\n", "# ============================================================================\n", "# GEOMETRIC LEVEL - UNIFORM CONTRIBUTION\n", "# ============================================================================\n", "\n", "class GeoLevel(nn.Module):\n", " # Class-level constant: deform scale safe for highest k\n", " BASE_DEFORM = 0.05 # Small enough for k=5 to stay valid\n", "\n", " def __init__(self, order, nin, nout, fdim, edim):\n", " super().__init__()\n", " self._order = order\n", " self._nin = nin\n", " self._nout = nout\n", " self._fdim = fdim\n", " self._edim = edim\n", " self._nv = order + 1\n", "\n", " self._cm = CMValidator(order)\n", "\n", " # Regular simplex template from factory\n", " factory = SimplexFactory(k=order, embed_dim=edim, method=\"regular\", scale=1.0)\n", " template = factory.build_torch(dtype=torch.float32)\n", " self.register_buffer('_template', template)\n", "\n", " # Compute template's natural volume for normalization\n", " with torch.no_grad():\n", " _, template_vol, _ = self._cm(template.unsqueeze(0))\n", " self._template_vol = template_vol.item()\n", "\n", " # Vertex selection\n", " _sel = torch.randn(nout, self._nv, nin) * 0.1\n", " self.register_parameter('_W_select', nn.Parameter(_sel))\n", "\n", " # Deformation network\n", " self._deform = nn.Sequential(\n", " nn.Linear(fdim, fdim),\n", " nn.LayerNorm(fdim),\n", " nn.GELU(),\n", " nn.Linear(fdim, self._nv * edim),\n", " )\n", "\n", " # Feature MLP\n", " self._mlp = nn.Sequential(\n", " nn.Linear(fdim, fdim * 2),\n", " nn.LayerNorm(fdim * 2),\n", " nn.GELU(),\n", " nn.Linear(fdim * 2, fdim),\n", " )\n", "\n", " # Geometry gate - normalize by template volume for uniform contribution\n", " gate_in = max(1, self._cm._npairs + 1)\n", " self._gate = nn.Sequential(\n", " nn.Linear(gate_in, fdim),\n", " nn.Sigmoid(),\n", " )\n", "\n", " # Volume normalization factor: scale vol² so all k contribute equally\n", " # vol² for regular k-simplex ∝ 1/((k+1)! * 2^k)\n", " # We want to amplify higher-k volumes\n", " self._vol_norm = math.factorial(order + 1) * (2 ** order)\n", "\n", " # Attenuation uses NORMALIZED volume\n", " self.register_parameter('_scale', nn.Parameter(torch.tensor(5.0))) # Higher sensitivity\n", " self.register_parameter('_bias', nn.Parameter(torch.tensor(0.0)))\n", "\n", " # Residual\n", " if nin != nout:\n", " self._proj = nn.Linear(nin * fdim, nout * fdim)\n", " else:\n", " self._proj = None\n", "\n", " def forward(self, x):\n", " B = x.shape[0]\n", "\n", " sel = F.softmax(self._W_select, dim=-1)\n", " picked = torch.einsum('ovn,bnf->bovf', sel, x)\n", " agg = picked.mean(dim=2)\n", "\n", " # Deformation - UNIFORM small scale\n", " deform = self._deform(agg).view(B, self._nout, self._nv, self._edim)\n", "\n", " # Template expansion\n", " template_expanded = self._template.unsqueeze(0).unsqueeze(0).expand(B, self._nout, -1, -1)\n", "\n", " # Apply UNIFORM deformation scale\n", " verts = template_expanded + self.BASE_DEFORM * deform\n", "\n", " # CM computation\n", " d2, vol2, valid = self._cm(verts)\n", "\n", " # NORMALIZE volume for uniform contribution across k\n", " vol2_norm = vol2 * self._vol_norm\n", "\n", " # Geometric gating with normalized volume\n", " if self._cm._npairs == 0:\n", " geo = vol2_norm.unsqueeze(-1)\n", " else:\n", " # Also normalize distances? Keep raw for now\n", " geo = torch.cat([d2, vol2_norm.unsqueeze(-1)], dim=-1)\n", " gate = self._gate(geo)\n", "\n", " # Feature transform\n", " out = self._mlp(agg) * gate\n", "\n", " # Attenuation based on NORMALIZED volume\n", " logv = torch.log(vol2_norm.abs() + 1e-8)\n", " attn = torch.sigmoid(self._scale * logv + self._bias)\n", " out = out * attn.unsqueeze(-1)\n", "\n", " # Residual\n", " if self._proj is not None:\n", " res = self._proj(x.flatten(1)).view(B, self._nout, self._fdim)\n", " else:\n", " res = x\n", " out = out + 0.1 * res\n", "\n", " return out, {\n", " 'd2': d2,\n", " 'vol2': vol2,\n", " 'vol2_norm': vol2_norm,\n", " 'valid': valid,\n", " 'attn': attn,\n", " 'order': self._order,\n", " 'vol_norm_factor': self._vol_norm,\n", " }\n", "\n", "\n", "# ============================================================================\n", "# COMPLEX\n", "# ============================================================================\n", "\n", "class GeoComplex(nn.Module):\n", " def __init__(self, nbase, fdim, depth, growth, edim):\n", " super().__init__()\n", " self._nbase = nbase\n", " self._fdim = fdim\n", " self._depth = depth\n", "\n", " sizes = [nbase]\n", " for i in range(depth):\n", " sizes.append(sizes[-1] + max(1, int(growth * (i + 1))))\n", " self._sizes = sizes\n", "\n", " self._proj_in = nn.Linear(fdim, nbase * fdim)\n", "\n", " level_list = []\n", " for i in range(depth):\n", " lv = GeoLevel(order=i + 1, nin=sizes[i], nout=sizes[i+1], fdim=fdim, edim=edim)\n", " level_list.append(lv)\n", " self._levels = nn.ModuleList(level_list)\n", "\n", " self.register_parameter('_lw', nn.Parameter(torch.ones(depth) / depth))\n", "\n", " def forward(self, x):\n", " B = x.shape[0]\n", " h = self._proj_in(x).view(B, self._nbase, self._fdim)\n", "\n", " outs, infos = [], []\n", " for lv in self._levels:\n", " h, info = lv(h)\n", " outs.append(h)\n", " infos.append(info)\n", "\n", " w = F.softmax(self._lw, dim=0)\n", " pooled = []\n", " for i, (o, inf) in enumerate(zip(outs, infos)):\n", " a = inf['attn']\n", " wm = (o * a.unsqueeze(-1)).sum(1) / (a.sum(1, keepdim=True) + 1e-8)\n", " pooled.append(w[i] * wm)\n", "\n", " return sum(pooled), {'sizes': self._sizes, 'infos': infos, 'weights': w.detach()}\n", "\n", "\n", "# ============================================================================\n", "# MODEL\n", "# ============================================================================\n", "\n", "class FashionGeoSimplex(nn.Module):\n", " def __init__(self, nbase=4, fdim=32, depth=5, growth=1.5, edim=8, ncpx=2):\n", " super().__init__()\n", "\n", " self._stem = nn.Sequential(\n", " nn.Conv2d(1, 32, 3, padding=1), nn.BatchNorm2d(32), nn.GELU(),\n", " nn.Conv2d(32, 64, 3, stride=2, padding=1), nn.BatchNorm2d(64), nn.GELU(),\n", " nn.Conv2d(64, 128, 3, stride=2, padding=1), nn.BatchNorm2d(128), nn.GELU(),\n", " )\n", " self._pool = nn.AdaptiveAvgPool2d(1)\n", " self._proj = nn.Linear(128, fdim)\n", "\n", " cpx_list = []\n", " for i in range(ncpx):\n", " c = GeoComplex(nbase, fdim, depth, growth, edim)\n", " cpx_list.append(c)\n", " print(f\"Complex {i+1} sizes: {c._sizes}\")\n", " self._cpx = nn.ModuleList(cpx_list)\n", " self._norms = nn.ModuleList([nn.LayerNorm(fdim) for _ in range(ncpx)])\n", "\n", " self._head = nn.Linear(fdim, 10)\n", "\n", " def forward(self, x, ret_diag=False):\n", " h = self._stem(x)\n", " h = self._pool(h).flatten(1)\n", " h = self._proj(h)\n", "\n", " diags = []\n", " align = 0\n", " for cpx, norm in zip(self._cpx, self._norms):\n", " out, diag = cpx(h)\n", " diags.append(diag)\n", " for inf in diag['infos']:\n", " align = align + F.relu(-inf['vol2']).mean()\n", " h = norm(h + out)\n", "\n", " logits = self._head(h)\n", "\n", " if ret_diag:\n", " return logits, diags, align\n", " return logits, align\n", "\n", "\n", "# ============================================================================\n", "# TRAIN\n", "# ============================================================================\n", "\n", "def train():\n", " transform = transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.2860,), (0.3530,))\n", " ])\n", "\n", " train_ds = datasets.FashionMNIST('./data', train=True, download=True, transform=transform)\n", " test_ds = datasets.FashionMNIST('./data', train=False, download=True, transform=transform)\n", " train_dl = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=2)\n", " test_dl = DataLoader(test_ds, batch_size=256, shuffle=False, num_workers=2)\n", "\n", " print(\"\\nBuilding model...\")\n", " model = FashionGeoSimplex(\n", " nbase=4, fdim=32, depth=5, growth=1.5, edim=8, ncpx=2\n", " ).to(device)\n", "\n", " print(f\"Params: {sum(p.numel() for p in model.parameters()):,}\")\n", " print(f\"Deform scale: {GeoLevel.BASE_DEFORM}\")\n", "\n", " opt = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.05)\n", " sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=30)\n", "\n", " best = 0\n", " print(\"\\nTraining...\")\n", " print(\"=\" * 115)\n", "\n", " for ep in range(30):\n", " model.train()\n", " lsum, asum, cor, tot = 0, 0, 0, 0\n", "\n", " for img, lab in train_dl:\n", " img, lab = img.to(device), lab.to(device)\n", "\n", " opt.zero_grad()\n", " logits, align = model(img)\n", " ce = F.cross_entropy(logits, lab)\n", " loss = ce + 0.1 * align\n", " loss.backward()\n", " opt.step()\n", "\n", " lsum += ce.item() * img.size(0)\n", " asum += align.item() * img.size(0)\n", " cor += (logits.argmax(1) == lab).sum().item()\n", " tot += img.size(0)\n", "\n", " tr_acc = cor / tot\n", " tr_loss = lsum / tot\n", " tr_align = asum / tot\n", "\n", " model.eval()\n", " cor, tot = 0, 0\n", " samp = None\n", " with torch.no_grad():\n", " for img, lab in test_dl:\n", " img, lab = img.to(device), lab.to(device)\n", " logits, diag, _ = model(img, ret_diag=True)\n", " cor += (logits.argmax(1) == lab).sum().item()\n", " tot += img.size(0)\n", " if samp is None:\n", " samp = diag\n", "\n", " te_acc = cor / tot\n", " sched.step()\n", " if te_acc > best:\n", " best = te_acc\n", "\n", " if ep % 5 == 0 or ep == 29:\n", " print(f\"Ep {ep+1:2d} | CE {tr_loss:.4f} | Align {tr_align:.4f} | Tr {tr_acc:.2%} | Te {te_acc:.2%} | Best {best:.2%}\")\n", " for inf in samp[0]['infos']:\n", " k = inf['order']\n", " v = inf['vol2'].mean().item()\n", " vn = inf['vol2_norm'].mean().item()\n", " vld = inf['valid'].float().mean().item()\n", " att = inf['attn'].mean().item()\n", " d = inf['d2'].mean().item() if inf['d2'].numel() > 0 else 0\n", " nf = inf['vol_norm_factor']\n", " print(f\" k={k}: vol²={v:.2e} norm_vol²={vn:8.4f} (×{nf:6.0f}) valid={vld:5.1%} attn={att:.3f} d²={d:.4f}\")\n", "\n", " print(\"=\" * 115)\n", " print(f\"Best: {best:.2%}\")\n", " return model\n", "\n", "\n", "if __name__ == \"__main__\":\n", " model = train()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "cellView": "form", "id": "0KgB_FEGI_wQ", "outputId": "b57726eb-f7cd-438d-9640-7279aa802eec" }, "execution_count": 2, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Device: cuda\n", "\n", "Building model...\n", "Complex 1 sizes: [4, 5, 8, 12, 18, 25]\n", "Complex 2 sizes: [4, 5, 8, 12, 18, 25]\n", "Params: 1,870,480\n", "Deform scale: 0.05\n", "\n", "Training...\n", "===================================================================================================================\n", "Ep 1 | CE 0.6210 | Align 0.0000 | Tr 79.12% | Te 82.34% | Best 82.34%\n", " k=1: vol²=1.10e+00 norm_vol²= 4.3829 (× 4) valid=100.0% attn=0.999 d²=1.0957\n", " k=2: vol²=1.74e-01 norm_vol²= 4.1699 (× 24) valid=100.0% attn=0.999 d²=0.9628\n", " k=3: vol²=1.22e-02 norm_vol²= 2.3417 (× 192) valid=100.0% attn=0.957 d²=0.9536\n", " k=4: vol²=3.84e-04 norm_vol²= 0.7364 (× 1920) valid=100.0% attn=0.214 d²=0.9126\n", " k=5: vol²=9.74e-06 norm_vol²= 0.2244 (× 23040) valid=100.0% attn=0.002 d²=0.9434\n", "Ep 6 | CE 0.2658 | Align 0.0000 | Tr 90.39% | Te 88.01% | Best 88.69%\n", " k=1: vol²=1.10e+00 norm_vol²= 4.3852 (× 4) valid=100.0% attn=0.997 d²=1.0963\n", " k=2: vol²=1.71e-01 norm_vol²= 4.1001 (× 24) valid=100.0% attn=0.995 d²=0.9512\n", " k=3: vol²=6.77e-03 norm_vol²= 1.3001 (× 192) valid=100.0% attn=0.559 d²=0.7758\n", " k=4: vol²=2.16e-04 norm_vol²= 0.4138 (× 1920) valid=100.0% attn=0.039 d²=0.7908\n", " k=5: vol²=5.65e-06 norm_vol²= 0.1301 (× 23040) valid=100.0% attn=0.000 d²=0.8504\n", "Ep 11 | CE 0.1864 | Align 0.0000 | Tr 93.15% | Te 90.00% | Best 90.00%\n", " k=1: vol²=1.01e+00 norm_vol²= 4.0337 (× 4) valid=100.0% attn=0.988 d²=1.0084\n", " k=2: vol²=1.67e-01 norm_vol²= 4.0047 (× 24) valid=100.0% attn=0.988 d²=0.9339\n", " k=3: vol²=3.95e-03 norm_vol²= 0.7577 (× 192) valid=100.0% attn=0.203 d²=0.6477\n", " k=4: vol²=1.45e-04 norm_vol²= 0.2780 (× 1920) valid=100.0% attn=0.006 d²=0.7200\n", " k=5: vol²=3.54e-06 norm_vol²= 0.0815 (× 23040) valid=100.0% attn=0.000 d²=0.7747\n", "Ep 16 | CE 0.1209 | Align 0.0000 | Tr 95.69% | Te 90.17% | Best 90.23%\n", " k=1: vol²=9.67e-01 norm_vol²= 3.8681 (× 4) valid=100.0% attn=0.974 d²=0.9670\n", " k=2: vol²=1.83e-01 norm_vol²= 4.3922 (× 24) valid=100.0% attn=0.976 d²=0.9606\n", " k=3: vol²=4.44e-03 norm_vol²= 0.8527 (× 192) valid=100.0% attn=0.264 d²=0.6602\n", " k=4: vol²=1.37e-04 norm_vol²= 0.2625 (× 1920) valid=100.0% attn=0.014 d²=0.7010\n", " k=5: vol²=2.42e-06 norm_vol²= 0.0557 (× 23040) valid=100.0% attn=0.000 d²=0.7184\n", "Ep 21 | CE 0.0618 | Align 0.0000 | Tr 97.87% | Te 89.79% | Best 90.23%\n", " k=1: vol²=9.37e-01 norm_vol²= 3.7499 (× 4) valid=100.0% attn=0.972 d²=0.9375\n", " k=2: vol²=1.80e-01 norm_vol²= 4.3153 (× 24) valid=100.0% attn=0.969 d²=0.9480\n", " k=3: vol²=4.56e-03 norm_vol²= 0.8756 (× 192) valid=100.0% attn=0.280 d²=0.6662\n", " k=4: vol²=1.05e-04 norm_vol²= 0.2024 (× 1920) valid=100.0% attn=0.007 d²=0.6569\n", " k=5: vol²=1.81e-06 norm_vol²= 0.0417 (× 23040) valid=100.0% attn=0.000 d²=0.6783\n", "Ep 26 | CE 0.0219 | Align 0.0000 | Tr 99.39% | Te 90.16% | Best 90.39%\n", " k=1: vol²=9.27e-01 norm_vol²= 3.7065 (× 4) valid=100.0% attn=0.970 d²=0.9266\n", " k=2: vol²=1.86e-01 norm_vol²= 4.4618 (× 24) valid=100.0% attn=0.972 d²=0.9652\n", " k=3: vol²=4.70e-03 norm_vol²= 0.9024 (× 192) valid=100.0% attn=0.287 d²=0.6710\n", " k=4: vol²=1.79e-04 norm_vol²= 0.3436 (× 1920) valid=100.0% attn=0.087 d²=0.7025\n", " k=5: vol²=1.96e-06 norm_vol²= 0.0451 (× 23040) valid=100.0% attn=0.000 d²=0.6807\n", "Ep 30 | CE 0.0129 | Align 0.0000 | Tr 99.69% | Te 90.05% | Best 90.39%\n", " k=1: vol²=9.34e-01 norm_vol²= 3.7373 (× 4) valid=100.0% attn=0.971 d²=0.9343\n", " k=2: vol²=1.87e-01 norm_vol²= 4.4798 (× 24) valid=100.0% attn=0.973 d²=0.9686\n", " k=3: vol²=4.66e-03 norm_vol²= 0.8940 (× 192) valid=100.0% attn=0.289 d²=0.6706\n", " k=4: vol²=1.76e-04 norm_vol²= 0.3385 (× 1920) valid=100.0% attn=0.083 d²=0.7028\n", " k=5: vol²=1.92e-06 norm_vol²= 0.0443 (× 23040) valid=100.0% attn=0.000 d²=0.6794\n", "===================================================================================================================\n", "Best: 90.39%\n" ] } ] }, { "cell_type": "code", "source": [ "#@title Validity-Only Attenuation (Valid = Full Contribution)\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets, transforms\n", "import math\n", "from itertools import combinations\n", "\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "print(f\"Device: {device}\")\n", "\n", "from geovocab2.shapes.factory.simplex_factory import SimplexFactory\n", "\n", "# ============================================================================\n", "# CAYLEY-MENGER VALIDATOR\n", "# ============================================================================\n", "\n", "class CMValidator(nn.Module):\n", " def __init__(self, simplex_order):\n", " super().__init__()\n", " self._order = simplex_order\n", " self._nv = simplex_order + 1\n", "\n", " if self._nv < 2:\n", " self._npairs = 0\n", " self.register_buffer('_pi', torch.tensor([], dtype=torch.long))\n", " self.register_buffer('_pj', torch.tensor([], dtype=torch.long))\n", " else:\n", " pairs = list(combinations(range(self._nv), 2))\n", " self._npairs = len(pairs)\n", " self.register_buffer('_pi', torch.tensor([p[0] for p in pairs], dtype=torch.long))\n", " self.register_buffer('_pj', torch.tensor([p[1] for p in pairs], dtype=torch.long))\n", "\n", " if self._order > 0:\n", " sign = (-1.0) ** (self._order + 1)\n", " fact = math.factorial(self._order)\n", " self._prefactor = sign / ((2.0 ** self._order) * (fact ** 2))\n", " else:\n", " self._prefactor = 1.0\n", "\n", " def forward(self, verts):\n", " if self._order == 0:\n", " shape = verts.shape[:-2]\n", " return (torch.zeros(*shape, 0, device=verts.device),\n", " torch.ones(*shape, device=verts.device),\n", " torch.ones(*shape, dtype=torch.bool, device=verts.device))\n", "\n", " gram = torch.einsum('...ve,...we->...vw', verts, verts)\n", " norms = torch.diagonal(gram, dim1=-2, dim2=-1)\n", " d2_mat = norms.unsqueeze(-1) + norms.unsqueeze(-2) - 2 * gram\n", " d2_mat = F.relu(d2_mat)\n", "\n", " d2_pairs = d2_mat[..., self._pi, self._pj]\n", "\n", " shape = d2_mat.shape[:-2]\n", " V = d2_mat.shape[-1]\n", " cm = torch.zeros(*shape, V+1, V+1, device=d2_mat.device, dtype=d2_mat.dtype)\n", " cm[..., 0, 1:] = 1.0\n", " cm[..., 1:, 0] = 1.0\n", " cm[..., 1:, 1:] = d2_mat\n", " det = torch.linalg.det(cm)\n", "\n", " vol2 = self._prefactor * det\n", " valid = vol2 > 1e-8\n", "\n", " return d2_pairs, vol2, valid\n", "\n", "\n", "# ============================================================================\n", "# GEOMETRIC LEVEL - VALIDITY-ONLY ATTENUATION\n", "# ============================================================================\n", "\n", "class GeoLevel(nn.Module):\n", " BASE_DEFORM = 0.05\n", "\n", " def __init__(self, order, nin, nout, fdim, edim):\n", " super().__init__()\n", " self._order = order\n", " self._nin = nin\n", " self._nout = nout\n", " self._fdim = fdim\n", " self._edim = edim\n", " self._nv = order + 1\n", "\n", " self._cm = CMValidator(order)\n", "\n", " # Regular simplex template\n", " factory = SimplexFactory(k=order, embed_dim=edim, method=\"regular\", scale=1.0)\n", " template = factory.build_torch(dtype=torch.float32)\n", " self.register_buffer('_template', template)\n", "\n", " # Vertex selection\n", " _sel = torch.randn(nout, self._nv, nin) * 0.1\n", " self.register_parameter('_W_select', nn.Parameter(_sel))\n", "\n", " # Deformation network\n", " self._deform = nn.Sequential(\n", " nn.Linear(fdim, fdim),\n", " nn.LayerNorm(fdim),\n", " nn.GELU(),\n", " nn.Linear(fdim, self._nv * edim),\n", " )\n", "\n", " # Feature MLP\n", " self._mlp = nn.Sequential(\n", " nn.Linear(fdim, fdim * 2),\n", " nn.LayerNorm(fdim * 2),\n", " nn.GELU(),\n", " nn.Linear(fdim * 2, fdim),\n", " )\n", "\n", " # Geometry-to-feature gate (NOT attenuation - this is for feature modulation)\n", " gate_in = max(1, self._cm._npairs + 1)\n", " self._gate = nn.Sequential(\n", " nn.Linear(gate_in, fdim),\n", " nn.Sigmoid(),\n", " )\n", "\n", " # Residual\n", " if nin != nout:\n", " self._proj = nn.Linear(nin * fdim, nout * fdim)\n", " else:\n", " self._proj = None\n", "\n", " def forward(self, x):\n", " B = x.shape[0]\n", "\n", " sel = F.softmax(self._W_select, dim=-1)\n", " picked = torch.einsum('ovn,bnf->bovf', sel, x)\n", " agg = picked.mean(dim=2)\n", "\n", " # Deformation\n", " deform = self._deform(agg).view(B, self._nout, self._nv, self._edim)\n", " template_expanded = self._template.unsqueeze(0).unsqueeze(0).expand(B, self._nout, -1, -1)\n", " verts = template_expanded + self.BASE_DEFORM * deform\n", "\n", " # CM computation\n", " d2, vol2, valid = self._cm(verts)\n", "\n", " # Geometric gating (feature modulation based on geometry)\n", " if self._cm._npairs == 0:\n", " geo = torch.ones(B, self._nout, 1, device=x.device)\n", " else:\n", " # Normalize d2 for gating (mean-center, scale)\n", " d2_norm = d2 / (d2.mean(dim=-1, keepdim=True) + 1e-8)\n", " geo = torch.cat([d2_norm, vol2.unsqueeze(-1) / (vol2.mean() + 1e-8)], dim=-1)\n", " gate = self._gate(geo)\n", "\n", " # Feature transform WITH geometry modulation\n", " out = self._mlp(agg) * gate\n", "\n", " # VALIDITY-ONLY ATTENUATION\n", " # Valid (vol² > 0) → attn = 1.0 (full contribution)\n", " # Invalid (vol² ≤ 0) → attn = 0.0 (no contribution)\n", " # Use soft version for gradient flow\n", " attn = torch.sigmoid(vol2 * 1e6) # Sharp sigmoid: valid→1, invalid→0\n", " out = out * attn.unsqueeze(-1)\n", "\n", " # Residual\n", " if self._proj is not None:\n", " res = self._proj(x.flatten(1)).view(B, self._nout, self._fdim)\n", " else:\n", " res = x\n", " out = out + 0.1 * res\n", "\n", " return out, {\n", " 'd2': d2,\n", " 'vol2': vol2,\n", " 'valid': valid,\n", " 'attn': attn,\n", " 'order': self._order,\n", " }\n", "\n", "\n", "# ============================================================================\n", "# COMPLEX\n", "# ============================================================================\n", "\n", "class GeoComplex(nn.Module):\n", " def __init__(self, nbase, fdim, depth, growth, edim):\n", " super().__init__()\n", " self._nbase = nbase\n", " self._fdim = fdim\n", " self._depth = depth\n", "\n", " sizes = [nbase]\n", " for i in range(depth):\n", " sizes.append(sizes[-1] + max(1, int(growth * (i + 1))))\n", " self._sizes = sizes\n", "\n", " self._proj_in = nn.Linear(fdim, nbase * fdim)\n", "\n", " level_list = []\n", " for i in range(depth):\n", " lv = GeoLevel(order=i + 1, nin=sizes[i], nout=sizes[i+1], fdim=fdim, edim=edim)\n", " level_list.append(lv)\n", " self._levels = nn.ModuleList(level_list)\n", "\n", " # EQUAL weights - each k contributes equally\n", " self.register_parameter('_lw', nn.Parameter(torch.ones(depth) / depth))\n", "\n", " def forward(self, x):\n", " B = x.shape[0]\n", " h = self._proj_in(x).view(B, self._nbase, self._fdim)\n", "\n", " outs, infos = [], []\n", " for lv in self._levels:\n", " h, info = lv(h)\n", " outs.append(h)\n", " infos.append(info)\n", "\n", " w = F.softmax(self._lw, dim=0)\n", " pooled = []\n", " for i, (o, inf) in enumerate(zip(outs, infos)):\n", " a = inf['attn']\n", " # Mean over simplices (attn-weighted for validity)\n", " wm = (o * a.unsqueeze(-1)).sum(1) / (a.sum(1, keepdim=True) + 1e-8)\n", " pooled.append(w[i] * wm)\n", "\n", " return sum(pooled), {'sizes': self._sizes, 'infos': infos, 'weights': w.detach()}\n", "\n", "\n", "# ============================================================================\n", "# MODEL\n", "# ============================================================================\n", "\n", "class FashionGeoSimplex(nn.Module):\n", " def __init__(self, nbase=4, fdim=32, depth=5, growth=1.5, edim=8, ncpx=1):\n", " super().__init__()\n", "\n", " self._stem = nn.Sequential(\n", " nn.Conv2d(1, 32, 3, padding=1), nn.BatchNorm2d(32), nn.GELU(),\n", " nn.Conv2d(32, 64, 3, stride=2, padding=1), nn.BatchNorm2d(64), nn.GELU(),\n", " nn.Conv2d(64, 128, 3, stride=2, padding=1), nn.BatchNorm2d(128), nn.GELU(),\n", " )\n", " self._pool = nn.AdaptiveAvgPool2d(1)\n", " self._proj = nn.Linear(128, fdim)\n", "\n", " cpx_list = []\n", " for i in range(ncpx):\n", " c = GeoComplex(nbase, fdim, depth, growth, edim)\n", " cpx_list.append(c)\n", " print(f\"Complex {i+1} sizes: {c._sizes}\")\n", " self._cpx = nn.ModuleList(cpx_list)\n", " self._norms = nn.ModuleList([nn.LayerNorm(fdim) for _ in range(ncpx)])\n", "\n", " self._head = nn.Linear(fdim, 10)\n", "\n", " def forward(self, x, ret_diag=False):\n", " h = self._stem(x)\n", " h = self._pool(h).flatten(1)\n", " h = self._proj(h)\n", "\n", " diags = []\n", " validity_loss = 0\n", " for cpx, norm in zip(self._cpx, self._norms):\n", " out, diag = cpx(h)\n", " diags.append(diag)\n", " # Penalize invalid simplices\n", " for inf in diag['infos']:\n", " validity_loss = validity_loss + F.relu(-inf['vol2']).mean()\n", " h = norm(h + out)\n", "\n", " logits = self._head(h)\n", "\n", " if ret_diag:\n", " return logits, diags, validity_loss\n", " return logits, validity_loss\n", "\n", "\n", "# ============================================================================\n", "# TRAIN\n", "# ============================================================================\n", "\n", "def train():\n", " transform = transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.2860,), (0.3530,))\n", " ])\n", "\n", " train_ds = datasets.FashionMNIST('./data', train=True, download=True, transform=transform)\n", " test_ds = datasets.FashionMNIST('./data', train=False, download=True, transform=transform)\n", " train_dl = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=2)\n", " test_dl = DataLoader(test_ds, batch_size=256, shuffle=False, num_workers=2)\n", "\n", " print(\"\\nBuilding model...\")\n", " model = FashionGeoSimplex(\n", " nbase=4, fdim=32, depth=5, growth=1.5, edim=8, ncpx=1\n", " ).to(device)\n", "\n", " print(f\"Params: {sum(p.numel() for p in model.parameters()):,}\")\n", "\n", " opt = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.05)\n", " sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=30)\n", "\n", " best = 0\n", " print(\"\\nTraining...\")\n", " print(\"=\" * 95)\n", "\n", " for ep in range(30):\n", " model.train()\n", " lsum, vsum, cor, tot = 0, 0, 0, 0\n", "\n", " for img, lab in train_dl:\n", " img, lab = img.to(device), lab.to(device)\n", "\n", " opt.zero_grad()\n", " logits, vloss = model(img)\n", " ce = F.cross_entropy(logits, lab)\n", " loss = ce + 0.1 * vloss\n", " loss.backward()\n", " opt.step()\n", "\n", " lsum += ce.item() * img.size(0)\n", " vsum += vloss.item() * img.size(0)\n", " cor += (logits.argmax(1) == lab).sum().item()\n", " tot += img.size(0)\n", "\n", " tr_acc = cor / tot\n", " tr_loss = lsum / tot\n", "\n", " model.eval()\n", " cor, tot = 0, 0\n", " samp = None\n", " with torch.no_grad():\n", " for img, lab in test_dl:\n", " img, lab = img.to(device), lab.to(device)\n", " logits, diag, _ = model(img, ret_diag=True)\n", " cor += (logits.argmax(1) == lab).sum().item()\n", " tot += img.size(0)\n", " if samp is None:\n", " samp = diag\n", "\n", " te_acc = cor / tot\n", " sched.step()\n", " if te_acc > best:\n", " best = te_acc\n", "\n", " # Get level weights\n", " lw = samp[0]['weights'].cpu().numpy()\n", "\n", " if ep % 5 == 0 or ep == 29:\n", " print(f\"Ep {ep+1:2d} | CE {tr_loss:.4f} | Tr {tr_acc:.2%} | Te {te_acc:.2%} | Best {best:.2%}\")\n", " print(f\" | Level weights: {['%.3f' % w for w in lw]}\")\n", " for inf in samp[0]['infos']:\n", " k = inf['order']\n", " v = inf['vol2'].mean().item()\n", " vld = inf['valid'].float().mean().item()\n", " att = inf['attn'].mean().item()\n", " d = inf['d2'].mean().item() if inf['d2'].numel() > 0 else 0\n", " print(f\" k={k}: vol²={v:.2e} valid={vld:5.1%} attn={att:.3f} d²={d:.4f}\")\n", "\n", " print(\"=\" * 95)\n", " print(f\"Best: {best:.2%}\")\n", " return model\n", "\n", "\n", "if __name__ == \"__main__\":\n", " model = train()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "cellView": "form", "id": "gk0zo8LONzca", "outputId": "f136f1a9-7a2b-4e0b-be80-fce3644d58bd" }, "execution_count": 5, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Device: cuda\n", "\n", "Building model...\n", "Complex 1 sizes: [4, 5, 8, 12, 18, 25]\n", "Params: 984,019\n", "\n", "Training...\n", "===============================================================================================\n", "Ep 1 | CE 0.6183 | Tr 79.44% | Te 82.62% | Best 82.62%\n", " | Level weights: ['0.207', '0.203', '0.199', '0.198', '0.194']\n", " k=1: vol²=8.62e-01 valid=100.0% attn=1.000 d²=0.8620\n", " k=2: vol²=1.86e-01 valid=100.0% attn=1.000 d²=0.9950\n", " k=3: vol²=1.51e-02 valid=100.0% attn=1.000 d²=1.0361\n", " k=4: vol²=6.00e-04 valid=100.0% attn=1.000 d²=1.0339\n", " k=5: vol²=1.14e-05 valid=100.0% attn=1.000 d²=0.9892\n", "Ep 6 | CE 0.2578 | Tr 90.67% | Te 87.74% | Best 87.74%\n", " | Level weights: ['0.210', '0.196', '0.191', '0.207', '0.195']\n", " k=1: vol²=4.75e-01 valid=100.0% attn=1.000 d²=0.4749\n", " k=2: vol²=1.98e-01 valid=100.0% attn=1.000 d²=0.8901\n", " k=3: vol²=1.98e-02 valid=100.0% attn=1.000 d²=1.0984\n", " k=4: vol²=4.95e-04 valid=100.0% attn=1.000 d²=1.0151\n", " k=5: vol²=5.98e-06 valid=100.0% attn=0.997 d²=0.9892\n", "Ep 11 | CE 0.1790 | Tr 93.56% | Te 88.70% | Best 89.66%\n", " | Level weights: ['0.207', '0.184', '0.183', '0.216', '0.210']\n", " k=1: vol²=2.71e-01 valid=100.0% attn=1.000 d²=0.2715\n", " k=2: vol²=1.66e-01 valid=100.0% attn=1.000 d²=0.8062\n", " k=3: vol²=2.53e-02 valid=100.0% attn=1.000 d²=1.0979\n", " k=4: vol²=3.09e-04 valid=100.0% attn=1.000 d²=0.9295\n", " k=5: vol²=5.28e-08 valid=100.0% attn=0.513 d²=0.7772\n", "Ep 16 | CE 0.1146 | Tr 95.83% | Te 89.59% | Best 89.99%\n", " | Level weights: ['0.205', '0.179', '0.177', '0.222', '0.216']\n", " k=1: vol²=2.80e-01 valid=100.0% attn=1.000 d²=0.2799\n", " k=2: vol²=2.98e-01 valid=100.0% attn=1.000 d²=1.0136\n", " k=3: vol²=3.29e-02 valid=100.0% attn=1.000 d²=1.1553\n", " k=4: vol²=1.47e-04 valid=100.0% attn=1.000 d²=0.8277\n", " k=5: vol²=2.97e-08 valid=100.0% attn=0.507 d²=0.8029\n", "Ep 21 | CE 0.0568 | Tr 97.98% | Te 89.74% | Best 89.99%\n", " | Level weights: ['0.210', '0.182', '0.176', '0.220', '0.212']\n", " k=1: vol²=1.83e-01 valid=100.0% attn=1.000 d²=0.1833\n", " k=2: vol²=2.07e-01 valid=100.0% attn=1.000 d²=0.8469\n", " k=3: vol²=2.62e-02 valid=100.0% attn=1.000 d²=1.0473\n", " k=4: vol²=1.55e-04 valid=100.0% attn=1.000 d²=0.7964\n", " k=5: vol²=2.45e-08 valid=100.0% attn=0.506 d²=0.8140\n", "Ep 26 | CE 0.0201 | Tr 99.45% | Te 89.71% | Best 90.05%\n", " | Level weights: ['0.214', '0.185', '0.177', '0.217', '0.207']\n", " k=1: vol²=2.19e-01 valid=100.0% attn=1.000 d²=0.2194\n", " k=2: vol²=2.29e-01 valid=100.0% attn=1.000 d²=0.8758\n", " k=3: vol²=2.19e-02 valid=100.0% attn=1.000 d²=0.9457\n", " k=4: vol²=1.41e-04 valid=100.0% attn=1.000 d²=0.7400\n", " k=5: vol²=2.15e-08 valid=64.1% attn=0.505 d²=0.8143\n", "Ep 30 | CE 0.0115 | Tr 99.77% | Te 89.72% | Best 90.05%\n", " | Level weights: ['0.215', '0.185', '0.177', '0.216', '0.206']\n", " k=1: vol²=2.16e-01 valid=100.0% attn=1.000 d²=0.2163\n", " k=2: vol²=2.25e-01 valid=100.0% attn=1.000 d²=0.8611\n", " k=3: vol²=2.11e-02 valid=100.0% attn=1.000 d²=0.9233\n", " k=4: vol²=1.42e-04 valid=100.0% attn=1.000 d²=0.7251\n", " k=5: vol²=1.99e-08 valid=55.5% attn=0.505 d²=0.8124\n", "===============================================================================================\n", "Best: 90.05%\n" ] } ] }, { "cell_type": "code", "source": [ "#@title True CM KSimplex - Exponential Growth (Proper Structure)\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets, transforms\n", "import math\n", "from itertools import combinations\n", "\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "print(f\"Device: {device}\")\n", "\n", "from geovocab2.shapes.factory.simplex_factory import SimplexFactory\n", "\n", "# ============================================================================\n", "# CAYLEY-MENGER VALIDATOR\n", "# ============================================================================\n", "\n", "class CMValidator(nn.Module):\n", " def __init__(self, simplex_order):\n", " super().__init__()\n", " self._order = simplex_order\n", " self._nv = simplex_order + 1\n", "\n", " if self._nv < 2:\n", " self._npairs = 0\n", " self.register_buffer('_pi', torch.tensor([], dtype=torch.long))\n", " self.register_buffer('_pj', torch.tensor([], dtype=torch.long))\n", " else:\n", " pairs = list(combinations(range(self._nv), 2))\n", " self._npairs = len(pairs)\n", " self.register_buffer('_pi', torch.tensor([p[0] for p in pairs], dtype=torch.long))\n", " self.register_buffer('_pj', torch.tensor([p[1] for p in pairs], dtype=torch.long))\n", "\n", " if self._order > 0:\n", " sign = (-1.0) ** (self._order + 1)\n", " fact = math.factorial(self._order)\n", " self._prefactor = sign / ((2.0 ** self._order) * (fact ** 2))\n", " else:\n", " self._prefactor = 1.0\n", "\n", " def forward(self, verts):\n", " if self._order == 0:\n", " shape = verts.shape[:-2]\n", " return (torch.zeros(*shape, 0, device=verts.device),\n", " torch.ones(*shape, device=verts.device),\n", " torch.ones(*shape, dtype=torch.bool, device=verts.device))\n", "\n", " gram = torch.einsum('...ve,...we->...vw', verts, verts)\n", " norms = torch.diagonal(gram, dim1=-2, dim2=-1)\n", " d2_mat = norms.unsqueeze(-1) + norms.unsqueeze(-2) - 2 * gram\n", " d2_mat = F.relu(d2_mat)\n", "\n", " d2_pairs = d2_mat[..., self._pi, self._pj]\n", "\n", " shape = d2_mat.shape[:-2]\n", " V = d2_mat.shape[-1]\n", " cm = torch.zeros(*shape, V+1, V+1, device=d2_mat.device, dtype=d2_mat.dtype)\n", " cm[..., 0, 1:] = 1.0\n", " cm[..., 1:, 0] = 1.0\n", " cm[..., 1:, 1:] = d2_mat\n", " det = torch.linalg.det(cm)\n", "\n", " vol2 = self._prefactor * det\n", " valid = vol2 > 1e-8\n", "\n", " return d2_pairs, vol2, valid\n", "\n", "\n", "# ============================================================================\n", "# GEOMETRIC LEVEL\n", "# ============================================================================\n", "\n", "class GeoLevel(nn.Module):\n", " BASE_DEFORM = 0.05\n", "\n", " def __init__(self, order, nin, nout, fdim, edim):\n", " super().__init__()\n", " self._order = order\n", " self._nin = nin\n", " self._nout = nout\n", " self._fdim = fdim\n", " self._edim = edim\n", " self._nv = order + 1\n", "\n", " # Potential simplices from nin inputs\n", " self._potential = math.comb(nin, self._nv)\n", "\n", " self._cm = CMValidator(order)\n", "\n", " # Regular simplex template\n", " factory = SimplexFactory(k=order, embed_dim=edim, method=\"regular\", scale=1.0)\n", " template = factory.build_torch(dtype=torch.float32)\n", " self.register_buffer('_template', template)\n", "\n", " # Vertex selection: [nout, nv, nin]\n", " # Each output simplex soft-selects nv vertices from nin inputs\n", " _sel = torch.randn(nout, self._nv, nin) * 0.1\n", " self.register_parameter('_W_select', nn.Parameter(_sel))\n", "\n", " # Deformation network\n", " self._deform = nn.Sequential(\n", " nn.Linear(fdim, fdim),\n", " nn.LayerNorm(fdim),\n", " nn.GELU(),\n", " nn.Linear(fdim, self._nv * edim),\n", " )\n", "\n", " # Feature MLP\n", " self._mlp = nn.Sequential(\n", " nn.Linear(fdim, fdim * 2),\n", " nn.LayerNorm(fdim * 2),\n", " nn.GELU(),\n", " nn.Linear(fdim * 2, fdim),\n", " )\n", "\n", " # Geometry gate\n", " gate_in = max(1, self._cm._npairs + 1)\n", " self._gate = nn.Sequential(\n", " nn.Linear(gate_in, fdim),\n", " nn.Sigmoid(),\n", " )\n", "\n", " # Residual\n", " if nin != nout:\n", " self._proj = nn.Linear(nin * fdim, nout * fdim)\n", " else:\n", " self._proj = None\n", "\n", " def forward(self, x):\n", " B = x.shape[0]\n", "\n", " sel = F.softmax(self._W_select, dim=-1)\n", " picked = torch.einsum('ovn,bnf->bovf', sel, x)\n", " agg = picked.mean(dim=2)\n", "\n", " # Deformation\n", " deform = self._deform(agg).view(B, self._nout, self._nv, self._edim)\n", " template_expanded = self._template.unsqueeze(0).unsqueeze(0).expand(B, self._nout, -1, -1)\n", " verts = template_expanded + self.BASE_DEFORM * deform\n", "\n", " # CM computation\n", " d2, vol2, valid = self._cm(verts)\n", "\n", " # Geometric gating\n", " if self._cm._npairs == 0:\n", " geo = torch.ones(B, self._nout, 1, device=x.device)\n", " else:\n", " d2_norm = d2 / (d2.mean(dim=-1, keepdim=True) + 1e-8)\n", " geo = torch.cat([d2_norm, vol2.unsqueeze(-1) / (vol2.mean() + 1e-8)], dim=-1)\n", " gate = self._gate(geo)\n", "\n", " out = self._mlp(agg) * gate\n", "\n", " # Validity-only attenuation\n", " attn = torch.sigmoid(vol2 * 1e6)\n", " out = out * attn.unsqueeze(-1)\n", "\n", " # Residual\n", " if self._proj is not None:\n", " res = self._proj(x.flatten(1)).view(B, self._nout, self._fdim)\n", " else:\n", " res = x\n", " out = out + 0.1 * res\n", "\n", " return out, {\n", " 'd2': d2,\n", " 'vol2': vol2,\n", " 'valid': valid,\n", " 'attn': attn,\n", " 'order': self._order,\n", " 'potential': self._potential,\n", " }\n", "\n", "\n", "# ============================================================================\n", "# COMPLEX - EXPONENTIAL GROWTH\n", "# ============================================================================\n", "\n", "class GeoComplex(nn.Module):\n", " def __init__(self, nbase, fdim, depth, edim):\n", " super().__init__()\n", " self._nbase = nbase\n", " self._fdim = fdim\n", " self._depth = depth\n", "\n", " # EXPONENTIAL GROWTH: each level doubles\n", " # Outputs become vertices for next level\n", " sizes = [nbase]\n", " for _ in range(depth):\n", " sizes.append(sizes[-1] * 2)\n", " self._sizes = sizes\n", "\n", " # Compute potential simplices at each level\n", " potentials = []\n", " for i in range(depth):\n", " k = i + 1 # simplex order\n", " nv = k + 1 # vertices needed\n", " nin = sizes[i]\n", " pot = math.comb(nin, nv) if nin >= nv else 0\n", " potentials.append(pot)\n", " self._potentials = potentials\n", "\n", " # Input projection: fdim -> nbase * fdim\n", " self._proj_in = nn.Linear(fdim, nbase * fdim)\n", "\n", " # Build levels\n", " level_list = []\n", " for i in range(depth):\n", " lv = GeoLevel(\n", " order=i + 1,\n", " nin=sizes[i],\n", " nout=sizes[i + 1],\n", " fdim=fdim,\n", " edim=edim\n", " )\n", " level_list.append(lv)\n", " self._levels = nn.ModuleList(level_list)\n", "\n", " # Level weights\n", " self.register_parameter('_lw', nn.Parameter(torch.ones(depth) / depth))\n", "\n", " def forward(self, x):\n", " B = x.shape[0]\n", " h = self._proj_in(x).view(B, self._nbase, self._fdim)\n", "\n", " outs, infos = [], []\n", " for lv in self._levels:\n", " h, info = lv(h)\n", " outs.append(h)\n", " infos.append(info)\n", "\n", " w = F.softmax(self._lw, dim=0)\n", " pooled = []\n", " for i, (o, inf) in enumerate(zip(outs, infos)):\n", " a = inf['attn']\n", " wm = (o * a.unsqueeze(-1)).sum(1) / (a.sum(1, keepdim=True) + 1e-8)\n", " pooled.append(w[i] * wm)\n", "\n", " return sum(pooled), {'sizes': self._sizes, 'potentials': self._potentials, 'infos': infos, 'weights': w.detach()}\n", "\n", "\n", "# ============================================================================\n", "# MODEL\n", "# ============================================================================\n", "\n", "class FashionGeoSimplex(nn.Module):\n", " def __init__(self, nbase=4, fdim=32, depth=5, edim=8, ncpx=2):\n", " super().__init__()\n", "\n", " self._stem = nn.Sequential(\n", " nn.Conv2d(1, 32, 3, padding=1), nn.BatchNorm2d(32), nn.GELU(),\n", " nn.Conv2d(32, 64, 3, stride=2, padding=1), nn.BatchNorm2d(64), nn.GELU(),\n", " nn.Conv2d(64, 128, 3, stride=2, padding=1), nn.BatchNorm2d(128), nn.GELU(),\n", " )\n", " self._pool = nn.AdaptiveAvgPool2d(1)\n", " self._proj = nn.Linear(128, fdim)\n", "\n", " cpx_list = []\n", " for i in range(ncpx):\n", " c = GeoComplex(nbase, fdim, depth, edim)\n", " cpx_list.append(c)\n", " print(f\"Complex {i+1}:\")\n", " print(f\" Sizes: {c._sizes}\")\n", " print(f\" Potentials per level:\")\n", " for k, pot in enumerate(c._potentials):\n", " print(f\" k={k+1}: C({c._sizes[k]},{k+2}) = {pot:,} potential {k+1}-simplices\")\n", " self._cpx = nn.ModuleList(cpx_list)\n", " self._norms = nn.ModuleList([nn.LayerNorm(fdim) for _ in range(ncpx)])\n", "\n", " self._head = nn.Linear(fdim, 10)\n", "\n", " def forward(self, x, ret_diag=False):\n", " h = self._stem(x)\n", " h = self._pool(h).flatten(1)\n", " h = self._proj(h)\n", "\n", " diags = []\n", " validity_loss = 0\n", " for cpx, norm in zip(self._cpx, self._norms):\n", " out, diag = cpx(h)\n", " diags.append(diag)\n", " for inf in diag['infos']:\n", " validity_loss = validity_loss + F.relu(-inf['vol2']).mean()\n", " h = norm(h + out)\n", "\n", " logits = self._head(h)\n", "\n", " if ret_diag:\n", " return logits, diags, validity_loss\n", " return logits, validity_loss\n", "\n", "\n", "# ============================================================================\n", "# TRAIN\n", "# ============================================================================\n", "\n", "def train():\n", " transform = transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.2860,), (0.3530,))\n", " ])\n", "\n", " train_ds = datasets.FashionMNIST('./data', train=True, download=True, transform=transform)\n", " test_ds = datasets.FashionMNIST('./data', train=False, download=True, transform=transform)\n", " train_dl = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=2)\n", " test_dl = DataLoader(test_ds, batch_size=256, shuffle=False, num_workers=2)\n", "\n", " print(\"\\nBuilding model...\")\n", " model = FashionGeoSimplex(\n", " nbase=4,\n", " fdim=32,\n", " depth=5,\n", " edim=8,\n", " ncpx=2\n", " ).to(device)\n", "\n", " print(f\"\\nTotal params: {sum(p.numel() for p in model.parameters()):,}\")\n", "\n", " opt = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.05)\n", " sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=30)\n", "\n", " best = 0\n", " print(\"\\nTraining...\")\n", " print(\"=\" * 100)\n", "\n", " for ep in range(30):\n", " model.train()\n", " lsum, vsum, cor, tot = 0, 0, 0, 0\n", "\n", " for img, lab in train_dl:\n", " img, lab = img.to(device), lab.to(device)\n", "\n", " opt.zero_grad()\n", " logits, vloss = model(img)\n", " ce = F.cross_entropy(logits, lab)\n", " loss = ce + 0.1 * vloss\n", " loss.backward()\n", " opt.step()\n", "\n", " lsum += ce.item() * img.size(0)\n", " vsum += vloss.item() * img.size(0)\n", " cor += (logits.argmax(1) == lab).sum().item()\n", " tot += img.size(0)\n", "\n", " tr_acc = cor / tot\n", " tr_loss = lsum / tot\n", "\n", " model.eval()\n", " cor, tot = 0, 0\n", " samp = None\n", " with torch.no_grad():\n", " for img, lab in test_dl:\n", " img, lab = img.to(device), lab.to(device)\n", " logits, diag, _ = model(img, ret_diag=True)\n", " cor += (logits.argmax(1) == lab).sum().item()\n", " tot += img.size(0)\n", " if samp is None:\n", " samp = diag\n", "\n", " te_acc = cor / tot\n", " sched.step()\n", " if te_acc > best:\n", " best = te_acc\n", "\n", " lw = samp[0]['weights'].cpu().numpy()\n", "\n", " if ep % 5 == 0 or ep == 29:\n", " print(f\"Ep {ep+1:2d} | CE {tr_loss:.4f} | Tr {tr_acc:.2%} | Te {te_acc:.2%} | Best {best:.2%}\")\n", " print(f\" | Weights: {['%.3f' % w for w in lw]}\")\n", " for inf in samp[0]['infos']:\n", " k = inf['order']\n", " v = inf['vol2'].mean().item()\n", " vld = inf['valid'].float().mean().item()\n", " att = inf['attn'].mean().item()\n", " d = inf['d2'].mean().item() if inf['d2'].numel() > 0 else 0\n", " pot = inf['potential']\n", " print(f\" k={k}: potential={pot:>10,} vol²={v:.2e} valid={vld:5.1%} attn={att:.3f} d²={d:.4f}\")\n", "\n", " print(\"=\" * 100)\n", " print(f\"Best: {best:.2%}\")\n", " return model\n", "\n", "\n", "if __name__ == \"__main__\":\n", " model = train()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "cellView": "form", "id": "NwBIvJu0VDlc", "outputId": "87067b96-dc52-4aa5-8115-4a217467cf31" }, "execution_count": 6, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Device: cuda\n", "\n", "Building model...\n", "Complex 1:\n", " Sizes: [4, 8, 16, 32, 64, 128]\n", " Potentials per level:\n", " k=1: C(4,2) = 6 potential 1-simplices\n", " k=2: C(8,3) = 56 potential 2-simplices\n", " k=3: C(16,4) = 1,820 potential 3-simplices\n", " k=4: C(32,5) = 201,376 potential 4-simplices\n", " k=5: C(64,6) = 74,974,368 potential 5-simplices\n", "Complex 2:\n", " Sizes: [4, 8, 16, 32, 64, 128]\n", " Potentials per level:\n", " k=1: C(4,2) = 6 potential 1-simplices\n", " k=2: C(8,3) = 56 potential 2-simplices\n", " k=3: C(16,4) = 1,820 potential 3-simplices\n", " k=4: C(32,5) = 201,376 potential 4-simplices\n", " k=5: C(64,6) = 74,974,368 potential 5-simplices\n", "\n", "Total params: 22,661,428\n", "\n", "Training...\n", "====================================================================================================\n", "Ep 1 | CE 0.6211 | Tr 79.52% | Te 83.99% | Best 83.99%\n", " | Weights: ['0.213', '0.201', '0.201', '0.200', '0.184']\n", " k=1: potential= 6 vol²=8.85e-01 valid=100.0% attn=1.000 d²=0.8848\n", " k=2: potential= 56 vol²=1.95e-01 valid=100.0% attn=1.000 d²=1.0221\n", " k=3: potential= 1,820 vol²=1.29e-02 valid=100.0% attn=1.000 d²=0.9890\n", " k=4: potential= 201,376 vol²=5.25e-04 valid=100.0% attn=1.000 d²=1.0006\n", " k=5: potential=74,974,368 vol²=1.22e-05 valid=100.0% attn=1.000 d²=1.0023\n", "Ep 6 | CE 0.2642 | Tr 90.43% | Te 86.53% | Best 88.26%\n", " | Weights: ['0.209', '0.193', '0.183', '0.238', '0.178']\n", " k=1: potential= 6 vol²=2.19e-01 valid=100.0% attn=1.000 d²=0.2186\n", " k=2: potential= 56 vol²=1.85e-01 valid=100.0% attn=1.000 d²=1.0016\n", " k=3: potential= 1,820 vol²=1.27e-02 valid=100.0% attn=1.000 d²=0.9950\n", " k=4: potential= 201,376 vol²=2.90e-04 valid=100.0% attn=1.000 d²=0.9341\n", " k=5: potential=74,974,368 vol²=2.38e-06 valid=100.0% attn=0.915 d²=0.9012\n", "Ep 11 | CE 0.1861 | Tr 93.23% | Te 89.94% | Best 89.94%\n", " | Weights: ['0.202', '0.180', '0.165', '0.268', '0.185']\n", " k=1: potential= 6 vol²=9.29e-02 valid=100.0% attn=1.000 d²=0.0929\n", " k=2: potential= 56 vol²=1.98e-01 valid=100.0% attn=1.000 d²=1.0434\n", " k=3: potential= 1,820 vol²=1.31e-02 valid=100.0% attn=1.000 d²=1.0180\n", " k=4: potential= 201,376 vol²=1.26e-04 valid=100.0% attn=1.000 d²=1.1049\n", " k=5: potential=74,974,368 vol²=3.15e-08 valid=100.0% attn=0.508 d²=0.6983\n", "Ep 16 | CE 0.1142 | Tr 95.81% | Te 89.96% | Best 90.12%\n", " | Weights: ['0.203', '0.175', '0.157', '0.279', '0.186']\n", " k=1: potential= 6 vol²=8.39e-02 valid=100.0% attn=1.000 d²=0.0839\n", " k=2: potential= 56 vol²=2.04e-01 valid=100.0% attn=1.000 d²=1.0218\n", " k=3: potential= 1,820 vol²=1.30e-02 valid=100.0% attn=1.000 d²=1.0185\n", " k=4: potential= 201,376 vol²=1.72e-05 valid=100.0% attn=1.000 d²=1.1952\n", " k=5: potential=74,974,368 vol²=3.01e-08 valid=100.0% attn=0.508 d²=0.7236\n", "Ep 21 | CE 0.0577 | Tr 98.02% | Te 89.94% | Best 90.20%\n", " | Weights: ['0.205', '0.177', '0.154', '0.279', '0.185']\n", " k=1: potential= 6 vol²=8.29e-02 valid=100.0% attn=1.000 d²=0.0829\n", " k=2: potential= 56 vol²=2.08e-01 valid=100.0% attn=1.000 d²=1.0266\n", " k=3: potential= 1,820 vol²=1.30e-02 valid=100.0% attn=1.000 d²=1.0203\n", " k=4: potential= 201,376 vol²=8.19e-06 valid=100.0% attn=0.979 d²=1.2106\n", " k=5: potential=74,974,368 vol²=3.26e-08 valid=100.0% attn=0.508 d²=0.7480\n", "Ep 26 | CE 0.0210 | Tr 99.36% | Te 90.23% | Best 90.23%\n", " | Weights: ['0.207', '0.182', '0.154', '0.275', '0.182']\n", " k=1: potential= 6 vol²=6.87e-02 valid=100.0% attn=1.000 d²=0.0687\n", " k=2: potential= 56 vol²=1.84e-01 valid=100.0% attn=1.000 d²=0.9339\n", " k=3: potential= 1,820 vol²=1.36e-02 valid=100.0% attn=1.000 d²=1.0209\n", " k=4: potential= 201,376 vol²=1.38e-06 valid=100.0% attn=0.689 d²=1.1003\n", " k=5: potential=74,974,368 vol²=3.84e-08 valid=100.0% attn=0.510 d²=0.7503\n", "Ep 30 | CE 0.0113 | Tr 99.71% | Te 90.19% | Best 90.23%\n", " | Weights: ['0.207', '0.182', '0.154', '0.274', '0.182']\n", " k=1: potential= 6 vol²=6.02e-02 valid=100.0% attn=1.000 d²=0.0602\n", " k=2: potential= 56 vol²=1.92e-01 valid=100.0% attn=1.000 d²=0.9515\n", " k=3: potential= 1,820 vol²=1.37e-02 valid=100.0% attn=1.000 d²=1.0220\n", " k=4: potential= 201,376 vol²=1.98e-06 valid=100.0% attn=0.754 d²=1.0963\n", " k=5: potential=74,974,368 vol²=4.16e-08 valid=100.0% attn=0.510 d²=0.7500\n", "====================================================================================================\n", "Best: 90.23%\n" ] } ] }, { "cell_type": "code", "source": [ "#@title Geometric Patchwork ViT - Full K-Simplex Hierarchy\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets, transforms\n", "import math\n", "from itertools import combinations\n", "\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "print(f\"Device: {device}\")\n", "\n", "from geovocab2.shapes.factory.simplex_factory import SimplexFactory\n", "\n", "# ============================================================================\n", "# CAYLEY-MENGER VALIDATOR\n", "# ============================================================================\n", "\n", "class CMValidator(nn.Module):\n", " def __init__(self, k):\n", " super().__init__()\n", " self._k = k\n", " self._nv = k + 1\n", "\n", " pairs = list(combinations(range(self._nv), 2))\n", " self._npairs = len(pairs)\n", " self.register_buffer('_pi', torch.tensor([p[0] for p in pairs], dtype=torch.long))\n", " self.register_buffer('_pj', torch.tensor([p[1] for p in pairs], dtype=torch.long))\n", "\n", " sign = (-1.0) ** (k + 1)\n", " fact = math.factorial(k)\n", " self._prefactor = sign / ((2.0 ** k) * (fact ** 2))\n", "\n", " def forward(self, verts):\n", " gram = torch.einsum('...ve,...we->...vw', verts, verts)\n", " norms = torch.diagonal(gram, dim1=-2, dim2=-1)\n", " d2_mat = norms.unsqueeze(-1) + norms.unsqueeze(-2) - 2 * gram\n", " d2_mat = F.relu(d2_mat)\n", "\n", " d2_pairs = d2_mat[..., self._pi, self._pj]\n", "\n", " shape = d2_mat.shape[:-2]\n", " V = d2_mat.shape[-1]\n", " cm = torch.zeros(*shape, V+1, V+1, device=d2_mat.device, dtype=d2_mat.dtype)\n", " cm[..., 0, 1:] = 1.0\n", " cm[..., 1:, 0] = 1.0\n", " cm[..., 1:, 1:] = d2_mat\n", " det = torch.linalg.det(cm)\n", "\n", " vol2 = self._prefactor * det\n", " valid = vol2 > 1e-8\n", "\n", " return d2_pairs, vol2, valid\n", "\n", "\n", "# ============================================================================\n", "# GEOMETRIC LEVEL (k-simplex layer with factory template)\n", "# ============================================================================\n", "\n", "class GeoLevel(nn.Module):\n", " BASE_DEFORM = 0.05\n", "\n", " def __init__(self, order, nin, nout, fdim, edim):\n", " super().__init__()\n", " self._order = order\n", " self._nin = nin\n", " self._nout = nout\n", " self._fdim = fdim\n", " self._edim = edim\n", " self._nv = order + 1\n", "\n", " self._potential = math.comb(nin, self._nv) if nin >= self._nv else 0\n", "\n", " self._cm = CMValidator(order)\n", "\n", " # Factory template - guaranteed valid\n", " factory = SimplexFactory(k=order, embed_dim=edim, method=\"regular\", scale=1.0)\n", " template = factory.build_torch(dtype=torch.float32)\n", " self.register_buffer('_template', template)\n", "\n", " # Vertex selection: soft-select nv vertices from nin inputs\n", " _sel = torch.randn(nout, self._nv, nin) * 0.1\n", " self.register_parameter('_W_select', nn.Parameter(_sel))\n", "\n", " # Deformation network\n", " self._deform = nn.Sequential(\n", " nn.Linear(fdim, fdim),\n", " nn.LayerNorm(fdim),\n", " nn.GELU(),\n", " nn.Linear(fdim, self._nv * edim),\n", " )\n", "\n", " # Feature MLP\n", " self._mlp = nn.Sequential(\n", " nn.Linear(fdim, fdim * 2),\n", " nn.LayerNorm(fdim * 2),\n", " nn.GELU(),\n", " nn.Linear(fdim * 2, fdim),\n", " )\n", "\n", " # Geometry gate\n", " gate_in = self._cm._npairs + 1\n", " self._gate = nn.Sequential(\n", " nn.Linear(gate_in, fdim),\n", " nn.Sigmoid(),\n", " )\n", "\n", " # Residual projection\n", " if nin != nout:\n", " self._proj = nn.Linear(nin * fdim, nout * fdim)\n", " else:\n", " self._proj = None\n", "\n", " def forward(self, x):\n", " \"\"\"x: [B, nin, fdim]\"\"\"\n", " B = x.shape[0]\n", "\n", " sel = F.softmax(self._W_select, dim=-1)\n", " picked = torch.einsum('ovn,bnf->bovf', sel, x)\n", " agg = picked.mean(dim=2)\n", "\n", " # Deformation from template\n", " deform = self._deform(agg).view(B, self._nout, self._nv, self._edim)\n", " template = self._template.unsqueeze(0).unsqueeze(0).expand(B, self._nout, -1, -1)\n", " verts = template + self.BASE_DEFORM * deform\n", "\n", " # CM validation\n", " d2, vol2, valid = self._cm(verts)\n", "\n", " # Geometry gating\n", " d2_norm = d2 / (d2.mean(dim=-1, keepdim=True) + 1e-8)\n", " geo = torch.cat([d2_norm, vol2.unsqueeze(-1) / (vol2.mean() + 1e-8)], dim=-1)\n", " gate = self._gate(geo)\n", "\n", " out = self._mlp(agg) * gate\n", "\n", " # Validity-only attenuation\n", " attn = torch.sigmoid(vol2 * 1e6)\n", " out = out * attn.unsqueeze(-1)\n", "\n", " # Residual\n", " if self._proj is not None:\n", " res = self._proj(x.flatten(1)).view(B, self._nout, self._fdim)\n", " else:\n", " res = x\n", " out = out + 0.1 * res\n", "\n", " return out, {\n", " 'd2': d2, 'vol2': vol2, 'valid': valid, 'attn': attn,\n", " 'order': self._order, 'potential': self._potential,\n", " }\n", "\n", "\n", "# ============================================================================\n", "# GEOMETRIC COMPLEX (k=1→2→3→... with exponential growth)\n", "# ============================================================================\n", "\n", "class GeoComplex(nn.Module):\n", " def __init__(self, nbase, fdim, depth, edim):\n", " super().__init__()\n", " self._nbase = nbase\n", " self._fdim = fdim\n", " self._depth = depth\n", "\n", " # EXPONENTIAL GROWTH\n", " sizes = [nbase]\n", " for _ in range(depth):\n", " sizes.append(sizes[-1] * 2)\n", " self._sizes = sizes\n", "\n", " # Potentials per level\n", " potentials = []\n", " for i in range(depth):\n", " k = i + 1\n", " nv = k + 1\n", " nin = sizes[i]\n", " pot = math.comb(nin, nv) if nin >= nv else 0\n", " potentials.append(pot)\n", " self._potentials = potentials\n", "\n", " # Input projection\n", " self._proj_in = nn.Linear(fdim, nbase * fdim)\n", "\n", " # Build levels k=1, 2, 3, ...\n", " level_list = []\n", " for i in range(depth):\n", " lv = GeoLevel(\n", " order=i + 1,\n", " nin=sizes[i],\n", " nout=sizes[i + 1],\n", " fdim=fdim,\n", " edim=edim\n", " )\n", " level_list.append(lv)\n", " self._levels = nn.ModuleList(level_list)\n", "\n", " # Level weights\n", " self.register_parameter('_lw', nn.Parameter(torch.ones(depth) / depth))\n", "\n", " def forward(self, x):\n", " \"\"\"x: [B, fdim]\"\"\"\n", " B = x.shape[0]\n", " h = self._proj_in(x).view(B, self._nbase, self._fdim)\n", "\n", " outs, infos = [], []\n", " for lv in self._levels:\n", " h, info = lv(h)\n", " outs.append(h)\n", " infos.append(info)\n", "\n", " # Weighted pool across levels\n", " w = F.softmax(self._lw, dim=0)\n", " pooled = []\n", " for i, (o, inf) in enumerate(zip(outs, infos)):\n", " a = inf['attn']\n", " wm = (o * a.unsqueeze(-1)).sum(1) / (a.sum(1, keepdim=True) + 1e-8)\n", " pooled.append(w[i] * wm)\n", "\n", " return sum(pooled), {\n", " 'sizes': self._sizes,\n", " 'potentials': self._potentials,\n", " 'infos': infos,\n", " 'weights': w.detach(),\n", " }\n", "\n", "\n", "# ============================================================================\n", "# GEOMETRIC PATCH EMBEDDING (each patch → k-simplex hierarchy)\n", "# ============================================================================\n", "\n", "class GeoPatchEmbed(nn.Module):\n", " \"\"\"\n", " Each patch processed through full k-simplex hierarchy.\n", " \"\"\"\n", " def __init__(self, img_size=28, patch_size=7, in_chans=1, embed_dim=64,\n", " nbase=4, depth=4, edim=8):\n", " super().__init__()\n", " self.img_size = img_size\n", " self.patch_size = patch_size\n", " self.num_patches = (img_size // patch_size) ** 2\n", " self.patch_dim = patch_size * patch_size * in_chans\n", " self.embed_dim = embed_dim\n", "\n", " # Patch to feature dim\n", " self._patch_proj = nn.Sequential(\n", " nn.Linear(self.patch_dim, embed_dim),\n", " nn.LayerNorm(embed_dim),\n", " nn.GELU(),\n", " )\n", "\n", " # Geometric complex per patch\n", " self._complex = GeoComplex(nbase, embed_dim, depth, edim)\n", "\n", " # CLS token\n", " self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim) * 0.02)\n", "\n", " # Position embedding\n", " self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches + 1, embed_dim) * 0.02)\n", "\n", " def forward(self, x):\n", " \"\"\"x: [B, C, H, W]\"\"\"\n", " B, C, H, W = x.shape\n", "\n", " # Extract patches\n", " patches = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)\n", " patches = patches.contiguous().view(B, C, -1, self.patch_size * self.patch_size)\n", " patches = patches.permute(0, 2, 1, 3).contiguous().view(B, self.num_patches, -1)\n", "\n", " # Project patches\n", " patch_feats = self._patch_proj(patches) # [B, P, embed_dim]\n", "\n", " # Process each patch through geometric complex\n", " # Reshape for batch processing: [B*P, embed_dim]\n", " flat = patch_feats.view(B * self.num_patches, self.embed_dim)\n", " geo_out, geo_info = self._complex(flat)\n", " geo_out = geo_out.view(B, self.num_patches, self.embed_dim)\n", "\n", " # Add CLS token\n", " cls = self.cls_token.expand(B, -1, -1)\n", " tokens = torch.cat([cls, geo_out], dim=1)\n", "\n", " # Add position embedding\n", " tokens = tokens + self.pos_embed\n", "\n", " return tokens, geo_info\n", "\n", "\n", "# ============================================================================\n", "# 4-SIMPLEX ATTENTION (sequence stream)\n", "# ============================================================================\n", "\n", "class Simplex4Attention(nn.Module):\n", " \"\"\"\n", " 4-simplex attention: each head soft-selects 5 tokens to form a 4-simplex.\n", " CM-validated geometry determines attention.\n", " \"\"\"\n", " def __init__(self, embed_dim, num_heads=4, edim=8):\n", " super().__init__()\n", " self.embed_dim = embed_dim\n", " self.num_heads = num_heads\n", " self.head_dim = embed_dim // num_heads\n", " self.edim = edim\n", "\n", " self._k = 4\n", " self._nv = 5\n", "\n", " self._cm = CMValidator(self._k)\n", "\n", " # Factory template\n", " factory = SimplexFactory(k=self._k, embed_dim=edim, method=\"regular\", scale=1.0)\n", " template = factory.build_torch(dtype=torch.float32)\n", " self.register_buffer('_template', template)\n", "\n", " # Token to vertex selection scores (per head)\n", " self._to_vertex_scores = nn.Linear(embed_dim, num_heads * self._nv)\n", "\n", " # Token to coordinate for deformation\n", " self._to_coord = nn.Linear(embed_dim, edim)\n", "\n", " # Value projection\n", " self._to_v = nn.Linear(embed_dim, embed_dim)\n", "\n", " # Geometry to output gate\n", " geom_dim = self._cm._npairs + 1 # 10 + 1 = 11\n", " self._geo_gate = nn.Sequential(\n", " nn.Linear(geom_dim, num_heads),\n", " nn.Sigmoid(),\n", " )\n", "\n", " # Output projection\n", " self._out = nn.Linear(embed_dim, embed_dim)\n", "\n", " self._deform_scale = 0.05\n", "\n", " def forward(self, x):\n", " \"\"\"x: [B, N, D]\"\"\"\n", " B, N, D = x.shape\n", " H = self.num_heads\n", "\n", " # Vertex selection: [B, N, H*5] -> [B, H, 5, N]\n", " vs = self._to_vertex_scores(x).view(B, N, H, self._nv).permute(0, 2, 3, 1)\n", " vw = F.softmax(vs, dim=-1) # [B, H, 5, N]\n", "\n", " # Coordinates\n", " coords = self._to_coord(x) # [B, N, edim]\n", "\n", " # Select coordinates: [B, H, 5, edim]\n", " sel_coords = torch.einsum('bhvn,bne->bhve', vw, coords)\n", "\n", " # Deform template\n", " template = self._template.unsqueeze(0).unsqueeze(0)\n", " verts = template + self._deform_scale * sel_coords\n", "\n", " # CM validation\n", " d2, vol2, valid = self._cm(verts)\n", "\n", " # Validity per head\n", " head_valid = torch.sigmoid(vol2 * 1e6) # [B, H]\n", "\n", " # Geometry gate\n", " geo = torch.cat([d2, vol2.unsqueeze(-1)], dim=-1)\n", " geo_gate = self._geo_gate(geo).mean(dim=-1) # [B, H]\n", " head_weight = geo_gate * head_valid\n", " head_weight = head_weight / (head_weight.sum(dim=-1, keepdim=True) + 1e-8)\n", "\n", " # Values\n", " v = self._to_v(x).view(B, N, H, self.head_dim).permute(0, 2, 1, 3) # [B, H, N, hd]\n", "\n", " # Attention from vertex weights (average over 5 positions)\n", " attn = vw.mean(dim=2) # [B, H, N]\n", "\n", " # Aggregate\n", " out = torch.einsum('bhn,bhnd->bhd', attn, v) # [B, H, hd]\n", " out = out * head_weight.unsqueeze(-1)\n", " out = out.reshape(B, D)\n", "\n", " # Project and broadcast\n", " out = self._out(out).unsqueeze(1).expand(-1, N, -1)\n", "\n", " return out, {\n", " 'd2': d2, 'vol2': vol2, 'valid': valid,\n", " 'head_weight': head_weight, 'vertex_weights': vw,\n", " }\n", "\n", "\n", "# ============================================================================\n", "# TRANSFORMER BLOCK\n", "# ============================================================================\n", "\n", "class GeoTransformerBlock(nn.Module):\n", " def __init__(self, embed_dim, num_heads, edim, mlp_ratio=4.0):\n", " super().__init__()\n", " self.norm1 = nn.LayerNorm(embed_dim)\n", " self.attn = Simplex4Attention(embed_dim, num_heads, edim)\n", " self.norm2 = nn.LayerNorm(embed_dim)\n", " self.mlp = nn.Sequential(\n", " nn.Linear(embed_dim, int(embed_dim * mlp_ratio)),\n", " nn.GELU(),\n", " nn.Linear(int(embed_dim * mlp_ratio), embed_dim),\n", " )\n", "\n", " def forward(self, x):\n", " attn_out, attn_info = self.attn(self.norm1(x))\n", " x = x + attn_out\n", " x = x + self.mlp(self.norm2(x))\n", " return x, attn_info\n", "\n", "\n", "# ============================================================================\n", "# FULL MODEL\n", "# ============================================================================\n", "\n", "class GeometricPatchworkViT(nn.Module):\n", " def __init__(\n", " self,\n", " img_size=28,\n", " patch_size=7,\n", " in_chans=1,\n", " num_classes=10,\n", " embed_dim=64,\n", " nbase=4,\n", " geo_depth=4, # k=1,2,3,4 hierarchy\n", " attn_depth=4, # transformer blocks\n", " num_heads=4,\n", " edim=8,\n", " ):\n", " super().__init__()\n", "\n", " # Geometric patch embedding with k-simplex hierarchy\n", " self.patch_embed = GeoPatchEmbed(\n", " img_size=img_size,\n", " patch_size=patch_size,\n", " in_chans=in_chans,\n", " embed_dim=embed_dim,\n", " nbase=nbase,\n", " depth=geo_depth,\n", " edim=edim,\n", " )\n", "\n", " # Transformer blocks with 4-simplex attention\n", " self.blocks = nn.ModuleList([\n", " GeoTransformerBlock(embed_dim, num_heads, edim)\n", " for _ in range(attn_depth)\n", " ])\n", "\n", " self.norm = nn.LayerNorm(embed_dim)\n", " self.head = nn.Linear(embed_dim, num_classes)\n", "\n", " # Print architecture\n", " print(f\"\\nGeometricPatchworkViT:\")\n", " print(f\" Patches: {self.patch_embed.num_patches} × {patch_size}×{patch_size}\")\n", " print(f\" K-simplex hierarchy per patch:\")\n", " print(f\" Sizes: {self.patch_embed._complex._sizes}\")\n", " print(f\" Potentials: {self.patch_embed._complex._potentials}\")\n", " print(f\" 4-simplex attention: {attn_depth} blocks × {num_heads} heads\")\n", "\n", " def forward(self, x, ret_diag=False):\n", " # Patch embedding through k-simplex hierarchy\n", " tokens, patch_geo_info = self.patch_embed(x)\n", "\n", " # Transformer with 4-simplex attention\n", " block_infos = []\n", " validity_loss = 0\n", "\n", " for block in self.blocks:\n", " tokens, attn_info = block(tokens)\n", " block_infos.append(attn_info)\n", " validity_loss = validity_loss + F.relu(-attn_info['vol2']).mean()\n", "\n", " # Add patch hierarchy validity loss\n", " for inf in patch_geo_info['infos']:\n", " validity_loss = validity_loss + F.relu(-inf['vol2']).mean()\n", "\n", " # CLS for classification\n", " tokens = self.norm(tokens)\n", " cls = tokens[:, 0]\n", " logits = self.head(cls)\n", "\n", " if ret_diag:\n", " return logits, {'patch': patch_geo_info, 'blocks': block_infos}, validity_loss\n", " return logits, validity_loss\n", "\n", "\n", "# ============================================================================\n", "# TRAIN\n", "# ============================================================================\n", "\n", "def train():\n", " transform = transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.2860,), (0.3530,))\n", " ])\n", "\n", " train_ds = datasets.FashionMNIST('./data', train=True, download=True, transform=transform)\n", " test_ds = datasets.FashionMNIST('./data', train=False, download=True, transform=transform)\n", " train_dl = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=2)\n", " test_dl = DataLoader(test_ds, batch_size=256, shuffle=False, num_workers=2)\n", "\n", " print(\"Building model...\")\n", " model = GeometricPatchworkViT(\n", " img_size=28,\n", " patch_size=2,\n", " in_chans=1,\n", " num_classes=10,\n", " embed_dim=8,\n", " nbase=4,\n", " geo_depth=4, # k=1→2→3→4 hierarchy\n", " attn_depth=4, # 4 transformer blocks\n", " num_heads=4,\n", " edim=8,\n", " ).to(device)\n", "\n", " params = sum(p.numel() for p in model.parameters())\n", " print(f\"Total params: {params:,}\")\n", "\n", " opt = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.05)\n", " sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=30)\n", "\n", " best = 0\n", " print(\"\\nTraining...\")\n", " print(\"=\" * 115)\n", "\n", " for ep in range(30):\n", " model.train()\n", " lsum, vsum, cor, tot = 0, 0, 0, 0\n", "\n", " for img, lab in train_dl:\n", " img, lab = img.to(device), lab.to(device)\n", "\n", " opt.zero_grad()\n", " logits, vloss = model(img)\n", " ce = F.cross_entropy(logits, lab)\n", " loss = ce + 0.1 * vloss\n", " loss.backward()\n", " opt.step()\n", "\n", " lsum += ce.item() * img.size(0)\n", " vsum += vloss.item() * img.size(0)\n", " cor += (logits.argmax(1) == lab).sum().item()\n", " tot += img.size(0)\n", "\n", " tr_acc = cor / tot\n", " tr_loss = lsum / tot\n", "\n", " model.eval()\n", " cor, tot = 0, 0\n", " samp = None\n", " with torch.no_grad():\n", " for img, lab in test_dl:\n", " img, lab = img.to(device), lab.to(device)\n", " logits, diag, _ = model(img, ret_diag=True)\n", " cor += (logits.argmax(1) == lab).sum().item()\n", " tot += img.size(0)\n", " if samp is None:\n", " samp = diag\n", "\n", " te_acc = cor / tot\n", " sched.step()\n", " if te_acc > best:\n", " best = te_acc\n", "\n", " if ep % 5 == 0 or ep == 29:\n", " print(f\"\\nEp {ep+1:2d} | CE {tr_loss:.4f} | Tr {tr_acc:.2%} | Te {te_acc:.2%} | Best {best:.2%}\")\n", "\n", " # Patch hierarchy stats\n", " pw = samp['patch']['weights'].cpu().numpy()\n", " print(f\" Patch k-hierarchy weights: {['%.3f' % w for w in pw]}\")\n", " for inf in samp['patch']['infos']:\n", " k = inf['order']\n", " v = inf['vol2'].mean().item()\n", " vld = inf['valid'].float().mean().item()\n", " att = inf['attn'].mean().item()\n", " d = inf['d2'].mean().item()\n", " pot = inf['potential']\n", " print(f\" k={k}: pot={pot:>8,} vol²={v:.2e} valid={vld:5.1%} attn={att:.3f} d²={d:.4f}\")\n", "\n", " # Attention block stats\n", " for i, bi in enumerate(samp['blocks']):\n", " v = bi['vol2'].mean().item()\n", " vld = bi['valid'].float().mean().item()\n", " hw = bi['head_weight'].mean(dim=0).cpu().numpy()\n", " print(f\" Block {i+1} 4-simplex: vol²={v:.2e} valid={vld:5.1%} heads={['%.2f' % h for h in hw]}\")\n", "\n", " print(\"=\" * 115)\n", " print(f\"Best: {best:.2%}\")\n", " return model\n", "\n", "\n", "if __name__ == \"__main__\":\n", " model = train()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "cellView": "form", "id": "GUzM109PYxIE", "outputId": "a57d14dc-3503-46a5-b01a-c9f70ba0a1c0" }, "execution_count": 8, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Device: cuda\n", "Building model...\n", "\n", "GeometricPatchworkViT:\n", " Patches: 16 × 7×7\n", " K-simplex hierarchy per patch:\n", " Sizes: [4, 8, 16, 32, 64]\n", " Potentials: [6, 56, 1820, 201376]\n", " 4-simplex attention: 4 blocks × 4 heads\n", "Total params: 11,451,118\n", "\n", "Training...\n", "===================================================================================================================\n", "\n", "Ep 1 | CE 0.8111 | Tr 70.79% | Te 80.51% | Best 80.51%\n", " Patch k-hierarchy weights: ['0.287', '0.260', '0.238', '0.215']\n", " k=1: pot= 6 vol²=8.99e-01 valid=100.0% attn=1.000 d²=0.8991\n", " k=2: pot= 56 vol²=1.84e-01 valid=100.0% attn=1.000 d²=0.9894\n", " k=3: pot= 1,820 vol²=1.21e-02 valid=100.0% attn=1.000 d²=0.9959\n", " k=4: pot= 201,376 vol²=5.20e-04 valid=100.0% attn=1.000 d²=1.0052\n", " Block 1 4-simplex: vol²=5.32e-04 valid=100.0% heads=['0.25', '0.25', '0.25', '0.25']\n", " Block 2 4-simplex: vol²=5.84e-04 valid=100.0% heads=['0.25', '0.25', '0.24', '0.26']\n", " Block 3 4-simplex: vol²=5.32e-04 valid=100.0% heads=['0.25', '0.25', '0.25', '0.25']\n", " Block 4 4-simplex: vol²=5.68e-04 valid=100.0% heads=['0.25', '0.26', '0.24', '0.25']\n", "\n", "Ep 6 | CE 0.3254 | Tr 88.00% | Te 86.47% | Best 86.47%\n", " Patch k-hierarchy weights: ['0.341', '0.247', '0.237', '0.174']\n", " k=1: pot= 6 vol²=7.11e-01 valid=100.0% attn=1.000 d²=0.7110\n", " k=2: pot= 56 vol²=1.39e-01 valid=100.0% attn=1.000 d²=0.9200\n", " k=3: pot= 1,820 vol²=8.78e-03 valid=100.0% attn=1.000 d²=0.9587\n", " k=4: pot= 201,376 vol²=3.98e-04 valid=100.0% attn=1.000 d²=0.9743\n", " Block 1 4-simplex: vol²=5.63e-04 valid=100.0% heads=['0.26', '0.24', '0.24', '0.26']\n", " Block 2 4-simplex: vol²=8.67e-04 valid=100.0% heads=['0.23', '0.23', '0.18', '0.36']\n", " Block 3 4-simplex: vol²=5.51e-04 valid=100.0% heads=['0.26', '0.25', '0.24', '0.25']\n", " Block 4 4-simplex: vol²=8.45e-04 valid=100.0% heads=['0.24', '0.34', '0.23', '0.19']\n", "\n", "Ep 11 | CE 0.2659 | Tr 90.07% | Te 88.09% | Best 88.09%\n", " Patch k-hierarchy weights: ['0.387', '0.240', '0.228', '0.146']\n", " k=1: pot= 6 vol²=4.17e-01 valid=100.0% attn=1.000 d²=0.4174\n", " k=2: pot= 56 vol²=9.97e-02 valid=100.0% attn=1.000 d²=0.7037\n", " k=3: pot= 1,820 vol²=8.29e-03 valid=100.0% attn=1.000 d²=0.8660\n", " k=4: pot= 201,376 vol²=4.20e-04 valid=100.0% attn=1.000 d²=0.9698\n", " Block 1 4-simplex: vol²=5.85e-04 valid=100.0% heads=['0.26', '0.23', '0.25', '0.25']\n", " Block 2 4-simplex: vol²=9.18e-04 valid=100.0% heads=['0.22', '0.25', '0.16', '0.37']\n", " Block 3 4-simplex: vol²=6.01e-04 valid=100.0% heads=['0.27', '0.24', '0.22', '0.26']\n", " Block 4 4-simplex: vol²=1.26e-03 valid=100.0% heads=['0.16', '0.44', '0.20', '0.20']\n", "\n", "Ep 16 | CE 0.2147 | Tr 91.95% | Te 88.97% | Best 88.97%\n", " Patch k-hierarchy weights: ['0.417', '0.237', '0.219', '0.127']\n", " k=1: pot= 6 vol²=2.79e-01 valid=100.0% attn=1.000 d²=0.2792\n", " k=2: pot= 56 vol²=5.88e-02 valid=100.0% attn=1.000 d²=0.6106\n", " k=3: pot= 1,820 vol²=7.82e-03 valid=100.0% attn=1.000 d²=0.7911\n", " k=4: pot= 201,376 vol²=2.37e-04 valid=100.0% attn=1.000 d²=0.8056\n", " Block 1 4-simplex: vol²=5.91e-04 valid=100.0% heads=['0.26', '0.23', '0.25', '0.25']\n", " Block 2 4-simplex: vol²=9.90e-04 valid=100.0% heads=['0.21', '0.29', '0.14', '0.37']\n", " Block 3 4-simplex: vol²=6.33e-04 valid=100.0% heads=['0.25', '0.25', '0.25', '0.25']\n", " Block 4 4-simplex: vol²=1.48e-03 valid=100.0% heads=['0.14', '0.46', '0.19', '0.21']\n", "\n", "Ep 21 | CE 0.1570 | Tr 94.20% | Te 90.06% | Best 90.06%\n", " Patch k-hierarchy weights: ['0.421', '0.237', '0.217', '0.125']\n", " k=1: pot= 6 vol²=2.33e-01 valid=100.0% attn=1.000 d²=0.2326\n", " k=2: pot= 56 vol²=4.50e-02 valid=100.0% attn=1.000 d²=0.5084\n", " k=3: pot= 1,820 vol²=5.98e-03 valid=100.0% attn=1.000 d²=0.6563\n", " k=4: pot= 201,376 vol²=2.94e-04 valid=100.0% attn=1.000 d²=0.7931\n", " Block 1 4-simplex: vol²=6.10e-04 valid=100.0% heads=['0.26', '0.23', '0.26', '0.25']\n", " Block 2 4-simplex: vol²=9.67e-04 valid=100.0% heads=['0.21', '0.29', '0.14', '0.36']\n", " Block 3 4-simplex: vol²=6.64e-04 valid=100.0% heads=['0.26', '0.24', '0.24', '0.25']\n", " Block 4 4-simplex: vol²=1.51e-03 valid=100.0% heads=['0.13', '0.46', '0.20', '0.20']\n", "\n", "Ep 26 | CE 0.1012 | Tr 96.44% | Te 89.76% | Best 90.06%\n", " Patch k-hierarchy weights: ['0.413', '0.238', '0.219', '0.130']\n", " k=1: pot= 6 vol²=2.01e-01 valid=100.0% attn=1.000 d²=0.2010\n", " k=2: pot= 56 vol²=3.48e-02 valid=100.0% attn=1.000 d²=0.4850\n", " k=3: pot= 1,820 vol²=5.92e-03 valid=100.0% attn=1.000 d²=0.5973\n", " k=4: pot= 201,376 vol²=2.71e-04 valid=100.0% attn=1.000 d²=0.7089\n", " Block 1 4-simplex: vol²=5.99e-04 valid=100.0% heads=['0.26', '0.23', '0.26', '0.25']\n", " Block 2 4-simplex: vol²=9.32e-04 valid=100.0% heads=['0.22', '0.30', '0.13', '0.35']\n", " Block 3 4-simplex: vol²=7.11e-04 valid=100.0% heads=['0.25', '0.25', '0.26', '0.25']\n", " Block 4 4-simplex: vol²=1.49e-03 valid=100.0% heads=['0.14', '0.43', '0.20', '0.24']\n", "\n", "Ep 30 | CE 0.0801 | Tr 97.38% | Te 89.91% | Best 90.06%\n", " Patch k-hierarchy weights: ['0.411', '0.238', '0.220', '0.131']\n", " k=1: pot= 6 vol²=1.93e-01 valid=100.0% attn=1.000 d²=0.1933\n", " k=2: pot= 56 vol²=3.37e-02 valid=100.0% attn=1.000 d²=0.4832\n", " k=3: pot= 1,820 vol²=5.55e-03 valid=100.0% attn=1.000 d²=0.5836\n", " k=4: pot= 201,376 vol²=2.61e-04 valid=100.0% attn=0.999 d²=0.6983\n", " Block 1 4-simplex: vol²=6.01e-04 valid=100.0% heads=['0.26', '0.23', '0.26', '0.25']\n", " Block 2 4-simplex: vol²=9.30e-04 valid=100.0% heads=['0.22', '0.30', '0.13', '0.35']\n", " Block 3 4-simplex: vol²=7.22e-04 valid=100.0% heads=['0.25', '0.25', '0.25', '0.25']\n", " Block 4 4-simplex: vol²=1.46e-03 valid=100.0% heads=['0.14', '0.42', '0.20', '0.24']\n", "===================================================================================================================\n", "Best: 90.06%\n" ] } ] }, { "cell_type": "code", "source": [ "#@title Geometric Patchwork ViT - Proper Simplex Patch Structure\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets, transforms\n", "import math\n", "from itertools import combinations\n", "\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "print(f\"Device: {device}\")\n", "\n", "from geovocab2.shapes.factory.simplex_factory import SimplexFactory\n", "\n", "# ============================================================================\n", "# CAYLEY-MENGER VALIDATOR\n", "# ============================================================================\n", "\n", "class CMValidator(nn.Module):\n", " def __init__(self, k):\n", " super().__init__()\n", " self._k = k\n", " self._nv = k + 1\n", "\n", " pairs = list(combinations(range(self._nv), 2))\n", " self._npairs = len(pairs)\n", " self.register_buffer('_pi', torch.tensor([p[0] for p in pairs], dtype=torch.long))\n", " self.register_buffer('_pj', torch.tensor([p[1] for p in pairs], dtype=torch.long))\n", "\n", " sign = (-1.0) ** (k + 1)\n", " fact = math.factorial(k)\n", " self._prefactor = sign / ((2.0 ** k) * (fact ** 2))\n", "\n", " def forward(self, verts):\n", " gram = torch.einsum('...ve,...we->...vw', verts, verts)\n", " norms = torch.diagonal(gram, dim1=-2, dim2=-1)\n", " d2_mat = norms.unsqueeze(-1) + norms.unsqueeze(-2) - 2 * gram\n", " d2_mat = F.relu(d2_mat)\n", "\n", " d2_pairs = d2_mat[..., self._pi, self._pj]\n", "\n", " shape = d2_mat.shape[:-2]\n", " V = d2_mat.shape[-1]\n", " cm = torch.zeros(*shape, V+1, V+1, device=d2_mat.device, dtype=d2_mat.dtype)\n", " cm[..., 0, 1:] = 1.0\n", " cm[..., 1:, 0] = 1.0\n", " cm[..., 1:, 1:] = d2_mat\n", " det = torch.linalg.det(cm)\n", "\n", " vol2 = self._prefactor * det\n", " valid = vol2 > 1e-8\n", "\n", " return d2_pairs, vol2, valid\n", "\n", "\n", "# ============================================================================\n", "# 3-SIMPLEX PATCH: Each patch IS a tetrahedron\n", "# ============================================================================\n", "\n", "class SimplexPatch(nn.Module):\n", " \"\"\"\n", " Each image patch is represented as a valid 3-simplex (tetrahedron).\n", "\n", " Structure per patch:\n", " - 4 vertices in R^edim (the tetrahedron)\n", " - 6 edge distances (d²)\n", " - 1 volume (vol²)\n", " - 4 vertex features in R^fdim\n", "\n", " The tetrahedron IS the patch representation, not a flattened summary.\n", " \"\"\"\n", " def __init__(self, patch_dim, fdim, edim):\n", " super().__init__()\n", " self._k = 3 # tetrahedron\n", " self._nv = 4\n", " self._fdim = fdim\n", " self._edim = edim\n", "\n", " self._cm = CMValidator(self._k)\n", "\n", " # Factory template\n", " factory = SimplexFactory(k=self._k, embed_dim=edim, method=\"regular\", scale=1.0)\n", " template = factory.build_torch(dtype=torch.float32)\n", " self.register_buffer('_template', template) # [4, edim]\n", "\n", " # Patch pixels → 4 vertex features\n", " self._to_vert_feats = nn.Sequential(\n", " nn.Linear(patch_dim, fdim * 2),\n", " nn.LayerNorm(fdim * 2),\n", " nn.GELU(),\n", " nn.Linear(fdim * 2, self._nv * fdim),\n", " )\n", "\n", " # Vertex features → deformation of template\n", " self._to_deform = nn.Linear(fdim, edim)\n", "\n", " self._deform_scale = 0.05\n", "\n", " def forward(self, patches):\n", " \"\"\"\n", " patches: [B, P, patch_dim]\n", "\n", " Returns:\n", " verts: [B, P, 4, edim] - tetrahedron vertices per patch\n", " feats: [B, P, 4, fdim] - features at each vertex\n", " d2: [B, P, 6] - squared edge distances\n", " vol2: [B, P] - squared volume\n", " valid: [B, P] - validity mask\n", " \"\"\"\n", " B, P, _ = patches.shape\n", "\n", " # Patch → 4 vertex features\n", " vert_feats = self._to_vert_feats(patches) # [B, P, 4*fdim]\n", " vert_feats = vert_feats.view(B, P, self._nv, self._fdim) # [B, P, 4, fdim]\n", "\n", " # Vertex features → deformation\n", " deform = self._to_deform(vert_feats) # [B, P, 4, edim]\n", "\n", " # Apply to template\n", " template = self._template.unsqueeze(0).unsqueeze(0) # [1, 1, 4, edim]\n", " verts = template + self._deform_scale * deform # [B, P, 4, edim]\n", "\n", " # CM validation\n", " d2, vol2, valid = self._cm(verts)\n", "\n", " return verts, vert_feats, d2, vol2, valid\n", "\n", "\n", "# ============================================================================\n", "# 4-SIMPLEX ATTENTION: Operates on tetrahedra as vertices\n", "# ============================================================================\n", "\n", "class Simplex4Attention(nn.Module):\n", " \"\"\"\n", " 4-simplex attention over patch tetrahedra.\n", "\n", " Each patch is a tetrahedron. We select 5 patches to form a 4-simplex.\n", " The centroid of each tetrahedron becomes a vertex of the 4-simplex.\n", "\n", " This preserves the geometric structure:\n", " - 16 tetrahedra (patches) as input\n", " - Select 5 tetrahedra per head\n", " - Their centroids form a 4-simplex\n", " - CM validates the 4-simplex\n", " - Attention weights from geometry\n", " \"\"\"\n", " def __init__(self, fdim, edim, num_heads=4):\n", " super().__init__()\n", " self.fdim = fdim\n", " self.edim = edim\n", " self.num_heads = num_heads\n", " self.head_dim = fdim // num_heads\n", "\n", " self._k = 4 # 4-simplex\n", " self._nv = 5 # select 5 patches\n", "\n", " self._cm = CMValidator(self._k)\n", "\n", " # Factory template for 4-simplex\n", " factory = SimplexFactory(k=self._k, embed_dim=edim, method=\"regular\", scale=1.0)\n", " template = factory.build_torch(dtype=torch.float32)\n", " self.register_buffer('_template', template) # [5, edim]\n", "\n", " # Selection scores: which 5 patches to select per head\n", " # Input: patch features (aggregated from vertices)\n", " self._to_select = nn.Linear(fdim, num_heads * self._nv)\n", "\n", " # Value projection (from vertex features)\n", " self._to_v = nn.Linear(fdim, fdim)\n", "\n", " # Geometry modulates output\n", " geom_dim = self._cm._npairs + 1 # 10 + 1\n", " self._geo_gate = nn.Sequential(\n", " nn.Linear(geom_dim, fdim),\n", " nn.Sigmoid(),\n", " )\n", "\n", " # Output projection\n", " self._out = nn.Linear(fdim, fdim)\n", "\n", " self._deform_scale = 0.05\n", "\n", " def forward(self, patch_verts, patch_feats, patch_d2, patch_vol2, patch_valid):\n", " \"\"\"\n", " patch_verts: [B, P, 4, edim] - tetrahedron vertices per patch\n", " patch_feats: [B, P, 4, fdim] - vertex features per patch\n", " patch_d2: [B, P, 6] - edge distances per patch\n", " patch_vol2: [B, P] - volumes per patch\n", " patch_valid: [B, P] - validity per patch\n", "\n", " Returns:\n", " out: [B, P, fdim] - updated patch features\n", " info: dict\n", " \"\"\"\n", " B, P, _, _ = patch_verts.shape\n", " H = self.num_heads\n", "\n", " # Aggregate patch features (mean over 4 vertices)\n", " patch_agg = patch_feats.mean(dim=2) # [B, P, fdim]\n", "\n", " # Centroid of each tetrahedron\n", " centroids = patch_verts.mean(dim=2) # [B, P, edim]\n", "\n", " # Selection scores: [B, P, H * 5]\n", " sel_scores = self._to_select(patch_agg)\n", " sel_scores = sel_scores.view(B, P, H, self._nv).permute(0, 2, 3, 1) # [B, H, 5, P]\n", "\n", " # Soft-select 5 patches per head\n", " sel_weights = F.softmax(sel_scores, dim=-1) # [B, H, 5, P]\n", "\n", " # Gather centroids for selected patches: [B, H, 5, edim]\n", " sel_centroids = torch.einsum('bhvp,bpe->bhve', sel_weights, centroids)\n", "\n", " # Deform 4-simplex template based on selected centroids\n", " template = self._template.unsqueeze(0).unsqueeze(0) # [1, 1, 5, edim]\n", " simplex4_verts = template + self._deform_scale * sel_centroids # [B, H, 5, edim]\n", "\n", " # CM validation of 4-simplices\n", " d2_4, vol2_4, valid_4 = self._cm(simplex4_verts) # d2: [B,H,10], vol2: [B,H]\n", "\n", " # Validity attenuation per head\n", " head_valid = torch.sigmoid(vol2_4 * 1e6) # [B, H]\n", "\n", " # Geometry gate\n", " geo = torch.cat([d2_4, vol2_4.unsqueeze(-1)], dim=-1) # [B, H, 11]\n", "\n", " # Values from vertex features\n", " v = self._to_v(patch_feats) # [B, P, 4, fdim]\n", " v = v.mean(dim=2) # [B, P, fdim] - aggregate per patch\n", " v = v.view(B, P, H, self.head_dim).permute(0, 2, 1, 3) # [B, H, P, head_dim]\n", "\n", " # Attention weights from selection (average over 5 positions)\n", " attn = sel_weights.mean(dim=2) # [B, H, P]\n", " attn = attn * patch_valid.unsqueeze(1) # mask invalid patches\n", " attn = attn / (attn.sum(dim=-1, keepdim=True) + 1e-8)\n", "\n", " # Weighted sum of values\n", " out = torch.einsum('bhp,bhpd->bhd', attn, v) # [B, H, head_dim]\n", "\n", " # Apply head validity\n", " out = out * head_valid.unsqueeze(-1) # [B, H, head_dim]\n", " out = out.reshape(B, self.fdim) # [B, fdim]\n", "\n", " # Geometry gate per-patch (broadcast from 4-simplex geometry)\n", " geo_expanded = geo.mean(dim=1) # [B, 11] - average across heads\n", " gate = self._geo_gate(geo_expanded) # [B, fdim]\n", "\n", " out = self._out(out) * gate # [B, fdim]\n", "\n", " # Broadcast back to all patches (residual style)\n", " out = out.unsqueeze(1).expand(-1, P, -1) # [B, P, fdim]\n", "\n", " return out, {\n", " 'd2_3': patch_d2, # tetrahedron edges\n", " 'vol2_3': patch_vol2, # tetrahedron volumes\n", " 'valid_3': patch_valid, # tetrahedron validity\n", " 'd2_4': d2_4, # 4-simplex edges\n", " 'vol2_4': vol2_4, # 4-simplex volumes\n", " 'valid_4': valid_4, # 4-simplex validity\n", " 'sel_weights': sel_weights, # which patches selected\n", " 'head_valid': head_valid,\n", " }\n", "\n", "\n", "# ============================================================================\n", "# TRANSFORMER BLOCK\n", "# ============================================================================\n", "\n", "class GeoTransformerBlock(nn.Module):\n", " def __init__(self, fdim, edim, num_heads, mlp_ratio=4.0):\n", " super().__init__()\n", " self.norm1_feats = nn.LayerNorm(fdim)\n", " self.attn = Simplex4Attention(fdim, edim, num_heads)\n", " self.norm2 = nn.LayerNorm(fdim)\n", " self.mlp = nn.Sequential(\n", " nn.Linear(fdim, int(fdim * mlp_ratio)),\n", " nn.GELU(),\n", " nn.Linear(int(fdim * mlp_ratio), fdim),\n", " )\n", "\n", " def forward(self, patch_verts, patch_feats, patch_d2, patch_vol2, patch_valid):\n", " \"\"\"\n", " Maintains full geometric structure throughout.\n", " \"\"\"\n", " B, P, nv, fdim = patch_feats.shape\n", "\n", " # Normalize features\n", " feats_norm = self.norm1_feats(patch_feats)\n", "\n", " # 4-simplex attention\n", " attn_out, attn_info = self.attn(\n", " patch_verts, feats_norm, patch_d2, patch_vol2, patch_valid\n", " )\n", "\n", " # Residual: add to all vertices of all patches\n", " patch_feats = patch_feats + attn_out.unsqueeze(2) # [B, P, 4, fdim]\n", "\n", " # MLP\n", " feats_flat = patch_feats.view(B * P * nv, fdim)\n", " feats_flat = self.norm2(feats_flat)\n", " feats_flat = feats_flat + self.mlp(feats_flat)\n", " patch_feats = feats_flat.view(B, P, nv, fdim)\n", "\n", " return patch_verts, patch_feats, patch_d2, patch_vol2, patch_valid, attn_info\n", "\n", "\n", "# ============================================================================\n", "# FULL MODEL\n", "# ============================================================================\n", "\n", "class GeometricPatchworkViT(nn.Module):\n", " def __init__(\n", " self,\n", " img_size=28,\n", " patch_size=7,\n", " in_chans=1,\n", " num_classes=10,\n", " fdim=64,\n", " edim=8,\n", " num_heads=4,\n", " depth=4,\n", " ):\n", " super().__init__()\n", "\n", " self.img_size = img_size\n", " self.patch_size = patch_size\n", " self.num_patches = (img_size // patch_size) ** 2\n", " self.patch_dim = patch_size * patch_size * in_chans\n", " self.fdim = fdim\n", " self.edim = edim\n", "\n", " # 3-simplex patch embedding\n", " self.patch_embed = SimplexPatch(self.patch_dim, fdim, edim)\n", "\n", " # Position embedding for tetrahedron centroids\n", " self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches, edim) * 0.02)\n", "\n", " # Transformer blocks maintaining full geometry\n", " self.blocks = nn.ModuleList([\n", " GeoTransformerBlock(fdim, edim, num_heads)\n", " for _ in range(depth)\n", " ])\n", "\n", " # Classification from aggregated tetrahedra\n", " # Use both geometry and features\n", " self.norm = nn.LayerNorm(fdim)\n", " geom_feat_dim = fdim + 6 + 1 # features + d² + vol²\n", " self.head = nn.Sequential(\n", " nn.Linear(geom_feat_dim, fdim),\n", " nn.GELU(),\n", " nn.Linear(fdim, num_classes),\n", " )\n", "\n", " print(f\"\\nGeometricPatchworkViT:\")\n", " print(f\" {self.num_patches} patches × {patch_size}×{patch_size}\")\n", " print(f\" Each patch = 3-simplex (tetrahedron): 4 verts in R^{edim}\")\n", " print(f\" Attention = 4-simplex: selects 5 tetrahedra per head\")\n", " print(f\" {depth} blocks × {num_heads} heads\")\n", "\n", " def forward(self, x, ret_diag=False):\n", " B = x.shape[0]\n", "\n", " # Extract patches\n", " patches = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)\n", " patches = patches.contiguous().view(B, -1, self.patch_size * self.patch_size) # [B, P, patch_dim]\n", "\n", " # Embed as tetrahedra\n", " patch_verts, patch_feats, patch_d2, patch_vol2, patch_valid = self.patch_embed(patches)\n", " # verts: [B, P, 4, edim], feats: [B, P, 4, fdim]\n", "\n", " # Add position embedding to centroids (affects geometry)\n", " centroids = patch_verts.mean(dim=2) # [B, P, edim]\n", " pos_offset = self.pos_embed # [1, P, edim]\n", " patch_verts = patch_verts + pos_offset.unsqueeze(2) # Add to all vertices\n", "\n", " # Recompute CM after position embedding\n", " patch_d2, patch_vol2, patch_valid = self.patch_embed._cm(patch_verts)\n", "\n", " # Transformer blocks\n", " block_infos = []\n", " validity_loss = F.relu(-patch_vol2).mean() # Initial patch validity\n", "\n", " for block in self.blocks:\n", " patch_verts, patch_feats, patch_d2, patch_vol2, patch_valid, info = block(\n", " patch_verts, patch_feats, patch_d2, patch_vol2, patch_valid\n", " )\n", " block_infos.append(info)\n", " validity_loss = validity_loss + F.relu(-info['vol2_4']).mean()\n", "\n", " # Aggregate for classification\n", " # Validity-weighted mean over patches\n", " patch_valid_f = patch_valid.float() # [B, P]\n", "\n", " # Mean over vertices, then weighted mean over patches\n", " feats_agg = patch_feats.mean(dim=2) # [B, P, fdim]\n", " feats_agg = self.norm(feats_agg)\n", " feats_agg = (feats_agg * patch_valid_f.unsqueeze(-1)).sum(dim=1) / (patch_valid_f.sum(dim=1, keepdim=True) + 1e-8) # [B, fdim]\n", "\n", " # Mean geometry\n", " d2_agg = (patch_d2 * patch_valid_f.unsqueeze(-1)).sum(dim=1) / (patch_valid_f.sum(dim=1, keepdim=True) + 1e-8) # [B, 6]\n", " vol2_agg = (patch_vol2 * patch_valid_f).sum(dim=1) / (patch_valid_f.sum(dim=1) + 1e-8) # [B]\n", "\n", " # Combine features + geometry\n", " combined = torch.cat([feats_agg, d2_agg, vol2_agg.unsqueeze(-1)], dim=-1) # [B, fdim+7]\n", "\n", " logits = self.head(combined)\n", "\n", " if ret_diag:\n", " return logits, {\n", " 'patch_d2': patch_d2,\n", " 'patch_vol2': patch_vol2,\n", " 'patch_valid': patch_valid,\n", " 'blocks': block_infos,\n", " }, validity_loss\n", "\n", " return logits, validity_loss\n", "\n", "\n", "# ============================================================================\n", "# TRAIN\n", "# ============================================================================\n", "\n", "def train():\n", " transform = transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.2860,), (0.3530,))\n", " ])\n", "\n", " train_ds = datasets.FashionMNIST('./data', train=True, download=True, transform=transform)\n", " test_ds = datasets.FashionMNIST('./data', train=False, download=True, transform=transform)\n", " train_dl = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=2)\n", " test_dl = DataLoader(test_ds, batch_size=256, shuffle=False, num_workers=2)\n", "\n", " print(\"Building model...\")\n", " model = GeometricPatchworkViT(\n", " img_size=28,\n", " patch_size=7,\n", " in_chans=1,\n", " num_classes=10,\n", " fdim=64,\n", " edim=8,\n", " num_heads=4,\n", " depth=4,\n", " ).to(device)\n", "\n", " params = sum(p.numel() for p in model.parameters())\n", " print(f\"Total params: {params:,}\")\n", "\n", " opt = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.05)\n", " sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=30)\n", "\n", " best = 0\n", " print(\"\\nTraining...\")\n", " print(\"=\" * 115)\n", "\n", " for ep in range(30):\n", " model.train()\n", " lsum, vsum, cor, tot = 0, 0, 0, 0\n", "\n", " for img, lab in train_dl:\n", " img, lab = img.to(device), lab.to(device)\n", "\n", " opt.zero_grad()\n", " logits, vloss = model(img)\n", " ce = F.cross_entropy(logits, lab)\n", " loss = ce + 0.1 * vloss\n", " loss.backward()\n", " opt.step()\n", "\n", " lsum += ce.item() * img.size(0)\n", " vsum += vloss.item() * img.size(0)\n", " cor += (logits.argmax(1) == lab).sum().item()\n", " tot += img.size(0)\n", "\n", " tr_acc = cor / tot\n", " tr_loss = lsum / tot\n", "\n", " model.eval()\n", " cor, tot = 0, 0\n", " samp = None\n", " with torch.no_grad():\n", " for img, lab in test_dl:\n", " img, lab = img.to(device), lab.to(device)\n", " logits, diag, _ = model(img, ret_diag=True)\n", " cor += (logits.argmax(1) == lab).sum().item()\n", " tot += img.size(0)\n", " if samp is None:\n", " samp = diag\n", "\n", " te_acc = cor / tot\n", " sched.step()\n", " if te_acc > best:\n", " best = te_acc\n", "\n", " if ep % 5 == 0 or ep == 29:\n", " # Patch 3-simplex stats\n", " p_vol = samp['patch_vol2'].mean().item()\n", " p_valid = samp['patch_valid'].float().mean().item()\n", " p_d2 = samp['patch_d2'].mean().item()\n", "\n", " print(f\"\\nEp {ep+1:2d} | CE {tr_loss:.4f} | Tr {tr_acc:.2%} | Te {te_acc:.2%} | Best {best:.2%}\")\n", " print(f\" 3-simplex patches: vol²={p_vol:.4f} valid={p_valid:.1%} d²={p_d2:.4f}\")\n", "\n", " # Per-block 4-simplex stats\n", " for i, bi in enumerate(samp['blocks']):\n", " v4 = bi['vol2_4'].mean().item()\n", " valid4 = bi['valid_4'].float().mean().item()\n", " d4 = bi['d2_4'].mean().item()\n", " hw = bi['head_valid'].mean(dim=0).cpu().numpy()\n", " print(f\" Block {i+1} 4-simplex: vol²={v4:.4f} valid={valid4:.1%} d²={d4:.4f} heads={['%.2f' % h for h in hw]}\")\n", "\n", " print(\"=\" * 115)\n", " print(f\"Best: {best:.2%}\")\n", " return model\n", "\n", "\n", "if __name__ == \"__main__\":\n", " model = train()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "cellView": "form", "id": "jcvID8JBh5la", "outputId": "640ef88c-a3cb-4533-a7d2-f251ed0990f2" }, "execution_count": 9, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Device: cuda\n", "Building model...\n", "\n", "GeometricPatchworkViT:\n", " 16 patches × 7×7\n", " Each patch = 3-simplex (tetrahedron): 4 verts in R^8\n", " Attention = 4-simplex: selects 5 tetrahedra per head\n", " 4 blocks × 4 heads\n", "Total params: 220,642\n", "\n", "Training...\n", "===================================================================================================================\n", "\n", "Ep 1 | CE 0.8372 | Tr 70.18% | Te 77.86% | Best 77.86%\n", " 3-simplex patches: vol²=0.0160 valid=100.0% d²=1.0490\n", " Block 1 4-simplex: vol²=0.0005 valid=100.0% d²=1.0008 heads=['1.00', '1.00', '1.00', '1.00']\n", " Block 2 4-simplex: vol²=0.0005 valid=100.0% d²=1.0004 heads=['1.00', '1.00', '1.00', '1.00']\n", " Block 3 4-simplex: vol²=0.0005 valid=100.0% d²=0.9998 heads=['1.00', '1.00', '1.00', '1.00']\n", " Block 4 4-simplex: vol²=0.0005 valid=100.0% d²=0.9997 heads=['1.00', '1.00', '1.00', '1.00']\n", "\n", "Ep 6 | CE 0.3629 | Tr 86.72% | Te 85.57% | Best 85.57%\n", " 3-simplex patches: vol²=0.0158 valid=100.0% d²=1.0448\n", " Block 1 4-simplex: vol²=0.0005 valid=100.0% d²=1.0019 heads=['1.00', '1.00', '1.00', '1.00']\n", " Block 2 4-simplex: vol²=0.0005 valid=100.0% d²=1.0013 heads=['1.00', '1.00', '1.00', '1.00']\n", " Block 3 4-simplex: vol²=0.0005 valid=100.0% d²=0.9988 heads=['1.00', '1.00', '1.00', '1.00']\n", " Block 4 4-simplex: vol²=0.0005 valid=100.0% d²=0.9993 heads=['1.00', '1.00', '1.00', '1.00']\n", "\n", "Ep 11 | CE 0.2780 | Tr 89.74% | Te 87.42% | Best 87.42%\n", " 3-simplex patches: vol²=0.0151 valid=100.0% d²=1.0293\n", " Block 1 4-simplex: vol²=0.0005 valid=100.0% d²=1.0026 heads=['1.00', '1.00', '1.00', '1.00']\n", " Block 2 4-simplex: vol²=0.0005 valid=100.0% d²=1.0013 heads=['1.00', '1.00', '1.00', '1.00']\n", " Block 3 4-simplex: vol²=0.0005 valid=100.0% d²=0.9990 heads=['1.00', '1.00', '1.00', '1.00']\n", " Block 4 4-simplex: vol²=0.0005 valid=100.0% d²=0.9996 heads=['1.00', '1.00', '1.00', '1.00']\n", "\n", "Ep 16 | CE 0.2135 | Tr 92.12% | Te 88.07% | Best 88.07%\n", " 3-simplex patches: vol²=0.0148 valid=100.0% d²=1.0227\n", " Block 1 4-simplex: vol²=0.0005 valid=100.0% d²=1.0027 heads=['1.00', '1.00', '1.00', '1.00']\n", " Block 2 4-simplex: vol²=0.0005 valid=100.0% d²=1.0012 heads=['1.00', '1.00', '1.00', '1.00']\n", " Block 3 4-simplex: vol²=0.0005 valid=100.0% d²=0.9991 heads=['1.00', '1.00', '1.00', '1.00']\n", " Block 4 4-simplex: vol²=0.0005 valid=100.0% d²=0.9992 heads=['1.00', '1.00', '1.00', '1.00']\n" ] }, { "output_type": "error", "ename": "KeyboardInterrupt", "evalue": "", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", "\u001b[0;32m/tmp/ipython-input-4143309422.py\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 516\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 517\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0m__name__\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"__main__\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 518\u001b[0;31m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m/tmp/ipython-input-4143309422.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m()\u001b[0m\n\u001b[1;32m 465\u001b[0m \u001b[0mce\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcross_entropy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlogits\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlab\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 466\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mce\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m0.1\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mvloss\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 467\u001b[0;31m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 468\u001b[0m \u001b[0mopt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 469\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/torch/_tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 623\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 624\u001b[0m )\n\u001b[0;32m--> 625\u001b[0;31m torch.autograd.backward(\n\u001b[0m\u001b[1;32m 626\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 627\u001b[0m )\n", "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 352\u001b[0m \u001b[0;31m# some Python versions print out the first line of a multi-line function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 353\u001b[0m \u001b[0;31m# calls in the traceback and some print out the last line\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 354\u001b[0;31m _engine_run_backward(\n\u001b[0m\u001b[1;32m 355\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 356\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/torch/autograd/graph.py\u001b[0m in \u001b[0;36m_engine_run_backward\u001b[0;34m(t_outputs, *args, **kwargs)\u001b[0m\n\u001b[1;32m 839\u001b[0m \u001b[0munregister_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_register_logging_hooks_on_whole_graph\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt_outputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 840\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 841\u001b[0;31m return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass\n\u001b[0m\u001b[1;32m 842\u001b[0m \u001b[0mt_outputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 843\u001b[0m ) # Calls into the C++ engine to run the backward pass\n", "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ] }, { "cell_type": "code", "source": [ "#@title Geometric Patchwork ViT - Using Full GeoComplex\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets, transforms\n", "import math\n", "from itertools import combinations\n", "\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "print(f\"Device: {device}\")\n", "\n", "from geovocab2.shapes.factory.simplex_factory import SimplexFactory\n", "\n", "# ============================================================================\n", "# CAYLEY-MENGER VALIDATOR\n", "# ============================================================================\n", "\n", "class CMValidator(nn.Module):\n", " def __init__(self, k):\n", " super().__init__()\n", " self._k = k\n", " self._nv = k + 1\n", "\n", " pairs = list(combinations(range(self._nv), 2))\n", " self._npairs = len(pairs)\n", " self.register_buffer('_pi', torch.tensor([p[0] for p in pairs], dtype=torch.long))\n", " self.register_buffer('_pj', torch.tensor([p[1] for p in pairs], dtype=torch.long))\n", "\n", " sign = (-1.0) ** (k + 1)\n", " fact = math.factorial(k)\n", " self._prefactor = sign / ((2.0 ** k) * (fact ** 2))\n", "\n", " def forward(self, verts):\n", " gram = torch.einsum('...ve,...we->...vw', verts, verts)\n", " norms = torch.diagonal(gram, dim1=-2, dim2=-1)\n", " d2_mat = norms.unsqueeze(-1) + norms.unsqueeze(-2) - 2 * gram\n", " d2_mat = F.relu(d2_mat)\n", "\n", " d2_pairs = d2_mat[..., self._pi, self._pj]\n", "\n", " shape = d2_mat.shape[:-2]\n", " V = d2_mat.shape[-1]\n", " cm = torch.zeros(*shape, V+1, V+1, device=d2_mat.device, dtype=d2_mat.dtype)\n", " cm[..., 0, 1:] = 1.0\n", " cm[..., 1:, 0] = 1.0\n", " cm[..., 1:, 1:] = d2_mat\n", " det = torch.linalg.det(cm)\n", "\n", " vol2 = self._prefactor * det\n", " valid = vol2 > 1e-8\n", "\n", " return d2_pairs, vol2, valid\n", "\n", "\n", "# ============================================================================\n", "# GEOMETRIC LEVEL (from our validated design)\n", "# ============================================================================\n", "\n", "class GeoLevel(nn.Module):\n", " BASE_DEFORM = 0.05\n", "\n", " def __init__(self, order, nin, nout, fdim, edim):\n", " super().__init__()\n", " self._order = order\n", " self._nin = nin\n", " self._nout = nout\n", " self._fdim = fdim\n", " self._edim = edim\n", " self._nv = order + 1\n", "\n", " self._potential = math.comb(nin, self._nv) if nin >= self._nv else 0\n", "\n", " self._cm = CMValidator(order)\n", "\n", " # Factory template - guaranteed valid\n", " factory = SimplexFactory(k=order, embed_dim=edim, method=\"regular\", scale=1.0)\n", " template = factory.build_torch(dtype=torch.float32)\n", " self.register_buffer('_template', template)\n", "\n", " # Vertex selection\n", " _sel = torch.randn(nout, self._nv, nin) * 0.1\n", " self.register_parameter('_W_select', nn.Parameter(_sel))\n", "\n", " # Deformation network\n", " self._deform = nn.Sequential(\n", " nn.Linear(fdim, fdim),\n", " nn.LayerNorm(fdim),\n", " nn.GELU(),\n", " nn.Linear(fdim, self._nv * edim),\n", " )\n", "\n", " # Feature MLP\n", " self._mlp = nn.Sequential(\n", " nn.Linear(fdim, fdim * 2),\n", " nn.LayerNorm(fdim * 2),\n", " nn.GELU(),\n", " nn.Linear(fdim * 2, fdim),\n", " )\n", "\n", " # Geometry gate\n", " gate_in = self._cm._npairs + 1\n", " self._gate = nn.Sequential(\n", " nn.Linear(gate_in, fdim),\n", " nn.Sigmoid(),\n", " )\n", "\n", " # Residual\n", " if nin != nout:\n", " self._proj = nn.Linear(nin * fdim, nout * fdim)\n", " else:\n", " self._proj = None\n", "\n", " def forward(self, x):\n", " \"\"\"x: [B, nin, fdim] -> [B, nout, fdim]\"\"\"\n", " B = x.shape[0]\n", "\n", " sel = F.softmax(self._W_select, dim=-1)\n", " picked = torch.einsum('ovn,bnf->bovf', sel, x)\n", " agg = picked.mean(dim=2)\n", "\n", " # Deform template\n", " deform = self._deform(agg).view(B, self._nout, self._nv, self._edim)\n", " template = self._template.unsqueeze(0).unsqueeze(0).expand(B, self._nout, -1, -1)\n", " verts = template + self.BASE_DEFORM * deform\n", "\n", " # CM validation\n", " d2, vol2, valid = self._cm(verts)\n", "\n", " # Geometry gating\n", " d2_norm = d2 / (d2.mean(dim=-1, keepdim=True) + 1e-8)\n", " geo = torch.cat([d2_norm, vol2.unsqueeze(-1) / (vol2.mean() + 1e-8)], dim=-1)\n", " gate = self._gate(geo)\n", "\n", " out = self._mlp(agg) * gate\n", "\n", " # Validity attenuation\n", " attn = torch.sigmoid(vol2 * 1e6)\n", " out = out * attn.unsqueeze(-1)\n", "\n", " # Residual\n", " if self._proj is not None:\n", " res = self._proj(x.flatten(1)).view(B, self._nout, self._fdim)\n", " else:\n", " res = x\n", " out = out + 0.1 * res\n", "\n", " return out, verts, {\n", " 'd2': d2, 'vol2': vol2, 'valid': valid, 'attn': attn,\n", " 'order': self._order, 'potential': self._potential,\n", " }\n", "\n", "\n", "# ============================================================================\n", "# GEOMETRIC COMPLEX (exponential growth k-simplex hierarchy)\n", "# ============================================================================\n", "\n", "class GeoComplex(nn.Module):\n", " def __init__(self, nbase, fdim, depth, edim):\n", " super().__init__()\n", " self._nbase = nbase\n", " self._fdim = fdim\n", " self._depth = depth\n", " self._edim = edim\n", "\n", " # EXPONENTIAL GROWTH\n", " sizes = [nbase]\n", " for _ in range(depth):\n", " sizes.append(sizes[-1] * 2)\n", " self._sizes = sizes\n", "\n", " # Potentials\n", " potentials = []\n", " for i in range(depth):\n", " k = i + 1\n", " nv = k + 1\n", " nin = sizes[i]\n", " pot = math.comb(nin, nv) if nin >= nv else 0\n", " potentials.append(pot)\n", " self._potentials = potentials\n", "\n", " # Input projection\n", " self._proj_in = nn.Linear(fdim, nbase * fdim)\n", "\n", " # Levels\n", " level_list = []\n", " for i in range(depth):\n", " lv = GeoLevel(\n", " order=i + 1,\n", " nin=sizes[i],\n", " nout=sizes[i + 1],\n", " fdim=fdim,\n", " edim=edim\n", " )\n", " level_list.append(lv)\n", " self._levels = nn.ModuleList(level_list)\n", "\n", " # Level weights\n", " self.register_parameter('_lw', nn.Parameter(torch.ones(depth) / depth))\n", "\n", " def forward(self, x):\n", " \"\"\"\n", " x: [B, fdim]\n", " Returns:\n", " pooled: [B, fdim] - weighted pooled output\n", " all_verts: list of [B, nout, nv, edim] per level\n", " all_feats: list of [B, nout, fdim] per level\n", " infos: list of info dicts per level\n", " \"\"\"\n", " B = x.shape[0]\n", " h = self._proj_in(x).view(B, self._nbase, self._fdim)\n", "\n", " all_feats, all_verts, infos = [], [], []\n", " for lv in self._levels:\n", " h, verts, info = lv(h)\n", " all_feats.append(h)\n", " all_verts.append(verts)\n", " infos.append(info)\n", "\n", " # Weighted pool\n", " w = F.softmax(self._lw, dim=0)\n", " pooled = []\n", " for i, (feat, info) in enumerate(zip(all_feats, infos)):\n", " a = info['attn']\n", " wm = (feat * a.unsqueeze(-1)).sum(1) / (a.sum(1, keepdim=True) + 1e-8)\n", " pooled.append(w[i] * wm)\n", "\n", " return sum(pooled), all_verts, all_feats, {\n", " 'sizes': self._sizes,\n", " 'potentials': self._potentials,\n", " 'infos': infos,\n", " 'weights': w.detach(),\n", " }\n", "\n", "\n", "# ============================================================================\n", "# GEOMETRIC PATCH: Each patch through full k-simplex hierarchy\n", "# ============================================================================\n", "\n", "class GeoPatch(nn.Module):\n", " \"\"\"\n", " Each image patch is processed through a full GeoComplex.\n", " Maintains the complete k-simplex hierarchy per patch.\n", " \"\"\"\n", " def __init__(self, patch_dim, fdim, nbase, depth, edim):\n", " super().__init__()\n", " self._fdim = fdim\n", " self._depth = depth\n", "\n", " # Patch to initial features\n", " self._patch_proj = nn.Sequential(\n", " nn.Linear(patch_dim, fdim),\n", " nn.LayerNorm(fdim),\n", " nn.GELU(),\n", " )\n", "\n", " # Full geometric complex\n", " self._complex = GeoComplex(nbase, fdim, depth, edim)\n", "\n", " def forward(self, patches):\n", " \"\"\"\n", " patches: [B, P, patch_dim]\n", "\n", " Returns per patch, per level:\n", " - vertices (the actual k-simplex geometry)\n", " - features\n", " - CM info\n", " \"\"\"\n", " B, P, _ = patches.shape\n", "\n", " # Project patches\n", " patch_feats = self._patch_proj(patches) # [B, P, fdim]\n", "\n", " # Process each patch through complex\n", " # Reshape: [B*P, fdim]\n", " flat = patch_feats.view(B * P, self._fdim)\n", "\n", " pooled, all_verts, all_feats, info = self._complex(flat)\n", "\n", " # Reshape back: each has shape [B*P, ...] -> [B, P, ...]\n", " pooled = pooled.view(B, P, self._fdim)\n", "\n", " # Reshape vertices and features per level\n", " all_verts_reshaped = []\n", " all_feats_reshaped = []\n", " for verts, feats in zip(all_verts, all_feats):\n", " # verts: [B*P, nout, nv, edim] -> [B, P, nout, nv, edim]\n", " nout, nv, edim = verts.shape[1], verts.shape[2], verts.shape[3]\n", " all_verts_reshaped.append(verts.view(B, P, nout, nv, edim))\n", "\n", " # feats: [B*P, nout, fdim] -> [B, P, nout, fdim]\n", " nout, fdim = feats.shape[1], feats.shape[2]\n", " all_feats_reshaped.append(feats.view(B, P, nout, fdim))\n", "\n", " # Reshape info\n", " for inf in info['infos']:\n", " for key in ['d2', 'vol2', 'valid', 'attn']:\n", " if key in inf:\n", " shape = inf[key].shape\n", " if len(shape) == 2: # [B*P, X]\n", " inf[key] = inf[key].view(B, P, shape[-1])\n", " elif len(shape) == 1: # [B*P]\n", " inf[key] = inf[key].view(B, P)\n", "\n", " return pooled, all_verts_reshaped, all_feats_reshaped, info\n", "\n", "\n", "# ============================================================================\n", "# 4-SIMPLEX SEQUENCE ATTENTION\n", "# ============================================================================\n", "\n", "class Simplex4SequenceAttention(nn.Module):\n", " \"\"\"\n", " 4-simplex attention over patches.\n", "\n", " Each patch has full k-simplex hierarchy.\n", " We use the highest level (k=depth) simplices as the \"tokens\".\n", " Select 5 patches to form a 4-simplex over patch-level geometry.\n", " \"\"\"\n", " def __init__(self, fdim, edim, num_patches, num_heads=4):\n", " super().__init__()\n", " self._fdim = fdim\n", " self._edim = edim\n", " self._num_patches = num_patches\n", " self._num_heads = num_heads\n", " self._head_dim = fdim // num_heads\n", "\n", " self._k = 4\n", " self._nv = 5\n", "\n", " self._cm = CMValidator(self._k)\n", "\n", " # Factory template\n", " factory = SimplexFactory(k=self._k, embed_dim=edim, method=\"regular\", scale=1.0)\n", " template = factory.build_torch(dtype=torch.float32)\n", " self.register_buffer('_template', template)\n", "\n", " # Selection: which 5 patches per head\n", " self._to_select = nn.Linear(fdim, num_heads * self._nv)\n", "\n", " # Patch features to coordinate\n", " self._to_coord = nn.Linear(fdim, edim)\n", "\n", " # Value projection\n", " self._to_v = nn.Linear(fdim, fdim)\n", "\n", " # Geometry gate\n", " geom_dim = self._cm._npairs + 1\n", " self._geo_gate = nn.Sequential(\n", " nn.Linear(geom_dim, num_heads),\n", " nn.Sigmoid(),\n", " )\n", "\n", " # Output\n", " self._out = nn.Linear(fdim, fdim)\n", "\n", " self._deform_scale = 0.05\n", "\n", " def forward(self, patch_pooled, patch_verts_highest, patch_feats_highest, patch_info):\n", " \"\"\"\n", " patch_pooled: [B, P, fdim] - pooled features per patch\n", " patch_verts_highest: [B, P, nout, nv, edim] - highest k-level vertices\n", " patch_feats_highest: [B, P, nout, fdim] - highest k-level features\n", " patch_info: info from GeoComplex\n", "\n", " Returns:\n", " out: [B, P, fdim]\n", " attn_info: dict\n", " \"\"\"\n", " B, P, _ = patch_pooled.shape\n", " H = self._num_heads\n", "\n", " # Selection scores: [B, P, H*5] -> [B, H, 5, P]\n", " sel = self._to_select(patch_pooled)\n", " sel = sel.view(B, P, H, self._nv).permute(0, 2, 3, 1)\n", " sel_weights = F.softmax(sel, dim=-1) # [B, H, 5, P]\n", "\n", " # Patch coordinates (from pooled features)\n", " coords = self._to_coord(patch_pooled) # [B, P, edim]\n", "\n", " # Select 5 patches -> their coordinates become 4-simplex vertices\n", " sel_coords = torch.einsum('bhvp,bpe->bhve', sel_weights, coords) # [B, H, 5, edim]\n", "\n", " # Deform template\n", " template = self._template.unsqueeze(0).unsqueeze(0)\n", " verts_4 = template + self._deform_scale * sel_coords\n", "\n", " # CM validation\n", " d2_4, vol2_4, valid_4 = self._cm(verts_4)\n", "\n", " # Head validity\n", " head_valid = torch.sigmoid(vol2_4 * 1e6) # [B, H]\n", "\n", " # Geometry gate\n", " geo = torch.cat([d2_4, vol2_4.unsqueeze(-1)], dim=-1)\n", " geo_weight = self._geo_gate(geo).mean(dim=-1) # [B, H]\n", " head_weight = geo_weight * head_valid\n", " head_weight = head_weight / (head_weight.sum(dim=-1, keepdim=True) + 1e-8)\n", "\n", " # Values\n", " v = self._to_v(patch_pooled) # [B, P, fdim]\n", " v = v.view(B, P, H, self._head_dim).permute(0, 2, 1, 3) # [B, H, P, hd]\n", "\n", " # Attention from selection\n", " attn = sel_weights.mean(dim=2) # [B, H, P]\n", " attn = attn / (attn.sum(dim=-1, keepdim=True) + 1e-8)\n", "\n", " # Aggregate\n", " out = torch.einsum('bhp,bhpd->bhd', attn, v) # [B, H, hd]\n", " out = out * head_weight.unsqueeze(-1)\n", " out = out.reshape(B, self._fdim)\n", " out = self._out(out)\n", "\n", " # Broadcast to patches\n", " out = out.unsqueeze(1).expand(-1, P, -1)\n", "\n", " return out, {\n", " 'd2_4': d2_4,\n", " 'vol2_4': vol2_4,\n", " 'valid_4': valid_4,\n", " 'head_weight': head_weight,\n", " 'sel_weights': sel_weights,\n", " }\n", "\n", "\n", "# ============================================================================\n", "# TRANSFORMER BLOCK\n", "# ============================================================================\n", "\n", "class GeoTransformerBlock(nn.Module):\n", " def __init__(self, fdim, edim, num_patches, num_heads, mlp_ratio=4.0):\n", " super().__init__()\n", " self.norm1 = nn.LayerNorm(fdim)\n", " self.attn = Simplex4SequenceAttention(fdim, edim, num_patches, num_heads)\n", " self.norm2 = nn.LayerNorm(fdim)\n", " self.mlp = nn.Sequential(\n", " nn.Linear(fdim, int(fdim * mlp_ratio)),\n", " nn.GELU(),\n", " nn.Linear(int(fdim * mlp_ratio), fdim),\n", " )\n", "\n", " def forward(self, patch_pooled, patch_verts, patch_feats, patch_info):\n", " # Attention\n", " attn_out, attn_info = self.attn(\n", " self.norm1(patch_pooled),\n", " patch_verts[-1], # Highest k-level\n", " patch_feats[-1],\n", " patch_info\n", " )\n", " patch_pooled = patch_pooled + attn_out\n", "\n", " # MLP\n", " patch_pooled = patch_pooled + self.mlp(self.norm2(patch_pooled))\n", "\n", " return patch_pooled, attn_info\n", "\n", "\n", "# ============================================================================\n", "# FULL MODEL\n", "# ============================================================================\n", "\n", "class GeometricPatchworkViT(nn.Module):\n", " def __init__(\n", " self,\n", " img_size=28,\n", " patch_size=7,\n", " in_chans=1,\n", " num_classes=10,\n", " fdim=64,\n", " edim=8,\n", " nbase=4,\n", " geo_depth=4,\n", " attn_depth=4,\n", " num_heads=4,\n", " ):\n", " super().__init__()\n", "\n", " self.img_size = img_size\n", " self.patch_size = patch_size\n", " self.num_patches = (img_size // patch_size) ** 2\n", " self.patch_dim = patch_size * patch_size * in_chans\n", "\n", " # Geometric patch embedding with full k-hierarchy\n", " self.patch_embed = GeoPatch(\n", " patch_dim=self.patch_dim,\n", " fdim=fdim,\n", " nbase=nbase,\n", " depth=geo_depth,\n", " edim=edim,\n", " )\n", "\n", " # Position embedding\n", " self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches, fdim) * 0.02)\n", "\n", " # Transformer blocks with 4-simplex attention\n", " self.blocks = nn.ModuleList([\n", " GeoTransformerBlock(fdim, edim, self.num_patches, num_heads)\n", " for _ in range(attn_depth)\n", " ])\n", "\n", " # Classification head uses geometry\n", " self.norm = nn.LayerNorm(fdim)\n", " self.head = nn.Linear(fdim, num_classes)\n", "\n", " # Print architecture\n", " print(f\"\\nGeometricPatchworkViT:\")\n", " print(f\" Patches: {self.num_patches} × {patch_size}×{patch_size}\")\n", " print(f\" K-simplex hierarchy per patch:\")\n", " print(f\" Sizes: {self.patch_embed._complex._sizes}\")\n", " print(f\" Potentials: {self.patch_embed._complex._potentials}\")\n", " for i, pot in enumerate(self.patch_embed._complex._potentials):\n", " k = i + 1\n", " print(f\" k={k}: {pot:,} potential {k}-simplices\")\n", " print(f\" 4-simplex attention: {attn_depth} blocks × {num_heads} heads\")\n", " print(f\" Selects 5 of {self.num_patches} patches per head\")\n", " print(f\" C({self.num_patches},5) = {math.comb(self.num_patches, 5):,} potential 4-simplices\")\n", "\n", " def forward(self, x, ret_diag=False):\n", " B = x.shape[0]\n", "\n", " # Extract patches\n", " patches = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)\n", " patches = patches.contiguous().view(B, -1, self.patch_size * self.patch_size)\n", "\n", " # Geometric patch embedding - full k-hierarchy\n", " patch_pooled, all_verts, all_feats, patch_info = self.patch_embed(patches)\n", "\n", " # Add position embedding\n", " patch_pooled = patch_pooled + self.pos_embed\n", "\n", " # Transformer blocks\n", " block_infos = []\n", " validity_loss = 0\n", "\n", " # Patch hierarchy validity\n", " for inf in patch_info['infos']:\n", " validity_loss = validity_loss + F.relu(-inf['vol2']).mean()\n", "\n", " for block in self.blocks:\n", " patch_pooled, attn_info = block(patch_pooled, all_verts, all_feats, patch_info)\n", " block_infos.append(attn_info)\n", " validity_loss = validity_loss + F.relu(-attn_info['vol2_4']).mean()\n", "\n", " # Aggregate for classification\n", " patch_pooled = self.norm(patch_pooled)\n", " cls = patch_pooled.mean(dim=1) # [B, fdim]\n", " logits = self.head(cls)\n", "\n", " if ret_diag:\n", " return logits, {'patch': patch_info, 'blocks': block_infos}, validity_loss\n", "\n", " return logits, validity_loss\n", "\n", "\n", "# ============================================================================\n", "# TRAIN\n", "# ============================================================================\n", "\n", "def train():\n", " transform = transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.2860,), (0.3530,))\n", " ])\n", "\n", " train_ds = datasets.FashionMNIST('./data', train=True, download=True, transform=transform)\n", " test_ds = datasets.FashionMNIST('./data', train=False, download=True, transform=transform)\n", " train_dl = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=2)\n", " test_dl = DataLoader(test_ds, batch_size=256, shuffle=False, num_workers=2)\n", "\n", " print(\"Building model...\")\n", " model = GeometricPatchworkViT(\n", " img_size=28,\n", " patch_size=7,\n", " in_chans=1,\n", " num_classes=10,\n", " fdim=64,\n", " edim=8,\n", " nbase=4,\n", " geo_depth=4,\n", " attn_depth=4,\n", " num_heads=4,\n", " ).to(device)\n", "\n", " params = sum(p.numel() for p in model.parameters())\n", " print(f\"Total params: {params:,}\")\n", "\n", " opt = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.05)\n", " sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=30)\n", "\n", " best = 0\n", " print(\"\\nTraining...\")\n", " print(\"=\" * 115)\n", "\n", " for ep in range(30):\n", " model.train()\n", " lsum, vsum, cor, tot = 0, 0, 0, 0\n", "\n", " for img, lab in train_dl:\n", " img, lab = img.to(device), lab.to(device)\n", "\n", " opt.zero_grad()\n", " logits, vloss = model(img)\n", " ce = F.cross_entropy(logits, lab)\n", " loss = ce + 0.1 * vloss\n", " loss.backward()\n", " opt.step()\n", "\n", " lsum += ce.item() * img.size(0)\n", " vsum += vloss.item() * img.size(0)\n", " cor += (logits.argmax(1) == lab).sum().item()\n", " tot += img.size(0)\n", "\n", " tr_acc = cor / tot\n", " tr_loss = lsum / tot\n", "\n", " model.eval()\n", " cor, tot = 0, 0\n", " samp = None\n", " with torch.no_grad():\n", " for img, lab in test_dl:\n", " img, lab = img.to(device), lab.to(device)\n", " logits, diag, _ = model(img, ret_diag=True)\n", " cor += (logits.argmax(1) == lab).sum().item()\n", " tot += img.size(0)\n", " if samp is None:\n", " samp = diag\n", "\n", " te_acc = cor / tot\n", " sched.step()\n", " if te_acc > best:\n", " best = te_acc\n", "\n", " if ep % 5 == 0 or ep == 29:\n", " print(f\"\\nEp {ep+1:2d} | CE {tr_loss:.4f} | Tr {tr_acc:.2%} | Te {te_acc:.2%} | Best {best:.2%}\")\n", "\n", " # K-hierarchy stats\n", " pw = samp['patch']['weights'].cpu().numpy()\n", " print(f\" Patch k-hierarchy weights: {['%.3f' % w for w in pw]}\")\n", " for inf in samp['patch']['infos']:\n", " k = inf['order']\n", " v = inf['vol2'].mean().item()\n", " vld = inf['valid'].float().mean().item()\n", " att = inf['attn'].mean().item()\n", " d = inf['d2'].mean().item()\n", " pot = inf['potential']\n", " print(f\" k={k}: pot={pot:>8,} vol²={v:.2e} valid={vld:5.1%} attn={att:.3f} d²={d:.4f}\")\n", "\n", " # 4-simplex attention stats\n", " for i, bi in enumerate(samp['blocks']):\n", " v4 = bi['vol2_4'].mean().item()\n", " valid4 = bi['valid_4'].float().mean().item()\n", " d4 = bi['d2_4'].mean().item()\n", " hw = bi['head_weight'].mean(dim=0).cpu().numpy()\n", " print(f\" Block {i+1} 4-simplex: vol²={v4:.2e} valid={valid4:5.1%} d²={d4:.4f} heads={['%.2f' % h for h in hw]}\")\n", "\n", " print(\"=\" * 115)\n", " print(f\"Best: {best:.2%}\")\n", " return model\n", "\n", "\n", "if __name__ == \"__main__\":\n", " model = train()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "id": "FugBnGC-i2tF", "outputId": "dd681fc1-2609-4751-d77c-4ae8c4dc3305" }, "execution_count": 2, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Device: cuda\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "100%|██████████| 26.4M/26.4M [00:00<00:00, 109MB/s]\n", "100%|██████████| 29.5k/29.5k [00:00<00:00, 4.02MB/s]\n", "100%|██████████| 4.42M/4.42M [00:00<00:00, 63.3MB/s]\n", "100%|██████████| 5.15k/5.15k [00:00<00:00, 20.7MB/s]\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "Building model...\n", "\n", "GeometricPatchworkViT:\n", " Patches: 16 × 7×7\n", " K-simplex hierarchy per patch:\n", " Sizes: [4, 8, 16, 32, 64]\n", " Potentials: [6, 56, 1820, 201376]\n", " k=1: 6 potential 1-simplices\n", " k=2: 56 potential 2-simplices\n", " k=3: 1,820 potential 3-simplices\n", " k=4: 201,376 potential 4-simplices\n", " 4-simplex attention: 4 blocks × 4 heads\n", " Selects 5 of 16 patches per head\n", " C(16,5) = 4,368 potential 4-simplices\n", "Total params: 11,450,990\n", "\n", "Training...\n", "===================================================================================================================\n", "\n", "Ep 1 | CE 0.8320 | Tr 69.77% | Te 79.11% | Best 79.11%\n", " Patch k-hierarchy weights: ['0.285', '0.256', '0.240', '0.219']\n", " k=1: pot= 6 vol²=1.08e+00 valid=100.0% attn=1.000 d²=1.0760\n", " k=2: pot= 56 vol²=1.67e-01 valid=100.0% attn=1.000 d²=0.9428\n", " k=3: pot= 1,820 vol²=1.46e-02 valid=100.0% attn=1.000 d²=1.0262\n", " k=4: pot= 201,376 vol²=5.85e-04 valid=100.0% attn=1.000 d²=1.0350\n", " Block 1 4-simplex: vol²=5.34e-04 valid=100.0% d²=0.9994 heads=['0.25', '0.25', '0.25', '0.25']\n", " Block 2 4-simplex: vol²=6.01e-04 valid=100.0% d²=1.0270 heads=['0.26', '0.26', '0.25', '0.23']\n", " Block 3 4-simplex: vol²=5.39e-04 valid=100.0% d²=0.9985 heads=['0.25', '0.25', '0.25', '0.25']\n", " Block 4 4-simplex: vol²=5.71e-04 valid=100.0% d²=1.0132 heads=['0.25', '0.25', '0.25', '0.25']\n", "\n", "Ep 6 | CE 0.3281 | Tr 87.88% | Te 87.19% | Best 87.19%\n", " Patch k-hierarchy weights: ['0.344', '0.245', '0.241', '0.170']\n", " k=1: pot= 6 vol²=6.21e-01 valid=100.0% attn=1.000 d²=0.6207\n", " k=2: pot= 56 vol²=1.31e-01 valid=100.0% attn=1.000 d²=0.8210\n", " k=3: pot= 1,820 vol²=1.41e-02 valid=100.0% attn=1.000 d²=1.0366\n", " k=4: pot= 201,376 vol²=5.17e-04 valid=100.0% attn=1.000 d²=1.1102\n", " Block 1 4-simplex: vol²=5.94e-04 valid=100.0% d²=1.0406 heads=['0.26', '0.25', '0.25', '0.25']\n", " Block 2 4-simplex: vol²=8.14e-04 valid=100.0% d²=1.1329 heads=['0.31', '0.28', '0.17', '0.24']\n", " Block 3 4-simplex: vol²=5.48e-04 valid=100.0% d²=1.0050 heads=['0.27', '0.23', '0.25', '0.25']\n", " Block 4 4-simplex: vol²=8.22e-04 valid=100.0% d²=1.1456 heads=['0.28', '0.24', '0.21', '0.27']\n", "\n", "Ep 11 | CE 0.2621 | Tr 90.27% | Te 88.15% | Best 88.28%\n", " Patch k-hierarchy weights: ['0.385', '0.241', '0.232', '0.143']\n", " k=1: pot= 6 vol²=3.67e-01 valid=100.0% attn=1.000 d²=0.3675\n", " k=2: pot= 56 vol²=8.56e-02 valid=100.0% attn=1.000 d²=0.6449\n", " k=3: pot= 1,820 vol²=1.29e-02 valid=100.0% attn=1.000 d²=0.9979\n", " k=4: pot= 201,376 vol²=3.25e-04 valid=100.0% attn=1.000 d²=0.9680\n", " Block 1 4-simplex: vol²=5.81e-04 valid=100.0% d²=1.0404 heads=['0.27', '0.25', '0.24', '0.24']\n", " Block 2 4-simplex: vol²=8.57e-04 valid=100.0% d²=1.1579 heads=['0.32', '0.31', '0.14', '0.22']\n", " Block 3 4-simplex: vol²=6.19e-04 valid=100.0% d²=1.0418 heads=['0.28', '0.20', '0.26', '0.26']\n", " Block 4 4-simplex: vol²=7.56e-04 valid=100.0% d²=1.1167 heads=['0.29', '0.23', '0.22', '0.26']\n", "\n", "Ep 16 | CE 0.2119 | Tr 92.09% | Te 89.15% | Best 89.15%\n", " Patch k-hierarchy weights: ['0.404', '0.242', '0.227', '0.128']\n", " k=1: pot= 6 vol²=3.70e-01 valid=100.0% attn=1.000 d²=0.3697\n", " k=2: pot= 56 vol²=1.13e-01 valid=100.0% attn=1.000 d²=0.7210\n", " k=3: pot= 1,820 vol²=1.06e-02 valid=100.0% attn=1.000 d²=0.9167\n", " k=4: pot= 201,376 vol²=3.76e-04 valid=100.0% attn=1.000 d²=0.9582\n", " Block 1 4-simplex: vol²=5.87e-04 valid=100.0% d²=1.0405 heads=['0.28', '0.25', '0.25', '0.23']\n", " Block 2 4-simplex: vol²=8.92e-04 valid=100.0% d²=1.1717 heads=['0.32', '0.32', '0.15', '0.20']\n", " Block 3 4-simplex: vol²=6.41e-04 valid=100.0% d²=1.0514 heads=['0.31', '0.20', '0.23', '0.26']\n", " Block 4 4-simplex: vol²=7.89e-04 valid=100.0% d²=1.1320 heads=['0.27', '0.23', '0.26', '0.24']\n", "\n", "Ep 21 | CE 0.1558 | Tr 94.25% | Te 89.91% | Best 89.91%\n", " Patch k-hierarchy weights: ['0.405', '0.244', '0.226', '0.125']\n", " k=1: pot= 6 vol²=2.68e-01 valid=100.0% attn=1.000 d²=0.2681\n", " k=2: pot= 56 vol²=5.76e-02 valid=100.0% attn=1.000 d²=0.5542\n", " k=3: pot= 1,820 vol²=7.03e-03 valid=100.0% attn=1.000 d²=0.7703\n", " k=4: pot= 201,376 vol²=4.61e-04 valid=100.0% attn=1.000 d²=0.8753\n", " Block 1 4-simplex: vol²=5.91e-04 valid=100.0% d²=1.0420 heads=['0.27', '0.25', '0.25', '0.23']\n", " Block 2 4-simplex: vol²=8.80e-04 valid=100.0% d²=1.1653 heads=['0.31', '0.32', '0.16', '0.22']\n", " Block 3 4-simplex: vol²=6.73e-04 valid=100.0% d²=1.0667 heads=['0.30', '0.20', '0.23', '0.27']\n", " Block 4 4-simplex: vol²=8.03e-04 valid=100.0% d²=1.1387 heads=['0.25', '0.30', '0.23', '0.22']\n" ] }, { "output_type": "error", "ename": "KeyboardInterrupt", "evalue": "", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", "\u001b[0;32m/tmp/ipython-input-3727292590.py\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 660\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 661\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0m__name__\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"__main__\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 662\u001b[0;31m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m/tmp/ipython-input-3727292590.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m()\u001b[0m\n\u001b[1;32m 600\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 601\u001b[0m \u001b[0mopt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 602\u001b[0;31m \u001b[0mlogits\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 603\u001b[0m \u001b[0mce\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcross_entropy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlogits\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlab\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 604\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mce\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m0.1\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mvloss\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1773\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1774\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1775\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1776\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1777\u001b[0m \u001b[0;31m# torchrec tests the code consistency with the following code\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1784\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1785\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1786\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1787\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1788\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/tmp/ipython-input-3727292590.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x, ret_diag)\u001b[0m\n\u001b[1;32m 538\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 539\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mblock\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mblocks\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 540\u001b[0;31m \u001b[0mpatch_pooled\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mattn_info\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mblock\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpatch_pooled\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mall_verts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mall_feats\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpatch_info\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 541\u001b[0m \u001b[0mblock_infos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mattn_info\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 542\u001b[0m \u001b[0mvalidity_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalidity_loss\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0mattn_info\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'vol2_4'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1773\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1774\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1775\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1776\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1777\u001b[0m \u001b[0;31m# torchrec tests the code consistency with the following code\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1784\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1785\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1786\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1787\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1788\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/tmp/ipython-input-3727292590.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, patch_pooled, patch_verts, patch_feats, patch_info)\u001b[0m\n\u001b[1;32m 442\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpatch_pooled\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpatch_verts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpatch_feats\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpatch_info\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 443\u001b[0m \u001b[0;31m# Attention\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 444\u001b[0;31m attn_out, attn_info = self.attn(\n\u001b[0m\u001b[1;32m 445\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpatch_pooled\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 446\u001b[0m \u001b[0mpatch_verts\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;31m# Highest k-level\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1773\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1774\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1775\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1776\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1777\u001b[0m \u001b[0;31m# torchrec tests the code consistency with the following code\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1784\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1785\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1786\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1787\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1788\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/tmp/ipython-input-3727292590.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, patch_pooled, patch_verts_highest, patch_feats_highest, patch_info)\u001b[0m\n\u001b[1;32m 387\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 388\u001b[0m \u001b[0;31m# CM validation\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 389\u001b[0;31m \u001b[0md2_4\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvol2_4\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalid_4\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_cm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mverts_4\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 390\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 391\u001b[0m \u001b[0;31m# Head validity\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1773\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1774\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1775\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1776\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1777\u001b[0m \u001b[0;31m# torchrec tests the code consistency with the following code\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1784\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1785\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1786\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1787\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1788\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/tmp/ipython-input-3727292590.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, verts)\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[0mnorms\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdiagonal\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgram\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim1\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim2\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 37\u001b[0m \u001b[0md2_mat\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnorms\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munsqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mnorms\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munsqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0;36m2\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mgram\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 38\u001b[0;31m \u001b[0md2_mat\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0md2_mat\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 39\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 40\u001b[0m \u001b[0md2_pairs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0md2_mat\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m...\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_pi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_pj\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/torch/nn/functional.py\u001b[0m in \u001b[0;36mrelu\u001b[0;34m(input, inplace)\u001b[0m\n\u001b[1;32m 1684\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1685\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1686\u001b[0;31m \u001b[0;32mdef\u001b[0m \u001b[0mrelu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minplace\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbool\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# noqa: D400,D402\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1687\u001b[0m r\"\"\"relu(input, inplace=False) -> Tensor\n\u001b[1;32m 1688\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ] }, { "cell_type": "code", "source": [ "#@title Geometric Patchwork ViT - NO FLATTEN\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets, transforms\n", "import math\n", "from itertools import combinations\n", "\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "print(f\"Device: {device}\")\n", "\n", "from geovocab2.shapes.factory.simplex_factory import SimplexFactory\n", "\n", "# ============================================================================\n", "# CAYLEY-MENGER VALIDATOR\n", "# ============================================================================\n", "\n", "class CMValidator(nn.Module):\n", " def __init__(self, k):\n", " super().__init__()\n", " self._k = k\n", " self._nv = k + 1\n", "\n", " pairs = list(combinations(range(self._nv), 2))\n", " self._npairs = len(pairs)\n", " self.register_buffer('_pi', torch.tensor([p[0] for p in pairs], dtype=torch.long))\n", " self.register_buffer('_pj', torch.tensor([p[1] for p in pairs], dtype=torch.long))\n", "\n", " sign = (-1.0) ** (k + 1)\n", " fact = math.factorial(k)\n", " self._prefactor = sign / ((2.0 ** k) * (fact ** 2))\n", "\n", " def forward(self, verts):\n", " \"\"\"verts: [..., nv, edim] - works on any leading dims\"\"\"\n", " gram = torch.einsum('...ve,...we->...vw', verts, verts)\n", " norms = torch.diagonal(gram, dim1=-2, dim2=-1)\n", " d2_mat = norms.unsqueeze(-1) + norms.unsqueeze(-2) - 2 * gram\n", " d2_mat = F.relu(d2_mat)\n", "\n", " d2_pairs = d2_mat[..., self._pi, self._pj]\n", "\n", " shape = d2_mat.shape[:-2]\n", " V = d2_mat.shape[-1]\n", " cm = torch.zeros(*shape, V+1, V+1, device=d2_mat.device, dtype=d2_mat.dtype)\n", " cm[..., 0, 1:] = 1.0\n", " cm[..., 1:, 0] = 1.0\n", " cm[..., 1:, 1:] = d2_mat\n", "\n", " vol2 = self._prefactor * torch.linalg.det(cm)\n", "\n", " return d2_pairs, vol2, (vol2 > 1e-8)\n", "\n", "\n", "# ============================================================================\n", "# GEOMETRIC LEVEL - OPERATES ON [..., nin, fdim]\n", "# ============================================================================\n", "\n", "class GeoLevel(nn.Module):\n", " BASE_DEFORM = 0.05\n", "\n", " def __init__(self, order, nin, nout, fdim, edim):\n", " super().__init__()\n", " self._order = order\n", " self._nin = nin\n", " self._nout = nout\n", " self._fdim = fdim\n", " self._edim = edim\n", " self._nv = order + 1\n", "\n", " self._cm = CMValidator(order)\n", "\n", " factory = SimplexFactory(k=order, embed_dim=edim, method=\"regular\", scale=1.0)\n", " self.register_buffer('_template', factory.build_torch(dtype=torch.float32))\n", "\n", " self.register_parameter('_W_sel', nn.Parameter(torch.randn(nout, self._nv, nin) * 0.1))\n", " self._deform = nn.Linear(fdim, self._nv * edim)\n", " self._transform = nn.Linear(fdim, fdim)\n", " self._gate = nn.Linear(self._cm._npairs + 1, fdim)\n", " self._norm = nn.LayerNorm(fdim)\n", "\n", " def forward(self, feats, coords):\n", " \"\"\"\n", " feats: [..., nin, fdim] - arbitrary leading dims preserved\n", " coords: [..., nin, edim]\n", "\n", " Returns:\n", " out_feats: [..., nout, fdim]\n", " out_coords: [..., nout, edim]\n", " vol2: [..., nout]\n", " \"\"\"\n", " # Soft selection: [nout, nv, nin]\n", " sel = F.softmax(self._W_sel, dim=-1)\n", "\n", " # Select: [..., nin, fdim] @ [nout, nv, nin] -> [..., nout, nv, fdim]\n", " picked_feats = torch.einsum('...nf,ovn->...ovf', feats, sel)\n", " picked_coords = torch.einsum('...ne,ovn->...ove', coords, sel)\n", "\n", " # Aggregate per simplex: [..., nout, fdim]\n", " agg = picked_feats.mean(dim=-2)\n", "\n", " # Deform template: [..., nout, nv, edim]\n", " deform = self._deform(agg).unflatten(-1, (self._nv, self._edim))\n", " verts = self._template + self.BASE_DEFORM * deform + 0.1 * picked_coords\n", "\n", " # CM validation\n", " d2, vol2, valid = self._cm(verts)\n", "\n", " # Geometry gate\n", " geo = torch.cat([\n", " d2 / (d2.mean(dim=-1, keepdim=True) + 1e-8),\n", " vol2.unsqueeze(-1) / (vol2.abs().mean() + 1e-8)\n", " ], dim=-1)\n", " gate = torch.sigmoid(self._gate(geo))\n", "\n", " # Transform with validity attenuation\n", " out_feats = self._norm(self._transform(agg)) * gate * torch.sigmoid(vol2 * 1e6).unsqueeze(-1)\n", " out_coords = verts.mean(dim=-2)\n", "\n", " return out_feats, out_coords, vol2\n", "\n", "\n", "# ============================================================================\n", "# GEOMETRIC COMPLEX - MAINTAINS [...] STRUCTURE\n", "# ============================================================================\n", "\n", "class GeoComplex(nn.Module):\n", " def __init__(self, nbase, fdim, depth, edim):\n", " super().__init__()\n", " self._nbase = nbase\n", " self._fdim = fdim\n", " self._edim = edim\n", " self._depth = depth\n", "\n", " self._sizes = [nbase * (2 ** i) for i in range(depth + 1)]\n", "\n", " self._proj_f = nn.Linear(fdim, nbase * fdim)\n", " self._proj_c = nn.Linear(fdim, nbase * edim)\n", "\n", " self._levels = nn.ModuleList([\n", " GeoLevel(order=i+1, nin=self._sizes[i], nout=self._sizes[i+1], fdim=fdim, edim=edim)\n", " for i in range(depth)\n", " ])\n", "\n", " def forward(self, x):\n", " \"\"\"\n", " x: [..., fdim] - arbitrary leading dims\n", "\n", " Returns:\n", " feats: [..., final_size, fdim]\n", " coords: [..., final_size, edim]\n", " vol2_stack: [..., depth, max_size]\n", " \"\"\"\n", " # Project: [..., fdim] -> [..., nbase, fdim]\n", " feats = self._proj_f(x).unflatten(-1, (self._nbase, self._fdim))\n", " coords = self._proj_c(x).unflatten(-1, (self._nbase, self._edim))\n", "\n", " max_size = self._sizes[-1]\n", " vol2_list = []\n", "\n", " for lv in self._levels:\n", " feats, coords, vol2 = lv(feats, coords)\n", " # Pad to max_size: [..., nout] -> [..., max_size]\n", " pad_size = max_size - vol2.shape[-1]\n", " vol2_padded = F.pad(vol2, (0, pad_size), value=1.0)\n", " vol2_list.append(vol2_padded)\n", "\n", " # Stack: list of [..., max_size] -> [..., depth, max_size]\n", " vol2_stack = torch.stack(vol2_list, dim=-2)\n", "\n", " return feats, coords, vol2_stack\n", "\n", "\n", "# ============================================================================\n", "# GEOMETRIC PATCH EMBEDDING - NO FLATTEN\n", "# ============================================================================\n", "\n", "class GeoPatchEmbed(nn.Module):\n", " def __init__(self, patch_dim, fdim, nbase, depth, edim):\n", " super().__init__()\n", " self._fdim = fdim\n", " self._edim = edim\n", "\n", " self._proj = nn.Sequential(\n", " nn.Linear(patch_dim, fdim),\n", " nn.LayerNorm(fdim),\n", " nn.GELU(),\n", " )\n", " self._complex = GeoComplex(nbase, fdim, depth, edim)\n", " self._final_size = self._complex._sizes[-1]\n", "\n", " def forward(self, patches):\n", " \"\"\"\n", " patches: [B, P, patch_dim]\n", "\n", " Returns:\n", " feats: [B, P, S, fdim] - S = final_size\n", " coords: [B, P, S, edim]\n", " vol2: [B, P, depth, max_size]\n", " \"\"\"\n", " # [B, P, patch_dim] -> [B, P, fdim]\n", " x = self._proj(patches)\n", "\n", " # [B, P, fdim] -> [B, P, S, fdim], [B, P, S, edim], [B, P, depth, max_size]\n", " # GeoComplex operates on last dim, preserves [B, P, ...]\n", " feats, coords, vol2 = self._complex(x)\n", "\n", " return feats, coords, vol2\n", "\n", "\n", "# ============================================================================\n", "# 4-SIMPLEX ATTENTION - OPERATES ON [B, P, S, fdim]\n", "# ============================================================================\n", "\n", "class Simplex4Attn(nn.Module):\n", " def __init__(self, fdim, edim, num_heads=4):\n", " super().__init__()\n", " self._fdim = fdim\n", " self._edim = edim\n", " self._num_heads = num_heads\n", " self._head_dim = fdim // num_heads\n", " self._nv = 5\n", "\n", " self._cm = CMValidator(4)\n", "\n", " factory = SimplexFactory(k=4, embed_dim=edim, method=\"regular\", scale=1.0)\n", " self.register_buffer('_template', factory.build_torch(dtype=torch.float32))\n", "\n", " self._to_sel = nn.Linear(fdim, num_heads * self._nv)\n", " self._to_coord = nn.Linear(fdim, edim)\n", " self._to_v = nn.Linear(fdim, fdim)\n", " self._geo_gate = nn.Linear(self._cm._npairs + 1, fdim)\n", " self._out = nn.Linear(fdim, fdim)\n", "\n", " def forward(self, feats, coords):\n", " \"\"\"\n", " feats: [B, P, S, fdim]\n", " coords: [B, P, S, edim]\n", "\n", " Returns:\n", " out: [B, P, S, fdim]\n", " vol2: [B, H]\n", " \"\"\"\n", " B, P, S, _ = feats.shape\n", " H = self._num_heads\n", "\n", " # Aggregate per patch: [B, P, fdim]\n", " patch_feats = feats.mean(dim=2)\n", " patch_coords = self._to_coord(patch_feats) # [B, P, edim]\n", "\n", " # Selection: [B, P, H*5] -> [B, H, 5, P]\n", " sel = self._to_sel(patch_feats).view(B, P, H, self._nv).permute(0, 2, 3, 1)\n", " sel_w = F.softmax(sel, dim=-1)\n", "\n", " # Select coords: [B, H, 5, P] @ [B, P, edim] -> [B, H, 5, edim]\n", " sel_coords = torch.einsum('bhvp,bpe->bhve', sel_w, patch_coords)\n", "\n", " # Form 4-simplex\n", " verts = self._template + 0.05 * sel_coords\n", " d2, vol2, valid = self._cm(verts) # vol2: [B, H]\n", "\n", " # Gate\n", " geo = torch.cat([d2, vol2.unsqueeze(-1)], dim=-1)\n", " gate = torch.sigmoid(self._geo_gate(geo)).mean(dim=1) # [B, fdim]\n", "\n", " # Values: [B, P, S, fdim] -> [B, P, S, H, hd]\n", " v = self._to_v(feats).view(B, P, S, H, self._head_dim)\n", "\n", " # Attention: [B, H, P] from selection\n", " attn = sel_w.mean(dim=2)\n", " attn = attn / (attn.sum(dim=-1, keepdim=True) + 1e-8)\n", "\n", " # Aggregate: [B, H, P] @ [B, P, S, H, hd] -> [B, H, hd]\n", " v_perm = v.permute(0, 3, 1, 2, 4) # [B, H, P, S, hd]\n", " v_agg = torch.einsum('bhp,bhpd->bhd', attn, v_perm.mean(dim=3))\n", "\n", " # Validity\n", " head_valid = torch.sigmoid(vol2 * 1e6)\n", " v_agg = v_agg * head_valid.unsqueeze(-1)\n", "\n", " # Output\n", " out = self._out(v_agg.reshape(B, self._fdim)) * gate\n", " out = out[:, None, None, :].expand(-1, P, S, -1)\n", "\n", " return out, vol2\n", "\n", "\n", "# ============================================================================\n", "# TRANSFORMER BLOCK\n", "# ============================================================================\n", "\n", "class GeoBlock(nn.Module):\n", " def __init__(self, fdim, edim, num_heads, mlp_ratio=4.0):\n", " super().__init__()\n", " self._norm1 = nn.LayerNorm(fdim)\n", " self._attn = Simplex4Attn(fdim, edim, num_heads)\n", " self._norm2 = nn.LayerNorm(fdim)\n", " self._mlp = nn.Sequential(\n", " nn.Linear(fdim, int(fdim * mlp_ratio)),\n", " nn.GELU(),\n", " nn.Linear(int(fdim * mlp_ratio), fdim),\n", " )\n", "\n", " def forward(self, feats, coords):\n", " \"\"\"\n", " feats: [B, P, S, fdim]\n", " coords: [B, P, S, edim]\n", " \"\"\"\n", " attn_out, vol2 = self._attn(self._norm1(feats), coords)\n", " feats = feats + attn_out\n", " feats = feats + self._mlp(self._norm2(feats))\n", " return feats, vol2\n", "\n", "\n", "# ============================================================================\n", "# FULL MODEL\n", "# ============================================================================\n", "\n", "class GeometricPatchworkViT(nn.Module):\n", " def __init__(\n", " self,\n", " img_size=28,\n", " patch_size=7,\n", " in_chans=1,\n", " num_classes=10,\n", " fdim=64,\n", " edim=8,\n", " nbase=4,\n", " geo_depth=4,\n", " attn_depth=4,\n", " num_heads=4,\n", " ):\n", " super().__init__()\n", "\n", " self.patch_size = patch_size\n", " self.num_patches = (img_size // patch_size) ** 2\n", " patch_dim = patch_size * patch_size * in_chans\n", "\n", " self._embed = GeoPatchEmbed(patch_dim, fdim, nbase, geo_depth, edim)\n", " self._pos = nn.Parameter(torch.randn(1, self.num_patches, 1, edim) * 0.02)\n", "\n", " self._blocks = nn.ModuleList([\n", " GeoBlock(fdim, edim, num_heads) for _ in range(attn_depth)\n", " ])\n", "\n", " self._norm = nn.LayerNorm(fdim)\n", " self._head = nn.Linear(fdim, num_classes)\n", "\n", " print(f\"\\nGeometricPatchworkViT (No Flatten):\")\n", " print(f\" Patches: {self.num_patches}\")\n", " print(f\" K-hierarchy sizes: {self._embed._complex._sizes}\")\n", " print(f\" Structure: [B, {self.num_patches}, {self._embed._final_size}, {fdim}]\")\n", "\n", " def forward(self, x):\n", " \"\"\"\n", " x: [B, C, H, W]\n", "\n", " Returns:\n", " logits: [B, num_classes]\n", " vol2_patch: [B, P, depth, max_size]\n", " vol2_attn: [B, attn_depth, H]\n", " \"\"\"\n", " B = x.shape[0]\n", "\n", " # Extract patches: [B, P, patch_dim]\n", " patches = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)\n", " patches = patches.contiguous().view(B, self.num_patches, -1)\n", "\n", " # Embed: [B, P, S, fdim], [B, P, S, edim], [B, P, depth, max_size]\n", " feats, coords, vol2_patch = self._embed(patches)\n", " coords = coords + self._pos\n", "\n", " # Blocks\n", " vol2_attn = []\n", " for blk in self._blocks:\n", " feats, vol2 = blk(feats, coords)\n", " vol2_attn.append(vol2)\n", "\n", " vol2_attn = torch.stack(vol2_attn, dim=1) # [B, attn_depth, H]\n", "\n", " # Classify\n", " logits = self._head(self._norm(feats.mean(dim=[1, 2])))\n", "\n", " return logits, vol2_patch, vol2_attn\n", "\n", "\n", "# ============================================================================\n", "# LOSS FUNCTION\n", "# ============================================================================\n", "\n", "def geometric_loss(logits, labels, vol2_patch, vol2_attn, ce_weight=1.0, validity_weight=0.1):\n", " ce = F.cross_entropy(logits, labels)\n", " validity = F.relu(-vol2_patch).mean() + F.relu(-vol2_attn).mean()\n", "\n", " total = ce_weight * ce + validity_weight * validity\n", "\n", " with torch.no_grad():\n", " info = {\n", " 'ce': ce.item(),\n", " 'validity': validity.item(),\n", " 'patch_valid': (vol2_patch > 0).float().mean().item(),\n", " 'attn_valid': (vol2_attn > 0).float().mean().item(),\n", " }\n", "\n", " return total, info\n", "\n", "\n", "# ============================================================================\n", "# TRAIN\n", "# ============================================================================\n", "\n", "def train():\n", " transform = transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.2860,), (0.3530,))\n", " ])\n", "\n", " train_ds = datasets.FashionMNIST('./data', train=True, download=True, transform=transform)\n", " test_ds = datasets.FashionMNIST('./data', train=False, download=True, transform=transform)\n", " train_dl = DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=2)\n", " test_dl = DataLoader(test_ds, batch_size=128, shuffle=False, num_workers=2)\n", "\n", " print(\"Building model...\")\n", " model = GeometricPatchworkViT(\n", " img_size=28,\n", " patch_size=7,\n", " in_chans=1,\n", " num_classes=10,\n", " fdim=64,\n", " edim=8,\n", " nbase=4,\n", " geo_depth=4,\n", " attn_depth=4,\n", " num_heads=4,\n", " ).to(device)\n", "\n", " params = sum(p.numel() for p in model.parameters())\n", " print(f\"Total params: {params:,}\")\n", "\n", " opt = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.05)\n", " sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=30)\n", "\n", " best = 0\n", " print(\"\\nTraining...\")\n", " print(\"=\" * 90)\n", "\n", " for ep in range(30):\n", " model.train()\n", " ce_sum, val_sum, cor, tot = 0, 0, 0, 0\n", "\n", " for img, lab in train_dl:\n", " img, lab = img.to(device), lab.to(device)\n", "\n", " opt.zero_grad()\n", " logits, vol2_p, vol2_a = model(img)\n", " loss, info = geometric_loss(logits, lab, vol2_p, vol2_a)\n", " loss.backward()\n", " torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n", " opt.step()\n", "\n", " ce_sum += info['ce'] * img.size(0)\n", " val_sum += info['validity'] * img.size(0)\n", " cor += (logits.argmax(1) == lab).sum().item()\n", " tot += img.size(0)\n", "\n", " tr_acc = cor / tot\n", "\n", " model.eval()\n", " cor, tot = 0, 0\n", " p_valid, a_valid = 0, 0\n", " with torch.no_grad():\n", " for img, lab in test_dl:\n", " img, lab = img.to(device), lab.to(device)\n", " logits, vol2_p, vol2_a = model(img)\n", " cor += (logits.argmax(1) == lab).sum().item()\n", " tot += img.size(0)\n", " p_valid += (vol2_p > 0).float().mean().item() * img.size(0)\n", " a_valid += (vol2_a > 0).float().mean().item() * img.size(0)\n", "\n", " te_acc = cor / tot\n", " p_valid /= tot\n", " a_valid /= tot\n", " sched.step()\n", " if te_acc > best:\n", " best = te_acc\n", "\n", " if ep % 5 == 0 or ep == 29:\n", " print(f\"Ep {ep+1:2d} | CE {ce_sum/tot:.4f} | Val {val_sum/tot:.4f} | \"\n", " f\"Tr {tr_acc:.2%} | Te {te_acc:.2%} | Best {best:.2%} | \"\n", " f\"Patch {p_valid:.1%} | Attn {a_valid:.1%}\")\n", "\n", " print(\"=\" * 90)\n", " print(f\"Best: {best:.2%}\")\n", " return model\n", "\n", "\n", "if __name__ == \"__main__\":\n", " model = train()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "gUA5MxPslWZV", "outputId": "92b842ab-ed95-40de-dfba-0ed2e0331337" }, "execution_count": 3, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Device: cuda\n", "Building model...\n", "\n", "GeometricPatchworkViT (No Flatten):\n", " Patches: 16\n", " K-hierarchy sizes: [4, 8, 16, 32, 64]\n", " Structure: [B, 16, 64, 64]\n", "Total params: 238,922\n", "\n", "Training...\n", "==========================================================================================\n", "Ep 1 | CE 4.9395 | Val 0.0000 | Tr 70.26% | Te 76.86% | Best 76.86% | Patch 100.0% | Attn 100.0%\n", "Ep 6 | CE 2.5194 | Val 0.0000 | Tr 84.74% | Te 83.96% | Best 83.96% | Patch 100.0% | Attn 100.0%\n", "Ep 11 | CE 2.0789 | Val 0.0000 | Tr 87.31% | Te 85.96% | Best 85.96% | Patch 100.0% | Attn 100.0%\n", "Ep 16 | CE 1.7462 | Val 0.0000 | Tr 89.17% | Te 86.77% | Best 86.82% | Patch 100.0% | Attn 100.0%\n", "Ep 21 | CE 1.4106 | Val 0.0000 | Tr 91.19% | Te 87.88% | Best 88.07% | Patch 100.0% | Attn 100.0%\n", "Ep 26 | CE 1.0994 | Val 0.0000 | Tr 93.19% | Te 88.31% | Best 88.34% | Patch 100.0% | Attn 100.0%\n", "Ep 30 | CE 0.9765 | Val 0.0000 | Tr 94.16% | Te 88.27% | Best 88.37% | Patch 100.0% | Attn 100.0%\n", "==========================================================================================\n", "Best: 88.37%\n" ] } ] }, { "cell_type": "code", "source": [ "#@title Geometric Patchwork - K-Simplex as Channels\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets, transforms\n", "import math\n", "from itertools import combinations\n", "\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "print(f\"Device: {device}\")\n", "\n", "from geovocab2.shapes.factory.simplex_factory import SimplexFactory\n", "\n", "# ============================================================================\n", "# CAYLEY-MENGER VALIDATOR\n", "# ============================================================================\n", "\n", "class CMValidator(nn.Module):\n", " def __init__(self, k):\n", " super().__init__()\n", " self._k = k\n", " self._nv = k + 1\n", "\n", " pairs = list(combinations(range(self._nv), 2))\n", " self._npairs = len(pairs)\n", " self.register_buffer('_pi', torch.tensor([p[0] for p in pairs], dtype=torch.long))\n", " self.register_buffer('_pj', torch.tensor([p[1] for p in pairs], dtype=torch.long))\n", "\n", " sign = (-1.0) ** (k + 1)\n", " fact = math.factorial(k)\n", " self._prefactor = sign / ((2.0 ** k) * (fact ** 2))\n", "\n", " def forward(self, verts):\n", " \"\"\"verts: [..., nv, edim]\"\"\"\n", " gram = torch.einsum('...ve,...we->...vw', verts, verts)\n", " norms = torch.diagonal(gram, dim1=-2, dim2=-1)\n", " d2_mat = norms.unsqueeze(-1) + norms.unsqueeze(-2) - 2 * gram\n", " d2_mat = F.relu(d2_mat)\n", "\n", " d2_pairs = d2_mat[..., self._pi, self._pj]\n", "\n", " shape = d2_mat.shape[:-2]\n", " V = d2_mat.shape[-1]\n", " cm = torch.zeros(*shape, V+1, V+1, device=d2_mat.device, dtype=d2_mat.dtype)\n", " cm[..., 0, 1:] = 1.0\n", " cm[..., 1:, 0] = 1.0\n", " cm[..., 1:, 1:] = d2_mat\n", "\n", " vol2 = self._prefactor * torch.linalg.det(cm)\n", "\n", " return d2_pairs, vol2\n", "\n", "\n", "# ============================================================================\n", "# K-SIMPLEX CHANNEL ENCODER\n", "# ============================================================================\n", "\n", "class KSimplexChannel(nn.Module):\n", " \"\"\"\n", " Encodes input into a single k-simplex representation.\n", " Output is the geometric features: [d², vol²] flattened.\n", " \"\"\"\n", " BASE_DEFORM = 0.05\n", "\n", " def __init__(self, k, in_dim, edim):\n", " super().__init__()\n", " self._k = k\n", " self._nv = k + 1\n", " self._edim = edim\n", "\n", " self._cm = CMValidator(k)\n", " self._out_dim = self._cm._npairs + 1 # distances + volume\n", "\n", " # Factory template\n", " factory = SimplexFactory(k=k, embed_dim=edim, method=\"regular\", scale=1.0)\n", " self.register_buffer('_template', factory.build_torch(dtype=torch.float32))\n", "\n", " # Input → vertex deformations\n", " self._to_deform = nn.Linear(in_dim, self._nv * edim)\n", "\n", " @property\n", " def out_dim(self):\n", " return self._out_dim\n", "\n", " def forward(self, x):\n", " \"\"\"\n", " x: [..., in_dim]\n", "\n", " Returns:\n", " geo: [..., npairs + 1] (d² pairs + vol²)\n", " vol2: [...] (for validity loss)\n", " \"\"\"\n", " # Deform template\n", " deform = self._to_deform(x).unflatten(-1, (self._nv, self._edim))\n", " verts = self._template + self.BASE_DEFORM * deform\n", "\n", " # CM validation\n", " d2, vol2 = self._cm(verts)\n", "\n", " # Geometric features\n", " geo = torch.cat([d2, vol2.unsqueeze(-1)], dim=-1)\n", "\n", " return geo, vol2\n", "\n", "\n", "# ============================================================================\n", "# PATCH TO K-SIMPLEX CHANNELS\n", "# ============================================================================\n", "\n", "class PatchToKChannels(nn.Module):\n", " \"\"\"\n", " Converts a patch into k-simplex channel representation.\n", "\n", " Input: [B, P_h, P_w, C*H*W]\n", " Output: [B, P_h, P_w, K, F] where K=depth k-levels, F=features per level\n", " \"\"\"\n", " def __init__(self, patch_dim, depth, edim):\n", " super().__init__()\n", " self._depth = depth\n", " self._edim = edim\n", "\n", " # Project patch to intermediate\n", " hidden = max(patch_dim, 64)\n", " self._proj = nn.Sequential(\n", " nn.Linear(patch_dim, hidden),\n", " nn.LayerNorm(hidden),\n", " nn.GELU(),\n", " )\n", "\n", " # K-simplex encoders: k=1, 2, 3, 4, ...\n", " self._k_encoders = nn.ModuleList([\n", " KSimplexChannel(k=k+1, in_dim=hidden, edim=edim)\n", " for k in range(depth)\n", " ])\n", "\n", " # Feature dim per level (varies by k: npairs + 1)\n", " self._k_dims = [enc.out_dim for enc in self._k_encoders]\n", " # k=1: 1+1=2, k=2: 3+1=4, k=3: 6+1=7, k=4: 10+1=11\n", "\n", " # Pad to uniform size for stacking\n", " self._max_dim = max(self._k_dims)\n", "\n", " def forward(self, patches):\n", " \"\"\"\n", " patches: [B, P_h, P_w, patch_dim]\n", "\n", " Returns:\n", " k_channels: [B, P_h, P_w, K, F] K=depth, F=max_dim\n", " vol2: [B, P_h, P_w, K] for loss\n", " \"\"\"\n", " # Project\n", " h = self._proj(patches) # [B, P_h, P_w, hidden]\n", "\n", " # Encode each k-level\n", " geo_list = []\n", " vol2_list = []\n", "\n", " for enc in self._k_encoders:\n", " geo, vol2 = enc(h) # geo: [..., k_dim], vol2: [...]\n", "\n", " # Pad to max_dim\n", " pad_size = self._max_dim - geo.shape[-1]\n", " if pad_size > 0:\n", " geo = F.pad(geo, (0, pad_size))\n", "\n", " geo_list.append(geo)\n", " vol2_list.append(vol2)\n", "\n", " # Stack: [B, P_h, P_w, K, F]\n", " k_channels = torch.stack(geo_list, dim=-2)\n", " vol2 = torch.stack(vol2_list, dim=-1) # [B, P_h, P_w, K]\n", "\n", " return k_channels, vol2\n", "\n", "\n", "# ============================================================================\n", "# K-SIMPLEX CHANNEL ATTENTION\n", "# ============================================================================\n", "\n", "class KChannelAttention(nn.Module):\n", " \"\"\"\n", " Attention over patches using their k-simplex channel structure.\n", "\n", " Input: [B, P_h, P_w, K, F]\n", " Output: [B, P_h, P_w, K, F]\n", " \"\"\"\n", " def __init__(self, depth, feat_dim, num_heads=4):\n", " super().__init__()\n", " self._depth = depth\n", " self._feat_dim = feat_dim\n", " self._num_heads = num_heads\n", "\n", " # Total features across k-channels\n", " total_dim = depth * feat_dim\n", "\n", " self._norm = nn.LayerNorm(total_dim)\n", " self._to_qkv = nn.Linear(total_dim, 3 * total_dim)\n", " self._out = nn.Linear(total_dim, total_dim)\n", "\n", " self._scale = (total_dim // num_heads) ** -0.5\n", "\n", " def forward(self, x):\n", " \"\"\"\n", " x: [B, P_h, P_w, K, F]\n", " \"\"\"\n", " B, Ph, Pw, K, F = x.shape\n", "\n", " # Flatten k-channels: [B, P_h, P_w, K*F]\n", " x_flat = x.flatten(-2)\n", "\n", " # Reshape to sequence: [B, P_h*P_w, K*F]\n", " x_seq = x_flat.view(B, Ph * Pw, K * F)\n", "\n", " # Attention\n", " x_norm = self._norm(x_seq)\n", " qkv = self._to_qkv(x_norm).chunk(3, dim=-1)\n", " q, k, v = [t.view(B, Ph * Pw, self._num_heads, -1).transpose(1, 2) for t in qkv]\n", "\n", " attn = (q @ k.transpose(-2, -1)) * self._scale\n", " attn = attn.softmax(dim=-1)\n", "\n", " out = (attn @ v).transpose(1, 2).reshape(B, Ph * Pw, K * F)\n", " out = self._out(out)\n", "\n", " # Reshape back: [B, P_h, P_w, K, F]\n", " out = out.view(B, Ph, Pw, K, F)\n", "\n", " return x + out\n", "\n", "\n", "# ============================================================================\n", "# FULL MODEL\n", "# ============================================================================\n", "\n", "class GeometricPatchworkViT(nn.Module):\n", " def __init__(\n", " self,\n", " img_size=28,\n", " patch_size=7,\n", " in_chans=1,\n", " num_classes=10,\n", " depth=4, # k=1,2,3,4\n", " edim=8,\n", " num_heads=4,\n", " num_blocks=4,\n", " ):\n", " super().__init__()\n", "\n", " self._patch_size = patch_size\n", " self._ph = img_size // patch_size\n", " self._pw = img_size // patch_size\n", " self._patch_dim = patch_size * patch_size * in_chans\n", "\n", " # Patch to k-channels\n", " self._patch_enc = PatchToKChannels(self._patch_dim, depth, edim)\n", " self._max_dim = self._patch_enc._max_dim\n", "\n", " # Position embedding: [1, P_h, P_w, K, F]\n", " self._pos = nn.Parameter(torch.randn(1, self._ph, self._pw, depth, self._max_dim) * 0.02)\n", "\n", " # Attention blocks\n", " self._blocks = nn.ModuleList([\n", " KChannelAttention(depth, self._max_dim, num_heads)\n", " for _ in range(num_blocks)\n", " ])\n", "\n", " # Classification\n", " self._norm = nn.LayerNorm(depth * self._max_dim)\n", " self._head = nn.Linear(depth * self._max_dim, num_classes)\n", "\n", " print(f\"\\nGeometricPatchworkViT:\")\n", " print(f\" Image: {img_size}×{img_size}, Patch: {patch_size}×{patch_size}\")\n", " print(f\" Grid: {self._ph}×{self._pw} = {self._ph * self._pw} patches\")\n", " print(f\" K-simplex channels: k=1..{depth}\")\n", " print(f\" Feature dims per k: {self._patch_enc._k_dims}\")\n", " print(f\" Structure: [B, {self._ph}, {self._pw}, {depth}, {self._max_dim}]\")\n", "\n", " def forward(self, x):\n", " \"\"\"\n", " x: [B, C, H, W]\n", "\n", " Returns:\n", " logits: [B, num_classes]\n", " vol2: [B, P_h, P_w, K]\n", " \"\"\"\n", " B = x.shape[0]\n", "\n", " # Extract patches: [B, P_h, P_w, patch_dim]\n", " patches = x.unfold(2, self._patch_size, self._patch_size) \\\n", " .unfold(3, self._patch_size, self._patch_size)\n", " patches = patches.permute(0, 2, 3, 1, 4, 5).contiguous()\n", " patches = patches.view(B, self._ph, self._pw, -1)\n", "\n", " # Encode to k-channels: [B, P_h, P_w, K, F]\n", " k_channels, vol2 = self._patch_enc(patches)\n", "\n", " # Add position\n", " k_channels = k_channels + self._pos\n", "\n", " # Attention blocks\n", " for blk in self._blocks:\n", " k_channels = blk(k_channels)\n", "\n", " # Classify: mean over spatial, flatten k-channels\n", " pooled = k_channels.mean(dim=[1, 2]) # [B, K, F]\n", " pooled = pooled.flatten(1) # [B, K*F]\n", " logits = self._head(self._norm(pooled))\n", "\n", " return logits, vol2\n", "\n", "\n", "# ============================================================================\n", "# LOSS\n", "# ============================================================================\n", "\n", "def geometric_loss(logits, labels, vol2, ce_weight=1.0, validity_weight=0.1):\n", " ce = F.cross_entropy(logits, labels)\n", " validity = F.relu(-vol2).mean()\n", "\n", " total = ce_weight * ce + validity_weight * validity\n", "\n", " with torch.no_grad():\n", " info = {\n", " 'ce': ce.item(),\n", " 'validity': validity.item(),\n", " 'valid_rate': (vol2 > 0).float().mean().item(),\n", " }\n", "\n", " return total, info\n", "\n", "\n", "# ============================================================================\n", "# TRAIN\n", "# ============================================================================\n", "\n", "def train():\n", " transform = transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.2860,), (0.3530,))\n", " ])\n", "\n", " train_ds = datasets.FashionMNIST('./data', train=True, download=True, transform=transform)\n", " test_ds = datasets.FashionMNIST('./data', train=False, download=True, transform=transform)\n", " train_dl = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=2)\n", " test_dl = DataLoader(test_ds, batch_size=256, shuffle=False, num_workers=2)\n", "\n", " print(\"Building model...\")\n", " model = GeometricPatchworkViT(\n", " img_size=28,\n", " patch_size=2,\n", " in_chans=1,\n", " num_classes=10,\n", " depth=4,\n", " edim=8,\n", " num_heads=4,\n", " num_blocks=4,\n", " ).to(device)\n", "\n", " params = sum(p.numel() for p in model.parameters())\n", " print(f\"Params: {params:,}\")\n", "\n", " model = torch.compile(model, mode='reduce-overhead')\n", "\n", " opt = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.05)\n", " sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=30)\n", "\n", " best = 0\n", " print(\"\\nTraining...\")\n", " print(\"=\" * 85)\n", "\n", " for ep in range(30):\n", " model.train()\n", " ce_sum, val_sum, cor, tot = 0, 0, 0, 0\n", "\n", " for img, lab in train_dl:\n", " img, lab = img.to(device), lab.to(device)\n", "\n", " opt.zero_grad()\n", " logits, vol2 = model(img)\n", " loss, info = geometric_loss(logits, lab, vol2)\n", " loss.backward()\n", " torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n", " opt.step()\n", "\n", " ce_sum += info['ce'] * img.size(0)\n", " val_sum += info['validity'] * img.size(0)\n", " cor += (logits.argmax(1) == lab).sum().item()\n", " tot += img.size(0)\n", "\n", " tr_acc = cor / tot\n", "\n", " model.eval()\n", " cor, tot, valid_sum = 0, 0, 0\n", " with torch.no_grad():\n", " for img, lab in test_dl:\n", " img, lab = img.to(device), lab.to(device)\n", " logits, vol2 = model(img)\n", " cor += (logits.argmax(1) == lab).sum().item()\n", " tot += img.size(0)\n", " valid_sum += (vol2 > 0).float().mean().item() * img.size(0)\n", "\n", " te_acc = cor / tot\n", " valid_rate = valid_sum / tot\n", " sched.step()\n", " if te_acc > best:\n", " best = te_acc\n", "\n", " if ep % 5 == 0 or ep == 29:\n", " print(f\"Ep {ep+1:2d} | CE {ce_sum/tot:.4f} | Val {val_sum/tot:.4f} | \"\n", " f\"Tr {tr_acc:.2%} | Te {te_acc:.2%} | Best {best:.2%} | Valid {valid_rate:.1%}\")\n", "\n", " print(\"=\" * 85)\n", " print(f\"Best: {best:.2%}\")\n", " return model\n", "\n", "\n", "if __name__ == \"__main__\":\n", " model = train()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "kSPu4uK7zHZq", "outputId": "0ae2ffc4-415b-4dd2-ad48-e36d63a7ae7e" }, "execution_count": 5, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Device: cuda\n", "Building model...\n", "\n", "GeometricPatchworkViT:\n", " Image: 28×28, Patch: 2×2\n", " Grid: 14×14 = 196 patches\n", " K-simplex channels: k=1..4\n", " Feature dims per k: [2, 4, 7, 11]\n", " Structure: [B, 14, 14, 4, 11]\n", "Params: 48,922\n", "\n", "Training...\n", "=====================================================================================\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "/usr/local/lib/python3.12/dist-packages/torch/_inductor/lowering.py:1744: FutureWarning: `torch._prims_common.check` is deprecated and will be removed in the future. Please use `torch._check*` functions instead.\n", " check(\n", "/usr/local/lib/python3.12/dist-packages/torch/_inductor/lowering.py:1744: FutureWarning: `torch._prims_common.check` is deprecated and will be removed in the future. Please use `torch._check*` functions instead.\n", " check(\n", "/usr/local/lib/python3.12/dist-packages/torch/_inductor/lowering.py:1744: FutureWarning: `torch._prims_common.check` is deprecated and will be removed in the future. Please use `torch._check*` functions instead.\n", " check(\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "Ep 1 | CE 7.3460 | Val 0.0000 | Tr 53.06% | Te 77.26% | Best 77.26% | Valid 100.0%\n", "Ep 6 | CE 2.2888 | Val 0.0000 | Tr 85.96% | Te 85.46% | Best 85.46% | Valid 100.0%\n", "Ep 11 | CE 1.8840 | Val 0.0000 | Tr 88.44% | Te 87.16% | Best 87.77% | Valid 100.0%\n", "Ep 16 | CE 1.6329 | Val 0.0000 | Tr 89.96% | Te 87.67% | Best 88.12% | Valid 100.0%\n", "Ep 21 | CE 1.4449 | Val 0.0000 | Tr 91.02% | Te 88.76% | Best 88.82% | Valid 100.0%\n", "Ep 26 | CE 1.2977 | Val 0.0000 | Tr 92.05% | Te 88.88% | Best 89.14% | Valid 100.0%\n", "Ep 30 | CE 1.2468 | Val 0.0000 | Tr 92.40% | Te 89.05% | Best 89.14% | Valid 100.0%\n", "=====================================================================================\n", "Best: 89.14%\n" ] } ] }, { "cell_type": "code", "source": [ "#@title Geometric Patchwork - CIFAR-10, Cross-Attention, Compiled\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets, transforms\n", "import math\n", "from itertools import combinations\n", "import time\n", "\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "print(f\"Device: {device}\")\n", "\n", "from geovocab2.shapes.factory.simplex_factory import SimplexFactory\n", "\n", "# ============================================================================\n", "# CAYLEY-MENGER VALIDATOR\n", "# ============================================================================\n", "\n", "class CMValidator(nn.Module):\n", " def __init__(self, k):\n", " super().__init__()\n", " self._k = k\n", " self._nv = k + 1\n", "\n", " pairs = list(combinations(range(self._nv), 2))\n", " self._npairs = len(pairs)\n", " self.register_buffer('_pi', torch.tensor([p[0] for p in pairs], dtype=torch.long))\n", " self.register_buffer('_pj', torch.tensor([p[1] for p in pairs], dtype=torch.long))\n", "\n", " sign = (-1.0) ** (k + 1)\n", " fact = math.factorial(k)\n", " self._prefactor = sign / ((2.0 ** k) * (fact ** 2))\n", "\n", " def forward(self, verts):\n", " gram = torch.einsum('...ve,...we->...vw', verts, verts)\n", " norms = torch.diagonal(gram, dim1=-2, dim2=-1)\n", " d2_mat = norms.unsqueeze(-1) + norms.unsqueeze(-2) - 2 * gram\n", " d2_mat = F.relu(d2_mat)\n", "\n", " d2_pairs = d2_mat[..., self._pi, self._pj]\n", "\n", " shape = d2_mat.shape[:-2]\n", " V = d2_mat.shape[-1]\n", " cm = torch.zeros(*shape, V+1, V+1, device=d2_mat.device, dtype=d2_mat.dtype)\n", " cm[..., 0, 1:] = 1.0\n", " cm[..., 1:, 0] = 1.0\n", " cm[..., 1:, 1:] = d2_mat\n", "\n", " vol2 = self._prefactor * torch.linalg.det(cm)\n", "\n", " return d2_pairs, vol2\n", "\n", "\n", "# ============================================================================\n", "# K-SIMPLEX CHANNEL ENCODER\n", "# ============================================================================\n", "\n", "class KSimplexChannel(nn.Module):\n", " BASE_DEFORM = 0.05\n", "\n", " def __init__(self, k, in_dim, edim):\n", " super().__init__()\n", " self._k = k\n", " self._nv = k + 1\n", " self._edim = edim\n", "\n", " self._cm = CMValidator(k)\n", " self._out_dim = self._cm._npairs + 1\n", "\n", " factory = SimplexFactory(k=k, embed_dim=edim, method=\"regular\", scale=1.0)\n", " self.register_buffer('_template', factory.build_torch(dtype=torch.float32))\n", "\n", " self._to_deform = nn.Linear(in_dim, self._nv * edim)\n", "\n", " @property\n", " def out_dim(self):\n", " return self._out_dim\n", "\n", " def forward(self, x):\n", " deform = self._to_deform(x).unflatten(-1, (self._nv, self._edim))\n", " verts = self._template + self.BASE_DEFORM * deform\n", " d2, vol2 = self._cm(verts)\n", " geo = torch.cat([d2, vol2.unsqueeze(-1)], dim=-1)\n", " return geo, vol2\n", "\n", "\n", "# ============================================================================\n", "# PATCH TO K-SIMPLEX CHANNELS\n", "# ============================================================================\n", "\n", "class PatchToKChannels(nn.Module):\n", " def __init__(self, patch_dim, depth, edim, hidden=128):\n", " super().__init__()\n", " self._depth = depth\n", " self._edim = edim\n", "\n", " self._proj = nn.Sequential(\n", " nn.Linear(patch_dim, hidden),\n", " nn.LayerNorm(hidden),\n", " nn.GELU(),\n", " )\n", "\n", " self._k_encoders = nn.ModuleList([\n", " KSimplexChannel(k=k+1, in_dim=hidden, edim=edim)\n", " for k in range(depth)\n", " ])\n", "\n", " self._k_dims = [enc.out_dim for enc in self._k_encoders]\n", " self._max_dim = max(self._k_dims)\n", "\n", " def forward(self, patches):\n", " h = self._proj(patches)\n", "\n", " geo_list = []\n", " vol2_list = []\n", " d2_list = []\n", "\n", " for enc in self._k_encoders:\n", " geo, vol2 = enc(h)\n", " d2 = geo[..., :-1] # All but last (volume)\n", "\n", " pad_size = self._max_dim - geo.shape[-1]\n", " if pad_size > 0:\n", " geo = F.pad(geo, (0, pad_size))\n", "\n", " geo_list.append(geo)\n", " vol2_list.append(vol2)\n", " d2_list.append(d2.mean(dim=-1)) # Mean d² per k\n", "\n", " k_channels = torch.stack(geo_list, dim=-2)\n", " vol2 = torch.stack(vol2_list, dim=-1)\n", " d2_mean = torch.stack(d2_list, dim=-1)\n", "\n", " return k_channels, vol2, d2_mean\n", "\n", "\n", "# ============================================================================\n", "# K-CHANNEL CROSS-ATTENTION\n", "# ============================================================================\n", "\n", "class KChannelCrossAttention(nn.Module):\n", " \"\"\"\n", " Cross-attention between k-levels.\n", " Each k-level attends to all other k-levels.\n", " \"\"\"\n", " def __init__(self, depth, feat_dim, num_heads=4):\n", " super().__init__()\n", " self._depth = depth\n", " self._feat_dim = feat_dim\n", " self._num_heads = num_heads\n", " self._head_dim = feat_dim // num_heads\n", "\n", " self._norm_q = nn.LayerNorm(feat_dim)\n", " self._norm_kv = nn.LayerNorm(feat_dim)\n", "\n", " self._to_q = nn.Linear(feat_dim, feat_dim)\n", " self._to_k = nn.Linear(feat_dim, feat_dim)\n", " self._to_v = nn.Linear(feat_dim, feat_dim)\n", " self._out = nn.Linear(feat_dim, feat_dim)\n", "\n", " self._scale = self._head_dim ** -0.5\n", "\n", " def forward(self, x):\n", " \"\"\"\n", " x: [B, P_h, P_w, K, F]\n", "\n", " Cross-attention across K dimension for each spatial location.\n", " \"\"\"\n", " B, Ph, Pw, K, F = x.shape\n", "\n", " # Reshape: [B*Ph*Pw, K, F]\n", " x_flat = x.view(B * Ph * Pw, K, F)\n", "\n", " q = self._to_q(self._norm_q(x_flat))\n", " k = self._to_k(self._norm_kv(x_flat))\n", " v = self._to_v(self._norm_kv(x_flat))\n", "\n", " # Multi-head: [B*Ph*Pw, K, H, D] -> [B*Ph*Pw, H, K, D]\n", " q = q.view(-1, K, self._num_heads, self._head_dim).transpose(1, 2)\n", " k = k.view(-1, K, self._num_heads, self._head_dim).transpose(1, 2)\n", " v = v.view(-1, K, self._num_heads, self._head_dim).transpose(1, 2)\n", "\n", " attn = (q @ k.transpose(-2, -1)) * self._scale\n", " attn = attn.softmax(dim=-1)\n", "\n", " out = (attn @ v).transpose(1, 2).reshape(B * Ph * Pw, K, F)\n", " out = self._out(out)\n", "\n", " return x + out.view(B, Ph, Pw, K, F), attn.view(B, Ph, Pw, self._num_heads, K, K)\n", "\n", "\n", "# ============================================================================\n", "# SPATIAL ATTENTION\n", "# ============================================================================\n", "\n", "class SpatialAttention(nn.Module):\n", " \"\"\"\n", " Attention across spatial patches, preserving k-channel structure.\n", " \"\"\"\n", " def __init__(self, depth, feat_dim, num_heads=4):\n", " super().__init__()\n", " self._num_heads = num_heads\n", "\n", " total_dim = depth * feat_dim\n", " self._head_dim = total_dim // num_heads\n", "\n", " self._norm = nn.LayerNorm(total_dim)\n", " self._to_qkv = nn.Linear(total_dim, 3 * total_dim)\n", " self._out = nn.Linear(total_dim, total_dim)\n", "\n", " self._scale = self._head_dim ** -0.5\n", "\n", " def forward(self, x):\n", " \"\"\"\n", " x: [B, P_h, P_w, K, F]\n", " \"\"\"\n", " B, Ph, Pw, K, F = x.shape\n", " N = Ph * Pw\n", "\n", " x_flat = x.view(B, N, K * F)\n", " x_norm = self._norm(x_flat)\n", "\n", " qkv = self._to_qkv(x_norm).chunk(3, dim=-1)\n", " q, k, v = [t.view(B, N, self._num_heads, self._head_dim).transpose(1, 2) for t in qkv]\n", "\n", " attn = (q @ k.transpose(-2, -1)) * self._scale\n", " attn = attn.softmax(dim=-1)\n", "\n", " out = (attn @ v).transpose(1, 2).reshape(B, N, K * F)\n", " out = self._out(out).view(B, Ph, Pw, K, F)\n", "\n", " return x + out, attn\n", "\n", "\n", "# ============================================================================\n", "# TRANSFORMER BLOCK\n", "# ============================================================================\n", "\n", "class GeoBlock(nn.Module):\n", " def __init__(self, depth, feat_dim, num_heads, mlp_ratio=4.0):\n", " super().__init__()\n", "\n", " # Cross-attention across k-levels\n", " self._k_attn = KChannelCrossAttention(depth, feat_dim, num_heads)\n", "\n", " # Spatial attention\n", " self._spatial_attn = SpatialAttention(depth, feat_dim, num_heads)\n", "\n", " # MLP\n", " total_dim = depth * feat_dim\n", " self._norm = nn.LayerNorm(total_dim)\n", " self._mlp = nn.Sequential(\n", " nn.Linear(total_dim, int(total_dim * mlp_ratio)),\n", " nn.GELU(),\n", " nn.Linear(int(total_dim * mlp_ratio), total_dim),\n", " )\n", "\n", " def forward(self, x):\n", " \"\"\"\n", " x: [B, P_h, P_w, K, F]\n", " \"\"\"\n", " B, Ph, Pw, K, F = x.shape\n", "\n", " # K-channel cross-attention\n", " x, k_attn = self._k_attn(x)\n", "\n", " # Spatial attention\n", " x, s_attn = self._spatial_attn(x)\n", "\n", " # MLP\n", " x_flat = x.view(B, Ph, Pw, K * F)\n", " x_flat = x_flat + self._mlp(self._norm(x_flat))\n", " x = x_flat.view(B, Ph, Pw, K, F)\n", "\n", " return x, k_attn, s_attn\n", "\n", "\n", "# ============================================================================\n", "# FULL MODEL\n", "# ============================================================================\n", "\n", "class GeometricPatchworkViT(nn.Module):\n", " def __init__(\n", " self,\n", " img_size=32,\n", " patch_size=8,\n", " in_chans=3,\n", " num_classes=10,\n", " depth=4,\n", " edim=8,\n", " hidden=128,\n", " num_heads=4,\n", " num_blocks=6,\n", " ):\n", " super().__init__()\n", "\n", " self._patch_size = patch_size\n", " self._ph = img_size // patch_size\n", " self._pw = img_size // patch_size\n", " self._patch_dim = patch_size * patch_size * in_chans\n", " self._depth = depth\n", "\n", " # Patch encoder\n", " self._patch_enc = PatchToKChannels(self._patch_dim, depth, edim, hidden)\n", " self._max_dim = self._patch_enc._max_dim\n", "\n", " # Project to larger feature dim\n", " self._feat_dim = 64\n", " self._proj_up = nn.Linear(self._max_dim, self._feat_dim)\n", "\n", " # Position embedding\n", " self._pos = nn.Parameter(torch.randn(1, self._ph, self._pw, depth, self._feat_dim) * 0.02)\n", "\n", " # Transformer blocks\n", " self._blocks = nn.ModuleList([\n", " GeoBlock(depth, self._feat_dim, num_heads)\n", " for _ in range(num_blocks)\n", " ])\n", "\n", " # Classification\n", " total_dim = depth * self._feat_dim\n", " self._norm = nn.LayerNorm(total_dim)\n", " self._head = nn.Linear(total_dim, num_classes)\n", "\n", " # Store config\n", " self._config = {\n", " 'img_size': img_size,\n", " 'patch_size': patch_size,\n", " 'grid': f'{self._ph}x{self._pw}',\n", " 'depth': depth,\n", " 'k_dims': self._patch_enc._k_dims,\n", " 'feat_dim': self._feat_dim,\n", " 'num_blocks': num_blocks,\n", " }\n", "\n", " def forward(self, x):\n", " B = x.shape[0]\n", "\n", " # Extract patches: [B, P_h, P_w, patch_dim]\n", " patches = x.unfold(2, self._patch_size, self._patch_size) \\\n", " .unfold(3, self._patch_size, self._patch_size)\n", " patches = patches.permute(0, 2, 3, 1, 4, 5).contiguous()\n", " patches = patches.view(B, self._ph, self._pw, -1)\n", "\n", " # Encode: [B, P_h, P_w, K, max_dim]\n", " k_channels, vol2, d2_mean = self._patch_enc(patches)\n", "\n", " # Project up: [B, P_h, P_w, K, feat_dim]\n", " k_channels = self._proj_up(k_channels)\n", " k_channels = k_channels + self._pos\n", "\n", " # Blocks\n", " k_attns, s_attns = [], []\n", " for blk in self._blocks:\n", " k_channels, k_attn, s_attn = blk(k_channels)\n", " k_attns.append(k_attn)\n", " s_attns.append(s_attn)\n", "\n", " # Classify\n", " pooled = k_channels.mean(dim=[1, 2]).flatten(1)\n", " logits = self._head(self._norm(pooled))\n", "\n", " return logits, {\n", " 'vol2': vol2, # [B, P_h, P_w, K]\n", " 'd2_mean': d2_mean, # [B, P_h, P_w, K]\n", " 'k_attns': k_attns, # List of [B, P_h, P_w, H, K, K]\n", " 's_attns': s_attns, # List of [B, H, N, N]\n", " }\n", "\n", "\n", "# ============================================================================\n", "# LOSS\n", "# ============================================================================\n", "\n", "def geometric_loss(logits, labels, info, ce_weight=1.0, validity_weight=0.1):\n", " ce = F.cross_entropy(logits, labels)\n", " vol2 = info['vol2']\n", " validity = F.relu(-vol2).mean()\n", "\n", " total = ce_weight * ce + validity_weight * validity\n", "\n", " return total, ce, validity\n", "\n", "\n", "# ============================================================================\n", "# METRICS\n", "# ============================================================================\n", "\n", "@torch.no_grad()\n", "def compute_metrics(info, depth):\n", " vol2 = info['vol2'] # [B, P_h, P_w, K]\n", " d2_mean = info['d2_mean'] # [B, P_h, P_w, K]\n", "\n", " metrics = {}\n", "\n", " # Per-k validity and volume\n", " for k in range(depth):\n", " v = vol2[..., k]\n", " d = d2_mean[..., k]\n", " metrics[f'k{k+1}_valid'] = (v > 0).float().mean().item()\n", " metrics[f'k{k+1}_vol2'] = v.mean().item()\n", " metrics[f'k{k+1}_d2'] = d.mean().item()\n", "\n", " # Overall\n", " metrics['valid_rate'] = (vol2 > 0).float().mean().item()\n", " metrics['vol2_mean'] = vol2.mean().item()\n", " metrics['vol2_std'] = vol2.std().item()\n", "\n", " # K-attention entropy (from last block)\n", " if info['k_attns']:\n", " k_attn = info['k_attns'][-1] # [B, P_h, P_w, H, K, K]\n", " k_attn_flat = k_attn.flatten(0, 3) # [B*P*H, K, K]\n", " entropy = -(k_attn_flat * (k_attn_flat + 1e-8).log()).sum(dim=-1).mean()\n", " metrics['k_attn_entropy'] = entropy.item()\n", "\n", " # Spatial attention entropy (from last block)\n", " if info['s_attns']:\n", " s_attn = info['s_attns'][-1] # [B, H, N, N]\n", " s_attn_flat = s_attn.flatten(0, 1) # [B*H, N, N]\n", " entropy = -(s_attn_flat * (s_attn_flat + 1e-8).log()).sum(dim=-1).mean()\n", " metrics['s_attn_entropy'] = entropy.item()\n", "\n", " return metrics\n", "\n", "\n", "# ============================================================================\n", "# TRAIN\n", "# ============================================================================\n", "\n", "def train():\n", " # CIFAR-10 transforms\n", " train_transform = transforms.Compose([\n", " transforms.RandomCrop(32, padding=4),\n", " transforms.RandomHorizontalFlip(),\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))\n", " ])\n", " test_transform = transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))\n", " ])\n", "\n", " train_ds = datasets.CIFAR10('./data', train=True, download=True, transform=train_transform)\n", " test_ds = datasets.CIFAR10('./data', train=False, download=True, transform=test_transform)\n", " train_dl = DataLoader(train_ds, batch_size=256, shuffle=True, num_workers=6, pin_memory=True, persistent_workers=True)\n", " test_dl = DataLoader(test_ds, batch_size=512, shuffle=False, num_workers=4, pin_memory=True, persistent_workers=True)\n", "\n", " print(\"\\nBuilding model...\")\n", " model = GeometricPatchworkViT(\n", " img_size=32,\n", " patch_size=8,\n", " in_chans=3,\n", " num_classes=10,\n", " depth=4,\n", " edim=16,\n", " hidden=64,\n", " num_heads=8,\n", " num_blocks=6,\n", " ).to(device)\n", "\n", " # Print config\n", " print(f\"\\nConfig:\")\n", " for k, v in model._config.items():\n", " print(f\" {k}: {v}\")\n", "\n", " params = sum(p.numel() for p in model.parameters())\n", " print(f\" params: {params:,}\")\n", "\n", " # Compile with reduce-overhead\n", " print(\"\\nCompiling model...\")\n", " model = torch.compile(model, mode=\"reduce-overhead\")\n", "\n", " opt = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.05)\n", " sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=100)\n", "\n", " best = 0\n", " print(\"\\nTraining...\")\n", " print(\"=\" * 130)\n", "\n", " header = (\n", " f\"{'Ep':>3} | {'CE':>6} | {'Val':>6} | {'Tr':>6} | {'Te':>6} | {'Best':>6} | \"\n", " f\"{'k1':>5} | {'k2':>5} | {'k3':>5} | {'k4':>5} | \"\n", " f\"{'vol²':>8} | {'KAtt':>5} | {'SAtt':>5} | {'s/ep':>5}\"\n", " )\n", " print(header)\n", " print(\"-\" * 130)\n", "\n", " for ep in range(100):\n", " t0 = time.time()\n", "\n", " model.train()\n", " ce_sum, val_sum, cor, tot = 0, 0, 0, 0\n", "\n", " for img, lab in train_dl:\n", " img, lab = img.to(device, non_blocking=True), lab.to(device, non_blocking=True)\n", "\n", " opt.zero_grad()\n", " logits, info = model(img)\n", " loss, ce, validity = geometric_loss(logits, lab, info)\n", " loss.backward()\n", " torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n", " opt.step()\n", "\n", " ce_sum += ce.item() * img.size(0)\n", " val_sum += validity.item() * img.size(0)\n", " cor += (logits.argmax(1) == lab).sum().item()\n", " tot += img.size(0)\n", "\n", " tr_acc = cor / tot\n", "\n", " model.eval()\n", " cor, tot = 0, 0\n", " all_metrics = []\n", " with torch.no_grad():\n", " for img, lab in test_dl:\n", " img, lab = img.to(device, non_blocking=True), lab.to(device, non_blocking=True)\n", " logits, info = model(img)\n", " cor += (logits.argmax(1) == lab).sum().item()\n", " tot += img.size(0)\n", " all_metrics.append(compute_metrics(info, model._config['depth']))\n", "\n", " te_acc = cor / tot\n", " sched.step()\n", " if te_acc > best:\n", " best = te_acc\n", "\n", " # Aggregate metrics\n", " m = {}\n", " for key in all_metrics[0]:\n", " m[key] = sum(d[key] for d in all_metrics) / len(all_metrics)\n", "\n", " elapsed = time.time() - t0\n", "\n", " if ep % 5 == 0 or ep == 99:\n", " print(\n", " f\"{ep+1:3d} | {ce_sum/tot:6.4f} | {val_sum/tot:6.4f} | {tr_acc:6.2%} | {te_acc:6.2%} | {best:6.2%} | \"\n", " f\"{m['k1_valid']:5.1%} | {m['k2_valid']:5.1%} | {m['k3_valid']:5.1%} | {m['k4_valid']:5.1%} | \"\n", " f\"{m['vol2_mean']:8.2e} | {m.get('k_attn_entropy', 0):5.2f} | {m.get('s_attn_entropy', 0):5.2f} | {elapsed:5.1f}\"\n", " )\n", "\n", " # Detailed per-k stats every 20 epochs\n", " if ep % 20 == 0 or ep == 99:\n", " print(f\" Per-k details:\")\n", " for k in range(model._config['depth']):\n", " print(f\" k={k+1}: valid={m[f'k{k+1}_valid']:.1%} vol²={m[f'k{k+1}_vol2']:.2e} d²={m[f'k{k+1}_d2']:.4f}\")\n", "\n", " print(\"=\" * 130)\n", " print(f\"Best: {best:.2%}\")\n", " return model\n", "\n", "\n", "if __name__ == \"__main__\":\n", " model = train()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "id": "VVqUhAsd094b", "outputId": "d07620b9-a5b2-4f64-a538-cded1a6ae353" }, "execution_count": 8, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Device: cuda\n", "\n", "Building model...\n", "\n", "Config:\n", " img_size: 32\n", " patch_size: 8\n", " grid: 4x4\n", " depth: 4\n", " k_dims: [2, 4, 7, 11]\n", " feat_dim: 64\n", " num_blocks: 6\n", " params: 4,874,922\n", "\n", "Compiling model...\n", "\n", "Training...\n", "==================================================================================================================================\n", " Ep | CE | Val | Tr | Te | Best | k1 | k2 | k3 | k4 | vol² | KAtt | SAtt | s/ep\n", "----------------------------------------------------------------------------------------------------------------------------------\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "/usr/local/lib/python3.12/dist-packages/torch/_inductor/lowering.py:1744: FutureWarning: `torch._prims_common.check` is deprecated and will be removed in the future. Please use `torch._check*` functions instead.\n", " check(\n", "/usr/local/lib/python3.12/dist-packages/torch/_inductor/lowering.py:1744: FutureWarning: `torch._prims_common.check` is deprecated and will be removed in the future. Please use `torch._check*` functions instead.\n", " check(\n", "/usr/local/lib/python3.12/dist-packages/torch/_inductor/lowering.py:1744: FutureWarning: `torch._prims_common.check` is deprecated and will be removed in the future. Please use `torch._check*` functions instead.\n", " check(\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ " 1 | 10.5853 | 0.0000 | 19.10% | 25.67% | 25.67% | 100.0% | 100.0% | 100.0% | 100.0% | 3.66e-01 | 1.34 | 1.73 | 124.4\n", " Per-k details:\n", " k=1: valid=100.0% vol²=1.18e+00 d²=1.1759\n", " k=2: valid=100.0% vol²=2.69e-01 d²=1.1436\n", " k=3: valid=100.0% vol²=1.68e-02 d²=1.0621\n", " k=4: valid=100.0% vol²=5.69e-04 d²=1.0712\n", " 6 | 7.6766 | 0.0000 | 43.43% | 45.24% | 45.24% | 100.0% | 100.0% | 100.0% | 100.0% | 3.72e-01 | 1.32 | 0.35 | 5.1\n", " 11 | 6.9712 | 0.0000 | 49.01% | 50.56% | 50.70% | 100.0% | 100.0% | 100.0% | 100.0% | 3.71e-01 | 1.20 | 0.38 | 5.0\n", " 16 | 6.3937 | 0.0000 | 53.54% | 53.82% | 53.82% | 100.0% | 100.0% | 100.0% | 100.0% | 3.54e-01 | 1.04 | 0.46 | 5.0\n", " 21 | 5.8676 | 0.0000 | 57.77% | 58.21% | 58.21% | 100.0% | 100.0% | 100.0% | 100.0% | 3.49e-01 | 0.93 | 0.55 | 5.0\n", " Per-k details:\n", " k=1: valid=100.0% vol²=1.14e+00 d²=1.1359\n", " k=2: valid=100.0% vol²=2.44e-01 d²=1.1750\n", " k=3: valid=100.0% vol²=1.44e-02 d²=1.0294\n", " k=4: valid=100.0% vol²=6.06e-04 d²=1.1132\n", " 26 | 5.4143 | 0.0000 | 61.06% | 59.13% | 60.33% | 100.0% | 100.0% | 100.0% | 100.0% | 3.39e-01 | 0.88 | 0.72 | 5.1\n", " 31 | 4.9114 | 0.0000 | 64.64% | 63.27% | 63.27% | 100.0% | 100.0% | 100.0% | 100.0% | 3.36e-01 | 0.81 | 0.84 | 5.1\n", " 36 | 4.4394 | 0.0000 | 68.23% | 65.53% | 65.53% | 100.0% | 100.0% | 100.0% | 100.0% | 3.28e-01 | 0.81 | 1.02 | 5.2\n", " 41 | 3.9912 | 0.0000 | 71.22% | 66.72% | 66.72% | 100.0% | 100.0% | 100.0% | 100.0% | 3.27e-01 | 0.79 | 1.17 | 5.1\n", " Per-k details:\n", " k=1: valid=100.0% vol²=1.07e+00 d²=1.0727\n", " k=2: valid=100.0% vol²=2.20e-01 d²=1.1206\n", " k=3: valid=100.0% vol²=1.34e-02 d²=1.0162\n", " k=4: valid=100.0% vol²=5.99e-04 d²=1.1258\n", " 46 | 3.4967 | 0.0000 | 74.78% | 67.02% | 67.40% | 100.0% | 100.0% | 100.0% | 100.0% | 3.17e-01 | 0.82 | 1.30 | 5.0\n", " 51 | 3.0084 | 0.0000 | 78.43% | 67.87% | 68.08% | 100.0% | 100.0% | 100.0% | 100.0% | 3.20e-01 | 0.83 | 1.39 | 5.0\n" ] }, { "output_type": "error", "ename": "KeyboardInterrupt", "evalue": "", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", "\u001b[0;32m/tmp/ipython-input-1997069037.py\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 552\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 553\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0m__name__\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"__main__\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 554\u001b[0;31m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m/tmp/ipython-input-1997069037.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m()\u001b[0m\n\u001b[1;32m 499\u001b[0m \u001b[0mlogits\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minfo\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 500\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mce\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalidity\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgeometric_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlogits\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlab\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minfo\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 501\u001b[0;31m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 502\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclip_grad_norm_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparameters\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1.0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 503\u001b[0m \u001b[0mopt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/torch/_tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 623\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 624\u001b[0m )\n\u001b[0;32m--> 625\u001b[0;31m torch.autograd.backward(\n\u001b[0m\u001b[1;32m 626\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 627\u001b[0m )\n", "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 352\u001b[0m \u001b[0;31m# some Python versions print out the first line of a multi-line function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 353\u001b[0m \u001b[0;31m# calls in the traceback and some print out the last line\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 354\u001b[0;31m _engine_run_backward(\n\u001b[0m\u001b[1;32m 355\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 356\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/torch/autograd/graph.py\u001b[0m in \u001b[0;36m_engine_run_backward\u001b[0;34m(t_outputs, *args, **kwargs)\u001b[0m\n\u001b[1;32m 839\u001b[0m \u001b[0munregister_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_register_logging_hooks_on_whole_graph\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt_outputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 840\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 841\u001b[0;31m return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass\n\u001b[0m\u001b[1;32m 842\u001b[0m \u001b[0mt_outputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 843\u001b[0m ) # Calls into the C++ engine to run the backward pass\n", "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ] }, { "cell_type": "code", "source": [ "#@title Geometric Patchwork - Fixed Bottleneck\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets, transforms\n", "import math\n", "from itertools import combinations\n", "import time\n", "\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "print(f\"Device: {device}\")\n", "\n", "from geovocab2.shapes.factory.simplex_factory import SimplexFactory\n", "\n", "# ============================================================================\n", "# CAYLEY-MENGER VALIDATOR\n", "# ============================================================================\n", "\n", "class CMValidator(nn.Module):\n", " def __init__(self, k):\n", " super().__init__()\n", " self._k = k\n", " self._nv = k + 1\n", "\n", " pairs = list(combinations(range(self._nv), 2))\n", " self._npairs = len(pairs)\n", " self.register_buffer('_pi', torch.tensor([p[0] for p in pairs], dtype=torch.long))\n", " self.register_buffer('_pj', torch.tensor([p[1] for p in pairs], dtype=torch.long))\n", "\n", " sign = (-1.0) ** (k + 1)\n", " fact = math.factorial(k)\n", " self._prefactor = sign / ((2.0 ** k) * (fact ** 2))\n", "\n", " def forward(self, verts):\n", " gram = torch.einsum('...ve,...we->...vw', verts, verts)\n", " norms = torch.diagonal(gram, dim1=-2, dim2=-1)\n", " d2_mat = norms.unsqueeze(-1) + norms.unsqueeze(-2) - 2 * gram\n", " d2_mat = F.relu(d2_mat)\n", "\n", " d2_pairs = d2_mat[..., self._pi, self._pj]\n", "\n", " shape = d2_mat.shape[:-2]\n", " V = d2_mat.shape[-1]\n", " cm = torch.zeros(*shape, V+1, V+1, device=d2_mat.device, dtype=d2_mat.dtype)\n", " cm[..., 0, 1:] = 1.0\n", " cm[..., 1:, 0] = 1.0\n", " cm[..., 1:, 1:] = d2_mat\n", "\n", " vol2 = self._prefactor * torch.linalg.det(cm)\n", "\n", " return d2_pairs, vol2\n", "\n", "\n", "# ============================================================================\n", "# K-SIMPLEX CHANNEL ENCODER - FIXED: KEEPS VERTEX FEATURES\n", "# ============================================================================\n", "\n", "class KSimplexChannel(nn.Module):\n", " BASE_DEFORM = 0.05\n", "\n", " def __init__(self, k, in_dim, edim, feat_dim):\n", " super().__init__()\n", " self._k = k\n", " self._nv = k + 1\n", " self._edim = edim\n", " self._feat_dim = feat_dim\n", "\n", " self._cm = CMValidator(k)\n", " self._geo_dim = self._cm._npairs + 1 # d² + vol²\n", "\n", " factory = SimplexFactory(k=k, embed_dim=edim, method=\"regular\", scale=1.0)\n", " self.register_buffer('_template', factory.build_torch(dtype=torch.float32))\n", "\n", " # Input → vertex coordinates (for CM)\n", " self._to_coords = nn.Linear(in_dim, self._nv * edim)\n", "\n", " # Input → vertex features (KEPT, not discarded)\n", " self._to_feats = nn.Linear(in_dim, self._nv * feat_dim)\n", "\n", " # Geometry modulates features\n", " self._geo_gate = nn.Sequential(\n", " nn.Linear(self._geo_dim, feat_dim),\n", " nn.Sigmoid(),\n", " )\n", "\n", " # Output: aggregate vertex features + geometry\n", " self._out_dim = feat_dim + self._geo_dim\n", "\n", " @property\n", " def out_dim(self):\n", " return self._out_dim\n", "\n", " def forward(self, x):\n", " \"\"\"\n", " x: [..., in_dim]\n", "\n", " Returns:\n", " out: [..., feat_dim + geo_dim] (vertex features + geometry)\n", " vol2: [...]\n", " \"\"\"\n", " # Vertex coordinates for CM\n", " coords = self._to_coords(x).unflatten(-1, (self._nv, self._edim))\n", " verts = self._template + self.BASE_DEFORM * coords\n", "\n", " # Vertex features (the actual learned representation)\n", " vert_feats = self._to_feats(x).unflatten(-1, (self._nv, self._feat_dim))\n", "\n", " # CM validation\n", " d2, vol2 = self._cm(verts)\n", " geo = torch.cat([d2, vol2.unsqueeze(-1)], dim=-1)\n", "\n", " # Gate features by geometry (valid simplices pass more info)\n", " gate = self._geo_gate(geo)\n", " validity = torch.sigmoid(vol2 * 1e6).unsqueeze(-1)\n", "\n", " # Aggregate vertex features (mean) with gating\n", " feat_agg = vert_feats.mean(dim=-2) * gate * validity\n", "\n", " # Output: features + geometry\n", " out = torch.cat([feat_agg, geo], dim=-1)\n", "\n", " return out, vol2, d2.mean(dim=-1)\n", "\n", "\n", "# ============================================================================\n", "# PATCH TO K-SIMPLEX CHANNELS - FIXED\n", "# ============================================================================\n", "\n", "class PatchToKChannels(nn.Module):\n", " def __init__(self, patch_dim, depth, edim, feat_dim, hidden=256):\n", " super().__init__()\n", " self._depth = depth\n", " self._edim = edim\n", " self._feat_dim = feat_dim\n", "\n", " self._proj = nn.Sequential(\n", " nn.Linear(patch_dim, hidden),\n", " nn.LayerNorm(hidden),\n", " nn.GELU(),\n", " nn.Linear(hidden, hidden),\n", " nn.LayerNorm(hidden),\n", " nn.GELU(),\n", " )\n", "\n", " self._k_encoders = nn.ModuleList([\n", " KSimplexChannel(k=k+1, in_dim=hidden, edim=edim, feat_dim=feat_dim)\n", " for k in range(depth)\n", " ])\n", "\n", " self._k_out_dims = [enc.out_dim for enc in self._k_encoders]\n", " self._max_out_dim = max(self._k_out_dims)\n", "\n", " def forward(self, patches):\n", " \"\"\"\n", " patches: [B, P_h, P_w, patch_dim]\n", "\n", " Returns:\n", " k_channels: [B, P_h, P_w, K, max_out_dim]\n", " vol2: [B, P_h, P_w, K]\n", " d2_mean: [B, P_h, P_w, K]\n", " \"\"\"\n", " h = self._proj(patches)\n", "\n", " out_list = []\n", " vol2_list = []\n", " d2_list = []\n", "\n", " for enc in self._k_encoders:\n", " out, vol2, d2_mean = enc(h)\n", "\n", " # Pad to max_out_dim\n", " pad_size = self._max_out_dim - out.shape[-1]\n", " if pad_size > 0:\n", " out = F.pad(out, (0, pad_size))\n", "\n", " out_list.append(out)\n", " vol2_list.append(vol2)\n", " d2_list.append(d2_mean)\n", "\n", " k_channels = torch.stack(out_list, dim=-2)\n", " vol2 = torch.stack(vol2_list, dim=-1)\n", " d2_mean = torch.stack(d2_list, dim=-1)\n", "\n", " return k_channels, vol2, d2_mean\n", "\n", "\n", "# ============================================================================\n", "# K-CHANNEL CROSS-ATTENTION\n", "# ============================================================================\n", "\n", "class KChannelCrossAttention(nn.Module):\n", " def __init__(self, depth, feat_dim, num_heads=4, dropout=0.1):\n", " super().__init__()\n", " self._depth = depth\n", " self._feat_dim = feat_dim\n", " self._num_heads = num_heads\n", " self._head_dim = feat_dim // num_heads\n", "\n", " self._norm_q = nn.LayerNorm(feat_dim)\n", " self._norm_kv = nn.LayerNorm(feat_dim)\n", "\n", " self._to_q = nn.Linear(feat_dim, feat_dim)\n", " self._to_k = nn.Linear(feat_dim, feat_dim)\n", " self._to_v = nn.Linear(feat_dim, feat_dim)\n", " self._out = nn.Linear(feat_dim, feat_dim)\n", " self._drop = nn.Dropout(dropout)\n", "\n", " self._scale = self._head_dim ** -0.5\n", "\n", " def forward(self, x):\n", " B, Ph, Pw, K, F = x.shape\n", "\n", " x_flat = x.view(B * Ph * Pw, K, F)\n", "\n", " q = self._to_q(self._norm_q(x_flat))\n", " k = self._to_k(self._norm_kv(x_flat))\n", " v = self._to_v(self._norm_kv(x_flat))\n", "\n", " q = q.view(-1, K, self._num_heads, self._head_dim).transpose(1, 2)\n", " k = k.view(-1, K, self._num_heads, self._head_dim).transpose(1, 2)\n", " v = v.view(-1, K, self._num_heads, self._head_dim).transpose(1, 2)\n", "\n", " attn = (q @ k.transpose(-2, -1)) * self._scale\n", " attn = attn.softmax(dim=-1)\n", " attn = self._drop(attn)\n", "\n", " out = (attn @ v).transpose(1, 2).reshape(B * Ph * Pw, K, F)\n", " out = self._out(out)\n", " out = self._drop(out)\n", "\n", " return x + out.view(B, Ph, Pw, K, F)\n", "\n", "\n", "# ============================================================================\n", "# SPATIAL ATTENTION\n", "# ============================================================================\n", "\n", "class SpatialAttention(nn.Module):\n", " def __init__(self, depth, feat_dim, num_heads=4, dropout=0.1):\n", " super().__init__()\n", " self._num_heads = num_heads\n", "\n", " total_dim = depth * feat_dim\n", " self._head_dim = total_dim // num_heads\n", "\n", " self._norm = nn.LayerNorm(total_dim)\n", " self._to_qkv = nn.Linear(total_dim, 3 * total_dim)\n", " self._out = nn.Linear(total_dim, total_dim)\n", " self._drop = nn.Dropout(dropout)\n", "\n", " self._scale = self._head_dim ** -0.5\n", "\n", " def forward(self, x):\n", " B, Ph, Pw, K, F = x.shape\n", " N = Ph * Pw\n", "\n", " x_flat = x.view(B, N, K * F)\n", " x_norm = self._norm(x_flat)\n", "\n", " qkv = self._to_qkv(x_norm).chunk(3, dim=-1)\n", " q, k, v = [t.view(B, N, self._num_heads, self._head_dim).transpose(1, 2) for t in qkv]\n", "\n", " attn = (q @ k.transpose(-2, -1)) * self._scale\n", " attn = attn.softmax(dim=-1)\n", " attn = self._drop(attn)\n", "\n", " out = (attn @ v).transpose(1, 2).reshape(B, N, K * F)\n", " out = self._out(out)\n", " out = self._drop(out)\n", "\n", " return x + out.view(B, Ph, Pw, K, F)\n", "\n", "\n", "# ============================================================================\n", "# TRANSFORMER BLOCK\n", "# ============================================================================\n", "\n", "class GeoBlock(nn.Module):\n", " def __init__(self, depth, feat_dim, num_heads, mlp_ratio=4.0, dropout=0.1):\n", " super().__init__()\n", "\n", " self._k_attn = KChannelCrossAttention(depth, feat_dim, num_heads, dropout)\n", " self._spatial_attn = SpatialAttention(depth, feat_dim, num_heads, dropout)\n", "\n", " total_dim = depth * feat_dim\n", " self._norm = nn.LayerNorm(total_dim)\n", " self._mlp = nn.Sequential(\n", " nn.Linear(total_dim, int(total_dim * mlp_ratio)),\n", " nn.GELU(),\n", " nn.Dropout(dropout),\n", " nn.Linear(int(total_dim * mlp_ratio), total_dim),\n", " nn.Dropout(dropout),\n", " )\n", "\n", " def forward(self, x):\n", " B, Ph, Pw, K, F = x.shape\n", "\n", " x = self._k_attn(x)\n", " x = self._spatial_attn(x)\n", "\n", " x_flat = x.view(B, Ph, Pw, K * F)\n", " x_flat = x_flat + self._mlp(self._norm(x_flat))\n", " x = x_flat.view(B, Ph, Pw, K, F)\n", "\n", " return x\n", "\n", "\n", "# ============================================================================\n", "# FULL MODEL\n", "# ============================================================================\n", "\n", "class GeometricPatchworkViT(nn.Module):\n", " def __init__(\n", " self,\n", " img_size=32,\n", " patch_size=4, # Smaller patches = more spatial resolution\n", " in_chans=3,\n", " num_classes=10,\n", " depth=4,\n", " edim=16, # Larger embedding dim\n", " feat_dim=64, # Feature dim per k-level\n", " hidden=256, # Larger hidden\n", " num_heads=8,\n", " num_blocks=8,\n", " dropout=0.1,\n", " ):\n", " super().__init__()\n", "\n", " self._patch_size = patch_size\n", " self._ph = img_size // patch_size\n", " self._pw = img_size // patch_size\n", " self._patch_dim = patch_size * patch_size * in_chans\n", " self._depth = depth\n", " self._feat_dim = feat_dim\n", "\n", " # Patch encoder\n", " self._patch_enc = PatchToKChannels(self._patch_dim, depth, edim, feat_dim, hidden)\n", " self._max_out_dim = self._patch_enc._max_out_dim\n", "\n", " # Project to uniform feat_dim for attention\n", " self._proj = nn.Linear(self._max_out_dim, feat_dim)\n", "\n", " # Position embedding\n", " self._pos = nn.Parameter(torch.randn(1, self._ph, self._pw, depth, feat_dim) * 0.02)\n", "\n", " # Blocks\n", " self._blocks = nn.ModuleList([\n", " GeoBlock(depth, feat_dim, num_heads, dropout=dropout)\n", " for _ in range(num_blocks)\n", " ])\n", "\n", " # Classification\n", " total_dim = depth * feat_dim\n", " self._norm = nn.LayerNorm(total_dim)\n", " self._head = nn.Sequential(\n", " nn.Linear(total_dim, total_dim),\n", " nn.GELU(),\n", " nn.Dropout(dropout),\n", " nn.Linear(total_dim, num_classes),\n", " )\n", "\n", " self._config = {\n", " 'img_size': img_size,\n", " 'patch_size': patch_size,\n", " 'grid': f'{self._ph}x{self._pw}',\n", " 'depth': depth,\n", " 'edim': edim,\n", " 'feat_dim': feat_dim,\n", " 'k_out_dims': self._patch_enc._k_out_dims,\n", " 'num_blocks': num_blocks,\n", " 'dropout': dropout,\n", " }\n", "\n", " def forward(self, x):\n", " B = x.shape[0]\n", "\n", " # Extract patches\n", " patches = x.unfold(2, self._patch_size, self._patch_size) \\\n", " .unfold(3, self._patch_size, self._patch_size)\n", " patches = patches.permute(0, 2, 3, 1, 4, 5).contiguous()\n", " patches = patches.view(B, self._ph, self._pw, -1)\n", "\n", " # Encode\n", " k_channels, vol2, d2_mean = self._patch_enc(patches)\n", "\n", " # Project to uniform dim\n", " k_channels = self._proj(k_channels)\n", " k_channels = k_channels + self._pos\n", "\n", " # Blocks\n", " for blk in self._blocks:\n", " k_channels = blk(k_channels)\n", "\n", " # Classify\n", " pooled = k_channels.mean(dim=[1, 2]).flatten(1)\n", " logits = self._head(self._norm(pooled))\n", "\n", " return logits, {'vol2': vol2, 'd2_mean': d2_mean}\n", "\n", "\n", "# ============================================================================\n", "# LOSS & METRICS\n", "# ============================================================================\n", "\n", "def geometric_loss(logits, labels, info, ce_weight=1.0, validity_weight=0.1):\n", " ce = F.cross_entropy(logits, labels)\n", " vol2 = info['vol2']\n", " validity = F.relu(-vol2).mean()\n", " total = ce_weight * ce + validity_weight * validity\n", " return total, ce, validity\n", "\n", "\n", "@torch.no_grad()\n", "def compute_metrics(info, depth):\n", " vol2 = info['vol2']\n", " d2_mean = info['d2_mean']\n", "\n", " m = {'valid_rate': (vol2 > 0).float().mean().item()}\n", " for k in range(depth):\n", " m[f'k{k+1}_valid'] = (vol2[..., k] > 0).float().mean().item()\n", " m[f'k{k+1}_vol2'] = vol2[..., k].mean().item()\n", " m[f'k{k+1}_d2'] = d2_mean[..., k].mean().item()\n", " return m\n", "\n", "\n", "# ============================================================================\n", "# TRAIN\n", "# ============================================================================\n", "\n", "def train():\n", " train_transform = transforms.Compose([\n", " transforms.RandomCrop(32, padding=4),\n", " transforms.RandomHorizontalFlip(),\n", " transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10),\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))\n", " ])\n", " test_transform = transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))\n", " ])\n", "\n", " train_ds = datasets.CIFAR10('./data', train=True, download=True, transform=train_transform)\n", " test_ds = datasets.CIFAR10('./data', train=False, download=True, transform=test_transform)\n", " train_dl = DataLoader(train_ds, batch_size=256, shuffle=True, num_workers=6, pin_memory=True)\n", " test_dl = DataLoader(test_ds, batch_size=512, shuffle=False, num_workers=4, pin_memory=True)\n", "\n", " print(\"\\nBuilding model...\")\n", " model = GeometricPatchworkViT(\n", " img_size=32,\n", " patch_size=8, # 8x8 grid\n", " in_chans=3,\n", " num_classes=10,\n", " depth=4,\n", " edim=32,\n", " feat_dim=128,\n", " hidden=512,\n", " num_heads=4,\n", " num_blocks=8,\n", " dropout=0.1,\n", " ).to(device)\n", "\n", " print(f\"\\nConfig:\")\n", " for k, v in model._config.items():\n", " print(f\" {k}: {v}\")\n", "\n", " params = sum(p.numel() for p in model.parameters())\n", " print(f\" params: {params:,}\")\n", "\n", " print(\"\\nCompiling...\")\n", " model = torch.compile(model, mode=\"reduce-overhead\")\n", "\n", " opt = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.05)\n", " sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=100)\n", "\n", " best = 0\n", " print(\"\\nTraining...\")\n", " print(\"=\" * 120)\n", " print(f\"{'Ep':>3} | {'CE':>6} | {'Val':>6} | {'Tr':>6} | {'Te':>6} | {'Best':>6} | \"\n", " f\"{'k1':>5} | {'k2':>5} | {'k3':>5} | {'k4':>5} | {'s/ep':>5}\")\n", " print(\"-\" * 120)\n", "\n", " for ep in range(100):\n", " t0 = time.time()\n", "\n", " model.train()\n", " ce_sum, val_sum, cor, tot = 0, 0, 0, 0\n", "\n", " for img, lab in train_dl:\n", " img, lab = img.to(device, non_blocking=True), lab.to(device, non_blocking=True)\n", "\n", " opt.zero_grad()\n", " logits, info = model(img)\n", " loss, ce, val = geometric_loss(logits, lab, info)\n", " loss.backward()\n", " torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n", " opt.step()\n", "\n", " ce_sum += ce.item() * img.size(0)\n", " val_sum += val.item() * img.size(0)\n", " cor += (logits.argmax(1) == lab).sum().item()\n", " tot += img.size(0)\n", "\n", " tr_acc = cor / tot\n", "\n", " model.eval()\n", " cor, tot = 0, 0\n", " metrics_agg = []\n", " with torch.no_grad():\n", " for img, lab in test_dl:\n", " img, lab = img.to(device, non_blocking=True), lab.to(device, non_blocking=True)\n", " logits, info = model(img)\n", " cor += (logits.argmax(1) == lab).sum().item()\n", " tot += img.size(0)\n", " metrics_agg.append(compute_metrics(info, model._config['depth']))\n", "\n", " te_acc = cor / tot\n", " sched.step()\n", " if te_acc > best:\n", " best = te_acc\n", "\n", " m = {k: sum(d[k] for d in metrics_agg) / len(metrics_agg) for k in metrics_agg[0]}\n", " elapsed = time.time() - t0\n", "\n", " if ep % 5 == 0 or ep == 99:\n", " print(f\"{ep+1:3d} | {ce_sum/tot:6.4f} | {val_sum/tot:6.4f} | {tr_acc:6.2%} | {te_acc:6.2%} | {best:6.2%} | \"\n", " f\"{m['k1_valid']:5.1%} | {m['k2_valid']:5.1%} | {m['k3_valid']:5.1%} | {m['k4_valid']:5.1%} | {elapsed:5.1f}\")\n", "\n", " if ep % 20 == 0 or ep == 99:\n", " print(f\" k-details: \" + \" | \".join(\n", " f\"k{k+1}: vol²={m[f'k{k+1}_vol2']:.2e} d²={m[f'k{k+1}_d2']:.3f}\"\n", " for k in range(model._config['depth'])\n", " ))\n", "\n", " print(\"=\" * 120)\n", " print(f\"Best: {best:.2%}\")\n", " return model\n", "\n", "\n", "if __name__ == \"__main__\":\n", " model = train()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "cellView": "form", "id": "bSdSvsYBCvvh", "outputId": "7036a621-52ba-4a1b-df76-ee9e6954d26d" }, "execution_count": 9, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Device: cuda\n", "\n", "Building model...\n", "\n", "Config:\n", " img_size: 32\n", " patch_size: 4\n", " grid: 8x8\n", " depth: 4\n", " edim: 16\n", " feat_dim: 64\n", " k_out_dims: [66, 68, 71, 75]\n", " num_blocks: 8\n", " dropout: 0.1\n", " params: 6,912,362\n", "\n", "Compiling...\n", "\n", "Training...\n", "========================================================================================================================\n", " Ep | CE | Val | Tr | Te | Best | k1 | k2 | k3 | k4 | s/ep\n", "------------------------------------------------------------------------------------------------------------------------\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "/usr/local/lib/python3.12/dist-packages/torch/_inductor/lowering.py:1744: FutureWarning: `torch._prims_common.check` is deprecated and will be removed in the future. Please use `torch._check*` functions instead.\n", " check(\n", "/usr/local/lib/python3.12/dist-packages/torch/_inductor/lowering.py:1744: FutureWarning: `torch._prims_common.check` is deprecated and will be removed in the future. Please use `torch._check*` functions instead.\n", " check(\n", "/usr/local/lib/python3.12/dist-packages/torch/_inductor/lowering.py:1744: FutureWarning: `torch._prims_common.check` is deprecated and will be removed in the future. Please use `torch._check*` functions instead.\n", " check(\n", "/usr/local/lib/python3.12/dist-packages/torch/_inductor/lowering.py:1744: FutureWarning: `torch._prims_common.check` is deprecated and will be removed in the future. Please use `torch._check*` functions instead.\n", " check(\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ " 1 | 10.5158 | 0.0000 | 19.14% | 26.45% | 26.45% | 100.0% | 100.0% | 100.0% | 100.0% | 201.8\n", " k-details: k1: vol²=9.10e-01 d²=0.910 | k2: vol²=1.63e-01 d²=0.932 | k3: vol²=1.17e-02 d²=0.941 | k4: vol²=4.67e-04 d²=0.955\n", " 6 | 7.9527 | 0.0000 | 41.10% | 51.10% | 51.10% | 100.0% | 100.0% | 100.0% | 100.0% | 20.8\n", " 11 | 7.0038 | 0.0000 | 49.44% | 58.15% | 58.15% | 100.0% | 100.0% | 100.0% | 100.0% | 20.8\n", " 16 | 6.3388 | 0.0000 | 54.49% | 62.76% | 62.76% | 100.0% | 100.0% | 100.0% | 100.0% | 20.8\n", " 21 | 5.9750 | 0.0000 | 57.18% | 65.30% | 66.68% | 100.0% | 100.0% | 100.0% | 100.0% | 20.8\n", " k-details: k1: vol²=9.02e-01 d²=0.902 | k2: vol²=1.03e-01 d²=0.750 | k3: vol²=4.54e-03 d²=0.680 | k4: vol²=1.28e-04 d²=0.691\n", " 26 | 5.6326 | 0.0000 | 59.76% | 69.83% | 69.83% | 100.0% | 100.0% | 100.0% | 100.0% | 20.8\n", " 31 | 5.3184 | 0.0000 | 62.05% | 70.82% | 71.15% | 100.0% | 100.0% | 100.0% | 100.0% | 20.8\n", " 36 | 5.0402 | 0.0000 | 63.99% | 71.96% | 72.37% | 100.0% | 100.0% | 100.0% | 100.0% | 20.8\n", " 41 | 4.7983 | 0.0000 | 65.98% | 74.46% | 74.46% | 100.0% | 100.0% | 100.0% | 100.0% | 20.8\n", " k-details: k1: vol²=1.03e+00 d²=1.028 | k2: vol²=1.17e-01 d²=0.804 | k3: vol²=4.87e-03 d²=0.689 | k4: vol²=1.42e-04 d²=0.712\n", " 46 | 4.5811 | 0.0000 | 67.53% | 75.72% | 75.72% | 100.0% | 100.0% | 100.0% | 100.0% | 20.8\n", " 51 | 4.3392 | 0.0000 | 69.25% | 76.94% | 77.45% | 100.0% | 100.0% | 100.0% | 100.0% | 20.8\n", " 56 | 4.0985 | 0.0000 | 70.81% | 78.93% | 78.93% | 100.0% | 100.0% | 100.0% | 100.0% | 20.8\n", " 61 | 3.8610 | 0.0000 | 72.54% | 78.36% | 79.45% | 100.0% | 100.0% | 100.0% | 100.0% | 20.8\n", " k-details: k1: vol²=1.01e+00 d²=1.008 | k2: vol²=1.20e-01 d²=0.812 | k3: vol²=5.66e-03 d²=0.712 | k4: vol²=1.35e-04 d²=0.704\n", " 66 | 3.7021 | 0.0000 | 73.70% | 80.21% | 80.46% | 100.0% | 100.0% | 100.0% | 100.0% | 20.8\n", " 71 | 3.5193 | 0.0000 | 74.97% | 81.27% | 81.27% | 100.0% | 100.0% | 100.0% | 100.0% | 20.8\n", " 76 | 3.3362 | 0.0000 | 76.15% | 81.62% | 81.84% | 100.0% | 100.0% | 100.0% | 100.0% | 20.8\n", " 81 | 3.1914 | 0.0000 | 77.41% | 82.43% | 82.43% | 100.0% | 100.0% | 100.0% | 100.0% | 20.8\n", " k-details: k1: vol²=9.96e-01 d²=0.996 | k2: vol²=1.19e-01 d²=0.810 | k3: vol²=6.31e-03 d²=0.738 | k4: vol²=1.52e-04 d²=0.727\n", " 86 | 3.0744 | 0.0000 | 78.21% | 83.17% | 83.17% | 100.0% | 100.0% | 100.0% | 100.0% | 20.8\n", " 91 | 2.9843 | 0.0000 | 78.73% | 83.02% | 83.17% | 100.0% | 100.0% | 100.0% | 100.0% | 20.9\n", " 96 | 2.9350 | 0.0000 | 79.13% | 83.19% | 83.19% | 100.0% | 100.0% | 100.0% | 100.0% | 20.9\n", "100 | 2.9431 | 0.0000 | 79.19% | 83.11% | 83.19% | 100.0% | 100.0% | 100.0% | 100.0% | 20.9\n", " k-details: k1: vol²=9.94e-01 d²=0.994 | k2: vol²=1.18e-01 d²=0.805 | k3: vol²=6.45e-03 d²=0.743 | k4: vol²=1.56e-04 d²=0.735\n", "========================================================================================================================\n", "Best: 83.19%\n" ] } ] }, { "cell_type": "code", "source": [ "#@title Geometric Patchwork - Fixed Bottleneck\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets, transforms\n", "import math\n", "from itertools import combinations\n", "import time\n", "\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "print(f\"Device: {device}\")\n", "\n", "from geovocab2.shapes.factory.simplex_factory import SimplexFactory\n", "\n", "# ============================================================================\n", "# CAYLEY-MENGER VALIDATOR\n", "# ============================================================================\n", "\n", "class CMValidator(nn.Module):\n", " def __init__(self, k):\n", " super().__init__()\n", " self._k = k\n", " self._nv = k + 1\n", "\n", " pairs = list(combinations(range(self._nv), 2))\n", " self._npairs = len(pairs)\n", " self.register_buffer('_pi', torch.tensor([p[0] for p in pairs], dtype=torch.long))\n", " self.register_buffer('_pj', torch.tensor([p[1] for p in pairs], dtype=torch.long))\n", "\n", " sign = (-1.0) ** (k + 1)\n", " fact = math.factorial(k)\n", " self._prefactor = sign / ((2.0 ** k) * (fact ** 2))\n", "\n", " def forward(self, verts):\n", " gram = torch.einsum('...ve,...we->...vw', verts, verts)\n", " norms = torch.diagonal(gram, dim1=-2, dim2=-1)\n", " d2_mat = norms.unsqueeze(-1) + norms.unsqueeze(-2) - 2 * gram\n", " d2_mat = F.relu(d2_mat)\n", "\n", " d2_pairs = d2_mat[..., self._pi, self._pj]\n", "\n", " shape = d2_mat.shape[:-2]\n", " V = d2_mat.shape[-1]\n", " cm = torch.zeros(*shape, V+1, V+1, device=d2_mat.device, dtype=d2_mat.dtype)\n", " cm[..., 0, 1:] = 1.0\n", " cm[..., 1:, 0] = 1.0\n", " cm[..., 1:, 1:] = d2_mat\n", "\n", " vol2 = self._prefactor * torch.linalg.det(cm)\n", "\n", " return d2_pairs, vol2\n", "\n", "\n", "# ============================================================================\n", "# K-SIMPLEX CHANNEL ENCODER - FIXED: KEEPS VERTEX FEATURES\n", "# ============================================================================\n", "\n", "class KSimplexChannel(nn.Module):\n", " BASE_DEFORM = 0.05\n", "\n", " def __init__(self, k, in_dim, edim, feat_dim):\n", " super().__init__()\n", " self._k = k\n", " self._nv = k + 1\n", " self._edim = edim\n", " self._feat_dim = feat_dim\n", "\n", " self._cm = CMValidator(k)\n", " self._geo_dim = self._cm._npairs + 1 # d² + vol²\n", "\n", " factory = SimplexFactory(k=k, embed_dim=edim, method=\"regular\", scale=1.0)\n", " self.register_buffer('_template', factory.build_torch(dtype=torch.float32))\n", "\n", " # Input → vertex coordinates (for CM)\n", " self._to_coords = nn.Linear(in_dim, self._nv * edim)\n", "\n", " # Input → vertex features (KEPT, not discarded)\n", " self._to_feats = nn.Linear(in_dim, self._nv * feat_dim)\n", "\n", " # Geometry modulates features\n", " self._geo_gate = nn.Sequential(\n", " nn.Linear(self._geo_dim, feat_dim),\n", " nn.Sigmoid(),\n", " )\n", "\n", " # Output: aggregate vertex features + geometry\n", " self._out_dim = feat_dim + self._geo_dim\n", "\n", " @property\n", " def out_dim(self):\n", " return self._out_dim\n", "\n", " def forward(self, x):\n", " \"\"\"\n", " x: [..., in_dim]\n", "\n", " Returns:\n", " out: [..., feat_dim + geo_dim] (vertex features + geometry)\n", " vol2: [...]\n", " \"\"\"\n", " # Vertex coordinates for CM\n", " coords = self._to_coords(x).unflatten(-1, (self._nv, self._edim))\n", " verts = self._template + self.BASE_DEFORM * coords\n", "\n", " # Vertex features (the actual learned representation)\n", " vert_feats = self._to_feats(x).unflatten(-1, (self._nv, self._feat_dim))\n", "\n", " # CM validation\n", " d2, vol2 = self._cm(verts)\n", " geo = torch.cat([d2, vol2.unsqueeze(-1)], dim=-1)\n", "\n", " # Gate features by geometry (valid simplices pass more info)\n", " gate = self._geo_gate(geo)\n", " validity = torch.sigmoid(vol2 * 1e6).unsqueeze(-1)\n", "\n", " # Aggregate vertex features (mean) with gating\n", " feat_agg = vert_feats.mean(dim=-2) * gate * validity\n", "\n", " # Output: features + geometry\n", " out = torch.cat([feat_agg, geo], dim=-1)\n", "\n", " return out, vol2, d2.mean(dim=-1)\n", "\n", "\n", "# ============================================================================\n", "# PATCH TO K-SIMPLEX CHANNELS - FIXED\n", "# ============================================================================\n", "\n", "class PatchToKChannels(nn.Module):\n", " def __init__(self, patch_dim, depth, edim, feat_dim, hidden=256):\n", " super().__init__()\n", " self._depth = depth\n", " self._edim = edim\n", " self._feat_dim = feat_dim\n", "\n", " self._proj = nn.Sequential(\n", " nn.Linear(patch_dim, hidden),\n", " nn.LayerNorm(hidden),\n", " nn.GELU(),\n", " nn.Linear(hidden, hidden),\n", " nn.LayerNorm(hidden),\n", " nn.GELU(),\n", " )\n", "\n", " self._k_encoders = nn.ModuleList([\n", " KSimplexChannel(k=k+1, in_dim=hidden, edim=edim, feat_dim=feat_dim)\n", " for k in range(depth)\n", " ])\n", "\n", " self._k_out_dims = [enc.out_dim for enc in self._k_encoders]\n", " self._max_out_dim = max(self._k_out_dims)\n", "\n", " def forward(self, patches):\n", " \"\"\"\n", " patches: [B, P_h, P_w, patch_dim]\n", "\n", " Returns:\n", " k_channels: [B, P_h, P_w, K, max_out_dim]\n", " vol2: [B, P_h, P_w, K]\n", " d2_mean: [B, P_h, P_w, K]\n", " \"\"\"\n", " h = self._proj(patches)\n", "\n", " out_list = []\n", " vol2_list = []\n", " d2_list = []\n", "\n", " for enc in self._k_encoders:\n", " out, vol2, d2_mean = enc(h)\n", "\n", " # Pad to max_out_dim\n", " pad_size = self._max_out_dim - out.shape[-1]\n", " if pad_size > 0:\n", " out = F.pad(out, (0, pad_size))\n", "\n", " out_list.append(out)\n", " vol2_list.append(vol2)\n", " d2_list.append(d2_mean)\n", "\n", " k_channels = torch.stack(out_list, dim=-2)\n", " vol2 = torch.stack(vol2_list, dim=-1)\n", " d2_mean = torch.stack(d2_list, dim=-1)\n", "\n", " return k_channels, vol2, d2_mean\n", "\n", "\n", "# ============================================================================\n", "# K-CHANNEL CROSS-ATTENTION\n", "# ============================================================================\n", "\n", "class KChannelCrossAttention(nn.Module):\n", " def __init__(self, depth, feat_dim, num_heads=4, dropout=0.1):\n", " super().__init__()\n", " self._depth = depth\n", " self._feat_dim = feat_dim\n", " self._num_heads = num_heads\n", " self._head_dim = feat_dim // num_heads\n", "\n", " self._norm_q = nn.LayerNorm(feat_dim)\n", " self._norm_kv = nn.LayerNorm(feat_dim)\n", "\n", " self._to_q = nn.Linear(feat_dim, feat_dim)\n", " self._to_k = nn.Linear(feat_dim, feat_dim)\n", " self._to_v = nn.Linear(feat_dim, feat_dim)\n", " self._out = nn.Linear(feat_dim, feat_dim)\n", " self._drop = nn.Dropout(dropout)\n", "\n", " self._scale = self._head_dim ** -0.5\n", "\n", " def forward(self, x):\n", " B, Ph, Pw, K, F = x.shape\n", "\n", " x_flat = x.view(B * Ph * Pw, K, F)\n", "\n", " q = self._to_q(self._norm_q(x_flat))\n", " k = self._to_k(self._norm_kv(x_flat))\n", " v = self._to_v(self._norm_kv(x_flat))\n", "\n", " q = q.view(-1, K, self._num_heads, self._head_dim).transpose(1, 2)\n", " k = k.view(-1, K, self._num_heads, self._head_dim).transpose(1, 2)\n", " v = v.view(-1, K, self._num_heads, self._head_dim).transpose(1, 2)\n", "\n", " attn = (q @ k.transpose(-2, -1)) * self._scale\n", " attn = attn.softmax(dim=-1)\n", " attn = self._drop(attn)\n", "\n", " out = (attn @ v).transpose(1, 2).reshape(B * Ph * Pw, K, F)\n", " out = self._out(out)\n", " out = self._drop(out)\n", "\n", " return x + out.view(B, Ph, Pw, K, F)\n", "\n", "\n", "# ============================================================================\n", "# SPATIAL ATTENTION\n", "# ============================================================================\n", "\n", "class SpatialAttention(nn.Module):\n", " def __init__(self, depth, feat_dim, num_heads=4, dropout=0.1):\n", " super().__init__()\n", " self._num_heads = num_heads\n", "\n", " total_dim = depth * feat_dim\n", " self._head_dim = total_dim // num_heads\n", "\n", " self._norm = nn.LayerNorm(total_dim)\n", " self._to_qkv = nn.Linear(total_dim, 3 * total_dim)\n", " self._out = nn.Linear(total_dim, total_dim)\n", " self._drop = nn.Dropout(dropout)\n", "\n", " self._scale = self._head_dim ** -0.5\n", "\n", " def forward(self, x):\n", " B, Ph, Pw, K, F = x.shape\n", " N = Ph * Pw\n", "\n", " x_flat = x.view(B, N, K * F)\n", " x_norm = self._norm(x_flat)\n", "\n", " qkv = self._to_qkv(x_norm).chunk(3, dim=-1)\n", " q, k, v = [t.view(B, N, self._num_heads, self._head_dim).transpose(1, 2) for t in qkv]\n", "\n", " attn = (q @ k.transpose(-2, -1)) * self._scale\n", " attn = attn.softmax(dim=-1)\n", " attn = self._drop(attn)\n", "\n", " out = (attn @ v).transpose(1, 2).reshape(B, N, K * F)\n", " out = self._out(out)\n", " out = self._drop(out)\n", "\n", " return x + out.view(B, Ph, Pw, K, F)\n", "\n", "\n", "# ============================================================================\n", "# TRANSFORMER BLOCK\n", "# ============================================================================\n", "\n", "class GeoBlock(nn.Module):\n", " def __init__(self, depth, feat_dim, num_heads, mlp_ratio=4.0, dropout=0.1):\n", " super().__init__()\n", "\n", " self._k_attn = KChannelCrossAttention(depth, feat_dim, num_heads, dropout)\n", " self._spatial_attn = SpatialAttention(depth, feat_dim, num_heads, dropout)\n", "\n", " total_dim = depth * feat_dim\n", " self._norm = nn.LayerNorm(total_dim)\n", " self._mlp = nn.Sequential(\n", " nn.Linear(total_dim, int(total_dim * mlp_ratio)),\n", " nn.GELU(),\n", " nn.Dropout(dropout),\n", " nn.Linear(int(total_dim * mlp_ratio), total_dim),\n", " nn.Dropout(dropout),\n", " )\n", "\n", " def forward(self, x):\n", " B, Ph, Pw, K, F = x.shape\n", "\n", " x = self._k_attn(x)\n", " x = self._spatial_attn(x)\n", "\n", " x_flat = x.view(B, Ph, Pw, K * F)\n", " x_flat = x_flat + self._mlp(self._norm(x_flat))\n", " x = x_flat.view(B, Ph, Pw, K, F)\n", "\n", " return x\n", "\n", "\n", "# ============================================================================\n", "# FULL MODEL\n", "# ============================================================================\n", "\n", "class GeometricPatchworkViT(nn.Module):\n", " def __init__(\n", " self,\n", " img_size=32,\n", " patch_size=4, # Smaller patches = more spatial resolution\n", " in_chans=3,\n", " num_classes=10,\n", " depth=4,\n", " edim=16, # Larger embedding dim\n", " feat_dim=64, # Feature dim per k-level\n", " hidden=256, # Larger hidden\n", " num_heads=8,\n", " num_blocks=8,\n", " dropout=0.1,\n", " ):\n", " super().__init__()\n", "\n", " self._patch_size = patch_size\n", " self._ph = img_size // patch_size\n", " self._pw = img_size // patch_size\n", " self._patch_dim = patch_size * patch_size * in_chans\n", " self._depth = depth\n", " self._feat_dim = feat_dim\n", "\n", " # Patch encoder\n", " self._patch_enc = PatchToKChannels(self._patch_dim, depth, edim, feat_dim, hidden)\n", " self._max_out_dim = self._patch_enc._max_out_dim\n", "\n", " # Project to uniform feat_dim for attention\n", " self._proj = nn.Linear(self._max_out_dim, feat_dim)\n", "\n", " # Position embedding\n", " self._pos = nn.Parameter(torch.randn(1, self._ph, self._pw, depth, feat_dim) * 0.02)\n", "\n", " # Blocks\n", " self._blocks = nn.ModuleList([\n", " GeoBlock(depth, feat_dim, num_heads, dropout=dropout)\n", " for _ in range(num_blocks)\n", " ])\n", "\n", " # Classification\n", " total_dim = depth * feat_dim\n", " self._norm = nn.LayerNorm(total_dim)\n", " self._head = nn.Sequential(\n", " nn.Linear(total_dim, total_dim),\n", " nn.GELU(),\n", " nn.Dropout(dropout),\n", " nn.Linear(total_dim, num_classes),\n", " )\n", "\n", " self._config = {\n", " 'img_size': img_size,\n", " 'patch_size': patch_size,\n", " 'grid': f'{self._ph}x{self._pw}',\n", " 'depth': depth,\n", " 'edim': edim,\n", " 'feat_dim': feat_dim,\n", " 'k_out_dims': self._patch_enc._k_out_dims,\n", " 'num_blocks': num_blocks,\n", " 'dropout': dropout,\n", " }\n", "\n", " def forward(self, x):\n", " B = x.shape[0]\n", "\n", " # Extract patches\n", " patches = x.unfold(2, self._patch_size, self._patch_size) \\\n", " .unfold(3, self._patch_size, self._patch_size)\n", " patches = patches.permute(0, 2, 3, 1, 4, 5).contiguous()\n", " patches = patches.view(B, self._ph, self._pw, -1)\n", "\n", " # Encode\n", " k_channels, vol2, d2_mean = self._patch_enc(patches)\n", "\n", " # Project to uniform dim\n", " k_channels = self._proj(k_channels)\n", " k_channels = k_channels + self._pos\n", "\n", " # Blocks\n", " for blk in self._blocks:\n", " k_channels = blk(k_channels)\n", "\n", " # Classify\n", " pooled = k_channels.mean(dim=[1, 2]).flatten(1)\n", " logits = self._head(self._norm(pooled))\n", "\n", " return logits, {'vol2': vol2, 'd2_mean': d2_mean}\n", "\n", "\n", "# ============================================================================\n", "# LOSS & METRICS\n", "# ============================================================================\n", "\n", "def geometric_loss(logits, labels, info, ce_weight=1.0, validity_weight=0.1):\n", " ce = F.cross_entropy(logits, labels)\n", " vol2 = info['vol2']\n", " validity = F.relu(-vol2).mean()\n", " total = ce_weight * ce + validity_weight * validity\n", " return total, ce, validity\n", "\n", "\n", "@torch.no_grad()\n", "def compute_metrics(info, depth):\n", " vol2 = info['vol2']\n", " d2_mean = info['d2_mean']\n", "\n", " m = {'valid_rate': (vol2 > 0).float().mean().item()}\n", " for k in range(depth):\n", " m[f'k{k+1}_valid'] = (vol2[..., k] > 0).float().mean().item()\n", " m[f'k{k+1}_vol2'] = vol2[..., k].mean().item()\n", " m[f'k{k+1}_d2'] = d2_mean[..., k].mean().item()\n", " return m\n", "\n", "\n", "# ============================================================================\n", "# TRAIN\n", "# ============================================================================\n", "\n", "def train():\n", " train_transform = transforms.Compose([\n", " transforms.RandomCrop(32, padding=4),\n", " transforms.RandomHorizontalFlip(),\n", " transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10),\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))\n", " ])\n", " test_transform = transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))\n", " ])\n", "\n", " train_ds = datasets.CIFAR10('./data', train=True, download=True, transform=train_transform)\n", " test_ds = datasets.CIFAR10('./data', train=False, download=True, transform=test_transform)\n", " train_dl = DataLoader(train_ds, batch_size=256, shuffle=True, num_workers=6, pin_memory=True)\n", " test_dl = DataLoader(test_ds, batch_size=512, shuffle=False, num_workers=4, pin_memory=True)\n", "\n", " print(\"\\nBuilding model...\")\n", " model = GeometricPatchworkViT(\n", " img_size=32,\n", " patch_size=4, # 8x8 grid\n", " in_chans=3,\n", " num_classes=10,\n", " depth=4,\n", " edim=16,\n", " feat_dim=64,\n", " hidden=256,\n", " num_heads=8,\n", " num_blocks=8,\n", " dropout=0.1,\n", " ).to(device)\n", "\n", " print(f\"\\nConfig:\")\n", " for k, v in model._config.items():\n", " print(f\" {k}: {v}\")\n", "\n", " params = sum(p.numel() for p in model.parameters())\n", " print(f\" params: {params:,}\")\n", "\n", " print(\"\\nCompiling...\")\n", " model = torch.compile(model, mode=\"reduce-overhead\")\n", "\n", " opt = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.05)\n", " sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=100)\n", "\n", " best = 0\n", " print(\"\\nTraining...\")\n", " print(\"=\" * 120)\n", " print(f\"{'Ep':>3} | {'CE':>6} | {'Val':>6} | {'Tr':>6} | {'Te':>6} | {'Best':>6} | \"\n", " f\"{'k1':>5} | {'k2':>5} | {'k3':>5} | {'k4':>5} | {'s/ep':>5}\")\n", " print(\"-\" * 120)\n", "\n", " for ep in range(300):\n", " t0 = time.time()\n", "\n", " model.train()\n", " ce_sum, val_sum, cor, tot = 0, 0, 0, 0\n", "\n", " for img, lab in train_dl:\n", " img, lab = img.to(device, non_blocking=True), lab.to(device, non_blocking=True)\n", "\n", " opt.zero_grad()\n", " logits, info = model(img)\n", " loss, ce, val = geometric_loss(logits, lab, info)\n", " loss.backward()\n", " torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n", " opt.step()\n", "\n", " ce_sum += ce.item() * img.size(0)\n", " val_sum += val.item() * img.size(0)\n", " cor += (logits.argmax(1) == lab).sum().item()\n", " tot += img.size(0)\n", "\n", " tr_acc = cor / tot\n", "\n", " model.eval()\n", " cor, tot = 0, 0\n", " metrics_agg = []\n", " with torch.no_grad():\n", " for img, lab in test_dl:\n", " img, lab = img.to(device, non_blocking=True), lab.to(device, non_blocking=True)\n", " logits, info = model(img)\n", " cor += (logits.argmax(1) == lab).sum().item()\n", " tot += img.size(0)\n", " metrics_agg.append(compute_metrics(info, model._config['depth']))\n", "\n", " te_acc = cor / tot\n", " sched.step()\n", " if te_acc > best:\n", " best = te_acc\n", "\n", " m = {k: sum(d[k] for d in metrics_agg) / len(metrics_agg) for k in metrics_agg[0]}\n", " elapsed = time.time() - t0\n", "\n", " if ep % 5 == 0 or ep == 99:\n", " print(f\"{ep+1:3d} | {ce_sum/tot:6.4f} | {val_sum/tot:6.4f} | {tr_acc:6.2%} | {te_acc:6.2%} | {best:6.2%} | \"\n", " f\"{m['k1_valid']:5.1%} | {m['k2_valid']:5.1%} | {m['k3_valid']:5.1%} | {m['k4_valid']:5.1%} | {elapsed:5.1f}\")\n", "\n", " if ep % 20 == 0 or ep == 99:\n", " print(f\" k-details: \" + \" | \".join(\n", " f\"k{k+1}: vol²={m[f'k{k+1}_vol2']:.2e} d²={m[f'k{k+1}_d2']:.3f}\"\n", " for k in range(model._config['depth'])\n", " ))\n", "\n", " print(\"=\" * 120)\n", " print(f\"Best: {best:.2%}\")\n", " return model\n", "\n", "\n", "if __name__ == \"__main__\":\n", " model = train()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Z1_7pl34MqvQ", "outputId": "04732ef2-22fb-4db5-d409-d6188d0a2ea4" }, "execution_count": 12, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Device: cuda\n", "\n", "Building model...\n", "\n", "Config:\n", " img_size: 32\n", " patch_size: 4\n", " grid: 8x8\n", " depth: 4\n", " edim: 16\n", " feat_dim: 64\n", " k_out_dims: [66, 68, 71, 75]\n", " num_blocks: 8\n", " dropout: 0.1\n", " params: 6,912,362\n", "\n", "Compiling...\n", "\n", "Training...\n", "========================================================================================================================\n", " Ep | CE | Val | Tr | Te | Best | k1 | k2 | k3 | k4 | s/ep\n", "------------------------------------------------------------------------------------------------------------------------\n", " 1 | 10.7341 | 0.0000 | 17.61% | 22.68% | 22.68% | 100.0% | 100.0% | 100.0% | 100.0% | 37.1\n", " k-details: k1: vol²=1.27e+00 d²=1.270 | k2: vol²=2.13e-01 d²=1.052 | k3: vol²=2.36e-02 d²=1.168 | k4: vol²=1.11e-03 d²=1.171\n", " 6 | 8.1430 | 0.0000 | 39.54% | 47.90% | 47.90% | 100.0% | 100.0% | 100.0% | 100.0% | 21.0\n", " 11 | 6.9631 | 0.0000 | 49.49% | 59.71% | 59.71% | 100.0% | 100.0% | 100.0% | 100.0% | 21.1\n", " 16 | 6.3790 | 0.0000 | 54.18% | 62.40% | 63.07% | 100.0% | 100.0% | 100.0% | 100.0% | 21.0\n", " 21 | 5.9825 | 0.0000 | 57.42% | 66.64% | 66.64% | 100.0% | 100.0% | 100.0% | 100.0% | 21.0\n", " k-details: k1: vol²=9.97e-01 d²=0.997 | k2: vol²=1.08e-01 d²=0.773 | k3: vol²=8.13e-03 d²=0.908 | k4: vol²=1.70e-04 d²=0.769\n", " 26 | 5.6247 | 0.0000 | 59.81% | 68.86% | 68.86% | 100.0% | 100.0% | 100.0% | 100.0% | 21.0\n", " 31 | 5.3297 | 0.0000 | 62.29% | 70.94% | 70.94% | 100.0% | 100.0% | 100.0% | 100.0% | 21.0\n", " 36 | 5.0943 | 0.0000 | 63.86% | 73.07% | 73.07% | 100.0% | 100.0% | 100.0% | 100.0% | 21.0\n", " 41 | 4.8257 | 0.0000 | 65.78% | 74.38% | 74.63% | 100.0% | 100.0% | 100.0% | 100.0% | 21.0\n", " k-details: k1: vol²=9.75e-01 d²=0.975 | k2: vol²=1.11e-01 d²=0.779 | k3: vol²=7.68e-03 d²=0.874 | k4: vol²=1.21e-04 d²=0.719\n", " 46 | 4.5609 | 0.0000 | 67.71% | 75.12% | 75.72% | 100.0% | 100.0% | 100.0% | 100.0% | 21.2\n", " 51 | 4.3283 | 0.0000 | 69.25% | 77.77% | 77.77% | 100.0% | 100.0% | 100.0% | 100.0% | 21.0\n", " 56 | 4.1123 | 0.0000 | 70.77% | 78.38% | 78.50% | 100.0% | 100.0% | 100.0% | 100.0% | 21.0\n", " 61 | 3.8848 | 0.0000 | 72.38% | 79.06% | 79.52% | 100.0% | 100.0% | 100.0% | 100.0% | 21.0\n", " k-details: k1: vol²=9.58e-01 d²=0.958 | k2: vol²=1.11e-01 d²=0.780 | k3: vol²=6.92e-03 d²=0.835 | k4: vol²=1.16e-04 d²=0.711\n", " 66 | 3.6766 | 0.0000 | 73.93% | 79.97% | 80.63% | 100.0% | 100.0% | 100.0% | 100.0% | 21.0\n", " 71 | 3.4712 | 0.0000 | 75.13% | 81.30% | 81.30% | 100.0% | 100.0% | 100.0% | 100.0% | 21.0\n", " 76 | 3.3240 | 0.0000 | 76.67% | 81.82% | 81.82% | 100.0% | 100.0% | 100.0% | 100.0% | 21.0\n", " 81 | 3.1819 | 0.0000 | 77.33% | 82.45% | 82.45% | 100.0% | 100.0% | 100.0% | 100.0% | 21.0\n", " k-details: k1: vol²=9.40e-01 d²=0.940 | k2: vol²=1.15e-01 d²=0.796 | k3: vol²=7.40e-03 d²=0.857 | k4: vol²=1.26e-04 d²=0.725\n", " 86 | 3.0697 | 0.0000 | 78.18% | 82.46% | 82.46% | 100.0% | 100.0% | 100.0% | 100.0% | 21.0\n", " 91 | 2.9876 | 0.0000 | 78.74% | 82.80% | 82.83% | 100.0% | 100.0% | 100.0% | 100.0% | 21.1\n", " 96 | 2.9530 | 0.0000 | 79.06% | 82.78% | 82.88% | 100.0% | 100.0% | 100.0% | 100.0% | 21.0\n", "100 | 2.9485 | 0.0000 | 79.08% | 82.74% | 82.88% | 100.0% | 100.0% | 100.0% | 100.0% | 21.0\n", " k-details: k1: vol²=9.23e-01 d²=0.923 | k2: vol²=1.17e-01 d²=0.802 | k3: vol²=7.46e-03 d²=0.859 | k4: vol²=1.30e-04 d²=0.731\n", "101 | 2.9604 | 0.0000 | 79.25% | 82.74% | 82.88% | 100.0% | 100.0% | 100.0% | 100.0% | 21.1\n", " k-details: k1: vol²=9.23e-01 d²=0.923 | k2: vol²=1.17e-01 d²=0.802 | k3: vol²=7.46e-03 d²=0.859 | k4: vol²=1.30e-04 d²=0.731\n", "106 | 2.9602 | 0.0000 | 79.07% | 82.76% | 82.88% | 100.0% | 100.0% | 100.0% | 100.0% | 21.0\n", "111 | 2.9279 | 0.0000 | 79.12% | 82.67% | 82.88% | 100.0% | 100.0% | 100.0% | 100.0% | 21.0\n", "116 | 3.0012 | 0.0000 | 78.77% | 82.71% | 82.88% | 100.0% | 100.0% | 100.0% | 100.0% | 21.0\n", "121 | 3.0037 | 0.0000 | 78.68% | 82.63% | 82.88% | 100.0% | 100.0% | 100.0% | 100.0% | 21.0\n", " k-details: k1: vol²=9.23e-01 d²=0.923 | k2: vol²=1.16e-01 d²=0.799 | k3: vol²=7.33e-03 d²=0.854 | k4: vol²=1.31e-04 d²=0.733\n", "126 | 3.0616 | 0.0000 | 78.21% | 82.52% | 82.93% | 100.0% | 100.0% | 100.0% | 100.0% | 21.1\n", "131 | 3.1285 | 0.0000 | 77.99% | 81.94% | 82.93% | 100.0% | 100.0% | 100.0% | 100.0% | 21.1\n", "136 | 3.1699 | 0.0000 | 77.45% | 81.73% | 82.93% | 100.0% | 100.0% | 100.0% | 100.0% | 21.0\n", "141 | 3.2558 | 0.0000 | 76.87% | 82.36% | 82.93% | 100.0% | 100.0% | 100.0% | 100.0% | 21.1\n", " k-details: k1: vol²=9.26e-01 d²=0.926 | k2: vol²=1.18e-01 d²=0.804 | k3: vol²=6.88e-03 d²=0.823 | k4: vol²=1.40e-04 d²=0.737\n", "146 | 3.2993 | 0.0000 | 76.66% | 81.21% | 82.93% | 100.0% | 100.0% | 100.0% | 100.0% | 21.1\n", "151 | 3.3385 | 0.0000 | 76.40% | 80.96% | 82.93% | 100.0% | 100.0% | 100.0% | 100.0% | 21.0\n", "156 | 3.3985 | 0.0000 | 76.07% | 81.65% | 82.93% | 100.0% | 100.0% | 100.0% | 100.0% | 21.0\n", "161 | 3.4146 | 0.0000 | 75.87% | 81.12% | 82.93% | 100.0% | 100.0% | 100.0% | 100.0% | 21.1\n", " k-details: k1: vol²=9.40e-01 d²=0.940 | k2: vol²=1.16e-01 d²=0.799 | k3: vol²=6.68e-03 d²=0.814 | k4: vol²=1.39e-04 d²=0.732\n", "166 | 3.4092 | 0.0000 | 75.85% | 81.13% | 82.93% | 100.0% | 100.0% | 100.0% | 100.0% | 21.2\n", "171 | 3.4158 | 0.0000 | 75.81% | 82.12% | 82.93% | 100.0% | 100.0% | 100.0% | 100.0% | 21.2\n", "176 | 3.3912 | 0.0000 | 76.09% | 80.91% | 82.93% | 100.0% | 100.0% | 100.0% | 100.0% | 21.2\n", "181 | 3.3550 | 0.0000 | 76.16% | 81.88% | 82.93% | 100.0% | 100.0% | 100.0% | 100.0% | 21.2\n", " k-details: k1: vol²=9.66e-01 d²=0.966 | k2: vol²=1.15e-01 d²=0.799 | k3: vol²=5.96e-03 d²=0.788 | k4: vol²=1.28e-04 d²=0.718\n", "186 | 3.2810 | 0.0000 | 76.82% | 81.05% | 82.93% | 100.0% | 100.0% | 100.0% | 100.0% | 21.2\n", "191 | 3.2636 | 0.0000 | 76.79% | 82.75% | 82.93% | 100.0% | 100.0% | 100.0% | 100.0% | 21.2\n", "196 | 3.0914 | 0.0000 | 78.19% | 82.86% | 83.14% | 100.0% | 100.0% | 100.0% | 100.0% | 21.2\n", "201 | 3.0576 | 0.0000 | 78.38% | 82.47% | 83.14% | 100.0% | 100.0% | 100.0% | 100.0% | 21.2\n", " k-details: k1: vol²=9.91e-01 d²=0.991 | k2: vol²=1.07e-01 d²=0.776 | k3: vol²=4.87e-03 d²=0.741 | k4: vol²=1.01e-04 d²=0.671\n", "206 | 2.9441 | 0.0000 | 79.48% | 82.73% | 83.36% | 100.0% | 100.0% | 100.0% | 100.0% | 21.2\n", "211 | 2.8108 | 0.0000 | 80.09% | 83.97% | 83.97% | 100.0% | 100.0% | 100.0% | 100.0% | 21.2\n", "216 | 2.7292 | 0.0000 | 80.61% | 84.44% | 84.44% | 100.0% | 100.0% | 100.0% | 100.0% | 21.2\n", "221 | 2.6265 | 0.0000 | 81.48% | 83.51% | 84.44% | 100.0% | 100.0% | 100.0% | 100.0% | 21.2\n", " k-details: k1: vol²=9.92e-01 d²=0.992 | k2: vol²=1.05e-01 d²=0.765 | k3: vol²=4.42e-03 d²=0.722 | k4: vol²=9.80e-05 d²=0.652\n", "226 | 2.4497 | 0.0000 | 82.78% | 84.68% | 84.83% | 100.0% | 100.0% | 100.0% | 100.0% | 21.2\n", "231 | 2.3336 | 0.0000 | 83.54% | 84.98% | 84.98% | 100.0% | 100.0% | 100.0% | 100.0% | 21.2\n", "236 | 2.1651 | 0.0000 | 84.70% | 85.17% | 85.19% | 100.0% | 100.0% | 100.0% | 100.0% | 21.3\n", "241 | 2.0805 | 0.0000 | 85.44% | 85.26% | 85.50% | 100.0% | 100.0% | 100.0% | 100.0% | 21.2\n", " k-details: k1: vol²=9.93e-01 d²=0.993 | k2: vol²=1.06e-01 d²=0.765 | k3: vol²=4.37e-03 d²=0.710 | k4: vol²=1.12e-04 d²=0.673\n", "246 | 1.8865 | 0.0000 | 86.71% | 85.79% | 86.17% | 100.0% | 100.0% | 100.0% | 100.0% | 21.2\n", "251 | 1.7596 | 0.0000 | 87.56% | 85.88% | 86.17% | 100.0% | 100.0% | 100.0% | 100.0% | 21.2\n", "256 | 1.6192 | 0.0000 | 88.61% | 86.30% | 86.70% | 100.0% | 100.0% | 100.0% | 100.0% | 21.3\n", "261 | 1.5038 | 0.0000 | 89.36% | 86.39% | 86.74% | 100.0% | 100.0% | 100.0% | 100.0% | 21.2\n", " k-details: k1: vol²=9.50e-01 d²=0.950 | k2: vol²=1.09e-01 d²=0.777 | k3: vol²=4.59e-03 d²=0.728 | k4: vol²=1.16e-04 d²=0.679\n", "266 | 1.3770 | 0.0000 | 90.35% | 86.66% | 86.76% | 100.0% | 100.0% | 100.0% | 100.0% | 21.2\n", "271 | 1.2501 | 0.0000 | 91.18% | 87.24% | 87.28% | 100.0% | 100.0% | 100.0% | 100.0% | 21.2\n", "276 | 1.1740 | 0.0000 | 91.95% | 87.18% | 87.33% | 100.0% | 100.0% | 100.0% | 100.0% | 21.2\n", "281 | 1.0817 | 0.0000 | 92.54% | 87.66% | 87.73% | 100.0% | 100.0% | 100.0% | 100.0% | 21.3\n", " k-details: k1: vol²=9.39e-01 d²=0.939 | k2: vol²=1.10e-01 d²=0.779 | k3: vol²=4.71e-03 d²=0.728 | k4: vol²=1.29e-04 d²=0.697\n", "286 | 1.0079 | 0.0000 | 92.83% | 87.71% | 87.80% | 100.0% | 100.0% | 100.0% | 100.0% | 21.3\n", "291 | 0.9460 | 0.0000 | 93.37% | 87.77% | 87.86% | 100.0% | 100.0% | 100.0% | 100.0% | 21.2\n", "296 | 0.9487 | 0.0000 | 93.33% | 87.82% | 87.91% | 100.0% | 100.0% | 100.0% | 100.0% | 21.2\n", "========================================================================================================================\n", "Best: 87.98%\n" ] } ] } ] }