Skip to content

fix: Add support for truncate_long_and_double in FX #1865

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion py/torch_tensorrt/fx/fx2trt.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
explicit_batch_dimension: bool = False,
explicit_precision: bool = False,
logger_level=None,
truncate_long_and_double=False,
):
super().__init__(module)

Expand Down Expand Up @@ -70,6 +71,7 @@ def __init__(

self.optimization_profiles: Optional[List] = None
self.input_specs = input_specs
self.truncate_long_and_double = truncate_long_and_double
self.input_specs_iter = 0
self.validate_input_specs()
self._cur_node_name: Optional[str] = None
Expand Down Expand Up @@ -306,7 +308,9 @@ def placeholder(self, target, args, kwargs):
self.optimization_profiles[i].set_shape(target, *shape_range)

return self.network.add_input(
name=target, shape=tuple(shape), dtype=torch_dtype_to_trt(dtype)
name=target,
shape=tuple(shape),
dtype=torch_dtype_to_trt(dtype, self.truncate_long_and_double),
)

def call_module(self, target, args, kwargs):
Expand Down
4 changes: 4 additions & 0 deletions py/torch_tensorrt/fx/lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def compile(
use_experimental_fx_rt=False,
correctness_atol=1e-1,
correctness_rtol=1e-1,
truncate_long_and_double=False,
) -> nn.Module:
"""
Takes in original module, input and lowering setting, run lowering workflow to turn module
Expand All @@ -62,6 +63,7 @@ def compile(
cuda_graph_batch_size: Cuda graph batch size, default to be -1.
dynamic_batch: batch dimension (dim=0) is dynamic.
use_experimental_fx_rt: Uses the next generation TRTModule which supports both Python and TorchScript based execution (including in C++).
truncate_long_and_double: Whether to automatically truncate long and double-type tensor inputs to TRT Engines
Returns:
A torch.nn.Module lowered by TensorRT.
"""
Expand All @@ -85,6 +87,7 @@ def compile(
use_experimental_rt=use_experimental_fx_rt,
correctness_atol=correctness_atol,
correctness_rtol=correctness_rtol,
truncate_long_and_double=truncate_long_and_double,
)
lowerer = Lowerer.create(lower_setting=lower_setting)
return lowerer(module, input)
Expand Down Expand Up @@ -129,6 +132,7 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult:
logger_level=trt.Logger.VERBOSE
if self.lower_setting.verbose_log
else trt.Logger.WARNING,
truncate_long_and_double=self.lower_setting.truncate_long_and_double,
)

interp_result: TRTInterpreterResult = interpreter.run(
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/fx/lower_setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,4 @@ class LowerSetting(LowerSettingBasic):
correctness_atol: float = 0.1
correctness_rtol: float = 0.1
use_experimental_rt: bool = False
truncate_long_and_double: bool = False
287 changes: 202 additions & 85 deletions py/torch_tensorrt/fx/test/core/test_trt_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from torch.testing._internal.common_utils import run_tests, TestCase
from torch_tensorrt.fx import InputTensorSpec, TRTInterpreter, TRTModule

# from torch_tensorrt import TRTModuleNext
# from torch_tensorrt import Device
from torch_tensorrt import TRTModuleNext
from torch_tensorrt import Device
from torch_tensorrt.fx.utils import LowerPrecision


Expand Down Expand Up @@ -58,89 +58,206 @@ def forward(self, x):
)


# TODO add unittest.skip later
# class TestTRTModuleNext(TestCase):
# def test_save_and_load_trt_module(self):
# class TestModule(torch.nn.Module):
# def forward(self, x):
# return x + x

# inputs = [torch.randn(1, 1)]
# mod = TestModule().eval()
# ref_output = mod(*inputs)

# mod = acc_tracer.trace(mod, inputs)

# interp = TRTInterpreter(
# mod,
# input_specs=InputTensorSpec.from_tensors(inputs),
# explicit_batch_dimension=True,
# )
# interp_res = interp.run(lower_precision=LowerPrecision.FP32)

# with io.BytesIO() as engine_bytes:
# engine_bytes.write(interp_res.engine.serialize())
# engine_str = engine_bytes.getvalue()

# trt_mod = TRTModuleNext(
# name="TestModule",
# serialized_engine=engine_str,
# input_binding_names=interp_res.input_names,
# output_binding_names=interp_res.output_names,
# target_device=Device(f"cuda:{torch.cuda.current_device()}"),
# )

# torch.save(trt_mod, "trt.pt")
# reload_trt_mod = torch.load("trt.pt")

# torch.testing.assert_allclose(
# reload_trt_mod(inputs[0].cuda()).cpu().reshape_as(ref_output),
# ref_output,
# rtol=1e-04,
# atol=1e-04,
# )
# os.remove(f"{os.getcwd()}/trt.pt")

# def test_save_and_load_state_dict(self):
# class TestModule(torch.nn.Module):
# def forward(self, x):
# return x + x

# inputs = [torch.randn(1, 1)]
# mod = TestModule().eval()
# ref_output = mod(*inputs)

# mod = acc_tracer.trace(mod, inputs)
# interp = TRTInterpreter(
# mod,
# input_specs=InputTensorSpec.from_tensors(inputs),
# explicit_batch_dimension=True,
# )
# interp_res = interp.run(lower_precision=LowerPrecision.FP32)

# with io.BytesIO() as engine_bytes:
# engine_bytes.write(interp_res.engine.serialize())
# engine_str = engine_bytes.getvalue()

# trt_mod = TRTModuleNext(
# name="TestModule",
# serialized_engine=engine_str,
# input_binding_names=interp_res.input_names,
# output_binding_names=interp_res.output_names,
# target_device=Device(f"cuda:{torch.cuda.current_device()}"),
# )

# st = trt_mod.state_dict()

# new_trt_mod = TRTModuleNext()
# new_trt_mod.load_state_dict(st)

# torch.testing.assert_allclose(
# new_trt_mod(inputs[0].cuda()).cpu().reshape_as(ref_output),
# ref_output,
# rtol=1e-04,
# atol=1e-04,
# )
class TestTRTModuleInt64Input(TestCase):
def test_save_and_load_trt_module(self):
class TestModule(torch.nn.Module):
def forward(self, x):
return x + x

inputs = [torch.randn(5, 5).long()]
mod = TestModule().eval()
ref_output = mod(*inputs)

mod = acc_tracer.trace(mod, inputs)
interp = TRTInterpreter(
mod,
input_specs=InputTensorSpec.from_tensors(inputs),
truncate_long_and_double=True,
)
trt_mod = TRTModule(*interp.run(lower_precision=LowerPrecision.FP32))
torch.save(trt_mod, "trt.pt")
reload_trt_mod = torch.load("trt.pt")

torch.testing.assert_close(
reload_trt_mod(inputs[0].cuda()).cpu(),
ref_output,
rtol=1e-04,
atol=1e-04,
check_dtype=False,
)
os.remove(f"{os.getcwd()}/trt.pt")

def test_save_and_load_state_dict(self):
class TestModule(torch.nn.Module):
def forward(self, x):
return x + x

inputs = [torch.randn(5, 5).long()]
mod = TestModule().eval()
ref_output = mod(*inputs)

mod = acc_tracer.trace(mod, inputs)
interp = TRTInterpreter(
mod,
input_specs=InputTensorSpec.from_tensors(inputs),
truncate_long_and_double=True,
)
trt_mod = TRTModule(*interp.run(lower_precision=LowerPrecision.FP32))
st = trt_mod.state_dict()

new_trt_mod = TRTModule()
new_trt_mod.load_state_dict(st)

torch.testing.assert_close(
new_trt_mod(inputs[0].cuda()).cpu(),
ref_output,
rtol=1e-04,
atol=1e-04,
check_dtype=False,
)


class TestTRTModuleFloat64Input(TestCase):
def test_save_and_load_trt_module(self):
class TestModule(torch.nn.Module):
def forward(self, x):
return x + x

inputs = [torch.randn(5, 5).double()]
mod = TestModule().eval()
ref_output = mod(*inputs)

mod = acc_tracer.trace(mod, inputs)
interp = TRTInterpreter(
mod,
input_specs=InputTensorSpec.from_tensors(inputs),
truncate_long_and_double=True,
)
trt_mod = TRTModule(*interp.run(lower_precision=LowerPrecision.FP32))
torch.save(trt_mod, "trt.pt")
reload_trt_mod = torch.load("trt.pt")

torch.testing.assert_close(
reload_trt_mod(inputs[0].cuda()).cpu(),
ref_output,
rtol=1e-04,
atol=1e-04,
check_dtype=False,
)
os.remove(f"{os.getcwd()}/trt.pt")

def test_save_and_load_state_dict(self):
class TestModule(torch.nn.Module):
def forward(self, x):
return x + x

inputs = [torch.randn(5, 5).double()]
mod = TestModule().eval()
ref_output = mod(*inputs)

mod = acc_tracer.trace(mod, inputs)
interp = TRTInterpreter(
mod,
input_specs=InputTensorSpec.from_tensors(inputs),
truncate_long_and_double=True,
)
trt_mod = TRTModule(*interp.run(lower_precision=LowerPrecision.FP32))
st = trt_mod.state_dict()

new_trt_mod = TRTModule()
new_trt_mod.load_state_dict(st)

torch.testing.assert_close(
new_trt_mod(inputs[0].cuda()).cpu(),
ref_output,
rtol=1e-04,
atol=1e-04,
check_dtype=False,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The output data type will be different, since TRT cannot output int64 types

)


class TestTRTModuleNext(TestCase):
def test_save_and_load_trt_module(self):
class TestModule(torch.nn.Module):
def forward(self, x):
return x + x

inputs = [torch.randn(1, 1)]
mod = TestModule().eval()
ref_output = mod(*inputs)

mod = acc_tracer.trace(mod, inputs)

interp = TRTInterpreter(
mod,
input_specs=InputTensorSpec.from_tensors(inputs),
explicit_batch_dimension=True,
)
interp_res = interp.run(lower_precision=LowerPrecision.FP32)

with io.BytesIO() as engine_bytes:
engine_bytes.write(interp_res.engine.serialize())
engine_str = engine_bytes.getvalue()

trt_mod = TRTModuleNext(
name="TestModule",
serialized_engine=engine_str,
input_binding_names=interp_res.input_names,
output_binding_names=interp_res.output_names,
target_device=Device(f"cuda:{torch.cuda.current_device()}"),
)

torch.save(trt_mod, "trt.pt")
reload_trt_mod = torch.load("trt.pt")

torch.testing.assert_allclose(
reload_trt_mod(inputs[0].cuda()).cpu().reshape_as(ref_output),
ref_output,
rtol=1e-04,
atol=1e-04,
)
os.remove(f"{os.getcwd()}/trt.pt")

def test_save_and_load_state_dict(self):
class TestModule(torch.nn.Module):
def forward(self, x):
return x + x

inputs = [torch.randn(1, 1)]
mod = TestModule().eval()
ref_output = mod(*inputs)

mod = acc_tracer.trace(mod, inputs)
interp = TRTInterpreter(
mod,
input_specs=InputTensorSpec.from_tensors(inputs),
explicit_batch_dimension=True,
)
interp_res = interp.run(lower_precision=LowerPrecision.FP32)

with io.BytesIO() as engine_bytes:
engine_bytes.write(interp_res.engine.serialize())
engine_str = engine_bytes.getvalue()

trt_mod = TRTModuleNext(
name="TestModule",
serialized_engine=engine_str,
input_binding_names=interp_res.input_names,
output_binding_names=interp_res.output_names,
target_device=Device(f"cuda:{torch.cuda.current_device()}"),
)

st = trt_mod.state_dict()

new_trt_mod = TRTModuleNext()
new_trt_mod.load_state_dict(st)

torch.testing.assert_allclose(
new_trt_mod(inputs[0].cuda()).cpu().reshape_as(ref_output),
ref_output,
rtol=1e-04,
atol=1e-04,
)


if __name__ == "__main__":
Expand Down
Loading