|
11 | 11 | import unittest
|
12 | 12 |
|
13 | 13 | import torch
|
14 |
| -from torchrec.sparse.jagged_tensor import _regroup_keyed_tensors, KeyedTensor |
| 14 | +from torchrec.sparse.jagged_tensor import ( |
| 15 | + _regroup_keyed_tensors, |
| 16 | + KeyedJaggedTensor, |
| 17 | + KeyedTensor, |
| 18 | +) |
15 | 19 | from torchrec.sparse.tests.utils import build_groups, build_kts
|
16 | 20 | from torchrec.test_utils import skip_if_asan_class
|
17 | 21 |
|
@@ -111,3 +115,128 @@ def test_regroup_backward(self) -> None:
|
111 | 115 |
|
112 | 116 | torch.allclose(actual_kt_0_grad, expected_kt_0_grad)
|
113 | 117 | torch.allclose(actual_kt_1_grad, expected_kt_1_grad)
|
| 118 | + |
| 119 | + # pyre-ignore |
| 120 | + @unittest.skipIf( |
| 121 | + torch.cuda.device_count() <= 0, |
| 122 | + "Not enough GPUs, this test requires at least one GPUs", |
| 123 | + ) |
| 124 | + def test_permute(self) -> None: |
| 125 | + values = torch.tensor( |
| 126 | + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], device=self.device |
| 127 | + ) |
| 128 | + lengths = torch.tensor([0, 2, 0, 1, 1, 1, 0, 3, 0], device=self.device) |
| 129 | + keys = ["index_0", "index_1", "index_2"] |
| 130 | + |
| 131 | + jag_tensor = KeyedJaggedTensor.from_lengths_sync( |
| 132 | + values=values, |
| 133 | + keys=keys, |
| 134 | + lengths=lengths, |
| 135 | + ) |
| 136 | + indices = [1, 0, 2] |
| 137 | + permuted_jag_tensor = jag_tensor.permute(indices) |
| 138 | + |
| 139 | + self.assertEqual(permuted_jag_tensor.keys(), ["index_1", "index_0", "index_2"]) |
| 140 | + self.assertEqual( |
| 141 | + permuted_jag_tensor.offset_per_key(), |
| 142 | + [0, 3, 5, 8], |
| 143 | + ) |
| 144 | + self.assertEqual( |
| 145 | + permuted_jag_tensor.values().tolist(), |
| 146 | + [3.0, 4.0, 5.0, 1.0, 2.0, 6.0, 7.0, 8.0], |
| 147 | + ) |
| 148 | + self.assertEqual( |
| 149 | + permuted_jag_tensor.lengths().tolist(), [1, 1, 1, 0, 2, 0, 0, 3, 0] |
| 150 | + ) |
| 151 | + self.assertEqual(permuted_jag_tensor.weights_or_none(), None) |
| 152 | + |
| 153 | + # pyre-ignore |
| 154 | + @unittest.skipIf( |
| 155 | + torch.cuda.device_count() <= 0, |
| 156 | + "Not enough GPUs, this test requires at least one GPUs", |
| 157 | + ) |
| 158 | + def test_permute_vb(self) -> None: |
| 159 | + values = torch.tensor( |
| 160 | + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], device=self.device |
| 161 | + ) |
| 162 | + lengths = torch.tensor([1, 0, 1, 3, 0, 1, 0, 2, 0], device=self.device) |
| 163 | + keys = ["index_0", "index_1", "index_2"] |
| 164 | + stride_per_key_per_rank = [[2], [4], [3]] |
| 165 | + |
| 166 | + jag_tensor = KeyedJaggedTensor.from_lengths_sync( |
| 167 | + values=values, |
| 168 | + keys=keys, |
| 169 | + lengths=lengths, |
| 170 | + stride_per_key_per_rank=stride_per_key_per_rank, |
| 171 | + ) |
| 172 | + |
| 173 | + indices = [1, 0, 2] |
| 174 | + permuted_jag_tensor = jag_tensor.permute(indices) |
| 175 | + |
| 176 | + self.assertEqual(permuted_jag_tensor.keys(), ["index_1", "index_0", "index_2"]) |
| 177 | + self.assertEqual( |
| 178 | + permuted_jag_tensor.offset_per_key(), |
| 179 | + [0, 5, 6, 8], |
| 180 | + ) |
| 181 | + self.assertEqual( |
| 182 | + permuted_jag_tensor.values().tolist(), |
| 183 | + [2.0, 3.0, 4.0, 5.0, 6.0, 1.0, 7.0, 8.0], |
| 184 | + ) |
| 185 | + self.assertEqual( |
| 186 | + permuted_jag_tensor.lengths().tolist(), [1, 3, 0, 1, 1, 0, 0, 2, 0] |
| 187 | + ) |
| 188 | + self.assertEqual(permuted_jag_tensor.weights_or_none(), None) |
| 189 | + |
| 190 | + # pyre-ignore |
| 191 | + @unittest.skipIf( |
| 192 | + torch.cuda.device_count() <= 0, |
| 193 | + "Not enough GPUs, this test requires at least one GPUs", |
| 194 | + ) |
| 195 | + def test_permute_duplicates(self) -> None: |
| 196 | + values = torch.tensor( |
| 197 | + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], device=self.device |
| 198 | + ) |
| 199 | + lengths = torch.tensor([0, 2, 0, 1, 1, 1, 0, 3, 0], device=self.device) |
| 200 | + keys = ["index_0", "index_1", "index_2"] |
| 201 | + |
| 202 | + jag_tensor = KeyedJaggedTensor.from_lengths_sync( |
| 203 | + values=values, |
| 204 | + keys=keys, |
| 205 | + lengths=lengths, |
| 206 | + ) |
| 207 | + |
| 208 | + indices = [1, 0, 2, 1, 1] |
| 209 | + permuted_jag_tensor = jag_tensor.permute(indices) |
| 210 | + |
| 211 | + self.assertEqual( |
| 212 | + permuted_jag_tensor.keys(), |
| 213 | + ["index_1", "index_0", "index_2", "index_1", "index_1"], |
| 214 | + ) |
| 215 | + self.assertEqual( |
| 216 | + permuted_jag_tensor.offset_per_key(), |
| 217 | + [0, 3, 5, 8, 11, 14], |
| 218 | + ) |
| 219 | + self.assertEqual( |
| 220 | + permuted_jag_tensor.values().tolist(), |
| 221 | + [ |
| 222 | + 3.0, |
| 223 | + 4.0, |
| 224 | + 5.0, |
| 225 | + 1.0, |
| 226 | + 2.0, |
| 227 | + 6.0, |
| 228 | + 7.0, |
| 229 | + 8.0, |
| 230 | + 3.0, |
| 231 | + 4.0, |
| 232 | + 5.0, |
| 233 | + 3.0, |
| 234 | + 4.0, |
| 235 | + 5.0, |
| 236 | + ], |
| 237 | + ) |
| 238 | + self.assertEqual( |
| 239 | + permuted_jag_tensor.lengths().tolist(), |
| 240 | + [1, 1, 1, 0, 2, 0, 0, 3, 0, 1, 1, 1, 1, 1, 1], |
| 241 | + ) |
| 242 | + self.assertEqual(permuted_jag_tensor.weights_or_none(), None) |
0 commit comments