diff --git a/torchrec/modules/embedding_modules.py b/torchrec/modules/embedding_modules.py index 00726e3c2..e21f14805 100644 --- a/torchrec/modules/embedding_modules.py +++ b/torchrec/modules/embedding_modules.py @@ -46,6 +46,10 @@ def process_pooled_embeddings( pooled_embeddings = torch.ops.fbgemm.group_index_select_dim0( pooled_embeddings, list(torch.unbind(inverse_indices)) ) + if not pooled_embeddings: + # Return a tensor with shape [batch_size, 0] if pooled_embeddings is empty list + batch_size = inverse_indices.shape[0] + return torch.zeros((batch_size, 0), device=inverse_indices.device) return torch.cat(pooled_embeddings, dim=1)