diff --git a/core/conversion/converters/impl/unsqueeze.cpp b/core/conversion/converters/impl/unsqueeze.cpp index 16a320c8bf..1c7e3e8968 100644 --- a/core/conversion/converters/impl/unsqueeze.cpp +++ b/core/conversion/converters/impl/unsqueeze.cpp @@ -32,7 +32,7 @@ auto unsqueeze_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns(). auto shuffle_layer = ctx->net->addShuffle(*self); TORCHTRT_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n); - shuffle_layer->setReshapeDimensions(util::unsqueezeDims(self->getDimensions(), dim)); + shuffle_layer->setReshapeDimensions(util::unsqueezeDims(self->getDimensions(), dim, 1, false)); auto out = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle_layer->getOutput(0)); diff --git a/tests/core/conversion/converters/test_unsqueeze.cpp b/tests/core/conversion/converters/test_unsqueeze.cpp index 199abe58bd..88203ae5ad 100644 --- a/tests/core/conversion/converters/test_unsqueeze.cpp +++ b/tests/core/conversion/converters/test_unsqueeze.cpp @@ -47,3 +47,25 @@ TEST(Converters, ATenUnsqueezeNegativeDimConvertsCorrectly) { ASSERT_TRUE( torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); } + +TEST(Converters, ATenUnsqueezeConvertsCorrectlyWithDynamicInput) { + const auto graph = R"IR( + graph(%0 : Tensor): + %1 : int = prim::Constant[value=1]() + %2 : Tensor = aten::unsqueeze(%0, %1) + return (%2))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {1, 10}, {at::kCUDA}); + + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); + + ASSERT_TRUE( + torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); +}