@@ -491,6 +491,36 @@ TEST_F(NVFuserTest, FusionScheduleBroadcastOnly_CUDA) {
491
491
}
492
492
}
493
493
494
+ // mermaid graph:
495
+ // ```mermaid
496
+ // %%{
497
+ // init: {
498
+ // 'theme': 'base',
499
+ // 'themeVariables': { 'fontSize': '30px', 'fontFamily': 'times'}}
500
+ // }%%
501
+ // graph TD
502
+ // T0("T0(M, N, K)")
503
+ // T1("T1(N, M, K)")
504
+ // T2("T2(M, K, N)")
505
+ // T0 --> A("transpose(1, 2)") --> T3("T3(M, K, N)")
506
+ // T1 ---> sigmoid --> T5("T5(N, M, K)")
507
+ // T5 --> B("transpose(0, 2)") --> T7("T7(K, M, N)")
508
+ // T2 ----> C("add")
509
+ // T3 --> C --> T6("T6(M, K, N)")
510
+ // T6 --> D("transpose(0, 1)") --> T11("T11(K, M, N)")
511
+ // T11 --> E("add") -->T12("T12(K, M, N)")
512
+ // T7 --> E
513
+ // T1 ---> F("transpose(0, 1)") --> T4("T4(M, N, K)")
514
+ // T0 --> G("add") --> T8("T8(M, N, K)") --> relu ---> T9("T9(M, N, K)")
515
+ // T4 --> G
516
+ // T6 ---> sin ---> T10("T10(M, K, N)")
517
+ // style T0 fill:lightgreen
518
+ // style T1 fill:lightgreen
519
+ // style T2 fill:lightgreen
520
+ // style T12 fill:lightblue
521
+ // style T9 fill:lightblue
522
+ // style T10 fill:lightblue
523
+ // ```
494
524
TEST_F (NVFuserTest, FusionScheduleTransposeComplexDAG1_CUDA) {
495
525
Fusion fusion;
496
526
FusionGuard fg (&fusion);
@@ -546,6 +576,36 @@ TEST_F(NVFuserTest, FusionScheduleTransposeComplexDAG1_CUDA) {
546
576
__FILE__);
547
577
}
548
578
579
+ // mermaid graph:
580
+ // ```mermaid
581
+ // %%{
582
+ // init: {
583
+ // 'theme': 'base',
584
+ // 'themeVariables': { 'fontSize': '30px', 'fontFamily': 'times'}}
585
+ // }%%
586
+ // graph TD
587
+ // T0("T0(M, N, K)")
588
+ // T1("T1(N, M, K)")
589
+ // T2("T2(M, K, N)")
590
+ // T0 --> A("transpose(1, 2)") --> T3("T3(M, K, N)")
591
+ // T1 ---> sigmoid --> T5("T5(N, M, K)")
592
+ // T5 --> B("transpose(0, 2)") --> T7("T7(K, M, N)")
593
+ // T2 ----> C("add")
594
+ // T3 --> C --> T6("T6(M, K, N)")
595
+ // T6 --> D("transpose(0, 1)") --> T11("T11(K, M, N)")
596
+ // T11 --> E("add") -->T12("T12(K, M, N)")
597
+ // T7 --> E
598
+ // T1 ---> F("transpose(0, 1)") --> T4("T4(M, N, K)")
599
+ // T0 --> G("add") --> T8("T8(M, N, K)") --> relu ---> T9("T9(M, N, K)")
600
+ // T4 --> G
601
+ // T6 ---> sin ---> T10("T10(M, K, N)")
602
+ // style T0 fill:lightgreen
603
+ // style T1 fill:lightgreen
604
+ // style T2 fill:lightgreen
605
+ // style T12 fill:lightblue
606
+ // style T9 fill:lightblue
607
+ // style T10 fill:lightblue
608
+ // ```
549
609
TEST_F (NVFuserTest, FusionManualScheduleTransposeComplexDAG1_CUDA) {
550
610
// achieved: 833.526 GB/s on RTX 3090 (theoretical bandwidth: 936 GB/s)
551
611
Fusion fusion;
@@ -729,6 +789,45 @@ TEST_F(NVFuserTest, FusionViewNoTranspose_CUDA) {
729
789
TORCH_CHECK (!hasAtLeastTwoValidGroups (&fusion));
730
790
}
731
791
792
+ // t0------------.
793
+ // t2->broadcast->sub->mul->relu->t6
794
+ // t1------------------'
795
+ TEST_F (NVFuserTest, FusionScheduleTransposeMissingDim_CUDA) {
796
+ Fusion fusion;
797
+ FusionGuard fg (&fusion);
798
+
799
+ auto tv0 = makeContigTensor (3 );
800
+ auto tv1 = makeContigConcreteTensor ({1 , -1 , 1 });
801
+ auto tv2 = makeContigTensor (1 );
802
+ fusion.addInput (tv0);
803
+ fusion.addInput (tv1);
804
+ fusion.addInput (tv2);
805
+ auto tv3 = broadcast (tv2, {true , false , true });
806
+ auto tv4 = sub (tv0, tv3);
807
+ auto tv5 = mul (tv4, tv1);
808
+ auto tv6 = relu (tv5);
809
+ fusion.addOutput (tv6);
810
+
811
+ auto options = at::TensorOptions ().dtype (at::kFloat ).device (at::kCUDA , 0 );
812
+ at::Tensor input0 = at::randn ({512 , 1024 , 512 }, options);
813
+ at::Tensor input1 = at::randn ({1 , 1024 , 1 }, options);
814
+ at::Tensor input2 = at::randn ({1024 }, options);
815
+
816
+ auto lparams = scheduleTranspose (&fusion, {input0, input1, input2});
817
+
818
+ FusionExecutor fe;
819
+ fe.compileFusion (&fusion, {input0, input1, input2}, lparams);
820
+ auto outputs = fe.runFusion ({input0, input1, input2}, lparams);
821
+
822
+ auto t3 = input2.unsqueeze (0 ).unsqueeze (-1 );
823
+ auto t4 = input0 - t3;
824
+ auto t5 = t4 * input1;
825
+ auto t6 = at::relu (t5);
826
+
827
+ testValidate (
828
+ &fusion, outputs, {input0, input1, input2}, {t6}, __LINE__, __FILE__);
829
+ }
830
+
732
831
} // namespace jit
733
832
} // namespace torch
734
833
#endif // #if defined(USE_CUDA)
0 commit comments