diff --git a/torchrec/distributed/embedding_lookup.py b/torchrec/distributed/embedding_lookup.py index fd5948ea5..1c157a174 100644 --- a/torchrec/distributed/embedding_lookup.py +++ b/torchrec/distributed/embedding_lookup.py @@ -46,7 +46,6 @@ EmbeddingComputeKernel, GroupedEmbeddingConfig, InputDistOutputs, - KJTList, ) from torchrec.distributed.fused_params import ( get_tbes_to_register_from_iterable, @@ -442,8 +441,9 @@ def _create_lookup( self.grouped_configs = grouped_configs self._feature_processor = feature_processor + self._world_size: int = dist.get_world_size(pg) self._scale_gradient_factor: int = ( - dist.get_world_size(pg) + self._world_size if scale_weight_gradients and get_gradient_division() else 1 ) @@ -487,11 +487,24 @@ def _need_prefetch(config: GroupedEmbeddingConfig) -> bool: ), ) + def _merge_variable_batch_embeddings( + self, embeddings: List[torch.Tensor], splits: List[List[int]] + ) -> List[torch.Tensor]: + split_embs = [e.split(s) for e, s in zip(embeddings, splits)] + combined_embs = [ + emb + for rank in range(self._world_size) + for n, embs in zip(self._feature_splits, split_embs) + for emb in embs[n * rank : n * rank + n] + ] + return [torch.cat(combined_embs)] + def forward( self, sparse_features: KeyedJaggedTensor, ) -> torch.Tensor: embeddings: List[torch.Tensor] = [] + vbe_splits = [] if len(self._emb_modules) > 0: assert sparse_features is not None features_by_group = sparse_features.split( @@ -514,6 +527,23 @@ def forward( embeddings.append(emb_op(features)) + if features.variable_stride_per_key(): + stride_per_rank_per_key = list( + zip(*features.stride_per_key_per_rank()) + ) + vbe_splits.append( + [ + stride * dim + for stride_per_rank in stride_per_rank_per_key + for stride, dim in zip( + stride_per_rank, config.embedding_dims() + ) + ] + ) + + if sparse_features.variable_stride_per_key(): + embeddings = self._merge_variable_batch_embeddings(embeddings, vbe_splits) + dummy_embedding = ( self._dummy_embs_tensor if sparse_features.variable_stride_per_key() diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 73969c0fe..9c729016e 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -205,28 +205,7 @@ def get_sharding_group( param_sharding: ParameterSharding, fused_params: Optional[Dict[str, Any]] = None, ) -> str: - if fused_params and fused_params.get(USE_ONE_TBE_PER_TABLE, False): - return config.name - if param_sharding.sharding_type in { - ShardingType.COLUMN_WISE.value, - ShardingType.TABLE_COLUMN_WISE.value, - }: - assert param_sharding.ranks - num_ranks = len(param_sharding.ranks) - assert config.embedding_dim % num_ranks == 0 - dim = config.embedding_dim // num_ranks - else: - dim = config.embedding_dim - - group = f"{param_sharding.sharding_type}" - if param_sharding.compute_kernel == EmbeddingComputeKernel.FUSED_UVM_CACHING.value: - group += f"@{param_sharding.compute_kernel}" - if (fused_params and fused_params.get("prefetch_pipeline", False)) or ( - param_sharding.cache_params - and param_sharding.cache_params.prefetch_pipeline - ): - group += f"@{dim}" - return group + return param_sharding.sharding_type def create_sharding_infos_by_group( diff --git a/torchrec/distributed/test_utils/test_model_parallel.py b/torchrec/distributed/test_utils/test_model_parallel.py index f7c702970..92a98ab2e 100644 --- a/torchrec/distributed/test_utils/test_model_parallel.py +++ b/torchrec/distributed/test_utils/test_model_parallel.py @@ -653,9 +653,10 @@ def test_sharding_multiple_kernels(self, sharding_type: str) -> None: ) for i, table in enumerate(self.tables) } + fused_params = {"prefetch_pipeline": True} self._test_sharding( # pyre-ignore[6] - sharders=[EmbeddingBagCollectionSharder()], + sharders=[EmbeddingBagCollectionSharder(fused_params=fused_params)], backend=self.backend, constraints=constraints, variable_batch_per_feature=True,