File size: 2,005 Bytes
eb5ce6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
"""
RabbitForCausalLM — AutoModel-compatible wrapper for Anvaya-Rabbit.

  pip install rtaforge transformers
  model = AutoModelForCausalLM.from_pretrained(
      "RtaForge/Anvaya-Rabbit-2.7B", trust_remote_code=True
  )
"""

from __future__ import annotations

import torch
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast

try:
    from configuration_rabbit import RabbitConfig
except ImportError:
    from .configuration_rabbit import RabbitConfig

try:
    from white_rabbit.rabbit_model import RabbitCausalLM, RabbitModelConfig
except ImportError as _e:
    raise ImportError(
        "The rtaforge package is required to load this model.\n"
        "Install it with:  pip install rtaforge"
    ) from _e


class RabbitForCausalLM(PreTrainedModel):
    config_class = RabbitConfig
    supports_gradient_checkpointing = True

    def __init__(self, config: RabbitConfig):
        super().__init__(config)
        self._inner = RabbitCausalLM(
            RabbitModelConfig(
                vocab_size=config.vocab_size,
                d_model=config.d_model,
                n_layers=config.n_layers,
                durga_variant="fu-64",
            )
        )

    def get_input_embeddings(self):
        return self._inner.embed_tokens

    def set_input_embeddings(self, value):
        self._inner.embed_tokens = value
        self._inner.lm_head.weight = value.weight

    def get_output_embeddings(self):
        return self._inner.lm_head

    def set_output_embeddings(self, value):
        self._inner.lm_head = value

    def forward(
        self,
        input_ids: torch.Tensor,
        labels: torch.Tensor | None = None,
        **kwargs,
    ) -> CausalLMOutputWithPast:
        out = self._inner(input_ids=input_ids, labels=labels)
        return CausalLMOutputWithPast(loss=out.get("loss"), logits=out["logits"])

    def prepare_inputs_for_generation(self, input_ids, **kwargs):
        return {"input_ids": input_ids}