Skip to content

Commit 024a6b2

Browse files
committed
fix(): need to fix gather converter
Signed-off-by: Abhiram Iyer <[email protected]> Signed-off-by: Abhiram Iyer <[email protected]>
1 parent d9c0e84 commit 024a6b2

File tree

2 files changed

+69
-1
lines changed

2 files changed

+69
-1
lines changed

core/conversion/converters/BUILD

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ cc_library(
2828
"impl/shuffle.cpp",
2929
"impl/softmax.cpp",
3030
"impl/unary.cpp",
31-
"impl/interpolate.cpp"
31+
"impl/interpolate.cpp",
32+
"impl/select.cpp"
3233
],
3334
deps = [
3435
"@tensorrt//:nvinfer",
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
#include "torch/torch.h"
2+
#include "core/util/prelude.h"
3+
#include "core/conversion/converters/converters.h"
4+
#include "NvInfer.h"
5+
#include "torch/csrc/autograd/generated/variable_factories.h"
6+
7+
#include <ATen/ATen.h>
8+
#include <vector>
9+
10+
#include <csignal>
11+
12+
namespace trtorch {
13+
namespace core {
14+
namespace conversion {
15+
namespace converters {
16+
namespace impl {
17+
namespace {
18+
19+
auto select_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
20+
.pattern({
21+
"aten::select.int(Tensor(a) self, int dim, int index) -> (Tensor(a))",
22+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
23+
std::cout << "select.int converter recognized" << std::endl;
24+
25+
auto in = args[0].ITensor();
26+
auto axis = args[1].unwrapToInt();
27+
auto ind = (int32_t) args[2].unwrapToInt();
28+
29+
// tried: vector for input
30+
//std::vector<int32_t> indices_input = {ind};
31+
32+
auto options = torch::TensorOptions().device(torch::kCUDA, 1).dtype(torch::kInt32);
33+
at::Tensor indices = torch::tensor(torch::detail::TensorDataContainer(ind), options);
34+
35+
auto weights = Weights(ctx, indices);
36+
// manually setting weights
37+
// weights.data.type = nvinfer1::DataType::kINT32;
38+
39+
auto const_layer = ctx->net->addConstant(weights.shape, weights.data);
40+
const_layer->setName(util::node_info(n).c_str());
41+
// manually setting output type
42+
// const_layer->setOutputType(0, nvinfer1::DataType::kINT32);
43+
44+
auto const_out = ctx->AssociateValueAndTensor(n->outputs()[0], const_layer->getOutput(0));
45+
46+
auto gather_layer = ctx->net->addGather(*in, *const_out, axis);
47+
gather_layer->setName(util::node_info(n).c_str());
48+
// manually setting output type
49+
// gather_layer->setOutputType(0, nvinfer1::DataType::kINT32);
50+
51+
auto gather_output = ctx->AssociateValueAndTensor(n->outputs()[0], gather_layer->getOutput(0));
52+
53+
LOG_DEBUG("Output tensor shape: " << gather_output->getDimensions());
54+
55+
// for debugging
56+
// std::raise(SIGTRAP);
57+
58+
return true;
59+
}
60+
});
61+
62+
} // namespace
63+
} // namespace impl
64+
} // namespace converters
65+
} // namespace conversion
66+
} // namespace core
67+
} // namespace trtorch

0 commit comments

Comments
 (0)