@@ -191,54 +191,56 @@ def test_convert_image_color_space(self, transform, input):
191
191
transform (input )
192
192
193
193
194
+ @pytest .mark .parametrize ("p" , [0.0 , 1.0 ])
194
195
class TestRandomHorizontalFlip :
195
- def input_tensor (self , dtype = torch .float32 ):
196
- return torch .tensor ([[[0 , 1 ], [0 , 1 ]], [[1 , 0 ], [1 , 0 ]]], dtype = dtype )
196
+ def input_expected_image_tensor (self , p , dtype = torch .float32 ):
197
+ input = torch .tensor ([[[0 , 1 ], [0 , 1 ]], [[1 , 0 ], [1 , 0 ]]], dtype = dtype )
198
+ expected = torch .tensor ([[[1 , 0 ], [1 , 0 ]], [[0 , 1 ], [0 , 1 ]]], dtype = dtype )
197
199
198
- def expected_tensor (self , dtype = torch .float32 ):
199
- return torch .tensor ([[[1 , 0 ], [1 , 0 ]], [[0 , 1 ], [0 , 1 ]]], dtype = dtype )
200
+ if p == 1.0 :
201
+ return input , expected
202
+ return input , input
200
203
201
- @pytest .mark .parametrize ("p" , [0.0 , 1.0 ], ids = ["p=0" , "p=1" ])
202
204
def test_simple_tensor (self , p ):
203
- input = self .input_tensor ()
205
+ input , expected = self .input_expected_image_tensor (p )
206
+ transform = transforms .RandomHorizontalFlip (p = p )
204
207
205
- actual = transforms . RandomHorizontalFlip ( p = p ) (input )
208
+ actual = transform (input )
206
209
207
- expected = self .expected_tensor () if p == 1.0 else input
208
210
assert_equal (expected , actual )
209
211
210
- @pytest .mark .parametrize ("p" , [0.0 , 1.0 ], ids = ["p=0" , "p=1" ])
211
212
def test_pil_image (self , p ):
212
- input = self .input_tensor (dtype = torch .uint8 )
213
+ input , expected = self .input_expected_image_tensor (p , dtype = torch .uint8 )
214
+ transform = transforms .RandomHorizontalFlip (p = p )
213
215
214
- actual = transforms . RandomHorizontalFlip ( p = p ) (to_pil_image (input ))
216
+ actual = transform (to_pil_image (input ))
215
217
216
- expected = self .expected_tensor (dtype = torch .uint8 ) if p == 1.0 else input
217
218
assert_equal (expected , pil_to_tensor (actual ))
218
219
219
- @pytest .mark .parametrize ("p" , [0.0 , 1.0 ], ids = ["p=0" , "p=1" ])
220
220
def test_features_image (self , p ):
221
- input = self .input_tensor ()
221
+ input , expected = self .input_expected_image_tensor (p )
222
+ transform = transforms .RandomHorizontalFlip (p = p )
222
223
223
- actual = transforms . RandomHorizontalFlip ( p = p ) (features .Image (input ))
224
+ actual = transform (features .Image (input ))
224
225
225
- expected = self .expected_tensor () if p == 1.0 else input
226
226
assert_equal (features .Image (expected ), actual )
227
227
228
- @pytest .mark .parametrize ("p" , [0.0 , 1.0 ], ids = ["p=0" , "p=1" ])
229
228
def test_features_segmentation_mask (self , p ):
230
- input = features .SegmentationMask (self .input_tensor ())
229
+ input , expected = self .input_expected_image_tensor (p )
230
+ transform = transforms .RandomHorizontalFlip (p = p )
231
231
232
- actual = transforms . RandomHorizontalFlip ( p = p )( input )
232
+ actual = transform ( features . SegmentationMask ( input ) )
233
233
234
- expected = self .expected_tensor () if p == 1.0 else input
235
234
assert_equal (features .SegmentationMask (expected ), actual )
236
235
237
- @pytest .mark .parametrize ("p" , [0.0 , 1.0 ], ids = ["p=0" , "p=1" ])
238
236
def test_features_bounding_box (self , p ):
239
237
input = features .BoundingBox ([0 , 0 , 5 , 5 ], format = features .BoundingBoxFormat .XYXY , image_size = (10 , 10 ))
238
+ transform = transforms .RandomHorizontalFlip (p = p )
240
239
241
- actual = transforms . RandomHorizontalFlip ( p = p ) (input )
240
+ actual = transform (input )
242
241
243
- expected = torch .tensor ([5 , 0 , 10 , 5 ]) if p == 1.0 else input
244
- assert_equal (features .BoundingBox .new_like (input , expected ), actual )
242
+ expected_image_tensor = torch .tensor ([5 , 0 , 10 , 5 ]) if p == 1.0 else input
243
+ expected = features .BoundingBox .new_like (input , data = expected_image_tensor )
244
+ assert_equal (expected , actual )
245
+ assert actual .format == expected .format
246
+ assert actual .image_size == expected .image_size
0 commit comments