Johnblick187 commited on
Commit
d9c1d79
·
verified ·
1 Parent(s): d00ea53

Update modeling_smartcoder_moe.py

Browse files
Files changed (1) hide show
  1. modeling_smartcoder_moe.py +73 -161
modeling_smartcoder_moe.py CHANGED
@@ -19,9 +19,8 @@ import math
19
  import torch
20
  import torch.nn as nn
21
  import torch.nn.functional as F
22
- from transformers import PreTrainedModel, PretrainedConfig
23
  from transformers.modeling_outputs import CausalLMOutputWithPast
24
- from typing import Optional, Tuple, List
25
 
26
 
27
  # ── Config ────────────────────────────────────────────────────────────────────
@@ -84,24 +83,24 @@ class RotaryEmbedding(nn.Module):
84
  super().__init__()
85
  inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
86
  self.register_buffer("inv_freq", inv_freq)
87
- self.max_pos = max_pos
88
- self._build_cache(max_pos)
89
 
90
- def _build_cache(self, seq_len):
91
- t = torch.arange(seq_len, device=self.inv_freq.device).float()
92
- freqs = torch.outer(t, self.inv_freq)
93
  emb = torch.cat([freqs, freqs], dim=-1)
94
- self.register_buffer("cos_cached", emb.cos()[None, None, :, :])
95
- self.register_buffer("sin_cached", emb.sin()[None, None, :, :])
 
96
 
97
- def forward(self, seq_len):
98
- if seq_len > self.max_pos:
99
- self._build_cache(seq_len)
100
  return self.cos_cached[:, :, :seq_len, :], \
101
  self.sin_cached[:, :, :seq_len, :]
102
 
103
 
104
- # ── LayerNorm (with bias) ─────────────────────────────────────────────────────
105
  class LayerNormWithBias(nn.Module):
106
  def __init__(self, hidden_size, eps=1e-5):
107
  super().__init__()
@@ -117,120 +116,80 @@ class LayerNormWithBias(nn.Module):
117
  class SmartCoderAttention(nn.Module):
118
  def __init__(self, config: SmartCoderMoEConfig):
119
  super().__init__()
120
- self.hidden_size = config.hidden_size
121
- self.num_heads = config.num_attention_heads
122
  self.num_kv_heads = config.num_key_value_heads
123
- self.head_dim = config.head_dim
124
  self.num_kv_groups = self.num_heads // self.num_kv_heads
125
 
126
  self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * config.head_dim, bias=True)
127
  self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * config.head_dim, bias=True)
128
  self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * config.head_dim, bias=True)
129
  self.o_proj = nn.Linear(config.num_attention_heads * config.head_dim, config.hidden_size, bias=True)
130
-
131
  self.rotary_emb = RotaryEmbedding(config.head_dim, config.max_position_embeddings, config.rope_theta)
132
 
133
- def forward(self, hidden_states, attention_mask=None, past_key_value=None, use_cache=False):
134
  B, T, _ = hidden_states.shape
135
 
136
  q = self.q_proj(hidden_states).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
137
  k = self.k_proj(hidden_states).view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2)
138
  v = self.v_proj(hidden_states).view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2)
139
 
140
- cos, sin = self.rotary_emb(T)
141
  cos = cos[:, :, :T, :self.head_dim]
142
  sin = sin[:, :, :T, :self.head_dim]
143
  q, k = apply_rotary_emb(q, k, cos, sin)
144
 
145
- if past_key_value is not None:
146
- k = torch.cat([past_key_value[0], k], dim=2)
147
- v = torch.cat([past_key_value[1], v], dim=2)
148
- present = (k, v) if use_cache else None
149
-
150
- # Expand KV heads to match Q heads (GQA)
151
  k = k.repeat_interleave(self.num_kv_groups, dim=1)
152
  v = v.repeat_interleave(self.num_kv_groups, dim=1)
153
 
154
- scale = math.sqrt(self.head_dim)
155
- attn = torch.matmul(q, k.transpose(-2, -1)) / scale
156
-
157
- kv_len = k.shape[2]
158
- causal_mask = torch.triu(
159
- torch.full((T, kv_len), float("-inf"), device=q.device, dtype=q.dtype),
160
- diagonal=1 + kv_len - T
161
- )
162
- attn = attn + causal_mask.unsqueeze(0).unsqueeze(0)
163
-
164
  if attention_mask is not None:
165
  attn = attn + attention_mask
166
-
167
  attn = F.softmax(attn, dim=-1, dtype=torch.float32).to(q.dtype)
168
- out = torch.matmul(attn, v)
169
- out = out.transpose(1, 2).contiguous().view(B, T, -1)
170
- return self.o_proj(out), present
171
 
172
 
173
  # ── MoE MLP ───────────────────────────────────────────────────────────────────
174
  class SmartCoderMoEMLP(nn.Module):
175
- """
176
- Hybrid Dense + MoE MLP.
177
- dense path: hidden -> dense_fc (8192) -> gelu -> dense_proj (2048)
178
- expert path: router picks top-k experts from experts_fc/experts_proj
179
- output = dense_out + expert_out
180
- """
181
  def __init__(self, config: SmartCoderMoEConfig):
182
  super().__init__()
183
  H = config.hidden_size
184
  DI = config.dense_intermediate_size
185
  NE = config.num_experts
186
  EI = config.expert_intermediate_size
187
- K = config.num_experts_per_tok
188
-
189
- self.num_experts = NE
190
- self.top_k = K
191
 
192
- # Dense residual path
193
- self.dense_fc = nn.Linear(H, DI, bias=True)
194
- self.dense_proj = nn.Linear(DI, H, bias=True)
195
 
196
- # MoE path — stored as batched weight matrices matching safetensors layout
197
- # experts_fc: [NE, EI, H]
198
- # experts_proj: [NE, H, EI]
199
- self.experts_fc = nn.Parameter(torch.empty(NE, EI, H))
200
  self.experts_proj = nn.Parameter(torch.empty(NE, H, EI))
201
- self.router = nn.Linear(H, NE, bias=False)
202
 
203
  def forward(self, x):
204
  B, T, H = x.shape
205
 
206
- # Dense path
207
  dense_out = self.dense_proj(F.gelu(self.dense_fc(x)))
208
 
209
- # Router
210
- router_logits = self.router(x) # [B, T, NE]
211
  router_weights = F.softmax(router_logits, dim=-1)
212
- top_weights, top_indices = router_weights.topk(self.top_k, dim=-1) # [B, T, K]
213
- top_weights = top_weights / top_weights.sum(dim=-1, keepdim=True) # normalize
214
 
215
- # Expert computation — iterate over top-k (K is small so this is fine)
216
  expert_out = torch.zeros_like(x)
217
  x_flat = x.view(B * T, H)
218
 
219
  for k in range(self.top_k):
220
- expert_ids = top_indices[:, :, k].reshape(B * T) # [B*T]
221
- weights = top_weights[:, :, k].reshape(B * T, 1) # [B*T, 1]
222
-
223
- # Batched expert forward using einsum
224
- # For each token, pick its expert's weights
225
- # experts_fc: [NE, EI, H] → gather → [B*T, EI, H]
226
- fc_w = self.experts_fc[expert_ids] # [B*T, EI, H]
227
- proj_w = self.experts_proj[expert_ids] # [B*T, H, EI]
228
-
229
- # up: [B*T, EI]
230
  hidden = F.gelu(torch.bmm(fc_w, x_flat.unsqueeze(-1)).squeeze(-1))
231
- # down: [B*T, H]
232
- out = torch.bmm(proj_w, hidden.unsqueeze(-1)).squeeze(-1)
233
-
234
  expert_out = expert_out + (out * weights).view(B, T, H)
235
 
236
  return dense_out + expert_out
@@ -240,59 +199,42 @@ class SmartCoderMoEMLP(nn.Module):
240
  class SmartCoderDecoderLayer(nn.Module):
241
  def __init__(self, config: SmartCoderMoEConfig):
242
  super().__init__()
243
- self.input_layernorm = LayerNormWithBias(config.hidden_size, config.rms_norm_eps)
244
- self.self_attn = SmartCoderAttention(config)
245
  self.post_attention_layernorm = LayerNormWithBias(config.hidden_size, config.rms_norm_eps)
246
- self.mlp = SmartCoderMoEMLP(config)
247
 
248
- def forward(self, hidden_states, attention_mask=None, past_key_value=None, use_cache=False):
249
- # Attention
250
  residual = hidden_states
251
  hidden_states = self.input_layernorm(hidden_states)
252
- hidden_states, present = self.self_attn(
253
- hidden_states, attention_mask=attention_mask,
254
- past_key_value=past_key_value, use_cache=use_cache
255
- )
256
  hidden_states = residual + hidden_states
257
 
258
- # MLP
259
  residual = hidden_states
260
  hidden_states = self.post_attention_layernorm(hidden_states)
261
  hidden_states = self.mlp(hidden_states)
262
  hidden_states = residual + hidden_states
263
 
264
- return hidden_states, present
265
 
266
 
267
- # ── Full Model ────────────────────────────────────────────────────────────────
268
  class SmartCoderMoEModel(nn.Module):
269
  def __init__(self, config: SmartCoderMoEConfig):
270
  super().__init__()
271
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
272
- self.layers = nn.ModuleList([
273
- SmartCoderDecoderLayer(config) for _ in range(config.num_hidden_layers)
274
- ])
275
- self.norm = LayerNormWithBias(config.hidden_size, config.rms_norm_eps)
276
 
277
- def forward(self, input_ids, attention_mask=None, past_key_values=None, use_cache=False):
278
  hidden_states = self.embed_tokens(input_ids)
279
- presents = [] if use_cache else None
280
-
281
- for i, layer in enumerate(self.layers):
282
- pkv = past_key_values[i] if past_key_values else None
283
- hidden_states, present = layer(
284
- hidden_states, attention_mask=attention_mask,
285
- past_key_value=pkv, use_cache=use_cache
286
- )
287
- if use_cache:
288
- presents.append(present)
289
 
290
- hidden_states = self.norm(hidden_states)
291
- return hidden_states, presents
292
 
293
-
294
- # ── CausalLM wrapper ──────────────────────────────────────────────────────────
295
- class SmartCoderMoEForCausalLM(PreTrainedModel):
296
  config_class = SmartCoderMoEConfig
297
  base_model_prefix = "model"
298
  supports_gradient_checkpointing = False
@@ -303,11 +245,8 @@ class SmartCoderMoEForCausalLM(PreTrainedModel):
303
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
304
  self.post_init()
305
 
306
- def get_input_embeddings(self):
307
- return self.model.embed_tokens
308
-
309
- def get_output_embeddings(self):
310
- return self.lm_head
311
 
312
  def forward(
313
  self,
@@ -316,13 +255,10 @@ class SmartCoderMoEForCausalLM(PreTrainedModel):
316
  past_key_values=None,
317
  inputs_embeds=None,
318
  labels=None,
319
- use_cache=True,
320
  **kwargs,
321
  ):
322
- hidden_states, presents = self.model(
323
- input_ids, attention_mask=attention_mask,
324
- past_key_values=past_key_values, use_cache=use_cache
325
- )
326
  logits = self.lm_head(hidden_states)
327
 
328
  loss = None
@@ -335,24 +271,18 @@ class SmartCoderMoEForCausalLM(PreTrainedModel):
335
  ignore_index=-100,
336
  )
337
 
338
- return CausalLMOutputWithPast(
339
- loss=loss,
340
- logits=logits,
341
- past_key_values=presents,
342
- )
343
 
344
- def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
345
- if past_key_values:
346
- input_ids = input_ids[:, -1:]
347
- return {"input_ids": input_ids, "past_key_values": past_key_values, "use_cache": True}
348
 
349
 
350
  # ── Loader ────────────────────────────────────────────────────────────────────
351
  def load_smartcoder_moe(model_id="Johnblick187/SmartCoderMoE", dtype=torch.bfloat16):
352
- """Load SmartCoderMoE with correct custom architecture."""
353
  import os
354
  from huggingface_hub import snapshot_download
355
  from safetensors.torch import load_file
 
356
 
357
  os.environ["HF_HUB_DISABLE_XET"] = "1"
358
 
@@ -364,50 +294,32 @@ def load_smartcoder_moe(model_id="Johnblick187/SmartCoderMoE", dtype=torch.bfloa
364
  model = SmartCoderMoEForCausalLM(config)
365
 
366
  print("Loading weights...")
367
- from pathlib import Path
368
  sf_files = sorted(Path(model_dir).glob("*.safetensors"))
369
  state_dict = {}
370
  for f in sf_files:
371
  state_dict.update(load_file(str(f)))
372
 
373
- # experts_fc in safetensors: [32, 512, 2048] matches our [NE, EI, H]
374
- # experts_proj in safetensors: [32, 2048, 512] — matches our [NE, H, EI] ✓
375
- # router in safetensors: [32, 2048] — stored as Linear weight [out, in]
 
 
 
 
 
 
 
376
 
377
  missing, unexpected = model.load_state_dict(state_dict, strict=False)
378
  if missing:
379
- print(f"Missing keys: {missing[:5]}{'...' if len(missing)>5 else ''}")
380
  if unexpected:
381
- print(f"Unexpected keys: {unexpected[:5]}{'...' if len(unexpected)>5 else ''}")
382
 
383
  model = model.to(dtype)
384
  print(f"Loaded! Params: {sum(p.numel() for p in model.parameters())/1e9:.2f}B")
385
  return model, config
386
 
387
-
388
- if __name__ == "__main__":
389
- from transformers import AutoTokenizer
390
- import torch
391
-
392
- model, config = load_smartcoder_moe()
393
- model.eval()
394
- model = model.cuda()
395
-
396
- tokenizer = AutoTokenizer.from_pretrained("Johnblick187/SmartCoderMoE", trust_remote_code=True)
397
-
398
- prompt = "def fibonacci(n):"
399
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
400
- input_len = inputs["input_ids"].shape[-1]
401
-
402
- with torch.no_grad():
403
- out = model.generate(
404
- **inputs,
405
- max_new_tokens=150,
406
- do_sample=True,
407
- temperature=0.7,
408
- top_p=0.95,
409
- repetition_penalty=1.3,
410
- pad_token_id=tokenizer.eos_token_id,
411
- )
412
-
413
- print(tokenizer.decode(out[0][input_len:], skip_special_tokens=True))
 
19
  import torch
20
  import torch.nn as nn
21
  import torch.nn.functional as F
22
+ from transformers import PreTrainedModel, PretrainedConfig, GenerationMixin
23
  from transformers.modeling_outputs import CausalLMOutputWithPast
 
24
 
25
 
26
  # ── Config ────────────────────────────────────────────────────────────────────
 
83
  super().__init__()
84
  inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
85
  self.register_buffer("inv_freq", inv_freq)
86
+ self._cached_len = 0
 
87
 
88
+ def _build_cache(self, seq_len, device):
89
+ t = torch.arange(seq_len, device=device).float()
90
+ freqs = torch.outer(t, self.inv_freq.to(device))
91
  emb = torch.cat([freqs, freqs], dim=-1)
92
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
93
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
94
+ self._cached_len = seq_len
95
 
96
+ def forward(self, seq_len, device):
97
+ if seq_len > self._cached_len:
98
+ self._build_cache(seq_len, device)
99
  return self.cos_cached[:, :, :seq_len, :], \
100
  self.sin_cached[:, :, :seq_len, :]
101
 
102
 
103
+ # ── LayerNorm with bias ───────────────────────────────────────────────────────
104
  class LayerNormWithBias(nn.Module):
105
  def __init__(self, hidden_size, eps=1e-5):
106
  super().__init__()
 
116
  class SmartCoderAttention(nn.Module):
117
  def __init__(self, config: SmartCoderMoEConfig):
118
  super().__init__()
119
+ self.num_heads = config.num_attention_heads
 
120
  self.num_kv_heads = config.num_key_value_heads
121
+ self.head_dim = config.head_dim
122
  self.num_kv_groups = self.num_heads // self.num_kv_heads
123
 
124
  self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * config.head_dim, bias=True)
125
  self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * config.head_dim, bias=True)
126
  self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * config.head_dim, bias=True)
127
  self.o_proj = nn.Linear(config.num_attention_heads * config.head_dim, config.hidden_size, bias=True)
 
128
  self.rotary_emb = RotaryEmbedding(config.head_dim, config.max_position_embeddings, config.rope_theta)
129
 
130
+ def forward(self, hidden_states, attention_mask=None, **kwargs):
131
  B, T, _ = hidden_states.shape
132
 
133
  q = self.q_proj(hidden_states).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
134
  k = self.k_proj(hidden_states).view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2)
135
  v = self.v_proj(hidden_states).view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2)
136
 
137
+ cos, sin = self.rotary_emb(T, hidden_states.device)
138
  cos = cos[:, :, :T, :self.head_dim]
139
  sin = sin[:, :, :T, :self.head_dim]
140
  q, k = apply_rotary_emb(q, k, cos, sin)
141
 
 
 
 
 
 
 
142
  k = k.repeat_interleave(self.num_kv_groups, dim=1)
143
  v = v.repeat_interleave(self.num_kv_groups, dim=1)
144
 
145
+ attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
146
+ causal = torch.triu(torch.full((T, T), float("-inf"), device=q.device, dtype=q.dtype), diagonal=1)
147
+ attn = attn + causal.unsqueeze(0).unsqueeze(0)
 
 
 
 
 
 
 
148
  if attention_mask is not None:
149
  attn = attn + attention_mask
 
150
  attn = F.softmax(attn, dim=-1, dtype=torch.float32).to(q.dtype)
151
+ out = torch.matmul(attn, v).transpose(1, 2).contiguous().view(B, T, -1)
152
+ return self.o_proj(out)
 
153
 
154
 
155
  # ── MoE MLP ───────────────────────────────────────────────────────────────────
156
  class SmartCoderMoEMLP(nn.Module):
 
 
 
 
 
 
157
  def __init__(self, config: SmartCoderMoEConfig):
158
  super().__init__()
159
  H = config.hidden_size
160
  DI = config.dense_intermediate_size
161
  NE = config.num_experts
162
  EI = config.expert_intermediate_size
 
 
 
 
163
 
164
+ self.num_experts = NE
165
+ self.top_k = config.num_experts_per_tok
 
166
 
167
+ self.dense_fc = nn.Linear(H, DI, bias=True)
168
+ self.dense_proj = nn.Linear(DI, H, bias=True)
169
+ self.experts_fc = nn.Parameter(torch.empty(NE, EI, H))
 
170
  self.experts_proj = nn.Parameter(torch.empty(NE, H, EI))
171
+ self.router = nn.Linear(H, NE, bias=False)
172
 
173
  def forward(self, x):
174
  B, T, H = x.shape
175
 
 
176
  dense_out = self.dense_proj(F.gelu(self.dense_fc(x)))
177
 
178
+ router_logits = self.router(x)
 
179
  router_weights = F.softmax(router_logits, dim=-1)
180
+ top_weights, top_indices = router_weights.topk(self.top_k, dim=-1)
181
+ top_weights = top_weights / top_weights.sum(dim=-1, keepdim=True)
182
 
 
183
  expert_out = torch.zeros_like(x)
184
  x_flat = x.view(B * T, H)
185
 
186
  for k in range(self.top_k):
187
+ expert_ids = top_indices[:, :, k].reshape(B * T)
188
+ weights = top_weights[:, :, k].reshape(B * T, 1)
189
+ fc_w = self.experts_fc[expert_ids]
190
+ proj_w = self.experts_proj[expert_ids]
 
 
 
 
 
 
191
  hidden = F.gelu(torch.bmm(fc_w, x_flat.unsqueeze(-1)).squeeze(-1))
192
+ out = torch.bmm(proj_w, hidden.unsqueeze(-1)).squeeze(-1)
 
 
193
  expert_out = expert_out + (out * weights).view(B, T, H)
194
 
195
  return dense_out + expert_out
 
199
  class SmartCoderDecoderLayer(nn.Module):
200
  def __init__(self, config: SmartCoderMoEConfig):
201
  super().__init__()
202
+ self.input_layernorm = LayerNormWithBias(config.hidden_size, config.rms_norm_eps)
203
+ self.self_attn = SmartCoderAttention(config)
204
  self.post_attention_layernorm = LayerNormWithBias(config.hidden_size, config.rms_norm_eps)
205
+ self.mlp = SmartCoderMoEMLP(config)
206
 
207
+ def forward(self, hidden_states, attention_mask=None, **kwargs):
 
208
  residual = hidden_states
209
  hidden_states = self.input_layernorm(hidden_states)
210
+ hidden_states = self.self_attn(hidden_states, attention_mask=attention_mask)
 
 
 
211
  hidden_states = residual + hidden_states
212
 
 
213
  residual = hidden_states
214
  hidden_states = self.post_attention_layernorm(hidden_states)
215
  hidden_states = self.mlp(hidden_states)
216
  hidden_states = residual + hidden_states
217
 
218
+ return hidden_states
219
 
220
 
221
+ # ── Model ─────────────────────────────────────────────────────────────────────
222
  class SmartCoderMoEModel(nn.Module):
223
  def __init__(self, config: SmartCoderMoEConfig):
224
  super().__init__()
225
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
226
+ self.layers = nn.ModuleList([SmartCoderDecoderLayer(config) for _ in range(config.num_hidden_layers)])
227
+ self.norm = LayerNormWithBias(config.hidden_size, config.rms_norm_eps)
 
 
228
 
229
+ def forward(self, input_ids, attention_mask=None, **kwargs):
230
  hidden_states = self.embed_tokens(input_ids)
231
+ for layer in self.layers:
232
+ hidden_states = layer(hidden_states, attention_mask=attention_mask)
233
+ return self.norm(hidden_states)
 
 
 
 
 
 
 
234
 
 
 
235
 
236
+ # ── CausalLM ──────────────────────────────────────────────────────────────────
237
+ class SmartCoderMoEForCausalLM(PreTrainedModel, GenerationMixin):
 
238
  config_class = SmartCoderMoEConfig
239
  base_model_prefix = "model"
240
  supports_gradient_checkpointing = False
 
245
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
246
  self.post_init()
247
 
248
+ def get_input_embeddings(self): return self.model.embed_tokens
249
+ def get_output_embeddings(self): return self.lm_head
 
 
 
250
 
251
  def forward(
252
  self,
 
255
  past_key_values=None,
256
  inputs_embeds=None,
257
  labels=None,
258
+ use_cache=None,
259
  **kwargs,
260
  ):
261
+ hidden_states = self.model(input_ids, attention_mask=attention_mask)
 
 
 
262
  logits = self.lm_head(hidden_states)
263
 
264
  loss = None
 
271
  ignore_index=-100,
272
  )
273
 
274
+ return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=None)
 
 
 
 
275
 
276
+ def prepare_inputs_for_generation(self, input_ids, **kwargs):
277
+ return {"input_ids": input_ids}
 
 
278
 
279
 
280
  # ── Loader ────────────────────────────────────────────────────────────────────
281
  def load_smartcoder_moe(model_id="Johnblick187/SmartCoderMoE", dtype=torch.bfloat16):
 
282
  import os
283
  from huggingface_hub import snapshot_download
284
  from safetensors.torch import load_file
285
+ from pathlib import Path
286
 
287
  os.environ["HF_HUB_DISABLE_XET"] = "1"
288
 
 
294
  model = SmartCoderMoEForCausalLM(config)
295
 
296
  print("Loading weights...")
 
297
  sf_files = sorted(Path(model_dir).glob("*.safetensors"))
298
  state_dict = {}
299
  for f in sf_files:
300
  state_dict.update(load_file(str(f)))
301
 
302
+ # Remap expert keyssafetensors has .weight suffix, our params don't
303
+ remapped = {}
304
+ for k, v in state_dict.items():
305
+ if 'experts_fc.weight' in k:
306
+ remapped[k.replace('experts_fc.weight', 'experts_fc')] = v
307
+ elif 'experts_proj.weight' in k:
308
+ remapped[k.replace('experts_proj.weight', 'experts_proj')] = v
309
+ else:
310
+ remapped[k] = v
311
+ state_dict = remapped
312
 
313
  missing, unexpected = model.load_state_dict(state_dict, strict=False)
314
  if missing:
315
+ print(f"Missing: {missing[:3]}{'...' if len(missing)>3 else ''}")
316
  if unexpected:
317
+ print(f"Unexpected: {unexpected[:3]}{'...' if len(unexpected)>3 else ''}")
318
 
319
  model = model.to(dtype)
320
  print(f"Loaded! Params: {sum(p.numel() for p in model.parameters())/1e9:.2f}B")
321
  return model, config
322
 
323
+ from transformers import AutoConfig, AutoModelForCausalLM
324
+ AutoConfig.register("smartcoder_moe", SmartCoderMoEConfig)
325
+ AutoModelForCausalLM.register(SmartCoderMoEConfig, SmartCoderMoEForCausalLM)