File tree 1 file changed +4
-4
lines changed 1 file changed +4
-4
lines changed Original file line number Diff line number Diff line change @@ -532,24 +532,24 @@ BatchNormBackend _select_batch_norm_backend(
532
532
return BatchNormBackend::Cudnn;
533
533
}
534
534
535
- // TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC once ROCm officially supports NHWC in MIOpen
535
+ // TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM once ROCm officially supports NHWC in MIOpen
536
536
// See #64427
537
537
// non static variable is used to be able to change environment variable in runtime for testing
538
- bool PYTORCH_MIOPEN_SUGGEST_NHWC = c10::utils::check_env (" PYTORCH_MIOPEN_SUGGEST_NHWC " ).value_or (false );
538
+ bool PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM = c10::utils::check_env (" PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM " ).value_or (false );
539
539
540
540
if (
541
541
input.is_cuda ()
542
542
&& (input.dim () <= MIOPEN_DIM_MAX)
543
543
&& (input.scalar_type () != at::kDouble )
544
- && (weight.scalar_type () == at::kFloat )
544
+ && (weight.scalar_type () == at::kFloat ) // allow only fp32 and mixed fp16/bf16
545
545
&& weight.defined () && bias.defined ()
546
546
&& ((running_mean.defined () && running_var.defined ())
547
547
|| (!running_mean.defined () && !running_var.defined () && training))
548
548
&& (input.dim () >= 3 )
549
549
&& detail::getCUDAHooks ().compiledWithMIOpen ()
550
550
&& cudnn_enabled
551
551
&& (input.suggest_memory_format () == MemoryFormat::Contiguous
552
- || (input.suggest_memory_format () == MemoryFormat::ChannelsLast && PYTORCH_MIOPEN_SUGGEST_NHWC ))
552
+ || (input.suggest_memory_format () == MemoryFormat::ChannelsLast && PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM ))
553
553
) {
554
554
return BatchNormBackend::Miopen;
555
555
}
You can’t perform that action at this time.
0 commit comments