Skip to content

Commit d923805

Browse files
authored
fix: aten::where with differing-shape inputs bugfix (#1533)
1 parent 0567b34 commit d923805

File tree

2 files changed

+46
-0
lines changed

2 files changed

+46
-0
lines changed

core/conversion/converters/impl/select.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -736,8 +736,22 @@ auto select_registrations TORCHTRT_UNUSED =
736736
{"aten::where.self(Tensor condition, Tensor self, Tensor other) -> (Tensor)",
737737
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
738738
auto condition = args[0].ITensorOrFreeze(ctx);
739+
auto condition_nbDims = condition->getDimensions().nbDims;
739740
auto x = args[1].ITensorOrFreeze(ctx);
741+
auto x_nbDims = x->getDimensions().nbDims;
740742
auto y = args[2].ITensorOrFreeze(ctx);
743+
auto y_nbDims = y->getDimensions().nbDims;
744+
745+
// Get maximum rank of all input tensors
746+
auto max_nbDims = std::max(condition_nbDims, std::max(x_nbDims, y_nbDims));
747+
748+
// TensorRT requires all inputs to Select layers to have the same rank, so for each
749+
// tensor input, ensure that its rank is equal to the maximum number of dimensions
750+
// If not, left-pad the tensor dimension with 1s until the max rank is achieved
751+
condition =
752+
addPadding(ctx, n, condition, max_nbDims, /*bool trailing =*/false, /*bool use_zeros =*/false);
753+
x = addPadding(ctx, n, x, max_nbDims, /*bool trailing =*/false, /*bool use_zeros =*/false);
754+
y = addPadding(ctx, n, y, max_nbDims, /*bool trailing =*/false, /*bool use_zeros =*/false);
741755

742756
auto layer = ctx->net->addSelect(*condition, *x, *y);
743757

tests/core/conversion/converters/test_select.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1224,3 +1224,35 @@ TEST(Converters, WhereConvertsCorrectly) {
12241224

12251225
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
12261226
}
1227+
1228+
TEST(Converters, WhereConvertsMismatchedShapesCorrectly) {
1229+
const auto graph = R"IR(
1230+
graph(%condition : Tensor,
1231+
%x : Tensor,
1232+
%y : Tensor):
1233+
%out : Tensor = aten::where(%condition, %x, %y)
1234+
return (%out))IR";
1235+
1236+
auto g = std::make_shared<torch::jit::Graph>();
1237+
1238+
torch::jit::parseIR(graph, g.get());
1239+
1240+
// As per Torch behavior, the input Tensors are expected to be broadcasted
1241+
// along their respective dimension in the largest-rank Tensor provided
1242+
auto condition = at::randint(0, 2, {7, 5}, {at::kCUDA}).to(torch::kBool);
1243+
auto x = at::randn({2, 7, 5}, {at::kCUDA});
1244+
auto y = at::randn({5}, {at::kCUDA});
1245+
1246+
auto jit_condition = at::clone(condition);
1247+
auto jit_x = at::clone(x);
1248+
auto jit_y = at::clone(y);
1249+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
1250+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_condition, jit_x, jit_y});
1251+
1252+
auto trt_condition = at::clone(condition);
1253+
auto trt_x = at::clone(x);
1254+
auto trt_y = at::clone(y);
1255+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_condition, trt_x, trt_y});
1256+
1257+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
1258+
}

0 commit comments

Comments
 (0)