diff --git a/core/lowering/passes/reduce_to.cpp b/core/lowering/passes/reduce_to.cpp index 1aef153ba2..5b19e63a29 100644 --- a/core/lowering/passes/reduce_to.cpp +++ b/core/lowering/passes/reduce_to.cpp @@ -8,16 +8,6 @@ namespace lowering { namespace passes { void ReduceToOperation(std::shared_ptr& graph) { - std::string to_dtype_layout_pattern = R"IR( - graph(%x, %dtype, %layout, %device, %pm, %nb, %copy, %format): - %out : Tensor = aten::to(%x, %dtype, %layout, %device, %pm, %nb, %copy, %format) - return (%out))IR"; - - std::string to_dtype_multi_input_pattern = R"IR( - graph(%x, %dtype, %layout, %device, %pm, %nb, %copy, %format): - %out : Tensor = aten::to(%x, %device, %dtype, %nb, %copy, %format) - return (%out))IR"; - std::string to_type_as_pattern = R"IR( graph(%input, %other): %out : Tensor = aten::type_as(%input, %other) @@ -30,11 +20,6 @@ void ReduceToOperation(std::shared_ptr& graph) { %out : Tensor = aten::to(%input, %other, %5, %5, %6) return (%out))IR"; - // replace aten::to.dtype_layout with aten::to.dtype - torch::jit::SubgraphRewriter map_aten_dtype_layout; - map_aten_dtype_layout.RegisterRewritePattern(to_dtype_layout_pattern, to_dtype_multi_input_pattern); - map_aten_dtype_layout.runOnGraph(graph); - // replace aten::type_as with aten::to.other torch::jit::SubgraphRewriter map_aten_type_as_to_other; map_aten_type_as_to_other.RegisterRewritePattern(to_type_as_pattern, to_other_pattern); diff --git a/tests/core/lowering/test_reduce_to_pass.cpp b/tests/core/lowering/test_reduce_to_pass.cpp index 0ea8feaf5e..aada20190e 100644 --- a/tests/core/lowering/test_reduce_to_pass.cpp +++ b/tests/core/lowering/test_reduce_to_pass.cpp @@ -6,28 +6,6 @@ #include "torch/csrc/jit/ir/irparser.h" #include "torch/csrc/jit/ir/subgraph_matcher.h" -TEST(LoweringPasses, ReduceToDtypeLayoutCorrectly) { - std::string source_graph = R"IR( - graph(%x, %dtype, %layout, %device, %pm, %nb, %copy, %format): - %out : Tensor = aten::to(%x, %dtype, %layout, %device, %pm, %nb, %copy, %format) - return (%out))IR"; - std::string target_graph = R"IR( - graph(%x, %dtype, %layout, %device, %pm, %nb, %copy, %format): - %out : Tensor = aten::to(%x, %device, %dtype, %nb, %copy, %format) - return (%out))IR"; - - torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( - torch_tensorrt::core::util::logging::LogLevel::kGRAPH); - auto sg = std::make_shared(); - torch::jit::parseIR(source_graph, &*sg); - torch_tensorrt::core::lowering::passes::ReduceToOperation(sg); - - auto tg = std::make_shared(); - torch::jit::parseIR(target_graph, &*tg); - - ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty()); -} - TEST(LoweringPasses, ReduceAtenTypeAsCorrectly) { std::string source_graph = R"IR( graph(%input, %other):