| from __future__ import annotations |
|
|
| import copy |
| import math |
| import time |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from ultralytics.nn.modules import Conv, DWConv, Detect, Segment |
| from ultralytics.nn.modules.block import Proto26 |
|
|
|
|
| class PrimaryCaps(nn.Module): |
| r"""Primary convolutional capsules. |
| |
| Outputs pose and activation, plus a concatenated NHWC capsule tensor. |
| |
| Args: |
| A: Input feature channels. |
| B: Number of capsule types. |
| K: Convolution kernel size. |
| P: Pose matrix side length (pose size is ``P*P``). |
| stride: Convolution stride. |
| |
| Input shape: |
| x: ``(N, A, H, W)`` |
| |
| Output shape: |
| a: ``(N, B, H_out, W_out)`` |
| p: ``(N, B*P*P, H_out, W_out)`` |
| out: ``(N, H_out, W_out, B*(P*P+1))`` |
| |
| Parameter size: |
| pose conv + act conv |
| ``(K*K*A*B*P*P + B*P*P) + (K*K*A*B + B)`` |
| """ |
|
|
| def __init__(self, A: int = 32, B: int = 32, K: int = 1, P: int = 4, stride: int = 1): |
| super().__init__() |
| self.B = B |
| self.P = P |
| self.psize = P * P |
|
|
| self.pose = nn.Conv2d(in_channels=A, out_channels=B * self.psize, kernel_size=K, stride=stride, bias=True) |
| self.a = nn.Conv2d(in_channels=A, out_channels=B, kernel_size=K, stride=stride, bias=True) |
| self.sigmoid = nn.Sigmoid() |
|
|
| def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| |
| p = self.pose(x) |
| a = self.sigmoid(self.a(x)) |
| out = torch.cat([p, a], dim=1).permute(0, 2, 3, 1).contiguous() |
| return a, p, out |
|
|
|
|
| class ConvCaps(nn.Module): |
| r"""Convolutional capsules with EM routing. |
| |
| Args: |
| B: Input capsule types. |
| C: Output capsule types. |
| K: Patch kernel size. |
| P: Pose matrix side length (pose size is ``P*P``). |
| stride: Spatial stride for patch extraction. |
| iters: Number of EM routing iterations. |
| coor_add: Add coordinate offsets (class-caps style option). |
| w_shared: Share transform matrices across spatial positions. |
| |
| Input shape: |
| x: ``(N, H, W, B*(P*P+1))`` |
| |
| Output shape: |
| p_out: ``(N, H_out, W_out, C*P*P)`` |
| a_out: ``(N, H_out, W_out, C)`` |
| out: ``(N, H_out, W_out, C*(P*P+1))`` |
| |
| Parameter size: |
| If ``w_shared=False``: |
| ``weights: (K*K*B*C*P*P*P*P)``, ``beta_u: C``, ``beta_a: C`` |
| |
| If ``w_shared=True``: |
| ``weights: (B*C*P*P*P*P)``, ``beta_u: C``, ``beta_a: C`` |
| |
| Total = ``weights + 2*C`` (excluding non-trainable buffers). |
| """ |
|
|
| def __init__( |
| self, |
| B: int = 32, |
| C: int = 32, |
| K: int = 3, |
| P: int = 4, |
| stride: int = 1, |
| iters: int = 3, |
| coor_add: bool = False, |
| w_shared: bool = False, |
| ): |
| super().__init__() |
| self.B = B |
| self.C = C |
| self.K = K |
| self.P = P |
| self.psize = P * P |
| self.stride = stride |
| self.iters = iters |
| self.coor_add = coor_add |
| self.w_shared = w_shared |
|
|
| self.eps = 1e-6 |
| self._lambda = 1e-3 |
| self.register_buffer("ln_2pi", torch.tensor(math.log(2 * math.pi), dtype=torch.float32), persistent=False) |
|
|
| |
| self.beta_u = nn.Parameter(torch.zeros(C)) |
| self.beta_a = nn.Parameter(torch.zeros(C)) |
|
|
| |
| weight_in = B if w_shared else (K * K * B) |
| self.weights = nn.Parameter(torch.randn(1, weight_in, C, self.psize, self.psize) * 0.02) |
|
|
| self.sigmoid = nn.Sigmoid() |
| self.softmax = nn.Softmax(dim=2) |
|
|
| def m_step( |
| self, |
| a_in: torch.Tensor, |
| r: torch.Tensor, |
| v: torch.Tensor, |
| eps: float, |
| b: int, |
| B: int, |
| C: int, |
| psize: int, |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| |
| if a_in.ndim == 3: |
| a_in = a_in.unsqueeze(2) |
| r = r * a_in |
| r = r / (r.sum(dim=2, keepdim=True) + eps) |
| r_sum = r.sum(dim=1, keepdim=True) |
| coeff = r / (r_sum + eps) |
|
|
| mu = torch.sum(coeff * v, dim=1, keepdim=True) |
| sigma_sq = torch.sum(coeff * (v - mu).pow(2), dim=1, keepdim=True) + eps |
| sigma_sq = sigma_sq.clamp_min(1e-4) |
|
|
| r_sum_flat = r_sum.view(b, C, 1) |
| sigma_sq_flat = sigma_sq.view(b, C, psize).clamp_min(1e-4) |
| cost_h = (self.beta_u.view(1, C, 1) + torch.log(torch.sqrt(sigma_sq_flat))) * r_sum_flat |
| a_out = self.sigmoid(self._lambda * (self.beta_a.view(1, C) - cost_h.sum(dim=2))).clamp(1e-4, 1.0 - 1e-4) |
|
|
| mu = torch.nan_to_num(mu, nan=0.0, posinf=1e4, neginf=-1e4) |
| sigma_sq = torch.nan_to_num(sigma_sq, nan=1e-4, posinf=1e4, neginf=1e-4) |
| a_out = torch.nan_to_num(a_out, nan=0.5, posinf=1.0 - 1e-4, neginf=1e-4) |
| return a_out, mu, sigma_sq |
|
|
| def e_step( |
| self, |
| mu: torch.Tensor, |
| sigma_sq: torch.Tensor, |
| a_out: torch.Tensor, |
| v: torch.Tensor, |
| eps: float, |
| b: int, |
| C: int, |
| ) -> torch.Tensor: |
| |
| sigma_sq = sigma_sq.clamp_min(1e-4) |
| a_out = a_out.clamp(1e-4, 1.0 - 1e-4) |
| ln_p_j_h = -1.0 * (v - mu).pow(2) / (2.0 * sigma_sq) - torch.log(torch.sqrt(sigma_sq)) - 0.5 * self.ln_2pi |
| ln_ap = ln_p_j_h.sum(dim=3) + torch.log(a_out.view(b, 1, C) + eps) |
| ln_ap = torch.nan_to_num(ln_ap, nan=0.0, posinf=50.0, neginf=-50.0) |
| r = self.softmax(ln_ap).unsqueeze(-1) |
| r = torch.nan_to_num(r, nan=(1.0 / max(C, 1)), posinf=1.0, neginf=0.0) |
| return r |
|
|
| def caps_em_routing(self, v: torch.Tensor, a_in: torch.Tensor, C: int, eps: float) -> tuple[torch.Tensor, torch.Tensor]: |
| b, B, _, psize = v.shape |
| r = v.new_full((b, B, C, 1), 1.0 / C) |
|
|
| for t in range(self.iters): |
| a_out, mu, sigma_sq = self.m_step(a_in, r, v, eps, b, B, C, psize) |
| if t < self.iters - 1: |
| r = self.e_step(mu, sigma_sq, a_out, v, eps, b, C) |
|
|
| |
| p_out = torch.nan_to_num(mu.squeeze(1), nan=0.0, posinf=1e4, neginf=-1e4) |
| a_out = torch.nan_to_num(a_out, nan=0.5, posinf=1.0 - 1e-4, neginf=1e-4) |
| return p_out, a_out |
|
|
| def add_pathes(self, x: torch.Tensor, B: int, K: int, psize: int, stride: int) -> tuple[torch.Tensor, int, int]: |
| |
| b, h, w, c = x.shape |
| x_chw = x.permute(0, 3, 1, 2).contiguous() |
| pad = K // 2 |
| patches = F.unfold(x_chw, kernel_size=K, padding=pad, stride=stride) |
|
|
| oh = (h + 2 * pad - K) // stride + 1 |
| ow = (w + 2 * pad - K) // stride + 1 |
| patches = patches.transpose(1, 2).contiguous().view(b, oh, ow, K * K, c) |
| return patches, oh, ow |
|
|
| def transform_view(self, x: torch.Tensor, w: torch.Tensor, C: int, P: int, w_shared: bool = False) -> torch.Tensor: |
| |
| b, in_votes, psize = x.shape |
| assert psize == P * P |
|
|
| w0 = w[0] |
| if w_shared: |
| base = w0.size(0) |
| reps = in_votes // base |
| w0 = w0.repeat(reps, 1, 1, 1) |
|
|
| |
| v = torch.einsum("bip,icpq->bicq", x, w0) |
| return v |
|
|
| def add_coord(self, v: torch.Tensor, b: int, h: int, w: int, B: int, C: int, psize: int) -> torch.Tensor: |
| |
| |
| v = v.view(b, h, w, B, C, psize) |
|
|
| device = v.device |
| dtype = v.dtype |
| coor_h_vals = torch.arange(h, dtype=dtype, device=device) / float(max(h, 1)) |
| coor_w_vals = torch.arange(w, dtype=dtype, device=device) / float(max(w, 1)) |
|
|
| coor_h = torch.zeros(1, h, 1, 1, 1, psize, dtype=dtype, device=device) |
| coor_w = torch.zeros(1, 1, w, 1, 1, psize, dtype=dtype, device=device) |
| coor_h[0, :, 0, 0, 0, 0] = coor_h_vals |
| coor_w[0, 0, :, 0, 0, 1] = coor_w_vals |
|
|
| v = (v + coor_h + coor_w).view(b, h * w * B, C, psize) |
| return v |
|
|
| def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| |
| b, h, w, c = x.shape |
|
|
| if not self.w_shared: |
| patches, oh, ow = self.add_pathes(x, self.B, self.K, self.psize, self.stride) |
|
|
| p_in = patches[..., : self.B * self.psize].contiguous().view(b * oh * ow, self.K * self.K * self.B, self.psize) |
| a_in = patches[..., self.B * self.psize :].contiguous().view(b * oh * ow, self.K * self.K * self.B, 1) |
|
|
| v = self.transform_view(p_in, self.weights, self.C, self.P, w_shared=False) |
| p_out, a_out = self.caps_em_routing(v, a_in, self.C, self.eps) |
|
|
| p_out = p_out.view(b, oh, ow, self.C * self.psize) |
| a_out = a_out.view(b, oh, ow, self.C) |
| out = torch.cat([p_out, a_out], dim=3) |
| else: |
| assert c == self.B * (self.psize + 1) |
| assert self.K == 1 |
| assert self.stride == 1 |
|
|
| p_in = x[..., : self.B * self.psize].contiguous().view(b, h * w * self.B, self.psize) |
| a_in = x[..., self.B * self.psize :].contiguous().view(b, h * w * self.B, 1) |
|
|
| v = self.transform_view(p_in, self.weights, self.C, self.P, w_shared=True) |
| if self.coor_add: |
| v = self.add_coord(v, b, h, w, self.B, self.C, self.psize) |
|
|
| p_cls, a_cls = self.caps_em_routing(v, a_in, self.C, self.eps) |
|
|
| |
| p_out = p_cls.reshape(b, 1, 1, self.C * self.psize).expand(b, h, w, self.C * self.psize) |
| a_out = a_cls.unsqueeze(1).unsqueeze(1).expand(b, h, w, self.C) |
| out = torch.cat([p_out, a_out], dim=3) |
|
|
| return p_out, a_out, out |
|
|
|
|
|
|
| class DynamicConvCaps(nn.Module): |
| r"""Convolutional capsules with Sabour-style dynamic routing. |
| |
| This layer keeps the same tensor interface as ``ConvCaps``: |
| input: (N, H, W, B*(P*P+1)) |
| output: p_out (N, H_out, W_out, C*P*P), a_out (N, H_out, W_out, C), out concat |
| |
| Args: |
| B: Input capsule types. |
| C: Output capsule types. |
| K: Patch kernel size. |
| P: Pose matrix side length. |
| stride: Patch stride. |
| iters: Routing iterations. |
| coor_add: Add coordinates in shared mode. |
| w_shared: Share transforms across spatial positions (requires K=1, stride=1). |
| """ |
|
|
| def __init__( |
| self, |
| B: int = 32, |
| C: int = 32, |
| K: int = 3, |
| P: int = 4, |
| stride: int = 1, |
| iters: int = 3, |
| coor_add: bool = False, |
| w_shared: bool = False, |
| ): |
| super().__init__() |
| self.B = B |
| self.C = C |
| self.K = K |
| self.P = P |
| self.psize = P * P |
| self.stride = stride |
| self.iters = iters |
| self.coor_add = coor_add |
| self.w_shared = w_shared |
| self.eps = 1e-6 |
|
|
| weight_in = B if w_shared else (K * K * B) |
| self.weights = nn.Parameter(torch.randn(1, weight_in, C, self.psize, self.psize) * 0.02) |
|
|
| @staticmethod |
| def squash(s: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor: |
| s2 = (s * s).sum(dim=dim, keepdim=True) |
| scale = s2 / (1.0 + s2) |
| return scale * s / torch.sqrt(s2 + eps) |
|
|
| def add_pathes(self, x: torch.Tensor, K: int, stride: int) -> tuple[torch.Tensor, int, int]: |
| b, h, w, c = x.shape |
| x_chw = x.permute(0, 3, 1, 2).contiguous() |
| pad = K // 2 |
| patches = F.unfold(x_chw, kernel_size=K, padding=pad, stride=stride) |
| oh = (h + 2 * pad - K) // stride + 1 |
| ow = (w + 2 * pad - K) // stride + 1 |
| patches = patches.transpose(1, 2).contiguous().view(b, oh, ow, K * K, c) |
| return patches, oh, ow |
|
|
| def transform_view(self, x: torch.Tensor, w_shared: bool) -> torch.Tensor: |
| |
| b, in_votes, psize = x.shape |
| if psize != self.psize: |
| raise ValueError('Invalid pose size for DynamicConvCaps') |
|
|
| w0 = self.weights[0] |
| if w_shared: |
| base = w0.size(0) |
| reps = in_votes // base |
| w0 = w0.repeat(reps, 1, 1, 1) |
|
|
| return torch.einsum('bip,icpq->bicq', x, w0) |
|
|
| def add_coord(self, v: torch.Tensor, b: int, h: int, w: int, B: int, C: int, psize: int) -> torch.Tensor: |
| |
| v = v.view(b, h, w, B, C, psize) |
| device, dtype = v.device, v.dtype |
| coor_h_vals = torch.arange(h, dtype=dtype, device=device) / float(max(h, 1)) |
| coor_w_vals = torch.arange(w, dtype=dtype, device=device) / float(max(w, 1)) |
|
|
| coor_h = torch.zeros(1, h, 1, 1, 1, psize, dtype=dtype, device=device) |
| coor_w = torch.zeros(1, 1, w, 1, 1, psize, dtype=dtype, device=device) |
| coor_h[0, :, 0, 0, 0, 0] = coor_h_vals |
| coor_w[0, 0, :, 0, 0, 1] = coor_w_vals |
|
|
| return (v + coor_h + coor_w).view(b, h * w * B, C, psize) |
|
|
| def dynamic_routing(self, v: torch.Tensor, a_in: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
| |
| n, in_votes, C, psize = v.shape |
| b_ij = v.new_zeros(n, in_votes, C) |
|
|
| a_in = a_in.clamp(1e-4, 1.0) |
| for t in range(self.iters): |
| c_ij = F.softmax(b_ij, dim=2) |
| c_ij = c_ij * a_in |
| c_ij = c_ij / (c_ij.sum(dim=2, keepdim=True) + self.eps) |
|
|
| s_j = (c_ij.unsqueeze(-1) * v).sum(dim=1) |
| v_j = self.squash(s_j, dim=-1, eps=self.eps) |
| if t < self.iters - 1: |
| agreement = (v * v_j.unsqueeze(1)).sum(dim=-1) |
| b_ij = b_ij + agreement |
|
|
| |
| a_out = torch.sqrt((v_j * v_j).sum(dim=-1) + self.eps).clamp(1e-4, 1.0 - 1e-4) |
| return v_j, a_out |
|
|
| def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| b, h, w, c = x.shape |
|
|
| if not self.w_shared: |
| patches, oh, ow = self.add_pathes(x, self.K, self.stride) |
| p_in = patches[..., : self.B * self.psize].contiguous().view(b * oh * ow, self.K * self.K * self.B, self.psize) |
| a_in = patches[..., self.B * self.psize :].contiguous().view(b * oh * ow, self.K * self.K * self.B, 1) |
|
|
| votes = self.transform_view(p_in, w_shared=False) |
| p_vec, a_vec = self.dynamic_routing(votes, a_in) |
|
|
| p_out = p_vec.view(b, oh, ow, self.C * self.psize) |
| a_out = a_vec.view(b, oh, ow, self.C) |
| out = torch.cat([p_out, a_out], dim=3) |
| else: |
| if c != self.B * (self.psize + 1) or self.K != 1 or self.stride != 1: |
| raise ValueError('DynamicConvCaps shared mode requires K=1, stride=1 and matching capsule channels') |
|
|
| p_in = x[..., : self.B * self.psize].contiguous().view(b, h * w * self.B, self.psize) |
| a_in = x[..., self.B * self.psize :].contiguous().view(b, h * w * self.B, 1) |
| votes = self.transform_view(p_in, w_shared=True) |
| if self.coor_add: |
| votes = self.add_coord(votes, b, h, w, self.B, self.C, self.psize) |
|
|
| p_vec, a_vec = self.dynamic_routing(votes, a_in) |
| p_out = p_vec.reshape(b, 1, 1, self.C * self.psize).expand(b, h, w, self.C * self.psize) |
| a_out = a_vec.unsqueeze(1).unsqueeze(1).expand(b, h, w, self.C) |
| out = torch.cat([p_out, a_out], dim=3) |
|
|
| p_out = torch.nan_to_num(p_out, nan=0.0, posinf=1e4, neginf=-1e4) |
| a_out = torch.nan_to_num(a_out, nan=0.5, posinf=1.0 - 1e-4, neginf=1e-4) |
| out = torch.nan_to_num(out, nan=0.0, posinf=1e4, neginf=-1e4) |
| return p_out, a_out, out |
|
|
|
|
| class SelfRoutingConvCaps(nn.Module): |
| r"""Convolutional self-routing capsules. |
| |
| Keeps the same output contract as ``ConvCaps``/``DynamicConvCaps``: |
| input: (N, H, W, B*(P*P+1)) |
| output: p_out (N, H_out, W_out, C*P*P), a_out (N, H_out, W_out, C), out concat |
| """ |
|
|
| def __init__( |
| self, |
| B: int = 32, |
| C: int = 32, |
| K: int = 3, |
| P: int = 4, |
| stride: int = 1, |
| iters: int = 1, |
| coor_add: bool = False, |
| w_shared: bool = False, |
| ): |
| super().__init__() |
| _ = (iters, w_shared) |
|
|
| self.B = B |
| self.C = C |
| self.K = K |
| self.P = P |
| self.psize = P * P |
| self.stride = stride |
| self.coor_add = coor_add |
| self.eps = 1e-6 |
|
|
| self.kk = K * K |
| self.kkB = self.kk * B |
|
|
| |
| self.W1 = nn.Parameter(torch.empty(self.kkB, C, self.psize, self.psize)) |
| nn.init.kaiming_uniform_(self.W1, a=math.sqrt(5)) |
|
|
| |
| self.W2 = nn.Parameter(torch.zeros(self.kkB, C, self.psize)) |
| self.b2 = nn.Parameter(torch.zeros(1, 1, self.kkB, C)) |
|
|
| def _output_hw(self, h: int, w: int) -> tuple[int, int]: |
| pad = self.K // 2 |
| oh = (h + 2 * pad - self.K) // self.stride + 1 |
| ow = (w + 2 * pad - self.K) // self.stride + 1 |
| return oh, ow |
|
|
| def _add_coord(self, pose_unf: torch.Tensor, oh: int, ow: int) -> torch.Tensor: |
| |
| if self.psize < 2: |
| return pose_unf |
|
|
| b, L, kkB, _ = pose_unf.shape |
| device, dtype = pose_unf.device, pose_unf.dtype |
| gy = torch.arange(oh, device=device, dtype=dtype) / float(max(oh, 1)) |
| gx = torch.arange(ow, device=device, dtype=dtype) / float(max(ow, 1)) |
| yy, xx = torch.meshgrid(gy, gx, indexing='ij') |
| coords = torch.stack((yy, xx), dim=-1).view(1, L, 1, 2) |
|
|
| pose_unf = pose_unf.clone() |
| pose_unf[..., :2] = pose_unf[..., :2] + coords |
| return pose_unf |
|
|
| def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| |
| b, h, w, c = x.shape |
| expected = self.B * (self.psize + 1) |
| if c != expected: |
| raise ValueError(f'SelfRoutingConvCaps expected {expected} channels, got {c}') |
|
|
| pose = x[..., : self.B * self.psize] |
| act = x[..., self.B * self.psize :] |
|
|
| pose_chw = pose.permute(0, 3, 1, 2).contiguous() |
| act_chw = act.permute(0, 3, 1, 2).contiguous() |
|
|
| pad = self.K // 2 |
| pose_unf = F.unfold(pose_chw, kernel_size=self.K, stride=self.stride, padding=pad) |
| act_unf = F.unfold(act_chw, kernel_size=self.K, stride=self.stride, padding=pad) |
|
|
| oh, ow = self._output_hw(h, w) |
| l = pose_unf.shape[-1] |
|
|
| pose_unf = pose_unf.view(b, self.B, self.psize, self.kk, l).permute(0, 4, 3, 1, 2).contiguous() |
| pose_unf = pose_unf.view(b, l, self.kkB, self.psize) |
|
|
| act_unf = act_unf.view(b, self.B, self.kk, l).permute(0, 3, 2, 1).contiguous() |
| act_unf = act_unf.view(b, l, self.kkB) |
|
|
| if self.coor_add: |
| pose_unf = self._add_coord(pose_unf, oh, ow) |
|
|
| |
| logit = torch.einsum('blip,icp->blic', pose_unf, self.W2) + self.b2 |
| r = F.softmax(logit, dim=3) |
|
|
| ar = act_unf.unsqueeze(-1) * r |
| ar_sum = ar.sum(dim=2, keepdim=True) + self.eps |
| coeff = ar / ar_sum |
|
|
| a_norm = act_unf.sum(dim=2, keepdim=True) + self.eps |
| a_out = (ar_sum.squeeze(2) / a_norm).clamp(1e-4, 1.0 - 1e-4) |
|
|
| pose_votes = torch.einsum('blip,icpq->blicq', pose_unf, self.W1) |
| pose_out = (coeff.unsqueeze(-1) * pose_votes).sum(dim=2) |
|
|
| p_out = pose_out.view(b, oh, ow, self.C * self.psize) |
| a_out = a_out.view(b, oh, ow, self.C) |
| out = torch.cat([p_out, a_out], dim=3) |
|
|
| p_out = torch.nan_to_num(p_out, nan=0.0, posinf=1e4, neginf=-1e4) |
| a_out = torch.nan_to_num(a_out, nan=0.5, posinf=1.0 - 1e-4, neginf=1e-4) |
| out = torch.nan_to_num(out, nan=0.0, posinf=1e4, neginf=-1e4) |
| return p_out, a_out, out |
|
|
|
|
|
|
| class CapsuleDualHead(nn.Module): |
| """Capsule detection head for one feature level. |
| |
| Args: |
| c_in: Input channels of this feature scale (from parser-provided ``ch``). |
| nc: Number of classes (final activation capsule count in ``ConvCaps2``). |
| reg_max: Detect DFL bins, box channels are ``4 * reg_max``. |
| k: Number of capsule types in ``PrimaryCaps``. |
| d: Requested pose descriptor size; internally mapped to square ``P*P``. |
| |
| Input shape: |
| x: ``(N, c_in, H, W)`` |
| |
| Output shape: |
| boxes: ``(N, 4*reg_max, H, W)`` |
| scores: ``(N, nc, H, W)`` |
| aux: dict with final capsule activations when ``return_aux=True`` else ``None`` |
| |
| Parameter size: |
| ``PrimaryCaps(c_in,k) + ConvCaps(k,nc,w_shared=True) + box_bias(4*reg_max)`` |
| |
| Structure: |
| PrimaryCaps -> ConvCaps(class caps only, shared) |
| """ |
|
|
| def __init__(self, c_in: int, nc: int, reg_max: int, k: int, d: int): |
| super().__init__() |
| |
| p = max(1, int(math.ceil(math.sqrt(d)))) |
|
|
| self.nc = nc |
| self.reg_max = reg_max |
| self.P = p |
| self.psize = self.P * self.P |
|
|
| |
| self.primary = PrimaryCaps(A=c_in, B=k, K=1, P=self.P, stride=1) |
| |
| self.conv_caps2 = ConvCaps(B=k, C=nc, K=1, P=self.P, stride=1, iters=1, coor_add=True, w_shared=True) |
|
|
| |
| self.box_bias = nn.Parameter(torch.zeros(4 * reg_max)) |
|
|
| def _pose_to_box(self, p_out: torch.Tensor, a_out: torch.Tensor) -> torch.Tensor: |
| |
| |
| _ = a_out |
| box_ch = 4 * self.reg_max |
|
|
| if p_out.shape[-1] >= box_ch: |
| box = p_out[..., :box_ch] |
| else: |
| |
| reps = math.ceil(box_ch / p_out.shape[-1]) |
| box = p_out.repeat(1, 1, 1, reps)[..., :box_ch] |
|
|
| return box + self.box_bias.view(1, 1, 1, box_ch) |
|
|
| def forward(self, x: torch.Tensor, return_aux: bool = False) -> tuple[torch.Tensor, torch.Tensor, dict | None]: |
| _, _, caps0 = self.primary(x) |
| p2, a2, _ = self.conv_caps2(caps0) |
|
|
| boxes = self._pose_to_box(p2, a2).permute(0, 3, 1, 2).contiguous() |
| a2_logits = torch.logit(a2.clamp(1e-4, 1.0 - 1e-4)) |
| scores = a2_logits.permute(0, 3, 1, 2).contiguous() |
|
|
| aux = None |
| if return_aux: |
| aux = { |
| "caps2_a": a2.permute(0, 3, 1, 2).contiguous(), |
| } |
| return boxes, scores, aux |
|
|
|
|
| class CapsuleClsHead(nn.Module): |
| """Capsule classification branch used as a drop-in replacement for Detect.cv3.""" |
|
|
| def __init__(self, c_in: int, nc: int, k: int = 4, d: int = 16, iters: int = 1): |
| super().__init__() |
| p = max(1, int(math.ceil(math.sqrt(d)))) |
| self.primary = PrimaryCaps(A=c_in, B=k, K=1, P=p, stride=1) |
| |
| self.mid_caps = SelfRoutingConvCaps(B=k, C=int((k+nc)/2), K=1, P=p, stride=1, iters=iters, coor_add=False, w_shared=True) |
| self.class_caps = SelfRoutingConvCaps(B=int((k+nc)/2), C=nc, K=1, P=p, stride=1, iters=iters, coor_add=False, w_shared=True) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| _, _, caps = self.primary(x) |
| _, _, caps_mid = self.mid_caps(caps) |
| _, a_out, _ = self.class_caps(caps_mid) |
| logits = torch.logit(a_out.clamp(1e-4, 1.0 - 1e-4)).permute(0, 3, 1, 2).contiguous() |
| return torch.nan_to_num(logits, nan=0.0, posinf=20.0, neginf=-20.0).float() |
|
|
|
|
| class CapsuleDetect(Detect): |
| """Detect head with capsule vote aggregation for both box and cls branches. |
| |
| Input feature of level i is packed as interleaved channels: |
| [pose(d_i), act(1)] repeated k_i times -> C_i = k_i * (d_i + 1) |
| |
| In forward_head: |
| - split pose/act per capsule type |
| - run Detect box/cls heads on each type-specific pose tensor |
| - aggregate type predictions with act-driven vote weights |
| |
| Detect decode/postprocess/end2end flow is reused unchanged. |
| """ |
|
|
| def __init__( |
| self, |
| nc: int = 80, |
| *args, |
| reg_max: int = 16, |
| end2end: bool = False, |
| k: list[int] | tuple[int, ...] = (4, 8, 16), |
| d: list[int] | tuple[int, ...] = (16, 16, 16), |
| ch: tuple = (), |
| ): |
| parsed = list(args) |
| if parsed and isinstance(parsed[-1], (list, tuple)): |
| ch = tuple(parsed.pop(-1)) |
|
|
| |
| if len(parsed) not in (2, 4): |
| raise ValueError('CapsuleDetect expects [k_list, d_list, reg_max, end2end, ch].') |
|
|
| k, d = parsed[0], parsed[1] |
| if len(parsed) == 4: |
| reg_max = int(parsed[2]) |
| end2end = bool(parsed[3]) |
|
|
| if not isinstance(k, (list, tuple)) or not isinstance(d, (list, tuple)): |
| raise TypeError('CapsuleDetect requires list/tuple k and d (per-level settings).') |
|
|
| ch = tuple(int(c) for c in ch) |
| nl = len(ch) |
| if len(k) != nl or len(d) != nl: |
| raise ValueError(f'CapsuleDetect k/d length must equal number of levels ({nl}).') |
|
|
| self.k_list = tuple(int(v) for v in k) |
| self.d_list = tuple(int(v) for v in d) |
|
|
| for i, c in enumerate(ch): |
| expected = self.k_list[i] * (self.d_list[i] + 1) |
| if c != expected: |
| raise ValueError( |
| f'CapsuleDetect level-{i} channel mismatch: got {c}, expected {expected} from k={self.k_list[i]}, d={self.d_list[i]}.' |
| ) |
|
|
| |
| super().__init__(nc=nc, reg_max=reg_max, end2end=end2end, ch=self.d_list) |
|
|
| |
| self.box_vote = nn.ModuleList( |
| nn.Sequential(Conv(k_i, k_i, 3), nn.Conv2d(k_i, k_i, 1, bias=True)) for k_i in self.k_list |
| ) |
| self.cls_vote = nn.ModuleList( |
| nn.Sequential(Conv(k_i, k_i, 3), nn.Conv2d(k_i, k_i, 1, bias=True)) for k_i in self.k_list |
| ) |
|
|
| def _split_caps(self, x: list[torch.Tensor]) -> tuple[list[torch.Tensor], list[torch.Tensor]]: |
| """Split packed feature into pose and activation tensors per level. |
| |
| Returns: |
| pose_caps: list of tensors, each (B, K, D, H, W) |
| act_map: list of tensors, each (B, K, H, W) |
| """ |
| pose_caps, act_map = [], [] |
| for i, xi in enumerate(x): |
| k_i = self.k_list[i] |
| d_i = self.d_list[i] |
| c = int(xi.shape[1]) |
| expected = k_i * (d_i + 1) |
| if c != expected: |
| raise ValueError(f'CapsuleDetect level-{i} channel mismatch: got {c}, expected {expected}.') |
|
|
| b, _, h, w = xi.shape |
| caps = xi.view(b, k_i, d_i + 1, h, w) |
| pose_caps.append(caps[:, :, :d_i].contiguous()) |
| act_map.append(caps[:, :, d_i].contiguous()) |
| return pose_caps, act_map |
|
|
| @staticmethod |
| def _normalized_votes(raw: torch.Tensor, eps: float = 1e-4) -> torch.Tensor: |
| |
| w = F.softplus(raw) + eps |
| return w / (w.sum(dim=1, keepdim=True) + eps) |
|
|
| def _run_voted_head( |
| self, |
| pose: torch.Tensor, |
| act: torch.Tensor, |
| head: torch.nn.Module, |
| vote_head: torch.nn.Module, |
| out_ch: int, |
| ) -> torch.Tensor: |
| """Apply one Detect head per type and aggregate by vote weights. |
| |
| Args: |
| pose: (B, K, D, H, W) |
| act: (B, K, H, W) |
| head: Detect box or cls head module for this level |
| vote_head: vote logits module for this level |
| out_ch: output channels of target prediction |
| |
| Returns: |
| (B, out_ch, H, W) |
| """ |
| b, k, d, h, w = pose.shape |
|
|
| |
| if k == 1: |
| return head(pose[:, 0]) |
|
|
| pose_bt = pose.reshape(b * k, d, h, w) |
| pred_bt = head(pose_bt).reshape(b, k, out_ch, h, w) |
|
|
| vote_raw = vote_head(act) |
| vote = self._normalized_votes(vote_raw).unsqueeze(2) |
| pred = (pred_bt * vote).sum(dim=1) |
| return pred |
|
|
| def forward_head( |
| self, x: list[torch.Tensor], box_head: torch.nn.Module = None, cls_head: torch.nn.Module = None |
| ) -> dict[str, torch.Tensor]: |
| if box_head is None or cls_head is None: |
| return dict() |
|
|
| pose_caps, act_map = self._split_caps(x) |
| bs = x[0].shape[0] |
|
|
| box_list = [] |
| cls_list = [] |
| for i in range(self.nl): |
| box_i = self._run_voted_head( |
| pose_caps[i], |
| act_map[i], |
| box_head[i], |
| self.box_vote[i], |
| out_ch=4 * self.reg_max, |
| ) |
| cls_i = self._run_voted_head( |
| pose_caps[i], |
| act_map[i], |
| cls_head[i], |
| self.cls_vote[i], |
| out_ch=self.nc, |
| ) |
| box_list.append(box_i.view(bs, 4 * self.reg_max, -1)) |
| cls_list.append(cls_i.view(bs, self.nc, -1)) |
|
|
| boxes = torch.cat(box_list, dim=-1) |
| scores = torch.cat(cls_list, dim=-1) |
| return dict(boxes=boxes, scores=scores, feats=x) |
| |
|
|
| class CapsuleDetectv1(Detect): |
| """Capsule Detect variant with activation-gated pose fusion. |
| |
| Per level: |
| 1) Split packed capsule channels into pose/activation (interleaved by type). |
| 2) Use a 2-layer 1x1 gate net on activation channels. |
| 3) Gate pose channels with residual scaling. |
| 4) Flatten to K*D channels and run original Detect cv2/cv3 heads. |
| """ |
|
|
| def __init__( |
| self, |
| nc: int = 80, |
| *args, |
| reg_max: int = 16, |
| end2end: bool = False, |
| k: list[int] | tuple[int, ...] = (4, 8, 16), |
| d: list[int] | tuple[int, ...] = (16, 16, 16), |
| ch: tuple = (), |
| ): |
| parsed = list(args) |
| if parsed and isinstance(parsed[-1], (list, tuple)): |
| ch = tuple(parsed.pop(-1)) |
|
|
| |
| if len(parsed) not in (2, 4): |
| raise ValueError("CapsuleDetectv1 expects [k_list, d_list, reg_max, end2end, ch].") |
|
|
| k, d = parsed[0], parsed[1] |
| if len(parsed) == 4: |
| reg_max = int(parsed[2]) |
| end2end = bool(parsed[3]) |
|
|
| if not isinstance(k, (list, tuple)) or not isinstance(d, (list, tuple)): |
| raise TypeError("CapsuleDetectv1 requires list/tuple k and d (per-level settings).") |
|
|
| ch = tuple(int(c) for c in ch) |
| nl = len(ch) |
| if len(k) != nl or len(d) != nl: |
| raise ValueError(f"CapsuleDetectv1 k/d length must equal number of levels ({nl}).") |
|
|
| self.k_list = tuple(int(v) for v in k) |
| self.d_list = tuple(int(v) for v in d) |
|
|
| |
| for i, c in enumerate(ch): |
| expected = self.k_list[i] * (self.d_list[i] + 1) |
| if c != expected: |
| raise ValueError( |
| f"CapsuleDetectv1 level-{i} channel mismatch: got {c}, " |
| f"expected {expected} from k={self.k_list[i]}, d={self.d_list[i]}." |
| ) |
|
|
| |
| merged_ch = tuple(k_i * d_i for k_i, d_i in zip(self.k_list, self.d_list)) |
| super().__init__(nc=nc, reg_max=reg_max, end2end=end2end, ch=merged_ch) |
|
|
| self.pose_gates = nn.ModuleList() |
| self.gate_alpha = nn.ParameterList() |
| for k_i, d_i in zip(self.k_list, self.d_list): |
| out_ch = k_i * d_i |
| hidden = max(8, k_i * 2) |
| self.pose_gates.append( |
| nn.Sequential( |
| nn.Conv2d(k_i, hidden, 1, bias=True), |
| nn.SiLU(inplace=True), |
| nn.Conv2d(hidden, out_ch, 1, bias=True), |
| ) |
| ) |
| self.gate_alpha.append(nn.Parameter(torch.tensor(0.5))) |
|
|
| def _split_pose_act(self, x: torch.Tensor, i: int) -> tuple[torch.Tensor, torch.Tensor]: |
| """Split one level packed tensor into pose and activation maps.""" |
| k_i = self.k_list[i] |
| d_i = self.d_list[i] |
| b, c, h, w = x.shape |
| expected = k_i * (d_i + 1) |
| if c != expected: |
| raise ValueError(f"CapsuleDetectv1 level-{i} channel mismatch: got {c}, expected {expected}.") |
|
|
| caps = x.view(b, k_i, d_i + 1, h, w) |
| pose = caps[:, :, :d_i].reshape(b, k_i * d_i, h, w).contiguous() |
| act = caps[:, :, d_i].contiguous() |
| return pose, act |
|
|
| def _merge_pose(self, x: list[torch.Tensor]) -> list[torch.Tensor]: |
| merged = [] |
| for i, xi in enumerate(x): |
| pose, act = self._split_pose_act(xi, i) |
| gate = torch.sigmoid(self.pose_gates[i](act)) |
| |
| pose = pose * (1.0 + self.gate_alpha[i] * gate) |
| merged.append(pose) |
| return merged |
|
|
| def forward_head( |
| self, x: list[torch.Tensor], box_head: torch.nn.Module = None, cls_head: torch.nn.Module = None |
| ) -> dict[str, torch.Tensor]: |
| if box_head is None or cls_head is None: |
| return dict() |
|
|
| pose_feats = self._merge_pose(x) |
| bs = pose_feats[0].shape[0] |
|
|
| box_list = [box_head[i](pose_feats[i]).view(bs, 4 * self.reg_max, -1) for i in range(self.nl)] |
| cls_list = [cls_head[i](pose_feats[i]).view(bs, self.nc, -1) for i in range(self.nl)] |
| boxes = torch.cat(box_list, dim=-1) |
| scores = torch.cat(cls_list, dim=-1) |
| return dict(boxes=boxes, scores=scores, feats=x) |
|
|
|
|
| class CapsuleDetectv2(Detect): |
| """Capsule Detect v2: activation-gated pose + activation bypass for classification.""" |
|
|
| def __init__( |
| self, |
| nc: int = 80, |
| *args, |
| reg_max: int = 16, |
| end2end: bool = False, |
| k: list[int] | tuple[int, ...] = (4, 8, 16), |
| d: list[int] | tuple[int, ...] = (16, 16, 16), |
| ch: tuple = (), |
| ): |
| parsed = list(args) |
| if parsed and isinstance(parsed[-1], (list, tuple)): |
| ch = tuple(parsed.pop(-1)) |
|
|
| |
| if len(parsed) not in (2, 4): |
| raise ValueError("CapsuleDetectv2 expects [k_list, d_list, reg_max, end2end, ch].") |
|
|
| k, d = parsed[0], parsed[1] |
| if len(parsed) == 4: |
| reg_max = int(parsed[2]) |
| end2end = bool(parsed[3]) |
|
|
| if not isinstance(k, (list, tuple)) or not isinstance(d, (list, tuple)): |
| raise TypeError("CapsuleDetectv2 requires list/tuple k and d (per-level settings).") |
|
|
| ch = tuple(int(c) for c in ch) |
| nl = len(ch) |
| if len(k) != nl or len(d) != nl: |
| raise ValueError(f"CapsuleDetectv2 k/d length must equal number of levels ({nl}).") |
|
|
| self.k_list = tuple(int(v) for v in k) |
| self.d_list = tuple(int(v) for v in d) |
|
|
| |
| for i, c in enumerate(ch): |
| expected = self.k_list[i] * (self.d_list[i] + 1) |
| if c != expected: |
| raise ValueError( |
| f"CapsuleDetectv2 level-{i} channel mismatch: got {c}, " |
| f"expected {expected} from k={self.k_list[i]}, d={self.d_list[i]}." |
| ) |
|
|
| |
| merged_ch = tuple(k_i * d_i for k_i, d_i in zip(self.k_list, self.d_list)) |
| super().__init__(nc=nc, reg_max=reg_max, end2end=end2end, ch=merged_ch) |
|
|
| self.pose_gates = nn.ModuleList() |
| self.gate_alpha = nn.ParameterList() |
| self.cls_bypass = nn.ModuleList() |
| self.cls_beta = nn.ParameterList() |
|
|
| for k_i, d_i in zip(self.k_list, self.d_list): |
| pose_ch = k_i * d_i |
| gate_hidden = max(8, k_i * 2) |
| self.pose_gates.append( |
| nn.Sequential( |
| nn.Conv2d(k_i, gate_hidden, 1, bias=True), |
| nn.SiLU(inplace=True), |
| nn.Conv2d(gate_hidden, pose_ch, 1, bias=True), |
| ) |
| ) |
| self.gate_alpha.append(nn.Parameter(torch.tensor(0.5))) |
|
|
| cls_hidden = max(16, k_i * 2) |
| self.cls_bypass.append( |
| nn.Sequential( |
| nn.Conv2d(k_i, cls_hidden, 1, bias=True), |
| nn.SiLU(inplace=True), |
| nn.Conv2d(cls_hidden, pose_ch, 1, bias=True), |
| ) |
| ) |
| self.cls_beta.append(nn.Parameter(torch.tensor(0.1))) |
|
|
| def _split_pose_act(self, x: torch.Tensor, i: int) -> tuple[torch.Tensor, torch.Tensor]: |
| """Split one level packed tensor into pose and activation maps.""" |
| k_i = self.k_list[i] |
| d_i = self.d_list[i] |
| b, c, h, w = x.shape |
| expected = k_i * (d_i + 1) |
| if c != expected: |
| raise ValueError(f"CapsuleDetectv2 level-{i} channel mismatch: got {c}, expected {expected}.") |
|
|
| caps = x.view(b, k_i, d_i + 1, h, w) |
| pose = caps[:, :, :d_i].reshape(b, k_i * d_i, h, w).contiguous() |
| act = caps[:, :, d_i].contiguous() |
| return pose, act |
|
|
| def _fuse_pose(self, x: list[torch.Tensor]) -> tuple[list[torch.Tensor], list[torch.Tensor]]: |
| box_feats, cls_feats = [], [] |
| for i, xi in enumerate(x): |
| pose, act = self._split_pose_act(xi, i) |
| gate = torch.sigmoid(self.pose_gates[i](act)) |
| pose_g = pose * (1.0 + self.gate_alpha[i] * gate) |
|
|
| |
| act_skip = self.cls_bypass[i](act) |
| cls_in = pose_g + self.cls_beta[i] * act_skip |
|
|
| box_feats.append(pose_g) |
| cls_feats.append(cls_in) |
| return box_feats, cls_feats |
|
|
| def forward_head( |
| self, x: list[torch.Tensor], box_head: torch.nn.Module = None, cls_head: torch.nn.Module = None |
| ) -> dict[str, torch.Tensor]: |
| if box_head is None or cls_head is None: |
| return dict() |
|
|
| box_feats, cls_feats = self._fuse_pose(x) |
| bs = x[0].shape[0] |
|
|
| box_list = [box_head[i](box_feats[i]).view(bs, 4 * self.reg_max, -1) for i in range(self.nl)] |
| cls_list = [cls_head[i](cls_feats[i]).view(bs, self.nc, -1) for i in range(self.nl)] |
| boxes = torch.cat(box_list, dim=-1) |
| scores = torch.cat(cls_list, dim=-1) |
| return dict(boxes=boxes, scores=scores, feats=x) |
|
|
|
|
| class CapsuleDetectv4(Detect): |
| """Capsule Detect v4: box uses raw pose, cls uses act bypass + symbolic type prior.""" |
|
|
| def __init__( |
| self, |
| nc: int = 80, |
| *args, |
| reg_max: int = 16, |
| end2end: bool = False, |
| k: list[int] | tuple[int, ...] = (4, 8, 16), |
| d: list[int] | tuple[int, ...] = (16, 16, 16), |
| ch: tuple = (), |
| ): |
| parsed = list(args) |
| if parsed and isinstance(parsed[-1], (list, tuple)): |
| ch = tuple(parsed.pop(-1)) |
|
|
| if len(parsed) not in (2, 4): |
| raise ValueError("CapsuleDetectv4 expects [k_list, d_list, reg_max, end2end, ch].") |
|
|
| k, d = parsed[0], parsed[1] |
| if len(parsed) == 4: |
| reg_max = int(parsed[2]) |
| end2end = bool(parsed[3]) |
|
|
| if not isinstance(k, (list, tuple)) or not isinstance(d, (list, tuple)): |
| raise TypeError("CapsuleDetectv4 requires list/tuple k and d (per-level settings).") |
|
|
| ch = tuple(int(c) for c in ch) |
| nl = len(ch) |
| if len(k) != nl or len(d) != nl: |
| raise ValueError(f"CapsuleDetectv4 k/d length must equal number of levels ({nl}).") |
|
|
| self.k_list = tuple(int(v) for v in k) |
| self.d_list = tuple(int(v) for v in d) |
|
|
| for i, c in enumerate(ch): |
| expected = self.k_list[i] * (self.d_list[i] + 1) |
| if c != expected: |
| raise ValueError( |
| f"CapsuleDetectv4 level-{i} channel mismatch: got {c}, " |
| f"expected {expected} from k={self.k_list[i]}, d={self.d_list[i]}." |
| ) |
|
|
| merged_ch = tuple(k_i * d_i for k_i, d_i in zip(self.k_list, self.d_list)) |
| super().__init__(nc=nc, reg_max=reg_max, end2end=end2end, ch=merged_ch) |
|
|
| self.cls_bypass = nn.ModuleList() |
| self.cls_beta = nn.ParameterList() |
| self.sym_prior = nn.ModuleList() |
| self.sym_beta = nn.ParameterList() |
|
|
| for k_i, d_i in zip(self.k_list, self.d_list): |
| pose_ch = k_i * d_i |
| cls_hidden = max(16, k_i * 2) |
| self.cls_bypass.append( |
| nn.Sequential( |
| nn.Conv2d(k_i, cls_hidden, 1, bias=True), |
| nn.SiLU(inplace=True), |
| nn.Conv2d(cls_hidden, pose_ch, 1, bias=True), |
| ) |
| ) |
| self.cls_beta.append(nn.Parameter(torch.tensor(0.1))) |
| self.sym_prior.append(nn.Conv2d(k_i, self.nc, 1, bias=False)) |
| self.sym_beta.append(nn.Parameter(torch.tensor(0.1))) |
|
|
| def _split_pose_act(self, x: torch.Tensor, i: int) -> tuple[torch.Tensor, torch.Tensor]: |
| k_i = self.k_list[i] |
| d_i = self.d_list[i] |
| b, c, h, w = x.shape |
| expected = k_i * (d_i + 1) |
| if c != expected: |
| raise ValueError(f"CapsuleDetectv4 level-{i} channel mismatch: got {c}, expected {expected}.") |
|
|
| caps = x.view(b, k_i, d_i + 1, h, w) |
| pose = caps[:, :, :d_i].reshape(b, k_i * d_i, h, w).contiguous() |
| act = caps[:, :, d_i].contiguous() |
| return pose, act |
|
|
| def _build_feats(self, x: list[torch.Tensor]) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: |
| box_feats, cls_feats, cls_priors = [], [], [] |
| for i, xi in enumerate(x): |
| pose, act = self._split_pose_act(xi, i) |
| cls_in = pose + self.cls_beta[i] * self.cls_bypass[i](act) |
| cls_prior = self.sym_beta[i] * self.sym_prior[i](act) |
| box_feats.append(pose) |
| cls_feats.append(cls_in) |
| cls_priors.append(cls_prior) |
| return box_feats, cls_feats, cls_priors |
|
|
| def forward_head( |
| self, x: list[torch.Tensor], box_head: torch.nn.Module = None, cls_head: torch.nn.Module = None |
| ) -> dict[str, torch.Tensor]: |
| if box_head is None or cls_head is None: |
| return dict() |
|
|
| box_feats, cls_feats, cls_priors = self._build_feats(x) |
| bs = x[0].shape[0] |
|
|
| box_list = [box_head[i](box_feats[i]).view(bs, 4 * self.reg_max, -1) for i in range(self.nl)] |
| cls_list = [ |
| (cls_head[i](cls_feats[i]) + cls_priors[i]).view(bs, self.nc, -1) |
| for i in range(self.nl) |
| ] |
| boxes = torch.cat(box_list, dim=-1) |
| scores = torch.cat(cls_list, dim=-1) |
| return dict(boxes=boxes, scores=scores, feats=x) |
|
|
|
|
| def _setup_capsule_layout( |
| k: list[int] | tuple[int, ...], |
| d: list[int] | tuple[int, ...], |
| ch: tuple, |
| cls_name: str, |
| ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: |
| if not isinstance(k, (list, tuple)) or not isinstance(d, (list, tuple)): |
| raise TypeError(f"{cls_name} requires list/tuple k and d (per-level settings).") |
|
|
| ch = tuple(int(c) for c in ch) |
| nl = len(ch) |
| if len(k) != nl or len(d) != nl: |
| raise ValueError(f"{cls_name} k/d length must equal number of levels ({nl}).") |
|
|
| k_list = tuple(int(v) for v in k) |
| d_list = tuple(int(v) for v in d) |
| for i, c in enumerate(ch): |
| expected = k_list[i] * (d_list[i] + 1) |
| if c != expected: |
| raise ValueError( |
| f"{cls_name} level-{i} channel mismatch: got {c}, " |
| f"expected {expected} from k={k_list[i]}, d={d_list[i]}." |
| ) |
| merged_ch = tuple(k_i * d_i for k_i, d_i in zip(k_list, d_list)) |
| return k_list, d_list, merged_ch |
|
|
|
|
| def _init_capsule_semantic_heads(obj: nn.Module) -> None: |
| obj.cls_bypass = nn.ModuleList() |
| obj.cls_beta = nn.ParameterList() |
| obj.sym_prior = nn.ModuleList() |
| obj.sym_norm = nn.ModuleList() |
| obj.sym_dropout = nn.ModuleList() |
| obj.sym_beta = nn.ParameterList() |
|
|
| for k_i, d_i in zip(obj.k_list, obj.d_list): |
| pose_ch = k_i * d_i |
| cls_hidden = max(16, k_i * 2) |
| obj.cls_bypass.append( |
| nn.Sequential( |
| nn.Conv2d(k_i, cls_hidden, 1, bias=True), |
| nn.SiLU(inplace=True), |
| nn.Conv2d(cls_hidden, pose_ch, 1, bias=True), |
| ) |
| ) |
| obj.cls_beta.append(nn.Parameter(torch.tensor(0.1))) |
| obj.sym_dropout.append(nn.Dropout2d(p=0.1)) |
| obj.sym_prior.append(nn.Conv2d(k_i, obj.nc, 1, bias=False)) |
| obj.sym_norm.append(nn.GroupNorm(1, obj.nc)) |
| obj.sym_beta.append(nn.Parameter(torch.tensor(0.1))) |
|
|
|
|
| def _capsule_split_pose_act( |
| x: torch.Tensor, |
| k_i: int, |
| d_i: int, |
| cls_name: str, |
| level_i: int, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| b, c, h, w = x.shape |
| expected = k_i * (d_i + 1) |
| if c != expected: |
| raise ValueError(f"{cls_name} level-{level_i} channel mismatch: got {c}, expected {expected}.") |
| caps = x.view(b, k_i, d_i + 1, h, w) |
| pose = caps[:, :, :d_i].reshape(b, k_i * d_i, h, w).contiguous() |
| act = caps[:, :, d_i].contiguous() |
| return pose, act |
|
|
|
|
| def _capsule_build_feats(obj: nn.Module, x: list[torch.Tensor]) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: |
| box_feats, cls_feats, cls_priors = [], [], [] |
| cls_name = obj.__class__.__name__ |
| for i, xi in enumerate(x): |
| pose, act = _capsule_split_pose_act(xi, obj.k_list[i], obj.d_list[i], cls_name, i) |
| cls_scale = torch.tanh(obj.cls_beta[i]) |
| cls_in = pose + cls_scale * obj.cls_bypass[i](act) |
| act_s = obj.sym_dropout[i](act) |
| prior = obj.sym_prior[i](act_s) |
| prior = obj.sym_norm[i](prior) |
| prior = prior - prior.mean(dim=1, keepdim=True) |
| sym_scale = torch.tanh(obj.sym_beta[i]) |
| cls_prior = sym_scale * prior |
| box_feats.append(pose) |
| cls_feats.append(cls_in) |
| cls_priors.append(cls_prior) |
| return box_feats, cls_feats, cls_priors |
|
|
|
|
| def _capsule_build_feats_gated( |
| obj: nn.Module, x: list[torch.Tensor] |
| ) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: |
| box_feats, cls_feats, cls_priors = [], [], [] |
| cls_name = obj.__class__.__name__ |
| for i, xi in enumerate(x): |
| pose, act = _capsule_split_pose_act(xi, obj.k_list[i], obj.d_list[i], cls_name, i) |
| cls_scale = torch.tanh(obj.cls_beta[i]) |
| gate = torch.sigmoid(obj.cls_bypass[i](act)) |
| cls_in = pose * (1.0 + cls_scale * gate) |
|
|
| act_s = obj.sym_dropout[i](act) |
| prior = obj.sym_prior[i](act_s) |
| prior = obj.sym_norm[i](prior) |
| prior = prior - prior.mean(dim=1, keepdim=True) |
| sym_scale = torch.tanh(obj.sym_beta[i]) |
| cls_prior = sym_scale * prior |
|
|
| box_feats.append(pose) |
| cls_feats.append(cls_in) |
| cls_priors.append(cls_prior) |
| return box_feats, cls_feats, cls_priors |
|
|
|
|
| def _capsule_build_feats_boxcls( |
| obj: nn.Module, x: list[torch.Tensor] |
| ) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: |
| box_feats, cls_feats, cls_priors = [], [], [] |
| cls_name = obj.__class__.__name__ |
| for i, xi in enumerate(x): |
| pose, act = _capsule_split_pose_act(xi, obj.k_list[i], obj.d_list[i], cls_name, i) |
| act_s = obj.sym_dropout[i](act) |
| prior = obj.sym_prior[i](act_s) |
| prior = obj.sym_norm[i](prior) |
| prior = prior - prior.mean(dim=1, keepdim=True) |
| sym_scale = torch.tanh(obj.sym_beta[i]) |
| cls_prior = sym_scale * prior |
|
|
| box_feats.append(pose) |
| cls_feats.append(pose) |
| cls_priors.append(cls_prior) |
| return box_feats, cls_feats, cls_priors |
|
|
|
|
| def _capsule_build_feats_boxcls_simpleprior( |
| obj: nn.Module, x: list[torch.Tensor] |
| ) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: |
| box_feats, cls_feats, cls_priors = [], [], [] |
| cls_name = obj.__class__.__name__ |
| for i, xi in enumerate(x): |
| pose, act = _capsule_split_pose_act(xi, obj.k_list[i], obj.d_list[i], cls_name, i) |
| act_s = obj.sym_dropout[i](act) |
| prior = obj.sym_prior[i](act_s) |
| sym_scale = torch.tanh(obj.sym_beta[i]) |
| cls_prior = sym_scale * prior |
|
|
| box_feats.append(pose) |
| cls_feats.append(pose) |
| cls_priors.append(cls_prior) |
| return box_feats, cls_feats, cls_priors |
|
|
|
|
| def _capsule_build_feats_open_vocab( |
| obj: nn.Module, x: list[torch.Tensor] |
| ) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: |
| box_feats, cls_feats, acts = [], [], [] |
| cls_name = obj.__class__.__name__ |
| for i, xi in enumerate(x): |
| pose, act = _capsule_split_pose_act(xi, obj.k_list[i], obj.d_list[i], cls_name, i) |
| cls_in = pose |
| if getattr(obj, "with_act_gate", False): |
| cls_scale = torch.tanh(obj.ov_beta[i]) |
| gate = torch.sigmoid(obj.ov_gate[i](act)) |
| cls_in = pose * (1.0 + cls_scale * gate) |
| box_feats.append(pose) |
| cls_feats.append(cls_in) |
| acts.append(act) |
| return box_feats, cls_feats, acts |
|
|
|
|
| class CapsuleDetectv5(Detect): |
| """Capsule Detect v5: box uses raw pose, cls uses stabilized symbolic prior.""" |
|
|
| def __init__( |
| self, |
| nc: int = 80, |
| *args, |
| reg_max: int = 16, |
| end2end: bool = False, |
| k: list[int] | tuple[int, ...] = (4, 8, 16), |
| d: list[int] | tuple[int, ...] = (16, 16, 16), |
| ch: tuple = (), |
| ): |
| parsed = list(args) |
| if parsed and isinstance(parsed[-1], (list, tuple)): |
| ch = tuple(parsed.pop(-1)) |
|
|
| if len(parsed) not in (2, 4): |
| raise ValueError("CapsuleDetectv5 expects [k_list, d_list, reg_max, end2end, ch].") |
|
|
| k, d = parsed[0], parsed[1] |
| if len(parsed) == 4: |
| reg_max = int(parsed[2]) |
| end2end = bool(parsed[3]) |
|
|
| self.k_list, self.d_list, merged_ch = _setup_capsule_layout(k, d, ch, "CapsuleDetectv5") |
| super().__init__(nc=nc, reg_max=reg_max, end2end=end2end, ch=merged_ch) |
| _init_capsule_semantic_heads(self) |
|
|
| def _split_pose_act(self, x: torch.Tensor, i: int) -> tuple[torch.Tensor, torch.Tensor]: |
| return _capsule_split_pose_act(x, self.k_list[i], self.d_list[i], "CapsuleDetectv5", i) |
|
|
| def _build_feats(self, x: list[torch.Tensor]) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: |
| return _capsule_build_feats(self, x) |
|
|
| def forward_head( |
| self, x: list[torch.Tensor], box_head: torch.nn.Module = None, cls_head: torch.nn.Module = None |
| ) -> dict[str, torch.Tensor]: |
| if box_head is None or cls_head is None: |
| return dict() |
|
|
| box_feats, cls_feats, cls_priors = self._build_feats(x) |
| bs = x[0].shape[0] |
|
|
| box_list = [box_head[i](box_feats[i]).view(bs, 4 * self.reg_max, -1) for i in range(self.nl)] |
| cls_list = [ |
| (cls_head[i](cls_feats[i]) + cls_priors[i]).view(bs, self.nc, -1) |
| for i in range(self.nl) |
| ] |
| boxes = torch.cat(box_list, dim=-1) |
| scores = torch.cat(cls_list, dim=-1) |
| return dict(boxes=boxes, scores=scores, feats=x) |
|
|
|
|
| class CapsuleDetectv6(CapsuleDetectv5): |
| """Capsule Detect v6: replace additive cls correction with multiplicative act gate.""" |
|
|
| def _build_feats(self, x: list[torch.Tensor]) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: |
| return _capsule_build_feats_gated(self, x) |
|
|
|
|
| class CapsuleDetectv7(CapsuleDetectv5): |
| """Capsule Detect v7: cls head consumes raw pose features plus symbolic priors only.""" |
|
|
| def __init__(self, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
| self.profile_head = False |
| self._head_profile: dict[str, float] = {} |
| self._head_profile_calls = 0 |
|
|
| def _ensure_profile_attrs(self) -> None: |
| if not hasattr(self, "profile_head"): |
| self.profile_head = False |
| if not hasattr(self, "_head_profile"): |
| self._head_profile = {} |
| if not hasattr(self, "_head_profile_calls"): |
| self._head_profile_calls = 0 |
|
|
| def reset_head_profile(self) -> None: |
| self._ensure_profile_attrs() |
| self._head_profile = { |
| "split_pose_act_ms": 0.0, |
| "cls_prior_ms": 0.0, |
| "box_head_ms": 0.0, |
| "cls_head_ms": 0.0, |
| "cat_ms": 0.0, |
| } |
| self._head_profile_calls = 0 |
|
|
| def get_head_profile(self) -> dict[str, float]: |
| self._ensure_profile_attrs() |
| if not self._head_profile: |
| return {} |
| out = dict(self._head_profile) |
| calls = max(self._head_profile_calls, 1) |
| out["calls"] = float(self._head_profile_calls) |
| out["total_ms"] = sum(v for k, v in out.items() if k.endswith("_ms")) |
| for key in list(self._head_profile): |
| out[key.replace("_ms", "_avg_ms")] = self._head_profile[key] / calls |
| return out |
|
|
| def _sync_profile(self) -> None: |
| self._ensure_profile_attrs() |
| if self.profile_head and torch.cuda.is_available(): |
| torch.cuda.synchronize() |
|
|
| def _build_feats(self, x: list[torch.Tensor]) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: |
| self._ensure_profile_attrs() |
| if not self.profile_head: |
| return _capsule_build_feats_boxcls(self, x) |
|
|
| if not self._head_profile: |
| self.reset_head_profile() |
|
|
| box_feats, cls_feats, cls_priors = [], [], [] |
| cls_name = self.__class__.__name__ |
| for i, xi in enumerate(x): |
| self._sync_profile() |
| t0 = time.perf_counter() |
| pose, act = _capsule_split_pose_act(xi, self.k_list[i], self.d_list[i], cls_name, i) |
| self._sync_profile() |
| self._head_profile["split_pose_act_ms"] += (time.perf_counter() - t0) * 1000.0 |
|
|
| self._sync_profile() |
| t0 = time.perf_counter() |
| act_s = self.sym_dropout[i](act) |
| prior = self.sym_prior[i](act_s) |
| prior = self.sym_norm[i](prior) |
| prior = prior - prior.mean(dim=1, keepdim=True) |
| sym_scale = torch.tanh(self.sym_beta[i]) |
| cls_prior = sym_scale * prior |
| self._sync_profile() |
| self._head_profile["cls_prior_ms"] += (time.perf_counter() - t0) * 1000.0 |
|
|
| box_feats.append(pose) |
| cls_feats.append(pose) |
| cls_priors.append(cls_prior) |
| return box_feats, cls_feats, cls_priors |
|
|
| def forward_head( |
| self, x: list[torch.Tensor], box_head: torch.nn.Module = None, cls_head: torch.nn.Module = None |
| ) -> dict[str, torch.Tensor]: |
| self._ensure_profile_attrs() |
| if box_head is None or cls_head is None: |
| return dict() |
|
|
| box_feats, cls_feats, cls_priors = self._build_feats(x) |
| bs = x[0].shape[0] |
|
|
| if not self.profile_head: |
| box_list = [box_head[i](box_feats[i]).view(bs, 4 * self.reg_max, -1) for i in range(self.nl)] |
| cls_list = [ |
| (cls_head[i](cls_feats[i]) + cls_priors[i]).view(bs, self.nc, -1) |
| for i in range(self.nl) |
| ] |
| boxes = torch.cat(box_list, dim=-1) |
| scores = torch.cat(cls_list, dim=-1) |
| return dict(boxes=boxes, scores=scores, feats=x) |
|
|
| if not self._head_profile: |
| self.reset_head_profile() |
| self._head_profile_calls += 1 |
|
|
| box_list, cls_list = [], [] |
| for i in range(self.nl): |
| self._sync_profile() |
| t0 = time.perf_counter() |
| box_i = box_head[i](box_feats[i]).view(bs, 4 * self.reg_max, -1) |
| self._sync_profile() |
| self._head_profile["box_head_ms"] += (time.perf_counter() - t0) * 1000.0 |
| box_list.append(box_i) |
|
|
| self._sync_profile() |
| t0 = time.perf_counter() |
| cls_i = (cls_head[i](cls_feats[i]) + cls_priors[i]).view(bs, self.nc, -1) |
| self._sync_profile() |
| self._head_profile["cls_head_ms"] += (time.perf_counter() - t0) * 1000.0 |
| cls_list.append(cls_i) |
|
|
| self._sync_profile() |
| t0 = time.perf_counter() |
| boxes = torch.cat(box_list, dim=-1) |
| scores = torch.cat(cls_list, dim=-1) |
| self._sync_profile() |
| self._head_profile["cat_ms"] += (time.perf_counter() - t0) * 1000.0 |
| return dict(boxes=boxes, scores=scores, feats=x) |
|
|
|
|
| class CapsuleDetectv8(CapsuleDetectv5): |
| """Capsule Detect v8: raw pose cls path with simplified cls_prior (no norm/centering).""" |
|
|
| def _build_feats(self, x: list[torch.Tensor]) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: |
| return _capsule_build_feats_boxcls_simpleprior(self, x) |
|
|
|
|
| class CapsuleOpenVocabDetect(Detect): |
| """Capsule detection head with open-vocabulary classification via text embedding matching.""" |
|
|
| def __init__( |
| self, |
| nc: int = 80, |
| *args, |
| reg_max: int = 16, |
| end2end: bool = False, |
| embed: int = 256, |
| with_act_gate: bool = False, |
| with_objectness_prior: bool = True, |
| k: list[int] | tuple[int, ...] = (4, 8, 16), |
| d: list[int] | tuple[int, ...] = (16, 16, 16), |
| ch: tuple = (), |
| ): |
| parsed = list(args) |
| if parsed and isinstance(parsed[-1], (list, tuple)): |
| ch = tuple(parsed.pop(-1)) |
|
|
| if len(parsed) not in (2, 4, 7): |
| raise ValueError( |
| "CapsuleOpenVocabDetect expects [k_list, d_list, (reg_max, end2end, embed, with_act_gate, " |
| "with_objectness_prior), ch]." |
| ) |
|
|
| k, d = parsed[0], parsed[1] |
| if len(parsed) == 4: |
| reg_max = int(parsed[2]) |
| end2end = bool(parsed[3]) |
| elif len(parsed) == 7: |
| |
| |
| |
| |
| if type(parsed[3]) is bool and type(parsed[4]) is bool and type(parsed[6]) is bool: |
| embed = int(parsed[2]) |
| with_act_gate = bool(parsed[3]) |
| with_objectness_prior = bool(parsed[4]) |
| reg_max = int(parsed[5]) |
| end2end = bool(parsed[6]) |
| else: |
| reg_max = int(parsed[2]) |
| end2end = bool(parsed[3]) |
| embed = int(parsed[4]) |
| with_act_gate = bool(parsed[5]) |
| with_objectness_prior = bool(parsed[6]) |
|
|
| self.k_list, self.d_list, merged_ch = _setup_capsule_layout(k, d, ch, "CapsuleOpenVocabDetect") |
| super().__init__(nc=nc, reg_max=reg_max, end2end=end2end, ch=merged_ch) |
|
|
| self.embed = int(embed) |
| self.with_act_gate = bool(with_act_gate) |
| self.with_objectness_prior = bool(with_objectness_prior) |
|
|
| self.emb_head = nn.ModuleList() |
| self.ov_gate = nn.ModuleList() |
| self.ov_beta = nn.ParameterList() |
| self.obj_prior = nn.ModuleList() |
|
|
| for k_i, d_i in zip(self.k_list, self.d_list): |
| pose_ch = k_i * d_i |
| self.emb_head.append( |
| nn.Sequential( |
| Conv(pose_ch, pose_ch, 3), |
| DWConv(pose_ch, pose_ch, 3), |
| nn.Conv2d(pose_ch, self.embed, 1, bias=True), |
| ) |
| ) |
| if self.with_act_gate: |
| hidden = max(16, k_i * 2) |
| self.ov_gate.append( |
| nn.Sequential( |
| nn.Conv2d(k_i, hidden, 1, bias=True), |
| nn.SiLU(inplace=True), |
| nn.Conv2d(hidden, pose_ch, 1, bias=True), |
| ) |
| ) |
| self.ov_beta.append(nn.Parameter(torch.tensor(0.1))) |
| else: |
| self.ov_gate.append(nn.Identity()) |
| self.ov_beta.append(nn.Parameter(torch.tensor(0.0), requires_grad=False)) |
|
|
| if self.with_objectness_prior: |
| self.obj_prior.append(nn.Conv2d(k_i, 1, 1, bias=True)) |
| else: |
| self.obj_prior.append(nn.Identity()) |
|
|
| self.logit_scale = nn.Parameter(torch.tensor(math.log(1 / 0.07), dtype=torch.float32)) |
| self.register_buffer("cached_text_embeddings", torch.empty(0), persistent=False) |
|
|
| def set_text_embeddings(self, text_embs: torch.Tensor | None) -> None: |
| """Cache normalized text embeddings for inference.""" |
| if text_embs is None: |
| self.cached_text_embeddings = torch.empty(0, device=self.logit_scale.device) |
| return |
| if text_embs.ndim != 2: |
| raise ValueError(f"text_embs must be 2D [num_classes, embed_dim], got shape {tuple(text_embs.shape)}.") |
| self.cached_text_embeddings = F.normalize(text_embs.detach().to(self.logit_scale.device), dim=-1) |
|
|
| def _split_pose_act(self, x: torch.Tensor, i: int) -> tuple[torch.Tensor, torch.Tensor]: |
| return _capsule_split_pose_act(x, self.k_list[i], self.d_list[i], "CapsuleOpenVocabDetect", i) |
|
|
| def _build_feats(self, x: list[torch.Tensor]) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: |
| return _capsule_build_feats_open_vocab(self, x) |
|
|
| def _prepare_text_embeddings(self, text_embs: torch.Tensor | None, bs: int, device: torch.device) -> torch.Tensor | None: |
| if text_embs is None: |
| if self.cached_text_embeddings.numel() == 0: |
| return None |
| text = self.cached_text_embeddings |
| else: |
| text = text_embs |
|
|
| if text.ndim == 2: |
| text = text.unsqueeze(0).expand(bs, -1, -1) |
| elif text.ndim != 3: |
| raise ValueError(f"text_embs must be 2D or 3D, got shape {tuple(text.shape)}.") |
|
|
| if text.shape[-1] != self.embed: |
| raise ValueError(f"text_embs last dim must equal embed={self.embed}, got {text.shape[-1]}.") |
| return F.normalize(text.to(device=device, dtype=self.logit_scale.dtype), dim=-1) |
|
|
| def _compute_ov_scores( |
| self, cls_feats: list[torch.Tensor], acts: list[torch.Tensor], text_embs: torch.Tensor | None |
| ) -> tuple[torch.Tensor | None, list[torch.Tensor], torch.Tensor | None]: |
| bs = cls_feats[0].shape[0] |
| level_embeddings = [] |
| for i in range(self.nl): |
| emb = self.emb_head[i](cls_feats[i]) |
| if self.with_objectness_prior: |
| emb = emb * (1.0 + torch.sigmoid(self.obj_prior[i](acts[i]))) |
| level_embeddings.append(emb) |
|
|
| text = self._prepare_text_embeddings(text_embs, bs, cls_feats[0].device) |
| if text is None: |
| return None, level_embeddings, None |
|
|
| visual_tokens = torch.cat( |
| [F.normalize(emb.flatten(2).transpose(1, 2), dim=-1) for emb in level_embeddings], |
| dim=1, |
| ) |
| scale = self.logit_scale.exp().clamp(max=100.0) |
| scores = torch.einsum("bnd,bcd->bcn", visual_tokens, text) * scale |
| return scores, level_embeddings, text |
|
|
| def forward_head( |
| self, |
| x: list[torch.Tensor], |
| text_embs: torch.Tensor | None = None, |
| box_head: torch.nn.Module = None, |
| cls_head: torch.nn.Module = None, |
| ) -> dict[str, torch.Tensor]: |
| del cls_head |
| if box_head is None: |
| return dict() |
|
|
| box_feats, cls_feats, acts = self._build_feats(x) |
| bs = x[0].shape[0] |
| boxes = torch.cat([box_head[i](box_feats[i]).view(bs, 4 * self.reg_max, -1) for i in range(self.nl)], dim=-1) |
| scores, level_embeddings, text = self._compute_ov_scores(cls_feats, acts, text_embs) |
|
|
| preds = { |
| "boxes": boxes, |
| "embeddings": level_embeddings, |
| "cls_feats": cls_feats, |
| "acts": acts, |
| "feats": x, |
| } |
| if scores is not None: |
| preds["scores"] = scores |
| preds["text_embeddings"] = text |
| return preds |
|
|
| def forward( |
| self, x: list[torch.Tensor], text_embs: torch.Tensor | None = None |
| ) -> dict[str, torch.Tensor] | torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]: |
| preds = self.forward_head(x, text_embs=text_embs, **self.one2many) |
| if self.end2end: |
| x_detach = [xi.detach() for xi in x] |
| one2one = self.forward_head(x_detach, text_embs=text_embs, **self.one2one) |
| preds = {"one2many": preds, "one2one": one2one} |
| if self.training: |
| return preds |
|
|
| infer_preds = preds["one2one"] if self.end2end else preds |
| if "scores" not in infer_preds: |
| raise ValueError("CapsuleOpenVocabDetect inference requires text_embs or cached text embeddings.") |
|
|
| original_nc = self.nc |
| self.nc = int(infer_preds["scores"].shape[1]) |
| try: |
| y = self._inference(infer_preds) |
| if self.end2end: |
| y = self.postprocess(y.permute(0, 2, 1)) |
| finally: |
| self.nc = original_nc |
| return y if self.export else (y, preds) |
|
|
|
|
| class CapsuleSegmentv1(Segment): |
| """Capsule-style Segment head aligned with CapsuleDetectv6 semantics.""" |
|
|
| def __init__( |
| self, |
| nc: int = 80, |
| *args, |
| nm: int = 32, |
| npr: int = 256, |
| reg_max: int = 16, |
| end2end: bool = False, |
| k: list[int] | tuple[int, ...] = (4, 8, 16), |
| d: list[int] | tuple[int, ...] = (16, 16, 16), |
| ch: tuple = (), |
| ): |
| parsed = list(args) |
| if parsed and isinstance(parsed[-1], (list, tuple)): |
| ch = tuple(parsed.pop(-1)) |
|
|
| if len(parsed) not in (2, 4, 6): |
| raise ValueError("CapsuleSegmentv1 expects [k_list, d_list, (nm, npr), reg_max, end2end, ch].") |
|
|
| k, d = parsed[0], parsed[1] |
| if len(parsed) == 4: |
| if isinstance(parsed[3], bool): |
| reg_max = int(parsed[2]) |
| end2end = bool(parsed[3]) |
| else: |
| nm = int(parsed[2]) |
| npr = int(parsed[3]) |
| elif len(parsed) == 6: |
| nm = int(parsed[2]) |
| npr = int(parsed[3]) |
| reg_max = int(parsed[4]) |
| end2end = bool(parsed[5]) |
|
|
| self.k_list, self.d_list, merged_ch = _setup_capsule_layout(k, d, ch, "CapsuleSegmentv1") |
| super().__init__(nc=nc, nm=nm, npr=npr, reg_max=reg_max, end2end=end2end, ch=merged_ch) |
| _init_capsule_semantic_heads(self) |
| self.proto = Proto26(merged_ch, self.npr, self.nm, nc) |
|
|
| def _split_pose_act(self, x: torch.Tensor, i: int) -> tuple[torch.Tensor, torch.Tensor]: |
| return _capsule_split_pose_act(x, self.k_list[i], self.d_list[i], "CapsuleSegmentv1", i) |
|
|
| def _build_feats(self, x: list[torch.Tensor]) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: |
| return _capsule_build_feats_gated(self, x) |
|
|
| def forward_head( |
| self, |
| x: list[torch.Tensor], |
| box_head: torch.nn.Module = None, |
| cls_head: torch.nn.Module = None, |
| mask_head: torch.nn.Module = None, |
| ) -> dict[str, torch.Tensor]: |
| if box_head is None or cls_head is None: |
| return dict() |
|
|
| box_feats, cls_feats, cls_priors = self._build_feats(x) |
| bs = x[0].shape[0] |
| boxes = torch.cat([box_head[i](box_feats[i]).view(bs, 4 * self.reg_max, -1) for i in range(self.nl)], dim=-1) |
| scores = torch.cat( |
| [(cls_head[i](cls_feats[i]) + cls_priors[i]).view(bs, self.nc, -1) for i in range(self.nl)], |
| dim=-1, |
| ) |
| preds = dict(boxes=boxes, scores=scores, feats=cls_feats) |
| if mask_head is not None: |
| preds["mask_coefficient"] = torch.cat( |
| [mask_head[i](cls_feats[i]).view(bs, self.nm, -1) for i in range(self.nl)], |
| dim=-1, |
| ) |
| return preds |
|
|
| def forward(self, x: list[torch.Tensor]) -> tuple | list[torch.Tensor] | dict[str, torch.Tensor]: |
| _, cls_feats, _ = self._build_feats(x) |
| outputs = Detect.forward(self, x) |
| preds = outputs[1] if isinstance(outputs, tuple) else outputs |
| proto_in = cls_feats |
| proto = self.proto(proto_in) |
| if isinstance(preds, dict): |
| if self.end2end: |
| preds["one2many"]["proto"] = proto |
| preds["one2one"]["proto"] = tuple(p.detach() for p in proto) if isinstance(proto, tuple) else proto.detach() |
| else: |
| preds["proto"] = proto |
| if self.training: |
| return preds |
| return (outputs, proto) if self.export else ((outputs[0], proto), preds) |
|
|
|
|
| class CapsuleSegmentv2(CapsuleSegmentv1): |
| """Capsule Segment v2: cls head consumes raw pose features and symbolic priors only.""" |
|
|
| def _build_feats(self, x: list[torch.Tensor]) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: |
| return _capsule_build_feats_boxcls(self, x) |
|
|
|
|
| class CapsuleSegmentv3(CapsuleSegmentv1): |
| """Capsule Segment v3: raw pose cls path with simplified cls_prior (no norm/centering).""" |
|
|
| def _build_feats(self, x: list[torch.Tensor]) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: |
| return _capsule_build_feats_boxcls_simpleprior(self, x) |
|
|
|
|