Skip to content

Commit fa81096

Browse files
committed
feat: Add support for device compilation setting
- Add updated Device utilities and automatic context-aware device detection for torch compile
1 parent c3a65ef commit fa81096

File tree

6 files changed

+60
-18
lines changed

6 files changed

+60
-18
lines changed

py/torch_tensorrt/_Device.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,10 @@
88

99
import warnings
1010

11-
import torch
12-
from torch_tensorrt import logging
13-
1411
# from torch_tensorrt import _enums
1512
import tensorrt as trt
13+
import torch
14+
from torch_tensorrt import logging
1615

1716
try:
1817
from torch_tensorrt import _C
@@ -120,6 +119,9 @@ def __str__(self) -> str:
120119
)
121120
)
122121

122+
def __repr__(self) -> str:
123+
return self.__str__()
124+
123125
def _to_internal(self) -> _C.Device:
124126
internal_dev = _C.Device()
125127
if self.device_type == trt.DeviceType.GPU:

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
from torch_tensorrt._Device import Device
23

34
PRECISION = torch.float32
45
DEBUG = False
@@ -10,3 +11,4 @@
1011
OPTIMIZATION_LEVEL = None
1112
TRUNCATE_LONG_AND_DOUBLE = False
1213
USE_PYTHON_RUNTIME = False
14+
DEVICE = Device._current_device()

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
from typing import Optional, Set
33

44
import torch
5+
from torch_tensorrt._Device import Device
56
from torch_tensorrt.dynamo._defaults import (
67
DEBUG,
8+
DEVICE,
79
MAX_AUX_STREAMS,
810
MIN_BLOCK_SIZE,
911
OPTIMIZATION_LEVEL,
@@ -29,3 +31,4 @@ class CompilationSettings:
2931
optimization_level: Optional[int] = OPTIMIZATION_LEVEL
3032
use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME
3133
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE
34+
device: Device = DEVICE

py/torch_tensorrt/dynamo/compile.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import collections.abc
44
import logging
5-
from typing import Any, List, Optional, Set, Tuple
5+
from typing import Any, List, Optional, Set, Tuple, Union
66

77
import torch
88
import torch_tensorrt
@@ -15,6 +15,7 @@
1515
from torch_tensorrt.dynamo import CompilationSettings
1616
from torch_tensorrt.dynamo._defaults import (
1717
DEBUG,
18+
DEVICE,
1819
MAX_AUX_STREAMS,
1920
MIN_BLOCK_SIZE,
2021
OPTIMIZATION_LEVEL,
@@ -31,7 +32,11 @@
3132
fuse_permute_linear,
3233
fuse_permute_matmul,
3334
)
34-
from torch_tensorrt.dynamo.utils import prepare_device, prepare_inputs
35+
from torch_tensorrt.dynamo.utils import (
36+
prepare_inputs,
37+
to_torch_device,
38+
to_torch_tensorrt_device,
39+
)
3540
from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter, TRTSplitterSetting
3641

3742
logger = logging.getLogger(__name__)
@@ -41,7 +46,7 @@ def compile(
4146
gm: Any,
4247
inputs: Any,
4348
*,
44-
device: Device = Device._current_device(),
49+
device: Union[Device, torch.device, str] = DEVICE,
4550
disable_tf32: bool = False,
4651
sparse_weights: bool = False,
4752
enabled_precisions: Set[torch.dtype] | Tuple[torch.dtype] = (torch.float32,),
@@ -81,7 +86,9 @@ def compile(
8186
if not isinstance(inputs, collections.abc.Sequence):
8287
inputs = [inputs]
8388

84-
_, torch_inputs = prepare_inputs(inputs, prepare_device(device))
89+
device = to_torch_tensorrt_device(device)
90+
91+
_, torch_inputs = prepare_inputs(inputs, to_torch_device(device))
8592

8693
if (
8794
torch.float16 in enabled_precisions
@@ -104,6 +111,7 @@ def compile(
104111
compilation_options = {
105112
"precision": precision,
106113
"debug": debug,
114+
"device": device,
107115
"workspace_size": workspace_size,
108116
"min_block_size": min_block_size,
109117
"torch_executed_ops": torch_executed_ops

py/torch_tensorrt/dynamo/conversion/conversion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,13 @@
33
import io
44
from typing import Sequence
55

6+
import tensorrt as trt
67
import torch
78
from torch_tensorrt._Input import Input
89
from torch_tensorrt.dynamo import CompilationSettings
910
from torch_tensorrt.dynamo.conversion import TRTInterpreter
1011
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
1112

12-
import tensorrt as trt
13-
1413

1514
def convert_module(
1615
module: torch.fx.GraphModule,
@@ -72,4 +71,5 @@ def convert_module(
7271
name=name,
7372
input_binding_names=list(interpreter_result.input_names),
7473
output_binding_names=list(interpreter_result.output_names),
74+
target_device=settings.device,
7575
)

py/torch_tensorrt/dynamo/utils.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import logging
44
from dataclasses import fields, replace
5-
from typing import Any, Callable, Dict, Optional, Sequence
5+
from typing import Any, Callable, Dict, Optional, Sequence, Union
66

77
import torch
88
from torch_tensorrt._Device import Device
@@ -114,23 +114,37 @@ def prepare_inputs(
114114
)
115115

116116

117-
def prepare_device(device: Device | torch.device) -> torch.device:
118-
_device: torch.device
117+
def to_torch_device(device: Union[Device, torch.device, str]) -> torch.device:
118+
"""Cast a device-type to torch.device
119+
120+
Returns the corresponding torch.device
121+
"""
119122
if isinstance(device, Device):
120123
if device.gpu_id != -1:
121-
_device = torch.device(device.gpu_id)
124+
return torch.device(device.gpu_id)
122125
else:
123126
raise ValueError("Invalid GPU ID provided for the CUDA device provided")
124127

125128
elif isinstance(device, torch.device):
126-
_device = device
129+
return device
127130

128131
else:
129-
raise ValueError(
130-
"Invalid device provided. Supported options: torch.device | torch_tensorrt.Device"
131-
)
132+
return torch.device(device)
132133

133-
return _device
134+
135+
def to_torch_tensorrt_device(device: Union[Device, torch.device, str]) -> Device:
136+
"""Cast a device-type to torch_tensorrt.Device
137+
138+
Returns the corresponding torch_tensorrt.Device
139+
"""
140+
if isinstance(device, Device):
141+
return device
142+
143+
elif isinstance(device, torch.device):
144+
return Device(gpu_id=device.index)
145+
146+
else:
147+
return Device(device)
134148

135149

136150
def parse_dynamo_kwargs(kwargs: Any) -> CompilationSettings:
@@ -164,6 +178,19 @@ def parse_dynamo_kwargs(kwargs: Any) -> CompilationSettings:
164178
# Parse input runtime specification
165179
settings.use_python_runtime = use_python_runtime_parser(settings.use_python_runtime)
166180

181+
# Ensure device is a torch_tensorrt Device
182+
settings.device = to_torch_tensorrt_device(settings.device)
183+
184+
# Check and update device settings
185+
default_torch_gpu_idx = torch.cuda.default_stream().device.index
186+
if "device" not in kwargs and default_torch_gpu_idx != settings.device.gpu_id:
187+
logger.warning(
188+
f"No device specified, detected differing gpu IDs for CUDA default: {settings.device.gpu_id} "
189+
f"and Torch default: {default_torch_gpu_idx}. Using Torch default gpu ID: {default_torch_gpu_idx}. "
190+
"If this is incorrect, please specify an input device, via the device keyword."
191+
)
192+
settings.device = Device(gpu_id=default_torch_gpu_idx)
193+
167194
logger.debug(f"Compiling with Settings:\n{settings}")
168195

169196
return settings

0 commit comments

Comments
 (0)