24
24
from monai .config import IndexSelection
25
25
from monai .config .type_definitions import NdarrayOrTensor
26
26
from monai .data .meta_obj import get_track_meta
27
+ from monai .data .meta_tensor import MetaTensor
27
28
from monai .data .utils import get_random_patch , get_valid_patch_size
28
29
from monai .transforms .inverse import InvertibleTransform
29
30
from monai .transforms .transform import Randomizable , Transform
@@ -89,6 +90,8 @@ class Pad(InvertibleTransform):
89
90
90
91
"""
91
92
93
+ backend = [TransformBackends .TORCH ]
94
+
92
95
def __init__ (
93
96
self ,
94
97
to_pad : Optional [List [Tuple [int , int ]]] = None ,
@@ -137,13 +140,16 @@ def __call__(
137
140
img_t = pad_pt (img_t .unsqueeze (0 ), pad_width , mode = mode_ , ** kwargs_ ).squeeze (0 )
138
141
139
142
if get_track_meta ():
140
- spatial_rank = max (len (img_t .affine ) - 1 , 1 )
141
- to_shift = [- s [0 ] for s in to_pad_ [1 :]] # skipping the channel pad
142
- mat = create_translate (spatial_rank , to_shift )
143
- img_t .meta ["affine" ] = img_t .affine @ convert_to_dst_type (mat , img_t .affine )[0 ]
144
- self .push_transform (img_t , extra_info = {"padded" : to_pad })
143
+ self ._update_meta (tensor = img_t , to_pad = to_pad_ )
144
+ self .push_transform (img_t , extra_info = {"padded" : to_pad_ })
145
145
return img_t
146
146
147
+ def _update_meta (self , tensor : MetaTensor , to_pad : List [Tuple [int , int ]]):
148
+ spatial_rank = max (len (tensor .affine ) - 1 , 1 )
149
+ to_shift = [- s [0 ] for s in to_pad [1 :]] # skipping the channel pad
150
+ mat = create_translate (spatial_rank , to_shift )
151
+ tensor .meta ["affine" ] = tensor .affine @ convert_to_dst_type (mat , tensor .affine )[0 ]
152
+
147
153
def inverse (self , data : torch .Tensor ) -> torch .Tensor :
148
154
transform = self .pop_transform (data )
149
155
padded = transform [TraceKeys .EXTRA_INFO ]["padded" ]
@@ -158,16 +164,10 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor:
158
164
return cropper (data )
159
165
160
166
161
- class SpatialPad (Transform ):
167
+ class SpatialPad (Pad ):
162
168
"""
163
169
Performs padding to the data, symmetric for all sides or all on one side for each dimension.
164
170
165
- If input is `torch.Tensor` and mode is `constant`, `torch.nn.functional.pad` will be used.
166
- Otherwise, `np.pad` will be used (input converted to `np.ndarray` if necessary).
167
-
168
- Uses np.pad so in practice, a mode needs to be provided. See numpy.lib.arraypad.pad
169
- for additional details.
170
-
171
171
Args:
172
172
spatial_size: the spatial size of output data after padding, if a dimension of the input
173
173
data size is bigger than the pad size, will not pad that dimension.
@@ -176,30 +176,24 @@ class SpatialPad(Transform):
176
176
`spatial_size=[32, 25, -1]`, the spatial size of output data will be [32, 30, 30].
177
177
method: {``"symmetric"``, ``"end"``}
178
178
Pad image symmetrically on every side or only pad at the end sides. Defaults to ``"symmetric"``.
179
- mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``,
180
- ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
181
- available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}.
179
+ mode: available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}.
182
180
One of the listed string values or a user supplied function. Defaults to ``"constant"``.
183
- See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
184
- https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
185
- kwargs: other arguments for the `np.pad` or `torch.pad` function.
186
- note that `np.pad` treats channel dimension as the first dimension.
181
+ See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html.
182
+ default to `self.mode`.
183
+ kwargs: other arguments for the `torch.pad` function, will override `self.kwargs`.
187
184
188
185
"""
189
186
190
- backend = Pad .backend
191
-
192
187
def __init__ (
193
188
self ,
194
189
spatial_size : Union [Sequence [int ], int ],
195
190
method : Union [Method , str ] = Method .SYMMETRIC ,
196
- mode : Union [NumpyPadMode , PytorchPadMode , str ] = NumpyPadMode .CONSTANT ,
191
+ mode : Union [PytorchPadMode , str ] = NumpyPadMode .CONSTANT ,
197
192
** kwargs ,
198
193
) -> None :
199
194
self .spatial_size = spatial_size
200
195
self .method : Method = look_up_option (method , Method )
201
- self .mode = mode
202
- self .kwargs = kwargs
196
+ super ().__init__ (mode = mode , ** kwargs )
203
197
204
198
def _determine_data_pad_width (self , data_shape : Sequence [int ]) -> List [Tuple [int , int ]]:
205
199
spatial_size = fall_back_tuple (self .spatial_size , data_shape )
@@ -212,8 +206,11 @@ def _determine_data_pad_width(self, data_shape: Sequence[int]) -> List[Tuple[int
212
206
return [(0 , max (sp_i - data_shape [i ], 0 )) for i , sp_i in enumerate (spatial_size )]
213
207
214
208
def __call__ (
215
- self , img : NdarrayOrTensor , mode : Optional [Union [NumpyPadMode , PytorchPadMode , str ]] = None
216
- ) -> NdarrayOrTensor :
209
+ self ,
210
+ img : torch .Tensor ,
211
+ mode : Optional [Union [PytorchPadMode , str ]] = None ,
212
+ ** kwargs ,
213
+ ) -> torch .Tensor :
217
214
"""
218
215
Args:
219
216
img: data to be transformed, assuming `img` is channel-first and
@@ -228,12 +225,8 @@ def __call__(
228
225
"""
229
226
data_pad_width = self ._determine_data_pad_width (img .shape [1 :])
230
227
all_pad_width = [(0 , 0 )] + data_pad_width
231
- if not np .asarray (all_pad_width ).any ():
232
- # all zeros, skip padding
233
- return img
234
228
235
- padder = Pad (to_pad = all_pad_width , mode = mode or self .mode , ** self .kwargs )
236
- return padder (img )
229
+ return super ().__call__ (img = img , to_pad = all_pad_width , mode = mode , ** kwargs )
237
230
238
231
239
232
class BorderPad (Transform ):
0 commit comments