-
Notifications
You must be signed in to change notification settings - Fork 248
[Experimental][Language] add T.GEMM_SP
for sm90 sparse tensor core
#526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
47 commits
Select commit
Hold shift + click to select a range
cb41ed1
[experimental] add a draft gemm_sp
botbw 95492be
[3rdparty] bump cutlass to v3.9.3
botbw 2c72c01
[lint] run format.sh
botbw 86fe989
[chore] rebase
botbw fdd1828
[chore] use abs path
botbw acad673
[gemm_sp] add metadata layout
botbw 213762c
[ci] add more example
botbw 775d2b9
[lint] run format.sh
botbw eceab43
[chore] polish
botbw e8c0d4d
[chore] move gemm_sp to experimental
botbw 0a1e366
[chore] polish
botbw 621e3cf
[lint] run format.sh
botbw 3f17184
Merge branch 'main' of https://github.com/tile-ai/tilelang into gemm_sp
LeiWang1999 70d0549
[Enhancement] Improve bulk copy handling and update GEMM sparse tenso…
LeiWang1999 2ae80f7
Implement Test
LeiWang1999 b98a0ed
[Enhancement] Update GEMM SP and SM89 templates for improved function…
LeiWang1999 297603e
lint fix
LeiWang1999 b37899b
[gemm_sp] support more layout and data types
botbw bc3c83c
Enhancement: sync T.gemm_sp's layout inference with T.gemm
botbw 27ed04a
Enhancement: support more block_k in compress util
botbw f698ed7
[Enhancement] enable block_k=64
botbw 556a3f3
[Lint] run format.sh
botbw f3a1ccc
[Enhancement] compressor support more dtype
botbw 0a803f9
Merge remote-tracking branch 'upstream/main' into gemm_sp
botbw 7fdcbbf
Enhancement: enable block_K=32
botbw cecf234
[Lint] format.sh
botbw d8905c5
[Fixbug] fix shape
botbw ffe0cee
Refactor: sync gemm
botbw 6c8156e
[Enhancement] enable transpose
botbw 03132de
[Enhancement] enable fp8_e4m3
botbw 4cf3f4f
[Enhancement] enable int8
botbw a51d8f1
[Lint] run format.sh
botbw c603e5d
[Benchmark] add gemm_sp benchmark
botbw cff57ee
[Example] fix 256 threads hang
botbw 32dd9b1
[CI] fix ci
botbw 29be5ea
[Chore] resolve gemini feedback
botbw 57b9b57
[Benchmark] increase search space
botbw bc88a99
[Lint] format
botbw a9dcfc3
Merge remote-tracking branch 'upstream/main' into gemm_sp
botbw cf903b5
[CI] skip sparse tensor core related tests as only sm90 is supported
botbw 299b68a
[CI] pass local run
botbw b873dbf
Update gemm_sm89.h
LeiWang1999 017c67d
lint fix
LeiWang1999 2dc3ca9
Merge branch 'main' into gemm_sp
LeiWang1999 b18ecb3
lint fix
LeiWang1999 4a07736
[Enhancement] Add support for sparse GEMM and initialize CUDA archite…
LeiWang1999 3ce7992
Update test_compress_utils.py
LeiWang1999 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,267 @@ | ||
# Copyright (c) Tile-AI Corporation. | ||
# Licensed under the MIT License. | ||
|
||
import argparse | ||
import itertools | ||
import logging | ||
import torch | ||
from triton.testing import do_bench | ||
|
||
import tilelang.language as T | ||
from tilelang.autotuner import autotune | ||
from tilelang import jit | ||
from tilelang.layout import make_metadata_layout | ||
# Configure logger | ||
logger = logging.getLogger(__name__) | ||
logger.setLevel(logging.DEBUG) | ||
|
||
|
||
def ref_program(A, B): | ||
""" | ||
A reference matrix multiplication program, used to compare performance. | ||
|
||
Parameters | ||
---------- | ||
A : numpy.ndarray | ||
The matrix with shape (M, K). | ||
B : numpy.ndarray | ||
The matrix with shape (N, K). | ||
|
||
Returns | ||
------- | ||
np.ndarray | ||
The result of A @ B.T, shape (M, N). | ||
""" | ||
return A @ B.T | ||
|
||
|
||
def get_configs(M, N, K): | ||
""" | ||
Generate a list of configuration dictionaries that will be used for tuning. | ||
|
||
Parameters | ||
---------- | ||
with_roller : bool | ||
Whether to enable bitblas roller to deduce search spaces | ||
|
||
Returns | ||
------- | ||
list of dict | ||
Each configuration dict includes various block sizes, pipeline stages, | ||
thread numbers, and other parameters to explore during autotuning. | ||
""" | ||
block_M = [64, 128, 256] | ||
block_N = [64, 128, 256] | ||
block_K = [64, 128] | ||
num_stages = [0, 1, 2, 3] | ||
thread_num = [128, 256] | ||
enable_rasterization = [True, False] | ||
policy = [T.GemmWarpPolicy.Square] | ||
_configs = list( | ||
itertools.product( | ||
block_M, | ||
block_N, | ||
block_K, | ||
num_stages, | ||
thread_num, | ||
policy, | ||
enable_rasterization, | ||
)) | ||
|
||
configs = [ | ||
{ | ||
"block_M": c[0], | ||
"block_N": c[1], | ||
"block_K": c[2], | ||
"num_stages": c[3], | ||
"thread_num": c[4], | ||
"policy": c[5], | ||
"enable_rasterization": c[6], # keep param name for backward-compat | ||
} for c in _configs | ||
] | ||
return configs | ||
|
||
|
||
def matmul_sp(M, N, K): | ||
""" | ||
Create an autotuned matrix multiplication kernel for matrices of shape: | ||
- A: (M, K) | ||
- B: (N, K) | ||
- C: (M, N) | ||
|
||
Parameters | ||
---------- | ||
M : int | ||
The dimension M of the matrix multiplication. | ||
N : int | ||
The dimension N of the matrix multiplication. | ||
K : int | ||
The dimension K of the matrix multiplication. | ||
|
||
Returns | ||
------- | ||
(best_latency, best_config, ref_latency) | ||
best_latency : float | ||
The best latency found among the tuned configurations. | ||
best_config : dict | ||
The parameter configuration that yielded best_latency. | ||
ref_latency : float | ||
The baseline latency of the reference program (for computing speedup). | ||
""" | ||
Comment on lines
+102
to
+110
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
# Decorate the kernel with autotune & jit, specifying: | ||
# - Tuning config list | ||
# - Profiling keys | ||
# - Warmup and repetition counts for better measurement | ||
# - A reference program for correctness verification | ||
# - The "tvm" profiler backend | ||
# - HIP as the compilation target (modify as needed for your hardware) | ||
|
||
@autotune( | ||
configs=get_configs(M, N, K), | ||
warmup=3, | ||
rep=20, | ||
) | ||
@jit(out_idx=[2],) | ||
def kernel( | ||
block_M=None, | ||
block_N=None, | ||
block_K=None, | ||
num_stages=None, | ||
thread_num=None, | ||
policy=None, | ||
enable_rasterization=None, | ||
): | ||
""" | ||
The actual kernel to compute C = A @ B^T. | ||
|
||
Parameters | ||
---------- | ||
block_M : int | ||
Block size in M dimension. | ||
block_N : int | ||
Block size in N dimension. | ||
block_K : int | ||
Block size in K dimension. | ||
num_stages : int | ||
Number of pipelined stages (for asynchronous load). | ||
thread_num : int | ||
Number of threads to use per block. | ||
k_pack : int | ||
K dimension packing factor to improve memory coalescing. | ||
|
||
Returns | ||
------- | ||
Function | ||
A TVM Tensor Language function (T.prim_func) that computes matmul. | ||
""" | ||
# Use half-precision for input data to reduce memory bandwidth, | ||
# accumulate in float for better numerical accuracy | ||
dtype = "float16" | ||
accum_dtype = "float" | ||
|
||
@T.prim_func | ||
def main( | ||
A_sparse: T.Tensor((M, K // 2), dtype), | ||
E: T.Tensor((M, K // 8), 'uint8'), | ||
B: T.Tensor((N, K), dtype), | ||
C: T.Tensor((M, N), dtype), | ||
Comment on lines
+165
to
+168
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
): | ||
""" | ||
The compiled TVM function for block-level matrix multiplication. | ||
|
||
- We divide the entire (M, N) domain into blocks of shape | ||
(block_M, block_N). | ||
- Each block has its own allocated shared memory for sub-blocks | ||
of A and B. | ||
- The partial results go into C_local, and then we copy them back | ||
to global memory C. | ||
""" | ||
# Bind x-dimension to block index in N, | ||
# y-dimension to block index in M. | ||
with T.Kernel( | ||
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): | ||
|
||
# Allocate shared memory for A sub-block of shape (block_M, block_K) | ||
A_shared = T.alloc_shared((block_M, block_K // 2), dtype) | ||
# Allocate shared memory for B sub-block of shape (block_N, block_K) | ||
B_shared = T.alloc_shared((block_N, block_K), dtype) | ||
# Allocate shared memory for E sub-block of shape (block_M, block_K // E_factor) | ||
E_shared = T.alloc_shared((block_M, block_K // 8), 'uint8') | ||
# Allocate a local fragment for intermediate accumulation | ||
C_local = T.alloc_fragment((block_M, block_N), accum_dtype) | ||
# Allocate a shared memory for C sub-block of shape (block_M, block_N) | ||
C_shared = T.alloc_shared((block_M, block_N), dtype) | ||
|
||
# Clear out the accumulation buffer | ||
T.clear(C_local) | ||
T.no_set_max_nreg() | ||
|
||
T.use_swizzle(panel_size=10, enable=enable_rasterization) | ||
T.annotate_layout({ | ||
E: | ||
make_metadata_layout( | ||
E, mma_dtype="float16", arch="sm90", backend="cutlass", | ||
block_k=block_K), | ||
E_shared: | ||
make_metadata_layout( | ||
E_shared, | ||
mma_dtype="float16", | ||
arch="sm90", | ||
backend="cutlass", | ||
block_k=block_K), | ||
}) | ||
# Loop over sub-blocks in K dimension, pipelined by num_stages | ||
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): | ||
# Load a sub-block of A from global memory into A_shared | ||
T.copy(A_sparse[by * block_M, k * block_K], A_shared) | ||
# Load a sub-block of E from global memory into E_shared | ||
T.copy(E[by * block_M, k * block_K // 8], E_shared) | ||
# Load a sub-block of B from global memory into B_shared | ||
T.copy(B[bx * block_N, k * block_K], B_shared) | ||
# Perform a partial matrix multiplication: | ||
# C_local += A_shared @ B_shared^T | ||
T.gemm_sp( | ||
A_shared, | ||
E_shared, | ||
B_shared, | ||
C_local, | ||
transpose_B=True, | ||
policy=policy, | ||
) | ||
# Write back the results from C_local to the global memory C | ||
T.copy(C_local, C_shared) | ||
T.copy(C_shared, C[by * block_M, bx * block_N]) | ||
|
||
return main | ||
|
||
return kernel() | ||
|
||
|
||
if __name__ == "__main__": | ||
# Parse command-line arguments for matrix dimensions | ||
parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark") | ||
parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M") | ||
parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N") | ||
parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K") | ||
args = parser.parse_args() | ||
|
||
M, N, K = args.m, args.n, args.k | ||
|
||
# Compute total floating-point operations to measure throughput | ||
total_flops = 2 * M * N * K | ||
|
||
# matmul(...) returns (best_latency, best_config, ref_latency) | ||
best_result = matmul_sp(M, N, K) | ||
best_latency = best_result.latency | ||
best_config = best_result.config | ||
A = torch.randn(M, K, dtype=torch.float16, device="cuda") | ||
B = torch.randn(N, K, dtype=torch.float16, device="cuda") | ||
ref_latency = do_bench(lambda: A @ B.T) | ||
|
||
# Print out the benchmark results | ||
print(f"Best latency (s): {best_latency}") | ||
print(f"Best TFlops: {total_flops / best_latency * 1e-9:.3f}") | ||
print(f"Best config: {best_config}") | ||
|
||
print(f"Reference TFlops: {total_flops / ref_latency * 1e-9:.3f}") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
15 changes: 15 additions & 0 deletions
15
examples/sparse_tensorcore/test_example_sparse_tensorcore.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# Copyright (c) Tile-AI Corporation. | ||
# Licensed under the MIT License. | ||
import tilelang.testing | ||
import tilelang | ||
import tilelang_example_sparse_tensorcore | ||
|
||
|
||
@tilelang.testing.requires_cuda | ||
@tilelang.testing.requires_cuda_compute_version(9, 0) | ||
def test_tilelang_example_sparse_tensorcore(): | ||
tilelang_example_sparse_tensorcore.main() | ||
|
||
|
||
if __name__ == "__main__": | ||
tilelang.testing.main() |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstring mentions a
with_roller
parameter, which isn't in the function signature. Please correct the docstring or add the parameter.