Lekr0's picture
Add files using upload-large-folder tool
62dca4c verified
"""DFlash LoRA Draft Model: Qwen3-8B with LoRA for parallel block generation."""
from typing import Optional
import torch
import torch.nn as nn
from peft import LoraConfig, TaskType, get_peft_model
from transformers import AutoModelForCausalLM, AutoConfig
class DFlashLoRADraftModel(nn.Module):
"""
Wraps a full Qwen3-8B (or any CausalLM) with PEFT LoRA adapters.
The model learns to predict all tokens in a block in parallel (1-step diffusion),
using a modified DFlash attention mask over the full sequence.
Attention mask design:
- context token i: standard causal (attends to j <= i)
- block token i (in block b): attends to all context tokens + all tokens in block b (bidirectional)
"""
def __init__(
self,
base_model: nn.Module,
block_size: int,
mask_token_id: int,
):
super().__init__()
self.model = base_model
self.block_size = block_size
self.mask_token_id = mask_token_id
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
lora_rank: int = 16,
lora_alpha: int = 32,
lora_dropout: float = 0.05,
lora_target_modules: Optional[list] = None,
block_size: int = 16,
mask_token_id: int = 151669,
torch_dtype: torch.dtype = torch.bfloat16,
device_map: str = "cuda",
trust_remote_code: bool = False,
attn_implementation: str = "sdpa",
**kwargs,
) -> "DFlashLoRADraftModel":
"""
attn_implementation: use 'sdpa' (default), 'eager', or 'flex_attention'.
'flex_attention' uses torch BlockMask — zero extra memory for attention masks.
Do NOT use 'flash_attention_2' — it does not support 4D additive attention masks.
"""
if lora_target_modules is None:
lora_target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"]
base_model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path,
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
attn_implementation=attn_implementation,
**kwargs,
)
base_model = base_model.cuda()
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
target_modules=lora_target_modules,
bias="none",
)
base_model = get_peft_model(base_model, lora_config)
# Cast LoRA parameters to match base model dtype for FSDP compatibility
for param in base_model.parameters():
if param.requires_grad and param.dtype != torch_dtype:
param.data = param.data.to(torch_dtype)
base_model.print_trainable_parameters()
return cls(
base_model=base_model,
block_size=block_size,
mask_token_id=mask_token_id,
)
def gradient_checkpointing_enable(self, **kwargs):
self.model.gradient_checkpointing_enable(**kwargs)
def parameters(self, *args, **kwargs):
return self.model.parameters(*args, **kwargs)
def named_parameters(self, *args, **kwargs):
return self.model.named_parameters(*args, **kwargs)
def train(self, mode=True):
self.model.train(mode)
return self
def eval(self):
self.model.eval()
return self
def save_pretrained(self, save_dir: str, **kwargs):
"""Save only the LoRA adapter weights."""
self.model.save_pretrained(save_dir, **kwargs)
def get_lm_head(self) -> nn.Module:
"""Return a reference to the lm_head through the PEFT model hierarchy."""
base_model = self.model.get_base_model()
return base_model.lm_head
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
position_ids: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_hidden_states: bool = False,
):
"""
Forward pass through the LoRA-adapted model.
Args:
input_ids: [bsz, seq_len] — noisy input (context real, block = anchor + MASKs)
attention_mask: DFlash attention mask — either [bsz, 1, seq_len, seq_len] (4D
additive) or a BlockMask (for flex_attention).
position_ids: [bsz, seq_len]
output_hidden_states: if True, return last hidden state instead of logits.
Used for chunked cross-entropy loss to avoid materializing full logits.
Returns:
logits [bsz, seq_len, vocab_size] when output_hidden_states=False, or
hidden_states [bsz, seq_len, hidden_dim] when output_hidden_states=True.
"""
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
use_cache=use_cache,
output_hidden_states=output_hidden_states,
)
if output_hidden_states:
return outputs.hidden_states[-1]
return outputs.logits