From d3a5429db61ff9e44ef7d2f2ef70a819297cbca1 Mon Sep 17 00:00:00 2001 From: Jithun Nair <37884920+jithunnair-amd@users.noreply.github.com> Date: Mon, 2 Dec 2024 13:57:45 -0600 Subject: [PATCH] [release/2.5] Enable bf16 with fp32 weights for MIOpen batchnorm (#1672) This PR enables: * using MIOpen OCL_mix backend for bf16 batchnorm with fp32 weights (using torch autocast). This was required and tested for customer workload using NCHW (which is the only memory_layout enabled). * logging for MIOpen batchnorm using `PYTORCH_MIOPEN_EXTRA_LOGGING` env var. TODO in separate PR: Need to implement PyTorch unit tests for this bf16/fp16 inputs + fp32 weights case. (cherry picked from commit abbfe770c8db8f48ae0a5abdc788ef6705818de3) --- aten/src/ATen/native/Normalization.cpp | 55 ++++++++++++++++++- .../ATen/native/miopen/BatchNorm_miopen.cpp | 4 +- 2 files changed, 56 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/Normalization.cpp b/aten/src/ATen/native/Normalization.cpp index e9e7c001837a4..e699b1867294d 100644 --- a/aten/src/ATen/native/Normalization.cpp +++ b/aten/src/ATen/native/Normalization.cpp @@ -61,6 +61,7 @@ #include #include #include +#include static const int MIOPEN_DIM_MAX = 5; @@ -514,8 +515,8 @@ BatchNormBackend _select_batch_norm_backend( input.is_cuda() && input.dim() <= MIOPEN_DIM_MAX && input.scalar_type() != at::kDouble - && input.scalar_type() != at::kBFloat16 && (weight.scalar_type() != at::kHalf) + && (weight.scalar_type() != at::kBFloat16) && weight.defined() && bias.defined() && ((running_mean.defined() && running_var.defined()) || (!running_mean.defined() && !running_var.defined() && training)) @@ -531,6 +532,7 @@ BatchNormBackend _select_batch_norm_backend( return BatchNormBackend::Native; } +bool PYTORCH_MIOPEN_EXTRA_LOGGING = c10::utils::check_env("PYTORCH_MIOPEN_EXTRA_LOGGING").value_or(false); // _batch_norm_impl_index(_backward) are used in the JIT be able to keep the run-time selection // of backends, while enabling it to keep the information about the used backend, so that it can @@ -541,6 +543,20 @@ std::tuple _batch_norm_impl_index( const Tensor& input, const std::optional& weight_opt /* optional */, const std::optional& bias_opt /* optional */, const std::optional& running_mean_opt /* optional */, const std::optional& running_var_opt /* optional */, bool training, double momentum, double eps, bool cudnn_enabled) { // See [Note: hacky wrapper removal for optional tensor] + if (PYTORCH_MIOPEN_EXTRA_LOGGING) + std :: cout + << "PYTORCH_MIOPEN_EXTRA_LOGGING: ********************* _batch_norm_impl_index" + << " input=" << input.scalar_type() + << " weight=" << (weight_opt.has_value() ? weight_opt.value().scalar_type() : at::ScalarType::Undefined) + << " bias=" << (bias_opt.has_value() ? bias_opt.value().scalar_type() : at::ScalarType::Undefined) + << " running_mean=" << (running_mean_opt.has_value() ? running_mean_opt.value().scalar_type() : at::ScalarType::Undefined) + << " running_var=" << (running_var_opt.has_value() ? running_var_opt.value().scalar_type() : at::ScalarType::Undefined) + << " training=" << training + // << " momentum=" << momentum + // << " eps=" << eps + << " cudnn_enabled=" << cudnn_enabled + << std::endl; + c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); const Tensor& weight = *weight_maybe_owned; const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();}); @@ -600,7 +616,24 @@ std::tuple _batch_norm_impl_index( Tensor reserve = at::empty({0}, input.options().dtype(kByte)); + if (PYTORCH_MIOPEN_EXTRA_LOGGING) + std::cout + << "PYTORCH_MIOPEN_EXTRA_LOGGING: ********************* _batch_norm_impl_index (use_miopen)" + << " use_miopen=" << (backend == BatchNormBackend::Miopen) + << " cudnn_enabled=" << cudnn_enabled + << " dim=" << input.dim() + << " memory_format=" << input.suggest_memory_format() + << " input.dtype=" << input.scalar_type() + << " weight.dtype=" << (weight.defined()?"+":"-") << weight.scalar_type() + << " bias.dtype=" << (bias.defined()?"+":"-") << bias.scalar_type() + << " running_mean.dtype=" << (running_mean.defined()?"+":"-") << running_mean.scalar_type() + << " running_var.dtype=" << (running_mean.defined()?"+":"-") << running_mean.scalar_type() + << " training=" << training + << std::endl; + if (backend == BatchNormBackend::Miopen) { + if (PYTORCH_MIOPEN_EXTRA_LOGGING) + std::cout << "PYTORCH_MIOPEN_EXTRA_LOGGING: ********************* _batch_norm_impl_index (calling miopen_batch_norm)" << std::endl; return std::tuple_cat( at::miopen_batch_norm( input.contiguous(), weight.contiguous(), bias.contiguous(), @@ -623,6 +656,8 @@ std::tuple _batch_norm_impl_index_backward( const Tensor& input, const Tensor& grad_output, const std::optional& weight_opt /* optional */, const std::optional& running_mean_opt /* optional */, const std::optional& running_var_opt /* optional */, const std::optional& save_mean_opt /* optional */, const std::optional& save_var_transform_opt /* optional */, bool train, double epsilon, std::array output_mask, const Tensor &reservedSpace) { // See [Note: hacky wrapper removal for optional tensor] + if (PYTORCH_MIOPEN_EXTRA_LOGGING) + std :: cout << "PYTORCH_MIOPEN_EXTRA_LOGGING: ********************* _batch_norm_impl_index_backward" << std::endl; c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); const Tensor& weight = *weight_maybe_owned; const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();}); @@ -653,12 +688,16 @@ std::tuple _batch_norm_impl_index_backward( // backward in inference mode is not supported in cudnn, fallback to native if (impl_index == 0 || (!train)) { + if (PYTORCH_MIOPEN_EXTRA_LOGGING) + std :: cout << "PYTORCH_MIOPEN_EXTRA_LOGGING: ********************* _batch_norm_impl_index_backward (calling native_batch_norm_backward)" << std::endl; return at::native_batch_norm_backward(grad_output, input, weight, running_mean, running_var, save_mean, save_var_transform, train, epsilon, output_mask); } else if (impl_index == 1) { // TODO: _batch_norm_impl_index_backward is only used in JIT. cudnn NHWC // format conversion is done inside cudnn_batch_norm_backward instead return at::cudnn_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var_transform, epsilon, reservedSpace); } else if (impl_index == 2) { + if (PYTORCH_MIOPEN_EXTRA_LOGGING) + std :: cout << "PYTORCH_MIOPEN_EXTRA_LOGGING: ********************* _batch_norm_impl_index_backward (calling miopen_batch_norm_backward)" << std::endl; return at::miopen_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var_transform, epsilon); } TORCH_INTERNAL_ASSERT(false, "Unsupported impl_index in _batch_norm_impl_index_backward: ", impl_index); @@ -669,6 +708,20 @@ Tensor batch_norm( const Tensor& input, const std::optional& weight_opt, const std::optional& bias_opt, const std::optional& running_mean_opt, const std::optional& running_var_opt, bool training, double momentum, double eps, bool cudnn_enabled) { + if (PYTORCH_MIOPEN_EXTRA_LOGGING) + std :: cout + << "PYTORCH_MIOPEN_EXTRA_LOGGING: ********************* batch_norm" + << " input=" << input.scalar_type() + << " weight=" << (weight_opt.has_value() ? weight_opt.value().scalar_type() : at::ScalarType::Undefined) + << " bias=" << (bias_opt.has_value() ? bias_opt.value().scalar_type() : at::ScalarType::Undefined) + << " running_mean=" << (running_mean_opt.has_value() ? running_mean_opt.value().scalar_type() : at::ScalarType::Undefined) + << " running_var=" << (running_var_opt.has_value() ? running_var_opt.value().scalar_type() : at::ScalarType::Undefined) + << " training=" << training + // << " momentum=" << momentum + // << " eps=" << eps + << " cudnn_enabled=" << cudnn_enabled + << std::endl; + const Tensor& weight = c10::value_or_else(weight_opt, [] {return Tensor();}); const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();}); const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();}); diff --git a/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp b/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp index 9f9ef77f90a34..9ef81efe236af 100644 --- a/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp +++ b/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp @@ -79,7 +79,7 @@ std::tuple miopen_batch_norm( checkAllDefined(c, {running_mean, running_var}); } checkAllSameGPU(c, {input, weight, bias, running_mean, running_var}); - if (input->scalar_type() != ScalarType::Half) { + if (input->scalar_type() != ScalarType::Half && input->scalar_type() != ScalarType::BFloat16) { checkAllSameType(c, {input, weight}); } checkAllSameType(c, {weight, bias, running_mean, running_var}); @@ -186,7 +186,7 @@ std::tuple miopen_batch_norm_backward( checkAllDefined(c, {input, grad_output, weight, save_mean, save_var}); checkAllSameGPU(c, {input, grad_output, weight, save_mean, save_var}); - if (input->scalar_type() == ScalarType::Half) { + if (input->scalar_type() == ScalarType::Half || input->scalar_type() == ScalarType::BFloat16) { checkScalarType(c, weight, ScalarType::Float); } else { checkAllSameType(c, {input, weight});