Skip to content

Commit de69ca0

Browse files
hxbaitimmoon10pre-commit-ci[bot]ksivaman
authored
[PyTorch] fix input_quantizer usage for save_original_input; fix blockwise FP8 convert_and_update_tensor (#1978)
* fix input_quantizer in save_original_input bwd Signed-off-by: Hongxiao Bai <[email protected]> * fix get shape of blockwise tensor with only compact colwise data Signed-off-by: Hongxiao Bai <[email protected]> * fix blockwise FP8 convert_and_update_tensor Signed-off-by: Hongxiao Bai <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Hongxiao Bai <[email protected]> Co-authored-by: Tim Moon <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
1 parent c0d2f1a commit de69ca0

File tree

3 files changed

+129
-13
lines changed

3 files changed

+129
-13
lines changed

tests/pytorch/test_float8blockwisetensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def test_quantize_dequantize_compact_format(
219219
rowwise=True,
220220
columnwise=dq_columnwise,
221221
block_scaling_dim=block_scaling_dim,
222-
all_gather_usage=True,
222+
all_gather_usage=(block_scaling_dim == 1),
223223
)
224224
self._test_quantize_dequantize(
225225
quantizer=quantizer,

transformer_engine/pytorch/csrc/quantizer.cpp

Lines changed: 122 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -671,13 +671,128 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::convert_and_update_te
671671
const DType dtype = tensor.attr("_fp8_dtype").cast<DType>();
672672
bool is_2D_scaled = tensor.attr("_is_2D_scaled").cast<bool>();
673673

674-
// Check the data matches quantizer usages
675-
NVTE_CHECK(!tensor.attr("_rowwise_data").is_none() == rowwise_usage,
676-
"Float8BlockwiseQTensor does not match quantizer usages (has_rowwise_data=",
677-
!tensor.attr("_rowwise_data").is_none(), ", rowwise_usage=", rowwise_usage);
678-
NVTE_CHECK(!tensor.attr("_columnwise_data").is_none() == columnwise_usage,
679-
"Float8BlockwiseQTensor does not match quantizer usages (has_columnwise_data=",
680-
!tensor.attr("_columnwise_data").is_none(), ", columnwise_usage=", columnwise_usage);
674+
// Extract buffers from Python tensor
675+
auto get_tensor = [&tensor](const char* name) -> std::optional<at::Tensor> {
676+
auto attr_py = tensor.attr(name);
677+
if (attr_py.is_none()) {
678+
return std::nullopt;
679+
}
680+
return attr_py.cast<at::Tensor>();
681+
};
682+
auto rowwise_data = get_tensor("_rowwise_data");
683+
auto rowwise_scale_inv = get_tensor("_rowwise_scale_inv");
684+
auto columnwise_data = get_tensor("_columnwise_data");
685+
auto columnwise_scale_inv = get_tensor("_columnwise_scale_inv");
686+
NVTE_CHECK(rowwise_data || columnwise_data, "FP8BlockwiseTensor has no data.");
687+
688+
// Tensor options and dimensions
689+
at::TensorOptions opts;
690+
at::TensorOptions scale_opts;
691+
opts = opts.dtype(torch::kUInt8).device(torch::kCUDA);
692+
scale_opts = scale_opts.dtype(torch::kFloat32).device(torch::kCUDA);
693+
694+
auto get_columnwise_shape = [&columnwise_data](bool all_gather_usage) -> std::vector<size_t> {
695+
if (!columnwise_data) {
696+
return std::vector<size_t>();
697+
}
698+
if (all_gather_usage) {
699+
return getTensorShape(*columnwise_data);
700+
}
701+
std::vector<size_t> shape = getTensorShape(*columnwise_data);
702+
std::vector<size_t> shape_transposed(shape.size());
703+
for (size_t i = 0; i + 1 < shape.size(); ++i) {
704+
shape_transposed[i] = shape[i + 1];
705+
}
706+
if (shape.size() > 0) {
707+
shape_transposed[shape.size() - 1] = shape[0];
708+
}
709+
return shape_transposed;
710+
};
711+
std::vector<size_t> shape;
712+
if (rowwise_data) {
713+
shape = getTensorShape(*rowwise_data);
714+
if (columnwise_data) {
715+
auto expected_shape = get_columnwise_shape(all_gather_usage);
716+
NVTE_CHECK(shape == expected_shape, "BlockwiseFP8 row-wise data (shape=", shape,
717+
") and column-wise data (shape=", expected_shape, ") do not match");
718+
}
719+
} else {
720+
shape = get_columnwise_shape(all_gather_usage);
721+
}
722+
std::vector<int64_t> torch_shape;
723+
for (auto s : shape) {
724+
torch_shape.emplace_back(static_cast<int64_t>(s));
725+
}
726+
727+
// Coerce row-wise data
728+
if (rowwise_usage) {
729+
if (!rowwise_data) {
730+
rowwise_data = at::empty(torch_shape, opts);
731+
tensor.attr("_rowwise_data") = *rowwise_data;
732+
}
733+
if (!rowwise_scale_inv) {
734+
auto scale_shape = get_scale_shape(shape, false);
735+
size_t sinv0 = scale_shape[0];
736+
size_t sinv1 = scale_shape[1];
737+
rowwise_scale_inv =
738+
at::empty({static_cast<int64_t>(sinv0), static_cast<int64_t>(sinv1)}, scale_opts);
739+
tensor.attr("_rowwise_scale_inv") = *rowwise_scale_inv;
740+
}
741+
} else { // rowwise_usage == false
742+
if (rowwise_data) {
743+
rowwise_data.reset();
744+
tensor.attr("_rowwise_data") = py::none();
745+
}
746+
if (rowwise_scale_inv) {
747+
rowwise_scale_inv.reset();
748+
tensor.attr("_rowwise_scale_inv") = py::none();
749+
}
750+
}
751+
752+
// Coerce column-wise data
753+
if (columnwise_usage) {
754+
std::vector<size_t> columnwise_shape;
755+
std::vector<int64_t> torch_columnwise_shape;
756+
if (torch_shape.size() > 0) {
757+
if (!all_gather_usage) {
758+
torch_columnwise_shape.reserve(torch_shape.size());
759+
columnwise_shape.reserve(shape.size());
760+
torch_columnwise_shape.push_back(torch_shape[torch_shape.size() - 1]);
761+
columnwise_shape.push_back(shape[shape.size() - 1]);
762+
for (size_t i = 0; i < torch_shape.size() - 1; ++i) {
763+
torch_columnwise_shape.push_back(torch_shape[i]);
764+
columnwise_shape.push_back(shape[i]);
765+
}
766+
} else {
767+
// assert we are doing 1D scaling
768+
NVTE_CHECK(block_scaling_dim == 1,
769+
"Compact columnwise format is not supported for 128x128 2D block scaling.");
770+
torch_columnwise_shape = torch_shape;
771+
columnwise_shape = shape;
772+
}
773+
}
774+
if (!columnwise_data) {
775+
columnwise_data = at::empty(torch_columnwise_shape, opts);
776+
tensor.attr("_columnwise_data") = *columnwise_data;
777+
}
778+
if (!columnwise_scale_inv) {
779+
auto scale_shape = get_scale_shape(shape, true);
780+
size_t sinv0 = scale_shape[0];
781+
size_t sinv1 = scale_shape[1];
782+
columnwise_scale_inv =
783+
at::empty({static_cast<int64_t>(sinv0), static_cast<int64_t>(sinv1)}, scale_opts);
784+
tensor.attr("_columnwise_scale_inv") = *columnwise_scale_inv;
785+
}
786+
} else { // columnwise_usage == false
787+
if (columnwise_data) {
788+
columnwise_data.reset();
789+
tensor.attr("_columnwise_data") = py::none();
790+
}
791+
if (columnwise_scale_inv) {
792+
columnwise_scale_inv.reset();
793+
tensor.attr("_columnwise_scale_inv") = py::none();
794+
}
795+
}
681796

682797
auto ret = TensorWrapper(is_2D_scaled ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D);
683798

transformer_engine/pytorch/module/linear.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -589,13 +589,14 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
589589
else:
590590
# Quantize input tensor
591591
quantizer = ctx.input_quantizer
592-
if ctx.backward_input_needs_gather and isinstance(
593-
quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
594-
):
592+
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
595593
# All-gather is not supported with FP8 column-wise data
596-
quantizer.set_usage(rowwise=True, columnwise=False)
594+
quantizer.set_usage(
595+
rowwise=True,
596+
columnwise=not ctx.backward_input_needs_gather,
597+
)
597598
else:
598-
quantizer.set_usage(rowwise=True, columnwise=True)
599+
quantizer.set_usage(rowwise=False, columnwise=True)
599600
inputmat = quantizer(inputmat)
600601
else:
601602
if isinstance(inputmat, QuantizedTensorBase):

0 commit comments

Comments
 (0)