| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import math |
| | from typing import Optional, Union |
| |
|
| | import torch |
| | from torch import nn |
| |
|
| | from .ar_modules_embedding import RotaryPositionEmbedding |
| | from .ar_modules_normalization import create_norm |
| |
|
| |
|
| | class Attention(nn.Module): |
| | """ |
| | Attenion layer with KV cache. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | n_heads: int, |
| | n_kv_heads: Union[int, None], |
| | dim: int, |
| | max_batch_size: int, |
| | max_seq_len: int, |
| | context_dim: Optional[int] = None, |
| | use_qk_normalization: bool = False, |
| | norm_type: str = "rmsnorm", |
| | norm_eps: float = 1e-5, |
| | causal_mask: Optional[bool] = True, |
| | head_dim: Optional[int] = None, |
| | fuse_qkv: bool = False, |
| | precision: str = "bfloat16", |
| | attn_type: str = "self", |
| | ): |
| | """ |
| | Initializes the GQA module. |
| | |
| | Args: |
| | n_heads (int): The number of attention heads. |
| | n_kv_heads (int, optional): The number of key-value attention heads. None defaults to n_heads. |
| | dim (int): The dimensionality of the input and output. |
| | max_batch_size (int): The maximum batch size. |
| | max_seq_len (int): The maximum sequence length. |
| | context_dim (int, optional): The dimensionality of the context for cross-attn. Defaults to None. |
| | use_qk_normalization (bool, optional): Whether to apply QK normalization. Defaults to False. |
| | norm_type (str, optional): The type of normalization layer. Defaults to "rmsnorm". |
| | norm_eps (float, optional): The epsilon value for normalization. Defaults to 1e-5. |
| | causal_mask (bool, optional): Whether to use causal mask. Defaults to True. |
| | head_dim (int, optional): The dimensionality of each attention head. If None, defaults to dim // n_heads. |
| | fuse_qkv (bool, optional): Whether to fuse QKV. Defaults to False. |
| | precision (str, optional): The precision of the module. Defaults to "bfloat16". |
| | attn_type (str, optional): The type of attention. Defaults to "self". |
| | """ |
| | super().__init__() |
| | assert attn_type in ["self", "cross", "full"], f"Invalid attention type: {attn_type}" |
| | self.attn_type = attn_type |
| | context_dim = dim if context_dim is None else context_dim |
| |
|
| | self.dim = dim |
| | self.context_dim = context_dim |
| | self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads |
| | self.n_local_kv_heads = self.n_kv_heads |
| | self.n_local_heads = n_heads |
| | self.n_rep = self.n_local_heads // self.n_local_kv_heads |
| | self.head_dim = dim // n_heads if head_dim is None else head_dim |
| | self.causal_mask = causal_mask |
| | self.fuse_qkv = fuse_qkv |
| | self.precision = precision |
| |
|
| | if fuse_qkv: |
| | assert context_dim == dim, f"Fuse QKV requires context_dim ({context_dim}) to be equal to dim ({dim})" |
| | self.total_local_head_dim = (self.n_local_heads + 2 * self.n_local_kv_heads) * self.head_dim |
| | self.wqkv = nn.Linear(dim, self.total_local_head_dim, bias=False) |
| | |
| | self._register_load_state_dict_pre_hook(self.load_hook) |
| | else: |
| | self.wq = nn.Linear(dim, self.n_local_heads * self.head_dim, bias=False) |
| | self.wk = nn.Linear(context_dim, self.n_local_kv_heads * self.head_dim, bias=False) |
| | self.wv = nn.Linear(context_dim, self.n_local_kv_heads * self.head_dim, bias=False) |
| | self.wo = nn.Linear(self.n_local_heads * self.head_dim, dim, bias=False) |
| |
|
| | self.max_batch_size = max_batch_size |
| | self.max_seq_len = max_seq_len |
| |
|
| | if self.attn_type == "self": |
| | |
| | self.init_kv_cache() |
| |
|
| | |
| | if use_qk_normalization: |
| | self.q_norm = create_norm(norm_type, dim=self.head_dim, eps=norm_eps) |
| | self.k_norm = create_norm(norm_type, dim=self.head_dim, eps=norm_eps) |
| |
|
| | self.use_qk_normalization = use_qk_normalization |
| |
|
| | self.to(dtype=getattr(torch, self.precision)) |
| |
|
| | def load_hook(self, state_dict, prefix, *args): |
| | if prefix + "wq.weight" in state_dict: |
| | wq = state_dict.pop(prefix + "wq.weight") |
| | wk = state_dict.pop(prefix + "wk.weight") |
| | wv = state_dict.pop(prefix + "wv.weight") |
| | state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) |
| |
|
| | def init_kv_cache(self, dtype=None): |
| | cache_shape = (self.max_batch_size, self.n_local_kv_heads, self.max_seq_len, self.head_dim) |
| | if dtype is None: |
| | dtype = getattr(torch, self.precision) |
| | if self.attn_type == "self": |
| | self.cache_k = torch.zeros(cache_shape, dtype=dtype).cuda() |
| | self.cache_v = torch.zeros(cache_shape, dtype=dtype).cuda() |
| |
|
| | def forward( |
| | self, |
| | x: torch.Tensor, |
| | rope: RotaryPositionEmbedding, |
| | input_pos: torch.Tensor, |
| | mask: Optional[torch.Tensor] = None, |
| | context: Optional[torch.Tensor] = None, |
| | ): |
| | """ |
| | Forward pass of GQA. |
| | |
| | Args: |
| | x: The input tensor of shape (batch_size, seq_len, dim). |
| | rope: The rotary positional embedding module. |
| | input_pos: The starting position of the current sequence. |
| | mask: The attention mask tensor. |
| | context: The context tensor of shape (batch_size, context_len, dim). |
| | |
| | Returns: |
| | The output tensor after applying GQA. |
| | """ |
| | bsz, seqlen, _ = x.shape |
| |
|
| | |
| | context = x if context is None else context |
| | context_len = seqlen if context is None else context.shape[1] |
| |
|
| | if self.fuse_qkv: |
| | q_size = self.n_local_heads * self.head_dim |
| | kv_size = self.n_local_kv_heads * self.head_dim |
| | xq, xk, xv = self.wqkv(x).split([q_size, kv_size, kv_size], dim=-1) |
| | else: |
| | |
| | xq, xk, xv = self.wq(x), self.wk(context), self.wv(context) |
| |
|
| | |
| | xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) |
| | xk = xk.view(bsz, context_len, self.n_local_kv_heads, self.head_dim) |
| | xv = xv.view(bsz, context_len, self.n_local_kv_heads, self.head_dim) |
| |
|
| | |
| | if self.use_qk_normalization: |
| | xq = self.q_norm(xq) |
| | xk = self.k_norm(xk) |
| |
|
| | |
| | |
| | if self.attn_type in ["self", "full"]: |
| | xq, xk = rope(xq, xk, input_pos, seqlen) |
| |
|
| | xq, xk, xv = map(lambda x: x.transpose(1, 2), (xq, xk, xv)) |
| | |
| | |
| | |
| | if self.attn_type == "self": |
| | |
| | assert input_pos is not None |
| | self.cache_k[:bsz, :, input_pos] = xk |
| | self.cache_v[:bsz, :, input_pos] = xv |
| | keys, values = ( |
| | self.cache_k[:bsz, :, :], |
| | self.cache_v[:bsz, :, :], |
| | ) |
| | else: |
| | keys, values = xk, xv |
| |
|
| | |
| | keys = keys.repeat_interleave(self.n_rep, dim=1) |
| | values = values.repeat_interleave(self.n_rep, dim=1) |
| |
|
| | |
| | |
| | |
| | is_causal = False |
| | output = scaled_dot_product_attention( |
| | xq, |
| | keys, |
| | values, |
| | head_dim=self.head_dim, |
| | mask=mask, |
| | is_causal=is_causal, |
| | dropout_p=0.0, |
| | ) |
| | output = output.view(bsz, seqlen, -1) |
| | output = self.wo(output) |
| | return output |
| |
|
| |
|
| | def scaled_dot_product_attention( |
| | q: torch.Tensor, |
| | k: torch.Tensor, |
| | v: torch.Tensor, |
| | head_dim: int, |
| | mask: Optional[torch.Tensor] = None, |
| | is_causal: Optional[bool] = None, |
| | dropout_p: float = 0.0, |
| | ) -> torch.Tensor: |
| | """ |
| | PyTorch's native implementation of Flash Attention 2. |
| | |
| | If `is_causal` is given, then the causal attention mask is applied accordingly: |
| | - If `is_causal` is True, the standard upper-left causal attention masking is applied. |
| | - If `is_causal` is False, no attention mask is applied, unless an explicit mask tensor is |
| | provided (i.e., `mask is not None`). |
| | |
| | If `is_causal` is not given (i.e., `is_causal is None`), then the attention mask is applied |
| | based on the provided mask tensor: |
| | - If no explicit attention mask is given (i.e., `mask is None`), `is_causal` is set to True, |
| | leading to the standard upper-left causal attention masking. |
| | - If an attention mask is given (i.e., `mask is not None`), the provided mask is used, |
| | and `is_causal` is set to False. |
| | |
| | Args: |
| | q (torch.Tensor): Query tensor |
| | k (torch.Tensor): Key tensor |
| | v (torch.Tensor): Value tensor |
| | head_dim (int): Dimension of each attention head |
| | mask (Optional[torch.Tensor], optional): Attention mask. Defaults to None. |
| | is_causal (Optional[bool], optional): Whether to apply causal attention mask. Defaults to None. |
| | dropout_p (float, optional): Dropout rate. Defaults to 0.0. |
| | |
| | Returns: |
| | torch.Tensor: Output tensor after applying scaled dot-product attention |
| | """ |
| | scale = 1.0 / math.sqrt(head_dim) |
| | if is_causal is None: |
| | is_causal = mask is None |
| | y = torch.nn.functional.scaled_dot_product_attention( |
| | q, |
| | k, |
| | v, |
| | attn_mask=mask, |
| | dropout_p=dropout_p, |
| | scale=scale, |
| | is_causal=is_causal, |
| | ) |
| | return y.transpose(1, 2).contiguous() |
| |
|