diff --git a/core/lowering/passes/remove_unnecessary_casts.cpp b/core/lowering/passes/remove_unnecessary_casts.cpp index d7c9c77d71..3386608f0d 100644 --- a/core/lowering/passes/remove_unnecessary_casts.cpp +++ b/core/lowering/passes/remove_unnecessary_casts.cpp @@ -131,6 +131,13 @@ void RemoveSingleUse0DTensors(std::shared_ptr& g) { user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]); user->destroy(); break; + case c10::aten::floor_divide: + new_node = g->create(c10::aten::floordiv, user->inputs(), 1); + new_node->insertAfter(user); + new_node->outputs()[0]->setType(c10::IntType::get()); + user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]); + user->destroy(); + break; default: new_node = g->create(user->kind(), user->inputs(), 1); new_node->insertAfter(user); diff --git a/tests/core/lowering/test_remove_unnecessary_casts.cpp b/tests/core/lowering/test_remove_unnecessary_casts.cpp index f1a8f9ff4f..dc4c397148 100644 --- a/tests/core/lowering/test_remove_unnecessary_casts.cpp +++ b/tests/core/lowering/test_remove_unnecessary_casts.cpp @@ -153,3 +153,136 @@ TEST(LoweringPasses, RemoveSingleUse0DTensorsFloatCorrectly) { ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty()); } + +TEST(LoweringPasses, RemoveSingleUse0DTensorsFloorDivIntCorrectly) { + std::string source_graph = R"IR( + graph(%0: int): + %1: Tensor = prim::Constant[value=[7]]() + %3: Tensor = prim::NumToTensor(%0) + %4: Tensor = aten::floor_divide(%1, %3) + %5: int = aten::Int(%4) + return (%5))IR"; + std::string target_graph = R"IR( + graph(%0: int): + %1: int = prim::Constant[value=7]() + %4: int = aten::floordiv(%1, %0) + return (%4))IR"; + + torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( + torch_tensorrt::core::util::logging::LogLevel::kGRAPH); + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph, sg.get()); + + auto first_op = *(sg->block()->nodes().begin()); + torch::jit::WithInsertPoint guard(first_op); + torch::jit::Value* r = sg->insertConstant(c10::scalar_to_tensor(7), c10::nullopt, first_op->scope()); + r->copyMetadata(first_op->output()); + r->setType(c10::TensorType::get()); + first_op->output()->replaceAllUsesWith(r); + first_op->destroy(); + + torch_tensorrt::core::lowering::passes::RemoveSingleUse0DTensors(sg); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, tg.get()); + + ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty()); +} + +TEST(LoweringPasses, RemoveSingleUse0DTensorsFloorDivFloatCorrectly) { + std::string source_graph = R"IR( + graph(%0: float): + %1: Tensor = prim::Constant[value=[8.]]() + %3: Tensor = prim::NumToTensor(%0) + %4: Tensor = aten::floor_divide(%1, %3) + %5: float = aten::Float(%4) + return (%5))IR"; + std::string target_graph = R"IR( + graph(%0: float): + %1: float = prim::Constant[value=8.]() + %4: float = aten::floordiv(%1, %0) + return (%4))IR"; + + torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( + torch_tensorrt::core::util::logging::LogLevel::kGRAPH); + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph, sg.get()); + + auto first_op = *(sg->block()->nodes().begin()); + torch::jit::WithInsertPoint guard(first_op); + torch::jit::Value* r = sg->insertConstant(c10::scalar_to_tensor(8.0), c10::nullopt, first_op->scope()); + r->copyMetadata(first_op->output()); + r->setType(c10::TensorType::get()); + first_op->output()->replaceAllUsesWith(r); + first_op->destroy(); + + torch_tensorrt::core::lowering::passes::RemoveSingleUse0DTensors(sg); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, tg.get()); + + ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty()); +} + +TEST(LoweringPasses, RemoveSingleUse0DTensorsFloorDivIntValuesAgree) { + std::string source_graph_no_inputs = R"IR( + graph(): + %0: int = prim::Constant[value=2]() + %11: int = prim::Constant[value=7]() + %3: Tensor = prim::NumToTensor(%0) + %1: Tensor = prim::NumToTensor(%11) + %4: Tensor = aten::floor_divide(%1, %3) + %50: int = aten::Int(%4) + %5: Tensor = prim::NumToTensor(%50) + return (%5))IR"; + std::string target_graph_no_inputs = R"IR( + graph(): + %0: int = prim::Constant[value=2]() + %1: int = prim::Constant[value=7]() + %40: int = aten::floordiv(%1, %0) + %4: Tensor = prim::NumToTensor(%40) + return (%4))IR"; + + auto g_in = std::make_shared(); + auto g_out = std::make_shared(); + + torch::jit::parseIR(source_graph_no_inputs, g_in.get()); + torch::jit::parseIR(target_graph_no_inputs, g_out.get()); + + auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g_in, {}); + auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g_out, {}); + + ASSERT_TRUE(torch_tensorrt::tests::util::exactlyEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor())); +} + +TEST(LoweringPasses, RemoveSingleUse0DTensorsFloorDivFloatValuesAgree) { + std::string source_graph_no_inputs = R"IR( + graph(): + %0: float = prim::Constant[value=2.]() + %11: float = prim::Constant[value=7.]() + %3: Tensor = prim::NumToTensor(%0) + %1: Tensor = prim::NumToTensor(%11) + %4: Tensor = aten::floor_divide(%1, %3) + %50: float = aten::Float(%4) + %5: Tensor = prim::NumToTensor(%50) + return (%5))IR"; + std::string target_graph_no_inputs = R"IR( + graph(): + %0: float = prim::Constant[value=2.]() + %1: float = prim::Constant[value=7.]() + %40: float = aten::floordiv(%1, %0) + %4: Tensor = prim::NumToTensor(%40) + return (%4))IR"; + + auto g_in = std::make_shared(); + auto g_out = std::make_shared(); + + torch::jit::parseIR(source_graph_no_inputs, g_in.get()); + torch::jit::parseIR(target_graph_no_inputs, g_out.get()); + + auto jit_pre_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g_in, {}); + auto jit_post_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g_out, {}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_pre_results[0].toTensor(), jit_post_results[0].toTensor(), 2e-6)); +}