Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
c6efd90
remove reciprocal op
zhongbozhu Jul 14, 2025
2fdef53
Refactor Quantizer::create_tensor function
timmoon10 Jul 15, 2025
bd5e1dd
Merge branch 'main' into refactor-quantizer-create-tensor-func
timmoon10 Jul 16, 2025
1338edf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 16, 2025
6d30bb9
Fix bug when constructing FP8 tensor
timmoon10 Jul 17, 2025
5fca0a0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 17, 2025
dc6fae5
Add quantize function to C++ quantizers
timmoon10 Jul 17, 2025
7ac091d
Prototype function to coerce Python quantized tensors to match quantizer
timmoon10 Jul 17, 2025
b30a4b4
Use quantizer class in tex.quantize
timmoon10 Jul 17, 2025
23be7be
Add FP8 current scaling support for activation backward
timmoon10 Jul 17, 2025
302a77d
Disable quantized GEMM output with FP8 current scaling
timmoon10 Jul 17, 2025
952333a
Add coerce_tensor functions for MXFP8 and DSv3
timmoon10 Jul 17, 2025
86af34c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 17, 2025
d0479a9
Merge branch 'main' into refactor-quantizer-create-tensor-func
timmoon10 Jul 17, 2025
596ead5
Avoid quantizing empty tensors
timmoon10 Jul 18, 2025
c4270b3
Use consistent shapes for FP8 transposes
timmoon10 Jul 18, 2025
34d1fde
In attention impl, construct FP8 tensors with pre-initialized scale-invs
timmoon10 Jul 19, 2025
a49cb5e
Initialize MXFP8 scales to zero
timmoon10 Jul 19, 2025
ba68676
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 19, 2025
0a79048
Merge branch 'main' into refactor-quantizer-create-tensor-func
timmoon10 Jul 21, 2025
76d2d53
Store copy of quantizer when creating quantized tensors
timmoon10 Jul 21, 2025
c54d821
Fix linter warnings
timmoon10 Jul 21, 2025
c5d0e46
Merge branch 'main' into refactor-quantizer-create-tensor-func
timmoon10 Jul 22, 2025
c252dc0
Make sure quantized tensors have private quantizer
timmoon10 Jul 22, 2025
c3c1df3
Merge branch 'main' into refactor-quantizer-create-tensor-func
timmoon10 Jul 22, 2025
df6313c
Rename "coerce_tensor" to "convert_and_update_tensor"
timmoon10 Jul 22, 2025
27cf92a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 22, 2025
3e7dbb1
Make sure CUDA context is available when launching NVRTC kernel
timmoon10 Jul 24, 2025
6bdbb12
Merge branch 'main' into refactor-quantizer-create-tensor-func
timmoon10 Jul 24, 2025
261f60f
Expose CUDA context creation function externally
timmoon10 Jul 24, 2025
970e54d
Merge branch 'main' into refactor-quantizer-create-tensor-func
timmoon10 Jul 24, 2025
2cd7fb2
Merge branch 'main' into refactor-quantizer-create-tensor-func
ksivaman Jul 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions tests/pytorch/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,10 +837,9 @@ def _test_basic_linear(
pytest.skip("FP8 output is only supported with FP8 GEMMs")
if quantized_grad_input and not quantized_compute:
pytest.skip("FP8 grad input is only supported with FP8 GEMMs")
if quantization == "mxfp8" and quantized_output:
pytest.skip("MXFP8 output is not supported with MXFP8 GEMMs")
if quantization == "mxfp8" and quantized_grad_input:
pytest.skip("MXFP8 grad input is not supported with MXFP8 GEMMs")
if quantization not in (None, "fp8"):
if quantized_output or quantized_grad_input:
pytest.skip("Recipe does not support quantized GEMM output")

# Random data
x_ref, x_test = make_reference_and_test_tensors(
Expand Down
1 change: 1 addition & 0 deletions transformer_engine/common/libtransformer_engine.version
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
transformer_engine::cuda::stream_priority_range*;
transformer_engine::cuda::current_device*;
transformer_engine::cuda_driver::get_symbol*;
transformer_engine::cuda_driver::ensure_context_exists*;
transformer_engine::ubuf_built_with_mpi*;
*transformer_engine::rtc*;
transformer_engine::nvte_cudnn_handle_init*;
Expand Down
13 changes: 13 additions & 0 deletions transformer_engine/common/util/cuda_driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,19 @@ void *get_symbol(const char *symbol, int cuda_version) {
return entry_point;
}

void ensure_context_exists() {
CUcontext context;
NVTE_CALL_CHECK_CUDA_DRIVER(cuCtxGetCurrent, &context);
if (context == nullptr) {
// Add primary context to context stack
CUdevice device;
NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &device, cuda::current_device());
NVTE_CALL_CHECK_CUDA_DRIVER(cuDevicePrimaryCtxRetain, &context, device);
NVTE_CALL_CHECK_CUDA_DRIVER(cuCtxSetCurrent, context);
NVTE_CALL_CHECK_CUDA_DRIVER(cuDevicePrimaryCtxRelease, device);
}
}

} // namespace cuda_driver

} // namespace transformer_engine
8 changes: 8 additions & 0 deletions transformer_engine/common/util/cuda_driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@ inline CUresult call(const char *symbol, ArgTs... args) {
return (*func)(args...);
}

/*! \brief Ensure that the calling thread has a CUDA context
*
* Each thread maintains a stack of CUDA contexts. If the calling
* thread has an empty stack, the primary context is added to the
* stack.
*/
void ensure_context_exists();

} // namespace cuda_driver

} // namespace transformer_engine
Expand Down
1 change: 1 addition & 0 deletions transformer_engine/common/util/rtc.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class Kernel {
template <typename... ArgTs>
void launch(int device_id, const dim3 grid_dim, const dim3 block_dim,
unsigned int shared_mem_bytes, cudaStream_t stream, ArgTs &&...args) {
cuda_driver::ensure_context_exists();
Copy link
Collaborator Author

@timmoon10 timmoon10 Jul 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR exposed a bug in our NVRTC infrastructure. Three facts:

  1. The CUDA driver maintains a thread-local stack of CUDA contexts.
  2. PyTorch will initialize the CUDA context if needed for jitting.
  3. PyTorch performs autograd on a separate thread.

By removing unnecessary at::reciprocals from create_tensor, I experienced some cases where the backward pass launched an NVRTC kernel before launching any PyTorch ops (namely in the FP8 linear op with UB). Since the autograd thread's context stack was empty, this resulted in "invalid device context" errors.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is interesting, thanks!

void *arg_ptrs[] = {const_cast<void *>(static_cast<const void *>(&args))...};
NVTE_CALL_CHECK_CUDA_DRIVER(cuLaunchKernel, get_function(device_id), grid_dim.x, grid_dim.y,
grid_dim.z, block_dim.x, block_dim.y, block_dim.z, shared_mem_bytes,
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/csrc/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

namespace transformer_engine::pytorch {

std::vector<size_t> getTensorShape(at::Tensor t) {
std::vector<size_t> getTensorShape(const at::Tensor& t) {
std::vector<size_t> shape;
for (auto s : t.sizes()) {
shape.push_back(s);
Expand Down
80 changes: 61 additions & 19 deletions transformer_engine/pytorch/csrc/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,21 @@ class Quantizer {

virtual void set_quantization_params(TensorWrapper* tensor) const = 0;

virtual std::pair<TensorWrapper, py::object> create_tensor(
const std::vector<size_t>& shape, DType dtype,
std::optional<at::Tensor> rowwise_data = std::nullopt) const = 0;
/*! @brief Construct a tensor with uninitialized data */
virtual std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
DType dtype) const = 0;

/*! @brief Convert a PyTorch tensor into a Transformer Engine C++ tensor
*
* The PyTorch tensor's attributes are modified to match the
* quantizer's configuration.
*/
virtual std::pair<TensorWrapper, py::object> convert_and_update_tensor(
py::object tensor) const = 0;

/*! @brief Convert to a quantized data format */
virtual void quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag = std::nullopt) = 0;

virtual ~Quantizer() = default;

Expand All @@ -121,9 +133,17 @@ class NoneQuantizer : public Quantizer {

void set_quantization_params(TensorWrapper* tensor) const override {}

std::pair<TensorWrapper, py::object> create_tensor(
const std::vector<size_t>& shape, DType dtype,
std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
DType dtype) const override;

/*! @brief Construct a tensor with pre-initialized data */
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape, DType dtype,
at::Tensor data) const;

std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object tensor) const override;

void quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;
};

class Float8Quantizer : public Quantizer {
Expand All @@ -139,9 +159,19 @@ class Float8Quantizer : public Quantizer {

void set_quantization_params(TensorWrapper* tensor) const override;

std::pair<TensorWrapper, py::object> create_tensor(
const std::vector<size_t>& shape, DType dtype,
std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
DType dtype) const override;

/*! @brief Construct a tensor with pre-initialized data */
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape, DType dtype,
std::optional<at::Tensor> data,
std::optional<at::Tensor> transpose,
std::optional<at::Tensor> scale_inv) const;

std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override;

void quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;
};

class Float8CurrentScalingQuantizer : public Quantizer {
Expand All @@ -161,9 +191,13 @@ class Float8CurrentScalingQuantizer : public Quantizer {

void set_quantization_params(TensorWrapper* tensor) const override;

std::pair<TensorWrapper, py::object> create_tensor(
const std::vector<size_t>& shape, DType dtype,
std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
DType dtype) const override;

std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override;

void quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;
};

class Float8BlockQuantizer : public Quantizer {
Expand Down Expand Up @@ -195,9 +229,13 @@ class Float8BlockQuantizer : public Quantizer {
// Create a python Float8BlockQuantized tensor and C++ wrapper
// for the tensor. Should set quantized data, scales for rowwise
// and optionally columnwise usage.
std::pair<TensorWrapper, py::object> create_tensor(
const std::vector<size_t>& shape, DType dtype,
std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
DType dtype) const override;

std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override;

void quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;

std::vector<size_t> get_scale_shape(const std::vector<size_t>& shape, bool columnwise) const;
};
Expand All @@ -212,16 +250,20 @@ class MXFP8Quantizer : public Quantizer {

void set_quantization_params(TensorWrapper* tensor) const override;

std::pair<TensorWrapper, py::object> create_tensor(
const std::vector<size_t>& shape, DType dtype,
std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
DType dtype) const override;

std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override;

void quantize(const TensorWrapper& input, TensorWrapper& out,
const std::optional<TensorWrapper>& noop_flag = std::nullopt) override;

std::vector<size_t> get_scale_shape(const std::vector<size_t>& shape, bool columnwise) const;
};

std::unique_ptr<Quantizer> convert_quantizer(py::handle quantizer);

std::vector<size_t> getTensorShape(at::Tensor t);
std::vector<size_t> getTensorShape(const at::Tensor& t);

transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid,
const std::string& fp8_recipe);
Expand Down
127 changes: 57 additions & 70 deletions transformer_engine/pytorch/csrc/extensions/activation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,87 +13,74 @@ namespace transformer_engine::pytorch {
template <void (*act_func)(const NVTETensor, NVTETensor, cudaStream_t)>
py::object activation_helper(const at::Tensor& input, py::handle quantizer, int shape_divisor = 1) {
init_extension();
auto my_quantizer = convert_quantizer(quantizer);
auto input_tensor = input.contiguous();

const TensorWrapper& te_input = makeTransformerEngineTensor(input_tensor);
const auto& te_input_shape = te_input.shape();
std::vector<size_t> input_shape(te_input_shape.data, te_input_shape.data + te_input_shape.ndim);
input_shape[input_shape.size() - 1] /= shape_divisor;
auto fake_tensor_type = input.scalar_type();

auto [te_output, out] =
my_quantizer->create_tensor(input_shape, GetTransformerEngineDType(fake_tensor_type));

// for current scaling, we need to compute amax first and then quantize
// because cache cannot fit in the entire tensor to compute amax and quantize
// the quantizer should not need amax reduction, no process group needed here
if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) {
// activation function might change the input data range, we need to first call the activation function
// and then find the amax and scale of that and then do the quantization
// get a NoneQuantizer to calculate amax of activation output
auto my_quantizer_none = std::make_unique<NoneQuantizer>(py::none());
auto [te_output_act, out_act] =
my_quantizer_none->create_tensor(input_shape, GetTransformerEngineDType(fake_tensor_type));

NVTE_SCOPED_GIL_RELEASE({
act_func(te_input.data(), te_output_act.data(), at::cuda::getCurrentCUDAStream());
// use te_output_act as input to the compute amax and find the amax of activated tensor
nvte_compute_amax(te_output_act.data(), te_output.data(), at::cuda::getCurrentCUDAStream());
});

// my_quantizer here has to be a Float8CurrentScalingQuantizer
auto my_quantizer_cs = static_cast<Float8CurrentScalingQuantizer*>(my_quantizer.get());
if (my_quantizer_cs->with_amax_reduction) {
NVTE_ERROR(
"per-tensor current scaling amax reduction is not supported in activation functions.");
}
QuantizationConfigWrapper quant_config;
quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales);
quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon);

NVTE_SCOPED_GIL_RELEASE({
nvte_compute_scale_from_amax(te_output.data(), quant_config,
at::cuda::getCurrentCUDAStream());
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
te_output.set_amax(nullptr, DType::kFloat32, te_output.defaultShape);
nvte_quantize_v2(te_output_act.data(), te_output.data(), quant_config,
at::cuda::getCurrentCUDAStream());
});
} else if (detail::IsFloat8BlockwiseQuantizers(quantizer.ptr())) {
// sanity check, since activation fusion is not supported for blockwise quantization yet
// need to raise an error here instead of silently going into act_func with wrong numerics
NVTE_ERROR("Activation fusion is not supported for blockwise quantization yet.");
// Input tensor
auto input_tensor = input.contiguous();
const TensorWrapper& input_cpp = makeTransformerEngineTensor(input_tensor);

// Construct output tensor
auto quantizer_cpp = convert_quantizer(quantizer);
const auto input_shape = input_cpp.shape();
std::vector<size_t> output_shape(input_shape.data, input_shape.data + input_shape.ndim);
output_shape.back() /= shape_divisor;
auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type());
auto [out_cpp, out_py] = quantizer_cpp->create_tensor(output_shape, fake_dtype);

// Compute activation
if (quantizer.is_none() || detail::IsFloat8Quantizers(quantizer.ptr()) ||
detail::IsMXFP8Quantizers(quantizer.ptr())) {
// Compute activation directly
NVTE_SCOPED_GIL_RELEASE(
{ act_func(input_cpp.data(), out_cpp.data(), at::cuda::getCurrentCUDAStream()); });
} else {
// Compute activation in high-precision, then quantize
auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(output_shape, fake_dtype);
NVTE_SCOPED_GIL_RELEASE(
{ act_func(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); });
{ act_func(input_cpp.data(), temp_cpp.data(), at::cuda::getCurrentCUDAStream()); });
quantizer_cpp->quantize(temp_cpp, out_cpp);
}

return out;
return out_py;
}

template <void (*act_func)(const NVTETensor, const NVTETensor, NVTETensor, cudaStream_t)>
py::object dactivation_helper(const at::Tensor& grad, const at::Tensor& input,
template <void (*dact_func)(const NVTETensor, const NVTETensor, NVTETensor, cudaStream_t)>
py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& input,
py::handle quantizer) {
init_extension();
auto my_quantizer = convert_quantizer(quantizer);
auto input_tensor = input.contiguous();
auto grad_tensor = grad.contiguous();

const TensorWrapper& te_input = makeTransformerEngineTensor(input_tensor);
const TensorWrapper& te_grad = makeTransformerEngineTensor(grad_tensor);
const auto& te_input_shape = te_input.shape();
std::vector<size_t> input_shape(te_input_shape.data, te_input_shape.data + te_input_shape.ndim);
auto fake_tensor_type = input.scalar_type();

auto [te_output, out] =
my_quantizer->create_tensor(input_shape, GetTransformerEngineDType(fake_tensor_type));

NVTE_SCOPED_GIL_RELEASE({
act_func(te_grad.data(), te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream());
});
// Grad output and input tensors
auto grad_output_tensor = grad_output.contiguous();
auto input_tensor = input.contiguous();
const TensorWrapper& grad_output_cpp = makeTransformerEngineTensor(grad_output_tensor);
const TensorWrapper& input_cpp = makeTransformerEngineTensor(input_tensor);

// Construct grad input tensor
auto quantizer_cpp = convert_quantizer(quantizer);
const auto input_shape_te = input_cpp.shape();
const std::vector<size_t> input_shape(input_shape_te.data,
input_shape_te.data + input_shape_te.ndim);
auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type());
auto [grad_input_cpp, grad_input_py] = quantizer_cpp->create_tensor(input_shape, fake_dtype);

// Compute activation backward
if (quantizer.is_none() || detail::IsFloat8Quantizers(quantizer.ptr()) ||
detail::IsMXFP8Quantizers(quantizer.ptr())) {
// Compute activation backward directly
NVTE_SCOPED_GIL_RELEASE({
dact_func(grad_output_cpp.data(), input_cpp.data(), grad_input_cpp.data(),
at::cuda::getCurrentCUDAStream());
});
} else {
// Compute activation backward in high-precision, then quantize
auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(input_shape, fake_dtype);
NVTE_SCOPED_GIL_RELEASE({
dact_func(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(),
at::cuda::getCurrentCUDAStream());
});
quantizer_cpp->quantize(temp_cpp, grad_input_cpp);
}

return out;
return grad_input_py;
}

py::object gelu(const at::Tensor& input, py::handle quantizer) {
Expand Down
Loading