Skip to content

Commit f21ced4

Browse files
Shirong Wufacebook-github-bot
Shirong Wu
authored andcommitted
Add split_with_sizes converter (pytorch#71953)
Summary: Pull Request resolved: pytorch#71953 Add converter for split_with_sizes Reviewed By: yinghai Differential Revision: D33829024 fbshipit-source-id: 50de383797a347ef7afecfbda80b2c84e244e404
1 parent 931ae4a commit f21ced4

File tree

3 files changed

+27
-1
lines changed

3 files changed

+27
-1
lines changed

test/fx2trt/converters/acc_op/test_split.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,5 +32,23 @@ def forward(self, x):
3232
test_explicit_batch_dim=False,
3333
)
3434

35+
@parameterized.expand(
36+
[
37+
("split_with_size", [2, 3, 5], 1),
38+
]
39+
)
40+
def test_split_with_size(self, _, split_size, dim):
41+
class Split(nn.Module):
42+
def forward(self, x):
43+
return x.split_with_sizes(split_size, dim)
44+
45+
inputs = [torch.randn(1, 10)]
46+
self.run_test(
47+
Split(),
48+
inputs,
49+
expected_ops={acc_ops.slice_tensor},
50+
test_explicit_batch_dim=False,
51+
)
52+
3553
if __name__ == '__main__':
3654
run_tests()

torch/fx/experimental/fx2trt/converters/acc_ops_converters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1661,7 +1661,7 @@ def slice_to_trt_params(py_slice, dim_size):
16611661
size = math.ceil((stop - start) * 1.0 / stride)
16621662
return start, size, stride
16631663

1664-
if not isinstance(slices, tuple):
1664+
if not isinstance(slices, tuple) and not isinstance(slices, list):
16651665
slices = (slices,)
16661666

16671667
if network.has_implicit_batch_dimension:

torch/fx/experimental/fx_acc/acc_ops.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1277,6 +1277,14 @@ def linalg_norm(*, input, ord, dim, keepdim):
12771277
("dim", "dim"),
12781278
],
12791279
)
1280+
@register_custom_acc_mapper_fn(
1281+
op_and_target=("call_method", "split_with_sizes"),
1282+
arg_replacement_tuples=[
1283+
("tensor", "input"),
1284+
("split_sizes", "split_size_or_sections"),
1285+
("dim", "dim"),
1286+
],
1287+
)
12801288
@register_custom_acc_mapper_fn(
12811289
op_and_target=("call_function", torch.split),
12821290
arg_replacement_tuples=[

0 commit comments

Comments
 (0)