diff --git a/.github/scripts/filter-matrix.py b/.github/scripts/filter-matrix.py index 3710539f59..69ee24080a 100644 --- a/.github/scripts/filter-matrix.py +++ b/.github/scripts/filter-matrix.py @@ -3,8 +3,9 @@ import argparse import json import sys +from typing import List -disabled_python_versions = "3.13" +disabled_python_versions: List[str] = [] def main(args: list[str]) -> None: diff --git a/.github/scripts/generate-tensorrt-test-matrix.py b/.github/scripts/generate-tensorrt-test-matrix.py index 546116d7c2..02b0f746ca 100644 --- a/.github/scripts/generate-tensorrt-test-matrix.py +++ b/.github/scripts/generate-tensorrt-test-matrix.py @@ -28,6 +28,10 @@ # please update the future tensorRT version you want to test here TENSORRT_VERSIONS_DICT = { "windows": { + "10.3.0": { + "urls": "https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.3.0/zip/TensorRT-10.3.0.26.Windows.win10.cuda-12.5.zip", + "strip_prefix": "TensorRT-10.3.0.26", + }, "10.7.0": { "urls": "https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.7.0/zip/TensorRT-10.7.0.23.Windows.win10.cuda-12.6.zip", "strip_prefix": "TensorRT-10.7.0.23", @@ -42,6 +46,10 @@ }, }, "linux": { + "10.3.0": { + "urls": "https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.3.0/tars/TensorRT-10.3.0.26.Linux.x86_64-gnu.cuda-12.5.tar.gz", + "strip_prefix": "TensorRT-10.3.0.26", + }, "10.7.0": { "urls": "https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.7.0/tars/TensorRT-10.7.0.23.Linux.x86_64-gnu.cuda-12.6.tar.gz", "strip_prefix": "TensorRT-10.7.0.23", diff --git a/.github/scripts/generate_binary_build_matrix.py b/.github/scripts/generate_binary_build_matrix.py index 52063ce9fc..331f1ad6ec 100644 --- a/.github/scripts/generate_binary_build_matrix.py +++ b/.github/scripts/generate_binary_build_matrix.py @@ -18,15 +18,16 @@ import sys from typing import Any, Callable, Dict, List, Optional, Tuple +PYTHON_VERSIONS_FOR_PR_BUILD = ["3.11"] PYTHON_ARCHES_DICT = { - "nightly": ["3.9", "3.10", "3.11", "3.12"], - "test": ["3.9", "3.10", "3.11", "3.12"], - "release": ["3.9", "3.10", "3.11", "3.12"], + "nightly": ["3.9", "3.10", "3.11", "3.12", "3.13"], + "test": ["3.9", "3.10", "3.11", "3.12", "3.13"], + "release": ["3.9", "3.10", "3.11", "3.12", "3.13"], } CUDA_ARCHES_DICT = { "nightly": ["11.8", "12.6", "12.8"], "test": ["11.8", "12.6", "12.8"], - "release": ["11.8", "12.6", "12.8"], + "release": ["11.8", "12.4", "12.6"], } ROCM_ARCHES_DICT = { "nightly": ["6.1", "6.2"], @@ -422,11 +423,6 @@ def generate_wheels_matrix( # Define default python version python_versions = list(PYTHON_ARCHES) - # If the list of python versions is set explicitly by the caller, stick with it instead - # of trying to add more versions behind the scene - if channel == NIGHTLY and (os in (LINUX, MACOS_ARM64, LINUX_AARCH64)): - python_versions += ["3.13"] - if os == LINUX: # NOTE: We only build manywheel packages for linux package_type = "manywheel" @@ -456,7 +452,7 @@ def generate_wheels_matrix( arches += [XPU] if limit_pr_builds: - python_versions = [python_versions[0]] + python_versions = PYTHON_VERSIONS_FOR_PR_BUILD global WHEEL_CONTAINER_IMAGES diff --git a/.github/workflows/build-test-linux.yml b/.github/workflows/build-test-linux.yml index 024afd8c62..4d252b24e4 100644 --- a/.github/workflows/build-test-linux.yml +++ b/.github/workflows/build-test-linux.yml @@ -23,7 +23,6 @@ jobs: test-infra-ref: main with-rocm: false with-cpu: false - python-versions: '["3.11", "3.12", "3.10", "3.9"]' filter-matrix: needs: [generate-matrix] diff --git a/.github/workflows/build-test-windows.yml b/.github/workflows/build-test-windows.yml index f78218e75d..2ee31b4b74 100644 --- a/.github/workflows/build-test-windows.yml +++ b/.github/workflows/build-test-windows.yml @@ -23,7 +23,6 @@ jobs: test-infra-ref: main with-rocm: false with-cpu: false - python-versions: '["3.11", "3.12", "3.10", "3.9"]' substitute-runner: needs: generate-matrix diff --git a/py/torch_tensorrt/_features.py b/py/torch_tensorrt/_features.py index 8da7ac6fff..29ab495fec 100644 --- a/py/torch_tensorrt/_features.py +++ b/py/torch_tensorrt/_features.py @@ -14,6 +14,7 @@ "torch_tensorrt_runtime", "dynamo_frontend", "fx_frontend", + "refit", ], ) @@ -36,9 +37,10 @@ _TORCHTRT_RT_AVAIL = _TS_FE_AVAIL or os.path.isfile(linked_file_runtime_full_path) _DYNAMO_FE_AVAIL = version.parse(sanitized_torch_version()) >= version.parse("2.1.dev") _FX_FE_AVAIL = True +_REFIT_AVAIL = version.parse(sys.version.split()[0]) < version.parse("3.13") ENABLED_FEATURES = FeatureSet( - _TS_FE_AVAIL, _TORCHTRT_RT_AVAIL, _DYNAMO_FE_AVAIL, _FX_FE_AVAIL + _TS_FE_AVAIL, _TORCHTRT_RT_AVAIL, _DYNAMO_FE_AVAIL, _FX_FE_AVAIL, _REFIT_AVAIL ) @@ -62,6 +64,22 @@ def not_implemented(*args: List[Any], **kwargs: Dict[str, Any]) -> Any: return wrapper +def needs_refit(f: Callable[..., Any]) -> Callable[..., Any]: + def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> Any: + if ENABLED_FEATURES.refit: + return f(*args, **kwargs) + else: + + def not_implemented(*args: List[Any], **kwargs: Dict[str, Any]) -> Any: + raise NotImplementedError( + "Refit feature is currently not available in Python 3.13 or higher" + ) + + return not_implemented(*args, **kwargs) + + return wrapper + + T = TypeVar("T") diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index c128e9cc82..a1eda40b2d 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -11,6 +11,7 @@ from torch.export import ExportedProgram from torch.fx.experimental.proxy_tensor import unset_fake_temporarily from torch_tensorrt._enums import dtype +from torch_tensorrt._features import needs_refit from torch_tensorrt._Input import Input from torch_tensorrt.dynamo import partitioning from torch_tensorrt.dynamo._exporter import inline_torch_modules @@ -47,6 +48,7 @@ logger = logging.getLogger(__name__) +@needs_refit def construct_refit_mapping( module: torch.fx.GraphModule, inputs: Sequence[Input], @@ -108,8 +110,11 @@ def construct_refit_mapping( return weight_map +@needs_refit def construct_refit_mapping_from_weight_name_map( - weight_name_map: dict[Any, Any], state_dict: dict[Any, Any] + weight_name_map: dict[Any, Any], + state_dict: dict[Any, Any], + settings: CompilationSettings, ) -> dict[Any, Any]: engine_weight_map = {} for engine_weight_name, (sd_weight_name, np_weight_type) in weight_name_map.items(): @@ -120,7 +125,9 @@ def construct_refit_mapping_from_weight_name_map( # If weights is not in sd, we can leave it unchanged continue else: - engine_weight_map[engine_weight_name] = state_dict[sd_weight_name] + engine_weight_map[engine_weight_name] = state_dict[sd_weight_name].to( + to_torch_device(settings.device) + ) engine_weight_map[engine_weight_name] = ( engine_weight_map[engine_weight_name] @@ -134,6 +141,7 @@ def construct_refit_mapping_from_weight_name_map( return engine_weight_map +@needs_refit def _refit_single_trt_engine_with_gm( new_gm: torch.fx.GraphModule, old_engine: trt.ICudaEngine, @@ -163,7 +171,7 @@ def _refit_single_trt_engine_with_gm( "constant_mapping", {} ) # type: ignore mapping = construct_refit_mapping_from_weight_name_map( - weight_name_map, new_gm.state_dict() + weight_name_map, new_gm.state_dict(), settings ) constant_mapping_with_type = {} @@ -213,6 +221,7 @@ def _refit_single_trt_engine_with_gm( raise AssertionError("Refitting failed.") +@needs_refit def refit_module_weights( compiled_module: torch.fx.GraphModule | ExportedProgram, new_weight_module: ExportedProgram, diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 17f2fccbff..fde07bf1f5 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -26,6 +26,7 @@ from torch.fx.passes.shape_prop import TensorMetadata from torch.utils._python_dispatch import _disable_current_modes from torch_tensorrt._enums import dtype +from torch_tensorrt._features import needs_refit from torch_tensorrt._Input import Input from torch_tensorrt.dynamo import _defaults from torch_tensorrt.dynamo._engine_cache import BaseEngineCache @@ -44,7 +45,7 @@ get_trt_tensor, to_torch, ) -from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, get_model_device, to_torch_device +from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, to_torch_device from torch_tensorrt.fx.observer import Observer from torch_tensorrt.logging import TRT_LOGGER @@ -434,6 +435,7 @@ def check_weight_equal( except Exception: return torch.all(sd_weight == network_weight) + @needs_refit def _save_weight_mapping(self) -> None: """ Construct the weight name mapping from engine weight name to state_dict weight name. @@ -491,15 +493,10 @@ def _save_weight_mapping(self) -> None: _LOGGER.info("Building weight name mapping...") # Stage 1: Name mapping torch_device = to_torch_device(self.compilation_settings.device) - gm_is_on_cuda = get_model_device(self.module).type == "cuda" - if not gm_is_on_cuda: - # If the model original position is on CPU, move it GPU - sd = { - k: v.reshape(-1).to(torch_device) - for k, v in self.module.state_dict().items() - } - else: - sd = {k: v.reshape(-1) for k, v in self.module.state_dict().items()} + sd = { + k: v.reshape(-1).to(torch_device) + for k, v in self.module.state_dict().items() + } weight_name_map: dict[str, Any] = {} np_map = {} constant_mapping = {} @@ -583,6 +580,7 @@ def _save_weight_mapping(self) -> None: gc.collect() torch.cuda.empty_cache() + @needs_refit def _insert_engine_to_cache(self, hash_val: str, serialized_engine: bytes) -> None: # TODO: @Evan is waiting for TRT's feature to cache the weight-stripped engine # if not self.compilation_settings.strip_engine_weights: @@ -610,6 +608,7 @@ def _insert_engine_to_cache(self, hash_val: str, serialized_engine: bytes) -> No ), ) + @needs_refit def _pull_cached_engine(self, hash_val: str) -> Optional[TRTInterpreterResult]: # query the cached TRT engine cached_data = self.engine_cache.check(hash_val) # type: ignore[union-attr] @@ -720,7 +719,7 @@ def run( if self.compilation_settings.reuse_cached_engines: interpreter_result = self._pull_cached_engine(hash_val) if interpreter_result is not None: # hit the cache - return interpreter_result + return interpreter_result # type: ignore[no-any-return] self._construct_trt_network_def() diff --git a/pyproject.toml b/pyproject.toml index d7f9d16ea8..8fc6258f6f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [build-system] requires = [ - "setuptools>=68.0.0", + "setuptools>=77.0.0", "packaging>=23.1", "wheel>=0.40.0", "ninja>=1.11.0", diff --git a/setup.py b/setup.py index 09933307c8..9f74cdb9d0 100644 --- a/setup.py +++ b/setup.py @@ -18,12 +18,12 @@ import torch import yaml from setuptools import Extension, find_namespace_packages, setup +from setuptools.command.bdist_wheel import bdist_wheel from setuptools.command.build_ext import build_ext from setuptools.command.develop import develop from setuptools.command.editable_wheel import editable_wheel from setuptools.command.install import install from torch.utils.cpp_extension import IS_WINDOWS, BuildExtension, CUDAExtension -from wheel.bdist_wheel import bdist_wheel __version__: str = "0.0.0" __cuda_version__: str = "0.0" diff --git a/tests/modules/custom_models.py b/tests/modules/custom_models.py index b62faffd1b..4906a1d495 100644 --- a/tests/modules/custom_models.py +++ b/tests/modules/custom_models.py @@ -3,7 +3,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -from transformers import BertConfig, BertModel, BertTokenizer # Sample Pool Model (for testing plugin serialization) @@ -165,6 +164,8 @@ def forward(self, z: List[torch.Tensor]): def BertModule(): + from transformers import BertConfig, BertModel, BertTokenizer + enc = BertTokenizer.from_pretrained("google-bert/bert-base-uncased") text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]" tokenized_text = enc.tokenize(text) diff --git a/tests/modules/hub.py b/tests/modules/hub.py index 0cce523fb3..d87635b435 100644 --- a/tests/modules/hub.py +++ b/tests/modules/hub.py @@ -4,10 +4,7 @@ import custom_models as cm import timm import torch -import torch.nn as nn -import torch.nn.functional as F import torchvision.models as models -from transformers import BertConfig, BertModel, BertTokenizer torch.hub._validate_not_a_forked_repo = lambda a, b, c: True diff --git a/tests/py/dynamo/models/test_engine_cache.py b/tests/py/dynamo/models/test_engine_cache.py index 36bf5edc95..0bc7c665b3 100644 --- a/tests/py/dynamo/models/test_engine_cache.py +++ b/tests/py/dynamo/models/test_engine_cache.py @@ -250,6 +250,10 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): msg=f"Engine caching didn't speed up the compilation. Time taken without engine caching: {times[0]} ms, time taken with engine caching: {times[2]} ms", ) + @unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Engine caching requires refit feature that is not supported in Python 3.13 or higher", + ) def test_dynamo_compile_with_custom_engine_cache(self): model = models.resnet18(pretrained=True).eval().to("cuda") @@ -314,6 +318,10 @@ def test_dynamo_compile_with_custom_engine_cache(self): for h, count in custom_engine_cache.hashes.items() ] + @unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Engine caching requires refit feature that is not supported in Python 3.13 or higher", + ) def test_dynamo_compile_change_input_shape(self): """Runs compilation 3 times, the cache should miss each time""" model = models.resnet18(pretrained=True).eval().to("cuda") @@ -346,6 +354,10 @@ def test_dynamo_compile_change_input_shape(self): for h, count in custom_engine_cache.hashes.items() ] + @unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Engine caching requires refit feature that is not supported in Python 3.13 or higher", + ) @pytest.mark.xfail def test_torch_compile_with_default_disk_engine_cache(self): # Custom Engine Cache @@ -485,6 +497,10 @@ def test_torch_compile_with_custom_engine_cache(self): for h, count in custom_engine_cache.hashes.items() ] + @unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Engine caching requires refit feature that is not supported in Python 3.13 or higher", + ) def test_torch_trt_compile_change_input_shape(self): # Custom Engine Cache model = models.resnet18(pretrained=True).eval().to("cuda") @@ -611,6 +627,10 @@ def forward(self, c, d): assertions.assertEqual(hash1, hash2) # @unittest.skip("benchmark on small models") + @unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Engine caching requires refit feature that is not supported in Python 3.13 or higher", + ) def test_caching_small_model(self): from torch_tensorrt.dynamo._refit import refit_module_weights diff --git a/tests/py/dynamo/models/test_model_refit.py b/tests/py/dynamo/models/test_model_refit.py index a0b3292c29..d71091b04e 100644 --- a/tests/py/dynamo/models/test_model_refit.py +++ b/tests/py/dynamo/models/test_model_refit.py @@ -1,3 +1,4 @@ +import importlib import os import tempfile import unittest @@ -21,7 +22,6 @@ pre_export_lowering, ) from torch_tensorrt.logging import TRT_LOGGER -from transformers import BertModel assertions = unittest.TestCase() @@ -30,6 +30,10 @@ not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, "TorchScript Frontend is not available", ) +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Refit feature is not supported in Python 3.13 or higher", +) @pytest.mark.unit def test_mapping(): model = models.resnet18(pretrained=False).eval().to("cuda") @@ -85,6 +89,10 @@ def test_mapping(): not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, "TorchScript Frontend is not available", ) +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Refit feature is not supported in Python 3.13 or higher", +) @pytest.mark.unit def test_refit_one_engine_with_weightmap(): model = models.resnet18(pretrained=False).eval().to("cuda") @@ -134,6 +142,10 @@ def test_refit_one_engine_with_weightmap(): not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, "TorchScript Frontend is not available", ) +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Refit feature is not supported in Python 3.13 or higher", +) @pytest.mark.unit def test_refit_one_engine_no_map_with_weightmap(): model = models.resnet18(pretrained=False).eval().to("cuda") @@ -184,6 +196,10 @@ def test_refit_one_engine_no_map_with_weightmap(): not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, "TorchScript Frontend is not available", ) +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Refit feature is not supported in Python 3.13 or higher", +) @pytest.mark.unit def test_refit_one_engine_with_wrong_weightmap(): model = models.resnet18(pretrained=False).eval().to("cuda") @@ -238,8 +254,18 @@ def test_refit_one_engine_with_wrong_weightmap(): not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, "TorchScript Frontend is not available", ) +@unittest.skipIf( + not importlib.util.find_spec("transformers"), + "transformers is required to run this test", +) +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Refit feature is not supported in Python 3.13 or higher", +) @pytest.mark.unit def test_refit_one_engine_bert_with_weightmap(): + from transformers import BertModel + inputs = [ torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"), ] @@ -293,6 +319,10 @@ def test_refit_one_engine_bert_with_weightmap(): not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, "TorchScript Frontend is not available", ) +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Refit feature is not supported in Python 3.13 or higher", +) @pytest.mark.unit def test_refit_one_engine_inline_runtime__with_weightmap(): trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep") @@ -339,6 +369,10 @@ def test_refit_one_engine_inline_runtime__with_weightmap(): torch._dynamo.reset() +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Refit feature is not supported in Python 3.13 or higher", +) @pytest.mark.unit def test_refit_one_engine_python_runtime_with_weightmap(): model = models.resnet18(pretrained=False).eval().to("cuda") @@ -387,6 +421,10 @@ def test_refit_one_engine_python_runtime_with_weightmap(): not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, "TorchScript Frontend is not available", ) +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Refit feature is not supported in Python 3.13 or higher", +) @pytest.mark.unit def test_refit_multiple_engine_with_weightmap(): class net(nn.Module): @@ -458,6 +496,10 @@ def forward(self, x): not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, "TorchScript Frontend is not available", ) +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Refit feature is not supported in Python 3.13 or higher", +) @pytest.mark.unit def test_refit_one_engine_without_weightmap(): model = models.resnet18(pretrained=True).eval().to("cuda") @@ -506,8 +548,18 @@ def test_refit_one_engine_without_weightmap(): not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, "TorchScript Frontend is not available", ) +@unittest.skipIf( + not importlib.util.find_spec("transformers"), + "transformers is required to run this test", +) +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Refit feature is not supported in Python 3.13 or higher", +) @pytest.mark.unit def test_refit_one_engine_bert_without_weightmap(): + from transformers import BertModel + inputs = [ torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda"), ] @@ -561,6 +613,10 @@ def test_refit_one_engine_bert_without_weightmap(): not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, "TorchScript Frontend is not available", ) +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Refit feature is not supported in Python 3.13 or higher", +) @pytest.mark.unit def test_refit_one_engine_inline_runtime_without_weightmap(): trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep") @@ -607,6 +663,10 @@ def test_refit_one_engine_inline_runtime_without_weightmap(): torch._dynamo.reset() +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Refit feature is not supported in Python 3.13 or higher", +) @pytest.mark.unit def test_refit_one_engine_python_runtime_without_weightmap(): model = models.resnet18(pretrained=True).eval().to("cuda") @@ -655,6 +715,10 @@ def test_refit_one_engine_python_runtime_without_weightmap(): not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, "TorchScript Frontend is not available", ) +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Refit feature is not supported in Python 3.13 or higher", +) @pytest.mark.unit def test_refit_multiple_engine_without_weightmap(): class net(nn.Module): @@ -722,6 +786,10 @@ def forward(self, x): torch._dynamo.reset() +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Refit feature is not supported in Python 3.13 or higher", +) @pytest.mark.unit def test_refit_cumsum_fallback(): class net(nn.Module): diff --git a/tests/py/dynamo/models/test_models.py b/tests/py/dynamo/models/test_models.py index 6314baa5ec..aa48836590 100644 --- a/tests/py/dynamo/models/test_models.py +++ b/tests/py/dynamo/models/test_models.py @@ -1,5 +1,5 @@ # type: ignore - +import importlib import unittest import pytest @@ -8,7 +8,6 @@ import torch_tensorrt as torchtrt import torchvision.models as models from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity -from transformers import BertModel assertions = unittest.TestCase() @@ -109,10 +108,16 @@ def test_efficientnet_b0(ir): @pytest.mark.unit +@unittest.skipIf( + not importlib.util.find_spec("transformers"), + "transformers is required to run this test", +) def test_bert_base_uncased(ir): + from transformers import BertModel + model = BertModel.from_pretrained("bert-base-uncased").cuda().eval() - input = torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda") - input2 = torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda") + input = torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda") + input2 = torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda") compile_spec = { "inputs": [ diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py index 6f96e259b0..19fdeaa9ab 100644 --- a/tests/py/dynamo/models/test_models_export.py +++ b/tests/py/dynamo/models/test_models_export.py @@ -10,7 +10,6 @@ import torch_tensorrt as torchtrt import torchvision.models as models from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity -from transformers import BertModel from packaging.version import Version @@ -114,12 +113,18 @@ def test_efficientnet_b0(ir): @pytest.mark.unit +@unittest.skipIf( + not importlib.util.find_spec("transformers"), + "transformers is required to run this test", +) def test_bert_base_uncased(ir): + from transformers import BertModel + model = ( BertModel.from_pretrained("bert-base-uncased", return_dict=False).cuda().eval() ) - input = torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda") - input2 = torch.randint(0, 1, (1, 14), dtype=torch.int32).to("cuda") + input = torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda") + input2 = torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda") compile_spec = { "inputs": [ diff --git a/tests/py/dynamo/models/test_weight_stripped_engine.py b/tests/py/dynamo/models/test_weight_stripped_engine.py index 0c79ba7a3f..33bf94e711 100644 --- a/tests/py/dynamo/models/test_weight_stripped_engine.py +++ b/tests/py/dynamo/models/test_weight_stripped_engine.py @@ -16,6 +16,10 @@ class TestWeightStrippedEngine(TestCase): + @unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Engine caching requires refit feature that is not supported in Python 3.13 or higher", + ) def test_three_ways_to_compile(self): pyt_model = models.resnet18(pretrained=True).eval().to("cuda") example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) @@ -57,6 +61,10 @@ def test_three_ways_to_compile(self): gm1_output, gm2_output, 1e-2, 1e-2 ), "gm2_output is not correct" + @unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Engine caching requires refit feature that is not supported in Python 3.13 or higher", + ) def test_three_ways_to_compile_weight_stripped_engine(self): pyt_model = models.resnet18(pretrained=True).eval().to("cuda") example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) @@ -89,6 +97,10 @@ def test_three_ways_to_compile_weight_stripped_engine(self): gm1_output.sum(), 0, msg="gm1_output should be all zeros" ) + @unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Engine caching requires refit feature that is not supported in Python 3.13 or higher", + ) def test_weight_stripped_engine_sizes(self): pyt_model = models.resnet18(pretrained=True).eval().to("cuda") example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) @@ -126,6 +138,10 @@ def test_weight_stripped_engine_sizes(self): msg=f"Weight-stripped refit-identical engine size is not smaller than the weight included engine size. Weight included engine size: {len(bytes(weight_included_engine))}, weight-stripped refit-identical engine size: {len(bytes(weight_stripped_refit_identical_engine))}", ) + @unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Engine caching requires refit feature that is not supported in Python 3.13 or higher", + ) def test_weight_stripped_engine_results(self): pyt_model = models.resnet18(pretrained=True).eval().to("cuda") example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) @@ -187,6 +203,10 @@ def test_weight_stripped_engine_results(self): @unittest.skip( "For now, torch-trt will save weighted engine if strip_engine_weights is False. In the near future, we plan to save weight-stripped engine regardless of strip_engine_weights, which is pending on TRT's feature development: NVBug #4914602" ) + @unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Engine caching requires refit feature that is not supported in Python 3.13 or higher", + ) def test_engine_caching_saves_weight_stripped_engine(self): pyt_model = models.resnet18(pretrained=True).eval().to("cuda") example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) @@ -233,6 +253,10 @@ def test_engine_caching_saves_weight_stripped_engine(self): msg=f"cached engine size is not smaller than the weight included engine size. Weight included engine size: {len(bytes(weight_included_engine))}, cached stripped engine size: {len(bytes(cached_stripped_engine))}", ) + @unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Engine caching requires refit feature that is not supported in Python 3.13 or higher", + ) def test_dynamo_compile_with_refittable_weight_stripped_engine(self): pyt_model = models.resnet18(pretrained=True).eval().to("cuda") example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) @@ -397,6 +421,10 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): msg=f"Engine caching didn't speed up the compilation. Time taken without engine caching: {times[0]} ms, time taken with engine caching: {times[2]} ms", ) + @unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Engine caching requires refit feature that is not supported in Python 3.13 or higher", + ) def test_different_args_dont_share_cached_engine(self): class MyModel(torch.nn.Module): def __init__(self): @@ -446,6 +474,10 @@ def forward(self, x): msg=f"It has {len(os.listdir(engine_cache_dir))} cached engine(s) but should have 2 engines", ) + @unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Engine caching requires refit feature that is not supported in Python 3.13 or higher", + ) def test_constant_mul_in_refitting(self): class MyModel(torch.nn.Module): def __init__(self): @@ -483,6 +515,10 @@ def forward(self, x): msg=f"TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) + @unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Engine caching requires refit feature that is not supported in Python 3.13 or higher", + ) def test_two_TRTRuntime_in_refitting(self): pyt_model = models.resnet18(pretrained=True).eval().to("cuda") example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) @@ -523,6 +559,10 @@ def test_two_TRTRuntime_in_refitting(self): ) @unittest.skip("Waiting for implementation") + @unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Engine caching requires refit feature that is not supported in Python 3.13 or higher", + ) def test_refit_identical_engine_weights(self): pyt_model = models.resnet18(pretrained=True).eval().to("cuda") example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) diff --git a/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py b/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py index c07e04b6a4..f1af1098b1 100644 --- a/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py +++ b/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py @@ -74,6 +74,10 @@ def test_check_input_shape_dynamic(): ) +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Refit feature is not supported in Python 3.13 or higher", +) @pytest.mark.unit def test_model_complex_dynamic_shape(): device = "cuda:0" @@ -194,6 +198,10 @@ def forward(self, a, b, c=None): not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, "TorchScript Frontend is not available", ) +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Refit feature is not supported in Python 3.13 or higher", +) @pytest.mark.unit def test_resnet18(): torch.manual_seed(0) @@ -230,6 +238,10 @@ def test_resnet18(): not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, "TorchScript Frontend is not available", ) +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Refit feature is not supported in Python 3.13 or higher", +) @pytest.mark.unit def test_save(): torch.manual_seed(0) @@ -266,6 +278,10 @@ def test_save(): not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, "TorchScript Frontend is not available", ) +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Refit feature is not supported in Python 3.13 or higher", +) @pytest.mark.unit def test_resnet18_modify_attribute(): torch.manual_seed(0) @@ -306,6 +322,10 @@ def test_resnet18_modify_attribute(): not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, "TorchScript Frontend is not available", ) +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Refit feature is not supported in Python 3.13 or higher", +) @pytest.mark.unit def test_resnet18_modify_attribute_no_refit(): torch.manual_seed(0) @@ -353,6 +373,10 @@ def test_resnet18_modify_attribute_no_refit(): not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, "TorchScript Frontend is not available", ) +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Refit feature is not supported in Python 3.13 or higher", +) @pytest.mark.unit def test_custom_model_with_kwarg(): class net(nn.Module): @@ -420,6 +444,10 @@ def forward(self, x, b=5, c=None, d=None): not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, "TorchScript Frontend is not available", ) +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Refit feature is not supported in Python 3.13 or higher", +) @pytest.mark.unit def test_custom_model_with_inplace_init(): class net(nn.Module): @@ -483,6 +511,10 @@ def set_weights(self): not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, "TorchScript Frontend is not available", ) +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Refit feature is not supported in Python 3.13 or higher", +) @pytest.mark.unit def test_custom_model_with_init_recompile(): class net(nn.Module): @@ -546,6 +578,10 @@ def set_layer(self): not torch_trt.ENABLED_FEATURES.torch_tensorrt_runtime, "TorchScript Frontend is not available", ) +@unittest.skipIf( + not torch_trt.ENABLED_FEATURES.refit, + "Refit feature is not supported in Python 3.13 or higher", +) @pytest.mark.unit def test_custom_model_with_kwarg_different_input(): class net(nn.Module): diff --git a/tests/py/requirements.txt b/tests/py/requirements.txt index 6fb6128089..4f3c4e083b 100644 --- a/tests/py/requirements.txt +++ b/tests/py/requirements.txt @@ -8,6 +8,6 @@ pytest>=8.2.1 pytest-xdist>=3.6.1 pyyaml timm>=1.0.3 -transformers==4.40.2 -nvidia-modelopt[deploy,hf,torch]~=0.17.0 +transformers==4.49.0 +nvidia-modelopt[deploy,hf,torch]~=0.17.0; python_version < "3.13" --extra-index-url https://pypi.nvidia.com