| |
| """DFlash Training Wrapper.""" |
|
|
| from typing import Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from specforge.modeling.draft.dflash import DFlashDraftModel |
|
|
| try: |
| from torch.nn.attention.flex_attention import BlockMask, create_block_mask |
|
|
| FLEX_ATTENTION_AVAILABLE = True |
| except ImportError: |
| FLEX_ATTENTION_AVAILABLE = False |
| BlockMask = None |
| create_block_mask = None |
|
|
|
|
| class OnlineDFlashModel(nn.Module): |
| """DFlash online training wrapper with block-wise CE loss.""" |
|
|
| def __init__( |
| self, |
| draft_model: DFlashDraftModel, |
| target_lm_head: nn.Module, |
| target_embed_tokens: nn.Module, |
| mask_token_id: int, |
| block_size: int = 16, |
| attention_backend: str = "flex_attention", |
| ): |
| super().__init__() |
| self.draft_model = draft_model |
| self.lm_head = target_lm_head |
| self.embed_tokens = target_embed_tokens |
| self.block_size = block_size |
| self.mask_token_id = mask_token_id |
| self.attention_backend = attention_backend |
|
|
| |
| self._cached_block_mask: Optional[BlockMask] = None |
| self._cached_seq_len: Optional[int] = None |
| self._cached_bsz: Optional[int] = None |
|
|
| def prepare_noise_input(self, input_ids: torch.Tensor) -> torch.Tensor: |
| """Prepare noise input: first token of each block is real, rest are MASK.""" |
| seq_len = input_ids.shape[1] |
| device = input_ids.device |
|
|
| positions = torch.arange(seq_len, device=device) |
| is_block_start = (positions % self.block_size) == 0 |
|
|
| noise_input_ids = torch.full_like(input_ids, self.mask_token_id) |
| noise_input_ids[:, is_block_start] = input_ids[:, is_block_start] |
|
|
| return noise_input_ids |
|
|
| def _get_or_create_block_mask( |
| self, bsz: int, q_len: int, kv_len: int, device: torch.device |
| ) -> "BlockMask": |
| """Get cached BlockMask or create a new one.""" |
| if ( |
| self._cached_block_mask is not None |
| and self._cached_seq_len == q_len |
| and self._cached_bsz == bsz |
| ): |
| return self._cached_block_mask |
|
|
| block_size = self.block_size |
|
|
| def dflash_mask_fn(b, h, q_idx, kv_idx): |
| L = q_len |
| is_ctx = kv_idx < L |
| q_block = q_idx // block_size |
| k_block_ctx = kv_idx // block_size |
| k_block_noise = (kv_idx - L) // block_size |
| ctx_visible = is_ctx & (k_block_ctx < q_block) |
| noise_visible = (~is_ctx) & (k_block_noise == q_block) |
| return ctx_visible | noise_visible |
|
|
| block_mask = create_block_mask( |
| dflash_mask_fn, |
| B=bsz, |
| H=1, |
| Q_LEN=q_len, |
| KV_LEN=kv_len, |
| device=device, |
| ) |
|
|
| self._cached_block_mask = block_mask |
| self._cached_seq_len = q_len |
| self._cached_bsz = bsz |
|
|
| return block_mask |
|
|
| def _create_parallel_attention_mask( |
| self, seq_len: int, device: torch.device |
| ) -> torch.Tensor: |
| """ |
| Create [L, 2L] attention mask for parallel training. |
| - Left half (ctx): Q can see K_ctx if K's block < Q's block |
| - Right half (noise): Q can see K_noise if same block (bidirectional) |
| """ |
| indices = torch.arange(seq_len, device=device) |
| block_ids = indices // self.block_size |
|
|
| q_block_ids = block_ids.unsqueeze(1) |
| k_block_ids = block_ids.unsqueeze(0) |
|
|
| ctx_mask = k_block_ids < q_block_ids |
| noise_mask = q_block_ids == k_block_ids |
|
|
| full_mask_bool = torch.cat([ctx_mask, noise_mask], dim=1) |
| full_mask = torch.zeros_like(full_mask_bool, dtype=torch.float32) |
| full_mask.masked_fill_(~full_mask_bool, torch.finfo(torch.float32).min) |
|
|
| return full_mask |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| hidden_states: torch.Tensor, |
| loss_mask: torch.Tensor, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """Parallel block-wise training forward pass.""" |
| bsz, seq_len = input_ids.shape |
| device = input_ids.device |
|
|
| |
| n_blocks = seq_len // self.block_size |
| effective_len = n_blocks * self.block_size |
| input_ids = input_ids[:, :effective_len] |
| hidden_states = hidden_states[:, :effective_len, :] |
| loss_mask = loss_mask[:, :effective_len] |
| attention_mask = attention_mask[:, :effective_len] |
|
|
| |
| noise_input_ids = self.prepare_noise_input(input_ids) |
| noise_embedding = self.embed_tokens(noise_input_ids) |
|
|
| |
| pos_seq = torch.arange(effective_len, device=device) |
| position_ids = torch.cat([pos_seq, pos_seq], dim=0).unsqueeze(0).expand(bsz, -1) |
|
|
| |
| if ( |
| self.attention_backend == "flex_attention" |
| and FLEX_ATTENTION_AVAILABLE |
| and create_block_mask is not None |
| ): |
| dflash_attn_mask = self._get_or_create_block_mask( |
| bsz=bsz, |
| q_len=effective_len, |
| kv_len=effective_len * 2, |
| device=device, |
| ) |
| else: |
| dflash_attn_mask = self._create_parallel_attention_mask( |
| effective_len, device |
| ) |
| dflash_attn_mask = dflash_attn_mask.to(dtype=hidden_states.dtype) |
| dflash_attn_mask = ( |
| dflash_attn_mask.unsqueeze(0).unsqueeze(0).expand(bsz, -1, -1, -1) |
| ) |
|
|
| |
| hidden = self.draft_model( |
| position_ids=position_ids, |
| noise_embedding=noise_embedding, |
| target_hidden=hidden_states, |
| attention_mask=dflash_attn_mask, |
| ) |
|
|
| |
| dflash_loss_mask_base = create_dflash_loss_mask( |
| effective_len, self.block_size, device |
| ) |
| combined_mask = loss_mask * dflash_loss_mask_base.unsqueeze(0) |
|
|
| logits = self.lm_head(hidden) |
|
|
| logits_flat = logits.reshape(-1, logits.size(-1)) |
| labels_flat = input_ids.reshape(-1) |
| mask_flat = combined_mask.reshape(-1) |
|
|
| active_indices = mask_flat > 0.5 |
| active_logits = logits_flat[active_indices] |
| active_labels = labels_flat[active_indices] |
|
|
| loss = F.cross_entropy(active_logits, active_labels) |
|
|
| with torch.no_grad(): |
| preds = active_logits.argmax(dim=-1) |
| correct = (preds == active_labels).float().sum() |
| total = active_labels.numel() |
| accuracy = correct / total |
|
|
| return loss, accuracy |
|
|
|
|
| def create_dflash_loss_mask( |
| seq_len: int, block_size: int, device: torch.device |
| ) -> torch.Tensor: |
| """Create DFlash loss mask: excludes block 0 and first position of each block.""" |
| positions = torch.arange(seq_len, device=device) |
| block_ids = positions // block_size |
|
|
| is_block_0 = block_ids == 0 |
| is_block_start = (positions % block_size) == 0 |
|
|
| valid_mask = ~is_block_0 & ~is_block_start |
| return valid_mask.float() |
|
|