From 2b9806064bab55a7c84388e55cadfc18ef9e0a88 Mon Sep 17 00:00:00 2001 From: Fenghui Zhang Date: Wed, 15 Jan 2025 18:43:38 +0000 Subject: [PATCH 01/13] Pipes attn_logits_soft_cap through multi_queries_paged_attention --- test/test_pallas.py | 4 +--- torch_xla/experimental/custom_kernel.py | 24 ++++++++++++++++-------- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/test/test_pallas.py b/test/test_pallas.py index d49df491dc06..d9ce3d919ac3 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -1,5 +1,4 @@ import logging -import os import unittest import torch @@ -606,6 +605,7 @@ def test_paged_attention_multi_queries_wrapper(self): effective_q_lens_xla, num_kv_pages_per_compute_block=block_kv_size // page_size, num_queries_per_compute_block=num_queries_per_compute_block, + attn_logits_soft_cap=0.3, ) nonkernel_output = multi_queries_paged_attention( @@ -822,7 +822,6 @@ def test_paged_attention_wrapper_with_dynamo(self): num_kv_heads = 8 q_kv_head_ratio = 8 head_dim = 256 - dtype = torch.float32 seq_lens = torch.tensor([0, 3, 256, 513, 1023, 2048], dtype=torch.int32) q, k_pages, v_pages, page_indices = self._pagedattention_generate_qkv( @@ -899,7 +898,6 @@ def test_paged_attention_wrapper_with_attn_logits_soft_cap(self): num_kv_heads = 8 q_kv_head_ratio = 8 head_dim = 256 - dtype = torch.float32 seq_lens = torch.tensor([0, 3, 256, 513, 1023, 2048], dtype=torch.int32) q, k_pages, v_pages, page_indices = self._pagedattention_generate_qkv( diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index 8ccc4ddbc597..5e3eefdfdb7f 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -569,6 +569,7 @@ def multi_queries_paged_attention( num_kv_pages_per_compute_block, num_queries_per_compute_block, use_kernel=True, + attn_logits_soft_cap: float | None = None, ): # [batch_size, query_len, num_heads, head_dim]: assert len(q.shape) == 4, "q should have 4 dimensions." if not use_kernel: @@ -595,6 +596,7 @@ def multi_queries_paged_attention( effective_q_lens, num_kv_pages_per_compute_block=num_kv_pages_per_compute_block, num_queries_per_compute_block=num_queries_per_compute_block, + attn_logits_soft_cap=attn_logits_soft_cap, static_argnames=[ "num_kv_pages_per_compute_block", "num_queries_per_compute_block", @@ -1103,21 +1105,27 @@ def paged_attention_non_xla(q: torch.Tensor, XLA_LIB.define( - "multi_queries_paged_attention(Tensor q, Tensor k_pages, Tensor v_pages, Tensor lengths, Tensor page_indices, Tensor effective_q_lens, int num_kv_pages_per_compute_block, int num_queries_per_compute_block, bool use_kernel) -> Tensor", -) + "multi_queries_paged_attention(Tensor q, Tensor k_pages, Tensor v_pages, Tensor lengths, Tensor page_indices," + " Tensor effective_q_lens, int num_kv_pages_per_compute_block, int num_queries_per_compute_block," + " bool use_kernel, float attn_logits_soft_cap=None) -> Tensor",) @impl(XLA_LIB, "multi_queries_paged_attention", "XLA") -def multi_queries_paged_attention_xla( - q: torch.Tensor, k_pages: torch.Tensor, v_pages: torch.Tensor, - lengths: torch.Tensor, page_indices: torch.Tensor, - effective_q_lens: torch.Tensor, num_kv_pages_per_compute_block: int, - num_queries_per_compute_block: int, use_kernel: bool): +def multi_queries_paged_attention_xla(q: torch.Tensor, + k_pages: torch.Tensor, + v_pages: torch.Tensor, + lengths: torch.Tensor, + page_indices: torch.Tensor, + effective_q_lens: torch.Tensor, + num_kv_pages_per_compute_block: int, + num_queries_per_compute_block: int, + use_kernel: bool, + attn_logits_soft_cap: float = None): return multi_queries_paged_attention(q, k_pages, v_pages, lengths, page_indices, effective_q_lens, num_kv_pages_per_compute_block, num_queries_per_compute_block, - use_kernel) + use_kernel, attn_logits_soft_cap) @impl(XLA_LIB, "multi_queries_paged_attention", "CompositeExplicitAutograd") From 8106ad2db87bc00a9387bbaf1d7a1052d3608b61 Mon Sep 17 00:00:00 2001 From: Fenghui Zhang Date: Wed, 15 Jan 2025 18:43:38 +0000 Subject: [PATCH 02/13] Implements attn_logits_soft_cap and pass it through multi_queries_paged_attention --- test/test_pallas.py | 34 +++++++++++++++++-- torch_xla/experimental/custom_kernel.py | 1 + .../multi_queries_paged_attention_kernel.py | 13 ++++++- 3 files changed, 45 insertions(+), 3 deletions(-) diff --git a/test/test_pallas.py b/test/test_pallas.py index d9ce3d919ac3..9d516ee7dbf3 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -596,6 +596,17 @@ def test_paged_attention_multi_queries_wrapper(self): page_indices_xla = page_indices.to("xla") effective_q_lens_xla = effective_q_lens.to("xla") + output_no_cap = multi_queries_paged_attention( + q_xla, + k_pages_xla, + v_pages_xla, + kv_seq_lens_xla, + page_indices_xla, + effective_q_lens_xla, + num_kv_pages_per_compute_block=block_kv_size // page_size, + num_queries_per_compute_block=num_queries_per_compute_block, + ) + output = multi_queries_paged_attention( q_xla, k_pages_xla, @@ -605,7 +616,7 @@ def test_paged_attention_multi_queries_wrapper(self): effective_q_lens_xla, num_kv_pages_per_compute_block=block_kv_size // page_size, num_queries_per_compute_block=num_queries_per_compute_block, - attn_logits_soft_cap=0.3, + attn_logits_soft_cap=1.0, ) nonkernel_output = multi_queries_paged_attention( @@ -627,6 +638,19 @@ def test_paged_attention_multi_queries_wrapper(self): page_indices_jax = jnp.array(page_indices.numpy(), dtype=jnp.int32) effective_q_lens_jax = jnp.array(effective_q_lens.numpy(), dtype=jnp.int32) expected_output = torch.from_numpy( + np.array( + jax_multi_queries_paged_attention( + q_jax, + k_pages_jax, + v_pages_jax, + kv_seq_lens_jax, + page_indices_jax, + effective_q_lens_jax, + num_kv_pages_per_compute_block=block_kv_size // page_size, + num_queries_per_compute_block=num_queries_per_compute_block, + attn_logits_soft_cap=1.0, + ))) + expected_output_no_cap = torch.from_numpy( np.array( jax_multi_queries_paged_attention( q_jax, @@ -642,9 +666,15 @@ def test_paged_attention_multi_queries_wrapper(self): self.assertTrue( torch.allclose( output.cpu(), expected_output.cpu(), atol=1e-5, rtol=1e-5)) + self.assertFalse( + torch.allclose( + output.cpu(), expected_output_no_cap.cpu(), atol=1e-5, rtol=1e-5)) self.assertTrue( torch.allclose( - output.cpu(), nonkernel_output.cpu(), atol=1e-2, rtol=1e-2)) + output_no_cap.cpu(), expected_output_no_cap.cpu(), atol=1e-5, rtol=1e-5)) + self.assertTrue( + torch.allclose( + output_no_cap.cpu(), nonkernel_output.cpu(), atol=1e-2, rtol=1e-2)) @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4, "This test only works on TPUv4+.") diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index 5e3eefdfdb7f..2d766a18ec1a 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -600,6 +600,7 @@ def multi_queries_paged_attention( static_argnames=[ "num_kv_pages_per_compute_block", "num_queries_per_compute_block", + "attn_logits_soft_cap", ], ) diff --git a/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py b/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py index 84d6ad530e50..e368ccb5f89b 100644 --- a/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py +++ b/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py @@ -116,6 +116,7 @@ def _flash_attention( query_len: int, page_size: int, head_dim: int, + attn_logits_soft_cap: float = None, ): b, kv_head_idx, q_blk_idx, kv_blk_idx = ( pl.program_id(0), @@ -143,6 +144,10 @@ def start_new_sequence(): s = jnp.einsum( 'qd,td->qt', q, k, preferred_element_type=jnp.float32) # [block_q, block_k] + if attn_logits_soft_cap is not None: + capped_s = jnp.tanh(s / attn_logits_soft_cap) + s = capped_s * attn_logits_soft_cap + assert s.shape == (num_queries_per_compute_block, kv_seq_len_per_kv_compute_blk) @@ -266,6 +271,7 @@ def paged_flash_attention_kernel( num_kv_pages_per_compute_block: int, mask_value: float, query_len: int, + attn_logits_soft_cap: float | None = None, ): """Pallas kernel for paged attention.""" b, kv_head_idx, q_blk_idx, kv_blk_idx = ( @@ -411,6 +417,7 @@ def prefetch_next_block(): # pylint: disable=unused-variable query_len=query_len, page_size=page_size, head_dim=head_dim, + attn_logits_soft_cap=attn_logits_soft_cap, ) # o_ref.shape=[num_q_heads_per_kv_head, num_queries_per_compute_block, head_dim] step_ref[0] = step + 1 @@ -428,6 +435,7 @@ def prefetch_next_block(): # pylint: disable=unused-variable "num_kv_pages_per_compute_block", "num_queries_per_compute_block", "mask_value", + "attn_logits_soft_cap", ], ) def paged_attention( @@ -441,6 +449,7 @@ def paged_attention( mask_value: float = DEFAULT_MASK_VALUE, num_kv_pages_per_compute_block: int, num_queries_per_compute_block: int = 4, + attn_logits_soft_cap: float | None = None, ) -> jax.Array: """Paged grouped query attention. @@ -620,7 +629,9 @@ def lm_index_map(batch_index, kv_head_index, q_seq_blk_idx, *_): batch_size=batch_size, num_kv_pages_per_compute_block=num_kv_pages_per_compute_block, mask_value=mask_value, - query_len=query_len), + query_len=query_len, + attn_logits_soft_cap=attn_logits_soft_cap, + ), grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=5, in_specs=in_specs, From 8802322e87ba955c99dc096e3432b762592e6fe3 Mon Sep 17 00:00:00 2001 From: Fenghui Zhang Date: Wed, 15 Jan 2025 18:43:38 +0000 Subject: [PATCH 03/13] Implements attn_logits_soft_cap and pass it through multi_queries_paged_attention --- test/test_pallas.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/test_pallas.py b/test/test_pallas.py index 9d516ee7dbf3..2a59cee9d9a0 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -671,7 +671,10 @@ def test_paged_attention_multi_queries_wrapper(self): output.cpu(), expected_output_no_cap.cpu(), atol=1e-5, rtol=1e-5)) self.assertTrue( torch.allclose( - output_no_cap.cpu(), expected_output_no_cap.cpu(), atol=1e-5, rtol=1e-5)) + output_no_cap.cpu(), + expected_output_no_cap.cpu(), + atol=1e-5, + rtol=1e-5)) self.assertTrue( torch.allclose( output_no_cap.cpu(), nonkernel_output.cpu(), atol=1e-2, rtol=1e-2)) From 9e57ad4274e7511f928fc589b2917da30efc4860 Mon Sep 17 00:00:00 2001 From: Fenghui Zhang Date: Wed, 15 Jan 2025 18:43:38 +0000 Subject: [PATCH 04/13] Implements attn_logits_soft_cap and pass it through multi_queries_paged_attention --- torch_xla/experimental/custom_kernel.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index 2d766a18ec1a..8c87b5f0fcfc 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -1134,7 +1134,8 @@ def multi_queries_paged_attention_non_xla( q: torch.Tensor, k_pages: torch.Tensor, v_pages: torch.Tensor, lengths: torch.Tensor, page_indices: torch.Tensor, effective_q_lens: torch.Tensor, num_kv_pages_per_compute_block: int, - num_queries_per_compute_block: int, use_kernel: bool): + num_queries_per_compute_block: int, use_kernel: bool, + attn_logits_soft_cap: float = None): return non_xla_attetion(q, k_pages, v_pages, "paged") From 18358764528ebe7f5a03dd2f66697ed8519d8745 Mon Sep 17 00:00:00 2001 From: Fenghui Zhang Date: Wed, 15 Jan 2025 18:43:38 +0000 Subject: [PATCH 05/13] Implements attn_logits_soft_cap and pass it through multi_queries_paged_attention --- torch_xla/experimental/custom_kernel.py | 4 ++-- .../pallas_kernels/multi_queries_paged_attention_kernel.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index 8c87b5f0fcfc..9733fd4c5365 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -1121,7 +1121,7 @@ def multi_queries_paged_attention_xla(q: torch.Tensor, num_kv_pages_per_compute_block: int, num_queries_per_compute_block: int, use_kernel: bool, - attn_logits_soft_cap: float = None): + attn_logits_soft_cap: float | None): return multi_queries_paged_attention(q, k_pages, v_pages, lengths, page_indices, effective_q_lens, num_kv_pages_per_compute_block, @@ -1135,7 +1135,7 @@ def multi_queries_paged_attention_non_xla( lengths: torch.Tensor, page_indices: torch.Tensor, effective_q_lens: torch.Tensor, num_kv_pages_per_compute_block: int, num_queries_per_compute_block: int, use_kernel: bool, - attn_logits_soft_cap: float = None): + attn_logits_soft_cap: float | None): return non_xla_attetion(q, k_pages, v_pages, "paged") diff --git a/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py b/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py index e368ccb5f89b..2b2c0ff7f334 100644 --- a/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py +++ b/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py @@ -116,7 +116,7 @@ def _flash_attention( query_len: int, page_size: int, head_dim: int, - attn_logits_soft_cap: float = None, + attn_logits_soft_cap: float | None, ): b, kv_head_idx, q_blk_idx, kv_blk_idx = ( pl.program_id(0), From 68cd431c7711dcecadeb33fc4a5bc6e385a39e63 Mon Sep 17 00:00:00 2001 From: Fenghui Zhang Date: Wed, 15 Jan 2025 18:43:38 +0000 Subject: [PATCH 06/13] Implements attn_logits_soft_cap and pass it through multi_queries_paged_attention --- test/test_pallas.py | 1 + test/test_tpu_paged_attention_kernel.py | 8 ++++++++ test/tpu/run_tests.sh | 2 ++ torch_xla/experimental/custom_kernel.py | 21 +++++++++++---------- 4 files changed, 22 insertions(+), 10 deletions(-) diff --git a/test/test_pallas.py b/test/test_pallas.py index 2a59cee9d9a0..869eb08164a2 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -740,6 +740,7 @@ def multi_queries_paged_attention_wrapper(q, k_pages, v_pages, kv_seq_lens, num_kv_pages_per_compute_block, num_queries_per_compute_block, use_kernel=use_kernel, + attn_logits_soft_cap=1.0, ) compiled_paged_attention = torch.compile( diff --git a/test/test_tpu_paged_attention_kernel.py b/test/test_tpu_paged_attention_kernel.py index 746439ba4d0f..8749334b7112 100644 --- a/test/test_tpu_paged_attention_kernel.py +++ b/test/test_tpu_paged_attention_kernel.py @@ -45,6 +45,7 @@ def _ref_jax_extended_paged_attention( lengths, # [batch_size], the effective kv_length. page_indices, # [batch_size, pages_per_sequence] effective_q_lens, # [batch_size] the effective q_length + attn_logits_soft_cap: float | None = None, ): batch_size, query_len, num_query_heads, head_size = q.shape num_kv_heads, total_num_pages, page_size, _ = k_pages.shape @@ -71,6 +72,9 @@ def _ref_jax_extended_paged_attention( v = jnp.repeat(v, num_query_per_kv, axis=1) attn = jnp.einsum("qhd,khd->hqk", q[i], k) + if attn_logits_soft_cap is not None: + capped_attn = jnp.tanh(attn / attn_logits_soft_cap) + attn = capped_attn * attn_logits_soft_cap attn = attn.astype('float32') effective_q_len = effective_q_lens[i] q_span = (kv_len - effective_q_len) + jax.lax.broadcasted_iota( @@ -111,6 +115,7 @@ def setUp(self): head_dim=(128, 256), num_queries_per_compute_block=(16, 32), block_kv_size=(128, 256), + attn_logits_soft_cap=(1.0, None), ) def test_paged_attention_without_query_padding( self, @@ -121,6 +126,7 @@ def test_paged_attention_without_query_padding( head_dim, num_queries_per_compute_block, block_kv_size, + attn_logits_soft_cap, ): max_kv_len = 2048 @@ -160,6 +166,7 @@ def test_paged_attention_without_query_padding( effective_q_lens, num_kv_pages_per_compute_block=num_kv_pages_per_compute_block, num_queries_per_compute_block=num_queries_per_compute_block, + attn_logits_soft_cap=attn_logits_soft_cap, ) # Note kernel execution is async. Without blocking, if an error happens in the kernel, the error may point to some irrelevant and confusing places. See https://github.com/pytorch/xla/pull/8356#issuecomment-2486861631 actual_output = jax.block_until_ready(actual_output) @@ -172,6 +179,7 @@ def test_paged_attention_without_query_padding( kv_seq_lens, page_indices, effective_q_lens, + attn_logits_soft_cap=attn_logits_soft_cap, ) self.assertEqual(actual_output.shape, expected_output.shape) diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index 99f8f6aa628b..5afcd5877ed8 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -36,6 +36,7 @@ python3 "$TEST_CDIR/scan/test_scan_layers.py" run_xla_hlo_debug python3 "$TEST_CDIR/scan/test_scan_debug.py" python3 "$TEST_CDIR/test_pallas.py" -v python3 "$TEST_CDIR/test_pallas_spmd.py" +XLA_DISABLE_FUNCTIONALIZATION=1 python3 "$TEST_CDIR/test_pallas_spmd.py" python3 "$TEST_CDIR/test_tpu_paged_attention_kernel.py" python3 "$TEST_CDIR/test_input_output_aliases.py" python3 "$TEST_CDIR/test_gmm.py" @@ -46,6 +47,7 @@ python3 "$TEST_CDIR/torch_distributed/test_torch_distributed_multi_all_reduce_xl python3 "$TEST_CDIR/torch_distributed/test_torch_distributed_reduce_scatter_xla_backend.py" python3 "$TEST_CDIR/quantized_ops/test_dot_general.py" run_xla_ir_hlo_debug python3 "$TEST_CDIR/test_user_computation_debug_cache.py" +python3 "$TEST_CDIR/test_data_type.py" # run examples, each test should takes <2 minutes python3 "$TEST_CDIR/../examples/data_parallel/train_resnet_spmd_data_parallel.py" diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index 9733fd4c5365..2a8dd94eee4e 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -506,6 +506,7 @@ def _multi_queries_paged_attention_nonkernel( lengths, # seq_lengths, [batch_size]. nb batch_size = len(seq_lens), the effective kv_length. page_indices, # [batch_size, pages_per_sequence] effective_q_lens, # [batch_size], the effective q_length + attn_logits_soft_cap: float | None = None, ) -> torch.Tensor: # [batch_size, query_len, num_heads, head_dim] batch_size, query_len, num_query_heads, head_size = q.shape num_kv_heads, total_num_pages, page_size, _ = k_pages.shape @@ -543,6 +544,9 @@ def _multi_queries_paged_attention_nonkernel( # For example, it can use bfloat16 instead of float32 or vice versa for performance or simplicity. attn = torch.einsum("qhd,khd->hqk", q[i], k) # [num_query_heads, query_len, kv_len] + if attn_logits_soft_cap is not None: + capped_attn = torch.tanh(attn / attn_logits_soft_cap) + attn = capped_attn * attn_logits_soft_cap attn = attn.float() empty_mask = torch.ones(query_len, kv_len, device=attn.device) effective_q_len = effective_q_lens[i] @@ -580,6 +584,7 @@ def multi_queries_paged_attention( lengths, page_indices, effective_q_lens, + attn_logits_soft_cap=attn_logits_soft_cap, ) # Import JAX within the function such that we don't need to call the jax_import_guard() @@ -1112,16 +1117,12 @@ def paged_attention_non_xla(q: torch.Tensor, @impl(XLA_LIB, "multi_queries_paged_attention", "XLA") -def multi_queries_paged_attention_xla(q: torch.Tensor, - k_pages: torch.Tensor, - v_pages: torch.Tensor, - lengths: torch.Tensor, - page_indices: torch.Tensor, - effective_q_lens: torch.Tensor, - num_kv_pages_per_compute_block: int, - num_queries_per_compute_block: int, - use_kernel: bool, - attn_logits_soft_cap: float | None): +def multi_queries_paged_attention_xla( + q: torch.Tensor, k_pages: torch.Tensor, v_pages: torch.Tensor, + lengths: torch.Tensor, page_indices: torch.Tensor, + effective_q_lens: torch.Tensor, num_kv_pages_per_compute_block: int, + num_queries_per_compute_block: int, use_kernel: bool, + attn_logits_soft_cap: float | None): return multi_queries_paged_attention(q, k_pages, v_pages, lengths, page_indices, effective_q_lens, num_kv_pages_per_compute_block, From 491dbdb1ae940eccc6d8122c506b5c02159a202f Mon Sep 17 00:00:00 2001 From: Fenghui Zhang Date: Wed, 15 Jan 2025 18:43:38 +0000 Subject: [PATCH 07/13] Implements attn_logits_soft_cap and pass it through multi_queries_paged_attention --- .../pallas_kernels/multi_queries_paged_attention_kernel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py b/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py index 2b2c0ff7f334..dc03d7bca851 100644 --- a/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py +++ b/torch_xla/experimental/pallas_kernels/multi_queries_paged_attention_kernel.py @@ -271,7 +271,7 @@ def paged_flash_attention_kernel( num_kv_pages_per_compute_block: int, mask_value: float, query_len: int, - attn_logits_soft_cap: float | None = None, + attn_logits_soft_cap: float | None, ): """Pallas kernel for paged attention.""" b, kv_head_idx, q_blk_idx, kv_blk_idx = ( From b8660fedb5592d283d8d03d7d97e1bec170cc86e Mon Sep 17 00:00:00 2001 From: Fenghui Zhang Date: Wed, 15 Jan 2025 18:43:38 +0000 Subject: [PATCH 08/13] Implements attn_logits_soft_cap and pass it through multi_queries_paged_attention --- test/tpu/run_tests.sh | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index 5afcd5877ed8..99f8f6aa628b 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -36,7 +36,6 @@ python3 "$TEST_CDIR/scan/test_scan_layers.py" run_xla_hlo_debug python3 "$TEST_CDIR/scan/test_scan_debug.py" python3 "$TEST_CDIR/test_pallas.py" -v python3 "$TEST_CDIR/test_pallas_spmd.py" -XLA_DISABLE_FUNCTIONALIZATION=1 python3 "$TEST_CDIR/test_pallas_spmd.py" python3 "$TEST_CDIR/test_tpu_paged_attention_kernel.py" python3 "$TEST_CDIR/test_input_output_aliases.py" python3 "$TEST_CDIR/test_gmm.py" @@ -47,7 +46,6 @@ python3 "$TEST_CDIR/torch_distributed/test_torch_distributed_multi_all_reduce_xl python3 "$TEST_CDIR/torch_distributed/test_torch_distributed_reduce_scatter_xla_backend.py" python3 "$TEST_CDIR/quantized_ops/test_dot_general.py" run_xla_ir_hlo_debug python3 "$TEST_CDIR/test_user_computation_debug_cache.py" -python3 "$TEST_CDIR/test_data_type.py" # run examples, each test should takes <2 minutes python3 "$TEST_CDIR/../examples/data_parallel/train_resnet_spmd_data_parallel.py" From 351de895b865685a858e21451df1412aedd872e3 Mon Sep 17 00:00:00 2001 From: Fenghui Zhang Date: Wed, 15 Jan 2025 18:43:38 +0000 Subject: [PATCH 09/13] Implements attn_logits_soft_cap and pass it through multi_queries_paged_attention --- test/test_pallas.py | 58 +++++++++++++------------ torch_xla/experimental/custom_kernel.py | 6 +-- 2 files changed, 34 insertions(+), 30 deletions(-) diff --git a/test/test_pallas.py b/test/test_pallas.py index 869eb08164a2..fdc904bd98e6 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -729,7 +729,8 @@ def multi_queries_paged_attention_wrapper(q, k_pages, v_pages, kv_seq_lens, page_indices, effective_q_lens, num_kv_pages_per_compute_block, num_queries_per_compute_block, - use_kernel): + use_kernel, + attn_logits_soft_cap): return torch.ops.xla.multi_queries_paged_attention( q, k_pages, @@ -740,39 +741,42 @@ def multi_queries_paged_attention_wrapper(q, k_pages, v_pages, kv_seq_lens, num_kv_pages_per_compute_block, num_queries_per_compute_block, use_kernel=use_kernel, - attn_logits_soft_cap=1.0, + attn_logits_soft_cap=attn_logits_soft_cap, ) compiled_paged_attention = torch.compile( multi_queries_paged_attention_wrapper, backend="openxla") - output = compiled_paged_attention( - q_xla, - k_pages_xla, - v_pages_xla, - kv_seq_lens_xla, - page_indices_xla, - effective_q_lens_xla, - num_kv_pages_per_compute_block=block_kv_size // page_size, - num_queries_per_compute_block=num_queries_per_compute_block, - use_kernel=True, - ) + for attn_logits_soft_cap in (1.0, None): + output = compiled_paged_attention( + q_xla, + k_pages_xla, + v_pages_xla, + kv_seq_lens_xla, + page_indices_xla, + effective_q_lens_xla, + num_kv_pages_per_compute_block=block_kv_size // page_size, + num_queries_per_compute_block=num_queries_per_compute_block, + use_kernel=True, + attn_logits_soft_cap=attn_logits_soft_cap, + ) - nonkernel_output = compiled_paged_attention( - q_xla, - k_pages_xla, - v_pages_xla, - kv_seq_lens_xla, - page_indices_xla, - effective_q_lens_xla, - num_kv_pages_per_compute_block=block_kv_size // page_size, - num_queries_per_compute_block=num_queries_per_compute_block, - use_kernel=False, - ) + nonkernel_output = compiled_paged_attention( + q_xla, + k_pages_xla, + v_pages_xla, + kv_seq_lens_xla, + page_indices_xla, + effective_q_lens_xla, + num_kv_pages_per_compute_block=block_kv_size // page_size, + num_queries_per_compute_block=num_queries_per_compute_block, + use_kernel=False, + attn_logits_soft_cap=attn_logits_soft_cap, + ) - self.assertTrue( - torch.allclose( - output.cpu(), nonkernel_output.cpu(), atol=1e-2, rtol=1e-2)) + self.assertTrue( + torch.allclose( + output.cpu(), nonkernel_output.cpu(), atol=1e-2, rtol=1e-2)) @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() != 4, "This test only works on TPUv4 and TPUv5p.") diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index 2a8dd94eee4e..d26d9f649b7e 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -1113,7 +1113,7 @@ def paged_attention_non_xla(q: torch.Tensor, XLA_LIB.define( "multi_queries_paged_attention(Tensor q, Tensor k_pages, Tensor v_pages, Tensor lengths, Tensor page_indices," " Tensor effective_q_lens, int num_kv_pages_per_compute_block, int num_queries_per_compute_block," - " bool use_kernel, float attn_logits_soft_cap=None) -> Tensor",) + " bool use_kernel, float? attn_logits_soft_cap=None) -> Tensor",) @impl(XLA_LIB, "multi_queries_paged_attention", "XLA") @@ -1122,7 +1122,7 @@ def multi_queries_paged_attention_xla( lengths: torch.Tensor, page_indices: torch.Tensor, effective_q_lens: torch.Tensor, num_kv_pages_per_compute_block: int, num_queries_per_compute_block: int, use_kernel: bool, - attn_logits_soft_cap: float | None): + attn_logits_soft_cap: float | None = None): return multi_queries_paged_attention(q, k_pages, v_pages, lengths, page_indices, effective_q_lens, num_kv_pages_per_compute_block, @@ -1136,7 +1136,7 @@ def multi_queries_paged_attention_non_xla( lengths: torch.Tensor, page_indices: torch.Tensor, effective_q_lens: torch.Tensor, num_kv_pages_per_compute_block: int, num_queries_per_compute_block: int, use_kernel: bool, - attn_logits_soft_cap: float | None): + attn_logits_soft_cap: float | None = None): return non_xla_attetion(q, k_pages, v_pages, "paged") From 2ce9e2fa2eccc6d829c8c10bc83bebf64e662f89 Mon Sep 17 00:00:00 2001 From: Fenghui Zhang Date: Wed, 15 Jan 2025 18:43:38 +0000 Subject: [PATCH 10/13] Implements attn_logits_soft_cap and pass it through multi_queries_paged_attention --- test/test_pallas.py | 14 ++++++++++ torch_xla/experimental/custom_kernel.py | 34 ++++++++++++++++--------- 2 files changed, 36 insertions(+), 12 deletions(-) diff --git a/test/test_pallas.py b/test/test_pallas.py index fdc904bd98e6..3e165dc80352 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -729,8 +729,22 @@ def multi_queries_paged_attention_wrapper(q, k_pages, v_pages, kv_seq_lens, page_indices, effective_q_lens, num_kv_pages_per_compute_block, num_queries_per_compute_block, +<<<<<<< HEAD use_kernel, attn_logits_soft_cap): +======= +<<<<<<< HEAD + use_kernel, + attn_logits_soft_cap): +======= +<<<<<<< HEAD + use_kernel, + attn_logits_soft_cap): +======= + use_kernel, attn_logits_soft_cap): +>>>>>>> 0a91471da (Implements attn_logits_soft_cap and pass it through multi_queries_paged_attention) +>>>>>>> 47e8d1d00 (Implements attn_logits_soft_cap and pass it through multi_queries_paged_attention) +>>>>>>> d430cb4e9 (Implements attn_logits_soft_cap and pass it through multi_queries_paged_attention) return torch.ops.xla.multi_queries_paged_attention( q, k_pages, diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index d26d9f649b7e..185d2085e7f1 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -1117,12 +1117,17 @@ def paged_attention_non_xla(q: torch.Tensor, @impl(XLA_LIB, "multi_queries_paged_attention", "XLA") -def multi_queries_paged_attention_xla( - q: torch.Tensor, k_pages: torch.Tensor, v_pages: torch.Tensor, - lengths: torch.Tensor, page_indices: torch.Tensor, - effective_q_lens: torch.Tensor, num_kv_pages_per_compute_block: int, - num_queries_per_compute_block: int, use_kernel: bool, - attn_logits_soft_cap: float | None = None): +def multi_queries_paged_attention_xla(q: torch.Tensor, + k_pages: torch.Tensor, + v_pages: torch.Tensor, + lengths: torch.Tensor, + page_indices: torch.Tensor, + effective_q_lens: torch.Tensor, + num_kv_pages_per_compute_block: int, + num_queries_per_compute_block: int, + use_kernel: bool, + attn_logits_soft_cap: float | + None = None): return multi_queries_paged_attention(q, k_pages, v_pages, lengths, page_indices, effective_q_lens, num_kv_pages_per_compute_block, @@ -1131,12 +1136,17 @@ def multi_queries_paged_attention_xla( @impl(XLA_LIB, "multi_queries_paged_attention", "CompositeExplicitAutograd") -def multi_queries_paged_attention_non_xla( - q: torch.Tensor, k_pages: torch.Tensor, v_pages: torch.Tensor, - lengths: torch.Tensor, page_indices: torch.Tensor, - effective_q_lens: torch.Tensor, num_kv_pages_per_compute_block: int, - num_queries_per_compute_block: int, use_kernel: bool, - attn_logits_soft_cap: float | None = None): +def multi_queries_paged_attention_non_xla(q: torch.Tensor, + k_pages: torch.Tensor, + v_pages: torch.Tensor, + lengths: torch.Tensor, + page_indices: torch.Tensor, + effective_q_lens: torch.Tensor, + num_kv_pages_per_compute_block: int, + num_queries_per_compute_block: int, + use_kernel: bool, + attn_logits_soft_cap: float | + None = None): return non_xla_attetion(q, k_pages, v_pages, "paged") From 19cf3a0f978dd7eb99e73441bb8979bd3df3002c Mon Sep 17 00:00:00 2001 From: Fenghui Zhang Date: Wed, 15 Jan 2025 18:43:38 +0000 Subject: [PATCH 11/13] Implements attn_logits_soft_cap and pass it through multi_queries_paged_attention --- test/test_pallas.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/test/test_pallas.py b/test/test_pallas.py index 3e165dc80352..fdc904bd98e6 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -729,22 +729,8 @@ def multi_queries_paged_attention_wrapper(q, k_pages, v_pages, kv_seq_lens, page_indices, effective_q_lens, num_kv_pages_per_compute_block, num_queries_per_compute_block, -<<<<<<< HEAD use_kernel, attn_logits_soft_cap): -======= -<<<<<<< HEAD - use_kernel, - attn_logits_soft_cap): -======= -<<<<<<< HEAD - use_kernel, - attn_logits_soft_cap): -======= - use_kernel, attn_logits_soft_cap): ->>>>>>> 0a91471da (Implements attn_logits_soft_cap and pass it through multi_queries_paged_attention) ->>>>>>> 47e8d1d00 (Implements attn_logits_soft_cap and pass it through multi_queries_paged_attention) ->>>>>>> d430cb4e9 (Implements attn_logits_soft_cap and pass it through multi_queries_paged_attention) return torch.ops.xla.multi_queries_paged_attention( q, k_pages, From 172f9cdd459cc307a44f9e80773fabfbe8129733 Mon Sep 17 00:00:00 2001 From: Fenghui Zhang Date: Wed, 15 Jan 2025 18:43:38 +0000 Subject: [PATCH 12/13] Implements attn_logits_soft_cap and pass it through multi_queries_paged_attention --- test/test_pallas.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/test_pallas.py b/test/test_pallas.py index fdc904bd98e6..106b917528d3 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -729,8 +729,7 @@ def multi_queries_paged_attention_wrapper(q, k_pages, v_pages, kv_seq_lens, page_indices, effective_q_lens, num_kv_pages_per_compute_block, num_queries_per_compute_block, - use_kernel, - attn_logits_soft_cap): + use_kernel, attn_logits_soft_cap): return torch.ops.xla.multi_queries_paged_attention( q, k_pages, From 633792cfb9fe4aaab536f8416337bbadfbba1fa6 Mon Sep 17 00:00:00 2001 From: Fenghui Zhang Date: Wed, 22 Jan 2025 15:50:07 +0000 Subject: [PATCH 13/13] Fix the signature of paged_attention by marking attn_logits_soft_cap optional --- torch_xla/experimental/custom_kernel.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index 185d2085e7f1..6743d562d8ab 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -1080,7 +1080,8 @@ def flash_attention_non_xla(q: torch.Tensor, XLA_LIB.define( - "paged_attention(Tensor q, Tensor k_pages, Tensor v_pages, Tensor lengths, Tensor page_indices, int pages_per_compute_block, str megacore_mode=None, float attn_logits_soft_cap=None) -> Tensor", + "paged_attention(Tensor q, Tensor k_pages, Tensor v_pages, Tensor lengths, Tensor page_indices," + " int pages_per_compute_block, str megacore_mode=None, float? attn_logits_soft_cap=None) -> Tensor", )