@@ -76,18 +76,18 @@ def test_from_4D_image_with_invalid_data(self):
76
76
img_utils .from_4D_image (tf .ones (shape = (2 , 2 , 4 , 1 )), tf .constant (2 ))
77
77
)
78
78
79
- def test_from_4D_image_with_invalid_shape (self ):
80
- errors = (ValueError , tf .errors .InvalidArgumentError )
81
- for rank in 2 , tf .constant (2 ):
82
- with self .subTest (rank = rank ):
83
- with self .assertRaisesRegexp (errors , "`image` must be 4D tensor" ):
84
- img_utils .from_4D_image (tf .ones (shape = (2 , 4 )), rank )
85
79
86
- with self .assertRaisesRegexp (errors , "`image` must be 4D tensor" ):
87
- img_utils .from_4D_image (tf .ones (shape = (2 , 4 , 1 )), rank )
80
+ @pytest .mark .parametrize ("rank" , [2 , tf .constant (2 )])
81
+ def test_from_4d_image_with_invalid_shape (rank ):
82
+ errors = (ValueError , tf .errors .InvalidArgumentError )
83
+ with pytest .raises (errors , match = "`image` must be 4D tensor" ):
84
+ img_utils .from_4D_image (tf .ones (shape = (2 , 4 )), rank )
85
+
86
+ with pytest .raises (errors , match = "`image` must be 4D tensor" ):
87
+ img_utils .from_4D_image (tf .ones (shape = (2 , 4 , 1 )), rank )
88
88
89
- with self . assertRaisesRegexp (errors , "`image` must be 4D tensor" ):
90
- img_utils .from_4D_image (tf .ones (shape = (1 , 2 , 4 , 1 , 1 )), rank )
89
+ with pytest . raises (errors , match = "`image` must be 4D tensor" ):
90
+ img_utils .from_4D_image (tf .ones (shape = (1 , 2 , 4 , 1 , 1 )), rank )
91
91
92
92
93
93
if __name__ == "__main__" :
0 commit comments