diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp index f359d67c72e786..1af688744fe8a5 100644 --- a/aten/src/ATen/native/Convolution.cpp +++ b/aten/src/ATen/native/Convolution.cpp @@ -120,9 +120,11 @@ auto ConvParams::use_cudnn(const at::Tensor& input) const -> bool { } auto ConvParams::use_miopen(const at::Tensor& input) const -> bool { - if (!detail::getCUDAHooks().compiledWithMIOpen() || !input.type().is_cuda() || !cudnn_enabled) - return false; - return true; + return ((input.type().scalarType() == at::kFloat) || (input.type().scalarType() == at::kHalf)) + && detail::getCUDAHooks().compiledWithMIOpen() + && input.type().is_cuda() + && cudnn_enabled + ; } auto ConvParams::use_mkldnn(const at::Tensor& input) const -> bool {