1
+ #include " torch/torch.h"
2
+ #include " core/util/prelude.h"
3
+ #include " core/conversion/converters/converters.h"
4
+ #include " NvInfer.h"
5
+ #include " torch/csrc/autograd/generated/variable_factories.h"
6
+
7
+ #include < ATen/ATen.h>
8
+ #include < vector>
9
+
10
+ #include < csignal>
11
+
12
+ namespace trtorch {
13
+ namespace core {
14
+ namespace conversion {
15
+ namespace converters {
16
+ namespace impl {
17
+ namespace {
18
+
19
+ auto select_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
20
+ .pattern({
21
+ " aten::select.int(Tensor(a) self, int dim, int index) -> (Tensor(a))" ,
22
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
23
+ std::cout << " select.int converter recognized" << std::endl;
24
+
25
+ auto in = args[0 ].ITensor ();
26
+ auto axis = args[1 ].unwrapToInt ();
27
+ auto ind = (int32_t ) args[2 ].unwrapToInt ();
28
+
29
+ // tried: vector for input
30
+ // std::vector<int32_t> indices_input = {ind};
31
+
32
+ auto options = torch::TensorOptions ().device (torch::kCUDA , 1 ).dtype (torch::kInt32 );
33
+ at::Tensor indices = torch::tensor (torch::detail::TensorDataContainer (ind), options);
34
+
35
+ auto weights = Weights (ctx, indices);
36
+ // manually setting weights
37
+ // weights.data.type = nvinfer1::DataType::kINT32;
38
+
39
+ auto const_layer = ctx->net ->addConstant (weights.shape , weights.data );
40
+ const_layer->setName (util::node_info (n).c_str ());
41
+ // manually setting output type
42
+ // const_layer->setOutputType(0, nvinfer1::DataType::kINT32);
43
+
44
+ auto const_out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], const_layer->getOutput (0 ));
45
+
46
+ auto gather_layer = ctx->net ->addGather (*in, *const_out, axis);
47
+ gather_layer->setName (util::node_info (n).c_str ());
48
+ // manually setting output type
49
+ // gather_layer->setOutputType(0, nvinfer1::DataType::kINT32);
50
+
51
+ auto gather_output = ctx->AssociateValueAndTensor (n->outputs ()[0 ], gather_layer->getOutput (0 ));
52
+
53
+ LOG_DEBUG (" Output tensor shape: " << gather_output->getDimensions ());
54
+
55
+ // for debugging
56
+ // std::raise(SIGTRAP);
57
+
58
+ return true ;
59
+ }
60
+ });
61
+
62
+ } // namespace
63
+ } // namespace impl
64
+ } // namespace converters
65
+ } // namespace conversion
66
+ } // namespace core
67
+ } // namespace trtorch
0 commit comments