Skip to content

Commit f1f3716

Browse files
committed
fix: Add automatic type promotion for FX ops
- Implement functionality to cast tensors to alternative types - Add functionality to elementwise ops to promote types and perform necessary casts - Address issues in FX ops where mixed-precision computations can cause errors - Add test cases to validate fix
1 parent df401dd commit f1f3716

File tree

3 files changed

+62
-1
lines changed

3 files changed

+62
-1
lines changed

py/torch_tensorrt/fx/converters/converter_utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,38 @@ def create_constant(
277277
return constant.get_output(0)
278278

279279

280+
def cast_trt_tensor(
281+
network: TRTNetwork,
282+
input_val: TRTTensor,
283+
dtype: TRTDataType,
284+
name: str,
285+
) -> TRTTensor:
286+
"""
287+
Given a TRT Tensor, convert that Tensor to the specified dtype
288+
289+
Adds an Identity layer to the network which performs the conversion
290+
291+
Args:
292+
network (TRTNetwork): A TensorRT network
293+
input_val (TRTTensor): A TRT Tensor to cast to a new data type
294+
dtype (TRTDataType): The TRTDataType to cast the input Tensor to
295+
name (str): Name of the calling layer
296+
297+
Returns:
298+
A TensorRT ITensor which has been casted to the specified dtype
299+
"""
300+
#
301+
if input_val.dtype != dtype:
302+
identity_layer = network.add_identity(input_val)
303+
identity_layer.set_output_type(0, dtype)
304+
identity_layer.name = (
305+
f"Cast ITensor {input_val.name} from {input_val.dtype} to {dtype} - {name}"
306+
)
307+
return identity_layer.get_output(0)
308+
else:
309+
return input_val
310+
311+
280312
def get_trt_tensor(
281313
network: TRTNetwork,
282314
input_val: Any,

py/torch_tensorrt/fx/converters/impl/elementwise/base.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,14 @@
1010
from torch.fx.node import Target
1111

1212
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor, TRTElementWiseOp
13-
from torch_tensorrt.fx.utils import torch_dtype_from_trt
13+
from torch_tensorrt.fx.utils import torch_dtype_from_trt, torch_dtype_to_trt
1414
from torch_tensorrt.fx.converters.converter_utils import (
1515
SourceIR,
1616
set_layer_name,
1717
broadcast,
1818
squeeze_left,
1919
get_trt_tensor,
20+
cast_trt_tensor,
2021
)
2122

2223

@@ -52,6 +53,7 @@ def convert_binary_elementwise(
5253
introduce constant via .size() op. Other scenario should be const folded first.
5354
If any operand is not a trt tensor, we make it a trt constant layer while preserve
5455
its dtype. Then we broadcast these two inputs to have the same number of dimensions.
56+
We also promote the types of the two tensors to avoid dtype errors in TRT.
5557
5658
Limitation:
5759
If we are using implicit batch dim mode, the operand that is not a trt
@@ -126,6 +128,16 @@ def convert_binary_elementwise(
126128
lhs_val = get_trt_tensor(network, lhs_val, f"{name}_lhs", lhs_dtype)
127129
rhs_val = get_trt_tensor(network, rhs_val, f"{name}_rhs", rhs_dtype)
128130

131+
promoted_type = torch.promote_types(
132+
torch_dtype_from_trt(lhs_val.dtype), torch_dtype_from_trt(rhs_val.dtype)
133+
)
134+
trt_promoted_type = torch_dtype_to_trt(promoted_type)
135+
136+
if trt_promoted_type != lhs_val.dtype:
137+
lhs_val = cast_trt_tensor(network, lhs_val, trt_promoted_type, name)
138+
if trt_promoted_type != rhs_val.dtype:
139+
rhs_val = cast_trt_tensor(network, rhs_val, trt_promoted_type, name)
140+
129141
# Check the limitation in the doc string.
130142
if network.has_implicit_batch_dimension:
131143
if is_lhs_trt_tensor and not is_rhs_trt_tensor:

py/torch_tensorrt/fx/test/converters/acc_op/test_binary_ops.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,23 @@ def forward(self, x):
5757
inputs = [torch.rand(1, 1) + 1]
5858
self.run_test(m, inputs, expected_ops={expected_op})
5959

60+
@parameterized.expand([(op[1].__name__, op[0], op[1]) for op in elementwise_ops])
61+
def test_elementwise_ops_mismatched_dtypes(
62+
self, name, orig_op: Callable, expected_op
63+
):
64+
class TestModule(nn.Module):
65+
def __init__(self, orig_op):
66+
super().__init__()
67+
self.orig_op = orig_op
68+
69+
def forward(self, x):
70+
return self.orig_op(x.int(), x)
71+
72+
m = TestModule(orig_op)
73+
# Avoid dividing by 0.
74+
inputs = [2 * torch.rand(1, 1, dtype=torch.float) + 1]
75+
self.run_test(m, inputs, expected_ops={expected_op})
76+
6077
@parameterized.expand([(op[1].__name__, op[0], op[1]) for op in elementwise_ops])
6178
def test_elementwise_ops_with_one_constant(
6279
self, name, orig_op: Callable, expected_op

0 commit comments

Comments
 (0)