Skip to content

Commit fa6db66

Browse files
cosineFishfatcat-z
andauthored
Fix issue #2102: Set reduction axis of mean to height and width for adjust_contrast op (#2140)
Signed-off-by: cosine <[email protected]> Co-authored-by: Jay Zhang <[email protected]>
1 parent 5708e10 commit fa6db66

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

tests/test_backend.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3280,13 +3280,13 @@ def func(x, x_new_size_):
32803280
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: x_new_size})
32813281

32823282
def test_adjust_contrast(self):
3283-
x_shape = [4, 3, 2]
3284-
x_val = np.arange(1, 1 + np.prod(x_shape), dtype=np.float32).reshape(x_shape)
3285-
y_val = np.array(2.1, np.float32)
32863283
def func(x, y):
32873284
x_ = tf.image.adjust_contrast(x, y)
32883285
return tf.identity(x_, name=_TFOUTPUT)
3289-
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val})
3286+
for x_shape in [[4, 3, 2], [2, 3, 4, 5], [3, 4, 2, 4, 3]]:
3287+
x_val = np.arange(1, 1 + np.prod(x_shape), dtype=np.float32).reshape(x_shape)
3288+
y_val = np.array(2.1, np.float32)
3289+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val})
32903290

32913291
@check_opset_min_version(11, "GatherElements")
32923292
def test_adjust_saturation(self):

tf2onnx/onnx_opset/nn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1522,8 +1522,8 @@ def version_1(cls, ctx, node, **kwargs):
15221522
contrast_factor = ctx.make_node("Cast", [dtype], attr={'to': dtype}).output[0]
15231523
rank = ctx.get_rank(images)
15241524
utils.make_sure(rank is not None, "AdjustContrastv2 requires input of known rank")
1525-
# Reduce everything except channels
1526-
axes_to_reduce = list(range(rank))[:-1]
1525+
# Reduce height and width only
1526+
axes_to_reduce = list(range(rank))[-3:-1]
15271527
mean = ctx.make_node("ReduceMean", [images], attr={'axes': axes_to_reduce, 'keepdims': True},
15281528
op_name_scope=node.name).output[0]
15291529
diff = ctx.make_node("Sub", [images, mean], op_name_scope=node.name).output[0]

0 commit comments

Comments
 (0)