tvastr commited on
Commit
eb5ce6f
·
verified ·
1 Parent(s): 984ce24

chore: add modeling_rabbit.py (safety-scrubbed AutoModel wrapper)

Browse files
Files changed (1) hide show
  1. modeling_rabbit.py +68 -0
modeling_rabbit.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RabbitForCausalLM — AutoModel-compatible wrapper for Anvaya-Rabbit.
3
+
4
+ pip install rtaforge transformers
5
+ model = AutoModelForCausalLM.from_pretrained(
6
+ "RtaForge/Anvaya-Rabbit-2.7B", trust_remote_code=True
7
+ )
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import torch
13
+ from transformers import PreTrainedModel
14
+ from transformers.modeling_outputs import CausalLMOutputWithPast
15
+
16
+ try:
17
+ from configuration_rabbit import RabbitConfig
18
+ except ImportError:
19
+ from .configuration_rabbit import RabbitConfig
20
+
21
+ try:
22
+ from white_rabbit.rabbit_model import RabbitCausalLM, RabbitModelConfig
23
+ except ImportError as _e:
24
+ raise ImportError(
25
+ "The rtaforge package is required to load this model.\n"
26
+ "Install it with: pip install rtaforge"
27
+ ) from _e
28
+
29
+
30
+ class RabbitForCausalLM(PreTrainedModel):
31
+ config_class = RabbitConfig
32
+ supports_gradient_checkpointing = True
33
+
34
+ def __init__(self, config: RabbitConfig):
35
+ super().__init__(config)
36
+ self._inner = RabbitCausalLM(
37
+ RabbitModelConfig(
38
+ vocab_size=config.vocab_size,
39
+ d_model=config.d_model,
40
+ n_layers=config.n_layers,
41
+ durga_variant="fu-64",
42
+ )
43
+ )
44
+
45
+ def get_input_embeddings(self):
46
+ return self._inner.embed_tokens
47
+
48
+ def set_input_embeddings(self, value):
49
+ self._inner.embed_tokens = value
50
+ self._inner.lm_head.weight = value.weight
51
+
52
+ def get_output_embeddings(self):
53
+ return self._inner.lm_head
54
+
55
+ def set_output_embeddings(self, value):
56
+ self._inner.lm_head = value
57
+
58
+ def forward(
59
+ self,
60
+ input_ids: torch.Tensor,
61
+ labels: torch.Tensor | None = None,
62
+ **kwargs,
63
+ ) -> CausalLMOutputWithPast:
64
+ out = self._inner(input_ids=input_ids, labels=labels)
65
+ return CausalLMOutputWithPast(loss=out.get("loss"), logits=out["logits"])
66
+
67
+ def prepare_inputs_for_generation(self, input_ids, **kwargs):
68
+ return {"input_ids": input_ids}