palette-edit-classifier / models /edit_classifier.py
Jonttup's picture
Upload models/edit_classifier.py with huggingface_hub
2f9ad67 verified
"""
31-Class Edit Operation Classifier β€” Neuroswarm Tier 2 Verification Engine
Verification stack:
Tier 1: 33-dim profile cosine similarity (nanoseconds, GPU)
Tier 2: THIS β€” edit classifier inference (milliseconds, GPU)
Tier 3: LLM review (seconds, API call, costs tokens)
Pipeline:
(before_hsl, after_hsl) each (B, H, W, 3)
β†’ Circular hue encoding: h β†’ (sin(2Ο€h), cos(2Ο€h)), stack with S,L β†’ 4D
β†’ HSLFeatureExtractor (ViT spatial features)
β†’ HybridRegionPooler (DETR-style learned queries, no scope markers)
β†’ Delta computation + fusion
β†’ Concat: [global_feat, profile_delta_33, oklab_magnitude_1]
β†’ Hierarchical classifier: level (3) β†’ op (31)
Fixes over v1:
1. Circular hue encoding (HSLFeatureExtractor) β€” hue wraparound correct
2. HybridRegionPooler β€” DETR learned queries with iterative refinement
3. 33-dim profile delta conditioning β€” structural direction signal
4. OKLab delta magnitude β€” perceptual change size signal
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, Dict, List
from .edit_ops import TRAINABLE_OPS, NUM_OPS, OP_TO_IDX, IDX_TO_OP, OpCode, OP_LEVEL
from .hsl_feature_extractor import HSLFeatureExtractor
from .hybrid_pooler import HybridRegionPooler
from .oklab_utils import hsl_to_oklab_batch
class EditOpClassifier(nn.Module):
"""
Neuroswarm Tier 2: Classifies edit ops from before/after palette pairs.
Managers call this thousands of times per cycle to verify sub-agent work
without spending tokens on LLM review. ~1ms inference on GPU.
Input: (before_hsl, after_hsl) each (B, H, W, 3) normalized HSL [0,1]
Output: (op_logits_31, level_logits_3, global_features)
"""
PROFILE_DIM = 33 # Structural profile vector dimensionality
OKLAB_DIM = 1 # Perceptual delta magnitude (scalar)
def __init__(
self,
hidden_dim: int = 256,
vit_layers: int = 4,
vit_heads: int = 8,
num_regions: int = 8,
patch_size: int = 4,
num_refinement_iters: int = 2,
dropout: float = 0.1,
):
super().__init__()
self.hidden_dim = hidden_dim
# Fix 1: HSLFeatureExtractor with circular hue encoding
# h β†’ (sin(2Ο€h), cos(2Ο€h)) handles hue wraparound correctly
# 359Β° and 1Β° are adjacent, not 358 apart
self.feature_extractor = HSLFeatureExtractor(
hidden_dim=hidden_dim,
num_layers=vit_layers,
num_heads=vit_heads,
patch_size=patch_size,
dropout=dropout,
)
# Fix 2: HybridRegionPooler β€” DETR-style learned queries
# use_structure=False because HSL palettes have NO scope markers
# Iterative refinement (Slot Attention style)
self.region_pooler = HybridRegionPooler(
hidden_dim=hidden_dim,
num_learned_queries=num_regions,
num_heads=vit_heads,
use_structure=False,
dropout=dropout,
num_refinement_iters=num_refinement_iters,
)
# Delta fusion: (before_regions, after_regions, delta) β†’ fused
self.delta_fusion = nn.Sequential(
nn.Linear(hidden_dim * 3, hidden_dim * 2),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim * 2, hidden_dim),
nn.LayerNorm(hidden_dim),
)
# Global pooling via attention
self.global_query = nn.Parameter(torch.randn(1, 1, hidden_dim) * 0.02)
self.global_attn = nn.MultiheadAttention(
hidden_dim, vit_heads, dropout=dropout, batch_first=True
)
# Fix 3: 33-dim profile delta projection
# Structural profile captures category distribution, color stats,
# scope depth, spectral alignment β€” compressed direction signal
self.profile_proj = nn.Sequential(
nn.Linear(self.PROFILE_DIM, hidden_dim // 4),
nn.GELU(),
nn.LayerNorm(hidden_dim // 4),
)
# Fix 4: OKLab delta magnitude projection
# Single scalar β€” "how big was this change" in perceptual space
self.oklab_proj = nn.Sequential(
nn.Linear(self.OKLAB_DIM, hidden_dim // 8),
nn.GELU(),
)
# Conditioning input size: hidden_dim + profile_proj + oklab_proj
cond_dim = hidden_dim + hidden_dim // 4 + hidden_dim // 8
# Level classifier (primitive / structural / semantic)
self.level_head = nn.Sequential(
nn.Linear(cond_dim, hidden_dim // 2),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim // 2, 3),
)
# Fine-grained op classifier (31 classes)
# Conditioned on level logits (hierarchical)
self.op_head = nn.Sequential(
nn.Linear(cond_dim + 3, hidden_dim), # +3 for level logits
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim // 2, NUM_OPS),
)
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
def encode_palette(self, hsl: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Encode HSL palette β†’ region embeddings + importance scores.
Args:
hsl: (B, H, W, 3) normalized HSL [0,1]
Returns:
regions: (B, R, hidden_dim) region embeddings
importance: (B, R) importance scores
"""
# HSLFeatureExtractor: circular hue β†’ ViT spatial features
features = self.feature_extractor(hsl) # (B, H, W, D)
# HybridRegionPooler: DETR queries β†’ region embeddings
regions, importance = self.region_pooler(features) # (B, R, D), (B, R)
return regions, importance
@staticmethod
def compute_oklab_delta(before_hsl: torch.Tensor, after_hsl: torch.Tensor) -> torch.Tensor:
"""
Compute perceptual change magnitude in OKLab space.
Returns:
(B, 1) scalar β€” mean DeltaE across all spatial positions
"""
# Convert to OKLab
before_oklab = hsl_to_oklab_batch(before_hsl) # (B, H, W, 3)
after_oklab = hsl_to_oklab_batch(after_hsl) # (B, H, W, 3)
# Per-pixel DeltaE
delta_e = (before_oklab - after_oklab).pow(2).sum(dim=-1).sqrt() # (B, H, W)
# Mean across spatial dimensions
mean_delta_e = delta_e.mean(dim=(1, 2), keepdim=False) # (B,)
return mean_delta_e.unsqueeze(-1) # (B, 1)
def forward(
self,
before_hsl: torch.Tensor,
after_hsl: torch.Tensor,
profile_delta: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Classify edit operation from before/after palette pair.
Args:
before_hsl: (B, H, W, 3) palette before edit, HSL [0,1]
after_hsl: (B, H, W, 3) palette after edit, HSL [0,1]
profile_delta: (B, 33) optional structural profile delta (after - before)
If None, zeros are used (graceful degradation)
Returns:
op_logits: (B, 31) logits over edit operations
level_logits: (B, 3) logits over levels
global_feat: (B, hidden_dim) fused delta representation
"""
B = before_hsl.shape[0]
device = before_hsl.device
# Encode both palettes through shared feature extractor + pooler
before_regions, before_imp = self.encode_palette(before_hsl) # (B, R, D)
after_regions, after_imp = self.encode_palette(after_hsl) # (B, R, D)
# Compute delta (importance-weighted)
imp = (before_imp + after_imp) / 2 # (B, R)
imp_w = imp.unsqueeze(-1) # (B, R, 1)
delta = (after_regions - before_regions) * imp_w
# Fuse: [before, after, delta] β†’ fused features
fused = torch.cat([before_regions, after_regions, delta], dim=-1) # (B, R, 3*D)
fused = self.delta_fusion(fused) # (B, R, D)
# Global pool via attention
query = self.global_query.expand(B, -1, -1)
global_feat, _ = self.global_attn(query, fused, fused)
global_feat = global_feat.squeeze(1) # (B, D)
# Fix 3: Profile delta conditioning
if profile_delta is None:
profile_delta = torch.zeros(B, self.PROFILE_DIM, device=device)
profile_feat = self.profile_proj(profile_delta) # (B, D//4)
# Fix 4: OKLab delta magnitude
oklab_delta = self.compute_oklab_delta(before_hsl, after_hsl) # (B, 1)
oklab_feat = self.oklab_proj(oklab_delta) # (B, D//8)
# Concatenate all conditioning signals
conditioned = torch.cat([global_feat, profile_feat, oklab_feat], dim=-1) # (B, D + D//4 + D//8)
# Level classification
level_logits = self.level_head(conditioned) # (B, 3)
# Fine op classification (conditioned on level)
op_input = torch.cat([conditioned, level_logits], dim=-1)
op_logits = self.op_head(op_input) # (B, 31)
return op_logits, level_logits, global_feat
# ====================================================================
# Tier 1: Profile cosine similarity (nanoseconds)
# ====================================================================
class Tier1ProfileVerifier:
"""
Neuroswarm Tier 1: Nanosecond verification via 33-dim profile cosine similarity.
Usage:
verifier = Tier1ProfileVerifier()
result = verifier.verify(expected_delta, actual_delta)
if result.tier == 'pass': ...
elif result.tier == 'escalate': ... # β†’ Tier 2
elif result.tier == 'reject': ... # β†’ retry agent
"""
def __init__(
self,
pass_threshold: float = 0.7,
reject_threshold: float = 0.3,
):
self.pass_threshold = pass_threshold
self.reject_threshold = reject_threshold
def verify(
self,
expected_delta: torch.Tensor,
actual_delta: torch.Tensor,
) -> dict:
"""
Compare expected vs actual structural profile delta.
Args:
expected_delta: (33,) or (B, 33) expected profile change
actual_delta: (33,) or (B, 33) actual profile change
Returns:
dict with 'alignment', 'tier' ('pass'/'escalate'/'reject')
"""
if expected_delta.dim() == 1:
expected_delta = expected_delta.unsqueeze(0)
actual_delta = actual_delta.unsqueeze(0)
# Cosine similarity
alignment = F.cosine_similarity(expected_delta, actual_delta, dim=-1) # (B,)
tiers = []
for a in alignment:
a_val = a.item()
if a_val >= self.pass_threshold:
tiers.append('pass')
elif a_val >= self.reject_threshold:
tiers.append('escalate')
else:
tiers.append('reject')
return {
'alignment': alignment,
'tiers': tiers,
'mean_alignment': alignment.mean().item(),
}
# ====================================================================
# Tier 2: Edit classifier inference wrapper
# ====================================================================
class Tier2EditVerifier:
"""
Neuroswarm Tier 2: Millisecond verification via edit classifier.
Usage:
verifier = Tier2EditVerifier(model, device='cuda')
result = verifier.verify(before_hsl, after_hsl, expected_op, profile_delta)
if result['match']: ... # agent did the right thing
else: ... # escalate to Tier 3
"""
def __init__(
self,
model: EditOpClassifier,
device: str = 'cpu',
confidence_threshold: float = 0.8,
):
self.model = model.to(device).eval()
self.device = device
self.confidence_threshold = confidence_threshold
@torch.no_grad()
def verify(
self,
before_hsl: torch.Tensor,
after_hsl: torch.Tensor,
expected_op: OpCode,
profile_delta: Optional[torch.Tensor] = None,
) -> dict:
"""
Verify that an agent performed the expected edit operation.
Returns:
dict with 'match', 'predicted_op', 'confidence', 'escalate'
"""
before = before_hsl.unsqueeze(0).to(self.device) if before_hsl.dim() == 3 else before_hsl.to(self.device)
after = after_hsl.unsqueeze(0).to(self.device) if after_hsl.dim() == 3 else after_hsl.to(self.device)
if profile_delta is not None:
profile_delta = profile_delta.unsqueeze(0).to(self.device) if profile_delta.dim() == 1 else profile_delta.to(self.device)
op_logits, level_logits, _ = self.model(before, after, profile_delta)
probs = F.softmax(op_logits, dim=-1)
pred_idx = probs.argmax(dim=-1).item()
confidence = probs[0, pred_idx].item()
predicted_op = IDX_TO_OP[pred_idx]
expected_idx = OP_TO_IDX[expected_op]
match = (pred_idx == expected_idx) and (confidence >= self.confidence_threshold)
escalate = not match
return {
'match': match,
'predicted_op': predicted_op,
'predicted_op_name': predicted_op.name,
'expected_op_name': expected_op.name,
'confidence': confidence,
'escalate': escalate,
'op_probs': probs[0].cpu(),
}
# ====================================================================
# Loss
# ====================================================================
class EditOpLoss(nn.Module):
"""
Combined loss for edit op classification.
Components:
- Cross-entropy on 31-class op prediction
- Cross-entropy on 3-class level prediction (auxiliary)
- Level-op consistency penalty
"""
def __init__(self, level_weight: float = 0.3, consistency_weight: float = 0.1):
super().__init__()
self.level_weight = level_weight
self.consistency_weight = consistency_weight
self.op_loss_fn = nn.CrossEntropyLoss(label_smoothing=0.05)
self.level_loss_fn = nn.CrossEntropyLoss(label_smoothing=0.05)
# Build op β†’ level mapping
self._op_to_level = {}
level_names = ['primitive', 'structural', 'semantic']
for op in TRAINABLE_OPS:
level = OP_LEVEL[op]
self._op_to_level[OP_TO_IDX[op]] = level_names.index(level)
def forward(
self,
op_logits: torch.Tensor,
level_logits: torch.Tensor,
op_labels: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, float]]:
"""
Args:
op_logits: (B, 31) predicted op logits
level_logits: (B, 3) predicted level logits
op_labels: (B,) integer labels in [0, 30]
Returns:
total_loss, metrics_dict
"""
op_loss = self.op_loss_fn(op_logits, op_labels)
level_labels = torch.tensor(
[self._op_to_level[l.item()] for l in op_labels],
device=op_labels.device, dtype=torch.long
)
level_loss = self.level_loss_fn(level_logits, level_labels)
pred_ops = op_logits.argmax(dim=-1)
pred_levels = level_logits.argmax(dim=-1)
expected_levels = torch.tensor(
[self._op_to_level[p.item()] for p in pred_ops],
device=op_labels.device, dtype=torch.long
)
consistency = (pred_levels == expected_levels).float().mean()
consistency_loss = 1.0 - consistency
total = op_loss + self.level_weight * level_loss + self.consistency_weight * consistency_loss
metrics = {
'loss': total.item(),
'op_loss': op_loss.item(),
'level_loss': level_loss.item(),
'consistency': consistency.item(),
'op_acc': (pred_ops == op_labels).float().mean().item(),
'level_acc': (pred_levels == level_labels).float().mean().item(),
}
return total, metrics
@staticmethod
def op_label_from_opcode(opcode: OpCode) -> int:
return OP_TO_IDX[opcode]
@staticmethod
def opcode_from_label(label: int) -> OpCode:
return IDX_TO_OP[label]