| |
| import flash_attn_cuda |
| import torch |
| import torch.nn as nn |
|
|
|
|
| def convert_blockmask(blockmask, causal): |
| """Convert from the 0-1 format to the format used by the CUDA code. |
| 0 means the block is skipped. |
| nonzero means the block is not skipped. |
| Argument: |
| blockmask: (row, col): a 0-1 tensor |
| Return: |
| blockmask_converted: (col, row), dtype torch.int32: for each column, it contains the row |
| indices of the nonzero blocks, padded with -1 to reach length @row. |
| The indices are multiplied by 4, with the smallest bit used to encode whether |
| it is the first nonzero in its row, and the 2nd smallest bit to encode whether it is |
| the last nonzero in its row.. |
| """ |
| assert not causal |
| |
| nrow, ncol = blockmask.shape |
| |
| blockmask = blockmask.to(dtype=torch.uint8) |
| nonzero_val, nonzero_sorted_rowidx = blockmask.sort(dim=0, stable=True, descending=True) |
| nonzero_unsorted_rowidx = nonzero_sorted_rowidx.argsort(dim=0) |
| last_nonzero_col_per_row = blockmask.sort(dim=-1, stable=True).indices[:, -1] |
| last_nonzero_col_per_row_after_sort = nonzero_unsorted_rowidx[ |
| torch.arange(nrow, device=blockmask.device), last_nonzero_col_per_row |
| ] |
| first_nonzero_col_per_row = blockmask.sort(dim=-1, stable=True, descending=True).indices[:, 0] |
| first_nonzero_col_per_row_after_sort = nonzero_unsorted_rowidx[ |
| torch.arange(nrow, device=blockmask.device), first_nonzero_col_per_row |
| ] |
| nonzero_idx = nonzero_sorted_rowidx * 4 |
| nonzero_idx[last_nonzero_col_per_row_after_sort, last_nonzero_col_per_row] += 2 |
| nonzero_idx[first_nonzero_col_per_row_after_sort, first_nonzero_col_per_row] += 1 |
| nonzero_idx[nonzero_val == 0] = -1 |
| return nonzero_idx.T.contiguous().to(dtype=torch.int32) |
|
|
|
|
| def _flash_blocksparse_attn_forward( |
| qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal, return_softmax |
| ): |
| context, softmax_lse, *rest = flash_attn_cuda.fwd_block( |
| qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal, return_softmax, None |
| ) |
| |
| |
| S_dmask = rest[0] if return_softmax else None |
| return context, softmax_lse, S_dmask |
|
|
|
|
| def _flash_blocksparse_attn_backward( |
| dout, |
| qkv, |
| out, |
| S_dmask, |
| softmax_lse, |
| cu_seqlens, |
| blockmask, |
| dropout_p, |
| max_s, |
| softmax_scale, |
| causal, |
| ): |
| dqkv, dp, softmax_d = flash_attn_cuda.bwd_block( |
| dout, |
| qkv, |
| out, |
| S_dmask, |
| softmax_lse, |
| cu_seqlens, |
| blockmask, |
| dropout_p, |
| softmax_scale, |
| max_s, |
| causal, |
| None, |
| ) |
| |
| |
| return dqkv |
|
|
|
|
| class FlashBlocksparseAttnFun(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal): |
| |
| rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None |
| if softmax_scale is None: |
| softmax_scale = qkv.shape[-1] ** (-0.5) |
| context, softmax_lse, S_dmask = _flash_blocksparse_attn_forward( |
| qkv, |
| cu_seqlens, |
| blockmask, |
| dropout_p, |
| max_s, |
| softmax_scale, |
| causal=causal, |
| return_softmax=False, |
| ) |
| ctx.save_for_backward(qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state) |
| ctx.dropout_p = dropout_p |
| ctx.max_s = max_s |
| ctx.softmax_scale = softmax_scale |
| ctx.causal = causal |
| return context |
|
|
| @staticmethod |
| def backward(ctx, dout): |
| qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state = ctx.saved_tensors |
| if rng_state is not None: |
| cur_rng_state = torch.cuda.get_rng_state() |
| torch.cuda.set_rng_state(rng_state) |
| |
| dqkv = _flash_blocksparse_attn_backward( |
| dout, |
| qkv, |
| context, |
| context, |
| softmax_lse, |
| cu_seqlens, |
| blockmask, |
| ctx.dropout_p, |
| ctx.max_s, |
| ctx.softmax_scale, |
| ctx.causal, |
| ) |
| if rng_state is not None: |
| torch.cuda.set_rng_state(cur_rng_state) |
| return dqkv, None, None, None, None, None, None, None |
|
|
|
|
| |
| |
| class FlashBlocksparseAttnFunWithS(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal): |
| |
| rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None |
| if softmax_scale is None: |
| softmax_scale = qkv.shape[-1] ** (-0.5) |
| context, softmax_lse, S_dmask = _flash_blocksparse_attn_forward( |
| qkv, |
| cu_seqlens, |
| blockmask, |
| dropout_p, |
| max_s, |
| softmax_scale, |
| causal=causal, |
| return_softmax=True, |
| ) |
| ctx.save_for_backward(qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state) |
| ctx.dropout_p = dropout_p |
| ctx.max_s = max_s |
| ctx.softmax_scale = softmax_scale |
| ctx.causal = causal |
| return context, S_dmask, softmax_lse |
|
|
| @staticmethod |
| def backward(ctx, dout, _dS_dmask_ignored, _dsoftmax_sum_ignored): |
| qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state = ctx.saved_tensors |
| if rng_state is not None: |
| cur_rng_state = torch.cuda.get_rng_state() |
| torch.cuda.set_rng_state(rng_state) |
| dqkv = _flash_blocksparse_attn_backward( |
| dout, |
| qkv, |
| context, |
| S_dmask, |
| softmax_lse, |
| cu_seqlens, |
| blockmask, |
| ctx.dropout_p, |
| ctx.max_s, |
| ctx.softmax_scale, |
| ctx.causal, |
| ) |
| if rng_state is not None: |
| torch.cuda.set_rng_state(cur_rng_state) |
| return dqkv, None, None, None, None, None, None |
|
|
|
|
| def flash_blocksparse_attn_func( |
| qkv, |
| cu_seqlens, |
| blockmask, |
| dropout_p, |
| max_s, |
| softmax_scale=None, |
| causal=False, |
| return_attn_probs=False, |
| convert_mask=True, |
| ): |
| """dropout_p should be set to 0.0 during evaluation""" |
| func = FlashBlocksparseAttnFun if not return_attn_probs else FlashBlocksparseAttnFunWithS |
| if convert_mask: |
| blockmask = convert_blockmask(blockmask, causal=causal) |
| return func.apply(qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal) |
|
|