File size: 5,176 Bytes
62dca4c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 | """DFlash LoRA Draft Model: Qwen3-8B with LoRA for parallel block generation."""
from typing import Optional
import torch
import torch.nn as nn
from peft import LoraConfig, TaskType, get_peft_model
from transformers import AutoModelForCausalLM, AutoConfig
class DFlashLoRADraftModel(nn.Module):
"""
Wraps a full Qwen3-8B (or any CausalLM) with PEFT LoRA adapters.
The model learns to predict all tokens in a block in parallel (1-step diffusion),
using a modified DFlash attention mask over the full sequence.
Attention mask design:
- context token i: standard causal (attends to j <= i)
- block token i (in block b): attends to all context tokens + all tokens in block b (bidirectional)
"""
def __init__(
self,
base_model: nn.Module,
block_size: int,
mask_token_id: int,
):
super().__init__()
self.model = base_model
self.block_size = block_size
self.mask_token_id = mask_token_id
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
lora_rank: int = 16,
lora_alpha: int = 32,
lora_dropout: float = 0.05,
lora_target_modules: Optional[list] = None,
block_size: int = 16,
mask_token_id: int = 151669,
torch_dtype: torch.dtype = torch.bfloat16,
device_map: str = "cuda",
trust_remote_code: bool = False,
attn_implementation: str = "sdpa",
**kwargs,
) -> "DFlashLoRADraftModel":
"""
attn_implementation: use 'sdpa' (default), 'eager', or 'flex_attention'.
'flex_attention' uses torch BlockMask — zero extra memory for attention masks.
Do NOT use 'flash_attention_2' — it does not support 4D additive attention masks.
"""
if lora_target_modules is None:
lora_target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"]
base_model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path,
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
attn_implementation=attn_implementation,
**kwargs,
)
base_model = base_model.cuda()
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
target_modules=lora_target_modules,
bias="none",
)
base_model = get_peft_model(base_model, lora_config)
# Cast LoRA parameters to match base model dtype for FSDP compatibility
for param in base_model.parameters():
if param.requires_grad and param.dtype != torch_dtype:
param.data = param.data.to(torch_dtype)
base_model.print_trainable_parameters()
return cls(
base_model=base_model,
block_size=block_size,
mask_token_id=mask_token_id,
)
def gradient_checkpointing_enable(self, **kwargs):
self.model.gradient_checkpointing_enable(**kwargs)
def parameters(self, *args, **kwargs):
return self.model.parameters(*args, **kwargs)
def named_parameters(self, *args, **kwargs):
return self.model.named_parameters(*args, **kwargs)
def train(self, mode=True):
self.model.train(mode)
return self
def eval(self):
self.model.eval()
return self
def save_pretrained(self, save_dir: str, **kwargs):
"""Save only the LoRA adapter weights."""
self.model.save_pretrained(save_dir, **kwargs)
def get_lm_head(self) -> nn.Module:
"""Return a reference to the lm_head through the PEFT model hierarchy."""
base_model = self.model.get_base_model()
return base_model.lm_head
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
position_ids: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_hidden_states: bool = False,
):
"""
Forward pass through the LoRA-adapted model.
Args:
input_ids: [bsz, seq_len] — noisy input (context real, block = anchor + MASKs)
attention_mask: DFlash attention mask — either [bsz, 1, seq_len, seq_len] (4D
additive) or a BlockMask (for flex_attention).
position_ids: [bsz, seq_len]
output_hidden_states: if True, return last hidden state instead of logits.
Used for chunked cross-entropy loss to avoid materializing full logits.
Returns:
logits [bsz, seq_len, vocab_size] when output_hidden_states=False, or
hidden_states [bsz, seq_len, hidden_dim] when output_hidden_states=True.
"""
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
use_cache=use_cache,
output_hidden_states=output_hidden_states,
)
if output_hidden_states:
return outputs.hidden_states[-1]
return outputs.logits
|