Skip to content

Commit abbfe77

Browse files
[release/2.5] Enable bf16 with fp32 weights for MIOpen batchnorm (#1672)
This PR enables: * using MIOpen OCL_mix backend for bf16 batchnorm with fp32 weights (using torch autocast). This was required and tested for customer workload using NCHW (which is the only memory_layout enabled). * logging for MIOpen batchnorm using `PYTORCH_MIOPEN_EXTRA_LOGGING` env var. TODO in separate PR: Need to implement PyTorch unit tests for this bf16/fp16 inputs + fp32 weights case.
1 parent f0927c2 commit abbfe77

File tree

2 files changed

+56
-3
lines changed

2 files changed

+56
-3
lines changed

aten/src/ATen/native/Normalization.cpp

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
#include <c10/core/SymIntArrayRef.h>
6262
#include <utility>
6363
#include <vector>
64+
#include <iostream>
6465

6566
static const int MIOPEN_DIM_MAX = 5;
6667

@@ -514,8 +515,8 @@ BatchNormBackend _select_batch_norm_backend(
514515
input.is_cuda()
515516
&& input.dim() <= MIOPEN_DIM_MAX
516517
&& input.scalar_type() != at::kDouble
517-
&& input.scalar_type() != at::kBFloat16
518518
&& (weight.scalar_type() != at::kHalf)
519+
&& (weight.scalar_type() != at::kBFloat16)
519520
&& weight.defined() && bias.defined()
520521
&& ((running_mean.defined() && running_var.defined())
521522
|| (!running_mean.defined() && !running_var.defined() && training))
@@ -531,6 +532,7 @@ BatchNormBackend _select_batch_norm_backend(
531532
return BatchNormBackend::Native;
532533
}
533534

535+
bool PYTORCH_MIOPEN_EXTRA_LOGGING = c10::utils::check_env("PYTORCH_MIOPEN_EXTRA_LOGGING").value_or(false);
534536

535537
// _batch_norm_impl_index(_backward) are used in the JIT be able to keep the run-time selection
536538
// of backends, while enabling it to keep the information about the used backend, so that it can
@@ -541,6 +543,20 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
541543
const Tensor& input, const std::optional<Tensor>& weight_opt /* optional */, const std::optional<Tensor>& bias_opt /* optional */, const std::optional<Tensor>& running_mean_opt /* optional */, const std::optional<Tensor>& running_var_opt /* optional */,
542544
bool training, double momentum, double eps, bool cudnn_enabled) {
543545
// See [Note: hacky wrapper removal for optional tensor]
546+
if (PYTORCH_MIOPEN_EXTRA_LOGGING)
547+
std :: cout
548+
<< "PYTORCH_MIOPEN_EXTRA_LOGGING: ********************* _batch_norm_impl_index"
549+
<< " input=" << input.scalar_type()
550+
<< " weight=" << (weight_opt.has_value() ? weight_opt.value().scalar_type() : at::ScalarType::Undefined)
551+
<< " bias=" << (bias_opt.has_value() ? bias_opt.value().scalar_type() : at::ScalarType::Undefined)
552+
<< " running_mean=" << (running_mean_opt.has_value() ? running_mean_opt.value().scalar_type() : at::ScalarType::Undefined)
553+
<< " running_var=" << (running_var_opt.has_value() ? running_var_opt.value().scalar_type() : at::ScalarType::Undefined)
554+
<< " training=" << training
555+
// << " momentum=" << momentum
556+
// << " eps=" << eps
557+
<< " cudnn_enabled=" << cudnn_enabled
558+
<< std::endl;
559+
544560
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
545561
const Tensor& weight = *weight_maybe_owned;
546562
const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
@@ -600,7 +616,24 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
600616

601617
Tensor reserve = at::empty({0}, input.options().dtype(kByte));
602618

619+
if (PYTORCH_MIOPEN_EXTRA_LOGGING)
620+
std::cout
621+
<< "PYTORCH_MIOPEN_EXTRA_LOGGING: ********************* _batch_norm_impl_index (use_miopen)"
622+
<< " use_miopen=" << (backend == BatchNormBackend::Miopen)
623+
<< " cudnn_enabled=" << cudnn_enabled
624+
<< " dim=" << input.dim()
625+
<< " memory_format=" << input.suggest_memory_format()
626+
<< " input.dtype=" << input.scalar_type()
627+
<< " weight.dtype=" << (weight.defined()?"+":"-") << weight.scalar_type()
628+
<< " bias.dtype=" << (bias.defined()?"+":"-") << bias.scalar_type()
629+
<< " running_mean.dtype=" << (running_mean.defined()?"+":"-") << running_mean.scalar_type()
630+
<< " running_var.dtype=" << (running_mean.defined()?"+":"-") << running_mean.scalar_type()
631+
<< " training=" << training
632+
<< std::endl;
633+
603634
if (backend == BatchNormBackend::Miopen) {
635+
if (PYTORCH_MIOPEN_EXTRA_LOGGING)
636+
std::cout << "PYTORCH_MIOPEN_EXTRA_LOGGING: ********************* _batch_norm_impl_index (calling miopen_batch_norm)" << std::endl;
604637
return std::tuple_cat(
605638
at::miopen_batch_norm(
606639
input.contiguous(), weight.contiguous(), bias.contiguous(),
@@ -623,6 +656,8 @@ std::tuple<Tensor, Tensor, Tensor> _batch_norm_impl_index_backward(
623656
const Tensor& input, const Tensor& grad_output, const std::optional<Tensor>& weight_opt /* optional */, const std::optional<Tensor>& running_mean_opt /* optional */, const std::optional<Tensor>& running_var_opt /* optional */, const std::optional<Tensor>& save_mean_opt /* optional */, const std::optional<Tensor>& save_var_transform_opt /* optional */,
624657
bool train, double epsilon, std::array<bool, 3> output_mask, const Tensor &reservedSpace) {
625658
// See [Note: hacky wrapper removal for optional tensor]
659+
if (PYTORCH_MIOPEN_EXTRA_LOGGING)
660+
std :: cout << "PYTORCH_MIOPEN_EXTRA_LOGGING: ********************* _batch_norm_impl_index_backward" << std::endl;
626661
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
627662
const Tensor& weight = *weight_maybe_owned;
628663
const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
@@ -653,12 +688,16 @@ std::tuple<Tensor, Tensor, Tensor> _batch_norm_impl_index_backward(
653688

654689
// backward in inference mode is not supported in cudnn, fallback to native
655690
if (impl_index == 0 || (!train)) {
691+
if (PYTORCH_MIOPEN_EXTRA_LOGGING)
692+
std :: cout << "PYTORCH_MIOPEN_EXTRA_LOGGING: ********************* _batch_norm_impl_index_backward (calling native_batch_norm_backward)" << std::endl;
656693
return at::native_batch_norm_backward(grad_output, input, weight, running_mean, running_var, save_mean, save_var_transform, train, epsilon, output_mask);
657694
} else if (impl_index == 1) {
658695
// TODO: _batch_norm_impl_index_backward is only used in JIT. cudnn NHWC
659696
// format conversion is done inside cudnn_batch_norm_backward instead
660697
return at::cudnn_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var_transform, epsilon, reservedSpace);
661698
} else if (impl_index == 2) {
699+
if (PYTORCH_MIOPEN_EXTRA_LOGGING)
700+
std :: cout << "PYTORCH_MIOPEN_EXTRA_LOGGING: ********************* _batch_norm_impl_index_backward (calling miopen_batch_norm_backward)" << std::endl;
662701
return at::miopen_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var_transform, epsilon);
663702
}
664703
TORCH_INTERNAL_ASSERT(false, "Unsupported impl_index in _batch_norm_impl_index_backward: ", impl_index);
@@ -669,6 +708,20 @@ Tensor batch_norm(
669708
const Tensor& input, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt,
670709
const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt,
671710
bool training, double momentum, double eps, bool cudnn_enabled) {
711+
if (PYTORCH_MIOPEN_EXTRA_LOGGING)
712+
std :: cout
713+
<< "PYTORCH_MIOPEN_EXTRA_LOGGING: ********************* batch_norm"
714+
<< " input=" << input.scalar_type()
715+
<< " weight=" << (weight_opt.has_value() ? weight_opt.value().scalar_type() : at::ScalarType::Undefined)
716+
<< " bias=" << (bias_opt.has_value() ? bias_opt.value().scalar_type() : at::ScalarType::Undefined)
717+
<< " running_mean=" << (running_mean_opt.has_value() ? running_mean_opt.value().scalar_type() : at::ScalarType::Undefined)
718+
<< " running_var=" << (running_var_opt.has_value() ? running_var_opt.value().scalar_type() : at::ScalarType::Undefined)
719+
<< " training=" << training
720+
// << " momentum=" << momentum
721+
// << " eps=" << eps
722+
<< " cudnn_enabled=" << cudnn_enabled
723+
<< std::endl;
724+
672725
const Tensor& weight = c10::value_or_else(weight_opt, [] {return Tensor();});
673726
const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
674727
const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});

aten/src/ATen/native/miopen/BatchNorm_miopen.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm(
7979
checkAllDefined(c, {running_mean, running_var});
8080
}
8181
checkAllSameGPU(c, {input, weight, bias, running_mean, running_var});
82-
if (input->scalar_type() != ScalarType::Half) {
82+
if (input->scalar_type() != ScalarType::Half && input->scalar_type() != ScalarType::BFloat16) {
8383
checkAllSameType(c, {input, weight});
8484
}
8585
checkAllSameType(c, {weight, bias, running_mean, running_var});
@@ -186,7 +186,7 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm_backward(
186186

187187
checkAllDefined(c, {input, grad_output, weight, save_mean, save_var});
188188
checkAllSameGPU(c, {input, grad_output, weight, save_mean, save_var});
189-
if (input->scalar_type() == ScalarType::Half) {
189+
if (input->scalar_type() == ScalarType::Half || input->scalar_type() == ScalarType::BFloat16) {
190190
checkScalarType(c, weight, ScalarType::Float);
191191
} else {
192192
checkAllSameType(c, {input, weight});

0 commit comments

Comments
 (0)