| import random |
| from typing import List, Tuple |
|
|
| import paged_attention as ops |
| import pytest |
| import torch |
| from paged_attention.platforms import current_platform |
|
|
| from .utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck |
|
|
| COPYING_DIRECTION = [("gpu", "cpu"), ("gpu", "gpu"), ("cpu", "gpu")] |
| DTYPES = [torch.half, torch.bfloat16, torch.float] |
| NUM_TOKENS = [42] |
| NUM_LAYERS = [1] |
| NUM_HEADS = [8] |
| HEAD_SIZES = [64, 80, 120, 256] |
| BLOCK_SIZES = [8, 16, 32] |
|
|
| |
| |
| NUM_BLOCKS = [1024, 10000] |
|
|
| NUM_MAPPINGS = [256] |
| SEEDS = [0] |
| if current_platform.is_mps(): |
| DEVICES = ["mps:0"] |
| else: |
| DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] |
|
|
| if current_platform.is_mps(): |
| KV_CACHE_DTYPE = ["auto", "fp8"] |
| else: |
| KV_CACHE_DTYPE = ["auto", "fp8"] |
|
|
|
|
| @pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) |
| @pytest.mark.parametrize("num_layers", NUM_LAYERS) |
| @pytest.mark.parametrize("num_heads", NUM_HEADS) |
| @pytest.mark.parametrize("head_size", HEAD_SIZES) |
| @pytest.mark.parametrize("block_size", BLOCK_SIZES) |
| @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) |
| @pytest.mark.parametrize("dtype", DTYPES) |
| @pytest.mark.parametrize("seed", SEEDS) |
| @pytest.mark.parametrize("device", DEVICES) |
| @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) |
| @torch.inference_mode() |
| def test_copy_blocks( |
| kv_cache_factory, |
| num_mappings: int, |
| num_layers: int, |
| num_heads: int, |
| head_size: int, |
| block_size: int, |
| num_blocks: int, |
| dtype: torch.dtype, |
| seed: int, |
| kv_cache_dtype: str, |
| device: str, |
| ) -> None: |
| if kv_cache_dtype == "fp8" and head_size % 16: |
| pytest.skip() |
| current_platform.seed_everything(seed) |
| |
| if not device.startswith("mps"): |
| torch.set_default_device(device) |
| |
| |
| assert 2 * num_mappings <= num_blocks |
| src_blocks = random.sample(range(num_blocks), num_mappings) |
| remainig_blocks = list(set(range(num_blocks)) - set(src_blocks)) |
| dst_blocks = random.sample(remainig_blocks, 2 * num_mappings) |
| block_mapping: List[Tuple[int, int]] = [] |
| for i in range(num_mappings): |
| src = src_blocks[i] |
| dst1 = dst_blocks[2 * i] |
| dst2 = dst_blocks[2 * i + 1] |
| block_mapping.append((src, dst1)) |
| block_mapping.append((src, dst2)) |
|
|
| |
| key_caches, value_caches = kv_cache_factory( |
| num_blocks, |
| block_size, |
| num_layers, |
| num_heads, |
| head_size, |
| kv_cache_dtype, |
| dtype, |
| seed, |
| device, |
| ) |
|
|
| |
| cloned_key_caches = [key_cache.clone() for key_cache in key_caches] |
| cloned_value_caches = [value_cache.clone() for value_cache in value_caches] |
|
|
| |
| block_mapping_tensor = torch.tensor( |
| block_mapping, dtype=torch.int64, device=device |
| ).view(-1, 2) |
|
|
| opcheck( |
| ops.ops.copy_blocks, |
| (key_caches, value_caches, block_mapping_tensor), |
| test_utils=DEFAULT_OPCHECK_TEST_UTILS, |
| cond=(head_size == HEAD_SIZES[0]), |
| ) |
| ops.copy_blocks(key_caches, value_caches, block_mapping_tensor) |
|
|
| |
| for src, dst in block_mapping: |
| for cloned_key_cache in cloned_key_caches: |
| cloned_key_cache[dst].copy_(cloned_key_cache[src]) |
| for cloned_value_cache in cloned_value_caches: |
| cloned_value_cache[dst].copy_(cloned_value_cache[src]) |
|
|
| |
| for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches): |
| torch.testing.assert_close(key_cache, cloned_key_cache) |
| for value_cache, cloned_value_cache in zip(value_caches, cloned_value_caches): |
| torch.testing.assert_close(value_cache, cloned_value_cache) |
|
|
|
|
| @pytest.mark.parametrize("num_tokens", NUM_TOKENS) |
| @pytest.mark.parametrize("num_heads", NUM_HEADS) |
| @pytest.mark.parametrize("head_size", HEAD_SIZES) |
| @pytest.mark.parametrize("block_size", BLOCK_SIZES) |
| @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) |
| @pytest.mark.parametrize("dtype", DTYPES) |
| @pytest.mark.parametrize("seed", SEEDS) |
| @pytest.mark.parametrize("device", DEVICES) |
| @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) |
| @torch.inference_mode() |
| def test_reshape_and_cache( |
| kv_cache_factory, |
| num_tokens: int, |
| num_heads: int, |
| head_size: int, |
| block_size: int, |
| num_blocks: int, |
| dtype: torch.dtype, |
| seed: int, |
| device: str, |
| kv_cache_dtype: str, |
| ) -> None: |
| if kv_cache_dtype == "fp8" and head_size % 16: |
| pytest.skip() |
| current_platform.seed_everything(seed) |
| |
| if not device.startswith("mps"): |
| torch.set_default_device(device) |
| |
| num_slots = block_size * num_blocks |
| slot_mapping_lst = random.sample(range(num_slots), num_tokens) |
| slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device) |
|
|
| qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype, device=device) |
| _, key, value = qkv.unbind(dim=1) |
|
|
| |
| key_caches, value_caches = kv_cache_factory( |
| num_blocks, |
| block_size, |
| 1, |
| num_heads, |
| head_size, |
| kv_cache_dtype, |
| dtype, |
| seed, |
| device, |
| ) |
| key_cache, value_cache = key_caches[0], value_caches[0] |
|
|
| |
| if kv_cache_dtype == "fp8": |
| cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16) |
| ops.convert_fp8(cloned_key_cache, key_cache) |
| cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16) |
| ops.convert_fp8(cloned_value_cache, value_cache) |
| else: |
| cloned_key_cache = key_cache.clone() |
| cloned_value_cache = value_cache.clone() |
|
|
| |
| k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device) |
|
|
| |
| opcheck( |
| ops.ops.reshape_and_cache, |
| ( |
| key, |
| value, |
| key_cache, |
| value_cache, |
| slot_mapping, |
| kv_cache_dtype, |
| k_scale, |
| v_scale, |
| ), |
| cond=(head_size == HEAD_SIZES[0]), |
| ) |
| ops.reshape_and_cache( |
| key, |
| value, |
| key_cache, |
| value_cache, |
| slot_mapping, |
| kv_cache_dtype, |
| k_scale, |
| v_scale, |
| ) |
|
|
| if kv_cache_dtype == "fp8": |
| result_key_cache = torch.empty_like(key_cache, dtype=torch.float16) |
| ops.convert_fp8(result_key_cache, key_cache) |
| result_value_cache = torch.empty_like(value_cache, dtype=torch.float16) |
| ops.convert_fp8(result_value_cache, value_cache) |
|
|
| |
| reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape) |
| block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor") |
| block_indicies_lst = block_indicies.cpu().tolist() |
| block_offsets = slot_mapping % block_size |
| block_offsets_lst = block_offsets.cpu().tolist() |
| for i in range(num_tokens): |
| block_idx = block_indicies_lst[i] |
| block_offset = block_offsets_lst[i] |
| cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i] |
| cloned_value_cache[block_idx, :, :, block_offset] = value[i] |
|
|
| if kv_cache_dtype == "fp8": |
| torch.testing.assert_close( |
| result_key_cache, cloned_key_cache, atol=0.02, rtol=0.2 |
| ) |
| torch.testing.assert_close( |
| result_value_cache, cloned_value_cache, atol=0.02, rtol=0.2 |
| ) |
| else: |
| torch.testing.assert_close(key_cache, cloned_key_cache) |
| torch.testing.assert_close(value_cache, cloned_value_cache) |
|
|
|
|
| @pytest.mark.parametrize("num_tokens", NUM_TOKENS) |
| @pytest.mark.parametrize("num_heads", NUM_HEADS) |
| @pytest.mark.parametrize("head_size", HEAD_SIZES) |
| @pytest.mark.parametrize("block_size", BLOCK_SIZES) |
| @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) |
| @pytest.mark.parametrize("dtype", DTYPES) |
| @pytest.mark.parametrize("seed", SEEDS) |
| @pytest.mark.parametrize("device", DEVICES) |
| @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) |
| @torch.inference_mode() |
| def test_reshape_and_cache_flash( |
| kv_cache_factory_flashinfer, |
| num_tokens: int, |
| num_heads: int, |
| head_size: int, |
| block_size: int, |
| num_blocks: int, |
| dtype: torch.dtype, |
| seed: int, |
| device: str, |
| kv_cache_dtype: str, |
| ) -> None: |
| |
| if current_platform.is_mps() and kv_cache_dtype == "fp8": |
| pytest.skip("reshape_and_cache_flash doesn't support FP8 on MPS") |
| current_platform.seed_everything(seed) |
| |
| if not device.startswith("mps"): |
| torch.set_default_device(device) |
|
|
| |
| num_slots = block_size * num_blocks |
| slot_mapping_lst = random.sample(range(num_slots), num_tokens) |
| slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device) |
|
|
| qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype, device=device) |
| _, key, value = qkv.unbind(dim=1) |
|
|
| |
| key_caches, value_caches = kv_cache_factory_flashinfer( |
| num_blocks, |
| block_size, |
| 1, |
| num_heads, |
| head_size, |
| kv_cache_dtype, |
| dtype, |
| device=device, |
| ) |
| key_cache, value_cache = key_caches[0].contiguous(), value_caches[0].contiguous() |
| del key_caches |
| del value_caches |
|
|
| k_scale = (key.amax() / 256.0).to(torch.float32) |
| v_scale = (value.amax() / 256.0).to(torch.float32) |
|
|
| |
| if kv_cache_dtype == "fp8": |
| cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16) |
| ops.convert_fp8(cloned_key_cache, key_cache, k_scale, kv_cache_dtype) |
| cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16) |
| ops.convert_fp8(cloned_value_cache, value_cache, v_scale, kv_cache_dtype) |
| else: |
| cloned_key_cache = key_cache.clone() |
| cloned_value_cache = value_cache.clone() |
|
|
| |
| opcheck( |
| ops.ops.reshape_and_cache_flash, |
| ( |
| key, |
| value, |
| key_cache, |
| value_cache, |
| slot_mapping, |
| kv_cache_dtype, |
| k_scale, |
| v_scale, |
| ), |
| cond=(head_size == HEAD_SIZES[0]), |
| ) |
| ops.reshape_and_cache_flash( |
| key, |
| value, |
| key_cache, |
| value_cache, |
| slot_mapping, |
| kv_cache_dtype, |
| k_scale, |
| v_scale, |
| ) |
|
|
| if kv_cache_dtype == "fp8": |
| result_key_cache = torch.empty_like(key_cache, dtype=torch.float16) |
| ops.convert_fp8( |
| result_key_cache, key_cache, k_scale.item(), kv_dtype=kv_cache_dtype |
| ) |
| result_value_cache = torch.empty_like(value_cache, dtype=torch.float16) |
| ops.convert_fp8( |
| result_value_cache, value_cache, v_scale.item(), kv_dtype=kv_cache_dtype |
| ) |
|
|
| |
| block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor") |
| block_indicies_lst = block_indicies.cpu().tolist() |
| block_offsets = slot_mapping % block_size |
| block_offsets_lst = block_offsets.cpu().tolist() |
| for i in range(num_tokens): |
| block_idx = block_indicies_lst[i] |
| block_offset = block_offsets_lst[i] |
| cloned_key_cache[block_idx, block_offset, :, :] = key[i] |
| cloned_value_cache[block_idx, block_offset, :, :] = value[i] |
|
|
| if kv_cache_dtype == "fp8": |
| torch.testing.assert_close( |
| result_key_cache, cloned_key_cache, atol=0.02, rtol=0.2 |
| ) |
| torch.testing.assert_close( |
| result_value_cache, cloned_value_cache, atol=0.02, rtol=0.2 |
| ) |
| else: |
| torch.testing.assert_close(key_cache, cloned_key_cache) |
| torch.testing.assert_close(value_cache, cloned_value_cache) |
|
|
|
|
| @pytest.mark.parametrize("direction", COPYING_DIRECTION) |
| @pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) |
| @pytest.mark.parametrize("num_heads", NUM_HEADS) |
| @pytest.mark.parametrize("head_size", HEAD_SIZES) |
| @pytest.mark.parametrize("block_size", BLOCK_SIZES) |
| @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) |
| @pytest.mark.parametrize("dtype", DTYPES) |
| @pytest.mark.parametrize("seed", SEEDS) |
| @pytest.mark.parametrize("device", DEVICES) |
| @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) |
| @torch.inference_mode() |
| def test_swap_blocks( |
| kv_cache_factory, |
| direction: Tuple[str, str], |
| num_mappings: int, |
| num_heads: int, |
| head_size: int, |
| block_size: int, |
| num_blocks: int, |
| dtype: torch.dtype, |
| seed: int, |
| device: str, |
| kv_cache_dtype: str, |
| ) -> None: |
| if kv_cache_dtype == "fp8" and "cpu" in direction: |
| pytest.skip() |
| if kv_cache_dtype == "fp8" and head_size % 16: |
| pytest.skip() |
|
|
| current_platform.seed_everything(seed) |
|
|
| src_device = device if direction[0] == "gpu" else "cpu" |
| dst_device = device if direction[1] == "gpu" else "cpu" |
|
|
| src_blocks = random.sample(range(num_blocks), num_mappings) |
| |
| if src_device == dst_device: |
| remaining_blocks = list(set(range(num_blocks)) - set(src_blocks)) |
| dst_blocks = random.sample(remaining_blocks, num_mappings) |
| else: |
| dst_blocks = random.sample(range(num_blocks), num_mappings) |
|
|
| block_mapping = list(zip(src_blocks, dst_blocks)) |
| block_mapping_tensor = torch.tensor( |
| block_mapping, dtype=torch.int64, device="cpu" |
| ).view(-1, 2) |
|
|
| |
| src_key_caches, src_value_caches = kv_cache_factory( |
| num_blocks, |
| block_size, |
| 1, |
| num_heads, |
| head_size, |
| kv_cache_dtype, |
| dtype, |
| seed, |
| src_device, |
| ) |
|
|
| |
| dist_key_caches, dist_value_caches = kv_cache_factory( |
| num_blocks, |
| block_size, |
| 1, |
| num_heads, |
| head_size, |
| kv_cache_dtype, |
| dtype, |
| seed, |
| dst_device, |
| ) |
|
|
| src_key_caches_clone = src_key_caches[0].clone() |
| src_value_caches_clone = src_value_caches[0].clone() |
|
|
| |
| do_opcheck = head_size == HEAD_SIZES[0] |
| opcheck( |
| ops.ops.swap_blocks, |
| (src_key_caches[0], dist_key_caches[0], block_mapping_tensor), |
| cond=do_opcheck, |
| ) |
| opcheck( |
| ops.ops.swap_blocks, |
| (src_value_caches[0], dist_value_caches[0], block_mapping_tensor), |
| cond=do_opcheck, |
| ) |
|
|
| ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping_tensor) |
| ops.swap_blocks(src_value_caches[0], dist_value_caches[0], block_mapping_tensor) |
|
|
| for src, dst in block_mapping: |
| torch.testing.assert_close( |
| src_key_caches_clone[src].cpu(), dist_key_caches[0][dst].cpu() |
| ) |
| torch.testing.assert_close( |
| src_value_caches_clone[src].cpu(), dist_value_caches[0][dst].cpu() |
| ) |
|
|
|
|
| @pytest.mark.parametrize("num_heads", NUM_HEADS) |
| @pytest.mark.parametrize("head_size", HEAD_SIZES) |
| @pytest.mark.parametrize("block_size", BLOCK_SIZES) |
| @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) |
| @pytest.mark.parametrize("dtype", DTYPES) |
| @pytest.mark.parametrize("seed", SEEDS) |
| @pytest.mark.parametrize("device", DEVICES) |
| @torch.inference_mode() |
| def test_fp8_e4m3_conversion( |
| num_heads: int, |
| head_size: int, |
| block_size: int, |
| num_blocks: int, |
| dtype: torch.dtype, |
| seed: int, |
| device: str, |
| ) -> None: |
| current_platform.seed_everything(seed) |
|
|
| low = -224.0 |
| high = 224.0 |
| shape = (num_blocks, num_heads, head_size, block_size) |
| cache = torch.empty(shape, dtype=dtype, device=device) |
| cache.uniform_(low, high) |
|
|
| cache_fp8 = torch.empty_like(cache, dtype=torch.uint8) |
| ops.convert_fp8(cache_fp8, cache) |
|
|
| converted_cache = torch.empty_like(cache) |
| ops.convert_fp8(converted_cache, cache_fp8) |
|
|
| torch.testing.assert_close(cache, converted_cache, atol=0.02, rtol=0.2) |
|
|