Skip to content

Commit 3950433

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] convert all bounding boxes to XYXY in dataset wrapper (#7275)
Reviewed By: vmoens Differential Revision: D44416599 fbshipit-source-id: 0786448bdfe4ca50bcf948cb1254d29a3e7f9367
1 parent f70371c commit 3950433

File tree

1 file changed

+19
-8
lines changed

1 file changed

+19
-8
lines changed

torchvision/datapoints/_dataset_wrapper.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -230,10 +230,13 @@ def wrapper(idx, sample):
230230
batched_target["image_id"] = image_id
231231

232232
spatial_size = tuple(F.get_spatial_size(image))
233-
batched_target["boxes"] = datapoints.BoundingBox(
234-
batched_target["bbox"],
235-
format=datapoints.BoundingBoxFormat.XYWH,
236-
spatial_size=spatial_size,
233+
batched_target["boxes"] = F.convert_format_bounding_box(
234+
datapoints.BoundingBox(
235+
batched_target["bbox"],
236+
format=datapoints.BoundingBoxFormat.XYWH,
237+
spatial_size=spatial_size,
238+
),
239+
new_format=datapoints.BoundingBoxFormat.XYXY,
237240
)
238241
batched_target["masks"] = datapoints.Mask(
239242
torch.stack(
@@ -323,8 +326,13 @@ def wrapper(idx, sample):
323326
target,
324327
target_types=dataset.target_type,
325328
type_wrappers={
326-
"bbox": lambda item: datapoints.BoundingBox(
327-
item, format=datapoints.BoundingBoxFormat.XYWH, spatial_size=(image.height, image.width)
329+
"bbox": lambda item: F.convert_format_bounding_box(
330+
datapoints.BoundingBox(
331+
item,
332+
format=datapoints.BoundingBoxFormat.XYWH,
333+
spatial_size=(image.height, image.width),
334+
),
335+
new_format=datapoints.BoundingBoxFormat.XYXY,
328336
),
329337
},
330338
)
@@ -416,8 +424,11 @@ def wrapper(idx, sample):
416424
image, target = sample
417425

418426
if target is not None:
419-
target["bbox"] = datapoints.BoundingBox(
420-
target["bbox"], format=datapoints.BoundingBoxFormat.XYWH, spatial_size=(image.height, image.width)
427+
target["bbox"] = F.convert_format_bounding_box(
428+
datapoints.BoundingBox(
429+
target["bbox"], format=datapoints.BoundingBoxFormat.XYWH, spatial_size=(image.height, image.width)
430+
),
431+
new_format=datapoints.BoundingBoxFormat.XYXY,
421432
)
422433

423434
return image, target

0 commit comments

Comments
 (0)