Skip to content

[PyTorch] fix input_quantizer usage for save_original_input; fix blockwise FP8 convert_and_update_tensor #1978

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Aug 6, 2025
2 changes: 1 addition & 1 deletion tests/pytorch/test_float8blockwisetensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def test_quantize_dequantize_compact_format(
rowwise=True,
columnwise=dq_columnwise,
block_scaling_dim=block_scaling_dim,
all_gather_usage=True,
all_gather_usage=(block_scaling_dim == 1),
)
self._test_quantize_dequantize(
quantizer=quantizer,
Expand Down
129 changes: 122 additions & 7 deletions transformer_engine/pytorch/csrc/quantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -671,13 +671,128 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::convert_and_update_te
const DType dtype = tensor.attr("_fp8_dtype").cast<DType>();
bool is_2D_scaled = tensor.attr("_is_2D_scaled").cast<bool>();

// Check the data matches quantizer usages
NVTE_CHECK(!tensor.attr("_rowwise_data").is_none() == rowwise_usage,
"Float8BlockwiseQTensor does not match quantizer usages (has_rowwise_data=",
!tensor.attr("_rowwise_data").is_none(), ", rowwise_usage=", rowwise_usage);
NVTE_CHECK(!tensor.attr("_columnwise_data").is_none() == columnwise_usage,
"Float8BlockwiseQTensor does not match quantizer usages (has_columnwise_data=",
!tensor.attr("_columnwise_data").is_none(), ", columnwise_usage=", columnwise_usage);
// Extract buffers from Python tensor
auto get_tensor = [&tensor](const char* name) -> std::optional<at::Tensor> {
auto attr_py = tensor.attr(name);
if (attr_py.is_none()) {
return std::nullopt;
}
return attr_py.cast<at::Tensor>();
};
auto rowwise_data = get_tensor("_rowwise_data");
auto rowwise_scale_inv = get_tensor("_rowwise_scale_inv");
auto columnwise_data = get_tensor("_columnwise_data");
auto columnwise_scale_inv = get_tensor("_columnwise_scale_inv");
NVTE_CHECK(rowwise_data || columnwise_data, "FP8BlockwiseTensor has no data.");

// Tensor options and dimensions
at::TensorOptions opts;
at::TensorOptions scale_opts;
opts = opts.dtype(torch::kUInt8).device(torch::kCUDA);
scale_opts = scale_opts.dtype(torch::kFloat32).device(torch::kCUDA);

auto get_columnwise_shape = [&columnwise_data](bool all_gather_usage) -> std::vector<size_t> {
if (!columnwise_data) {
return std::vector<size_t>();
}
if (all_gather_usage) {
return getTensorShape(*columnwise_data);
}
std::vector<size_t> shape = getTensorShape(*columnwise_data);
std::vector<size_t> shape_transposed(shape.size());
for (size_t i = 0; i + 1 < shape.size(); ++i) {
shape_transposed[i] = shape[i + 1];
}
if (shape.size() > 0) {
shape_transposed[shape.size() - 1] = shape[0];
}
return shape_transposed;
};
std::vector<size_t> shape;
if (rowwise_data) {
shape = getTensorShape(*rowwise_data);
if (columnwise_data) {
auto expected_shape = get_columnwise_shape(all_gather_usage);
NVTE_CHECK(shape == expected_shape, "BlockwiseFP8 row-wise data (shape=", shape,
") and column-wise data (shape=", expected_shape, ") do not match");
}
} else {
shape = get_columnwise_shape(all_gather_usage);
}
std::vector<int64_t> torch_shape;
for (auto s : shape) {
torch_shape.emplace_back(static_cast<int64_t>(s));
}

// Coerce row-wise data
if (rowwise_usage) {
if (!rowwise_data) {
rowwise_data = at::empty(torch_shape, opts);
tensor.attr("_rowwise_data") = *rowwise_data;
}
if (!rowwise_scale_inv) {
auto scale_shape = get_scale_shape(shape, false);
size_t sinv0 = scale_shape[0];
size_t sinv1 = scale_shape[1];
rowwise_scale_inv =
at::empty({static_cast<int64_t>(sinv0), static_cast<int64_t>(sinv1)}, scale_opts);
tensor.attr("_rowwise_scale_inv") = *rowwise_scale_inv;
}
} else { // rowwise_usage == false
if (rowwise_data) {
rowwise_data.reset();
tensor.attr("_rowwise_data") = py::none();
}
if (rowwise_scale_inv) {
rowwise_scale_inv.reset();
tensor.attr("_rowwise_scale_inv") = py::none();
}
}

// Coerce column-wise data
if (columnwise_usage) {
std::vector<size_t> columnwise_shape;
std::vector<int64_t> torch_columnwise_shape;
if (torch_shape.size() > 0) {
if (!all_gather_usage) {
torch_columnwise_shape.reserve(torch_shape.size());
columnwise_shape.reserve(shape.size());
torch_columnwise_shape.push_back(torch_shape[torch_shape.size() - 1]);
columnwise_shape.push_back(shape[shape.size() - 1]);
for (size_t i = 0; i < torch_shape.size() - 1; ++i) {
torch_columnwise_shape.push_back(torch_shape[i]);
columnwise_shape.push_back(shape[i]);
}
} else {
// assert we are doing 1D scaling
NVTE_CHECK(block_scaling_dim == 1,
"Compact columnwise format is not supported for 128x128 2D block scaling.");
torch_columnwise_shape = torch_shape;
columnwise_shape = shape;
}
}
if (!columnwise_data) {
columnwise_data = at::empty(torch_columnwise_shape, opts);
tensor.attr("_columnwise_data") = *columnwise_data;
}
if (!columnwise_scale_inv) {
auto scale_shape = get_scale_shape(shape, true);
size_t sinv0 = scale_shape[0];
size_t sinv1 = scale_shape[1];
columnwise_scale_inv =
at::empty({static_cast<int64_t>(sinv0), static_cast<int64_t>(sinv1)}, scale_opts);
tensor.attr("_columnwise_scale_inv") = *columnwise_scale_inv;
}
} else { // columnwise_usage == false
if (columnwise_data) {
columnwise_data.reset();
tensor.attr("_columnwise_data") = py::none();
}
if (columnwise_scale_inv) {
columnwise_scale_inv.reset();
tensor.attr("_columnwise_scale_inv") = py::none();
}
}

auto ret = TensorWrapper(is_2D_scaled ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D);

Expand Down
11 changes: 6 additions & 5 deletions transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,13 +589,14 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
else:
# Quantize input tensor
quantizer = ctx.input_quantizer
if ctx.backward_input_needs_gather and isinstance(
quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
):
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
# All-gather is not supported with FP8 column-wise data
quantizer.set_usage(rowwise=True, columnwise=False)
quantizer.set_usage(
rowwise=True,
columnwise=not ctx.backward_input_needs_gather,
)
else:
quantizer.set_usage(rowwise=True, columnwise=True)
quantizer.set_usage(rowwise=False, columnwise=True)
inputmat = quantizer(inputmat)
else:
if isinstance(inputmat, QuantizedTensorBase):
Expand Down