@@ -241,7 +241,7 @@ MatmulParam getMatmulParams(
241
241
return params;
242
242
}
243
243
244
- static void Nvfuser_Matmul_4warp (
244
+ static void Nvfuser_Matmul_4warp3stage (
245
245
benchmark::State& benchmark_state,
246
246
MatmulLayout layout) {
247
247
auto cta_tile = GemmTile (128 , 128 , 32 );
@@ -256,7 +256,7 @@ static void Nvfuser_Matmul_4warp(
256
256
SingleMatmulBase (benchmark_state, layout, params);
257
257
}
258
258
259
- static void Nvfuser_Matmul_8warp (
259
+ static void Nvfuser_Matmul_8warp3stage (
260
260
benchmark::State& benchmark_state,
261
261
MatmulLayout layout) {
262
262
auto cta_tile = GemmTile (256 , 128 , 32 );
@@ -271,6 +271,36 @@ static void Nvfuser_Matmul_8warp(
271
271
SingleMatmulBase (benchmark_state, layout, params);
272
272
}
273
273
274
+ static void Nvfuser_Matmul_4warp4stage (
275
+ benchmark::State& benchmark_state,
276
+ MatmulLayout layout) {
277
+ auto cta_tile = GemmTile (128 , 128 , 32 );
278
+ int number_of_stage = 4 ;
279
+
280
+ auto params = getMatmulParams (cta_tile, number_of_stage, layout);
281
+
282
+ NVFUSER_BENCHMARK_ARCH_SMEM_GUARD (
283
+ 8 , 0 , getSmemSize (cta_tile, number_of_stage), benchmark_state);
284
+
285
+ // Run benchmark:
286
+ SingleMatmulBase (benchmark_state, layout, params);
287
+ }
288
+
289
+ static void Nvfuser_Matmul_8warp4stage (
290
+ benchmark::State& benchmark_state,
291
+ MatmulLayout layout) {
292
+ auto cta_tile = GemmTile (256 , 128 , 32 );
293
+ int number_of_stage = 4 ;
294
+
295
+ auto params = getMatmulParams (cta_tile, number_of_stage, layout);
296
+
297
+ NVFUSER_BENCHMARK_ARCH_SMEM_GUARD (
298
+ 8 , 0 , getSmemSize (cta_tile, number_of_stage), benchmark_state);
299
+
300
+ // Run benchmark:
301
+ SingleMatmulBase (benchmark_state, layout, params);
302
+ }
303
+
274
304
// ----------------------------- Benchmark Instantiation-------
275
305
276
306
// Common utils:
@@ -286,21 +316,41 @@ static void Nvfuser_Matmul_8warp(
286
316
run (NT, MatmulLayout::NT)
287
317
288
318
// Instantiations:
289
- #define Nvfuser_4warp_test (layout_label, layout ) \
290
- BENCHMARK_CAPTURE ( \
291
- Nvfuser_Matmul_4warp, no_quant_nvfuser_4warp_##layout_label, layout) \
319
+ #define Nvfuser_4warp3stage_test (layout_label, layout ) \
320
+ BENCHMARK_CAPTURE ( \
321
+ Nvfuser_Matmul_4warp3stage, \
322
+ no_quant_nvfuser_4warp_##layout_label, \
323
+ layout) \
324
+ ->NO_TILE_QUANTIZATION_ARGS
325
+
326
+ #define Nvfuser_8warp3stage_test (layout_label, layout ) \
327
+ BENCHMARK_CAPTURE ( \
328
+ Nvfuser_Matmul_8warp3stage, \
329
+ no_quant_nvfuser_8warp_##layout_label, \
330
+ layout) \
331
+ ->NO_TILE_QUANTIZATION_ARGS
332
+
333
+ #define Nvfuser_4warp4stage_test (layout_label, layout ) \
334
+ BENCHMARK_CAPTURE ( \
335
+ Nvfuser_Matmul_4warp4stage, \
336
+ no_quant_nvfuser_4warp_##layout_label, \
337
+ layout) \
292
338
->NO_TILE_QUANTIZATION_ARGS
293
339
294
- #define Nvfuser_8warp_test (layout_label, layout ) \
295
- BENCHMARK_CAPTURE ( \
296
- Nvfuser_Matmul_8warp, no_quant_nvfuser_8warp_##layout_label, layout) \
340
+ #define Nvfuser_8warp4stage_test (layout_label, layout ) \
341
+ BENCHMARK_CAPTURE ( \
342
+ Nvfuser_Matmul_8warp4stage, \
343
+ no_quant_nvfuser_8warp_##layout_label, \
344
+ layout) \
297
345
->NO_TILE_QUANTIZATION_ARGS
298
346
299
347
#define Eagermode_test (layout_label, layout ) \
300
348
BENCHMARK_CAPTURE ( \
301
349
EagerModeMatmul, no_quant_eagermode_##layout_label, layout) \
302
350
->NO_TILE_QUANTIZATION_ARGS
303
351
304
- ForAllLayouts(Nvfuser_4warp_test);
305
- ForAllLayouts (Nvfuser_8warp_test);
352
+ ForAllLayouts(Nvfuser_4warp3stage_test);
353
+ ForAllLayouts (Nvfuser_4warp4stage_test);
354
+ ForAllLayouts (Nvfuser_8warp3stage_test);
355
+ ForAllLayouts (Nvfuser_8warp4stage_test);
306
356
ForAllLayouts (Eagermode_test);
0 commit comments