Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 108 additions & 0 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_cpu"
)
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu:permute_multi_embedding_ops_cpu"
)
torch.ops.load_library(
"//deeplearning/fbgemm/fbgemm_gpu:permute_multi_embedding_ops_gpu"
)
except OSError:
pass

Expand Down Expand Up @@ -164,6 +170,24 @@ def _all_keys_used_once(
return len(key_set) == len(group_set) == len(flat_keys) == len(flat_groups)


@torch.fx.wrap
def permute_multi_embedding(
keyed_tensors: List["KeyedTensor"], groups: List[List["str"]]
) -> List[torch.Tensor]:
keys, lengths, values = _desugar_keyed_tensors(keyed_tensors)
permutes, in_shape, out_shape, out_lengths = _kt_regroup_permutes(
values[0], keys, lengths, groups
)
permuted_values = torch.ops.fbgemm.permute_multi_embedding(
values,
permutes,
in_shape,
out_shape,
out_lengths,
)
return permuted_values


@torch.fx.wrap
def _fbgemm_permute_pooled_embs(
keyed_tensors: List["KeyedTensor"], groups: List[List["str"]]
Expand Down Expand Up @@ -240,6 +264,90 @@ def _remap_to_groups(
return permute, inv_permute, offsets, inv_offsets, splits


def _kt_regroup_permutes(
value: torch.Tensor,
keys: List[List[str]],
key_lengths: List[List[int]],
groups: List[List[str]],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int]]:
"""
returns: permutes, in_shapes, out_shapes, out_lengths
"""
# key => (tensor_idx, key_index)
key_map: Dict[str, Tuple[int, int]] = {
key: (tensor_idx, key_idx)
for tensor_idx, tensor in enumerate(keys)
for key_idx, key in enumerate(tensor)
}

# [offsets per tensor]
in_offsets: List[List[int]] = [[] for _ in key_lengths]
for i, tensor in enumerate(key_lengths):
in_offsets[i] = _cumsum(tensor)
in_lengths: List[int] = [sum(lengths) for lengths in key_lengths]

# set total_permutes as the jump stop sign
total_permutes: int = sum(len(tensor) for tensor in groups)
out_lengths: List[int] = [0] * len(groups)

# [input_tensor_idx, output_tensor_idx, input_start, output_start, length, jump]
permute_param = 6
permutes: List[List[int]] = [[0] * permute_param for _ in range(total_permutes)]

# record the last seen index, so that can make the jump from last_seen to current
last_seen: Dict[str, int] = {}
permute_idx = 0
for output_tensor_idx, output_tenser in enumerate(groups):
output_start = 0
for output_key in output_tenser:
input_tensor_idx, input_key_idx = key_map[output_key]
input_start = in_offsets[input_tensor_idx][input_key_idx]
length = key_lengths[input_tensor_idx][input_key_idx]

# add jump data
if output_key not in last_seen:
jump = 0 # don't need to jump yet
# positive as a potential jump start
last_seen[output_key] = permute_idx
else:
prev = last_seen[output_key]
if prev >= 0: # positive ==> it's a jump start
# jump to current idx, positive as the jump start
permutes[prev][5] = permute_idx
else: # it's already in a jump sequence, mark as negative
permutes[-prev][5] = -permute_idx
# mark last_seen negative since it's already in jump
last_seen[output_key] = -permute_idx
# it's a potential jump stop
jump = -total_permutes

permutes[permute_idx][:] = [
input_tensor_idx,
output_tensor_idx,
input_start,
output_start,
length,
jump,
]
permute_idx += 1
output_start += length
out_lengths[output_tensor_idx] = output_start

permute_tensor = torch.tensor(permutes, dtype=torch.int32)
in_shapes = torch.tensor(in_lengths, dtype=torch.int32)
out_shapes = torch.tensor(out_lengths, dtype=torch.int32)
device = value.device
permute_tensor = _pin_and_move(permute_tensor, device)
in_shapes = _pin_and_move(in_shapes, device)
out_shapes = _pin_and_move(out_shapes, device)
return (
permute_tensor,
in_shapes,
out_shapes,
out_lengths,
)


def _values_string(values: torch.Tensor, start: int, end: int) -> str:
size = values.size()
if len(size) == 1:
Expand Down
189 changes: 187 additions & 2 deletions torchrec/sparse/tests/test_jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torch.testing import FileCheck
from torchrec.fx import symbolic_trace
from torchrec.sparse.jagged_tensor import (
_kt_regroup_permutes,
_regroup_keyed_tensors,
ComputeJTDictToKJT,
ComputeKJTToJTDict,
Expand Down Expand Up @@ -1374,6 +1375,192 @@ def test_permute_vb(self) -> None:
)
self.assertEqual(permuted_jag_tensor.weights_or_none(), None)

def test_kt_regroup_permutes(self) -> None:
keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
lengths = [[3, 4], [5, 6, 7], [8]]
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]
for device in ["cpu", "meta", "cuda"]:
if device == "cuda" and not torch.cuda.is_available():
continue
device = torch.device(device)
permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_permutes(
torch.empty(0, device=device), keys, lengths, groups
)
ref_permutes = [
[0, 0, 0, 0, 3, 4], # f1, jump to 4, as a start
[1, 0, 0, 3, 5, 0], # f3
[0, 1, 3, 0, 4, 0], # f2
[1, 2, 5, 0, 6, 0], # f4
[0, 2, 0, 6, 3, -6], # f1 jump to 6, as in a jump sequence
[2, 2, 0, 9, 8, 0], # f6
[0, 3, 0, 0, 3, -8], # f1 jump stop, as out of boundary
[1, 3, 11, 3, 7, 0], # f5
]
if device.type == "meta":
self.assertEqual(
permutes.shape, (len(ref_permutes), len(ref_permutes[0]))
)
self.assertEqual(in_shapes.shape, (3,))
self.assertEqual(out_shapes.shape, (4,))
else:
self.assertTrue(
torch.equal(
permutes,
torch.tensor(ref_permutes, dtype=torch.int32, device=device),
)
)
self.assertEqual(in_shapes.tolist(), [7, 18, 8])
self.assertEqual(out_shapes.tolist(), [8, 4, 17, 10])
self.assertEqual(out_lengths, [8, 4, 17, 10])

def test_multi_permute_forward_cpu(self) -> None:
batch_size = 32
keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
lengths = [[3, 4], [5, 6, 7], [8]]
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]
values = [
torch.randn(batch_size, sum(lens), device="cpu", requires_grad=True)
for lens in lengths
]
permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_permutes(
values[0], keys, lengths, groups
)
refs = [[] for _ in groups]
for i in range(permutes.size(0)):
in_idx, out_idx, in_start, _, length, _ = permutes[i].tolist()
refs[out_idx].append(values[in_idx][:, in_start : (in_start + length)])
refs = [torch.cat(ref, dim=1) for ref in refs]
outputs = torch.ops.fbgemm.permute_multi_embedding(
values, permutes, in_shapes, out_shapes, out_lengths
)
for out, ref in zip(outputs, refs):
self.assertTrue(torch.allclose(out, ref))

def test_multi_permute_forward_meta(self) -> None:
batch_size = 32
keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
lengths = [[3, 4], [5, 6, 7], [8]]
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]
values = [
torch.randn(batch_size, sum(lens), device="meta", requires_grad=True)
for lens in lengths
]
permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_permutes(
values[0], keys, lengths, groups
)
outputs = torch.ops.fbgemm.permute_multi_embedding(
values, permutes, in_shapes, out_shapes, out_lengths
)
for out, ref in zip(outputs, out_lengths):
self.assertEqual(out.shape, (batch_size, ref))

# pyre-ignore[56]
@unittest.skipIf(
torch.cuda.device_count() <= 0,
"CUDA is not available",
)
def test_multi_permute_forward_gpu(self) -> None:
batch_size = 1024
keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
lengths = [[96, 256], [512, 128, 768], [1024]]
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]
values = [
torch.randn(batch_size, sum(lens), device="cuda", requires_grad=True)
for lens in lengths
]
permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_permutes(
values[0], keys, lengths, groups
)
refs = [[] for _ in groups]
for i in range(permutes.size(0)):
in_idx, out_idx, in_start, _, length, _ = permutes[i].tolist()
refs[out_idx].append(values[in_idx][:, in_start : (in_start + length)])
refs = [torch.cat(ref, dim=1) for ref in refs]
outputs = torch.ops.fbgemm.permute_multi_embedding(
values, permutes, in_shapes, out_shapes, out_lengths
)
for out, ref in zip(outputs, refs):
self.assertTrue(torch.allclose(out, ref))

def test_multi_permute_backward_cpu(self) -> None:
batch_size = 32
keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
lengths = [[3, 4], [5, 6, 7], [8]]
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]
values = [
torch.randn(batch_size, sum(lens), device="cpu", requires_grad=True)
for lens in lengths
]
ref_values = [v.detach() for v in values]
for v in ref_values:
v.requires_grad = True
permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_permutes(
values[0], keys, lengths, groups
)
refs = [[] for _ in groups]
for i in range(permutes.size(0)):
in_idx, out_idx, in_start, _, length, _ = permutes[i].tolist()
refs[out_idx].append(ref_values[in_idx][:, in_start : (in_start + length)])
refs = [torch.cat(ref, dim=1) for ref in refs]
outputs = torch.ops.fbgemm.permute_multi_embedding(
values, permutes, in_shapes, out_shapes, out_lengths
)
for out, ref in zip(outputs, refs):
self.assertTrue(torch.allclose(out, ref))

ref_loss, loss = refs[0].sum(), outputs[0].sum()
for i in range(1, len(refs)):
ref_loss += (i + 1.1) * refs[i].sum()
loss += (i + 1.1) * outputs[i].sum()
ref_loss.backward()
loss.backward()
for val, ref in zip(values, ref_values):
val_grad, ref_grad = val.grad, ref.grad
assert isinstance(val_grad, torch.Tensor)
self.assertTrue(torch.allclose(val_grad, ref_grad))

# pyre-ignore[56]
@unittest.skipIf(
torch.cuda.device_count() <= 0,
"CUDA is not available",
)
def test_multi_permute_backward_gpu(self) -> None:
batch_size = 2048
keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
lengths = [[96, 256], [512, 128, 768], [1024]]
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]
values = [
torch.randn(batch_size, sum(lens), device="cuda", requires_grad=True)
for lens in lengths
]
ref_values = [v.detach() for v in values]
for v in ref_values:
v.requires_grad = True
permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_permutes(
values[0], keys, lengths, groups
)
refs = [[] for _ in groups]
for i in range(permutes.size(0)):
in_idx, out_idx, in_start, _, length, _ = permutes[i].tolist()
refs[out_idx].append(ref_values[in_idx][:, in_start : (in_start + length)])
refs = [torch.cat(ref, dim=1) for ref in refs]
outputs = torch.ops.fbgemm.permute_multi_embedding(
values, permutes, in_shapes, out_shapes, out_lengths
)
for out, ref in zip(outputs, refs):
self.assertTrue(torch.allclose(out, ref))

ref_loss, loss = refs[0].sum(), outputs[0].sum()
for i in range(1, len(refs)):
ref_loss += (i + 1.1) * refs[i].sum()
loss += (i + 1.1) * outputs[i].sum()
ref_loss.backward()
loss.backward()
for val, ref in zip(values, ref_values):
val_grad, ref_grad = val.grad, ref.grad
assert isinstance(val_grad, torch.Tensor)
self.assertTrue(torch.allclose(val_grad, ref_grad))

def test_permute_duplicates(self) -> None:
values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
lengths = torch.IntTensor([0, 2, 0, 1, 1, 1, 0, 3, 0])
Expand Down Expand Up @@ -1650,8 +1837,6 @@ def test_string_vb(self) -> None:
stride_per_key_per_rank=stride_per_key_per_rank,
)

print(str(jag_tensor))

self.assertEqual(
str(jag_tensor),
"""\
Expand Down