| """DFlash LoRA Draft Model: Qwen3-8B with LoRA for parallel block generation.""" |
|
|
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
| from peft import LoraConfig, TaskType, get_peft_model |
| from transformers import AutoModelForCausalLM, AutoConfig |
|
|
|
|
| class DFlashLoRADraftModel(nn.Module): |
| """ |
| Wraps a full Qwen3-8B (or any CausalLM) with PEFT LoRA adapters. |
| The model learns to predict all tokens in a block in parallel (1-step diffusion), |
| using a modified DFlash attention mask over the full sequence. |
| |
| Attention mask design: |
| - context token i: standard causal (attends to j <= i) |
| - block token i (in block b): attends to all context tokens + all tokens in block b (bidirectional) |
| """ |
|
|
| def __init__( |
| self, |
| base_model: nn.Module, |
| block_size: int, |
| mask_token_id: int, |
| ): |
| super().__init__() |
| self.model = base_model |
| self.block_size = block_size |
| self.mask_token_id = mask_token_id |
|
|
| @classmethod |
| def from_pretrained( |
| cls, |
| pretrained_model_name_or_path: str, |
| lora_rank: int = 16, |
| lora_alpha: int = 32, |
| lora_dropout: float = 0.05, |
| lora_target_modules: Optional[list] = None, |
| block_size: int = 16, |
| mask_token_id: int = 151669, |
| torch_dtype: torch.dtype = torch.bfloat16, |
| device_map: str = "cuda", |
| trust_remote_code: bool = False, |
| attn_implementation: str = "sdpa", |
| **kwargs, |
| ) -> "DFlashLoRADraftModel": |
| """ |
| attn_implementation: use 'sdpa' (default), 'eager', or 'flex_attention'. |
| 'flex_attention' uses torch BlockMask — zero extra memory for attention masks. |
| Do NOT use 'flash_attention_2' — it does not support 4D additive attention masks. |
| """ |
| if lora_target_modules is None: |
| lora_target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"] |
|
|
| base_model = AutoModelForCausalLM.from_pretrained( |
| pretrained_model_name_or_path, |
| torch_dtype=torch_dtype, |
| trust_remote_code=trust_remote_code, |
| attn_implementation=attn_implementation, |
| **kwargs, |
| ) |
| base_model = base_model.cuda() |
|
|
| lora_config = LoraConfig( |
| task_type=TaskType.CAUSAL_LM, |
| r=lora_rank, |
| lora_alpha=lora_alpha, |
| lora_dropout=lora_dropout, |
| target_modules=lora_target_modules, |
| bias="none", |
| ) |
| base_model = get_peft_model(base_model, lora_config) |
| |
| for param in base_model.parameters(): |
| if param.requires_grad and param.dtype != torch_dtype: |
| param.data = param.data.to(torch_dtype) |
| base_model.print_trainable_parameters() |
|
|
| return cls( |
| base_model=base_model, |
| block_size=block_size, |
| mask_token_id=mask_token_id, |
| ) |
|
|
| def gradient_checkpointing_enable(self, **kwargs): |
| self.model.gradient_checkpointing_enable(**kwargs) |
|
|
| def parameters(self, *args, **kwargs): |
| return self.model.parameters(*args, **kwargs) |
|
|
| def named_parameters(self, *args, **kwargs): |
| return self.model.named_parameters(*args, **kwargs) |
|
|
| def train(self, mode=True): |
| self.model.train(mode) |
| return self |
|
|
| def eval(self): |
| self.model.eval() |
| return self |
|
|
| def save_pretrained(self, save_dir: str, **kwargs): |
| """Save only the LoRA adapter weights.""" |
| self.model.save_pretrained(save_dir, **kwargs) |
|
|
| def get_lm_head(self) -> nn.Module: |
| """Return a reference to the lm_head through the PEFT model hierarchy.""" |
| base_model = self.model.get_base_model() |
| return base_model.lm_head |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| position_ids: Optional[torch.Tensor] = None, |
| use_cache: bool = False, |
| output_hidden_states: bool = False, |
| ): |
| """ |
| Forward pass through the LoRA-adapted model. |
| |
| Args: |
| input_ids: [bsz, seq_len] — noisy input (context real, block = anchor + MASKs) |
| attention_mask: DFlash attention mask — either [bsz, 1, seq_len, seq_len] (4D |
| additive) or a BlockMask (for flex_attention). |
| position_ids: [bsz, seq_len] |
| output_hidden_states: if True, return last hidden state instead of logits. |
| Used for chunked cross-entropy loss to avoid materializing full logits. |
| |
| Returns: |
| logits [bsz, seq_len, vocab_size] when output_hidden_states=False, or |
| hidden_states [bsz, seq_len, hidden_dim] when output_hidden_states=True. |
| """ |
| outputs = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| use_cache=use_cache, |
| output_hidden_states=output_hidden_states, |
| ) |
| if output_hidden_states: |
| return outputs.hidden_states[-1] |
| return outputs.logits |
|
|