You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
auto topk_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().pattern(
18
+
{"aten::topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices)",
19
+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
20
+
auto self = args[0].ITensorOrFreeze(ctx);
21
+
auto k = args[1].unwrapToInt();
22
+
auto dim = args[2].unwrapToInt();
23
+
auto largest = args[3].unwrapToBool();
24
+
auto sorted = args[4].unwrapToBool();
25
+
26
+
auto selfDim = util::toVec(self->getDimensions());
27
+
28
+
//reduceAxes The reduction dimensions. The bit in position i of bitmask reduceAxes corresponds to explicit dimension i of the result.
29
+
//E.g., the least significant bit corresponds to the first explicit dimension and the next to least significant bit corresponds to the second explicit dimension.
30
+
31
+
if (dim < 0) {
32
+
dim = selfDim.size() + dim;
33
+
}
34
+
35
+
uint32_t shiftDim = 1 << dim;
36
+
37
+
LOG_DEBUG("Output topk reduce dim: " << dim);
38
+
39
+
auto TopKOperation = largest ? (nvinfer1::TopKOperation::kMAX) : (nvinfer1::TopKOperation::kMIN);
40
+
41
+
auto new_layer = ctx->net->addTopK(*self, TopKOperation, k, shiftDim);
42
+
43
+
TRTORCH_CHECK(new_layer, "Unable to create topk layer from node: " << *n);
44
+
45
+
auto out0 = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0));
46
+
auto out1 = ctx->AssociateValueAndTensor(n->outputs()[1], new_layer->getOutput(1));
0 commit comments