8
8
import torch
9
9
from torch .nn import Module
10
10
from torch_tensorrt ._Device import Device
11
- from torch_tensorrt .dynamo .runtime .tools import _is_switch_required , _select_rt_device
11
+ from torch_tensorrt .dynamo .runtime .tools import (
12
+ _is_switch_required ,
13
+ _select_rt_device ,
14
+ multi_gpu_device_check ,
15
+ )
12
16
from torch_tensorrt .fx .utils import Frameworks , unified_dtype_converter
13
17
14
18
import torch_tensorrt
@@ -33,6 +37,10 @@ def __init__(
33
37
):
34
38
super (PythonTorchTensorRTModule , self ).__init__ ()
35
39
self ._register_state_dict_hook (PythonTorchTensorRTModule ._on_state_dict )
40
+
41
+ # Run multi-gpu device check to validate engine instantiation
42
+ multi_gpu_device_check ()
43
+
36
44
self .engine = engine
37
45
self .input_names = input_names if input_names is not None else []
38
46
self .output_names = output_names if output_names is not None else []
@@ -133,6 +141,9 @@ def _load_from_state_dict(
133
141
) -> None :
134
142
engine_bytes = state_dict [prefix + "engine" ]
135
143
144
+ # Run multi-gpu device check to validate engine instantiation
145
+ multi_gpu_device_check ()
146
+
136
147
logger = trt .Logger ()
137
148
runtime = trt .Runtime (logger )
138
149
self .engine = runtime .deserialize_cuda_engine (engine_bytes )
@@ -162,7 +173,9 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
162
173
self ._check_initialized ()
163
174
164
175
# If in safe mode, check at each iteration for for whether a switch is required
165
- if torch_tensorrt ._compile .SAFE_MODE :
176
+ if (
177
+ torch_tensorrt .runtime .multi_device_safe_mode ._PY_RT_MULTI_DEVICE_SAFE_MODE
178
+ ):
166
179
curr_device_id = torch .cuda .current_device ()
167
180
curr_device_properties = torch .cuda .get_device_properties (
168
181
curr_device_id
@@ -202,24 +215,22 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
202
215
)
203
216
204
217
for i , input_name in enumerate (self .input_names ):
205
- # Check that the inputs are on cuda and have the correct data type if in safe mode
206
- if torch_tensorrt ._compile .SAFE_MODE :
207
- if not contiguous_inputs [i ].is_cuda :
208
- logger .warning (
209
- f"Detected input { input_name } of engine { self .engine .name } is not on a cuda device. "
210
- "This tensor is being moved by the runtime but for performance considerations, "
211
- "ensure your inputs are all on GPU and open an issue here "
212
- "(https://github.com/pytorch/TensorRT/issues) if this warning persists."
213
- )
214
- contiguous_inputs = (
215
- contiguous_inputs [:i ]
216
- + [contiguous_inputs [i ].cuda ()]
217
- + contiguous_inputs [i + 1 :]
218
- )
219
-
220
- assert (
221
- contiguous_inputs [i ].dtype == self .input_dtypes [i ]
222
- ), f"Dtype mismatch for { i } th input({ input_name } ). Expect { self .input_dtypes [i ]} , got { contiguous_inputs [i ].dtype } ."
218
+ if not contiguous_inputs [i ].is_cuda :
219
+ logger .warning (
220
+ f"Detected input { input_name } of engine { self .engine .name } is not on a cuda device. "
221
+ "This tensor is being moved by the runtime but for performance considerations, "
222
+ "ensure your inputs are all on GPU and open an issue here "
223
+ "(https://github.com/pytorch/TensorRT/issues) if this warning persists."
224
+ )
225
+ contiguous_inputs = (
226
+ contiguous_inputs [:i ]
227
+ + [contiguous_inputs [i ].cuda ()]
228
+ + contiguous_inputs [i + 1 :]
229
+ )
230
+
231
+ assert (
232
+ contiguous_inputs [i ].dtype == self .input_dtypes [i ]
233
+ ), f"Dtype mismatch for { i } th input({ input_name } ). Expect { self .input_dtypes [i ]} , got { contiguous_inputs [i ].dtype } ."
223
234
224
235
idx = self .input_binding_indices_in_order [i ]
225
236
bindings [idx ] = contiguous_inputs [i ].data_ptr ()
0 commit comments