From 1f0baf0b9078c9081864dc74b6be8bc3d38c5a0d Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 20 Jul 2020 14:35:34 -0700 Subject: [PATCH 1/9] Add a test case of computeAt cycles --- test/cpp/jit/test_gpu.cpp | 47 +++++++++++++++++++++++++++++++++++++++ test/cpp/jit/tests.h | 3 ++- 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/test/cpp/jit/test_gpu.cpp b/test/cpp/jit/test_gpu.cpp index 6962e52196dc..533954ff44e2 100644 --- a/test/cpp/jit/test_gpu.cpp +++ b/test/cpp/jit/test_gpu.cpp @@ -4503,6 +4503,53 @@ void testGPU_FusionIsOneInt() { TORCH_CHECK(!z->isOneInt()); } +// This is to verify no cycle of computeAt is created. A more complex +// variation of this pattern appears in one of the Python tests +// (test_random_topo). +void testGPU_FusionComputeAtNonterminatingOutput() { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv0 = makeDummyTensor(1); + fusion.addInput(tv0); + + // Common intermediate tensor + auto tv1 = add(tv0, new Float(0)); + // tv1 -> tv2 + auto tv2 = add(tv1, new Float(0)); + // tv1 -> tv3 -> tv4 + auto tv3 = add(tv1, new Float(0)); + auto tv4 = add(tv3, new Float(0)); + + // The order of adding outputs matters. If tv3 is added before tv4, + // it should be fine. However, if tv4 is added before tv3, there + // will be a cycle of tv3->tv4 and tv4->tv3. tv3->tv4 is created + // first, and then tv4->tv3 is created at the final phase of + // computeAt (ComputeAt::setupOutputs). + if (true) { + // A cycle of tv3 <-> tv4 will be created. + fusion.addOutput(tv2); + fusion.addOutput(tv4); + fusion.addOutput(tv3); + } else { + // This should work fine. + fusion.addOutput(tv2); + fusion.addOutput(tv3); + fusion.addOutput(tv4); + } + + tv0->computeAt(tv2, -1); + + fusion.printMath(); + + TORCH_CHECK( + !(tv3->getComputeAtView() == tv4 && tv4->getComputeAtView() == tv3), + "ComputeAt cycle detected between tv3 and tv4"); + + fusion.printKernel(); + return; +} + } // namespace jit } // namespace torch diff --git a/test/cpp/jit/tests.h b/test/cpp/jit/tests.h index 53cbe78cd68d..58b982ffbc20 100644 --- a/test/cpp/jit/tests.h +++ b/test/cpp/jit/tests.h @@ -184,7 +184,8 @@ namespace jit { _(GPU_FusionSymbolicReduction) \ _(GPU_FusionUnrollWithAlloc) \ _(GPU_FusionIsZeroInt) \ - _(GPU_FusionIsOneInt) + _(GPU_FusionIsOneInt) \ + _(GPU_FusionComputeAtNonterminatingOutput) #else #define TH_FORALL_TESTS_CUDA(_) \ _(ArgumentSpec) \ From 035c9995afa5f799c050fb95c427bc0f043c760a Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 20 Jul 2020 15:09:10 -0700 Subject: [PATCH 2/9] Move getTerminatingOutputs to Fusion --- torch/csrc/jit/codegen/cuda/fusion.cpp | 41 ++++++++++++++++++++ torch/csrc/jit/codegen/cuda/fusion.h | 6 +++ torch/csrc/jit/codegen/cuda/iter_visitor.cpp | 27 +------------ torch/csrc/jit/codegen/cuda/iter_visitor.h | 2 - 4 files changed, 49 insertions(+), 27 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/fusion.cpp b/torch/csrc/jit/codegen/cuda/fusion.cpp index a53714f21e62..5431a2fe166e 100644 --- a/torch/csrc/jit/codegen/cuda/fusion.cpp +++ b/torch/csrc/jit/codegen/cuda/fusion.cpp @@ -38,6 +38,14 @@ std::vector ExprSort::getExprs( return es.exprs; } +std::vector ExprSort::getExprs( + Fusion* fusion, + const std::vector& from) { + ExprSort es; + es.traverseFrom(fusion, from, false); + return es.exprs; +} + void InputsOf::handle(Val* v) { if (FusionGuard::getCurFusion()->origin(v) == nullptr) inputs.emplace(v); @@ -541,6 +549,39 @@ bool Fusion::hasGridReduction() { return false; } +std::vector Fusion::getTerminatingOutputs() { + FusionGuard fg(this); + + std::unordered_set used_vals; + + const auto exprs = ExprSort::getExprs( + this, std::vector(outputs().begin(), outputs().end())); + + for (auto expr : exprs) { + for (auto inp : expr->inputs()) + used_vals.emplace(inp); + } + + std::unordered_set terminating_outputs; + for (auto out : outputs()) { + if (used_vals.find(out) != used_vals.end()) + continue; + terminating_outputs.emplace(out); + } + + std::vector sorted_outputs{terminating_outputs.begin(), + terminating_outputs.end()}; + + // Sort the outputs in order to give a deterministic traversal + // order. + std::sort( + sorted_outputs.begin(), + sorted_outputs.end(), + [](const Val* v0, const Val* v1) { return v0->name() < v1->name(); }); + + return sorted_outputs; +} + } // namespace fuser } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/codegen/cuda/fusion.h b/torch/csrc/jit/codegen/cuda/fusion.h index f882390b8584..59bcac402946 100644 --- a/torch/csrc/jit/codegen/cuda/fusion.h +++ b/torch/csrc/jit/codegen/cuda/fusion.h @@ -79,6 +79,10 @@ class ExprSort : public IterVisitor { bool from_outputs_only, bool breadth_first, bool respect_compute_at); + + static std::vector getExprs( + Fusion* fusion, + const std::vector& from); }; class InputsOf : public IterVisitor { @@ -236,6 +240,8 @@ class TORCH_CUDA_API Fusion final { return outputs_; } + std::vector getTerminatingOutputs(); + bool hasInput(const Val* val) const; bool hasOutput(const Val* val) const; diff --git a/torch/csrc/jit/codegen/cuda/iter_visitor.cpp b/torch/csrc/jit/codegen/cuda/iter_visitor.cpp index e9a78110bd1d..5f5f9182d764 100644 --- a/torch/csrc/jit/codegen/cuda/iter_visitor.cpp +++ b/torch/csrc/jit/codegen/cuda/iter_visitor.cpp @@ -142,8 +142,7 @@ void IterVisitor::traverse_( TORCH_INTERNAL_ASSERT(false, "Not implemented yet."); if (from_outputs_only) { - auto term_outs = IterVisitor::getTerminatingOutputs(fusion); - std::vector term_val_outs(term_outs.begin(), term_outs.end()); + auto term_val_outs = fusion->getTerminatingOutputs(); if (!term_val_outs.empty()) traverseFrom( fusion, term_val_outs, traverse_all_paths, respect_compute_at); @@ -179,6 +178,7 @@ void IterVisitor::traverseAllPaths( namespace { +// TODO: Remove this in favor of ExprSort // Expr sort will take a fusion and return a topologically sorted list of // expressions. class Exprs : public IterVisitor { @@ -222,29 +222,6 @@ class Inputs : public IterVisitor { } // namespace -std::unordered_set IterVisitor::getTerminatingOutputs( - Fusion* const fusion) { - FusionGuard fg(fusion); - - std::unordered_set used_vals; - - const auto exprs = Exprs::getExprs( - fusion, - std::vector(fusion->outputs().begin(), fusion->outputs().end())); - - for (auto expr : exprs) { - for (auto inp : expr->inputs()) - used_vals.emplace(inp); - } - - std::unordered_set terminating_outputs; - for (auto out : fusion->outputs()) - if (used_vals.find(out) == used_vals.end()) - terminating_outputs.emplace(out); - - return terminating_outputs; -} - std::unordered_set IterVisitor::getInputsTo( const std::vector& vals) { return Inputs::getInputs(vals); diff --git a/torch/csrc/jit/codegen/cuda/iter_visitor.h b/torch/csrc/jit/codegen/cuda/iter_visitor.h index 117c5f1a244f..e8aed5f8e47d 100644 --- a/torch/csrc/jit/codegen/cuda/iter_visitor.h +++ b/torch/csrc/jit/codegen/cuda/iter_visitor.h @@ -117,8 +117,6 @@ class TORCH_CUDA_API IterVisitor : public OptOutDispatch { bool breadth_first = false, bool respect_compute_at = false); - static std::unordered_set getTerminatingOutputs(Fusion* const); - static std::unordered_set getInputsTo(const std::vector& vals); }; From 652e8603349cd1ff2cf2da38e795fb5fc6035909 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 20 Jul 2020 15:11:53 -0700 Subject: [PATCH 3/9] Fix #200 --- torch/csrc/jit/codegen/cuda/compute_at.cpp | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/torch/csrc/jit/codegen/cuda/compute_at.cpp b/torch/csrc/jit/codegen/cuda/compute_at.cpp index f57163345c45..55df035e2fb4 100644 --- a/torch/csrc/jit/codegen/cuda/compute_at.cpp +++ b/torch/csrc/jit/codegen/cuda/compute_at.cpp @@ -387,12 +387,21 @@ void ComputeAt::setupOutputs() { return; std::vector touched_output_order; + const auto& terminating_outputs = + FusionGuard::getCurFusion()->getTerminatingOutputs(); for (auto out : FusionGuard::getCurFusion()->outputs()) { if (out->getValType() == ValType::TensorView) { if (tv_data.find(out->as()) != tv_data.end()) { if (tv_data[out->as()].touched()) { - touched_output_order.push_back(out->as()); + // No need to adjust computeAt when an output is not + // a terminating output. + if (std::find( + terminating_outputs.begin(), + terminating_outputs.end(), + out) != terminating_outputs.end()) { + touched_output_order.push_back(out->as()); + } } } } From 13cf2d995753dc4c19b0ecbdf2074949ffde0415 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 20 Jul 2020 16:10:25 -0700 Subject: [PATCH 4/9] Do not try to fix but just fail if a cycle of computeAt is detected. --- torch/csrc/jit/codegen/cuda/iter_visitor.cpp | 14 +++++++++++ .../jit/codegen/cuda/lower_validation.cpp | 23 +++++++++++++------ .../csrc/jit/codegen/cuda/lower_validation.h | 4 ---- 3 files changed, 30 insertions(+), 11 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/iter_visitor.cpp b/torch/csrc/jit/codegen/cuda/iter_visitor.cpp index 5f5f9182d764..6207aa5f8e5b 100644 --- a/torch/csrc/jit/codegen/cuda/iter_visitor.cpp +++ b/torch/csrc/jit/codegen/cuda/iter_visitor.cpp @@ -143,6 +143,20 @@ void IterVisitor::traverse_( if (from_outputs_only) { auto term_val_outs = fusion->getTerminatingOutputs(); + // Reorder outputs such that tensors that are computed at other + // tensors are visited earlier than them. + auto swap_pos = term_val_outs.begin(); + for (auto it = term_val_outs.begin(); it != term_val_outs.end(); ++it) { + Val* val = *it; + if (val->getValType() == ValType::TensorView) { + auto tv = val->as(); + if (tv->hasComputeAt()) { + std::swap(*swap_pos, *it); + ++swap_pos; + continue; + } + } + } if (!term_val_outs.empty()) traverseFrom( fusion, term_val_outs, traverse_all_paths, respect_compute_at); diff --git a/torch/csrc/jit/codegen/cuda/lower_validation.cpp b/torch/csrc/jit/codegen/cuda/lower_validation.cpp index 31a2fc27a95e..4219823bca36 100644 --- a/torch/csrc/jit/codegen/cuda/lower_validation.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_validation.cpp @@ -2,13 +2,16 @@ #include #include #include +#include namespace torch { namespace jit { namespace fuser { +namespace { + // Some pre-compilation checks -static void IrValidate(Fusion* fusion) { +void IrValidate(Fusion* fusion) { fusion->validateInputs(); for (Val* val : fusion->vals()) { if (ir_utils::isTV(val)) { @@ -28,8 +31,7 @@ static void IrValidate(Fusion* fusion) { } } -// Remove circular computeAt references -void IrFixComputeAt(Fusion* fusion) { +void IrValidateNoBackwardComputeAt(Fusion* fusion) { std::vector exprs = fusion->exprs(true); std::set visited; for (auto it = exprs.rbegin(); it != exprs.rend(); it++) { @@ -40,14 +42,21 @@ void IrFixComputeAt(Fusion* fusion) { TensorView* tv = ir_utils::asTV(expr->output(0)); TensorView* ctv = tv->getComputeAtView(); - if (ctv != nullptr && visited.find(ctv) == visited.end()) { - ctv->setComputeAt(tv, (int)tv->getThisComputeAtAxis()); - tv->clearComputeAt(); + if (ctv != nullptr) { + TORCH_INTERNAL_ASSERT( + visited.find(ctv) != visited.end(), + "Inconsistent computeAt detected. ", + tv, + " is computed at ", + ctv, + ", which is not yet visited."); } visited.emplace(tv); } } +} // namespace + void IrBuildSizesMap(Fusion* fusion) { // Sizes of inputs/outputs -> T.size[...] std::unordered_map size_map; @@ -119,7 +128,7 @@ void IrAdjustMemoryTypes(Fusion* fusion) { void PrepareForLowering(Fusion* fusion) { FusionGuard fg(fusion); - IrFixComputeAt(fusion); + IrValidateNoBackwardComputeAt(fusion); IrValidate(fusion); IrBuildSizesMap(fusion); IrAdjustMemoryTypes(fusion); diff --git a/torch/csrc/jit/codegen/cuda/lower_validation.h b/torch/csrc/jit/codegen/cuda/lower_validation.h index 6990012a51cb..b7a9df98c18d 100644 --- a/torch/csrc/jit/codegen/cuda/lower_validation.h +++ b/torch/csrc/jit/codegen/cuda/lower_validation.h @@ -25,10 +25,6 @@ namespace fuser { void TORCH_CUDA_API PrepareForLowering(Fusion* fusion); -// Compute at can have some circular references. Before we can call any tv -// with tv->getComputeAtAxis(i) we need to break those circular dependencies. -void IrFixComputeAt(Fusion* fusion); - // TensorViews are all based on symbolic sizes. When we first initialize them we // don't know if they're inputs or outputs which would mean that they have // runtime shapes. Intermediate tensors (those not going to global memory) do From b5345cbe5ac748b827279d9285fc732c797a08e8 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 20 Jul 2020 16:47:35 -0700 Subject: [PATCH 5/9] Revert "Do not try to fix but just fail if a cycle of computeAt is detected." This reverts commit 69b0b22158e831b33378a9938cec9f10b0879c34. Reordering output tensors isn't trivial. It isn't necessary either, so this validation doesn't buy us much. --- torch/csrc/jit/codegen/cuda/iter_visitor.cpp | 14 ----------- .../jit/codegen/cuda/lower_validation.cpp | 23 ++++++------------- .../csrc/jit/codegen/cuda/lower_validation.h | 4 ++++ 3 files changed, 11 insertions(+), 30 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/iter_visitor.cpp b/torch/csrc/jit/codegen/cuda/iter_visitor.cpp index 6207aa5f8e5b..5f5f9182d764 100644 --- a/torch/csrc/jit/codegen/cuda/iter_visitor.cpp +++ b/torch/csrc/jit/codegen/cuda/iter_visitor.cpp @@ -143,20 +143,6 @@ void IterVisitor::traverse_( if (from_outputs_only) { auto term_val_outs = fusion->getTerminatingOutputs(); - // Reorder outputs such that tensors that are computed at other - // tensors are visited earlier than them. - auto swap_pos = term_val_outs.begin(); - for (auto it = term_val_outs.begin(); it != term_val_outs.end(); ++it) { - Val* val = *it; - if (val->getValType() == ValType::TensorView) { - auto tv = val->as(); - if (tv->hasComputeAt()) { - std::swap(*swap_pos, *it); - ++swap_pos; - continue; - } - } - } if (!term_val_outs.empty()) traverseFrom( fusion, term_val_outs, traverse_all_paths, respect_compute_at); diff --git a/torch/csrc/jit/codegen/cuda/lower_validation.cpp b/torch/csrc/jit/codegen/cuda/lower_validation.cpp index 4219823bca36..31a2fc27a95e 100644 --- a/torch/csrc/jit/codegen/cuda/lower_validation.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_validation.cpp @@ -2,16 +2,13 @@ #include #include #include -#include namespace torch { namespace jit { namespace fuser { -namespace { - // Some pre-compilation checks -void IrValidate(Fusion* fusion) { +static void IrValidate(Fusion* fusion) { fusion->validateInputs(); for (Val* val : fusion->vals()) { if (ir_utils::isTV(val)) { @@ -31,7 +28,8 @@ void IrValidate(Fusion* fusion) { } } -void IrValidateNoBackwardComputeAt(Fusion* fusion) { +// Remove circular computeAt references +void IrFixComputeAt(Fusion* fusion) { std::vector exprs = fusion->exprs(true); std::set visited; for (auto it = exprs.rbegin(); it != exprs.rend(); it++) { @@ -42,21 +40,14 @@ void IrValidateNoBackwardComputeAt(Fusion* fusion) { TensorView* tv = ir_utils::asTV(expr->output(0)); TensorView* ctv = tv->getComputeAtView(); - if (ctv != nullptr) { - TORCH_INTERNAL_ASSERT( - visited.find(ctv) != visited.end(), - "Inconsistent computeAt detected. ", - tv, - " is computed at ", - ctv, - ", which is not yet visited."); + if (ctv != nullptr && visited.find(ctv) == visited.end()) { + ctv->setComputeAt(tv, (int)tv->getThisComputeAtAxis()); + tv->clearComputeAt(); } visited.emplace(tv); } } -} // namespace - void IrBuildSizesMap(Fusion* fusion) { // Sizes of inputs/outputs -> T.size[...] std::unordered_map size_map; @@ -128,7 +119,7 @@ void IrAdjustMemoryTypes(Fusion* fusion) { void PrepareForLowering(Fusion* fusion) { FusionGuard fg(fusion); - IrValidateNoBackwardComputeAt(fusion); + IrFixComputeAt(fusion); IrValidate(fusion); IrBuildSizesMap(fusion); IrAdjustMemoryTypes(fusion); diff --git a/torch/csrc/jit/codegen/cuda/lower_validation.h b/torch/csrc/jit/codegen/cuda/lower_validation.h index b7a9df98c18d..6990012a51cb 100644 --- a/torch/csrc/jit/codegen/cuda/lower_validation.h +++ b/torch/csrc/jit/codegen/cuda/lower_validation.h @@ -25,6 +25,10 @@ namespace fuser { void TORCH_CUDA_API PrepareForLowering(Fusion* fusion); +// Compute at can have some circular references. Before we can call any tv +// with tv->getComputeAtAxis(i) we need to break those circular dependencies. +void IrFixComputeAt(Fusion* fusion); + // TensorViews are all based on symbolic sizes. When we first initialize them we // don't know if they're inputs or outputs which would mean that they have // runtime shapes. Intermediate tensors (those not going to global memory) do From d3d8640de6802778d95824008a7df91915ecbc1a Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 20 Jul 2020 16:54:20 -0700 Subject: [PATCH 6/9] Remove IrFixComputeAt as it is not necessary --- .../jit/codegen/cuda/lower_validation.cpp | 21 ------------------- .../csrc/jit/codegen/cuda/lower_validation.h | 4 ---- 2 files changed, 25 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/lower_validation.cpp b/torch/csrc/jit/codegen/cuda/lower_validation.cpp index 31a2fc27a95e..e5314670e72e 100644 --- a/torch/csrc/jit/codegen/cuda/lower_validation.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_validation.cpp @@ -28,26 +28,6 @@ static void IrValidate(Fusion* fusion) { } } -// Remove circular computeAt references -void IrFixComputeAt(Fusion* fusion) { - std::vector exprs = fusion->exprs(true); - std::set visited; - for (auto it = exprs.rbegin(); it != exprs.rend(); it++) { - Expr* expr = *it; - if (!ir_utils::isTVOp(expr)) - continue; - - TensorView* tv = ir_utils::asTV(expr->output(0)); - TensorView* ctv = tv->getComputeAtView(); - - if (ctv != nullptr && visited.find(ctv) == visited.end()) { - ctv->setComputeAt(tv, (int)tv->getThisComputeAtAxis()); - tv->clearComputeAt(); - } - visited.emplace(tv); - } -} - void IrBuildSizesMap(Fusion* fusion) { // Sizes of inputs/outputs -> T.size[...] std::unordered_map size_map; @@ -119,7 +99,6 @@ void IrAdjustMemoryTypes(Fusion* fusion) { void PrepareForLowering(Fusion* fusion) { FusionGuard fg(fusion); - IrFixComputeAt(fusion); IrValidate(fusion); IrBuildSizesMap(fusion); IrAdjustMemoryTypes(fusion); diff --git a/torch/csrc/jit/codegen/cuda/lower_validation.h b/torch/csrc/jit/codegen/cuda/lower_validation.h index 6990012a51cb..b7a9df98c18d 100644 --- a/torch/csrc/jit/codegen/cuda/lower_validation.h +++ b/torch/csrc/jit/codegen/cuda/lower_validation.h @@ -25,10 +25,6 @@ namespace fuser { void TORCH_CUDA_API PrepareForLowering(Fusion* fusion); -// Compute at can have some circular references. Before we can call any tv -// with tv->getComputeAtAxis(i) we need to break those circular dependencies. -void IrFixComputeAt(Fusion* fusion); - // TensorViews are all based on symbolic sizes. When we first initialize them we // don't know if they're inputs or outputs which would mean that they have // runtime shapes. Intermediate tensors (those not going to global memory) do From f378a5bbe1e5ea2157d8e4e83c9d38507e7a3188 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 20 Jul 2020 17:31:17 -0700 Subject: [PATCH 7/9] Add validation to the test case --- test/cpp/jit/test_gpu.cpp | 56 ++++++++++++++++++++++++++------------- 1 file changed, 38 insertions(+), 18 deletions(-) diff --git a/test/cpp/jit/test_gpu.cpp b/test/cpp/jit/test_gpu.cpp index 533954ff44e2..7d06aaf88e05 100644 --- a/test/cpp/jit/test_gpu.cpp +++ b/test/cpp/jit/test_gpu.cpp @@ -4514,39 +4514,59 @@ void testGPU_FusionComputeAtNonterminatingOutput() { fusion.addInput(tv0); // Common intermediate tensor - auto tv1 = add(tv0, new Float(0)); + auto tv1 = add(tv0, new Float(1)); // tv1 -> tv2 - auto tv2 = add(tv1, new Float(0)); + auto tv2 = add(tv1, new Float(2)); // tv1 -> tv3 -> tv4 - auto tv3 = add(tv1, new Float(0)); - auto tv4 = add(tv3, new Float(0)); + auto tv3 = add(tv1, new Float(3)); + auto tv4 = add(tv3, new Float(4)); + // NOTE: This should no longer occur as of PR #201. // The order of adding outputs matters. If tv3 is added before tv4, // it should be fine. However, if tv4 is added before tv3, there // will be a cycle of tv3->tv4 and tv4->tv3. tv3->tv4 is created // first, and then tv4->tv3 is created at the final phase of // computeAt (ComputeAt::setupOutputs). - if (true) { - // A cycle of tv3 <-> tv4 will be created. - fusion.addOutput(tv2); - fusion.addOutput(tv4); - fusion.addOutput(tv3); - } else { - // This should work fine. - fusion.addOutput(tv2); - fusion.addOutput(tv3); - fusion.addOutput(tv4); - } + fusion.addOutput(tv2); + fusion.addOutput(tv4); + fusion.addOutput(tv3); tv0->computeAt(tv2, -1); - fusion.printMath(); - TORCH_CHECK( !(tv3->getComputeAtView() == tv4 && tv4->getComputeAtView() == tv3), "ComputeAt cycle detected between tv3 and tv4"); - fusion.printKernel(); + const auto options = + at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor input = at::rand(100, options); + + torch::jit::fuser::cuda::FusionExecutor fe; + fe.compileFusion(&fusion); + auto outputs = fe.runFusion({input}); + + auto& output_tv2 = outputs[0]; + auto& output_tv4 = outputs[1]; + auto& output_tv3 = outputs[2]; + + auto aten_t1 = input + 1; + auto aten_t2 = aten_t1 + 2; + auto aten_t3 = aten_t1 + 3; + auto aten_t4 = aten_t3 + 4; + + TORCH_CHECK( + aten_t2.allclose(output_tv2), + "Error of: ", + aten_t2.sub(output_tv2).abs().max()); + TORCH_CHECK( + aten_t3.allclose(output_tv3), + "Error of: ", + aten_t3.sub(output_tv3).abs().max()); + TORCH_CHECK( + aten_t4.allclose(output_tv4), + "Error of: ", + aten_t4.sub(output_tv4).abs().max()); + return; } From 71f6b1ca13a51996caecfe229fb8763158d4b08f Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 20 Jul 2020 17:48:40 -0700 Subject: [PATCH 8/9] Remove Exprs in favor of ExprSort --- torch/csrc/jit/codegen/cuda/iter_visitor.cpp | 23 +------------------- 1 file changed, 1 insertion(+), 22 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/iter_visitor.cpp b/torch/csrc/jit/codegen/cuda/iter_visitor.cpp index 5f5f9182d764..574caf6d51cf 100644 --- a/torch/csrc/jit/codegen/cuda/iter_visitor.cpp +++ b/torch/csrc/jit/codegen/cuda/iter_visitor.cpp @@ -178,27 +178,6 @@ void IterVisitor::traverseAllPaths( namespace { -// TODO: Remove this in favor of ExprSort -// Expr sort will take a fusion and return a topologically sorted list of -// expressions. -class Exprs : public IterVisitor { - private: - std::vector exprs; - - void handle(Expr* expr) override { - exprs.push_back(expr); - } - - public: - static std::vector getExprs( - Fusion* fusion, - const std::vector& from) { - Exprs ex; - ex.traverseFrom(fusion, from, false); - return ex.exprs; - } -}; - // Expr sort will take a fusion and return a topologically sorted list of // expressions. class Inputs : public IterVisitor { @@ -300,7 +279,7 @@ void BackwardVisitor::traverseFrom( auto vals = AllVals::get(fusion, from); - auto exprs = Exprs::getExprs(fusion, from); + auto exprs = ExprSort::getExprs(fusion, from); { size_t pos = 0; From fc7a832801ceed14da32be8ebbe2c5c3c4db574f Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Tue, 21 Jul 2020 09:57:47 -0700 Subject: [PATCH 9/9] Simplify getTerminatingOutputs --- torch/csrc/jit/codegen/cuda/fusion.cpp | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/fusion.cpp b/torch/csrc/jit/codegen/cuda/fusion.cpp index 5431a2fe166e..06170256c87a 100644 --- a/torch/csrc/jit/codegen/cuda/fusion.cpp +++ b/torch/csrc/jit/codegen/cuda/fusion.cpp @@ -562,24 +562,13 @@ std::vector Fusion::getTerminatingOutputs() { used_vals.emplace(inp); } - std::unordered_set terminating_outputs; + std::vector terminating_outputs; for (auto out : outputs()) { if (used_vals.find(out) != used_vals.end()) continue; - terminating_outputs.emplace(out); + terminating_outputs.push_back(out); } - - std::vector sorted_outputs{terminating_outputs.begin(), - terminating_outputs.end()}; - - // Sort the outputs in order to give a deterministic traversal - // order. - std::sort( - sorted_outputs.begin(), - sorted_outputs.end(), - [](const Val* v0, const Val* v1) { return v0->name() < v1->name(); }); - - return sorted_outputs; + return terminating_outputs; } } // namespace fuser