Skip to content

Commit de9cbaf

Browse files
committed
fix: Bug in aten::where with differing-shape inputs
- Behavior of Torch-TRT differing from that of Torch in the case where the input tensors to `aten::where` have different rank - Torch automatically broadcasts tensors to the highest-rank variant whereas the TRT Select layer requires tensors of the same rank and throws an error - Add dimension checking and unsqueeze operator to ensure broadcasting is enabled - Add test case to catch error
1 parent 063be0d commit de9cbaf

File tree

2 files changed

+65
-0
lines changed

2 files changed

+65
-0
lines changed

core/conversion/converters/impl/select.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -736,8 +736,41 @@ auto select_registrations TORCHTRT_UNUSED =
736736
{"aten::where.self(Tensor condition, Tensor self, Tensor other) -> (Tensor)",
737737
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
738738
auto condition = args[0].ITensorOrFreeze(ctx);
739+
auto c_nbDims = condition->getDimensions().nbDims;
739740
auto x = args[1].ITensorOrFreeze(ctx);
741+
auto x_nbDims = x->getDimensions().nbDims;
740742
auto y = args[2].ITensorOrFreeze(ctx);
743+
auto y_nbDims = y->getDimensions().nbDims;
744+
745+
// Get maximum rank of all input tensors
746+
auto max_nbDims = std::max(c_nbDims, std::max(x_nbDims, y_nbDims));
747+
748+
// TensorRT requires all inputs to Select layers to have the same rank, so for each
749+
// tensor input, ensure that its rank is equal to the maximum number of dimensions
750+
// If not, left-pad the tensor dimension with 1s until the max rank is achieved
751+
auto add_reshape = [&ctx, &max_nbDims](nvinfer1::ITensor*& tensor) {
752+
nvinfer1::Dims dimensions = tensor->getDimensions();
753+
754+
// If the rank of this tensor is smaller than the max rank, use reshape
755+
if (dimensions.nbDims < max_nbDims) {
756+
auto shuffle_layer = ctx->net->addShuffle(*tensor);
757+
758+
// For each dimension from the rank of the smaller tensor to the max rank,
759+
// unsqueeze dimensions by 1
760+
for (auto i = dimensions.nbDims; i < max_nbDims; i++) {
761+
dimensions = util::unsqueezeDims(dimensions, 0, 1, false);
762+
}
763+
764+
// Reshape to the unsqueezed dimensions
765+
shuffle_layer->setReshapeDimensions(dimensions);
766+
tensor = shuffle_layer->getOutput(0);
767+
}
768+
};
769+
770+
// Apply reshape to each tensor input
771+
add_reshape(condition);
772+
add_reshape(x);
773+
add_reshape(y);
741774

742775
auto layer = ctx->net->addSelect(*condition, *x, *y);
743776

tests/core/conversion/converters/test_select.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1224,3 +1224,35 @@ TEST(Converters, WhereConvertsCorrectly) {
12241224

12251225
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
12261226
}
1227+
1228+
TEST(Converters, WhereConvertsMismatchedShapesCorrectly) {
1229+
const auto graph = R"IR(
1230+
graph(%condition : Tensor,
1231+
%x : Tensor,
1232+
%y : Tensor):
1233+
%out : Tensor = aten::where(%condition, %x, %y)
1234+
return (%out))IR";
1235+
1236+
auto g = std::make_shared<torch::jit::Graph>();
1237+
1238+
torch::jit::parseIR(graph, g.get());
1239+
1240+
// As per Torch behavior, the input Tensors are expected to be broadcasted
1241+
// along their respective dimension in the largest-rank Tensor provided
1242+
auto condition = at::randint(0, 2, {7, 5}, {at::kCUDA}).to(torch::kBool);
1243+
auto x = at::randn({2, 7, 5}, {at::kCUDA});
1244+
auto y = at::randn({5}, {at::kCUDA});
1245+
1246+
auto jit_condition = at::clone(condition);
1247+
auto jit_x = at::clone(x);
1248+
auto jit_y = at::clone(y);
1249+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
1250+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_condition, jit_x, jit_y});
1251+
1252+
auto trt_condition = at::clone(condition);
1253+
auto trt_x = at::clone(x);
1254+
auto trt_y = at::clone(y);
1255+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_condition, trt_x, trt_y});
1256+
1257+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
1258+
}

0 commit comments

Comments
 (0)