@@ -1459,6 +1459,19 @@ def _create_inverse_indices_permute_indices(
1459
1459
inverse_indices [1 ].device ,
1460
1460
)
1461
1461
1462
+ def _is_optimizer_enabled (
1463
+ self ,
1464
+ has_local_optimizer : bool ,
1465
+ env : ShardingEnv ,
1466
+ device : Optional [torch .device ],
1467
+ ) -> bool :
1468
+ flag = torch .tensor (
1469
+ [has_local_optimizer ], dtype = torch .uint8 , device = device
1470
+ ) # example: True
1471
+ # Reduce with MAX to check if any process has True
1472
+ dist .all_reduce (flag , op = dist .ReduceOp .MAX , group = env .process_group )
1473
+ return bool (flag .item ())
1474
+
1462
1475
# pyre-ignore [14]
1463
1476
def input_dist (
1464
1477
self ,
@@ -1698,10 +1711,17 @@ def update_shards(
1698
1711
return
1699
1712
1700
1713
current_state = self .state_dict ()
1701
- has_optimizer = len (self ._optim ._optims ) > 0 and all (
1714
+ has_local_optimizer = len (self ._optim ._optims ) > 0 and all (
1702
1715
len (i ) > 0 for i in self ._optim .state_dict ()["state" ].values ()
1703
1716
)
1704
1717
1718
+ # communicate optimizer state across all ranks, because if one rank owns all tables
1719
+ # and other ranks does not own any table, and later transfer the weights to empty rank
1720
+ # creates inconsistent state, because initally empty rank does not have optimizer state
1721
+ # hence, incorrectly computes the tensor splits
1722
+
1723
+ has_optimizer = self ._is_optimizer_enabled (has_local_optimizer , env , device )
1724
+
1705
1725
# TODO: Saving lookups tensors to CPU to eventually avoid recreating them completely again
1706
1726
# TODO: Ensure lookup tensors are actually being deleted
1707
1727
for _ , lookup in enumerate (self ._lookups ):
@@ -1715,7 +1735,7 @@ def update_shards(
1715
1735
max_dim_0 , max_dim_1 = get_largest_dims_from_sharding_plan_updates (
1716
1736
changed_sharding_params
1717
1737
)
1718
- old_optimizer_state = self ._optim .state_dict () if has_optimizer else None
1738
+ old_optimizer_state = self ._optim .state_dict () if has_local_optimizer else None
1719
1739
1720
1740
local_shard_names_by_src_rank , local_output_tensor = shards_all_to_all (
1721
1741
module = self ,
@@ -1727,6 +1747,7 @@ def update_shards(
1727
1747
max_dim_0 = max_dim_0 ,
1728
1748
max_dim_1 = max_dim_1 ,
1729
1749
optimizer_state = old_optimizer_state ,
1750
+ has_optimizer = has_optimizer ,
1730
1751
)
1731
1752
1732
1753
for name , param in changed_sharding_params .items ():
@@ -1791,30 +1812,25 @@ def update_shards(
1791
1812
self ._optim : CombinedOptimizer = CombinedOptimizer (optims )
1792
1813
1793
1814
if has_optimizer :
1794
- split_index = len (local_output_tensor ) // 2
1795
- local_weight_tensors = local_output_tensor [:split_index ]
1796
- local_optimizer_tensors = local_output_tensor [split_index :]
1797
- # Modifies new_opt_state in place and returns it
1798
1815
optimizer_state = update_optimizer_state_post_resharding (
1799
1816
old_opt_state = old_optimizer_state , # pyre-ignore
1800
1817
new_opt_state = copy .deepcopy (self ._optim .state_dict ()),
1801
1818
ordered_shard_names_and_lengths = local_shard_names_by_src_rank ,
1802
- output_tensor = local_optimizer_tensors ,
1819
+ output_tensor = local_output_tensor ,
1803
1820
max_dim_0 = max_dim_0 ,
1821
+ extend_shard_name = self .extend_shard_name ,
1804
1822
)
1805
-
1806
1823
self ._optim .load_state_dict (optimizer_state )
1807
- else :
1808
- local_weight_tensors = local_output_tensor
1809
1824
1810
1825
current_state = update_state_dict_post_resharding (
1811
1826
state_dict = current_state ,
1812
1827
ordered_shard_names_and_lengths = local_shard_names_by_src_rank ,
1813
- output_tensor = local_weight_tensors ,
1828
+ output_tensor = local_output_tensor ,
1814
1829
new_sharding_params = changed_sharding_params ,
1815
1830
curr_rank = dist .get_rank (),
1816
1831
extend_shard_name = self .extend_shard_name ,
1817
1832
max_dim_0 = max_dim_0 ,
1833
+ has_optimizer = has_optimizer ,
1818
1834
)
1819
1835
1820
1836
self .load_state_dict (current_state )
0 commit comments