diff --git a/README.md b/README.md index 550a95072e..f1c6e608fa 100644 --- a/README.md +++ b/README.md @@ -217,6 +217,14 @@ bazel build //:libtorchtrt --compilation_mode opt ``` shell bazel build //:libtorchtrt --compilation_mode=dbg ``` +### Install only FX module + +FX2TRT is python based module in torch_tensorrt. If the users would like to install only the FX module of torch_tensorrt, please run +``` shell +pushd py +python3 setup.py install --fx2trt-only +popd +``` ### Native compilation on NVIDIA Jetson AGX We performed end to end testing on Jetson platform using Jetpack SDK 4.6. @@ -316,4 +324,4 @@ Take a look at the [CONTRIBUTING.md](CONTRIBUTING.md) ## License -The Torch-TensorRT license can be found in the LICENSE file. It is licensed with a BSD Style licence \ No newline at end of file +The Torch-TensorRT license can be found in the LICENSE file. It is licensed with a BSD Style licence diff --git a/py/setup.py b/py/setup.py index 890a0e1e8e..4046b71221 100644 --- a/py/setup.py +++ b/py/setup.py @@ -23,11 +23,14 @@ JETPACK_VERSION = None __version__ = '1.2.0a0' - +FX2TRT_ONLY = False def get_git_revision_short_hash() -> str: return subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD']).decode('ascii').strip() +if "--fx2trt-only" in sys.argv: + FX2TRT_ONLY = True + sys.argv.remove("--fx2trt-only") if "--release" not in sys.argv: __version__ = __version__ + "+" + get_git_revision_short_hash() @@ -138,11 +141,14 @@ def finalize_options(self): develop.finalize_options(self) def run(self): - global CXX11_ABI - build_libtorchtrt_pre_cxx11_abi(develop=True, cxx11_abi=CXX11_ABI) - gen_version_file() - copy_libtorchtrt() - develop.run(self) + if FX2TRT_ONLY: + develop.run(self) + else: + global CXX11_ABI + build_libtorchtrt_pre_cxx11_abi(develop=True, cxx11_abi=CXX11_ABI) + gen_version_file() + copy_libtorchtrt() + develop.run(self) class InstallCommand(install): @@ -155,11 +161,14 @@ def finalize_options(self): install.finalize_options(self) def run(self): - global CXX11_ABI - build_libtorchtrt_pre_cxx11_abi(develop=False, cxx11_abi=CXX11_ABI) - gen_version_file() - copy_libtorchtrt() - install.run(self) + if FX2TRT_ONLY: + install.run(self) + else: + global CXX11_ABI + build_libtorchtrt_pre_cxx11_abi(develop=False, cxx11_abi=CXX11_ABI) + gen_version_file() + copy_libtorchtrt() + install.run(self) class BdistCommand(bdist_wheel): @@ -254,6 +263,23 @@ def run(self): ] + (["-D_GLIBCXX_USE_CXX11_ABI=1"] if CXX11_ABI else ["-D_GLIBCXX_USE_CXX11_ABI=0"]), undef_macros=["NDEBUG"]) ] +if FX2TRT_ONLY: + ext_modules=None + packages=[ + "torch_tensorrt.fx", + "torch_tensorrt.fx.converters", + "torch_tensorrt.fx.passes", + "torch_tensorrt.fx.tools", + "torch_tensorrt.fx.tracer.acc_tracer", + ] + package_dir={ + "torch_tensorrt.fx": "torch_tensorrt/fx", + "torch_tensorrt.fx.converters": "torch_tensorrt/fx/converters", + "torch_tensorrt.fx.passes": "torch_tensorrt/fx/passes", + "torch_tensorrt.fx.tools": "torch_tensorrt/fx/tools", + "torch_tensorrt.fx.tracer.acc_tracer": "torch_tensorrt/fx/tracer/acc_tracer", + } + with open("README.md", "r", encoding="utf-8") as fh: long_description = fh.read() @@ -282,7 +308,8 @@ def run(self): }, zip_safe=False, license="BSD", - packages=find_packages(), + packages=packages if FX2TRT_ONLY else find_packages(), + package_dir=package_dir if FX2TRT_ONLY else {}, classifiers=[ "Development Status :: 5 - Stable", "Environment :: GPU :: NVIDIA CUDA", "License :: OSI Approved :: BSD License", "Intended Audience :: Developers", @@ -311,4 +338,4 @@ def run(self): exclude_package_data={ '': ['*.cpp'], 'torch_tensorrt': ['csrc/*.cpp'], - }) + }),