File size: 19,708 Bytes
62dca4c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 | # coding=utf-8
"""DFlash Training Wrapper."""
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint as grad_checkpoint
from specforge.modeling.draft.dflash import DFlashDraftModel
try:
from torch.nn.attention.flex_attention import BlockMask, create_block_mask
FLEX_ATTENTION_AVAILABLE = True
except ImportError:
FLEX_ATTENTION_AVAILABLE = False
BlockMask = None
create_block_mask = None
class OnlineDFlashModel(nn.Module):
"""DFlash online training wrapper with block-wise CE loss."""
def __init__(
self,
draft_model: DFlashDraftModel,
target_lm_head: nn.Module,
target_embed_tokens: nn.Module,
mask_token_id: int,
block_size: int = 16,
attention_backend: str = "flex_attention",
random_anchor: bool = False,
num_anchors: int = 512,
loss_decay_gamma: Optional[float] = None,
lm_head_chunk_size: int = 0,
):
super().__init__()
self.draft_model = draft_model
self.lm_head = target_lm_head
self.embed_tokens = target_embed_tokens
self.block_size = block_size
self.mask_token_id = mask_token_id
self.attention_backend = attention_backend
self.random_anchor = random_anchor
self.num_anchors = num_anchors
self.loss_decay_gamma = loss_decay_gamma
self.lm_head_chunk_size = lm_head_chunk_size
self._cached_block_mask: Optional[BlockMask] = None
self._cached_seq_len: Optional[int] = None
self._cached_bsz: Optional[int] = None
def _sample_anchor_positions(
self, seq_len: int, loss_mask: torch.Tensor, device: torch.device
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Randomly sample anchor positions per sample; returns (anchors, keep_mask)."""
bs = self.block_size
bsz = loss_mask.shape[0]
max_anchor = max(seq_len - bs, 0)
valid = loss_mask[:, : max_anchor + 1] > 0.5
valid_counts = valid.sum(dim=1)
max_n = min(self.num_anchors, int(valid_counts.max().item()))
if max_n == 0:
anchors = torch.arange(0, seq_len, bs, device=device)
anchors = anchors.unsqueeze(0).expand(bsz, -1)
return anchors, torch.ones(
bsz, anchors.shape[1], dtype=torch.bool, device=device
)
anchor_list = []
keep_list = []
for i in range(bsz):
valid_indices = valid[i].nonzero(as_tuple=False).squeeze(-1)
n_i = min(self.num_anchors, valid_indices.numel())
if n_i == 0:
anchors_i = torch.zeros(max_n, dtype=torch.long, device=device)
keep_i = torch.zeros(max_n, dtype=torch.bool, device=device)
else:
perm = torch.randperm(valid_indices.numel(), device=device)[:n_i]
anchors_i = valid_indices[perm].sort().values
if n_i < max_n:
anchors_i = torch.cat(
[anchors_i, anchors_i[-1:].expand(max_n - n_i)], dim=0
)
keep_i = torch.zeros(max_n, dtype=torch.bool, device=device)
keep_i[:n_i] = True
anchor_list.append(anchors_i)
keep_list.append(keep_i)
return torch.stack(anchor_list, dim=0), torch.stack(keep_list, dim=0)
def _build_blocks_from_anchors(
self,
input_ids: torch.Tensor,
hidden_states: torch.Tensor,
loss_mask: torch.Tensor,
anchor_positions: torch.Tensor,
block_keep_mask: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Gather fixed-size blocks; padding blocks get block_id=-1 and loss=0."""
bs = self.block_size
device = input_ids.device
bsz = input_ids.shape[0]
n = anchor_positions.shape[1]
offsets = torch.arange(bs, device=device).unsqueeze(0)
gather_idx = anchor_positions.unsqueeze(-1) + offsets
gather_idx = gather_idx.reshape(bsz, -1)
block_input_ids = torch.gather(input_ids, 1, gather_idx)
block_hidden = torch.gather(
hidden_states,
1,
gather_idx.unsqueeze(-1).expand(-1, -1, hidden_states.size(-1)),
)
block_loss_mask = torch.gather(loss_mask, 1, gather_idx)
token_keep = block_keep_mask.repeat_interleave(bs, dim=1)
block_loss_mask = block_loss_mask * token_keep.to(block_loss_mask.dtype)
block_ids = torch.arange(n, device=device).repeat_interleave(bs)
pad_token_mask = (~block_keep_mask).repeat_interleave(bs, dim=1)
block_ids = block_ids.unsqueeze(0).expand(bsz, -1).clone()
block_ids[pad_token_mask] = -1
return block_input_ids, block_hidden, block_loss_mask, block_ids, gather_idx
def prepare_noise_input(
self, input_ids: torch.Tensor, block_ids: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Prepare noise input: first token of each block is real, rest are MASK."""
bsz, seq_len = input_ids.shape
device = input_ids.device
if block_ids is not None:
is_block_start = torch.ones(bsz, seq_len, dtype=torch.bool, device=device)
is_block_start[:, 1:] = block_ids[:, 1:] != block_ids[:, :-1]
else:
positions = torch.arange(seq_len, device=device)
is_block_start = (positions % self.block_size) == 0
is_block_start = is_block_start.unsqueeze(0).expand(bsz, -1)
noise_input_ids = torch.full_like(input_ids, self.mask_token_id)
noise_input_ids[is_block_start] = input_ids[is_block_start]
return noise_input_ids
def _get_or_create_block_mask(
self,
bsz: int,
q_len: int,
kv_len: int,
device: torch.device,
block_ids: Optional[torch.Tensor] = None,
) -> "BlockMask":
"""Get cached BlockMask or create a new one."""
if block_ids is None:
if (
self._cached_block_mask is not None
and self._cached_seq_len == q_len
and self._cached_bsz == bsz
):
return self._cached_block_mask
block_size = self.block_size
if block_ids is not None:
_block_ids = block_ids
def dflash_mask_fn(b, h, q_idx, kv_idx):
L = q_len
is_ctx = kv_idx < L
q_b = _block_ids[b, q_idx]
k_ctx = _block_ids[b, kv_idx.clamp(max=L - 1)]
k_noise = _block_ids[b, (kv_idx - L).clamp(min=0, max=L - 1)]
q_valid = q_b >= 0
k_ctx_valid = k_ctx >= 0
k_noise_valid = k_noise >= 0
ctx_visible = is_ctx & q_valid & k_ctx_valid & (k_ctx < q_b)
noise_visible = (~is_ctx) & q_valid & k_noise_valid & (k_noise == q_b)
return ctx_visible | noise_visible
else:
def dflash_mask_fn(b, h, q_idx, kv_idx):
L = q_len
is_ctx = kv_idx < L
q_block = q_idx // block_size
k_block_ctx = kv_idx // block_size
k_block_noise = (kv_idx - L) // block_size
ctx_visible = is_ctx & (k_block_ctx < q_block)
noise_visible = (~is_ctx) & (k_block_noise == q_block)
return ctx_visible | noise_visible
block_mask = create_block_mask(
dflash_mask_fn,
B=bsz,
H=1,
Q_LEN=q_len,
KV_LEN=kv_len,
device=device,
)
if block_ids is None:
self._cached_block_mask = block_mask
self._cached_seq_len = q_len
self._cached_bsz = bsz
return block_mask
def _create_parallel_attention_mask(
self,
bsz: int,
seq_len: int,
device: torch.device,
block_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Create [bsz, L, 2L] attention mask for parallel training."""
if block_ids is None:
ids = torch.arange(seq_len, device=device) // self.block_size
q_ids = ids.unsqueeze(1)
k_ids = ids.unsqueeze(0)
ctx_mask = k_ids < q_ids
noise_mask = q_ids == k_ids
full_mask_bool = torch.cat([ctx_mask, noise_mask], dim=1)
full_mask = torch.zeros_like(full_mask_bool, dtype=torch.float32)
full_mask.masked_fill_(~full_mask_bool, torch.finfo(torch.float32).min)
return full_mask.unsqueeze(0).expand(bsz, -1, -1)
q_ids = block_ids.unsqueeze(2)
k_ids = block_ids.unsqueeze(1)
q_valid = q_ids >= 0
k_valid = k_ids >= 0
ctx_mask = q_valid & k_valid & (k_ids < q_ids)
noise_mask = q_valid & k_valid & (k_ids == q_ids)
full_mask_bool = torch.cat([ctx_mask, noise_mask], dim=2)
full_mask = torch.zeros_like(full_mask_bool, dtype=torch.float32)
full_mask.masked_fill_(~full_mask_bool, torch.finfo(torch.float32).min)
return full_mask
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
hidden_states: torch.Tensor,
loss_mask: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Parallel block-wise training forward pass."""
bsz, seq_len = input_ids.shape
device = input_ids.device
block_ids = None
if self.random_anchor and self.training:
anchor_positions, block_keep_mask = self._sample_anchor_positions(
seq_len, loss_mask, device
)
(input_ids, hidden_states, loss_mask, block_ids, block_positions) = (
self._build_blocks_from_anchors(
input_ids,
hidden_states,
loss_mask,
anchor_positions,
block_keep_mask,
)
)
effective_len = input_ids.shape[1]
base_positions = block_positions
else:
n_blocks = seq_len // self.block_size
effective_len = n_blocks * self.block_size
input_ids = input_ids[:, :effective_len]
hidden_states = hidden_states[:, :effective_len, :]
loss_mask = loss_mask[:, :effective_len]
attention_mask = attention_mask[:, :effective_len]
base_positions = (
torch.arange(effective_len, device=device).unsqueeze(0).expand(bsz, -1)
)
noise_input_ids = self.prepare_noise_input(input_ids, block_ids)
noise_embedding = self.embed_tokens(noise_input_ids)
position_ids = torch.cat([base_positions, base_positions], dim=1)
if (
self.attention_backend == "flex_attention"
and FLEX_ATTENTION_AVAILABLE
and create_block_mask is not None
):
dflash_attn_mask = self._get_or_create_block_mask(
bsz=bsz,
q_len=effective_len,
kv_len=effective_len * 2,
device=device,
block_ids=block_ids,
)
else:
dflash_attn_mask = self._create_parallel_attention_mask(
bsz, effective_len, device, block_ids
)
dflash_attn_mask = dflash_attn_mask.to(dtype=hidden_states.dtype)
dflash_attn_mask = dflash_attn_mask.unsqueeze(1)
hidden = self.draft_model(
position_ids=position_ids,
noise_embedding=noise_embedding,
target_hidden=hidden_states,
attention_mask=dflash_attn_mask,
)
dflash_loss_weights = create_dflash_loss_mask(
effective_len,
self.block_size,
device,
gamma=self.loss_decay_gamma,
block_ids=block_ids,
)
if block_ids is None:
dflash_loss_weights = dflash_loss_weights.unsqueeze(0)
combined_mask = loss_mask * dflash_loss_weights
if self.lm_head_chunk_size > 0 and effective_len > self.lm_head_chunk_size:
loss, accuracy = self._chunked_lm_loss(
hidden, input_ids, loss_mask, combined_mask, effective_len, block_ids
)
else:
loss, accuracy = self._full_lm_loss(
hidden, input_ids, loss_mask, combined_mask, effective_len, block_ids
)
return loss, accuracy
def _compute_acceptance_accuracy(
self,
preds_all: torch.Tensor,
input_ids: torch.Tensor,
loss_mask: torch.Tensor,
effective_len: int,
block_ids: Optional[torch.Tensor],
) -> torch.Tensor:
"""Compute block-wise acceptance rate metric."""
bsz = input_ids.shape[0]
correct_all = (preds_all == input_ids).float()
bs = self.block_size
n_blocks = effective_len // bs
try:
if block_ids is not None:
correct_blocks = correct_all.reshape(bsz, n_blocks, bs)
loss_mask_blocks = loss_mask.reshape(bsz, n_blocks, bs)
else:
if n_blocks > 1:
correct_blocks = correct_all[:, bs:].reshape(
bsz, n_blocks - 1, bs
)
loss_mask_blocks = loss_mask[:, bs:].reshape(
bsz, n_blocks - 1, bs
)
else:
raise ValueError("Only one block")
correct_pred = correct_blocks[:, :, 1:]
loss_mask_pred = loss_mask_blocks[:, :, 1:]
block_valid = (loss_mask_pred.sum(dim=2) == (bs - 1)).float()
correct_pred = correct_pred * loss_mask_pred
cumulative_correct = correct_pred.cumprod(dim=2)
acceptance_lengths = cumulative_correct.sum(dim=2)
acceptance_lengths = (acceptance_lengths * block_valid).sum(dim=1)
total_blocks_sum = block_valid.sum(dim=1).sum().clamp_min(1)
avg_accept_length = acceptance_lengths.sum() / total_blocks_sum
accuracy = avg_accept_length / (bs - 1)
except Exception:
valid_mask = (loss_mask > 0.5).reshape(-1)
correct_flat = correct_all.reshape(-1)[valid_mask]
accuracy = correct_flat.mean() if correct_flat.numel() > 0 else 0.0
return accuracy
def _full_lm_loss(
self,
hidden: torch.Tensor,
input_ids: torch.Tensor,
loss_mask: torch.Tensor,
combined_mask: torch.Tensor,
effective_len: int,
block_ids: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Original non-chunked lm_head + loss computation."""
logits = self.lm_head(hidden)
with torch.no_grad():
preds_all = logits.argmax(dim=-1)
accuracy = self._compute_acceptance_accuracy(
preds_all, input_ids, loss_mask, effective_len, block_ids
)
logits_flat = logits.reshape(-1, logits.size(-1))
labels_flat = input_ids.reshape(-1)
mask_flat = combined_mask.reshape(-1)
active_indices = mask_flat > 1e-6
active_logits = logits_flat[active_indices]
active_labels = labels_flat[active_indices]
active_weights = mask_flat[active_indices]
if self.loss_decay_gamma is not None:
per_token_loss = F.cross_entropy(
active_logits, active_labels, reduction="none"
)
loss = (per_token_loss * active_weights).sum() / active_weights.sum()
else:
loss = F.cross_entropy(active_logits, active_labels)
return loss, accuracy
def _chunked_lm_loss(
self,
hidden: torch.Tensor,
input_ids: torch.Tensor,
loss_mask: torch.Tensor,
combined_mask: torch.Tensor,
effective_len: int,
block_ids: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Chunked lm_head + loss: avoids materializing full [bsz, seq, vocab] logits.
Processes the sequence in chunks of lm_head_chunk_size. Each chunk uses
gradient checkpointing so logits are recomputed (not stored) during backward.
Peak logits memory: O(chunk_size * vocab_size) instead of O(seq_len * vocab_size).
"""
chunk_size = self.lm_head_chunk_size
# 1. Accuracy: compute argmax per chunk (no_grad, no memory concern)
with torch.no_grad():
preds_chunks = []
for start in range(0, effective_len, chunk_size):
end = min(start + chunk_size, effective_len)
chunk_logits = self.lm_head(hidden[:, start:end, :])
preds_chunks.append(chunk_logits.argmax(dim=-1))
preds_all = torch.cat(preds_chunks, dim=1)
accuracy = self._compute_acceptance_accuracy(
preds_all, input_ids, loss_mask, effective_len, block_ids
)
# 2. Loss: chunked with gradient checkpointing
total_loss = torch.tensor(0.0, device=hidden.device)
total_weight = torch.tensor(0.0, device=hidden.device)
def _chunk_ce(h_chunk, labels_chunk, weights_chunk):
logits_chunk = self.lm_head(h_chunk)
logits_flat = logits_chunk.reshape(-1, logits_chunk.size(-1))
labels_flat = labels_chunk.reshape(-1)
weights_flat = weights_chunk.reshape(-1)
active = weights_flat > 1e-6
if not active.any():
return logits_flat.sum() * 0, weights_flat.sum() * 0
per_token = F.cross_entropy(
logits_flat[active], labels_flat[active], reduction="none"
)
return (per_token * weights_flat[active]).sum(), weights_flat[active].sum()
for start in range(0, effective_len, chunk_size):
end = min(start + chunk_size, effective_len)
chunk_loss, chunk_weight = grad_checkpoint(
_chunk_ce,
hidden[:, start:end, :],
input_ids[:, start:end],
combined_mask[:, start:end],
use_reentrant=False,
)
total_loss = total_loss + chunk_loss
total_weight = total_weight + chunk_weight
loss = total_loss / total_weight.clamp_min(1e-8)
return loss, accuracy
def create_dflash_loss_mask(
seq_len: int,
block_size: int,
device: torch.device,
gamma: Optional[float] = None,
block_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Create DFlash loss mask: excludes block starts; for non-random, also excludes block 0.
Returns [seq_len] when block_ids is None, [bsz, seq_len] when block_ids is per-sample.
"""
positions = torch.arange(seq_len, device=device)
pos_in_block = positions % block_size
if block_ids is not None:
is_block_start = torch.ones_like(block_ids, dtype=torch.bool)
is_block_start[:, 1:] = block_ids[:, 1:] != block_ids[:, :-1]
valid_mask = ~is_block_start & (block_ids >= 0)
pos_in_block = pos_in_block.unsqueeze(0)
else:
is_block_start = (positions % block_size) == 0
is_first_block = (positions // block_size) == 0
valid_mask = ~is_first_block & ~is_block_start
if gamma is not None:
decay = torch.exp(-(pos_in_block.float() - 1.0) / gamma)
return valid_mask.float() * decay
else:
return valid_mask.float()
|