Skip to content

Commit 429e813

Browse files
committed
Fix issue onnx#2102: Set reduction axis of mean to height and width for adjust_contrast op
Signed-off-by: cosine <[email protected]>
1 parent ec01956 commit 429e813

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

tests/test_backend.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3280,13 +3280,13 @@ def func(x, x_new_size_):
32803280
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: x_new_size})
32813281

32823282
def test_adjust_contrast(self):
3283-
x_shape = [4, 3, 2]
3284-
x_val = np.arange(1, 1 + np.prod(x_shape), dtype=np.float32).reshape(x_shape)
3285-
y_val = np.array(2.1, np.float32)
32863283
def func(x, y):
32873284
x_ = tf.image.adjust_contrast(x, y)
32883285
return tf.identity(x_, name=_TFOUTPUT)
3289-
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val})
3286+
for x_shape in [[4, 3, 2], [2, 3, 4, 5], [3, 4, 2, 4, 3]]:
3287+
x_val = np.arange(1, 1 + np.prod(x_shape), dtype=np.float32).reshape(x_shape)
3288+
y_val = np.array(2.1, np.float32)
3289+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val})
32903290

32913291
@check_opset_min_version(11, "GatherElements")
32923292
def test_adjust_saturation(self):

tf2onnx/onnx_opset/nn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1505,8 +1505,8 @@ def version_1(cls, ctx, node, **kwargs):
15051505
contrast_factor = ctx.make_node("Cast", [dtype], attr={'to': dtype}).output[0]
15061506
rank = ctx.get_rank(images)
15071507
utils.make_sure(rank is not None, "AdjustContrastv2 requires input of known rank")
1508-
# Reduce everything except channels
1509-
axes_to_reduce = list(range(rank))[:-1]
1508+
# Reduce height and width only
1509+
axes_to_reduce = list(range(rank))[-3:-1]
15101510
mean = ctx.make_node("ReduceMean", [images], attr={'axes': axes_to_reduce, 'keepdims': True},
15111511
op_name_scope=node.name).output[0]
15121512
diff = ctx.make_node("Sub", [images, mean], op_name_scope=node.name).output[0]

0 commit comments

Comments
 (0)