Skip to content

Commit 629fd71

Browse files
dstaay-fbfacebook-github-bot
authored andcommitted
Fix bug on VBE+CPU (#2256)
Summary: Pull Request resolved: #2256 Internal users reported a bug working with VBE + CPU. Identified regression was introduced by stray edit in D55695198. Simple 1-line fix, but added unit test to cover this edge case for both CPU + GPU setups. Reviewed By: TroyGarden Differential Revision: D60430765
1 parent 2771a90 commit 629fd71

File tree

4 files changed

+122
-2
lines changed

4 files changed

+122
-2
lines changed

torchrec/modules/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def _permute_tensor_by_segments(
184184
segment_sizes,
185185
tensor,
186186
weights,
187-
tensor.numel(),
187+
output_size,
188188
)
189189
return permuted_tensor, permuted_weights
190190

torchrec/sparse/jagged_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ def _permute_tensor_by_segments(
453453
segment_sizes,
454454
tensor,
455455
weights,
456-
tensor.numel(),
456+
output_size,
457457
)
458458
return permuted_tensor, permuted_weights
459459

torchrec/sparse/tests/test_jagged_tensor.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1400,6 +1400,59 @@ def test_permute_vb(self) -> None:
14001400
)
14011401
self.assertEqual(permuted_jag_tensor.weights_or_none(), None)
14021402

1403+
def test_permute_vb_duplicate(self) -> None:
1404+
values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
1405+
lengths = torch.IntTensor([1, 0, 1, 3, 0, 1, 0, 2, 0])
1406+
keys = ["index_0", "index_1", "index_2"]
1407+
stride_per_key_per_rank = [[2], [4], [3]]
1408+
1409+
jag_tensor = KeyedJaggedTensor.from_lengths_sync(
1410+
values=values,
1411+
keys=keys,
1412+
lengths=lengths,
1413+
stride_per_key_per_rank=stride_per_key_per_rank,
1414+
)
1415+
1416+
indices = [1, 1, 0, 0, 2, 2]
1417+
permuted_jag_tensor = jag_tensor.permute(indices)
1418+
1419+
self.assertEqual(
1420+
permuted_jag_tensor.keys(),
1421+
["index_1", "index_1", "index_0", "index_0", "index_2", "index_2"],
1422+
)
1423+
self.assertTrue(
1424+
torch.equal(
1425+
permuted_jag_tensor.values(),
1426+
torch.Tensor(
1427+
[
1428+
2.0,
1429+
3.0,
1430+
4.0,
1431+
5.0,
1432+
6.0,
1433+
2.0,
1434+
3.0,
1435+
4.0,
1436+
5.0,
1437+
6.0,
1438+
1.0,
1439+
1.0,
1440+
7.0,
1441+
8.0,
1442+
7.0,
1443+
8.0,
1444+
]
1445+
),
1446+
)
1447+
)
1448+
self.assertTrue(
1449+
torch.equal(
1450+
permuted_jag_tensor.lengths(),
1451+
torch.IntTensor([1, 3, 0, 1, 1, 3, 0, 1, 1, 0, 1, 0, 0, 2, 0, 0, 2, 0]),
1452+
)
1453+
)
1454+
self.assertEqual(permuted_jag_tensor.weights_or_none(), None)
1455+
14031456
def test_permute_duplicates(self) -> None:
14041457
values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
14051458
lengths = torch.IntTensor([0, 2, 0, 1, 1, 1, 0, 3, 0])

torchrec/sparse/tests/test_jagged_tensor_gpu.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,13 @@ def test_regroup_backward(self) -> None:
116116
torch.allclose(actual_kt_0_grad, expected_kt_0_grad)
117117
torch.allclose(actual_kt_1_grad, expected_kt_1_grad)
118118

119+
120+
@skip_if_asan_class
121+
class TestKeyedJaggedTensorGPU(unittest.TestCase):
122+
def setUp(self) -> None:
123+
super().setUp()
124+
self.device = torch.cuda.current_device()
125+
119126
# pyre-ignore
120127
@unittest.skipIf(
121128
torch.cuda.device_count() <= 0,
@@ -187,6 +194,66 @@ def test_permute_vb(self) -> None:
187194
)
188195
self.assertEqual(permuted_jag_tensor.weights_or_none(), None)
189196

197+
# pyre-ignore
198+
@unittest.skipIf(
199+
torch.cuda.device_count() <= 0,
200+
"Not enough GPUs, this test requires at least one GPUs",
201+
)
202+
def test_permute_vb_duplicate(self) -> None:
203+
values = torch.tensor(
204+
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], device=self.device
205+
)
206+
lengths = torch.tensor([1, 0, 1, 3, 0, 1, 0, 2, 0], device=self.device)
207+
keys = ["index_0", "index_1", "index_2"]
208+
stride_per_key_per_rank = [[2], [4], [3]]
209+
210+
jag_tensor = KeyedJaggedTensor.from_lengths_sync(
211+
values=values,
212+
keys=keys,
213+
lengths=lengths,
214+
stride_per_key_per_rank=stride_per_key_per_rank,
215+
)
216+
217+
indices = [1, 1, 0, 0, 2, 2]
218+
permuted_jag_tensor = jag_tensor.permute(indices)
219+
220+
self.assertEqual(
221+
permuted_jag_tensor.keys(),
222+
["index_1", "index_1", "index_0", "index_0", "index_2", "index_2"],
223+
)
224+
self.assertTrue(
225+
torch.equal(
226+
permuted_jag_tensor.values().cpu(),
227+
torch.Tensor(
228+
[
229+
2.0,
230+
3.0,
231+
4.0,
232+
5.0,
233+
6.0,
234+
2.0,
235+
3.0,
236+
4.0,
237+
5.0,
238+
6.0,
239+
1.0,
240+
1.0,
241+
7.0,
242+
8.0,
243+
7.0,
244+
8.0,
245+
]
246+
),
247+
)
248+
)
249+
self.assertTrue(
250+
torch.equal(
251+
permuted_jag_tensor.lengths().cpu(),
252+
torch.IntTensor([1, 3, 0, 1, 1, 3, 0, 1, 1, 0, 1, 0, 0, 2, 0, 0, 2, 0]),
253+
)
254+
)
255+
self.assertEqual(permuted_jag_tensor.weights_or_none(), None)
256+
190257
# pyre-ignore
191258
@unittest.skipIf(
192259
torch.cuda.device_count() <= 0,

0 commit comments

Comments
 (0)