| | import torch |
| | import random |
| | import torch.nn.functional as F |
| |
|
| | import flash_mla |
| |
|
| | |
| |
|
| |
|
| | def test_flash_mla(): |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | b = 16 |
| | s_q = 16 |
| | mean_sk = 16 |
| | h_q = 16 |
| | h_kv = 1 |
| | d = 576 |
| | dv = 512 |
| |
|
| |
|
| | causal = True |
| | varlen = False |
| |
|
| | print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}") |
| |
|
| | cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32) |
| | if varlen: |
| | for i in range(b): |
| | cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), s_q) |
| | total_seqlens = cache_seqlens.sum().item() |
| | mean_seqlens = cache_seqlens.float().mean().int().item() |
| | max_seqlen = cache_seqlens.max().item() |
| | |
| | |
| | print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}") |
| | max_seqlen_pad = max_seqlen + 255 & ~255 |
| | q = torch.randn(b, s_q, h_q, d) |
| | block_size = 64 |
| | block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view( |
| | b, max_seqlen_pad // block_size |
| | ) |
| | blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) |
| | print(blocked_k.shape) |
| | for i in range(b): |
| | blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item() :] = float( |
| | "nan" |
| | ) |
| | blocked_v = blocked_k[..., :dv] |
| | print(blocked_k.shape, blocked_v.shape) |
| |
|
| | cache_seqlens = cache_seqlens.to("cuda") |
| |
|
| | tile_scheduler_metadata, num_splits = flash_mla.get_mla_metadata( |
| | seqlens_k=cache_seqlens, |
| | |
| | s_q=s_q * h_q // h_kv, |
| | h_kv=h_kv, |
| | ) |
| | print(tile_scheduler_metadata, num_splits) |
| |
|
| | |
| | assert False |
| |
|