Skip to content

Commit 65bc360

Browse files
committed
fix: Add support for truncate_long_and_double in FX
- Add support and testing for `double` type inputs
1 parent 5fe7c23 commit 65bc360

File tree

6 files changed

+100
-6
lines changed

6 files changed

+100
-6
lines changed

py/torch_tensorrt/fx/fx2trt.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def __init__(
4040
explicit_batch_dimension: bool = False,
4141
explicit_precision: bool = False,
4242
logger_level=None,
43+
truncate_long_and_double=False,
4344
):
4445
super().__init__(module)
4546

@@ -69,6 +70,7 @@ def __init__(
6970

7071
self.optimization_profiles: Optional[List] = None
7172
self.input_specs = input_specs
73+
self.truncate_long_and_double = truncate_long_and_double
7274
self.input_specs_iter = 0
7375
self.validate_input_specs()
7476
self._cur_node_name: Optional[str] = None
@@ -300,7 +302,9 @@ def placeholder(self, target, args, kwargs):
300302
self.optimization_profiles[i].set_shape(target, *shape_range)
301303

302304
return self.network.add_input(
303-
name=target, shape=tuple(shape), dtype=torch_dtype_to_trt(dtype)
305+
name=target,
306+
shape=tuple(shape),
307+
dtype=torch_dtype_to_trt(dtype, self.truncate_long_and_double),
304308
)
305309

306310
def call_module(self, target, args, kwargs):

py/torch_tensorrt/fx/lower.py

+4
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def compile(
4141
dynamic_batch=True,
4242
is_aten=False,
4343
use_experimental_fx_rt=False,
44+
truncate_long_and_double=False,
4445
) -> nn.Module:
4546
"""
4647
Takes in original module, input and lowering setting, run lowering workflow to turn module
@@ -60,6 +61,7 @@ def compile(
6061
cuda_graph_batch_size: Cuda graph batch size, default to be -1.
6162
dynamic_batch: batch dimension (dim=0) is dynamic.
6263
use_experimental_fx_rt: Uses the next generation TRTModule which supports both Python and TorchScript based execution (including in C++).
64+
truncate_long_and_double: Whether to automatically truncate long and double-type tensor inputs to TRT Engines
6365
Returns:
6466
A torch.nn.Module lowered by TensorRT.
6567
"""
@@ -81,6 +83,7 @@ def compile(
8183
dynamic_batch=dynamic_batch,
8284
is_aten=is_aten,
8385
use_experimental_rt=use_experimental_fx_rt,
86+
truncate_long_and_double=truncate_long_and_double,
8487
)
8588
lowerer = Lowerer.create(lower_setting=lower_setting)
8689
return lowerer(module, input)
@@ -125,6 +128,7 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult:
125128
logger_level=trt.Logger.VERBOSE
126129
if self.lower_setting.verbose_log
127130
else trt.Logger.WARNING,
131+
truncate_long_and_double=self.lower_setting.truncate_long_and_double,
128132
)
129133

130134
interp_result: TRTInterpreterResult = interpreter.run(

py/torch_tensorrt/fx/lower_setting.py

+1
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,4 @@ class LowerSetting(LowerSettingBasic):
101101
correctness_atol: float = 0.1
102102
correctness_rtol: float = 0.1
103103
use_experimental_rt: bool = False
104+
truncate_long_and_double: bool = False

py/torch_tensorrt/fx/test/core/test_trt_module.py

+57
Original file line numberDiff line numberDiff line change
@@ -199,5 +199,62 @@ def forward(self, x):
199199
)
200200

201201

202+
class TestTRTModuleFloat64Input(TestCase):
203+
def test_save_and_load_trt_module(self):
204+
class TestModule(torch.nn.Module):
205+
def forward(self, x):
206+
return x + x
207+
208+
inputs = [torch.randn(5, 5).double()]
209+
mod = TestModule().eval()
210+
ref_output = mod(*inputs)
211+
212+
mod = acc_tracer.trace(mod, inputs)
213+
interp = TRTInterpreter(
214+
mod,
215+
input_specs=InputTensorSpec.from_tensors(inputs),
216+
)
217+
trt_mod = TRTModule(*interp.run(lower_precision=LowerPrecision.FP32))
218+
torch.save(trt_mod, "trt.pt")
219+
reload_trt_mod = torch.load("trt.pt")
220+
221+
torch.testing.assert_close(
222+
reload_trt_mod(inputs[0].cuda()).cpu(),
223+
ref_output,
224+
rtol=1e-04,
225+
atol=1e-04,
226+
check_dtype=False,
227+
)
228+
os.remove(f"{os.getcwd()}/trt.pt")
229+
230+
def test_save_and_load_state_dict(self):
231+
class TestModule(torch.nn.Module):
232+
def forward(self, x):
233+
return x + x
234+
235+
inputs = [torch.randn(5, 5).double()]
236+
mod = TestModule().eval()
237+
ref_output = mod(*inputs)
238+
239+
mod = acc_tracer.trace(mod, inputs)
240+
interp = TRTInterpreter(
241+
mod,
242+
input_specs=InputTensorSpec.from_tensors(inputs),
243+
)
244+
trt_mod = TRTModule(*interp.run(lower_precision=LowerPrecision.FP32))
245+
st = trt_mod.state_dict()
246+
247+
new_trt_mod = TRTModule()
248+
new_trt_mod.load_state_dict(st)
249+
250+
torch.testing.assert_close(
251+
new_trt_mod(inputs[0].cuda()).cpu(),
252+
ref_output,
253+
rtol=1e-04,
254+
atol=1e-04,
255+
check_dtype=False,
256+
)
257+
258+
202259
if __name__ == "__main__":
203260
run_tests()

py/torch_tensorrt/fx/trt_module.py

+9
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,15 @@ def forward(self, *inputs):
156156
inputs = (
157157
inputs[:i] + (inputs[i].to(torch.int32),) + inputs[i + 1 :]
158158
)
159+
elif (
160+
inputs[i].dtype == torch.float64
161+
and self.input_dtypes[i] == torch.float32
162+
):
163+
inputs = (
164+
inputs[:i]
165+
+ (inputs[i].to(torch.float32),)
166+
+ inputs[i + 1 :]
167+
)
159168

160169
assert (
161170
inputs[i].dtype == self.input_dtypes[i]

py/torch_tensorrt/fx/utils.py

+24-5
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@ class LowerPrecision(Enum):
2525
INT8 = "int8"
2626

2727

28-
def torch_dtype_to_trt(dtype: torch.dtype) -> TRTDataType:
28+
def torch_dtype_to_trt(
29+
dtype: torch.dtype, truncate_long_and_double: bool = False
30+
) -> TRTDataType:
2931
"""
3032
Convert PyTorch data types to TensorRT data types.
3133
@@ -42,14 +44,31 @@ def torch_dtype_to_trt(dtype: torch.dtype) -> TRTDataType:
4244
elif dtype == torch.int32:
4345
return trt.int32
4446
elif dtype == torch.int64:
45-
_LOGGER.warn(
46-
"Detected Int64 Input, Casting to Int32 for TRT Engine Compatibility"
47-
)
48-
return trt.int32
47+
if truncate_long_and_double:
48+
_LOGGER.warn(
49+
"Detected Int64 Input, Casting to Int32 for TRT Engine Compatibility"
50+
)
51+
return trt.int32
52+
else:
53+
raise AssertionError(
54+
"Detected Int64 Input, enable truncate_long_and_double=True to cast "
55+
+ "input to Int32 for TRT Engine"
56+
)
4957
elif dtype == torch.float16:
5058
return trt.float16
5159
elif dtype == torch.float32:
5260
return trt.float32
61+
elif dtype == torch.float64:
62+
if truncate_long_and_double:
63+
_LOGGER.warn(
64+
"Detected Float64 Input, Casting to Float32 for TRT Engine Compatibility"
65+
)
66+
return trt.float32
67+
else:
68+
raise AssertionError(
69+
"Detected Float64 Input, enable truncate_long_and_double=True to cast "
70+
+ "input to Float32 for TRT Engine"
71+
)
5372
else:
5473
raise TypeError("%s is not supported by tensorrt" % dtype)
5574

0 commit comments

Comments
 (0)