From 6ddcaf62d4fd216de1b6364a4b53b98cb88b0cd4 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sun, 14 Aug 2022 13:16:14 +0100 Subject: [PATCH 01/10] fixes 4907 Signed-off-by: Wenqi Li --- monai/transforms/spatial/array.py | 28 +++++++++++++++------------- tests/test_resize.py | 2 -- tests/test_resized.py | 10 +++++++++- 3 files changed, 24 insertions(+), 16 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index a8e76e098b..14da37300a 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -842,10 +842,12 @@ def __call__( scale = self.spatial_size / max(img_size) spatial_size_ = tuple(int(round(s * scale)) for s in img_size) - if tuple(img.shape[1:]) == spatial_size_: # spatial shape is already the desired - return convert_to_tensor(img, track_meta=get_track_meta()) # type: ignore - original_sp_size = img.shape[1:] + _mode = look_up_option(self.mode if mode is None else mode, InterpolateMode) + _align_corners = self.align_corners if align_corners is None else align_corners + if tuple(img.shape[1:]) == spatial_size_: # spatial shape is already the desired + img = convert_to_tensor(img, track_meta=get_track_meta()) # type: ignore + return self._post_process(img, original_sp_size, spatial_size_, _mode, _align_corners, input_ndim) img_ = convert_to_tensor(img, dtype=torch.float, track_meta=False) if anti_aliasing and any(x < y for x, y in zip(spatial_size_, img_.shape[1:])): @@ -862,25 +864,25 @@ def __call__( img_ = convert_to_tensor(anti_aliasing_filter(img_), track_meta=False) img = convert_to_tensor(img, track_meta=get_track_meta()) - _mode = look_up_option(self.mode if mode is None else mode, InterpolateMode) - _align_corners = self.align_corners if align_corners is None else align_corners - resized = torch.nn.functional.interpolate( input=img_.unsqueeze(0), size=spatial_size_, mode=_mode, align_corners=_align_corners ) out, *_ = convert_to_dst_type(resized.squeeze(0), img) + return self._post_process(out, original_sp_size, spatial_size_, _mode, _align_corners, input_ndim) + + def _post_process(self, img: torch.Tensor, orig_size, sp_size, mode, align_corners, ndim) -> torch.Tensor: if get_track_meta(): - self.update_meta(out, original_sp_size, spatial_size_) + self.update_meta(img, orig_size, sp_size) self.push_transform( - out, - orig_size=original_sp_size, + img, + orig_size=orig_size, extra_info={ - "mode": _mode, - "align_corners": _align_corners if _align_corners is not None else TraceKeys.NONE, - "new_dim": len(original_sp_size) - input_ndim, # additional dims appended + "mode": mode, + "align_corners": align_corners if align_corners is not None else TraceKeys.NONE, + "new_dim": len(orig_size) - ndim, # additional dims appended }, ) - return out + return img def update_meta(self, img, spatial_size, new_spatial_size): affine = convert_to_tensor(img.affine, track_meta=False) diff --git a/tests/test_resize.py b/tests/test_resize.py index 8927b5dba5..b755bb3faf 100644 --- a/tests/test_resize.py +++ b/tests/test_resize.py @@ -74,8 +74,6 @@ def test_correct_results(self, spatial_size, mode, anti_aliasing): im = p(self.imt[0]) out = resize(im) if isinstance(im, MetaTensor): - if not out.applied_operations: - return # skipped because good shape im_inv = resize.inverse(out) self.assertTrue(not im_inv.applied_operations) assert_allclose(im_inv.shape, im.shape) diff --git a/tests/test_resized.py b/tests/test_resized.py index b8db666357..277243db96 100644 --- a/tests/test_resized.py +++ b/tests/test_resized.py @@ -16,7 +16,7 @@ from parameterized import parameterized from monai.data import MetaTensor, set_track_meta -from monai.transforms import Resized +from monai.transforms import Invertd, Resized from tests.utils import TEST_NDARRAYS_ALL, NumpyImageTestCase2D, assert_allclose, test_local_inversion TEST_CASE_0 = [{"keys": "img", "spatial_size": 15}, (6, 10, 15)] @@ -80,6 +80,14 @@ def test_longest_shape(self, input_param, expected_shape): np.testing.assert_allclose(result["img"].shape[1:], expected_shape) set_track_meta(True) + def test_identical_spatial(self): + test_input = {"X": np.ones((1, 10, 16, 17))} + xform = Resized("X", (-1, 18, 19)) + out = xform(test_input) + out["Y"] = 2 * out["X"] + transform_inverse = Invertd(keys="Y", transform=xform, orig_keys="X") + assert_allclose(transform_inverse(out)["Y"].array, np.ones((1, 10, 16, 17)) * 2) + if __name__ == "__main__": unittest.main() From b8715483e9ede84d56adcc0fc5515a42ecb494ae Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sat, 13 Aug 2022 18:07:42 +0100 Subject: [PATCH 02/10] compose knows when to eagerly execute Signed-off-by: Wenqi Li --- monai/data/meta_tensor.py | 58 +++++++++++++- monai/transforms/compose.py | 34 ++++++++ monai/transforms/croppad/array.py | 104 +++++++++++++++++++------ monai/transforms/croppad/dictionary.py | 19 ++++- monai/transforms/transform.py | 19 ++++- monai/utils/enums.py | 1 + 6 files changed, 204 insertions(+), 31 deletions(-) diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 8897371903..96439dd937 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -24,7 +24,7 @@ from monai.data.utils import affine_to_spacing, decollate_batch, list_data_collate, remove_extra_metadata from monai.utils import look_up_option from monai.utils.enums import MetaKeys, PostFix, SpaceKeys -from monai.utils.type_conversion import convert_data_type, convert_to_tensor +from monai.utils.type_conversion import convert_data_type, convert_to_dst_type, convert_to_tensor __all__ = ["MetaTensor"] @@ -125,7 +125,7 @@ def __init__( super().__init__() # set meta if meta is not None: - self.meta = meta + self.meta = dict(meta) elif isinstance(x, MetaObj): self.__dict__ = deepcopy(x.__dict__) # set the affine @@ -150,6 +150,60 @@ def __init__( if MetaKeys.SPACE not in self.meta: self.meta[MetaKeys.SPACE] = SpaceKeys.RAS # defaulting to the right-anterior-superior space + if MetaKeys.ORIGINAL_CHANNEL_DIM not in self.meta: + self.meta[MetaKeys.ORIGINAL_CHANNEL_DIM] = 0 # defaulting to channel first + + @property + def evaluated(self) -> bool: + """a flag indicating whether the array content is up-to-date with the affine/spatial_shape properties.""" + if MetaKeys.EVALUATED not in self.meta: + self.meta[MetaKeys.EVALUATED] = True + return bool(self.meta[MetaKeys.EVALUATED]) + + @evaluated.setter + def evaluated(self, value: bool): + """when setting an evaluated metatensor to a lazy status, original affine will be stored.""" + if not value and (MetaKeys.SPATIAL_SHAPE not in self.meta or MetaKeys.AFFINE not in self.meta): + warnings.warn("Setting MetaTensor to lazy evaluation requires spatial_shape and affine.") + if self.evaluated and not value: + self.meta[MetaKeys.ORIGINAL_AFFINE] = self.affine # switch to lazy evaluation, store current affine + self.meta[MetaKeys.SPATIAL_SHAPE] = self.spatial_shape + self.meta[MetaKeys.EVALUATED] = value + + def evaluate(self, mode="bilinear", padding_mode="border"): + if self.evaluated: + self.spatial_shape = self.array.shape[1:] + return + # how to ensure channel first? + resampler = monai.transforms.SpatialResample(mode=mode, padding_mode=padding_mode) + dst_affine, self.affine = self.affine, self.meta[MetaKeys.ORIGINAL_AFFINE] + with resampler.trace_transform(False): + output = resampler(self, dst_affine=dst_affine, spatial_size=self.spatial_shape) + self.array = output.array + self.spatial_shape = self.array.shape[1:] + self.affine = dst_affine + self.evaluated = True + return + + @property + def spatial_shape(self): + """if spatial shape is undefined, it infers the shape from array shape and original channel dim.""" + if MetaKeys.SPATIAL_SHAPE not in self.meta: + _shape = list(self.array.shape) + channel_dim = self.meta.get(MetaKeys.ORIGINAL_CHANNEL_DIM, 0) + if _shape and channel_dim != "no_channel": + _shape.pop(int(channel_dim)) + else: + _shape = self.meta.get(MetaKeys.SPATIAL_SHAPE) + if not isinstance(_shape, torch.Tensor): + self.meta[MetaKeys.SPATIAL_SHAPE] = convert_to_tensor( + _shape, device=torch.device("cpu"), wrap_sequence=True, track_meta=False + ) + return self.meta[MetaKeys.SPATIAL_SHAPE] + + @spatial_shape.setter + def spatial_shape(self, value): + self.meta[MetaKeys.SPATIAL_SHAPE] = convert_to_dst_type(value, self.spatial_shape, wrap_sequence=True)[0] @staticmethod def update_meta(rets: Sequence, func, args, kwargs) -> Sequence: diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 1d60c34c3e..2a61169345 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -17,10 +17,12 @@ import numpy as np +import monai from monai.transforms.inverse import InvertibleTransform # For backwards compatibility (so this still works: from monai.transforms.compose import MapTransform) from monai.transforms.transform import ( # noqa: F401 + LazyTransform, MapTransform, Randomizable, RandomizableTransform, @@ -33,6 +35,28 @@ __all__ = ["Compose", "OneOf"] +def eval_lazy_stack(data, upcoming, lazy_resample: bool = False): + """ + Given the upcoming transform ``upcoming``, if lazy_resample is True, go through the Metatensors and + evaluate the lazy applied operations. The returned `data` will then be ready for the ``upcoming`` transform. + """ + if not lazy_resample: + return data # eager evaluation + if isinstance(data, monai.data.MetaTensor): + if lazy_resample and not isinstance(upcoming, LazyTransform): + data.evaluate() + return data + if isinstance(data, Mapping): + if isinstance(upcoming, MapTransform): + return { + k: eval_lazy_stack(v, upcoming, lazy_resample) if k in upcoming.keys else v for k, v in data.items() + } + return {k: eval_lazy_stack(v, upcoming, lazy_resample) for k, v in data.items()} + if isinstance(data, (list, tuple)): + return [eval_lazy_stack(v, upcoming, lazy_resample) for v in data] + return data + + class Compose(Randomizable, InvertibleTransform): """ ``Compose`` provides the ability to chain a series of callables together in @@ -110,6 +134,7 @@ class Compose(Randomizable, InvertibleTransform): log_stats: whether to log the detailed information of data and applied transform when error happened, for NumPy array and PyTorch Tensor, log the data shape and value range, for other metadata, log the values directly. default to `False`. + lazy_resample: whether to compute consecutive spatial transforms resampling lazily. Default to False. """ @@ -119,6 +144,7 @@ def __init__( map_items: bool = True, unpack_items: bool = False, log_stats: bool = False, + lazy_resample: bool = False, ) -> None: if transforms is None: transforms = [] @@ -126,8 +152,14 @@ def __init__( self.map_items = map_items self.unpack_items = unpack_items self.log_stats = log_stats + self.lazy_resample = lazy_resample self.set_random_state(seed=get_seed()) + if self.lazy_resample: + for t in self.flatten().transforms: # TODO: test Compose of Compose/OneOf + if isinstance(t, LazyTransform): + t.set_eager_mode(False) + def set_random_state(self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None) -> "Compose": super().set_random_state(seed=seed, state=state) for _transform in self.transforms: @@ -170,7 +202,9 @@ def __len__(self): def __call__(self, input_): for _transform in self.transforms: + input_ = eval_lazy_stack(input_, upcoming=_transform, lazy_resample=self.lazy_resample) input_ = apply_transform(_transform, input_, self.map_items, self.unpack_items, self.log_stats) + input_ = eval_lazy_stack(input_, upcoming=None, lazy_resample=self.lazy_resample) return input_ def inverse(self, data): diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index bd9f73bbc9..3741499ad4 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -27,7 +27,7 @@ from monai.data.meta_tensor import MetaTensor from monai.data.utils import get_random_patch, get_valid_patch_size from monai.transforms.inverse import InvertibleTransform, TraceableTransform -from monai.transforms.transform import Randomizable, Transform +from monai.transforms.transform import LazyTransform, Randomizable, Transform from monai.transforms.utils import ( compute_divisible_spatial_size, convert_pad_mode, @@ -48,6 +48,7 @@ TransformBackends, convert_data_type, convert_to_dst_type, + convert_to_numpy, convert_to_tensor, ensure_tuple, ensure_tuple_rep, @@ -77,7 +78,7 @@ ] -class Pad(InvertibleTransform): +class Pad(InvertibleTransform, LazyTransform): """ Perform padding for a given an amount of padding in each dimension. @@ -137,6 +138,15 @@ def _pt_pad(img: torch.Tensor, pad_width, mode, **kwargs) -> torch.Tensor: # torch.pad expects `[B, C, H, W, [D]]` shape return pad_pt(img.unsqueeze(0), pt_pad_width, mode=mode, **kwargs).squeeze(0) + def lazy_call(self, img: torch.Tensor, to_pad) -> torch.Tensor: + if get_track_meta() and isinstance(img, MetaTensor): + img.evaluated = False + orig_size = img.spatial_shape + self.update_meta(img, to_pad=to_pad) + self.push_transform(img, orig_size=orig_size, extra_info={"padded": to_pad}) + return img + return img + def __call__( # type: ignore self, img: torch.Tensor, to_pad: Optional[List[Tuple[int, int]]] = None, mode: Optional[str] = None, **kwargs ) -> torch.Tensor: @@ -157,19 +167,27 @@ def __call__( # type: ignore """ to_pad_ = self.to_pad if to_pad is None else to_pad if to_pad_ is None: - to_pad_ = self.compute_pad_width(img.shape[1:]) + spatial_shape = convert_to_numpy( + img.spatial_shape if isinstance(img, MetaTensor) and not img.evaluated else img.shape[1:], + wrap_sequence=True, + ) + to_pad_ = self.compute_pad_width(spatial_shape) mode_ = self.mode if mode is None else mode kwargs_ = dict(self.kwargs) kwargs_.update(kwargs) img_t = convert_to_tensor(data=img, track_meta=get_track_meta()) - _orig_size = img_t.shape[1:] + _orig_size = img_t.spatial_shape if isinstance(img, MetaTensor) and not img.evaluated else img_t.shape[1:] # all zeros, skip padding if np.asarray(to_pad_).any(): to_pad_ = list(to_pad_) if len(to_pad_) < len(img_t.shape): to_pad_ = list(to_pad_) + [(0, 0)] * (len(img_t.shape) - len(to_pad_)) + if not self.eager_mode: + return self.lazy_call(img_t, to_pad=to_pad_) + if not img_t.evaluated: + img_t.evaluate() if mode_ in {"linear_ramp", "maximum", "mean", "median", "minimum", "symmetric", "empty"}: out = self._np_pad(img_t, pad_width=to_pad_, mode=mode_, **kwargs_) else: @@ -199,6 +217,8 @@ def update_meta(self, tensor: MetaTensor, to_pad: List[Tuple[int, int]]): to_shift = [-s[0] for s in to_pad[1:]] # skipping the channel pad mat = create_translate(spatial_rank, to_shift) tensor.affine = tensor.affine @ convert_to_dst_type(mat, tensor.affine)[0] + new_shape = [d + s + e for d, (s, e) in zip(tensor.spatial_shape, to_pad[1:])] + tensor.spatial_shape = new_shape def inverse(self, data: MetaTensor) -> MetaTensor: transform = self.pop_transform(data) @@ -361,7 +381,7 @@ def compute_pad_width(self, spatial_shape: Sequence[int]) -> List[Tuple[int, int return spatial_pad.compute_pad_width(spatial_shape) -class Crop(InvertibleTransform): +class Crop(InvertibleTransform, LazyTransform): """ Perform crop operation on the input image. @@ -421,36 +441,47 @@ def compute_slices( else: return [slice(int(s), int(e)) for s, e in zip(roi_start_t.tolist(), roi_end_t.tolist())] + def lazy_call(self, img: torch.Tensor, slices, cropped) -> torch.Tensor: + if get_track_meta() and isinstance(img, MetaTensor): + img.evaluated = False + orig_size = img.spatial_shape + self.update_meta(img, slices=slices, orig_size=orig_size) + self.push_transform(img, orig_size=orig_size, extra_info={"cropped": cropped}) + return img + return img + def __call__(self, img: torch.Tensor, slices: Tuple[slice, ...]) -> torch.Tensor: # type: ignore """ Apply the transform to `img`, assuming `img` is channel-first and slicing doesn't apply to the channel dim. """ - orig_size = img.shape[1:] + orig_size = img.spatial_shape if isinstance(img, MetaTensor) and not img.evaluated else img.shape[1:] slices_ = list(slices) - sd = len(img.shape[1:]) # spatial dims + sd = len(orig_size) # spatial dims if len(slices_) < sd: slices_ += [slice(None)] * (sd - len(slices_)) # Add in the channel (no cropping) slices = tuple([slice(None)] + slices_[:sd]) - + cropped_np = np.asarray([[s.indices(o)[0], o - s.indices(o)[1]] for s, o in zip(slices[1:], orig_size)]) + cropped = cropped_np.flatten().tolist() img_t: MetaTensor = convert_to_tensor(data=img, track_meta=get_track_meta()) - _orig_size = img_t.shape[1:] + if not self.eager_mode: + return self.lazy_call(img_t, slices, cropped) + if not img_t.evaluated: + img_t.evaluate() img_t = img_t[slices] # type: ignore if get_track_meta(): - self.update_meta(tensor=img_t, slices=slices) - cropped_from_start = np.asarray([s.indices(o)[0] for s, o in zip(slices[1:], orig_size)]) - cropped_from_end = np.asarray(orig_size) - img_t.shape[1:] - cropped_from_start - cropped = list(chain(*zip(cropped_from_start.tolist(), cropped_from_end.tolist()))) - self.push_transform(img_t, orig_size=_orig_size, extra_info={"cropped": cropped}) + self.update_meta(tensor=img_t, slices=slices, orig_size=orig_size) + self.push_transform(img_t, orig_size=orig_size, extra_info={"cropped": cropped}) return img_t - def update_meta(self, tensor: MetaTensor, slices: Tuple[slice, ...]): + def update_meta(self, tensor: MetaTensor, slices: Tuple[slice, ...], orig_size): spatial_rank = max(len(tensor.affine) - 1, 1) - to_shift = [s.start if s.start is not None else 0 for s in ensure_tuple(slices)[1:]] + to_shift = [s.start if s.start is not None else 0 for s in slices[1:]] mat = create_translate(spatial_rank, to_shift) tensor.affine = tensor.affine @ convert_to_dst_type(mat, tensor.affine)[0] + tensor.spatial_shape = [s.indices(o)[1] - s.indices(o)[0] for s, o in zip(slices[1:], orig_size)] def inverse(self, img: MetaTensor) -> MetaTensor: transform = self.pop_transform(img) @@ -526,6 +557,7 @@ def __init__(self, roi_size: Union[Sequence[int], int]) -> None: self.roi_size = roi_size def compute_slices(self, spatial_size: Sequence[int]): # type: ignore + spatial_size = convert_to_numpy(spatial_size, wrap_sequence=True) roi_size = fall_back_tuple(self.roi_size, spatial_size) roi_center = [i // 2 for i in spatial_size] return super().compute_slices(roi_center=roi_center, roi_size=roi_size) @@ -536,7 +568,12 @@ def __call__(self, img: torch.Tensor) -> torch.Tensor: # type: ignore slicing doesn't apply to the channel dim. """ - return super().__call__(img=img, slices=self.compute_slices(img.shape[1:])) + return super().__call__( + img=img, + slices=self.compute_slices( + img.spatial_shape if isinstance(img, MetaTensor) and not img.evaluated else img.shape[1:] + ), + ) class CenterScaleCrop(Crop): @@ -553,11 +590,16 @@ def __init__(self, roi_scale: Union[Sequence[float], float]): self.roi_scale = roi_scale def __call__(self, img: torch.Tensor) -> torch.Tensor: # type: ignore - img_size = img.shape[1:] + img_size = img.spatial_shape if isinstance(img, MetaTensor) and not img.evaluated else img.shape[1:] ndim = len(img_size) roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)] cropper = CenterSpatialCrop(roi_size=roi_size) - return super().__call__(img=img, slices=cropper.compute_slices(img.shape[1:])) + return super().__call__( + img=img, + slices=cropper.compute_slices( + img.spatial_shape if isinstance(img, MetaTensor) and not img.evaluated else img.shape[1:] + ), + ) class RandSpatialCrop(Randomizable, Crop): @@ -616,13 +658,18 @@ def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor: """ if randomize: - self.randomize(img.shape[1:]) + self.randomize(img.spatial_shape if isinstance(img, MetaTensor) and not img.evaluated else img.shape[1:]) if self._size is None: raise RuntimeError("self._size not specified.") if self.random_center: return super().__call__(img=img, slices=self._slices) cropper = CenterSpatialCrop(self._size) - return super().__call__(img=img, slices=cropper.compute_slices(img.shape[1:])) + return super().__call__( + img=img, + slices=cropper.compute_slices( + img.spatial_shape if isinstance(img, MetaTensor) and not img.evaluated else img.shape[1:] + ), + ) class RandScaleCrop(RandSpatialCrop): @@ -675,7 +722,7 @@ def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor: slicing doesn't apply to the channel dim. """ - self.get_max_roi_size(img.shape[1:]) + self.get_max_roi_size(img.spatial_shape if isinstance(img, MetaTensor) and not img.evaluated else img.shape[1:]) return super().__call__(img=img, randomize=randomize) @@ -824,6 +871,10 @@ def __init__( self.k_divisible = k_divisible self.padder = Pad(mode=mode, **pad_kwargs) + def set_eager_mode(self, value): + super().set_eager_mode(True) + self.padder.set_eager_mode(True) + def compute_bounding_box(self, img: torch.Tensor): """ Compute the start points and end points of bounding box to crop. @@ -1264,7 +1315,7 @@ def __call__( return results -class ResizeWithPadOrCrop(InvertibleTransform): +class ResizeWithPadOrCrop(InvertibleTransform, LazyTransform): """ Resize an image to a target spatial size by either centrally cropping the image or padding it evenly with a user-specified mode. @@ -1299,6 +1350,11 @@ def __init__( self.padder = SpatialPad(spatial_size=spatial_size, method=method, mode=mode, **pad_kwargs) self.cropper = CenterSpatialCrop(roi_size=spatial_size) + def set_eager_mode(self, value): + super().set_eager_mode(value) + self.padder.set_eager_mode(value) + self.cropper.set_eager_mode(value) + def __call__(self, img: torch.Tensor, mode: Optional[str] = None, **pad_kwargs) -> torch.Tensor: # type: ignore """ Args: @@ -1314,7 +1370,7 @@ def __call__(self, img: torch.Tensor, mode: Optional[str] = None, **pad_kwargs) note that `np.pad` treats channel dimension as the first dimension. """ - orig_size = img.shape[1:] + orig_size = img.spatial_shape if isinstance(img, MetaTensor) and not img.evaluated else img.shape[1:] ret = self.padder(self.cropper(img), mode=mode, **pad_kwargs) # remove the individual info and combine if get_track_meta(): diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index bae6705c22..35aa51b354 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -44,7 +44,7 @@ SpatialPad, ) from monai.transforms.inverse import InvertibleTransform -from monai.transforms.transform import MapTransform, Randomizable +from monai.transforms.transform import LazyTransform, MapTransform, Randomizable from monai.transforms.utils import is_positive from monai.utils import MAX_SEED, Method, PytorchPadMode, ensure_tuple_rep from monai.utils.deprecate_utils import deprecated_arg @@ -107,7 +107,7 @@ ] -class Padd(MapTransform, InvertibleTransform): +class Padd(MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Pad`. @@ -141,6 +141,11 @@ def __init__( self.padder = padder self.mode = ensure_tuple_rep(mode, len(self.keys)) + def set_eager_mode(self, value): + super().set_eager_mode(value) + if isinstance(self.padder, LazyTransform): + self.padder.set_eager_mode(value) + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) for key, m in self.key_iterator(d, self.mode): @@ -288,7 +293,7 @@ def __init__( super().__init__(keys, padder=padder, mode=mode, allow_missing_keys=allow_missing_keys) -class Cropd(MapTransform, InvertibleTransform): +class Cropd(MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based wrapper of abstract class :py:class:`monai.transforms.Crop`. @@ -306,6 +311,11 @@ def __init__(self, keys: KeysCollection, cropper: Crop, allow_missing_keys: bool super().__init__(keys, allow_missing_keys) self.cropper = cropper + def set_eager_mode(self, value): + super().set_eager_mode(value) + if isinstance(self.cropper, LazyTransform): + self.cropper.set_eager_mode(value) + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) for key in self.key_iterator(d): @@ -351,7 +361,8 @@ def randomize(self, img_size: Sequence[int]) -> None: def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) # the first key must exist to execute random operations - self.randomize(d[self.first_key(d)].shape[1:]) + first_item = d[self.first_key(d)] + self.randomize(first_item.spatial_shape if isinstance(first_item, MetaTensor) else first_item.shape[1:]) for key in self.key_iterator(d): kwargs = {"randomize": False} if isinstance(self.cropper, Randomizable) else {} d[key] = self.cropper(d[key], **kwargs) # type: ignore diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 730cb634c0..d61f7aa727 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -25,7 +25,15 @@ from monai.utils import MAX_SEED, ensure_tuple, first from monai.utils.enums import TransformBackends -__all__ = ["ThreadUnsafe", "apply_transform", "Randomizable", "RandomizableTransform", "Transform", "MapTransform"] +__all__ = [ + "ThreadUnsafe", + "apply_transform", + "Randomizable", + "RandomizableTransform", + "Transform", + "MapTransform", + "LazyTransform", +] ReturnType = TypeVar("ReturnType") @@ -127,6 +135,15 @@ class ThreadUnsafe: pass +class LazyTransform: + """Whether the transform can accept lazy metatensors (metatensor.evaluated is False) and can be evaluated lazily.""" + + eager_mode = True + + def set_eager_mode(self, value): + self.eager_mode = value + + class Randomizable(ABC, ThreadUnsafe): """ An interface for handling random state locally, currently based on a class diff --git a/monai/utils/enums.py b/monai/utils/enums.py index a6d9a23309..cb2241ba23 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -453,3 +453,4 @@ class MetaKeys(StrEnum): SPATIAL_SHAPE = "spatial_shape" # optional key for the length in each spatial dimension SPACE = "space" # possible values of space type are defined in `SpaceKeys` ORIGINAL_CHANNEL_DIM = "original_channel_dim" # an integer or "no_channel" + EVALUATED = "evaluated" # whether the array is up-to-date with the applied_operations (lazy evaluation) From be9b48ff775f5f28d1b14b750fcec461d5c70c4f Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sun, 14 Aug 2022 12:31:03 +0100 Subject: [PATCH 03/10] update spacing and orientation Signed-off-by: Wenqi Li --- monai/transforms/compose.py | 2 +- monai/transforms/croppad/array.py | 23 +++++--------- monai/transforms/croppad/dictionary.py | 6 +++- monai/transforms/spatial/array.py | 44 ++++++++++++++++++++------ monai/transforms/spatial/dictionary.py | 26 ++++++++++++--- 5 files changed, 70 insertions(+), 31 deletions(-) diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 2a61169345..9fe4f6421e 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -44,7 +44,7 @@ def eval_lazy_stack(data, upcoming, lazy_resample: bool = False): return data # eager evaluation if isinstance(data, monai.data.MetaTensor): if lazy_resample and not isinstance(upcoming, LazyTransform): - data.evaluate() + data.evaluate("nearest") return data if isinstance(data, Mapping): if isinstance(upcoming, MapTransform): diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 3741499ad4..587cbe4188 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -141,9 +141,9 @@ def _pt_pad(img: torch.Tensor, pad_width, mode, **kwargs) -> torch.Tensor: def lazy_call(self, img: torch.Tensor, to_pad) -> torch.Tensor: if get_track_meta() and isinstance(img, MetaTensor): img.evaluated = False - orig_size = img.spatial_shape self.update_meta(img, to_pad=to_pad) - self.push_transform(img, orig_size=orig_size, extra_info={"padded": to_pad}) + self.push_transform(img, orig_size=img.spatial_shape, extra_info={"padded": to_pad}) + img.spatial_shape = [d + s + e for d, (s, e) in zip(img.spatial_shape, to_pad[1:])] return img return img @@ -186,8 +186,6 @@ def __call__( # type: ignore to_pad_ = list(to_pad_) + [(0, 0)] * (len(img_t.shape) - len(to_pad_)) if not self.eager_mode: return self.lazy_call(img_t, to_pad=to_pad_) - if not img_t.evaluated: - img_t.evaluate() if mode_ in {"linear_ramp", "maximum", "mean", "median", "minimum", "symmetric", "empty"}: out = self._np_pad(img_t, pad_width=to_pad_, mode=mode_, **kwargs_) else: @@ -217,8 +215,6 @@ def update_meta(self, tensor: MetaTensor, to_pad: List[Tuple[int, int]]): to_shift = [-s[0] for s in to_pad[1:]] # skipping the channel pad mat = create_translate(spatial_rank, to_shift) tensor.affine = tensor.affine @ convert_to_dst_type(mat, tensor.affine)[0] - new_shape = [d + s + e for d, (s, e) in zip(tensor.spatial_shape, to_pad[1:])] - tensor.spatial_shape = new_shape def inverse(self, data: MetaTensor) -> MetaTensor: transform = self.pop_transform(data) @@ -444,9 +440,9 @@ def compute_slices( def lazy_call(self, img: torch.Tensor, slices, cropped) -> torch.Tensor: if get_track_meta() and isinstance(img, MetaTensor): img.evaluated = False - orig_size = img.spatial_shape - self.update_meta(img, slices=slices, orig_size=orig_size) - self.push_transform(img, orig_size=orig_size, extra_info={"cropped": cropped}) + self.update_meta(img, slices=slices) + self.push_transform(img, orig_size=img.spatial_shape, extra_info={"cropped": cropped}) + img.spatial_shape = [s.indices(o)[1] - s.indices(o)[0] for s, o in zip(slices[1:], img.spatial_shape)] return img return img @@ -468,20 +464,17 @@ def __call__(self, img: torch.Tensor, slices: Tuple[slice, ...]) -> torch.Tensor img_t: MetaTensor = convert_to_tensor(data=img, track_meta=get_track_meta()) if not self.eager_mode: return self.lazy_call(img_t, slices, cropped) - if not img_t.evaluated: - img_t.evaluate() img_t = img_t[slices] # type: ignore if get_track_meta(): - self.update_meta(tensor=img_t, slices=slices, orig_size=orig_size) + self.update_meta(tensor=img_t, slices=slices) self.push_transform(img_t, orig_size=orig_size, extra_info={"cropped": cropped}) return img_t - def update_meta(self, tensor: MetaTensor, slices: Tuple[slice, ...], orig_size): + def update_meta(self, tensor: MetaTensor, slices: Tuple[slice, ...]): spatial_rank = max(len(tensor.affine) - 1, 1) - to_shift = [s.start if s.start is not None else 0 for s in slices[1:]] + to_shift = [s.start if s.start is not None else 0 for s in ensure_tuple(slices)[1:]] mat = create_translate(spatial_rank, to_shift) tensor.affine = tensor.affine @ convert_to_dst_type(mat, tensor.affine)[0] - tensor.spatial_shape = [s.indices(o)[1] - s.indices(o)[0] for s, o in zip(slices[1:], orig_size)] def inverse(self, img: MetaTensor) -> MetaTensor: transform = self.pop_transform(img) diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 35aa51b354..e7a427c83b 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -362,7 +362,11 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc d = dict(data) # the first key must exist to execute random operations first_item = d[self.first_key(d)] - self.randomize(first_item.spatial_shape if isinstance(first_item, MetaTensor) else first_item.shape[1:]) + self.randomize( + first_item.spatial_shape + if isinstance(first_item, MetaTensor) and not first_item.evaluated + else first_item.shape[1:] + ) for key in self.key_iterator(d): kwargs = {"randomize": False} if isinstance(self.cropper, Randomizable) else {} d[key] = self.cropper(d[key], **kwargs) # type: ignore diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 14da37300a..c54ad98c9b 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -30,7 +30,7 @@ from monai.transforms.croppad.array import CenterSpatialCrop, ResizeWithPadOrCrop from monai.transforms.intensity.array import GaussianSmooth from monai.transforms.inverse import InvertibleTransform -from monai.transforms.transform import Randomizable, RandomizableTransform, Transform +from monai.transforms.transform import LazyTransform, Randomizable, RandomizableTransform, Transform from monai.transforms.utils import ( convert_pad_mode, create_control_grid, @@ -99,7 +99,7 @@ RandRange = Optional[Union[Sequence[Union[Tuple[float, float], float]], float]] -class SpatialResample(InvertibleTransform): +class SpatialResample(InvertibleTransform, LazyTransform): """ Resample input image from the orientation/spacing defined by ``src_affine`` affine matrix into the ones specified by ``dst_affine`` affine matrix. @@ -176,6 +176,13 @@ def _post_process( def update_meta(self, img, dst_affine): img.affine = dst_affine + def lazy_call(self, img: torch.Tensor, output_shape, *args) -> torch.Tensor: + if get_track_meta() and isinstance(img, MetaTensor): + img.evaluated = False + img.spatial_shape = output_shape + return self._post_process(img, *args) + return img + @deprecated_arg( name="src_affine", since="0.9", msg_suffix="img should be `MetaTensor`, so affine can be extracted directly." ) @@ -248,6 +255,11 @@ def __call__( spatial_size, _ = compute_shape_offset(in_spatial_size, src_affine_, dst_affine) # type: ignore spatial_size = torch.tensor(fall_back_tuple(ensure_tuple(spatial_size)[:spatial_rank], in_spatial_size)) + if not self.eager_mode: + return self.lazy_call( + img, spatial_size, src_affine_, dst_affine, mode, padding_mode, align_corners, original_spatial_shape + ) + if ( allclose(src_affine_, dst_affine, atol=AFFINE_TOL) and allclose(spatial_size, in_spatial_size) @@ -407,7 +419,7 @@ def __call__( return img -class Spacing(InvertibleTransform): +class Spacing(InvertibleTransform, LazyTransform): """ Resample input image into the specified `pixdim`. """ @@ -471,6 +483,10 @@ def __init__( dtype=dtype, ) + def set_eager_mode(self, value): + super().set_eager_mode(value) + self.sp_resample.set_eager_mode(value) + @deprecated_arg(name="affine", since="0.9", msg_suffix="Not needed, input should be `MetaTensor`.") def __call__( self, @@ -555,7 +571,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: return self.sp_resample.inverse(data) -class Orientation(InvertibleTransform): +class Orientation(InvertibleTransform, LazyTransform): """ Change the input image's orientation into the specified based on `axcodes`. """ @@ -596,6 +612,15 @@ def __init__( self.as_closest_canonical = as_closest_canonical self.labels = labels + def lazy_call(self, img: torch.Tensor, new_affine, original_affine, ordering) -> torch.Tensor: + if get_track_meta() and isinstance(img, MetaTensor): + img.evaluated = False + self.update_meta(img, new_affine) + self.push_transform(img, extra_info={"original_affine": original_affine}) + img.spatial_shape = img.spatial_shape[[i - 1 for i in ordering if i != 0]] + return img + return img + def __call__(self, data_array: torch.Tensor) -> torch.Tensor: """ If input type is `MetaTensor`, original affine is extracted with `data_array.affine`. @@ -655,15 +680,16 @@ def __call__(self, data_array: torch.Tensor) -> torch.Tensor: spatial_ornt[:, 0] += 1 # skip channel dim spatial_ornt = np.concatenate([np.array([[0, 1]]), spatial_ornt]) axes = [ax for ax, flip in enumerate(spatial_ornt[:, 1]) if flip == -1] - if axes: - data_array = torch.flip(data_array, dims=axes) full_transpose = np.arange(len(data_array.shape)) full_transpose[: len(spatial_ornt)] = np.argsort(spatial_ornt[:, 0]) - if not np.all(full_transpose == np.arange(len(data_array.shape))): - data_array = data_array.permute(full_transpose.tolist()) - new_affine = to_affine_nd(affine_np, new_affine) new_affine, *_ = convert_data_type(new_affine, torch.Tensor, dtype=torch.float32, device=data_array.device) + if not self.eager_mode: + return self.lazy_call(data_array, new_affine, affine_np, full_transpose) + if axes: + data_array = torch.flip(data_array, dims=axes) + if not np.all(full_transpose == np.arange(len(data_array.shape))): + data_array = data_array.permute(full_transpose.tolist()) if get_track_meta(): self.update_meta(data_array, new_affine) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 493369d258..754a565493 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -51,7 +51,7 @@ SpatialResample, Zoom, ) -from monai.transforms.transform import MapTransform, RandomizableTransform +from monai.transforms.transform import LazyTransform, MapTransform, RandomizableTransform from monai.transforms.utils import create_grid from monai.utils import ( GridSampleMode, @@ -142,7 +142,7 @@ ] -class SpatialResampled(MapTransform, InvertibleTransform): +class SpatialResampled(MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.SpatialResample`. @@ -204,6 +204,10 @@ def __init__( self.dtype = ensure_tuple_rep(dtype, len(self.keys)) self.dst_keys = ensure_tuple_rep(dst_keys, len(self.keys)) + def set_eager_mode(self, value): + super().set_eager_mode(value) + self.sp_transform.set_eager_mode(value) + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d: Dict = dict(data) for (key, mode, padding_mode, align_corners, dtype, dst_key) in self.key_iterator( @@ -227,7 +231,7 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch return d -class ResampleToMatchd(MapTransform, InvertibleTransform): +class ResampleToMatchd(MapTransform, InvertibleTransform, LazyTransform): """Dictionary-based wrapper of :py:class:`monai.transforms.ResampleToMatch`.""" backend = ResampleToMatch.backend @@ -273,6 +277,10 @@ def __init__( self.dtype = ensure_tuple_rep(dtype, len(self.keys)) self.resampler = ResampleToMatch() + def set_eager_mode(self, value): + super().set_eager_mode(value) + self.resampler.set_eager_mode(value) + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) for (key, mode, padding_mode, align_corners, dtype) in self.key_iterator( @@ -295,7 +303,7 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch return d -class Spacingd(MapTransform, InvertibleTransform): +class Spacingd(MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Spacing`. @@ -372,6 +380,10 @@ def __init__( self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.dtype = ensure_tuple_rep(dtype, len(self.keys)) + def set_eager_mode(self, value): + super().set_eager_mode(value) + self.spacing_transform.set_eager_mode(value) + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d: Dict = dict(data) for key, mode, padding_mode, align_corners, dtype in self.key_iterator( @@ -390,7 +402,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd return d -class Orientationd(MapTransform, InvertibleTransform): +class Orientationd(MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Orientation`. @@ -433,6 +445,10 @@ def __init__( super().__init__(keys, allow_missing_keys) self.ornt_transform = Orientation(axcodes=axcodes, as_closest_canonical=as_closest_canonical, labels=labels) + def set_eager_mode(self, value): + super().set_eager_mode(value) + self.ornt_transform.set_eager_mode(value) + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d: Dict = dict(data) for key in self.key_iterator(d): From fb943c3006c93a0d4b580379e2281411b8948023 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sun, 14 Aug 2022 14:22:11 +0100 Subject: [PATCH 04/10] reviewing spatial transforms Signed-off-by: Wenqi Li --- monai/transforms/spatial/array.py | 30 +++++++++++++++---- monai/transforms/spatial/dictionary.py | 12 ++++++-- .../utils_pytorch_numpy_unification.py | 4 +-- 3 files changed, 36 insertions(+), 10 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index c54ad98c9b..248c7b092b 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -712,7 +712,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: return data -class Flip(InvertibleTransform): +class Flip(InvertibleTransform, LazyTransform): """ Reverses the order of elements along the given spatial axis. Preserves shape. See `torch.flip` documentation for additional details: @@ -744,6 +744,15 @@ def update_meta(self, img, shape, axes): def forward_image(self, img, axes) -> torch.Tensor: return torch.flip(img, axes) + def lazy_call(self, img: torch.Tensor, axes) -> torch.Tensor: + if get_track_meta() and isinstance(img, MetaTensor): + img.evaluated = False + self.update_meta(img, img.shape, axes) + self.push_transform(img) + img.spatial_shape = img.spatial_shape + return img + return img + def __call__(self, img: torch.Tensor) -> torch.Tensor: """ Args: @@ -751,6 +760,8 @@ def __call__(self, img: torch.Tensor) -> torch.Tensor: """ img = convert_to_tensor(img, track_meta=get_track_meta()) axes = map_spatial_axes(img.ndim, self.spatial_axis) + if not self.eager_mode: + return self.lazy_call(img, axes) out = self.forward_image(img, axes) if get_track_meta(): self.update_meta(out, out.shape, axes) @@ -764,7 +775,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: return flipper(data) -class Resize(InvertibleTransform): +class Resize(InvertibleTransform, LazyTransform): """ Resize the input image to given spatial size (with scaling, not cropping/padding). Implemented using :py:class:`torch.nn.functional.interpolate`. @@ -860,15 +871,16 @@ def __call__( "len(spatial_size) must be greater or equal to img spatial dimensions, " f"got spatial_size={output_ndim} img={input_ndim}." ) - spatial_size_ = fall_back_tuple(self.spatial_size, img.shape[1:]) + img_size = img.spatial_shape if isinstance(img, MetaTensor) and not img.evaluated else img.shape[1:] + spatial_size_ = fall_back_tuple(self.spatial_size, img_size) else: # for the "longest" mode - img_size = img.shape[1:] + img_size = img.spatial_shape if isinstance(img, MetaTensor) and not img.evaluated else img.shape[1:] if not isinstance(self.spatial_size, int): raise ValueError("spatial_size must be an int number if size_mode is 'longest'.") scale = self.spatial_size / max(img_size) spatial_size_ = tuple(int(round(s * scale)) for s in img_size) - original_sp_size = img.shape[1:] + original_sp_size = img.spatial_shape if isinstance(img, MetaTensor) and not img.evaluated else img.shape[1:] _mode = look_up_option(self.mode if mode is None else mode, InterpolateMode) _align_corners = self.align_corners if align_corners is None else align_corners if tuple(img.shape[1:]) == spatial_size_: # spatial shape is already the desired @@ -877,6 +889,8 @@ def __call__( img_ = convert_to_tensor(img, dtype=torch.float, track_meta=False) if anti_aliasing and any(x < y for x, y in zip(spatial_size_, img_.shape[1:])): + if not self.eager_mode: + raise ValueError("anti aliasing is not compatible with lazy evaluation.") factors = torch.div(torch.Tensor(list(img_.shape[1:])), torch.Tensor(spatial_size_)) if anti_aliasing_sigma is None: # if sigma is not given, use the default sigma in skimage.transform.resize @@ -1478,7 +1492,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: return Rotate(0).inverse_transform(data, xform_info[TraceKeys.EXTRA_INFO]) -class RandFlip(RandomizableTransform, InvertibleTransform): +class RandFlip(RandomizableTransform, InvertibleTransform, LazyTransform): """ Randomly flips the image along axes. Preserves shape. See numpy.flip for additional details. @@ -1495,6 +1509,10 @@ def __init__(self, prob: float = 0.1, spatial_axis: Optional[Union[Sequence[int] RandomizableTransform.__init__(self, prob) self.flipper = Flip(spatial_axis=spatial_axis) + def set_eager_mode(self, value): + super().set_eager_mode(value) + self.flipper.set_eager_mode(value) + def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor: """ Args: diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 754a565493..22b5644dc5 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -1128,7 +1128,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc return d -class Flipd(MapTransform, InvertibleTransform): +class Flipd(MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Flip`. @@ -1152,6 +1152,10 @@ def __init__( super().__init__(keys, allow_missing_keys) self.flipper = Flip(spatial_axis=spatial_axis) + def set_eager_mode(self, value): + super().set_eager_mode(value) + self.flipper.set_eager_mode(value) + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) for key in self.key_iterator(d): @@ -1165,7 +1169,7 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch return d -class RandFlipd(RandomizableTransform, MapTransform, InvertibleTransform): +class RandFlipd(RandomizableTransform, MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based version :py:class:`monai.transforms.RandFlip`. @@ -1192,6 +1196,10 @@ def __init__( RandomizableTransform.__init__(self, prob) self.flipper = Flip(spatial_axis=spatial_axis) + def set_eager_mode(self, value): + super().set_eager_mode(value) + self.flipper.set_eager_mode(value) + def set_random_state( self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None ) -> "RandFlipd": diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py index f6f67d4c7d..55c2c4de89 100644 --- a/monai/transforms/utils_pytorch_numpy_unification.py +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -285,7 +285,7 @@ def cumsum(a: NdarrayOrTensor, axis=None, **kwargs) -> NdarrayOrTensor: def isfinite(x: NdarrayOrTensor) -> NdarrayOrTensor: """`np.isfinite` with equivalent implementation for torch.""" if not isinstance(x, torch.Tensor): - return np.isfinite(x) + return np.isfinite(x) # type: ignore return torch.isfinite(x) @@ -333,7 +333,7 @@ def isnan(x: NdarrayOrTensor) -> NdarrayOrTensor: """ if isinstance(x, np.ndarray): - return np.isnan(x) + return np.isnan(x) # type: ignore return torch.isnan(x) From 58caa93d4b04aeaec6e6719404589526b32d2a92 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sun, 14 Aug 2022 23:35:14 +0100 Subject: [PATCH 05/10] update spatial xforms Signed-off-by: Wenqi Li --- monai/transforms/compose.py | 2 +- monai/transforms/croppad/array.py | 2 - monai/transforms/spatial/array.py | 193 ++++++++++++++++++++----- monai/transforms/spatial/dictionary.py | 42 +++++- 4 files changed, 188 insertions(+), 51 deletions(-) diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 9fe4f6421e..8af02d4f2d 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -44,7 +44,7 @@ def eval_lazy_stack(data, upcoming, lazy_resample: bool = False): return data # eager evaluation if isinstance(data, monai.data.MetaTensor): if lazy_resample and not isinstance(upcoming, LazyTransform): - data.evaluate("nearest") + data.evaluate("bilinear") return data if isinstance(data, Mapping): if isinstance(upcoming, MapTransform): diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index 587cbe4188..bd0048ffb7 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -144,7 +144,6 @@ def lazy_call(self, img: torch.Tensor, to_pad) -> torch.Tensor: self.update_meta(img, to_pad=to_pad) self.push_transform(img, orig_size=img.spatial_shape, extra_info={"padded": to_pad}) img.spatial_shape = [d + s + e for d, (s, e) in zip(img.spatial_shape, to_pad[1:])] - return img return img def __call__( # type: ignore @@ -443,7 +442,6 @@ def lazy_call(self, img: torch.Tensor, slices, cropped) -> torch.Tensor: self.update_meta(img, slices=slices) self.push_transform(img, orig_size=img.spatial_shape, extra_info={"cropped": cropped}) img.spatial_shape = [s.indices(o)[1] - s.indices(o)[0] for s, o in zip(slices[1:], img.spatial_shape)] - return img return img def __call__(self, img: torch.Tensor, slices: Tuple[slice, ...]) -> torch.Tensor: # type: ignore diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 248c7b092b..567ffea1ec 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -12,6 +12,7 @@ A collection of "vanilla" transforms for spatial operations https://github.com/Project-MONAI/MONAI/wiki/MONAI_Design """ +import math import warnings from copy import deepcopy from enum import Enum @@ -62,7 +63,12 @@ from monai.utils.enums import GridPatchSort, PytorchPadMode, TraceKeys, TransformBackends from monai.utils.misc import ImageMetaKey as Key from monai.utils.module import look_up_option -from monai.utils.type_conversion import convert_data_type, get_equivalent_dtype, get_torch_dtype_from_string +from monai.utils.type_conversion import ( + convert_data_type, + convert_to_numpy, + get_equivalent_dtype, + get_torch_dtype_from_string, +) nib, has_nib = optional_import("nibabel") @@ -618,7 +624,6 @@ def lazy_call(self, img: torch.Tensor, new_affine, original_affine, ordering) -> self.update_meta(img, new_affine) self.push_transform(img, extra_info={"original_affine": original_affine}) img.spatial_shape = img.spatial_shape[[i - 1 for i in ordering if i != 0]] - return img return img def __call__(self, data_array: torch.Tensor) -> torch.Tensor: @@ -747,10 +752,10 @@ def forward_image(self, img, axes) -> torch.Tensor: def lazy_call(self, img: torch.Tensor, axes) -> torch.Tensor: if get_track_meta() and isinstance(img, MetaTensor): img.evaluated = False - self.update_meta(img, img.shape, axes) + spatial_chn_shape = [1, *convert_to_numpy(img.spatial_shape).tolist()] + self.update_meta(img, spatial_chn_shape, axes) self.push_transform(img) img.spatial_shape = img.spatial_shape - return img return img def __call__(self, img: torch.Tensor) -> torch.Tensor: @@ -880,17 +885,19 @@ def __call__( scale = self.spatial_size / max(img_size) spatial_size_ = tuple(int(round(s * scale)) for s in img_size) - original_sp_size = img.spatial_shape if isinstance(img, MetaTensor) and not img.evaluated else img.shape[1:] _mode = look_up_option(self.mode if mode is None else mode, InterpolateMode) _align_corners = self.align_corners if align_corners is None else align_corners - if tuple(img.shape[1:]) == spatial_size_: # spatial shape is already the desired - img = convert_to_tensor(img, track_meta=get_track_meta()) # type: ignore + img = convert_to_tensor(img, track_meta=get_track_meta()) # type: ignore + if not self.eager_mode: + if anti_aliasing: + raise ValueError("anti-aliasing is not compatible with lazy evaluation.") + return self.lazy_call(img, spatial_size_, _mode, _align_corners, input_ndim) + original_sp_size = img.spatial_shape if isinstance(img, MetaTensor) and not img.evaluated else img.shape[1:] + if tuple(convert_to_numpy(original_sp_size)) == spatial_size_: # spatial shape is already the desired return self._post_process(img, original_sp_size, spatial_size_, _mode, _align_corners, input_ndim) img_ = convert_to_tensor(img, dtype=torch.float, track_meta=False) if anti_aliasing and any(x < y for x, y in zip(spatial_size_, img_.shape[1:])): - if not self.eager_mode: - raise ValueError("anti aliasing is not compatible with lazy evaluation.") factors = torch.div(torch.Tensor(list(img_.shape[1:])), torch.Tensor(spatial_size_)) if anti_aliasing_sigma is None: # if sigma is not given, use the default sigma in skimage.transform.resize @@ -903,7 +910,6 @@ def __call__( anti_aliasing_filter = GaussianSmooth(sigma=anti_aliasing_sigma) img_ = convert_to_tensor(anti_aliasing_filter(img_), track_meta=False) - img = convert_to_tensor(img, track_meta=get_track_meta()) resized = torch.nn.functional.interpolate( input=img_.unsqueeze(0), size=spatial_size_, mode=_mode, align_corners=_align_corners ) @@ -924,6 +930,13 @@ def _post_process(self, img: torch.Tensor, orig_size, sp_size, mode, align_corne ) return img + def lazy_call(self, img: torch.Tensor, sp_size, *args) -> torch.Tensor: + if get_track_meta() and isinstance(img, MetaTensor): + img.evaluated = False + img = self._post_process(img, img.spatial_shape, sp_size, *args) + img.spatial_shape = sp_size + return img + def update_meta(self, img, spatial_size, new_spatial_size): affine = convert_to_tensor(img.affine, track_meta=False) img.affine = scale_affine(affine, spatial_size, new_spatial_size) @@ -946,7 +959,7 @@ def inverse_transform(self, data: torch.Tensor, transform) -> torch.Tensor: return data -class Rotate(InvertibleTransform): +class Rotate(InvertibleTransform, LazyTransform): """ Rotates an input image by given angle using :py:class:`monai.networks.layers.AffineTransform`. @@ -1018,7 +1031,7 @@ def __call__( img = convert_to_tensor(img, track_meta=get_track_meta()) _dtype = get_equivalent_dtype(dtype or self.dtype or img.dtype, torch.Tensor) - im_shape = np.asarray(img.shape[1:]) # spatial dimensions + im_shape = np.asarray(img.spatial_shape if isinstance(img, MetaTensor) and not img.evaluated else img.shape[1:]) input_ndim = len(im_shape) if input_ndim not in (2, 3): raise ValueError(f"Unsupported image dimension: {input_ndim}, available options are [2, 3].") @@ -1041,6 +1054,8 @@ def __call__( _mode = look_up_option(mode or self.mode, GridSampleMode) _padding_mode = look_up_option(padding_mode or self.padding_mode, GridSamplePadMode) _align_corners = self.align_corners if align_corners is None else align_corners + if not self.eager_mode: + return self.lazy_call(img, output_shape, transform_t, _mode, _padding_mode, _align_corners, _dtype) xform = AffineTransform( normalized=False, mode=_mode, @@ -1050,20 +1065,32 @@ def __call__( ) output: torch.Tensor = xform(img_t.unsqueeze(0), transform_t, spatial_size=output_shape).float().squeeze(0) out, *_ = convert_to_dst_type(output, dst=img, dtype=output.dtype) + return self._post_process(out, im_shape, transform_t, _mode, _padding_mode, _align_corners, _dtype) + + def _post_process( + self, img: torch.Tensor, orig_size, transform_t, mode, padding_mode, align_corners, dtype + ) -> torch.Tensor: if get_track_meta(): - self.update_meta(out, transform_t) + self.update_meta(img, transform_t) self.push_transform( - out, - orig_size=img_t.shape[1:], + img, + orig_size=orig_size, extra_info={ - "rot_mat": transform, - "mode": _mode, - "padding_mode": _padding_mode, - "align_corners": _align_corners if _align_corners is not None else TraceKeys.NONE, - "dtype": str(_dtype)[6:], # dtype as string; remove "torch": torch.float32 -> float32 + "rot_mat": transform_t, + "mode": mode, + "padding_mode": padding_mode, + "align_corners": align_corners if align_corners is not None else TraceKeys.NONE, + "dtype": str(dtype)[6:], # dtype as string; remove "torch": torch.float32 -> float32 }, ) - return out + return img + + def lazy_call(self, img, output_shape, transform_t, *args): + if get_track_meta() and isinstance(img, MetaTensor): + img.evaluated = False + img = self._post_process(img, img.spatial_shape, transform_t, *args) + img.spatial_shape = output_shape + return img def update_meta(self, img, rotate_mat): affine = convert_to_tensor(img.affine, track_meta=False) @@ -1099,7 +1126,7 @@ def inverse_transform(self, data: torch.Tensor, transform) -> torch.Tensor: return out -class Zoom(InvertibleTransform): +class Zoom(InvertibleTransform, LazyTransform): """ Zooms an ND image using :py:class:`torch.nn.functional.interpolate`. For details, please see https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html. @@ -1182,6 +1209,13 @@ def __call__( _align_corners = self.align_corners if align_corners is None else align_corners _padding_mode = padding_mode or self.padding_mode + if not self.eager_mode: + if self.keep_size: + raise NotImplementedError("keep_size=True is currently not compatible with lazy evaluation") + else: + output_size = [int(math.floor(float(i) * z)) for i, z in zip(img.spatial_shape, _zoom)] + return self.lazy_call(img, output_size, _mode, _align_corners) + zoomed: NdarrayOrTensor = torch.nn.functional.interpolate( recompute_scale_factor=True, input=img_t.unsqueeze(0), @@ -1213,6 +1247,23 @@ def __call__( ) return out + def lazy_call(self, img, zoom_size, mode, align_corners): + if get_track_meta() and isinstance(img, MetaTensor): + img.evaluated = False + self.update_meta(img, img.spatial_shape, zoom_size) + self.push_transform( + img, + orig_size=img.spatial_shape, + extra_info={ + "mode": mode, + "align_corners": align_corners if align_corners is not None else TraceKeys.NONE, + "do_padcrop": False, + "padcrop": {}, + }, + ) + img.spatial_shape = zoom_size + return img + def update_meta(self, img, spatial_size, new_spatial_size): affine = convert_to_tensor(img.affine, track_meta=False) img.affine = scale_affine(affine, spatial_size, new_spatial_size) @@ -1367,7 +1418,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: return Rotate90().inverse_transform(data, rotate_xform) -class RandRotate(RandomizableTransform, InvertibleTransform): +class RandRotate(RandomizableTransform, InvertibleTransform, LazyTransform): """ Randomly rotate the input arrays. @@ -1477,6 +1528,7 @@ def __call__( align_corners=self.align_corners if align_corners is None else align_corners, dtype=dtype or self.dtype or img.dtype, ) + rotator.set_eager_mode(self.eager_mode) out = rotator(img) else: out = convert_to_tensor(img, track_meta=get_track_meta()) @@ -1536,7 +1588,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: return self.flipper.inverse(data) -class RandAxisFlip(RandomizableTransform, InvertibleTransform): +class RandAxisFlip(RandomizableTransform, InvertibleTransform, LazyTransform): """ Randomly select a spatial axis and flip along it. See numpy.flip for additional details. @@ -1554,6 +1606,10 @@ def __init__(self, prob: float = 0.1) -> None: self._axis: Optional[int] = None self.flipper = Flip(spatial_axis=self._axis) + def set_eager_mode(self, value): + super().set_eager_mode(value) + self.flipper.set_eager_mode(value) + def randomize(self, data: NdarrayOrTensor) -> None: super().randomize(None) if not self._do_transform: @@ -1589,7 +1645,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: return flipper(data) -class RandZoom(RandomizableTransform, InvertibleTransform): +class RandZoom(RandomizableTransform, InvertibleTransform, LazyTransform): """ Randomly zooms input arrays with given probability within given zoom range. @@ -1696,14 +1752,16 @@ def __call__( if not self._do_transform: out = convert_to_tensor(img, track_meta=get_track_meta()) else: - out = Zoom( + xform = Zoom( self._zoom, keep_size=self.keep_size, mode=look_up_option(mode or self.mode, InterpolateMode), padding_mode=padding_mode or self.padding_mode, align_corners=self.align_corners if align_corners is None else align_corners, **self.kwargs, - )(img) + ) + xform.set_eager_mode(self.eager_mode) + out = xform(img) if get_track_meta(): z_info = self.pop_transform(out, check=False) if self._do_transform else {} self.push_transform(out, extra_info=z_info) @@ -1741,8 +1799,8 @@ class AffineGrid(Transform): If ``None``, use the data type of input data (if `grid` is provided). device: device on which the tensor will be allocated, if a new grid is generated. affine: If applied, ignore the params (`rotate_params`, etc.) and use the - supplied matrix. Should be square with each side = num of image spatial - dimensions + 1. + supplied matrix. Should be square with each side = num of image spatial dimensions + 1. + affine_only: whether to return an affine matrix without computing the actual grid. Defaults to False. .. deprecated:: 0.6.0 ``as_tensor_output`` is deprecated. @@ -1762,6 +1820,7 @@ def __init__( device: Optional[torch.device] = None, dtype: DtypeLike = np.float32, affine: Optional[NdarrayOrTensor] = None, + affine_only: bool = False, ) -> None: self.rotate_params = rotate_params self.shear_params = shear_params @@ -1770,6 +1829,7 @@ def __init__( self.device = device self.dtype = dtype self.affine = affine + self.affine_only = affine_only def __call__( self, spatial_size: Optional[Sequence[int]] = None, grid: Optional[torch.Tensor] = None @@ -1787,19 +1847,23 @@ def __call__( ValueError: When ``grid=None`` and ``spatial_size=None``. Incompatible values. """ - if grid is None: # create grid from spatial_size - if spatial_size is None: - raise ValueError("Incompatible values: grid=None and spatial_size=None.") - grid_ = create_grid(spatial_size, device=self.device, backend="torch", dtype=self.dtype) + if not self.affine_only: + if grid is None: # create grid from spatial_size + if spatial_size is None: + raise ValueError("Incompatible values: grid=None and spatial_size=None.") + grid_ = create_grid(spatial_size, device=self.device, backend="torch", dtype=self.dtype) + else: + grid_ = grid + _dtype = self.dtype or grid_.dtype + grid_: torch.Tensor = convert_to_tensor(grid_, dtype=_dtype, track_meta=get_track_meta()) # type: ignore + _device = grid_.device # type: ignore + spatial_dims = len(grid_.shape) - 1 else: - grid_ = grid - _dtype = self.dtype or grid_.dtype - grid_: torch.Tensor = convert_to_tensor(grid_, dtype=_dtype, track_meta=get_track_meta()) # type: ignore + _device = self.device + spatial_dims = len(spatial_size) _b = TransformBackends.TORCH - _device = grid_.device # type: ignore affine: NdarrayOrTensor if self.affine is None: - spatial_dims = len(grid_.shape) - 1 affine = torch.eye(spatial_dims + 1, device=_device) if self.rotate_params: affine = affine @ create_rotate(spatial_dims, self.rotate_params, device=_device, backend=_b) @@ -1811,6 +1875,8 @@ def __call__( affine = affine @ create_scale(spatial_dims, self.scale_params, device=_device, backend=_b) else: affine = self.affine + if self.affine_only: + return None, affine affine = to_affine_nd(len(grid_) - 1, affine) affine = convert_to_tensor(affine, device=grid_.device, dtype=grid_.dtype, track_meta=False) # type: ignore @@ -1835,6 +1901,7 @@ def __init__( scale_range: RandRange = None, as_tensor_output: bool = True, device: Optional[torch.device] = None, + affine_only: bool = False, ) -> None: """ Args: @@ -1883,6 +1950,7 @@ def __init__( self.scale_params: Optional[List[float]] = None self.device = device + self.affine_only = affine_only self.affine: Optional[torch.Tensor] = torch.eye(4, dtype=torch.float64) def _get_rand_param(self, param_range, add_scalar: float = 0.0): @@ -1925,7 +1993,10 @@ def __call__( translate_params=self.translate_params, scale_params=self.scale_params, device=self.device, + affine_only=self.affine_only, ) + if affine_grid.affine_only: + return affine_grid(spatial_size, grid)[1] _grid: torch.Tensor _grid, self.affine = affine_grid(spatial_size, grid) return _grid @@ -2104,7 +2175,7 @@ def __call__( # type: ignore return out_val -class Affine(InvertibleTransform): +class Affine(InvertibleTransform, LazyTransform): """ Transform ``img`` given the affine parameters. A tutorial is available: https://github.com/Project-MONAI/tutorials/blob/0.6.0/modules/transforms_demo_2d.ipynb. @@ -2234,6 +2305,10 @@ def __call__( sp_size = fall_back_tuple(self.spatial_size if spatial_size is None else spatial_size, img_size) _mode = mode or self.mode _padding_mode = padding_mode or self.padding_mode + if not self.eager_mode: + self.affine_grid.affine_only = True + _, affine = self.affine_grid(spatial_size=sp_size) + return self.lazy_call(img, affine, sp_size, _mode, _padding_mode) grid, affine = self.affine_grid(spatial_size=sp_size) out = self.resampler(img, grid=grid, mode=_mode, padding_mode=_padding_mode) if not isinstance(out, MetaTensor): @@ -2259,6 +2334,17 @@ def update_meta(self, img, mat, img_size, sp_size): affine = convert_data_type(img.affine, torch.Tensor)[0] img.affine = Affine.compute_w_affine(affine, mat, img_size, sp_size) + def lazy_call(self, img: torch.Tensor, affine, output_size, mode, padding_mode) -> torch.Tensor: + if get_track_meta() and isinstance(img, MetaTensor): + img.evaluated = False + orig_size = img.spatial_shape + self.update_meta(img, affine, orig_size, output_size) + self.push_transform( + img, orig_size=orig_size, extra_info={"affine": affine, "mode": mode, "padding_mode": padding_mode} + ) + img.spatial_shape = output_size + return img + def inverse(self, data: torch.Tensor) -> torch.Tensor: transform = self.pop_transform(data) orig_size = transform[TraceKeys.ORIG_SIZE] @@ -2280,7 +2366,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: return out # type: ignore -class RandAffine(RandomizableTransform, InvertibleTransform): +class RandAffine(RandomizableTransform, InvertibleTransform, LazyTransform): """ Random affine transform. A tutorial is available: https://github.com/Project-MONAI/tutorials/blob/0.6.0/modules/transforms_demo_2d.ipynb. @@ -2401,6 +2487,8 @@ def get_identity_grid(self, spatial_size: Sequence[int]): Args: spatial_size: non-dynamic spatial size """ + if not self.eager_mode: + return None ndim = len(spatial_size) if spatial_size != fall_back_tuple(spatial_size, [1] * ndim) or spatial_size != fall_back_tuple( spatial_size, [2] * ndim @@ -2461,6 +2549,13 @@ def __call__( _mode = mode or self.mode _padding_mode = padding_mode or self.padding_mode img = convert_to_tensor(img, track_meta=get_track_meta()) + if not self.eager_mode: + if self._do_transform: + self.rand_affine_grid.affine_only = True + affine = self.rand_affine_grid(sp_size, grid=grid, randomize=randomize) + else: + affine = convert_to_dst_type(torch.eye(len(sp_size) + 1), img)[0] + return self.lazy_call(img, affine, sp_size, _mode, _padding_mode, do_resampling) if not do_resampling: out: torch.Tensor = convert_data_type(img, dtype=torch.float32, device=self.resampler.device)[0] else: @@ -2489,6 +2584,24 @@ def update_meta(self, img, mat, img_size, sp_size): affine = convert_data_type(img.affine, torch.Tensor)[0] img.affine = Affine.compute_w_affine(affine, mat, img_size, sp_size) + def lazy_call(self, img: torch.Tensor, affine, output_size, mode, padding_mode, do_resampling): + if get_track_meta() and isinstance(img, MetaTensor): + img.evaluated = False + orig_size = img.spatial_shape + self.update_meta(img, affine, orig_size, output_size) + self.push_transform( + img, + orig_size=orig_size, + extra_info={ + "affine": affine, + "mode": mode, + "padding_mode": padding_mode, + "do_resampling": do_resampling, + }, + ) + img.spatial_shape = output_size + return img + def inverse(self, data: torch.Tensor) -> torch.Tensor: transform = self.pop_transform(data) # if transform was not performed nothing to do. diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 22b5644dc5..d0a5d48666 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -561,7 +561,7 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch return d -class Resized(MapTransform, InvertibleTransform): +class Resized(MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Resize`. @@ -604,6 +604,10 @@ def __init__( self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.resizer = Resize(spatial_size=spatial_size, size_mode=size_mode) + def set_eager_mode(self, value): + super().set_eager_mode(value) + self.resizer.set_eager_mode(value) + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) for key, mode, align_corners in self.key_iterator(d, self.mode, self.align_corners): @@ -718,7 +722,7 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch return d -class RandAffined(RandomizableTransform, MapTransform, InvertibleTransform): +class RandAffined(RandomizableTransform, MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.RandAffine`. """ @@ -812,6 +816,10 @@ def __init__( self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) + def set_eager_mode(self, value): + super().set_eager_mode(value) + self.rand_affine.set_eager_mode(value) + def set_random_state( self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None ) -> "RandAffined": @@ -830,7 +838,8 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N # all the keys share the same random Affine factor self.rand_affine.randomize() - spatial_size = d[first_key].shape[1:] # type: ignore + item = d[first_key] + spatial_size = item.spatial_shape if isinstance(item, MetaTensor) and not item.evaluated else item.shape[1:] # type: ignore sp_size = fall_back_tuple(self.rand_affine.spatial_size, spatial_size) # change image size or do random transform do_resampling = self._do_transform or (sp_size != ensure_tuple(spatial_size)) @@ -839,7 +848,8 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N if do_resampling: # need to prepare grid grid = self.rand_affine.get_identity_grid(sp_size) if self._do_transform: # add some random factors - grid = self.rand_affine.rand_affine_grid(grid=grid) + self.rand_affine.rand_affine_grid.affine_only = not self.eager_mode + grid = self.rand_affine.rand_affine_grid(sp_size, grid=grid) for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): # do the transform @@ -1289,7 +1299,7 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch return d -class Rotated(MapTransform, InvertibleTransform): +class Rotated(MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Rotate`. @@ -1338,6 +1348,10 @@ def __init__( self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.dtype = ensure_tuple_rep(dtype, len(self.keys)) + def set_eager_mode(self, value): + super().set_eager_mode(value) + self.rotator.set_eager_mode(value) + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) for key, mode, padding_mode, align_corners, dtype in self.key_iterator( @@ -1355,7 +1369,7 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch return d -class RandRotated(RandomizableTransform, MapTransform, InvertibleTransform): +class RandRotated(RandomizableTransform, MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based version :py:class:`monai.transforms.RandRotate` Randomly rotates the input arrays. @@ -1414,6 +1428,10 @@ def __init__( self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.dtype = ensure_tuple_rep(dtype, len(self.keys)) + def set_eager_mode(self, value): + super().set_eager_mode(value) + self.rand_rotate.set_eager_mode(value) + def set_random_state( self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None ) -> "RandRotated": @@ -1456,7 +1474,7 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch return d -class Zoomd(MapTransform, InvertibleTransform): +class Zoomd(MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Zoom`. @@ -1506,6 +1524,10 @@ def __init__( self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.zoomer = Zoom(zoom=zoom, keep_size=keep_size, **kwargs) + def set_eager_mode(self, value): + super().set_eager_mode(value) + self.zoomer.set_eager_mode(value) + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) for key, mode, padding_mode, align_corners in self.key_iterator( @@ -1521,7 +1543,7 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch return d -class RandZoomd(RandomizableTransform, MapTransform, InvertibleTransform): +class RandZoomd(RandomizableTransform, MapTransform, InvertibleTransform, LazyTransform): """ Dict-based version :py:class:`monai.transforms.RandZoom`. @@ -1582,6 +1604,10 @@ def __init__( self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) + def set_eager_mode(self, value): + super().set_eager_mode(value) + self.rand_zoom.set_eager_mode(value) + def set_random_state( self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None ) -> "RandZoomd": From e71476b04149fb30c2d450e51ac4e4a9a45ebe09 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 15 Aug 2022 11:06:18 +0100 Subject: [PATCH 06/10] review spatial transforms Signed-off-by: Wenqi Li --- monai/transforms/compose.py | 33 ++++++++-- monai/transforms/spatial/array.py | 85 +++++++++++++++++--------- monai/transforms/spatial/dictionary.py | 24 ++++++-- monai/transforms/transform.py | 6 +- 4 files changed, 106 insertions(+), 42 deletions(-) diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 8af02d4f2d..0f54c05000 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -30,12 +30,14 @@ apply_transform, ) from monai.utils import MAX_SEED, ensure_tuple, get_seed -from monai.utils.enums import TraceKeys +from monai.utils.enums import GridSampleMode, GridSamplePadMode, TraceKeys __all__ = ["Compose", "OneOf"] -def eval_lazy_stack(data, upcoming, lazy_resample: bool = False): +def eval_lazy_stack( + data, upcoming, lazy_resample: bool = False, mode=GridSampleMode.BILINEAR, padding_mode=GridSamplePadMode.BORDER +): """ Given the upcoming transform ``upcoming``, if lazy_resample is True, go through the Metatensors and evaluate the lazy applied operations. The returned `data` will then be ready for the ``upcoming`` transform. @@ -44,7 +46,7 @@ def eval_lazy_stack(data, upcoming, lazy_resample: bool = False): return data # eager evaluation if isinstance(data, monai.data.MetaTensor): if lazy_resample and not isinstance(upcoming, LazyTransform): - data.evaluate("bilinear") + data.evaluate(mode=mode, padding_mode=padding_mode) return data if isinstance(data, Mapping): if isinstance(upcoming, MapTransform): @@ -135,6 +137,15 @@ class Compose(Randomizable, InvertibleTransform): for NumPy array and PyTorch Tensor, log the data shape and value range, for other metadata, log the values directly. default to `False`. lazy_resample: whether to compute consecutive spatial transforms resampling lazily. Default to False. + mode: {``"bilinear"``, ``"nearest"``} + Interpolation mode when ``lazy_resample=True``. Defaults to ``"bilinear"``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + When `USE_COMPILED` is `True`, this argument uses + ``"nearest"``, ``"bilinear"``, ``"bicubic"`` to indicate 0, 1, 3 order interpolations. + See also: https://docs.monai.io/en/stable/networks.html#grid-pull + padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} + Padding mode for outside grid values when ``lazy_resample=True``. Defaults to ``"border"``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html """ @@ -145,6 +156,8 @@ def __init__( unpack_items: bool = False, log_stats: bool = False, lazy_resample: bool = False, + mode=GridSampleMode.BILINEAR, + padding_mode=GridSamplePadMode.BORDER, ) -> None: if transforms is None: transforms = [] @@ -153,6 +166,8 @@ def __init__( self.unpack_items = unpack_items self.log_stats = log_stats self.lazy_resample = lazy_resample + self.mode = mode + self.padding_mode = padding_mode self.set_random_state(seed=get_seed()) if self.lazy_resample: @@ -202,9 +217,17 @@ def __len__(self): def __call__(self, input_): for _transform in self.transforms: - input_ = eval_lazy_stack(input_, upcoming=_transform, lazy_resample=self.lazy_resample) + input_ = eval_lazy_stack( + input_, + upcoming=_transform, + lazy_resample=self.lazy_resample, + mode=self.mode, + padding_mode=self.padding_mode, + ) input_ = apply_transform(_transform, input_, self.map_items, self.unpack_items, self.log_stats) - input_ = eval_lazy_stack(input_, upcoming=None, lazy_resample=self.lazy_resample) + input_ = eval_lazy_stack( + input_, upcoming=None, lazy_resample=self.lazy_resample, mode=self.mode, padding_mode=self.padding_mode + ) return input_ def inverse(self, data): diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 567ffea1ec..e297867bd3 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -934,7 +934,7 @@ def lazy_call(self, img: torch.Tensor, sp_size, *args) -> torch.Tensor: if get_track_meta() and isinstance(img, MetaTensor): img.evaluated = False img = self._post_process(img, img.spatial_shape, sp_size, *args) - img.spatial_shape = sp_size + img.spatial_shape = sp_size # type: ignore return img def update_meta(self, img, spatial_size, new_spatial_size): @@ -1085,11 +1085,11 @@ def _post_process( ) return img - def lazy_call(self, img, output_shape, transform_t, *args): + def lazy_call(self, img: torch.Tensor, output_shape, transform_t, *args) -> torch.Tensor: if get_track_meta() and isinstance(img, MetaTensor): img.evaluated = False img = self._post_process(img, img.spatial_shape, transform_t, *args) - img.spatial_shape = output_shape + img.spatial_shape = output_shape # type: ignore return img def update_meta(self, img, rotate_mat): @@ -1209,7 +1209,7 @@ def __call__( _align_corners = self.align_corners if align_corners is None else align_corners _padding_mode = padding_mode or self.padding_mode - if not self.eager_mode: + if not self.eager_mode and isinstance(img, MetaTensor): if self.keep_size: raise NotImplementedError("keep_size=True is currently not compatible with lazy evaluation") else: @@ -1247,7 +1247,7 @@ def __call__( ) return out - def lazy_call(self, img, zoom_size, mode, align_corners): + def lazy_call(self, img: torch.Tensor, zoom_size, mode, align_corners) -> torch.Tensor: if get_track_meta() and isinstance(img, MetaTensor): img.evaluated = False self.update_meta(img, img.spatial_shape, zoom_size) @@ -1261,7 +1261,7 @@ def lazy_call(self, img, zoom_size, mode, align_corners): "padcrop": {}, }, ) - img.spatial_shape = zoom_size + img.spatial_shape = zoom_size # type: ignore return img def update_meta(self, img, spatial_size, new_spatial_size): @@ -1293,7 +1293,7 @@ def inverse_transform(self, data: torch.Tensor, transform) -> torch.Tensor: return out -class Rotate90(InvertibleTransform): +class Rotate90(InvertibleTransform, LazyTransform): """ Rotate an array by 90 degrees in the plane specified by `axes`. See `torch.rot90` for additional details: @@ -1311,7 +1311,7 @@ def __init__(self, k: int = 1, spatial_axes: Tuple[int, int] = (0, 1)) -> None: Default: (0, 1), this is the first two axis in spatial dimensions. If axis is negative it counts from the last to the first axis. """ - self.k = k + self.k = (4 + (k % 4)) % 4 # 0, 1, 2, 3 spatial_axes_: Tuple[int, int] = ensure_tuple(spatial_axes) # type: ignore if len(spatial_axes_) != 2: raise ValueError("spatial_axes must be 2 int numbers to indicate the axes to rotate 90 degrees.") @@ -1324,7 +1324,9 @@ def __call__(self, img: torch.Tensor) -> torch.Tensor: """ img = convert_to_tensor(img, track_meta=get_track_meta()) axes = map_spatial_axes(img.ndim, self.spatial_axes) - ori_shape = img.shape[1:] + ori_shape = img.spatial_shape if isinstance(img, MetaTensor) and not img.evaluated else img.shape[1:] + if not self.eager_mode: + return self.lazy_call(img, axes, self.k) out: NdarrayOrTensor = torch.rot90(img, self.k, axes) out = convert_to_dst_type(out, img)[0] if get_track_meta(): @@ -1332,6 +1334,19 @@ def __call__(self, img: torch.Tensor) -> torch.Tensor: self.push_transform(out, extra_info={"axes": [d - 1 for d in axes], "k": self.k}) # compensate spatial dim return out + def lazy_call(self, img: torch.Tensor, axes, k) -> torch.Tensor: + if get_track_meta() and isinstance(img, MetaTensor): + img.evaluated = False + ori_shape = img.spatial_shape.cpu().tolist() + output_shape = img.spatial_shape.cpu().tolist() + if k in (1, 3): + a_0, a_1 = axes[0] - 1, axes[1] - 1 + output_shape[a_0], output_shape[a_1] = ori_shape[a_1], ori_shape[a_0] + self.update_meta(img, ori_shape, output_shape, axes, k) + self.push_transform(img, extra_info={"axes": [d - 1 for d in axes], "k": k}) + img.spatial_shape = output_shape + return img + def update_meta(self, img, spatial_size, new_spatial_size, axes, k): affine = convert_data_type(img.affine, torch.Tensor)[0] r, sp_r = len(affine) - 1, len(spatial_size) @@ -1362,7 +1377,7 @@ def inverse_transform(self, data: torch.Tensor, transform) -> torch.Tensor: return xform(data) -class RandRotate90(RandomizableTransform, InvertibleTransform): +class RandRotate90(RandomizableTransform, InvertibleTransform, LazyTransform): """ With probability `prob`, input arrays are rotated by 90 degrees in the plane specified by `spatial_axes`. @@ -1401,7 +1416,9 @@ def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor: self.randomize() if self._do_transform: - out = Rotate90(self._rand_k, self.spatial_axes)(img) + xform = Rotate90(self._rand_k, self.spatial_axes) + xform.set_eager_mode(self.eager_mode) + out = xform(img) else: out = convert_to_tensor(img, track_meta=get_track_meta()) @@ -1774,7 +1791,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: return Zoom(self._zoom).inverse_transform(data, xform_info[TraceKeys.EXTRA_INFO]) -class AffineGrid(Transform): +class AffineGrid(Transform, LazyTransform): """ Affine transforms on the coordinates. @@ -1800,7 +1817,6 @@ class AffineGrid(Transform): device: device on which the tensor will be allocated, if a new grid is generated. affine: If applied, ignore the params (`rotate_params`, etc.) and use the supplied matrix. Should be square with each side = num of image spatial dimensions + 1. - affine_only: whether to return an affine matrix without computing the actual grid. Defaults to False. .. deprecated:: 0.6.0 ``as_tensor_output`` is deprecated. @@ -1820,7 +1836,6 @@ def __init__( device: Optional[torch.device] = None, dtype: DtypeLike = np.float32, affine: Optional[NdarrayOrTensor] = None, - affine_only: bool = False, ) -> None: self.rotate_params = rotate_params self.shear_params = shear_params @@ -1829,11 +1844,10 @@ def __init__( self.device = device self.dtype = dtype self.affine = affine - self.affine_only = affine_only def __call__( self, spatial_size: Optional[Sequence[int]] = None, grid: Optional[torch.Tensor] = None - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[Optional[torch.Tensor], torch.Tensor]: """ The grid can be initialized with a `spatial_size` parameter, or provided directly as `grid`. Therefore, either `spatial_size` or `grid` must be provided. @@ -1847,7 +1861,7 @@ def __call__( ValueError: When ``grid=None`` and ``spatial_size=None``. Incompatible values. """ - if not self.affine_only: + if self.eager_mode: if grid is None: # create grid from spatial_size if spatial_size is None: raise ValueError("Incompatible values: grid=None and spatial_size=None.") @@ -1860,7 +1874,7 @@ def __call__( spatial_dims = len(grid_.shape) - 1 else: _device = self.device - spatial_dims = len(spatial_size) + spatial_dims = len(spatial_size) # type: ignore _b = TransformBackends.TORCH affine: NdarrayOrTensor if self.affine is None: @@ -1875,8 +1889,8 @@ def __call__( affine = affine @ create_scale(spatial_dims, self.scale_params, device=_device, backend=_b) else: affine = self.affine - if self.affine_only: - return None, affine + if not self.eager_mode: + return None, affine # type: ignore affine = to_affine_nd(len(grid_) - 1, affine) affine = convert_to_tensor(affine, device=grid_.device, dtype=grid_.dtype, track_meta=False) # type: ignore @@ -1884,7 +1898,7 @@ def __call__( return grid_, affine # type: ignore -class RandAffineGrid(Randomizable, Transform): +class RandAffineGrid(Randomizable, Transform, LazyTransform): """ Generate randomised affine grid. @@ -1901,7 +1915,7 @@ def __init__( scale_range: RandRange = None, as_tensor_output: bool = True, device: Optional[torch.device] = None, - affine_only: bool = False, + dtype: DtypeLike = np.float32, ) -> None: """ Args: @@ -1928,6 +1942,8 @@ def __init__( the scale factor to translate for every spatial dims. A value of 1.0 is added to the result. This allows 0 to correspond to no change (i.e., a scaling of 1.0). device: device to store the output grid data. + dtype: data type for the grid computation. Defaults to ``np.float32``. + If ``None``, use the data type of input data (if `grid` is provided). See also: - :py:meth:`monai.transforms.utils.create_rotate` @@ -1950,7 +1966,7 @@ def __init__( self.scale_params: Optional[List[float]] = None self.device = device - self.affine_only = affine_only + self.dtype = dtype self.affine: Optional[torch.Tensor] = torch.eye(4, dtype=torch.float64) def _get_rand_param(self, param_range, add_scalar: float = 0.0): @@ -1993,10 +2009,11 @@ def __call__( translate_params=self.translate_params, scale_params=self.scale_params, device=self.device, - affine_only=self.affine_only, + dtype=self.dtype, ) - if affine_grid.affine_only: - return affine_grid(spatial_size, grid)[1] + affine_grid.set_eager_mode(self.eager_mode) + if not self.eager_mode: # return the affine only, don't construct the grid + return affine_grid(spatial_size, grid)[1] # type: ignore _grid: torch.Tensor _grid, self.affine = affine_grid(spatial_size, grid) return _grid @@ -2275,6 +2292,10 @@ def __init__( self.mode: str = look_up_option(mode, GridSampleMode) self.padding_mode: str = look_up_option(padding_mode, GridSamplePadMode) + def set_eager_mode(self, value): + super().set_eager_mode(value) + self.affine_grid.set_eager_mode(value) + def __call__( self, img: torch.Tensor, @@ -2306,7 +2327,6 @@ def __call__( _mode = mode or self.mode _padding_mode = padding_mode or self.padding_mode if not self.eager_mode: - self.affine_grid.affine_only = True _, affine = self.affine_grid(spatial_size=sp_size) return self.lazy_call(img, affine, sp_size, _mode, _padding_mode) grid, affine = self.affine_grid(spatial_size=sp_size) @@ -2458,10 +2478,16 @@ def __init__( self.mode: str = GridSampleMode(mode) self.padding_mode: str = GridSamplePadMode(padding_mode) + def set_eager_mode(self, value): + super().set_eager_mode(value) + self.rand_affine_grid.set_eager_mode(value) + def _init_identity_cache(self): """ Create cache of the identity grid if cache_grid=True and spatial_size is known. """ + if not self.eager_mode: + return None if self.spatial_size is None: if self.cache_grid: warnings.warn( @@ -2551,10 +2577,9 @@ def __call__( img = convert_to_tensor(img, track_meta=get_track_meta()) if not self.eager_mode: if self._do_transform: - self.rand_affine_grid.affine_only = True affine = self.rand_affine_grid(sp_size, grid=grid, randomize=randomize) else: - affine = convert_to_dst_type(torch.eye(len(sp_size) + 1), img)[0] + affine = convert_to_dst_type(torch.eye(len(sp_size) + 1), img, dtype=self.rand_affine_grid.dtype)[0] return self.lazy_call(img, affine, sp_size, _mode, _padding_mode, do_resampling) if not do_resampling: out: torch.Tensor = convert_data_type(img, dtype=torch.float32, device=self.resampler.device)[0] @@ -2584,7 +2609,7 @@ def update_meta(self, img, mat, img_size, sp_size): affine = convert_data_type(img.affine, torch.Tensor)[0] img.affine = Affine.compute_w_affine(affine, mat, img_size, sp_size) - def lazy_call(self, img: torch.Tensor, affine, output_size, mode, padding_mode, do_resampling): + def lazy_call(self, img: torch.Tensor, affine, output_size, mode, padding_mode, do_resampling) -> torch.Tensor: if get_track_meta() and isinstance(img, MetaTensor): img.evaluated = False orig_size = img.spatial_shape diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index d0a5d48666..46f165f8f3 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -462,7 +462,7 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch return d -class Rotate90d(MapTransform, InvertibleTransform): +class Rotate90d(MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Rotate90`. """ @@ -482,6 +482,10 @@ def __init__( super().__init__(keys, allow_missing_keys) self.rotator = Rotate90(k, spatial_axes) + def set_eager_mode(self, value): + super().set_eager_mode(value) + self.rotator.set_eager_mode(value) + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) for key in self.key_iterator(d): @@ -495,7 +499,7 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch return d -class RandRotate90d(RandomizableTransform, MapTransform, InvertibleTransform): +class RandRotate90d(RandomizableTransform, MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based version :py:class:`monai.transforms.RandRotate90`. With probability `prob`, input arrays are rotated by 90 degrees @@ -543,6 +547,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Mapping[Hashable, t # FIXME: here we didn't use array version `RandRotate90` transform as others, because we need # to be compatible with the random status of some previous integration tests rotator = Rotate90(self._rand_k, self.spatial_axes) + rotator.set_eager_mode(value) for key in self.key_iterator(d): d[key] = rotator(d[key]) if self._do_transform else convert_to_tensor(d[key], track_meta=get_track_meta()) if get_track_meta(): @@ -621,7 +626,7 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch return d -class Affined(MapTransform, InvertibleTransform): +class Affined(MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.Affine`. """ @@ -709,6 +714,10 @@ def __init__( self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) + def set_eager_mode(self, value): + super().set_eager_mode(value) + self.affine.set_eager_mode(value) + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): @@ -829,7 +838,7 @@ def set_random_state( def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - first_key: Union[Hashable, List] = self.first_key(d) + first_key: Hashable = self.first_key(d) if first_key == []: out: Dict[Hashable, NdarrayOrTensor] = convert_to_tensor(d, track_meta=get_track_meta()) return out @@ -848,7 +857,6 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N if do_resampling: # need to prepare grid grid = self.rand_affine.get_identity_grid(sp_size) if self._do_transform: # add some random factors - self.rand_affine.rand_affine_grid.affine_only = not self.eager_mode grid = self.rand_affine.rand_affine_grid(sp_size, grid=grid) for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): @@ -1241,7 +1249,7 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch return d -class RandAxisFlipd(RandomizableTransform, MapTransform, InvertibleTransform): +class RandAxisFlipd(RandomizableTransform, MapTransform, InvertibleTransform, LazyTransform): """ Dictionary-based version :py:class:`monai.transforms.RandAxisFlip`. @@ -1262,6 +1270,10 @@ def __init__(self, keys: KeysCollection, prob: float = 0.1, allow_missing_keys: RandomizableTransform.__init__(self, prob) self.flipper = RandAxisFlip(prob=1.0) + def set_eager_mode(self, value): + super().set_eager_mode(value) + self.flipper.set_eager_mode(value) + def set_random_state( self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None ) -> "RandAxisFlipd": diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index d61f7aa727..b0909030b0 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -140,7 +140,11 @@ class LazyTransform: eager_mode = True - def set_eager_mode(self, value): + def set_eager_mode(self, value: bool): + """ + when eager_mode is True, the transform should return the transformed array with up-to-date metadata. + When it's False, the transform may return updated metadata and not running the actual data array transform. + """ self.eager_mode = value From 0bf0efa17275e718a5962a7e74614a7ee08433c8 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 15 Aug 2022 11:12:38 +0100 Subject: [PATCH 07/10] autofix Signed-off-by: Wenqi Li --- monai/transforms/spatial/dictionary.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 46f165f8f3..a0bc702456 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -547,7 +547,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Mapping[Hashable, t # FIXME: here we didn't use array version `RandRotate90` transform as others, because we need # to be compatible with the random status of some previous integration tests rotator = Rotate90(self._rand_k, self.spatial_axes) - rotator.set_eager_mode(value) + rotator.set_eager_mode(self.eager_mode) for key in self.key_iterator(d): d[key] = rotator(d[key]) if self._do_transform else convert_to_tensor(d[key], track_meta=get_track_meta()) if get_track_meta(): @@ -838,7 +838,7 @@ def set_random_state( def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - first_key: Hashable = self.first_key(d) + first_key = self.first_key(d) if first_key == []: out: Dict[Hashable, NdarrayOrTensor] = convert_to_tensor(d, track_meta=get_track_meta()) return out From aa1bcdd9d5530932bfb53c1dcd81b709165af0b5 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 15 Aug 2022 12:10:32 +0100 Subject: [PATCH 08/10] fixes tests Signed-off-by: Wenqi Li --- tests/test_meta_tensor.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index 2f873c2d73..891bd2258a 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -74,8 +74,8 @@ def check_meta(self, a: MetaTensor, b: MetaTensor) -> None: aff_a = meta_a.get("affine", None) aff_b = meta_b.get("affine", None) assert_allclose(aff_a, aff_b) - meta_a = {k: v for k, v in meta_a.items() if k != "affine"} - meta_b = {k: v for k, v in meta_b.items() if k != "affine"} + meta_a = {k: v for k, v in meta_a.items() if k not in ("affine", "original_channel_dim")} + meta_b = {k: v for k, v in meta_b.items() if k not in ("affine", "original_channel_dim")} self.assertEqual(meta_a, meta_b) def check( @@ -122,7 +122,7 @@ def test_as_tensor(self, device, dtype): def test_as_dict(self): m, _ = self.get_im() m_dict = m.as_dict("im") - im, meta = m_dict["im"], m_dict[PostFix.meta("im")] + im, meta = m_dict["im"], deepcopy(m_dict[PostFix.meta("im")]) affine = meta.pop("affine") m2 = MetaTensor(im, affine, meta) self.check(m2, m, check_ids=False) From 7c6f71420620b31d39a53332f1ead876d6ba209b Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 15 Aug 2022 13:29:14 +0100 Subject: [PATCH 09/10] fixes tests Signed-off-by: Wenqi Li --- monai/data/meta_tensor.py | 4 +++- tests/test_ensure_channel_first.py | 4 +++- tests/test_meta_tensor.py | 4 ++-- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 96439dd937..f11d629dc3 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -150,8 +150,10 @@ def __init__( if MetaKeys.SPACE not in self.meta: self.meta[MetaKeys.SPACE] = SpaceKeys.RAS # defaulting to the right-anterior-superior space + if MetaKeys.EVALUATED not in self.meta: + self.meta[MetaKeys.EVALUATED] = True if MetaKeys.ORIGINAL_CHANNEL_DIM not in self.meta: - self.meta[MetaKeys.ORIGINAL_CHANNEL_DIM] = 0 # defaulting to channel first + self.meta[MetaKeys.ORIGINAL_CHANNEL_DIM] = "no_channel" # defaulting to channel first @property def evaluated(self) -> bool: diff --git a/tests/test_ensure_channel_first.py b/tests/test_ensure_channel_first.py index fdf776d81a..dae1e884e1 100644 --- a/tests/test_ensure_channel_first.py +++ b/tests/test_ensure_channel_first.py @@ -81,7 +81,9 @@ def test_check(self): with self.assertRaises(ValueError): # not MetaTensor EnsureChannelFirst(add_channel_default=False)(im) with self.assertRaises(ValueError): # no meta - EnsureChannelFirst(add_channel_default=False)(MetaTensor(im)) + test_case = MetaTensor(im) + test_case.meta.pop("original_channel_dim") + EnsureChannelFirst(add_channel_default=False)(test_case) with self.assertRaises(ValueError): # no meta channel EnsureChannelFirst(add_channel_default=False)(im_nodim) diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index 891bd2258a..64d604b65b 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -74,8 +74,8 @@ def check_meta(self, a: MetaTensor, b: MetaTensor) -> None: aff_a = meta_a.get("affine", None) aff_b = meta_b.get("affine", None) assert_allclose(aff_a, aff_b) - meta_a = {k: v for k, v in meta_a.items() if k not in ("affine", "original_channel_dim")} - meta_b = {k: v for k, v in meta_b.items() if k not in ("affine", "original_channel_dim")} + meta_a = {k: v for k, v in meta_a.items() if k not in ("affine", "original_channel_dim", "evaluated")} + meta_b = {k: v for k, v in meta_b.items() if k not in ("affine", "original_channel_dim", "evaluated")} self.assertEqual(meta_a, meta_b) def check( From 713fde934fb7f1040762052c6611dbd6a63af4c6 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 19 Aug 2022 12:16:26 +0100 Subject: [PATCH 10/10] testing backends Signed-off-by: Wenqi Li --- monai/data/meta_tensor.py | 2 +- monai/transforms/compose.py | 7 ++++--- monai/transforms/spatial/array.py | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index f11d629dc3..ba89ed91c7 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -180,7 +180,7 @@ def evaluate(self, mode="bilinear", padding_mode="border"): resampler = monai.transforms.SpatialResample(mode=mode, padding_mode=padding_mode) dst_affine, self.affine = self.affine, self.meta[MetaKeys.ORIGINAL_AFFINE] with resampler.trace_transform(False): - output = resampler(self, dst_affine=dst_affine, spatial_size=self.spatial_shape) + output = resampler(self, dst_affine=dst_affine, spatial_size=self.spatial_shape, align_corners=True) self.array = output.array self.spatial_shape = self.array.shape[1:] self.affine = dst_affine diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 0f54c05000..ea51ba36fb 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -51,11 +51,12 @@ def eval_lazy_stack( if isinstance(data, Mapping): if isinstance(upcoming, MapTransform): return { - k: eval_lazy_stack(v, upcoming, lazy_resample) if k in upcoming.keys else v for k, v in data.items() + k: eval_lazy_stack(v, upcoming, lazy_resample, mode, padding_mode) if k in upcoming.keys else v + for k, v in data.items() } - return {k: eval_lazy_stack(v, upcoming, lazy_resample) for k, v in data.items()} + return {k: eval_lazy_stack(v, upcoming, lazy_resample, mode, padding_mode) for k, v in data.items()} if isinstance(data, (list, tuple)): - return [eval_lazy_stack(v, upcoming, lazy_resample) for v in data] + return [eval_lazy_stack(v, upcoming, lazy_resample, mode, padding_mode) for v in data] return data diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 8cff79ef4f..20ee1c0fb6 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -308,7 +308,7 @@ def __call__( for idx, d_dst in enumerate(spatial_size[:spatial_rank]): _t_r[idx, -1] = (max(d_dst, 2) - 1.0) / 2.0 xform = xform @ _t_r - if not USE_COMPILED: + if not USE_COMPILED and not isinstance(mode, int): _t_l = normalize_transform( in_spatial_size, xform.device, xform.dtype, align_corners=True # type: ignore )[0]