Skip to content

Commit 3ecb427

Browse files
csarofeentlemonaoyamjjsjann123kevinstephano
authored
Continue broadcast support (#77)
* Simplify a few test cases Replace custom exception checks with ASSERT_THROW macros. * ExpressionEvaluator * Stricter EvaluationContext binding rules 1. Don't allow overwriting concrete values 2. Don't allow binding values to expression results * Fix clang-format errors * Switch to Int::ScalarType The expression evaluator is now using Int::ScalarType instead of plain int. * Avoid a fight with clang-tidy * Check the numbers of kernel input and output parameters * Add an optional arc from TensorView to its root domain This is generated for detail_level >= DetailLevel::Explicit * Checks kernel arguments * Prefer pointers over references * Bug fix * Fix accidental construction of IValue * Use noReduction * Add const to const pointer * Make an integer tensor an error as it is not yet supported * clang-tidy * Incorporate review feedback * added lerp support in parser * add missing addcmul parser and tests * clang_format * Return TensorView* from binary/compound/ternary ops * clang-format * Use TensorView* param in reductionOp and sum * Prefer as instead of static_cast * Transform replay refactor (#53) Goal of this work is to have the transformation history be specific to IterDomains instead of TensorDomains. This should make it a lot easier to match up IterDomains during replay which can be complicated when taking into consideration reduction axes, rfactors, and broadcast axes. Co-authored-by: Jie <[email protected]> Co-authored-by: Kevin Stephano <[email protected]> * python test fixes (#52) fix python tests failure: 1. put Fusion inside cudaKernel to facilitate runtime arg check. 2. relax rank check for broadcast support in integration; 3. add shape propagation for newly added opeartion: [addcmul, lerp]; 4. adding utility function to create FusionGuard from CudaKernel directly. * [nvFuser] add torch.jit.fuser context manager (pytorch#38993) (#54) Summary: 1. `torch.jit.fuser(str)` context manager facilitates switch between backend fusers: str - 'fuser0' enables only legacy fuser; str - 'fuser1' enables only NNC; str - 'fuser2' enables only nvFuser; 2. cleanup updated python tests. Pull Request resolved: pytorch#38993 Reviewed By: nairbv, pbelevich Differential Revision: D21800620 Pulled By: soumith fbshipit-source-id: 7fe855f5a5b97368e5e84c98c28d04b2e1276c85 * Add another reduction example, change fusion printMath. * Small test fix. * Change Reduction4 test to use TIDx.x * Minor cleanup. * Clean up some noexcepts. * More cleanup. * Refactor computeAt, get first broadcast example working. * Validate first non-trivial broadcast kernel. * Fix replay when broadcast is merged with non-broadcast dim. * Add constness in replay and index compute. * Add another broadcast test. Rework index computation for producers, base on consumer computed indices. * Val isCconst fix. * Add dot product gemm example. * Clang. * Minor bug fixes. * Format and add comments to GEMM test. * WIP: Fix for enabling broadcast after reduction plus a Softmax test. (#66) * Fix for enabling broadcast after reduction plus a Softmax test. * Cleaner way of fixing checks for matching non-broadcast dims to non-reduction dims. * Clang. Co-authored-by: Kevin Stephano <[email protected]> Co-authored-by: Christian Sarofeen <[email protected]> * Backout bad merge conflict resolutions. * More post rebase cleanup. * Refix a few tests. Some from a bad rebase. * Address comments. * Missed some review comments. * tmp Co-authored-by: Lemo <[email protected]> Co-authored-by: Naoya Maruyama <[email protected]> Co-authored-by: Jie <[email protected]> Co-authored-by: Kevin Stephano <[email protected]> Co-authored-by: Kevin Stephano <[email protected]>
1 parent b9af528 commit 3ecb427

19 files changed

+674
-324
lines changed

test/cpp/jit/test_gpu.cpp

Lines changed: 228 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1060,7 +1060,7 @@ void testGPU_FusionAdvancedComputeAt() {
10601060
TORCH_CHECK(tv0->getComputeAtView() == tv3);
10611061
TORCH_CHECK(tv1->getComputeAtView() == tv4);
10621062
TORCH_CHECK(tv2->getComputeAtView() == tv4);
1063-
TORCH_CHECK(tv3->getComputeAtView() == tv6);
1063+
TORCH_CHECK(tv3->getComputeAtView() == tv5);
10641064
TORCH_CHECK(tv4->getComputeAtView() == tv5);
10651065
TORCH_CHECK(tv5->getComputeAtView() == tv6);
10661066
TORCH_CHECK(!tv6->hasComputeAt());
@@ -1074,7 +1074,7 @@ void testGPU_FusionAdvancedComputeAt() {
10741074

10751075
tv0->computeAt(tv6, 1);
10761076

1077-
TORCH_CHECK(tv0->getComputeAtView() == tv3 && tv0->nDims() == 3);
1077+
TORCH_CHECK(tv0->getComputeAtView() == tv6 && tv0->nDims() == 3);
10781078
TORCH_CHECK(tv1->getComputeAtView() == tv4 && tv1->nDims() == 3);
10791079
TORCH_CHECK(tv2->getComputeAtView() == tv4 && tv2->nDims() == 3);
10801080
TORCH_CHECK(tv3->getComputeAtView() == tv6 && tv3->nDims() == 3);
@@ -1148,6 +1148,7 @@ void testGPU_FusionAdvancedComputeAt() {
11481148
fusion.addOutput(tv6);
11491149

11501150
tv2->computeAt(tv4, 1);
1151+
11511152
TORCH_CHECK(!tv0->hasComputeAt());
11521153
TORCH_CHECK(!tv1->hasComputeAt());
11531154
TORCH_CHECK(tv2->getComputeAtView() == tv4);
@@ -1495,9 +1496,6 @@ void testGPU_FusionLoopUnroll() {
14951496

14961497
int inp_size = 129 * 13 * 3;
14971498

1498-
// GPULower lower(&fusion);
1499-
// lower.printKernel(std::cout);
1500-
15011499
prog.device_ = 0;
15021500
prog.grid((inp_size + 63) / 64);
15031501
prog.block(block_size);
@@ -2163,11 +2161,6 @@ void testGPU_FusionReduction() {
21632161
tv2->axis(-1)->parallelize(ParallelType::TIDx);
21642162
tv3->axis(-1)->parallelize(ParallelType::TIDx);
21652163

2166-
// for(auto expr : fusion.exprs(true))
2167-
// std::cout<<expr<<std::endl;
2168-
// GPULower lower(&fusion);
2169-
// lower.printKernel(std::cout);
2170-
21712164
int numel_x = 65000;
21722165
int numel_y = 1025;
21732166

@@ -2585,7 +2578,8 @@ void testGPU_FusionReductionTFT() {
25852578

25862579
void testGPU_FusionSimpleBCast() {
25872580
{
2588-
Fusion fusion;
2581+
torch::jit::fuser::cuda::CudaKernel prog;
2582+
Fusion& fusion = *prog.fusion_;
25892583
FusionGuard fg(&fusion);
25902584

25912585
// Set up your input tensor views
@@ -2594,17 +2588,45 @@ void testGPU_FusionSimpleBCast() {
25942588
fusion.addInput(tv0);
25952589
fusion.addInput(tv1);
25962590

2597-
TensorView* tv2 = add(tv0, tv1);
2591+
TensorView* tv2 = broadcast(tv0, {false, false, true});
2592+
TensorView* tv3 = broadcast(tv1, {true, false, false});
25982593

2599-
// tv1[I0, R1] = tv0[I0, I1]
2600-
TensorView* tv3 = broadcast(tv2, {false, true, true, false});
2601-
2602-
Val* tv4 = mul(tv3, makeDummyTensor(4));
2594+
TensorView* tv4 = add(tv2, tv3);
2595+
tv4->split(-1, 4);
2596+
tv4->split(0, 8);
26032597
fusion.addOutput(tv4);
2598+
2599+
tv0->computeAt(tv4, -1);
2600+
tv1->computeAt(tv4, -1);
2601+
2602+
tv4->axis(0)->parallelize(ParallelType::BIDx);
2603+
tv4->axis(-1)->parallelize(ParallelType::TIDx);
2604+
2605+
size_t x = 63, y = 33, z = 15;
2606+
2607+
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
2608+
2609+
at::Tensor t0 = at::randn({x, y}, options);
2610+
at::Tensor t1 = at::randn({y, z}, options);
2611+
2612+
at::Tensor cg_output = at::empty({x, y, z}, options);
2613+
2614+
prog.device_ = 0;
2615+
prog.grid(ceilDiv_(x, 8));
2616+
prog.block(4);
2617+
torch::jit::fuser::cuda::compileKernel(&prog);
2618+
torch::jit::fuser::cuda::runTestKernel(&prog, {t0, t1}, {cg_output});
2619+
2620+
auto t2 = t0.unsqueeze(-1).expand({x, y, z});
2621+
auto t3 = t1.expand({x, y, z});
2622+
auto t4 = t2.add(t3);
2623+
2624+
TORCH_CHECK(t4.allclose(cg_output));
26042625
}
26052626

26062627
{
2607-
Fusion fusion;
2628+
torch::jit::fuser::cuda::CudaKernel prog;
2629+
Fusion& fusion = *prog.fusion_;
26082630
FusionGuard fg(&fusion);
26092631

26102632
// Set up your input tensor views
@@ -2613,20 +2635,204 @@ void testGPU_FusionSimpleBCast() {
26132635
fusion.addInput(tv0);
26142636
fusion.addInput(tv1);
26152637

2616-
TensorView* tv2 = broadcast(tv0, {true, false, false});
2617-
TensorView* tv3 = broadcast(tv1, {false, false, true});
2638+
// TODO add pointwise ops on the begining before the bcast.
2639+
2640+
TensorView* tv2 = broadcast(tv0, {false, false, true});
2641+
TensorView* tv3 = broadcast(tv1, {true, false, false});
2642+
2643+
TensorView* tv4 = add(tv2, tv3);
2644+
2645+
tv4->merge(0, 1);
26182646

2619-
TensorView* tv4 = mul(tv3, tv2);
26202647
fusion.addOutput(tv4);
26212648

26222649
tv0->computeAt(tv4, -1);
26232650
tv1->computeAt(tv4, -1);
26242651

2625-
// GPULower lower(&fusion);
2626-
// lower.printKernel(std::cout);
2652+
tv4->axis(0)->parallelize(ParallelType::BIDx);
2653+
2654+
size_t x = 63, y = 33, z = 15;
2655+
2656+
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
2657+
2658+
at::Tensor t0 = at::randn({x, y}, options);
2659+
at::Tensor t1 = at::randn({y, z}, options);
2660+
2661+
at::Tensor cg_output = at::empty({x, y, z}, options);
2662+
2663+
prog.device_ = 0;
2664+
prog.grid(x * y);
2665+
prog.block(1);
2666+
torch::jit::fuser::cuda::compileKernel(&prog);
2667+
torch::jit::fuser::cuda::runTestKernel(&prog, {t0, t1}, {cg_output});
2668+
2669+
auto t2 = t0.unsqueeze(-1).expand({x, y, z});
2670+
auto t3 = t1.expand({x, y, z});
2671+
auto t4 = t2.add(t3);
2672+
2673+
TORCH_CHECK(t4.allclose(cg_output));
2674+
}
2675+
}
2676+
2677+
void testGPU_FusionSimpleGemm() {
2678+
{
2679+
torch::jit::fuser::cuda::CudaKernel prog;
2680+
Fusion& fusion = *prog.fusion_;
2681+
FusionGuard fg(&fusion);
2682+
2683+
// Set up your input tensor views
2684+
TensorView* tv0 = makeDummyTensor(2); // M, K
2685+
TensorView* tv1 = makeDummyTensor(2); // K, N
2686+
fusion.addInput(tv0);
2687+
fusion.addInput(tv1);
2688+
2689+
TensorView* tv2 = broadcast(tv0, {false, false, true});
2690+
// tv2[I0, I1, B] = tv0[I0, I1]
2691+
2692+
TensorView* tv3 = broadcast(tv1, {true, false, false});
2693+
// tv3[B, I1, I2] = tv1[I1, I2]
2694+
2695+
// tv4[I0, I1, I2] = tv2[I0, I1, B] * tv3[B, I1, I2]
2696+
TensorView* tv4 = mul(tv2, tv3);
2697+
// tv5[I0, R1, I2] = tv4[I0, I1, I2]
2698+
TensorView* tv5 = sum(tv4, {1});
2699+
fusion.addOutput(tv5);
2700+
2701+
tv5->split(1, 32);
2702+
// tv5[I0, R1o, R1i{32}, I2]
2703+
2704+
auto tv6 = tv5->rFactor({1});
2705+
// tv6[I0, R1o, I1i{32}, I2] = tv4[I0, I1, I2]
2706+
// tv5[I0, , R1i{32}, I2] = tv6[I0, R1o, I1i{32}, I2]
2707+
2708+
tv5->split(0, 4);
2709+
tv5->split(-1, 4);
2710+
// tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}]
2711+
// tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}]
2712+
2713+
tv0->computeAt(tv5, -1);
2714+
tv1->computeAt(tv5, -1);
2715+
2716+
// tv6[I0o, I0i{4}, R1o, I1i{32}, I2o, I2i{4}]
2717+
// tv5[I0o, I0i{4}, , R1i{32}, I2o, I2i{4}]
2718+
//--> (line symbolizes compute at location)
2719+
// tv4[I0o, I0i{4}, I1i{32}, I2o, I2i{4}|, I1o]
2720+
// tv6[I0o, I0i{4}, I1i{32}, I2o, I2i{4}|, R1o]
2721+
// tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}|]
2722+
2723+
tv0->computeAt(tv6, -1);
2724+
tv1->computeAt(tv6, -1);
2725+
// tv4[I0o, I0i{4}, I1i{32}, I2o, I2i{4}, I1o |]
2726+
// tv6[I0o, I0i{4}, I1i{32}, I2o, I2i{4}, R1o |]
2727+
// tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}|]
2728+
2729+
tv5->axis(0)->parallelize(ParallelType::BIDz);
2730+
tv5->axis(1)->parallelize(ParallelType::TIDz);
2731+
2732+
tv5->axis(-2)->parallelize(ParallelType::BIDy);
2733+
tv5->axis(-1)->parallelize(ParallelType::TIDy);
2734+
2735+
tv5->axis(2)->parallelize(ParallelType::TIDx);
2736+
tv6->axis(2)->parallelize(ParallelType::TIDx);
2737+
2738+
size_t M = 65, K = 33, N = 17;
2739+
2740+
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
2741+
2742+
at::Tensor t0 = at::randn({M, K}, options);
2743+
at::Tensor t1 = at::randn({K, N}, options);
2744+
2745+
at::Tensor cg_output = at::empty({M, N}, options);
2746+
2747+
prog.device_ = 0;
2748+
prog.grid(1, ceilDiv_(N, 4), ceilDiv_(M, 4));
2749+
2750+
prog.block(32, 4, 4);
2751+
torch::jit::fuser::cuda::compileKernel(&prog);
2752+
torch::jit::fuser::cuda::runTestKernel(&prog, {t0, t1}, {cg_output});
2753+
2754+
auto t2 = t0.matmul(t1);
2755+
TORCH_CHECK(
2756+
t2.allclose(cg_output, 1e-5, 1e-5),
2757+
"Error of: ",
2758+
t2.sub(cg_output).abs().max());
26272759
}
26282760
}
26292761

2762+
// This test currently requires a combination of broadcast and reduction
2763+
// operations and parellelization strategy that is currently not supported.
2764+
// It is a goal to get this example working and this test is added so we
2765+
// can continue working on getting this example fixed. Right now it
2766+
// produces an incorrect result. Either we need to error coherently on the
2767+
// optimization strategy we don't support and set this test to one we do support
2768+
// or we need to get this schedule working correctly.
2769+
void testGPU_FusionSoftmax() {
2770+
torch::jit::fuser::cuda::CudaKernel prog;
2771+
Fusion& fusion = *prog.fusion_;
2772+
FusionGuard fg(&fusion);
2773+
2774+
// Set up your input tensor views
2775+
TensorView* input_tv0 = makeDummyTensor(3);
2776+
fusion.addInput(input_tv0);
2777+
2778+
TensorView* max_val_tv1 =
2779+
reductionOp(BinaryOpType::Max, {2}, new Float(0), input_tv0);
2780+
TensorView* bcast_max_tv2 = broadcast(max_val_tv1, {false, false, true});
2781+
TensorView* exp_tv3 = sub(input_tv0, bcast_max_tv2);
2782+
TensorView* sum_exp_tv4 =
2783+
reductionOp(BinaryOpType::Add, {2}, new Float(0), exp_tv3);
2784+
TensorView* bcast_sum_tv5 = broadcast(sum_exp_tv4, {false, false, true});
2785+
TensorView* output_tv6 = div(exp_tv3, bcast_sum_tv5);
2786+
2787+
max_val_tv1->split(-1, 32);
2788+
TensorView* max_val_rf_tv7 = max_val_tv1->rFactor({-2});
2789+
sum_exp_tv4->split(-1, 32);
2790+
TensorView* sum_exp_rf_tv8 = sum_exp_tv4->rFactor({-2});
2791+
2792+
exp_tv3->computeAt(sum_exp_rf_tv8, {2});
2793+
2794+
max_val_rf_tv7->axis(0)->parallelize(ParallelType::BIDx);
2795+
max_val_tv1->axis(0)->parallelize(ParallelType::BIDx);
2796+
bcast_max_tv2->axis(0)->parallelize(ParallelType::BIDx);
2797+
sum_exp_rf_tv8->axis(0)->parallelize(ParallelType::BIDx);
2798+
sum_exp_tv4->axis(0)->parallelize(ParallelType::BIDx);
2799+
bcast_sum_tv5->axis(0)->parallelize(ParallelType::BIDx);
2800+
output_tv6->axis(0)->parallelize(ParallelType::BIDx);
2801+
2802+
max_val_rf_tv7->axis(1)->parallelize(ParallelType::BIDy);
2803+
max_val_tv1->axis(1)->parallelize(ParallelType::BIDy);
2804+
bcast_max_tv2->axis(1)->parallelize(ParallelType::BIDy);
2805+
sum_exp_rf_tv8->axis(1)->parallelize(ParallelType::BIDy);
2806+
sum_exp_tv4->axis(1)->parallelize(ParallelType::BIDy);
2807+
bcast_sum_tv5->axis(1)->parallelize(ParallelType::BIDy);
2808+
output_tv6->axis(1)->parallelize(ParallelType::BIDy);
2809+
2810+
max_val_rf_tv7->axis(-1)->parallelize(ParallelType::TIDx);
2811+
max_val_tv1->axis(-1)->parallelize(ParallelType::TIDx);
2812+
bcast_max_tv2->axis(-1)->parallelize(ParallelType::TIDx);
2813+
exp_tv3->axis(-1)->parallelize(ParallelType::TIDx);
2814+
sum_exp_rf_tv8->axis(-1)->parallelize(ParallelType::TIDx);
2815+
sum_exp_tv4->axis(-1)->parallelize(ParallelType::TIDx);
2816+
bcast_sum_tv5->axis(-1)->parallelize(ParallelType::TIDx);
2817+
output_tv6->axis(-1)->parallelize(ParallelType::TIDx);
2818+
2819+
fusion.addOutput(output_tv6);
2820+
2821+
prog.device_ = 0;
2822+
prog.grid(32, 32);
2823+
prog.block(32);
2824+
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
2825+
at::Tensor t0 = at::randn({32, 32, 128}, options);
2826+
at::Tensor cg_output = at::empty({32, 32, 128}, options);
2827+
torch::jit::fuser::cuda::compileKernel(&prog);
2828+
torch::jit::fuser::cuda::runTestKernel(&prog, {t0}, {cg_output});
2829+
2830+
auto t2 = at::_softmax(t0, -1, false);
2831+
// TORCH_CHECK(
2832+
// t2.allclose(cg_output, 1e-5, 1e-5),
2833+
// "Error of: ",
2834+
// t2.sub(cg_output).abs().max());
2835+
}
26302836
// Similar to FusionReduction but uses grid reduction
26312837
void testGPU_FusionGridReduction1() {
26322838
const int gdimx = 32;

test/cpp/jit/tests.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,8 @@ namespace jit {
143143
_(GPU_FusionReduction5) \
144144
_(GPU_FusionReductionTFT) \
145145
_(GPU_FusionSimpleBCast) \
146+
_(GPU_FusionSimpleGemm) \
147+
_(GPU_FusionSoftmax) \
146148
_(GPU_FusionGridReduction1) \
147149
_(GPU_FusionGridReduction2) \
148150
_(GPU_FusionGridReduction3dim1) \

torch/csrc/jit/codegen/cuda/arith.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -460,9 +460,9 @@ TORCH_CUDA_API TensorView* broadcast(
460460
if (ent)
461461
n_broadcasts++;
462462
TORCH_CHECK(
463-
nBCastDims - n_broadcasts == inp->nDims(),
463+
nBCastDims - n_broadcasts == inp->domain()->noReductions().size(),
464464
"Invalid broadcast, number of false entries in is_broadcast_dim expected to be ",
465-
inp->nDims(),
465+
inp->domain()->noReductions().size(),
466466
" but received ",
467467
nBCastDims - n_broadcasts);
468468

torch/csrc/jit/codegen/cuda/arith.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ TORCH_CUDA_API TensorView* broadcast(
6060
TensorView* inp,
6161
const std::vector<bool>& is_broadcast_dim);
6262

63-
// BINARY OPAERATIONS
63+
// BINARY OPERATIONS
64+
// add
6465
TORCH_CUDA_API Val* add(Val* v1, Val* v2);
6566
TORCH_CUDA_API TensorView* add(TensorView* v1, Val* v2);
6667
TORCH_CUDA_API TensorView* add(Val* v1, TensorView* v2);

0 commit comments

Comments
 (0)