Skip to content

[WIP] Implement torchvision.utils.draw_bounding_boxes #2631

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 11 commits into from
38 changes: 38 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
from PIL import Image


def path_to_tensor(filepath):
from numpy import array as to_numpy_array
return torch.from_numpy(to_numpy_array(Image.open(filepath)))


class Tester(unittest.TestCase):

def test_make_grid_not_inplace(self):
Expand Down Expand Up @@ -41,6 +46,39 @@ def test_normalize_in_make_grid(self):
self.assertTrue(torch.equal(norm_max, rounded_grid_max), 'Normalized max is not equal to 1')
self.assertTrue(torch.equal(norm_min, rounded_grid_min), 'Normalized min is not equal to 0')

# def test_bboxes_not_inplace(self):
# t = torch.rand(5, 3, 10, 10) * 255
# t_clone = t.clone()
#
# TODO: this doesn't work; we need to pass in bboxes
# utils.draw_bounding_bboxes(t, draw_labels=False)
# self.assertTrue(torch.equal(t, t_clone), 'draw_bounding_bboxes modified tensor in-place')
#
# utils.draw_bounding_bboxes(t, draw_labels=True)
# self.assertTrue(torch.equal(t, t_clone), 'draw_bounding_bboxes modified tensor in-place')

def test_bboxes(self):
from numpy import array as to_numpy_array

IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
IMAGE_DIR = os.path.join(IMAGE_ROOT, "fakedata", "imagefolder")
inp_img_path = os.path.join(IMAGE_DIR, 'a4.png')
out_img_path = os.path.join(IMAGE_DIR, 'b5.png')

inp_img_pil = path_to_tensor(inp_img_path)
bboxes = ((1, 2, 10, 18), (4, 8, 9, 11))
# TODO: maybe write the rectangle programatically in this test instead of
# statically loading output?
out_img_pil = path_to_tensor(out_img_path)

self.assertTrue(
torch.equal(
utils.draw_bounding_bboxes(inp_img_pil, bboxes, draw_labels=False),
out_img_pil,
),
'draw_bounding_bboxes returned an incorrect result',
)

@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
def test_save_image(self):
with tempfile.NamedTemporaryFile(suffix='.png') as f:
Expand Down
70 changes: 69 additions & 1 deletion torchvision/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union, Optional, List, Tuple, Text, BinaryIO
from typing import Union, Optional, List, Tuple, Text, BinaryIO, Sequence, Dict
import io
import pathlib
import torch
Expand Down Expand Up @@ -128,3 +128,71 @@ def save_image(
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
im = Image.fromarray(ndarr)
im.save(fp, format=format)


BBox = Tuple[int, int, int, int]
BBoxes = Sequence[BBox]
Color = Tuple[int, int, int]
DEFAULT_COLORS: Sequence[Color]


def draw_bounding_boxes(
image: torch.Tensor,
bboxes: Union[BBoxes, Dict[str, Sequence[BBox]]],
colors: Optional[Dict[str, Color]] = None,
draw_labels: bool = None,
width: int = 1,
) -> torch.Tensor:
# TODO: docstring

bboxes_is_seq = BBoxes.__instancecheck__(bboxes)
# bboxes_is_dict is Dict[str, Sequence[BBox]].__instancecheck__(bboxes)
bboxes_is_dict = not bboxes_is_seq

if bboxes_is_seq:
# TODO: raise better Errors
if colors is not None:
# can't pass custom colors if bboxes is a sequence
raise Error
if draw_labels is True:
# can't draw labels if bboxes is a sequence
raise Error

if draw_labels is None:
if bboxes_is_seq:
draw_labels = False
else: # BBoxes.__instancecheck__(Dict[str, Sequence[BBox]])
draw_labels = True

# colors: Union[Sequence[Color], Dict[str, Color]]
if colors is None:
# TODO: default to one of @pmeir's suggestions as a seq
colors_: Sequence[Color] = colors
else:
colors_: Dict[str, Color] = colors

from PIL import Image, ImageDraw
# Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
ndarr = image.mul(255).add_(0.5).clamp_(0, 255).permute(
1, 2, 0).to('cpu', torch.uint8).numpy()
im = Image.fromarray(ndarr)
draw = ImageDraw.Draw(im)

if bboxes_is_dict:
if Sequence[Color].__instancecheck__(colors_):
# align the colors seq with the bbox classes
colors = dict(zip(sorted(bboxes.keys()), colors_))

for bbox_class, bbox in enumerate(bboxes.items()):
draw.rectangle(bbox, outline=colors_[bbox_class], width=width)
if draw_labels:
# TODO: this will probably overlap with the bbox
# hard-code in a margin for the label?
label_tl_x, label_tl_y, _, _ = bbox
draw.text((label_tl_x, label_tl_y), bbox_class)
else: # bboxes_is_seq
for i, bbox in enumerate(bboxes):
draw.rectangle(bbox, outline=colors_[i], width=width)

from numpy import array as to_numpy_array
return torch.from_numpy(to_numpy_array(im))