Skip to content

Commit 4d8daea

Browse files
committed
more benchmark and test extension
1 parent 2efe1b6 commit 4d8daea

File tree

2 files changed

+113
-10
lines changed

2 files changed

+113
-10
lines changed

benchmarks/cpp/nvfuser/matmul.cpp

Lines changed: 60 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ MatmulParam getMatmulParams(
241241
return params;
242242
}
243243

244-
static void Nvfuser_Matmul_4warp(
244+
static void Nvfuser_Matmul_4warp3stage(
245245
benchmark::State& benchmark_state,
246246
MatmulLayout layout) {
247247
auto cta_tile = GemmTile(128, 128, 32);
@@ -256,7 +256,7 @@ static void Nvfuser_Matmul_4warp(
256256
SingleMatmulBase(benchmark_state, layout, params);
257257
}
258258

259-
static void Nvfuser_Matmul_8warp(
259+
static void Nvfuser_Matmul_8warp3stage(
260260
benchmark::State& benchmark_state,
261261
MatmulLayout layout) {
262262
auto cta_tile = GemmTile(256, 128, 32);
@@ -271,6 +271,36 @@ static void Nvfuser_Matmul_8warp(
271271
SingleMatmulBase(benchmark_state, layout, params);
272272
}
273273

274+
static void Nvfuser_Matmul_4warp4stage(
275+
benchmark::State& benchmark_state,
276+
MatmulLayout layout) {
277+
auto cta_tile = GemmTile(128, 128, 32);
278+
int number_of_stage = 4;
279+
280+
auto params = getMatmulParams(cta_tile, number_of_stage, layout);
281+
282+
NVFUSER_BENCHMARK_ARCH_SMEM_GUARD(
283+
8, 0, getSmemSize(cta_tile, number_of_stage), benchmark_state);
284+
285+
// Run benchmark:
286+
SingleMatmulBase(benchmark_state, layout, params);
287+
}
288+
289+
static void Nvfuser_Matmul_8warp4stage(
290+
benchmark::State& benchmark_state,
291+
MatmulLayout layout) {
292+
auto cta_tile = GemmTile(256, 128, 32);
293+
int number_of_stage = 4;
294+
295+
auto params = getMatmulParams(cta_tile, number_of_stage, layout);
296+
297+
NVFUSER_BENCHMARK_ARCH_SMEM_GUARD(
298+
8, 0, getSmemSize(cta_tile, number_of_stage), benchmark_state);
299+
300+
// Run benchmark:
301+
SingleMatmulBase(benchmark_state, layout, params);
302+
}
303+
274304
// ----------------------------- Benchmark Instantiation-------
275305

276306
// Common utils:
@@ -286,21 +316,41 @@ static void Nvfuser_Matmul_8warp(
286316
run(NT, MatmulLayout::NT)
287317

288318
// Instantiations:
289-
#define Nvfuser_4warp_test(layout_label, layout) \
290-
BENCHMARK_CAPTURE( \
291-
Nvfuser_Matmul_4warp, no_quant_nvfuser_4warp_##layout_label, layout) \
319+
#define Nvfuser_4warp3stage_test(layout_label, layout) \
320+
BENCHMARK_CAPTURE( \
321+
Nvfuser_Matmul_4warp3stage, \
322+
no_quant_nvfuser_4warp_##layout_label, \
323+
layout) \
324+
->NO_TILE_QUANTIZATION_ARGS
325+
326+
#define Nvfuser_8warp3stage_test(layout_label, layout) \
327+
BENCHMARK_CAPTURE( \
328+
Nvfuser_Matmul_8warp3stage, \
329+
no_quant_nvfuser_8warp_##layout_label, \
330+
layout) \
331+
->NO_TILE_QUANTIZATION_ARGS
332+
333+
#define Nvfuser_4warp4stage_test(layout_label, layout) \
334+
BENCHMARK_CAPTURE( \
335+
Nvfuser_Matmul_4warp4stage, \
336+
no_quant_nvfuser_4warp_##layout_label, \
337+
layout) \
292338
->NO_TILE_QUANTIZATION_ARGS
293339

294-
#define Nvfuser_8warp_test(layout_label, layout) \
295-
BENCHMARK_CAPTURE( \
296-
Nvfuser_Matmul_8warp, no_quant_nvfuser_8warp_##layout_label, layout) \
340+
#define Nvfuser_8warp4stage_test(layout_label, layout) \
341+
BENCHMARK_CAPTURE( \
342+
Nvfuser_Matmul_8warp4stage, \
343+
no_quant_nvfuser_8warp_##layout_label, \
344+
layout) \
297345
->NO_TILE_QUANTIZATION_ARGS
298346

299347
#define Eagermode_test(layout_label, layout) \
300348
BENCHMARK_CAPTURE( \
301349
EagerModeMatmul, no_quant_eagermode_##layout_label, layout) \
302350
->NO_TILE_QUANTIZATION_ARGS
303351

304-
ForAllLayouts(Nvfuser_4warp_test);
305-
ForAllLayouts(Nvfuser_8warp_test);
352+
ForAllLayouts(Nvfuser_4warp3stage_test);
353+
ForAllLayouts(Nvfuser_4warp4stage_test);
354+
ForAllLayouts(Nvfuser_8warp3stage_test);
355+
ForAllLayouts(Nvfuser_8warp4stage_test);
306356
ForAllLayouts(Eagermode_test);

torch/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3056,6 +3056,59 @@ TEST_F(NVFuserTest, FusionTuringMatmulLargeLoad_CUDA) {
30563056
}
30573057
}
30583058

3059+
// Matmul test on Ampere using ldmatrix.x4 to load operands
3060+
TEST_F(NVFuserTest, FusionAmpereMatmulLargeLoadLargeK_CUDA) {
3061+
// Keep multiples of 8 to keep vectorizable.
3062+
int M = 504, N = 136, K = 2048;
3063+
for (auto layout : kAllSupportedLayout) {
3064+
Fusion fusion;
3065+
FusionGuard fg(&fusion);
3066+
auto tv0 = makeContigTensor(2, DataType::Half);
3067+
auto tv1 = makeContigTensor(2, DataType::Half);
3068+
3069+
fusion.addInput(tv0);
3070+
fusion.addInput(tv1);
3071+
3072+
auto tv2 = matmul(tv0, tv1, layout);
3073+
3074+
fusion.addOutput(tv2);
3075+
3076+
MatMulTileOptions gemm_tile;
3077+
gemm_tile.cta_tile = GemmTile(128, 128, 64);
3078+
gemm_tile.warp_tile = GemmTile(64, 64, 64);
3079+
gemm_tile.instruction_tile = GemmTile(16, 16, 16);
3080+
3081+
auto mma_builder =
3082+
MmaBuilder(MmaOptions::MacroType::Ampere_16_16_16, gemm_tile)
3083+
.layout(layout);
3084+
3085+
MatmulParam params(mma_builder);
3086+
params.tile_sizes = gemm_tile;
3087+
params.async_gmem_load_operands = true;
3088+
params.double_buffer_options.double_buffer_smem_write = true;
3089+
params.double_buffer_options.double_buffer_smem_read = true;
3090+
params.double_buffer_options.smem_double_buffer_stage = 3;
3091+
scheduleMatmul(tv2, tv0, tv1, params);
3092+
3093+
at::manual_seed(0);
3094+
auto inputs = fp16MatmulAtInput(M, N, K, layout);
3095+
3096+
CompileOptions co;
3097+
co.index_mode = KernelIndexMode::INT32;
3098+
3099+
FusionExecutor fe;
3100+
NVFUSER_TEST_CUDA_ARCH_COMPILE_CHECK(
3101+
8,
3102+
0,
3103+
fe.compileFusion(
3104+
&fusion, {inputs.first, inputs.second}, LaunchParams(), co));
3105+
auto cg_outputs = fe.runFusion({inputs.first, inputs.second});
3106+
auto tref = atMatmul(
3107+
inputs.first.to(at::kFloat), inputs.second.to(at::kFloat), layout);
3108+
TORCH_CHECK(cg_outputs[0].allclose(tref, 0.001, 0.001));
3109+
}
3110+
}
3111+
30593112
#undef NVFUSER_TEST_CUDA_ARCH_GUARD
30603113

30613114
} // namespace jit

0 commit comments

Comments
 (0)