diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 0e969b8e4..3670c3030 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -786,7 +786,7 @@ def _maybe_compute_length_per_key( values: Optional[torch.Tensor], ) -> List[int]: if length_per_key is None: - if values is not None and values.is_meta: + if len(keys) and values is not None and values.is_meta: # create dummy lengths per key when on meta device total_length = values.numel() _length = [total_length // len(keys)] * len(keys) diff --git a/torchrec/sparse/tests/test_jagged_tensor.py b/torchrec/sparse/tests/test_jagged_tensor.py index 98edb4194..614ce5965 100644 --- a/torchrec/sparse/tests/test_jagged_tensor.py +++ b/torchrec/sparse/tests/test_jagged_tensor.py @@ -1836,6 +1836,13 @@ def test_meta_device_compatibility(self) -> None: keys=keys, values=values, weights=weights, offsets=offsets ) + # test empty keys case + kjt = KeyedJaggedTensor.from_lengths_sync( + keys=[], + values=torch.tensor([], device=torch.device("meta")), + lengths=torch.tensor([], device=torch.device("meta")), + ) + class TestKeyedJaggedTensorScripting(unittest.TestCase): def test_scriptable_forward(self) -> None: