Skip to content

Commit 7794c78

Browse files
committed
fix(): fixed interpolate_plugin to handle dynamically sized inputs for adaptive_pool2d
Signed-off-by: Abhiram Iyer <[email protected]> Signed-off-by: Abhiram Iyer <[email protected]>
1 parent 549ca38 commit 7794c78

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

core/conversion/converters/impl/plugins/interpolate_plugin.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,16 @@ size_t InterpolatePlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inp
156156
int InterpolatePlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, const void *const *inputs,
157157
void *const *outputs, void *workspace,
158158
cudaStream_t stream) {
159+
at::Tensor input;
160+
161+
if (mode == "adaptive_pool2d") {
162+
// use dynamically inferred input shape (for pooling)
163+
input = at::from_blob((void*) inputs[0], util::toVec(inputDesc->dims), [](void*){}, tensor_options);
164+
} else {
165+
// use precomputed input shape (for interpolation/upsampling)
166+
input = at::from_blob((void*) inputs[0], in_shape, [](void*){}, tensor_options);
167+
}
159168

160-
at::Tensor input = at::from_blob((void*) inputs[0], in_shape, [](void*){}, tensor_options);
161169
at::Tensor output = at::from_blob(outputs[0], out_shape, [](void*){}, tensor_options);
162170

163171
at::cuda::CUDAStream torch_stream = at::cuda::getStreamFromPool();

0 commit comments

Comments
 (0)