Skip to content

Commit 0ed22fc

Browse files
Matthew Hoffmanpytorchmergebot
authored andcommitted
Merge type stubs torch nn parallel (pytorch#102194)
Fixes merge issue for pytorch#101528 In the above PR, `torch.nn.parallel.parallel_apply.get_a_var` was marked private to appease the [public interface linter](https://github.com/pytorch/pytorch/actions/runs/4999216467/jobs/8955582204#step:14:21666): pytorch@ceeb242 This broke CI pipelines running external dependencies that expected `get_a_var`'s name to not change. In this PR, we change the name back to `get_a_var` and include it in the `__all__` instead. Pull Request resolved: pytorch#102194 Approved by: https://github.com/ezyang
1 parent 7b6438d commit 0ed22fc

File tree

12 files changed

+140
-147
lines changed

12 files changed

+140
-147
lines changed

torch/distributed/utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@
55
import torch
66
import torch.distributed as dist
77
from torch.nn.parallel._functions import _get_stream
8-
from torch.nn.parallel.scatter_gather import ( # type: ignore[attr-defined]
9-
_is_namedtuple,
10-
)
8+
from torch.nn.parallel.scatter_gather import _is_namedtuple
119
from torch.nn.utils.rnn import PackedSequence
1210

1311
__all__ = [] # type: ignore[var-annotated]

torch/nn/parallel/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from .parallel_apply import parallel_apply
22
from .replicate import replicate
33
from .data_parallel import DataParallel, data_parallel
4-
from .scatter_gather import scatter, gather
4+
from .scatter_gather import gather, scatter
55
from .distributed import DistributedDataParallel
66

77
__all__ = ['replicate', 'scatter', 'parallel_apply', 'gather', 'data_parallel',

torch/nn/parallel/__init__.pyi

Lines changed: 0 additions & 5 deletions
This file was deleted.

torch/nn/parallel/common_types.pyi

Lines changed: 0 additions & 6 deletions
This file was deleted.

torch/nn/parallel/data_parallel.py

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch
33
import warnings
44
from itertools import chain
5+
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
56
from ..modules import Module
67
from .scatter_gather import scatter_kwargs, gather
78
from .replicate import replicate
@@ -15,7 +16,7 @@
1516

1617
__all__ = ['DataParallel', 'data_parallel']
1718

18-
def _check_balance(device_ids):
19+
def _check_balance(device_ids: Sequence[Union[int, torch.device]]) -> None:
1920
imbalance_warn = """
2021
There is an imbalance between your GPUs. You may want to exclude GPU {} which
2122
has less than 75% of the memory or cores of GPU {}. You can do so by setting
@@ -121,7 +122,13 @@ class DataParallel(Module):
121122

122123
# TODO: update notes/cuda.rst when this class handles 8+ GPUs well
123124

124-
def __init__(self, module, device_ids=None, output_device=None, dim=0):
125+
def __init__(
126+
self,
127+
module: Module,
128+
device_ids: Optional[Sequence[Union[int, torch.device]]] = None,
129+
output_device: Optional[Union[int, torch.device]] = None,
130+
dim: int = 0,
131+
) -> None:
125132
super().__init__()
126133
torch._C._log_api_usage_once("torch.nn.parallel.DataParallel")
127134
device_type = _get_available_device_type()
@@ -133,6 +140,9 @@ def __init__(self, module, device_ids=None, output_device=None, dim=0):
133140
if device_ids is None:
134141
device_ids = _get_all_device_indices()
135142

143+
if device_ids is None:
144+
raise RuntimeError("no available devices were found")
145+
136146
if output_device is None:
137147
output_device = device_ids[0]
138148

@@ -147,7 +157,7 @@ def __init__(self, module, device_ids=None, output_device=None, dim=0):
147157
if len(self.device_ids) == 1:
148158
self.module.to(self.src_device_obj)
149159

150-
def forward(self, *inputs, **kwargs):
160+
def forward(self, *inputs: Any, **kwargs: Any) -> Any:
151161
with torch.autograd.profiler.record_function("DataParallel.forward"):
152162
if not self.device_ids:
153163
return self.module(*inputs, **kwargs)
@@ -158,33 +168,45 @@ def forward(self, *inputs, **kwargs):
158168
"on device {} (device_ids[0]) but found one of "
159169
"them on device: {}".format(self.src_device_obj, t.device))
160170

161-
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
171+
inputs, module_kwargs = self.scatter(inputs, kwargs, self.device_ids)
162172
# for forward function without any inputs, empty list and dict will be created
163173
# so the module can be executed on one device which is the first one in device_ids
164-
if not inputs and not kwargs:
174+
if not inputs and not module_kwargs:
165175
inputs = ((),)
166-
kwargs = ({},)
176+
module_kwargs = ({},)
167177

168178
if len(self.device_ids) == 1:
169-
return self.module(*inputs[0], **kwargs[0])
179+
return self.module(*inputs[0], **module_kwargs[0])
170180
replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
171-
outputs = self.parallel_apply(replicas, inputs, kwargs)
181+
outputs = self.parallel_apply(replicas, inputs, module_kwargs)
172182
return self.gather(outputs, self.output_device)
173183

174-
def replicate(self, module, device_ids):
184+
def replicate(self, module: Module, device_ids: Sequence[Union[int, torch.device]]) -> List[Module]:
175185
return replicate(module, device_ids, not torch.is_grad_enabled())
176186

177-
def scatter(self, inputs, kwargs, device_ids):
187+
def scatter(
188+
self,
189+
inputs: Tuple[Any, ...],
190+
kwargs: Optional[Dict[str, Any]],
191+
device_ids: Sequence[Union[int, torch.device]],
192+
) -> Any:
178193
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
179194

180-
def parallel_apply(self, replicas, inputs, kwargs):
195+
def parallel_apply(self, replicas: Sequence[Module], inputs: Sequence[Any], kwargs: Any) -> List[Any]:
181196
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
182197

183-
def gather(self, outputs, output_device):
198+
def gather(self, outputs: Any, output_device: Union[int, torch.device]) -> Any:
184199
return gather(outputs, output_device, dim=self.dim)
185200

186201

187-
def data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, module_kwargs=None):
202+
def data_parallel(
203+
module: Module,
204+
inputs: Any,
205+
device_ids: Optional[Sequence[Union[int, torch.device]]] = None,
206+
output_device: Optional[Union[int, torch.device]] = None,
207+
dim: int = 0,
208+
module_kwargs: Optional[Any] = None,
209+
) -> torch.Tensor:
188210
r"""Evaluates module(input) in parallel across the GPUs given in device_ids.
189211
190212
This is the functional version of the DataParallel module.
@@ -204,9 +226,15 @@ def data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, mo
204226

205227
device_type = _get_available_device_type()
206228

229+
if device_type is None:
230+
raise RuntimeError("device type could not be determined")
231+
207232
if device_ids is None:
208233
device_ids = _get_all_device_indices()
209234

235+
if device_ids is None:
236+
raise RuntimeError("no available devices were found")
237+
210238
if output_device is None:
211239
output_device = device_ids[0]
212240

@@ -227,6 +255,8 @@ def data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, mo
227255
inputs = ((),)
228256
module_kwargs = ({},)
229257

258+
assert module_kwargs is not None
259+
230260
if len(device_ids) == 1:
231261
return module(*inputs[0], **module_kwargs[0])
232262
used_device_ids = device_ids[:len(inputs)]

torch/nn/parallel/data_parallel.pyi

Lines changed: 0 additions & 29 deletions
This file was deleted.

torch/nn/parallel/parallel_apply.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
import threading
22
import torch
3+
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
4+
from ..modules import Module
35
from torch.cuda._utils import _get_device_index
46
from torch.cuda.amp import autocast
57
from torch._utils import ExceptionWrapper
68

9+
__all__ = ['get_a_var', 'parallel_apply']
710

8-
def get_a_var(obj):
11+
def get_a_var(obj: Union[torch.Tensor, List[Any], Tuple[Any, ...], Dict[Any, Any]]) -> Optional[torch.Tensor]:
912
if isinstance(obj, torch.Tensor):
1013
return obj
1114

@@ -19,8 +22,12 @@ def get_a_var(obj):
1922
return result
2023
return None
2124

22-
23-
def parallel_apply(modules, inputs, kwargs_tup=None, devices=None):
25+
def parallel_apply(
26+
modules: Sequence[Module],
27+
inputs: Sequence[Any],
28+
kwargs_tup: Optional[Sequence[Dict[str, Any]]] = None,
29+
devices: Optional[Sequence[Optional[Union[int, torch.device]]]] = None,
30+
) -> List[Any]:
2431
r"""Applies each `module` in :attr:`modules` in parallel on arguments
2532
contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword)
2633
on each of :attr:`devices`.
@@ -39,7 +46,7 @@ def parallel_apply(modules, inputs, kwargs_tup=None, devices=None):
3946
if kwargs_tup is not None:
4047
assert len(modules) == len(kwargs_tup)
4148
else:
42-
kwargs_tup = ({},) * len(modules)
49+
kwargs_tup = (cast(Dict[str, Any], {}),) * len(modules)
4350
if devices is not None:
4451
assert len(modules) == len(devices)
4552
else:
@@ -50,10 +57,24 @@ def parallel_apply(modules, inputs, kwargs_tup=None, devices=None):
5057
results = {}
5158
grad_enabled, autocast_enabled = torch.is_grad_enabled(), torch.is_autocast_enabled()
5259

53-
def _worker(i, module, input, kwargs, device=None, stream=None):
60+
def _worker(
61+
i: int,
62+
module: Module,
63+
input: Any,
64+
kwargs: Dict[str, Any],
65+
device: Optional[Union[int, torch.device]] = None,
66+
stream: Optional[torch.cuda.Stream] = None,
67+
) -> None:
5468
torch.set_grad_enabled(grad_enabled)
5569
if device is None:
56-
device = get_a_var(input).get_device()
70+
t = get_a_var(input)
71+
if t is None:
72+
with lock:
73+
results[i] = ExceptionWrapper(
74+
where="in replica {}, no device was provided and no tensor input was found; "
75+
"device cannot be resolved".format(i))
76+
return
77+
device = t.get_device()
5778
if stream is None:
5879
stream = torch.cuda.current_stream(device)
5980
try:

torch/nn/parallel/parallel_apply.pyi

Lines changed: 0 additions & 11 deletions
This file was deleted.

torch/nn/parallel/replicate.py

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,34 @@
1+
import torch
2+
from ..modules import Module
13
from . import comm
4+
from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Sequence, Set, Union
25
from torch._utils import _get_device_index
36

47
from collections import OrderedDict
58

9+
if TYPE_CHECKING:
10+
import torch.jit
11+
import torch.jit._state
12+
13+
__all__ = ['replicate']
614

7-
def _is_script_module(module):
15+
def _is_script_module(module: Module) -> bool:
816
import torch.jit
917
return isinstance(module, torch.jit.ScriptModule)
1018

1119

12-
def _is_script_method(module):
20+
def _is_script_method(module: Module) -> bool:
1321
import torch.jit
1422
return isinstance(module, torch._C.ScriptMethod)
1523

1624

17-
def _init_script_module():
25+
def _init_script_module() -> "torch.jit.ScriptModule":
1826
import torch.jit
1927
return torch.jit.ScriptModule()
2028

2129

22-
def _is_jit_enabled():
23-
import torch.jit
30+
def _is_jit_enabled() -> "torch.jit._state.EnabledProxy":
31+
import torch.jit._state
2432
return torch.jit._state._enabled
2533

2634

@@ -31,10 +39,10 @@ def _is_jit_enabled():
3139
#
3240
# currently a module cannot be replicated properly if the descendants of
3341
# any ScriptModule contains python module (type 1 above)
34-
def _replicatable_module(module, memo=None):
42+
def _replicatable_module(module: Module, memo: Optional[Set[Module]] = None) -> bool:
3543

3644
# module.modules() contains module itself as the first element
37-
def descendant_modules(module):
45+
def descendant_modules(module: Module) -> Iterator[Module]:
3846
gen = module.modules()
3947
next(gen)
4048
return gen
@@ -61,7 +69,11 @@ def descendant_modules(module):
6169

6270
return True
6371

64-
def _broadcast_coalesced_reshape(tensors, devices, detach=False):
72+
def _broadcast_coalesced_reshape(
73+
tensors: Sequence[torch.Tensor],
74+
devices: Sequence[Union[int, torch.device]],
75+
detach: bool = False,
76+
) -> List[List[torch.Tensor]]:
6577
from ._functions import Broadcast
6678
if detach:
6779
return comm.broadcast_coalesced(tensors, devices)
@@ -75,7 +87,11 @@ def _broadcast_coalesced_reshape(tensors, devices, detach=False):
7587
return []
7688

7789

78-
def replicate(network, devices, detach=False):
90+
def replicate(
91+
network: Module,
92+
devices: Sequence[Union[int, torch.device]],
93+
detach: bool = False,
94+
) -> List[Module]:
7995
if not _replicatable_module(network):
8096
raise RuntimeError("Cannot replicate network where python modules are "
8197
"childrens of ScriptModule")
@@ -91,8 +107,8 @@ def replicate(network, devices, detach=False):
91107
param_copies = _broadcast_coalesced_reshape(params, devices, detach)
92108

93109
buffers = list(network.buffers())
94-
buffers_rg = []
95-
buffers_not_rg = []
110+
buffers_rg: List[torch.Tensor] = []
111+
buffers_not_rg: List[torch.Tensor] = []
96112
for buf in buffers:
97113
if buf.requires_grad and not detach:
98114
buffers_rg.append(buf)
@@ -106,8 +122,8 @@ def replicate(network, devices, detach=False):
106122
buffer_copies_not_rg = _broadcast_coalesced_reshape(buffers_not_rg, devices, detach=True)
107123

108124
modules = list(network.modules())
109-
module_copies = [[] for device in devices]
110-
module_indices = {}
125+
module_copies: List[List[Module]] = [[] for _ in devices]
126+
module_indices: Dict[Module, int] = {}
111127

112128
for i, module in enumerate(modules):
113129
module_indices[module] = i
@@ -142,13 +158,13 @@ def replicate(network, devices, detach=False):
142158
param_idx = param_indices[param]
143159
for j in range(num_replicas):
144160
replica = module_copies[j][i]
145-
param = param_copies[j][param_idx]
161+
param_copy = param_copies[j][param_idx]
146162
# parameters in replicas are no longer leaves,
147163
# so setattr them as non-parameter attributes
148-
setattr(replica, key, param)
164+
setattr(replica, key, param_copy)
149165
# expose the parameter for DDP
150-
replica._former_parameters[key] = param
151-
for key, buf in module._buffers.items():
166+
replica._former_parameters[key] = param_copy
167+
for key, buf in module._buffers.items(): # type: ignore[assignment]
152168
if buf is None:
153169
for j in range(num_replicas):
154170
replica = module_copies[j][i]

0 commit comments

Comments
 (0)