Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 37 additions & 39 deletions examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,31 @@
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
from einops import rearrange, einsum
import argparse

tilelang.disable_cache()


def get_configs():
import itertools
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better code organization and to follow PEP 8 guidelines, it's recommended to place all imports at the top of the file. Please move this import to the top-level of the module.

BLOCK_N = [16, 32, 64, 128]
BLOCK_H = [16, 32, 64, 128]
num_split = [1, 2, 4, 8, 16, 32]
threads = [128, 256]

_configs = list(itertools.product(BLOCK_N, BLOCK_H, num_split, threads))

return [{
"block_N": c[0],
"block_H": c[1],
"num_split": c[2],
"threads": c[3],
} for c in _configs]
Comment on lines +20 to +25
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The creation of configuration dictionaries can be made more concise and maintainable. Using zip with explicit keys avoids relying on hardcoded indices like c[0], c[1], etc., which is less error-prone if the order of parameters in itertools.product changes.

    return [dict(zip(("block_N", "block_H", "num_split", "threads"), c)) for c in _configs]



@tilelang.autotune(configs=get_configs())
@tilelang.jit(
out_idx=[6], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
Expand Down Expand Up @@ -273,26 +290,39 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=1, help='batch size')
parser.add_argument('--batch', type=int, default=128, help='batch size')
parser.add_argument('--heads', type=int, default=128, help='q heads number')
parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number')
parser.add_argument('--kv_ctx', type=int, default=1024, help='kv context length')
parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length')
parser.add_argument('--dim', type=int, default=512, help='head dim')
parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim')
parser.add_argument('--auto_tune', action='store_true', help='auto tune')
parser.add_argument('--autotune', action='store_true', help='auto tune')
args = parser.parse_args()
batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim
enable_autotune = args.auto_tune
enable_autotune = args.autotune

qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim)
pv_flops = 2 * batch * heads * kv_ctx * dim
total_flops = qk_flops + pv_flops
BLOCK_N = 32
BLOCK_H = 64
num_split = 4
threads = 128

kernel = flashmla_decode(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H,
num_split)
if enable_autotune:
kernel = flashmla_decode(batch, heads, kv_heads, kv_ctx, dim, pe_dim)
else:
kernel = flashmla_decode(
batch,
heads,
kv_heads,
kv_ctx,
dim,
pe_dim,
BLOCK_N,
BLOCK_H,
num_split,
threads=threads)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
input_tensors = profiler._get_inputs()
tilelang_output = kernel(*input_tensors)
Expand All @@ -303,35 +333,3 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
latency = profiler.do_bench(warmup=500)
print(f"Latency: {latency} ms")
print(f"TFlops: {total_flops / latency * 1e-9} TFlops")

# Enable Auto Tuning


def get_configs():
import itertools
BLOCK_N = [16, 32, 64, 128]
BLOCK_H = [16, 32, 64, 128]
num_split = [1, 2, 4, 8, 16, 32]
thread_num = [128, 256]

_configs = list(itertools.product(BLOCK_N, BLOCK_H, num_split, thread_num))

return [{
"block_N": c[0],
"block_H": c[1],
"num_split": c[2],
"thread_num": c[3],
} for c in _configs]

def wrapped_kernel(block_N=None, block_H=None, num_split=None, thread_num=None):
return flashmla_decode(batch, heads, kv_heads, kv_ctx, dim, pe_dim, block_N, block_H,
num_split, thread_num)

if enable_autotune:
autotuner = AutoTuner.from_kernel(kernel=wrapped_kernel, configs=get_configs())
tune_result = autotuner.run(warmup=3, rep=20)
best_latency = tune_result.latency
best_config = tune_result.config
print(f"Best latency: {best_latency} ms")
print(f"Best TFlops: {total_flops / best_latency * 1e-9} TFlops")
print(f"Best config: {best_config}")
17 changes: 13 additions & 4 deletions tilelang/autotuner/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ class AutoTuner:
profile_args = ProfileArgs()

_kernel_parameters: Optional[Tuple[str, ...]] = None
_function_parameters: Optional[Dict[str, Any]] = None
_lock = threading.Lock() # For thread safety
_memory_cache = {} # In-memory cache dictionary
cache_dir: Path = Path(env.TILELANG_CACHE_DIR) / "autotuner"
Expand Down Expand Up @@ -222,9 +223,10 @@ def set_profile_args(self,

return self

def set_kernel_parameters(self, parameters: Tuple[str, ...]):
def set_kernel_parameters(self, k_parameters: Tuple[str, ...], f_parameters: Dict[str, Any]):
# for cache key generation
self._kernel_parameters = parameters
self._kernel_parameters = k_parameters
self._function_parameters = f_parameters
Comment on lines +226 to +229
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The parameter names k_parameters and f_parameters are a bit cryptic. Using more descriptive names like kernel_params_key and function_parameters would improve readability and make the code easier to understand and maintain.

Suggested change
def set_kernel_parameters(self, k_parameters: Tuple[str, ...], f_parameters: Dict[str, Any]):
# for cache key generation
self._kernel_parameters = parameters
self._kernel_parameters = k_parameters
self._function_parameters = f_parameters
def set_kernel_parameters(self, kernel_params_key: Tuple[str, ...], function_parameters: Dict[str, Any]):
# for cache key generation
self._kernel_parameters = kernel_params_key
self._function_parameters = function_parameters


def generate_cache_key(self, parameters: Dict[str, Any]) -> Optional[AutotuneResult]:
"""Generate a cache key for the auto-tuning process.
Expand Down Expand Up @@ -417,8 +419,15 @@ def shape_equal(a, b):
key_args_tuple, key_kwargs_tuple = self._kernel_parameters
tunable_arguments = [key for key, _ in top_config.items()]

def check_tunable_argument_value(key, parameters, key_args_tuple) -> bool:
params_list = list(parameters.keys())
assert key in params_list, f"Tunable argument {key} not found in function parameters"
return params_list.index(key) < len(key_args_tuple)

# Check if all tunable arguments have been tuned by comparing config keys with key_kwargs_tuple
if any(key in top_config for key, _ in key_kwargs_tuple):
if any(key in top_config for key, _ in key_kwargs_tuple) or any(
check_tunable_argument_value(key, self._function_parameters, key_args_tuple)
for key in tunable_arguments):
logger.warning(
f"Tunable parameters {tunable_arguments} already provided during auto-tuning. Skipping compilation and using direct JIT"
)
Expand Down Expand Up @@ -676,7 +685,7 @@ def jit_compile(**config_arg):
)

autotuner.jit_compile = jit_compile
autotuner.set_kernel_parameters(key)
autotuner.set_kernel_parameters(key, inspect.signature(fn).parameters)

autotuner.run = partial(autotuner.run, warmup, rep, timeout)

Expand Down
Loading