Skip to content

Commit 51589d3

Browse files
authored
Some cleanups on tests and heuristics params (#1866)
Refactor heuristics params to make it more extensible
1 parent a6b3e70 commit 51589d3

25 files changed

+458
-581
lines changed

benchmarks/cpp/nvfuser/bert.cpp

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,8 @@ static void MagicScheduler_DivMaxSoftDropFwd(
133133
std::vector<at::Tensor> cg_outputs;
134134

135135
auto norm_params = getPersistentHeuristics(&fusion, at_inputs);
136-
TORCH_CHECK(norm_params.has_value(), "Norm scheduler can't be used!");
137-
schedulePersistentKernel(&fusion, norm_params.value());
136+
TORCH_CHECK(norm_params != nullptr, "Norm scheduler can't be used!");
137+
schedulePersistentKernel(&fusion, *norm_params);
138138

139139
FusionExecutor fe;
140140
fe.compileFusion(&fusion);
@@ -143,7 +143,7 @@ static void MagicScheduler_DivMaxSoftDropFwd(
143143
cudaDeviceSynchronize();
144144
for (auto _ : benchmark_state) {
145145
CudaKernelTimer timer;
146-
cg_outputs = fe.runFusion({t0, t1}, norm_params.value().lparams);
146+
cg_outputs = fe.runFusion({t0, t1}, norm_params->lparams);
147147
benchmark_state.SetIterationTime(fe.kernelTimeMs() / 1000.0);
148148
}
149149
// Sync everything up before we're finished, don't want to run ahead on the
@@ -193,8 +193,8 @@ static void MagicScheduler_DivMaxSoftDropBwd(
193193
std::vector<at::Tensor> cg_outputs;
194194

195195
auto norm_params = getPersistentHeuristics(&fusion, at_inputs);
196-
TORCH_CHECK(norm_params.has_value(), "Norm scheduler can't be used!");
197-
schedulePersistentKernel(&fusion, norm_params.value());
196+
TORCH_CHECK(norm_params != nullptr, "Norm scheduler can't be used!");
197+
schedulePersistentKernel(&fusion, *norm_params);
198198

199199
FusionExecutor fe;
200200
fe.compileFusion(&fusion);
@@ -203,7 +203,7 @@ static void MagicScheduler_DivMaxSoftDropBwd(
203203
cudaDeviceSynchronize();
204204
for (auto _ : benchmark_state) {
205205
CudaKernelTimer timer;
206-
cg_outputs = fe.runFusion({t0, t1, t2, t3}, norm_params.value().lparams);
206+
cg_outputs = fe.runFusion({t0, t1, t2, t3}, norm_params->lparams);
207207
benchmark_state.SetIterationTime(fe.kernelTimeMs() / 1000.0);
208208
}
209209
// Sync everything up before we're finished, don't want to run ahead on the
@@ -308,8 +308,8 @@ static void MagicScheduler_BiasDropoutAddLayernormFwd(
308308
std::vector<at::Tensor> cg_outputs;
309309

310310
auto norm_params = getPersistentHeuristics(&fusion, at_inputs);
311-
TORCH_CHECK(norm_params.has_value(), "Norm scheduler can't be used!");
312-
schedulePersistentKernel(&fusion, norm_params.value());
311+
TORCH_CHECK(norm_params != nullptr, "Norm scheduler can't be used!");
312+
schedulePersistentKernel(&fusion, *norm_params);
313313

314314
FusionExecutor fe;
315315
fe.compileFusion(&fusion);
@@ -319,7 +319,7 @@ static void MagicScheduler_BiasDropoutAddLayernormFwd(
319319
cudaDeviceSynchronize();
320320
for (auto _ : benchmark_state) {
321321
CudaKernelTimer timer;
322-
cg_outputs = fe.runFusion(at_inputs, norm_params.value().lparams);
322+
cg_outputs = fe.runFusion(at_inputs, norm_params->lparams);
323323
benchmark_state.SetIterationTime(fe.kernelTimeMs() / 1000.0);
324324
}
325325
// Sync everything up before we're finished, don't want to run ahead on the
@@ -423,8 +423,8 @@ static void MagicScheduler_BiasDropoutAddLayernormBwd1(
423423
std::vector<at::Tensor> cg_outputs;
424424

425425
auto norm_params = getReductionHeuristics(&fusion, at_inputs);
426-
TORCH_CHECK(norm_params.has_value(), "Norm scheduler can't be used!");
427-
scheduleReduction(&fusion, norm_params.value());
426+
TORCH_CHECK(norm_params != nullptr, "Norm scheduler can't be used!");
427+
scheduleReduction(&fusion, *norm_params);
428428

429429
FusionExecutor fe;
430430
fe.compileFusion(&fusion);
@@ -434,7 +434,7 @@ static void MagicScheduler_BiasDropoutAddLayernormBwd1(
434434
cudaDeviceSynchronize();
435435
for (auto _ : benchmark_state) {
436436
clearL2Cache();
437-
cg_outputs = fe.runFusion(at_inputs, norm_params.value().lparams);
437+
cg_outputs = fe.runFusion(at_inputs, norm_params->lparams);
438438
benchmark_state.SetIterationTime(fe.kernelTimeMs() / 1000.0);
439439
}
440440
// Sync everything up before we're finished, don't want to run ahead on the
@@ -534,8 +534,8 @@ static void MagicScheduler_BiasDropoutAddLayernormBwd2(
534534
std::vector<at::Tensor> cg_outputs;
535535

536536
auto norm_params = getPersistentHeuristics(&fusion, at_inputs);
537-
TORCH_CHECK(norm_params.has_value(), "Norm scheduler can't be used!");
538-
schedulePersistentKernel(&fusion, norm_params.value());
537+
TORCH_CHECK(norm_params != nullptr, "Norm scheduler can't be used!");
538+
schedulePersistentKernel(&fusion, *norm_params);
539539

540540
FusionExecutor fe;
541541
fe.compileFusion(&fusion);
@@ -545,7 +545,7 @@ static void MagicScheduler_BiasDropoutAddLayernormBwd2(
545545
cudaDeviceSynchronize();
546546
for (auto _ : benchmark_state) {
547547
CudaKernelTimer timer;
548-
cg_outputs = fe.runFusion(at_inputs, norm_params.value().lparams);
548+
cg_outputs = fe.runFusion(at_inputs, norm_params->lparams);
549549
benchmark_state.SetIterationTime(fe.kernelTimeMs() / 1000.0);
550550
}
551551
// Sync everything up before we're finished, don't want to run ahead on the
@@ -625,8 +625,8 @@ static void MagicScheduler_BiasDropoutAddLayernormBwd3(
625625
std::vector<at::Tensor> cg_outputs;
626626

627627
auto norm_params = getReductionHeuristics(&fusion, at_inputs);
628-
TORCH_CHECK(norm_params.has_value(), "Norm scheduler can't be used!");
629-
scheduleReduction(&fusion, norm_params.value());
628+
TORCH_CHECK(norm_params != nullptr, "Norm scheduler can't be used!");
629+
scheduleReduction(&fusion, *norm_params);
630630

631631
FusionExecutor fe;
632632
fe.compileFusion(&fusion);
@@ -636,7 +636,7 @@ static void MagicScheduler_BiasDropoutAddLayernormBwd3(
636636
cudaDeviceSynchronize();
637637
for (auto _ : benchmark_state) {
638638
CudaKernelTimer timer;
639-
cg_outputs = fe.runFusion(at_inputs, norm_params.value().lparams);
639+
cg_outputs = fe.runFusion(at_inputs, norm_params->lparams);
640640
benchmark_state.SetIterationTime(fe.kernelTimeMs() / 1000.0);
641641
}
642642
// Sync everything up before we're finished, don't want to run ahead on the

benchmarks/cpp/nvfuser/broadcast.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,7 @@ static void NvFuserScheduler_Broadcast(
6969

7070
auto compile_log = fusion_executor_cache->getMostRecentExecutorInfo();
7171
auto executor_instance = compile_log.fusion_executor;
72-
TORCH_INTERNAL_ASSERT(compile_log.pointwise_params.has_value());
73-
auto params = toString(compile_log.pointwise_params.value());
72+
auto params = toString(compile_log.params);
7473
auto lparams = toString(compile_log.fusion_executor->lastLaunchParams());
7574

7675
benchmark_state.SetLabel(params + lparams);

benchmarks/cpp/nvfuser/reduction.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,7 @@ static void NvFuserScheduler_Reduction(
6565

6666
auto compile_log = fusion_executor_cache->getMostRecentExecutorInfo();
6767
auto executor_instance = compile_log.fusion_executor;
68-
TORCH_INTERNAL_ASSERT(compile_log.reduction_params.has_value());
69-
auto rparams = toString(compile_log.reduction_params.value());
68+
auto rparams = toString(compile_log.params);
7069
auto lparams = toString(compile_log.fusion_executor->lastLaunchParams());
7170

7271
benchmark_state.SetLabel(rparams + lparams);

benchmarks/cpp/nvfuser/scale_bias_relu.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,7 @@ static void NvFuserScheduler_SBR(
135135

136136
auto compile_log = fusion_executor_cache->getMostRecentExecutorInfo();
137137
auto executor_instance = compile_log.fusion_executor;
138-
TORCH_INTERNAL_ASSERT(compile_log.pointwise_params.has_value());
139-
auto params = toString(compile_log.pointwise_params.value());
138+
auto params = toString(compile_log.params);
140139
auto lparams = toString(compile_log.fusion_executor->lastLaunchParams());
141140

142141
benchmark_state.SetLabel(params + lparams);
@@ -238,8 +237,7 @@ static void NvFuserScheduler_SBR_Norm(
238237

239238
auto compile_log = fusion_executor_cache->getMostRecentExecutorInfo();
240239
auto executor_instance = compile_log.fusion_executor;
241-
TORCH_INTERNAL_ASSERT(compile_log.pointwise_params.has_value());
242-
auto params = toString(compile_log.pointwise_params.value());
240+
auto params = toString(compile_log.params);
243241
auto lparams = toString(compile_log.fusion_executor->lastLaunchParams());
244242

245243
benchmark_state.SetLabel(params + lparams);

benchmarks/cpp/nvfuser/utils.cpp

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,20 @@ std::string toString(PointwiseParams params) {
8989
return ss.str();
9090
}
9191

92+
std::string toString(const std::shared_ptr<HeuristicParams>& params) {
93+
auto rparams = std::dynamic_pointer_cast<ReductionParams>(params);
94+
if (rparams) {
95+
return toString(*rparams);
96+
}
97+
auto pparams = std::dynamic_pointer_cast<PointwiseParams>(params);
98+
if (pparams) {
99+
return toString(*pparams);
100+
}
101+
TORCH_INTERNAL_ASSERT(
102+
false,
103+
"Unknown heuristic parameter type. Did you just added a new heuristic parameter type but forget to update here?");
104+
}
105+
92106
std::string toString(LaunchParams lparams) {
93107
std::stringstream ss;
94108
lparams.toString();
@@ -123,9 +137,7 @@ TensorView* makeContigTensor(size_t ndims, DataType dtype) {
123137
.build();
124138
}
125139

126-
TensorView* makeConcreteTensor(
127-
std::vector<int64_t> shape,
128-
DataType dtype) {
140+
TensorView* makeConcreteTensor(std::vector<int64_t> shape, DataType dtype) {
129141
return TensorViewBuilder().shape(shape).dtype(dtype).build();
130142
}
131143

@@ -157,15 +169,9 @@ void runBenchmarkIterations(
157169
auto compile_log = fusion_executor_cache->getMostRecentExecutorInfo();
158170
auto executor_instance = compile_log.fusion_executor;
159171

160-
if (compile_log.reduction_params.has_value()) {
161-
auto rparams = toString(compile_log.reduction_params.value());
162-
auto lparams = toString(compile_log.fusion_executor->lastLaunchParams());
163-
benchmark_state.SetLabel(rparams + lparams);
164-
} else if (compile_log.pointwise_params.has_value()){
165-
auto pparams = toString(compile_log.pointwise_params.value());
166-
auto lparams = toString(compile_log.fusion_executor->lastLaunchParams());
167-
benchmark_state.SetLabel(pparams + lparams);
168-
}
172+
auto params = toString(compile_log.params);
173+
auto lparams = toString(compile_log.fusion_executor->lastLaunchParams());
174+
benchmark_state.SetLabel(params + lparams);
169175

170176
executor_instance->setMeasureKernelTimeFlag(true);
171177

benchmarks/cpp/nvfuser/utils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ TensorView* makeContigConcreteTensor(
3838

3939
std::string toString(ReductionParams rparams);
4040
std::string toString(PointwiseParams params);
41+
std::string toString(const std::shared_ptr<HeuristicParams>& params);
4142
std::string toString(LaunchParams lparams);
4243

4344
// Run benchmark iterations with provided inputs. If not segmented, report

torch/csrc/jit/codegen/cuda/kernel_cache.cpp

Lines changed: 6 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -341,32 +341,16 @@ std::vector<at::Tensor> FusionKernelRuntime::runKernelWithInput(
341341
options.index_mode = scheduler_entry->indexMode();
342342
FusionGuard fg(fusion_to_run.get());
343343
scheduler_entry->schedule(fusion_to_run.get());
344-
// Load launch params for reduction and normalization kernels
345-
if (scheduler_entry->hasReductionParam()) {
346-
launch_params = scheduler_entry->reductionParams().lparams;
347-
} else {
348-
launch_params = scheduler_entry->pointwiseParams().lparams;
349-
}
344+
launch_params = scheduler_entry->params()->lparams;
350345
executors_[group_id].compileFusion(
351346
fusion_to_run.get(), inputs, launch_params, options);
352347
} else {
353-
// Load launch params for reduction and normalization kernels
354-
if (scheduler_entry->hasReductionParam()) {
355-
launch_params = scheduler_entry->reductionParams().lparams;
356-
} else {
357-
launch_params = scheduler_entry->pointwiseParams().lparams;
358-
}
348+
launch_params = scheduler_entry->params()->lparams;
359349
}
360350

361351
if (profiling_) {
362352
most_recent_executor_log_.fusion_executor = &executors_[group_id];
363-
if (scheduler_entry->hasReductionParam()) {
364-
most_recent_executor_log_.reduction_params =
365-
scheduler_entry->reductionParams();
366-
} else {
367-
most_recent_executor_log_.pointwise_params =
368-
scheduler_entry->pointwiseParams();
369-
}
353+
most_recent_executor_log_.params = scheduler_entry->params()->clone();
370354
}
371355

372356
auto& executor = executors_[group_id];
@@ -395,11 +379,7 @@ std::vector<at::Tensor> FusionKernelRuntime::runKernelWithInput(
395379
}
396380
}
397381
std::cout << "Compiler log: " << executor.compilerLog() << "\n";
398-
if (scheduler_entry->hasReductionParam()) {
399-
std::cout << scheduler_entry->reductionParams().toString() << "\n";
400-
} else {
401-
std::cout << scheduler_entry->pointwiseParams().toString() << "\n";
402-
}
382+
std::cout << scheduler_entry->params()->toString() << "\n";
403383
std::cout << "With arguments: " << executor.lastLaunchParams().toString();
404384
std::cout << executor.kernelName() << " " << executor.bytesProcessed()
405385
<< " bytes/ " << std::setprecision(3) << executor.kernelTimeMs()
@@ -604,13 +584,8 @@ void FusionKernelRuntime::updateHeuristicsLaunchParams(
604584
update_heuristics->heuristicsList().size() == scheduler_list_length);
605585
for (const auto i : c10::irange(scheduler_list_length)) {
606586
auto& schedulerPtr = heuristics_->heuristicsList()[i];
607-
if (schedulerPtr->hasReductionParam()) {
608-
schedulerPtr->updateLaunchConstraint(
609-
update_heuristics->heuristicsList()[i]->reductionParams().lparams);
610-
} else {
611-
schedulerPtr->updateLaunchConstraint(
612-
update_heuristics->heuristicsList()[i]->pointwiseParams().lparams);
613-
}
587+
schedulerPtr->updateLaunchConstraint(
588+
update_heuristics->heuristicsList()[i]->params()->lparams);
614589
}
615590
}
616591

torch/csrc/jit/codegen/cuda/kernel_cache.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@ class SchedulerRuntimeInfo;
2525

2626
// Utilities for benchmarking and profiling
2727
struct ExecutorLog {
28-
c10::optional<ReductionParams> reduction_params = c10::nullopt;
29-
c10::optional<PointwiseParams> pointwise_params = c10::nullopt;
28+
std::shared_ptr<HeuristicParams> params = nullptr;
3029
FusionExecutor* fusion_executor = nullptr;
3130
};
3231

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#pragma once
2+
3+
#include <torch/csrc/jit/codegen/cuda/executor_launch_params.h>
4+
5+
#include <string>
6+
7+
namespace torch {
8+
namespace jit {
9+
namespace fuser {
10+
namespace cuda {
11+
12+
class HeuristicParams {
13+
public:
14+
std::string tag = "";
15+
16+
LaunchParams lparams;
17+
18+
virtual std::string toString() const {
19+
return "Undefined Heuristic Params";
20+
}
21+
22+
virtual size_t hash() const = 0;
23+
24+
virtual ~HeuristicParams() = default;
25+
26+
virtual bool sameAs(const std::shared_ptr<HeuristicParams>& other) const = 0;
27+
28+
virtual std::shared_ptr<HeuristicParams> clone() const = 0;
29+
30+
HeuristicParams() = default;
31+
HeuristicParams(const std::string& tag) : tag(tag) {}
32+
};
33+
34+
} // namespace cuda
35+
} // namespace fuser
36+
} // namespace jit
37+
} // namespace torch

0 commit comments

Comments
 (0)