File size: 16,997 Bytes
40d87dd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | Now I have a complete picture. Here's the full analysis and proposed plan.Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β
Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β
Β ---Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β
Β Current Logic of train_dflash.pyΒ Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β
Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β
Β This script trains a small, standalone DFlash draft model (1β10 transformer layers) that learns to predict token blocks in parallel,Β Β Β
Β guided by hidden states extracted from a frozen target model.
Β Architecture (two separate models)
Β 1. Target model (frozen, e.g. Qwen3-8B/32B) β runs a full forward pass to produce hidden states at selected layers.
Β 2. Draft model (DFlashDraftModel, ~1β10 layers) β a lightweight Qwen3-based decoder that takes noise embeddings + target hidden states
Β and predicts the block tokens.
Β 3. Target embed_tokens + lm_head β loaded separately via TargetEmbeddingsAndHead to avoid duplicating the full target model in memory.
Β Key locations
Β ββββββββββββββββββββββββββββ¬βββββββββββββββββββββββββββββββββββββ¬ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Β βΒ Β Β Β ComponentΒ Β Β Β Β βΒ Β Β Β Β Β Β Β FileΒ Β Β Β Β Β Β Β βΒ Β Β Β Β Β Β Β Β Β Β Β Β LinesΒ Β Β Β Β Β Β Β Β Β Β Β Β β
Β ββββββββββββββββββββββββββββΌβββββββββββββββββββββββββββββββββββββΌββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€
Β β Model initΒ Β Β Β Β Β Β Β β scripts/train_dflash.pyΒ Β Β Β Β Β β build_models() L254β311Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β β
Β ββββββββββββββββββββββββββββΌβββββββββββββββββββββββββββββββββββββΌββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€
Β β Target hidden extraction β scripts/train_dflash.pyΒ Β Β Β Β Β β L644β647 (target_model.generate_dflash_data)Β Β Β Β Β β
Β ββββββββββββββββββββββββββββΌβββββββββββββββββββββββββββββββββββββΌββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€
Β β Forward passΒ Β Β Β Β Β Β β specforge/core/dflash.pyΒ Β Β Β Β Β β OnlineDFlashModel.forward() L243β332Β Β Β Β Β Β Β Β Β β
Β ββββββββββββββββββββββββββββΌβββββββββββββββββββββββββββββββββββββΌββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€
Β β Loss calculationΒ Β Β Β Β β specforge/core/dflash.pyΒ Β Β Β Β Β β _full_lm_loss() L382β417, _chunked_lm_loss() L419β478 β
Β ββββββββββββββββββββββββββββΌβββββββββββββββββββββββββββββββββββββΌββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€
Β β Loss maskΒ Β Β Β Β Β Β Β β specforge/core/dflash.pyΒ Β Β Β Β Β β create_dflash_loss_mask() L481β509Β Β Β Β Β Β Β Β Β Β β
Β ββββββββββββββββββββββββββββΌβββββββββββββββββββββββββββββββββββββΌββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€
Β β Draft model architecture β specforge/modeling/draft/dflash.py β DFlashDraftModel L212β266Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β β
Β ββββββββββββββββββββββββββββΌβββββββββββββββββββββββββββββββββββββΌββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€
Β β DFlash attentionΒ Β Β Β Β β specforge/modeling/draft/dflash.py β Qwen3DFlashAttention L42β134Β Β Β Β Β Β Β Β Β Β Β Β Β β
Β ββββββββββββββββββββββββββββ΄βββββββββββββββββββββββββββββββββββββ΄ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Β Forward pass flow (per training step)
Β input_ids, attention_mask, loss_maskΒ βΒ target_model.generate_dflash_data()
Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β β
Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β hidden_states (from target layers [1,9,17,25,33])
Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β β
Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β OnlineDFlashModel.forward():
Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β 1. Truncate to block boundary
Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β 2. prepare_noise_input(): anchor tokens kept, rest β MASK
Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β 3. embed_tokens(noise_input_ids) β noise_embedding
Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β 4. Build DFlash attention mask (flex_attention or additive)
Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β 5. draft_model(noise_embedding, target_hidden, mask)
Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β 6. lm_head(hidden) β logits
Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β 7. CE loss on non-anchor positions (weighted by loss_mask Γ decay)
Β The draft model's custom Qwen3DFlashAttention concatenates [context_hidden, noise_hidden] as KV, with queries only from noise tokens. The
Β Β attention mask enforces: block tokens see all preceding blocks' context + bidirectional within their own block.
Β ---
Β What already exists: train_dflash_lora.py
Β Interestingly, the repo already has a LoRA variant at scripts/train_dflash_lora.py with its own model (DFlashLoRADraftModel) and wrapper
Β (OnlineDFlashLoRAModel). This is exactly the approach you described β Qwen3-8B + LoRA, no separate target model, 1-step diffusion
Β training. The key differences from train_dflash.py:
Β βββββββββββββββββββ¬ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ¬βββββββββββββββββββββββββββββββββββββββββββββββββββββ
Β βΒ Β Β AspectΒ Β Β βΒ Β Β Β Β Β Β Β Β Β Β Β train_dflash.pyΒ Β Β Β Β Β Β Β Β Β Β Β βΒ Β Β Β Β Β Β Β train_dflash_lora.pyΒ Β Β Β Β Β Β Β β
Β βββββββββββββββββββΌββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββΌβββββββββββββββββββββββββββββββββββββββββββββββββββββ€
Β β Draft modelΒ Β Β β Small custom DFlashDraftModel (1β10 layers)Β Β Β Β Β Β Β Β Β β Full Qwen3-8B + LoRA adaptersΒ Β Β Β Β Β Β Β Β Β Β β
Β βββββββββββββββββββΌββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββΌβββββββββββββββββββββββββββββββββββββββββββββββββββββ€
Β β Target modelΒ Β β Separate frozen model for hidden state extractionΒ Β Β Β Β Β β None β model uses its own representationsΒ Β Β Β Β β
Β βββββββββββββββββββΌββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββΌβββββββββββββββββββββββββββββββββββββββββββββββββββββ€
Β β AttentionΒ Β Β Β β Custom Qwen3DFlashAttention (Q from noise, KV from [ctx,Β Β β Standard HF attention with 4D additive DFlash mask β
Β βΒ Β Β Β Β Β Β Β Β β noise])Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β βΒ Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β β
Β βββββββββββββββββββΌββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββΌβββββββββββββββββββββββββββββββββββββββββββββββββββββ€
Β β ForwardΒ Β Β Β Β β draft_model(noise_emb, target_hidden, mask)Β Β Β Β Β Β Β Β Β β model(noise_input_ids, 4d_mask, position_ids) βΒ Β β
Β βΒ Β Β Β Β Β Β Β Β βΒ Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β β logitsΒ Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β β
Β βββββββββββββββββββΌββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββΌβββββββββββββββββββββββββββββββββββββββββββββββββββββ€
Β β TrainableΒ Β Β Β β All draft model paramsΒ Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β β Only LoRA (q/k/v/o_proj), base frozenΒ Β Β Β Β Β Β β
Β β paramsΒ Β Β Β Β βΒ Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β βΒ Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β β
Β βββββββββββββββββββΌββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββΌβββββββββββββββββββββββββββββββββββββββββββββββββββββ€
Β β FSDP strategyΒ Β β SHARD_GRAD_OPΒ Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β β FULL_SHARDΒ Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β Β β
Β βββββββββββββββββββ΄ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ΄βββββββββββββββββββββββββββββββββββββββββββββββββββββ
Β ---
Β Proposed Modification Plan
Β Since train_dflash_lora.py already implements the core idea, the plan focuses on what's missing or needs improvement to make it a proper
Β "1-step dLLM draft model" for your research:
Β Phase 1: Validate and extend the existing LoRA pipeline
Β 1. Add MLP to LoRA targets β The current config only targets q_proj, k_proj, v_proj, o_proj. For stronger 1-step diffusion capability,
Β add gate_proj, up_proj, down_proj to lora_target_modules. This gives the model more capacity to learn the non-autoregressive distribution
Β Β shift.
Β 2. Add multi-step noise schedule support β Currently the training is strictly 1-step (all non-anchors β MASK). For a proper diffusion/AR
Β fusion, add an option for a noise schedule where a fraction of block tokens are revealed (not just the anchor), controlled by a
Β noise_ratio parameter. This would modify prepare_noise_input() in OnlineDFlashLoRAModel:
Β # Instead of: all non-anchor β MASK
Β # Allow: randomly keep some non-anchor tokens with probability (1 - noise_ratio)
Β 3. Add configurable context_len strategy β Currently context_len=0 treats the whole sequence as blocks. Add a --context-ratio arg that
Β dynamically sets context_len as a fraction of the sequence, so the model learns to condition on varying amounts of AR-decoded prefix.
Β Phase 2: Training logic improvements
Β 4. Add KL divergence loss β In addition to CE loss against ground truth, add an optional KL loss against the base model's AR distribution
Β Β (teacher forcing). This regularizes the LoRA model to stay close to the original Qwen3-8B distribution. Modify
Β OnlineDFlashLoRAModel.forward():
Β # Compute base model logits (no_grad, no LoRA) as teacher
Β # KL(draft_logits || teacher_logits) on block positions
Β # total_loss = ce_loss + kl_weight * kl_loss
Β 5. Add evaluation with speculative decoding metrics β The current accuracy metric is block-wise acceptance rate. Add an eval loop that
Β actually runs speculative decoding (draft β verify) to measure real speedup, using the LoRA model as the drafter and the base model (with
Β Β LoRA disabled) as the verifier.
Β Phase 3: Integration with train_dflash.py style features
Β 6. Port random anchor sampling β train_dflash.py has --random-anchor (L147β156) which samples diverse anchor positions per sequence. This
Β Β is missing from the LoRA variant and would improve training diversity. Port _sample_anchor_positions and _build_blocks_from_anchors from
Β Β OnlineDFlashModel to OnlineDFlashLoRAModel.
Β 7. Port chunked lm_head β The LoRA variant materializes full [bsz, seq_len, vocab_size] logits. For Qwen3-8B's 152K vocab, this is ~1.2GB
Β Β per sample at seq_len=2048. Port the _chunked_lm_loss from OnlineDFlashModel for memory efficiency.
Β 8. Add tensor parallelism support β The LoRA script currently forces tp_size=1. For Qwen3-8B on multi-GPU setups, add TP support to shard
Β Β the base model across GPUs while keeping LoRA params on each rank.
Β Minimal starting point
Β If you want to just get training running immediately, the existing train_dflash_lora.py + configs/qwen3-8b-dflash-lora.json should work
Β out of the box:
Β torchrun --nproc_per_node=4 scripts/train_dflash_lora.py \
Β Β --model-path Qwen/Qwen3-8B \
Β Β --lora-config configs/qwen3-8b-dflash-lora.json \
Β Β --train-data-path /path/to/data.jsonl \
Β Β --output-dir ./outputs/qwen3-8b-dflash-lora \
Β Β --block-size 16 \
Β Β --gradient-checkpointing \
Β Β --learning-rate 2e-4 \
Β Β --batch-size 1 \
Β Β --max-length 2048 \
Β Β --num-epochs 3
Β Want me to start implementing any of these phases? I'd suggest starting with Phase 1 (noise schedule + MLP targets) since that's the most
Β Β impactful for the diffusion/AR fusion research angle. |