Skip to content

Commit 5f8a495

Browse files
sarckkfacebook-github-bot
authored andcommitted
Support prefetching for SSD TBE lookup (#2275)
Summary: Pull Request resolved: #2275 Currently, we cannot use prefetch pipeline with SSD-based TBE. This diff adds the requires changes in torchrec code to support this. Reviewed By: chrisxcai Differential Revision: D60838580 fbshipit-source-id: 71c837554e21651656a77e8e01b36c95d23d135f
1 parent f7c1ca1 commit 5f8a495

File tree

1 file changed

+38
-17
lines changed

1 file changed

+38
-17
lines changed

torchrec/distributed/embedding_lookup.py

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
2121
SplitTableBatchedEmbeddingBagsCodegen,
2222
)
23+
from fbgemm_gpu.tbe.ssd.training import SSDTableBatchedEmbeddingBags
2324
from torch import nn
2425

2526
from torch.autograd.function import FunctionCtx
@@ -182,7 +183,10 @@ def _create_lookup(
182183
config: GroupedEmbeddingConfig,
183184
) -> BaseEmbedding:
184185
for table in config.embedding_tables:
185-
if table.compute_kernel == EmbeddingComputeKernel.FUSED_UVM_CACHING:
186+
if (
187+
table.compute_kernel == EmbeddingComputeKernel.FUSED_UVM_CACHING
188+
or table.compute_kernel == EmbeddingComputeKernel.KEY_VALUE
189+
):
186190
self._need_prefetch = True
187191
if config.compute_kernel == EmbeddingComputeKernel.DENSE:
188192
return BatchedDenseEmbedding(
@@ -254,11 +258,18 @@ def prefetch(
254258
"If you don’t turn on prefetch_pipeline, cache locations might be wrong in backward and can cause wrong results.\n"
255259
)
256260
if hasattr(emb_op.emb_module, "prefetch"):
257-
emb_op.emb_module.prefetch(
258-
indices=features.values(),
259-
offsets=features.offsets(),
260-
forward_stream=forward_stream,
261-
)
261+
if isinstance(emb_op.emb_module, SSDTableBatchedEmbeddingBags):
262+
# only takes indices and offsets
263+
emb_op.emb_module.prefetch(
264+
indices=features.values(),
265+
offsets=features.offsets(),
266+
)
267+
else:
268+
emb_op.emb_module.prefetch(
269+
indices=features.values(),
270+
offsets=features.offsets(),
271+
forward_stream=forward_stream,
272+
)
262273

263274
def forward(
264275
self,
@@ -455,7 +466,10 @@ def prefetch(
455466
) -> None:
456467
def _need_prefetch(config: GroupedEmbeddingConfig) -> bool:
457468
for table in config.embedding_tables:
458-
if table.compute_kernel == EmbeddingComputeKernel.FUSED_UVM_CACHING:
469+
if (
470+
table.compute_kernel == EmbeddingComputeKernel.FUSED_UVM_CACHING
471+
or table.compute_kernel == EmbeddingComputeKernel.KEY_VALUE
472+
):
459473
return True
460474
return False
461475

@@ -476,16 +490,23 @@ def _need_prefetch(config: GroupedEmbeddingConfig) -> bool:
476490
"If you don't turn on prefetch_pipeline, cache locations might be wrong in backward and can cause wrong results.\n"
477491
)
478492
if hasattr(emb_op.emb_module, "prefetch"):
479-
emb_op.emb_module.prefetch(
480-
indices=features.values(),
481-
offsets=features.offsets(),
482-
forward_stream=forward_stream,
483-
batch_size_per_feature_per_rank=(
484-
features.stride_per_key_per_rank()
485-
if features.variable_stride_per_key()
486-
else None
487-
),
488-
)
493+
if isinstance(emb_op.emb_module, SSDTableBatchedEmbeddingBags):
494+
# only takes indices and offsets
495+
emb_op.emb_module.prefetch(
496+
indices=features.values(),
497+
offsets=features.offsets(),
498+
)
499+
else:
500+
emb_op.emb_module.prefetch(
501+
indices=features.values(),
502+
offsets=features.offsets(),
503+
forward_stream=forward_stream,
504+
batch_size_per_feature_per_rank=(
505+
features.stride_per_key_per_rank()
506+
if features.variable_stride_per_key()
507+
else None
508+
),
509+
)
489510

490511
def _merge_variable_batch_embeddings(
491512
self, embeddings: List[torch.Tensor], splits: List[List[int]]

0 commit comments

Comments
 (0)