diff --git a/test/test_ops.py b/test/test_ops.py index bc4f9d19464..8ec0e6c7ea9 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -995,7 +995,7 @@ def test_frozenbatchnorm2d_eps(self, seed): torch.testing.assert_close(fbn(x), bn(x), rtol=1e-5, atol=1e-6) -class TestBoxConversion: +class TestBoxConversionToRoi: def _get_box_sequences(): # Define here the argument type of `boxes` supported by region pooling operations box_tensor = torch.tensor([[0, 0, 0, 100, 100], [1, 0, 0, 100, 100]], dtype=torch.float) @@ -1021,7 +1021,7 @@ def test_convert_boxes_to_roi_format(self, box_sequence): assert_equal(ref_tensor, ops._utils.convert_boxes_to_roi_format(box_sequence)) -class TestBox: +class TestBoxConvert: def test_bbox_same(self): box_tensor = torch.tensor( [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float @@ -1051,7 +1051,7 @@ def test_bbox_xyxy_xywh(self): assert_equal(box_xyxy, box_tensor) def test_bbox_xyxy_cxcywh(self): - # Simple test convert boxes to xywh and back. Make sure they are same. + # Simple test convert boxes to cxcywh and back. Make sure they are same. # box_tensor is in x1 y1 x2 y2 format. box_tensor = torch.tensor( [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float @@ -1073,7 +1073,6 @@ def test_bbox_xywh_cxcywh(self): [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float ) - # This is wrong exp_cxcywh = torch.tensor( [[50, 50, 100, 100], [0, 0, 0, 0], [20, 25, 20, 20], [58, 65, 70, 60]], dtype=torch.float ) @@ -1102,296 +1101,374 @@ def test_bbox_convert_jit(self): ) scripted_fn = torch.jit.script(ops.box_convert) - TOLERANCE = 1e-3 box_xywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="xywh") scripted_xywh = scripted_fn(box_tensor, "xyxy", "xywh") - torch.testing.assert_close(scripted_xywh, box_xywh, rtol=0.0, atol=TOLERANCE) + torch.testing.assert_close(scripted_xywh, box_xywh) box_cxcywh = ops.box_convert(box_tensor, in_fmt="xyxy", out_fmt="cxcywh") scripted_cxcywh = scripted_fn(box_tensor, "xyxy", "cxcywh") - torch.testing.assert_close(scripted_cxcywh, box_cxcywh, rtol=0.0, atol=TOLERANCE) + torch.testing.assert_close(scripted_cxcywh, box_cxcywh) -class BoxTestBase(ABC): - @abstractmethod - def _target_fn(self) -> Tuple[bool, Callable]: - pass +INT_BOXES = [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]] +FLOAT_BOXES = [ + [285.3538, 185.5758, 1193.5110, 851.4551], + [285.1472, 188.7374, 1192.4984, 851.0669], + [279.2440, 197.9812, 1189.4746, 849.2019], +] + + +class TestBoxArea: + def area_check(self, box, expected, atol=1e-4): + out = ops.box_area(box) + torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=atol) + + @pytest.mark.parametrize("dtype", [torch.int8, torch.int16, torch.int32, torch.int64]) + def test_int_boxes(self, dtype): + box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=dtype) + expected = torch.tensor([10000, 0], dtype=torch.int32) + self.area_check(box_tensor, expected) + + @pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) + def test_float_boxes(self, dtype): + box_tensor = torch.tensor(FLOAT_BOXES, dtype=dtype) + expected = torch.tensor([604723.0806, 600965.4666, 592761.0085], dtype=dtype) + self.area_check(box_tensor, expected) + + def test_float16_box(self): + box_tensor = torch.tensor( + [[2.825, 1.8625, 3.90, 4.85], [2.825, 4.875, 19.20, 5.10], [2.925, 1.80, 8.90, 4.90]], dtype=torch.float16 + ) + + expected = torch.tensor([3.2170, 3.7108, 18.5071], dtype=torch.float16) + self.area_check(box_tensor, expected, atol=0.01) - def _perform_box_operation(self, box: Tensor, run_as_script: bool = False) -> Tensor: - is_binary_fn = self._target_fn()[0] - target_fn = self._target_fn()[1] - box_operation = torch.jit.script(target_fn) if run_as_script else target_fn - return box_operation(box, box) if is_binary_fn else box_operation(box) + def test_box_area_jit(self): + box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=torch.float) + expected = ops.box_area(box_tensor) + scripted_fn = torch.jit.script(ops.box_area) + scripted_area = scripted_fn(box_tensor) + torch.testing.assert_close(scripted_area, expected) - def _run_test(self, test_input: List, dtypes: List[torch.dtype], tolerance: float, expected: List) -> None: - def assert_close(box: Tensor, expected: Tensor, tolerance): - out = self._perform_box_operation(box) - torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=tolerance) +class TestIouBase: + @staticmethod + def _run_test(target_fn: Callable, test_input: List, dtypes: List[torch.dtype], atol: float, expected: List): for dtype in dtypes: actual_box = torch.tensor(test_input, dtype=dtype) expected_box = torch.tensor(expected) - assert_close(actual_box, expected_box, tolerance) + out = target_fn(actual_box, actual_box) + torch.testing.assert_close(out, expected_box, rtol=0.0, check_dtype=False, atol=atol) - def _run_jit_test(self, test_input: List) -> None: + @staticmethod + def _run_jit_test(target_fn: Callable, test_input: List): box_tensor = torch.tensor(test_input, dtype=torch.float) - expected = self._perform_box_operation(box_tensor, True) - scripted_area = self._perform_box_operation(box_tensor, True) - torch.testing.assert_close(scripted_area, expected, rtol=0.0, atol=1e-3) + expected = target_fn(box_tensor, box_tensor) + scripted_fn = torch.jit.script(target_fn) + scripted_out = scripted_fn(box_tensor, box_tensor) + torch.testing.assert_close(scripted_out, expected) -class TestBoxArea(BoxTestBase): - def _target_fn(self) -> Tuple[bool, Callable]: - return (False, ops.box_area) +class TestBoxIou(TestIouBase): + int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]] + float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] - def _generate_int_input() -> List[List[int]]: - return [[0, 0, 100, 100], [0, 0, 0, 0]] + @pytest.mark.parametrize( + "test_input, dtypes, atol, expected", + [ + pytest.param(INT_BOXES, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected), + pytest.param(FLOAT_BOXES, [torch.float16], 0.002, float_expected), + pytest.param(FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected), + ], + ) + def test_iou(self, test_input, dtypes, atol, expected): + self._run_test(ops.box_iou, test_input, dtypes, atol, expected) - def _generate_int_expected() -> List[int]: - return [10000, 0] + def test_iou_jit(self): + self._run_jit_test(ops.box_iou, INT_BOXES) - def _generate_float_input(index: int) -> List[List[float]]: - return [ - [ - [285.3538, 185.5758, 1193.5110, 851.4551], - [285.1472, 188.7374, 1192.4984, 851.0669], - [279.2440, 197.9812, 1189.4746, 849.2019], - ], - [[285.25, 185.625, 1194.0, 851.5], [285.25, 188.75, 1192.0, 851.0], [279.25, 198.0, 1189.0, 849.0]], - ][index] - def _generate_float_expected(index: int) -> List[float]: - return [[604723.0806, 600965.4666, 592761.0085], [605113.875, 600495.1875, 592247.25]][index] +class TestGeneralizedBoxIou(TestIouBase): + int_expected = [[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0]] + float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] @pytest.mark.parametrize( - "test_input, dtypes, tolerance, expected", + "test_input, dtypes, atol, expected", [ - pytest.param( - _generate_int_input(), - [torch.int8, torch.int16, torch.int32, torch.int64], - 1e-4, - _generate_int_expected(), - ), - pytest.param(_generate_float_input(0), [torch.float32, torch.float64], 0.05, _generate_float_expected(0)), - pytest.param(_generate_float_input(1), [torch.float16], 1e-4, _generate_float_expected(1)), + pytest.param(INT_BOXES, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected), + pytest.param(FLOAT_BOXES, [torch.float16], 0.002, float_expected), + pytest.param(FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected), ], ) - def test_box_area(self, test_input: List, dtypes: List[torch.dtype], tolerance: float, expected: List) -> None: - self._run_test(test_input, dtypes, tolerance, expected) + def test_iou(self, test_input, dtypes, atol, expected): + self._run_test(ops.generalized_box_iou, test_input, dtypes, atol, expected) - def test_box_area_jit(self) -> None: - self._run_jit_test([[0, 0, 100, 100], [0, 0, 0, 0]]) + def test_iou_jit(self): + self._run_jit_test(ops.generalized_box_iou, INT_BOXES) -class TestBoxIou(BoxTestBase): - def _target_fn(self) -> Tuple[bool, Callable]: - return (True, ops.box_iou) +class TestDistanceBoxIoU(TestIouBase): + int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]] + float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] - def _generate_int_input() -> List[List[int]]: - return [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]] + @pytest.mark.parametrize( + "test_input, dtypes, atol, expected", + [ + pytest.param(INT_BOXES, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected), + pytest.param(FLOAT_BOXES, [torch.float16], 0.002, float_expected), + pytest.param(FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected), + ], + ) + def test_iou(self, test_input, dtypes, atol, expected): + self._run_test(ops.distance_box_iou, test_input, dtypes, atol, expected) - def _generate_int_expected() -> List[List[float]]: - return [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]] + def test_iou_jit(self): + self._run_jit_test(ops.distance_box_iou, INT_BOXES) - def _generate_float_input() -> List[List[float]]: - return [ - [285.3538, 185.5758, 1193.5110, 851.4551], - [285.1472, 188.7374, 1192.4984, 851.0669], - [279.2440, 197.9812, 1189.4746, 849.2019], - ] - def _generate_float_expected() -> List[List[float]]: - return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] +class TestCompleteBoxIou(TestIouBase): + int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]] + float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] @pytest.mark.parametrize( - "test_input, dtypes, tolerance, expected", + "test_input, dtypes, atol, expected", [ - pytest.param( - _generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected() - ), - pytest.param(_generate_float_input(), [torch.float16], 0.002, _generate_float_expected()), - pytest.param(_generate_float_input(), [torch.float32, torch.float64], 1e-4, _generate_float_expected()), + pytest.param(INT_BOXES, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected), + pytest.param(FLOAT_BOXES, [torch.float16], 0.002, float_expected), + pytest.param(FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected), ], ) - def test_iou(self, test_input: List, dtypes: List[torch.dtype], tolerance: float, expected: List) -> None: - self._run_test(test_input, dtypes, tolerance, expected) + def test_iou(self, test_input, dtypes, atol, expected): + self._run_test(ops.complete_box_iou, test_input, dtypes, atol, expected) - def test_iou_jit(self) -> None: - self._run_jit_test([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) + def test_iou_jit(self): + self._run_jit_test(ops.complete_box_iou, INT_BOXES) -class TestGenBoxIou(BoxTestBase): - def _target_fn(self) -> Tuple[bool, Callable]: - return (True, ops.generalized_box_iou) +def get_boxes(dtype, device): + box1 = torch.tensor([-1, -1, 1, 1], dtype=dtype, device=device) + box2 = torch.tensor([0, 0, 1, 1], dtype=dtype, device=device) + box3 = torch.tensor([0, 1, 1, 2], dtype=dtype, device=device) + box4 = torch.tensor([1, 1, 2, 2], dtype=dtype, device=device) - def _generate_int_input() -> List[List[int]]: - return [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]] + box1s = torch.stack([box2, box2], dim=0) + box2s = torch.stack([box3, box4], dim=0) - def _generate_int_expected() -> List[List[float]]: - return [[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0]] + return box1, box2, box3, box4, box1s, box2s - def _generate_float_input() -> List[List[float]]: - return [ - [285.3538, 185.5758, 1193.5110, 851.4551], - [285.1472, 188.7374, 1192.4984, 851.0669], - [279.2440, 197.9812, 1189.4746, 849.2019], - ] - def _generate_float_expected() -> List[List[float]]: - return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] +def assert_iou_loss(iou_fn, box1, box2, expected_loss, device, reduction="none"): + computed_loss = iou_fn(box1, box2, reduction=reduction) + expected_loss = torch.tensor(expected_loss, device=device) + torch.testing.assert_close(computed_loss, expected_loss) - @pytest.mark.parametrize( - "test_input, dtypes, tolerance, expected", - [ - pytest.param( - _generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected() - ), - pytest.param(_generate_float_input(), [torch.float16], 0.002, _generate_float_expected()), - pytest.param(_generate_float_input(), [torch.float32, torch.float64], 0.001, _generate_float_expected()), - ], - ) - def test_gen_iou(self, test_input: List, dtypes: List[torch.dtype], tolerance: float, expected: List) -> None: - self._run_test(test_input, dtypes, tolerance, expected) - def test_giou_jit(self) -> None: - self._run_jit_test([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) +def assert_empty_loss(iou_fn, dtype, device): + box1 = torch.randn([0, 4], dtype=dtype, device=device).requires_grad_() + box2 = torch.randn([0, 4], dtype=dtype, device=device).requires_grad_() + loss = iou_fn(box1, box2, reduction="mean") + loss.backward() + torch.testing.assert_close(loss, torch.tensor(0.0, device=device)) + assert box1.grad is not None, "box1.grad should not be None after backward is called" + assert box2.grad is not None, "box2.grad should not be None after backward is called" + loss = iou_fn(box1, box2, reduction="none") + assert loss.numel() == 0, f"{str(iou_fn)} for two empty box should be empty" -class TestDistanceBoxIoU(BoxTestBase): - def _target_fn(self): - return (True, ops.distance_box_iou) +class TestGeneralizedBoxIouLoss: + # We refer to original test: https://github.com/facebookresearch/fvcore/blob/main/tests/test_giou_loss.py + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) + def test_giou_loss(self, dtype, device): - def _generate_int_input(): - return [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]] + box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device) - def _generate_int_expected(): - return [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]] + # Identical boxes should have loss of 0 + assert_iou_loss(ops.generalized_box_iou_loss, box1, box1, 0.0, device=device) - def _generate_float_input(): - return [ - [285.3538, 185.5758, 1193.5110, 851.4551], - [285.1472, 188.7374, 1192.4984, 851.0669], - [279.2440, 197.9812, 1189.4746, 849.2019], - ] + # quarter size box inside other box = IoU of 0.25 + assert_iou_loss(ops.generalized_box_iou_loss, box1, box2, 0.75, device=device) - def _generate_float_expected(): - return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] + # Two side by side boxes, area=union + # IoU=0 and GIoU=0 (loss 1.0) + assert_iou_loss(ops.generalized_box_iou_loss, box2, box3, 1.0, device=device) - @pytest.mark.parametrize( - "test_input, dtypes, tolerance, expected", - [ - pytest.param( - _generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected() - ), - pytest.param(_generate_float_input(), [torch.float16], 0.002, _generate_float_expected()), - pytest.param(_generate_float_input(), [torch.float32, torch.float64], 0.001, _generate_float_expected()), - ], - ) - def test_distance_iou(self, test_input, dtypes, tolerance, expected): - self._run_test(test_input, dtypes, tolerance, expected) + # Two diagonally adjacent boxes, area=2*union + # IoU=0 and GIoU=-0.5 (loss 1.5) + assert_iou_loss(ops.generalized_box_iou_loss, box2, box4, 1.5, device=device) - def test_distance_iou_jit(self): - self._run_jit_test([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) + # Test batched loss and reductions + assert_iou_loss(ops.generalized_box_iou_loss, box1s, box2s, 2.5, device=device, reduction="sum") + assert_iou_loss(ops.generalized_box_iou_loss, box1s, box2s, 1.25, device=device, reduction="mean") + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) + def test_empty_inputs(self, dtype, device): + assert_empty_loss(ops.generalized_box_iou_loss, dtype, device) -@pytest.mark.parametrize("device", cpu_and_gpu()) -@pytest.mark.parametrize("dtype", [torch.float32, torch.half]) -def test_distance_iou_loss(dtype, device): - box1 = torch.tensor([-1, -1, 1, 1], dtype=dtype, device=device) - box2 = torch.tensor([0, 0, 1, 1], dtype=dtype, device=device) - box3 = torch.tensor([0, 1, 1, 2], dtype=dtype, device=device) - box4 = torch.tensor([1, 1, 2, 2], dtype=dtype, device=device) - box1s = torch.stack( - [box2, box2], - dim=0, - ) - box2s = torch.stack( - [box3, box4], - dim=0, - ) +class TestCompleteBoxIouLoss: + @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) + @pytest.mark.parametrize("device", cpu_and_gpu()) + def test_ciou_loss(self, dtype, device): + box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device) - def assert_distance_iou_loss(box1, box2, expected_output, reduction="none"): - output = ops.distance_box_iou_loss(box1, box2, reduction=reduction) - # TODO: When passing the dtype, the torch.half fails as usual. - expected_output = torch.tensor(expected_output, device=device) - tol = 1e-5 if dtype != torch.half else 1e-3 - torch.testing.assert_close(output, expected_output, rtol=tol, atol=tol) + assert_iou_loss(ops.complete_box_iou_loss, box1, box1, 0.0, device=device) + assert_iou_loss(ops.complete_box_iou_loss, box1, box2, 0.8125, device=device) + assert_iou_loss(ops.complete_box_iou_loss, box1, box3, 1.1923, device=device) + assert_iou_loss(ops.complete_box_iou_loss, box1, box4, 1.2500, device=device) + assert_iou_loss(ops.complete_box_iou_loss, box1s, box2s, 1.2250, device=device, reduction="mean") + assert_iou_loss(ops.complete_box_iou_loss, box1s, box2s, 2.4500, device=device, reduction="sum") - assert_distance_iou_loss(box1, box1, 0.0) + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) + def test_empty_inputs(self, dtype, device): + assert_empty_loss(ops.complete_box_iou_loss, dtype, device) - assert_distance_iou_loss(box1, box2, 0.8125) - assert_distance_iou_loss(box1, box3, 1.1923) +class TestDistanceBoxIouLoss: + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) + def test_distance_iou_loss(self, dtype, device): + box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device) - assert_distance_iou_loss(box1, box4, 1.2500) + assert_iou_loss(ops.distance_box_iou_loss, box1, box1, 0.0, device=device) + assert_iou_loss(ops.distance_box_iou_loss, box1, box2, 0.8125, device=device) + assert_iou_loss(ops.distance_box_iou_loss, box1, box3, 1.1923, device=device) + assert_iou_loss(ops.distance_box_iou_loss, box1, box4, 1.2500, device=device) + assert_iou_loss(ops.distance_box_iou_loss, box1s, box2s, 1.2250, device=device, reduction="mean") + assert_iou_loss(ops.distance_box_iou_loss, box1s, box2s, 2.4500, device=device, reduction="sum") - assert_distance_iou_loss(box1s, box2s, 1.2250, reduction="mean") - assert_distance_iou_loss(box1s, box2s, 2.4500, reduction="sum") + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) + def test_empty_distance_iou_inputs(self, dtype, device): + assert_empty_loss(ops.distance_box_iou_loss, dtype, device) -@pytest.mark.parametrize("device", cpu_and_gpu()) -@pytest.mark.parametrize("dtype", [torch.float32, torch.half]) -def test_empty_distance_iou_inputs(dtype, device) -> None: - box1 = torch.randn([0, 4], dtype=dtype, device=device).requires_grad_() - box2 = torch.randn([0, 4], dtype=dtype, device=device).requires_grad_() +class TestFocalLoss: + def _generate_diverse_input_target_pair(self, shape=(5, 2), **kwargs): + def logit(p): + return torch.log(p / (1 - p)) - loss = ops.distance_box_iou_loss(box1, box2, reduction="mean") - loss.backward() + def generate_tensor_with_range_type(shape, range_type, **kwargs): + if range_type != "random_binary": + low, high = { + "small": (0.0, 0.2), + "big": (0.8, 1.0), + "zeros": (0.0, 0.0), + "ones": (1.0, 1.0), + "random": (0.0, 1.0), + }[range_type] + return torch.testing.make_tensor(shape, low=low, high=high, **kwargs) + else: + return torch.randint(0, 2, shape, **kwargs) - tol = 1e-3 if dtype is torch.half else 1e-5 - torch.testing.assert_close(loss, torch.tensor(0.0, device=device), rtol=tol, atol=tol) - assert box1.grad is not None, "box1.grad should not be None after backward is called" - assert box2.grad is not None, "box2.grad should not be None after backward is called" + # This function will return inputs and targets with shape: (shape[0]*9, shape[1]) + inputs = [] + targets = [] + for input_range_type, target_range_type in [ + ("small", "zeros"), + ("small", "ones"), + ("small", "random_binary"), + ("big", "zeros"), + ("big", "ones"), + ("big", "random_binary"), + ("random", "zeros"), + ("random", "ones"), + ("random", "random_binary"), + ]: + inputs.append(logit(generate_tensor_with_range_type(shape, input_range_type, **kwargs))) + targets.append(generate_tensor_with_range_type(shape, target_range_type, **kwargs)) - loss = ops.distance_box_iou_loss(box1, box2, reduction="none") - assert loss.numel() == 0, "diou_loss for two empty box should be empty" + return torch.cat(inputs), torch.cat(targets) + @pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0]) + @pytest.mark.parametrize("gamma", [0, 2]) + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) + @pytest.mark.parametrize("seed", [0, 1]) + def test_correct_ratio(self, alpha, gamma, device, dtype, seed): + if device == "cpu" and dtype is torch.half: + pytest.skip("Currently torch.half is not fully supported on cpu") + # For testing the ratio with manual calculation, we require the reduction to be "none" + reduction = "none" + torch.random.manual_seed(seed) + inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device) + focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) + ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction=reduction) -class TestCompleteBoxIou(BoxTestBase): - def _target_fn(self) -> Tuple[bool, Callable]: - return (True, ops.complete_box_iou) + assert torch.all( + focal_loss <= ce_loss + ), "focal loss must be less or equal to cross entropy loss with same input" - def _generate_int_input() -> List[List[int]]: - return [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]] + loss_ratio = (focal_loss / ce_loss).squeeze() + prob = torch.sigmoid(inputs) + p_t = prob * targets + (1 - prob) * (1 - targets) + correct_ratio = (1.0 - p_t) ** gamma + if alpha >= 0: + alpha_t = alpha * targets + (1 - alpha) * (1 - targets) + correct_ratio = correct_ratio * alpha_t - def _generate_int_expected() -> List[List[float]]: - return [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]] + tol = 1e-3 if dtype is torch.half else 1e-5 + torch.testing.assert_close(correct_ratio, loss_ratio, atol=tol, rtol=tol) - def _generate_float_input() -> List[List[float]]: - return [ - [285.3538, 185.5758, 1193.5110, 851.4551], - [285.1472, 188.7374, 1192.4984, 851.0669], - [279.2440, 197.9812, 1189.4746, 849.2019], - ] + @pytest.mark.parametrize("reduction", ["mean", "sum"]) + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) + @pytest.mark.parametrize("seed", [2, 3]) + def test_equal_ce_loss(self, reduction, device, dtype, seed): + if device == "cpu" and dtype is torch.half: + pytest.skip("Currently torch.half is not fully supported on cpu") + # focal loss should be equal ce_loss if alpha=-1 and gamma=0 + alpha = -1 + gamma = 0 + torch.random.manual_seed(seed) + inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device) + inputs_fl = inputs.clone().requires_grad_() + targets_fl = targets.clone() + inputs_ce = inputs.clone().requires_grad_() + targets_ce = targets.clone() + focal_loss = ops.sigmoid_focal_loss(inputs_fl, targets_fl, gamma=gamma, alpha=alpha, reduction=reduction) + ce_loss = F.binary_cross_entropy_with_logits(inputs_ce, targets_ce, reduction=reduction) - def _generate_float_expected() -> List[List[float]]: - return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] + torch.testing.assert_close(focal_loss, ce_loss) - @pytest.mark.parametrize( - "test_input, dtypes, tolerance, expected", - [ - pytest.param( - _generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected() - ), - pytest.param(_generate_float_input(), [torch.float32, torch.float64], 0.002, _generate_float_expected()), - pytest.param(_generate_float_input(), [torch.float32, torch.float64], 0.001, _generate_float_expected()), - ], - ) - def test_complete_iou(self, test_input: List, dtypes: List[torch.dtype], tolerance: float, expected: List) -> None: - self._run_test(test_input, dtypes, tolerance, expected) + focal_loss.backward() + ce_loss.backward() + torch.testing.assert_close(inputs_fl.grad, inputs_ce.grad) - def test_ciou_jit(self) -> None: - self._run_jit_test([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) + @pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0]) + @pytest.mark.parametrize("gamma", [0, 2]) + @pytest.mark.parametrize("reduction", ["none", "mean", "sum"]) + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) + @pytest.mark.parametrize("seed", [4, 5]) + def test_jit(self, alpha, gamma, reduction, device, dtype, seed): + if device == "cpu" and dtype is torch.half: + pytest.skip("Currently torch.half is not fully supported on cpu") + script_fn = torch.jit.script(ops.sigmoid_focal_loss) + torch.random.manual_seed(seed) + inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device) + focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) + if device == "cpu": + scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) + else: + with torch.jit.fuser("fuser2"): + # Use fuser2 to prevent a bug on fuser: https://github.com/pytorch/pytorch/issues/75476 + # We may remove this condition once the bug is resolved + scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) + + tol = 1e-3 if dtype is torch.half else 1e-5 + torch.testing.assert_close(focal_loss, scripted_focal_loss, rtol=tol, atol=tol) class TestMasksToBoxes: def test_masks_box(self): - def masks_box_check(masks, expected, tolerance=1e-4): + def masks_box_check(masks, expected, atol=1e-4): out = ops.masks_to_boxes(masks) assert out.dtype == torch.float - torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=tolerance) + torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=True, atol=atol) # Check for int type boxes. def _get_image(): @@ -1579,227 +1656,5 @@ def test_is_leaf_node(self, dim, p, block_size, inplace): assert len(graph_node_names[0]) == 1 + op_obj.n_inputs -class TestFocalLoss: - def _generate_diverse_input_target_pair(self, shape=(5, 2), **kwargs): - def logit(p: Tensor) -> Tensor: - return torch.log(p / (1 - p)) - - def generate_tensor_with_range_type(shape, range_type, **kwargs): - if range_type != "random_binary": - low, high = { - "small": (0.0, 0.2), - "big": (0.8, 1.0), - "zeros": (0.0, 0.0), - "ones": (1.0, 1.0), - "random": (0.0, 1.0), - }[range_type] - return torch.testing.make_tensor(shape, low=low, high=high, **kwargs) - else: - return torch.randint(0, 2, shape, **kwargs) - - # This function will return inputs and targets with shape: (shape[0]*9, shape[1]) - inputs = [] - targets = [] - for input_range_type, target_range_type in [ - ("small", "zeros"), - ("small", "ones"), - ("small", "random_binary"), - ("big", "zeros"), - ("big", "ones"), - ("big", "random_binary"), - ("random", "zeros"), - ("random", "ones"), - ("random", "random_binary"), - ]: - inputs.append(logit(generate_tensor_with_range_type(shape, input_range_type, **kwargs))) - targets.append(generate_tensor_with_range_type(shape, target_range_type, **kwargs)) - - return torch.cat(inputs), torch.cat(targets) - - @pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0]) - @pytest.mark.parametrize("gamma", [0, 2]) - @pytest.mark.parametrize("device", cpu_and_gpu()) - @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) - @pytest.mark.parametrize("seed", [0, 1]) - def test_correct_ratio(self, alpha, gamma, device, dtype, seed) -> None: - if device == "cpu" and dtype is torch.half: - pytest.skip("Currently torch.half is not fully supported on cpu") - # For testing the ratio with manual calculation, we require the reduction to be "none" - reduction = "none" - torch.random.manual_seed(seed) - inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device) - focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) - ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction=reduction) - - assert torch.all( - focal_loss <= ce_loss - ), "focal loss must be less or equal to cross entropy loss with same input" - - loss_ratio = (focal_loss / ce_loss).squeeze() - prob = torch.sigmoid(inputs) - p_t = prob * targets + (1 - prob) * (1 - targets) - correct_ratio = (1.0 - p_t) ** gamma - if alpha >= 0: - alpha_t = alpha * targets + (1 - alpha) * (1 - targets) - correct_ratio = correct_ratio * alpha_t - - tol = 1e-3 if dtype is torch.half else 1e-5 - torch.testing.assert_close(correct_ratio, loss_ratio, rtol=tol, atol=tol) - - @pytest.mark.parametrize("reduction", ["mean", "sum"]) - @pytest.mark.parametrize("device", cpu_and_gpu()) - @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) - @pytest.mark.parametrize("seed", [2, 3]) - def test_equal_ce_loss(self, reduction, device, dtype, seed) -> None: - if device == "cpu" and dtype is torch.half: - pytest.skip("Currently torch.half is not fully supported on cpu") - # focal loss should be equal ce_loss if alpha=-1 and gamma=0 - alpha = -1 - gamma = 0 - torch.random.manual_seed(seed) - inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device) - inputs_fl = inputs.clone().requires_grad_() - targets_fl = targets.clone() - inputs_ce = inputs.clone().requires_grad_() - targets_ce = targets.clone() - focal_loss = ops.sigmoid_focal_loss(inputs_fl, targets_fl, gamma=gamma, alpha=alpha, reduction=reduction) - ce_loss = F.binary_cross_entropy_with_logits(inputs_ce, targets_ce, reduction=reduction) - - tol = 1e-3 if dtype is torch.half else 1e-5 - torch.testing.assert_close(focal_loss, ce_loss, rtol=tol, atol=tol) - - focal_loss.backward() - ce_loss.backward() - torch.testing.assert_close(inputs_fl.grad, inputs_ce.grad, rtol=tol, atol=tol) - - @pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0]) - @pytest.mark.parametrize("gamma", [0, 2]) - @pytest.mark.parametrize("reduction", ["none", "mean", "sum"]) - @pytest.mark.parametrize("device", cpu_and_gpu()) - @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) - @pytest.mark.parametrize("seed", [4, 5]) - def test_jit(self, alpha, gamma, reduction, device, dtype, seed) -> None: - if device == "cpu" and dtype is torch.half: - pytest.skip("Currently torch.half is not fully supported on cpu") - script_fn = torch.jit.script(ops.sigmoid_focal_loss) - torch.random.manual_seed(seed) - inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device) - focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) - if device == "cpu": - scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) - else: - with torch.jit.fuser("fuser2"): - # Use fuser2 to prevent a bug on fuser: https://github.com/pytorch/pytorch/issues/75476 - # We may remove this condition once the bug is resolved - scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) - - tol = 1e-3 if dtype is torch.half else 1e-5 - torch.testing.assert_close(focal_loss, scripted_focal_loss, rtol=tol, atol=tol) - - -class TestGeneralizedBoxIouLoss: - # We refer to original test: https://github.com/facebookresearch/fvcore/blob/main/tests/test_giou_loss.py - @pytest.mark.parametrize("device", cpu_and_gpu()) - @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) - def test_giou_loss(self, dtype, device) -> None: - box1 = torch.tensor([-1, -1, 1, 1], dtype=dtype, device=device) - box2 = torch.tensor([0, 0, 1, 1], dtype=dtype, device=device) - box3 = torch.tensor([0, 1, 1, 2], dtype=dtype, device=device) - box4 = torch.tensor([1, 1, 2, 2], dtype=dtype, device=device) - - box1s = torch.stack([box2, box2], dim=0) - box2s = torch.stack([box3, box4], dim=0) - - def assert_giou_loss(box1, box2, expected_loss, reduction="none"): - tol = 1e-3 if dtype is torch.half else 1e-5 - computed_loss = ops.generalized_box_iou_loss(box1, box2, reduction=reduction) - expected_loss = torch.tensor(expected_loss, device=device) - torch.testing.assert_close(computed_loss, expected_loss, rtol=tol, atol=tol) - - # Identical boxes should have loss of 0 - assert_giou_loss(box1, box1, 0.0) - - # quarter size box inside other box = IoU of 0.25 - assert_giou_loss(box1, box2, 0.75) - - # Two side by side boxes, area=union - # IoU=0 and GIoU=0 (loss 1.0) - assert_giou_loss(box2, box3, 1.0) - - # Two diagonally adjacent boxes, area=2*union - # IoU=0 and GIoU=-0.5 (loss 1.5) - assert_giou_loss(box2, box4, 1.5) - - # Test batched loss and reductions - assert_giou_loss(box1s, box2s, 2.5, reduction="sum") - assert_giou_loss(box1s, box2s, 1.25, reduction="mean") - - @pytest.mark.parametrize("device", cpu_and_gpu()) - @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) - def test_empty_inputs(self, dtype, device) -> None: - box1 = torch.randn([0, 4], dtype=dtype).requires_grad_() - box2 = torch.randn([0, 4], dtype=dtype).requires_grad_() - - loss = ops.generalized_box_iou_loss(box1, box2, reduction="mean") - loss.backward() - - tol = 1e-3 if dtype is torch.half else 1e-5 - torch.testing.assert_close(loss, torch.tensor(0.0), rtol=tol, atol=tol) - assert box1.grad is not None, "box1.grad should not be None after backward is called" - assert box2.grad is not None, "box2.grad should not be None after backward is called" - - loss = ops.generalized_box_iou_loss(box1, box2, reduction="none") - assert loss.numel() == 0, "giou_loss for two empty box should be empty" - - -class TestCIOULoss: - @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) - @pytest.mark.parametrize("device", cpu_and_gpu()) - def test_ciou_loss(self, dtype, device): - box1 = torch.tensor([-1, -1, 1, 1], dtype=dtype, device=device) - box2 = torch.tensor([0, 0, 1, 1], dtype=dtype, device=device) - box3 = torch.tensor([0, 1, 1, 2], dtype=dtype, device=device) - box4 = torch.tensor([1, 1, 2, 2], dtype=dtype, device=device) - - box1s = torch.stack([box2, box2], dim=0) - box2s = torch.stack([box3, box4], dim=0) - - def assert_ciou_loss(box1, box2, expected_output, reduction="none"): - - output = ops.complete_box_iou_loss(box1, box2, reduction=reduction) - # TODO: When passing the dtype, the torch.half test doesn't pass... - expected_output = torch.tensor(expected_output, device=device) - tol = 1e-5 if dtype != torch.half else 1e-3 - torch.testing.assert_close(output, expected_output, rtol=tol, atol=tol) - - assert_ciou_loss(box1, box1, 0.0) - - assert_ciou_loss(box1, box2, 0.8125) - - assert_ciou_loss(box1, box3, 1.1923) - - assert_ciou_loss(box1, box4, 1.2500) - - assert_ciou_loss(box1s, box2s, 1.2250, reduction="mean") - assert_ciou_loss(box1s, box2s, 2.4500, reduction="sum") - - @pytest.mark.parametrize("device", cpu_and_gpu()) - @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) - def test_empty_inputs(self, dtype, device) -> None: - box1 = torch.randn([0, 4], dtype=dtype).requires_grad_() - box2 = torch.randn([0, 4], dtype=dtype).requires_grad_() - - loss = ops.complete_box_iou_loss(box1, box2, reduction="mean") - loss.backward() - - tol = 1e-3 if dtype is torch.half else 1e-5 - torch.testing.assert_close(loss, torch.tensor(0.0), rtol=tol, atol=tol) - assert box1.grad is not None, "box1.grad should not be None after backward is called" - assert box2.grad is not None, "box2.grad should not be None after backward is called" - - loss = ops.complete_box_iou_loss(box1, box2, reduction="none") - assert loss.numel() == 0, "ciou_loss for two empty box should be empty" - - if __name__ == "__main__": pytest.main([__file__])