Skip to content

Commit 8a5afd2

Browse files
lcskrishnaiotamudelta
authored andcommitted
[Pytorch] enable mixed precision fp16 training.
* enabled miopen fp16 for mixed precision training * removed ununsed mathType completely
1 parent e502360 commit 8a5afd2

File tree

2 files changed

+1
-3
lines changed

2 files changed

+1
-3
lines changed

aten/src/ATen/miopen/Descriptors.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,6 @@ struct ConvolutionDescriptor
122122
&miopenDestroyConvolutionDescriptor>
123123
{
124124
void set(miopenDataType_t dataType, int dim, int* pad, int* stride, int * upscale /* aka dilation */, int groups) {
125-
miopenDataType_t mathType = dataType;
126-
if (dataType == miopenHalf) mathType = miopenFloat;
127125
MIOPEN_CHECK(miopenInitConvolutionDescriptor(mut_desc(), miopenConvolution, pad[0], pad[1], stride[0], stride[1], upscale[0], upscale[1]));
128126
MIOPEN_CHECK(miopenSetConvolutionGroupCount(mut_desc(), groups));
129127
}

aten/src/ATen/native/Normalization.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ Tensor batch_norm(
265265
bool use_miopen = (input.is_cuda()
266266
&& input.dim() <= MIOPEN_DIM_MAX
267267
&& input.type().scalarType() != at::kDouble
268-
&& (input.type().scalarType() == weight.type().scalarType())
268+
&& (weight.type().scalarType() != at::kHalf)
269269
&& weight.defined() && bias.defined()
270270
&& ((running_mean.defined() && running_var.defined())
271271
|| (!running_mean.defined() && !running_var.defined() && training))

0 commit comments

Comments
 (0)