-
Notifications
You must be signed in to change notification settings - Fork 266
[AMD][MLA] Fix mla autotune for rocm #861
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6932d2e
3e15964
ff24b5d
eee70de
6b78083
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,31 @@ | ||
import torch | ||
import torch.nn.functional as F | ||
import tilelang | ||
from tilelang.autotuner import * | ||
import tilelang.language as T | ||
from einops import rearrange, einsum | ||
import argparse | ||
|
||
tilelang.disable_cache() | ||
|
||
|
||
def get_configs(): | ||
import itertools | ||
BLOCK_N = [16, 32, 64, 128] | ||
BLOCK_H = [16, 32, 64, 128] | ||
num_split = [1, 2, 4, 8, 16, 32] | ||
threads = [128, 256] | ||
|
||
_configs = list(itertools.product(BLOCK_N, BLOCK_H, num_split, threads)) | ||
|
||
return [{ | ||
"block_N": c[0], | ||
"block_H": c[1], | ||
"num_split": c[2], | ||
"threads": c[3], | ||
} for c in _configs] | ||
Comment on lines
+20
to
+25
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The creation of configuration dictionaries can be made more concise and maintainable. Using return [dict(zip(("block_N", "block_H", "num_split", "threads"), c)) for c in _configs] |
||
|
||
|
||
@tilelang.autotune(configs=get_configs()) | ||
@tilelang.jit( | ||
out_idx=[6], pass_configs={ | ||
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, | ||
|
@@ -273,26 +290,39 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial): | |
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--batch', type=int, default=1, help='batch size') | ||
parser.add_argument('--batch', type=int, default=128, help='batch size') | ||
parser.add_argument('--heads', type=int, default=128, help='q heads number') | ||
parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number') | ||
parser.add_argument('--kv_ctx', type=int, default=1024, help='kv context length') | ||
parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length') | ||
parser.add_argument('--dim', type=int, default=512, help='head dim') | ||
parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim') | ||
parser.add_argument('--auto_tune', action='store_true', help='auto tune') | ||
parser.add_argument('--autotune', action='store_true', help='auto tune') | ||
args = parser.parse_args() | ||
batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim | ||
enable_autotune = args.auto_tune | ||
enable_autotune = args.autotune | ||
|
||
qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim) | ||
pv_flops = 2 * batch * heads * kv_ctx * dim | ||
total_flops = qk_flops + pv_flops | ||
BLOCK_N = 32 | ||
BLOCK_H = 64 | ||
num_split = 4 | ||
threads = 128 | ||
|
||
kernel = flashmla_decode(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, | ||
num_split) | ||
if enable_autotune: | ||
kernel = flashmla_decode(batch, heads, kv_heads, kv_ctx, dim, pe_dim) | ||
else: | ||
kernel = flashmla_decode( | ||
batch, | ||
heads, | ||
kv_heads, | ||
kv_ctx, | ||
dim, | ||
pe_dim, | ||
BLOCK_N, | ||
BLOCK_H, | ||
num_split, | ||
threads=threads) | ||
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) | ||
input_tensors = profiler._get_inputs() | ||
tilelang_output = kernel(*input_tensors) | ||
|
@@ -303,35 +333,3 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial): | |
latency = profiler.do_bench(warmup=500) | ||
print(f"Latency: {latency} ms") | ||
print(f"TFlops: {total_flops / latency * 1e-9} TFlops") | ||
|
||
# Enable Auto Tuning | ||
|
||
|
||
def get_configs(): | ||
import itertools | ||
BLOCK_N = [16, 32, 64, 128] | ||
BLOCK_H = [16, 32, 64, 128] | ||
num_split = [1, 2, 4, 8, 16, 32] | ||
thread_num = [128, 256] | ||
|
||
_configs = list(itertools.product(BLOCK_N, BLOCK_H, num_split, thread_num)) | ||
|
||
return [{ | ||
"block_N": c[0], | ||
"block_H": c[1], | ||
"num_split": c[2], | ||
"thread_num": c[3], | ||
} for c in _configs] | ||
|
||
def wrapped_kernel(block_N=None, block_H=None, num_split=None, thread_num=None): | ||
return flashmla_decode(batch, heads, kv_heads, kv_ctx, dim, pe_dim, block_N, block_H, | ||
num_split, thread_num) | ||
|
||
if enable_autotune: | ||
autotuner = AutoTuner.from_kernel(kernel=wrapped_kernel, configs=get_configs()) | ||
tune_result = autotuner.run(warmup=3, rep=20) | ||
best_latency = tune_result.latency | ||
best_config = tune_result.config | ||
print(f"Best latency: {best_latency} ms") | ||
print(f"Best TFlops: {total_flops / best_latency * 1e-9} TFlops") | ||
print(f"Best config: {best_config}") |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -104,6 +104,7 @@ class AutoTuner: | |||||||||||||||||||
profile_args = ProfileArgs() | ||||||||||||||||||||
|
||||||||||||||||||||
_kernel_parameters: Optional[Tuple[str, ...]] = None | ||||||||||||||||||||
_function_parameters: Optional[Dict[str, Any]] = None | ||||||||||||||||||||
_lock = threading.Lock() # For thread safety | ||||||||||||||||||||
_memory_cache = {} # In-memory cache dictionary | ||||||||||||||||||||
cache_dir: Path = Path(env.TILELANG_CACHE_DIR) / "autotuner" | ||||||||||||||||||||
|
@@ -222,9 +223,10 @@ def set_profile_args(self, | |||||||||||||||||||
|
||||||||||||||||||||
return self | ||||||||||||||||||||
|
||||||||||||||||||||
def set_kernel_parameters(self, parameters: Tuple[str, ...]): | ||||||||||||||||||||
def set_kernel_parameters(self, k_parameters: Tuple[str, ...], f_parameters: Dict[str, Any]): | ||||||||||||||||||||
# for cache key generation | ||||||||||||||||||||
self._kernel_parameters = parameters | ||||||||||||||||||||
self._kernel_parameters = k_parameters | ||||||||||||||||||||
self._function_parameters = f_parameters | ||||||||||||||||||||
Comment on lines
+226
to
+229
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The parameter names
Suggested change
|
||||||||||||||||||||
|
||||||||||||||||||||
def generate_cache_key(self, parameters: Dict[str, Any]) -> Optional[AutotuneResult]: | ||||||||||||||||||||
"""Generate a cache key for the auto-tuning process. | ||||||||||||||||||||
|
@@ -417,8 +419,15 @@ def shape_equal(a, b): | |||||||||||||||||||
key_args_tuple, key_kwargs_tuple = self._kernel_parameters | ||||||||||||||||||||
tunable_arguments = [key for key, _ in top_config.items()] | ||||||||||||||||||||
|
||||||||||||||||||||
def check_tunable_argument_value(key, parameters, key_args_tuple) -> bool: | ||||||||||||||||||||
params_list = list(parameters.keys()) | ||||||||||||||||||||
assert key in params_list, f"Tunable argument {key} not found in function parameters" | ||||||||||||||||||||
return params_list.index(key) < len(key_args_tuple) | ||||||||||||||||||||
|
||||||||||||||||||||
# Check if all tunable arguments have been tuned by comparing config keys with key_kwargs_tuple | ||||||||||||||||||||
if any(key in top_config for key, _ in key_kwargs_tuple): | ||||||||||||||||||||
if any(key in top_config for key, _ in key_kwargs_tuple) or any( | ||||||||||||||||||||
check_tunable_argument_value(key, self._function_parameters, key_args_tuple) | ||||||||||||||||||||
for key in tunable_arguments): | ||||||||||||||||||||
logger.warning( | ||||||||||||||||||||
f"Tunable parameters {tunable_arguments} already provided during auto-tuning. Skipping compilation and using direct JIT" | ||||||||||||||||||||
) | ||||||||||||||||||||
|
@@ -676,7 +685,7 @@ def jit_compile(**config_arg): | |||||||||||||||||||
) | ||||||||||||||||||||
|
||||||||||||||||||||
autotuner.jit_compile = jit_compile | ||||||||||||||||||||
autotuner.set_kernel_parameters(key) | ||||||||||||||||||||
autotuner.set_kernel_parameters(key, inspect.signature(fn).parameters) | ||||||||||||||||||||
|
||||||||||||||||||||
autotuner.run = partial(autotuner.run, warmup, rep, timeout) | ||||||||||||||||||||
|
||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For better code organization and to follow PEP 8 guidelines, it's recommended to place all imports at the top of the file. Please move this import to the top-level of the module.