@@ -447,7 +447,32 @@ TEST_F(NVFuserTest, FusionFrontendBasic_CUDA) {
447
447
448
448
std::vector<IValue> inputs = {t0, t1};
449
449
450
- // Define fusion
450
+ Fusion fauto;
451
+ { // Do automatic scheduling on fauto
452
+ FusionGuard fg (&fauto);
453
+
454
+ auto tv0 = makeSymbolicTensor (3 );
455
+ auto tv1 = makeSymbolicTensor (3 );
456
+ auto c0 = IrBuilder::create<Double>(3.0 );
457
+
458
+ fauto.addInput (tv0);
459
+ fauto.addInput (tv1);
460
+
461
+ auto tv2 = add (tv0, tv1);
462
+ auto tv3 = mul (tv2, c0);
463
+ auto tv4 = sum (tv3, {-1 }, false , DataType::Float);
464
+
465
+ fauto.addOutput (tv4);
466
+
467
+ // Run automatic scheduler
468
+ auto reduction_params = getReductionHeuristics (&fauto, inputs);
469
+ TORCH_CHECK (reduction_params, " Reduction schedule was not generated!" );
470
+ scheduleReduction (&fauto, *reduction_params);
471
+ }
472
+
473
+ // Re-define the fusion exactly for manual scheduling
474
+ // This is necessary in order to catch all the constructors inside each
475
+ // Fusion independently.
451
476
Fusion fusion;
452
477
FusionGuard fg (&fusion);
453
478
@@ -464,39 +489,55 @@ TEST_F(NVFuserTest, FusionFrontendBasic_CUDA) {
464
489
465
490
fusion.addOutput (tv4);
466
491
467
- // Run automatic scheduler
468
- auto fauto = Fusion (fusion); // unique_ptr to copy of fusion
469
- auto reduction_params = getReductionHeuristics (&fauto, inputs);
470
- TORCH_CHECK (reduction_params, " Reduction schedule was not generated!" );
471
- scheduleReduction (&fauto, *reduction_params);
472
-
473
492
// Perform manual scheduling
474
- tv4->merge (0 , 1 ); // {i0*i1, i2}
475
- tv4->split (
476
- 1 ,
477
- NamedScalar::getParallelDim (
478
- ParallelType::TIDx)); // {i0*i1, r2 / bDx, bDx}
479
- tv4->split (-2 , 1 );
480
- tv4->reorder ({{-2 , -1 }, {-1 , -2 }});
481
- tv4->split (0 , 2 );
482
- tv4->reorder ({{1 , 2 }, {2 , 1 }});
483
- tv4->split (0 , 1 );
484
- tv4->reorder ({{1 , 2 }, {2 , 1 }});
485
- tv4->axis (0 )->parallelize (ParallelType::BIDx);
486
- tv4->axis (2 )->parallelize (ParallelType::Unswitch);
487
- tv4->axis (3 )->parallelize (ParallelType::Unroll);
488
- tv4->axis (4 )->parallelize (ParallelType::TIDx);
489
- tv4->axis (5 )->parallelize (ParallelType::Unswitch);
490
493
491
- auto tv5 = tv0->cacheAfter ();
492
- auto tv6 = tv1->cacheAfter ();
493
- auto tv7 = tv4->cacheBefore ();
494
+ auto tv5 = tv0->cacheAfter (); // tv5
495
+ auto tv6 = tv1->cacheAfter (); // tv6
496
+ auto tv7 = tv4->cacheBefore (); // tv7
497
+
498
+ tv7->reorder ({{2 , 0 }});
499
+ tv7->merge (1 , 2 );
500
+ tv7->reorder ({{1 , 0 }});
501
+ tv7->split (1 , NamedScalar::getParallelDim (ParallelType::TIDx));
502
+ tv7->axis (2 )->parallelize (ParallelType::TIDx);
503
+ tv7->split (1 , 1 );
504
+ tv7->axis (2 )->parallelize (ParallelType::Unswitch);
505
+ tv7->split (0 , 2 );
506
+ tv7->axis (1 )->parallelize (ParallelType::Unroll);
507
+ tv7->split (0 , 1 );
508
+ tv7->axis (1 )->parallelize (ParallelType::Unswitch);
509
+ tv7->axis (0 )->parallelize (ParallelType::BIDx);
510
+
511
+ tv7->reorder ({{0 , 0 }, {1 , 2 }, {2 , 3 }, {3 , 1 }, {4 , 5 }, {5 , 4 }});
512
+
494
513
auto tv8 = tv7->rFactor ({1 , 5 });
495
514
515
+ // NOTE: see multiReductionInliner for more info on how propagation and
516
+ // inlining works in the reduction scheduler
517
+
496
518
// propagate the mapping to other tensors
497
- TransformPropagatorWithCheck propagator (tv7);
498
- MaxRootDomainInfoSpanningTree (tv7).traverse (&propagator);
499
- scheduler_utils::parallelizeAllLike (tv7, {tv2, tv3, tv4, tv5, tv6, tv8});
519
+ TransformPropagatorWithCheck propagator (tv8);
520
+ MaxRootDomainInfoSpanningTree (tv8).traverse (&propagator);
521
+ // Propagate parallelization except vectorization and unrolling
522
+ scheduler_utils::parallelizeAllLike (
523
+ tv8,
524
+ {},
525
+ allParallelTypesExcept (
526
+ {ParallelType::Unroll,
527
+ ParallelType::Vectorize,
528
+ ParallelType::MisalignedVectorize}));
529
+ // Propagate vectorization/unrolling to those tensors that need it
530
+ scheduler_utils::parallelizeAllLike (
531
+ tv8,
532
+ {tv4, tv6, tv5},
533
+ {
534
+ ParallelType::Unroll,
535
+ ParallelType::Vectorize,
536
+ ParallelType::MisalignedVectorize,
537
+ });
538
+ // If reference shouldn't be unrolled, clear that parallel type.
539
+ tv8->axis (3 )->parallelize (ParallelType::Serial);
540
+ tv7->axis (2 )->parallelize (ParallelType::Serial);
500
541
501
542
inlineMost ();
502
543
0 commit comments