-
Notifications
You must be signed in to change notification settings - Fork 243
[Enhancement] Support tf32 gemm_rs #607
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -7,6 +7,8 @@ | |||||||||
import tilelang.language as T | ||||||||||
from tilelang.autotuner import AutoTuner | ||||||||||
|
||||||||||
tilelang.disable_cache() | ||||||||||
|
||||||||||
|
||||||||||
def ref_program(x, y): | ||||||||||
return x + y | ||||||||||
|
@@ -19,12 +21,17 @@ def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads): | |||||||||
def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor( | ||||||||||
(M, N), out_dtype)): | ||||||||||
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): | ||||||||||
start_x = bx * block_N | ||||||||||
start_y = by * block_M | ||||||||||
A_shared = T.alloc_shared((block_M, block_N), in_dtype) | ||||||||||
B_shared = T.alloc_shared((block_M, block_N), in_dtype) | ||||||||||
C_local = T.alloc_fragment((block_M, block_N), out_dtype) | ||||||||||
C_shared = T.alloc_shared((block_M, block_N), out_dtype) | ||||||||||
|
||||||||||
T.copy(A[by * block_M, bx * block_N], A_shared) | ||||||||||
T.copy(B[by * block_M, bx * block_N], B_shared) | ||||||||||
for (local_y, local_x) in T.Parallel(block_M, block_N): | ||||||||||
y = start_y + local_y | ||||||||||
x = start_x + local_x | ||||||||||
C[y, x] = A[y, x] + B[y, x] | ||||||||||
C_local[local_y, local_x] = A_shared[local_y, local_x] + B_shared[local_y, local_x] | ||||||||||
T.copy(C_local, C_shared) | ||||||||||
T.copy(C_shared, C[by * block_M, bx * block_N]) | ||||||||||
|
||||||||||
return elem_add | ||||||||||
|
||||||||||
|
@@ -56,7 +63,7 @@ def kernel(block_M=None, block_N=None, threads=None): | |||||||||
|
||||||||||
def main(): | ||||||||||
parser = argparse.ArgumentParser() | ||||||||||
parser.add_argument("--m", type=int, default=512) | ||||||||||
parser.add_argument("--m", type=int, default=1024) | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||
parser.add_argument("--n", type=int, default=1024) | ||||||||||
parser.add_argument("--use_autotune", action="store_true", default=False) | ||||||||||
args, _ = parser.parse_known_args() | ||||||||||
|
@@ -70,9 +77,8 @@ def main(): | |||||||||
kernel = result.kernel | ||||||||||
else: | ||||||||||
# Default config | ||||||||||
config = {"block_M": 128, "block_N": 256, "threads": 128} | ||||||||||
config = {"block_M": 128, "block_N": 128, "threads": 128} | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The block sizes are hardcoded. Consider adding command-line arguments for
Suggested change
|
||||||||||
kernel = elementwise_add(M, N, **config, in_dtype="float32", out_dtype="float32") | ||||||||||
|
||||||||||
out = kernel(a, b) | ||||||||||
torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2) | ||||||||||
|
||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -21,6 +21,15 @@ static IterVar make_itervar(std::string name, PrimExpr dom) { | |||||||||
return IterVar(Range(0, dom), var, IterVarType::kDataPar); | ||||||||||
} | ||||||||||
|
||||||||||
Fragment makeGemmFragment8x4() { | ||||||||||
IterVar i = make_itervar("i", 8); | ||||||||||
IterVar j = make_itervar("j", 4); | ||||||||||
IterVar rep = make_itervar("rep", 1); | ||||||||||
PrimExpr forward_thread = FloorDiv(j->var, 1) + 4 * i; | ||||||||||
PrimExpr index = FloorMod(j->var, 1); | ||||||||||
Comment on lines
+28
to
+29
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The expressions
Suggested change
|
||||||||||
return Fragment({i, j}, {index}, forward_thread, rep); | ||||||||||
} | ||||||||||
|
||||||||||
Fragment makeGemmFragment8x8() { | ||||||||||
IterVar i = make_itervar("i", 8); | ||||||||||
IterVar j = make_itervar("j", 8); | ||||||||||
|
@@ -29,6 +38,25 @@ Fragment makeGemmFragment8x8() { | |||||||||
PrimExpr index = FloorMod(j->var, 2); | ||||||||||
return Fragment({i, j}, {index}, forward_thread, rep); | ||||||||||
} | ||||||||||
|
||||||||||
Fragment makeGemmFragment8x16() { | ||||||||||
IterVar i = make_itervar("i", 8); | ||||||||||
IterVar j = make_itervar("j", 16); | ||||||||||
IterVar rep = make_itervar("rep", 1); | ||||||||||
PrimExpr forward_thread = FloorDiv(j->var, 4) + 4 * i; | ||||||||||
PrimExpr index = FloorMod(j->var, 4); | ||||||||||
return Fragment({i, j}, {index}, forward_thread, rep); | ||||||||||
} | ||||||||||
|
||||||||||
Fragment makeGemmFragment8x8Transposed() { | ||||||||||
IterVar i = make_itervar("i", 8); | ||||||||||
IterVar j = make_itervar("j", 8); | ||||||||||
IterVar rep = make_itervar("rep", 1); | ||||||||||
PrimExpr forward_thread = FloorDiv(i->var, 2) + 4 * j; | ||||||||||
PrimExpr index = FloorMod(i->var, 2); | ||||||||||
return Fragment({i, j}, {index}, forward_thread, rep); | ||||||||||
} | ||||||||||
|
||||||||||
/* | ||||||||||
From https://github.com/RadeonOpenCompute/amd_matrix_instruction_calculator | ||||||||||
./matrix_calculator.py --architecture cdna1 --instruction v_mfma_f32_16x16x16f16 | ||||||||||
|
@@ -61,24 +89,6 @@ Fragment makeGemmFragmentC16x16CDNA() { | |||||||||
return Fragment({i, j}, {index}, forward_thread, rep); | ||||||||||
} | ||||||||||
|
||||||||||
Fragment makeGemmFragment8x8Transposed() { | ||||||||||
IterVar i = make_itervar("i", 8); | ||||||||||
IterVar j = make_itervar("j", 8); | ||||||||||
IterVar rep = make_itervar("rep", 1); | ||||||||||
PrimExpr forward_thread = FloorDiv(i->var, 2) + 4 * j; | ||||||||||
PrimExpr index = FloorMod(i->var, 2); | ||||||||||
return Fragment({i, j}, {index}, forward_thread, rep); | ||||||||||
} | ||||||||||
|
||||||||||
Fragment makeGemmFragment8x16() { | ||||||||||
IterVar i = make_itervar("i", 8); | ||||||||||
IterVar j = make_itervar("j", 16); | ||||||||||
IterVar rep = make_itervar("rep", 1); | ||||||||||
PrimExpr forward_thread = FloorDiv(j->var, 4) + 4 * i; | ||||||||||
PrimExpr index = FloorMod(j->var, 4); | ||||||||||
return Fragment({i, j}, {index}, forward_thread, rep); | ||||||||||
} | ||||||||||
|
||||||||||
Fragment makeGemmFragmentC_F64(const int block_m, const int block_n, | ||||||||||
const int warp_m, const int warp_n) { | ||||||||||
ICHECK(block_m % warp_m == 0); | ||||||||||
|
@@ -150,8 +160,8 @@ Fragment makeGemmFragmentA(const int block_m, const int block_n, | |||||||||
ICHECK(warp_m % 16 == 0); | ||||||||||
ICHECK(block_k % 16 == 0); | ||||||||||
// Only support 8-bit and 16-bit | ||||||||||
ICHECK(element_size == 8 || element_size == 16) | ||||||||||
<< "element bitwidth=" << element_size; | ||||||||||
ICHECK(element_size == 8 || element_size == 16 || element_size == 32) | ||||||||||
<< "unsupported element bitwidth=" << element_size; | ||||||||||
|
||||||||||
if (transposed) { | ||||||||||
auto base_layout = | ||||||||||
|
@@ -176,6 +186,13 @@ Fragment makeGemmFragmentA(const int block_m, const int block_n, | |||||||||
auto block_layout = | ||||||||||
warp_layout->Repeat({warp_m / 16, block_k / 16}, false, false); | ||||||||||
return block_layout; | ||||||||||
} else if (element_size == 32) { | ||||||||||
auto base_layout = makeGemmFragment8x4()->Repeat({2, 2}, false, false); | ||||||||||
auto warp_layout = base_layout->Repeat({block_m / warp_m, 1}, true) | ||||||||||
->Replicate(block_n / warp_n); | ||||||||||
auto block_layout = | ||||||||||
warp_layout->Repeat({warp_m / 16, block_k / 8}, false, false); | ||||||||||
return block_layout; | ||||||||||
} else { | ||||||||||
ICHECK(0); | ||||||||||
return Fragment(); | ||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of copying
C_local
toC_shared
and then toC
, consider directly copyingC_local
toC
. This eliminates the need forC_shared
and reduces memory operations, potentially improving performance.