Skip to content

Commit a209b06

Browse files
authored
5413 update dicefocal include foreground (#5416)
Signed-off-by: Wenqi Li <[email protected]> Fixes #5413 ### Description excluding background shouldn't be done before softmax ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Wenqi Li <[email protected]>
1 parent 8925e3e commit a209b06

File tree

2 files changed

+57
-44
lines changed

2 files changed

+57
-44
lines changed

monai/losses/dice.py

Lines changed: 43 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,12 @@ def __init__(
6060
include_background: if False, channel index 0 (background category) is excluded from the calculation.
6161
if the non-background segmentations are small compared to the total image size they can get overwhelmed
6262
by the signal from the background so excluding it in such cases helps convergence.
63-
to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
63+
to_onehot_y: whether to convert the ``target`` into the one-hot format,
64+
using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.
6465
sigmoid: if True, apply a sigmoid function to the prediction.
6566
softmax: if True, apply a softmax function to the prediction.
66-
other_act: if don't want to use `sigmoid` or `softmax`, use other callable function to execute
67-
other activation layers, Defaults to ``None``. for example:
68-
`other_act = torch.tanh`.
67+
other_act: callable function to execute other activation layers, Defaults to ``None``. for example:
68+
``other_act = torch.tanh``.
6969
squared_pred: use squared versions of targets and predictions in the denominator or not.
7070
jaccard: compute Jaccard Index (soft IoU) instead of dice or not.
7171
reduction: {``"none"``, ``"mean"``, ``"sum"``}
@@ -247,12 +247,12 @@ def __init__(
247247
"""
248248
Args:
249249
include_background: If False channel index 0 (background category) is excluded from the calculation.
250-
to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
250+
to_onehot_y: whether to convert the ``target`` into the one-hot format,
251+
using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.
251252
sigmoid: If True, apply a sigmoid function to the prediction.
252253
softmax: If True, apply a softmax function to the prediction.
253-
other_act: if don't want to use `sigmoid` or `softmax`, use other callable function to execute
254-
other activation layers, Defaults to ``None``. for example:
255-
`other_act = torch.tanh`.
254+
other_act: callable function to execute other activation layers, Defaults to ``None``. for example:
255+
``other_act = torch.tanh``.
256256
w_type: {``"square"``, ``"simple"``, ``"uniform"``}
257257
Type of function to transform ground truth volume to a weight factor. Defaults to ``"square"``.
258258
reduction: {``"none"``, ``"mean"``, ``"sum"``}
@@ -639,14 +639,14 @@ def __init__(
639639
``reduction`` is used for both losses and other parameters are only used for dice loss.
640640
641641
include_background: if False channel index 0 (background category) is excluded from the calculation.
642-
to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
642+
to_onehot_y: whether to convert the ``target`` into the one-hot format,
643+
using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.
643644
sigmoid: if True, apply a sigmoid function to the prediction, only used by the `DiceLoss`,
644645
don't need to specify activation function for `CrossEntropyLoss`.
645646
softmax: if True, apply a softmax function to the prediction, only used by the `DiceLoss`,
646647
don't need to specify activation function for `CrossEntropyLoss`.
647-
other_act: if don't want to use `sigmoid` or `softmax`, use other callable function to execute
648-
other activation layers, Defaults to ``None``. for example: `other_act = torch.tanh`.
649-
only used by the `DiceLoss`, don't need to specify activation function for `CrossEntropyLoss`.
648+
other_act: callable function to execute other activation layers, Defaults to ``None``. for example:
649+
``other_act = torch.tanh``. only used by the `DiceLoss`, not for the `CrossEntropyLoss`.
650650
squared_pred: use squared versions of targets and predictions in the denominator or not.
651651
jaccard: compute Jaccard Index (soft IoU) instead of dice or not.
652652
reduction: {``"mean"``, ``"sum"``}
@@ -728,7 +728,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
728728
729729
"""
730730
if len(input.shape) != len(target.shape):
731-
raise ValueError("the number of dimensions for input and target should be the same.")
731+
raise ValueError(
732+
"the number of dimensions for input and target should be the same, "
733+
f"got shape {input.shape} and {target.shape}."
734+
)
732735

733736
dice_loss = self.dice(input, target)
734737
ce_loss = self.ce(input, target)
@@ -743,6 +746,10 @@ class DiceFocalLoss(_Loss):
743746
The details of Dice loss is shown in ``monai.losses.DiceLoss``.
744747
The details of Focal Loss is shown in ``monai.losses.FocalLoss``.
745748
749+
``gamma``, ``focal_weight`` and ``lambda_focal`` are only used for the focal loss.
750+
``include_background`` and ``reduction`` are used for both losses
751+
and other parameters are only used for dice loss.
752+
746753
"""
747754

748755
def __init__(
@@ -765,18 +772,15 @@ def __init__(
765772
) -> None:
766773
"""
767774
Args:
768-
``gamma``, ``focal_weight`` and ``lambda_focal`` are only used for focal loss.
769-
``include_background``, ``to_onehot_y``and ``reduction`` are used for both losses
770-
and other parameters are only used for dice loss.
771775
include_background: if False channel index 0 (background category) is excluded from the calculation.
772-
to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
776+
to_onehot_y: whether to convert the ``target`` into the one-hot format,
777+
using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.
773778
sigmoid: if True, apply a sigmoid function to the prediction, only used by the `DiceLoss`,
774779
don't need to specify activation function for `FocalLoss`.
775780
softmax: if True, apply a softmax function to the prediction, only used by the `DiceLoss`,
776781
don't need to specify activation function for `FocalLoss`.
777-
other_act: if don't want to use `sigmoid` or `softmax`, use other callable function to execute
778-
other activation layers, Defaults to ``None``. for example: `other_act = torch.tanh`.
779-
only used by the `DiceLoss`, don't need to specify activation function for `FocalLoss`.
782+
other_act: callable function to execute other activation layers, Defaults to ``None``.
783+
for example: `other_act = torch.tanh`. only used by the `DiceLoss`, not for `FocalLoss`.
780784
squared_pred: use squared versions of targets and predictions in the denominator or not.
781785
jaccard: compute Jaccard Index (soft IoU) instead of dice or not.
782786
reduction: {``"none"``, ``"mean"``, ``"sum"``}
@@ -803,6 +807,8 @@ def __init__(
803807
"""
804808
super().__init__()
805809
self.dice = DiceLoss(
810+
include_background=include_background,
811+
to_onehot_y=False,
806812
sigmoid=sigmoid,
807813
softmax=softmax,
808814
other_act=other_act,
@@ -813,15 +819,20 @@ def __init__(
813819
smooth_dr=smooth_dr,
814820
batch=batch,
815821
)
816-
self.focal = FocalLoss(gamma=gamma, weight=focal_weight, reduction=reduction)
822+
self.focal = FocalLoss(
823+
include_background=include_background,
824+
to_onehot_y=False,
825+
gamma=gamma,
826+
weight=focal_weight,
827+
reduction=reduction,
828+
)
817829
if lambda_dice < 0.0:
818830
raise ValueError("lambda_dice should be no less than 0.0.")
819831
if lambda_focal < 0.0:
820832
raise ValueError("lambda_focal should be no less than 0.0.")
821833
self.lambda_dice = lambda_dice
822834
self.lambda_focal = lambda_focal
823835
self.to_onehot_y = to_onehot_y
824-
self.include_background = include_background
825836

826837
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
827838
"""
@@ -836,24 +847,16 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
836847
837848
"""
838849
if len(input.shape) != len(target.shape):
839-
raise ValueError("the number of dimensions for input and target should be the same.")
840-
841-
n_pred_ch = input.shape[1]
842-
850+
raise ValueError(
851+
"the number of dimensions for input and target should be the same, "
852+
f"got shape {input.shape} and {target.shape}."
853+
)
843854
if self.to_onehot_y:
855+
n_pred_ch = input.shape[1]
844856
if n_pred_ch == 1:
845857
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
846858
else:
847859
target = one_hot(target, num_classes=n_pred_ch)
848-
849-
if not self.include_background:
850-
if n_pred_ch == 1:
851-
warnings.warn("single channel prediction, `include_background=False` ignored.")
852-
else:
853-
# if skipping background, removing first channel
854-
target = target[:, 1:]
855-
input = input[:, 1:]
856-
857860
dice_loss = self.dice(input, target)
858861
focal_loss = self.focal(input, target)
859862
total_loss: torch.Tensor = self.lambda_dice * dice_loss + self.lambda_focal * focal_loss
@@ -867,11 +870,13 @@ class GeneralizedDiceFocalLoss(torch.nn.modules.loss._Loss):
867870
Args:
868871
include_background (bool, optional): if False channel index 0 (background category) is excluded from the calculation.
869872
Defaults to True.
870-
to_onehot_y (bool, optional): whether to convert `y` into the one-hot format. Defaults to False.
873+
to_onehot_y: whether to convert the ``target`` into the one-hot format,
874+
using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.
871875
sigmoid (bool, optional): if True, apply a sigmoid function to the prediction. Defaults to False.
872876
softmax (bool, optional): if True, apply a softmax function to the prediction. Defaults to False.
873-
other_act (Optional[Callable], optional): if don't want to use sigmoid or softmax, use other callable
874-
function to execute other activation layers. Defaults to None.
877+
other_act (Optional[Callable], optional): callable function to execute other activation layers,
878+
Defaults to ``None``. for example: `other_act = torch.tanh`.
879+
only used by the `GeneralizedDiceLoss`, not for the `FocalLoss`.
875880
w_type (Union[Weight, str], optional): {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to transform
876881
ground-truth volume to a weight factor. Defaults to ``"square"``.
877882
reduction (Union[LossReduction, str], optional): {``"none"``, ``"mean"``, ``"sum"``}. Specified the reduction to

tests/test_dice_focal_loss.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import numpy as np
1515
import torch
16+
from parameterized import parameterized
1617

1718
from monai.losses import DiceFocalLoss, DiceLoss, FocalLoss
1819
from tests.utils import test_script_save
@@ -36,17 +37,24 @@ def test_result_onehot_target_include_bg(self):
3637
expected_val = dice(pred, label) + lambda_focal * focal(pred, label)
3738
np.testing.assert_allclose(result, expected_val)
3839

39-
def test_result_no_onehot_no_bg(self):
40-
size = [3, 3, 5, 5]
41-
label = torch.randint(low=0, high=2, size=size)
42-
label = torch.argmax(label, dim=1, keepdim=True)
40+
@parameterized.expand([[[3, 3, 5, 5], True], [[3, 2, 5, 5], False]])
41+
def test_result_no_onehot_no_bg(self, size, onehot):
42+
label = torch.randint(low=0, high=size[1] - 1, size=size)
43+
if onehot:
44+
label = torch.argmax(label, dim=1, keepdim=True)
4345
pred = torch.randn(size)
4446
for reduction in ["sum", "mean", "none"]:
45-
common_params = {"include_background": False, "to_onehot_y": True, "reduction": reduction}
46-
for focal_weight in [2.0, torch.tensor([1.0, 2.0]), (2.0, 1)]:
47+
for focal_weight in [2.0] + [] if size[1] != 3 else [torch.tensor([1.0, 2.0]), (2.0, 1)]:
4748
for lambda_focal in [0.5, 1.0, 1.5]:
49+
common_params = {
50+
"include_background": False,
51+
"softmax": True,
52+
"to_onehot_y": onehot,
53+
"reduction": reduction,
54+
}
4855
dice_focal = DiceFocalLoss(focal_weight=focal_weight, lambda_focal=lambda_focal, **common_params)
4956
dice = DiceLoss(**common_params)
57+
common_params.pop("softmax", None)
5058
focal = FocalLoss(weight=focal_weight, **common_params)
5159
result = dice_focal(pred, label)
5260
expected_val = dice(pred, label) + lambda_focal * focal(pred, label)

0 commit comments

Comments
 (0)