Skip to content

Commit d9c0e84

Browse files
committed
fix(): fixed FP16 bug, fixed README, addressed some other PR comments
Signed-off-by: Abhiram Iyer <[email protected]> Signed-off-by: Abhiram Iyer <[email protected]>
1 parent 7794c78 commit d9c0e84

File tree

3 files changed

+12
-20
lines changed

3 files changed

+12
-20
lines changed

README.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,8 +205,6 @@ You can register a converter for your op using the `NodeConverterRegistry` insid
205205

206206
## Known Limitations
207207

208-
- You cannot use both Adaptive Pooling in PyTorch and also use TRTorch Dynamic input shape (follow [#49](https://github.com/NVIDIA/TRTorch/issues/49) for the latest on the issue)
209-
210208
## Structure of the repo
211209

212210
| Component | Description |

core/conversion/converters/impl/interpolate.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
#include "NvInfer.h"
66
#include "NvInferRuntimeCommon.h"
77

8-
#include <tuple>
9-
108
namespace trtorch {
119
namespace core {
1210
namespace conversion {

core/conversion/converters/impl/plugins/interpolate_plugin.cpp

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -134,15 +134,20 @@ size_t InterpolatePlugin::getSerializationSize() const {
134134
}
135135

136136
bool InterpolatePlugin::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) {
137-
if (inOut->format != nvinfer1::TensorFormat::kLINEAR) {
138-
return false;
139-
}
137+
TRTORCH_ASSERT(0 <= pos && pos <= 1, "There should be exactly 2 connections to the plugin - 1 input, 1 output");
138+
TRTORCH_ASSERT(nbInputs == 1, "Expected a single tensor as input to interpolate plugin");
139+
TRTORCH_ASSERT(nbOutputs == 1, "Expected a single tensor as output to interpolate plugin");
140140

141-
if (inOut->type == DataType::kINT32 || inOut->type == DataType::kINT8) {
142-
return false;
141+
const PluginTensorDesc& in = inOut[0];
142+
143+
if (pos == 0) {
144+
return (in.type == nvinfer1::DataType::kFLOAT) && (in.format == nvinfer1::TensorFormat::kLINEAR);
143145
}
144146

145-
return true;
147+
// pos == 1, accessing information about output tensor
148+
const PluginTensorDesc& out = inOut[1];
149+
150+
return (in.type == out.type) && (in.format == out.format);
146151
}
147152

148153
void InterpolatePlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) {
@@ -156,16 +161,7 @@ size_t InterpolatePlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inp
156161
int InterpolatePlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, const void *const *inputs,
157162
void *const *outputs, void *workspace,
158163
cudaStream_t stream) {
159-
at::Tensor input;
160-
161-
if (mode == "adaptive_pool2d") {
162-
// use dynamically inferred input shape (for pooling)
163-
input = at::from_blob((void*) inputs[0], util::toVec(inputDesc->dims), [](void*){}, tensor_options);
164-
} else {
165-
// use precomputed input shape (for interpolation/upsampling)
166-
input = at::from_blob((void*) inputs[0], in_shape, [](void*){}, tensor_options);
167-
}
168-
164+
at::Tensor input = at::from_blob((void*) inputs[0], util::toVec(inputDesc->dims), [](void*){}, tensor_options);
169165
at::Tensor output = at::from_blob(outputs[0], out_shape, [](void*){}, tensor_options);
170166

171167
at::cuda::CUDAStream torch_stream = at::cuda::getStreamFromPool();

0 commit comments

Comments
 (0)