Skip to content

Commit 4958d8e

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
FBGEMM kernel for KeyedTensor (PooledEmbedding) permute mapping
Summary: # context * current we have a working function `permute_pooled_embs_auto_grad` to do a full permute of KTs, including forward and backward * it has several limitations: a) it has to be a full permute, duplicates are not supported; b) in the main [use case](https://fburl.com/code/89od0rqm) there has to be a torch.concat on the input KTs, which is not very efficient; c) the function output a single KT which requires a split operation * there is some attempt to support duplicated outputs, but the backward doesn't work * this diff is trying to create a new kernel (named `multi_permute_pooled_embedding`) to support a multiple-KT to multiple-KT mapping operation with backward support # operator example usage * used in python ``` # test inputs: 3 KTs with batch_size=2048 batch_size = 2048 keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] lengths = [[96, 256], [512, 128, 768], [1024]] values = [ torch.randn(batch_size, sum(lens), device="cuda", requires_grad=True) for lens in lengths ] # target outputs: 4 KTs with re-arranged keys (features), duplicates are allowed groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] # accessorial arguments to the op/kernel permutes, in_lengths, out_lengths = _multi_remap_to_groups( keys, lengths, groups ) # arguments outputs = torch.ops.fbgemm.permute_multi_embedding( values, # list of tensors (on device) permutes.to(device=torch.device("cuda")), # tensor on device out_lengths.tolist(), # List[int] on CPU in_lengths.to(device=torch.device("cuda")), # tensor on device out_lengths.to(device=torch.device("cuda")), # tensor on device ) ``` * values ``` permutes = tensor( [ [0, 0, 0, 0, 3, 4], # f1 [1, 0, 0, 3, 5, 0], # f3 [0, 1, 3, 0, 4, 0], # f2 [1, 2, 5, 0, 6, 0], # f4 [0, 2, 0, 6, 3, -6], # f1 [2, 2, 0, 9, 8, 0], # f6 [0, 3, 0, 0, 3, -8], # f1 [1, 3, 11, 3, 7, 0], # f5 ] ) ``` # details 1. from the above example usage, we can clean see that the operatior takes in the following: a) values: List[torch.Tensor], which represents the input KTs b) permutes: torch.Tensor, which contains the permute information, will be explained later c) output_lengths_list: List[int], the lengths of the output tensors (KTs), which is needed to allocate memory on device ahead d) in_lengths: torch.Tensor, lengths of input tensors, which is on device e) out_lengths: torch.Tensor, lengths of output tensors, which is on device 2. the operator returns a list of tensors, which represents the permuted KTs 3. `permute` is the most critical argument in this operator: a) 2-D tensor b) each row represents key (feature) permute move c) a permute move = [input_tensor_id, output_tensor_id, input_start_idx, output_start_idx, feature_length, jump] d) jump is used in backward when a key (feature) from the input tensor is mapped to multiple places in the output tensors # performance notes The good: 1. the algorithm is designed in a way that it doesn't need to know in advance whether the 1-to-N mapping exists in the permutes. 2. `_all_keys_used_once` is no longer needed 3. no longer need a torch.cat before calling the old operator The same bad: 1. it requires several HtoD communications (move tensor to device): a) 3 tensors, which are `permutes`, `input_lengths`, and `output_lengths`. Those tensors needs to be on the device so that the cuda kernels has access to it. b) 2 lists of (scalar_t*) pointers, input and output tensor lists. c) Didn't find a good way to let the kernel knows the address of the lists of input/output tensors, because the lists are also need to be on the device. 2. tensor.contiguous for the backward function, it looks like the grad from the backward are somehow not contiguous Differential Revision: D57055616
1 parent be40210 commit 4958d8e

File tree

2 files changed

+235
-2
lines changed

2 files changed

+235
-2
lines changed

torchrec/sparse/jagged_tensor.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,73 @@ def _remap_to_groups(
240240
return permute, inv_permute, offsets, inv_offsets, splits
241241

242242

243+
@torch.fx.wrap
244+
def _multi_remap_to_groups(
245+
keys: List[List[str]],
246+
key_lengths: List[List[int]],
247+
groups: List[List[str]],
248+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
249+
"""
250+
Given a list of keys and lengths per key for each group, return the permute 2D tensor, and 1D tensor lengths:
251+
[[input_tensor_idx, output_tensor_idx, input_start, output_start, length]], [length]
252+
"""
253+
# key => (tensor_idx, key_index)
254+
key_map: Dict[str, Tuple[int, int]] = {
255+
key: (tensor_idx, key_idx)
256+
for tensor_idx, tensor in enumerate(keys)
257+
for key_idx, key in enumerate(tensor)
258+
}
259+
260+
# [offsets per tensor]
261+
offsets_list: List[List[int]] = [_cumsum(tensor) for tensor in key_lengths]
262+
263+
# [input_tensor_idx, output_tensor_idx, input_start, output_start, length]
264+
permute_list: List[List[int]] = []
265+
output_lengths: List[int] = [0] * len(groups)
266+
267+
total_permutes = sum(len(tensor) for tensor in groups)
268+
last_seen: Dict[str, int] = {}
269+
for output_tensor_idx, output_tenser in enumerate(groups):
270+
output_start = 0
271+
for output_key in output_tenser:
272+
input_tensor_idx, input_key_idx = key_map[output_key]
273+
input_start = offsets_list[input_tensor_idx][input_key_idx]
274+
length = key_lengths[input_tensor_idx][input_key_idx]
275+
276+
# add jump data
277+
if output_key in last_seen:
278+
jump = last_seen[output_key]
279+
if jump >= 0: # it's a jump start
280+
# change previous jump to current
281+
permute_list[jump][5] = len(permute_list)
282+
else: # it's already in a jump sequence
283+
permute_list[-jump][5] = -len(permute_list)
284+
last_seen[output_key] = -len(permute_list) # it's already in jump
285+
jump = -total_permutes
286+
else:
287+
jump = 0
288+
last_seen[output_key] = len(permute_list) # potential jump start
289+
290+
permute_list.append(
291+
[
292+
input_tensor_idx,
293+
output_tensor_idx,
294+
input_start,
295+
output_start,
296+
length,
297+
jump,
298+
]
299+
)
300+
output_start += length
301+
output_lengths[output_tensor_idx] = output_start
302+
permutes = torch.tensor(permute_list, dtype=torch.int64)
303+
in_lengths = torch.tensor(
304+
[offsets[-1] for offsets in offsets_list], dtype=torch.int64
305+
)
306+
out_lengths = torch.tensor(output_lengths, dtype=torch.int64)
307+
return permutes, in_lengths, out_lengths
308+
309+
243310
def _values_string(values: torch.Tensor, start: int, end: int) -> str:
244311
size = values.size()
245312
if len(size) == 1:

torchrec/sparse/tests/test_jagged_tensor.py

Lines changed: 168 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from torch.testing import FileCheck
1717
from torchrec.fx import symbolic_trace
1818
from torchrec.sparse.jagged_tensor import (
19+
_multi_remap_to_groups,
1920
_regroup_keyed_tensors,
2021
ComputeJTDictToKJT,
2122
ComputeKJTToJTDict,
@@ -1374,6 +1375,173 @@ def test_permute_vb(self) -> None:
13741375
)
13751376
self.assertEqual(permuted_jag_tensor.weights_or_none(), None)
13761377

1378+
def test_multi_remap_to_group(self) -> None:
1379+
keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
1380+
lengths = [[3, 4], [5, 6, 7], [8]]
1381+
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]
1382+
res, in_lengths, out_lengths = _multi_remap_to_groups(keys, lengths, groups)
1383+
ref = torch.tensor(
1384+
[
1385+
[0, 0, 0, 0, 3, 4], # f1
1386+
[1, 0, 0, 3, 5, 0], # f3
1387+
[0, 1, 3, 0, 4, 0], # f2
1388+
[1, 2, 5, 0, 6, 0], # f4
1389+
[0, 2, 0, 6, 3, -6], # f1
1390+
[2, 2, 0, 9, 8, 0], # f6
1391+
[0, 3, 0, 0, 3, -8], # f1
1392+
[1, 3, 11, 3, 7, 0], # f5
1393+
]
1394+
)
1395+
self.assertEqual(in_lengths.tolist(), [7, 18, 8])
1396+
self.assertEqual(out_lengths.tolist(), [8, 4, 17, 10])
1397+
self.assertTrue(torch.equal(res, ref))
1398+
1399+
def test_multi_permute_forward_cpu(self) -> None:
1400+
batch_size = 5
1401+
keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
1402+
lengths = [[3, 4], [5, 6, 7], [8]]
1403+
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]
1404+
values = [
1405+
torch.randn(batch_size, sum(lens), device="cpu", requires_grad=True)
1406+
for lens in lengths
1407+
]
1408+
permutes, in_lengths, out_lengths = _multi_remap_to_groups(
1409+
keys, lengths, groups
1410+
)
1411+
refs = [[] for _ in groups]
1412+
for in_idx, out_idx, in_start, _, length, _ in permutes.tolist():
1413+
refs[out_idx].append(values[in_idx][:, in_start : (in_start + length)])
1414+
refs = [torch.cat(ref, dim=1) for ref in refs]
1415+
outputs = torch.ops.fbgemm.permute_multi_embedding(
1416+
values, permutes, out_lengths.tolist(), in_lengths, out_lengths
1417+
)
1418+
for out, ref in zip(outputs, refs):
1419+
self.assertTrue(torch.allclose(out, ref))
1420+
1421+
def test_multi_permute_forward_meta(self) -> None:
1422+
batch_size = 5
1423+
keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
1424+
lengths = [[3, 4], [5, 6, 7], [8]]
1425+
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]
1426+
values = [
1427+
torch.randn(batch_size, sum(lens), device="meta", requires_grad=True)
1428+
for lens in lengths
1429+
]
1430+
permutes, in_lengths, out_lengths = _multi_remap_to_groups(
1431+
keys, lengths, groups
1432+
)
1433+
refs = [[] for _ in groups]
1434+
for in_idx, out_idx, in_start, _, length, _ in permutes.tolist():
1435+
refs[out_idx].append(values[in_idx][:, in_start : (in_start + length)])
1436+
refs = [torch.cat(ref, dim=1) for ref in refs]
1437+
outputs = torch.ops.fbgemm.permute_multi_embedding(
1438+
values, permutes, out_lengths.tolist(), in_lengths, out_lengths
1439+
)
1440+
for out, ref in zip(outputs, refs):
1441+
self.assertEqual(out.shape, ref.shape)
1442+
1443+
def test_multi_permute_forward_gpu(self) -> None:
1444+
batch_size = 5
1445+
keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
1446+
lengths = [[3, 4], [5, 6, 7], [8]]
1447+
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]
1448+
values = [
1449+
torch.randn(batch_size, sum(lens), device="cuda", requires_grad=True)
1450+
for lens in lengths
1451+
]
1452+
permutes, in_lengths, out_lengths = _multi_remap_to_groups(
1453+
keys, lengths, groups
1454+
)
1455+
refs = [[] for _ in groups]
1456+
for in_idx, out_idx, in_start, _, length, _ in permutes.tolist():
1457+
refs[out_idx].append(values[in_idx][:, in_start : (in_start + length)])
1458+
refs = [torch.cat(ref, dim=1) for ref in refs]
1459+
outputs = torch.ops.fbgemm.permute_multi_embedding(
1460+
values,
1461+
permutes.to(device=torch.device("cuda")),
1462+
out_lengths.tolist(),
1463+
in_lengths.to(device=torch.device("cuda")),
1464+
out_lengths.to(device=torch.device("cuda")),
1465+
)
1466+
for out, ref in zip(outputs, refs):
1467+
self.assertTrue(torch.allclose(out, ref))
1468+
1469+
def test_multi_permute_backward_cpu(self) -> None:
1470+
batch_size = 5
1471+
keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
1472+
lengths = [[3, 4], [5, 6, 7], [8]]
1473+
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]
1474+
values = [
1475+
torch.randn(batch_size, sum(lens), device="cpu", requires_grad=True)
1476+
for lens in lengths
1477+
]
1478+
ref_values = [v.detach() for v in values]
1479+
for v in ref_values:
1480+
v.requires_grad = True
1481+
permutes, in_lengths, out_lengths = _multi_remap_to_groups(
1482+
keys, lengths, groups
1483+
)
1484+
refs = [[] for _ in groups]
1485+
for in_idx, out_idx, in_start, _, length, _ in permutes.tolist():
1486+
refs[out_idx].append(ref_values[in_idx][:, in_start : (in_start + length)])
1487+
refs = [torch.cat(ref, dim=1) for ref in refs]
1488+
outputs = torch.ops.fbgemm.permute_multi_embedding(
1489+
values,
1490+
permutes,
1491+
out_lengths.tolist(),
1492+
in_lengths,
1493+
out_lengths,
1494+
)
1495+
for out, ref in zip(outputs, refs):
1496+
self.assertTrue(torch.allclose(out, ref))
1497+
1498+
ref_loss = sum((i + 1.1) * ref.sum() for i, ref in enumerate(refs))
1499+
self.assertTrue(isinstance(ref_loss, torch.Tensor))
1500+
ref_loss.backward()
1501+
loss = sum((i + 1.1) * out.sum() for i, out in enumerate(outputs))
1502+
self.assertTrue(isinstance(loss, torch.Tensor))
1503+
loss.backward()
1504+
for val, ref in zip(values, ref_values):
1505+
self.assertTrue(torch.allclose(val.grad, ref.grad))
1506+
1507+
def test_multi_permute_backward_gpu(self) -> None:
1508+
batch_size = 2048
1509+
keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
1510+
lengths = [[96, 256], [512, 128, 768], [1024]]
1511+
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]
1512+
values = [
1513+
torch.randn(batch_size, sum(lens), device="cuda", requires_grad=True)
1514+
for lens in lengths
1515+
]
1516+
ref_values = [v.detach() for v in values]
1517+
for v in ref_values:
1518+
v.requires_grad = True
1519+
permutes, in_lengths, out_lengths = _multi_remap_to_groups(
1520+
keys, lengths, groups
1521+
)
1522+
refs = [[] for _ in groups]
1523+
for in_idx, out_idx, in_start, _, length, _ in permutes.tolist():
1524+
refs[out_idx].append(ref_values[in_idx][:, in_start : (in_start + length)])
1525+
refs = [torch.cat(ref, dim=1) for ref in refs]
1526+
outputs = torch.ops.fbgemm.permute_multi_embedding(
1527+
values,
1528+
permutes.to(device=torch.device("cuda")),
1529+
out_lengths.tolist(),
1530+
in_lengths.to(device=torch.device("cuda")),
1531+
out_lengths.to(device=torch.device("cuda")),
1532+
)
1533+
for out, ref in zip(outputs, refs):
1534+
self.assertTrue(torch.allclose(out, ref))
1535+
1536+
ref_loss = sum((i + 1.1) * ref.sum() for i, ref in enumerate(refs))
1537+
self.assertTrue(isinstance(ref_loss, torch.Tensor))
1538+
ref_loss.backward()
1539+
loss = sum((i + 1.1) * out.sum() for i, out in enumerate(outputs))
1540+
loss = sum((i + 1.1) * out.sum() for i, out in enumerate(outputs))
1541+
loss.backward()
1542+
for val, ref in zip(values, ref_values):
1543+
self.assertTrue(torch.allclose(val.grad, ref.grad))
1544+
13771545
def test_permute_duplicates(self) -> None:
13781546
values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
13791547
lengths = torch.IntTensor([0, 2, 0, 1, 1, 1, 0, 3, 0])
@@ -1650,8 +1818,6 @@ def test_string_vb(self) -> None:
16501818
stride_per_key_per_rank=stride_per_key_per_rank,
16511819
)
16521820

1653-
print(str(jag_tensor))
1654-
16551821
self.assertEqual(
16561822
str(jag_tensor),
16571823
"""\

0 commit comments

Comments
 (0)