Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
cb41ed1
[experimental] add a draft gemm_sp
botbw Apr 8, 2025
95492be
[3rdparty] bump cutlass to v3.9.3
botbw May 27, 2025
2c72c01
[lint] run format.sh
botbw May 27, 2025
86fe989
[chore] rebase
botbw May 27, 2025
fdd1828
[chore] use abs path
botbw May 27, 2025
acad673
[gemm_sp] add metadata layout
botbw Jun 6, 2025
213762c
[ci] add more example
botbw Jun 6, 2025
775d2b9
[lint] run format.sh
botbw Jun 6, 2025
eceab43
[chore] polish
botbw Jun 6, 2025
e8c0d4d
[chore] move gemm_sp to experimental
botbw Jun 6, 2025
0a1e366
[chore] polish
botbw Jun 6, 2025
621e3cf
[lint] run format.sh
botbw Jun 6, 2025
3f17184
Merge branch 'main' of https://github.com/tile-ai/tilelang into gemm_sp
LeiWang1999 Jun 7, 2025
70d0549
[Enhancement] Improve bulk copy handling and update GEMM sparse tenso…
LeiWang1999 Jun 8, 2025
2ae80f7
Implement Test
LeiWang1999 Jun 8, 2025
b98a0ed
[Enhancement] Update GEMM SP and SM89 templates for improved function…
LeiWang1999 Jun 8, 2025
297603e
lint fix
LeiWang1999 Jun 8, 2025
b37899b
[gemm_sp] support more layout and data types
botbw Jun 9, 2025
bc3c83c
Enhancement: sync T.gemm_sp's layout inference with T.gemm
botbw Jun 10, 2025
27ed04a
Enhancement: support more block_k in compress util
botbw Jun 10, 2025
f698ed7
[Enhancement] enable block_k=64
botbw Jun 11, 2025
556a3f3
[Lint] run format.sh
botbw Jun 11, 2025
f3a1ccc
[Enhancement] compressor support more dtype
botbw Jun 11, 2025
0a803f9
Merge remote-tracking branch 'upstream/main' into gemm_sp
botbw Jun 12, 2025
7fdcbbf
Enhancement: enable block_K=32
botbw Jun 12, 2025
cecf234
[Lint] format.sh
botbw Jun 12, 2025
d8905c5
[Fixbug] fix shape
botbw Jun 12, 2025
ffe0cee
Refactor: sync gemm
botbw Jun 12, 2025
6c8156e
[Enhancement] enable transpose
botbw Jun 12, 2025
03132de
[Enhancement] enable fp8_e4m3
botbw Jun 16, 2025
4cf3f4f
[Enhancement] enable int8
botbw Jun 16, 2025
a51d8f1
[Lint] run format.sh
botbw Jun 16, 2025
c603e5d
[Benchmark] add gemm_sp benchmark
botbw Jun 16, 2025
cff57ee
[Example] fix 256 threads hang
botbw Jun 16, 2025
32dd9b1
[CI] fix ci
botbw Jun 16, 2025
29be5ea
[Chore] resolve gemini feedback
botbw Jun 16, 2025
57b9b57
[Benchmark] increase search space
botbw Jun 17, 2025
bc88a99
[Lint] format
botbw Jul 1, 2025
a9dcfc3
Merge remote-tracking branch 'upstream/main' into gemm_sp
botbw Jul 1, 2025
cf903b5
[CI] skip sparse tensor core related tests as only sm90 is supported
botbw Jul 1, 2025
299b68a
[CI] pass local run
botbw Jul 1, 2025
b873dbf
Update gemm_sm89.h
LeiWang1999 Jul 1, 2025
017c67d
lint fix
LeiWang1999 Jul 2, 2025
2dc3ca9
Merge branch 'main' into gemm_sp
LeiWang1999 Jul 2, 2025
b18ecb3
lint fix
LeiWang1999 Jul 2, 2025
4a07736
[Enhancement] Add support for sparse GEMM and initialize CUDA archite…
LeiWang1999 Jul 3, 2025
3ce7992
Update test_compress_utils.py
LeiWang1999 Jul 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/cutlass
Submodule cutlass updated 530 files
1 change: 1 addition & 0 deletions benchmark/matmul/benchmark_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def matmul(M, N, K, with_roller):
configs=get_configs(M, N, K, with_roller),
warmup=3,
rep=20,
ref_prog=ref_program,
)
@jit(out_idx=[2],)
def kernel(
Expand Down
267 changes: 267 additions & 0 deletions benchmark/matmul/benchmark_matmul_sp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
# Copyright (c) Tile-AI Corporation.
# Licensed under the MIT License.

import argparse
import itertools
import logging
import torch
from triton.testing import do_bench

import tilelang.language as T
from tilelang.autotuner import autotune
from tilelang import jit
from tilelang.layout import make_metadata_layout
# Configure logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)


def ref_program(A, B):
"""
A reference matrix multiplication program, used to compare performance.

Parameters
----------
A : numpy.ndarray
The matrix with shape (M, K).
B : numpy.ndarray
The matrix with shape (N, K).

Returns
-------
np.ndarray
The result of A @ B.T, shape (M, N).
"""
return A @ B.T


def get_configs(M, N, K):
"""
Generate a list of configuration dictionaries that will be used for tuning.

Parameters
----------
with_roller : bool
Whether to enable bitblas roller to deduce search spaces

Returns
-------
list of dict
Each configuration dict includes various block sizes, pipeline stages,
thread numbers, and other parameters to explore during autotuning.
"""
Comment on lines +38 to +52
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring mentions a with_roller parameter, which isn't in the function signature. Please correct the docstring or add the parameter.

block_M = [64, 128, 256]
block_N = [64, 128, 256]
block_K = [64, 128]
num_stages = [0, 1, 2, 3]
thread_num = [128, 256]
enable_rasterization = [True, False]
policy = [T.GemmWarpPolicy.Square]
_configs = list(
itertools.product(
block_M,
block_N,
block_K,
num_stages,
thread_num,
policy,
enable_rasterization,
))

configs = [
{
"block_M": c[0],
"block_N": c[1],
"block_K": c[2],
"num_stages": c[3],
"thread_num": c[4],
"policy": c[5],
"enable_rasterization": c[6], # keep param name for backward-compat
} for c in _configs
]
return configs


def matmul_sp(M, N, K):
"""
Create an autotuned matrix multiplication kernel for matrices of shape:
- A: (M, K)
- B: (N, K)
- C: (M, N)

Parameters
----------
M : int
The dimension M of the matrix multiplication.
N : int
The dimension N of the matrix multiplication.
K : int
The dimension K of the matrix multiplication.

Returns
-------
(best_latency, best_config, ref_latency)
best_latency : float
The best latency found among the tuned configurations.
best_config : dict
The parameter configuration that yielded best_latency.
ref_latency : float
The baseline latency of the reference program (for computing speedup).
"""
Comment on lines +102 to +110
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring for matmul_sp should be updated to reflect that it returns an AutotuneResult object, not (best_latency, best_config, ref_latency) directly.


# Decorate the kernel with autotune & jit, specifying:
# - Tuning config list
# - Profiling keys
# - Warmup and repetition counts for better measurement
# - A reference program for correctness verification
# - The "tvm" profiler backend
# - HIP as the compilation target (modify as needed for your hardware)

@autotune(
configs=get_configs(M, N, K),
warmup=3,
rep=20,
)
@jit(out_idx=[2],)
def kernel(
block_M=None,
block_N=None,
block_K=None,
num_stages=None,
thread_num=None,
policy=None,
enable_rasterization=None,
):
"""
The actual kernel to compute C = A @ B^T.

Parameters
----------
block_M : int
Block size in M dimension.
block_N : int
Block size in N dimension.
block_K : int
Block size in K dimension.
num_stages : int
Number of pipelined stages (for asynchronous load).
thread_num : int
Number of threads to use per block.
k_pack : int
K dimension packing factor to improve memory coalescing.

Returns
-------
Function
A TVM Tensor Language function (T.prim_func) that computes matmul.
"""
# Use half-precision for input data to reduce memory bandwidth,
# accumulate in float for better numerical accuracy
dtype = "float16"
accum_dtype = "float"

@T.prim_func
def main(
A_sparse: T.Tensor((M, K // 2), dtype),
E: T.Tensor((M, K // 8), 'uint8'),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype),
Comment on lines +165 to +168
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider adding assertions or checks for K % 8 == 0 to ensure A_sparse and E tensor shape divisions are valid.

):
"""
The compiled TVM function for block-level matrix multiplication.

- We divide the entire (M, N) domain into blocks of shape
(block_M, block_N).
- Each block has its own allocated shared memory for sub-blocks
of A and B.
- The partial results go into C_local, and then we copy them back
to global memory C.
"""
# Bind x-dimension to block index in N,
# y-dimension to block index in M.
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):

# Allocate shared memory for A sub-block of shape (block_M, block_K)
A_shared = T.alloc_shared((block_M, block_K // 2), dtype)
# Allocate shared memory for B sub-block of shape (block_N, block_K)
B_shared = T.alloc_shared((block_N, block_K), dtype)
# Allocate shared memory for E sub-block of shape (block_M, block_K // E_factor)
E_shared = T.alloc_shared((block_M, block_K // 8), 'uint8')
# Allocate a local fragment for intermediate accumulation
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# Allocate a shared memory for C sub-block of shape (block_M, block_N)
C_shared = T.alloc_shared((block_M, block_N), dtype)

# Clear out the accumulation buffer
T.clear(C_local)
T.no_set_max_nreg()

T.use_swizzle(panel_size=10, enable=enable_rasterization)
T.annotate_layout({
E:
make_metadata_layout(
E, mma_dtype="float16", arch="sm90", backend="cutlass",
block_k=block_K),
E_shared:
make_metadata_layout(
E_shared,
mma_dtype="float16",
arch="sm90",
backend="cutlass",
block_k=block_K),
})
# Loop over sub-blocks in K dimension, pipelined by num_stages
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
# Load a sub-block of A from global memory into A_shared
T.copy(A_sparse[by * block_M, k * block_K], A_shared)
# Load a sub-block of E from global memory into E_shared
T.copy(E[by * block_M, k * block_K // 8], E_shared)
# Load a sub-block of B from global memory into B_shared
T.copy(B[bx * block_N, k * block_K], B_shared)
# Perform a partial matrix multiplication:
# C_local += A_shared @ B_shared^T
T.gemm_sp(
A_shared,
E_shared,
B_shared,
C_local,
transpose_B=True,
policy=policy,
)
# Write back the results from C_local to the global memory C
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])

return main

return kernel()


if __name__ == "__main__":
# Parse command-line arguments for matrix dimensions
parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark")
parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K")
args = parser.parse_args()

M, N, K = args.m, args.n, args.k

# Compute total floating-point operations to measure throughput
total_flops = 2 * M * N * K

# matmul(...) returns (best_latency, best_config, ref_latency)
best_result = matmul_sp(M, N, K)
best_latency = best_result.latency
best_config = best_result.config
A = torch.randn(M, K, dtype=torch.float16, device="cuda")
B = torch.randn(N, K, dtype=torch.float16, device="cuda")
ref_latency = do_bench(lambda: A @ B.T)

# Print out the benchmark results
print(f"Best latency (s): {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9:.3f}")
print(f"Best config: {best_config}")

print(f"Reference TFlops: {total_flops / ref_latency * 1e-9:.3f}")
3 changes: 2 additions & 1 deletion examples/elementwise/example_elementwise_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,9 @@ def main():
kernel = result.kernel
else:
# Default config
config = {"block_M": 128, "block_N": 128, "threads": 128}
config = {"block_M": 32, "block_N": 32, "threads": 128}
kernel = elementwise_add(M, N, **config, in_dtype="float32", out_dtype="float32")

out = kernel(a, b)
torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2)

Expand Down
15 changes: 15 additions & 0 deletions examples/sparse_tensorcore/test_example_sparse_tensorcore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) Tile-AI Corporation.
# Licensed under the MIT License.
import tilelang.testing
import tilelang
import tilelang_example_sparse_tensorcore


@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(9, 0)
def test_tilelang_example_sparse_tensorcore():
tilelang_example_sparse_tensorcore.main()


if __name__ == "__main__":
tilelang.testing.main()
Loading
Loading