@@ -1296,75 +1296,48 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
1296
1296
if getattr (self .__class__ , "get_extra_state" , Module .get_extra_state ) is not Module .get_extra_state :
1297
1297
destination [extra_state_key ] = self .get_extra_state ()
1298
1298
1299
- def _state_dict_impl (self , destination , prefix , keep_vars ):
1300
- r"""Holds the actual implementation of
1301
- :meth:`~torch.nn.Module.state_dict`, with recursive calls for
1302
- descendants of this module.
1303
-
1304
- In rare cases, users can call this directly to provide a custom
1305
- :attr:`destination`.
1306
-
1307
- Args:
1308
- destination (dict): a dict where state will be stored
1309
- prefix (str): the prefix for parameters and buffers used in this
1310
- module
1311
- keep_vars (bool): whether NOT to return buffers detached from
1312
- autograd
1313
- """
1314
- destination ._metadata [prefix [:- 1 ]] = local_metadata = dict (version = self ._version )
1315
- self ._save_to_state_dict (destination , prefix , keep_vars )
1316
- for name , module in self ._modules .items ():
1317
- if module is not None :
1318
- module .state_dict (destination , prefix + name + '.' , keep_vars = keep_vars )
1319
- for hook in self ._state_dict_hooks .values ():
1320
- hook_result = hook (self , destination , prefix , local_metadata )
1321
- if hook_result is not None :
1322
- destination = hook_result
1323
- return destination
1324
-
1325
- # TODO: Deprecated, destination is becoming private. Remove this signature when BC allows
1326
- # See https://github.com/pytorch/pytorch/issues/72778#issuecomment-1039263869
1299
+ # The user can pass an optional arbitrary mappable object to `state_dict`, in which case `state_dict` returns
1300
+ # back that same object. But if they pass nothing, an `OrederedDict` is created and returned.
1327
1301
T_destination = TypeVar ('T_destination' , bound = Dict [str , Any ])
1328
1302
1329
1303
@overload
1330
- def state_dict (self , destination : T_destination , prefix : str = ..., keep_vars : bool = ...) -> T_destination :
1304
+ def state_dict (self , * , destination : T_destination , prefix : str = ..., keep_vars : bool = ...) -> T_destination :
1331
1305
...
1332
1306
1333
1307
@overload
1334
1308
def state_dict (self , * , prefix : str = ..., keep_vars : bool = ...) -> Dict [str , Any ]:
1335
1309
...
1336
1310
1311
+ # TODO: Change `*args` to `*` and remove the copprespinding warning in docs when BC allows.
1312
+ # Also remove the logic for arg parsing together.
1337
1313
def state_dict (self , * args , destination = None , prefix = '' , keep_vars = False ):
1338
1314
r"""Returns a dictionary containing a whole state of the module.
1339
1315
1340
1316
Both parameters and persistent buffers (e.g. running averages) are
1341
1317
included. Keys are corresponding parameter and buffer names.
1342
1318
Parameters and buffers set to ``None`` are not included.
1343
1319
1344
- This can be called as
1345
-
1346
- .. function:: state_dict(*, prefix='', keep_vars=False)
1347
- :noindex:
1348
-
1349
- .. function:: state_dict(destination, prefix='', keep_vars=False)
1350
- :noindex:
1320
+ .. warning::
1321
+ Currently ``state_dict()`` also accepts positional arguments for
1322
+ ``destination``, ``prefix`` and ``keep_vars`` in order. However,
1323
+ this is being deprecated and keyword arguments will be enforced in
1324
+ future releases.
1351
1325
1352
1326
.. warning::
1353
- The second signature is deprecated and should not be used. It's only
1354
- temporarily kept for backward compatibility and will be removed in
1355
- a future release. Use the first signature instead.
1327
+ Please avoid the use of argument ``destination`` as it is not
1328
+ designed for end-users.
1356
1329
1357
1330
Args:
1358
- destination (dict, optional): Deprecated. This dict is returned
1359
- with the module state saved in it. It should also have an
1360
- attribute ``_metadata: dict`` to save metadata of the module
1361
- state. If it's not provided, an ``OrderedDict`` is created and
1362
- returned. Default: ``None``
1331
+ destination (dict, optional): If provided, the state of module will
1332
+ be updated into the dict and the same object is returned.
1333
+ Otherwise, an ``OrderedDict`` will be created and returned.
1334
+ Default: ``None``.
1363
1335
prefix (str, optional): a prefix added to parameter and buffer
1364
- names to compose the keys in dict . Default: ``''``
1336
+ names to compose the keys in state_dict . Default: ``''``.
1365
1337
keep_vars (bool, optional): by default the :class:`~torch.Tensor` s
1366
1338
returned in the state dict are detached from autograd. If it's
1367
- set to ``True``, detaching is not performed. Default: ``False``
1339
+ set to ``True``, detaching will not be performed.
1340
+ Default: ``False``.
1368
1341
1369
1342
Returns:
1370
1343
dict:
@@ -1377,30 +1350,37 @@ def state_dict(self, *args, destination=None, prefix='', keep_vars=False):
1377
1350
1378
1351
"""
1379
1352
1380
- # TODO: positional args parsing is just for BC. Remove on transition to kwargs-only
1381
- warn_msg = []
1353
+ # TODO: Remove `args` and the parsing logic when BC allows.
1382
1354
if len (args ) > 0 :
1383
- warn_msg .append ('positional arguments' )
1384
1355
if destination is None :
1385
1356
destination = args [0 ]
1386
1357
if len (args ) > 1 and prefix == '' :
1387
1358
prefix = args [1 ]
1388
1359
if len (args ) > 2 and keep_vars is False :
1389
1360
keep_vars = args [2 ]
1361
+ # DeprecationWarning is ignored by default
1362
+ warnings .warn (
1363
+ "Positional args are being deprecated, use kwargs instead. Refer to "
1364
+ "https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict"
1365
+ " for details." )
1390
1366
1391
- if destination is not None :
1392
- warn_msg .append ('argument "destination"' )
1393
- else :
1367
+ if destination is None :
1394
1368
destination = OrderedDict ()
1395
1369
destination ._metadata = OrderedDict ()
1396
1370
1397
- if warn_msg :
1398
- # DeprecationWarning is ignored by default
1399
- warnings .warn (
1400
- " and " .join (warn_msg ) + " are deprecated. nn.Module.state_dict will not accept them in the future. "
1401
- "Refer to https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict for details." )
1371
+ local_metadata = dict (version = self ._version )
1372
+ if hasattr (destination , "_metadata" ):
1373
+ destination ._metadata [prefix [:- 1 ]] = local_metadata
1402
1374
1403
- return self ._state_dict_impl (destination , prefix , keep_vars )
1375
+ self ._save_to_state_dict (destination , prefix , keep_vars )
1376
+ for name , module in self ._modules .items ():
1377
+ if module is not None :
1378
+ module .state_dict (destination = destination , prefix = prefix + name + '.' , keep_vars = keep_vars )
1379
+ for hook in self ._state_dict_hooks .values ():
1380
+ hook_result = hook (self , destination , prefix , local_metadata )
1381
+ if hook_result is not None :
1382
+ destination = hook_result
1383
+ return destination
1404
1384
1405
1385
def _register_load_state_dict_pre_hook (self , hook , with_module = False ):
1406
1386
r"""These hooks will be called with arguments: `state_dict`, `prefix`,
0 commit comments