bbkdevops's picture
download
raw
9.21 kB
"""Stable Mesh security classifier for defensive CWE-style classification.
This module is a lightweight, testable implementation of the mesh blueprint:
- base code-model backbones are frozen by default,
- only projections, cycle gates/norms, router, and security head are trainable,
- Stage 1 can additionally freeze the router,
- tokenization/data collation is isolated from the model so raw code or
disassembly can be handled as classification inputs.
The class supports injected backbones for tests and small local experiments.
Loading real Hugging Face models should be done by the caller with quantization
and device mapping appropriate to the machine; this module never calls
``model.cuda()`` or ``model.to("cuda")`` on quantized bases.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Protocol
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from .layers import RMSNorm
class TokenizerLike(Protocol):
def __call__(
self,
texts: list[str],
*,
padding: str,
truncation: bool,
max_length: int,
return_tensors: str | None = None,
):
...
@dataclass
class StableMeshConfig:
num_cwe_classes: int = 118
num_cycles: int = 3
target_dim: int = 4096
node_dim: int = 4096
dropout: float = 0.05
stage1_freeze_router: bool = True
def __post_init__(self) -> None:
if self.num_cwe_classes <= 0:
raise ValueError("num_cwe_classes must be positive")
if self.num_cycles <= 0:
raise ValueError("num_cycles must be positive")
if self.target_dim <= 0 or self.node_dim <= 0:
raise ValueError("target_dim and node_dim must be positive")
class FrozenBackboneNode(nn.Module):
"""Small wrapper around a frozen feature extractor."""
def __init__(self, backbone: nn.Module, output_dim: int):
super().__init__()
self.backbone = backbone
self.output_dim = output_dim
for param in self.backbone.parameters():
param.requires_grad = False
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None) -> torch.Tensor:
with torch.no_grad():
out = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
if isinstance(out, dict):
hidden = out.get("last_hidden_state")
if hidden is None:
hidden = out.get("hidden_states", [None])[-1]
else:
hidden = getattr(out, "last_hidden_state", None)
if hidden is None and isinstance(out, (tuple, list)):
hidden = out[0]
if hidden is None:
raise ValueError("backbone must return last_hidden_state-compatible output")
mask = attention_mask.unsqueeze(-1).to(hidden.dtype) if attention_mask is not None else torch.ones_like(hidden[..., :1])
return (hidden * mask).sum(dim=1) / mask.sum(dim=1).clamp_min(1.0)
class StableMeshSecurityClassifier(nn.Module):
"""Mesh classifier over three frozen code-understanding nodes."""
TRAINABLE_MARKERS = ("proj_in", "proj_out", "mesh_router", "cycle_gate", "cycle_norm", "security_head")
def __init__(
self,
*,
g_node: nn.Module,
s_node: nn.Module,
y_node: nn.Module,
cfg: StableMeshConfig | None = None,
):
super().__init__()
self.cfg = cfg or StableMeshConfig()
self.nodes = nn.ModuleList([g_node, s_node, y_node])
self.proj_in = nn.ModuleList([nn.Linear(self.cfg.node_dim, self.cfg.target_dim, bias=False) for _ in range(3)])
self.proj_out = nn.Linear(self.cfg.target_dim, self.cfg.target_dim, bias=False)
self.mesh_router = nn.Linear(self.cfg.target_dim, 3, bias=True)
self.cycle_gate = nn.Parameter(torch.full((self.cfg.num_cycles, 3), 2.0))
self.cycle_norm = nn.ModuleList([RMSNorm(self.cfg.target_dim) for _ in range(self.cfg.num_cycles)])
self.dropout = nn.Dropout(self.cfg.dropout)
self.security_head = nn.Linear(self.cfg.target_dim, self.cfg.num_cwe_classes)
self.freeze_for_stage1(stage1_freeze_router=self.cfg.stage1_freeze_router)
def freeze_for_stage1(self, *, stage1_freeze_router: bool = True) -> dict[str, int]:
trainable = 0
frozen = 0
for name, param in self.named_parameters():
should_train = any(marker in name for marker in self.TRAINABLE_MARKERS)
if stage1_freeze_router and "mesh_router" in name:
should_train = False
param.requires_grad = should_train
if should_train:
trainable += param.numel()
else:
frozen += param.numel()
return {"trainable_params": trainable, "frozen_params": frozen}
def unfreeze_router_for_stage2(self) -> dict[str, int]:
return self.freeze_for_stage1(stage1_freeze_router=False)
def _node_features(self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None) -> list[torch.Tensor]:
feats = []
for node, proj in zip(self.nodes, self.proj_in):
feature = node(input_ids=input_ids, attention_mask=attention_mask)
if feature.shape[-1] != self.cfg.node_dim:
raise ValueError(f"node feature dim {feature.shape[-1]} != expected {self.cfg.node_dim}")
feats.append(proj(feature))
return feats
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
labels: torch.Tensor | None = None,
return_metrics: bool = False,
) -> dict[str, torch.Tensor | dict]:
features = self._node_features(input_ids, attention_mask)
stacked = torch.stack(features, dim=1)
state = stacked.mean(dim=1)
route_snapshots = []
for cycle_idx in range(self.cfg.num_cycles):
normed = self.cycle_norm[cycle_idx](state)
route = torch.softmax(self.mesh_router(normed), dim=-1)
gate = torch.sigmoid(self.cycle_gate[cycle_idx]).view(1, 3, 1)
mixed = (stacked * route.unsqueeze(-1) * gate).sum(dim=1)
state = state + torch.tanh(mixed) / float(cycle_idx + 2)
route_snapshots.append(route.detach())
logits = self.security_head(self.dropout(self.proj_out(state)))
result: dict[str, torch.Tensor | dict] = {"logits": logits}
if labels is not None:
result["loss"] = torch.nn.functional.cross_entropy(logits, labels)
if return_metrics:
route_mean = torch.stack(route_snapshots).mean(dim=(0, 1))
result["metrics"] = {
"route_mean": route_mean,
"route_entropy": -(route_mean * torch.log(route_mean.clamp_min(1e-8))).sum(),
"trainable_params": sum(p.numel() for p in self.parameters() if p.requires_grad),
"frozen_params": sum(p.numel() for p in self.parameters() if not p.requires_grad),
}
return result
class PurpleBinaryDataset(Dataset):
"""Tokenized defensive security classification dataset."""
def __init__(self, encodings: dict[str, list[list[int]]], label_indices: list[int]):
if len(encodings.get("input_ids", [])) != len(label_indices):
raise ValueError("input_ids and labels must have the same length")
self.encodings = encodings
self.labels = label_indices
def __len__(self) -> int:
return len(self.labels)
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
item = {
"input_ids": torch.tensor(self.encodings["input_ids"][idx], dtype=torch.long),
"attention_mask": torch.tensor(self.encodings["attention_mask"][idx], dtype=torch.long),
"labels": torch.tensor(self.labels[idx], dtype=torch.long),
}
return item
def tokenize_security_records(
records: list[dict],
tokenizer: TokenizerLike,
*,
label_to_id: dict[str, int],
max_length: int = 2048,
text_key: str = "text",
label_key: str = "label",
) -> PurpleBinaryDataset:
texts: list[str] = []
labels: list[int] = []
for record in records:
text = str(record.get(text_key, "")).strip()
label = str(record.get(label_key, "")).strip()
if not text:
continue
if label not in label_to_id:
raise ValueError(f"unknown label: {label}")
texts.append(text)
labels.append(label_to_id[label])
encoded = tokenizer(
texts,
padding="max_length",
truncation=True,
max_length=max_length,
return_tensors=None,
)
return PurpleBinaryDataset(
{"input_ids": encoded["input_ids"], "attention_mask": encoded["attention_mask"]},
labels,
)
def build_stage1_optimizer(model: StableMeshSecurityClassifier, *, lr: float = 1e-4, weight_decay: float = 0.01):
params = [p for p in model.parameters() if p.requires_grad]
if not params:
raise ValueError("no trainable parameters after stage1 freezing")
return torch.optim.AdamW(params, lr=lr, weight_decay=weight_decay)

Xet Storage Details

Size:
9.21 kB
·
Xet hash:
e4d4f999b18784602a5ed5b1640121cd840f4724eac071b21d24276350281048

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.