Skip to content

Commit 4416d1f

Browse files
committed
feat(//core/conversion/converters/impl): added support for linear1d and bilinear2d ops
Signed-off-by: Abhiram Iyer <[email protected]> Signed-off-by: Abhiram Iyer <[email protected]>
1 parent 5ddab8b commit 4416d1f

File tree

1 file changed

+66
-0
lines changed

1 file changed

+66
-0
lines changed

core/conversion/converters/impl/interpolate.cpp

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,72 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
100100
TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) << "\nScale factor parameter for upsample_nearest3d not supported yet.");
101101
}
102102

103+
return true;
104+
}
105+
}).pattern({
106+
"aten::upsample_linear1d(Tensor self, int[1] output_size, bool align_corners, float? scales=None) -> (Tensor)",
107+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
108+
auto in = args[0].ITensor();
109+
auto in_shape = util::toVec(in->getDimensions());
110+
111+
bool align_corners = args[2].IValue()->to<bool>();
112+
113+
// Case 1: user uses output size and not scales
114+
if (!args[1].IValue()->isNone() && args[3].IValue()->isNone()) {
115+
auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList()));
116+
117+
TRTORCH_ASSERT(out_size.size() == 1, "aten::upsample_linear1d input Tensor and output size dimension mismatch");
118+
119+
auto out_shape = in_shape;
120+
std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size()));
121+
122+
auto resize_layer = ctx->net->addResize(*in);
123+
TRTORCH_CHECK(resize_layer, "Unable to create interpolation (resizing) layer from node" << *n);
124+
125+
resize_layer->setOutputDimensions(util::toDims(out_shape));
126+
resize_layer->setResizeMode(nvinfer1::ResizeMode::kLINEAR);
127+
resize_layer->setAlignCorners(align_corners);
128+
resize_layer->setName(util::node_info(n).c_str());
129+
130+
auto layer_output = ctx->AssociateValueAndTensor(n->outputs()[0], resize_layer->getOutput(0));
131+
LOG_DEBUG("Output tensor shape: " << layer_output->getDimensions());
132+
} else {
133+
TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) << "\nScale factor parameter for upsample_linear1d not supported yet.");
134+
}
135+
136+
return true;
137+
}
138+
}).pattern({
139+
"aten::upsample_bilinear2d(Tensor self, int[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> (Tensor)",
140+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
141+
auto in = args[0].ITensor();
142+
auto in_shape = util::toVec(in->getDimensions());
143+
144+
bool align_corners = args[2].IValue()->to<bool>();
145+
146+
// Case 1: user uses output size and not scales_h, scales_w
147+
if (!args[1].IValue()->isNone() && args[3].IValue()->isNone() && args[4].IValue()->isNone()) {
148+
auto out_size = util::toVec(util::toDims(args[1].unwrapToIntList()));
149+
150+
TRTORCH_ASSERT(out_size.size() == 2, "aten::upsample_bilinear2d input Tensor and output size dimension mismatch");
151+
152+
auto out_shape = in_shape;
153+
std::copy(out_size.begin(), out_size.end(), out_shape.begin() + (in_shape.size() - out_size.size()));
154+
155+
auto resize_layer = ctx->net->addResize(*in);
156+
TRTORCH_CHECK(resize_layer, "Unable to create interpolation (resizing) layer from node" << *n);
157+
158+
resize_layer->setOutputDimensions(util::toDims(out_shape));
159+
resize_layer->setResizeMode(nvinfer1::ResizeMode::kLINEAR);
160+
resize_layer->setAlignCorners(align_corners);
161+
resize_layer->setName(util::node_info(n).c_str());
162+
163+
auto layer_output = ctx->AssociateValueAndTensor(n->outputs()[0], resize_layer->getOutput(0));
164+
LOG_DEBUG("Output tensor shape: " << layer_output->getDimensions());
165+
} else {
166+
TRTORCH_THROW_ERROR("Unable to convert node: " << util::node_info(n) << "\nScale factor parameter for upsample_linear1d not supported yet.");
167+
}
168+
103169
return true;
104170
}
105171
});

0 commit comments

Comments
 (0)