| | from typing import * |
| | from enum import Enum |
| | import torch |
| | import math |
| | from .. import SparseTensor |
| | from .. import DEBUG, ATTN |
| |
|
| | if ATTN == 'xformers': |
| | import xformers.ops as xops |
| | elif ATTN == 'flash_attn': |
| | import flash_attn |
| | else: |
| | raise ValueError(f"Unknown attention module: {ATTN}") |
| |
|
| |
|
| | __all__ = [ |
| | 'sparse_serialized_scaled_dot_product_self_attention', |
| | ] |
| |
|
| |
|
| | class SerializeMode(Enum): |
| | Z_ORDER = 0 |
| | Z_ORDER_TRANSPOSED = 1 |
| | HILBERT = 2 |
| | HILBERT_TRANSPOSED = 3 |
| |
|
| |
|
| | SerializeModes = [ |
| | SerializeMode.Z_ORDER, |
| | SerializeMode.Z_ORDER_TRANSPOSED, |
| | SerializeMode.HILBERT, |
| | SerializeMode.HILBERT_TRANSPOSED |
| | ] |
| |
|
| |
|
| | def calc_serialization( |
| | tensor: SparseTensor, |
| | window_size: int, |
| | serialize_mode: SerializeMode = SerializeMode.Z_ORDER, |
| | shift_sequence: int = 0, |
| | shift_window: Tuple[int, int, int] = (0, 0, 0) |
| | ) -> Tuple[torch.Tensor, torch.Tensor, List[int]]: |
| | """ |
| | Calculate serialization and partitioning for a set of coordinates. |
| | |
| | Args: |
| | tensor (SparseTensor): The input tensor. |
| | window_size (int): The window size to use. |
| | serialize_mode (SerializeMode): The serialization mode to use. |
| | shift_sequence (int): The shift of serialized sequence. |
| | shift_window (Tuple[int, int, int]): The shift of serialized coordinates. |
| | |
| | Returns: |
| | (torch.Tensor, torch.Tensor): Forwards and backwards indices. |
| | """ |
| | fwd_indices = [] |
| | bwd_indices = [] |
| | seq_lens = [] |
| | seq_batch_indices = [] |
| | offsets = [0] |
| | |
| | if 'vox2seq' not in globals(): |
| | import vox2seq |
| |
|
| | |
| | serialize_coords = tensor.coords[:, 1:].clone() |
| | serialize_coords += torch.tensor(shift_window, dtype=torch.int32, device=tensor.device).reshape(1, 3) |
| | if serialize_mode == SerializeMode.Z_ORDER: |
| | code = vox2seq.encode(serialize_coords, mode='z_order', permute=[0, 1, 2]) |
| | elif serialize_mode == SerializeMode.Z_ORDER_TRANSPOSED: |
| | code = vox2seq.encode(serialize_coords, mode='z_order', permute=[1, 0, 2]) |
| | elif serialize_mode == SerializeMode.HILBERT: |
| | code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[0, 1, 2]) |
| | elif serialize_mode == SerializeMode.HILBERT_TRANSPOSED: |
| | code = vox2seq.encode(serialize_coords, mode='hilbert', permute=[1, 0, 2]) |
| | else: |
| | raise ValueError(f"Unknown serialize mode: {serialize_mode}") |
| | |
| | for bi, s in enumerate(tensor.layout): |
| | num_points = s.stop - s.start |
| | num_windows = (num_points + window_size - 1) // window_size |
| | valid_window_size = num_points / num_windows |
| | to_ordered = torch.argsort(code[s.start:s.stop]) |
| | if num_windows == 1: |
| | fwd_indices.append(to_ordered) |
| | bwd_indices.append(torch.zeros_like(to_ordered).scatter_(0, to_ordered, torch.arange(num_points, device=tensor.device))) |
| | fwd_indices[-1] += s.start |
| | bwd_indices[-1] += offsets[-1] |
| | seq_lens.append(num_points) |
| | seq_batch_indices.append(bi) |
| | offsets.append(offsets[-1] + seq_lens[-1]) |
| | else: |
| | |
| | offset = 0 |
| | mids = [(i + 0.5) * valid_window_size + shift_sequence for i in range(num_windows)] |
| | split = [math.floor(i * valid_window_size + shift_sequence) for i in range(num_windows + 1)] |
| | bwd_index = torch.zeros((num_points,), dtype=torch.int64, device=tensor.device) |
| | for i in range(num_windows): |
| | mid = mids[i] |
| | valid_start = split[i] |
| | valid_end = split[i + 1] |
| | padded_start = math.floor(mid - 0.5 * window_size) |
| | padded_end = padded_start + window_size |
| | fwd_indices.append(to_ordered[torch.arange(padded_start, padded_end, device=tensor.device) % num_points]) |
| | offset += valid_start - padded_start |
| | bwd_index.scatter_(0, fwd_indices[-1][valid_start-padded_start:valid_end-padded_start], torch.arange(offset, offset + valid_end - valid_start, device=tensor.device)) |
| | offset += padded_end - valid_start |
| | fwd_indices[-1] += s.start |
| | seq_lens.extend([window_size] * num_windows) |
| | seq_batch_indices.extend([bi] * num_windows) |
| | bwd_indices.append(bwd_index + offsets[-1]) |
| | offsets.append(offsets[-1] + num_windows * window_size) |
| |
|
| | fwd_indices = torch.cat(fwd_indices) |
| | bwd_indices = torch.cat(bwd_indices) |
| |
|
| | return fwd_indices, bwd_indices, seq_lens, seq_batch_indices |
| | |
| |
|
| | def sparse_serialized_scaled_dot_product_self_attention( |
| | qkv: SparseTensor, |
| | window_size: int, |
| | serialize_mode: SerializeMode = SerializeMode.Z_ORDER, |
| | shift_sequence: int = 0, |
| | shift_window: Tuple[int, int, int] = (0, 0, 0) |
| | ) -> SparseTensor: |
| | """ |
| | Apply serialized scaled dot product self attention to a sparse tensor. |
| | |
| | Args: |
| | qkv (SparseTensor): [N, *, 3, H, C] sparse tensor containing Qs, Ks, and Vs. |
| | window_size (int): The window size to use. |
| | serialize_mode (SerializeMode): The serialization mode to use. |
| | shift_sequence (int): The shift of serialized sequence. |
| | shift_window (Tuple[int, int, int]): The shift of serialized coordinates. |
| | shift (int): The shift to use. |
| | """ |
| | assert len(qkv.shape) == 4 and qkv.shape[1] == 3, f"Invalid shape for qkv, got {qkv.shape}, expected [N, *, 3, H, C]" |
| |
|
| | serialization_spatial_cache_name = f'serialization_{serialize_mode}_{window_size}_{shift_sequence}_{shift_window}' |
| | serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name) |
| | if serialization_spatial_cache is None: |
| | fwd_indices, bwd_indices, seq_lens, seq_batch_indices = calc_serialization(qkv, window_size, serialize_mode, shift_sequence, shift_window) |
| | qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, seq_batch_indices)) |
| | else: |
| | fwd_indices, bwd_indices, seq_lens, seq_batch_indices = serialization_spatial_cache |
| |
|
| | M = fwd_indices.shape[0] |
| | T = qkv.feats.shape[0] |
| | H = qkv.feats.shape[2] |
| | C = qkv.feats.shape[3] |
| | |
| | qkv_feats = qkv.feats[fwd_indices] |
| |
|
| | if DEBUG: |
| | start = 0 |
| | qkv_coords = qkv.coords[fwd_indices] |
| | for i in range(len(seq_lens)): |
| | assert (qkv_coords[start:start+seq_lens[i], 0] == seq_batch_indices[i]).all(), f"SparseWindowedScaledDotProductSelfAttention: batch index mismatch" |
| | start += seq_lens[i] |
| |
|
| | if all([seq_len == window_size for seq_len in seq_lens]): |
| | B = len(seq_lens) |
| | N = window_size |
| | qkv_feats = qkv_feats.reshape(B, N, 3, H, C) |
| | if ATTN == 'xformers': |
| | q, k, v = qkv_feats.unbind(dim=2) |
| | out = xops.memory_efficient_attention(q, k, v) |
| | elif ATTN == 'flash_attn': |
| | out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) |
| | else: |
| | raise ValueError(f"Unknown attention module: {ATTN}") |
| | out = out.reshape(B * N, H, C) |
| | else: |
| | if ATTN == 'xformers': |
| | q, k, v = qkv_feats.unbind(dim=1) |
| | q = q.unsqueeze(0) |
| | k = k.unsqueeze(0) |
| | v = v.unsqueeze(0) |
| | mask = xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens) |
| | out = xops.memory_efficient_attention(q, k, v, mask)[0] |
| | elif ATTN == 'flash_attn': |
| | cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \ |
| | .to(qkv.device).int() |
| | out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) |
| |
|
| | out = out[bwd_indices] |
| |
|
| | if DEBUG: |
| | qkv_coords = qkv_coords[bwd_indices] |
| | assert torch.equal(qkv_coords, qkv.coords), "SparseWindowedScaledDotProductSelfAttention: coordinate mismatch" |
| |
|
| | return qkv.replace(out) |
| |
|