From 7da0a41c6fea1f59041547eb9ddb7504a9de022f Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Mon, 16 May 2022 21:57:43 +0530 Subject: [PATCH 01/14] Refactor tests --- test/test_ious.py | 147 +++++++++++++ test/test_losses.py | 229 +++++++++++++++++++ test/test_ops.py | 525 +++----------------------------------------- 3 files changed, 412 insertions(+), 489 deletions(-) create mode 100644 test/test_ious.py create mode 100644 test/test_losses.py diff --git a/test/test_ious.py b/test/test_ious.py new file mode 100644 index 00000000000..4e87d64b477 --- /dev/null +++ b/test/test_ious.py @@ -0,0 +1,147 @@ +from typing import List, Callable + +import pytest +import torch +import torch.fx +from torch import Tensor +from torchvision import ops + + +class IouTestBase: + @staticmethod + def _run_test(target_fn: Callable, test_input: List, dtypes: List[torch.dtype], tolerance: float, expected: List): + def assert_close(box: Tensor, expected: Tensor, tolerance): + out = target_fn(box, box) + torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=tolerance) + + for dtype in dtypes: + actual_box = torch.tensor(test_input, dtype=dtype) + expected_box = torch.tensor(expected) + assert_close(actual_box, expected_box, tolerance) + + @staticmethod + def _run_jit_test(target_fn: Callable, test_input: List): + box_tensor = torch.tensor(test_input, dtype=torch.float) + expected = target_fn(box_tensor, box_tensor) + scripted_fn = torch.jit.script(target_fn) + scripted_out = scripted_fn(box_tensor, box_tensor) + torch.testing.assert_close(scripted_out, expected, rtol=0.0, atol=1e-3) + + +def _generate_int_input(): + return [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]] + + +def _generate_float_input(): + return [ + [285.3538, 185.5758, 1193.5110, 851.4551], + [285.1472, 188.7374, 1192.4984, 851.0669], + [279.2440, 197.9812, 1189.4746, 849.2019], + ] + + +class TestBoxIou(IouTestBase): + def _generate_int_expected(): + return [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]] + + def _generate_float_input(): + return [ + [285.3538, 185.5758, 1193.5110, 851.4551], + [285.1472, 188.7374, 1192.4984, 851.0669], + [279.2440, 197.9812, 1189.4746, 849.2019], + ] + + def _generate_float_expected(): + return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] + + @pytest.mark.parametrize( + "test_input, dtypes, tolerance, expected", + [ + pytest.param( + _generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected() + ), + pytest.param(_generate_float_input(), [torch.float16], 0.002, _generate_float_expected()), + pytest.param(_generate_float_input(), [torch.float32, torch.float64], 1e-3, _generate_float_expected()), + ], + ) + def test_iou(self, test_input, dtypes, tolerance, expected): + self._run_test(ops.box_iou, test_input, dtypes, tolerance, expected) + + def test_iou_jit(self): + self._run_jit_test(ops.box_iou, [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) + + +class TestGenBoxIou(IouTestBase): + def _generate_int_expected(): + return [[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0]] + + def _generate_float_expected(): + return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] + + @pytest.mark.parametrize( + "test_input, dtypes, tolerance, expected", + [ + pytest.param( + _generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected() + ), + pytest.param(_generate_float_input(), [torch.float16], 0.002, _generate_float_expected()), + pytest.param(_generate_float_input(), [torch.float32, torch.float64], 1e-3, _generate_float_expected()), + ], + ) + def test_iou(self, test_input, dtypes, tolerance, expected): + self._run_test(ops.generalized_box_iou, test_input, dtypes, tolerance, expected) + + def test_iou_jit(self): + self._run_jit_test(ops.generalized_box_iou, [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) + + +class TestDistanceBoxIoU(IouTestBase): + def _generate_int_expected(): + return [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]] + + def _generate_float_expected(): + return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] + + @pytest.mark.parametrize( + "test_input, dtypes, tolerance, expected", + [ + pytest.param( + _generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected() + ), + pytest.param(_generate_float_input(), [torch.float16], 0.002, _generate_float_expected()), + pytest.param(_generate_float_input(), [torch.float32, torch.float64], 1e-3, _generate_float_expected()), + ], + ) + def test_iou(self, test_input, dtypes, tolerance, expected): + self._run_test(ops.distance_box_iou, test_input, dtypes, tolerance, expected) + + def test_iou_jit(self): + self._run_jit_test(ops.distance_box_iou, [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) + + +class TestCompleteBoxIou(IouTestBase): + def _generate_int_expected(): + return [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]] + + def _generate_float_expected(): + return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] + + @pytest.mark.parametrize( + "test_input, dtypes, tolerance, expected", + [ + pytest.param( + _generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected() + ), + pytest.param(_generate_float_input(), [torch.float16], 0.002, _generate_float_expected()), + pytest.param(_generate_float_input(), [torch.float32, torch.float64], 1e-3, _generate_float_expected()), + ], + ) + def test_iou(self, test_input, dtypes, tolerance, expected): + self._run_test(ops.complete_box_iou, test_input, dtypes, tolerance, expected) + + def test_iou_jit(self): + self._run_jit_test(ops.complete_box_iou, [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/test/test_losses.py b/test/test_losses.py new file mode 100644 index 00000000000..c2d7f9452ef --- /dev/null +++ b/test/test_losses.py @@ -0,0 +1,229 @@ +import pytest +import torch +import torch.nn.functional as F +from common_utils import cpu_and_gpu +from torchvision import ops + + +def get_boxes(dtype, device): + box1 = torch.tensor([-1, -1, 1, 1], dtype=dtype, device=device) + box2 = torch.tensor([0, 0, 1, 1], dtype=dtype, device=device) + box3 = torch.tensor([0, 1, 1, 2], dtype=dtype, device=device) + box4 = torch.tensor([1, 1, 2, 2], dtype=dtype, device=device) + + box1s = torch.stack([box2, box2], dim=0) + box2s = torch.stack([box3, box4], dim=0) + + return box1, box2, box3, box4, box1s, box2s + + +def assert_iou_loss(iou_fn, box1, box2, expected_loss, dtype, device, reduction="none"): + tol = 1e-3 if dtype is torch.half else 1e-5 + computed_loss = iou_fn(box1, box2, reduction=reduction) + expected_loss = torch.tensor(expected_loss, device=device) + torch.testing.assert_close(computed_loss, expected_loss, rtol=tol, atol=tol) + + +def assert_empty_loss(iou_fn, dtype, device): + box1 = torch.randn([0, 4], dtype=dtype, device=device).requires_grad_() + box2 = torch.randn([0, 4], dtype=dtype, device=device).requires_grad_() + loss = iou_fn(box1, box2, reduction="mean") + loss.backward() + tol = 1e-3 if dtype is torch.half else 1e-5 + torch.testing.assert_close(loss, torch.tensor(0.0, device=device), rtol=tol, atol=tol) + assert box1.grad is not None, "box1.grad should not be None after backward is called" + assert box2.grad is not None, "box2.grad should not be None after backward is called" + loss = iou_fn(box1, box2, reduction="none") + assert loss.numel() == 0, f"{str(iou_fn)} for two empty box should be empty" + + +class TestGeneralizedBoxIouLoss: + # We refer to original test: https://github.com/facebookresearch/fvcore/blob/main/tests/test_giou_loss.py + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) + def test_giou_loss(self, dtype, device): + + box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device) + + # Identical boxes should have loss of 0 + assert_iou_loss(ops.generalized_box_iou_loss, box1, box1, 0.0, dtype=dtype, device=device) + + # quarter size box inside other box = IoU of 0.25 + assert_iou_loss(ops.generalized_box_iou_loss, box1, box2, 0.75, dtype=dtype, device=device) + + # Two side by side boxes, area=union + # IoU=0 and GIoU=0 (loss 1.0) + assert_iou_loss(ops.generalized_box_iou_loss, box2, box3, 1.0, dtype=dtype, device=device) + + # Two diagonally adjacent boxes, area=2*union + # IoU=0 and GIoU=-0.5 (loss 1.5) + assert_iou_loss(ops.generalized_box_iou_loss, box2, box4, 1.5, dtype=dtype, device=device) + + # Test batched loss and reductions + assert_iou_loss(ops.generalized_box_iou_loss, box1s, box2s, 2.5, dtype=dtype, device=device, reduction="sum") + assert_iou_loss(ops.generalized_box_iou_loss, box1s, box2s, 1.25, dtype=dtype, device=device, reduction="mean") + + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) + def test_empty_inputs(self, dtype, device): + assert_empty_loss(ops.generalized_box_iou_loss, dtype, device) + + +class TestCIOULoss: + @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) + @pytest.mark.parametrize("device", cpu_and_gpu()) + def test_ciou_loss(self, dtype, device): + box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device) + + assert_iou_loss(ops.complete_box_iou_loss, box1, box1, 0.0, dtype=dtype, device=device) + assert_iou_loss(ops.complete_box_iou_loss, box1, box2, 0.8125, dtype=dtype, device=device) + assert_iou_loss(ops.complete_box_iou_loss, box1, box3, 1.1923, dtype=dtype, device=device) + assert_iou_loss(ops.complete_box_iou_loss, box1, box4, 1.2500, dtype=dtype, device=device) + assert_iou_loss(ops.complete_box_iou_loss, box1s, box2s, 1.2250, dtype=dtype, device=device, reduction="mean") + assert_iou_loss(ops.complete_box_iou_loss, box1s, box2s, 2.4500, dtype=dtype, device=device, reduction="sum") + + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) + def test_empty_inputs(self, dtype, device): + assert_empty_loss(ops.complete_box_iou_loss, dtype, device) + + +class TestDIouLoss: + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) + def test_distance_iou_loss(self, dtype, device): + box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device) + + assert_iou_loss(ops.distance_box_iou_loss, box1, box1, 0.0, dtype=dtype, device=device) + assert_iou_loss(ops.distance_box_iou_loss, box1, box2, 0.8125, dtype=dtype, device=device) + assert_iou_loss(ops.distance_box_iou_loss, box1, box3, 1.1923, dtype=dtype, device=device) + assert_iou_loss(ops.distance_box_iou_loss, box1, box4, 1.2500, dtype=dtype, device=device) + assert_iou_loss(ops.distance_box_iou_loss, box1s, box2s, 1.2250, dtype=dtype, device=device, reduction="mean") + assert_iou_loss(ops.distance_box_iou_loss, box1s, box2s, 2.4500, dtype=dtype, device=device, reduction="sum") + + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) + def test_empty_distance_iou_inputs(self, dtype, device): + assert_empty_loss(ops.distance_box_iou_loss, dtype, device) + + +class TestFocalLoss: + def _generate_diverse_input_target_pair(self, shape=(5, 2), **kwargs): + def logit(p): + return torch.log(p / (1 - p)) + + def generate_tensor_with_range_type(shape, range_type, **kwargs): + if range_type != "random_binary": + low, high = { + "small": (0.0, 0.2), + "big": (0.8, 1.0), + "zeros": (0.0, 0.0), + "ones": (1.0, 1.0), + "random": (0.0, 1.0), + }[range_type] + return torch.testing.make_tensor(shape, low=low, high=high, **kwargs) + else: + return torch.randint(0, 2, shape, **kwargs) + + # This function will return inputs and targets with shape: (shape[0]*9, shape[1]) + inputs = [] + targets = [] + for input_range_type, target_range_type in [ + ("small", "zeros"), + ("small", "ones"), + ("small", "random_binary"), + ("big", "zeros"), + ("big", "ones"), + ("big", "random_binary"), + ("random", "zeros"), + ("random", "ones"), + ("random", "random_binary"), + ]: + inputs.append(logit(generate_tensor_with_range_type(shape, input_range_type, **kwargs))) + targets.append(generate_tensor_with_range_type(shape, target_range_type, **kwargs)) + + return torch.cat(inputs), torch.cat(targets) + + @pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0]) + @pytest.mark.parametrize("gamma", [0, 2]) + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) + @pytest.mark.parametrize("seed", [0, 1]) + def test_correct_ratio(self, alpha, gamma, device, dtype, seed): + if device == "cpu" and dtype is torch.half: + pytest.skip("Currently torch.half is not fully supported on cpu") + # For testing the ratio with manual calculation, we require the reduction to be "none" + reduction = "none" + torch.random.manual_seed(seed) + inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device) + focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) + ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction=reduction) + + assert torch.all( + focal_loss <= ce_loss + ), "focal loss must be less or equal to cross entropy loss with same input" + + loss_ratio = (focal_loss / ce_loss).squeeze() + prob = torch.sigmoid(inputs) + p_t = prob * targets + (1 - prob) * (1 - targets) + correct_ratio = (1.0 - p_t) ** gamma + if alpha >= 0: + alpha_t = alpha * targets + (1 - alpha) * (1 - targets) + correct_ratio = correct_ratio * alpha_t + + tol = 1e-3 if dtype is torch.half else 1e-5 + torch.testing.assert_close(correct_ratio, loss_ratio, rtol=tol, atol=tol) + + @pytest.mark.parametrize("reduction", ["mean", "sum"]) + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) + @pytest.mark.parametrize("seed", [2, 3]) + def test_equal_ce_loss(self, reduction, device, dtype, seed): + if device == "cpu" and dtype is torch.half: + pytest.skip("Currently torch.half is not fully supported on cpu") + # focal loss should be equal ce_loss if alpha=-1 and gamma=0 + alpha = -1 + gamma = 0 + torch.random.manual_seed(seed) + inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device) + inputs_fl = inputs.clone().requires_grad_() + targets_fl = targets.clone() + inputs_ce = inputs.clone().requires_grad_() + targets_ce = targets.clone() + focal_loss = ops.sigmoid_focal_loss(inputs_fl, targets_fl, gamma=gamma, alpha=alpha, reduction=reduction) + ce_loss = F.binary_cross_entropy_with_logits(inputs_ce, targets_ce, reduction=reduction) + + tol = 1e-3 if dtype is torch.half else 1e-5 + torch.testing.assert_close(focal_loss, ce_loss, rtol=tol, atol=tol) + + focal_loss.backward() + ce_loss.backward() + torch.testing.assert_close(inputs_fl.grad, inputs_ce.grad, rtol=tol, atol=tol) + + @pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0]) + @pytest.mark.parametrize("gamma", [0, 2]) + @pytest.mark.parametrize("reduction", ["none", "mean", "sum"]) + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) + @pytest.mark.parametrize("seed", [4, 5]) + def test_jit(self, alpha, gamma, reduction, device, dtype, seed): + if device == "cpu" and dtype is torch.half: + pytest.skip("Currently torch.half is not fully supported on cpu") + script_fn = torch.jit.script(ops.sigmoid_focal_loss) + torch.random.manual_seed(seed) + inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device) + focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) + if device == "cpu": + scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) + else: + with torch.jit.fuser("fuser2"): + # Use fuser2 to prevent a bug on fuser: https://github.com/pytorch/pytorch/issues/75476 + # We may remove this condition once the bug is resolved + scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) + + tol = 1e-3 if dtype is torch.half else 1e-5 + torch.testing.assert_close(focal_loss, scripted_focal_loss, rtol=tol, atol=tol) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/test/test_ops.py b/test/test_ops.py index 96cfb630e8d..df5d397713c 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -3,13 +3,12 @@ from abc import ABC, abstractmethod from functools import lru_cache from itertools import product -from typing import Callable, List, Tuple +from typing import Tuple import numpy as np import pytest import torch import torch.fx -import torch.nn.functional as F from common_utils import assert_equal, cpu_and_gpu, needs_cuda from PIL import Image from torch import nn, Tensor @@ -1021,7 +1020,7 @@ def test_convert_boxes_to_roi_format(self, box_sequence): assert_equal(ref_tensor, ops._utils.convert_boxes_to_roi_format(box_sequence)) -class TestBox: +class TestBoxConvert: def test_bbox_same(self): box_tensor = torch.tensor( [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float @@ -1051,7 +1050,7 @@ def test_bbox_xyxy_xywh(self): assert_equal(box_xyxy, box_tensor) def test_bbox_xyxy_cxcywh(self): - # Simple test convert boxes to xywh and back. Make sure they are same. + # Simple test convert boxes to cxcywh and back. Make sure they are same. # box_tensor is in x1 y1 x2 y2 format. box_tensor = torch.tensor( [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float @@ -1073,7 +1072,6 @@ def test_bbox_xywh_cxcywh(self): [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float ) - # This is wrong exp_cxcywh = torch.tensor( [[50, 50, 100, 100], [0, 0, 0, 0], [20, 25, 20, 20], [58, 65, 70, 60]], dtype=torch.float ) @@ -1113,277 +1111,48 @@ def test_bbox_convert_jit(self): torch.testing.assert_close(scripted_cxcywh, box_cxcywh, rtol=0.0, atol=TOLERANCE) -class BoxTestBase(ABC): - @abstractmethod - def _target_fn(self) -> Tuple[bool, Callable]: - pass +def area_check(box, expected, tolerance=1e-4): + out = ops.box_area(box) + torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=tolerance) - def _perform_box_operation(self, box: Tensor, run_as_script: bool = False) -> Tensor: - is_binary_fn = self._target_fn()[0] - target_fn = self._target_fn()[1] - box_operation = torch.jit.script(target_fn) if run_as_script else target_fn - return box_operation(box, box) if is_binary_fn else box_operation(box) - def _run_test(self, test_input: List, dtypes: List[torch.dtype], tolerance: float, expected: List) -> None: - def assert_close(box: Tensor, expected: Tensor, tolerance): - out = self._perform_box_operation(box) - torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=tolerance) - - for dtype in dtypes: - actual_box = torch.tensor(test_input, dtype=dtype) - expected_box = torch.tensor(expected) - assert_close(actual_box, expected_box, tolerance) - - def _run_jit_test(self, test_input: List) -> None: - box_tensor = torch.tensor(test_input, dtype=torch.float) - expected = self._perform_box_operation(box_tensor, True) - scripted_area = self._perform_box_operation(box_tensor, True) - torch.testing.assert_close(scripted_area, expected, rtol=0.0, atol=1e-3) +class TestBoxArea: + @pytest.mark.parametrize("dtype", [torch.int8, torch.int16, torch.int32, torch.int64]) + def test_int_boxes(self, dtype): + # Check for int boxes + box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=dtype) + expected = torch.tensor([10000, 0]) + area_check(box_tensor, expected) - -class TestBoxArea(BoxTestBase): - def _target_fn(self) -> Tuple[bool, Callable]: - return (False, ops.box_area) - - def _generate_int_input() -> List[List[int]]: - return [[0, 0, 100, 100], [0, 0, 0, 0]] - - def _generate_int_expected() -> List[int]: - return [10000, 0] - - def _generate_float_input(index: int) -> List[List[float]]: - return [ + # Check for float32 and float64 boxes + @pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) + def test_float_boxes(self, dtype): + box_tensor = torch.tensor( [ [285.3538, 185.5758, 1193.5110, 851.4551], [285.1472, 188.7374, 1192.4984, 851.0669], [279.2440, 197.9812, 1189.4746, 849.2019], ], - [[285.25, 185.625, 1194.0, 851.5], [285.25, 188.75, 1192.0, 851.0], [279.25, 198.0, 1189.0, 849.0]], - ][index] - - def _generate_float_expected(index: int) -> List[float]: - return [[604723.0806, 600965.4666, 592761.0085], [605113.875, 600495.1875, 592247.25]][index] - - @pytest.mark.parametrize( - "test_input, dtypes, tolerance, expected", - [ - pytest.param( - _generate_int_input(), - [torch.int8, torch.int16, torch.int32, torch.int64], - 1e-4, - _generate_int_expected(), - ), - pytest.param(_generate_float_input(0), [torch.float32, torch.float64], 0.05, _generate_float_expected(0)), - pytest.param(_generate_float_input(1), [torch.float16], 1e-4, _generate_float_expected(1)), - ], - ) - def test_box_area(self, test_input: List, dtypes: List[torch.dtype], tolerance: float, expected: List) -> None: - self._run_test(test_input, dtypes, tolerance, expected) - - def test_box_area_jit(self) -> None: - self._run_jit_test([[0, 0, 100, 100], [0, 0, 0, 0]]) - - -class TestBoxIou(BoxTestBase): - def _target_fn(self) -> Tuple[bool, Callable]: - return (True, ops.box_iou) - - def _generate_int_input() -> List[List[int]]: - return [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]] - - def _generate_int_expected() -> List[List[float]]: - return [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]] - - def _generate_float_input() -> List[List[float]]: - return [ - [285.3538, 185.5758, 1193.5110, 851.4551], - [285.1472, 188.7374, 1192.4984, 851.0669], - [279.2440, 197.9812, 1189.4746, 849.2019], - ] - - def _generate_float_expected() -> List[List[float]]: - return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] - - @pytest.mark.parametrize( - "test_input, dtypes, tolerance, expected", - [ - pytest.param( - _generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected() - ), - pytest.param(_generate_float_input(), [torch.float16], 0.002, _generate_float_expected()), - pytest.param(_generate_float_input(), [torch.float32, torch.float64], 1e-4, _generate_float_expected()), - ], - ) - def test_iou(self, test_input: List, dtypes: List[torch.dtype], tolerance: float, expected: List) -> None: - self._run_test(test_input, dtypes, tolerance, expected) - - def test_iou_jit(self) -> None: - self._run_jit_test([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) - - -class TestGenBoxIou(BoxTestBase): - def _target_fn(self) -> Tuple[bool, Callable]: - return (True, ops.generalized_box_iou) - - def _generate_int_input() -> List[List[int]]: - return [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]] - - def _generate_int_expected() -> List[List[float]]: - return [[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0]] - - def _generate_float_input() -> List[List[float]]: - return [ - [285.3538, 185.5758, 1193.5110, 851.4551], - [285.1472, 188.7374, 1192.4984, 851.0669], - [279.2440, 197.9812, 1189.4746, 849.2019], - ] - - def _generate_float_expected() -> List[List[float]]: - return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] - - @pytest.mark.parametrize( - "test_input, dtypes, tolerance, expected", - [ - pytest.param( - _generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected() - ), - pytest.param(_generate_float_input(), [torch.float16], 0.002, _generate_float_expected()), - pytest.param(_generate_float_input(), [torch.float32, torch.float64], 0.001, _generate_float_expected()), - ], - ) - def test_gen_iou(self, test_input: List, dtypes: List[torch.dtype], tolerance: float, expected: List) -> None: - self._run_test(test_input, dtypes, tolerance, expected) - - def test_giou_jit(self) -> None: - self._run_jit_test([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) - - -class TestDistanceBoxIoU(BoxTestBase): - def _target_fn(self): - return (True, ops.distance_box_iou) - - def _generate_int_input(): - return [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]] - - def _generate_int_expected(): - return [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]] - - def _generate_float_input(): - return [ - [285.3538, 185.5758, 1193.5110, 851.4551], - [285.1472, 188.7374, 1192.4984, 851.0669], - [279.2440, 197.9812, 1189.4746, 849.2019], - ] - - def _generate_float_expected(): - return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] - - @pytest.mark.parametrize( - "test_input, dtypes, tolerance, expected", - [ - pytest.param( - _generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected() - ), - pytest.param(_generate_float_input(), [torch.float16], 0.002, _generate_float_expected()), - pytest.param(_generate_float_input(), [torch.float32, torch.float64], 0.001, _generate_float_expected()), - ], - ) - def test_distance_iou(self, test_input, dtypes, tolerance, expected): - self._run_test(test_input, dtypes, tolerance, expected) - - def test_distance_iou_jit(self): - self._run_jit_test([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) - - -@pytest.mark.parametrize("device", cpu_and_gpu()) -@pytest.mark.parametrize("dtype", [torch.float32, torch.half]) -def test_distance_iou_loss(dtype, device): - box1 = torch.tensor([-1, -1, 1, 1], dtype=dtype, device=device) - box2 = torch.tensor([0, 0, 1, 1], dtype=dtype, device=device) - box3 = torch.tensor([0, 1, 1, 2], dtype=dtype, device=device) - box4 = torch.tensor([1, 1, 2, 2], dtype=dtype, device=device) - - box1s = torch.stack( - [box2, box2], - dim=0, - ) - box2s = torch.stack( - [box3, box4], - dim=0, - ) - - def assert_distance_iou_loss(box1, box2, expected_output, reduction="none"): - output = ops.distance_box_iou_loss(box1, box2, reduction=reduction) - # TODO: When passing the dtype, the torch.half fails as usual. - expected_output = torch.tensor(expected_output, device=device) - tol = 1e-5 if dtype != torch.half else 1e-3 - torch.testing.assert_close(output, expected_output, rtol=tol, atol=tol) - - assert_distance_iou_loss(box1, box1, 0.0) - - assert_distance_iou_loss(box1, box2, 0.8125) - - assert_distance_iou_loss(box1, box3, 1.1923) - - assert_distance_iou_loss(box1, box4, 1.2500) - - assert_distance_iou_loss(box1s, box2s, 1.2250, reduction="mean") - assert_distance_iou_loss(box1s, box2s, 2.4500, reduction="sum") - - -@pytest.mark.parametrize("device", cpu_and_gpu()) -@pytest.mark.parametrize("dtype", [torch.float32, torch.half]) -def test_empty_distance_iou_inputs(dtype, device) -> None: - box1 = torch.randn([0, 4], dtype=dtype, device=device).requires_grad_() - box2 = torch.randn([0, 4], dtype=dtype, device=device).requires_grad_() - - loss = ops.distance_box_iou_loss(box1, box2, reduction="mean") - loss.backward() - - tol = 1e-3 if dtype is torch.half else 1e-5 - torch.testing.assert_close(loss, torch.tensor(0.0, device=device), rtol=tol, atol=tol) - assert box1.grad is not None, "box1.grad should not be None after backward is called" - assert box2.grad is not None, "box2.grad should not be None after backward is called" - - loss = ops.distance_box_iou_loss(box1, box2, reduction="none") - assert loss.numel() == 0, "diou_loss for two empty box should be empty" - - -class TestCompleteBoxIou(BoxTestBase): - def _target_fn(self) -> Tuple[bool, Callable]: - return (True, ops.complete_box_iou) - - def _generate_int_input() -> List[List[int]]: - return [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]] - - def _generate_int_expected() -> List[List[float]]: - return [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]] - - def _generate_float_input() -> List[List[float]]: - return [ - [285.3538, 185.5758, 1193.5110, 851.4551], - [285.1472, 188.7374, 1192.4984, 851.0669], - [279.2440, 197.9812, 1189.4746, 849.2019], - ] - - def _generate_float_expected() -> List[List[float]]: - return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] - - @pytest.mark.parametrize( - "test_input, dtypes, tolerance, expected", - [ - pytest.param( - _generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected() - ), - pytest.param(_generate_float_input(), [torch.float32, torch.float64], 0.002, _generate_float_expected()), - pytest.param(_generate_float_input(), [torch.float32, torch.float64], 0.001, _generate_float_expected()), - ], - ) - def test_complete_iou(self, test_input: List, dtypes: List[torch.dtype], tolerance: float, expected: List) -> None: - self._run_test(test_input, dtypes, tolerance, expected) + dtype=dtype, + ) + expected = torch.tensor([604723.0806, 600965.4666, 592761.0085], dtype=torch.float64) + area_check(box_tensor, expected, tolerance=0.05) - def test_ciou_jit(self) -> None: - self._run_jit_test([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) + def test_float16_box(self): + # Check for float16 box + box_tensor = torch.tensor( + [[285.25, 185.625, 1194.0, 851.5], [285.25, 188.75, 1192.0, 851.0], [279.25, 198.0, 1189.0, 849.0]], + dtype=torch.float16, + ) + expected = torch.tensor([605113.875, 600495.1875, 592247.25]) + area_check(box_tensor, expected) + + def test_box_area_jit(self): + box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=torch.float) + expected = ops.box_area(box_tensor) + scripted_fn = torch.jit.script(ops.box_area) + scripted_area = scripted_fn(box_tensor) + torch.testing.assert_close(scripted_area, expected, rtol=0.0, atol=1e-3) class TestMasksToBoxes: @@ -1579,227 +1348,5 @@ def test_is_leaf_node(self, dim, p, block_size, inplace): assert len(graph_node_names[0]) == 1 + op_obj.n_inputs -class TestFocalLoss: - def _generate_diverse_input_target_pair(self, shape=(5, 2), **kwargs): - def logit(p: Tensor) -> Tensor: - return torch.log(p / (1 - p)) - - def generate_tensor_with_range_type(shape, range_type, **kwargs): - if range_type != "random_binary": - low, high = { - "small": (0.0, 0.2), - "big": (0.8, 1.0), - "zeros": (0.0, 0.0), - "ones": (1.0, 1.0), - "random": (0.0, 1.0), - }[range_type] - return torch.testing.make_tensor(shape, low=low, high=high, **kwargs) - else: - return torch.randint(0, 2, shape, **kwargs) - - # This function will return inputs and targets with shape: (shape[0]*9, shape[1]) - inputs = [] - targets = [] - for input_range_type, target_range_type in [ - ("small", "zeros"), - ("small", "ones"), - ("small", "random_binary"), - ("big", "zeros"), - ("big", "ones"), - ("big", "random_binary"), - ("random", "zeros"), - ("random", "ones"), - ("random", "random_binary"), - ]: - inputs.append(logit(generate_tensor_with_range_type(shape, input_range_type, **kwargs))) - targets.append(generate_tensor_with_range_type(shape, target_range_type, **kwargs)) - - return torch.cat(inputs), torch.cat(targets) - - @pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0]) - @pytest.mark.parametrize("gamma", [0, 2]) - @pytest.mark.parametrize("device", cpu_and_gpu()) - @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) - @pytest.mark.parametrize("seed", [0, 1]) - def test_correct_ratio(self, alpha, gamma, device, dtype, seed) -> None: - if device == "cpu" and dtype is torch.half: - pytest.skip("Currently torch.half is not fully supported on cpu") - # For testing the ratio with manual calculation, we require the reduction to be "none" - reduction = "none" - torch.random.manual_seed(seed) - inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device) - focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) - ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction=reduction) - - assert torch.all( - focal_loss <= ce_loss - ), "focal loss must be less or equal to cross entropy loss with same input" - - loss_ratio = (focal_loss / ce_loss).squeeze() - prob = torch.sigmoid(inputs) - p_t = prob * targets + (1 - prob) * (1 - targets) - correct_ratio = (1.0 - p_t) ** gamma - if alpha >= 0: - alpha_t = alpha * targets + (1 - alpha) * (1 - targets) - correct_ratio = correct_ratio * alpha_t - - tol = 1e-3 if dtype is torch.half else 1e-5 - torch.testing.assert_close(correct_ratio, loss_ratio, rtol=tol, atol=tol) - - @pytest.mark.parametrize("reduction", ["mean", "sum"]) - @pytest.mark.parametrize("device", cpu_and_gpu()) - @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) - @pytest.mark.parametrize("seed", [2, 3]) - def test_equal_ce_loss(self, reduction, device, dtype, seed) -> None: - if device == "cpu" and dtype is torch.half: - pytest.skip("Currently torch.half is not fully supported on cpu") - # focal loss should be equal ce_loss if alpha=-1 and gamma=0 - alpha = -1 - gamma = 0 - torch.random.manual_seed(seed) - inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device) - inputs_fl = inputs.clone().requires_grad_() - targets_fl = targets.clone() - inputs_ce = inputs.clone().requires_grad_() - targets_ce = targets.clone() - focal_loss = ops.sigmoid_focal_loss(inputs_fl, targets_fl, gamma=gamma, alpha=alpha, reduction=reduction) - ce_loss = F.binary_cross_entropy_with_logits(inputs_ce, targets_ce, reduction=reduction) - - tol = 1e-3 if dtype is torch.half else 1e-5 - torch.testing.assert_close(focal_loss, ce_loss, rtol=tol, atol=tol) - - focal_loss.backward() - ce_loss.backward() - torch.testing.assert_close(inputs_fl.grad, inputs_ce.grad, rtol=tol, atol=tol) - - @pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0]) - @pytest.mark.parametrize("gamma", [0, 2]) - @pytest.mark.parametrize("reduction", ["none", "mean", "sum"]) - @pytest.mark.parametrize("device", cpu_and_gpu()) - @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) - @pytest.mark.parametrize("seed", [4, 5]) - def test_jit(self, alpha, gamma, reduction, device, dtype, seed) -> None: - if device == "cpu" and dtype is torch.half: - pytest.skip("Currently torch.half is not fully supported on cpu") - script_fn = torch.jit.script(ops.sigmoid_focal_loss) - torch.random.manual_seed(seed) - inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device) - focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) - if device == "cpu": - scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) - else: - with torch.jit.fuser("fuser2"): - # Use fuser2 to prevent a bug on fuser: https://github.com/pytorch/pytorch/issues/75476 - # We may remove this condition once the bug is resolved - scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) - - tol = 1e-3 if dtype is torch.half else 1e-5 - torch.testing.assert_close(focal_loss, scripted_focal_loss, rtol=tol, atol=tol) - - -class TestGeneralizedBoxIouLoss: - # We refer to original test: https://github.com/facebookresearch/fvcore/blob/main/tests/test_giou_loss.py - @pytest.mark.parametrize("device", cpu_and_gpu()) - @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) - def test_giou_loss(self, dtype, device) -> None: - box1 = torch.tensor([-1, -1, 1, 1], dtype=dtype, device=device) - box2 = torch.tensor([0, 0, 1, 1], dtype=dtype, device=device) - box3 = torch.tensor([0, 1, 1, 2], dtype=dtype, device=device) - box4 = torch.tensor([1, 1, 2, 2], dtype=dtype, device=device) - - box1s = torch.stack([box2, box2], dim=0) - box2s = torch.stack([box3, box4], dim=0) - - def assert_giou_loss(box1, box2, expected_loss, reduction="none"): - tol = 1e-3 if dtype is torch.half else 1e-5 - computed_loss = ops.generalized_box_iou_loss(box1, box2, reduction=reduction) - expected_loss = torch.tensor(expected_loss, device=device) - torch.testing.assert_close(computed_loss, expected_loss, rtol=tol, atol=tol) - - # Identical boxes should have loss of 0 - assert_giou_loss(box1, box1, 0.0) - - # quarter size box inside other box = IoU of 0.25 - assert_giou_loss(box1, box2, 0.75) - - # Two side by side boxes, area=union - # IoU=0 and GIoU=0 (loss 1.0) - assert_giou_loss(box2, box3, 1.0) - - # Two diagonally adjacent boxes, area=2*union - # IoU=0 and GIoU=-0.5 (loss 1.5) - assert_giou_loss(box2, box4, 1.5) - - # Test batched loss and reductions - assert_giou_loss(box1s, box2s, 2.5, reduction="sum") - assert_giou_loss(box1s, box2s, 1.25, reduction="mean") - - @pytest.mark.parametrize("device", cpu_and_gpu()) - @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) - def test_empty_inputs(self, dtype, device) -> None: - box1 = torch.randn([0, 4], dtype=dtype).requires_grad_() - box2 = torch.randn([0, 4], dtype=dtype).requires_grad_() - - loss = ops.generalized_box_iou_loss(box1, box2, reduction="mean") - loss.backward() - - tol = 1e-3 if dtype is torch.half else 1e-5 - torch.testing.assert_close(loss, torch.tensor(0.0), rtol=tol, atol=tol) - assert box1.grad is not None, "box1.grad should not be None after backward is called" - assert box2.grad is not None, "box2.grad should not be None after backward is called" - - loss = ops.generalized_box_iou_loss(box1, box2, reduction="none") - assert loss.numel() == 0, "giou_loss for two empty box should be empty" - - -class TestCIOULoss: - @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) - @pytest.mark.parametrize("device", cpu_and_gpu()) - def test_ciou_loss(self, dtype, device): - box1 = torch.tensor([-1, -1, 1, 1], dtype=dtype, device=device) - box2 = torch.tensor([0, 0, 1, 1], dtype=dtype, device=device) - box3 = torch.tensor([0, 1, 1, 2], dtype=dtype, device=device) - box4 = torch.tensor([1, 1, 2, 2], dtype=dtype, device=device) - - box1s = torch.stack([box2, box2], dim=0) - box2s = torch.stack([box3, box4], dim=0) - - def assert_ciou_loss(box1, box2, expected_output, reduction="none"): - - output = ops.complete_box_iou_loss(box1, box2, reduction=reduction) - # TODO: When passing the dtype, the torch.half test doesn't pass... - expected_output = torch.tensor(expected_output, device=device) - tol = 1e-5 if dtype != torch.half else 1e-3 - torch.testing.assert_close(output, expected_output, rtol=tol, atol=tol) - - assert_ciou_loss(box1, box1, 0.0) - - assert_ciou_loss(box1, box2, 0.8125) - - assert_ciou_loss(box1, box3, 1.1923) - - assert_ciou_loss(box1, box4, 1.2500) - - assert_ciou_loss(box1s, box2s, 1.2250, reduction="mean") - assert_ciou_loss(box1s, box2s, 2.4500, reduction="sum") - - @pytest.mark.parametrize("device", cpu_and_gpu()) - @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) - def test_empty_inputs(self, dtype, device) -> None: - box1 = torch.randn([0, 4], dtype=dtype).requires_grad_() - box2 = torch.randn([0, 4], dtype=dtype).requires_grad_() - - loss = ops.complete_box_iou_loss(box1, box2, reduction="mean") - loss.backward() - - tol = 1e-3 if dtype is torch.half else 1e-5 - torch.testing.assert_close(loss, torch.tensor(0.0), rtol=tol, atol=tol) - assert box1.grad is not None, "box1.grad should not be None after backward is called" - assert box2.grad is not None, "box2.grad should not be None after backward is called" - - loss = ops.complete_box_iou_loss(box1, box2, reduction="none") - assert loss.numel() == 0, "ciou_loss for two empty box should be empty" - - if __name__ == "__main__": pytest.main([__file__]) From 7f788f119d2479d50011daf96435ad626359e8ce Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Mon, 6 Jun 2022 17:00:10 +0530 Subject: [PATCH 02/14] Remove tol, fix comments --- test/test_losses.py | 22 +++++++++++----------- test/test_ops.py | 3 --- 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/test/test_losses.py b/test/test_losses.py index c2d7f9452ef..34dd6a93296 100644 --- a/test/test_losses.py +++ b/test/test_losses.py @@ -18,10 +18,10 @@ def get_boxes(dtype, device): def assert_iou_loss(iou_fn, box1, box2, expected_loss, dtype, device, reduction="none"): - tol = 1e-3 if dtype is torch.half else 1e-5 + # tol = 1e-3 if dtype is torch.half else 1e-5 computed_loss = iou_fn(box1, box2, reduction=reduction) expected_loss = torch.tensor(expected_loss, device=device) - torch.testing.assert_close(computed_loss, expected_loss, rtol=tol, atol=tol) + torch.testing.assert_close(computed_loss, expected_loss) def assert_empty_loss(iou_fn, dtype, device): @@ -29,8 +29,8 @@ def assert_empty_loss(iou_fn, dtype, device): box2 = torch.randn([0, 4], dtype=dtype, device=device).requires_grad_() loss = iou_fn(box1, box2, reduction="mean") loss.backward() - tol = 1e-3 if dtype is torch.half else 1e-5 - torch.testing.assert_close(loss, torch.tensor(0.0, device=device), rtol=tol, atol=tol) + # tol = 1e-3 if dtype is torch.half else 1e-5 + torch.testing.assert_close(loss, torch.tensor(0.0, device=device)) assert box1.grad is not None, "box1.grad should not be None after backward is called" assert box2.grad is not None, "box2.grad should not be None after backward is called" loss = iou_fn(box1, box2, reduction="none") @@ -171,8 +171,8 @@ def test_correct_ratio(self, alpha, gamma, device, dtype, seed): alpha_t = alpha * targets + (1 - alpha) * (1 - targets) correct_ratio = correct_ratio * alpha_t - tol = 1e-3 if dtype is torch.half else 1e-5 - torch.testing.assert_close(correct_ratio, loss_ratio, rtol=tol, atol=tol) + # tol = 1e-3 if dtype is torch.half else 1e-5 + torch.testing.assert_close(correct_ratio, loss_ratio) @pytest.mark.parametrize("reduction", ["mean", "sum"]) @pytest.mark.parametrize("device", cpu_and_gpu()) @@ -193,12 +193,12 @@ def test_equal_ce_loss(self, reduction, device, dtype, seed): focal_loss = ops.sigmoid_focal_loss(inputs_fl, targets_fl, gamma=gamma, alpha=alpha, reduction=reduction) ce_loss = F.binary_cross_entropy_with_logits(inputs_ce, targets_ce, reduction=reduction) - tol = 1e-3 if dtype is torch.half else 1e-5 - torch.testing.assert_close(focal_loss, ce_loss, rtol=tol, atol=tol) + # tol = 1e-3 if dtype is torch.half else 1e-5 + torch.testing.assert_close(focal_loss, ce_loss) focal_loss.backward() ce_loss.backward() - torch.testing.assert_close(inputs_fl.grad, inputs_ce.grad, rtol=tol, atol=tol) + torch.testing.assert_close(inputs_fl.grad, inputs_ce.grad) @pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0]) @pytest.mark.parametrize("gamma", [0, 2]) @@ -221,8 +221,8 @@ def test_jit(self, alpha, gamma, reduction, device, dtype, seed): # We may remove this condition once the bug is resolved scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) - tol = 1e-3 if dtype is torch.half else 1e-5 - torch.testing.assert_close(focal_loss, scripted_focal_loss, rtol=tol, atol=tol) + # tol = 1e-3 if dtype is torch.half else 1e-5 + torch.testing.assert_close(focal_loss, scripted_focal_loss) if __name__ == "__main__": diff --git a/test/test_ops.py b/test/test_ops.py index df5d397713c..ae7810f0591 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1119,12 +1119,10 @@ def area_check(box, expected, tolerance=1e-4): class TestBoxArea: @pytest.mark.parametrize("dtype", [torch.int8, torch.int16, torch.int32, torch.int64]) def test_int_boxes(self, dtype): - # Check for int boxes box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=dtype) expected = torch.tensor([10000, 0]) area_check(box_tensor, expected) - # Check for float32 and float64 boxes @pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) def test_float_boxes(self, dtype): box_tensor = torch.tensor( @@ -1139,7 +1137,6 @@ def test_float_boxes(self, dtype): area_check(box_tensor, expected, tolerance=0.05) def test_float16_box(self): - # Check for float16 box box_tensor = torch.tensor( [[285.25, 185.625, 1194.0, 851.5], [285.25, 188.75, 1192.0, 851.0], [279.25, 198.0, 1189.0, 849.0]], dtype=torch.float16, From 6ab501adb60e5a0880ad543bf7fce3d6f35d8aca Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Mon, 6 Jun 2022 18:16:32 +0530 Subject: [PATCH 03/14] Add tolerance only where necessary --- test/test_losses.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/test/test_losses.py b/test/test_losses.py index 34dd6a93296..6c0a1cf711a 100644 --- a/test/test_losses.py +++ b/test/test_losses.py @@ -18,7 +18,6 @@ def get_boxes(dtype, device): def assert_iou_loss(iou_fn, box1, box2, expected_loss, dtype, device, reduction="none"): - # tol = 1e-3 if dtype is torch.half else 1e-5 computed_loss = iou_fn(box1, box2, reduction=reduction) expected_loss = torch.tensor(expected_loss, device=device) torch.testing.assert_close(computed_loss, expected_loss) @@ -29,7 +28,6 @@ def assert_empty_loss(iou_fn, dtype, device): box2 = torch.randn([0, 4], dtype=dtype, device=device).requires_grad_() loss = iou_fn(box1, box2, reduction="mean") loss.backward() - # tol = 1e-3 if dtype is torch.half else 1e-5 torch.testing.assert_close(loss, torch.tensor(0.0, device=device)) assert box1.grad is not None, "box1.grad should not be None after backward is called" assert box2.grad is not None, "box2.grad should not be None after backward is called" @@ -171,7 +169,6 @@ def test_correct_ratio(self, alpha, gamma, device, dtype, seed): alpha_t = alpha * targets + (1 - alpha) * (1 - targets) correct_ratio = correct_ratio * alpha_t - # tol = 1e-3 if dtype is torch.half else 1e-5 torch.testing.assert_close(correct_ratio, loss_ratio) @pytest.mark.parametrize("reduction", ["mean", "sum"]) @@ -193,8 +190,8 @@ def test_equal_ce_loss(self, reduction, device, dtype, seed): focal_loss = ops.sigmoid_focal_loss(inputs_fl, targets_fl, gamma=gamma, alpha=alpha, reduction=reduction) ce_loss = F.binary_cross_entropy_with_logits(inputs_ce, targets_ce, reduction=reduction) - # tol = 1e-3 if dtype is torch.half else 1e-5 - torch.testing.assert_close(focal_loss, ce_loss) + tol = 1e-3 if dtype is torch.half else 1e-5 + torch.testing.assert_close(focal_loss, ce_loss, atol=tol, rtol=tol) focal_loss.backward() ce_loss.backward() @@ -221,8 +218,8 @@ def test_jit(self, alpha, gamma, reduction, device, dtype, seed): # We may remove this condition once the bug is resolved scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) - # tol = 1e-3 if dtype is torch.half else 1e-5 - torch.testing.assert_close(focal_loss, scripted_focal_loss) + tol = 1e-3 if dtype is torch.half else 1e-5 + torch.testing.assert_close(focal_loss, scripted_focal_loss, rtol=tol, atol=tol) if __name__ == "__main__": From b83d745c7bc67c71827055e0aa2954cfc353d363 Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Mon, 6 Jun 2022 18:53:20 +0530 Subject: [PATCH 04/14] Add tolerance only where necessary --- test/test_losses.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_losses.py b/test/test_losses.py index 6c0a1cf711a..c7659e05119 100644 --- a/test/test_losses.py +++ b/test/test_losses.py @@ -169,7 +169,8 @@ def test_correct_ratio(self, alpha, gamma, device, dtype, seed): alpha_t = alpha * targets + (1 - alpha) * (1 - targets) correct_ratio = correct_ratio * alpha_t - torch.testing.assert_close(correct_ratio, loss_ratio) + tol = 1e-3 if dtype is torch.half else 1e-5 + torch.testing.assert_close(correct_ratio, loss_ratio, atol=tol, rtol=tol) @pytest.mark.parametrize("reduction", ["mean", "sum"]) @pytest.mark.parametrize("device", cpu_and_gpu()) From 7e49682864cbbc421cb2668a22a304fb05b4d020 Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Mon, 6 Jun 2022 18:53:43 +0530 Subject: [PATCH 05/14] Add tolerance only where necessary --- test/test_losses.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/test_losses.py b/test/test_losses.py index c7659e05119..87398c44bfa 100644 --- a/test/test_losses.py +++ b/test/test_losses.py @@ -191,8 +191,7 @@ def test_equal_ce_loss(self, reduction, device, dtype, seed): focal_loss = ops.sigmoid_focal_loss(inputs_fl, targets_fl, gamma=gamma, alpha=alpha, reduction=reduction) ce_loss = F.binary_cross_entropy_with_logits(inputs_ce, targets_ce, reduction=reduction) - tol = 1e-3 if dtype is torch.half else 1e-5 - torch.testing.assert_close(focal_loss, ce_loss, atol=tol, rtol=tol) + torch.testing.assert_close(focal_loss, ce_loss) focal_loss.backward() ce_loss.backward() From 485d1fc2b612746e1eb88aaf09053fd54b9c603e Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Tue, 7 Jun 2022 00:00:39 +0530 Subject: [PATCH 06/14] Refactor to adapt suggestions --- test/test_ious.py | 93 +++++++++++++++------------------------------ test/test_losses.py | 4 +- test/test_ops.py | 4 +- 3 files changed, 35 insertions(+), 66 deletions(-) diff --git a/test/test_ious.py b/test/test_ious.py index 4e87d64b477..84f2521e360 100644 --- a/test/test_ious.py +++ b/test/test_ious.py @@ -7,7 +7,7 @@ from torchvision import ops -class IouTestBase: +class TestIouBase: @staticmethod def _run_test(target_fn: Callable, test_input: List, dtypes: List[torch.dtype], tolerance: float, expected: List): def assert_close(box: Tensor, expected: Tensor, tolerance): @@ -28,40 +28,24 @@ def _run_jit_test(target_fn: Callable, test_input: List): torch.testing.assert_close(scripted_out, expected, rtol=0.0, atol=1e-3) -def _generate_int_input(): - return [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]] +IOU_INT_BOXES = [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]] +IOU_FLOAT_BOXES = [ + [285.3538, 185.5758, 1193.5110, 851.4551], + [285.1472, 188.7374, 1192.4984, 851.0669], + [279.2440, 197.9812, 1189.4746, 849.2019], +] -def _generate_float_input(): - return [ - [285.3538, 185.5758, 1193.5110, 851.4551], - [285.1472, 188.7374, 1192.4984, 851.0669], - [279.2440, 197.9812, 1189.4746, 849.2019], - ] - - -class TestBoxIou(IouTestBase): - def _generate_int_expected(): - return [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]] - - def _generate_float_input(): - return [ - [285.3538, 185.5758, 1193.5110, 851.4551], - [285.1472, 188.7374, 1192.4984, 851.0669], - [279.2440, 197.9812, 1189.4746, 849.2019], - ] - - def _generate_float_expected(): - return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] +class TestBoxIou(TestIouBase): + generate_int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]] + float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] @pytest.mark.parametrize( "test_input, dtypes, tolerance, expected", [ - pytest.param( - _generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected() - ), - pytest.param(_generate_float_input(), [torch.float16], 0.002, _generate_float_expected()), - pytest.param(_generate_float_input(), [torch.float32, torch.float64], 1e-3, _generate_float_expected()), + pytest.param(IOU_INT_BOXES, [torch.int16, torch.int32, torch.int64], 1e-4, generate_int_expected), + pytest.param(IOU_FLOAT_BOXES, [torch.float16], 0.002, float_expected), + pytest.param(IOU_FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected), ], ) def test_iou(self, test_input, dtypes, tolerance, expected): @@ -71,21 +55,16 @@ def test_iou_jit(self): self._run_jit_test(ops.box_iou, [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) -class TestGenBoxIou(IouTestBase): - def _generate_int_expected(): - return [[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0]] - - def _generate_float_expected(): - return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] +class TestGeneralizedBoxIou(TestIouBase): + int_expected = [[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0]] + float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] @pytest.mark.parametrize( "test_input, dtypes, tolerance, expected", [ - pytest.param( - _generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected() - ), - pytest.param(_generate_float_input(), [torch.float16], 0.002, _generate_float_expected()), - pytest.param(_generate_float_input(), [torch.float32, torch.float64], 1e-3, _generate_float_expected()), + pytest.param(IOU_INT_BOXES, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected), + pytest.param(IOU_FLOAT_BOXES, [torch.float16], 0.002, float_expected), + pytest.param(IOU_FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected), ], ) def test_iou(self, test_input, dtypes, tolerance, expected): @@ -95,21 +74,16 @@ def test_iou_jit(self): self._run_jit_test(ops.generalized_box_iou, [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) -class TestDistanceBoxIoU(IouTestBase): - def _generate_int_expected(): - return [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]] - - def _generate_float_expected(): - return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] +class TestDistanceBoxIoU(TestIouBase): + int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]] + float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] @pytest.mark.parametrize( "test_input, dtypes, tolerance, expected", [ - pytest.param( - _generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected() - ), - pytest.param(_generate_float_input(), [torch.float16], 0.002, _generate_float_expected()), - pytest.param(_generate_float_input(), [torch.float32, torch.float64], 1e-3, _generate_float_expected()), + pytest.param(IOU_INT_BOXES, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected), + pytest.param(IOU_FLOAT_BOXES, [torch.float16], 0.002, float_expected), + pytest.param(IOU_FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected), ], ) def test_iou(self, test_input, dtypes, tolerance, expected): @@ -119,21 +93,16 @@ def test_iou_jit(self): self._run_jit_test(ops.distance_box_iou, [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) -class TestCompleteBoxIou(IouTestBase): - def _generate_int_expected(): - return [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]] - - def _generate_float_expected(): - return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] +class TestCompleteBoxIou(TestIouBase): + int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]] + float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] @pytest.mark.parametrize( "test_input, dtypes, tolerance, expected", [ - pytest.param( - _generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected() - ), - pytest.param(_generate_float_input(), [torch.float16], 0.002, _generate_float_expected()), - pytest.param(_generate_float_input(), [torch.float32, torch.float64], 1e-3, _generate_float_expected()), + pytest.param(IOU_INT_BOXES, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected), + pytest.param(IOU_FLOAT_BOXES, [torch.float16], 0.002, float_expected), + pytest.param(IOU_FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected), ], ) def test_iou(self, test_input, dtypes, tolerance, expected): diff --git a/test/test_losses.py b/test/test_losses.py index 87398c44bfa..098a56c518e 100644 --- a/test/test_losses.py +++ b/test/test_losses.py @@ -67,7 +67,7 @@ def test_empty_inputs(self, dtype, device): assert_empty_loss(ops.generalized_box_iou_loss, dtype, device) -class TestCIOULoss: +class TestCompleteBoxIouLoss: @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) @pytest.mark.parametrize("device", cpu_and_gpu()) def test_ciou_loss(self, dtype, device): @@ -86,7 +86,7 @@ def test_empty_inputs(self, dtype, device): assert_empty_loss(ops.complete_box_iou_loss, dtype, device) -class TestDIouLoss: +class TestDistanceBoxIouLoss: @pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) def test_distance_iou_loss(self, dtype, device): diff --git a/test/test_ops.py b/test/test_ops.py index ae7810f0591..0f1811cd901 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1120,7 +1120,7 @@ class TestBoxArea: @pytest.mark.parametrize("dtype", [torch.int8, torch.int16, torch.int32, torch.int64]) def test_int_boxes(self, dtype): box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=dtype) - expected = torch.tensor([10000, 0]) + expected = torch.tensor([10000, 0], dtype=torch.int32) area_check(box_tensor, expected) @pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) @@ -1157,7 +1157,7 @@ def test_masks_box(self): def masks_box_check(masks, expected, tolerance=1e-4): out = ops.masks_to_boxes(masks) assert out.dtype == torch.float - torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=tolerance) + torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=True, atol=tolerance) # Check for int type boxes. def _get_image(): From 5c8f4fb7b07c03c0000f264ee6faea7779e06d78 Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Tue, 7 Jun 2022 20:30:27 +0530 Subject: [PATCH 07/14] Refactor and add nits --- test/test_ious.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/test/test_ious.py b/test/test_ious.py index 84f2521e360..61a632c0374 100644 --- a/test/test_ious.py +++ b/test/test_ious.py @@ -2,22 +2,18 @@ import pytest import torch -import torch.fx -from torch import Tensor from torchvision import ops class TestIouBase: @staticmethod def _run_test(target_fn: Callable, test_input: List, dtypes: List[torch.dtype], tolerance: float, expected: List): - def assert_close(box: Tensor, expected: Tensor, tolerance): - out = target_fn(box, box) - torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=tolerance) - for dtype in dtypes: actual_box = torch.tensor(test_input, dtype=dtype) expected_box = torch.tensor(expected) - assert_close(actual_box, expected_box, tolerance) + out = target_fn(actual_box, actual_box) + torch.testing.assert_close(out, expected_box, rtol=0.0, check_dtype=False, atol=tolerance) + # assert_close(actual_box, expected_box, tolerance) @staticmethod def _run_jit_test(target_fn: Callable, test_input: List): @@ -52,7 +48,7 @@ def test_iou(self, test_input, dtypes, tolerance, expected): self._run_test(ops.box_iou, test_input, dtypes, tolerance, expected) def test_iou_jit(self): - self._run_jit_test(ops.box_iou, [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) + self._run_jit_test(ops.box_iou, IOU_INT_BOXES) class TestGeneralizedBoxIou(TestIouBase): @@ -71,7 +67,7 @@ def test_iou(self, test_input, dtypes, tolerance, expected): self._run_test(ops.generalized_box_iou, test_input, dtypes, tolerance, expected) def test_iou_jit(self): - self._run_jit_test(ops.generalized_box_iou, [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) + self._run_jit_test(ops.generalized_box_iou, IOU_INT_BOXES) class TestDistanceBoxIoU(TestIouBase): @@ -90,7 +86,7 @@ def test_iou(self, test_input, dtypes, tolerance, expected): self._run_test(ops.distance_box_iou, test_input, dtypes, tolerance, expected) def test_iou_jit(self): - self._run_jit_test(ops.distance_box_iou, [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) + self._run_jit_test(ops.distance_box_iou, IOU_INT_BOXES) class TestCompleteBoxIou(TestIouBase): @@ -109,7 +105,7 @@ def test_iou(self, test_input, dtypes, tolerance, expected): self._run_test(ops.complete_box_iou, test_input, dtypes, tolerance, expected) def test_iou_jit(self): - self._run_jit_test(ops.complete_box_iou, [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) + self._run_jit_test(ops.complete_box_iou, IOU_INT_BOXES) if __name__ == "__main__": From 1ed639fc2f4ffab5b7ac773a0a4fe3403a25bec2 Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Tue, 7 Jun 2022 20:32:23 +0530 Subject: [PATCH 08/14] Refactor box area --- test/test_ops.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 0f1811cd901..c52b6751ef8 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1111,17 +1111,16 @@ def test_bbox_convert_jit(self): torch.testing.assert_close(scripted_cxcywh, box_cxcywh, rtol=0.0, atol=TOLERANCE) -def area_check(box, expected, tolerance=1e-4): - out = ops.box_area(box) - torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=tolerance) - - class TestBoxArea: + def area_check(self, box, expected, tolerance=1e-4): + out = ops.box_area(box) + torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=tolerance) + @pytest.mark.parametrize("dtype", [torch.int8, torch.int16, torch.int32, torch.int64]) def test_int_boxes(self, dtype): box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=dtype) expected = torch.tensor([10000, 0], dtype=torch.int32) - area_check(box_tensor, expected) + self.area_check(box_tensor, expected) @pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) def test_float_boxes(self, dtype): @@ -1134,7 +1133,7 @@ def test_float_boxes(self, dtype): dtype=dtype, ) expected = torch.tensor([604723.0806, 600965.4666, 592761.0085], dtype=torch.float64) - area_check(box_tensor, expected, tolerance=0.05) + self.area_check(box_tensor, expected, tolerance=0.05) def test_float16_box(self): box_tensor = torch.tensor( @@ -1142,7 +1141,7 @@ def test_float16_box(self): dtype=torch.float16, ) expected = torch.tensor([605113.875, 600495.1875, 592247.25]) - area_check(box_tensor, expected) + self.area_check(box_tensor, expected) def test_box_area_jit(self): box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=torch.float) From fd96c07ca3b136a1874496bcd75730ed59df6080 Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Tue, 7 Jun 2022 20:50:06 +0530 Subject: [PATCH 09/14] Refactor to one file --- test/test_ious.py | 112 --------------- test/test_losses.py | 226 ------------------------------- test/test_ops.py | 323 +++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 322 insertions(+), 339 deletions(-) delete mode 100644 test/test_ious.py delete mode 100644 test/test_losses.py diff --git a/test/test_ious.py b/test/test_ious.py deleted file mode 100644 index 61a632c0374..00000000000 --- a/test/test_ious.py +++ /dev/null @@ -1,112 +0,0 @@ -from typing import List, Callable - -import pytest -import torch -from torchvision import ops - - -class TestIouBase: - @staticmethod - def _run_test(target_fn: Callable, test_input: List, dtypes: List[torch.dtype], tolerance: float, expected: List): - for dtype in dtypes: - actual_box = torch.tensor(test_input, dtype=dtype) - expected_box = torch.tensor(expected) - out = target_fn(actual_box, actual_box) - torch.testing.assert_close(out, expected_box, rtol=0.0, check_dtype=False, atol=tolerance) - # assert_close(actual_box, expected_box, tolerance) - - @staticmethod - def _run_jit_test(target_fn: Callable, test_input: List): - box_tensor = torch.tensor(test_input, dtype=torch.float) - expected = target_fn(box_tensor, box_tensor) - scripted_fn = torch.jit.script(target_fn) - scripted_out = scripted_fn(box_tensor, box_tensor) - torch.testing.assert_close(scripted_out, expected, rtol=0.0, atol=1e-3) - - -IOU_INT_BOXES = [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]] -IOU_FLOAT_BOXES = [ - [285.3538, 185.5758, 1193.5110, 851.4551], - [285.1472, 188.7374, 1192.4984, 851.0669], - [279.2440, 197.9812, 1189.4746, 849.2019], -] - - -class TestBoxIou(TestIouBase): - generate_int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]] - float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] - - @pytest.mark.parametrize( - "test_input, dtypes, tolerance, expected", - [ - pytest.param(IOU_INT_BOXES, [torch.int16, torch.int32, torch.int64], 1e-4, generate_int_expected), - pytest.param(IOU_FLOAT_BOXES, [torch.float16], 0.002, float_expected), - pytest.param(IOU_FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected), - ], - ) - def test_iou(self, test_input, dtypes, tolerance, expected): - self._run_test(ops.box_iou, test_input, dtypes, tolerance, expected) - - def test_iou_jit(self): - self._run_jit_test(ops.box_iou, IOU_INT_BOXES) - - -class TestGeneralizedBoxIou(TestIouBase): - int_expected = [[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0]] - float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] - - @pytest.mark.parametrize( - "test_input, dtypes, tolerance, expected", - [ - pytest.param(IOU_INT_BOXES, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected), - pytest.param(IOU_FLOAT_BOXES, [torch.float16], 0.002, float_expected), - pytest.param(IOU_FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected), - ], - ) - def test_iou(self, test_input, dtypes, tolerance, expected): - self._run_test(ops.generalized_box_iou, test_input, dtypes, tolerance, expected) - - def test_iou_jit(self): - self._run_jit_test(ops.generalized_box_iou, IOU_INT_BOXES) - - -class TestDistanceBoxIoU(TestIouBase): - int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]] - float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] - - @pytest.mark.parametrize( - "test_input, dtypes, tolerance, expected", - [ - pytest.param(IOU_INT_BOXES, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected), - pytest.param(IOU_FLOAT_BOXES, [torch.float16], 0.002, float_expected), - pytest.param(IOU_FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected), - ], - ) - def test_iou(self, test_input, dtypes, tolerance, expected): - self._run_test(ops.distance_box_iou, test_input, dtypes, tolerance, expected) - - def test_iou_jit(self): - self._run_jit_test(ops.distance_box_iou, IOU_INT_BOXES) - - -class TestCompleteBoxIou(TestIouBase): - int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]] - float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] - - @pytest.mark.parametrize( - "test_input, dtypes, tolerance, expected", - [ - pytest.param(IOU_INT_BOXES, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected), - pytest.param(IOU_FLOAT_BOXES, [torch.float16], 0.002, float_expected), - pytest.param(IOU_FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected), - ], - ) - def test_iou(self, test_input, dtypes, tolerance, expected): - self._run_test(ops.complete_box_iou, test_input, dtypes, tolerance, expected) - - def test_iou_jit(self): - self._run_jit_test(ops.complete_box_iou, IOU_INT_BOXES) - - -if __name__ == "__main__": - pytest.main([__file__]) diff --git a/test/test_losses.py b/test/test_losses.py deleted file mode 100644 index 098a56c518e..00000000000 --- a/test/test_losses.py +++ /dev/null @@ -1,226 +0,0 @@ -import pytest -import torch -import torch.nn.functional as F -from common_utils import cpu_and_gpu -from torchvision import ops - - -def get_boxes(dtype, device): - box1 = torch.tensor([-1, -1, 1, 1], dtype=dtype, device=device) - box2 = torch.tensor([0, 0, 1, 1], dtype=dtype, device=device) - box3 = torch.tensor([0, 1, 1, 2], dtype=dtype, device=device) - box4 = torch.tensor([1, 1, 2, 2], dtype=dtype, device=device) - - box1s = torch.stack([box2, box2], dim=0) - box2s = torch.stack([box3, box4], dim=0) - - return box1, box2, box3, box4, box1s, box2s - - -def assert_iou_loss(iou_fn, box1, box2, expected_loss, dtype, device, reduction="none"): - computed_loss = iou_fn(box1, box2, reduction=reduction) - expected_loss = torch.tensor(expected_loss, device=device) - torch.testing.assert_close(computed_loss, expected_loss) - - -def assert_empty_loss(iou_fn, dtype, device): - box1 = torch.randn([0, 4], dtype=dtype, device=device).requires_grad_() - box2 = torch.randn([0, 4], dtype=dtype, device=device).requires_grad_() - loss = iou_fn(box1, box2, reduction="mean") - loss.backward() - torch.testing.assert_close(loss, torch.tensor(0.0, device=device)) - assert box1.grad is not None, "box1.grad should not be None after backward is called" - assert box2.grad is not None, "box2.grad should not be None after backward is called" - loss = iou_fn(box1, box2, reduction="none") - assert loss.numel() == 0, f"{str(iou_fn)} for two empty box should be empty" - - -class TestGeneralizedBoxIouLoss: - # We refer to original test: https://github.com/facebookresearch/fvcore/blob/main/tests/test_giou_loss.py - @pytest.mark.parametrize("device", cpu_and_gpu()) - @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) - def test_giou_loss(self, dtype, device): - - box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device) - - # Identical boxes should have loss of 0 - assert_iou_loss(ops.generalized_box_iou_loss, box1, box1, 0.0, dtype=dtype, device=device) - - # quarter size box inside other box = IoU of 0.25 - assert_iou_loss(ops.generalized_box_iou_loss, box1, box2, 0.75, dtype=dtype, device=device) - - # Two side by side boxes, area=union - # IoU=0 and GIoU=0 (loss 1.0) - assert_iou_loss(ops.generalized_box_iou_loss, box2, box3, 1.0, dtype=dtype, device=device) - - # Two diagonally adjacent boxes, area=2*union - # IoU=0 and GIoU=-0.5 (loss 1.5) - assert_iou_loss(ops.generalized_box_iou_loss, box2, box4, 1.5, dtype=dtype, device=device) - - # Test batched loss and reductions - assert_iou_loss(ops.generalized_box_iou_loss, box1s, box2s, 2.5, dtype=dtype, device=device, reduction="sum") - assert_iou_loss(ops.generalized_box_iou_loss, box1s, box2s, 1.25, dtype=dtype, device=device, reduction="mean") - - @pytest.mark.parametrize("device", cpu_and_gpu()) - @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) - def test_empty_inputs(self, dtype, device): - assert_empty_loss(ops.generalized_box_iou_loss, dtype, device) - - -class TestCompleteBoxIouLoss: - @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) - @pytest.mark.parametrize("device", cpu_and_gpu()) - def test_ciou_loss(self, dtype, device): - box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device) - - assert_iou_loss(ops.complete_box_iou_loss, box1, box1, 0.0, dtype=dtype, device=device) - assert_iou_loss(ops.complete_box_iou_loss, box1, box2, 0.8125, dtype=dtype, device=device) - assert_iou_loss(ops.complete_box_iou_loss, box1, box3, 1.1923, dtype=dtype, device=device) - assert_iou_loss(ops.complete_box_iou_loss, box1, box4, 1.2500, dtype=dtype, device=device) - assert_iou_loss(ops.complete_box_iou_loss, box1s, box2s, 1.2250, dtype=dtype, device=device, reduction="mean") - assert_iou_loss(ops.complete_box_iou_loss, box1s, box2s, 2.4500, dtype=dtype, device=device, reduction="sum") - - @pytest.mark.parametrize("device", cpu_and_gpu()) - @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) - def test_empty_inputs(self, dtype, device): - assert_empty_loss(ops.complete_box_iou_loss, dtype, device) - - -class TestDistanceBoxIouLoss: - @pytest.mark.parametrize("device", cpu_and_gpu()) - @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) - def test_distance_iou_loss(self, dtype, device): - box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device) - - assert_iou_loss(ops.distance_box_iou_loss, box1, box1, 0.0, dtype=dtype, device=device) - assert_iou_loss(ops.distance_box_iou_loss, box1, box2, 0.8125, dtype=dtype, device=device) - assert_iou_loss(ops.distance_box_iou_loss, box1, box3, 1.1923, dtype=dtype, device=device) - assert_iou_loss(ops.distance_box_iou_loss, box1, box4, 1.2500, dtype=dtype, device=device) - assert_iou_loss(ops.distance_box_iou_loss, box1s, box2s, 1.2250, dtype=dtype, device=device, reduction="mean") - assert_iou_loss(ops.distance_box_iou_loss, box1s, box2s, 2.4500, dtype=dtype, device=device, reduction="sum") - - @pytest.mark.parametrize("device", cpu_and_gpu()) - @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) - def test_empty_distance_iou_inputs(self, dtype, device): - assert_empty_loss(ops.distance_box_iou_loss, dtype, device) - - -class TestFocalLoss: - def _generate_diverse_input_target_pair(self, shape=(5, 2), **kwargs): - def logit(p): - return torch.log(p / (1 - p)) - - def generate_tensor_with_range_type(shape, range_type, **kwargs): - if range_type != "random_binary": - low, high = { - "small": (0.0, 0.2), - "big": (0.8, 1.0), - "zeros": (0.0, 0.0), - "ones": (1.0, 1.0), - "random": (0.0, 1.0), - }[range_type] - return torch.testing.make_tensor(shape, low=low, high=high, **kwargs) - else: - return torch.randint(0, 2, shape, **kwargs) - - # This function will return inputs and targets with shape: (shape[0]*9, shape[1]) - inputs = [] - targets = [] - for input_range_type, target_range_type in [ - ("small", "zeros"), - ("small", "ones"), - ("small", "random_binary"), - ("big", "zeros"), - ("big", "ones"), - ("big", "random_binary"), - ("random", "zeros"), - ("random", "ones"), - ("random", "random_binary"), - ]: - inputs.append(logit(generate_tensor_with_range_type(shape, input_range_type, **kwargs))) - targets.append(generate_tensor_with_range_type(shape, target_range_type, **kwargs)) - - return torch.cat(inputs), torch.cat(targets) - - @pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0]) - @pytest.mark.parametrize("gamma", [0, 2]) - @pytest.mark.parametrize("device", cpu_and_gpu()) - @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) - @pytest.mark.parametrize("seed", [0, 1]) - def test_correct_ratio(self, alpha, gamma, device, dtype, seed): - if device == "cpu" and dtype is torch.half: - pytest.skip("Currently torch.half is not fully supported on cpu") - # For testing the ratio with manual calculation, we require the reduction to be "none" - reduction = "none" - torch.random.manual_seed(seed) - inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device) - focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) - ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction=reduction) - - assert torch.all( - focal_loss <= ce_loss - ), "focal loss must be less or equal to cross entropy loss with same input" - - loss_ratio = (focal_loss / ce_loss).squeeze() - prob = torch.sigmoid(inputs) - p_t = prob * targets + (1 - prob) * (1 - targets) - correct_ratio = (1.0 - p_t) ** gamma - if alpha >= 0: - alpha_t = alpha * targets + (1 - alpha) * (1 - targets) - correct_ratio = correct_ratio * alpha_t - - tol = 1e-3 if dtype is torch.half else 1e-5 - torch.testing.assert_close(correct_ratio, loss_ratio, atol=tol, rtol=tol) - - @pytest.mark.parametrize("reduction", ["mean", "sum"]) - @pytest.mark.parametrize("device", cpu_and_gpu()) - @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) - @pytest.mark.parametrize("seed", [2, 3]) - def test_equal_ce_loss(self, reduction, device, dtype, seed): - if device == "cpu" and dtype is torch.half: - pytest.skip("Currently torch.half is not fully supported on cpu") - # focal loss should be equal ce_loss if alpha=-1 and gamma=0 - alpha = -1 - gamma = 0 - torch.random.manual_seed(seed) - inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device) - inputs_fl = inputs.clone().requires_grad_() - targets_fl = targets.clone() - inputs_ce = inputs.clone().requires_grad_() - targets_ce = targets.clone() - focal_loss = ops.sigmoid_focal_loss(inputs_fl, targets_fl, gamma=gamma, alpha=alpha, reduction=reduction) - ce_loss = F.binary_cross_entropy_with_logits(inputs_ce, targets_ce, reduction=reduction) - - torch.testing.assert_close(focal_loss, ce_loss) - - focal_loss.backward() - ce_loss.backward() - torch.testing.assert_close(inputs_fl.grad, inputs_ce.grad) - - @pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0]) - @pytest.mark.parametrize("gamma", [0, 2]) - @pytest.mark.parametrize("reduction", ["none", "mean", "sum"]) - @pytest.mark.parametrize("device", cpu_and_gpu()) - @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) - @pytest.mark.parametrize("seed", [4, 5]) - def test_jit(self, alpha, gamma, reduction, device, dtype, seed): - if device == "cpu" and dtype is torch.half: - pytest.skip("Currently torch.half is not fully supported on cpu") - script_fn = torch.jit.script(ops.sigmoid_focal_loss) - torch.random.manual_seed(seed) - inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device) - focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) - if device == "cpu": - scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) - else: - with torch.jit.fuser("fuser2"): - # Use fuser2 to prevent a bug on fuser: https://github.com/pytorch/pytorch/issues/75476 - # We may remove this condition once the bug is resolved - scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) - - tol = 1e-3 if dtype is torch.half else 1e-5 - torch.testing.assert_close(focal_loss, scripted_focal_loss, rtol=tol, atol=tol) - - -if __name__ == "__main__": - pytest.main([__file__]) diff --git a/test/test_ops.py b/test/test_ops.py index c52b6751ef8..4f918f35d58 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -3,12 +3,13 @@ from abc import ABC, abstractmethod from functools import lru_cache from itertools import product -from typing import Tuple +from typing import Tuple, Callable, List import numpy as np import pytest import torch import torch.fx +import torch.nn.functional as F from common_utils import assert_equal, cpu_and_gpu, needs_cuda from PIL import Image from torch import nn, Tensor @@ -1151,6 +1152,326 @@ def test_box_area_jit(self): torch.testing.assert_close(scripted_area, expected, rtol=0.0, atol=1e-3) +class TestIouBase: + @staticmethod + def _run_test(target_fn: Callable, test_input: List, dtypes: List[torch.dtype], tolerance: float, expected: List): + for dtype in dtypes: + actual_box = torch.tensor(test_input, dtype=dtype) + expected_box = torch.tensor(expected) + out = target_fn(actual_box, actual_box) + torch.testing.assert_close(out, expected_box, rtol=0.0, check_dtype=False, atol=tolerance) + # assert_close(actual_box, expected_box, tolerance) + + @staticmethod + def _run_jit_test(target_fn: Callable, test_input: List): + box_tensor = torch.tensor(test_input, dtype=torch.float) + expected = target_fn(box_tensor, box_tensor) + scripted_fn = torch.jit.script(target_fn) + scripted_out = scripted_fn(box_tensor, box_tensor) + torch.testing.assert_close(scripted_out, expected, rtol=0.0, atol=1e-3) + + +IOU_INT_BOXES = [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]] +IOU_FLOAT_BOXES = [ + [285.3538, 185.5758, 1193.5110, 851.4551], + [285.1472, 188.7374, 1192.4984, 851.0669], + [279.2440, 197.9812, 1189.4746, 849.2019], +] + + +class TestBoxIou(TestIouBase): + generate_int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]] + float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] + + @pytest.mark.parametrize( + "test_input, dtypes, tolerance, expected", + [ + pytest.param(IOU_INT_BOXES, [torch.int16, torch.int32, torch.int64], 1e-4, generate_int_expected), + pytest.param(IOU_FLOAT_BOXES, [torch.float16], 0.002, float_expected), + pytest.param(IOU_FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected), + ], + ) + def test_iou(self, test_input, dtypes, tolerance, expected): + self._run_test(ops.box_iou, test_input, dtypes, tolerance, expected) + + def test_iou_jit(self): + self._run_jit_test(ops.box_iou, IOU_INT_BOXES) + + +class TestGeneralizedBoxIou(TestIouBase): + int_expected = [[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0]] + float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] + + @pytest.mark.parametrize( + "test_input, dtypes, tolerance, expected", + [ + pytest.param(IOU_INT_BOXES, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected), + pytest.param(IOU_FLOAT_BOXES, [torch.float16], 0.002, float_expected), + pytest.param(IOU_FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected), + ], + ) + def test_iou(self, test_input, dtypes, tolerance, expected): + self._run_test(ops.generalized_box_iou, test_input, dtypes, tolerance, expected) + + def test_iou_jit(self): + self._run_jit_test(ops.generalized_box_iou, IOU_INT_BOXES) + + +class TestDistanceBoxIoU(TestIouBase): + int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]] + float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] + + @pytest.mark.parametrize( + "test_input, dtypes, tolerance, expected", + [ + pytest.param(IOU_INT_BOXES, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected), + pytest.param(IOU_FLOAT_BOXES, [torch.float16], 0.002, float_expected), + pytest.param(IOU_FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected), + ], + ) + def test_iou(self, test_input, dtypes, tolerance, expected): + self._run_test(ops.distance_box_iou, test_input, dtypes, tolerance, expected) + + def test_iou_jit(self): + self._run_jit_test(ops.distance_box_iou, IOU_INT_BOXES) + + +class TestCompleteBoxIou(TestIouBase): + int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]] + float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] + + @pytest.mark.parametrize( + "test_input, dtypes, tolerance, expected", + [ + pytest.param(IOU_INT_BOXES, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected), + pytest.param(IOU_FLOAT_BOXES, [torch.float16], 0.002, float_expected), + pytest.param(IOU_FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected), + ], + ) + def test_iou(self, test_input, dtypes, tolerance, expected): + self._run_test(ops.complete_box_iou, test_input, dtypes, tolerance, expected) + + def test_iou_jit(self): + self._run_jit_test(ops.complete_box_iou, IOU_INT_BOXES) + + +def get_boxes(dtype, device): + box1 = torch.tensor([-1, -1, 1, 1], dtype=dtype, device=device) + box2 = torch.tensor([0, 0, 1, 1], dtype=dtype, device=device) + box3 = torch.tensor([0, 1, 1, 2], dtype=dtype, device=device) + box4 = torch.tensor([1, 1, 2, 2], dtype=dtype, device=device) + + box1s = torch.stack([box2, box2], dim=0) + box2s = torch.stack([box3, box4], dim=0) + + return box1, box2, box3, box4, box1s, box2s + + +def assert_iou_loss(iou_fn, box1, box2, expected_loss, dtype, device, reduction="none"): + computed_loss = iou_fn(box1, box2, reduction=reduction) + expected_loss = torch.tensor(expected_loss, device=device) + torch.testing.assert_close(computed_loss, expected_loss) + + +def assert_empty_loss(iou_fn, dtype, device): + box1 = torch.randn([0, 4], dtype=dtype, device=device).requires_grad_() + box2 = torch.randn([0, 4], dtype=dtype, device=device).requires_grad_() + loss = iou_fn(box1, box2, reduction="mean") + loss.backward() + torch.testing.assert_close(loss, torch.tensor(0.0, device=device)) + assert box1.grad is not None, "box1.grad should not be None after backward is called" + assert box2.grad is not None, "box2.grad should not be None after backward is called" + loss = iou_fn(box1, box2, reduction="none") + assert loss.numel() == 0, f"{str(iou_fn)} for two empty box should be empty" + + +class TestGeneralizedBoxIouLoss: + # We refer to original test: https://github.com/facebookresearch/fvcore/blob/main/tests/test_giou_loss.py + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) + def test_giou_loss(self, dtype, device): + + box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device) + + # Identical boxes should have loss of 0 + assert_iou_loss(ops.generalized_box_iou_loss, box1, box1, 0.0, dtype=dtype, device=device) + + # quarter size box inside other box = IoU of 0.25 + assert_iou_loss(ops.generalized_box_iou_loss, box1, box2, 0.75, dtype=dtype, device=device) + + # Two side by side boxes, area=union + # IoU=0 and GIoU=0 (loss 1.0) + assert_iou_loss(ops.generalized_box_iou_loss, box2, box3, 1.0, dtype=dtype, device=device) + + # Two diagonally adjacent boxes, area=2*union + # IoU=0 and GIoU=-0.5 (loss 1.5) + assert_iou_loss(ops.generalized_box_iou_loss, box2, box4, 1.5, dtype=dtype, device=device) + + # Test batched loss and reductions + assert_iou_loss(ops.generalized_box_iou_loss, box1s, box2s, 2.5, dtype=dtype, device=device, reduction="sum") + assert_iou_loss(ops.generalized_box_iou_loss, box1s, box2s, 1.25, dtype=dtype, device=device, reduction="mean") + + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) + def test_empty_inputs(self, dtype, device): + assert_empty_loss(ops.generalized_box_iou_loss, dtype, device) + + +class TestCompleteBoxIouLoss: + @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) + @pytest.mark.parametrize("device", cpu_and_gpu()) + def test_ciou_loss(self, dtype, device): + box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device) + + assert_iou_loss(ops.complete_box_iou_loss, box1, box1, 0.0, dtype=dtype, device=device) + assert_iou_loss(ops.complete_box_iou_loss, box1, box2, 0.8125, dtype=dtype, device=device) + assert_iou_loss(ops.complete_box_iou_loss, box1, box3, 1.1923, dtype=dtype, device=device) + assert_iou_loss(ops.complete_box_iou_loss, box1, box4, 1.2500, dtype=dtype, device=device) + assert_iou_loss(ops.complete_box_iou_loss, box1s, box2s, 1.2250, dtype=dtype, device=device, reduction="mean") + assert_iou_loss(ops.complete_box_iou_loss, box1s, box2s, 2.4500, dtype=dtype, device=device, reduction="sum") + + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) + def test_empty_inputs(self, dtype, device): + assert_empty_loss(ops.complete_box_iou_loss, dtype, device) + + +class TestDistanceBoxIouLoss: + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) + def test_distance_iou_loss(self, dtype, device): + box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device) + + assert_iou_loss(ops.distance_box_iou_loss, box1, box1, 0.0, dtype=dtype, device=device) + assert_iou_loss(ops.distance_box_iou_loss, box1, box2, 0.8125, dtype=dtype, device=device) + assert_iou_loss(ops.distance_box_iou_loss, box1, box3, 1.1923, dtype=dtype, device=device) + assert_iou_loss(ops.distance_box_iou_loss, box1, box4, 1.2500, dtype=dtype, device=device) + assert_iou_loss(ops.distance_box_iou_loss, box1s, box2s, 1.2250, dtype=dtype, device=device, reduction="mean") + assert_iou_loss(ops.distance_box_iou_loss, box1s, box2s, 2.4500, dtype=dtype, device=device, reduction="sum") + + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) + def test_empty_distance_iou_inputs(self, dtype, device): + assert_empty_loss(ops.distance_box_iou_loss, dtype, device) + + +class TestFocalLoss: + def _generate_diverse_input_target_pair(self, shape=(5, 2), **kwargs): + def logit(p): + return torch.log(p / (1 - p)) + + def generate_tensor_with_range_type(shape, range_type, **kwargs): + if range_type != "random_binary": + low, high = { + "small": (0.0, 0.2), + "big": (0.8, 1.0), + "zeros": (0.0, 0.0), + "ones": (1.0, 1.0), + "random": (0.0, 1.0), + }[range_type] + return torch.testing.make_tensor(shape, low=low, high=high, **kwargs) + else: + return torch.randint(0, 2, shape, **kwargs) + + # This function will return inputs and targets with shape: (shape[0]*9, shape[1]) + inputs = [] + targets = [] + for input_range_type, target_range_type in [ + ("small", "zeros"), + ("small", "ones"), + ("small", "random_binary"), + ("big", "zeros"), + ("big", "ones"), + ("big", "random_binary"), + ("random", "zeros"), + ("random", "ones"), + ("random", "random_binary"), + ]: + inputs.append(logit(generate_tensor_with_range_type(shape, input_range_type, **kwargs))) + targets.append(generate_tensor_with_range_type(shape, target_range_type, **kwargs)) + + return torch.cat(inputs), torch.cat(targets) + + @pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0]) + @pytest.mark.parametrize("gamma", [0, 2]) + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) + @pytest.mark.parametrize("seed", [0, 1]) + def test_correct_ratio(self, alpha, gamma, device, dtype, seed): + if device == "cpu" and dtype is torch.half: + pytest.skip("Currently torch.half is not fully supported on cpu") + # For testing the ratio with manual calculation, we require the reduction to be "none" + reduction = "none" + torch.random.manual_seed(seed) + inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device) + focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) + ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction=reduction) + + assert torch.all( + focal_loss <= ce_loss + ), "focal loss must be less or equal to cross entropy loss with same input" + + loss_ratio = (focal_loss / ce_loss).squeeze() + prob = torch.sigmoid(inputs) + p_t = prob * targets + (1 - prob) * (1 - targets) + correct_ratio = (1.0 - p_t) ** gamma + if alpha >= 0: + alpha_t = alpha * targets + (1 - alpha) * (1 - targets) + correct_ratio = correct_ratio * alpha_t + + tol = 1e-3 if dtype is torch.half else 1e-5 + torch.testing.assert_close(correct_ratio, loss_ratio, atol=tol, rtol=tol) + + @pytest.mark.parametrize("reduction", ["mean", "sum"]) + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) + @pytest.mark.parametrize("seed", [2, 3]) + def test_equal_ce_loss(self, reduction, device, dtype, seed): + if device == "cpu" and dtype is torch.half: + pytest.skip("Currently torch.half is not fully supported on cpu") + # focal loss should be equal ce_loss if alpha=-1 and gamma=0 + alpha = -1 + gamma = 0 + torch.random.manual_seed(seed) + inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device) + inputs_fl = inputs.clone().requires_grad_() + targets_fl = targets.clone() + inputs_ce = inputs.clone().requires_grad_() + targets_ce = targets.clone() + focal_loss = ops.sigmoid_focal_loss(inputs_fl, targets_fl, gamma=gamma, alpha=alpha, reduction=reduction) + ce_loss = F.binary_cross_entropy_with_logits(inputs_ce, targets_ce, reduction=reduction) + + torch.testing.assert_close(focal_loss, ce_loss) + + focal_loss.backward() + ce_loss.backward() + torch.testing.assert_close(inputs_fl.grad, inputs_ce.grad) + + @pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0]) + @pytest.mark.parametrize("gamma", [0, 2]) + @pytest.mark.parametrize("reduction", ["none", "mean", "sum"]) + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) + @pytest.mark.parametrize("seed", [4, 5]) + def test_jit(self, alpha, gamma, reduction, device, dtype, seed): + if device == "cpu" and dtype is torch.half: + pytest.skip("Currently torch.half is not fully supported on cpu") + script_fn = torch.jit.script(ops.sigmoid_focal_loss) + torch.random.manual_seed(seed) + inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device) + focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) + if device == "cpu": + scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) + else: + with torch.jit.fuser("fuser2"): + # Use fuser2 to prevent a bug on fuser: https://github.com/pytorch/pytorch/issues/75476 + # We may remove this condition once the bug is resolved + scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) + + tol = 1e-3 if dtype is torch.half else 1e-5 + torch.testing.assert_close(focal_loss, scripted_focal_loss, rtol=tol, atol=tol) + + class TestMasksToBoxes: def test_masks_box(self): def masks_box_check(masks, expected, tolerance=1e-4): From 5c00ebc3c041a9005d6dd82f97f738ffc41832b4 Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Thu, 9 Jun 2022 00:29:26 +0530 Subject: [PATCH 10/14] Adapt almost all except area --- test/test_ops.py | 91 ++++++++++++++++++++++++------------------------ 1 file changed, 45 insertions(+), 46 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 4f918f35d58..8470346b56f 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -995,7 +995,7 @@ def test_frozenbatchnorm2d_eps(self, seed): torch.testing.assert_close(fbn(x), bn(x), rtol=1e-5, atol=1e-6) -class TestBoxConversion: +class TestBoxConversionToRoi: def _get_box_sequences(): # Define here the argument type of `boxes` supported by region pooling operations box_tensor = torch.tensor([[0, 0, 0, 100, 100], [1, 0, 0, 100, 100]], dtype=torch.float) @@ -1101,21 +1101,21 @@ def test_bbox_convert_jit(self): ) scripted_fn = torch.jit.script(ops.box_convert) - TOLERANCE = 1e-3 + atol = 1e-3 box_xywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xywh") scripted_xywh = scripted_fn(box_tensor, "xyxy", "xywh") - torch.testing.assert_close(scripted_xywh, box_xywh, rtol=0.0, atol=TOLERANCE) + torch.testing.assert_close(scripted_xywh, box_xywh, rtol=0.0, atol=atol) box_cxcywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="cxcywh") scripted_cxcywh = scripted_fn(box_tensor, "xyxy", "cxcywh") - torch.testing.assert_close(scripted_cxcywh, box_cxcywh, rtol=0.0, atol=TOLERANCE) + torch.testing.assert_close(scripted_cxcywh, box_cxcywh, rtol=0.0, atol=atol) class TestBoxArea: - def area_check(self, box, expected, tolerance=1e-4): + def area_check(self, box, expected, atol=1e-4): out = ops.box_area(box) - torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=tolerance) + torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=atol) @pytest.mark.parametrize("dtype", [torch.int8, torch.int16, torch.int32, torch.int64]) def test_int_boxes(self, dtype): @@ -1133,8 +1133,8 @@ def test_float_boxes(self, dtype): ], dtype=dtype, ) - expected = torch.tensor([604723.0806, 600965.4666, 592761.0085], dtype=torch.float64) - self.area_check(box_tensor, expected, tolerance=0.05) + expected = torch.tensor([604723.0806, 600965.4666, 592761.0085], dtype=dtype) + self.area_check(box_tensor, expected, atol=0.05) def test_float16_box(self): box_tensor = torch.tensor( @@ -1154,13 +1154,12 @@ def test_box_area_jit(self): class TestIouBase: @staticmethod - def _run_test(target_fn: Callable, test_input: List, dtypes: List[torch.dtype], tolerance: float, expected: List): + def _run_test(target_fn: Callable, test_input: List, dtypes: List[torch.dtype], atol: float, expected: List): for dtype in dtypes: actual_box = torch.tensor(test_input, dtype=dtype) expected_box = torch.tensor(expected) out = target_fn(actual_box, actual_box) - torch.testing.assert_close(out, expected_box, rtol=0.0, check_dtype=False, atol=tolerance) - # assert_close(actual_box, expected_box, tolerance) + torch.testing.assert_close(out, expected_box, rtol=0.0, check_dtype=False, atol=atol) @staticmethod def _run_jit_test(target_fn: Callable, test_input: List): @@ -1180,19 +1179,19 @@ def _run_jit_test(target_fn: Callable, test_input: List): class TestBoxIou(TestIouBase): - generate_int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]] + int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]] float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] @pytest.mark.parametrize( - "test_input, dtypes, tolerance, expected", + "test_input, dtypes, atol, expected", [ - pytest.param(IOU_INT_BOXES, [torch.int16, torch.int32, torch.int64], 1e-4, generate_int_expected), + pytest.param(IOU_INT_BOXES, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected), pytest.param(IOU_FLOAT_BOXES, [torch.float16], 0.002, float_expected), pytest.param(IOU_FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected), ], ) - def test_iou(self, test_input, dtypes, tolerance, expected): - self._run_test(ops.box_iou, test_input, dtypes, tolerance, expected) + def test_iou(self, test_input, dtypes, atol, expected): + self._run_test(ops.box_iou, test_input, dtypes, atol, expected) def test_iou_jit(self): self._run_jit_test(ops.box_iou, IOU_INT_BOXES) @@ -1203,15 +1202,15 @@ class TestGeneralizedBoxIou(TestIouBase): float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] @pytest.mark.parametrize( - "test_input, dtypes, tolerance, expected", + "test_input, dtypes, atol, expected", [ pytest.param(IOU_INT_BOXES, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected), pytest.param(IOU_FLOAT_BOXES, [torch.float16], 0.002, float_expected), pytest.param(IOU_FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected), ], ) - def test_iou(self, test_input, dtypes, tolerance, expected): - self._run_test(ops.generalized_box_iou, test_input, dtypes, tolerance, expected) + def test_iou(self, test_input, dtypes, atol, expected): + self._run_test(ops.generalized_box_iou, test_input, dtypes, atol, expected) def test_iou_jit(self): self._run_jit_test(ops.generalized_box_iou, IOU_INT_BOXES) @@ -1222,15 +1221,15 @@ class TestDistanceBoxIoU(TestIouBase): float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] @pytest.mark.parametrize( - "test_input, dtypes, tolerance, expected", + "test_input, dtypes, atol, expected", [ pytest.param(IOU_INT_BOXES, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected), pytest.param(IOU_FLOAT_BOXES, [torch.float16], 0.002, float_expected), pytest.param(IOU_FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected), ], ) - def test_iou(self, test_input, dtypes, tolerance, expected): - self._run_test(ops.distance_box_iou, test_input, dtypes, tolerance, expected) + def test_iou(self, test_input, dtypes, atol, expected): + self._run_test(ops.distance_box_iou, test_input, dtypes, atol, expected) def test_iou_jit(self): self._run_jit_test(ops.distance_box_iou, IOU_INT_BOXES) @@ -1241,15 +1240,15 @@ class TestCompleteBoxIou(TestIouBase): float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] @pytest.mark.parametrize( - "test_input, dtypes, tolerance, expected", + "test_input, dtypes, atol, expected", [ pytest.param(IOU_INT_BOXES, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected), pytest.param(IOU_FLOAT_BOXES, [torch.float16], 0.002, float_expected), pytest.param(IOU_FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected), ], ) - def test_iou(self, test_input, dtypes, tolerance, expected): - self._run_test(ops.complete_box_iou, test_input, dtypes, tolerance, expected) + def test_iou(self, test_input, dtypes, atol, expected): + self._run_test(ops.complete_box_iou, test_input, dtypes, atol, expected) def test_iou_jit(self): self._run_jit_test(ops.complete_box_iou, IOU_INT_BOXES) @@ -1267,7 +1266,7 @@ def get_boxes(dtype, device): return box1, box2, box3, box4, box1s, box2s -def assert_iou_loss(iou_fn, box1, box2, expected_loss, dtype, device, reduction="none"): +def assert_iou_loss(iou_fn, box1, box2, expected_loss, device, reduction="none"): computed_loss = iou_fn(box1, box2, reduction=reduction) expected_loss = torch.tensor(expected_loss, device=device) torch.testing.assert_close(computed_loss, expected_loss) @@ -1294,22 +1293,22 @@ def test_giou_loss(self, dtype, device): box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device) # Identical boxes should have loss of 0 - assert_iou_loss(ops.generalized_box_iou_loss, box1, box1, 0.0, dtype=dtype, device=device) + assert_iou_loss(ops.generalized_box_iou_loss, box1, box1, 0.0, device=device) # quarter size box inside other box = IoU of 0.25 - assert_iou_loss(ops.generalized_box_iou_loss, box1, box2, 0.75, dtype=dtype, device=device) + assert_iou_loss(ops.generalized_box_iou_loss, box1, box2, 0.75, device=device) # Two side by side boxes, area=union # IoU=0 and GIoU=0 (loss 1.0) - assert_iou_loss(ops.generalized_box_iou_loss, box2, box3, 1.0, dtype=dtype, device=device) + assert_iou_loss(ops.generalized_box_iou_loss, box2, box3, 1.0, device=device) # Two diagonally adjacent boxes, area=2*union # IoU=0 and GIoU=-0.5 (loss 1.5) - assert_iou_loss(ops.generalized_box_iou_loss, box2, box4, 1.5, dtype=dtype, device=device) + assert_iou_loss(ops.generalized_box_iou_loss, box2, box4, 1.5, device=device) # Test batched loss and reductions - assert_iou_loss(ops.generalized_box_iou_loss, box1s, box2s, 2.5, dtype=dtype, device=device, reduction="sum") - assert_iou_loss(ops.generalized_box_iou_loss, box1s, box2s, 1.25, dtype=dtype, device=device, reduction="mean") + assert_iou_loss(ops.generalized_box_iou_loss, box1s, box2s, 2.5, device=device, reduction="sum") + assert_iou_loss(ops.generalized_box_iou_loss, box1s, box2s, 1.25, device=device, reduction="mean") @pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) @@ -1323,12 +1322,12 @@ class TestCompleteBoxIouLoss: def test_ciou_loss(self, dtype, device): box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device) - assert_iou_loss(ops.complete_box_iou_loss, box1, box1, 0.0, dtype=dtype, device=device) - assert_iou_loss(ops.complete_box_iou_loss, box1, box2, 0.8125, dtype=dtype, device=device) - assert_iou_loss(ops.complete_box_iou_loss, box1, box3, 1.1923, dtype=dtype, device=device) - assert_iou_loss(ops.complete_box_iou_loss, box1, box4, 1.2500, dtype=dtype, device=device) - assert_iou_loss(ops.complete_box_iou_loss, box1s, box2s, 1.2250, dtype=dtype, device=device, reduction="mean") - assert_iou_loss(ops.complete_box_iou_loss, box1s, box2s, 2.4500, dtype=dtype, device=device, reduction="sum") + assert_iou_loss(ops.complete_box_iou_loss, box1, box1, 0.0, device=device) + assert_iou_loss(ops.complete_box_iou_loss, box1, box2, 0.8125, device=device) + assert_iou_loss(ops.complete_box_iou_loss, box1, box3, 1.1923, device=device) + assert_iou_loss(ops.complete_box_iou_loss, box1, box4, 1.2500, device=device) + assert_iou_loss(ops.complete_box_iou_loss, box1s, box2s, 1.2250, device=device, reduction="mean") + assert_iou_loss(ops.complete_box_iou_loss, box1s, box2s, 2.4500, device=device, reduction="sum") @pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) @@ -1342,12 +1341,12 @@ class TestDistanceBoxIouLoss: def test_distance_iou_loss(self, dtype, device): box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device) - assert_iou_loss(ops.distance_box_iou_loss, box1, box1, 0.0, dtype=dtype, device=device) - assert_iou_loss(ops.distance_box_iou_loss, box1, box2, 0.8125, dtype=dtype, device=device) - assert_iou_loss(ops.distance_box_iou_loss, box1, box3, 1.1923, dtype=dtype, device=device) - assert_iou_loss(ops.distance_box_iou_loss, box1, box4, 1.2500, dtype=dtype, device=device) - assert_iou_loss(ops.distance_box_iou_loss, box1s, box2s, 1.2250, dtype=dtype, device=device, reduction="mean") - assert_iou_loss(ops.distance_box_iou_loss, box1s, box2s, 2.4500, dtype=dtype, device=device, reduction="sum") + assert_iou_loss(ops.distance_box_iou_loss, box1, box1, 0.0, device=device) + assert_iou_loss(ops.distance_box_iou_loss, box1, box2, 0.8125, device=device) + assert_iou_loss(ops.distance_box_iou_loss, box1, box3, 1.1923, device=device) + assert_iou_loss(ops.distance_box_iou_loss, box1, box4, 1.2500, device=device) + assert_iou_loss(ops.distance_box_iou_loss, box1s, box2s, 1.2250, device=device, reduction="mean") + assert_iou_loss(ops.distance_box_iou_loss, box1s, box2s, 2.4500, device=device, reduction="sum") @pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) @@ -1474,10 +1473,10 @@ def test_jit(self, alpha, gamma, reduction, device, dtype, seed): class TestMasksToBoxes: def test_masks_box(self): - def masks_box_check(masks, expected, tolerance=1e-4): + def masks_box_check(masks, expected, atol=1e-4): out = ops.masks_to_boxes(masks) assert out.dtype == torch.float - torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=True, atol=tolerance) + torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=True, atol=atol) # Check for int type boxes. def _get_image(): From 141bb68f6a5575bda5d87cacfae4641ed3245ab0 Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Fri, 10 Jun 2022 16:20:52 +0530 Subject: [PATCH 11/14] final update --- test/test_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_ops.py b/test/test_ops.py index 8470346b56f..87147059ec6 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1141,7 +1141,7 @@ def test_float16_box(self): [[285.25, 185.625, 1194.0, 851.5], [285.25, 188.75, 1192.0, 851.0], [279.25, 198.0, 1189.0, 849.0]], dtype=torch.float16, ) - expected = torch.tensor([605113.875, 600495.1875, 592247.25]) + expected = torch.tensor([605113.875, 600495.1875, 592247.25], dtype=torch.float32) self.area_check(box_tensor, expected) def test_box_area_jit(self): From 1f183c528b207d0941a42341d413753330420ae6 Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Fri, 10 Jun 2022 16:48:25 +0530 Subject: [PATCH 12/14] Tighten for jit --- test/test_ops.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 87147059ec6..a7afd6e90a7 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1101,15 +1101,14 @@ def test_bbox_convert_jit(self): ) scripted_fn = torch.jit.script(ops.box_convert) - atol = 1e-3 box_xywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xywh") scripted_xywh = scripted_fn(box_tensor, "xyxy", "xywh") - torch.testing.assert_close(scripted_xywh, box_xywh, rtol=0.0, atol=atol) + torch.testing.assert_close(scripted_xywh, box_xywh) box_cxcywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="cxcywh") scripted_cxcywh = scripted_fn(box_tensor, "xyxy", "cxcywh") - torch.testing.assert_close(scripted_cxcywh, box_cxcywh, rtol=0.0, atol=atol) + torch.testing.assert_close(scripted_cxcywh, box_cxcywh) class TestBoxArea: @@ -1134,22 +1133,22 @@ def test_float_boxes(self, dtype): dtype=dtype, ) expected = torch.tensor([604723.0806, 600965.4666, 592761.0085], dtype=dtype) - self.area_check(box_tensor, expected, atol=0.05) + self.area_check(box_tensor, expected) def test_float16_box(self): box_tensor = torch.tensor( - [[285.25, 185.625, 1194.0, 851.5], [285.25, 188.75, 1192.0, 851.0], [279.25, 198.0, 1189.0, 849.0]], + [[28.25, 18.625, 39.0, 48.5], [28.25, 48.75, 192.0, 51.0], [29.25, 18.0, 89.0, 49.0]], dtype=torch.float16, ) - expected = torch.tensor([605113.875, 600495.1875, 592247.25], dtype=torch.float32) - self.area_check(box_tensor, expected) + expected = torch.tensor([321.1562, 368.4375, 1852.2500], dtype=torch.float16) + self.area_check(box_tensor, expected, atol=0.3) def test_box_area_jit(self): box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=torch.float) expected = ops.box_area(box_tensor) scripted_fn = torch.jit.script(ops.box_area) scripted_area = scripted_fn(box_tensor) - torch.testing.assert_close(scripted_area, expected, rtol=0.0, atol=1e-3) + torch.testing.assert_close(scripted_area, expected) class TestIouBase: @@ -1167,7 +1166,7 @@ def _run_jit_test(target_fn: Callable, test_input: List): expected = target_fn(box_tensor, box_tensor) scripted_fn = torch.jit.script(target_fn) scripted_out = scripted_fn(box_tensor, box_tensor) - torch.testing.assert_close(scripted_out, expected, rtol=0.0, atol=1e-3) + torch.testing.assert_close(scripted_out, expected) IOU_INT_BOXES = [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]] From 5171187bde3f040efcfd83b6245a59ae34cc3125 Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Tue, 21 Jun 2022 15:47:05 +0530 Subject: [PATCH 13/14] Refactor slightly --- test/test_ops.py | 57 +++++++++++++++++++++--------------------------- 1 file changed, 25 insertions(+), 32 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index a7afd6e90a7..5d34cdbc87a 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1111,6 +1111,14 @@ def test_bbox_convert_jit(self): torch.testing.assert_close(scripted_cxcywh, box_cxcywh) +INT_BOXES = [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]] +FLOAT_BOXES = [ + [285.3538, 185.5758, 1193.5110, 851.4551], + [285.1472, 188.7374, 1192.4984, 851.0669], + [279.2440, 197.9812, 1189.4746, 849.2019], +] + + class TestBoxArea: def area_check(self, box, expected, atol=1e-4): out = ops.box_area(box) @@ -1124,14 +1132,7 @@ def test_int_boxes(self, dtype): @pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) def test_float_boxes(self, dtype): - box_tensor = torch.tensor( - [ - [285.3538, 185.5758, 1193.5110, 851.4551], - [285.1472, 188.7374, 1192.4984, 851.0669], - [279.2440, 197.9812, 1189.4746, 849.2019], - ], - dtype=dtype, - ) + box_tensor = torch.tensor(FLOAT_BOXES, dtype=dtype) expected = torch.tensor([604723.0806, 600965.4666, 592761.0085], dtype=dtype) self.area_check(box_tensor, expected) @@ -1169,14 +1170,6 @@ def _run_jit_test(target_fn: Callable, test_input: List): torch.testing.assert_close(scripted_out, expected) -IOU_INT_BOXES = [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]] -IOU_FLOAT_BOXES = [ - [285.3538, 185.5758, 1193.5110, 851.4551], - [285.1472, 188.7374, 1192.4984, 851.0669], - [279.2440, 197.9812, 1189.4746, 849.2019], -] - - class TestBoxIou(TestIouBase): int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]] float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] @@ -1184,16 +1177,16 @@ class TestBoxIou(TestIouBase): @pytest.mark.parametrize( "test_input, dtypes, atol, expected", [ - pytest.param(IOU_INT_BOXES, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected), - pytest.param(IOU_FLOAT_BOXES, [torch.float16], 0.002, float_expected), - pytest.param(IOU_FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected), + pytest.param(INT_BOXES, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected), + pytest.param(FLOAT_BOXES, [torch.float16], 0.002, float_expected), + pytest.param(FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected), ], ) def test_iou(self, test_input, dtypes, atol, expected): self._run_test(ops.box_iou, test_input, dtypes, atol, expected) def test_iou_jit(self): - self._run_jit_test(ops.box_iou, IOU_INT_BOXES) + self._run_jit_test(ops.box_iou, INT_BOXES) class TestGeneralizedBoxIou(TestIouBase): @@ -1203,16 +1196,16 @@ class TestGeneralizedBoxIou(TestIouBase): @pytest.mark.parametrize( "test_input, dtypes, atol, expected", [ - pytest.param(IOU_INT_BOXES, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected), - pytest.param(IOU_FLOAT_BOXES, [torch.float16], 0.002, float_expected), - pytest.param(IOU_FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected), + pytest.param(INT_BOXES, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected), + pytest.param(FLOAT_BOXES, [torch.float16], 0.002, float_expected), + pytest.param(FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected), ], ) def test_iou(self, test_input, dtypes, atol, expected): self._run_test(ops.generalized_box_iou, test_input, dtypes, atol, expected) def test_iou_jit(self): - self._run_jit_test(ops.generalized_box_iou, IOU_INT_BOXES) + self._run_jit_test(ops.generalized_box_iou, INT_BOXES) class TestDistanceBoxIoU(TestIouBase): @@ -1222,16 +1215,16 @@ class TestDistanceBoxIoU(TestIouBase): @pytest.mark.parametrize( "test_input, dtypes, atol, expected", [ - pytest.param(IOU_INT_BOXES, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected), - pytest.param(IOU_FLOAT_BOXES, [torch.float16], 0.002, float_expected), - pytest.param(IOU_FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected), + pytest.param(INT_BOXES, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected), + pytest.param(FLOAT_BOXES, [torch.float16], 0.002, float_expected), + pytest.param(FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected), ], ) def test_iou(self, test_input, dtypes, atol, expected): self._run_test(ops.distance_box_iou, test_input, dtypes, atol, expected) def test_iou_jit(self): - self._run_jit_test(ops.distance_box_iou, IOU_INT_BOXES) + self._run_jit_test(ops.distance_box_iou, INT_BOXES) class TestCompleteBoxIou(TestIouBase): @@ -1241,16 +1234,16 @@ class TestCompleteBoxIou(TestIouBase): @pytest.mark.parametrize( "test_input, dtypes, atol, expected", [ - pytest.param(IOU_INT_BOXES, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected), - pytest.param(IOU_FLOAT_BOXES, [torch.float16], 0.002, float_expected), - pytest.param(IOU_FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected), + pytest.param(INT_BOXES, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected), + pytest.param(FLOAT_BOXES, [torch.float16], 0.002, float_expected), + pytest.param(FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected), ], ) def test_iou(self, test_input, dtypes, atol, expected): self._run_test(ops.complete_box_iou, test_input, dtypes, atol, expected) def test_iou_jit(self): - self._run_jit_test(ops.complete_box_iou, IOU_INT_BOXES) + self._run_jit_test(ops.complete_box_iou, INT_BOXES) def get_boxes(dtype, device): From 8f7364501d0014bbcda10a98aa4d9eeb82b5341f Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Wed, 27 Jul 2022 11:43:51 +0530 Subject: [PATCH 14/14] Fix tests --- test/test_ops.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 53637baf347..8ec0e6c7ea9 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from functools import lru_cache from itertools import product -from typing import Tuple, Callable, List +from typing import Callable, List, Tuple import numpy as np import pytest @@ -1138,11 +1138,11 @@ def test_float_boxes(self, dtype): def test_float16_box(self): box_tensor = torch.tensor( - [[28.25, 18.625, 39.0, 48.5], [28.25, 48.75, 192.0, 51.0], [29.25, 18.0, 89.0, 49.0]], - dtype=torch.float16, + [[2.825, 1.8625, 3.90, 4.85], [2.825, 4.875, 19.20, 5.10], [2.925, 1.80, 8.90, 4.90]], dtype=torch.float16 ) - expected = torch.tensor([321.1562, 368.4375, 1852.2500], dtype=torch.float16) - self.area_check(box_tensor, expected, atol=0.3) + + expected = torch.tensor([3.2170, 3.7108, 18.5071], dtype=torch.float16) + self.area_check(box_tensor, expected, atol=0.01) def test_box_area_jit(self): box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=torch.float)