-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Add masks to boundaries #7704
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add masks to boundaries #7704
Changes from 8 commits
79dcbb1
d171ffd
9d41c0a
e277308
330301c
a8bd95c
7311956
08485c0
59fb72c
091f3fb
fa68881
c2d8074
aa4b2e3
293e436
cf07bc0
7abbc3b
762992f
0991f93
4de4913
91df477
ebee25e
9fc12a9
080fa0d
e526765
1ec78df
78062c0
6cce8a6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -398,7 +398,40 @@ def _box_diou_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tuple[Te | |
# distance between boxes' centers squared. | ||
return iou - (centers_distance_squared / diagonal_distance_squared), iou | ||
|
||
def masks_to_boundaries(masks: torch.Tensor, dilation_ratio: float = 0.02) -> torch.Tensor: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess it's OK to have the implementation in this file even though this isn't related to boxed. However, I don't think we should expose it here. I think we should just expose it in from the We probably just need to rename this to from .boxes import import _masks_to_boundaries as masks_to_boundaries Any other suggestion @pmeier @vfdev-5 @oke-aditya ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
No strong opinion, but could we maybe also have a new 👍 for only exposing it in the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Tbh there is demand for mask_utils. Several of them, #4415 . Candidate utils like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we can always create an I'm OK with creating There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This sounds best solution! We can avoid the bloat inside this file as well as keep them private 😄 |
||
""" | ||
Compute the boundaries around the provided masks using morphological operations. | ||
|
||
Returns a tensor of the same shape as the input masks containing the boundaries of each mask. | ||
|
||
Args: | ||
masks (Tensor[N, H, W]): masks to transform where N is the number of masks | ||
and (H, W) are the spatial dimensions. | ||
dilation_ratio (float, optional): ratio used for the dilation operation. Default: 0.02 | ||
|
||
Returns: | ||
Tensor[N, H, W]: boundaries | ||
bhack marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
# If no masks are provided, return an empty tensor | ||
if masks.numel() == 0: | ||
return torch.zeros_like(masks) | ||
|
||
n, h, w = masks.shape | ||
img_diag = math.sqrt(h ** 2 + w ** 2) | ||
dilation = int(round(dilation_ratio * img_diag)) | ||
selem_size = dilation * 2 + 1 | ||
bhack marked this conversation as resolved.
Show resolved
Hide resolved
bhack marked this conversation as resolved.
Show resolved
Hide resolved
|
||
selem = torch.ones((n, 1, selem_size, selem_size), device=masks.device) | ||
|
||
# Compute the boundaries for each mask | ||
masks = masks.float().unsqueeze(1) | ||
eroded_masks = F.conv2d(masks, selem, padding=dilation) | ||
# Make the output binary | ||
eroded_masks = (eroded_masks == selem.view(n, -1).sum(-1).view(n, 1, 1, 1)).byte() | ||
|
||
contours = masks.byte() - eroded_masks | ||
|
||
return contours.squeeze(1) | ||
|
||
def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Compute the bounding boxes around the provided masks. | ||
|
@@ -431,3 +464,4 @@ def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor: | |
bounding_boxes[index, 3] = torch.max(y) | ||
|
||
return bounding_boxes | ||
|
Uh oh!
There was an error while loading. Please reload this page.