Instructions to use kernels-community/paged-attention with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Kernels
How to use kernels-community/paged-attention with Kernels:
# !pip install kernels from kernels import get_kernel kernel = get_kernel("kernels-community/paged-attention") - Notebooks
- Google Colab
- Kaggle
| import random | |
| from typing import List, Tuple | |
| import paged_attention as ops | |
| import pytest | |
| import torch | |
| from paged_attention.platforms import current_platform | |
| from .utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck | |
| COPYING_DIRECTION = [("gpu", "cpu"), ("gpu", "gpu"), ("cpu", "gpu")] | |
| DTYPES = [torch.half, torch.bfloat16, torch.float] | |
| NUM_TOKENS = [42] # Arbitrary values for testing | |
| NUM_LAYERS = [1] # Arbitrary values for testing | |
| NUM_HEADS = [8] # Arbitrary values for testing | |
| HEAD_SIZES = [64, 80, 120, 256] | |
| BLOCK_SIZES = [8, 16, 32] | |
| # Arbitrary values for testing | |
| # don't make it too large. e.g. [1024, 36000] will OOM | |
| NUM_BLOCKS = [1024, 10000] | |
| NUM_MAPPINGS = [256] # Arbitrary values for testing | |
| SEEDS = [0] | |
| if current_platform.is_mps(): | |
| DEVICES = ["mps:0"] | |
| else: | |
| DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] | |
| if current_platform.is_mps(): | |
| KV_CACHE_DTYPE = ["auto", "fp8"] | |
| else: | |
| KV_CACHE_DTYPE = ["auto", "fp8"] | |
| def test_copy_blocks( | |
| kv_cache_factory, | |
| num_mappings: int, | |
| num_layers: int, | |
| num_heads: int, | |
| head_size: int, | |
| block_size: int, | |
| num_blocks: int, | |
| dtype: torch.dtype, | |
| seed: int, | |
| kv_cache_dtype: str, | |
| device: str, | |
| ) -> None: | |
| if kv_cache_dtype == "fp8" and head_size % 16: | |
| pytest.skip() | |
| current_platform.seed_everything(seed) | |
| # Don't set MPS as default device to avoid placeholder storage error | |
| if not device.startswith("mps"): | |
| torch.set_default_device(device) | |
| # Generate random block mappings where each source block is mapped to two | |
| # destination blocks. | |
| assert 2 * num_mappings <= num_blocks | |
| src_blocks = random.sample(range(num_blocks), num_mappings) | |
| remainig_blocks = list(set(range(num_blocks)) - set(src_blocks)) | |
| dst_blocks = random.sample(remainig_blocks, 2 * num_mappings) | |
| block_mapping: List[Tuple[int, int]] = [] | |
| for i in range(num_mappings): | |
| src = src_blocks[i] | |
| dst1 = dst_blocks[2 * i] | |
| dst2 = dst_blocks[2 * i + 1] | |
| block_mapping.append((src, dst1)) | |
| block_mapping.append((src, dst2)) | |
| # Create the KV caches. | |
| key_caches, value_caches = kv_cache_factory( | |
| num_blocks, | |
| block_size, | |
| num_layers, | |
| num_heads, | |
| head_size, | |
| kv_cache_dtype, | |
| dtype, | |
| seed, | |
| device, | |
| ) | |
| # Clone the KV caches. | |
| cloned_key_caches = [key_cache.clone() for key_cache in key_caches] | |
| cloned_value_caches = [value_cache.clone() for value_cache in value_caches] | |
| # Call the copy blocks kernel. | |
| block_mapping_tensor = torch.tensor( | |
| block_mapping, dtype=torch.int64, device=device | |
| ).view(-1, 2) | |
| opcheck( | |
| ops.ops.copy_blocks, | |
| (key_caches, value_caches, block_mapping_tensor), | |
| test_utils=DEFAULT_OPCHECK_TEST_UTILS, | |
| cond=(head_size == HEAD_SIZES[0]), | |
| ) | |
| ops.copy_blocks(key_caches, value_caches, block_mapping_tensor) | |
| # Run the reference implementation. | |
| for src, dst in block_mapping: | |
| for cloned_key_cache in cloned_key_caches: | |
| cloned_key_cache[dst].copy_(cloned_key_cache[src]) | |
| for cloned_value_cache in cloned_value_caches: | |
| cloned_value_cache[dst].copy_(cloned_value_cache[src]) | |
| # Compare the results. | |
| for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches): | |
| torch.testing.assert_close(key_cache, cloned_key_cache) | |
| for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches): | |
| torch.testing.assert_close(value_cache, cloned_value_cache) | |
| def test_reshape_and_cache( | |
| kv_cache_factory, | |
| num_tokens: int, | |
| num_heads: int, | |
| head_size: int, | |
| block_size: int, | |
| num_blocks: int, | |
| dtype: torch.dtype, | |
| seed: int, | |
| device: str, | |
| kv_cache_dtype: str, | |
| ) -> None: | |
| if kv_cache_dtype == "fp8" and head_size % 16: | |
| pytest.skip() | |
| current_platform.seed_everything(seed) | |
| # Don't set MPS as default device to avoid placeholder storage error | |
| if not device.startswith("mps"): | |
| torch.set_default_device(device) | |
| # Create a random slot mapping. | |
| num_slots = block_size * num_blocks | |
| slot_mapping_lst = random.sample(range(num_slots), num_tokens) | |
| slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device) | |
| qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype, device=device) | |
| _, key, value = qkv.unbind(dim=1) | |
| # Create the KV caches. | |
| key_caches, value_caches = kv_cache_factory( | |
| num_blocks, | |
| block_size, | |
| 1, | |
| num_heads, | |
| head_size, | |
| kv_cache_dtype, | |
| dtype, | |
| seed, | |
| device, | |
| ) | |
| key_cache, value_cache = key_caches[0], value_caches[0] | |
| # Clone the KV caches. | |
| if kv_cache_dtype == "fp8": | |
| cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16) | |
| ops.convert_fp8(cloned_key_cache, key_cache) | |
| cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16) | |
| ops.convert_fp8(cloned_value_cache, value_cache) | |
| else: | |
| cloned_key_cache = key_cache.clone() | |
| cloned_value_cache = value_cache.clone() | |
| # Using default kv_scale | |
| k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device) | |
| # Call the reshape_and_cache kernel. | |
| opcheck( | |
| ops.ops.reshape_and_cache, | |
| ( | |
| key, | |
| value, | |
| key_cache, | |
| value_cache, | |
| slot_mapping, | |
| kv_cache_dtype, | |
| k_scale, | |
| v_scale, | |
| ), | |
| cond=(head_size == HEAD_SIZES[0]), | |
| ) | |
| ops.reshape_and_cache( | |
| key, | |
| value, | |
| key_cache, | |
| value_cache, | |
| slot_mapping, | |
| kv_cache_dtype, | |
| k_scale, | |
| v_scale, | |
| ) | |
| if kv_cache_dtype == "fp8": | |
| result_key_cache = torch.empty_like(key_cache, dtype=torch.float16) | |
| ops.convert_fp8(result_key_cache, key_cache) | |
| result_value_cache = torch.empty_like(value_cache, dtype=torch.float16) | |
| ops.convert_fp8(result_value_cache, value_cache) | |
| # Run the reference implementation. | |
| reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape) | |
| block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor") | |
| block_indicies_lst = block_indicies.cpu().tolist() | |
| block_offsets = slot_mapping % block_size | |
| block_offsets_lst = block_offsets.cpu().tolist() | |
| for i in range(num_tokens): | |
| block_idx = block_indicies_lst[i] | |
| block_offset = block_offsets_lst[i] | |
| cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i] | |
| cloned_value_cache[block_idx, :, :, block_offset] = value[i] | |
| if kv_cache_dtype == "fp8": | |
| torch.testing.assert_close( | |
| result_key_cache, cloned_key_cache, atol=0.02, rtol=0.2 | |
| ) | |
| torch.testing.assert_close( | |
| result_value_cache, cloned_value_cache, atol=0.02, rtol=0.2 | |
| ) | |
| else: | |
| torch.testing.assert_close(key_cache, cloned_key_cache) | |
| torch.testing.assert_close(value_cache, cloned_value_cache) | |
| def test_reshape_and_cache_flash( | |
| kv_cache_factory_flashinfer, | |
| num_tokens: int, | |
| num_heads: int, | |
| head_size: int, | |
| block_size: int, | |
| num_blocks: int, | |
| dtype: torch.dtype, | |
| seed: int, | |
| device: str, | |
| kv_cache_dtype: str, | |
| ) -> None: | |
| # Flash variant doesn't support FP8 on MPS devices yet | |
| if current_platform.is_mps() and kv_cache_dtype == "fp8": | |
| pytest.skip("reshape_and_cache_flash doesn't support FP8 on MPS") | |
| current_platform.seed_everything(seed) | |
| # Don't set MPS as default device to avoid placeholder storage error | |
| if not device.startswith("mps"): | |
| torch.set_default_device(device) | |
| # Create a random slot mapping. | |
| num_slots = block_size * num_blocks | |
| slot_mapping_lst = random.sample(range(num_slots), num_tokens) | |
| slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device) | |
| qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype, device=device) | |
| _, key, value = qkv.unbind(dim=1) | |
| # Create the KV caches. | |
| key_caches, value_caches = kv_cache_factory_flashinfer( | |
| num_blocks, | |
| block_size, | |
| 1, | |
| num_heads, | |
| head_size, | |
| kv_cache_dtype, | |
| dtype, | |
| device=device, | |
| ) | |
| key_cache, value_cache = key_caches[0].contiguous(), value_caches[0].contiguous() | |
| del key_caches | |
| del value_caches | |
| k_scale = (key.amax() / 256.0).to(torch.float32) | |
| v_scale = (value.amax() / 256.0).to(torch.float32) | |
| # Clone the KV caches. | |
| if kv_cache_dtype == "fp8": | |
| cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16) | |
| ops.convert_fp8(cloned_key_cache, key_cache, k_scale, kv_cache_dtype) | |
| cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16) | |
| ops.convert_fp8(cloned_value_cache, value_cache, v_scale, kv_cache_dtype) | |
| else: | |
| cloned_key_cache = key_cache.clone() | |
| cloned_value_cache = value_cache.clone() | |
| # Call the reshape_and_cache kernel. | |
| opcheck( | |
| ops.ops.reshape_and_cache_flash, | |
| ( | |
| key, | |
| value, | |
| key_cache, | |
| value_cache, | |
| slot_mapping, | |
| kv_cache_dtype, | |
| k_scale, | |
| v_scale, | |
| ), | |
| cond=(head_size == HEAD_SIZES[0]), | |
| ) | |
| ops.reshape_and_cache_flash( | |
| key, | |
| value, | |
| key_cache, | |
| value_cache, | |
| slot_mapping, | |
| kv_cache_dtype, | |
| k_scale, | |
| v_scale, | |
| ) | |
| if kv_cache_dtype == "fp8": | |
| result_key_cache = torch.empty_like(key_cache, dtype=torch.float16) | |
| ops.convert_fp8( | |
| result_key_cache, key_cache, k_scale.item(), kv_dtype=kv_cache_dtype | |
| ) | |
| result_value_cache = torch.empty_like(value_cache, dtype=torch.float16) | |
| ops.convert_fp8( | |
| result_value_cache, value_cache, v_scale.item(), kv_dtype=kv_cache_dtype | |
| ) | |
| # Run the reference implementation. | |
| block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor") | |
| block_indicies_lst = block_indicies.cpu().tolist() | |
| block_offsets = slot_mapping % block_size | |
| block_offsets_lst = block_offsets.cpu().tolist() | |
| for i in range(num_tokens): | |
| block_idx = block_indicies_lst[i] | |
| block_offset = block_offsets_lst[i] | |
| cloned_key_cache[block_idx, block_offset, :, :] = key[i] | |
| cloned_value_cache[block_idx, block_offset, :, :] = value[i] | |
| if kv_cache_dtype == "fp8": | |
| torch.testing.assert_close( | |
| result_key_cache, cloned_key_cache, atol=0.02, rtol=0.2 | |
| ) | |
| torch.testing.assert_close( | |
| result_value_cache, cloned_value_cache, atol=0.02, rtol=0.2 | |
| ) | |
| else: | |
| torch.testing.assert_close(key_cache, cloned_key_cache) | |
| torch.testing.assert_close(value_cache, cloned_value_cache) | |
| def test_swap_blocks( | |
| kv_cache_factory, | |
| direction: Tuple[str, str], | |
| num_mappings: int, | |
| num_heads: int, | |
| head_size: int, | |
| block_size: int, | |
| num_blocks: int, | |
| dtype: torch.dtype, | |
| seed: int, | |
| device: str, | |
| kv_cache_dtype: str, | |
| ) -> None: | |
| if kv_cache_dtype == "fp8" and "cpu" in direction: | |
| pytest.skip() | |
| if kv_cache_dtype == "fp8" and head_size % 16: | |
| pytest.skip() | |
| current_platform.seed_everything(seed) | |
| src_device = device if direction[0] == "gpu" else "cpu" | |
| dst_device = device if direction[1] == "gpu" else "cpu" | |
| src_blocks = random.sample(range(num_blocks), num_mappings) | |
| # For the same device, mapping must not overlap | |
| if src_device == dst_device: | |
| remaining_blocks = list(set(range(num_blocks)) - set(src_blocks)) | |
| dst_blocks = random.sample(remaining_blocks, num_mappings) | |
| else: | |
| dst_blocks = random.sample(range(num_blocks), num_mappings) | |
| block_mapping = list(zip(src_blocks, dst_blocks)) | |
| block_mapping_tensor = torch.tensor( | |
| block_mapping, dtype=torch.int64, device="cpu" | |
| ).view(-1, 2) | |
| # Create the KV caches on the first device. | |
| src_key_caches, src_value_caches = kv_cache_factory( | |
| num_blocks, | |
| block_size, | |
| 1, | |
| num_heads, | |
| head_size, | |
| kv_cache_dtype, | |
| dtype, | |
| seed, | |
| src_device, | |
| ) | |
| # Create the KV caches on the second device. | |
| dist_key_caches, dist_value_caches = kv_cache_factory( | |
| num_blocks, | |
| block_size, | |
| 1, | |
| num_heads, | |
| head_size, | |
| kv_cache_dtype, | |
| dtype, | |
| seed, | |
| dst_device, | |
| ) | |
| src_key_caches_clone = src_key_caches[0].clone() | |
| src_value_caches_clone = src_value_caches[0].clone() | |
| # Call the swap_blocks kernel. | |
| do_opcheck = head_size == HEAD_SIZES[0] | |
| opcheck( | |
| ops.ops.swap_blocks, | |
| (src_key_caches[0], dist_key_caches[0], block_mapping_tensor), | |
| cond=do_opcheck, | |
| ) | |
| opcheck( | |
| ops.ops.swap_blocks, | |
| (src_value_caches[0], dist_value_caches[0], block_mapping_tensor), | |
| cond=do_opcheck, | |
| ) | |
| ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping_tensor) | |
| ops.swap_blocks(src_value_caches[0], dist_value_caches[0], block_mapping_tensor) | |
| for src, dst in block_mapping: | |
| torch.testing.assert_close( | |
| src_key_caches_clone[src].cpu(), dist_key_caches[0][dst].cpu() | |
| ) | |
| torch.testing.assert_close( | |
| src_value_caches_clone[src].cpu(), dist_value_caches[0][dst].cpu() | |
| ) | |
| def test_fp8_e4m3_conversion( | |
| num_heads: int, | |
| head_size: int, | |
| block_size: int, | |
| num_blocks: int, | |
| dtype: torch.dtype, | |
| seed: int, | |
| device: str, | |
| ) -> None: | |
| current_platform.seed_everything(seed) | |
| low = -224.0 | |
| high = 224.0 | |
| shape = (num_blocks, num_heads, head_size, block_size) | |
| cache = torch.empty(shape, dtype=dtype, device=device) | |
| cache.uniform_(low, high) | |
| cache_fp8 = torch.empty_like(cache, dtype=torch.uint8) | |
| ops.convert_fp8(cache_fp8, cache) | |
| converted_cache = torch.empty_like(cache) | |
| ops.convert_fp8(converted_cache, cache_fp8) | |
| torch.testing.assert_close(cache, converted_cache, atol=0.02, rtol=0.2) | |