Skip to content

Commit 8d2a431

Browse files
authored
Fix colors type
1 parent f8a0b29 commit 8d2a431

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

torchvision/utils.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -164,9 +164,12 @@ def draw_bounding_boxes(
164164
else: # BBoxes.__instancecheck__(Dict[str, Sequence[BBox]])
165165
draw_labels = True
166166

167+
# colors: Union[Sequence[Color], Dict[str, Color]]
167168
if colors is None:
168169
# TODO: default to one of @pmeir's suggestions as a seq
169-
pass
170+
colors_: Sequence[Color] = colors
171+
else:
172+
colors_: Dict[str, Color] = colors
170173

171174
from PIL import Image, ImageDraw
172175
# Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
@@ -176,20 +179,20 @@ def draw_bounding_boxes(
176179
draw = ImageDraw.Draw(im)
177180

178181
if bboxes_is_dict:
179-
if Sequence[Color].__instancecheck__(colors):
182+
if Sequence[Color].__instancecheck__(colors_):
180183
# align the colors seq with the bbox classes
181-
colors = dict(zip(sorted(bboxes.keys()), colors))
184+
colors = dict(zip(sorted(bboxes.keys()), colors_))
182185

183-
for i, (bbox_class, bbox) in enumerate(bboxes.items()):
184-
draw.rectangle(bbox, outline=colors[bbox_class], width=width)
186+
for bbox_class, bbox in enumerate(bboxes.items()):
187+
draw.rectangle(bbox, outline=colors_[bbox_class], width=width)
185188
if draw_labels:
186189
# TODO: this will probably overlap with the bbox
187190
# hard-code in a margin for the label?
188191
label_tl_x, label_tl_y, _, _ = bbox
189192
draw.text((label_tl_x, label_tl_y), bbox_class)
190193
else: # bboxes_is_seq
191194
for i, bbox in enumerate(bboxes):
192-
draw.rectangle(bbox, outline=colors[i], width=width)
195+
draw.rectangle(bbox, outline=colors_[i], width=width)
193196

194197
from numpy import array as to_numpy_array
195198
return torch.from_numpy(to_numpy_array(im))

0 commit comments

Comments
 (0)