diff --git a/monai/data/meta_obj.py b/monai/data/meta_obj.py index 6aab05dc94..74daf59ba8 100644 --- a/monai/data/meta_obj.py +++ b/monai/data/meta_obj.py @@ -213,6 +213,9 @@ def push_pending_operation(self, t: Any) -> None: def pop_pending_operation(self) -> Any: return self._pending_operations.pop() + def clear_pending_operations(self) -> Any: + self._pending_operations = MetaObj.get_default_applied_operations() + @property def is_batch(self) -> bool: """Return whether object is part of batch or not.""" diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 62b9b8d4d1..b227b5bd5e 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -227,6 +227,8 @@ from .inverse_batch_transform import BatchInverseTransform, Decollated, DecollateD, DecollateDict from .io.array import SUPPORTED_READERS, LoadImage, SaveImage from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict +from .lazy.functional import apply +from .lazy.utils import combine_transforms, resample from .meta_utility.dictionary import ( FromMetaTensord, FromMetaTensorD, diff --git a/monai/transforms/lazy/__init__.py b/monai/transforms/lazy/__init__.py new file mode 100644 index 0000000000..1e97f89407 --- /dev/null +++ b/monai/transforms/lazy/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/monai/transforms/lazy/functional.py b/monai/transforms/lazy/functional.py new file mode 100644 index 0000000000..0536bbe85b --- /dev/null +++ b/monai/transforms/lazy/functional.py @@ -0,0 +1,62 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Union + +import torch + +from monai.data.meta_tensor import MetaTensor +from monai.data.utils import to_affine_nd +from monai.transforms.lazy.utils import ( + affine_from_pending, + combine_transforms, + is_compatible_apply_kwargs, + kwargs_from_pending, + resample, +) + +__all__ = ["apply"] + + +def apply(data: Union[torch.Tensor, MetaTensor], pending: Optional[list] = None): + """ + This method applies pending transforms to `data` tensors. + + Args: + data: A torch Tensor or a monai MetaTensor. + pending: pending transforms. This must be set if data is a Tensor, but is optional if data is a MetaTensor. + """ + if isinstance(data, MetaTensor) and pending is None: + pending = data.pending_operations + pending = [] if pending is None else pending + + if not pending: + return data + + cumulative_xform = affine_from_pending(pending[0]) + cur_kwargs = kwargs_from_pending(pending[0]) + + for p in pending[1:]: + new_kwargs = kwargs_from_pending(p) + if not is_compatible_apply_kwargs(cur_kwargs, new_kwargs): + # carry out an intermediate resample here due to incompatibility between arguments + data = resample(data, cumulative_xform, cur_kwargs) + next_matrix = affine_from_pending(p) + cumulative_xform = combine_transforms(cumulative_xform, next_matrix) + cur_kwargs.update(new_kwargs) + data = resample(data, cumulative_xform, cur_kwargs) + if isinstance(data, MetaTensor): + data.clear_pending_operations() + data.affine = data.affine @ to_affine_nd(3, cumulative_xform) + for p in pending: + data.push_applied_operation(p) + + return data, pending diff --git a/monai/transforms/lazy/utils.py b/monai/transforms/lazy/utils.py new file mode 100644 index 0000000000..4e37e78833 --- /dev/null +++ b/monai/transforms/lazy/utils.py @@ -0,0 +1,125 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import numpy as np +import torch + +import monai +from monai.config import NdarrayOrTensor +from monai.utils import LazyAttr, convert_to_tensor + +__all__ = ["resample", "combine_transforms"] + + +class Affine: + """A class to represent an affine transform matrix.""" + + __slots__ = ("data",) + + def __init__(self, data): + self.data = data + + @staticmethod + def is_affine_shaped(data): + """Check if the data is an affine matrix.""" + if isinstance(data, Affine): + return True + if isinstance(data, DisplacementField): + return False + if not hasattr(data, "shape") or len(data.shape) < 2: + return False + return data.shape[-1] in (3, 4) and data.shape[-2] in (3, 4) and data.shape[-1] == data.shape[-2] + + +class DisplacementField: + """A class to represent a dense displacement field.""" + + __slots__ = ("data",) + + def __init__(self, data): + self.data = data + + @staticmethod + def is_ddf_shaped(data): + """Check if the data is a DDF.""" + if isinstance(data, DisplacementField): + return True + if isinstance(data, Affine): + return False + if not hasattr(data, "shape") or len(data.shape) < 3: + return False + return not Affine.is_affine_shaped(data) + + +def combine_transforms(left: torch.Tensor, right: torch.Tensor) -> torch.Tensor: + """Given transforms A and B to be applied to x, return the combined transform (AB), so that A(B(x)) becomes AB(x)""" + if Affine.is_affine_shaped(left) and Affine.is_affine_shaped(right): # linear transforms + left = convert_to_tensor(left.data if isinstance(left, Affine) else left, wrap_sequence=True) + right = convert_to_tensor(right.data if isinstance(right, Affine) else right, wrap_sequence=True) + return torch.matmul(left, right) + if DisplacementField.is_ddf_shaped(left) and DisplacementField.is_ddf_shaped( + right + ): # adds DDFs, do we need metadata if metatensor input? + left = convert_to_tensor(left.data if isinstance(left, DisplacementField) else left, wrap_sequence=True) + right = convert_to_tensor(right.data if isinstance(right, DisplacementField) else right, wrap_sequence=True) + return left + right + raise NotImplementedError + + +def affine_from_pending(pending_item): + """Extract the affine matrix from a pending transform item.""" + if isinstance(pending_item, (torch.Tensor, np.ndarray)): + return pending_item + if isinstance(pending_item, dict): + return pending_item[LazyAttr.AFFINE] + return pending_item + + +def kwargs_from_pending(pending_item): + """Extract kwargs from a pending transform item.""" + if not isinstance(pending_item, dict): + return {} + ret = { + LazyAttr.INTERP_MODE: pending_item.get(LazyAttr.INTERP_MODE, None), # interpolation mode + LazyAttr.PADDING_MODE: pending_item.get(LazyAttr.PADDING_MODE, None), # padding mode + } + if LazyAttr.SHAPE in pending_item: + ret[LazyAttr.SHAPE] = pending_item[LazyAttr.SHAPE] + if LazyAttr.DTYPE in pending_item: + ret[LazyAttr.DTYPE] = pending_item[LazyAttr.DTYPE] + return ret + + +def is_compatible_apply_kwargs(kwargs_1, kwargs_2): + """Check if two sets of kwargs are compatible (to be combined in `apply`).""" + return True + + +def resample(data: torch.Tensor, matrix: NdarrayOrTensor, kwargs: Optional[dict] = None): + """ + This is a minimal implementation of resample that always uses Affine. + """ + if not Affine.is_affine_shaped(matrix): + raise NotImplementedError("calling dense grid resample API not implemented") + kwargs = {} if kwargs is None else kwargs + init_kwargs = { + "spatial_size": kwargs.pop(LazyAttr.SHAPE, data.shape)[1:], + "dtype": kwargs.pop(LazyAttr.DTYPE, data.dtype), + } + call_kwargs = { + "mode": kwargs.pop(LazyAttr.INTERP_MODE, None), + "padding_mode": kwargs.pop(LazyAttr.PADDING_MODE, None), + } + resampler = monai.transforms.Affine(affine=matrix, image_only=True, **init_kwargs) + with resampler.trace_transform(False): # don't track this transform in `data` + return resampler(img=data, **call_kwargs) diff --git a/monai/utils/enums.py b/monai/utils/enums.py index 4fd9bea557..b606cd8667 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -630,3 +630,4 @@ class LazyAttr(StrEnum): AFFINE = "lazy_affine" PADDING_MODE = "lazy_padding_mode" INTERP_MODE = "lazy_interpolation_mode" + DTYPE = "lazy_dtype" diff --git a/tests/test_apply.py b/tests/test_apply.py new file mode 100644 index 0000000000..afb29ad576 --- /dev/null +++ b/tests/test_apply.py @@ -0,0 +1,71 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch + +from monai.transforms.lazy.functional import apply +from monai.transforms.utils import create_rotate +from monai.utils import LazyAttr, convert_to_tensor +from tests.utils import get_arange_img + + +def single_2d_transform_cases(): + return [ + ( + torch.as_tensor(get_arange_img((32, 32))), + [{LazyAttr.AFFINE: create_rotate(2, np.pi / 4)}, {LazyAttr.AFFINE: create_rotate(2, -np.pi / 4)}], + (1, 32, 32), + ), + (torch.as_tensor(get_arange_img((32, 32))), [create_rotate(2, np.pi / 2)], (1, 32, 32)), + ( + torch.as_tensor(get_arange_img((16, 16))), + [{LazyAttr.AFFINE: create_rotate(2, np.pi / 2), LazyAttr.SHAPE: (1, 45, 45)}], + (1, 45, 45), + ), + ] + + +class TestApply(unittest.TestCase): + def _test_apply_impl(self, tensor, pending_transforms, expected_shape): + result = apply(tensor, pending_transforms) + self.assertListEqual(result[1], pending_transforms) + self.assertEqual(result[0].shape, expected_shape) + + def _test_apply_metatensor_impl(self, tensor, pending_transforms, expected_shape, pending_as_parameter): + tensor_ = convert_to_tensor(tensor, track_meta=True) + if pending_as_parameter: + result, transforms = apply(tensor_, pending_transforms) + else: + for p in pending_transforms: + tensor_.push_pending_operation(p) + result, transforms = apply(tensor_) + self.assertEqual(result.shape, expected_shape) + + SINGLE_TRANSFORM_CASES = single_2d_transform_cases() + + def test_apply_single_transform(self): + for case in self.SINGLE_TRANSFORM_CASES: + self._test_apply_impl(*case) + + def test_apply_single_transform_metatensor(self): + for case in self.SINGLE_TRANSFORM_CASES: + self._test_apply_metatensor_impl(*case, False) + + def test_apply_single_transform_metatensor_override(self): + for case in self.SINGLE_TRANSFORM_CASES: + self._test_apply_metatensor_impl(*case, True) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_resample.py b/tests/test_resample.py new file mode 100644 index 0000000000..0136552334 --- /dev/null +++ b/tests/test_resample.py @@ -0,0 +1,40 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.transforms.lazy.functional import resample +from monai.utils import convert_to_tensor +from tests.utils import assert_allclose, get_arange_img + + +def rotate_90_2d(): + t = torch.eye(3) + t[:, 0] = torch.FloatTensor([0, -1, 0]) + t[:, 1] = torch.FloatTensor([1, 0, 0]) + return t + + +RESAMPLE_FUNCTION_CASES = [(get_arange_img((3, 3)), rotate_90_2d(), [[2, 5, 8], [1, 4, 7], [0, 3, 6]])] + + +class TestResampleFunction(unittest.TestCase): + @parameterized.expand(RESAMPLE_FUNCTION_CASES) + def test_resample_function_impl(self, img, matrix, expected): + out = resample(convert_to_tensor(img), matrix) + assert_allclose(out[0], expected, type_test=False) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/utils.py b/tests/utils.py index afe08e0bfa..e963e03668 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -348,6 +348,16 @@ def make_rand_affine(ndim: int = 3, random_state: Optional[np.random.RandomState return af +def get_arange_img(size, dtype=np.float32, offset=0): + """ + Returns an image as a numpy array (complete with channel as dim 0) + with contents that iterate like an arange. + """ + n_elem = np.prod(size) + img = np.arange(offset, offset + n_elem, dtype=dtype).reshape(size) + return np.expand_dims(img, 0) + + class DistTestCase(unittest.TestCase): """ testcase without _outcome, so that it's picklable.