Skip to content

Commit 057237f

Browse files
authored
Fix CUDA driver error: misaligned address for transpose scheduler (#1918)
1 parent 3fb3d80 commit 057237f

File tree

2 files changed

+143
-8
lines changed

2 files changed

+143
-8
lines changed

torch/csrc/jit/codegen/cuda/scheduler/transpose.cpp

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -533,7 +533,9 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams& params) {
533533

534534
// parallelize group2 and its cached inputs
535535
{
536-
reference2->axis(-1)->parallelize(ParallelType::Vectorize);
536+
if (params.vectorize_factor2 > 1) {
537+
reference2->axis(-1)->parallelize(ParallelType::Vectorize);
538+
}
537539
reference2->axis(-2)->parallelize(ParallelType::TIDx);
538540
reference2->axis(-3)->parallelize(ParallelType::Unroll);
539541

@@ -542,9 +544,27 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams& params) {
542544
scheduler_utils::parallelizeAllLike(
543545
reference2,
544546
{group2_and_cached_inputs.begin(), group2_and_cached_inputs.end()},
545-
{ParallelType::Vectorize, ParallelType::TIDx});
547+
{ParallelType::TIDx});
546548

547-
// Only unrolled the axes that exactly maps to the unrolled axes
549+
// Only vectorize the axes that exactly maps to the vectorized axes
550+
// on reference as support for permissively mapped axes are not
551+
// yet clearly defined.
552+
std::vector<TensorView*> vectorized_group2_cached_inputs;
553+
for (auto gin : group2_and_cached_inputs) {
554+
if (std::any_of(
555+
gin->domain()->domain().begin(),
556+
gin->domain()->domain().end(),
557+
[&ca_map, reference2](IterDomain* id) {
558+
return ca_map.areMapped(
559+
id, reference2->axis(-1), IdMappingMode::EXACT);
560+
})) {
561+
vectorized_group2_cached_inputs.push_back(gin);
562+
}
563+
}
564+
scheduler_utils::parallelizeAllLike(
565+
reference2, vectorized_group2_cached_inputs, {ParallelType::Vectorize});
566+
567+
// Only unroll the axes that exactly maps to the unrolled axes
548568
// on reference as support for permissively mapped axes are not
549569
// yet clearly defined.
550570
std::vector<TensorView*> unrolled_group2_cached_inputs;
@@ -559,7 +579,6 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams& params) {
559579
unrolled_group2_cached_inputs.push_back(gin);
560580
}
561581
}
562-
563582
scheduler_utils::parallelizeAllLike(
564583
reference2, unrolled_group2_cached_inputs, {ParallelType::Unroll});
565584
}
@@ -571,7 +590,9 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams& params) {
571590
reference1->merge(pos);
572591
reference1->split(pos, params.vectorize_factor1);
573592
reference1->split(pos, kThreadsPerBlock);
574-
reference1->axis(-1)->parallelize(ParallelType::Vectorize);
593+
if (params.vectorize_factor1 > 1) {
594+
reference1->axis(-1)->parallelize(ParallelType::Vectorize);
595+
}
575596
reference1->axis(-2)->parallelize(ParallelType::TIDx);
576597
reference1->axis(-3)->parallelize(ParallelType::Unroll);
577598
// [..., Unroll, TIDx, Vectorize]
@@ -600,10 +621,26 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams& params) {
600621
group1_and_cached_inputs.emplace_back(ir_utils::consumerTvsOf(tv)[0]);
601622
}
602623
}
624+
625+
// Only vectorize the axes that exactly maps to the vectorized axes
626+
// on reference as support for permissively mapped axes are not
627+
// yet clearly defined.
628+
std::vector<TensorView*> vectorized_group1_cached_inputs;
629+
for (auto gin : group1_and_cached_inputs) {
630+
if (std::any_of(
631+
gin->domain()->domain().begin(),
632+
gin->domain()->domain().end(),
633+
[&ca_map, reference1](IterDomain* id) {
634+
return ca_map.areMapped(
635+
id, reference1->axis(-1), IdMappingMode::EXACT);
636+
})) {
637+
vectorized_group1_cached_inputs.push_back(gin);
638+
}
639+
}
603640
scheduler_utils::parallelizeAllLike(
604-
reference1, group1_and_cached_inputs, {ParallelType::Vectorize});
641+
reference1, vectorized_group1_cached_inputs, {ParallelType::Vectorize});
605642

606-
// Only unrolled the axes that exactly maps to the unrolled axes
643+
// Only unroll the axes that exactly maps to the unrolled axes
607644
// on reference as support for permissively mapped axes are not
608645
// yet clearly defined.
609646
std::vector<TensorView*> unrolled_group1_cached_inputs;
@@ -618,7 +655,6 @@ void scheduleTranspose(Fusion* fusion, const TransposeParams& params) {
618655
unrolled_group1_cached_inputs.push_back(gin);
619656
}
620657
}
621-
622658
scheduler_utils::parallelizeAllLike(
623659
reference1, unrolled_group1_cached_inputs, {ParallelType::Unroll});
624660
}

torch/csrc/jit/codegen/cuda/test/test_gpu_transpose.cpp

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,36 @@ TEST_F(NVFuserTest, FusionScheduleBroadcastOnly_CUDA) {
491491
}
492492
}
493493

494+
// mermaid graph:
495+
// ```mermaid
496+
// %%{
497+
// init: {
498+
// 'theme': 'base',
499+
// 'themeVariables': { 'fontSize': '30px', 'fontFamily': 'times'}}
500+
// }%%
501+
// graph TD
502+
// T0("T0(M, N, K)")
503+
// T1("T1(N, M, K)")
504+
// T2("T2(M, K, N)")
505+
// T0 --> A("transpose(1, 2)") --> T3("T3(M, K, N)")
506+
// T1 ---> sigmoid --> T5("T5(N, M, K)")
507+
// T5 --> B("transpose(0, 2)") --> T7("T7(K, M, N)")
508+
// T2 ----> C("add")
509+
// T3 --> C --> T6("T6(M, K, N)")
510+
// T6 --> D("transpose(0, 1)") --> T11("T11(K, M, N)")
511+
// T11 --> E("add") -->T12("T12(K, M, N)")
512+
// T7 --> E
513+
// T1 ---> F("transpose(0, 1)") --> T4("T4(M, N, K)")
514+
// T0 --> G("add") --> T8("T8(M, N, K)") --> relu ---> T9("T9(M, N, K)")
515+
// T4 --> G
516+
// T6 ---> sin ---> T10("T10(M, K, N)")
517+
// style T0 fill:lightgreen
518+
// style T1 fill:lightgreen
519+
// style T2 fill:lightgreen
520+
// style T12 fill:lightblue
521+
// style T9 fill:lightblue
522+
// style T10 fill:lightblue
523+
// ```
494524
TEST_F(NVFuserTest, FusionScheduleTransposeComplexDAG1_CUDA) {
495525
Fusion fusion;
496526
FusionGuard fg(&fusion);
@@ -546,6 +576,36 @@ TEST_F(NVFuserTest, FusionScheduleTransposeComplexDAG1_CUDA) {
546576
__FILE__);
547577
}
548578

579+
// mermaid graph:
580+
// ```mermaid
581+
// %%{
582+
// init: {
583+
// 'theme': 'base',
584+
// 'themeVariables': { 'fontSize': '30px', 'fontFamily': 'times'}}
585+
// }%%
586+
// graph TD
587+
// T0("T0(M, N, K)")
588+
// T1("T1(N, M, K)")
589+
// T2("T2(M, K, N)")
590+
// T0 --> A("transpose(1, 2)") --> T3("T3(M, K, N)")
591+
// T1 ---> sigmoid --> T5("T5(N, M, K)")
592+
// T5 --> B("transpose(0, 2)") --> T7("T7(K, M, N)")
593+
// T2 ----> C("add")
594+
// T3 --> C --> T6("T6(M, K, N)")
595+
// T6 --> D("transpose(0, 1)") --> T11("T11(K, M, N)")
596+
// T11 --> E("add") -->T12("T12(K, M, N)")
597+
// T7 --> E
598+
// T1 ---> F("transpose(0, 1)") --> T4("T4(M, N, K)")
599+
// T0 --> G("add") --> T8("T8(M, N, K)") --> relu ---> T9("T9(M, N, K)")
600+
// T4 --> G
601+
// T6 ---> sin ---> T10("T10(M, K, N)")
602+
// style T0 fill:lightgreen
603+
// style T1 fill:lightgreen
604+
// style T2 fill:lightgreen
605+
// style T12 fill:lightblue
606+
// style T9 fill:lightblue
607+
// style T10 fill:lightblue
608+
// ```
549609
TEST_F(NVFuserTest, FusionManualScheduleTransposeComplexDAG1_CUDA) {
550610
// achieved: 833.526 GB/s on RTX 3090 (theoretical bandwidth: 936 GB/s)
551611
Fusion fusion;
@@ -729,6 +789,45 @@ TEST_F(NVFuserTest, FusionViewNoTranspose_CUDA) {
729789
TORCH_CHECK(!hasAtLeastTwoValidGroups(&fusion));
730790
}
731791

792+
// t0------------.
793+
// t2->broadcast->sub->mul->relu->t6
794+
// t1------------------'
795+
TEST_F(NVFuserTest, FusionScheduleTransposeMissingDim_CUDA) {
796+
Fusion fusion;
797+
FusionGuard fg(&fusion);
798+
799+
auto tv0 = makeContigTensor(3);
800+
auto tv1 = makeContigConcreteTensor({1, -1, 1});
801+
auto tv2 = makeContigTensor(1);
802+
fusion.addInput(tv0);
803+
fusion.addInput(tv1);
804+
fusion.addInput(tv2);
805+
auto tv3 = broadcast(tv2, {true, false, true});
806+
auto tv4 = sub(tv0, tv3);
807+
auto tv5 = mul(tv4, tv1);
808+
auto tv6 = relu(tv5);
809+
fusion.addOutput(tv6);
810+
811+
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
812+
at::Tensor input0 = at::randn({512, 1024, 512}, options);
813+
at::Tensor input1 = at::randn({1, 1024, 1}, options);
814+
at::Tensor input2 = at::randn({1024}, options);
815+
816+
auto lparams = scheduleTranspose(&fusion, {input0, input1, input2});
817+
818+
FusionExecutor fe;
819+
fe.compileFusion(&fusion, {input0, input1, input2}, lparams);
820+
auto outputs = fe.runFusion({input0, input1, input2}, lparams);
821+
822+
auto t3 = input2.unsqueeze(0).unsqueeze(-1);
823+
auto t4 = input0 - t3;
824+
auto t5 = t4 * input1;
825+
auto t6 = at::relu(t5);
826+
827+
testValidate(
828+
&fusion, outputs, {input0, input1, input2}, {t6}, __LINE__, __FILE__);
829+
}
830+
732831
} // namespace jit
733832
} // namespace torch
734833
#endif // #if defined(USE_CUDA)

0 commit comments

Comments
 (0)