Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 431b145

Browse files
committedAug 15, 2023
feat: Add support for device compilation setting
- Add updated Device utilities and automatic context-aware device detection for torch compile - Add testing for new utilities
1 parent b57d83e commit 431b145

File tree

8 files changed

+105
-28
lines changed

8 files changed

+105
-28
lines changed
 

‎py/torch_tensorrt/_Device.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,10 @@
88

99
import warnings
1010

11-
import torch
12-
from torch_tensorrt import logging
13-
1411
# from torch_tensorrt import _enums
1512
import tensorrt as trt
13+
import torch
14+
from torch_tensorrt import logging
1615

1716
try:
1817
from torch_tensorrt import _C
@@ -120,6 +119,9 @@ def __str__(self) -> str:
120119
)
121120
)
122121

122+
def __repr__(self) -> str:
123+
return self.__str__()
124+
123125
def _to_internal(self) -> _C.Device:
124126
internal_dev = _C.Device()
125127
if self.device_type == trt.DeviceType.GPU:

‎py/torch_tensorrt/_compile.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,12 +192,12 @@ def compile(
192192
import collections.abc
193193

194194
from torch_tensorrt import Device
195-
from torch_tensorrt.dynamo.utils import prepare_device, prepare_inputs
195+
from torch_tensorrt.dynamo.utils import prepare_inputs, to_torch_device
196196

197197
if not isinstance(inputs, collections.abc.Sequence):
198198
inputs = [inputs]
199199
device = kwargs.get("device", Device._current_device())
200-
torchtrt_inputs, torch_inputs = prepare_inputs(inputs, prepare_device(device))
200+
torchtrt_inputs, torch_inputs = prepare_inputs(inputs, to_torch_device(device))
201201
module = torch_tensorrt.dynamo.trace(module, torch_inputs, **kwargs)
202202
compiled_aten_module: torch.fx.GraphModule = dynamo_compile(
203203
module,

‎py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
from torch_tensorrt._Device import Device
23

34
PRECISION = torch.float32
45
DEBUG = False
@@ -11,3 +12,4 @@
1112
TRUNCATE_LONG_AND_DOUBLE = False
1213
USE_PYTHON_RUNTIME = False
1314
USE_FAST_PARTITIONER = True
15+
DEVICE = Device._current_device()

‎py/torch_tensorrt/dynamo/_settings.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
from typing import Optional, Set
33

44
import torch
5+
from torch_tensorrt._Device import Device
56
from torch_tensorrt.dynamo._defaults import (
67
DEBUG,
8+
DEVICE,
79
MAX_AUX_STREAMS,
810
MIN_BLOCK_SIZE,
911
OPTIMIZATION_LEVEL,
@@ -31,3 +33,4 @@ class CompilationSettings:
3133
use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME
3234
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE
3335
use_fast_partitioner: bool = USE_FAST_PARTITIONER
36+
device: Device = DEVICE

‎py/torch_tensorrt/dynamo/compile.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import collections.abc
44
import logging
5-
from typing import Any, List, Optional, Set, Tuple
5+
from typing import Any, List, Optional, Set, Tuple, Union
66

77
import torch
88
import torch_tensorrt
@@ -14,6 +14,7 @@
1414
from torch_tensorrt.dynamo import CompilationSettings
1515
from torch_tensorrt.dynamo._defaults import (
1616
DEBUG,
17+
DEVICE,
1718
MAX_AUX_STREAMS,
1819
MIN_BLOCK_SIZE,
1920
OPTIMIZATION_LEVEL,
@@ -30,7 +31,11 @@
3031
fuse_permute_linear,
3132
fuse_permute_matmul,
3233
)
33-
from torch_tensorrt.dynamo.utils import prepare_device, prepare_inputs
34+
from torch_tensorrt.dynamo.utils import (
35+
prepare_inputs,
36+
to_torch_device,
37+
to_torch_tensorrt_device,
38+
)
3439

3540
logger = logging.getLogger(__name__)
3641

@@ -39,7 +44,7 @@ def compile(
3944
gm: Any,
4045
inputs: Any,
4146
*,
42-
device: Device = Device._current_device(),
47+
device: Union[Device, torch.device, str] = DEVICE,
4348
disable_tf32: bool = False,
4449
sparse_weights: bool = False,
4550
enabled_precisions: Set[torch.dtype] | Tuple[torch.dtype] = (torch.float32,),
@@ -80,7 +85,9 @@ def compile(
8085
if not isinstance(inputs, collections.abc.Sequence):
8186
inputs = [inputs]
8287

83-
_, torch_inputs = prepare_inputs(inputs, prepare_device(device))
88+
device = to_torch_tensorrt_device(device)
89+
90+
_, torch_inputs = prepare_inputs(inputs, to_torch_device(device))
8491

8592
if (
8693
torch.float16 in enabled_precisions
@@ -103,6 +110,7 @@ def compile(
103110
compilation_options = {
104111
"precision": precision,
105112
"debug": debug,
113+
"device": device,
106114
"workspace_size": workspace_size,
107115
"min_block_size": min_block_size,
108116
"torch_executed_ops": torch_executed_ops

‎py/torch_tensorrt/dynamo/conversion/conversion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,13 @@
33
import io
44
from typing import Sequence
55

6+
import tensorrt as trt
67
import torch
78
from torch_tensorrt._Input import Input
89
from torch_tensorrt.dynamo import CompilationSettings
910
from torch_tensorrt.dynamo.conversion import TRTInterpreter
1011
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
1112

12-
import tensorrt as trt
13-
1413

1514
def convert_module(
1615
module: torch.fx.GraphModule,
@@ -72,4 +71,5 @@ def convert_module(
7271
name=name,
7372
input_binding_names=list(interpreter_result.input_names),
7473
output_binding_names=list(interpreter_result.output_names),
74+
target_device=settings.device,
7575
)

‎py/torch_tensorrt/dynamo/utils.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import logging
44
from dataclasses import fields, replace
5-
from typing import Any, Callable, Dict, Optional, Sequence
5+
from typing import Any, Callable, Dict, Optional, Sequence, Union
66

77
import torch
88
from torch_tensorrt._Device import Device
@@ -114,23 +114,37 @@ def prepare_inputs(
114114
)
115115

116116

117-
def prepare_device(device: Device | torch.device) -> torch.device:
118-
_device: torch.device
117+
def to_torch_device(device: Union[Device, torch.device, str]) -> torch.device:
118+
"""Cast a device-type to torch.device
119+
120+
Returns the corresponding torch.device
121+
"""
119122
if isinstance(device, Device):
120123
if device.gpu_id != -1:
121-
_device = torch.device(device.gpu_id)
124+
return torch.device(device.gpu_id)
122125
else:
123126
raise ValueError("Invalid GPU ID provided for the CUDA device provided")
124127

125128
elif isinstance(device, torch.device):
126-
_device = device
129+
return device
127130

128131
else:
129-
raise ValueError(
130-
"Invalid device provided. Supported options: torch.device | torch_tensorrt.Device"
131-
)
132+
return torch.device(device)
132133

133-
return _device
134+
135+
def to_torch_tensorrt_device(device: Union[Device, torch.device, str]) -> Device:
136+
"""Cast a device-type to torch_tensorrt.Device
137+
138+
Returns the corresponding torch_tensorrt.Device
139+
"""
140+
if isinstance(device, Device):
141+
return device
142+
143+
elif isinstance(device, torch.device):
144+
return Device(gpu_id=device.index)
145+
146+
else:
147+
return Device(device)
134148

135149

136150
def parse_dynamo_kwargs(kwargs: Any) -> CompilationSettings:
@@ -164,6 +178,19 @@ def parse_dynamo_kwargs(kwargs: Any) -> CompilationSettings:
164178
# Parse input runtime specification
165179
settings.use_python_runtime = use_python_runtime_parser(settings.use_python_runtime)
166180

181+
# Ensure device is a torch_tensorrt Device
182+
settings.device = to_torch_tensorrt_device(settings.device)
183+
184+
# Check and update device settings
185+
default_torch_gpu_idx = torch.cuda.default_stream().device.index
186+
if "device" not in kwargs and default_torch_gpu_idx != settings.device.gpu_id:
187+
logger.warning(
188+
f"No device specified, detected differing gpu IDs for CUDA default: {settings.device.gpu_id} "
189+
f"and Torch default: {default_torch_gpu_idx}. Using Torch default gpu ID: {default_torch_gpu_idx}. "
190+
"If this is incorrect, please specify an input device, via the device keyword."
191+
)
192+
settings.device = Device(gpu_id=default_torch_gpu_idx)
193+
167194
logger.debug(f"Compiling with Settings:\n{settings}")
168195

169196
return settings

‎tests/py/dynamo/backend/test_compiler_utils.py

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,61 @@
1-
from torch_tensorrt.dynamo.utils import prepare_device, prepare_inputs
2-
from utils import same_output_format
3-
import torch_tensorrt
41
import unittest
2+
53
import torch
4+
import torch_tensorrt
5+
from torch_tensorrt.dynamo.utils import (
6+
prepare_inputs,
7+
to_torch_device,
8+
to_torch_tensorrt_device,
9+
)
10+
from utils import same_output_format
611

712

8-
class TestPrepareDevice(unittest.TestCase):
9-
def test_prepare_cuda_device(self):
13+
class TestToTorchDevice(unittest.TestCase):
14+
def test_cast_cuda_device(self):
1015
gpu_id = 0
1116
device = torch.device(f"cuda:{gpu_id}")
12-
prepared_device = prepare_device(device)
17+
prepared_device = to_torch_device(device)
1318
self.assertTrue(isinstance(prepared_device, torch.device))
1419
self.assertTrue(prepared_device.index == gpu_id)
1520

16-
def test_prepare_trt_device(self):
21+
def test_cast_trt_device(self):
1722
gpu_id = 4
1823
device = torch_tensorrt.Device(gpu_id=gpu_id)
19-
prepared_device = prepare_device(device)
24+
prepared_device = to_torch_device(device)
25+
self.assertTrue(isinstance(prepared_device, torch.device))
26+
self.assertTrue(prepared_device.index == gpu_id)
27+
28+
def test_cast_str_device(self):
29+
gpu_id = 2
30+
device = f"cuda:{2}"
31+
prepared_device = to_torch_device(device)
2032
self.assertTrue(isinstance(prepared_device, torch.device))
2133
self.assertTrue(prepared_device.index == gpu_id)
2234

2335

36+
class TestToTorchTRTDevice(unittest.TestCase):
37+
def test_cast_cuda_device(self):
38+
gpu_id = 0
39+
device = torch.device(f"cuda:{gpu_id}")
40+
prepared_device = to_torch_tensorrt_device(device)
41+
self.assertTrue(isinstance(prepared_device, torch_tensorrt.Device))
42+
self.assertTrue(prepared_device.gpu_id == gpu_id)
43+
44+
def test_cast_trt_device(self):
45+
gpu_id = 4
46+
device = torch_tensorrt.Device(gpu_id=gpu_id)
47+
prepared_device = to_torch_tensorrt_device(device)
48+
self.assertTrue(isinstance(prepared_device, torch_tensorrt.Device))
49+
self.assertTrue(prepared_device.gpu_id == gpu_id)
50+
51+
def test_cast_str_device(self):
52+
gpu_id = 2
53+
device = f"cuda:{2}"
54+
prepared_device = to_torch_tensorrt_device(device)
55+
self.assertTrue(isinstance(prepared_device, torch_tensorrt.Device))
56+
self.assertTrue(prepared_device.gpu_id == gpu_id)
57+
58+
2459
class TestPrepareInputs(unittest.TestCase):
2560
def test_prepare_single_tensor_input(self):
2661
inputs = [torch.ones((4, 4))]

0 commit comments

Comments
 (0)
Please sign in to comment.