diff --git a/test/test_utils.py b/test/test_utils.py index f1982130f75..f1a8d65db6d 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -9,6 +9,11 @@ from PIL import Image +def path_to_tensor(filepath): + from numpy import array as to_numpy_array + return torch.from_numpy(to_numpy_array(Image.open(filepath))) + + class Tester(unittest.TestCase): def test_make_grid_not_inplace(self): @@ -41,6 +46,39 @@ def test_normalize_in_make_grid(self): self.assertTrue(torch.equal(norm_max, rounded_grid_max), 'Normalized max is not equal to 1') self.assertTrue(torch.equal(norm_min, rounded_grid_min), 'Normalized min is not equal to 0') + # def test_bboxes_not_inplace(self): + # t = torch.rand(5, 3, 10, 10) * 255 + # t_clone = t.clone() + # + # TODO: this doesn't work; we need to pass in bboxes + # utils.draw_bounding_bboxes(t, draw_labels=False) + # self.assertTrue(torch.equal(t, t_clone), 'draw_bounding_bboxes modified tensor in-place') + # + # utils.draw_bounding_bboxes(t, draw_labels=True) + # self.assertTrue(torch.equal(t, t_clone), 'draw_bounding_bboxes modified tensor in-place') + + def test_bboxes(self): + from numpy import array as to_numpy_array + + IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets") + IMAGE_DIR = os.path.join(IMAGE_ROOT, "fakedata", "imagefolder") + inp_img_path = os.path.join(IMAGE_DIR, 'a4.png') + out_img_path = os.path.join(IMAGE_DIR, 'b5.png') + + inp_img_pil = path_to_tensor(inp_img_path) + bboxes = ((1, 2, 10, 18), (4, 8, 9, 11)) + # TODO: maybe write the rectangle programatically in this test instead of + # statically loading output? + out_img_pil = path_to_tensor(out_img_path) + + self.assertTrue( + torch.equal( + utils.draw_bounding_bboxes(inp_img_pil, bboxes, draw_labels=False), + out_img_pil, + ), + 'draw_bounding_bboxes returned an incorrect result', + ) + @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows') def test_save_image(self): with tempfile.NamedTemporaryFile(suffix='.png') as f: diff --git a/torchvision/utils.py b/torchvision/utils.py index be373138c5f..54c5278aef0 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -1,4 +1,4 @@ -from typing import Union, Optional, List, Tuple, Text, BinaryIO +from typing import Union, Optional, List, Tuple, Text, BinaryIO, Sequence, Dict import io import pathlib import torch @@ -128,3 +128,71 @@ def save_image( ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() im = Image.fromarray(ndarr) im.save(fp, format=format) + + +BBox = Tuple[int, int, int, int] +BBoxes = Sequence[BBox] +Color = Tuple[int, int, int] +DEFAULT_COLORS: Sequence[Color] + + +def draw_bounding_boxes( + image: torch.Tensor, + bboxes: Union[BBoxes, Dict[str, Sequence[BBox]]], + colors: Optional[Dict[str, Color]] = None, + draw_labels: bool = None, + width: int = 1, +) -> torch.Tensor: + # TODO: docstring + + bboxes_is_seq = BBoxes.__instancecheck__(bboxes) + # bboxes_is_dict is Dict[str, Sequence[BBox]].__instancecheck__(bboxes) + bboxes_is_dict = not bboxes_is_seq + + if bboxes_is_seq: + # TODO: raise better Errors + if colors is not None: + # can't pass custom colors if bboxes is a sequence + raise Error + if draw_labels is True: + # can't draw labels if bboxes is a sequence + raise Error + + if draw_labels is None: + if bboxes_is_seq: + draw_labels = False + else: # BBoxes.__instancecheck__(Dict[str, Sequence[BBox]]) + draw_labels = True + + # colors: Union[Sequence[Color], Dict[str, Color]] + if colors is None: + # TODO: default to one of @pmeir's suggestions as a seq + colors_: Sequence[Color] = colors + else: + colors_: Dict[str, Color] = colors + + from PIL import Image, ImageDraw + # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer + ndarr = image.mul(255).add_(0.5).clamp_(0, 255).permute( + 1, 2, 0).to('cpu', torch.uint8).numpy() + im = Image.fromarray(ndarr) + draw = ImageDraw.Draw(im) + + if bboxes_is_dict: + if Sequence[Color].__instancecheck__(colors_): + # align the colors seq with the bbox classes + colors = dict(zip(sorted(bboxes.keys()), colors_)) + + for bbox_class, bbox in enumerate(bboxes.items()): + draw.rectangle(bbox, outline=colors_[bbox_class], width=width) + if draw_labels: + # TODO: this will probably overlap with the bbox + # hard-code in a margin for the label? + label_tl_x, label_tl_y, _, _ = bbox + draw.text((label_tl_x, label_tl_y), bbox_class) + else: # bboxes_is_seq + for i, bbox in enumerate(bboxes): + draw.rectangle(bbox, outline=colors_[i], width=width) + + from numpy import array as to_numpy_array + return torch.from_numpy(to_numpy_array(im))