Skip to content

Commit ec19059

Browse files
authored
[DSD] Correctly handle shared parameters for optimizer state_dict (#1… (pytorch#129252)
[DSD] Correctly handle shared parameters for optimizer state_dict (pytorch#128685) * Fixes pytorch#128011 See the discussion in pytorch#128076 Current implementation of `set_optimizer_state_dict()` assumes that all the fqns returned by `_get_fqns()` must exist in the optimizer state_dict. This is not true if the model has shared parameters. In such a case, only one fqn of the shared parameters will appear in the optimizer state_dict. This PR addresses the issue. Differential Revision: [D58573487](https://our.internmc.facebook.com/intern/diff/D58573487/) Pull Request resolved: pytorch#128685 Approved by: https://github.com/LucasLLC (cherry picked from commit 1a52791)
1 parent 04e98d3 commit ec19059

File tree

2 files changed

+63
-6
lines changed

2 files changed

+63
-6
lines changed

test/distributed/checkpoint/test_state_dict.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -813,6 +813,33 @@ def test_deprecate_fsdp_api(self) -> None:
813813
):
814814
get_model_state_dict(model)
815815

816+
@with_comms
817+
@skip_if_lt_x_gpu(2)
818+
def test_shared_weight(self):
819+
class TiedEmbeddingModel(nn.Module):
820+
def __init__(self, vocab_size, embedding_dim):
821+
super().__init__()
822+
self.embedding = nn.Embedding(vocab_size, embedding_dim)
823+
self.decoder = nn.Linear(embedding_dim, vocab_size)
824+
self.decoder.weight = self.embedding.weight # Tying weights
825+
826+
def forward(self, input):
827+
input = (input * 10).to(torch.int)
828+
embedded = self.embedding(input)
829+
output = self.decoder(embedded)
830+
return output
831+
832+
def init_model_optim():
833+
device_mesh = init_device_mesh("cuda", (self.world_size,))
834+
orig_model = TiedEmbeddingModel(10000, 300).to(torch.device("cuda"))
835+
orig_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-3)
836+
copy_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-3)
837+
dist_model = FSDP(copy.deepcopy(orig_model), device_mesh=device_mesh)
838+
dist_optim = torch.optim.AdamW(dist_model.parameters(), lr=1e-3)
839+
return orig_model, orig_optim, copy_optim, dist_model, dist_optim
840+
841+
self._test_save_load(init_model_optim)
842+
816843

817844
class TestNoComm(MultiProcessTestCase):
818845
def setUp(self) -> None:

torch/distributed/checkpoint/state_dict.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,9 @@ class _StateDictInfo(StateDictOptions):
151151
fqn_param_mapping: Dict[
152152
Union[str, torch.Tensor], Union[FQNS_T, torch.Tensor]
153153
] = field(default_factory=dict)
154+
shared_params_mapping: Dict[
155+
Union[str, torch.Tensor], Union[FQNS_T, torch.Tensor]
156+
] = field(default_factory=dict)
154157
submodule_prefixes: Set[str] = field(default_factory=set)
155158
handle_model: bool = True
156159
handle_optim: bool = True
@@ -284,14 +287,29 @@ def _verify_options(
284287
fqn_param_mapping: Dict[
285288
Union[str, torch.Tensor], Union[Set[str], torch.Tensor]
286289
] = {}
290+
shared_params_mapping: Dict[
291+
Union[str, torch.Tensor], Union[Set[str], torch.Tensor]
292+
] = {}
287293
for name, param in _iterate_valid_model_state(model):
294+
if isinstance(param, _EXTRA_STATE):
295+
continue
296+
288297
fqns = _get_fqns(model, name)
289-
if not isinstance(param, _EXTRA_STATE):
290-
fqn_param_mapping[param] = fqns
298+
fqn = fqn_param_mapping.get(param, None)
299+
if fqn is not None:
300+
cast(Set[str], fqn_param_mapping[param]).update(fqns)
301+
shared_params_mapping[param] = fqn_param_mapping[param]
302+
else:
303+
# We need to do copy as _get_fqns is lru_cached
304+
fqn_param_mapping[param] = fqns.copy()
291305
for fqn in fqns:
292306
if not isinstance(param, _EXTRA_STATE):
293307
fqn_param_mapping[fqn] = param
294308

309+
for param_, fqns_ in list(shared_params_mapping.items()):
310+
for fqn in fqns_:
311+
shared_params_mapping[fqn] = cast(torch.Tensor, param_)
312+
295313
submodule_prefixes: Set[str] = set()
296314
if submodules:
297315
submodules = set(submodules)
@@ -359,6 +377,7 @@ def fsdp_state_dict_type_without_warning(
359377
return _StateDictInfo(
360378
**asdict(options),
361379
fqn_param_mapping=fqn_param_mapping,
380+
shared_params_mapping=shared_params_mapping,
362381
submodule_prefixes=submodule_prefixes,
363382
fsdp_context=fsdp_context,
364383
fsdp_modules=cast(List[nn.Module], fsdp_modules),
@@ -448,7 +467,7 @@ def _get_model_state_dict(
448467

449468
for key in list(state_dict.keys()):
450469
fqns = _get_fqns(model, key)
451-
assert len(fqns) == 1
470+
assert len(fqns) == 1, (key, fqns)
452471
fqn = next(iter(fqns))
453472
if fqn != key:
454473
# As we only support FSDP, DDP, and TP, the only cases are
@@ -795,6 +814,19 @@ def _split_optim_state_dict(
795814
pg_state.append({_PARAMS: []})
796815
for param in param_group[_PARAMS]:
797816
for fqn in info.fqn_param_mapping[param]:
817+
if fqn in info.shared_params_mapping:
818+
in_params = False
819+
for loaded_param_group in cast(
820+
ListDictValueType, optim_state_dict[_PG]
821+
):
822+
if fqn in cast(List[str], loaded_param_group[_PARAMS]):
823+
in_params = True
824+
break
825+
else:
826+
in_params = True
827+
if not in_params:
828+
continue
829+
798830
params = pg_state[-1][_PARAMS]
799831
assert isinstance(params, list)
800832
params.append(fqn)
@@ -803,9 +835,7 @@ def _split_optim_state_dict(
803835
for loaded_param_group in cast(
804836
ListDictValueType, optim_state_dict[_PG]
805837
):
806-
params = loaded_param_group[_PARAMS]
807-
assert isinstance(params, list)
808-
if fqn in params:
838+
if fqn in cast(List[str], loaded_param_group[_PARAMS]):
809839
pg_mapping[id(loaded_param_group)] = len(return_osd[_PG]) - 1
810840

811841
for param_group in cast(ListDictValueType, optim_state_dict[_PG]):

0 commit comments

Comments
 (0)