Skip to content

Commit f617898

Browse files
authored
feat: support cumsum dynamo converter (#2403)
1 parent b5efb6e commit f617898

File tree

3 files changed

+139
-1
lines changed

3 files changed

+139
-1
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+23
Original file line numberDiff line numberDiff line change
@@ -691,6 +691,29 @@ def aten_ops_chunk(
691691
)
692692

693693

694+
@dynamo_tensorrt_converter(torch.ops.aten.cumsum.default) # type: ignore[misc]
695+
@enforce_tensor_types(
696+
{
697+
0: (TRTTensor,),
698+
}
699+
) # type: ignore[misc]
700+
def aten_ops_cumsum(
701+
ctx: ConversionContext,
702+
target: Target,
703+
args: Tuple[Argument, ...],
704+
kwargs: Dict[str, Argument],
705+
name: str,
706+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
707+
return impl.slice.cumsum(
708+
ctx,
709+
target,
710+
SourceIR.ATEN,
711+
name,
712+
args[0],
713+
args[1],
714+
)
715+
716+
694717
@dynamo_tensorrt_converter(torch.ops.aten.permute.default) # type: ignore[misc]
695718
@enforce_tensor_types(
696719
{

py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py

+47-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
import math
22
from typing import Optional
33

4+
import numpy as np
5+
import tensorrt as trt
46
from torch.fx.node import Target
57
from torch_tensorrt.dynamo._SourceIR import SourceIR
8+
from torch_tensorrt.dynamo.conversion import impl
69
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
7-
from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim
10+
from torch_tensorrt.dynamo.conversion.converter_utils import (
11+
get_positive_dim,
12+
get_trt_tensor,
13+
)
814
from torch_tensorrt.dynamo.conversion.impl.slice.base import slice
915
from torch_tensorrt.fx.converters.converter_utils import (
1016
has_dynamic_shape,
@@ -157,3 +163,43 @@ def chunk(
157163
cnt += 1
158164

159165
return result
166+
167+
168+
def cumsum(
169+
ctx: ConversionContext,
170+
target: Target,
171+
source_ir: Optional[SourceIR],
172+
name: str,
173+
input: TRTTensor,
174+
dim: int,
175+
) -> TRTTensor:
176+
input_shape = input.shape
177+
dim = get_positive_dim(dim, len(input_shape))
178+
loop = ctx.net.add_loop()
179+
axis = np.array(input_shape[dim])
180+
trip_limit = get_trt_tensor(ctx, axis, f"{name}_trip_limit")
181+
loop.add_trip_limit(trip_limit, trt.TripLimit.COUNT)
182+
iterator = loop.add_iterator(input, dim, reverse=False)
183+
data = iterator.get_output(0)
184+
new_dims = tuple(data.shape)
185+
zeros = np.zeros(new_dims)
186+
zero_trttensor = get_trt_tensor(ctx, zeros, f"{name}_initial_value")
187+
188+
running_sum = loop.add_recurrence(zero_trttensor)
189+
set_layer_name(running_sum, target, f"{name}_running_sum", source_ir)
190+
running_sum_tensor = running_sum.get_output(0)
191+
192+
current_sum = impl.elementwise.add(
193+
ctx,
194+
target,
195+
source_ir,
196+
f"{name}_elementwise_add",
197+
data,
198+
running_sum_tensor,
199+
)
200+
running_sum.set_input(1, current_sum)
201+
202+
loop_output = loop.add_loop_output(current_sum, trt.LoopOutput.CONCATENATE, dim)
203+
set_layer_name(loop_output, target, f"{name}_loop_output", source_ir)
204+
loop_output.set_input(1, trip_limit)
205+
return loop_output.get_output(0)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestCumsumConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
((1,), 0),
13+
((2,), 0),
14+
((3,), -1),
15+
]
16+
)
17+
def test_cumsum_1D(self, shape, dim):
18+
class Cumsum(nn.Module):
19+
def forward(self, x):
20+
return torch.ops.aten.cumsum.default(x, dim)
21+
22+
inputs = [torch.randn(shape)]
23+
self.run_test(
24+
Cumsum(),
25+
inputs,
26+
)
27+
28+
@parameterized.expand(
29+
[
30+
((3, 1), 0),
31+
((3, 1), 1),
32+
((2, 3), -1),
33+
((2, 3), -2),
34+
]
35+
)
36+
def test_cumsum_2D(self, shape, dims):
37+
class Cumsum(nn.Module):
38+
def forward(self, x):
39+
return torch.ops.aten.cumsum.default(x, dims)
40+
41+
inputs = [torch.randn(shape)]
42+
self.run_test(
43+
Cumsum(),
44+
inputs,
45+
)
46+
47+
@parameterized.expand(
48+
[
49+
((4, 2, 3), 0),
50+
((4, 2, 3), 1),
51+
((1, 2, 3), 2),
52+
((1, 2, 3), -1),
53+
((1, 2, 3), -2),
54+
]
55+
)
56+
def test_cumsum_3D(self, shape, dims):
57+
class Cumsum(nn.Module):
58+
def forward(self, x):
59+
return torch.ops.aten.cumsum.default(x, dims)
60+
61+
inputs = [torch.randn(shape)]
62+
self.run_test(
63+
Cumsum(),
64+
inputs,
65+
)
66+
67+
68+
if __name__ == "__main__":
69+
run_tests()

0 commit comments

Comments
 (0)