From c3f3a7518cfb04619b465d443a1c108dbf5fb29f Mon Sep 17 00:00:00 2001 From: Yosua Michael Maranatha Date: Thu, 7 Apr 2022 13:34:23 +0100 Subject: [PATCH 1/8] Add test for sigmoid_focal_loss --- test/test_ops.py | 136 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 136 insertions(+) diff --git a/test/test_ops.py b/test/test_ops.py index ad9aaefee52..290c400dfc3 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -9,6 +9,7 @@ import pytest import torch import torch.fx +import torch.nn.functional as F from common_utils import assert_equal, cpu_and_gpu, needs_cuda from PIL import Image from torch import nn, Tensor @@ -1450,5 +1451,140 @@ def test_is_leaf_node(self, dim, p, block_size, inplace): assert len(graph_node_names[0]) == 1 + op_obj.n_inputs +def focal_loss_device_and_dtype(): + # This is a helper function to provide list of device and dtype pair for TestFocalLoss + # We need this because torch.half is not fully supported on cpu + # For instance, torch.log(torch.rand(1, dtype=torch.half)) will produce error currently + def augment_need_cuda_param(x): + if x[0] == "cuda": + return pytest.param(x, marks=pytest.mark.needs_cuda) + return x + + result = [ + (device, dtype) + for device in ["cpu", "cuda"] + for dtype in [torch.float32, torch.half] + if (device, dtype) != ("cpu", torch.half) + ] + result = [augment_need_cuda_param(x) for x in result] + return result + + +class TestFocalLoss: + def _logit(self, p: Tensor) -> Tensor: + return torch.log(p / (1 - p)) + + def _get_subt(self, p: float, targets: Tensor): + return p * targets + (1 - p) * target + + def _generate_tensor_with_range_type(self, shape, range_type, **kwargs): + if range_type != "random_binary": + range_map = { + "small": (0.0, 0.2), + "big": (0.8, 1.0), + "zeros": (0.0, 0.0), + "ones": (1.0, 1.0), + "random": (0.0, 1.0), + } + range = range_map[range_type] + return (range[1] - range[0]) * torch.rand(shape, **kwargs) + range[0] + else: + return torch.randint(0, 2, shape, **kwargs) + + def _generate_diverse_input_target_pair(self, shape=(5, 2), **kwargs): + # This function will return inputs and targets with shape: (shape[0]*9, shape[1]) + inputs_targets_range_type = [ + ("small", "zeros"), + ("small", "ones"), + ("small", "random_binary"), + ("big", "zeros"), + ("big", "ones"), + ("big", "random_binary"), + ("random", "zeros"), + ("random", "ones"), + ("random", "random_binary"), + ] + inputs = self._logit( + torch.concat( + [ + self._generate_tensor_with_range_type(shape, x[0], **kwargs).squeeze() + for x in inputs_targets_range_type + ] + ) + ) + targets = torch.concat( + [self._generate_tensor_with_range_type(shape, x[1], **kwargs).squeeze() for x in inputs_targets_range_type] + ) + return inputs, targets + + @pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0]) + @pytest.mark.parametrize("gamma", [0, 2]) + @pytest.mark.parametrize("device_dtype", focal_loss_device_and_dtype()) + @pytest.mark.parametrize("seed", [0, 1]) + def test_correct_ratio(self, alpha, gamma, device_dtype, seed) -> None: + # For testing the ratio with manual calculation, we require the reduction to be "none" + reduction = "none" + device, dtype = device_dtype + torch.random.manual_seed(seed) + inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device) + focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) + ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction=reduction) + + assert torch.all(focal_loss <= ce_loss) + + loss_ratio = (focal_loss / ce_loss).squeeze() + prob = torch.sigmoid(inputs) + p_t = prob * targets + (1 - prob) * (1 - targets) + correct_ratio = (1.0 - p_t) ** gamma + if alpha >= 0: + alpha_t = alpha * targets + (1 - alpha) * (1 - targets) + correct_ratio = correct_ratio * alpha_t + + tol = 1e-3 if dtype is torch.half else 1e-5 + torch.testing.assert_close(correct_ratio, loss_ratio, rtol=tol, atol=tol) + + @pytest.mark.parametrize("reduction", ["mean", "sum"]) + @pytest.mark.parametrize("device_dtype", focal_loss_device_and_dtype()) + @pytest.mark.parametrize("seed", [2, 3]) + def test_equal_ce_loss(self, reduction, device_dtype, seed) -> None: + # focal loss should be equal ce_loss if alpha=-1 and gamma=0 + alpha = -1 + gamma = 0 + device, dtype = device_dtype + torch.random.manual_seed(seed) + inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device) + inputs_fl = inputs.clone().requires_grad_() + targets_fl = targets.clone() + inputs_ce = inputs.clone().requires_grad_() + targets_ce = targets.clone() + focal_loss = ops.sigmoid_focal_loss(inputs_fl, targets_fl, gamma=gamma, alpha=alpha, reduction=reduction) + ce_loss = F.binary_cross_entropy_with_logits(inputs_ce, targets_ce, reduction=reduction) + + assert torch.all(focal_loss <= ce_loss) + + tol = 1e-3 if dtype is torch.half else 1e-5 + torch.testing.assert_close(focal_loss.data, ce_loss.data, rtol=tol, atol=tol) + + focal_loss.backward() + ce_loss.backward() + torch.testing.assert_close(inputs_fl.grad.data, inputs_ce.grad.data, rtol=tol, atol=tol) + + @pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0]) + @pytest.mark.parametrize("gamma", [0, 2]) + @pytest.mark.parametrize("reduction", ["none", "mean", "sum"]) + @pytest.mark.parametrize("device_dtype", focal_loss_device_and_dtype()) + @pytest.mark.parametrize("seed", [4, 5]) + def test_jit(self, alpha, gamma, reduction, device_dtype, seed) -> None: + device, dtype = device_dtype + script_fn = torch.jit.script(ops.sigmoid_focal_loss) + torch.random.manual_seed(seed) + inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device) + focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) + scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) + + tol = 1e-3 if dtype is torch.half else 1e-5 + torch.testing.assert_close(focal_loss, scripted_focal_loss, rtol=tol, atol=tol) + + if __name__ == "__main__": pytest.main([__file__]) From 4810e28ff3e45e9309ec0c3a214832b706798d09 Mon Sep 17 00:00:00 2001 From: YosuaMichael Date: Thu, 7 Apr 2022 15:21:34 +0100 Subject: [PATCH 2/8] Update test/test_ops.py Improve code by using torch.testing.make_tensor Co-authored-by: Philip Meier --- test/test_ops.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 290c400dfc3..a8ed818a549 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1479,15 +1479,14 @@ def _get_subt(self, p: float, targets: Tensor): def _generate_tensor_with_range_type(self, shape, range_type, **kwargs): if range_type != "random_binary": - range_map = { + low, high = { "small": (0.0, 0.2), "big": (0.8, 1.0), "zeros": (0.0, 0.0), "ones": (1.0, 1.0), "random": (0.0, 1.0), - } - range = range_map[range_type] - return (range[1] - range[0]) * torch.rand(shape, **kwargs) + range[0] + }[range_type] + return torch.testing.make_tensor(shape, low=low, high=high, **kwargs) else: return torch.randint(0, 2, shape, **kwargs) From 5c92ee78411776e5b5f60a1be1c6b186c2c87c57 Mon Sep 17 00:00:00 2001 From: YosuaMichael Date: Thu, 7 Apr 2022 15:22:38 +0100 Subject: [PATCH 3/8] Update test/test_ops.py Remove unnecessary assert Co-authored-by: Philip Meier --- test/test_ops.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index a8ed818a549..8b21a16e264 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1559,8 +1559,6 @@ def test_equal_ce_loss(self, reduction, device_dtype, seed) -> None: focal_loss = ops.sigmoid_focal_loss(inputs_fl, targets_fl, gamma=gamma, alpha=alpha, reduction=reduction) ce_loss = F.binary_cross_entropy_with_logits(inputs_ce, targets_ce, reduction=reduction) - assert torch.all(focal_loss <= ce_loss) - tol = 1e-3 if dtype is torch.half else 1e-5 torch.testing.assert_close(focal_loss.data, ce_loss.data, rtol=tol, atol=tol) From 98fe32eb1f2fcce43571bd1ad1b3b7879563cb36 Mon Sep 17 00:00:00 2001 From: YosuaMichael Date: Thu, 7 Apr 2022 15:23:11 +0100 Subject: [PATCH 4/8] Update test/test_ops.py Refactor code for generating inputs and targets Co-authored-by: Philip Meier --- test/test_ops.py | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 8b21a16e264..ac8865cc2d7 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1492,7 +1492,9 @@ def _generate_tensor_with_range_type(self, shape, range_type, **kwargs): def _generate_diverse_input_target_pair(self, shape=(5, 2), **kwargs): # This function will return inputs and targets with shape: (shape[0]*9, shape[1]) - inputs_targets_range_type = [ + inputs = [] + targets = [] + for input_range_type, target_range_type in [ ("small", "zeros"), ("small", "ones"), ("small", "random_binary"), @@ -1502,19 +1504,11 @@ def _generate_diverse_input_target_pair(self, shape=(5, 2), **kwargs): ("random", "zeros"), ("random", "ones"), ("random", "random_binary"), - ] - inputs = self._logit( - torch.concat( - [ - self._generate_tensor_with_range_type(shape, x[0], **kwargs).squeeze() - for x in inputs_targets_range_type - ] - ) - ) - targets = torch.concat( - [self._generate_tensor_with_range_type(shape, x[1], **kwargs).squeeze() for x in inputs_targets_range_type] - ) - return inputs, targets + ]: + inputs.append(self._logit(self._generate_tensor_with_range_type(shape, input_range_type, **kwargs))) + targets.append(self._generate_tensor_with_range_type(shape, target_range_type, **kwargs)) + + return torch.cat(inputs), torch.cat(targets) @pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0]) @pytest.mark.parametrize("gamma", [0, 2]) From c4720b47d0bd0fa293504e4fd2dcc6dc1c4b0577 Mon Sep 17 00:00:00 2001 From: Yosua Michael Maranatha Date: Thu, 7 Apr 2022 17:29:29 +0100 Subject: [PATCH 5/8] Improve focal_loss test code suggested on comment by Philip --- test/test_ops.py | 54 ++++++++++++++++++------------------------------ 1 file changed, 20 insertions(+), 34 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index ac8865cc2d7..bf742224ff0 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1451,32 +1451,10 @@ def test_is_leaf_node(self, dim, p, block_size, inplace): assert len(graph_node_names[0]) == 1 + op_obj.n_inputs -def focal_loss_device_and_dtype(): - # This is a helper function to provide list of device and dtype pair for TestFocalLoss - # We need this because torch.half is not fully supported on cpu - # For instance, torch.log(torch.rand(1, dtype=torch.half)) will produce error currently - def augment_need_cuda_param(x): - if x[0] == "cuda": - return pytest.param(x, marks=pytest.mark.needs_cuda) - return x - - result = [ - (device, dtype) - for device in ["cpu", "cuda"] - for dtype in [torch.float32, torch.half] - if (device, dtype) != ("cpu", torch.half) - ] - result = [augment_need_cuda_param(x) for x in result] - return result - - class TestFocalLoss: def _logit(self, p: Tensor) -> Tensor: return torch.log(p / (1 - p)) - def _get_subt(self, p: float, targets: Tensor): - return p * targets + (1 - p) * target - def _generate_tensor_with_range_type(self, shape, range_type, **kwargs): if range_type != "random_binary": low, high = { @@ -1512,18 +1490,22 @@ def _generate_diverse_input_target_pair(self, shape=(5, 2), **kwargs): @pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0]) @pytest.mark.parametrize("gamma", [0, 2]) - @pytest.mark.parametrize("device_dtype", focal_loss_device_and_dtype()) + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) @pytest.mark.parametrize("seed", [0, 1]) - def test_correct_ratio(self, alpha, gamma, device_dtype, seed) -> None: + def test_correct_ratio(self, alpha, gamma, device, dtype, seed) -> None: + if device == "cpu" and dtype is torch.half: + pytest.skip("Currently torch.half is not fully supported on cpu") # For testing the ratio with manual calculation, we require the reduction to be "none" reduction = "none" - device, dtype = device_dtype torch.random.manual_seed(seed) inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device) focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction=reduction) - assert torch.all(focal_loss <= ce_loss) + assert torch.all( + focal_loss <= ce_loss + ), "focal loss must be less or equal to cross entropy loss with same input" loss_ratio = (focal_loss / ce_loss).squeeze() prob = torch.sigmoid(inputs) @@ -1537,13 +1519,15 @@ def test_correct_ratio(self, alpha, gamma, device_dtype, seed) -> None: torch.testing.assert_close(correct_ratio, loss_ratio, rtol=tol, atol=tol) @pytest.mark.parametrize("reduction", ["mean", "sum"]) - @pytest.mark.parametrize("device_dtype", focal_loss_device_and_dtype()) + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) @pytest.mark.parametrize("seed", [2, 3]) - def test_equal_ce_loss(self, reduction, device_dtype, seed) -> None: + def test_equal_ce_loss(self, reduction, device, dtype, seed) -> None: + if device == "cpu" and dtype is torch.half: + pytest.skip("Currently torch.half is not fully supported on cpu") # focal loss should be equal ce_loss if alpha=-1 and gamma=0 alpha = -1 gamma = 0 - device, dtype = device_dtype torch.random.manual_seed(seed) inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device) inputs_fl = inputs.clone().requires_grad_() @@ -1554,19 +1538,21 @@ def test_equal_ce_loss(self, reduction, device_dtype, seed) -> None: ce_loss = F.binary_cross_entropy_with_logits(inputs_ce, targets_ce, reduction=reduction) tol = 1e-3 if dtype is torch.half else 1e-5 - torch.testing.assert_close(focal_loss.data, ce_loss.data, rtol=tol, atol=tol) + torch.testing.assert_close(focal_loss, ce_loss, rtol=tol, atol=tol) focal_loss.backward() ce_loss.backward() - torch.testing.assert_close(inputs_fl.grad.data, inputs_ce.grad.data, rtol=tol, atol=tol) + torch.testing.assert_close(inputs_fl.grad, inputs_ce.grad, rtol=tol, atol=tol) @pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0]) @pytest.mark.parametrize("gamma", [0, 2]) @pytest.mark.parametrize("reduction", ["none", "mean", "sum"]) - @pytest.mark.parametrize("device_dtype", focal_loss_device_and_dtype()) + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) @pytest.mark.parametrize("seed", [4, 5]) - def test_jit(self, alpha, gamma, reduction, device_dtype, seed) -> None: - device, dtype = device_dtype + def test_jit(self, alpha, gamma, reduction, device, dtype, seed) -> None: + if device == "cpu" and dtype is torch.half: + pytest.skip("Currently torch.half is not fully supported on cpu") script_fn = torch.jit.script(ops.sigmoid_focal_loss) torch.random.manual_seed(seed) inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device) From 38fb25f35f219a6e7b7785f764162a2ca1315b2f Mon Sep 17 00:00:00 2001 From: Yosua Michael M Date: Fri, 8 Apr 2022 10:29:53 +0100 Subject: [PATCH 6/8] Use fuser2 to prevent fuser bug --- test/test_ops.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/test_ops.py b/test/test_ops.py index bf742224ff0..02b6d053816 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1557,7 +1557,10 @@ def test_jit(self, alpha, gamma, reduction, device, dtype, seed) -> None: torch.random.manual_seed(seed) inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device) focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) - scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) + with torch.jit.fuser("fuser2"): + # Use fuser2 to prevent the bug from fuser that will be triggered in following condition: + # dtype=torch.half, device="cuda" + scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) tol = 1e-3 if dtype is torch.half else 1e-5 torch.testing.assert_close(focal_loss, scripted_focal_loss, rtol=tol, atol=tol) From 8c324ac31fc6727eddeb6c9bb17c5e72f83c21f0 Mon Sep 17 00:00:00 2001 From: Yosua Michael Maranatha Date: Fri, 8 Apr 2022 11:24:44 +0100 Subject: [PATCH 7/8] Combine function to generate input, dont set the fuser when device is cpu --- test/test_ops.py | 45 ++++++++++++++++++++++++--------------------- 1 file changed, 24 insertions(+), 21 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 02b6d053816..07d029a55d9 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1452,23 +1452,23 @@ def test_is_leaf_node(self, dim, p, block_size, inplace): class TestFocalLoss: - def _logit(self, p: Tensor) -> Tensor: - return torch.log(p / (1 - p)) - - def _generate_tensor_with_range_type(self, shape, range_type, **kwargs): - if range_type != "random_binary": - low, high = { - "small": (0.0, 0.2), - "big": (0.8, 1.0), - "zeros": (0.0, 0.0), - "ones": (1.0, 1.0), - "random": (0.0, 1.0), - }[range_type] - return torch.testing.make_tensor(shape, low=low, high=high, **kwargs) - else: - return torch.randint(0, 2, shape, **kwargs) - def _generate_diverse_input_target_pair(self, shape=(5, 2), **kwargs): + def logit(p: Tensor) -> Tensor: + return torch.log(p / (1 - p)) + + def generate_tensor_with_range_type(shape, range_type, **kwargs): + if range_type != "random_binary": + low, high = { + "small": (0.0, 0.2), + "big": (0.8, 1.0), + "zeros": (0.0, 0.0), + "ones": (1.0, 1.0), + "random": (0.0, 1.0), + }[range_type] + return torch.testing.make_tensor(shape, low=low, high=high, **kwargs) + else: + return torch.randint(0, 2, shape, **kwargs) + # This function will return inputs and targets with shape: (shape[0]*9, shape[1]) inputs = [] targets = [] @@ -1483,8 +1483,8 @@ def _generate_diverse_input_target_pair(self, shape=(5, 2), **kwargs): ("random", "ones"), ("random", "random_binary"), ]: - inputs.append(self._logit(self._generate_tensor_with_range_type(shape, input_range_type, **kwargs))) - targets.append(self._generate_tensor_with_range_type(shape, target_range_type, **kwargs)) + inputs.append(logit(generate_tensor_with_range_type(shape, input_range_type, **kwargs))) + targets.append(generate_tensor_with_range_type(shape, target_range_type, **kwargs)) return torch.cat(inputs), torch.cat(targets) @@ -1557,10 +1557,13 @@ def test_jit(self, alpha, gamma, reduction, device, dtype, seed) -> None: torch.random.manual_seed(seed) inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device) focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) - with torch.jit.fuser("fuser2"): - # Use fuser2 to prevent the bug from fuser that will be triggered in following condition: - # dtype=torch.half, device="cuda" + if device == "cpu": scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) + else: + with torch.jit.fuser("fuser2"): + # Use fuser2 to prevent the bug from fuser that will be triggered in following condition: + # dtype=torch.half, device="cuda" + scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) tol = 1e-3 if dtype is torch.half else 1e-5 torch.testing.assert_close(focal_loss, scripted_focal_loss, rtol=tol, atol=tol) From 22a22f52cd822079aeb7cbd7395c5ec50612ef82 Mon Sep 17 00:00:00 2001 From: Yosua Michael Maranatha Date: Fri, 8 Apr 2022 14:19:23 +0100 Subject: [PATCH 8/8] Add github issue for the fuser problem --- test/test_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 07d029a55d9..071f079e97c 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1561,8 +1561,8 @@ def test_jit(self, alpha, gamma, reduction, device, dtype, seed) -> None: scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) else: with torch.jit.fuser("fuser2"): - # Use fuser2 to prevent the bug from fuser that will be triggered in following condition: - # dtype=torch.half, device="cuda" + # Use fuser2 to prevent a bug on fuser: https://github.com/pytorch/pytorch/issues/75476 + # We may remove this condition once the bug is resolved scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) tol = 1e-3 if dtype is torch.half else 1e-5