Buckets:
bbkdevops/unicosys-hypergraph-bucket / tinymind-native-8b-remote-handoff /bundle /model /stable_mesh_security.py
| """Stable Mesh security classifier for defensive CWE-style classification. | |
| This module is a lightweight, testable implementation of the mesh blueprint: | |
| - base code-model backbones are frozen by default, | |
| - only projections, cycle gates/norms, router, and security head are trainable, | |
| - Stage 1 can additionally freeze the router, | |
| - tokenization/data collation is isolated from the model so raw code or | |
| disassembly can be handled as classification inputs. | |
| The class supports injected backbones for tests and small local experiments. | |
| Loading real Hugging Face models should be done by the caller with quantization | |
| and device mapping appropriate to the machine; this module never calls | |
| ``model.cuda()`` or ``model.to("cuda")`` on quantized bases. | |
| """ | |
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from typing import Protocol | |
| import torch | |
| import torch.nn as nn | |
| from torch.utils.data import Dataset | |
| from .layers import RMSNorm | |
| class TokenizerLike(Protocol): | |
| def __call__( | |
| self, | |
| texts: list[str], | |
| *, | |
| padding: str, | |
| truncation: bool, | |
| max_length: int, | |
| return_tensors: str | None = None, | |
| ): | |
| ... | |
| class StableMeshConfig: | |
| num_cwe_classes: int = 118 | |
| num_cycles: int = 3 | |
| target_dim: int = 4096 | |
| node_dim: int = 4096 | |
| dropout: float = 0.05 | |
| stage1_freeze_router: bool = True | |
| def __post_init__(self) -> None: | |
| if self.num_cwe_classes <= 0: | |
| raise ValueError("num_cwe_classes must be positive") | |
| if self.num_cycles <= 0: | |
| raise ValueError("num_cycles must be positive") | |
| if self.target_dim <= 0 or self.node_dim <= 0: | |
| raise ValueError("target_dim and node_dim must be positive") | |
| class FrozenBackboneNode(nn.Module): | |
| """Small wrapper around a frozen feature extractor.""" | |
| def __init__(self, backbone: nn.Module, output_dim: int): | |
| super().__init__() | |
| self.backbone = backbone | |
| self.output_dim = output_dim | |
| for param in self.backbone.parameters(): | |
| param.requires_grad = False | |
| def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None) -> torch.Tensor: | |
| with torch.no_grad(): | |
| out = self.backbone(input_ids=input_ids, attention_mask=attention_mask) | |
| if isinstance(out, dict): | |
| hidden = out.get("last_hidden_state") | |
| if hidden is None: | |
| hidden = out.get("hidden_states", [None])[-1] | |
| else: | |
| hidden = getattr(out, "last_hidden_state", None) | |
| if hidden is None and isinstance(out, (tuple, list)): | |
| hidden = out[0] | |
| if hidden is None: | |
| raise ValueError("backbone must return last_hidden_state-compatible output") | |
| mask = attention_mask.unsqueeze(-1).to(hidden.dtype) if attention_mask is not None else torch.ones_like(hidden[..., :1]) | |
| return (hidden * mask).sum(dim=1) / mask.sum(dim=1).clamp_min(1.0) | |
| class StableMeshSecurityClassifier(nn.Module): | |
| """Mesh classifier over three frozen code-understanding nodes.""" | |
| TRAINABLE_MARKERS = ("proj_in", "proj_out", "mesh_router", "cycle_gate", "cycle_norm", "security_head") | |
| def __init__( | |
| self, | |
| *, | |
| g_node: nn.Module, | |
| s_node: nn.Module, | |
| y_node: nn.Module, | |
| cfg: StableMeshConfig | None = None, | |
| ): | |
| super().__init__() | |
| self.cfg = cfg or StableMeshConfig() | |
| self.nodes = nn.ModuleList([g_node, s_node, y_node]) | |
| self.proj_in = nn.ModuleList([nn.Linear(self.cfg.node_dim, self.cfg.target_dim, bias=False) for _ in range(3)]) | |
| self.proj_out = nn.Linear(self.cfg.target_dim, self.cfg.target_dim, bias=False) | |
| self.mesh_router = nn.Linear(self.cfg.target_dim, 3, bias=True) | |
| self.cycle_gate = nn.Parameter(torch.full((self.cfg.num_cycles, 3), 2.0)) | |
| self.cycle_norm = nn.ModuleList([RMSNorm(self.cfg.target_dim) for _ in range(self.cfg.num_cycles)]) | |
| self.dropout = nn.Dropout(self.cfg.dropout) | |
| self.security_head = nn.Linear(self.cfg.target_dim, self.cfg.num_cwe_classes) | |
| self.freeze_for_stage1(stage1_freeze_router=self.cfg.stage1_freeze_router) | |
| def freeze_for_stage1(self, *, stage1_freeze_router: bool = True) -> dict[str, int]: | |
| trainable = 0 | |
| frozen = 0 | |
| for name, param in self.named_parameters(): | |
| should_train = any(marker in name for marker in self.TRAINABLE_MARKERS) | |
| if stage1_freeze_router and "mesh_router" in name: | |
| should_train = False | |
| param.requires_grad = should_train | |
| if should_train: | |
| trainable += param.numel() | |
| else: | |
| frozen += param.numel() | |
| return {"trainable_params": trainable, "frozen_params": frozen} | |
| def unfreeze_router_for_stage2(self) -> dict[str, int]: | |
| return self.freeze_for_stage1(stage1_freeze_router=False) | |
| def _node_features(self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None) -> list[torch.Tensor]: | |
| feats = [] | |
| for node, proj in zip(self.nodes, self.proj_in): | |
| feature = node(input_ids=input_ids, attention_mask=attention_mask) | |
| if feature.shape[-1] != self.cfg.node_dim: | |
| raise ValueError(f"node feature dim {feature.shape[-1]} != expected {self.cfg.node_dim}") | |
| feats.append(proj(feature)) | |
| return feats | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: torch.Tensor | None = None, | |
| labels: torch.Tensor | None = None, | |
| return_metrics: bool = False, | |
| ) -> dict[str, torch.Tensor | dict]: | |
| features = self._node_features(input_ids, attention_mask) | |
| stacked = torch.stack(features, dim=1) | |
| state = stacked.mean(dim=1) | |
| route_snapshots = [] | |
| for cycle_idx in range(self.cfg.num_cycles): | |
| normed = self.cycle_norm[cycle_idx](state) | |
| route = torch.softmax(self.mesh_router(normed), dim=-1) | |
| gate = torch.sigmoid(self.cycle_gate[cycle_idx]).view(1, 3, 1) | |
| mixed = (stacked * route.unsqueeze(-1) * gate).sum(dim=1) | |
| state = state + torch.tanh(mixed) / float(cycle_idx + 2) | |
| route_snapshots.append(route.detach()) | |
| logits = self.security_head(self.dropout(self.proj_out(state))) | |
| result: dict[str, torch.Tensor | dict] = {"logits": logits} | |
| if labels is not None: | |
| result["loss"] = torch.nn.functional.cross_entropy(logits, labels) | |
| if return_metrics: | |
| route_mean = torch.stack(route_snapshots).mean(dim=(0, 1)) | |
| result["metrics"] = { | |
| "route_mean": route_mean, | |
| "route_entropy": -(route_mean * torch.log(route_mean.clamp_min(1e-8))).sum(), | |
| "trainable_params": sum(p.numel() for p in self.parameters() if p.requires_grad), | |
| "frozen_params": sum(p.numel() for p in self.parameters() if not p.requires_grad), | |
| } | |
| return result | |
| class PurpleBinaryDataset(Dataset): | |
| """Tokenized defensive security classification dataset.""" | |
| def __init__(self, encodings: dict[str, list[list[int]]], label_indices: list[int]): | |
| if len(encodings.get("input_ids", [])) != len(label_indices): | |
| raise ValueError("input_ids and labels must have the same length") | |
| self.encodings = encodings | |
| self.labels = label_indices | |
| def __len__(self) -> int: | |
| return len(self.labels) | |
| def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: | |
| item = { | |
| "input_ids": torch.tensor(self.encodings["input_ids"][idx], dtype=torch.long), | |
| "attention_mask": torch.tensor(self.encodings["attention_mask"][idx], dtype=torch.long), | |
| "labels": torch.tensor(self.labels[idx], dtype=torch.long), | |
| } | |
| return item | |
| def tokenize_security_records( | |
| records: list[dict], | |
| tokenizer: TokenizerLike, | |
| *, | |
| label_to_id: dict[str, int], | |
| max_length: int = 2048, | |
| text_key: str = "text", | |
| label_key: str = "label", | |
| ) -> PurpleBinaryDataset: | |
| texts: list[str] = [] | |
| labels: list[int] = [] | |
| for record in records: | |
| text = str(record.get(text_key, "")).strip() | |
| label = str(record.get(label_key, "")).strip() | |
| if not text: | |
| continue | |
| if label not in label_to_id: | |
| raise ValueError(f"unknown label: {label}") | |
| texts.append(text) | |
| labels.append(label_to_id[label]) | |
| encoded = tokenizer( | |
| texts, | |
| padding="max_length", | |
| truncation=True, | |
| max_length=max_length, | |
| return_tensors=None, | |
| ) | |
| return PurpleBinaryDataset( | |
| {"input_ids": encoded["input_ids"], "attention_mask": encoded["attention_mask"]}, | |
| labels, | |
| ) | |
| def build_stage1_optimizer(model: StableMeshSecurityClassifier, *, lr: float = 1e-4, weight_decay: float = 0.01): | |
| params = [p for p in model.parameters() if p.requires_grad] | |
| if not params: | |
| raise ValueError("no trainable parameters after stage1 freezing") | |
| return torch.optim.AdamW(params, lr=lr, weight_decay=weight_decay) | |
Xet Storage Details
- Size:
- 9.21 kB
- Xet hash:
- e4d4f999b18784602a5ed5b1640121cd840f4724eac071b21d24276350281048
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.