Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions examples/elementwise/example_elementwise_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import tilelang.language as T
from tilelang.autotuner import AutoTuner

tilelang.disable_cache()


def ref_program(x, y):
return x + y
Expand All @@ -19,12 +21,17 @@ def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads):
def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor(
(M, N), out_dtype)):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
start_x = bx * block_N
start_y = by * block_M
A_shared = T.alloc_shared((block_M, block_N), in_dtype)
B_shared = T.alloc_shared((block_M, block_N), in_dtype)
C_local = T.alloc_fragment((block_M, block_N), out_dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)

T.copy(A[by * block_M, bx * block_N], A_shared)
T.copy(B[by * block_M, bx * block_N], B_shared)
for (local_y, local_x) in T.Parallel(block_M, block_N):
y = start_y + local_y
x = start_x + local_x
C[y, x] = A[y, x] + B[y, x]
C_local[local_y, local_x] = A_shared[local_y, local_x] + B_shared[local_y, local_x]
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
Comment on lines +33 to +34
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Instead of copying C_local to C_shared and then to C, consider directly copying C_local to C. This eliminates the need for C_shared and reduces memory operations, potentially improving performance.

Suggested change
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
A_shared = T.alloc_shared((block_M, block_N), in_dtype)
B_shared = T.alloc_shared((block_M, block_N), in_dtype)
C_local = T.alloc_fragment((block_M, block_N), out_dtype)
T.copy(A[by * block_M, bx * block_N], A_shared)
T.copy(B[by * block_M, bx * block_N], B_shared)
for (local_y, local_x) in T.Parallel(block_M, block_N):
C_local[local_y, local_x] = A_shared[local_y, local_x] + B_shared[local_y, local_x]
T.copy(C_local, C[by * block_M, bx * block_N])


return elem_add

Expand Down Expand Up @@ -56,7 +63,7 @@ def kernel(block_M=None, block_N=None, threads=None):

def main():
parser = argparse.ArgumentParser()
parser.add_argument("--m", type=int, default=512)
parser.add_argument("--m", type=int, default=1024)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider using a more descriptive variable name than m for matrix dimension. Using matrix_size or num_rows would improve readability.

Suggested change
parser.add_argument("--m", type=int, default=1024)
parser.add_argument("--matrix_size", type=int, default=1024)

parser.add_argument("--n", type=int, default=1024)
parser.add_argument("--use_autotune", action="store_true", default=False)
args, _ = parser.parse_known_args()
Expand All @@ -70,9 +77,8 @@ def main():
kernel = result.kernel
else:
# Default config
config = {"block_M": 128, "block_N": 256, "threads": 128}
config = {"block_M": 128, "block_N": 128, "threads": 128}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The block sizes are hardcoded. Consider adding command-line arguments for block_M and block_N to allow users to experiment with different block sizes without modifying the code.

Suggested change
config = {"block_M": 128, "block_N": 128, "threads": 128}
parser.add_argument("--block_M", type=int, default=128)
parser.add_argument("--block_N", type=int, default=128)
config = {"block_M": args.block_M, "block_N": args.block_N, "threads": 128}

kernel = elementwise_add(M, N, **config, in_dtype="float32", out_dtype="float32")

out = kernel(a, b)
torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2)

Expand Down
4 changes: 2 additions & 2 deletions examples/quickstart.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Licensed under the MIT License.
import tilelang
import tilelang.language as T

# `make_mma_swizzle_layout` is a python defined layout function
# specifically designed for MMA operations
# which ensures the consistency with the nvidia CUTLASS Library.
Expand Down Expand Up @@ -73,8 +74,7 @@ def main(
# out_idx specifies the index of the output buffer in the argument list
# if out_idx is specified, the tensor will be created during runtime
# target currently can be "cuda" or "hip" or "cpu".
jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", execution_backend="cython")
# jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", execution_backend="dlpack")
jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda")

# 3. Test the kernel in Python with PyTorch data
import torch
Expand Down
57 changes: 37 additions & 20 deletions src/layout/gemm_layouts.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@ static IterVar make_itervar(std::string name, PrimExpr dom) {
return IterVar(Range(0, dom), var, IterVarType::kDataPar);
}

Fragment makeGemmFragment8x4() {
IterVar i = make_itervar("i", 8);
IterVar j = make_itervar("j", 4);
IterVar rep = make_itervar("rep", 1);
PrimExpr forward_thread = FloorDiv(j->var, 1) + 4 * i;
PrimExpr index = FloorMod(j->var, 1);
Comment on lines +28 to +29
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The expressions FloorDiv(j->var, 1) and FloorMod(j->var, 1) can be simplified to j->var and 0 respectively. This improves readability and potentially reduces unnecessary computation.

Suggested change
PrimExpr forward_thread = FloorDiv(j->var, 1) + 4 * i;
PrimExpr index = FloorMod(j->var, 1);
PrimExpr forward_thread = j->var + 4 * i;
PrimExpr index = 0;

return Fragment({i, j}, {index}, forward_thread, rep);
}

Fragment makeGemmFragment8x8() {
IterVar i = make_itervar("i", 8);
IterVar j = make_itervar("j", 8);
Expand All @@ -29,6 +38,25 @@ Fragment makeGemmFragment8x8() {
PrimExpr index = FloorMod(j->var, 2);
return Fragment({i, j}, {index}, forward_thread, rep);
}

Fragment makeGemmFragment8x16() {
IterVar i = make_itervar("i", 8);
IterVar j = make_itervar("j", 16);
IterVar rep = make_itervar("rep", 1);
PrimExpr forward_thread = FloorDiv(j->var, 4) + 4 * i;
PrimExpr index = FloorMod(j->var, 4);
return Fragment({i, j}, {index}, forward_thread, rep);
}

Fragment makeGemmFragment8x8Transposed() {
IterVar i = make_itervar("i", 8);
IterVar j = make_itervar("j", 8);
IterVar rep = make_itervar("rep", 1);
PrimExpr forward_thread = FloorDiv(i->var, 2) + 4 * j;
PrimExpr index = FloorMod(i->var, 2);
return Fragment({i, j}, {index}, forward_thread, rep);
}

/*
From https://github.com/RadeonOpenCompute/amd_matrix_instruction_calculator
./matrix_calculator.py --architecture cdna1 --instruction v_mfma_f32_16x16x16f16
Expand Down Expand Up @@ -61,24 +89,6 @@ Fragment makeGemmFragmentC16x16CDNA() {
return Fragment({i, j}, {index}, forward_thread, rep);
}

Fragment makeGemmFragment8x8Transposed() {
IterVar i = make_itervar("i", 8);
IterVar j = make_itervar("j", 8);
IterVar rep = make_itervar("rep", 1);
PrimExpr forward_thread = FloorDiv(i->var, 2) + 4 * j;
PrimExpr index = FloorMod(i->var, 2);
return Fragment({i, j}, {index}, forward_thread, rep);
}

Fragment makeGemmFragment8x16() {
IterVar i = make_itervar("i", 8);
IterVar j = make_itervar("j", 16);
IterVar rep = make_itervar("rep", 1);
PrimExpr forward_thread = FloorDiv(j->var, 4) + 4 * i;
PrimExpr index = FloorMod(j->var, 4);
return Fragment({i, j}, {index}, forward_thread, rep);
}

Fragment makeGemmFragmentC_F64(const int block_m, const int block_n,
const int warp_m, const int warp_n) {
ICHECK(block_m % warp_m == 0);
Expand Down Expand Up @@ -150,8 +160,8 @@ Fragment makeGemmFragmentA(const int block_m, const int block_n,
ICHECK(warp_m % 16 == 0);
ICHECK(block_k % 16 == 0);
// Only support 8-bit and 16-bit
ICHECK(element_size == 8 || element_size == 16)
<< "element bitwidth=" << element_size;
ICHECK(element_size == 8 || element_size == 16 || element_size == 32)
<< "unsupported element bitwidth=" << element_size;

if (transposed) {
auto base_layout =
Expand All @@ -176,6 +186,13 @@ Fragment makeGemmFragmentA(const int block_m, const int block_n,
auto block_layout =
warp_layout->Repeat({warp_m / 16, block_k / 16}, false, false);
return block_layout;
} else if (element_size == 32) {
auto base_layout = makeGemmFragment8x4()->Repeat({2, 2}, false, false);
auto warp_layout = base_layout->Repeat({block_m / warp_m, 1}, true)
->Replicate(block_n / warp_n);
auto block_layout =
warp_layout->Repeat({warp_m / 16, block_k / 8}, false, false);
return block_layout;
} else {
ICHECK(0);
return Fragment();
Expand Down
Loading