diff --git a/test/test_pallas.py b/test/test_pallas.py index c1f9df9ba0c2..d3f0e0f30c31 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -637,8 +637,8 @@ def _test_ragged_paged_attention( sm_scale=1.0, sliding_window=None, soft_cap=None, - num_kv_pages_per_block=16, - num_queries_per_block=128, + num_kv_pages_per_block=None, + num_queries_per_block=None, pad_tokens_and_seqs=False, use_dynamo=True, ): @@ -751,16 +751,6 @@ def ragged_paged_attention_wrapper( num_seqs_jax = jnp.array([num_seqs], dtype=jnp.int32) from torch_xla.experimental.pallas_kernels.ragged_paged_attention_v2 import ragged_paged_attention as jax_ragged_paged_attention - from torch_xla.experimental.tuned_block_sizes import get_ragged_attention_tuned_block_size - if num_kv_pages_per_block is None: - assert num_queries_per_block is None - token_num, q_head_num, _ = q.shape - _, page_size, num_combined_kv_heads, _ = kv_pages.shape - _, pages_per_seq = page_indices.shape - num_kv_heads = num_combined_kv_heads // 2 - max_model_len = pages_per_seq * page_size - num_kv_pages_per_block, num_queries_per_block = get_ragged_attention_tuned_block_size( - q_head_num, num_kv_heads, token_num, max_model_len) jax_kernel_output = torch.from_numpy( np.array( jax_ragged_paged_attention( @@ -790,8 +780,7 @@ def ragged_paged_attention_wrapper( sm_scale=[1.0, 0.5], sliding_window=[None, 128], soft_cap=[None, 10.0], - pad_tokens_and_seqs=[False, True], - block_sizes=[(16, 128), (None, None)]) + pad_tokens_and_seqs=[False, True]) @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4, "This test only works on TPUv4+.") def test_ragged_paged_attention_wrapper_with_dynamo( @@ -803,12 +792,10 @@ def test_ragged_paged_attention_wrapper_with_dynamo( sliding_window, soft_cap, pad_tokens_and_seqs, - block_sizes, ): head_dim = 128 page_size = 16 num_pages = 1000 - num_kv_pages_per_block, num_queries_per_block = block_sizes self._test_ragged_paged_attention( seq_lens, @@ -822,8 +809,6 @@ def test_ragged_paged_attention_wrapper_with_dynamo( soft_cap=soft_cap, pad_tokens_and_seqs=pad_tokens_and_seqs, use_dynamo=True, - num_kv_pages_per_block=num_kv_pages_per_block, - num_queries_per_block=num_queries_per_block, ) @parameterized.product( @@ -834,7 +819,6 @@ def test_ragged_paged_attention_wrapper_with_dynamo( sliding_window=[None, 128], soft_cap=[None, 10.0], pad_tokens_and_seqs=[False, True], - block_sizes=[(16, 128), (None, None)], ) @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4, "This test only works on TPUv4+.") @@ -847,12 +831,10 @@ def test_ragged_paged_attention_wrapper_without_dynamo( sliding_window, soft_cap, pad_tokens_and_seqs, - block_sizes, ): head_dim = 128 page_size = 16 num_pages = 1000 - num_kv_pages_per_block, num_queries_per_block = block_sizes self._test_ragged_paged_attention( seq_lens, @@ -866,8 +848,6 @@ def test_ragged_paged_attention_wrapper_without_dynamo( soft_cap=soft_cap, pad_tokens_and_seqs=pad_tokens_and_seqs, use_dynamo=False, - num_kv_pages_per_block=num_kv_pages_per_block, - num_queries_per_block=num_queries_per_block, ) @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4, diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index cfc3fb66617d..347ad7907d5f 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -9,7 +9,6 @@ from torch_xla.distributed.spmd import Mesh import torch_xla.distributed.spmd as xs from torch_xla._internal.jax_workarounds import requires_jax -from torch_xla.experimental.tuned_block_sizes import get_ragged_attention_tuned_block_size # Re-expose this API used that is referenced by docs from torch_xla._internal.jax_workarounds import jax_import_guard # noqa: F401, pylint: disable=unused-import @@ -956,16 +955,6 @@ def ragged_paged_attention( # in the global scope which could cause problems for xmp.spawn. from torch_xla.experimental.pallas_kernels.ragged_paged_attention_v2 import ragged_paged_attention as ragged_attention - if num_kv_pages_per_block is None: - assert num_queries_per_block is None - token_num, q_head_num, _ = q.shape - _, page_size, num_combined_kv_heads, _ = kv_pages.shape - _, pages_per_seq = page_indices.shape - num_kv_heads = num_combined_kv_heads // 2 - max_model_len = pages_per_seq * page_size - num_kv_pages_per_block, num_queries_per_block = get_ragged_attention_tuned_block_size( - q_head_num, num_kv_heads, token_num, max_model_len) - if vmem_limit_bytes is None: vmem_limit_bytes = 64 * 1024 * 1024 diff --git a/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py b/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py index 48c28825418b..4f2724a42deb 100644 --- a/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py +++ b/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py @@ -26,6 +26,445 @@ import jax.numpy as jnp DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max) +# The page size is too small. We only have 32 SREGs in TC. If the pages +# per seq is too large, SREGs will spill. +MAX_PAGES_PER_SEQ = 16 + +# key: +# - q_dtype_name +# - kv_dtype_name +# - num_q_heads_per_blk +# - num_kv_heads_per_blk +# - head_dim +# - page_size +# - max_num_batched_tokens +# - max_model_len = page_size * pages_per_seq +# value: +# - num_kv_pages_per_block +# - num_queries_per_block +TUNED_BLOCK_SIZES = { + 'TPU v6': { + # go/keep-sorted start + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 1024, 1024): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 1024, 2048): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 1024, 4096): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 1024, 512): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 2048, 1024): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 2048, 2048): (16, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 2048, 4096): (32, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 2048, 512): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 4096, 1024): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 4096, 2048): (16, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 4096, 4096): (32, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 4096, 512): (4, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 512, 1024): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 512, 2048): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 512, 4096): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 512, 512): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 1024): (64, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 128): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 2048): (128, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 256): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 512): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 64): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 1024): (64, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 128): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 2048): (128, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 256): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 512): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 64): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 1024): (64, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 128): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 2048): (128, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 256): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 512): (32, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 64): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 1024): (64, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 128): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 2048): (128, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 256): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 512): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 64): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 1024, 1024): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 1024, 2048): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 1024, 4096): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 2048, 1024): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 2048, 2048): (8, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 2048, 4096): (16, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 4096, 1024): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 4096, 2048): (8, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 4096, 4096): (16, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 512, 1024): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 512, 2048): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 512, 4096): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 1024): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 128): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 2048): (64, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 256): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 4096): (128, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 512): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 1024): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 128): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 2048): (64, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 256): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 4096): (128, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 512): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 1024): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 128): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 2048): (64, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 256): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 4096): (128, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 512): (16, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 1024): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 128): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 2048): (64, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 256): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 4096): (128, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 512): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 1024): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 2048): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 256): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 4096): (64, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 512): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 1024): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 2048): (32, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 256): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 4096): (64, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 512): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 1024): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 2048): (32, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 256): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 4096): (64, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 512): (8, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 1024): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 2048): (32, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 256): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 4096): (64, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 512): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 1024): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 2048): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 4096): (32, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 512): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 1024): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 2048): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 4096): (32, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 512): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 1024): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 2048): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 4096): (32, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 512): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 1024): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 2048): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 4096): (32, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 512): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 1024): (64, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 128): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 2048): (128, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 256): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 512): (32, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 64): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 1024): (64, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 128): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 2048): (128, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 256): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 512): (32, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 64): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 1024): (64, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 128): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 2048): (128, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 256): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 512): (32, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 64): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 1024): (64, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 128): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 2048): (128, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 256): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 512): (32, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 64): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 1024, 1024): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 1024, 2048): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 1024, 4096): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 2048, 1024): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 2048, 2048): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 2048, 4096): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 4096, 1024): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 4096, 2048): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 4096, 4096): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 512, 1024): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 512, 2048): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 512, 4096): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 1024): (32, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 128): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 2048): (64, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 256): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 4096): (128, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 512): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 1024): (32, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 128): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 2048): (64, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 256): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 4096): (64, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 512): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 1024): (32, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 128): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 2048): (64, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 256): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 4096): (64, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 512): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 1024): (32, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 128): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 2048): (64, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 256): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 4096): (128, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 512): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 1024): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 2048): (32, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 256): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 4096): (64, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 512): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 1024): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 2048): (32, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 256): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 4096): (64, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 512): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 1024): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 2048): (32, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 256): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 4096): (64, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 512): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 1024): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 2048): (32, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 256): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 4096): (64, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 512): (8, 32), + # go/keep-sorted end + }, + 'TPU v5': { + # go/keep-sorted start + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 1024, 1024): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 1024, 2048): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 1024, 512): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 2048, 1024): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 2048, 2048): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 2048, 512): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 4096, 1024): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 4096, 2048): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 4096, 512): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 512, 1024): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 512, 2048): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 128, 512, 512): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 128): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 256): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 1024, 64): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 128): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 256): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 2048, 64): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 128): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 256): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 4096, 64): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 128): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 256): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 16, 512, 64): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 1024, 1024): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 1024, 2048): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 1024, 4096): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 2048, 1024): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 2048, 2048): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 2048, 4096): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 4096, 1024): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 4096, 2048): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 4096, 4096): (16, 64), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 512, 1024): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 512, 2048): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 256, 512, 4096): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 128): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 256): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 1024, 512): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 128): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 256): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 2048, 512): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 128): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 256): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 4096, 512): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 128): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 256): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 32, 512, 512): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 1024): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 256): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 1024, 512): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 1024): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 256): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 2048, 512): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 1024): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 256): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 4096, 512): (8, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 1024): (16, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 256): (4, 32), + ('bfloat16', 'bfloat16', 32, 8, 128, 64, 512, 512): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 1024): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 2048): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 1024, 512): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 1024): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 2048): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 2048, 512): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 1024): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 2048): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 4096, 512): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 1024): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 2048): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 128, 512, 512): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 128): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 256): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 1024, 64): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 128): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 256): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 2048, 64): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 128): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 256): (16, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 4096, 64): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 128): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 256): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 16, 512, 64): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 1024, 1024): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 1024, 2048): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 1024, 4096): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 2048, 1024): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 2048, 2048): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 2048, 4096): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 4096, 1024): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 4096, 2048): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 4096, 4096): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 512, 1024): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 512, 2048): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 256, 512, 4096): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 128): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 256): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 1024, 512): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 128): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 256): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 2048, 512): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 128): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 256): (8, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 4096, 512): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 128): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 256): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 32, 512, 512): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 1024): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 256): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 1024, 512): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 1024): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 256): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 2048, 512): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 1024): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 256): (4, 64), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 4096, 512): (8, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 1024): (16, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 256): (4, 32), + ('bfloat16', 'bfloat16', 8, 1, 128, 64, 512, 512): (8, 32), + # go/keep-sorted end + }, +} + + +def next_power_of_2(x: int): + """Finds the smallest power of 2 >= x using bit manipulation. + + Args: + x: The input number (should be an integer). + + Returns: + The smallest integer power of 2 that is >= x. + """ + assert x > 0 + if x == 1: + return 1 + return 1 << (x - 1).bit_length() + + +def simplify_key(key): + """Simplify the key to reduce the number of combinations.""" + ( + q_dtype, + kv_dtype, + num_q_heads_per_blk, + num_kv_heads_per_blk, + head_dim, + page_size, + max_num_batched_tokens, + pages_per_seq, + ) = key + return ( + jnp.dtype(q_dtype).name, + jnp.dtype(kv_dtype).name, + next_power_of_2(num_q_heads_per_blk), + next_power_of_2(num_kv_heads_per_blk), + (head_dim + 127) // 128 * 128, + next_power_of_2(page_size), + next_power_of_2(max_num_batched_tokens), + next_power_of_2(page_size * pages_per_seq), + ) + + +def get_tpu_version() -> int: + """Returns the numeric version of the TPU, or -1 if not on TPU.""" + kind = jax.devices()[0].device_kind + if 'TPU' not in kind: + return -1 + if kind.endswith(' lite'): + kind = kind[:-len(' lite')] + assert kind[:-1] == 'TPU v', kind + return int(kind[-1]) + + +def get_device_name(num_devices: int | None = None): + name = ' '.join(jax.devices()[0].device_kind.split()[:2]) + if num_devices is not None: + name += f'-{num_devices}' + return name + + +def get_tuned_block_sizes( + q_dtype, + kv_dtype, + num_q_heads_per_blk, + num_kv_heads_per_blk, + head_dim, + page_size, + max_num_batched_tokens, + pages_per_seq, +) -> tuple[int, int]: + """Look up for the best (num_kv_pages_per_blk, num_queries_per_blk) from auto-tuned table.""" + tpu_version = get_tpu_version() + if tpu_version < 4: + raise NotImplementedError('TPU version must be 4 or higher.') + key = ( + q_dtype, + kv_dtype, + num_q_heads_per_blk, + num_kv_heads_per_blk, + head_dim, + page_size, + max_num_batched_tokens, + pages_per_seq, + ) + key = simplify_key(key) + device_name = get_device_name() + + # Default block sizes. + bkv, bq = (128, 32) + if tpu_version == 4: + # This default block size is not tuned, only make sure there's no + # OOM in vmem + bkv, bq = (32, 32) + elif device_name in TUNED_BLOCK_SIZES: + if key in TUNED_BLOCK_SIZES[device_name]: + bkv, bq = TUNED_BLOCK_SIZES[device_name][key] + return (min(pages_per_seq, bkv), min(max_num_batched_tokens, bq)) + + +def get_min_page_size(max_model_len, min_page_size=16): + """Recommended min page size for high-performance kernel.""" + return max(next_power_of_2(max_model_len) // MAX_PAGES_PER_SEQ, min_page_size) class MultiPageAsyncCopyDescriptor: @@ -81,8 +520,18 @@ def ref_ragged_paged_attention( soft_cap: float | None = None, mask_value: float | None = DEFAULT_MASK_VALUE, ): - validate_static_inputs(queries, kv_pages, kv_lens, page_indices, cu_q_lens, - num_seqs, sliding_window, soft_cap) + static_validate_inputs( + queries, + kv_pages, + kv_lens, + page_indices, + cu_q_lens, + num_seqs, + sm_scale=sm_scale, + sliding_window=sliding_window, + soft_cap=soft_cap, + mask_value=mask_value, + ) if mask_value is None: mask_value = DEFAULT_MASK_VALUE _, _, num_combined_kv_heads, head_dim = kv_pages.shape @@ -124,19 +573,40 @@ def ref_ragged_paged_attention( # Expect to run these checks during runtime. -def validate_dynamic_inputs( +def dynamic_validate_inputs( q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] kv_pages: jax. Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] kv_lens: jax.Array, # i32[max_num_seqs] page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] cu_q_lens: jax.Array, # i32[max_num_seqs + 1] - num_seqs, # i32[1] + num_seqs: jax.Array, # i32[1] + *, + # These inputs are optional. If not specified, we will not validate them. + sm_scale: float | None = None, sliding_window: int | None = None, soft_cap: float | None = None, + mask_value: float | None = None, + # Kernel specific params. + num_kv_pages_per_block: int | None = None, + num_queries_per_block: int | None = None, + vmem_limit_bytes: int | None = None, ): - validate_static_inputs(q, kv_pages, kv_lens, page_indices, cu_q_lens, - num_seqs, sliding_window, soft_cap) + static_validate_inputs( + q, + kv_pages, + kv_lens, + page_indices, + cu_q_lens, + num_seqs, + sm_scale=sm_scale, + sliding_window=sliding_window, + soft_cap=soft_cap, + mask_value=mask_value, + num_kv_pages_per_block=num_kv_pages_per_block, + num_queries_per_block=num_queries_per_block, + vmem_limit_bytes=vmem_limit_bytes, + ) max_num_batched_tokens = q.shape[0] page_size = kv_pages.shape[1] max_num_seqs, pages_per_seq = page_indices.shape @@ -161,22 +631,30 @@ def validate_dynamic_inputs( # Expect to run these checks during compile time. -def validate_static_inputs( +def static_validate_inputs( q: jax.Array, # [max_num_batched_tokens, num_q_heads, head_dim] kv_pages: jax. Array, # [total_num_pages, page_size, num_combined_kv_heads, head_dim] kv_lens: jax.Array, # i32[max_num_seqs] page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] cu_q_lens: jax.Array, # i32[max_num_seqs + 1] - num_seqs, # i32[1] + num_seqs: jax.Array, # i32[1] + *, + # These inputs are optional. If not specified, we will not validate them. + sm_scale: float | None = None, sliding_window: int | None = None, soft_cap: float | None = None, + mask_value: float | None = None, + # Kernel specific params. + num_kv_pages_per_block: int | None = None, + num_queries_per_block: int | None = None, + vmem_limit_bytes: int | None = None, ): _, num_q_heads, head_dim = q.shape _, _, num_combined_kv_heads, head_dim_k = kv_pages.shape assert num_combined_kv_heads % 2 == 0 num_kv_heads = num_combined_kv_heads // 2 - max_num_seqs, _ = page_indices.shape + max_num_seqs, pages_per_seq = page_indices.shape if num_seqs.shape != (1,): raise ValueError(f"{num_seqs.shape=} must be (1,)") if head_dim_k != head_dim: @@ -201,6 +679,16 @@ def validate_static_inputs( raise ValueError(f"{sliding_window=} must be positive.") if soft_cap is not None and soft_cap == 0.0: raise ValueError(f"{soft_cap=} must not be 0.0.") + if (num_kv_pages_per_block is not None and + not 0 < num_kv_pages_per_block <= pages_per_seq): + raise ValueError( + f"{num_kv_pages_per_block=} must be in range (0, {pages_per_seq}].") + if num_queries_per_block is not None and num_queries_per_block <= 0: + raise ValueError(f"{num_queries_per_block=} must be positive.") + if vmem_limit_bytes is not None and vmem_limit_bytes <= 0: + raise ValueError(f"{vmem_limit_bytes=} must be positive.") + del sm_scale # No constraints on sm_scale. + del mask_value # No consstraints on mask_value. def ragged_paged_attention_kernel( @@ -615,8 +1103,8 @@ def ragged_paged_attention( sliding_window: int | None = None, soft_cap: float | None = None, mask_value: float | None = DEFAULT_MASK_VALUE, - num_kv_pages_per_block: int = 16, - num_queries_per_block: int = 128, + num_kv_pages_per_block: int | None = None, + num_queries_per_block: int | None = None, vmem_limit_bytes: int | None = None, ): """Ragged paged attention that supports mixed prefill and decode. @@ -643,20 +1131,45 @@ def ragged_paged_attention( Returns: The output of the attention. """ - validate_static_inputs(q, kv_pages, kv_lens, page_indices, cu_q_lens, - num_seqs, sliding_window, soft_cap) + static_validate_inputs( + q, + kv_pages, + kv_lens, + page_indices, + cu_q_lens, + num_seqs, + sm_scale=sm_scale, + sliding_window=sliding_window, + soft_cap=soft_cap, + mask_value=mask_value, + num_kv_pages_per_block=num_kv_pages_per_block, + num_queries_per_block=num_queries_per_block, + vmem_limit_bytes=vmem_limit_bytes, + ) if mask_value is None: mask_value = DEFAULT_MASK_VALUE - num_q, num_q_heads, head_dim = q.shape + num_q_tokens, num_q_heads, head_dim = q.shape _, page_size, num_combined_kv_heads, _ = kv_pages.shape assert num_combined_kv_heads % 2 == 0 num_kv_heads = num_combined_kv_heads // 2 + _, pages_per_seq = page_indices.shape + num_q_heads_per_blk, num_combined_kv_heads_per_blk = get_min_heads_per_blk( + num_q_heads, num_combined_kv_heads, q.dtype, kv_pages.dtype) num_q_per_blk = num_queries_per_block num_kv_pages_per_blk = num_kv_pages_per_block + if num_q_per_blk is None or num_kv_pages_per_blk is None: + num_kv_pages_per_blk, num_q_per_blk = get_tuned_block_sizes( + q.dtype, + kv_pages.dtype, + num_q_heads_per_blk, + num_combined_kv_heads_per_blk // 2, + head_dim, + page_size, + num_q_tokens, + pages_per_seq, + ) num_q_heads_per_kv_head = num_q_heads // num_kv_heads - num_q_blks = cdiv(num_q, num_q_per_blk) - num_q_heads_per_blk, num_combined_kv_heads_per_blk = get_min_heads_per_blk( - num_q_heads, num_combined_kv_heads, q.dtype, kv_pages.dtype) + num_q_blks = cdiv(num_q_tokens, num_q_per_blk) assert num_combined_kv_heads_per_blk % 2 == 0 num_kv_heads_per_blk = num_combined_kv_heads_per_blk // 2 assert num_q_heads_per_blk % num_q_heads_per_kv_head == 0 diff --git a/torch_xla/experimental/tuned_block_sizes.py b/torch_xla/experimental/tuned_block_sizes.py deleted file mode 100644 index d3cb0ae25da2..000000000000 --- a/torch_xla/experimental/tuned_block_sizes.py +++ /dev/null @@ -1,86 +0,0 @@ -import torch_xla - - -def _next_power_of_2_bit_manipulation(x): - """ - Finds the smallest power of 2 >= x using bit manipulation. - Assumes x is an integer. - - Args: - x: The input number (should be an integer). - - Returns: - The smallest integer power of 2 that is >= x. - Returns 1 if x <= 0. - """ - if x <= 0: - return 1 - if x == 1: - return 1 - return 1 << (x - 1).bit_length() - - -# ragged_paged_attention -# key: (q_head_num, kv_head_num, token_num, max_model_len) -# value: (num_kv_pages_per_block, num_queries_per_block) - - -def _simplify_key_ragged_paged_attention(q_head_num, kv_head_num, token_num, - max_model_len): - token_num = _next_power_of_2_bit_manipulation(token_num) - max_model_len = _next_power_of_2_bit_manipulation(max_model_len) - return q_head_num, kv_head_num, token_num, max_model_len - - -# TODO: add more tuned block sizes in the table -# q_head_num, kv_head_num, token_num, max_model_len -_ragged_attention_table = { - (32, 8, 4096, 2048): (128, 64), - (4, 1, 4096, 2048): (128, 128), - (32, 8, 2048, 2048): (128, 32), - (4, 1, 2048, 2048): (128, 64), - (32, 8, 1024, 2048): (64, 32), - (1, 1, 1024, 2048): (64, 32), - (32, 8, 4096, 4096): (128, 64), - (4, 1, 4096, 4096): (128, 128), - (32, 8, 2048, 4096): (128, 32), - (4, 1, 2048, 4096): (128, 64), - (32, 8, 1024, 4096): (64, 32), - (1, 1, 1024, 4096): (64, 32), - (32, 8, 4096, 64): (32, 32), - (4, 1, 4096, 64): (32, 32), - (32, 8, 2048, 64): (32, 32), - (4, 1, 2048, 64): (32, 32), - (32, 8, 1024, 64): (32, 32), - (1, 1, 1024, 64): (32, 32), - (32, 8, 4096, 128): (32, 32), - (4, 1, 4096, 128): (32, 32), - (32, 8, 2048, 128): (32, 32), - (4, 1, 2048, 128): (32, 32), - (32, 8, 1024, 128): (32, 32), - (1, 1, 1024, 128): (32, 32), - (10, 2, 4096, 2048): (128, 32), # Qwen/Qwen2.5-32B - (10, 2, 2048, 2048): (128, 32), # Qwen/Qwen2.5-32B - (10, 2, 1024, 2048): (128, 32), # Qwen/Qwen2.5-32B - (5, 1, 4098, 2048): (128, 64), # Qwen/Qwen2.5-32B - (5, 1, 2048, 2048): (128, 32), # Qwen/Qwen2.5-32B - (5, 1, 1024, 2048): (128, 32), # Qwen/Qwen2.5-32B -} - - -def get_ragged_attention_tuned_block_size(q_head_num, kv_head_num, token_num, - max_model_len): - tpu_version = torch_xla.tpu.version() - if tpu_version < 4: - raise NotImplementedError("TPU version must be 4 or higher.") - if tpu_version == 4: - # This default block size is not tuned, only make sure there's no - # OOM in vmem - num_kv_pages_per_block = 16 - num_queries_per_block = 128 - return num_kv_pages_per_block, num_queries_per_block - - key = _simplify_key_ragged_paged_attention(q_head_num, kv_head_num, token_num, - max_model_len) - block_sizes = _ragged_attention_table.get(key, (128, 32)) - return block_sizes