from __future__ import annotations import copy import math import time import torch import torch.nn as nn import torch.nn.functional as F from ultralytics.nn.modules import Conv, DWConv, Detect, Segment from ultralytics.nn.modules.block import Proto26 class PrimaryCaps(nn.Module): r"""Primary convolutional capsules. Outputs pose and activation, plus a concatenated NHWC capsule tensor. Args: A: Input feature channels. B: Number of capsule types. K: Convolution kernel size. P: Pose matrix side length (pose size is ``P*P``). stride: Convolution stride. Input shape: x: ``(N, A, H, W)`` Output shape: a: ``(N, B, H_out, W_out)`` p: ``(N, B*P*P, H_out, W_out)`` out: ``(N, H_out, W_out, B*(P*P+1))`` Parameter size: pose conv + act conv ``(K*K*A*B*P*P + B*P*P) + (K*K*A*B + B)`` """ def __init__(self, A: int = 32, B: int = 32, K: int = 1, P: int = 4, stride: int = 1): super().__init__() self.B = B self.P = P self.psize = P * P self.pose = nn.Conv2d(in_channels=A, out_channels=B * self.psize, kernel_size=K, stride=stride, bias=True) self.a = nn.Conv2d(in_channels=A, out_channels=B, kernel_size=K, stride=stride, bias=True) self.sigmoid = nn.Sigmoid() def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # p: (B, B*psize, H, W), a: (B, B, H, W) p = self.pose(x) a = self.sigmoid(self.a(x)) out = torch.cat([p, a], dim=1).permute(0, 2, 3, 1).contiguous() # (B, H, W, B*(psize+1)) return a, p, out class ConvCaps(nn.Module): r"""Convolutional capsules with EM routing. Args: B: Input capsule types. C: Output capsule types. K: Patch kernel size. P: Pose matrix side length (pose size is ``P*P``). stride: Spatial stride for patch extraction. iters: Number of EM routing iterations. coor_add: Add coordinate offsets (class-caps style option). w_shared: Share transform matrices across spatial positions. Input shape: x: ``(N, H, W, B*(P*P+1))`` Output shape: p_out: ``(N, H_out, W_out, C*P*P)`` a_out: ``(N, H_out, W_out, C)`` out: ``(N, H_out, W_out, C*(P*P+1))`` Parameter size: If ``w_shared=False``: ``weights: (K*K*B*C*P*P*P*P)``, ``beta_u: C``, ``beta_a: C`` If ``w_shared=True``: ``weights: (B*C*P*P*P*P)``, ``beta_u: C``, ``beta_a: C`` Total = ``weights + 2*C`` (excluding non-trainable buffers). """ def __init__( self, B: int = 32, C: int = 32, K: int = 3, P: int = 4, stride: int = 1, iters: int = 3, coor_add: bool = False, w_shared: bool = False, ): super().__init__() self.B = B self.C = C self.K = K self.P = P self.psize = P * P self.stride = stride self.iters = iters self.coor_add = coor_add self.w_shared = w_shared self.eps = 1e-6 self._lambda = 1e-3 self.register_buffer("ln_2pi", torch.tensor(math.log(2 * math.pi), dtype=torch.float32), persistent=False) # Matrix-caps paper uses per-capsule beta scalars. self.beta_u = nn.Parameter(torch.zeros(C)) self.beta_a = nn.Parameter(torch.zeros(C)) # For non-shared conv-caps, input vote count is K*K*B. For shared mode it is B then repeated by HW. weight_in = B if w_shared else (K * K * B) self.weights = nn.Parameter(torch.randn(1, weight_in, C, self.psize, self.psize) * 0.02) self.sigmoid = nn.Sigmoid() self.softmax = nn.Softmax(dim=2) def m_step( self, a_in: torch.Tensor, r: torch.Tensor, v: torch.Tensor, eps: float, b: int, B: int, C: int, psize: int, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # a_in: (b, B, 1) or (b, B, 1, 1), r: (b, B, C, 1), v: (b, B, C, psize) if a_in.ndim == 3: a_in = a_in.unsqueeze(2) r = r * a_in r = r / (r.sum(dim=2, keepdim=True) + eps) r_sum = r.sum(dim=1, keepdim=True) coeff = r / (r_sum + eps) mu = torch.sum(coeff * v, dim=1, keepdim=True) # (b, 1, C, psize) sigma_sq = torch.sum(coeff * (v - mu).pow(2), dim=1, keepdim=True) + eps sigma_sq = sigma_sq.clamp_min(1e-4) r_sum_flat = r_sum.view(b, C, 1) sigma_sq_flat = sigma_sq.view(b, C, psize).clamp_min(1e-4) cost_h = (self.beta_u.view(1, C, 1) + torch.log(torch.sqrt(sigma_sq_flat))) * r_sum_flat a_out = self.sigmoid(self._lambda * (self.beta_a.view(1, C) - cost_h.sum(dim=2))).clamp(1e-4, 1.0 - 1e-4) mu = torch.nan_to_num(mu, nan=0.0, posinf=1e4, neginf=-1e4) sigma_sq = torch.nan_to_num(sigma_sq, nan=1e-4, posinf=1e4, neginf=1e-4) a_out = torch.nan_to_num(a_out, nan=0.5, posinf=1.0 - 1e-4, neginf=1e-4) return a_out, mu, sigma_sq def e_step( self, mu: torch.Tensor, sigma_sq: torch.Tensor, a_out: torch.Tensor, v: torch.Tensor, eps: float, b: int, C: int, ) -> torch.Tensor: # mu: (b,1,C,psize), sigma_sq: (b,1,C,psize), a_out: (b,C), v: (b,B,C,psize) sigma_sq = sigma_sq.clamp_min(1e-4) a_out = a_out.clamp(1e-4, 1.0 - 1e-4) ln_p_j_h = -1.0 * (v - mu).pow(2) / (2.0 * sigma_sq) - torch.log(torch.sqrt(sigma_sq)) - 0.5 * self.ln_2pi ln_ap = ln_p_j_h.sum(dim=3) + torch.log(a_out.view(b, 1, C) + eps) ln_ap = torch.nan_to_num(ln_ap, nan=0.0, posinf=50.0, neginf=-50.0) r = self.softmax(ln_ap).unsqueeze(-1) # (b,B,C,1) r = torch.nan_to_num(r, nan=(1.0 / max(C, 1)), posinf=1.0, neginf=0.0) return r def caps_em_routing(self, v: torch.Tensor, a_in: torch.Tensor, C: int, eps: float) -> tuple[torch.Tensor, torch.Tensor]: b, B, _, psize = v.shape r = v.new_full((b, B, C, 1), 1.0 / C) for t in range(self.iters): a_out, mu, sigma_sq = self.m_step(a_in, r, v, eps, b, B, C, psize) if t < self.iters - 1: r = self.e_step(mu, sigma_sq, a_out, v, eps, b, C) # p_out: (b, C, psize), a_out: (b, C) p_out = torch.nan_to_num(mu.squeeze(1), nan=0.0, posinf=1e4, neginf=-1e4) a_out = torch.nan_to_num(a_out, nan=0.5, posinf=1.0 - 1e-4, neginf=1e-4) return p_out, a_out def add_pathes(self, x: torch.Tensor, B: int, K: int, psize: int, stride: int) -> tuple[torch.Tensor, int, int]: # x: (b, h, w, B*(psize+1)) -> patches: (b, oh, ow, K*K, B*(psize+1)) b, h, w, c = x.shape x_chw = x.permute(0, 3, 1, 2).contiguous() pad = K // 2 patches = F.unfold(x_chw, kernel_size=K, padding=pad, stride=stride) oh = (h + 2 * pad - K) // stride + 1 ow = (w + 2 * pad - K) // stride + 1 patches = patches.transpose(1, 2).contiguous().view(b, oh, ow, K * K, c) return patches, oh, ow def transform_view(self, x: torch.Tensor, w: torch.Tensor, C: int, P: int, w_shared: bool = False) -> torch.Tensor: # x: (b, in_votes, psize), w: (1, in_votes_base, C, psize, psize) b, in_votes, psize = x.shape assert psize == P * P w0 = w[0] if w_shared: base = w0.size(0) reps = in_votes // base w0 = w0.repeat(reps, 1, 1, 1) # (b, in_votes, C, psize) v = torch.einsum("bip,icpq->bicq", x, w0) return v def add_coord(self, v: torch.Tensor, b: int, h: int, w: int, B: int, C: int, psize: int) -> torch.Tensor: # v: (b, h*w*B, C, psize) # Supports rectangular feature maps (h != w). v = v.view(b, h, w, B, C, psize) device = v.device dtype = v.dtype coor_h_vals = torch.arange(h, dtype=dtype, device=device) / float(max(h, 1)) coor_w_vals = torch.arange(w, dtype=dtype, device=device) / float(max(w, 1)) coor_h = torch.zeros(1, h, 1, 1, 1, psize, dtype=dtype, device=device) coor_w = torch.zeros(1, 1, w, 1, 1, psize, dtype=dtype, device=device) coor_h[0, :, 0, 0, 0, 0] = coor_h_vals coor_w[0, 0, :, 0, 0, 1] = coor_w_vals v = (v + coor_h + coor_w).view(b, h * w * B, C, psize) return v def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # x shape: (b, h, w, B*(psize+1)) b, h, w, c = x.shape if not self.w_shared: patches, oh, ow = self.add_pathes(x, self.B, self.K, self.psize, self.stride) p_in = patches[..., : self.B * self.psize].contiguous().view(b * oh * ow, self.K * self.K * self.B, self.psize) a_in = patches[..., self.B * self.psize :].contiguous().view(b * oh * ow, self.K * self.K * self.B, 1) v = self.transform_view(p_in, self.weights, self.C, self.P, w_shared=False) p_out, a_out = self.caps_em_routing(v, a_in, self.C, self.eps) p_out = p_out.view(b, oh, ow, self.C * self.psize) a_out = a_out.view(b, oh, ow, self.C) out = torch.cat([p_out, a_out], dim=3) else: assert c == self.B * (self.psize + 1) assert self.K == 1 assert self.stride == 1 p_in = x[..., : self.B * self.psize].contiguous().view(b, h * w * self.B, self.psize) a_in = x[..., self.B * self.psize :].contiguous().view(b, h * w * self.B, 1) v = self.transform_view(p_in, self.weights, self.C, self.P, w_shared=True) if self.coor_add: v = self.add_coord(v, b, h, w, self.B, self.C, self.psize) p_cls, a_cls = self.caps_em_routing(v, a_in, self.C, self.eps) # Broadcast class capsules back to spatial map for Detect-style dense outputs. p_out = p_cls.reshape(b, 1, 1, self.C * self.psize).expand(b, h, w, self.C * self.psize) a_out = a_cls.unsqueeze(1).unsqueeze(1).expand(b, h, w, self.C) out = torch.cat([p_out, a_out], dim=3) return p_out, a_out, out class DynamicConvCaps(nn.Module): r"""Convolutional capsules with Sabour-style dynamic routing. This layer keeps the same tensor interface as ``ConvCaps``: input: (N, H, W, B*(P*P+1)) output: p_out (N, H_out, W_out, C*P*P), a_out (N, H_out, W_out, C), out concat Args: B: Input capsule types. C: Output capsule types. K: Patch kernel size. P: Pose matrix side length. stride: Patch stride. iters: Routing iterations. coor_add: Add coordinates in shared mode. w_shared: Share transforms across spatial positions (requires K=1, stride=1). """ def __init__( self, B: int = 32, C: int = 32, K: int = 3, P: int = 4, stride: int = 1, iters: int = 3, coor_add: bool = False, w_shared: bool = False, ): super().__init__() self.B = B self.C = C self.K = K self.P = P self.psize = P * P self.stride = stride self.iters = iters self.coor_add = coor_add self.w_shared = w_shared self.eps = 1e-6 weight_in = B if w_shared else (K * K * B) self.weights = nn.Parameter(torch.randn(1, weight_in, C, self.psize, self.psize) * 0.02) @staticmethod def squash(s: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor: s2 = (s * s).sum(dim=dim, keepdim=True) scale = s2 / (1.0 + s2) return scale * s / torch.sqrt(s2 + eps) def add_pathes(self, x: torch.Tensor, K: int, stride: int) -> tuple[torch.Tensor, int, int]: b, h, w, c = x.shape x_chw = x.permute(0, 3, 1, 2).contiguous() pad = K // 2 patches = F.unfold(x_chw, kernel_size=K, padding=pad, stride=stride) oh = (h + 2 * pad - K) // stride + 1 ow = (w + 2 * pad - K) // stride + 1 patches = patches.transpose(1, 2).contiguous().view(b, oh, ow, K * K, c) return patches, oh, ow def transform_view(self, x: torch.Tensor, w_shared: bool) -> torch.Tensor: # x: (b, in_votes, psize) -> votes: (b, in_votes, C, psize) b, in_votes, psize = x.shape if psize != self.psize: raise ValueError('Invalid pose size for DynamicConvCaps') w0 = self.weights[0] if w_shared: base = w0.size(0) reps = in_votes // base w0 = w0.repeat(reps, 1, 1, 1) return torch.einsum('bip,icpq->bicq', x, w0) def add_coord(self, v: torch.Tensor, b: int, h: int, w: int, B: int, C: int, psize: int) -> torch.Tensor: # v: (b, h*w*B, C, psize) v = v.view(b, h, w, B, C, psize) device, dtype = v.device, v.dtype coor_h_vals = torch.arange(h, dtype=dtype, device=device) / float(max(h, 1)) coor_w_vals = torch.arange(w, dtype=dtype, device=device) / float(max(w, 1)) coor_h = torch.zeros(1, h, 1, 1, 1, psize, dtype=dtype, device=device) coor_w = torch.zeros(1, 1, w, 1, 1, psize, dtype=dtype, device=device) coor_h[0, :, 0, 0, 0, 0] = coor_h_vals coor_w[0, 0, :, 0, 0, 1] = coor_w_vals return (v + coor_h + coor_w).view(b, h * w * B, C, psize) def dynamic_routing(self, v: torch.Tensor, a_in: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: # v: (n, in_votes, C, psize), a_in: (n, in_votes, 1) n, in_votes, C, psize = v.shape b_ij = v.new_zeros(n, in_votes, C) a_in = a_in.clamp(1e-4, 1.0) for t in range(self.iters): c_ij = F.softmax(b_ij, dim=2) c_ij = c_ij * a_in c_ij = c_ij / (c_ij.sum(dim=2, keepdim=True) + self.eps) s_j = (c_ij.unsqueeze(-1) * v).sum(dim=1) v_j = self.squash(s_j, dim=-1, eps=self.eps) if t < self.iters - 1: agreement = (v * v_j.unsqueeze(1)).sum(dim=-1) b_ij = b_ij + agreement # activation from vector length in (0,1) a_out = torch.sqrt((v_j * v_j).sum(dim=-1) + self.eps).clamp(1e-4, 1.0 - 1e-4) return v_j, a_out def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: b, h, w, c = x.shape if not self.w_shared: patches, oh, ow = self.add_pathes(x, self.K, self.stride) p_in = patches[..., : self.B * self.psize].contiguous().view(b * oh * ow, self.K * self.K * self.B, self.psize) a_in = patches[..., self.B * self.psize :].contiguous().view(b * oh * ow, self.K * self.K * self.B, 1) votes = self.transform_view(p_in, w_shared=False) p_vec, a_vec = self.dynamic_routing(votes, a_in) p_out = p_vec.view(b, oh, ow, self.C * self.psize) a_out = a_vec.view(b, oh, ow, self.C) out = torch.cat([p_out, a_out], dim=3) else: if c != self.B * (self.psize + 1) or self.K != 1 or self.stride != 1: raise ValueError('DynamicConvCaps shared mode requires K=1, stride=1 and matching capsule channels') p_in = x[..., : self.B * self.psize].contiguous().view(b, h * w * self.B, self.psize) a_in = x[..., self.B * self.psize :].contiguous().view(b, h * w * self.B, 1) votes = self.transform_view(p_in, w_shared=True) if self.coor_add: votes = self.add_coord(votes, b, h, w, self.B, self.C, self.psize) p_vec, a_vec = self.dynamic_routing(votes, a_in) p_out = p_vec.reshape(b, 1, 1, self.C * self.psize).expand(b, h, w, self.C * self.psize) a_out = a_vec.unsqueeze(1).unsqueeze(1).expand(b, h, w, self.C) out = torch.cat([p_out, a_out], dim=3) p_out = torch.nan_to_num(p_out, nan=0.0, posinf=1e4, neginf=-1e4) a_out = torch.nan_to_num(a_out, nan=0.5, posinf=1.0 - 1e-4, neginf=1e-4) out = torch.nan_to_num(out, nan=0.0, posinf=1e4, neginf=-1e4) return p_out, a_out, out class SelfRoutingConvCaps(nn.Module): r"""Convolutional self-routing capsules. Keeps the same output contract as ``ConvCaps``/``DynamicConvCaps``: input: (N, H, W, B*(P*P+1)) output: p_out (N, H_out, W_out, C*P*P), a_out (N, H_out, W_out, C), out concat """ def __init__( self, B: int = 32, C: int = 32, K: int = 3, P: int = 4, stride: int = 1, iters: int = 1, coor_add: bool = False, w_shared: bool = False, ): super().__init__() _ = (iters, w_shared) # kept for API compatibility with other capsule layers. self.B = B self.C = C self.K = K self.P = P self.psize = P * P self.stride = stride self.coor_add = coor_add self.eps = 1e-6 self.kk = K * K self.kkB = self.kk * B # Pose transform for each input capsule vote -> output capsule pose. self.W1 = nn.Parameter(torch.empty(self.kkB, C, self.psize, self.psize)) nn.init.kaiming_uniform_(self.W1, a=math.sqrt(5)) # Routing logits from local pose vectors. self.W2 = nn.Parameter(torch.zeros(self.kkB, C, self.psize)) self.b2 = nn.Parameter(torch.zeros(1, 1, self.kkB, C)) def _output_hw(self, h: int, w: int) -> tuple[int, int]: pad = self.K // 2 oh = (h + 2 * pad - self.K) // self.stride + 1 ow = (w + 2 * pad - self.K) // self.stride + 1 return oh, ow def _add_coord(self, pose_unf: torch.Tensor, oh: int, ow: int) -> torch.Tensor: # pose_unf: (b, L, kkB, psize) if self.psize < 2: return pose_unf b, L, kkB, _ = pose_unf.shape device, dtype = pose_unf.device, pose_unf.dtype gy = torch.arange(oh, device=device, dtype=dtype) / float(max(oh, 1)) gx = torch.arange(ow, device=device, dtype=dtype) / float(max(ow, 1)) yy, xx = torch.meshgrid(gy, gx, indexing='ij') coords = torch.stack((yy, xx), dim=-1).view(1, L, 1, 2) pose_unf = pose_unf.clone() pose_unf[..., :2] = pose_unf[..., :2] + coords return pose_unf def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # x: (b, h, w, B*(psize+1)) b, h, w, c = x.shape expected = self.B * (self.psize + 1) if c != expected: raise ValueError(f'SelfRoutingConvCaps expected {expected} channels, got {c}') pose = x[..., : self.B * self.psize] act = x[..., self.B * self.psize :] pose_chw = pose.permute(0, 3, 1, 2).contiguous() act_chw = act.permute(0, 3, 1, 2).contiguous() pad = self.K // 2 pose_unf = F.unfold(pose_chw, kernel_size=self.K, stride=self.stride, padding=pad) act_unf = F.unfold(act_chw, kernel_size=self.K, stride=self.stride, padding=pad) oh, ow = self._output_hw(h, w) l = pose_unf.shape[-1] pose_unf = pose_unf.view(b, self.B, self.psize, self.kk, l).permute(0, 4, 3, 1, 2).contiguous() pose_unf = pose_unf.view(b, l, self.kkB, self.psize) act_unf = act_unf.view(b, self.B, self.kk, l).permute(0, 3, 2, 1).contiguous() act_unf = act_unf.view(b, l, self.kkB) if self.coor_add: pose_unf = self._add_coord(pose_unf, oh, ow) # Routing logits and couplings. logit = torch.einsum('blip,icp->blic', pose_unf, self.W2) + self.b2 r = F.softmax(logit, dim=3) ar = act_unf.unsqueeze(-1) * r ar_sum = ar.sum(dim=2, keepdim=True) + self.eps coeff = ar / ar_sum a_norm = act_unf.sum(dim=2, keepdim=True) + self.eps a_out = (ar_sum.squeeze(2) / a_norm).clamp(1e-4, 1.0 - 1e-4) pose_votes = torch.einsum('blip,icpq->blicq', pose_unf, self.W1) pose_out = (coeff.unsqueeze(-1) * pose_votes).sum(dim=2) p_out = pose_out.view(b, oh, ow, self.C * self.psize) a_out = a_out.view(b, oh, ow, self.C) out = torch.cat([p_out, a_out], dim=3) p_out = torch.nan_to_num(p_out, nan=0.0, posinf=1e4, neginf=-1e4) a_out = torch.nan_to_num(a_out, nan=0.5, posinf=1.0 - 1e-4, neginf=1e-4) out = torch.nan_to_num(out, nan=0.0, posinf=1e4, neginf=-1e4) return p_out, a_out, out class CapsuleDualHead(nn.Module): """Capsule detection head for one feature level. Args: c_in: Input channels of this feature scale (from parser-provided ``ch``). nc: Number of classes (final activation capsule count in ``ConvCaps2``). reg_max: Detect DFL bins, box channels are ``4 * reg_max``. k: Number of capsule types in ``PrimaryCaps``. d: Requested pose descriptor size; internally mapped to square ``P*P``. Input shape: x: ``(N, c_in, H, W)`` Output shape: boxes: ``(N, 4*reg_max, H, W)`` scores: ``(N, nc, H, W)`` aux: dict with final capsule activations when ``return_aux=True`` else ``None`` Parameter size: ``PrimaryCaps(c_in,k) + ConvCaps(k,nc,w_shared=True) + box_bias(4*reg_max)`` Structure: PrimaryCaps -> ConvCaps(class caps only, shared) """ def __init__(self, c_in: int, nc: int, reg_max: int, k: int, d: int): super().__init__() # Matrix-caps pose is square; choose smallest square >= requested d. p = max(1, int(math.ceil(math.sqrt(d)))) self.nc = nc self.reg_max = reg_max self.P = p self.psize = self.P * self.P # A=c_in, B=k, P controls pose channels as B*(P*P). self.primary = PrimaryCaps(A=c_in, B=k, K=1, P=self.P, stride=1) # Single class-caps layer with shared transforms for parameter reduction. self.conv_caps2 = ConvCaps(B=k, C=nc, K=1, P=self.P, stride=1, iters=1, coor_add=True, w_shared=True) # Detect-style localization prior set in CapsuleDetect.bias_init(). self.box_bias = nn.Parameter(torch.zeros(4 * reg_max)) def _pose_to_box(self, p_out: torch.Tensor, a_out: torch.Tensor) -> torch.Tensor: # p_out: (b,h,w,nc*psize), a_out is intentionally unused here. # Simple rule requested: use first 4*reg_max pose values as box channels. _ = a_out box_ch = 4 * self.reg_max if p_out.shape[-1] >= box_ch: box = p_out[..., :box_ch] else: # If pose channels are fewer than required box channels, repeat and trim. reps = math.ceil(box_ch / p_out.shape[-1]) box = p_out.repeat(1, 1, 1, reps)[..., :box_ch] return box + self.box_bias.view(1, 1, 1, box_ch) def forward(self, x: torch.Tensor, return_aux: bool = False) -> tuple[torch.Tensor, torch.Tensor, dict | None]: _, _, caps0 = self.primary(x) p2, a2, _ = self.conv_caps2(caps0) boxes = self._pose_to_box(p2, a2).permute(0, 3, 1, 2).contiguous() # (b,4*reg_max,h,w) a2_logits = torch.logit(a2.clamp(1e-4, 1.0 - 1e-4)) scores = a2_logits.permute(0, 3, 1, 2).contiguous() # (b,nc,h,w) logits aux = None if return_aux: aux = { "caps2_a": a2.permute(0, 3, 1, 2).contiguous(), } return boxes, scores, aux class CapsuleClsHead(nn.Module): """Capsule classification branch used as a drop-in replacement for Detect.cv3.""" def __init__(self, c_in: int, nc: int, k: int = 4, d: int = 16, iters: int = 1): super().__init__() p = max(1, int(math.ceil(math.sqrt(d)))) self.primary = PrimaryCaps(A=c_in, B=k, K=1, P=p, stride=1) # Internal capsule refinement layer. self.mid_caps = SelfRoutingConvCaps(B=k, C=int((k+nc)/2), K=1, P=p, stride=1, iters=iters, coor_add=False, w_shared=True) self.class_caps = SelfRoutingConvCaps(B=int((k+nc)/2), C=nc, K=1, P=p, stride=1, iters=iters, coor_add=False, w_shared=True) def forward(self, x: torch.Tensor) -> torch.Tensor: # Output Detect-compatible class logits in BCHW. _, _, caps = self.primary(x) _, _, caps_mid = self.mid_caps(caps) _, a_out, _ = self.class_caps(caps_mid) logits = torch.logit(a_out.clamp(1e-4, 1.0 - 1e-4)).permute(0, 3, 1, 2).contiguous() return torch.nan_to_num(logits, nan=0.0, posinf=20.0, neginf=-20.0).float() class CapsuleDetect(Detect): """Detect head with capsule vote aggregation for both box and cls branches. Input feature of level i is packed as interleaved channels: [pose(d_i), act(1)] repeated k_i times -> C_i = k_i * (d_i + 1) In forward_head: - split pose/act per capsule type - run Detect box/cls heads on each type-specific pose tensor - aggregate type predictions with act-driven vote weights Detect decode/postprocess/end2end flow is reused unchanged. """ def __init__( self, nc: int = 80, *args, reg_max: int = 16, end2end: bool = False, k: list[int] | tuple[int, ...] = (4, 8, 16), d: list[int] | tuple[int, ...] = (16, 16, 16), ch: tuple = (), ): parsed = list(args) if parsed and isinstance(parsed[-1], (list, tuple)): ch = tuple(parsed.pop(-1)) # Parser layout: [k_list, d_list, reg_max, end2end, ch] if len(parsed) not in (2, 4): raise ValueError('CapsuleDetect expects [k_list, d_list, reg_max, end2end, ch].') k, d = parsed[0], parsed[1] if len(parsed) == 4: reg_max = int(parsed[2]) end2end = bool(parsed[3]) if not isinstance(k, (list, tuple)) or not isinstance(d, (list, tuple)): raise TypeError('CapsuleDetect requires list/tuple k and d (per-level settings).') ch = tuple(int(c) for c in ch) nl = len(ch) if len(k) != nl or len(d) != nl: raise ValueError(f'CapsuleDetect k/d length must equal number of levels ({nl}).') self.k_list = tuple(int(v) for v in k) self.d_list = tuple(int(v) for v in d) for i, c in enumerate(ch): expected = self.k_list[i] * (self.d_list[i] + 1) if c != expected: raise ValueError( f'CapsuleDetect level-{i} channel mismatch: got {c}, expected {expected} from k={self.k_list[i]}, d={self.d_list[i]}.' ) # Detect heads operate on per-type pose tensors (d_i channels). super().__init__(nc=nc, reg_max=reg_max, end2end=end2end, ch=self.d_list) # Vote weights from activation channels (K_i channels), separate for cls/box. self.box_vote = nn.ModuleList( nn.Sequential(Conv(k_i, k_i, 3), nn.Conv2d(k_i, k_i, 1, bias=True)) for k_i in self.k_list ) self.cls_vote = nn.ModuleList( nn.Sequential(Conv(k_i, k_i, 3), nn.Conv2d(k_i, k_i, 1, bias=True)) for k_i in self.k_list ) def _split_caps(self, x: list[torch.Tensor]) -> tuple[list[torch.Tensor], list[torch.Tensor]]: """Split packed feature into pose and activation tensors per level. Returns: pose_caps: list of tensors, each (B, K, D, H, W) act_map: list of tensors, each (B, K, H, W) """ pose_caps, act_map = [], [] for i, xi in enumerate(x): k_i = self.k_list[i] d_i = self.d_list[i] c = int(xi.shape[1]) expected = k_i * (d_i + 1) if c != expected: raise ValueError(f'CapsuleDetect level-{i} channel mismatch: got {c}, expected {expected}.') b, _, h, w = xi.shape caps = xi.view(b, k_i, d_i + 1, h, w) pose_caps.append(caps[:, :, :d_i].contiguous()) act_map.append(caps[:, :, d_i].contiguous()) return pose_caps, act_map @staticmethod def _normalized_votes(raw: torch.Tensor, eps: float = 1e-4) -> torch.Tensor: # No softmax/sigmoid: use softplus + sum-normalization. w = F.softplus(raw) + eps return w / (w.sum(dim=1, keepdim=True) + eps) def _run_voted_head( self, pose: torch.Tensor, act: torch.Tensor, head: torch.nn.Module, vote_head: torch.nn.Module, out_ch: int, ) -> torch.Tensor: """Apply one Detect head per type and aggregate by vote weights. Args: pose: (B, K, D, H, W) act: (B, K, H, W) head: Detect box or cls head module for this level vote_head: vote logits module for this level out_ch: output channels of target prediction Returns: (B, out_ch, H, W) """ b, k, d, h, w = pose.shape # No voting needed when there is only one capsule type. if k == 1: return head(pose[:, 0]) pose_bt = pose.reshape(b * k, d, h, w) pred_bt = head(pose_bt).reshape(b, k, out_ch, h, w) vote_raw = vote_head(act) # (B, K, H, W) vote = self._normalized_votes(vote_raw).unsqueeze(2) # (B, K, 1, H, W) pred = (pred_bt * vote).sum(dim=1) return pred def forward_head( self, x: list[torch.Tensor], box_head: torch.nn.Module = None, cls_head: torch.nn.Module = None ) -> dict[str, torch.Tensor]: if box_head is None or cls_head is None: return dict() pose_caps, act_map = self._split_caps(x) bs = x[0].shape[0] box_list = [] cls_list = [] for i in range(self.nl): box_i = self._run_voted_head( pose_caps[i], act_map[i], box_head[i], self.box_vote[i], out_ch=4 * self.reg_max, ) cls_i = self._run_voted_head( pose_caps[i], act_map[i], cls_head[i], self.cls_vote[i], out_ch=self.nc, ) box_list.append(box_i.view(bs, 4 * self.reg_max, -1)) cls_list.append(cls_i.view(bs, self.nc, -1)) boxes = torch.cat(box_list, dim=-1) scores = torch.cat(cls_list, dim=-1) return dict(boxes=boxes, scores=scores, feats=x) class CapsuleDetectv1(Detect): """Capsule Detect variant with activation-gated pose fusion. Per level: 1) Split packed capsule channels into pose/activation (interleaved by type). 2) Use a 2-layer 1x1 gate net on activation channels. 3) Gate pose channels with residual scaling. 4) Flatten to K*D channels and run original Detect cv2/cv3 heads. """ def __init__( self, nc: int = 80, *args, reg_max: int = 16, end2end: bool = False, k: list[int] | tuple[int, ...] = (4, 8, 16), d: list[int] | tuple[int, ...] = (16, 16, 16), ch: tuple = (), ): parsed = list(args) if parsed and isinstance(parsed[-1], (list, tuple)): ch = tuple(parsed.pop(-1)) # Parser layout: [k_list, d_list, reg_max, end2end, ch] if len(parsed) not in (2, 4): raise ValueError("CapsuleDetectv1 expects [k_list, d_list, reg_max, end2end, ch].") k, d = parsed[0], parsed[1] if len(parsed) == 4: reg_max = int(parsed[2]) end2end = bool(parsed[3]) if not isinstance(k, (list, tuple)) or not isinstance(d, (list, tuple)): raise TypeError("CapsuleDetectv1 requires list/tuple k and d (per-level settings).") ch = tuple(int(c) for c in ch) nl = len(ch) if len(k) != nl or len(d) != nl: raise ValueError(f"CapsuleDetectv1 k/d length must equal number of levels ({nl}).") self.k_list = tuple(int(v) for v in k) self.d_list = tuple(int(v) for v in d) # Input from neck is packed as K*(D+1): [pose(D), act(1)] repeated K types. for i, c in enumerate(ch): expected = self.k_list[i] * (self.d_list[i] + 1) if c != expected: raise ValueError( f"CapsuleDetectv1 level-{i} channel mismatch: got {c}, " f"expected {expected} from k={self.k_list[i]}, d={self.d_list[i]}." ) # Detect heads consume merged pose channels: K*D. merged_ch = tuple(k_i * d_i for k_i, d_i in zip(self.k_list, self.d_list)) super().__init__(nc=nc, reg_max=reg_max, end2end=end2end, ch=merged_ch) self.pose_gates = nn.ModuleList() self.gate_alpha = nn.ParameterList() for k_i, d_i in zip(self.k_list, self.d_list): out_ch = k_i * d_i hidden = max(8, k_i * 2) self.pose_gates.append( nn.Sequential( nn.Conv2d(k_i, hidden, 1, bias=True), nn.SiLU(inplace=True), nn.Conv2d(hidden, out_ch, 1, bias=True), ) ) self.gate_alpha.append(nn.Parameter(torch.tensor(0.5))) def _split_pose_act(self, x: torch.Tensor, i: int) -> tuple[torch.Tensor, torch.Tensor]: """Split one level packed tensor into pose and activation maps.""" k_i = self.k_list[i] d_i = self.d_list[i] b, c, h, w = x.shape expected = k_i * (d_i + 1) if c != expected: raise ValueError(f"CapsuleDetectv1 level-{i} channel mismatch: got {c}, expected {expected}.") caps = x.view(b, k_i, d_i + 1, h, w) pose = caps[:, :, :d_i].reshape(b, k_i * d_i, h, w).contiguous() act = caps[:, :, d_i].contiguous() return pose, act def _merge_pose(self, x: list[torch.Tensor]) -> list[torch.Tensor]: merged = [] for i, xi in enumerate(x): pose, act = self._split_pose_act(xi, i) gate = torch.sigmoid(self.pose_gates[i](act)) # Residual gating keeps base pose information and improves stability. pose = pose * (1.0 + self.gate_alpha[i] * gate) merged.append(pose) return merged def forward_head( self, x: list[torch.Tensor], box_head: torch.nn.Module = None, cls_head: torch.nn.Module = None ) -> dict[str, torch.Tensor]: if box_head is None or cls_head is None: return dict() pose_feats = self._merge_pose(x) bs = pose_feats[0].shape[0] box_list = [box_head[i](pose_feats[i]).view(bs, 4 * self.reg_max, -1) for i in range(self.nl)] cls_list = [cls_head[i](pose_feats[i]).view(bs, self.nc, -1) for i in range(self.nl)] boxes = torch.cat(box_list, dim=-1) scores = torch.cat(cls_list, dim=-1) return dict(boxes=boxes, scores=scores, feats=x) class CapsuleDetectv2(Detect): """Capsule Detect v2: activation-gated pose + activation bypass for classification.""" def __init__( self, nc: int = 80, *args, reg_max: int = 16, end2end: bool = False, k: list[int] | tuple[int, ...] = (4, 8, 16), d: list[int] | tuple[int, ...] = (16, 16, 16), ch: tuple = (), ): parsed = list(args) if parsed and isinstance(parsed[-1], (list, tuple)): ch = tuple(parsed.pop(-1)) # Parser layout: [k_list, d_list, reg_max, end2end, ch] if len(parsed) not in (2, 4): raise ValueError("CapsuleDetectv2 expects [k_list, d_list, reg_max, end2end, ch].") k, d = parsed[0], parsed[1] if len(parsed) == 4: reg_max = int(parsed[2]) end2end = bool(parsed[3]) if not isinstance(k, (list, tuple)) or not isinstance(d, (list, tuple)): raise TypeError("CapsuleDetectv2 requires list/tuple k and d (per-level settings).") ch = tuple(int(c) for c in ch) nl = len(ch) if len(k) != nl or len(d) != nl: raise ValueError(f"CapsuleDetectv2 k/d length must equal number of levels ({nl}).") self.k_list = tuple(int(v) for v in k) self.d_list = tuple(int(v) for v in d) # Input from neck is packed as K*(D+1): [pose(D), act(1)] repeated K types. for i, c in enumerate(ch): expected = self.k_list[i] * (self.d_list[i] + 1) if c != expected: raise ValueError( f"CapsuleDetectv2 level-{i} channel mismatch: got {c}, " f"expected {expected} from k={self.k_list[i]}, d={self.d_list[i]}." ) # Detect heads consume merged pose channels: K*D. merged_ch = tuple(k_i * d_i for k_i, d_i in zip(self.k_list, self.d_list)) super().__init__(nc=nc, reg_max=reg_max, end2end=end2end, ch=merged_ch) self.pose_gates = nn.ModuleList() self.gate_alpha = nn.ParameterList() self.cls_bypass = nn.ModuleList() self.cls_beta = nn.ParameterList() for k_i, d_i in zip(self.k_list, self.d_list): pose_ch = k_i * d_i gate_hidden = max(8, k_i * 2) self.pose_gates.append( nn.Sequential( nn.Conv2d(k_i, gate_hidden, 1, bias=True), nn.SiLU(inplace=True), nn.Conv2d(gate_hidden, pose_ch, 1, bias=True), ) ) self.gate_alpha.append(nn.Parameter(torch.tensor(0.5))) cls_hidden = max(16, k_i * 2) self.cls_bypass.append( nn.Sequential( nn.Conv2d(k_i, cls_hidden, 1, bias=True), nn.SiLU(inplace=True), nn.Conv2d(cls_hidden, pose_ch, 1, bias=True), ) ) self.cls_beta.append(nn.Parameter(torch.tensor(0.1))) def _split_pose_act(self, x: torch.Tensor, i: int) -> tuple[torch.Tensor, torch.Tensor]: """Split one level packed tensor into pose and activation maps.""" k_i = self.k_list[i] d_i = self.d_list[i] b, c, h, w = x.shape expected = k_i * (d_i + 1) if c != expected: raise ValueError(f"CapsuleDetectv2 level-{i} channel mismatch: got {c}, expected {expected}.") caps = x.view(b, k_i, d_i + 1, h, w) pose = caps[:, :, :d_i].reshape(b, k_i * d_i, h, w).contiguous() act = caps[:, :, d_i].contiguous() return pose, act def _fuse_pose(self, x: list[torch.Tensor]) -> tuple[list[torch.Tensor], list[torch.Tensor]]: box_feats, cls_feats = [], [] for i, xi in enumerate(x): pose, act = self._split_pose_act(xi, i) gate = torch.sigmoid(self.pose_gates[i](act)) pose_g = pose * (1.0 + self.gate_alpha[i] * gate) # Classification bypass from activation channels. act_skip = self.cls_bypass[i](act) cls_in = pose_g + self.cls_beta[i] * act_skip box_feats.append(pose_g) cls_feats.append(cls_in) return box_feats, cls_feats def forward_head( self, x: list[torch.Tensor], box_head: torch.nn.Module = None, cls_head: torch.nn.Module = None ) -> dict[str, torch.Tensor]: if box_head is None or cls_head is None: return dict() box_feats, cls_feats = self._fuse_pose(x) bs = x[0].shape[0] box_list = [box_head[i](box_feats[i]).view(bs, 4 * self.reg_max, -1) for i in range(self.nl)] cls_list = [cls_head[i](cls_feats[i]).view(bs, self.nc, -1) for i in range(self.nl)] boxes = torch.cat(box_list, dim=-1) scores = torch.cat(cls_list, dim=-1) return dict(boxes=boxes, scores=scores, feats=x) class CapsuleDetectv4(Detect): """Capsule Detect v4: box uses raw pose, cls uses act bypass + symbolic type prior.""" def __init__( self, nc: int = 80, *args, reg_max: int = 16, end2end: bool = False, k: list[int] | tuple[int, ...] = (4, 8, 16), d: list[int] | tuple[int, ...] = (16, 16, 16), ch: tuple = (), ): parsed = list(args) if parsed and isinstance(parsed[-1], (list, tuple)): ch = tuple(parsed.pop(-1)) if len(parsed) not in (2, 4): raise ValueError("CapsuleDetectv4 expects [k_list, d_list, reg_max, end2end, ch].") k, d = parsed[0], parsed[1] if len(parsed) == 4: reg_max = int(parsed[2]) end2end = bool(parsed[3]) if not isinstance(k, (list, tuple)) or not isinstance(d, (list, tuple)): raise TypeError("CapsuleDetectv4 requires list/tuple k and d (per-level settings).") ch = tuple(int(c) for c in ch) nl = len(ch) if len(k) != nl or len(d) != nl: raise ValueError(f"CapsuleDetectv4 k/d length must equal number of levels ({nl}).") self.k_list = tuple(int(v) for v in k) self.d_list = tuple(int(v) for v in d) for i, c in enumerate(ch): expected = self.k_list[i] * (self.d_list[i] + 1) if c != expected: raise ValueError( f"CapsuleDetectv4 level-{i} channel mismatch: got {c}, " f"expected {expected} from k={self.k_list[i]}, d={self.d_list[i]}." ) merged_ch = tuple(k_i * d_i for k_i, d_i in zip(self.k_list, self.d_list)) super().__init__(nc=nc, reg_max=reg_max, end2end=end2end, ch=merged_ch) self.cls_bypass = nn.ModuleList() self.cls_beta = nn.ParameterList() self.sym_prior = nn.ModuleList() self.sym_beta = nn.ParameterList() for k_i, d_i in zip(self.k_list, self.d_list): pose_ch = k_i * d_i cls_hidden = max(16, k_i * 2) self.cls_bypass.append( nn.Sequential( nn.Conv2d(k_i, cls_hidden, 1, bias=True), nn.SiLU(inplace=True), nn.Conv2d(cls_hidden, pose_ch, 1, bias=True), ) ) self.cls_beta.append(nn.Parameter(torch.tensor(0.1))) self.sym_prior.append(nn.Conv2d(k_i, self.nc, 1, bias=False)) self.sym_beta.append(nn.Parameter(torch.tensor(0.1))) def _split_pose_act(self, x: torch.Tensor, i: int) -> tuple[torch.Tensor, torch.Tensor]: k_i = self.k_list[i] d_i = self.d_list[i] b, c, h, w = x.shape expected = k_i * (d_i + 1) if c != expected: raise ValueError(f"CapsuleDetectv4 level-{i} channel mismatch: got {c}, expected {expected}.") caps = x.view(b, k_i, d_i + 1, h, w) pose = caps[:, :, :d_i].reshape(b, k_i * d_i, h, w).contiguous() act = caps[:, :, d_i].contiguous() return pose, act def _build_feats(self, x: list[torch.Tensor]) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: box_feats, cls_feats, cls_priors = [], [], [] for i, xi in enumerate(x): pose, act = self._split_pose_act(xi, i) cls_in = pose + self.cls_beta[i] * self.cls_bypass[i](act) cls_prior = self.sym_beta[i] * self.sym_prior[i](act) box_feats.append(pose) cls_feats.append(cls_in) cls_priors.append(cls_prior) return box_feats, cls_feats, cls_priors def forward_head( self, x: list[torch.Tensor], box_head: torch.nn.Module = None, cls_head: torch.nn.Module = None ) -> dict[str, torch.Tensor]: if box_head is None or cls_head is None: return dict() box_feats, cls_feats, cls_priors = self._build_feats(x) bs = x[0].shape[0] box_list = [box_head[i](box_feats[i]).view(bs, 4 * self.reg_max, -1) for i in range(self.nl)] cls_list = [ (cls_head[i](cls_feats[i]) + cls_priors[i]).view(bs, self.nc, -1) for i in range(self.nl) ] boxes = torch.cat(box_list, dim=-1) scores = torch.cat(cls_list, dim=-1) return dict(boxes=boxes, scores=scores, feats=x) def _setup_capsule_layout( k: list[int] | tuple[int, ...], d: list[int] | tuple[int, ...], ch: tuple, cls_name: str, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: if not isinstance(k, (list, tuple)) or not isinstance(d, (list, tuple)): raise TypeError(f"{cls_name} requires list/tuple k and d (per-level settings).") ch = tuple(int(c) for c in ch) nl = len(ch) if len(k) != nl or len(d) != nl: raise ValueError(f"{cls_name} k/d length must equal number of levels ({nl}).") k_list = tuple(int(v) for v in k) d_list = tuple(int(v) for v in d) for i, c in enumerate(ch): expected = k_list[i] * (d_list[i] + 1) if c != expected: raise ValueError( f"{cls_name} level-{i} channel mismatch: got {c}, " f"expected {expected} from k={k_list[i]}, d={d_list[i]}." ) merged_ch = tuple(k_i * d_i for k_i, d_i in zip(k_list, d_list)) return k_list, d_list, merged_ch def _init_capsule_semantic_heads(obj: nn.Module) -> None: obj.cls_bypass = nn.ModuleList() obj.cls_beta = nn.ParameterList() obj.sym_prior = nn.ModuleList() obj.sym_norm = nn.ModuleList() obj.sym_dropout = nn.ModuleList() obj.sym_beta = nn.ParameterList() for k_i, d_i in zip(obj.k_list, obj.d_list): pose_ch = k_i * d_i cls_hidden = max(16, k_i * 2) obj.cls_bypass.append( nn.Sequential( nn.Conv2d(k_i, cls_hidden, 1, bias=True), nn.SiLU(inplace=True), nn.Conv2d(cls_hidden, pose_ch, 1, bias=True), ) ) obj.cls_beta.append(nn.Parameter(torch.tensor(0.1))) obj.sym_dropout.append(nn.Dropout2d(p=0.1)) obj.sym_prior.append(nn.Conv2d(k_i, obj.nc, 1, bias=False)) obj.sym_norm.append(nn.GroupNorm(1, obj.nc)) obj.sym_beta.append(nn.Parameter(torch.tensor(0.1))) def _capsule_split_pose_act( x: torch.Tensor, k_i: int, d_i: int, cls_name: str, level_i: int, ) -> tuple[torch.Tensor, torch.Tensor]: b, c, h, w = x.shape expected = k_i * (d_i + 1) if c != expected: raise ValueError(f"{cls_name} level-{level_i} channel mismatch: got {c}, expected {expected}.") caps = x.view(b, k_i, d_i + 1, h, w) pose = caps[:, :, :d_i].reshape(b, k_i * d_i, h, w).contiguous() act = caps[:, :, d_i].contiguous() return pose, act def _capsule_build_feats(obj: nn.Module, x: list[torch.Tensor]) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: box_feats, cls_feats, cls_priors = [], [], [] cls_name = obj.__class__.__name__ for i, xi in enumerate(x): pose, act = _capsule_split_pose_act(xi, obj.k_list[i], obj.d_list[i], cls_name, i) cls_scale = torch.tanh(obj.cls_beta[i]) cls_in = pose + cls_scale * obj.cls_bypass[i](act) act_s = obj.sym_dropout[i](act) prior = obj.sym_prior[i](act_s) prior = obj.sym_norm[i](prior) prior = prior - prior.mean(dim=1, keepdim=True) sym_scale = torch.tanh(obj.sym_beta[i]) cls_prior = sym_scale * prior box_feats.append(pose) cls_feats.append(cls_in) cls_priors.append(cls_prior) return box_feats, cls_feats, cls_priors def _capsule_build_feats_gated( obj: nn.Module, x: list[torch.Tensor] ) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: box_feats, cls_feats, cls_priors = [], [], [] cls_name = obj.__class__.__name__ for i, xi in enumerate(x): pose, act = _capsule_split_pose_act(xi, obj.k_list[i], obj.d_list[i], cls_name, i) cls_scale = torch.tanh(obj.cls_beta[i]) gate = torch.sigmoid(obj.cls_bypass[i](act)) cls_in = pose * (1.0 + cls_scale * gate) act_s = obj.sym_dropout[i](act) prior = obj.sym_prior[i](act_s) prior = obj.sym_norm[i](prior) prior = prior - prior.mean(dim=1, keepdim=True) sym_scale = torch.tanh(obj.sym_beta[i]) cls_prior = sym_scale * prior box_feats.append(pose) cls_feats.append(cls_in) cls_priors.append(cls_prior) return box_feats, cls_feats, cls_priors def _capsule_build_feats_boxcls( obj: nn.Module, x: list[torch.Tensor] ) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: box_feats, cls_feats, cls_priors = [], [], [] cls_name = obj.__class__.__name__ for i, xi in enumerate(x): pose, act = _capsule_split_pose_act(xi, obj.k_list[i], obj.d_list[i], cls_name, i) act_s = obj.sym_dropout[i](act) prior = obj.sym_prior[i](act_s) prior = obj.sym_norm[i](prior) prior = prior - prior.mean(dim=1, keepdim=True) sym_scale = torch.tanh(obj.sym_beta[i]) cls_prior = sym_scale * prior box_feats.append(pose) cls_feats.append(pose) cls_priors.append(cls_prior) return box_feats, cls_feats, cls_priors def _capsule_build_feats_boxcls_simpleprior( obj: nn.Module, x: list[torch.Tensor] ) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: box_feats, cls_feats, cls_priors = [], [], [] cls_name = obj.__class__.__name__ for i, xi in enumerate(x): pose, act = _capsule_split_pose_act(xi, obj.k_list[i], obj.d_list[i], cls_name, i) act_s = obj.sym_dropout[i](act) prior = obj.sym_prior[i](act_s) sym_scale = torch.tanh(obj.sym_beta[i]) cls_prior = sym_scale * prior box_feats.append(pose) cls_feats.append(pose) cls_priors.append(cls_prior) return box_feats, cls_feats, cls_priors def _capsule_build_feats_open_vocab( obj: nn.Module, x: list[torch.Tensor] ) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: box_feats, cls_feats, acts = [], [], [] cls_name = obj.__class__.__name__ for i, xi in enumerate(x): pose, act = _capsule_split_pose_act(xi, obj.k_list[i], obj.d_list[i], cls_name, i) cls_in = pose if getattr(obj, "with_act_gate", False): cls_scale = torch.tanh(obj.ov_beta[i]) gate = torch.sigmoid(obj.ov_gate[i](act)) cls_in = pose * (1.0 + cls_scale * gate) box_feats.append(pose) cls_feats.append(cls_in) acts.append(act) return box_feats, cls_feats, acts class CapsuleDetectv5(Detect): """Capsule Detect v5: box uses raw pose, cls uses stabilized symbolic prior.""" def __init__( self, nc: int = 80, *args, reg_max: int = 16, end2end: bool = False, k: list[int] | tuple[int, ...] = (4, 8, 16), d: list[int] | tuple[int, ...] = (16, 16, 16), ch: tuple = (), ): parsed = list(args) if parsed and isinstance(parsed[-1], (list, tuple)): ch = tuple(parsed.pop(-1)) if len(parsed) not in (2, 4): raise ValueError("CapsuleDetectv5 expects [k_list, d_list, reg_max, end2end, ch].") k, d = parsed[0], parsed[1] if len(parsed) == 4: reg_max = int(parsed[2]) end2end = bool(parsed[3]) self.k_list, self.d_list, merged_ch = _setup_capsule_layout(k, d, ch, "CapsuleDetectv5") super().__init__(nc=nc, reg_max=reg_max, end2end=end2end, ch=merged_ch) _init_capsule_semantic_heads(self) def _split_pose_act(self, x: torch.Tensor, i: int) -> tuple[torch.Tensor, torch.Tensor]: return _capsule_split_pose_act(x, self.k_list[i], self.d_list[i], "CapsuleDetectv5", i) def _build_feats(self, x: list[torch.Tensor]) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: return _capsule_build_feats(self, x) def forward_head( self, x: list[torch.Tensor], box_head: torch.nn.Module = None, cls_head: torch.nn.Module = None ) -> dict[str, torch.Tensor]: if box_head is None or cls_head is None: return dict() box_feats, cls_feats, cls_priors = self._build_feats(x) bs = x[0].shape[0] box_list = [box_head[i](box_feats[i]).view(bs, 4 * self.reg_max, -1) for i in range(self.nl)] cls_list = [ (cls_head[i](cls_feats[i]) + cls_priors[i]).view(bs, self.nc, -1) for i in range(self.nl) ] boxes = torch.cat(box_list, dim=-1) scores = torch.cat(cls_list, dim=-1) return dict(boxes=boxes, scores=scores, feats=x) class CapsuleDetectv6(CapsuleDetectv5): """Capsule Detect v6: replace additive cls correction with multiplicative act gate.""" def _build_feats(self, x: list[torch.Tensor]) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: return _capsule_build_feats_gated(self, x) class CapsuleDetectv7(CapsuleDetectv5): """Capsule Detect v7: cls head consumes raw pose features plus symbolic priors only.""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.profile_head = False self._head_profile: dict[str, float] = {} self._head_profile_calls = 0 def _ensure_profile_attrs(self) -> None: if not hasattr(self, "profile_head"): self.profile_head = False if not hasattr(self, "_head_profile"): self._head_profile = {} if not hasattr(self, "_head_profile_calls"): self._head_profile_calls = 0 def reset_head_profile(self) -> None: self._ensure_profile_attrs() self._head_profile = { "split_pose_act_ms": 0.0, "cls_prior_ms": 0.0, "box_head_ms": 0.0, "cls_head_ms": 0.0, "cat_ms": 0.0, } self._head_profile_calls = 0 def get_head_profile(self) -> dict[str, float]: self._ensure_profile_attrs() if not self._head_profile: return {} out = dict(self._head_profile) calls = max(self._head_profile_calls, 1) out["calls"] = float(self._head_profile_calls) out["total_ms"] = sum(v for k, v in out.items() if k.endswith("_ms")) for key in list(self._head_profile): out[key.replace("_ms", "_avg_ms")] = self._head_profile[key] / calls return out def _sync_profile(self) -> None: self._ensure_profile_attrs() if self.profile_head and torch.cuda.is_available(): torch.cuda.synchronize() def _build_feats(self, x: list[torch.Tensor]) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: self._ensure_profile_attrs() if not self.profile_head: return _capsule_build_feats_boxcls(self, x) if not self._head_profile: self.reset_head_profile() box_feats, cls_feats, cls_priors = [], [], [] cls_name = self.__class__.__name__ for i, xi in enumerate(x): self._sync_profile() t0 = time.perf_counter() pose, act = _capsule_split_pose_act(xi, self.k_list[i], self.d_list[i], cls_name, i) self._sync_profile() self._head_profile["split_pose_act_ms"] += (time.perf_counter() - t0) * 1000.0 self._sync_profile() t0 = time.perf_counter() act_s = self.sym_dropout[i](act) prior = self.sym_prior[i](act_s) prior = self.sym_norm[i](prior) prior = prior - prior.mean(dim=1, keepdim=True) sym_scale = torch.tanh(self.sym_beta[i]) cls_prior = sym_scale * prior self._sync_profile() self._head_profile["cls_prior_ms"] += (time.perf_counter() - t0) * 1000.0 box_feats.append(pose) cls_feats.append(pose) cls_priors.append(cls_prior) return box_feats, cls_feats, cls_priors def forward_head( self, x: list[torch.Tensor], box_head: torch.nn.Module = None, cls_head: torch.nn.Module = None ) -> dict[str, torch.Tensor]: self._ensure_profile_attrs() if box_head is None or cls_head is None: return dict() box_feats, cls_feats, cls_priors = self._build_feats(x) bs = x[0].shape[0] if not self.profile_head: box_list = [box_head[i](box_feats[i]).view(bs, 4 * self.reg_max, -1) for i in range(self.nl)] cls_list = [ (cls_head[i](cls_feats[i]) + cls_priors[i]).view(bs, self.nc, -1) for i in range(self.nl) ] boxes = torch.cat(box_list, dim=-1) scores = torch.cat(cls_list, dim=-1) return dict(boxes=boxes, scores=scores, feats=x) if not self._head_profile: self.reset_head_profile() self._head_profile_calls += 1 box_list, cls_list = [], [] for i in range(self.nl): self._sync_profile() t0 = time.perf_counter() box_i = box_head[i](box_feats[i]).view(bs, 4 * self.reg_max, -1) self._sync_profile() self._head_profile["box_head_ms"] += (time.perf_counter() - t0) * 1000.0 box_list.append(box_i) self._sync_profile() t0 = time.perf_counter() cls_i = (cls_head[i](cls_feats[i]) + cls_priors[i]).view(bs, self.nc, -1) self._sync_profile() self._head_profile["cls_head_ms"] += (time.perf_counter() - t0) * 1000.0 cls_list.append(cls_i) self._sync_profile() t0 = time.perf_counter() boxes = torch.cat(box_list, dim=-1) scores = torch.cat(cls_list, dim=-1) self._sync_profile() self._head_profile["cat_ms"] += (time.perf_counter() - t0) * 1000.0 return dict(boxes=boxes, scores=scores, feats=x) class CapsuleDetectv8(CapsuleDetectv5): """Capsule Detect v8: raw pose cls path with simplified cls_prior (no norm/centering).""" def _build_feats(self, x: list[torch.Tensor]) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: return _capsule_build_feats_boxcls_simpleprior(self, x) class CapsuleOpenVocabDetect(Detect): """Capsule detection head with open-vocabulary classification via text embedding matching.""" def __init__( self, nc: int = 80, *args, reg_max: int = 16, end2end: bool = False, embed: int = 256, with_act_gate: bool = False, with_objectness_prior: bool = True, k: list[int] | tuple[int, ...] = (4, 8, 16), d: list[int] | tuple[int, ...] = (16, 16, 16), ch: tuple = (), ): parsed = list(args) if parsed and isinstance(parsed[-1], (list, tuple)): ch = tuple(parsed.pop(-1)) if len(parsed) not in (2, 4, 7): raise ValueError( "CapsuleOpenVocabDetect expects [k_list, d_list, (reg_max, end2end, embed, with_act_gate, " "with_objectness_prior), ch]." ) k, d = parsed[0], parsed[1] if len(parsed) == 4: reg_max = int(parsed[2]) end2end = bool(parsed[3]) elif len(parsed) == 7: # Support both direct args order: # [k_list, d_list, reg_max, end2end, embed, with_act_gate, with_objectness_prior] # and parser-appended order: # [k_list, d_list, embed, with_act_gate, with_objectness_prior, reg_max, end2end] if type(parsed[3]) is bool and type(parsed[4]) is bool and type(parsed[6]) is bool: embed = int(parsed[2]) with_act_gate = bool(parsed[3]) with_objectness_prior = bool(parsed[4]) reg_max = int(parsed[5]) end2end = bool(parsed[6]) else: reg_max = int(parsed[2]) end2end = bool(parsed[3]) embed = int(parsed[4]) with_act_gate = bool(parsed[5]) with_objectness_prior = bool(parsed[6]) self.k_list, self.d_list, merged_ch = _setup_capsule_layout(k, d, ch, "CapsuleOpenVocabDetect") super().__init__(nc=nc, reg_max=reg_max, end2end=end2end, ch=merged_ch) self.embed = int(embed) self.with_act_gate = bool(with_act_gate) self.with_objectness_prior = bool(with_objectness_prior) self.emb_head = nn.ModuleList() self.ov_gate = nn.ModuleList() self.ov_beta = nn.ParameterList() self.obj_prior = nn.ModuleList() for k_i, d_i in zip(self.k_list, self.d_list): pose_ch = k_i * d_i self.emb_head.append( nn.Sequential( Conv(pose_ch, pose_ch, 3), DWConv(pose_ch, pose_ch, 3), nn.Conv2d(pose_ch, self.embed, 1, bias=True), ) ) if self.with_act_gate: hidden = max(16, k_i * 2) self.ov_gate.append( nn.Sequential( nn.Conv2d(k_i, hidden, 1, bias=True), nn.SiLU(inplace=True), nn.Conv2d(hidden, pose_ch, 1, bias=True), ) ) self.ov_beta.append(nn.Parameter(torch.tensor(0.1))) else: self.ov_gate.append(nn.Identity()) self.ov_beta.append(nn.Parameter(torch.tensor(0.0), requires_grad=False)) if self.with_objectness_prior: self.obj_prior.append(nn.Conv2d(k_i, 1, 1, bias=True)) else: self.obj_prior.append(nn.Identity()) self.logit_scale = nn.Parameter(torch.tensor(math.log(1 / 0.07), dtype=torch.float32)) self.register_buffer("cached_text_embeddings", torch.empty(0), persistent=False) def set_text_embeddings(self, text_embs: torch.Tensor | None) -> None: """Cache normalized text embeddings for inference.""" if text_embs is None: self.cached_text_embeddings = torch.empty(0, device=self.logit_scale.device) return if text_embs.ndim != 2: raise ValueError(f"text_embs must be 2D [num_classes, embed_dim], got shape {tuple(text_embs.shape)}.") self.cached_text_embeddings = F.normalize(text_embs.detach().to(self.logit_scale.device), dim=-1) def _split_pose_act(self, x: torch.Tensor, i: int) -> tuple[torch.Tensor, torch.Tensor]: return _capsule_split_pose_act(x, self.k_list[i], self.d_list[i], "CapsuleOpenVocabDetect", i) def _build_feats(self, x: list[torch.Tensor]) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: return _capsule_build_feats_open_vocab(self, x) def _prepare_text_embeddings(self, text_embs: torch.Tensor | None, bs: int, device: torch.device) -> torch.Tensor | None: if text_embs is None: if self.cached_text_embeddings.numel() == 0: return None text = self.cached_text_embeddings else: text = text_embs if text.ndim == 2: text = text.unsqueeze(0).expand(bs, -1, -1) elif text.ndim != 3: raise ValueError(f"text_embs must be 2D or 3D, got shape {tuple(text.shape)}.") if text.shape[-1] != self.embed: raise ValueError(f"text_embs last dim must equal embed={self.embed}, got {text.shape[-1]}.") return F.normalize(text.to(device=device, dtype=self.logit_scale.dtype), dim=-1) def _compute_ov_scores( self, cls_feats: list[torch.Tensor], acts: list[torch.Tensor], text_embs: torch.Tensor | None ) -> tuple[torch.Tensor | None, list[torch.Tensor], torch.Tensor | None]: bs = cls_feats[0].shape[0] level_embeddings = [] for i in range(self.nl): emb = self.emb_head[i](cls_feats[i]) if self.with_objectness_prior: emb = emb * (1.0 + torch.sigmoid(self.obj_prior[i](acts[i]))) level_embeddings.append(emb) text = self._prepare_text_embeddings(text_embs, bs, cls_feats[0].device) if text is None: return None, level_embeddings, None visual_tokens = torch.cat( [F.normalize(emb.flatten(2).transpose(1, 2), dim=-1) for emb in level_embeddings], dim=1, ) scale = self.logit_scale.exp().clamp(max=100.0) scores = torch.einsum("bnd,bcd->bcn", visual_tokens, text) * scale return scores, level_embeddings, text def forward_head( self, x: list[torch.Tensor], text_embs: torch.Tensor | None = None, box_head: torch.nn.Module = None, cls_head: torch.nn.Module = None, ) -> dict[str, torch.Tensor]: del cls_head # fixed-class cls head is unused in open-vocabulary mode if box_head is None: return dict() box_feats, cls_feats, acts = self._build_feats(x) bs = x[0].shape[0] boxes = torch.cat([box_head[i](box_feats[i]).view(bs, 4 * self.reg_max, -1) for i in range(self.nl)], dim=-1) scores, level_embeddings, text = self._compute_ov_scores(cls_feats, acts, text_embs) preds = { "boxes": boxes, "embeddings": level_embeddings, "cls_feats": cls_feats, "acts": acts, "feats": x, } if scores is not None: preds["scores"] = scores preds["text_embeddings"] = text return preds def forward( self, x: list[torch.Tensor], text_embs: torch.Tensor | None = None ) -> dict[str, torch.Tensor] | torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]: preds = self.forward_head(x, text_embs=text_embs, **self.one2many) if self.end2end: x_detach = [xi.detach() for xi in x] one2one = self.forward_head(x_detach, text_embs=text_embs, **self.one2one) preds = {"one2many": preds, "one2one": one2one} if self.training: return preds infer_preds = preds["one2one"] if self.end2end else preds if "scores" not in infer_preds: raise ValueError("CapsuleOpenVocabDetect inference requires text_embs or cached text embeddings.") original_nc = self.nc self.nc = int(infer_preds["scores"].shape[1]) try: y = self._inference(infer_preds) if self.end2end: y = self.postprocess(y.permute(0, 2, 1)) finally: self.nc = original_nc return y if self.export else (y, preds) class CapsuleSegmentv1(Segment): """Capsule-style Segment head aligned with CapsuleDetectv6 semantics.""" def __init__( self, nc: int = 80, *args, nm: int = 32, npr: int = 256, reg_max: int = 16, end2end: bool = False, k: list[int] | tuple[int, ...] = (4, 8, 16), d: list[int] | tuple[int, ...] = (16, 16, 16), ch: tuple = (), ): parsed = list(args) if parsed and isinstance(parsed[-1], (list, tuple)): ch = tuple(parsed.pop(-1)) if len(parsed) not in (2, 4, 6): raise ValueError("CapsuleSegmentv1 expects [k_list, d_list, (nm, npr), reg_max, end2end, ch].") k, d = parsed[0], parsed[1] if len(parsed) == 4: if isinstance(parsed[3], bool): reg_max = int(parsed[2]) end2end = bool(parsed[3]) else: nm = int(parsed[2]) npr = int(parsed[3]) elif len(parsed) == 6: nm = int(parsed[2]) npr = int(parsed[3]) reg_max = int(parsed[4]) end2end = bool(parsed[5]) self.k_list, self.d_list, merged_ch = _setup_capsule_layout(k, d, ch, "CapsuleSegmentv1") super().__init__(nc=nc, nm=nm, npr=npr, reg_max=reg_max, end2end=end2end, ch=merged_ch) _init_capsule_semantic_heads(self) self.proto = Proto26(merged_ch, self.npr, self.nm, nc) def _split_pose_act(self, x: torch.Tensor, i: int) -> tuple[torch.Tensor, torch.Tensor]: return _capsule_split_pose_act(x, self.k_list[i], self.d_list[i], "CapsuleSegmentv1", i) def _build_feats(self, x: list[torch.Tensor]) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: return _capsule_build_feats_gated(self, x) def forward_head( self, x: list[torch.Tensor], box_head: torch.nn.Module = None, cls_head: torch.nn.Module = None, mask_head: torch.nn.Module = None, ) -> dict[str, torch.Tensor]: if box_head is None or cls_head is None: return dict() box_feats, cls_feats, cls_priors = self._build_feats(x) bs = x[0].shape[0] boxes = torch.cat([box_head[i](box_feats[i]).view(bs, 4 * self.reg_max, -1) for i in range(self.nl)], dim=-1) scores = torch.cat( [(cls_head[i](cls_feats[i]) + cls_priors[i]).view(bs, self.nc, -1) for i in range(self.nl)], dim=-1, ) preds = dict(boxes=boxes, scores=scores, feats=cls_feats) if mask_head is not None: preds["mask_coefficient"] = torch.cat( [mask_head[i](cls_feats[i]).view(bs, self.nm, -1) for i in range(self.nl)], dim=-1, ) return preds def forward(self, x: list[torch.Tensor]) -> tuple | list[torch.Tensor] | dict[str, torch.Tensor]: _, cls_feats, _ = self._build_feats(x) outputs = Detect.forward(self, x) preds = outputs[1] if isinstance(outputs, tuple) else outputs proto_in = cls_feats proto = self.proto(proto_in) # multi-level Proto26 over merged capsule features if isinstance(preds, dict): if self.end2end: preds["one2many"]["proto"] = proto preds["one2one"]["proto"] = tuple(p.detach() for p in proto) if isinstance(proto, tuple) else proto.detach() else: preds["proto"] = proto if self.training: return preds return (outputs, proto) if self.export else ((outputs[0], proto), preds) class CapsuleSegmentv2(CapsuleSegmentv1): """Capsule Segment v2: cls head consumes raw pose features and symbolic priors only.""" def _build_feats(self, x: list[torch.Tensor]) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: return _capsule_build_feats_boxcls(self, x) class CapsuleSegmentv3(CapsuleSegmentv1): """Capsule Segment v3: raw pose cls path with simplified cls_prior (no norm/centering).""" def _build_feats(self, x: list[torch.Tensor]) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: return _capsule_build_feats_boxcls_simpleprior(self, x)