20
20
from fbgemm_gpu .split_table_batched_embeddings_ops_training import (
21
21
SplitTableBatchedEmbeddingBagsCodegen ,
22
22
)
23
+ from fbgemm_gpu .tbe .ssd .training import SSDTableBatchedEmbeddingBags
23
24
from torch import nn
24
25
25
26
from torch .autograd .function import FunctionCtx
@@ -182,7 +183,10 @@ def _create_lookup(
182
183
config : GroupedEmbeddingConfig ,
183
184
) -> BaseEmbedding :
184
185
for table in config .embedding_tables :
185
- if table .compute_kernel == EmbeddingComputeKernel .FUSED_UVM_CACHING :
186
+ if (
187
+ table .compute_kernel == EmbeddingComputeKernel .FUSED_UVM_CACHING
188
+ or table .compute_kernel == EmbeddingComputeKernel .KEY_VALUE
189
+ ):
186
190
self ._need_prefetch = True
187
191
if config .compute_kernel == EmbeddingComputeKernel .DENSE :
188
192
return BatchedDenseEmbedding (
@@ -254,11 +258,18 @@ def prefetch(
254
258
"If you don’t turn on prefetch_pipeline, cache locations might be wrong in backward and can cause wrong results.\n "
255
259
)
256
260
if hasattr (emb_op .emb_module , "prefetch" ):
257
- emb_op .emb_module .prefetch (
258
- indices = features .values (),
259
- offsets = features .offsets (),
260
- forward_stream = forward_stream ,
261
- )
261
+ if isinstance (emb_op .emb_module , SSDTableBatchedEmbeddingBags ):
262
+ # only takes indices and offsets
263
+ emb_op .emb_module .prefetch (
264
+ indices = features .values (),
265
+ offsets = features .offsets (),
266
+ )
267
+ else :
268
+ emb_op .emb_module .prefetch (
269
+ indices = features .values (),
270
+ offsets = features .offsets (),
271
+ forward_stream = forward_stream ,
272
+ )
262
273
263
274
def forward (
264
275
self ,
@@ -455,7 +466,10 @@ def prefetch(
455
466
) -> None :
456
467
def _need_prefetch (config : GroupedEmbeddingConfig ) -> bool :
457
468
for table in config .embedding_tables :
458
- if table .compute_kernel == EmbeddingComputeKernel .FUSED_UVM_CACHING :
469
+ if (
470
+ table .compute_kernel == EmbeddingComputeKernel .FUSED_UVM_CACHING
471
+ or table .compute_kernel == EmbeddingComputeKernel .KEY_VALUE
472
+ ):
459
473
return True
460
474
return False
461
475
@@ -476,16 +490,23 @@ def _need_prefetch(config: GroupedEmbeddingConfig) -> bool:
476
490
"If you don't turn on prefetch_pipeline, cache locations might be wrong in backward and can cause wrong results.\n "
477
491
)
478
492
if hasattr (emb_op .emb_module , "prefetch" ):
479
- emb_op .emb_module .prefetch (
480
- indices = features .values (),
481
- offsets = features .offsets (),
482
- forward_stream = forward_stream ,
483
- batch_size_per_feature_per_rank = (
484
- features .stride_per_key_per_rank ()
485
- if features .variable_stride_per_key ()
486
- else None
487
- ),
488
- )
493
+ if isinstance (emb_op .emb_module , SSDTableBatchedEmbeddingBags ):
494
+ # only takes indices and offsets
495
+ emb_op .emb_module .prefetch (
496
+ indices = features .values (),
497
+ offsets = features .offsets (),
498
+ )
499
+ else :
500
+ emb_op .emb_module .prefetch (
501
+ indices = features .values (),
502
+ offsets = features .offsets (),
503
+ forward_stream = forward_stream ,
504
+ batch_size_per_feature_per_rank = (
505
+ features .stride_per_key_per_rank ()
506
+ if features .variable_stride_per_key ()
507
+ else None
508
+ ),
509
+ )
489
510
490
511
def _merge_variable_batch_embeddings (
491
512
self , embeddings : List [torch .Tensor ], splits : List [List [int ]]
0 commit comments