diff --git a/core/lowering/passes/replace_aten_pad.cpp b/core/lowering/passes/replace_aten_pad.cpp index f99a0349c1..dd5e721b0c 100644 --- a/core/lowering/passes/replace_aten_pad.cpp +++ b/core/lowering/passes/replace_aten_pad.cpp @@ -99,10 +99,17 @@ void ReplaceAtenPad(std::shared_ptr& graph) { } else if (mode_str == "constant") { // aten::constant_pad_nd(Tensor self, int[] pad, Scalar value=0) -> (Tensor) torch::jit::Node* new_node; + auto pad_value = it->inputs()[3]; + auto is_pad_none = torch::jit::toIValue(it->inputs()[3])->isNone(); + if (is_pad_none) { + pad_value = graph->insertConstant(0.0); + } + new_node = graph->create( c10::Symbol::fromQualString("aten::constant_pad_nd"), - torch::jit::ArrayRef({it->inputs()[0], it->inputs()[1], it->inputs()[3]}), + torch::jit::ArrayRef({it->inputs()[0], it->inputs()[1], pad_value}), 1); + new_node->insertAfter(*it); new_node->outputs()[0]->setType(c10::TensorType::get()); it->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]); diff --git a/tests/core/lowering/test_replace_aten_pad_pass.cpp b/tests/core/lowering/test_replace_aten_pad_pass.cpp index ddad5efb16..a36a3dac22 100644 --- a/tests/core/lowering/test_replace_aten_pad_pass.cpp +++ b/tests/core/lowering/test_replace_aten_pad_pass.cpp @@ -43,6 +43,43 @@ TEST(LoweringPasses, AtenPadConstantCorrectly) { ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(trt_results_sg[0], trt_results_tg[0], 2e-6)); } +TEST(LoweringPasses, AtenPadConstantNoneValueCorrectly) { + const auto source_graph = R"IR( + graph(%0 : Tensor): + %2 : str = prim::Constant[value="constant"]() + %1 : int[] = prim::Constant[value=[2, 3, 2, 3]]() + %3 : NoneType = prim::Constant() + %4 : Tensor = aten::pad(%0, %1, %2, %3) + return (%4))IR"; + + const auto target_graph = R"IR( + graph(%0 : Tensor): + %1 : int[] = prim::Constant[value=[2, 3, 2, 3]]() + %2 : Scalar = prim::Constant[value=0.0]() + %3 : Tensor = aten::constant_pad_nd(%0, %1, %2) + return (%3))IR"; + + torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level( + torch_tensorrt::core::util::logging::LogLevel::kGRAPH); + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph, &*sg); + torch_tensorrt::core::lowering::passes::ReplaceAtenPad(sg); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, &*tg); + + auto in = at::randint(1, 10, {1, 3, 4, 5}, {at::kCUDA}); + + auto trt_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(sg->inputs(), {}); + auto trt_results_sg = torch_tensorrt::tests::util::RunGraphEngine(sg, params, {trt_in}); + + params = torch_tensorrt::core::ir::get_static_params(tg->inputs(), {}); + auto trt_results_tg = torch_tensorrt::tests::util::RunGraphEngine(tg, params, {trt_in}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(trt_results_sg[0], trt_results_tg[0], 2e-6)); +} + TEST(LoweringPasses, AtenPadReflect1dCorrectly) { const auto source_graph = R"IR( graph(%0 : Tensor): @@ -221,4 +258,4 @@ TEST(LoweringPasses, AtenPadReplicate3dCorrectly) { auto trt_results_tg = torch_tensorrt::tests::util::RunGraphEngine(tg, params, {trt_in}); ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(trt_results_sg[0], trt_results_tg[0], 2e-6)); -} \ No newline at end of file +}