Skip to content

Commit d6e3ad3

Browse files
Chilleefacebook-github-bot
authored andcommitted
Added remove_duplicate parameter to nn.Module (#39)
Summary: Pull Request resolved: pytorch/torchrec#39 Pull Request resolved: pytorch/torchrec#6 This makes it so that shared parameters get their own entry in `named_parameters`. More broadly, this makes it so that ``` params_and_buffers = {**mod.named_named_parameters(remove_duplicate=False), **mod.named_buffers(remove_duplicate=False)} _stateless.functional_call(mod, params_and_buffers, args, kwargs) ``` is identical to calling the original module's forwards pass. cc pietern mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse SciPioneer H-Huang Pull Request resolved: pytorch#71542 Reviewed By: jbschlosser, albanD Differential Revision: D33716716 Pulled By: Chillee fbshipit-source-id: ff1ed9980bd1a3f7ebaf695ee5e401202b543213
1 parent 9705212 commit d6e3ad3

File tree

3 files changed

+37
-9
lines changed

3 files changed

+37
-9
lines changed

test/test_nn.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1118,6 +1118,24 @@ def names(named_buffers):
11181118
names(s.named_buffers()),
11191119
['0.dummy_buf', '0.l1.layer_dummy_buf'])
11201120

1121+
def test_named_parameters_buffers_duplicates(self):
1122+
class Foo(nn.Module):
1123+
def __init__(self):
1124+
super().__init__()
1125+
self.bias = nn.Parameter(torch.randn(3))
1126+
self.linear = nn.Linear(3, 3)
1127+
self.linear.bias = self.bias
1128+
self.linear_cloned = self.linear
1129+
self.register_buffer('buffer', torch.randn(3))
1130+
self.register_buffer('buffer_cloned', self.buffer)
1131+
1132+
mod = Foo()
1133+
self.assertEqual(len(list(mod.named_parameters())), 2)
1134+
self.assertEqual(len(list(mod.named_parameters(remove_duplicate=False))), 5)
1135+
1136+
self.assertEqual(len(list(mod.named_buffers())), 1)
1137+
self.assertEqual(len(list(mod.named_buffers(remove_duplicate=False))), 2)
1138+
11211139
def test_call_supports_python_dict_output(self):
11221140
class Net(nn.Module):
11231141
def __init__(self):

torch/distributed/nn/api/remote_module.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -374,15 +374,15 @@ def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
374374
)
375375

376376
def named_parameters( # type: ignore[return]
377-
self, prefix: str = "", recurse: bool = True
377+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
378378
) -> Iterator[Tuple[str, Parameter]]:
379379
_raise_not_supported(self.named_parameters.__name__)
380380

381381
def buffers(self, recurse: bool = True) -> Iterator[Tensor]: # type: ignore[return]
382382
_raise_not_supported(self.buffers.__name__)
383383

384384
def named_buffers( # type: ignore[return]
385-
self, prefix: str = "", recurse: bool = True
385+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
386386
) -> Iterator[Tuple[str, Tensor]]:
387387
_raise_not_supported(self.named_buffers.__name__)
388388

torch/nn/modules/module.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1498,16 +1498,17 @@ def load(module, prefix=''):
14981498
self.__class__.__name__, "\n\t".join(error_msgs)))
14991499
return _IncompatibleKeys(missing_keys, unexpected_keys)
15001500

1501-
def _named_members(self, get_members_fn, prefix='', recurse=True):
1501+
def _named_members(self, get_members_fn, prefix='', recurse=True, remove_duplicate=True):
15021502
r"""Helper method for yielding various names + members of modules."""
15031503
memo = set()
1504-
modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)]
1504+
modules = self.named_modules(prefix=prefix, remove_duplicate=remove_duplicate) if recurse else [(prefix, self)]
15051505
for module_prefix, module in modules:
15061506
members = get_members_fn(module)
15071507
for k, v in members:
15081508
if v is None or v in memo:
15091509
continue
1510-
memo.add(v)
1510+
if remove_duplicate:
1511+
memo.add(v)
15111512
name = module_prefix + ('.' if module_prefix else '') + k
15121513
yield name, v
15131514

@@ -1535,7 +1536,10 @@ def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
15351536
for name, param in self.named_parameters(recurse=recurse):
15361537
yield param
15371538

1538-
def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]:
1539+
def named_parameters(self,
1540+
prefix: str = '',
1541+
recurse: bool = True,
1542+
remove_duplicate: bool = True) -> Iterator[Tuple[str, Parameter]]:
15391543
r"""Returns an iterator over module parameters, yielding both the
15401544
name of the parameter as well as the parameter itself.
15411545
@@ -1544,6 +1548,9 @@ def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[T
15441548
recurse (bool): if True, then yields parameters of this module
15451549
and all submodules. Otherwise, yields only parameters that
15461550
are direct members of this module.
1551+
remove_duplicate (bool): if True, then removes parameters
1552+
that are duplicates of each other. For example, if two
1553+
parameters are tied, it'll only return one.
15471554
15481555
Yields:
15491556
(string, Parameter): Tuple containing the name and parameter
@@ -1557,7 +1564,7 @@ def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[T
15571564
"""
15581565
gen = self._named_members(
15591566
lambda module: module._parameters.items(),
1560-
prefix=prefix, recurse=recurse)
1567+
prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate)
15611568
for elem in gen:
15621569
yield elem
15631570

@@ -1583,7 +1590,7 @@ def buffers(self, recurse: bool = True) -> Iterator[Tensor]:
15831590
for _, buf in self.named_buffers(recurse=recurse):
15841591
yield buf
15851592

1586-
def named_buffers(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Tensor]]:
1593+
def named_buffers(self, prefix: str = '', recurse: bool = True, remove_duplicate: bool = True) -> Iterator[Tuple[str, Tensor]]:
15871594
r"""Returns an iterator over module buffers, yielding both the
15881595
name of the buffer as well as the buffer itself.
15891596
@@ -1592,6 +1599,9 @@ def named_buffers(self, prefix: str = '', recurse: bool = True) -> Iterator[Tupl
15921599
recurse (bool): if True, then yields buffers of this module
15931600
and all submodules. Otherwise, yields only buffers that
15941601
are direct members of this module.
1602+
remove_duplicate (bool): if True, then removes buffers
1603+
that are duplicates of each other. For example, if two
1604+
buffers are tied, it'll only return one.
15951605
15961606
Yields:
15971607
(string, torch.Tensor): Tuple containing the name and buffer
@@ -1605,7 +1615,7 @@ def named_buffers(self, prefix: str = '', recurse: bool = True) -> Iterator[Tupl
16051615
"""
16061616
gen = self._named_members(
16071617
lambda module: module._buffers.items(),
1608-
prefix=prefix, recurse=recurse)
1618+
prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate)
16091619
for elem in gen:
16101620
yield elem
16111621

0 commit comments

Comments
 (0)