Skip to content

Add test for sigmoid_focal_loss #5783

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 8, 2022
119 changes: 119 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pytest
import torch
import torch.fx
import torch.nn.functional as F
from common_utils import assert_equal, cpu_and_gpu, needs_cuda
from PIL import Image
from torch import nn, Tensor
Expand Down Expand Up @@ -1450,5 +1451,123 @@ def test_is_leaf_node(self, dim, p, block_size, inplace):
assert len(graph_node_names[0]) == 1 + op_obj.n_inputs


class TestFocalLoss:
def _generate_diverse_input_target_pair(self, shape=(5, 2), **kwargs):
def logit(p: Tensor) -> Tensor:
return torch.log(p / (1 - p))

def generate_tensor_with_range_type(shape, range_type, **kwargs):
if range_type != "random_binary":
low, high = {
"small": (0.0, 0.2),
"big": (0.8, 1.0),
"zeros": (0.0, 0.0),
"ones": (1.0, 1.0),
"random": (0.0, 1.0),
}[range_type]
return torch.testing.make_tensor(shape, low=low, high=high, **kwargs)
else:
return torch.randint(0, 2, shape, **kwargs)

# This function will return inputs and targets with shape: (shape[0]*9, shape[1])
inputs = []
targets = []
for input_range_type, target_range_type in [
("small", "zeros"),
("small", "ones"),
("small", "random_binary"),
("big", "zeros"),
("big", "ones"),
("big", "random_binary"),
("random", "zeros"),
("random", "ones"),
("random", "random_binary"),
]:
inputs.append(logit(generate_tensor_with_range_type(shape, input_range_type, **kwargs)))
targets.append(generate_tensor_with_range_type(shape, target_range_type, **kwargs))

return torch.cat(inputs), torch.cat(targets)

@pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0])
@pytest.mark.parametrize("gamma", [0, 2])
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
@pytest.mark.parametrize("seed", [0, 1])
def test_correct_ratio(self, alpha, gamma, device, dtype, seed) -> None:
if device == "cpu" and dtype is torch.half:
pytest.skip("Currently torch.half is not fully supported on cpu")
# For testing the ratio with manual calculation, we require the reduction to be "none"
reduction = "none"
torch.random.manual_seed(seed)
inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device)
focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction=reduction)

assert torch.all(
focal_loss <= ce_loss
), "focal loss must be less or equal to cross entropy loss with same input"

loss_ratio = (focal_loss / ce_loss).squeeze()
prob = torch.sigmoid(inputs)
p_t = prob * targets + (1 - prob) * (1 - targets)
correct_ratio = (1.0 - p_t) ** gamma
if alpha >= 0:
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
correct_ratio = correct_ratio * alpha_t

tol = 1e-3 if dtype is torch.half else 1e-5
torch.testing.assert_close(correct_ratio, loss_ratio, rtol=tol, atol=tol)

@pytest.mark.parametrize("reduction", ["mean", "sum"])
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
@pytest.mark.parametrize("seed", [2, 3])
def test_equal_ce_loss(self, reduction, device, dtype, seed) -> None:
if device == "cpu" and dtype is torch.half:
pytest.skip("Currently torch.half is not fully supported on cpu")
# focal loss should be equal ce_loss if alpha=-1 and gamma=0
alpha = -1
gamma = 0
torch.random.manual_seed(seed)
inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device)
inputs_fl = inputs.clone().requires_grad_()
targets_fl = targets.clone()
inputs_ce = inputs.clone().requires_grad_()
targets_ce = targets.clone()
focal_loss = ops.sigmoid_focal_loss(inputs_fl, targets_fl, gamma=gamma, alpha=alpha, reduction=reduction)
ce_loss = F.binary_cross_entropy_with_logits(inputs_ce, targets_ce, reduction=reduction)

tol = 1e-3 if dtype is torch.half else 1e-5
torch.testing.assert_close(focal_loss, ce_loss, rtol=tol, atol=tol)

focal_loss.backward()
ce_loss.backward()
torch.testing.assert_close(inputs_fl.grad, inputs_ce.grad, rtol=tol, atol=tol)

@pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0])
@pytest.mark.parametrize("gamma", [0, 2])
@pytest.mark.parametrize("reduction", ["none", "mean", "sum"])
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
@pytest.mark.parametrize("seed", [4, 5])
def test_jit(self, alpha, gamma, reduction, device, dtype, seed) -> None:
if device == "cpu" and dtype is torch.half:
pytest.skip("Currently torch.half is not fully supported on cpu")
script_fn = torch.jit.script(ops.sigmoid_focal_loss)
torch.random.manual_seed(seed)
inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device)
focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)
if device == "cpu":
scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)
else:
with torch.jit.fuser("fuser2"):
# Use fuser2 to prevent a bug on fuser: https://github.com/pytorch/pytorch/issues/75476
# We may remove this condition once the bug is resolved
scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)

tol = 1e-3 if dtype is torch.half else 1e-5
torch.testing.assert_close(focal_loss, scripted_focal_loss, rtol=tol, atol=tol)


if __name__ == "__main__":
pytest.main([__file__])