Skip to content

Commit 051f393

Browse files
committed
add separate environment variable to enable NHWC batchnorm (#1896)
env var PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM=1 enables NHWC batchnorm separately from convolution
1 parent d685372 commit 051f393

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

aten/src/ATen/native/Normalization.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -532,24 +532,24 @@ BatchNormBackend _select_batch_norm_backend(
532532
return BatchNormBackend::Cudnn;
533533
}
534534

535-
// TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC once ROCm officially supports NHWC in MIOpen
535+
// TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM once ROCm officially supports NHWC in MIOpen
536536
// See #64427
537537
// non static variable is used to be able to change environment variable in runtime for testing
538-
bool PYTORCH_MIOPEN_SUGGEST_NHWC = c10::utils::check_env("PYTORCH_MIOPEN_SUGGEST_NHWC").value_or(false);
538+
bool PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM = c10::utils::check_env("PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM").value_or(false);
539539

540540
if (
541541
input.is_cuda()
542542
&& (input.dim() <= MIOPEN_DIM_MAX)
543543
&& (input.scalar_type() != at::kDouble)
544-
&& (weight.scalar_type() == at::kFloat)
544+
&& (weight.scalar_type() == at::kFloat) // allow only fp32 and mixed fp16/bf16
545545
&& weight.defined() && bias.defined()
546546
&& ((running_mean.defined() && running_var.defined())
547547
|| (!running_mean.defined() && !running_var.defined() && training))
548548
&& (input.dim() >= 3)
549549
&& detail::getCUDAHooks().compiledWithMIOpen()
550550
&& cudnn_enabled
551551
&& (input.suggest_memory_format() == MemoryFormat::Contiguous
552-
|| (input.suggest_memory_format() == MemoryFormat::ChannelsLast && PYTORCH_MIOPEN_SUGGEST_NHWC))
552+
|| (input.suggest_memory_format() == MemoryFormat::ChannelsLast && PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM))
553553
) {
554554
return BatchNormBackend::Miopen;
555555
}

0 commit comments

Comments
 (0)