| Now I have a complete picture. Here's the full analysis and proposed plan.Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β |
| Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β |
| Β ---Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β |
| Β Current Logic of train_dflash.pyΒ Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β |
| Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β |
| Β This script trains a small, standalone DFlash draft model (1β10 transformer layers) that learns to predict token blocks in parallel,Β Β Β |
| Β guided by hidden states extracted from a frozen target model. |
| |
| Β Architecture (two separate models) |
| |
| Β 1. Target model (frozen, e.g. Qwen3-8B/32B) β runs a full forward pass to produce hidden states at selected layers. |
| Β 2. Draft model (DFlashDraftModel, ~1β10 layers) β a lightweight Qwen3-based decoder that takes noise embeddings + target hidden states |
| Β and predicts the block tokens. |
| Β 3. Target embed_tokens + lm_head β loaded separately via TargetEmbeddingsAndHead to avoid duplicating the full target model in memory. |
| |
| Β Key locations |
| |
| Β ββββββββββββββββββββββββββββ¬βββββββββββββββββββββββββββββββββββββ¬ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| Β βΒ Β Β Β ComponentΒ Β Β Β Β βΒ Β Β Β Β Β Β Β FileΒ Β Β Β Β Β Β Β βΒ Β Β Β Β Β Β Β Β Β Β Β Β LinesΒ Β Β Β Β Β Β Β Β Β Β Β Β β |
| Β ββββββββββββββββββββββββββββΌβββββββββββββββββββββββββββββββββββββΌββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€ |
| Β β Model initΒ Β Β Β Β Β Β Β β scripts/train_dflash.pyΒ Β Β Β Β Β β build_models() L254β311Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β β |
| Β ββββββββββββββββββββββββββββΌβββββββββββββββββββββββββββββββββββββΌββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€ |
| Β β Target hidden extraction β scripts/train_dflash.pyΒ Β Β Β Β Β β L644β647 (target_model.generate_dflash_data)Β Β Β Β Β β |
| Β ββββββββββββββββββββββββββββΌβββββββββββββββββββββββββββββββββββββΌββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€ |
| Β β Forward passΒ Β Β Β Β Β Β β specforge/core/dflash.pyΒ Β Β Β Β Β β OnlineDFlashModel.forward() L243β332Β Β Β Β Β Β Β Β Β β |
| Β ββββββββββββββββββββββββββββΌβββββββββββββββββββββββββββββββββββββΌββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€ |
| Β β Loss calculationΒ Β Β Β Β β specforge/core/dflash.pyΒ Β Β Β Β Β β _full_lm_loss() L382β417, _chunked_lm_loss() L419β478 β |
| Β ββββββββββββββββββββββββββββΌβββββββββββββββββββββββββββββββββββββΌββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€ |
| Β β Loss maskΒ Β Β Β Β Β Β Β β specforge/core/dflash.pyΒ Β Β Β Β Β β create_dflash_loss_mask() L481β509Β Β Β Β Β Β Β Β Β Β β |
| Β ββββββββββββββββββββββββββββΌβββββββββββββββββββββββββββββββββββββΌββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€ |
| Β β Draft model architecture β specforge/modeling/draft/dflash.py β DFlashDraftModel L212β266Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β β |
| Β ββββββββββββββββββββββββββββΌβββββββββββββββββββββββββββββββββββββΌββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€ |
| Β β DFlash attentionΒ Β Β Β Β β specforge/modeling/draft/dflash.py β Qwen3DFlashAttention L42β134Β Β Β Β Β Β Β Β Β Β Β Β Β β |
| Β ββββββββββββββββββββββββββββ΄βββββββββββββββββββββββββββββββββββββ΄ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
|
|
| Β Forward pass flow (per training step) |
|
|
| Β input_ids, attention_mask, loss_maskΒ βΒ target_model.generate_dflash_data() |
| Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β β |
| Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β hidden_states (from target layers [1,9,17,25,33]) |
| Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β β |
| Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β OnlineDFlashModel.forward(): |
| Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β 1. Truncate to block boundary |
| Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β 2. prepare_noise_input(): anchor tokens kept, rest β MASK |
| Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β 3. embed_tokens(noise_input_ids) β noise_embedding |
| Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β 4. Build DFlash attention mask (flex_attention or additive) |
| Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β 5. draft_model(noise_embedding, target_hidden, mask) |
| Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β 6. lm_head(hidden) β logits |
| Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β 7. CE loss on non-anchor positions (weighted by loss_mask Γ decay) |
| |
| Β The draft model's custom Qwen3DFlashAttention concatenates [context_hidden, noise_hidden] as KV, with queries only from noise tokens. The |
| Β Β attention mask enforces: block tokens see all preceding blocks' context + bidirectional within their own block. |
| |
| Β --- |
| Β What already exists: train_dflash_lora.py |
| |
| Β Interestingly, the repo already has a LoRA variant at scripts/train_dflash_lora.py with its own model (DFlashLoRADraftModel) and wrapper |
| Β (OnlineDFlashLoRAModel). This is exactly the approach you described β Qwen3-8B + LoRA, no separate target model, 1-step diffusion |
| Β training. The key differences from train_dflash.py: |
|
|
| Β βββββββββββββββββββ¬ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ¬βββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| Β βΒ Β Β AspectΒ Β Β βΒ Β Β Β Β Β Β Β Β Β Β Β train_dflash.pyΒ Β Β Β Β Β Β Β Β Β Β Β βΒ Β Β Β Β Β Β Β train_dflash_lora.pyΒ Β Β Β Β Β Β Β β |
| Β βββββββββββββββββββΌββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββΌβββββββββββββββββββββββββββββββββββββββββββββββββββββ€ |
| Β β Draft modelΒ Β Β β Small custom DFlashDraftModel (1β10 layers)Β Β Β Β Β Β Β Β Β β Full Qwen3-8B + LoRA adaptersΒ Β Β Β Β Β Β Β Β Β Β β |
| Β βββββββββββββββββββΌββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββΌβββββββββββββββββββββββββββββββββββββββββββββββββββββ€ |
| Β β Target modelΒ Β β Separate frozen model for hidden state extractionΒ Β Β Β Β Β β None β model uses its own representationsΒ Β Β Β Β β |
| Β βββββββββββββββββββΌββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββΌβββββββββββββββββββββββββββββββββββββββββββββββββββββ€ |
| Β β AttentionΒ Β Β Β β Custom Qwen3DFlashAttention (Q from noise, KV from [ctx,Β Β β Standard HF attention with 4D additive DFlash mask β |
| Β βΒ Β Β Β Β Β Β Β Β β noise])Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β βΒ Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β β |
| Β βββββββββββββββββββΌββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββΌβββββββββββββββββββββββββββββββββββββββββββββββββββββ€ |
| Β β ForwardΒ Β Β Β Β β draft_model(noise_emb, target_hidden, mask)Β Β Β Β Β Β Β Β Β β model(noise_input_ids, 4d_mask, position_ids) βΒ Β β |
| Β βΒ Β Β Β Β Β Β Β Β βΒ Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β β logitsΒ Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β β |
| Β βββββββββββββββββββΌββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββΌβββββββββββββββββββββββββββββββββββββββββββββββββββββ€ |
| Β β TrainableΒ Β Β Β β All draft model paramsΒ Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β β Only LoRA (q/k/v/o_proj), base frozenΒ Β Β Β Β Β Β β |
| Β β paramsΒ Β Β Β Β βΒ Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β βΒ Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β β |
| Β βββββββββββββββββββΌββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββΌβββββββββββββββββββββββββββββββββββββββββββββββββββββ€ |
| Β β FSDP strategyΒ Β β SHARD_GRAD_OPΒ Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β β FULL_SHARDΒ Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β β |
| Β βββββββββββββββββββ΄ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ΄βββββββββββββββββββββββββββββββββββββββββββββββββββββ |
|
|
| Β --- |
| Β Proposed Modification Plan |
|
|
| Β Since train_dflash_lora.py already implements the core idea, the plan focuses on what's missing or needs improvement to make it a proper |
| Β "1-step dLLM draft model" for your research: |
|
|
| Β Phase 1: Validate and extend the existing LoRA pipeline |
|
|
| Β 1. Add MLP to LoRA targets β The current config only targets q_proj, k_proj, v_proj, o_proj. For stronger 1-step diffusion capability, |
| Β add gate_proj, up_proj, down_proj to lora_target_modules. This gives the model more capacity to learn the non-autoregressive distribution |
| Β Β shift. |
| Β 2. Add multi-step noise schedule support β Currently the training is strictly 1-step (all non-anchors β MASK). For a proper diffusion/AR |
| Β fusion, add an option for a noise schedule where a fraction of block tokens are revealed (not just the anchor), controlled by a |
| Β noise_ratio parameter. This would modify prepare_noise_input() in OnlineDFlashLoRAModel: |
| Β # Instead of: all non-anchor β MASK |
| Β # Allow: randomly keep some non-anchor tokens with probability (1 - noise_ratio) |
| Β 3. Add configurable context_len strategy β Currently context_len=0 treats the whole sequence as blocks. Add a --context-ratio arg that |
| Β dynamically sets context_len as a fraction of the sequence, so the model learns to condition on varying amounts of AR-decoded prefix. |
|
|
| Β Phase 2: Training logic improvements |
|
|
| Β 4. Add KL divergence loss β In addition to CE loss against ground truth, add an optional KL loss against the base model's AR distribution |
| Β Β (teacher forcing). This regularizes the LoRA model to stay close to the original Qwen3-8B distribution. Modify |
| Β OnlineDFlashLoRAModel.forward(): |
| Β # Compute base model logits (no_grad, no LoRA) as teacher |
| Β # KL(draft_logits || teacher_logits) on block positions |
| Β # total_loss = ce_loss + kl_weight * kl_loss |
| Β 5. Add evaluation with speculative decoding metrics β The current accuracy metric is block-wise acceptance rate. Add an eval loop that |
| Β actually runs speculative decoding (draft β verify) to measure real speedup, using the LoRA model as the drafter and the base model (with |
| Β Β LoRA disabled) as the verifier. |
| |
| Β Phase 3: Integration with train_dflash.py style features |
|
|
| Β 6. Port random anchor sampling β train_dflash.py has --random-anchor (L147β156) which samples diverse anchor positions per sequence. This |
| Β Β is missing from the LoRA variant and would improve training diversity. Port _sample_anchor_positions and _build_blocks_from_anchors from |
| Β Β OnlineDFlashModel to OnlineDFlashLoRAModel. |
| Β 7. Port chunked lm_head β The LoRA variant materializes full [bsz, seq_len, vocab_size] logits. For Qwen3-8B's 152K vocab, this is ~1.2GB |
| Β Β per sample at seq_len=2048. Port the _chunked_lm_loss from OnlineDFlashModel for memory efficiency. |
| Β 8. Add tensor parallelism support β The LoRA script currently forces tp_size=1. For Qwen3-8B on multi-GPU setups, add TP support to shard |
| Β Β the base model across GPUs while keeping LoRA params on each rank. |
|
|
| Β Minimal starting point |
|
|
| Β If you want to just get training running immediately, the existing train_dflash_lora.py + configs/qwen3-8b-dflash-lora.json should work |
| Β out of the box: |
|
|
| Β torchrun --nproc_per_node=4 scripts/train_dflash_lora.py \ |
| Β Β --model-path Qwen/Qwen3-8B \ |
| Β Β --lora-config configs/qwen3-8b-dflash-lora.json \ |
| Β Β --train-data-path /path/to/data.jsonl \ |
| Β Β --output-dir ./outputs/qwen3-8b-dflash-lora \ |
| Β Β --block-size 16 \ |
| Β Β --gradient-checkpointing \ |
| Β Β --learning-rate 2e-4 \ |
| Β Β --batch-size 1 \ |
| Β Β --max-length 2048 \ |
| Β Β --num-epochs 3 |
|
|
| Β Want me to start implementing any of these phases? I'd suggest starting with Phase 1 (noise schedule + MLP targets) since that's the most |
| Β Β impactful for the diffusion/AR fusion research angle. |