Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions torchrec/distributed/embedding_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
EmbeddingComputeKernel,
GroupedEmbeddingConfig,
InputDistOutputs,
KJTList,
)
from torchrec.distributed.fused_params import (
get_tbes_to_register_from_iterable,
Expand Down Expand Up @@ -442,8 +441,9 @@ def _create_lookup(
self.grouped_configs = grouped_configs
self._feature_processor = feature_processor

self._world_size: int = dist.get_world_size(pg)
self._scale_gradient_factor: int = (
dist.get_world_size(pg)
self._world_size
if scale_weight_gradients and get_gradient_division()
else 1
)
Expand Down Expand Up @@ -487,11 +487,24 @@ def _need_prefetch(config: GroupedEmbeddingConfig) -> bool:
),
)

def _merge_variable_batch_embeddings(
self, embeddings: List[torch.Tensor], splits: List[List[int]]
) -> List[torch.Tensor]:
split_embs = [e.split(s) for e, s in zip(embeddings, splits)]
combined_embs = [
emb
for rank in range(self._world_size)
for n, embs in zip(self._feature_splits, split_embs)
for emb in embs[n * rank : n * rank + n]
]
return [torch.cat(combined_embs)]

def forward(
self,
sparse_features: KeyedJaggedTensor,
) -> torch.Tensor:
embeddings: List[torch.Tensor] = []
vbe_splits = []
if len(self._emb_modules) > 0:
assert sparse_features is not None
features_by_group = sparse_features.split(
Expand All @@ -514,6 +527,23 @@ def forward(

embeddings.append(emb_op(features))

if features.variable_stride_per_key():
stride_per_rank_per_key = list(
zip(*features.stride_per_key_per_rank())
)
vbe_splits.append(
[
stride * dim
for stride_per_rank in stride_per_rank_per_key
for stride, dim in zip(
stride_per_rank, config.embedding_dims()
)
]
)

if sparse_features.variable_stride_per_key():
embeddings = self._merge_variable_batch_embeddings(embeddings, vbe_splits)

dummy_embedding = (
self._dummy_embs_tensor
if sparse_features.variable_stride_per_key()
Expand Down
23 changes: 1 addition & 22 deletions torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,28 +205,7 @@ def get_sharding_group(
param_sharding: ParameterSharding,
fused_params: Optional[Dict[str, Any]] = None,
) -> str:
if fused_params and fused_params.get(USE_ONE_TBE_PER_TABLE, False):
return config.name
if param_sharding.sharding_type in {
ShardingType.COLUMN_WISE.value,
ShardingType.TABLE_COLUMN_WISE.value,
}:
assert param_sharding.ranks
num_ranks = len(param_sharding.ranks)
assert config.embedding_dim % num_ranks == 0
dim = config.embedding_dim // num_ranks
else:
dim = config.embedding_dim

group = f"{param_sharding.sharding_type}"
if param_sharding.compute_kernel == EmbeddingComputeKernel.FUSED_UVM_CACHING.value:
group += f"@{param_sharding.compute_kernel}"
if (fused_params and fused_params.get("prefetch_pipeline", False)) or (
param_sharding.cache_params
and param_sharding.cache_params.prefetch_pipeline
):
group += f"@{dim}"
return group
return param_sharding.sharding_type


def create_sharding_infos_by_group(
Expand Down
3 changes: 2 additions & 1 deletion torchrec/distributed/test_utils/test_model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,9 +653,10 @@ def test_sharding_multiple_kernels(self, sharding_type: str) -> None:
)
for i, table in enumerate(self.tables)
}
fused_params = {"prefetch_pipeline": True}
self._test_sharding(
# pyre-ignore[6]
sharders=[EmbeddingBagCollectionSharder()],
sharders=[EmbeddingBagCollectionSharder(fused_params=fused_params)],
backend=self.backend,
constraints=constraints,
variable_batch_per_feature=True,
Expand Down