File size: 7,270 Bytes
d522318 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 | # coding=utf-8
"""DFlash Training Wrapper."""
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from specforge.modeling.draft.dflash import DFlashDraftModel
try:
from torch.nn.attention.flex_attention import BlockMask, create_block_mask
FLEX_ATTENTION_AVAILABLE = True
except ImportError:
FLEX_ATTENTION_AVAILABLE = False
BlockMask = None
create_block_mask = None
class OnlineDFlashModel(nn.Module):
"""DFlash online training wrapper with block-wise CE loss."""
def __init__(
self,
draft_model: DFlashDraftModel,
target_lm_head: nn.Module,
target_embed_tokens: nn.Module,
mask_token_id: int,
block_size: int = 16,
attention_backend: str = "flex_attention",
):
super().__init__()
self.draft_model = draft_model
self.lm_head = target_lm_head
self.embed_tokens = target_embed_tokens
self.block_size = block_size
self.mask_token_id = mask_token_id
self.attention_backend = attention_backend
# Cache for BlockMask
self._cached_block_mask: Optional[BlockMask] = None
self._cached_seq_len: Optional[int] = None
self._cached_bsz: Optional[int] = None
def prepare_noise_input(self, input_ids: torch.Tensor) -> torch.Tensor:
"""Prepare noise input: first token of each block is real, rest are MASK."""
seq_len = input_ids.shape[1]
device = input_ids.device
positions = torch.arange(seq_len, device=device)
is_block_start = (positions % self.block_size) == 0
noise_input_ids = torch.full_like(input_ids, self.mask_token_id)
noise_input_ids[:, is_block_start] = input_ids[:, is_block_start]
return noise_input_ids
def _get_or_create_block_mask(
self, bsz: int, q_len: int, kv_len: int, device: torch.device
) -> "BlockMask":
"""Get cached BlockMask or create a new one."""
if (
self._cached_block_mask is not None
and self._cached_seq_len == q_len
and self._cached_bsz == bsz
):
return self._cached_block_mask
block_size = self.block_size
def dflash_mask_fn(b, h, q_idx, kv_idx):
L = q_len
is_ctx = kv_idx < L
q_block = q_idx // block_size
k_block_ctx = kv_idx // block_size
k_block_noise = (kv_idx - L) // block_size
ctx_visible = is_ctx & (k_block_ctx < q_block)
noise_visible = (~is_ctx) & (k_block_noise == q_block)
return ctx_visible | noise_visible
block_mask = create_block_mask(
dflash_mask_fn,
B=bsz,
H=1,
Q_LEN=q_len,
KV_LEN=kv_len,
device=device,
)
self._cached_block_mask = block_mask
self._cached_seq_len = q_len
self._cached_bsz = bsz
return block_mask
def _create_parallel_attention_mask(
self, seq_len: int, device: torch.device
) -> torch.Tensor:
"""
Create [L, 2L] attention mask for parallel training.
- Left half (ctx): Q can see K_ctx if K's block < Q's block
- Right half (noise): Q can see K_noise if same block (bidirectional)
"""
indices = torch.arange(seq_len, device=device)
block_ids = indices // self.block_size
q_block_ids = block_ids.unsqueeze(1)
k_block_ids = block_ids.unsqueeze(0)
ctx_mask = k_block_ids < q_block_ids
noise_mask = q_block_ids == k_block_ids
full_mask_bool = torch.cat([ctx_mask, noise_mask], dim=1)
full_mask = torch.zeros_like(full_mask_bool, dtype=torch.float32)
full_mask.masked_fill_(~full_mask_bool, torch.finfo(torch.float32).min)
return full_mask
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
hidden_states: torch.Tensor,
loss_mask: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Parallel block-wise training forward pass."""
bsz, seq_len = input_ids.shape
device = input_ids.device
# Truncate to multiple of block_size
n_blocks = seq_len // self.block_size
effective_len = n_blocks * self.block_size
input_ids = input_ids[:, :effective_len]
hidden_states = hidden_states[:, :effective_len, :]
loss_mask = loss_mask[:, :effective_len]
attention_mask = attention_mask[:, :effective_len]
# Prepare inputs
noise_input_ids = self.prepare_noise_input(input_ids)
noise_embedding = self.embed_tokens(noise_input_ids)
# Position IDs: [ctx_pos, noise_pos] both 0..L-1
pos_seq = torch.arange(effective_len, device=device)
position_ids = torch.cat([pos_seq, pos_seq], dim=0).unsqueeze(0).expand(bsz, -1)
# Construct attention mask
if (
self.attention_backend == "flex_attention"
and FLEX_ATTENTION_AVAILABLE
and create_block_mask is not None
):
dflash_attn_mask = self._get_or_create_block_mask(
bsz=bsz,
q_len=effective_len,
kv_len=effective_len * 2,
device=device,
)
else:
dflash_attn_mask = self._create_parallel_attention_mask(
effective_len, device
)
dflash_attn_mask = dflash_attn_mask.to(dtype=hidden_states.dtype)
dflash_attn_mask = (
dflash_attn_mask.unsqueeze(0).unsqueeze(0).expand(bsz, -1, -1, -1)
)
# Forward pass
hidden = self.draft_model(
position_ids=position_ids,
noise_embedding=noise_embedding,
target_hidden=hidden_states,
attention_mask=dflash_attn_mask,
)
# Compute loss (skip block 0 and block starts)
dflash_loss_mask_base = create_dflash_loss_mask(
effective_len, self.block_size, device
)
combined_mask = loss_mask * dflash_loss_mask_base.unsqueeze(0)
logits = self.lm_head(hidden)
logits_flat = logits.reshape(-1, logits.size(-1))
labels_flat = input_ids.reshape(-1)
mask_flat = combined_mask.reshape(-1)
active_indices = mask_flat > 0.5
active_logits = logits_flat[active_indices]
active_labels = labels_flat[active_indices]
loss = F.cross_entropy(active_logits, active_labels)
with torch.no_grad():
preds = active_logits.argmax(dim=-1)
correct = (preds == active_labels).float().sum()
total = active_labels.numel()
accuracy = correct / total
return loss, accuracy
def create_dflash_loss_mask(
seq_len: int, block_size: int, device: torch.device
) -> torch.Tensor:
"""Create DFlash loss mask: excludes block 0 and first position of each block."""
positions = torch.arange(seq_len, device=device)
block_ids = positions // block_size
is_block_0 = block_ids == 0
is_block_start = (positions % block_size) == 0
valid_mask = ~is_block_0 & ~is_block_start
return valid_mask.float()
|