From 9ba01764e6b6686f0caae9b9d127bf3bc8cd5ae6 Mon Sep 17 00:00:00 2001 From: Lemo Date: Tue, 30 Jun 2020 10:20:29 -0700 Subject: [PATCH 1/5] Stylistic changes --- torch/csrc/jit/codegen/cuda/iter_visitor.cpp | 7 +++---- torch/csrc/jit/codegen/cuda/iter_visitor.h | 12 +++--------- torch/csrc/jit/codegen/cuda/transform_iter.h | 4 ++-- 3 files changed, 8 insertions(+), 15 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/iter_visitor.cpp b/torch/csrc/jit/codegen/cuda/iter_visitor.cpp index 6d0dfcc9e908..2430f492e242 100644 --- a/torch/csrc/jit/codegen/cuda/iter_visitor.cpp +++ b/torch/csrc/jit/codegen/cuda/iter_visitor.cpp @@ -64,8 +64,7 @@ std::vector IterVisitor::next(Expr* expr, bool respect_compute_at) { } // Remove any stmt in stmts that is in visited -namespace { -void remove_visited( +static void remove_visited( std::vector& stmts, const std::unordered_set& visited) { std::deque::iterator> to_erase; @@ -79,7 +78,6 @@ void remove_visited( to_erase.pop_back(); } } -} // namespace void IterVisitor::traverseFrom( Fusion* const fusion, @@ -104,7 +102,7 @@ void IterVisitor::traverseFrom( all_inputs_visited = true; continue; } - auto& stmt = current_inputs.back(); + const auto stmt = current_inputs.back(); // Visit stmt when all_inputs_visited is true. if (all_inputs_visited) { // Mark visited @@ -217,6 +215,7 @@ struct Inputs : public IterVisitor { return inps.inputs; } }; + } // namespace std::unordered_set IterVisitor::getTerminatingOutputs( diff --git a/torch/csrc/jit/codegen/cuda/iter_visitor.h b/torch/csrc/jit/codegen/cuda/iter_visitor.h index 832f2a41507c..2ed699fa6814 100644 --- a/torch/csrc/jit/codegen/cuda/iter_visitor.h +++ b/torch/csrc/jit/codegen/cuda/iter_visitor.h @@ -56,17 +56,17 @@ struct TORCH_CUDA_API IterVisitor : public OptOutDispatch { // This handle functions is called on every Statement* in topological order, // starting from outputs to inputs. - virtual void handle(Statement* s) override { + void handle(Statement* s) override { OptOutDispatch::handle(s); } // This handle functions is called on every Expr* in topological order, // starting from outputs to inputs. - virtual void handle(Expr* e) override { + void handle(Expr* e) override { OptOutDispatch::handle(e); } // This handle functions is called on every Val* in topological order, // starting from outputs to inputs. - virtual void handle(Val* v) override { + void handle(Val* v) override { OptOutDispatch::handle(v); } @@ -96,12 +96,6 @@ struct TORCH_CUDA_API IterVisitor : public OptOutDispatch { bool traverseAllPaths = false, bool respectComputeAt = false); - void traverseFrom2( - Fusion* const fusion, - const std::vector& from, - bool traverseAllPaths = false, - bool respectComputeAt = false); - // from_outputs_only = true start from outputs registered with fusion, // from_outputs_only = false start from all leaf nodes, // bool breadth_first = true is not implemented yet diff --git a/torch/csrc/jit/codegen/cuda/transform_iter.h b/torch/csrc/jit/codegen/cuda/transform_iter.h index 925179babfa3..9b33f96ef764 100644 --- a/torch/csrc/jit/codegen/cuda/transform_iter.h +++ b/torch/csrc/jit/codegen/cuda/transform_iter.h @@ -73,10 +73,10 @@ struct TORCH_CUDA_API ReplayTransformations : public IterVisitor { // TODO: HANDLE RFACTOR DOMAINS // We're going to replay this split operation on the corresponding ID - virtual void handle(Split* s) override; + void handle(Split* s) override; // We're going to replay this merge operation on the corresponding IDs - virtual void handle(Merge* m) override; + void handle(Merge* m) override; public: ReplayTransformations( From 6ed271d80f0e8b81eebe92709894fd2f3f671683 Mon Sep 17 00:00:00 2001 From: Lemo Date: Tue, 30 Jun 2020 10:20:55 -0700 Subject: [PATCH 2/5] Constness workaround for the ExpressionEvaluator interface --- torch/csrc/jit/codegen/cuda/expr_evaluator.cpp | 7 +++++-- torch/csrc/jit/codegen/cuda/expr_evaluator.h | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/expr_evaluator.cpp b/torch/csrc/jit/codegen/cuda/expr_evaluator.cpp index 2a333fcbf331..724586594484 100644 --- a/torch/csrc/jit/codegen/cuda/expr_evaluator.cpp +++ b/torch/csrc/jit/codegen/cuda/expr_evaluator.cpp @@ -38,11 +38,14 @@ void EvaluationContext::print() const { } c10::optional ExpressionEvaluator::evaluate( - Val* val, + const Val* val, const EvaluationContext* context) { TORCH_CHECK(context != nullptr); ExpressionEvaluator evaluator(context); - evaluator.traverseFrom(context->fusion(), {val}, false); + // IterVisitor only supports Val*, so we need to strip away const from + // our target val. This is ugly, but it's ok since we know that all + // the original Fusion nodes are allocated from the heap as non-const + evaluator.traverseFrom(context->fusion(), {const_cast(val)}, false); return evaluator.value(val); } diff --git a/torch/csrc/jit/codegen/cuda/expr_evaluator.h b/torch/csrc/jit/codegen/cuda/expr_evaluator.h index 05f5b63a587f..4dfdd93c5352 100644 --- a/torch/csrc/jit/codegen/cuda/expr_evaluator.h +++ b/torch/csrc/jit/codegen/cuda/expr_evaluator.h @@ -48,7 +48,7 @@ class TORCH_CUDA_API ExpressionEvaluator : private IterVisitor { // Returns the result of the specified expression, or nullopt if // the result cannot be evaluated static c10::optional evaluate( - Val* val, + const Val* val, const EvaluationContext* context); private: From 2d4b9546d1362bf9a7ff159d2c1308fa7c6d93af Mon Sep 17 00:00:00 2001 From: Lemo Date: Tue, 30 Jun 2020 10:34:33 -0700 Subject: [PATCH 3/5] Formatting fix --- torch/csrc/jit/codegen/cuda/expr_evaluator.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/csrc/jit/codegen/cuda/expr_evaluator.cpp b/torch/csrc/jit/codegen/cuda/expr_evaluator.cpp index 724586594484..7ec553ff95d5 100644 --- a/torch/csrc/jit/codegen/cuda/expr_evaluator.cpp +++ b/torch/csrc/jit/codegen/cuda/expr_evaluator.cpp @@ -43,7 +43,7 @@ c10::optional ExpressionEvaluator::evaluate( TORCH_CHECK(context != nullptr); ExpressionEvaluator evaluator(context); // IterVisitor only supports Val*, so we need to strip away const from - // our target val. This is ugly, but it's ok since we know that all + // our target val. This is ugly, but it's ok since we know that all // the original Fusion nodes are allocated from the heap as non-const evaluator.traverseFrom(context->fusion(), {const_cast(val)}, false); return evaluator.value(val); From 98d859f0b07201b2f58a35fc4011f8e9a907b4d8 Mon Sep 17 00:00:00 2001 From: Lemo Date: Tue, 30 Jun 2020 11:46:49 -0700 Subject: [PATCH 4/5] Revert "Constness workaround for the ExpressionEvaluator interface" This reverts commit 6ed271d80f0e8b81eebe92709894fd2f3f671683. --- torch/csrc/jit/codegen/cuda/expr_evaluator.cpp | 7 ++----- torch/csrc/jit/codegen/cuda/expr_evaluator.h | 2 +- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/expr_evaluator.cpp b/torch/csrc/jit/codegen/cuda/expr_evaluator.cpp index 7ec553ff95d5..2a333fcbf331 100644 --- a/torch/csrc/jit/codegen/cuda/expr_evaluator.cpp +++ b/torch/csrc/jit/codegen/cuda/expr_evaluator.cpp @@ -38,14 +38,11 @@ void EvaluationContext::print() const { } c10::optional ExpressionEvaluator::evaluate( - const Val* val, + Val* val, const EvaluationContext* context) { TORCH_CHECK(context != nullptr); ExpressionEvaluator evaluator(context); - // IterVisitor only supports Val*, so we need to strip away const from - // our target val. This is ugly, but it's ok since we know that all - // the original Fusion nodes are allocated from the heap as non-const - evaluator.traverseFrom(context->fusion(), {const_cast(val)}, false); + evaluator.traverseFrom(context->fusion(), {val}, false); return evaluator.value(val); } diff --git a/torch/csrc/jit/codegen/cuda/expr_evaluator.h b/torch/csrc/jit/codegen/cuda/expr_evaluator.h index 4dfdd93c5352..05f5b63a587f 100644 --- a/torch/csrc/jit/codegen/cuda/expr_evaluator.h +++ b/torch/csrc/jit/codegen/cuda/expr_evaluator.h @@ -48,7 +48,7 @@ class TORCH_CUDA_API ExpressionEvaluator : private IterVisitor { // Returns the result of the specified expression, or nullopt if // the result cannot be evaluated static c10::optional evaluate( - const Val* val, + Val* val, const EvaluationContext* context); private: From 42dcd3391bf95418e1e7ecdbc1890bf22302358e Mon Sep 17 00:00:00 2001 From: Lemo Date: Tue, 30 Jun 2020 13:22:05 -0700 Subject: [PATCH 5/5] Incorporating feedback --- torch/csrc/jit/codegen/cuda/iter_visitor.cpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/iter_visitor.cpp b/torch/csrc/jit/codegen/cuda/iter_visitor.cpp index 2430f492e242..6a0e6efcaa3d 100644 --- a/torch/csrc/jit/codegen/cuda/iter_visitor.cpp +++ b/torch/csrc/jit/codegen/cuda/iter_visitor.cpp @@ -63,8 +63,10 @@ std::vector IterVisitor::next(Expr* expr, bool respect_compute_at) { return next_stmts; } +namespace { + // Remove any stmt in stmts that is in visited -static void remove_visited( +void remove_visited( std::vector& stmts, const std::unordered_set& visited) { std::deque::iterator> to_erase; @@ -79,6 +81,8 @@ static void remove_visited( } } +} // namespace + void IterVisitor::traverseFrom( Fusion* const fusion, const std::vector& from, @@ -102,7 +106,7 @@ void IterVisitor::traverseFrom( all_inputs_visited = true; continue; } - const auto stmt = current_inputs.back(); + const auto& stmt = current_inputs.back(); // Visit stmt when all_inputs_visited is true. if (all_inputs_visited) { // Mark visited