| import json |
| import logging |
| from pathlib import Path |
|
|
| import torch |
| from transformers import AutoModel, PreTrainedModel |
| from transformers import ModernBertConfig |
|
|
| for _logger_name in ["transformers.modeling_utils", "transformers.configuration_utils"]: |
| logging.getLogger(_logger_name).setLevel(logging.ERROR) |
|
|
| from .configuration_hare import HareConfig |
| from .birwkv7 import BiRWKV7Layer, init_from_attention |
|
|
|
|
| def _find_encoder(model): |
| for attr in ['encoder', 'model']: |
| if hasattr(model, attr): |
| candidate = getattr(model, attr) |
| if hasattr(candidate, 'layers'): |
| return candidate |
| if hasattr(model, 'layers'): |
| return model |
| raise RuntimeError(f"Cannot find encoder layers in {type(model).__name__}") |
|
|
|
|
| def _perform_surgery(model, replaced_layers, hidden_size, num_heads): |
| encoder = _find_encoder(model) |
| for layer_idx_str, info in replaced_layers.items(): |
| layer_idx = int(layer_idx_str) |
| layer = encoder.layers[layer_idx] |
| attn = None |
| attn_name = None |
| for name in ['attn', 'attention', 'self_attn', 'self_attention']: |
| if hasattr(layer, name): |
| attn = getattr(layer, name) |
| attn_name = name |
| break |
| if attn is None: |
| continue |
| birwkv = BiRWKV7Layer(hidden_size, num_heads) |
| device = next(attn.parameters()).device |
| dtype = next(attn.parameters()).dtype |
| birwkv = birwkv.to(device=device, dtype=dtype) |
| setattr(layer, attn_name, birwkv) |
|
|
|
|
| class HareModel(PreTrainedModel): |
| config_class = HareConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| base_config = ModernBertConfig( |
| hidden_size=config.hidden_size, |
| num_attention_heads=config.num_attention_heads, |
| num_hidden_layers=config.num_hidden_layers, |
| intermediate_size=config.intermediate_size, |
| vocab_size=config.vocab_size, |
| max_position_embeddings=config.max_position_embeddings, |
| pad_token_id=config.pad_token_id, |
| bos_token_id=config.bos_token_id, |
| eos_token_id=config.eos_token_id, |
| cls_token_id=getattr(config, 'cls_token_id', config.bos_token_id), |
| sep_token_id=getattr(config, 'sep_token_id', config.eos_token_id), |
| global_attn_every_n_layers=getattr(config, 'global_attn_every_n_layers', 3), |
| local_attention=getattr(config, 'local_attention', 128), |
| ) |
| self.inner_model = AutoModel.from_config(base_config) |
|
|
| if config.replaced_layers: |
| _perform_surgery( |
| self.inner_model, |
| config.replaced_layers, |
| config.hidden_size, |
| config.num_attention_heads, |
| ) |
|
|
| def forward(self, input_ids=None, attention_mask=None, **kwargs): |
| outputs = self.inner_model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| **kwargs, |
| ) |
| return outputs |
|
|
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): |
| model_dir = Path(pretrained_model_name_or_path) |
| surgery_meta_path = model_dir / "surgery_meta.json" |
|
|
| if not surgery_meta_path.exists(): |
| from huggingface_hub import hf_hub_download |
| try: |
| surgery_meta_path = Path(hf_hub_download( |
| pretrained_model_name_or_path, "surgery_meta.json")) |
| model_dir = surgery_meta_path.parent |
| except Exception: |
| return super().from_pretrained( |
| pretrained_model_name_or_path, *args, **kwargs) |
|
|
| with open(surgery_meta_path) as f: |
| meta = json.load(f) |
|
|
| config = cls.config_class.from_pretrained(pretrained_model_name_or_path) |
| config.replaced_layers = meta.get("replaced_layers") |
| config.surgery_variant = meta.get("variant", "conservative") |
|
|
| model = cls(config) |
|
|
| weights_path = model_dir / "model.pt" |
| if not weights_path.exists(): |
| from huggingface_hub import hf_hub_download |
| try: |
| weights_path = Path(hf_hub_download( |
| pretrained_model_name_or_path, "model.pt")) |
| except Exception: |
| pass |
|
|
| if weights_path.exists(): |
| state_dict = torch.load(weights_path, map_location="cpu", weights_only=True) |
| model.inner_model.load_state_dict(state_dict) |
|
|
| return model.float().eval() |
|
|