Skip to content

Commit 8bc4369

Browse files
committed
feat(aten::prelu): Basic prelu support
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent fe06d09 commit 8bc4369

File tree

2 files changed

+68
-0
lines changed

2 files changed

+68
-0
lines changed

core/conversion/converters/impl/activation.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,27 @@ auto acthardtanh TRTORCH_UNUSED = RegisterNodeConversionPatterns()
7979
new_layer->setName(util::node_info(n).c_str());
8080
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0));
8181

82+
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
83+
return true;
84+
}
85+
}).pattern({
86+
"aten::prelu(Tensor self, Tensor weight) -> (Tensor)",
87+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
88+
auto in = args[0].ITensor();
89+
auto slopes = args[1].unwrapToTensor();
90+
91+
//if (slopes.numel() != 1) {
92+
// auto in_dims = util::toVec(in.getDimensions());
93+
// auto per_channel_shape = std::vector<int64_t>(in_dims.begin() + 2, in_dims.end());
94+
// for ()
95+
//}
96+
97+
auto slope_tensor = tensor_to_const(ctx, slopes);
98+
99+
auto new_layer = ctx->net->addParametricReLU(*in, *slope_tensor);
100+
new_layer->setName(util::node_info(n).c_str());
101+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0));
102+
82103
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
83104
return true;
84105
}

tests/core/converters/test_activation.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,3 +109,50 @@ TEST(Converters, ATenHardTanhCustomRangeConvertsCorrectly) {
109109
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
110110
}
111111

112+
TEST(Converters, ATenPReLUConvertsCorrectly) {
113+
const auto graph = R"IR(
114+
graph(%0 : Tensor,
115+
%1 : Float(1)):
116+
%3 : Tensor = aten::prelu(%0, %1)
117+
return (%3))IR";
118+
119+
auto g = std::make_shared<torch::jit::Graph>();
120+
torch::jit::parseIR(graph, &*g);
121+
122+
auto in = at::randint(-5, 5, {5}, {at::kCUDA});
123+
auto slope = at::randint(-5, 5, {1}, {at::kCUDA});
124+
125+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {slope});
126+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
127+
128+
in = at::clone(in);
129+
params = trtorch::core::conversion::get_named_params(g->inputs(), {slope});
130+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
131+
132+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
133+
}
134+
135+
TEST(Converters, ATenPReLUMultiChannelConvertsCorrectly) {
136+
const auto graph = R"IR(
137+
graph(%0 : Tensor,
138+
%1 : Float(10)):
139+
%3 : Tensor = aten::prelu(%0, %1)
140+
return (%3))IR";
141+
142+
auto g = std::make_shared<torch::jit::Graph>();
143+
torch::jit::parseIR(graph, &*g);
144+
145+
auto in = at::randint(-5, 5, {1,10, 1, 1}, {at::kCUDA});
146+
auto slope = at::randint(-5, 5, {10}, {at::kCUDA});
147+
148+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {slope});
149+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
150+
151+
in = at::clone(in);
152+
params = trtorch::core::conversion::get_named_params(g->inputs(), {slope});
153+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
154+
155+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
156+
}
157+
158+

0 commit comments

Comments
 (0)