-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Cleaning up Ops Boxes and Losses 🧹 #5979
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
0d728c4
475f656
e28511d
77f8f7a
4d55891
8fd0e30
6aea76e
d3b4951
5fdd7a8
2488305
4175be3
4237e4e
6599ec0
9b6bfb1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
from typing import List, Callable | ||
|
||
import pytest | ||
import torch | ||
import torch.fx | ||
from torch import Tensor | ||
from torchvision import ops | ||
|
||
|
||
class IouTestBase: | ||
@staticmethod | ||
def _run_test(target_fn: Callable, test_input: List, dtypes: List[torch.dtype], tolerance: float, expected: List): | ||
def assert_close(box: Tensor, expected: Tensor, tolerance): | ||
out = target_fn(box, box) | ||
torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=tolerance) | ||
|
||
for dtype in dtypes: | ||
actual_box = torch.tensor(test_input, dtype=dtype) | ||
expected_box = torch.tensor(expected) | ||
assert_close(actual_box, expected_box, tolerance) | ||
|
||
@staticmethod | ||
def _run_jit_test(target_fn: Callable, test_input: List): | ||
box_tensor = torch.tensor(test_input, dtype=torch.float) | ||
expected = target_fn(box_tensor, box_tensor) | ||
scripted_fn = torch.jit.script(target_fn) | ||
scripted_out = scripted_fn(box_tensor, box_tensor) | ||
torch.testing.assert_close(scripted_out, expected, rtol=0.0, atol=1e-3) | ||
|
||
|
||
def _generate_int_input(): | ||
return [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]] | ||
|
||
|
||
def _generate_float_input(): | ||
return [ | ||
[285.3538, 185.5758, 1193.5110, 851.4551], | ||
[285.1472, 188.7374, 1192.4984, 851.0669], | ||
[279.2440, 197.9812, 1189.4746, 849.2019], | ||
] | ||
|
||
|
||
class TestBoxIou(IouTestBase): | ||
def _generate_int_expected(): | ||
return [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]] | ||
|
||
def _generate_float_input(): | ||
return [ | ||
[285.3538, 185.5758, 1193.5110, 851.4551], | ||
[285.1472, 188.7374, 1192.4984, 851.0669], | ||
[279.2440, 197.9812, 1189.4746, 849.2019], | ||
] | ||
|
||
def _generate_float_expected(): | ||
return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] | ||
|
||
@pytest.mark.parametrize( | ||
"test_input, dtypes, tolerance, expected", | ||
[ | ||
pytest.param( | ||
_generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected() | ||
), | ||
pytest.param(_generate_float_input(), [torch.float16], 0.002, _generate_float_expected()), | ||
pytest.param(_generate_float_input(), [torch.float32, torch.float64], 1e-3, _generate_float_expected()), | ||
], | ||
) | ||
def test_iou(self, test_input, dtypes, tolerance, expected): | ||
self._run_test(ops.box_iou, test_input, dtypes, tolerance, expected) | ||
|
||
def test_iou_jit(self): | ||
self._run_jit_test(ops.box_iou, [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) | ||
|
||
|
||
class TestGenBoxIou(IouTestBase): | ||
def _generate_int_expected(): | ||
return [[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0]] | ||
|
||
def _generate_float_expected(): | ||
return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] | ||
|
||
@pytest.mark.parametrize( | ||
"test_input, dtypes, tolerance, expected", | ||
[ | ||
pytest.param( | ||
_generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected() | ||
), | ||
pytest.param(_generate_float_input(), [torch.float16], 0.002, _generate_float_expected()), | ||
pytest.param(_generate_float_input(), [torch.float32, torch.float64], 1e-3, _generate_float_expected()), | ||
], | ||
) | ||
def test_iou(self, test_input, dtypes, tolerance, expected): | ||
self._run_test(ops.generalized_box_iou, test_input, dtypes, tolerance, expected) | ||
|
||
def test_iou_jit(self): | ||
self._run_jit_test(ops.generalized_box_iou, [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) | ||
|
||
|
||
class TestDistanceBoxIoU(IouTestBase): | ||
def _generate_int_expected(): | ||
return [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]] | ||
|
||
def _generate_float_expected(): | ||
return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] | ||
|
||
@pytest.mark.parametrize( | ||
"test_input, dtypes, tolerance, expected", | ||
[ | ||
pytest.param( | ||
_generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected() | ||
), | ||
pytest.param(_generate_float_input(), [torch.float16], 0.002, _generate_float_expected()), | ||
pytest.param(_generate_float_input(), [torch.float32, torch.float64], 1e-3, _generate_float_expected()), | ||
], | ||
) | ||
def test_iou(self, test_input, dtypes, tolerance, expected): | ||
self._run_test(ops.distance_box_iou, test_input, dtypes, tolerance, expected) | ||
|
||
def test_iou_jit(self): | ||
self._run_jit_test(ops.distance_box_iou, [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) | ||
|
||
|
||
class TestCompleteBoxIou(IouTestBase): | ||
def _generate_int_expected(): | ||
return [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]] | ||
|
||
def _generate_float_expected(): | ||
return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] | ||
|
||
@pytest.mark.parametrize( | ||
"test_input, dtypes, tolerance, expected", | ||
[ | ||
pytest.param( | ||
_generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected() | ||
), | ||
pytest.param(_generate_float_input(), [torch.float16], 0.002, _generate_float_expected()), | ||
pytest.param(_generate_float_input(), [torch.float32, torch.float64], 1e-3, _generate_float_expected()), | ||
], | ||
) | ||
def test_iou(self, test_input, dtypes, tolerance, expected): | ||
self._run_test(ops.complete_box_iou, test_input, dtypes, tolerance, expected) | ||
|
||
def test_iou_jit(self): | ||
self._run_jit_test(ops.complete_box_iou, [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) | ||
|
||
|
||
if __name__ == "__main__": | ||
pytest.main([__file__]) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,229 @@ | ||
import pytest | ||
import torch | ||
import torch.nn.functional as F | ||
from common_utils import cpu_and_gpu | ||
from torchvision import ops | ||
|
||
|
||
def get_boxes(dtype, device): | ||
box1 = torch.tensor([-1, -1, 1, 1], dtype=dtype, device=device) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not super happy with this choice of box. Since this is actually invalid input There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree with the concern. Detectron2 used the same set of boxes. see this. I think we should use valid input boxes. That being said, should we also check if the input boxes have non-negative values? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We cannot assert here as it will lead to cuda call and cause trouble. I'm not sure if we can use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmm, ig this situation is similar to #5776 (comment) |
||
box2 = torch.tensor([0, 0, 1, 1], dtype=dtype, device=device) | ||
box3 = torch.tensor([0, 1, 1, 2], dtype=dtype, device=device) | ||
box4 = torch.tensor([1, 1, 2, 2], dtype=dtype, device=device) | ||
|
||
box1s = torch.stack([box2, box2], dim=0) | ||
box2s = torch.stack([box3, box4], dim=0) | ||
|
||
return box1, box2, box3, box4, box1s, box2s | ||
|
||
|
||
def assert_iou_loss(iou_fn, box1, box2, expected_loss, dtype, device, reduction="none"): | ||
tol = 1e-3 if dtype is torch.half else 1e-5 | ||
computed_loss = iou_fn(box1, box2, reduction=reduction) | ||
expected_loss = torch.tensor(expected_loss, device=device) | ||
torch.testing.assert_close(computed_loss, expected_loss, rtol=tol, atol=tol) | ||
|
||
|
||
def assert_empty_loss(iou_fn, dtype, device): | ||
box1 = torch.randn([0, 4], dtype=dtype, device=device).requires_grad_() | ||
box2 = torch.randn([0, 4], dtype=dtype, device=device).requires_grad_() | ||
loss = iou_fn(box1, box2, reduction="mean") | ||
loss.backward() | ||
tol = 1e-3 if dtype is torch.half else 1e-5 | ||
torch.testing.assert_close(loss, torch.tensor(0.0, device=device), rtol=tol, atol=tol) | ||
assert box1.grad is not None, "box1.grad should not be None after backward is called" | ||
assert box2.grad is not None, "box2.grad should not be None after backward is called" | ||
loss = iou_fn(box1, box2, reduction="none") | ||
assert loss.numel() == 0, f"{str(iou_fn)} for two empty box should be empty" | ||
|
||
|
||
class TestGeneralizedBoxIouLoss: | ||
# We refer to original test: https://github.com/facebookresearch/fvcore/blob/main/tests/test_giou_loss.py | ||
@pytest.mark.parametrize("device", cpu_and_gpu()) | ||
@pytest.mark.parametrize("dtype", [torch.float32, torch.half]) | ||
def test_giou_loss(self, dtype, device): | ||
|
||
box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device) | ||
|
||
# Identical boxes should have loss of 0 | ||
assert_iou_loss(ops.generalized_box_iou_loss, box1, box1, 0.0, dtype=dtype, device=device) | ||
|
||
# quarter size box inside other box = IoU of 0.25 | ||
assert_iou_loss(ops.generalized_box_iou_loss, box1, box2, 0.75, dtype=dtype, device=device) | ||
|
||
# Two side by side boxes, area=union | ||
# IoU=0 and GIoU=0 (loss 1.0) | ||
assert_iou_loss(ops.generalized_box_iou_loss, box2, box3, 1.0, dtype=dtype, device=device) | ||
|
||
# Two diagonally adjacent boxes, area=2*union | ||
# IoU=0 and GIoU=-0.5 (loss 1.5) | ||
assert_iou_loss(ops.generalized_box_iou_loss, box2, box4, 1.5, dtype=dtype, device=device) | ||
|
||
# Test batched loss and reductions | ||
assert_iou_loss(ops.generalized_box_iou_loss, box1s, box2s, 2.5, dtype=dtype, device=device, reduction="sum") | ||
assert_iou_loss(ops.generalized_box_iou_loss, box1s, box2s, 1.25, dtype=dtype, device=device, reduction="mean") | ||
|
||
@pytest.mark.parametrize("device", cpu_and_gpu()) | ||
@pytest.mark.parametrize("dtype", [torch.float32, torch.half]) | ||
def test_empty_inputs(self, dtype, device): | ||
assert_empty_loss(ops.generalized_box_iou_loss, dtype, device) | ||
|
||
|
||
class TestCIOULoss: | ||
@pytest.mark.parametrize("dtype", [torch.float32, torch.half]) | ||
@pytest.mark.parametrize("device", cpu_and_gpu()) | ||
def test_ciou_loss(self, dtype, device): | ||
box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device) | ||
|
||
assert_iou_loss(ops.complete_box_iou_loss, box1, box1, 0.0, dtype=dtype, device=device) | ||
assert_iou_loss(ops.complete_box_iou_loss, box1, box2, 0.8125, dtype=dtype, device=device) | ||
assert_iou_loss(ops.complete_box_iou_loss, box1, box3, 1.1923, dtype=dtype, device=device) | ||
assert_iou_loss(ops.complete_box_iou_loss, box1, box4, 1.2500, dtype=dtype, device=device) | ||
assert_iou_loss(ops.complete_box_iou_loss, box1s, box2s, 1.2250, dtype=dtype, device=device, reduction="mean") | ||
assert_iou_loss(ops.complete_box_iou_loss, box1s, box2s, 2.4500, dtype=dtype, device=device, reduction="sum") | ||
|
||
@pytest.mark.parametrize("device", cpu_and_gpu()) | ||
@pytest.mark.parametrize("dtype", [torch.float32, torch.half]) | ||
def test_empty_inputs(self, dtype, device): | ||
assert_empty_loss(ops.complete_box_iou_loss, dtype, device) | ||
|
||
|
||
class TestDIouLoss: | ||
@pytest.mark.parametrize("device", cpu_and_gpu()) | ||
@pytest.mark.parametrize("dtype", [torch.float32, torch.half]) | ||
def test_distance_iou_loss(self, dtype, device): | ||
box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device) | ||
|
||
assert_iou_loss(ops.distance_box_iou_loss, box1, box1, 0.0, dtype=dtype, device=device) | ||
assert_iou_loss(ops.distance_box_iou_loss, box1, box2, 0.8125, dtype=dtype, device=device) | ||
assert_iou_loss(ops.distance_box_iou_loss, box1, box3, 1.1923, dtype=dtype, device=device) | ||
assert_iou_loss(ops.distance_box_iou_loss, box1, box4, 1.2500, dtype=dtype, device=device) | ||
assert_iou_loss(ops.distance_box_iou_loss, box1s, box2s, 1.2250, dtype=dtype, device=device, reduction="mean") | ||
assert_iou_loss(ops.distance_box_iou_loss, box1s, box2s, 2.4500, dtype=dtype, device=device, reduction="sum") | ||
|
||
@pytest.mark.parametrize("device", cpu_and_gpu()) | ||
@pytest.mark.parametrize("dtype", [torch.float32, torch.half]) | ||
def test_empty_distance_iou_inputs(self, dtype, device): | ||
assert_empty_loss(ops.distance_box_iou_loss, dtype, device) | ||
|
||
|
||
class TestFocalLoss: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we are testing losses here, I felt to add this here . To avoid confusion between files. |
||
def _generate_diverse_input_target_pair(self, shape=(5, 2), **kwargs): | ||
def logit(p): | ||
return torch.log(p / (1 - p)) | ||
|
||
def generate_tensor_with_range_type(shape, range_type, **kwargs): | ||
if range_type != "random_binary": | ||
low, high = { | ||
"small": (0.0, 0.2), | ||
"big": (0.8, 1.0), | ||
"zeros": (0.0, 0.0), | ||
"ones": (1.0, 1.0), | ||
"random": (0.0, 1.0), | ||
}[range_type] | ||
return torch.testing.make_tensor(shape, low=low, high=high, **kwargs) | ||
else: | ||
return torch.randint(0, 2, shape, **kwargs) | ||
|
||
# This function will return inputs and targets with shape: (shape[0]*9, shape[1]) | ||
inputs = [] | ||
targets = [] | ||
for input_range_type, target_range_type in [ | ||
("small", "zeros"), | ||
("small", "ones"), | ||
("small", "random_binary"), | ||
("big", "zeros"), | ||
("big", "ones"), | ||
("big", "random_binary"), | ||
("random", "zeros"), | ||
("random", "ones"), | ||
("random", "random_binary"), | ||
]: | ||
inputs.append(logit(generate_tensor_with_range_type(shape, input_range_type, **kwargs))) | ||
targets.append(generate_tensor_with_range_type(shape, target_range_type, **kwargs)) | ||
|
||
return torch.cat(inputs), torch.cat(targets) | ||
|
||
@pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0]) | ||
@pytest.mark.parametrize("gamma", [0, 2]) | ||
@pytest.mark.parametrize("device", cpu_and_gpu()) | ||
@pytest.mark.parametrize("dtype", [torch.float32, torch.half]) | ||
@pytest.mark.parametrize("seed", [0, 1]) | ||
def test_correct_ratio(self, alpha, gamma, device, dtype, seed): | ||
if device == "cpu" and dtype is torch.half: | ||
pytest.skip("Currently torch.half is not fully supported on cpu") | ||
# For testing the ratio with manual calculation, we require the reduction to be "none" | ||
reduction = "none" | ||
torch.random.manual_seed(seed) | ||
inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device) | ||
focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) | ||
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction=reduction) | ||
|
||
assert torch.all( | ||
focal_loss <= ce_loss | ||
), "focal loss must be less or equal to cross entropy loss with same input" | ||
|
||
loss_ratio = (focal_loss / ce_loss).squeeze() | ||
prob = torch.sigmoid(inputs) | ||
p_t = prob * targets + (1 - prob) * (1 - targets) | ||
correct_ratio = (1.0 - p_t) ** gamma | ||
if alpha >= 0: | ||
alpha_t = alpha * targets + (1 - alpha) * (1 - targets) | ||
correct_ratio = correct_ratio * alpha_t | ||
|
||
tol = 1e-3 if dtype is torch.half else 1e-5 | ||
torch.testing.assert_close(correct_ratio, loss_ratio, rtol=tol, atol=tol) | ||
|
||
@pytest.mark.parametrize("reduction", ["mean", "sum"]) | ||
@pytest.mark.parametrize("device", cpu_and_gpu()) | ||
@pytest.mark.parametrize("dtype", [torch.float32, torch.half]) | ||
@pytest.mark.parametrize("seed", [2, 3]) | ||
def test_equal_ce_loss(self, reduction, device, dtype, seed): | ||
if device == "cpu" and dtype is torch.half: | ||
pytest.skip("Currently torch.half is not fully supported on cpu") | ||
# focal loss should be equal ce_loss if alpha=-1 and gamma=0 | ||
alpha = -1 | ||
gamma = 0 | ||
torch.random.manual_seed(seed) | ||
inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device) | ||
inputs_fl = inputs.clone().requires_grad_() | ||
targets_fl = targets.clone() | ||
inputs_ce = inputs.clone().requires_grad_() | ||
targets_ce = targets.clone() | ||
focal_loss = ops.sigmoid_focal_loss(inputs_fl, targets_fl, gamma=gamma, alpha=alpha, reduction=reduction) | ||
ce_loss = F.binary_cross_entropy_with_logits(inputs_ce, targets_ce, reduction=reduction) | ||
|
||
tol = 1e-3 if dtype is torch.half else 1e-5 | ||
torch.testing.assert_close(focal_loss, ce_loss, rtol=tol, atol=tol) | ||
|
||
focal_loss.backward() | ||
ce_loss.backward() | ||
torch.testing.assert_close(inputs_fl.grad, inputs_ce.grad, rtol=tol, atol=tol) | ||
|
||
@pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0]) | ||
@pytest.mark.parametrize("gamma", [0, 2]) | ||
@pytest.mark.parametrize("reduction", ["none", "mean", "sum"]) | ||
@pytest.mark.parametrize("device", cpu_and_gpu()) | ||
@pytest.mark.parametrize("dtype", [torch.float32, torch.half]) | ||
@pytest.mark.parametrize("seed", [4, 5]) | ||
def test_jit(self, alpha, gamma, reduction, device, dtype, seed): | ||
if device == "cpu" and dtype is torch.half: | ||
pytest.skip("Currently torch.half is not fully supported on cpu") | ||
script_fn = torch.jit.script(ops.sigmoid_focal_loss) | ||
torch.random.manual_seed(seed) | ||
inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device) | ||
focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) | ||
if device == "cpu": | ||
scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) | ||
else: | ||
with torch.jit.fuser("fuser2"): | ||
# Use fuser2 to prevent a bug on fuser: https://github.com/pytorch/pytorch/issues/75476 | ||
# We may remove this condition once the bug is resolved | ||
scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) | ||
|
||
tol = 1e-3 if dtype is torch.half else 1e-5 | ||
torch.testing.assert_close(focal_loss, scripted_focal_loss, rtol=tol, atol=tol) | ||
|
||
|
||
if __name__ == "__main__": | ||
pytest.main([__file__]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Making a class, inheriting and the calling method is also possible. For now is this fine?