Skip to content

Commit 1a52791

Browse files
feginpytorchmergebot
authored andcommitted
[DSD] Correctly handle shared parameters for optimizer state_dict (pytorch#128685)
* Fixes pytorch#128011 See the discussion in pytorch#128076 Current implementation of `set_optimizer_state_dict()` assumes that all the fqns returned by `_get_fqns()` must exist in the optimizer state_dict. This is not true if the model has shared parameters. In such a case, only one fqn of the shared parameters will appear in the optimizer state_dict. This PR addresses the issue. Differential Revision: [D58573487](https://our.internmc.facebook.com/intern/diff/D58573487/) Pull Request resolved: pytorch#128685 Approved by: https://github.com/LucasLLC
1 parent d77a1aa commit 1a52791

File tree

2 files changed

+63
-6
lines changed

2 files changed

+63
-6
lines changed

test/distributed/checkpoint/test_state_dict.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -851,6 +851,33 @@ def test_deprecate_fsdp_api(self) -> None:
851851
):
852852
get_model_state_dict(model)
853853

854+
@with_comms
855+
@skip_if_lt_x_gpu(2)
856+
def test_shared_weight(self):
857+
class TiedEmbeddingModel(nn.Module):
858+
def __init__(self, vocab_size, embedding_dim):
859+
super().__init__()
860+
self.embedding = nn.Embedding(vocab_size, embedding_dim)
861+
self.decoder = nn.Linear(embedding_dim, vocab_size)
862+
self.decoder.weight = self.embedding.weight # Tying weights
863+
864+
def forward(self, input):
865+
input = (input * 10).to(torch.int)
866+
embedded = self.embedding(input)
867+
output = self.decoder(embedded)
868+
return output
869+
870+
def init_model_optim():
871+
device_mesh = init_device_mesh("cuda", (self.world_size,))
872+
orig_model = TiedEmbeddingModel(10000, 300).to(torch.device("cuda"))
873+
orig_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-3)
874+
copy_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-3)
875+
dist_model = FSDP(copy.deepcopy(orig_model), device_mesh=device_mesh)
876+
dist_optim = torch.optim.AdamW(dist_model.parameters(), lr=1e-3)
877+
return orig_model, orig_optim, copy_optim, dist_model, dist_optim
878+
879+
self._test_save_load(init_model_optim)
880+
854881

855882
class TestNoComm(MultiProcessTestCase):
856883
def setUp(self) -> None:

torch/distributed/checkpoint/state_dict.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,9 @@ class _StateDictInfo(StateDictOptions):
153153
fqn_param_mapping: Dict[
154154
Union[str, torch.Tensor], Union[FQNS_T, torch.Tensor]
155155
] = field(default_factory=dict)
156+
shared_params_mapping: Dict[
157+
Union[str, torch.Tensor], Union[FQNS_T, torch.Tensor]
158+
] = field(default_factory=dict)
156159
submodule_prefixes: Set[str] = field(default_factory=set)
157160
handle_model: bool = True
158161
handle_optim: bool = True
@@ -286,14 +289,29 @@ def _verify_options(
286289
fqn_param_mapping: Dict[
287290
Union[str, torch.Tensor], Union[Set[str], torch.Tensor]
288291
] = {}
292+
shared_params_mapping: Dict[
293+
Union[str, torch.Tensor], Union[Set[str], torch.Tensor]
294+
] = {}
289295
for name, param in _iterate_valid_model_state(model):
296+
if isinstance(param, _EXTRA_STATE):
297+
continue
298+
290299
fqns = _get_fqns(model, name)
291-
if not isinstance(param, _EXTRA_STATE):
292-
fqn_param_mapping[param] = fqns
300+
fqn = fqn_param_mapping.get(param, None)
301+
if fqn is not None:
302+
cast(Set[str], fqn_param_mapping[param]).update(fqns)
303+
shared_params_mapping[param] = fqn_param_mapping[param]
304+
else:
305+
# We need to do copy as _get_fqns is lru_cached
306+
fqn_param_mapping[param] = fqns.copy()
293307
for fqn in fqns:
294308
if not isinstance(param, _EXTRA_STATE):
295309
fqn_param_mapping[fqn] = param
296310

311+
for param_, fqns_ in list(shared_params_mapping.items()):
312+
for fqn in fqns_:
313+
shared_params_mapping[fqn] = cast(torch.Tensor, param_)
314+
297315
submodule_prefixes: Set[str] = set()
298316
if submodules:
299317
submodules = set(submodules)
@@ -361,6 +379,7 @@ def fsdp_state_dict_type_without_warning(
361379
return _StateDictInfo(
362380
**asdict(options),
363381
fqn_param_mapping=fqn_param_mapping,
382+
shared_params_mapping=shared_params_mapping,
364383
submodule_prefixes=submodule_prefixes,
365384
fsdp_context=fsdp_context,
366385
fsdp_modules=cast(List[nn.Module], fsdp_modules),
@@ -450,7 +469,7 @@ def _get_model_state_dict(
450469

451470
for key in list(state_dict.keys()):
452471
fqns = _get_fqns(model, key)
453-
assert len(fqns) == 1
472+
assert len(fqns) == 1, (key, fqns)
454473
fqn = next(iter(fqns))
455474
if fqn != key:
456475
# As we only support FSDP, DDP, and TP, the only cases are
@@ -797,6 +816,19 @@ def _split_optim_state_dict(
797816
pg_state.append({_PARAMS: []})
798817
for param in param_group[_PARAMS]:
799818
for fqn in info.fqn_param_mapping[param]:
819+
if fqn in info.shared_params_mapping:
820+
in_params = False
821+
for loaded_param_group in cast(
822+
ListDictValueType, optim_state_dict[_PG]
823+
):
824+
if fqn in cast(List[str], loaded_param_group[_PARAMS]):
825+
in_params = True
826+
break
827+
else:
828+
in_params = True
829+
if not in_params:
830+
continue
831+
800832
params = pg_state[-1][_PARAMS]
801833
assert isinstance(params, list)
802834
params.append(fqn)
@@ -805,9 +837,7 @@ def _split_optim_state_dict(
805837
for loaded_param_group in cast(
806838
ListDictValueType, optim_state_dict[_PG]
807839
):
808-
params = loaded_param_group[_PARAMS]
809-
assert isinstance(params, list)
810-
if fqn in params:
840+
if fqn in cast(List[str], loaded_param_group[_PARAMS]):
811841
pg_mapping[id(loaded_param_group)] = len(return_osd[_PG]) - 1
812842

813843
for param_group in cast(ListDictValueType, optim_state_dict[_PG]):

0 commit comments

Comments
 (0)