# coding=utf-8 """DFlash Training Wrapper.""" from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from specforge.modeling.draft.dflash import DFlashDraftModel try: from torch.nn.attention.flex_attention import BlockMask, create_block_mask FLEX_ATTENTION_AVAILABLE = True except ImportError: FLEX_ATTENTION_AVAILABLE = False BlockMask = None create_block_mask = None def create_dflash_sdpa_mask(anchor_positions, block_keep_mask, S, block_size, device): B, N = anchor_positions.shape Q_LEN = N * block_size KV_LEN = S + N * block_size q_indices = torch.arange(Q_LEN, device=device).view(1, 1, -1, 1) # (1, 1, Q_LEN, 1) kv_indices = torch.arange(KV_LEN, device=device).view( 1, 1, 1, -1 ) # (1, 1, 1, KV_LEN) q_block_ids = q_indices // block_size anchor_expanded = anchor_positions.view(B, 1, N, 1).repeat_interleave( block_size, dim=2 ) mask_context = (kv_indices < S) & (kv_indices < anchor_expanded) is_draft = kv_indices >= S kv_block_ids = (kv_indices - S) // block_size mask_draft = is_draft & (q_block_ids == kv_block_ids) valid_block = block_keep_mask.view(B, 1, N, 1).repeat_interleave(block_size, dim=2) final_mask = (mask_context | mask_draft) & valid_block return final_mask def create_dflash_block_mask( anchor_positions: torch.Tensor, block_keep_mask: torch.Tensor, S: int, block_size: int, device: torch.device, ): """Construct Flex Attention BlockMask for DFlash training. KV: [Context (S tokens) | Block_0 | Block_1 | ... | Block_{n-1}] Q: [Block_0 | Block_1 | ... | Block_{n-1}] Rules: 1. Each block sees context strictly before its anchor (kv_idx < anchor_pos). 2. Intra-block attention is bidirectional. 3. Different blocks are invisible to each other. 4. Invalid blocks (block_keep_mask=False) see nothing. """ def dflash_mask_mod(b, h, q_idx, kv_idx): q_block_id = q_idx // block_size safe_q_block_id = q_block_id.clamp(max=N - 1) anchor_pos = anchor_positions[b, safe_q_block_id] is_context = kv_idx < S # Strictly less than: matches inference where target_hidden[anchor_pos] # is not available as context. mask_context = is_context & (kv_idx < anchor_pos) is_draft = kv_idx >= S kv_block_id = (kv_idx - S) // block_size mask_draft = is_draft & (q_block_id == kv_block_id) is_valid_block = block_keep_mask[b, safe_q_block_id] in_bounds = q_block_id < N return (mask_context | mask_draft) & is_valid_block & in_bounds B, N = anchor_positions.shape Q_LEN = N * block_size KV_LEN = S + N * block_size return create_block_mask( dflash_mask_mod, B=B, H=None, Q_LEN=Q_LEN, KV_LEN=KV_LEN, device=device ) class OnlineDFlashModel(nn.Module): """DFlash online training wrapper with block-wise CE loss.""" def __init__( self, draft_model: DFlashDraftModel, target_lm_head: nn.Module, target_embed_tokens: nn.Module, mask_token_id: int, block_size: int = 16, attention_backend: str = "flex_attention", num_anchors: int = 512, loss_decay_gamma: Optional[float] = None, ): super().__init__() self.draft_model = draft_model self.lm_head = target_lm_head self.embed_tokens = target_embed_tokens self.block_size = block_size self.mask_token_id = mask_token_id self.attention_backend = attention_backend self.num_anchors = num_anchors self.loss_decay_gamma = loss_decay_gamma self._cached_block_mask: Optional[BlockMask] = None self._cached_seq_len: Optional[int] = None self._cached_bsz: Optional[int] = None def _sample_anchor_positions( self, seq_len: int, loss_mask: torch.Tensor, device: torch.device ) -> Tuple[torch.Tensor, torch.Tensor]: """Randomly sample anchor positions per sample; returns (anchors, keep_mask).""" bs = self.block_size bsz = loss_mask.shape[0] max_anchor = max(seq_len - bs, 0) valid = loss_mask[:, : max_anchor + 1] > 0.5 valid_counts = valid.sum(dim=1) max_n = min(self.num_anchors, int(valid_counts.max().item()) - 1) if max_n <= 0: raise ValueError("should preprocess the data.") indices = ( torch.arange(max_anchor + 1, device=device).unsqueeze(0).expand(bsz, -1) ) masked_indices = torch.where( valid, indices, torch.tensor(seq_len + 1, device=device) ) random_vals = torch.rand(bsz, max_anchor + 1, device=device) random_vals = torch.where(valid, random_vals, torch.tensor(2.0, device=device)) _, sorted_idx = random_vals.sort(dim=1) gathered = torch.gather(masked_indices, 1, sorted_idx) anchors = gathered[:, :max_n].sort(dim=1).values keep_mask = torch.arange(max_n, device=device).unsqueeze( 0 ) < valid_counts.unsqueeze(1).clamp(max=max_n) anchors = torch.where( keep_mask, anchors, torch.tensor(0, dtype=torch.long, device=device) ) return anchors, keep_mask def prepare_noise_input( self, input_ids: torch.Tensor, block_ids: Optional[torch.Tensor] = None ) -> torch.Tensor: """Prepare noise input: first token of each block is real, rest are MASK.""" bsz, seq_len = input_ids.shape device = input_ids.device if block_ids is not None: is_block_start = torch.ones(bsz, seq_len, dtype=torch.bool, device=device) is_block_start[:, 1:] = block_ids[:, 1:] != block_ids[:, :-1] else: positions = torch.arange(seq_len, device=device) is_block_start = (positions % self.block_size) == 0 is_block_start = is_block_start.unsqueeze(0).expand(bsz, -1) noise_input_ids = torch.full_like(input_ids, self.mask_token_id) noise_input_ids[is_block_start] = input_ids[is_block_start] return noise_input_ids def _create_position_ids(self, anchor_positions: torch.Tensor) -> torch.Tensor: """Create absolute position IDs for parallel draft blocks.""" bsz, n_blocks = anchor_positions.shape device = anchor_positions.device offsets = torch.arange(self.block_size, device=device).view(1, 1, -1) pos_ids = anchor_positions.unsqueeze(-1) + offsets return pos_ids.view(bsz, -1) def _create_noise_embed(self, input_ids, anchor_positions, block_keep_mask): bsz, seq_len = input_ids.shape n = anchor_positions.shape[1] bs = self.block_size device = input_ids.device noise_ids = torch.full( (bsz, n * bs), self.mask_token_id, dtype=torch.long, device=device ) block_starts = torch.arange(n, device=device) * bs block_starts = block_starts.unsqueeze(0).expand(bsz, -1) valid_anchor_positions = anchor_positions.clamp(0, seq_len - 1) anchor_tokens = torch.gather(input_ids, 1, valid_anchor_positions) flat_batch_idx = torch.arange(bsz, device=device).unsqueeze(1).expand(bsz, n) noise_ids[flat_batch_idx, block_starts] = torch.where( block_keep_mask, anchor_tokens, torch.tensor(self.mask_token_id, dtype=torch.long, device=device), ) return self.embed_tokens(noise_ids) def forward( self, input_ids: torch.Tensor, hidden_states: torch.Tensor, loss_mask: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """Parallel block-wise training forward pass.""" bsz, seq_len = input_ids.shape device = input_ids.device anchor_positions, block_keep_mask = self._sample_anchor_positions( seq_len, loss_mask, device ) noise_embedding = self._create_noise_embed( input_ids, anchor_positions, block_keep_mask ) context_position_ids = ( torch.arange(seq_len, device=device).unsqueeze(0).expand(bsz, -1) ) draft_position_ids = self._create_position_ids(anchor_positions) full_position_ids = torch.cat([context_position_ids, draft_position_ids], dim=1) if self.attention_backend == "flex_attention": dflash_attn_mask = create_dflash_block_mask( anchor_positions=anchor_positions, block_keep_mask=block_keep_mask, S=seq_len, block_size=self.block_size, device=device, ) else: dflash_attn_mask = create_dflash_sdpa_mask( anchor_positions=anchor_positions, block_keep_mask=block_keep_mask, S=seq_len, block_size=self.block_size, device=device, ) output_hidden = self.draft_model( position_ids=full_position_ids, noise_embedding=noise_embedding, target_hidden=hidden_states, attention_mask=dflash_attn_mask, ) logits = self.lm_head(output_hidden) # --- Labels: same-position prediction (position k predicts token anchor+k) --- label_offsets = torch.arange(0, self.block_size, device=device).view(1, 1, -1) label_indices = anchor_positions.unsqueeze(-1) + label_offsets valid_label_mask = label_indices < seq_len safe_label_indices = label_indices.clamp(max=seq_len - 1) target_ids = torch.gather( input_ids.unsqueeze(1).expand(-1, anchor_positions.size(1), -1), 2, safe_label_indices, ) # --- Weight mask: block validity * bounds * exclude anchor (pos 0) * loss_mask --- weight_mask = ( block_keep_mask.unsqueeze(-1).expand(-1, -1, self.block_size).float() ) weight_mask = weight_mask * valid_label_mask.float() pos_in_block = torch.arange(self.block_size, device=device).view(1, 1, -1) weight_mask = weight_mask * (pos_in_block > 0).float() original_loss_mask_gathered = torch.gather( loss_mask.unsqueeze(1).expand(-1, anchor_positions.size(1), -1), 2, safe_label_indices, ) weight_mask = weight_mask * original_loss_mask_gathered binary_eval_mask = weight_mask.view(-1) # --- Loss decay: exp(-(k-1)/γ) so k=1 (1st prediction) gets weight 1.0 --- if self.loss_decay_gamma is not None and self.loss_decay_gamma > 0: k = torch.arange(self.block_size, device=device).view(1, 1, -1) decay_weights = torch.exp( -(k - 1).clamp(min=0).float() / self.loss_decay_gamma ) weight_mask = weight_mask * decay_weights # --- Cross entropy --- flat_logits = logits.view(-1, logits.size(-1)) flat_targets = target_ids.view(-1) flat_weights = weight_mask.view(-1) loss_per_token = F.cross_entropy(flat_logits, flat_targets, reduction="none") valid_token_count = flat_weights.sum() + 1e-6 loss = (loss_per_token * flat_weights).sum() / valid_token_count # --- Accuracy --- with torch.no_grad(): pred_ids = torch.argmax(flat_logits, dim=-1) correct = (pred_ids == flat_targets) & (binary_eval_mask > 0.5) actual_token_count = binary_eval_mask.sum() + 1e-6 accuracy = correct.sum().float() / actual_token_count return loss, accuracy