| | """ |
| | Collins-RoPE 极简 Embedding 模型(HuggingFace 原生实现) |
| | 架构:Hash Embedding (2-Universal + Sign Hash) -> RoPE -> Transformer Encoder -> Mean Pooling |
| | 目标参数量:~2M |
| | """ |
| |
|
| | import math |
| | from dataclasses import dataclass |
| | from typing import Optional |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from transformers import PretrainedConfig, PreTrainedModel |
| | from transformers.modeling_outputs import BaseModelOutput |
| |
|
| |
|
| | class CollinsConfig(PretrainedConfig): |
| | model_type = "collins" |
| |
|
| | def __init__( |
| | self, |
| | vocab_size: int = 30522, |
| | num_buckets: int = 2048, |
| | hidden_size: int = 256, |
| | num_hidden_layers: int = 3, |
| | num_attention_heads: int = 8, |
| | intermediate_size: int = 1024, |
| | hidden_dropout_prob: float = 0.1, |
| | attention_probs_dropout_prob: float = 0.1, |
| | max_position_embeddings: int = 512, |
| | |
| | hash_seed: int = 42, |
| | **kwargs, |
| | ): |
| | super().__init__(**kwargs) |
| | self.vocab_size = vocab_size |
| | self.num_buckets = num_buckets |
| | self.hidden_size = hidden_size |
| | self.num_hidden_layers = num_hidden_layers |
| | self.num_attention_heads = num_attention_heads |
| | self.intermediate_size = intermediate_size |
| | self.hidden_dropout_prob = hidden_dropout_prob |
| | self.attention_probs_dropout_prob = attention_probs_dropout_prob |
| | self.max_position_embeddings = max_position_embeddings |
| | self.hash_seed = hash_seed |
| |
|
| |
|
| | class CollinsHashEmbedding(nn.Module): |
| | """ |
| | 2-Universal Hash + Sign Hash 压缩 Embedding。 |
| | 哈希参数从 config.hash_seed 确定性生成,保证 save/load 后一致。 |
| | """ |
| |
|
| | def __init__(self, config: CollinsConfig): |
| | super().__init__() |
| | self.num_buckets = config.num_buckets |
| | self.hidden_size = config.hidden_size |
| |
|
| | self.hash_table = nn.Parameter( |
| | torch.randn(config.num_buckets, config.hidden_size) |
| | / math.sqrt(config.hidden_size) |
| | ) |
| |
|
| | prime = 2147483647 |
| | rng = torch.Generator() |
| | rng.manual_seed(config.hash_seed) |
| | a1 = torch.randint(1, prime, (1,), generator=rng, dtype=torch.long) |
| | b1 = torch.randint(0, prime, (1,), generator=rng, dtype=torch.long) |
| | a2 = torch.randint(1, prime, (1,), generator=rng, dtype=torch.long) |
| | b2 = torch.randint(0, prime, (1,), generator=rng, dtype=torch.long) |
| |
|
| | self.register_buffer("prime", torch.tensor(prime, dtype=torch.long)) |
| | self.register_buffer("a1", a1) |
| | self.register_buffer("b1", b1) |
| | self.register_buffer("a2", a2) |
| | self.register_buffer("b2", b2) |
| |
|
| | def forward(self, input_ids: torch.Tensor) -> torch.Tensor: |
| | x = input_ids.long() |
| | bucket_idx = ((x * self.a1 + self.b1) % self.prime) % self.num_buckets |
| | sign = ((x * self.a2 + self.b2) % self.prime) % 2 |
| | sign = (sign * 2 - 1).float() |
| | return self.hash_table[bucket_idx] * sign.unsqueeze(-1) |
| |
|
| |
|
| | class CollinsModel(PreTrainedModel): |
| | """ |
| | Collins-RoPE Encoder,输出 last_hidden_state 和 pooler_output。 |
| | 使用 transformers.models.bert 的 BertEncoder + RoPE 替换 BertEmbeddings。 |
| | """ |
| |
|
| | config_class = CollinsConfig |
| | base_model_prefix = "collins" |
| | supports_gradient_checkpointing = True |
| |
|
| | def __init__(self, config: CollinsConfig): |
| | super().__init__(config) |
| | self.config = config |
| |
|
| | self.embeddings = CollinsHashEmbedding(config) |
| |
|
| | |
| | from transformers.models.bert.modeling_bert import BertEncoder, BertConfig |
| |
|
| | bert_cfg = BertConfig( |
| | hidden_size=config.hidden_size, |
| | num_hidden_layers=config.num_hidden_layers, |
| | num_attention_heads=config.num_attention_heads, |
| | intermediate_size=config.intermediate_size, |
| | hidden_dropout_prob=config.hidden_dropout_prob, |
| | attention_probs_dropout_prob=config.attention_probs_dropout_prob, |
| | max_position_embeddings=config.max_position_embeddings, |
| | |
| | position_embedding_type="relative_key_query", |
| | ) |
| | bert_cfg._attn_implementation = "eager" |
| | self.encoder = BertEncoder(bert_cfg) |
| |
|
| | |
| | dim = config.hidden_size |
| | inv_freq = 1.0 / ( |
| | 10000 ** (torch.arange(0, dim, 2).float() / dim) |
| | ) |
| | t = torch.arange(config.max_position_embeddings).float() |
| | freqs = torch.einsum("i,j->ij", t, inv_freq) |
| | self.register_buffer("rope_cos", freqs.cos()) |
| | self.register_buffer("rope_sin", freqs.sin()) |
| |
|
| | self.post_init() |
| |
|
| | def _apply_rope(self, x: torch.Tensor) -> torch.Tensor: |
| | seq_len = x.shape[1] |
| | cos = self.rope_cos[:seq_len].unsqueeze(0) |
| | sin = self.rope_sin[:seq_len].unsqueeze(0) |
| | x1, x2 = x[..., 0::2], x[..., 1::2] |
| | return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1) |
| |
|
| | def get_extended_attention_mask(self, attention_mask: torch.Tensor) -> torch.Tensor: |
| | |
| | extended = attention_mask[:, None, None, :] |
| | extended = (1.0 - extended.float()) * torch.finfo(torch.float32).min |
| | return extended |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | return_dict: bool = True, |
| | ): |
| | if attention_mask is None: |
| | attention_mask = torch.ones_like(input_ids) |
| |
|
| | x = self.embeddings(input_ids) |
| | x = self._apply_rope(x) |
| |
|
| | ext_mask = self.get_extended_attention_mask(attention_mask) |
| | encoder_out = self.encoder(x, attention_mask=ext_mask) |
| | hidden_states = encoder_out.last_hidden_state |
| |
|
| | |
| | mask = attention_mask.unsqueeze(-1).float() |
| | pooled = (hidden_states * mask).sum(1) / mask.sum(1).clamp(min=1e-9) |
| | pooled = F.normalize(pooled, p=2, dim=-1) |
| |
|
| | if not return_dict: |
| | return (hidden_states, pooled) |
| |
|
| | return BaseModelOutput( |
| | last_hidden_state=hidden_states, |
| | hidden_states=None, |
| | attentions=None, |
| | ), pooled |
| |
|
| |
|
| | class CollinsSTWrapper(nn.Module): |
| | """ |
| | sentence-transformers 5.x 兼容包装层。 |
| | 持有 tokenizer,实现 tokenize() 接口,同时注入 sentence_embedding。 |
| | """ |
| |
|
| | def __init__(self, collins_model: CollinsModel, tokenizer_name_or_path: str = "bert-base-uncased", max_seq_length: int = 128): |
| | super().__init__() |
| | from transformers import AutoTokenizer |
| | self.collins_model = collins_model |
| | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) |
| | self.max_seq_length = max_seq_length |
| |
|
| | def tokenize(self, texts: list[str], padding: str | bool = True) -> dict: |
| | return self.tokenizer( |
| | texts, |
| | padding=padding, |
| | truncation=True, |
| | max_length=self.max_seq_length, |
| | return_tensors="pt", |
| | ) |
| |
|
| | def forward(self, features: dict) -> dict: |
| | input_ids = features["input_ids"] |
| | attention_mask = features.get("attention_mask", None) |
| | _, pooled = self.collins_model(input_ids, attention_mask) |
| | features["sentence_embedding"] = pooled |
| | return features |
| |
|
| | def save(self, output_path: str): |
| | self.collins_model.save_pretrained(output_path) |
| | self.tokenizer.save_pretrained(output_path) |
| |
|
| | @staticmethod |
| | def load(input_path: str) -> "CollinsSTWrapper": |
| | model = CollinsModel.from_pretrained(input_path) |
| | return CollinsSTWrapper(model, tokenizer_name_or_path=input_path) |
| |
|