Skip to content

Commit 51a7921

Browse files
isururanawakafacebook-github-bot
authored andcommitted
Refactoring Resharding API (#3245)
Summary: Pull Request resolved: #3245 Differential Revision: D79023990
1 parent 5d388f3 commit 51a7921

File tree

7 files changed

+292
-127
lines changed

7 files changed

+292
-127
lines changed

torchrec/distributed/embeddingbag.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1459,6 +1459,19 @@ def _create_inverse_indices_permute_indices(
14591459
inverse_indices[1].device,
14601460
)
14611461

1462+
def _is_optimizer_enabled(
1463+
self,
1464+
has_local_optimizer: bool,
1465+
env: ShardingEnv,
1466+
device: Optional[torch.device],
1467+
) -> bool:
1468+
flag = torch.tensor(
1469+
[has_local_optimizer], dtype=torch.uint8, device=device
1470+
) # example: True
1471+
# Reduce with MAX to check if any process has True
1472+
dist.all_reduce(flag, op=dist.ReduceOp.MAX, group=env.process_group)
1473+
return bool(flag.item())
1474+
14621475
# pyre-ignore [14]
14631476
def input_dist(
14641477
self,
@@ -1698,10 +1711,17 @@ def update_shards(
16981711
return
16991712

17001713
current_state = self.state_dict()
1701-
has_optimizer = len(self._optim._optims) > 0 and all(
1714+
has_local_optimizer = len(self._optim._optims) > 0 and all(
17021715
len(i) > 0 for i in self._optim.state_dict()["state"].values()
17031716
)
17041717

1718+
# communicate optimizer state across all ranks, because if one rank owns all tables
1719+
# and other ranks does not own any table, and later transfer the weights to empty rank
1720+
# creates inconsistent state, because initally empty rank does not have optimizer state
1721+
# hence, incorrectly computes the tensor splits
1722+
1723+
has_optimizer = self._is_optimizer_enabled(has_local_optimizer, env, device)
1724+
17051725
# TODO: Saving lookups tensors to CPU to eventually avoid recreating them completely again
17061726
# TODO: Ensure lookup tensors are actually being deleted
17071727
for _, lookup in enumerate(self._lookups):
@@ -1715,7 +1735,7 @@ def update_shards(
17151735
max_dim_0, max_dim_1 = get_largest_dims_from_sharding_plan_updates(
17161736
changed_sharding_params
17171737
)
1718-
old_optimizer_state = self._optim.state_dict() if has_optimizer else None
1738+
old_optimizer_state = self._optim.state_dict() if has_local_optimizer else None
17191739

17201740
local_shard_names_by_src_rank, local_output_tensor = shards_all_to_all(
17211741
module=self,
@@ -1727,6 +1747,7 @@ def update_shards(
17271747
max_dim_0=max_dim_0,
17281748
max_dim_1=max_dim_1,
17291749
optimizer_state=old_optimizer_state,
1750+
has_optimizer=has_optimizer,
17301751
)
17311752

17321753
for name, param in changed_sharding_params.items():
@@ -1791,30 +1812,25 @@ def update_shards(
17911812
self._optim: CombinedOptimizer = CombinedOptimizer(optims)
17921813

17931814
if has_optimizer:
1794-
split_index = len(local_output_tensor) // 2
1795-
local_weight_tensors = local_output_tensor[:split_index]
1796-
local_optimizer_tensors = local_output_tensor[split_index:]
1797-
# Modifies new_opt_state in place and returns it
17981815
optimizer_state = update_optimizer_state_post_resharding(
17991816
old_opt_state=old_optimizer_state, # pyre-ignore
18001817
new_opt_state=copy.deepcopy(self._optim.state_dict()),
18011818
ordered_shard_names_and_lengths=local_shard_names_by_src_rank,
1802-
output_tensor=local_optimizer_tensors,
1819+
output_tensor=local_output_tensor,
18031820
max_dim_0=max_dim_0,
1821+
extend_shard_name=self.extend_shard_name,
18041822
)
1805-
18061823
self._optim.load_state_dict(optimizer_state)
1807-
else:
1808-
local_weight_tensors = local_output_tensor
18091824

18101825
current_state = update_state_dict_post_resharding(
18111826
state_dict=current_state,
18121827
ordered_shard_names_and_lengths=local_shard_names_by_src_rank,
1813-
output_tensor=local_weight_tensors,
1828+
output_tensor=local_output_tensor,
18141829
new_sharding_params=changed_sharding_params,
18151830
curr_rank=dist.get_rank(),
18161831
extend_shard_name=self.extend_shard_name,
18171832
max_dim_0=max_dim_0,
1833+
has_optimizer=has_optimizer,
18181834
)
18191835

18201836
self.load_state_dict(current_state)

0 commit comments

Comments
 (0)