diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp index cee0ccd212f5c4..1b8e68a1bd8ffe 100644 --- a/aten/src/ATen/native/Convolution.cpp +++ b/aten/src/ATen/native/Convolution.cpp @@ -3,6 +3,9 @@ #include "ATen/Config.h" +static const int MIOPEN_DIM_MAX = 4; +static const bool MIOPEN_ENABLED = getenv("DISABLE_MIOPEN") != NULL; + namespace at { namespace native { struct ConvParams { @@ -123,7 +126,8 @@ auto ConvParams::use_miopen(const at::Tensor& input) const -> bool { return ((input.type().scalarType() == at::kFloat) || (input.type().scalarType() == at::kHalf)) && detail::getCUDAHooks().compiledWithMIOpen() && input.type().is_cuda() - && cudnn_enabled + && input.dim() > MIOPEN_DIM_MAX + && MIOPEN_ENABLED ; } diff --git a/aten/src/ATen/native/Normalization.cpp b/aten/src/ATen/native/Normalization.cpp index ed0a94ae496718..cb37465dab1a56 100644 --- a/aten/src/ATen/native/Normalization.cpp +++ b/aten/src/ATen/native/Normalization.cpp @@ -7,6 +7,9 @@ #include +static const int MIOPEN_DIM_MAX = 4; +static const bool MIOPEN_ENABLED = getenv("DISABLE_MIOPEN") != NULL; + namespace at { namespace native { namespace { @@ -67,12 +70,14 @@ Tensor batch_norm( } bool use_miopen = (input.type().is_cuda() - && (input.type().scalarType() != at::kHalf - || weight.type().scalarType() == at::kFloat) + && input.dim() < MIOPEN_DIM_MAX + && input.type().scalarType() != at::kDouble + && (input.type().scalarType() == weight.type().scalarType()) && weight.defined() && bias.defined() && ((running_mean.defined() && running_var.defined()) || (!running_mean.defined() && !running_var.defined() && training)) && detail::getCUDAHooks().compiledWithMIOpen() + && MIOPEN_ENABLED ); if (use_miopen) { diff --git a/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp b/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp index c9d25780bd65d3..bfad2392b55a89 100644 --- a/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp +++ b/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp @@ -67,9 +67,7 @@ std::tuple miopen_batch_norm( checkAllDefined(c, {running_mean, running_var}); } checkAllSameGPU(c, {input, weight, bias, running_mean, running_var}); - if (input->type().scalarType() == ScalarType::Half) { - checkScalarType(c, weight, ScalarType::Float); - } else { + if (input->type().scalarType() != ScalarType::Half) { checkAllSameType(c, {input, weight}); } checkAllSameType(c, {weight, bias, running_mean, running_var});