diff --git a/core/conversion/converters/impl/select.cpp b/core/conversion/converters/impl/select.cpp index 79f061ad5d..567735dfbd 100644 --- a/core/conversion/converters/impl/select.cpp +++ b/core/conversion/converters/impl/select.cpp @@ -721,6 +721,23 @@ auto select_registrations TORCHTRT_UNUSED = layer->setName(util::node_info(n).c_str()); + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], layer->getOutput(0)); + LOG_DEBUG("Output shape: " << out_tensor->getDimensions()); + return true; + }}) + .pattern( + {"aten::where.self(Tensor condition, Tensor self, Tensor other) -> (Tensor)", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto condition = args[0].ITensorOrFreeze(ctx); + auto x = args[1].ITensorOrFreeze(ctx); + auto y = args[2].ITensorOrFreeze(ctx); + + auto layer = ctx->net->addSelect(*condition, *x, *y); + + TORCHTRT_CHECK(layer, "Unable to create select layer for aten::where.self"); + + layer->setName(util::node_info(n).c_str()); + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], layer->getOutput(0)); LOG_DEBUG("Output shape: " << out_tensor->getDimensions()); return true; diff --git a/tests/core/conversion/converters/test_select.cpp b/tests/core/conversion/converters/test_select.cpp index c04036e9ba..2b70ac3dfc 100644 --- a/tests/core/conversion/converters/test_select.cpp +++ b/tests/core/conversion/converters/test_select.cpp @@ -1138,3 +1138,33 @@ TEST(Converters, ScatterSrcConvertsCorrectly) { ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[i], trt, 2e-6)); } } + +TEST(Converters, WhereConvertsCorrectly) { + const auto graph = R"IR( + graph(%condition : Tensor, + %x : Tensor, + %y : Tensor): + %out : Tensor = aten::where(%condition, %x, %y) + return (%out))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto condition = at::randint(0, 2, {5, 5}, {at::kCUDA}).to(torch::kBool); + auto x = at::randn({5, 5}, {at::kCUDA}); + auto y = at::randn({5, 5}, {at::kCUDA}); + + auto jit_condition = at::clone(condition); + auto jit_x = at::clone(x); + auto jit_y = at::clone(y); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_condition, jit_x, jit_y}); + + auto trt_condition = at::clone(condition); + auto trt_x = at::clone(x); + auto trt_y = at::clone(y); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_condition, trt_x, trt_y}); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); +}