diff --git a/examples/blocksparse_gemm/example_blocksparse_gemm.py b/examples/blocksparse_gemm/example_blocksparse_gemm.py index b7785f61e..c9d7804c8 100644 --- a/examples/blocksparse_gemm/example_blocksparse_gemm.py +++ b/examples/blocksparse_gemm/example_blocksparse_gemm.py @@ -5,7 +5,6 @@ import itertools import tilelang import tilelang.language as T -from tilelang.autotuner import AutoTuner from tilelang.engine.param import KernelParam from tilelang.utils.tensor import get_tensor_supply, TensorSupplyType import torch @@ -37,7 +36,7 @@ print(f"Using Autotuner: {use_autotune}\n") -def get_configs(M, N, K): +def get_configs(): block_M = [64, 128, 256] block_N = [64, 128, 256] block_K = [32, 64] @@ -93,55 +92,7 @@ def supply_program(params: List[KernelParam]): return input_tensors -def get_best_config(M, N, K): - - # Define the kernel function to be tuned. - # Parameters like block_M, block_N, etc., are tuned by the AutoTuner. - def kernel(block_M=None, - block_N=None, - block_K=None, - num_stages=None, - thread_num=None, - enable_rasteration=None): - return blocksparse_matmul(M, N, K, block_M, block_N, block_K, num_stages, thread_num, - enable_rasteration) - - autotuner = AutoTuner.from_kernel( - kernel=kernel, configs=get_configs(M, N, K) - ).set_compile_args( - out_idx=[-1], # Index of the output tensor - target="auto", # Automatically detect target - ).set_profile_args( - # supply_type should not set here because we provide a custom supply - # function `supply_prog` and `supply_type` will be ignored. - - # supply_prog: Provide the custom function to generate input tensors - # (A, B, BlockMask) for the kernel, allowing controlling sparsity via - # BlockMask generation. - supply_prog=supply_program, - - # ref_prog: Using dense matmul (A @ B) as a placeholder reference. - # The 'correct' block-sparse reference (`ref_program` above) requires - # block_M, block_N, block_K parameters. However, these parameters are - # part of the configuration being *tuned* by the AutoTuner and cannot - # be fixed inputs to a static `ref_prog` function signature. - # This dense matmul serves only as a performance baseline. - ref_prog=lambda A, B, BlockMask: A @ B, - - # skip_check: Set to True because the provided `ref_prog` does not - # compute the correct result for the block-sparse kernel. - skip_check=True, - - # cache_input_tensors: Set to False because the shape of the BlockMask tensor - # (dependent on block_M, block_N, block_K being tuned) changes between - # different configurations. Reusing cached tensors from a previous - # configuration would lead to shape mismatches. - cache_input_tensors=False, - ) - # Run the tuning process - return autotuner.run(warmup=3, rep=20) - - +@tilelang.autotune(configs=get_configs(),) @tilelang.jit(out_idx=[-1]) def blocksparse_matmul(M, N, @@ -195,22 +146,16 @@ def main(): # Run the autotuner to find the best kernel configuration and performance # get_best_config is expected to return an object containing the compiled kernel, # the best configuration found, latency, and reference latency. - result = get_best_config(M, N, K) + kernel = blocksparse_matmul(M, N, K) - # Extract results from the autotuner run - kernel = result.kernel - best_config = result.config - block_M = best_config[0] - block_N = best_config[1] - block_K = best_config[2] - best_latency = result.latency - ref_latency = result.ref_latency + best_config = kernel.config + best_latency = kernel.latency + block_M, block_N, block_K = best_config["block_M"], best_config["block_N"], best_config[ + "block_K"] print(f"Best Config: {best_config}") - print(f"Block Dimensions (BM, BN, BK): ({block_M}, {block_N}, {block_K})") print(f"Sparsity Ratio: {sparsity}") print(f"Best Kernel Latency: {best_latency:.6f} ms") - print(f"Reference Latency: {ref_latency:.6f} ms") else: kernel = blocksparse_matmul(M, N, K, DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K, DEFAULT_NUM_STAGES, DEFAULT_THREAD_NUM, diff --git a/examples/convolution/example_convolution_autotune.py b/examples/convolution/example_convolution_autotune.py index 53ab8bd7b..923eb7112 100644 --- a/examples/convolution/example_convolution_autotune.py +++ b/examples/convolution/example_convolution_autotune.py @@ -3,8 +3,7 @@ import torch import argparse import itertools -import tilelang as tl -from tilelang.autotuner import * +import tilelang import tilelang.language as T from tilelang.autotuner import AutoTuner from tilelang.carver.template import ConvTemplate @@ -167,7 +166,7 @@ def main( out_idx=[2], target="auto", ).set_profile_args( - supply_type=tl.TensorSupplyType.Integer, + supply_type=tilelang.TensorSupplyType.Integer, ref_prog=ref_prog, skip_check=False, ) @@ -301,9 +300,9 @@ def main(n: int = 128, kernel = result.kernel else: config = get_heuristic_config() - kernel = tl.compile(convolution(N, C, H, W, F, K, S, D, P, **config), out_idx=[2]) + kernel = tilelang.compile(convolution(N, C, H, W, F, K, S, D, P, **config), out_idx=[2]) - profiler = kernel.get_profiler(tensor_supply_type=tl.TensorSupplyType.Auto) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) tilelang_latency = profiler.do_bench() ref_latency = profiler.do_bench(ref_prog) profiler.assert_allclose(ref_prog, atol=1e-2, rtol=1e-2) diff --git a/testing/python/autotune/test_tilelang_autotune_decorator.py b/testing/python/autotune/test_tilelang_autotune_decorator.py deleted file mode 100644 index 088bd5488..000000000 --- a/testing/python/autotune/test_tilelang_autotune_decorator.py +++ /dev/null @@ -1,265 +0,0 @@ -# Copyright (c) Tile-AI Corporation. -# Licensed under the MIT License. - -import itertools -import logging - -import tilelang.testing -import tilelang.language as T -from tilelang.autotuner import autotune - -# Configure logger -logger = logging.getLogger(__name__) -logger.setLevel(logging.DEBUG) - - -def ref_program(A, B): - """ - A reference matrix multiplication program, used to compare performance. - - Parameters - ---------- - A : numpy.ndarray - The matrix with shape (M, K). - B : numpy.ndarray - The matrix with shape (N, K). - - Returns - ------- - np.ndarray - The result of A @ B.T, shape (M, N). - """ - return A @ B.T - - -def get_configs(M, N, K, with_roller=False): - """ - Generate a list of configuration dictionaries that will be used for tuning. - - Parameters - ---------- - with_roller : bool - Whether to enable bitblas roller to deduce search spaces - - Returns - ------- - list of dict - Each configuration dict includes various block sizes, pipeline stages, - thread numbers, and other parameters to explore during autotuning. - """ - if with_roller: - from tilelang.carver.template import MatmulTemplate - from tilelang.carver.arch import CUDA - from tilelang.carver.roller.rasterization import NoRasterization - arch = CUDA("cuda") - topk = 20 - - # Simple TIR Compute Expression - carve_template = MatmulTemplate( - M=M, - N=N, - K=K, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", - ).with_arch(arch) - - func = carve_template.equivalent_function() - assert func is not None, "Function is None" - - roller_hints = carve_template.recommend_hints(topk=topk) - - if roller_hints is None: - raise ValueError("No Roller Hints Found for TensorCore Scheduling") - - configs = [] - for hint in roller_hints: - config = {} - block_m, block_n = hint.block - warp_m, warp_n = hint.warp - config["block_M"] = block_m - config["block_N"] = block_n - config["block_K"] = hint.rstep[0] - config["num_stages"] = 0 - config["thread_num"] = (block_m * block_n) // (warp_m * warp_n) * 32 - config["enable_rasteration"] = hint.rasterization_plan is not NoRasterization - configs.append(config) - for config in configs: - print(config) - else: - - block_M = [64] - block_N = [64] - block_K = [32] - num_stages = [0, 1] - thread_num = [128] - enable_rasterization = [False] - - _configs = list( - itertools.product( - block_M, - block_N, - block_K, - num_stages, - thread_num, - enable_rasterization, - )) - - configs = [ - { - "block_M": c[0], - "block_N": c[1], - "block_K": c[2], - "num_stages": c[3], - "thread_num": c[4], - "enable_rasteration": c[5], # keep param name for backward-compat - } for c in _configs - ] - return configs - - -def matmul(M, N, K, with_roller): - """ - Create an autotuned matrix multiplication kernel for matrices of shape: - - A: (M, K) - - B: (N, K) - - C: (M, N) - - Parameters - ---------- - M : int - The dimension M of the matrix multiplication. - N : int - The dimension N of the matrix multiplication. - K : int - The dimension K of the matrix multiplication. - - Returns - ------- - (best_latency, best_config, ref_latency) - best_latency : float - The best latency found among the tuned configurations. - best_config : dict - The parameter configuration that yielded best_latency. - ref_latency : float - The baseline latency of the reference program (for computing speedup). - """ - - @autotune( - configs=get_configs(M, N, K, with_roller), - warmup=3, - rep=20, - ) - @tilelang.jit(out_idx=[-1],) - def kernel( - block_M=None, - block_N=None, - block_K=None, - num_stages=None, - thread_num=None, - enable_rasteration=None, - ): - """ - The actual kernel to compute C = A @ B^T. - - Parameters - ---------- - block_M : int - Block size in M dimension. - block_N : int - Block size in N dimension. - block_K : int - Block size in K dimension. - num_stages : int - Number of pipelined stages (for asynchronous load). - thread_num : int - Number of threads to use per block. - enable_rasteration : bool - Whether to enable rasterization (swizzling) optimization. - k_pack : int - K dimension packing factor to improve memory coalescing. - - Returns - ------- - Function - A TVM Tensor Language function (T.prim_func) that computes matmul. - """ - # Use half-precision for input data to reduce memory bandwidth, - # accumulate in float for better numerical accuracy - dtype = "float16" - accum_dtype = "float" - - @T.prim_func - def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), dtype), - ): - """ - The compiled TVM function for block-level matrix multiplication. - - - We divide the entire (M, N) domain into blocks of shape - (block_M, block_N). - - Each block has its own allocated shared memory for sub-blocks - of A and B. - - The partial results go into C_local, and then we copy them back - to global memory C. - """ - # Bind x-dimension to block index in N, - # y-dimension to block index in M. - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): - - # Allocate shared memory for A sub-block of shape (block_M, block_K) - A_shared = T.alloc_shared((block_M, block_K), dtype) - # Allocate shared memory for B sub-block of shape (block_N, block_K) - B_shared = T.alloc_shared((block_N, block_K), dtype) - # Allocate a local fragment for intermediate accumulation - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - - # Enable (or disable) swizzling optimization - T.use_swizzle(panel_size=10, enable=enable_rasteration) - - # Clear out the accumulation buffer - T.clear(C_local) - - # Loop over sub-blocks in K dimension, pipelined by num_stages - for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - # Load a sub-block of A from global memory into A_shared - T.copy( - A[by * block_M, k * block_K], - A_shared, - ) - # Load a sub-block of B from global memory into B_shared - T.copy( - B[bx * block_N, k * block_K], - B_shared, - ) - # Perform a partial matrix multiplication: - # C_local += A_shared @ B_shared^T - T.gemm( - A_shared, - B_shared, - C_local, - transpose_B=True, - ) - # Write back the results from C_local to the global memory C - T.copy(C_local, C[by * block_M, bx * block_N]) - - return main - - return kernel() - - -def test_autotune_get_configs(): - get_configs(8192, 8192, 8192, with_roller=True) - get_configs(8192, 8192, 8192, with_roller=False) - - -def test_autotune_matmul(): - matmul(8192, 8192, 8192, with_roller=True) - matmul(8192, 8192, 8192, with_roller=False) - - -if __name__ == "__main__": - tilelang.testing.main() diff --git a/testing/python/autotune/test_tilelang_autotune_with_inputs.py b/testing/python/autotune/test_tilelang_autotune_with_inputs.py new file mode 100644 index 000000000..8f76560de --- /dev/null +++ b/testing/python/autotune/test_tilelang_autotune_with_inputs.py @@ -0,0 +1,140 @@ +# Copyright (c) Tile-AI Corporation. +# Licensed under the MIT License. + +import itertools +import logging +import tilelang +import tilelang.testing +from tilelang.autotuner import set_autotune_inputs +import tilelang.language as T + +# Configure logger +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +def ref_program(A, B): + """ + A reference matrix multiplication program, used to compare performance. + + Parameters + ---------- + A : numpy.ndarray + The matrix with shape (M, K). + B : numpy.ndarray + The matrix with shape (N, K). + + Returns + ------- + np.ndarray + The result of A @ B.T, shape (M, N). + """ + return A @ B.T + + +def get_configs(): + iter_params = dict( + block_M=[64], + block_N=[64], + block_K=[32], + num_stages=[0, 1], + thread_num=[128], + enable_rasterization=[False]) + return [{ + k: v for k, v in zip(iter_params, values) + } for values in itertools.product(*iter_params.values())] + + +@tilelang.jit(out_idx=[-1]) +def matmul(M, + N, + K, + block_M=128, + block_N=128, + block_K=32, + num_stages=0, + thread_num=128, + enable_rasteration=False): + + dtype = "float16" + accum_dtype = "float" + + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), + ): + """ + The compiled TVM function for block-level matrix multiplication. + + - We divide the entire (M, N) domain into blocks of shape + (block_M, block_N). + - Each block has its own allocated shared memory for sub-blocks + of A and B. + - The partial results go into C_local, and then we copy them back + to global memory C. + """ + # Bind x-dimension to block index in N, + # y-dimension to block index in M. + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): + + # Allocate shared memory for A sub-block of shape (block_M, block_K) + A_shared = T.alloc_shared((block_M, block_K), dtype) + # Allocate shared memory for B sub-block of shape (block_N, block_K) + B_shared = T.alloc_shared((block_N, block_K), dtype) + # Allocate a local fragment for intermediate accumulation + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + # Enable (or disable) swizzling optimization + T.use_swizzle(panel_size=10, enable=enable_rasteration) + + # Clear out the accumulation buffer + T.clear(C_local) + + # Loop over sub-blocks in K dimension, pipelined by num_stages + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + # Load a sub-block of A from global memory into A_shared + T.copy( + A[by * block_M, k * block_K], + A_shared, + ) + # Load a sub-block of B from global memory into B_shared + T.copy( + B[bx * block_N, k * block_K], + B_shared, + ) + # Perform a partial matrix multiplication: + # C_local += A_shared @ B_shared^T + T.gemm( + A_shared, + B_shared, + C_local, + transpose_B=True, + ) + # Write back the results from C_local to the global memory C + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_autotune(M: int, N: int, K: int): + import torch + a = torch.randn(M, K, dtype=torch.float16).cuda() + b = torch.randn(N, K, dtype=torch.float16).cuda() + + with set_autotune_inputs([a, b]): + kernel = matmul(M, N, K) + + c = kernel(a, b) + + ref_c = ref_program(a, b) + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + + +def test_autotune_matmul(): + run_autotune(8192, 8192, 8192) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/autotuner/__init__.py b/tilelang/autotuner/__init__.py index 44cd7df6f..6b8307f3f 100644 --- a/tilelang/autotuner/__init__.py +++ b/tilelang/autotuner/__init__.py @@ -1,782 +1,11 @@ # Copyright (c) Tile-AI Corporation. # Licensed under the MIT License. -"""The auto-tune module for tilelang programs. -This module provides functionality for auto-tuning tilelang programs, including JIT compilation -and performance optimization through configuration search. -""" - -import tilelang -from tilelang import tvm as tvm -from tvm.tir import PrimFunc, Var -from tvm.target import Target -import inspect -from functools import partial -from typing import (Callable, List, Literal, Any, Optional, Union, Dict, overload, Tuple) -from tqdm import tqdm -import logging -import functools -import concurrent.futures -import torch -import os -import sys -import signal -import json -import hashlib -import threading -import traceback -from pathlib import Path - -from tilelang.env import ( - TILELANG_CACHE_DIR, - TILELANG_AUTO_TUNING_CPU_UTILITIES, - TILELANG_AUTO_TUNING_CPU_COUNTS, - TILELANG_AUTO_TUNING_MAX_CPU_COUNT, - is_cache_enabled, +from .tuner import ( + autotune, # noqa: F401 + AutoTuner, # noqa: F401 +) +from .capture import ( + set_autotune_inputs, # noqa: F401 + get_autotune_inputs, # noqa: F401 ) -from tilelang.autotuner.param import CompileArgs, ProfileArgs, AutotuneResult -from tilelang.jit.param import _P, _RProg -from tilelang.version import __version__ - - -class TimeoutException(Exception): - pass - - -def timeout_handler(signum, frame): - raise TimeoutException("Operation timed out") - - -def run_with_timeout(func, timeout, *args, **kwargs): - signal.signal(signal.SIGALRM, timeout_handler) - signal.alarm(timeout) - try: - result = func(*args, **kwargs) - except Exception as e: - raise e - finally: - signal.alarm(0) - return result - - -# Configure logging for the autotuner module -# TODO: Consider creating a common logger in utils -logger = logging.getLogger(__name__) -logger.setLevel(logging.DEBUG) -logger.propagate = False - -# Lazy handler initialization flag -_logger_handlers_initialized = False - - -def _init_logger_handlers(): - global _logger_handlers_initialized - if _logger_handlers_initialized: - return - formatter = logging.Formatter('%(asctime)s %(levelname)s:%(message)s') - file_handler = logging.FileHandler('autotuner.log', mode='w') - file_handler.setLevel(logging.DEBUG) - file_handler.setFormatter(formatter) - console_handler = logging.StreamHandler(sys.stdout) - console_handler.setLevel(logging.INFO) - console_handler.setFormatter(formatter) - logger.addHandler(file_handler) - logger.addHandler(console_handler) - _logger_handlers_initialized = True - - -def get_available_cpu_count() -> int: - """Gets the number of CPU cores available to the current process. - """ - try: - cpu_count = len(os.sched_getaffinity(0)) - except AttributeError: - cpu_count = os.cpu_count() - - return cpu_count - - -class AutoTuner: - """Auto-tuner for tilelang programs. - - This class handles the auto-tuning process by testing different configurations - and finding the optimal parameters for program execution. - - Args: - fn: The function to be auto-tuned. - configs: List of configurations to try during auto-tuning. - """ - compile_args = CompileArgs() - profile_args = ProfileArgs() - - _kernel_parameters: Optional[Tuple[str, ...]] = None - _lock = threading.Lock() # For thread safety - _memory_cache = {} # In-memory cache dictionary - cache_dir: Path = Path(TILELANG_CACHE_DIR) / "autotuner" - - def __init__(self, fn: Callable, configs): - self.fn = fn - self.configs = configs - self.ref_latency_cache = None - self.jit_input_tensors = None - self.ref_input_tensors = None - self.jit_compile = None - - @classmethod - def from_kernel(cls, kernel: Callable, configs): - """Create an AutoTuner instance from a kernel function. - - Args: - kernel: The kernel function to auto-tune. - configs: List of configurations to try. - - Returns: - AutoTuner: A new AutoTuner instance. - """ - return cls(kernel, configs) - - def set_compile_args(self, - out_idx: Union[List[int], int, None] = None, - target: Literal['auto', 'cuda', 'hip'] = 'auto', - execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython", - target_host: Union[str, Target] = None, - verbose: bool = False, - pass_configs: Optional[Dict[str, Any]] = None): - """Set compilation arguments for the auto-tuner. - - Args: - out_idx: List of output tensor indices. - target: Target platform. - execution_backend: Execution backend to use for kernel execution. - target_host: Target host for cross-compilation. - verbose: Whether to enable verbose output. - pass_configs: Additional keyword arguments to pass to the Compiler PassContext. - - Returns: - AutoTuner: Self for method chaining. - """ - self.compile_args = CompileArgs( - out_idx=out_idx, - target=target, - execution_backend=execution_backend, - target_host=target_host, - verbose=verbose, - pass_configs=pass_configs) - - return self - - def set_profile_args(self, - warmup: int = 25, - rep: int = 100, - timeout: int = 30, - supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto, - ref_prog: Callable = None, - supply_prog: Callable = None, - rtol: float = 1e-2, - atol: float = 1e-2, - max_mismatched_ratio: float = 0.01, - skip_check: bool = False, - manual_check_prog: Callable = None, - cache_input_tensors: bool = False): - """Set profiling arguments for the auto-tuner. - - Args: - supply_type: Type of tensor supply mechanism. Ignored if `supply_prog` is provided. - ref_prog: Reference program for validation. - supply_prog: Supply program for input tensors. - rtol: Relative tolerance for validation. - atol: Absolute tolerance for validation. - max_mismatched_ratio: Maximum allowed mismatch ratio. - skip_check: Whether to skip validation. - manual_check_prog: Manual check program for validation. - cache_input_tensors: Whether to cache input tensors. - warmup: Number of warmup iterations. - rep: Number of repetitions for timing. - timeout: Maximum time per configuration. - - Returns: - AutoTuner: Self for method chaining. - """ - self.profile_args = ProfileArgs( - supply_type=supply_type, - ref_prog=ref_prog, - supply_prog=supply_prog, - rtol=rtol, - atol=atol, - max_mismatched_ratio=max_mismatched_ratio, - skip_check=skip_check, - manual_check_prog=manual_check_prog, - cache_input_tensors=cache_input_tensors, - warmup=warmup, - rep=rep, - timeout=timeout) - - # If a custom `supply_prog` is provided, the profiler's `supply_type` setting - # becomes ineffective. The custom supply program will be used instead. - if supply_prog is not None and supply_type != tilelang.TensorSupplyType.Auto: - logger.warning("Ignoring `supply_type` passed to `set_profile_args` because " - "`supply_prog` is not None.") - - return self - - def set_kernel_parameters(self, parameters: Tuple[str, ...]): - # for cache key generation - self._kernel_parameters = parameters - - def generate_cache_key(self, parameters: Dict[str, Any]) -> Optional[AutotuneResult]: - """Generate a cache key for the auto-tuning process. - """ - # extract parameters from the function signature - op_parameters = [] - for _, default_value in parameters.items(): - if default_value.default is not inspect.Parameter.empty: - op_parameters.append(default_value.default) - - if self._kernel_parameters is not None: - op_parameters += self._kernel_parameters - - func_source = inspect.getsource(self.fn) - key_data = { - "version": __version__, - "op_parameters": tuple(op_parameters), - "func_source": func_source, - "configs": self.configs, - "compile_args": hash(self.compile_args), - "profile_args": hash(self.profile_args), - } - # Sort keys to ensure consistency - key_string = json.dumps(key_data, sort_keys=True) - return hashlib.sha256(key_string.encode()).hexdigest() - - def _save_result_to_disk(self, key, result: AutotuneResult): - result.save_to_disk(self.cache_dir / key) - - def _load_result_from_disk(self, key) -> AutotuneResult: - result = AutotuneResult.load_from_disk(self.cache_dir / key, self.compile_args) - return result - - def run(self, warmup: int = 25, rep: int = 100, timeout: int = 30): - """Run the auto-tuning process. - - Args: - warmup: Number of warmup iterations. - rep: Number of repetitions for timing. - timeout: Maximum time per configuration. - - Returns: - AutotuneResult: Results of the auto-tuning process. - """ - _init_logger_handlers() - - sig = inspect.signature(self.fn) - parameters = sig.parameters - - if isinstance(self.configs, Callable): - self.configs = self.configs(*self._kernel_parameters) - - key = self.generate_cache_key(parameters) - - with self._lock: - if is_cache_enabled(): - # First check in-memory cache - if key in self._memory_cache: - logger.warning("Found kernel in memory cache. For better performance," \ - " consider using `@tilelang.autotune` instead of direct AutoTuner.from_kernel.") - return self._memory_cache[key] - - # Then check disk cache - result = self._load_result_from_disk(key) - if result is not None: - # Populate memory cache with disk result - self._memory_cache[key] = result - return result - - best_latency: float = 1e8 - best_config: Optional[Dict[str, Any]] = None - best_kernel: Optional[tilelang.JITKernel] = None - - def _compile(**config_arg) -> tilelang.JITKernel: - compile_args = self.compile_args - return compile_args.compile_program(self.fn(**config_arg)) - - if self.jit_compile is None: - self.jit_compile = _compile - - def target_fn(jit_kernel: tilelang.JITKernel): - # Unpack the context - profile_args = self.profile_args - supply_type = profile_args.supply_type - skip_check = profile_args.skip_check - manual_check_prog = profile_args.manual_check_prog - cache_input_tensors = profile_args.cache_input_tensors - ref_prog = profile_args.ref_prog - supply_prog = profile_args.supply_prog - rtol = profile_args.rtol - atol = profile_args.atol - max_mismatched_ratio = profile_args.max_mismatched_ratio - - profiler = jit_kernel.get_profiler(tensor_supply_type=supply_type) - - # Factory functions for generating input tensors. - # This encapsulates the logic of using either a custom supply program (`supply_prog`) - # or the default profiler input generation (`profiler._get_inputs`). - def get_input_tensors_supply(with_output: bool): - - def func(): - if supply_prog is not None: - return supply_prog(profiler._get_params(with_output=with_output)) - else: - return profiler._get_inputs(with_output=with_output) - - return func - - jit_input_tensors_supply = get_input_tensors_supply(with_output=False) - ref_input_tensors_supply = get_input_tensors_supply(with_output=False) - - if cache_input_tensors: - params = profiler._get_params(with_output=False) - if self.jit_input_tensors is None: - self.jit_input_tensors = jit_input_tensors_supply() - else: - # check if the cached tensors are compatible with the current configuration - assert len(params) == len( - self.jit_input_tensors), "len(params) != len(self.jit_input_tensors)" - for p, c in zip(params, self.jit_input_tensors): - if not isinstance(c, torch.Tensor): - # skip non-tensor inputs checking - continue - - # Check tensor compatibility using generator expression - def shape_equal(a, b): - return all( - a_dim == b_dim or isinstance(a_dim, Var) or isinstance(b_dim, Var) - for a_dim, b_dim in zip(a.shape, b.shape)) - - if p.dtype != c.dtype or not shape_equal(p, c): - logger.warning( - "\nIncompatible input tensor properties detected between cached tensors and " - "tensors regenerated for the current configuration trial. " - "This can happen if different tuning configurations require different input shapes/dtypes " - "and input tensor caching is enabled.\n" - "To ensure fresh, compatible inputs are generated for every trial " - "you can disable caching by setting:\n" - " `cache_input_tensors=False`\n" - "within your `.set_compile_args(...)` call.\n") - # otherwise, regenerate the input tensors for safety - self.jit_input_tensors = jit_input_tensors_supply() - break - else: - self.jit_input_tensors = jit_input_tensors_supply() - - if (not skip_check) and (ref_prog is not None): - if manual_check_prog is not None: - profiler.manual_assert_close( - ref_prog, - input_tensors=self.jit_input_tensors, - manual_check_prog=manual_check_prog) - else: - profiler.assert_allclose( - ref_prog, - input_tensors=self.jit_input_tensors, - rtol=rtol, - atol=atol, - max_mismatched_ratio=max_mismatched_ratio) - latency = profiler.do_bench( - warmup=warmup, rep=rep, input_tensors=self.jit_input_tensors) - - if self.ref_latency_cache is None and ref_prog is not None: - self.ref_input_tensors = ref_input_tensors_supply() - self.ref_latency_cache = profiler.do_bench( - ref_prog, n_warmup=warmup, n_repeat=rep, input_tensors=self.ref_input_tensors) - - return latency, self.ref_latency_cache - - config_args = [] - for config in self.configs: - new_kwargs = {} - keys = config.keys() - for name, _ in parameters.items(): - if name in config: - new_kwargs[name] = config[name] - unused_keys = set(keys) - set(new_kwargs.keys()) - if len(unused_keys) > 0: - raise ValueError(f"Unused keys in config: {unused_keys}") - config_args.append(new_kwargs) - - if len(config_args) == 0: - raise ValueError("No configurations to tune, please check your `@autotune` decorator") - - # check if the tunable arguments has been set. - # get the back config argument - top_config, *rest = config_args - - if self._kernel_parameters is not None: - key_args_tuple, key_kwargs_tuple = self._kernel_parameters - tunable_arguments = [key for key, _ in top_config.items()] - - # Check if all tunable arguments have been tuned by comparing config keys with key_kwargs_tuple - if any(key in top_config for key, _ in key_kwargs_tuple): - logger.warning( - f"Tunable parameters {tunable_arguments} already provided during auto-tuning. Skipping compilation and using direct JIT" - ) - # compile the kernel with the provided parameters - jit_kernel = self.jit_compile() - autotuner_result = AutotuneResult( - libcode=jit_kernel.get_kernel_source(), - func=jit_kernel.prim_func, - kernel=jit_kernel) - self._memory_cache[key] = autotuner_result - return autotuner_result - # get the cpu count - available_cpu_count = get_available_cpu_count() - cpu_utilizations = float(TILELANG_AUTO_TUNING_CPU_UTILITIES) - cpu_counts = int(TILELANG_AUTO_TUNING_CPU_COUNTS) - max_cpu_count = int(TILELANG_AUTO_TUNING_MAX_CPU_COUNT) - if cpu_counts > 0: - num_workers = min(cpu_counts, available_cpu_count) - logger.info( - f"Auto-tuning with {cpu_counts} CPU counts, {available_cpu_count} CPUs available, {num_workers} CPUs will be used" - ) - else: - num_workers = max(1, int(available_cpu_count * cpu_utilizations)) - logger.info( - f"Auto-tuning with {cpu_utilizations} CPU utilizations, {available_cpu_count} CPUs available, {num_workers} CPUs will be used" - ) - - if max_cpu_count > 0 and num_workers > max_cpu_count: - logger.warning( - f"Auto-tuning with {cpu_utilizations} CPU utilizations, {available_cpu_count} CPUs available, {num_workers} CPUs will be used, but the max CPU count is {max_cpu_count}, so we will use {max_cpu_count} CPUs" - ) - num_workers = max_cpu_count - - pool = concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) - futures = [] - future_to_index = {} - - def device_wrapper(func, device, **config_arg): - torch.cuda.set_device(device) - return func(**config_arg) - - for i, config_arg in enumerate(config_args): - future = pool.submit( - functools.partial(device_wrapper, self.jit_compile, torch.cuda.current_device()), - **config_arg, - ) - futures.append(future) - future_to_index[future] = i - - results_with_configs = [] - for future in tqdm( - concurrent.futures.as_completed(futures), - total=len(futures), - desc="Compiling configurations"): - idx = future_to_index[future] - config = config_args[idx] - try: - result = future.result() - results_with_configs.append((result, config)) - except Exception as e: - logger.debug( - f"Compilation failed for config {config} at index {idx} with error: {e}") - continue - - ref_latency = None - progress_bar = tqdm(range(len(results_with_configs)), desc="Bench configurations") - for i in progress_bar: - jit_kernel, config = results_with_configs[i] - try: - # Cannot ThreadPoolExecutor to enforce timeout on target_fn execution - # Because tma init may behave strangely with one thread - # latency, ref_latency = target_fn(jit_kernel) - latency, ref_latency = run_with_timeout(target_fn, timeout, jit_kernel) - except TimeoutException: - logger.info( - f"A timeout occurred while testing config {config}, checkout autotuner.log for more details" - ) - continue - except Exception: - logger.info( - f"An error occurred while testing config {config}, checkout autotuner.log for more details" - ) - logger.debug(f"Error: {traceback.format_exc()}") - continue - - if latency < best_latency: - best_latency = latency - best_config = config - best_kernel = jit_kernel - - progress_bar.set_postfix({"best_latency": best_latency}) - tqdm.write(f"Tuned Latency {latency} with config {config} at index {i}") - - pool.shutdown() - - if best_kernel is None: - error_msg = ("Auto-tuning failed: No configuration successfully " - "compiled and passed benchmarking/validation.") - logger.error(error_msg) - raise RuntimeError(error_msg) - - best_kernel: tilelang.JITKernel = best_kernel.update_tuner_result( - latency=best_latency, - config=best_config, - ref_latency=ref_latency, - ) - - autotuner_result = AutotuneResult( - latency=best_latency, - config=best_config, - ref_latency=ref_latency, - libcode=best_kernel.get_kernel_source(), - func=best_kernel.prim_func, - kernel=best_kernel) - - if self.compile_args.execution_backend == "dlpack": - logger.warning("DLPack backend does not support cache saving to disk.") - else: - with self._lock: - if is_cache_enabled(): - self._save_result_to_disk(key, autotuner_result) - - self._memory_cache[key] = autotuner_result - - return autotuner_result - - def __call__(self) -> Any: - """Make the AutoTuner callable, running the auto-tuning process. - - Returns: - AutotuneResult: Results of the auto-tuning process. - """ - return self.run() - - -class _AutoTunerImplementation: - # Overload __init__ to help type checkers understand the effect of return_program - # The '-> None' is for __init__ itself. The crucial part is Literal for return_program. - - warmup: int = 25 - rep: int = 100 - timeout: int = 100 - configs: Union[Dict, Callable] = None - supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto - ref_prog: Callable = None - supply_prog: Callable = None - rtol: float = 1e-2 - atol: float = 1e-2 - max_mismatched_ratio: float = 0.01 - skip_check: bool = False - manual_check_prog: Callable = None - cache_input_tensors: bool = False - - def __init__(self, - configs: Union[Dict, Callable], - warmup: int = 25, - rep: int = 100, - timeout: int = 100, - supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto, - ref_prog: Callable = None, - supply_prog: Callable = None, - rtol: float = 1e-2, - atol: float = 1e-2, - max_mismatched_ratio: float = 0.01, - skip_check: bool = False, - manual_check_prog: Callable = None, - cache_input_tensors: bool = False) -> None: - """Initialize the AutoTunerImplementation. - - Args: - configs: Configuration space to explore during auto-tuning. - warmup: Number of warmup iterations before timing. - rep: Number of repetitions for timing measurements. - timeout: Maximum time (in seconds) allowed for each configuration. - supply_type: Strategy for generating input tensors (random/zeros/etc) - ref_prog: Reference implementation for validation - supply_prog: Custom function to provide input tensors - rtol: Relative tolerance for numerical validation - atol: Absolute tolerance for numerical validation - max_mismatched_ratio: Allowed percentage of mismatched values - skip_check: Bypass validation against reference implementation - manual_check_prog: Custom validation function - cache_input_tensors: Reuse input tensors across trials - """ - # Configuration and benchmarking parameters - self.configs = configs # Search space of tuning configurations - self.warmup = warmup # Warmup iterations for stable measurements - self.rep = rep # Measurement repetitions for statistics - self.timeout = timeout # Per-configuration timeout threshold - - # Tensor handling and validation setup - self.supply_type = supply_type # Input tensor generation strategy - self.ref_prog = ref_prog # Ground truth implementation - self.supply_prog = supply_prog # Custom input data provider - self.rtol = rtol # Relative error tolerance - self.atol = atol # Absolute error tolerance - self.max_mismatched_ratio = max_mismatched_ratio # Allowed mismatch - - # Validation control flags - self.skip_check = skip_check # Bypass accuracy verification - self.manual_check_prog = manual_check_prog # Custom validation - self.cache_input_tensors = cache_input_tensors # Reuse inputs - - # Cache for storing tuned kernel implementations - self._tuner_cache: Dict[tuple, tilelang.JITKernel] = {} # (args, kwargs) -> compiled kernel - - # This tells the type checker what the *wrapper* function will return. - # this is for linting, please do not remove it. - @overload - def __call__(self, fn: Callable[_P, _RProg]) -> Callable[_P, Tuple[_RProg, AutotuneResult]]: - ... - - @overload - def __call__(self, fn: Callable[_P, _RProg]) -> Callable[_P, AutotuneResult]: - ... - - # Actual implementation of __call__ - def __call__(self, fn: Callable[_P, _RProg]) -> Callable[_P, Any]: - warmup = self.warmup - rep = self.rep - timeout = self.timeout - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - - key_args_tuple = args - key_kwargs_tuple = tuple(sorted(kwargs.items())) - key = (key_args_tuple, key_kwargs_tuple) - - if key not in self._tuner_cache: - - def jit_compile(**config_arg): - return fn(*args, **kwargs, __tune_params=config_arg) - - compile_arguments = fn(__return_compile_arguments=True) - - autotuner = AutoTuner( - fn, configs=self.configs).set_profile_args( - supply_type=self.supply_type, - ref_prog=self.ref_prog, - supply_prog=self.supply_prog, - rtol=self.rtol, - atol=self.atol, - max_mismatched_ratio=self.max_mismatched_ratio, - skip_check=self.skip_check, - manual_check_prog=self.manual_check_prog, - cache_input_tensors=self.cache_input_tensors, - ).set_compile_args( - out_idx=compile_arguments['out_idx'], - execution_backend=compile_arguments['execution_backend'], - target=compile_arguments['target'], - target_host=compile_arguments['target_host'], - verbose=compile_arguments['verbose'], - pass_configs=compile_arguments['pass_configs'], - ) - - autotuner.jit_compile = jit_compile - autotuner.set_kernel_parameters(key) - - autotuner.run = partial(autotuner.run, warmup, rep, timeout) - - artifact = autotuner.run() - - self._tuner_cache[key] = artifact.kernel - - return self._tuner_cache[key] - - return wrapper - - -def autotune( # This is the new public interface - func: Union[Callable[_P, _RProg], PrimFunc, None] = None, - *, # Indicates subsequent arguments are keyword-only - configs: Union[Dict, Callable], - # profile arguments - warmup: int = 25, - rep: int = 100, - timeout: int = 100, - # compile arguments - supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto, - ref_prog: Callable = None, - supply_prog: Callable = None, - rtol: float = 1e-2, - atol: float = 1e-2, - max_mismatched_ratio: float = 0.01, - skip_check: bool = False, - manual_check_prog: Callable = None, - cache_input_tensors: bool = False, -): - """ - Just-In-Time (JIT) compiler decorator for TileLang functions. - - This decorator can be used without arguments (e.g., `@tilelang.jit`): - Applies JIT compilation with default settings. - - Tips: - - If you want to skip the auto-tuning process, you can set override the tunable parameters in the function signature. - ```python - if enable_autotune: - kernel = flashattn(batch, heads, seq_len, dim, is_causal) - else: - kernel = flashattn( - batch, heads, seq_len, dim, is_causal, groups=groups, block_M=128, block_N=128, num_stages=2, threads=256) - ``` - - Parameters - ---------- - func_or_out_idx : Any, optional - If using `@tilelang.jit(...)` to configure, this is the `out_idx` parameter. - If using `@tilelang.jit` directly on a function, this argument is implicitly - the function to be decorated (and `out_idx` will be `None`). - configs : Dict or Callable - Configuration space to explore during auto-tuning. - warmup : int, optional - Number of warmup iterations before timing. - rep : int, optional - Number of repetitions for timing measurements. - timeout : int, optional - target : Union[str, Target], optional - Compilation target for TVM (e.g., "cuda", "llvm"). Defaults to "auto". - target_host : Union[str, Target], optional - Target host for cross-compilation. Defaults to None. - execution_backend : Literal["dlpack", "ctypes", "cython"], optional - Backend for kernel execution and argument passing. Defaults to "cython". - verbose : bool, optional - Enables verbose logging during compilation. Defaults to False. - pass_configs : Optional[Dict[str, Any]], optional - Configurations for TVM's pass context. Defaults to None. - debug_root_path : Optional[str], optional - Directory to save compiled kernel source for debugging. Defaults to None. - - Returns - ------- - Callable - Either a JIT-compiled wrapper around the input function, or a configured decorator - instance that can then be applied to a function. - """ - if callable(func): - # Case 1: Used as @autotune (func_or_out_idx is the function, others are defaults) - # This is a placeholder for a real auto tuner implementation - raise ValueError( - "Use tilelang.autotune to decorate func without arguments is not supported yet.") - elif isinstance(func, PrimFunc): - raise ValueError("Use tilelang.jit to decorate prim_func is not supported yet.") - else: - # Case 2: Used as @autotune(...) to configure, or func_or_out_idx is meant as out_idx. - # Create a _AutoTunerImplementation instance with the provided/defaulted arguments. - # This instance is a decorator that will be applied to the function later. - configured_decorator = _AutoTunerImplementation( - configs=configs, - warmup=warmup, - rep=rep, - timeout=timeout, - supply_type=supply_type, - ref_prog=ref_prog, - supply_prog=supply_prog, - rtol=rtol, - atol=atol, - max_mismatched_ratio=max_mismatched_ratio, - skip_check=skip_check, - manual_check_prog=manual_check_prog, - cache_input_tensors=cache_input_tensors, - ) - return configured_decorator diff --git a/tilelang/autotuner/capture.py b/tilelang/autotuner/capture.py new file mode 100644 index 000000000..c0661be4b --- /dev/null +++ b/tilelang/autotuner/capture.py @@ -0,0 +1,129 @@ +# Copyright (c) Tile-AI Corporation. +# Licensed under the MIT License. + +import threading +from typing import List, Any, Optional + +# Use thread local to store the stack +# This is to avoid the cross-thread interference +_local = threading.local() + + +class CaptureStack: + """ + A simple stack implementation for capturing items in a thread-local context. + Used to manage a stack of items (e.g., input tensors) for auto-tuning capture. + """ + + def __init__(self): + # Initialize an empty list to use as the stack + self.stack = [] + + def push(self, item): + """ + Push an item onto the top of the stack. + + Args: + item: The item to be pushed onto the stack. + """ + self.stack.append(item) + + def pop(self): + """ + Pop and return the top item from the stack. + + Returns: + The item at the top of the stack. + + Raises: + IndexError: If the stack is empty. + """ + return self.stack.pop() + + def top(self): + """ + Return the item at the top of the stack without removing it. + + Returns: + The item at the top of the stack. + + Raises: + IndexError: If the stack is empty. + """ + return self.stack[-1] + + def size(self): + """ + Return the number of items in the stack. + + Returns: + int: The size of the stack. + """ + return len(self.stack) + + def __len__(self): + """ + Return the number of items in the stack (len operator support). + + Returns: + int: The size of the stack. + """ + return len(self.stack) + + def __bool__(self): + """ + Return True if the stack is not empty, False otherwise. + + Returns: + bool: Whether the stack contains any items. + """ + return bool(self.stack) + + +def _get_current_stack() -> CaptureStack: + if not hasattr(_local, "capture_stack"): + _local.capture_stack = CaptureStack() + return _local.capture_stack + + +class AutotuneInputsCapture: + + __slots__ = ("tensors") + + def __init__(self, tensors: List[Any]): + self.tensors = tensors + + def __enter__(self) -> None: + _get_current_stack().push(self) + + def __exit__(self, exc_type, exc_val, exc_tb): + _get_current_stack().pop() + + +def set_autotune_inputs(*args) -> AutotuneInputsCapture: + """Set input tensors for auto-tuning. + + This function creates a context manager for capturing input tensors + during the auto-tuning process. It supports both: + set_autotune_inputs(a, b, c) + set_autotune_inputs([a, b, c]) + + Args: + *args: Either a single list/tuple of tensors, or multiple tensor arguments. + + Returns: + AutotuneInputsCapture: A context manager for auto-tuning inputs. + """ + if len(args) == 1 and isinstance(args[0], (list, tuple)): + tensors = list(args[0]) + else: + tensors = list(args) + return AutotuneInputsCapture(tensors) + + +def get_autotune_inputs() -> Optional[List[Any]]: + """ + Get the current autotune inputs from the stack. + """ + stack = _get_current_stack() + return stack.top().tensors if stack else None diff --git a/tilelang/autotuner/tuner.py b/tilelang/autotuner/tuner.py new file mode 100644 index 000000000..d6e500851 --- /dev/null +++ b/tilelang/autotuner/tuner.py @@ -0,0 +1,792 @@ +# Copyright (c) Tile-AI Corporation. +# Licensed under the MIT License. +"""The auto-tune module for tilelang programs. + +This module provides functionality for auto-tuning tilelang programs, including JIT compilation +and performance optimization through configuration search. +""" + +import tilelang +from tilelang import tvm as tvm +from tvm.tir import PrimFunc, Var +from tvm.target import Target +import inspect +from functools import partial +from typing import (Callable, List, Literal, Any, Optional, Union, Dict, overload, Tuple) +from tqdm import tqdm +import logging +import functools +import concurrent.futures +import torch +import os +import sys +import signal +import json +import hashlib +import threading +import traceback +from pathlib import Path + +from tilelang.env import ( + TILELANG_CACHE_DIR, + TILELANG_AUTO_TUNING_CPU_UTILITIES, + TILELANG_AUTO_TUNING_CPU_COUNTS, + TILELANG_AUTO_TUNING_MAX_CPU_COUNT, + is_cache_enabled, +) +from tilelang.autotuner.param import CompileArgs, ProfileArgs, AutotuneResult +from tilelang.autotuner.capture import get_autotune_inputs +from tilelang.jit.param import _P, _RProg +from tilelang.version import __version__ + + +class TimeoutException(Exception): + pass + + +def timeout_handler(signum, frame): + raise TimeoutException("Operation timed out") + + +def run_with_timeout(func, timeout, *args, **kwargs): + signal.signal(signal.SIGALRM, timeout_handler) + signal.alarm(timeout) + try: + result = func(*args, **kwargs) + except Exception as e: + raise e + finally: + signal.alarm(0) + return result + + +# Configure logging for the autotuner module +# TODO: Consider creating a common logger in utils +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) +logger.propagate = False + +# Lazy handler initialization flag +_logger_handlers_initialized = False + + +def _init_logger_handlers(): + global _logger_handlers_initialized + if _logger_handlers_initialized: + return + formatter = logging.Formatter('%(asctime)s %(levelname)s:%(message)s') + file_handler = logging.FileHandler('autotuner.log', mode='w') + file_handler.setLevel(logging.DEBUG) + file_handler.setFormatter(formatter) + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setLevel(logging.INFO) + console_handler.setFormatter(formatter) + logger.addHandler(file_handler) + logger.addHandler(console_handler) + _logger_handlers_initialized = True + + +def get_available_cpu_count() -> int: + """Gets the number of CPU cores available to the current process. + """ + try: + cpu_count = len(os.sched_getaffinity(0)) + except AttributeError: + cpu_count = os.cpu_count() + + return cpu_count or 1 + + +class AutoTuner: + """Auto-tuner for tilelang programs. + + This class handles the auto-tuning process by testing different configurations + and finding the optimal parameters for program execution. + + Args: + fn: The function to be auto-tuned. + configs: List of configurations to try during auto-tuning. + """ + compile_args = CompileArgs() + profile_args = ProfileArgs() + + _kernel_parameters: Optional[Tuple[str, ...]] = None + _lock = threading.Lock() # For thread safety + _memory_cache = {} # In-memory cache dictionary + cache_dir: Path = Path(TILELANG_CACHE_DIR) / "autotuner" + + def __init__(self, fn: Callable, configs): + self.fn = fn + self.configs = configs + self.ref_latency_cache = None + self.jit_input_tensors = None + self.ref_input_tensors = None + self.jit_compile = None + + @classmethod + def from_kernel(cls, kernel: Callable, configs): + """Create an AutoTuner instance from a kernel function. + + Args: + kernel: The kernel function to auto-tune. + configs: List of configurations to try. + + Returns: + AutoTuner: A new AutoTuner instance. + """ + return cls(kernel, configs) + + def set_compile_args(self, + out_idx: Union[List[int], int, None] = None, + target: Literal['auto', 'cuda', 'hip'] = 'auto', + execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython", + target_host: Union[str, Target] = None, + verbose: bool = False, + pass_configs: Optional[Dict[str, Any]] = None): + """Set compilation arguments for the auto-tuner. + + Args: + out_idx: List of output tensor indices. + target: Target platform. + execution_backend: Execution backend to use for kernel execution. + target_host: Target host for cross-compilation. + verbose: Whether to enable verbose output. + pass_configs: Additional keyword arguments to pass to the Compiler PassContext. + + Returns: + AutoTuner: Self for method chaining. + """ + self.compile_args = CompileArgs( + out_idx=out_idx, + target=target, + execution_backend=execution_backend, + target_host=target_host, + verbose=verbose, + pass_configs=pass_configs) + + return self + + def set_profile_args(self, + warmup: int = 25, + rep: int = 100, + timeout: int = 30, + supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto, + ref_prog: Callable = None, + supply_prog: Callable = None, + rtol: float = 1e-2, + atol: float = 1e-2, + max_mismatched_ratio: float = 0.01, + skip_check: bool = False, + manual_check_prog: Callable = None, + cache_input_tensors: bool = False): + """Set profiling arguments for the auto-tuner. + + Args: + supply_type: Type of tensor supply mechanism. Ignored if `supply_prog` is provided. + ref_prog: Reference program for validation. + supply_prog: Supply program for input tensors. + rtol: Relative tolerance for validation. + atol: Absolute tolerance for validation. + max_mismatched_ratio: Maximum allowed mismatch ratio. + skip_check: Whether to skip validation. + manual_check_prog: Manual check program for validation. + cache_input_tensors: Whether to cache input tensors. + warmup: Number of warmup iterations. + rep: Number of repetitions for timing. + timeout: Maximum time per configuration. + + Returns: + AutoTuner: Self for method chaining. + """ + # If the program is under `with set_autotune_inputs` context, + # the `supply_prog` will be ignored and the `get_autotune_inputs` will be used instead. + if get_autotune_inputs() is not None: + if supply_prog is not None: + logger.warning( + "`supply_prog` will be ignored as this program is under `with set_autotune_inputs` context." + ) + supply_prog = lambda _: get_autotune_inputs() # noqa: E731ยท + + self.profile_args = ProfileArgs( + supply_type=supply_type, + ref_prog=ref_prog, + supply_prog=supply_prog, + rtol=rtol, + atol=atol, + max_mismatched_ratio=max_mismatched_ratio, + skip_check=skip_check, + manual_check_prog=manual_check_prog, + cache_input_tensors=cache_input_tensors, + warmup=warmup, + rep=rep, + timeout=timeout) + + # If a custom `supply_prog` is provided, the profiler's `supply_type` setting + # becomes ineffective. The custom supply program will be used instead. + if supply_prog is not None and supply_type != tilelang.TensorSupplyType.Auto: + logger.warning("Ignoring `supply_type` passed to `set_profile_args` because " + "`supply_prog` is not None.") + + return self + + def set_kernel_parameters(self, parameters: Tuple[str, ...]): + # for cache key generation + self._kernel_parameters = parameters + + def generate_cache_key(self, parameters: Dict[str, Any]) -> Optional[AutotuneResult]: + """Generate a cache key for the auto-tuning process. + """ + # extract parameters from the function signature + op_parameters = [] + for _, default_value in parameters.items(): + if default_value.default is not inspect.Parameter.empty: + op_parameters.append(default_value.default) + + if self._kernel_parameters is not None: + op_parameters += self._kernel_parameters + + func_source = inspect.getsource(self.fn) + key_data = { + "version": __version__, + "op_parameters": tuple(op_parameters), + "func_source": func_source, + "configs": self.configs, + "compile_args": hash(self.compile_args), + "profile_args": hash(self.profile_args), + } + # Sort keys to ensure consistency + key_string = json.dumps(key_data, sort_keys=True) + return hashlib.sha256(key_string.encode()).hexdigest() + + def _save_result_to_disk(self, key, result: AutotuneResult): + result.save_to_disk(self.cache_dir / key) + + def _load_result_from_disk(self, key) -> AutotuneResult: + result = AutotuneResult.load_from_disk(self.cache_dir / key, self.compile_args) + return result + + def run(self, warmup: int = 25, rep: int = 100, timeout: int = 30): + """Run the auto-tuning process. + + Args: + warmup: Number of warmup iterations. + rep: Number of repetitions for timing. + timeout: Maximum time per configuration. + + Returns: + AutotuneResult: Results of the auto-tuning process. + """ + _init_logger_handlers() + + sig = inspect.signature(self.fn) + parameters = sig.parameters + + if isinstance(self.configs, Callable): + self.configs = self.configs(*self._kernel_parameters) + + key = self.generate_cache_key(parameters) + + with self._lock: + if is_cache_enabled(): + # First check in-memory cache + if key in self._memory_cache: + logger.warning("Found kernel in memory cache. For better performance," \ + " consider using `@tilelang.autotune` instead of direct AutoTuner.from_kernel.") + return self._memory_cache[key] + + # Then check disk cache + result = self._load_result_from_disk(key) + if result is not None: + # Populate memory cache with disk result + self._memory_cache[key] = result + return result + + best_latency: float = 1e8 + best_config: Optional[Dict[str, Any]] = None + best_kernel: Optional[tilelang.JITKernel] = None + + def _compile(**config_arg) -> tilelang.JITKernel: + compile_args = self.compile_args + return compile_args.compile_program(self.fn(**config_arg)) + + if self.jit_compile is None: + self.jit_compile = _compile + + def target_fn(jit_kernel: tilelang.JITKernel): + # Unpack the context + profile_args = self.profile_args + supply_type = profile_args.supply_type + skip_check = profile_args.skip_check + manual_check_prog = profile_args.manual_check_prog + cache_input_tensors = profile_args.cache_input_tensors + ref_prog = profile_args.ref_prog + supply_prog = profile_args.supply_prog + rtol = profile_args.rtol + atol = profile_args.atol + max_mismatched_ratio = profile_args.max_mismatched_ratio + + profiler = jit_kernel.get_profiler(tensor_supply_type=supply_type) + + # Factory functions for generating input tensors. + # This encapsulates the logic of using either a custom supply program (`supply_prog`) + # or the default profiler input generation (`profiler._get_inputs`). + def get_input_tensors_supply(with_output: bool): + + def func(): + if supply_prog is not None: + return supply_prog(profiler._get_params(with_output=with_output)) + else: + return profiler._get_inputs(with_output=with_output) + + return func + + jit_input_tensors_supply = get_input_tensors_supply(with_output=False) + ref_input_tensors_supply = get_input_tensors_supply(with_output=False) + + if cache_input_tensors: + params = profiler._get_params(with_output=False) + if self.jit_input_tensors is None: + self.jit_input_tensors = jit_input_tensors_supply() + else: + # check if the cached tensors are compatible with the current configuration + assert len(params) == len( + self.jit_input_tensors), "len(params) != len(self.jit_input_tensors)" + for p, c in zip(params, self.jit_input_tensors): + if not isinstance(c, torch.Tensor): + # skip non-tensor inputs checking + continue + + # Check tensor compatibility using generator expression + def shape_equal(a, b): + return all( + a_dim == b_dim or isinstance(a_dim, Var) or isinstance(b_dim, Var) + for a_dim, b_dim in zip(a.shape, b.shape)) + + if p.dtype != c.dtype or not shape_equal(p, c): + logger.warning( + "\nIncompatible input tensor properties detected between cached tensors and " + "tensors regenerated for the current configuration trial. " + "This can happen if different tuning configurations require different input shapes/dtypes " + "and input tensor caching is enabled.\n" + "To ensure fresh, compatible inputs are generated for every trial " + "you can disable caching by setting:\n" + " `cache_input_tensors=False`\n" + "within your `.set_compile_args(...)` call.\n") + # otherwise, regenerate the input tensors for safety + self.jit_input_tensors = jit_input_tensors_supply() + break + else: + self.jit_input_tensors = jit_input_tensors_supply() + + if (not skip_check) and (ref_prog is not None): + if manual_check_prog is not None: + profiler.manual_assert_close( + ref_prog, + input_tensors=self.jit_input_tensors, + manual_check_prog=manual_check_prog) + else: + profiler.assert_allclose( + ref_prog, + input_tensors=self.jit_input_tensors, + rtol=rtol, + atol=atol, + max_mismatched_ratio=max_mismatched_ratio) + latency = profiler.do_bench( + warmup=warmup, rep=rep, input_tensors=self.jit_input_tensors) + + if self.ref_latency_cache is None and ref_prog is not None: + self.ref_input_tensors = ref_input_tensors_supply() + self.ref_latency_cache = profiler.do_bench( + ref_prog, n_warmup=warmup, n_repeat=rep, input_tensors=self.ref_input_tensors) + + return latency, self.ref_latency_cache + + config_args = [] + for config in self.configs: + new_kwargs = {} + keys = config.keys() + for name, _ in parameters.items(): + if name in config: + new_kwargs[name] = config[name] + unused_keys = set(keys) - set(new_kwargs.keys()) + if len(unused_keys) > 0: + raise ValueError(f"Unused keys in config: {unused_keys}") + config_args.append(new_kwargs) + + if len(config_args) == 0: + raise ValueError("No configurations to tune, please check your `@autotune` decorator") + + # check if the tunable arguments has been set. + # get the back config argument + top_config, *rest = config_args + + if self._kernel_parameters is not None: + key_args_tuple, key_kwargs_tuple = self._kernel_parameters + tunable_arguments = [key for key, _ in top_config.items()] + + # Check if all tunable arguments have been tuned by comparing config keys with key_kwargs_tuple + if any(key in top_config for key, _ in key_kwargs_tuple): + logger.warning( + f"Tunable parameters {tunable_arguments} already provided during auto-tuning. Skipping compilation and using direct JIT" + ) + # compile the kernel with the provided parameters + jit_kernel = self.jit_compile() + autotuner_result = AutotuneResult( + libcode=jit_kernel.get_kernel_source(), + func=jit_kernel.prim_func, + kernel=jit_kernel) + self._memory_cache[key] = autotuner_result + return autotuner_result + # get the cpu count + available_cpu_count = get_available_cpu_count() + cpu_utilizations = float(TILELANG_AUTO_TUNING_CPU_UTILITIES) + cpu_counts = int(TILELANG_AUTO_TUNING_CPU_COUNTS) + max_cpu_count = int(TILELANG_AUTO_TUNING_MAX_CPU_COUNT) + if cpu_counts > 0: + num_workers = min(cpu_counts, available_cpu_count) + logger.info( + f"Auto-tuning with {cpu_counts} CPU counts, {available_cpu_count} CPUs available, {num_workers} CPUs will be used" + ) + else: + num_workers = max(1, int(available_cpu_count * cpu_utilizations)) + logger.info( + f"Auto-tuning with {cpu_utilizations} CPU utilizations, {available_cpu_count} CPUs available, {num_workers} CPUs will be used" + ) + + if max_cpu_count > 0 and num_workers > max_cpu_count: + logger.warning( + f"Auto-tuning with {cpu_utilizations} CPU utilizations, {available_cpu_count} CPUs available, {num_workers} CPUs will be used, but the max CPU count is {max_cpu_count}, so we will use {max_cpu_count} CPUs" + ) + num_workers = max_cpu_count + + pool = concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) + futures = [] + future_to_index = {} + + def device_wrapper(func, device, **config_arg): + torch.cuda.set_device(device) + return func(**config_arg) + + for i, config_arg in enumerate(config_args): + future = pool.submit( + functools.partial(device_wrapper, self.jit_compile, torch.cuda.current_device()), + **config_arg, + ) + futures.append(future) + future_to_index[future] = i + + results_with_configs = [] + for future in tqdm( + concurrent.futures.as_completed(futures), + total=len(futures), + desc="Compiling configurations"): + idx = future_to_index[future] + config = config_args[idx] + try: + result = future.result() + results_with_configs.append((result, config)) + except Exception as e: + logger.debug( + f"Compilation failed for config {config} at index {idx} with error: {e}") + continue + + ref_latency = None + progress_bar = tqdm(range(len(results_with_configs)), desc="Bench configurations") + for i in progress_bar: + jit_kernel, config = results_with_configs[i] + try: + # Cannot ThreadPoolExecutor to enforce timeout on target_fn execution + # Because tma init may behave strangely with one thread + # latency, ref_latency = target_fn(jit_kernel) + latency, ref_latency = run_with_timeout(target_fn, timeout, jit_kernel) + except TimeoutException: + logger.info( + f"A timeout occurred while testing config {config}, checkout autotuner.log for more details" + ) + continue + except Exception: + logger.info( + f"An error occurred while testing config {config}, checkout autotuner.log for more details" + ) + logger.debug(f"Error: {traceback.format_exc()}") + continue + + if latency < best_latency: + best_latency = latency + best_config = config + best_kernel = jit_kernel + + progress_bar.set_postfix({"best_latency": best_latency}) + tqdm.write(f"Tuned Latency {latency} with config {config} at index {i}") + + pool.shutdown() + + if best_kernel is None: + error_msg = ("Auto-tuning failed: No configuration successfully " + "compiled and passed benchmarking/validation.") + logger.error(error_msg) + raise RuntimeError(error_msg) + + best_kernel: tilelang.JITKernel = best_kernel.update_tuner_result( + latency=best_latency, + config=best_config, + ref_latency=ref_latency, + ) + + autotuner_result = AutotuneResult( + latency=best_latency, + config=best_config, + ref_latency=ref_latency, + libcode=best_kernel.get_kernel_source(), + func=best_kernel.prim_func, + kernel=best_kernel) + + if self.compile_args.execution_backend == "dlpack": + logger.warning("DLPack backend does not support cache saving to disk.") + else: + with self._lock: + if is_cache_enabled(): + self._save_result_to_disk(key, autotuner_result) + + self._memory_cache[key] = autotuner_result + + return autotuner_result + + def __call__(self) -> Any: + """Make the AutoTuner callable, running the auto-tuning process. + + Returns: + AutotuneResult: Results of the auto-tuning process. + """ + return self.run() + + +class _AutoTunerImplementation: + # Overload __init__ to help type checkers understand the effect of return_program + # The '-> None' is for __init__ itself. The crucial part is Literal for return_program. + + warmup: int = 25 + rep: int = 100 + timeout: int = 100 + configs: Union[Dict, Callable] = None + supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto + ref_prog: Callable = None + supply_prog: Callable = None + rtol: float = 1e-2 + atol: float = 1e-2 + max_mismatched_ratio: float = 0.01 + skip_check: bool = False + manual_check_prog: Callable = None + cache_input_tensors: bool = False + + def __init__(self, + configs: Union[Dict, Callable], + warmup: int = 25, + rep: int = 100, + timeout: int = 100, + supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto, + ref_prog: Callable = None, + supply_prog: Callable = None, + rtol: float = 1e-2, + atol: float = 1e-2, + max_mismatched_ratio: float = 0.01, + skip_check: bool = False, + manual_check_prog: Callable = None, + cache_input_tensors: bool = False) -> None: + """Initialize the AutoTunerImplementation. + + Args: + configs: Configuration space to explore during auto-tuning. + warmup: Number of warmup iterations before timing. + rep: Number of repetitions for timing measurements. + timeout: Maximum time (in seconds) allowed for each configuration. + supply_type: Strategy for generating input tensors (random/zeros/etc) + ref_prog: Reference implementation for validation + supply_prog: Custom function to provide input tensors + rtol: Relative tolerance for numerical validation + atol: Absolute tolerance for numerical validation + max_mismatched_ratio: Allowed percentage of mismatched values + skip_check: Bypass validation against reference implementation + manual_check_prog: Custom validation function + cache_input_tensors: Reuse input tensors across trials + """ + # Configuration and benchmarking parameters + self.configs = configs # Search space of tuning configurations + self.warmup = warmup # Warmup iterations for stable measurements + self.rep = rep # Measurement repetitions for statistics + self.timeout = timeout # Per-configuration timeout threshold + + # Tensor handling and validation setup + self.supply_type = supply_type # Input tensor generation strategy + self.ref_prog = ref_prog # Ground truth implementation + self.supply_prog = supply_prog # Custom input data provider + self.rtol = rtol # Relative error tolerance + self.atol = atol # Absolute error tolerance + self.max_mismatched_ratio = max_mismatched_ratio # Allowed mismatch + + # Validation control flags + self.skip_check = skip_check # Bypass accuracy verification + self.manual_check_prog = manual_check_prog # Custom validation + self.cache_input_tensors = cache_input_tensors # Reuse inputs + + # Cache for storing tuned kernel implementations + self._tuner_cache: Dict[tuple, tilelang.JITKernel] = {} # (args, kwargs) -> compiled kernel + + # This tells the type checker what the *wrapper* function will return. + # this is for linting, please do not remove it. + @overload + def __call__(self, fn: Callable[_P, _RProg]) -> Callable[_P, Tuple[_RProg, AutotuneResult]]: + ... + + @overload + def __call__(self, fn: Callable[_P, _RProg]) -> Callable[_P, AutotuneResult]: + ... + + # Actual implementation of __call__ + def __call__(self, fn: Callable[_P, _RProg]) -> Callable[_P, Any]: + warmup = self.warmup + rep = self.rep + timeout = self.timeout + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + + key_args_tuple = args + key_kwargs_tuple = tuple(sorted(kwargs.items())) + key = (key_args_tuple, key_kwargs_tuple) + + if key not in self._tuner_cache: + + def jit_compile(**config_arg): + return fn(*args, **kwargs, __tune_params=config_arg) + + compile_arguments = fn(__return_compile_arguments=True) + + autotuner = AutoTuner( + fn, configs=self.configs).set_profile_args( + supply_type=self.supply_type, + ref_prog=self.ref_prog, + supply_prog=self.supply_prog, + rtol=self.rtol, + atol=self.atol, + max_mismatched_ratio=self.max_mismatched_ratio, + skip_check=self.skip_check, + manual_check_prog=self.manual_check_prog, + cache_input_tensors=self.cache_input_tensors, + ).set_compile_args( + out_idx=compile_arguments['out_idx'], + execution_backend=compile_arguments['execution_backend'], + target=compile_arguments['target'], + target_host=compile_arguments['target_host'], + verbose=compile_arguments['verbose'], + pass_configs=compile_arguments['pass_configs'], + ) + + autotuner.jit_compile = jit_compile + autotuner.set_kernel_parameters(key) + + autotuner.run = partial(autotuner.run, warmup, rep, timeout) + + artifact = autotuner.run() + + self._tuner_cache[key] = artifact.kernel + + return self._tuner_cache[key] + + return wrapper + + +def autotune( # This is the new public interface + func: Union[Callable[_P, _RProg], PrimFunc, None] = None, + *, # Indicates subsequent arguments are keyword-only + configs: Union[Dict, Callable], + # profile arguments + warmup: int = 25, + rep: int = 100, + timeout: int = 100, + # compile arguments + supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto, + ref_prog: Callable = None, + supply_prog: Callable = None, + rtol: float = 1e-2, + atol: float = 1e-2, + max_mismatched_ratio: float = 0.01, + skip_check: bool = False, + manual_check_prog: Callable = None, + cache_input_tensors: bool = False, +): + """ + Just-In-Time (JIT) compiler decorator for TileLang functions. + + This decorator can be used without arguments (e.g., `@tilelang.jit`): + Applies JIT compilation with default settings. + + Tips: + - If you want to skip the auto-tuning process, you can set override the tunable parameters in the function signature. + ```python + if enable_autotune: + kernel = flashattn(batch, heads, seq_len, dim, is_causal) + else: + kernel = flashattn( + batch, heads, seq_len, dim, is_causal, groups=groups, block_M=128, block_N=128, num_stages=2, threads=256) + ``` + + Parameters + ---------- + func_or_out_idx : Any, optional + If using `@tilelang.jit(...)` to configure, this is the `out_idx` parameter. + If using `@tilelang.jit` directly on a function, this argument is implicitly + the function to be decorated (and `out_idx` will be `None`). + configs : Dict or Callable + Configuration space to explore during auto-tuning. + warmup : int, optional + Number of warmup iterations before timing. + rep : int, optional + Number of repetitions for timing measurements. + timeout : int, optional + target : Union[str, Target], optional + Compilation target for TVM (e.g., "cuda", "llvm"). Defaults to "auto". + target_host : Union[str, Target], optional + Target host for cross-compilation. Defaults to None. + execution_backend : Literal["dlpack", "ctypes", "cython"], optional + Backend for kernel execution and argument passing. Defaults to "cython". + verbose : bool, optional + Enables verbose logging during compilation. Defaults to False. + pass_configs : Optional[Dict[str, Any]], optional + Configurations for TVM's pass context. Defaults to None. + debug_root_path : Optional[str], optional + Directory to save compiled kernel source for debugging. Defaults to None. + + Returns + ------- + Callable + Either a JIT-compiled wrapper around the input function, or a configured decorator + instance that can then be applied to a function. + """ + if callable(func): + # Case 1: Used as @autotune (func_or_out_idx is the function, others are defaults) + # This is a placeholder for a real auto tuner implementation + raise ValueError( + "Use tilelang.autotune to decorate func without arguments is not supported yet.") + elif isinstance(func, PrimFunc): + raise ValueError("Use tilelang.jit to decorate prim_func is not supported yet.") + else: + # Case 2: Used as @autotune(...) to configure, or func_or_out_idx is meant as out_idx. + # Create a _AutoTunerImplementation instance with the provided/defaulted arguments. + # This instance is a decorator that will be applied to the function later. + configured_decorator = _AutoTunerImplementation( + configs=configs, + warmup=warmup, + rep=rep, + timeout=timeout, + supply_type=supply_type, + ref_prog=ref_prog, + supply_prog=supply_prog, + rtol=rtol, + atol=atol, + max_mismatched_ratio=max_mismatched_ratio, + skip_check=skip_check, + manual_check_prog=manual_check_prog, + cache_input_tensors=cache_input_tensors, + ) + return configured_decorator