From dfac864e6a1c50bd4d013237c655c6fb7776ed8d Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Tue, 10 Sep 2024 13:55:31 -0700 Subject: [PATCH 1/6] chore: make engine caching opt-in feature --- py/torch_tensorrt/dynamo/_defaults.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 83e85cb3c7..edb7bff10b 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -35,8 +35,8 @@ tempfile.gettempdir(), "torch_tensorrt_engine_cache", "timing_cache.bin" ) LAZY_ENGINE_INIT = False -CACHE_BUILT_ENGINES = True -REUSE_CACHED_ENGINES = True +CACHE_BUILT_ENGINES = False +REUSE_CACHED_ENGINES = False ENGINE_CACHE_DIR = os.path.join(tempfile.gettempdir(), "torch_tensorrt_engine_cache") ENGINE_CACHE_SIZE = 1073741824 CUSTOM_ENGINE_CACHE = None From 6cd3b839341b375702a49306bff9cb5a344702ef Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Thu, 12 Sep 2024 00:55:07 -0700 Subject: [PATCH 2/6] feat: exclude refit sensitive ops from TRT compilation --- py/torch_tensorrt/dynamo/_compiler.py | 12 ++++++++++++ py/torch_tensorrt/dynamo/_defaults.py | 2 +- py/torch_tensorrt/dynamo/_refit.py | 9 +++++++-- py/torch_tensorrt/dynamo/utils.py | 4 +--- tests/py/dynamo/conversion/harness.py | 14 +++++++++++--- 5 files changed, 32 insertions(+), 9 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 2e6ff039b4..cc40e37fcd 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -19,6 +19,7 @@ parse_non_trt_nodes, ) from torch_tensorrt.dynamo._engine_cache import BaseEngineCache, DiskEngineCache +from torch_tensorrt.dynamo._refit import REFIT_SENSITIVE_OPS from torch_tensorrt.dynamo.conversion import ( CompilationSettings, UnsupportedOperatorException, @@ -317,6 +318,10 @@ def compile_module( # Assume converters support dynamic shapes and disable validation CONVERTERS.set_dynamic_shape_support(settings.assume_dynamic_shape_support) + # Set non-refitable ops as disallowed targets. + if settings.make_refitable: + CONVERTERS.set_disallowed_targets(REFIT_SENSITIVE_OPS) + # Set torch-executed ops CONVERTERS.set_disallowed_targets(settings.torch_executed_ops) @@ -673,6 +678,13 @@ def convert_exported_program_to_serialized_trt_engine( # Assume converters support dynamic shapes and disable validation CONVERTERS.set_dynamic_shape_support(settings.assume_dynamic_shape_support) + # Set non-refitable ops as disallowed targets. + if settings.make_refitable: + CONVERTERS.set_disallowed_targets(REFIT_SENSITIVE_OPS) + + # Set torch-executed ops + CONVERTERS.set_disallowed_targets(settings.torch_executed_ops) + try: interpreter_result = interpret_module_to_result( gm, diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index edb7bff10b..bc8d56a262 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -26,7 +26,7 @@ USE_PYTHON_RUNTIME = False USE_FAST_PARTITIONER = True ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False -MAKE_REFITABLE = False +MAKE_REFITABLE = True REQUIRE_FULL_COMPILATION = False DRYRUN = False HARDWARE_COMPATIBLE = False diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index c68b0a22aa..fd5a871fab 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -6,6 +6,7 @@ from typing import Any, List, Optional, Sequence, Tuple import numpy as np +import tensorrt as trt import torch from torch.export import ExportedProgram from torch_tensorrt._enums import dtype @@ -42,10 +43,14 @@ ) from torch_tensorrt.logging import TRT_LOGGER -import tensorrt as trt - logger = logging.getLogger(__name__) +# These ops are not refitable. +REFIT_SENSITIVE_OPS = { + torch.ops.aten.cumsum.default, + torch.ops.aten.embedding_bag.default, +} + def construct_refit_mapping( module: torch.fx.GraphModule, diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 75fbf4c935..67bce691f9 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -195,9 +195,7 @@ def get_model_device(module: torch.fx.GraphModule) -> torch.device: if device is None: device = to_torch_device(default_device()) - logger.warning( - "Could not detect the device on which the model exists. Assuming the model is on CPU" - ) + return device diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index f53bdf5d59..03cc9caea2 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -23,7 +23,7 @@ pre_export_lowering, ) from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule -from torch_tensorrt.dynamo.utils import ATOL, RTOL, get_torch_inputs +from torch_tensorrt.dynamo.utils import ATOL, RTOL, get_model_device, get_torch_inputs _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -225,7 +225,6 @@ def generate_graph( propagate_shapes: bool = False, ): mod = mod.eval() - torch_inputs = get_torch_inputs(original_inputs, _defaults.DEVICE) if use_dynamo_tracer: exported_program = torch_tensorrt.dynamo.trace(mod, tuple(original_inputs)) exported_program = pre_export_lowering(exported_program) @@ -242,6 +241,8 @@ def generate_graph( if propagate_shapes: # TODO: This is currently being used to test embedding_bag_aten due to https://github.com/pytorch/TensorRT/issues/2843 try: + device = get_model_device(fx_module) + torch_inputs = get_torch_inputs(original_inputs, device) ShapeProp(fx_module).propagate(*torch_inputs) except (RuntimeError, AssertionError): _LOGGER.warning( @@ -262,6 +263,7 @@ def run_test( enable_passes=False, propagate_shapes=False, int32_reqd=False, + make_refitable=False, ): mod = self.generate_graph( mod, @@ -277,6 +279,7 @@ def run_test( enabled_precisions={dtype._from(precision)}, truncate_double=True, debug=True, + make_refitable=make_refitable, ) num_inputs = len(inputs) @@ -345,6 +348,7 @@ def run_test_compare_tensor_attributes_only( output_dtypes=None, use_dynamo_tracer=False, enable_passes=False, + make_refitable=False, ): mod = self.generate_graph( mod, @@ -358,6 +362,7 @@ def run_test_compare_tensor_attributes_only( enabled_precisions={dtype._from(precision)}, truncate_double=True, debug=True, + make_refitable=make_refitable, ) interp = TRTInterpreter( @@ -383,6 +388,7 @@ def run_test_with_dynamic_shape( pyt_inputs=None, propagate_shapes=False, check_dtype=True, + make_refitable=False, ): mod = self.generate_graph( mod, @@ -394,7 +400,9 @@ def run_test_with_dynamic_shape( # Previous instance of the interpreter auto-casted 64-bit inputs # We replicate this behavior here - compilation_settings = CompilationSettings(truncate_double=True) + compilation_settings = CompilationSettings( + truncate_double=True, make_refitable=make_refitable + ) if check_dtype: output_dtypes = infer_module_output_dtypes( From 1b255924bca1417323504767e366e427bd534dbb Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Thu, 12 Sep 2024 00:58:17 -0700 Subject: [PATCH 3/6] chore: change refit default value --- py/torch_tensorrt/dynamo/_defaults.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index bc8d56a262..edb7bff10b 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -26,7 +26,7 @@ USE_PYTHON_RUNTIME = False USE_FAST_PARTITIONER = True ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False -MAKE_REFITABLE = True +MAKE_REFITABLE = False REQUIRE_FULL_COMPILATION = False DRYRUN = False HARDWARE_COMPATIBLE = False From cd54e88ec4e0e19bd781288c0ad14657c6b8742e Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Thu, 12 Sep 2024 02:04:19 -0700 Subject: [PATCH 4/6] chore: add testcase --- py/torch_tensorrt/dynamo/_compiler.py | 4 +- py/torch_tensorrt/dynamo/_refit.py | 10 ++-- .../py/dynamo/conversion/test_cumsum_aten.py | 4 ++ .../conversion/test_embedding_bag_aten.py | 4 ++ tests/py/dynamo/models/test_model_refit.py | 53 +++++++++++++++++++ 5 files changed, 69 insertions(+), 6 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index cc40e37fcd..671d7f87a0 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -320,7 +320,7 @@ def compile_module( # Set non-refitable ops as disallowed targets. if settings.make_refitable: - CONVERTERS.set_disallowed_targets(REFIT_SENSITIVE_OPS) + settings.torch_executed_ops.update(REFIT_SENSITIVE_OPS) # Set torch-executed ops CONVERTERS.set_disallowed_targets(settings.torch_executed_ops) @@ -680,7 +680,7 @@ def convert_exported_program_to_serialized_trt_engine( # Set non-refitable ops as disallowed targets. if settings.make_refitable: - CONVERTERS.set_disallowed_targets(REFIT_SENSITIVE_OPS) + settings.torch_executed_ops.update(REFIT_SENSITIVE_OPS) # Set torch-executed ops CONVERTERS.set_disallowed_targets(settings.torch_executed_ops) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index fd5a871fab..0aff9b39eb 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -46,10 +46,12 @@ logger = logging.getLogger(__name__) # These ops are not refitable. -REFIT_SENSITIVE_OPS = { - torch.ops.aten.cumsum.default, - torch.ops.aten.embedding_bag.default, -} +REFIT_SENSITIVE_OPS = frozenset( + { + "torch.ops.aten.cumsum.default", + "torch.ops.aten.embedding_bag.default", + } +) def construct_refit_mapping( diff --git a/tests/py/dynamo/conversion/test_cumsum_aten.py b/tests/py/dynamo/conversion/test_cumsum_aten.py index 4143401bd4..b08a2ee5a0 100644 --- a/tests/py/dynamo/conversion/test_cumsum_aten.py +++ b/tests/py/dynamo/conversion/test_cumsum_aten.py @@ -24,6 +24,7 @@ def forward(self, x): self.run_test( Cumsum(), inputs, + make_refitable=False, ) @parameterized.expand( @@ -43,6 +44,7 @@ def forward(self, x): self.run_test( Cumsum(), inputs, + make_refitable=False, ) @parameterized.expand( @@ -63,6 +65,7 @@ def forward(self, x): self.run_test( Cumsum(), inputs, + make_refitable=False, ) @parameterized.expand( @@ -92,6 +95,7 @@ def forward(self, x): self.run_test_with_dynamic_shape( Cumsum(), inputs, + make_refitable=False, ) diff --git a/tests/py/dynamo/conversion/test_embedding_bag_aten.py b/tests/py/dynamo/conversion/test_embedding_bag_aten.py index 3fef3d70cf..87d36de9ca 100644 --- a/tests/py/dynamo/conversion/test_embedding_bag_aten.py +++ b/tests/py/dynamo/conversion/test_embedding_bag_aten.py @@ -148,6 +148,7 @@ def forward(self, weight, indices): precision=weight.dtype, enable_passes=True, propagate_shapes=True, + make_refitable=False, ) @parameterized.expand( @@ -345,6 +346,7 @@ def forward(self, weight, indices, offsets): precision=weight.dtype, enable_passes=True, propagate_shapes=True, + make_refitable=False, ) @parameterized.expand( @@ -409,6 +411,7 @@ def forward(self, weight, indices, offsets): precision=weight.dtype, enable_passes=True, propagate_shapes=True, + make_refitable=False, ) @parameterized.expand( @@ -490,6 +493,7 @@ def forward(self, weights, indices, offsets, per_sample_weights=None): min_block_size=1, cache_built_engines=False, reuse_cached_engines=False, + make_refitable=False, ) # use the inputs with different shape to inference: if per_sample_weights is None: diff --git a/tests/py/dynamo/models/test_model_refit.py b/tests/py/dynamo/models/test_model_refit.py index 9782cd829c..2d03442cec 100644 --- a/tests/py/dynamo/models/test_model_refit.py +++ b/tests/py/dynamo/models/test_model_refit.py @@ -729,3 +729,56 @@ def forward(self, x): # Clean up model env torch._dynamo.reset() + + +@pytest.mark.unit +def test_refit_cumsum_fallback(): + + class net(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 12, 3, padding=1) + self.fc1 = nn.Linear(12 * 16 * 16, 10) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(x) + x = torch.flatten(x, 1) + x = torch.cumsum(self.fc1(x), 1) + x = x**2 + return x + + model = net().eval().to("cuda") + inputs = [torch.randn((1, 3, 16, 16)).to("cuda")] + model(*inputs) + exp_program = torch.export.export(model, tuple(inputs)) + with torchtrt.logging.debug(): + trt_gm = torchtrt.dynamo.compile( + exp_program, + tuple(inputs), + enabled_precisions={torch.float}, + debug=True, + min_block_size=1, + make_refitable=True, + ) + + num_pyt_segments = len( + [1 for submod in list(trt_gm.named_children()) if "_run_on_gpu" in submod[0]] + ) + + # Number of pyt segments should be 1 (because of cumsum being non-refitable) + assertions.assertTrue( + num_pyt_segments == 1, + f"test_refit_cumsum_fallback test found {num_pyt_segments} pytorch segments but expected 1", + ) + + # Check the output + pyt_outputs, trt_outputs = exp_program.module()(*inputs), trt_gm(*inputs) + for pyt_output, trt_output in zip(pyt_outputs, trt_outputs): + assertions.assertTrue( + torch.allclose(pyt_output, trt_output, 1e-2, 1e-2), + "Refit Result is not correct. Refit failed", + ) + # Clean up model env + + torch._dynamo.reset() From cf4928b921e65d872d5e66316cd9e022ce828074 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Mon, 16 Sep 2024 16:59:33 -0700 Subject: [PATCH 5/6] chore: updates --- py/torch_tensorrt/dynamo/_compiler.py | 22 ++----- py/torch_tensorrt/dynamo/_refit.py | 8 --- .../dynamo/conversion/_ConverterRegistry.py | 40 +++++++---- .../dynamo/conversion/_TRTInterpreter.py | 8 ++- .../dynamo/conversion/aten_ops_converters.py | 66 +++++++++++++------ .../dynamo/conversion/converter_utils.py | 5 +- .../dynamo/conversion/ops_evaluators.py | 11 ++-- .../dynamo/conversion/prims_ops_converters.py | 5 +- 8 files changed, 99 insertions(+), 66 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 671d7f87a0..8425589abf 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -19,7 +19,6 @@ parse_non_trt_nodes, ) from torch_tensorrt.dynamo._engine_cache import BaseEngineCache, DiskEngineCache -from torch_tensorrt.dynamo._refit import REFIT_SENSITIVE_OPS from torch_tensorrt.dynamo.conversion import ( CompilationSettings, UnsupportedOperatorException, @@ -315,15 +314,9 @@ def compile_module( dryrun_tracker = DryRunTracker() if sample_kwarg_inputs is None: sample_kwarg_inputs = {} - # Assume converters support dynamic shapes and disable validation - CONVERTERS.set_dynamic_shape_support(settings.assume_dynamic_shape_support) - # Set non-refitable ops as disallowed targets. - if settings.make_refitable: - settings.torch_executed_ops.update(REFIT_SENSITIVE_OPS) - - # Set torch-executed ops - CONVERTERS.set_disallowed_targets(settings.torch_executed_ops) + # Configure user compilation settings to converters. + CONVERTERS.set_compilation_settings(settings) # Check the number of supported operations in the graph num_supported_ops, total_ops = partitioning.get_graph_converter_support( @@ -675,15 +668,8 @@ def convert_exported_program_to_serialized_trt_engine( settings = CompilationSettings(**compilation_options) logger.info("Compilation Settings: %s\n", settings) - # Assume converters support dynamic shapes and disable validation - CONVERTERS.set_dynamic_shape_support(settings.assume_dynamic_shape_support) - - # Set non-refitable ops as disallowed targets. - if settings.make_refitable: - settings.torch_executed_ops.update(REFIT_SENSITIVE_OPS) - - # Set torch-executed ops - CONVERTERS.set_disallowed_targets(settings.torch_executed_ops) + # Configure user compilation settings to converters. + CONVERTERS.set_compilation_settings(settings) try: interpreter_result = interpret_module_to_result( diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 0aff9b39eb..4463af6350 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -45,14 +45,6 @@ logger = logging.getLogger(__name__) -# These ops are not refitable. -REFIT_SENSITIVE_OPS = frozenset( - { - "torch.ops.aten.cumsum.default", - "torch.ops.aten.embedding_bag.default", - } -) - def construct_refit_mapping( module: torch.fx.GraphModule, diff --git a/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py b/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py index e4ea91c196..4801834e56 100644 --- a/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py +++ b/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py @@ -23,6 +23,7 @@ from torch import SymBool, SymFloat, SymInt from torch._ops import OpOverloadPacket from torch.fx.node import Argument, Node, Target, _get_qualified_name +from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.fx.converter_registry import CONVERTERS as FX_CONVERTERS @@ -82,7 +83,9 @@ class ConverterSupport: """ converter_implementation: ConverterImplSignature - capability_validator: Callable[[Node], bool] = field(default=lambda node: True) + capability_validator: Callable[[Node, CompilationSettings], bool] = field( + default=lambda node, compilation_settings: True + ) supports_dynamic_shapes: bool = False @@ -112,10 +115,10 @@ def has_dynamic_shapes_in_args( def has_static_shapes_in_args( arg_positions_to_check: Optional[List[int]] = None, -) -> Callable[[torch.fx.Node], bool]: +) -> Callable[[torch.fx.Node, CompilationSettings], bool]: """Returns True if a node has static inputs in node.args at specified positions""" - _has_static_shapes = lambda node, arg_positions_to_check: not _has_dynamic_shapes( - node, arg_positions_to_check + _has_static_shapes = lambda node, compilation_settings, arg_positions_to_check: not _has_dynamic_shapes( + node, compilation_settings, arg_positions_to_check ) return functools.partial( _has_static_shapes, arg_positions_to_check=arg_positions_to_check @@ -123,7 +126,9 @@ def has_static_shapes_in_args( def _has_dynamic_shapes( - node: torch.fx.Node, arg_positions_to_check: Optional[List[int]] = None + node: torch.fx.Node, + compilation_settings: CompilationSettings = None, + arg_positions_to_check: Optional[List[int]] = None, ) -> bool: # Validate that none of the inputs to the node have Dynamic shapes assert isinstance( @@ -188,7 +193,7 @@ def dynamo_tensorrt_converter( key: Target, *, enabled: bool = True, - capability_validator: Optional[Callable[[Node], bool]] = None, + capability_validator: Optional[Callable[[Node, CompilationSettings], bool]] = None, priority: ConverterPriority = ConverterPriority.STANDARD, supports_dynamic_shapes: bool = False, ) -> Callable[[ConverterImplSignature], ConverterImplSignature]: @@ -297,7 +302,6 @@ def __init__( ], registry_names: Optional[Sequence[str]] = None, registry_calling_conventions: Optional[Sequence[CallingConvention]] = None, - assume_dynamic_shape_support: bool = False, ): # Copy reference to each dictionary object into attribute list self.registries = list(registries) @@ -318,12 +322,16 @@ def __init__( CallingConvention.CTX for _ in range(len(self.registries)) ] + self.compilation_settings: CompilationSettings = None self.disallowed_targets: Collection[Target] = set() - self.assume_dynamic_shape_support = assume_dynamic_shape_support self.validate_invariants() - def set_dynamic_shape_support(self, assume_dynamic_shape_support: bool) -> None: - self.assume_dynamic_shape_support = assume_dynamic_shape_support + def set_compilation_settings( + self, compilation_settings: CompilationSettings + ) -> None: + self.compilation_settings = compilation_settings + # set torch executed ops as disallowed targets + self.set_disallowed_targets(compilation_settings.torch_executed_ops) def set_disallowed_targets(self, torch_executed_ops: Collection[Target]) -> None: self.disallowed_targets = torch_executed_ops @@ -412,7 +420,11 @@ def __getitem__( self.validate_invariants() key = node.target - + assume_dynamic_shape_support = False + if self.compilation_settings: + assume_dynamic_shape_support = ( + self.compilation_settings.assume_dynamic_shape_support + ) if ( key in self.disallowed_targets or self.qualified_name_or_str(key) in self.disallowed_targets @@ -436,8 +448,10 @@ def __getitem__( # 2) Assume dynamic_shape support is True # 3) Node only has static shaped inputs # 4) Node has dynamic inputs and the converter has supports_dynamic_shapes=True - if candidate.capability_validator(node) and ( - self.assume_dynamic_shape_support + if candidate.capability_validator( + node, self.compilation_settings + ) and ( + assume_dynamic_shape_support or not node_has_dynamic_shapes(node) or candidate.supports_dynamic_shapes ): diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 84fe345137..ae76ea8c37 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -18,6 +18,7 @@ ) import numpy as np +import tensorrt as trt import torch import torch.fx from torch.fx.node import _get_qualified_name @@ -43,7 +44,6 @@ from torch_tensorrt.fx.observer import Observer from torch_tensorrt.logging import TRT_LOGGER -import tensorrt as trt from packaging import version _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -89,6 +89,11 @@ def __init__( self.builder.create_network(flag), compilation_settings ) + self.compilation_settings = compilation_settings + if not CONVERTERS.compilation_settings: + # Configure user compilation settings to converters. + CONVERTERS.set_compilation_settings(compilation_settings) + assert TRTInterpreter._all_precisions_supported( compilation_settings.enabled_precisions ), f"Attempted to enable kernel precisions that are not supported (got: {compilation_settings.enabled_precisions}, support: {_defaults.SUPPORTED_KERNEL_PRECISIONS})" @@ -117,7 +122,6 @@ def __init__( self._itensor_to_tensor_meta: Dict[trt.tensorrt.ITensor, TensorMetadata] = ( dict() ) - self.compilation_settings = compilation_settings # Data types for TRT Module output Tensors self.output_dtypes = ( diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index a757cf023e..1735fba9b1 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -7,6 +7,7 @@ import numpy as np import torch from torch.fx.node import Argument, Node, Target +from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion import impl from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext @@ -48,7 +49,7 @@ def get_ir(target: Target) -> SourceIR: return SourceIR.UNKNOWN -def one_user_validator(node: Node) -> bool: +def one_user_validator(node: Node, settings: CompilationSettings = None) -> bool: # Validate only one user, which is a getitem node that accesses the first element in the list return ( len(node.users) == 1 @@ -270,7 +271,11 @@ def aten_ops_embedding( ) -def embedding_bag_validator(node: Node) -> bool: +def embedding_bag_validator(node: Node, settings: CompilationSettings = None) -> bool: + # Embedding bag op is not refitable + if settings.make_refitable: + return False + if not one_user_validator(node): return False meta = node.args[1].meta @@ -416,7 +421,7 @@ def aten_ops_symsize_int( return impl.shape.shape(ctx, target, SourceIR.ATEN, name, args[0], args[1]) -def index_dtype_validator(node: Node) -> bool: +def index_dtype_validator(node: Node, settings: CompilationSettings = None) -> bool: index = node.args[1] for ind in index: if ind is not None: @@ -837,7 +842,7 @@ def aten_ops_select( ) -def index_put_validator(node: Node) -> bool: +def index_put_validator(node: Node, settings: CompilationSettings = None) -> bool: if args_bounds_check(node.args, 3, False): # Check if accumulate is valid _LOGGER.debug("We do not support accumulate=True for aten.index_put operation") accumulate_valid = False @@ -924,7 +929,18 @@ def aten_ops_slice( ) -@dynamo_tensorrt_converter(torch.ops.aten.cumsum.default, supports_dynamic_shapes=True) +def refit_validator(node: Node, settings: CompilationSettings = None) -> bool: + # cumsum op is not refitable + if settings and settings.make_refitable: + return False + return True + + +@dynamo_tensorrt_converter( + torch.ops.aten.cumsum.default, + capability_validator=refit_validator, + supports_dynamic_shapes=True, +) @enforce_tensor_types( { 0: (TRTTensor,), @@ -970,7 +986,7 @@ def aten_ops_tile( ) -def zero_output_validator(node: Node) -> bool: +def zero_output_validator(node: Node, settings: CompilationSettings = None) -> bool: if 0 in node.args[1]: _LOGGER.debug( f"We do not support output tensor {node.args[1]} tensors with zero-sized dimensions for this operation." @@ -1027,7 +1043,9 @@ def aten_ops_permute( ) -def to_copy_dtype_validator(placeholder_only: bool) -> Callable[[Node], bool]: +def to_copy_dtype_validator( + placeholder_only: bool, settings: CompilationSettings = None +) -> Callable[[Node, CompilationSettings], bool]: """Return validator for to_copy node with placeholder restrictions""" def validate_dtype(to_copy_node: Node) -> bool: @@ -1059,7 +1077,7 @@ def validate_dtype(to_copy_node: Node) -> bool: ) return False - def validator(to_copy_node: Node) -> bool: + def validator(to_copy_node: Node, settings: CompilationSettings = None) -> bool: """Returns true if the to_copy node can be converted to TRT and the placeholder restriction is satisfied """ @@ -1074,7 +1092,9 @@ def validator(to_copy_node: Node) -> bool: @dynamo_tensorrt_converter( torch.ops.aten.clone.default, - capability_validator=lambda node: not is_only_operator_on_placeholder(node), + capability_validator=lambda node, settings: not is_only_operator_on_placeholder( + node, settings + ), supports_dynamic_shapes=True, ) @dynamo_tensorrt_converter( @@ -2128,7 +2148,7 @@ def aten_ops_logical_xor( ) -def bitwise_type_validator(node: Node) -> bool: +def bitwise_type_validator(node: Node, settings: CompilationSettings = None) -> bool: supported_type = [torch.bool, bool] tensor_targets = [ @@ -2271,7 +2291,9 @@ def aten_ops_bitwise_xor( ) -def bitwise_not_type_validator(node: Node) -> bool: +def bitwise_not_type_validator( + node: Node, settings: CompilationSettings = None +) -> bool: val = node.args[0] val_meta = val.meta.get("tensor_meta") @@ -2453,7 +2475,7 @@ def aten_ops_le( ) -def conv_param_validator(conv_node: Node) -> bool: +def conv_param_validator(conv_node: Node, settings: CompilationSettings = None) -> bool: return conv_node.args[7] in ([0], [0, 0], [0, 0, 0]) @@ -2549,7 +2571,9 @@ def aten_ops_cdist_forward( ) -def avg_pool_param_validator(pool_node: Node) -> bool: +def avg_pool_param_validator( + pool_node: Node, settings: CompilationSettings = None +) -> bool: ceil_mode = args_bounds_check(pool_node.args, 4, False) divisor_override = args_bounds_check(pool_node.args, 6) @@ -2665,12 +2689,12 @@ def aten_ops_adaptive_avg_poolNd( ) -def topk_validator(node: Node) -> bool: +def topk_validator(node: Node, settings: CompilationSettings = None) -> bool: k = node.args[1] return topk_sort_validator(k) -def sort_validator(node: Node) -> bool: +def sort_validator(node: Node, settings: CompilationSettings = None) -> bool: meta_data = node.args[0].meta.get("tensor_meta") if meta_data is None: return False @@ -2692,7 +2716,9 @@ def topk_sort_validator(k: int) -> bool: return True -def max_pool_param_validator(pool_node: Node) -> bool: +def max_pool_param_validator( + pool_node: Node, settings: CompilationSettings = None +) -> bool: dilation = args_bounds_check(pool_node.args, 4, 1) ceil_mode = args_bounds_check(pool_node.args, 5, False) @@ -2746,7 +2772,7 @@ def aten_ops_max_pool( ) -def attention_validator(node: Node) -> bool: +def attention_validator(node: Node, settings: CompilationSettings = None) -> bool: # Currently, `attn_mask` is not supported return args_bounds_check(node.args, 3) is None @@ -3637,7 +3663,7 @@ def aten_ops_flip( ) -def zero_diag_size_validator(node: Node) -> bool: +def zero_diag_size_validator(node: Node, settings: CompilationSettings = None) -> bool: meta = node.args[0].meta.get("tensor_meta") if meta: input_shape = meta.shape @@ -3765,7 +3791,9 @@ def aten_ops_index_select( ) -def dropout_inference_validator(node: Node) -> bool: +def dropout_inference_validator( + node: Node, settings: CompilationSettings = None +) -> bool: train_mode = args_bounds_check(node.args, 2, None) if train_mode is False: return True diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 70135f86d3..39ace2a873 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -10,6 +10,7 @@ from torch.fx.node import Argument, Target from torch.fx.passes.shape_prop import TensorMetadata from torch_tensorrt import _enums +from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( @@ -111,7 +112,9 @@ def format_tensor_metadata(metadata: Union[Any, Sequence[Any]]) -> str: return metadata_string -def is_only_operator_on_placeholder(node: torch.fx.Node) -> bool: +def is_only_operator_on_placeholder( + node: torch.fx.Node, settings: CompilationSettings = None +) -> bool: """Detects whether a call_function node is the only operator on a placeholder""" # Returns true if the node operates on a placeholder and is a direct output return ( diff --git a/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py b/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py index 5eeb2db661..f320505c94 100644 --- a/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py +++ b/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py @@ -7,6 +7,7 @@ import numpy as np import torch from torch.fx.node import Argument, Node, Target +from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( ConverterRegistry, @@ -18,7 +19,7 @@ _LOGGER: logging.Logger = logging.getLogger(__name__) -def getitem_validator(getitem_node: Node) -> bool: +def getitem_validator(getitem_node: Node, settings: CompilationSettings = None) -> bool: from torch_tensorrt.dynamo.conversion._ConverterRegistry import DYNAMO_CONVERTERS # Getitem nodes can only be converted if their parent node also can @@ -45,7 +46,7 @@ def generic_evaluator( return target(*args) -def rand_validator(rand_node: Node) -> bool: +def rand_validator(rand_node: Node, settings: CompilationSettings = None) -> bool: dtype = rand_node.kwargs.get("dtype", None) layout = rand_node.kwargs.get("layout", None) if dtype is not None: @@ -85,7 +86,9 @@ def aten_ops_randn( return np.random.randn(*args[0]) -def randperm_validator(randperm_node: Node) -> bool: +def randperm_validator( + randperm_node: Node, settings: CompilationSettings = None +) -> bool: dtype = randperm_node.kwargs.get("dtype", None) layout = randperm_node.kwargs.get("layout", None) input = randperm_node.args[0] @@ -116,7 +119,7 @@ def aten_ops_randperm( return np.random.permutation(args[0]) -def empty_validator(empty_node: Node) -> bool: +def empty_validator(empty_node: Node, settings: CompilationSettings = None) -> bool: device = empty_node.kwargs.get("device", None) if device is not None: _LOGGER.debug(f"Currently we don't support specifying device, got {device}.") diff --git a/py/torch_tensorrt/dynamo/conversion/prims_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/prims_ops_converters.py index 9548dc287a..923ca9be6c 100644 --- a/py/torch_tensorrt/dynamo/conversion/prims_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/prims_ops_converters.py @@ -3,6 +3,7 @@ import torch from torch.fx.node import Argument, Target +from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion import impl from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext @@ -15,7 +16,9 @@ # TODO: expand the scope of this converter with aten.expand implementation -def broadcast_checker(broadcast_node: torch.fx.Node) -> bool: +def broadcast_checker( + broadcast_node: torch.fx.Node, settings: CompilationSettings = None +) -> bool: # The current implementation of broadcast_in_dim can only handle unsqueeze return all( broadcast_node.args[1][i] == 1 From 2ce8a2bd62497980b6ce76b15afef5a42b3453b8 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Tue, 17 Sep 2024 09:46:29 -0700 Subject: [PATCH 6/6] chore: updates --- .../dynamo/conversion/aten_ops_converters.py | 4 ++-- tests/py/dynamo/conversion/harness.py | 12 ++++++------ tests/py/dynamo/conversion/test_cumsum_aten.py | 8 ++++---- .../py/dynamo/conversion/test_embedding_bag_aten.py | 8 ++++---- tests/py/dynamo/models/test_model_refit.py | 5 ++--- 5 files changed, 18 insertions(+), 19 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 1735fba9b1..60a48d98e3 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -273,7 +273,7 @@ def aten_ops_embedding( def embedding_bag_validator(node: Node, settings: CompilationSettings = None) -> bool: # Embedding bag op is not refitable - if settings.make_refitable: + if settings.make_refittable: return False if not one_user_validator(node): @@ -931,7 +931,7 @@ def aten_ops_slice( def refit_validator(node: Node, settings: CompilationSettings = None) -> bool: # cumsum op is not refitable - if settings and settings.make_refitable: + if settings and settings.make_refittable: return False return True diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index 03cc9caea2..632b73e2f3 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -263,7 +263,7 @@ def run_test( enable_passes=False, propagate_shapes=False, int32_reqd=False, - make_refitable=False, + make_refittable=False, ): mod = self.generate_graph( mod, @@ -279,7 +279,7 @@ def run_test( enabled_precisions={dtype._from(precision)}, truncate_double=True, debug=True, - make_refitable=make_refitable, + make_refittable=make_refittable, ) num_inputs = len(inputs) @@ -348,7 +348,7 @@ def run_test_compare_tensor_attributes_only( output_dtypes=None, use_dynamo_tracer=False, enable_passes=False, - make_refitable=False, + make_refittable=False, ): mod = self.generate_graph( mod, @@ -362,7 +362,7 @@ def run_test_compare_tensor_attributes_only( enabled_precisions={dtype._from(precision)}, truncate_double=True, debug=True, - make_refitable=make_refitable, + make_refittable=make_refittable, ) interp = TRTInterpreter( @@ -388,7 +388,7 @@ def run_test_with_dynamic_shape( pyt_inputs=None, propagate_shapes=False, check_dtype=True, - make_refitable=False, + make_refittable=False, ): mod = self.generate_graph( mod, @@ -401,7 +401,7 @@ def run_test_with_dynamic_shape( # Previous instance of the interpreter auto-casted 64-bit inputs # We replicate this behavior here compilation_settings = CompilationSettings( - truncate_double=True, make_refitable=make_refitable + truncate_double=True, make_refittable=make_refittable ) if check_dtype: diff --git a/tests/py/dynamo/conversion/test_cumsum_aten.py b/tests/py/dynamo/conversion/test_cumsum_aten.py index b08a2ee5a0..1c32be6dd6 100644 --- a/tests/py/dynamo/conversion/test_cumsum_aten.py +++ b/tests/py/dynamo/conversion/test_cumsum_aten.py @@ -24,7 +24,7 @@ def forward(self, x): self.run_test( Cumsum(), inputs, - make_refitable=False, + make_refittable=False, ) @parameterized.expand( @@ -44,7 +44,7 @@ def forward(self, x): self.run_test( Cumsum(), inputs, - make_refitable=False, + make_refittable=False, ) @parameterized.expand( @@ -65,7 +65,7 @@ def forward(self, x): self.run_test( Cumsum(), inputs, - make_refitable=False, + make_refittable=False, ) @parameterized.expand( @@ -95,7 +95,7 @@ def forward(self, x): self.run_test_with_dynamic_shape( Cumsum(), inputs, - make_refitable=False, + make_refittable=False, ) diff --git a/tests/py/dynamo/conversion/test_embedding_bag_aten.py b/tests/py/dynamo/conversion/test_embedding_bag_aten.py index 87d36de9ca..6543ac2306 100644 --- a/tests/py/dynamo/conversion/test_embedding_bag_aten.py +++ b/tests/py/dynamo/conversion/test_embedding_bag_aten.py @@ -148,7 +148,7 @@ def forward(self, weight, indices): precision=weight.dtype, enable_passes=True, propagate_shapes=True, - make_refitable=False, + make_refittable=False, ) @parameterized.expand( @@ -346,7 +346,7 @@ def forward(self, weight, indices, offsets): precision=weight.dtype, enable_passes=True, propagate_shapes=True, - make_refitable=False, + make_refittable=False, ) @parameterized.expand( @@ -411,7 +411,7 @@ def forward(self, weight, indices, offsets): precision=weight.dtype, enable_passes=True, propagate_shapes=True, - make_refitable=False, + make_refittable=False, ) @parameterized.expand( @@ -493,7 +493,7 @@ def forward(self, weights, indices, offsets, per_sample_weights=None): min_block_size=1, cache_built_engines=False, reuse_cached_engines=False, - make_refitable=False, + make_refittable=False, ) # use the inputs with different shape to inference: if per_sample_weights is None: diff --git a/tests/py/dynamo/models/test_model_refit.py b/tests/py/dynamo/models/test_model_refit.py index b80b46ea96..0f6fb05914 100644 --- a/tests/py/dynamo/models/test_model_refit.py +++ b/tests/py/dynamo/models/test_model_refit.py @@ -5,6 +5,7 @@ import numpy as np import pytest +import tensorrt as trt import torch import torch.nn.functional as F import torch_tensorrt as torchtrt @@ -24,8 +25,6 @@ from torch_tensorrt.logging import TRT_LOGGER from transformers import BertModel -import tensorrt as trt - assertions = unittest.TestCase() @@ -760,7 +759,7 @@ def forward(self, x): enabled_precisions={torch.float}, debug=True, min_block_size=1, - make_refitable=True, + make_refittable=True, ) num_pyt_segments = len(