from __future__ import annotations from typing import Iterable, Sequence import torch from torch import nn from torch.nn import functional as F def squash(x: torch.Tensor, dim: int = -1, eps: float = 1e-7) -> torch.Tensor: """Squash nonlinearity used by capsule networks.""" squared_norm = (x * x).sum(dim=dim, keepdim=True) scale = squared_norm / (1.0 + squared_norm) return scale * x / torch.sqrt(squared_norm + eps) class ConvBNAct(nn.Module): """Convolution + BatchNorm + SiLU.""" def __init__( self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, padding: int | None = None, ) -> None: super().__init__() if padding is None: padding = kernel_size // 2 self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False) self.bn = nn.BatchNorm2d(out_channels) self.act = nn.SiLU(inplace=True) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.act(self.bn(self.conv(x))) class PrimaryCaps2d(nn.Module): """Primary capsule layer for 2D feature maps.""" def __init__( self, in_channels: int, num_caps: int, dim_caps: int, kernel_size: int = 1, stride: int = 1, padding: int | None = None, ) -> None: super().__init__() if padding is None: padding = kernel_size // 2 out_channels = num_caps * dim_caps self.num_caps = num_caps self.dim_caps = dim_caps self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False) self.bn = nn.BatchNorm2d(out_channels) self.act = nn.SiLU(inplace=True) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.act(self.bn(self.conv(x))) bsz, _, h, w = x.shape x = x.view(bsz, self.num_caps, self.dim_caps, h, w) return squash(x, dim=2) class RoutingCaps(nn.Module): """Dynamic routing between capsules.""" def __init__( self, num_in_caps: int, dim_in: int, num_out_caps: int, dim_out: int, routing_iters: int = 3, ) -> None: super().__init__() self.num_in_caps = num_in_caps self.dim_in = dim_in self.num_out_caps = num_out_caps self.dim_out = dim_out self.routing_iters = routing_iters weight = torch.randn(1, num_in_caps, num_out_caps, dim_out, dim_in) * 0.01 self.W = nn.Parameter(weight) def forward(self, x: torch.Tensor) -> torch.Tensor: if x.ndim != 3: raise ValueError(f"RoutingCaps expects [B, N, D], got {tuple(x.shape)}") bsz = x.shape[0] x = x.unsqueeze(2).unsqueeze(-1) # [B, N, 1, D, 1] u_hat = torch.matmul(self.W, x).squeeze(-1) # [B, N, M, Dout] b = x.new_zeros(bsz, self.num_in_caps, self.num_out_caps) for idx in range(self.routing_iters): c = F.softmax(b, dim=-1) s = (c.unsqueeze(-1) * u_hat).sum(dim=1) v = squash(s, dim=-1) if idx < self.routing_iters - 1: b = b + (u_hat * v.unsqueeze(1)).sum(dim=-1) return v class DeformableCaps2d(nn.Module): """Deformable capsule layer with learned sampling offsets.""" def __init__( self, in_channels: int, num_child_caps: int = 8, dim_child: int = 8, num_parent_caps: int = 8, dim_parent: int = 8, num_samples: int = 4, routing_iters: int = 3, offset_scale: float = 1.0, out_channels: int | None = None, ) -> None: super().__init__() self.num_child_caps = num_child_caps self.dim_child = dim_child self.num_parent_caps = num_parent_caps self.dim_parent = dim_parent self.num_samples = num_samples self.routing_iters = routing_iters self.offset_scale = offset_scale self.primary = PrimaryCaps2d(in_channels, num_child_caps, dim_child, kernel_size=1, stride=1, padding=0) self.offset = nn.Conv2d(in_channels, 2 * num_samples, kernel_size=3, stride=1, padding=1) nn.init.zeros_(self.offset.weight) nn.init.zeros_(self.offset.bias) self.routing = RoutingCaps( num_in_caps=num_samples * num_child_caps, dim_in=dim_child, num_out_caps=num_parent_caps, dim_out=dim_parent, routing_iters=routing_iters, ) caps_channels = num_parent_caps * dim_parent self.out_channels = out_channels or caps_channels self.project = None if self.out_channels != caps_channels: self.project = ConvBNAct(caps_channels, self.out_channels, kernel_size=1, stride=1, padding=0) @staticmethod def _base_grid(h: int, w: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor: ys = torch.linspace(-1.0, 1.0, h, device=device, dtype=dtype) xs = torch.linspace(-1.0, 1.0, w, device=device, dtype=dtype) yy, xx = torch.meshgrid(ys, xs, indexing="ij") return torch.stack((xx, yy), dim=-1) def forward(self, x: torch.Tensor) -> torch.Tensor: child = self.primary(x) # [B, Nc, Dc, H, W] bsz, _, _, h, w = child.shape child_flat = child.view(bsz, self.num_child_caps * self.dim_child, h, w) offsets = self.offset(x).view(bsz, self.num_samples, 2, h, w) offsets = torch.tanh(offsets) * self.offset_scale scale_x = max(w - 1, 1) / 2.0 scale_y = max(h - 1, 1) / 2.0 scale = offsets.new_tensor([scale_x, scale_y]).view(1, 1, 2, 1, 1) offsets = offsets / scale base = self._base_grid(h, w, x.device, x.dtype).view(1, 1, h, w, 2) grids = base + offsets.permute(0, 1, 3, 4, 2) sampled = [] for idx in range(self.num_samples): grid = grids[:, idx] feat = F.grid_sample( child_flat, grid, mode="bilinear", padding_mode="zeros", align_corners=True, ) sampled.append(feat) sampled = torch.stack(sampled, dim=1) sampled = sampled.view(bsz, self.num_samples, self.num_child_caps, self.dim_child, h, w) sampled = sampled.permute(0, 4, 5, 1, 2, 3).contiguous() sampled = sampled.view(bsz * h * w, self.num_samples * self.num_child_caps, self.dim_child) routed = self.routing(sampled) routed = routed.view(bsz, h, w, self.num_parent_caps, self.dim_parent) routed = routed.permute(0, 3, 4, 1, 2).contiguous() out = routed.view(bsz, self.num_parent_caps * self.dim_parent, h, w) if self.project is not None: out = self.project(out) return out class DeformableCapsBlock(nn.Module): """Backbone block: Conv downsample + deformable capsule routing.""" def __init__( self, c1: int, c2: int, num_child_caps: int = 8, dim_child: int = 8, num_parent_caps: int = 8, dim_parent: int = 8, num_samples: int = 4, routing_iters: int = 3, stride: int = 1, ) -> None: super().__init__() self.down = ConvBNAct(c1, c2, kernel_size=3, stride=stride) self.caps = DeformableCaps2d( in_channels=c2, num_child_caps=num_child_caps, dim_child=dim_child, num_parent_caps=num_parent_caps, dim_parent=dim_parent, num_samples=num_samples, routing_iters=routing_iters, out_channels=c2, ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.down(x) return self.caps(x) class CapsuleBackbone(nn.Module): """Simple capsule-based backbone that returns multi-scale features.""" def __init__( self, in_channels: int = 3, stem_channels: int = 64, stages: Sequence[int] = (128, 256, 512), capsule_cfgs: Iterable[dict] | None = None, ) -> None: super().__init__() self.stem = ConvBNAct(in_channels, stem_channels, kernel_size=3, stride=2) stage_cfgs = list(capsule_cfgs) if capsule_cfgs is not None else [{}] * len(stages) if len(stage_cfgs) != len(stages): raise ValueError("capsule_cfgs must match stages length") blocks = [] in_ch = stem_channels for out_ch, cfg in zip(stages, stage_cfgs): blocks.append( nn.Sequential( ConvBNAct(in_ch, out_ch, kernel_size=3, stride=2), DeformableCaps2d(out_ch, out_channels=out_ch, **cfg), ) ) in_ch = out_ch self.stages = nn.ModuleList(blocks) def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, ...]: x = self.stem(x) outputs = [] for stage in self.stages: x = stage(x) outputs.append(x) return tuple(outputs) __all__ = [ "CapsuleBackbone", "ConvBNAct", "DeformableCaps2d", "DeformableCapsBlock", "PrimaryCaps2d", "RoutingCaps", "squash", ]