Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/examples/onnx/onnx_export.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"\n",
"<b>Note:</b>\n",
"\n",
"Currently, export to ONNX is supported only for high precision, FP8 delayed scaling and MXFP8.\n",
"Currently, export to ONNX is supported only for high precision, FP8 delayed scaling, FP8 current scaling and MXFP8.\n",
"\n",
"</div>\n",
"\n",
Expand Down
18 changes: 13 additions & 5 deletions tests/pytorch/test_onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
fp8_recipes.append(recipe.MXFP8BlockScaling())
if fp8_available:
fp8_recipes.append(recipe.DelayedScaling())
fp8_recipes.append(recipe.Float8CurrentScaling())
fp8_recipes.append(None)

supported_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"]
Expand All @@ -81,11 +82,11 @@
],
outputs=[PyCustomOpDef.dt_uint8],
)
def trt_fp8_quantize(t, scale):
def trt_fp8_quantize(t, scale_inv):
"""FP8 quantization extension for ONNX Runtime."""
x = torch.from_numpy(t).cuda()
q = te.tensor.float8_tensor.Float8Quantizer(
scale=1 / torch.from_numpy(scale).cuda(),
scale=1 / torch.from_numpy(scale_inv).cuda(),
amax=torch.zeros([1]).cuda(),
fp8_dtype=tex.DType.kFloat8E4M3,
)
Expand All @@ -101,11 +102,11 @@ def trt_fp8_quantize(t, scale):
],
outputs=[PyCustomOpDef.dt_float],
)
def trt_fp8_dequantize(t, scale):
def trt_fp8_dequantize(t, scale_inv):
"""FP8 dequantization extension for ONNX Runtime."""
x = torch.from_numpy(t).cuda()
q = te.tensor.float8_tensor.Float8Quantizer(
scale=1 / torch.from_numpy(scale).cuda(),
scale=1 / torch.from_numpy(scale_inv).cuda(),
amax=torch.zeros([1]).cuda(),
fp8_dtype=tex.DType.kFloat8E4M3,
)
Expand Down Expand Up @@ -593,7 +594,9 @@ def _test_export_layernorm_linear(
fname,
inp,
model,
atol=1e-3,
# For current scaling we use Float8Quantizer in tests + amax computed by hand,
# which has slightly different numerics than Float8CurrentScalingQuantizer.
atol=1e-3 if fp8_recipe.__class__ is not recipe.Float8CurrentScaling else 2e-2,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't current scaling have better numerics than delayed scaling?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We compare fp8 from TE with fp8 from onnxruntime, not fp8 with high precision. Since onnxruntime will use fp8quantizer it is expected that (float8quantizer vs float8qunatizer) will be closed that (float8CurrentScalingQuantizer vs emulation of current scaling using float8qunatizer.

is_fp8=fp8_recipe is not None,
te_outputs=te_outputs,
)
Expand Down Expand Up @@ -1150,6 +1153,11 @@ def test_trt_integration(fp8_recipe: recipe.Recipe):
ffn_hidden_size=128,
num_attention_heads=4,
).eval()

if type(fp8_recipe) == recipe.Float8CurrentScaling:
# TODO(pgadzinski): Attention does not work with TRT for FP8CurrentScaling
model = te.LayerNormMLP(128, 128)

inps = (torch.randn([16, 16, 128], device="cuda", requires_grad=False),)

with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe):
Expand Down
54 changes: 47 additions & 7 deletions transformer_engine/pytorch/onnx_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ def onnx_quantize_fp8_symbolic(
doc="TRT FP8 Quantize Linear used for inference.",
inputs=[
defs.OpSchema.FormalParameter("tensor", "tensor(float)", "Input tensor to quantize"),
defs.OpSchema.FormalParameter("scale", "tensor(float)", "Scale factor for quantization"),
defs.OpSchema.FormalParameter(
"scale_inv", "tensor(float)", "Inverse scale factor for quantization"
),
],
outputs=[defs.OpSchema.FormalParameter("output", "tensor(uint8)", "Quantized output tensor")],
)
Expand All @@ -126,11 +128,10 @@ def onnx_quantize_fp8_symbolic(


@torch.library.custom_op("tex::fp8_dequantize", mutates_args=[])
def onnx_dequantize_fp8_op(tensor: torch.Tensor, scale: float) -> torch.Tensor:
def onnx_dequantize_fp8_op(tensor: torch.Tensor, scale_inv: torch.Tensor) -> torch.Tensor:
"""Dequantize from Float8Tensor used for inference."""
scale_tensor = torch.tensor(scale, dtype=torch.float32, device=tensor.device)
quantizer = Float8Quantizer(
scale_tensor, torch.zeros(1).to(tensor.device), tex.DType.kFloat8E4M3
1 / scale_inv, torch.zeros(1).to(tensor.device), tex.DType.kFloat8E4M3
)
quantizer_tensor = quantizer.create_tensor_from_data(tensor, fake_dtype=torch.float32)
return quantizer_tensor.dequantize()
Expand All @@ -143,10 +144,9 @@ def _(tensor: torch.Tensor, _) -> torch.Tensor:


def onnx_dequantize_fp8_symbolic(
tensor: onnxscript.onnx_types.TensorType, scale: float
tensor: onnxscript.onnx_types.TensorType, scale_inv: onnxscript.onnx_types.TensorType
) -> onnxscript.onnx_types.TensorType:
"""Symbolic dequantize from Float8Tensor used for inference."""
scale_inv = op.Constant(value_float=1 / scale)
return TRT_FP8DequantizeLinear(tensor, scale_inv)


Expand All @@ -157,7 +157,9 @@ def onnx_dequantize_fp8_symbolic(
doc="TRT FP8 Dequantize Linear from Float8Tensor used for inference.",
inputs=[
defs.OpSchema.FormalParameter("tensor", "tensor(uint8)", "Input tensor to dequantize"),
defs.OpSchema.FormalParameter("scale", "tensor(float)", "Scale factor for dequantization"),
defs.OpSchema.FormalParameter(
"scale_inv", "tensor(float)", "Inverse scale factor for dequantization"
),
],
outputs=[defs.OpSchema.FormalParameter("output", "tensor(float)", "Dequantized output tensor")],
)
Expand All @@ -166,6 +168,43 @@ def onnx_dequantize_fp8_symbolic(
opset=trt_opset, name="TRT_FP8DequantizeLinear", op_schema=schema
)

# ONNX FP8 Current Scaling Quantization


@torch.library.custom_op("tex::fp8_cs_quantize", mutates_args=[])
def onnx_cs_quantize_fp8_op(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Quantize to FP8 with current scaling; returns (uint8, scale_inv)."""
if tensor.dtype != torch.float32:
tensor = tensor.to(torch.float32)
amax = tensor.abs().max()
eps = torch.tensor(1e-12, dtype=torch.float32, device=tensor.device)
amax = torch.maximum(amax, eps)
fp8_max = torch.tensor(448, dtype=torch.float32, device=tensor.device)
scale = fp8_max / amax
q = torch.ops.tex.fp8_quantize(tensor, scale)
scale_inv = 1 / scale
return q, scale_inv


@onnx_cs_quantize_fp8_op.register_fake
def _(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.empty(tensor.shape, dtype=torch.uint8, device=tensor.device), torch.ones(
1, dtype=torch.float32, device=tensor.device
)


def onnx_quantize_fp8_cs_symbolic(
tensor: onnxscript.onnx_types.TensorType,
):
"""Symbolic quantize with current scaling; computes scale_inv from tensor."""
# scale_inv = 1 / max(abs(tensor))
amax = op.ReduceMax(op.Abs(tensor), keepdims=0)
eps = op.Constant(value_float=1.0e-12)
amax = op.Max(amax, eps)
scale_inv = op.Div(amax, op.Constant(value_float=448.0))
q = TRT_FP8QuantizeLinear(tensor, scale_inv)
return q, scale_inv


# ONNX MXFP8 Quantization

Expand Down Expand Up @@ -356,6 +395,7 @@ def onnx_attention_mask_func(
torch.ops.tex.gemm_inf.default: onnx_gemm_inf_symbolic,
torch.ops.tex.fp8_quantize.default: onnx_quantize_fp8_symbolic,
torch.ops.tex.fp8_dequantize.default: onnx_dequantize_fp8_symbolic,
torch.ops.tex.fp8_cs_quantize.default: onnx_quantize_fp8_cs_symbolic,
torch.ops.tex.mxfp8_quantize.default: onnx_quantize_mxfp8_symbolic,
torch.ops.tex.mxfp8_dequantize.default: onnx_dequantize_mxfp8_symbolic,
torch.ops.tex.layernorm.default: onnx_layernorm_symbolic,
Expand Down
22 changes: 16 additions & 6 deletions transformer_engine/pytorch/tensor/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor:

def onnx_dequantize(self, tensor: QuantizedTensor) -> torch.Tensor:
"""Function using primitives with ONNX defined translations."""
out = torch.ops.tex.fp8_dequantize(tensor._data, self.scale.item())
out = torch.ops.tex.fp8_dequantize(tensor._data, tensor._scale_inv)
out = out.to(tensor.dtype)
return out

Expand Down Expand Up @@ -350,15 +350,25 @@ def create_tensor_from_data(

def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor:
"""Function using primitives with ONNX defined translations."""
raise NotImplementedError(
"Float8CurrentScalingQuantizer does not support ONNX quantization yet."
if tensor.dtype != torch.float32:
tensor = tensor.to(torch.float32)
data, scale_inv = torch.ops.tex.fp8_cs_quantize(tensor)
return Float8Tensor(
shape=data.shape,
dtype=torch.float32,
data=data,
fp8_scale_inv=scale_inv,
fp8_dtype=self.dtype,
requires_grad=False,
data_transpose=None,
quantizer=self,
)

def onnx_dequantize(self, tensor: QuantizedTensor) -> torch.Tensor:
"""Function using primitives with ONNX defined translations."""
raise NotImplementedError(
"Float8CurrentScalingQuantizer does not support ONNX dequantization yet."
)
out = torch.ops.tex.fp8_dequantize(tensor._data, tensor._scale_inv)
out = out.to(tensor.dtype)
return out

def _canonicalized_amax_reduction_group(self) -> dist_group_type:
"""Get process group for amax reduction"""
Expand Down