@@ -487,7 +487,9 @@ def _flatten_tensor_optim_state(
487
487
"Tensor optimizer state does not have same shape as its "
488
488
f"parameter: { tensor .shape } { shape } "
489
489
)
490
- # Flatten the tensor states
490
+ # Flatten the tensor states: we do not need to add any padding since the
491
+ # flattened optimizer state tensor sharded via `_get_shard()`, which pads
492
+ # the shard as needed (just like for the flattened parameter)
491
493
cpu_device = torch .device ("cpu" )
492
494
tensors = [
493
495
torch .flatten (state_value .to (cpu_device )) if state_value is not None
@@ -497,26 +499,11 @@ def _flatten_tensor_optim_state(
497
499
for state_value , shape
498
500
in zip (pos_dim_tensors , unflat_param_shapes )
499
501
]
500
- padding = flat_param .num_padded
501
- if padding > 0 :
502
- tensors .append (torch .zeros (padding , dtype = dtype , device = cpu_device ))
503
502
flat_tensor = torch .cat (tensors )
504
- # `flat_tensor`'s shape should be 1D and less than or equal to the
505
- # flattened parameter's shape (where the inequality is strict for positive
506
- # padding)
507
- if not flat_param ._is_sharded : # currently, only when world size is 1
508
- # If the parameter is not sharded, then `_full_param_padded` is not
509
- # used, so we skip the shape check
510
- return flat_tensor
511
- full_padded_dim = flat_param ._full_param_padded .dim () # type: ignore[attr-defined]
512
- full_padded_shape = flat_param ._full_param_padded .shape # type: ignore[attr-defined]
513
- assert flat_tensor .dim () == 1 , \
514
- f"`flat_tensor` should be 1D but got { flat_tensor .dim ()} dims"
515
- assert full_padded_dim == 1 , \
516
- f"`_full_param_padded` should be 1D but got { full_padded_dim } dims"
517
- assert flat_tensor .shape [0 ] <= full_padded_shape [0 ], \
503
+ flat_param_shape = flat_param ._orig_size # type: ignore[attr-defined]
504
+ assert flat_tensor .shape == flat_param_shape , \
518
505
f"tensor optim state: { flat_tensor .shape } " \
519
- f"parameter: { full_padded_shape } "
506
+ f"flattened parameter: { flat_param_shape } "
520
507
return flat_tensor
521
508
522
509
0 commit comments