@@ -1970,6 +1970,7 @@ def choose_qparams_affine_float8(
1970
1970
tensor : torch .Tensor ,
1971
1971
float8_dtype : torch .dtype = torch .float8_e4m3fn ,
1972
1972
scale_dtype : torch .dtype = torch .float32 ,
1973
+ block_size : Optional [Tuple [int , ...]] = None ,
1973
1974
) -> torch .Tensor :
1974
1975
"""
1975
1976
Calculates float8 scaling factor for the given high precision tensor, using tensorwise granularity.
@@ -1978,12 +1979,27 @@ def choose_qparams_affine_float8(
1978
1979
tensor (torch.Tensor): Input tensor to be quantized.
1979
1980
float8_dtype (torch.dtype): Data type of the quantized tensor (e.g., torch.float8_e4m3fn, torch.float8_e5m2).
1980
1981
"""
1982
+ quant_max = torch .finfo (float8_dtype ).max
1981
1983
# only tensorwise scaling is supported for now:
1982
- quant_min , quant_max = torch .finfo (float8_dtype ).min , torch .finfo (float8_dtype ).max
1983
- min_val_neg = torch .min (tensor )
1984
- max_val_pos = torch .max (tensor )
1985
- max_val_pos = torch .max (- min_val_neg , max_val_pos )
1986
- scale = max_val_pos / (float (quant_max - quant_min ) / 2 )
1984
+ if block_size is None :
1985
+ max_abs = tensor .abs ().max ()
1986
+ scale = max_abs / quant_max
1987
+ else :
1988
+ shape_for_reduction , reduction_dims = _get_reduction_params (
1989
+ block_size , tensor .shape
1990
+ )
1991
+ tensor_reshaped = tensor .view (shape_for_reduction )
1992
+ max_abs = tensor_reshaped .abs ().amax (dim = reduction_dims , keepdim = True )
1993
+
1994
+ scale = max_abs / quant_max
1995
+ # Reshape scale back to match the expected output shape
1996
+ # The scale tensor should have the same shape as the input divided by block_size
1997
+ output_shape = [
1998
+ input_size // block_size [i ] if block_size [i ] > 1 else input_size
1999
+ for i , input_size in enumerate (tensor .shape )
2000
+ ]
2001
+ scale = scale .reshape (output_shape )
2002
+
1987
2003
return scale .to (dtype = scale_dtype )
1988
2004
1989
2005
@@ -2027,5 +2043,24 @@ def dequantize_affine_float8(
2027
2043
# upcasted to `float32` to divide by the scale, since scale is a fp32 for float8 quantization.
2028
2044
# In order to match numerics between eager and compile, we upcast manually here.
2029
2045
fp8_tensor = tensor .to (torch .float32 )
2030
- hp_tensor = fp8_tensor * scale
2046
+ # For block-wise quantization, we need to broadcast the scale to match tensor dimensions
2047
+ if scale .shape != tensor .shape :
2048
+ # Calculate the block size from the shape difference
2049
+ block_size = tuple (
2050
+ tensor .shape [i ] // scale .shape [i ]
2051
+ if scale .shape [i ] != tensor .shape [i ]
2052
+ else 1
2053
+ for i in range (len (tensor .shape ))
2054
+ )
2055
+
2056
+ scale_expanded = scale
2057
+ for i in range (len (tensor .shape )):
2058
+ if block_size [i ] > 1 :
2059
+ # Repeat the scale values for each block
2060
+ scale_expanded = scale_expanded .repeat_interleave (block_size [i ], dim = i )
2061
+ else :
2062
+ # Tensor-wise quantization - scale already matches
2063
+ scale_expanded = scale
2064
+
2065
+ hp_tensor = fp8_tensor * scale_expanded
2031
2066
return hp_tensor .to (output_dtype )
0 commit comments