File size: 1,334 Bytes
91bda10 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 | from transformers import PretrainedConfig
class LaTeXDecoderConfig(PretrainedConfig):
model_type = "latex_decoder"
def __init__(
self,
vocab_size: int = 8192,
pad_id: int = 0,
bos_id: int = 2,
eos_id: int = 3,
d_model: int = 512,
n_heads: int = 8,
n_layers: int = 6,
d_ff: int = 1408,
dropout: float = 0.1,
max_seq_len: int = 200,
rope_theta: float = 10000.0,
tie_weights: bool = True,
**kwargs,
):
kwargs.pop("pad_token_id", None)
kwargs.pop("bos_token_id", None)
kwargs.pop("eos_token_id", None)
super().__init__(
pad_token_id=pad_id,
bos_token_id=bos_id,
eos_token_id=eos_id,
**kwargs,
)
self.vocab_size = vocab_size
self.pad_id = pad_id
self.bos_id = bos_id
self.eos_id = eos_id
self.d_model = d_model
self.n_heads = n_heads
self.n_layers = n_layers
self.d_ff = d_ff
self.dropout = dropout
self.max_seq_len = max_seq_len
self.rope_theta = rope_theta
self.tie_weights = tie_weights
@property
def head_dim(self) -> int:
assert self.d_model % self.n_heads == 0
return self.d_model // self.n_heads
|