Skip to content

Commit e667b74

Browse files
committed
Make FrontendBasic test match auto schedule
Again I used the fusion_debug dump from #2326 to trace what the reduction scheduler is doing. This time I learned about multiReductionInliner, which uses two calls to parallelizeAllLike for different types of ParallelTypes, followed by an undoing of unrolling and vectorization on the reference tensor. The need for the latter is still a little unclear to me.
1 parent f09055e commit e667b74

File tree

1 file changed

+70
-29
lines changed

1 file changed

+70
-29
lines changed

third_party/nvfuser/test/test_gpu_match_frontend.cpp

Lines changed: 70 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,32 @@ TEST_F(NVFuserTest, FusionFrontendBasic_CUDA) {
447447

448448
std::vector<IValue> inputs = {t0, t1};
449449

450-
// Define fusion
450+
Fusion fauto;
451+
{ // Do automatic scheduling on fauto
452+
FusionGuard fg(&fauto);
453+
454+
auto tv0 = makeSymbolicTensor(3);
455+
auto tv1 = makeSymbolicTensor(3);
456+
auto c0 = IrBuilder::create<Double>(3.0);
457+
458+
fauto.addInput(tv0);
459+
fauto.addInput(tv1);
460+
461+
auto tv2 = add(tv0, tv1);
462+
auto tv3 = mul(tv2, c0);
463+
auto tv4 = sum(tv3, {-1}, false, DataType::Float);
464+
465+
fauto.addOutput(tv4);
466+
467+
// Run automatic scheduler
468+
auto reduction_params = getReductionHeuristics(&fauto, inputs);
469+
TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
470+
scheduleReduction(&fauto, *reduction_params);
471+
}
472+
473+
// Re-define the fusion exactly for manual scheduling
474+
// This is necessary in order to catch all the constructors inside each
475+
// Fusion independently.
451476
Fusion fusion;
452477
FusionGuard fg(&fusion);
453478

@@ -464,39 +489,55 @@ TEST_F(NVFuserTest, FusionFrontendBasic_CUDA) {
464489

465490
fusion.addOutput(tv4);
466491

467-
// Run automatic scheduler
468-
auto fauto = Fusion(fusion); // unique_ptr to copy of fusion
469-
auto reduction_params = getReductionHeuristics(&fauto, inputs);
470-
TORCH_CHECK(reduction_params, "Reduction schedule was not generated!");
471-
scheduleReduction(&fauto, *reduction_params);
472-
473492
// Perform manual scheduling
474-
tv4->merge(0, 1); // {i0*i1, i2}
475-
tv4->split(
476-
1,
477-
NamedScalar::getParallelDim(
478-
ParallelType::TIDx)); // {i0*i1, r2 / bDx, bDx}
479-
tv4->split(-2, 1);
480-
tv4->reorder({{-2, -1}, {-1, -2}});
481-
tv4->split(0, 2);
482-
tv4->reorder({{1, 2}, {2, 1}});
483-
tv4->split(0, 1);
484-
tv4->reorder({{1, 2}, {2, 1}});
485-
tv4->axis(0)->parallelize(ParallelType::BIDx);
486-
tv4->axis(2)->parallelize(ParallelType::Unswitch);
487-
tv4->axis(3)->parallelize(ParallelType::Unroll);
488-
tv4->axis(4)->parallelize(ParallelType::TIDx);
489-
tv4->axis(5)->parallelize(ParallelType::Unswitch);
490493

491-
auto tv5 = tv0->cacheAfter();
492-
auto tv6 = tv1->cacheAfter();
493-
auto tv7 = tv4->cacheBefore();
494+
auto tv5 = tv0->cacheAfter(); // tv5
495+
auto tv6 = tv1->cacheAfter(); // tv6
496+
auto tv7 = tv4->cacheBefore(); // tv7
497+
498+
tv7->reorder({{2, 0}});
499+
tv7->merge(1, 2);
500+
tv7->reorder({{1, 0}});
501+
tv7->split(1, NamedScalar::getParallelDim(ParallelType::TIDx));
502+
tv7->axis(2)->parallelize(ParallelType::TIDx);
503+
tv7->split(1, 1);
504+
tv7->axis(2)->parallelize(ParallelType::Unswitch);
505+
tv7->split(0, 2);
506+
tv7->axis(1)->parallelize(ParallelType::Unroll);
507+
tv7->split(0, 1);
508+
tv7->axis(1)->parallelize(ParallelType::Unswitch);
509+
tv7->axis(0)->parallelize(ParallelType::BIDx);
510+
511+
tv7->reorder({{0, 0}, {1, 2}, {2, 3}, {3, 1}, {4, 5}, {5, 4}});
512+
494513
auto tv8 = tv7->rFactor({1, 5});
495514

515+
// NOTE: see multiReductionInliner for more info on how propagation and
516+
// inlining works in the reduction scheduler
517+
496518
// propagate the mapping to other tensors
497-
TransformPropagatorWithCheck propagator(tv7);
498-
MaxRootDomainInfoSpanningTree(tv7).traverse(&propagator);
499-
scheduler_utils::parallelizeAllLike(tv7, {tv2, tv3, tv4, tv5, tv6, tv8});
519+
TransformPropagatorWithCheck propagator(tv8);
520+
MaxRootDomainInfoSpanningTree(tv8).traverse(&propagator);
521+
// Propagate parallelization except vectorization and unrolling
522+
scheduler_utils::parallelizeAllLike(
523+
tv8,
524+
{},
525+
allParallelTypesExcept(
526+
{ParallelType::Unroll,
527+
ParallelType::Vectorize,
528+
ParallelType::MisalignedVectorize}));
529+
// Propagate vectorization/unrolling to those tensors that need it
530+
scheduler_utils::parallelizeAllLike(
531+
tv8,
532+
{tv4, tv6, tv5},
533+
{
534+
ParallelType::Unroll,
535+
ParallelType::Vectorize,
536+
ParallelType::MisalignedVectorize,
537+
});
538+
// If reference shouldn't be unrolled, clear that parallel type.
539+
tv8->axis(3)->parallelize(ParallelType::Serial);
540+
tv7->axis(2)->parallelize(ParallelType::Serial);
500541

501542
inlineMost();
502543

0 commit comments

Comments
 (0)