Skip to content

support aten::index.Tensor #921

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions core/conversion/converters/impl/select.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,34 @@ bool add_split(ConversionCtx* ctx, const torch::jit::Node* n, args& args, bool s
return true;
}

nvinfer1::ITensor* roll(
ConversionCtx* ctx,
nvinfer1::ITensor* in,
int shift,
int dim,
const std::vector<int64_t>& in_shape) {
auto in_dim = in_shape[dim];

auto start = (in_dim - shift) % in_dim;
// Behavior of % is different in C++ vs Python for negative numbers. This
// corrects the difference.
if (start < 0) {
start = start + in_dim;
}
at::Tensor index0 = at::arange(start, in_dim, 1, torch::kInt32);
at::Tensor index;
if (start == 0) {
index = index0;
} else {
at::Tensor index1 = at::arange(start, torch::kInt32);
index = at::cat({index0, index1}, 0);
}
auto index_tensor = tensor_to_const(ctx, index);
auto gather_layer = ctx->net->addGather(*in, *index_tensor, dim);
auto out = gather_layer->getOutput(0);
return out;
}

auto select_registrations TORCHTRT_UNUSED =
RegisterNodeConversionPatterns()
.pattern({"aten::select.int(Tensor(a) self, int dim, int index) -> (Tensor(a))",
Expand Down Expand Up @@ -200,6 +228,69 @@ auto select_registrations TORCHTRT_UNUSED =

LOG_DEBUG("Output tensor shape: " << out->getDimensions());

return true;
}})
.pattern({"aten::roll(Tensor self, int[1] shifts, int[1] dims=[]) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto in = args[0].ITensor();
auto shifts = args[1].unwrapToIntList().vec();
auto dims = args[2].unwrapToIntList().vec();

TORCHTRT_CHECK(dims.size() == shifts.size(), "dims.size() should be equal to shifts.size()");
if (ctx->input_is_dynamic) {
TORCHTRT_THROW_ERROR("aten::roll is currently not support in dynamic input shape compilation");
} else {
auto in_shape = util::toVec(in->getDimensions());
for (size_t i = 0; i < dims.size(); i++) {
auto dim = dims[i] < 0 ? (in_shape.size() + dims[i]) : dims[i];
TORCHTRT_CHECK(dim < in_shape.size(), "Dimension out of range");
in = roll(ctx, in, shifts[i], dim, in_shape);
}
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], in);

LOG_DEBUG("Output tensor shape: " << out->getDimensions());

return true;
}
}})
.pattern(
{"aten::index.Tensor(Tensor self, Tensor?[] indices) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto in = args[0].ITensorOrFreeze(ctx);
auto ts = args[1].IValue()->toListRef();

std::vector<nvinfer1::ITensor*> tensors;
for (auto t : ts) {
if (t.isTensor()) {
auto torch_tensor = t.toTensor();
tensors.push_back(tensor_to_const(ctx, torch_tensor));
} else {
auto cont = t.toCustomClass<TensorContainer>();
tensors.push_back(cont->tensor());
}
}

// In TorchScript, aten::index.Tensor indexes the self tensor along its each dimension by several
// indexes. In this version of Torch-TensorRT, it can only receive one index tensor which means it only
// indexes the self tensor along dimension 0.
TORCHTRT_CHECK(
tensors.size() == 1,
"In this version of Torch-TensorRT, aten::index.Tensor can only receive one index tensor which means it only indexes the self tensor along dimension 0.");
auto indicesTensor = tensors[0];
// Set datatype for indices tensor to INT32
auto identity = ctx->net->addIdentity(*indicesTensor);
identity->setOutputType(0, nvinfer1::DataType::kINT32);
indicesTensor = identity->getOutput(0);

// IGatherLayer takes in input tensor, the indices, and the axis of input tensor to take indices
// from
auto gather_layer = ctx->net->addGather(*in, *indicesTensor, 0);
TORCHTRT_CHECK(gather_layer, "Unable to create gather layer from node: " << *n);
auto gather_out = gather_layer->getOutput(0);

auto out = ctx->AssociateValueAndTensor(n->outputs()[0], gather_out);

LOG_DEBUG("Output tensor shape: " << out->getDimensions());
return true;
}})
.pattern(
Expand Down
104 changes: 104 additions & 0 deletions tests/core/conversion/converters/test_select.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,84 @@ TEST(Converters, ATenEmbeddingConvertsCorrectly) {
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenRollConvertsCorrectly) {
const auto graph = R"IR(
graph(%1 : Tensor):
%2 : int[] = prim::Constant[value=[1, 0, 3, 7]]()
%3 : int[] = prim::Constant[value=[0, 1, 2, 3]]()
%4 : Tensor = aten::roll(%1, %2, %3)
return (%4))IR";

auto g = std::make_shared<torch::jit::Graph>();

torch::jit::parseIR(graph, g.get());

// Run Pytorch
auto in = at::randint(1, 10, {2, 3, 4, 5}, {at::kCUDA});

auto jit_in = at::clone(in);
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});

auto trt_in = at::clone(in);
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
auto trt = trt_results[0].reshape(jit_results[0].sizes());

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenRollShiftsNegativeConvertsCorrectly) {
const auto graph = R"IR(
graph(%1 : Tensor):
%2 : int[] = prim::Constant[value=[0, -3, -3]]()
%3 : int[] = prim::Constant[value=[1, 2, 3]]()
%4 : Tensor = aten::roll(%1, %2, %3)
return (%4))IR";

auto g = std::make_shared<torch::jit::Graph>();

torch::jit::parseIR(graph, g.get());

// Run Pytorch
auto in = at::randint(1, 10, {1, 3, 4, 5}, {at::kCUDA});

auto jit_in = at::clone(in);
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});

auto trt_in = at::clone(in);
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
auto trt = trt_results[0].reshape(jit_results[0].sizes());

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenRollDimsNegativeConvertsCorrectly) {
const auto graph = R"IR(
graph(%1 : Tensor):
%2 : int[] = prim::Constant[value=[0, -3, -3]]()
%3 : int[] = prim::Constant[value=[1, 2, -1]]()
%4 : Tensor = aten::roll(%1, %2, %3)
return (%4))IR";

auto g = std::make_shared<torch::jit::Graph>();

torch::jit::parseIR(graph, g.get());

// Run Pytorch
auto in = at::randint(1, 10, {1, 3, 4, 5}, {at::kCUDA});

auto jit_in = at::clone(in);
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});

auto trt_in = at::clone(in);
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {trt_in});
auto trt = trt_results[0].reshape(jit_results[0].sizes());

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenSliceConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
Expand Down Expand Up @@ -463,3 +541,29 @@ TEST(Converters, ATenMaskedFillZerosConvertsCorrectly) {
ASSERT_TRUE(
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
}

TEST(Converters, ATenIndexTensorConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor,
%index : Tensor):
%18 : Tensor?[] = prim::ListConstruct(%index)
%19 : Tensor = aten::index(%x.1, %18)
return (%19))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto in1 = at::randint(1, 10, {5, 10}, {at::kCUDA});
auto in2 = at::full({2}, 4, {at::kCUDA});
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
auto in2_trt = at::full({2}, 4, {options});

auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in1, in2});

params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in1, in2_trt});

ASSERT_TRUE(
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
}