diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 13b27a0b11..9b9bb58acd 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -837,10 +837,9 @@ def _test_basic_linear( pytest.skip("FP8 output is only supported with FP8 GEMMs") if quantized_grad_input and not quantized_compute: pytest.skip("FP8 grad input is only supported with FP8 GEMMs") - if quantization == "mxfp8" and quantized_output: - pytest.skip("MXFP8 output is not supported with MXFP8 GEMMs") - if quantization == "mxfp8" and quantized_grad_input: - pytest.skip("MXFP8 grad input is not supported with MXFP8 GEMMs") + if quantization not in (None, "fp8"): + if quantized_output or quantized_grad_input: + pytest.skip("Recipe does not support quantized GEMM output") # Random data x_ref, x_test = make_reference_and_test_tensors( diff --git a/transformer_engine/common/libtransformer_engine.version b/transformer_engine/common/libtransformer_engine.version index 4412d0c5fe..706c237ccc 100644 --- a/transformer_engine/common/libtransformer_engine.version +++ b/transformer_engine/common/libtransformer_engine.version @@ -8,6 +8,7 @@ transformer_engine::cuda::stream_priority_range*; transformer_engine::cuda::current_device*; transformer_engine::cuda_driver::get_symbol*; + transformer_engine::cuda_driver::ensure_context_exists*; transformer_engine::ubuf_built_with_mpi*; *transformer_engine::rtc*; transformer_engine::nvte_cudnn_handle_init*; diff --git a/transformer_engine/common/util/cuda_driver.cpp b/transformer_engine/common/util/cuda_driver.cpp index 59d490e58e..4812435f7b 100644 --- a/transformer_engine/common/util/cuda_driver.cpp +++ b/transformer_engine/common/util/cuda_driver.cpp @@ -44,6 +44,19 @@ void *get_symbol(const char *symbol, int cuda_version) { return entry_point; } +void ensure_context_exists() { + CUcontext context; + NVTE_CALL_CHECK_CUDA_DRIVER(cuCtxGetCurrent, &context); + if (context == nullptr) { + // Add primary context to context stack + CUdevice device; + NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &device, cuda::current_device()); + NVTE_CALL_CHECK_CUDA_DRIVER(cuDevicePrimaryCtxRetain, &context, device); + NVTE_CALL_CHECK_CUDA_DRIVER(cuCtxSetCurrent, context); + NVTE_CALL_CHECK_CUDA_DRIVER(cuDevicePrimaryCtxRelease, device); + } +} + } // namespace cuda_driver } // namespace transformer_engine diff --git a/transformer_engine/common/util/cuda_driver.h b/transformer_engine/common/util/cuda_driver.h index a0fcd65c85..3425e0af35 100644 --- a/transformer_engine/common/util/cuda_driver.h +++ b/transformer_engine/common/util/cuda_driver.h @@ -39,6 +39,14 @@ inline CUresult call(const char *symbol, ArgTs... args) { return (*func)(args...); } +/*! \brief Ensure that the calling thread has a CUDA context + * + * Each thread maintains a stack of CUDA contexts. If the calling + * thread has an empty stack, the primary context is added to the + * stack. + */ +void ensure_context_exists(); + } // namespace cuda_driver } // namespace transformer_engine diff --git a/transformer_engine/common/util/rtc.h b/transformer_engine/common/util/rtc.h index 820b16c206..7de1e4d55c 100644 --- a/transformer_engine/common/util/rtc.h +++ b/transformer_engine/common/util/rtc.h @@ -59,6 +59,7 @@ class Kernel { template void launch(int device_id, const dim3 grid_dim, const dim3 block_dim, unsigned int shared_mem_bytes, cudaStream_t stream, ArgTs &&...args) { + cuda_driver::ensure_context_exists(); void *arg_ptrs[] = {const_cast(static_cast(&args))...}; NVTE_CALL_CHECK_CUDA_DRIVER(cuLaunchKernel, get_function(device_id), grid_dim.x, grid_dim.y, grid_dim.z, block_dim.x, block_dim.y, block_dim.z, shared_mem_bytes, diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index f86b60f612..ab3b7abec4 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -12,7 +12,7 @@ namespace transformer_engine::pytorch { -std::vector getTensorShape(at::Tensor t) { +std::vector getTensorShape(const at::Tensor& t) { std::vector shape; for (auto s : t.sizes()) { shape.push_back(s); diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index b5b63f7574..be3b995a13 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -98,9 +98,21 @@ class Quantizer { virtual void set_quantization_params(TensorWrapper* tensor) const = 0; - virtual std::pair create_tensor( - const std::vector& shape, DType dtype, - std::optional rowwise_data = std::nullopt) const = 0; + /*! @brief Construct a tensor with uninitialized data */ + virtual std::pair create_tensor(const std::vector& shape, + DType dtype) const = 0; + + /*! @brief Convert a PyTorch tensor into a Transformer Engine C++ tensor + * + * The PyTorch tensor's attributes are modified to match the + * quantizer's configuration. + */ + virtual std::pair convert_and_update_tensor( + py::object tensor) const = 0; + + /*! @brief Convert to a quantized data format */ + virtual void quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag = std::nullopt) = 0; virtual ~Quantizer() = default; @@ -121,9 +133,17 @@ class NoneQuantizer : public Quantizer { void set_quantization_params(TensorWrapper* tensor) const override {} - std::pair create_tensor( - const std::vector& shape, DType dtype, - std::optional rowwise_data = std::nullopt) const override; + std::pair create_tensor(const std::vector& shape, + DType dtype) const override; + + /*! @brief Construct a tensor with pre-initialized data */ + std::pair create_tensor(const std::vector& shape, DType dtype, + at::Tensor data) const; + + std::pair convert_and_update_tensor(py::object tensor) const override; + + void quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag = std::nullopt) override; }; class Float8Quantizer : public Quantizer { @@ -139,9 +159,19 @@ class Float8Quantizer : public Quantizer { void set_quantization_params(TensorWrapper* tensor) const override; - std::pair create_tensor( - const std::vector& shape, DType dtype, - std::optional rowwise_data = std::nullopt) const override; + std::pair create_tensor(const std::vector& shape, + DType dtype) const override; + + /*! @brief Construct a tensor with pre-initialized data */ + std::pair create_tensor(const std::vector& shape, DType dtype, + std::optional data, + std::optional transpose, + std::optional scale_inv) const; + + std::pair convert_and_update_tensor(py::object shape) const override; + + void quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag = std::nullopt) override; }; class Float8CurrentScalingQuantizer : public Quantizer { @@ -161,9 +191,13 @@ class Float8CurrentScalingQuantizer : public Quantizer { void set_quantization_params(TensorWrapper* tensor) const override; - std::pair create_tensor( - const std::vector& shape, DType dtype, - std::optional rowwise_data = std::nullopt) const override; + std::pair create_tensor(const std::vector& shape, + DType dtype) const override; + + std::pair convert_and_update_tensor(py::object shape) const override; + + void quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag = std::nullopt) override; }; class Float8BlockQuantizer : public Quantizer { @@ -195,9 +229,13 @@ class Float8BlockQuantizer : public Quantizer { // Create a python Float8BlockQuantized tensor and C++ wrapper // for the tensor. Should set quantized data, scales for rowwise // and optionally columnwise usage. - std::pair create_tensor( - const std::vector& shape, DType dtype, - std::optional rowwise_data = std::nullopt) const override; + std::pair create_tensor(const std::vector& shape, + DType dtype) const override; + + std::pair convert_and_update_tensor(py::object shape) const override; + + void quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag = std::nullopt) override; std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; }; @@ -212,16 +250,20 @@ class MXFP8Quantizer : public Quantizer { void set_quantization_params(TensorWrapper* tensor) const override; - std::pair create_tensor( - const std::vector& shape, DType dtype, - std::optional rowwise_data = std::nullopt) const override; + std::pair create_tensor(const std::vector& shape, + DType dtype) const override; + + std::pair convert_and_update_tensor(py::object shape) const override; + + void quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag = std::nullopt) override; std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; }; std::unique_ptr convert_quantizer(py::handle quantizer); -std::vector getTensorShape(at::Tensor t); +std::vector getTensorShape(const at::Tensor& t); transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid, const std::string& fp8_recipe); diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index dfc8a82913..c9eae092b0 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -13,87 +13,74 @@ namespace transformer_engine::pytorch { template py::object activation_helper(const at::Tensor& input, py::handle quantizer, int shape_divisor = 1) { init_extension(); - auto my_quantizer = convert_quantizer(quantizer); - auto input_tensor = input.contiguous(); - - const TensorWrapper& te_input = makeTransformerEngineTensor(input_tensor); - const auto& te_input_shape = te_input.shape(); - std::vector input_shape(te_input_shape.data, te_input_shape.data + te_input_shape.ndim); - input_shape[input_shape.size() - 1] /= shape_divisor; - auto fake_tensor_type = input.scalar_type(); - - auto [te_output, out] = - my_quantizer->create_tensor(input_shape, GetTransformerEngineDType(fake_tensor_type)); - - // for current scaling, we need to compute amax first and then quantize - // because cache cannot fit in the entire tensor to compute amax and quantize - // the quantizer should not need amax reduction, no process group needed here - if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { - // activation function might change the input data range, we need to first call the activation function - // and then find the amax and scale of that and then do the quantization - // get a NoneQuantizer to calculate amax of activation output - auto my_quantizer_none = std::make_unique(py::none()); - auto [te_output_act, out_act] = - my_quantizer_none->create_tensor(input_shape, GetTransformerEngineDType(fake_tensor_type)); - - NVTE_SCOPED_GIL_RELEASE({ - act_func(te_input.data(), te_output_act.data(), at::cuda::getCurrentCUDAStream()); - // use te_output_act as input to the compute amax and find the amax of activated tensor - nvte_compute_amax(te_output_act.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); - }); - // my_quantizer here has to be a Float8CurrentScalingQuantizer - auto my_quantizer_cs = static_cast(my_quantizer.get()); - if (my_quantizer_cs->with_amax_reduction) { - NVTE_ERROR( - "per-tensor current scaling amax reduction is not supported in activation functions."); - } - QuantizationConfigWrapper quant_config; - quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); - quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); - - NVTE_SCOPED_GIL_RELEASE({ - nvte_compute_scale_from_amax(te_output.data(), quant_config, - at::cuda::getCurrentCUDAStream()); - // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel - te_output.set_amax(nullptr, DType::kFloat32, te_output.defaultShape); - nvte_quantize_v2(te_output_act.data(), te_output.data(), quant_config, - at::cuda::getCurrentCUDAStream()); - }); - } else if (detail::IsFloat8BlockwiseQuantizers(quantizer.ptr())) { - // sanity check, since activation fusion is not supported for blockwise quantization yet - // need to raise an error here instead of silently going into act_func with wrong numerics - NVTE_ERROR("Activation fusion is not supported for blockwise quantization yet."); + // Input tensor + auto input_tensor = input.contiguous(); + const TensorWrapper& input_cpp = makeTransformerEngineTensor(input_tensor); + + // Construct output tensor + auto quantizer_cpp = convert_quantizer(quantizer); + const auto input_shape = input_cpp.shape(); + std::vector output_shape(input_shape.data, input_shape.data + input_shape.ndim); + output_shape.back() /= shape_divisor; + auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type()); + auto [out_cpp, out_py] = quantizer_cpp->create_tensor(output_shape, fake_dtype); + + // Compute activation + if (quantizer.is_none() || detail::IsFloat8Quantizers(quantizer.ptr()) || + detail::IsMXFP8Quantizers(quantizer.ptr())) { + // Compute activation directly + NVTE_SCOPED_GIL_RELEASE( + { act_func(input_cpp.data(), out_cpp.data(), at::cuda::getCurrentCUDAStream()); }); } else { + // Compute activation in high-precision, then quantize + auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(output_shape, fake_dtype); NVTE_SCOPED_GIL_RELEASE( - { act_func(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); }); + { act_func(input_cpp.data(), temp_cpp.data(), at::cuda::getCurrentCUDAStream()); }); + quantizer_cpp->quantize(temp_cpp, out_cpp); } - return out; + return out_py; } -template -py::object dactivation_helper(const at::Tensor& grad, const at::Tensor& input, +template +py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& input, py::handle quantizer) { init_extension(); - auto my_quantizer = convert_quantizer(quantizer); - auto input_tensor = input.contiguous(); - auto grad_tensor = grad.contiguous(); - - const TensorWrapper& te_input = makeTransformerEngineTensor(input_tensor); - const TensorWrapper& te_grad = makeTransformerEngineTensor(grad_tensor); - const auto& te_input_shape = te_input.shape(); - std::vector input_shape(te_input_shape.data, te_input_shape.data + te_input_shape.ndim); - auto fake_tensor_type = input.scalar_type(); - - auto [te_output, out] = - my_quantizer->create_tensor(input_shape, GetTransformerEngineDType(fake_tensor_type)); - NVTE_SCOPED_GIL_RELEASE({ - act_func(te_grad.data(), te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); - }); + // Grad output and input tensors + auto grad_output_tensor = grad_output.contiguous(); + auto input_tensor = input.contiguous(); + const TensorWrapper& grad_output_cpp = makeTransformerEngineTensor(grad_output_tensor); + const TensorWrapper& input_cpp = makeTransformerEngineTensor(input_tensor); + + // Construct grad input tensor + auto quantizer_cpp = convert_quantizer(quantizer); + const auto input_shape_te = input_cpp.shape(); + const std::vector input_shape(input_shape_te.data, + input_shape_te.data + input_shape_te.ndim); + auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type()); + auto [grad_input_cpp, grad_input_py] = quantizer_cpp->create_tensor(input_shape, fake_dtype); + + // Compute activation backward + if (quantizer.is_none() || detail::IsFloat8Quantizers(quantizer.ptr()) || + detail::IsMXFP8Quantizers(quantizer.ptr())) { + // Compute activation backward directly + NVTE_SCOPED_GIL_RELEASE({ + dact_func(grad_output_cpp.data(), input_cpp.data(), grad_input_cpp.data(), + at::cuda::getCurrentCUDAStream()); + }); + } else { + // Compute activation backward in high-precision, then quantize + auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(input_shape, fake_dtype); + NVTE_SCOPED_GIL_RELEASE({ + dact_func(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), + at::cuda::getCurrentCUDAStream()); + }); + quantizer_cpp->quantize(temp_cpp, grad_input_cpp); + } - return out; + return grad_input_py; } py::object gelu(const at::Tensor& input, py::handle quantizer) { diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 71a8062b1a..6d835a5c94 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -18,7 +18,7 @@ void mha_fill(const transformer_engine::TensorWrapper &self, const at::Tensor &s auto max_tokens = shape[0]; auto fcd_size = 1; - for (int i = 1; i <= shape.size(); i++) { + for (size_t i = 1; i <= shape.size(); i++) { fcd_size *= shape[i]; } @@ -103,8 +103,20 @@ std::vector fused_attn_fwd( auto o_shape = std::vector{q_shape.begin(), q_shape.end()}; o_shape[o_shape.size() - 1] = v_shape[v_shape.size() - 1]; py::object o_python, s_python; - std::tie(te_O, o_python) = O_quantizer->create_tensor(o_shape, fake_dtype_te); - std::tie(te_S, s_python) = S_quantizer->create_tensor({0}, DType::kFloat32); + if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { + // Initialize FP8 tensor with scale-inverse + auto *O_quantizer_fp8 = dynamic_cast(O_quantizer.get()); + auto *S_quantizer_fp8 = dynamic_cast(S_quantizer.get()); + NVTE_CHECK(O_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8"); + NVTE_CHECK(S_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8"); + std::tie(te_O, o_python) = O_quantizer_fp8->create_tensor(o_shape, fake_dtype_te, std::nullopt, + std::nullopt, std::nullopt); + std::tie(te_S, s_python) = S_quantizer_fp8->create_tensor({0}, DType::kFloat32, std::nullopt, + std::nullopt, std::nullopt); + } else { + std::tie(te_O, o_python) = O_quantizer->create_tensor(o_shape, fake_dtype_te); + std::tie(te_S, s_python) = S_quantizer->create_tensor({0}, DType::kFloat32); + } auto o_shape_int64 = std::vector{o_shape.begin(), o_shape.end()}; // construct NVTE tensors @@ -284,8 +296,20 @@ std::vector fused_attn_bwd( py::object s_python, dp_python; std::unique_ptr S_quantizer = convert_quantizer(s_quantizer); std::unique_ptr dP_quantizer = convert_quantizer(dp_quantizer); - std::tie(te_S, s_python) = S_quantizer->create_tensor({0}, DType::kFloat32); - std::tie(te_dP, dp_python) = dP_quantizer->create_tensor({0}, DType::kFloat32); + + if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { + auto *S_quantizer_fp8 = dynamic_cast(S_quantizer.get()); + auto *dP_quantizer_fp8 = dynamic_cast(dP_quantizer.get()); + NVTE_CHECK(S_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8"); + NVTE_CHECK(dP_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8"); + std::tie(te_S, s_python) = S_quantizer_fp8->create_tensor({0}, DType::kFloat32, std::nullopt, + std::nullopt, std::nullopt); + std::tie(te_dP, dp_python) = dP_quantizer_fp8->create_tensor({0}, DType::kFloat32, std::nullopt, + std::nullopt, std::nullopt); + } else { + std::tie(te_S, s_python) = S_quantizer->create_tensor({0}, DType::kFloat32); + std::tie(te_dP, dp_python) = dP_quantizer->create_tensor({0}, DType::kFloat32); + } std::vector q_shape = convertShape(te_Q.shape()); std::vector k_shape = convertShape(te_K.shape()); @@ -374,9 +398,22 @@ std::vector fused_attn_bwd( default: NVTE_ERROR("QKV layout not supported!"); } - std::tie(te_dQ, py_dQ) = dQKV_quantizer->create_tensor(q_shape, fake_dtype_te, dQ); - std::tie(te_dK, py_dK) = dQKV_quantizer->create_tensor(k_shape, fake_dtype_te, dK); - std::tie(te_dV, py_dV) = dQKV_quantizer->create_tensor(v_shape, fake_dtype_te, dV); + if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { + auto *fp8_quantizer = dynamic_cast(dQKV_quantizer.get()); + NVTE_CHECK(fp8_quantizer != nullptr, "Expected Float8Quantizer when dtype is FP8"); + std::tie(te_dQ, py_dQ) = + fp8_quantizer->create_tensor(q_shape, fake_dtype_te, dQ, std::nullopt, std::nullopt); + std::tie(te_dK, py_dK) = + fp8_quantizer->create_tensor(k_shape, fake_dtype_te, dK, std::nullopt, std::nullopt); + std::tie(te_dV, py_dV) = + fp8_quantizer->create_tensor(v_shape, fake_dtype_te, dV, std::nullopt, std::nullopt); + } else { + auto *none_quantizer = dynamic_cast(dQKV_quantizer.get()); + NVTE_CHECK(none_quantizer != nullptr, "Expected NoneQuantizer when dtype is not FP8"); + std::tie(te_dQ, py_dQ) = none_quantizer->create_tensor(q_shape, fake_dtype_te, dQ); + std::tie(te_dK, py_dK) = none_quantizer->create_tensor(k_shape, fake_dtype_te, dK); + std::tie(te_dV, py_dV) = none_quantizer->create_tensor(v_shape, fake_dtype_te, dV); + } // construct NVTE tensors if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 07f2be9df6..5408cf1a6b 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -28,60 +28,6 @@ std::vector get_tensor_shape(const TensorWrapper &tensor) { return std::vector(shape.data, shape.data + shape.ndim); } -void quantize_impl(const TensorWrapper &input, py::handle &quantizer_py, - std::unique_ptr &quantizer_cpp, TensorWrapper &output, - TensorWrapper &noop_flag) { - // Check tensor dims - NVTE_CHECK(get_tensor_shape(input) == get_tensor_shape(output), - "Input tensor (shape=", get_tensor_shape(input), - ") and output tensor (shape=", get_tensor_shape(output), ") do not match"); - if (input.numel() == 0) { - return; - } - - // Recipe-specific configuration - QuantizationConfigWrapper quant_config; - quant_config.set_noop_tensor(noop_flag.data()); - if (detail::IsFloat8CurrentScalingQuantizers(quantizer_py.ptr())) { - auto my_quantizer_cs = static_cast(quantizer_cpp.get()); - NVTE_SCOPED_GIL_RELEASE( - { nvte_compute_amax(input.data(), output.data(), at::cuda::getCurrentCUDAStream()); }); - // check if we need to do amax reudction (depending on model parallel configs) - if (my_quantizer_cs->with_amax_reduction) { - c10::intrusive_ptr process_group_ptr = my_quantizer_cs->amax_reduction_group; - // construct torch tesnor from NVTEBasicTensor without reallocating memory - at::Tensor &amax_tensor_torch = my_quantizer_cs->amax; - std::vector tensors = {amax_tensor_torch}; - // allreduce amax tensor - c10d::AllreduceOptions allreduce_opts; - allreduce_opts.reduceOp = c10d::ReduceOp::MAX; - process_group_ptr->allreduce(tensors, allreduce_opts)->wait(); - } - // this config is used for cs scaling factor computation - // because compute scale is cannot be fused with quantize kernel - // so in nvte_quantize_v2 with current scaling, the quant config is not used again - quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); - quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); - NVTE_SCOPED_GIL_RELEASE({ - nvte_compute_scale_from_amax(output.data(), quant_config, at::cuda::getCurrentCUDAStream()); - }); - // set amax ptr to null in output TensorWrapper to avoid atomic amax updates in kernel - output.set_amax(nullptr, DType::kFloat32, output.defaultShape); - } else if (detail::IsFloat8BlockwiseQuantizers(quantizer_py.ptr())) { - auto my_quantizer_bw = static_cast(quantizer_cpp.get()); - quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales); - quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon); - if (my_quantizer_bw->all_gather_usage) { - quant_config.set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat::COMPACT); - } - } - - // Perform quantization - NVTE_SCOPED_GIL_RELEASE({ - nvte_quantize_v2(input.data(), output.data(), quant_config, at::cuda::getCurrentCUDAStream()); - }); -} - } // namespace py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::object &output, @@ -101,18 +47,17 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob const auto fake_dtype = input_cpp.dtype(); std::tie(output_cpp, output_py) = quantizer_cpp->create_tensor(shape, fake_dtype); } else { - output_py = output; - output_cpp = makeTransformerEngineTensor(output_py, quantizer); + std::tie(output_cpp, output_py) = quantizer_cpp->convert_and_update_tensor(output); } // Initialize no-op flag - TensorWrapper noop_flag_cpp; + std::optional noop_flag_cpp; if (noop_flag.has_value()) { noop_flag_cpp = makeTransformerEngineTensor(*noop_flag); } // Perform quantization - quantize_impl(input_cpp, quantizer, quantizer_cpp, output_cpp, noop_flag_cpp); + quantizer_cpp->quantize(input_cpp, output_cpp, noop_flag_cpp); return output_py; } @@ -182,10 +127,8 @@ void multi_tensor_quantize_impl(const std::vector &input_list, }); } else { // Quantize kernels individually - TensorWrapper dummy_noop_flag; for (size_t i = 0; i < num_tensors; ++i) { - quantize_impl(input_list[i], quantizer_py_list[i], quantizer_cpp_list[i], output_list[i], - dummy_noop_flag); + quantizer_cpp_list[i]->quantize(input_list[i], output_list[i]); } } } diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cpp b/transformer_engine/pytorch/csrc/extensions/transpose.cpp index d2f7107fe5..d6ae0c86a1 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cpp +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cpp @@ -18,27 +18,35 @@ namespace pytorch { at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional output) { init_extension(); - const auto dim = input.dim(); - NVTE_CHECK(dim >= 2, "Need at least 2D tensor to transpose."); - - if (input.dim() > 2) { - input = input.view({-1, input.size(dim - 1)}); + // Tensor dimensions + const auto shape = getTensorShape(input); + std::vector transpose_shape_int64; + if (shape.size() > 0) { + transpose_shape_int64.push_back(shape.back()); + for (size_t i = 0; i < shape.size() - 1; ++i) { + transpose_shape_int64.push_back(shape[i]); + } } + const size_t M = shape.size() > 0 ? product(shape) / shape.back() : 1; + const size_t N = shape.size() > 0 ? shape.back() : 1; - size_t M = static_cast(input.size(0)); - size_t N = static_cast(input.size(1)); - + // Output tensor at::Tensor out; if (output.has_value()) { out = *output; } else { - out = allocateTorchTensor(input.size(1), input.size(0), DType::kByte); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + out = at::empty(transpose_shape_int64, opts); } - if (M == 0 || N == 0) return out; + // Return immediately if tensor is empty + if (M == 0 || N == 0) { + return out; + } + + // Compute transpose auto input_cu = makeTransformerEngineTensor(input.data_ptr(), std::vector{M, N}, otype); auto output_cu = makeTransformerEngineTensor(out.data_ptr(), std::vector{N, M}, otype); - nvte_transpose(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); return out; diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 0ce1fc90ec..a7b7f58891 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -12,6 +12,27 @@ namespace transformer_engine::pytorch { +namespace { + +/*! @brief Transposed tensor shape + * + * The tensor is interpreted as a 2D matrix by flattening all but the + * last dimension, and then transposed. + */ +template +std::vector make_transpose_shape(const std::vector& shape) { + std::vector ret; + if (shape.size() > 0) { + ret.push_back(shape.back()); + for (size_t i = 0; i < shape.size() - 1; ++i) { + ret.push_back(shape[i]); + } + } + return ret; +} + +} // namespace + constexpr size_t MXFP8_BLOCK_SIZE = 32; Quantizer::Quantizer(const py::handle& quantizer) { @@ -37,24 +58,36 @@ Float8Quantizer::Float8Quantizer(const py::handle& quantizer) : Quantizer(quanti this->dtype = type; } -std::pair NoneQuantizer::create_tensor( - const std::vector& shape, DType dtype, std::optional rowwise_data) const { - at::TensorOptions opts; - opts = opts.dtype(GetATenDType(dtype)).device(torch::kCUDA); - std::vector torch_shape; - for (auto s : shape) { - torch_shape.emplace_back(static_cast(s)); - } - at::Tensor ret; - if (rowwise_data.has_value()) { - ret = std::move(*rowwise_data); - } else { - ret = at::empty(torch_shape, opts); - } +std::pair NoneQuantizer::create_tensor(const std::vector& shape, + DType dtype) const { + const std::vector shape_int64(shape.begin(), shape.end()); + const auto opts = at::TensorOptions().dtype(GetATenDType(dtype)).device(torch::kCUDA); + return create_tensor(shape, dtype, at::empty(shape_int64, opts)); +} - TensorWrapper tensor; - tensor.set_rowwise_data(ret.data_ptr(), dtype, shape); - return {std::move(tensor), py::cast(ret)}; +std::pair NoneQuantizer::create_tensor(const std::vector& shape, + DType dtype, + at::Tensor data) const { + TensorWrapper out_cpp; + out_cpp.set_rowwise_data(data.data_ptr(), dtype, shape); + set_quantization_params(&out_cpp); + return {std::move(out_cpp), py::cast(data)}; +} + +std::pair NoneQuantizer::convert_and_update_tensor( + py::object tensor) const { + auto tensor_pyt = tensor.cast(); + TensorWrapper out_cpp; + out_cpp.set_rowwise_data(tensor_pyt.data_ptr(), + GetTransformerEngineDType(tensor_pyt.scalar_type()), + getTensorShape(tensor_pyt)); + set_quantization_params(&out_cpp); + return {std::move(out_cpp), std::move(tensor)}; +} + +void NoneQuantizer::quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag) { + NVTE_ERROR("NoneQuantizer does not support quantization"); } void Float8Quantizer::set_quantization_params(TensorWrapper* tensor) const { @@ -76,68 +109,180 @@ void Float8Quantizer::set_quantization_params(TensorWrapper* tensor) const { } std::pair Float8Quantizer::create_tensor( - const std::vector& shape, DType dtype, std::optional rowwise_data) const { + const std::vector& shape, DType dtype) const { + const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + at::Tensor scale_inv = at::empty(std::vector{1}, opts); + return create_tensor(shape, dtype, std::nullopt, std::nullopt, std::move(scale_inv)); +} + +std::pair Float8Quantizer::create_tensor( + const std::vector& shape, DType dtype, std::optional data, + std::optional transpose, std::optional scale_inv) const { using namespace pybind11::literals; - std::vector rowwise_torch_shape; - std::vector columnwise_torch_shape; - if (!shape.empty()) { - columnwise_torch_shape.emplace_back(static_cast(shape.back())); + // Initialize data tensor + const bool with_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); + if (with_data && !data) { + const std::vector shape_int64(shape.begin(), shape.end()); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + data = at::empty(shape_int64, opts); + } else if (!with_data && data) { + data.reset(); } - for (size_t i = 0; i < shape.size(); ++i) { - if (i < shape.size() - 1) { - columnwise_torch_shape.emplace_back(static_cast(shape[i])); - } - rowwise_torch_shape.emplace_back(static_cast(shape[i])); + py::object data_py = with_data ? py::cast(*data) : py::none(); + + // Initialize transpose tensor + const bool with_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + if (with_transpose && !transpose) { + const auto transpose_shape = make_transpose_shape(shape); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + transpose = at::empty(transpose_shape, opts); + } else if (!with_transpose && transpose) { + transpose.reset(); } - at::TensorOptions opts; - opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); - at::Tensor data; - if (rowwise_usage) { - if (rowwise_data.has_value()) { - data = std::move(*rowwise_data); - } else { - data = at::empty(rowwise_torch_shape, opts); - } - } - const py::object py_data = rowwise_usage ? py::cast(data) : py::none(); - at::Tensor columnwise_data; - bool create_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); - if (create_transpose) { - columnwise_data = at::empty(columnwise_torch_shape, opts); + py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none(); + + // Initialize scale-inverse tensor + if (!scale_inv) { + scale_inv = at::reciprocal(scale); } - const py::object py_columnwise_data = create_transpose ? py::cast(columnwise_data) : py::none(); - opts = opts.dtype(torch::kFloat32); - // TODO: Replace with an empty tensor. - at::Tensor scale_inv = at::reciprocal(scale); - py::object ret; + + // Construct Python FP8 tensor + py::object out_py; if (internal) { py::handle Float8TensorClass(reinterpret_cast(Float8TensorBasePythonClass)); - ret = Float8TensorClass("data"_a = py_data, "fp8_scale_inv"_a = scale_inv, - "fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data, - "quantizer"_a = this->quantizer); + out_py = Float8TensorClass("data"_a = data_py, "fp8_scale_inv"_a = *scale_inv, + "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, + "quantizer"_a = this->quantizer); } else { py::handle Float8TensorClass(reinterpret_cast(Float8TensorPythonClass)); - ret = Float8TensorClass("shape"_a = rowwise_torch_shape, "dtype"_a = GetATenDType(dtype), - "data"_a = py_data, "fp8_scale_inv"_a = scale_inv, - "fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data, - "quantizer"_a = this->quantizer); + const std::vector shape_int64(shape.begin(), shape.end()); + out_py = Float8TensorClass("shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), + "data"_a = data_py, "fp8_scale_inv"_a = *scale_inv, + "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, + "quantizer"_a = this->quantizer); } - TensorWrapper tensor(this->get_scaling_mode()); - if (rowwise_usage) { - tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape); - tensor.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector{1}); + + // Construct C++ FP8 tensor + TensorWrapper out_cpp(this->get_scaling_mode()); + if (with_data) { + out_cpp.set_rowwise_data(data->data_ptr(), this->dtype, shape); + out_cpp.set_rowwise_scale_inv(scale_inv->data_ptr(), DType::kFloat32, std::vector{1}); + } + if (with_transpose) { + const auto transpose_shape = make_transpose_shape(shape); + out_cpp.set_columnwise_data(transpose->data_ptr(), this->dtype, transpose_shape); + out_cpp.set_columnwise_scale_inv(scale_inv->data_ptr(), DType::kFloat32, + std::vector{1}); + } + this->set_quantization_params(&out_cpp); + + return {std::move(out_cpp), std::move(out_py)}; +} + +std::pair Float8Quantizer::convert_and_update_tensor( + py::object tensor) const { + NVTE_CHECK(detail::IsFloat8Tensor(tensor.ptr()), "Float8Quantizer must output to Float8Tensor."); + + // Expected buffers + const bool need_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); + const bool need_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + NVTE_CHECK(need_data || need_transpose, "Invalid usages for Float8Quantizer."); + + // Extract buffers from Python tensor + auto data_py = tensor.attr("_data"); + auto transpose_py = tensor.attr("_transpose"); + const bool has_data = !data_py.is_none(); + const bool has_transpose = !transpose_py.is_none(); + NVTE_CHECK(has_data || has_transpose, "Float8Tensor has no data."); + std::optional data_tensor, transpose_tensor; + if (has_data) { + data_tensor = data_py.cast(); } - if (create_transpose) { - std::vector transposed_shape; - for (auto s : columnwise_torch_shape) { - transposed_shape.emplace_back(static_cast(s)); + if (has_transpose) { + transpose_tensor = transpose_py.cast(); + } + at::Tensor scale_inv_tensor = tensor.attr("_scale_inv").cast(); + + // Tensor dimensions + std::vector shape; + if (has_transpose) { + const auto transpose_shape = getTensorShape(*transpose_tensor); + if (transpose_shape.size() > 0) { + for (size_t i = 1; i < transpose_shape.size(); ++i) { + shape.push_back(transpose_shape[i]); + } + shape.push_back(transpose_shape.front()); + } + if (has_data) { + auto expected_shape = getTensorShape(*data_tensor); + NVTE_CHECK(shape == expected_shape, "FP8 data (shape=", expected_shape, + ") and transpose (shape=", transpose_shape, ") do not match"); } - tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, transposed_shape); - tensor.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector{1}); + } else { // Already checked has_data == true + shape = getTensorShape(*data_tensor); } - this->set_quantization_params(&tensor); - return {std::move(tensor), std::move(ret)}; + + // Coerce data tensor + if (has_data && !need_data) { + data_tensor.reset(); + data_py = py::none(); + tensor.attr("_data") = data_py; + } else if (!has_data && need_data) { + const std::vector shape_int64(shape.begin(), shape.end()); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + data_tensor = at::empty(shape_int64, opts); + data_py = py::cast(data_tensor); + tensor.attr("_data") = data_py; + } + + // Coerce transpose tensor + if (has_transpose && !need_transpose) { + transpose_tensor.reset(); + transpose_py = py::none(); + tensor.attr("_transpose") = transpose_py; + } else if (!has_transpose && need_transpose) { + const auto transpose_shape = make_transpose_shape(shape); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + transpose_tensor = at::empty(transpose_shape, opts); + transpose_py = py::cast(transpose_tensor); + tensor.attr("_transpose") = transpose_py; + } + tensor.attr("_transpose_invalid") = !need_transpose; + + // Coerce other attrs + tensor.attr("_fp8_dtype") = dtype; + + // Construct C++ FP8 tensor + TensorWrapper out_cpp; + if (data_tensor) { + out_cpp.set_rowwise_data(data_tensor->data_ptr(), this->dtype, shape); + out_cpp.set_rowwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32, + std::vector{1}); + } + if (transpose_tensor) { + const auto transpose_shape = make_transpose_shape(shape); + out_cpp.set_columnwise_data(transpose_tensor->data_ptr(), this->dtype, transpose_shape); + out_cpp.set_columnwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32, + std::vector{1}); + } + this->set_quantization_params(&out_cpp); + + return {std::move(out_cpp), std::move(tensor)}; +} + +void Float8Quantizer::quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag) { + if (input.numel() == 0) { + return; + } + QuantizationConfigWrapper quant_config; + if (noop_flag) { + quant_config.set_noop_tensor(noop_flag->data()); + } + NVTE_SCOPED_GIL_RELEASE({ + nvte_quantize_v2(input.data(), out.data(), quant_config, at::cuda::getCurrentCUDAStream()); + }); } Float8CurrentScalingQuantizer::Float8CurrentScalingQuantizer(const py::handle& quantizer) @@ -187,71 +332,198 @@ void Float8CurrentScalingQuantizer::set_quantization_params(TensorWrapper* tenso } std::pair Float8CurrentScalingQuantizer::create_tensor( - const std::vector& shape, DType dtype, std::optional rowwise_data) const { + const std::vector& shape, DType dtype) const { using namespace pybind11::literals; - std::vector rowwise_torch_shape; - std::vector columnwise_torch_shape; - std::vector scale_inv_torch_shape = {1}; // Shape of 1 element for scale_inv - if (!shape.empty()) { - columnwise_torch_shape.emplace_back(static_cast(shape.back())); - } - for (size_t i = 0; i < shape.size(); ++i) { - if (i < shape.size() - 1) { - columnwise_torch_shape.emplace_back(static_cast(shape[i])); - } - rowwise_torch_shape.emplace_back(static_cast(shape[i])); + // Initialize data tensor + at::Tensor data_tensor; + const bool with_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); + if (with_data) { + const std::vector shape_int64(shape.begin(), shape.end()); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + data_tensor = at::empty(shape_int64, opts); } - at::TensorOptions opts; - opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); - at::Tensor data; - if (rowwise_usage) { - if (rowwise_data.has_value()) { - data = std::move(*rowwise_data); - } else { - data = at::empty(rowwise_torch_shape, opts); - } - } - const py::object py_data = rowwise_usage ? py::cast(data) : py::none(); - at::Tensor columnwise_data; - bool create_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); - if (create_transpose) { - columnwise_data = at::empty(columnwise_torch_shape, opts); + + // Initialize transpose tensor + at::Tensor transpose_tensor; + const bool with_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + if (with_transpose) { + const auto transpose_shape = make_transpose_shape(shape); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + transpose_tensor = at::empty(transpose_shape, opts); } - const py::object py_columnwise_data = create_transpose ? py::cast(columnwise_data) : py::none(); - // In current scaling, scale is not known but we initialize it with 1 to avoid division by zero. If scale is already calculated, it can be correctly set. - at::Tensor scale_inv = at::reciprocal(scale); + // Initialize scale-inverse tensor + at::Tensor scale_inv_tensor; + { + const std::vector scale_inv_shape = {1}; + const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + scale_inv_tensor = at::empty(scale_inv_shape, opts); + } - py::object ret; + // Construct Python FP8 tensor + py::object out_py; + py::object data_py = with_data ? py::cast(data_tensor) : py::none(); + py::object transpose_py = with_transpose ? py::cast(transpose_tensor) : py::none(); if (internal) { py::handle Float8TensorClass(reinterpret_cast(Float8TensorBasePythonClass)); - ret = Float8TensorClass("data"_a = py_data, "fp8_scale_inv"_a = scale_inv, - "fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data, - "quantizer"_a = this->quantizer); + out_py = Float8TensorClass("data"_a = data_py, "fp8_scale_inv"_a = scale_inv_tensor, + "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, + "quantizer"_a = this->quantizer); } else { py::handle Float8TensorClass(reinterpret_cast(Float8TensorPythonClass)); - ret = Float8TensorClass("shape"_a = rowwise_torch_shape, "dtype"_a = GetATenDType(dtype), - "data"_a = py_data, "fp8_scale_inv"_a = scale_inv, - "fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data, - "quantizer"_a = this->quantizer); + const std::vector shape_int64(shape.begin(), shape.end()); + out_py = Float8TensorClass("shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), + "data"_a = data_py, "fp8_scale_inv"_a = scale_inv_tensor, + "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, + "quantizer"_a = this->quantizer); } - TensorWrapper tensor(this->get_scaling_mode()); - if (rowwise_usage) { - tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape); - tensor.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector{1}); + + // Construct C++ FP8 tensor + TensorWrapper out_cpp(this->get_scaling_mode()); + if (with_data) { + out_cpp.set_rowwise_data(data_tensor.data_ptr(), this->dtype, shape); + out_cpp.set_rowwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32, + std::vector{1}); + } + if (with_transpose) { + const auto transpose_shape = make_transpose_shape(shape); + out_cpp.set_columnwise_data(transpose_tensor.data_ptr(), this->dtype, transpose_shape); + out_cpp.set_columnwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32, + std::vector{1}); } - if (create_transpose) { - std::vector transposed_shape; - for (auto s : columnwise_torch_shape) { - transposed_shape.emplace_back(static_cast(s)); + this->set_quantization_params(&out_cpp); + + return {std::move(out_cpp), std::move(out_py)}; +} + +std::pair Float8CurrentScalingQuantizer::convert_and_update_tensor( + py::object tensor) const { + NVTE_CHECK(detail::IsFloat8Tensor(tensor.ptr()), + "Float8CurrentScalingQuantizer must output to Float8Tensor."); + + // Expected buffers + const bool need_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); + const bool need_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + NVTE_CHECK(need_data || need_transpose, "Invalid quantizer usages."); + + // Extract buffers from Python tensor + auto data_py = tensor.attr("_data"); + auto transpose_py = tensor.attr("_transpose"); + const bool has_data = !data_py.is_none(); + const bool has_transpose = !transpose_py.is_none(); + NVTE_CHECK(has_data || has_transpose, "Tensor has no data."); + std::optional data_tensor, transpose_tensor; + if (has_data) { + data_tensor = data_py.cast(); + } + if (has_transpose) { + transpose_tensor = transpose_py.cast(); + } + at::Tensor scale_inv_tensor = tensor.attr("_scale_inv").cast(); + + // Tensor dimensions + std::vector shape; + if (has_transpose) { + const auto transpose_shape = getTensorShape(*transpose_tensor); + if (transpose_shape.size() > 0) { + for (size_t i = 1; i < transpose_shape.size(); ++i) { + shape.push_back(transpose_shape[i]); + } + shape.push_back(transpose_shape.front()); } - tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, transposed_shape); - tensor.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector{1}); + if (has_data) { + auto expected_shape = getTensorShape(*data_tensor); + NVTE_CHECK(shape == expected_shape, "FP8 data (shape=", expected_shape, + ") and transpose (shape=", transpose_shape, ") do not match"); + } + } else { // Already checked has_data == true + shape = getTensorShape(*data_tensor); } - this->set_quantization_params(&tensor); - return {std::move(tensor), std::move(ret)}; + // Coerce data tensor in Python tensor + if (has_data && !need_data) { + data_tensor.reset(); + data_py = py::none(); + tensor.attr("_data") = data_py; + } else if (!has_data && need_data) { + const std::vector shape_int64(shape.begin(), shape.end()); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + data_tensor = at::empty(shape_int64, opts); + data_py = py::cast(data_tensor); + tensor.attr("_data") = data_py; + } + + // Coerce transpose tensor + if (has_transpose && !need_transpose) { + transpose_tensor.reset(); + transpose_py = py::none(); + tensor.attr("_transpose") = transpose_py; + } else if (!has_transpose && need_transpose) { + const auto transpose_shape = make_transpose_shape(shape); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + transpose_tensor = at::empty(transpose_shape, opts); + transpose_py = py::cast(transpose_tensor); + tensor.attr("_transpose") = transpose_py; + } + tensor.attr("_transpose_invalid") = !need_transpose; + + // Coerce other attrs + tensor.attr("_fp8_dtype") = dtype; + + // Construct C++ FP8 tensor + TensorWrapper out_cpp; + if (data_tensor) { + out_cpp.set_rowwise_data(data_tensor->data_ptr(), this->dtype, shape); + out_cpp.set_rowwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32, + std::vector{1}); + } + if (transpose_tensor) { + const auto transpose_shape = make_transpose_shape(shape); + out_cpp.set_columnwise_data(transpose_tensor->data_ptr(), this->dtype, transpose_shape); + out_cpp.set_columnwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32, + std::vector{1}); + } + this->set_quantization_params(&out_cpp); + + return {std::move(out_cpp), std::move(tensor)}; +} + +void Float8CurrentScalingQuantizer::quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag) { + auto stream = at::cuda::getCurrentCUDAStream(); + + // Nothing to be done if input is empty + if (input.numel() == 0) { + return; + } + + // Quantization configs + QuantizationConfigWrapper quant_config; + if (noop_flag) { + quant_config.set_noop_tensor(noop_flag->data()); + } + quant_config.set_force_pow_2_scales(force_pow_2_scales); + quant_config.set_amax_epsilon(amax_epsilon); + + // Compute amax + NVTE_SCOPED_GIL_RELEASE({ nvte_compute_amax(input.data(), out.data(), stream); }); + + // Perform amax reduction if needed + if (with_amax_reduction) { + // allreduce amax tensor + c10d::AllreduceOptions opts; + opts.reduceOp = c10d::ReduceOp::MAX; + std::vector tensors = {amax}; + NVTE_SCOPED_GIL_RELEASE({ amax_reduction_group->allreduce(tensors, opts)->wait(); }); + } + + // Compute scaling factor + NVTE_SCOPED_GIL_RELEASE({ nvte_compute_scale_from_amax(out.data(), quant_config, stream); }); + + // Cast to FP8 + out.set_amax(nullptr, DType::kFloat32, out.defaultShape); // Avoid atomic amax updates + NVTE_SCOPED_GIL_RELEASE({ nvte_quantize_v2(input.data(), out.data(), quant_config, stream); }); } Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quantizer(quantizer) { @@ -280,7 +552,7 @@ void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const } std::pair Float8BlockQuantizer::create_tensor( - const std::vector& shape, DType dtype, std::optional rowwise_data) const { + const std::vector& shape, DType dtype) const { using namespace pybind11::literals; std::vector torch_shape; for (auto s : shape) { @@ -299,11 +571,7 @@ std::pair Float8BlockQuantizer::create_tensor( : Float8BlockScaleTensorFormat::GEMM_READY); if (rowwise_usage) { - if (rowwise_data.has_value()) { - data_rowwise = std::move(*rowwise_data); - } else { - data_rowwise = at::empty(torch_shape, opts); - } + data_rowwise = at::empty(torch_shape, opts); auto scale_shape = get_scale_shape(shape, false); size_t sinv0 = scale_shape[0]; size_t sinv1 = scale_shape[1]; @@ -373,6 +641,62 @@ std::pair Float8BlockQuantizer::create_tensor( return {std::move(tensor), std::move(ret)}; } +std::pair Float8BlockQuantizer::convert_and_update_tensor( + py::object tensor) const { + const DType dtype = tensor.attr("_fp8_dtype").cast(); + bool is_2D_scaled = tensor.attr("_is_2D_scaled").cast(); + + // Check the data matches quantizer usages + NVTE_CHECK(!tensor.attr("_rowwise_data").is_none() == rowwise_usage, + "Float8BlockwiseQTensor does not match quantizer usages (has_rowwise_data=", + !tensor.attr("_rowwise_data").is_none(), ", rowwise_usage=", rowwise_usage); + NVTE_CHECK(!tensor.attr("_columnwise_data").is_none() == columnwise_usage, + "Float8BlockwiseQTensor does not match quantizer usages (has_columnwise_data=", + !tensor.attr("_columnwise_data").is_none(), ", columnwise_usage=", columnwise_usage); + + auto ret = TensorWrapper(is_2D_scaled ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D); + + if (rowwise_usage) { + const at::Tensor& data_rowwise = tensor.attr("_rowwise_data").cast(); + const at::Tensor& scale_inv_rowwise = tensor.attr("_rowwise_scale_inv").cast(); + void* scale_inv_rowwise_dptr = scale_inv_rowwise.data_ptr(); + const auto& rowwise_shape = getTensorShape(data_rowwise); + ret.set_rowwise_data(data_rowwise.data_ptr(), dtype, rowwise_shape); + const auto scale_inv_rowwise_shape = getTensorShape(scale_inv_rowwise); + ret.set_rowwise_scale_inv(scale_inv_rowwise_dptr, DType::kFloat32, scale_inv_rowwise_shape); + } + if (columnwise_usage) { + const at::Tensor& data_colwise = tensor.attr("_columnwise_data").cast(); + const at::Tensor& scale_inv_colwise = tensor.attr("_columnwise_scale_inv").cast(); + void* scale_inv_colwise_dptr = scale_inv_colwise.data_ptr(); + const auto& shape = getTensorShape(data_colwise); + ret.set_columnwise_data(data_colwise.data_ptr(), dtype, shape); + const auto scale_inv_colwise_shape = getTensorShape(scale_inv_colwise); + ret.set_columnwise_scale_inv(scale_inv_colwise_dptr, DType::kFloat32, scale_inv_colwise_shape); + } + set_quantization_params(&ret); + return {std::move(ret), std::move(tensor)}; +} + +void Float8BlockQuantizer::quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag) { + if (input.numel() == 0) { + return; + } + QuantizationConfigWrapper quant_config; + if (noop_flag) { + quant_config.set_noop_tensor(noop_flag->data()); + } + quant_config.set_force_pow_2_scales(force_pow_2_scales); + quant_config.set_amax_epsilon(amax_epsilon); + if (all_gather_usage) { + quant_config.set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat::COMPACT); + } + NVTE_SCOPED_GIL_RELEASE({ + nvte_quantize_v2(input.data(), out.data(), quant_config, at::cuda::getCurrentCUDAStream()); + }); +} + std::vector Float8BlockQuantizer::get_scale_shape(const std::vector& shape, bool columnwise) const { size_t numel = 1; @@ -465,71 +789,204 @@ void MXFP8Quantizer::set_quantization_params(TensorWrapper* tensor) const { columnwise_data.shape); } -std::pair MXFP8Quantizer::create_tensor( - const std::vector& shape, DType dtype, std::optional rowwise_data) const { +std::pair MXFP8Quantizer::create_tensor(const std::vector& shape, + DType dtype) const { using namespace pybind11::literals; - std::vector torch_shape; - size_t numel = 1; - for (auto s : shape) { - torch_shape.emplace_back(static_cast(s)); - numel *= s; - } - TensorWrapper tensor(NVTE_MXFP8_1D_SCALING); - at::TensorOptions opts; - at::Tensor rowwise_data1, columnwise_data, rowwise_scale_inv, - columnwise_scale_inv; // TODO(pgadzinski) - change - opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); - - at::Tensor data; - if (rowwise_usage) { - if (rowwise_data.has_value()) { - data = std::move(*rowwise_data); - } else { - data = at::empty(torch_shape, opts); + // Tensor dimensions + const std::vector shape_int64(shape.begin(), shape.end()); + size_t flat_first_dim = 1; + if (shape.size() > 0) { + for (size_t i = 0; i < shape.size() - 1; ++i) { + flat_first_dim *= shape[i]; } - auto scale_shape = get_scale_shape(shape, false); - size_t sinv0 = scale_shape[0]; - size_t sinv1 = scale_shape[1]; - rowwise_scale_inv = at::zeros({static_cast(sinv0), static_cast(sinv1)}, opts); - tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape); - tensor.set_rowwise_scale_inv( - rowwise_scale_inv.data_ptr(), DType::kFloat8E8M0, - std::vector{static_cast(sinv0), static_cast(sinv1)}); } + const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1; + NVTE_CHECK(flat_first_dim % MXFP8_BLOCK_SIZE == 0 && flat_last_dim % MXFP8_BLOCK_SIZE == 0, + "MXFP8 requires tensor dims that are divisble by ", MXFP8_BLOCK_SIZE, + " (got shape=", shape, ")"); + const auto rowwise_scale_inv_shape = get_scale_shape(shape, false); + const auto columnwise_scale_inv_shape = get_scale_shape(shape, true); + // Allocate tensors + at::Tensor rowwise_data_tensor, rowwise_scale_inv_tensor; + at::Tensor columnwise_data_tensor, columnwise_scale_inv_tensor; + const auto uint8_tensor_opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + if (rowwise_usage) { + const std::vector scale_inv_shape_int64(rowwise_scale_inv_shape.begin(), + rowwise_scale_inv_shape.end()); + rowwise_data_tensor = at::empty(shape_int64, uint8_tensor_opts); + rowwise_scale_inv_tensor = at::zeros(scale_inv_shape_int64, uint8_tensor_opts); + } if (columnwise_usage) { - auto scale_shape = get_scale_shape(shape, true); - size_t sinv0 = scale_shape[0]; - size_t sinv1 = scale_shape[1]; - columnwise_data = at::empty(torch_shape, opts); - columnwise_scale_inv = - at::zeros({static_cast(sinv0), static_cast(sinv1)}, opts); - - tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, shape); - tensor.set_columnwise_scale_inv( - columnwise_scale_inv.data_ptr(), DType::kFloat8E8M0, - std::vector{static_cast(sinv0), static_cast(sinv1)}); + const std::vector scale_inv_shape_int64(columnwise_scale_inv_shape.begin(), + columnwise_scale_inv_shape.end()); + columnwise_data_tensor = at::empty(shape_int64, uint8_tensor_opts); + columnwise_scale_inv_tensor = at::zeros(scale_inv_shape_int64, uint8_tensor_opts); } - this->set_quantization_params(&tensor); - py::object ret; + // Convert tensors to Python + auto py_cast = [](at::Tensor& tensor, bool need_cast) -> py::object { + return need_cast ? py::cast(tensor) : py::none(); + }; + auto rowwise_data_py = py_cast(rowwise_data_tensor, rowwise_usage); + auto rowwise_scale_inv_py = py_cast(rowwise_scale_inv_tensor, rowwise_usage); + auto columnwise_data_py = py_cast(columnwise_data_tensor, columnwise_usage); + auto columnwise_scale_inv_py = py_cast(columnwise_scale_inv_tensor, columnwise_usage); + + // Construct Python MXFP8 tensor + py::object out_py; if (internal) { py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorBasePythonClass)); - ret = MXFP8TensorClass("rowwise_data"_a = data, "columnwise_data"_a = columnwise_data, - "rowwise_scale_inv"_a = rowwise_scale_inv, - "columnwise_scale_inv"_a = columnwise_scale_inv, - "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); + out_py = MXFP8TensorClass("rowwise_data"_a = rowwise_data_py, + "columnwise_data"_a = columnwise_data_py, + "rowwise_scale_inv"_a = rowwise_scale_inv_py, + "columnwise_scale_inv"_a = columnwise_scale_inv_py, + "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); } else { py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorPythonClass)); - ret = MXFP8TensorClass("shape"_a = torch_shape, "dtype"_a = GetATenDType(dtype), - "rowwise_data"_a = data, "columnwise_data"_a = columnwise_data, - "rowwise_scale_inv"_a = rowwise_scale_inv, - "columnwise_scale_inv"_a = columnwise_scale_inv, - "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); + out_py = MXFP8TensorClass("shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), + "rowwise_data"_a = rowwise_data_py, + "columnwise_data"_a = columnwise_data_py, + "rowwise_scale_inv"_a = rowwise_scale_inv_py, + "columnwise_scale_inv"_a = columnwise_scale_inv_py, + "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); } - return {std::move(tensor), std::move(ret)}; + // Construct C++ MXFP8 tensor + TensorWrapper out_cpp(NVTE_MXFP8_1D_SCALING); + if (rowwise_usage) { + out_cpp.set_rowwise_data(rowwise_data_tensor.data_ptr(), this->dtype, shape); + out_cpp.set_rowwise_scale_inv(rowwise_scale_inv_tensor.data_ptr(), DType::kFloat8E8M0, + rowwise_scale_inv_shape); + } + if (columnwise_usage) { + out_cpp.set_columnwise_data(columnwise_data_tensor.data_ptr(), this->dtype, shape); + out_cpp.set_columnwise_scale_inv(columnwise_scale_inv_tensor.data_ptr(), DType::kFloat8E8M0, + columnwise_scale_inv_shape); + } + this->set_quantization_params(&out_cpp); + + return {std::move(out_cpp), std::move(out_py)}; +} + +std::pair MXFP8Quantizer::convert_and_update_tensor( + py::object tensor) const { + NVTE_CHECK(detail::IsMXFP8Tensor(tensor.ptr()), "MXFP8Quantizer must output to MXFP8Tensor."); + + // Extract buffers from Python tensor + auto get_tensor = [&tensor](const char* name) -> std::optional { + auto attr_py = tensor.attr(name); + if (attr_py.is_none()) { + return std::nullopt; + } + return attr_py.cast(); + }; + auto rowwise_data = get_tensor("_rowwise_data"); + auto rowwise_scale_inv = get_tensor("_rowwise_scale_inv"); + auto columnwise_data = get_tensor("_columnwise_data"); + auto columnwise_scale_inv = get_tensor("_columnwise_scale_inv"); + NVTE_CHECK(rowwise_data || columnwise_data, "MXFP8Tensor has no data."); + + // Tensor dimensions + std::vector shape; + if (columnwise_data) { + shape = getTensorShape(*columnwise_data); + if (rowwise_data) { + auto expected_shape = getTensorShape(*rowwise_data); + NVTE_CHECK(shape == expected_shape, "MXFP8 row-wise data (shape=", expected_shape, + ") and column-wise data (shape=", shape, ") do not match"); + } + } else { // Already checked columnwise_data_tensor == true + shape = getTensorShape(*rowwise_data); + } + + // Coerce row-wise data + if (rowwise_usage) { + if (!rowwise_data) { + const std::vector shape_int64(shape.begin(), shape.end()); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + rowwise_data = at::empty(shape_int64, opts); + tensor.attr("_rowwise_data") = *rowwise_data; + } + if (!rowwise_scale_inv) { + const auto scale_inv_shape = get_scale_shape(shape, false); + const std::vector scale_inv_shape_int64(scale_inv_shape.begin(), + scale_inv_shape.end()); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + rowwise_scale_inv = at::zeros(scale_inv_shape_int64, opts); + tensor.attr("_rowwise_scale_inv") = *rowwise_scale_inv; + } + } else { // rowwise_usage == false + if (rowwise_data) { + rowwise_data.reset(); + tensor.attr("_rowwise_data") = py::none(); + } + if (rowwise_scale_inv) { + rowwise_scale_inv.reset(); + tensor.attr("_rowwise_scale_inv") = py::none(); + } + } + + // Coerce column-wise data + if (columnwise_usage) { + if (!columnwise_data) { + const std::vector shape_int64(shape.begin(), shape.end()); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + columnwise_data = at::empty(shape_int64, opts); + tensor.attr("_columnwise_data") = *columnwise_data; + } + if (!columnwise_scale_inv) { + const auto scale_inv_shape = get_scale_shape(shape, true); + const std::vector scale_inv_shape_int64(scale_inv_shape.begin(), + scale_inv_shape.end()); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + columnwise_scale_inv = at::zeros(scale_inv_shape_int64, opts); + tensor.attr("_columnwise_scale_inv") = *columnwise_scale_inv; + } + } else { // columnwise_usage == false + if (columnwise_data) { + columnwise_data.reset(); + tensor.attr("_columnwise_data") = py::none(); + } + if (columnwise_scale_inv) { + columnwise_scale_inv.reset(); + tensor.attr("_columnwise_scale_inv") = py::none(); + } + } + + // Coerce other attrs + tensor.attr("_fp8_dtype") = dtype; + + // Construct C++ MXFP8 tensor + TensorWrapper out_cpp(NVTE_MXFP8_1D_SCALING); + if (rowwise_usage) { + out_cpp.set_rowwise_data(rowwise_data->data_ptr(), dtype, shape); + out_cpp.set_rowwise_scale_inv(rowwise_scale_inv->data_ptr(), DType::kFloat8E8M0, + getTensorShape(*rowwise_scale_inv)); + } + if (columnwise_usage) { + out_cpp.set_columnwise_data(columnwise_data->data_ptr(), dtype, shape); + out_cpp.set_columnwise_scale_inv(columnwise_scale_inv->data_ptr(), DType::kFloat8E8M0, + getTensorShape(*columnwise_scale_inv)); + } + this->set_quantization_params(&out_cpp); + + return {std::move(out_cpp), std::move(tensor)}; +} + +void MXFP8Quantizer::quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag) { + if (input.numel() == 0) { + return; + } + QuantizationConfigWrapper quant_config; + if (noop_flag) { + quant_config.set_noop_tensor(noop_flag->data()); + } + NVTE_SCOPED_GIL_RELEASE({ + nvte_quantize_v2(input.data(), out.data(), quant_config, at::cuda::getCurrentCUDAStream()); + }); } std::vector MXFP8Quantizer::get_scale_shape(const std::vector& shape, diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 383efc8237..7f10336ced 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -22,8 +22,7 @@ from ...fp8 import FP8GlobalStateManager, Recipe from ...module.base import _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD from ...tensor import Quantizer -from ...tensor.float8_blockwise_tensor import Float8BlockQuantizer -from ...tensor.mxfp8_tensor import MXFP8Quantizer +from ...tensor.float8_tensor import Float8Quantizer from ...tensor._internal.float8_tensor_base import Float8TensorBase from ..op import BasicOperation, OperationContext from .._common import maybe_dequantize, is_quantized_tensor @@ -480,18 +479,11 @@ def _functional_forward( raise ValueError("Output tensor is quantized, but quantizer was not provided") else: output_quantizer = None - if isinstance(output_quantizer, MXFP8Quantizer): - raise RuntimeError( - "Attempting to generate MXFP8 output tensor, " - "but GEMM with MXFP8 output is not supported" - ) - if isinstance(output_quantizer, Float8BlockQuantizer): - raise RuntimeError( - "Attempting to generate Float8BlockQuantized output tensor, " - "but GEMM with Float8BlockQuantized output is not supported" - ) - if output_quantizer is not None: + if not isinstance(output_quantizer, Float8Quantizer): + raise RuntimeError( + "Attempting to generate quantized output tensor with unsupported quantizer" + ) output_quantizer.set_usage(rowwise=True, columnwise=False) # Check if accumulating into output tensor @@ -765,11 +757,12 @@ def _functional_backward( ) else: grad_input_quantizer = None - if isinstance(grad_input_quantizer, MXFP8Quantizer): - raise RuntimeError( - "Attempting to generate MXFP8 grad input tensor, " - "but GEMM with MXFP8 output is not supported" - ) + if grad_input_quantizer is not None: + if not isinstance(grad_input_quantizer, Float8Quantizer): + raise RuntimeError( + "Attempting to generate quantized grad input tensor " + "with unsupported quantizer" + ) # Check if accumulating into grad input tensor if accumulate_into_grad_input: diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index 9316f3d791..61853f9f41 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -182,7 +182,7 @@ def _functional_forward( if weight_quantizer is None: raise ValueError("Missing quantizer for weight tensor") if output_quantizer is not None: - raise ValueError("FP8 output is not supported") + raise ValueError("Quantized output is not supported") else: input_quantizer = None weight_quantizer = None diff --git a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py index 882650ffba..787c322a0c 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py @@ -59,7 +59,7 @@ def __new__( instance = super().__new__(cls, *args, **kwargs) instance._rowwise_data = rowwise_data instance._columnwise_data = columnwise_data - instance._quantizer = quantizer + instance._quantizer = quantizer.copy() if quantizer is not None else None instance._fp8_dtype = fp8_dtype instance._rowwise_scale_inv = rowwise_scale_inv instance._columnwise_scale_inv = columnwise_scale_inv diff --git a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py index c0dc6e6519..a88ae33f09 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py @@ -86,7 +86,7 @@ def __new__( else: instance = super().__new__(cls, *args, **kwargs) instance._data = data - instance._quantizer = quantizer + instance._quantizer = quantizer.copy() if quantizer is not None else None instance._fp8_dtype = fp8_dtype instance._scale_inv = fp8_scale_inv instance._transpose = data_transpose diff --git a/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py index 8f87e5c73d..a093904bc9 100644 --- a/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py @@ -83,7 +83,7 @@ def __new__( instance = super().__new__(cls, *args, **kwargs) instance._rowwise_data = rowwise_data instance._columnwise_data = columnwise_data - instance._quantizer = quantizer + instance._quantizer = quantizer.copy() if quantizer is not None else None instance._fp8_dtype = fp8_dtype instance._rowwise_scale_inv = rowwise_scale_inv instance._columnwise_scale_inv = columnwise_scale_inv diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index bac7159491..0e41fc9c51 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -521,7 +521,7 @@ def _set_data(self, tensor: torch.Tensor) -> None: def _set_from_tensor(dst: Float8BlockwiseQTensor, src: Float8BlockwiseQTensor): dst._rowwise_data = src._rowwise_data dst._columnwise_data = src._columnwise_data - dst._quantizer = src._quantizer + dst._quantizer = src._quantizer.copy() dst._fp8_dtype = src._fp8_dtype dst._rowwise_scale_inv = src._rowwise_scale_inv dst._columnwise_scale_inv = src._columnwise_scale_inv diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index bccfc49dbd..895e68bf02 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -108,10 +108,9 @@ def make_empty( # Allocate FP8 data transpose if needed data_transpose = None if self.columnwise_usage: - inner_dim = data.size(-1) + transpose_shape = [data.size(-1)] + list(data.shape[:-1]) data_transpose = torch.empty( - inner_dim, - data.numel() // inner_dim, + transpose_shape, dtype=torch.uint8, device=device, ) @@ -230,7 +229,7 @@ def __init__( amax_epsilon: float = 0.0, ) -> None: super().__init__(rowwise=rowwise, columnwise=columnwise) - self.scale = torch.ones(1, dtype=torch.float32, device=device) + self.scale = torch.empty(1, dtype=torch.float32, device=device) self.amax = torch.empty(1, dtype=torch.float32, device=device) self.dtype = fp8_dtype self.with_amax_reduction = with_amax_reduction @@ -690,7 +689,7 @@ def _set_data(self, tensor: torch.Tensor) -> None: # Float8Tensor attributes self._data = tensor._data - self._quantizer = tensor._quantizer + self._quantizer = tensor._quantizer.copy() self._fp8_dtype = tensor._fp8_dtype self._scale_inv = tensor._scale_inv self._transpose = tensor._transpose diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 10b587e17e..b96575d37b 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -433,7 +433,7 @@ def _set_data(self, tensor: torch.Tensor) -> None: super(MXFP8Tensor, type(self)).data.__set__(self, dummy_tensor) self._rowwise_data = tensor._rowwise_data self._columnwise_data = tensor._columnwise_data - self._quantizer = tensor._quantizer + self._quantizer = tensor._quantizer.copy() self._fp8_dtype = tensor._fp8_dtype self._rowwise_scale_inv = tensor._rowwise_scale_inv self._columnwise_scale_inv = tensor._columnwise_scale_inv