Skip to content

Commit 0f9f0b4

Browse files
authored
Add matmul benchmark (#2007)
1 parent 45045cd commit 0f9f0b4

File tree

2 files changed

+358
-0
lines changed

2 files changed

+358
-0
lines changed

benchmarks/cpp/nvfuser/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ if(USE_CUDA)
2020
softmax_backward.cpp
2121
scale_bias_relu.cpp
2222
transpose.cpp
23+
matmul.cpp
2324
timm.cpp
2425
utils.cpp
2526
main.cpp)

benchmarks/cpp/nvfuser/matmul.cpp

Lines changed: 357 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,357 @@
1+
#include <torch/csrc/jit/codegen/cuda/arith.h>
2+
#include <torch/csrc/jit/codegen/cuda/executor.h>
3+
#include <torch/csrc/jit/codegen/cuda/fusion.h>
4+
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
5+
#include <torch/csrc/jit/codegen/cuda/ir_utils.h>
6+
#include <torch/csrc/jit/codegen/cuda/ops/all_ops.h>
7+
#include <torch/csrc/jit/codegen/cuda/scheduler/all_schedulers.h>
8+
#include <torch/csrc/jit/codegen/cuda/scheduler/matmul.h>
9+
10+
#include <benchmark/benchmark.h>
11+
12+
#include <cuda_runtime.h>
13+
14+
#include <benchmarks/cpp/nvfuser/utils.h>
15+
16+
using namespace torch::jit::fuser::cuda;
17+
18+
bool cudaArchGuardShouldSkip(int required_major, int required_minor) {
19+
int capability_major = at::cuda::getCurrentDeviceProperties()->major;
20+
int capability_minor = at::cuda::getCurrentDeviceProperties()->minor;
21+
22+
if (capability_major < required_major ||
23+
(capability_major == required_major &&
24+
capability_minor < required_minor)) {
25+
return true;
26+
}
27+
return false;
28+
}
29+
30+
bool hasRequiredSmemSize(size_t required_size) {
31+
// Only checking device 0
32+
return at::cuda::getDeviceProperties(0)->sharedMemPerBlockOptin >=
33+
required_size;
34+
}
35+
36+
#define NVFUSER_BENCHMARK_ARCH_SMEM_GUARD( \
37+
REQUIRED_MAJOR, REQUIRED_MINOR, SMEM_SIZE, STATE) \
38+
if (cudaArchGuardShouldSkip(REQUIRED_MAJOR, REQUIRED_MINOR) || \
39+
!hasRequiredSmemSize(SMEM_SIZE)) { \
40+
STATE.SkipWithError("Unsupported arch or not enough smem!"); \
41+
return; \
42+
}
43+
44+
// util to track support matmul operand layout.
45+
using MatmulLayout = MmaOptions::MmaInputLayout;
46+
47+
static constexpr std::array<MatmulLayout, 3> kAllSupportedLayout = {
48+
MatmulLayout::TT,
49+
MatmulLayout::NT,
50+
MatmulLayout::TN};
51+
52+
// Generic interface to get matmul op with the given layout.
53+
TensorView* matmul(TensorView* a, TensorView* b, MatmulLayout layout) {
54+
TORCH_CHECK(
55+
a->nDims() == 2 && b->nDims() == 2, "only pure matmuls for these tests");
56+
TensorView *tv2 = nullptr, *tv0b = nullptr, *tv1b = nullptr;
57+
switch (layout) {
58+
case MatmulLayout::TT:
59+
tv0b = broadcast(a, {false, false, true});
60+
tv1b = broadcast(b, {true, false, false});
61+
tv2 = fusedMultiplySum(tv0b, tv1b, {1});
62+
break;
63+
case MatmulLayout::TN:
64+
tv0b = broadcast(a, {false, true, false});
65+
tv1b = broadcast(b, {true, false, false});
66+
tv2 = fusedMultiplySum(tv0b, tv1b, {2});
67+
break;
68+
case MatmulLayout::NT:
69+
tv0b = broadcast(a, {false, false, true});
70+
tv1b = broadcast(b, {false, true, false});
71+
tv2 = fusedMultiplySum(tv0b, tv1b, {0});
72+
break;
73+
default:
74+
TORCH_CHECK(false, "unsupported data layout.");
75+
}
76+
return tv2;
77+
}
78+
79+
// Utility to generate matmul input tensors based on given layout
80+
at::Tensor atMatmul(at::Tensor a, at::Tensor b, MatmulLayout layout) {
81+
switch (layout) {
82+
case MatmulLayout::TT:
83+
return a.matmul(b);
84+
case MatmulLayout::TN:
85+
return a.matmul(b.t());
86+
case MatmulLayout::NT:
87+
return a.t().matmul(b);
88+
default:
89+
TORCH_CHECK(false, "unsupported data layout.");
90+
}
91+
return at::Tensor();
92+
}
93+
94+
// Utility to generate reference results based on given layout
95+
std::pair<at::Tensor, at::Tensor> fp16MatmulAtInput(
96+
int M,
97+
int N,
98+
int K,
99+
MatmulLayout layout) {
100+
auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0);
101+
102+
switch (layout) {
103+
case MatmulLayout::TT:
104+
return std::make_pair(
105+
at::randn({M, K}, options), at::randn({K, N}, options));
106+
case MatmulLayout::TN:
107+
return std::make_pair(
108+
at::randn({M, K}, options), at::randn({N, K}, options));
109+
case MatmulLayout::NT:
110+
return std::make_pair(
111+
at::randn({K, M}, options), at::randn({K, N}, options));
112+
default:
113+
TORCH_CHECK(false, "unsupported data layout.");
114+
}
115+
return std::make_pair(at::Tensor(), at::Tensor());
116+
}
117+
118+
// TODO: separate compute and schedule definition once the can schedule
119+
// logic and pattern matching is ready.
120+
void setupMatmul(Fusion* fusion, MatmulLayout layout, MatmulParam params) {
121+
// Only hgemm on the initial setup
122+
auto a = makeContigTensor(2, DataType::Half);
123+
auto b = makeContigTensor(2, DataType::Half);
124+
125+
auto c = matmul(a, b, layout);
126+
127+
fusion->addInput(a);
128+
fusion->addInput(b);
129+
fusion->addOutput(c);
130+
131+
scheduleMatmul(c, a, b, params);
132+
}
133+
134+
static void SingleMatmulBase(
135+
benchmark::State& benchmark_state,
136+
MatmulLayout layout,
137+
MatmulParam params) {
138+
std::vector<int64_t> input_mnk{
139+
benchmark_state.range(0),
140+
benchmark_state.range(1),
141+
benchmark_state.range(2)};
142+
143+
auto fusion_ptr = std::make_unique<Fusion>();
144+
auto fusion = fusion_ptr.get();
145+
FusionGuard fg(fusion);
146+
147+
// Define fusion graph
148+
setupMatmul(fusion, layout, params);
149+
150+
// inputs
151+
at::manual_seed(0);
152+
153+
// Tensor inputs
154+
auto inputs = fp16MatmulAtInput(
155+
input_mnk.at(0), input_mnk.at(1), input_mnk.at(2), layout);
156+
157+
KernelArgumentHolder args = KernelArgumentHolder::createKernelArgumentHolder(
158+
{inputs.first, inputs.second});
159+
160+
// Always use 32b indexing mode for now.
161+
TORCH_INTERNAL_ASSERT(args.getIndexMode() == KernelIndexMode::INT32);
162+
163+
// Compile kernel
164+
FusionExecutor fe;
165+
fe.compileFusion(fusion, args, LaunchParams());
166+
167+
// Warm up run
168+
auto outputs = fe.runFusion({inputs.first, inputs.second});
169+
fe.setMeasureKernelTimeFlag(true);
170+
171+
// Sync everything up before we start
172+
for (auto _ : benchmark_state) {
173+
clearL2Cache();
174+
auto outputs = fe.runFusion({inputs.first, inputs.second});
175+
benchmark_state.SetIterationTime(fe.kernelTimeMs() / 1000.0);
176+
}
177+
// Sync everything up before we're finished, don't want to run ahead on the
178+
// cpu while benchmarking.
179+
cudaDeviceSynchronize();
180+
181+
// TODO: FLOPS calculation
182+
}
183+
184+
static void EagerModeMatmul(
185+
benchmark::State& benchmark_state,
186+
MatmulLayout layout) {
187+
std::vector<int64_t> input_mnk{
188+
benchmark_state.range(0),
189+
benchmark_state.range(1),
190+
benchmark_state.range(2)};
191+
192+
at::manual_seed(0);
193+
194+
auto inputs = fp16MatmulAtInput(
195+
input_mnk.at(0), input_mnk.at(1), input_mnk.at(2), layout);
196+
197+
// warm up run
198+
auto outputs = atMatmul(inputs.first, inputs.second, layout);
199+
200+
for (auto _ : benchmark_state) {
201+
clearL2Cache();
202+
CudaKernelTimer timer;
203+
outputs = atMatmul(inputs.first, inputs.second, layout);
204+
benchmark_state.SetIterationTime(timer.elapsed() / 1000.0);
205+
}
206+
// Sync everything up before we're finished, don't want to run ahead on the
207+
// cpu while benchmarking.
208+
cudaDeviceSynchronize();
209+
}
210+
211+
// Actual benchmarking
212+
// -----------------------------------------------------------------
213+
214+
size_t getSmemSize(GemmTile cta_tile, int stage_number) {
215+
return ((cta_tile.m * cta_tile.k) + (cta_tile.n * cta_tile.k)) *
216+
dataTypeSize(DataType::Half) * stage_number;
217+
}
218+
219+
// TODO: this part eventually will be automated by heuristics
220+
MatmulParam getMatmulParams(
221+
GemmTile cta_tile,
222+
int stage_number,
223+
MatmulLayout layout) {
224+
MatMulTileOptions gemm_tile;
225+
gemm_tile.cta_tile = cta_tile;
226+
// TODO: pipe through split K
227+
gemm_tile.warp_tile = GemmTile(64, 64, cta_tile.k);
228+
gemm_tile.instruction_tile = GemmTile(16, 16, 16);
229+
230+
// Collect mma swizzle info
231+
auto mma_builder =
232+
MmaBuilder(MmaOptions::MacroType::Ampere_16_16_16, gemm_tile)
233+
.layout(layout);
234+
235+
MatmulParam params(mma_builder);
236+
params.tile_sizes = gemm_tile;
237+
params.async_gmem_load_operands = true;
238+
params.double_buffer_options.double_buffer_smem_write = true;
239+
params.double_buffer_options.double_buffer_smem_read = true;
240+
params.double_buffer_options.smem_double_buffer_stage = stage_number;
241+
242+
return params;
243+
}
244+
245+
static void Nvfuser_Matmul_4warp3stage(
246+
benchmark::State& benchmark_state,
247+
MatmulLayout layout) {
248+
auto cta_tile = GemmTile(128, 128, 32);
249+
int number_of_stage = 3;
250+
251+
auto params = getMatmulParams(cta_tile, number_of_stage, layout);
252+
253+
NVFUSER_BENCHMARK_ARCH_SMEM_GUARD(
254+
8, 0, getSmemSize(cta_tile, number_of_stage), benchmark_state);
255+
256+
// Run benchmark:
257+
SingleMatmulBase(benchmark_state, layout, params);
258+
}
259+
260+
static void Nvfuser_Matmul_8warp3stage(
261+
benchmark::State& benchmark_state,
262+
MatmulLayout layout) {
263+
auto cta_tile = GemmTile(256, 128, 32);
264+
int number_of_stage = 3;
265+
266+
auto params = getMatmulParams(cta_tile, number_of_stage, layout);
267+
268+
NVFUSER_BENCHMARK_ARCH_SMEM_GUARD(
269+
8, 0, getSmemSize(cta_tile, number_of_stage), benchmark_state);
270+
271+
// Run benchmark:
272+
SingleMatmulBase(benchmark_state, layout, params);
273+
}
274+
275+
static void Nvfuser_Matmul_4warp4stage(
276+
benchmark::State& benchmark_state,
277+
MatmulLayout layout) {
278+
auto cta_tile = GemmTile(128, 128, 32);
279+
int number_of_stage = 4;
280+
281+
auto params = getMatmulParams(cta_tile, number_of_stage, layout);
282+
283+
NVFUSER_BENCHMARK_ARCH_SMEM_GUARD(
284+
8, 0, getSmemSize(cta_tile, number_of_stage), benchmark_state);
285+
286+
// Run benchmark:
287+
SingleMatmulBase(benchmark_state, layout, params);
288+
}
289+
290+
static void Nvfuser_Matmul_8warp4stage(
291+
benchmark::State& benchmark_state,
292+
MatmulLayout layout) {
293+
auto cta_tile = GemmTile(256, 128, 32);
294+
int number_of_stage = 4;
295+
296+
auto params = getMatmulParams(cta_tile, number_of_stage, layout);
297+
298+
NVFUSER_BENCHMARK_ARCH_SMEM_GUARD(
299+
8, 0, getSmemSize(cta_tile, number_of_stage), benchmark_state);
300+
301+
// Run benchmark:
302+
SingleMatmulBase(benchmark_state, layout, params);
303+
}
304+
305+
// ----------------------------- Benchmark Instantiation-------
306+
307+
// Common utils:
308+
#define NO_TILE_QUANTIZATION_ARGS \
309+
ArgsProduct( \
310+
{{2048}, {3456}, benchmark::CreateDenseRange(512, 4096, /*step=*/512)}) \
311+
->Unit(benchmark::kMicrosecond) \
312+
->UseManualTime();
313+
314+
#define ForAllLayouts(run) \
315+
run(TT, MatmulLayout::TT); \
316+
run(TN, MatmulLayout::TN); \
317+
run(NT, MatmulLayout::NT)
318+
319+
// Instantiations:
320+
#define Nvfuser_4warp3stage_test(layout_label, layout) \
321+
BENCHMARK_CAPTURE( \
322+
Nvfuser_Matmul_4warp3stage, \
323+
no_quant_nvfuser_4warp_##layout_label, \
324+
layout) \
325+
->NO_TILE_QUANTIZATION_ARGS
326+
327+
#define Nvfuser_8warp3stage_test(layout_label, layout) \
328+
BENCHMARK_CAPTURE( \
329+
Nvfuser_Matmul_8warp3stage, \
330+
no_quant_nvfuser_8warp_##layout_label, \
331+
layout) \
332+
->NO_TILE_QUANTIZATION_ARGS
333+
334+
#define Nvfuser_4warp4stage_test(layout_label, layout) \
335+
BENCHMARK_CAPTURE( \
336+
Nvfuser_Matmul_4warp4stage, \
337+
no_quant_nvfuser_4warp_##layout_label, \
338+
layout) \
339+
->NO_TILE_QUANTIZATION_ARGS
340+
341+
#define Nvfuser_8warp4stage_test(layout_label, layout) \
342+
BENCHMARK_CAPTURE( \
343+
Nvfuser_Matmul_8warp4stage, \
344+
no_quant_nvfuser_8warp_##layout_label, \
345+
layout) \
346+
->NO_TILE_QUANTIZATION_ARGS
347+
348+
#define Eagermode_test(layout_label, layout) \
349+
BENCHMARK_CAPTURE( \
350+
EagerModeMatmul, no_quant_eagermode_##layout_label, layout) \
351+
->NO_TILE_QUANTIZATION_ARGS
352+
353+
ForAllLayouts(Nvfuser_4warp3stage_test);
354+
ForAllLayouts(Nvfuser_4warp4stage_test);
355+
ForAllLayouts(Nvfuser_8warp3stage_test);
356+
ForAllLayouts(Nvfuser_8warp4stage_test);
357+
ForAllLayouts(Eagermode_test);

0 commit comments

Comments
 (0)