diff --git a/.ci/docker/ci_commit_pins/triton.txt b/.ci/docker/ci_commit_pins/triton.txt index 396be2dd54aee..2a9c139109996 100644 --- a/.ci/docker/ci_commit_pins/triton.txt +++ b/.ci/docker/ci_commit_pins/triton.txt @@ -1 +1 @@ -6da9e66008b58a7b8553f96c69021cca0d0028f0 +a34a79dbd711ea9f8fb5090bcaf24a7717574206 diff --git a/torch/_inductor/autotune_process.py b/torch/_inductor/autotune_process.py index 2828f48b79c23..fe1620cc90ca3 100644 --- a/torch/_inductor/autotune_process.py +++ b/torch/_inductor/autotune_process.py @@ -630,6 +630,8 @@ def __init__( num_stages: int, num_warps: int, matrix_instr_nonkdim: int = 0, # only used for hip to choose the shape of mfma instruction. + waves_per_eu: int = 0, # only used for hip to schedule waves per execution unit + kpack: int = 0, # ROCm specific gemm paramete workspace_arg: Optional[WorkspaceArg] = None, ) -> None: super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args) @@ -639,6 +641,8 @@ def __init__( self.num_stages = num_stages self.num_warps = num_warps self.matrix_instr_nonkdim = matrix_instr_nonkdim + self.waves_per_eu = waves_per_eu + self.kpack = kpack self.workspace_arg = workspace_arg def make_run_fn( diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index 2318be5c423e5..2b1884cd00504 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -1,5 +1,7 @@ # mypy: allow-untyped-defs """ Triton Implementation of the flex_attention Kernel""" +import os +import itertools import logging import math @@ -1206,10 +1208,24 @@ def flex_attention( if torch.version.hip: configs = [(c[0], c[1], c[2], 1) for c in configs] + # Check if the environment variable is set + if os.getenv("TORCHINDUCTOR_EXHAUSTIVE_FLEX_ATTENTION_EXPERIMENTAL") == "1": + param1 = [16, 32, 64, 128, 256, 512] + param2 = [16, 32, 64, 128, 256, 512] + param3 = [2, 4, 8, 16] + param4 = [1] + + # Generate full search space + configs = list(itertools.product(param1, param2, param3, param4)) + # Mark SPARSE_KV_BLOCK_SIZE & SPARSE_Q_BLOCK_SIZE as static shapes and add guards. SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_KV_BLOCK_SIZE) SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_Q_BLOCK_SIZE) + # ROCm specific considerations + if torch.version.hip: + kernel_options["kpack"] = 2 + # Note, we don't need to pass in the captured buffers explicitly # because they're implicitly added by the score_mod function # We do need to explicitly pass it in for autotuning though. @@ -1234,33 +1250,67 @@ def flex_attention( cur_kernel_options.setdefault("SPARSE_Q_BLOCK_SIZE", SPARSE_Q_BLOCK_SIZE) cur_kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE) - error = flex_attention_template.maybe_append_choice( - choices=choices, - input_nodes=[ - query, - key, - value, - logsumexp, - kv_num_blocks, - kv_indices, - full_kv_num_blocks, - full_kv_indices, - ], - layout=layout, - subgraphs=[ - subgraph_buffer, - mask_graph_buffer, - ], - mutated_inputs=[ - logsumexp, - ], - num_stages=num_stages, - num_warps=num_warps, - call_sizes=query.get_size(), - **cur_kernel_options, - ) - if error is not None and len(configs) == 1: - raise error + if os.getenv("TORCHINDUCTOR_EXHAUSTIVE_FLEX_ATTENTION_EXPERIMENTAL") == "1": + for mfma in [0, 16]: + for wpeu in [0, 1, 2, 4, 8]: + cur_kernel_options["waves_per_eu"] = wpeu + cur_kernel_options["matrix_instr_non_kdim"] = mfma + error = flex_attention_template.maybe_append_choice( + choices=choices, + input_nodes=[ + query, + key, + value, + logsumexp, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + ], + layout=layout, + subgraphs=[ + subgraph_buffer, + mask_graph_buffer, + ], + mutated_inputs=[ + logsumexp, + ], + num_stages=num_stages, + num_warps=num_warps, + call_sizes=query.get_size(), + **cur_kernel_options, + ) + if error is not None and len(configs) == 1: + raise error + else: + error = flex_attention_template.maybe_append_choice( + choices=choices, + input_nodes=[ + query, + key, + value, + logsumexp, + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + ], + layout=layout, + subgraphs=[ + subgraph_buffer, + mask_graph_buffer, + ], + mutated_inputs=[ + logsumexp, + ], + num_stages=num_stages, + num_warps=num_warps, + call_sizes=query.get_size(), + **cur_kernel_options, + ) + if error is not None and len(configs) == 1: + raise error + inputs_for_autotuning = ( [ query, @@ -2257,13 +2307,15 @@ def flex_attention_backward(*args, **kwargs): configs.extend( [ (BLOCK1, BLOCK2, w, s) - for BLOCK1 in [32, 64] - for BLOCK2 in [32, 64, 128] - for w in ([4, 8] if BLOCK1 >= 128 or BLOCK2 >= 128 else [4]) + for BLOCK1 in [16, 32, 64, 128, 256, 512] + for BLOCK2 in [16, 32, 64, 128, 256, 512] + for w in ([4, 8] if BLOCK1 >= 128 or BLOCK2 >= 128 else [4, 8]) for s in num_stages_list if BLOCK2 % BLOCK1 == 0 ] ) + + original_kernel_options = kernel_options.copy() for BLOCK1, BLOCK2, num_warps, num_stages in configs: if ( @@ -2273,9 +2325,6 @@ def flex_attention_backward(*args, **kwargs): or SPARSE_Q_BLOCK_SIZE % BLOCK2 != 0 ): continue - if num_warps == 8: - # Working around https://github.com/pytorch/pytorch/issues/141603 - continue # Performance tuning cur_kernel_options = original_kernel_options.copy() @@ -2287,43 +2336,47 @@ def flex_attention_backward(*args, **kwargs): cur_kernel_options.setdefault("SPARSE_Q_BLOCK_SIZE", SPARSE_Q_BLOCK_SIZE) cur_kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE) - flex_attention_backward_template.maybe_append_choice( - choices=choices, - input_nodes=[ - query, - key, - value, - logsumexp, - delta, - grad_out, - grad_query, - broadcasted_grad_value, - kv_num_blocks, - kv_indices, - q_num_blocks, - q_indices, - full_kv_num_blocks, - full_kv_indices, - full_q_num_blocks, - full_q_indices, - ], - layout=layout_broadcasted_k, # We use store_output only for grad_key - subgraphs=[ - fw_subgraph_buffer, - joint_outputs.grad_input, - mask_graph_buffer, - joint_outputs.captured_grads_compute, - ], - mutated_inputs=[ - grad_query, - broadcasted_grad_value, - *joint_outputs.mutated_grads, - ], - call_sizes=query.get_size() + key.get_size()[1:3], - num_stages=num_stages, - num_warps=num_warps, - **cur_kernel_options, - ) + for wpeu in [0, 1, 2, 4, 8]: + for mfma in [0, 16]: + cur_kernel_options["waves_per_eu"] = wpeu + cur_kernel_options["matrix_instr_non_kdim"] = mfma + flex_attention_backward_template.maybe_append_choice( + choices=choices, + input_nodes=[ + query, + key, + value, + logsumexp, + delta, + grad_out, + grad_query, + broadcasted_grad_value, + kv_num_blocks, + kv_indices, + q_num_blocks, + q_indices, + full_kv_num_blocks, + full_kv_indices, + full_q_num_blocks, + full_q_indices, + ], + layout=layout_broadcasted_k, # We use store_output only for grad_key + subgraphs=[ + fw_subgraph_buffer, + joint_outputs.grad_input, + mask_graph_buffer, + joint_outputs.captured_grads_compute, + ], + mutated_inputs=[ + grad_query, + broadcasted_grad_value, + *joint_outputs.mutated_grads, + ], + call_sizes=query.get_size() + key.get_size()[1:3], + num_stages=num_stages, + num_warps=num_warps, + **cur_kernel_options, + ) inputs_for_autotuning = ( [ query, diff --git a/torch/_inductor/kernel/mm_common.py b/torch/_inductor/kernel/mm_common.py index cb3b2d7836c1a..6e6e7faf4e3e2 100644 --- a/torch/_inductor/kernel/mm_common.py +++ b/torch/_inductor/kernel/mm_common.py @@ -4,6 +4,8 @@ import logging from typing import Any, cast, Dict, Sequence, Tuple +from torch.utils._ordered_set import OrderedSet + import sympy import torch @@ -75,7 +77,7 @@ def filtered_configs( ), min_block_size_k, ) - used = set() + used = OrderedSet[tuple[int, ...]]() for block_m, block_n, block_k, num_stages, num_warps in configs: # shrink configs for small sizes block_m = max(min(int(block_m * scale), m), min_block_size) @@ -88,6 +90,7 @@ def filtered_configs( # each warp computes 16x16 tile = 256 num_warps = min(num_warps, block_m * block_n // 256) if torch.version.hip: + kpack = 2 for matrix_instr_nonkdim in [0, 16]: if matrix_instr_nonkdim != 0 and ( block_m % matrix_instr_nonkdim != 0 @@ -95,6 +98,7 @@ def filtered_configs( ): # block_m and block_n must be a multiple of matrix_instr_nonkdim continue + if ( block_m, block_n, @@ -102,6 +106,7 @@ def filtered_configs( num_stages, num_warps, matrix_instr_nonkdim, + kpack, ) not in used: used.add( ( @@ -111,6 +116,7 @@ def filtered_configs( num_stages, num_warps, matrix_instr_nonkdim, + kpack, ) ) yield triton_config( @@ -120,6 +126,7 @@ def filtered_configs( num_stages=num_stages, num_warps=num_warps, matrix_instr_nonkdim=matrix_instr_nonkdim, + kpack=kpack, ) else: if (block_m, block_n, block_k, num_stages, num_warps, 0) not in used: diff --git a/torch/_inductor/runtime/coordinate_descent_tuner.py b/torch/_inductor/runtime/coordinate_descent_tuner.py index 62a2abcea8d2d..5fe978276637a 100644 --- a/torch/_inductor/runtime/coordinate_descent_tuner.py +++ b/torch/_inductor/runtime/coordinate_descent_tuner.py @@ -21,6 +21,8 @@ def get_field(config, name): return config.num_warps elif name == "num_stages": return config.num_stages + elif name == "waves_per_eu": + return config.kwargs.get(name, int(8 // config.num_warps)) else: return config.kwargs.get(name, None) @@ -97,6 +99,8 @@ def tunable_fields(self): ] if self.is_mm: out.append("num_stages") + if self.inductor_meta.get("is_hip") is True: + out.append("waves_per_eu") return out @@ -105,6 +109,8 @@ def value_too_large(self, name: str, val: int) -> bool: return val > self.get_config_max(name[0].lower()) if name == "num_warps": return val > self.get_warpsmax() + if name == "waves_per_eu": + return val > 8 return False diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index b3fde21699dba..82ed98921d133 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -354,9 +354,16 @@ def jit_lines(self): triton_meta["configs"] = [config_of(signature)] for arg_num in triton_meta["configs"][0].equal_to_1: # type: ignore[index] triton_meta["constants"][signature[arg_num].name] = 1 # type: ignore[index] - matrix_instr_nonkdim = self.meta.get("matrix_instr_nonkdim", 0) - if matrix_instr_nonkdim != 0: + matrix_instr_nonkdim = self.meta.get("matrix_instr_nonkdim", None) + waves_per_eu = self.meta.get("waves_per_eu", None) + kpack = self.meta.get("kpack", None) + if matrix_instr_nonkdim: triton_meta["matrix_instr_nonkdim"] = matrix_instr_nonkdim + if waves_per_eu: + triton_meta["waves_per_eu"] = waves_per_eu + if kpack: + triton_meta["kpack"] = kpack + self.triton_meta = triton_meta @@ -920,6 +927,8 @@ def make_kernel_render(out_node): num_stages=num_stages, num_warps=num_warps, matrix_instr_nonkdim=kwargs.get("matrix_instr_nonkdim", 0), + waves_per_eu=kwargs.get("waves_per_eu", 0), + kpack=kwargs.get("kpack", 2), input_tensor_meta=TensorMeta.from_irnodes(full_input_nodes), # type: ignore[arg-type] output_tensor_meta=TensorMeta.from_irnodes(layout), workspace_arg=workspace_arg,