harryrobert commited on
Commit
91bda10
·
verified ·
1 Parent(s): 39fce7f

pretrain checkpoint step 56000 — loss 1.1006

Browse files
config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "latex_decoder",
3
+ "architectures": [
4
+ "LaTeXDecoderForCausalLM"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_latex_decoder.LaTeXDecoderConfig",
8
+ "AutoModelForCausalLM": "modeling_latex_decoder.LaTeXDecoderForCausalLM"
9
+ },
10
+ "vocab_size": 2046,
11
+ "pad_id": 0,
12
+ "bos_id": 2,
13
+ "eos_id": 3,
14
+ "pad_token_id": 0,
15
+ "bos_token_id": 2,
16
+ "eos_token_id": 3,
17
+ "d_model": 512,
18
+ "n_heads": 8,
19
+ "n_layers": 6,
20
+ "d_ff": 1408,
21
+ "dropout": 0.1,
22
+ "max_seq_len": 200,
23
+ "rope_theta": 10000.0,
24
+ "tie_weights": true,
25
+ "pretrain_step": 56000,
26
+ "pretrain_loss": 1.100601
27
+ }
configuration_latex_decoder.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class LaTeXDecoderConfig(PretrainedConfig):
5
+ model_type = "latex_decoder"
6
+
7
+ def __init__(
8
+ self,
9
+ vocab_size: int = 8192,
10
+ pad_id: int = 0,
11
+ bos_id: int = 2,
12
+ eos_id: int = 3,
13
+ d_model: int = 512,
14
+ n_heads: int = 8,
15
+ n_layers: int = 6,
16
+ d_ff: int = 1408,
17
+ dropout: float = 0.1,
18
+ max_seq_len: int = 200,
19
+ rope_theta: float = 10000.0,
20
+ tie_weights: bool = True,
21
+ **kwargs,
22
+ ):
23
+ kwargs.pop("pad_token_id", None)
24
+ kwargs.pop("bos_token_id", None)
25
+ kwargs.pop("eos_token_id", None)
26
+ super().__init__(
27
+ pad_token_id=pad_id,
28
+ bos_token_id=bos_id,
29
+ eos_token_id=eos_id,
30
+ **kwargs,
31
+ )
32
+ self.vocab_size = vocab_size
33
+ self.pad_id = pad_id
34
+ self.bos_id = bos_id
35
+ self.eos_id = eos_id
36
+ self.d_model = d_model
37
+ self.n_heads = n_heads
38
+ self.n_layers = n_layers
39
+ self.d_ff = d_ff
40
+ self.dropout = dropout
41
+ self.max_seq_len = max_seq_len
42
+ self.rope_theta = rope_theta
43
+ self.tie_weights = tie_weights
44
+
45
+ @property
46
+ def head_dim(self) -> int:
47
+ assert self.d_model % self.n_heads == 0
48
+ return self.d_model // self.n_heads
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a2cca8bc685f1908c1bb8d004a79e95ce4874d045671aaa7565bbab6892144f0
3
+ size 81291512
modeling_latex_decoder.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # update v2
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from typing import Optional
7
+
8
+ from transformers import PreTrainedModel
9
+ from transformers.modeling_outputs import CausalLMOutput
10
+
11
+ from .configuration_latex_decoder import LaTeXDecoderConfig
12
+
13
+
14
+ class RMSNorm(nn.Module):
15
+ def __init__(self, d_model: int, eps: float = 1e-6):
16
+ super().__init__()
17
+ self.eps = eps
18
+ self.weight = nn.Parameter(torch.ones(d_model))
19
+
20
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
21
+ rms = x.pow(2).mean(-1, keepdim=True).add(self.eps).sqrt()
22
+ return x / rms * self.weight
23
+
24
+
25
+ def _build_rope_cache(seq_len, head_dim, theta=10000.0, device=None, dtype=torch.float32):
26
+ half = head_dim // 2
27
+ inv_freq = 1.0 / (theta ** (torch.arange(0, half, device=device, dtype=torch.float32) / half))
28
+ pos = torch.arange(seq_len, device=device, dtype=torch.float32)
29
+ freqs = torch.outer(pos, inv_freq)
30
+ emb = torch.cat([freqs, freqs], dim=-1)
31
+ return emb.cos().to(dtype), emb.sin().to(dtype)
32
+
33
+
34
+ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
35
+ half = x.shape[-1] // 2
36
+ x1, x2 = x[..., :half], x[..., half:]
37
+ return torch.cat([-x2, x1], dim=-1)
38
+
39
+
40
+ def apply_rope(q, k, cos, sin):
41
+ cos = cos.unsqueeze(0).unsqueeze(0)
42
+ sin = sin.unsqueeze(0).unsqueeze(0)
43
+ return q * cos + _rotate_half(q) * sin, k * cos + _rotate_half(k) * sin
44
+
45
+
46
+ class CausalSelfAttention(nn.Module):
47
+ def __init__(self, cfg: LaTeXDecoderConfig):
48
+ super().__init__()
49
+ self.n_heads = cfg.n_heads
50
+ self.head_dim = cfg.head_dim
51
+ self.d_model = cfg.d_model
52
+ self.dropout_p = cfg.dropout
53
+ self.rope_theta = cfg.rope_theta
54
+
55
+ self.qkv_proj = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False)
56
+ self.out_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
57
+ self._rope_cache: dict = {}
58
+
59
+ def _get_rope(self, seq_len, device, dtype):
60
+ key = (seq_len, str(device), dtype)
61
+ if key not in self._rope_cache:
62
+ self._rope_cache[key] = _build_rope_cache(seq_len, self.head_dim, self.rope_theta, device, dtype)
63
+ return self._rope_cache[key]
64
+
65
+ def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
66
+ B, T, C = x.shape
67
+ q, k, v = self.qkv_proj(x).chunk(3, dim=-1)
68
+
69
+ q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
70
+ k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
71
+ v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
72
+
73
+ cos, sin = self._get_rope(T, x.device, q.dtype)
74
+ q, k = apply_rope(q, k, cos, sin)
75
+
76
+ dropout_p = self.dropout_p if self.training else 0.0
77
+
78
+ if attention_mask is not None:
79
+ causal = torch.triu(torch.full((T, T), float("-inf"), device=x.device, dtype=q.dtype), diagonal=1)
80
+ pad = (~attention_mask).unsqueeze(1).unsqueeze(2)
81
+ attn_bias = causal.unsqueeze(0).unsqueeze(0).expand(B, 1, T, T).clone()
82
+ attn_bias = attn_bias.masked_fill(pad, float("-inf"))
83
+ out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias, dropout_p=dropout_p, is_causal=False)
84
+ else:
85
+ out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p, is_causal=True)
86
+
87
+ return self.out_proj(out.transpose(1, 2).contiguous().view(B, T, C))
88
+
89
+
90
+ class SwiGLUFFN(nn.Module):
91
+ def __init__(self, cfg: LaTeXDecoderConfig):
92
+ super().__init__()
93
+ self.gate_proj = nn.Linear(cfg.d_model, cfg.d_ff, bias=False)
94
+ self.up_proj = nn.Linear(cfg.d_model, cfg.d_ff, bias=False)
95
+ self.down_proj = nn.Linear(cfg.d_ff, cfg.d_model, bias=False)
96
+ self.dropout = nn.Dropout(cfg.dropout)
97
+
98
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
99
+ return self.dropout(self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)))
100
+
101
+
102
+ class TransformerBlock(nn.Module):
103
+ def __init__(self, cfg: LaTeXDecoderConfig):
104
+ super().__init__()
105
+ self.norm1 = RMSNorm(cfg.d_model)
106
+ self.attn = CausalSelfAttention(cfg)
107
+ self.norm2 = RMSNorm(cfg.d_model)
108
+ self.ffn = SwiGLUFFN(cfg)
109
+ self.drop = nn.Dropout(cfg.dropout)
110
+
111
+ def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
112
+ x = x + self.drop(self.attn(self.norm1(x), attention_mask))
113
+ x = x + self.drop(self.ffn(self.norm2(x)))
114
+ return x
115
+
116
+
117
+ class LaTeXDecoderForCausalLM(PreTrainedModel):
118
+ config_class = LaTeXDecoderConfig
119
+ base_model_prefix = "model"
120
+ supports_gradient_checkpointing = False
121
+
122
+ def __init__(self, config: LaTeXDecoderConfig):
123
+ super().__init__(config)
124
+
125
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, padding_idx=config.pad_id)
126
+ self.embed_drop = nn.Dropout(config.dropout)
127
+ self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)])
128
+ self.norm_final = RMSNorm(config.d_model)
129
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
130
+
131
+ if config.tie_weights:
132
+ self.lm_head.weight = self.embed_tokens.weight
133
+
134
+ self.post_init()
135
+
136
+ def _init_weights(self, module: nn.Module):
137
+ if isinstance(module, nn.Linear):
138
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
139
+ if module.bias is not None:
140
+ nn.init.zeros_(module.bias)
141
+ elif isinstance(module, nn.Embedding):
142
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
143
+
144
+ def forward(
145
+ self,
146
+ input_ids: torch.Tensor,
147
+ attention_mask: Optional[torch.Tensor] = None,
148
+ labels: Optional[torch.Tensor] = None,
149
+ **kwargs,
150
+ ) -> CausalLMOutput:
151
+ x = self.embed_drop(self.embed_tokens(input_ids))
152
+ for layer in self.layers:
153
+ x = layer(x, attention_mask)
154
+ logits = self.lm_head(self.norm_final(x))
155
+
156
+ loss = None
157
+ if labels is not None:
158
+ shift_logits = logits[:, :-1, :].contiguous()
159
+ shift_labels = labels[:, 1:].contiguous()
160
+ shift_labels = shift_labels.masked_fill(shift_labels == self.config.pad_id, -100)
161
+ loss = F.cross_entropy(
162
+ shift_logits.view(-1, self.config.vocab_size),
163
+ shift_labels.view(-1),
164
+ ignore_index=-100,
165
+ )
166
+
167
+ return CausalLMOutput(loss=loss, logits=logits)
168
+
169
+ @torch.inference_mode()
170
+ def generate(
171
+ self,
172
+ prompt_ids: torch.Tensor,
173
+ max_new_tokens: int = 200,
174
+ temperature: float = 1.0,
175
+ top_p: float = 0.9,
176
+ eos_id: Optional[int] = None,
177
+ ) -> torch.Tensor:
178
+ eos = eos_id if eos_id is not None else self.config.eos_id
179
+ generated = prompt_ids.clone()
180
+
181
+ for _ in range(max_new_tokens):
182
+ ctx = generated[:, -self.config.max_seq_len:]
183
+ logits = self.forward(ctx).logits[:, -1, :]
184
+
185
+ if temperature == 0.0:
186
+ next_id = logits.argmax(dim=-1, keepdim=True)
187
+ else:
188
+ probs = F.softmax(logits / temperature, dim=-1)
189
+ sorted_probs, sorted_idx = probs.sort(dim=-1, descending=True)
190
+ cumsum = sorted_probs.cumsum(dim=-1)
191
+ sorted_probs[cumsum - sorted_probs > top_p] = 0.0
192
+ sorted_probs /= sorted_probs.sum(dim=-1, keepdim=True)
193
+ next_id = sorted_idx.gather(-1, torch.multinomial(sorted_probs, 1))
194
+
195
+ generated = torch.cat([generated, next_id], dim=-1)
196
+ if next_id.item() == eos:
197
+ break
198
+
199
+ return generated
tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "pad_token": {
3
+ "content": "<pad>",
4
+ "single_word": false,
5
+ "lstrip": false,
6
+ "rstrip": false,
7
+ "normalized": false
8
+ },
9
+ "unk_token": {
10
+ "content": "<unk>",
11
+ "single_word": false,
12
+ "lstrip": false,
13
+ "rstrip": false,
14
+ "normalized": false
15
+ },
16
+ "bos_token": {
17
+ "content": "<bos>",
18
+ "single_word": false,
19
+ "lstrip": false,
20
+ "rstrip": false,
21
+ "normalized": false
22
+ },
23
+ "eos_token": {
24
+ "content": "<eos>",
25
+ "single_word": false,
26
+ "lstrip": false,
27
+ "rstrip": false,
28
+ "normalized": false
29
+ }
30
+ }
tokenizer/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 2046,
3
+ "n_frozen": 697,
4
+ "special_tokens": [
5
+ "<pad>",
6
+ "<unk>",
7
+ "<bos>",
8
+ "<eos>"
9
+ ],
10
+ "pad_token": "<pad>",
11
+ "unk_token": "<unk>",
12
+ "bos_token": "<bos>",
13
+ "eos_token": "<eos>",
14
+ "pad_id": 0,
15
+ "unk_id": 1,
16
+ "bos_id": 2,
17
+ "eos_id": 3,
18
+ "model_max_length": 256,
19
+ "padding_side": "right",
20
+ "truncation_side": "right",
21
+ "tokenizer_version": 2,
22
+ "tokenizer_class": "PreTrainedTokenizerFast"
23
+ }