Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
95287c2
Simplify a few test cases
tlemo Apr 27, 2020
33a1b76
ExpressionEvaluator
tlemo May 26, 2020
2adcd3c
Stricter EvaluationContext binding rules
tlemo May 26, 2020
a15ca79
Fix clang-format errors
tlemo May 26, 2020
15fcfe0
Switch to Int::ScalarType
tlemo May 26, 2020
ca2ba1d
Avoid a fight with clang-tidy
tlemo May 26, 2020
9a1c4c5
Add an optional arc from TensorView to its root domain
tlemo May 27, 2020
aa1e89e
Check the numbers of kernel input and output parameters
naoyam May 13, 2020
787540d
Checks kernel arguments
naoyam May 18, 2020
2763c13
Prefer pointers over references
naoyam May 19, 2020
5d911c4
Bug fix
naoyam May 19, 2020
bc32d63
Fix accidental construction of IValue
naoyam May 19, 2020
b09fedd
Use noReduction
naoyam May 19, 2020
f7f261f
Add const to const pointer
naoyam May 19, 2020
de40ff6
Make an integer tensor an error as it is not yet supported
naoyam May 19, 2020
6ef8a8d
clang-tidy
naoyam May 27, 2020
1648bef
Incorporate review feedback
tlemo May 28, 2020
74b2ed5
added lerp support in parser
jjsjann123 May 22, 2020
977d1fe
add missing addcmul parser and tests
jjsjann123 May 26, 2020
d32e77e
clang_format
jjsjann123 May 29, 2020
42213f2
Return TensorView* from binary/compound/ternary ops
naoyam May 30, 2020
c38b5cd
clang-format
naoyam May 31, 2020
71c95c8
Use TensorView* param in reductionOp and sum
naoyam Jun 1, 2020
2604af6
Prefer as instead of static_cast
naoyam Jun 1, 2020
2f909f2
Transform replay refactor (#53)
csarofeen Jun 4, 2020
82226e5
python test fixes (#52)
jjsjann123 Jun 4, 2020
dcb796e
[nvFuser] add torch.jit.fuser context manager (#38993) (#54)
jjsjann123 Jun 5, 2020
65ff3eb
Add another reduction example, change fusion printMath.
csarofeen Jun 5, 2020
272aa1b
Small test fix.
csarofeen Jun 5, 2020
5726ab4
Change Reduction4 test to use TIDx.x
csarofeen Jun 5, 2020
73d7401
Minor cleanup.
csarofeen Jun 5, 2020
e1e5667
Clean up some noexcepts.
csarofeen Jun 5, 2020
490d101
More cleanup.
csarofeen Jun 5, 2020
e3d8441
Refactor computeAt, get first broadcast example working.
csarofeen Jun 6, 2020
831b222
Validate first non-trivial broadcast kernel.
csarofeen Jun 6, 2020
1a88ce8
Fix replay when broadcast is merged with non-broadcast dim.
csarofeen Jun 7, 2020
7c15591
Add constness in replay and index compute.
csarofeen Jun 7, 2020
581223b
Add another broadcast test. Rework index computation for producers, b…
csarofeen Jun 8, 2020
21e2989
Val isCconst fix.
csarofeen Jun 8, 2020
71e4219
Add dot product gemm example.
csarofeen Jun 8, 2020
550ca71
Clang.
csarofeen Jun 8, 2020
4693817
Minor bug fixes.
csarofeen Jun 9, 2020
943a15a
Format and add comments to GEMM test.
csarofeen Jun 9, 2020
9dd01bd
WIP: Fix for enabling broadcast after reduction plus a Softmax test. …
kevinstephano Jun 12, 2020
b48a826
Backout bad merge conflict resolutions.
csarofeen Jun 12, 2020
61ca498
More post rebase cleanup.
csarofeen Jun 12, 2020
fa80486
Refix a few tests. Some from a bad rebase.
csarofeen Jun 16, 2020
ad017a7
Address comments.
csarofeen Jun 17, 2020
95f9f80
Missed some review comments.
csarofeen Jun 18, 2020
2b77eb7
Merge branch '20_6_11_devel' of https://www.github.com/csarofeen/pyto…
csarofeen Jun 18, 2020
629ec01
tmp
csarofeen Jun 18, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
250 changes: 228 additions & 22 deletions test/cpp/jit/test_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1060,7 +1060,7 @@ void testGPU_FusionAdvancedComputeAt() {
TORCH_CHECK(tv0->getComputeAtView() == tv3);
TORCH_CHECK(tv1->getComputeAtView() == tv4);
TORCH_CHECK(tv2->getComputeAtView() == tv4);
TORCH_CHECK(tv3->getComputeAtView() == tv6);
TORCH_CHECK(tv3->getComputeAtView() == tv5);
TORCH_CHECK(tv4->getComputeAtView() == tv5);
TORCH_CHECK(tv5->getComputeAtView() == tv6);
TORCH_CHECK(!tv6->hasComputeAt());
Expand All @@ -1074,7 +1074,7 @@ void testGPU_FusionAdvancedComputeAt() {

tv0->computeAt(tv6, 1);

TORCH_CHECK(tv0->getComputeAtView() == tv3 && tv0->nDims() == 3);
TORCH_CHECK(tv0->getComputeAtView() == tv6 && tv0->nDims() == 3);
TORCH_CHECK(tv1->getComputeAtView() == tv4 && tv1->nDims() == 3);
TORCH_CHECK(tv2->getComputeAtView() == tv4 && tv2->nDims() == 3);
TORCH_CHECK(tv3->getComputeAtView() == tv6 && tv3->nDims() == 3);
Expand Down Expand Up @@ -1148,6 +1148,7 @@ void testGPU_FusionAdvancedComputeAt() {
fusion.addOutput(tv6);

tv2->computeAt(tv4, 1);

TORCH_CHECK(!tv0->hasComputeAt());
TORCH_CHECK(!tv1->hasComputeAt());
TORCH_CHECK(tv2->getComputeAtView() == tv4);
Expand Down Expand Up @@ -1495,9 +1496,6 @@ void testGPU_FusionLoopUnroll() {

int inp_size = 129 * 13 * 3;

// GPULower lower(&fusion);
// lower.printKernel(std::cout);

prog.device_ = 0;
prog.grid((inp_size + 63) / 64);
prog.block(block_size);
Expand Down Expand Up @@ -2163,11 +2161,6 @@ void testGPU_FusionReduction() {
tv2->axis(-1)->parallelize(ParallelType::TIDx);
tv3->axis(-1)->parallelize(ParallelType::TIDx);

// for(auto expr : fusion.exprs(true))
// std::cout<<expr<<std::endl;
// GPULower lower(&fusion);
// lower.printKernel(std::cout);

int numel_x = 65000;
int numel_y = 1025;

Expand Down Expand Up @@ -2585,7 +2578,8 @@ void testGPU_FusionReductionTFT() {

void testGPU_FusionSimpleBCast() {
{
Fusion fusion;
torch::jit::fuser::cuda::CudaKernel prog;
Fusion& fusion = *prog.fusion_;
FusionGuard fg(&fusion);

// Set up your input tensor views
Expand All @@ -2594,17 +2588,45 @@ void testGPU_FusionSimpleBCast() {
fusion.addInput(tv0);
fusion.addInput(tv1);

TensorView* tv2 = add(tv0, tv1);
TensorView* tv2 = broadcast(tv0, {false, false, true});
TensorView* tv3 = broadcast(tv1, {true, false, false});

// tv1[I0, R1] = tv0[I0, I1]
TensorView* tv3 = broadcast(tv2, {false, true, true, false});

Val* tv4 = mul(tv3, makeDummyTensor(4));
TensorView* tv4 = add(tv2, tv3);
tv4->split(-1, 4);
tv4->split(0, 8);
fusion.addOutput(tv4);

tv0->computeAt(tv4, -1);
tv1->computeAt(tv4, -1);

tv4->axis(0)->parallelize(ParallelType::BIDx);
tv4->axis(-1)->parallelize(ParallelType::TIDx);

size_t x = 63, y = 33, z = 15;

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);

at::Tensor t0 = at::randn({x, y}, options);
at::Tensor t1 = at::randn({y, z}, options);

at::Tensor cg_output = at::empty({x, y, z}, options);

prog.device_ = 0;
prog.grid(ceilDiv_(x, 8));
prog.block(4);
torch::jit::fuser::cuda::compileKernel(&prog);
torch::jit::fuser::cuda::runTestKernel(&prog, {t0, t1}, {cg_output});

auto t2 = t0.unsqueeze(-1).expand({x, y, z});
auto t3 = t1.expand({x, y, z});
auto t4 = t2.add(t3);

TORCH_CHECK(t4.allclose(cg_output));
}

{
Fusion fusion;
torch::jit::fuser::cuda::CudaKernel prog;
Fusion& fusion = *prog.fusion_;
FusionGuard fg(&fusion);

// Set up your input tensor views
Expand All @@ -2613,20 +2635,204 @@ void testGPU_FusionSimpleBCast() {
fusion.addInput(tv0);
fusion.addInput(tv1);

TensorView* tv2 = broadcast(tv0, {true, false, false});
TensorView* tv3 = broadcast(tv1, {false, false, true});
// TODO add pointwise ops on the begining before the bcast.

TensorView* tv2 = broadcast(tv0, {false, false, true});
TensorView* tv3 = broadcast(tv1, {true, false, false});

TensorView* tv4 = add(tv2, tv3);

tv4->merge(0, 1);

TensorView* tv4 = mul(tv3, tv2);
fusion.addOutput(tv4);

tv0->computeAt(tv4, -1);
tv1->computeAt(tv4, -1);

// GPULower lower(&fusion);
// lower.printKernel(std::cout);
tv4->axis(0)->parallelize(ParallelType::BIDx);

size_t x = 63, y = 33, z = 15;

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);

at::Tensor t0 = at::randn({x, y}, options);
at::Tensor t1 = at::randn({y, z}, options);

at::Tensor cg_output = at::empty({x, y, z}, options);

prog.device_ = 0;
prog.grid(x * y);
prog.block(1);
torch::jit::fuser::cuda::compileKernel(&prog);
torch::jit::fuser::cuda::runTestKernel(&prog, {t0, t1}, {cg_output});

auto t2 = t0.unsqueeze(-1).expand({x, y, z});
auto t3 = t1.expand({x, y, z});
auto t4 = t2.add(t3);

TORCH_CHECK(t4.allclose(cg_output));
}
}

void testGPU_FusionSimpleGemm() {
{
torch::jit::fuser::cuda::CudaKernel prog;
Fusion& fusion = *prog.fusion_;
FusionGuard fg(&fusion);

// Set up your input tensor views
TensorView* tv0 = makeDummyTensor(2); // M, K
TensorView* tv1 = makeDummyTensor(2); // K, N
fusion.addInput(tv0);
fusion.addInput(tv1);

TensorView* tv2 = broadcast(tv0, {false, false, true});
// tv2[I0, I1, B] = tv0[I0, I1]

TensorView* tv3 = broadcast(tv1, {true, false, false});
// tv3[B, I1, I2] = tv1[I1, I2]

// tv4[I0, I1, I2] = tv2[I0, I1, B] * tv3[B, I1, I2]
TensorView* tv4 = mul(tv2, tv3);
// tv5[I0, R1, I2] = tv4[I0, I1, I2]
TensorView* tv5 = sum(tv4, {1});
fusion.addOutput(tv5);

tv5->split(1, 32);
// tv5[I0, R1o, R1i{32}, I2]

auto tv6 = tv5->rFactor({1});
// tv6[I0, R1o, I1i{32}, I2] = tv4[I0, I1, I2]
// tv5[I0, , R1i{32}, I2] = tv6[I0, R1o, I1i{32}, I2]

tv5->split(0, 4);
tv5->split(-1, 4);
// tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}]
// tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}]

tv0->computeAt(tv5, -1);
tv1->computeAt(tv5, -1);

// tv6[I0o, I0i{4}, R1o, I1i{32}, I2o, I2i{4}]
// tv5[I0o, I0i{4}, , R1i{32}, I2o, I2i{4}]
//--> (line symbolizes compute at location)
// tv4[I0o, I0i{4}, I1i{32}, I2o, I2i{4}|, I1o]
// tv6[I0o, I0i{4}, I1i{32}, I2o, I2i{4}|, R1o]
// tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}|]

tv0->computeAt(tv6, -1);
tv1->computeAt(tv6, -1);
// tv4[I0o, I0i{4}, I1i{32}, I2o, I2i{4}, I1o |]
// tv6[I0o, I0i{4}, I1i{32}, I2o, I2i{4}, R1o |]
// tv5[I0o, I0i{4}, R1i{32}, I2o, I2i{4}|]

tv5->axis(0)->parallelize(ParallelType::BIDz);
tv5->axis(1)->parallelize(ParallelType::TIDz);

tv5->axis(-2)->parallelize(ParallelType::BIDy);
tv5->axis(-1)->parallelize(ParallelType::TIDy);

tv5->axis(2)->parallelize(ParallelType::TIDx);
tv6->axis(2)->parallelize(ParallelType::TIDx);

size_t M = 65, K = 33, N = 17;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

constexpr?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not. We should want to verify the tests work with dynamically sized tensors.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think constexpr buys us anything here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are compile time constants, right? If that's the case constexpr is the most explict language construct to express it (it would also prevent accidental value updates)


auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);

at::Tensor t0 = at::randn({M, K}, options);
at::Tensor t1 = at::randn({K, N}, options);

at::Tensor cg_output = at::empty({M, N}, options);

prog.device_ = 0;
prog.grid(1, ceilDiv_(N, 4), ceilDiv_(M, 4));

prog.block(32, 4, 4);
torch::jit::fuser::cuda::compileKernel(&prog);
torch::jit::fuser::cuda::runTestKernel(&prog, {t0, t1}, {cg_output});

auto t2 = t0.matmul(t1);
TORCH_CHECK(
t2.allclose(cg_output, 1e-5, 1e-5),
"Error of: ",
t2.sub(cg_output).abs().max());
}
}

// This test currently requires a combination of broadcast and reduction
// operations and parellelization strategy that is currently not supported.
// It is a goal to get this example working and this test is added so we
// can continue working on getting this example fixed. Right now it
// produces an incorrect result. Either we need to error coherently on the
// optimization strategy we don't support and set this test to one we do support
// or we need to get this schedule working correctly.
void testGPU_FusionSoftmax() {
torch::jit::fuser::cuda::CudaKernel prog;
Fusion& fusion = *prog.fusion_;
FusionGuard fg(&fusion);

// Set up your input tensor views
TensorView* input_tv0 = makeDummyTensor(3);
fusion.addInput(input_tv0);

TensorView* max_val_tv1 =
reductionOp(BinaryOpType::Max, {2}, new Float(0), input_tv0);
TensorView* bcast_max_tv2 = broadcast(max_val_tv1, {false, false, true});
TensorView* exp_tv3 = sub(input_tv0, bcast_max_tv2);
TensorView* sum_exp_tv4 =
reductionOp(BinaryOpType::Add, {2}, new Float(0), exp_tv3);
TensorView* bcast_sum_tv5 = broadcast(sum_exp_tv4, {false, false, true});
TensorView* output_tv6 = div(exp_tv3, bcast_sum_tv5);

max_val_tv1->split(-1, 32);
TensorView* max_val_rf_tv7 = max_val_tv1->rFactor({-2});
sum_exp_tv4->split(-1, 32);
TensorView* sum_exp_rf_tv8 = sum_exp_tv4->rFactor({-2});

exp_tv3->computeAt(sum_exp_rf_tv8, {2});

max_val_rf_tv7->axis(0)->parallelize(ParallelType::BIDx);
max_val_tv1->axis(0)->parallelize(ParallelType::BIDx);
bcast_max_tv2->axis(0)->parallelize(ParallelType::BIDx);
sum_exp_rf_tv8->axis(0)->parallelize(ParallelType::BIDx);
sum_exp_tv4->axis(0)->parallelize(ParallelType::BIDx);
bcast_sum_tv5->axis(0)->parallelize(ParallelType::BIDx);
output_tv6->axis(0)->parallelize(ParallelType::BIDx);

max_val_rf_tv7->axis(1)->parallelize(ParallelType::BIDy);
max_val_tv1->axis(1)->parallelize(ParallelType::BIDy);
bcast_max_tv2->axis(1)->parallelize(ParallelType::BIDy);
sum_exp_rf_tv8->axis(1)->parallelize(ParallelType::BIDy);
sum_exp_tv4->axis(1)->parallelize(ParallelType::BIDy);
bcast_sum_tv5->axis(1)->parallelize(ParallelType::BIDy);
output_tv6->axis(1)->parallelize(ParallelType::BIDy);

max_val_rf_tv7->axis(-1)->parallelize(ParallelType::TIDx);
max_val_tv1->axis(-1)->parallelize(ParallelType::TIDx);
bcast_max_tv2->axis(-1)->parallelize(ParallelType::TIDx);
exp_tv3->axis(-1)->parallelize(ParallelType::TIDx);
sum_exp_rf_tv8->axis(-1)->parallelize(ParallelType::TIDx);
sum_exp_tv4->axis(-1)->parallelize(ParallelType::TIDx);
bcast_sum_tv5->axis(-1)->parallelize(ParallelType::TIDx);
output_tv6->axis(-1)->parallelize(ParallelType::TIDx);

fusion.addOutput(output_tv6);

prog.device_ = 0;
prog.grid(32, 32);
prog.block(32);
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor t0 = at::randn({32, 32, 128}, options);
at::Tensor cg_output = at::empty({32, 32, 128}, options);
torch::jit::fuser::cuda::compileKernel(&prog);
torch::jit::fuser::cuda::runTestKernel(&prog, {t0}, {cg_output});

auto t2 = at::_softmax(t0, -1, false);
// TORCH_CHECK(
// t2.allclose(cg_output, 1e-5, 1e-5),
// "Error of: ",
// t2.sub(cg_output).abs().max());
}
// Similar to FusionReduction but uses grid reduction
void testGPU_FusionGridReduction1() {
const int gdimx = 32;
Expand Down
2 changes: 2 additions & 0 deletions test/cpp/jit/tests.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ namespace jit {
_(GPU_FusionReduction5) \
_(GPU_FusionReductionTFT) \
_(GPU_FusionSimpleBCast) \
_(GPU_FusionSimpleGemm) \
_(GPU_FusionSoftmax) \
_(GPU_FusionGridReduction1) \
_(GPU_FusionGridReduction2) \
_(GPU_FusionGridReduction3dim1) \
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/jit/codegen/cuda/arith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -460,9 +460,9 @@ TORCH_CUDA_API TensorView* broadcast(
if (ent)
n_broadcasts++;
TORCH_CHECK(
nBCastDims - n_broadcasts == inp->nDims(),
nBCastDims - n_broadcasts == inp->domain()->noReductions().size(),
"Invalid broadcast, number of false entries in is_broadcast_dim expected to be ",
inp->nDims(),
inp->domain()->noReductions().size(),
" but received ",
nBCastDims - n_broadcasts);

Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/jit/codegen/cuda/arith.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ TORCH_CUDA_API TensorView* broadcast(
TensorView* inp,
const std::vector<bool>& is_broadcast_dim);

// BINARY OPAERATIONS
// BINARY OPERATIONS
// add
TORCH_CUDA_API Val* add(Val* v1, Val* v2);
TORCH_CUDA_API TensorView* add(TensorView* v1, Val* v2);
TORCH_CUDA_API TensorView* add(Val* v1, TensorView* v2);
Expand Down
Loading