diff --git a/captum/robust/_core/fgsm.py b/captum/robust/_core/fgsm.py index 6e57f17640..7ce267e0c8 100644 --- a/captum/robust/_core/fgsm.py +++ b/captum/robust/_core/fgsm.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -from typing import Any, Callable, Optional, Tuple +from typing import Any, Callable, Optional, Tuple, Union import torch from captum._utils.common import ( @@ -82,6 +82,7 @@ def perturb( target: Any, additional_forward_args: Any = None, targeted: bool = False, + mask: Optional[TensorOrTupleOfTensorsGeneric] = None, ) -> TensorOrTupleOfTensorsGeneric: r""" This method computes and returns the perturbed input for each input tensor. @@ -130,6 +131,12 @@ def perturb( Default: None. targeted (bool, optional): If attack should be targeted. Default: False. + mask (Tensor or tuple[Tensor, ...], optional): mask of zeroes and ones + that defines which elements within the input tensor(s) are + perturbed. This mask must have the same shape and + dimensionality as the inputs. If this argument is not + provided, all elements will be perturbed. + Default: None. Returns: @@ -144,6 +151,11 @@ def perturb( """ is_inputs_tuple = _is_tuple(inputs) inputs: Tuple[Tensor, ...] = _format_tensor_into_tuples(inputs) + masks: Union[Tuple[int, ...], Tuple[Tensor, ...]] = ( + _format_tensor_into_tuples(mask) + if (mask is not None) + else (1,) * len(inputs) + ) gradient_mask = apply_gradient_requirements(inputs) def _forward_with_loss() -> Tensor: @@ -161,7 +173,7 @@ def _forward_with_loss() -> Tensor: grads = compute_gradients(_forward_with_loss, inputs) undo_gradient_requirements(inputs, gradient_mask) - perturbed_inputs = self._perturb(inputs, grads, epsilon, targeted) + perturbed_inputs = self._perturb(inputs, grads, epsilon, targeted, masks) perturbed_inputs = tuple( self.bound(perturbed_inputs[i]) for i in range(len(perturbed_inputs)) ) @@ -173,6 +185,7 @@ def _perturb( grads: Tuple, epsilon: float, targeted: bool, + masks: Tuple, ) -> Tuple: r""" A helper function to calculate the perturbed inputs given original @@ -183,9 +196,9 @@ def _perturb( inputs = tuple( torch.where( torch.abs(grad) > self.zero_thresh, - inp + multiplier * epsilon * torch.sign(grad), + inp + multiplier * epsilon * torch.sign(grad) * mask, inp, ) - for grad, inp in zip(grads, inputs) + for grad, inp, mask in zip(grads, inputs, masks) ) return inputs diff --git a/captum/robust/_core/pgd.py b/captum/robust/_core/pgd.py index 14e6374f0a..5391b39cfb 100644 --- a/captum/robust/_core/pgd.py +++ b/captum/robust/_core/pgd.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -from typing import Any, Callable +from typing import Any, Callable, Optional, Tuple, Union import torch import torch.nn.functional as F @@ -78,6 +78,7 @@ def perturb( targeted: bool = False, random_start: bool = False, norm: str = "Linf", + mask: Optional[TensorOrTupleOfTensorsGeneric] = None, ) -> TensorOrTupleOfTensorsGeneric: r""" This method computes and returns the perturbed input for each input tensor. @@ -134,6 +135,12 @@ def perturb( norm (str, optional): Specifies the norm to calculate distance from original inputs: ``Linf`` | ``L2``. Default: ``Linf`` + mask (Tensor or tuple[Tensor, ...], optional): mask of zeroes and ones + that defines which elements within the input tensor(s) are + perturbed. This mask must have the same shape and + dimensionality as the inputs. If this argument is not + provided, all elements are perturbed. + Default: None. Returns: @@ -157,15 +164,29 @@ def _clip(inputs: Tensor, outputs: Tensor) -> Tensor: is_inputs_tuple = _is_tuple(inputs) formatted_inputs = _format_tensor_into_tuples(inputs) + formatted_masks: Union[Tuple[int, ...], Tuple[Tensor, ...]] = ( + _format_tensor_into_tuples(mask) + if (mask is not None) + else (1,) * len(formatted_inputs) + ) perturbed_inputs = formatted_inputs if random_start: perturbed_inputs = tuple( - self.bound(self._random_point(formatted_inputs[i], radius, norm)) + self.bound( + self._random_point( + formatted_inputs[i], radius, norm, formatted_masks[i] + ) + ) for i in range(len(formatted_inputs)) ) for _i in range(step_num): perturbed_inputs = self.fgsm.perturb( - perturbed_inputs, step_size, target, additional_forward_args, targeted + perturbed_inputs, + step_size, + target, + additional_forward_args, + targeted, + formatted_masks, ) perturbed_inputs = tuple( _clip(formatted_inputs[j], perturbed_inputs[j]) @@ -178,7 +199,9 @@ def _clip(inputs: Tensor, outputs: Tensor) -> Tensor: ) return _format_output(is_inputs_tuple, perturbed_inputs) - def _random_point(self, center: Tensor, radius: float, norm: str) -> Tensor: + def _random_point( + self, center: Tensor, radius: float, norm: str, mask: Union[Tensor, int] + ) -> Tensor: r""" A helper function that returns a uniform random point within the ball with the given center and radius. Norm should be either L2 or Linf. @@ -190,9 +213,9 @@ def _random_point(self, center: Tensor, radius: float, norm: str) -> Tensor: r = (torch.rand(u.size(0)) ** (1.0 / d)) * radius r = r[(...,) + (None,) * (r.dim() - 1)] x = r * unit_u - return center + x + return center + (x * mask) elif norm == "Linf": x = torch.rand_like(center) * radius * 2 - radius - return center + x + return center + (x * mask) else: raise AssertionError("Norm constraint must be L2 or Linf.") diff --git a/tests/robust/test_FGSM.py b/tests/robust/test_FGSM.py index 595d8c7b0e..acdcbe7eb0 100644 --- a/tests/robust/test_FGSM.py +++ b/tests/robust/test_FGSM.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -from typing import Any, Callable, List, Tuple, Union +from typing import Any, Callable, List, Optional, Tuple, Union import torch from captum._utils.typing import TensorLikeList, TensorOrTupleOfTensorsGeneric @@ -128,6 +128,60 @@ def test_attack_bound(self) -> None: upper_bound=5.0, ) + def test_attack_masked_tensor(self) -> None: + model = BasicModel() + input = torch.tensor([[2.0, -9.0, 9.0, 1.0, -3.0]], requires_grad=True) + mask = torch.tensor([[1, 0, 0, 1, 1]]) + self._FGSM_assert( + model, input, 1, 0.1, [[2.0, -9.0, 9.0, 1.0, -3.0]], mask=mask + ) + + def test_attack_masked_multiinput(self) -> None: + model = BasicModel2() + input1 = torch.tensor([[4.0, -1.0], [3.0, 10.0]], requires_grad=True) + input2 = torch.tensor([[2.0, -5.0], [-2.0, 1.0]], requires_grad=True) + mask1 = torch.tensor([[1, 0], [1, 0]]) + mask2 = torch.tensor([[0, 0], [0, 0]]) + self._FGSM_assert( + model, + (input1, input2), + 0, + 0.25, + ([[3.75, -1.0], [2.75, 10.0]], [[2.0, -5.0], [-2.0, 1.0]]), + mask=(mask1, mask2), + ) + + def test_attack_masked_loss_defined(self) -> None: + model = BasicModel_MultiLayer() + add_input = torch.tensor([[-1.0, 2.0, 2.0]]) + input = torch.tensor([[1.0, 6.0, -3.0]]) + labels = torch.tensor([0]) + mask = torch.tensor([[0, 0, 1]]) + loss_func = CrossEntropyLoss(reduction="none") + adv = FGSM(model, loss_func) + perturbed_input = adv.perturb( + input, 0.2, labels, additional_forward_args=(add_input,), mask=mask + ) + assertTensorAlmostEqual( + self, perturbed_input, [[1.0, 6.0, -3.0]], delta=0.01, mode="max" + ) + + def test_attack_masked_bound(self) -> None: + model = BasicModel() + input = torch.tensor([[9.0, 10.0, -6.0, -1.0]]) + mask = torch.tensor([[1, 0, 1, 0]]) + self._FGSM_assert( + model, + input, + 3, + 0.2, + [[5.0, 5.0, -5.0, -1.0]], + targeted=True, + lower_bound=-5.0, + upper_bound=5.0, + mask=mask, + ) + def _FGSM_assert( self, model: Callable, @@ -139,10 +193,11 @@ def _FGSM_assert( additional_inputs: Any = None, lower_bound: float = float("-inf"), upper_bound: float = float("inf"), + mask: Optional[TensorOrTupleOfTensorsGeneric] = None, ) -> None: adv = FGSM(model, lower_bound=lower_bound, upper_bound=upper_bound) perturbed_input = adv.perturb( - inputs, epsilon, target, additional_inputs, targeted + inputs, epsilon, target, additional_inputs, targeted, mask ) if isinstance(perturbed_input, Tensor): assertTensorAlmostEqual( diff --git a/tests/robust/test_PGD.py b/tests/robust/test_PGD.py index 340026182f..7e39ca99d9 100644 --- a/tests/robust/test_PGD.py +++ b/tests/robust/test_PGD.py @@ -108,3 +108,103 @@ def test_attack_random_start(self) -> None: ) norm = torch.norm((perturbed_input - input).squeeze()).numpy() self.assertLessEqual(norm, 0.25) + + def test_attack_masked_nontargeted(self) -> None: + model = BasicModel() + input = torch.tensor([[2.0, -9.0, 9.0, 1.0, -3.0]]) + mask = torch.tensor([[1, 1, 0, 0, 0]]) + adv = PGD(model) + perturbed_input = adv.perturb(input, 0.25, 0.1, 2, 4, mask=mask) + assertTensorAlmostEqual( + self, + perturbed_input, + [[2.0, -9.0, 9.0, 1.0, -3.0]], + delta=0.01, + mode="max", + ) + + def test_attack_masked_targeted(self) -> None: + model = BasicModel() + input = torch.tensor([[9.0, 10.0, -6.0, -1.0]], requires_grad=True) + mask = torch.tensor([[1, 1, 1, 0]]) + adv = PGD(model) + perturbed_input = adv.perturb(input, 0.2, 0.1, 3, 3, targeted=True, mask=mask) + assertTensorAlmostEqual( + self, + perturbed_input, + [[9.0, 10.0, -6.0, -1.0]], + delta=0.01, + mode="max", + ) + + def test_attack_masked_multiinput(self) -> None: + model = BasicModel2() + input1 = torch.tensor([[4.0, -1.0], [3.0, 10.0]], requires_grad=True) + input2 = torch.tensor([[2.0, -5.0], [-2.0, 1.0]], requires_grad=True) + mask1 = torch.tensor([[1, 1], [0, 0]]) + mask2 = torch.tensor([[0, 1], [0, 1]]) + adv = PGD(model) + perturbed_input = adv.perturb( + (input1, input2), 0.25, 0.1, 3, 0, norm="L2", mask=(mask1, mask2) + ) + answer = ([[3.75, -1.0], [3.0, 10.0]], [[2.0, -5.0], [-2.0, 1.0]]) + for i in range(len(perturbed_input)): + assertTensorAlmostEqual( + self, + perturbed_input[i], + answer[i], + delta=0.01, + mode="max", + ) + + def test_attack_masked_random_start(self) -> None: + model = BasicModel() + input = torch.tensor([[2.0, -9.0, 9.0, 1.0, -3.0]]) + mask = torch.tensor([[1, 0, 1, 0, 1]]) + adv = PGD(model) + perturbed_input = adv.perturb( + input, 0.25, 0.1, 0, 4, random_start=True, mask=mask + ) + assertTensorAlmostEqual( + self, + perturbed_input, + [[2.0, -9.0, 9.0, 1.0, -3.0]], + delta=0.25, + mode="max", + ) + perturbed_input = adv.perturb( + input, 0.25, 0.1, 0, 4, norm="L2", random_start=True, mask=mask + ) + norm = torch.norm((perturbed_input - input).squeeze()).numpy() + self.assertLessEqual(norm, 0.25) + + def test_attack_masked_3dimensional_input(self) -> None: + model = BasicModel() + input = torch.tensor( + [[[4.0, 2.0], [-1.0, -2.0]], [[3.0, -4.0], [10.0, 5.0]]], requires_grad=True + ) + mask = torch.tensor([[[1, 0], [0, 1]], [[1, 0], [1, 1]]]) + adv = PGD(model) + perturbed_input = adv.perturb(input, 0.25, 0.1, 3, (0, 1), mask=mask) + assertTensorAlmostEqual( + self, + perturbed_input, + [[[4.0, 2.0], [-1.0, -2.0]], [[3.0, -4.0], [10.0, 5.0]]], + delta=0.01, + mode="max", + ) + + def test_attack_masked_loss_defined(self) -> None: + model = BasicModel_MultiLayer() + add_input = torch.tensor([[-1.0, 2.0, 2.0]]) + input = torch.tensor([[1.0, 6.0, -3.0]]) + mask = torch.tensor([[0, 1, 0]]) + labels = torch.tensor([0]) + loss_func = CrossEntropyLoss(reduction="none") + adv = PGD(model, loss_func) + perturbed_input = adv.perturb( + input, 0.25, 0.1, 3, labels, additional_forward_args=(add_input,), mask=mask + ) + assertTensorAlmostEqual( + self, perturbed_input, [[1.0, 6.0, -3.0]], delta=0.01, mode="max" + )