File size: 5,176 Bytes
62dca4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
"""DFlash LoRA Draft Model: Qwen3-8B with LoRA for parallel block generation."""

from typing import Optional

import torch
import torch.nn as nn
from peft import LoraConfig, TaskType, get_peft_model
from transformers import AutoModelForCausalLM, AutoConfig


class DFlashLoRADraftModel(nn.Module):
    """
    Wraps a full Qwen3-8B (or any CausalLM) with PEFT LoRA adapters.
    The model learns to predict all tokens in a block in parallel (1-step diffusion),
    using a modified DFlash attention mask over the full sequence.

    Attention mask design:
      - context token i: standard causal (attends to j <= i)
      - block token i (in block b): attends to all context tokens + all tokens in block b (bidirectional)
    """

    def __init__(
        self,
        base_model: nn.Module,
        block_size: int,
        mask_token_id: int,
    ):
        super().__init__()
        self.model = base_model
        self.block_size = block_size
        self.mask_token_id = mask_token_id

    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_name_or_path: str,
        lora_rank: int = 16,
        lora_alpha: int = 32,
        lora_dropout: float = 0.05,
        lora_target_modules: Optional[list] = None,
        block_size: int = 16,
        mask_token_id: int = 151669,
        torch_dtype: torch.dtype = torch.bfloat16,
        device_map: str = "cuda",
        trust_remote_code: bool = False,
        attn_implementation: str = "sdpa",
        **kwargs,
    ) -> "DFlashLoRADraftModel":
        """
        attn_implementation: use 'sdpa' (default), 'eager', or 'flex_attention'.
        'flex_attention' uses torch BlockMask — zero extra memory for attention masks.
        Do NOT use 'flash_attention_2' — it does not support 4D additive attention masks.
        """
        if lora_target_modules is None:
            lora_target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"]

        base_model = AutoModelForCausalLM.from_pretrained(
            pretrained_model_name_or_path,
            torch_dtype=torch_dtype,
            trust_remote_code=trust_remote_code,
            attn_implementation=attn_implementation,
            **kwargs,
        )
        base_model = base_model.cuda()

        lora_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            r=lora_rank,
            lora_alpha=lora_alpha,
            lora_dropout=lora_dropout,
            target_modules=lora_target_modules,
            bias="none",
        )
        base_model = get_peft_model(base_model, lora_config)
        # Cast LoRA parameters to match base model dtype for FSDP compatibility
        for param in base_model.parameters():
            if param.requires_grad and param.dtype != torch_dtype:
                param.data = param.data.to(torch_dtype)
        base_model.print_trainable_parameters()

        return cls(
            base_model=base_model,
            block_size=block_size,
            mask_token_id=mask_token_id,
        )

    def gradient_checkpointing_enable(self, **kwargs):
        self.model.gradient_checkpointing_enable(**kwargs)

    def parameters(self, *args, **kwargs):
        return self.model.parameters(*args, **kwargs)

    def named_parameters(self, *args, **kwargs):
        return self.model.named_parameters(*args, **kwargs)

    def train(self, mode=True):
        self.model.train(mode)
        return self

    def eval(self):
        self.model.eval()
        return self

    def save_pretrained(self, save_dir: str, **kwargs):
        """Save only the LoRA adapter weights."""
        self.model.save_pretrained(save_dir, **kwargs)

    def get_lm_head(self) -> nn.Module:
        """Return a reference to the lm_head through the PEFT model hierarchy."""
        base_model = self.model.get_base_model()
        return base_model.lm_head

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        position_ids: Optional[torch.Tensor] = None,
        use_cache: bool = False,
        output_hidden_states: bool = False,
    ):
        """
        Forward pass through the LoRA-adapted model.

        Args:
            input_ids: [bsz, seq_len] — noisy input (context real, block = anchor + MASKs)
            attention_mask: DFlash attention mask — either [bsz, 1, seq_len, seq_len] (4D
                additive) or a BlockMask (for flex_attention).
            position_ids: [bsz, seq_len]
            output_hidden_states: if True, return last hidden state instead of logits.
                Used for chunked cross-entropy loss to avoid materializing full logits.

        Returns:
            logits [bsz, seq_len, vocab_size] when output_hidden_states=False, or
            hidden_states [bsz, seq_len, hidden_dim] when output_hidden_states=True.
        """
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            use_cache=use_cache,
            output_hidden_states=output_hidden_states,
        )
        if output_hidden_states:
            return outputs.hidden_states[-1]
        return outputs.logits