-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Added CIOU loss function #5776
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added CIOU loss function #5776
Changes from all commits
d868ec1
abb09eb
9c2ee2e
f3f1d92
2d0f627
a158ca3
5760487
56147d2
38e7a19
9a1cf90
755fa07
1a6b59a
99a3951
d89dbec
c9b0cab
8c2feee
b1d33fa
c531b1d
19b23d1
916418f
96c6dda
844e0da
38f9ede
2422913
ada4471
b8a7d96
c8a18ce
9b4803a
5cf1591
d25a5a0
14add84
03ecb91
1c4ae7f
2cbc6a2
9c88d92
e36fb15
1e57b6b
7e244fb
47c7e09
f5a352c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import torch | ||
|
||
from ..utils import _log_api_usage_once | ||
from .giou_loss import _upcast | ||
abhi-glitchhg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
def complete_box_iou_loss( | ||
boxes1: torch.Tensor, | ||
boxes2: torch.Tensor, | ||
reduction: str = "none", | ||
eps: float = 1e-7, | ||
) -> torch.Tensor: | ||
|
||
""" | ||
abhi-glitchhg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Gradient-friendly IoU loss with an additional penalty that is non-zero when the | ||
boxes do not overlap overlap area, This loss function considers important geometrical | ||
factors such as overlap area, normalized central point distance and aspect ratio. | ||
This loss is symmetric, so the boxes1 and boxes2 arguments are interchangeable. | ||
|
||
abhi-glitchhg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with | ||
``0 <= x1 < x2`` and ``0 <= y1 < y2``, and The two boxes should have the | ||
same dimensions. | ||
|
||
Args: | ||
boxes1 : (Tensor[N, 4] or Tensor[4]) first set of boxes | ||
boxes2 : (Tensor[N, 4] or Tensor[4]) second set of boxes | ||
reduction : (string, optional) Specifies the reduction to apply to the output: | ||
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: No reduction will be | ||
applied to the output. ``'mean'``: The output will be averaged. | ||
``'sum'``: The output will be summed. Default: ``'none'`` | ||
eps : (float): small number to prevent division by zero. Default: 1e-7 | ||
|
||
Reference: | ||
|
||
Complete Intersection over Union Loss (Zhaohui Zheng et. al) | ||
https://arxiv.org/abs/1911.08287 | ||
|
||
""" | ||
|
||
abhi-glitchhg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Original Implementation : https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/losses.py | ||
|
||
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): | ||
_log_api_usage_once(complete_box_iou_loss) | ||
|
||
boxes1 = _upcast(boxes1) | ||
boxes2 = _upcast(boxes2) | ||
|
||
x1, y1, x2, y2 = boxes1.unbind(dim=-1) | ||
x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1) | ||
|
||
# Intersection keypoints | ||
xkis1 = torch.max(x1, x1g) | ||
ykis1 = torch.max(y1, y1g) | ||
xkis2 = torch.min(x2, x2g) | ||
ykis2 = torch.min(y2, y2g) | ||
|
||
intsct = torch.zeros_like(x1) | ||
mask = (ykis2 > ykis1) & (xkis2 > xkis1) | ||
intsct[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask]) | ||
union = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsct + eps | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: Moreover it's worth noting that cIoU and dIoU share a large number of common code that could be shared. |
||
iou = intsct / union | ||
|
||
# smallest enclosing box | ||
xc1 = torch.min(x1, x1g) | ||
yc1 = torch.min(y1, y1g) | ||
xc2 = torch.max(x2, x2g) | ||
yc2 = torch.max(y2, y2g) | ||
diag_len = ((xc2 - xc1) ** 2) + ((yc2 - yc1) ** 2) + eps | ||
|
||
# centers of boxes | ||
x_p = (x2 + x1) / 2 | ||
y_p = (y2 + y1) / 2 | ||
x_g = (x1g + x2g) / 2 | ||
y_g = (y1g + y2g) / 2 | ||
distance = ((x_p - x_g) ** 2) + ((y_p - y_g) ** 2) | ||
|
||
# width and height of boxes | ||
w_pred = x2 - x1 | ||
h_pred = y2 - y1 | ||
w_gt = x2g - x1g | ||
h_gt = y2g - y1g | ||
v = (4 / (torch.pi ** 2)) * torch.pow((torch.atan(w_gt / h_gt) - torch.atan(w_pred / h_pred)), 2) | ||
with torch.no_grad(): | ||
alpha = v / (1 - iou + v + eps) | ||
|
||
loss = 1 - iou + (distance / diag_len) + alpha * v | ||
if reduction == "mean": | ||
loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum() | ||
elif reduction == "sum": | ||
loss = loss.sum() | ||
|
||
return loss |
Uh oh!
There was an error while loading. Please reload this page.