Skip to content

feat: Support weight-stripped engine and REFIT_IDENTICAL flag #3167

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 59 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
40349a8
support weight-stripped engine and REFIT_IDENTICAL flag
zewenli98 Sep 19, 2024
5d7c677
refactor with new design
zewenli98 Sep 20, 2024
82b7ddc
lint
zewenli98 Oct 1, 2024
9f6a771
samll fix
zewenli98 Oct 1, 2024
7ea3c0f
remove make_refittable
zewenli98 Oct 1, 2024
bf7553b
fast refit -> slow refit
zewenli98 Oct 2, 2024
46e9bc8
fix np.bool_, group_norm
zewenli98 Oct 2, 2024
d783fdd
add immutable_weights
zewenli98 Oct 2, 2024
160588e
skip engine caching for non-refittable engines, slow refit -> fast refit
zewenli98 Oct 2, 2024
493f981
refactored, there are 3 types of engines
zewenli98 Oct 5, 2024
f204104
fix and add tests
zewenli98 Oct 5, 2024
4663c83
fix issues #3206 #3217
zewenli98 Oct 8, 2024
c57ab06
small fix
zewenli98 Oct 15, 2024
402c9b0
resolve comments
zewenli98 Oct 15, 2024
d8e59da
WIP: cache weight-stripped engine
zewenli98 Oct 22, 2024
e8811fd
Merge branch 'main' into weight_stripped_engine
zewenli98 Oct 31, 2024
f2e3f00
redesigned hash func and add constant mapping to fast refit
zewenli98 Nov 4, 2024
31af308
refactor and add tests
zewenli98 Nov 6, 2024
1ae33f4
Merge branch 'main' into weight_stripped_engine
zewenli98 Nov 6, 2024
90bf679
update
zewenli98 Nov 6, 2024
a8a34f6
increase ENGINE_CACHE_SIZE
zewenli98 Nov 6, 2024
285bc90
skip some tests
zewenli98 Nov 7, 2024
2d152cf
fix tests
zewenli98 Nov 7, 2024
d461608
try fixing cumsum
zewenli98 Nov 8, 2024
d57b885
Merge branch 'main' into weight_stripped_engine
zewenli98 Nov 8, 2024
23d68d5
fix windows cross compile, TODO: whether windows support stripping en…
zewenli98 Nov 8, 2024
a928f67
CI debug test 1
zewenli98 Nov 13, 2024
02625ca
CI debug test 2
zewenli98 Nov 14, 2024
c462e40
CI debug test 3
zewenli98 Nov 16, 2024
9ba33b5
Merge branch 'main' into weight_stripped_engine
Nov 19, 2024
3d68039
reduce -n to 4 for converter tests on CI
zewenli98 Nov 20, 2024
2e7ef3b
reduce -n to 4 for converter tests on CI
zewenli98 Nov 20, 2024
9ff165c
simplify test_different_args_dont_share_cached_engine
zewenli98 Nov 22, 2024
8ca8e2d
reduce -n to 2
zewenli98 Nov 22, 2024
f9f2a70
reduce -n to 1
zewenli98 Nov 22, 2024
c69c61a
revert -n back to 4 and chunk converter
zewenli98 Nov 23, 2024
05b560d
change to opt-in feature
zewenli98 Nov 28, 2024
7feea97
fix conflict
zewenli98 Nov 28, 2024
d1521c3
fix typo
zewenli98 Nov 28, 2024
5a193a2
Merge branch 'main' into weight_stripped_engine
Nov 29, 2024
0b345be
small fix
zewenli98 Dec 3, 2024
6754481
Merge branch 'main' into weight_stripped_engine
zewenli98 Dec 6, 2024
4a7e957
update to manylinux2_28-builder
zewenli98 Dec 10, 2024
6e840ba
remove cuda12.6 tests
zewenli98 Dec 10, 2024
9a8473a
remove one_user_validator for native_layer_norm
zewenli98 Dec 10, 2024
6a07767
clear tests
zewenli98 Dec 10, 2024
ed3424a
remove the whole chunk
zewenli98 Dec 10, 2024
ef54239
add cuda12.6 back and export D_GLIBCXX_USE_CXX11_ABI=1
zewenli98 Dec 10, 2024
f166562
fix env
zewenli98 Dec 10, 2024
80aae71
fix container
zewenli98 Dec 10, 2024
676c9ce
fix env
zewenli98 Dec 11, 2024
bf2edc6
fix env
zewenli98 Dec 11, 2024
627d510
fix env
zewenli98 Dec 11, 2024
b393b6f
fix env
zewenli98 Dec 11, 2024
78d72b6
fix env
zewenli98 Dec 11, 2024
a5d3c18
export USE_CXX11_ABI=1 for cuda12.6
zewenli98 Dec 11, 2024
4f02da8
remove chunk
zewenli98 Dec 11, 2024
7d7423a
resolve comments
zewenli98 Dec 12, 2024
9f76304
Merge branch 'main' into weight_stripped_engine
zewenli98 Dec 12, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .github/scripts/generate_binary_build_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,33 +152,33 @@ def initialize_globals(channel: str, build_python_only: bool) -> None:
"12.4": "pytorch/manylinux2_28-builder:cuda12.4",
"12.6": "pytorch/manylinux2_28-builder:cuda12.6",
**{
gpu_arch: f"pytorch/manylinux-builder:rocm{gpu_arch}"
gpu_arch: f"pytorch/manylinux2_28-builder:rocm{gpu_arch}"
for gpu_arch in ROCM_ARCHES
},
CPU: "pytorch/manylinux-builder:cpu",
CPU: "pytorch/manylinux2_28-builder:cpu",
XPU: "pytorch/manylinux2_28-builder:xpu",
# TODO: Migrate CUDA_AARCH64 image to manylinux2_28_aarch64-builder:cuda12.4
CPU_AARCH64: "pytorch/manylinux2_28_aarch64-builder:cpu-aarch64",
CUDA_AARCH64: "pytorch/manylinuxaarch64-builder:cuda12.4",
}
LIBTORCH_CONTAINER_IMAGES = {
**{
(gpu_arch, PRE_CXX11_ABI): f"pytorch/manylinux-builder:cuda{gpu_arch}"
(gpu_arch, PRE_CXX11_ABI): f"pytorch/manylinux2_28-builder:cuda{gpu_arch}"
for gpu_arch in CUDA_ARCHES
},
**{
(gpu_arch, CXX11_ABI): f"pytorch/libtorch-cxx11-builder:cuda{gpu_arch}"
for gpu_arch in CUDA_ARCHES
},
**{
(gpu_arch, PRE_CXX11_ABI): f"pytorch/manylinux-builder:rocm{gpu_arch}"
(gpu_arch, PRE_CXX11_ABI): f"pytorch/manylinux2_28-builder:rocm{gpu_arch}"
for gpu_arch in ROCM_ARCHES
},
**{
(gpu_arch, CXX11_ABI): f"pytorch/libtorch-cxx11-builder:rocm{gpu_arch}"
for gpu_arch in ROCM_ARCHES
},
(CPU, PRE_CXX11_ABI): "pytorch/manylinux-builder:cpu",
(CPU, PRE_CXX11_ABI): "pytorch/manylinux2_28-builder:cpu",
(CPU, CXX11_ABI): "pytorch/libtorch-cxx11-builder:cpu",
}

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/build-test-linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ jobs:
export CI_BUILD=1
pushd .
cd tests/py/dynamo
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 8 conversion/
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 4 conversion/
popd

tests-py-dynamo-fe:
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/build-test-tensorrt-linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ jobs:
export CI_BUILD=1
pushd .
cd tests/py/dynamo
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 8 conversion/
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 4 conversion/
popd

tests-py-dynamo-fe:
Expand Down Expand Up @@ -314,4 +314,4 @@ jobs:

concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ inputs.repository }}-${{ github.event_name == 'workflow_dispatch' }}-${{ inputs.job-name }}
cancel-in-progress: true
cancel-in-progress: true
4 changes: 2 additions & 2 deletions .github/workflows/build-test-tensorrt-windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ jobs:
export CI_BUILD=1
pushd .
cd tests/py/dynamo
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 8 conversion/
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 4 conversion/
popd

tests-py-dynamo-fe:
Expand Down Expand Up @@ -298,4 +298,4 @@ jobs:

concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ inputs.repository }}-${{ github.event_name == 'workflow_dispatch' }}-${{ inputs.job-name }}
cancel-in-progress: true
cancel-in-progress: true
2 changes: 1 addition & 1 deletion .github/workflows/build-test-windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ jobs:
export CI_BUILD=1
pushd .
cd tests/py/dynamo
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 8 conversion/
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 4 conversion/
popd

tests-py-dynamo-fe:
Expand Down
2 changes: 1 addition & 1 deletion examples/dynamo/engine_caching_bert_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def compile_bert(iterations=3):
"truncate_double": True,
"debug": False,
"min_block_size": 1,
"make_refittable": True,
"immutable_weights": False,
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
"engine_cache_dir": "/tmp/torch_trt_bert_engine_cache",
Expand Down
8 changes: 4 additions & 4 deletions examples/dynamo/engine_caching_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def remove_timing_cache(path=TIMING_CACHE_PATH):
# in a subsequent compilation, either as part of this session or a new session, the cache will
# pull the built engine and **refit** the weights which can reduce compilation times by orders of magnitude.
# As such, in order to insert a new engine into the cache (i.e. ``cache_built_engines=True``),
# the engine must be refittable (``make_refittable=True``). See :ref:`refit_engine_example` for more details.
# the engine must be refittable (``immutable_weights=False``). See :ref:`refit_engine_example` for more details.


def torch_compile(iterations=3):
Expand Down Expand Up @@ -97,7 +97,7 @@ def torch_compile(iterations=3):
"enabled_precisions": enabled_precisions,
"debug": debug,
"min_block_size": min_block_size,
"make_refittable": True,
"immutable_weights": False,
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
},
Expand Down Expand Up @@ -157,7 +157,7 @@ def dynamo_compile(iterations=3):
enabled_precisions=enabled_precisions,
debug=debug,
min_block_size=min_block_size,
make_refittable=True,
immutable_weights=False,
cache_built_engines=cache_built_engines,
reuse_cached_engines=reuse_cached_engines,
engine_cache_size=1 << 30, # 1GB
Expand Down Expand Up @@ -268,7 +268,7 @@ def torch_compile_my_cache(iterations=3):
"enabled_precisions": enabled_precisions,
"debug": debug,
"min_block_size": min_block_size,
"make_refittable": True,
"immutable_weights": False,
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
"custom_engine_cache": engine_cache,
Expand Down
4 changes: 2 additions & 2 deletions examples/dynamo/mutable_torchtrt_module_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
settings = {
"use_python": False,
"enabled_precisions": {torch.float32},
"make_refittable": True,
"immutable_weights": False,
}

model = models.resnet18(pretrained=True).eval().to("cuda")
Expand Down Expand Up @@ -80,7 +80,7 @@
"use_python_runtime": True,
"enabled_precisions": {torch.float16},
"debug": True,
"make_refittable": True,
"immutable_weights": False,
}

model_id = "runwayml/stable-diffusion-v1-5"
Expand Down
4 changes: 2 additions & 2 deletions examples/dynamo/refit_engine_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
# ---------------------------------------
#
# The inital step is to compile a module and save it as with a normal. Note that there is an
# additional parameter `make_refittable` that is set to `True`. This parameter is used to
# additional parameter `immutable_weights` that is set to `False`. This parameter is used to
# indicate that the engine being built should support weight refitting later. Engines built without
# these setttings will not be able to be refit.
#
Expand All @@ -69,7 +69,7 @@
debug=debug,
min_block_size=min_block_size,
torch_executed_ops=torch_executed_ops,
make_refittable=True,
immutable_weights=False,
reuse_cached_engines=False,
) # Output is a torch.fx.GraphModule

Expand Down
2 changes: 1 addition & 1 deletion py/ci/Dockerfile.ci
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM pytorch/manylinux-builder:cuda12.4
FROM pytorch/manylinux2_28-builder:cuda12.6

RUN yum install -y ninja-build

Expand Down
4 changes: 2 additions & 2 deletions py/torch_tensorrt/_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def _from(
return dtype.f32
elif t == np.float64:
return dtype.f64
elif t == np.bool:
elif t == np.bool_:
return dtype.b
# TODO: Consider using ml_dtypes when issues like this are resolved:
# https://github.com/pytorch/pytorch/issues/109873
Expand Down Expand Up @@ -1384,7 +1384,7 @@ def current_platform(cls) -> Platform:
def __str__(self) -> str:
return str(self.name)

@needs_torch_tensorrt_runtime
@needs_torch_tensorrt_runtime # type: ignore
def _to_serialized_rt_platform(self) -> str:
val: str = torch.ops.tensorrt._platform_unknown()

Expand Down
Loading
Loading