-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Refactor test hsv2rgb #3988
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor test hsv2rgb #3988
Conversation
Thanks @vivekkumar7089 ! I'll review in more details but it looks like something went wrong: I see no deleted tests, and it looks like you included |
Thanks, @NicolasHug , for reviewing. I missed deleting the tests while removing merge conflicts. I will change as soon as possible. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Thanks a lot @vivekkumar7089 !
def _test_rotate_all_options(self, tensor, pil_img, scripted_rotate, centers): | ||
img_size = pil_img.size | ||
dt = tensor.dtype | ||
for r in [NEAREST, ]: | ||
for a in range(-180, 180, 17): | ||
for e in [True, False]: | ||
for c in centers: | ||
for f in [None, [0, 0, 0], (1, 2, 3), [255, 255, 255], [1, ], (2.0, )]: | ||
f_pil = int(f[0]) if f is not None and len(f) == 1 else f | ||
out_pil_img = F.rotate(pil_img, angle=a, interpolation=r, expand=e, center=c, fill=f_pil) | ||
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) | ||
for fn in [F.rotate, scripted_rotate]: | ||
out_tensor = fn(tensor, angle=a, interpolation=r, expand=e, center=c, fill=f).cpu() | ||
|
||
if out_tensor.dtype != torch.uint8: | ||
out_tensor = out_tensor.to(torch.uint8) | ||
|
||
self.assertEqual( | ||
out_tensor.shape, | ||
out_pil_tensor.shape, | ||
msg="{}: {} vs {}".format( | ||
(img_size, r, dt, a, e, c), out_tensor.shape, out_pil_tensor.shape | ||
)) | ||
|
||
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0 | ||
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2] | ||
# Tolerance : less than 3% of different pixels | ||
self.assertLess( | ||
ratio_diff_pixels, | ||
0.03, | ||
msg="{}: {}\n{} vs \n{}".format( | ||
(img_size, r, dt, a, e, c, f), | ||
ratio_diff_pixels, | ||
out_tensor[0, :7, :7], | ||
out_pil_tensor[0, :7, :7] | ||
) | ||
) | ||
|
||
def test_rotate(self): | ||
# Tests on square image | ||
scripted_rotate = torch.jit.script(F.rotate) | ||
|
||
data = [_create_data(26, 26, device=self.device), _create_data(32, 26, device=self.device)] | ||
for tensor, pil_img in data: | ||
|
||
img_size = pil_img.size | ||
centers = [ | ||
None, | ||
(int(img_size[0] * 0.3), int(img_size[0] * 0.4)), | ||
[int(img_size[0] * 0.5), int(img_size[0] * 0.6)] | ||
] | ||
|
||
for dt in [None, torch.float32, torch.float64, torch.float16]: | ||
|
||
if dt == torch.float16 and torch.device(self.device).type == "cpu": | ||
# skip float16 on CPU case | ||
continue | ||
|
||
if dt is not None: | ||
tensor = tensor.to(dtype=dt) | ||
|
||
self._test_rotate_all_options(tensor, pil_img, scripted_rotate, centers) | ||
|
||
batch_tensors = _create_data_batch(26, 36, num_samples=4, device=self.device) | ||
if dt is not None: | ||
batch_tensors = batch_tensors.to(dtype=dt) | ||
|
||
center = (20, 22) | ||
_test_fn_on_batch( | ||
batch_tensors, F.rotate, angle=32, interpolation=NEAREST, expand=True, center=center | ||
) | ||
tensor, pil_img = data[0] | ||
# assert deprecation warning and non-BC | ||
with self.assertWarnsRegex(UserWarning, r"Argument resample is deprecated and will be removed"): | ||
res1 = F.rotate(tensor, 45, resample=2) | ||
res2 = F.rotate(tensor, 45, interpolation=BILINEAR) | ||
assert_equal(res1, res2) | ||
|
||
# scriptable function test | ||
s_tuple_transformed_batches = script_ten_crop(batch_tensors, [10, 11]) | ||
for transformed_batch, s_transformed_batch in zip(tuple_transformed_batches, s_tuple_transformed_batches): | ||
assert_equal(transformed_batch, s_transformed_batch) | ||
# assert changed type warning | ||
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"): | ||
res1 = F.rotate(tensor, 45, interpolation=2) | ||
res2 = F.rotate(tensor, 45, interpolation=BILINEAR) | ||
assert_equal(res1, res2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this still left-over from a merge conflict?
Reviewed By: fmassa Differential Revision: D29097714 fbshipit-source-id: 8ed9f670893f43b7b68755dcbd21234ba26f70ba
Refactor group B1 as mentioned in #3956
Group B1
test_hsv2rgb
test_rgb2hsv
test_rgb_to_grayscale -- parametrize over num_output_channels
test_center_crop
test_five_crop
test_ten_crop