From 2c682fe8f3ba7796a4b9b3ea4801372da2b23875 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Thu, 27 Oct 2022 18:14:18 +0100 Subject: [PATCH 01/52] Apply and MetaMatrix; partial functionality Signed-off-by: Ben Murray --- monai/transforms/__init__.py | 3 + monai/transforms/apply.py | 229 ++++++++++++++++++++++++++++++++ monai/transforms/meta_matrix.py | 213 +++++++++++++++++++++++++++++ monai/transforms/utils.py | 97 ++++++++++++++ tests/test_apply.py | 9 ++ tests/test_matmul.py | 129 ++++++++++++++++++ 6 files changed, 680 insertions(+) create mode 100644 monai/transforms/apply.py create mode 100644 monai/transforms/meta_matrix.py create mode 100644 tests/test_apply.py create mode 100644 tests/test_matmul.py diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 389571d16f..08c6b6a8f5 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -10,6 +10,7 @@ # limitations under the License. from .adaptors import FunctionSignature, adaptor, apply_alias, to_kwargs +from .apply import Apply, apply from .compose import Compose, OneOf from .croppad.array import ( BorderPad, @@ -621,6 +622,8 @@ generate_label_classes_crop_centers, generate_pos_neg_label_crop_centers, generate_spatial_bounding_box, + get_backend_from_tensor_like, + get_device_from_tensor_like, get_extreme_points, get_largest_connected_component_mask, get_number_image_type_conversions, diff --git a/monai/transforms/apply.py b/monai/transforms/apply.py new file mode 100644 index 0000000000..c58505e929 --- /dev/null +++ b/monai/transforms/apply.py @@ -0,0 +1,229 @@ +from typing import Optional, Sequence, Union + +import itertools as it + +import numpy as np + +import torch + +from monai.config import DtypeLike +from monai.data import MetaTensor +from monai.transforms import Affine +from monai.transforms.inverse import InvertibleTransform +from monai.transforms.transform import MapTransform +from monai.utils import GridSampleMode, GridSamplePadMode +from monai.transforms.utils import get_backend_from_tensor_like, get_device_from_tensor_like, dtypes_to_str_or_identity +from monai.transforms.meta_matrix import MatrixFactory, MetaMatrix, Matrix, matmul + +__all__ = [ + "apply", + "Apply" +] + +# TODO: This should move to a common place to be shared with dictionary +#from monai.utils.type_conversion import dtypes_to_str_or_identity + +GridSampleModeSequence = Union[Sequence[Union[GridSampleMode, str]], GridSampleMode, str] +GridSamplePadModeSequence = Union[Sequence[Union[GridSamplePadMode, str]], GridSamplePadMode, str] +DtypeSequence = Union[Sequence[DtypeLike], DtypeLike] + + +# TODO: move to mapping_stack.py +def extents_from_shape(shape, dtype=np.float64): + extents = [[0, shape[i]] for i in range(1, len(shape))] + + extents = it.product(*extents) + return list(np.asarray(e + (1,), dtype=dtype) for e in extents) + + +# TODO: move to mapping_stack.py +def shape_from_extents( + src_shape: Sequence, + extents: Union[Sequence[np.ndarray], Sequence[torch.Tensor], np.ndarray, torch.Tensor] +): + if isinstance(extents, (list, tuple)): + if isinstance(extents[0], np.ndarray): + aextents = np.asarray(extents) + aextents = torch.from_numpy(aextents) + else: + aextents = torch.stack(extents) + else: + if isinstance(extents, np.ndarray): + aextents = torch.from_numpy(extents) + else: + aextents = extents + + mins = aextents.min(axis=0)[0] + maxes = aextents.max(axis=0)[0] + values = torch.round(maxes - mins).type(torch.IntTensor)[:-1] + return torch.cat((torch.as_tensor([src_shape[0]]), values)) + + +def metadata_is_compatible(value_1, value_2): + if value_1 is None: + return True + else: + if value_2 is None: + return True + return value_1 == value_2 + + +def metadata_dtype_is_compatible(value_1, value_2): + if value_1 is None: + return True + else: + if value_2 is None: + return True + + # if we are here, value_1 and value_2 are both set + # TODO: this is not a good enough solution + value_1_ = dtypes_to_str_or_identity(value_1) + value_2_ = dtypes_to_str_or_identity(value_2) + return value_1_ == value_2_ + + +def starting_matrix_and_extents(matrix_factory, data): + # set up the identity matrix and metadata + cumulative_matrix = matrix_factory.identity() + cumulative_extents = extents_from_shape(data.shape) + return cumulative_matrix, cumulative_extents + + +def prepare_args_dict_for_apply(cur_mode, cur_padding_mode, cur_device, cur_dtype): + kwargs = {} + if cur_mode is not None: + kwargs['mode'] = cur_mode + if cur_padding_mode is not None: + kwargs['padding_mode'] = cur_padding_mode + if cur_device is not None: + kwargs['device'] = cur_device + if cur_dtype is not None: + kwargs['dtype'] = cur_dtype + + return kwargs + + +def matrix_from_matrix_container(matrix): + if isinstance(matrix, MetaMatrix): + return matrix.matrix.data + elif isinstance(matrix, Matrix): + return matrix.data + else: + return matrix + + +def apply(data: Union[torch.Tensor, MetaTensor], + pending: Optional[dict, list] = None): + + # TODO: if data is a dict, then pending must also be a dict + if isinstance(data, dict): + rd = dict() + for k, v in data.items(): + result = apply(v, pending) + rd[k] = result + return rd + + if isinstance(data, MetaTensor) or pending is not None: + pending_ = [] if pending is None else pending + else: + pending_ = data.pending_transforms + + if len(pending_) == 0: + return data + + dim_count = len(data.shape) - 1 + matrix_factory = MatrixFactory(dim_count, + get_backend_from_tensor_like(data), + get_device_from_tensor_like(data)) + + # set up the identity matrix and metadata + cumulative_matrix, cumulative_extents = starting_matrix_and_extents(matrix_factory, data) + + # set the various resampling parameters to an initial state + cur_mode = None + cur_padding_mode = None + cur_device = None + cur_dtype = None + cur_shape = data.shape + + for meta_matrix in pending_: + next_matrix = meta_matrix.data + # print("intermediate matrix\n", matrix_from_matrix_container(cumulative_matrix)) + cumulative_matrix = matmul(cumulative_matrix, next_matrix) + cumulative_extents = [matmul(e, cumulative_matrix) for e in cumulative_extents] + + new_mode = meta_matrix.metadata.get('mode', None) + new_padding_mode = meta_matrix.metadata.get('padding_mode', None) + new_device = meta_matrix.metadata.get('device', None) + new_dtype = meta_matrix.metadata.get('dtype', None) + new_shape = meta_matrix.metadata.get('shape_override', None) + + mode_compat = metadata_is_compatible(cur_mode, new_mode) + padding_mode_compat = metadata_is_compatible(cur_padding_mode, new_padding_mode) + device_compat = metadata_is_compatible(cur_device, new_device) + dtype_compat = metadata_dtype_is_compatible(cur_dtype, new_dtype) + + if (mode_compat is False or padding_mode_compat is False or + device_compat is False or dtype_compat is False): + # carry out an intermediate resample here due to incompatibility between arguments + kwargs = prepare_args_dict_for_apply(cur_mode, cur_padding_mode, cur_device, cur_dtype) + + cumulative_matrix_ = matrix_from_matrix_container(cumulative_matrix) + a = Affine(norm_coords=False, + affine=cumulative_matrix_, + **kwargs) + data, _ = a(img=data) + + cumulative_matrix, cumulative_extents =\ + starting_matrix_and_extents(matrix_factory, data) + + cur_mode = cur_mode if new_mode is None else new_mode + cur_padding_mode = cur_padding_mode if new_padding_mode is None else new_padding_mode + cur_device = cur_device if new_device is None else new_device + cur_dtype = cur_dtype if new_dtype is None else new_dtype + cur_shape = cur_shape if new_shape is None else new_shape + + kwargs = prepare_args_dict_for_apply(cur_mode, cur_padding_mode, cur_device, cur_dtype) + + cumulative_matrix_ = matrix_from_matrix_container(cumulative_matrix) + + # print(f"applying with cumulative matrix\n {cumulative_matrix_}") + a = Affine(norm_coords=False, + affine=cumulative_matrix_, + spatial_size=cur_shape[1:], + normalized=False, + **kwargs) + data, tx = a(img=data) + data.clear_pending_transforms() + + return data + + +# make Apply universal for arrays and dictionaries; it just calls through to functional apply +class Apply(InvertibleTransform): + + def __init__(self): + super().__init__() + + def __call__(self, *args, **kwargs): + return apply(*args, **kwargs) + + def inverse(self, data): + return NotImplementedError() + + +# class Applyd(MapTransform, InvertibleTransform): +# +# def __init__(self): +# super().__init__() +# +# def __call__( +# self, +# d: dict +# ): +# rd = dict() +# for k, v in d.items(): +# rd[k] = apply(v) +# +# def inverse(self, data): +# return NotImplementedError() diff --git a/monai/transforms/meta_matrix.py b/monai/transforms/meta_matrix.py new file mode 100644 index 0000000000..3d0da051de --- /dev/null +++ b/monai/transforms/meta_matrix.py @@ -0,0 +1,213 @@ +from typing import Optional, Union + +import numpy as np + +import torch + +from monai.config import NdarrayOrTensor + + +def is_matrix_shaped(data): + + return (len(data.shape) == 2 and data.shape[0] in (3, 4) and + data.shape[1] in (3, 4) and data.shape[0] == data.shape[1]) + + +def is_grid_shaped(data): + + return (len(data.shape) == 3 and data.shape[0] == 3 or + len(data.shape) == 4 and data.shape[0] == 4) + + +class MatrixFactory: + pass + + +def ensure_tensor(data: NdarrayOrTensor): + if isinstance(data, torch.Tensor): + return data + + return torch.as_tensor(data) + + +class Matrix: + + def __init__(self, matrix: NdarrayOrTensor): + self.data = ensure_tensor(matrix) + + # def __matmul__(self, other): + # if isinstance(other, Matrix): + # other_matrix = other.data + # else: + # other_matrix = other + # return self.data @ other_matrix + # + # def __rmatmul__(self, other): + # return other.__matmul__(self.data) + + +class Grid: + def __init__(self, grid): + self.data = ensure_tensor(grid) + + # def __matmul__(self, other): + # raise NotImplementedError() + + +class MetaMatrix: + + def __init__( + self, + matrix: Union[NdarrayOrTensor, Matrix, Grid], + metadata: Optional[dict] = None): + + if not isinstance(matrix, (Matrix, Grid)): + if matrix.shape == 2: + if matrix.shape[0] != matrix.shape[1] or matrix.shape[0] not in (3, 4): + raise ValueError("If 'matrix' is passed a numpy ndarray/torch Tensor, it must" + f" be 3x3 or 4x4 ('matrix' has has shape {matrix.shape})") + matrix_ = Matrix(matrix) + else: + matrix_ = matrix + self.matrix = matrix_ + + self.metadata = metadata or {} + + def __matmul__(self, other): + if isinstance(other, MetaMatrix): + other_ = other.matrix + else: + other_ = other + return MetaMatrix(self.matrix @ other_) + + def __rmatmul__(self, other): + if isinstance(other, MetaMatrix): + other_ = other.matrix + else: + other_ = other + return MetaMatrix(other_ @ self.matrix) + + +def matmul( + left: Union[MetaMatrix, Grid, Matrix, NdarrayOrTensor], + right: Union[MetaMatrix, Grid, Matrix, NdarrayOrTensor] +): + matrix_types = (MetaMatrix, Grid, Matrix, torch.Tensor, np.ndarray) + + if not isinstance(left, matrix_types): + raise TypeError(f"'left' must be one of {matrix_types} but is {type(left)}") + if not isinstance(right, matrix_types): + raise TypeError(f"'second' must be one of {matrix_types} but is {type(right)}") + + left_ = left.matrix if isinstance(left, MetaMatrix) else left + right_ = right.matrix if isinstance(right, MetaMatrix) else right + + # TODO: it might be better to not return a metamatrix, unless we pass in the resulting + # metadata also + put_in_metamatrix = isinstance(left, MetaMatrix) or isinstance(right, MetaMatrix) + + put_in_grid = isinstance(left, Grid) or isinstance(right, Grid) + + put_in_matrix = isinstance(left, Matrix) or isinstance(right, Matrix) + put_in_matrix = False if put_in_grid is True else put_in_matrix + + promote_to_tensor = not (isinstance(left_, np.ndarray) and isinstance(right_, np.ndarray)) + + left_raw = left_.data if isinstance(left_, (Matrix, Grid)) else left_ + right_raw = right_.data if isinstance(right_, (Matrix, Grid)) else right_ + + if promote_to_tensor: + left_raw = torch.as_tensor(left_raw) + right_raw = torch.as_tensor(right_raw) + + if isinstance(left_, Grid): + if isinstance(right_, Grid): + raise RuntimeError("Unable to matrix multiply two Grids") + else: + result = matmul_grid_matrix(left_raw, right_raw) + else: + if isinstance(right_, Grid): + result = matmul_matrix_grid(left_raw, right_raw) + else: + result = matmul_matrix_matrix(left_raw, right_raw) + + if put_in_grid: + result = Grid(result) + elif put_in_matrix: + result = Matrix(result) + + if put_in_metamatrix: + result = MetaMatrix(result) + + return result + + +def matmul_matrix_grid( + left: NdarrayOrTensor, + right: NdarrayOrTensor +): + if not is_matrix_shaped(left): + raise ValueError(f"'left' should be a 2D or 3D homogenous matrix but has shape {left.shape}") + + if not is_grid_shaped(right): + raise ValueError("'right' should be a 3D array with shape[0] == 2 or a " + f"4D array with shape[0] == 3 but has shape {right.shape}") + + # flatten the grid to take advantage of torch batch matrix multiply + right_flat = right.reshape(right.shape[0], -1) + result_flat = left @ right_flat + # restore the grid shape + result = result_flat.reshape((-1,) + result_flat.shape[1:]) + return result + + +def matmul_grid_matrix( + left: NdarrayOrTensor, + right: NdarrayOrTensor +): + if not is_grid_shaped(left): + raise ValueError("'left' should be a 3D array with shape[0] == 2 or a " + f"4D array with shape[0] == 3 but has shape {left.shape}") + + if not is_matrix_shaped(right): + raise ValueError(f"'right' should be a 2D or 3D homogenous matrix but has shape {right.shape}") + + try: + inv_matrix = torch.inverse(right) + except RuntimeError: + # the matrix is not invertible, so we will have to perform a slow grid to matrix operation + return matmul_grid_matrix_slow(left, right) + + # invert the matrix and swap the arguments, taking advantage of + # matrix @ vector == vector_transposed @ matrix_inverse + return matmul_matrix_grid(torch.inverse(right), left) + + +def matmul_grid_matrix_slow( + left: NdarrayOrTensor, + right: NdarrayOrTensor +): + if not is_grid_shaped(left): + raise ValueError("'left' should be a 3D array with shape[0] == 2 or a " + f"4D array with shape[0] == 3 but has shape {left.shape}") + + if not is_matrix_shaped(right): + raise ValueError(f"'right' should be a 2D or 3D homogenous matrix but has shape {right.shape}") + + flat_left = left.reshape(left.shape[0], -1) + result_flat = torch.zeros_like(flat_left) + for i in range(flat_left.shape[1]): + vector = flat_left[:, i][None, :] + result_vector = vector @ right + result_flat[:, i] = result_vector[0, :] + + # restore the grid shape + result = result_flat.reshape((-1,) + result_flat.shape[1:]) + return result + + +def matmul_matrix_matrix( + left: NdarrayOrTensor, + right: NdarrayOrTensor, +): + return left @ right diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index e96d906f20..091d3e5741 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -85,6 +85,8 @@ "generate_label_classes_crop_centers", "generate_pos_neg_label_crop_centers", "generate_spatial_bounding_box", + "get_backend_from_tensor_like", + "get_device_from_tensor_like", "get_extreme_points", "get_largest_connected_component_mask", "remove_small_objects", @@ -1790,5 +1792,100 @@ def squarepulse(sig, duty: float = 0.5): return y +def get_device_from_tensor_like( + data: NdarrayOrTensor +): + """ + This function returns the device of `data`, which must either be a numpy ndarray or a + pytorch Tensor. + + Args: + data: the ndarray/tensor to return the device of + + Returns: + None if `data` is a numpy array, or the device of the pytorch Tensor + """ + if isinstance(data, np.ndarray): + return None + elif isinstance(data, torch.Tensor): + return data.device + else: + msg = "'data' must be one of numpy ndarray or torch Tensor but is {}" + raise ValueError(msg.format(type(data))) + + +def get_backend_from_tensor_like( + data: NdarrayOrTensor +): + """ + This function returns the backend of `data`, which must either be a numpy ndarray or a + pytorch Tensor. + + Args: + data: the ndarray/tensor to return the device of + + Returns: + None if `data` is a numpy array, or the device of the pytorch Tensor + """ + if isinstance(data, np.ndarray): + return TransformBackends.NUMPY + elif isinstance(data, torch.Tensor): + return TransformBackends.TORCH + else: + msg = "'data' must be one of numpy ndarray or torch Tensor but is {}" + raise ValueError(msg.format(type(data))) + + +def dtype_torch_to_numpy(dtype: torch.dtype) -> np.dtype: + """Convert a torch dtype to its numpy equivalent.""" + return torch.empty([], dtype=dtype).numpy().dtype # type: ignore + + +def dtype_numpy_to_torch(dtype: np.dtype) -> torch.dtype: + """Convert a numpy dtype to its torch equivalent.""" + return torch.from_numpy(np.empty([], dtype=dtype)).dtype + + +__dtype_dict = { + np.int8: 'int8', + torch.int8: 'int8', + np.int16: 'int16', + torch.int16: 'int16', + int: 'int32', + np.int32: 'int32', + torch.int32: 'int32', + np.int64: 'int64', + torch.int64: 'int64', + np.uint8: 'uint8', + torch.uint8: 'uint8', + np.uint16: 'uint16', + np.uint32: 'uint32', + np.uint64: 'uint64', + float: 'float32', + np.float16: 'float16', + np.float: 'float32', + np.float32: 'float32', + np.float64: 'float64', + torch.float16: 'float16', + torch.float: 'float32', + torch.float32: 'float32', + torch.double: 'float64', + torch.float64: 'float64' +} + +def dtypes_to_str_or_identity(dtype: Any) -> Any: + return __dtype_dict.get(dtype, dtype) + + +def get_numpy_dtype_from_string(dtype: str) -> np.dtype: + """Get a numpy dtype (e.g., `np.float32`) from its string (e.g., `"float32"`).""" + return np.empty([], dtype=dtype).dtype + + +def get_torch_dtype_from_string(dtype: str) -> torch.dtype: + """Get a torch dtype (e.g., `torch.float32`) from its string (e.g., `"float32"`).""" + return dtype_numpy_to_torch(get_numpy_dtype_from_string(dtype)) + + if __name__ == "__main__": print_transform_backends() diff --git a/tests/test_apply.py b/tests/test_apply.py new file mode 100644 index 0000000000..0e8500544f --- /dev/null +++ b/tests/test_apply.py @@ -0,0 +1,9 @@ +import unittest + +from monai.transforms import apply + + +class TestApply(unittest.TestCase): + + def _test_apply_impl(self): + result = apply(None) diff --git a/tests/test_matmul.py b/tests/test_matmul.py new file mode 100644 index 0000000000..19fa2ab36b --- /dev/null +++ b/tests/test_matmul.py @@ -0,0 +1,129 @@ +import unittest +from parameterized import parameterized + +import numpy as np + +import torch + +from monai.transforms.meta_matrix import ( + Grid, is_grid_shaped, is_matrix_shaped, matmul, matmul_matrix_grid, matmul_grid_matrix_slow, + matmul_grid_matrix, Matrix +) + + +class TestMatmulFunctions(unittest.TestCase): + + def test_matrix_grid_and_grid_inv_matrix_are_equivalent(self): + grid = torch.randn((3, 32, 32)) + + matrix = torch.eye(3, 3) + matrix[:, 0] = torch.FloatTensor([0, -1, 0]) + matrix[:, 1] = torch.FloatTensor([1, 0, 0]) + + inv_matrix = torch.inverse(matrix) + + result1 = matmul_matrix_grid(matrix, grid) + result2 = matmul_grid_matrix(grid, inv_matrix) + self.assertTrue(torch.allclose(result1, result2)) + + def test_matmul_grid_matrix_slow(self): + grid = torch.randn((3, 32, 32)) + + matrix = torch.eye(3, 3) + matrix[:, 0] = torch.FloatTensor([0, -1, 0]) + matrix[:, 1] = torch.FloatTensor([1, 0, 0]) + + result1 = matmul_grid_matrix_slow(grid, matrix) + result2 = matmul_grid_matrix(grid, matrix) + self.assertTrue(torch.allclose(result1, result2)) + + MATMUL_TEST_CASES = [ + [np.eye(3, dtype=np.float32), np.eye(3, dtype=np.float32), np.ndarray], + [np.eye(3, dtype=np.float32), torch.eye(3), torch.Tensor], + [np.eye(3, dtype=np.float32), Matrix(torch.eye(3)), Matrix], + [np.eye(3, dtype=np.float32), Grid(torch.randn((3, 8, 8))), Grid], + [torch.eye(3), np.eye(3, dtype=np.float32), torch.Tensor], + [torch.eye(3), torch.eye(3), torch.Tensor], + [torch.eye(3), Matrix(torch.eye(3)), Matrix], + [torch.eye(3), Grid(torch.randn((3, 8, 8))), Grid], + [Matrix(torch.eye(3)), np.eye(3, dtype=np.float32), Matrix], + [Matrix(torch.eye(3)), torch.eye(3), Matrix], + [Matrix(torch.eye(3)), Matrix(torch.eye(3)), Matrix], + [Matrix(torch.eye(3)), Grid(torch.randn((3, 8, 8))), Grid], + [Grid(torch.randn((3, 8, 8))), np.eye(3, dtype=np.float32), Grid], + [Grid(torch.randn((3, 8, 8))), torch.eye(3), Grid], + [Grid(torch.randn((3, 8, 8))), Matrix(torch.eye(3)), Grid], + [Grid(torch.randn((3, 8, 8))), Grid(torch.randn((3, 8, 8))), None], + ] + + def _test_matmul_correct_return_type_impl(self, left, right, expected): + if expected is None: + with self.assertRaises(RuntimeError): + result = matmul(left, right) + else: + result = matmul(left, right) + self.assertIsInstance(result, expected) + + @parameterized.expand(MATMUL_TEST_CASES) + def test_matmul_correct_return_type(self, left, right, expected): + self._test_matmul_correct_return_type_impl(left, right, expected) + + # def test_all_matmul_correct_return_type(self): + # for case in self.MATMUL_TEST_CASES: + # with self.subTest(f"{case}"): + # self._test_matmul_correct_return_type_impl(*case) + + MATRIX_SHAPE_TESTCASES = [ + (torch.randn(2, 2), False), + (torch.randn(3, 3), True), + (torch.randn(4, 4), True), + (torch.randn(5, 5), False), + (torch.randn(3, 4), False), + (torch.randn(4, 3), False), + (torch.randn(3), False), + (torch.randn(4), False), + (torch.randn(5), False), + (torch.randn(3, 3, 3), False), + (torch.randn(4, 4, 4), False), + (torch.randn(5, 5, 5), False) + ] + + def _test_is_matrix_shaped_impl(self, matrix, expected): + self.assertEqual(is_matrix_shaped(matrix), expected) + + @parameterized.expand(MATRIX_SHAPE_TESTCASES) + def test_is_matrix_shaped(self, matrix, expected): + self._test_is_matrix_shaped_impl(matrix, expected) + + # def test_all_is_matrix_shaped(self): + # for case in self.MATRIX_SHAPE_TESTCASES: + # with self.subTest(f"{case[0].shape}"): + # self._test_is_matrix_shaped_impl(*case) + + + GRID_SHAPE_TESTCASES = [ + (torch.randn(1, 16, 32), False), + (torch.randn(2, 16, 32), False), + (torch.randn(3, 16, 32), True), + (torch.randn(4, 16, 32), False), + (torch.randn(5, 16, 32), False), + (torch.randn(3, 16, 32, 64), False), + (torch.randn(4, 16, 32, 64), True), + (torch.randn(5, 16, 32, 64), False) + ] + + def _test_is_grid_shaped_impl(self, grid, expected): + self.assertEqual(is_grid_shaped(grid), expected) + + @parameterized.expand(GRID_SHAPE_TESTCASES) + def test_is_grid_shaped(self, grid, expected): + self._test_is_grid_shaped_impl(grid, expected) + + # def test_all_is_grid_shaped(self): + # for case in self.GRID_SHAPE_TESTCASES: + # with self.subTest(f"{case[0].shape}"): + # self._test_is_grid_shaped_impl(*case) + + +if __name__ == "__main__": + unittest.main() From dbbe26cd9a5bc433a8e03501995db0dc5aa7be2d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 27 Oct 2022 17:16:45 +0000 Subject: [PATCH 02/52] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/transforms/apply.py | 1 - 1 file changed, 1 deletion(-) diff --git a/monai/transforms/apply.py b/monai/transforms/apply.py index c58505e929..f97705b847 100644 --- a/monai/transforms/apply.py +++ b/monai/transforms/apply.py @@ -10,7 +10,6 @@ from monai.data import MetaTensor from monai.transforms import Affine from monai.transforms.inverse import InvertibleTransform -from monai.transforms.transform import MapTransform from monai.utils import GridSampleMode, GridSamplePadMode from monai.transforms.utils import get_backend_from_tensor_like, get_device_from_tensor_like, dtypes_to_str_or_identity from monai.transforms.meta_matrix import MatrixFactory, MetaMatrix, Matrix, matmul From 9598480e16ed42fd774209dc20d802a32cf1a05a Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Thu, 27 Oct 2022 21:04:47 +0100 Subject: [PATCH 03/52] making import for MetaTensor more specific to avoid circular reference Signed-off-by: Ben Murray --- monai/transforms/apply.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/apply.py b/monai/transforms/apply.py index c58505e929..1b963adf8f 100644 --- a/monai/transforms/apply.py +++ b/monai/transforms/apply.py @@ -7,7 +7,7 @@ import torch from monai.config import DtypeLike -from monai.data import MetaTensor +from monai.data.meta_tensor import MetaTensor from monai.transforms import Affine from monai.transforms.inverse import InvertibleTransform from monai.transforms.transform import MapTransform From 0ed7c2df751c784cd24c97c8b34ebe9842ec6519 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Thu, 27 Oct 2022 21:07:35 +0100 Subject: [PATCH 04/52] Making Affine import more specific in apply to avoid circular reference Signed-off-by: Ben Murray --- monai/transforms/apply.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/apply.py b/monai/transforms/apply.py index d613540952..9b56321eaf 100644 --- a/monai/transforms/apply.py +++ b/monai/transforms/apply.py @@ -8,7 +8,7 @@ from monai.config import DtypeLike from monai.data.meta_tensor import MetaTensor -from monai.transforms import Affine +from monai.transforms.spatial.array import Affine from monai.transforms.inverse import InvertibleTransform from monai.utils import GridSampleMode, GridSamplePadMode from monai.transforms.utils import get_backend_from_tensor_like, get_device_from_tensor_like, dtypes_to_str_or_identity From 56717f6d4d3c5c6720ad4389a5e477a029b386fd Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Thu, 27 Oct 2022 21:10:02 +0100 Subject: [PATCH 05/52] Fixing typing signature issue on apply method Signed-off-by: Ben Murray --- monai/transforms/apply.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/apply.py b/monai/transforms/apply.py index 9b56321eaf..b2fa8fc708 100644 --- a/monai/transforms/apply.py +++ b/monai/transforms/apply.py @@ -112,7 +112,7 @@ def matrix_from_matrix_container(matrix): def apply(data: Union[torch.Tensor, MetaTensor], - pending: Optional[dict, list] = None): + pending: Optional[Union[dict, list]] = None): # TODO: if data is a dict, then pending must also be a dict if isinstance(data, dict): From f15f46dea6c61e3f535e5335e80defd08e78e3a5 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Thu, 27 Oct 2022 21:20:39 +0100 Subject: [PATCH 06/52] Adding missing license boilerplate Signed-off-by: Ben Murray --- monai/transforms/apply.py | 11 ++++++++ monai/transforms/meta_matrix.py | 11 ++++++++ tests/tempscript.py | 47 +++++++++++++++++++++++++++++++++ tests/test_apply.py | 11 ++++++++ tests/test_matmul.py | 11 ++++++++ 5 files changed, 91 insertions(+) create mode 100644 tests/tempscript.py diff --git a/monai/transforms/apply.py b/monai/transforms/apply.py index b2fa8fc708..20d031b448 100644 --- a/monai/transforms/apply.py +++ b/monai/transforms/apply.py @@ -1,3 +1,14 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import Optional, Sequence, Union import itertools as it diff --git a/monai/transforms/meta_matrix.py b/monai/transforms/meta_matrix.py index 3d0da051de..559cd7e451 100644 --- a/monai/transforms/meta_matrix.py +++ b/monai/transforms/meta_matrix.py @@ -1,3 +1,14 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import Optional, Union import numpy as np diff --git a/tests/tempscript.py b/tests/tempscript.py new file mode 100644 index 0000000000..5a9e5945b6 --- /dev/null +++ b/tests/tempscript.py @@ -0,0 +1,47 @@ +import numpy as np + +import matplotlib.pyplot as plt + +import torch +from monai.utils import GridSampleMode, GridSamplePadMode + +from monai.transforms.atmostonce.apply import Applyd +from monai.transforms.atmostonce.dictionary import Rotated +from monai.transforms import Compose + + + +def test_rotate_tensor(): + r = Rotated(('image', 'label'), [0.0, 1.0, 0.0]) + + d = { + 'image': torch.zeros((1, 64, 64, 32), device="cpu", dtype=torch.float32), + 'label': torch.ones((1, 64, 64, 32), device="cpu", dtype=torch.int8) + } + d = r(d) + + for k, v in d.items(): + if isinstance(v, (np.ndarray, torch.Tensor)): + print(k, v.shape) + else: + print(k, v) + + +def test_rotate_apply(): + c = Compose([ + Rotated(('image', 'label'), (0.0, 3.14159265 / 2, 0.0)), + Applyd(('image', 'label'), + modes=(GridSampleMode.BILINEAR, GridSampleMode.NEAREST), + padding_modes=(GridSamplePadMode.BORDER, GridSamplePadMode.BORDER)) + ]) + + d = { + 'image': torch.zeros((1, 64, 64, 32), device="cpu", dtype=torch.float32), + 'label': torch.ones((1, 64, 64, 32), device="cpu", dtype=torch.int8) + } + plt.imshow(d['image'][0, ..., d['image'].shape[-1]//2]) + d = c(d) + plt.imshow(d['image'][0, ..., d['image'].shape[-1]//2]) + print(d['image'].shape) + +test_rotate_apply() diff --git a/tests/test_apply.py b/tests/test_apply.py index 0e8500544f..953c249db4 100644 --- a/tests/test_apply.py +++ b/tests/test_apply.py @@ -1,3 +1,14 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import unittest from monai.transforms import apply diff --git a/tests/test_matmul.py b/tests/test_matmul.py index 19fa2ab36b..bf427ce06c 100644 --- a/tests/test_matmul.py +++ b/tests/test_matmul.py @@ -1,3 +1,14 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import unittest from parameterized import parameterized From 69041065132dd2392d546bfc56b33c57fb3269d7 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Thu, 27 Oct 2022 21:32:27 +0100 Subject: [PATCH 07/52] Minimal docstrings for apply / Apply Signed-off-by: Ben Murray --- monai/transforms/apply.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/monai/transforms/apply.py b/monai/transforms/apply.py index 20d031b448..9290035c10 100644 --- a/monai/transforms/apply.py +++ b/monai/transforms/apply.py @@ -122,10 +122,17 @@ def matrix_from_matrix_container(matrix): return matrix -def apply(data: Union[torch.Tensor, MetaTensor], +def apply(data: Union[torch.Tensor, MetaTensor, dict], pending: Optional[Union[dict, list]] = None): - - # TODO: if data is a dict, then pending must also be a dict + """ + This method applies pending transforms to tensors. + + Args: + data: A torch Tensor, monai MetaTensor, or a dictionary containing Tensors / MetaTensors + pending: Optional arg containing pending transforms. This must be set if data is a Tensor + or dictionary of Tensors, but is optional if data is a MetaTensor / dictionary of + MetaTensors. + """ if isinstance(data, dict): rd = dict() for k, v in data.items(): @@ -211,7 +218,10 @@ def apply(data: Union[torch.Tensor, MetaTensor], # make Apply universal for arrays and dictionaries; it just calls through to functional apply class Apply(InvertibleTransform): - + """ + Apply wraps the apply method and can function as a Transform in either array or dictionary + mode. + """ def __init__(self): super().__init__() From 54c6f2aa81d3ee6d40acea513b5bb880d3b1d049 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Thu, 27 Oct 2022 21:48:44 +0100 Subject: [PATCH 08/52] Splitting apply.py into lazy/functional and lazy/array Signed-off-by: Ben Murray --- monai/transforms/__init__.py | 3 +- monai/transforms/lazy/__init__.py | 10 ++++ monai/transforms/lazy/array.py | 46 +++++++++++++++++++ .../{apply.py => lazy/functional.py} | 37 +-------------- 4 files changed, 59 insertions(+), 37 deletions(-) create mode 100644 monai/transforms/lazy/__init__.py create mode 100644 monai/transforms/lazy/array.py rename monai/transforms/{apply.py => lazy/functional.py} (90%) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 08c6b6a8f5..1350b3bd95 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -10,7 +10,6 @@ # limitations under the License. from .adaptors import FunctionSignature, adaptor, apply_alias, to_kwargs -from .apply import Apply, apply from .compose import Compose, OneOf from .croppad.array import ( BorderPad, @@ -228,6 +227,8 @@ from .inverse_batch_transform import BatchInverseTransform, Decollated, DecollateD, DecollateDict from .io.array import SUPPORTED_READERS, LoadImage, SaveImage from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict +from .lazy.array import Apply +from .lazy.functional import apply from .meta_utility.dictionary import ( FromMetaTensord, FromMetaTensorD, diff --git a/monai/transforms/lazy/__init__.py b/monai/transforms/lazy/__init__.py new file mode 100644 index 0000000000..1e97f89407 --- /dev/null +++ b/monai/transforms/lazy/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/monai/transforms/lazy/array.py b/monai/transforms/lazy/array.py new file mode 100644 index 0000000000..b958010b4f --- /dev/null +++ b/monai/transforms/lazy/array.py @@ -0,0 +1,46 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from monai.transforms.inverse import InvertibleTransform + +from monai.transforms.lazy.functional import apply + + +class Apply(InvertibleTransform): + """ + Apply wraps the apply method and can function as a Transform in either array or dictionary + mode. + """ + def __init__(self): + super().__init__() + + def __call__(self, *args, **kwargs): + return apply(*args, **kwargs) + + def inverse(self, data): + return NotImplementedError() + + +# class Applyd(MapTransform, InvertibleTransform): +# +# def __init__(self): +# super().__init__() +# +# def __call__( +# self, +# d: dict +# ): +# rd = dict() +# for k, v in d.items(): +# rd[k] = apply(v) +# +# def inverse(self, data): +# return NotImplementedError() \ No newline at end of file diff --git a/monai/transforms/apply.py b/monai/transforms/lazy/functional.py similarity index 90% rename from monai/transforms/apply.py rename to monai/transforms/lazy/functional.py index 9290035c10..af320598b4 100644 --- a/monai/transforms/apply.py +++ b/monai/transforms/lazy/functional.py @@ -20,14 +20,12 @@ from monai.config import DtypeLike from monai.data.meta_tensor import MetaTensor from monai.transforms.spatial.array import Affine -from monai.transforms.inverse import InvertibleTransform from monai.utils import GridSampleMode, GridSamplePadMode from monai.transforms.utils import get_backend_from_tensor_like, get_device_from_tensor_like, dtypes_to_str_or_identity from monai.transforms.meta_matrix import MatrixFactory, MetaMatrix, Matrix, matmul __all__ = [ - "apply", - "Apply" + "apply" ] # TODO: This should move to a common place to be shared with dictionary @@ -214,36 +212,3 @@ def apply(data: Union[torch.Tensor, MetaTensor, dict], data.clear_pending_transforms() return data - - -# make Apply universal for arrays and dictionaries; it just calls through to functional apply -class Apply(InvertibleTransform): - """ - Apply wraps the apply method and can function as a Transform in either array or dictionary - mode. - """ - def __init__(self): - super().__init__() - - def __call__(self, *args, **kwargs): - return apply(*args, **kwargs) - - def inverse(self, data): - return NotImplementedError() - - -# class Applyd(MapTransform, InvertibleTransform): -# -# def __init__(self): -# super().__init__() -# -# def __call__( -# self, -# d: dict -# ): -# rd = dict() -# for k, v in d.items(): -# rd[k] = apply(v) -# -# def inverse(self, data): -# return NotImplementedError() From 472c3cdba8a01b226076824ab7dd739daf3e6942 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 27 Oct 2022 20:49:21 +0000 Subject: [PATCH 09/52] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/transforms/lazy/array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/lazy/array.py b/monai/transforms/lazy/array.py index b958010b4f..a3de0dfe7d 100644 --- a/monai/transforms/lazy/array.py +++ b/monai/transforms/lazy/array.py @@ -43,4 +43,4 @@ def inverse(self, data): # rd[k] = apply(v) # # def inverse(self, data): -# return NotImplementedError() \ No newline at end of file +# return NotImplementedError() From 5f529591d786183688f74bc5e13202b37f9b4245 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Thu, 27 Oct 2022 21:54:57 +0100 Subject: [PATCH 10/52] Removing spurious test file tempscript Signed-off-by: Ben Murray --- tests/tempscript.py | 47 --------------------------------------------- 1 file changed, 47 deletions(-) delete mode 100644 tests/tempscript.py diff --git a/tests/tempscript.py b/tests/tempscript.py deleted file mode 100644 index 5a9e5945b6..0000000000 --- a/tests/tempscript.py +++ /dev/null @@ -1,47 +0,0 @@ -import numpy as np - -import matplotlib.pyplot as plt - -import torch -from monai.utils import GridSampleMode, GridSamplePadMode - -from monai.transforms.atmostonce.apply import Applyd -from monai.transforms.atmostonce.dictionary import Rotated -from monai.transforms import Compose - - - -def test_rotate_tensor(): - r = Rotated(('image', 'label'), [0.0, 1.0, 0.0]) - - d = { - 'image': torch.zeros((1, 64, 64, 32), device="cpu", dtype=torch.float32), - 'label': torch.ones((1, 64, 64, 32), device="cpu", dtype=torch.int8) - } - d = r(d) - - for k, v in d.items(): - if isinstance(v, (np.ndarray, torch.Tensor)): - print(k, v.shape) - else: - print(k, v) - - -def test_rotate_apply(): - c = Compose([ - Rotated(('image', 'label'), (0.0, 3.14159265 / 2, 0.0)), - Applyd(('image', 'label'), - modes=(GridSampleMode.BILINEAR, GridSampleMode.NEAREST), - padding_modes=(GridSamplePadMode.BORDER, GridSamplePadMode.BORDER)) - ]) - - d = { - 'image': torch.zeros((1, 64, 64, 32), device="cpu", dtype=torch.float32), - 'label': torch.ones((1, 64, 64, 32), device="cpu", dtype=torch.int8) - } - plt.imshow(d['image'][0, ..., d['image'].shape[-1]//2]) - d = c(d) - plt.imshow(d['image'][0, ..., d['image'].shape[-1]//2]) - print(d['image'].shape) - -test_rotate_apply() From f0b010a37537aefe141e07dbc7a23e86f78ca57c Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Thu, 27 Oct 2022 22:13:08 +0100 Subject: [PATCH 11/52] Auto formatting fixes Signed-off-by: Ben Murray --- monai/transforms/lazy/array.py | 2 +- monai/transforms/lazy/functional.py | 58 ++++++++++---------------- monai/transforms/meta_matrix.py | 63 ++++++++++++----------------- monai/transforms/utils.py | 57 +++++++++++++------------- tests/test_apply.py | 1 - tests/test_matmul.py | 19 +++++---- 6 files changed, 86 insertions(+), 114 deletions(-) diff --git a/monai/transforms/lazy/array.py b/monai/transforms/lazy/array.py index a3de0dfe7d..bc61a24a6e 100644 --- a/monai/transforms/lazy/array.py +++ b/monai/transforms/lazy/array.py @@ -10,7 +10,6 @@ # limitations under the License. from monai.transforms.inverse import InvertibleTransform - from monai.transforms.lazy.functional import apply @@ -19,6 +18,7 @@ class Apply(InvertibleTransform): Apply wraps the apply method and can function as a Transform in either array or dictionary mode. """ + def __init__(self): super().__init__() diff --git a/monai/transforms/lazy/functional.py b/monai/transforms/lazy/functional.py index af320598b4..607e90d6bb 100644 --- a/monai/transforms/lazy/functional.py +++ b/monai/transforms/lazy/functional.py @@ -9,27 +9,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Sequence, Union - import itertools as it +from typing import Optional, Sequence, Union import numpy as np - import torch from monai.config import DtypeLike from monai.data.meta_tensor import MetaTensor +from monai.transforms.meta_matrix import Matrix, MatrixFactory, MetaMatrix, matmul from monai.transforms.spatial.array import Affine +from monai.transforms.utils import dtypes_to_str_or_identity, get_backend_from_tensor_like, get_device_from_tensor_like from monai.utils import GridSampleMode, GridSamplePadMode -from monai.transforms.utils import get_backend_from_tensor_like, get_device_from_tensor_like, dtypes_to_str_or_identity -from monai.transforms.meta_matrix import MatrixFactory, MetaMatrix, Matrix, matmul -__all__ = [ - "apply" -] +__all__ = ["apply"] # TODO: This should move to a common place to be shared with dictionary -#from monai.utils.type_conversion import dtypes_to_str_or_identity +# from monai.utils.type_conversion import dtypes_to_str_or_identity GridSampleModeSequence = Union[Sequence[Union[GridSampleMode, str]], GridSampleMode, str] GridSamplePadModeSequence = Union[Sequence[Union[GridSamplePadMode, str]], GridSamplePadMode, str] @@ -46,8 +42,7 @@ def extents_from_shape(shape, dtype=np.float64): # TODO: move to mapping_stack.py def shape_from_extents( - src_shape: Sequence, - extents: Union[Sequence[np.ndarray], Sequence[torch.Tensor], np.ndarray, torch.Tensor] + src_shape: Sequence, extents: Union[Sequence[np.ndarray], Sequence[torch.Tensor], np.ndarray, torch.Tensor] ): if isinstance(extents, (list, tuple)): if isinstance(extents[0], np.ndarray): @@ -100,13 +95,13 @@ def starting_matrix_and_extents(matrix_factory, data): def prepare_args_dict_for_apply(cur_mode, cur_padding_mode, cur_device, cur_dtype): kwargs = {} if cur_mode is not None: - kwargs['mode'] = cur_mode + kwargs["mode"] = cur_mode if cur_padding_mode is not None: - kwargs['padding_mode'] = cur_padding_mode + kwargs["padding_mode"] = cur_padding_mode if cur_device is not None: - kwargs['device'] = cur_device + kwargs["device"] = cur_device if cur_dtype is not None: - kwargs['dtype'] = cur_dtype + kwargs["dtype"] = cur_dtype return kwargs @@ -120,8 +115,7 @@ def matrix_from_matrix_container(matrix): return matrix -def apply(data: Union[torch.Tensor, MetaTensor, dict], - pending: Optional[Union[dict, list]] = None): +def apply(data: Union[torch.Tensor, MetaTensor, dict], pending: Optional[Union[dict, list]] = None): """ This method applies pending transforms to tensors. @@ -147,9 +141,7 @@ def apply(data: Union[torch.Tensor, MetaTensor, dict], return data dim_count = len(data.shape) - 1 - matrix_factory = MatrixFactory(dim_count, - get_backend_from_tensor_like(data), - get_device_from_tensor_like(data)) + matrix_factory = MatrixFactory(dim_count, get_backend_from_tensor_like(data), get_device_from_tensor_like(data)) # set up the identity matrix and metadata cumulative_matrix, cumulative_extents = starting_matrix_and_extents(matrix_factory, data) @@ -167,30 +159,26 @@ def apply(data: Union[torch.Tensor, MetaTensor, dict], cumulative_matrix = matmul(cumulative_matrix, next_matrix) cumulative_extents = [matmul(e, cumulative_matrix) for e in cumulative_extents] - new_mode = meta_matrix.metadata.get('mode', None) - new_padding_mode = meta_matrix.metadata.get('padding_mode', None) - new_device = meta_matrix.metadata.get('device', None) - new_dtype = meta_matrix.metadata.get('dtype', None) - new_shape = meta_matrix.metadata.get('shape_override', None) + new_mode = meta_matrix.metadata.get("mode", None) + new_padding_mode = meta_matrix.metadata.get("padding_mode", None) + new_device = meta_matrix.metadata.get("device", None) + new_dtype = meta_matrix.metadata.get("dtype", None) + new_shape = meta_matrix.metadata.get("shape_override", None) mode_compat = metadata_is_compatible(cur_mode, new_mode) padding_mode_compat = metadata_is_compatible(cur_padding_mode, new_padding_mode) device_compat = metadata_is_compatible(cur_device, new_device) dtype_compat = metadata_dtype_is_compatible(cur_dtype, new_dtype) - if (mode_compat is False or padding_mode_compat is False or - device_compat is False or dtype_compat is False): + if mode_compat is False or padding_mode_compat is False or device_compat is False or dtype_compat is False: # carry out an intermediate resample here due to incompatibility between arguments kwargs = prepare_args_dict_for_apply(cur_mode, cur_padding_mode, cur_device, cur_dtype) cumulative_matrix_ = matrix_from_matrix_container(cumulative_matrix) - a = Affine(norm_coords=False, - affine=cumulative_matrix_, - **kwargs) + a = Affine(norm_coords=False, affine=cumulative_matrix_, **kwargs) data, _ = a(img=data) - cumulative_matrix, cumulative_extents =\ - starting_matrix_and_extents(matrix_factory, data) + cumulative_matrix, cumulative_extents = starting_matrix_and_extents(matrix_factory, data) cur_mode = cur_mode if new_mode is None else new_mode cur_padding_mode = cur_padding_mode if new_padding_mode is None else new_padding_mode @@ -203,11 +191,7 @@ def apply(data: Union[torch.Tensor, MetaTensor, dict], cumulative_matrix_ = matrix_from_matrix_container(cumulative_matrix) # print(f"applying with cumulative matrix\n {cumulative_matrix_}") - a = Affine(norm_coords=False, - affine=cumulative_matrix_, - spatial_size=cur_shape[1:], - normalized=False, - **kwargs) + a = Affine(norm_coords=False, affine=cumulative_matrix_, spatial_size=cur_shape[1:], normalized=False, **kwargs) data, tx = a(img=data) data.clear_pending_transforms() diff --git a/monai/transforms/meta_matrix.py b/monai/transforms/meta_matrix.py index 559cd7e451..e955914c45 100644 --- a/monai/transforms/meta_matrix.py +++ b/monai/transforms/meta_matrix.py @@ -12,7 +12,6 @@ from typing import Optional, Union import numpy as np - import torch from monai.config import NdarrayOrTensor @@ -20,14 +19,14 @@ def is_matrix_shaped(data): - return (len(data.shape) == 2 and data.shape[0] in (3, 4) and - data.shape[1] in (3, 4) and data.shape[0] == data.shape[1]) + return ( + len(data.shape) == 2 and data.shape[0] in (3, 4) and data.shape[1] in (3, 4) and data.shape[0] == data.shape[1] + ) def is_grid_shaped(data): - return (len(data.shape) == 3 and data.shape[0] == 3 or - len(data.shape) == 4 and data.shape[0] == 4) + return len(data.shape) == 3 and data.shape[0] == 3 or len(data.shape) == 4 and data.shape[0] == 4 class MatrixFactory: @@ -42,7 +41,6 @@ def ensure_tensor(data: NdarrayOrTensor): class Matrix: - def __init__(self, matrix: NdarrayOrTensor): self.data = ensure_tensor(matrix) @@ -66,17 +64,15 @@ def __init__(self, grid): class MetaMatrix: - - def __init__( - self, - matrix: Union[NdarrayOrTensor, Matrix, Grid], - metadata: Optional[dict] = None): + def __init__(self, matrix: Union[NdarrayOrTensor, Matrix, Grid], metadata: Optional[dict] = None): if not isinstance(matrix, (Matrix, Grid)): if matrix.shape == 2: if matrix.shape[0] != matrix.shape[1] or matrix.shape[0] not in (3, 4): - raise ValueError("If 'matrix' is passed a numpy ndarray/torch Tensor, it must" - f" be 3x3 or 4x4 ('matrix' has has shape {matrix.shape})") + raise ValueError( + "If 'matrix' is passed a numpy ndarray/torch Tensor, it must" + f" be 3x3 or 4x4 ('matrix' has has shape {matrix.shape})" + ) matrix_ = Matrix(matrix) else: matrix_ = matrix @@ -100,8 +96,7 @@ def __rmatmul__(self, other): def matmul( - left: Union[MetaMatrix, Grid, Matrix, NdarrayOrTensor], - right: Union[MetaMatrix, Grid, Matrix, NdarrayOrTensor] + left: Union[MetaMatrix, Grid, Matrix, NdarrayOrTensor], right: Union[MetaMatrix, Grid, Matrix, NdarrayOrTensor] ): matrix_types = (MetaMatrix, Grid, Matrix, torch.Tensor, np.ndarray) @@ -153,16 +148,15 @@ def matmul( return result -def matmul_matrix_grid( - left: NdarrayOrTensor, - right: NdarrayOrTensor -): +def matmul_matrix_grid(left: NdarrayOrTensor, right: NdarrayOrTensor): if not is_matrix_shaped(left): raise ValueError(f"'left' should be a 2D or 3D homogenous matrix but has shape {left.shape}") if not is_grid_shaped(right): - raise ValueError("'right' should be a 3D array with shape[0] == 2 or a " - f"4D array with shape[0] == 3 but has shape {right.shape}") + raise ValueError( + "'right' should be a 3D array with shape[0] == 2 or a " + f"4D array with shape[0] == 3 but has shape {right.shape}" + ) # flatten the grid to take advantage of torch batch matrix multiply right_flat = right.reshape(right.shape[0], -1) @@ -172,13 +166,12 @@ def matmul_matrix_grid( return result -def matmul_grid_matrix( - left: NdarrayOrTensor, - right: NdarrayOrTensor -): +def matmul_grid_matrix(left: NdarrayOrTensor, right: NdarrayOrTensor): if not is_grid_shaped(left): - raise ValueError("'left' should be a 3D array with shape[0] == 2 or a " - f"4D array with shape[0] == 3 but has shape {left.shape}") + raise ValueError( + "'left' should be a 3D array with shape[0] == 2 or a " + f"4D array with shape[0] == 3 but has shape {left.shape}" + ) if not is_matrix_shaped(right): raise ValueError(f"'right' should be a 2D or 3D homogenous matrix but has shape {right.shape}") @@ -194,13 +187,12 @@ def matmul_grid_matrix( return matmul_matrix_grid(torch.inverse(right), left) -def matmul_grid_matrix_slow( - left: NdarrayOrTensor, - right: NdarrayOrTensor -): +def matmul_grid_matrix_slow(left: NdarrayOrTensor, right: NdarrayOrTensor): if not is_grid_shaped(left): - raise ValueError("'left' should be a 3D array with shape[0] == 2 or a " - f"4D array with shape[0] == 3 but has shape {left.shape}") + raise ValueError( + "'left' should be a 3D array with shape[0] == 2 or a " + f"4D array with shape[0] == 3 but has shape {left.shape}" + ) if not is_matrix_shaped(right): raise ValueError(f"'right' should be a 2D or 3D homogenous matrix but has shape {right.shape}") @@ -217,8 +209,5 @@ def matmul_grid_matrix_slow( return result -def matmul_matrix_matrix( - left: NdarrayOrTensor, - right: NdarrayOrTensor, -): +def matmul_matrix_matrix(left: NdarrayOrTensor, right: NdarrayOrTensor): return left @ right diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 091d3e5741..b505d27aa8 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -1792,9 +1792,7 @@ def squarepulse(sig, duty: float = 0.5): return y -def get_device_from_tensor_like( - data: NdarrayOrTensor -): +def get_device_from_tensor_like(data: NdarrayOrTensor): """ This function returns the device of `data`, which must either be a numpy ndarray or a pytorch Tensor. @@ -1814,9 +1812,7 @@ def get_device_from_tensor_like( raise ValueError(msg.format(type(data))) -def get_backend_from_tensor_like( - data: NdarrayOrTensor -): +def get_backend_from_tensor_like(data: NdarrayOrTensor): """ This function returns the backend of `data`, which must either be a numpy ndarray or a pytorch Tensor. @@ -1847,32 +1843,33 @@ def dtype_numpy_to_torch(dtype: np.dtype) -> torch.dtype: __dtype_dict = { - np.int8: 'int8', - torch.int8: 'int8', - np.int16: 'int16', - torch.int16: 'int16', - int: 'int32', - np.int32: 'int32', - torch.int32: 'int32', - np.int64: 'int64', - torch.int64: 'int64', - np.uint8: 'uint8', - torch.uint8: 'uint8', - np.uint16: 'uint16', - np.uint32: 'uint32', - np.uint64: 'uint64', - float: 'float32', - np.float16: 'float16', - np.float: 'float32', - np.float32: 'float32', - np.float64: 'float64', - torch.float16: 'float16', - torch.float: 'float32', - torch.float32: 'float32', - torch.double: 'float64', - torch.float64: 'float64' + np.int8: "int8", + torch.int8: "int8", + np.int16: "int16", + torch.int16: "int16", + int: "int32", + np.int32: "int32", + torch.int32: "int32", + np.int64: "int64", + torch.int64: "int64", + np.uint8: "uint8", + torch.uint8: "uint8", + np.uint16: "uint16", + np.uint32: "uint32", + np.uint64: "uint64", + float: "float32", + np.float16: "float16", + np.float: "float32", + np.float32: "float32", + np.float64: "float64", + torch.float16: "float16", + torch.float: "float32", + torch.float32: "float32", + torch.double: "float64", + torch.float64: "float64", } + def dtypes_to_str_or_identity(dtype: Any) -> Any: return __dtype_dict.get(dtype, dtype) diff --git a/tests/test_apply.py b/tests/test_apply.py index 953c249db4..9a050ab6ae 100644 --- a/tests/test_apply.py +++ b/tests/test_apply.py @@ -15,6 +15,5 @@ class TestApply(unittest.TestCase): - def _test_apply_impl(self): result = apply(None) diff --git a/tests/test_matmul.py b/tests/test_matmul.py index bf427ce06c..2ac3b70a2e 100644 --- a/tests/test_matmul.py +++ b/tests/test_matmul.py @@ -10,20 +10,24 @@ # limitations under the License. import unittest -from parameterized import parameterized import numpy as np - import torch +from parameterized import parameterized from monai.transforms.meta_matrix import ( - Grid, is_grid_shaped, is_matrix_shaped, matmul, matmul_matrix_grid, matmul_grid_matrix_slow, - matmul_grid_matrix, Matrix + Grid, + Matrix, + is_grid_shaped, + is_matrix_shaped, + matmul, + matmul_grid_matrix, + matmul_grid_matrix_slow, + matmul_matrix_grid, ) class TestMatmulFunctions(unittest.TestCase): - def test_matrix_grid_and_grid_inv_matrix_are_equivalent(self): grid = torch.randn((3, 32, 32)) @@ -96,7 +100,7 @@ def test_matmul_correct_return_type(self, left, right, expected): (torch.randn(5), False), (torch.randn(3, 3, 3), False), (torch.randn(4, 4, 4), False), - (torch.randn(5, 5, 5), False) + (torch.randn(5, 5, 5), False), ] def _test_is_matrix_shaped_impl(self, matrix, expected): @@ -111,7 +115,6 @@ def test_is_matrix_shaped(self, matrix, expected): # with self.subTest(f"{case[0].shape}"): # self._test_is_matrix_shaped_impl(*case) - GRID_SHAPE_TESTCASES = [ (torch.randn(1, 16, 32), False), (torch.randn(2, 16, 32), False), @@ -120,7 +123,7 @@ def test_is_matrix_shaped(self, matrix, expected): (torch.randn(5, 16, 32), False), (torch.randn(3, 16, 32, 64), False), (torch.randn(4, 16, 32, 64), True), - (torch.randn(5, 16, 32, 64), False) + (torch.randn(5, 16, 32, 64), False), ] def _test_is_grid_shaped_impl(self, grid, expected): From 69568ff87921750e8d202b956a5b61cdb725d6e7 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Thu, 27 Oct 2022 23:10:52 +0100 Subject: [PATCH 12/52] Fixing issues raised by linter Signed-off-by: Ben Murray --- monai/transforms/lazy/functional.py | 2 +- monai/transforms/meta_matrix.py | 2 +- tests/test_apply.py | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/monai/transforms/lazy/functional.py b/monai/transforms/lazy/functional.py index 607e90d6bb..aa153b8e02 100644 --- a/monai/transforms/lazy/functional.py +++ b/monai/transforms/lazy/functional.py @@ -37,7 +37,7 @@ def extents_from_shape(shape, dtype=np.float64): extents = [[0, shape[i]] for i in range(1, len(shape))] extents = it.product(*extents) - return list(np.asarray(e + (1,), dtype=dtype) for e in extents) + return [np.asarray(e + (1,), dtype=dtype) for e in extents] # TODO: move to mapping_stack.py diff --git a/monai/transforms/meta_matrix.py b/monai/transforms/meta_matrix.py index e955914c45..aff2d23c61 100644 --- a/monai/transforms/meta_matrix.py +++ b/monai/transforms/meta_matrix.py @@ -184,7 +184,7 @@ def matmul_grid_matrix(left: NdarrayOrTensor, right: NdarrayOrTensor): # invert the matrix and swap the arguments, taking advantage of # matrix @ vector == vector_transposed @ matrix_inverse - return matmul_matrix_grid(torch.inverse(right), left) + return matmul_matrix_grid(inv_matrix, left) def matmul_grid_matrix_slow(left: NdarrayOrTensor, right: NdarrayOrTensor): diff --git a/tests/test_apply.py b/tests/test_apply.py index 9a050ab6ae..036549d0b4 100644 --- a/tests/test_apply.py +++ b/tests/test_apply.py @@ -16,4 +16,5 @@ class TestApply(unittest.TestCase): def _test_apply_impl(self): - result = apply(None) + # result = apply(None) + pass From 9bf2ed8ef3bab6657a60bbf93110c5519b1efc1d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 27 Oct 2022 22:11:41 +0000 Subject: [PATCH 13/52] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_apply.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_apply.py b/tests/test_apply.py index 036549d0b4..a4279aa0ba 100644 --- a/tests/test_apply.py +++ b/tests/test_apply.py @@ -11,7 +11,6 @@ import unittest -from monai.transforms import apply class TestApply(unittest.TestCase): From 6e77d57938a4253353265616b54ab8be3f2d5ab3 Mon Sep 17 00:00:00 2001 From: Wenqi Li <831580+wyli@users.noreply.github.com> Date: Fri, 28 Oct 2022 08:54:33 +0100 Subject: [PATCH 14/52] 5422 update attentionunet parameters (#5423) Signed-off-by: Wenqi Li Fixes #5422 ### Description the kernel size and strides are hard-coded: https://github.com/Project-MONAI/MONAI/blob/a209b06438343830e561a0afd41b1025516a8977/monai/networks/nets/attentionunet.py#L151 this PR makes the values tunable. `level` parameter is not used and removed in this PR. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Wenqi Li --- monai/networks/nets/attentionunet.py | 35 +++++++++++++++++++++------- tests/test_attentionunet.py | 2 +- 2 files changed, 28 insertions(+), 9 deletions(-) diff --git a/monai/networks/nets/attentionunet.py b/monai/networks/nets/attentionunet.py index 177a54e105..a57b57425e 100644 --- a/monai/networks/nets/attentionunet.py +++ b/monai/networks/nets/attentionunet.py @@ -143,12 +143,27 @@ def forward(self, g: torch.Tensor, x: torch.Tensor) -> torch.Tensor: class AttentionLayer(nn.Module): - def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, submodule: nn.Module, dropout=0.0): + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + submodule: nn.Module, + up_kernel_size=3, + strides=2, + dropout=0.0, + ): super().__init__() self.attention = AttentionBlock( spatial_dims=spatial_dims, f_g=in_channels, f_l=in_channels, f_int=in_channels // 2 ) - self.upconv = UpConv(spatial_dims=spatial_dims, in_channels=out_channels, out_channels=in_channels, strides=2) + self.upconv = UpConv( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=in_channels, + strides=strides, + kernel_size=up_kernel_size, + ) self.merge = Convolution( spatial_dims=spatial_dims, in_channels=2 * in_channels, out_channels=in_channels, dropout=dropout ) @@ -174,7 +189,7 @@ class AttentionUnet(nn.Module): channels (Sequence[int]): sequence of channels. Top block first. The length of `channels` should be no less than 2. strides (Sequence[int]): stride to use for convolutions. kernel_size: convolution kernel size. - upsample_kernel_size: convolution kernel size for transposed convolution layers. + up_kernel_size: convolution kernel size for transposed convolution layers. dropout: dropout ratio. Defaults to no dropout. """ @@ -210,9 +225,9 @@ def __init__( ) self.up_kernel_size = up_kernel_size - def _create_block(channels: Sequence[int], strides: Sequence[int], level: int = 0) -> nn.Module: + def _create_block(channels: Sequence[int], strides: Sequence[int]) -> nn.Module: if len(channels) > 2: - subblock = _create_block(channels[1:], strides[1:], level=level + 1) + subblock = _create_block(channels[1:], strides[1:]) return AttentionLayer( spatial_dims=spatial_dims, in_channels=channels[0], @@ -227,17 +242,19 @@ def _create_block(channels: Sequence[int], strides: Sequence[int], level: int = ), subblock, ), + up_kernel_size=self.up_kernel_size, + strides=strides[0], dropout=dropout, ) else: # the next layer is the bottom so stop recursion, - # create the bottom layer as the sublock for this layer - return self._get_bottom_layer(channels[0], channels[1], strides[0], level=level + 1) + # create the bottom layer as the subblock for this layer + return self._get_bottom_layer(channels[0], channels[1], strides[0]) encdec = _create_block(self.channels, self.strides) self.model = nn.Sequential(head, encdec, reduce_channels) - def _get_bottom_layer(self, in_channels: int, out_channels: int, strides: int, level: int) -> nn.Module: + def _get_bottom_layer(self, in_channels: int, out_channels: int, strides: int) -> nn.Module: return AttentionLayer( spatial_dims=self.dimensions, in_channels=in_channels, @@ -249,6 +266,8 @@ def _get_bottom_layer(self, in_channels: int, out_channels: int, strides: int, l strides=strides, dropout=self.dropout, ), + up_kernel_size=self.up_kernel_size, + strides=strides, dropout=self.dropout, ) diff --git a/tests/test_attentionunet.py b/tests/test_attentionunet.py index b2f53f9c16..e1df7b8acd 100644 --- a/tests/test_attentionunet.py +++ b/tests/test_attentionunet.py @@ -39,7 +39,7 @@ def test_attentionunet(self): shape = (3, 1) + (92,) * dims input = torch.rand(*shape) model = att.AttentionUnet( - spatial_dims=dims, in_channels=1, out_channels=2, channels=(3, 4, 5), strides=(2, 2) + spatial_dims=dims, in_channels=1, out_channels=2, channels=(3, 4, 5), up_kernel_size=5, strides=(1, 2) ) output = model(input) self.assertEqual(output.shape[2:], input.shape[2:]) From 95e37c4775cefdda145a75cf5ee8860b04ff5926 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Fri, 28 Oct 2022 11:20:42 +0100 Subject: [PATCH 15/52] Starting tests for apply Signed-off-by: Ben Murray --- tests/test_apply.py | 48 +++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 44 insertions(+), 4 deletions(-) diff --git a/tests/test_apply.py b/tests/test_apply.py index 036549d0b4..7d8af36e75 100644 --- a/tests/test_apply.py +++ b/tests/test_apply.py @@ -11,10 +11,50 @@ import unittest -from monai.transforms import apply +import numpy as np + +import torch + +from monai.transforms.lazy.functional import apply +from monai.transforms.meta_matrix import MetaMatrix + + +def get_img(size, dtype=torch.float32, offset=0): + img = torch.zeros(size, dtype=dtype) + if len(size) == 2: + for j in range(size[0]): + for i in range(size[1]): + img[j, i] = i + j * size[0] + offset + else: + for k in range(size[0]): + for j in range(size[1]): + for i in range(size[2]): + img[k, j, i] = i + j * size[0] + k * size[0] * size[1] + return np.expand_dims(img, 0) + + +def rotate_45_2D(): + t = torch.eye(3) + t[:, 0] = torch.FloatTensor([0, -1, 0]) + t[:, 1] = torch.FloatTensor([1, 0, 0]) + return t class TestApply(unittest.TestCase): - def _test_apply_impl(self): - # result = apply(None) - pass + + def _test_apply_impl(self, tensor, pending_transforms): + print(tensor.shape) + # for m in pending_transforms: + # print(m.matrix) + # print(m.metadata) + result = apply(tensor, pending_transforms) + print(result) + + SINGLE_TRANSFORM_CASES = [ + (torch.randn((1, 16, 16)), [MetaMatrix(rotate_45_2D(), {"id", "rotate"})]) + ] + + def test_apply_single_transform(self): + for case in self.SINGLE_TRANSFORM_CASES: + self._test_apply_impl(*case) + From 01b8e12b53cf16e40bba7bfac2d4b111f2dc3c21 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 28 Oct 2022 10:22:39 +0000 Subject: [PATCH 16/52] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_apply.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_apply.py b/tests/test_apply.py index 7d8af36e75..6cc5719019 100644 --- a/tests/test_apply.py +++ b/tests/test_apply.py @@ -57,4 +57,3 @@ def _test_apply_impl(self, tensor, pending_transforms): def test_apply_single_transform(self): for case in self.SINGLE_TRANSFORM_CASES: self._test_apply_impl(*case) - From 5d0db756d0bdb6c056461b87643ee0a83742ea4d Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Fri, 28 Oct 2022 14:57:52 +0100 Subject: [PATCH 17/52] Further array functionality and testing; waiting on PR #5107 Signed-off-by: Ben Murray --- monai/transforms/__init__.py | 7 ++ monai/transforms/lazy/array.py | 4 +- monai/transforms/lazy/functional.py | 18 +++-- monai/transforms/meta_matrix.py | 76 ++++++++++++++++- monai/transforms/utils.py | 121 ++++++++++++++++++++++++++++ tests/test_apply.py | 20 ++++- 6 files changed, 235 insertions(+), 11 deletions(-) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 1350b3bd95..344a68f0e1 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -237,6 +237,13 @@ ToMetaTensorD, ToMetaTensorDict, ) +from .meta_matrix import ( + Grid, + matmul, + Matrix, + MatrixFactory, + MetaMatrix, +) from .nvtx import ( Mark, Markd, diff --git a/monai/transforms/lazy/array.py b/monai/transforms/lazy/array.py index bc61a24a6e..ae165cf566 100644 --- a/monai/transforms/lazy/array.py +++ b/monai/transforms/lazy/array.py @@ -9,8 +9,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from monai.transforms.inverse import InvertibleTransform from monai.transforms.lazy.functional import apply +from monai.transforms.inverse import InvertibleTransform + +__all__ = ["Apply"] class Apply(InvertibleTransform): diff --git a/monai/transforms/lazy/functional.py b/monai/transforms/lazy/functional.py index aa153b8e02..adb44fc9a3 100644 --- a/monai/transforms/lazy/functional.py +++ b/monai/transforms/lazy/functional.py @@ -33,7 +33,7 @@ # TODO: move to mapping_stack.py -def extents_from_shape(shape, dtype=np.float64): +def extents_from_shape(shape, dtype=np.float32): extents = [[0, shape[i]] for i in range(1, len(shape))] extents = it.product(*extents) @@ -132,10 +132,10 @@ def apply(data: Union[torch.Tensor, MetaTensor, dict], pending: Optional[Union[d rd[k] = result return rd - if isinstance(data, MetaTensor) or pending is not None: - pending_ = [] if pending is None else pending - else: + if isinstance(data, MetaTensor) and pending is None: pending_ = data.pending_transforms + else: + pending_ = [] if pending is None else pending if len(pending_) == 0: return data @@ -154,7 +154,7 @@ def apply(data: Union[torch.Tensor, MetaTensor, dict], pending: Optional[Union[d cur_shape = data.shape for meta_matrix in pending_: - next_matrix = meta_matrix.data + next_matrix = meta_matrix.matrix # print("intermediate matrix\n", matrix_from_matrix_container(cumulative_matrix)) cumulative_matrix = matmul(cumulative_matrix, next_matrix) cumulative_extents = [matmul(e, cumulative_matrix) for e in cumulative_extents] @@ -193,6 +193,10 @@ def apply(data: Union[torch.Tensor, MetaTensor, dict], pending: Optional[Union[d # print(f"applying with cumulative matrix\n {cumulative_matrix_}") a = Affine(norm_coords=False, affine=cumulative_matrix_, spatial_size=cur_shape[1:], normalized=False, **kwargs) data, tx = a(img=data) - data.clear_pending_transforms() + if isinstance(data, MetaTensor): + data.clear_pending_transforms() + for p in pending_: + data.affine = p.matrix.data + data.push_applied_operation(p) - return data + return data, pending_ diff --git a/monai/transforms/meta_matrix.py b/monai/transforms/meta_matrix.py index aff2d23c61..845fe3f2c5 100644 --- a/monai/transforms/meta_matrix.py +++ b/monai/transforms/meta_matrix.py @@ -9,13 +9,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union +from typing import Optional, Sequence, Union import numpy as np import torch +from monai.transforms.utils import _create_rotate, _create_rotate_90, _create_flip, _create_shear, _create_scale, \ + _create_translate + +from monai.transforms.utils import get_backend_from_tensor_like, get_device_from_tensor_like +from monai.utils import TransformBackends from monai.config import NdarrayOrTensor +__all__ = ["Grid", "matmul", "Matrix", "MatrixFactory", "MetaMatrix"] def is_matrix_shaped(data): @@ -30,7 +36,73 @@ def is_grid_shaped(data): class MatrixFactory: - pass + + def __init__(self, + dims: int, + backend: TransformBackends, + device: Optional[torch.device] = None): + + if backend == TransformBackends.NUMPY: + if device is not None: + raise ValueError("'device' must be None with TransformBackends.NUMPY") + self._device = None + self._sin = lambda th: np.sin(th, dtype=np.float32) + self._cos = lambda th: np.cos(th, dtype=np.float32) + self._eye = lambda th: np.eye(th, dtype=np.float32) + self._diag = lambda th: np.diag(th, dtype=np.float32) + else: + if device is None: + raise ValueError("'device' must be set with TransformBackends.TORCH") + self._device = device + self._sin = lambda th: torch.sin(torch.as_tensor(th, + dtype=torch.float32, + device=self._device)) + self._cos = lambda th: torch.cos(torch.as_tensor(th, + dtype=torch.float32, + device=self._device)) + self._eye = lambda rank: torch.eye(rank, + device=self._device, + dtype=torch.float32); + self._diag = lambda size: torch.diag(torch.as_tensor(size, + device=self._device, + dtype=torch.float32)) + + self._backend = backend + self._dims = dims + + @staticmethod + def from_tensor(data): + return MatrixFactory(len(data.shape)-1, + get_backend_from_tensor_like(data), + get_device_from_tensor_like(data)) + + def identity(self): + matrix = self._eye(self._dims + 1) + return MetaMatrix(matrix, {}) + + def rotate_euler(self, radians: Union[Sequence[float], float], **extra_args): + matrix = _create_rotate(self._dims, radians, self._sin, self._cos, self._eye) + return MetaMatrix(matrix, extra_args) + + def rotate_90(self, rotations, axis, **extra_args): + matrix = _create_rotate_90(self._dims, rotations, axis) + return MetaMatrix(matrix, extra_args) + + def flip(self, axis, **extra_args): + matrix = _create_flip(self._dims, axis, self._eye) + return MetaMatrix(matrix, extra_args) + + def shear(self, coefs: Union[Sequence[float], float], **extra_args): + matrix = _create_shear(self._dims, coefs, self._eye) + return MetaMatrix(matrix, extra_args) + + def scale(self, factors: Union[Sequence[float], float], **extra_args): + matrix = _create_scale(self._dims, factors, self._diag) + return MetaMatrix(matrix, extra_args) + + def translate(self, offsets: Union[Sequence[float], float], **extra_args): + matrix = _create_translate(self._dims, offsets, self._eye) + return MetaMatrix(matrix, extra_args) def ensure_tensor(data: NdarrayOrTensor): diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index b505d27aa8..a982471cfe 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -891,6 +891,127 @@ def _create_translate( return array_func(affine) # type: ignore +def _create_rotate_90( + spatial_dims: int, + axis: Tuple[int, int], + steps: Optional[int] = 1, + eye_func: Callable = np.eye +) -> NdarrayOrTensor: + + values = [(1, 0, 0, 1), + (0, -1, 1, 0), + (-1, 0, 0, -1), + (0, 1, -1, 0)] + + if spatial_dims == 2: + if axis != (0, 1): + raise ValueError(f"if 'spatial_dims' is 2, 'axis' must be (0, 1) but is {axis}") + elif spatial_dims == 3: + if axis not in ((0, 1), (0, 2), (1, 2)): + raise ValueError("if 'spatial_dims' is 3, 'axis' must be (0,1), (0, 2), or (1, 2) " + f"but is {axis}") + else: + raise ValueError(f"'spatial_dims' must be 2 or 3 but is {spatial_dims}") + + steps_ = steps % 4 + + affine = eye_func(spatial_dims + 1) + + if spatial_dims == 2: + a, b = 0, 1 + else: + a, b = axis + + affine[a, a], affine[a, b], affine[b, a], affine[b, b] = values[steps] + return affine + + +def create_rotate_90( + spatial_dims: int, + axis: int, + steps: Optional[int] = 1, + device: Optional[torch.device] = None, + backend: str = TransformBackends.NUMPY, +) -> NdarrayOrTensor: + """ + create a 2D or 3D rotation matrix + + Args: + spatial_dims: {``2``, ``3``} spatial rank + radians: rotation radians + when spatial_dims == 3, the `radians` sequence corresponds to + rotation in the 1st, 2nd, and 3rd dim respectively. + device: device to compute and store the output (when the backend is "torch"). + backend: APIs to use, ``numpy`` or ``torch``. + + Raises: + ValueError: When ``radians`` is empty. + ValueError: When ``spatial_dims`` is not one of [2, 3]. + + """ + _backend = look_up_option(backend, TransformBackends) + if _backend == TransformBackends.NUMPY: + return _create_rotate_90( + spatial_dims=spatial_dims, + axis=axis, + steps=steps, + eye_func=np.eye) + if _backend == TransformBackends.TORCH: + return _create_rotate_90( + spatial_dims=spatial_dims, + axis=axis, + steps=steps, + eye_func=lambda rank: torch.eye(rank, device=device), + ) + raise ValueError(f"backend {backend} is not supported") + + +def _create_flip( + spatial_dims: int, + spatial_axis: Union[Sequence[int], int], + eye_func: Callable = np.eye +): + affine = eye_func(spatial_dims + 1) + if isinstance(spatial_axis, int): + if spatial_axis < -spatial_dims or spatial_axis >= spatial_dims: + raise ValueError("'spatial_axis' values must be between " + f"{-spatial_dims} and {spatial_dims-1} inclusive " + f"('spatial_axis' is {spatial_axis})") + affine[spatial_axis, spatial_axis] = -1 + else: + if any((s < -spatial_dims or s >= spatial_dims) for s in spatial_axis): + raise ValueError("'spatial_axis' values must be between " + f"{-spatial_dims} and {spatial_dims-1} inclusive " + f"('spatial_axis' is {spatial_axis})") + + for i in range(spatial_dims): + if i in spatial_axis: + affine[i, i] = -1 + + return affine + + +def create_flip( + spatial_dims: int, + spatial_axis: Union[Sequence[int], int], + device: Optional[torch.device] = None, + backend: str = TransformBackends.NUMPY, +) -> NdarrayOrTensor: + _backend = look_up_option(backend, TransformBackends) + if _backend == TransformBackends.NUMPY: + return _create_flip( + spatial_dims=spatial_dims, + spatial_axis=spatial_axis, + eye_func=np.eye) + if _backend == TransformBackends.TORCH: + return _create_flip( + spatial_dims=spatial_dims, + spatial_axis=spatial_axis, + eye_func=lambda rank: torch.eye(rank, device=device), + ) + raise ValueError(f"backend {backend} is not supported") + + def generate_spatial_bounding_box( img: NdarrayOrTensor, select_fn: Callable = is_positive, diff --git a/tests/test_apply.py b/tests/test_apply.py index 7d8af36e75..088d275a12 100644 --- a/tests/test_apply.py +++ b/tests/test_apply.py @@ -14,6 +14,7 @@ import numpy as np import torch +from monai.utils import convert_to_tensor from monai.transforms.lazy.functional import apply from monai.transforms.meta_matrix import MetaMatrix @@ -50,11 +51,28 @@ def _test_apply_impl(self, tensor, pending_transforms): result = apply(tensor, pending_transforms) print(result) + def _test_apply_metatensor_impl(self, tensor, pending_transforms, pending_as_parameter): + tensor_ = convert_to_tensor(tensor) + if pending_as_parameter: + result = apply(tensor_, pending_transforms) + else: + for p in pending_transforms: + # TODO: cannot do the next part until the MetaTensor PR #5107 is in + # tensor_.push_pending(p) + raise NotImplementedError() + SINGLE_TRANSFORM_CASES = [ - (torch.randn((1, 16, 16)), [MetaMatrix(rotate_45_2D(), {"id", "rotate"})]) + (torch.randn((1, 16, 16)), [MetaMatrix(rotate_45_2D(), {"id": "rotate"})]) ] def test_apply_single_transform(self): for case in self.SINGLE_TRANSFORM_CASES: self._test_apply_impl(*case) + def test_apply_single_transform_metatensor(self): + for case in self.SINGLE_TRANSFORM_CASES: + self._test_apply_metatensor_impl(*case, False) + + def test_apply_single_transform_metatensor_override(self): + for case in self.SINGLE_TRANSFORM_CASES: + self._test_apply_metatensor_impl(*case, True) From aca9a8b10a7d555262fad73bdd12c41708c96cfc Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Fri, 28 Oct 2022 15:53:06 +0100 Subject: [PATCH 18/52] Adding resample function for unified resampling and test placeholder Signed-off-by: Ben Murray --- monai/transforms/utility/functional.py | 32 ++++++++++++++++++ tests/test_apply.py | 23 +++---------- tests/test_resample.py | 46 ++++++++++++++++++++++++++ tests/utils.py | 18 ++++++++++ 4 files changed, 100 insertions(+), 19 deletions(-) create mode 100644 monai/transforms/utility/functional.py create mode 100644 tests/test_resample.py diff --git a/monai/transforms/utility/functional.py b/monai/transforms/utility/functional.py new file mode 100644 index 0000000000..7aec22c723 --- /dev/null +++ b/monai/transforms/utility/functional.py @@ -0,0 +1,32 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Union + +import torch +from monai.transforms import Affine + +from monai.config import NdarrayOrTensor +from monai.transforms.meta_matrix import Grid, Matrix + + +def resample( + data: torch.Tensor, + matrix: Union[NdarrayOrTensor, Matrix, Grid], + kwargs: Optional[dict] = None +): + """ + This is a minimal implementation of resample that always uses Affine. + """ + if kwargs is not None: + a = Affine(affine=matrix, **kwargs) + else: + a = Affine(affine=matrix) + return a(img=data) diff --git a/tests/test_apply.py b/tests/test_apply.py index 088d275a12..19fca1dae2 100644 --- a/tests/test_apply.py +++ b/tests/test_apply.py @@ -19,22 +19,10 @@ from monai.transforms.lazy.functional import apply from monai.transforms.meta_matrix import MetaMatrix +from tests.utils import get_arange_img -def get_img(size, dtype=torch.float32, offset=0): - img = torch.zeros(size, dtype=dtype) - if len(size) == 2: - for j in range(size[0]): - for i in range(size[1]): - img[j, i] = i + j * size[0] + offset - else: - for k in range(size[0]): - for j in range(size[1]): - for i in range(size[2]): - img[k, j, i] = i + j * size[0] + k * size[0] * size[1] - return np.expand_dims(img, 0) - -def rotate_45_2D(): +def rotate_45_2d(): t = torch.eye(3) t[:, 0] = torch.FloatTensor([0, -1, 0]) t[:, 1] = torch.FloatTensor([1, 0, 0]) @@ -45,11 +33,8 @@ class TestApply(unittest.TestCase): def _test_apply_impl(self, tensor, pending_transforms): print(tensor.shape) - # for m in pending_transforms: - # print(m.matrix) - # print(m.metadata) result = apply(tensor, pending_transforms) - print(result) + self.assertListEqual(result[1], pending_transforms) def _test_apply_metatensor_impl(self, tensor, pending_transforms, pending_as_parameter): tensor_ = convert_to_tensor(tensor) @@ -62,7 +47,7 @@ def _test_apply_metatensor_impl(self, tensor, pending_transforms, pending_as_par raise NotImplementedError() SINGLE_TRANSFORM_CASES = [ - (torch.randn((1, 16, 16)), [MetaMatrix(rotate_45_2D(), {"id": "rotate"})]) + (torch.randn((1, 16, 16)), [MetaMatrix(rotate_45_2d(), {"id": "rotate"})]) ] def test_apply_single_transform(self): diff --git a/tests/test_resample.py b/tests/test_resample.py new file mode 100644 index 0000000000..ce9c8ede9e --- /dev/null +++ b/tests/test_resample.py @@ -0,0 +1,46 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import torch + +from monai.transforms.utility.functional import resample +from monai.utils import convert_to_tensor + +from monai.transforms.lazy.functional import apply +from monai.transforms.meta_matrix import MetaMatrix + +from tests.utils import get_arange_img + + +def rotate_45_2d(): + t = torch.eye(3) + t[:, 0] = torch.FloatTensor([0, -1, 0]) + t[:, 1] = torch.FloatTensor([1, 0, 0]) + return t + + +class TestResampleFunction(unittest.TestCase): + + def _test_resample_function_impl(self, img, matrix): + result = resample(convert_to_tensor(img), matrix) + print(result) + + RESAMPLE_FUNCTION_CASES = [ + (get_arange_img((1, 16, 16)), rotate_45_2d()) + ] + + def test_resample_function(self): + for case in self.RESAMPLE_FUNCTION_CASES: + self._test_resample_function_impl(*case) diff --git a/tests/utils.py b/tests/utils.py index b16b4b13fb..bbceb9cfc4 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -347,6 +347,24 @@ def make_rand_affine(ndim: int = 3, random_state: Optional[np.random.RandomState return af +def get_arange_img(size, dtype=torch.float32, offset=0): + """ + Returns an 2d or 3d image as a numpy tensor (complete with channel as dim 0) + with contents that iterate like an arange. + """ + img = torch.zeros(size, dtype=dtype) + if len(size) == 2: + for j in range(size[0]): + for i in range(size[1]): + img[j, i] = i + j * size[0] + offset + else: + for k in range(size[0]): + for j in range(size[1]): + for i in range(size[2]): + img[k, j, i] = i + j * size[0] + k * size[0] * size[1] + offset + return np.expand_dims(img, 0) + + class DistTestCase(unittest.TestCase): """ testcase without _outcome, so that it's picklable. From 5788a2345dbeb399061ccad15ade1708c568926a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 28 Oct 2022 14:56:28 +0000 Subject: [PATCH 19/52] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_apply.py | 2 -- tests/test_resample.py | 3 --- 2 files changed, 5 deletions(-) diff --git a/tests/test_apply.py b/tests/test_apply.py index 19fca1dae2..2447d8a3af 100644 --- a/tests/test_apply.py +++ b/tests/test_apply.py @@ -11,7 +11,6 @@ import unittest -import numpy as np import torch from monai.utils import convert_to_tensor @@ -19,7 +18,6 @@ from monai.transforms.lazy.functional import apply from monai.transforms.meta_matrix import MetaMatrix -from tests.utils import get_arange_img def rotate_45_2d(): diff --git a/tests/test_resample.py b/tests/test_resample.py index ce9c8ede9e..c032481389 100644 --- a/tests/test_resample.py +++ b/tests/test_resample.py @@ -11,15 +11,12 @@ import unittest -import numpy as np import torch from monai.transforms.utility.functional import resample from monai.utils import convert_to_tensor -from monai.transforms.lazy.functional import apply -from monai.transforms.meta_matrix import MetaMatrix from tests.utils import get_arange_img From 3f82016b5e85b5325546c7f4840b7baf82dddc93 Mon Sep 17 00:00:00 2001 From: Wenqi Li <831580+wyli@users.noreply.github.com> Date: Fri, 28 Oct 2022 16:03:50 +0100 Subject: [PATCH 20/52] 5425 conda tests (#5426) Fixes #5425 tested: https://github.com/Project-MONAI/MONAI/actions/runs/3344552788 ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Wenqi Li --- tests/test_nifti_rw.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/tests/test_nifti_rw.py b/tests/test_nifti_rw.py index 31e1de3fe9..6edf53d339 100644 --- a/tests/test_nifti_rw.py +++ b/tests/test_nifti_rw.py @@ -157,8 +157,8 @@ def test_write_2d(self): writer_obj.set_metadata({"affine": np.diag([1, 1, 1]), "original_affine": np.diag([1.4, 1, 1])}) writer_obj.write(image_name, verbose=True) out = nib.load(image_name) - np.testing.assert_allclose(out.get_fdata(), [[0, 1, 2], [3.0, 4, 5]]) - np.testing.assert_allclose(out.affine, np.diag([1.4, 1, 1, 1])) + np.testing.assert_allclose(out.get_fdata(), [[0, 1, 2], [3.0, 4, 5]], atol=1e-4, rtol=1e-4) + np.testing.assert_allclose(out.affine, np.diag([1.4, 1, 1, 1]), atol=1e-4, rtol=1e-4) image_name = os.path.join(out_dir, "test1.nii.gz") img = np.arange(5).reshape((1, 5)) @@ -168,8 +168,8 @@ def test_write_2d(self): ) writer_obj.write(image_name, verbose=True) out = nib.load(image_name) - np.testing.assert_allclose(out.get_fdata(), [[0, 2, 4]]) - np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 1, 1])) + np.testing.assert_allclose(out.get_fdata(), [[0, 2, 4]], atol=1e-4, rtol=1e-4) + np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 1, 1]), atol=1e-4, rtol=1e-4) def test_write_3d(self): with tempfile.TemporaryDirectory() as out_dir: @@ -192,8 +192,8 @@ def test_write_3d(self): ) writer_obj.write(image_name, verbose=True) out = nib.load(image_name) - np.testing.assert_allclose(out.get_fdata(), [[[0, 2, 4]]]) - np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 2, 1])) + np.testing.assert_allclose(out.get_fdata(), [[[0, 2, 4]]], atol=1e-4, rtol=1e-4) + np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 2, 1]), atol=1e-4, rtol=1e-4) def test_write_4d(self): with tempfile.TemporaryDirectory() as out_dir: @@ -216,8 +216,8 @@ def test_write_4d(self): ) writer_obj.write(image_name, verbose=True) out = nib.load(image_name) - np.testing.assert_allclose(out.get_fdata(), [[[[0], [2], [4]]]]) - np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 2, 1])) + np.testing.assert_allclose(out.get_fdata(), [[[[0], [2], [4]]]], atol=1e-4, rtol=1e-4) + np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 2, 1]), atol=1e-4, rtol=1e-4) def test_write_5d(self): with tempfile.TemporaryDirectory() as out_dir: @@ -241,8 +241,10 @@ def test_write_5d(self): writer_obj.set_metadata({"affine": np.diag([1, 1, 1, 3]), "original_affine": np.diag([1.4, 2.0, 2, 3])}) writer_obj.write(image_name, verbose=True) out = nib.load(image_name) - np.testing.assert_allclose(out.get_fdata(), np.array([[[[[0.0, 2.0]], [[4.0, 5.0]], [[7.0, 9.0]]]]])) - np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 2, 1])) + np.testing.assert_allclose( + out.get_fdata(), np.array([[[[[0.0, 2.0]], [[4.0, 5.0]], [[7.0, 9.0]]]]]), atol=1e-4, rtol=1e-4 + ) + np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 2, 1]), atol=1e-4, rtol=1e-4) if __name__ == "__main__": From 8b1f0c35326d162d30ebb72827ba06b65065bf3e Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Fri, 28 Oct 2022 21:28:14 +0100 Subject: [PATCH 21/52] Type classes for lazy resampling (#5418) Signed-off-by: Ben Murray Fixes # . ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Ben Murray --- docs/source/transforms.rst | 20 ++++++ monai/transforms/__init__.py | 13 +++- monai/transforms/transform.py | 85 ++++++++++++++++++++++- monai/visualize/gradient_based.py | 2 +- tests/test_randomizable_transform_type.py | 33 +++++++++ 5 files changed, 149 insertions(+), 4 deletions(-) create mode 100644 tests/test_randomizable_transform_type.py diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 874f01a945..7b728fde48 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -22,11 +22,31 @@ Generic Interfaces :members: :special-members: __call__ +`RandomizableTrait` +^^^^^^^^^^^^^^^^^^^ +.. autoclass:: RandomizableTrait + :members: + +`LazyTrait` +^^^^^^^^^^^ +.. autoclass:: LazyTrait + :members: + +`MultiSampleTrait` +^^^^^^^^^^^^^^^^^^ +.. autoclass:: MultiSampleTrait + :members: + `Randomizable` ^^^^^^^^^^^^^^ .. autoclass:: Randomizable :members: +`LazyTransform` +^^^^^^^^^^^^^^^ +.. autoclass:: LazyTransform + :members: + `RandomizableTransform` ^^^^^^^^^^^^^^^^^^^^^^^ .. autoclass:: RandomizableTransform diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 389571d16f..9cabc167a7 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -449,7 +449,18 @@ ZoomD, ZoomDict, ) -from .transform import MapTransform, Randomizable, RandomizableTransform, ThreadUnsafe, Transform, apply_transform +from .transform import ( + LazyTrait, + LazyTransform, + MapTransform, + MultiSampleTrait, + Randomizable, + RandomizableTrait, + RandomizableTransform, + ThreadUnsafe, + Transform, + apply_transform, +) from .utility.array import ( AddChannel, AddCoordinateChannels, diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 21d057f5d3..b1a7d9b4db 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -26,7 +26,18 @@ from monai.utils.enums import TransformBackends from monai.utils.misc import MONAIEnvVars -__all__ = ["ThreadUnsafe", "apply_transform", "Randomizable", "RandomizableTransform", "Transform", "MapTransform"] +__all__ = [ + "ThreadUnsafe", + "apply_transform", + "LazyTrait", + "RandomizableTrait", + "MultiSampleTrait", + "Randomizable", + "LazyTransform", + "RandomizableTransform", + "Transform", + "MapTransform", +] ReturnType = TypeVar("ReturnType") @@ -118,6 +129,56 @@ def _log_stats(data, prefix: Optional[str] = "Data"): raise RuntimeError(f"applying transform {transform}") from e +class LazyTrait: + """ + An interface to indicate that the transform has the capability to execute using + MONAI's lazy resampling feature. In order to do this, the implementing class needs + to be able to describe its operation as an affine matrix or grid with accompanying metadata. + This interface can be extended from by people adapting transforms to the MONAI framework as + well as by implementors of MONAI transforms. + """ + + @property + def lazy_evaluation(self): + """ + Get whether lazy_evaluation is enabled for this transform instance. + Returns: + True if the transform is operating in a lazy fashion, False if not. + """ + raise NotImplementedError() + + @lazy_evaluation.setter + def lazy_evaluation(self, enabled: bool): + """ + Set whether lazy_evaluation is enabled for this transform instance. + Args: + enabled: True if the transform should operate in a lazy fashion, False if not. + """ + raise NotImplementedError() + + +class RandomizableTrait: + """ + An interface to indicate that the transform has the capability to perform + randomized transforms to the data that it is called upon. This interface + can be extended from by people adapting transforms to the MONAI framework as well as by + implementors of MONAI transforms. + """ + + pass + + +class MultiSampleTrait: + """ + An interface to indicate that the transform has the capability to return multiple samples + given an input, such as when performing random crops of a sample. This interface can be + extended from by people adapting transforms to the MONAI framework as well as by implementors + of MONAI transforms. + """ + + pass + + class ThreadUnsafe: """ A class to denote that the transform will mutate its member variables, @@ -251,7 +312,27 @@ def __call__(self, data: Any): raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") -class RandomizableTransform(Randomizable, Transform): +class LazyTransform(Transform, LazyTrait): + """ + An implementation of functionality for lazy transforms that can be subclassed by array and + dictionary transforms to simplify implementation of new lazy transforms. + """ + + def __init__(self, lazy_evaluation: Optional[bool] = True): + self.lazy_evaluation = lazy_evaluation + + @property + def lazy_evaluation(self): + return self.lazy_evaluation + + @lazy_evaluation.setter + def lazy_evaluation(self, lazy_evaluation: bool): + if not isinstance(lazy_evaluation, bool): + raise TypeError("'lazy_evaluation must be a bool but is of " f"type {type(lazy_evaluation)}'") + self.lazy_evaluation = lazy_evaluation + + +class RandomizableTransform(Randomizable, Transform, RandomizableTrait): """ An interface for handling random state locally, currently based on a class variable `R`, which is an instance of `np.random.RandomState`. diff --git a/monai/visualize/gradient_based.py b/monai/visualize/gradient_based.py index 7f4ddce1d0..7ab6ef260d 100644 --- a/monai/visualize/gradient_based.py +++ b/monai/visualize/gradient_based.py @@ -90,7 +90,7 @@ def get_grad(self, x: torch.Tensor, index: torch.Tensor | int | None, retain_gra x.requires_grad = True self._model(x, class_idx=index, retain_graph=retain_graph, **kwargs) - grad: torch.Tensor = x.grad.detach() + grad: torch.Tensor = x.grad.detach() # type: ignore return grad def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None, **kwargs) -> torch.Tensor: diff --git a/tests/test_randomizable_transform_type.py b/tests/test_randomizable_transform_type.py new file mode 100644 index 0000000000..9f77d2cd5a --- /dev/null +++ b/tests/test_randomizable_transform_type.py @@ -0,0 +1,33 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from monai.transforms.transform import RandomizableTrait, RandomizableTransform + + +class InheritsInterface(RandomizableTrait): + pass + + +class InheritsImplementation(RandomizableTransform): + def __call__(self, data): + return data + + +class TestRandomizableTransformType(unittest.TestCase): + def test_is_randomizable_transform_type(self): + inst = InheritsInterface() + self.assertIsInstance(inst, RandomizableTrait) + + def test_set_random_state_randomizable_transform(self): + inst = InheritsImplementation() + inst.set_random_state(0) From 350fe6e5c742826e27da2cb0a6c6a6226c0b8e86 Mon Sep 17 00:00:00 2001 From: Wenqi Li <831580+wyli@users.noreply.github.com> Date: Sat, 29 Oct 2022 12:01:36 +0100 Subject: [PATCH 22/52] 5432 convert metadict types (#5433) Signed-off-by: Wenqi Li Fixes #5432 Fixes https://github.com/Project-MONAI/MONAILabel/issues/1064 ### Description convert potential itk types in `img.GetMetaDataDictionary()` to numpy arrays ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Wenqi Li --- .github/workflows/pythonapp.yml | 4 ++++ monai/data/image_reader.py | 7 ++++++- tests/test_load_image.py | 2 +- 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml index 484026626f..a710e8e46e 100644 --- a/.github/workflows/pythonapp.yml +++ b/.github/workflows/pythonapp.yml @@ -88,6 +88,10 @@ jobs: name: Install torch cpu from pytorch.org (Windows only) run: | python -m pip install torch==1.12.1+cpu torchvision==0.13.1+cpu -f https://download.pytorch.org/whl/torch_stable.html + - if: runner.os == 'Linux' + name: Install itk pre-release (Linux only) + run: | + python -m pip install --pre -U itk - name: Install the dependencies run: | python -m pip install torch==1.12.1 torchvision==0.13.1 diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 087d4d1950..34e1368fe2 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -318,7 +318,12 @@ def _get_meta_dict(self, img) -> Dict: """ img_meta_dict = img.GetMetaDataDictionary() - meta_dict = {key: img_meta_dict[key] for key in img_meta_dict.GetKeys() if not key.startswith("ITK_")} + meta_dict = {} + for key in img_meta_dict.GetKeys(): + if key.startswith("ITK_"): + continue + val = img_meta_dict[key] + meta_dict[key] = np.asarray(val) if type(val).__name__.startswith("itk") else val meta_dict["spacing"] = np.asarray(img.GetSpacing()) return meta_dict diff --git a/tests/test_load_image.py b/tests/test_load_image.py index 834c9f9d59..1db39a310b 100644 --- a/tests/test_load_image.py +++ b/tests/test_load_image.py @@ -315,7 +315,7 @@ def test_channel_dim(self, input_param, filename, expected_shape): with tempfile.TemporaryDirectory() as tempdir: filename = os.path.join(tempdir, filename) nib.save(nib.Nifti1Image(test_image, np.eye(4)), filename) - result = LoadImage(image_only=True, **input_param)(filename) + result = LoadImage(image_only=True, **input_param)(filename) # with itk, meta has 'qto_xyz': itkMatrixF44 self.assertTupleEqual( result.shape, (3, 128, 128, 128) if input_param.get("ensure_channel_first", False) else expected_shape From f2018835d4de600c89b895ecd3e16527c24f5e37 Mon Sep 17 00:00:00 2001 From: Wenqi Li <831580+wyli@users.noreply.github.com> Date: Sat, 29 Oct 2022 15:04:40 +0100 Subject: [PATCH 23/52] 4922 adding a minimal lazy transform interface (#5407) follow-up of #4922 ### Description - minimal interface to track the pending transforms via metatensor - transforms.Flip is modified as an example for discussion - discussion points: - maintaining `pending_operations` and `applied_operations` independently? - the data structure for `pending_operations` element is a python dictionary - transform "functional" refactoring ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Wenqi Li --- monai/data/meta_obj.py | 14 ++++++++++++++ monai/data/meta_tensor.py | 18 ++++++++++++++++-- monai/utils/__init__.py | 1 + monai/utils/enums.py | 14 ++++++++++++++ tests/test_meta_tensor.py | 9 +++++++++ 5 files changed, 54 insertions(+), 2 deletions(-) diff --git a/monai/data/meta_obj.py b/monai/data/meta_obj.py index 5061efc1ce..6aab05dc94 100644 --- a/monai/data/meta_obj.py +++ b/monai/data/meta_obj.py @@ -82,6 +82,7 @@ class MetaObj: def __init__(self): self._meta: dict = MetaObj.get_default_meta() self._applied_operations: list = MetaObj.get_default_applied_operations() + self._pending_operations: list = MetaObj.get_default_applied_operations() # the same default as applied_ops self._is_batch: bool = False @staticmethod @@ -199,6 +200,19 @@ def push_applied_operation(self, t: Any) -> None: def pop_applied_operation(self) -> Any: return self._applied_operations.pop() + @property + def pending_operations(self) -> list[dict]: + """Get the pending operations. Defaults to ``[]``.""" + if hasattr(self, "_pending_operations"): + return self._pending_operations + return MetaObj.get_default_applied_operations() # the same default as applied_ops + + def push_pending_operation(self, t: Any) -> None: + self._pending_operations.append(t) + + def pop_pending_operation(self) -> Any: + return self._pending_operations.pop() + @property def is_batch(self) -> bool: """Return whether object is part of batch or not.""" diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 5a7d81ad8e..493aef848b 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -23,8 +23,8 @@ from monai.data.meta_obj import MetaObj, get_track_meta from monai.data.utils import affine_to_spacing, decollate_batch, list_data_collate, remove_extra_metadata from monai.utils import look_up_option -from monai.utils.enums import MetaKeys, PostFix, SpaceKeys -from monai.utils.type_conversion import convert_data_type, convert_to_tensor +from monai.utils.enums import LazyAttr, MetaKeys, PostFix, SpaceKeys +from monai.utils.type_conversion import convert_data_type, convert_to_numpy, convert_to_tensor __all__ = ["MetaTensor"] @@ -445,6 +445,20 @@ def pixdim(self): return [affine_to_spacing(a) for a in self.affine] return affine_to_spacing(self.affine) + def peek_pending_shape(self): + """Get the currently expected spatial shape as if all the pending operations are executed.""" + res = None + if self.pending_operations: + res = self.pending_operations[-1].get(LazyAttr.SHAPE, None) + # default to spatial shape (assuming channel-first input) + return tuple(convert_to_numpy(self.shape, wrap_sequence=True).tolist()[1:]) if res is None else res + + def peek_pending_affine(self): + res = None + if self.pending_operations: + res = self.pending_operations[-1].get(LazyAttr.AFFINE, None) + return self.affine if res is None else res + def new_empty(self, size, dtype=None, device=None, requires_grad=False): """ must be defined for deepcopy to work diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index c5419cb9af..21d3621090 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -34,6 +34,7 @@ InterpolateMode, InverseKeys, JITMetadataKeys, + LazyAttr, LossReduction, MetaKeys, Method, diff --git a/monai/utils/enums.py b/monai/utils/enums.py index 79edbd7451..4fd9bea557 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -54,6 +54,7 @@ "AlgoEnsembleKeys", "HoVerNetMode", "HoVerNetBranch", + "LazyAttr", ] @@ -616,3 +617,16 @@ class HoVerNetBranch(StrEnum): HV = "horizontal_vertical" NP = "nucleus_prediction" NC = "type_prediction" + + +class LazyAttr(StrEnum): + """ + MetaTensor with pending operations requires some key attributes tracked especially when the primary array + is not up-to-date due to lazy evaluation. + This class specifies the set of key attributes to be tracked for each MetaTensor. + """ + + SHAPE = "lazy_shape" # spatial shape + AFFINE = "lazy_affine" + PADDING_MODE = "lazy_padding_mode" + INTERP_MODE = "lazy_interpolation_mode" diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index b46905f3c1..20d25ef61c 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -495,6 +495,15 @@ def test_construct_with_pre_applied_transforms(self): m = MetaTensor(im, applied_operations=data["im"].applied_operations) self.assertEqual(len(m.applied_operations), len(tr.transforms)) + def test_pending_ops(self): + m, _ = self.get_im() + self.assertEqual(m.pending_operations, []) + self.assertEqual(m.peek_pending_shape(), (10, 8)) + self.assertIsInstance(m.peek_pending_affine(), torch.Tensor) + m.push_pending_operation({}) + self.assertEqual(m.peek_pending_shape(), (10, 8)) + self.assertIsInstance(m.peek_pending_affine(), torch.Tensor) + @parameterized.expand(TESTS) def test_multiprocessing(self, device=None, dtype=None): """multiprocessing sharing with 'device' and 'dtype'""" From 5b9345b4dd9c3e1b6acef1ab88e8b16d5c403911 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Thu, 27 Oct 2022 18:14:18 +0100 Subject: [PATCH 24/52] Apply and MetaMatrix; partial functionality Signed-off-by: Ben Murray --- monai/transforms/__init__.py | 3 + monai/transforms/apply.py | 229 ++++++++++++++++++++++++++++++++ monai/transforms/meta_matrix.py | 213 +++++++++++++++++++++++++++++ monai/transforms/utils.py | 97 ++++++++++++++ tests/test_apply.py | 9 ++ tests/test_matmul.py | 129 ++++++++++++++++++ 6 files changed, 680 insertions(+) create mode 100644 monai/transforms/apply.py create mode 100644 monai/transforms/meta_matrix.py create mode 100644 tests/test_apply.py create mode 100644 tests/test_matmul.py diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 9cabc167a7..51e83a94e8 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -10,6 +10,7 @@ # limitations under the License. from .adaptors import FunctionSignature, adaptor, apply_alias, to_kwargs +from .apply import Apply, apply from .compose import Compose, OneOf from .croppad.array import ( BorderPad, @@ -632,6 +633,8 @@ generate_label_classes_crop_centers, generate_pos_neg_label_crop_centers, generate_spatial_bounding_box, + get_backend_from_tensor_like, + get_device_from_tensor_like, get_extreme_points, get_largest_connected_component_mask, get_number_image_type_conversions, diff --git a/monai/transforms/apply.py b/monai/transforms/apply.py new file mode 100644 index 0000000000..c58505e929 --- /dev/null +++ b/monai/transforms/apply.py @@ -0,0 +1,229 @@ +from typing import Optional, Sequence, Union + +import itertools as it + +import numpy as np + +import torch + +from monai.config import DtypeLike +from monai.data import MetaTensor +from monai.transforms import Affine +from monai.transforms.inverse import InvertibleTransform +from monai.transforms.transform import MapTransform +from monai.utils import GridSampleMode, GridSamplePadMode +from monai.transforms.utils import get_backend_from_tensor_like, get_device_from_tensor_like, dtypes_to_str_or_identity +from monai.transforms.meta_matrix import MatrixFactory, MetaMatrix, Matrix, matmul + +__all__ = [ + "apply", + "Apply" +] + +# TODO: This should move to a common place to be shared with dictionary +#from monai.utils.type_conversion import dtypes_to_str_or_identity + +GridSampleModeSequence = Union[Sequence[Union[GridSampleMode, str]], GridSampleMode, str] +GridSamplePadModeSequence = Union[Sequence[Union[GridSamplePadMode, str]], GridSamplePadMode, str] +DtypeSequence = Union[Sequence[DtypeLike], DtypeLike] + + +# TODO: move to mapping_stack.py +def extents_from_shape(shape, dtype=np.float64): + extents = [[0, shape[i]] for i in range(1, len(shape))] + + extents = it.product(*extents) + return list(np.asarray(e + (1,), dtype=dtype) for e in extents) + + +# TODO: move to mapping_stack.py +def shape_from_extents( + src_shape: Sequence, + extents: Union[Sequence[np.ndarray], Sequence[torch.Tensor], np.ndarray, torch.Tensor] +): + if isinstance(extents, (list, tuple)): + if isinstance(extents[0], np.ndarray): + aextents = np.asarray(extents) + aextents = torch.from_numpy(aextents) + else: + aextents = torch.stack(extents) + else: + if isinstance(extents, np.ndarray): + aextents = torch.from_numpy(extents) + else: + aextents = extents + + mins = aextents.min(axis=0)[0] + maxes = aextents.max(axis=0)[0] + values = torch.round(maxes - mins).type(torch.IntTensor)[:-1] + return torch.cat((torch.as_tensor([src_shape[0]]), values)) + + +def metadata_is_compatible(value_1, value_2): + if value_1 is None: + return True + else: + if value_2 is None: + return True + return value_1 == value_2 + + +def metadata_dtype_is_compatible(value_1, value_2): + if value_1 is None: + return True + else: + if value_2 is None: + return True + + # if we are here, value_1 and value_2 are both set + # TODO: this is not a good enough solution + value_1_ = dtypes_to_str_or_identity(value_1) + value_2_ = dtypes_to_str_or_identity(value_2) + return value_1_ == value_2_ + + +def starting_matrix_and_extents(matrix_factory, data): + # set up the identity matrix and metadata + cumulative_matrix = matrix_factory.identity() + cumulative_extents = extents_from_shape(data.shape) + return cumulative_matrix, cumulative_extents + + +def prepare_args_dict_for_apply(cur_mode, cur_padding_mode, cur_device, cur_dtype): + kwargs = {} + if cur_mode is not None: + kwargs['mode'] = cur_mode + if cur_padding_mode is not None: + kwargs['padding_mode'] = cur_padding_mode + if cur_device is not None: + kwargs['device'] = cur_device + if cur_dtype is not None: + kwargs['dtype'] = cur_dtype + + return kwargs + + +def matrix_from_matrix_container(matrix): + if isinstance(matrix, MetaMatrix): + return matrix.matrix.data + elif isinstance(matrix, Matrix): + return matrix.data + else: + return matrix + + +def apply(data: Union[torch.Tensor, MetaTensor], + pending: Optional[dict, list] = None): + + # TODO: if data is a dict, then pending must also be a dict + if isinstance(data, dict): + rd = dict() + for k, v in data.items(): + result = apply(v, pending) + rd[k] = result + return rd + + if isinstance(data, MetaTensor) or pending is not None: + pending_ = [] if pending is None else pending + else: + pending_ = data.pending_transforms + + if len(pending_) == 0: + return data + + dim_count = len(data.shape) - 1 + matrix_factory = MatrixFactory(dim_count, + get_backend_from_tensor_like(data), + get_device_from_tensor_like(data)) + + # set up the identity matrix and metadata + cumulative_matrix, cumulative_extents = starting_matrix_and_extents(matrix_factory, data) + + # set the various resampling parameters to an initial state + cur_mode = None + cur_padding_mode = None + cur_device = None + cur_dtype = None + cur_shape = data.shape + + for meta_matrix in pending_: + next_matrix = meta_matrix.data + # print("intermediate matrix\n", matrix_from_matrix_container(cumulative_matrix)) + cumulative_matrix = matmul(cumulative_matrix, next_matrix) + cumulative_extents = [matmul(e, cumulative_matrix) for e in cumulative_extents] + + new_mode = meta_matrix.metadata.get('mode', None) + new_padding_mode = meta_matrix.metadata.get('padding_mode', None) + new_device = meta_matrix.metadata.get('device', None) + new_dtype = meta_matrix.metadata.get('dtype', None) + new_shape = meta_matrix.metadata.get('shape_override', None) + + mode_compat = metadata_is_compatible(cur_mode, new_mode) + padding_mode_compat = metadata_is_compatible(cur_padding_mode, new_padding_mode) + device_compat = metadata_is_compatible(cur_device, new_device) + dtype_compat = metadata_dtype_is_compatible(cur_dtype, new_dtype) + + if (mode_compat is False or padding_mode_compat is False or + device_compat is False or dtype_compat is False): + # carry out an intermediate resample here due to incompatibility between arguments + kwargs = prepare_args_dict_for_apply(cur_mode, cur_padding_mode, cur_device, cur_dtype) + + cumulative_matrix_ = matrix_from_matrix_container(cumulative_matrix) + a = Affine(norm_coords=False, + affine=cumulative_matrix_, + **kwargs) + data, _ = a(img=data) + + cumulative_matrix, cumulative_extents =\ + starting_matrix_and_extents(matrix_factory, data) + + cur_mode = cur_mode if new_mode is None else new_mode + cur_padding_mode = cur_padding_mode if new_padding_mode is None else new_padding_mode + cur_device = cur_device if new_device is None else new_device + cur_dtype = cur_dtype if new_dtype is None else new_dtype + cur_shape = cur_shape if new_shape is None else new_shape + + kwargs = prepare_args_dict_for_apply(cur_mode, cur_padding_mode, cur_device, cur_dtype) + + cumulative_matrix_ = matrix_from_matrix_container(cumulative_matrix) + + # print(f"applying with cumulative matrix\n {cumulative_matrix_}") + a = Affine(norm_coords=False, + affine=cumulative_matrix_, + spatial_size=cur_shape[1:], + normalized=False, + **kwargs) + data, tx = a(img=data) + data.clear_pending_transforms() + + return data + + +# make Apply universal for arrays and dictionaries; it just calls through to functional apply +class Apply(InvertibleTransform): + + def __init__(self): + super().__init__() + + def __call__(self, *args, **kwargs): + return apply(*args, **kwargs) + + def inverse(self, data): + return NotImplementedError() + + +# class Applyd(MapTransform, InvertibleTransform): +# +# def __init__(self): +# super().__init__() +# +# def __call__( +# self, +# d: dict +# ): +# rd = dict() +# for k, v in d.items(): +# rd[k] = apply(v) +# +# def inverse(self, data): +# return NotImplementedError() diff --git a/monai/transforms/meta_matrix.py b/monai/transforms/meta_matrix.py new file mode 100644 index 0000000000..3d0da051de --- /dev/null +++ b/monai/transforms/meta_matrix.py @@ -0,0 +1,213 @@ +from typing import Optional, Union + +import numpy as np + +import torch + +from monai.config import NdarrayOrTensor + + +def is_matrix_shaped(data): + + return (len(data.shape) == 2 and data.shape[0] in (3, 4) and + data.shape[1] in (3, 4) and data.shape[0] == data.shape[1]) + + +def is_grid_shaped(data): + + return (len(data.shape) == 3 and data.shape[0] == 3 or + len(data.shape) == 4 and data.shape[0] == 4) + + +class MatrixFactory: + pass + + +def ensure_tensor(data: NdarrayOrTensor): + if isinstance(data, torch.Tensor): + return data + + return torch.as_tensor(data) + + +class Matrix: + + def __init__(self, matrix: NdarrayOrTensor): + self.data = ensure_tensor(matrix) + + # def __matmul__(self, other): + # if isinstance(other, Matrix): + # other_matrix = other.data + # else: + # other_matrix = other + # return self.data @ other_matrix + # + # def __rmatmul__(self, other): + # return other.__matmul__(self.data) + + +class Grid: + def __init__(self, grid): + self.data = ensure_tensor(grid) + + # def __matmul__(self, other): + # raise NotImplementedError() + + +class MetaMatrix: + + def __init__( + self, + matrix: Union[NdarrayOrTensor, Matrix, Grid], + metadata: Optional[dict] = None): + + if not isinstance(matrix, (Matrix, Grid)): + if matrix.shape == 2: + if matrix.shape[0] != matrix.shape[1] or matrix.shape[0] not in (3, 4): + raise ValueError("If 'matrix' is passed a numpy ndarray/torch Tensor, it must" + f" be 3x3 or 4x4 ('matrix' has has shape {matrix.shape})") + matrix_ = Matrix(matrix) + else: + matrix_ = matrix + self.matrix = matrix_ + + self.metadata = metadata or {} + + def __matmul__(self, other): + if isinstance(other, MetaMatrix): + other_ = other.matrix + else: + other_ = other + return MetaMatrix(self.matrix @ other_) + + def __rmatmul__(self, other): + if isinstance(other, MetaMatrix): + other_ = other.matrix + else: + other_ = other + return MetaMatrix(other_ @ self.matrix) + + +def matmul( + left: Union[MetaMatrix, Grid, Matrix, NdarrayOrTensor], + right: Union[MetaMatrix, Grid, Matrix, NdarrayOrTensor] +): + matrix_types = (MetaMatrix, Grid, Matrix, torch.Tensor, np.ndarray) + + if not isinstance(left, matrix_types): + raise TypeError(f"'left' must be one of {matrix_types} but is {type(left)}") + if not isinstance(right, matrix_types): + raise TypeError(f"'second' must be one of {matrix_types} but is {type(right)}") + + left_ = left.matrix if isinstance(left, MetaMatrix) else left + right_ = right.matrix if isinstance(right, MetaMatrix) else right + + # TODO: it might be better to not return a metamatrix, unless we pass in the resulting + # metadata also + put_in_metamatrix = isinstance(left, MetaMatrix) or isinstance(right, MetaMatrix) + + put_in_grid = isinstance(left, Grid) or isinstance(right, Grid) + + put_in_matrix = isinstance(left, Matrix) or isinstance(right, Matrix) + put_in_matrix = False if put_in_grid is True else put_in_matrix + + promote_to_tensor = not (isinstance(left_, np.ndarray) and isinstance(right_, np.ndarray)) + + left_raw = left_.data if isinstance(left_, (Matrix, Grid)) else left_ + right_raw = right_.data if isinstance(right_, (Matrix, Grid)) else right_ + + if promote_to_tensor: + left_raw = torch.as_tensor(left_raw) + right_raw = torch.as_tensor(right_raw) + + if isinstance(left_, Grid): + if isinstance(right_, Grid): + raise RuntimeError("Unable to matrix multiply two Grids") + else: + result = matmul_grid_matrix(left_raw, right_raw) + else: + if isinstance(right_, Grid): + result = matmul_matrix_grid(left_raw, right_raw) + else: + result = matmul_matrix_matrix(left_raw, right_raw) + + if put_in_grid: + result = Grid(result) + elif put_in_matrix: + result = Matrix(result) + + if put_in_metamatrix: + result = MetaMatrix(result) + + return result + + +def matmul_matrix_grid( + left: NdarrayOrTensor, + right: NdarrayOrTensor +): + if not is_matrix_shaped(left): + raise ValueError(f"'left' should be a 2D or 3D homogenous matrix but has shape {left.shape}") + + if not is_grid_shaped(right): + raise ValueError("'right' should be a 3D array with shape[0] == 2 or a " + f"4D array with shape[0] == 3 but has shape {right.shape}") + + # flatten the grid to take advantage of torch batch matrix multiply + right_flat = right.reshape(right.shape[0], -1) + result_flat = left @ right_flat + # restore the grid shape + result = result_flat.reshape((-1,) + result_flat.shape[1:]) + return result + + +def matmul_grid_matrix( + left: NdarrayOrTensor, + right: NdarrayOrTensor +): + if not is_grid_shaped(left): + raise ValueError("'left' should be a 3D array with shape[0] == 2 or a " + f"4D array with shape[0] == 3 but has shape {left.shape}") + + if not is_matrix_shaped(right): + raise ValueError(f"'right' should be a 2D or 3D homogenous matrix but has shape {right.shape}") + + try: + inv_matrix = torch.inverse(right) + except RuntimeError: + # the matrix is not invertible, so we will have to perform a slow grid to matrix operation + return matmul_grid_matrix_slow(left, right) + + # invert the matrix and swap the arguments, taking advantage of + # matrix @ vector == vector_transposed @ matrix_inverse + return matmul_matrix_grid(torch.inverse(right), left) + + +def matmul_grid_matrix_slow( + left: NdarrayOrTensor, + right: NdarrayOrTensor +): + if not is_grid_shaped(left): + raise ValueError("'left' should be a 3D array with shape[0] == 2 or a " + f"4D array with shape[0] == 3 but has shape {left.shape}") + + if not is_matrix_shaped(right): + raise ValueError(f"'right' should be a 2D or 3D homogenous matrix but has shape {right.shape}") + + flat_left = left.reshape(left.shape[0], -1) + result_flat = torch.zeros_like(flat_left) + for i in range(flat_left.shape[1]): + vector = flat_left[:, i][None, :] + result_vector = vector @ right + result_flat[:, i] = result_vector[0, :] + + # restore the grid shape + result = result_flat.reshape((-1,) + result_flat.shape[1:]) + return result + + +def matmul_matrix_matrix( + left: NdarrayOrTensor, + right: NdarrayOrTensor, +): + return left @ right diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index e96d906f20..091d3e5741 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -85,6 +85,8 @@ "generate_label_classes_crop_centers", "generate_pos_neg_label_crop_centers", "generate_spatial_bounding_box", + "get_backend_from_tensor_like", + "get_device_from_tensor_like", "get_extreme_points", "get_largest_connected_component_mask", "remove_small_objects", @@ -1790,5 +1792,100 @@ def squarepulse(sig, duty: float = 0.5): return y +def get_device_from_tensor_like( + data: NdarrayOrTensor +): + """ + This function returns the device of `data`, which must either be a numpy ndarray or a + pytorch Tensor. + + Args: + data: the ndarray/tensor to return the device of + + Returns: + None if `data` is a numpy array, or the device of the pytorch Tensor + """ + if isinstance(data, np.ndarray): + return None + elif isinstance(data, torch.Tensor): + return data.device + else: + msg = "'data' must be one of numpy ndarray or torch Tensor but is {}" + raise ValueError(msg.format(type(data))) + + +def get_backend_from_tensor_like( + data: NdarrayOrTensor +): + """ + This function returns the backend of `data`, which must either be a numpy ndarray or a + pytorch Tensor. + + Args: + data: the ndarray/tensor to return the device of + + Returns: + None if `data` is a numpy array, or the device of the pytorch Tensor + """ + if isinstance(data, np.ndarray): + return TransformBackends.NUMPY + elif isinstance(data, torch.Tensor): + return TransformBackends.TORCH + else: + msg = "'data' must be one of numpy ndarray or torch Tensor but is {}" + raise ValueError(msg.format(type(data))) + + +def dtype_torch_to_numpy(dtype: torch.dtype) -> np.dtype: + """Convert a torch dtype to its numpy equivalent.""" + return torch.empty([], dtype=dtype).numpy().dtype # type: ignore + + +def dtype_numpy_to_torch(dtype: np.dtype) -> torch.dtype: + """Convert a numpy dtype to its torch equivalent.""" + return torch.from_numpy(np.empty([], dtype=dtype)).dtype + + +__dtype_dict = { + np.int8: 'int8', + torch.int8: 'int8', + np.int16: 'int16', + torch.int16: 'int16', + int: 'int32', + np.int32: 'int32', + torch.int32: 'int32', + np.int64: 'int64', + torch.int64: 'int64', + np.uint8: 'uint8', + torch.uint8: 'uint8', + np.uint16: 'uint16', + np.uint32: 'uint32', + np.uint64: 'uint64', + float: 'float32', + np.float16: 'float16', + np.float: 'float32', + np.float32: 'float32', + np.float64: 'float64', + torch.float16: 'float16', + torch.float: 'float32', + torch.float32: 'float32', + torch.double: 'float64', + torch.float64: 'float64' +} + +def dtypes_to_str_or_identity(dtype: Any) -> Any: + return __dtype_dict.get(dtype, dtype) + + +def get_numpy_dtype_from_string(dtype: str) -> np.dtype: + """Get a numpy dtype (e.g., `np.float32`) from its string (e.g., `"float32"`).""" + return np.empty([], dtype=dtype).dtype + + +def get_torch_dtype_from_string(dtype: str) -> torch.dtype: + """Get a torch dtype (e.g., `torch.float32`) from its string (e.g., `"float32"`).""" + return dtype_numpy_to_torch(get_numpy_dtype_from_string(dtype)) + + if __name__ == "__main__": print_transform_backends() diff --git a/tests/test_apply.py b/tests/test_apply.py new file mode 100644 index 0000000000..0e8500544f --- /dev/null +++ b/tests/test_apply.py @@ -0,0 +1,9 @@ +import unittest + +from monai.transforms import apply + + +class TestApply(unittest.TestCase): + + def _test_apply_impl(self): + result = apply(None) diff --git a/tests/test_matmul.py b/tests/test_matmul.py new file mode 100644 index 0000000000..19fa2ab36b --- /dev/null +++ b/tests/test_matmul.py @@ -0,0 +1,129 @@ +import unittest +from parameterized import parameterized + +import numpy as np + +import torch + +from monai.transforms.meta_matrix import ( + Grid, is_grid_shaped, is_matrix_shaped, matmul, matmul_matrix_grid, matmul_grid_matrix_slow, + matmul_grid_matrix, Matrix +) + + +class TestMatmulFunctions(unittest.TestCase): + + def test_matrix_grid_and_grid_inv_matrix_are_equivalent(self): + grid = torch.randn((3, 32, 32)) + + matrix = torch.eye(3, 3) + matrix[:, 0] = torch.FloatTensor([0, -1, 0]) + matrix[:, 1] = torch.FloatTensor([1, 0, 0]) + + inv_matrix = torch.inverse(matrix) + + result1 = matmul_matrix_grid(matrix, grid) + result2 = matmul_grid_matrix(grid, inv_matrix) + self.assertTrue(torch.allclose(result1, result2)) + + def test_matmul_grid_matrix_slow(self): + grid = torch.randn((3, 32, 32)) + + matrix = torch.eye(3, 3) + matrix[:, 0] = torch.FloatTensor([0, -1, 0]) + matrix[:, 1] = torch.FloatTensor([1, 0, 0]) + + result1 = matmul_grid_matrix_slow(grid, matrix) + result2 = matmul_grid_matrix(grid, matrix) + self.assertTrue(torch.allclose(result1, result2)) + + MATMUL_TEST_CASES = [ + [np.eye(3, dtype=np.float32), np.eye(3, dtype=np.float32), np.ndarray], + [np.eye(3, dtype=np.float32), torch.eye(3), torch.Tensor], + [np.eye(3, dtype=np.float32), Matrix(torch.eye(3)), Matrix], + [np.eye(3, dtype=np.float32), Grid(torch.randn((3, 8, 8))), Grid], + [torch.eye(3), np.eye(3, dtype=np.float32), torch.Tensor], + [torch.eye(3), torch.eye(3), torch.Tensor], + [torch.eye(3), Matrix(torch.eye(3)), Matrix], + [torch.eye(3), Grid(torch.randn((3, 8, 8))), Grid], + [Matrix(torch.eye(3)), np.eye(3, dtype=np.float32), Matrix], + [Matrix(torch.eye(3)), torch.eye(3), Matrix], + [Matrix(torch.eye(3)), Matrix(torch.eye(3)), Matrix], + [Matrix(torch.eye(3)), Grid(torch.randn((3, 8, 8))), Grid], + [Grid(torch.randn((3, 8, 8))), np.eye(3, dtype=np.float32), Grid], + [Grid(torch.randn((3, 8, 8))), torch.eye(3), Grid], + [Grid(torch.randn((3, 8, 8))), Matrix(torch.eye(3)), Grid], + [Grid(torch.randn((3, 8, 8))), Grid(torch.randn((3, 8, 8))), None], + ] + + def _test_matmul_correct_return_type_impl(self, left, right, expected): + if expected is None: + with self.assertRaises(RuntimeError): + result = matmul(left, right) + else: + result = matmul(left, right) + self.assertIsInstance(result, expected) + + @parameterized.expand(MATMUL_TEST_CASES) + def test_matmul_correct_return_type(self, left, right, expected): + self._test_matmul_correct_return_type_impl(left, right, expected) + + # def test_all_matmul_correct_return_type(self): + # for case in self.MATMUL_TEST_CASES: + # with self.subTest(f"{case}"): + # self._test_matmul_correct_return_type_impl(*case) + + MATRIX_SHAPE_TESTCASES = [ + (torch.randn(2, 2), False), + (torch.randn(3, 3), True), + (torch.randn(4, 4), True), + (torch.randn(5, 5), False), + (torch.randn(3, 4), False), + (torch.randn(4, 3), False), + (torch.randn(3), False), + (torch.randn(4), False), + (torch.randn(5), False), + (torch.randn(3, 3, 3), False), + (torch.randn(4, 4, 4), False), + (torch.randn(5, 5, 5), False) + ] + + def _test_is_matrix_shaped_impl(self, matrix, expected): + self.assertEqual(is_matrix_shaped(matrix), expected) + + @parameterized.expand(MATRIX_SHAPE_TESTCASES) + def test_is_matrix_shaped(self, matrix, expected): + self._test_is_matrix_shaped_impl(matrix, expected) + + # def test_all_is_matrix_shaped(self): + # for case in self.MATRIX_SHAPE_TESTCASES: + # with self.subTest(f"{case[0].shape}"): + # self._test_is_matrix_shaped_impl(*case) + + + GRID_SHAPE_TESTCASES = [ + (torch.randn(1, 16, 32), False), + (torch.randn(2, 16, 32), False), + (torch.randn(3, 16, 32), True), + (torch.randn(4, 16, 32), False), + (torch.randn(5, 16, 32), False), + (torch.randn(3, 16, 32, 64), False), + (torch.randn(4, 16, 32, 64), True), + (torch.randn(5, 16, 32, 64), False) + ] + + def _test_is_grid_shaped_impl(self, grid, expected): + self.assertEqual(is_grid_shaped(grid), expected) + + @parameterized.expand(GRID_SHAPE_TESTCASES) + def test_is_grid_shaped(self, grid, expected): + self._test_is_grid_shaped_impl(grid, expected) + + # def test_all_is_grid_shaped(self): + # for case in self.GRID_SHAPE_TESTCASES: + # with self.subTest(f"{case[0].shape}"): + # self._test_is_grid_shaped_impl(*case) + + +if __name__ == "__main__": + unittest.main() From e2018534869cf0f410f3cc158a2fac10385d2336 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Thu, 27 Oct 2022 21:04:47 +0100 Subject: [PATCH 25/52] making import for MetaTensor more specific to avoid circular reference Signed-off-by: Ben Murray --- monai/transforms/apply.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/apply.py b/monai/transforms/apply.py index c58505e929..1b963adf8f 100644 --- a/monai/transforms/apply.py +++ b/monai/transforms/apply.py @@ -7,7 +7,7 @@ import torch from monai.config import DtypeLike -from monai.data import MetaTensor +from monai.data.meta_tensor import MetaTensor from monai.transforms import Affine from monai.transforms.inverse import InvertibleTransform from monai.transforms.transform import MapTransform From 4f5c44eb4d5b5028f3328717acb41123a83eb97f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 27 Oct 2022 17:16:45 +0000 Subject: [PATCH 26/52] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/transforms/apply.py | 1 - 1 file changed, 1 deletion(-) diff --git a/monai/transforms/apply.py b/monai/transforms/apply.py index 1b963adf8f..d613540952 100644 --- a/monai/transforms/apply.py +++ b/monai/transforms/apply.py @@ -10,7 +10,6 @@ from monai.data.meta_tensor import MetaTensor from monai.transforms import Affine from monai.transforms.inverse import InvertibleTransform -from monai.transforms.transform import MapTransform from monai.utils import GridSampleMode, GridSamplePadMode from monai.transforms.utils import get_backend_from_tensor_like, get_device_from_tensor_like, dtypes_to_str_or_identity from monai.transforms.meta_matrix import MatrixFactory, MetaMatrix, Matrix, matmul From 4370e5a348a5c21d8bb325b78848f7455a88601f Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Thu, 27 Oct 2022 21:07:35 +0100 Subject: [PATCH 27/52] Making Affine import more specific in apply to avoid circular reference Signed-off-by: Ben Murray --- monai/transforms/apply.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/apply.py b/monai/transforms/apply.py index d613540952..9b56321eaf 100644 --- a/monai/transforms/apply.py +++ b/monai/transforms/apply.py @@ -8,7 +8,7 @@ from monai.config import DtypeLike from monai.data.meta_tensor import MetaTensor -from monai.transforms import Affine +from monai.transforms.spatial.array import Affine from monai.transforms.inverse import InvertibleTransform from monai.utils import GridSampleMode, GridSamplePadMode from monai.transforms.utils import get_backend_from_tensor_like, get_device_from_tensor_like, dtypes_to_str_or_identity From d959925fd9fbd3cdfd700dc0522967d96cd2850e Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Thu, 27 Oct 2022 21:10:02 +0100 Subject: [PATCH 28/52] Fixing typing signature issue on apply method Signed-off-by: Ben Murray --- monai/transforms/apply.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/apply.py b/monai/transforms/apply.py index 9b56321eaf..b2fa8fc708 100644 --- a/monai/transforms/apply.py +++ b/monai/transforms/apply.py @@ -112,7 +112,7 @@ def matrix_from_matrix_container(matrix): def apply(data: Union[torch.Tensor, MetaTensor], - pending: Optional[dict, list] = None): + pending: Optional[Union[dict, list]] = None): # TODO: if data is a dict, then pending must also be a dict if isinstance(data, dict): From 5f254c7a4df99b179f3b29ae86580b81700e734c Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Thu, 27 Oct 2022 21:20:39 +0100 Subject: [PATCH 29/52] Adding missing license boilerplate Signed-off-by: Ben Murray --- monai/transforms/apply.py | 11 ++++++++ monai/transforms/meta_matrix.py | 11 ++++++++ tests/tempscript.py | 47 +++++++++++++++++++++++++++++++++ tests/test_apply.py | 11 ++++++++ tests/test_matmul.py | 11 ++++++++ 5 files changed, 91 insertions(+) create mode 100644 tests/tempscript.py diff --git a/monai/transforms/apply.py b/monai/transforms/apply.py index b2fa8fc708..20d031b448 100644 --- a/monai/transforms/apply.py +++ b/monai/transforms/apply.py @@ -1,3 +1,14 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import Optional, Sequence, Union import itertools as it diff --git a/monai/transforms/meta_matrix.py b/monai/transforms/meta_matrix.py index 3d0da051de..559cd7e451 100644 --- a/monai/transforms/meta_matrix.py +++ b/monai/transforms/meta_matrix.py @@ -1,3 +1,14 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import Optional, Union import numpy as np diff --git a/tests/tempscript.py b/tests/tempscript.py new file mode 100644 index 0000000000..5a9e5945b6 --- /dev/null +++ b/tests/tempscript.py @@ -0,0 +1,47 @@ +import numpy as np + +import matplotlib.pyplot as plt + +import torch +from monai.utils import GridSampleMode, GridSamplePadMode + +from monai.transforms.atmostonce.apply import Applyd +from monai.transforms.atmostonce.dictionary import Rotated +from monai.transforms import Compose + + + +def test_rotate_tensor(): + r = Rotated(('image', 'label'), [0.0, 1.0, 0.0]) + + d = { + 'image': torch.zeros((1, 64, 64, 32), device="cpu", dtype=torch.float32), + 'label': torch.ones((1, 64, 64, 32), device="cpu", dtype=torch.int8) + } + d = r(d) + + for k, v in d.items(): + if isinstance(v, (np.ndarray, torch.Tensor)): + print(k, v.shape) + else: + print(k, v) + + +def test_rotate_apply(): + c = Compose([ + Rotated(('image', 'label'), (0.0, 3.14159265 / 2, 0.0)), + Applyd(('image', 'label'), + modes=(GridSampleMode.BILINEAR, GridSampleMode.NEAREST), + padding_modes=(GridSamplePadMode.BORDER, GridSamplePadMode.BORDER)) + ]) + + d = { + 'image': torch.zeros((1, 64, 64, 32), device="cpu", dtype=torch.float32), + 'label': torch.ones((1, 64, 64, 32), device="cpu", dtype=torch.int8) + } + plt.imshow(d['image'][0, ..., d['image'].shape[-1]//2]) + d = c(d) + plt.imshow(d['image'][0, ..., d['image'].shape[-1]//2]) + print(d['image'].shape) + +test_rotate_apply() diff --git a/tests/test_apply.py b/tests/test_apply.py index 0e8500544f..953c249db4 100644 --- a/tests/test_apply.py +++ b/tests/test_apply.py @@ -1,3 +1,14 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import unittest from monai.transforms import apply diff --git a/tests/test_matmul.py b/tests/test_matmul.py index 19fa2ab36b..bf427ce06c 100644 --- a/tests/test_matmul.py +++ b/tests/test_matmul.py @@ -1,3 +1,14 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import unittest from parameterized import parameterized From cd28575703a8ff8774158303343dde03e1532d0d Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Thu, 27 Oct 2022 21:32:27 +0100 Subject: [PATCH 30/52] Minimal docstrings for apply / Apply Signed-off-by: Ben Murray --- monai/transforms/apply.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/monai/transforms/apply.py b/monai/transforms/apply.py index 20d031b448..9290035c10 100644 --- a/monai/transforms/apply.py +++ b/monai/transforms/apply.py @@ -122,10 +122,17 @@ def matrix_from_matrix_container(matrix): return matrix -def apply(data: Union[torch.Tensor, MetaTensor], +def apply(data: Union[torch.Tensor, MetaTensor, dict], pending: Optional[Union[dict, list]] = None): - - # TODO: if data is a dict, then pending must also be a dict + """ + This method applies pending transforms to tensors. + + Args: + data: A torch Tensor, monai MetaTensor, or a dictionary containing Tensors / MetaTensors + pending: Optional arg containing pending transforms. This must be set if data is a Tensor + or dictionary of Tensors, but is optional if data is a MetaTensor / dictionary of + MetaTensors. + """ if isinstance(data, dict): rd = dict() for k, v in data.items(): @@ -211,7 +218,10 @@ def apply(data: Union[torch.Tensor, MetaTensor], # make Apply universal for arrays and dictionaries; it just calls through to functional apply class Apply(InvertibleTransform): - + """ + Apply wraps the apply method and can function as a Transform in either array or dictionary + mode. + """ def __init__(self): super().__init__() From 5653a64f1e21365f9f3f4f12951e8f39d78fc4ae Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Thu, 27 Oct 2022 21:48:44 +0100 Subject: [PATCH 31/52] Splitting apply.py into lazy/functional and lazy/array Signed-off-by: Ben Murray --- monai/transforms/__init__.py | 3 +- monai/transforms/lazy/__init__.py | 10 ++++ monai/transforms/lazy/array.py | 46 +++++++++++++++++++ .../{apply.py => lazy/functional.py} | 37 +-------------- 4 files changed, 59 insertions(+), 37 deletions(-) create mode 100644 monai/transforms/lazy/__init__.py create mode 100644 monai/transforms/lazy/array.py rename monai/transforms/{apply.py => lazy/functional.py} (90%) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 51e83a94e8..431c554f6f 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -10,7 +10,6 @@ # limitations under the License. from .adaptors import FunctionSignature, adaptor, apply_alias, to_kwargs -from .apply import Apply, apply from .compose import Compose, OneOf from .croppad.array import ( BorderPad, @@ -228,6 +227,8 @@ from .inverse_batch_transform import BatchInverseTransform, Decollated, DecollateD, DecollateDict from .io.array import SUPPORTED_READERS, LoadImage, SaveImage from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict +from .lazy.array import Apply +from .lazy.functional import apply from .meta_utility.dictionary import ( FromMetaTensord, FromMetaTensorD, diff --git a/monai/transforms/lazy/__init__.py b/monai/transforms/lazy/__init__.py new file mode 100644 index 0000000000..1e97f89407 --- /dev/null +++ b/monai/transforms/lazy/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/monai/transforms/lazy/array.py b/monai/transforms/lazy/array.py new file mode 100644 index 0000000000..b958010b4f --- /dev/null +++ b/monai/transforms/lazy/array.py @@ -0,0 +1,46 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from monai.transforms.inverse import InvertibleTransform + +from monai.transforms.lazy.functional import apply + + +class Apply(InvertibleTransform): + """ + Apply wraps the apply method and can function as a Transform in either array or dictionary + mode. + """ + def __init__(self): + super().__init__() + + def __call__(self, *args, **kwargs): + return apply(*args, **kwargs) + + def inverse(self, data): + return NotImplementedError() + + +# class Applyd(MapTransform, InvertibleTransform): +# +# def __init__(self): +# super().__init__() +# +# def __call__( +# self, +# d: dict +# ): +# rd = dict() +# for k, v in d.items(): +# rd[k] = apply(v) +# +# def inverse(self, data): +# return NotImplementedError() \ No newline at end of file diff --git a/monai/transforms/apply.py b/monai/transforms/lazy/functional.py similarity index 90% rename from monai/transforms/apply.py rename to monai/transforms/lazy/functional.py index 9290035c10..af320598b4 100644 --- a/monai/transforms/apply.py +++ b/monai/transforms/lazy/functional.py @@ -20,14 +20,12 @@ from monai.config import DtypeLike from monai.data.meta_tensor import MetaTensor from monai.transforms.spatial.array import Affine -from monai.transforms.inverse import InvertibleTransform from monai.utils import GridSampleMode, GridSamplePadMode from monai.transforms.utils import get_backend_from_tensor_like, get_device_from_tensor_like, dtypes_to_str_or_identity from monai.transforms.meta_matrix import MatrixFactory, MetaMatrix, Matrix, matmul __all__ = [ - "apply", - "Apply" + "apply" ] # TODO: This should move to a common place to be shared with dictionary @@ -214,36 +212,3 @@ def apply(data: Union[torch.Tensor, MetaTensor, dict], data.clear_pending_transforms() return data - - -# make Apply universal for arrays and dictionaries; it just calls through to functional apply -class Apply(InvertibleTransform): - """ - Apply wraps the apply method and can function as a Transform in either array or dictionary - mode. - """ - def __init__(self): - super().__init__() - - def __call__(self, *args, **kwargs): - return apply(*args, **kwargs) - - def inverse(self, data): - return NotImplementedError() - - -# class Applyd(MapTransform, InvertibleTransform): -# -# def __init__(self): -# super().__init__() -# -# def __call__( -# self, -# d: dict -# ): -# rd = dict() -# for k, v in d.items(): -# rd[k] = apply(v) -# -# def inverse(self, data): -# return NotImplementedError() From 70badbf8e2a086e763e41f130217ce724d742a0e Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Thu, 27 Oct 2022 21:54:57 +0100 Subject: [PATCH 32/52] Removing spurious test file tempscript Signed-off-by: Ben Murray --- tests/tempscript.py | 47 --------------------------------------------- 1 file changed, 47 deletions(-) delete mode 100644 tests/tempscript.py diff --git a/tests/tempscript.py b/tests/tempscript.py deleted file mode 100644 index 5a9e5945b6..0000000000 --- a/tests/tempscript.py +++ /dev/null @@ -1,47 +0,0 @@ -import numpy as np - -import matplotlib.pyplot as plt - -import torch -from monai.utils import GridSampleMode, GridSamplePadMode - -from monai.transforms.atmostonce.apply import Applyd -from monai.transforms.atmostonce.dictionary import Rotated -from monai.transforms import Compose - - - -def test_rotate_tensor(): - r = Rotated(('image', 'label'), [0.0, 1.0, 0.0]) - - d = { - 'image': torch.zeros((1, 64, 64, 32), device="cpu", dtype=torch.float32), - 'label': torch.ones((1, 64, 64, 32), device="cpu", dtype=torch.int8) - } - d = r(d) - - for k, v in d.items(): - if isinstance(v, (np.ndarray, torch.Tensor)): - print(k, v.shape) - else: - print(k, v) - - -def test_rotate_apply(): - c = Compose([ - Rotated(('image', 'label'), (0.0, 3.14159265 / 2, 0.0)), - Applyd(('image', 'label'), - modes=(GridSampleMode.BILINEAR, GridSampleMode.NEAREST), - padding_modes=(GridSamplePadMode.BORDER, GridSamplePadMode.BORDER)) - ]) - - d = { - 'image': torch.zeros((1, 64, 64, 32), device="cpu", dtype=torch.float32), - 'label': torch.ones((1, 64, 64, 32), device="cpu", dtype=torch.int8) - } - plt.imshow(d['image'][0, ..., d['image'].shape[-1]//2]) - d = c(d) - plt.imshow(d['image'][0, ..., d['image'].shape[-1]//2]) - print(d['image'].shape) - -test_rotate_apply() From 227b58b778c245e14d6c8f35fdf09765b96d7a8c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 27 Oct 2022 20:49:21 +0000 Subject: [PATCH 33/52] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/transforms/lazy/array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/lazy/array.py b/monai/transforms/lazy/array.py index b958010b4f..a3de0dfe7d 100644 --- a/monai/transforms/lazy/array.py +++ b/monai/transforms/lazy/array.py @@ -43,4 +43,4 @@ def inverse(self, data): # rd[k] = apply(v) # # def inverse(self, data): -# return NotImplementedError() \ No newline at end of file +# return NotImplementedError() From 47f00ca67e6f7e1c873c3c570486b2eaa12c0fd8 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Thu, 27 Oct 2022 22:13:08 +0100 Subject: [PATCH 34/52] Auto formatting fixes Signed-off-by: Ben Murray --- monai/transforms/lazy/array.py | 2 +- monai/transforms/lazy/functional.py | 58 ++++++++++---------------- monai/transforms/meta_matrix.py | 63 ++++++++++++----------------- monai/transforms/utils.py | 57 +++++++++++++------------- tests/test_apply.py | 1 - tests/test_matmul.py | 19 +++++---- 6 files changed, 86 insertions(+), 114 deletions(-) diff --git a/monai/transforms/lazy/array.py b/monai/transforms/lazy/array.py index a3de0dfe7d..bc61a24a6e 100644 --- a/monai/transforms/lazy/array.py +++ b/monai/transforms/lazy/array.py @@ -10,7 +10,6 @@ # limitations under the License. from monai.transforms.inverse import InvertibleTransform - from monai.transforms.lazy.functional import apply @@ -19,6 +18,7 @@ class Apply(InvertibleTransform): Apply wraps the apply method and can function as a Transform in either array or dictionary mode. """ + def __init__(self): super().__init__() diff --git a/monai/transforms/lazy/functional.py b/monai/transforms/lazy/functional.py index af320598b4..607e90d6bb 100644 --- a/monai/transforms/lazy/functional.py +++ b/monai/transforms/lazy/functional.py @@ -9,27 +9,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Sequence, Union - import itertools as it +from typing import Optional, Sequence, Union import numpy as np - import torch from monai.config import DtypeLike from monai.data.meta_tensor import MetaTensor +from monai.transforms.meta_matrix import Matrix, MatrixFactory, MetaMatrix, matmul from monai.transforms.spatial.array import Affine +from monai.transforms.utils import dtypes_to_str_or_identity, get_backend_from_tensor_like, get_device_from_tensor_like from monai.utils import GridSampleMode, GridSamplePadMode -from monai.transforms.utils import get_backend_from_tensor_like, get_device_from_tensor_like, dtypes_to_str_or_identity -from monai.transforms.meta_matrix import MatrixFactory, MetaMatrix, Matrix, matmul -__all__ = [ - "apply" -] +__all__ = ["apply"] # TODO: This should move to a common place to be shared with dictionary -#from monai.utils.type_conversion import dtypes_to_str_or_identity +# from monai.utils.type_conversion import dtypes_to_str_or_identity GridSampleModeSequence = Union[Sequence[Union[GridSampleMode, str]], GridSampleMode, str] GridSamplePadModeSequence = Union[Sequence[Union[GridSamplePadMode, str]], GridSamplePadMode, str] @@ -46,8 +42,7 @@ def extents_from_shape(shape, dtype=np.float64): # TODO: move to mapping_stack.py def shape_from_extents( - src_shape: Sequence, - extents: Union[Sequence[np.ndarray], Sequence[torch.Tensor], np.ndarray, torch.Tensor] + src_shape: Sequence, extents: Union[Sequence[np.ndarray], Sequence[torch.Tensor], np.ndarray, torch.Tensor] ): if isinstance(extents, (list, tuple)): if isinstance(extents[0], np.ndarray): @@ -100,13 +95,13 @@ def starting_matrix_and_extents(matrix_factory, data): def prepare_args_dict_for_apply(cur_mode, cur_padding_mode, cur_device, cur_dtype): kwargs = {} if cur_mode is not None: - kwargs['mode'] = cur_mode + kwargs["mode"] = cur_mode if cur_padding_mode is not None: - kwargs['padding_mode'] = cur_padding_mode + kwargs["padding_mode"] = cur_padding_mode if cur_device is not None: - kwargs['device'] = cur_device + kwargs["device"] = cur_device if cur_dtype is not None: - kwargs['dtype'] = cur_dtype + kwargs["dtype"] = cur_dtype return kwargs @@ -120,8 +115,7 @@ def matrix_from_matrix_container(matrix): return matrix -def apply(data: Union[torch.Tensor, MetaTensor, dict], - pending: Optional[Union[dict, list]] = None): +def apply(data: Union[torch.Tensor, MetaTensor, dict], pending: Optional[Union[dict, list]] = None): """ This method applies pending transforms to tensors. @@ -147,9 +141,7 @@ def apply(data: Union[torch.Tensor, MetaTensor, dict], return data dim_count = len(data.shape) - 1 - matrix_factory = MatrixFactory(dim_count, - get_backend_from_tensor_like(data), - get_device_from_tensor_like(data)) + matrix_factory = MatrixFactory(dim_count, get_backend_from_tensor_like(data), get_device_from_tensor_like(data)) # set up the identity matrix and metadata cumulative_matrix, cumulative_extents = starting_matrix_and_extents(matrix_factory, data) @@ -167,30 +159,26 @@ def apply(data: Union[torch.Tensor, MetaTensor, dict], cumulative_matrix = matmul(cumulative_matrix, next_matrix) cumulative_extents = [matmul(e, cumulative_matrix) for e in cumulative_extents] - new_mode = meta_matrix.metadata.get('mode', None) - new_padding_mode = meta_matrix.metadata.get('padding_mode', None) - new_device = meta_matrix.metadata.get('device', None) - new_dtype = meta_matrix.metadata.get('dtype', None) - new_shape = meta_matrix.metadata.get('shape_override', None) + new_mode = meta_matrix.metadata.get("mode", None) + new_padding_mode = meta_matrix.metadata.get("padding_mode", None) + new_device = meta_matrix.metadata.get("device", None) + new_dtype = meta_matrix.metadata.get("dtype", None) + new_shape = meta_matrix.metadata.get("shape_override", None) mode_compat = metadata_is_compatible(cur_mode, new_mode) padding_mode_compat = metadata_is_compatible(cur_padding_mode, new_padding_mode) device_compat = metadata_is_compatible(cur_device, new_device) dtype_compat = metadata_dtype_is_compatible(cur_dtype, new_dtype) - if (mode_compat is False or padding_mode_compat is False or - device_compat is False or dtype_compat is False): + if mode_compat is False or padding_mode_compat is False or device_compat is False or dtype_compat is False: # carry out an intermediate resample here due to incompatibility between arguments kwargs = prepare_args_dict_for_apply(cur_mode, cur_padding_mode, cur_device, cur_dtype) cumulative_matrix_ = matrix_from_matrix_container(cumulative_matrix) - a = Affine(norm_coords=False, - affine=cumulative_matrix_, - **kwargs) + a = Affine(norm_coords=False, affine=cumulative_matrix_, **kwargs) data, _ = a(img=data) - cumulative_matrix, cumulative_extents =\ - starting_matrix_and_extents(matrix_factory, data) + cumulative_matrix, cumulative_extents = starting_matrix_and_extents(matrix_factory, data) cur_mode = cur_mode if new_mode is None else new_mode cur_padding_mode = cur_padding_mode if new_padding_mode is None else new_padding_mode @@ -203,11 +191,7 @@ def apply(data: Union[torch.Tensor, MetaTensor, dict], cumulative_matrix_ = matrix_from_matrix_container(cumulative_matrix) # print(f"applying with cumulative matrix\n {cumulative_matrix_}") - a = Affine(norm_coords=False, - affine=cumulative_matrix_, - spatial_size=cur_shape[1:], - normalized=False, - **kwargs) + a = Affine(norm_coords=False, affine=cumulative_matrix_, spatial_size=cur_shape[1:], normalized=False, **kwargs) data, tx = a(img=data) data.clear_pending_transforms() diff --git a/monai/transforms/meta_matrix.py b/monai/transforms/meta_matrix.py index 559cd7e451..e955914c45 100644 --- a/monai/transforms/meta_matrix.py +++ b/monai/transforms/meta_matrix.py @@ -12,7 +12,6 @@ from typing import Optional, Union import numpy as np - import torch from monai.config import NdarrayOrTensor @@ -20,14 +19,14 @@ def is_matrix_shaped(data): - return (len(data.shape) == 2 and data.shape[0] in (3, 4) and - data.shape[1] in (3, 4) and data.shape[0] == data.shape[1]) + return ( + len(data.shape) == 2 and data.shape[0] in (3, 4) and data.shape[1] in (3, 4) and data.shape[0] == data.shape[1] + ) def is_grid_shaped(data): - return (len(data.shape) == 3 and data.shape[0] == 3 or - len(data.shape) == 4 and data.shape[0] == 4) + return len(data.shape) == 3 and data.shape[0] == 3 or len(data.shape) == 4 and data.shape[0] == 4 class MatrixFactory: @@ -42,7 +41,6 @@ def ensure_tensor(data: NdarrayOrTensor): class Matrix: - def __init__(self, matrix: NdarrayOrTensor): self.data = ensure_tensor(matrix) @@ -66,17 +64,15 @@ def __init__(self, grid): class MetaMatrix: - - def __init__( - self, - matrix: Union[NdarrayOrTensor, Matrix, Grid], - metadata: Optional[dict] = None): + def __init__(self, matrix: Union[NdarrayOrTensor, Matrix, Grid], metadata: Optional[dict] = None): if not isinstance(matrix, (Matrix, Grid)): if matrix.shape == 2: if matrix.shape[0] != matrix.shape[1] or matrix.shape[0] not in (3, 4): - raise ValueError("If 'matrix' is passed a numpy ndarray/torch Tensor, it must" - f" be 3x3 or 4x4 ('matrix' has has shape {matrix.shape})") + raise ValueError( + "If 'matrix' is passed a numpy ndarray/torch Tensor, it must" + f" be 3x3 or 4x4 ('matrix' has has shape {matrix.shape})" + ) matrix_ = Matrix(matrix) else: matrix_ = matrix @@ -100,8 +96,7 @@ def __rmatmul__(self, other): def matmul( - left: Union[MetaMatrix, Grid, Matrix, NdarrayOrTensor], - right: Union[MetaMatrix, Grid, Matrix, NdarrayOrTensor] + left: Union[MetaMatrix, Grid, Matrix, NdarrayOrTensor], right: Union[MetaMatrix, Grid, Matrix, NdarrayOrTensor] ): matrix_types = (MetaMatrix, Grid, Matrix, torch.Tensor, np.ndarray) @@ -153,16 +148,15 @@ def matmul( return result -def matmul_matrix_grid( - left: NdarrayOrTensor, - right: NdarrayOrTensor -): +def matmul_matrix_grid(left: NdarrayOrTensor, right: NdarrayOrTensor): if not is_matrix_shaped(left): raise ValueError(f"'left' should be a 2D or 3D homogenous matrix but has shape {left.shape}") if not is_grid_shaped(right): - raise ValueError("'right' should be a 3D array with shape[0] == 2 or a " - f"4D array with shape[0] == 3 but has shape {right.shape}") + raise ValueError( + "'right' should be a 3D array with shape[0] == 2 or a " + f"4D array with shape[0] == 3 but has shape {right.shape}" + ) # flatten the grid to take advantage of torch batch matrix multiply right_flat = right.reshape(right.shape[0], -1) @@ -172,13 +166,12 @@ def matmul_matrix_grid( return result -def matmul_grid_matrix( - left: NdarrayOrTensor, - right: NdarrayOrTensor -): +def matmul_grid_matrix(left: NdarrayOrTensor, right: NdarrayOrTensor): if not is_grid_shaped(left): - raise ValueError("'left' should be a 3D array with shape[0] == 2 or a " - f"4D array with shape[0] == 3 but has shape {left.shape}") + raise ValueError( + "'left' should be a 3D array with shape[0] == 2 or a " + f"4D array with shape[0] == 3 but has shape {left.shape}" + ) if not is_matrix_shaped(right): raise ValueError(f"'right' should be a 2D or 3D homogenous matrix but has shape {right.shape}") @@ -194,13 +187,12 @@ def matmul_grid_matrix( return matmul_matrix_grid(torch.inverse(right), left) -def matmul_grid_matrix_slow( - left: NdarrayOrTensor, - right: NdarrayOrTensor -): +def matmul_grid_matrix_slow(left: NdarrayOrTensor, right: NdarrayOrTensor): if not is_grid_shaped(left): - raise ValueError("'left' should be a 3D array with shape[0] == 2 or a " - f"4D array with shape[0] == 3 but has shape {left.shape}") + raise ValueError( + "'left' should be a 3D array with shape[0] == 2 or a " + f"4D array with shape[0] == 3 but has shape {left.shape}" + ) if not is_matrix_shaped(right): raise ValueError(f"'right' should be a 2D or 3D homogenous matrix but has shape {right.shape}") @@ -217,8 +209,5 @@ def matmul_grid_matrix_slow( return result -def matmul_matrix_matrix( - left: NdarrayOrTensor, - right: NdarrayOrTensor, -): +def matmul_matrix_matrix(left: NdarrayOrTensor, right: NdarrayOrTensor): return left @ right diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 091d3e5741..b505d27aa8 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -1792,9 +1792,7 @@ def squarepulse(sig, duty: float = 0.5): return y -def get_device_from_tensor_like( - data: NdarrayOrTensor -): +def get_device_from_tensor_like(data: NdarrayOrTensor): """ This function returns the device of `data`, which must either be a numpy ndarray or a pytorch Tensor. @@ -1814,9 +1812,7 @@ def get_device_from_tensor_like( raise ValueError(msg.format(type(data))) -def get_backend_from_tensor_like( - data: NdarrayOrTensor -): +def get_backend_from_tensor_like(data: NdarrayOrTensor): """ This function returns the backend of `data`, which must either be a numpy ndarray or a pytorch Tensor. @@ -1847,32 +1843,33 @@ def dtype_numpy_to_torch(dtype: np.dtype) -> torch.dtype: __dtype_dict = { - np.int8: 'int8', - torch.int8: 'int8', - np.int16: 'int16', - torch.int16: 'int16', - int: 'int32', - np.int32: 'int32', - torch.int32: 'int32', - np.int64: 'int64', - torch.int64: 'int64', - np.uint8: 'uint8', - torch.uint8: 'uint8', - np.uint16: 'uint16', - np.uint32: 'uint32', - np.uint64: 'uint64', - float: 'float32', - np.float16: 'float16', - np.float: 'float32', - np.float32: 'float32', - np.float64: 'float64', - torch.float16: 'float16', - torch.float: 'float32', - torch.float32: 'float32', - torch.double: 'float64', - torch.float64: 'float64' + np.int8: "int8", + torch.int8: "int8", + np.int16: "int16", + torch.int16: "int16", + int: "int32", + np.int32: "int32", + torch.int32: "int32", + np.int64: "int64", + torch.int64: "int64", + np.uint8: "uint8", + torch.uint8: "uint8", + np.uint16: "uint16", + np.uint32: "uint32", + np.uint64: "uint64", + float: "float32", + np.float16: "float16", + np.float: "float32", + np.float32: "float32", + np.float64: "float64", + torch.float16: "float16", + torch.float: "float32", + torch.float32: "float32", + torch.double: "float64", + torch.float64: "float64", } + def dtypes_to_str_or_identity(dtype: Any) -> Any: return __dtype_dict.get(dtype, dtype) diff --git a/tests/test_apply.py b/tests/test_apply.py index 953c249db4..9a050ab6ae 100644 --- a/tests/test_apply.py +++ b/tests/test_apply.py @@ -15,6 +15,5 @@ class TestApply(unittest.TestCase): - def _test_apply_impl(self): result = apply(None) diff --git a/tests/test_matmul.py b/tests/test_matmul.py index bf427ce06c..2ac3b70a2e 100644 --- a/tests/test_matmul.py +++ b/tests/test_matmul.py @@ -10,20 +10,24 @@ # limitations under the License. import unittest -from parameterized import parameterized import numpy as np - import torch +from parameterized import parameterized from monai.transforms.meta_matrix import ( - Grid, is_grid_shaped, is_matrix_shaped, matmul, matmul_matrix_grid, matmul_grid_matrix_slow, - matmul_grid_matrix, Matrix + Grid, + Matrix, + is_grid_shaped, + is_matrix_shaped, + matmul, + matmul_grid_matrix, + matmul_grid_matrix_slow, + matmul_matrix_grid, ) class TestMatmulFunctions(unittest.TestCase): - def test_matrix_grid_and_grid_inv_matrix_are_equivalent(self): grid = torch.randn((3, 32, 32)) @@ -96,7 +100,7 @@ def test_matmul_correct_return_type(self, left, right, expected): (torch.randn(5), False), (torch.randn(3, 3, 3), False), (torch.randn(4, 4, 4), False), - (torch.randn(5, 5, 5), False) + (torch.randn(5, 5, 5), False), ] def _test_is_matrix_shaped_impl(self, matrix, expected): @@ -111,7 +115,6 @@ def test_is_matrix_shaped(self, matrix, expected): # with self.subTest(f"{case[0].shape}"): # self._test_is_matrix_shaped_impl(*case) - GRID_SHAPE_TESTCASES = [ (torch.randn(1, 16, 32), False), (torch.randn(2, 16, 32), False), @@ -120,7 +123,7 @@ def test_is_matrix_shaped(self, matrix, expected): (torch.randn(5, 16, 32), False), (torch.randn(3, 16, 32, 64), False), (torch.randn(4, 16, 32, 64), True), - (torch.randn(5, 16, 32, 64), False) + (torch.randn(5, 16, 32, 64), False), ] def _test_is_grid_shaped_impl(self, grid, expected): From 3043ae7798fddb8fc91166da1a04354bf00631ae Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Thu, 27 Oct 2022 23:10:52 +0100 Subject: [PATCH 35/52] Fixing issues raised by linter Signed-off-by: Ben Murray --- monai/transforms/lazy/functional.py | 2 +- monai/transforms/meta_matrix.py | 2 +- tests/test_apply.py | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/monai/transforms/lazy/functional.py b/monai/transforms/lazy/functional.py index 607e90d6bb..aa153b8e02 100644 --- a/monai/transforms/lazy/functional.py +++ b/monai/transforms/lazy/functional.py @@ -37,7 +37,7 @@ def extents_from_shape(shape, dtype=np.float64): extents = [[0, shape[i]] for i in range(1, len(shape))] extents = it.product(*extents) - return list(np.asarray(e + (1,), dtype=dtype) for e in extents) + return [np.asarray(e + (1,), dtype=dtype) for e in extents] # TODO: move to mapping_stack.py diff --git a/monai/transforms/meta_matrix.py b/monai/transforms/meta_matrix.py index e955914c45..aff2d23c61 100644 --- a/monai/transforms/meta_matrix.py +++ b/monai/transforms/meta_matrix.py @@ -184,7 +184,7 @@ def matmul_grid_matrix(left: NdarrayOrTensor, right: NdarrayOrTensor): # invert the matrix and swap the arguments, taking advantage of # matrix @ vector == vector_transposed @ matrix_inverse - return matmul_matrix_grid(torch.inverse(right), left) + return matmul_matrix_grid(inv_matrix, left) def matmul_grid_matrix_slow(left: NdarrayOrTensor, right: NdarrayOrTensor): diff --git a/tests/test_apply.py b/tests/test_apply.py index 9a050ab6ae..036549d0b4 100644 --- a/tests/test_apply.py +++ b/tests/test_apply.py @@ -16,4 +16,5 @@ class TestApply(unittest.TestCase): def _test_apply_impl(self): - result = apply(None) + # result = apply(None) + pass From 4aae9da7c7ed2dc2d1355900ee43bada95be4449 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Fri, 28 Oct 2022 11:20:42 +0100 Subject: [PATCH 36/52] Starting tests for apply Signed-off-by: Ben Murray --- tests/test_apply.py | 48 +++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 44 insertions(+), 4 deletions(-) diff --git a/tests/test_apply.py b/tests/test_apply.py index 036549d0b4..7d8af36e75 100644 --- a/tests/test_apply.py +++ b/tests/test_apply.py @@ -11,10 +11,50 @@ import unittest -from monai.transforms import apply +import numpy as np + +import torch + +from monai.transforms.lazy.functional import apply +from monai.transforms.meta_matrix import MetaMatrix + + +def get_img(size, dtype=torch.float32, offset=0): + img = torch.zeros(size, dtype=dtype) + if len(size) == 2: + for j in range(size[0]): + for i in range(size[1]): + img[j, i] = i + j * size[0] + offset + else: + for k in range(size[0]): + for j in range(size[1]): + for i in range(size[2]): + img[k, j, i] = i + j * size[0] + k * size[0] * size[1] + return np.expand_dims(img, 0) + + +def rotate_45_2D(): + t = torch.eye(3) + t[:, 0] = torch.FloatTensor([0, -1, 0]) + t[:, 1] = torch.FloatTensor([1, 0, 0]) + return t class TestApply(unittest.TestCase): - def _test_apply_impl(self): - # result = apply(None) - pass + + def _test_apply_impl(self, tensor, pending_transforms): + print(tensor.shape) + # for m in pending_transforms: + # print(m.matrix) + # print(m.metadata) + result = apply(tensor, pending_transforms) + print(result) + + SINGLE_TRANSFORM_CASES = [ + (torch.randn((1, 16, 16)), [MetaMatrix(rotate_45_2D(), {"id", "rotate"})]) + ] + + def test_apply_single_transform(self): + for case in self.SINGLE_TRANSFORM_CASES: + self._test_apply_impl(*case) + From b6f1bd9f736529398be2d0a4a9d1ab879a4011fc Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Fri, 28 Oct 2022 14:57:52 +0100 Subject: [PATCH 37/52] Further array functionality and testing; waiting on PR #5107 Signed-off-by: Ben Murray --- monai/transforms/__init__.py | 7 ++ monai/transforms/lazy/array.py | 4 +- monai/transforms/lazy/functional.py | 18 +++-- monai/transforms/meta_matrix.py | 76 ++++++++++++++++- monai/transforms/utils.py | 121 ++++++++++++++++++++++++++++ tests/test_apply.py | 20 ++++- 6 files changed, 235 insertions(+), 11 deletions(-) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 431c554f6f..307b5cda28 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -237,6 +237,13 @@ ToMetaTensorD, ToMetaTensorDict, ) +from .meta_matrix import ( + Grid, + matmul, + Matrix, + MatrixFactory, + MetaMatrix, +) from .nvtx import ( Mark, Markd, diff --git a/monai/transforms/lazy/array.py b/monai/transforms/lazy/array.py index bc61a24a6e..ae165cf566 100644 --- a/monai/transforms/lazy/array.py +++ b/monai/transforms/lazy/array.py @@ -9,8 +9,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from monai.transforms.inverse import InvertibleTransform from monai.transforms.lazy.functional import apply +from monai.transforms.inverse import InvertibleTransform + +__all__ = ["Apply"] class Apply(InvertibleTransform): diff --git a/monai/transforms/lazy/functional.py b/monai/transforms/lazy/functional.py index aa153b8e02..adb44fc9a3 100644 --- a/monai/transforms/lazy/functional.py +++ b/monai/transforms/lazy/functional.py @@ -33,7 +33,7 @@ # TODO: move to mapping_stack.py -def extents_from_shape(shape, dtype=np.float64): +def extents_from_shape(shape, dtype=np.float32): extents = [[0, shape[i]] for i in range(1, len(shape))] extents = it.product(*extents) @@ -132,10 +132,10 @@ def apply(data: Union[torch.Tensor, MetaTensor, dict], pending: Optional[Union[d rd[k] = result return rd - if isinstance(data, MetaTensor) or pending is not None: - pending_ = [] if pending is None else pending - else: + if isinstance(data, MetaTensor) and pending is None: pending_ = data.pending_transforms + else: + pending_ = [] if pending is None else pending if len(pending_) == 0: return data @@ -154,7 +154,7 @@ def apply(data: Union[torch.Tensor, MetaTensor, dict], pending: Optional[Union[d cur_shape = data.shape for meta_matrix in pending_: - next_matrix = meta_matrix.data + next_matrix = meta_matrix.matrix # print("intermediate matrix\n", matrix_from_matrix_container(cumulative_matrix)) cumulative_matrix = matmul(cumulative_matrix, next_matrix) cumulative_extents = [matmul(e, cumulative_matrix) for e in cumulative_extents] @@ -193,6 +193,10 @@ def apply(data: Union[torch.Tensor, MetaTensor, dict], pending: Optional[Union[d # print(f"applying with cumulative matrix\n {cumulative_matrix_}") a = Affine(norm_coords=False, affine=cumulative_matrix_, spatial_size=cur_shape[1:], normalized=False, **kwargs) data, tx = a(img=data) - data.clear_pending_transforms() + if isinstance(data, MetaTensor): + data.clear_pending_transforms() + for p in pending_: + data.affine = p.matrix.data + data.push_applied_operation(p) - return data + return data, pending_ diff --git a/monai/transforms/meta_matrix.py b/monai/transforms/meta_matrix.py index aff2d23c61..845fe3f2c5 100644 --- a/monai/transforms/meta_matrix.py +++ b/monai/transforms/meta_matrix.py @@ -9,13 +9,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union +from typing import Optional, Sequence, Union import numpy as np import torch +from monai.transforms.utils import _create_rotate, _create_rotate_90, _create_flip, _create_shear, _create_scale, \ + _create_translate + +from monai.transforms.utils import get_backend_from_tensor_like, get_device_from_tensor_like +from monai.utils import TransformBackends from monai.config import NdarrayOrTensor +__all__ = ["Grid", "matmul", "Matrix", "MatrixFactory", "MetaMatrix"] def is_matrix_shaped(data): @@ -30,7 +36,73 @@ def is_grid_shaped(data): class MatrixFactory: - pass + + def __init__(self, + dims: int, + backend: TransformBackends, + device: Optional[torch.device] = None): + + if backend == TransformBackends.NUMPY: + if device is not None: + raise ValueError("'device' must be None with TransformBackends.NUMPY") + self._device = None + self._sin = lambda th: np.sin(th, dtype=np.float32) + self._cos = lambda th: np.cos(th, dtype=np.float32) + self._eye = lambda th: np.eye(th, dtype=np.float32) + self._diag = lambda th: np.diag(th, dtype=np.float32) + else: + if device is None: + raise ValueError("'device' must be set with TransformBackends.TORCH") + self._device = device + self._sin = lambda th: torch.sin(torch.as_tensor(th, + dtype=torch.float32, + device=self._device)) + self._cos = lambda th: torch.cos(torch.as_tensor(th, + dtype=torch.float32, + device=self._device)) + self._eye = lambda rank: torch.eye(rank, + device=self._device, + dtype=torch.float32); + self._diag = lambda size: torch.diag(torch.as_tensor(size, + device=self._device, + dtype=torch.float32)) + + self._backend = backend + self._dims = dims + + @staticmethod + def from_tensor(data): + return MatrixFactory(len(data.shape)-1, + get_backend_from_tensor_like(data), + get_device_from_tensor_like(data)) + + def identity(self): + matrix = self._eye(self._dims + 1) + return MetaMatrix(matrix, {}) + + def rotate_euler(self, radians: Union[Sequence[float], float], **extra_args): + matrix = _create_rotate(self._dims, radians, self._sin, self._cos, self._eye) + return MetaMatrix(matrix, extra_args) + + def rotate_90(self, rotations, axis, **extra_args): + matrix = _create_rotate_90(self._dims, rotations, axis) + return MetaMatrix(matrix, extra_args) + + def flip(self, axis, **extra_args): + matrix = _create_flip(self._dims, axis, self._eye) + return MetaMatrix(matrix, extra_args) + + def shear(self, coefs: Union[Sequence[float], float], **extra_args): + matrix = _create_shear(self._dims, coefs, self._eye) + return MetaMatrix(matrix, extra_args) + + def scale(self, factors: Union[Sequence[float], float], **extra_args): + matrix = _create_scale(self._dims, factors, self._diag) + return MetaMatrix(matrix, extra_args) + + def translate(self, offsets: Union[Sequence[float], float], **extra_args): + matrix = _create_translate(self._dims, offsets, self._eye) + return MetaMatrix(matrix, extra_args) def ensure_tensor(data: NdarrayOrTensor): diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index b505d27aa8..a982471cfe 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -891,6 +891,127 @@ def _create_translate( return array_func(affine) # type: ignore +def _create_rotate_90( + spatial_dims: int, + axis: Tuple[int, int], + steps: Optional[int] = 1, + eye_func: Callable = np.eye +) -> NdarrayOrTensor: + + values = [(1, 0, 0, 1), + (0, -1, 1, 0), + (-1, 0, 0, -1), + (0, 1, -1, 0)] + + if spatial_dims == 2: + if axis != (0, 1): + raise ValueError(f"if 'spatial_dims' is 2, 'axis' must be (0, 1) but is {axis}") + elif spatial_dims == 3: + if axis not in ((0, 1), (0, 2), (1, 2)): + raise ValueError("if 'spatial_dims' is 3, 'axis' must be (0,1), (0, 2), or (1, 2) " + f"but is {axis}") + else: + raise ValueError(f"'spatial_dims' must be 2 or 3 but is {spatial_dims}") + + steps_ = steps % 4 + + affine = eye_func(spatial_dims + 1) + + if spatial_dims == 2: + a, b = 0, 1 + else: + a, b = axis + + affine[a, a], affine[a, b], affine[b, a], affine[b, b] = values[steps] + return affine + + +def create_rotate_90( + spatial_dims: int, + axis: int, + steps: Optional[int] = 1, + device: Optional[torch.device] = None, + backend: str = TransformBackends.NUMPY, +) -> NdarrayOrTensor: + """ + create a 2D or 3D rotation matrix + + Args: + spatial_dims: {``2``, ``3``} spatial rank + radians: rotation radians + when spatial_dims == 3, the `radians` sequence corresponds to + rotation in the 1st, 2nd, and 3rd dim respectively. + device: device to compute and store the output (when the backend is "torch"). + backend: APIs to use, ``numpy`` or ``torch``. + + Raises: + ValueError: When ``radians`` is empty. + ValueError: When ``spatial_dims`` is not one of [2, 3]. + + """ + _backend = look_up_option(backend, TransformBackends) + if _backend == TransformBackends.NUMPY: + return _create_rotate_90( + spatial_dims=spatial_dims, + axis=axis, + steps=steps, + eye_func=np.eye) + if _backend == TransformBackends.TORCH: + return _create_rotate_90( + spatial_dims=spatial_dims, + axis=axis, + steps=steps, + eye_func=lambda rank: torch.eye(rank, device=device), + ) + raise ValueError(f"backend {backend} is not supported") + + +def _create_flip( + spatial_dims: int, + spatial_axis: Union[Sequence[int], int], + eye_func: Callable = np.eye +): + affine = eye_func(spatial_dims + 1) + if isinstance(spatial_axis, int): + if spatial_axis < -spatial_dims or spatial_axis >= spatial_dims: + raise ValueError("'spatial_axis' values must be between " + f"{-spatial_dims} and {spatial_dims-1} inclusive " + f"('spatial_axis' is {spatial_axis})") + affine[spatial_axis, spatial_axis] = -1 + else: + if any((s < -spatial_dims or s >= spatial_dims) for s in spatial_axis): + raise ValueError("'spatial_axis' values must be between " + f"{-spatial_dims} and {spatial_dims-1} inclusive " + f"('spatial_axis' is {spatial_axis})") + + for i in range(spatial_dims): + if i in spatial_axis: + affine[i, i] = -1 + + return affine + + +def create_flip( + spatial_dims: int, + spatial_axis: Union[Sequence[int], int], + device: Optional[torch.device] = None, + backend: str = TransformBackends.NUMPY, +) -> NdarrayOrTensor: + _backend = look_up_option(backend, TransformBackends) + if _backend == TransformBackends.NUMPY: + return _create_flip( + spatial_dims=spatial_dims, + spatial_axis=spatial_axis, + eye_func=np.eye) + if _backend == TransformBackends.TORCH: + return _create_flip( + spatial_dims=spatial_dims, + spatial_axis=spatial_axis, + eye_func=lambda rank: torch.eye(rank, device=device), + ) + raise ValueError(f"backend {backend} is not supported") + + def generate_spatial_bounding_box( img: NdarrayOrTensor, select_fn: Callable = is_positive, diff --git a/tests/test_apply.py b/tests/test_apply.py index 7d8af36e75..088d275a12 100644 --- a/tests/test_apply.py +++ b/tests/test_apply.py @@ -14,6 +14,7 @@ import numpy as np import torch +from monai.utils import convert_to_tensor from monai.transforms.lazy.functional import apply from monai.transforms.meta_matrix import MetaMatrix @@ -50,11 +51,28 @@ def _test_apply_impl(self, tensor, pending_transforms): result = apply(tensor, pending_transforms) print(result) + def _test_apply_metatensor_impl(self, tensor, pending_transforms, pending_as_parameter): + tensor_ = convert_to_tensor(tensor) + if pending_as_parameter: + result = apply(tensor_, pending_transforms) + else: + for p in pending_transforms: + # TODO: cannot do the next part until the MetaTensor PR #5107 is in + # tensor_.push_pending(p) + raise NotImplementedError() + SINGLE_TRANSFORM_CASES = [ - (torch.randn((1, 16, 16)), [MetaMatrix(rotate_45_2D(), {"id", "rotate"})]) + (torch.randn((1, 16, 16)), [MetaMatrix(rotate_45_2D(), {"id": "rotate"})]) ] def test_apply_single_transform(self): for case in self.SINGLE_TRANSFORM_CASES: self._test_apply_impl(*case) + def test_apply_single_transform_metatensor(self): + for case in self.SINGLE_TRANSFORM_CASES: + self._test_apply_metatensor_impl(*case, False) + + def test_apply_single_transform_metatensor_override(self): + for case in self.SINGLE_TRANSFORM_CASES: + self._test_apply_metatensor_impl(*case, True) From f6b889f1ff1217ebb50c87625d4301ffd7588388 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Fri, 28 Oct 2022 15:53:06 +0100 Subject: [PATCH 38/52] Adding resample function for unified resampling and test placeholder Signed-off-by: Ben Murray --- monai/transforms/utility/functional.py | 32 ++++++++++++++++++ tests/test_apply.py | 23 +++---------- tests/test_resample.py | 46 ++++++++++++++++++++++++++ tests/utils.py | 18 ++++++++++ 4 files changed, 100 insertions(+), 19 deletions(-) create mode 100644 monai/transforms/utility/functional.py create mode 100644 tests/test_resample.py diff --git a/monai/transforms/utility/functional.py b/monai/transforms/utility/functional.py new file mode 100644 index 0000000000..7aec22c723 --- /dev/null +++ b/monai/transforms/utility/functional.py @@ -0,0 +1,32 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Union + +import torch +from monai.transforms import Affine + +from monai.config import NdarrayOrTensor +from monai.transforms.meta_matrix import Grid, Matrix + + +def resample( + data: torch.Tensor, + matrix: Union[NdarrayOrTensor, Matrix, Grid], + kwargs: Optional[dict] = None +): + """ + This is a minimal implementation of resample that always uses Affine. + """ + if kwargs is not None: + a = Affine(affine=matrix, **kwargs) + else: + a = Affine(affine=matrix) + return a(img=data) diff --git a/tests/test_apply.py b/tests/test_apply.py index 088d275a12..19fca1dae2 100644 --- a/tests/test_apply.py +++ b/tests/test_apply.py @@ -19,22 +19,10 @@ from monai.transforms.lazy.functional import apply from monai.transforms.meta_matrix import MetaMatrix +from tests.utils import get_arange_img -def get_img(size, dtype=torch.float32, offset=0): - img = torch.zeros(size, dtype=dtype) - if len(size) == 2: - for j in range(size[0]): - for i in range(size[1]): - img[j, i] = i + j * size[0] + offset - else: - for k in range(size[0]): - for j in range(size[1]): - for i in range(size[2]): - img[k, j, i] = i + j * size[0] + k * size[0] * size[1] - return np.expand_dims(img, 0) - -def rotate_45_2D(): +def rotate_45_2d(): t = torch.eye(3) t[:, 0] = torch.FloatTensor([0, -1, 0]) t[:, 1] = torch.FloatTensor([1, 0, 0]) @@ -45,11 +33,8 @@ class TestApply(unittest.TestCase): def _test_apply_impl(self, tensor, pending_transforms): print(tensor.shape) - # for m in pending_transforms: - # print(m.matrix) - # print(m.metadata) result = apply(tensor, pending_transforms) - print(result) + self.assertListEqual(result[1], pending_transforms) def _test_apply_metatensor_impl(self, tensor, pending_transforms, pending_as_parameter): tensor_ = convert_to_tensor(tensor) @@ -62,7 +47,7 @@ def _test_apply_metatensor_impl(self, tensor, pending_transforms, pending_as_par raise NotImplementedError() SINGLE_TRANSFORM_CASES = [ - (torch.randn((1, 16, 16)), [MetaMatrix(rotate_45_2D(), {"id": "rotate"})]) + (torch.randn((1, 16, 16)), [MetaMatrix(rotate_45_2d(), {"id": "rotate"})]) ] def test_apply_single_transform(self): diff --git a/tests/test_resample.py b/tests/test_resample.py new file mode 100644 index 0000000000..ce9c8ede9e --- /dev/null +++ b/tests/test_resample.py @@ -0,0 +1,46 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import torch + +from monai.transforms.utility.functional import resample +from monai.utils import convert_to_tensor + +from monai.transforms.lazy.functional import apply +from monai.transforms.meta_matrix import MetaMatrix + +from tests.utils import get_arange_img + + +def rotate_45_2d(): + t = torch.eye(3) + t[:, 0] = torch.FloatTensor([0, -1, 0]) + t[:, 1] = torch.FloatTensor([1, 0, 0]) + return t + + +class TestResampleFunction(unittest.TestCase): + + def _test_resample_function_impl(self, img, matrix): + result = resample(convert_to_tensor(img), matrix) + print(result) + + RESAMPLE_FUNCTION_CASES = [ + (get_arange_img((1, 16, 16)), rotate_45_2d()) + ] + + def test_resample_function(self): + for case in self.RESAMPLE_FUNCTION_CASES: + self._test_resample_function_impl(*case) diff --git a/tests/utils.py b/tests/utils.py index b16b4b13fb..bbceb9cfc4 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -347,6 +347,24 @@ def make_rand_affine(ndim: int = 3, random_state: Optional[np.random.RandomState return af +def get_arange_img(size, dtype=torch.float32, offset=0): + """ + Returns an 2d or 3d image as a numpy tensor (complete with channel as dim 0) + with contents that iterate like an arange. + """ + img = torch.zeros(size, dtype=dtype) + if len(size) == 2: + for j in range(size[0]): + for i in range(size[1]): + img[j, i] = i + j * size[0] + offset + else: + for k in range(size[0]): + for j in range(size[1]): + for i in range(size[2]): + img[k, j, i] = i + j * size[0] + k * size[0] * size[1] + offset + return np.expand_dims(img, 0) + + class DistTestCase(unittest.TestCase): """ testcase without _outcome, so that it's picklable. From 98143d047610da244507d8c71c01824938a0900e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 28 Oct 2022 14:56:28 +0000 Subject: [PATCH 39/52] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_apply.py | 2 -- tests/test_resample.py | 3 --- 2 files changed, 5 deletions(-) diff --git a/tests/test_apply.py b/tests/test_apply.py index 19fca1dae2..2447d8a3af 100644 --- a/tests/test_apply.py +++ b/tests/test_apply.py @@ -11,7 +11,6 @@ import unittest -import numpy as np import torch from monai.utils import convert_to_tensor @@ -19,7 +18,6 @@ from monai.transforms.lazy.functional import apply from monai.transforms.meta_matrix import MetaMatrix -from tests.utils import get_arange_img def rotate_45_2d(): diff --git a/tests/test_resample.py b/tests/test_resample.py index ce9c8ede9e..c032481389 100644 --- a/tests/test_resample.py +++ b/tests/test_resample.py @@ -11,15 +11,12 @@ import unittest -import numpy as np import torch from monai.transforms.utility.functional import resample from monai.utils import convert_to_tensor -from monai.transforms.lazy.functional import apply -from monai.transforms.meta_matrix import MetaMatrix from tests.utils import get_arange_img From eae58258e105f2cc25193a7a27f95aa8c03ecb6d Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Thu, 27 Oct 2022 18:14:18 +0100 Subject: [PATCH 40/52] Apply and MetaMatrix; partial functionality Signed-off-by: Ben Murray --- monai/transforms/__init__.py | 1 + monai/transforms/apply.py | 229 +++++++++++++++++++++++++++++++++++ 2 files changed, 230 insertions(+) create mode 100644 monai/transforms/apply.py diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 307b5cda28..fb57831c17 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -10,6 +10,7 @@ # limitations under the License. from .adaptors import FunctionSignature, adaptor, apply_alias, to_kwargs +from .apply import Apply, apply from .compose import Compose, OneOf from .croppad.array import ( BorderPad, diff --git a/monai/transforms/apply.py b/monai/transforms/apply.py new file mode 100644 index 0000000000..c58505e929 --- /dev/null +++ b/monai/transforms/apply.py @@ -0,0 +1,229 @@ +from typing import Optional, Sequence, Union + +import itertools as it + +import numpy as np + +import torch + +from monai.config import DtypeLike +from monai.data import MetaTensor +from monai.transforms import Affine +from monai.transforms.inverse import InvertibleTransform +from monai.transforms.transform import MapTransform +from monai.utils import GridSampleMode, GridSamplePadMode +from monai.transforms.utils import get_backend_from_tensor_like, get_device_from_tensor_like, dtypes_to_str_or_identity +from monai.transforms.meta_matrix import MatrixFactory, MetaMatrix, Matrix, matmul + +__all__ = [ + "apply", + "Apply" +] + +# TODO: This should move to a common place to be shared with dictionary +#from monai.utils.type_conversion import dtypes_to_str_or_identity + +GridSampleModeSequence = Union[Sequence[Union[GridSampleMode, str]], GridSampleMode, str] +GridSamplePadModeSequence = Union[Sequence[Union[GridSamplePadMode, str]], GridSamplePadMode, str] +DtypeSequence = Union[Sequence[DtypeLike], DtypeLike] + + +# TODO: move to mapping_stack.py +def extents_from_shape(shape, dtype=np.float64): + extents = [[0, shape[i]] for i in range(1, len(shape))] + + extents = it.product(*extents) + return list(np.asarray(e + (1,), dtype=dtype) for e in extents) + + +# TODO: move to mapping_stack.py +def shape_from_extents( + src_shape: Sequence, + extents: Union[Sequence[np.ndarray], Sequence[torch.Tensor], np.ndarray, torch.Tensor] +): + if isinstance(extents, (list, tuple)): + if isinstance(extents[0], np.ndarray): + aextents = np.asarray(extents) + aextents = torch.from_numpy(aextents) + else: + aextents = torch.stack(extents) + else: + if isinstance(extents, np.ndarray): + aextents = torch.from_numpy(extents) + else: + aextents = extents + + mins = aextents.min(axis=0)[0] + maxes = aextents.max(axis=0)[0] + values = torch.round(maxes - mins).type(torch.IntTensor)[:-1] + return torch.cat((torch.as_tensor([src_shape[0]]), values)) + + +def metadata_is_compatible(value_1, value_2): + if value_1 is None: + return True + else: + if value_2 is None: + return True + return value_1 == value_2 + + +def metadata_dtype_is_compatible(value_1, value_2): + if value_1 is None: + return True + else: + if value_2 is None: + return True + + # if we are here, value_1 and value_2 are both set + # TODO: this is not a good enough solution + value_1_ = dtypes_to_str_or_identity(value_1) + value_2_ = dtypes_to_str_or_identity(value_2) + return value_1_ == value_2_ + + +def starting_matrix_and_extents(matrix_factory, data): + # set up the identity matrix and metadata + cumulative_matrix = matrix_factory.identity() + cumulative_extents = extents_from_shape(data.shape) + return cumulative_matrix, cumulative_extents + + +def prepare_args_dict_for_apply(cur_mode, cur_padding_mode, cur_device, cur_dtype): + kwargs = {} + if cur_mode is not None: + kwargs['mode'] = cur_mode + if cur_padding_mode is not None: + kwargs['padding_mode'] = cur_padding_mode + if cur_device is not None: + kwargs['device'] = cur_device + if cur_dtype is not None: + kwargs['dtype'] = cur_dtype + + return kwargs + + +def matrix_from_matrix_container(matrix): + if isinstance(matrix, MetaMatrix): + return matrix.matrix.data + elif isinstance(matrix, Matrix): + return matrix.data + else: + return matrix + + +def apply(data: Union[torch.Tensor, MetaTensor], + pending: Optional[dict, list] = None): + + # TODO: if data is a dict, then pending must also be a dict + if isinstance(data, dict): + rd = dict() + for k, v in data.items(): + result = apply(v, pending) + rd[k] = result + return rd + + if isinstance(data, MetaTensor) or pending is not None: + pending_ = [] if pending is None else pending + else: + pending_ = data.pending_transforms + + if len(pending_) == 0: + return data + + dim_count = len(data.shape) - 1 + matrix_factory = MatrixFactory(dim_count, + get_backend_from_tensor_like(data), + get_device_from_tensor_like(data)) + + # set up the identity matrix and metadata + cumulative_matrix, cumulative_extents = starting_matrix_and_extents(matrix_factory, data) + + # set the various resampling parameters to an initial state + cur_mode = None + cur_padding_mode = None + cur_device = None + cur_dtype = None + cur_shape = data.shape + + for meta_matrix in pending_: + next_matrix = meta_matrix.data + # print("intermediate matrix\n", matrix_from_matrix_container(cumulative_matrix)) + cumulative_matrix = matmul(cumulative_matrix, next_matrix) + cumulative_extents = [matmul(e, cumulative_matrix) for e in cumulative_extents] + + new_mode = meta_matrix.metadata.get('mode', None) + new_padding_mode = meta_matrix.metadata.get('padding_mode', None) + new_device = meta_matrix.metadata.get('device', None) + new_dtype = meta_matrix.metadata.get('dtype', None) + new_shape = meta_matrix.metadata.get('shape_override', None) + + mode_compat = metadata_is_compatible(cur_mode, new_mode) + padding_mode_compat = metadata_is_compatible(cur_padding_mode, new_padding_mode) + device_compat = metadata_is_compatible(cur_device, new_device) + dtype_compat = metadata_dtype_is_compatible(cur_dtype, new_dtype) + + if (mode_compat is False or padding_mode_compat is False or + device_compat is False or dtype_compat is False): + # carry out an intermediate resample here due to incompatibility between arguments + kwargs = prepare_args_dict_for_apply(cur_mode, cur_padding_mode, cur_device, cur_dtype) + + cumulative_matrix_ = matrix_from_matrix_container(cumulative_matrix) + a = Affine(norm_coords=False, + affine=cumulative_matrix_, + **kwargs) + data, _ = a(img=data) + + cumulative_matrix, cumulative_extents =\ + starting_matrix_and_extents(matrix_factory, data) + + cur_mode = cur_mode if new_mode is None else new_mode + cur_padding_mode = cur_padding_mode if new_padding_mode is None else new_padding_mode + cur_device = cur_device if new_device is None else new_device + cur_dtype = cur_dtype if new_dtype is None else new_dtype + cur_shape = cur_shape if new_shape is None else new_shape + + kwargs = prepare_args_dict_for_apply(cur_mode, cur_padding_mode, cur_device, cur_dtype) + + cumulative_matrix_ = matrix_from_matrix_container(cumulative_matrix) + + # print(f"applying with cumulative matrix\n {cumulative_matrix_}") + a = Affine(norm_coords=False, + affine=cumulative_matrix_, + spatial_size=cur_shape[1:], + normalized=False, + **kwargs) + data, tx = a(img=data) + data.clear_pending_transforms() + + return data + + +# make Apply universal for arrays and dictionaries; it just calls through to functional apply +class Apply(InvertibleTransform): + + def __init__(self): + super().__init__() + + def __call__(self, *args, **kwargs): + return apply(*args, **kwargs) + + def inverse(self, data): + return NotImplementedError() + + +# class Applyd(MapTransform, InvertibleTransform): +# +# def __init__(self): +# super().__init__() +# +# def __call__( +# self, +# d: dict +# ): +# rd = dict() +# for k, v in d.items(): +# rd[k] = apply(v) +# +# def inverse(self, data): +# return NotImplementedError() From 6ba693b3290d5106b2ebc0cf5296270c6b482868 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Thu, 27 Oct 2022 21:04:47 +0100 Subject: [PATCH 41/52] making import for MetaTensor more specific to avoid circular reference Signed-off-by: Ben Murray --- monai/transforms/apply.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/apply.py b/monai/transforms/apply.py index c58505e929..1b963adf8f 100644 --- a/monai/transforms/apply.py +++ b/monai/transforms/apply.py @@ -7,7 +7,7 @@ import torch from monai.config import DtypeLike -from monai.data import MetaTensor +from monai.data.meta_tensor import MetaTensor from monai.transforms import Affine from monai.transforms.inverse import InvertibleTransform from monai.transforms.transform import MapTransform From 209eb301fff2b5f488edcc6956243aaf4d23da8c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 27 Oct 2022 17:16:45 +0000 Subject: [PATCH 42/52] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/transforms/apply.py | 1 - 1 file changed, 1 deletion(-) diff --git a/monai/transforms/apply.py b/monai/transforms/apply.py index 1b963adf8f..d613540952 100644 --- a/monai/transforms/apply.py +++ b/monai/transforms/apply.py @@ -10,7 +10,6 @@ from monai.data.meta_tensor import MetaTensor from monai.transforms import Affine from monai.transforms.inverse import InvertibleTransform -from monai.transforms.transform import MapTransform from monai.utils import GridSampleMode, GridSamplePadMode from monai.transforms.utils import get_backend_from_tensor_like, get_device_from_tensor_like, dtypes_to_str_or_identity from monai.transforms.meta_matrix import MatrixFactory, MetaMatrix, Matrix, matmul From e7cb64811a1512c9325bddb4569b0ac4aeaa1fae Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Thu, 27 Oct 2022 21:07:35 +0100 Subject: [PATCH 43/52] Making Affine import more specific in apply to avoid circular reference Signed-off-by: Ben Murray --- monai/transforms/apply.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/apply.py b/monai/transforms/apply.py index d613540952..9b56321eaf 100644 --- a/monai/transforms/apply.py +++ b/monai/transforms/apply.py @@ -8,7 +8,7 @@ from monai.config import DtypeLike from monai.data.meta_tensor import MetaTensor -from monai.transforms import Affine +from monai.transforms.spatial.array import Affine from monai.transforms.inverse import InvertibleTransform from monai.utils import GridSampleMode, GridSamplePadMode from monai.transforms.utils import get_backend_from_tensor_like, get_device_from_tensor_like, dtypes_to_str_or_identity From bfd94584a995a0d37099250f4d9812d326c3db69 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Thu, 27 Oct 2022 21:10:02 +0100 Subject: [PATCH 44/52] Fixing typing signature issue on apply method Signed-off-by: Ben Murray --- monai/transforms/apply.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/apply.py b/monai/transforms/apply.py index 9b56321eaf..b2fa8fc708 100644 --- a/monai/transforms/apply.py +++ b/monai/transforms/apply.py @@ -112,7 +112,7 @@ def matrix_from_matrix_container(matrix): def apply(data: Union[torch.Tensor, MetaTensor], - pending: Optional[dict, list] = None): + pending: Optional[Union[dict, list]] = None): # TODO: if data is a dict, then pending must also be a dict if isinstance(data, dict): From 45551d171cfda1b7e9c72c8f6314b4dea72095b5 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Thu, 27 Oct 2022 21:20:39 +0100 Subject: [PATCH 45/52] Adding missing license boilerplate Signed-off-by: Ben Murray --- monai/transforms/apply.py | 11 +++++++++ tests/tempscript.py | 47 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+) create mode 100644 tests/tempscript.py diff --git a/monai/transforms/apply.py b/monai/transforms/apply.py index b2fa8fc708..20d031b448 100644 --- a/monai/transforms/apply.py +++ b/monai/transforms/apply.py @@ -1,3 +1,14 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import Optional, Sequence, Union import itertools as it diff --git a/tests/tempscript.py b/tests/tempscript.py new file mode 100644 index 0000000000..5a9e5945b6 --- /dev/null +++ b/tests/tempscript.py @@ -0,0 +1,47 @@ +import numpy as np + +import matplotlib.pyplot as plt + +import torch +from monai.utils import GridSampleMode, GridSamplePadMode + +from monai.transforms.atmostonce.apply import Applyd +from monai.transforms.atmostonce.dictionary import Rotated +from monai.transforms import Compose + + + +def test_rotate_tensor(): + r = Rotated(('image', 'label'), [0.0, 1.0, 0.0]) + + d = { + 'image': torch.zeros((1, 64, 64, 32), device="cpu", dtype=torch.float32), + 'label': torch.ones((1, 64, 64, 32), device="cpu", dtype=torch.int8) + } + d = r(d) + + for k, v in d.items(): + if isinstance(v, (np.ndarray, torch.Tensor)): + print(k, v.shape) + else: + print(k, v) + + +def test_rotate_apply(): + c = Compose([ + Rotated(('image', 'label'), (0.0, 3.14159265 / 2, 0.0)), + Applyd(('image', 'label'), + modes=(GridSampleMode.BILINEAR, GridSampleMode.NEAREST), + padding_modes=(GridSamplePadMode.BORDER, GridSamplePadMode.BORDER)) + ]) + + d = { + 'image': torch.zeros((1, 64, 64, 32), device="cpu", dtype=torch.float32), + 'label': torch.ones((1, 64, 64, 32), device="cpu", dtype=torch.int8) + } + plt.imshow(d['image'][0, ..., d['image'].shape[-1]//2]) + d = c(d) + plt.imshow(d['image'][0, ..., d['image'].shape[-1]//2]) + print(d['image'].shape) + +test_rotate_apply() From 3dd6164a27a2762f9ae03115ac64d493d8064d4f Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Thu, 27 Oct 2022 21:32:27 +0100 Subject: [PATCH 46/52] Minimal docstrings for apply / Apply Signed-off-by: Ben Murray --- monai/transforms/apply.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/monai/transforms/apply.py b/monai/transforms/apply.py index 20d031b448..9290035c10 100644 --- a/monai/transforms/apply.py +++ b/monai/transforms/apply.py @@ -122,10 +122,17 @@ def matrix_from_matrix_container(matrix): return matrix -def apply(data: Union[torch.Tensor, MetaTensor], +def apply(data: Union[torch.Tensor, MetaTensor, dict], pending: Optional[Union[dict, list]] = None): - - # TODO: if data is a dict, then pending must also be a dict + """ + This method applies pending transforms to tensors. + + Args: + data: A torch Tensor, monai MetaTensor, or a dictionary containing Tensors / MetaTensors + pending: Optional arg containing pending transforms. This must be set if data is a Tensor + or dictionary of Tensors, but is optional if data is a MetaTensor / dictionary of + MetaTensors. + """ if isinstance(data, dict): rd = dict() for k, v in data.items(): @@ -211,7 +218,10 @@ def apply(data: Union[torch.Tensor, MetaTensor], # make Apply universal for arrays and dictionaries; it just calls through to functional apply class Apply(InvertibleTransform): - + """ + Apply wraps the apply method and can function as a Transform in either array or dictionary + mode. + """ def __init__(self): super().__init__() From 29123aa5c863c17b6174e4843f0d7227c8a213bd Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Thu, 27 Oct 2022 21:48:44 +0100 Subject: [PATCH 47/52] Splitting apply.py into lazy/functional and lazy/array Signed-off-by: Ben Murray --- monai/transforms/__init__.py | 1 - monai/transforms/apply.py | 249 --------------------------------- monai/transforms/lazy/array.py | 1 - 3 files changed, 251 deletions(-) delete mode 100644 monai/transforms/apply.py diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index fb57831c17..307b5cda28 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -10,7 +10,6 @@ # limitations under the License. from .adaptors import FunctionSignature, adaptor, apply_alias, to_kwargs -from .apply import Apply, apply from .compose import Compose, OneOf from .croppad.array import ( BorderPad, diff --git a/monai/transforms/apply.py b/monai/transforms/apply.py deleted file mode 100644 index 9290035c10..0000000000 --- a/monai/transforms/apply.py +++ /dev/null @@ -1,249 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional, Sequence, Union - -import itertools as it - -import numpy as np - -import torch - -from monai.config import DtypeLike -from monai.data.meta_tensor import MetaTensor -from monai.transforms.spatial.array import Affine -from monai.transforms.inverse import InvertibleTransform -from monai.utils import GridSampleMode, GridSamplePadMode -from monai.transforms.utils import get_backend_from_tensor_like, get_device_from_tensor_like, dtypes_to_str_or_identity -from monai.transforms.meta_matrix import MatrixFactory, MetaMatrix, Matrix, matmul - -__all__ = [ - "apply", - "Apply" -] - -# TODO: This should move to a common place to be shared with dictionary -#from monai.utils.type_conversion import dtypes_to_str_or_identity - -GridSampleModeSequence = Union[Sequence[Union[GridSampleMode, str]], GridSampleMode, str] -GridSamplePadModeSequence = Union[Sequence[Union[GridSamplePadMode, str]], GridSamplePadMode, str] -DtypeSequence = Union[Sequence[DtypeLike], DtypeLike] - - -# TODO: move to mapping_stack.py -def extents_from_shape(shape, dtype=np.float64): - extents = [[0, shape[i]] for i in range(1, len(shape))] - - extents = it.product(*extents) - return list(np.asarray(e + (1,), dtype=dtype) for e in extents) - - -# TODO: move to mapping_stack.py -def shape_from_extents( - src_shape: Sequence, - extents: Union[Sequence[np.ndarray], Sequence[torch.Tensor], np.ndarray, torch.Tensor] -): - if isinstance(extents, (list, tuple)): - if isinstance(extents[0], np.ndarray): - aextents = np.asarray(extents) - aextents = torch.from_numpy(aextents) - else: - aextents = torch.stack(extents) - else: - if isinstance(extents, np.ndarray): - aextents = torch.from_numpy(extents) - else: - aextents = extents - - mins = aextents.min(axis=0)[0] - maxes = aextents.max(axis=0)[0] - values = torch.round(maxes - mins).type(torch.IntTensor)[:-1] - return torch.cat((torch.as_tensor([src_shape[0]]), values)) - - -def metadata_is_compatible(value_1, value_2): - if value_1 is None: - return True - else: - if value_2 is None: - return True - return value_1 == value_2 - - -def metadata_dtype_is_compatible(value_1, value_2): - if value_1 is None: - return True - else: - if value_2 is None: - return True - - # if we are here, value_1 and value_2 are both set - # TODO: this is not a good enough solution - value_1_ = dtypes_to_str_or_identity(value_1) - value_2_ = dtypes_to_str_or_identity(value_2) - return value_1_ == value_2_ - - -def starting_matrix_and_extents(matrix_factory, data): - # set up the identity matrix and metadata - cumulative_matrix = matrix_factory.identity() - cumulative_extents = extents_from_shape(data.shape) - return cumulative_matrix, cumulative_extents - - -def prepare_args_dict_for_apply(cur_mode, cur_padding_mode, cur_device, cur_dtype): - kwargs = {} - if cur_mode is not None: - kwargs['mode'] = cur_mode - if cur_padding_mode is not None: - kwargs['padding_mode'] = cur_padding_mode - if cur_device is not None: - kwargs['device'] = cur_device - if cur_dtype is not None: - kwargs['dtype'] = cur_dtype - - return kwargs - - -def matrix_from_matrix_container(matrix): - if isinstance(matrix, MetaMatrix): - return matrix.matrix.data - elif isinstance(matrix, Matrix): - return matrix.data - else: - return matrix - - -def apply(data: Union[torch.Tensor, MetaTensor, dict], - pending: Optional[Union[dict, list]] = None): - """ - This method applies pending transforms to tensors. - - Args: - data: A torch Tensor, monai MetaTensor, or a dictionary containing Tensors / MetaTensors - pending: Optional arg containing pending transforms. This must be set if data is a Tensor - or dictionary of Tensors, but is optional if data is a MetaTensor / dictionary of - MetaTensors. - """ - if isinstance(data, dict): - rd = dict() - for k, v in data.items(): - result = apply(v, pending) - rd[k] = result - return rd - - if isinstance(data, MetaTensor) or pending is not None: - pending_ = [] if pending is None else pending - else: - pending_ = data.pending_transforms - - if len(pending_) == 0: - return data - - dim_count = len(data.shape) - 1 - matrix_factory = MatrixFactory(dim_count, - get_backend_from_tensor_like(data), - get_device_from_tensor_like(data)) - - # set up the identity matrix and metadata - cumulative_matrix, cumulative_extents = starting_matrix_and_extents(matrix_factory, data) - - # set the various resampling parameters to an initial state - cur_mode = None - cur_padding_mode = None - cur_device = None - cur_dtype = None - cur_shape = data.shape - - for meta_matrix in pending_: - next_matrix = meta_matrix.data - # print("intermediate matrix\n", matrix_from_matrix_container(cumulative_matrix)) - cumulative_matrix = matmul(cumulative_matrix, next_matrix) - cumulative_extents = [matmul(e, cumulative_matrix) for e in cumulative_extents] - - new_mode = meta_matrix.metadata.get('mode', None) - new_padding_mode = meta_matrix.metadata.get('padding_mode', None) - new_device = meta_matrix.metadata.get('device', None) - new_dtype = meta_matrix.metadata.get('dtype', None) - new_shape = meta_matrix.metadata.get('shape_override', None) - - mode_compat = metadata_is_compatible(cur_mode, new_mode) - padding_mode_compat = metadata_is_compatible(cur_padding_mode, new_padding_mode) - device_compat = metadata_is_compatible(cur_device, new_device) - dtype_compat = metadata_dtype_is_compatible(cur_dtype, new_dtype) - - if (mode_compat is False or padding_mode_compat is False or - device_compat is False or dtype_compat is False): - # carry out an intermediate resample here due to incompatibility between arguments - kwargs = prepare_args_dict_for_apply(cur_mode, cur_padding_mode, cur_device, cur_dtype) - - cumulative_matrix_ = matrix_from_matrix_container(cumulative_matrix) - a = Affine(norm_coords=False, - affine=cumulative_matrix_, - **kwargs) - data, _ = a(img=data) - - cumulative_matrix, cumulative_extents =\ - starting_matrix_and_extents(matrix_factory, data) - - cur_mode = cur_mode if new_mode is None else new_mode - cur_padding_mode = cur_padding_mode if new_padding_mode is None else new_padding_mode - cur_device = cur_device if new_device is None else new_device - cur_dtype = cur_dtype if new_dtype is None else new_dtype - cur_shape = cur_shape if new_shape is None else new_shape - - kwargs = prepare_args_dict_for_apply(cur_mode, cur_padding_mode, cur_device, cur_dtype) - - cumulative_matrix_ = matrix_from_matrix_container(cumulative_matrix) - - # print(f"applying with cumulative matrix\n {cumulative_matrix_}") - a = Affine(norm_coords=False, - affine=cumulative_matrix_, - spatial_size=cur_shape[1:], - normalized=False, - **kwargs) - data, tx = a(img=data) - data.clear_pending_transforms() - - return data - - -# make Apply universal for arrays and dictionaries; it just calls through to functional apply -class Apply(InvertibleTransform): - """ - Apply wraps the apply method and can function as a Transform in either array or dictionary - mode. - """ - def __init__(self): - super().__init__() - - def __call__(self, *args, **kwargs): - return apply(*args, **kwargs) - - def inverse(self, data): - return NotImplementedError() - - -# class Applyd(MapTransform, InvertibleTransform): -# -# def __init__(self): -# super().__init__() -# -# def __call__( -# self, -# d: dict -# ): -# rd = dict() -# for k, v in d.items(): -# rd[k] = apply(v) -# -# def inverse(self, data): -# return NotImplementedError() diff --git a/monai/transforms/lazy/array.py b/monai/transforms/lazy/array.py index ae165cf566..f295d0cae1 100644 --- a/monai/transforms/lazy/array.py +++ b/monai/transforms/lazy/array.py @@ -20,7 +20,6 @@ class Apply(InvertibleTransform): Apply wraps the apply method and can function as a Transform in either array or dictionary mode. """ - def __init__(self): super().__init__() From b827cffe24c4bfc2e41b61b7f5e52d33998dc6e2 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Thu, 27 Oct 2022 21:54:57 +0100 Subject: [PATCH 48/52] Removing spurious test file tempscript Signed-off-by: Ben Murray --- tests/tempscript.py | 47 --------------------------------------------- 1 file changed, 47 deletions(-) delete mode 100644 tests/tempscript.py diff --git a/tests/tempscript.py b/tests/tempscript.py deleted file mode 100644 index 5a9e5945b6..0000000000 --- a/tests/tempscript.py +++ /dev/null @@ -1,47 +0,0 @@ -import numpy as np - -import matplotlib.pyplot as plt - -import torch -from monai.utils import GridSampleMode, GridSamplePadMode - -from monai.transforms.atmostonce.apply import Applyd -from monai.transforms.atmostonce.dictionary import Rotated -from monai.transforms import Compose - - - -def test_rotate_tensor(): - r = Rotated(('image', 'label'), [0.0, 1.0, 0.0]) - - d = { - 'image': torch.zeros((1, 64, 64, 32), device="cpu", dtype=torch.float32), - 'label': torch.ones((1, 64, 64, 32), device="cpu", dtype=torch.int8) - } - d = r(d) - - for k, v in d.items(): - if isinstance(v, (np.ndarray, torch.Tensor)): - print(k, v.shape) - else: - print(k, v) - - -def test_rotate_apply(): - c = Compose([ - Rotated(('image', 'label'), (0.0, 3.14159265 / 2, 0.0)), - Applyd(('image', 'label'), - modes=(GridSampleMode.BILINEAR, GridSampleMode.NEAREST), - padding_modes=(GridSamplePadMode.BORDER, GridSamplePadMode.BORDER)) - ]) - - d = { - 'image': torch.zeros((1, 64, 64, 32), device="cpu", dtype=torch.float32), - 'label': torch.ones((1, 64, 64, 32), device="cpu", dtype=torch.int8) - } - plt.imshow(d['image'][0, ..., d['image'].shape[-1]//2]) - d = c(d) - plt.imshow(d['image'][0, ..., d['image'].shape[-1]//2]) - print(d['image'].shape) - -test_rotate_apply() From 1cfe0a0e8d7ce2c77ca41fa033b15591df2a333f Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Thu, 27 Oct 2022 22:13:08 +0100 Subject: [PATCH 49/52] Auto formatting fixes Signed-off-by: Ben Murray --- monai/transforms/lazy/array.py | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/transforms/lazy/array.py b/monai/transforms/lazy/array.py index f295d0cae1..ae165cf566 100644 --- a/monai/transforms/lazy/array.py +++ b/monai/transforms/lazy/array.py @@ -20,6 +20,7 @@ class Apply(InvertibleTransform): Apply wraps the apply method and can function as a Transform in either array or dictionary mode. """ + def __init__(self): super().__init__() From adbd5fab9e9282f599208a698de3f071fdb7289d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 27 Oct 2022 22:11:41 +0000 Subject: [PATCH 50/52] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Ben Murray --- tests/test_apply.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_apply.py b/tests/test_apply.py index 2447d8a3af..bc7798ef76 100644 --- a/tests/test_apply.py +++ b/tests/test_apply.py @@ -11,7 +11,6 @@ import unittest - import torch from monai.utils import convert_to_tensor @@ -19,7 +18,6 @@ from monai.transforms.meta_matrix import MetaMatrix - def rotate_45_2d(): t = torch.eye(3) t[:, 0] = torch.FloatTensor([0, -1, 0]) From e62a3162900fce9ba06cf0d62ad049a65bc51db1 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Fri, 28 Oct 2022 14:57:52 +0100 Subject: [PATCH 51/52] Further array functionality and testing; waiting on PR #5107 Signed-off-by: Ben Murray --- tests/test_apply.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/test_apply.py b/tests/test_apply.py index bc7798ef76..019bae42f1 100644 --- a/tests/test_apply.py +++ b/tests/test_apply.py @@ -42,6 +42,16 @@ def _test_apply_metatensor_impl(self, tensor, pending_transforms, pending_as_par # tensor_.push_pending(p) raise NotImplementedError() + def _test_apply_metatensor_impl(self, tensor, pending_transforms, pending_as_parameter): + tensor_ = convert_to_tensor(tensor) + if pending_as_parameter: + result = apply(tensor_, pending_transforms) + else: + for p in pending_transforms: + # TODO: cannot do the next part until the MetaTensor PR #5107 is in + # tensor_.push_pending(p) + raise NotImplementedError() + SINGLE_TRANSFORM_CASES = [ (torch.randn((1, 16, 16)), [MetaMatrix(rotate_45_2d(), {"id": "rotate"})]) ] From d10d74ff1a3ad52c0387b1a3918f75cd3f6c61c3 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Fri, 28 Oct 2022 15:53:06 +0100 Subject: [PATCH 52/52] Adding resample function for unified resampling and test placeholder Signed-off-by: Ben Murray --- tests/test_resample.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_resample.py b/tests/test_resample.py index c032481389..4e33517bd1 100644 --- a/tests/test_resample.py +++ b/tests/test_resample.py @@ -11,12 +11,11 @@ import unittest - import torch from monai.transforms.utility.functional import resample -from monai.utils import convert_to_tensor +from monai.utils import convert_to_tensor from tests.utils import get_arange_img