from typing import Callable, Optional import torch from torch import nn from transformers import DynamicCache from transformers.cache_utils import Cache from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.models.qwen3.modeling_qwen3 import ( ALL_ATTENTION_FUNCTIONS, FlashAttentionKwargs, GradientCheckpointingLayer, Qwen3Config, Qwen3MLP, Qwen3PreTrainedModel, Qwen3RMSNorm, Qwen3RotaryEmbedding, eager_attention_forward, rotate_half, ) from typing_extensions import Tuple, Unpack def sample(logits: torch.Tensor, temperature: float = 0.0) -> torch.Tensor: if temperature < 1e-5: return torch.argmax(logits, dim=-1) bsz, seq_len, vocab_size = logits.shape logits = logits.view(-1, vocab_size) logits = logits / temperature probs = torch.softmax(logits, dim=-1) return torch.multinomial(probs, num_samples=1).view(bsz, seq_len) def apply_rotary_pos_emb( q, k, q_cos, q_sin, k_cos=None, k_sin=None, position_ids=None, unsqueeze_dim=1, ): q_cos = q_cos.unsqueeze(unsqueeze_dim) q_sin = q_sin.unsqueeze(unsqueeze_dim) if k_cos is None: k_cos = q_cos k_sin = q_sin else: k_cos = k_cos.unsqueeze(unsqueeze_dim) k_sin = k_sin.unsqueeze(unsqueeze_dim) q_len = q.size(-2) q_embed = (q * q_cos[..., -q_len:, :]) + ( rotate_half(q) * q_sin[..., -q_len:, :] ) k_embed = (k * k_cos) + (rotate_half(k) * k_sin) return q_embed, k_embed class Qwen3DFlashAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: Qwen3Config, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx self.head_dim = getattr( config, "head_dim", config.hidden_size // config.num_attention_heads ) self.num_key_value_groups = ( config.num_attention_heads // config.num_key_value_heads ) self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout self.is_causal = False self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias, ) self.k_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias, ) self.v_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias, ) self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias, ) self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) self.sliding_window = ( config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None ) def forward( self, hidden_states: torch.Tensor, target_hidden: torch.Tensor, position_embeddings: tuple[ tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor] ], attention_mask: Optional[torch.Tensor], kv_hidden_states: Optional[torch.Tensor] = None, past_key_values: Optional[Cache] = None, update_kv_cache: bool = True, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: bsz, q_len = hidden_states.shape[:-1] ctx_len = target_hidden.shape[1] if kv_hidden_states is None: kv_hidden_states = hidden_states q = self.q_proj(hidden_states) q = q.view(bsz, q_len, -1, self.head_dim) q = self.q_norm(q).transpose(1, 2) k_ctx = self.k_proj(target_hidden) k_noise = self.k_proj(kv_hidden_states) v_ctx = self.v_proj(target_hidden) v_noise = self.v_proj(kv_hidden_states) k = torch.cat([k_ctx, k_noise], dim=1).view( bsz, ctx_len + q_len, -1, self.head_dim ) v = torch.cat([v_ctx, v_noise], dim=1).view( bsz, ctx_len + q_len, -1, self.head_dim ) k = self.k_norm(k).transpose(1, 2) v = v.transpose(1, 2) (q_cos, q_sin), (k_cos, k_sin) = position_embeddings q, k = apply_rotary_pos_emb(q, k, q_cos, q_sin, k_cos, k_sin) if past_key_values is not None: if update_kv_cache: cache_kwargs = { "sin": k_sin, "cos": k_cos, "cache_position": cache_position, } k, v = past_key_values.update(k, v, self.layer_idx, cache_kwargs) elif self.layer_idx < len(past_key_values.layers): cache_layer = past_key_values.layers[self.layer_idx] if cache_layer.get_seq_length() > 0: k = torch.cat([cache_layer.keys, k], dim=-2) v = torch.cat([cache_layer.values, v], dim=-2) attn_fn: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attn_fn = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attn_fn( self, q, k, v, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, sliding_window=self.sliding_window, **kwargs, ) attn_output = attn_output.reshape(bsz, q_len, -1) attn_output = self.o_proj(attn_output) return attn_output, attn_weights class Qwen3DFlashDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: Qwen3Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.self_attn = Qwen3DFlashAttention(config=config, layer_idx=layer_idx) self.mlp = Qwen3MLP(config) self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen3RMSNorm( config.hidden_size, eps=config.rms_norm_eps ) def forward( self, target_hidden: Optional[torch.Tensor] = None, hidden_states: Optional[torch.Tensor] = None, kv_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, update_kv_cache: bool = True, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[ Tuple[torch.Tensor, torch.Tensor] ] = None, # necessary, but kept here for BC **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] ]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) if kv_hidden_states is None: kv_hidden_states = hidden_states else: kv_hidden_states = self.input_layernorm(kv_hidden_states) hidden_states = self.self_attn( hidden_states=hidden_states, target_hidden=target_hidden, kv_hidden_states=kv_hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_value, output_attentions=output_attentions, use_cache=use_cache, update_kv_cache=update_kv_cache, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, )[0] hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states def build_target_layer_ids(num_target_layers: int, num_draft_layers: int): if num_draft_layers == 1: return [(num_target_layers // 2)] start = 1 end = num_target_layers - 3 span = end - start target_layer_ids = [ int(round(start + (i * span) / (num_draft_layers - 1))) for i in range(num_draft_layers) ] return target_layer_ids def extract_context_feature( hidden_states: list[torch.Tensor], layer_ids: Optional[list[int]], ) -> torch.Tensor: offset = 1 selected_states = [] for layer_id in layer_ids: selected_states.append(hidden_states[layer_id + offset]) target_hidden = torch.cat(selected_states, dim=-1) return target_hidden class DFlashDraftModel(Qwen3PreTrainedModel): config_class = Qwen3Config _no_split_modules = ["Qwen3DFlashDecoderLayer"] def __init__(self, config) -> None: super().__init__(config) self.config = config self.layers = nn.ModuleList( [ Qwen3DFlashDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers) ] ) dflash_config = getattr(config, "dflash_config", {}) or {} self.target_layer_ids = dflash_config.get( "target_layer_ids", build_target_layer_ids(config.num_target_layers, config.num_hidden_layers), ) self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = Qwen3RotaryEmbedding(config) self.fc = nn.Linear( len(self.target_layer_ids) * config.hidden_size, config.hidden_size, bias=False, ) self.hidden_norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.block_size = config.block_size self.mask_token_id = dflash_config.get("mask_token_id", None) self.post_init() def _resolve_position_ids( self, position_ids: Optional[torch.LongTensor], noise_position_ids: Optional[torch.LongTensor], kv_position_ids: Optional[torch.LongTensor], noise_len: int, ctx_len: int, ) -> tuple[torch.LongTensor, torch.LongTensor]: if position_ids is not None: if kv_position_ids is None: kv_position_ids = position_ids if noise_position_ids is None: noise_position_ids = position_ids[:, -noise_len:] if noise_position_ids is None: raise ValueError("DFlash forward requires noise_position_ids or position_ids.") if kv_position_ids is None: if ctx_len == 0: kv_position_ids = noise_position_ids else: raise ValueError( "DFlash forward requires kv_position_ids for context+noise attention." ) expected_kv_len = ctx_len + noise_len if noise_position_ids.shape[1] != noise_len: raise ValueError( f"noise_position_ids length {noise_position_ids.shape[1]} does not match noise length {noise_len}." ) if kv_position_ids.shape[1] != expected_kv_len: raise ValueError( f"kv_position_ids length {kv_position_ids.shape[1]} does not match expected KV length {expected_kv_len}." ) return noise_position_ids, kv_position_ids def forward( self, position_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, noise_embedding: Optional[torch.Tensor] = None, kv_noise_embedding: Optional[torch.Tensor] = None, target_hidden: Optional[torch.Tensor] = None, noise_position_ids: Optional[torch.LongTensor] = None, kv_position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, use_cache: bool = False, **kwargs, ) -> CausalLMOutputWithPast: hidden_states = noise_embedding kv_hidden_states = kv_noise_embedding target_hidden = self.hidden_norm(self.fc(target_hidden)) noise_position_ids, kv_position_ids = self._resolve_position_ids( position_ids=position_ids, noise_position_ids=noise_position_ids, kv_position_ids=kv_position_ids, noise_len=hidden_states.shape[1], ctx_len=target_hidden.shape[1], ) position_embeddings = ( self.rotary_emb(hidden_states, noise_position_ids), self.rotary_emb(hidden_states, kv_position_ids), ) for layer in self.layers: if kv_hidden_states is None: hidden_states = layer( hidden_states=hidden_states, target_hidden=target_hidden, attention_mask=attention_mask, position_ids=kv_position_ids, past_key_value=past_key_values, use_cache=use_cache, update_kv_cache=use_cache, position_embeddings=position_embeddings, **kwargs, ) else: hidden_states = layer( hidden_states=hidden_states, target_hidden=target_hidden, kv_hidden_states=kv_hidden_states, attention_mask=attention_mask, position_ids=kv_position_ids, past_key_value=past_key_values, use_cache=use_cache, update_kv_cache=False, position_embeddings=position_embeddings, **kwargs, ) kv_hidden_states = layer( hidden_states=kv_hidden_states, target_hidden=target_hidden, attention_mask=attention_mask, position_ids=kv_position_ids, past_key_value=past_key_values, use_cache=use_cache, update_kv_cache=use_cache, position_embeddings=position_embeddings, **kwargs, ) return self.norm(hidden_states) @torch.inference_mode() def spec_generate( self, target: nn.Module, input_ids: torch.LongTensor, max_new_tokens: int, stop_token_ids: list[int], temperature: float, num_denoise_steps: int = 1, ): self.eval() num_input_tokens = input_ids.shape[1] max_length = num_input_tokens + max_new_tokens block_size = self.block_size output_ids = torch.full( (1, max_length + block_size), self.mask_token_id, dtype=torch.long, device=target.device, ) position_ids = torch.arange( output_ids.shape[1], device=target.device ).unsqueeze(0) past_key_values_target = DynamicCache() past_key_values_draft = DynamicCache() # Prefill stage output = target( input_ids, position_ids=position_ids[:, :num_input_tokens], past_key_values=past_key_values_target, use_cache=True, logits_to_keep=1, output_hidden_states=True, ) output_ids[:, :num_input_tokens] = input_ids output_ids[:, num_input_tokens : num_input_tokens + 1] = sample( output.logits, temperature ) target_hidden = extract_context_feature( output.hidden_states, self.target_layer_ids ) # Decode stage acceptance_lengths = [] start = input_ids.shape[1] while start < max_length: block_output_ids = output_ids[:, start : start + block_size].clone() block_position_ids = position_ids[:, start : start + block_size] draft_cache_prefix_len = past_key_values_draft.get_seq_length() draft_kv_position_ids = position_ids[ :, draft_cache_prefix_len : start + block_size ] mask_noise_embedding = target.model.embed_tokens(block_output_ids) # Multi-step denoising loop for denoise_step in range(num_denoise_steps): noise_embedding = mask_noise_embedding if denoise_step > 0: pred_noise_embedding = target.model.embed_tokens(block_output_ids) mix_weight = denoise_step / num_denoise_steps noise_embedding = torch.lerp( mask_noise_embedding, pred_noise_embedding, mix_weight ) draft_hidden = self( target_hidden=target_hidden, noise_embedding=noise_embedding, noise_position_ids=block_position_ids, kv_position_ids=draft_kv_position_ids, past_key_values=past_key_values_draft, use_cache=True, is_causal=False, )[:, -block_size + 1 :, :] draft_logits = target.lm_head(draft_hidden) block_output_ids[:, 1:] = sample(draft_logits) if denoise_step + 1 < num_denoise_steps: # Reuse the accepted-prefix cache, but rebuild the current block on the next denoise step. past_key_values_draft.crop(draft_cache_prefix_len) past_key_values_draft.crop(start) output = target( block_output_ids, position_ids=block_position_ids, past_key_values=past_key_values_target, use_cache=True, output_hidden_states=True, ) posterior = sample(output.logits, temperature) acceptance_length = ( (block_output_ids[:, 1:] == posterior[:, :-1]) .cumprod(dim=1) .sum(dim=1)[0] .item() ) output_ids[:, start : start + acceptance_length + 1] = block_output_ids[ :, : acceptance_length + 1 ] output_ids[:, start + acceptance_length + 1] = posterior[ :, acceptance_length ] start += acceptance_length + 1 past_key_values_target.crop(start) target_hidden = extract_context_feature( output.hidden_states, self.target_layer_ids )[:, : acceptance_length + 1, :] acceptance_lengths.append(acceptance_length + 1) if stop_token_ids is not None and any( stop_token_id in output_ids[:, num_input_tokens:] for stop_token_id in stop_token_ids ): break output_ids = output_ids[:, :max_length] output_ids = output_ids[:, output_ids[0] != self.mask_token_id] if stop_token_ids is not None: stop_token_ids = torch.tensor(stop_token_ids, device=output_ids.device) stop_token_indices = torch.isin( output_ids[0][num_input_tokens:], stop_token_ids ).nonzero(as_tuple=True)[0] if stop_token_indices.numel() > 0: output_ids = output_ids[ :, : num_input_tokens + stop_token_indices[0] + 1 ] return output_ids, acceptance_lengths