@@ -163,13 +163,45 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
163
163
auto layer_output = ctx->AssociateValueAndTensor (n->outputs ()[0 ], resize_layer->getOutput (0 ));
164
164
LOG_DEBUG (" Output tensor shape: " << layer_output->getDimensions ());
165
165
} else {
166
- TRTORCH_THROW_ERROR (" Unable to convert node: " << util::node_info (n) << " \n Scale factor parameter for upsample_linear1d not supported yet." );
166
+ TRTORCH_THROW_ERROR (" Unable to convert node: " << util::node_info (n) << " \n Scale factor parameter for upsample_bilinear2d not supported yet." );
167
167
}
168
168
169
169
return true ;
170
170
}
171
- });
171
+ }).pattern({
172
+ " aten::upsample_trilinear3d(Tensor self, int[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> (Tensor)" ,
173
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
174
+ auto in = args[0 ].ITensor ();
175
+ auto in_shape = util::toVec (in->getDimensions ());
176
+
177
+ bool align_corners = args[2 ].IValue ()->to <bool >();
172
178
179
+ // Case 1: user uses output size and not scales_d, scales_h, scales_w
180
+ if (!args[1 ].IValue ()->isNone () && args[3 ].IValue ()->isNone () && args[4 ].IValue ()->isNone () && args[5 ].IValue ()->isNone ()) {
181
+ auto out_size = util::toVec (util::toDims (args[1 ].unwrapToIntList ()));
182
+
183
+ TRTORCH_ASSERT (out_size.size () == 3 , " aten::upsample_trilinear3d input Tensor and output size dimension mismatch" );
184
+
185
+ auto out_shape = in_shape;
186
+ std::copy (out_size.begin (), out_size.end (), out_shape.begin () + (in_shape.size () - out_size.size ()));
187
+
188
+ auto resize_layer = ctx->net ->addResize (*in);
189
+ TRTORCH_CHECK (resize_layer, " Unable to create interpolation (resizing) layer from node" << *n);
190
+
191
+ resize_layer->setOutputDimensions (util::toDims (out_shape));
192
+ resize_layer->setResizeMode (nvinfer1::ResizeMode::kLINEAR );
193
+ resize_layer->setAlignCorners (align_corners);
194
+ resize_layer->setName (util::node_info (n).c_str ());
195
+
196
+ auto layer_output = ctx->AssociateValueAndTensor (n->outputs ()[0 ], resize_layer->getOutput (0 ));
197
+ LOG_DEBUG (" Output tensor shape: " << layer_output->getDimensions ());
198
+ } else {
199
+ TRTORCH_THROW_ERROR (" Unable to convert node: " << util::node_info (n) << " \n Scale factor parameter for upsample_trilinear3d not supported yet." );
200
+ }
201
+
202
+ return true ;
203
+ }
204
+ });
173
205
174
206
} // namespace
175
207
} // namespace impl
0 commit comments