| """ |
| 31-Class Edit Operation Classifier β Neuroswarm Tier 2 Verification Engine |
| |
| Verification stack: |
| Tier 1: 33-dim profile cosine similarity (nanoseconds, GPU) |
| Tier 2: THIS β edit classifier inference (milliseconds, GPU) |
| Tier 3: LLM review (seconds, API call, costs tokens) |
| |
| Pipeline: |
| (before_hsl, after_hsl) each (B, H, W, 3) |
| β Circular hue encoding: h β (sin(2Οh), cos(2Οh)), stack with S,L β 4D |
| β HSLFeatureExtractor (ViT spatial features) |
| β HybridRegionPooler (DETR-style learned queries, no scope markers) |
| β Delta computation + fusion |
| β Concat: [global_feat, profile_delta_33, oklab_magnitude_1] |
| β Hierarchical classifier: level (3) β op (31) |
| |
| Fixes over v1: |
| 1. Circular hue encoding (HSLFeatureExtractor) β hue wraparound correct |
| 2. HybridRegionPooler β DETR learned queries with iterative refinement |
| 3. 33-dim profile delta conditioning β structural direction signal |
| 4. OKLab delta magnitude β perceptual change size signal |
| """ |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Optional, Tuple, Dict, List |
|
|
| from .edit_ops import TRAINABLE_OPS, NUM_OPS, OP_TO_IDX, IDX_TO_OP, OpCode, OP_LEVEL |
| from .hsl_feature_extractor import HSLFeatureExtractor |
| from .hybrid_pooler import HybridRegionPooler |
| from .oklab_utils import hsl_to_oklab_batch |
|
|
|
|
| class EditOpClassifier(nn.Module): |
| """ |
| Neuroswarm Tier 2: Classifies edit ops from before/after palette pairs. |
| |
| Managers call this thousands of times per cycle to verify sub-agent work |
| without spending tokens on LLM review. ~1ms inference on GPU. |
| |
| Input: (before_hsl, after_hsl) each (B, H, W, 3) normalized HSL [0,1] |
| Output: (op_logits_31, level_logits_3, global_features) |
| """ |
|
|
| PROFILE_DIM = 33 |
| OKLAB_DIM = 1 |
|
|
| def __init__( |
| self, |
| hidden_dim: int = 256, |
| vit_layers: int = 4, |
| vit_heads: int = 8, |
| num_regions: int = 8, |
| patch_size: int = 4, |
| num_refinement_iters: int = 2, |
| dropout: float = 0.1, |
| ): |
| super().__init__() |
| self.hidden_dim = hidden_dim |
|
|
| |
| |
| |
| self.feature_extractor = HSLFeatureExtractor( |
| hidden_dim=hidden_dim, |
| num_layers=vit_layers, |
| num_heads=vit_heads, |
| patch_size=patch_size, |
| dropout=dropout, |
| ) |
|
|
| |
| |
| |
| self.region_pooler = HybridRegionPooler( |
| hidden_dim=hidden_dim, |
| num_learned_queries=num_regions, |
| num_heads=vit_heads, |
| use_structure=False, |
| dropout=dropout, |
| num_refinement_iters=num_refinement_iters, |
| ) |
|
|
| |
| self.delta_fusion = nn.Sequential( |
| nn.Linear(hidden_dim * 3, hidden_dim * 2), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(hidden_dim * 2, hidden_dim), |
| nn.LayerNorm(hidden_dim), |
| ) |
|
|
| |
| self.global_query = nn.Parameter(torch.randn(1, 1, hidden_dim) * 0.02) |
| self.global_attn = nn.MultiheadAttention( |
| hidden_dim, vit_heads, dropout=dropout, batch_first=True |
| ) |
|
|
| |
| |
| |
| self.profile_proj = nn.Sequential( |
| nn.Linear(self.PROFILE_DIM, hidden_dim // 4), |
| nn.GELU(), |
| nn.LayerNorm(hidden_dim // 4), |
| ) |
|
|
| |
| |
| self.oklab_proj = nn.Sequential( |
| nn.Linear(self.OKLAB_DIM, hidden_dim // 8), |
| nn.GELU(), |
| ) |
|
|
| |
| cond_dim = hidden_dim + hidden_dim // 4 + hidden_dim // 8 |
|
|
| |
| self.level_head = nn.Sequential( |
| nn.Linear(cond_dim, hidden_dim // 2), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(hidden_dim // 2, 3), |
| ) |
|
|
| |
| |
| self.op_head = nn.Sequential( |
| nn.Linear(cond_dim + 3, hidden_dim), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(hidden_dim, hidden_dim // 2), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(hidden_dim // 2, NUM_OPS), |
| ) |
|
|
| self._init_weights() |
|
|
| def _init_weights(self): |
| for m in self.modules(): |
| if isinstance(m, nn.Linear): |
| nn.init.trunc_normal_(m.weight, std=0.02) |
| if m.bias is not None: |
| nn.init.zeros_(m.bias) |
|
|
| def encode_palette(self, hsl: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Encode HSL palette β region embeddings + importance scores. |
| |
| Args: |
| hsl: (B, H, W, 3) normalized HSL [0,1] |
| |
| Returns: |
| regions: (B, R, hidden_dim) region embeddings |
| importance: (B, R) importance scores |
| """ |
| |
| features = self.feature_extractor(hsl) |
|
|
| |
| regions, importance = self.region_pooler(features) |
|
|
| return regions, importance |
|
|
| @staticmethod |
| def compute_oklab_delta(before_hsl: torch.Tensor, after_hsl: torch.Tensor) -> torch.Tensor: |
| """ |
| Compute perceptual change magnitude in OKLab space. |
| |
| Returns: |
| (B, 1) scalar β mean DeltaE across all spatial positions |
| """ |
| |
| before_oklab = hsl_to_oklab_batch(before_hsl) |
| after_oklab = hsl_to_oklab_batch(after_hsl) |
|
|
| |
| delta_e = (before_oklab - after_oklab).pow(2).sum(dim=-1).sqrt() |
|
|
| |
| mean_delta_e = delta_e.mean(dim=(1, 2), keepdim=False) |
|
|
| return mean_delta_e.unsqueeze(-1) |
|
|
| def forward( |
| self, |
| before_hsl: torch.Tensor, |
| after_hsl: torch.Tensor, |
| profile_delta: Optional[torch.Tensor] = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| """ |
| Classify edit operation from before/after palette pair. |
| |
| Args: |
| before_hsl: (B, H, W, 3) palette before edit, HSL [0,1] |
| after_hsl: (B, H, W, 3) palette after edit, HSL [0,1] |
| profile_delta: (B, 33) optional structural profile delta (after - before) |
| If None, zeros are used (graceful degradation) |
| |
| Returns: |
| op_logits: (B, 31) logits over edit operations |
| level_logits: (B, 3) logits over levels |
| global_feat: (B, hidden_dim) fused delta representation |
| """ |
| B = before_hsl.shape[0] |
| device = before_hsl.device |
|
|
| |
| before_regions, before_imp = self.encode_palette(before_hsl) |
| after_regions, after_imp = self.encode_palette(after_hsl) |
|
|
| |
| imp = (before_imp + after_imp) / 2 |
| imp_w = imp.unsqueeze(-1) |
| delta = (after_regions - before_regions) * imp_w |
|
|
| |
| fused = torch.cat([before_regions, after_regions, delta], dim=-1) |
| fused = self.delta_fusion(fused) |
|
|
| |
| query = self.global_query.expand(B, -1, -1) |
| global_feat, _ = self.global_attn(query, fused, fused) |
| global_feat = global_feat.squeeze(1) |
|
|
| |
| if profile_delta is None: |
| profile_delta = torch.zeros(B, self.PROFILE_DIM, device=device) |
| profile_feat = self.profile_proj(profile_delta) |
|
|
| |
| oklab_delta = self.compute_oklab_delta(before_hsl, after_hsl) |
| oklab_feat = self.oklab_proj(oklab_delta) |
|
|
| |
| conditioned = torch.cat([global_feat, profile_feat, oklab_feat], dim=-1) |
|
|
| |
| level_logits = self.level_head(conditioned) |
|
|
| |
| op_input = torch.cat([conditioned, level_logits], dim=-1) |
| op_logits = self.op_head(op_input) |
|
|
| return op_logits, level_logits, global_feat |
|
|
|
|
| |
| |
| |
|
|
| class Tier1ProfileVerifier: |
| """ |
| Neuroswarm Tier 1: Nanosecond verification via 33-dim profile cosine similarity. |
| |
| Usage: |
| verifier = Tier1ProfileVerifier() |
| result = verifier.verify(expected_delta, actual_delta) |
| if result.tier == 'pass': ... |
| elif result.tier == 'escalate': ... # β Tier 2 |
| elif result.tier == 'reject': ... # β retry agent |
| """ |
|
|
| def __init__( |
| self, |
| pass_threshold: float = 0.7, |
| reject_threshold: float = 0.3, |
| ): |
| self.pass_threshold = pass_threshold |
| self.reject_threshold = reject_threshold |
|
|
| def verify( |
| self, |
| expected_delta: torch.Tensor, |
| actual_delta: torch.Tensor, |
| ) -> dict: |
| """ |
| Compare expected vs actual structural profile delta. |
| |
| Args: |
| expected_delta: (33,) or (B, 33) expected profile change |
| actual_delta: (33,) or (B, 33) actual profile change |
| |
| Returns: |
| dict with 'alignment', 'tier' ('pass'/'escalate'/'reject') |
| """ |
| if expected_delta.dim() == 1: |
| expected_delta = expected_delta.unsqueeze(0) |
| actual_delta = actual_delta.unsqueeze(0) |
|
|
| |
| alignment = F.cosine_similarity(expected_delta, actual_delta, dim=-1) |
|
|
| tiers = [] |
| for a in alignment: |
| a_val = a.item() |
| if a_val >= self.pass_threshold: |
| tiers.append('pass') |
| elif a_val >= self.reject_threshold: |
| tiers.append('escalate') |
| else: |
| tiers.append('reject') |
|
|
| return { |
| 'alignment': alignment, |
| 'tiers': tiers, |
| 'mean_alignment': alignment.mean().item(), |
| } |
|
|
|
|
| |
| |
| |
|
|
| class Tier2EditVerifier: |
| """ |
| Neuroswarm Tier 2: Millisecond verification via edit classifier. |
| |
| Usage: |
| verifier = Tier2EditVerifier(model, device='cuda') |
| result = verifier.verify(before_hsl, after_hsl, expected_op, profile_delta) |
| if result['match']: ... # agent did the right thing |
| else: ... # escalate to Tier 3 |
| """ |
|
|
| def __init__( |
| self, |
| model: EditOpClassifier, |
| device: str = 'cpu', |
| confidence_threshold: float = 0.8, |
| ): |
| self.model = model.to(device).eval() |
| self.device = device |
| self.confidence_threshold = confidence_threshold |
|
|
| @torch.no_grad() |
| def verify( |
| self, |
| before_hsl: torch.Tensor, |
| after_hsl: torch.Tensor, |
| expected_op: OpCode, |
| profile_delta: Optional[torch.Tensor] = None, |
| ) -> dict: |
| """ |
| Verify that an agent performed the expected edit operation. |
| |
| Returns: |
| dict with 'match', 'predicted_op', 'confidence', 'escalate' |
| """ |
| before = before_hsl.unsqueeze(0).to(self.device) if before_hsl.dim() == 3 else before_hsl.to(self.device) |
| after = after_hsl.unsqueeze(0).to(self.device) if after_hsl.dim() == 3 else after_hsl.to(self.device) |
| if profile_delta is not None: |
| profile_delta = profile_delta.unsqueeze(0).to(self.device) if profile_delta.dim() == 1 else profile_delta.to(self.device) |
|
|
| op_logits, level_logits, _ = self.model(before, after, profile_delta) |
|
|
| probs = F.softmax(op_logits, dim=-1) |
| pred_idx = probs.argmax(dim=-1).item() |
| confidence = probs[0, pred_idx].item() |
| predicted_op = IDX_TO_OP[pred_idx] |
|
|
| expected_idx = OP_TO_IDX[expected_op] |
| match = (pred_idx == expected_idx) and (confidence >= self.confidence_threshold) |
| escalate = not match |
|
|
| return { |
| 'match': match, |
| 'predicted_op': predicted_op, |
| 'predicted_op_name': predicted_op.name, |
| 'expected_op_name': expected_op.name, |
| 'confidence': confidence, |
| 'escalate': escalate, |
| 'op_probs': probs[0].cpu(), |
| } |
|
|
|
|
| |
| |
| |
|
|
| class EditOpLoss(nn.Module): |
| """ |
| Combined loss for edit op classification. |
| |
| Components: |
| - Cross-entropy on 31-class op prediction |
| - Cross-entropy on 3-class level prediction (auxiliary) |
| - Level-op consistency penalty |
| """ |
|
|
| def __init__(self, level_weight: float = 0.3, consistency_weight: float = 0.1): |
| super().__init__() |
| self.level_weight = level_weight |
| self.consistency_weight = consistency_weight |
| self.op_loss_fn = nn.CrossEntropyLoss(label_smoothing=0.05) |
| self.level_loss_fn = nn.CrossEntropyLoss(label_smoothing=0.05) |
|
|
| |
| self._op_to_level = {} |
| level_names = ['primitive', 'structural', 'semantic'] |
| for op in TRAINABLE_OPS: |
| level = OP_LEVEL[op] |
| self._op_to_level[OP_TO_IDX[op]] = level_names.index(level) |
|
|
| def forward( |
| self, |
| op_logits: torch.Tensor, |
| level_logits: torch.Tensor, |
| op_labels: torch.Tensor, |
| ) -> Tuple[torch.Tensor, Dict[str, float]]: |
| """ |
| Args: |
| op_logits: (B, 31) predicted op logits |
| level_logits: (B, 3) predicted level logits |
| op_labels: (B,) integer labels in [0, 30] |
| |
| Returns: |
| total_loss, metrics_dict |
| """ |
| op_loss = self.op_loss_fn(op_logits, op_labels) |
|
|
| level_labels = torch.tensor( |
| [self._op_to_level[l.item()] for l in op_labels], |
| device=op_labels.device, dtype=torch.long |
| ) |
| level_loss = self.level_loss_fn(level_logits, level_labels) |
|
|
| pred_ops = op_logits.argmax(dim=-1) |
| pred_levels = level_logits.argmax(dim=-1) |
| expected_levels = torch.tensor( |
| [self._op_to_level[p.item()] for p in pred_ops], |
| device=op_labels.device, dtype=torch.long |
| ) |
| consistency = (pred_levels == expected_levels).float().mean() |
| consistency_loss = 1.0 - consistency |
|
|
| total = op_loss + self.level_weight * level_loss + self.consistency_weight * consistency_loss |
|
|
| metrics = { |
| 'loss': total.item(), |
| 'op_loss': op_loss.item(), |
| 'level_loss': level_loss.item(), |
| 'consistency': consistency.item(), |
| 'op_acc': (pred_ops == op_labels).float().mean().item(), |
| 'level_acc': (pred_levels == level_labels).float().mean().item(), |
| } |
|
|
| return total, metrics |
|
|
| @staticmethod |
| def op_label_from_opcode(opcode: OpCode) -> int: |
| return OP_TO_IDX[opcode] |
|
|
| @staticmethod |
| def opcode_from_label(label: int) -> OpCode: |
| return IDX_TO_OP[label] |
|
|