diff --git a/examples/fx/lower_example.py b/examples/fx/lower_example.py index 7f3b374f44..cd9215712b 100644 --- a/examples/fx/lower_example.py +++ b/examples/fx/lower_example.py @@ -4,7 +4,7 @@ import torch import torchvision -from torch_tensorrt.fx.lower import compile +from torch_tensorrt.fx import compile from torch_tensorrt.fx.utils import LowerPrecision diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index 8b5f235531..18b9901c56 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -7,7 +7,6 @@ from enum import Enum import torch_tensorrt.fx -import torch_tensorrt.fx.lower from torch_tensorrt.fx.utils import LowerPrecision @@ -140,7 +139,7 @@ def compile( else: raise ValueError(f"Precision {enabled_precisions} not supported on FX") - return torch_tensorrt.fx.lower.compile( + return torch_tensorrt.fx.compile( module, inputs, lower_precision=lower_precision, diff --git a/py/torch_tensorrt/fx/__init__.py b/py/torch_tensorrt/fx/__init__.py index c1c42c446f..03eb7174b5 100644 --- a/py/torch_tensorrt/fx/__init__.py +++ b/py/torch_tensorrt/fx/__init__.py @@ -11,5 +11,6 @@ from .input_tensor_spec import generate_input_specs, InputTensorSpec # noqa from .lower_setting import LowerSetting # noqa from .trt_module import TRTModule # noqa +from .lower import compile # usort: skip #noqa logging.basicConfig(level=logging.INFO)