binomial-marks-1 / modeling_marks.py
ilayibrahimzadeh's picture
Initial publish of binomial-marks-1
f7b715f verified
"""Self-contained model class for binomial-marks-1.
Distributed alongside the weights on HuggingFace Hub so anyone can do:
from transformers import AutoTokenizer, AutoModel
tok = AutoTokenizer.from_pretrained("BinomialTechnologies/binomial-marks-1")
model = AutoModel.from_pretrained("BinomialTechnologies/binomial-marks-1",
trust_remote_code=True)
This file imports only from `transformers` + `torch` — no project-internal
dependencies.
Architecture:
ModernBERT-large encoder (with optional YaRN RoPE extension to 16k)
↓ (CLS + masked mean pool concatenated)
↓ (3 × MLP heads)
23 outputs:
10 × topic_mentioned (binary classification, sigmoid → BCE loss)
10 × topic_score (regression in [-2, +2] after clamp at inference)
3 × tone_score (regression in [1, 5] after clamp at inference)
"""
from __future__ import annotations
import math
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoConfig
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_outputs import ModelOutput
# Relative import — HF's `trust_remote_code` loader bundles sibling .py
# files together and resolves these without the symbol being "installed".
from .configuration_marks import MarksConfig, TOPICS, TONES
# ---------------------------------------------------------------------------
# YaRN RoPE extension — per-dim ramp; applied after model load
# ---------------------------------------------------------------------------
def _yarn_inv_freq(
head_dim: int,
base: float,
scale: float,
original_max_position: int,
beta_fast: float = 32.0,
beta_slow: float = 1.0,
device=None,
dtype=torch.float32,
) -> torch.Tensor:
if scale <= 1.0:
return 1.0 / (base ** (torch.arange(0, head_dim, 2, device=device, dtype=dtype) / head_dim))
inv_freq_extrap = 1.0 / (base ** (torch.arange(0, head_dim, 2, device=device, dtype=dtype) / head_dim))
inv_freq_interp = inv_freq_extrap / scale
wavelengths = 2.0 * math.pi / inv_freq_extrap
L = original_max_position
ramp = (L / wavelengths - beta_slow) / (beta_fast - beta_slow)
ramp = ramp.clamp(0.0, 1.0)
return inv_freq_interp * (1.0 - ramp) + inv_freq_extrap * ramp
def _apply_yarn_to_modernbert(encoder, new_max_position: int,
original_max_position: int = 8192,
beta_fast: float = 32.0, beta_slow: float = 1.0):
if new_max_position == original_max_position:
return
scale = new_max_position / original_max_position
cfg = encoder.config
head_dim = cfg.hidden_size // cfg.num_attention_heads
global_base = float(getattr(cfg, "global_rope_theta", getattr(cfg, "rope_theta", 10000.0)))
rotary_modules = [
m for _, m in encoder.named_modules()
if m.__class__.__name__ == "ModernBertRotaryEmbedding"
]
for mod in rotary_modules:
full_buf = getattr(mod, "full_attention_inv_freq", None)
if full_buf is None or full_buf.numel() != head_dim // 2:
continue
new_inv = _yarn_inv_freq(
head_dim=head_dim, base=global_base, scale=scale,
original_max_position=original_max_position,
beta_fast=beta_fast, beta_slow=beta_slow,
device=full_buf.device, dtype=full_buf.dtype,
)
full_buf.data.copy_(new_inv)
# ---------------------------------------------------------------------------
# Output dataclass
# ---------------------------------------------------------------------------
@dataclass
class MarksOutput(ModelOutput):
loss: Optional[torch.Tensor] = None
loss_components: Optional[dict] = None
topic_mentioned_logits: Optional[torch.Tensor] = None # (B, 10)
topic_score: Optional[torch.Tensor] = None # (B, 10)
tone_score: Optional[torch.Tensor] = None # (B, 3)
# ---------------------------------------------------------------------------
# Model
# ---------------------------------------------------------------------------
class MarksMultiHead(PreTrainedModel):
"""Multi-head ModernBERT-large fine-tuned for earnings-call NLP scoring.
23 outputs per call:
* topic_mentioned (binary, 10 dims)
* topic_score (regression in [-2, +2], 10 dims)
* tone_score (regression in [1, 5], 3 dims)
"""
config_class = MarksConfig
base_model_prefix = "encoder"
supports_gradient_checkpointing = True
def __init__(self, config: MarksConfig):
super().__init__(config)
self.n_topics = len(config.topics)
self.n_tones = len(config.tones)
# Encoder — built from config (so we don't redownload base weights;
# weights come from this repo's safetensors).
if config.encoder_config:
enc_cfg = AutoConfig.from_dict(config.encoder_config) if hasattr(AutoConfig, "from_dict") else AutoConfig.for_model(**config.encoder_config)
else:
enc_cfg = AutoConfig.from_pretrained(config.encoder_name_or_path)
# Override the encoder ctx to the trained value (16384 for our v1).
enc_cfg.max_position_embeddings = config.max_position_embeddings
# Initialize encoder with config-only constructor (random init); the
# PreTrainedModel.from_pretrained caller will restore real weights
# from this repo's safetensors.
self.encoder = AutoModel.from_config(enc_cfg)
H = enc_cfg.hidden_size
# Head input is CLS + mean pool concatenated → 2H.
head_in = 2 * H
head_hidden = H // config.head_dim_ratio
def _mlp(out_dim: int) -> nn.Sequential:
return nn.Sequential(
nn.Linear(head_in, head_hidden),
nn.GELU(),
nn.Dropout(config.dropout),
nn.Linear(head_hidden, out_dim),
)
self.dropout = nn.Dropout(config.dropout)
self.head_topic_mentioned = _mlp(self.n_topics)
self.head_topic_score = _mlp(self.n_topics)
self.head_tone_score = _mlp(self.n_tones)
# Loss weights (used only if labels are passed for fine-tuning).
self._loss_weights = config.loss_weights
# Apply YaRN to encoder (idempotent if max_position == native).
if config.marks_rope_strategy == "yarn":
_apply_yarn_to_modernbert(
self.encoder,
new_max_position=config.max_position_embeddings,
original_max_position=config.original_max_position,
)
# NTK is applied inside encoder config; nothing to do here.
self.post_init()
# -------------------------------------------------------------------------
# Forward
# -------------------------------------------------------------------------
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
topic_mentioned: Optional[torch.Tensor] = None,
topic_score: Optional[torch.Tensor] = None,
tone_score: Optional[torch.Tensor] = None,
**kwargs,
) -> MarksOutput:
out = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
last_hidden = out.last_hidden_state # (B, T, H)
cls = last_hidden[:, 0] # (B, H)
m = attention_mask.unsqueeze(-1).to(last_hidden.dtype)
mean_pool = (last_hidden * m).sum(1) / m.sum(1).clamp(min=1.0) # (B, H)
pooled = self.dropout(torch.cat([cls, mean_pool], dim=-1)) # (B, 2H)
tm_logits = self.head_topic_mentioned(pooled)
ts_pred = self.head_topic_score(pooled)
tn_pred = self.head_tone_score(pooled)
loss, components = None, {}
if topic_mentioned is not None:
tm_logits_fp = tm_logits.float()
ts_pred_fp = ts_pred.float()
tn_pred_fp = tn_pred.float()
tm_t = topic_mentioned.float()
ts_t = topic_score.float()
tn_t = tone_score.float()
l_tm = F.binary_cross_entropy_with_logits(tm_logits_fp, tm_t)
l_ts = F.mse_loss(ts_pred_fp, ts_t)
l_tn = F.mse_loss(tn_pred_fp, tn_t)
components = {
"topic_mentioned": l_tm.detach(),
"topic_score": l_ts.detach(),
"tone_scores": l_tn.detach(),
}
w = self._loss_weights
loss = (
w["topic_mentioned"] * l_tm
+ w["topic_score"] * l_ts
+ w["tone_scores"] * l_tn
)
return MarksOutput(
loss=loss,
loss_components=components or None,
topic_mentioned_logits=tm_logits,
topic_score=ts_pred,
tone_score=tn_pred,
)
# -------------------------------------------------------------------------
# Convenience predict
# -------------------------------------------------------------------------
@torch.no_grad()
def predict(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
mention_threshold: float = 0.5,
) -> dict:
"""Run a forward pass and return clamped + masked predictions.
Returns a dict with:
topic_mentioned (B, 10) hard 0/1
topic_mentioned_prob (B, 10) sigmoid confidence
topic_score (B, 10) clamped to [-2, +2], zeroed where mentioned=0
tone_score (B, 3) clamped to [1, 5]
"""
out = self.forward(input_ids=input_ids, attention_mask=attention_mask)
prob = torch.sigmoid(out.topic_mentioned_logits)
mentioned = (prob >= mention_threshold).float()
ts_lo, ts_hi = self.config.topic_score_range
tn_lo, tn_hi = self.config.tone_score_range
ts = out.topic_score.clamp(ts_lo, ts_hi) * mentioned
tn = out.tone_score.clamp(tn_lo, tn_hi)
return {
"topic_mentioned": mentioned,
"topic_mentioned_prob": prob,
"topic_score": ts,
"tone_score": tn,
}
# -------------------------------------------------------------------------
# Gradient checkpointing — delegate to encoder
# -------------------------------------------------------------------------
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
if hasattr(self.encoder, "gradient_checkpointing_enable"):
self.encoder.gradient_checkpointing_enable(
gradient_checkpointing_kwargs=gradient_checkpointing_kwargs
)
def gradient_checkpointing_disable(self):
if hasattr(self.encoder, "gradient_checkpointing_disable"):
self.encoder.gradient_checkpointing_disable()