|
23 | 23 | os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg')
|
24 | 24 |
|
25 | 25 |
|
| 26 | +def cycle_over(objs): |
| 27 | + objs = list(objs) |
| 28 | + for idx, obj in enumerate(objs): |
| 29 | + yield obj, objs[:idx] + objs[idx + 1:] |
| 30 | + |
| 31 | +def int_dtypes(): |
| 32 | + yield from iter( |
| 33 | + (torch.uint8, torch.int8, torch.int16, torch.short, torch.int32, torch.int, torch.int64, torch.long,) |
| 34 | + ) |
| 35 | + |
| 36 | +def float_dtypes(): |
| 37 | + yield from iter((torch.float32, torch.float, torch.float64, torch.double)) |
| 38 | + |
| 39 | + |
26 | 40 | class Tester(unittest.TestCase):
|
27 | 41 |
|
28 | 42 | def test_crop(self):
|
@@ -510,54 +524,99 @@ def test_to_tensor(self):
|
510 | 524 | output = trans(img)
|
511 | 525 | self.assertTrue(np.allclose(input_data.numpy(), output.numpy()))
|
512 | 526 |
|
513 |
| - def test_convert_image_dtype(self): |
514 |
| - def cycle_over(objs): |
515 |
| - objs = list(objs) |
516 |
| - for idx, obj in enumerate(objs): |
517 |
| - yield obj, objs[:idx] + objs[idx + 1:] |
518 |
| - |
519 |
| - # dtype_max_value = { |
520 |
| - # dtype: 1.0 |
521 |
| - # for dtype in (torch.float32, torch.float, torch.float64, torch.double)#, torch.bool,) |
522 |
| - # # torch.float16 and torch.half are disabled for now since they do not support torch.max |
523 |
| - # # See https://github.com/pytorch/pytorch/issues/28623#issuecomment-611379051 |
524 |
| - # # (torch.float32, torch.float, torch.float64, torch.double, torch.float16, torch.half, torch.bool, ) |
525 |
| - # } |
526 |
| - dtype_max_value = {} |
527 |
| - dtype_max_value.update( |
528 |
| - { |
529 |
| - dtype: torch.iinfo(dtype).max |
530 |
| - for dtype in ( |
531 |
| - torch.uint8, |
532 |
| - torch.int8, |
533 |
| - torch.int16, |
534 |
| - torch.short, |
535 |
| - torch.int32, |
536 |
| - torch.int, |
537 |
| - torch.int64, |
538 |
| - torch.long, |
539 |
| - ) |
540 |
| - } |
541 |
| - ) |
| 527 | + def test_convert_image_dtype_float_to_float(self): |
| 528 | + for input_dtype, output_dtypes in cycle_over(float_dtypes()): |
| 529 | + input_image = torch.tensor((0.0, 1.0), dtype=input_dtype) |
| 530 | + for output_dtype in output_dtypes: |
| 531 | + with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): |
| 532 | + transform = transforms.ConvertImageDtype(output_dtype) |
| 533 | + output_image = transform(input_image) |
| 534 | + |
| 535 | + actual_min, actual_max = output_image.tolist() |
| 536 | + desired_min, desired_max = 0.0, 1.0 |
| 537 | + |
| 538 | + self.assertAlmostEqual(actual_min, desired_min) |
| 539 | + self.assertAlmostEqual(actual_max, desired_max) |
| 540 | + |
| 541 | + def test_convert_image_dtype_float_to_int(self): |
| 542 | + for input_dtype in float_dtypes(): |
| 543 | + input_image = torch.tensor((0.0, 1.0), dtype=input_dtype) |
| 544 | + for output_dtype in int_dtypes(): |
| 545 | + with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): |
| 546 | + transform = transforms.ConvertImageDtype(output_dtype) |
| 547 | + |
| 548 | + if (input_dtype == torch.float32 and output_dtype in (torch.int32, torch.int64)) or ( |
| 549 | + input_dtype == torch.float64 and output_dtype == torch.int64 |
| 550 | + ): |
| 551 | + with self.assertRaises(RuntimeError): |
| 552 | + transform(input_image) |
| 553 | + else: |
| 554 | + output_image = transform(input_image) |
542 | 555 |
|
543 |
| - for input_dtype, output_dtypes in cycle_over(dtype_max_value.keys()): |
544 |
| - input_image = torch.ones(1, dtype=input_dtype) * dtype_max_value[input_dtype] |
| 556 | + actual_min, actual_max = output_image.tolist() |
| 557 | + desired_min, desired_max = 0, torch.iinfo(output_dtype).max |
545 | 558 |
|
| 559 | + self.assertEqual(actual_min, desired_min) |
| 560 | + self.assertEqual(actual_max, desired_max) |
| 561 | + |
| 562 | + def test_convert_image_dtype_int_to_float(self): |
| 563 | + for input_dtype in int_dtypes(): |
| 564 | + input_image = torch.tensor((0, torch.iinfo(input_dtype).max), dtype=input_dtype) |
| 565 | + for output_dtype in float_dtypes(): |
| 566 | + with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): |
| 567 | + transform = transforms.ConvertImageDtype(output_dtype) |
| 568 | + output_image = transform(input_image) |
| 569 | + |
| 570 | + actual_min, actual_max = output_image.tolist() |
| 571 | + desired_min, desired_max = 0.0, 1.0 |
| 572 | + |
| 573 | + self.assertAlmostEqual(actual_min, desired_min) |
| 574 | + self.assertGreaterEqual(actual_min, desired_min) |
| 575 | + self.assertAlmostEqual(actual_max, desired_max) |
| 576 | + self.assertLessEqual(actual_max, desired_max) |
| 577 | + |
| 578 | + def test_convert_image_dtype_int_to_int(self): |
| 579 | + for input_dtype, output_dtypes in cycle_over(int_dtypes()): |
| 580 | + input_max = torch.iinfo(input_dtype).max |
| 581 | + input_image = torch.tensor((0, input_max), dtype=input_dtype) |
546 | 582 | for output_dtype in output_dtypes:
|
| 583 | + output_max = torch.iinfo(output_dtype).max |
| 584 | + |
547 | 585 | with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
|
548 | 586 | transform = transforms.ConvertImageDtype(output_dtype)
|
549 | 587 | output_image = transform(input_image)
|
550 | 588 |
|
551 |
| - actual = output_image.dtype |
552 |
| - desired = output_dtype |
553 |
| - self.assertEqual(actual, desired) |
| 589 | + actual_min, actual_max = output_image.tolist() |
| 590 | + desired_min, desired_max = 0, output_max |
554 | 591 |
|
555 |
| - actual = torch.max(output_image).item() |
556 |
| - desired = dtype_max_value[output_dtype] |
557 |
| - if output_dtype.is_floating_point: |
558 |
| - self.assertAlmostEqual(actual, desired) |
| 592 | + # see https://github.com/pytorch/vision/pull/2078#issuecomment-641036236 for details |
| 593 | + if input_max >= output_max: |
| 594 | + error_term = 0 |
559 | 595 | else:
|
560 |
| - self.assertEqual(actual, desired) |
| 596 | + error_term = 1 - (torch.iinfo(output_dtype).max + 1) // (torch.iinfo(input_dtype).max + 1) |
| 597 | + |
| 598 | + self.assertEqual(actual_min, desired_min) |
| 599 | + self.assertEqual(actual_max, desired_max + error_term) |
| 600 | + |
| 601 | + def test_convert_image_dtype_int_to_int_consistency(self): |
| 602 | + for input_dtype, output_dtypes in cycle_over(int_dtypes()): |
| 603 | + input_max = torch.iinfo(input_dtype).max |
| 604 | + input_image = torch.tensor((0, input_max), dtype=input_dtype) |
| 605 | + for output_dtype in output_dtypes: |
| 606 | + output_max = torch.iinfo(output_dtype).max |
| 607 | + if output_max <= input_max: |
| 608 | + continue |
| 609 | + |
| 610 | + with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype): |
| 611 | + transform = transforms.ConvertImageDtype(output_dtype) |
| 612 | + inverse_transfrom = transforms.ConvertImageDtype(input_dtype) |
| 613 | + output_image = inverse_transfrom(transform(input_image)) |
| 614 | + |
| 615 | + actual_min, actual_max = output_image.tolist() |
| 616 | + desired_min, desired_max = 0, input_max |
| 617 | + |
| 618 | + self.assertEqual(actual_min, desired_min) |
| 619 | + self.assertEqual(actual_max, desired_max) |
561 | 620 |
|
562 | 621 | @unittest.skipIf(accimage is None, 'accimage not available')
|
563 | 622 | def test_accimage_to_tensor(self):
|
|
0 commit comments