Skip to content

Add backward RMSNorm+Add fusion #2028

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 65 additions & 31 deletions tests/cpp/operator/test_normalization.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,19 @@ namespace {

template <typename InputType, typename OutputType>
void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
NormType norm_type, bool use_cudnn, const bool zero_centered_gamma_in_weight_dtype) {
const NormType norm_type, const bool use_cudnn,
const bool zero_centered_gamma_in_weight_dtype, const bool fused_bwd_add) {
if (sizeof(InputType) < sizeof(OutputType)) {
GTEST_SKIP() << "LN kernel does not support OutputType > InputType";
return;
}

if (norm_type == LayerNorm && fused_bwd_add) {
GTEST_SKIP() << "Fused LN backward+add not currently supported";
}

if (fused_bwd_add && zero_centered_gamma_in_weight_dtype) {
GTEST_SKIP() << "zero_centered_gamma_in_weight_dtype not currently supported "
<< "in fused norm backward+add";
}

if (getDeviceComputeCapability() < hopperComputeCapability && use_cudnn) {
Expand All @@ -45,7 +54,6 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
if ((itype == DType::kBFloat16 && otype == DType::kFloat16) ||
(itype == DType::kFloat16 && otype == DType::kBFloat16)) {
GTEST_SKIP() << "LN kernel does not support mixing Float16 and BFloat16";
return;
}

Tensor input("input", std::vector<size_t>{ N, H }, itype);
Expand All @@ -55,6 +63,7 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
Tensor mu("mu", std::vector<size_t>{ N }, DType::kFloat32);
Tensor rsigma("rsigma", std::vector<size_t>{ N }, DType::kFloat32);
Tensor dz("dz", std::vector<size_t>{ N, H }, wtype);
Tensor bwd_add("bwd_add", std::vector<size_t>{ N, H }, wtype);
Tensor dx("dx", std::vector<size_t>{ N, H }, itype);
Tensor dgamma("dgamma", std::vector<size_t>{ H }, wtype);
Tensor dbeta("dbeta", std::vector<size_t>{ H }, wtype);
Expand All @@ -65,6 +74,11 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
fillUniform(&beta);
setRandomScale(&z);
fillUniform(&dz);
if (fused_bwd_add) {
fillUniform(&bwd_add);
} else {
fillCase<WeightType>(&bwd_add, zeros);
}

std::unique_ptr<OutputType[]> ref_output = std::make_unique<OutputType[]>(N * H);
std::unique_ptr<float[]> ref_mu = std::make_unique<float[]>(N);
Expand All @@ -85,7 +99,6 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
nvte_enable_cudnn_norm_fwd(true);
nvte_enable_cudnn_norm_bwd(true);


// Zero-centered gamma in weight dtype only supported by CuDNN backend currently
if (zero_centered_gamma_in_weight_dtype) {
nvte_enable_zero_centered_gamma_in_weight_dtype(true);
Expand Down Expand Up @@ -125,15 +138,23 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
z.data(), rsigma.data(), workspace_fwd.data(),
prop.multiProcessorCount, zero_centered_gamma, 0);

nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(),
dx.data(), dgamma.data(),
workspace_bwd.data(),
prop.multiProcessorCount, zero_centered_gamma, 0);
workspace_bwd = Tensor("workspace", workspace_bwd.rowwise_shape(), workspace_bwd.dtype());
nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(),
dx.data(), dgamma.data(),
workspace_bwd.data(),
prop.multiProcessorCount, zero_centered_gamma, 0);
if (fused_bwd_add) {
nvte_rmsnorm_bwd_add(dz.data(), input.data(), bwd_add.data(), rsigma.data(), gamma.data(),
dx.data(), dgamma.data(), workspace_bwd.data(),
prop.multiProcessorCount, zero_centered_gamma, 0);
workspace_bwd = Tensor("workspace", workspace_bwd.rowwise_shape(), workspace_bwd.dtype());
nvte_rmsnorm_bwd_add(dz.data(), input.data(), bwd_add.data(), rsigma.data(), gamma.data(),
dx.data(), dgamma.data(), workspace_bwd.data(),
prop.multiProcessorCount, zero_centered_gamma, 0);
} else {
nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(),
dx.data(), dgamma.data(), workspace_bwd.data(), prop.multiProcessorCount,
zero_centered_gamma, 0);
workspace_bwd = Tensor("workspace", workspace_bwd.rowwise_shape(), workspace_bwd.dtype());
nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(),
dx.data(), dgamma.data(), workspace_bwd.data(), prop.multiProcessorCount,
zero_centered_gamma, 0);
}
}

if (use_cudnn){
Expand Down Expand Up @@ -167,6 +188,7 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
use_cudnn,
zero_centered_gamma_in_weight_dtype);
compute_ref_backward(norm_type, dz.rowwise_cpu_dptr<WeightType>(),
bwd_add.rowwise_cpu_dptr<WeightType>(),
input.rowwise_cpu_dptr<InputType>(),
mu.rowwise_cpu_dptr<float>(), rsigma.rowwise_cpu_dptr<float>(),
gamma.rowwise_cpu_dptr<WeightType>(),
Expand Down Expand Up @@ -214,30 +236,40 @@ std::vector<std::pair<size_t, size_t>> test_cases = {
} // namespace

class NormTestSuite : public ::testing::TestWithParam<std::tuple<bool,
NormType,
transformer_engine::DType,
transformer_engine::DType,
std::pair<size_t, size_t>,
bool,
bool>> {};
NormType,
transformer_engine::DType,
transformer_engine::DType,
std::pair<size_t, size_t>,
bool,
bool,
bool>> {};

TEST_P(NormTestSuite, TestNorm) {
using namespace transformer_engine;
using namespace test;
using namespace transformer_engine;
using namespace test;

const bool use_cudnn = std::get<0>(GetParam());
const NormType norm_type = std::get<1>(GetParam());
const DType input_type = std::get<2>(GetParam());
const DType output_type = std::get<3>(GetParam());
const auto size = std::get<4>(GetParam());
const bool zero_centered_gamma = std::get<5>(GetParam());
const bool cudnn_zero_centered_gamm_in_weight_dtype = std::get<6>(GetParam());

TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTest<InputType, OutputType>(size.first, size.second, zero_centered_gamma, norm_type, use_cudnn, cudnn_zero_centered_gamm_in_weight_dtype);
const DType input_type = std::get<2>(GetParam());
const DType output_type = std::get<3>(GetParam());
const auto size = std::get<4>(GetParam());
const bool zero_centered_gamma = std::get<5>(GetParam());
const bool cudnn_zero_centered_gamma_in_weight_dtype = std::get<6>(GetParam());
const bool fused_bwd_add = std::get<7>(GetParam());

TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTest<InputType, OutputType>(
size.first,
size.second,
zero_centered_gamma,
norm_type,
use_cudnn,
cudnn_zero_centered_gamma_in_weight_dtype,
fused_bwd_add
);
);
);
}

INSTANTIATE_TEST_SUITE_P(
Expand All @@ -250,6 +282,7 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16, DType::kFloat8E4M3),
::testing::ValuesIn(test_cases),
::testing::Values(false, true),
::testing::Values(false, true),
::testing::Values(false, true)),
[](const testing::TestParamInfo<NormTestSuite::ParamType>& info) {
auto backend = std::get<0>(info.param) == false ? "Te" : "Cudnn";
Expand All @@ -261,6 +294,7 @@ INSTANTIATE_TEST_SUITE_P(
std::to_string(std::get<4>(info.param).first) + "X" +
std::to_string(std::get<4>(info.param).second) + "X" +
std::to_string(std::get<5>(info.param)) + "X" +
std::to_string(std::get<6>(info.param));
std::to_string(std::get<6>(info.param)) + "X" +
std::to_string(std::get<7>(info.param));
return name;
});
6 changes: 4 additions & 2 deletions tests/cpp/operator/test_normalization.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ void compute_ref_output(NormType norm_type,


template <typename InputType, typename OutputType>
void compute_ref_backward(const NormType norm_type, const OutputType *output_grad, const InputType *data,
void compute_ref_backward(const NormType norm_type, const OutputType *output_grad,
const OutputType *add, const InputType *data,
const float *mu, const float *rsigma,
const InputType *gamma,
InputType *data_grad,
Expand Down Expand Up @@ -165,7 +166,8 @@ void compute_ref_backward(const NormType norm_type, const OutputType *output_gra
compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn, cudnn_zero_centered_gamma_in_weight_dtype);
const compute_t dz = static_cast<compute_t>(output_grad[i * H + j]);
const compute_t dy = g * dz;
const compute_t dx = rsigma[i] * (dy - mdyy * y - mdy);
const compute_t a = static_cast<compute_t>(add[i * H + j]);
const compute_t dx = a + rsigma[i] * (dy - mdyy * y - mdy);
data_grad[i * H + j] = static_cast<InputType>(dx);
}
}
Expand Down
11 changes: 10 additions & 1 deletion tests/cpp/test_common.cu
Original file line number Diff line number Diff line change
Expand Up @@ -844,9 +844,18 @@ void fillCase(Tensor *t, const InputsFillCase fill_case) {
}
}

template void fillCase<byte>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<int16>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<int32>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<int64>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<fp32>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<fp16>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<bf16>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<fp8e4m3>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<fp8e5m2>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<fp32>(Tensor *t, const InputsFillCase fill_case);
#if FP4_TYPE_SUPPORTED
template void fillCase<fp4e2m1>(Tensor *t, const InputsFillCase fill_case);
#endif

void setRandomScale(Tensor *t) {
std::uniform_real_distribution<> dis(-2.0, 1.0);
Expand Down
89 changes: 89 additions & 0 deletions tests/pytorch/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import transformer_engine.pytorch.ops as te_ops
from transformer_engine.pytorch.ops.fused import (
BackwardActivationBias,
BackwardAddRMSNorm,
BackwardLinearAdd,
ForwardLinearBiasActivation,
ForwardLinearBiasAdd,
Expand Down Expand Up @@ -2100,6 +2101,94 @@ def test_backward_activation_bias(
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(db_test, b_ref.grad, **tols)

@pytest.mark.parametrize("weight_shape", ((19,), (64,)))
@pytest.mark.parametrize("in_shape", ((-1,), (6, 16, -1)))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("zero_centered_gamma", (False, True))
def test_backward_add_rmsnorm(
self,
*,
weight_shape: Iterable[int],
in_shape: Iterable[int],
dtype: torch.dtype,
device: torch.device = "cuda",
eps: float = 0.3,
zero_centered_gamma: bool,
) -> None:
"""Fused backward RMNorm + add"""

# Make input and weight shapes consistent
in_shape = list(in_shape)[:-1] + list(weight_shape)

# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
)
w_ref, w_test = make_reference_and_test_tensors(
weight_shape,
test_dtype=dtype,
test_device=device,
)
dy1_ref, dy1_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
dy2_ref, dy2_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)

# Plain PyTorch implementation
inner_dims = tuple(range(len(in_shape) - len(weight_shape), len(in_shape)))
var_ref = x_ref.square().sum(dim=inner_dims, keepdim=True) / math.prod(weight_shape)
if zero_centered_gamma:
y1_ref = x_ref / torch.sqrt(eps + var_ref) * (1 + w_ref)
else:
y1_ref = x_ref / torch.sqrt(eps + var_ref) * w_ref
y2_ref = x_ref
(y1_ref * dy1_ref + y2_ref * dy2_ref).sum().backward()

# Implementation with fusible operations
model = te_ops.Sequential(
te_ops.MakeExtraOutput(),
te_ops.RMSNorm(
weight_shape,
eps=eps,
device=device,
dtype=dtype,
zero_centered_gamma=zero_centered_gamma,
),
)
with torch.no_grad():
model[1].weight.copy_(w_test)
del w_test
y1_test, y2_test = model(x_test)
(y1_test * dy1_test + y2_test * dy2_test).sum().backward()

# Check that backward operations have been fused
backward_ops = model._module_groups[0]._backward_ops
assert len(backward_ops) == 1
assert isinstance(backward_ops[0][0], BackwardAddRMSNorm)

# Expected numerical error
tols = dtype_tols(dtype)

# Check results
y1_test = y1_test.to(dtype=torch.float64, device="cpu")
y2_test = y2_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
dw_test = model[1].weight.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y1_test, y1_ref, **tols)
torch.testing.assert_close(y2_test, y2_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(dw_test, w_ref.grad, **tols)

@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", _quantization_list)
def test_backward_linear_add(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ extern "C" {
* y = \frac{x - E[x]}{\sqrt{Var[x] + \varepsilon}} \gamma + \beta
* @f]
*
* Calling this function with workspace set to empty tensor will not perform the operation,
* Calling this function with workspace set to an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] x Input tensor of shape [N, H].
Expand Down Expand Up @@ -55,8 +55,8 @@ void nvte_layernorm_fwd(const NVTETensor x, const NVTETensor gamma, const NVTETe
* else
* with respect to \f$x\f$, \f$\gamma\f$ and \f$\beta\f$.
*
* Calling this function with workspace set to empty tensor will not perform the operation,
* but instead set the shape and type of these tensors to the required values.
* Calling this function with workspace set to an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] dz Incoming gradient tensor of shape [N, H].
* \param[in] x Forward input tensor of shape [N, H].
Expand Down Expand Up @@ -90,9 +90,8 @@ void nvte_layernorm_bwd(const NVTETensor dz, const NVTETensor x, const NVTETenso
* RMS_\varepsilon(x) = \sqrt{\frac{1}{n}\sum_{i=0}^{n-1} x_i^2 + \varepsilon}
* @f]
*
* Calling this function with workspace and barrier set to empty tensor will not
* perform the operation, but instead set the shape and type of the workspace
* and barrier tensors to the required values.
* Calling this function with workspace set to an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] x Input tensor of shape [N, H].
* \param[in] gamma Gamma tensor of shape [H].
Expand Down Expand Up @@ -121,9 +120,8 @@ void nvte_rmsnorm_fwd(const NVTETensor x, const NVTETensor gamma, const float ep
* @f]
* with respect to \f$x\f$ and \f$gamma\f$.
*
* Calling this function with workspace, barrier, dgamma_part set
* to empty tensor will not perform the operation, but instead set the shape and type
* of these tensors to the required values.
* Calling this function with workspace set to an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] dz Incoming gradient tensor of shape [N, H].
* \param[in] x Forward input tensor of shape [N, H].
Expand All @@ -142,6 +140,29 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, const NVTETensor x, const NVTETensor
NVTETensor workspace, const int multiprocessorCount,
const bool zero_centered_gamma, cudaStream_t stream);

/*! \brief Compute backward of RMSNorm and add additional tensor to output gradient
*
* Calling this function with workspace set to an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] dz Incoming gradient tensor of shape [N, H].
* \param[in] x Forward input tensor of shape [N, H].
* \param[in] add Additional tensor to add to output gradient [N, H].
* \param[in] rsigma Reciprocal of the root mean square of the input
* calculated over the last dimension. Shape: [N].
* \param[in] gamma Gamma tensor of shape [H].
* \param[out] dx Output gradient of shape [N, H].
* \param[out] dgamma Gradient for gamma tensor of shape [H].
* \param[out] workspace Workspace tensor.
* \param[in] multiprocessorCount Number of SMs in the device.
* \param[in] zero_centered_gamma Multiply normalized values by @f$ \gamma+1 @f$ instead of @f$ \gamma @f$
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_rmsnorm_bwd_add(const NVTETensor dz, const NVTETensor x, const NVTETensor add,
const NVTETensor rsigma, const NVTETensor gamma, NVTETensor dx,
NVTETensor dgamma, NVTETensor workspace, const int multiprocessorCount,
const bool zero_centered_gamma, cudaStream_t stream);

/*! \brief Helper to enable cuDNN backend for normalization
*
* \param[in] bool Enable if True
Expand Down
Loading