Skip to content

Commit 63e36b6

Browse files
committed
[DLMED] update inverse and spatial_pad
Signed-off-by: Nic Ma <[email protected]>
1 parent 000f035 commit 63e36b6

File tree

3 files changed

+228
-98
lines changed

3 files changed

+228
-98
lines changed

monai/transforms/croppad/array.py

Lines changed: 24 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from monai.config import IndexSelection
2525
from monai.config.type_definitions import NdarrayOrTensor
2626
from monai.data.meta_obj import get_track_meta
27+
from monai.data.meta_tensor import MetaTensor
2728
from monai.data.utils import get_random_patch, get_valid_patch_size
2829
from monai.transforms.inverse import InvertibleTransform
2930
from monai.transforms.transform import Randomizable, Transform
@@ -89,6 +90,8 @@ class Pad(InvertibleTransform):
8990
9091
"""
9192

93+
backend = [TransformBackends.TORCH]
94+
9295
def __init__(
9396
self,
9497
to_pad: Optional[List[Tuple[int, int]]] = None,
@@ -137,13 +140,16 @@ def __call__(
137140
img_t = pad_pt(img_t.unsqueeze(0), pad_width, mode=mode_, **kwargs_).squeeze(0)
138141

139142
if get_track_meta():
140-
spatial_rank = max(len(img_t.affine) - 1, 1)
141-
to_shift = [-s[0] for s in to_pad_[1:]] # skipping the channel pad
142-
mat = create_translate(spatial_rank, to_shift)
143-
img_t.meta["affine"] = img_t.affine @ convert_to_dst_type(mat, img_t.affine)[0]
144-
self.push_transform(img_t, extra_info={"padded": to_pad})
143+
self._update_meta(tensor=img_t, to_pad=to_pad_)
144+
self.push_transform(img_t, extra_info={"padded": to_pad_})
145145
return img_t
146146

147+
def _update_meta(self, tensor: MetaTensor, to_pad: List[Tuple[int, int]]):
148+
spatial_rank = max(len(tensor.affine) - 1, 1)
149+
to_shift = [-s[0] for s in to_pad[1:]] # skipping the channel pad
150+
mat = create_translate(spatial_rank, to_shift)
151+
tensor.meta["affine"] = tensor.affine @ convert_to_dst_type(mat, tensor.affine)[0]
152+
147153
def inverse(self, data: torch.Tensor) -> torch.Tensor:
148154
transform = self.pop_transform(data)
149155
padded = transform[TraceKeys.EXTRA_INFO]["padded"]
@@ -158,16 +164,10 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor:
158164
return cropper(data)
159165

160166

161-
class SpatialPad(Transform):
167+
class SpatialPad(Pad):
162168
"""
163169
Performs padding to the data, symmetric for all sides or all on one side for each dimension.
164170
165-
If input is `torch.Tensor` and mode is `constant`, `torch.nn.functional.pad` will be used.
166-
Otherwise, `np.pad` will be used (input converted to `np.ndarray` if necessary).
167-
168-
Uses np.pad so in practice, a mode needs to be provided. See numpy.lib.arraypad.pad
169-
for additional details.
170-
171171
Args:
172172
spatial_size: the spatial size of output data after padding, if a dimension of the input
173173
data size is bigger than the pad size, will not pad that dimension.
@@ -176,30 +176,24 @@ class SpatialPad(Transform):
176176
`spatial_size=[32, 25, -1]`, the spatial size of output data will be [32, 30, 30].
177177
method: {``"symmetric"``, ``"end"``}
178178
Pad image symmetrically on every side or only pad at the end sides. Defaults to ``"symmetric"``.
179-
mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``,
180-
``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
181-
available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}.
179+
mode: available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}.
182180
One of the listed string values or a user supplied function. Defaults to ``"constant"``.
183-
See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
184-
https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
185-
kwargs: other arguments for the `np.pad` or `torch.pad` function.
186-
note that `np.pad` treats channel dimension as the first dimension.
181+
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html.
182+
default to `self.mode`.
183+
kwargs: other arguments for the `torch.pad` function, will override `self.kwargs`.
187184
188185
"""
189186

190-
backend = Pad.backend
191-
192187
def __init__(
193188
self,
194189
spatial_size: Union[Sequence[int], int],
195190
method: Union[Method, str] = Method.SYMMETRIC,
196-
mode: Union[NumpyPadMode, PytorchPadMode, str] = NumpyPadMode.CONSTANT,
191+
mode: Union[PytorchPadMode, str] = NumpyPadMode.CONSTANT,
197192
**kwargs,
198193
) -> None:
199194
self.spatial_size = spatial_size
200195
self.method: Method = look_up_option(method, Method)
201-
self.mode = mode
202-
self.kwargs = kwargs
196+
super().__init__(mode=mode, **kwargs)
203197

204198
def _determine_data_pad_width(self, data_shape: Sequence[int]) -> List[Tuple[int, int]]:
205199
spatial_size = fall_back_tuple(self.spatial_size, data_shape)
@@ -212,8 +206,11 @@ def _determine_data_pad_width(self, data_shape: Sequence[int]) -> List[Tuple[int
212206
return [(0, max(sp_i - data_shape[i], 0)) for i, sp_i in enumerate(spatial_size)]
213207

214208
def __call__(
215-
self, img: NdarrayOrTensor, mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None
216-
) -> NdarrayOrTensor:
209+
self,
210+
img: torch.Tensor,
211+
mode: Optional[Union[PytorchPadMode, str]] = None,
212+
**kwargs,
213+
) -> torch.Tensor:
217214
"""
218215
Args:
219216
img: data to be transformed, assuming `img` is channel-first and
@@ -228,12 +225,8 @@ def __call__(
228225
"""
229226
data_pad_width = self._determine_data_pad_width(img.shape[1:])
230227
all_pad_width = [(0, 0)] + data_pad_width
231-
if not np.asarray(all_pad_width).any():
232-
# all zeros, skip padding
233-
return img
234228

235-
padder = Pad(to_pad=all_pad_width, mode=mode or self.mode, **self.kwargs)
236-
return padder(img)
229+
return super().__call__(img=img, to_pad=all_pad_width, mode=mode, **kwargs)
237230

238231

239232
class BorderPad(Transform):

0 commit comments

Comments
 (0)