From c6efd90180a101b2fb6f60182b11c7f42f0da9b1 Mon Sep 17 00:00:00 2001 From: zhongboz Date: Mon, 14 Jul 2025 15:44:44 -0700 Subject: [PATCH 01/24] remove reciprocal op Signed-off-by: zhongboz --- transformer_engine/pytorch/csrc/quantizer.cpp | 5 +++-- transformer_engine/pytorch/tensor/float8_tensor.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index dc4d55d2fc..8829a38398 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -220,8 +220,9 @@ std::pair Float8CurrentScalingQuantizer::create_tenso } const py::object py_columnwise_data = create_transpose ? py::cast(columnwise_data) : py::none(); - // In current scaling, scale is not known but we initialize it with 1 to avoid division by zero. If scale is already calculated, it can be correctly set. - at::Tensor scale_inv = at::reciprocal(scale); + // In current scaling, scale_inv is not known but we initialize it as an empty tensor to be filled later. + at::Tensor scale_inv = + at::empty({1}, at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)); py::object ret; if (internal) { diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 1c3e575473..fc9c16af5b 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -215,7 +215,7 @@ def __init__( amax_epsilon: float = 0.0, ) -> None: super().__init__(rowwise=rowwise, columnwise=columnwise) - self.scale = torch.ones(1, dtype=torch.float32, device=device) + self.scale = torch.empty(1, dtype=torch.float32, device=device) self.amax = torch.empty(1, dtype=torch.float32, device=device) self.dtype = fp8_dtype self.with_amax_reduction = with_amax_reduction From 2fdef539e94f6c38481e2e292f72aa91c5b55c18 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Tue, 15 Jul 2025 08:28:52 +0000 Subject: [PATCH 02/24] Refactor Quantizer::create_tensor function Signed-off-by: Tim Moon --- transformer_engine/pytorch/csrc/common.h | 24 +- .../pytorch/csrc/extensions/attention.cpp | 19 +- transformer_engine/pytorch/csrc/quantizer.cpp | 394 ++++++++++-------- 3 files changed, 253 insertions(+), 184 deletions(-) diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index d8c08651f2..4c39f504b1 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -98,8 +98,7 @@ class Quantizer { virtual void set_quantization_params(TensorWrapper* tensor) const = 0; virtual std::pair create_tensor( - const std::vector& shape, DType dtype, - std::optional rowwise_data = std::nullopt) const = 0; + const std::vector& shape, DType dtype) const = 0; virtual ~Quantizer() = default; @@ -121,8 +120,10 @@ class NoneQuantizer : public Quantizer { void set_quantization_params(TensorWrapper* tensor) const override {} std::pair create_tensor( - const std::vector& shape, DType dtype, - std::optional rowwise_data = std::nullopt) const override; + const std::vector& shape, DType dtype) const override; + + std::pair create_tensor( + const std::vector& shape, DType dtype, at::Tensor data) const; }; class Float8Quantizer : public Quantizer { @@ -138,9 +139,13 @@ class Float8Quantizer : public Quantizer { void set_quantization_params(TensorWrapper* tensor) const override; + std::pair create_tensor( + const std::vector& shape, DType dtype) const override; + std::pair create_tensor( const std::vector& shape, DType dtype, - std::optional rowwise_data = std::nullopt) const override; + std::optional data, std::optional transpose, + std::optional scale_inv) const; }; class Float8CurrentScalingQuantizer : public Quantizer { @@ -161,8 +166,7 @@ class Float8CurrentScalingQuantizer : public Quantizer { void set_quantization_params(TensorWrapper* tensor) const override; std::pair create_tensor( - const std::vector& shape, DType dtype, - std::optional rowwise_data = std::nullopt) const override; + const std::vector& shape, DType dtype) const override; }; class Float8BlockQuantizer : public Quantizer { @@ -195,8 +199,7 @@ class Float8BlockQuantizer : public Quantizer { // for the tensor. Should set quantized data, scales for rowwise // and optionally columnwise usage. std::pair create_tensor( - const std::vector& shape, DType dtype, - std::optional rowwise_data = std::nullopt) const override; + const std::vector& shape, DType dtype) const override; std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; }; @@ -212,8 +215,7 @@ class MXFP8Quantizer : public Quantizer { void set_quantization_params(TensorWrapper* tensor) const override; std::pair create_tensor( - const std::vector& shape, DType dtype, - std::optional rowwise_data = std::nullopt) const override; + const std::vector& shape, DType dtype) const override; }; std::unique_ptr convert_quantizer(py::handle quantizer); diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 71a8062b1a..7a68cfc0e1 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -374,9 +374,22 @@ std::vector fused_attn_bwd( default: NVTE_ERROR("QKV layout not supported!"); } - std::tie(te_dQ, py_dQ) = dQKV_quantizer->create_tensor(q_shape, fake_dtype_te, dQ); - std::tie(te_dK, py_dK) = dQKV_quantizer->create_tensor(k_shape, fake_dtype_te, dK); - std::tie(te_dV, py_dV) = dQKV_quantizer->create_tensor(v_shape, fake_dtype_te, dV); + if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { + auto* fp8_quantizer = dynamic_cast(dQKV_quantizer.get()); + NVTE_CHECK(fp8_quantizer != nullptr, "Expected Float8Quantizer when dtype is FP8"); + std::tie(te_dQ, py_dQ) = fp8_quantizer->create_tensor(q_shape, fake_dtype_te, dQ, + std::nullopt, std::nullopt); + std::tie(te_dK, py_dK) = fp8_quantizer->create_tensor(k_shape, fake_dtype_te, dK, + std::nullopt, std::nullopt); + std::tie(te_dV, py_dV) = fp8_quantizer->create_tensor(v_shape, fake_dtype_te, dV, + std::nullopt, std::nullopt); + } else { + auto* none_quantizer = dynamic_cast(dQKV_quantizer.get()); + NVTE_CHECK(none_quantizer != nullptr, "Expected NoneQuantizer when dtype is not FP8"); + std::tie(te_dQ, py_dQ) = none_quantizer->create_tensor(q_shape, fake_dtype_te, dQ); + std::tie(te_dK, py_dK) = none_quantizer->create_tensor(k_shape, fake_dtype_te, dK); + std::tie(te_dV, py_dV) = none_quantizer->create_tensor(v_shape, fake_dtype_te, dV); + } // construct NVTE tensors if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 8829a38398..09e3e91b55 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -38,23 +38,18 @@ Float8Quantizer::Float8Quantizer(const py::handle& quantizer) : Quantizer(quanti } std::pair NoneQuantizer::create_tensor( - const std::vector& shape, DType dtype, std::optional rowwise_data) const { - at::TensorOptions opts; - opts = opts.dtype(GetATenDType(dtype)).device(torch::kCUDA); - std::vector torch_shape; - for (auto s : shape) { - torch_shape.emplace_back(static_cast(s)); - } - at::Tensor ret; - if (rowwise_data.has_value()) { - ret = std::move(*rowwise_data); - } else { - ret = at::empty(torch_shape, opts); - } + const std::vector& shape, DType dtype) const { + const std::vector shape_int64(shape.begin(), shape.end()); + const auto opts = at::TensorOptions().dtype(GetATenDType(dtype)).device(torch::kCUDA); + return create_tensor(shape, dtype, at::empty(shape_int64, opts)); +} - TensorWrapper tensor; - tensor.set_rowwise_data(ret.data_ptr(), dtype, shape); - return {std::move(tensor), py::cast(ret)}; +std::pair NoneQuantizer::create_tensor( + const std::vector& shape, DType dtype, at::Tensor data) const { + py::object out_py = py::cast(data); + TensorWrapper out_cpp; + out_cpp.set_rowwise_data(data.data_ptr(), dtype, shape); + return {std::move(out_cpp), std::move(out_py)}; } void Float8Quantizer::set_quantization_params(TensorWrapper* tensor) const { @@ -76,68 +71,110 @@ void Float8Quantizer::set_quantization_params(TensorWrapper* tensor) const { } std::pair Float8Quantizer::create_tensor( - const std::vector& shape, DType dtype, std::optional rowwise_data) const { - using namespace pybind11::literals; - std::vector rowwise_torch_shape; - std::vector columnwise_torch_shape; - - if (!shape.empty()) { - columnwise_torch_shape.emplace_back(static_cast(shape.back())); + const std::vector& shape, DType dtype) const { + // Allocate data tensor if needed + std::optional data; + if (rowwise_usage) { + const std::vector shape_int64(shape.begin(), shape.end()); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + data.emplace(at::empty(shape_int64, opts)); } - for (size_t i = 0; i < shape.size(); ++i) { - if (i < shape.size() - 1) { - columnwise_torch_shape.emplace_back(static_cast(shape[i])); + + // Allocate transpose tensor if needed + std::optional transpose; + const bool with_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + if (with_transpose) { + std::vector transpose_shape_int64; + if (shape.size() > 0) { + transpose_shape_int64.push_back(shape.back()); + for (size_t i = 0; i < shape.size() - 1; ++i) { + transpose_shape_int64.push_back(shape[i]); + } } - rowwise_torch_shape.emplace_back(static_cast(shape[i])); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + transpose.emplace(at::empty(transpose_shape_int64, opts)); } - at::TensorOptions opts; - opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); - at::Tensor data; + + // Allocate scale-inverse tensor + std::optional scale_inv; + { + const std::vector scale_inv_shape = {1}; + const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + scale_inv.emplace(at::empty(scale_inv_shape, opts)); + }; + + // Construct FP8 tensor + return create_tensor(shape, dtype, std::move(data), + std::move(transpose), std::move(scale_inv)); +} + +std::pair Float8Quantizer::create_tensor( + const std::vector& shape, DType dtype, std::optional data, + std::optional transpose, + std::optional scale_inv) const { + using namespace pybind11::literals; + + // Initialize data tensor + at::Tensor data_tensor; if (rowwise_usage) { - if (rowwise_data.has_value()) { - data = std::move(*rowwise_data); - } else { - data = at::empty(rowwise_torch_shape, opts); - } + NVTE_CHECK(data, + "Constructing Float8Tensor with row-wise usage, but no FP8 data was provided"); + data_tensor = std::move(*data); } - const py::object py_data = rowwise_usage ? py::cast(data) : py::none(); - at::Tensor columnwise_data; - bool create_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); - if (create_transpose) { - columnwise_data = at::empty(columnwise_torch_shape, opts); + py::object data_py = rowwise_usage ? py::cast(data_tensor) : py::none(); + + // Initialize transpose tensor + at::Tensor transpose_tensor; + const bool with_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + if (with_transpose) { + NVTE_CHECK(transpose, + "Constructing Float8Tensor with column-wise usage, but no FP8 transpose was provided"); + transpose_tensor = std::move(*transpose); } - const py::object py_columnwise_data = create_transpose ? py::cast(columnwise_data) : py::none(); - opts = opts.dtype(torch::kFloat32); - // TODO: Replace with an empty tensor. - at::Tensor scale_inv = at::reciprocal(scale); - py::object ret; + py::object transpose_py = with_transpose ? py::cast(transpose_tensor) : py::none(); + + // Initialize scale-inverse tensor + at::Tensor scale_inv_tensor = scale_inv ? std::move(*scale_inv) : at::reciprocal(scale); + py::object scale_inv_py = py::cast(scale_inv_tensor); + + // Construct Python FP8 tensor + py::object out_py; if (internal) { py::handle Float8TensorClass(reinterpret_cast(Float8TensorBasePythonClass)); - ret = Float8TensorClass("data"_a = py_data, "fp8_scale_inv"_a = scale_inv, - "fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data, - "quantizer"_a = this->quantizer); + out_py = Float8TensorClass("data"_a = data_py, "fp8_scale_inv"_a = scale_inv_py, + "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, + "quantizer"_a = this->quantizer); } else { py::handle Float8TensorClass(reinterpret_cast(Float8TensorPythonClass)); - ret = Float8TensorClass("shape"_a = rowwise_torch_shape, "dtype"_a = GetATenDType(dtype), - "data"_a = py_data, "fp8_scale_inv"_a = scale_inv, - "fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data, - "quantizer"_a = this->quantizer); + const std::vector shape_int64(shape.begin(), shape.end()); + out_py = Float8TensorClass("shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), + "data"_a = data_py, "fp8_scale_inv"_a = scale_inv, + "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, + "quantizer"_a = this->quantizer); } - TensorWrapper tensor(this->get_scaling_mode()); + + // Construct C++ FP8 tensor + TensorWrapper out_cpp(this->get_scaling_mode()); if (rowwise_usage) { - tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape); - tensor.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector{1}); + out_cpp.set_rowwise_data(data_tensor.data_ptr(), this->dtype, shape); + out_cpp.set_rowwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32, + std::vector{1}); } - if (create_transpose) { - std::vector transposed_shape; - for (auto s : columnwise_torch_shape) { - transposed_shape.emplace_back(static_cast(s)); + if (with_transpose) { + std::vector transpose_shape; + if (shape.size() > 0) { + transpose_shape.push_back(shape.back()); + for (size_t i = 0; i < shape.size() - 1; ++i) { + transpose_shape.push_back(shape[i]); + } } - tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, transposed_shape); - tensor.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector{1}); + out_cpp.set_columnwise_data(transpose_tensor.data_ptr(), this->dtype, transpose_shape); + out_cpp.set_columnwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32, + std::vector{1}); } - this->set_quantization_params(&tensor); - return {std::move(tensor), std::move(ret)}; + this->set_quantization_params(&out_cpp); + + return {std::move(out_cpp), std::move(out_py)}; } Float8CurrentScalingQuantizer::Float8CurrentScalingQuantizer(const py::handle& quantizer) @@ -187,72 +224,81 @@ void Float8CurrentScalingQuantizer::set_quantization_params(TensorWrapper* tenso } std::pair Float8CurrentScalingQuantizer::create_tensor( - const std::vector& shape, DType dtype, std::optional rowwise_data) const { + const std::vector& shape, DType dtype) const { using namespace pybind11::literals; - std::vector rowwise_torch_shape; - std::vector columnwise_torch_shape; - std::vector scale_inv_torch_shape = {1}; // Shape of 1 element for scale_inv - if (!shape.empty()) { - columnwise_torch_shape.emplace_back(static_cast(shape.back())); - } - for (size_t i = 0; i < shape.size(); ++i) { - if (i < shape.size() - 1) { - columnwise_torch_shape.emplace_back(static_cast(shape[i])); - } - rowwise_torch_shape.emplace_back(static_cast(shape[i])); - } - at::TensorOptions opts; - opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); - at::Tensor data; + // Initialize data tensor + at::Tensor data_tensor; if (rowwise_usage) { - if (rowwise_data.has_value()) { - data = std::move(*rowwise_data); - } else { - data = at::empty(rowwise_torch_shape, opts); - } + const std::vector shape_int64(shape.begin(), shape.end()); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + data_tensor = at::empty(shape_int64, opts); } - const py::object py_data = rowwise_usage ? py::cast(data) : py::none(); - at::Tensor columnwise_data; - bool create_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); - if (create_transpose) { - columnwise_data = at::empty(columnwise_torch_shape, opts); + + // Initialize transpose tensor + at::Tensor transpose_tensor; + const bool with_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + if (with_transpose) { + std::vector transpose_shape_int64; + if (shape.size() > 0) { + transpose_shape_int64.push_back(shape.back()); + for (size_t i = 0; i < shape.size() - 1; ++i) { + transpose_shape_int64.push_back(shape[i]); + } + } + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + transpose_tensor = at::empty(transpose_shape_int64, opts); } - const py::object py_columnwise_data = create_transpose ? py::cast(columnwise_data) : py::none(); - // In current scaling, scale_inv is not known but we initialize it as an empty tensor to be filled later. - at::Tensor scale_inv = - at::empty({1}, at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)); + // Initialize scale-inverse tensor + at::Tensor scale_inv_tensor; + { + const std::vector scale_inv_shape = {1}; + const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + scale_inv_tensor = at::empty(scale_inv_shape, opts); + } - py::object ret; + // Construct Python FP8 tensor + py::object out_py; + py::object data_py = rowwise_usage ? py::cast(data_tensor) : py::none(); + py::object transpose_py = with_transpose ? py::cast(transpose_tensor) : py::none(); + py::object scale_inv_py = py::cast(scale_inv_tensor); if (internal) { py::handle Float8TensorClass(reinterpret_cast(Float8TensorBasePythonClass)); - ret = Float8TensorClass("data"_a = py_data, "fp8_scale_inv"_a = scale_inv, - "fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data, - "quantizer"_a = this->quantizer); + out_py = Float8TensorClass("data"_a = data_py, "fp8_scale_inv"_a = scale_inv_py, + "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, + "quantizer"_a = this->quantizer); } else { py::handle Float8TensorClass(reinterpret_cast(Float8TensorPythonClass)); - ret = Float8TensorClass("shape"_a = rowwise_torch_shape, "dtype"_a = GetATenDType(dtype), - "data"_a = py_data, "fp8_scale_inv"_a = scale_inv, - "fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data, - "quantizer"_a = this->quantizer); + const std::vector shape_int64(shape.begin(), shape.end()); + out_py = Float8TensorClass("shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), + "data"_a = data_py, "fp8_scale_inv"_a = scale_inv_py, + "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, + "quantizer"_a = this->quantizer); } - TensorWrapper tensor(this->get_scaling_mode()); + + // Construct C++ FP8 tensor + TensorWrapper out_cpp(this->get_scaling_mode()); if (rowwise_usage) { - tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape); - tensor.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector{1}); + out_cpp.set_rowwise_data(data_tensor.data_ptr(), this->dtype, shape); + out_cpp.set_rowwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32, + std::vector{1}); } - if (create_transpose) { - std::vector transposed_shape; - for (auto s : columnwise_torch_shape) { - transposed_shape.emplace_back(static_cast(s)); + if (with_transpose) { + std::vector transpose_shape; + if (shape.size() > 0) { + transpose_shape.push_back(shape.back()); + for (size_t i = 0; i < shape.size() - 1; ++i) { + transpose_shape.push_back(shape[i]); + } } - tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, transposed_shape); - tensor.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector{1}); + out_cpp.set_columnwise_data(transpose_tensor.data_ptr(), this->dtype, transpose_shape); + out_cpp.set_columnwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32, + std::vector{1}); } - this->set_quantization_params(&tensor); + this->set_quantization_params(&out_cpp); - return {std::move(tensor), std::move(ret)}; + return {std::move(out_cpp), std::move(out_py)}; } Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quantizer(quantizer) { @@ -281,7 +327,7 @@ void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const } std::pair Float8BlockQuantizer::create_tensor( - const std::vector& shape, DType dtype, std::optional rowwise_data) const { + const std::vector& shape, DType dtype) const { using namespace pybind11::literals; std::vector torch_shape; for (auto s : shape) { @@ -300,11 +346,7 @@ std::pair Float8BlockQuantizer::create_tensor( : Float8BlockScaleTensorFormat::GEMM_READY); if (rowwise_usage) { - if (rowwise_data.has_value()) { - data_rowwise = std::move(*rowwise_data); - } else { - data_rowwise = at::empty(torch_shape, opts); - } + data_rowwise = at::empty(torch_shape, opts); auto scale_shape = get_scale_shape(shape, false); size_t sinv0 = scale_shape[0]; size_t sinv1 = scale_shape[1]; @@ -467,72 +509,84 @@ void MXFP8Quantizer::set_quantization_params(TensorWrapper* tensor) const { } std::pair MXFP8Quantizer::create_tensor( - const std::vector& shape, DType dtype, std::optional rowwise_data) const { + const std::vector& shape, DType dtype) const { using namespace pybind11::literals; - std::vector torch_shape; - size_t numel = 1; - for (auto s : shape) { - torch_shape.emplace_back(static_cast(s)); - numel *= s; - } - TensorWrapper tensor(NVTE_MXFP8_1D_SCALING); - at::TensorOptions opts; - at::Tensor rowwise_data1, columnwise_data, rowwise_scale_inv, - columnwise_scale_inv; // TODO(pgadzinski) - change - opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); - auto last_dim = static_cast(torch_shape.back()); - - NVTE_CHECK(last_dim % MXFP8_BLOCK_SIZE == 0 && (numel / last_dim) % MXFP8_BLOCK_SIZE == 0, + // Tensor dimensions + const std::vector shape_int64(shape.begin(), shape.end()); + size_t flat_first_dim = 1; + if (shape.size() > 0) { + for (size_t i = 0; i < shape.size() - 1; ++i) { + flat_first_dim *= shape[i]; + } + } + const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1; + NVTE_CHECK(flat_first_dim % MXFP8_BLOCK_SIZE == 0 && flat_last_dim % MXFP8_BLOCK_SIZE == 0, "MXFP8 requires tensor dims that are divisble by ", MXFP8_BLOCK_SIZE, - " (got shape=", torch_shape, ")"); - - at::Tensor data; + " (got shape=", shape, ")"); + const std::vector rowwise_scale_inv_shape_int64 = { + roundup(flat_first_dim, 128), + roundup(flat_last_dim / MXFP8_BLOCK_SIZE, 4)}; + const std::vector columnwise_scale_inv_shape_int64 = { + roundup(flat_first_dim / MXFP8_BLOCK_SIZE, 4), + roundup(flat_last_dim, 128)}; + + // Allocate tensors + at::Tensor rowwise_data_tensor, rowwise_scale_inv_tensor; + at::Tensor columnwise_data_tensor, columnwise_scale_inv_tensor; + const auto uint8_tensor_opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); if (rowwise_usage) { - if (rowwise_data.has_value()) { - data = std::move(*rowwise_data); - } else { - data = at::empty(torch_shape, opts); - } - auto sinv0 = roundup(numel / last_dim, 128); - auto sinv1 = roundup(last_dim / MXFP8_BLOCK_SIZE, 4); - rowwise_scale_inv = at::zeros({sinv0, sinv1}, opts); - tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape); - tensor.set_rowwise_scale_inv( - rowwise_scale_inv.data_ptr(), DType::kFloat8E8M0, - std::vector{static_cast(sinv0), static_cast(sinv1)}); + rowwise_data_tensor = at::empty(shape_int64, uint8_tensor_opts); + rowwise_scale_inv_tensor = at::empty(rowwise_scale_inv_shape_int64, uint8_tensor_opts); } - if (columnwise_usage) { - auto sinv0 = roundup(numel / (last_dim * MXFP8_BLOCK_SIZE), 4); - auto sinv1 = roundup(last_dim, 128); - columnwise_data = at::empty(torch_shape, opts); - columnwise_scale_inv = at::zeros({sinv0, sinv1}, opts); - - tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, shape); - tensor.set_columnwise_scale_inv( - columnwise_scale_inv.data_ptr(), DType::kFloat8E8M0, - std::vector{static_cast(sinv0), static_cast(sinv1)}); + columnwise_data_tensor = at::empty(shape_int64, uint8_tensor_opts); + columnwise_scale_inv_tensor = at::empty(columnwise_scale_inv_shape_int64, uint8_tensor_opts); } - this->set_quantization_params(&tensor); - py::object ret; + // Construct Python MXFP8 tensor + py::object out_py; + auto py_cast = [] (at::Tensor &tensor, bool need_cast) -> py::object { + return need_cast ? py::cast(tensor) : py::none(); + }; + py::object rowwise_data_py = py_cast(rowwise_data_tensor, rowwise_usage); + py::object rowwise_scale_inv_py = py_cast(rowwise_scale_inv_tensor, rowwise_usage); + py::object columnwise_data_py = py_cast(columnwise_data_tensor, columnwise_usage); + py::object columnwise_scale_inv_py = py_cast(columnwise_scale_inv_tensor, columnwise_usage); if (internal) { py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorBasePythonClass)); - ret = MXFP8TensorClass("rowwise_data"_a = data, "columnwise_data"_a = columnwise_data, - "rowwise_scale_inv"_a = rowwise_scale_inv, - "columnwise_scale_inv"_a = columnwise_scale_inv, - "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); + out_py = MXFP8TensorClass("rowwise_data"_a = rowwise_data_py, "columnwise_data"_a = columnwise_data_py, + "rowwise_scale_inv"_a = rowwise_scale_inv_py, + "columnwise_scale_inv"_a = columnwise_scale_inv_py, + "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); } else { py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorPythonClass)); - ret = MXFP8TensorClass("shape"_a = torch_shape, "dtype"_a = GetATenDType(dtype), - "rowwise_data"_a = data, "columnwise_data"_a = columnwise_data, - "rowwise_scale_inv"_a = rowwise_scale_inv, - "columnwise_scale_inv"_a = columnwise_scale_inv, - "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); + out_py = MXFP8TensorClass("shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), + "rowwise_data"_a = rowwise_data_py, "columnwise_data"_a = columnwise_data_py, + "rowwise_scale_inv"_a = rowwise_scale_inv_py, + "columnwise_scale_inv"_a = columnwise_scale_inv_py, + "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); } - return {std::move(tensor), std::move(ret)}; + // Construct C++ MXFP8 tensor + TensorWrapper out_cpp(NVTE_MXFP8_1D_SCALING); + if (rowwise_usage) { + const std::vector scale_inv_shape(rowwise_scale_inv_shape_int64.begin(), + rowwise_scale_inv_shape_int64.end()); + out_cpp.set_rowwise_data(rowwise_data_tensor.data_ptr(), this->dtype, shape); + out_cpp.set_rowwise_scale_inv(rowwise_scale_inv_tensor.data_ptr(), DType::kFloat8E8M0, + scale_inv_shape); + } + if (columnwise_usage) { + const std::vector scale_inv_shape(columnwise_scale_inv_shape_int64.begin(), + columnwise_scale_inv_shape_int64.end()); + out_cpp.set_columnwise_data(columnwise_data_tensor.data_ptr(), this->dtype, shape); + out_cpp.set_columnwise_scale_inv(columnwise_scale_inv_tensor.data_ptr(), DType::kFloat8E8M0, + scale_inv_shape); + } + this->set_quantization_params(&out_cpp); + + return {std::move(out_cpp), std::move(out_py)}; } } // namespace transformer_engine::pytorch From 1338edf940216f6860bc10395906e8a18440f5bf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 16 Jul 2025 01:54:58 +0000 Subject: [PATCH 03/24] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/csrc/common.h | 36 +++++++++---------- .../pytorch/csrc/extensions/attention.cpp | 16 ++++----- transformer_engine/pytorch/csrc/quantizer.cpp | 35 +++++++++--------- 3 files changed, 44 insertions(+), 43 deletions(-) diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 277dc0c24c..4b73ec0347 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -98,8 +98,8 @@ class Quantizer { virtual void set_quantization_params(TensorWrapper* tensor) const = 0; - virtual std::pair create_tensor( - const std::vector& shape, DType dtype) const = 0; + virtual std::pair create_tensor(const std::vector& shape, + DType dtype) const = 0; virtual ~Quantizer() = default; @@ -120,11 +120,11 @@ class NoneQuantizer : public Quantizer { void set_quantization_params(TensorWrapper* tensor) const override {} - std::pair create_tensor( - const std::vector& shape, DType dtype) const override; + std::pair create_tensor(const std::vector& shape, + DType dtype) const override; - std::pair create_tensor( - const std::vector& shape, DType dtype, at::Tensor data) const; + std::pair create_tensor(const std::vector& shape, DType dtype, + at::Tensor data) const; }; class Float8Quantizer : public Quantizer { @@ -140,13 +140,13 @@ class Float8Quantizer : public Quantizer { void set_quantization_params(TensorWrapper* tensor) const override; - std::pair create_tensor( - const std::vector& shape, DType dtype) const override; + std::pair create_tensor(const std::vector& shape, + DType dtype) const override; - std::pair create_tensor( - const std::vector& shape, DType dtype, - std::optional data, std::optional transpose, - std::optional scale_inv) const; + std::pair create_tensor(const std::vector& shape, DType dtype, + std::optional data, + std::optional transpose, + std::optional scale_inv) const; }; class Float8CurrentScalingQuantizer : public Quantizer { @@ -166,8 +166,8 @@ class Float8CurrentScalingQuantizer : public Quantizer { void set_quantization_params(TensorWrapper* tensor) const override; - std::pair create_tensor( - const std::vector& shape, DType dtype) const override; + std::pair create_tensor(const std::vector& shape, + DType dtype) const override; }; class Float8BlockQuantizer : public Quantizer { @@ -199,8 +199,8 @@ class Float8BlockQuantizer : public Quantizer { // Create a python Float8BlockQuantized tensor and C++ wrapper // for the tensor. Should set quantized data, scales for rowwise // and optionally columnwise usage. - std::pair create_tensor( - const std::vector& shape, DType dtype) const override; + std::pair create_tensor(const std::vector& shape, + DType dtype) const override; std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; }; @@ -215,8 +215,8 @@ class MXFP8Quantizer : public Quantizer { void set_quantization_params(TensorWrapper* tensor) const override; - std::pair create_tensor( - const std::vector& shape, DType dtype) const override; + std::pair create_tensor(const std::vector& shape, + DType dtype) const override; std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; }; diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 7a68cfc0e1..74c5c5e7a7 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -375,16 +375,16 @@ std::vector fused_attn_bwd( NVTE_ERROR("QKV layout not supported!"); } if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { - auto* fp8_quantizer = dynamic_cast(dQKV_quantizer.get()); + auto *fp8_quantizer = dynamic_cast(dQKV_quantizer.get()); NVTE_CHECK(fp8_quantizer != nullptr, "Expected Float8Quantizer when dtype is FP8"); - std::tie(te_dQ, py_dQ) = fp8_quantizer->create_tensor(q_shape, fake_dtype_te, dQ, - std::nullopt, std::nullopt); - std::tie(te_dK, py_dK) = fp8_quantizer->create_tensor(k_shape, fake_dtype_te, dK, - std::nullopt, std::nullopt); - std::tie(te_dV, py_dV) = fp8_quantizer->create_tensor(v_shape, fake_dtype_te, dV, - std::nullopt, std::nullopt); + std::tie(te_dQ, py_dQ) = + fp8_quantizer->create_tensor(q_shape, fake_dtype_te, dQ, std::nullopt, std::nullopt); + std::tie(te_dK, py_dK) = + fp8_quantizer->create_tensor(k_shape, fake_dtype_te, dK, std::nullopt, std::nullopt); + std::tie(te_dV, py_dV) = + fp8_quantizer->create_tensor(v_shape, fake_dtype_te, dV, std::nullopt, std::nullopt); } else { - auto* none_quantizer = dynamic_cast(dQKV_quantizer.get()); + auto *none_quantizer = dynamic_cast(dQKV_quantizer.get()); NVTE_CHECK(none_quantizer != nullptr, "Expected NoneQuantizer when dtype is not FP8"); std::tie(te_dQ, py_dQ) = none_quantizer->create_tensor(q_shape, fake_dtype_te, dQ); std::tie(te_dK, py_dK) = none_quantizer->create_tensor(k_shape, fake_dtype_te, dK); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index be513d9561..a25c099047 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -37,15 +37,16 @@ Float8Quantizer::Float8Quantizer(const py::handle& quantizer) : Quantizer(quanti this->dtype = type; } -std::pair NoneQuantizer::create_tensor( - const std::vector& shape, DType dtype) const { +std::pair NoneQuantizer::create_tensor(const std::vector& shape, + DType dtype) const { const std::vector shape_int64(shape.begin(), shape.end()); const auto opts = at::TensorOptions().dtype(GetATenDType(dtype)).device(torch::kCUDA); return create_tensor(shape, dtype, at::empty(shape_int64, opts)); } -std::pair NoneQuantizer::create_tensor( - const std::vector& shape, DType dtype, at::Tensor data) const { +std::pair NoneQuantizer::create_tensor(const std::vector& shape, + DType dtype, + at::Tensor data) const { py::object out_py = py::cast(data); TensorWrapper out_cpp; out_cpp.set_rowwise_data(data.data_ptr(), dtype, shape); @@ -104,21 +105,18 @@ std::pair Float8Quantizer::create_tensor( }; // Construct FP8 tensor - return create_tensor(shape, dtype, std::move(data), - std::move(transpose), std::move(scale_inv)); + return create_tensor(shape, dtype, std::move(data), std::move(transpose), std::move(scale_inv)); } std::pair Float8Quantizer::create_tensor( const std::vector& shape, DType dtype, std::optional data, - std::optional transpose, - std::optional scale_inv) const { + std::optional transpose, std::optional scale_inv) const { using namespace pybind11::literals; // Initialize data tensor at::Tensor data_tensor; if (rowwise_usage) { - NVTE_CHECK(data, - "Constructing Float8Tensor with row-wise usage, but no FP8 data was provided"); + NVTE_CHECK(data, "Constructing Float8Tensor with row-wise usage, but no FP8 data was provided"); data_tensor = std::move(*data); } py::object data_py = rowwise_usage ? py::cast(data_tensor) : py::none(); @@ -127,8 +125,9 @@ std::pair Float8Quantizer::create_tensor( at::Tensor transpose_tensor; const bool with_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); if (with_transpose) { - NVTE_CHECK(transpose, - "Constructing Float8Tensor with column-wise usage, but no FP8 transpose was provided"); + NVTE_CHECK( + transpose, + "Constructing Float8Tensor with column-wise usage, but no FP8 transpose was provided"); transpose_tensor = std::move(*transpose); } py::object transpose_py = with_transpose ? py::cast(transpose_tensor) : py::none(); @@ -508,8 +507,8 @@ void MXFP8Quantizer::set_quantization_params(TensorWrapper* tensor) const { columnwise_data.shape); } -std::pair MXFP8Quantizer::create_tensor( - const std::vector& shape, DType dtype) const { +std::pair MXFP8Quantizer::create_tensor(const std::vector& shape, + DType dtype) const { using namespace pybind11::literals; // Tensor dimensions @@ -546,7 +545,7 @@ std::pair MXFP8Quantizer::create_tensor( // Construct Python MXFP8 tensor py::object out_py; - auto py_cast = [] (at::Tensor &tensor, bool need_cast) -> py::object { + auto py_cast = [](at::Tensor& tensor, bool need_cast) -> py::object { return need_cast ? py::cast(tensor) : py::none(); }; py::object rowwise_data_py = py_cast(rowwise_data_tensor, rowwise_usage); @@ -555,14 +554,16 @@ std::pair MXFP8Quantizer::create_tensor( py::object columnwise_scale_inv_py = py_cast(columnwise_scale_inv_tensor, columnwise_usage); if (internal) { py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorBasePythonClass)); - out_py = MXFP8TensorClass("rowwise_data"_a = rowwise_data_py, "columnwise_data"_a = columnwise_data_py, + out_py = MXFP8TensorClass("rowwise_data"_a = rowwise_data_py, + "columnwise_data"_a = columnwise_data_py, "rowwise_scale_inv"_a = rowwise_scale_inv_py, "columnwise_scale_inv"_a = columnwise_scale_inv_py, "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); } else { py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorPythonClass)); out_py = MXFP8TensorClass("shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), - "rowwise_data"_a = rowwise_data_py, "columnwise_data"_a = columnwise_data_py, + "rowwise_data"_a = rowwise_data_py, + "columnwise_data"_a = columnwise_data_py, "rowwise_scale_inv"_a = rowwise_scale_inv_py, "columnwise_scale_inv"_a = columnwise_scale_inv_py, "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); From 6d30bb9a40ace6800841508ae2e4765821cbcc3a Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Thu, 17 Jul 2025 00:11:27 +0000 Subject: [PATCH 04/24] Fix bug when constructing FP8 tensor Signed-off-by: Tim Moon --- transformer_engine/pytorch/csrc/quantizer.cpp | 42 ++++++++++--------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index a25c099047..ef4521e157 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -44,13 +44,11 @@ std::pair NoneQuantizer::create_tensor(const std::vec return create_tensor(shape, dtype, at::empty(shape_int64, opts)); } -std::pair NoneQuantizer::create_tensor(const std::vector& shape, - DType dtype, - at::Tensor data) const { - py::object out_py = py::cast(data); +std::pair NoneQuantizer::create_tensor( + const std::vector& shape, DType dtype, at::Tensor data) const { TensorWrapper out_cpp; out_cpp.set_rowwise_data(data.data_ptr(), dtype, shape); - return {std::move(out_cpp), std::move(out_py)}; + return {std::move(out_cpp), py::cast(data)}; } void Float8Quantizer::set_quantization_params(TensorWrapper* tensor) const { @@ -75,7 +73,8 @@ std::pair Float8Quantizer::create_tensor( const std::vector& shape, DType dtype) const { // Allocate data tensor if needed std::optional data; - if (rowwise_usage) { + const bool with_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); + if (with_data) { const std::vector shape_int64(shape.begin(), shape.end()); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); data.emplace(at::empty(shape_int64, opts)); @@ -115,11 +114,12 @@ std::pair Float8Quantizer::create_tensor( // Initialize data tensor at::Tensor data_tensor; - if (rowwise_usage) { - NVTE_CHECK(data, "Constructing Float8Tensor with row-wise usage, but no FP8 data was provided"); + const bool with_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); + if (with_data) { + NVTE_CHECK(data, "Constructing Float8Tensor, but no FP8 data was provided"); data_tensor = std::move(*data); } - py::object data_py = rowwise_usage ? py::cast(data_tensor) : py::none(); + py::object data_py = with_data ? py::cast(data_tensor) : py::none(); // Initialize transpose tensor at::Tensor transpose_tensor; @@ -133,28 +133,30 @@ std::pair Float8Quantizer::create_tensor( py::object transpose_py = with_transpose ? py::cast(transpose_tensor) : py::none(); // Initialize scale-inverse tensor - at::Tensor scale_inv_tensor = scale_inv ? std::move(*scale_inv) : at::reciprocal(scale); - py::object scale_inv_py = py::cast(scale_inv_tensor); + if (!scale_inv) { + scale_inv.emplace(at::reciprocal(scale)); + } + auto& scale_inv_tensor = *scale_inv; // Construct Python FP8 tensor py::object out_py; if (internal) { py::handle Float8TensorClass(reinterpret_cast(Float8TensorBasePythonClass)); - out_py = Float8TensorClass("data"_a = data_py, "fp8_scale_inv"_a = scale_inv_py, + out_py = Float8TensorClass("data"_a = data_py, "fp8_scale_inv"_a = scale_inv_tensor, "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, "quantizer"_a = this->quantizer); } else { py::handle Float8TensorClass(reinterpret_cast(Float8TensorPythonClass)); const std::vector shape_int64(shape.begin(), shape.end()); out_py = Float8TensorClass("shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), - "data"_a = data_py, "fp8_scale_inv"_a = scale_inv, + "data"_a = data_py, "fp8_scale_inv"_a = scale_inv_tensor, "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, "quantizer"_a = this->quantizer); } // Construct C++ FP8 tensor TensorWrapper out_cpp(this->get_scaling_mode()); - if (rowwise_usage) { + if (with_data) { out_cpp.set_rowwise_data(data_tensor.data_ptr(), this->dtype, shape); out_cpp.set_rowwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32, std::vector{1}); @@ -228,7 +230,8 @@ std::pair Float8CurrentScalingQuantizer::create_tenso // Initialize data tensor at::Tensor data_tensor; - if (rowwise_usage) { + const bool with_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); + if (with_data) { const std::vector shape_int64(shape.begin(), shape.end()); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); data_tensor = at::empty(shape_int64, opts); @@ -259,26 +262,25 @@ std::pair Float8CurrentScalingQuantizer::create_tenso // Construct Python FP8 tensor py::object out_py; - py::object data_py = rowwise_usage ? py::cast(data_tensor) : py::none(); + py::object data_py = with_data ? py::cast(data_tensor) : py::none(); py::object transpose_py = with_transpose ? py::cast(transpose_tensor) : py::none(); - py::object scale_inv_py = py::cast(scale_inv_tensor); if (internal) { py::handle Float8TensorClass(reinterpret_cast(Float8TensorBasePythonClass)); - out_py = Float8TensorClass("data"_a = data_py, "fp8_scale_inv"_a = scale_inv_py, + out_py = Float8TensorClass("data"_a = data_py, "fp8_scale_inv"_a = scale_inv_tensor, "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, "quantizer"_a = this->quantizer); } else { py::handle Float8TensorClass(reinterpret_cast(Float8TensorPythonClass)); const std::vector shape_int64(shape.begin(), shape.end()); out_py = Float8TensorClass("shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), - "data"_a = data_py, "fp8_scale_inv"_a = scale_inv_py, + "data"_a = data_py, "fp8_scale_inv"_a = scale_inv_tensor, "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, "quantizer"_a = this->quantizer); } // Construct C++ FP8 tensor TensorWrapper out_cpp(this->get_scaling_mode()); - if (rowwise_usage) { + if (with_data) { out_cpp.set_rowwise_data(data_tensor.data_ptr(), this->dtype, shape); out_cpp.set_rowwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32, std::vector{1}); From 5fca0a0855b151ba17fb795da3a70b4dc03b2bdc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 17 Jul 2025 00:12:22 +0000 Subject: [PATCH 05/24] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/csrc/quantizer.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index ef4521e157..435a68f9fa 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -44,8 +44,9 @@ std::pair NoneQuantizer::create_tensor(const std::vec return create_tensor(shape, dtype, at::empty(shape_int64, opts)); } -std::pair NoneQuantizer::create_tensor( - const std::vector& shape, DType dtype, at::Tensor data) const { +std::pair NoneQuantizer::create_tensor(const std::vector& shape, + DType dtype, + at::Tensor data) const { TensorWrapper out_cpp; out_cpp.set_rowwise_data(data.data_ptr(), dtype, shape); return {std::move(out_cpp), py::cast(data)}; From dc6fae53d19ac534282ce75f69de34a5f60d593f Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Thu, 17 Jul 2025 02:08:03 +0000 Subject: [PATCH 06/24] Add quantize function to C++ quantizers Signed-off-by: Tim Moon --- transformer_engine/pytorch/csrc/common.h | 12 +++ transformer_engine/pytorch/csrc/quantizer.cpp | 78 +++++++++++++++++++ 2 files changed, 90 insertions(+) diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 4b73ec0347..c43e5d3bc1 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -101,6 +101,8 @@ class Quantizer { virtual std::pair create_tensor(const std::vector& shape, DType dtype) const = 0; + virtual void quantize(const TensorWrapper &input, TensorWrapper &out, const std::optional &noop_flag) = 0; + virtual ~Quantizer() = default; bool rowwise_usage = true; @@ -125,6 +127,8 @@ class NoneQuantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype, at::Tensor data) const; + + void quantize(const TensorWrapper &input, TensorWrapper &out, const std::optional &noop_flag) override; }; class Float8Quantizer : public Quantizer { @@ -147,6 +151,8 @@ class Float8Quantizer : public Quantizer { std::optional data, std::optional transpose, std::optional scale_inv) const; + + void quantize(const TensorWrapper &input, TensorWrapper &out, const std::optional &noop_flag) override; }; class Float8CurrentScalingQuantizer : public Quantizer { @@ -168,6 +174,8 @@ class Float8CurrentScalingQuantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype) const override; + + void quantize(const TensorWrapper &input, TensorWrapper &out, const std::optional &noop_flag) override; }; class Float8BlockQuantizer : public Quantizer { @@ -202,6 +210,8 @@ class Float8BlockQuantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype) const override; + void quantize(const TensorWrapper &input, TensorWrapper &out, const std::optional &noop_flag) override; + std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; }; @@ -218,6 +228,8 @@ class MXFP8Quantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype) const override; + void quantize(const TensorWrapper &input, TensorWrapper &out, const std::optional &noop_flag) override; + std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; }; diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 435a68f9fa..e896bd67ba 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -52,6 +52,10 @@ std::pair NoneQuantizer::create_tensor(const std::vec return {std::move(out_cpp), py::cast(data)}; } +void NoneQuantizer::quantize(const TensorWrapper &input, TensorWrapper &out, const std::optional &noop_flag) { + NVTE_ERROR("Not yet implemented"); /// TODO Implement +} + void Float8Quantizer::set_quantization_params(TensorWrapper* tensor) const { tensor->set_scale(scale.data_ptr(), GetTransformerEngineDType(scale.scalar_type()), getTensorShape(scale)); @@ -179,6 +183,17 @@ std::pair Float8Quantizer::create_tensor( return {std::move(out_cpp), std::move(out_py)}; } +void Float8Quantizer::quantize(const TensorWrapper &input, TensorWrapper &out, const std::optional &noop_flag) { + QuantizationConfigWrapper quant_config; + if (noop_flag) { + quant_config.set_noop_tensor(noop_flag->data()); + } + set_quantization_params(&out); + NVTE_SCOPED_GIL_RELEASE({ + nvte_quantize_v2(input.data(), out.data(), quant_config, at::cuda::getCurrentCUDAStream()); + }); +} + Float8CurrentScalingQuantizer::Float8CurrentScalingQuantizer(const py::handle& quantizer) : Quantizer(quantizer) { const at::Tensor& scale = quantizer.attr("scale").cast(); @@ -303,6 +318,42 @@ std::pair Float8CurrentScalingQuantizer::create_tenso return {std::move(out_cpp), std::move(out_py)}; } +void Float8CurrentScalingQuantizer::quantize(const TensorWrapper &input, TensorWrapper &out, const std::optional &noop_flag) { + auto stream = at::cuda::getCurrentCUDAStream(); + + // Quantization configs + QuantizationConfigWrapper quant_config; + if (noop_flag) { + quant_config.set_noop_tensor(noop_flag->data()); + } + quant_config.set_force_pow_2_scales(force_pow_2_scales); + quant_config.set_amax_epsilon(amax_epsilon); + + // Compute amax + set_quantization_params(&out); + NVTE_SCOPED_GIL_RELEASE({ nvte_compute_amax(input.data(), out.data(), stream); }); + + // Perform amax reduction if needed + if (with_amax_reduction) { + // allreduce amax tensor + c10d::AllreduceOptions opts; + opts.reduceOp = c10d::ReduceOp::MAX; + std::vector tensors = {amax}; + NVTE_SCOPED_GIL_RELEASE({ amax_reduction_group->allreduce(tensors, opts)->wait(); }); + } + + // Compute scaling factor + NVTE_SCOPED_GIL_RELEASE({ + nvte_compute_scale_from_amax(out.data(), quant_config, stream); + }); + + // Cast to FP8 + out.set_amax(nullptr, DType::kFloat32, out.defaultShape); // Avoid atomic amax updates + NVTE_SCOPED_GIL_RELEASE({ + nvte_quantize_v2(input.data(), out.data(), quant_config, stream); + }); +} + Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quantizer(quantizer) { this->dtype = quantizer.attr("dtype").cast(); this->block_scaling_dim = quantizer.attr("block_scaling_dim").cast(); @@ -418,6 +469,22 @@ std::pair Float8BlockQuantizer::create_tensor( return {std::move(tensor), std::move(ret)}; } +void Float8BlockQuantizer::quantize(const TensorWrapper &input, TensorWrapper &out, const std::optional &noop_flag) { + QuantizationConfigWrapper quant_config; + if (noop_flag) { + quant_config.set_noop_tensor(noop_flag->data()); + } + quant_config.set_force_pow_2_scales(force_pow_2_scales); + quant_config.set_amax_epsilon(amax_epsilon); + if (all_gather_usage) { + quant_config.set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat::COMPACT); + } + set_quantization_params(&out); + NVTE_SCOPED_GIL_RELEASE({ + nvte_quantize_v2(input.data(), out.data(), quant_config, at::cuda::getCurrentCUDAStream()); + }); +} + std::vector Float8BlockQuantizer::get_scale_shape(const std::vector& shape, bool columnwise) const { size_t numel = 1; @@ -589,6 +656,17 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve return {std::move(out_cpp), std::move(out_py)}; } +void MXFP8Quantizer::quantize(const TensorWrapper &input, TensorWrapper &out, const std::optional &noop_flag) { + QuantizationConfigWrapper quant_config; + if (noop_flag) { + quant_config.set_noop_tensor(noop_flag->data()); + } + set_quantization_params(&out); + NVTE_SCOPED_GIL_RELEASE({ + nvte_quantize_v2(input.data(), out.data(), quant_config, at::cuda::getCurrentCUDAStream()); + }); +} + std::vector MXFP8Quantizer::get_scale_shape(const std::vector& shape, bool columnwise) const { size_t numel = 1; From 7ac091d7eb1fc9e340571bfc68607d87aea294f5 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Thu, 17 Jul 2025 03:58:42 +0000 Subject: [PATCH 07/24] Prototype function to coerce Python quantized tensors to match quantizer Signed-off-by: Tim Moon --- transformer_engine/pytorch/csrc/common.cpp | 2 +- transformer_engine/pytorch/csrc/common.h | 28 ++- transformer_engine/pytorch/csrc/quantizer.cpp | 224 +++++++++++++++++- 3 files changed, 237 insertions(+), 17 deletions(-) diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index f86b60f612..6d99ec88ae 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -12,7 +12,7 @@ namespace transformer_engine::pytorch { -std::vector getTensorShape(at::Tensor t) { +std::vector getTensorShape(const at::Tensor &t) { std::vector shape; for (auto s : t.sizes()) { shape.push_back(s); diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index c43e5d3bc1..d41ac1e798 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -98,9 +98,11 @@ class Quantizer { virtual void set_quantization_params(TensorWrapper* tensor) const = 0; - virtual std::pair create_tensor(const std::vector& shape, + virtual std::pair create_tensor(const std::vector &shape, DType dtype) const = 0; + virtual std::pair coerce_tensor(py::object tensor) const = 0; + virtual void quantize(const TensorWrapper &input, TensorWrapper &out, const std::optional &noop_flag) = 0; virtual ~Quantizer() = default; @@ -122,12 +124,14 @@ class NoneQuantizer : public Quantizer { void set_quantization_params(TensorWrapper* tensor) const override {} - std::pair create_tensor(const std::vector& shape, + std::pair create_tensor(const std::vector &shape, DType dtype) const override; - std::pair create_tensor(const std::vector& shape, DType dtype, + std::pair create_tensor(const std::vector &shape, DType dtype, at::Tensor data) const; + std::pair coerce_tensor(py::object tensor) const override; + void quantize(const TensorWrapper &input, TensorWrapper &out, const std::optional &noop_flag) override; }; @@ -144,14 +148,16 @@ class Float8Quantizer : public Quantizer { void set_quantization_params(TensorWrapper* tensor) const override; - std::pair create_tensor(const std::vector& shape, + std::pair create_tensor(const std::vector &shape, DType dtype) const override; - std::pair create_tensor(const std::vector& shape, DType dtype, + std::pair create_tensor(const std::vector &shape, DType dtype, std::optional data, std::optional transpose, std::optional scale_inv) const; + std::pair coerce_tensor(py::object shape) const override; + void quantize(const TensorWrapper &input, TensorWrapper &out, const std::optional &noop_flag) override; }; @@ -175,6 +181,8 @@ class Float8CurrentScalingQuantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype) const override; + std::pair coerce_tensor(py::object shape) const override; + void quantize(const TensorWrapper &input, TensorWrapper &out, const std::optional &noop_flag) override; }; @@ -207,9 +215,11 @@ class Float8BlockQuantizer : public Quantizer { // Create a python Float8BlockQuantized tensor and C++ wrapper // for the tensor. Should set quantized data, scales for rowwise // and optionally columnwise usage. - std::pair create_tensor(const std::vector& shape, + std::pair create_tensor(const std::vector &shape, DType dtype) const override; + std::pair coerce_tensor(py::object shape) const override; + void quantize(const TensorWrapper &input, TensorWrapper &out, const std::optional &noop_flag) override; std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; @@ -225,9 +235,11 @@ class MXFP8Quantizer : public Quantizer { void set_quantization_params(TensorWrapper* tensor) const override; - std::pair create_tensor(const std::vector& shape, + std::pair create_tensor(const std::vector &shape, DType dtype) const override; + std::pair coerce_tensor(py::object shape) const override; + void quantize(const TensorWrapper &input, TensorWrapper &out, const std::optional &noop_flag) override; std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; @@ -235,7 +247,7 @@ class MXFP8Quantizer : public Quantizer { std::unique_ptr convert_quantizer(py::handle quantizer); -std::vector getTensorShape(at::Tensor t); +std::vector getTensorShape(const at::Tensor &t); transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid, const std::string& fp8_recipe); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index e896bd67ba..b488b280a0 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -52,6 +52,10 @@ std::pair NoneQuantizer::create_tensor(const std::vec return {std::move(out_cpp), py::cast(data)}; } +std::pair NoneQuantizer::coerce_tensor(py::object tensor) const { + NVTE_ERROR("Not yet implemented"); /// TODO Implement +} + void NoneQuantizer::quantize(const TensorWrapper &input, TensorWrapper &out, const std::optional &noop_flag) { NVTE_ERROR("Not yet implemented"); /// TODO Implement } @@ -82,7 +86,7 @@ std::pair Float8Quantizer::create_tensor( if (with_data) { const std::vector shape_int64(shape.begin(), shape.end()); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); - data.emplace(at::empty(shape_int64, opts)); + data = at::empty(shape_int64, opts); } // Allocate transpose tensor if needed @@ -97,7 +101,7 @@ std::pair Float8Quantizer::create_tensor( } } const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); - transpose.emplace(at::empty(transpose_shape_int64, opts)); + transpose = at::empty(transpose_shape_int64, opts); } // Allocate scale-inverse tensor @@ -105,7 +109,7 @@ std::pair Float8Quantizer::create_tensor( { const std::vector scale_inv_shape = {1}; const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); - scale_inv.emplace(at::empty(scale_inv_shape, opts)); + scale_inv = at::empty(scale_inv_shape, opts); }; // Construct FP8 tensor @@ -139,7 +143,7 @@ std::pair Float8Quantizer::create_tensor( // Initialize scale-inverse tensor if (!scale_inv) { - scale_inv.emplace(at::reciprocal(scale)); + scale_inv = at::reciprocal(scale); } auto& scale_inv_tensor = *scale_inv; @@ -183,12 +187,111 @@ std::pair Float8Quantizer::create_tensor( return {std::move(out_cpp), std::move(out_py)}; } +std::pair Float8Quantizer::coerce_tensor(py::object tensor) const { + NVTE_CHECK(detail::IsFloat8Tensor(tensor.ptr()), "Float8Quantizer must output to Float8Tensor."); + + // Expected buffers + const bool need_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); + const bool need_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + NVTE_CHECK(need_data || need_transpose, "Invalid quantizer usages."); + + // Extract buffers from Python tensor + auto data_py = tensor.attr("_data"); + auto transpose_py = tensor.attr("_transpose"); + const bool has_data = !data_py.is_none(); + const bool has_transpose = !transpose_py.is_none(); + NVTE_CHECK(has_data || has_transpose, "Tensor has no data."); + std::optional data_tensor, transpose_tensor; + if (has_data) { data_tensor = data_py.cast(); } + if (has_transpose) { transpose_tensor = transpose_py.cast(); } + at::Tensor scale_inv_tensor = tensor.attr("_scale_inv").cast(); + + // Tensor dimensions + std::vector shape; + if (has_transpose) { + const auto transpose_shape = getTensorShape(*transpose_tensor); + if (transpose_shape.size() > 0) { + for (size_t i = 1; i < transpose_shape.size(); ++i) { + shape.push_back(transpose_shape[i]); + } + shape.push_back(transpose_shape.front()); + } + if (has_data) { + auto expected_shape = getTensorShape(*data_tensor); + NVTE_CHECK(shape == expected_shape, + "FP8 data (shape=", expected_shape, ") and transpose (shape=", + transpose_shape, ") do not match"); + } + } else { // Already checked has_data == true + shape = getTensorShape(*data_tensor); + } + + // Coerce data tensor in Python tensor + if (has_data && !need_data) { + data_tensor.reset(); + data_py = py::none(); + tensor.attr("_data") = data_py; + } else if (!has_data && need_data) { + const std::vector shape_int64(shape.begin(), shape.end()); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + data_tensor = at::empty(shape_int64, opts); + data_py = py::cast(data_tensor); + tensor.attr("_data") = data_py; + } + + // Coerce transpose tensor + if (has_transpose && !need_transpose) { + transpose_tensor.reset(); + transpose_py = py::none(); + tensor.attr("_transpose") = transpose_py; + } else if (!has_transpose && need_transpose) { + std::vector transpose_shape_int64; + if (shape.size() > 0) { + transpose_shape_int64.push_back(shape.back()); + for (size_t i = 0; i < shape.size(); ++i) { + transpose_shape_int64.push_back(shape[i]); + } + } + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + transpose_tensor = at::empty(transpose_shape_int64, opts); + transpose_py = py::cast(transpose_tensor); + tensor.attr("_transpose") = transpose_py; + } + tensor.attr("_transpose_invalid") = !need_transpose; + + // Coerce other attrs + tensor.attr("_fp8_dtype") = dtype; + tensor.attr("_quantizer") = quantizer; /// TODO Need to make copy? + + // Construct C++ FP8 tensor + TensorWrapper out_cpp; + if (data_tensor) { + out_cpp.set_rowwise_data(data_tensor->data_ptr(), this->dtype, shape); + out_cpp.set_rowwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32, + std::vector{1}); + } + if (transpose_tensor) { + std::vector transpose_shape; + if (shape.size() > 0) { + transpose_shape.push_back(shape.back()); + for (size_t i = 0; i < shape.size() - 1; ++i) { + transpose_shape.push_back(shape[i]); + } + } + out_cpp.set_columnwise_data(transpose_tensor->data_ptr(), this->dtype, transpose_shape); + out_cpp.set_columnwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32, + std::vector{1}); + } + this->set_quantization_params(&out_cpp); + + return {std::move(out_cpp), std::move(tensor)}; +} + void Float8Quantizer::quantize(const TensorWrapper &input, TensorWrapper &out, const std::optional &noop_flag) { QuantizationConfigWrapper quant_config; if (noop_flag) { quant_config.set_noop_tensor(noop_flag->data()); } - set_quantization_params(&out); NVTE_SCOPED_GIL_RELEASE({ nvte_quantize_v2(input.data(), out.data(), quant_config, at::cuda::getCurrentCUDAStream()); }); @@ -318,6 +421,106 @@ std::pair Float8CurrentScalingQuantizer::create_tenso return {std::move(out_cpp), std::move(out_py)}; } +std::pair Float8CurrentScalingQuantizer::coerce_tensor(py::object tensor) const { + NVTE_CHECK(detail::IsFloat8Tensor(tensor.ptr()), "Float8CurrentScalingQuantizer must output to Float8Tensor."); + + // Expected buffers + const bool need_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); + const bool need_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + NVTE_CHECK(need_data || need_transpose, "Invalid quantizer usages."); + + // Extract buffers from Python tensor + auto data_py = tensor.attr("_data"); + auto transpose_py = tensor.attr("_transpose"); + const bool has_data = !data_py.is_none(); + const bool has_transpose = !transpose_py.is_none(); + NVTE_CHECK(has_data || has_transpose, "Tensor has no data."); + std::optional data_tensor, transpose_tensor; + if (has_data) { data_tensor = data_py.cast(); } + if (has_transpose) { transpose_tensor = transpose_py.cast(); } + at::Tensor scale_inv_tensor = tensor.attr("_scale_inv").cast(); + + // Tensor dimensions + std::vector shape; + if (has_transpose) { + const auto transpose_shape = getTensorShape(*transpose_tensor); + if (transpose_shape.size() > 0) { + for (size_t i = 1; i < transpose_shape.size(); ++i) { + shape.push_back(transpose_shape[i]); + } + shape.push_back(transpose_shape.front()); + } + if (has_data) { + auto expected_shape = getTensorShape(*data_tensor); + NVTE_CHECK(shape == expected_shape, + "FP8 data (shape=", expected_shape, ") and transpose (shape=", + transpose_shape, ") do not match"); + } + } else { // Already checked has_data == true + shape = getTensorShape(*data_tensor); + } + + // Coerce data tensor in Python tensor + if (has_data && !need_data) { + data_tensor.reset(); + data_py = py::none(); + tensor.attr("_data") = data_py; + } else if (!has_data && need_data) { + const std::vector shape_int64(shape.begin(), shape.end()); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + data_tensor = at::empty(shape_int64, opts); + data_py = py::cast(data_tensor); + tensor.attr("_data") = data_py; + } + + // Coerce transpose tensor + if (has_transpose && !need_transpose) { + transpose_tensor.reset(); + transpose_py = py::none(); + tensor.attr("_transpose") = transpose_py; + } else if (!has_transpose && need_transpose) { + std::vector transpose_shape_int64; + if (shape.size() > 0) { + transpose_shape_int64.push_back(shape.back()); + for (size_t i = 0; i < shape.size(); ++i) { + transpose_shape_int64.push_back(shape[i]); + } + } + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + transpose_tensor = at::empty(transpose_shape_int64, opts); + transpose_py = py::cast(transpose_tensor); + tensor.attr("_transpose") = transpose_py; + } + tensor.attr("_transpose_invalid") = !need_transpose; + + // Coerce other attrs + tensor.attr("_fp8_dtype") = dtype; + tensor.attr("_quantizer") = quantizer; /// TODO Need to make copy? + + // Construct C++ FP8 tensor + TensorWrapper out_cpp; + if (data_tensor) { + out_cpp.set_rowwise_data(data_tensor->data_ptr(), this->dtype, shape); + out_cpp.set_rowwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32, + std::vector{1}); + } + if (transpose_tensor) { + std::vector transpose_shape; + if (shape.size() > 0) { + transpose_shape.push_back(shape.back()); + for (size_t i = 0; i < shape.size() - 1; ++i) { + transpose_shape.push_back(shape[i]); + } + } + out_cpp.set_columnwise_data(transpose_tensor->data_ptr(), this->dtype, transpose_shape); + out_cpp.set_columnwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32, + std::vector{1}); + } + this->set_quantization_params(&out_cpp); + + return {std::move(out_cpp), std::move(tensor)}; +} + void Float8CurrentScalingQuantizer::quantize(const TensorWrapper &input, TensorWrapper &out, const std::optional &noop_flag) { auto stream = at::cuda::getCurrentCUDAStream(); @@ -330,7 +533,6 @@ void Float8CurrentScalingQuantizer::quantize(const TensorWrapper &input, TensorW quant_config.set_amax_epsilon(amax_epsilon); // Compute amax - set_quantization_params(&out); NVTE_SCOPED_GIL_RELEASE({ nvte_compute_amax(input.data(), out.data(), stream); }); // Perform amax reduction if needed @@ -469,6 +671,10 @@ std::pair Float8BlockQuantizer::create_tensor( return {std::move(tensor), std::move(ret)}; } +std::pair Float8BlockQuantizer::coerce_tensor(py::object tensor) const { + NVTE_ERROR("Not yet implemented"); /// TODO Implement +} + void Float8BlockQuantizer::quantize(const TensorWrapper &input, TensorWrapper &out, const std::optional &noop_flag) { QuantizationConfigWrapper quant_config; if (noop_flag) { @@ -479,7 +685,6 @@ void Float8BlockQuantizer::quantize(const TensorWrapper &input, TensorWrapper &o if (all_gather_usage) { quant_config.set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat::COMPACT); } - set_quantization_params(&out); NVTE_SCOPED_GIL_RELEASE({ nvte_quantize_v2(input.data(), out.data(), quant_config, at::cuda::getCurrentCUDAStream()); }); @@ -656,12 +861,15 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve return {std::move(out_cpp), std::move(out_py)}; } +std::pair MXFP8Quantizer::coerce_tensor(py::object tensor) const { + NVTE_ERROR("Not yet implemented"); /// TODO Implement +} + void MXFP8Quantizer::quantize(const TensorWrapper &input, TensorWrapper &out, const std::optional &noop_flag) { QuantizationConfigWrapper quant_config; if (noop_flag) { quant_config.set_noop_tensor(noop_flag->data()); } - set_quantization_params(&out); NVTE_SCOPED_GIL_RELEASE({ nvte_quantize_v2(input.data(), out.data(), quant_config, at::cuda::getCurrentCUDAStream()); }); From b30a4b40f6c59cfdaf3b03dd5e0f25a2d4d52449 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Thu, 17 Jul 2025 04:11:18 +0000 Subject: [PATCH 08/24] Use quantizer class in tex.quantize Signed-off-by: Tim Moon --- transformer_engine/pytorch/csrc/common.h | 18 +++-- .../pytorch/csrc/extensions/cast.cpp | 65 ++----------------- 2 files changed, 16 insertions(+), 67 deletions(-) diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index d41ac1e798..ebf4431e27 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -103,7 +103,8 @@ class Quantizer { virtual std::pair coerce_tensor(py::object tensor) const = 0; - virtual void quantize(const TensorWrapper &input, TensorWrapper &out, const std::optional &noop_flag) = 0; + virtual void quantize(const TensorWrapper &input, TensorWrapper &out, + const std::optional &noop_flag = std::nullopt) = 0; virtual ~Quantizer() = default; @@ -132,7 +133,8 @@ class NoneQuantizer : public Quantizer { std::pair coerce_tensor(py::object tensor) const override; - void quantize(const TensorWrapper &input, TensorWrapper &out, const std::optional &noop_flag) override; + void quantize(const TensorWrapper &input, TensorWrapper &out, + const std::optional &noop_flag = std::nullopt) override; }; class Float8Quantizer : public Quantizer { @@ -158,7 +160,8 @@ class Float8Quantizer : public Quantizer { std::pair coerce_tensor(py::object shape) const override; - void quantize(const TensorWrapper &input, TensorWrapper &out, const std::optional &noop_flag) override; + void quantize(const TensorWrapper &input, TensorWrapper &out, + const std::optional &noop_flag = std::nullopt) override; }; class Float8CurrentScalingQuantizer : public Quantizer { @@ -183,7 +186,8 @@ class Float8CurrentScalingQuantizer : public Quantizer { std::pair coerce_tensor(py::object shape) const override; - void quantize(const TensorWrapper &input, TensorWrapper &out, const std::optional &noop_flag) override; + void quantize(const TensorWrapper &input, TensorWrapper &out, + const std::optional &noop_flag = std::nullopt) override; }; class Float8BlockQuantizer : public Quantizer { @@ -220,7 +224,8 @@ class Float8BlockQuantizer : public Quantizer { std::pair coerce_tensor(py::object shape) const override; - void quantize(const TensorWrapper &input, TensorWrapper &out, const std::optional &noop_flag) override; + void quantize(const TensorWrapper &input, TensorWrapper &out, + const std::optional &noop_flag = std::nullopt) override; std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; }; @@ -240,7 +245,8 @@ class MXFP8Quantizer : public Quantizer { std::pair coerce_tensor(py::object shape) const override; - void quantize(const TensorWrapper &input, TensorWrapper &out, const std::optional &noop_flag) override; + void quantize(const TensorWrapper &input, TensorWrapper &out, + const std::optional &noop_flag = std::nullopt) override; std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; }; diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 27a355ca92..e7f0da647e 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -28,60 +28,6 @@ std::vector get_tensor_shape(const TensorWrapper &tensor) { return std::vector(shape.data, shape.data + shape.ndim); } -void quantize_impl(const TensorWrapper &input, py::handle &quantizer_py, - std::unique_ptr &quantizer_cpp, TensorWrapper &output, - TensorWrapper &noop_flag) { - // Check tensor dims - NVTE_CHECK(get_tensor_shape(input) == get_tensor_shape(output), - "Input tensor (shape=", get_tensor_shape(input), - ") and output tensor (shape=", get_tensor_shape(output), ") do not match"); - if (input.numel() == 0) { - return; - } - - // Recipe-specific configuration - QuantizationConfigWrapper quant_config; - quant_config.set_noop_tensor(noop_flag.data()); - if (detail::IsFloat8CurrentScalingQuantizers(quantizer_py.ptr())) { - auto my_quantizer_cs = static_cast(quantizer_cpp.get()); - NVTE_SCOPED_GIL_RELEASE( - { nvte_compute_amax(input.data(), output.data(), at::cuda::getCurrentCUDAStream()); }); - // check if we need to do amax reudction (depending on model parallel configs) - if (my_quantizer_cs->with_amax_reduction) { - c10::intrusive_ptr process_group_ptr = my_quantizer_cs->amax_reduction_group; - // construct torch tesnor from NVTEBasicTensor without reallocating memory - at::Tensor &amax_tensor_torch = my_quantizer_cs->amax; - std::vector tensors = {amax_tensor_torch}; - // allreduce amax tensor - c10d::AllreduceOptions allreduce_opts; - allreduce_opts.reduceOp = c10d::ReduceOp::MAX; - process_group_ptr->allreduce(tensors, allreduce_opts)->wait(); - } - // this config is used for cs scaling factor computation - // because compute scale is cannot be fused with quantize kernel - // so in nvte_quantize_v2 with current scaling, the quant config is not used again - quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); - quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); - NVTE_SCOPED_GIL_RELEASE({ - nvte_compute_scale_from_amax(output.data(), quant_config, at::cuda::getCurrentCUDAStream()); - }); - // set amax ptr to null in output TensorWrapper to avoid atomic amax updates in kernel - output.set_amax(nullptr, DType::kFloat32, output.defaultShape); - } else if (detail::IsFloat8BlockwiseQuantizers(quantizer_py.ptr())) { - auto my_quantizer_bw = static_cast(quantizer_cpp.get()); - quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales); - quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon); - if (my_quantizer_bw->all_gather_usage) { - quant_config.set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat::COMPACT); - } - } - - // Perform quantization - NVTE_SCOPED_GIL_RELEASE({ - nvte_quantize_v2(input.data(), output.data(), quant_config, at::cuda::getCurrentCUDAStream()); - }); -} - } // namespace py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::object &output, @@ -101,18 +47,17 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob const auto fake_dtype = input_cpp.dtype(); std::tie(output_cpp, output_py) = quantizer_cpp->create_tensor(shape, fake_dtype); } else { - output_py = output; - output_cpp = makeTransformerEngineTensor(output_py, quantizer); + std::tie(output_cpp, output_py) = quantizer_cpp->coerce_tensor(output); } // Initialize no-op flag - TensorWrapper noop_flag_cpp; + std::optional noop_flag_cpp; if (noop_flag.has_value()) { noop_flag_cpp = makeTransformerEngineTensor(*noop_flag); } // Perform quantization - quantize_impl(input_cpp, quantizer, quantizer_cpp, output_cpp, noop_flag_cpp); + quantizer_cpp->quantize(input_cpp, output_cpp, noop_flag_cpp); return output_py; } @@ -181,10 +126,8 @@ void multi_tensor_quantize_impl(const std::vector &input_list, }); } else { // Quantize kernels individually - TensorWrapper dummy_noop_flag; for (size_t i = 0; i < num_tensors; ++i) { - quantize_impl(input_list[i], quantizer_py_list[i], quantizer_cpp_list[i], output_list[i], - dummy_noop_flag); + quantizer_cpp_list[i]->quantize(input_list[i], output_list[i]); } } } From 23be7bedadd2bb7e811d82e1e7f738a537c5d6c6 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Thu, 17 Jul 2025 04:58:52 +0000 Subject: [PATCH 09/24] Add FP8 current scaling support for activation backward Signed-off-by: Tim Moon --- .../pytorch/csrc/extensions/activation.cpp | 126 ++++++++---------- 1 file changed, 56 insertions(+), 70 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index dfc8a82913..3e65fbf586 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -13,87 +13,73 @@ namespace transformer_engine::pytorch { template py::object activation_helper(const at::Tensor& input, py::handle quantizer, int shape_divisor = 1) { init_extension(); - auto my_quantizer = convert_quantizer(quantizer); - auto input_tensor = input.contiguous(); - - const TensorWrapper& te_input = makeTransformerEngineTensor(input_tensor); - const auto& te_input_shape = te_input.shape(); - std::vector input_shape(te_input_shape.data, te_input_shape.data + te_input_shape.ndim); - input_shape[input_shape.size() - 1] /= shape_divisor; - auto fake_tensor_type = input.scalar_type(); - - auto [te_output, out] = - my_quantizer->create_tensor(input_shape, GetTransformerEngineDType(fake_tensor_type)); - - // for current scaling, we need to compute amax first and then quantize - // because cache cannot fit in the entire tensor to compute amax and quantize - // the quantizer should not need amax reduction, no process group needed here - if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { - // activation function might change the input data range, we need to first call the activation function - // and then find the amax and scale of that and then do the quantization - // get a NoneQuantizer to calculate amax of activation output - auto my_quantizer_none = std::make_unique(py::none()); - auto [te_output_act, out_act] = - my_quantizer_none->create_tensor(input_shape, GetTransformerEngineDType(fake_tensor_type)); - - NVTE_SCOPED_GIL_RELEASE({ - act_func(te_input.data(), te_output_act.data(), at::cuda::getCurrentCUDAStream()); - // use te_output_act as input to the compute amax and find the amax of activated tensor - nvte_compute_amax(te_output_act.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); - }); - // my_quantizer here has to be a Float8CurrentScalingQuantizer - auto my_quantizer_cs = static_cast(my_quantizer.get()); - if (my_quantizer_cs->with_amax_reduction) { - NVTE_ERROR( - "per-tensor current scaling amax reduction is not supported in activation functions."); - } - QuantizationConfigWrapper quant_config; - quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); - quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); - - NVTE_SCOPED_GIL_RELEASE({ - nvte_compute_scale_from_amax(te_output.data(), quant_config, - at::cuda::getCurrentCUDAStream()); - // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel - te_output.set_amax(nullptr, DType::kFloat32, te_output.defaultShape); - nvte_quantize_v2(te_output_act.data(), te_output.data(), quant_config, - at::cuda::getCurrentCUDAStream()); - }); - } else if (detail::IsFloat8BlockwiseQuantizers(quantizer.ptr())) { - // sanity check, since activation fusion is not supported for blockwise quantization yet - // need to raise an error here instead of silently going into act_func with wrong numerics - NVTE_ERROR("Activation fusion is not supported for blockwise quantization yet."); + // Input tensor + auto input_tensor = input.contiguous(); + const TensorWrapper& input_cpp = makeTransformerEngineTensor(input_tensor); + + // Construct output tensor + auto quantizer_cpp = convert_quantizer(quantizer); + const auto input_shape = input_cpp.shape(); + std::vector output_shape(input_shape.data, input_shape.data + input_shape.ndim); + output_shape.back() /= shape_divisor; + auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type()); + auto [out_cpp, out_py] = quantizer_cpp->create_tensor(output_shape, fake_dtype); + + // Compute activation + if (quantizer.is_none() + || detail::IsFloat8Quantizers(quantizer.ptr()) + || detail::IsMXFP8Quantizers(quantizer.ptr())) { + // Compute activation directly + NVTE_SCOPED_GIL_RELEASE( + { act_func(input_cpp.data(), out_cpp.data(), at::cuda::getCurrentCUDAStream()); }); } else { + // Compute activation in high-precision, then quantize + auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(output_shape, fake_dtype); NVTE_SCOPED_GIL_RELEASE( - { act_func(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); }); + { act_func(input_cpp.data(), temp_cpp.data(), at::cuda::getCurrentCUDAStream()); }); + quantizer_cpp->quantize(temp_cpp, out_cpp); } - return out; + return out_py; } -template -py::object dactivation_helper(const at::Tensor& grad, const at::Tensor& input, +template +py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& input, py::handle quantizer) { init_extension(); - auto my_quantizer = convert_quantizer(quantizer); - auto input_tensor = input.contiguous(); - auto grad_tensor = grad.contiguous(); - - const TensorWrapper& te_input = makeTransformerEngineTensor(input_tensor); - const TensorWrapper& te_grad = makeTransformerEngineTensor(grad_tensor); - const auto& te_input_shape = te_input.shape(); - std::vector input_shape(te_input_shape.data, te_input_shape.data + te_input_shape.ndim); - auto fake_tensor_type = input.scalar_type(); - - auto [te_output, out] = - my_quantizer->create_tensor(input_shape, GetTransformerEngineDType(fake_tensor_type)); - NVTE_SCOPED_GIL_RELEASE({ - act_func(te_grad.data(), te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); - }); + // Grad output and input tensors + auto grad_output_tensor = grad_output.contiguous(); + auto input_tensor = input.contiguous(); + const TensorWrapper& grad_output_cpp = makeTransformerEngineTensor(grad_output_tensor); + const TensorWrapper& input_cpp = makeTransformerEngineTensor(input_tensor); + + // Construct grad input tensor + auto quantizer_cpp = convert_quantizer(quantizer); + const auto input_shape_te = input_cpp.shape(); + const std::vector input_shape(input_shape_te.data, input_shape_te.data + input_shape_te.ndim); + auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type()); + auto [grad_input_cpp, grad_input_py] = quantizer_cpp->create_tensor(input_shape, fake_dtype); + + // Compute activation backward + if (quantizer.is_none() + || detail::IsFloat8Quantizers(quantizer.ptr()) + || detail::IsMXFP8Quantizers(quantizer.ptr())) { + // Compute activation backward directly + NVTE_SCOPED_GIL_RELEASE({ + dact_func(grad_output_cpp.data(), input_cpp.data(), grad_input_cpp.data(), at::cuda::getCurrentCUDAStream()); + }); + } else { + // Compute activation backward in high-precision, then quantize + auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(input_shape, fake_dtype); + NVTE_SCOPED_GIL_RELEASE({ + dact_func(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), at::cuda::getCurrentCUDAStream()); + }); + quantizer_cpp->quantize(temp_cpp, grad_input_cpp); + } - return out; + return grad_input_py; } py::object gelu(const at::Tensor& input, py::handle quantizer) { From 302a77d6e43ee0e53e626536c4102e86b9a5832e Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Thu, 17 Jul 2025 05:07:10 +0000 Subject: [PATCH 10/24] Disable quantized GEMM output with FP8 current scaling Signed-off-by: Tim Moon --- tests/pytorch/test_fusible_ops.py | 7 +++-- .../pytorch/ops/basic/basic_linear.py | 26 +++++++------------ .../ops/fused/userbuffers_forward_linear.py | 2 +- 3 files changed, 14 insertions(+), 21 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 778ae687f3..3bd39dd667 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -786,10 +786,9 @@ def _test_basic_linear( pytest.skip("FP8 output is only supported with FP8 GEMMs") if quantized_grad_input and not quantized_compute: pytest.skip("FP8 grad input is only supported with FP8 GEMMs") - if quantization == "mxfp8" and quantized_output: - pytest.skip("MXFP8 output is not supported with MXFP8 GEMMs") - if quantization == "mxfp8" and quantized_grad_input: - pytest.skip("MXFP8 grad input is not supported with MXFP8 GEMMs") + if quantization not in (None, "fp8"): + if quantized_output or quantized_grad_input: + pytest.skip("Recipe does not support quantized GEMM output") # Random data x_ref, x_test = make_reference_and_test_tensors( diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 59fc096072..5766aaaab9 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -516,18 +516,11 @@ def _functional_forward( raise ValueError("Output tensor is quantized, but quantizer was not provided") else: output_quantizer = None - if isinstance(output_quantizer, MXFP8Quantizer): - raise RuntimeError( - "Attempting to generate MXFP8 output tensor, " - "but GEMM with MXFP8 output is not supported" - ) - if isinstance(output_quantizer, Float8BlockQuantizer): - raise RuntimeError( - "Attempting to generate Float8BlockQuantized output tensor, " - "but GEMM with Float8BlockQuantized output is not supported" - ) - if output_quantizer is not None: + if not isinstance(output_quantizer, Float8Quantizer): + raise RuntimeError( + "Attempting to generate quantized output tensor with unsupported quantizer" + ) output_quantizer.set_usage(rowwise=True, columnwise=False) # Check if accumulating into output tensor @@ -801,11 +794,12 @@ def _functional_backward( ) else: grad_input_quantizer = None - if isinstance(grad_input_quantizer, MXFP8Quantizer): - raise RuntimeError( - "Attempting to generate MXFP8 grad input tensor, " - "but GEMM with MXFP8 output is not supported" - ) + if grad_input_quantizer is not None: + if not isinstance(grad_input_quantizer, Float8Quantizer): + raise RuntimeError( + "Attempting to generate quantized grad input tensor " + "with unsupported quantizer" + ) # Check if accumulating into grad input tensor if accumulate_into_grad_input: diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index 30d9cdaaeb..0808511654 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -182,7 +182,7 @@ def _functional_forward( if weight_quantizer is None: raise ValueError("Missing quantizer for weight tensor") if output_quantizer is not None: - raise ValueError("FP8 output is not supported") + raise ValueError("Quantized output is not supported") else: input_quantizer = None weight_quantizer = None From 952333af7c868b5d51bd19b591aa0a95d14edc71 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Thu, 17 Jul 2025 05:55:08 +0000 Subject: [PATCH 11/24] Add coerce_tensor functions for MXFP8 and DSv3 Signed-off-by: Tim Moon --- transformer_engine/pytorch/csrc/quantizer.cpp | 155 +++++++++++++++++- 1 file changed, 148 insertions(+), 7 deletions(-) diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index b488b280a0..7c6def8ea0 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -49,15 +49,22 @@ std::pair NoneQuantizer::create_tensor(const std::vec at::Tensor data) const { TensorWrapper out_cpp; out_cpp.set_rowwise_data(data.data_ptr(), dtype, shape); + set_quantization_params(&out_cpp); return {std::move(out_cpp), py::cast(data)}; } std::pair NoneQuantizer::coerce_tensor(py::object tensor) const { - NVTE_ERROR("Not yet implemented"); /// TODO Implement + auto tensor_pyt = tensor.cast(); + TensorWrapper out_cpp; + out_cpp.set_rowwise_data(tensor_pyt.data_ptr(), + GetTransformerEngineDType(tensor_pyt.scalar_type()), + getTensorShape(tensor_pyt)); + set_quantization_params(&out_cpp); + return {std::move(out_cpp), std::move(tensor)}; } void NoneQuantizer::quantize(const TensorWrapper &input, TensorWrapper &out, const std::optional &noop_flag) { - NVTE_ERROR("Not yet implemented"); /// TODO Implement + NVTE_ERROR("NoneQuantizer does not support quantization"); } void Float8Quantizer::set_quantization_params(TensorWrapper* tensor) const { @@ -193,14 +200,14 @@ std::pair Float8Quantizer::coerce_tensor(py::object t // Expected buffers const bool need_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); const bool need_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); - NVTE_CHECK(need_data || need_transpose, "Invalid quantizer usages."); + NVTE_CHECK(need_data || need_transpose, "Invalid usages for Float8Quantizer."); // Extract buffers from Python tensor auto data_py = tensor.attr("_data"); auto transpose_py = tensor.attr("_transpose"); const bool has_data = !data_py.is_none(); const bool has_transpose = !transpose_py.is_none(); - NVTE_CHECK(has_data || has_transpose, "Tensor has no data."); + NVTE_CHECK(has_data || has_transpose, "Float8Tensor has no data."); std::optional data_tensor, transpose_tensor; if (has_data) { data_tensor = data_py.cast(); } if (has_transpose) { transpose_tensor = transpose_py.cast(); } @@ -226,7 +233,7 @@ std::pair Float8Quantizer::coerce_tensor(py::object t shape = getTensorShape(*data_tensor); } - // Coerce data tensor in Python tensor + // Coerce data tensor if (has_data && !need_data) { data_tensor.reset(); data_py = py::none(); @@ -672,7 +679,39 @@ std::pair Float8BlockQuantizer::create_tensor( } std::pair Float8BlockQuantizer::coerce_tensor(py::object tensor) const { - NVTE_ERROR("Not yet implemented"); /// TODO Implement + const DType dtype = tensor.attr("_fp8_dtype").cast(); + bool is_2D_scaled = tensor.attr("_is_2D_scaled").cast(); + + // Check the data matches quantizer usages + NVTE_CHECK(!tensor.attr("_rowwise_data").is_none() == rowwise_usage, + "Float8BlockwiseQTensor does not match quantizer usages (has_rowwise_data=", + !tensor.attr("_rowwise_data").is_none(), ", rowwise_usage=", rowwise_usage); + NVTE_CHECK(!tensor.attr("_columnwise_data").is_none() == columnwise_usage, + "Float8BlockwiseQTensor does not match quantizer usages (has_columnwise_data=", + !tensor.attr("_columnwise_data").is_none(), ", columnwise_usage=", columnwise_usage); + + auto ret = TensorWrapper(is_2D_scaled ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D); + + if (rowwise_usage) { + const at::Tensor &data_rowwise = tensor.attr("_rowwise_data").cast(); + const at::Tensor &scale_inv_rowwise = tensor.attr("_rowwise_scale_inv").cast(); + void *scale_inv_rowwise_dptr = scale_inv_rowwise.data_ptr(); + const auto &rowwise_shape = getTensorShape(data_rowwise); + ret.set_rowwise_data(data_rowwise.data_ptr(), dtype, rowwise_shape); + const auto scale_inv_rowwise_shape = getTensorShape(scale_inv_rowwise); + ret.set_rowwise_scale_inv(scale_inv_rowwise_dptr, DType::kFloat32, scale_inv_rowwise_shape); + } + if (columnwise_usage) { + const at::Tensor &data_colwise = tensor.attr("_columnwise_data").cast(); + const at::Tensor &scale_inv_colwise = tensor.attr("_columnwise_scale_inv").cast(); + void *scale_inv_colwise_dptr = scale_inv_colwise.data_ptr(); + const auto &shape = getTensorShape(data_colwise); + ret.set_columnwise_data(data_colwise.data_ptr(), dtype, shape); + const auto scale_inv_colwise_shape = getTensorShape(scale_inv_colwise); + ret.set_columnwise_scale_inv(scale_inv_colwise_dptr, DType::kFloat32, scale_inv_colwise_shape); + } + set_quantization_params(&ret); + return {std::move(ret), std::move(tensor)}; } void Float8BlockQuantizer::quantize(const TensorWrapper &input, TensorWrapper &out, const std::optional &noop_flag) { @@ -862,7 +901,109 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve } std::pair MXFP8Quantizer::coerce_tensor(py::object tensor) const { - NVTE_ERROR("Not yet implemented"); /// TODO Implement + NVTE_CHECK(detail::IsMXFP8Tensor(tensor.ptr()), "MXFP8Quantizer must output to MXFP8Tensor."); + + // Extract buffers from Python tensor + auto get_tensor = [&tensor](const char *name) -> std::optional { + auto attr_py = tensor.attr(name); + if (attr_py.is_none()) { + return std::nullopt; + } + return attr_py.cast(); + }; + auto rowwise_data = get_tensor("_rowwise_data"); + auto rowwise_scale_inv = get_tensor("_rowwise_scale_inv"); + auto columnwise_data = get_tensor("_columnwise_data"); + auto columnwise_scale_inv = get_tensor("_columnwise_scale_inv"); + NVTE_CHECK(!rowwise_data || !columnwise_data, "MXFP8Tensor has no data."); + + // Tensor dimensions + std::vector shape; + if (columnwise_data) { + shape = getTensorShape(*columnwise_data); + if (rowwise_data) { + auto expected_shape = getTensorShape(*rowwise_data); + NVTE_CHECK(shape == expected_shape, + "MXFP8 row-wise data (shape=", expected_shape, ") and column-wise data (shape=", + shape, ") do not match"); + } + } else { // Already checked columnwise_data_tensor == true + shape = getTensorShape(*rowwise_data); + } + + // Coerce row-wise data + if (rowwise_usage) { + if (!rowwise_data) { + const std::vector shape_int64(shape.begin(), shape.end()); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + rowwise_data = at::empty(shape_int64, opts); + tensor.attr("_rowwise_data") = *rowwise_data; + } + if (!rowwise_scale_inv) { + const auto scale_inv_shape = get_scale_shape(shape, false); + const std::vector scale_inv_shape_int64(scale_inv_shape.begin(), + scale_inv_shape.end()); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + rowwise_scale_inv = at::empty(scale_inv_shape_int64, opts); + tensor.attr("_rowwise_scale_inv") = *rowwise_scale_inv; + } + } else { // rowwise_usage == false + if (rowwise_data) { + rowwise_data.reset(); + tensor.attr("_rowwise_data") = py::none(); + } + if (rowwise_scale_inv) { + rowwise_scale_inv.reset(); + tensor.attr("_rowwise_scale_inv") = py::none(); + } + } + + // Coerce column-wise data + if (columnwise_usage) { + if (!columnwise_data) { + const std::vector shape_int64(shape.begin(), shape.end()); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + columnwise_data = at::empty(shape_int64, opts); + tensor.attr("_columnwise_data") = *columnwise_data; + } + if (!columnwise_scale_inv) { + const auto scale_inv_shape = get_scale_shape(shape, true); + const std::vector scale_inv_shape_int64(scale_inv_shape.begin(), + scale_inv_shape.end()); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + columnwise_scale_inv = at::empty(scale_inv_shape_int64, opts); + tensor.attr("_columnwise_scale_inv") = *columnwise_scale_inv; + } + } else { // columnwise_usage == false + if (columnwise_data) { + columnwise_data.reset(); + tensor.attr("_columnwise_data") = py::none(); + } + if (columnwise_scale_inv) { + columnwise_scale_inv.reset(); + tensor.attr("_columnwise_scale_inv") = py::none(); + } + } + + // Coerce other attrs + tensor.attr("_fp8_dtype") = dtype; + tensor.attr("_quantizer") = quantizer; /// TODO Need to make copy? + + // Construct C++ MXFP8 tensor + TensorWrapper out_cpp(NVTE_MXFP8_1D_SCALING); + if (rowwise_usage) { + out_cpp.set_rowwise_data(rowwise_data->data_ptr(), dtype, shape); + out_cpp.set_rowwise_scale_inv(rowwise_scale_inv->data_ptr(), DType::kFloat8E8M0, + getTensorShape(*rowwise_scale_inv)); + } + if (columnwise_usage) { + out_cpp.set_columnwise_data(columnwise_data->data_ptr(), dtype, shape); + out_cpp.set_columnwise_scale_inv(columnwise_scale_inv->data_ptr(), DType::kFloat8E8M0, + getTensorShape(*columnwise_scale_inv)); + } + this->set_quantization_params(&out_cpp); + + return {std::move(out_cpp), std::move(tensor)}; } void MXFP8Quantizer::quantize(const TensorWrapper &input, TensorWrapper &out, const std::optional &noop_flag) { From 86af34cdc0fc1797b75c3cd4f9c5902380f3895c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 17 Jul 2025 05:56:31 +0000 Subject: [PATCH 12/24] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/csrc/common.cpp | 2 +- transformer_engine/pytorch/csrc/common.h | 40 +++++----- .../pytorch/csrc/extensions/activation.cpp | 19 ++--- transformer_engine/pytorch/csrc/quantizer.cpp | 78 ++++++++++--------- 4 files changed, 74 insertions(+), 65 deletions(-) diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index 6d99ec88ae..ab3b7abec4 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -12,7 +12,7 @@ namespace transformer_engine::pytorch { -std::vector getTensorShape(const at::Tensor &t) { +std::vector getTensorShape(const at::Tensor& t) { std::vector shape; for (auto s : t.sizes()) { shape.push_back(s); diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index ebf4431e27..e12fc85ffb 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -98,13 +98,13 @@ class Quantizer { virtual void set_quantization_params(TensorWrapper* tensor) const = 0; - virtual std::pair create_tensor(const std::vector &shape, + virtual std::pair create_tensor(const std::vector& shape, DType dtype) const = 0; virtual std::pair coerce_tensor(py::object tensor) const = 0; - virtual void quantize(const TensorWrapper &input, TensorWrapper &out, - const std::optional &noop_flag = std::nullopt) = 0; + virtual void quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag = std::nullopt) = 0; virtual ~Quantizer() = default; @@ -125,16 +125,16 @@ class NoneQuantizer : public Quantizer { void set_quantization_params(TensorWrapper* tensor) const override {} - std::pair create_tensor(const std::vector &shape, + std::pair create_tensor(const std::vector& shape, DType dtype) const override; - std::pair create_tensor(const std::vector &shape, DType dtype, + std::pair create_tensor(const std::vector& shape, DType dtype, at::Tensor data) const; std::pair coerce_tensor(py::object tensor) const override; - void quantize(const TensorWrapper &input, TensorWrapper &out, - const std::optional &noop_flag = std::nullopt) override; + void quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag = std::nullopt) override; }; class Float8Quantizer : public Quantizer { @@ -150,18 +150,18 @@ class Float8Quantizer : public Quantizer { void set_quantization_params(TensorWrapper* tensor) const override; - std::pair create_tensor(const std::vector &shape, + std::pair create_tensor(const std::vector& shape, DType dtype) const override; - std::pair create_tensor(const std::vector &shape, DType dtype, + std::pair create_tensor(const std::vector& shape, DType dtype, std::optional data, std::optional transpose, std::optional scale_inv) const; std::pair coerce_tensor(py::object shape) const override; - void quantize(const TensorWrapper &input, TensorWrapper &out, - const std::optional &noop_flag = std::nullopt) override; + void quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag = std::nullopt) override; }; class Float8CurrentScalingQuantizer : public Quantizer { @@ -186,8 +186,8 @@ class Float8CurrentScalingQuantizer : public Quantizer { std::pair coerce_tensor(py::object shape) const override; - void quantize(const TensorWrapper &input, TensorWrapper &out, - const std::optional &noop_flag = std::nullopt) override; + void quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag = std::nullopt) override; }; class Float8BlockQuantizer : public Quantizer { @@ -219,13 +219,13 @@ class Float8BlockQuantizer : public Quantizer { // Create a python Float8BlockQuantized tensor and C++ wrapper // for the tensor. Should set quantized data, scales for rowwise // and optionally columnwise usage. - std::pair create_tensor(const std::vector &shape, + std::pair create_tensor(const std::vector& shape, DType dtype) const override; std::pair coerce_tensor(py::object shape) const override; - void quantize(const TensorWrapper &input, TensorWrapper &out, - const std::optional &noop_flag = std::nullopt) override; + void quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag = std::nullopt) override; std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; }; @@ -240,20 +240,20 @@ class MXFP8Quantizer : public Quantizer { void set_quantization_params(TensorWrapper* tensor) const override; - std::pair create_tensor(const std::vector &shape, + std::pair create_tensor(const std::vector& shape, DType dtype) const override; std::pair coerce_tensor(py::object shape) const override; - void quantize(const TensorWrapper &input, TensorWrapper &out, - const std::optional &noop_flag = std::nullopt) override; + void quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag = std::nullopt) override; std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; }; std::unique_ptr convert_quantizer(py::handle quantizer); -std::vector getTensorShape(const at::Tensor &t); +std::vector getTensorShape(const at::Tensor& t); transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid, const std::string& fp8_recipe); diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index 3e65fbf586..c9eae092b0 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -27,9 +27,8 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int auto [out_cpp, out_py] = quantizer_cpp->create_tensor(output_shape, fake_dtype); // Compute activation - if (quantizer.is_none() - || detail::IsFloat8Quantizers(quantizer.ptr()) - || detail::IsMXFP8Quantizers(quantizer.ptr())) { + if (quantizer.is_none() || detail::IsFloat8Quantizers(quantizer.ptr()) || + detail::IsMXFP8Quantizers(quantizer.ptr())) { // Compute activation directly NVTE_SCOPED_GIL_RELEASE( { act_func(input_cpp.data(), out_cpp.data(), at::cuda::getCurrentCUDAStream()); }); @@ -58,23 +57,25 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i // Construct grad input tensor auto quantizer_cpp = convert_quantizer(quantizer); const auto input_shape_te = input_cpp.shape(); - const std::vector input_shape(input_shape_te.data, input_shape_te.data + input_shape_te.ndim); + const std::vector input_shape(input_shape_te.data, + input_shape_te.data + input_shape_te.ndim); auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type()); auto [grad_input_cpp, grad_input_py] = quantizer_cpp->create_tensor(input_shape, fake_dtype); // Compute activation backward - if (quantizer.is_none() - || detail::IsFloat8Quantizers(quantizer.ptr()) - || detail::IsMXFP8Quantizers(quantizer.ptr())) { + if (quantizer.is_none() || detail::IsFloat8Quantizers(quantizer.ptr()) || + detail::IsMXFP8Quantizers(quantizer.ptr())) { // Compute activation backward directly NVTE_SCOPED_GIL_RELEASE({ - dact_func(grad_output_cpp.data(), input_cpp.data(), grad_input_cpp.data(), at::cuda::getCurrentCUDAStream()); + dact_func(grad_output_cpp.data(), input_cpp.data(), grad_input_cpp.data(), + at::cuda::getCurrentCUDAStream()); }); } else { // Compute activation backward in high-precision, then quantize auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(input_shape, fake_dtype); NVTE_SCOPED_GIL_RELEASE({ - dact_func(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), at::cuda::getCurrentCUDAStream()); + dact_func(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), + at::cuda::getCurrentCUDAStream()); }); quantizer_cpp->quantize(temp_cpp, grad_input_cpp); } diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 7c6def8ea0..e44e216401 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -63,7 +63,8 @@ std::pair NoneQuantizer::coerce_tensor(py::object ten return {std::move(out_cpp), std::move(tensor)}; } -void NoneQuantizer::quantize(const TensorWrapper &input, TensorWrapper &out, const std::optional &noop_flag) { +void NoneQuantizer::quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag) { NVTE_ERROR("NoneQuantizer does not support quantization"); } @@ -209,8 +210,12 @@ std::pair Float8Quantizer::coerce_tensor(py::object t const bool has_transpose = !transpose_py.is_none(); NVTE_CHECK(has_data || has_transpose, "Float8Tensor has no data."); std::optional data_tensor, transpose_tensor; - if (has_data) { data_tensor = data_py.cast(); } - if (has_transpose) { transpose_tensor = transpose_py.cast(); } + if (has_data) { + data_tensor = data_py.cast(); + } + if (has_transpose) { + transpose_tensor = transpose_py.cast(); + } at::Tensor scale_inv_tensor = tensor.attr("_scale_inv").cast(); // Tensor dimensions @@ -225,9 +230,8 @@ std::pair Float8Quantizer::coerce_tensor(py::object t } if (has_data) { auto expected_shape = getTensorShape(*data_tensor); - NVTE_CHECK(shape == expected_shape, - "FP8 data (shape=", expected_shape, ") and transpose (shape=", - transpose_shape, ") do not match"); + NVTE_CHECK(shape == expected_shape, "FP8 data (shape=", expected_shape, + ") and transpose (shape=", transpose_shape, ") do not match"); } } else { // Already checked has_data == true shape = getTensorShape(*data_tensor); @@ -294,7 +298,8 @@ std::pair Float8Quantizer::coerce_tensor(py::object t return {std::move(out_cpp), std::move(tensor)}; } -void Float8Quantizer::quantize(const TensorWrapper &input, TensorWrapper &out, const std::optional &noop_flag) { +void Float8Quantizer::quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag) { QuantizationConfigWrapper quant_config; if (noop_flag) { quant_config.set_noop_tensor(noop_flag->data()); @@ -428,8 +433,10 @@ std::pair Float8CurrentScalingQuantizer::create_tenso return {std::move(out_cpp), std::move(out_py)}; } -std::pair Float8CurrentScalingQuantizer::coerce_tensor(py::object tensor) const { - NVTE_CHECK(detail::IsFloat8Tensor(tensor.ptr()), "Float8CurrentScalingQuantizer must output to Float8Tensor."); +std::pair Float8CurrentScalingQuantizer::coerce_tensor( + py::object tensor) const { + NVTE_CHECK(detail::IsFloat8Tensor(tensor.ptr()), + "Float8CurrentScalingQuantizer must output to Float8Tensor."); // Expected buffers const bool need_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); @@ -443,8 +450,12 @@ std::pair Float8CurrentScalingQuantizer::coerce_tenso const bool has_transpose = !transpose_py.is_none(); NVTE_CHECK(has_data || has_transpose, "Tensor has no data."); std::optional data_tensor, transpose_tensor; - if (has_data) { data_tensor = data_py.cast(); } - if (has_transpose) { transpose_tensor = transpose_py.cast(); } + if (has_data) { + data_tensor = data_py.cast(); + } + if (has_transpose) { + transpose_tensor = transpose_py.cast(); + } at::Tensor scale_inv_tensor = tensor.attr("_scale_inv").cast(); // Tensor dimensions @@ -459,9 +470,8 @@ std::pair Float8CurrentScalingQuantizer::coerce_tenso } if (has_data) { auto expected_shape = getTensorShape(*data_tensor); - NVTE_CHECK(shape == expected_shape, - "FP8 data (shape=", expected_shape, ") and transpose (shape=", - transpose_shape, ") do not match"); + NVTE_CHECK(shape == expected_shape, "FP8 data (shape=", expected_shape, + ") and transpose (shape=", transpose_shape, ") do not match"); } } else { // Already checked has_data == true shape = getTensorShape(*data_tensor); @@ -528,7 +538,8 @@ std::pair Float8CurrentScalingQuantizer::coerce_tenso return {std::move(out_cpp), std::move(tensor)}; } -void Float8CurrentScalingQuantizer::quantize(const TensorWrapper &input, TensorWrapper &out, const std::optional &noop_flag) { +void Float8CurrentScalingQuantizer::quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag) { auto stream = at::cuda::getCurrentCUDAStream(); // Quantization configs @@ -552,15 +563,11 @@ void Float8CurrentScalingQuantizer::quantize(const TensorWrapper &input, TensorW } // Compute scaling factor - NVTE_SCOPED_GIL_RELEASE({ - nvte_compute_scale_from_amax(out.data(), quant_config, stream); - }); + NVTE_SCOPED_GIL_RELEASE({ nvte_compute_scale_from_amax(out.data(), quant_config, stream); }); // Cast to FP8 out.set_amax(nullptr, DType::kFloat32, out.defaultShape); // Avoid atomic amax updates - NVTE_SCOPED_GIL_RELEASE({ - nvte_quantize_v2(input.data(), out.data(), quant_config, stream); - }); + NVTE_SCOPED_GIL_RELEASE({ nvte_quantize_v2(input.data(), out.data(), quant_config, stream); }); } Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quantizer(quantizer) { @@ -693,19 +700,19 @@ std::pair Float8BlockQuantizer::coerce_tensor(py::obj auto ret = TensorWrapper(is_2D_scaled ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D); if (rowwise_usage) { - const at::Tensor &data_rowwise = tensor.attr("_rowwise_data").cast(); - const at::Tensor &scale_inv_rowwise = tensor.attr("_rowwise_scale_inv").cast(); - void *scale_inv_rowwise_dptr = scale_inv_rowwise.data_ptr(); - const auto &rowwise_shape = getTensorShape(data_rowwise); + const at::Tensor& data_rowwise = tensor.attr("_rowwise_data").cast(); + const at::Tensor& scale_inv_rowwise = tensor.attr("_rowwise_scale_inv").cast(); + void* scale_inv_rowwise_dptr = scale_inv_rowwise.data_ptr(); + const auto& rowwise_shape = getTensorShape(data_rowwise); ret.set_rowwise_data(data_rowwise.data_ptr(), dtype, rowwise_shape); const auto scale_inv_rowwise_shape = getTensorShape(scale_inv_rowwise); ret.set_rowwise_scale_inv(scale_inv_rowwise_dptr, DType::kFloat32, scale_inv_rowwise_shape); } if (columnwise_usage) { - const at::Tensor &data_colwise = tensor.attr("_columnwise_data").cast(); - const at::Tensor &scale_inv_colwise = tensor.attr("_columnwise_scale_inv").cast(); - void *scale_inv_colwise_dptr = scale_inv_colwise.data_ptr(); - const auto &shape = getTensorShape(data_colwise); + const at::Tensor& data_colwise = tensor.attr("_columnwise_data").cast(); + const at::Tensor& scale_inv_colwise = tensor.attr("_columnwise_scale_inv").cast(); + void* scale_inv_colwise_dptr = scale_inv_colwise.data_ptr(); + const auto& shape = getTensorShape(data_colwise); ret.set_columnwise_data(data_colwise.data_ptr(), dtype, shape); const auto scale_inv_colwise_shape = getTensorShape(scale_inv_colwise); ret.set_columnwise_scale_inv(scale_inv_colwise_dptr, DType::kFloat32, scale_inv_colwise_shape); @@ -714,7 +721,8 @@ std::pair Float8BlockQuantizer::coerce_tensor(py::obj return {std::move(ret), std::move(tensor)}; } -void Float8BlockQuantizer::quantize(const TensorWrapper &input, TensorWrapper &out, const std::optional &noop_flag) { +void Float8BlockQuantizer::quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag) { QuantizationConfigWrapper quant_config; if (noop_flag) { quant_config.set_noop_tensor(noop_flag->data()); @@ -904,7 +912,7 @@ std::pair MXFP8Quantizer::coerce_tensor(py::object te NVTE_CHECK(detail::IsMXFP8Tensor(tensor.ptr()), "MXFP8Quantizer must output to MXFP8Tensor."); // Extract buffers from Python tensor - auto get_tensor = [&tensor](const char *name) -> std::optional { + auto get_tensor = [&tensor](const char* name) -> std::optional { auto attr_py = tensor.attr(name); if (attr_py.is_none()) { return std::nullopt; @@ -923,9 +931,8 @@ std::pair MXFP8Quantizer::coerce_tensor(py::object te shape = getTensorShape(*columnwise_data); if (rowwise_data) { auto expected_shape = getTensorShape(*rowwise_data); - NVTE_CHECK(shape == expected_shape, - "MXFP8 row-wise data (shape=", expected_shape, ") and column-wise data (shape=", - shape, ") do not match"); + NVTE_CHECK(shape == expected_shape, "MXFP8 row-wise data (shape=", expected_shape, + ") and column-wise data (shape=", shape, ") do not match"); } } else { // Already checked columnwise_data_tensor == true shape = getTensorShape(*rowwise_data); @@ -1006,7 +1013,8 @@ std::pair MXFP8Quantizer::coerce_tensor(py::object te return {std::move(out_cpp), std::move(tensor)}; } -void MXFP8Quantizer::quantize(const TensorWrapper &input, TensorWrapper &out, const std::optional &noop_flag) { +void MXFP8Quantizer::quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag) { QuantizationConfigWrapper quant_config; if (noop_flag) { quant_config.set_noop_tensor(noop_flag->data()); From 596ead52f572b3d1aa9749025ab260820713fc4b Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Fri, 18 Jul 2025 02:22:38 +0000 Subject: [PATCH 13/24] Avoid quantizing empty tensors Signed-off-by: Tim Moon --- transformer_engine/pytorch/csrc/quantizer.cpp | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index e44e216401..faa9515d90 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -300,6 +300,9 @@ std::pair Float8Quantizer::coerce_tensor(py::object t void Float8Quantizer::quantize(const TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag) { + if (input.numel() == 0) { + return; + } QuantizationConfigWrapper quant_config; if (noop_flag) { quant_config.set_noop_tensor(noop_flag->data()); @@ -542,6 +545,11 @@ void Float8CurrentScalingQuantizer::quantize(const TensorWrapper& input, TensorW const std::optional& noop_flag) { auto stream = at::cuda::getCurrentCUDAStream(); + // Nothing to be done if input is empty + if (input.numel() == 0) { + return; + } + // Quantization configs QuantizationConfigWrapper quant_config; if (noop_flag) { @@ -723,6 +731,9 @@ std::pair Float8BlockQuantizer::coerce_tensor(py::obj void Float8BlockQuantizer::quantize(const TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag) { + if (input.numel() == 0) { + return; + } QuantizationConfigWrapper quant_config; if (noop_flag) { quant_config.set_noop_tensor(noop_flag->data()); @@ -1015,6 +1026,9 @@ std::pair MXFP8Quantizer::coerce_tensor(py::object te void MXFP8Quantizer::quantize(const TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag) { + if (input.numel() == 0) { + return; + } QuantizationConfigWrapper quant_config; if (noop_flag) { quant_config.set_noop_tensor(noop_flag->data()); From c4270b3a4dd0723a8e44ac58ab6fa07477f2f83f Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Fri, 18 Jul 2025 02:47:28 +0000 Subject: [PATCH 14/24] Use consistent shapes for FP8 transposes Signed-off-by: Tim Moon --- .../pytorch/csrc/extensions/transpose.cpp | 30 ++++++++++++------- .../pytorch/tensor/float8_tensor.py | 5 ++-- 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cpp b/transformer_engine/pytorch/csrc/extensions/transpose.cpp index d2f7107fe5..d6ae0c86a1 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cpp +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cpp @@ -18,27 +18,35 @@ namespace pytorch { at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional output) { init_extension(); - const auto dim = input.dim(); - NVTE_CHECK(dim >= 2, "Need at least 2D tensor to transpose."); - - if (input.dim() > 2) { - input = input.view({-1, input.size(dim - 1)}); + // Tensor dimensions + const auto shape = getTensorShape(input); + std::vector transpose_shape_int64; + if (shape.size() > 0) { + transpose_shape_int64.push_back(shape.back()); + for (size_t i = 0; i < shape.size() - 1; ++i) { + transpose_shape_int64.push_back(shape[i]); + } } + const size_t M = shape.size() > 0 ? product(shape) / shape.back() : 1; + const size_t N = shape.size() > 0 ? shape.back() : 1; - size_t M = static_cast(input.size(0)); - size_t N = static_cast(input.size(1)); - + // Output tensor at::Tensor out; if (output.has_value()) { out = *output; } else { - out = allocateTorchTensor(input.size(1), input.size(0), DType::kByte); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + out = at::empty(transpose_shape_int64, opts); } - if (M == 0 || N == 0) return out; + // Return immediately if tensor is empty + if (M == 0 || N == 0) { + return out; + } + + // Compute transpose auto input_cu = makeTransformerEngineTensor(input.data_ptr(), std::vector{M, N}, otype); auto output_cu = makeTransformerEngineTensor(out.data_ptr(), std::vector{N, M}, otype); - nvte_transpose(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); return out; diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index c1476945a8..4c57ffd3e5 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -108,10 +108,9 @@ def make_empty( # Allocate FP8 data transpose if needed data_transpose = None if self.columnwise_usage: - inner_dim = data.size(-1) + transpose_shape = [data.size(-1)] + list(data.shape[:-1]) data_transpose = torch.empty( - inner_dim, - data.numel() // inner_dim, + transpose_shape, dtype=torch.uint8, device=device, ) From 34d1fde6aef6675645a80424ed46bc062f19732d Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Sat, 19 Jul 2025 02:13:20 +0000 Subject: [PATCH 15/24] In attention impl, construct FP8 tensors with pre-initialized scale-invs Signed-off-by: Tim Moon --- transformer_engine/pytorch/csrc/common.h | 9 + .../pytorch/csrc/extensions/attention.cpp | 30 +++- transformer_engine/pytorch/csrc/quantizer.cpp | 159 ++++++------------ 3 files changed, 88 insertions(+), 110 deletions(-) diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index e12fc85ffb..078a96c113 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -98,11 +98,18 @@ class Quantizer { virtual void set_quantization_params(TensorWrapper* tensor) const = 0; + /*! @brief Construct a tensor with uninitialized data */ virtual std::pair create_tensor(const std::vector& shape, DType dtype) const = 0; + /*! @brief Load a PyTorch tensor + * + * The PyTorch tensor's attributes are modified to match the + * quantizer's configuration. + */ virtual std::pair coerce_tensor(py::object tensor) const = 0; + /*! @brief Convert to a quantized data format */ virtual void quantize(const TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag = std::nullopt) = 0; @@ -128,6 +135,7 @@ class NoneQuantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype) const override; + /*! @brief Construct a tensor with pre-initialized data */ std::pair create_tensor(const std::vector& shape, DType dtype, at::Tensor data) const; @@ -153,6 +161,7 @@ class Float8Quantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype) const override; + /*! @brief Construct a tensor with pre-initialized data */ std::pair create_tensor(const std::vector& shape, DType dtype, std::optional data, std::optional transpose, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 74c5c5e7a7..cbe1fd25f4 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -18,7 +18,7 @@ void mha_fill(const transformer_engine::TensorWrapper &self, const at::Tensor &s auto max_tokens = shape[0]; auto fcd_size = 1; - for (int i = 1; i <= shape.size(); i++) { + for (size_t i = 1; i <= shape.size(); i++) { fcd_size *= shape[i]; } @@ -103,8 +103,18 @@ std::vector fused_attn_fwd( auto o_shape = std::vector{q_shape.begin(), q_shape.end()}; o_shape[o_shape.size() - 1] = v_shape[v_shape.size() - 1]; py::object o_python, s_python; - std::tie(te_O, o_python) = O_quantizer->create_tensor(o_shape, fake_dtype_te); - std::tie(te_S, s_python) = S_quantizer->create_tensor({0}, DType::kFloat32); + if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { + // Initialize FP8 tensor with scale-inverse + auto *O_quantizer_fp8 = dynamic_cast(O_quantizer.get()); + auto *S_quantizer_fp8 = dynamic_cast(S_quantizer.get()); + NVTE_CHECK(O_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8"); + NVTE_CHECK(S_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8"); + std::tie(te_O, o_python) = O_quantizer_fp8->create_tensor(o_shape, fake_dtype_te, std::nullopt, std::nullopt, std::nullopt); + std::tie(te_S, s_python) = S_quantizer_fp8->create_tensor({0}, DType::kFloat32, std::nullopt, std::nullopt, std::nullopt); + } else { + std::tie(te_O, o_python) = O_quantizer->create_tensor(o_shape, fake_dtype_te); + std::tie(te_S, s_python) = S_quantizer->create_tensor({0}, DType::kFloat32); + } auto o_shape_int64 = std::vector{o_shape.begin(), o_shape.end()}; // construct NVTE tensors @@ -284,8 +294,18 @@ std::vector fused_attn_bwd( py::object s_python, dp_python; std::unique_ptr S_quantizer = convert_quantizer(s_quantizer); std::unique_ptr dP_quantizer = convert_quantizer(dp_quantizer); - std::tie(te_S, s_python) = S_quantizer->create_tensor({0}, DType::kFloat32); - std::tie(te_dP, dp_python) = dP_quantizer->create_tensor({0}, DType::kFloat32); + + if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { + auto *S_quantizer_fp8 = dynamic_cast(S_quantizer.get()); + auto *dP_quantizer_fp8 = dynamic_cast(dP_quantizer.get()); + NVTE_CHECK(S_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8"); + NVTE_CHECK(dP_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8"); + std::tie(te_S, s_python) = S_quantizer_fp8->create_tensor({0}, DType::kFloat32, std::nullopt, std::nullopt, std::nullopt); + std::tie(te_dP, dp_python) = dP_quantizer_fp8->create_tensor({0}, DType::kFloat32, std::nullopt, std::nullopt, std::nullopt); + } else { + std::tie(te_S, s_python) = S_quantizer->create_tensor({0}, DType::kFloat32); + std::tie(te_dP, dp_python) = dP_quantizer->create_tensor({0}, DType::kFloat32); + } std::vector q_shape = convertShape(te_Q.shape()); std::vector k_shape = convertShape(te_K.shape()); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index faa9515d90..f9da7cb2f7 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -12,6 +12,27 @@ namespace transformer_engine::pytorch { +namespace { + +/*! @brief Transposed tensor shape + * + * The tensor is interpreted as a 2D matrix by flattening all but the + * last dimension, and then transposed. + */ +template +std::vector make_transpose_shape(const std::vector &shape) { + std::vector ret; + if (shape.size() > 0) { + ret.push_back(shape.back()); + for (size_t i = 0; i < shape.size() - 1; ++i) { + ret.push_back(shape[i]); + } + } + return ret; +} + +} // namespace + constexpr size_t MXFP8_BLOCK_SIZE = 32; Quantizer::Quantizer(const py::handle& quantizer) { @@ -88,40 +109,9 @@ void Float8Quantizer::set_quantization_params(TensorWrapper* tensor) const { std::pair Float8Quantizer::create_tensor( const std::vector& shape, DType dtype) const { - // Allocate data tensor if needed - std::optional data; - const bool with_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); - if (with_data) { - const std::vector shape_int64(shape.begin(), shape.end()); - const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); - data = at::empty(shape_int64, opts); - } - - // Allocate transpose tensor if needed - std::optional transpose; - const bool with_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); - if (with_transpose) { - std::vector transpose_shape_int64; - if (shape.size() > 0) { - transpose_shape_int64.push_back(shape.back()); - for (size_t i = 0; i < shape.size() - 1; ++i) { - transpose_shape_int64.push_back(shape[i]); - } - } - const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); - transpose = at::empty(transpose_shape_int64, opts); - } - - // Allocate scale-inverse tensor - std::optional scale_inv; - { - const std::vector scale_inv_shape = {1}; - const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); - scale_inv = at::empty(scale_inv_shape, opts); - }; - - // Construct FP8 tensor - return create_tensor(shape, dtype, std::move(data), std::move(transpose), std::move(scale_inv)); + const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + at::Tensor scale_inv = at::empty(std::vector{1}, opts); + return create_tensor(shape, dtype, std::nullopt, std::nullopt, std::move(scale_inv)); } std::pair Float8Quantizer::create_tensor( @@ -130,43 +120,44 @@ std::pair Float8Quantizer::create_tensor( using namespace pybind11::literals; // Initialize data tensor - at::Tensor data_tensor; const bool with_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); - if (with_data) { - NVTE_CHECK(data, "Constructing Float8Tensor, but no FP8 data was provided"); - data_tensor = std::move(*data); + if (with_data && !data) { + const std::vector shape_int64(shape.begin(), shape.end()); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + data = at::empty(shape_int64, opts); + } else if (!with_data && data) { + data.reset(); } - py::object data_py = with_data ? py::cast(data_tensor) : py::none(); + py::object data_py = with_data ? py::cast(*data) : py::none(); // Initialize transpose tensor - at::Tensor transpose_tensor; const bool with_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); - if (with_transpose) { - NVTE_CHECK( - transpose, - "Constructing Float8Tensor with column-wise usage, but no FP8 transpose was provided"); - transpose_tensor = std::move(*transpose); + if (with_transpose && !transpose) { + const auto transpose_shape = make_transpose_shape(shape); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + transpose = at::empty(transpose_shape, opts); + } else if (!with_transpose && transpose) { + transpose.reset(); } - py::object transpose_py = with_transpose ? py::cast(transpose_tensor) : py::none(); + py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none(); // Initialize scale-inverse tensor if (!scale_inv) { scale_inv = at::reciprocal(scale); } - auto& scale_inv_tensor = *scale_inv; // Construct Python FP8 tensor py::object out_py; if (internal) { py::handle Float8TensorClass(reinterpret_cast(Float8TensorBasePythonClass)); - out_py = Float8TensorClass("data"_a = data_py, "fp8_scale_inv"_a = scale_inv_tensor, + out_py = Float8TensorClass("data"_a = data_py, "fp8_scale_inv"_a = *scale_inv, "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, "quantizer"_a = this->quantizer); } else { py::handle Float8TensorClass(reinterpret_cast(Float8TensorPythonClass)); const std::vector shape_int64(shape.begin(), shape.end()); out_py = Float8TensorClass("shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), - "data"_a = data_py, "fp8_scale_inv"_a = scale_inv_tensor, + "data"_a = data_py, "fp8_scale_inv"_a = *scale_inv, "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, "quantizer"_a = this->quantizer); } @@ -174,20 +165,14 @@ std::pair Float8Quantizer::create_tensor( // Construct C++ FP8 tensor TensorWrapper out_cpp(this->get_scaling_mode()); if (with_data) { - out_cpp.set_rowwise_data(data_tensor.data_ptr(), this->dtype, shape); - out_cpp.set_rowwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32, + out_cpp.set_rowwise_data(data->data_ptr(), this->dtype, shape); + out_cpp.set_rowwise_scale_inv(scale_inv->data_ptr(), DType::kFloat32, std::vector{1}); } if (with_transpose) { - std::vector transpose_shape; - if (shape.size() > 0) { - transpose_shape.push_back(shape.back()); - for (size_t i = 0; i < shape.size() - 1; ++i) { - transpose_shape.push_back(shape[i]); - } - } - out_cpp.set_columnwise_data(transpose_tensor.data_ptr(), this->dtype, transpose_shape); - out_cpp.set_columnwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32, + const auto transpose_shape = make_transpose_shape(shape); + out_cpp.set_columnwise_data(transpose->data_ptr(), this->dtype, transpose_shape); + out_cpp.set_columnwise_scale_inv(scale_inv->data_ptr(), DType::kFloat32, std::vector{1}); } this->set_quantization_params(&out_cpp); @@ -256,15 +241,9 @@ std::pair Float8Quantizer::coerce_tensor(py::object t transpose_py = py::none(); tensor.attr("_transpose") = transpose_py; } else if (!has_transpose && need_transpose) { - std::vector transpose_shape_int64; - if (shape.size() > 0) { - transpose_shape_int64.push_back(shape.back()); - for (size_t i = 0; i < shape.size(); ++i) { - transpose_shape_int64.push_back(shape[i]); - } - } + const auto transpose_shape = make_transpose_shape(shape); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); - transpose_tensor = at::empty(transpose_shape_int64, opts); + transpose_tensor = at::empty(transpose_shape, opts); transpose_py = py::cast(transpose_tensor); tensor.attr("_transpose") = transpose_py; } @@ -282,13 +261,7 @@ std::pair Float8Quantizer::coerce_tensor(py::object t std::vector{1}); } if (transpose_tensor) { - std::vector transpose_shape; - if (shape.size() > 0) { - transpose_shape.push_back(shape.back()); - for (size_t i = 0; i < shape.size() - 1; ++i) { - transpose_shape.push_back(shape[i]); - } - } + const auto transpose_shape = make_transpose_shape(shape); out_cpp.set_columnwise_data(transpose_tensor->data_ptr(), this->dtype, transpose_shape); out_cpp.set_columnwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32, std::vector{1}); @@ -375,15 +348,9 @@ std::pair Float8CurrentScalingQuantizer::create_tenso at::Tensor transpose_tensor; const bool with_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); if (with_transpose) { - std::vector transpose_shape_int64; - if (shape.size() > 0) { - transpose_shape_int64.push_back(shape.back()); - for (size_t i = 0; i < shape.size() - 1; ++i) { - transpose_shape_int64.push_back(shape[i]); - } - } + const auto transpose_shape = make_transpose_shape(shape); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); - transpose_tensor = at::empty(transpose_shape_int64, opts); + transpose_tensor = at::empty(transpose_shape, opts); } // Initialize scale-inverse tensor @@ -420,13 +387,7 @@ std::pair Float8CurrentScalingQuantizer::create_tenso std::vector{1}); } if (with_transpose) { - std::vector transpose_shape; - if (shape.size() > 0) { - transpose_shape.push_back(shape.back()); - for (size_t i = 0; i < shape.size() - 1; ++i) { - transpose_shape.push_back(shape[i]); - } - } + const auto transpose_shape = make_transpose_shape(shape); out_cpp.set_columnwise_data(transpose_tensor.data_ptr(), this->dtype, transpose_shape); out_cpp.set_columnwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32, std::vector{1}); @@ -499,15 +460,9 @@ std::pair Float8CurrentScalingQuantizer::coerce_tenso transpose_py = py::none(); tensor.attr("_transpose") = transpose_py; } else if (!has_transpose && need_transpose) { - std::vector transpose_shape_int64; - if (shape.size() > 0) { - transpose_shape_int64.push_back(shape.back()); - for (size_t i = 0; i < shape.size(); ++i) { - transpose_shape_int64.push_back(shape[i]); - } - } + const auto transpose_shape = make_transpose_shape(shape); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); - transpose_tensor = at::empty(transpose_shape_int64, opts); + transpose_tensor = at::empty(transpose_shape, opts); transpose_py = py::cast(transpose_tensor); tensor.attr("_transpose") = transpose_py; } @@ -525,13 +480,7 @@ std::pair Float8CurrentScalingQuantizer::coerce_tenso std::vector{1}); } if (transpose_tensor) { - std::vector transpose_shape; - if (shape.size() > 0) { - transpose_shape.push_back(shape.back()); - for (size_t i = 0; i < shape.size() - 1; ++i) { - transpose_shape.push_back(shape[i]); - } - } + const auto transpose_shape = make_transpose_shape(shape); out_cpp.set_columnwise_data(transpose_tensor->data_ptr(), this->dtype, transpose_shape); out_cpp.set_columnwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32, std::vector{1}); From a49cb5e875ce58fcf7d5951f39d6e7f975226db0 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Sat, 19 Jul 2025 02:19:35 +0000 Subject: [PATCH 16/24] Initialize MXFP8 scales to zero Signed-off-by: Tim Moon --- transformer_engine/pytorch/csrc/quantizer.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index f9da7cb2f7..b547e1cc40 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -816,13 +816,13 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve const std::vector scale_inv_shape_int64(rowwise_scale_inv_shape.begin(), rowwise_scale_inv_shape.end()); rowwise_data_tensor = at::empty(shape_int64, uint8_tensor_opts); - rowwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, uint8_tensor_opts); + rowwise_scale_inv_tensor = at::zeros(scale_inv_shape_int64, uint8_tensor_opts); } if (columnwise_usage) { const std::vector scale_inv_shape_int64(columnwise_scale_inv_shape.begin(), columnwise_scale_inv_shape.end()); columnwise_data_tensor = at::empty(shape_int64, uint8_tensor_opts); - columnwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, uint8_tensor_opts); + columnwise_scale_inv_tensor = at::zeros(scale_inv_shape_int64, uint8_tensor_opts); } // Construct Python MXFP8 tensor @@ -883,7 +883,7 @@ std::pair MXFP8Quantizer::coerce_tensor(py::object te auto rowwise_scale_inv = get_tensor("_rowwise_scale_inv"); auto columnwise_data = get_tensor("_columnwise_data"); auto columnwise_scale_inv = get_tensor("_columnwise_scale_inv"); - NVTE_CHECK(!rowwise_data || !columnwise_data, "MXFP8Tensor has no data."); + NVTE_CHECK(rowwise_data || columnwise_data, "MXFP8Tensor has no data."); // Tensor dimensions std::vector shape; @@ -911,7 +911,7 @@ std::pair MXFP8Quantizer::coerce_tensor(py::object te const std::vector scale_inv_shape_int64(scale_inv_shape.begin(), scale_inv_shape.end()); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); - rowwise_scale_inv = at::empty(scale_inv_shape_int64, opts); + rowwise_scale_inv = at::zeros(scale_inv_shape_int64, opts); tensor.attr("_rowwise_scale_inv") = *rowwise_scale_inv; } } else { // rowwise_usage == false @@ -938,7 +938,7 @@ std::pair MXFP8Quantizer::coerce_tensor(py::object te const std::vector scale_inv_shape_int64(scale_inv_shape.begin(), scale_inv_shape.end()); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); - columnwise_scale_inv = at::empty(scale_inv_shape_int64, opts); + columnwise_scale_inv = at::zeros(scale_inv_shape_int64, opts); tensor.attr("_columnwise_scale_inv") = *columnwise_scale_inv; } } else { // columnwise_usage == false From ba6867690c26b8d55f972faf57e35dc17e30f1a1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 19 Jul 2025 02:20:18 +0000 Subject: [PATCH 17/24] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/csrc/extensions/attention.cpp | 12 ++++++++---- transformer_engine/pytorch/csrc/quantizer.cpp | 7 +++---- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index cbe1fd25f4..6d835a5c94 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -109,8 +109,10 @@ std::vector fused_attn_fwd( auto *S_quantizer_fp8 = dynamic_cast(S_quantizer.get()); NVTE_CHECK(O_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8"); NVTE_CHECK(S_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8"); - std::tie(te_O, o_python) = O_quantizer_fp8->create_tensor(o_shape, fake_dtype_te, std::nullopt, std::nullopt, std::nullopt); - std::tie(te_S, s_python) = S_quantizer_fp8->create_tensor({0}, DType::kFloat32, std::nullopt, std::nullopt, std::nullopt); + std::tie(te_O, o_python) = O_quantizer_fp8->create_tensor(o_shape, fake_dtype_te, std::nullopt, + std::nullopt, std::nullopt); + std::tie(te_S, s_python) = S_quantizer_fp8->create_tensor({0}, DType::kFloat32, std::nullopt, + std::nullopt, std::nullopt); } else { std::tie(te_O, o_python) = O_quantizer->create_tensor(o_shape, fake_dtype_te); std::tie(te_S, s_python) = S_quantizer->create_tensor({0}, DType::kFloat32); @@ -300,8 +302,10 @@ std::vector fused_attn_bwd( auto *dP_quantizer_fp8 = dynamic_cast(dP_quantizer.get()); NVTE_CHECK(S_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8"); NVTE_CHECK(dP_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8"); - std::tie(te_S, s_python) = S_quantizer_fp8->create_tensor({0}, DType::kFloat32, std::nullopt, std::nullopt, std::nullopt); - std::tie(te_dP, dp_python) = dP_quantizer_fp8->create_tensor({0}, DType::kFloat32, std::nullopt, std::nullopt, std::nullopt); + std::tie(te_S, s_python) = S_quantizer_fp8->create_tensor({0}, DType::kFloat32, std::nullopt, + std::nullopt, std::nullopt); + std::tie(te_dP, dp_python) = dP_quantizer_fp8->create_tensor({0}, DType::kFloat32, std::nullopt, + std::nullopt, std::nullopt); } else { std::tie(te_S, s_python) = S_quantizer->create_tensor({0}, DType::kFloat32); std::tie(te_dP, dp_python) = dP_quantizer->create_tensor({0}, DType::kFloat32); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index b547e1cc40..b84b1b0f8d 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -19,8 +19,8 @@ namespace { * The tensor is interpreted as a 2D matrix by flattening all but the * last dimension, and then transposed. */ -template -std::vector make_transpose_shape(const std::vector &shape) { +template +std::vector make_transpose_shape(const std::vector& shape) { std::vector ret; if (shape.size() > 0) { ret.push_back(shape.back()); @@ -166,8 +166,7 @@ std::pair Float8Quantizer::create_tensor( TensorWrapper out_cpp(this->get_scaling_mode()); if (with_data) { out_cpp.set_rowwise_data(data->data_ptr(), this->dtype, shape); - out_cpp.set_rowwise_scale_inv(scale_inv->data_ptr(), DType::kFloat32, - std::vector{1}); + out_cpp.set_rowwise_scale_inv(scale_inv->data_ptr(), DType::kFloat32, std::vector{1}); } if (with_transpose) { const auto transpose_shape = make_transpose_shape(shape); From 76d2d53c55609b30359204651bf829c4d08dbaae Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Mon, 21 Jul 2025 23:12:38 +0000 Subject: [PATCH 18/24] Store copy of quantizer when creating quantized tensors Signed-off-by: Tim Moon --- transformer_engine/pytorch/csrc/quantizer.cpp | 49 ++++++++++++------- 1 file changed, 32 insertions(+), 17 deletions(-) diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index b84b1b0f8d..311dfc0c75 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -146,20 +146,24 @@ std::pair Float8Quantizer::create_tensor( scale_inv = at::reciprocal(scale); } + // Make shallow copy of quantizer so in-place ops aren't influenced + // by future usage changes + auto quantizer_py = this->quantizer.attr("copy")(); + // Construct Python FP8 tensor py::object out_py; if (internal) { py::handle Float8TensorClass(reinterpret_cast(Float8TensorBasePythonClass)); out_py = Float8TensorClass("data"_a = data_py, "fp8_scale_inv"_a = *scale_inv, "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, - "quantizer"_a = this->quantizer); + "quantizer"_a = quantizer_py); } else { py::handle Float8TensorClass(reinterpret_cast(Float8TensorPythonClass)); const std::vector shape_int64(shape.begin(), shape.end()); out_py = Float8TensorClass("shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), "data"_a = data_py, "fp8_scale_inv"_a = *scale_inv, "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, - "quantizer"_a = this->quantizer); + "quantizer"_a = quantizer_py); } // Construct C++ FP8 tensor @@ -250,7 +254,6 @@ std::pair Float8Quantizer::coerce_tensor(py::object t // Coerce other attrs tensor.attr("_fp8_dtype") = dtype; - tensor.attr("_quantizer") = quantizer; /// TODO Need to make copy? // Construct C++ FP8 tensor TensorWrapper out_cpp; @@ -360,6 +363,10 @@ std::pair Float8CurrentScalingQuantizer::create_tenso scale_inv_tensor = at::empty(scale_inv_shape, opts); } + // Make shallow copy of quantizer so in-place ops aren't influenced + // by future usage changes + auto quantizer_py = this->quantizer.attr("copy")(); + // Construct Python FP8 tensor py::object out_py; py::object data_py = with_data ? py::cast(data_tensor) : py::none(); @@ -368,14 +375,14 @@ std::pair Float8CurrentScalingQuantizer::create_tenso py::handle Float8TensorClass(reinterpret_cast(Float8TensorBasePythonClass)); out_py = Float8TensorClass("data"_a = data_py, "fp8_scale_inv"_a = scale_inv_tensor, "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, - "quantizer"_a = this->quantizer); + "quantizer"_a = quantizer_py); } else { py::handle Float8TensorClass(reinterpret_cast(Float8TensorPythonClass)); const std::vector shape_int64(shape.begin(), shape.end()); out_py = Float8TensorClass("shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), "data"_a = data_py, "fp8_scale_inv"_a = scale_inv_tensor, "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, - "quantizer"_a = this->quantizer); + "quantizer"_a = quantizer_py); } // Construct C++ FP8 tensor @@ -469,7 +476,6 @@ std::pair Float8CurrentScalingQuantizer::coerce_tenso // Coerce other attrs tensor.attr("_fp8_dtype") = dtype; - tensor.attr("_quantizer") = quantizer; /// TODO Need to make copy? // Construct C++ FP8 tensor TensorWrapper out_cpp; @@ -618,6 +624,10 @@ std::pair Float8BlockQuantizer::create_tensor( } this->set_quantization_params(&tensor); + // Make shallow copy of quantizer so in-place ops aren't influenced + // by future usage changes + auto quantizer_py = this->quantizer.attr("copy")(); + py::object ret; if (internal) { py::handle Float8BlockwiseQTensorClass( @@ -625,7 +635,7 @@ std::pair Float8BlockQuantizer::create_tensor( ret = Float8BlockwiseQTensorClass( "rowwise_data"_a = data_rowwise, "columnwise_data"_a = data_colwise, "rowwise_scale_inv"_a = scale_inv_rowwise, "columnwise_scale_inv"_a = scale_inv_colwise, - "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer, + "fp8_dtype"_a = this->dtype, "quantizer"_a = quantizer_py, "is_2D_scaled"_a = (block_scaling_dim == 2), "data_format"_a = data_format); } else { py::handle Float8BlockwiseQTensorClass( @@ -634,7 +644,7 @@ std::pair Float8BlockQuantizer::create_tensor( "shape"_a = torch_shape, "dtype"_a = GetATenDType(dtype), "rowwise_data"_a = data_rowwise, "columnwise_data"_a = data_colwise, "rowwise_scale_inv"_a = scale_inv_rowwise, "columnwise_scale_inv"_a = scale_inv_colwise, "fp8_dtype"_a = this->dtype, - "quantizer"_a = this->quantizer, "is_2D_scaled"_a = (block_scaling_dim == 2), + "quantizer"_a = quantizer_py, "is_2D_scaled"_a = (block_scaling_dim == 2), "data_format"_a = data_format); } @@ -824,22 +834,28 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve columnwise_scale_inv_tensor = at::zeros(scale_inv_shape_int64, uint8_tensor_opts); } - // Construct Python MXFP8 tensor - py::object out_py; + // Convert tensors to Python auto py_cast = [](at::Tensor& tensor, bool need_cast) -> py::object { return need_cast ? py::cast(tensor) : py::none(); }; - py::object rowwise_data_py = py_cast(rowwise_data_tensor, rowwise_usage); - py::object rowwise_scale_inv_py = py_cast(rowwise_scale_inv_tensor, rowwise_usage); - py::object columnwise_data_py = py_cast(columnwise_data_tensor, columnwise_usage); - py::object columnwise_scale_inv_py = py_cast(columnwise_scale_inv_tensor, columnwise_usage); + auto rowwise_data_py = py_cast(rowwise_data_tensor, rowwise_usage); + auto rowwise_scale_inv_py = py_cast(rowwise_scale_inv_tensor, rowwise_usage); + auto columnwise_data_py = py_cast(columnwise_data_tensor, columnwise_usage); + auto columnwise_scale_inv_py = py_cast(columnwise_scale_inv_tensor, columnwise_usage); + + // Make shallow copy of quantizer so in-place ops aren't influenced + // by future usage changes + auto quantizer_py = this->quantizer.attr("copy")(); + + // Construct Python MXFP8 tensor + py::object out_py; if (internal) { py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorBasePythonClass)); out_py = MXFP8TensorClass("rowwise_data"_a = rowwise_data_py, "columnwise_data"_a = columnwise_data_py, "rowwise_scale_inv"_a = rowwise_scale_inv_py, "columnwise_scale_inv"_a = columnwise_scale_inv_py, - "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); + "fp8_dtype"_a = this->dtype, "quantizer"_a = quantizer_py); } else { py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorPythonClass)); out_py = MXFP8TensorClass("shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), @@ -847,7 +863,7 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve "columnwise_data"_a = columnwise_data_py, "rowwise_scale_inv"_a = rowwise_scale_inv_py, "columnwise_scale_inv"_a = columnwise_scale_inv_py, - "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); + "fp8_dtype"_a = this->dtype, "quantizer"_a = quantizer_py); } // Construct C++ MXFP8 tensor @@ -953,7 +969,6 @@ std::pair MXFP8Quantizer::coerce_tensor(py::object te // Coerce other attrs tensor.attr("_fp8_dtype") = dtype; - tensor.attr("_quantizer") = quantizer; /// TODO Need to make copy? // Construct C++ MXFP8 tensor TensorWrapper out_cpp(NVTE_MXFP8_1D_SCALING); From c54d821877a79a6f9aff1e1e71bc4126f49c26e6 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Mon, 21 Jul 2025 23:15:26 +0000 Subject: [PATCH 19/24] Fix linter warnings Signed-off-by: Tim Moon --- transformer_engine/pytorch/ops/basic/basic_linear.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 5766aaaab9..05e8170c63 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -23,8 +23,6 @@ from ...module.base import _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD from ...tensor import Quantizer from ...tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer -from ...tensor.float8_blockwise_tensor import Float8BlockQuantizer -from ...tensor.mxfp8_tensor import MXFP8Quantizer from ...tensor._internal.float8_tensor_base import Float8TensorBase from ..op import BasicOperation, OperationContext from .._common import maybe_dequantize, is_quantized_tensor From c252dc0ba5dfb99874f5d6ee9f21e47629ed419c Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Tue, 22 Jul 2025 06:36:28 +0000 Subject: [PATCH 20/24] Make sure quantized tensors have private quantizer Avoid problems with in-place ops after quantizer usages are changed externally. Signed-off-by: Tim Moon --- transformer_engine/pytorch/csrc/quantizer.cpp | 32 +++++-------------- .../_internal/float8_blockwise_tensor_base.py | 2 +- .../tensor/_internal/float8_tensor_base.py | 2 +- .../tensor/_internal/mxfp8_tensor_base.py | 2 +- .../pytorch/tensor/float8_blockwise_tensor.py | 2 +- .../pytorch/tensor/float8_tensor.py | 2 +- .../pytorch/tensor/mxfp8_tensor.py | 2 +- 7 files changed, 14 insertions(+), 30 deletions(-) diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 311dfc0c75..2558cafbf8 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -146,24 +146,20 @@ std::pair Float8Quantizer::create_tensor( scale_inv = at::reciprocal(scale); } - // Make shallow copy of quantizer so in-place ops aren't influenced - // by future usage changes - auto quantizer_py = this->quantizer.attr("copy")(); - // Construct Python FP8 tensor py::object out_py; if (internal) { py::handle Float8TensorClass(reinterpret_cast(Float8TensorBasePythonClass)); out_py = Float8TensorClass("data"_a = data_py, "fp8_scale_inv"_a = *scale_inv, "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, - "quantizer"_a = quantizer_py); + "quantizer"_a = this->quantizer); } else { py::handle Float8TensorClass(reinterpret_cast(Float8TensorPythonClass)); const std::vector shape_int64(shape.begin(), shape.end()); out_py = Float8TensorClass("shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), "data"_a = data_py, "fp8_scale_inv"_a = *scale_inv, "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, - "quantizer"_a = quantizer_py); + "quantizer"_a = this->quantizer); } // Construct C++ FP8 tensor @@ -363,10 +359,6 @@ std::pair Float8CurrentScalingQuantizer::create_tenso scale_inv_tensor = at::empty(scale_inv_shape, opts); } - // Make shallow copy of quantizer so in-place ops aren't influenced - // by future usage changes - auto quantizer_py = this->quantizer.attr("copy")(); - // Construct Python FP8 tensor py::object out_py; py::object data_py = with_data ? py::cast(data_tensor) : py::none(); @@ -375,14 +367,14 @@ std::pair Float8CurrentScalingQuantizer::create_tenso py::handle Float8TensorClass(reinterpret_cast(Float8TensorBasePythonClass)); out_py = Float8TensorClass("data"_a = data_py, "fp8_scale_inv"_a = scale_inv_tensor, "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, - "quantizer"_a = quantizer_py); + "quantizer"_a = this->quantizer); } else { py::handle Float8TensorClass(reinterpret_cast(Float8TensorPythonClass)); const std::vector shape_int64(shape.begin(), shape.end()); out_py = Float8TensorClass("shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), "data"_a = data_py, "fp8_scale_inv"_a = scale_inv_tensor, "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, - "quantizer"_a = quantizer_py); + "quantizer"_a = this->quantizer); } // Construct C++ FP8 tensor @@ -624,10 +616,6 @@ std::pair Float8BlockQuantizer::create_tensor( } this->set_quantization_params(&tensor); - // Make shallow copy of quantizer so in-place ops aren't influenced - // by future usage changes - auto quantizer_py = this->quantizer.attr("copy")(); - py::object ret; if (internal) { py::handle Float8BlockwiseQTensorClass( @@ -635,7 +623,7 @@ std::pair Float8BlockQuantizer::create_tensor( ret = Float8BlockwiseQTensorClass( "rowwise_data"_a = data_rowwise, "columnwise_data"_a = data_colwise, "rowwise_scale_inv"_a = scale_inv_rowwise, "columnwise_scale_inv"_a = scale_inv_colwise, - "fp8_dtype"_a = this->dtype, "quantizer"_a = quantizer_py, + "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer, "is_2D_scaled"_a = (block_scaling_dim == 2), "data_format"_a = data_format); } else { py::handle Float8BlockwiseQTensorClass( @@ -644,7 +632,7 @@ std::pair Float8BlockQuantizer::create_tensor( "shape"_a = torch_shape, "dtype"_a = GetATenDType(dtype), "rowwise_data"_a = data_rowwise, "columnwise_data"_a = data_colwise, "rowwise_scale_inv"_a = scale_inv_rowwise, "columnwise_scale_inv"_a = scale_inv_colwise, "fp8_dtype"_a = this->dtype, - "quantizer"_a = quantizer_py, "is_2D_scaled"_a = (block_scaling_dim == 2), + "quantizer"_a = this->quantizer, "is_2D_scaled"_a = (block_scaling_dim == 2), "data_format"_a = data_format); } @@ -843,10 +831,6 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve auto columnwise_data_py = py_cast(columnwise_data_tensor, columnwise_usage); auto columnwise_scale_inv_py = py_cast(columnwise_scale_inv_tensor, columnwise_usage); - // Make shallow copy of quantizer so in-place ops aren't influenced - // by future usage changes - auto quantizer_py = this->quantizer.attr("copy")(); - // Construct Python MXFP8 tensor py::object out_py; if (internal) { @@ -855,7 +839,7 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve "columnwise_data"_a = columnwise_data_py, "rowwise_scale_inv"_a = rowwise_scale_inv_py, "columnwise_scale_inv"_a = columnwise_scale_inv_py, - "fp8_dtype"_a = this->dtype, "quantizer"_a = quantizer_py); + "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); } else { py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorPythonClass)); out_py = MXFP8TensorClass("shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), @@ -863,7 +847,7 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve "columnwise_data"_a = columnwise_data_py, "rowwise_scale_inv"_a = rowwise_scale_inv_py, "columnwise_scale_inv"_a = columnwise_scale_inv_py, - "fp8_dtype"_a = this->dtype, "quantizer"_a = quantizer_py); + "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); } // Construct C++ MXFP8 tensor diff --git a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py index 882650ffba..787c322a0c 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py @@ -59,7 +59,7 @@ def __new__( instance = super().__new__(cls, *args, **kwargs) instance._rowwise_data = rowwise_data instance._columnwise_data = columnwise_data - instance._quantizer = quantizer + instance._quantizer = quantizer.copy() if quantizer is not None else None instance._fp8_dtype = fp8_dtype instance._rowwise_scale_inv = rowwise_scale_inv instance._columnwise_scale_inv = columnwise_scale_inv diff --git a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py index c0dc6e6519..a88ae33f09 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py @@ -86,7 +86,7 @@ def __new__( else: instance = super().__new__(cls, *args, **kwargs) instance._data = data - instance._quantizer = quantizer + instance._quantizer = quantizer.copy() if quantizer is not None else None instance._fp8_dtype = fp8_dtype instance._scale_inv = fp8_scale_inv instance._transpose = data_transpose diff --git a/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py index 8f87e5c73d..a093904bc9 100644 --- a/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py @@ -83,7 +83,7 @@ def __new__( instance = super().__new__(cls, *args, **kwargs) instance._rowwise_data = rowwise_data instance._columnwise_data = columnwise_data - instance._quantizer = quantizer + instance._quantizer = quantizer.copy() if quantizer is not None else None instance._fp8_dtype = fp8_dtype instance._rowwise_scale_inv = rowwise_scale_inv instance._columnwise_scale_inv = columnwise_scale_inv diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index bac7159491..0e41fc9c51 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -521,7 +521,7 @@ def _set_data(self, tensor: torch.Tensor) -> None: def _set_from_tensor(dst: Float8BlockwiseQTensor, src: Float8BlockwiseQTensor): dst._rowwise_data = src._rowwise_data dst._columnwise_data = src._columnwise_data - dst._quantizer = src._quantizer + dst._quantizer = src._quantizer.copy() dst._fp8_dtype = src._fp8_dtype dst._rowwise_scale_inv = src._rowwise_scale_inv dst._columnwise_scale_inv = src._columnwise_scale_inv diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 4c57ffd3e5..895e68bf02 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -689,7 +689,7 @@ def _set_data(self, tensor: torch.Tensor) -> None: # Float8Tensor attributes self._data = tensor._data - self._quantizer = tensor._quantizer + self._quantizer = tensor._quantizer.copy() self._fp8_dtype = tensor._fp8_dtype self._scale_inv = tensor._scale_inv self._transpose = tensor._transpose diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 10b587e17e..b96575d37b 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -433,7 +433,7 @@ def _set_data(self, tensor: torch.Tensor) -> None: super(MXFP8Tensor, type(self)).data.__set__(self, dummy_tensor) self._rowwise_data = tensor._rowwise_data self._columnwise_data = tensor._columnwise_data - self._quantizer = tensor._quantizer + self._quantizer = tensor._quantizer.copy() self._fp8_dtype = tensor._fp8_dtype self._rowwise_scale_inv = tensor._rowwise_scale_inv self._columnwise_scale_inv = tensor._columnwise_scale_inv From df6313c7e43f6785bf1e3be8446d8ca41594414b Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Tue, 22 Jul 2025 21:36:23 +0000 Subject: [PATCH 21/24] Rename "coerce_tensor" to "convert_and_update_tensor" Signed-off-by: Tim Moon --- transformer_engine/pytorch/csrc/common.h | 14 +++++++------- .../pytorch/csrc/extensions/cast.cpp | 2 +- transformer_engine/pytorch/csrc/quantizer.cpp | 10 +++++----- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 078a96c113..5c0fff6b06 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -102,12 +102,12 @@ class Quantizer { virtual std::pair create_tensor(const std::vector& shape, DType dtype) const = 0; - /*! @brief Load a PyTorch tensor + /*! @brief Convert a PyTorch tensor into a Transformer Engine C++ tensor * * The PyTorch tensor's attributes are modified to match the * quantizer's configuration. */ - virtual std::pair coerce_tensor(py::object tensor) const = 0; + virtual std::pair convert_and_update_tensor(py::object tensor) const = 0; /*! @brief Convert to a quantized data format */ virtual void quantize(const TensorWrapper& input, TensorWrapper& out, @@ -139,7 +139,7 @@ class NoneQuantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype, at::Tensor data) const; - std::pair coerce_tensor(py::object tensor) const override; + std::pair convert_and_update_tensor(py::object tensor) const override; void quantize(const TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag = std::nullopt) override; @@ -167,7 +167,7 @@ class Float8Quantizer : public Quantizer { std::optional transpose, std::optional scale_inv) const; - std::pair coerce_tensor(py::object shape) const override; + std::pair convert_and_update_tensor(py::object shape) const override; void quantize(const TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag = std::nullopt) override; @@ -193,7 +193,7 @@ class Float8CurrentScalingQuantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype) const override; - std::pair coerce_tensor(py::object shape) const override; + std::pair convert_and_update_tensor(py::object shape) const override; void quantize(const TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag = std::nullopt) override; @@ -231,7 +231,7 @@ class Float8BlockQuantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype) const override; - std::pair coerce_tensor(py::object shape) const override; + std::pair convert_and_update_tensor(py::object shape) const override; void quantize(const TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag = std::nullopt) override; @@ -252,7 +252,7 @@ class MXFP8Quantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype) const override; - std::pair coerce_tensor(py::object shape) const override; + std::pair convert_and_update_tensor(py::object shape) const override; void quantize(const TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag = std::nullopt) override; diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index da054ae99a..5408cf1a6b 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -47,7 +47,7 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob const auto fake_dtype = input_cpp.dtype(); std::tie(output_cpp, output_py) = quantizer_cpp->create_tensor(shape, fake_dtype); } else { - std::tie(output_cpp, output_py) = quantizer_cpp->coerce_tensor(output); + std::tie(output_cpp, output_py) = quantizer_cpp->convert_and_update_tensor(output); } // Initialize no-op flag diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 2558cafbf8..fcf743c85a 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -74,7 +74,7 @@ std::pair NoneQuantizer::create_tensor(const std::vec return {std::move(out_cpp), py::cast(data)}; } -std::pair NoneQuantizer::coerce_tensor(py::object tensor) const { +std::pair NoneQuantizer::convert_and_update_tensor(py::object tensor) const { auto tensor_pyt = tensor.cast(); TensorWrapper out_cpp; out_cpp.set_rowwise_data(tensor_pyt.data_ptr(), @@ -179,7 +179,7 @@ std::pair Float8Quantizer::create_tensor( return {std::move(out_cpp), std::move(out_py)}; } -std::pair Float8Quantizer::coerce_tensor(py::object tensor) const { +std::pair Float8Quantizer::convert_and_update_tensor(py::object tensor) const { NVTE_CHECK(detail::IsFloat8Tensor(tensor.ptr()), "Float8Quantizer must output to Float8Tensor."); // Expected buffers @@ -395,7 +395,7 @@ std::pair Float8CurrentScalingQuantizer::create_tenso return {std::move(out_cpp), std::move(out_py)}; } -std::pair Float8CurrentScalingQuantizer::coerce_tensor( +std::pair Float8CurrentScalingQuantizer::convert_and_update_tensor( py::object tensor) const { NVTE_CHECK(detail::IsFloat8Tensor(tensor.ptr()), "Float8CurrentScalingQuantizer must output to Float8Tensor."); @@ -639,7 +639,7 @@ std::pair Float8BlockQuantizer::create_tensor( return {std::move(tensor), std::move(ret)}; } -std::pair Float8BlockQuantizer::coerce_tensor(py::object tensor) const { +std::pair Float8BlockQuantizer::convert_and_update_tensor(py::object tensor) const { const DType dtype = tensor.attr("_fp8_dtype").cast(); bool is_2D_scaled = tensor.attr("_is_2D_scaled").cast(); @@ -867,7 +867,7 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve return {std::move(out_cpp), std::move(out_py)}; } -std::pair MXFP8Quantizer::coerce_tensor(py::object tensor) const { +std::pair MXFP8Quantizer::convert_and_update_tensor(py::object tensor) const { NVTE_CHECK(detail::IsMXFP8Tensor(tensor.ptr()), "MXFP8Quantizer must output to MXFP8Tensor."); // Extract buffers from Python tensor From 27cf92a4d099f70dd0664a5761456386231feb3f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 22 Jul 2025 21:37:01 +0000 Subject: [PATCH 22/24] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/csrc/common.h | 3 ++- transformer_engine/pytorch/csrc/quantizer.cpp | 12 ++++++++---- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 5c0fff6b06..be3b995a13 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -107,7 +107,8 @@ class Quantizer { * The PyTorch tensor's attributes are modified to match the * quantizer's configuration. */ - virtual std::pair convert_and_update_tensor(py::object tensor) const = 0; + virtual std::pair convert_and_update_tensor( + py::object tensor) const = 0; /*! @brief Convert to a quantized data format */ virtual void quantize(const TensorWrapper& input, TensorWrapper& out, diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index fcf743c85a..a7b7f58891 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -74,7 +74,8 @@ std::pair NoneQuantizer::create_tensor(const std::vec return {std::move(out_cpp), py::cast(data)}; } -std::pair NoneQuantizer::convert_and_update_tensor(py::object tensor) const { +std::pair NoneQuantizer::convert_and_update_tensor( + py::object tensor) const { auto tensor_pyt = tensor.cast(); TensorWrapper out_cpp; out_cpp.set_rowwise_data(tensor_pyt.data_ptr(), @@ -179,7 +180,8 @@ std::pair Float8Quantizer::create_tensor( return {std::move(out_cpp), std::move(out_py)}; } -std::pair Float8Quantizer::convert_and_update_tensor(py::object tensor) const { +std::pair Float8Quantizer::convert_and_update_tensor( + py::object tensor) const { NVTE_CHECK(detail::IsFloat8Tensor(tensor.ptr()), "Float8Quantizer must output to Float8Tensor."); // Expected buffers @@ -639,7 +641,8 @@ std::pair Float8BlockQuantizer::create_tensor( return {std::move(tensor), std::move(ret)}; } -std::pair Float8BlockQuantizer::convert_and_update_tensor(py::object tensor) const { +std::pair Float8BlockQuantizer::convert_and_update_tensor( + py::object tensor) const { const DType dtype = tensor.attr("_fp8_dtype").cast(); bool is_2D_scaled = tensor.attr("_is_2D_scaled").cast(); @@ -867,7 +870,8 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve return {std::move(out_cpp), std::move(out_py)}; } -std::pair MXFP8Quantizer::convert_and_update_tensor(py::object tensor) const { +std::pair MXFP8Quantizer::convert_and_update_tensor( + py::object tensor) const { NVTE_CHECK(detail::IsMXFP8Tensor(tensor.ptr()), "MXFP8Quantizer must output to MXFP8Tensor."); // Extract buffers from Python tensor From 3e7dbb1be6c1a17f3d08f80e2538f591eb9b5c88 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Thu, 24 Jul 2025 05:41:01 +0000 Subject: [PATCH 23/24] Make sure CUDA context is available when launching NVRTC kernel Signed-off-by: Tim Moon --- transformer_engine/common/util/cuda_driver.cpp | 13 +++++++++++++ transformer_engine/common/util/cuda_driver.h | 8 ++++++++ transformer_engine/common/util/rtc.h | 1 + 3 files changed, 22 insertions(+) diff --git a/transformer_engine/common/util/cuda_driver.cpp b/transformer_engine/common/util/cuda_driver.cpp index 59d490e58e..4812435f7b 100644 --- a/transformer_engine/common/util/cuda_driver.cpp +++ b/transformer_engine/common/util/cuda_driver.cpp @@ -44,6 +44,19 @@ void *get_symbol(const char *symbol, int cuda_version) { return entry_point; } +void ensure_context_exists() { + CUcontext context; + NVTE_CALL_CHECK_CUDA_DRIVER(cuCtxGetCurrent, &context); + if (context == nullptr) { + // Add primary context to context stack + CUdevice device; + NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &device, cuda::current_device()); + NVTE_CALL_CHECK_CUDA_DRIVER(cuDevicePrimaryCtxRetain, &context, device); + NVTE_CALL_CHECK_CUDA_DRIVER(cuCtxSetCurrent, context); + NVTE_CALL_CHECK_CUDA_DRIVER(cuDevicePrimaryCtxRelease, device); + } +} + } // namespace cuda_driver } // namespace transformer_engine diff --git a/transformer_engine/common/util/cuda_driver.h b/transformer_engine/common/util/cuda_driver.h index a0fcd65c85..3425e0af35 100644 --- a/transformer_engine/common/util/cuda_driver.h +++ b/transformer_engine/common/util/cuda_driver.h @@ -39,6 +39,14 @@ inline CUresult call(const char *symbol, ArgTs... args) { return (*func)(args...); } +/*! \brief Ensure that the calling thread has a CUDA context + * + * Each thread maintains a stack of CUDA contexts. If the calling + * thread has an empty stack, the primary context is added to the + * stack. + */ +void ensure_context_exists(); + } // namespace cuda_driver } // namespace transformer_engine diff --git a/transformer_engine/common/util/rtc.h b/transformer_engine/common/util/rtc.h index 820b16c206..7de1e4d55c 100644 --- a/transformer_engine/common/util/rtc.h +++ b/transformer_engine/common/util/rtc.h @@ -59,6 +59,7 @@ class Kernel { template void launch(int device_id, const dim3 grid_dim, const dim3 block_dim, unsigned int shared_mem_bytes, cudaStream_t stream, ArgTs &&...args) { + cuda_driver::ensure_context_exists(); void *arg_ptrs[] = {const_cast(static_cast(&args))...}; NVTE_CALL_CHECK_CUDA_DRIVER(cuLaunchKernel, get_function(device_id), grid_dim.x, grid_dim.y, grid_dim.z, block_dim.x, block_dim.y, block_dim.z, shared_mem_bytes, From 261f60f138e5baae263ab170b4ab41abfa20fffd Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Thu, 24 Jul 2025 19:28:39 +0000 Subject: [PATCH 24/24] Expose CUDA context creation function externally Signed-off-by: Tim Moon --- transformer_engine/common/libtransformer_engine.version | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_engine/common/libtransformer_engine.version b/transformer_engine/common/libtransformer_engine.version index 4412d0c5fe..706c237ccc 100644 --- a/transformer_engine/common/libtransformer_engine.version +++ b/transformer_engine/common/libtransformer_engine.version @@ -8,6 +8,7 @@ transformer_engine::cuda::stream_priority_range*; transformer_engine::cuda::current_device*; transformer_engine::cuda_driver::get_symbol*; + transformer_engine::cuda_driver::ensure_context_exists*; transformer_engine::ubuf_built_with_mpi*; *transformer_engine::rtc*; transformer_engine::nvte_cudnn_handle_init*;