Skip to content

Replacement Apply and Resample #5436

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Nov 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions monai/data/meta_obj.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,9 @@ def push_pending_operation(self, t: Any) -> None:
def pop_pending_operation(self) -> Any:
return self._pending_operations.pop()

def clear_pending_operations(self) -> Any:
self._pending_operations = MetaObj.get_default_applied_operations()

@property
def is_batch(self) -> bool:
"""Return whether object is part of batch or not."""
Expand Down
2 changes: 2 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,8 @@
from .inverse_batch_transform import BatchInverseTransform, Decollated, DecollateD, DecollateDict
from .io.array import SUPPORTED_READERS, LoadImage, SaveImage
from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict
from .lazy.functional import apply
from .lazy.utils import combine_transforms, resample
from .meta_utility.dictionary import (
FromMetaTensord,
FromMetaTensorD,
Expand Down
10 changes: 10 additions & 0 deletions monai/transforms/lazy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
62 changes: 62 additions & 0 deletions monai/transforms/lazy/functional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Union

import torch

from monai.data.meta_tensor import MetaTensor
from monai.data.utils import to_affine_nd
from monai.transforms.lazy.utils import (
affine_from_pending,
combine_transforms,
is_compatible_apply_kwargs,
kwargs_from_pending,
resample,
)

__all__ = ["apply"]


def apply(data: Union[torch.Tensor, MetaTensor], pending: Optional[list] = None):
"""
This method applies pending transforms to `data` tensors.

Args:
data: A torch Tensor or a monai MetaTensor.
pending: pending transforms. This must be set if data is a Tensor, but is optional if data is a MetaTensor.
"""
if isinstance(data, MetaTensor) and pending is None:
pending = data.pending_operations
pending = [] if pending is None else pending

if not pending:
return data

cumulative_xform = affine_from_pending(pending[0])
cur_kwargs = kwargs_from_pending(pending[0])

for p in pending[1:]:
new_kwargs = kwargs_from_pending(p)
if not is_compatible_apply_kwargs(cur_kwargs, new_kwargs):
# carry out an intermediate resample here due to incompatibility between arguments
data = resample(data, cumulative_xform, cur_kwargs)
next_matrix = affine_from_pending(p)
cumulative_xform = combine_transforms(cumulative_xform, next_matrix)
cur_kwargs.update(new_kwargs)
data = resample(data, cumulative_xform, cur_kwargs)
if isinstance(data, MetaTensor):
data.clear_pending_operations()
data.affine = data.affine @ to_affine_nd(3, cumulative_xform)
for p in pending:
data.push_applied_operation(p)

return data, pending
125 changes: 125 additions & 0 deletions monai/transforms/lazy/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

import numpy as np
import torch

import monai
from monai.config import NdarrayOrTensor
from monai.utils import LazyAttr, convert_to_tensor

__all__ = ["resample", "combine_transforms"]


class Affine:
"""A class to represent an affine transform matrix."""

__slots__ = ("data",)

def __init__(self, data):
self.data = data

@staticmethod
def is_affine_shaped(data):
"""Check if the data is an affine matrix."""
if isinstance(data, Affine):
return True
if isinstance(data, DisplacementField):
return False
if not hasattr(data, "shape") or len(data.shape) < 2:
return False
return data.shape[-1] in (3, 4) and data.shape[-2] in (3, 4) and data.shape[-1] == data.shape[-2]


class DisplacementField:
"""A class to represent a dense displacement field."""

__slots__ = ("data",)

def __init__(self, data):
self.data = data

@staticmethod
def is_ddf_shaped(data):
"""Check if the data is a DDF."""
if isinstance(data, DisplacementField):
return True
if isinstance(data, Affine):
return False
if not hasattr(data, "shape") or len(data.shape) < 3:
return False
return not Affine.is_affine_shaped(data)


def combine_transforms(left: torch.Tensor, right: torch.Tensor) -> torch.Tensor:
"""Given transforms A and B to be applied to x, return the combined transform (AB), so that A(B(x)) becomes AB(x)"""
if Affine.is_affine_shaped(left) and Affine.is_affine_shaped(right): # linear transforms
left = convert_to_tensor(left.data if isinstance(left, Affine) else left, wrap_sequence=True)
right = convert_to_tensor(right.data if isinstance(right, Affine) else right, wrap_sequence=True)
return torch.matmul(left, right)
if DisplacementField.is_ddf_shaped(left) and DisplacementField.is_ddf_shaped(
right
): # adds DDFs, do we need metadata if metatensor input?
left = convert_to_tensor(left.data if isinstance(left, DisplacementField) else left, wrap_sequence=True)
right = convert_to_tensor(right.data if isinstance(right, DisplacementField) else right, wrap_sequence=True)
return left + right
raise NotImplementedError


def affine_from_pending(pending_item):
"""Extract the affine matrix from a pending transform item."""
if isinstance(pending_item, (torch.Tensor, np.ndarray)):
return pending_item
if isinstance(pending_item, dict):
return pending_item[LazyAttr.AFFINE]
return pending_item


def kwargs_from_pending(pending_item):
"""Extract kwargs from a pending transform item."""
if not isinstance(pending_item, dict):
return {}
ret = {
LazyAttr.INTERP_MODE: pending_item.get(LazyAttr.INTERP_MODE, None), # interpolation mode
LazyAttr.PADDING_MODE: pending_item.get(LazyAttr.PADDING_MODE, None), # padding mode
}
if LazyAttr.SHAPE in pending_item:
ret[LazyAttr.SHAPE] = pending_item[LazyAttr.SHAPE]
if LazyAttr.DTYPE in pending_item:
ret[LazyAttr.DTYPE] = pending_item[LazyAttr.DTYPE]
return ret


def is_compatible_apply_kwargs(kwargs_1, kwargs_2):
"""Check if two sets of kwargs are compatible (to be combined in `apply`)."""
return True


def resample(data: torch.Tensor, matrix: NdarrayOrTensor, kwargs: Optional[dict] = None):
"""
This is a minimal implementation of resample that always uses Affine.
"""
if not Affine.is_affine_shaped(matrix):
raise NotImplementedError("calling dense grid resample API not implemented")
kwargs = {} if kwargs is None else kwargs
init_kwargs = {
"spatial_size": kwargs.pop(LazyAttr.SHAPE, data.shape)[1:],
"dtype": kwargs.pop(LazyAttr.DTYPE, data.dtype),
}
call_kwargs = {
"mode": kwargs.pop(LazyAttr.INTERP_MODE, None),
"padding_mode": kwargs.pop(LazyAttr.PADDING_MODE, None),
}
resampler = monai.transforms.Affine(affine=matrix, image_only=True, **init_kwargs)
with resampler.trace_transform(False): # don't track this transform in `data`
return resampler(img=data, **call_kwargs)
1 change: 1 addition & 0 deletions monai/utils/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,3 +630,4 @@ class LazyAttr(StrEnum):
AFFINE = "lazy_affine"
PADDING_MODE = "lazy_padding_mode"
INTERP_MODE = "lazy_interpolation_mode"
DTYPE = "lazy_dtype"
71 changes: 71 additions & 0 deletions tests/test_apply.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
import torch

from monai.transforms.lazy.functional import apply
from monai.transforms.utils import create_rotate
from monai.utils import LazyAttr, convert_to_tensor
from tests.utils import get_arange_img


def single_2d_transform_cases():
return [
(
torch.as_tensor(get_arange_img((32, 32))),
[{LazyAttr.AFFINE: create_rotate(2, np.pi / 4)}, {LazyAttr.AFFINE: create_rotate(2, -np.pi / 4)}],
(1, 32, 32),
),
(torch.as_tensor(get_arange_img((32, 32))), [create_rotate(2, np.pi / 2)], (1, 32, 32)),
(
torch.as_tensor(get_arange_img((16, 16))),
[{LazyAttr.AFFINE: create_rotate(2, np.pi / 2), LazyAttr.SHAPE: (1, 45, 45)}],
(1, 45, 45),
),
]


class TestApply(unittest.TestCase):
def _test_apply_impl(self, tensor, pending_transforms, expected_shape):
result = apply(tensor, pending_transforms)
self.assertListEqual(result[1], pending_transforms)
self.assertEqual(result[0].shape, expected_shape)

def _test_apply_metatensor_impl(self, tensor, pending_transforms, expected_shape, pending_as_parameter):
tensor_ = convert_to_tensor(tensor, track_meta=True)
if pending_as_parameter:
result, transforms = apply(tensor_, pending_transforms)
else:
for p in pending_transforms:
tensor_.push_pending_operation(p)
result, transforms = apply(tensor_)
self.assertEqual(result.shape, expected_shape)

SINGLE_TRANSFORM_CASES = single_2d_transform_cases()

def test_apply_single_transform(self):
for case in self.SINGLE_TRANSFORM_CASES:
self._test_apply_impl(*case)

def test_apply_single_transform_metatensor(self):
for case in self.SINGLE_TRANSFORM_CASES:
self._test_apply_metatensor_impl(*case, False)

def test_apply_single_transform_metatensor_override(self):
for case in self.SINGLE_TRANSFORM_CASES:
self._test_apply_metatensor_impl(*case, True)


if __name__ == "__main__":
unittest.main()
40 changes: 40 additions & 0 deletions tests/test_resample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch
from parameterized import parameterized

from monai.transforms.lazy.functional import resample
from monai.utils import convert_to_tensor
from tests.utils import assert_allclose, get_arange_img


def rotate_90_2d():
t = torch.eye(3)
t[:, 0] = torch.FloatTensor([0, -1, 0])
t[:, 1] = torch.FloatTensor([1, 0, 0])
return t


RESAMPLE_FUNCTION_CASES = [(get_arange_img((3, 3)), rotate_90_2d(), [[2, 5, 8], [1, 4, 7], [0, 3, 6]])]


class TestResampleFunction(unittest.TestCase):
@parameterized.expand(RESAMPLE_FUNCTION_CASES)
def test_resample_function_impl(self, img, matrix, expected):
out = resample(convert_to_tensor(img), matrix)
assert_allclose(out[0], expected, type_test=False)


if __name__ == "__main__":
unittest.main()
10 changes: 10 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,16 @@ def make_rand_affine(ndim: int = 3, random_state: Optional[np.random.RandomState
return af


def get_arange_img(size, dtype=np.float32, offset=0):
"""
Returns an image as a numpy array (complete with channel as dim 0)
with contents that iterate like an arange.
"""
n_elem = np.prod(size)
img = np.arange(offset, offset + n_elem, dtype=dtype).reshape(size)
return np.expand_dims(img, 0)


class DistTestCase(unittest.TestCase):
"""
testcase without _outcome, so that it's picklable.
Expand Down