import torch import torch._dynamo as dynamo from torch.nn.attention.flex_attention import ( create_block_mask, flex_attention, or_masks, ) from transformers.utils import is_torchdynamo_compiling dynamo.config.recompile_limit = 64 # Reference Implementation https://github.com/huggingface/transformers/blob/main/src/transformers/integrations/flex_attention.py class WrappedFlexAttention: """ We are doing a singleton class so that flex attention is compiled once when it's first called. """ _instance = None _is_flex_compiled = False _compiled_flex_attention = None def __new__(cls, *args, **kwargs): if cls._instance is None: # Create a new instance if one doesn't already exist cls._instance = super().__new__(cls) return cls._instance @torch.compiler.disable(recursive=False) def __init__(self): """ Initialize or update the singleton instance. """ if not self._is_flex_compiled: # Enable dynamic shapes to handle different input sizes self._compiled_flex_attention = torch.compile( flex_attention, # mode="max-autotune-no-cudagraphs", ) self._is_flex_compiled = True def __call__(self): return self._compiled_flex_attention def compile_friendly_flex_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs, ) -> torch.Tensor: # First call initialise singleton wrapper object, second call invokes the object method to return compiled flex attention # Do not use compiled version if already compiling forward (it raises issues) flex_attention_compiled = ( WrappedFlexAttention()() if not is_torchdynamo_compiling() else flex_attention ) return flex_attention_compiled( query, key, value, **kwargs, ) class WrappedCreateBlockMask: _instance = None _is_create_block_mask_compiled = False _compiled_create_block_mask = None def __new__(cls, *args, **kwargs): if cls._instance is None: cls._instance = super().__new__(cls) return cls._instance @torch.compiler.disable(recursive=False) def __init__(self): if not self._is_create_block_mask_compiled: self._compiled_create_block_mask = torch.compile(create_block_mask) self._is_create_block_mask_compiled = True def __call__(self): return self._compiled_create_block_mask def compile_friendly_create_block_mask( mask_mod, B, H, Q_LEN, KV_LEN, device, ): create_block_mask_compiled = ( WrappedCreateBlockMask()() if not is_torchdynamo_compiling() else create_block_mask ) return create_block_mask_compiled( mask_mod, B, H, Q_LEN, KV_LEN, device, ) def generate_eagle3_mask( seq_lengths: torch.Tensor, Q_LEN: int, KV_LEN: int, lck: int = 0 ): def causal_mask(b, h, q_idx, kv_idx): # Causal will keep shrinking by 1 diagnol due to appended suffix # Shirnk the causal by diagnol causal_mask = q_idx >= kv_idx padding_mask = (kv_idx < seq_lengths[b]) & (q_idx < seq_lengths[b]) return causal_mask & padding_mask def suffix_mask(b, h, q_idx, kv_idx): suffix_mask = kv_idx >= Q_LEN padding_mask = kv_idx % Q_LEN < seq_lengths[b] diagnol_mask = (kv_idx - q_idx) % Q_LEN == 0 return suffix_mask & padding_mask & diagnol_mask mask_mod = or_masks(causal_mask, suffix_mask) mask_mod.__name__ = f"eagle3_mask_Q_{Q_LEN}_KV_{KV_LEN}_lck_{lck}" return mask_mod