"""DFlash LoRA Draft Model: Qwen3-8B with LoRA for parallel block generation.""" from typing import Optional import torch import torch.nn as nn from peft import LoraConfig, TaskType, get_peft_model from transformers import AutoModelForCausalLM, AutoConfig class DFlashLoRADraftModel(nn.Module): """ Wraps a full Qwen3-8B (or any CausalLM) with PEFT LoRA adapters. The model learns to predict all tokens in a block in parallel (1-step diffusion), using a modified DFlash attention mask over the full sequence. Attention mask design: - context token i: standard causal (attends to j <= i) - block token i (in block b): attends to all context tokens + all tokens in block b (bidirectional) """ def __init__( self, base_model: nn.Module, block_size: int, mask_token_id: int, ): super().__init__() self.model = base_model self.block_size = block_size self.mask_token_id = mask_token_id @classmethod def from_pretrained( cls, pretrained_model_name_or_path: str, lora_rank: int = 16, lora_alpha: int = 32, lora_dropout: float = 0.05, lora_target_modules: Optional[list] = None, block_size: int = 16, mask_token_id: int = 151669, torch_dtype: torch.dtype = torch.bfloat16, device_map: str = "cuda", trust_remote_code: bool = False, attn_implementation: str = "sdpa", **kwargs, ) -> "DFlashLoRADraftModel": """ attn_implementation: use 'sdpa' (default), 'eager', or 'flex_attention'. 'flex_attention' uses torch BlockMask — zero extra memory for attention masks. Do NOT use 'flash_attention_2' — it does not support 4D additive attention masks. """ if lora_target_modules is None: lora_target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"] base_model = AutoModelForCausalLM.from_pretrained( pretrained_model_name_or_path, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code, attn_implementation=attn_implementation, **kwargs, ) base_model = base_model.cuda() lora_config = LoraConfig( task_type=TaskType.CAUSAL_LM, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, target_modules=lora_target_modules, bias="none", ) base_model = get_peft_model(base_model, lora_config) # Cast LoRA parameters to match base model dtype for FSDP compatibility for param in base_model.parameters(): if param.requires_grad and param.dtype != torch_dtype: param.data = param.data.to(torch_dtype) base_model.print_trainable_parameters() return cls( base_model=base_model, block_size=block_size, mask_token_id=mask_token_id, ) def gradient_checkpointing_enable(self, **kwargs): self.model.gradient_checkpointing_enable(**kwargs) def parameters(self, *args, **kwargs): return self.model.parameters(*args, **kwargs) def named_parameters(self, *args, **kwargs): return self.model.named_parameters(*args, **kwargs) def train(self, mode=True): self.model.train(mode) return self def eval(self): self.model.eval() return self def save_pretrained(self, save_dir: str, **kwargs): """Save only the LoRA adapter weights.""" self.model.save_pretrained(save_dir, **kwargs) def get_lm_head(self) -> nn.Module: """Return a reference to the lm_head through the PEFT model hierarchy.""" base_model = self.model.get_base_model() return base_model.lm_head def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, position_ids: Optional[torch.Tensor] = None, use_cache: bool = False, output_hidden_states: bool = False, ): """ Forward pass through the LoRA-adapted model. Args: input_ids: [bsz, seq_len] — noisy input (context real, block = anchor + MASKs) attention_mask: DFlash attention mask — either [bsz, 1, seq_len, seq_len] (4D additive) or a BlockMask (for flex_attention). position_ids: [bsz, seq_len] output_hidden_states: if True, return last hidden state instead of logits. Used for chunked cross-entropy loss to avoid materializing full logits. Returns: logits [bsz, seq_len, vocab_size] when output_hidden_states=False, or hidden_states [bsz, seq_len, hidden_dim] when output_hidden_states=True. """ outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=use_cache, output_hidden_states=output_hidden_states, ) if output_hidden_states: return outputs.hidden_states[-1] return outputs.logits