diff --git a/3rdparty/cutlass b/3rdparty/cutlass index afa177220..ad7b2f5e8 160000 --- a/3rdparty/cutlass +++ b/3rdparty/cutlass @@ -1 +1 @@ -Subproject commit afa1772203677c5118fcd82537a9c8fefbcc7008 +Subproject commit ad7b2f5e84fcfa124cb02b91d5bd26d238c0459e diff --git a/benchmark/matmul/benchmark_matmul.py b/benchmark/matmul/benchmark_matmul.py index c0f2c7583..50a5b805d 100644 --- a/benchmark/matmul/benchmark_matmul.py +++ b/benchmark/matmul/benchmark_matmul.py @@ -161,6 +161,7 @@ def matmul(M, N, K, with_roller): configs=get_configs(M, N, K, with_roller), warmup=3, rep=20, + ref_prog=ref_program, ) @jit(out_idx=[2],) def kernel( diff --git a/benchmark/matmul/benchmark_matmul_sp.py b/benchmark/matmul/benchmark_matmul_sp.py new file mode 100644 index 000000000..de1851477 --- /dev/null +++ b/benchmark/matmul/benchmark_matmul_sp.py @@ -0,0 +1,267 @@ +# Copyright (c) Tile-AI Corporation. +# Licensed under the MIT License. + +import argparse +import itertools +import logging +import torch +from triton.testing import do_bench + +import tilelang.language as T +from tilelang.autotuner import autotune +from tilelang import jit +from tilelang.layout import make_metadata_layout +# Configure logger +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +def ref_program(A, B): + """ + A reference matrix multiplication program, used to compare performance. + + Parameters + ---------- + A : numpy.ndarray + The matrix with shape (M, K). + B : numpy.ndarray + The matrix with shape (N, K). + + Returns + ------- + np.ndarray + The result of A @ B.T, shape (M, N). + """ + return A @ B.T + + +def get_configs(M, N, K): + """ + Generate a list of configuration dictionaries that will be used for tuning. + + Parameters + ---------- + with_roller : bool + Whether to enable bitblas roller to deduce search spaces + + Returns + ------- + list of dict + Each configuration dict includes various block sizes, pipeline stages, + thread numbers, and other parameters to explore during autotuning. + """ + block_M = [64, 128, 256] + block_N = [64, 128, 256] + block_K = [64, 128] + num_stages = [0, 1, 2, 3] + thread_num = [128, 256] + enable_rasterization = [True, False] + policy = [T.GemmWarpPolicy.Square] + _configs = list( + itertools.product( + block_M, + block_N, + block_K, + num_stages, + thread_num, + policy, + enable_rasterization, + )) + + configs = [ + { + "block_M": c[0], + "block_N": c[1], + "block_K": c[2], + "num_stages": c[3], + "thread_num": c[4], + "policy": c[5], + "enable_rasterization": c[6], # keep param name for backward-compat + } for c in _configs + ] + return configs + + +def matmul_sp(M, N, K): + """ + Create an autotuned matrix multiplication kernel for matrices of shape: + - A: (M, K) + - B: (N, K) + - C: (M, N) + + Parameters + ---------- + M : int + The dimension M of the matrix multiplication. + N : int + The dimension N of the matrix multiplication. + K : int + The dimension K of the matrix multiplication. + + Returns + ------- + (best_latency, best_config, ref_latency) + best_latency : float + The best latency found among the tuned configurations. + best_config : dict + The parameter configuration that yielded best_latency. + ref_latency : float + The baseline latency of the reference program (for computing speedup). + """ + + # Decorate the kernel with autotune & jit, specifying: + # - Tuning config list + # - Profiling keys + # - Warmup and repetition counts for better measurement + # - A reference program for correctness verification + # - The "tvm" profiler backend + # - HIP as the compilation target (modify as needed for your hardware) + + @autotune( + configs=get_configs(M, N, K), + warmup=3, + rep=20, + ) + @jit(out_idx=[2],) + def kernel( + block_M=None, + block_N=None, + block_K=None, + num_stages=None, + thread_num=None, + policy=None, + enable_rasterization=None, + ): + """ + The actual kernel to compute C = A @ B^T. + + Parameters + ---------- + block_M : int + Block size in M dimension. + block_N : int + Block size in N dimension. + block_K : int + Block size in K dimension. + num_stages : int + Number of pipelined stages (for asynchronous load). + thread_num : int + Number of threads to use per block. + k_pack : int + K dimension packing factor to improve memory coalescing. + + Returns + ------- + Function + A TVM Tensor Language function (T.prim_func) that computes matmul. + """ + # Use half-precision for input data to reduce memory bandwidth, + # accumulate in float for better numerical accuracy + dtype = "float16" + accum_dtype = "float" + + @T.prim_func + def main( + A_sparse: T.Tensor((M, K // 2), dtype), + E: T.Tensor((M, K // 8), 'uint8'), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), + ): + """ + The compiled TVM function for block-level matrix multiplication. + + - We divide the entire (M, N) domain into blocks of shape + (block_M, block_N). + - Each block has its own allocated shared memory for sub-blocks + of A and B. + - The partial results go into C_local, and then we copy them back + to global memory C. + """ + # Bind x-dimension to block index in N, + # y-dimension to block index in M. + with T.Kernel( + T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): + + # Allocate shared memory for A sub-block of shape (block_M, block_K) + A_shared = T.alloc_shared((block_M, block_K // 2), dtype) + # Allocate shared memory for B sub-block of shape (block_N, block_K) + B_shared = T.alloc_shared((block_N, block_K), dtype) + # Allocate shared memory for E sub-block of shape (block_M, block_K // E_factor) + E_shared = T.alloc_shared((block_M, block_K // 8), 'uint8') + # Allocate a local fragment for intermediate accumulation + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + # Allocate a shared memory for C sub-block of shape (block_M, block_N) + C_shared = T.alloc_shared((block_M, block_N), dtype) + + # Clear out the accumulation buffer + T.clear(C_local) + T.no_set_max_nreg() + + T.use_swizzle(panel_size=10, enable=enable_rasterization) + T.annotate_layout({ + E: + make_metadata_layout( + E, mma_dtype="float16", arch="sm90", backend="cutlass", + block_k=block_K), + E_shared: + make_metadata_layout( + E_shared, + mma_dtype="float16", + arch="sm90", + backend="cutlass", + block_k=block_K), + }) + # Loop over sub-blocks in K dimension, pipelined by num_stages + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + # Load a sub-block of A from global memory into A_shared + T.copy(A_sparse[by * block_M, k * block_K], A_shared) + # Load a sub-block of E from global memory into E_shared + T.copy(E[by * block_M, k * block_K // 8], E_shared) + # Load a sub-block of B from global memory into B_shared + T.copy(B[bx * block_N, k * block_K], B_shared) + # Perform a partial matrix multiplication: + # C_local += A_shared @ B_shared^T + T.gemm_sp( + A_shared, + E_shared, + B_shared, + C_local, + transpose_B=True, + policy=policy, + ) + # Write back the results from C_local to the global memory C + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return main + + return kernel() + + +if __name__ == "__main__": + # Parse command-line arguments for matrix dimensions + parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark") + parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M") + parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N") + parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K") + args = parser.parse_args() + + M, N, K = args.m, args.n, args.k + + # Compute total floating-point operations to measure throughput + total_flops = 2 * M * N * K + + # matmul(...) returns (best_latency, best_config, ref_latency) + best_result = matmul_sp(M, N, K) + best_latency = best_result.latency + best_config = best_result.config + A = torch.randn(M, K, dtype=torch.float16, device="cuda") + B = torch.randn(N, K, dtype=torch.float16, device="cuda") + ref_latency = do_bench(lambda: A @ B.T) + + # Print out the benchmark results + print(f"Best latency (s): {best_latency}") + print(f"Best TFlops: {total_flops / best_latency * 1e-9:.3f}") + print(f"Best config: {best_config}") + + print(f"Reference TFlops: {total_flops / ref_latency * 1e-9:.3f}") diff --git a/examples/elementwise/example_elementwise_add.py b/examples/elementwise/example_elementwise_add.py index effb0f70d..7263067fb 100644 --- a/examples/elementwise/example_elementwise_add.py +++ b/examples/elementwise/example_elementwise_add.py @@ -77,8 +77,9 @@ def main(): kernel = result.kernel else: # Default config - config = {"block_M": 128, "block_N": 128, "threads": 128} + config = {"block_M": 32, "block_N": 32, "threads": 128} kernel = elementwise_add(M, N, **config, in_dtype="float32", out_dtype="float32") + out = kernel(a, b) torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2) diff --git a/examples/sparse_tensorcore/test_example_sparse_tensorcore.py b/examples/sparse_tensorcore/test_example_sparse_tensorcore.py new file mode 100644 index 000000000..fdb126837 --- /dev/null +++ b/examples/sparse_tensorcore/test_example_sparse_tensorcore.py @@ -0,0 +1,15 @@ +# Copyright (c) Tile-AI Corporation. +# Licensed under the MIT License. +import tilelang.testing +import tilelang +import tilelang_example_sparse_tensorcore + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version(9, 0) +def test_tilelang_example_sparse_tensorcore(): + tilelang_example_sparse_tensorcore.main() + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py b/examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py new file mode 100644 index 000000000..1ec197469 --- /dev/null +++ b/examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py @@ -0,0 +1,133 @@ +# Copyright (c) Tile-AI Organization. +# Licensed under the MIT License. +import torch +import tilelang +from tilelang.utils.sparse import compress_sm90 +from tilelang.layout import make_metadata_layout +import tilelang.testing + + +@tilelang.jit(out_idx=[-1]) +def matmul_sp( + M, + N, + K, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, +): + A_sparse_shape = (M, K // 2) + B_shape = (K, N) + A_shared_shape = (block_M, block_K // 2) + B_shared_shape = (block_K, block_N) + + import tilelang.language as T + + @T.prim_func + def main( + A_sparse: T.Tensor(A_sparse_shape, in_dtype), + E: T.Tensor((M, K // 8), 'uint8'), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + E_shared = T.alloc_shared((block_M, block_K // 8), 'uint8') + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.annotate_layout({ + E: + make_metadata_layout( + E, mma_dtype="float16", arch="sm90", backend="cutlass", block_k=block_K), + E_shared: + make_metadata_layout( + E_shared, + mma_dtype="float16", + arch="sm90", + backend="cutlass", + block_k=block_K), + }) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(E[by * block_M, k * block_K // 8], E_shared) + T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared) + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm_sp(A_shared, E_shared, B_shared, C_local, False, False) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def generate_2_to_4_sparse_tensor(shape, dtype=torch.float32, device='cpu'): + if shape[-1] % 4 != 0: + raise ValueError("Last dimension must be divisible by 4 for 2:4 sparsity.") + + full_tensor = torch.randn(shape, dtype=dtype, device=device) + mask = torch.zeros_like(full_tensor, dtype=torch.bool) + + group_count = shape[-1] // 4 + group_shape = shape[:-1] + (group_count, 4) + + reshaped = full_tensor.view(*group_shape) + + for idx in range(reshaped.numel() // 4): + flat_idx = torch.randint(0, 4, (2,), dtype=torch.int64) + while flat_idx[0] == flat_idx[1]: + flat_idx[1] = torch.randint(0, 4, (1,), dtype=torch.int64) + i = idx // group_count + j = idx % group_count + mask.view(*group_shape)[i, j, flat_idx[0]] = True + mask.view(*group_shape)[i, j, flat_idx[1]] = True + + sparse_tensor = full_tensor * mask + return sparse_tensor + + +def run_gemm_sp( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + block_M, + block_N, + block_K, + num_stages, + num_threads, +): + kernel = matmul_sp( + M, + N, + K, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + num_threads, + ) + + A = generate_2_to_4_sparse_tensor((M, K), dtype=torch.float16, device='cuda') + A_sparse, E = compress_sm90(A, block_k=block_K, transposed=False) + B = torch.randn((K, N), device='cuda', dtype=torch.float16) + + C_sp = kernel(A_sparse, E, B).half() + C = torch.matmul(A, B) + torch.testing.assert_close(C_sp, C, atol=1e-3, rtol=1e-3) + print("pass") + + +def main(): + run_gemm_sp(512, 1024, 768, "float16", "float16", "float32", 128, 128, 128, 2, 128) + + +if __name__ == "__main__": + main() diff --git a/src/op/bulk_copy.cc b/src/op/bulk_copy.cc index 3175feaa3..ae2b09b7f 100644 --- a/src/op/bulk_copy.cc +++ b/src/op/bulk_copy.cc @@ -110,6 +110,12 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const { Array global_range = is_load ? src_range : dst_range; Array shared_range = is_load ? dst_range : src_range; + if (T.layout_map.count(global_tensor)) { + LOG(WARNING) << "TMA bulk copy cannot support a non-swizzled global " + "layout, fallback to normal copy."; + return Stmt(); + } + Array indices; for (auto r : shared_range) indices.push_back(r->min); @@ -135,10 +141,6 @@ Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer) const { shared_layout = T.layout_map[shared_tensor]; shared_tensor = T.buffer_remap[shared_tensor]; } - if (T.layout_map.count(global_tensor)) { - ICHECK(T.layout_map.count(global_tensor) == 0) - << "Cannot support global layout."; - } TMADesc desc; diff --git a/src/op/gemm_sp.cc b/src/op/gemm_sp.cc new file mode 100644 index 000000000..5f0ac0329 --- /dev/null +++ b/src/op/gemm_sp.cc @@ -0,0 +1,322 @@ +// Copyright (c) Tile-AI Corporation. +// Licensed under the MIT License. + +/*! + * \file tl/op/gemm_sp.cc + * + * Define gemm_sp operator. + */ + +#include "gemm_sp.h" + +#include +#include +#include +#include + +#include "../target/utils.h" +#include "builtin.h" +#include "gemm.h" + +namespace tvm { +namespace tl { +static std::vector toPrimeFactors(int x) { + int i = 2; + std::vector result; + while (x > 1) { + if (x % i == 0) { + x /= i; + result.push_back(i); + } else { + i++; + } + } + return result; +} + +GemmSP::GemmSP(Array args, BufferMap vmap) { + A = vmap[GetVarFromAccessPtr(args[0])]; + E = vmap[GetVarFromAccessPtr(args[1])]; + B = vmap[GetVarFromAccessPtr(args[2])]; + C = vmap[GetVarFromAccessPtr(args[3])]; + trans_A = args[4].as().value(); + trans_B = args[5].as().value(); + M = args[6].as().value()->value; + N = args[7].as().value()->value; + K = args[8].as().value()->value; + policy = static_cast(args[9].as().value()->value); + clear_accum = args[10].as().value(); + if (args.size() > 11) { + kPack = args[11].as().value()->value; + if (kPack != 1 && kPack != 2) { + ICHECK(false) << "kPack must be 1 or 2"; + } + } + if (args.size() > 12) { + wg_wait = args[12].as().value()->value; + } +} + +std::pair +GemmSP::ComputeWarpPartition(int num_warps, Target target, + bool maybe_hopper_wgmma) const { + int m_warp = 1, n_warp = 1; + constexpr int kMPerWarp = 16; // Rows processed by a single warp + constexpr int kNPerWarp = 8; // Columns processed by a single warp + bool allow_wgmma = TargetIsHopper(target) && maybe_hopper_wgmma && + (this->M >= 64) && (num_warps % 4 == 0); + if (allow_wgmma) { + ICHECK(num_warps % 4 == 0) << "Warp-Group MMA requires 128×k threads."; + + constexpr int kGroup = 4; // Number of warps in a warp-group + + m_warp = kGroup; // Initially, only one warp-group on M dimension + n_warp = num_warps / m_warp; // Rest all on N dimension + + if (this->policy == GemmWarpPolicy::kFullRow) { + // Try to put as many warp-groups as possible on M dimension + // (decreasing multiples of 4, ensuring divisibility by M) + for (int cand = num_warps; cand >= kGroup; cand -= kGroup) { + if (this->M % (cand * kMPerWarp) == 0) { + m_warp = cand; + n_warp = num_warps / m_warp; + break; + } + } + } else if (this->policy == GemmWarpPolicy::kFullCol) { + // Try to use warps on N dimension; if N is not divisible, split excess + // groups to M + int cand_n = n_warp; // Initially assume all on N + if (this->N % (cand_n * kNPerWarp) != 0) { // N direction division fails + int max_n = this->N / kNPerWarp; + // Find a feasible n_warp from max possible downwards, ensuring + // num_warps/n_warp is multiple of 4 + for (int n = std::min(cand_n, max_n); n >= 1; --n) { + if (num_warps % n == 0 && (num_warps / n) % kGroup == 0) { + n_warp = n; + m_warp = num_warps / n_warp; + break; + } + } + } + } else if (this->policy == GemmWarpPolicy::kSquare) { + // Exhaustive search, but m must be multiple of 4 + int max_m = this->M / kMPerWarp; + int max_n = this->N / kNPerWarp; + + float ideal = this->N > 0 ? static_cast(this->M) / this->N : 1.f; + + float best_score = std::numeric_limits::max(); + int best_m = kGroup, best_n = n_warp; + + for (int m = kGroup; m <= num_warps && m <= max_m; m += kGroup) { + if (num_warps % m) + continue; + int n = num_warps / m; + if (n > max_n) + continue; + + float m_per_warp = static_cast(this->M) / (m * kMPerWarp); + float n_per_warp = static_cast(this->N) / (n * kNPerWarp); + float score = std::abs(m_per_warp / n_per_warp - ideal); + + if (score < best_score) { + best_score = score; + best_m = m; + best_n = n; + } + } + m_warp = best_m; + n_warp = best_n; + } else { + ICHECK(0) << "Unknown GemmWarpPolicy"; + } + + ICHECK(m_warp * n_warp == num_warps) + << "m_warp * n_warp must equal num_warps"; + return {m_warp, n_warp}; + } + + if (this->policy == GemmWarpPolicy::kFullRow) { + // Try to partition M first + m_warp = num_warps; + n_warp = 1; + + // If M cannot be evenly divided by m_warp*16, try to split remaining warps + // to N + if (this->M % (m_warp * kMPerWarp) != 0) { + // Calculate how many warps we can use for M + int max_m_warps = this->M / kMPerWarp; + m_warp = max_m_warps; + // Use remaining warps for N + n_warp = num_warps / m_warp; + if (n_warp == 0) + n_warp = 1; + } + } else if (this->policy == GemmWarpPolicy::kFullCol) { + // Try to partition N first + m_warp = 1; + n_warp = num_warps; + + // If N cannot be evenly divided by n_warp*8, try to split remaining warps + // to M + if (this->N % (n_warp * kNPerWarp) != 0) { + // Calculate how many warps we can use for N + int max_n_warps = this->N / kNPerWarp; + n_warp = max_n_warps; + // Use remaining warps for M + m_warp = num_warps / n_warp; + if (m_warp == 0) + m_warp = 1; + } + } else if (this->policy == GemmWarpPolicy::kSquare) { + // First calculate the maximum possible warps for each dimension + int max_m_warps = + this->M / kMPerWarp; // Each warp needs at least 16 elements in M + int max_n_warps = + this->N / kNPerWarp; // Each warp needs at least 8 elements in N + + // Calculate the ideal ratio of M/N warps based on the matrix dimensions + float ideal_ratio = 1.0f; + if (this->N > 0) { + ideal_ratio = static_cast(this->M) / this->N; + } + + // Start with a balanced initial guess + m_warp = 1; + n_warp = 1; + + // Try to find the best balanced partition + int best_m = 1; + int best_n = 1; + float best_balance = std::numeric_limits::max(); + + // Try all possible combinations that satisfy the constraints + for (int m = 1; m <= max_m_warps && m <= num_warps; m++) { + int n = num_warps / m; + + // Calculate how balanced this partition is + float m_per_warp = static_cast(this->M) / (m * kMPerWarp); + float n_per_warp = static_cast(this->N) / (n * kNPerWarp); + float balance = std::abs(m_per_warp / n_per_warp - ideal_ratio); + + if (balance < best_balance) { + best_balance = balance; + best_m = m; + best_n = n; + } + } + + m_warp = best_m; + n_warp = best_n; + } else { + ICHECK(0) << "Unknown GemmWarpPolicy"; + } + return {m_warp, n_warp}; +} + +Stmt GemmSP::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { + int warp_size = 32; + + auto block_size = *as_const_int(T.thread_bounds->extent); + bool maybe_wgmma = TargetIsHopper(T.target) && (this->M >= 64) && + (block_size / warp_size % 4 == 0); + + auto [warp_m, warp_n] = + ComputeWarpPartition(block_size / warp_size, T.target, maybe_wgmma); + + std::stringstream ss; + std::string op_name = "tl::gemm_sp_ss"; + ICHECK((A.scope() == "shared" || A.scope() == "shared.dyn") && + (B.scope() == "shared" || B.scope() == "shared.dyn")) + << "Only support shared.dyn scope for A and B, but received " << A.scope() + << " and " << B.scope(); + ICHECK((E.scope() == "shared" || E.scope() == "shared.dyn")) + << "Only support shared.dyn scope for E as copy from smem to rmem are " + "delegated to cute implemntation, found " + << E.scope(); + ss << op_name << "<" << M << ", " << N << ", " << K << ", "; + ss << warp_m << ", " << warp_n << ", "; + ss << trans_A << ", " << trans_B; + ss << ", " << clear_accum; + if (TargetIsHopper(T.target)) { + ss << ", " << (maybe_wgmma ? "true" : "false"); + } + if (wg_wait != 0) { + ss << ", " << wg_wait; + } + ss << ">"; + auto A_buffer = T.buffer_remap.count(A) ? T.buffer_remap[A] : A; + auto B_buffer = T.buffer_remap.count(B) ? T.buffer_remap[B] : B; + auto C_buffer = T.buffer_remap[C]; + auto E_buffer = T.buffer_remap.count(E) ? T.buffer_remap[E] : E; + + Array new_args; + new_args.push_back(StringImm(ss.str())); + new_args.push_back(A_buffer.access_ptr(1)); + new_args.push_back(B_buffer.access_ptr(1)); + new_args.push_back(C_buffer.access_ptr(3)); + new_args.push_back(E_buffer.access_ptr(1)); + auto new_call = Call(DataType::Handle(), builtin::call_extern(), new_args); + return Evaluate(new_call); +} + +LayoutMap GemmSP::InferLayout(const LayoutInferArgs &T, InferLevel level) { + if (completed_) + return {}; + LayoutMap results; + ICHECK(C.scope() == "local.fragment"); + auto thread_range = T.thread_bounds; + auto block_size = *as_const_int(thread_range->extent); + if (TargetIsHopper(T.target)) { + const int warp_size = 32; + constexpr int wgmma_m = 16 * 4; + bool maybe_wgmma = + (this->M >= wgmma_m) && (block_size / warp_size % 4 == 0); + auto [warp_m, warp_n] = + ComputeWarpPartition(block_size / warp_size, T.target, maybe_wgmma); + auto fragment = + maybe_wgmma + ? makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n, + C->dtype.bits()) + : makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits()); + results.Set(C, fragment->BindThreadRange(thread_range)); + if (A.scope() == "shared" || A.scope() == "shared.dyn") { + int dim_A = A->shape.size(); + const int64_t mat_stride = *as_const_int(A->shape[dim_A - 2]); + const int64_t mat_continuous = *as_const_int(A->shape[dim_A - 1]); + const int64_t continuity = + trans_A ? 4 * mat_continuous / warp_m : mat_continuous; + results.Set(A, makeGemmABLayoutHopper(mat_stride, mat_continuous, + mat_continuous, A->dtype.bits(), + trans_A ? 1 : 2)); + } else { + ICHECK(false) << "Not implemented"; + } + + if (B.scope() == "shared" || B.scope() == "shared.dyn") { + int dim_B = B->shape.size(); + const int64_t mat_stride = *as_const_int(B->shape[dim_B - 2]); + const int64_t mat_continuous = *as_const_int(B->shape[dim_B - 1]); + const int64_t continuity = + trans_B ? mat_continuous : mat_continuous / warp_n; + results.Set(B, + makeGemmABLayoutHopper(mat_stride, mat_continuous, continuity, + B->dtype.bits(), trans_B ? 2 : 1)); + } else { + ICHECK(false) << "WGMMA only support B in shared."; + } + } else { + ICHECK(0) << "Not supported " << T.target->str() + << " Currently only Hopper are supported"; + } + completed_ = true; + return results; +} +TIR_REGISTER_TL_OP(GemmSP, gemm_sp) + .set_num_inputs(5) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +} // namespace tl +} // namespace tvm diff --git a/src/op/gemm_sp.h b/src/op/gemm_sp.h new file mode 100644 index 000000000..6ffaa75ed --- /dev/null +++ b/src/op/gemm_sp.h @@ -0,0 +1,52 @@ +// Copyright (c) Tile-AI Corporation. +// Licensed under the MIT License. + +/*! + * \file tl/op/gemm_sp.h + * \brief Define gemm_sp operator. + * + */ + +#ifndef TVM_TL_OP_GEMM_SP_H_ +#define TVM_TL_OP_GEMM_SP_H_ + +#include "op.h" + +namespace tvm { +namespace tl { + +using namespace tir; + +class GemmSP : public Operator { +public: + GemmSP(Array args, BufferMap vmap); + Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final; + LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final; + static const Op &Get(); + enum class GemmWarpPolicy { + kSquare = 0, + kFullRow = 1, + kFullCol = 2, + } policy; + +private: + std::pair + ComputeWarpPartition(int num_warps, Target target, + bool maybe_hopper_wgmma = true) const; + + Array call_args; + tir::Buffer A, B, C, E; + bool trans_A, trans_B; + int M, N, K; + bool clear_accum = false; + // k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack + // only will be enabled under cdna mfma instructions + int kPack = 1; + int wg_wait = 0; + bool completed_ = false; +}; + +} // namespace tl +} // namespace tvm + +#endif // TVM_TL_OP_GEMM_SP_H_ diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index e52b95ce1..914dda8a0 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -127,6 +127,9 @@ std::string CodeGenTileLangCUDA::Finish() { } decl_stream << "#include \n"; + if (enable_sparse_gemm_) { + decl_stream << "#include \n"; + } decl_stream << "#include \n"; decl_stream << "#include \n"; decl_stream << "#include \n"; @@ -1390,6 +1393,14 @@ void CodeGenTileLangCUDA::VisitStmt_(const EvaluateNode *op) { stream << " " << vid_global_barrier_expect_ << " = 0;\n"; PrintIndent(); stream << "}\n"; + } else if (call && call->op.same_as(builtin::call_extern())) { + ICHECK(call->args.size() >= 1) + << "call_extern must have at least 1 argument"; + std::string func_name = call->args[0].as()->value; + if (func_name.find("tl::gemm_sp") == 0) { + enable_sparse_gemm_ = true; + } + CodeGenC::VisitStmt_(op); } else { CodeGenC::VisitStmt_(op); } diff --git a/src/target/codegen_cuda.h b/src/target/codegen_cuda.h index ded20900c..7cf594b53 100644 --- a/src/target/codegen_cuda.h +++ b/src/target/codegen_cuda.h @@ -89,6 +89,8 @@ class CodeGenTileLangCUDA final : public CodeGenC { bool enable_bf16_{false}; // whether enable fp8 bool enable_fp8_{false}; + // whether enable sparse gemm + bool enable_sparse_gemm_{false}; // whether enable int8 bool enable_int8_{false}; // whether enable warp shuffle intrinsics diff --git a/src/tl_templates/cuda/compress_sm90.cu b/src/tl_templates/cuda/compress_sm90.cu new file mode 100644 index 000000000..6635220cd --- /dev/null +++ b/src/tl_templates/cuda/compress_sm90.cu @@ -0,0 +1,167 @@ +#include + +#include + +#include "cute/tensor.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/transform/device/transform_universal_adapter.hpp" +#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" + +using namespace cute; + +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + if (error != cutlass::Status::kSuccess) { \ + std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) \ + << " at: " << __LINE__ << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + +#define CUDA_CHECK(status) \ + { \ + cudaError_t error = status; \ + if (error != cudaSuccess) { \ + std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \ + << " at line: " << __LINE__ << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } +template +std::tuple compress_impl(torch::Tensor A) { + using ElementA = T; + using ElementE = uint8_t; + using LayoutTagA = conditional_t; + using ProblemShape = cute::Shape; + + using StrideA = cutlass::gemm::TagToStrideA_t; + using StrideE = StrideA; + + // NOTE: this is derived from sparse sm90 mma atoms + // Ref: https://github.com/NVIDIA/cutlass/blob/dc4817921edda44a549197ff3a9dcf5df0636e7b/include/cute/atom/mma_traits_sm90_gmma_sparse.hpp + using SparseE = conditional_t<(sizeof_bits_v == 32), cute::sparse_elem<4, ElementE>, cute::sparse_elem<8, ElementE>>; + static constexpr GMMA::Major GmmaMajorA = transposed ? cute::SM90::GMMA::Major::MN : cute::SM90::GMMA::Major::K; + using SparseConfig = cutlass::Sm90GemmSparseConfig< + cute::sparse_elem<2, ElementA>, GmmaMajorA, + SparseE, cute::C>; + + using CompressorUtility = + cutlass::transform::kernel::StructuredSparseCompressorUtility< + ProblemShape, ElementA, LayoutTagA, SparseConfig>; + + using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor< + ProblemShape, ElementA, LayoutTagA, SparseConfig, cutlass::arch::Sm90>; + + using Compressor = cutlass::transform::device::TransformUniversalAdapter; + + TORCH_CHECK(A.is_contiguous(), "A need to be contiguous"); + TORCH_CHECK(A.dim() == 2, "Might support batch dim in the future "); + + int M = -1; + int K = -1; + int N = -1; // not used, but required for config + int L = 1; + if constexpr(transposed) { + M = A.size(1); + K = A.size(0); + } else { + M = A.size(0); + K = A.size(1); + } + + ProblemShape problem_shape = make_tuple(M, N, K, L); + StrideA stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + + CompressorUtility compressor_utility(problem_shape, stride_A); + int ME = compressor_utility.get_metadata_m_physical(); + int KE = compressor_utility.get_metadata_k_physical(); + int KC = compressor_utility.get_tensorA_k_physical(); + + StrideE stride_E = cutlass::make_cute_packed_stride(StrideE{}, cute::make_shape(ME, KE, L)); + auto dtype = A.dtype().toScalarType(); + torch::Tensor A_compressed = torch::zeros(KC * M, + torch::TensorOptions().dtype(dtype).device(A.device())); + torch::Tensor E = torch::zeros({ME, KE}, + torch::TensorOptions().dtype(torch::kUInt8).device(A.device())); + + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = A.device().index(); + hw_info.sm_count = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + typename Compressor::Arguments arguments{problem_shape, + { + A.data_ptr(), + stride_A, + A_compressed.data_ptr(), + E.data_ptr(), + }, + {hw_info}}; + + Compressor compressor_op; + size_t workspace_size = Compressor::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + CUTLASS_CHECK(compressor_op.can_implement(arguments)); + CUTLASS_CHECK(compressor_op.initialize(arguments, workspace.get())); + CUTLASS_CHECK(compressor_op.run()); + CUDA_CHECK(cudaDeviceSynchronize()); + + if constexpr (transposed) { + return std::make_tuple(A_compressed.view({KC, M}), E); + } else { + return std::make_tuple(A_compressed.view({M, KC}), E); + } +} + +// block <= 128 +// Ref https://github.com/NVIDIA/cutlass/blob/c2ad7c5b20f131c4ba33601860f1da3f9c9df0f3/include/cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl#L145-L146 +#define DISPATCH_BLOCK_K(TYPE, BLOCK_K, FACTOR, TENSOR, TRANSPOSED) \ + [&]() -> std::tuple { \ + switch (BLOCK_K) { \ + case int(32 * FACTOR): return compress_impl(TENSOR); \ + case int(64 * FACTOR): return compress_impl(TENSOR); \ + case int(128 * FACTOR): return compress_impl(TENSOR); \ + default: \ + TORCH_CHECK(false, "Unsupported block_k: ", BLOCK_K); \ + } \ + }() + +#define DISPATCH_CONTIGUOUS(TRANSPOSED) \ + [&]() -> std::tuple { \ + switch (dtype) { \ + case torch::kFloat32: \ + return DISPATCH_BLOCK_K(float, block_k, 0.5, A, TRANSPOSED); \ + case torch::kFloat16: \ + case torch::kBFloat16: \ + return DISPATCH_BLOCK_K(cute::half_t, block_k, 1, A, TRANSPOSED); \ + case torch::kFloat8_e4m3fn: \ + return DISPATCH_BLOCK_K(cute::float_e4m3_t, block_k, 2, A, TRANSPOSED); \ + case torch::kFloat8_e5m2: \ + return DISPATCH_BLOCK_K(cute::float_e5m2_t, block_k, 2, A, TRANSPOSED); \ + case torch::kChar: \ + return DISPATCH_BLOCK_K(int8_t, block_k, 2, A, TRANSPOSED); \ + case torch::kByte: \ + return DISPATCH_BLOCK_K(uint8_t, block_k, 2, A, TRANSPOSED); \ + default: \ + TORCH_CHECK(false, "Unsupported dtype"); \ + } \ + }() + +std::tuple compress_sm90(torch::Tensor A, int64_t block_k, bool transposed) { + auto dtype = A.dtype().toScalarType(); + return transposed ? DISPATCH_CONTIGUOUS(true) : DISPATCH_CONTIGUOUS(false); +} + +#undef DISPATCH_BLOCK_K +#undef DISPATCH_CONTIGUOUS + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("compress_sm90", torch::wrap_pybind_function(compress_sm90), + "compress_sm90"); +} diff --git a/src/tl_templates/cuda/gemm_sm89.h b/src/tl_templates/cuda/gemm_sm89.h index f7d8c21bc..8176b460f 100644 --- a/src/tl_templates/cuda/gemm_sm89.h +++ b/src/tl_templates/cuda/gemm_sm89.h @@ -4,6 +4,8 @@ #include #include +#include + #include #include #include @@ -21,104 +23,16 @@ using _X = Underscore; #if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 890)) -struct SM89_16x8x32_F32F8F8F32_E4M3_TN { - using DRegisters = float[4]; - using ARegisters = uint32_t[4]; - using BRegisters = uint32_t[2]; - using CRegisters = float[4]; - - CUTE_HOST_DEVICE static void fma(float &d0, float &d1, float &d2, float &d3, - uint32_t const &a0, uint32_t const &a1, - uint32_t const &a2, uint32_t const &a3, - uint32_t const &b0, uint32_t const &b1, - float const &c0, float const &c1, - float const &c2, float const &c3) { - asm volatile("mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " - "{%0, %1, %2, %3}," - "{%4, %5, %6, %7}," - "{%8, %9}," - "{%10, %11, %12, %13};\n" - : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1), - "f"(c0), "f"(c1), "f"(c2), "f"(c3)); - } -}; - -struct SM89_16x8x32_F32F8F8F32_E5M2_TN { - using DRegisters = float[4]; - using ARegisters = uint32_t[4]; - using BRegisters = uint32_t[2]; - using CRegisters = float[4]; - - CUTE_HOST_DEVICE static void fma(float &d0, float &d1, float &d2, float &d3, - uint32_t const &a0, uint32_t const &a1, - uint32_t const &a2, uint32_t const &a3, - uint32_t const &b0, uint32_t const &b1, - float const &c0, float const &c1, - float const &c2, float const &c3) { - asm volatile("mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e5m2.f32 " - "{%0, %1, %2, %3}," - "{%4, %5, %6, %7}," - "{%8, %9}," - "{%10, %11, %12, %13};\n" - : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) - : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1), - "f"(c0), "f"(c1), "f"(c2), "f"(c3)); - } -}; - -// (T32,V1) -> (M8,N8) -using SM80_8x4 = Layout, _1>, Stride, _0>>; -// (T32,V2) -> (M8,N8) -using SM80_8x8_Row = - Layout, _2>, Stride, _8>>; -// (T32,V4) -> (M8,N16) -using SM80_8x16_Row = - Layout, _4>, Stride, _8>>; -// (T32,V4) -> (M16,N8) -using SM80_16x8_Row = Layout, Shape<_2, _2>>, - Stride, Stride<_16, _8>>>; - -template <> struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = fp8_e4_t; - using ValTypeB = fp8_e4_t; - using ValTypeC = float; - - using Shape_MNK = Shape<_16, _8, _32>; - using ThrID = Layout<_32>; - using ALayout = Layout, Shape<_4, _2, _2>>, - Stride, Stride<_16, _8, _256>>>; - using BLayout = Layout, Shape<_4, _2>>, - Stride, Stride<_8, _128>>>; - using CLayout = SM80_16x8_Row; -}; - -template <> struct MMA_Traits { - using ValTypeD = float; - using ValTypeA = fp8_e5_t; - using ValTypeB = fp8_e5_t; - using ValTypeC = float; - - using Shape_MNK = Shape<_16, _8, _32>; - using ThrID = Layout<_32>; - using ALayout = Layout, Shape<_4, _2, _2>>, - Stride, Stride<_16, _8, _256>>>; - using BLayout = Layout, Shape<_4, _2>>, - Stride, Stride<_8, _128>>>; - using CLayout = SM80_16x8_Row; -}; - template struct DispatchInstruction { - using MMA = MMA_Atom; + using MMA = MMA_Atom; using MMA_Group = Tile<_X, Int, _X>; }; template struct DispatchInstruction { - using MMA = MMA_Atom; + using MMA = MMA_Atom; using MMA_Group = Tile<_X, Int, _X>; }; diff --git a/src/tl_templates/cuda/gemm_sp.h b/src/tl_templates/cuda/gemm_sp.h new file mode 100644 index 000000000..ce5352753 --- /dev/null +++ b/src/tl_templates/cuda/gemm_sp.h @@ -0,0 +1,8 @@ +// Copyright (c) Tile-AI Corporation. +// Licensed under the MIT License. +#pragma once +#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 900)) +#include "gemm_sp_sm90.h" +#else + +#endif diff --git a/src/tl_templates/cuda/gemm_sp_sm90.h b/src/tl_templates/cuda/gemm_sp_sm90.h new file mode 100644 index 000000000..d0df667e8 --- /dev/null +++ b/src/tl_templates/cuda/gemm_sp_sm90.h @@ -0,0 +1,234 @@ +// Copyright (c) Tile-AI Corporation. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include + +namespace cute { +namespace tl_wgmma_sp { +template +class GemmTensorOp { +public: + static_assert(num_warp_m % 4 == 0, "num_warp_m must be a multiple of 4"); + + using A_type = conditional_t::value, + tfloat32_t, A_type_raw>; + using B_type = conditional_t::value, + tfloat32_t, B_type_raw>; + using C_type = C_type_raw; + + static constexpr bool need_tfloat32_cast = + std::is_same::value && + std::is_same::value; + + static constexpr GMMA::Major GmmaMajorA = + trans_A ? GMMA::Major::MN : GMMA::Major::K; + static constexpr GMMA::Major GmmaMajorB = + trans_B ? GMMA::Major::K : GMMA::Major::MN; + + using TiledMma = decltype(make_tiled_mma( + GMMA::ss_op_selector_sparse< + A_type, B_type, C_type, + Shape, Int, Int>, + GmmaMajorA, GmmaMajorB>(), + Layout, Int, _1>>{})); + + using ElementAMma = typename TiledMma::ValTypeA; + using ElementAMmaSparsity = Int; + using ElementBMma = typename TiledMma::ValTypeB; + using ElementEMma = typename TiledMma::ValTypeE; + using ElementEMmaSparsity = Int; + using E_type_raw = typename ElementEMma::raw_type; + + using SparseConfig = + cutlass::Sm90GemmSparseConfig{}, _128{}))>; + + using LayoutA = decltype(SparseConfig::deduce_layoutA()); + using LayoutE = decltype(SparseConfig::deduce_layoutE()); + + using SmemLayoutAtomA = + decltype(cutlass::gemm::collective::detail::ss_smem_selector_sparse< + GmmaMajorA, A_type, Int, Int, ElementAMmaSparsity>()); + using SmemLayoutAtomB = + decltype(cutlass::gemm::collective::detail::ss_smem_selector< + GmmaMajorB, B_type, Int, Int>()); + + using SmemLayoutAtomE_ = typename SparseConfig::TensorEAtom; + using SmemLayoutAtomE = + ComposedLayout, + smem_sparse_ptr_flag_bits>, + SmemLayoutAtomE_>; + + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, Shape, Int>{}, + conditional_t, Step<_1, _2>>{})); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, Shape, Int>{}, + conditional_t, Step<_2, _1>>{})); + using SmemLayoutE = decltype(tile_to_shape( + SmemLayoutAtomE{}, Shape, Int>{}, + conditional_t, Step<_1, _2>>{})); + + using SmemCopyAtomE = AutoVectorizingCopy; + + template + static CUTE_DEVICE void body(A_type_raw *pA, B_type_raw *pB, C_type_raw *pC, + E_type_raw *pE) { + const int tid = threadIdx.x; + Tensor sA = + make_tensor(make_smem_ptr(recast_ptr(pA)), SmemLayoutA{}); + Tensor sB = + make_tensor(make_smem_ptr(recast_ptr(pB)), SmemLayoutB{}); + Tensor sE = as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(recast_ptr(pE)), SmemLayoutE{})); + + TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tid); + + Tensor tCsA = thr_mma.partition_A(sA); + Tensor tCsB = thr_mma.partition_B(sB); + Tensor tCsE = partition_E(thr_mma, sE(_, _)); + + Tensor tCrA = thr_mma.make_fragment_A(tCsA); + Tensor tCrB = thr_mma.make_fragment_B(tCsB); + Tensor tCrE = make_fragment_like(tCsE); + + auto copy_atom_E = Copy_Atom{}; + auto smem_tiled_copy_E = make_tiled_copy_E(copy_atom_E, tiled_mma); + auto smem_thr_copy_E = smem_tiled_copy_E.get_thread_slice(tid); + Tensor tEsE = smem_thr_copy_E.partition_S(sE); + Tensor tErE = smem_thr_copy_E.retile_D(tCrE); + + Tensor acc = + make_tensor(make_rmem_ptr(pC), + partition_shape_C(tiled_mma, Shape, Int>{})); + + warpgroup_fence_operand(acc); + warpgroup_arrive(); + if constexpr (clear_accum) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + copy(smem_tiled_copy_E, tEsE, tErE); + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + gemm(tiled_mma, make_zip_tensor(tCrA(_, _, k_block), tCrE(_, _, k_block)), + tCrB(_, _, k_block), acc); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + + warpgroup_commit_batch(); + if constexpr (wg_wait >= 0) { + warpgroup_wait(); + } + warpgroup_fence_operand(acc); + } + + template + CUTE_HOST_DEVICE static constexpr auto + thrfrg_E(TiledMMA const &mma, + ETensor &&etensor) { + using TiledMma = TiledMMA; + + CUTE_STATIC_ASSERT_V(rank(etensor) >= Int<2>{}); + + // Reorder the tensor for the TiledAtom + auto t_tile = make_tile(get<0>(PermutationMNK{}), get<2>(PermutationMNK{})); + auto t_tensor = logical_divide(etensor, t_tile); // (PermM,PermK) + + // Tile the tensor for the Atom + auto e_tile = + make_tile(make_layout(size<0>(typename TiledMma::AtomShape_MNK{})), + make_layout(size<2>(typename TiledMma::AtomShape_MNK{}))); + auto e_tensor = + zipped_divide(t_tensor, e_tile); // ((AtomM,AtomK),(RestM,RestK)) + + // Transform the Atom mode from (M,K) to (Thr,Val) + using AtomLayoutE_TV = typename TiledMma::Atom::Traits::ELayout; + auto tv_tensor = + e_tensor.compose(AtomLayoutE_TV{}, _); // ((ThrV,FrgV),(RestM,RestK)) + + // Tile the tensor for the Thread + auto thr_tile = + make_tile(_, make_tile(make_layout(size<1>(mma.thr_layout_vmnk_)), + make_layout(size<3>(mma.thr_layout_vmnk_)))); + auto thr_tensor = zipped_divide( + tv_tensor, thr_tile); // ((ThrV,(ThrM,ThrK)),(FrgV,(RestM,RestK))) + + return thr_tensor; + } + + template + CUTE_HOST_DEVICE static constexpr auto + get_layoutE_TV(TiledMMA const &mma) { + // (M,K) -> (M,K) + auto ref_E = make_layout(make_shape(tile_size<0>(mma), tile_size<2>(mma))); + // (ethrid,val) -> (M,K) + auto layoutE_TV = thrfrg_E(mma, ref_E); + + // (ThrV,(ThrM,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK)) + auto etile = make_tile( + _, make_tile(make_layout(make_shape(size<1>(mma.thr_layout_vmnk_), + size<2>(mma.thr_layout_vmnk_)), + make_stride(Int<1>{}, Int<0>{})), + _)); + + // thr_idx -> (ThrV,ThrM,ThrN,ThrK) + auto thridx_2_thrid = right_inverse(mma.thr_layout_vmnk_); + + // (thr_idx,val) -> (M,K) + return layoutE_TV.compose(etile, _).compose(thridx_2_thrid, _); + } + + template + CUTE_HOST_DEVICE static constexpr auto + partition_E(ThrMMA const &thr_mma, ETensor &&etensor) { + auto thr_tensor = make_tensor(static_cast(etensor).data(), + thrfrg_E(thr_mma, etensor.layout())); + + auto thr_vmk = make_coord( + get<0>(thr_mma.thr_vmnk_), + make_coord(get<1>(thr_mma.thr_vmnk_), get<3>(thr_mma.thr_vmnk_))); + return thr_tensor(thr_vmk, + make_coord(_, repeat(thr_tensor)>(_))); + } + + template + CUTE_HOST_DEVICE static constexpr auto + make_tiled_copy_E(Copy_Atom const ©_atom, + TiledMMA const &mma) { + return make_tiled_copy_impl( + copy_atom, get_layoutE_TV(mma), + make_shape(tile_size<0>(mma), tile_size<2>(mma))); + } +}; + +} // namespace tl_wgmma_sp +} // namespace cute + +namespace tl { +template , + typename E_type = typename MMA::ElementEMma::raw_type> +TL_DEVICE void gemm_sp_ss(A_type *pA, B_type *pB, C_type *accum, E_type *pE) { + static_assert(use_wgmma, "only wgmma is supported for now"); + if constexpr (use_wgmma) { + MMA::body(pA, pB, accum, pE); + } else { + CUTE_GCC_UNREACHABLE; + } +} +} // namespace tl \ No newline at end of file diff --git a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py index 11ceb8891..69861aa35 100644 --- a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py +++ b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py @@ -115,6 +115,8 @@ def test_gemm(): 32, 2) # pad_f16f16f16_nn # GEMM tests for mixed precision (float16 + float32) + run_gemm(512, 1024, 768, False, False, "float16", "float16", "float32", 128, 128, + 16) # f16f16f32_nn run_gemm(512, 1024, 768, False, False, "float16", "float16", "float32", 128, 128, 32) # f16f16f32_nn run_gemm(512 + 19, 1024 + 17, 768 + 15, False, False, "float16", "float16", "float32", 128, 64, diff --git a/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py new file mode 100644 index 000000000..9c1455964 --- /dev/null +++ b/testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py @@ -0,0 +1,239 @@ +# Copyright (c) Tile-AI Organization. +# Licensed under the MIT License. +import torch +import tilelang +import tilelang.testing + +from tilelang.utils.sparse import compress_sm90 +from tilelang.layout import make_metadata_layout + +torch.set_printoptions(threshold=float('inf'), edgeitems=float('inf'), linewidth=10000) +torch.manual_seed(42) + +STR_TO_TYPE = { + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "e4m3_float8": torch.float8_e4m3fn, + "int8": torch.int8, +} + +SPARSITY_MAP = { + torch.float16: (2, 4), + torch.bfloat16: (2, 4), + torch.float8_e4m3fn: (2, 4), + torch.int8: (2, 4), +} + + +def matmul_sp( + M, + N, + K, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + threads, + trans_A, + trans_B, +): + E_factor = 4 if in_dtype == "float32" else 8 + A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M) + B_shape = (K, N) if not trans_B else (N, K) + A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M) + B_shared_shape = (block_K, block_N) if not trans_B else (block_N, block_K) + + import tilelang.language as T + + @T.prim_func + def main( + A_sparse: T.Tensor(A_sparse_shape, in_dtype), + E: T.Tensor((M, K // E_factor), 'uint8'), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + E_shared = T.alloc_shared((block_M, block_K // E_factor), 'uint8') + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.annotate_layout({ + E: + make_metadata_layout( + E, mma_dtype="float16", arch="sm90", backend="cutlass", block_k=block_K), + E_shared: + make_metadata_layout( + E_shared, + mma_dtype="float16", + arch="sm90", + backend="cutlass", + block_k=block_K), + }) + T.no_set_max_nreg() + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(E[by * block_M, k * block_K // E_factor], E_shared) + if trans_A: + T.copy(A_sparse[k * block_K // 2, by * block_M], A_shared) + else: + T.copy(A_sparse[by * block_M, k * block_K // 2], A_shared) + if trans_B: + T.copy(B[bx * block_N, k * block_K], B_shared) + else: + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm_sp(A_shared, E_shared, B_shared, C_local, trans_A, trans_B) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def generate_sparse_tensor_float32(M: int, K: int, dtype: torch.dtype, device='cpu', trans_A=False): + elem, group = SPARSITY_MAP[dtype] + if K % group != 0: + raise ValueError( + f"Last dimension must be divisible by {group} for {elem}:{group} sparsity.") + + if trans_A: + full_tensor = torch.randn(K * M, dtype=torch.float32, device=device).view(K, M) + mask = torch.zeros_like(full_tensor, dtype=torch.bool) + for j in range(M): + for i in range(0, K, group): + flat_idx = torch.randint(0, group, (elem,), dtype=torch.int64) + for k in range(1, len(flat_idx)): + while flat_idx[k] in flat_idx[:k]: + flat_idx[k] = torch.randint(0, group, (1,), dtype=torch.int64) + for idx in flat_idx: + mask[i + idx, j] = True + else: + full_tensor = torch.randn((M, K), dtype=torch.float32, device=device).view(M, K) + mask = torch.zeros_like(full_tensor, dtype=torch.bool) + for i in range(M): + for j in range(0, K, group): + flat_idx = torch.randint(0, group, (elem,), dtype=torch.int64) + for k in range(1, len(flat_idx)): + while flat_idx[k] in flat_idx[:k]: + flat_idx[k] = torch.randint(0, group, (1,), dtype=torch.int64) + for idx in flat_idx: + mask[i, j + idx] = True + + return full_tensor * mask + + +def normalize(tensor, max_range=100.0): + assert max_range <= 448.0 + max_v = tensor.abs().max().clamp(1e-4) + scaler = max_range / max_v + return tensor * scaler + + +def calc_diff(x, y): + x, y = x.double(), y.double() + denominator = (x * x + y * y).sum() + sim = 2 * (x * y).sum() / denominator + return 1 - sim + + +def run_gemm_sp( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + block_M, + block_N, + block_K, + num_stages, + num_threads, + trans_A=False, + trans_B=False, +): + program = matmul_sp( + M, + N, + K, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + num_threads, + trans_A, + trans_B, + ) + if in_dtype == "float32": + torch.backends.cuda.matmul.allow_tf32 = True + + kernel = tilelang.compile( + program, + out_idx=[-1], + ) + A = generate_sparse_tensor_float32( + M, K, dtype=STR_TO_TYPE[in_dtype], device='cuda', trans_A=trans_A) + if trans_B: + B = torch.randn((N, K), device='cuda', dtype=torch.float32) + else: + B = torch.randn((K, N), device='cuda', dtype=torch.float32) + + if "float8" in in_dtype or "int8" in in_dtype: + A = normalize(A) + B = normalize(B) + + A = A.to(STR_TO_TYPE[in_dtype]) + B = B.to(STR_TO_TYPE[in_dtype]) + + A_sparse, E = compress_sm90(A, block_K, trans_A) + + C_sp = kernel(A_sparse, E, B) + + def _matmul(A, B): + if trans_A: + A = A.T + if trans_B: + B = B.T + if "float8" in in_dtype or "int8" in in_dtype: + A = A.to(torch.float32) + B = B.to(torch.float32) + return torch.matmul(A, B).to(STR_TO_TYPE[out_dtype]) + + C = _matmul(A, B) + if 'float8' in in_dtype: + diff = calc_diff(C_sp, C) + assert diff < 1e-3, f"{diff=}" + else: + torch.testing.assert_close(C_sp, C, atol=1e-3, rtol=1e-3) + print("pass") + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version(9, 0) +def test_gemm_sp(): + run_gemm_sp(512, 1024, 768, "float16", "float16", "float32", 64, 64, 32, 2, 128) + run_gemm_sp(512, 1024, 768, "float16", "float16", "float32", 64, 64, 32, 0, 256) + + run_gemm_sp(512, 1024, 768, "float16", "float16", "float32", 64, 64, 64, 0, 128) + run_gemm_sp(512, 1024, 768, "float16", "float16", "float32", 64, 64, 64, 2, 128) + + run_gemm_sp(512, 1024, 768, "float16", "float16", "float32", 128, 128, 128, 0, 128) + run_gemm_sp(512, 1024, 768, "float16", "float16", "float32", 128, 128, 128, 2, 128) + + run_gemm_sp(512, 1024, 768, "float16", "float16", "float32", 64, 128, 256, 0, 128) + run_gemm_sp(512, 1024, 768, "float16", "float16", "float32", 64, 128, 256, 2, 128) + + run_gemm_sp(512, 1024, 768, "float16", "float16", "float32", 64, 64, 64, 0, 128, False, True) + run_gemm_sp(512, 1024, 768, "float16", "float16", "float32", 64, 64, 64, 0, 128, True, False) + run_gemm_sp(512, 1024, 768, "float16", "float16", "float32", 64, 64, 64, 0, 128, True, True) + + run_gemm_sp(512, 1024, 768, "e4m3_float8", "float16", "float16", 64, 64, 64, 2, 128, False, + True) + + run_gemm_sp(512, 1024, 768, "int8", "int8", "int32", 64, 64, 64, 2, 128, False, True) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/testing/python/utils/test_compress_utils.py b/testing/python/utils/test_compress_utils.py new file mode 100644 index 000000000..7f7b338ed --- /dev/null +++ b/testing/python/utils/test_compress_utils.py @@ -0,0 +1,64 @@ +# Copyright (c) Tile-AI Organization. +# Licensed under the MIT License. +import torch +import tilelang +from tilelang.utils.sparse import compress_sm90 + + +def generate_2_to_4_sparse_tensor(shape, dtype=torch.float32, device='cpu'): + if shape[-1] % 4 != 0: + raise ValueError("Last dimension must be divisible by 4 for 2:4 sparsity.") + + full_tensor = torch.randn(shape, dtype=torch.float32, device=device) + mask = torch.zeros_like(full_tensor, dtype=torch.bool) + + group_count = shape[-1] // 4 + group_shape = shape[:-1] + (group_count, 4) + + reshaped = full_tensor.view(*group_shape) + + for idx in range(reshaped.numel() // 4): + flat_idx = torch.randint(0, 4, (2,), dtype=torch.int64) + while flat_idx[0] == flat_idx[1]: + flat_idx[1] = torch.randint(0, 4, (1,), dtype=torch.int64) + i = idx // group_count + j = idx % group_count + mask.view(*group_shape)[i, j, flat_idx[0]] = True + mask.view(*group_shape)[i, j, flat_idx[1]] = True + + sparse_tensor = full_tensor * mask + return sparse_tensor.to(dtype) + + +def _test_compress_sm90(M, K, block_k, dtype): + A = generate_2_to_4_sparse_tensor((M, K), dtype=dtype, device='cuda') + A_sparse, E = compress_sm90(A, block_k, False) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version(9, 0) +def test_compress_sm90(): + _test_compress_sm90(1024, 1024, 128, torch.float16) + _test_compress_sm90(1024, 1024, 64, torch.float16) + _test_compress_sm90(1024, 1024, 32, torch.float16) + + _test_compress_sm90(1024, 1024, 128, torch.bfloat16) + _test_compress_sm90(1024, 1024, 64, torch.bfloat16) + _test_compress_sm90(1024, 1024, 32, torch.bfloat16) + + _test_compress_sm90(1024, 1024, 64, torch.float32) + _test_compress_sm90(1024, 1024, 32, torch.float32) + _test_compress_sm90(1024, 1024, 16, torch.float32) + + _test_compress_sm90(1024, 1024, 256, torch.float8_e4m3fn) + _test_compress_sm90(1024, 1024, 128, torch.float8_e4m3fn) + _test_compress_sm90(1024, 1024, 64, torch.float8_e4m3fn) + + _test_compress_sm90(1024, 1024, 256, torch.float8_e5m2) + _test_compress_sm90(1024, 1024, 128, torch.float8_e5m2) + _test_compress_sm90(1024, 1024, 64, torch.float8_e5m2) + + +if __name__ == "__main__": + test_compress_sm90() + print("All tests passed.") diff --git a/tilelang/env.py b/tilelang/env.py index 4dd504b72..ef8c28922 100644 --- a/tilelang/env.py +++ b/tilelang/env.py @@ -49,6 +49,21 @@ def _find_rocm_home() -> str: return rocm_home if rocm_home is not None else "" +def _initialize_torch_cuda_arch_flags(): + import os + from tilelang.contrib import nvcc + from tilelang.utils.target import determine_target + + target = determine_target(return_object=True) + # create tmp source file for torch cpp extension + compute_version = "".join(nvcc.get_target_compute_version(target).split(".")) + # set TORCH_CUDA_ARCH_LIST + major = compute_version[0] + minor = compute_version[1] + + os.environ["TORCH_CUDA_ARCH_LIST"] = f"{major}.{minor}" + + CUDA_HOME = _find_cuda_home() ROCM_HOME = _find_rocm_home() @@ -197,4 +212,5 @@ def is_enabled(cls) -> bool: "enable_cache", "disable_cache", "is_cache_enabled", + "_initialize_torch_cuda_arch_flags", ] diff --git a/tilelang/jit/env.py b/tilelang/jit/env.py index ed8e38408..df8415975 100644 --- a/tilelang/jit/env.py +++ b/tilelang/jit/env.py @@ -31,21 +31,6 @@ ) -def _initialize_torch_cuda_arch_flags(): - import os - from tilelang.contrib import nvcc - from tilelang.utils.target import determine_target - - target = determine_target(return_object=True) - # create tmp source file for torch cpp extension - compute_version = "".join(nvcc.get_target_compute_version(target).split(".")) - # set TORCH_CUDA_ARCH_LIST - major = compute_version[0] - minor = compute_version[1] - - os.environ["TORCH_CUDA_ARCH_LIST"] = f"{major}.{minor}" - - def _get_workspace_dir_name() -> pathlib.Path: try: from tilelang.contrib import nvcc @@ -64,7 +49,6 @@ def _get_workspace_dir_name() -> pathlib.Path: return pathlib.Path.home() / ".cache" / "tilelang" / arch -# _initialize_torch_cuda_arch_flags() TILELANG_JIT_WORKSPACE_DIR = _get_workspace_dir_name() TILELANG_JIT_DIR = TILELANG_JIT_WORKSPACE_DIR / "cached_ops" TILELANG_GEN_SRC_DIR = TILELANG_JIT_WORKSPACE_DIR / "generated" diff --git a/tilelang/language/__init__.py b/tilelang/language/__init__.py index f62145b8b..bc4e96356 100644 --- a/tilelang/language/__init__.py +++ b/tilelang/language/__init__.py @@ -44,6 +44,7 @@ ) from .copy import copy, c2d_im2col # noqa: F401 from .gemm import GemmWarpPolicy, gemm # noqa: F401 +from .experimental.gemm_sp import gemm_sp # noqa: F401 from .fill import fill, clear # noqa: F401 from .reduce import ( reduce, # noqa: F401 diff --git a/tilelang/language/experimental/__init__.py b/tilelang/language/experimental/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tilelang/language/experimental/gemm_sp.py b/tilelang/language/experimental/gemm_sp.py new file mode 100644 index 000000000..7a1dfe062 --- /dev/null +++ b/tilelang/language/experimental/gemm_sp.py @@ -0,0 +1,88 @@ +# Copyright (c) Tile-AI Corporation. +# Licensed under the MIT License. +"""The language interface for tl programs.""" + +from tilelang.primitives.gemm.base import GemmWarpPolicy +import tilelang.language as T +from tvm import tir +from typing import Union + + +def gemm_sp( + A_sparse: Union[tir.Buffer, tir.Var], + E: Union[tir.Buffer, tir.Var], + B: Union[tir.Buffer, tir.Var], + C: Union[tir.Buffer, tir.Var], + transpose_A: bool = False, + transpose_B: bool = False, + policy: GemmWarpPolicy = GemmWarpPolicy.Square, + clear_accum: bool = False, + k_pack: int = 1, + wg_wait: int = 0, +): + """Perform a Sparse General Matrix Multiplication (GEMM-sp) operation. + + This function computes C = A @ B where A and B can optionally be transposed. + The operation supports various warp policies and accumulation modes. + + Args: + A_sparse (Union[tir.Buffer, tir.Var]): First input matrix dense values + E (Union[tir.Buffer, tir.Var]): First input matrix sparse metadata + B (Union[tir.Buffer, tir.Var]): Second input matrix + C (Union[tir.Buffer, tir.Var]): Output matrix for results + transpose_A (bool, optional): Whether to transpose matrix A. Defaults to False. + transpose_B (bool, optional): Whether to transpose matrix B. Defaults to False. + policy (GemmWarpPolicy, optional): Warp execution policy. Defaults to GemmWarpPolicy.Square. + clear_accum (bool, optional): Whether to clear accumulator before computation. Defaults to False. + k_pack (int, optional): Number of k dimensions packed into a single warp. Defaults to 1. + wg_wait (int, optional): Warp group wait count. Defaults to 0. + + Returns: + tir.Call: A handle to the GEMM operation + + Raises: + AssertionError: If the K dimensions of matrices A and B don't match + """ + + def legalize_arguments(arg: Union[tir.Buffer, tir.Var]): + """Convert let-bound variables to their corresponding buffers. + + Args: + arg (Union[tir.Buffer, tir.Var]): Input argument to legalize + + Returns: + Union[tir.Buffer, tir.Var]: The legalized argument + """ + if isinstance(arg, tir.Var) and T.has_let_value(arg): + return T.get_let_value(arg).buffer + return arg + + A_sparse = legalize_arguments(A_sparse) + B = legalize_arguments(B) + C = legalize_arguments(C) + M = C.shape[0] + N = C.shape[1] + K_A = A_sparse.shape[0] if transpose_A else A_sparse.shape[1] + K_B = B.shape[1] if transpose_B else B.shape[0] + assert K_A * 2 == K_B, f"T.gemm_sp K shape check failed: K_A = {K_A}, K_B = {K_B}" + Aptr = A_sparse.access_ptr("r") + Bptr = B.access_ptr("r") + Cptr = C.access_ptr("rw") + Eptr = E.access_ptr("r") + return tir.call_intrin( + "handle", + tir.op.Op.get("tl.gemm_sp"), + Aptr, + Eptr, + Bptr, + Cptr, + transpose_A, + transpose_B, + M, + N, + K_B, + policy, + clear_accum, + k_pack, + wg_wait, + ) diff --git a/tilelang/layout/__init__.py b/tilelang/layout/__init__.py index 41db70448..73fccb2b2 100644 --- a/tilelang/layout/__init__.py +++ b/tilelang/layout/__init__.py @@ -6,3 +6,4 @@ from .layout import Layout # noqa: F401 from .fragment import Fragment # noqa: F401 from .swizzle import make_swizzled_layout # noqa: F401 +from .gemm_sp import make_metadata_layout # noqa: F401 \ No newline at end of file diff --git a/tilelang/layout/gemm_sp.py b/tilelang/layout/gemm_sp.py new file mode 100644 index 000000000..46724a8ac --- /dev/null +++ b/tilelang/layout/gemm_sp.py @@ -0,0 +1,111 @@ +# Copyright (c) Tile-AI Corporation. +# Licensed under the MIT License. +"""Wrapping Layouts.""" +# pylint: disable=invalid-name, unsupported-binary-operation + +import tvm +import tilelang.language as T +import warnings + +from typing import List +from math import prod + + +def decompose_col_major(index_1d: int, basis: List[int]) -> List[int]: + res = [] + for x in basis: + res.append(index_1d % x) + index_1d //= x + return res + + +def __make_metadata_layout_sm90_cutlass(buffer: tvm.tir.Buffer, mma_dtype: str, block_k: int): + if block_k > 128: + block_k = 128 + # Ref: https://github.com/NVIDIA/cutlass/blob/c2ad7c5b20f131c4ba33601860f1da3f9c9df0f3/include/cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl#L145-L146 + warnings.warn(f"block_k {block_k} is too large, set to 128 for {mma_dtype}.", stacklevel=2) + if mma_dtype not in ["float16", "bfloat16", "float32", "int8", "float8"]: + raise NotImplementedError(f"Unsupported dtype: {mma_dtype}") + + if buffer.dtype not in ["uint8", "int8"]: + raise ValueError(f"metadata should be 8 bit, got {buffer.dtype}") + + bits_map = { + "float16": 16, + "bfloat16": 16, + "float32": 32, + "int8": 8, + "float8": 8, + } + + # ref: https://github.com/NVIDIA/cutlass/blob/c2ad7c5b20f131c4ba33601860f1da3f9c9df0f3/include/cutlass/gemm/collective/builders/sm90_sparse_config.inl#L108-L117 + # get atom layout according to mma dtype + BlockK = 512 // bits_map[mma_dtype] + if block_k % BlockK != 0: + raise ValueError(f"Tile K is too small, which should be at least {BlockK} for {mma_dtype}") + NumK = block_k // BlockK # block_k is MinTileShapeK + + def gen_stride(shape_ik, order): + stride_ik = [None for _ in range(len(shape_ik))] + order = [(i, o) for i, o in enumerate(order)] + order.sort(key=lambda x: x[1]) + accu_shape = 1 + for i, (o, _) in enumerate(order): + if i == 0: + stride_ik[o] = 1 + else: + stride_ik[o] = accu_shape + accu_shape *= shape_ik[o] + return stride_ik + + if bits_map[mma_dtype] == 32: # x // 8 is to convert bits into uint8 + shape_ik = [8, 2, 4, 8 // 8, 2, NumK] + stride_ik = gen_stride(shape_ik, [3, 1, 5, 0, 4, 2]) + shape_i, shape_k = shape_ik[:3], shape_ik[3:] + stride_i, stride_k = stride_ik[:3], stride_ik[3:] + elif bits_map[mma_dtype] == 16: + shape_ik = [8, 2, 4, 16 // 8, 2, NumK] + stride_ik = gen_stride(shape_ik, [3, 1, 5, 0, 4, 2]) + shape_i, shape_k = shape_ik[:3], shape_ik[3:] + stride_i, stride_k = stride_ik[:3], stride_ik[3:] + elif bits_map[mma_dtype] == 8: + shape_i, shape_k = [64], [BlockK] + stride_i, stride_k = [BlockK], [1] + else: + raise NotImplementedError(f"Unknown mma type {mma_dtype}") + + shape = buffer.shape + + # repeat to buffer size in col major + rep_i = (shape[0] + 63) // 64 + rep_k = (shape[1] + prod(shape_k) - 1) // prod(shape_k) + rep_i_stride = prod(shape_i + shape_k) + shape_i.append(rep_i) + stride_i.append(rep_i_stride) + rep_k_stirde = prod(shape_i + shape_k) + shape_k.append(rep_k) + stride_k.append(rep_k_stirde) + + def transform(i: int, k: int) -> int: + nonlocal shape_i, shape_k, stride_i, stride_k + i_decomposed = decompose_col_major(i, shape_i) + k_decomposed = decompose_col_major(k, shape_k) + i_offset = sum(i_decomposed[k] * stride_i[k] for k in range(len(i_decomposed))) + k_offset = sum(k_decomposed[k] * stride_k[k] for k in range(len(k_decomposed))) + return i_offset + k_offset + + return T.Layout(shape, transform) + + +def make_metadata_layout(buffer: tvm.tir.Buffer, + mma_dtype: str = "float16", + arch: str = "sm90", + backend: str = 'cutlass', + **extra_args): + if arch == "sm90": + if backend == 'cutlass': + return __make_metadata_layout_sm90_cutlass(buffer, mma_dtype, **extra_args) + else: + raise NotImplementedError(f"Arch {arch}, Unsupported backend: {backend}") + else: + raise NotImplementedError(f"Unsupported architecture: {arch}") diff --git a/tilelang/utils/sparse.py b/tilelang/utils/sparse.py new file mode 100644 index 000000000..6f97ebad2 --- /dev/null +++ b/tilelang/utils/sparse.py @@ -0,0 +1,58 @@ +# Copyright (c) Tile-AI Corporation. +# Licensed under the MIT License. + +import os +import torch +import warnings +from torch.utils.cpp_extension import load, _import_module_from_library +from tilelang.env import TILELANG_CACHE_DIR, TILELANG_TEMPLATE_PATH, CUTLASS_INCLUDE_DIR + +# Define paths +compress_util = os.path.join(TILELANG_TEMPLATE_PATH, "tl_templates/cuda/compress_sm90.cu") +# Cache directory for compiled extensions +_CACHE_DIR = os.path.join(TILELANG_CACHE_DIR, "sparse_compressor") +os.makedirs(_CACHE_DIR, exist_ok=True) + + +def _get_cached_lib(): + name = 'compress_lib' + cached_path = os.path.join(_CACHE_DIR, f"{name}.so") + + if os.path.exists(cached_path): + try: + return _import_module_from_library(name, cached_path) + except Exception: + # If loading fails, recompile + pass + + from tilelang.env import _initialize_torch_cuda_arch_flags + # Set TORCH_CUDA_ARCH_LIST + _initialize_torch_cuda_arch_flags() + + # Compile if not cached or loading failed + return load( + name=name, + sources=[compress_util], + extra_cuda_cflags=[ + '-O2', + '-std=c++17', + '-lineinfo', + f'-I{CUTLASS_INCLUDE_DIR}', + f'-I{CUTLASS_INCLUDE_DIR}/../tools/util/include', + '-arch=sm_90', + ], + build_directory=_CACHE_DIR, + ) + + +def compress_sm90(A: torch.Tensor, block_k: int, + transposed: bool) -> tuple[torch.Tensor, torch.Tensor]: + if block_k > 128: + block_k = 128 + # Ref: https://github.com/NVIDIA/cutlass/blob/c2ad7c5b20f131c4ba33601860f1da3f9c9df0f3/include/cutlass/gemm/collective/builders/sm90_sparse_gmma_builder.inl#L145-L146 + warnings.warn( + f"block_k {block_k} is too large, set to 128 for sm90 compression.", stacklevel=2) + # Load the library (will use cache if available) + compress_lib = _get_cached_lib() + + return compress_lib.compress_sm90(A, block_k, transposed)