Skip to content

Commit 186d0fe

Browse files
YosuaMichaelpmeier
authored andcommitted
[fbsync] Add test for sigmoid_focal_loss (#5783)
Summary: * Add test for sigmoid_focal_loss * Update test/test_ops.py Improve code by using torch.testing.make_tensor * Update test/test_ops.py Remove unnecessary assert * Update test/test_ops.py Refactor code for generating inputs and targets * Improve focal_loss test code suggested on comment by Philip * Use fuser2 to prevent fuser bug * Combine function to generate input, dont set the fuser when device is cpu * Add github issue for the fuser problem (Note: this ignores all push blocking failures!) Reviewed By: jdsgomes, NicolasHug Differential Revision: D36095694 fbshipit-source-id: 36ed7ece4188eef159fa7f84fc3ebdc346a63f2d Co-authored-by: Philip Meier <[email protected]> Co-authored-by: Philip Meier <[email protected]> Co-authored-by: Philip Meier <[email protected]> Co-authored-by: Philip Meier <[email protected]>
1 parent db13442 commit 186d0fe

File tree

1 file changed

+119
-0
lines changed

1 file changed

+119
-0
lines changed

test/test_ops.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pytest
1010
import torch
1111
import torch.fx
12+
import torch.nn.functional as F
1213
from common_utils import assert_equal, cpu_and_gpu, needs_cuda
1314
from PIL import Image
1415
from torch import nn, Tensor
@@ -1450,5 +1451,123 @@ def test_is_leaf_node(self, dim, p, block_size, inplace):
14501451
assert len(graph_node_names[0]) == 1 + op_obj.n_inputs
14511452

14521453

1454+
class TestFocalLoss:
1455+
def _generate_diverse_input_target_pair(self, shape=(5, 2), **kwargs):
1456+
def logit(p: Tensor) -> Tensor:
1457+
return torch.log(p / (1 - p))
1458+
1459+
def generate_tensor_with_range_type(shape, range_type, **kwargs):
1460+
if range_type != "random_binary":
1461+
low, high = {
1462+
"small": (0.0, 0.2),
1463+
"big": (0.8, 1.0),
1464+
"zeros": (0.0, 0.0),
1465+
"ones": (1.0, 1.0),
1466+
"random": (0.0, 1.0),
1467+
}[range_type]
1468+
return torch.testing.make_tensor(shape, low=low, high=high, **kwargs)
1469+
else:
1470+
return torch.randint(0, 2, shape, **kwargs)
1471+
1472+
# This function will return inputs and targets with shape: (shape[0]*9, shape[1])
1473+
inputs = []
1474+
targets = []
1475+
for input_range_type, target_range_type in [
1476+
("small", "zeros"),
1477+
("small", "ones"),
1478+
("small", "random_binary"),
1479+
("big", "zeros"),
1480+
("big", "ones"),
1481+
("big", "random_binary"),
1482+
("random", "zeros"),
1483+
("random", "ones"),
1484+
("random", "random_binary"),
1485+
]:
1486+
inputs.append(logit(generate_tensor_with_range_type(shape, input_range_type, **kwargs)))
1487+
targets.append(generate_tensor_with_range_type(shape, target_range_type, **kwargs))
1488+
1489+
return torch.cat(inputs), torch.cat(targets)
1490+
1491+
@pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0])
1492+
@pytest.mark.parametrize("gamma", [0, 2])
1493+
@pytest.mark.parametrize("device", cpu_and_gpu())
1494+
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
1495+
@pytest.mark.parametrize("seed", [0, 1])
1496+
def test_correct_ratio(self, alpha, gamma, device, dtype, seed) -> None:
1497+
if device == "cpu" and dtype is torch.half:
1498+
pytest.skip("Currently torch.half is not fully supported on cpu")
1499+
# For testing the ratio with manual calculation, we require the reduction to be "none"
1500+
reduction = "none"
1501+
torch.random.manual_seed(seed)
1502+
inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device)
1503+
focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)
1504+
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction=reduction)
1505+
1506+
assert torch.all(
1507+
focal_loss <= ce_loss
1508+
), "focal loss must be less or equal to cross entropy loss with same input"
1509+
1510+
loss_ratio = (focal_loss / ce_loss).squeeze()
1511+
prob = torch.sigmoid(inputs)
1512+
p_t = prob * targets + (1 - prob) * (1 - targets)
1513+
correct_ratio = (1.0 - p_t) ** gamma
1514+
if alpha >= 0:
1515+
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
1516+
correct_ratio = correct_ratio * alpha_t
1517+
1518+
tol = 1e-3 if dtype is torch.half else 1e-5
1519+
torch.testing.assert_close(correct_ratio, loss_ratio, rtol=tol, atol=tol)
1520+
1521+
@pytest.mark.parametrize("reduction", ["mean", "sum"])
1522+
@pytest.mark.parametrize("device", cpu_and_gpu())
1523+
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
1524+
@pytest.mark.parametrize("seed", [2, 3])
1525+
def test_equal_ce_loss(self, reduction, device, dtype, seed) -> None:
1526+
if device == "cpu" and dtype is torch.half:
1527+
pytest.skip("Currently torch.half is not fully supported on cpu")
1528+
# focal loss should be equal ce_loss if alpha=-1 and gamma=0
1529+
alpha = -1
1530+
gamma = 0
1531+
torch.random.manual_seed(seed)
1532+
inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device)
1533+
inputs_fl = inputs.clone().requires_grad_()
1534+
targets_fl = targets.clone()
1535+
inputs_ce = inputs.clone().requires_grad_()
1536+
targets_ce = targets.clone()
1537+
focal_loss = ops.sigmoid_focal_loss(inputs_fl, targets_fl, gamma=gamma, alpha=alpha, reduction=reduction)
1538+
ce_loss = F.binary_cross_entropy_with_logits(inputs_ce, targets_ce, reduction=reduction)
1539+
1540+
tol = 1e-3 if dtype is torch.half else 1e-5
1541+
torch.testing.assert_close(focal_loss, ce_loss, rtol=tol, atol=tol)
1542+
1543+
focal_loss.backward()
1544+
ce_loss.backward()
1545+
torch.testing.assert_close(inputs_fl.grad, inputs_ce.grad, rtol=tol, atol=tol)
1546+
1547+
@pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0])
1548+
@pytest.mark.parametrize("gamma", [0, 2])
1549+
@pytest.mark.parametrize("reduction", ["none", "mean", "sum"])
1550+
@pytest.mark.parametrize("device", cpu_and_gpu())
1551+
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
1552+
@pytest.mark.parametrize("seed", [4, 5])
1553+
def test_jit(self, alpha, gamma, reduction, device, dtype, seed) -> None:
1554+
if device == "cpu" and dtype is torch.half:
1555+
pytest.skip("Currently torch.half is not fully supported on cpu")
1556+
script_fn = torch.jit.script(ops.sigmoid_focal_loss)
1557+
torch.random.manual_seed(seed)
1558+
inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device)
1559+
focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)
1560+
if device == "cpu":
1561+
scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)
1562+
else:
1563+
with torch.jit.fuser("fuser2"):
1564+
# Use fuser2 to prevent a bug on fuser: https://github.com/pytorch/pytorch/issues/75476
1565+
# We may remove this condition once the bug is resolved
1566+
scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)
1567+
1568+
tol = 1e-3 if dtype is torch.half else 1e-5
1569+
torch.testing.assert_close(focal_loss, scripted_focal_loss, rtol=tol, atol=tol)
1570+
1571+
14531572
if __name__ == "__main__":
14541573
pytest.main([__file__])

0 commit comments

Comments
 (0)