Skip to content

Commit 1bb9f10

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
Split of "[TorchRec][PT2] KJT custom op for 1d lengths input" (#2176)
Summary: Pull Request resolved: #2176 Pull Request resolved: #2163 X-link: pytorch/FBGEMM#2774 # context * move the `tensor.view(-1, stride)` from python into the operator (c++) * make the PT2 complier happy * reference: D58948987 # notes * not sure if we should directly change the op call in the jagged_tensor * tested on CPU and GPU * backward/autograd not tested Reviewed By: bearzx Differential Revision: D59031938 fbshipit-source-id: 3a80e2acedf06f842ab7b105f1498d0e26fc12ae
1 parent 46707d7 commit 1bb9f10

File tree

2 files changed

+157
-3
lines changed

2 files changed

+157
-3
lines changed

torchrec/sparse/jagged_tensor.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1955,8 +1955,20 @@ def permute(
19551955
indices_tensor,
19561956
self.weights_or_none(),
19571957
)
1958+
elif is_torchdynamo_compiling():
1959+
(
1960+
permuted_lengths,
1961+
permuted_values,
1962+
permuted_weights,
1963+
) = torch.ops.fbgemm.permute_2D_sparse_data_input1D(
1964+
indices_tensor,
1965+
self.lengths(),
1966+
self.values(),
1967+
self.stride(),
1968+
self.weights_or_none(),
1969+
permuted_length_per_key_sum,
1970+
)
19581971
else:
1959-
19601972
(
19611973
permuted_lengths,
19621974
permuted_values,
@@ -2338,7 +2350,20 @@ def dist_init(
23382350
s == stride for s in stride_per_rank
23392351
)
23402352

2341-
if single_batch_per_rank:
2353+
if single_batch_per_rank and is_torchdynamo_compiling():
2354+
(
2355+
lengths,
2356+
values,
2357+
weights,
2358+
) = torch.ops.fbgemm.permute_2D_sparse_data_input1D(
2359+
torch.jit._unwrap_optional(recat),
2360+
lengths,
2361+
values,
2362+
stride,
2363+
weights,
2364+
values.numel(),
2365+
)
2366+
elif single_batch_per_rank:
23422367
(
23432368
lengths,
23442369
values,

torchrec/sparse/tests/test_jagged_tensor_gpu.py

Lines changed: 130 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@
1111
import unittest
1212

1313
import torch
14-
from torchrec.sparse.jagged_tensor import _regroup_keyed_tensors, KeyedTensor
14+
from torchrec.sparse.jagged_tensor import (
15+
_regroup_keyed_tensors,
16+
KeyedJaggedTensor,
17+
KeyedTensor,
18+
)
1519
from torchrec.sparse.tests.utils import build_groups, build_kts
1620
from torchrec.test_utils import skip_if_asan_class
1721

@@ -111,3 +115,128 @@ def test_regroup_backward(self) -> None:
111115

112116
torch.allclose(actual_kt_0_grad, expected_kt_0_grad)
113117
torch.allclose(actual_kt_1_grad, expected_kt_1_grad)
118+
119+
# pyre-ignore
120+
@unittest.skipIf(
121+
torch.cuda.device_count() <= 0,
122+
"Not enough GPUs, this test requires at least one GPUs",
123+
)
124+
def test_permute(self) -> None:
125+
values = torch.tensor(
126+
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], device=self.device
127+
)
128+
lengths = torch.tensor([0, 2, 0, 1, 1, 1, 0, 3, 0], device=self.device)
129+
keys = ["index_0", "index_1", "index_2"]
130+
131+
jag_tensor = KeyedJaggedTensor.from_lengths_sync(
132+
values=values,
133+
keys=keys,
134+
lengths=lengths,
135+
)
136+
indices = [1, 0, 2]
137+
permuted_jag_tensor = jag_tensor.permute(indices)
138+
139+
self.assertEqual(permuted_jag_tensor.keys(), ["index_1", "index_0", "index_2"])
140+
self.assertEqual(
141+
permuted_jag_tensor.offset_per_key(),
142+
[0, 3, 5, 8],
143+
)
144+
self.assertEqual(
145+
permuted_jag_tensor.values().tolist(),
146+
[3.0, 4.0, 5.0, 1.0, 2.0, 6.0, 7.0, 8.0],
147+
)
148+
self.assertEqual(
149+
permuted_jag_tensor.lengths().tolist(), [1, 1, 1, 0, 2, 0, 0, 3, 0]
150+
)
151+
self.assertEqual(permuted_jag_tensor.weights_or_none(), None)
152+
153+
# pyre-ignore
154+
@unittest.skipIf(
155+
torch.cuda.device_count() <= 0,
156+
"Not enough GPUs, this test requires at least one GPUs",
157+
)
158+
def test_permute_vb(self) -> None:
159+
values = torch.tensor(
160+
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], device=self.device
161+
)
162+
lengths = torch.tensor([1, 0, 1, 3, 0, 1, 0, 2, 0], device=self.device)
163+
keys = ["index_0", "index_1", "index_2"]
164+
stride_per_key_per_rank = [[2], [4], [3]]
165+
166+
jag_tensor = KeyedJaggedTensor.from_lengths_sync(
167+
values=values,
168+
keys=keys,
169+
lengths=lengths,
170+
stride_per_key_per_rank=stride_per_key_per_rank,
171+
)
172+
173+
indices = [1, 0, 2]
174+
permuted_jag_tensor = jag_tensor.permute(indices)
175+
176+
self.assertEqual(permuted_jag_tensor.keys(), ["index_1", "index_0", "index_2"])
177+
self.assertEqual(
178+
permuted_jag_tensor.offset_per_key(),
179+
[0, 5, 6, 8],
180+
)
181+
self.assertEqual(
182+
permuted_jag_tensor.values().tolist(),
183+
[2.0, 3.0, 4.0, 5.0, 6.0, 1.0, 7.0, 8.0],
184+
)
185+
self.assertEqual(
186+
permuted_jag_tensor.lengths().tolist(), [1, 3, 0, 1, 1, 0, 0, 2, 0]
187+
)
188+
self.assertEqual(permuted_jag_tensor.weights_or_none(), None)
189+
190+
# pyre-ignore
191+
@unittest.skipIf(
192+
torch.cuda.device_count() <= 0,
193+
"Not enough GPUs, this test requires at least one GPUs",
194+
)
195+
def test_permute_duplicates(self) -> None:
196+
values = torch.tensor(
197+
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], device=self.device
198+
)
199+
lengths = torch.tensor([0, 2, 0, 1, 1, 1, 0, 3, 0], device=self.device)
200+
keys = ["index_0", "index_1", "index_2"]
201+
202+
jag_tensor = KeyedJaggedTensor.from_lengths_sync(
203+
values=values,
204+
keys=keys,
205+
lengths=lengths,
206+
)
207+
208+
indices = [1, 0, 2, 1, 1]
209+
permuted_jag_tensor = jag_tensor.permute(indices)
210+
211+
self.assertEqual(
212+
permuted_jag_tensor.keys(),
213+
["index_1", "index_0", "index_2", "index_1", "index_1"],
214+
)
215+
self.assertEqual(
216+
permuted_jag_tensor.offset_per_key(),
217+
[0, 3, 5, 8, 11, 14],
218+
)
219+
self.assertEqual(
220+
permuted_jag_tensor.values().tolist(),
221+
[
222+
3.0,
223+
4.0,
224+
5.0,
225+
1.0,
226+
2.0,
227+
6.0,
228+
7.0,
229+
8.0,
230+
3.0,
231+
4.0,
232+
5.0,
233+
3.0,
234+
4.0,
235+
5.0,
236+
],
237+
)
238+
self.assertEqual(
239+
permuted_jag_tensor.lengths().tolist(),
240+
[1, 1, 1, 0, 2, 0, 0, 3, 0, 1, 1, 1, 1, 1, 1],
241+
)
242+
self.assertEqual(permuted_jag_tensor.weights_or_none(), None)

0 commit comments

Comments
 (0)