Skip to content

Commit 71f61c6

Browse files
gs-olivebowang007
authored andcommitted
fix: Implement aten.mean.default and aten.mean.dim converters (#1810)
1 parent 47df4cd commit 71f61c6

File tree

2 files changed

+108
-11
lines changed

2 files changed

+108
-11
lines changed

py/torch_tensorrt/fx/converters/aten_ops_converters.py

+24-11
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ def aten_ops_add(
4141
return acc_ops_converters.acc_ops_add(network, target, None, kwargs_new, name)
4242

4343

44-
@tensorrt_converter(torch.ops.aten.mean.dim)
4544
@tensorrt_converter(torch.ops.aten._adaptive_avg_pool3d.default)
4645
@tensorrt_converter(torch.ops.aten._adaptive_avg_pool2d.default)
4746
def aten_ops_adaptive_avg_poolnd(
@@ -51,24 +50,38 @@ def aten_ops_adaptive_avg_poolnd(
5150
kwargs: Dict[str, Argument],
5251
name: str,
5352
) -> Union[TRTTensor, Sequence[TRTTensor]]:
54-
if target == torch.ops.aten.mean.dim:
55-
56-
if list(args[1]) != [-1, -2]:
57-
raise RuntimeError(f"We do not support {target} has dim={args[1]}")
58-
else:
59-
output_size = [1, 1]
60-
else:
61-
output_size = args[1]
62-
6353
kwargs_new = {
6454
"input": args[0],
65-
"output_size": output_size,
55+
"output_size": args[1],
6656
}
6757
return acc_ops_converters.acc_ops_adaptive_avg_poolnd(
6858
network, target, None, kwargs_new, name
6959
)
7060

7161

62+
@tensorrt_converter(torch.ops.aten.mean.default)
63+
@tensorrt_converter(torch.ops.aten.mean.dim)
64+
def aten_ops_mean(
65+
network: TRTNetwork,
66+
target: Target,
67+
args: Tuple[Argument, ...],
68+
kwargs: Dict[str, Argument],
69+
name: str,
70+
) -> TRTTensor:
71+
# Default invocation of aten.mean only uses first argument and
72+
# averages over all elements (all dimensions)
73+
# aten.mean.dim invocation allows specification of dimensions to average
74+
# over, as well at the option to keep the dimension or not
75+
kwargs_new = {
76+
"input": args[0],
77+
"dim": args[1] if len(args) >= 2 else list(range(len(args[0].shape))),
78+
"keepdim": args[2] if len(args) >= 3 else False,
79+
}
80+
return add_reduce_layer(
81+
network, target, args, kwargs_new, trt.ReduceOperation.AVG, name
82+
)
83+
84+
7285
@tensorrt_converter(torch.ops.aten.batch_norm)
7386
def aten_ops_batch_norm(
7487
network: TRTNetwork,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import torch
2+
import torch.nn as nn
3+
from torch.testing._internal.common_utils import run_tests
4+
from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
5+
6+
7+
class TestMeanDimConverter(DispatchTestCase):
8+
def test_mean_dim_keepdims(self):
9+
class TestModule(nn.Module):
10+
def forward(self, x):
11+
return torch.mean(x, dim=[0, 1], keepdim=True)
12+
13+
inputs = [torch.randn(1, 10)]
14+
self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.mean.dim})
15+
16+
def test_mean_dim_keepdims_with_dynamic_shape(self):
17+
class TestModule(nn.Module):
18+
def forward(self, x):
19+
return torch.mean(x, dim=[0, 1, 2], keepdim=True)
20+
21+
input_specs = [
22+
InputTensorSpec(
23+
shape=(-1, -1, -1),
24+
dtype=torch.float32,
25+
shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))],
26+
),
27+
]
28+
self.run_test_with_dynamic_shape(
29+
TestModule(), input_specs, expected_ops={torch.ops.aten.mean.dim}
30+
)
31+
32+
def test_mean_dim_keepdims_false(self):
33+
class TestModule(nn.Module):
34+
def forward(self, x):
35+
return torch.mean(x, dim=0, keepdim=False)
36+
37+
inputs = [torch.randn(3, 5, 7)]
38+
self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.mean.dim})
39+
40+
def test_mean_dim_keepdims_false_with_dynamic_shape(self):
41+
class TestModule(nn.Module):
42+
def forward(self, x):
43+
return torch.mean(x, dim=-1, keepdim=False)
44+
45+
input_specs = [
46+
InputTensorSpec(
47+
shape=(-1, -1, -1),
48+
dtype=torch.float32,
49+
shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))],
50+
),
51+
]
52+
self.run_test_with_dynamic_shape(
53+
TestModule(), input_specs, expected_ops={torch.ops.aten.mean.dim}
54+
)
55+
56+
57+
class TestMeanConverter(DispatchTestCase):
58+
def test_mean(self):
59+
class TestModule(nn.Module):
60+
def forward(self, x):
61+
return torch.mean(x)
62+
63+
inputs = [torch.randn(3, 8, 5, 7, 1)]
64+
self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.mean.default})
65+
66+
def test_mean_with_dynamic_shape(self):
67+
class TestModule(nn.Module):
68+
def forward(self, x):
69+
return torch.mean(x)
70+
71+
input_specs = [
72+
InputTensorSpec(
73+
shape=(-1, -1, -1),
74+
dtype=torch.float32,
75+
shape_ranges=[((1, 1, 1), (1, 5, 8), (3, 10, 10))],
76+
),
77+
]
78+
self.run_test_with_dynamic_shape(
79+
TestModule(), input_specs, expected_ops={torch.ops.aten.mean.default}
80+
)
81+
82+
83+
if __name__ == "__main__":
84+
run_tests()

0 commit comments

Comments
 (0)