""" Collins-RoPE 极简 Embedding 模型(HuggingFace 原生实现) 架构:Hash Embedding (2-Universal + Sign Hash) -> RoPE -> Transformer Encoder -> Mean Pooling 目标参数量:~2M """ import math from dataclasses import dataclass from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from transformers import PretrainedConfig, PreTrainedModel from transformers.modeling_outputs import BaseModelOutput class CollinsConfig(PretrainedConfig): model_type = "collins" def __init__( self, vocab_size: int = 30522, num_buckets: int = 2048, hidden_size: int = 256, num_hidden_layers: int = 3, num_attention_heads: int = 8, intermediate_size: int = 1024, hidden_dropout_prob: float = 0.1, attention_probs_dropout_prob: float = 0.1, max_position_embeddings: int = 512, # 2-Universal Hash 固定种子(保证 load 后哈希一致) hash_seed: int = 42, **kwargs, ): super().__init__(**kwargs) self.vocab_size = vocab_size self.num_buckets = num_buckets self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.intermediate_size = intermediate_size self.hidden_dropout_prob = hidden_dropout_prob self.attention_probs_dropout_prob = attention_probs_dropout_prob self.max_position_embeddings = max_position_embeddings self.hash_seed = hash_seed class CollinsHashEmbedding(nn.Module): """ 2-Universal Hash + Sign Hash 压缩 Embedding。 哈希参数从 config.hash_seed 确定性生成,保证 save/load 后一致。 """ def __init__(self, config: CollinsConfig): super().__init__() self.num_buckets = config.num_buckets self.hidden_size = config.hidden_size self.hash_table = nn.Parameter( torch.randn(config.num_buckets, config.hidden_size) / math.sqrt(config.hidden_size) ) prime = 2147483647 # 梅森素数 2^31 - 1 rng = torch.Generator() rng.manual_seed(config.hash_seed) a1 = torch.randint(1, prime, (1,), generator=rng, dtype=torch.long) b1 = torch.randint(0, prime, (1,), generator=rng, dtype=torch.long) a2 = torch.randint(1, prime, (1,), generator=rng, dtype=torch.long) b2 = torch.randint(0, prime, (1,), generator=rng, dtype=torch.long) self.register_buffer("prime", torch.tensor(prime, dtype=torch.long)) self.register_buffer("a1", a1) self.register_buffer("b1", b1) self.register_buffer("a2", a2) self.register_buffer("b2", b2) def forward(self, input_ids: torch.Tensor) -> torch.Tensor: x = input_ids.long() bucket_idx = ((x * self.a1 + self.b1) % self.prime) % self.num_buckets sign = ((x * self.a2 + self.b2) % self.prime) % 2 sign = (sign * 2 - 1).float() return self.hash_table[bucket_idx] * sign.unsqueeze(-1) class CollinsModel(PreTrainedModel): """ Collins-RoPE Encoder,输出 last_hidden_state 和 pooler_output。 使用 transformers.models.bert 的 BertEncoder + RoPE 替换 BertEmbeddings。 """ config_class = CollinsConfig base_model_prefix = "collins" supports_gradient_checkpointing = True def __init__(self, config: CollinsConfig): super().__init__(config) self.config = config self.embeddings = CollinsHashEmbedding(config) # 直接复用 HF BertEncoder(含 Multi-Head Attention + FFN + LayerNorm) from transformers.models.bert.modeling_bert import BertEncoder, BertConfig bert_cfg = BertConfig( hidden_size=config.hidden_size, num_hidden_layers=config.num_hidden_layers, num_attention_heads=config.num_attention_heads, intermediate_size=config.intermediate_size, hidden_dropout_prob=config.hidden_dropout_prob, attention_probs_dropout_prob=config.attention_probs_dropout_prob, max_position_embeddings=config.max_position_embeddings, # 关闭 Bert 自带的位置编码,我们用 RoPE position_embedding_type="relative_key_query", ) bert_cfg._attn_implementation = "eager" self.encoder = BertEncoder(bert_cfg) # RoPE 频率缓冲(无参数) dim = config.hidden_size inv_freq = 1.0 / ( 10000 ** (torch.arange(0, dim, 2).float() / dim) ) t = torch.arange(config.max_position_embeddings).float() freqs = torch.einsum("i,j->ij", t, inv_freq) self.register_buffer("rope_cos", freqs.cos()) self.register_buffer("rope_sin", freqs.sin()) self.post_init() def _apply_rope(self, x: torch.Tensor) -> torch.Tensor: seq_len = x.shape[1] cos = self.rope_cos[:seq_len].unsqueeze(0) sin = self.rope_sin[:seq_len].unsqueeze(0) x1, x2 = x[..., 0::2], x[..., 1::2] return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1) def get_extended_attention_mask(self, attention_mask: torch.Tensor) -> torch.Tensor: # BertEncoder 需要 [B, 1, 1, L] 形式的 mask,0 = 保留,-inf = 忽略 extended = attention_mask[:, None, None, :] extended = (1.0 - extended.float()) * torch.finfo(torch.float32).min return extended def forward( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, ): if attention_mask is None: attention_mask = torch.ones_like(input_ids) x = self.embeddings(input_ids) # [B, L, D] x = self._apply_rope(x) # [B, L, D] ext_mask = self.get_extended_attention_mask(attention_mask) encoder_out = self.encoder(x, attention_mask=ext_mask) hidden_states = encoder_out.last_hidden_state # [B, L, D] # Mean Pooling mask = attention_mask.unsqueeze(-1).float() pooled = (hidden_states * mask).sum(1) / mask.sum(1).clamp(min=1e-9) pooled = F.normalize(pooled, p=2, dim=-1) if not return_dict: return (hidden_states, pooled) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=None, attentions=None, ), pooled class CollinsSTWrapper(nn.Module): """ sentence-transformers 5.x 兼容包装层。 持有 tokenizer,实现 tokenize() 接口,同时注入 sentence_embedding。 """ def __init__(self, collins_model: CollinsModel, tokenizer_name_or_path: str = "bert-base-uncased", max_seq_length: int = 128): super().__init__() from transformers import AutoTokenizer self.collins_model = collins_model self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) self.max_seq_length = max_seq_length def tokenize(self, texts: list[str], padding: str | bool = True) -> dict: return self.tokenizer( texts, padding=padding, truncation=True, max_length=self.max_seq_length, return_tensors="pt", ) def forward(self, features: dict) -> dict: input_ids = features["input_ids"] attention_mask = features.get("attention_mask", None) _, pooled = self.collins_model(input_ids, attention_mask) features["sentence_embedding"] = pooled return features def save(self, output_path: str): self.collins_model.save_pretrained(output_path) self.tokenizer.save_pretrained(output_path) @staticmethod def load(input_path: str) -> "CollinsSTWrapper": model = CollinsModel.from_pretrained(input_path) return CollinsSTWrapper(model, tokenizer_name_or_path=input_path)