From 1d7776fdfe186a7252496e82e3d0410acfabf936 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Tue, 8 Apr 2025 15:17:53 -0700 Subject: [PATCH 01/25] Beginning of work to properly reuse the output given to quantize Signed-off-by: Przemek Tredak --- tests/pytorch/test_sanity.py | 25 +++++ transformer_engine/pytorch/csrc/common.h | 6 ++ .../pytorch/csrc/extensions/attention.cpp | 6 +- .../pytorch/csrc/extensions/cast.cpp | 16 +--- transformer_engine/pytorch/csrc/quantizer.cpp | 96 ++++++++++++++----- 5 files changed, 111 insertions(+), 38 deletions(-) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 2ca133e77b..0fc6b34779 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -395,6 +395,31 @@ def _test_sanity_common( loss.backward() torch.cuda.synchronize() + # now try eval with weight caching + block.eval() + + with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): + te_out = block(te_inp, is_first_microbatch=True) + with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): + te_out = block(te_inp, is_first_microbatch=False) + torch.cuda.synchronize() + + # now try regular execution again with weight caching + block.train() + + with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): + te_out = block(te_inp, is_first_microbatch=True) + if isinstance(te_out, tuple): + te_out = te_out[0] + loss = te_out.sum() + loss.backward() + with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): + te_out = block(te_inp, is_first_microbatch=False) + if isinstance(te_out, tuple): + te_out = te_out[0] + loss = te_out.sum() + loss.backward() + torch.cuda.synchronize() def _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad): if skip_dgrad and skip_wgrad: diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index b84b95cb23..2de60e9e8a 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -98,6 +98,7 @@ class Quantizer { virtual std::pair create_tensor( const std::vector& shape, DType dtype, + const py::object& output = py::none(), std::optional rowwise_data = std::nullopt) const = 0; virtual ~Quantizer() = default; @@ -121,6 +122,7 @@ class NoneQuantizer : public Quantizer { std::pair create_tensor( const std::vector& shape, DType dtype, + const py::object& output = py::none(), std::optional rowwise_data = std::nullopt) const override; }; @@ -139,6 +141,7 @@ class Float8Quantizer : public Quantizer { std::pair create_tensor( const std::vector& shape, DType dtype, + const py::object& output = py::none(), std::optional rowwise_data = std::nullopt) const override; }; @@ -161,6 +164,7 @@ class Float8CurrentScalingQuantizer : public Quantizer { std::pair create_tensor( const std::vector& shape, DType dtype, + const py::object& output = py::none(), std::optional rowwise_data = std::nullopt) const override; }; @@ -193,6 +197,7 @@ class Float8BlockQuantizer : public Quantizer { // and optionally columnwise usage. std::pair create_tensor( const std::vector& shape, DType dtype, + const py::object& output = py::none(), std::optional rowwise_data = std::nullopt) const override; }; @@ -208,6 +213,7 @@ class MXFP8Quantizer : public Quantizer { std::pair create_tensor( const std::vector& shape, DType dtype, + const py::object& output = py::none(), std::optional rowwise_data = std::nullopt) const override; }; diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index efe825f0db..b7ee42cf18 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -374,9 +374,9 @@ std::vector fused_attn_bwd( default: NVTE_ERROR("QKV layout not supported!"); } - std::tie(te_dQ, py_dQ) = dQKV_quantizer->create_tensor(q_shape, fake_dtype_te, dQ); - std::tie(te_dK, py_dK) = dQKV_quantizer->create_tensor(k_shape, fake_dtype_te, dK); - std::tie(te_dV, py_dV) = dQKV_quantizer->create_tensor(v_shape, fake_dtype_te, dV); + std::tie(te_dQ, py_dQ) = dQKV_quantizer->create_tensor(q_shape, fake_dtype_te, py::none(), dQ); + std::tie(te_dK, py_dK) = dQKV_quantizer->create_tensor(k_shape, fake_dtype_te, py::none(), dK); + std::tie(te_dV, py_dV) = dQKV_quantizer->create_tensor(v_shape, fake_dtype_te, py::none(), dV); // construct NVTE tensors if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 1edbef8cd6..8585c9f86b 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -27,15 +27,8 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob fake_tensor_type = at::kFloat; } - TensorWrapper te_output; - py::object out; - if (output.is_none()) { - DType fake_te_type = GetTransformerEngineDType(fake_tensor_type); - std::tie(te_output, out) = my_quantizer->create_tensor(input_shape, fake_te_type); - } else { - out = output; - te_output = makeTransformerEngineTensor(output, quantizer); - } + DType fake_te_type = GetTransformerEngineDType(fake_tensor_type); + auto [te_output, out] = my_quantizer->create_tensor(input_shape, fake_te_type, output); TensorWrapper te_noop; if (noop.has_value()) { @@ -101,7 +94,7 @@ py::object dequantize(const py::handle& input, transformer_engine::DType otype) const auto& shape = convertShape(input_tensor.shape()); - auto [out_tensor, out] = q.create_tensor(shape, otype); + auto [out_tensor, out] = q.create_tensor(shape, otype, none); NVTE_SCOPED_GIL_RELEASE({ nvte_dequantize(input_tensor.data(), out_tensor.data(), at::cuda::getCurrentCUDAStream()); @@ -123,7 +116,8 @@ std::vector dbias_dact(const at::Tensor& grad_output, const at::Tens auto act_input_tensor = makeTransformerEngineTensor(act_input); const auto& shape = convertShape(grad_tensor.shape()); - auto [dact_tensor, dact] = my_quantizer->create_tensor(shape, act_input_tensor.dtype()); + auto [dact_tensor, dact] = my_quantizer->create_tensor(shape, act_input_tensor.dtype(), + py::none()); auto dbias_tensor = makeTransformerEngineTensor(grad_bias); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 4744d8ca92..4ab315ecd7 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -14,6 +14,18 @@ namespace transformer_engine::pytorch { constexpr size_t MXFP8_BLOCK_SIZE = 32; +bool tensor_is_reusable(const at::Tensor& tensor, + const std::vector& shape, + const at::TensorOptions& opts) { + const auto& tensor_shape = tensor.sizes(); + if (tensor_shape.equals(shape)) return false; + const at::TensorOptions& tensor_opts = tensor.options(); + if (opts.dtype() == tensor_opts.dtype()) return false; + if (opts.device() == tensor_opts.device()) return false; + if (opts.device_index() == tensor_opts.device_index()) return false; + return true; +} + Quantizer::Quantizer(const py::handle& quantizer) { if (quantizer.is_none()) { this->rowwise_usage = true; @@ -37,8 +49,23 @@ Float8Quantizer::Float8Quantizer(const py::handle& quantizer) : Quantizer(quanti this->dtype = type; } +at::Tensor create_torch_tensor(const std::vector& shape, + const at::TensorOptions& opts, + const py::object& tensor_to_reuse) { + if (!tensor_to_reuse.is_none()) { + // Reuse output + const at::Tensor temp = tensor_to_reuse.cast(); + if (tensor_is_reusable(temp, shape, opts)) { + return temp; + } + } + return at::empty(shape, opts); +} + std::pair NoneQuantizer::create_tensor( - const std::vector& shape, DType dtype, std::optional rowwise_data) const { + const std::vector& shape, DType dtype, + const py::object& output, + std::optional rowwise_data) const { at::TensorOptions opts; opts = opts.dtype(GetATenDType(dtype)).device(torch::kCUDA); std::vector torch_shape; @@ -49,7 +76,7 @@ std::pair NoneQuantizer::create_tensor( if (rowwise_data.has_value()) { ret = std::move(*rowwise_data); } else { - ret = at::empty(torch_shape, opts); + ret = create_torch_tensor(torch_shape, opts, output); } TensorWrapper tensor; @@ -60,7 +87,6 @@ std::pair NoneQuantizer::create_tensor( void Float8Quantizer::set_quantization_params(TensorWrapper* tensor) const { tensor->set_scale(scale.data_ptr(), GetTransformerEngineDType(scale.scalar_type()), getTensorShape(scale)); - at::TensorOptions opts = opts.dtype(torch::kFloat32).device(torch::kCUDA); tensor->set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), getTensorShape(amax)); auto rowwise_data = tensor->get_rowwise_data(); @@ -76,10 +102,13 @@ void Float8Quantizer::set_quantization_params(TensorWrapper* tensor) const { } std::pair Float8Quantizer::create_tensor( - const std::vector& shape, DType dtype, std::optional rowwise_data) const { + const std::vector& shape, DType dtype, + const py::object& output, std::optional rowwise_data) const { using namespace pybind11::literals; std::vector rowwise_torch_shape; std::vector columnwise_torch_shape; + rowwise_torch_shape.reserve(shape.size()); + columnwise_torch_shape.reserve(shape.size()); if (!shape.empty()) { columnwise_torch_shape.emplace_back(static_cast(shape.back())); @@ -93,35 +122,55 @@ std::pair Float8Quantizer::create_tensor( at::TensorOptions opts; opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); at::Tensor data; + at::Tensor columnwise_data; + py::object py_data{py::none()}; + py::object py_columnwise_data{py::none()}; + bool create_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + if (!output.is_none()) { + NVTE_CHECK(detail::IsFloat8Tensor(output.ptr()), + "Wrong Tensor type provided for reuse. ", + "Expected Float8Tensor or Float8TensorBase, but got ", + py::repr(output).cast()); + py_data = output.attr("_data"); + py_columnwise_data = output.attr("_transpose"); + } if (rowwise_usage) { if (rowwise_data.has_value()) { data = std::move(*rowwise_data); } else { - data = at::empty(rowwise_torch_shape, opts); + data = create_torch_tensor(rowwise_torch_shape, opts, py_data); } } - const py::object py_data = rowwise_usage ? py::cast(data) : py::none(); - at::Tensor columnwise_data; - bool create_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + py_data = rowwise_usage ? py::cast(data) : py::none(); if (create_transpose) { - columnwise_data = at::empty(columnwise_torch_shape, opts); + columnwise_data = create_torch_tensor(columnwise_torch_shape, opts, py_columnwise_data); } - const py::object py_columnwise_data = create_transpose ? py::cast(columnwise_data) : py::none(); + py_columnwise_data = create_transpose ? py::cast(columnwise_data) : py::none(); opts = opts.dtype(torch::kFloat32); // TODO: Replace with an empty tensor. at::Tensor scale_inv = at::reciprocal(scale); py::object ret; - if (internal) { - py::handle Float8TensorClass(reinterpret_cast(Float8TensorBasePythonClass)); - ret = Float8TensorClass("data"_a = py_data, "fp8_scale_inv"_a = scale_inv, - "fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data, - "quantizer"_a = this->quantizer); + if (output.is_none()) { + if (internal) { + py::handle Float8TensorClass(reinterpret_cast(Float8TensorBasePythonClass)); + ret = Float8TensorClass("data"_a = py_data, "fp8_scale_inv"_a = scale_inv, + "fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data, + "quantizer"_a = this->quantizer); + } else { + py::handle Float8TensorClass(reinterpret_cast(Float8TensorPythonClass)); + ret = Float8TensorClass("shape"_a = rowwise_torch_shape, "dtype"_a = GetATenDType(dtype), + "data"_a = py_data, "fp8_scale_inv"_a = scale_inv, + "fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data, + "quantizer"_a = this->quantizer); + } } else { - py::handle Float8TensorClass(reinterpret_cast(Float8TensorPythonClass)); - ret = Float8TensorClass("shape"_a = rowwise_torch_shape, "dtype"_a = GetATenDType(dtype), - "data"_a = py_data, "fp8_scale_inv"_a = scale_inv, - "fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data, - "quantizer"_a = this->quantizer); + output.attr("_data") = py_data; + output.attr("_scale_inv") = scale_inv; + output.attr("_fp8_dtype") = this->dtype; + output.attr("_transpose") = py_columnwise_data; + output.attr("_quantizer") = this->quantizer; + output.attr("_transpose_invalid") = !py_columnwise_data.is_none(); + ret = output; } TensorWrapper tensor(this->get_scaling_mode()); if (rowwise_usage) { @@ -170,7 +219,6 @@ void Float8CurrentScalingQuantizer::set_quantization_params(TensorWrapper* tenso // transfer amax and scale pointer from quantizer to output tensor (only as gpu buffer, no meaningful data in them) tensor->set_scale(scale.data_ptr(), GetTransformerEngineDType(scale.scalar_type()), getTensorShape(scale)); - at::TensorOptions opts = opts.dtype(torch::kFloat32).device(torch::kCUDA); tensor->set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), getTensorShape(amax)); // quantize output and its transpose @@ -187,7 +235,7 @@ void Float8CurrentScalingQuantizer::set_quantization_params(TensorWrapper* tenso } std::pair Float8CurrentScalingQuantizer::create_tensor( - const std::vector& shape, DType dtype, std::optional rowwise_data) const { + const std::vector& shape, DType dtype, const py::object& output, std::optional rowwise_data) const { using namespace pybind11::literals; std::vector rowwise_torch_shape; std::vector columnwise_torch_shape; @@ -279,7 +327,7 @@ void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const } std::pair Float8BlockQuantizer::create_tensor( - const std::vector& shape, DType dtype, std::optional rowwise_data) const { + const std::vector& shape, DType dtype, const py::object& output, std::optional rowwise_data) const { using namespace pybind11::literals; std::vector torch_shape; size_t numel = 1; @@ -405,7 +453,7 @@ void MXFP8Quantizer::set_quantization_params(TensorWrapper* tensor) const { } std::pair MXFP8Quantizer::create_tensor( - const std::vector& shape, DType dtype, std::optional rowwise_data) const { + const std::vector& shape, DType dtype, const py::object& output, std::optional rowwise_data) const { using namespace pybind11::literals; std::vector torch_shape; size_t numel = 1; From d207cea713f6f7e045ed0caea75ec9b319de3c66 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Tue, 8 Apr 2025 17:43:32 -0700 Subject: [PATCH 02/25] Add current scaling Signed-off-by: Przemek Tredak --- transformer_engine/pytorch/csrc/quantizer.cpp | 56 +++++++++++++------ 1 file changed, 39 insertions(+), 17 deletions(-) diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 4ab315ecd7..23474097b6 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -169,7 +169,7 @@ std::pair Float8Quantizer::create_tensor( output.attr("_fp8_dtype") = this->dtype; output.attr("_transpose") = py_columnwise_data; output.attr("_quantizer") = this->quantizer; - output.attr("_transpose_invalid") = !py_columnwise_data.is_none(); + output.attr("_transpose_invalid") = py_columnwise_data.is_none(); ret = output; } TensorWrapper tensor(this->get_scaling_mode()); @@ -240,6 +240,8 @@ std::pair Float8CurrentScalingQuantizer::create_tenso std::vector rowwise_torch_shape; std::vector columnwise_torch_shape; std::vector scale_inv_torch_shape = {1}; // Shape of 1 element for scale_inv + rowwise_torch_shape.reserve(shape.size()); + columnwise_torch_shape.reserve(shape.size()); if (!shape.empty()) { columnwise_torch_shape.emplace_back(static_cast(shape.back())); @@ -253,36 +255,56 @@ std::pair Float8CurrentScalingQuantizer::create_tenso at::TensorOptions opts; opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); at::Tensor data; + at::Tensor columnwise_data; + py::object py_data{py::none()}; + py::object py_columnwise_data{py::none()}; + bool create_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + if (!output.is_none()) { + NVTE_CHECK(detail::IsFloat8Tensor(output.ptr()), + "Wrong Tensor type provided for reuse. ", + "Expected Float8Tensor or Float8TensorBase, but got ", + py::repr(output).cast()); + py_data = output.attr("_data"); + py_columnwise_data = output.attr("_transpose"); + } if (rowwise_usage) { if (rowwise_data.has_value()) { data = std::move(*rowwise_data); } else { - data = at::empty(rowwise_torch_shape, opts); + data = create_torch_tensor(rowwise_torch_shape, opts, py_data); } } - const py::object py_data = rowwise_usage ? py::cast(data) : py::none(); - at::Tensor columnwise_data; - bool create_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + py_data = rowwise_usage ? py::cast(data) : py::none(); if (create_transpose) { - columnwise_data = at::empty(columnwise_torch_shape, opts); + columnwise_data = create_torch_tensor(columnwise_torch_shape, opts, py_columnwise_data); } - const py::object py_columnwise_data = create_transpose ? py::cast(columnwise_data) : py::none(); + py_columnwise_data = create_transpose ? py::cast(columnwise_data) : py::none(); // In current scaling, scale is not known but we initialize it with 1 to avoid division by zero. If scale is already calculated, it can be correctly set. at::Tensor scale_inv = at::reciprocal(scale); py::object ret; - if (internal) { - py::handle Float8TensorClass(reinterpret_cast(Float8TensorBasePythonClass)); - ret = Float8TensorClass("data"_a = py_data, "fp8_scale_inv"_a = scale_inv, - "fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data, - "quantizer"_a = this->quantizer); + if (output.is_none()) { + if (internal) { + py::handle Float8TensorClass(reinterpret_cast(Float8TensorBasePythonClass)); + ret = Float8TensorClass("data"_a = py_data, "fp8_scale_inv"_a = scale_inv, + "fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data, + "quantizer"_a = this->quantizer); + } else { + py::handle Float8TensorClass(reinterpret_cast(Float8TensorPythonClass)); + ret = Float8TensorClass("shape"_a = rowwise_torch_shape, "dtype"_a = GetATenDType(dtype), + "data"_a = py_data, "fp8_scale_inv"_a = scale_inv, + "fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data, + "quantizer"_a = this->quantizer); + } } else { - py::handle Float8TensorClass(reinterpret_cast(Float8TensorPythonClass)); - ret = Float8TensorClass("shape"_a = rowwise_torch_shape, "dtype"_a = GetATenDType(dtype), - "data"_a = py_data, "fp8_scale_inv"_a = scale_inv, - "fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data, - "quantizer"_a = this->quantizer); + output.attr("_data") = py_data; + output.attr("_scale_inv") = scale_inv; + output.attr("_fp8_dtype") = this->dtype; + output.attr("_transpose") = py_columnwise_data; + output.attr("_quantizer") = this->quantizer; + output.attr("_transpose_invalid") = py_columnwise_data.is_none(); + ret = output; } TensorWrapper tensor(this->get_scaling_mode()); if (rowwise_usage) { From fa28e4993f2ebbc7ad01f70a3a4cb1ce049f260e Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Mon, 14 Apr 2025 16:43:05 -0700 Subject: [PATCH 03/25] Beginning of the other recipes Signed-off-by: Przemek Tredak --- transformer_engine/pytorch/csrc/quantizer.cpp | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 23474097b6..30f234a319 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -369,11 +369,26 @@ std::pair Float8BlockQuantizer::create_tensor( size_t m_dim = numel / k_dim; constexpr size_t kBlockLen = 128; + py::object py_data{py::none()}; + py::object py_columnwise_data{py::none()}; + py::object py_scale_inv_rowwise{py::none()}; + py::object py_scale_inv_columnwise{py::none()}; + if (!output.is_none()) { + NVTE_CHECK(detail::IsFloat8BlockwiseQTensor(output.ptr()), + "Wrong Tensor type provided for reuse. ", + "Expected Float8BlockwiseQTensor or Float8BlockwiseQTensorBase, but got ", + py::repr(output).cast()); + py_data = output.attr("_rowwise_data"); + py_columnwise_data = output.attr("_columnwise_data"); + py_scale_inv_rowwise = output.attr("_rowwise_scale_inv"); + py_scale_inv_columnwise = output.attr("_columnwise_scale_inv"); + } + if (rowwise_usage) { if (rowwise_data.has_value()) { data_rowwise = std::move(*rowwise_data); } else { - data_rowwise = at::empty(torch_shape, opts); + data_rowwise = create_torch_tensor(torch_shape, opts, py_data); } size_t sinv0 = 0; size_t sinv1 = 0; From 49ad1223cf566b3b8bf69b701f48fd209cb6bdcf Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Fri, 25 Apr 2025 17:12:14 -0700 Subject: [PATCH 04/25] Added MXFP8 and cleanup Signed-off-by: Przemek Tredak --- transformer_engine/pytorch/csrc/common.cpp | 5 - transformer_engine/pytorch/csrc/common.h | 11 +- transformer_engine/pytorch/csrc/quantizer.cpp | 420 +++++++++--------- 3 files changed, 228 insertions(+), 208 deletions(-) diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index f86b60f612..69f129cbc2 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -286,9 +286,4 @@ std::vector convertShape(const NVTEShape& shape) { return std::vector(shape.data, shape.data + shape.ndim); } -int roundup(const int value, const int multiple) { - assert(multiple > 0); - return ((value + multiple - 1) / multiple) * multiple; -} - } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 2de60e9e8a..3e3993f644 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -355,7 +355,16 @@ void* getDataPtr(at::Tensor tensor, int offset = 0); std::vector convertShape(const NVTEShape& shape); -int roundup(const int value, const int multiple); +template +T divup(const T value, const T multiple) { + assert(multiple > 0); + return ((value + multiple - 1) / multiple); +} + +template +T roundup(const T value, const T multiple) { + return divup(value, multiple) * multiple; +} NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape); } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 30f234a319..dd7683f13f 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -8,7 +8,6 @@ #include "common.h" #include "pybind.h" -#include "torch/torch.h" namespace transformer_engine::pytorch { @@ -49,6 +48,7 @@ Float8Quantizer::Float8Quantizer(const py::handle& quantizer) : Quantizer(quanti this->dtype = type; } +// Create torch tensor reusing existing data if possible at::Tensor create_torch_tensor(const std::vector& shape, const at::TensorOptions& opts, const py::object& tensor_to_reuse) { @@ -62,16 +62,26 @@ at::Tensor create_torch_tensor(const std::vector& shape, return at::empty(shape, opts); } +// Create torch tensor reusing existing data is possible +// The reused tensor is tensor_to_reuse.attr_name +at::Tensor create_torch_tensor(const std::vector& shape, + const at::TensorOptions& opts, + const py::object& tensor_to_reuse, + const std::string_view& attr_name) { + py::object tensor{py::none()}; + if (!tensor_to_reuse.is_none()) { + tensor = tensor_to_reuse.attr(attr_name.data()); + } + return create_torch_tensor(shape, opts, tensor); +} + std::pair NoneQuantizer::create_tensor( const std::vector& shape, DType dtype, const py::object& output, std::optional rowwise_data) const { at::TensorOptions opts; opts = opts.dtype(GetATenDType(dtype)).device(torch::kCUDA); - std::vector torch_shape; - for (auto s : shape) { - torch_shape.emplace_back(static_cast(s)); - } + std::vector torch_shape(shape.begin(), shape.end()); at::Tensor ret; if (rowwise_data.has_value()) { ret = std::move(*rowwise_data); @@ -105,86 +115,78 @@ std::pair Float8Quantizer::create_tensor( const std::vector& shape, DType dtype, const py::object& output, std::optional rowwise_data) const { using namespace pybind11::literals; - std::vector rowwise_torch_shape; - std::vector columnwise_torch_shape; - rowwise_torch_shape.reserve(shape.size()); - columnwise_torch_shape.reserve(shape.size()); - if (!shape.empty()) { - columnwise_torch_shape.emplace_back(static_cast(shape.back())); - } - for (size_t i = 0; i < shape.size(); ++i) { - if (i < shape.size() - 1) { - columnwise_torch_shape.emplace_back(static_cast(shape[i])); - } - rowwise_torch_shape.emplace_back(static_cast(shape[i])); - } at::TensorOptions opts; opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); - at::Tensor data; - at::Tensor columnwise_data; - py::object py_data{py::none()}; - py::object py_columnwise_data{py::none()}; + std::vector rowwise_torch_shape(shape.begin(), shape.end()); + + std::optional data = std::nullopt; + std::optional columnwise_data = std::nullopt; + // TODO: Replace with an empty tensor. + at::Tensor scale_inv = at::reciprocal(scale); + bool create_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + TensorWrapper tensor(this->get_scaling_mode()); if (!output.is_none()) { NVTE_CHECK(detail::IsFloat8Tensor(output.ptr()), "Wrong Tensor type provided for reuse. ", "Expected Float8Tensor or Float8TensorBase, but got ", py::repr(output).cast()); - py_data = output.attr("_data"); - py_columnwise_data = output.attr("_transpose"); } + if (rowwise_usage) { if (rowwise_data.has_value()) { data = std::move(*rowwise_data); } else { - data = create_torch_tensor(rowwise_torch_shape, opts, py_data); + data = create_torch_tensor(rowwise_torch_shape, opts, output, "_data"); } + + tensor.set_rowwise_data(data->data_ptr(), this->dtype, shape); + tensor.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector{1}); } - py_data = rowwise_usage ? py::cast(data) : py::none(); + if (create_transpose) { - columnwise_data = create_torch_tensor(columnwise_torch_shape, opts, py_columnwise_data); + std::vector columnwise_torch_shape; + columnwise_torch_shape.reserve(shape.size()); + if (!shape.empty()) { + columnwise_torch_shape.emplace_back(static_cast(shape.back())); + } + for (size_t i = 0; i < shape.size() - 1; ++i) { + columnwise_torch_shape.emplace_back(static_cast(shape[i])); + } + std::vector transposed_shape(columnwise_torch_shape.begin(), + columnwise_torch_shape.end()); + + columnwise_data = create_torch_tensor(columnwise_torch_shape, opts, output, "_transpose"); + tensor.set_columnwise_data(columnwise_data->data_ptr(), this->dtype, transposed_shape); + tensor.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector{1}); } - py_columnwise_data = create_transpose ? py::cast(columnwise_data) : py::none(); + opts = opts.dtype(torch::kFloat32); - // TODO: Replace with an empty tensor. - at::Tensor scale_inv = at::reciprocal(scale); py::object ret; if (output.is_none()) { if (internal) { py::handle Float8TensorClass(reinterpret_cast(Float8TensorBasePythonClass)); - ret = Float8TensorClass("data"_a = py_data, "fp8_scale_inv"_a = scale_inv, - "fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data, + ret = Float8TensorClass("data"_a = data, "fp8_scale_inv"_a = scale_inv, + "fp8_dtype"_a = this->dtype, "data_transpose"_a = columnwise_data, "quantizer"_a = this->quantizer); } else { py::handle Float8TensorClass(reinterpret_cast(Float8TensorPythonClass)); ret = Float8TensorClass("shape"_a = rowwise_torch_shape, "dtype"_a = GetATenDType(dtype), - "data"_a = py_data, "fp8_scale_inv"_a = scale_inv, - "fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data, + "data"_a = data, "fp8_scale_inv"_a = scale_inv, + "fp8_dtype"_a = this->dtype, "data_transpose"_a = columnwise_data, "quantizer"_a = this->quantizer); } } else { - output.attr("_data") = py_data; + output.attr("_data") = data; output.attr("_scale_inv") = scale_inv; output.attr("_fp8_dtype") = this->dtype; - output.attr("_transpose") = py_columnwise_data; + output.attr("_transpose") = columnwise_data; output.attr("_quantizer") = this->quantizer; - output.attr("_transpose_invalid") = py_columnwise_data.is_none(); + output.attr("_transpose_invalid") = (columnwise_data == std::nullopt); ret = output; } - TensorWrapper tensor(this->get_scaling_mode()); - if (rowwise_usage) { - tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape); - tensor.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector{1}); - } - if (create_transpose) { - std::vector transposed_shape; - for (auto s : columnwise_torch_shape) { - transposed_shape.emplace_back(static_cast(s)); - } - tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, transposed_shape); - tensor.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector{1}); - } + this->set_quantization_params(&tensor); return {std::move(tensor), std::move(ret)}; } @@ -237,88 +239,77 @@ void Float8CurrentScalingQuantizer::set_quantization_params(TensorWrapper* tenso std::pair Float8CurrentScalingQuantizer::create_tensor( const std::vector& shape, DType dtype, const py::object& output, std::optional rowwise_data) const { using namespace pybind11::literals; - std::vector rowwise_torch_shape; - std::vector columnwise_torch_shape; std::vector scale_inv_torch_shape = {1}; // Shape of 1 element for scale_inv - rowwise_torch_shape.reserve(shape.size()); - columnwise_torch_shape.reserve(shape.size()); - if (!shape.empty()) { - columnwise_torch_shape.emplace_back(static_cast(shape.back())); - } - for (size_t i = 0; i < shape.size(); ++i) { - if (i < shape.size() - 1) { - columnwise_torch_shape.emplace_back(static_cast(shape[i])); - } - rowwise_torch_shape.emplace_back(static_cast(shape[i])); - } at::TensorOptions opts; opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); - at::Tensor data; - at::Tensor columnwise_data; - py::object py_data{py::none()}; - py::object py_columnwise_data{py::none()}; + std::vector rowwise_torch_shape(shape.begin(), shape.end()); + + std::optional data = std::nullopt; + std::optional columnwise_data = std::nullopt; + // In current scaling, scale is not known but we initialize it with 1 to avoid division by zero. If scale is already calculated, it can be correctly set. + at::Tensor scale_inv = at::reciprocal(scale); + + TensorWrapper tensor(this->get_scaling_mode()); + bool create_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); if (!output.is_none()) { NVTE_CHECK(detail::IsFloat8Tensor(output.ptr()), "Wrong Tensor type provided for reuse. ", "Expected Float8Tensor or Float8TensorBase, but got ", py::repr(output).cast()); - py_data = output.attr("_data"); - py_columnwise_data = output.attr("_transpose"); } if (rowwise_usage) { if (rowwise_data.has_value()) { data = std::move(*rowwise_data); } else { - data = create_torch_tensor(rowwise_torch_shape, opts, py_data); + data = create_torch_tensor(rowwise_torch_shape, opts, output, "_data"); } + + tensor.set_rowwise_data(data->data_ptr(), this->dtype, shape); + tensor.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector{1}); } - py_data = rowwise_usage ? py::cast(data) : py::none(); + if (create_transpose) { - columnwise_data = create_torch_tensor(columnwise_torch_shape, opts, py_columnwise_data); + std::vector columnwise_torch_shape; + columnwise_torch_shape.reserve(shape.size()); + if (!shape.empty()) { + columnwise_torch_shape.emplace_back(static_cast(shape.back())); + } + for (size_t i = 0; i < shape.size() - 1; ++i) { + columnwise_torch_shape.emplace_back(static_cast(shape[i])); + } + columnwise_data = create_torch_tensor(columnwise_torch_shape, opts, output, "_transpose"); + std::vector transposed_shape(columnwise_torch_shape.begin(), + columnwise_torch_shape.end()); + tensor.set_columnwise_data(columnwise_data->data_ptr(), this->dtype, transposed_shape); + tensor.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector{1}); } - py_columnwise_data = create_transpose ? py::cast(columnwise_data) : py::none(); - - // In current scaling, scale is not known but we initialize it with 1 to avoid division by zero. If scale is already calculated, it can be correctly set. - at::Tensor scale_inv = at::reciprocal(scale); py::object ret; if (output.is_none()) { if (internal) { py::handle Float8TensorClass(reinterpret_cast(Float8TensorBasePythonClass)); - ret = Float8TensorClass("data"_a = py_data, "fp8_scale_inv"_a = scale_inv, - "fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data, + ret = Float8TensorClass("data"_a = data, "fp8_scale_inv"_a = scale_inv, + "fp8_dtype"_a = this->dtype, "data_transpose"_a = columnwise_data, "quantizer"_a = this->quantizer); } else { py::handle Float8TensorClass(reinterpret_cast(Float8TensorPythonClass)); ret = Float8TensorClass("shape"_a = rowwise_torch_shape, "dtype"_a = GetATenDType(dtype), - "data"_a = py_data, "fp8_scale_inv"_a = scale_inv, - "fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data, + "data"_a = data, "fp8_scale_inv"_a = scale_inv, + "fp8_dtype"_a = this->dtype, "data_transpose"_a = columnwise_data, "quantizer"_a = this->quantizer); } } else { - output.attr("_data") = py_data; + output.attr("_data") = data; output.attr("_scale_inv") = scale_inv; output.attr("_fp8_dtype") = this->dtype; - output.attr("_transpose") = py_columnwise_data; + output.attr("_transpose") = columnwise_data; output.attr("_quantizer") = this->quantizer; - output.attr("_transpose_invalid") = py_columnwise_data.is_none(); + output.attr("_transpose_invalid") = (columnwise_data == std::nullopt); ret = output; } - TensorWrapper tensor(this->get_scaling_mode()); - if (rowwise_usage) { - tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape); - tensor.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector{1}); - } - if (create_transpose) { - std::vector transposed_shape; - for (auto s : columnwise_torch_shape) { - transposed_shape.emplace_back(static_cast(s)); - } - tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, transposed_shape); - tensor.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector{1}); - } + this->set_quantization_params(&tensor); return {std::move(tensor), std::move(ret)}; @@ -351,122 +342,130 @@ void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const std::pair Float8BlockQuantizer::create_tensor( const std::vector& shape, DType dtype, const py::object& output, std::optional rowwise_data) const { using namespace pybind11::literals; - std::vector torch_shape; - size_t numel = 1; - for (auto s : shape) { - torch_shape.emplace_back(static_cast(s)); - numel *= s; - } + std::vector torch_shape(shape.begin(), shape.end()); + size_t numel = product(shape); TensorWrapper tensor(this->get_scaling_mode()); at::TensorOptions opts; - at::TensorOptions scale_opts; - at::Tensor data_rowwise, data_colwise, scale_inv_rowwise, scale_inv_colwise; opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); + at::TensorOptions scale_opts; scale_opts = scale_opts.dtype(torch::kFloat32).device(torch::kCUDA); - size_t k_dim = torch_shape.size() == 0 ? 1u : torch_shape.back(); + std::optional data_rowwise, data_colwise, scale_inv_rowwise, scale_inv_colwise; + + size_t k_dim = shape.size() == 0 ? 1u : shape.back(); size_t m_dim = numel / k_dim; constexpr size_t kBlockLen = 128; - py::object py_data{py::none()}; - py::object py_columnwise_data{py::none()}; - py::object py_scale_inv_rowwise{py::none()}; - py::object py_scale_inv_columnwise{py::none()}; if (!output.is_none()) { NVTE_CHECK(detail::IsFloat8BlockwiseQTensor(output.ptr()), "Wrong Tensor type provided for reuse. ", "Expected Float8BlockwiseQTensor or Float8BlockwiseQTensorBase, but got ", py::repr(output).cast()); - py_data = output.attr("_rowwise_data"); - py_columnwise_data = output.attr("_columnwise_data"); - py_scale_inv_rowwise = output.attr("_rowwise_scale_inv"); - py_scale_inv_columnwise = output.attr("_columnwise_scale_inv"); } if (rowwise_usage) { if (rowwise_data.has_value()) { data_rowwise = std::move(*rowwise_data); } else { - data_rowwise = create_torch_tensor(torch_shape, opts, py_data); + data_rowwise = create_torch_tensor(torch_shape, opts, output, "_rowwise_data"); } size_t sinv0 = 0; size_t sinv1 = 0; - if (block_scaling_dim == 2) { - sinv0 = (m_dim + kBlockLen - 1) / kBlockLen; - sinv1 = roundup((k_dim + kBlockLen - 1) / kBlockLen, 4); - } else if (block_scaling_dim == 1) { - sinv0 = (k_dim + kBlockLen - 1) / kBlockLen; - sinv1 = roundup(m_dim, 4); - } else { - NVTE_CHECK(false, - "Unsupported block_scaling_dim in create_tensor rowwise." - "Expected 1 or 2. Got ", - block_scaling_dim); + switch (block_scaling_dim) { + case 1: { + sinv0 = divup(k_dim, kBlockLen); + sinv1 = roundup(m_dim, 4lu); + } + break; + case 2: { + sinv0 = divup(m_dim, kBlockLen); + sinv1 = roundup(divup(k_dim, kBlockLen), 4lu); + } + break; + default: { + NVTE_ERROR("Unsupported block_scaling_dim in create_tensor rowwise." + "Expected 1 or 2. Got ", block_scaling_dim); + } + break; } - scale_inv_rowwise = - at::empty({static_cast(sinv0), static_cast(sinv1)}, scale_opts); - tensor.set_rowwise_data(data_rowwise.data_ptr(), this->dtype, shape); - tensor.set_rowwise_scale_inv(scale_inv_rowwise.data_ptr(), DType::kFloat32, + scale_inv_rowwise = create_torch_tensor( + {static_cast(sinv0), static_cast(sinv1)}, + scale_opts, output, "_rowwise_scale_inv"); + tensor.set_rowwise_data(data_rowwise->data_ptr(), this->dtype, shape); + tensor.set_rowwise_scale_inv(scale_inv_rowwise->data_ptr(), DType::kFloat32, std::vector{sinv0, sinv1}); } if (columnwise_usage) { std::vector torch_columnwise_shape; - std::vector columnwise_shape; - NVTE_CHECK(torch_shape.size() == shape.size(), "Shape expected to match torch shape. Shape ", - columnwise_shape, " torch shape: ", torch_columnwise_shape); if (torch_shape.size() > 0) { torch_columnwise_shape.reserve(torch_shape.size()); - columnwise_shape.reserve(shape.size()); torch_columnwise_shape.push_back(torch_shape[torch_shape.size() - 1]); - columnwise_shape.push_back(shape[shape.size() - 1]); for (size_t i = 0; i < torch_shape.size() - 1; ++i) { torch_columnwise_shape.push_back(torch_shape[i]); - columnwise_shape.push_back(shape[i]); } } + std::vector columnwise_shape(torch_columnwise_shape.begin(), + torch_columnwise_shape.end()); size_t sinv0 = 0; size_t sinv1 = 0; - if (block_scaling_dim == 2) { - sinv0 = (k_dim + kBlockLen - 1) / kBlockLen; - sinv1 = roundup((m_dim + kBlockLen - 1) / kBlockLen, 4); - } else if (block_scaling_dim == 1) { - sinv0 = (m_dim + kBlockLen - 1) / kBlockLen; - sinv1 = roundup(k_dim, 4); - } else { - NVTE_CHECK(false, - "Unsupported block_scaling_dim in create_tensor columnwise." - "Expected 1 or 2. Got ", - block_scaling_dim); + switch (block_scaling_dim) { + case 1: { + sinv0 = divup(m_dim, kBlockLen); + sinv1 = roundup(k_dim, 4lu); + } + break; + case 2: { + sinv0 = divup(k_dim, kBlockLen); + sinv1 = roundup(divup(m_dim, kBlockLen), 4lu); + } + break; + default: { + NVTE_ERROR("Unsupported block_scaling_dim in create_tensor columnwise." + "Expected 1 or 2. Got ", block_scaling_dim); + } + break; } - data_colwise = at::empty(torch_columnwise_shape, opts); - scale_inv_colwise = - at::empty({static_cast(sinv0), static_cast(sinv1)}, scale_opts); + data_colwise = create_torch_tensor(torch_columnwise_shape, opts, output, "_columnwise_data"); + scale_inv_colwise = create_torch_tensor( + {static_cast(sinv0), static_cast(sinv1)}, + scale_opts, output, "_columnwise_scale_inv"); - tensor.set_columnwise_data(data_colwise.data_ptr(), this->dtype, columnwise_shape); - tensor.set_columnwise_scale_inv(scale_inv_colwise.data_ptr(), DType::kFloat32, + tensor.set_columnwise_data(data_colwise->data_ptr(), this->dtype, columnwise_shape); + tensor.set_columnwise_scale_inv(scale_inv_colwise->data_ptr(), DType::kFloat32, std::vector{sinv0, sinv1}); } this->set_quantization_params(&tensor); py::object ret; - if (internal) { - py::handle Float8BlockwiseQTensorClass( - reinterpret_cast(Float8BlockwiseQTensorBasePythonClass)); - ret = Float8BlockwiseQTensorClass( - "rowwise_data"_a = data_rowwise, "columnwise_data"_a = data_colwise, - "rowwise_scale_inv"_a = scale_inv_rowwise, "columnwise_scale_inv"_a = scale_inv_colwise, - "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer, - "is_2D_scaled"_a = (block_scaling_dim == 2)); + if (output.is_none()) { + if (internal) { + py::handle Float8BlockwiseQTensorClass( + reinterpret_cast(Float8BlockwiseQTensorBasePythonClass)); + ret = Float8BlockwiseQTensorClass( + "rowwise_data"_a = data_rowwise, "columnwise_data"_a = data_colwise, + "rowwise_scale_inv"_a = scale_inv_rowwise, "columnwise_scale_inv"_a = scale_inv_colwise, + "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer, + "is_2D_scaled"_a = (block_scaling_dim == 2)); + } else { + py::handle Float8BlockwiseQTensorClass( + reinterpret_cast(Float8BlockwiseQTensorPythonClass)); + ret = Float8BlockwiseQTensorClass( + "shape"_a = torch_shape, "dtype"_a = GetATenDType(dtype), "rowwise_data"_a = data_rowwise, + "columnwise_data"_a = data_colwise, "rowwise_scale_inv"_a = scale_inv_rowwise, + "columnwise_scale_inv"_a = scale_inv_colwise, "fp8_dtype"_a = this->dtype, + "quantizer"_a = this->quantizer, "is_2D_scaled"_a = (block_scaling_dim == 2)); + } } else { - py::handle Float8BlockwiseQTensorClass( - reinterpret_cast(Float8BlockwiseQTensorPythonClass)); - ret = Float8BlockwiseQTensorClass( - "shape"_a = torch_shape, "dtype"_a = GetATenDType(dtype), "rowwise_data"_a = data_rowwise, - "columnwise_data"_a = data_colwise, "rowwise_scale_inv"_a = scale_inv_rowwise, - "columnwise_scale_inv"_a = scale_inv_colwise, "fp8_dtype"_a = this->dtype, - "quantizer"_a = this->quantizer, "is_2D_scaled"_a = (block_scaling_dim == 2)); + output.attr("_rowwise_data") = data_rowwise; + output.attr("_columnwise_data") = data_colwise; + output.attr("_quantizer") = this->quantizer; + output.attr("_fp8_dtype") = this->dtype; + output.attr("_rowwise_scale_inv") = scale_inv_colwise; + output.attr("_columnwise_scale_inv") = scale_inv_rowwise; + output.attr("_is_2D_scaled") = (block_scaling_dim == 2); + ret = output; } return {std::move(tensor), std::move(ret)}; @@ -492,67 +491,84 @@ void MXFP8Quantizer::set_quantization_params(TensorWrapper* tensor) const { std::pair MXFP8Quantizer::create_tensor( const std::vector& shape, DType dtype, const py::object& output, std::optional rowwise_data) const { using namespace pybind11::literals; - std::vector torch_shape; - size_t numel = 1; - for (auto s : shape) { - torch_shape.emplace_back(static_cast(s)); - numel *= s; - } + std::vector torch_shape(shape.begin(), shape.end()); + size_t numel = product(shape); TensorWrapper tensor(NVTE_MXFP8_1D_SCALING); at::TensorOptions opts; - at::Tensor rowwise_data1, columnwise_data, rowwise_scale_inv, - columnwise_scale_inv; // TODO(pgadzinski) - change opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); + + std::optional data_rowwise, data_colwise, rowwise_scale_inv, columnwise_scale_inv; + auto last_dim = static_cast(torch_shape.back()); NVTE_CHECK(last_dim % MXFP8_BLOCK_SIZE == 0 && (numel / last_dim) % MXFP8_BLOCK_SIZE == 0, "MXFP8 requires tensor dims that are divisble by ", MXFP8_BLOCK_SIZE, " (got shape=", torch_shape, ")"); - at::Tensor data; + if (!output.is_none()) { + NVTE_CHECK(detail::IsMXFP8Tensor(output.ptr()), + "Wrong Tensor type provided for reuse. ", + "Expected MXFP8Tensor or MXFP8TensorBase, but got ", + py::repr(output).cast()); + } + if (rowwise_usage) { if (rowwise_data.has_value()) { - data = std::move(*rowwise_data); + data_rowwise = std::move(*rowwise_data); } else { - data = at::empty(torch_shape, opts); + data_rowwise = create_torch_tensor(torch_shape, opts, output, "_rowwise_data"); } - auto sinv0 = roundup(numel / last_dim, 128); - auto sinv1 = roundup(last_dim / MXFP8_BLOCK_SIZE, 4); - rowwise_scale_inv = at::zeros({sinv0, sinv1}, opts); - tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape); + auto sinv0 = roundup(numel / last_dim, 128lu); + auto sinv1 = roundup(last_dim / MXFP8_BLOCK_SIZE, 4lu); + rowwise_scale_inv = create_torch_tensor( + {static_cast(sinv0), static_cast(sinv1)}, + opts, output, "_rowwise_scale_inv"); + tensor.set_rowwise_data(data_rowwise->data_ptr(), this->dtype, shape); tensor.set_rowwise_scale_inv( - rowwise_scale_inv.data_ptr(), DType::kFloat8E8M0, + rowwise_scale_inv->data_ptr(), DType::kFloat8E8M0, std::vector{static_cast(sinv0), static_cast(sinv1)}); } if (columnwise_usage) { - auto sinv0 = roundup(numel / (last_dim * MXFP8_BLOCK_SIZE), 4); - auto sinv1 = roundup(last_dim, 128); - columnwise_data = at::empty(torch_shape, opts); - columnwise_scale_inv = at::zeros({sinv0, sinv1}, opts); - - tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, shape); + auto sinv0 = roundup(numel / (last_dim * MXFP8_BLOCK_SIZE), 4lu); + auto sinv1 = roundup(last_dim, 128lu); + data_colwise = create_torch_tensor(torch_shape, opts, output, "_columnwise_data"); + columnwise_scale_inv = create_torch_tensor( + {static_cast(sinv0), static_cast(sinv1)}, + opts, output, "_columnwise_scale_inv"); + + tensor.set_columnwise_data(data_colwise->data_ptr(), this->dtype, shape); tensor.set_columnwise_scale_inv( - columnwise_scale_inv.data_ptr(), DType::kFloat8E8M0, - std::vector{static_cast(sinv0), static_cast(sinv1)}); + columnwise_scale_inv->data_ptr(), DType::kFloat8E8M0, + std::vector{sinv0, sinv1}); } this->set_quantization_params(&tensor); py::object ret; - if (internal) { - py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorBasePythonClass)); - ret = MXFP8TensorClass("rowwise_data"_a = data, "columnwise_data"_a = columnwise_data, - "rowwise_scale_inv"_a = rowwise_scale_inv, - "columnwise_scale_inv"_a = columnwise_scale_inv, - "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); + if (output.is_none()) { + if (internal) { + py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorBasePythonClass)); + ret = MXFP8TensorClass("rowwise_data"_a = data_rowwise, "columnwise_data"_a = data_colwise, + "rowwise_scale_inv"_a = rowwise_scale_inv, + "columnwise_scale_inv"_a = columnwise_scale_inv, + "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); + } else { + py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorPythonClass)); + ret = MXFP8TensorClass("shape"_a = torch_shape, "dtype"_a = GetATenDType(dtype), + "rowwise_data"_a = data_rowwise, "columnwise_data"_a = data_colwise, + "rowwise_scale_inv"_a = rowwise_scale_inv, + "columnwise_scale_inv"_a = columnwise_scale_inv, + "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); + } } else { - py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorPythonClass)); - ret = MXFP8TensorClass("shape"_a = torch_shape, "dtype"_a = GetATenDType(dtype), - "rowwise_data"_a = data, "columnwise_data"_a = columnwise_data, - "rowwise_scale_inv"_a = rowwise_scale_inv, - "columnwise_scale_inv"_a = columnwise_scale_inv, - "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); + output.attr("_rowwise_data") = data_rowwise; + output.attr("_columnwise_data") = data_colwise; + output.attr("_quantizer") = this->quantizer; + output.attr("_fp8_dtype") = this->dtype; + output.attr("_rowwise_scale_inv") = rowwise_scale_inv; + output.attr("_columnwise_scale_inv") = columnwise_scale_inv; + ret = output; } return {std::move(tensor), std::move(ret)}; From 17678d95a4e7f953c74be6355271589acd2f0bc8 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Fri, 30 May 2025 15:54:32 -0700 Subject: [PATCH 05/25] Fix Signed-off-by: Przemek Tredak --- transformer_engine/pytorch/csrc/quantizer.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index dd7683f13f..54ae428343 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -17,11 +17,11 @@ bool tensor_is_reusable(const at::Tensor& tensor, const std::vector& shape, const at::TensorOptions& opts) { const auto& tensor_shape = tensor.sizes(); - if (tensor_shape.equals(shape)) return false; + if (!tensor_shape.equals(shape)) return false; const at::TensorOptions& tensor_opts = tensor.options(); - if (opts.dtype() == tensor_opts.dtype()) return false; - if (opts.device() == tensor_opts.device()) return false; - if (opts.device_index() == tensor_opts.device_index()) return false; + if (opts.dtype() != tensor_opts.dtype()) return false; + if (opts.device() != tensor_opts.device()) return false; + if (opts.device_index() != tensor_opts.device_index()) return false; return true; } From b4881015bf7f0b74acf03272ff82ddff1cfb6787 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Fri, 30 May 2025 16:54:05 -0700 Subject: [PATCH 06/25] Actually reuse tensors and get rid of the hack for MXFP8 Signed-off-by: Przemek Tredak --- transformer_engine/pytorch/csrc/quantizer.cpp | 4 ++-- transformer_engine/pytorch/module/base.py | 7 ------- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 54ae428343..db8dd610fd 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -20,8 +20,8 @@ bool tensor_is_reusable(const at::Tensor& tensor, if (!tensor_shape.equals(shape)) return false; const at::TensorOptions& tensor_opts = tensor.options(); if (opts.dtype() != tensor_opts.dtype()) return false; - if (opts.device() != tensor_opts.device()) return false; - if (opts.device_index() != tensor_opts.device_index()) return false; + if (opts.device().type() != tensor_opts.device().type()) return false; + if (opts.device_index() != tensor_opts.device_index() && opts.device_index() != -1) return false; return true; } diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 87794cc63b..9f52e77a83 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1280,13 +1280,6 @@ def get_weight_workspace( if cache_name is not None: out = self._fp8_workspaces.get(cache_name, None) - if quantizer is not None and isinstance(out, MXFP8TensorBase): - if quantizer.rowwise_usage and out._rowwise_data is None: - out = None - del self._fp8_workspaces[cache_name] - elif quantizer.columnwise_usage and out._columnwise_data is None: - out = None - del self._fp8_workspaces[cache_name] is_debug = isinstance(quantizer, DebugQuantizer) is_out_debug_tensor = out is not None and isinstance(out, DebugQuantizedTensor) From 209cb9fca9d7efa84e5e43fa0c7f05bfaba92699 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Fri, 30 May 2025 16:59:01 -0700 Subject: [PATCH 07/25] Small cleaning Signed-off-by: Przemek Tredak --- transformer_engine/pytorch/module/base.py | 25 ++++------------------- 1 file changed, 4 insertions(+), 21 deletions(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 9f52e77a83..6df35251a8 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -43,7 +43,6 @@ from ..tensor._internal.float8_tensor_base import Float8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..utils import torch_get_autocast_gpu_dtype -from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from ...common.recipe import DelayedScaling, Recipe from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor @@ -1074,15 +1073,7 @@ def grad_output_preprocess( grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0) if ctx.ub_overlap_ag: # Quantize the gradient if needed - if not isinstance( - grad_output, - ( - QuantizedTensor, - Float8TensorBase, - MXFP8TensorBase, - Float8BlockwiseQTensorBase, - ), - ): + if not isinstance(grad_output, QuantizedTensorBase): grad_output = quantizer(grad_output) # Copy into communication buffer, and replace original gradient with it @@ -1105,15 +1096,7 @@ def grad_output_preprocess( if ctx.debug: grad_output_ = quantizer(grad_output) if ( - isinstance( - grad_output_.get_tensor(True), - ( - QuantizedTensor, - Float8TensorBase, - MXFP8TensorBase, - Float8BlockwiseQTensorBase, - ), - ) + isinstance(grad_output_.get_tensor(True), QuantizedTensorBase) and ctx.use_bias ): grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0) @@ -1127,7 +1110,7 @@ def grad_output_preprocess( if ctx.use_bias: if isinstance( grad_output, - (QuantizedTensor, Float8TensorBase, MXFP8TensorBase, Float8BlockwiseQTensorBase), + QuantizedTensorBase, ): grad_bias = grad_output.dequantize().view(-1, grad_output.shape[-1]).sum(dim=0) else: @@ -1138,7 +1121,7 @@ def grad_output_preprocess( grad_bias, grad_output = tex.bgrad_quantize(grad_output, quantizer) if not isinstance( grad_output, - (QuantizedTensor, Float8TensorBase, MXFP8TensorBase, Float8BlockwiseQTensorBase), + QuantizedTensorBase, ): grad_output = quantizer(grad_output) return grad_output, grad_bias From c61b14bd388be694a579c8f3666839b410fccc69 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Fri, 30 May 2025 17:06:19 -0700 Subject: [PATCH 08/25] Make sure dgrad is not needed in the test during eval phase Signed-off-by: Przemek Tredak --- tests/pytorch/test_sanity.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 0fc6b34779..255b3d0bc7 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -397,6 +397,7 @@ def _test_sanity_common( # now try eval with weight caching block.eval() + te_inp.requires_grad = False with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): te_out = block(te_inp, is_first_microbatch=True) @@ -406,6 +407,7 @@ def _test_sanity_common( # now try regular execution again with weight caching block.train() + te_inp.requires_grad = True with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): te_out = block(te_inp, is_first_microbatch=True) From 7561fb4cc87633c6ea58c99acf400065ecf73d71 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 31 May 2025 00:16:16 +0000 Subject: [PATCH 09/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/test_sanity.py | 1 + transformer_engine/pytorch/csrc/common.h | 18 +-- .../pytorch/csrc/extensions/cast.cpp | 4 +- transformer_engine/pytorch/csrc/quantizer.cpp | 109 ++++++++---------- transformer_engine/pytorch/module/base.py | 5 +- 5 files changed, 61 insertions(+), 76 deletions(-) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 255b3d0bc7..fcf6511be9 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -423,6 +423,7 @@ def _test_sanity_common( loss.backward() torch.cuda.synchronize() + def _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad): if skip_dgrad and skip_wgrad: pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.") diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 3e3993f644..97f38effb3 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -97,8 +97,7 @@ class Quantizer { virtual void set_quantization_params(TensorWrapper* tensor) const = 0; virtual std::pair create_tensor( - const std::vector& shape, DType dtype, - const py::object& output = py::none(), + const std::vector& shape, DType dtype, const py::object& output = py::none(), std::optional rowwise_data = std::nullopt) const = 0; virtual ~Quantizer() = default; @@ -121,8 +120,7 @@ class NoneQuantizer : public Quantizer { void set_quantization_params(TensorWrapper* tensor) const override {} std::pair create_tensor( - const std::vector& shape, DType dtype, - const py::object& output = py::none(), + const std::vector& shape, DType dtype, const py::object& output = py::none(), std::optional rowwise_data = std::nullopt) const override; }; @@ -140,8 +138,7 @@ class Float8Quantizer : public Quantizer { void set_quantization_params(TensorWrapper* tensor) const override; std::pair create_tensor( - const std::vector& shape, DType dtype, - const py::object& output = py::none(), + const std::vector& shape, DType dtype, const py::object& output = py::none(), std::optional rowwise_data = std::nullopt) const override; }; @@ -163,8 +160,7 @@ class Float8CurrentScalingQuantizer : public Quantizer { void set_quantization_params(TensorWrapper* tensor) const override; std::pair create_tensor( - const std::vector& shape, DType dtype, - const py::object& output = py::none(), + const std::vector& shape, DType dtype, const py::object& output = py::none(), std::optional rowwise_data = std::nullopt) const override; }; @@ -196,8 +192,7 @@ class Float8BlockQuantizer : public Quantizer { // for the tensor. Should set quantized data, scales for rowwise // and optionally columnwise usage. std::pair create_tensor( - const std::vector& shape, DType dtype, - const py::object& output = py::none(), + const std::vector& shape, DType dtype, const py::object& output = py::none(), std::optional rowwise_data = std::nullopt) const override; }; @@ -212,8 +207,7 @@ class MXFP8Quantizer : public Quantizer { void set_quantization_params(TensorWrapper* tensor) const override; std::pair create_tensor( - const std::vector& shape, DType dtype, - const py::object& output = py::none(), + const std::vector& shape, DType dtype, const py::object& output = py::none(), std::optional rowwise_data = std::nullopt) const override; }; diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 8585c9f86b..32ed05ea7f 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -116,8 +116,8 @@ std::vector dbias_dact(const at::Tensor& grad_output, const at::Tens auto act_input_tensor = makeTransformerEngineTensor(act_input); const auto& shape = convertShape(grad_tensor.shape()); - auto [dact_tensor, dact] = my_quantizer->create_tensor(shape, act_input_tensor.dtype(), - py::none()); + auto [dact_tensor, dact] = + my_quantizer->create_tensor(shape, act_input_tensor.dtype(), py::none()); auto dbias_tensor = makeTransformerEngineTensor(grad_bias); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index db8dd610fd..2b237f7fa7 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -13,8 +13,7 @@ namespace transformer_engine::pytorch { constexpr size_t MXFP8_BLOCK_SIZE = 32; -bool tensor_is_reusable(const at::Tensor& tensor, - const std::vector& shape, +bool tensor_is_reusable(const at::Tensor& tensor, const std::vector& shape, const at::TensorOptions& opts) { const auto& tensor_shape = tensor.sizes(); if (!tensor_shape.equals(shape)) return false; @@ -49,8 +48,7 @@ Float8Quantizer::Float8Quantizer(const py::handle& quantizer) : Quantizer(quanti } // Create torch tensor reusing existing data if possible -at::Tensor create_torch_tensor(const std::vector& shape, - const at::TensorOptions& opts, +at::Tensor create_torch_tensor(const std::vector& shape, const at::TensorOptions& opts, const py::object& tensor_to_reuse) { if (!tensor_to_reuse.is_none()) { // Reuse output @@ -64,8 +62,7 @@ at::Tensor create_torch_tensor(const std::vector& shape, // Create torch tensor reusing existing data is possible // The reused tensor is tensor_to_reuse.attr_name -at::Tensor create_torch_tensor(const std::vector& shape, - const at::TensorOptions& opts, +at::Tensor create_torch_tensor(const std::vector& shape, const at::TensorOptions& opts, const py::object& tensor_to_reuse, const std::string_view& attr_name) { py::object tensor{py::none()}; @@ -76,8 +73,7 @@ at::Tensor create_torch_tensor(const std::vector& shape, } std::pair NoneQuantizer::create_tensor( - const std::vector& shape, DType dtype, - const py::object& output, + const std::vector& shape, DType dtype, const py::object& output, std::optional rowwise_data) const { at::TensorOptions opts; opts = opts.dtype(GetATenDType(dtype)).device(torch::kCUDA); @@ -112,8 +108,8 @@ void Float8Quantizer::set_quantization_params(TensorWrapper* tensor) const { } std::pair Float8Quantizer::create_tensor( - const std::vector& shape, DType dtype, - const py::object& output, std::optional rowwise_data) const { + const std::vector& shape, DType dtype, const py::object& output, + std::optional rowwise_data) const { using namespace pybind11::literals; at::TensorOptions opts; @@ -128,8 +124,7 @@ std::pair Float8Quantizer::create_tensor( bool create_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); TensorWrapper tensor(this->get_scaling_mode()); if (!output.is_none()) { - NVTE_CHECK(detail::IsFloat8Tensor(output.ptr()), - "Wrong Tensor type provided for reuse. ", + NVTE_CHECK(detail::IsFloat8Tensor(output.ptr()), "Wrong Tensor type provided for reuse. ", "Expected Float8Tensor or Float8TensorBase, but got ", py::repr(output).cast()); } @@ -237,7 +232,8 @@ void Float8CurrentScalingQuantizer::set_quantization_params(TensorWrapper* tenso } std::pair Float8CurrentScalingQuantizer::create_tensor( - const std::vector& shape, DType dtype, const py::object& output, std::optional rowwise_data) const { + const std::vector& shape, DType dtype, const py::object& output, + std::optional rowwise_data) const { using namespace pybind11::literals; std::vector scale_inv_torch_shape = {1}; // Shape of 1 element for scale_inv @@ -254,8 +250,7 @@ std::pair Float8CurrentScalingQuantizer::create_tenso bool create_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); if (!output.is_none()) { - NVTE_CHECK(detail::IsFloat8Tensor(output.ptr()), - "Wrong Tensor type provided for reuse. ", + NVTE_CHECK(detail::IsFloat8Tensor(output.ptr()), "Wrong Tensor type provided for reuse. ", "Expected Float8Tensor or Float8TensorBase, but got ", py::repr(output).cast()); } @@ -340,7 +335,8 @@ void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const } std::pair Float8BlockQuantizer::create_tensor( - const std::vector& shape, DType dtype, const py::object& output, std::optional rowwise_data) const { + const std::vector& shape, DType dtype, const py::object& output, + std::optional rowwise_data) const { using namespace pybind11::literals; std::vector torch_shape(shape.begin(), shape.end()); size_t numel = product(shape); @@ -374,24 +370,23 @@ std::pair Float8BlockQuantizer::create_tensor( size_t sinv1 = 0; switch (block_scaling_dim) { case 1: { - sinv0 = divup(k_dim, kBlockLen); - sinv1 = roundup(m_dim, 4lu); - } - break; + sinv0 = divup(k_dim, kBlockLen); + sinv1 = roundup(m_dim, 4lu); + } break; case 2: { - sinv0 = divup(m_dim, kBlockLen); - sinv1 = roundup(divup(k_dim, kBlockLen), 4lu); - } - break; + sinv0 = divup(m_dim, kBlockLen); + sinv1 = roundup(divup(k_dim, kBlockLen), 4lu); + } break; default: { - NVTE_ERROR("Unsupported block_scaling_dim in create_tensor rowwise." - "Expected 1 or 2. Got ", block_scaling_dim); - } - break; + NVTE_ERROR( + "Unsupported block_scaling_dim in create_tensor rowwise." + "Expected 1 or 2. Got ", + block_scaling_dim); + } break; } - scale_inv_rowwise = create_torch_tensor( - {static_cast(sinv0), static_cast(sinv1)}, - scale_opts, output, "_rowwise_scale_inv"); + scale_inv_rowwise = + create_torch_tensor({static_cast(sinv0), static_cast(sinv1)}, scale_opts, + output, "_rowwise_scale_inv"); tensor.set_rowwise_data(data_rowwise->data_ptr(), this->dtype, shape); tensor.set_rowwise_scale_inv(scale_inv_rowwise->data_ptr(), DType::kFloat32, std::vector{sinv0, sinv1}); @@ -412,25 +407,24 @@ std::pair Float8BlockQuantizer::create_tensor( size_t sinv1 = 0; switch (block_scaling_dim) { case 1: { - sinv0 = divup(m_dim, kBlockLen); - sinv1 = roundup(k_dim, 4lu); - } - break; + sinv0 = divup(m_dim, kBlockLen); + sinv1 = roundup(k_dim, 4lu); + } break; case 2: { - sinv0 = divup(k_dim, kBlockLen); - sinv1 = roundup(divup(m_dim, kBlockLen), 4lu); - } - break; + sinv0 = divup(k_dim, kBlockLen); + sinv1 = roundup(divup(m_dim, kBlockLen), 4lu); + } break; default: { - NVTE_ERROR("Unsupported block_scaling_dim in create_tensor columnwise." - "Expected 1 or 2. Got ", block_scaling_dim); - } - break; + NVTE_ERROR( + "Unsupported block_scaling_dim in create_tensor columnwise." + "Expected 1 or 2. Got ", + block_scaling_dim); + } break; } data_colwise = create_torch_tensor(torch_columnwise_shape, opts, output, "_columnwise_data"); - scale_inv_colwise = create_torch_tensor( - {static_cast(sinv0), static_cast(sinv1)}, - scale_opts, output, "_columnwise_scale_inv"); + scale_inv_colwise = + create_torch_tensor({static_cast(sinv0), static_cast(sinv1)}, scale_opts, + output, "_columnwise_scale_inv"); tensor.set_columnwise_data(data_colwise->data_ptr(), this->dtype, columnwise_shape); tensor.set_columnwise_scale_inv(scale_inv_colwise->data_ptr(), DType::kFloat32, @@ -489,7 +483,8 @@ void MXFP8Quantizer::set_quantization_params(TensorWrapper* tensor) const { } std::pair MXFP8Quantizer::create_tensor( - const std::vector& shape, DType dtype, const py::object& output, std::optional rowwise_data) const { + const std::vector& shape, DType dtype, const py::object& output, + std::optional rowwise_data) const { using namespace pybind11::literals; std::vector torch_shape(shape.begin(), shape.end()); size_t numel = product(shape); @@ -507,8 +502,7 @@ std::pair MXFP8Quantizer::create_tensor( " (got shape=", torch_shape, ")"); if (!output.is_none()) { - NVTE_CHECK(detail::IsMXFP8Tensor(output.ptr()), - "Wrong Tensor type provided for reuse. ", + NVTE_CHECK(detail::IsMXFP8Tensor(output.ptr()), "Wrong Tensor type provided for reuse. ", "Expected MXFP8Tensor or MXFP8TensorBase, but got ", py::repr(output).cast()); } @@ -521,9 +515,9 @@ std::pair MXFP8Quantizer::create_tensor( } auto sinv0 = roundup(numel / last_dim, 128lu); auto sinv1 = roundup(last_dim / MXFP8_BLOCK_SIZE, 4lu); - rowwise_scale_inv = create_torch_tensor( - {static_cast(sinv0), static_cast(sinv1)}, - opts, output, "_rowwise_scale_inv"); + rowwise_scale_inv = + create_torch_tensor({static_cast(sinv0), static_cast(sinv1)}, opts, + output, "_rowwise_scale_inv"); tensor.set_rowwise_data(data_rowwise->data_ptr(), this->dtype, shape); tensor.set_rowwise_scale_inv( rowwise_scale_inv->data_ptr(), DType::kFloat8E8M0, @@ -534,14 +528,13 @@ std::pair MXFP8Quantizer::create_tensor( auto sinv0 = roundup(numel / (last_dim * MXFP8_BLOCK_SIZE), 4lu); auto sinv1 = roundup(last_dim, 128lu); data_colwise = create_torch_tensor(torch_shape, opts, output, "_columnwise_data"); - columnwise_scale_inv = create_torch_tensor( - {static_cast(sinv0), static_cast(sinv1)}, - opts, output, "_columnwise_scale_inv"); + columnwise_scale_inv = + create_torch_tensor({static_cast(sinv0), static_cast(sinv1)}, opts, + output, "_columnwise_scale_inv"); tensor.set_columnwise_data(data_colwise->data_ptr(), this->dtype, shape); - tensor.set_columnwise_scale_inv( - columnwise_scale_inv->data_ptr(), DType::kFloat8E8M0, - std::vector{sinv0, sinv1}); + tensor.set_columnwise_scale_inv(columnwise_scale_inv->data_ptr(), DType::kFloat8E8M0, + std::vector{sinv0, sinv1}); } this->set_quantization_params(&tensor); diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 6df35251a8..fbb8f8efc1 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1095,10 +1095,7 @@ def grad_output_preprocess( # bgrad only if wgrad is in FP8, otherwise it is fused with wgrad and we return None if ctx.debug: grad_output_ = quantizer(grad_output) - if ( - isinstance(grad_output_.get_tensor(True), QuantizedTensorBase) - and ctx.use_bias - ): + if isinstance(grad_output_.get_tensor(True), QuantizedTensorBase) and ctx.use_bias: grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0) else: grad_bias = None From 41b8fb44dfae452636e79ca790f434a175057faa Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Mon, 2 Jun 2025 14:12:31 -0700 Subject: [PATCH 10/25] Fixes Signed-off-by: Przemek Tredak --- transformer_engine/pytorch/csrc/quantizer.cpp | 4 ++-- transformer_engine/pytorch/module/layernorm_mlp.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 2b237f7fa7..293ea171f2 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -456,8 +456,8 @@ std::pair Float8BlockQuantizer::create_tensor( output.attr("_columnwise_data") = data_colwise; output.attr("_quantizer") = this->quantizer; output.attr("_fp8_dtype") = this->dtype; - output.attr("_rowwise_scale_inv") = scale_inv_colwise; - output.attr("_columnwise_scale_inv") = scale_inv_rowwise; + output.attr("_rowwise_scale_inv") = scale_inv_rowwise; + output.attr("_columnwise_scale_inv") = scale_inv_colwise; output.attr("_is_2D_scaled") = (block_scaling_dim == 2); ret = output; } diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 6ff2763ee1..fd4f86a3e3 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -610,7 +610,9 @@ def forward( ctx.debug = debug ctx.requires_dgrad = ( - inp.requires_grad or ln_weight.requires_grad or ln_bias.requires_grad + inp.requires_grad or + ln_weight.requires_grad or + (ln_bias.requires_grad if ln_bias is not None else False) ) ctx.normalization = normalization ctx.reduce_and_update_bwd_fp8_tensors = False From 20363a4e78fa7858b56defe02ce65fcea8af4416 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 2 Jun 2025 21:13:02 +0000 Subject: [PATCH 11/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/layernorm_mlp.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index fd4f86a3e3..83b6014a1c 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -610,9 +610,9 @@ def forward( ctx.debug = debug ctx.requires_dgrad = ( - inp.requires_grad or - ln_weight.requires_grad or - (ln_bias.requires_grad if ln_bias is not None else False) + inp.requires_grad + or ln_weight.requires_grad + or (ln_bias.requires_grad if ln_bias is not None else False) ) ctx.normalization = normalization ctx.reduce_and_update_bwd_fp8_tensors = False From 2acb07a89dde16647912d05c9588779b239132b9 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Thu, 5 Jun 2025 17:49:37 -0700 Subject: [PATCH 12/25] Fixes Signed-off-by: Przemek Tredak --- transformer_engine/pytorch/csrc/quantizer.cpp | 22 +++++++++++++------ .../pytorch/ops/basic/basic_linear.py | 15 ------------- 2 files changed, 15 insertions(+), 22 deletions(-) diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 293ea171f2..227674e772 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -48,15 +48,22 @@ Float8Quantizer::Float8Quantizer(const py::handle& quantizer) : Quantizer(quanti } // Create torch tensor reusing existing data if possible -at::Tensor create_torch_tensor(const std::vector& shape, const at::TensorOptions& opts, - const py::object& tensor_to_reuse) { +at::Tensor _create_torch_tensor(const std::vector& shape, const at::TensorOptions& opts, + const py::object& tensor_to_reuse, + bool zero_out) { if (!tensor_to_reuse.is_none()) { // Reuse output const at::Tensor temp = tensor_to_reuse.cast(); if (tensor_is_reusable(temp, shape, opts)) { + if (zero_out) { + temp.zero_(); + } return temp; } } + if (zero_out) { + return at::zeros(shape, opts); + } return at::empty(shape, opts); } @@ -64,12 +71,13 @@ at::Tensor create_torch_tensor(const std::vector& shape, const at::Tens // The reused tensor is tensor_to_reuse.attr_name at::Tensor create_torch_tensor(const std::vector& shape, const at::TensorOptions& opts, const py::object& tensor_to_reuse, - const std::string_view& attr_name) { + const std::string_view& attr_name, + bool zero_out = false) { py::object tensor{py::none()}; if (!tensor_to_reuse.is_none()) { tensor = tensor_to_reuse.attr(attr_name.data()); } - return create_torch_tensor(shape, opts, tensor); + return _create_torch_tensor(shape, opts, tensor, zero_out); } std::pair NoneQuantizer::create_tensor( @@ -82,7 +90,7 @@ std::pair NoneQuantizer::create_tensor( if (rowwise_data.has_value()) { ret = std::move(*rowwise_data); } else { - ret = create_torch_tensor(torch_shape, opts, output); + ret = _create_torch_tensor(torch_shape, opts, output, false); } TensorWrapper tensor; @@ -517,7 +525,7 @@ std::pair MXFP8Quantizer::create_tensor( auto sinv1 = roundup(last_dim / MXFP8_BLOCK_SIZE, 4lu); rowwise_scale_inv = create_torch_tensor({static_cast(sinv0), static_cast(sinv1)}, opts, - output, "_rowwise_scale_inv"); + output, "_rowwise_scale_inv", true); tensor.set_rowwise_data(data_rowwise->data_ptr(), this->dtype, shape); tensor.set_rowwise_scale_inv( rowwise_scale_inv->data_ptr(), DType::kFloat8E8M0, @@ -530,7 +538,7 @@ std::pair MXFP8Quantizer::create_tensor( data_colwise = create_torch_tensor(torch_shape, opts, output, "_columnwise_data"); columnwise_scale_inv = create_torch_tensor({static_cast(sinv0), static_cast(sinv1)}, opts, - output, "_columnwise_scale_inv"); + output, "_columnwise_scale_inv", true); tensor.set_columnwise_data(data_colwise->data_ptr(), this->dtype, shape); tensor.set_columnwise_scale_inv(columnwise_scale_inv->data_ptr(), DType::kFloat8E8M0, diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 0e786ca96f..886a56cec6 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -313,16 +313,7 @@ def pre_forward(self, *args, **kwargs) -> None: # Configure quantizers if FP8GlobalStateManager.is_fp8_enabled(): - input_quantizer = self.get_quantizer("forward", 0) weight_quantizer = self.get_quantizer("forward", 1) - grad_output_quantizer = self.get_quantizer("backward", 0) - - # Specify required tensor formats - is_grad_enabled = torch.is_grad_enabled() - weight_requires_grad = is_grad_enabled and weight.requires_grad - input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) - weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) - grad_output_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) # Make sure weight tensor has correct quantizer # Note: Quantizer might have changed if quantization @@ -898,12 +889,6 @@ def op_forward( if prev_op is not None and prev_op.num_quantizers("backward") > 0: grad_input_quantizer = prev_op.get_quantizer("backward", 0) - # Configure quantizers - # Note: We cache the quantized input for backward pass, - # but discard the quantized weights. - input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) - weight_quantizer.set_usage(rowwise=True, columnwise=False) - # Get autocast dtype if needed dtype = None if torch.is_autocast_enabled(): From 955064400078c81ea47a4216170f007cab84877d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Jun 2025 00:50:08 +0000 Subject: [PATCH 13/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/csrc/quantizer.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 227674e772..d8942ba67f 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -49,8 +49,7 @@ Float8Quantizer::Float8Quantizer(const py::handle& quantizer) : Quantizer(quanti // Create torch tensor reusing existing data if possible at::Tensor _create_torch_tensor(const std::vector& shape, const at::TensorOptions& opts, - const py::object& tensor_to_reuse, - bool zero_out) { + const py::object& tensor_to_reuse, bool zero_out) { if (!tensor_to_reuse.is_none()) { // Reuse output const at::Tensor temp = tensor_to_reuse.cast(); @@ -70,8 +69,7 @@ at::Tensor _create_torch_tensor(const std::vector& shape, const at::Ten // Create torch tensor reusing existing data is possible // The reused tensor is tensor_to_reuse.attr_name at::Tensor create_torch_tensor(const std::vector& shape, const at::TensorOptions& opts, - const py::object& tensor_to_reuse, - const std::string_view& attr_name, + const py::object& tensor_to_reuse, const std::string_view& attr_name, bool zero_out = false) { py::object tensor{py::none()}; if (!tensor_to_reuse.is_none()) { From f0f96b97f09a53c8683464987a98168156347374 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Fri, 6 Jun 2025 22:18:53 +0000 Subject: [PATCH 14/25] Fix for integer overflow Signed-off-by: Przemek Tredak --- transformer_engine/pytorch/csrc/quantizer.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index d8942ba67f..6cbc3ce1e7 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -151,9 +151,9 @@ std::pair Float8Quantizer::create_tensor( columnwise_torch_shape.reserve(shape.size()); if (!shape.empty()) { columnwise_torch_shape.emplace_back(static_cast(shape.back())); - } - for (size_t i = 0; i < shape.size() - 1; ++i) { - columnwise_torch_shape.emplace_back(static_cast(shape[i])); + for (size_t i = 0; i < shape.size() - 1; ++i) { + columnwise_torch_shape.emplace_back(static_cast(shape[i])); + } } std::vector transposed_shape(columnwise_torch_shape.begin(), columnwise_torch_shape.end()); @@ -276,9 +276,9 @@ std::pair Float8CurrentScalingQuantizer::create_tenso columnwise_torch_shape.reserve(shape.size()); if (!shape.empty()) { columnwise_torch_shape.emplace_back(static_cast(shape.back())); - } - for (size_t i = 0; i < shape.size() - 1; ++i) { - columnwise_torch_shape.emplace_back(static_cast(shape[i])); + for (size_t i = 0; i < shape.size() - 1; ++i) { + columnwise_torch_shape.emplace_back(static_cast(shape[i])); + } } columnwise_data = create_torch_tensor(columnwise_torch_shape, opts, output, "_transpose"); std::vector transposed_shape(columnwise_torch_shape.begin(), From eb49987e9bc65a32502fbc37b935b9bde98baf13 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Wed, 11 Jun 2025 11:23:55 -0700 Subject: [PATCH 15/25] Try copying the quantizer Signed-off-by: Przemek Tredak --- .../pytorch/csrc/extensions/pybind.cpp | 19 ++++++- transformer_engine/pytorch/csrc/pybind.h | 28 ++++++---- transformer_engine/pytorch/csrc/quantizer.cpp | 52 ++++++++++++++----- 3 files changed, 75 insertions(+), 24 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 0a1b76e697..ebb6c3e270 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -12,8 +12,6 @@ #include #include -#include - #include "../common.h" #include "../extensions.h" #include "common.h" @@ -31,6 +29,22 @@ PyTypeObject *Float8BlockwiseQTensorPythonClass = nullptr; PyTypeObject *Float8BlockwiseQTensorBasePythonClass = nullptr; PyTypeObject *Float8BlockwiseQuantizerClass = nullptr; +py::object PythonCopy; + +py::object python_copy(const py::handle& src) { + return PythonCopy(src); +} + +void init_copy_extension() { + if (!PythonCopy.is_none()) return; + + // Import Python's copy module + py::module_ copy_module = py::module_::import("copy"); + + // Get the copy.copy function + PythonCopy = copy_module.attr("copy"); +} + void init_float8_extension() { if (Float8TensorPythonClass) return; auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.float8_tensor"); @@ -88,6 +102,7 @@ void init_extension() { init_float8_extension(); init_mxfp8_extension(); init_float8blockwise_extension(); + init_copy_extension(); } } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/pybind.h b/transformer_engine/pytorch/csrc/pybind.h index 9fd1ae4de9..0e2c869084 100644 --- a/transformer_engine/pytorch/csrc/pybind.h +++ b/transformer_engine/pytorch/csrc/pybind.h @@ -42,10 +42,13 @@ extern PyTypeObject *Float8BlockwiseQTensorBasePythonClass; extern PyTypeObject *Float8BlockwiseQuantizerClass; void init_extension(); +py::object python_copy(const py::handle &src); -void init_float8_extension(); - -void init_mxfp8_extension(); +enum class PythonTensorType { + INVALID = 0, + TENSOR = 1, + TENSOR_BASE = 2, +}; namespace detail { @@ -55,23 +58,28 @@ inline bool IsFloat8CurrentScalingQuantizers(PyObject *obj) { return Py_TYPE(obj) == Float8CurrentScalingQuantizerClass; } -inline bool IsFloat8Tensor(PyObject *obj) { - return Py_TYPE(obj) == Float8TensorPythonClass || Py_TYPE(obj) == Float8TensorBasePythonClass; +inline PythonTensorType IsFloat8Tensor(PyObject *obj) { + if (Py_TYPE(obj) == Float8TensorPythonClass) return PythonTensorType::TENSOR; + if (Py_TYPE(obj) == Float8TensorBasePythonClass) return PythonTensorType::TENSOR_BASE; + return PythonTensorType::INVALID; } inline bool IsMXFP8Quantizers(PyObject *obj) { return Py_TYPE(obj) == MXFP8QuantizerClass; } -inline bool IsMXFP8Tensor(PyObject *obj) { - return Py_TYPE(obj) == MXFP8TensorPythonClass || Py_TYPE(obj) == MXFP8TensorBasePythonClass; +inline PythonTensorType IsMXFP8Tensor(PyObject *obj) { + if (Py_TYPE(obj) == MXFP8TensorPythonClass) return PythonTensorType::TENSOR; + if (Py_TYPE(obj) == MXFP8TensorBasePythonClass) return PythonTensorType::TENSOR_BASE; + return PythonTensorType::INVALID; } inline bool IsFloat8BlockwiseQuantizers(PyObject *obj) { return Py_TYPE(obj) == Float8BlockwiseQuantizerClass; } -inline bool IsFloat8BlockwiseQTensor(PyObject *obj) { - return Py_TYPE(obj) == Float8BlockwiseQTensorPythonClass || - Py_TYPE(obj) == Float8BlockwiseQTensorBasePythonClass; +inline PythonTensorType IsFloat8BlockwiseQTensor(PyObject *obj) { + if (Py_TYPE(obj) == Float8BlockwiseQTensorPythonClass) return PythonTensorType::TENSOR; + if (Py_TYPE(obj) == Float8BlockwiseQTensorBasePythonClass) return PythonTensorType::TENSOR_BASE; + return PythonTensorType::INVALID; } TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 6cbc3ce1e7..c4875d8fa8 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -129,8 +129,11 @@ std::pair Float8Quantizer::create_tensor( bool create_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); TensorWrapper tensor(this->get_scaling_mode()); + auto output_tensor_type = PythonTensorType::INVALID; if (!output.is_none()) { - NVTE_CHECK(detail::IsFloat8Tensor(output.ptr()), "Wrong Tensor type provided for reuse. ", + output_tensor_type = detail::IsFloat8Tensor(output.ptr()); + NVTE_CHECK(output_tensor_type != PythonTensorType::INVALID, + "Wrong Tensor type provided for reuse. ", "Expected Float8Tensor or Float8TensorBase, but got ", py::repr(output).cast()); } @@ -176,14 +179,18 @@ std::pair Float8Quantizer::create_tensor( ret = Float8TensorClass("shape"_a = rowwise_torch_shape, "dtype"_a = GetATenDType(dtype), "data"_a = data, "fp8_scale_inv"_a = scale_inv, "fp8_dtype"_a = this->dtype, "data_transpose"_a = columnwise_data, - "quantizer"_a = this->quantizer); + "quantizer"_a = python_copy(this->quantizer)); } } else { output.attr("_data") = data; output.attr("_scale_inv") = scale_inv; output.attr("_fp8_dtype") = this->dtype; output.attr("_transpose") = columnwise_data; - output.attr("_quantizer") = this->quantizer; + if (output_tensor_type == PythonTensorType::TENSOR_BASE) { + output.attr("_quantizer") = this->quantizer; + } else { + output.attr("_quantizer") = python_copy(this->quantizer); + } output.attr("_transpose_invalid") = (columnwise_data == std::nullopt); ret = output; } @@ -255,8 +262,11 @@ std::pair Float8CurrentScalingQuantizer::create_tenso TensorWrapper tensor(this->get_scaling_mode()); bool create_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + auto output_tensor_type = PythonTensorType::INVALID; if (!output.is_none()) { - NVTE_CHECK(detail::IsFloat8Tensor(output.ptr()), "Wrong Tensor type provided for reuse. ", + output_tensor_type = detail::IsFloat8Tensor(output.ptr()); + NVTE_CHECK(output_tensor_type != PythonTensorType::INVALID, + "Wrong Tensor type provided for reuse. ", "Expected Float8Tensor or Float8TensorBase, but got ", py::repr(output).cast()); } @@ -299,14 +309,18 @@ std::pair Float8CurrentScalingQuantizer::create_tenso ret = Float8TensorClass("shape"_a = rowwise_torch_shape, "dtype"_a = GetATenDType(dtype), "data"_a = data, "fp8_scale_inv"_a = scale_inv, "fp8_dtype"_a = this->dtype, "data_transpose"_a = columnwise_data, - "quantizer"_a = this->quantizer); + "quantizer"_a = python_copy(this->quantizer)); } } else { output.attr("_data") = data; output.attr("_scale_inv") = scale_inv; output.attr("_fp8_dtype") = this->dtype; output.attr("_transpose") = columnwise_data; - output.attr("_quantizer") = this->quantizer; + if (output_tensor_type == PythonTensorType::TENSOR_BASE) { + output.attr("_quantizer") = this->quantizer; + } else { + output.attr("_quantizer") = python_copy(this->quantizer); + } output.attr("_transpose_invalid") = (columnwise_data == std::nullopt); ret = output; } @@ -359,8 +373,10 @@ std::pair Float8BlockQuantizer::create_tensor( size_t m_dim = numel / k_dim; constexpr size_t kBlockLen = 128; + auto output_tensor_type = PythonTensorType::INVALID; if (!output.is_none()) { - NVTE_CHECK(detail::IsFloat8BlockwiseQTensor(output.ptr()), + output_tensor_type = detail::IsFloat8BlockwiseQTensor(output.ptr()); + NVTE_CHECK(output_tensor_type != PythonTensorType::INVALID, "Wrong Tensor type provided for reuse. ", "Expected Float8BlockwiseQTensor or Float8BlockwiseQTensorBase, but got ", py::repr(output).cast()); @@ -455,12 +471,17 @@ std::pair Float8BlockQuantizer::create_tensor( "shape"_a = torch_shape, "dtype"_a = GetATenDType(dtype), "rowwise_data"_a = data_rowwise, "columnwise_data"_a = data_colwise, "rowwise_scale_inv"_a = scale_inv_rowwise, "columnwise_scale_inv"_a = scale_inv_colwise, "fp8_dtype"_a = this->dtype, - "quantizer"_a = this->quantizer, "is_2D_scaled"_a = (block_scaling_dim == 2)); + "quantizer"_a = python_copy(this->quantizer), + "is_2D_scaled"_a = (block_scaling_dim == 2)); } } else { output.attr("_rowwise_data") = data_rowwise; output.attr("_columnwise_data") = data_colwise; - output.attr("_quantizer") = this->quantizer; + if (output_tensor_type == PythonTensorType::TENSOR_BASE) { + output.attr("_quantizer") = this->quantizer; + } else { + output.attr("_quantizer") = python_copy(this->quantizer); + } output.attr("_fp8_dtype") = this->dtype; output.attr("_rowwise_scale_inv") = scale_inv_rowwise; output.attr("_columnwise_scale_inv") = scale_inv_colwise; @@ -507,8 +528,10 @@ std::pair MXFP8Quantizer::create_tensor( "MXFP8 requires tensor dims that are divisble by ", MXFP8_BLOCK_SIZE, " (got shape=", torch_shape, ")"); + auto output_tensor_type = PythonTensorType::INVALID; if (!output.is_none()) { - NVTE_CHECK(detail::IsMXFP8Tensor(output.ptr()), "Wrong Tensor type provided for reuse. ", + output_tensor_type = detail::IsMXFP8Tensor(output.ptr()); + NVTE_CHECK(output_tensor_type != PythonTensorType::INVALID, "Wrong Tensor type provided for reuse. ", "Expected MXFP8Tensor or MXFP8TensorBase, but got ", py::repr(output).cast()); } @@ -558,12 +581,17 @@ std::pair MXFP8Quantizer::create_tensor( "rowwise_data"_a = data_rowwise, "columnwise_data"_a = data_colwise, "rowwise_scale_inv"_a = rowwise_scale_inv, "columnwise_scale_inv"_a = columnwise_scale_inv, - "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); + "fp8_dtype"_a = this->dtype, + "quantizer"_a = python_copy(this->quantizer)); } } else { output.attr("_rowwise_data") = data_rowwise; output.attr("_columnwise_data") = data_colwise; - output.attr("_quantizer") = this->quantizer; + if (output_tensor_type == PythonTensorType::TENSOR_BASE) { + output.attr("_quantizer") = this->quantizer; + } else { + output.attr("_quantizer") = python_copy(this->quantizer); + } output.attr("_fp8_dtype") = this->dtype; output.attr("_rowwise_scale_inv") = rowwise_scale_inv; output.attr("_columnwise_scale_inv") = columnwise_scale_inv; From 6dcd4807a7399f6641fd9568af25fc557bc79113 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Wed, 11 Jun 2025 11:26:57 -0700 Subject: [PATCH 16/25] Fix Signed-off-by: Przemek Tredak --- transformer_engine/pytorch/csrc/common.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index 69f129cbc2..5f90657973 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -67,7 +67,7 @@ TensorWrapper makeTransformerEngineTensor(py::handle tensor, py::handle quantize // also during dequantize, the quantizer param is unknown -> so quantizer is NoneQuantizer for (auto [check_type, check_quantizer_type, create_tensor, _] : detail::custom_types_converters) { - if (check_type(tensor.ptr())) { + if (check_type(tensor.ptr()) != PythonTensorType::INVALID) { if (!(quantizer.is_none() || check_quantizer_type(quantizer.ptr()))) { continue; } From b6f1aeb8c3d85409709ec521a4d922dc1a008b60 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Wed, 11 Jun 2025 16:31:28 -0700 Subject: [PATCH 17/25] Fix CUDA graphs test Signed-off-by: Przemek Tredak --- transformer_engine/pytorch/csrc/quantizer.cpp | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index c4875d8fa8..af3aac15d9 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -124,8 +124,12 @@ std::pair Float8Quantizer::create_tensor( std::optional data = std::nullopt; std::optional columnwise_data = std::nullopt; - // TODO: Replace with an empty tensor. - at::Tensor scale_inv = at::reciprocal(scale); + at::Tensor scale_inv = create_torch_tensor(scale.sizes().vec(), scale.options(), + output, "_scale_inv"); + // TODO: Remove + if (rowwise_data.has_value()) { + at::reciprocal_out(scale_inv, scale); + } bool create_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); TensorWrapper tensor(this->get_scaling_mode()); @@ -257,7 +261,13 @@ std::pair Float8CurrentScalingQuantizer::create_tenso std::optional data = std::nullopt; std::optional columnwise_data = std::nullopt; // In current scaling, scale is not known but we initialize it with 1 to avoid division by zero. If scale is already calculated, it can be correctly set. - at::Tensor scale_inv = at::reciprocal(scale); + at::Tensor scale_inv = create_torch_tensor(scale.sizes().vec(), scale.options(), + output, "_scale_inv"); + // TODO: Remove + if (rowwise_data.has_value()) { + at::reciprocal_out(scale_inv, scale); + } + TensorWrapper tensor(this->get_scaling_mode()); From 1f4f894b9b0f16b4fa6d2926b29d48f7b9be49ec Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 12 Jun 2025 00:01:29 +0000 Subject: [PATCH 18/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/csrc/extensions/pybind.cpp | 4 +-- transformer_engine/pytorch/csrc/quantizer.cpp | 29 +++++++++---------- 2 files changed, 14 insertions(+), 19 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index ab7eb41814..d3274f3ee5 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -31,9 +31,7 @@ PyTypeObject *Float8BlockwiseQuantizerClass = nullptr; py::object PythonCopy; -py::object python_copy(const py::handle& src) { - return PythonCopy(src); -} +py::object python_copy(const py::handle &src) { return PythonCopy(src); } void init_copy_extension() { if (!PythonCopy.is_none()) return; diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 4209d2298e..86bd7164bd 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -124,8 +124,8 @@ std::pair Float8Quantizer::create_tensor( std::optional data = std::nullopt; std::optional columnwise_data = std::nullopt; - at::Tensor scale_inv = create_torch_tensor(scale.sizes().vec(), scale.options(), - output, "_scale_inv"); + at::Tensor scale_inv = + create_torch_tensor(scale.sizes().vec(), scale.options(), output, "_scale_inv"); // TODO: Remove if (rowwise_data.has_value()) { at::reciprocal_out(scale_inv, scale); @@ -261,14 +261,13 @@ std::pair Float8CurrentScalingQuantizer::create_tenso std::optional data = std::nullopt; std::optional columnwise_data = std::nullopt; // In current scaling, scale is not known but we initialize it with 1 to avoid division by zero. If scale is already calculated, it can be correctly set. - at::Tensor scale_inv = create_torch_tensor(scale.sizes().vec(), scale.options(), - output, "_scale_inv"); + at::Tensor scale_inv = + create_torch_tensor(scale.sizes().vec(), scale.options(), output, "_scale_inv"); // TODO: Remove if (rowwise_data.has_value()) { at::reciprocal_out(scale_inv, scale); } - TensorWrapper tensor(this->get_scaling_mode()); bool create_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); @@ -509,8 +508,7 @@ std::pair Float8BlockQuantizer::create_tensor( "shape"_a = torch_shape, "dtype"_a = GetATenDType(dtype), "rowwise_data"_a = data_rowwise, "columnwise_data"_a = data_colwise, "rowwise_scale_inv"_a = scale_inv_rowwise, "columnwise_scale_inv"_a = scale_inv_colwise, "fp8_dtype"_a = this->dtype, - "quantizer"_a = python_copy(this->quantizer), - "is_2D_scaled"_a = (block_scaling_dim == 2), + "quantizer"_a = python_copy(this->quantizer), "is_2D_scaled"_a = (block_scaling_dim == 2), "data_format"_a = data_format); } } else { @@ -571,9 +569,9 @@ std::pair MXFP8Quantizer::create_tensor( auto output_tensor_type = PythonTensorType::INVALID; if (!output.is_none()) { output_tensor_type = detail::IsMXFP8Tensor(output.ptr()); - NVTE_CHECK(output_tensor_type != PythonTensorType::INVALID, "Wrong Tensor type provided for reuse. ", - "Expected MXFP8Tensor or MXFP8TensorBase, but got ", - py::repr(output).cast()); + NVTE_CHECK( + output_tensor_type != PythonTensorType::INVALID, "Wrong Tensor type provided for reuse. ", + "Expected MXFP8Tensor or MXFP8TensorBase, but got ", py::repr(output).cast()); } if (rowwise_usage) { @@ -617,12 +615,11 @@ std::pair MXFP8Quantizer::create_tensor( "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); } else { py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorPythonClass)); - ret = MXFP8TensorClass("shape"_a = torch_shape, "dtype"_a = GetATenDType(dtype), - "rowwise_data"_a = data_rowwise, "columnwise_data"_a = data_colwise, - "rowwise_scale_inv"_a = rowwise_scale_inv, - "columnwise_scale_inv"_a = columnwise_scale_inv, - "fp8_dtype"_a = this->dtype, - "quantizer"_a = python_copy(this->quantizer)); + ret = MXFP8TensorClass( + "shape"_a = torch_shape, "dtype"_a = GetATenDType(dtype), "rowwise_data"_a = data_rowwise, + "columnwise_data"_a = data_colwise, "rowwise_scale_inv"_a = rowwise_scale_inv, + "columnwise_scale_inv"_a = columnwise_scale_inv, "fp8_dtype"_a = this->dtype, + "quantizer"_a = python_copy(this->quantizer)); } } else { output.attr("_rowwise_data") = data_rowwise; From 53554f27a06bb752c11fa417f8c2fffaef87d512 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Wed, 11 Jun 2025 17:06:53 -0700 Subject: [PATCH 19/25] Fix Signed-off-by: Przemek Tredak --- transformer_engine/pytorch/csrc/extensions/pybind.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index d3274f3ee5..d7f78d4e30 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -29,7 +29,7 @@ PyTypeObject *Float8BlockwiseQTensorPythonClass = nullptr; PyTypeObject *Float8BlockwiseQTensorBasePythonClass = nullptr; PyTypeObject *Float8BlockwiseQuantizerClass = nullptr; -py::object PythonCopy; +py::object PythonCopy = py::none(); py::object python_copy(const py::handle &src) { return PythonCopy(src); } From 343d43d335992302ce338e7855fbfe50f3ac8e53 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Fri, 13 Jun 2025 14:05:25 -0700 Subject: [PATCH 20/25] Fix the float8blockwise tests and MXFP8 cuda graphs tests Signed-off-by: Przemek Tredak --- tests/pytorch/test_float8blockwisetensor.py | 2 +- transformer_engine/pytorch/csrc/quantizer.cpp | 86 ++++++++++++------- 2 files changed, 57 insertions(+), 31 deletions(-) diff --git a/tests/pytorch/test_float8blockwisetensor.py b/tests/pytorch/test_float8blockwisetensor.py index 1f23be3626..0b8b445856 100644 --- a/tests/pytorch/test_float8blockwisetensor.py +++ b/tests/pytorch/test_float8blockwisetensor.py @@ -206,7 +206,7 @@ def test_quantize_dequantize_dims( @pytest.mark.parametrize( "dims", [[], 256, 311, [264], [256, 512], [250, 500], [7, 5, 3], [2, 3, 5, 3]] ) - @pytest.mark.parametrize("block_scaling_dim", [1, 2]) + @pytest.mark.parametrize("block_scaling_dim", [1]) @pytest.mark.parametrize("dq_columnwise", [True, False]) @pytest.mark.xfail(raises=NotImplementedError) def test_quantize_dequantize_compact_format( diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 86bd7164bd..afaee2cc61 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -48,34 +48,35 @@ Float8Quantizer::Float8Quantizer(const py::handle& quantizer) : Quantizer(quanti } // Create torch tensor reusing existing data if possible -at::Tensor _create_torch_tensor(const std::vector& shape, const at::TensorOptions& opts, - const py::object& tensor_to_reuse, bool zero_out) { +std::pair create_torch_tensor_ex( + const std::vector& shape, const at::TensorOptions& opts, + const py::object& tensor_to_reuse) { if (!tensor_to_reuse.is_none()) { // Reuse output const at::Tensor temp = tensor_to_reuse.cast(); if (tensor_is_reusable(temp, shape, opts)) { - if (zero_out) { - temp.zero_(); - } - return temp; + return {temp, true}; } } - if (zero_out) { - return at::zeros(shape, opts); - } - return at::empty(shape, opts); + return {at::empty(shape, opts), false}; } // Create torch tensor reusing existing data is possible // The reused tensor is tensor_to_reuse.attr_name -at::Tensor create_torch_tensor(const std::vector& shape, const at::TensorOptions& opts, - const py::object& tensor_to_reuse, const std::string_view& attr_name, - bool zero_out = false) { +std::pair create_torch_tensor_ex( + const std::vector& shape, const at::TensorOptions& opts, + const py::object& tensor_to_reuse, const std::string_view& attr_name) { py::object tensor{py::none()}; if (!tensor_to_reuse.is_none()) { tensor = tensor_to_reuse.attr(attr_name.data()); } - return _create_torch_tensor(shape, opts, tensor, zero_out); + return create_torch_tensor_ex(shape, opts, tensor); +} + +at::Tensor create_torch_tensor( + const std::vector& shape, const at::TensorOptions& opts, + const py::object& tensor_to_reuse, const std::string_view& attr_name) { + return create_torch_tensor_ex(shape, opts, tensor_to_reuse, attr_name).first; } std::pair NoneQuantizer::create_tensor( @@ -88,7 +89,7 @@ std::pair NoneQuantizer::create_tensor( if (rowwise_data.has_value()) { ret = std::move(*rowwise_data); } else { - ret = _create_torch_tensor(torch_shape, opts, output, false); + std::tie(ret, std::ignore) = create_torch_tensor_ex(torch_shape, opts, output); } TensorWrapper tensor; @@ -124,8 +125,7 @@ std::pair Float8Quantizer::create_tensor( std::optional data = std::nullopt; std::optional columnwise_data = std::nullopt; - at::Tensor scale_inv = - create_torch_tensor(scale.sizes().vec(), scale.options(), output, "_scale_inv"); + auto scale_inv = create_torch_tensor(scale.sizes().vec(), scale.options(), output, "_scale_inv"); // TODO: Remove if (rowwise_data.has_value()) { at::reciprocal_out(scale_inv, scale); @@ -261,8 +261,7 @@ std::pair Float8CurrentScalingQuantizer::create_tenso std::optional data = std::nullopt; std::optional columnwise_data = std::nullopt; // In current scaling, scale is not known but we initialize it with 1 to avoid division by zero. If scale is already calculated, it can be correctly set. - at::Tensor scale_inv = - create_torch_tensor(scale.sizes().vec(), scale.options(), output, "_scale_inv"); + auto scale_inv = create_torch_tensor(scale.sizes().vec(), scale.options(), output, "_scale_inv"); // TODO: Remove if (rowwise_data.has_value()) { at::reciprocal_out(scale_inv, scale); @@ -429,9 +428,8 @@ std::pair Float8BlockQuantizer::create_tensor( block_scaling_dim); } break; } - scale_inv_rowwise = - create_torch_tensor({static_cast(sinv0), static_cast(sinv1)}, scale_opts, - output, "_rowwise_scale_inv"); + scale_inv_rowwise = create_torch_tensor({static_cast(sinv0), static_cast(sinv1)}, + scale_opts, output, "_rowwise_scale_inv"); tensor.set_rowwise_data(data_rowwise->data_ptr(), this->dtype, shape); tensor.set_rowwise_scale_inv(scale_inv_rowwise->data_ptr(), DType::kFloat32, std::vector{sinv0, sinv1}); @@ -575,16 +573,30 @@ std::pair MXFP8Quantizer::create_tensor( } if (rowwise_usage) { + bool data_reused = false; if (rowwise_data.has_value()) { data_rowwise = std::move(*rowwise_data); } else { - data_rowwise = create_torch_tensor(torch_shape, opts, output, "_rowwise_data"); + std::tie(data_rowwise, data_reused) = create_torch_tensor_ex(torch_shape, opts, output, + "_rowwise_data"); } auto sinv0 = roundup(numel / last_dim, 128lu); auto sinv1 = roundup(last_dim / MXFP8_BLOCK_SIZE, 4lu); - rowwise_scale_inv = - create_torch_tensor({static_cast(sinv0), static_cast(sinv1)}, opts, - output, "_rowwise_scale_inv", true); + bool scale_inv_reused = false; + std::tie(rowwise_scale_inv, scale_inv_reused) = + create_torch_tensor_ex({static_cast(sinv0), static_cast(sinv1)}, opts, + output, "_rowwise_scale_inv"); + if (!data_reused || !scale_inv_reused) { + // cuBLAS requires scale inverse tensor to be zero-padded. + // If both the data and the scale inverse are reused though + // then the provided output tensor was a valid MXFP8 tensor + // so we can assume that the zero padding is already done. + // Not doing this zero-padding in such scenario is important + // in the CUDA graphs + weight caching case where the reuse + // would happen but the quantization would not be performed + // again. + rowwise_scale_inv->zero_(); + } tensor.set_rowwise_data(data_rowwise->data_ptr(), this->dtype, shape); tensor.set_rowwise_scale_inv( rowwise_scale_inv->data_ptr(), DType::kFloat8E8M0, @@ -594,10 +606,24 @@ std::pair MXFP8Quantizer::create_tensor( if (columnwise_usage) { auto sinv0 = roundup(numel / (last_dim * MXFP8_BLOCK_SIZE), 4lu); auto sinv1 = roundup(last_dim, 128lu); - data_colwise = create_torch_tensor(torch_shape, opts, output, "_columnwise_data"); - columnwise_scale_inv = - create_torch_tensor({static_cast(sinv0), static_cast(sinv1)}, opts, - output, "_columnwise_scale_inv", true); + bool data_reused = false; + std::tie(data_colwise, data_reused) = create_torch_tensor_ex(torch_shape, opts, output, + "_columnwise_data"); + bool scale_inv_reused = false; + std::tie(columnwise_scale_inv, scale_inv_reused) = + create_torch_tensor_ex({static_cast(sinv0), static_cast(sinv1)}, opts, + output, "_columnwise_scale_inv"); + if (!data_reused || !scale_inv_reused) { + // cuBLAS requires scale inverse tensor to be zero-padded. + // If both the data and the scale inverse are reused though + // then the provided output tensor was a valid MXFP8 tensor + // so we can assume that the zero padding is already done. + // Not doing this zero-padding in such scenario is important + // in the CUDA graphs + weight caching case where the reuse + // would happen but the quantization would not be performed + // again. + columnwise_scale_inv->zero_(); + } tensor.set_columnwise_data(data_colwise->data_ptr(), this->dtype, shape); tensor.set_columnwise_scale_inv(columnwise_scale_inv->data_ptr(), DType::kFloat8E8M0, From 817d8ceb1d14b158adbc1ca6248e5c44ab5111ed Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 13 Jun 2025 21:06:12 +0000 Subject: [PATCH 21/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/csrc/quantizer.cpp | 32 ++++++++++--------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index afaee2cc61..20da35f262 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -48,9 +48,9 @@ Float8Quantizer::Float8Quantizer(const py::handle& quantizer) : Quantizer(quanti } // Create torch tensor reusing existing data if possible -std::pair create_torch_tensor_ex( - const std::vector& shape, const at::TensorOptions& opts, - const py::object& tensor_to_reuse) { +std::pair create_torch_tensor_ex(const std::vector& shape, + const at::TensorOptions& opts, + const py::object& tensor_to_reuse) { if (!tensor_to_reuse.is_none()) { // Reuse output const at::Tensor temp = tensor_to_reuse.cast(); @@ -63,9 +63,10 @@ std::pair create_torch_tensor_ex( // Create torch tensor reusing existing data is possible // The reused tensor is tensor_to_reuse.attr_name -std::pair create_torch_tensor_ex( - const std::vector& shape, const at::TensorOptions& opts, - const py::object& tensor_to_reuse, const std::string_view& attr_name) { +std::pair create_torch_tensor_ex(const std::vector& shape, + const at::TensorOptions& opts, + const py::object& tensor_to_reuse, + const std::string_view& attr_name) { py::object tensor{py::none()}; if (!tensor_to_reuse.is_none()) { tensor = tensor_to_reuse.attr(attr_name.data()); @@ -73,9 +74,9 @@ std::pair create_torch_tensor_ex( return create_torch_tensor_ex(shape, opts, tensor); } -at::Tensor create_torch_tensor( - const std::vector& shape, const at::TensorOptions& opts, - const py::object& tensor_to_reuse, const std::string_view& attr_name) { +at::Tensor create_torch_tensor(const std::vector& shape, const at::TensorOptions& opts, + const py::object& tensor_to_reuse, + const std::string_view& attr_name) { return create_torch_tensor_ex(shape, opts, tensor_to_reuse, attr_name).first; } @@ -428,8 +429,9 @@ std::pair Float8BlockQuantizer::create_tensor( block_scaling_dim); } break; } - scale_inv_rowwise = create_torch_tensor({static_cast(sinv0), static_cast(sinv1)}, - scale_opts, output, "_rowwise_scale_inv"); + scale_inv_rowwise = + create_torch_tensor({static_cast(sinv0), static_cast(sinv1)}, scale_opts, + output, "_rowwise_scale_inv"); tensor.set_rowwise_data(data_rowwise->data_ptr(), this->dtype, shape); tensor.set_rowwise_scale_inv(scale_inv_rowwise->data_ptr(), DType::kFloat32, std::vector{sinv0, sinv1}); @@ -577,8 +579,8 @@ std::pair MXFP8Quantizer::create_tensor( if (rowwise_data.has_value()) { data_rowwise = std::move(*rowwise_data); } else { - std::tie(data_rowwise, data_reused) = create_torch_tensor_ex(torch_shape, opts, output, - "_rowwise_data"); + std::tie(data_rowwise, data_reused) = + create_torch_tensor_ex(torch_shape, opts, output, "_rowwise_data"); } auto sinv0 = roundup(numel / last_dim, 128lu); auto sinv1 = roundup(last_dim / MXFP8_BLOCK_SIZE, 4lu); @@ -607,8 +609,8 @@ std::pair MXFP8Quantizer::create_tensor( auto sinv0 = roundup(numel / (last_dim * MXFP8_BLOCK_SIZE), 4lu); auto sinv1 = roundup(last_dim, 128lu); bool data_reused = false; - std::tie(data_colwise, data_reused) = create_torch_tensor_ex(torch_shape, opts, output, - "_columnwise_data"); + std::tie(data_colwise, data_reused) = + create_torch_tensor_ex(torch_shape, opts, output, "_columnwise_data"); bool scale_inv_reused = false; std::tie(columnwise_scale_inv, scale_inv_reused) = create_torch_tensor_ex({static_cast(sinv0), static_cast(sinv1)}, opts, From 207e4b7e04009a271cd86ad8adf58c442414d3bc Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Fri, 13 Jun 2025 14:48:16 -0700 Subject: [PATCH 22/25] Fix issue from merge Signed-off-by: Przemek Tredak --- transformer_engine/pytorch/ops/basic/basic_linear.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 507d0d77e1..bc01e40301 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -313,7 +313,9 @@ def pre_forward(self, *args, **kwargs) -> None: # Configure quantizers if FP8GlobalStateManager.is_fp8_enabled(): + input_quantizer = self.get_quantizer("forward", 0) weight_quantizer = self.get_quantizer("forward", 1) + grad_output_quantizer = self.get_quantizer("backward", 0) # Recipe-specific configuration recipe = FP8GlobalStateManager.get_fp8_recipe() From 715cc53f0ea6f0acb4246f46823923f454d533b4 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Fri, 13 Jun 2025 15:35:32 -0700 Subject: [PATCH 23/25] Always use tex.quantize when updating cache to use proper quantizer Signed-off-by: Przemek Tredak --- transformer_engine/pytorch/module/base.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 1b283362c1..d76dae476a 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1320,10 +1320,7 @@ def get_weight_workspace( if update_workspace: if tensor is None: raise ValueError("tensor kwarg must be provided to update FP8 workspace") - if hasattr(out, "quantize_"): - out.quantize_(tensor, noop_flag=skip_update_flag) - else: - tex.quantize(tensor, quantizer, out, skip_update_flag) + tex.quantize(tensor, quantizer, out, skip_update_flag) return out def _load_from_state_dict( From d682178c1b4e1a4d9b9eef48aa445ae7db2150a1 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Tue, 17 Jun 2025 19:11:05 +0000 Subject: [PATCH 24/25] Debug Signed-off-by: Przemek Tredak --- tests/pytorch/test_sanity.py | 80 ++++++++++++++++--- .../dot_product_attention/backends.py | 9 +++ .../pytorch/cpp_extensions/fused_attn.py | 3 + transformer_engine/pytorch/csrc/extensions.h | 2 +- .../pytorch/csrc/extensions/attention.cpp | 79 +++++++++++++++++- 5 files changed, 162 insertions(+), 11 deletions(-) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 7db5aef202..623836f1cf 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -104,7 +104,7 @@ def is_fp8_supported(self): model_configs = { "126m": ModelConfig(12, 2048, 2, 768, 12), - "small": ModelConfig(2, 32, 2, 64, 2), + "small": ModelConfig(2, 16, 2, 128, 1), "weird": ModelConfig(2, 37, 3, 69, 3), "large": ModelConfig(1, 128, 2, 512, 4, 128), } @@ -1152,27 +1152,31 @@ def test_sanity_fp8_gemm_with_unalignment(N, datatype): @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper.") @pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.") -@pytest.mark.parametrize("model", ["large"]) +@pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) def test_sanity_attention_extra_state(model, dtype): config = model_configs[model] + print("regular") outputs = _run_attention_extra_state(dtype, config, checkpoint=False) + print("checkpointed") outputs_checkpoint = _run_attention_extra_state(dtype, config, checkpoint=True) - outputs_checkpoint_v1_6 = _run_attention_extra_state( - dtype, config, mimic_v1_6=True, checkpoint=True - ) + # outputs_checkpoint_v1_6 = _run_attention_extra_state( + # dtype, config, mimic_v1_6=True, checkpoint=True + # ) # Check that results match tols = dtype_tols(dtype) if dtype in (torch.float16, torch.bfloat16): tols.update(dict(rtol=2e-2, atol=2e-3)) for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint)): + print(i) torch.testing.assert_close( test, ref, **tols, ) for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint_v1_6)): + print(f"Second loop {i}") torch.testing.assert_close( test, ref, @@ -1201,6 +1205,8 @@ def _run_attention_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False requires_grad=True, ) + torch.set_printoptions(threshold=100_000_000) + def get_model(dtype, config): sigma = 0.023 init_method = init_method_normal(sigma) @@ -1219,15 +1225,58 @@ def get_model(dtype, config): params_dtype=dtype, device="cuda", ) + # block = torch.nn.Sequential( + # Linear(config.hidden_size, + # config.hidden_size), + # Linear(config.hidden_size, + # config.hidden_size), + # Linear(config.hidden_size, + # config.hidden_size), + # Linear(config.hidden_size, + # config.hidden_size)) + # block.to(dtype=dtype) return block block = get_model(dtype, config) + print("Before the first loop") + # for n,p in block.named_parameters(): + # print(n) + # print(p) + print("data") + print(block.self_attention.proj.weight._data) + print("scale_inv") + print(block.self_attention.proj.weight._scale_inv) + print("transpose") + print(block.self_attention.proj.weight._transpose) + + import transformer_engine.pytorch.attention.dot_product_attention.backends as bbb + bbb.DEBUG_BLOCK = block + print("set!") + + print("End before the first loop") + print(f"scale inv: {block.self_attention.proj.weight._scale_inv}") for i in range(steps // 2): with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe): - output = block(hidden_states, None) + print(f"scale inv 0: {block.self_attention.proj.weight._scale_inv}") + output = block(hidden_states) + print(f"scale inv 1: {block.self_attention.proj.weight._scale_inv}") + print(f"output {i}") + print(output) loss = output.sum() - loss.backward() - + loss.backward() + print(f"scale inv 2: {block.self_attention.proj.weight._scale_inv}") + + print("Before the checkpoint") + # for n,p in block.named_parameters(): + # print(n) + # print(p) + print("data") + print(block.self_attention.proj.weight._data) + print("scale_inv") + print(block.self_attention.proj.weight._scale_inv) + print("transpose") + print(block.self_attention.proj.weight._transpose) + print("End before the checkpoint") if checkpoint: sd = block.state_dict() if mimic_v1_6: @@ -1259,10 +1308,23 @@ def get_model(dtype, config): for i in range((steps + 1) // 2): with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe): - output = block(hidden_states, None) + output = block(hidden_states) + print(f"after output {i}") + print(output) loss = output.sum() loss.backward() + print("After the checkpoint") + # for n,p in block.named_parameters(): + # print(n) + # print(p) + print("data") + print(block.self_attention.proj.weight._data) + print("scale_inv") + print(block.self_attention.proj.weight._scale_inv) + print("transpose") + print(block.self_attention.proj.weight._transpose) + print("End after the checkpoint") torch.cuda.synchronize() if os.path.exists(path): diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index b3b7630df3..a370ff349d 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -57,6 +57,8 @@ AttentionLogging as attn_log, ) +DEBUG_BLOCK = None + # Global vars for flash attn v2 and v3 imports flash_attn_cuda_bwd = None flash_attn_func = None @@ -964,6 +966,8 @@ def forward( case _: raise "Invalid qkv_layout " + qkv_layout # q_fp8, k_fp8, v_fp8, out_fp8: torch.float8_e4m3fn + print(f"Q quantizer scale: {q_fp8._quantizer.scale.shape}") + print(f"mixed quantizer scale: {qkv_fp8._quantizer.scale.shape}") out_fp8, aux_ctx_tensors = fused_attn_fwd( is_training, max_seqlen_q, @@ -1190,6 +1194,9 @@ def backward(ctx, d_out): dqkv_dtype = TE_DType[d_out_fp8._data.dtype] # q_fp8, k_fp8, v_fp8, out_fp8: torch.float8_e4m3fn # d_out_fp8, dq_fp8, dk_fp8, dv_fp8: torch.float8_e5m2 + print(DEBUG_BLOCK) + if DEBUG_BLOCK is not None: + print(f"Inside attention: {DEBUG_BLOCK.self_attention.proj.weight._scale_inv}") dq_fp8, dk_fp8, dv_fp8, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, @@ -1218,6 +1225,8 @@ def backward(ctx, d_out): ctx.window_size, ctx.deterministic, ) + if DEBUG_BLOCK is not None: + print(f"After Inside attention: {DEBUG_BLOCK.self_attention.proj.weight._scale_inv}") # is_input_fp8 = False: dq, dk, dv: torch.float16 or torch.bfloat16 # is_input_fp8 = True: dq, dk, dv: torch.float8_e5m2 diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index b9810bf861..f5ed383241 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -445,6 +445,8 @@ def fused_attn_bwd( len(aux_ctx_tensors) == 3 ), "aux_ctx_tensors is required to be [M, ZInv, rng_state] for FP8 fused attention." + import transformer_engine.pytorch.attention.dot_product_attention.backends as bbb + debug = bbb.DEBUG_BLOCK.self_attention.proj.weight._scale_inv output_tensors = tex.fused_attn_bwd( max_seqlen_q, max_seqlen_kv, @@ -471,6 +473,7 @@ def fused_attn_bwd( s_quantizer, dp_quantizer, dqkv_quantizer, + debug, ) return output_tensors diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 72f6f27596..61a48534ad 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -62,7 +62,7 @@ std::vector fused_attn_bwd( const std::vector Aux_CTX_Tensors, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, - py::handle dp_quantizer, py::handle dqkv_quantizer); + py::handle dp_quantizer, py::handle dqkv_quantizer, at::Tensor debug); at::Tensor fa_prepare_fwd(at::Tensor qkvi); at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v); diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 1ba459a36d..d13d78212d 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -4,6 +4,7 @@ * See LICENSE for license information. ************************************************************************/ +#include #include "../extensions.h" #include "common.h" #include "pybind.h" @@ -101,6 +102,7 @@ std::vector fused_attn_fwd( // create output tensor O auto o_shape = std::vector{q_shape.begin(), q_shape.end()}; + std::cout << "O shape size: " << o_shape.size() << " V: " << v_shape.size() << std::endl; o_shape[o_shape.size() - 1] = v_shape[v_shape.size() - 1]; py::object o_python, s_python; std::tie(te_O, o_python) = O_quantizer->create_tensor(o_shape, fake_dtype_te); @@ -258,6 +260,41 @@ std::vector fused_attn_fwd( return output_tensors; } +float debug_print(const std::string& name, const at::Tensor& t) { + float ret; + cudaMemcpy(&ret, t.data_ptr(), sizeof(float), cudaMemcpyDeviceToHost); + std::cout << name << " " << ret << " " << *reinterpret_cast(&ret) << std::endl; + return ret; +} + +void debug_print(const std::string& name, const NVTETensor t, bool with_values = false) { + int sizes[] = {1,2,4,8,4,2,2,1,1,1,1}; + for (int i = 0; i < 6; ++i) { + auto param = nvte_get_tensor_param(t, (NVTETensorParam)i); + uintptr_t start = reinterpret_cast(param.data_ptr); + if (start != 0) { + auto num = product(param.shape, 0, param.shape.ndim); + auto end = start + num * sizes[(int)(param.dtype)]; + std::cout << name << " " << start << " " << end << std::endl; + std::cout << name << " shape: " << std::to_string(param.shape) << " dtype: " << std::to_string((int)(param.dtype)) << std::endl; + if (with_values) { + if ((int)(param.dtype) == 2) { + int32_t * values = new int32_t[num]; + cudaMemcpy(values, param.data_ptr, num * sizeof(uint32_t), cudaMemcpyDeviceToHost); + std::cout << name << " Values" << std::endl; + for (int i = 0; i < num; ++i) { + std::cout << i << " " << values[i] << std::endl; + } + std::cout << name << " End Values" << std::endl; + delete[] values; + } + } + } else { + std::cout << name << " 0" << std::endl; + } + } +} + // fused attention BWD with separate Q, K and V std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, @@ -268,7 +305,7 @@ std::vector fused_attn_bwd( const std::vector Aux_CTX_Tensors, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, - py::handle dp_quantizer, py::handle dqkv_quantizer) { + py::handle dp_quantizer, py::handle dqkv_quantizer, at::Tensor debug) { auto none = py::none(); TensorWrapper te_Q, te_K, te_V, te_O, te_dO, te_S, te_dP, te_dQ, te_dK, te_dV; te_Q = makeTransformerEngineTensor(Q, none); @@ -276,6 +313,7 @@ std::vector fused_attn_bwd( te_V = makeTransformerEngineTensor(V, none); te_O = makeTransformerEngineTensor(O, none); te_dO = makeTransformerEngineTensor(dO, none); + debug_print("Beginning", debug); // qkv type from the te_Q std::unique_ptr dQKV_quantizer = convert_quantizer(dqkv_quantizer); const DType qkv_type = te_Q.dtype(); @@ -374,9 +412,13 @@ std::vector fused_attn_bwd( default: NVTE_ERROR("QKV layout not supported!"); } + std::cout << "Creating dQKV" << std::endl; std::tie(te_dQ, py_dQ) = dQKV_quantizer->create_tensor(q_shape, fake_dtype_te, py::none(), dQ); std::tie(te_dK, py_dK) = dQKV_quantizer->create_tensor(k_shape, fake_dtype_te, py::none(), dK); std::tie(te_dV, py_dV) = dQKV_quantizer->create_tensor(v_shape, fake_dtype_te, py::none(), dV); + std::cout << "dq " << reinterpret_cast(dQ.data_ptr()) << std::endl; + std::cout << "dk " << reinterpret_cast(dK.data_ptr()) << std::endl; + std::cout << "dv " << reinterpret_cast(dV.data_ptr()) << std::endl; // construct NVTE tensors if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { @@ -464,6 +506,7 @@ std::vector fused_attn_bwd( // create workspace TensorWrapper workspace; + debug_print("Before nvte call1", debug); // populate tensors with appropriate shapes and dtypes NVTE_SCOPED_GIL_RELEASE({ @@ -481,6 +524,38 @@ std::vector fused_attn_bwd( workspace = makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); + debug_print("Before nvte call2", debug); + std::cout << "debug data ptr: " << reinterpret_cast(debug.data_ptr()) << std::endl; + debug_print("Q", te_Q.data()); + debug_print("K", te_K.data()); + debug_print("V", te_V.data()); + debug_print("O", te_O.data()); + debug_print("dO", te_dO.data()); + debug_print("S", te_S.data()); + debug_print("dP", te_dP.data()); + debug_print("dQ", te_dQ.data()); + debug_print("dK", te_dK.data()); + debug_print("dV", te_dV.data()); + debug_print("dBias", te_dBias.data()); + debug_print("cuseq_q", te_cu_seqlens_q.data(), true); + debug_print("cuseq_q_padded", te_cu_seqlens_q_padded.data(), true); + debug_print("cuseq_kv", te_cu_seqlens_kv.data(), true); + debug_print("cuseq_kv_padded", te_cu_seqlens_kv_padded.data(), true); + debug_print("workspace", workspace.data()); + std::cout << "max_seqlen_q " << max_seqlen_q << std::endl; + std::cout << "max_seqlen_kv " << max_seqlen_kv << std::endl; + std::cout << "attn_scale " << attn_scale << std::endl; + std::cout << "p_dropout " << p_dropout << std::endl; + std::cout << "qkv_layout " << qkv_layout << std::endl; + std::cout << "bias_type " << bias_type << std::endl; + std::cout << "attn_mask_type " << attn_mask_type << std::endl; + std::cout << "window_size[0] " << window_size[0] << std::endl; + std::cout << "window_size[1] " << window_size[1] << std::endl; + std::cout << "deterministic " << deterministic << std::endl; + + for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { + debug_print("Aux" + std::to_string(i), nvte_aux_tensor_pack.tensors[i]); + } // execute kernel NVTE_SCOPED_GIL_RELEASE({ nvte_fused_attn_bwd( @@ -492,8 +567,10 @@ std::vector fused_attn_bwd( workspace.data(), at::cuda::getCurrentCUDAStream()); }); + float result = debug_print("Before destroy", debug); // destroy tensor wrappers nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); + debug_print("Ending", debug); return {py_dQ, py_dK, py_dV, py::cast(dBias)}; } From e6f38d1a9ebf1aa49ac20991505d3241c4afb62a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 17 Jun 2025 19:12:06 +0000 Subject: [PATCH 25/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/test_sanity.py | 3 ++- .../attention/dot_product_attention/backends.py | 9 +++++++-- .../pytorch/cpp_extensions/fused_attn.py | 1 + .../pytorch/csrc/extensions/attention.cpp | 14 ++++++++------ 4 files changed, 18 insertions(+), 9 deletions(-) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 623836f1cf..e30ccd048e 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -1248,8 +1248,9 @@ def get_model(dtype, config): print(block.self_attention.proj.weight._scale_inv) print("transpose") print(block.self_attention.proj.weight._transpose) - + import transformer_engine.pytorch.attention.dot_product_attention.backends as bbb + bbb.DEBUG_BLOCK = block print("set!") diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index a370ff349d..92a89057be 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1196,7 +1196,9 @@ def backward(ctx, d_out): # d_out_fp8, dq_fp8, dk_fp8, dv_fp8: torch.float8_e5m2 print(DEBUG_BLOCK) if DEBUG_BLOCK is not None: - print(f"Inside attention: {DEBUG_BLOCK.self_attention.proj.weight._scale_inv}") + print( + f"Inside attention: {DEBUG_BLOCK.self_attention.proj.weight._scale_inv}" + ) dq_fp8, dk_fp8, dv_fp8, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, @@ -1226,7 +1228,10 @@ def backward(ctx, d_out): ctx.deterministic, ) if DEBUG_BLOCK is not None: - print(f"After Inside attention: {DEBUG_BLOCK.self_attention.proj.weight._scale_inv}") + print( + "After Inside attention:" + f" {DEBUG_BLOCK.self_attention.proj.weight._scale_inv}" + ) # is_input_fp8 = False: dq, dk, dv: torch.float16 or torch.bfloat16 # is_input_fp8 = True: dq, dk, dv: torch.float8_e5m2 diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index f5ed383241..fc3317b2fa 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -446,6 +446,7 @@ def fused_attn_bwd( ), "aux_ctx_tensors is required to be [M, ZInv, rng_state] for FP8 fused attention." import transformer_engine.pytorch.attention.dot_product_attention.backends as bbb + debug = bbb.DEBUG_BLOCK.self_attention.proj.weight._scale_inv output_tensors = tex.fused_attn_bwd( max_seqlen_q, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index d13d78212d..5b16cb04ad 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -5,6 +5,7 @@ ************************************************************************/ #include + #include "../extensions.h" #include "common.h" #include "pybind.h" @@ -260,15 +261,15 @@ std::vector fused_attn_fwd( return output_tensors; } -float debug_print(const std::string& name, const at::Tensor& t) { +float debug_print(const std::string &name, const at::Tensor &t) { float ret; cudaMemcpy(&ret, t.data_ptr(), sizeof(float), cudaMemcpyDeviceToHost); - std::cout << name << " " << ret << " " << *reinterpret_cast(&ret) << std::endl; + std::cout << name << " " << ret << " " << *reinterpret_cast(&ret) << std::endl; return ret; } -void debug_print(const std::string& name, const NVTETensor t, bool with_values = false) { - int sizes[] = {1,2,4,8,4,2,2,1,1,1,1}; +void debug_print(const std::string &name, const NVTETensor t, bool with_values = false) { + int sizes[] = {1, 2, 4, 8, 4, 2, 2, 1, 1, 1, 1}; for (int i = 0; i < 6; ++i) { auto param = nvte_get_tensor_param(t, (NVTETensorParam)i); uintptr_t start = reinterpret_cast(param.data_ptr); @@ -276,10 +277,11 @@ void debug_print(const std::string& name, const NVTETensor t, bool with_values = auto num = product(param.shape, 0, param.shape.ndim); auto end = start + num * sizes[(int)(param.dtype)]; std::cout << name << " " << start << " " << end << std::endl; - std::cout << name << " shape: " << std::to_string(param.shape) << " dtype: " << std::to_string((int)(param.dtype)) << std::endl; + std::cout << name << " shape: " << std::to_string(param.shape) + << " dtype: " << std::to_string((int)(param.dtype)) << std::endl; if (with_values) { if ((int)(param.dtype) == 2) { - int32_t * values = new int32_t[num]; + int32_t *values = new int32_t[num]; cudaMemcpy(values, param.data_ptr, num * sizeof(uint32_t), cudaMemcpyDeviceToHost); std::cout << name << " Values" << std::endl; for (int i = 0; i < num; ++i) {