Skip to content

Commit bb46e70

Browse files
committed
feat(//core/conversion/converters/impl): added support for trilinear3d op
Signed-off-by: Abhiram Iyer <[email protected]> Signed-off-by: Abhiram Iyer <[email protected]>
1 parent 4416d1f commit bb46e70

File tree

1 file changed

+34
-2
lines changed

1 file changed

+34
-2
lines changed

core/conversion/converters/impl/interpolate.cpp

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,13 +163,45 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
163163
auto layer_output = ctx->AssociateValueAndTensor(n->outputs()[0], resize_layer->getOutput(0));
164164
LOG_DEBUG("Output tensor shape: " << layer_output->getDimensions());
165165
} else {
166-
TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) << "\nScale factor parameter for upsample_linear1d not supported yet.");
166+
TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) << "\nScale factor parameter for upsample_bilinear2d not supported yet.");
167167
}
168168

169169
return true;
170170
}
171-
});
171+
}).pattern({
172+
"aten::upsample_trilinear3d(Tensor self, int[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> (Tensor)",
173+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
174+
auto in = args[0].ITensor();
175+
auto in_shape = util::toVec(in->getDimensions());
176+
177+
bool align_corners = args[2].IValue()->to<bool>();
172178

179+
// Case 1: user uses output size and not scales_d, scales_h, scales_w
180+
if (!args[1].IValue()->isNone() && args[3].IValue()->isNone() && args[4].IValue()->isNone() && args[5].IValue()->isNone()) {
181+
auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList()));
182+
183+
TRTORCH_ASSERT(out_size.size() == 3, "aten::upsample_trilinear3d input Tensor and output size dimension mismatch");
184+
185+
auto out_shape = in_shape;
186+
std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size()));
187+
188+
auto resize_layer = ctx->net->addResize(*in);
189+
TRTORCH_CHECK(resize_layer, "Unable to create interpolation (resizing) layer from node" << *n);
190+
191+
resize_layer->setOutputDimensions(util::toDims(out_shape));
192+
resize_layer->setResizeMode(nvinfer1::ResizeMode::kLINEAR);
193+
resize_layer->setAlignCorners(align_corners);
194+
resize_layer->setName(util::node_info(n).c_str());
195+
196+
auto layer_output = ctx->AssociateValueAndTensor(n->outputs()[0], resize_layer->getOutput(0));
197+
LOG_DEBUG("Output tensor shape: " << layer_output->getDimensions());
198+
} else {
199+
TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) << "\nScale factor parameter for upsample_trilinear3d not supported yet.");
200+
}
201+
202+
return true;
203+
}
204+
});
173205

174206
} // namespace
175207
} // namespace impl

0 commit comments

Comments
 (0)