|
2 | 2 |
|
3 | 3 | import logging
|
4 | 4 | from dataclasses import fields, replace
|
5 |
| -from typing import Any, Callable, Dict, Optional, Sequence |
| 5 | +from typing import Any, Callable, Dict, Optional, Sequence, Union |
6 | 6 |
|
7 | 7 | import torch
|
8 | 8 | from torch_tensorrt._Device import Device
|
@@ -114,23 +114,37 @@ def prepare_inputs(
|
114 | 114 | )
|
115 | 115 |
|
116 | 116 |
|
117 |
| -def prepare_device(device: Device | torch.device) -> torch.device: |
118 |
| - _device: torch.device |
| 117 | +def to_torch_device(device: Union[Device, torch.device, str]) -> torch.device: |
| 118 | + """Cast a device-type to torch.device |
| 119 | +
|
| 120 | + Returns the corresponding torch.device |
| 121 | + """ |
119 | 122 | if isinstance(device, Device):
|
120 | 123 | if device.gpu_id != -1:
|
121 |
| - _device = torch.device(device.gpu_id) |
| 124 | + return torch.device(device.gpu_id) |
122 | 125 | else:
|
123 | 126 | raise ValueError("Invalid GPU ID provided for the CUDA device provided")
|
124 | 127 |
|
125 | 128 | elif isinstance(device, torch.device):
|
126 |
| - _device = device |
| 129 | + return device |
127 | 130 |
|
128 | 131 | else:
|
129 |
| - raise ValueError( |
130 |
| - "Invalid device provided. Supported options: torch.device | torch_tensorrt.Device" |
131 |
| - ) |
| 132 | + return torch.device(device) |
132 | 133 |
|
133 |
| - return _device |
| 134 | + |
| 135 | +def to_torch_tensorrt_device(device: Union[Device, torch.device, str]) -> Device: |
| 136 | + """Cast a device-type to torch_tensorrt.Device |
| 137 | +
|
| 138 | + Returns the corresponding torch_tensorrt.Device |
| 139 | + """ |
| 140 | + if isinstance(device, Device): |
| 141 | + return device |
| 142 | + |
| 143 | + elif isinstance(device, torch.device): |
| 144 | + return Device(gpu_id=device.index) |
| 145 | + |
| 146 | + else: |
| 147 | + return Device(device) |
134 | 148 |
|
135 | 149 |
|
136 | 150 | def parse_dynamo_kwargs(kwargs: Any) -> CompilationSettings:
|
@@ -164,6 +178,19 @@ def parse_dynamo_kwargs(kwargs: Any) -> CompilationSettings:
|
164 | 178 | # Parse input runtime specification
|
165 | 179 | settings.use_python_runtime = use_python_runtime_parser(settings.use_python_runtime)
|
166 | 180 |
|
| 181 | + # Ensure device is a torch_tensorrt Device |
| 182 | + settings.device = to_torch_tensorrt_device(settings.device) |
| 183 | + |
| 184 | + # Check and update device settings |
| 185 | + default_torch_gpu_idx = torch.cuda.default_stream().device.index |
| 186 | + if "device" not in kwargs and default_torch_gpu_idx != settings.device.gpu_id: |
| 187 | + logger.warning( |
| 188 | + f"No device specified, detected differing gpu IDs for CUDA default: {settings.device.gpu_id} " |
| 189 | + f"and Torch default: {default_torch_gpu_idx}. Using Torch default gpu ID: {default_torch_gpu_idx}. " |
| 190 | + "If this is incorrect, please specify an input device, via the device keyword." |
| 191 | + ) |
| 192 | + settings.device = Device(gpu_id=default_torch_gpu_idx) |
| 193 | + |
167 | 194 | logger.debug(f"Compiling with Settings:\n{settings}")
|
168 | 195 |
|
169 | 196 | return settings
|
|
0 commit comments