Fix bug for https://github.com/ByteDance-Seed/Stable-DiffCoder/issues/1

#7
by Natt1e - opened
Files changed (1) hide show
  1. modeling_stable_diffcoder.py +17 -7
modeling_stable_diffcoder.py CHANGED
@@ -156,16 +156,20 @@ class StableDiffcoderForCausalLM(LlamaForCausalLM):
156
  nfe = 0
157
  final_flag = False
158
  prefill_length = prompt_length // block_length * block_length
 
159
  if prefill_length > 0:
160
  cur_attn_mask = block_diffusion_attention_mask[
161
  ..., :prefill_length, :prefill_length
162
  ]
 
 
163
  self(
164
  x[:, :prefill_length],
165
  past_key_values=past_key_values,
166
  attention_mask=cur_attn_mask,
167
  use_cache=True,
168
- ).past_key_values
 
169
 
170
  for block_id, block_size in enumerate(gen_block_list):
171
  block_start = (
@@ -182,7 +186,7 @@ class StableDiffcoderForCausalLM(LlamaForCausalLM):
182
  replace_position[:, block_start:block_end] = True
183
 
184
  for token_count in num_transfer_tokens:
185
- if token_count:
186
  nfe += 1
187
  mask_map = x[:, block_start:block_end] == mask_id
188
  attention_mask = block_diffusion_attention_mask[
@@ -205,22 +209,28 @@ class StableDiffcoderForCausalLM(LlamaForCausalLM):
205
  remasking,
206
  mask_map,
207
  x[:, block_start:block_end],
208
- token_count if threshold is None else None,
209
  threshold,
210
- shift=False,
211
  )
212
  x[:, block_start:block_end][transfer_map] = x0[transfer_map]
213
 
214
  if (x[:, block_start:block_end] == mask_id).sum() == 0:
 
 
 
 
215
  if (
216
  eos_id is not None
217
- and (x[:, block_start:block_end] == eos_id).sum() > 0
 
218
  ):
219
  final_flag = True
220
  x = x[:, :block_end]
221
- eos_pos = (x == eos_id).nonzero(as_tuple=True)[1][0].item()
222
  x[0, eos_pos:] = eos_id
223
  break
 
224
  nfe += 1
225
  self(
226
  x[:, block_start:block_end],
@@ -231,7 +241,7 @@ class StableDiffcoderForCausalLM(LlamaForCausalLM):
231
  use_cache=True,
232
  cache_position=replace_position.nonzero(as_tuple=True)[1],
233
  )
234
- break
235
 
236
  if final_flag:
237
  break
 
156
  nfe = 0
157
  final_flag = False
158
  prefill_length = prompt_length // block_length * block_length
159
+
160
  if prefill_length > 0:
161
  cur_attn_mask = block_diffusion_attention_mask[
162
  ..., :prefill_length, :prefill_length
163
  ]
164
+ # Fix 1: Explicitly pass cache_position for newer transformers prefill
165
+ cache_pos = torch.arange(prefill_length, device=x.device)
166
  self(
167
  x[:, :prefill_length],
168
  past_key_values=past_key_values,
169
  attention_mask=cur_attn_mask,
170
  use_cache=True,
171
+ cache_position=cache_pos,
172
+ )
173
 
174
  for block_id, block_size in enumerate(gen_block_list):
175
  block_start = (
 
186
  replace_position[:, block_start:block_end] = True
187
 
188
  for token_count in num_transfer_tokens:
189
+ if token_count > 0:
190
  nfe += 1
191
  mask_map = x[:, block_start:block_end] == mask_id
192
  attention_mask = block_diffusion_attention_mask[
 
209
  remasking,
210
  mask_map,
211
  x[:, block_start:block_end],
212
+ token_count.item() if threshold is None else None,
213
  threshold,
214
+ shift=shift,
215
  )
216
  x[:, block_start:block_end][transfer_map] = x0[transfer_map]
217
 
218
  if (x[:, block_start:block_end] == mask_id).sum() == 0:
219
+
220
+ # Fix 2: Calculate where the generated tokens ACTUALLY start in this block
221
+ gen_start = max(block_start, prompt_length)
222
+
223
  if (
224
  eos_id is not None
225
+ and gen_start < block_end
226
+ and (x[:, gen_start:block_end] == eos_id).sum() > 0
227
  ):
228
  final_flag = True
229
  x = x[:, :block_end]
230
+ eos_pos = (x[:, gen_start:block_end] == eos_id).nonzero(as_tuple=True)[1][0].item() + gen_start
231
  x[0, eos_pos:] = eos_id
232
  break
233
+
234
  nfe += 1
235
  self(
236
  x[:, block_start:block_end],
 
241
  use_cache=True,
242
  cache_position=replace_position.nonzero(as_tuple=True)[1],
243
  )
244
+ break
245
 
246
  if final_flag:
247
  break