| | |
| | |
| |
|
| | class CapacityHead(nn.Module): |
| | def __init__(self, in_dim, feat_dim, init_capacity=1.0): |
| | super().__init__() |
| | self._raw_capacity = nn.Parameter(torch.tensor(math.log(math.exp(init_capacity) - 1))) |
| | |
| | self.evidence_net = nn.Sequential( |
| | nn.Linear(in_dim, feat_dim), nn.GELU(), nn.Linear(feat_dim, 1)) |
| | self.feature_net = nn.Sequential( |
| | nn.Linear(in_dim, feat_dim), nn.GELU(), nn.Linear(feat_dim, feat_dim)) |
| | self.retain_gate = nn.Sequential( |
| | nn.Linear(feat_dim + 1, feat_dim), nn.Sigmoid()) |
| | self.overflow_gate = nn.Sequential( |
| | nn.Linear(feat_dim + 1, feat_dim), nn.Sigmoid()) |
| |
|
| | @property |
| | def capacity(self): |
| | return F.softplus(self._raw_capacity) |
| |
|
| | def forward(self, x): |
| | cap = self.capacity |
| | raw_ev = F.relu(self.evidence_net(x)) |
| | fill = torch.clamp(raw_ev / (cap + 1e-8), max=1.0) |
| | sat = torch.clamp((raw_ev - cap) / (cap + 1e-8), min=0.0) |
| | feat = self.feature_net(x) |
| | retained = self.retain_gate(torch.cat([feat, fill], -1)) * feat * fill |
| | overflow = self.overflow_gate(torch.cat([feat, sat], -1)) * feat * torch.clamp(sat, max=1.0) |
| | return fill, overflow, retained, cap, raw_ev |
| |
|
| |
|
| | |
| |
|
| | class DifferentiationGate(nn.Module): |
| | """ |
| | Curvature direction analysis via occupancy field differentiation. |
| | |
| | Computes gradient and Laplacian of the 3D occupancy field to determine: |
| | - Curvature direction: convex (normals point outward) vs concave (inward) |
| | - Curvature alternation: where sign flips (saddle points, torus inner/outer) |
| | - Perturbation robustness: smoothed gradient features survive noise |
| | |
| | The key insight: a hemisphere and bowl occupy nearly identical voxels, |
| | but their occupancy gradients point in opposite directions relative |
| | to the center of mass. The Laplacian's sign distinguishes them. |
| | |
| | Outputs gate signals that modulate curvature features: |
| | - direction_gate: learned weighting based on gradient analysis |
| | - alternation_score: how much curvature sign varies spatially |
| | - directional_features: rich features encoding curvature orientation |
| | """ |
| |
|
| | def __init__(self, embed_dim=64): |
| | super().__init__() |
| |
|
| | |
| | |
| | diff_kernels = torch.zeros(4, 1, 3, 3, 3) |
| | |
| | diff_kernels[0, 0, 0, 1, 1] = -1; diff_kernels[0, 0, 2, 1, 1] = 1 |
| | |
| | diff_kernels[1, 0, 1, 0, 1] = -1; diff_kernels[1, 0, 1, 2, 1] = 1 |
| | |
| | diff_kernels[2, 0, 1, 1, 0] = -1; diff_kernels[2, 0, 1, 1, 2] = 1 |
| | |
| | diff_kernels[3, 0, 1, 1, 1] = -6 |
| | diff_kernels[3, 0, 0, 1, 1] = 1; diff_kernels[3, 0, 2, 1, 1] = 1 |
| | diff_kernels[3, 0, 1, 0, 1] = 1; diff_kernels[3, 0, 1, 2, 1] = 1 |
| | diff_kernels[3, 0, 1, 1, 0] = 1; diff_kernels[3, 0, 1, 1, 2] = 1 |
| | self.register_buffer("diff_kernels", diff_kernels) |
| |
|
| | |
| | coords = torch.stack(torch.meshgrid( |
| | torch.arange(GS, dtype=torch.float32), |
| | torch.arange(GS, dtype=torch.float32), |
| | torch.arange(GS, dtype=torch.float32), |
| | indexing="ij"), dim=-1) |
| | self.register_buffer("coords", coords) |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | raw_feat_dim = 3 + 3 + 1 + 3 + 5 |
| | |
| | self.lap_conv = nn.Sequential( |
| | nn.Conv3d(1, 16, 3, padding=1), nn.GELU(), |
| | nn.Conv3d(16, 16, 3, padding=1), nn.GELU(), |
| | nn.AdaptiveAvgPool3d(2)) |
| | lap_conv_dim = 16 * 8 |
| |
|
| | |
| | self.grad_conv = nn.Sequential( |
| | nn.Conv3d(3, 16, 3, padding=1), nn.GELU(), |
| | nn.Conv3d(16, 16, 3, padding=1), nn.GELU(), |
| | nn.AdaptiveAvgPool3d(2)) |
| | grad_conv_dim = 16 * 8 |
| |
|
| | total_feat_dim = raw_feat_dim + lap_conv_dim + grad_conv_dim |
| |
|
| | |
| | self.direction_net = nn.Sequential( |
| | SwiGLU(total_feat_dim, embed_dim), |
| | nn.Linear(embed_dim, embed_dim), nn.Sigmoid()) |
| |
|
| | |
| | self.direction_feat_net = nn.Sequential( |
| | SwiGLU(total_feat_dim, embed_dim), |
| | nn.Linear(embed_dim, embed_dim)) |
| |
|
| | def forward(self, grid): |
| | """ |
| | grid: (B, 5, 5, 5) binary occupancy |
| | |
| | Returns: |
| | direction_gate: (B, embed_dim) sigmoid gate for curvature features |
| | direction_feat: (B, embed_dim) additive directional features |
| | alternation_score: (B, 1) how much curvature alternates |
| | """ |
| | B = grid.shape[0] |
| | device = grid.device |
| | vox = grid.unsqueeze(1) |
| |
|
| | |
| | |
| | |
| | vox_smooth = F.avg_pool3d( |
| | F.pad(vox, (1,1,1,1,1,1), mode='replicate'), |
| | kernel_size=3, stride=1, padding=0) |
| |
|
| | |
| | diff = F.conv3d(vox_smooth, self.diff_kernels, padding=1) |
| | grad_field = diff[:, :3] |
| | gx, gy, gz = diff[:, 0:1], diff[:, 1:2], diff[:, 2:3] |
| | lap = diff[:, 3:4] |
| |
|
| | |
| | flat_grid = grid.reshape(B, -1) |
| | flat_coords = self.coords.reshape(-1, 3) |
| | total_occ = flat_grid.sum(dim=-1, keepdim=True).clamp(min=1) |
| | centroids = (flat_grid.unsqueeze(-1) * flat_coords.unsqueeze(0)).sum(dim=1) / total_occ |
| |
|
| | |
| | grad_flat = grad_field.reshape(B, 3, -1).permute(0, 2, 1) |
| | diff_from_center = flat_coords.unsqueeze(0) - centroids.unsqueeze(1) |
| | diff_norm = diff_from_center / (diff_from_center.norm(dim=-1, keepdim=True) + 1e-8) |
| | dot_products = (grad_flat * diff_norm).sum(dim=-1) |
| | grad_mag = grad_flat.norm(dim=-1) |
| | active = (flat_grid > 0.5) & (grad_mag > 0.01) |
| |
|
| | |
| | n_active = active.float().sum(-1).clamp(min=1) |
| | frac_outward = ((dot_products > 0.1) & active).float().sum(-1) / n_active |
| | frac_inward = ((dot_products < -0.1) & active).float().sum(-1) / n_active |
| | frac_neutral = 1.0 - frac_outward - frac_inward |
| | direction_hist = torch.stack([frac_outward, frac_inward, frac_neutral], dim=-1) |
| |
|
| | |
| | lap_flat = lap.reshape(B, -1) |
| | lap_active = flat_grid > 0.5 |
| | n_lap_active = lap_active.float().sum(-1).clamp(min=1) |
| | frac_pos_lap = ((lap_flat > 0.1) & lap_active).float().sum(-1) / n_lap_active |
| | frac_neg_lap = ((lap_flat < -0.1) & lap_active).float().sum(-1) / n_lap_active |
| | frac_zero_lap = 1.0 - frac_pos_lap - frac_neg_lap |
| | lap_hist = torch.stack([frac_pos_lap, frac_neg_lap, frac_zero_lap], dim=-1) |
| |
|
| | |
| | |
| | |
| | lap_3d = lap.squeeze(1) |
| | |
| | boundary_mask = F.max_pool3d(vox, kernel_size=3, stride=1, padding=1).squeeze(1) |
| |
|
| | |
| | bm_x = boundary_mask[:, 1:, :, :] * boundary_mask[:, :-1, :, :] |
| | flip_x = (torch.sign(lap_3d[:, 1:, :, :]) * torch.sign(lap_3d[:, :-1, :, :]) < 0).float() |
| | active_flips_x = (flip_x * bm_x).sum(dim=(1, 2, 3)) |
| | active_pairs_x = bm_x.sum(dim=(1, 2, 3)).clamp(min=1) |
| |
|
| | bm_y = boundary_mask[:, :, 1:, :] * boundary_mask[:, :, :-1, :] |
| | flip_y = (torch.sign(lap_3d[:, :, 1:, :]) * torch.sign(lap_3d[:, :, :-1, :]) < 0).float() |
| | active_flips_y = (flip_y * bm_y).sum(dim=(1, 2, 3)) |
| | active_pairs_y = bm_y.sum(dim=(1, 2, 3)).clamp(min=1) |
| |
|
| | bm_z = boundary_mask[:, :, :, 1:] * boundary_mask[:, :, :, :-1] |
| | flip_z = (torch.sign(lap_3d[:, :, :, 1:]) * torch.sign(lap_3d[:, :, :, :-1]) < 0).float() |
| | active_flips_z = (flip_z * bm_z).sum(dim=(1, 2, 3)) |
| | active_pairs_z = bm_z.sum(dim=(1, 2, 3)).clamp(min=1) |
| |
|
| | alternation = ((active_flips_x / active_pairs_x + |
| | active_flips_y / active_pairs_y + |
| | active_flips_z / active_pairs_z) / 3.0).unsqueeze(-1) |
| |
|
| | |
| | |
| | gx_mean = (gx.squeeze(1) * grid).sum(dim=(1, 2, 3)) / total_occ.squeeze(-1) |
| | gy_mean = (gy.squeeze(1) * grid).sum(dim=(1, 2, 3)) / total_occ.squeeze(-1) |
| | gz_mean = (gz.squeeze(1) * grid).sum(dim=(1, 2, 3)) / total_occ.squeeze(-1) |
| | grad_asym = torch.stack([gx_mean, gy_mean, gz_mean], dim=-1) |
| |
|
| | |
| | |
| | dists = diff_from_center.norm(dim=-1) |
| | |
| | |
| | bin_idx = torch.nan_to_num(dists * (5.0 / 3.5), nan=0.0).long().clamp(0, 4) |
| | active_mask = (flat_grid > 0.5) |
| | radial_grad = torch.zeros(B, 5, device=device) |
| | |
| | weighted_mag = grad_mag * active_mask.float() |
| | one_hot = F.one_hot(bin_idx, 5).float() |
| | active_oh = one_hot * active_mask.float().unsqueeze(-1) |
| | counts = active_oh.sum(dim=1).clamp(min=1) |
| | radial_grad = (weighted_mag.unsqueeze(-1) * active_oh).sum(dim=1) / counts |
| | |
| |
|
| | |
| | lap_feat = self.lap_conv(lap).reshape(B, -1) |
| |
|
| | |
| | grad_feat = self.grad_conv(grad_field).reshape(B, -1) |
| |
|
| | |
| | raw_feat = torch.cat([ |
| | direction_hist, |
| | lap_hist, |
| | alternation, |
| | grad_asym, |
| | radial_grad, |
| | ], dim=-1) |
| |
|
| | all_feat = torch.cat([raw_feat, lap_feat, grad_feat], dim=-1) |
| |
|
| | direction_gate = self.direction_net(all_feat) |
| | direction_feat = self.direction_feat_net(all_feat) |
| |
|
| | return direction_gate, direction_feat, alternation |
| |
|
| |
|
| | |
| |
|
| | def deform_grid(grid, p_dropout=0.1, p_add=0.1, p_shift=0.15): |
| | """Fully vectorized voxel augmentation — zero CPU-GPU sync points.""" |
| | B = grid.shape[0] |
| | device = grid.device |
| | r = torch.rand(B, 3, device=device) |
| | out = grid.clone() |
| |
|
| | |
| | drop_sel = (r[:, 0] < p_dropout).view(B, 1, 1, 1) |
| | keep = torch.rand_like(out) > 0.15 |
| | out = torch.where(drop_sel, out * keep.float(), out) |
| |
|
| | |
| | add_sel = (r[:, 1] < p_add).view(B, 1, 1, 1).float() |
| | dilated = F.max_pool3d(out.unsqueeze(1), kernel_size=3, stride=1, padding=1).squeeze(1) |
| | boundary = ((dilated > 0.5) & (out < 0.5)).float() |
| | add_noise = (torch.rand_like(out) < 0.3).float() |
| | out = (out + boundary * add_noise * add_sel).clamp(max=1.0) |
| |
|
| | |
| | shift_sel = (r[:, 2] < p_shift) |
| | axes = torch.randint(3, (B,), device=device) |
| | dirs = torch.randint(0, 2, (B,), device=device) * 2 - 1 |
| |
|
| | |
| | |
| | versions = [] |
| | for ax in range(3): |
| | for d in [-1, 1]: |
| | s = torch.roll(out, shifts=d, dims=ax + 1) |
| | |
| | if d == 1: |
| | if ax == 0: s[:, 0, :, :] = 0 |
| | elif ax == 1: s[:, :, 0, :] = 0 |
| | else: s[:, :, :, 0] = 0 |
| | else: |
| | if ax == 0: s[:, -1, :, :] = 0 |
| | elif ax == 1: s[:, :, -1, :] = 0 |
| | else: s[:, :, :, -1] = 0 |
| | versions.append(s) |
| | versions.append(out) |
| | stacked = torch.stack(versions, dim=0) |
| |
|
| | |
| | assign = torch.where(shift_sel, axes * 2 + (dirs == 1).long(), torch.full_like(axes, 6)) |
| | |
| | out = stacked[assign, torch.arange(B, device=device)] |
| |
|
| | return out |
| |
|
| |
|
| | |
| |
|
| | class CurvatureHead(nn.Module): |
| | """ |
| | Axis-aware curvature detection with differentiation gating. |
| | |
| | 1. Per-axis max projections -> 2D conv (keeps 2×2 spatial) |
| | 2. Radial occupancy profile from centroid |
| | 3. Axial symmetry + translation invariance scores |
| | 4. 3D conv with spatial preservation (2×2×2) |
| | 5. DifferentiationGate: gradient/Laplacian analysis for direction detection |
| | |
| | The DifferentiationGate modulates curvature features so that |
| | convex and concave shapes get distinct representations even when |
| | their occupancy patterns are nearly identical. |
| | """ |
| |
|
| | def __init__(self, rigid_feat_dim, fill_dim, embed_dim): |
| | super().__init__() |
| |
|
| | self.plane_conv = nn.Sequential( |
| | nn.Conv2d(1, 16, 3, padding=1), nn.GELU(), |
| | nn.Conv2d(16, 16, 3, padding=1), nn.GELU(), |
| | nn.AdaptiveAvgPool2d(2)) |
| | plane_feat_dim = 3 * 16 * 4 |
| |
|
| | n_radial = 5 |
| | self.radial_net = nn.Sequential( |
| | nn.Linear(n_radial, 32), nn.GELU(), nn.Linear(32, 16)) |
| | radial_feat_dim = 16 |
| |
|
| | symmetry_feat_dim = 6 |
| |
|
| | self.voxel_conv = nn.Sequential( |
| | nn.Conv3d(1, 16, 3, padding=1), nn.GELU(), |
| | nn.Conv3d(16, 32, 3, padding=1), nn.GELU(), |
| | nn.AdaptiveAvgPool3d(2)) |
| | voxel3d_feat_dim = 32 * 8 |
| |
|
| | |
| | self.diff_gate = DifferentiationGate(embed_dim) |
| |
|
| | |
| | pre_gate_dim = (plane_feat_dim + radial_feat_dim + symmetry_feat_dim + |
| | voxel3d_feat_dim + rigid_feat_dim + fill_dim) |
| |
|
| | |
| | self.pre_gate_proj = nn.Sequential( |
| | SwiGLU(pre_gate_dim, embed_dim * 2), |
| | nn.Linear(embed_dim * 2, embed_dim)) |
| |
|
| | |
| | |
| | post_gate_dim = embed_dim + embed_dim + 1 + pre_gate_dim |
| |
|
| | |
| | self.curved_head = nn.Sequential( |
| | SwiGLU(post_gate_dim, embed_dim), |
| | nn.Linear(embed_dim, 1), nn.Sigmoid()) |
| | self.curv_type_head = nn.Sequential( |
| | SwiGLU(post_gate_dim, embed_dim), |
| | nn.Linear(embed_dim, NUM_CURVATURES)) |
| | self.curv_features = nn.Sequential( |
| | SwiGLU(post_gate_dim, embed_dim * 2), |
| | nn.Linear(embed_dim * 2, embed_dim)) |
| |
|
| | def forward(self, grid, rigid_retained, fill_ratios): |
| | B = grid.shape[0] |
| |
|
| | proj_x = grid.max(dim=1).values |
| | proj_y = grid.max(dim=2).values |
| | proj_z = grid.max(dim=3).values |
| |
|
| | |
| | projs_batched = torch.cat([ |
| | proj_x.unsqueeze(1), proj_y.unsqueeze(1), proj_z.unsqueeze(1) |
| | ], dim=0) |
| | plane_all = self.plane_conv(projs_batched).reshape(3, B, -1) |
| | plane_feat = plane_all.permute(1, 0, 2).reshape(B, -1) |
| |
|
| | radial = self._radial_profile(grid) |
| | radial_feat = self.radial_net(radial) |
| |
|
| | sym_feat = self._symmetry_features(proj_x, proj_y, proj_z) |
| |
|
| | vox3d_feat = self.voxel_conv(grid.unsqueeze(1)).reshape(B, -1) |
| |
|
| | |
| | raw_combined = torch.cat([ |
| | plane_feat, radial_feat, sym_feat, vox3d_feat, |
| | rigid_retained, fill_ratios], dim=-1) |
| |
|
| | |
| | pre_gate = self.pre_gate_proj(raw_combined) |
| |
|
| | |
| | dir_gate, dir_feat, alternation = self.diff_gate(grid) |
| |
|
| | |
| | gated = pre_gate * dir_gate |
| |
|
| | |
| | combined = torch.cat([gated, dir_feat, alternation, raw_combined], dim=-1) |
| |
|
| | is_curved = self.curved_head(combined) |
| | curv_logits = self.curv_type_head(combined) |
| | curv_feat = self.curv_features(combined) |
| | return is_curved, curv_logits, curv_feat, alternation |
| |
|
| | def _radial_profile(self, grid): |
| | B = grid.shape[0] |
| | device = grid.device |
| | coords = torch.stack(torch.meshgrid( |
| | torch.arange(GS, device=device, dtype=torch.float32), |
| | torch.arange(GS, device=device, dtype=torch.float32), |
| | torch.arange(GS, device=device, dtype=torch.float32), |
| | indexing="ij"), dim=-1) |
| | flat_grid = grid.reshape(B, -1) |
| | flat_coords = coords.reshape(-1, 3) |
| | total_occ = flat_grid.sum(dim=-1, keepdim=True).clamp(min=1) |
| | centroids = (flat_grid.unsqueeze(-1) * flat_coords.unsqueeze(0)).sum(dim=1) / total_occ |
| | diffs = flat_coords.unsqueeze(0) - centroids.unsqueeze(1) |
| | dists = diffs.norm(dim=-1) |
| | max_dist = 3.5 |
| | n_bins = 5 |
| | |
| | bin_idx = torch.nan_to_num(dists * (float(n_bins) / max_dist), nan=0.0).long().clamp(0, n_bins - 1) |
| | one_hot = F.one_hot(bin_idx, n_bins).float() |
| | weighted = flat_grid.unsqueeze(-1) * one_hot |
| | profile = weighted.sum(dim=1) / total_occ |
| | return profile |
| |
|
| | def _symmetry_features(self, proj_x, proj_y, proj_z): |
| | projs = torch.stack([proj_x, proj_y, proj_z], dim=1) |
| | fh = torch.flip(projs, dims=[2]) |
| | fv = torch.flip(projs, dims=[3]) |
| | sym = 1.0 - ((projs - fh).abs().mean(dim=(2, 3)) + |
| | (projs - fv).abs().mean(dim=(2, 3))) / 2 |
| | shift_diff = (projs[:, :, 1:, :] - projs[:, :, :-1, :]).abs().mean(dim=(2, 3)) |
| | trans_inv = 1.0 - shift_diff |
| | |
| | return torch.stack([sym[:, 0], trans_inv[:, 0], |
| | sym[:, 1], trans_inv[:, 1], |
| | sym[:, 2], trans_inv[:, 2]], dim=-1) |
| |
|
| |
|
| | |
| |
|
| | def compute_confidence(logits): |
| | """ |
| | Compute real calibrated confidence metrics from logits. |
| | |
| | Returns dict with: |
| | max_prob: max(softmax(logits)) — calibrated top-class probability |
| | margin: top1_prob - top2_prob — disambiguation strength |
| | entropy: -sum(p * log(p)) — total uncertainty (lower = more confident) |
| | confidence: margin — primary confidence signal for gating |
| | """ |
| | probs = F.softmax(logits, dim=-1) |
| | max_prob, _ = probs.max(dim=-1) |
| |
|
| | top2 = probs.topk(2, dim=-1).values |
| | margin = top2[:, 0] - top2[:, 1] |
| |
|
| | |
| | log_probs = F.log_softmax(logits, dim=-1) |
| | entropy = -(probs * log_probs).sum(dim=-1) |
| | max_entropy = math.log(logits.shape[-1]) |
| | norm_entropy = entropy / max_entropy |
| |
|
| | return { |
| | "max_prob": max_prob, |
| | "margin": margin, |
| | "entropy": norm_entropy, |
| | "confidence": margin, |
| | } |
| |
|
| |
|
| | |
| |
|
| | class RectifiedFlowArbiter(nn.Module): |
| | """ |
| | Rectified flow matching for ambiguous classification refinement. |
| | |
| | Real flow matching requires a target endpoint to define the velocity field. |
| | We learn class prototypes in latent space as targets: for a sample of class c, |
| | the target is prototype[c]. The velocity field learns to transport the |
| | encoded feature z0 toward the correct prototype z1 in straight lines: |
| | |
| | v_target = z1 - z0 (rectified: straight path from source to target) |
| | loss = ||v_predicted - v_target||^2 (flow matching objective) |
| | |
| | At inference, the arbiter integrates the learned velocity field from z0, |
| | landing near the correct class prototype. Classification reads off the |
| | nearest prototype. |
| | |
| | Confidence gating: velocity magnitude is scaled by (1 - margin), so |
| | confident first-pass predictions receive minimal correction. |
| | """ |
| |
|
| | def __init__(self, feat_dim, n_classes, n_steps=4, latent_dim=128, embed_dim=64): |
| | super().__init__() |
| | self.n_steps = n_steps |
| | self.n_classes = n_classes |
| | self.dt = 1.0 / n_steps |
| | self.latent_dim = latent_dim |
| |
|
| | |
| | self.encode = nn.Sequential( |
| | nn.Linear(feat_dim, latent_dim * 2), nn.GELU(), |
| | nn.Linear(latent_dim * 2, latent_dim)) |
| |
|
| | |
| | self.prototypes = nn.Parameter(torch.randn(n_classes, latent_dim) * 0.05) |
| |
|
| | |
| | self.time_embed = nn.Sequential( |
| | nn.Linear(16, embed_dim), nn.GELU(), |
| | nn.Linear(embed_dim, embed_dim)) |
| |
|
| | |
| | self.conf_embed = nn.Sequential( |
| | nn.Linear(3, embed_dim), nn.GELU(), |
| | nn.Linear(embed_dim, embed_dim)) |
| |
|
| | |
| | vel_in = latent_dim + embed_dim + embed_dim |
| | self.velocity = nn.Sequential( |
| | SwiGLU(vel_in, latent_dim), |
| | nn.Linear(latent_dim, latent_dim), |
| | SwiGLU(latent_dim, latent_dim), |
| | nn.Linear(latent_dim, latent_dim)) |
| |
|
| | |
| | self.vel_gate = nn.Sequential( |
| | nn.Linear(embed_dim, latent_dim), nn.Sigmoid()) |
| |
|
| | |
| | self.classifier_head = nn.Sequential( |
| | SwiGLU(latent_dim + n_classes, 96), |
| | nn.Linear(96, n_classes)) |
| |
|
| | |
| | self.blend_head = nn.Sequential( |
| | nn.Linear(feat_dim, 64), nn.GELU(), |
| | nn.Linear(64, 1), nn.Sigmoid()) |
| |
|
| | |
| | self.refined_confidence = nn.Sequential( |
| | SwiGLU(latent_dim, 32), |
| | nn.Linear(32, 1), nn.Sigmoid()) |
| |
|
| | def _time_encoding(self, t, device): |
| | freqs = torch.exp(torch.linspace(0, -4, 8, device=device)) |
| | args = t.unsqueeze(-1) * freqs.unsqueeze(0) |
| | return torch.cat([args.sin(), args.cos()], dim=-1) |
| |
|
| | def _proto_logits(self, z): |
| | """Classify by negative distance to prototypes.""" |
| | |
| | dists = torch.cdist(z.unsqueeze(0), self.prototypes.unsqueeze(0)).squeeze(0) |
| | |
| | combined = torch.cat([z, -dists], dim=-1) |
| | return self.classifier_head(combined) |
| |
|
| | def forward(self, features, initial_logits, labels=None): |
| | """ |
| | features: (B, feat_dim) |
| | initial_logits: (B, n_classes) |
| | labels: (B,) — only during training, for flow matching target |
| | |
| | Returns: |
| | refined_logits, refined_conf, initial_conf, trajectory_logits, flow_loss |
| | """ |
| | B = features.shape[0] |
| | device = features.device |
| |
|
| | |
| | initial_conf = compute_confidence(initial_logits) |
| | conf_input = torch.stack([ |
| | initial_conf["max_prob"], |
| | initial_conf["margin"], |
| | initial_conf["entropy"]], dim=-1) |
| | conf_emb = self.conf_embed(conf_input) |
| |
|
| | |
| | gate = self.vel_gate(conf_emb) |
| | inv_conf = (1.0 - initial_conf["margin"]).unsqueeze(-1) |
| | adaptive_gate = gate * inv_conf |
| |
|
| | |
| | z0 = self.encode(features) |
| |
|
| | |
| | flow_loss = torch.tensor(0.0, device=device) |
| | if labels is not None: |
| | |
| | z1 = self.prototypes[labels] |
| | |
| | v_target = z1 - z0 |
| |
|
| | |
| | t_rand = torch.rand(B, device=device) |
| | t_emb = self.time_embed(self._time_encoding(t_rand, device)) |
| |
|
| | |
| | z_t = z0 + t_rand.unsqueeze(-1) * v_target |
| |
|
| | |
| | vel_input = torch.cat([z_t, t_emb, conf_emb], dim=-1) |
| | v_pred = self.velocity(vel_input) * adaptive_gate |
| | v_pred = v_pred.clamp(-20, 20) |
| |
|
| | |
| | flow_loss = F.mse_loss(v_pred, v_target.clamp(-20, 20)) |
| |
|
| | |
| | z = z0 |
| | trajectory_logits = [] |
| | for step in range(self.n_steps): |
| | t_val = torch.full((B,), step * self.dt, device=device) |
| | t_emb = self.time_embed(self._time_encoding(t_val, device)) |
| |
|
| | vel_input = torch.cat([z, t_emb, conf_emb], dim=-1) |
| | v = self.velocity(vel_input) * adaptive_gate |
| | |
| | v = v.clamp(-20, 20) |
| |
|
| | z = z + self.dt * v |
| | trajectory_logits.append(self._proto_logits(z)) |
| |
|
| | refined_logits = trajectory_logits[-1] |
| | refined_conf = self.refined_confidence(z) |
| |
|
| | |
| | blend_weight = self.blend_head(features) |
| |
|
| | return refined_logits, refined_conf, initial_conf, trajectory_logits, flow_loss, blend_weight |
| |
|
| |
|
| | |
| |
|
| | class GeometricShapeClassifier(nn.Module): |
| | def __init__(self, n_classes=NUM_CLASSES, embed_dim=64, n_tracers=5): |
| | super().__init__() |
| | self.n_tracers = n_tracers |
| | self.embed_dim = embed_dim |
| |
|
| | self.voxel_embed = nn.Sequential( |
| | nn.Linear(4, embed_dim), nn.GELU(), nn.Linear(embed_dim, embed_dim)) |
| |
|
| | coords = torch.stack(torch.meshgrid( |
| | torch.arange(GS, dtype=torch.float32), |
| | torch.arange(GS, dtype=torch.float32), |
| | torch.arange(GS, dtype=torch.float32), |
| | indexing="ij"), dim=-1) / (GS - 1) |
| | self.register_buffer("pos_grid", coords) |
| |
|
| | self.tracer_tokens = nn.Parameter(torch.randn(n_tracers, embed_dim) * 0.02) |
| | self.tracer_attn = nn.MultiheadAttention(embed_dim, num_heads=4, batch_first=True) |
| | self.tracer_gate = nn.Sequential(nn.Linear(embed_dim * 2, embed_dim), nn.Sigmoid()) |
| | self.tracer_interact = nn.Sequential( |
| | nn.Linear(embed_dim * 2, embed_dim), nn.GELU(), nn.Linear(embed_dim, embed_dim)) |
| | |
| | self.edge_head = nn.Sequential( |
| | SwiGLU(embed_dim * 2, 32), nn.Linear(32, 1)) |
| |
|
| | |
| | _pi, _pj = [], [] |
| | for i in range(n_tracers): |
| | for j in range(i + 1, n_tracers): |
| | _pi.append(i); _pj.append(j) |
| | self.register_buffer("_pair_i", torch.tensor(_pi, dtype=torch.long)) |
| | self.register_buffer("_pair_j", torch.tensor(_pj, dtype=torch.long)) |
| | self.n_pairs = len(_pi) |
| |
|
| | pool_dim = embed_dim * n_tracers |
| |
|
| | self.dim0 = CapacityHead(pool_dim, embed_dim, init_capacity=0.5) |
| | self.dim1 = CapacityHead(pool_dim + embed_dim, embed_dim, init_capacity=1.0) |
| | self.dim2 = CapacityHead(pool_dim + embed_dim, embed_dim, init_capacity=1.5) |
| | self.dim3 = CapacityHead(pool_dim + embed_dim, embed_dim, init_capacity=2.0) |
| |
|
| | rigid_feat_dim = embed_dim * 4 |
| | self.curvature = CurvatureHead(rigid_feat_dim, fill_dim=4, embed_dim=embed_dim) |
| |
|
| | class_in = pool_dim + 4 + rigid_feat_dim + embed_dim + 1 |
| | self.class_in = class_in |
| | self.classifier = nn.Sequential( |
| | nn.Linear(class_in, 256), nn.GELU(), nn.Dropout(0.1), |
| | nn.Linear(256, 128), nn.GELU(), nn.Linear(128, n_classes)) |
| |
|
| | |
| | self.peak_head = nn.Sequential( |
| | SwiGLU(class_in, 32), nn.Linear(32, 4)) |
| | |
| | self.volume_head = nn.Sequential( |
| | nn.Linear(class_in, 64), nn.GELU(), nn.Linear(64, 1)) |
| | |
| | self.cm_head = nn.Sequential( |
| | SwiGLU(class_in, 64), nn.Linear(64, 1), nn.Tanh()) |
| |
|
| | |
| | self.arbiter = RectifiedFlowArbiter( |
| | feat_dim=class_in, n_classes=n_classes, |
| | n_steps=4, latent_dim=128, embed_dim=embed_dim) |
| |
|
| | def forward(self, grid, labels=None): |
| | B = grid.shape[0] |
| | occ = grid.reshape(B, GS**3, 1) |
| | pos = self.pos_grid.reshape(1, GS**3, 3).expand(B, -1, -1) |
| | voxel_emb = self.voxel_embed(torch.cat([occ, pos], dim=-1)) |
| |
|
| | tracers = self.tracer_tokens.unsqueeze(0).expand(B, -1, -1) |
| | tracers, _ = self.tracer_attn(tracers, voxel_emb, voxel_emb) |
| |
|
| | |
| | left = tracers[:, self._pair_i] |
| | right = tracers[:, self._pair_j] |
| | pairs = torch.cat([left, right], dim=-1) |
| |
|
| | |
| | flat_pairs = pairs.reshape(B * self.n_pairs, -1) |
| | gate = self.tracer_gate(flat_pairs).reshape(B, self.n_pairs, -1) |
| | interaction = self.tracer_interact(flat_pairs).reshape(B, self.n_pairs, -1) |
| | edge_lengths = self.edge_head(flat_pairs).reshape(B, self.n_pairs) |
| |
|
| | |
| | gated = gate * interaction |
| | tracer_out = tracers.clone() |
| | pi_exp = self._pair_i.view(1, self.n_pairs, 1).expand(B, -1, self.embed_dim) |
| | pj_exp = self._pair_j.view(1, self.n_pairs, 1).expand(B, -1, self.embed_dim) |
| | tracer_out.scatter_add_(1, pi_exp, gated) |
| | tracer_out.scatter_add_(1, pj_exp, gated) |
| | pooled = tracer_out.reshape(B, -1) |
| |
|
| | fill0, ovf0, ret0, cap0, _ = self.dim0(pooled) |
| | fill1, ovf1, ret1, cap1, _ = self.dim1(torch.cat([pooled, ovf0], -1)) |
| | fill2, ovf2, ret2, cap2, _ = self.dim2(torch.cat([pooled, ovf1], -1)) |
| | fill3, ovf3, ret3, cap3, _ = self.dim3(torch.cat([pooled, ovf2], -1)) |
| |
|
| | fill_ratios = torch.cat([fill0, fill1, fill2, fill3], dim=-1) |
| | rigid_retained = torch.cat([ret0, ret1, ret2, ret3], dim=-1) |
| | ovf_norms = torch.stack([ |
| | ovf0.norm(dim=-1), ovf1.norm(dim=-1), |
| | ovf2.norm(dim=-1), ovf3.norm(dim=-1)], dim=-1) |
| |
|
| | is_curved, curv_logits, curv_feat, alternation = self.curvature(grid, rigid_retained, fill_ratios) |
| | full = torch.cat([pooled, fill_ratios, rigid_retained, curv_feat, is_curved], dim=-1) |
| |
|
| | |
| | initial_logits = self.classifier(full) |
| |
|
| | |
| | refined_logits, refined_conf, initial_conf, trajectory_logits, flow_loss, blend_weight = \ |
| | self.arbiter(full, initial_logits, labels=labels) |
| |
|
| | |
| | |
| | final_logits = blend_weight * initial_logits + (1.0 - blend_weight) * refined_logits |
| |
|
| | return { |
| | |
| | "class_logits": final_logits, |
| | "initial_logits": initial_logits, |
| | "refined_logits": refined_logits, |
| | "trajectory_logits": trajectory_logits, |
| | |
| | "flow_loss": flow_loss, |
| | |
| | "confidence": initial_conf["confidence"], |
| | "max_prob": initial_conf["max_prob"], |
| | "entropy": initial_conf["entropy"], |
| | "refined_confidence": refined_conf, |
| | "blend_weight": blend_weight.squeeze(-1), |
| | |
| | "peak_logits": self.peak_head(full), |
| | "volume_pred": self.volume_head(full).squeeze(-1), |
| | "cm_pred": self.cm_head(full).squeeze(-1), |
| | "edge_lengths": edge_lengths, |
| | "fill_ratios": fill_ratios, |
| | "overflows": ovf_norms, |
| | "capacities": torch.stack([cap0, cap1, cap2, cap3]), |
| | "is_curved_pred": is_curved, |
| | "curv_type_logits": curv_logits, |
| | "alternation": alternation, |
| | |
| | "features": full, |
| | } |
| |
|
| |
|
| | |
| | _m = GeometricShapeClassifier() |
| | print(f'GeometricShapeClassifier: {sum(p.numel() for p in _m.parameters()):,} params') |
| | del _m |