Skip to content

Commit b304f33

Browse files
mwoottoniotamudelta
authored andcommitted
MIOpen: Honor Max Dim (#222)
* MIOpen: Batchnorm - Allow half/half and half/float, disallow double * MIOpen: Honor DIM_MAX * Limit MIOpen batchnorm to same-precision
1 parent 3075155 commit b304f33

File tree

3 files changed

+13
-6
lines changed

3 files changed

+13
-6
lines changed

aten/src/ATen/native/Convolution.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33

44
#include "ATen/Config.h"
55

6+
static const int MIOPEN_DIM_MAX = 4;
7+
static const bool MIOPEN_ENABLED = getenv("DISABLE_MIOPEN") != NULL;
8+
69
namespace at { namespace native {
710

811
struct ConvParams {
@@ -123,7 +126,8 @@ auto ConvParams::use_miopen(const at::Tensor& input) const -> bool {
123126
return ((input.type().scalarType() == at::kFloat) || (input.type().scalarType() == at::kHalf))
124127
&& detail::getCUDAHooks().compiledWithMIOpen()
125128
&& input.type().is_cuda()
126-
&& cudnn_enabled
129+
&& input.dim() > MIOPEN_DIM_MAX
130+
&& MIOPEN_ENABLED
127131
;
128132
}
129133

aten/src/ATen/native/Normalization.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77

88
#include <vector>
99

10+
static const int MIOPEN_DIM_MAX = 4;
11+
static const bool MIOPEN_ENABLED = getenv("DISABLE_MIOPEN") != NULL;
12+
1013
namespace at { namespace native {
1114

1215
namespace {
@@ -67,12 +70,14 @@ Tensor batch_norm(
6770
}
6871

6972
bool use_miopen = (input.type().is_cuda()
70-
&& (input.type().scalarType() != at::kHalf
71-
|| weight.type().scalarType() == at::kFloat)
73+
&& input.dim() < MIOPEN_DIM_MAX
74+
&& input.type().scalarType() != at::kDouble
75+
&& (input.type().scalarType() == weight.type().scalarType())
7276
&& weight.defined() && bias.defined()
7377
&& ((running_mean.defined() && running_var.defined())
7478
|| (!running_mean.defined() && !running_var.defined() && training))
7579
&& detail::getCUDAHooks().compiledWithMIOpen()
80+
&& MIOPEN_ENABLED
7681
);
7782

7883
if (use_miopen) {

aten/src/ATen/native/miopen/BatchNorm_miopen.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,7 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm(
6767
checkAllDefined(c, {running_mean, running_var});
6868
}
6969
checkAllSameGPU(c, {input, weight, bias, running_mean, running_var});
70-
if (input->type().scalarType() == ScalarType::Half) {
71-
checkScalarType(c, weight, ScalarType::Float);
72-
} else {
70+
if (input->type().scalarType() != ScalarType::Half) {
7371
checkAllSameType(c, {input, weight});
7472
}
7573
checkAllSameType(c, {weight, bias, running_mean, running_var});

0 commit comments

Comments
 (0)