-
Notifications
You must be signed in to change notification settings - Fork 524
[PyTorch] ONNX export of FP8 Current Scaling #2068
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
cdcae9e
5e8c645
531d8c2
7c41268
2d33f1b
14a3719
eb37999
754b340
0c34985
f182d72
bec8c09
5bcb317
bd2946c
aeb560a
f2884bc
4d7de1b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -112,7 +112,9 @@ def onnx_quantize_fp8_symbolic( | |
doc="TRT FP8 Quantize Linear used for inference.", | ||
inputs=[ | ||
defs.OpSchema.FormalParameter("tensor", "tensor(float)", "Input tensor to quantize"), | ||
defs.OpSchema.FormalParameter("scale", "tensor(float)", "Scale factor for quantization"), | ||
defs.OpSchema.FormalParameter( | ||
"scale_inv", "tensor(float)", "Inverse scale factor for quantization" | ||
), | ||
], | ||
outputs=[defs.OpSchema.FormalParameter("output", "tensor(uint8)", "Quantized output tensor")], | ||
) | ||
|
@@ -157,7 +159,9 @@ def onnx_dequantize_fp8_symbolic( | |
doc="TRT FP8 Dequantize Linear from Float8Tensor used for inference.", | ||
inputs=[ | ||
defs.OpSchema.FormalParameter("tensor", "tensor(uint8)", "Input tensor to dequantize"), | ||
defs.OpSchema.FormalParameter("scale", "tensor(float)", "Scale factor for dequantization"), | ||
defs.OpSchema.FormalParameter( | ||
"scale_inv", "tensor(float)", "Inverse scale factor for dequantization" | ||
), | ||
], | ||
outputs=[defs.OpSchema.FormalParameter("output", "tensor(float)", "Dequantized output tensor")], | ||
) | ||
|
@@ -166,6 +170,68 @@ def onnx_dequantize_fp8_symbolic( | |
opset=trt_opset, name="TRT_FP8DequantizeLinear", op_schema=schema | ||
) | ||
|
||
# ONNX FP8 Current Scaling Quantization | ||
|
||
|
||
@torch.library.custom_op("tex::fp8_cs_quantize", mutates_args=[]) | ||
def onnx_cs_quantize_fp8_op(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | ||
"""Quantize to FP8 with current scaling; returns (uint8, scale_inv).""" | ||
if tensor.dtype != torch.float32: | ||
tensor = tensor.to(torch.float32) | ||
amax = tensor.abs().max() | ||
eps = torch.tensor(1e-12, dtype=torch.float32, device=tensor.device) | ||
amax = torch.maximum(amax, eps) | ||
fp8_max = torch.tensor(448, dtype=torch.float32, device=tensor.device) | ||
scale = fp8_max / amax | ||
q = torch.ops.tex.fp8_quantize(tensor, scale) | ||
scale_inv = 1 / scale | ||
return q, scale_inv | ||
|
||
|
||
@onnx_cs_quantize_fp8_op.register_fake | ||
def _(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | ||
return torch.empty(tensor.shape, dtype=torch.uint8, device=tensor.device), torch.ones( | ||
1, dtype=torch.float32, device=tensor.device | ||
) | ||
|
||
|
||
def onnx_quantize_fp8_cs_symbolic( | ||
tensor: onnxscript.onnx_types.TensorType, | ||
): | ||
"""Symbolic quantize with current scaling; computes scale_inv from tensor.""" | ||
# scale_inv = 1 / max(abs(tensor)) | ||
amax = op.ReduceMax(op.Abs(tensor), keepdims=0) | ||
eps = op.Constant(value_float=1.0e-12) | ||
amax = op.Max(amax, eps) | ||
scale_inv = op.Div(amax, op.Constant(value_float=448.0)) | ||
q = TRT_FP8QuantizeLinear(tensor, scale_inv) | ||
return q, scale_inv | ||
|
||
|
||
# ONNX FP8 Current Scaling Dequantization | ||
|
||
|
||
@torch.library.custom_op("tex::fp8_cs_dequantize", mutates_args=[]) | ||
def onnx_cs_dequantize_fp8_op(tensor: torch.Tensor, scale_inv: torch.Tensor) -> torch.Tensor: | ||
"""Dequantize FP8 with provided dynamic scale_inv.""" | ||
quantizer = Float8Quantizer( | ||
1 / scale_inv.to(torch.float32), torch.zeros(1).to(tensor.device), tex.DType.kFloat8E4M3 | ||
) | ||
quantizer_tensor = quantizer.create_tensor_from_data(tensor, fake_dtype=torch.float32) | ||
return quantizer_tensor.dequantize() | ||
|
||
|
||
@onnx_cs_dequantize_fp8_op.register_fake | ||
def _(tensor: torch.Tensor, _) -> torch.Tensor: | ||
return torch.empty(tensor.shape, dtype=torch.float32, device=tensor.device) | ||
|
||
|
||
def onnx_dequantize_fp8_cs_symbolic( | ||
tensor: onnxscript.onnx_types.UINT8, scale_inv: onnxscript.onnx_types.TensorType | ||
) -> onnxscript.onnx_types.TensorType: | ||
"""Symbolic dequantize with current scaling.""" | ||
return TRT_FP8DequantizeLinear(tensor, scale_inv) | ||
|
||
|
||
|
||
# ONNX MXFP8 Quantization | ||
|
||
|
@@ -356,6 +422,8 @@ def onnx_attention_mask_func( | |
torch.ops.tex.gemm_inf.default: onnx_gemm_inf_symbolic, | ||
torch.ops.tex.fp8_quantize.default: onnx_quantize_fp8_symbolic, | ||
torch.ops.tex.fp8_dequantize.default: onnx_dequantize_fp8_symbolic, | ||
torch.ops.tex.fp8_cs_quantize.default: onnx_quantize_fp8_cs_symbolic, | ||
torch.ops.tex.fp8_cs_dequantize.default: onnx_dequantize_fp8_cs_symbolic, | ||
torch.ops.tex.mxfp8_quantize.default: onnx_quantize_mxfp8_symbolic, | ||
torch.ops.tex.mxfp8_dequantize.default: onnx_dequantize_mxfp8_symbolic, | ||
torch.ops.tex.layernorm.default: onnx_layernorm_symbolic, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't current scaling have better numerics than delayed scaling?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We compare fp8 from TE with fp8 from onnxruntime, not fp8 with high precision. Since onnxruntime will use fp8quantizer it is expected that (float8quantizer vs float8qunatizer) will be closed that (float8CurrentScalingQuantizer vs emulation of current scaling using float8qunatizer.