Skip to content

Commit 737e283

Browse files
edqwerty10facebook-github-bot
authored andcommitted
handle empty (keys) sparse features case (#1883)
Summary: Pull Request resolved: #1883 handle empty keys on meta device path Reviewed By: MarcioPorto Differential Revision: D56175026 fbshipit-source-id: 7707cdecc4a68069ee809d630f20aeebbc1db1d7
1 parent a80219a commit 737e283

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

torchrec/sparse/jagged_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -786,7 +786,7 @@ def _maybe_compute_length_per_key(
786786
values: Optional[torch.Tensor],
787787
) -> List[int]:
788788
if length_per_key is None:
789-
if values is not None and values.is_meta:
789+
if len(keys) and values is not None and values.is_meta:
790790
# create dummy lengths per key when on meta device
791791
total_length = values.numel()
792792
_length = [total_length // len(keys)] * len(keys)

torchrec/sparse/tests/test_jagged_tensor.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1836,6 +1836,13 @@ def test_meta_device_compatibility(self) -> None:
18361836
keys=keys, values=values, weights=weights, offsets=offsets
18371837
)
18381838

1839+
# test empty keys case
1840+
kjt = KeyedJaggedTensor.from_lengths_sync(
1841+
keys=[],
1842+
values=torch.tensor([], device=torch.device("meta")),
1843+
lengths=torch.tensor([], device=torch.device("meta")),
1844+
)
1845+
18391846

18401847
class TestKeyedJaggedTensorScripting(unittest.TestCase):
18411848
def test_scriptable_forward(self) -> None:

0 commit comments

Comments
 (0)