From 227834e531eb073994060a8575801826224dc392 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 1 Jul 2025 17:05:52 +0800 Subject: [PATCH] Enhancement: Update quickstart and elementwise examples for improved performance and clarity - Added a line break in `quickstart.py` for better readability. - Simplified the JIT kernel compilation in `quickstart.py` by removing the unused execution backend option. - Modified `example_elementwise_add.py` to disable cache for `tilelang` and optimized the element-wise addition kernel by utilizing shared memory for input tensors, improving performance. - Updated default values for matrix dimensions and block sizes in the argument parser to enhance usability. --- .../elementwise/example_elementwise_add.py | 22 ++++--- examples/quickstart.py | 4 +- src/layout/gemm_layouts.cc | 57 ++++++++++++------- 3 files changed, 53 insertions(+), 30 deletions(-) diff --git a/examples/elementwise/example_elementwise_add.py b/examples/elementwise/example_elementwise_add.py index 82d15cf5a..effb0f70d 100644 --- a/examples/elementwise/example_elementwise_add.py +++ b/examples/elementwise/example_elementwise_add.py @@ -7,6 +7,8 @@ import tilelang.language as T from tilelang.autotuner import AutoTuner +tilelang.disable_cache() + def ref_program(x, y): return x + y @@ -19,12 +21,17 @@ def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads): def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor( (M, N), out_dtype)): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - start_x = bx * block_N - start_y = by * block_M + A_shared = T.alloc_shared((block_M, block_N), in_dtype) + B_shared = T.alloc_shared((block_M, block_N), in_dtype) + C_local = T.alloc_fragment((block_M, block_N), out_dtype) + C_shared = T.alloc_shared((block_M, block_N), out_dtype) + + T.copy(A[by * block_M, bx * block_N], A_shared) + T.copy(B[by * block_M, bx * block_N], B_shared) for (local_y, local_x) in T.Parallel(block_M, block_N): - y = start_y + local_y - x = start_x + local_x - C[y, x] = A[y, x] + B[y, x] + C_local[local_y, local_x] = A_shared[local_y, local_x] + B_shared[local_y, local_x] + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) return elem_add @@ -56,7 +63,7 @@ def kernel(block_M=None, block_N=None, threads=None): def main(): parser = argparse.ArgumentParser() - parser.add_argument("--m", type=int, default=512) + parser.add_argument("--m", type=int, default=1024) parser.add_argument("--n", type=int, default=1024) parser.add_argument("--use_autotune", action="store_true", default=False) args, _ = parser.parse_known_args() @@ -70,9 +77,8 @@ def main(): kernel = result.kernel else: # Default config - config = {"block_M": 128, "block_N": 256, "threads": 128} + config = {"block_M": 128, "block_N": 128, "threads": 128} kernel = elementwise_add(M, N, **config, in_dtype="float32", out_dtype="float32") - out = kernel(a, b) torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2) diff --git a/examples/quickstart.py b/examples/quickstart.py index e4f2110a9..297c855ad 100644 --- a/examples/quickstart.py +++ b/examples/quickstart.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. import tilelang import tilelang.language as T + # `make_mma_swizzle_layout` is a python defined layout function # specifically designed for MMA operations # which ensures the consistency with the nvidia CUTLASS Library. @@ -73,8 +74,7 @@ def main( # out_idx specifies the index of the output buffer in the argument list # if out_idx is specified, the tensor will be created during runtime # target currently can be "cuda" or "hip" or "cpu". -jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", execution_backend="cython") -# jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", execution_backend="dlpack") +jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda") # 3. Test the kernel in Python with PyTorch data import torch diff --git a/src/layout/gemm_layouts.cc b/src/layout/gemm_layouts.cc index 1461a5c67..c5df17fb7 100644 --- a/src/layout/gemm_layouts.cc +++ b/src/layout/gemm_layouts.cc @@ -21,6 +21,15 @@ static IterVar make_itervar(std::string name, PrimExpr dom) { return IterVar(Range(0, dom), var, IterVarType::kDataPar); } +Fragment makeGemmFragment8x4() { + IterVar i = make_itervar("i", 8); + IterVar j = make_itervar("j", 4); + IterVar rep = make_itervar("rep", 1); + PrimExpr forward_thread = FloorDiv(j->var, 1) + 4 * i; + PrimExpr index = FloorMod(j->var, 1); + return Fragment({i, j}, {index}, forward_thread, rep); +} + Fragment makeGemmFragment8x8() { IterVar i = make_itervar("i", 8); IterVar j = make_itervar("j", 8); @@ -29,6 +38,25 @@ Fragment makeGemmFragment8x8() { PrimExpr index = FloorMod(j->var, 2); return Fragment({i, j}, {index}, forward_thread, rep); } + +Fragment makeGemmFragment8x16() { + IterVar i = make_itervar("i", 8); + IterVar j = make_itervar("j", 16); + IterVar rep = make_itervar("rep", 1); + PrimExpr forward_thread = FloorDiv(j->var, 4) + 4 * i; + PrimExpr index = FloorMod(j->var, 4); + return Fragment({i, j}, {index}, forward_thread, rep); +} + +Fragment makeGemmFragment8x8Transposed() { + IterVar i = make_itervar("i", 8); + IterVar j = make_itervar("j", 8); + IterVar rep = make_itervar("rep", 1); + PrimExpr forward_thread = FloorDiv(i->var, 2) + 4 * j; + PrimExpr index = FloorMod(i->var, 2); + return Fragment({i, j}, {index}, forward_thread, rep); +} + /* From https://github.com/RadeonOpenCompute/amd_matrix_instruction_calculator ./matrix_calculator.py --architecture cdna1 --instruction v_mfma_f32_16x16x16f16 @@ -61,24 +89,6 @@ Fragment makeGemmFragmentC16x16CDNA() { return Fragment({i, j}, {index}, forward_thread, rep); } -Fragment makeGemmFragment8x8Transposed() { - IterVar i = make_itervar("i", 8); - IterVar j = make_itervar("j", 8); - IterVar rep = make_itervar("rep", 1); - PrimExpr forward_thread = FloorDiv(i->var, 2) + 4 * j; - PrimExpr index = FloorMod(i->var, 2); - return Fragment({i, j}, {index}, forward_thread, rep); -} - -Fragment makeGemmFragment8x16() { - IterVar i = make_itervar("i", 8); - IterVar j = make_itervar("j", 16); - IterVar rep = make_itervar("rep", 1); - PrimExpr forward_thread = FloorDiv(j->var, 4) + 4 * i; - PrimExpr index = FloorMod(j->var, 4); - return Fragment({i, j}, {index}, forward_thread, rep); -} - Fragment makeGemmFragmentC_F64(const int block_m, const int block_n, const int warp_m, const int warp_n) { ICHECK(block_m % warp_m == 0); @@ -150,8 +160,8 @@ Fragment makeGemmFragmentA(const int block_m, const int block_n, ICHECK(warp_m % 16 == 0); ICHECK(block_k % 16 == 0); // Only support 8-bit and 16-bit - ICHECK(element_size == 8 || element_size == 16) - << "element bitwidth=" << element_size; + ICHECK(element_size == 8 || element_size == 16 || element_size == 32) + << "unsupported element bitwidth=" << element_size; if (transposed) { auto base_layout = @@ -176,6 +186,13 @@ Fragment makeGemmFragmentA(const int block_m, const int block_n, auto block_layout = warp_layout->Repeat({warp_m / 16, block_k / 16}, false, false); return block_layout; + } else if (element_size == 32) { + auto base_layout = makeGemmFragment8x4()->Repeat({2, 2}, false, false); + auto warp_layout = base_layout->Repeat({block_m / warp_m, 1}, true) + ->Replicate(block_n / warp_n); + auto block_layout = + warp_layout->Repeat({warp_m / 16, block_k / 8}, false, false); + return block_layout; } else { ICHECK(0); return Fragment();