File size: 3,635 Bytes
b6ff324
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# ------------------------------------------------------------------------
# Copyright (c) 2024-present, BAAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------
"""Flex attention layers."""

from itertools import accumulate
from typing import List

import torch
from torch import nn

try:
    from torch.nn.attention.flex_attention import create_block_mask
    from torch.nn.attention.flex_attention import flex_attention
except ImportError:
    flex_attention = create_block_mask = None


class FlexAttentionCausal2D(nn.Module):
    """Block-wise causal flex attention."""

    def __init__(self):
        super(FlexAttentionCausal2D, self).__init__()
        self.attn_func = self.offsets = self.flags = None
        self.cu_offsets = self.block_mask = None

    def set_offsets(self, offsets: List[int]):
        """Set block-wise mask offsets."""
        offsets = list(type(offsets)([0]) + offsets if offsets[0] != 0 else offsets)
        if offsets != self.offsets:
            self.offsets, self.block_mask = offsets, None

    def set_offsets_by_lens(self, lens, flags=None):
        """Set block-wise mask offsets by lengths."""
        self.set_offsets(list(accumulate(type(lens)([0]) + lens if lens[0] != 0 else lens)))
        self.flags = flags  # Bidirectional flags (-1: lower triangular, 1: full)

    def get_mask_mod(self) -> callable:
        """Return the mask modification."""
        counts = self.cu_offsets[1:] - self.cu_offsets[:-1]
        ids = torch.arange(len(counts), device=self.cu_offsets.device, dtype=torch.int32)
        ids = ids.repeat_interleave(counts)
        if self.flags is None:
            return lambda b, h, qi, ki: (qi >= ki) | (ids[qi] == ids[ki])
        flags = list(self.flags) + [-1] * (len(counts) - len(self.flags))
        flags = torch.as_tensor(flags, device=self.cu_offsets.device, dtype=torch.int32)
        flags = flags.repeat_interleave(counts)
        return lambda b, h, qi, ki: (qi >= ki) | ((ids[qi] * flags[qi]) == ids[ki])

    def get_attn_func(self) -> callable:
        """Return the attention function."""
        if flex_attention is None:
            raise NotImplementedError(f"FlexAttn requires torch>=2.5 but got {torch.__version__}")
        if self.attn_func is None:
            self.attn_func = torch.compile(flex_attention)
        return self.attn_func

    def get_block_mask(self, q: torch.Tensor) -> torch.Tensor:
        """Return the attention block mask according to inputs."""
        if self.block_mask is not None:
            return self.block_mask
        b, h, q_len = q.shape[:3]
        args = {"B": b, "H": h, "Q_LEN": q_len, "KV_LEN": q_len, "_compile": True}
        self.cu_offsets = torch.as_tensor(self.offsets, device=q.device, dtype=torch.int32)
        self.block_mask = create_block_mask(self.get_mask_mod(), **args)
        return self.block_mask

    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
        return self.get_attn_func()(q, k, v, block_mask=self.get_block_mask(q), enable_gqa=True)