Skip to content

Commit 1557f6e

Browse files
inocsinnarendasan
authored andcommitted
support topk converter/test_case
Signed-off-by: inocsin <[email protected]>
1 parent b228bf2 commit 1557f6e

File tree

3 files changed

+90
-0
lines changed

3 files changed

+90
-0
lines changed

core/conversion/converters/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ cc_library(
5252
"impl/stack.cpp",
5353
"impl/lstm_cell.cpp",
5454
"impl/unsqueeze.cpp",
55+
"impl/topk.cpp",
5556
],
5657
deps = [
5758
"@tensorrt//:nvinfer",
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
#include "NvInfer.h"
2+
#include "core/conversion/converters/converters.h"
3+
#include "core/conversion/tensorcontainer/TensorContainer.h"
4+
#include "core/util/prelude.h"
5+
#include "torch/torch.h"
6+
7+
#include <ATen/ATen.h>
8+
#include <vector>
9+
10+
namespace trtorch {
11+
namespace core {
12+
namespace conversion {
13+
namespace converters {
14+
namespace impl {
15+
namespace {
16+
17+
auto topk_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().pattern(
18+
{"aten::topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices)",
19+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
20+
auto self = args[0].ITensorOrFreeze(ctx);
21+
auto k = args[1].unwrapToInt();
22+
auto dim = args[2].unwrapToInt();
23+
auto largest = args[3].unwrapToBool();
24+
auto sorted = args[4].unwrapToBool();
25+
26+
auto selfDim = util::toVec(self->getDimensions());
27+
28+
//reduceAxes The reduction dimensions. The bit in position i of bitmask reduceAxes corresponds to explicit dimension i of the result.
29+
//E.g., the least significant bit corresponds to the first explicit dimension and the next to least significant bit corresponds to the second explicit dimension.
30+
31+
if (dim < 0) {
32+
dim = selfDim.size() + dim;
33+
}
34+
35+
uint32_t shiftDim = 1 << dim;
36+
37+
LOG_DEBUG("Output topk reduce dim: " << dim);
38+
39+
auto TopKOperation = largest ? (nvinfer1::TopKOperation::kMAX) : (nvinfer1::TopKOperation::kMIN);
40+
41+
auto new_layer = ctx->net->addTopK(*self, TopKOperation, k, shiftDim);
42+
43+
TRTORCH_CHECK(new_layer, "Unable to create topk layer from node: " << *n);
44+
45+
auto out0 = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0));
46+
auto out1 = ctx->AssociateValueAndTensor(n->outputs()[1], new_layer->getOutput(1));
47+
48+
LOG_DEBUG("Output tensor(0) shape: " << out0->getDimensions());
49+
LOG_DEBUG("Output tensor(1) shape: " << out1->getDimensions());
50+
51+
return true;
52+
}});
53+
54+
} // namespace
55+
} // namespace impl
56+
} // namespace converters
57+
} // namespace conversion
58+
} // namespace core
59+
} // namespace trtorch

tests/core/converters/test_topk.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#include <string>
2+
#include "core/compiler.h"
3+
#include "gtest/gtest.h"
4+
#include "tests/util/util.h"
5+
#include "torch/csrc/jit/ir/irparser.h"
6+
7+
TEST(Converters, ATenTopKConvertsCorrectly) {
8+
const auto graph = R"IR(
9+
graph(%0 : Tensor):
10+
%1 : int = prim::Constant[value=20]()
11+
%2 : int = prim::Constant[value=-1]()
12+
%3 : bool = prim::Constant[value=1]()
13+
%4 : bool = prim::Constant[value=1]()
14+
%5 : Tensor, %6 : Tensor = aten::topk(%0, %1, %2, %3, %4)
15+
return (%5, %6))IR";
16+
17+
auto g = std::make_shared<torch::jit::Graph>();
18+
torch::jit::parseIR(graph, &*g);
19+
20+
auto in = at::rand({10, 10, 100}, {at::kCUDA});
21+
22+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
23+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
24+
25+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
26+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
27+
28+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
29+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[1], trt_results[1].reshape_as(jit_results[1]), 2e-6));
30+
}

0 commit comments

Comments
 (0)