@@ -109,3 +109,50 @@ TEST(Converters, ATenHardTanhCustomRangeConvertsCorrectly) {
109
109
ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
110
110
}
111
111
112
+ TEST (Converters, ATenPReLUConvertsCorrectly) {
113
+ const auto graph = R"IR(
114
+ graph(%0 : Tensor,
115
+ %1 : Float(1)):
116
+ %3 : Tensor = aten::prelu(%0, %1)
117
+ return (%3))IR" ;
118
+
119
+ auto g = std::make_shared<torch::jit::Graph>();
120
+ torch::jit::parseIR (graph, &*g);
121
+
122
+ auto in = at::randint (-5 , 5 , {5 }, {at::kCUDA });
123
+ auto slope = at::randint (-5 , 5 , {1 }, {at::kCUDA });
124
+
125
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {slope});
126
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {in});
127
+
128
+ in = at::clone (in);
129
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {slope});
130
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in});
131
+
132
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
133
+ }
134
+
135
+ TEST (Converters, ATenPReLUMultiChannelConvertsCorrectly) {
136
+ const auto graph = R"IR(
137
+ graph(%0 : Tensor,
138
+ %1 : Float(10)):
139
+ %3 : Tensor = aten::prelu(%0, %1)
140
+ return (%3))IR" ;
141
+
142
+ auto g = std::make_shared<torch::jit::Graph>();
143
+ torch::jit::parseIR (graph, &*g);
144
+
145
+ auto in = at::randint (-5 , 5 , {1 ,10 , 1 , 1 }, {at::kCUDA });
146
+ auto slope = at::randint (-5 , 5 , {10 }, {at::kCUDA });
147
+
148
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {slope});
149
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {in});
150
+
151
+ in = at::clone (in);
152
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {slope});
153
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in});
154
+
155
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
156
+ }
157
+
158
+
0 commit comments