geolip-captionbert-8192 / modeling_caption_bert.py
AbstractPhil's picture
Update modeling_caption_bert.py
34d35c5 verified
# ============================================================================
# CaptionBERT-8192: HuggingFace AutoModel with Alignment Bank
#
# Usage:
# from transformers import AutoModel, AutoTokenizer
# model = AutoModel.from_pretrained("AbstractPhil/geolip-captionbert-8192",
# trust_remote_code=True)
# tokenizer = AutoTokenizer.from_pretrained("AbstractPhil/geolip-captionbert-8192",
# trust_remote_code=True)
# inputs = tokenizer("A cat on a windowsill", return_tensors="pt",
# padding=True, truncation=True, max_length=512)
# outputs = model(**inputs)
#
# # Core embedding (consensus-distilled, L2-normalized)
# embedding = outputs.last_hidden_state # (B, 768)
#
# # Enriched embedding (with geometric context from 5-expert bank)
# enriched = outputs.enriched # (B, 768 + bank_dim)
#
# # Token-level representations (pre-pooling, for sequence tasks)
# tokens = outputs.token_embeddings # (B, L, 384)
#
# # Geometric diagnostics
# geo = outputs.geometric_context # dict with expert cos, anchors, etc.
# ============================================================================
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig, PreTrainedModel
class CaptionBertConfig(PretrainedConfig):
model_type = "caption_bert"
def __init__(
self,
vocab_size=30522,
max_position_embeddings=8192,
hidden_size=384,
num_attention_heads=6,
num_hidden_layers=6,
intermediate_size=1536,
output_dim=768,
hidden_dropout_prob=0.0,
pad_token_id=0,
# Alignment bank
bank_enabled=True,
bank_n_experts=5,
bank_n_anchors=512,
bank_dim=128,
bank_cv_target=0.082,
**kwargs,
):
super().__init__(pad_token_id=pad_token_id, **kwargs)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.num_hidden_layers = num_hidden_layers
self.intermediate_size = intermediate_size
self.output_dim = output_dim
self.hidden_dropout_prob = hidden_dropout_prob
self.bank_enabled = bank_enabled
self.bank_n_experts = bank_n_experts
self.bank_n_anchors = bank_n_anchors
self.bank_dim = bank_dim
self.bank_cv_target = bank_cv_target
class AlignmentBank(nn.Module):
"""
Geometric interface layer preserving 5-expert differentiation structure.
Trained post-hoc on frozen encoder via GPA + whitened Procrustes.
Stores per-expert rotation matrices, whiteners, and means that encode
how each expert's geometric perspective differs from the consensus center.
Provides geometric context annotations (128-dim) alongside the core
768-dim consensus embedding for downstream heads.
"""
def __init__(self, d_embed=768, n_experts=5, n_anchors=512, d_bank=128):
super().__init__()
self.d_embed = d_embed
self.n_experts = n_experts
self.n_anchors = n_anchors
self.d_bank = d_bank
# Per-expert Procrustes components (the differentiation structure)
self.expert_rotations = nn.ParameterList([
nn.Parameter(torch.eye(d_embed)) for _ in range(n_experts)])
self.expert_whiteners = nn.ParameterList([
nn.Parameter(torch.eye(d_embed)) for _ in range(n_experts)])
self.expert_means = nn.ParameterList([
nn.Parameter(torch.zeros(d_embed)) for _ in range(n_experts)])
# Consensus landmarks on the hypersphere
self.anchors = nn.Parameter(
F.normalize(torch.randn(n_anchors, d_embed), dim=-1))
# Geometric context projection
n_cross = n_experts * (n_experts - 1) // 2
geo_dim = n_experts + n_experts + n_cross + 1 + n_experts + n_anchors
self.geo_proj = nn.Sequential(
nn.Linear(geo_dim, d_bank * 2), nn.GELU(), nn.LayerNorm(d_bank * 2),
nn.Linear(d_bank * 2, d_bank), nn.LayerNorm(d_bank))
# Calibrated consensus targets (preserved from training)
self.register_buffer("target_cv", torch.tensor(0.082))
self.register_buffer("target_mean_cos", torch.tensor(0.0))
self.register_buffer("target_spectral", torch.zeros(50))
self.register_buffer("target_cross_cos_mean", torch.tensor(0.0))
self.register_buffer("target_cross_cos_std", torch.tensor(0.0))
self.register_buffer("target_disagreement_ratio", torch.tensor(0.0))
def forward(self, embedding):
B = embedding.shape[0]
emb = embedding.float()
# Full whitened Procrustes per expert: center β†’ whiten β†’ normalize β†’ rotate
expert_consistency = []
expert_recon = []
expert_projected = []
for i in range(self.n_experts):
R = self.expert_rotations[i]
W = self.expert_whiteners[i]
mu = self.expert_means[i]
centered = emb - mu
whitened = centered @ W
whitened_n = F.normalize(whitened, dim=-1)
in_expert = whitened_n @ R.T
back = in_expert @ R
cos = F.cosine_similarity(whitened_n, back, dim=-1)
recon = (whitened_n - back).pow(2).mean(dim=-1)
expert_consistency.append(cos)
expert_recon.append(recon)
expert_projected.append(in_expert)
expert_cos = torch.stack(expert_consistency, dim=-1)
expert_mse = torch.stack(expert_recon, dim=-1)
# Cross-expert differentiation (10 pairs for 5 experts)
cross_cos = []
for i in range(self.n_experts):
for j in range(i + 1, self.n_experts):
cc = F.cosine_similarity(
expert_projected[i], expert_projected[j], dim=-1)
cross_cos.append(cc)
cross_features = torch.stack(cross_cos, dim=-1)
# Per-sample disagreement
per_sample_agreement = expert_cos.mean(dim=-1)
per_sample_disagreement = expert_cos.std(dim=-1)
disagreement_ratio = per_sample_disagreement / (per_sample_agreement + 1e-8)
# Expert norm ratios
expert_norms = []
for i in range(self.n_experts):
W = self.expert_whiteners[i]; mu = self.expert_means[i]
whitened = (emb - mu) @ W
expert_norms.append(whitened.norm(dim=-1))
norm_ratio = torch.stack(expert_norms, dim=-1)
norm_ratio = norm_ratio / (norm_ratio.mean(dim=-1, keepdim=True) + 1e-8)
# Anchor distances
anchors_n = F.normalize(self.anchors, dim=-1)
anchor_cos = emb @ anchors_n.T
# Geometric context vector
geo_input = torch.cat([
expert_cos, expert_mse, cross_features,
disagreement_ratio.unsqueeze(-1), norm_ratio, anchor_cos
], dim=-1)
geo_context = self.geo_proj(geo_input)
enriched = torch.cat([embedding, geo_context], dim=-1)
# Diagnostics
diagnostics = {
"expert_cos_mean": expert_cos.mean().item(),
"expert_cos_std": expert_cos.std().item(),
"cross_expert_cos": cross_features.mean().item(),
"cross_expert_cos_std": cross_features.std().item(),
"anchor_max_cos": anchor_cos.max(dim=-1).values.mean().item(),
"anchor_mean_cos": anchor_cos.mean().item(),
"disagreement_ratio": disagreement_ratio.mean().item(),
"norm_ratio_spread": norm_ratio.std(dim=-1).mean().item(),
}
return enriched, geo_context, diagnostics
class CaptionBertModel(PreTrainedModel):
"""
Consensus-distilled caption encoder with geometric alignment bank.
The encoder produces L2-normalized 768-dim embeddings in the geometric
consensus space of 5 BERT-family models (BERT, ModernBERT, RoBERTa,
ALBERT, DistilBERT), aligned via Generalized Procrustes Analysis.
The alignment bank annotates each embedding with 128-dim geometric
context from the 5-expert differentiation structure β€” per-expert
consistency, cross-expert disagreement, and anchor distances.
Output fields:
last_hidden_state: (B, 768) L2-normalized consensus embedding
pooler_output: (B, 768) same (HF compatibility)
token_embeddings: (B, L, 384) pre-pooling token representations
enriched: (B, 896) embedding + bank geometric context
geometric_context: dict expert cos, cross-expert, anchors, etc.
hidden_states: tuple per-layer outputs (if requested)
"""
config_class = CaptionBertConfig
def __init__(self, config):
super().__init__(config)
self.config = config
# ── Encoder ──
self.token_emb = nn.Embedding(
config.vocab_size, config.hidden_size,
padding_idx=config.pad_token_id)
self.pos_emb = nn.Embedding(
config.max_position_embeddings, config.hidden_size)
self.emb_norm = nn.LayerNorm(config.hidden_size)
self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
encoder_layer = nn.TransformerEncoderLayer(
d_model=config.hidden_size,
nhead=config.num_attention_heads,
dim_feedforward=config.intermediate_size,
dropout=config.hidden_dropout_prob,
activation="gelu",
batch_first=True,
norm_first=True,
)
self.encoder = nn.TransformerEncoder(
encoder_layer, num_layers=config.num_hidden_layers,
enable_nested_tensor=False)
self.output_proj = nn.Sequential(
nn.Linear(config.hidden_size, config.hidden_size),
nn.GELU(),
nn.LayerNorm(config.hidden_size),
nn.Linear(config.hidden_size, config.output_dim),
)
# ── Alignment Bank ──
if getattr(config, 'bank_enabled', False):
self.bank = AlignmentBank(
d_embed=config.output_dim,
n_experts=config.bank_n_experts,
n_anchors=config.bank_n_anchors,
d_bank=config.bank_dim,
)
else:
self.bank = None
self.post_init()
def forward(self, input_ids=None, attention_mask=None,
output_hidden_states=False, **kwargs):
B, L = input_ids.shape
device = input_ids.device
# ── Encode ──
positions = torch.arange(L, device=device).unsqueeze(0)
x = self.token_emb(input_ids) + self.pos_emb(positions)
x = self.emb_drop(self.emb_norm(x))
if attention_mask is not None:
key_padding_mask = ~attention_mask.bool()
else:
key_padding_mask = (input_ids == self.config.pad_token_id)
hidden_states = [x] if output_hidden_states else None
for layer in self.encoder.layers:
x = layer(x, src_key_padding_mask=key_padding_mask)
if output_hidden_states:
hidden_states.append(x)
# ── Pool + Project ──
if attention_mask is not None:
mask = attention_mask.unsqueeze(-1).float()
else:
mask = (~key_padding_mask).unsqueeze(-1).float()
pooled = (x * mask).sum(1) / mask.sum(1).clamp(min=1)
embedding = F.normalize(self.output_proj(pooled), dim=-1)
# ── Alignment Bank ──
enriched = None
geo_diagnostics = None
if self.bank is not None:
enriched, _, geo_diagnostics = self.bank(embedding)
# ── Output ──
result = {
'last_hidden_state': embedding, # (B, 768)
'pooler_output': embedding, # (B, 768) compat
'token_embeddings': x, # (B, L, 384)
'enriched': enriched, # (B, 896) or None
'geometric_context': geo_diagnostics, # dict or None
}
if output_hidden_states:
result['hidden_states'] = tuple(hidden_states)
return type('Output', (), result)()
def encode(self, texts, tokenizer=None, max_length=512, batch_size=128,
device=None):
"""Convenience: raw text β†’ L2-normalized (N, 768) embeddings."""
if isinstance(texts, str):
texts = [texts]
if tokenizer is None:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
if device is None:
device = next(self.parameters()).device
self.eval()
all_emb = []
with torch.no_grad():
for i in range(0, len(texts), batch_size):
batch = texts[i:i+batch_size]
inputs = tokenizer(
batch, max_length=max_length, padding="max_length",
truncation=True, return_tensors="pt"
).to(device)
out = self(input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"])
all_emb.append(out.last_hidden_state.cpu())
return torch.cat(all_emb)