2
2
import torch
3
3
import warnings
4
4
from itertools import chain
5
+ from typing import Any , Dict , List , Optional , Sequence , Tuple , Union
5
6
from ..modules import Module
6
7
from .scatter_gather import scatter_kwargs , gather
7
8
from .replicate import replicate
15
16
16
17
__all__ = ['DataParallel' , 'data_parallel' ]
17
18
18
- def _check_balance (device_ids ) :
19
+ def _check_balance (device_ids : Sequence [ Union [ int , torch . device ]]) -> None :
19
20
imbalance_warn = """
20
21
There is an imbalance between your GPUs. You may want to exclude GPU {} which
21
22
has less than 75% of the memory or cores of GPU {}. You can do so by setting
@@ -121,7 +122,13 @@ class DataParallel(Module):
121
122
122
123
# TODO: update notes/cuda.rst when this class handles 8+ GPUs well
123
124
124
- def __init__ (self , module , device_ids = None , output_device = None , dim = 0 ):
125
+ def __init__ (
126
+ self ,
127
+ module : Module ,
128
+ device_ids : Optional [Sequence [Union [int , torch .device ]]] = None ,
129
+ output_device : Optional [Union [int , torch .device ]] = None ,
130
+ dim : int = 0 ,
131
+ ) -> None :
125
132
super ().__init__ ()
126
133
torch ._C ._log_api_usage_once ("torch.nn.parallel.DataParallel" )
127
134
device_type = _get_available_device_type ()
@@ -133,6 +140,9 @@ def __init__(self, module, device_ids=None, output_device=None, dim=0):
133
140
if device_ids is None :
134
141
device_ids = _get_all_device_indices ()
135
142
143
+ if device_ids is None :
144
+ raise RuntimeError ("no available devices were found" )
145
+
136
146
if output_device is None :
137
147
output_device = device_ids [0 ]
138
148
@@ -147,7 +157,7 @@ def __init__(self, module, device_ids=None, output_device=None, dim=0):
147
157
if len (self .device_ids ) == 1 :
148
158
self .module .to (self .src_device_obj )
149
159
150
- def forward (self , * inputs , ** kwargs ) :
160
+ def forward (self , * inputs : Any , ** kwargs : Any ) -> Any :
151
161
with torch .autograd .profiler .record_function ("DataParallel.forward" ):
152
162
if not self .device_ids :
153
163
return self .module (* inputs , ** kwargs )
@@ -158,33 +168,45 @@ def forward(self, *inputs, **kwargs):
158
168
"on device {} (device_ids[0]) but found one of "
159
169
"them on device: {}" .format (self .src_device_obj , t .device ))
160
170
161
- inputs , kwargs = self .scatter (inputs , kwargs , self .device_ids )
171
+ inputs , module_kwargs = self .scatter (inputs , kwargs , self .device_ids )
162
172
# for forward function without any inputs, empty list and dict will be created
163
173
# so the module can be executed on one device which is the first one in device_ids
164
- if not inputs and not kwargs :
174
+ if not inputs and not module_kwargs :
165
175
inputs = ((),)
166
- kwargs = ({},)
176
+ module_kwargs = ({},)
167
177
168
178
if len (self .device_ids ) == 1 :
169
- return self .module (* inputs [0 ], ** kwargs [0 ])
179
+ return self .module (* inputs [0 ], ** module_kwargs [0 ])
170
180
replicas = self .replicate (self .module , self .device_ids [:len (inputs )])
171
- outputs = self .parallel_apply (replicas , inputs , kwargs )
181
+ outputs = self .parallel_apply (replicas , inputs , module_kwargs )
172
182
return self .gather (outputs , self .output_device )
173
183
174
- def replicate (self , module , device_ids ) :
184
+ def replicate (self , module : Module , device_ids : Sequence [ Union [ int , torch . device ]]) -> List [ Module ] :
175
185
return replicate (module , device_ids , not torch .is_grad_enabled ())
176
186
177
- def scatter (self , inputs , kwargs , device_ids ):
187
+ def scatter (
188
+ self ,
189
+ inputs : Tuple [Any , ...],
190
+ kwargs : Optional [Dict [str , Any ]],
191
+ device_ids : Sequence [Union [int , torch .device ]],
192
+ ) -> Any :
178
193
return scatter_kwargs (inputs , kwargs , device_ids , dim = self .dim )
179
194
180
- def parallel_apply (self , replicas , inputs , kwargs ) :
195
+ def parallel_apply (self , replicas : Sequence [ Module ] , inputs : Sequence [ Any ] , kwargs : Any ) -> List [ Any ] :
181
196
return parallel_apply (replicas , inputs , kwargs , self .device_ids [:len (replicas )])
182
197
183
- def gather (self , outputs , output_device ) :
198
+ def gather (self , outputs : Any , output_device : Union [ int , torch . device ]) -> Any :
184
199
return gather (outputs , output_device , dim = self .dim )
185
200
186
201
187
- def data_parallel (module , inputs , device_ids = None , output_device = None , dim = 0 , module_kwargs = None ):
202
+ def data_parallel (
203
+ module : Module ,
204
+ inputs : Any ,
205
+ device_ids : Optional [Sequence [Union [int , torch .device ]]] = None ,
206
+ output_device : Optional [Union [int , torch .device ]] = None ,
207
+ dim : int = 0 ,
208
+ module_kwargs : Optional [Any ] = None ,
209
+ ) -> torch .Tensor :
188
210
r"""Evaluates module(input) in parallel across the GPUs given in device_ids.
189
211
190
212
This is the functional version of the DataParallel module.
@@ -204,9 +226,15 @@ def data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, mo
204
226
205
227
device_type = _get_available_device_type ()
206
228
229
+ if device_type is None :
230
+ raise RuntimeError ("device type could not be determined" )
231
+
207
232
if device_ids is None :
208
233
device_ids = _get_all_device_indices ()
209
234
235
+ if device_ids is None :
236
+ raise RuntimeError ("no available devices were found" )
237
+
210
238
if output_device is None :
211
239
output_device = device_ids [0 ]
212
240
@@ -227,6 +255,8 @@ def data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, mo
227
255
inputs = ((),)
228
256
module_kwargs = ({},)
229
257
258
+ assert module_kwargs is not None
259
+
230
260
if len (device_ids ) == 1 :
231
261
return module (* inputs [0 ], ** module_kwargs [0 ])
232
262
used_device_ids = device_ids [:len (inputs )]
0 commit comments