| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import math |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from transformers import PretrainedConfig, PreTrainedModel |
| |
|
| |
|
| | class CaptionBertConfig(PretrainedConfig): |
| | model_type = "caption_bert" |
| |
|
| | def __init__( |
| | self, |
| | vocab_size=30522, |
| | max_position_embeddings=8192, |
| | hidden_size=384, |
| | num_attention_heads=6, |
| | num_hidden_layers=6, |
| | intermediate_size=1536, |
| | output_dim=768, |
| | hidden_dropout_prob=0.0, |
| | pad_token_id=0, |
| | |
| | bank_enabled=True, |
| | bank_n_experts=5, |
| | bank_n_anchors=512, |
| | bank_dim=128, |
| | bank_cv_target=0.082, |
| | **kwargs, |
| | ): |
| | super().__init__(pad_token_id=pad_token_id, **kwargs) |
| | self.vocab_size = vocab_size |
| | self.max_position_embeddings = max_position_embeddings |
| | self.hidden_size = hidden_size |
| | self.num_attention_heads = num_attention_heads |
| | self.num_hidden_layers = num_hidden_layers |
| | self.intermediate_size = intermediate_size |
| | self.output_dim = output_dim |
| | self.hidden_dropout_prob = hidden_dropout_prob |
| | self.bank_enabled = bank_enabled |
| | self.bank_n_experts = bank_n_experts |
| | self.bank_n_anchors = bank_n_anchors |
| | self.bank_dim = bank_dim |
| | self.bank_cv_target = bank_cv_target |
| |
|
| |
|
| | class AlignmentBank(nn.Module): |
| | """ |
| | Geometric interface layer preserving 5-expert differentiation structure. |
| | |
| | Trained post-hoc on frozen encoder via GPA + whitened Procrustes. |
| | Stores per-expert rotation matrices, whiteners, and means that encode |
| | how each expert's geometric perspective differs from the consensus center. |
| | |
| | Provides geometric context annotations (128-dim) alongside the core |
| | 768-dim consensus embedding for downstream heads. |
| | """ |
| | def __init__(self, d_embed=768, n_experts=5, n_anchors=512, d_bank=128): |
| | super().__init__() |
| | self.d_embed = d_embed |
| | self.n_experts = n_experts |
| | self.n_anchors = n_anchors |
| | self.d_bank = d_bank |
| |
|
| | |
| | self.expert_rotations = nn.ParameterList([ |
| | nn.Parameter(torch.eye(d_embed)) for _ in range(n_experts)]) |
| | self.expert_whiteners = nn.ParameterList([ |
| | nn.Parameter(torch.eye(d_embed)) for _ in range(n_experts)]) |
| | self.expert_means = nn.ParameterList([ |
| | nn.Parameter(torch.zeros(d_embed)) for _ in range(n_experts)]) |
| |
|
| | |
| | self.anchors = nn.Parameter( |
| | F.normalize(torch.randn(n_anchors, d_embed), dim=-1)) |
| |
|
| | |
| | n_cross = n_experts * (n_experts - 1) // 2 |
| | geo_dim = n_experts + n_experts + n_cross + 1 + n_experts + n_anchors |
| | self.geo_proj = nn.Sequential( |
| | nn.Linear(geo_dim, d_bank * 2), nn.GELU(), nn.LayerNorm(d_bank * 2), |
| | nn.Linear(d_bank * 2, d_bank), nn.LayerNorm(d_bank)) |
| |
|
| | |
| | self.register_buffer("target_cv", torch.tensor(0.082)) |
| | self.register_buffer("target_mean_cos", torch.tensor(0.0)) |
| | self.register_buffer("target_spectral", torch.zeros(50)) |
| | self.register_buffer("target_cross_cos_mean", torch.tensor(0.0)) |
| | self.register_buffer("target_cross_cos_std", torch.tensor(0.0)) |
| | self.register_buffer("target_disagreement_ratio", torch.tensor(0.0)) |
| |
|
| | def forward(self, embedding): |
| | B = embedding.shape[0] |
| | emb = embedding.float() |
| |
|
| | |
| | expert_consistency = [] |
| | expert_recon = [] |
| | expert_projected = [] |
| | for i in range(self.n_experts): |
| | R = self.expert_rotations[i] |
| | W = self.expert_whiteners[i] |
| | mu = self.expert_means[i] |
| | centered = emb - mu |
| | whitened = centered @ W |
| | whitened_n = F.normalize(whitened, dim=-1) |
| | in_expert = whitened_n @ R.T |
| | back = in_expert @ R |
| | cos = F.cosine_similarity(whitened_n, back, dim=-1) |
| | recon = (whitened_n - back).pow(2).mean(dim=-1) |
| | expert_consistency.append(cos) |
| | expert_recon.append(recon) |
| | expert_projected.append(in_expert) |
| |
|
| | expert_cos = torch.stack(expert_consistency, dim=-1) |
| | expert_mse = torch.stack(expert_recon, dim=-1) |
| |
|
| | |
| | cross_cos = [] |
| | for i in range(self.n_experts): |
| | for j in range(i + 1, self.n_experts): |
| | cc = F.cosine_similarity( |
| | expert_projected[i], expert_projected[j], dim=-1) |
| | cross_cos.append(cc) |
| | cross_features = torch.stack(cross_cos, dim=-1) |
| |
|
| | |
| | per_sample_agreement = expert_cos.mean(dim=-1) |
| | per_sample_disagreement = expert_cos.std(dim=-1) |
| | disagreement_ratio = per_sample_disagreement / (per_sample_agreement + 1e-8) |
| |
|
| | |
| | expert_norms = [] |
| | for i in range(self.n_experts): |
| | W = self.expert_whiteners[i]; mu = self.expert_means[i] |
| | whitened = (emb - mu) @ W |
| | expert_norms.append(whitened.norm(dim=-1)) |
| | norm_ratio = torch.stack(expert_norms, dim=-1) |
| | norm_ratio = norm_ratio / (norm_ratio.mean(dim=-1, keepdim=True) + 1e-8) |
| |
|
| | |
| | anchors_n = F.normalize(self.anchors, dim=-1) |
| | anchor_cos = emb @ anchors_n.T |
| |
|
| | |
| | geo_input = torch.cat([ |
| | expert_cos, expert_mse, cross_features, |
| | disagreement_ratio.unsqueeze(-1), norm_ratio, anchor_cos |
| | ], dim=-1) |
| | geo_context = self.geo_proj(geo_input) |
| | enriched = torch.cat([embedding, geo_context], dim=-1) |
| |
|
| | |
| | diagnostics = { |
| | "expert_cos_mean": expert_cos.mean().item(), |
| | "expert_cos_std": expert_cos.std().item(), |
| | "cross_expert_cos": cross_features.mean().item(), |
| | "cross_expert_cos_std": cross_features.std().item(), |
| | "anchor_max_cos": anchor_cos.max(dim=-1).values.mean().item(), |
| | "anchor_mean_cos": anchor_cos.mean().item(), |
| | "disagreement_ratio": disagreement_ratio.mean().item(), |
| | "norm_ratio_spread": norm_ratio.std(dim=-1).mean().item(), |
| | } |
| |
|
| | return enriched, geo_context, diagnostics |
| |
|
| |
|
| | class CaptionBertModel(PreTrainedModel): |
| | """ |
| | Consensus-distilled caption encoder with geometric alignment bank. |
| | |
| | The encoder produces L2-normalized 768-dim embeddings in the geometric |
| | consensus space of 5 BERT-family models (BERT, ModernBERT, RoBERTa, |
| | ALBERT, DistilBERT), aligned via Generalized Procrustes Analysis. |
| | |
| | The alignment bank annotates each embedding with 128-dim geometric |
| | context from the 5-expert differentiation structure β per-expert |
| | consistency, cross-expert disagreement, and anchor distances. |
| | |
| | Output fields: |
| | last_hidden_state: (B, 768) L2-normalized consensus embedding |
| | pooler_output: (B, 768) same (HF compatibility) |
| | token_embeddings: (B, L, 384) pre-pooling token representations |
| | enriched: (B, 896) embedding + bank geometric context |
| | geometric_context: dict expert cos, cross-expert, anchors, etc. |
| | hidden_states: tuple per-layer outputs (if requested) |
| | """ |
| | config_class = CaptionBertConfig |
| |
|
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.config = config |
| |
|
| | |
| | self.token_emb = nn.Embedding( |
| | config.vocab_size, config.hidden_size, |
| | padding_idx=config.pad_token_id) |
| | self.pos_emb = nn.Embedding( |
| | config.max_position_embeddings, config.hidden_size) |
| | self.emb_norm = nn.LayerNorm(config.hidden_size) |
| | self.emb_drop = nn.Dropout(config.hidden_dropout_prob) |
| |
|
| | encoder_layer = nn.TransformerEncoderLayer( |
| | d_model=config.hidden_size, |
| | nhead=config.num_attention_heads, |
| | dim_feedforward=config.intermediate_size, |
| | dropout=config.hidden_dropout_prob, |
| | activation="gelu", |
| | batch_first=True, |
| | norm_first=True, |
| | ) |
| | self.encoder = nn.TransformerEncoder( |
| | encoder_layer, num_layers=config.num_hidden_layers, |
| | enable_nested_tensor=False) |
| |
|
| | self.output_proj = nn.Sequential( |
| | nn.Linear(config.hidden_size, config.hidden_size), |
| | nn.GELU(), |
| | nn.LayerNorm(config.hidden_size), |
| | nn.Linear(config.hidden_size, config.output_dim), |
| | ) |
| |
|
| | |
| | if getattr(config, 'bank_enabled', False): |
| | self.bank = AlignmentBank( |
| | d_embed=config.output_dim, |
| | n_experts=config.bank_n_experts, |
| | n_anchors=config.bank_n_anchors, |
| | d_bank=config.bank_dim, |
| | ) |
| | else: |
| | self.bank = None |
| |
|
| | self.post_init() |
| |
|
| | def forward(self, input_ids=None, attention_mask=None, |
| | output_hidden_states=False, **kwargs): |
| | B, L = input_ids.shape |
| | device = input_ids.device |
| |
|
| | |
| | positions = torch.arange(L, device=device).unsqueeze(0) |
| | x = self.token_emb(input_ids) + self.pos_emb(positions) |
| | x = self.emb_drop(self.emb_norm(x)) |
| |
|
| | if attention_mask is not None: |
| | key_padding_mask = ~attention_mask.bool() |
| | else: |
| | key_padding_mask = (input_ids == self.config.pad_token_id) |
| |
|
| | hidden_states = [x] if output_hidden_states else None |
| | for layer in self.encoder.layers: |
| | x = layer(x, src_key_padding_mask=key_padding_mask) |
| | if output_hidden_states: |
| | hidden_states.append(x) |
| |
|
| | |
| | if attention_mask is not None: |
| | mask = attention_mask.unsqueeze(-1).float() |
| | else: |
| | mask = (~key_padding_mask).unsqueeze(-1).float() |
| | pooled = (x * mask).sum(1) / mask.sum(1).clamp(min=1) |
| | embedding = F.normalize(self.output_proj(pooled), dim=-1) |
| |
|
| | |
| | enriched = None |
| | geo_diagnostics = None |
| | if self.bank is not None: |
| | enriched, _, geo_diagnostics = self.bank(embedding) |
| |
|
| | |
| | result = { |
| | 'last_hidden_state': embedding, |
| | 'pooler_output': embedding, |
| | 'token_embeddings': x, |
| | 'enriched': enriched, |
| | 'geometric_context': geo_diagnostics, |
| | } |
| | if output_hidden_states: |
| | result['hidden_states'] = tuple(hidden_states) |
| |
|
| | return type('Output', (), result)() |
| |
|
| | def encode(self, texts, tokenizer=None, max_length=512, batch_size=128, |
| | device=None): |
| | """Convenience: raw text β L2-normalized (N, 768) embeddings.""" |
| | if isinstance(texts, str): |
| | texts = [texts] |
| | if tokenizer is None: |
| | from transformers import AutoTokenizer |
| | tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased") |
| | if device is None: |
| | device = next(self.parameters()).device |
| | self.eval() |
| | all_emb = [] |
| | with torch.no_grad(): |
| | for i in range(0, len(texts), batch_size): |
| | batch = texts[i:i+batch_size] |
| | inputs = tokenizer( |
| | batch, max_length=max_length, padding="max_length", |
| | truncation=True, return_tensors="pt" |
| | ).to(device) |
| | out = self(input_ids=inputs["input_ids"], |
| | attention_mask=inputs["attention_mask"]) |
| | all_emb.append(out.last_hidden_state.cpu()) |
| | return torch.cat(all_emb) |