Johnblick187 commited on
Commit
32dba37
·
verified ·
1 Parent(s): 21425ee

Upload modeling_smartcoder.py

Browse files
Files changed (1) hide show
  1. modeling_smartcoder.py +413 -0
modeling_smartcoder.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ modeling_smartcoder_moe.py
3
+ Custom model class for SmartCoderMoE.
4
+
5
+ Architecture (from tensor inspection):
6
+ - vocab_size: 65536, hidden: 2048, layers: 40
7
+ - Attention: q[2048,2048], k/v[512,2048] — 16 heads, 4 KV heads, head_dim=128
8
+ - MLP (hybrid dense + MoE):
9
+ dense_fc: [8192, 2048] up
10
+ dense_proj: [2048, 8192] down
11
+ experts_fc: [32, 512, 2048] expert up (batched)
12
+ experts_proj: [32, 2048, 512] expert down (batched)
13
+ router: [32, 2048] router logits
14
+ - LayerNorm: weight+bias (input_layernorm, post_attention_layernorm)
15
+ - Final norm: model.norm.weight/bias
16
+ """
17
+
18
+ import math
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+ from transformers import PreTrainedModel, PretrainedConfig
23
+ from transformers.modeling_outputs import CausalLMOutputWithPast
24
+ from typing import Optional, Tuple, List
25
+
26
+
27
+ # ── Config ────────────────────────────────────────────────────────────────────
28
+ class SmartCoderMoEConfig(PretrainedConfig):
29
+ model_type = "smartcoder_moe"
30
+
31
+ def __init__(
32
+ self,
33
+ vocab_size=65536,
34
+ hidden_size=2048,
35
+ num_hidden_layers=40,
36
+ num_attention_heads=16,
37
+ num_key_value_heads=4,
38
+ dense_intermediate_size=8192,
39
+ num_experts=32,
40
+ expert_intermediate_size=512,
41
+ num_experts_per_tok=2,
42
+ max_position_embeddings=16384,
43
+ rope_theta=10000.0,
44
+ rms_norm_eps=1e-5,
45
+ pad_token_id=0,
46
+ bos_token_id=1,
47
+ eos_token_id=0,
48
+ tie_word_embeddings=False,
49
+ **kwargs,
50
+ ):
51
+ self.vocab_size = vocab_size
52
+ self.hidden_size = hidden_size
53
+ self.num_hidden_layers = num_hidden_layers
54
+ self.num_attention_heads = num_attention_heads
55
+ self.num_key_value_heads = num_key_value_heads
56
+ self.head_dim = hidden_size // num_attention_heads
57
+ self.dense_intermediate_size = dense_intermediate_size
58
+ self.num_experts = num_experts
59
+ self.expert_intermediate_size = expert_intermediate_size
60
+ self.num_experts_per_tok = num_experts_per_tok
61
+ self.max_position_embeddings = max_position_embeddings
62
+ self.rope_theta = rope_theta
63
+ self.rms_norm_eps = rms_norm_eps
64
+ super().__init__(
65
+ pad_token_id=pad_token_id,
66
+ bos_token_id=bos_token_id,
67
+ eos_token_id=eos_token_id,
68
+ tie_word_embeddings=tie_word_embeddings,
69
+ **kwargs,
70
+ )
71
+
72
+
73
+ # ── RoPE ──────────────────────────────────────────────────────────────────────
74
+ def rotate_half(x):
75
+ x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
76
+ return torch.cat([-x2, x1], dim=-1)
77
+
78
+ def apply_rotary_emb(q, k, cos, sin):
79
+ return (q * cos) + (rotate_half(q) * sin), \
80
+ (k * cos) + (rotate_half(k) * sin)
81
+
82
+ class RotaryEmbedding(nn.Module):
83
+ def __init__(self, dim, max_pos=16384, base=10000.0):
84
+ super().__init__()
85
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
86
+ self.register_buffer("inv_freq", inv_freq)
87
+ self.max_pos = max_pos
88
+ self._build_cache(max_pos)
89
+
90
+ def _build_cache(self, seq_len):
91
+ t = torch.arange(seq_len, device=self.inv_freq.device).float()
92
+ freqs = torch.outer(t, self.inv_freq)
93
+ emb = torch.cat([freqs, freqs], dim=-1)
94
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :])
95
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :])
96
+
97
+ def forward(self, seq_len):
98
+ if seq_len > self.max_pos:
99
+ self._build_cache(seq_len)
100
+ return self.cos_cached[:, :, :seq_len, :], \
101
+ self.sin_cached[:, :, :seq_len, :]
102
+
103
+
104
+ # ── LayerNorm (with bias) ─────────────────────────────────────────────────────
105
+ class LayerNormWithBias(nn.Module):
106
+ def __init__(self, hidden_size, eps=1e-5):
107
+ super().__init__()
108
+ self.weight = nn.Parameter(torch.ones(hidden_size))
109
+ self.bias = nn.Parameter(torch.zeros(hidden_size))
110
+ self.eps = eps
111
+
112
+ def forward(self, x):
113
+ return F.layer_norm(x, x.shape[-1:], self.weight, self.bias, self.eps)
114
+
115
+
116
+ # ── Attention ─────────────────────────────────────────────────────────────────
117
+ class SmartCoderAttention(nn.Module):
118
+ def __init__(self, config: SmartCoderMoEConfig):
119
+ super().__init__()
120
+ self.hidden_size = config.hidden_size
121
+ self.num_heads = config.num_attention_heads
122
+ self.num_kv_heads = config.num_key_value_heads
123
+ self.head_dim = config.head_dim
124
+ self.num_kv_groups = self.num_heads // self.num_kv_heads
125
+
126
+ self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * config.head_dim, bias=True)
127
+ self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * config.head_dim, bias=True)
128
+ self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * config.head_dim, bias=True)
129
+ self.o_proj = nn.Linear(config.num_attention_heads * config.head_dim, config.hidden_size, bias=True)
130
+
131
+ self.rotary_emb = RotaryEmbedding(config.head_dim, config.max_position_embeddings, config.rope_theta)
132
+
133
+ def forward(self, hidden_states, attention_mask=None, past_key_value=None, use_cache=False):
134
+ B, T, _ = hidden_states.shape
135
+
136
+ q = self.q_proj(hidden_states).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
137
+ k = self.k_proj(hidden_states).view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2)
138
+ v = self.v_proj(hidden_states).view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2)
139
+
140
+ cos, sin = self.rotary_emb(T)
141
+ cos = cos[:, :, :T, :self.head_dim]
142
+ sin = sin[:, :, :T, :self.head_dim]
143
+ q, k = apply_rotary_emb(q, k, cos, sin)
144
+
145
+ if past_key_value is not None:
146
+ k = torch.cat([past_key_value[0], k], dim=2)
147
+ v = torch.cat([past_key_value[1], v], dim=2)
148
+ present = (k, v) if use_cache else None
149
+
150
+ # Expand KV heads to match Q heads (GQA)
151
+ k = k.repeat_interleave(self.num_kv_groups, dim=1)
152
+ v = v.repeat_interleave(self.num_kv_groups, dim=1)
153
+
154
+ scale = math.sqrt(self.head_dim)
155
+ attn = torch.matmul(q, k.transpose(-2, -1)) / scale
156
+
157
+ kv_len = k.shape[2]
158
+ causal_mask = torch.triu(
159
+ torch.full((T, kv_len), float("-inf"), device=q.device, dtype=q.dtype),
160
+ diagonal=1 + kv_len - T
161
+ )
162
+ attn = attn + causal_mask.unsqueeze(0).unsqueeze(0)
163
+
164
+ if attention_mask is not None:
165
+ attn = attn + attention_mask
166
+
167
+ attn = F.softmax(attn, dim=-1, dtype=torch.float32).to(q.dtype)
168
+ out = torch.matmul(attn, v)
169
+ out = out.transpose(1, 2).contiguous().view(B, T, -1)
170
+ return self.o_proj(out), present
171
+
172
+
173
+ # ── MoE MLP ───────────────────────────────────────────────────────────────────
174
+ class SmartCoderMoEMLP(nn.Module):
175
+ """
176
+ Hybrid Dense + MoE MLP.
177
+ dense path: hidden -> dense_fc (8192) -> gelu -> dense_proj (2048)
178
+ expert path: router picks top-k experts from experts_fc/experts_proj
179
+ output = dense_out + expert_out
180
+ """
181
+ def __init__(self, config: SmartCoderMoEConfig):
182
+ super().__init__()
183
+ H = config.hidden_size
184
+ DI = config.dense_intermediate_size
185
+ NE = config.num_experts
186
+ EI = config.expert_intermediate_size
187
+ K = config.num_experts_per_tok
188
+
189
+ self.num_experts = NE
190
+ self.top_k = K
191
+
192
+ # Dense residual path
193
+ self.dense_fc = nn.Linear(H, DI, bias=True)
194
+ self.dense_proj = nn.Linear(DI, H, bias=True)
195
+
196
+ # MoE path — stored as batched weight matrices matching safetensors layout
197
+ # experts_fc: [NE, EI, H]
198
+ # experts_proj: [NE, H, EI]
199
+ self.experts_fc = nn.Parameter(torch.empty(NE, EI, H))
200
+ self.experts_proj = nn.Parameter(torch.empty(NE, H, EI))
201
+ self.router = nn.Linear(H, NE, bias=False)
202
+
203
+ def forward(self, x):
204
+ B, T, H = x.shape
205
+
206
+ # Dense path
207
+ dense_out = self.dense_proj(F.gelu(self.dense_fc(x)))
208
+
209
+ # Router
210
+ router_logits = self.router(x) # [B, T, NE]
211
+ router_weights = F.softmax(router_logits, dim=-1)
212
+ top_weights, top_indices = router_weights.topk(self.top_k, dim=-1) # [B, T, K]
213
+ top_weights = top_weights / top_weights.sum(dim=-1, keepdim=True) # normalize
214
+
215
+ # Expert computation — iterate over top-k (K is small so this is fine)
216
+ expert_out = torch.zeros_like(x)
217
+ x_flat = x.view(B * T, H)
218
+
219
+ for k in range(self.top_k):
220
+ expert_ids = top_indices[:, :, k].reshape(B * T) # [B*T]
221
+ weights = top_weights[:, :, k].reshape(B * T, 1) # [B*T, 1]
222
+
223
+ # Batched expert forward using einsum
224
+ # For each token, pick its expert's weights
225
+ # experts_fc: [NE, EI, H] → gather → [B*T, EI, H]
226
+ fc_w = self.experts_fc[expert_ids] # [B*T, EI, H]
227
+ proj_w = self.experts_proj[expert_ids] # [B*T, H, EI]
228
+
229
+ # up: [B*T, EI]
230
+ hidden = F.gelu(torch.bmm(fc_w, x_flat.unsqueeze(-1)).squeeze(-1))
231
+ # down: [B*T, H]
232
+ out = torch.bmm(proj_w, hidden.unsqueeze(-1)).squeeze(-1)
233
+
234
+ expert_out = expert_out + (out * weights).view(B, T, H)
235
+
236
+ return dense_out + expert_out
237
+
238
+
239
+ # ── Decoder Layer ─────────────────────────────────────────────────────────────
240
+ class SmartCoderDecoderLayer(nn.Module):
241
+ def __init__(self, config: SmartCoderMoEConfig):
242
+ super().__init__()
243
+ self.input_layernorm = LayerNormWithBias(config.hidden_size, config.rms_norm_eps)
244
+ self.self_attn = SmartCoderAttention(config)
245
+ self.post_attention_layernorm = LayerNormWithBias(config.hidden_size, config.rms_norm_eps)
246
+ self.mlp = SmartCoderMoEMLP(config)
247
+
248
+ def forward(self, hidden_states, attention_mask=None, past_key_value=None, use_cache=False):
249
+ # Attention
250
+ residual = hidden_states
251
+ hidden_states = self.input_layernorm(hidden_states)
252
+ hidden_states, present = self.self_attn(
253
+ hidden_states, attention_mask=attention_mask,
254
+ past_key_value=past_key_value, use_cache=use_cache
255
+ )
256
+ hidden_states = residual + hidden_states
257
+
258
+ # MLP
259
+ residual = hidden_states
260
+ hidden_states = self.post_attention_layernorm(hidden_states)
261
+ hidden_states = self.mlp(hidden_states)
262
+ hidden_states = residual + hidden_states
263
+
264
+ return hidden_states, present
265
+
266
+
267
+ # ── Full Model ────────────────────────────────────────────────────────────────
268
+ class SmartCoderMoEModel(nn.Module):
269
+ def __init__(self, config: SmartCoderMoEConfig):
270
+ super().__init__()
271
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
272
+ self.layers = nn.ModuleList([
273
+ SmartCoderDecoderLayer(config) for _ in range(config.num_hidden_layers)
274
+ ])
275
+ self.norm = LayerNormWithBias(config.hidden_size, config.rms_norm_eps)
276
+
277
+ def forward(self, input_ids, attention_mask=None, past_key_values=None, use_cache=False):
278
+ hidden_states = self.embed_tokens(input_ids)
279
+ presents = [] if use_cache else None
280
+
281
+ for i, layer in enumerate(self.layers):
282
+ pkv = past_key_values[i] if past_key_values else None
283
+ hidden_states, present = layer(
284
+ hidden_states, attention_mask=attention_mask,
285
+ past_key_value=pkv, use_cache=use_cache
286
+ )
287
+ if use_cache:
288
+ presents.append(present)
289
+
290
+ hidden_states = self.norm(hidden_states)
291
+ return hidden_states, presents
292
+
293
+
294
+ # ── CausalLM wrapper ──────────────────────────────────────────────────────────
295
+ class SmartCoderMoEForCausalLM(PreTrainedModel):
296
+ config_class = SmartCoderMoEConfig
297
+ base_model_prefix = "model"
298
+ supports_gradient_checkpointing = False
299
+
300
+ def __init__(self, config: SmartCoderMoEConfig):
301
+ super().__init__(config)
302
+ self.model = SmartCoderMoEModel(config)
303
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
304
+ self.post_init()
305
+
306
+ def get_input_embeddings(self):
307
+ return self.model.embed_tokens
308
+
309
+ def get_output_embeddings(self):
310
+ return self.lm_head
311
+
312
+ def forward(
313
+ self,
314
+ input_ids=None,
315
+ attention_mask=None,
316
+ past_key_values=None,
317
+ inputs_embeds=None,
318
+ labels=None,
319
+ use_cache=True,
320
+ **kwargs,
321
+ ):
322
+ hidden_states, presents = self.model(
323
+ input_ids, attention_mask=attention_mask,
324
+ past_key_values=past_key_values, use_cache=use_cache
325
+ )
326
+ logits = self.lm_head(hidden_states)
327
+
328
+ loss = None
329
+ if labels is not None:
330
+ shift_logits = logits[..., :-1, :].contiguous()
331
+ shift_labels = labels[..., 1:].contiguous()
332
+ loss = F.cross_entropy(
333
+ shift_logits.view(-1, shift_logits.size(-1)),
334
+ shift_labels.view(-1),
335
+ ignore_index=-100,
336
+ )
337
+
338
+ return CausalLMOutputWithPast(
339
+ loss=loss,
340
+ logits=logits,
341
+ past_key_values=presents,
342
+ )
343
+
344
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
345
+ if past_key_values:
346
+ input_ids = input_ids[:, -1:]
347
+ return {"input_ids": input_ids, "past_key_values": past_key_values, "use_cache": True}
348
+
349
+
350
+ # ── Loader ────────────────────────────────────────────────────────────────────
351
+ def load_smartcoder_moe(model_id="Johnblick187/SmartCoderMoE", dtype=torch.bfloat16):
352
+ """Load SmartCoderMoE with correct custom architecture."""
353
+ import os
354
+ from huggingface_hub import snapshot_download
355
+ from safetensors.torch import load_file
356
+
357
+ os.environ["HF_HUB_DISABLE_XET"] = "1"
358
+
359
+ print(f"Downloading {model_id}...")
360
+ model_dir = snapshot_download(model_id)
361
+
362
+ config = SmartCoderMoEConfig()
363
+ print("Initializing model...")
364
+ model = SmartCoderMoEForCausalLM(config)
365
+
366
+ print("Loading weights...")
367
+ from pathlib import Path
368
+ sf_files = sorted(Path(model_dir).glob("*.safetensors"))
369
+ state_dict = {}
370
+ for f in sf_files:
371
+ state_dict.update(load_file(str(f)))
372
+
373
+ # experts_fc in safetensors: [32, 512, 2048] — matches our [NE, EI, H] ✓
374
+ # experts_proj in safetensors: [32, 2048, 512] — matches our [NE, H, EI] ✓
375
+ # router in safetensors: [32, 2048] — stored as Linear weight [out, in] ✓
376
+
377
+ missing, unexpected = model.load_state_dict(state_dict, strict=False)
378
+ if missing:
379
+ print(f"Missing keys: {missing[:5]}{'...' if len(missing)>5 else ''}")
380
+ if unexpected:
381
+ print(f"Unexpected keys: {unexpected[:5]}{'...' if len(unexpected)>5 else ''}")
382
+
383
+ model = model.to(dtype)
384
+ print(f"Loaded! Params: {sum(p.numel() for p in model.parameters())/1e9:.2f}B")
385
+ return model, config
386
+
387
+
388
+ if __name__ == "__main__":
389
+ from transformers import AutoTokenizer
390
+ import torch
391
+
392
+ model, config = load_smartcoder_moe()
393
+ model.eval()
394
+ model = model.cuda()
395
+
396
+ tokenizer = AutoTokenizer.from_pretrained("Johnblick187/SmartCoderMoE", trust_remote_code=True)
397
+
398
+ prompt = "def fibonacci(n):"
399
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
400
+ input_len = inputs["input_ids"].shape[-1]
401
+
402
+ with torch.no_grad():
403
+ out = model.generate(
404
+ **inputs,
405
+ max_new_tokens=150,
406
+ do_sample=True,
407
+ temperature=0.7,
408
+ top_p=0.95,
409
+ repetition_penalty=1.3,
410
+ pad_token_id=tokenizer.eos_token_id,
411
+ )
412
+
413
+ print(tokenizer.decode(out[0][input_len:], skip_special_tokens=True))