@@ -2674,23 +2674,44 @@ void testGPU_FusionSimpleGemm() {
2674
2674
fusion.addInput (tv1);
2675
2675
2676
2676
TensorView* tv2 = broadcast (tv0, {false , false , true });
2677
+ // tv2[I0, I1, B] = tv0[I0, I1]
2678
+
2677
2679
TensorView* tv3 = broadcast (tv1, {true , false , false });
2680
+ // tv3[B, I1, I2] = tv1[I1, I2]
2678
2681
2682
+ // tv4[I0, I1, I2] = tv2[I0, I1, B] * tv3[B, I1, I2]
2679
2683
TensorView* tv4 = mul (tv2, tv3);
2684
+ // tv5[I0, R1, I2] = tv4[I0, I1, I2]
2680
2685
TensorView* tv5 = sum (tv4, {1 });
2681
2686
fusion.addOutput (tv5);
2682
2687
2683
2688
tv5->split (1 , 32 );
2689
+ // tv5[I0, R1o, R1i{32}, I2]
2690
+
2684
2691
auto tv6 = tv5->rFactor ({1 });
2692
+ // tv6[I0, R1o, I1i{32}, I2] = tv4[I0, I1, I2]
2693
+ // tv5[I0, , R1i{32}, I2] = tv6[I0, R1o, I1i{32}, I2]
2685
2694
2686
2695
tv5->split (0 , 4 );
2687
2696
tv5->split (-1 , 4 );
2697
+ // tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}]
2698
+ // tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}]
2688
2699
2689
2700
tv0->computeAt (tv5, -1 );
2690
2701
tv1->computeAt (tv5, -1 );
2691
2702
2703
+ // tv6[I0o, I0i{4}, R1o, I1i{32}, I2o, I2i{4}]
2704
+ // tv5[I0o, I0i{4}, , R1i{32}, I2o, I2i{4}]
2705
+ // --> (line symbolizes compute at location)
2706
+ // tv4[I0o, I0i{4}, I1i{32}, I2o, I2i{4}|, I1o]
2707
+ // tv6[I0o, I0i{4}, I1i{32}, I2o, I2i{4}|, R1o]
2708
+ // tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}|]
2709
+
2692
2710
tv0->computeAt (tv6, -1 );
2693
2711
tv1->computeAt (tv6, -1 );
2712
+ // tv4[I0o, I0i{4}, I1i{32}, I2o, I2i{4}, I1o |]
2713
+ // tv6[I0o, I0i{4}, I1i{32}, I2o, I2i{4}, R1o |]
2714
+ // tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}|]
2694
2715
2695
2716
tv5->axis (0 )->parallelize (ParallelType::BIDz);
2696
2717
tv5->axis (1 )->parallelize (ParallelType::TIDz);
0 commit comments