Skip to content

Commit ec01956

Browse files
cosineFishfatcat-z
andauthored
Fix ResizeBilinear and ResizeNearestNeighbor and update related tests (onnx#2130)
1) Fix tf ResizeBilinear and ResizeNearestNeighbor op with align_corners=True or half_pixel_centers=True 2) Update related tests to make input data include x.5 and add tests for align_corners=True Signed-off-by: cosine <[email protected]> Co-authored-by: Jay Zhang <[email protected]>
1 parent 5f918ab commit ec01956

File tree

2 files changed

+42
-11
lines changed

2 files changed

+42
-11
lines changed

tests/test_backend.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3192,7 +3192,7 @@ def func(x):
31923192
def test_resize_nearest_neighbor(self):
31933193
x_shape = [1, 15, 20, 2]
31943194
x_new_size = [30, 40]
3195-
x_val = np.arange(1, 1 + np.prod(x_shape)).astype("float32").reshape(x_shape)
3195+
x_val = np.arange(1, 1 + np.prod(x_shape)/10, 0.1).astype("float32").reshape(x_shape)
31963196
def func(x):
31973197
x_new_size_ = tf.constant(x_new_size)
31983198
x_ = resize_nearest_neighbor(x, x_new_size_)
@@ -3202,7 +3202,7 @@ def func(x):
32023202
@check_opset_min_version(9, "resize_nearest_neighbor")
32033203
def test_resize_nearest_neighbor_with_non_const(self):
32043204
x_shape = [3, 10, 8, 5]
3205-
x_val = np.arange(1, 1 + np.prod(x_shape), dtype=np.float32).reshape(x_shape)
3205+
x_val = np.arange(1, 1 + np.prod(x_shape)/10, 0.1, dtype=np.float32).reshape(x_shape)
32063206
x_new_size = np.array([20, 16]).astype(np.int32)
32073207
def func(x, x_new_size_):
32083208
x_ = resize_nearest_neighbor(x, x_new_size_)
@@ -3221,13 +3221,26 @@ def func(x):
32213221
return tf.identity(x_, name=_TFOUTPUT)
32223222
_ = self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
32233223

3224+
@skip_caffe2_backend()
3225+
@check_tf_min_version("1.14")
3226+
@check_opset_min_version(11, "coordinate_transformation_mode attr of resize_bilinear")
3227+
def test_resize_bilinear_align_coreners(self):
3228+
x_shape = [1, 15, 20, 2]
3229+
x_new_size = [30, 40]
3230+
x_val = np.arange(1, 1 + np.prod(x_shape)/10, 0.1).astype("float32").reshape(x_shape)
3231+
def func(x):
3232+
x_new_size_ = tf.constant(x_new_size)
3233+
x_ = resize_bilinear(x, x_new_size_, align_corners=True)
3234+
return tf.identity(x_, name=_TFOUTPUT)
3235+
_ = self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
3236+
32243237
@skip_caffe2_backend()
32253238
@check_tf_min_version("1.14")
32263239
@check_opset_min_version(11, "coordinate_transformation_mode attr")
32273240
def test_resize_bilinear_half_pixel_centers(self):
32283241
x_shape = [1, 15, 20, 2]
32293242
x_new_size = [30, 40]
3230-
x_val = np.arange(1, 1 + np.prod(x_shape)).astype("float32").reshape(x_shape)
3243+
x_val = np.arange(1, 1 + np.prod(x_shape)/10, 0.1).astype("float32").reshape(x_shape)
32313244
def func(x):
32323245
x_new_size_ = tf.constant(x_new_size)
32333246
x_ = resize_bilinear(x, x_new_size_, half_pixel_centers=True)
@@ -3237,7 +3250,7 @@ def func(x):
32373250
@check_opset_min_version(9, "resize_bilinear")
32383251
def test_resize_bilinear_with_non_const(self):
32393252
x_shape = [3, 10, 8, 5]
3240-
x_val = np.arange(1, 1 + np.prod(x_shape), dtype=np.float32).reshape(x_shape)
3253+
x_val = np.arange(1, 1 + np.prod(x_shape)/10, 0.1, dtype=np.float32).reshape(x_shape)
32413254
x_new_size = np.array([20, 16]).astype(np.int32)
32423255
def func(x, x_new_size_):
32433256
x_ = resize_bilinear(x, x_new_size_)
@@ -3248,7 +3261,7 @@ def func(x, x_new_size_):
32483261
def test_resize_bilinear_with_non_const2(self):
32493262
# scales has an element larger than 1 and also has an element less that 1
32503263
x_shape = [3, 100, 8, 5]
3251-
x_val = np.arange(1, 1 + np.prod(x_shape), dtype=np.float32).reshape(x_shape)
3264+
x_val = np.arange(1, 1 + np.prod(x_shape)/10, 0.1, dtype=np.float32).reshape(x_shape)
32523265
x_new_size = np.array([20, 16]).astype(np.int32)
32533266
def func(x, x_new_size_):
32543267
x_ = resize_bilinear(x, x_new_size_)
@@ -3259,7 +3272,7 @@ def func(x, x_new_size_):
32593272
@check_opset_min_version(11, "resize_bilinear_v2")
32603273
def test_resize_bilinear_v2_with_non_const(self):
32613274
x_shape = [3, 10, 8, 5]
3262-
x_val = np.arange(1, 1 + np.prod(x_shape), dtype=np.float32).reshape(x_shape)
3275+
x_val = np.arange(1, 1 + np.prod(x_shape)/10, 0.1, dtype=np.float32).reshape(x_shape)
32633276
x_new_size = np.array([20, 16]).astype(np.int32)
32643277
def func(x, x_new_size_):
32653278
x_ = resize_bilinear_v2(x, x_new_size_)
@@ -3304,7 +3317,7 @@ def func(x, y):
33043317
def test_resize_bicubic(self):
33053318
x_shape = [1, 15, 20, 2]
33063319
new_size_val = np.array([30, 40], dtype=np.int32)
3307-
x_val = np.arange(1, 1 + np.prod(x_shape)).astype("float32").reshape(x_shape)
3320+
x_val = np.arange(1, 1 + np.prod(x_shape)/10, 0.1).astype("float32").reshape(x_shape)
33083321
def func(x, new_size):
33093322
y = tf.image.resize(x, new_size, method=tf.image.ResizeMethod.BICUBIC)
33103323
return tf.identity(y, name=_TFOUTPUT)
@@ -3314,7 +3327,7 @@ def func(x, new_size):
33143327
def test_resize_nearest_neighbor2(self):
33153328
x_shape = [1, 300, 20, 2]
33163329
x_new_size = [30, 40]
3317-
x_val = np.arange(1, 1 + np.prod(x_shape)).astype("float32").reshape(x_shape)
3330+
x_val = np.arange(1, 1 + np.prod(x_shape)/10, 0.1).astype("float32").reshape(x_shape)
33183331
def func(x):
33193332
x_new_size_ = tf.constant(x_new_size)
33203333
x_ = resize_nearest_neighbor(x, x_new_size_)
@@ -3326,13 +3339,25 @@ def func(x):
33263339
def test_resize_nearest_neighbor_half_pixel_centers(self):
33273340
x_shape = [1, 10, 20, 2]
33283341
x_new_size = [20, 40]
3329-
x_val = np.arange(1, 1 + np.prod(x_shape)).astype("float32").reshape(x_shape)
3342+
x_val = np.arange(1, 1 + np.prod(x_shape)/10, 0.1).astype("float32").reshape(x_shape)
33303343
def func(x):
33313344
x_new_size_ = tf.constant(x_new_size)
33323345
x_ = resize_nearest_neighbor(x, x_new_size_, half_pixel_centers=True)
33333346
return tf.identity(x_, name=_TFOUTPUT)
33343347
_ = self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
33353348

3349+
@check_tf_min_version("1.14")
3350+
@check_opset_min_version(11, "coordinate_transformation_mode and nearest_mode attr")
3351+
def test_resize_nearest_neighbor_align_corners(self):
3352+
x_shape = [1, 10, 20, 2]
3353+
x_new_size = [20, 40]
3354+
x_val = np.arange(1, 1 + np.prod(x_shape)/10, 0.1).astype("float32").reshape(x_shape)
3355+
def func(x):
3356+
x_new_size_ = tf.constant(x_new_size)
3357+
x_ = resize_nearest_neighbor(x, x_new_size_, align_corners=True)
3358+
return tf.identity(x_, name=_TFOUTPUT)
3359+
_ = self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
3360+
33363361
@check_opset_min_version(9, "fill")
33373362
def test_fill_float32(self):
33383363
x_shape = [1, 15, 20, 2]

tf2onnx/onnx_opset/nn.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1409,11 +1409,11 @@ def version_11(cls, ctx, node, **kwargs):
14091409
nearest_mode = "floor"
14101410
if "align_corners" in node.attr and node.attr["align_corners"].i:
14111411
transformation_mode = "align_corners"
1412+
nearest_mode = "round_prefer_ceil"
14121413
if "half_pixel_centers" in node.attr and node.attr["half_pixel_centers"].i:
14131414
if node.type == "ResizeNearestNeighbor" and not ctx.is_target(constants.TARGET_TENSORRT):
14141415
# TensorRT only supports nearest_mode = "floor" for mode = "nearest"
1415-
transformation_mode = "half_pixel"
1416-
nearest_mode = "round_prefer_ceil"
1416+
transformation_mode = "tf_half_pixel_for_nn"
14171417
else:
14181418
transformation_mode = "half_pixel"
14191419
attr = {"mode": mode, "nearest_mode": nearest_mode, "coordinate_transformation_mode": transformation_mode,
@@ -1435,6 +1435,12 @@ def _convert_since_9(cls, ctx, node, op_type, use_target_size=False):
14351435
# wants the input to be NHWC - adjust target_shape to this.
14361436
utils.make_sure(node.type != "ResizeBicubic", "Opset 11 is required for bicubic interpolation for node %s",
14371437
node.name)
1438+
if "align_corners" in node.attr:
1439+
utils.make_sure(not node.attr["align_corners"].i,
1440+
"Opset 11 is required for align_corners=True for node %s", node.name)
1441+
if "half_pixel_centers" in node.attr:
1442+
utils.make_sure(not node.attr["half_pixel_centers"].i,
1443+
"Opset 11 is required for half_pixel_centers=True for node %s", node.name)
14381444
mode = "linear" if node.type == "ResizeBilinear" else "nearest"
14391445

14401446
# because onnxruntime only supports to scale the last two dims so transpose is inserted

0 commit comments

Comments
 (0)