Skip to content

Commit

Permalink
[release/2.4] Enable bf16 with fp32 weights for MIOpen batchnorm (#1801)
Browse files Browse the repository at this point in the history
This PR enables:
* using MIOpen OCL_mix backend for bf16 batchnorm with fp32 weights
(using torch autocast). This was required and tested for customer
workload using NCHW (which is the only memory_layout enabled).
* logging for MIOpen batchnorm using `PYTORCH_MIOPEN_EXTRA_LOGGING` env
var.

TODO in separate PR: Need to implement PyTorch unit tests for this
bf16/fp16 inputs + fp32 weights case.

(cherry picked from commit abbfe77)
  • Loading branch information
jithunnair-amd authored Jan 15, 2025
1 parent db6b3c5 commit f5fe136
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 3 deletions.
55 changes: 54 additions & 1 deletion aten/src/ATen/native/Normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
#include <c10/core/SymIntArrayRef.h>
#include <utility>
#include <vector>
#include <iostream>

static const int MIOPEN_DIM_MAX = 5;

Expand Down Expand Up @@ -514,8 +515,8 @@ BatchNormBackend _select_batch_norm_backend(
input.is_cuda()
&& input.dim() <= MIOPEN_DIM_MAX
&& input.scalar_type() != at::kDouble
&& input.scalar_type() != at::kBFloat16
&& (weight.scalar_type() != at::kHalf)
&& (weight.scalar_type() != at::kBFloat16)
&& weight.defined() && bias.defined()
&& ((running_mean.defined() && running_var.defined())
|| (!running_mean.defined() && !running_var.defined() && training))
Expand All @@ -531,6 +532,7 @@ BatchNormBackend _select_batch_norm_backend(
return BatchNormBackend::Native;
}

bool PYTORCH_MIOPEN_EXTRA_LOGGING = c10::utils::check_env("PYTORCH_MIOPEN_EXTRA_LOGGING").value_or(false);

// _batch_norm_impl_index(_backward) are used in the JIT be able to keep the run-time selection
// of backends, while enabling it to keep the information about the used backend, so that it can
Expand All @@ -541,6 +543,20 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
const Tensor& input, const std::optional<Tensor>& weight_opt /* optional */, const std::optional<Tensor>& bias_opt /* optional */, const std::optional<Tensor>& running_mean_opt /* optional */, const std::optional<Tensor>& running_var_opt /* optional */,
bool training, double momentum, double eps, bool cudnn_enabled) {
// See [Note: hacky wrapper removal for optional tensor]
if (PYTORCH_MIOPEN_EXTRA_LOGGING)
std :: cout
<< "PYTORCH_MIOPEN_EXTRA_LOGGING: ********************* _batch_norm_impl_index"
<< " input=" << input.scalar_type()
<< " weight=" << (weight_opt.has_value() ? weight_opt.value().scalar_type() : at::ScalarType::Undefined)
<< " bias=" << (bias_opt.has_value() ? bias_opt.value().scalar_type() : at::ScalarType::Undefined)
<< " running_mean=" << (running_mean_opt.has_value() ? running_mean_opt.value().scalar_type() : at::ScalarType::Undefined)
<< " running_var=" << (running_var_opt.has_value() ? running_var_opt.value().scalar_type() : at::ScalarType::Undefined)
<< " training=" << training
// << " momentum=" << momentum
// << " eps=" << eps
<< " cudnn_enabled=" << cudnn_enabled
<< std::endl;

c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
Expand Down Expand Up @@ -600,7 +616,24 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(

Tensor reserve = at::empty({0}, input.options().dtype(kByte));

if (PYTORCH_MIOPEN_EXTRA_LOGGING)
std::cout
<< "PYTORCH_MIOPEN_EXTRA_LOGGING: ********************* _batch_norm_impl_index (use_miopen)"
<< " use_miopen=" << (backend == BatchNormBackend::Miopen)
<< " cudnn_enabled=" << cudnn_enabled
<< " dim=" << input.dim()
<< " memory_format=" << input.suggest_memory_format()
<< " input.dtype=" << input.scalar_type()
<< " weight.dtype=" << (weight.defined()?"+":"-") << weight.scalar_type()
<< " bias.dtype=" << (bias.defined()?"+":"-") << bias.scalar_type()
<< " running_mean.dtype=" << (running_mean.defined()?"+":"-") << running_mean.scalar_type()
<< " running_var.dtype=" << (running_mean.defined()?"+":"-") << running_mean.scalar_type()
<< " training=" << training
<< std::endl;

if (backend == BatchNormBackend::Miopen) {
if (PYTORCH_MIOPEN_EXTRA_LOGGING)
std::cout << "PYTORCH_MIOPEN_EXTRA_LOGGING: ********************* _batch_norm_impl_index (calling miopen_batch_norm)" << std::endl;
return std::tuple_cat(
at::miopen_batch_norm(
input.contiguous(), weight.contiguous(), bias.contiguous(),
Expand All @@ -623,6 +656,8 @@ std::tuple<Tensor, Tensor, Tensor> _batch_norm_impl_index_backward(
const Tensor& input, const Tensor& grad_output, const std::optional<Tensor>& weight_opt /* optional */, const std::optional<Tensor>& running_mean_opt /* optional */, const std::optional<Tensor>& running_var_opt /* optional */, const std::optional<Tensor>& save_mean_opt /* optional */, const std::optional<Tensor>& save_var_transform_opt /* optional */,
bool train, double epsilon, std::array<bool, 3> output_mask, const Tensor &reservedSpace) {
// See [Note: hacky wrapper removal for optional tensor]
if (PYTORCH_MIOPEN_EXTRA_LOGGING)
std :: cout << "PYTORCH_MIOPEN_EXTRA_LOGGING: ********************* _batch_norm_impl_index_backward" << std::endl;
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
Expand Down Expand Up @@ -653,12 +688,16 @@ std::tuple<Tensor, Tensor, Tensor> _batch_norm_impl_index_backward(

// backward in inference mode is not supported in cudnn, fallback to native
if (impl_index == 0 || (!train)) {
if (PYTORCH_MIOPEN_EXTRA_LOGGING)
std :: cout << "PYTORCH_MIOPEN_EXTRA_LOGGING: ********************* _batch_norm_impl_index_backward (calling native_batch_norm_backward)" << std::endl;
return at::native_batch_norm_backward(grad_output, input, weight, running_mean, running_var, save_mean, save_var_transform, train, epsilon, output_mask);
} else if (impl_index == 1) {
// TODO: _batch_norm_impl_index_backward is only used in JIT. cudnn NHWC
// format conversion is done inside cudnn_batch_norm_backward instead
return at::cudnn_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var_transform, epsilon, reservedSpace);
} else if (impl_index == 2) {
if (PYTORCH_MIOPEN_EXTRA_LOGGING)
std :: cout << "PYTORCH_MIOPEN_EXTRA_LOGGING: ********************* _batch_norm_impl_index_backward (calling miopen_batch_norm_backward)" << std::endl;
return at::miopen_batch_norm_backward(input, grad_output, weight, running_mean, running_var, save_mean, save_var_transform, epsilon);
}
TORCH_INTERNAL_ASSERT(false, "Unsupported impl_index in _batch_norm_impl_index_backward: ", impl_index);
Expand All @@ -669,6 +708,20 @@ Tensor batch_norm(
const Tensor& input, const std::optional<Tensor>& weight_opt, const std::optional<Tensor>& bias_opt,
const std::optional<Tensor>& running_mean_opt, const std::optional<Tensor>& running_var_opt,
bool training, double momentum, double eps, bool cudnn_enabled) {
if (PYTORCH_MIOPEN_EXTRA_LOGGING)
std :: cout
<< "PYTORCH_MIOPEN_EXTRA_LOGGING: ********************* batch_norm"
<< " input=" << input.scalar_type()
<< " weight=" << (weight_opt.has_value() ? weight_opt.value().scalar_type() : at::ScalarType::Undefined)
<< " bias=" << (bias_opt.has_value() ? bias_opt.value().scalar_type() : at::ScalarType::Undefined)
<< " running_mean=" << (running_mean_opt.has_value() ? running_mean_opt.value().scalar_type() : at::ScalarType::Undefined)
<< " running_var=" << (running_var_opt.has_value() ? running_var_opt.value().scalar_type() : at::ScalarType::Undefined)
<< " training=" << training
// << " momentum=" << momentum
// << " eps=" << eps
<< " cudnn_enabled=" << cudnn_enabled
<< std::endl;

const Tensor& weight = c10::value_or_else(weight_opt, [] {return Tensor();});
const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();});
const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();});
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/miopen/BatchNorm_miopen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm(
checkAllDefined(c, {running_mean, running_var});
}
checkAllSameGPU(c, {input, weight, bias, running_mean, running_var});
if (input->scalar_type() != ScalarType::Half) {
if (input->scalar_type() != ScalarType::Half && input->scalar_type() != ScalarType::BFloat16) {
checkAllSameType(c, {input, weight});
}
checkAllSameType(c, {weight, bias, running_mean, running_var});
Expand Down Expand Up @@ -186,7 +186,7 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm_backward(

checkAllDefined(c, {input, grad_output, weight, save_mean, save_var});
checkAllSameGPU(c, {input, grad_output, weight, save_mean, save_var});
if (input->scalar_type() == ScalarType::Half) {
if (input->scalar_type() == ScalarType::Half || input->scalar_type() == ScalarType::BFloat16) {
checkScalarType(c, weight, ScalarType::Float);
} else {
checkAllSameType(c, {input, weight});
Expand Down

0 comments on commit f5fe136

Please sign in to comment.