diff --git a/py/torch_tensorrt/dynamo/__init__.py b/py/torch_tensorrt/dynamo/__init__.py index ea1778edfe..31f0b61fff 100644 --- a/py/torch_tensorrt/dynamo/__init__.py +++ b/py/torch_tensorrt/dynamo/__init__.py @@ -1,2 +1,6 @@ -from torch_tensorrt.dynamo import fx_ts_compat -from .backend import compile +import torch +from packaging import version + +if version.parse(torch.__version__) >= version.parse("2.1.dev"): + from torch_tensorrt.dynamo import fx_ts_compat + from .backend import compile diff --git a/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py b/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py index edcce20d65..217aee973e 100644 --- a/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py +++ b/py/torch_tensorrt/fx/tracer/dispatch_tracer/aten_tracer.py @@ -2,10 +2,11 @@ import sys from contextlib import contextmanager from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Union +from packaging import version import torch -if not torch.__version__.startswith("1"): +if version.parse(torch.__version__) >= version.parse("2.dev"): import torch._dynamo as torchdynamo from torch.fx.passes.infra.pass_base import PassResult