Skip to content

Commit 8412f20

Browse files
Andrew Gupytorchmergebot
Andrew Gu
authored andcommitted
[FSDP] Remove unneeded padding logic for optim state dict
Pull Request resolved: pytorch#78208 Approved by: https://github.com/rohan-varma
1 parent cdb009c commit 8412f20

File tree

2 files changed

+10
-27
lines changed

2 files changed

+10
-27
lines changed

test/distributed/fsdp/test_fsdp_optim_state.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -446,8 +446,7 @@ def test_full_optim_state_dict_nested(
446446
full_osd, ref_osd, check_same_param_keys=check_same_param_keys,
447447
)
448448

449-
# Require 4 GPUs since we test halving the world size
450-
@skip_if_lt_x_gpu(4)
449+
@skip_if_lt_x_gpu(2)
451450
@parametrize("use_multiple_param_groups", [False, True])
452451
@parametrize("wrap_alt", [False, True])
453452
@parametrize("halve_world_size", [False, True])
@@ -467,8 +466,7 @@ def test_shard_full_optim_state_dict_nested(
467466
wrap_alt=wrap_alt,
468467
)
469468

470-
# Require 4 GPUs since we test halving the world size
471-
@skip_if_lt_x_gpu(4)
469+
@skip_if_lt_x_gpu(2)
472470
def test_shard_full_optim_state_dict_transformer(self) -> None:
473471
"""Tests :meth:`shard_full_optim_state_dict` for an FSDP-root
474472
transformer model with shared parameters."""
@@ -478,8 +476,7 @@ def test_shard_full_optim_state_dict_transformer(self) -> None:
478476
osd_comm_method=_OSDCommMethod.BROADCAST_OBJECT_LIST,
479477
)
480478

481-
# Require 4 GPUs since we test halving the world size
482-
@skip_if_lt_x_gpu(4)
479+
@skip_if_lt_x_gpu(2)
483480
@parametrize("use_multiple_param_groups", [False, True])
484481
@parametrize("wrap_alt", [False, True])
485482
@parametrize("halve_world_size", [False, True])
@@ -499,8 +496,7 @@ def test_scatter_full_optim_state_dict_nested(
499496
wrap_alt=wrap_alt,
500497
)
501498

502-
# Require 4 GPUs since we test halving the world size
503-
@skip_if_lt_x_gpu(4)
499+
@skip_if_lt_x_gpu(2)
504500
def test_scatter_full_optim_state_dict_transformer(self) -> None:
505501
"""Tests :meth:`scatter_full_optim_state_dict` for an FSDP-root
506502
transformer model with shared parameters."""

torch/distributed/fsdp/_optim_utils.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,9 @@ def _flatten_tensor_optim_state(
487487
"Tensor optimizer state does not have same shape as its "
488488
f"parameter: {tensor.shape} {shape}"
489489
)
490-
# Flatten the tensor states
490+
# Flatten the tensor states: we do not need to add any padding since the
491+
# flattened optimizer state tensor sharded via `_get_shard()`, which pads
492+
# the shard as needed (just like for the flattened parameter)
491493
cpu_device = torch.device("cpu")
492494
tensors = [
493495
torch.flatten(state_value.to(cpu_device)) if state_value is not None
@@ -497,26 +499,11 @@ def _flatten_tensor_optim_state(
497499
for state_value, shape
498500
in zip(pos_dim_tensors, unflat_param_shapes)
499501
]
500-
padding = flat_param.num_padded
501-
if padding > 0:
502-
tensors.append(torch.zeros(padding, dtype=dtype, device=cpu_device))
503502
flat_tensor = torch.cat(tensors)
504-
# `flat_tensor`'s shape should be 1D and less than or equal to the
505-
# flattened parameter's shape (where the inequality is strict for positive
506-
# padding)
507-
if not flat_param._is_sharded: # currently, only when world size is 1
508-
# If the parameter is not sharded, then `_full_param_padded` is not
509-
# used, so we skip the shape check
510-
return flat_tensor
511-
full_padded_dim = flat_param._full_param_padded.dim() # type: ignore[attr-defined]
512-
full_padded_shape = flat_param._full_param_padded.shape # type: ignore[attr-defined]
513-
assert flat_tensor.dim() == 1, \
514-
f"`flat_tensor` should be 1D but got {flat_tensor.dim()} dims"
515-
assert full_padded_dim == 1, \
516-
f"`_full_param_padded` should be 1D but got {full_padded_dim} dims"
517-
assert flat_tensor.shape[0] <= full_padded_shape[0], \
503+
flat_param_shape = flat_param._orig_size # type: ignore[attr-defined]
504+
assert flat_tensor.shape == flat_param_shape, \
518505
f"tensor optim state: {flat_tensor.shape} " \
519-
f"parameter: {full_padded_shape}"
506+
f"flattened parameter: {flat_param_shape}"
520507
return flat_tensor
521508

522509

0 commit comments

Comments
 (0)