Skip to content

Commit b8776e1

Browse files
lkctpytorchmergebot
authored andcommitted
Fix false DeprecationWarning in Module.state_dict
Fixes pytorch#75404 TODO: - [x] add tests Pull Request resolved: pytorch#75507 Approved by: https://github.com/jbschlosser
1 parent 429a80d commit b8776e1

File tree

4 files changed

+42
-60
lines changed

4 files changed

+42
-60
lines changed

test/test_nn.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6170,6 +6170,9 @@ def test_state_dict(self):
61706170
self.assertEqual(state_dict['weight'].data_ptr(), l.weight.data_ptr())
61716171
self.assertEqual(state_dict['bias'].data_ptr(), l.bias.data_ptr())
61726172

6173+
# Reference https://github.com/pytorch/pytorch/pull/75507#issuecomment-1110291545
6174+
self.assertNotWarn(lambda: l.state_dict(destination=dict()), "Should not warn kwarg destination w/o _metadata")
6175+
61736176
def test_load_state_dict(self):
61746177
l = nn.Linear(5, 5)
61756178
block = nn.Module()

torch/distributed/nn/api/remote_module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ def register_forward_pre_hook(self, hook: Callable[..., None]) -> RemovableHandl
362362
def register_forward_hook(self, hook: Callable[..., None]) -> RemovableHandle: # type: ignore[return]
363363
_raise_not_supported(self.register_forward_hook.__name__)
364364

365-
def state_dict(self, destination=None, prefix="", keep_vars=False):
365+
def state_dict(self, *args, **kwargs):
366366
_raise_not_supported(self.state_dict.__name__)
367367

368368
def load_state_dict(

torch/jit/_script.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -893,7 +893,6 @@ def _get_methods(cls):
893893
"double",
894894
"half",
895895
"state_dict",
896-
"_state_dict_impl",
897896
"_save_to_state_dict",
898897
"load_state_dict",
899898
"_load_from_state_dict",

torch/nn/modules/module.py

Lines changed: 38 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1296,75 +1296,48 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
12961296
if getattr(self.__class__, "get_extra_state", Module.get_extra_state) is not Module.get_extra_state:
12971297
destination[extra_state_key] = self.get_extra_state()
12981298

1299-
def _state_dict_impl(self, destination, prefix, keep_vars):
1300-
r"""Holds the actual implementation of
1301-
:meth:`~torch.nn.Module.state_dict`, with recursive calls for
1302-
descendants of this module.
1303-
1304-
In rare cases, users can call this directly to provide a custom
1305-
:attr:`destination`.
1306-
1307-
Args:
1308-
destination (dict): a dict where state will be stored
1309-
prefix (str): the prefix for parameters and buffers used in this
1310-
module
1311-
keep_vars (bool): whether NOT to return buffers detached from
1312-
autograd
1313-
"""
1314-
destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version)
1315-
self._save_to_state_dict(destination, prefix, keep_vars)
1316-
for name, module in self._modules.items():
1317-
if module is not None:
1318-
module.state_dict(destination, prefix + name + '.', keep_vars=keep_vars)
1319-
for hook in self._state_dict_hooks.values():
1320-
hook_result = hook(self, destination, prefix, local_metadata)
1321-
if hook_result is not None:
1322-
destination = hook_result
1323-
return destination
1324-
1325-
# TODO: Deprecated, destination is becoming private. Remove this signature when BC allows
1326-
# See https://github.com/pytorch/pytorch/issues/72778#issuecomment-1039263869
1299+
# The user can pass an optional arbitrary mappable object to `state_dict`, in which case `state_dict` returns
1300+
# back that same object. But if they pass nothing, an `OrederedDict` is created and returned.
13271301
T_destination = TypeVar('T_destination', bound=Dict[str, Any])
13281302

13291303
@overload
1330-
def state_dict(self, destination: T_destination, prefix: str = ..., keep_vars: bool = ...) -> T_destination:
1304+
def state_dict(self, *, destination: T_destination, prefix: str = ..., keep_vars: bool = ...) -> T_destination:
13311305
...
13321306

13331307
@overload
13341308
def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> Dict[str, Any]:
13351309
...
13361310

1311+
# TODO: Change `*args` to `*` and remove the copprespinding warning in docs when BC allows.
1312+
# Also remove the logic for arg parsing together.
13371313
def state_dict(self, *args, destination=None, prefix='', keep_vars=False):
13381314
r"""Returns a dictionary containing a whole state of the module.
13391315
13401316
Both parameters and persistent buffers (e.g. running averages) are
13411317
included. Keys are corresponding parameter and buffer names.
13421318
Parameters and buffers set to ``None`` are not included.
13431319
1344-
This can be called as
1345-
1346-
.. function:: state_dict(*, prefix='', keep_vars=False)
1347-
:noindex:
1348-
1349-
.. function:: state_dict(destination, prefix='', keep_vars=False)
1350-
:noindex:
1320+
.. warning::
1321+
Currently ``state_dict()`` also accepts positional arguments for
1322+
``destination``, ``prefix`` and ``keep_vars`` in order. However,
1323+
this is being deprecated and keyword arguments will be enforced in
1324+
future releases.
13511325
13521326
.. warning::
1353-
The second signature is deprecated and should not be used. It's only
1354-
temporarily kept for backward compatibility and will be removed in
1355-
a future release. Use the first signature instead.
1327+
Please avoid the use of argument ``destination`` as it is not
1328+
designed for end-users.
13561329
13571330
Args:
1358-
destination (dict, optional): Deprecated. This dict is returned
1359-
with the module state saved in it. It should also have an
1360-
attribute ``_metadata: dict`` to save metadata of the module
1361-
state. If it's not provided, an ``OrderedDict`` is created and
1362-
returned. Default: ``None``
1331+
destination (dict, optional): If provided, the state of module will
1332+
be updated into the dict and the same object is returned.
1333+
Otherwise, an ``OrderedDict`` will be created and returned.
1334+
Default: ``None``.
13631335
prefix (str, optional): a prefix added to parameter and buffer
1364-
names to compose the keys in dict. Default: ``''``
1336+
names to compose the keys in state_dict. Default: ``''``.
13651337
keep_vars (bool, optional): by default the :class:`~torch.Tensor` s
13661338
returned in the state dict are detached from autograd. If it's
1367-
set to ``True``, detaching is not performed. Default: ``False``
1339+
set to ``True``, detaching will not be performed.
1340+
Default: ``False``.
13681341
13691342
Returns:
13701343
dict:
@@ -1377,30 +1350,37 @@ def state_dict(self, *args, destination=None, prefix='', keep_vars=False):
13771350
13781351
"""
13791352

1380-
# TODO: positional args parsing is just for BC. Remove on transition to kwargs-only
1381-
warn_msg = []
1353+
# TODO: Remove `args` and the parsing logic when BC allows.
13821354
if len(args) > 0:
1383-
warn_msg.append('positional arguments')
13841355
if destination is None:
13851356
destination = args[0]
13861357
if len(args) > 1 and prefix == '':
13871358
prefix = args[1]
13881359
if len(args) > 2 and keep_vars is False:
13891360
keep_vars = args[2]
1361+
# DeprecationWarning is ignored by default
1362+
warnings.warn(
1363+
"Positional args are being deprecated, use kwargs instead. Refer to "
1364+
"https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict"
1365+
" for details.")
13901366

1391-
if destination is not None:
1392-
warn_msg.append('argument "destination"')
1393-
else:
1367+
if destination is None:
13941368
destination = OrderedDict()
13951369
destination._metadata = OrderedDict()
13961370

1397-
if warn_msg:
1398-
# DeprecationWarning is ignored by default
1399-
warnings.warn(
1400-
" and ".join(warn_msg) + " are deprecated. nn.Module.state_dict will not accept them in the future. "
1401-
"Refer to https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict for details.")
1371+
local_metadata = dict(version=self._version)
1372+
if hasattr(destination, "_metadata"):
1373+
destination._metadata[prefix[:-1]] = local_metadata
14021374

1403-
return self._state_dict_impl(destination, prefix, keep_vars)
1375+
self._save_to_state_dict(destination, prefix, keep_vars)
1376+
for name, module in self._modules.items():
1377+
if module is not None:
1378+
module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
1379+
for hook in self._state_dict_hooks.values():
1380+
hook_result = hook(self, destination, prefix, local_metadata)
1381+
if hook_result is not None:
1382+
destination = hook_result
1383+
return destination
14041384

14051385
def _register_load_state_dict_pre_hook(self, hook, with_module=False):
14061386
r"""These hooks will be called with arguments: `state_dict`, `prefix`,

0 commit comments

Comments
 (0)