|
16 | 16 | from torch.testing import FileCheck
|
17 | 17 | from torchrec.fx import symbolic_trace
|
18 | 18 | from torchrec.sparse.jagged_tensor import (
|
| 19 | + _multi_remap_to_groups, |
19 | 20 | _regroup_keyed_tensors,
|
20 | 21 | ComputeJTDictToKJT,
|
21 | 22 | ComputeKJTToJTDict,
|
@@ -1374,6 +1375,173 @@ def test_permute_vb(self) -> None:
|
1374 | 1375 | )
|
1375 | 1376 | self.assertEqual(permuted_jag_tensor.weights_or_none(), None)
|
1376 | 1377 |
|
| 1378 | + def test_multi_remap_to_group(self) -> None: |
| 1379 | + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] |
| 1380 | + lengths = [[3, 4], [5, 6, 7], [8]] |
| 1381 | + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] |
| 1382 | + res, in_lengths, out_lengths = _multi_remap_to_groups(keys, lengths, groups) |
| 1383 | + ref = torch.tensor( |
| 1384 | + [ |
| 1385 | + [0, 0, 0, 0, 3, 4], # f1 |
| 1386 | + [1, 0, 0, 3, 5, 0], # f3 |
| 1387 | + [0, 1, 3, 0, 4, 0], # f2 |
| 1388 | + [1, 2, 5, 0, 6, 0], # f4 |
| 1389 | + [0, 2, 0, 6, 3, -6], # f1 |
| 1390 | + [2, 2, 0, 9, 8, 0], # f6 |
| 1391 | + [0, 3, 0, 0, 3, -8], # f1 |
| 1392 | + [1, 3, 11, 3, 7, 0], # f5 |
| 1393 | + ] |
| 1394 | + ) |
| 1395 | + self.assertEqual(in_lengths.tolist(), [7, 18, 8]) |
| 1396 | + self.assertEqual(out_lengths.tolist(), [8, 4, 17, 10]) |
| 1397 | + self.assertTrue(torch.equal(res, ref)) |
| 1398 | + |
| 1399 | + def test_multi_permute_forward_cpu(self) -> None: |
| 1400 | + batch_size = 5 |
| 1401 | + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] |
| 1402 | + lengths = [[3, 4], [5, 6, 7], [8]] |
| 1403 | + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] |
| 1404 | + values = [ |
| 1405 | + torch.randn(batch_size, sum(lens), device="cpu", requires_grad=True) |
| 1406 | + for lens in lengths |
| 1407 | + ] |
| 1408 | + permutes, in_lengths, out_lengths = _multi_remap_to_groups( |
| 1409 | + keys, lengths, groups |
| 1410 | + ) |
| 1411 | + refs = [[] for _ in groups] |
| 1412 | + for in_idx, out_idx, in_start, _, length, _ in permutes.tolist(): |
| 1413 | + refs[out_idx].append(values[in_idx][:, in_start : (in_start + length)]) |
| 1414 | + refs = [torch.cat(ref, dim=1) for ref in refs] |
| 1415 | + outputs = torch.ops.fbgemm.permute_multi_embedding( |
| 1416 | + values, permutes, out_lengths.tolist(), in_lengths, out_lengths |
| 1417 | + ) |
| 1418 | + for out, ref in zip(outputs, refs): |
| 1419 | + self.assertTrue(torch.allclose(out, ref)) |
| 1420 | + |
| 1421 | + def test_multi_permute_forward_meta(self) -> None: |
| 1422 | + batch_size = 5 |
| 1423 | + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] |
| 1424 | + lengths = [[3, 4], [5, 6, 7], [8]] |
| 1425 | + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] |
| 1426 | + values = [ |
| 1427 | + torch.randn(batch_size, sum(lens), device="meta", requires_grad=True) |
| 1428 | + for lens in lengths |
| 1429 | + ] |
| 1430 | + permutes, in_lengths, out_lengths = _multi_remap_to_groups( |
| 1431 | + keys, lengths, groups |
| 1432 | + ) |
| 1433 | + refs = [[] for _ in groups] |
| 1434 | + for in_idx, out_idx, in_start, _, length, _ in permutes.tolist(): |
| 1435 | + refs[out_idx].append(values[in_idx][:, in_start : (in_start + length)]) |
| 1436 | + refs = [torch.cat(ref, dim=1) for ref in refs] |
| 1437 | + outputs = torch.ops.fbgemm.permute_multi_embedding( |
| 1438 | + values, permutes, out_lengths.tolist(), in_lengths, out_lengths |
| 1439 | + ) |
| 1440 | + for out, ref in zip(outputs, refs): |
| 1441 | + self.assertEqual(out.shape, ref.shape) |
| 1442 | + |
| 1443 | + def test_multi_permute_forward_gpu(self) -> None: |
| 1444 | + batch_size = 5 |
| 1445 | + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] |
| 1446 | + lengths = [[3, 4], [5, 6, 7], [8]] |
| 1447 | + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] |
| 1448 | + values = [ |
| 1449 | + torch.randn(batch_size, sum(lens), device="cuda", requires_grad=True) |
| 1450 | + for lens in lengths |
| 1451 | + ] |
| 1452 | + permutes, in_lengths, out_lengths = _multi_remap_to_groups( |
| 1453 | + keys, lengths, groups |
| 1454 | + ) |
| 1455 | + refs = [[] for _ in groups] |
| 1456 | + for in_idx, out_idx, in_start, _, length, _ in permutes.tolist(): |
| 1457 | + refs[out_idx].append(values[in_idx][:, in_start : (in_start + length)]) |
| 1458 | + refs = [torch.cat(ref, dim=1) for ref in refs] |
| 1459 | + outputs = torch.ops.fbgemm.permute_multi_embedding( |
| 1460 | + values, |
| 1461 | + permutes.to(device=torch.device("cuda")), |
| 1462 | + out_lengths.tolist(), |
| 1463 | + in_lengths.to(device=torch.device("cuda")), |
| 1464 | + out_lengths.to(device=torch.device("cuda")), |
| 1465 | + ) |
| 1466 | + for out, ref in zip(outputs, refs): |
| 1467 | + self.assertTrue(torch.allclose(out, ref)) |
| 1468 | + |
| 1469 | + def test_multi_permute_backward_cpu(self) -> None: |
| 1470 | + batch_size = 5 |
| 1471 | + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] |
| 1472 | + lengths = [[3, 4], [5, 6, 7], [8]] |
| 1473 | + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] |
| 1474 | + values = [ |
| 1475 | + torch.randn(batch_size, sum(lens), device="cpu", requires_grad=True) |
| 1476 | + for lens in lengths |
| 1477 | + ] |
| 1478 | + ref_values = [v.detach() for v in values] |
| 1479 | + for v in ref_values: |
| 1480 | + v.requires_grad = True |
| 1481 | + permutes, in_lengths, out_lengths = _multi_remap_to_groups( |
| 1482 | + keys, lengths, groups |
| 1483 | + ) |
| 1484 | + refs = [[] for _ in groups] |
| 1485 | + for in_idx, out_idx, in_start, _, length, _ in permutes.tolist(): |
| 1486 | + refs[out_idx].append(ref_values[in_idx][:, in_start : (in_start + length)]) |
| 1487 | + refs = [torch.cat(ref, dim=1) for ref in refs] |
| 1488 | + outputs = torch.ops.fbgemm.permute_multi_embedding( |
| 1489 | + values, |
| 1490 | + permutes, |
| 1491 | + out_lengths.tolist(), |
| 1492 | + in_lengths, |
| 1493 | + out_lengths, |
| 1494 | + ) |
| 1495 | + for out, ref in zip(outputs, refs): |
| 1496 | + self.assertTrue(torch.allclose(out, ref)) |
| 1497 | + |
| 1498 | + ref_loss = sum((i + 1.1) * ref.sum() for i, ref in enumerate(refs)) |
| 1499 | + self.assertTrue(isinstance(ref_loss, torch.Tensor)) |
| 1500 | + ref_loss.backward() |
| 1501 | + loss = sum((i + 1.1) * out.sum() for i, out in enumerate(outputs)) |
| 1502 | + self.assertTrue(isinstance(loss, torch.Tensor)) |
| 1503 | + loss.backward() |
| 1504 | + for val, ref in zip(values, ref_values): |
| 1505 | + self.assertTrue(torch.allclose(val.grad, ref.grad)) |
| 1506 | + |
| 1507 | + def test_multi_permute_backward_gpu(self) -> None: |
| 1508 | + batch_size = 2048 |
| 1509 | + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] |
| 1510 | + lengths = [[96, 256], [512, 128, 768], [1024]] |
| 1511 | + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] |
| 1512 | + values = [ |
| 1513 | + torch.randn(batch_size, sum(lens), device="cuda", requires_grad=True) |
| 1514 | + for lens in lengths |
| 1515 | + ] |
| 1516 | + ref_values = [v.detach() for v in values] |
| 1517 | + for v in ref_values: |
| 1518 | + v.requires_grad = True |
| 1519 | + permutes, in_lengths, out_lengths = _multi_remap_to_groups( |
| 1520 | + keys, lengths, groups |
| 1521 | + ) |
| 1522 | + refs = [[] for _ in groups] |
| 1523 | + for in_idx, out_idx, in_start, _, length, _ in permutes.tolist(): |
| 1524 | + refs[out_idx].append(ref_values[in_idx][:, in_start : (in_start + length)]) |
| 1525 | + refs = [torch.cat(ref, dim=1) for ref in refs] |
| 1526 | + outputs = torch.ops.fbgemm.permute_multi_embedding( |
| 1527 | + values, |
| 1528 | + permutes.to(device=torch.device("cuda")), |
| 1529 | + out_lengths.tolist(), |
| 1530 | + in_lengths.to(device=torch.device("cuda")), |
| 1531 | + out_lengths.to(device=torch.device("cuda")), |
| 1532 | + ) |
| 1533 | + for out, ref in zip(outputs, refs): |
| 1534 | + self.assertTrue(torch.allclose(out, ref)) |
| 1535 | + |
| 1536 | + ref_loss = sum((i + 1.1) * ref.sum() for i, ref in enumerate(refs)) |
| 1537 | + self.assertTrue(isinstance(ref_loss, torch.Tensor)) |
| 1538 | + ref_loss.backward() |
| 1539 | + loss = sum((i + 1.1) * out.sum() for i, out in enumerate(outputs)) |
| 1540 | + loss = sum((i + 1.1) * out.sum() for i, out in enumerate(outputs)) |
| 1541 | + loss.backward() |
| 1542 | + for val, ref in zip(values, ref_values): |
| 1543 | + self.assertTrue(torch.allclose(val.grad, ref.grad)) |
| 1544 | + |
1377 | 1545 | def test_permute_duplicates(self) -> None:
|
1378 | 1546 | values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
|
1379 | 1547 | lengths = torch.IntTensor([0, 2, 0, 1, 1, 1, 0, 3, 0])
|
@@ -1650,8 +1818,6 @@ def test_string_vb(self) -> None:
|
1650 | 1818 | stride_per_key_per_rank=stride_per_key_per_rank,
|
1651 | 1819 | )
|
1652 | 1820 |
|
1653 |
| - print(str(jag_tensor)) |
1654 |
| - |
1655 | 1821 | self.assertEqual(
|
1656 | 1822 | str(jag_tensor),
|
1657 | 1823 | """\
|
|
0 commit comments