From b99849f3dc40c12f2132d55f4dd29e0ba5f0f038 Mon Sep 17 00:00:00 2001 From: joshuadeng Date: Wed, 10 Jul 2024 13:31:31 -0700 Subject: [PATCH 1/2] concat VBE embeddings in lookup module (#2215) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2215 Previously to handle multiple VBE TBE output which is 1d tensor ordered by rank, we grouped sharding info such that there would only be one TBE created per sharding module. This avoided the issue of concatting multiple 1d tensors that are ordered by rank (not a problem in on VBE bc of 2d output which we can concat on dim 1). This grouping which would be done only applies to specific UVM caching setups that used prefetch pipeline, as each sharding type could require multiple TBE to handle both HBM and UVM caching setups. In most cases the TBE could be fused for each sharding type, so we grouped by such. Each sharding module handles individual input dist, lookup, output dist, and by creating a sharding module per each TBE in EMO setups would cause regression, as there would be an increase in comms to handle the increased input dists and output dists. This diff removes the need for the grouping logic to circumvent the VBE TBE output concatenation by implementing output concatenation, which removes the necessity for specialized sharding grouping logic for specific EMO cases. Differential Revision: D58894728 Reviewed By: dstaay-fb, levythu --- torchrec/distributed/embedding_lookup.py | 34 +++++++++++++++++-- torchrec/distributed/embeddingbag.py | 23 +------------ .../test_utils/test_model_parallel.py | 3 +- .../tests/test_pt2_multiprocess.py | 6 ++-- 4 files changed, 39 insertions(+), 27 deletions(-) diff --git a/torchrec/distributed/embedding_lookup.py b/torchrec/distributed/embedding_lookup.py index fd5948ea5..1c157a174 100644 --- a/torchrec/distributed/embedding_lookup.py +++ b/torchrec/distributed/embedding_lookup.py @@ -46,7 +46,6 @@ EmbeddingComputeKernel, GroupedEmbeddingConfig, InputDistOutputs, - KJTList, ) from torchrec.distributed.fused_params import ( get_tbes_to_register_from_iterable, @@ -442,8 +441,9 @@ def _create_lookup( self.grouped_configs = grouped_configs self._feature_processor = feature_processor + self._world_size: int = dist.get_world_size(pg) self._scale_gradient_factor: int = ( - dist.get_world_size(pg) + self._world_size if scale_weight_gradients and get_gradient_division() else 1 ) @@ -487,11 +487,24 @@ def _need_prefetch(config: GroupedEmbeddingConfig) -> bool: ), ) + def _merge_variable_batch_embeddings( + self, embeddings: List[torch.Tensor], splits: List[List[int]] + ) -> List[torch.Tensor]: + split_embs = [e.split(s) for e, s in zip(embeddings, splits)] + combined_embs = [ + emb + for rank in range(self._world_size) + for n, embs in zip(self._feature_splits, split_embs) + for emb in embs[n * rank : n * rank + n] + ] + return [torch.cat(combined_embs)] + def forward( self, sparse_features: KeyedJaggedTensor, ) -> torch.Tensor: embeddings: List[torch.Tensor] = [] + vbe_splits = [] if len(self._emb_modules) > 0: assert sparse_features is not None features_by_group = sparse_features.split( @@ -514,6 +527,23 @@ def forward( embeddings.append(emb_op(features)) + if features.variable_stride_per_key(): + stride_per_rank_per_key = list( + zip(*features.stride_per_key_per_rank()) + ) + vbe_splits.append( + [ + stride * dim + for stride_per_rank in stride_per_rank_per_key + for stride, dim in zip( + stride_per_rank, config.embedding_dims() + ) + ] + ) + + if sparse_features.variable_stride_per_key(): + embeddings = self._merge_variable_batch_embeddings(embeddings, vbe_splits) + dummy_embedding = ( self._dummy_embs_tensor if sparse_features.variable_stride_per_key() diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 73969c0fe..9c729016e 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -205,28 +205,7 @@ def get_sharding_group( param_sharding: ParameterSharding, fused_params: Optional[Dict[str, Any]] = None, ) -> str: - if fused_params and fused_params.get(USE_ONE_TBE_PER_TABLE, False): - return config.name - if param_sharding.sharding_type in { - ShardingType.COLUMN_WISE.value, - ShardingType.TABLE_COLUMN_WISE.value, - }: - assert param_sharding.ranks - num_ranks = len(param_sharding.ranks) - assert config.embedding_dim % num_ranks == 0 - dim = config.embedding_dim // num_ranks - else: - dim = config.embedding_dim - - group = f"{param_sharding.sharding_type}" - if param_sharding.compute_kernel == EmbeddingComputeKernel.FUSED_UVM_CACHING.value: - group += f"@{param_sharding.compute_kernel}" - if (fused_params and fused_params.get("prefetch_pipeline", False)) or ( - param_sharding.cache_params - and param_sharding.cache_params.prefetch_pipeline - ): - group += f"@{dim}" - return group + return param_sharding.sharding_type def create_sharding_infos_by_group( diff --git a/torchrec/distributed/test_utils/test_model_parallel.py b/torchrec/distributed/test_utils/test_model_parallel.py index f7c702970..92a98ab2e 100644 --- a/torchrec/distributed/test_utils/test_model_parallel.py +++ b/torchrec/distributed/test_utils/test_model_parallel.py @@ -653,9 +653,10 @@ def test_sharding_multiple_kernels(self, sharding_type: str) -> None: ) for i, table in enumerate(self.tables) } + fused_params = {"prefetch_pipeline": True} self._test_sharding( # pyre-ignore[6] - sharders=[EmbeddingBagCollectionSharder()], + sharders=[EmbeddingBagCollectionSharder(fused_params=fused_params)], backend=self.backend, constraints=constraints, variable_batch_per_feature=True, diff --git a/torchrec/distributed/tests/test_pt2_multiprocess.py b/torchrec/distributed/tests/test_pt2_multiprocess.py index 250b01542..17b1c60cc 100644 --- a/torchrec/distributed/tests/test_pt2_multiprocess.py +++ b/torchrec/distributed/tests/test_pt2_multiprocess.py @@ -518,7 +518,8 @@ def disable_cuda_tf32(self) -> bool: ShardingType.TABLE_WISE.value, _InputType.SINGLE_BATCH, _ConvertToVariableBatch.TRUE, - "inductor", + # TODO: Revert to "inductor" once https://github.com/pytorch/pytorch/pull/130431 is landed + "eager", _TestConfig(), ), ( @@ -526,7 +527,8 @@ def disable_cuda_tf32(self) -> bool: ShardingType.COLUMN_WISE.value, _InputType.SINGLE_BATCH, _ConvertToVariableBatch.TRUE, - "inductor", + # TODO: Revert to "inductor" once https://github.com/pytorch/pytorch/pull/130431 is landed + "eager", _TestConfig(), ), ( From c5144fdb64c6d7a31ba2366d3339d3514c5ac378 Mon Sep 17 00:00:00 2001 From: Joshua Deng Date: Wed, 10 Jul 2024 14:03:38 -0700 Subject: [PATCH 2/2] revert sharding grouping logic for vbe (#2216) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2216 reverting sharding grouping logic in EBC/VLE modules that supported specific UVM caching + prefetch pipeline uses cases to circumvent VBE TBE output concatenation. As concatenation is implemented in the preceding diff, this diff cleans up the logic left behind from grouping sharding by UVM caching kernel conditions to avoid VBE TBE output concatenation. Reviewed By: dstaay-fb, levythu Differential Revision: D58989195 --- torchrec/distributed/embeddingbag.py | 28 +++++++------------ .../planner/tests/test_embeddingbag_utils.py | 12 ++++---- 2 files changed, 16 insertions(+), 24 deletions(-) diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 9c729016e..fa096001c 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -13,7 +13,6 @@ from functools import partial from typing import ( Any, - Callable, cast, Dict, Iterator, @@ -39,7 +38,6 @@ EmbeddingShardingInfo, KJTListSplitsAwaitable, Multistreamable, - USE_ONE_TBE_PER_TABLE, ) from torchrec.distributed.embedding_types import ( BaseEmbeddingSharder, @@ -77,7 +75,6 @@ optimizer_type_to_emb_opt_type, ) from torchrec.modules.embedding_configs import ( - BaseEmbeddingConfig, EmbeddingBagConfig, EmbeddingTableConfig, PoolingType, @@ -200,15 +197,7 @@ def create_embedding_bag_sharding( raise ValueError(f"Sharding type not supported {sharding_type}") -def get_sharding_group( - config: BaseEmbeddingConfig, - param_sharding: ParameterSharding, - fused_params: Optional[Dict[str, Any]] = None, -) -> str: - return param_sharding.sharding_type - - -def create_sharding_infos_by_group( +def create_sharding_infos_by_sharding( module: EmbeddingBagCollectionInterface, table_name_to_parameter_sharding: Dict[str, ParameterSharding], prefix: str, @@ -229,7 +218,9 @@ def create_sharding_infos_by_group( else: shared_feature[feature_name] = True - group_to_sharding_infos: Dict[str, List[EmbeddingShardingInfo]] = defaultdict(list) + sharding_type_to_sharding_infos: Dict[str, List[EmbeddingShardingInfo]] = ( + defaultdict(list) + ) # state_dict returns parameter.Tensor, which loses parameter level attributes parameter_by_name = dict(module.named_parameters()) @@ -283,7 +274,6 @@ def create_sharding_infos_by_group( ) per_table_fused_params = convert_to_fbgemm_types(per_table_fused_params) - group = get_sharding_group(config, parameter_sharding, fused_params) sharding_info = EmbeddingShardingInfo( embedding_config=EmbeddingTableConfig( num_embeddings=config.num_embeddings, @@ -303,8 +293,10 @@ def create_sharding_infos_by_group( param=param, fused_params=per_table_fused_params, ) - group_to_sharding_infos[group].append(sharding_info) - return group_to_sharding_infos + sharding_type_to_sharding_infos[parameter_sharding.sharding_type].append( + sharding_info + ) + return sharding_type_to_sharding_infos def create_sharding_infos_by_sharding_device_group( @@ -581,7 +573,7 @@ def __init__( ) self._env = env - group_to_sharding_infos = create_sharding_infos_by_group( + sharding_type_to_sharding_infos = create_sharding_infos_by_sharding( module, table_name_to_parameter_sharding, "embedding_bags.", @@ -602,7 +594,7 @@ def __init__( permute_embeddings=True, qcomm_codecs_registry=self.qcomm_codecs_registry, ) - for embedding_configs in group_to_sharding_infos.values() + for embedding_configs in sharding_type_to_sharding_infos.values() ] self._is_weighted: bool = module.is_weighted() diff --git a/torchrec/distributed/planner/tests/test_embeddingbag_utils.py b/torchrec/distributed/planner/tests/test_embeddingbag_utils.py index 9b9ebfb29..eae6c113e 100644 --- a/torchrec/distributed/planner/tests/test_embeddingbag_utils.py +++ b/torchrec/distributed/planner/tests/test_embeddingbag_utils.py @@ -11,7 +11,7 @@ import unittest from torchrec.distributed.embeddingbag import ( - create_sharding_infos_by_group, + create_sharding_infos_by_sharding, EmbeddingBagCollectionSharder, ) from torchrec.distributed.planner import ( @@ -79,7 +79,7 @@ def setUp(self) -> None: ) self.expected_plan = planner.plan(self.model, [self.sharder]) # pyre-ignore[6] - self.expected_sharding_infos = create_sharding_infos_by_group( + self.expected_sharding_infos = create_sharding_infos_by_sharding( self.model, self.expected_plan.get_plan_for_module(""), # pyre-ignore[6] prefix="embedding_bags.", @@ -93,7 +93,7 @@ def test_create_sharding_infos_by_group_override(self) -> None: # with sharder fused params that will get overridden sharder_fused_params = {"enforce_hbm": False} - overriden_sharding_infos = create_sharding_infos_by_group( + overriden_sharding_infos = create_sharding_infos_by_sharding( self.model, self.expected_plan.get_plan_for_module(""), prefix="embedding_bags.", @@ -106,7 +106,7 @@ def test_create_sharding_infos_by_group_override(self) -> None: # with sharder fused params that won't get overridden sharder_fused_params = {"ABC": True} - not_overriden_sharding_infos = create_sharding_infos_by_group( + not_overriden_sharding_infos = create_sharding_infos_by_sharding( self.model, self.expected_plan.get_plan_for_module(""), prefix="embedding_bags.", @@ -141,7 +141,7 @@ def test_create_sharding_infos_by_group_combine(self) -> None: # provide that two fused params from sharder sharder_fused_params = {"enforce_hbm": True, "stochastic_rounding": False} - combined_sharding_infos = create_sharding_infos_by_group( + combined_sharding_infos = create_sharding_infos_by_sharding( self.model, new_plan.get_plan_for_module(""), # pyre-ignore[6] prefix="embedding_bags.", @@ -156,7 +156,7 @@ def test_create_sharding_infos_by_group_combine(self) -> None: # provide that two fused params from sharder wrongly sharder_fused_params = {"enforce_hbm": True, "stochastic_rounding": True} - wrong_combined_sharding_infos = create_sharding_infos_by_group( + wrong_combined_sharding_infos = create_sharding_infos_by_sharding( self.model, new_plan.get_plan_for_module(""), # pyre-ignore[6] prefix="embedding_bags.",