@@ -671,13 +671,128 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::convert_and_update_te
671
671
const DType dtype = tensor.attr (" _fp8_dtype" ).cast <DType>();
672
672
bool is_2D_scaled = tensor.attr (" _is_2D_scaled" ).cast <bool >();
673
673
674
- // Check the data matches quantizer usages
675
- NVTE_CHECK (!tensor.attr (" _rowwise_data" ).is_none () == rowwise_usage,
676
- " Float8BlockwiseQTensor does not match quantizer usages (has_rowwise_data=" ,
677
- !tensor.attr (" _rowwise_data" ).is_none (), " , rowwise_usage=" , rowwise_usage);
678
- NVTE_CHECK (!tensor.attr (" _columnwise_data" ).is_none () == columnwise_usage,
679
- " Float8BlockwiseQTensor does not match quantizer usages (has_columnwise_data=" ,
680
- !tensor.attr (" _columnwise_data" ).is_none (), " , columnwise_usage=" , columnwise_usage);
674
+ // Extract buffers from Python tensor
675
+ auto get_tensor = [&tensor](const char * name) -> std::optional<at::Tensor> {
676
+ auto attr_py = tensor.attr (name);
677
+ if (attr_py.is_none ()) {
678
+ return std::nullopt ;
679
+ }
680
+ return attr_py.cast <at::Tensor>();
681
+ };
682
+ auto rowwise_data = get_tensor (" _rowwise_data" );
683
+ auto rowwise_scale_inv = get_tensor (" _rowwise_scale_inv" );
684
+ auto columnwise_data = get_tensor (" _columnwise_data" );
685
+ auto columnwise_scale_inv = get_tensor (" _columnwise_scale_inv" );
686
+ NVTE_CHECK (rowwise_data || columnwise_data, " FP8BlockwiseTensor has no data." );
687
+
688
+ // Tensor options and dimensions
689
+ at::TensorOptions opts;
690
+ at::TensorOptions scale_opts;
691
+ opts = opts.dtype (torch::kUInt8 ).device (torch::kCUDA );
692
+ scale_opts = scale_opts.dtype (torch::kFloat32 ).device (torch::kCUDA );
693
+
694
+ auto get_columnwise_shape = [&columnwise_data](bool all_gather_usage) -> std::vector<size_t > {
695
+ if (!columnwise_data) {
696
+ return std::vector<size_t >();
697
+ }
698
+ if (all_gather_usage) {
699
+ return getTensorShape (*columnwise_data);
700
+ }
701
+ std::vector<size_t > shape = getTensorShape (*columnwise_data);
702
+ std::vector<size_t > shape_transposed (shape.size ());
703
+ for (size_t i = 0 ; i + 1 < shape.size (); ++i) {
704
+ shape_transposed[i] = shape[i + 1 ];
705
+ }
706
+ if (shape.size () > 0 ) {
707
+ shape_transposed[shape.size () - 1 ] = shape[0 ];
708
+ }
709
+ return shape_transposed;
710
+ };
711
+ std::vector<size_t > shape;
712
+ if (rowwise_data) {
713
+ shape = getTensorShape (*rowwise_data);
714
+ if (columnwise_data) {
715
+ auto expected_shape = get_columnwise_shape (all_gather_usage);
716
+ NVTE_CHECK (shape == expected_shape, " BlockwiseFP8 row-wise data (shape=" , shape,
717
+ " ) and column-wise data (shape=" , expected_shape, " ) do not match" );
718
+ }
719
+ } else {
720
+ shape = get_columnwise_shape (all_gather_usage);
721
+ }
722
+ std::vector<int64_t > torch_shape;
723
+ for (auto s : shape) {
724
+ torch_shape.emplace_back (static_cast <int64_t >(s));
725
+ }
726
+
727
+ // Coerce row-wise data
728
+ if (rowwise_usage) {
729
+ if (!rowwise_data) {
730
+ rowwise_data = at::empty (torch_shape, opts);
731
+ tensor.attr (" _rowwise_data" ) = *rowwise_data;
732
+ }
733
+ if (!rowwise_scale_inv) {
734
+ auto scale_shape = get_scale_shape (shape, false );
735
+ size_t sinv0 = scale_shape[0 ];
736
+ size_t sinv1 = scale_shape[1 ];
737
+ rowwise_scale_inv =
738
+ at::empty ({static_cast <int64_t >(sinv0), static_cast <int64_t >(sinv1)}, scale_opts);
739
+ tensor.attr (" _rowwise_scale_inv" ) = *rowwise_scale_inv;
740
+ }
741
+ } else { // rowwise_usage == false
742
+ if (rowwise_data) {
743
+ rowwise_data.reset ();
744
+ tensor.attr (" _rowwise_data" ) = py::none ();
745
+ }
746
+ if (rowwise_scale_inv) {
747
+ rowwise_scale_inv.reset ();
748
+ tensor.attr (" _rowwise_scale_inv" ) = py::none ();
749
+ }
750
+ }
751
+
752
+ // Coerce column-wise data
753
+ if (columnwise_usage) {
754
+ std::vector<size_t > columnwise_shape;
755
+ std::vector<int64_t > torch_columnwise_shape;
756
+ if (torch_shape.size () > 0 ) {
757
+ if (!all_gather_usage) {
758
+ torch_columnwise_shape.reserve (torch_shape.size ());
759
+ columnwise_shape.reserve (shape.size ());
760
+ torch_columnwise_shape.push_back (torch_shape[torch_shape.size () - 1 ]);
761
+ columnwise_shape.push_back (shape[shape.size () - 1 ]);
762
+ for (size_t i = 0 ; i < torch_shape.size () - 1 ; ++i) {
763
+ torch_columnwise_shape.push_back (torch_shape[i]);
764
+ columnwise_shape.push_back (shape[i]);
765
+ }
766
+ } else {
767
+ // assert we are doing 1D scaling
768
+ NVTE_CHECK (block_scaling_dim == 1 ,
769
+ " Compact columnwise format is not supported for 128x128 2D block scaling." );
770
+ torch_columnwise_shape = torch_shape;
771
+ columnwise_shape = shape;
772
+ }
773
+ }
774
+ if (!columnwise_data) {
775
+ columnwise_data = at::empty (torch_columnwise_shape, opts);
776
+ tensor.attr (" _columnwise_data" ) = *columnwise_data;
777
+ }
778
+ if (!columnwise_scale_inv) {
779
+ auto scale_shape = get_scale_shape (shape, true );
780
+ size_t sinv0 = scale_shape[0 ];
781
+ size_t sinv1 = scale_shape[1 ];
782
+ columnwise_scale_inv =
783
+ at::empty ({static_cast <int64_t >(sinv0), static_cast <int64_t >(sinv1)}, scale_opts);
784
+ tensor.attr (" _columnwise_scale_inv" ) = *columnwise_scale_inv;
785
+ }
786
+ } else { // columnwise_usage == false
787
+ if (columnwise_data) {
788
+ columnwise_data.reset ();
789
+ tensor.attr (" _columnwise_data" ) = py::none ();
790
+ }
791
+ if (columnwise_scale_inv) {
792
+ columnwise_scale_inv.reset ();
793
+ tensor.attr (" _columnwise_scale_inv" ) = py::none ();
794
+ }
795
+ }
681
796
682
797
auto ret = TensorWrapper (is_2D_scaled ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D);
683
798
0 commit comments