@@ -24872,7 +24872,7 @@ TEST_F(NVFuserTest, FusionInsertMagicZero1_CUDA) {
24872
24872
tv2->reorder({{1, 2}, {2, 1}});
24873
24873
tv2->merge(0);
24874
24874
24875
- TransformPropagator propagator(tv2);
24875
+ TransformPropagatorWithCheck propagator(tv2);
24876
24876
MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator);
24877
24877
24878
24878
tv0->computeAt(tv2, 1);
@@ -24992,7 +24992,7 @@ TEST_F(NVFuserTest, FusionExpandReduce2_CUDA) {
24992
24992
// [iBIDx, iTIDx, rTIDy, rBIDy, rO]
24993
24993
auto tv3 = tv2->rFactor({-1});
24994
24994
24995
- TransformPropagator propagator(tv3);
24995
+ TransformPropagatorWithCheck propagator(tv3);
24996
24996
MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator);
24997
24997
scheduler_utils::parallelizeAllLike(tv3);
24998
24998
tv0->computeAt(tv3, -1, ComputeAtMode::MostInlined);
@@ -25693,6 +25693,77 @@ TEST_F(NVFuserTest, AsyncCompilation_CUDA) {
25693
25693
executor_cache.fusion(), outputs, aten_inputs, {t6}, __LINE__, __FILE__);
25694
25694
}
25695
25695
25696
+ TEST_F(NVFuserTest, FusionMergeBroadcastingTrivialReduction1_CUDA) {
25697
+ std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
25698
+ auto fusion = fusion_ptr.get();
25699
+ FusionGuard fg(fusion);
25700
+
25701
+ TensorView* tv0 = makeConcreteTensor({1, 1});
25702
+ TensorView* tv1 = makeConcreteTensor({-1});
25703
+ fusion->addInput(tv0);
25704
+ fusion->addInput(tv1);
25705
+ auto tv2 = sum(tv0, {1});
25706
+ auto tv3 = add(tv2, tv1);
25707
+ fusion->addOutput(tv3);
25708
+
25709
+ tv0->merge(0);
25710
+
25711
+ MaxRootDomainInfoSpanningTree tree(tv0);
25712
+ TransformPropagatorWithCheck tp(tv0);
25713
+ tree.traverse(&tp);
25714
+
25715
+ InlinePropagator ip(tv0, -1, ComputeAtMode::MostInlined);
25716
+ tree.traverse(&ip);
25717
+
25718
+ auto options = at::TensorOptions().dtype(kFloat).device(at::kCUDA, 0);
25719
+ at::Tensor t0 = at::randn({1, 1}, options);
25720
+ at::Tensor t1 = at::randn({10}, options);
25721
+
25722
+ FusionExecutor fe;
25723
+ fe.compileFusion(fusion, {t0, t1});
25724
+ auto cg_outputs = fe.runFusion({t0, t1});
25725
+ auto out = cg_outputs[0];
25726
+
25727
+ testValidate(
25728
+ fusion, {out}, {t0, t1}, {t1 + t0.flatten()}, __LINE__, __FILE__);
25729
+ }
25730
+
25731
+ TEST_F(NVFuserTest, FusionMergeBroadcastingTrivialReduction2_CUDA) {
25732
+ std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
25733
+ auto fusion = fusion_ptr.get();
25734
+ FusionGuard fg(fusion);
25735
+
25736
+ TensorView* tv0 = makeConcreteTensor({-1, 1, 1});
25737
+ TensorView* tv1 = makeConcreteTensor({-1, -1});
25738
+ fusion->addInput(tv0);
25739
+ fusion->addInput(tv1);
25740
+ auto tv2 = sum(tv0, {1});
25741
+ auto tv3 = add(tv2, tv1);
25742
+ fusion->addOutput(tv3);
25743
+
25744
+ tv2->merge(1);
25745
+ tv2->merge(0);
25746
+
25747
+ MaxRootDomainInfoSpanningTree tree(tv0);
25748
+ TransformPropagatorWithCheck tp(tv0);
25749
+ tree.traverse(&tp);
25750
+
25751
+ InlinePropagator ip(tv0, -1, ComputeAtMode::MostInlined);
25752
+ tree.traverse(&ip);
25753
+
25754
+ auto options = at::TensorOptions().dtype(kFloat).device(at::kCUDA, 0);
25755
+ at::Tensor t0 = at::randn({10, 1, 1}, options);
25756
+ at::Tensor t1 = at::randn({10, 10}, options);
25757
+
25758
+ FusionExecutor fe;
25759
+ fe.compileFusion(fusion, {t0, t1});
25760
+ auto cg_outputs = fe.runFusion({t0, t1});
25761
+ auto out = cg_outputs[0];
25762
+
25763
+ testValidate(
25764
+ fusion, {out}, {t0, t1}, {t1 + t0.squeeze(-1)}, __LINE__, __FILE__);
25765
+ }
25766
+
25696
25767
} // namespace jit
25697
25768
} // namespace torch
25698
25769
#endif // #if defined(USE_CUDA)
0 commit comments