Natt1e commited on
Commit
8b68c5b
·
verified ·
1 Parent(s): bc14582

Update modeling_stable_diffcoder.py

Browse files

This PR is for https://github.com/ByteDance-Seed/Stable-DiffCoder/issues/1

Files changed (1) hide show
  1. modeling_stable_diffcoder.py +22 -8
modeling_stable_diffcoder.py CHANGED
@@ -160,12 +160,15 @@ class StableDiffcoderForCausalLM(LlamaForCausalLM):
160
  cur_attn_mask = block_diffusion_attention_mask[
161
  ..., :prefill_length, :prefill_length
162
  ]
 
 
163
  self(
164
  x[:, :prefill_length],
165
  past_key_values=past_key_values,
166
  attention_mask=cur_attn_mask,
167
  use_cache=True,
168
- ).past_key_values
 
169
 
170
  for block_id, block_size in enumerate(gen_block_list):
171
  block_start = (
@@ -181,8 +184,13 @@ class StableDiffcoderForCausalLM(LlamaForCausalLM):
181
  replace_position = torch.zeros_like(x, dtype=torch.bool)
182
  replace_position[:, block_start:block_end] = True
183
 
184
- for token_count in num_transfer_tokens:
185
- if token_count:
 
 
 
 
 
186
  nfe += 1
187
  mask_map = x[:, block_start:block_end] == mask_id
188
  attention_mask = block_diffusion_attention_mask[
@@ -207,20 +215,26 @@ class StableDiffcoderForCausalLM(LlamaForCausalLM):
207
  x[:, block_start:block_end],
208
  token_count if threshold is None else None,
209
  threshold,
210
- shift=False,
211
  )
212
  x[:, block_start:block_end][transfer_map] = x0[transfer_map]
213
 
214
  if (x[:, block_start:block_end] == mask_id).sum() == 0:
 
 
 
 
215
  if (
216
  eos_id is not None
217
- and (x[:, block_start:block_end] == eos_id).sum() > 0
 
218
  ):
219
  final_flag = True
220
  x = x[:, :block_end]
221
- eos_pos = (x == eos_id).nonzero(as_tuple=True)[1][0].item()
222
  x[0, eos_pos:] = eos_id
223
  break
 
224
  nfe += 1
225
  self(
226
  x[:, block_start:block_end],
@@ -231,13 +245,13 @@ class StableDiffcoderForCausalLM(LlamaForCausalLM):
231
  use_cache=True,
232
  cache_position=replace_position.nonzero(as_tuple=True)[1],
233
  )
234
- break
235
 
236
  if final_flag:
237
  break
238
 
239
  return x, nfe
240
-
241
  @torch.no_grad()
242
  def generate(
243
  self,
 
160
  cur_attn_mask = block_diffusion_attention_mask[
161
  ..., :prefill_length, :prefill_length
162
  ]
163
+ # Fix 1: Explicitly pass cache_position for newer transformers prefill
164
+ cache_pos = torch.arange(prefill_length, device=x.device)
165
  self(
166
  x[:, :prefill_length],
167
  past_key_values=past_key_values,
168
  attention_mask=cur_attn_mask,
169
  use_cache=True,
170
+ cache_position=cache_pos,
171
+ )
172
 
173
  for block_id, block_size in enumerate(gen_block_list):
174
  block_start = (
 
184
  replace_position = torch.zeros_like(x, dtype=torch.bool)
185
  replace_position[:, block_start:block_end] = True
186
 
187
+ step_idx = 0
188
+ while True:
189
+ idx = min(step_idx, len(num_transfer_tokens) - 1)
190
+ token_count = num_transfer_tokens[idx].item()
191
+ step_idx += 1
192
+
193
+ if token_count > 0:
194
  nfe += 1
195
  mask_map = x[:, block_start:block_end] == mask_id
196
  attention_mask = block_diffusion_attention_mask[
 
215
  x[:, block_start:block_end],
216
  token_count if threshold is None else None,
217
  threshold,
218
+ shift=shift,
219
  )
220
  x[:, block_start:block_end][transfer_map] = x0[transfer_map]
221
 
222
  if (x[:, block_start:block_end] == mask_id).sum() == 0:
223
+ # Fix 2: Calculate where the generated tokens ACTUALLY start in this block
224
+ gen_start = max(block_start, prompt_length)
225
+
226
+ # Only check for eos_id in the freshly generated region, ignoring the prompt overlap
227
  if (
228
  eos_id is not None
229
+ and gen_start < block_end
230
+ and (x[:, gen_start:block_end] == eos_id).sum() > 0
231
  ):
232
  final_flag = True
233
  x = x[:, :block_end]
234
+ eos_pos = (x[:, gen_start:block_end] == eos_id).nonzero(as_tuple=True)[1][0].item() + gen_start
235
  x[0, eos_pos:] = eos_id
236
  break
237
+
238
  nfe += 1
239
  self(
240
  x[:, block_start:block_end],
 
245
  use_cache=True,
246
  cache_position=replace_position.nonzero(as_tuple=True)[1],
247
  )
248
+ break
249
 
250
  if final_flag:
251
  break
252
 
253
  return x, nfe
254
+
255
  @torch.no_grad()
256
  def generate(
257
  self,