Skip to content

Commit 8fe09df

Browse files
joshuadengfacebook-github-bot
authored andcommitted
concat VBE embeddings in lookup module (#2215)
Summary: Pull Request resolved: #2215 Previously to handle multiple VBE TBE output which is 1d tensor ordered by rank, we grouped sharding info such that there would only be one TBE created per sharding module. This avoided the issue of concatting multiple 1d tensors that are ordered by rank (not a problem in on VBE bc of 2d output which we can concat on dim 1). This grouping which would be done only applies to specific UVM caching setups that used prefetch pipeline, as each sharding type could require multiple TBE to handle both HBM and UVM caching setups. In most cases the TBE could be fused for each sharding type, so we grouped by such. Each sharding module handles individual input dist, lookup, output dist, and by creating a sharding module per each TBE in EMO setups would cause regression, as there would be an increase in comms to handle the increased input dists and output dists. This diff removes the need for the grouping logic to circumvent the VBE TBE output concatenation by implementing output concatenation, which removes the necessity for specialized sharding grouping logic for specific EMO cases. Reviewed By: dstaay-fb Differential Revision: D58894728
1 parent 4651535 commit 8fe09df

File tree

3 files changed

+35
-25
lines changed

3 files changed

+35
-25
lines changed

torchrec/distributed/embedding_lookup.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
EmbeddingComputeKernel,
4747
GroupedEmbeddingConfig,
4848
InputDistOutputs,
49-
KJTList,
5049
)
5150
from torchrec.distributed.fused_params import (
5251
get_tbes_to_register_from_iterable,
@@ -442,8 +441,9 @@ def _create_lookup(
442441
self.grouped_configs = grouped_configs
443442
self._feature_processor = feature_processor
444443

444+
self._world_size: int = dist.get_world_size(pg)
445445
self._scale_gradient_factor: int = (
446-
dist.get_world_size(pg)
446+
self._world_size
447447
if scale_weight_gradients and get_gradient_division()
448448
else 1
449449
)
@@ -487,11 +487,24 @@ def _need_prefetch(config: GroupedEmbeddingConfig) -> bool:
487487
),
488488
)
489489

490+
def _merge_variable_batch_embeddings(
491+
self, embeddings: List[torch.Tensor], splits: List[List[int]]
492+
) -> List[torch.Tensor]:
493+
split_embs = [e.split(s) for e, s in zip(embeddings, splits)]
494+
combined_embs = [
495+
emb
496+
for rank in range(self._world_size)
497+
for n, embs in zip(self._feature_splits, split_embs)
498+
for emb in embs[n * rank : n * rank + n]
499+
]
500+
return [torch.cat(combined_embs)]
501+
490502
def forward(
491503
self,
492504
sparse_features: KeyedJaggedTensor,
493505
) -> torch.Tensor:
494506
embeddings: List[torch.Tensor] = []
507+
vbe_splits = []
495508
if len(self._emb_modules) > 0:
496509
assert sparse_features is not None
497510
features_by_group = sparse_features.split(
@@ -514,6 +527,23 @@ def forward(
514527

515528
embeddings.append(emb_op(features))
516529

530+
if features.variable_stride_per_key():
531+
stride_per_rank_per_key = list(
532+
zip(*features.stride_per_key_per_rank())
533+
)
534+
vbe_splits.append(
535+
[
536+
stride * dim
537+
for stride_per_rank in stride_per_rank_per_key
538+
for stride, dim in zip(
539+
stride_per_rank, config.embedding_dims()
540+
)
541+
]
542+
)
543+
544+
if sparse_features.variable_stride_per_key():
545+
embeddings = self._merge_variable_batch_embeddings(embeddings, vbe_splits)
546+
517547
dummy_embedding = (
518548
self._dummy_embs_tensor
519549
if sparse_features.variable_stride_per_key()

torchrec/distributed/embeddingbag.py

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -205,28 +205,7 @@ def get_sharding_group(
205205
param_sharding: ParameterSharding,
206206
fused_params: Optional[Dict[str, Any]] = None,
207207
) -> str:
208-
if fused_params and fused_params.get(USE_ONE_TBE_PER_TABLE, False):
209-
return config.name
210-
if param_sharding.sharding_type in {
211-
ShardingType.COLUMN_WISE.value,
212-
ShardingType.TABLE_COLUMN_WISE.value,
213-
}:
214-
assert param_sharding.ranks
215-
num_ranks = len(param_sharding.ranks)
216-
assert config.embedding_dim % num_ranks == 0
217-
dim = config.embedding_dim // num_ranks
218-
else:
219-
dim = config.embedding_dim
220-
221-
group = f"{param_sharding.sharding_type}"
222-
if param_sharding.compute_kernel == EmbeddingComputeKernel.FUSED_UVM_CACHING.value:
223-
group += f"@{param_sharding.compute_kernel}"
224-
if (fused_params and fused_params.get("prefetch_pipeline", False)) or (
225-
param_sharding.cache_params
226-
and param_sharding.cache_params.prefetch_pipeline
227-
):
228-
group += f"@{dim}"
229-
return group
208+
return param_sharding.sharding_type
230209

231210

232211
def create_sharding_infos_by_group(

torchrec/distributed/test_utils/test_model_parallel.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -653,9 +653,10 @@ def test_sharding_multiple_kernels(self, sharding_type: str) -> None:
653653
)
654654
for i, table in enumerate(self.tables)
655655
}
656+
fused_params = {"prefetch_pipeline": True}
656657
self._test_sharding(
657658
# pyre-ignore[6]
658-
sharders=[EmbeddingBagCollectionSharder()],
659+
sharders=[EmbeddingBagCollectionSharder(fused_params=fused_params)],
659660
backend=self.backend,
660661
constraints=constraints,
661662
variable_batch_per_feature=True,

0 commit comments

Comments
 (0)