Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1526,7 +1526,7 @@ def update_shards(

current_state = update_state_dict_post_resharding(
state_dict=current_state,
shard_names_by_src_rank=local_shard_names_by_src_rank,
ordered_shard_names_and_lengths=local_shard_names_by_src_rank,
output_tensor=local_output_tensor,
new_sharding_params=changed_sharding_params,
curr_rank=dist.get_rank(),
Expand Down
87 changes: 60 additions & 27 deletions torchrec/distributed/sharding/dynamic_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@ def shards_all_to_all(
changed_sharding_params: Dict[str, ParameterSharding],
env: ShardingEnv,
extend_shard_name: Callable[[str], str] = lambda x: x,
) -> Tuple[List[str], torch.Tensor]:
) -> Tuple[List[Tuple[str, int]], torch.Tensor]:
"""
Performs an all-to-all communication to redistribute shards across ranks based on new sharding parameters.
Assumes ranks are ordered in ParameterSharding.ranks.

Args:
module (ShardedEmbeddingBagCollection): The module containing sharded tensors to be redistributed.
TODO: Update to support more modules
module (ShardedModule[Any, Any, Any, Any]): The module containing sharded tensors to be redistributed.
TODO: Update to support more modules, currently only supports ShardedEmbeddingBagCollection.

state_dict (Dict[str, ShardedTensor]): The state dictionary containing the current sharded tensors.

Expand All @@ -47,8 +47,9 @@ def shards_all_to_all(
extend_shard_name (Callable[[str], str], optional): A function to extend shard names to the full name in state_dict.

Returns:
Tuple[List[str], torch.Tensor]: A tuple containing:
- A list of shard names that were sent from a specific rank to the current rank, ordered by rank, then shard order.
Tuple[List[Tuple[str, int]], torch.Tensor]: A tuple containing:
- A list of shard name and the corresponding shard_size in dim 1 that were sent to the current rank.
This is a flattened and pruned nested list, which orders the shards names and sizes by rank, then shard order.
- The tensor containing all shards received by the current rank after the all-to-all operation.
"""
if env.output_dtensor:
Expand All @@ -62,10 +63,12 @@ def shards_all_to_all(
rank = dist.get_rank()
input_splits_per_rank = [[0] * world_size for _ in range(world_size)]
output_splits_per_rank = [[0] * world_size for _ in range(world_size)]
local_input_tensor = torch.empty([0], device=device)
local_output_tensor = torch.empty([0], device=device)

shard_names_by_src_rank = []
# 0 by default, as current rank may be recieving 0 shards
num_embeddings_received = 0
output_tensor_tensor_count = 0
shard_names_to_lengths_by_src_rank = [[] for _ in range(world_size)]
local_table_to_input_tensor_by_dst_rank = [[] for _ in range(world_size)]
for shard_name, param in changed_sharding_params.items():
sharded_t = state_dict[extend_shard_name(shard_name)]
assert param.ranks is not None
Expand All @@ -84,27 +87,52 @@ def shards_all_to_all(
src_rank = src_ranks[i]

shard_size = sharded_t.metadata().shards_metadata[i].shard_sizes
shard_size_dim_0 = shard_size[0]
input_splits_per_rank[src_rank][dst_rank] += shard_size_dim_0
output_splits_per_rank[dst_rank][src_rank] += shard_size_dim_0
shard_size_dim_1 = shard_size[1]
input_splits_per_rank[src_rank][dst_rank] += shard_size_dim_1
output_splits_per_rank[dst_rank][src_rank] += shard_size_dim_1
if src_rank == rank:
local_shards = sharded_t.local_shards()
assert len(local_shards) == 1
local_input_tensor = torch.cat(
(
local_input_tensor,
sharded_t.local_shards()[0].tensor,
)
local_table_to_input_tensor_by_dst_rank[dst_rank].append(
sharded_t.local_shards()[0].tensor
)
if dst_rank == rank:
shard_names_by_src_rank.append(shard_name)
local_output_tensor = torch.cat(
(local_output_tensor, torch.empty(shard_size, device=device))
shard_names_to_lengths_by_src_rank[src_rank].append(
(shard_name, shard_size_dim_1)
)
# NOTE: Only need to update num_embeddings_received to be the
# num_embeddings of shards if this rank is actually recieving
# any tensors
if num_embeddings_received == 0:
num_embeddings_received = shard_size[0]
else:
# TODO: for 2D and row-wise, shard_sizes in dim 0 may be variable
# For now, assume that shard_sizes in dim 0 are all the same
assert num_embeddings_received == shard_size[0]
output_tensor_tensor_count += shard_size[1]

local_input_splits = input_splits_per_rank[rank]
local_output_splits = output_splits_per_rank[rank]

local_input_tensor = torch.empty([0], device=device)
for sub_l in local_table_to_input_tensor_by_dst_rank:
for shard_info in sub_l:
local_input_tensor = torch.cat(
(
local_input_tensor,
shard_info,
),
dim=1,
)

# Transposing the Tensors - because we are concatenating them along dimension 1
# This is because dim 0 size may be different for different shards
# whereas dim 1 size is the same for all shards as dim 1 size = num_embeddings per table
local_output_tensor = torch.empty(
[output_tensor_tensor_count, num_embeddings_received], device=device
)
local_input_tensor = local_input_tensor.T.contiguous()

assert sum(local_output_splits) == len(local_output_tensor)
assert sum(local_input_splits) == len(local_input_tensor)
dist.all_to_all_single(
Expand All @@ -115,12 +143,18 @@ def shards_all_to_all(
group=dist.group.WORLD,
)

return shard_names_by_src_rank, local_output_tensor
flattened_output_names_lengths = [
shard_info
for sub_l in shard_names_to_lengths_by_src_rank
for shard_info in sub_l
]

return flattened_output_names_lengths, local_output_tensor


def update_state_dict_post_resharding(
state_dict: Dict[str, ShardedTensor],
shard_names_by_src_rank: List[str],
ordered_shard_names_and_lengths: List[Tuple[str, int]],
output_tensor: torch.Tensor,
new_sharding_params: Dict[str, ParameterSharding],
curr_rank: int,
Expand All @@ -133,8 +167,9 @@ def update_state_dict_post_resharding(
Args:
state_dict (Dict[str, Any]): The state dict to be updated with new shard placements and local shards.

shard_names_by_src_rank (List[str]): A list of shard names that were sent from a specific rank to the
current rank, ordered by rank, then shard order.
shard_names_by_src_rank (List[Tuple[str, int]]): A list of shard name and the corresponding shard_size in dim 1
that were sent to the current rank. This is a flattened and pruned nested list, which orders the shards names and
sizes by rank, then shard order.

output_tensor (torch.Tensor): The tensor containing the output data from the AllToAll operation.

Expand All @@ -149,16 +184,14 @@ def update_state_dict_post_resharding(
Dict[str, ShardedTensor]: The updated state dictionary with new shard placements and local shards.
"""
slice_index = 0
shard_names_by_src_rank

shard_name_to_local_output_tensor: Dict[str, torch.Tensor] = {}

for shard_name in shard_names_by_src_rank:
shard_size = state_dict[extend_shard_name(shard_name)].size(0)
for shard_name, shard_size in ordered_shard_names_and_lengths:
end_slice_index = slice_index + shard_size
shard_name_to_local_output_tensor[shard_name] = output_tensor[
slice_index:end_slice_index
]
].T
slice_index = end_slice_index

for shard_name, param in new_sharding_params.items():
Expand Down
Loading