From ba18185f6da5daa5704dc5583e5b1ee01c12b80b Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Thu, 10 Aug 2023 13:45:32 -0700 Subject: [PATCH] feat: Add support for device compilation setting - Add updated Device utilities and automatic context-aware device detection for torch compile - Add testing for new utilities --- py/torch_tensorrt/_Device.py | 8 +-- py/torch_tensorrt/_compile.py | 4 +- py/torch_tensorrt/dynamo/_defaults.py | 6 +++ py/torch_tensorrt/dynamo/_settings.py | 3 ++ py/torch_tensorrt/dynamo/compile.py | 16 ++++-- .../dynamo/conversion/conversion.py | 4 +- py/torch_tensorrt/dynamo/utils.py | 52 +++++++++++++++---- .../py/dynamo/backend/test_compiler_utils.py | 51 +++++++++++++++--- 8 files changed, 115 insertions(+), 29 deletions(-) diff --git a/py/torch_tensorrt/_Device.py b/py/torch_tensorrt/_Device.py index 3ac276db13..0f8ce1e392 100644 --- a/py/torch_tensorrt/_Device.py +++ b/py/torch_tensorrt/_Device.py @@ -8,11 +8,10 @@ import warnings -import torch -from torch_tensorrt import logging - # from torch_tensorrt import _enums import tensorrt as trt +import torch +from torch_tensorrt import logging try: from torch_tensorrt import _C @@ -120,6 +119,9 @@ def __str__(self) -> str: ) ) + def __repr__(self) -> str: + return self.__str__() + def _to_internal(self) -> _C.Device: internal_dev = _C.Device() if self.device_type == trt.DeviceType.GPU: diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index 598568816b..af40ce8dad 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -209,12 +209,12 @@ def compile( import collections.abc from torch_tensorrt import Device - from torch_tensorrt.dynamo.utils import prepare_device, prepare_inputs + from torch_tensorrt.dynamo.utils import prepare_inputs, to_torch_device if not isinstance(inputs, collections.abc.Sequence): inputs = [inputs] device = kwargs.get("device", Device._current_device()) - torchtrt_inputs, torch_inputs = prepare_inputs(inputs, prepare_device(device)) + torchtrt_inputs, torch_inputs = prepare_inputs(inputs, to_torch_device(device)) module = torch_tensorrt.dynamo.trace(module, torch_inputs, **kwargs) compiled_aten_module: torch.fx.GraphModule = dynamo_compile( module, diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index ec67a7a358..199674e6e7 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -1,7 +1,9 @@ import torch +from torch_tensorrt._Device import Device PRECISION = torch.float32 DEBUG = False +DEVICE = None WORKSPACE_SIZE = 0 MIN_BLOCK_SIZE = 5 PASS_THROUGH_BUILD_FAILURES = False @@ -12,3 +14,7 @@ USE_PYTHON_RUNTIME = False USE_FAST_PARTITIONER = True ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False + + +def default_device() -> Device: + return Device(gpu_id=torch.cuda.current_device()) diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 6f17ad768b..0bd644d006 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -2,6 +2,7 @@ from typing import Optional, Set import torch +from torch_tensorrt._Device import Device from torch_tensorrt.dynamo._defaults import ( DEBUG, ENABLE_EXPERIMENTAL_DECOMPOSITIONS, @@ -15,6 +16,7 @@ USE_PYTHON_RUNTIME, VERSION_COMPATIBLE, WORKSPACE_SIZE, + default_device, ) @@ -54,3 +56,4 @@ class CompilationSettings: truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE use_fast_partitioner: bool = USE_FAST_PARTITIONER enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS + device: Device = field(default_factory=default_device) diff --git a/py/torch_tensorrt/dynamo/compile.py b/py/torch_tensorrt/dynamo/compile.py index c274a27b5d..5ae45c8b0b 100644 --- a/py/torch_tensorrt/dynamo/compile.py +++ b/py/torch_tensorrt/dynamo/compile.py @@ -2,7 +2,7 @@ import collections.abc import logging -from typing import Any, List, Optional, Sequence, Set, Tuple +from typing import Any, List, Optional, Sequence, Set, Tuple, Union import torch import torch_tensorrt @@ -13,6 +13,7 @@ from torch_tensorrt.dynamo import CompilationSettings, partitioning from torch_tensorrt.dynamo._defaults import ( DEBUG, + DEVICE, ENABLE_EXPERIMENTAL_DECOMPOSITIONS, MAX_AUX_STREAMS, MIN_BLOCK_SIZE, @@ -29,7 +30,11 @@ convert_module, repair_long_or_double_inputs, ) -from torch_tensorrt.dynamo.utils import prepare_device, prepare_inputs +from torch_tensorrt.dynamo.utils import ( + prepare_inputs, + to_torch_device, + to_torch_tensorrt_device, +) logger = logging.getLogger(__name__) @@ -38,7 +43,7 @@ def compile( gm: Any, inputs: Any, *, - device: Device = Device._current_device(), + device: Optional[Union[Device, torch.device, str]] = DEVICE, disable_tf32: bool = False, sparse_weights: bool = False, enabled_precisions: Set[torch.dtype] | Tuple[torch.dtype] = (torch.float32,), @@ -82,7 +87,9 @@ def compile( if not isinstance(inputs, collections.abc.Sequence): inputs = [inputs] - _, torch_inputs = prepare_inputs(inputs, prepare_device(device)) + device = to_torch_tensorrt_device(device) + + _, torch_inputs = prepare_inputs(inputs, to_torch_device(device)) if ( torch.float16 in enabled_precisions @@ -105,6 +112,7 @@ def compile( compilation_options = { "precision": precision, "debug": debug, + "device": device, "workspace_size": workspace_size, "min_block_size": min_block_size, "torch_executed_ops": torch_executed_ops diff --git a/py/torch_tensorrt/dynamo/conversion/conversion.py b/py/torch_tensorrt/dynamo/conversion/conversion.py index 5c9bbd8c70..787a6d6c25 100644 --- a/py/torch_tensorrt/dynamo/conversion/conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/conversion.py @@ -3,14 +3,13 @@ import io from typing import Sequence +import tensorrt as trt import torch from torch_tensorrt._Input import Input from torch_tensorrt.dynamo import CompilationSettings from torch_tensorrt.dynamo.conversion import TRTInterpreter from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule -import tensorrt as trt - def convert_module( module: torch.fx.GraphModule, @@ -72,4 +71,5 @@ def convert_module( name=name, input_binding_names=list(interpreter_result.input_names), output_binding_names=list(interpreter_result.output_names), + target_device=settings.device, ) diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 019d6b904c..980616f35f 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -2,7 +2,7 @@ import logging from dataclasses import fields, replace -from typing import Any, Callable, Dict, Optional, Sequence +from typing import Any, Callable, Dict, Optional, Sequence, Union import torch import torch_tensorrt @@ -116,23 +116,45 @@ def prepare_inputs( ) -def prepare_device(device: Device | torch.device) -> torch.device: - _device: torch.device +def to_torch_device(device: Optional[Union[Device, torch.device, str]]) -> torch.device: + """Cast a device-type to torch.device + + Returns the corresponding torch.device + """ if isinstance(device, Device): if device.gpu_id != -1: - _device = torch.device(device.gpu_id) + return torch.device(device.gpu_id) else: raise ValueError("Invalid GPU ID provided for the CUDA device provided") elif isinstance(device, torch.device): - _device = device + return device + + elif device is None: + return torch.device(torch.cuda.current_device()) else: - raise ValueError( - "Invalid device provided. Supported options: torch.device | torch_tensorrt.Device" - ) + return torch.device(device) - return _device + +def to_torch_tensorrt_device( + device: Optional[Union[Device, torch.device, str]] +) -> Device: + """Cast a device-type to torch_tensorrt.Device + + Returns the corresponding torch_tensorrt.Device + """ + if isinstance(device, Device): + return device + + elif isinstance(device, torch.device): + return Device(gpu_id=device.index) + + elif device is None: + return Device(gpu_id=torch.cuda.current_device()) + + else: + return Device(device) def parse_dynamo_kwargs(kwargs: Any) -> CompilationSettings: @@ -184,7 +206,17 @@ def parse_dynamo_kwargs(kwargs: Any) -> CompilationSettings: # Parse input runtime specification settings.use_python_runtime = use_python_runtime_parser(settings.use_python_runtime) - logger.info("Compilation Settings: %s\n", settings) + # Ensure device is a torch_tensorrt Device + settings.device = to_torch_tensorrt_device(settings.device) + + # Check and update device settings + if "device" not in kwargs: + logger.info( + f"Device not specified, using Torch default current device - cuda:{settings.device.gpu_id}. " + "If this is incorrect, please specify an input device, via the device keyword." + ) + + logger.info(f"Compiling with Settings:\n{settings}") return settings diff --git a/tests/py/dynamo/backend/test_compiler_utils.py b/tests/py/dynamo/backend/test_compiler_utils.py index 3ef81b4e1a..2a2cef1b08 100644 --- a/tests/py/dynamo/backend/test_compiler_utils.py +++ b/tests/py/dynamo/backend/test_compiler_utils.py @@ -1,26 +1,61 @@ -from torch_tensorrt.dynamo.utils import prepare_device, prepare_inputs -from utils import same_output_format -import torch_tensorrt import unittest + import torch +import torch_tensorrt +from torch_tensorrt.dynamo.utils import ( + prepare_inputs, + to_torch_device, + to_torch_tensorrt_device, +) +from utils import same_output_format -class TestPrepareDevice(unittest.TestCase): - def test_prepare_cuda_device(self): +class TestToTorchDevice(unittest.TestCase): + def test_cast_cuda_device(self): gpu_id = 0 device = torch.device(f"cuda:{gpu_id}") - prepared_device = prepare_device(device) + prepared_device = to_torch_device(device) self.assertTrue(isinstance(prepared_device, torch.device)) self.assertTrue(prepared_device.index == gpu_id) - def test_prepare_trt_device(self): + def test_cast_trt_device(self): gpu_id = 4 device = torch_tensorrt.Device(gpu_id=gpu_id) - prepared_device = prepare_device(device) + prepared_device = to_torch_device(device) + self.assertTrue(isinstance(prepared_device, torch.device)) + self.assertTrue(prepared_device.index == gpu_id) + + def test_cast_str_device(self): + gpu_id = 2 + device = f"cuda:{2}" + prepared_device = to_torch_device(device) self.assertTrue(isinstance(prepared_device, torch.device)) self.assertTrue(prepared_device.index == gpu_id) +class TestToTorchTRTDevice(unittest.TestCase): + def test_cast_cuda_device(self): + gpu_id = 0 + device = torch.device(f"cuda:{gpu_id}") + prepared_device = to_torch_tensorrt_device(device) + self.assertTrue(isinstance(prepared_device, torch_tensorrt.Device)) + self.assertTrue(prepared_device.gpu_id == gpu_id) + + def test_cast_trt_device(self): + gpu_id = 4 + device = torch_tensorrt.Device(gpu_id=gpu_id) + prepared_device = to_torch_tensorrt_device(device) + self.assertTrue(isinstance(prepared_device, torch_tensorrt.Device)) + self.assertTrue(prepared_device.gpu_id == gpu_id) + + def test_cast_str_device(self): + gpu_id = 2 + device = f"cuda:{2}" + prepared_device = to_torch_tensorrt_device(device) + self.assertTrue(isinstance(prepared_device, torch_tensorrt.Device)) + self.assertTrue(prepared_device.gpu_id == gpu_id) + + class TestPrepareInputs(unittest.TestCase): def test_prepare_single_tensor_input(self): inputs = [torch.ones((4, 4))]