diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 6a0192841..33ea870ae 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -27,6 +27,9 @@ import torch from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings +from fbgemm_gpu.split_table_batched_embeddings_ops_training import ( + DenseTableBatchedEmbeddingBagsCodegen, +) from tensordict import TensorDict from torch import distributed as dist, nn, Tensor from torch.autograd.profiler import record_function @@ -50,6 +53,10 @@ ) from torchrec.distributed.sharding.cw_sharding import CwPooledEmbeddingSharding from torchrec.distributed.sharding.dp_sharding import DpPooledEmbeddingSharding +from torchrec.distributed.sharding.dynamic_sharding import ( + shards_all_to_all, + update_state_dict_post_resharding, +) from torchrec.distributed.sharding.grid_sharding import GridPooledEmbeddingSharding from torchrec.distributed.sharding.rw_sharding import RwPooledEmbeddingSharding from torchrec.distributed.sharding.tw_sharding import TwPooledEmbeddingSharding @@ -635,14 +642,17 @@ def __init__( self._env = env # output parameters as DTensor in state dict self._output_dtensor: bool = env.output_dtensor - - sharding_type_to_sharding_infos = create_sharding_infos_by_sharding( - module, - table_name_to_parameter_sharding, - "embedding_bags.", - fused_params, + self.sharding_type_to_sharding_infos: Dict[str, List[EmbeddingShardingInfo]] = ( + create_sharding_infos_by_sharding( + module, + table_name_to_parameter_sharding, + "embedding_bags.", + fused_params, + ) + ) + self._sharding_types: List[str] = list( + self.sharding_type_to_sharding_infos.keys() ) - self._sharding_types: List[str] = list(sharding_type_to_sharding_infos.keys()) self._embedding_shardings: List[ EmbeddingSharding[ EmbeddingShardingContext, @@ -658,7 +668,7 @@ def __init__( permute_embeddings=True, qcomm_codecs_registry=self.qcomm_codecs_registry, ) - for embedding_configs in sharding_type_to_sharding_infos.values() + for embedding_configs in self.sharding_type_to_sharding_infos.values() ] self._is_weighted: bool = module.is_weighted() @@ -833,7 +843,7 @@ def _pre_load_state_dict_hook( lookup = lookup.module lookup.purge() - def _initialize_torch_state(self) -> None: # noqa + def _initialize_torch_state(self, skip_registering: bool = False) -> None: # noqa """ This provides consistency between this class and the EmbeddingBagCollection's nn.Module API calls (state_dict, named_modules, etc) @@ -1063,11 +1073,12 @@ def post_state_dict_hook( destination_key = f"{prefix}embedding_bags.{table_name}.weight" destination[destination_key] = sharded_kvtensor - self.register_state_dict_pre_hook(self._pre_state_dict_hook) - self._register_state_dict_hook(post_state_dict_hook) - self._register_load_state_dict_pre_hook( - self._pre_load_state_dict_hook, with_module=True - ) + if not skip_registering: + self.register_state_dict_pre_hook(self._pre_state_dict_hook) + self._register_state_dict_hook(post_state_dict_hook) + self._register_load_state_dict_pre_hook( + self._pre_load_state_dict_hook, with_module=True + ) self.reset_parameters() def reset_parameters(self) -> None: @@ -1164,6 +1175,7 @@ def _create_output_dist(self) -> None: self._uncombined_embedding_dims.extend(sharding.uncombined_embedding_dims()) embedding_shard_metadata.extend(sharding.embedding_shard_metadata()) self._dim_per_key = torch.tensor(self._embedding_dims, device=self._device) + embedding_shard_offsets: List[int] = [ meta.shard_offsets[1] if meta is not None else 0 for meta in embedding_shard_metadata @@ -1179,6 +1191,38 @@ def _create_output_dist(self) -> None: embedding_shard_offsets[i], ), ) + + self._permute_op: PermutePooledEmbeddings = PermutePooledEmbeddings( + self._uncombined_embedding_dims, permute_indices, self._device + ) + + def _update_output_dist(self) -> None: + embedding_shard_metadata: List[Optional[ShardMetadata]] = [] + # TODO: Optimize to only go through embedding shardings with new ranks + self._output_dists: List[nn.Module] = [] + self._embedding_names: List[str] = [] + for sharding in self._embedding_shardings: + # TODO: if sharding type of table completely changes, need to regenerate everything + self._embedding_names.extend(sharding.embedding_names()) + self._output_dists.append(sharding.create_output_dist(device=self._device)) + embedding_shard_metadata.extend(sharding.embedding_shard_metadata()) + + embedding_shard_offsets: List[int] = [ + meta.shard_offsets[1] if meta is not None else 0 + for meta in embedding_shard_metadata + ] + embedding_name_order: Dict[str, int] = {} + for i, name in enumerate(self._uncombined_embedding_names): + embedding_name_order.setdefault(name, i) + + permute_indices = sorted( + range(len(self._uncombined_embedding_names)), + key=lambda i: ( + embedding_name_order[self._uncombined_embedding_names[i]], + embedding_shard_offsets[i], + ), + ) + self._permute_op: PermutePooledEmbeddings = PermutePooledEmbeddings( self._uncombined_embedding_dims, permute_indices, self._device ) @@ -1396,6 +1440,108 @@ def compute_and_output_dist( return awaitable + def update_shards( + self, + changed_sharding_params: Dict[str, ParameterSharding], # NOTE: only delta + env: ShardingEnv, + device: Optional[torch.device], + ) -> None: + """ + Update shards for this module based on the changed_sharding_params. This will: + 1. Move current lookup tensors to CPU + 2. Purge lookups + 3. Call shards_all_2_all containing collective to redistribute tensors + 4. Update state_dict and other attributes to reflect new placements and shards + 5. Create new lookups, and load in updated state_dict + + Args: + changed_sharding_params (Dict[str, ParameterSharding]): A dictionary mapping + table names to their new parameter sharding configs. This should only + contain shards/table names that need to be moved. + env (ShardingEnv): The sharding environment for the module. + device (Optional[torch.device]): The device to place the updated module on. + """ + + if env.output_dtensor: + raise RuntimeError("We do not yet support DTensor for resharding yet") + return + + current_state = self.state_dict() + # TODO: Save Optimizers + + saved_weights = {} + # TODO: Saving lookups tensors to CPU to eventually avoid recreating them completely again + for i, lookup in enumerate(self._lookups): + for attribute, tbe_module in lookup.named_modules(): + if type(tbe_module) is DenseTableBatchedEmbeddingBagsCodegen: + saved_weights[str(i) + "." + attribute] = tbe_module.weights.cpu() + # Note: lookup.purge should delete tbe_module and weights + # del tbe_module.weights + # del tbe_module + # pyre-ignore + lookup.purge() + + # Deleting all lookups + self._lookups.clear() + + local_shard_names_by_src_rank, local_output_tensor = shards_all_to_all( + module=self, + state_dict=current_state, + device=device, # pyre-ignore + changed_sharding_params=changed_sharding_params, + env=env, + extend_shard_name=self.extend_shard_name, + ) + + current_state = update_state_dict_post_resharding( + state_dict=current_state, + shard_names_by_src_rank=local_shard_names_by_src_rank, + output_tensor=local_output_tensor, + new_sharding_params=changed_sharding_params, + curr_rank=dist.get_rank(), + extend_shard_name=self.extend_shard_name, + ) + + for name, param in changed_sharding_params.items(): + self.module_sharding_plan[name] = param + # TODO: Support detecting old sharding type when sharding type is changing + for sharding_info in self.sharding_type_to_sharding_infos[ + param.sharding_type + ]: + if sharding_info.embedding_config.name == name: + sharding_info.param_sharding = param + + self._sharding_types: List[str] = list( + self.sharding_type_to_sharding_infos.keys() + ) + # TODO: Optimize to update only the changed embedding shardings + self._embedding_shardings: List[ + EmbeddingSharding[ + EmbeddingShardingContext, + KeyedJaggedTensor, + torch.Tensor, + torch.Tensor, + ] + ] = [ + create_embedding_bag_sharding( + embedding_configs, + env, + device, + permute_embeddings=True, + qcomm_codecs_registry=self.qcomm_codecs_registry, + ) + for embedding_configs in self.sharding_type_to_sharding_infos.values() + ] + + self._create_lookups() + self._update_output_dist() + + if env.process_group and dist.get_backend(env.process_group) != "fake": + self._initialize_torch_state(skip_registering=True) + + self.load_state_dict(current_state) + return + @property def fused_optimizer(self) -> KeyedOptimizer: return self._optim @@ -1403,6 +1549,10 @@ def fused_optimizer(self) -> KeyedOptimizer: def create_context(self) -> EmbeddingBagCollectionContext: return EmbeddingBagCollectionContext() + @staticmethod + def extend_shard_name(shard_name: str) -> str: + return f"embedding_bags.{shard_name}.weight" + class EmbeddingBagCollectionSharder(BaseEmbeddingSharder[EmbeddingBagCollection]): """ @@ -1435,6 +1585,33 @@ def shardable_parameters( for name, param in module.embedding_bags.named_parameters() } + def reshard( + self, + sharded_module: ShardedEmbeddingBagCollection, + changed_shard_to_params: Dict[str, ParameterSharding], + env: ShardingEnv, + device: Optional[torch.device] = None, + ) -> ShardedEmbeddingBagCollection: + """ + Updates the sharded module in place based on the changed_shard_to_params + which contains the new ParameterSharding with different shard placements. + + Args: + sharded_module (ShardedEmbeddingBagCollection): The module to update + changed_shard_to_params (Dict[str, ParameterSharding]): A dictionary mapping + table names to their new parameter sharding configs. This should only + contain shards/table names that need to be moved + env (ShardingEnv): The sharding environment + device (Optional[torch.device]): The device to place the updated module on + + Returns: + ShardedEmbeddingBagCollection: The updated sharded module + """ + + if len(changed_shard_to_params) > 0: + sharded_module.update_shards(changed_shard_to_params, env, device) + return sharded_module + @property def module_type(self) -> Type[EmbeddingBagCollection]: return EmbeddingBagCollection diff --git a/torchrec/distributed/sharding/dynamic_sharding.py b/torchrec/distributed/sharding/dynamic_sharding.py new file mode 100644 index 000000000..4e50c4f72 --- /dev/null +++ b/torchrec/distributed/sharding/dynamic_sharding.py @@ -0,0 +1,190 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Any, Callable, Dict, List, Tuple + +import torch +import torch.distributed as dist +from torch.distributed._shard.sharded_tensor import Shard +from torchrec.distributed.types import ( + ParameterSharding, + ShardedModule, + ShardedTensor, + ShardingEnv, +) + + +def shards_all_to_all( + module: ShardedModule[Any, Any, Any, Any], # pyre-ignore + state_dict: Dict[str, ShardedTensor], + device: torch.device, + changed_sharding_params: Dict[str, ParameterSharding], + env: ShardingEnv, + extend_shard_name: Callable[[str], str] = lambda x: x, +) -> Tuple[List[str], torch.Tensor]: + """ + Performs an all-to-all communication to redistribute shards across ranks based on new sharding parameters. + Assumes ranks are ordered in ParameterSharding.ranks. + + Args: + module (ShardedEmbeddingBagCollection): The module containing sharded tensors to be redistributed. + TODO: Update to support more modules + + state_dict (Dict[str, ShardedTensor]): The state dictionary containing the current sharded tensors. + + device (torch.device): The device on which the output tensors will be placed. + + changed_sharding_params (Dict[str, ParameterSharding]): A dictionary mapping shard names to their new sharding parameters. + + env (ShardingEnv): The sharding environment containing world size and other distributed information. + + extend_shard_name (Callable[[str], str], optional): A function to extend shard names to the full name in state_dict. + + Returns: + Tuple[List[str], torch.Tensor]: A tuple containing: + - A list of shard names that were sent from a specific rank to the current rank, ordered by rank, then shard order. + - The tensor containing all shards received by the current rank after the all-to-all operation. + """ + if env.output_dtensor: + raise RuntimeError("We do not yet support DTensor for resharding yet") + return + + # Module sharding plan is used to get the source ranks for each shard + assert hasattr(module, "module_sharding_plan") + + world_size = env.world_size + rank = dist.get_rank() + input_splits_per_rank = [[0] * world_size for _ in range(world_size)] + output_splits_per_rank = [[0] * world_size for _ in range(world_size)] + local_input_tensor = torch.empty([0], device=device) + local_output_tensor = torch.empty([0], device=device) + + shard_names_by_src_rank = [] + for shard_name, param in changed_sharding_params.items(): + sharded_t = state_dict[extend_shard_name(shard_name)] + assert param.ranks is not None + dst_ranks = param.ranks + state_dict[extend_shard_name(shard_name)] + # pyre-ignore + src_ranks = module.module_sharding_plan[shard_name].ranks + + # TODO: Implement changing rank sizes for beyond TW sharding + assert len(dst_ranks) == len(src_ranks) + + # index needed to distinguish between multiple shards + # within the same shardedTensor for each table + for i in range(len(src_ranks)): + dst_rank = dst_ranks[i] + src_rank = src_ranks[i] + + shard_size = sharded_t.metadata().shards_metadata[i].shard_sizes + shard_size_dim_0 = shard_size[0] + input_splits_per_rank[src_rank][dst_rank] += shard_size_dim_0 + output_splits_per_rank[dst_rank][src_rank] += shard_size_dim_0 + if src_rank == rank: + local_shards = sharded_t.local_shards() + assert len(local_shards) == 1 + local_input_tensor = torch.cat( + ( + local_input_tensor, + sharded_t.local_shards()[0].tensor, + ) + ) + if dst_rank == rank: + shard_names_by_src_rank.append(shard_name) + local_output_tensor = torch.cat( + (local_output_tensor, torch.empty(shard_size, device=device)) + ) + + local_input_splits = input_splits_per_rank[rank] + local_output_splits = output_splits_per_rank[rank] + + assert sum(local_output_splits) == len(local_output_tensor) + assert sum(local_input_splits) == len(local_input_tensor) + dist.all_to_all_single( + output=local_output_tensor, + input=local_input_tensor, + output_split_sizes=local_output_splits, + input_split_sizes=local_input_splits, + group=dist.group.WORLD, + ) + + return shard_names_by_src_rank, local_output_tensor + + +def update_state_dict_post_resharding( + state_dict: Dict[str, ShardedTensor], + shard_names_by_src_rank: List[str], + output_tensor: torch.Tensor, + new_sharding_params: Dict[str, ParameterSharding], + curr_rank: int, + extend_shard_name: Callable[[str], str] = lambda x: x, +) -> Dict[str, ShardedTensor]: + """ + Updates and returns the given state_dict with new placements and + local_shards based on the output tensor of the AllToAll collective. + + Args: + state_dict (Dict[str, Any]): The state dict to be updated with new shard placements and local shards. + + shard_names_by_src_rank (List[str]): A list of shard names that were sent from a specific rank to the + current rank, ordered by rank, then shard order. + + output_tensor (torch.Tensor): The tensor containing the output data from the AllToAll operation. + + new_sharding_params (Dict[str, ParameterSharding]): A dictionary mapping shard names to their new sharding parameters. + This should only contain shard names that were updated during the AllToAll operation. + + curr_rank (int): The current rank of the process in the distributed environment. + + extend_shard_name (Callable[[str], str], optional): A function to extend shard names to the full name in state_dict. + + Returns: + Dict[str, ShardedTensor]: The updated state dictionary with new shard placements and local shards. + """ + slice_index = 0 + shard_names_by_src_rank + + shard_name_to_local_output_tensor: Dict[str, torch.Tensor] = {} + + for shard_name in shard_names_by_src_rank: + shard_size = state_dict[extend_shard_name(shard_name)].size(0) + end_slice_index = slice_index + shard_size + shard_name_to_local_output_tensor[shard_name] = output_tensor[ + slice_index:end_slice_index + ] + slice_index = end_slice_index + + for shard_name, param in new_sharding_params.items(): + extended_name = extend_shard_name(shard_name) + # pyre-ignore + for i in range(len(param.ranks)): + # pyre-ignore + r = param.ranks[i] + sharded_t = state_dict[extended_name] + # Update placements + sharded_t.metadata().shards_metadata[i].placement = ( + torch.distributed._remote_device(f"rank:{r}/cuda:{r}") + ) + if r == curr_rank: + assert len(output_tensor) > 0 + # slice output tensor for correct size. + sharded_t._local_shards = [ + Shard( + tensor=shard_name_to_local_output_tensor[shard_name], + metadata=state_dict[extended_name] + .metadata() + .shards_metadata[i], + ) + ] + break + else: + sharded_t._local_shards = [] + + return state_dict diff --git a/torchrec/distributed/tests/test_dynamic_sharding.py b/torchrec/distributed/tests/test_dynamic_sharding.py new file mode 100644 index 000000000..ccc46cc94 --- /dev/null +++ b/torchrec/distributed/tests/test_dynamic_sharding.py @@ -0,0 +1,437 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + + +import copy + +import random +import unittest + +from typing import Any, Dict, List, Optional, Union + +import hypothesis.strategies as st + +import torch + +from hypothesis import given, settings, Verbosity +from torch import nn + +from torchrec import distributed as trec_dist, EmbeddingBagCollection, KeyedJaggedTensor +from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection + +from torchrec.distributed.sharding_plan import ( + construct_module_sharding_plan, + get_module_to_default_sharders, + table_wise, +) + +from torchrec.distributed.test_utils.multi_process import ( + MultiProcessContext, + MultiProcessTestBase, +) +from torchrec.distributed.test_utils.test_sharding import copy_state_dict + +from torchrec.distributed.types import ( + EmbeddingModuleShardingPlan, + ParameterSharding, + ShardingEnv, + ShardingType, +) +from torchrec.modules.embedding_configs import data_type_to_dtype, EmbeddingBagConfig + +from torchrec.test_utils import skip_if_asan_class +from torchrec.types import DataType + + +# Utils: +def table_name(i: int) -> str: + return "table_" + str(i) + + +def feature_name(i: int) -> str: + return "feature_" + str(i) + + +def generate_input_by_world_size( + world_size: int, + num_tables: int, + num_embeddings: int = 4, + max_mul: int = 3, +) -> List[KeyedJaggedTensor]: + # TODO merge with new ModelInput generator in TestUtils + kjt_input_per_rank = [] + mul = random.randint(1, max_mul) + total_size = num_tables * mul + + for _ in range(world_size): + feature_names = [feature_name(i) for i in range(num_tables)] + lengths = [] + values = [] + counting_l = 0 + for i in range(total_size): + if i == total_size - 1: + lengths.append(total_size - counting_l) + break + next_l = random.randint(0, total_size - counting_l) + values.extend( + [random.randint(0, num_embeddings - 1) for _ in range(next_l)] + ) + lengths.append(next_l) + counting_l += next_l + + # for length in lengths: + + kjt_input_per_rank.append( + KeyedJaggedTensor.from_lengths_sync( + keys=feature_names, + values=torch.LongTensor(values), + lengths=torch.LongTensor(lengths), + ) + ) + + return kjt_input_per_rank + + +def generate_embedding_bag_config( + data_type: DataType, + num_tables: int = 3, + embedding_dim: int = 16, + num_embeddings: int = 4, +) -> List[EmbeddingBagConfig]: + embedding_bag_config = [] + for i in range(num_tables): + embedding_bag_config.append( + EmbeddingBagConfig( + name=table_name(i), + feature_names=[feature_name(i)], + embedding_dim=embedding_dim, + num_embeddings=num_embeddings, + data_type=data_type, + ), + ) + return embedding_bag_config + + +def create_test_initial_state_dict( + sharded_module_type: nn.Module, + num_tables: int, + data_type: DataType, + embedding_dim: int = 16, + num_embeddings: int = 4, +) -> Dict[str, torch.Tensor]: + """ + Helpful for debugging: + + initial_state_dict = { + "embedding_bags.table_0.weight": torch.tensor( + [ + [1] * 16, + [2] * 16, + [3] * 16, + [4] * 16, + ], + ), + "embedding_bags.table_1.weight": torch.tensor( + [ + [101] * 16, + [102] * 16, + [103] * 16, + [104] * 16, + ], + dtype=data_type_to_dtype(data_type), + ), + ... + } + """ + + initial_state_dict = {} + for i in range(num_tables): + # pyre-ignore + extended_name = sharded_module_type.extend_shard_name(table_name(i)) + initial_state_dict[extended_name] = torch.tensor( + [[j + (i * 100)] * embedding_dim for j in range(num_embeddings)], + dtype=data_type_to_dtype(data_type), + ) + + return initial_state_dict + + +def are_modules_identical( + module1: Union[EmbeddingBagCollection, ShardedEmbeddingBagCollection], + module2: Union[EmbeddingBagCollection, ShardedEmbeddingBagCollection], +) -> None: + # Check if both modules have the same type + assert type(module1) is type(module2) + + # Check if both modules have the same parameters + params1 = list(module1.named_parameters()) + params2 = list(module2.named_parameters()) + + assert len(params1) == len(params2) + + for param1, param2 in zip(params1, params2): + # Check parameter names + assert param1[0] == param2[0] + # Check parameter values + assert torch.allclose(param1[1], param2[1]) + + # Check if both modules have the same buffers + buffers1 = list(module1.named_buffers()) + buffers2 = list(module2.named_buffers()) + + assert len(buffers1) == len(buffers2) + + for buffer1, buffer2 in zip(buffers1, buffers2): + assert buffer1[0] == buffer2[0] # Check buffer names + assert torch.allclose(buffer1[1], buffer2[1]) # Check buffer values + + +def output_sharding_plan_delta( + old_plan: EmbeddingModuleShardingPlan, new_plan: EmbeddingModuleShardingPlan +) -> EmbeddingModuleShardingPlan: + assert len(old_plan) == len(new_plan) + return_plan = copy.deepcopy(new_plan) + for shard_name, old_param in old_plan.items(): + if shard_name not in return_plan: + raise ValueError(f"Shard {shard_name} not found in new plan") + new_param = return_plan[shard_name] + old_ranks = old_param.ranks + new_ranks = new_param.ranks + if old_ranks == new_ranks: + del return_plan[shard_name] + + return return_plan + + +def _test_ebc_resharding( + tables: List[EmbeddingBagConfig], + initial_state_dict: Dict[str, Any], + rank: int, + world_size: int, + kjt_input_per_rank: List[KeyedJaggedTensor], + backend: str, + module_sharding_plan: EmbeddingModuleShardingPlan, + new_module_sharding_plan: EmbeddingModuleShardingPlan, + local_size: Optional[int] = None, +) -> None: + """ + Distributed call to test resharding for ebc by creating 2 models with identical config and + states: + m1 sharded with new_module_sharding_plan + m2 sharded with module_sharding_plan, then resharded with new_module_sharding_plan + + Expects m1 and resharded m2 to be the same, and predictions outputted from the same KJT + inputs to be the same. + + TODO: modify to include other modules once dynamic sharding is built out. + """ + trec_dist.comm_ops.set_gradient_division(False) + with MultiProcessContext(rank, world_size, backend, local_size) as ctx: + kjt_input_per_rank = [kjt.to(ctx.device) for kjt in kjt_input_per_rank] + + initial_state_dict = { + fqn: tensor.to(ctx.device) for fqn, tensor in initial_state_dict.items() + } + m1 = EmbeddingBagCollection( + tables=tables, + device=ctx.device, + ) + + m2 = EmbeddingBagCollection( + tables=tables, + device=ctx.device, + ) + + # Load initial State - making sure models are identical + m1.load_state_dict(initial_state_dict) + copy_state_dict( + loc=m1.state_dict(), + glob=copy.deepcopy(initial_state_dict), + ) + + m2.load_state_dict(initial_state_dict) + copy_state_dict( + loc=m2.state_dict(), + glob=copy.deepcopy(initial_state_dict), + ) + + sharder = get_module_to_default_sharders()[type(m1)] + + # pyre-ignore + env = ShardingEnv.from_process_group(ctx.pg) + + sharded_m1 = sharder.shard( + module=m1, + params=new_module_sharding_plan, + env=env, + device=ctx.device, + ) + + sharded_m2 = sharder.shard( + module=m1, + params=module_sharding_plan, + env=env, + device=ctx.device, + ) + + new_module_sharding_plan_delta = output_sharding_plan_delta( + module_sharding_plan, new_module_sharding_plan + ) + + # pyre-ignore + resharded_m2 = sharder.reshard( + sharded_module=sharded_m2, + changed_shard_to_params=new_module_sharding_plan_delta, + env=env, + device=ctx.device, + ) + + are_modules_identical(sharded_m1, resharded_m2) + + feature_keys = [] + for table in tables: + feature_keys.extend(table.feature_names) + + # For current test model and inputs, the prediction should be the exact same + rtol = 0 + atol = 0 + + for _ in range(world_size): + # sharded model + # each rank gets a subbatch + sharded_m1_pred_kt_no_dict = sharded_m1(kjt_input_per_rank[ctx.rank]) + resharded_m2_pred_kt_no_dict = resharded_m2(kjt_input_per_rank[ctx.rank]) + + sharded_m1_pred_kt = sharded_m1_pred_kt_no_dict.to_dict() + resharded_m2_pred_kt = resharded_m2_pred_kt_no_dict.to_dict() + sharded_m1_pred = torch.stack( + [sharded_m1_pred_kt[feature] for feature in feature_keys] + ) + + resharded_m2_pred = torch.stack( + [resharded_m2_pred_kt[feature] for feature in feature_keys] + ) + # cast to CPU because when casting unsharded_model.to on the same module, there could some race conditions + # in normal author modelling code this won't be an issue because each rank would individually create + # their model. output from sharded_pred is correctly on the correct device. + + # Compare predictions of sharded vs unsharded models. + torch.testing.assert_close( + sharded_m1_pred.cpu(), resharded_m2_pred.cpu(), rtol=rtol, atol=atol + ) + + sharded_m1_pred.sum().backward() + resharded_m2_pred.sum().backward() + + +@skip_if_asan_class +class MultiRankDynamicShardingTest(MultiProcessTestBase): + def _run_ebc_resharding_test( + self, + per_param_sharding: Dict[str, ParameterSharding], + new_per_param_sharding: Dict[str, ParameterSharding], + num_tables: int, + world_size: int, + data_type: DataType, + embedding_dim: int = 16, + num_embeddings: int = 4, + ) -> None: + embedding_bag_config = generate_embedding_bag_config( + data_type, num_tables, embedding_dim, num_embeddings + ) + + module_sharding_plan = construct_module_sharding_plan( + EmbeddingBagCollection(tables=embedding_bag_config), + # pyre-ignore + per_param_sharding=per_param_sharding, + local_size=world_size, + world_size=world_size, + device_type="cuda" if torch.cuda.is_available() else "cpu", + ) + + new_module_sharding_plan = construct_module_sharding_plan( + EmbeddingBagCollection(tables=embedding_bag_config), + # pyre-ignore + per_param_sharding=new_per_param_sharding, + local_size=world_size, + world_size=world_size, + device_type="cuda" if torch.cuda.is_available() else "cpu", + ) + + # Row-wise not supported on gloo + if ( + not torch.cuda.is_available() + and new_module_sharding_plan["table_0"].sharding_type + == ShardingType.ROW_WISE.value + ): + return + + kjt_input_per_rank = generate_input_by_world_size( + world_size, num_tables, num_embeddings + ) + + # initial_state_dict filled with deterministic dummy values + initial_state_dict = create_test_initial_state_dict( + ShardedEmbeddingBagCollection, # pyre-ignore + num_tables, + data_type, + embedding_dim, + num_embeddings, + ) + + self._run_multi_process_test( + callable=_test_ebc_resharding, + world_size=world_size, + tables=embedding_bag_config, + initial_state_dict=initial_state_dict, + kjt_input_per_rank=kjt_input_per_rank, + backend="nccl" if torch.cuda.is_available() else "gloo", + module_sharding_plan=module_sharding_plan, + new_module_sharding_plan=new_module_sharding_plan, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", + ) + @given( # pyre-ignore + num_tables=st.sampled_from([2, 3, 4]), + data_type=st.sampled_from([DataType.FP32, DataType.FP16]), + world_size=st.sampled_from([2, 4]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None) + def test_dynamic_sharding_ebc_tw( + self, + num_tables: int, + data_type: DataType, + world_size: int, + ) -> None: + # Tests EBC dynamic sharding implementation for TW + + # Cannot include old/new rank generation with hypothesis library due to depedency on world_size + old_ranks = [random.randint(0, world_size - 1) for _ in range(num_tables)] + new_ranks = [random.randint(0, world_size - 1) for _ in range(num_tables)] + + if new_ranks == old_ranks: + return + per_param_sharding = {} + new_per_param_sharding = {} + + # Construct parameter shardings + for i in range(num_tables): + per_param_sharding[table_name(i)] = table_wise(rank=old_ranks[i]) + new_per_param_sharding[table_name(i)] = table_wise(rank=new_ranks[i]) + + self._run_ebc_resharding_test( + per_param_sharding, + new_per_param_sharding, + num_tables, + world_size, + data_type, + )