|
9 | 9 | import pytest
|
10 | 10 | import torch
|
11 | 11 | import torch.fx
|
| 12 | +import torch.nn.functional as F |
12 | 13 | from common_utils import assert_equal, cpu_and_gpu, needs_cuda
|
13 | 14 | from PIL import Image
|
14 | 15 | from torch import nn, Tensor
|
@@ -1450,5 +1451,123 @@ def test_is_leaf_node(self, dim, p, block_size, inplace):
|
1450 | 1451 | assert len(graph_node_names[0]) == 1 + op_obj.n_inputs
|
1451 | 1452 |
|
1452 | 1453 |
|
| 1454 | +class TestFocalLoss: |
| 1455 | + def _generate_diverse_input_target_pair(self, shape=(5, 2), **kwargs): |
| 1456 | + def logit(p: Tensor) -> Tensor: |
| 1457 | + return torch.log(p / (1 - p)) |
| 1458 | + |
| 1459 | + def generate_tensor_with_range_type(shape, range_type, **kwargs): |
| 1460 | + if range_type != "random_binary": |
| 1461 | + low, high = { |
| 1462 | + "small": (0.0, 0.2), |
| 1463 | + "big": (0.8, 1.0), |
| 1464 | + "zeros": (0.0, 0.0), |
| 1465 | + "ones": (1.0, 1.0), |
| 1466 | + "random": (0.0, 1.0), |
| 1467 | + }[range_type] |
| 1468 | + return torch.testing.make_tensor(shape, low=low, high=high, **kwargs) |
| 1469 | + else: |
| 1470 | + return torch.randint(0, 2, shape, **kwargs) |
| 1471 | + |
| 1472 | + # This function will return inputs and targets with shape: (shape[0]*9, shape[1]) |
| 1473 | + inputs = [] |
| 1474 | + targets = [] |
| 1475 | + for input_range_type, target_range_type in [ |
| 1476 | + ("small", "zeros"), |
| 1477 | + ("small", "ones"), |
| 1478 | + ("small", "random_binary"), |
| 1479 | + ("big", "zeros"), |
| 1480 | + ("big", "ones"), |
| 1481 | + ("big", "random_binary"), |
| 1482 | + ("random", "zeros"), |
| 1483 | + ("random", "ones"), |
| 1484 | + ("random", "random_binary"), |
| 1485 | + ]: |
| 1486 | + inputs.append(logit(generate_tensor_with_range_type(shape, input_range_type, **kwargs))) |
| 1487 | + targets.append(generate_tensor_with_range_type(shape, target_range_type, **kwargs)) |
| 1488 | + |
| 1489 | + return torch.cat(inputs), torch.cat(targets) |
| 1490 | + |
| 1491 | + @pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0]) |
| 1492 | + @pytest.mark.parametrize("gamma", [0, 2]) |
| 1493 | + @pytest.mark.parametrize("device", cpu_and_gpu()) |
| 1494 | + @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) |
| 1495 | + @pytest.mark.parametrize("seed", [0, 1]) |
| 1496 | + def test_correct_ratio(self, alpha, gamma, device, dtype, seed) -> None: |
| 1497 | + if device == "cpu" and dtype is torch.half: |
| 1498 | + pytest.skip("Currently torch.half is not fully supported on cpu") |
| 1499 | + # For testing the ratio with manual calculation, we require the reduction to be "none" |
| 1500 | + reduction = "none" |
| 1501 | + torch.random.manual_seed(seed) |
| 1502 | + inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device) |
| 1503 | + focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) |
| 1504 | + ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction=reduction) |
| 1505 | + |
| 1506 | + assert torch.all( |
| 1507 | + focal_loss <= ce_loss |
| 1508 | + ), "focal loss must be less or equal to cross entropy loss with same input" |
| 1509 | + |
| 1510 | + loss_ratio = (focal_loss / ce_loss).squeeze() |
| 1511 | + prob = torch.sigmoid(inputs) |
| 1512 | + p_t = prob * targets + (1 - prob) * (1 - targets) |
| 1513 | + correct_ratio = (1.0 - p_t) ** gamma |
| 1514 | + if alpha >= 0: |
| 1515 | + alpha_t = alpha * targets + (1 - alpha) * (1 - targets) |
| 1516 | + correct_ratio = correct_ratio * alpha_t |
| 1517 | + |
| 1518 | + tol = 1e-3 if dtype is torch.half else 1e-5 |
| 1519 | + torch.testing.assert_close(correct_ratio, loss_ratio, rtol=tol, atol=tol) |
| 1520 | + |
| 1521 | + @pytest.mark.parametrize("reduction", ["mean", "sum"]) |
| 1522 | + @pytest.mark.parametrize("device", cpu_and_gpu()) |
| 1523 | + @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) |
| 1524 | + @pytest.mark.parametrize("seed", [2, 3]) |
| 1525 | + def test_equal_ce_loss(self, reduction, device, dtype, seed) -> None: |
| 1526 | + if device == "cpu" and dtype is torch.half: |
| 1527 | + pytest.skip("Currently torch.half is not fully supported on cpu") |
| 1528 | + # focal loss should be equal ce_loss if alpha=-1 and gamma=0 |
| 1529 | + alpha = -1 |
| 1530 | + gamma = 0 |
| 1531 | + torch.random.manual_seed(seed) |
| 1532 | + inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device) |
| 1533 | + inputs_fl = inputs.clone().requires_grad_() |
| 1534 | + targets_fl = targets.clone() |
| 1535 | + inputs_ce = inputs.clone().requires_grad_() |
| 1536 | + targets_ce = targets.clone() |
| 1537 | + focal_loss = ops.sigmoid_focal_loss(inputs_fl, targets_fl, gamma=gamma, alpha=alpha, reduction=reduction) |
| 1538 | + ce_loss = F.binary_cross_entropy_with_logits(inputs_ce, targets_ce, reduction=reduction) |
| 1539 | + |
| 1540 | + tol = 1e-3 if dtype is torch.half else 1e-5 |
| 1541 | + torch.testing.assert_close(focal_loss, ce_loss, rtol=tol, atol=tol) |
| 1542 | + |
| 1543 | + focal_loss.backward() |
| 1544 | + ce_loss.backward() |
| 1545 | + torch.testing.assert_close(inputs_fl.grad, inputs_ce.grad, rtol=tol, atol=tol) |
| 1546 | + |
| 1547 | + @pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0]) |
| 1548 | + @pytest.mark.parametrize("gamma", [0, 2]) |
| 1549 | + @pytest.mark.parametrize("reduction", ["none", "mean", "sum"]) |
| 1550 | + @pytest.mark.parametrize("device", cpu_and_gpu()) |
| 1551 | + @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) |
| 1552 | + @pytest.mark.parametrize("seed", [4, 5]) |
| 1553 | + def test_jit(self, alpha, gamma, reduction, device, dtype, seed) -> None: |
| 1554 | + if device == "cpu" and dtype is torch.half: |
| 1555 | + pytest.skip("Currently torch.half is not fully supported on cpu") |
| 1556 | + script_fn = torch.jit.script(ops.sigmoid_focal_loss) |
| 1557 | + torch.random.manual_seed(seed) |
| 1558 | + inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device) |
| 1559 | + focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) |
| 1560 | + if device == "cpu": |
| 1561 | + scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) |
| 1562 | + else: |
| 1563 | + with torch.jit.fuser("fuser2"): |
| 1564 | + # Use fuser2 to prevent a bug on fuser: https://github.com/pytorch/pytorch/issues/75476 |
| 1565 | + # We may remove this condition once the bug is resolved |
| 1566 | + scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) |
| 1567 | + |
| 1568 | + tol = 1e-3 if dtype is torch.half else 1e-5 |
| 1569 | + torch.testing.assert_close(focal_loss, scripted_focal_loss, rtol=tol, atol=tol) |
| 1570 | + |
| 1571 | + |
1453 | 1572 | if __name__ == "__main__":
|
1454 | 1573 | pytest.main([__file__])
|
0 commit comments