Hanrui / progress /step1.md
Lekr0's picture
Add files using upload-large-folder tool
40d87dd verified
Now I have a complete picture. Here's the full analysis and proposed plan.Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β 
Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β 
Β  ---Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β Β 
Β  Current Logic of train_dflash.pyΒ  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β 
Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β 
Β  This script trains a small, standalone DFlash draft model (1–10 transformer layers) that learns to predict token blocks in parallel,Β  Β  Β 
Β  guided by hidden states extracted from a frozen target model.
Β  Architecture (two separate models)
Β  1. Target model (frozen, e.g. Qwen3-8B/32B) β€” runs a full forward pass to produce hidden states at selected layers.
Β  2. Draft model (DFlashDraftModel, ~1–10 layers) β€” a lightweight Qwen3-based decoder that takes noise embeddings + target hidden states
Β  and predicts the block tokens.
Β  3. Target embed_tokens + lm_head β€” loaded separately via TargetEmbeddingsAndHead to avoid duplicating the full target model in memory.
Β  Key locations
Β  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
Β  β”‚Β  Β  Β  Β  ComponentΒ  Β  Β  Β  Β β”‚Β  Β  Β  Β  Β  Β  Β  Β  FileΒ  Β  Β  Β  Β  Β  Β  Β  β”‚Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β LinesΒ  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β β”‚
Β  β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
Β  β”‚ Model initΒ  Β  Β  Β  Β  Β  Β  Β β”‚ scripts/train_dflash.pyΒ  Β  Β  Β  Β  Β  β”‚ build_models() L254–311Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β β”‚
Β  β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
Β  β”‚ Target hidden extraction β”‚ scripts/train_dflash.pyΒ  Β  Β  Β  Β  Β  β”‚ L644–647 (target_model.generate_dflash_data)Β  Β  Β  Β  Β  β”‚
Β  β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
Β  β”‚ Forward passΒ  Β  Β  Β  Β  Β  Β β”‚ specforge/core/dflash.pyΒ  Β  Β  Β  Β  Β β”‚ OnlineDFlashModel.forward() L243–332Β  Β  Β  Β  Β  Β  Β  Β  Β  β”‚
Β  β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
Β  β”‚ Loss calculationΒ  Β  Β  Β  Β β”‚ specforge/core/dflash.pyΒ  Β  Β  Β  Β  Β β”‚ _full_lm_loss() L382–417, _chunked_lm_loss() L419–478 β”‚
Β  β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
Β  β”‚ Loss maskΒ  Β  Β  Β  Β  Β  Β  Β  β”‚ specforge/core/dflash.pyΒ  Β  Β  Β  Β  Β β”‚ create_dflash_loss_mask() L481–509Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  β”‚
Β  β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
Β  β”‚ Draft model architecture β”‚ specforge/modeling/draft/dflash.py β”‚ DFlashDraftModel L212–266Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β β”‚
Β  β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
Β  β”‚ DFlash attentionΒ  Β  Β  Β  Β β”‚ specforge/modeling/draft/dflash.py β”‚ Qwen3DFlashAttention L42–134Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  β”‚
Β  β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
Β  Forward pass flow (per training step)
Β  input_ids, attention_mask, loss_maskΒ  β†’Β  target_model.generate_dflash_data()
Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  ↓
Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β hidden_states (from target layers [1,9,17,25,33])
Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  ↓
Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  OnlineDFlashModel.forward():
Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  1. Truncate to block boundary
Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  2. prepare_noise_input(): anchor tokens kept, rest β†’ MASK
Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  3. embed_tokens(noise_input_ids) β†’ noise_embedding
Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  4. Build DFlash attention mask (flex_attention or additive)
Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  5. draft_model(noise_embedding, target_hidden, mask)
Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  6. lm_head(hidden) β†’ logits
Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  7. CE loss on non-anchor positions (weighted by loss_mask Γ— decay)
Β  The draft model's custom Qwen3DFlashAttention concatenates [context_hidden, noise_hidden] as KV, with queries only from noise tokens. The
Β  Β attention mask enforces: block tokens see all preceding blocks' context + bidirectional within their own block.
Β  ---
Β  What already exists: train_dflash_lora.py
Β  Interestingly, the repo already has a LoRA variant at scripts/train_dflash_lora.py with its own model (DFlashLoRADraftModel) and wrapper
Β  (OnlineDFlashLoRAModel). This is exactly the approach you described β€” Qwen3-8B + LoRA, no separate target model, 1-step diffusion
Β  training. The key differences from train_dflash.py:
Β  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
Β  β”‚Β  Β  Β AspectΒ  Β  Β  β”‚Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β train_dflash.pyΒ  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β β”‚Β  Β  Β  Β  Β  Β  Β  Β  train_dflash_lora.pyΒ  Β  Β  Β  Β  Β  Β  Β  β”‚
Β  β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
Β  β”‚ Draft modelΒ  Β  Β β”‚ Small custom DFlashDraftModel (1–10 layers)Β  Β  Β  Β  Β  Β  Β  Β  Β β”‚ Full Qwen3-8B + LoRA adaptersΒ  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  β”‚
Β  β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
Β  β”‚ Target modelΒ  Β  β”‚ Separate frozen model for hidden state extractionΒ  Β  Β  Β  Β  Β β”‚ None β€” model uses its own representationsΒ  Β  Β  Β  Β  β”‚
Β  β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
Β  β”‚ AttentionΒ  Β  Β  Β β”‚ Custom Qwen3DFlashAttention (Q from noise, KV from [ctx,Β  Β  β”‚ Standard HF attention with 4D additive DFlash mask β”‚
Β  β”‚Β  Β  Β  Β  Β  Β  Β  Β  Β β”‚ noise])Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β β”‚Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  β”‚
Β  β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
Β  β”‚ ForwardΒ  Β  Β  Β  Β β”‚ draft_model(noise_emb, target_hidden, mask)Β  Β  Β  Β  Β  Β  Β  Β  Β β”‚ model(noise_input_ids, 4d_mask, position_ids) β†’Β  Β  β”‚
Β  β”‚Β  Β  Β  Β  Β  Β  Β  Β  Β β”‚Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β β”‚ logitsΒ  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β β”‚
Β  β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
Β  β”‚ TrainableΒ  Β  Β  Β β”‚ All draft model paramsΒ  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  β”‚ Only LoRA (q/k/v/o_proj), base frozenΒ  Β  Β  Β  Β  Β  Β  β”‚
Β  β”‚ paramsΒ  Β  Β  Β  Β  β”‚Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β β”‚Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  β”‚
Β  β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
Β  β”‚ FSDP strategyΒ  Β β”‚ SHARD_GRAD_OPΒ  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β β”‚ FULL_SHARDΒ  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β  Β β”‚
Β  β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
Β  ---
Β  Proposed Modification Plan
Β  Since train_dflash_lora.py already implements the core idea, the plan focuses on what's missing or needs improvement to make it a proper
Β  "1-step dLLM draft model" for your research:
Β  Phase 1: Validate and extend the existing LoRA pipeline
Β  1. Add MLP to LoRA targets β€” The current config only targets q_proj, k_proj, v_proj, o_proj. For stronger 1-step diffusion capability,
Β  add gate_proj, up_proj, down_proj to lora_target_modules. This gives the model more capacity to learn the non-autoregressive distribution
Β  Β shift.
Β  2. Add multi-step noise schedule support β€” Currently the training is strictly 1-step (all non-anchors β†’ MASK). For a proper diffusion/AR
Β  fusion, add an option for a noise schedule where a fraction of block tokens are revealed (not just the anchor), controlled by a
Β  noise_ratio parameter. This would modify prepare_noise_input() in OnlineDFlashLoRAModel:
Β  # Instead of: all non-anchor β†’ MASK
Β  # Allow: randomly keep some non-anchor tokens with probability (1 - noise_ratio)
Β  3. Add configurable context_len strategy β€” Currently context_len=0 treats the whole sequence as blocks. Add a --context-ratio arg that
Β  dynamically sets context_len as a fraction of the sequence, so the model learns to condition on varying amounts of AR-decoded prefix.
Β  Phase 2: Training logic improvements
Β  4. Add KL divergence loss β€” In addition to CE loss against ground truth, add an optional KL loss against the base model's AR distribution
Β  Β (teacher forcing). This regularizes the LoRA model to stay close to the original Qwen3-8B distribution. Modify
Β  OnlineDFlashLoRAModel.forward():
Β  # Compute base model logits (no_grad, no LoRA) as teacher
Β  # KL(draft_logits || teacher_logits) on block positions
Β  # total_loss = ce_loss + kl_weight * kl_loss
Β  5. Add evaluation with speculative decoding metrics β€” The current accuracy metric is block-wise acceptance rate. Add an eval loop that
Β  actually runs speculative decoding (draft β†’ verify) to measure real speedup, using the LoRA model as the drafter and the base model (with
Β  Β LoRA disabled) as the verifier.
Β  Phase 3: Integration with train_dflash.py style features
Β  6. Port random anchor sampling β€” train_dflash.py has --random-anchor (L147–156) which samples diverse anchor positions per sequence. This
Β  Β is missing from the LoRA variant and would improve training diversity. Port _sample_anchor_positions and _build_blocks_from_anchors from
Β  Β OnlineDFlashModel to OnlineDFlashLoRAModel.
Β  7. Port chunked lm_head β€” The LoRA variant materializes full [bsz, seq_len, vocab_size] logits. For Qwen3-8B's 152K vocab, this is ~1.2GB
Β  Β per sample at seq_len=2048. Port the _chunked_lm_loss from OnlineDFlashModel for memory efficiency.
Β  8. Add tensor parallelism support β€” The LoRA script currently forces tp_size=1. For Qwen3-8B on multi-GPU setups, add TP support to shard
Β  Β the base model across GPUs while keeping LoRA params on each rank.
Β  Minimal starting point
Β  If you want to just get training running immediately, the existing train_dflash_lora.py + configs/qwen3-8b-dflash-lora.json should work
Β  out of the box:
Β  torchrun --nproc_per_node=4 scripts/train_dflash_lora.py \
Β  Β  --model-path Qwen/Qwen3-8B \
Β  Β  --lora-config configs/qwen3-8b-dflash-lora.json \
Β  Β  --train-data-path /path/to/data.jsonl \
Β  Β  --output-dir ./outputs/qwen3-8b-dflash-lora \
Β  Β  --block-size 16 \
Β  Β  --gradient-checkpointing \
Β  Β  --learning-rate 2e-4 \
Β  Β  --batch-size 1 \
Β  Β  --max-length 2048 \
Β  Β  --num-epochs 3
Β  Want me to start implementing any of these phases? I'd suggest starting with Phase 1 (noise schedule + MLP targets) since that's the most
Β  Β impactful for the diffusion/AR fusion research angle.