From d557549c639d091dcddc0ea8e378bb41453c67d8 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Sun, 30 Oct 2022 11:14:34 +0000 Subject: [PATCH 01/19] Commit of functionality on replacement lr_apply branch due to issues with rabasing and automatic signatures Signed-off-by: Ben Murray --- monai/transforms/__init__.py | 11 ++ monai/transforms/lazy/__init__.py | 10 + monai/transforms/lazy/array.py | 48 +++++ monai/transforms/lazy/functional.py | 201 +++++++++++++++++++ monai/transforms/meta_matrix.py | 285 +++++++++++++++++++++++++++ monai/transforms/utility/resample.py | 32 +++ monai/transforms/utils.py | 206 +++++++++++++++++++ tests/test_apply.py | 73 +++++++ tests/test_matmul.py | 143 ++++++++++++++ tests/test_resample.py | 46 +++++ tests/utils.py | 18 ++ 11 files changed, 1073 insertions(+) create mode 100644 monai/transforms/lazy/__init__.py create mode 100644 monai/transforms/lazy/array.py create mode 100644 monai/transforms/lazy/functional.py create mode 100644 monai/transforms/meta_matrix.py create mode 100644 monai/transforms/utility/resample.py create mode 100644 tests/test_apply.py create mode 100644 tests/test_matmul.py create mode 100644 tests/test_resample.py diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 9cabc167a7..6416414c05 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -227,6 +227,15 @@ from .inverse_batch_transform import BatchInverseTransform, Decollated, DecollateD, DecollateDict from .io.array import SUPPORTED_READERS, LoadImage, SaveImage from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict +from .lazy.array import Apply +from .lazy.functional import apply +from .meta_matrix import ( + Grid, + matmul, + Matrix, + MatrixFactory, + MetaMatrix, +) from .meta_utility.dictionary import ( FromMetaTensord, FromMetaTensorD, @@ -632,6 +641,8 @@ generate_label_classes_crop_centers, generate_pos_neg_label_crop_centers, generate_spatial_bounding_box, + get_backend_from_tensor_like, + get_device_from_tensor_like, get_extreme_points, get_largest_connected_component_mask, get_number_image_type_conversions, diff --git a/monai/transforms/lazy/__init__.py b/monai/transforms/lazy/__init__.py new file mode 100644 index 0000000000..1e97f89407 --- /dev/null +++ b/monai/transforms/lazy/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/monai/transforms/lazy/array.py b/monai/transforms/lazy/array.py new file mode 100644 index 0000000000..e271d07703 --- /dev/null +++ b/monai/transforms/lazy/array.py @@ -0,0 +1,48 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from monai.transforms.lazy.functional import apply +from monai.transforms.inverse import InvertibleTransform + +__all__ = ["Apply"] + + +class Apply(InvertibleTransform): + """ + Apply wraps the apply method and can function as a Transform in either array or dictionary + mode. + """ + + def __init__(self): + super().__init__() + + def __call__(self, *args, **kwargs): + return apply(*args, **kwargs) + + def inverse(self, data): + return NotImplementedError() + + +# class Applyd(MapTransform, InvertibleTransform): +# +# def __init__(self): +# super().__init__() +# +# def __call__( +# self, +# d: dict +# ): +# rd = dict() +# for k, v in d.items(): +# rd[k] = apply(v) +# +# def inverse(self, data): +# return NotImplementedError() \ No newline at end of file diff --git a/monai/transforms/lazy/functional.py b/monai/transforms/lazy/functional.py new file mode 100644 index 0000000000..93594bfa74 --- /dev/null +++ b/monai/transforms/lazy/functional.py @@ -0,0 +1,201 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools as it +from typing import Optional, Sequence, Union + +import numpy as np +import torch + +from monai.config import DtypeLike +from monai.data.meta_tensor import MetaTensor +from monai.transforms.meta_matrix import Matrix, MatrixFactory, MetaMatrix, matmul +from monai.transforms.spatial.array import Affine +from monai.transforms.utils import dtypes_to_str_or_identity, get_backend_from_tensor_like, get_device_from_tensor_like +from monai.utils import GridSampleMode, GridSamplePadMode + +__all__ = ["apply"] + +# TODO: This should move to a common place to be shared with dictionary +# from monai.utils.type_conversion import dtypes_to_str_or_identity + +GridSampleModeSequence = Union[Sequence[Union[GridSampleMode, str]], GridSampleMode, str] +GridSamplePadModeSequence = Union[Sequence[Union[GridSamplePadMode, str]], GridSamplePadMode, str] +DtypeSequence = Union[Sequence[DtypeLike], DtypeLike] + + +# TODO: move to mapping_stack.py +def extents_from_shape(shape, dtype=np.float32): + extents = [[0, shape[i]] for i in range(1, len(shape))] + + extents = it.product(*extents) + return [np.asarray(e + (1,), dtype=dtype) for e in extents] + + +# TODO: move to mapping_stack.py +def shape_from_extents( + src_shape: Sequence, extents: Union[Sequence[np.ndarray], Sequence[torch.Tensor], np.ndarray, torch.Tensor] +): + if isinstance(extents, (list, tuple)): + if isinstance(extents[0], np.ndarray): + aextents = np.asarray(extents) + aextents = torch.from_numpy(aextents) + else: + aextents = torch.stack(extents) + else: + if isinstance(extents, np.ndarray): + aextents = torch.from_numpy(extents) + else: + aextents = extents + + mins = aextents.min(axis=0)[0] + maxes = aextents.max(axis=0)[0] + values = torch.round(maxes - mins).type(torch.IntTensor)[:-1] + return torch.cat((torch.as_tensor([src_shape[0]]), values)) + + +def metadata_is_compatible(value_1, value_2): + if value_1 is None: + return True + else: + if value_2 is None: + return True + return value_1 == value_2 + + +def metadata_dtype_is_compatible(value_1, value_2): + if value_1 is None: + return True + else: + if value_2 is None: + return True + + # if we are here, value_1 and value_2 are both set + # TODO: this is not a good enough solution + value_1_ = dtypes_to_str_or_identity(value_1) + value_2_ = dtypes_to_str_or_identity(value_2) + return value_1_ == value_2_ + + +def starting_matrix_and_extents(matrix_factory, data): + # set up the identity matrix and metadata + cumulative_matrix = matrix_factory.identity() + cumulative_extents = extents_from_shape(data.shape) + return cumulative_matrix, cumulative_extents + + +def prepare_args_dict_for_apply(cur_mode, cur_padding_mode, cur_device, cur_dtype): + kwargs = {} + if cur_mode is not None: + kwargs["mode"] = cur_mode + if cur_padding_mode is not None: + kwargs["padding_mode"] = cur_padding_mode + if cur_device is not None: + kwargs["device"] = cur_device + if cur_dtype is not None: + kwargs["dtype"] = cur_dtype + + return kwargs + + +def matrix_from_matrix_container(matrix): + if isinstance(matrix, MetaMatrix): + return matrix.matrix.data + elif isinstance(matrix, Matrix): + return matrix.data + else: + return matrix + + +def apply(data: Union[torch.Tensor, MetaTensor, dict], pending: Optional[Union[dict, list]] = None): + """ + This method applies pending transforms to tensors. + Args: + data: A torch Tensor, monai MetaTensor, or a dictionary containing Tensors / MetaTensors + pending: Optional arg containing pending transforms. This must be set if data is a Tensor + or dictionary of Tensors, but is optional if data is a MetaTensor / dictionary of + MetaTensors. + """ + if isinstance(data, dict): + rd = dict() + for k, v in data.items(): + result = apply(v, pending) + rd[k] = result + return rd + + if isinstance(data, MetaTensor) and pending is None: + pending_ = data.pending_transforms + else: + pending_ = [] if pending is None else pending + + if len(pending_) == 0: + return data + + dim_count = len(data.shape) - 1 + matrix_factory = MatrixFactory(dim_count, get_backend_from_tensor_like(data), get_device_from_tensor_like(data)) + + # set up the identity matrix and metadata + cumulative_matrix, cumulative_extents = starting_matrix_and_extents(matrix_factory, data) + + # set the various resampling parameters to an initial state + cur_mode = None + cur_padding_mode = None + cur_device = None + cur_dtype = None + cur_shape = data.shape + + for meta_matrix in pending_: + next_matrix = meta_matrix.matrix + # print("intermediate matrix\n", matrix_from_matrix_container(cumulative_matrix)) + cumulative_matrix = matmul(cumulative_matrix, next_matrix) + cumulative_extents = [matmul(e, cumulative_matrix) for e in cumulative_extents] + + new_mode = meta_matrix.metadata.get("mode", None) + new_padding_mode = meta_matrix.metadata.get("padding_mode", None) + new_device = meta_matrix.metadata.get("device", None) + new_dtype = meta_matrix.metadata.get("dtype", None) + new_shape = meta_matrix.metadata.get("shape_override", None) + + mode_compat = metadata_is_compatible(cur_mode, new_mode) + padding_mode_compat = metadata_is_compatible(cur_padding_mode, new_padding_mode) + device_compat = metadata_is_compatible(cur_device, new_device) + dtype_compat = metadata_dtype_is_compatible(cur_dtype, new_dtype) + + if mode_compat is False or padding_mode_compat is False or device_compat is False or dtype_compat is False: + # carry out an intermediate resample here due to incompatibility between arguments + kwargs = prepare_args_dict_for_apply(cur_mode, cur_padding_mode, cur_device, cur_dtype) + + cumulative_matrix_ = matrix_from_matrix_container(cumulative_matrix) + a = Affine(norm_coords=False, affine=cumulative_matrix_, **kwargs) + data, _ = a(img=data) + + cumulative_matrix, cumulative_extents = starting_matrix_and_extents(matrix_factory, data) + + cur_mode = cur_mode if new_mode is None else new_mode + cur_padding_mode = cur_padding_mode if new_padding_mode is None else new_padding_mode + cur_device = cur_device if new_device is None else new_device + cur_dtype = cur_dtype if new_dtype is None else new_dtype + cur_shape = cur_shape if new_shape is None else new_shape + + kwargs = prepare_args_dict_for_apply(cur_mode, cur_padding_mode, cur_device, cur_dtype) + + cumulative_matrix_ = matrix_from_matrix_container(cumulative_matrix) + + # print(f"applying with cumulative matrix\n {cumulative_matrix_}") + a = Affine(norm_coords=False, affine=cumulative_matrix_, spatial_size=cur_shape[1:], normalized=False, **kwargs) + data, tx = a(img=data) + if isinstance(data, MetaTensor): + data.clear_pending_transforms() + for p in pending_: + data.affine = p.matrix.data + data.push_applied_operation(p) + + return data, pending_ diff --git a/monai/transforms/meta_matrix.py b/monai/transforms/meta_matrix.py new file mode 100644 index 0000000000..845fe3f2c5 --- /dev/null +++ b/monai/transforms/meta_matrix.py @@ -0,0 +1,285 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Sequence, Union + +import numpy as np +import torch +from monai.transforms.utils import _create_rotate, _create_rotate_90, _create_flip, _create_shear, _create_scale, \ + _create_translate + +from monai.transforms.utils import get_backend_from_tensor_like, get_device_from_tensor_like +from monai.utils import TransformBackends + +from monai.config import NdarrayOrTensor + +__all__ = ["Grid", "matmul", "Matrix", "MatrixFactory", "MetaMatrix"] + +def is_matrix_shaped(data): + + return ( + len(data.shape) == 2 and data.shape[0] in (3, 4) and data.shape[1] in (3, 4) and data.shape[0] == data.shape[1] + ) + + +def is_grid_shaped(data): + + return len(data.shape) == 3 and data.shape[0] == 3 or len(data.shape) == 4 and data.shape[0] == 4 + + +class MatrixFactory: + + def __init__(self, + dims: int, + backend: TransformBackends, + device: Optional[torch.device] = None): + + if backend == TransformBackends.NUMPY: + if device is not None: + raise ValueError("'device' must be None with TransformBackends.NUMPY") + self._device = None + self._sin = lambda th: np.sin(th, dtype=np.float32) + self._cos = lambda th: np.cos(th, dtype=np.float32) + self._eye = lambda th: np.eye(th, dtype=np.float32) + self._diag = lambda th: np.diag(th, dtype=np.float32) + else: + if device is None: + raise ValueError("'device' must be set with TransformBackends.TORCH") + self._device = device + self._sin = lambda th: torch.sin(torch.as_tensor(th, + dtype=torch.float32, + device=self._device)) + self._cos = lambda th: torch.cos(torch.as_tensor(th, + dtype=torch.float32, + device=self._device)) + self._eye = lambda rank: torch.eye(rank, + device=self._device, + dtype=torch.float32); + self._diag = lambda size: torch.diag(torch.as_tensor(size, + device=self._device, + dtype=torch.float32)) + + self._backend = backend + self._dims = dims + + @staticmethod + def from_tensor(data): + return MatrixFactory(len(data.shape)-1, + get_backend_from_tensor_like(data), + get_device_from_tensor_like(data)) + + def identity(self): + matrix = self._eye(self._dims + 1) + return MetaMatrix(matrix, {}) + + def rotate_euler(self, radians: Union[Sequence[float], float], **extra_args): + matrix = _create_rotate(self._dims, radians, self._sin, self._cos, self._eye) + return MetaMatrix(matrix, extra_args) + + def rotate_90(self, rotations, axis, **extra_args): + matrix = _create_rotate_90(self._dims, rotations, axis) + return MetaMatrix(matrix, extra_args) + + def flip(self, axis, **extra_args): + matrix = _create_flip(self._dims, axis, self._eye) + return MetaMatrix(matrix, extra_args) + + def shear(self, coefs: Union[Sequence[float], float], **extra_args): + matrix = _create_shear(self._dims, coefs, self._eye) + return MetaMatrix(matrix, extra_args) + + def scale(self, factors: Union[Sequence[float], float], **extra_args): + matrix = _create_scale(self._dims, factors, self._diag) + return MetaMatrix(matrix, extra_args) + + def translate(self, offsets: Union[Sequence[float], float], **extra_args): + matrix = _create_translate(self._dims, offsets, self._eye) + return MetaMatrix(matrix, extra_args) + + +def ensure_tensor(data: NdarrayOrTensor): + if isinstance(data, torch.Tensor): + return data + + return torch.as_tensor(data) + + +class Matrix: + def __init__(self, matrix: NdarrayOrTensor): + self.data = ensure_tensor(matrix) + + # def __matmul__(self, other): + # if isinstance(other, Matrix): + # other_matrix = other.data + # else: + # other_matrix = other + # return self.data @ other_matrix + # + # def __rmatmul__(self, other): + # return other.__matmul__(self.data) + + +class Grid: + def __init__(self, grid): + self.data = ensure_tensor(grid) + + # def __matmul__(self, other): + # raise NotImplementedError() + + +class MetaMatrix: + def __init__(self, matrix: Union[NdarrayOrTensor, Matrix, Grid], metadata: Optional[dict] = None): + + if not isinstance(matrix, (Matrix, Grid)): + if matrix.shape == 2: + if matrix.shape[0] != matrix.shape[1] or matrix.shape[0] not in (3, 4): + raise ValueError( + "If 'matrix' is passed a numpy ndarray/torch Tensor, it must" + f" be 3x3 or 4x4 ('matrix' has has shape {matrix.shape})" + ) + matrix_ = Matrix(matrix) + else: + matrix_ = matrix + self.matrix = matrix_ + + self.metadata = metadata or {} + + def __matmul__(self, other): + if isinstance(other, MetaMatrix): + other_ = other.matrix + else: + other_ = other + return MetaMatrix(self.matrix @ other_) + + def __rmatmul__(self, other): + if isinstance(other, MetaMatrix): + other_ = other.matrix + else: + other_ = other + return MetaMatrix(other_ @ self.matrix) + + +def matmul( + left: Union[MetaMatrix, Grid, Matrix, NdarrayOrTensor], right: Union[MetaMatrix, Grid, Matrix, NdarrayOrTensor] +): + matrix_types = (MetaMatrix, Grid, Matrix, torch.Tensor, np.ndarray) + + if not isinstance(left, matrix_types): + raise TypeError(f"'left' must be one of {matrix_types} but is {type(left)}") + if not isinstance(right, matrix_types): + raise TypeError(f"'second' must be one of {matrix_types} but is {type(right)}") + + left_ = left.matrix if isinstance(left, MetaMatrix) else left + right_ = right.matrix if isinstance(right, MetaMatrix) else right + + # TODO: it might be better to not return a metamatrix, unless we pass in the resulting + # metadata also + put_in_metamatrix = isinstance(left, MetaMatrix) or isinstance(right, MetaMatrix) + + put_in_grid = isinstance(left, Grid) or isinstance(right, Grid) + + put_in_matrix = isinstance(left, Matrix) or isinstance(right, Matrix) + put_in_matrix = False if put_in_grid is True else put_in_matrix + + promote_to_tensor = not (isinstance(left_, np.ndarray) and isinstance(right_, np.ndarray)) + + left_raw = left_.data if isinstance(left_, (Matrix, Grid)) else left_ + right_raw = right_.data if isinstance(right_, (Matrix, Grid)) else right_ + + if promote_to_tensor: + left_raw = torch.as_tensor(left_raw) + right_raw = torch.as_tensor(right_raw) + + if isinstance(left_, Grid): + if isinstance(right_, Grid): + raise RuntimeError("Unable to matrix multiply two Grids") + else: + result = matmul_grid_matrix(left_raw, right_raw) + else: + if isinstance(right_, Grid): + result = matmul_matrix_grid(left_raw, right_raw) + else: + result = matmul_matrix_matrix(left_raw, right_raw) + + if put_in_grid: + result = Grid(result) + elif put_in_matrix: + result = Matrix(result) + + if put_in_metamatrix: + result = MetaMatrix(result) + + return result + + +def matmul_matrix_grid(left: NdarrayOrTensor, right: NdarrayOrTensor): + if not is_matrix_shaped(left): + raise ValueError(f"'left' should be a 2D or 3D homogenous matrix but has shape {left.shape}") + + if not is_grid_shaped(right): + raise ValueError( + "'right' should be a 3D array with shape[0] == 2 or a " + f"4D array with shape[0] == 3 but has shape {right.shape}" + ) + + # flatten the grid to take advantage of torch batch matrix multiply + right_flat = right.reshape(right.shape[0], -1) + result_flat = left @ right_flat + # restore the grid shape + result = result_flat.reshape((-1,) + result_flat.shape[1:]) + return result + + +def matmul_grid_matrix(left: NdarrayOrTensor, right: NdarrayOrTensor): + if not is_grid_shaped(left): + raise ValueError( + "'left' should be a 3D array with shape[0] == 2 or a " + f"4D array with shape[0] == 3 but has shape {left.shape}" + ) + + if not is_matrix_shaped(right): + raise ValueError(f"'right' should be a 2D or 3D homogenous matrix but has shape {right.shape}") + + try: + inv_matrix = torch.inverse(right) + except RuntimeError: + # the matrix is not invertible, so we will have to perform a slow grid to matrix operation + return matmul_grid_matrix_slow(left, right) + + # invert the matrix and swap the arguments, taking advantage of + # matrix @ vector == vector_transposed @ matrix_inverse + return matmul_matrix_grid(inv_matrix, left) + + +def matmul_grid_matrix_slow(left: NdarrayOrTensor, right: NdarrayOrTensor): + if not is_grid_shaped(left): + raise ValueError( + "'left' should be a 3D array with shape[0] == 2 or a " + f"4D array with shape[0] == 3 but has shape {left.shape}" + ) + + if not is_matrix_shaped(right): + raise ValueError(f"'right' should be a 2D or 3D homogenous matrix but has shape {right.shape}") + + flat_left = left.reshape(left.shape[0], -1) + result_flat = torch.zeros_like(flat_left) + for i in range(flat_left.shape[1]): + vector = flat_left[:, i][None, :] + result_vector = vector @ right + result_flat[:, i] = result_vector[0, :] + + # restore the grid shape + result = result_flat.reshape((-1,) + result_flat.shape[1:]) + return result + + +def matmul_matrix_matrix(left: NdarrayOrTensor, right: NdarrayOrTensor): + return left @ right diff --git a/monai/transforms/utility/resample.py b/monai/transforms/utility/resample.py new file mode 100644 index 0000000000..7aec22c723 --- /dev/null +++ b/monai/transforms/utility/resample.py @@ -0,0 +1,32 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Union + +import torch +from monai.transforms import Affine + +from monai.config import NdarrayOrTensor +from monai.transforms.meta_matrix import Grid, Matrix + + +def resample( + data: torch.Tensor, + matrix: Union[NdarrayOrTensor, Matrix, Grid], + kwargs: Optional[dict] = None +): + """ + This is a minimal implementation of resample that always uses Affine. + """ + if kwargs is not None: + a = Affine(affine=matrix, **kwargs) + else: + a = Affine(affine=matrix) + return a(img=data) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index e96d906f20..bdf392779a 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -889,6 +889,124 @@ def _create_translate( return array_func(affine) # type: ignore +def _create_rotate_90( + spatial_dims: int, + axis: Tuple[int, int], + steps: Optional[int] = 1, + eye_func: Callable = np.eye +) -> NdarrayOrTensor: + + values = [(1, 0, 0, 1), + (0, -1, 1, 0), + (-1, 0, 0, -1), + (0, 1, -1, 0)] + + if spatial_dims == 2: + if axis != (0, 1): + raise ValueError(f"if 'spatial_dims' is 2, 'axis' must be (0, 1) but is {axis}") + elif spatial_dims == 3: + if axis not in ((0, 1), (0, 2), (1, 2)): + raise ValueError("if 'spatial_dims' is 3, 'axis' must be (0,1), (0, 2), or (1, 2) " + f"but is {axis}") + else: + raise ValueError(f"'spatial_dims' must be 2 or 3 but is {spatial_dims}") + + steps_ = steps % 4 + + affine = eye_func(spatial_dims + 1) + + if spatial_dims == 2: + a, b = 0, 1 + else: + a, b = axis + + affine[a, a], affine[a, b], affine[b, a], affine[b, b] = values[steps] + return affine + + +def create_rotate_90( + spatial_dims: int, + axis: int, + steps: Optional[int] = 1, + device: Optional[torch.device] = None, + backend: str = TransformBackends.NUMPY, +) -> NdarrayOrTensor: + """ + create a 2D or 3D rotation matrix + Args: + spatial_dims: {``2``, ``3``} spatial rank + radians: rotation radians + when spatial_dims == 3, the `radians` sequence corresponds to + rotation in the 1st, 2nd, and 3rd dim respectively. + device: device to compute and store the output (when the backend is "torch"). + backend: APIs to use, ``numpy`` or ``torch``. + Raises: + ValueError: When ``radians`` is empty. + ValueError: When ``spatial_dims`` is not one of [2, 3]. + """ + _backend = look_up_option(backend, TransformBackends) + if _backend == TransformBackends.NUMPY: + return _create_rotate_90( + spatial_dims=spatial_dims, + axis=axis, + steps=steps, + eye_func=np.eye) + if _backend == TransformBackends.TORCH: + return _create_rotate_90( + spatial_dims=spatial_dims, + axis=axis, + steps=steps, + eye_func=lambda rank: torch.eye(rank, device=device), + ) + raise ValueError(f"backend {backend} is not supported") + + +def _create_flip( + spatial_dims: int, + spatial_axis: Union[Sequence[int], int], + eye_func: Callable = np.eye +): + affine = eye_func(spatial_dims + 1) + if isinstance(spatial_axis, int): + if spatial_axis < -spatial_dims or spatial_axis >= spatial_dims: + raise ValueError("'spatial_axis' values must be between " + f"{-spatial_dims} and {spatial_dims-1} inclusive " + f"('spatial_axis' is {spatial_axis})") + affine[spatial_axis, spatial_axis] = -1 + else: + if any((s < -spatial_dims or s >= spatial_dims) for s in spatial_axis): + raise ValueError("'spatial_axis' values must be between " + f"{-spatial_dims} and {spatial_dims-1} inclusive " + f"('spatial_axis' is {spatial_axis})") + + for i in range(spatial_dims): + if i in spatial_axis: + affine[i, i] = -1 + + return affine + + +def create_flip( + spatial_dims: int, + spatial_axis: Union[Sequence[int], int], + device: Optional[torch.device] = None, + backend: str = TransformBackends.NUMPY, +) -> NdarrayOrTensor: + _backend = look_up_option(backend, TransformBackends) + if _backend == TransformBackends.NUMPY: + return _create_flip( + spatial_dims=spatial_dims, + spatial_axis=spatial_axis, + eye_func=np.eye) + if _backend == TransformBackends.TORCH: + return _create_flip( + spatial_dims=spatial_dims, + spatial_axis=spatial_axis, + eye_func=lambda rank: torch.eye(rank, device=device), + ) + raise ValueError(f"backend {backend} is not supported") + + def generate_spatial_bounding_box( img: NdarrayOrTensor, select_fn: Callable = is_positive, @@ -1790,5 +1908,93 @@ def squarepulse(sig, duty: float = 0.5): return y +def get_device_from_tensor_like(data: NdarrayOrTensor): + """ + This function returns the device of `data`, which must either be a numpy ndarray or a + pytorch Tensor. + Args: + data: the ndarray/tensor to return the device of + Returns: + None if `data` is a numpy array, or the device of the pytorch Tensor + """ + if isinstance(data, np.ndarray): + return None + elif isinstance(data, torch.Tensor): + return data.device + else: + msg = "'data' must be one of numpy ndarray or torch Tensor but is {}" + raise ValueError(msg.format(type(data))) + + +def get_backend_from_tensor_like(data: NdarrayOrTensor): + """ + This function returns the backend of `data`, which must either be a numpy ndarray or a + pytorch Tensor. + Args: + data: the ndarray/tensor to return the device of + Returns: + None if `data` is a numpy array, or the device of the pytorch Tensor + """ + if isinstance(data, np.ndarray): + return TransformBackends.NUMPY + elif isinstance(data, torch.Tensor): + return TransformBackends.TORCH + else: + msg = "'data' must be one of numpy ndarray or torch Tensor but is {}" + raise ValueError(msg.format(type(data))) + + +def dtype_torch_to_numpy(dtype: torch.dtype) -> np.dtype: + """Convert a torch dtype to its numpy equivalent.""" + return torch.empty([], dtype=dtype).numpy().dtype # type: ignore + + +def dtype_numpy_to_torch(dtype: np.dtype) -> torch.dtype: + """Convert a numpy dtype to its torch equivalent.""" + return torch.from_numpy(np.empty([], dtype=dtype)).dtype + + +__dtype_dict = { + np.int8: "int8", + torch.int8: "int8", + np.int16: "int16", + torch.int16: "int16", + int: "int32", + np.int32: "int32", + torch.int32: "int32", + np.int64: "int64", + torch.int64: "int64", + np.uint8: "uint8", + torch.uint8: "uint8", + np.uint16: "uint16", + np.uint32: "uint32", + np.uint64: "uint64", + float: "float32", + np.float16: "float16", + np.float: "float32", + np.float32: "float32", + np.float64: "float64", + torch.float16: "float16", + torch.float: "float32", + torch.float32: "float32", + torch.double: "float64", + torch.float64: "float64", +} + + +def dtypes_to_str_or_identity(dtype: Any) -> Any: + return __dtype_dict.get(dtype, dtype) + + +def get_numpy_dtype_from_string(dtype: str) -> np.dtype: + """Get a numpy dtype (e.g., `np.float32`) from its string (e.g., `"float32"`).""" + return np.empty([], dtype=dtype).dtype + + +def get_torch_dtype_from_string(dtype: str) -> torch.dtype: + """Get a torch dtype (e.g., `torch.float32`) from its string (e.g., `"float32"`).""" + return dtype_numpy_to_torch(get_numpy_dtype_from_string(dtype)) + + if __name__ == "__main__": print_transform_backends() diff --git a/tests/test_apply.py b/tests/test_apply.py new file mode 100644 index 0000000000..9cd9aee462 --- /dev/null +++ b/tests/test_apply.py @@ -0,0 +1,73 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from monai.utils import convert_to_tensor + +from monai.transforms.lazy.functional import apply +from monai.transforms.meta_matrix import MetaMatrix + + +def rotate_45_2d(): + t = torch.eye(3) + t[:, 0] = torch.FloatTensor([0, -1, 0]) + t[:, 1] = torch.FloatTensor([1, 0, 0]) + return t + + +class TestApply(unittest.TestCase): + + def _test_apply_impl(self, tensor, pending_transforms): + print(tensor.shape) + result = apply(tensor, pending_transforms) + self.assertListEqual(result[1], pending_transforms) + + def _test_apply_metatensor_impl(self, tensor, pending_transforms, pending_as_parameter): + tensor_ = convert_to_tensor(tensor) + if pending_as_parameter: + result = apply(tensor_, pending_transforms) + else: + for p in pending_transforms: + # TODO: cannot do the next part until the MetaTensor PR #5107 is in + # tensor_.push_pending(p) + raise NotImplementedError() + + def _test_apply_metatensor_impl(self, tensor, pending_transforms, pending_as_parameter): + tensor_ = convert_to_tensor(tensor) + if pending_as_parameter: + result = apply(tensor_, pending_transforms) + else: + for p in pending_transforms: + # TODO: cannot do the next part until the MetaTensor PR #5107 is in + # tensor_.push_pending(p) + raise NotImplementedError() + + SINGLE_TRANSFORM_CASES = [ + (torch.randn((1, 16, 16)), [MetaMatrix(rotate_45_2d(), {"id": "rotate"})]) + ] + + def test_apply_single_transform(self): + for case in self.SINGLE_TRANSFORM_CASES: + self._test_apply_impl(*case) + + def test_apply_single_transform_metatensor(self): + for case in self.SINGLE_TRANSFORM_CASES: + self._test_apply_metatensor_impl(*case, False) + + def test_apply_single_transform_metatensor_override(self): + for case in self.SINGLE_TRANSFORM_CASES: + self._test_apply_metatensor_impl(*case, True) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_matmul.py b/tests/test_matmul.py new file mode 100644 index 0000000000..2ac3b70a2e --- /dev/null +++ b/tests/test_matmul.py @@ -0,0 +1,143 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.transforms.meta_matrix import ( + Grid, + Matrix, + is_grid_shaped, + is_matrix_shaped, + matmul, + matmul_grid_matrix, + matmul_grid_matrix_slow, + matmul_matrix_grid, +) + + +class TestMatmulFunctions(unittest.TestCase): + def test_matrix_grid_and_grid_inv_matrix_are_equivalent(self): + grid = torch.randn((3, 32, 32)) + + matrix = torch.eye(3, 3) + matrix[:, 0] = torch.FloatTensor([0, -1, 0]) + matrix[:, 1] = torch.FloatTensor([1, 0, 0]) + + inv_matrix = torch.inverse(matrix) + + result1 = matmul_matrix_grid(matrix, grid) + result2 = matmul_grid_matrix(grid, inv_matrix) + self.assertTrue(torch.allclose(result1, result2)) + + def test_matmul_grid_matrix_slow(self): + grid = torch.randn((3, 32, 32)) + + matrix = torch.eye(3, 3) + matrix[:, 0] = torch.FloatTensor([0, -1, 0]) + matrix[:, 1] = torch.FloatTensor([1, 0, 0]) + + result1 = matmul_grid_matrix_slow(grid, matrix) + result2 = matmul_grid_matrix(grid, matrix) + self.assertTrue(torch.allclose(result1, result2)) + + MATMUL_TEST_CASES = [ + [np.eye(3, dtype=np.float32), np.eye(3, dtype=np.float32), np.ndarray], + [np.eye(3, dtype=np.float32), torch.eye(3), torch.Tensor], + [np.eye(3, dtype=np.float32), Matrix(torch.eye(3)), Matrix], + [np.eye(3, dtype=np.float32), Grid(torch.randn((3, 8, 8))), Grid], + [torch.eye(3), np.eye(3, dtype=np.float32), torch.Tensor], + [torch.eye(3), torch.eye(3), torch.Tensor], + [torch.eye(3), Matrix(torch.eye(3)), Matrix], + [torch.eye(3), Grid(torch.randn((3, 8, 8))), Grid], + [Matrix(torch.eye(3)), np.eye(3, dtype=np.float32), Matrix], + [Matrix(torch.eye(3)), torch.eye(3), Matrix], + [Matrix(torch.eye(3)), Matrix(torch.eye(3)), Matrix], + [Matrix(torch.eye(3)), Grid(torch.randn((3, 8, 8))), Grid], + [Grid(torch.randn((3, 8, 8))), np.eye(3, dtype=np.float32), Grid], + [Grid(torch.randn((3, 8, 8))), torch.eye(3), Grid], + [Grid(torch.randn((3, 8, 8))), Matrix(torch.eye(3)), Grid], + [Grid(torch.randn((3, 8, 8))), Grid(torch.randn((3, 8, 8))), None], + ] + + def _test_matmul_correct_return_type_impl(self, left, right, expected): + if expected is None: + with self.assertRaises(RuntimeError): + result = matmul(left, right) + else: + result = matmul(left, right) + self.assertIsInstance(result, expected) + + @parameterized.expand(MATMUL_TEST_CASES) + def test_matmul_correct_return_type(self, left, right, expected): + self._test_matmul_correct_return_type_impl(left, right, expected) + + # def test_all_matmul_correct_return_type(self): + # for case in self.MATMUL_TEST_CASES: + # with self.subTest(f"{case}"): + # self._test_matmul_correct_return_type_impl(*case) + + MATRIX_SHAPE_TESTCASES = [ + (torch.randn(2, 2), False), + (torch.randn(3, 3), True), + (torch.randn(4, 4), True), + (torch.randn(5, 5), False), + (torch.randn(3, 4), False), + (torch.randn(4, 3), False), + (torch.randn(3), False), + (torch.randn(4), False), + (torch.randn(5), False), + (torch.randn(3, 3, 3), False), + (torch.randn(4, 4, 4), False), + (torch.randn(5, 5, 5), False), + ] + + def _test_is_matrix_shaped_impl(self, matrix, expected): + self.assertEqual(is_matrix_shaped(matrix), expected) + + @parameterized.expand(MATRIX_SHAPE_TESTCASES) + def test_is_matrix_shaped(self, matrix, expected): + self._test_is_matrix_shaped_impl(matrix, expected) + + # def test_all_is_matrix_shaped(self): + # for case in self.MATRIX_SHAPE_TESTCASES: + # with self.subTest(f"{case[0].shape}"): + # self._test_is_matrix_shaped_impl(*case) + + GRID_SHAPE_TESTCASES = [ + (torch.randn(1, 16, 32), False), + (torch.randn(2, 16, 32), False), + (torch.randn(3, 16, 32), True), + (torch.randn(4, 16, 32), False), + (torch.randn(5, 16, 32), False), + (torch.randn(3, 16, 32, 64), False), + (torch.randn(4, 16, 32, 64), True), + (torch.randn(5, 16, 32, 64), False), + ] + + def _test_is_grid_shaped_impl(self, grid, expected): + self.assertEqual(is_grid_shaped(grid), expected) + + @parameterized.expand(GRID_SHAPE_TESTCASES) + def test_is_grid_shaped(self, grid, expected): + self._test_is_grid_shaped_impl(grid, expected) + + # def test_all_is_grid_shaped(self): + # for case in self.GRID_SHAPE_TESTCASES: + # with self.subTest(f"{case[0].shape}"): + # self._test_is_grid_shaped_impl(*case) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_resample.py b/tests/test_resample.py new file mode 100644 index 0000000000..fc8dab06bf --- /dev/null +++ b/tests/test_resample.py @@ -0,0 +1,46 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from monai.transforms.utility.functional import resample + +from monai.utils import convert_to_tensor + +from tests.utils import get_arange_img + + +def rotate_45_2d(): + t = torch.eye(3) + t[:, 0] = torch.FloatTensor([0, -1, 0]) + t[:, 1] = torch.FloatTensor([1, 0, 0]) + return t + + +class TestResampleFunction(unittest.TestCase): + + def _test_resample_function_impl(self, img, matrix): + result = resample(convert_to_tensor(img), matrix) + print(result) + + RESAMPLE_FUNCTION_CASES = [ + (get_arange_img((1, 16, 16)), rotate_45_2d()) + ] + + def test_resample_function(self): + for case in self.RESAMPLE_FUNCTION_CASES: + self._test_resample_function_impl(*case) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/utils.py b/tests/utils.py index afe08e0bfa..89766657ba 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -348,6 +348,24 @@ def make_rand_affine(ndim: int = 3, random_state: Optional[np.random.RandomState return af +def get_arange_img(size, dtype=torch.float32, offset=0): + """ + Returns an 2d or 3d image as a numpy tensor (complete with channel as dim 0) + with contents that iterate like an arange. + """ + img = torch.zeros(size, dtype=dtype) + if len(size) == 2: + for j in range(size[0]): + for i in range(size[1]): + img[j, i] = i + j * size[0] + offset + else: + for k in range(size[0]): + for j in range(size[1]): + for i in range(size[2]): + img[k, j, i] = i + j * size[0] + k * size[0] * size[1] + offset + return np.expand_dims(img, 0) + + class DistTestCase(unittest.TestCase): """ testcase without _outcome, so that it's picklable. From 49e7ed678e67ead9aa54f196ab06054ff13f47fb Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Tue, 1 Nov 2022 12:50:45 +0000 Subject: [PATCH 02/19] adding functional tests for flip, resize and spacing Signed-off-by: Ben Murray --- monai/transforms/croppad/functional.py | 50 ++ monai/transforms/lazy/functional.py | 2 +- monai/transforms/meta_matrix.py | 2 +- monai/transforms/spatial/functional.py | 498 ++++++++++++++++++ .../utility/{resample.py => functional.py} | 0 tests/test_flipf.py | 137 +++++ tests/test_resizef.py | 178 +++++++ tests/test_spacingf.py | 181 +++++++ 8 files changed, 1046 insertions(+), 2 deletions(-) create mode 100644 monai/transforms/croppad/functional.py create mode 100644 monai/transforms/spatial/functional.py rename monai/transforms/utility/{resample.py => functional.py} (100%) create mode 100644 tests/test_flipf.py create mode 100644 tests/test_resizef.py create mode 100644 tests/test_spacingf.py diff --git a/monai/transforms/croppad/functional.py b/monai/transforms/croppad/functional.py new file mode 100644 index 0000000000..6bdd45e64d --- /dev/null +++ b/monai/transforms/croppad/functional.py @@ -0,0 +1,50 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Sequence, Union + +import torch + +from monai.data.meta_obj import get_track_meta +from monai.transforms.lazy.functional import extents_from_shape, shape_from_extents +from monai.transforms.meta_matrix import MatrixFactory +from monai.utils import GridSamplePadMode, NumpyPadMode, convert_to_tensor + + +def croppad( + img: torch.Tensor, + slices: Union[Sequence[slice], slice], + padding_mode: Optional[Union[GridSamplePadMode, str]] = NumpyPadMode.EDGE, + shape_override: Optional[Sequence] = None +): + img_ = convert_to_tensor(img, track_meta=get_track_meta()) + input_shape = img_.shape if shape_override is None else shape_override + input_ndim = len(input_shape) - 1 + if len(slices) != input_ndim: + raise ValueError(f"'slices' length {len(slices)} must be equal to 'img' " + f"spatial dimensions of {input_ndim}") + + img_centers = [i / 2 for i in input_shape[1:]] + slice_centers = [(s.stop + s.start) / 2 for s in slices] + deltas = [s - i for i, s in zip(img_centers, slice_centers)] + transform = MatrixFactory.from_tensor(img).translate(deltas).matrix.matrix + im_extents = extents_from_shape([input_shape[0]] + [s.stop - s.start for s in slices]) + im_extents = [transform @ e for e in im_extents] + shape_override_ = shape_from_extents(input_shape, im_extents) + + metadata = { + "slices": slices, + "padding_mode": padding_mode, + "dtype": img.dtype, + "im_extents": im_extents, + "shape_override": shape_override_ + } + return img_, transform, metadata diff --git a/monai/transforms/lazy/functional.py b/monai/transforms/lazy/functional.py index 93594bfa74..50192514af 100644 --- a/monai/transforms/lazy/functional.py +++ b/monai/transforms/lazy/functional.py @@ -59,7 +59,7 @@ def shape_from_extents( mins = aextents.min(axis=0)[0] maxes = aextents.max(axis=0)[0] values = torch.round(maxes - mins).type(torch.IntTensor)[:-1] - return torch.cat((torch.as_tensor([src_shape[0]]), values)) + return torch.cat((torch.IntTensor([src_shape[0]]), values)) def metadata_is_compatible(value_1, value_2): diff --git a/monai/transforms/meta_matrix.py b/monai/transforms/meta_matrix.py index 845fe3f2c5..c8df87c1e6 100644 --- a/monai/transforms/meta_matrix.py +++ b/monai/transforms/meta_matrix.py @@ -49,7 +49,7 @@ def __init__(self, self._sin = lambda th: np.sin(th, dtype=np.float32) self._cos = lambda th: np.cos(th, dtype=np.float32) self._eye = lambda th: np.eye(th, dtype=np.float32) - self._diag = lambda th: np.diag(th, dtype=np.float32) + self._diag = lambda th: np.diag(th).astype(np.float32) else: if device is None: raise ValueError("'device' must be set with TransformBackends.TORCH") diff --git a/monai/transforms/spatial/functional.py b/monai/transforms/spatial/functional.py new file mode 100644 index 0000000000..cf121d4465 --- /dev/null +++ b/monai/transforms/spatial/functional.py @@ -0,0 +1,498 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Sequence, Tuple, Union + +import numpy as np + +import torch +from monai.networks.layers import GaussianFilter + +from monai.networks.utils import meshgrid_ij + +from monai.transforms import create_grid, create_rotate, create_translate, map_spatial_axes +from monai.config import DtypeLike +from monai.data import get_track_meta +from monai.transforms.lazy.functional import extents_from_shape, shape_from_extents +from monai.transforms.meta_matrix import MatrixFactory +from monai.utils import ( + convert_to_tensor, + ensure_tuple, + ensure_tuple_rep, + ensure_tuple_size, + fall_back_tuple, + get_equivalent_dtype, + look_up_option, + GridSampleMode, + GridSamplePadMode, + InterpolateMode, + NumpyPadMode +) + + +def identity( + img: torch.Tensor, + mode: Optional[Union[InterpolateMode, str]] = None, + padding_mode: Optional[Union[NumpyPadMode, GridSamplePadMode, str]] = None, + dtype: Optional[Union[DtypeLike, torch.dtype]] = None +): + img_ = convert_to_tensor(img, track_meta=get_track_meta()) + + mode_ = None if mode is None else look_up_option(mode, GridSampleMode) + padding_mode_ = None if padding_mode is None else look_up_option(padding_mode, GridSamplePadMode) + dtype_ = get_equivalent_dtype(dtype or img_.dtype, torch.Tensor) + + transform = MatrixFactory.from_tensor(img_).identity().matrix.matrix + + metadata = dict() + if mode_ is not None: + metadata["mode"] = mode_ + if padding_mode_ is not None: + metadata["padding_mode"] = padding_mode_ + metadata["dtype"] = dtype_ + return img_, transform, metadata + + +def spacing( + img: torch.Tensor, + pixdim: Union[Sequence[float], float], + src_pixdim: Union[Sequence[float], float], + diagonal: Optional[bool] = False, + mode: Optional[Union[InterpolateMode, str]] = InterpolateMode.AREA, + padding_mode: Optional[Union[NumpyPadMode, GridSamplePadMode, str]] = NumpyPadMode.EDGE, + align_corners: Optional[bool] = False, + dtype: Optional[Union[DtypeLike, torch.dtype]] = None, + shape_override: Optional[Sequence] = None +): + """ + Args: + img: channel first array, must have shape: (num_channels, H[, W, ..., ]). + mode: {``"nearest"``, ``"nearest-exact"``, ``"linear"``, + ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} + The interpolation mode. Defaults to ``self.mode``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html + align_corners: This only has an effect when mode is + 'linear', 'bilinear', 'bicubic' or 'trilinear'. Defaults to ``self.align_corners``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html + anti_aliasing: bool, optional + Whether to apply a Gaussian filter to smooth the image prior + to downsampling. It is crucial to filter when downsampling + the image to avoid aliasing artifacts. See also ``skimage.transform.resize`` + anti_aliasing_sigma: {float, tuple of floats}, optional + Standard deviation for Gaussian filtering used when anti-aliasing. + By default, this value is chosen as (s - 1) / 2 where s is the + downsampling factor, where s > 1. For the up-size case, s < 1, no + anti-aliasing is performed prior to rescaling. + + Raises: + ValueError: When ``self.spatial_size`` length is less than ``img`` spatial dimensions. + + """ + + img_ = convert_to_tensor(img, track_meta=get_track_meta()) + input_shape = img_.shape if shape_override is None else shape_override + input_ndim = len(input_shape) - 1 + + pixdim_ = ensure_tuple_rep(pixdim, input_ndim) + src_pixdim_ = ensure_tuple_rep(src_pixdim, input_ndim) + + if diagonal is True: + raise ValueError("'diagonal' value of True is not currently supported") + + mode_ = look_up_option(mode, GridSampleMode) + padding_mode_ = look_up_option(padding_mode, GridSamplePadMode) + dtype_ = get_equivalent_dtype(dtype or img.dtype, torch.Tensor) + zoom_factors = [i / j for i, j in zip(src_pixdim_, pixdim_)] + + # TODO: decide whether we are consistently returning MetaMatrix or concrete transforms + transform = MatrixFactory.from_tensor(img).scale(zoom_factors).matrix.data + im_extents = extents_from_shape(input_shape) + im_extents = [transform @ e for e in im_extents] + shape_override_ = shape_from_extents(input_shape, im_extents) + + metadata = { + "pixdim": pixdim_, + "src_pixdim": src_pixdim_, + "diagonal": diagonal, + "mode": mode_, + "padding_mode": padding_mode_, + "align_corners": align_corners, + "dtype": dtype_, + # "im_extents": im_extents, + "shape_override": shape_override_ + } + return img_, transform, metadata + + +def orientation( + img: torch.Tensor +): + pass + + +def flip( + img: torch.Tensor, + spatial_axis: Union[Sequence[int], int], + shape_override: Optional[Sequence] = None +): + img_ = convert_to_tensor(img, track_meta=get_track_meta()) + input_shape = img_.shape if shape_override is None else shape_override + + spatial_axis_ = spatial_axis + if spatial_axis_ is None: + spatial_axis_ = tuple(i for i in range(len(input_shape[1:]))) + transform = MatrixFactory.from_tensor(img).flip(spatial_axis_).matrix.data + # im_extents = extents_from_shape(input_shape) + # im_extents = [transform @ e for e in im_extents] + # + # shape_override_ = shape_from_extents(input_shape, im_extents) + + metadata = { + "spatial_axis": spatial_axis_, + # "im_extents": im_extents, + "shape_override": shape_override + } + return img_, transform, metadata + + +def resize( + img: torch.Tensor, + spatial_size: Union[Sequence[int], int], + size_mode: str = "all", + mode: Optional[Union[InterpolateMode, str]] = InterpolateMode.AREA, + align_corners: Optional[bool] = False, + anti_aliasing: Optional[bool] = None, + anti_aliasing_sigma: Optional[Union[Sequence[float], float]] = None, + dtype: Optional[Union[DtypeLike, torch.dtype]] = None, + shape_override: Optional[Sequence] = None +): + """ + Args: + img: channel first array, must have shape: (num_channels, H[, W, ..., ]). + mode: {``"nearest"``, ``"nearest-exact"``, ``"linear"``, + ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} + The interpolation mode. Defaults to ``self.mode``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html + align_corners: This only has an effect when mode is + 'linear', 'bilinear', 'bicubic' or 'trilinear'. Defaults to ``self.align_corners``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html + anti_aliasing: bool, optional + Whether to apply a Gaussian filter to smooth the image prior + to downsampling. It is crucial to filter when downsampling + the image to avoid aliasing artifacts. See also ``skimage.transform.resize`` + anti_aliasing_sigma: {float, tuple of floats}, optional + Standard deviation for Gaussian filtering used when anti-aliasing. + By default, this value is chosen as (s - 1) / 2 where s is the + downsampling factor, where s > 1. For the up-size case, s < 1, no + anti-aliasing is performed prior to rescaling. + + Raises: + ValueError: When ``self.spatial_size`` length is less than ``img`` spatial dimensions. + + """ + + img_ = convert_to_tensor(img, track_meta=get_track_meta()) + input_shape = img_.shape if shape_override is None else shape_override + input_ndim = len(input_shape) - 1 + + if size_mode == "all": + output_ndim = len(ensure_tuple(spatial_size)) + if output_ndim > input_ndim: + input_shape = ensure_tuple_size(input_shape, output_ndim + 1, 1) + img = img.reshape(input_shape) + elif output_ndim < input_ndim: + raise ValueError( + "len(spatial_size) must be greater or equal to img spatial dimensions, " + f"got spatial_size={output_ndim} img={input_ndim}." + ) + spatial_size_ = fall_back_tuple(spatial_size, input_shape[1:]) + else: # for the "longest" mode + img_size = input_shape[1:] + if not isinstance(spatial_size, int): + raise ValueError("spatial_size must be an int number if size_mode is 'longest'.") + scale = spatial_size / max(img_size) + spatial_size_ = tuple(int(round(s * scale)) for s in img_size) + + mode_ = look_up_option(mode, GridSampleMode) + dtype_ = get_equivalent_dtype(dtype or img.dtype, torch.Tensor) + zoom_factors = [i / j for i, j in zip(spatial_size_, input_shape[1:])] + transform = MatrixFactory.from_tensor(img).scale(zoom_factors).matrix.data + im_extents = extents_from_shape(input_shape) + im_extents = [transform @ e for e in im_extents] + shape_override_ = shape_from_extents(input_shape, im_extents) + + metadata = { + "spatial_size": spatial_size, + "size_mode": size_mode, + "mode": mode_, + "align_corners": align_corners, + "anti_aliasing": anti_aliasing, + "anti_aliasing_sigma": anti_aliasing_sigma, + "dtype": dtype_, + # "im_extents": im_extents, + "shape_override": shape_override_ + } + return img_, transform, metadata + + +def rotate( + img: torch.Tensor, + angle: Union[Sequence[float], float], + keep_size: Optional[bool] = True, + mode: Optional[Union[InterpolateMode, str]] = InterpolateMode.AREA, + padding_mode: Optional[Union[NumpyPadMode, GridSamplePadMode, str]] = NumpyPadMode.EDGE, + align_corners: Optional[bool] = False, + dtype: Optional[Union[DtypeLike, torch.dtype]] = None, + shape_override: Optional[Sequence] = None +): + """ + Args: + img: channel first array, must have shape: [chns, H, W] or [chns, H, W, D]. + angle: Rotation angle(s) in radians. should a float for 2D, three floats for 3D. + keep_size: If it is True, the output shape is kept the same as the input. + If it is False, the output shape is adapted so that the + input array is contained completely in the output. Default is True. + mode: {``"bilinear"``, ``"nearest"``} + Interpolation mode to calculate output values. Defaults to ``self.mode``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} + Padding mode for outside grid values. Defaults to ``self.padding_mode``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + align_corners: Defaults to ``self.align_corners``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + align_corners: Defaults to ``self.align_corners``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + dtype: data type for resampling computation. Defaults to ``self.dtype``. + If None, use the data type of input data. To be compatible with other modules, + the output data type is always ``np.float32``. + + Raises: + ValueError: When ``img`` spatially is not one of [2D, 3D]. + + """ + + img_ = convert_to_tensor(img, track_meta=get_track_meta()) + mode_ = look_up_option(mode, GridSampleMode) + padding_mode_ = look_up_option(padding_mode, GridSamplePadMode) + dtype_ = get_equivalent_dtype(dtype or img.dtype, torch.Tensor) + input_shape = img_.shape if shape_override is None else shape_override + input_ndim = len(input_shape) - 1 + if input_ndim not in (2, 3): + raise ValueError(f"Unsupported image dimension: {input_ndim}, available options are [2, 3].") + + angle_ = ensure_tuple_rep(angle, 1 if input_ndim == 2 else 3) + rotate_tx = torch.from_numpy(create_rotate(input_ndim, angle_)) + im_extents = extents_from_shape(input_shape) + if not keep_size: + im_extents = [rotate_tx @ e for e in im_extents] + spatial_shape = shape_from_extents(input_shape, im_extents) + else: + spatial_shape = input_shape + transform = rotate_tx + metadata = { + "angle": angle_, + "keep_size": keep_size, + "mode": mode_, + "padding_mode": padding_mode_, + "align_corners": align_corners, + "dtype": dtype_, + "im_extents": im_extents, + "shape_override": spatial_shape + } + return img_, transform, metadata + + +def zoom( + img: torch.Tensor, + factor: Union[Sequence[float], float], + mode: Optional[Union[InterpolateMode, str]] = InterpolateMode.BILINEAR, + padding_mode: Optional[Union[NumpyPadMode, GridSamplePadMode, str]] = NumpyPadMode.EDGE, + align_corners: Optional[bool] = False, + keep_size: Optional[bool] = True, + dtype: Optional[Union[DtypeLike, torch.dtype]] = None, + shape_override: Optional[Sequence] = None +): + """ + Args: + img: channel first array, must have shape: (num_channels, H[, W, ..., ]). + mode: {``"nearest"``, ``"nearest-exact"``, ``"linear"``, + ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} + The interpolation mode. Defaults to ``self.mode``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html + align_corners: This only has an effect when mode is + 'linear', 'bilinear', 'bicubic' or 'trilinear'. Defaults to ``self.align_corners``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html + + Raises: + ValueError: When ``self.spatial_size`` length is less than ``img`` spatial dimensions. + + """ + + img_ = convert_to_tensor(img, track_meta=get_track_meta()) + input_shape = img_.shape if shape_override is None else shape_override + input_ndim = len(input_shape) - 1 + + zoom_factors = ensure_tuple_rep(factor, input_ndim) + zoom_factors = [1 / f for f in zoom_factors] + + mode_ = look_up_option(mode, GridSampleMode) + padding_mode_ = look_up_option(padding_mode, GridSamplePadMode) + dtype_ = get_equivalent_dtype(dtype or img_.dtype, torch.Tensor) + + transform = MatrixFactory.from_tensor(img_).scale(zoom_factors).matrix.matrix + im_extents = extents_from_shape(input_shape) + if keep_size is False: + im_extents = [transform @ e for e in im_extents] + shape_override_ = shape_from_extents(input_shape, im_extents) + else: + shape_override_ = input_shape + + metadata = { + "factor": zoom_factors, + "mode": mode_, + "padding_mode": padding_mode_, + "align_corners": align_corners, + "keep_size": keep_size, + "dtype": dtype_, + "im_extents": im_extents, + "shape_override": shape_override_ + } + return img_, transform, metadata + + +def rotate90( + img: torch.Tensor, + k: Optional[int] = 1, + spatial_axes: Optional[Tuple[int, int]] = (0, 1), + shape_override: Optional[bool] = None +): + if len(spatial_axes) != 2: + raise ValueError("'spatial_axes' must be a tuple of two integers indicating") + + img_ = convert_to_tensor(img, track_meta=get_track_meta()) + # axes = map_spatial_axes(img.ndim, spatial_axes) + # ori_shape = img.shape[1:] + input_shape = img_.shape if shape_override is None else shape_override + input_ndim = len(input_shape) - 1 + + transform = MatrixFactory.from_tensor(img_).rotate_90(k, ) + + metadata = { + "k": k, + "spatial_axes": spatial_axes, + "shape_override": shape_override + } + return img_, transform, metadata + + +def grid_distortion( + img: torch.Tensor, + num_cells: Union[Tuple[int], int], + distort_steps: Sequence[Sequence[float]], + mode: str = GridSampleMode.BILINEAR, + padding_mode: str = GridSamplePadMode.BORDER, + shape_override: Optional[Tuple[int]] = None +): + all_ranges = [] + num_cells = ensure_tuple_rep(num_cells, len(img.shape) - 1) + for dim_idx, dim_size in enumerate(img.shape[1:]): + dim_distort_steps = distort_steps[dim_idx] + ranges = torch.zeros(dim_size, dtype=torch.float32) + cell_size = dim_size // num_cells[dim_idx] + prev = 0 + for idx in range(num_cells[dim_idx] + 1): + start = int(idx * cell_size) + end = start + cell_size + if end > dim_size: + end = dim_size + cur = dim_size + else: + cur = prev + cell_size * dim_distort_steps[idx] + prev = cur + ranges = range - (dim_size - 1.0) / 2.0 + all_ranges.append() + coords = meshgrid_ij(*all_ranges) + grid = torch.stack([*coords, torch.ones_like(coords[0])]) + + metadata = { + "num_cells": num_cells, + "distort_steps": distort_steps, + "mode": mode, + "padding_mode": padding_mode + } + + return img, grid, metadata + + +def elastic_3d( + img: torch.Tensor, + sigma: float, + magnitude: float, + offsets: torch.Tensor, + spatial_size: Optional[Union[Tuple[int, int, int], int]] = None, + mode: str = GridSampleMode.BILINEAR, + padding_mode: str = GridSamplePadMode.REFLECTION, + device: Optional[torch.device] = None, + shape_override: Optional[Tuple[float]] = None +): + img_ = convert_to_tensor(img, track_meta=get_track_meta()) + + sp_size = fall_back_tuple(spatial_size, img.shape[1:]) + device_ = img.device if isinstance(img, torch.Tensor) else device + grid = create_grid(spatial_size=sp_size, device=device_, backend="torch") + gaussian = GaussianFilter(3, sigma, 3.0).to(device=device_) + grid[:3] += gaussian(offsets)[0] * magnitude + + metadata = { + "sigma": sigma, + "magnitude": magnitude, + "offsets": offsets, + } + if spatial_size is not None: + metadata["spatial_size"] = spatial_size + if mode is not None: + metadata["mode"] = mode + if padding_mode is not None: + metadata["padding_mode"] = padding_mode + if shape_override is not None: + metadata["shape_override"] = shape_override + + return img_, grid, metadata + + +def translate( + img: torch.Tensor, + translation: Sequence[float], + mode: Optional[Union[GridSampleMode, str]] = GridSampleMode.BILINEAR, + padding_mode: Optional[Union[GridSamplePadMode, str]] = NumpyPadMode.EDGE, + dtype: Union[DtypeLike, torch.dtype] = np.float32, + shape_override: Optional[Sequence] = None +): + img_ = convert_to_tensor(img, track_meta=get_track_meta()) + input_shape = img_.shape if shape_override is None else shape_override + input_ndim = len(input_shape) - 1 + if len(translation) != input_ndim: + raise ValueError(f"'translate' length {len(translation)} must be equal to 'img' " + f"spatial dimensions of {input_ndim}") + + transform = MatrixFactory.from_tensor(img).translate(translation).matrix.matrix + im_extents = extents_from_shape(input_shape) + im_extents = [transform @ e for e in im_extents] + # shape_override_ = shape_from_extents(input_shape, im_extents) + + metadata = { + "translation": translation, + "padding_mode": padding_mode, + "dtype": img.dtype, + "im_extents": im_extents, + # "shape_override": shape_override_ + } + return img_, transform, metadata diff --git a/monai/transforms/utility/resample.py b/monai/transforms/utility/functional.py similarity index 100% rename from monai/transforms/utility/resample.py rename to monai/transforms/utility/functional.py diff --git a/tests/test_flipf.py b/tests/test_flipf.py new file mode 100644 index 0000000000..ad1b364e94 --- /dev/null +++ b/tests/test_flipf.py @@ -0,0 +1,137 @@ +import unittest + +from typing import Sequence, Union + +import torch + +from monai.transforms.spatial.functional import flip +from tests.utils import get_arange_img + + +def affine_flip(dims, axis: Union[int, Sequence[int]]): + if axis is None: + return torch.eye(dims + 1) + + if not isinstance(axis, (list, tuple)): + axis = (axis,) + t = torch.eye(dims + 1) + for i in axis: + t[i, i] = -1 + return t + + +def get_metadata(overrides=None, remove=None): + metadata = { + "spatial_axis": None, + "shape_override": None, + } + if overrides is not None: + for k, v in overrides.items(): + metadata[k] = v + if remove is not None: + for k in remove: + if k in metadata: + del metadata[k] + return metadata + + +class TestFunctionalSpacing(unittest.TestCase): + + FLIP_CASES = [ + # 2d cases + ( + get_arange_img((32, 32)), affine_flip(2, (0, 1)), + get_metadata({"spatial_axis": None}), get_metadata({"spatial_axis": (0, 1)}) + ), + ( + get_arange_img((32, 32)), affine_flip(2, 0), + get_metadata({"spatial_axis": 0}), get_metadata({"spatial_axis": 0}) + ), + ( + get_arange_img((32, 32)), affine_flip(2, 1), + get_metadata({"spatial_axis": 1}), get_metadata({"spatial_axis": 1}) + ), + ( + get_arange_img((32, 32)), affine_flip(2, 0), + get_metadata({"spatial_axis": (0,)}), get_metadata({"spatial_axis": (0,)}) + ), + ( + get_arange_img((32, 32)), affine_flip(2, 1), + get_metadata({"spatial_axis": (1,)}), get_metadata({"spatial_axis": (1,)}) + ), + ( + get_arange_img((32, 32)), affine_flip(2, (0, 1)), + get_metadata({"spatial_axis": (0, 1)}), get_metadata({"spatial_axis": (0, 1)}) + ), + ( + get_arange_img((32, 32)), affine_flip(2, (0, 1)), + get_metadata({"spatial_axis": (1, 0)}), get_metadata({"spatial_axis": (1, 0)}) + ), + # 3d cases + ( + get_arange_img((32, 32, 16)), affine_flip(3, (0, 1, 2)), + get_metadata({"spatial_axis": None}), get_metadata({"spatial_axis": (0, 1, 2)}) + ), + ( + get_arange_img((32, 32, 16)), affine_flip(3, 0), + get_metadata({"spatial_axis": 0}), get_metadata({"spatial_axis": 0}) + ), + ( + get_arange_img((32, 32, 16)), affine_flip(3, 1), + get_metadata({"spatial_axis": 1}), get_metadata({"spatial_axis": 1}) + ), + ( + get_arange_img((32, 32, 16)), affine_flip(3, 2), + get_metadata({"spatial_axis": 2}), get_metadata({"spatial_axis": 2}) + ), + ( + get_arange_img((32, 32, 16)), affine_flip(3, 0), + get_metadata({"spatial_axis": (0,)}), get_metadata({"spatial_axis": (0,)}) + ), + ( + get_arange_img((32, 32, 16)), affine_flip(3, 1), + get_metadata({"spatial_axis": (1,)}), get_metadata({"spatial_axis": (1,)}) + ), + ( + get_arange_img((32, 32, 16)), affine_flip(3, 2), + get_metadata({"spatial_axis": (2,)}), get_metadata({"spatial_axis": (2,)}) + ), + ( + get_arange_img((32, 32, 16)), affine_flip(3, (0, 1)), + get_metadata({"spatial_axis": (0, 1)}), get_metadata({"spatial_axis": (0, 1)}) + ), + ( + get_arange_img((32, 32, 16)), affine_flip(3, (0, 2)), + get_metadata({"spatial_axis": (0, 2)}), get_metadata({"spatial_axis": (0, 2)}) + ), + ( + get_arange_img((32, 32, 16)), affine_flip(3, (1, 2)), + get_metadata({"spatial_axis": (1, 2)}), get_metadata({"spatial_axis": (1, 2)}) + ), + ( + get_arange_img((32, 32, 16)), affine_flip(3, (0, 1, 2)), + get_metadata({"spatial_axis": (0, 1, 2)}), get_metadata({"spatial_axis": (0, 1, 2)}) + ), + ] + + def _test_functional_flip_impl( + self, img, expected_transform, call_params, expected_metadata + ): + img_, transform_, metadata = flip(img, **call_params) + self.assertTrue(torch.allclose(transform_, expected_transform), + msg=f"{transform_} != {expected_transform}") + actual_keys = set(metadata.keys()) + expected_keys = set(expected_metadata.keys()) + self.assertSetEqual(actual_keys, expected_keys) + for k in actual_keys: + if isinstance(metadata[k], torch.Tensor) and metadata[k] is not None: + self.assertTrue(torch.allclose(metadata[k], expected_metadata[k]), + msg=f"{metadata[k]} != {expected_metadata[k]} for key {k}") + else: + self.assertEqual(metadata[k], expected_metadata[k], + msg=f"{metadata[k]} != {expected_metadata[k]} for key {k}") + + def test_functional_flip(self): + for icase, case in enumerate(self.FLIP_CASES): + with self.subTest(f"{icase}"): + self._test_functional_flip_impl(*case) \ No newline at end of file diff --git a/tests/test_resizef.py b/tests/test_resizef.py new file mode 100644 index 0000000000..eb32e70e3b --- /dev/null +++ b/tests/test_resizef.py @@ -0,0 +1,178 @@ +import unittest + +import torch + +from monai.transforms.spatial.functional import resize +from tests.utils import get_arange_img + + +def affine_scale_2d(scale): + scale0, scale1 = scale if isinstance(scale, (tuple, list)) else (scale, scale) + t = torch.eye(3) + t[:, 0] = torch.FloatTensor([scale0, 0.0, 0.0]) + t[:, 1] = torch.FloatTensor([0.0, scale1, 0.0]) + return t + + +def affine_scale_3d(scale): + scale0, scale1, scale2 = scale if isinstance(scale, (tuple, list)) else (scale, scale, scale) + t = torch.eye(4) + t[:, 0] = torch.FloatTensor([scale0, 0.0, 0.0, 0.0]) + t[:, 1] = torch.FloatTensor([0.0, scale1, 0.0, 0.0]) + t[:, 2] = torch.FloatTensor([0.0, 0.0, scale2, 0.0]) + return t + + +def get_metadata(is_3d=True, overrides=None, remove=None): + metadata = { + "size_mode": "all", + "mode": "nearest", + "align_corners": False, + "anti_aliasing": False, + "anti_aliasing_sigma": None, + "dtype": torch.float32, + # "im_extents": None, + # "shape_override": torch.IntTensor([1, 32, 32]) # shape override shouldn't always be in here + } + if overrides is not None: + for k, v in overrides.items(): + metadata[k] = v + if remove is not None: + for k in remove: + if k in metadata: + del metadata[k] + return metadata + + +class TestFunctionalSpacing(unittest.TestCase): + + SPACING_CASES = [ + # 2d - "all" + ( + get_arange_img((32, 32)), + affine_scale_2d(0.5), + get_metadata(False, {"spatial_size": (16, 16)}), + get_metadata(False, {"spatial_size": (16, 16), "shape_override": torch.IntTensor([1, 16, 16])}) + ), + ( + get_arange_img((32, 32)), + affine_scale_2d(2.0), + get_metadata(False, {"spatial_size": (64, 64)}), + get_metadata(False, {"spatial_size": (64, 64), "shape_override": torch.IntTensor([1, 64, 64])}) + ), + ( + get_arange_img((32, 16)), + affine_scale_2d((0.5, 1.0)), + get_metadata(False, {"spatial_size": (16, 16)}), + get_metadata(False, {"spatial_size": (16, 16), "shape_override": torch.IntTensor([1, 16, 16])}) + ), + # 2d - "longest" + ( + get_arange_img((32, 16)), + affine_scale_2d((0.5)), + get_metadata(False, {"spatial_size": 16, "size_mode": "longest"}), + get_metadata(False, {"spatial_size": 16, "size_mode": "longest", + "shape_override": torch.IntTensor([1, 16, 8])}) + ), + ( + get_arange_img((16, 32)), + affine_scale_2d((0.5)), + get_metadata(False, {"spatial_size": 16, "size_mode": "longest"}), + get_metadata(False, {"spatial_size": 16, "size_mode": "longest", + "shape_override": torch.IntTensor([1, 8, 16])}) + ), + ( + get_arange_img((32, 16)), + affine_scale_2d((2.0)), + get_metadata(False, {"spatial_size": 64, "size_mode": "longest"}), + get_metadata(False, {"spatial_size": 64, "size_mode": "longest", + "shape_override": torch.IntTensor([1, 64, 32])}) + ), + ( + get_arange_img((16, 32)), + affine_scale_2d((2.0)), + get_metadata(False, {"spatial_size": 64, "size_mode": "longest"}), + get_metadata(False, {"spatial_size": 64, "size_mode": "longest", + "shape_override": torch.IntTensor([1, 32, 64])}) + ), + # 3d - "all" + ( + get_arange_img((32, 32, 16)), + affine_scale_3d(0.5), + get_metadata(False, {"spatial_size": (16, 16, 8)}), + get_metadata(False, {"spatial_size": (16, 16, 8), + "shape_override": torch.IntTensor([1, 16, 16, 8])}) + ), + ( + get_arange_img((32, 32, 16)), + affine_scale_3d(2.0), + get_metadata(False, {"spatial_size": (64, 64, 32)}), + get_metadata(False, {"spatial_size": (64, 64, 32), + "shape_override": torch.IntTensor([1, 64, 64, 32])}) + ), + # 3d - "longest" + ( + get_arange_img((32, 16, 8)), + affine_scale_3d((0.5)), + get_metadata(False, {"spatial_size": 16, "size_mode": "longest"}), + get_metadata(False, {"spatial_size": 16, "size_mode": "longest", + "shape_override": torch.IntTensor([1, 16, 8, 4])}) + ), + ( + get_arange_img((16, 32, 8)), + affine_scale_3d((0.5)), + get_metadata(False, {"spatial_size": 16, "size_mode": "longest"}), + get_metadata(False, {"spatial_size": 16, "size_mode": "longest", + "shape_override": torch.IntTensor([1, 8, 16, 4])}) + ), + ( + get_arange_img((8, 16, 32)), + affine_scale_3d((0.5)), + get_metadata(False, {"spatial_size": 16, "size_mode": "longest"}), + get_metadata(False, {"spatial_size": 16, "size_mode": "longest", + "shape_override": torch.IntTensor([1, 4, 8, 16])}) + ), + ( + get_arange_img((32, 16, 8)), + affine_scale_3d((2.0)), + get_metadata(False, {"spatial_size": 64, "size_mode": "longest"}), + get_metadata(False, {"spatial_size": 64, "size_mode": "longest", + "shape_override": torch.IntTensor([1, 64, 32, 16])}) + ), + ( + get_arange_img((16, 32, 8)), + affine_scale_3d((2.0)), + get_metadata(False, {"spatial_size": 64, "size_mode": "longest"}), + get_metadata(False, {"spatial_size": 64, "size_mode": "longest", + "shape_override": torch.IntTensor([1, 32, 64, 16])}) + ), + ( + get_arange_img((8, 16, 32)), + affine_scale_3d((2.0)), + get_metadata(False, {"spatial_size": 64, "size_mode": "longest"}), + get_metadata(False, {"spatial_size": 64, "size_mode": "longest", + "shape_override": torch.IntTensor([1, 16, 32, 64])}) + ), + ] + + def _test_functional_resize_impl( + self, img, expected_transform, call_params, expected_metadata + ): + img_, transform_, metadata = resize(img, **call_params) + self.assertTrue(torch.allclose(transform_, expected_transform), + msg=f"{transform_} != {expected_transform}") + actual_keys = set(metadata.keys()) + expected_keys = set(expected_metadata.keys()) + self.assertSetEqual(actual_keys, expected_keys) + for k in actual_keys: + if isinstance(metadata[k], torch.Tensor): + self.assertTrue(torch.allclose(metadata[k], expected_metadata[k]), + msg=f"{metadata[k]} != {expected_metadata[k]} for key {k}") + else: + self.assertEqual(metadata[k], expected_metadata[k], + msg=f"{metadata[k]} != {expected_metadata[k]} for key {k}") + + def test_functional_resize(self): + for icase, case in enumerate(self.SPACING_CASES): + with self.subTest(f"{icase}"): + self._test_functional_resize_impl(*case) \ No newline at end of file diff --git a/tests/test_spacingf.py b/tests/test_spacingf.py new file mode 100644 index 0000000000..5b41532943 --- /dev/null +++ b/tests/test_spacingf.py @@ -0,0 +1,181 @@ +import unittest + +import torch + +from monai.transforms.spatial.functional import spacing +from tests.utils import get_arange_img + + +def affine_scale_2d(scale): + scale0, scale1 = scale if isinstance(scale, (tuple, list)) else (scale, scale) + t = torch.eye(3) + t[:, 0] = torch.FloatTensor([scale0, 0.0, 0.0]) + t[:, 1] = torch.FloatTensor([0.0, scale1, 0.0]) + return t + + +def affine_scale_3d(scale): + scale0, scale1, scale2 = scale if isinstance(scale, (tuple, list)) else (scale, scale, scale) + t = torch.eye(4) + t[:, 0] = torch.FloatTensor([scale0, 0.0, 0.0, 0.0]) + t[:, 1] = torch.FloatTensor([0.0, scale1, 0.0, 0.0]) + t[:, 2] = torch.FloatTensor([0.0, 0.0, scale2, 0.0]) + return t + + +def get_metadata(is_3d=True, overrides=None, remove=None): + metadata = { + "pixdim": (1.0, 1.0, 1.0) if is_3d else (1.0, 1.0), + "src_pixdim": (1.0, 1.0, 1.0) if is_3d else (1.0, 1.0), + "diagonal": False, + "mode": "nearest", + "padding_mode": "zeros", + "align_corners": False, + "dtype": torch.float32, + # "im_extents": None, + # "shape_override": torch.IntTensor([1, 32, 32]) # shape override shouldn't always be in here + } + if overrides is not None: + for k, v in overrides.items(): + metadata[k] = v + if remove is not None: + for k in remove: + if k in metadata: + del metadata[k] + return metadata + + +class TestFunctionalSpacing(unittest.TestCase): + + SPACING_CASES = [ + ( + get_arange_img((32, 32)), + affine_scale_2d(0.5), + get_metadata(False, {"pixdim": (2.0, 2.0)}), + get_metadata(False, {"pixdim": (2.0, 2.0), "shape_override": torch.IntTensor([1, 16, 16])}) + ), + ( + get_arange_img((32, 32)), + affine_scale_2d(0.5), + get_metadata(False, {"src_pixdim": (0.5, 0.5)}), + get_metadata(False, {"src_pixdim": (0.5, 0.5), "shape_override": torch.IntTensor([1, 16, 16])}) + ), + ( + get_arange_img((32, 32)), + affine_scale_2d(0.25), + get_metadata(False, {"pixdim": (2.0, 2.0), "src_pixdim": (0.5, 0.5)}), + get_metadata(False, {"pixdim": (2.0, 2.0), "src_pixdim": (0.5, 0.5), + "shape_override": torch.IntTensor([1, 8, 8])}) + ), + ( + get_arange_img((32, 32)), + affine_scale_2d(2.0), + get_metadata(False, {"src_pixdim": (2.0, 2.0)}), + get_metadata(False, {"src_pixdim": (2.0, 2.0), "shape_override": torch.IntTensor([1, 64, 64])}) + ), + ( + get_arange_img((32, 32)), + affine_scale_2d(2.0), + get_metadata(False, {"pixdim": (0.5, 0.5)}), + get_metadata(False, {"pixdim": (0.5, 0.5), "shape_override": torch.IntTensor([1, 64, 64])}) + ), + ( + get_arange_img((32, 32)), + affine_scale_2d(4.0), + get_metadata(False, {"pixdim": (0.5, 0.5), "src_pixdim": (2.0, 2.0)}), + get_metadata(False, {"pixdim": (0.5, 0.5), "src_pixdim": (2.0, 2.0), + "shape_override": torch.IntTensor([1, 128, 128])}) + ), + ( + get_arange_img((32, 32)), + affine_scale_2d((0.5, 2.0)), + get_metadata(False, {"pixdim": (2.0, 1.0), "src_pixdim": (1.0, 2.0)}), + get_metadata(False, {"pixdim": (2.0, 1.0), "src_pixdim": (1.0, 2.0), + "shape_override": torch.IntTensor([1, 16, 64])}) + ), + ( + get_arange_img((32, 32)), + affine_scale_2d((2.0, 0.5)), + get_metadata(False, {"pixdim": (1.0, 2.0), "src_pixdim": (2.0, 1.0)}), + get_metadata(False, {"pixdim": (1.0, 2.0), "src_pixdim": (2.0, 1.0), + "shape_override": torch.IntTensor([1, 64, 16])}) + ), + ( + get_arange_img((32, 32, 24)), + affine_scale_3d(0.5), + get_metadata(True, {"pixdim": (2.0, 2.0, 2.0)}), + get_metadata(True, {"pixdim": (2.0, 2.0, 2.0), + "shape_override": torch.IntTensor([1, 16, 16, 12])}) + ), + ( + get_arange_img((32, 32, 24)), + affine_scale_3d(0.5), + get_metadata(True, {"src_pixdim": (0.5, 0.5, 0.5)}), + get_metadata(True, {"src_pixdim": (0.5, 0.5, 0.5), + "shape_override": torch.IntTensor([1, 16, 16, 12])}) + ), + ( + get_arange_img((32, 32, 24)), + affine_scale_3d(0.25), + get_metadata(True, {"pixdim": (2.0, 2.0, 2.0), "src_pixdim": (0.5, 0.5, 0.5)}), + get_metadata(True, {"pixdim": (2.0, 2.0, 2.0), "src_pixdim": (0.5, 0.5, 0.5), + "shape_override": torch.IntTensor([1, 8, 8, 6])}) + ), + ( + get_arange_img((32, 32, 24)), + affine_scale_3d(2.0), + get_metadata(True, {"src_pixdim": (2.0, 2.0, 2.0)}), + get_metadata(True, {"src_pixdim": (2.0, 2.0, 2.0), + "shape_override": torch.IntTensor([1, 64, 64, 48])}) + ), + ( + get_arange_img((32, 32, 24)), + affine_scale_3d(2.0), + get_metadata(True, {"pixdim": (0.5, 0.5, 0.5)}), + get_metadata(True, {"pixdim": (0.5, 0.5, 0.5), + "shape_override": torch.IntTensor([1, 64, 64, 48])}) + ), + ( + get_arange_img((32, 32, 24)), + affine_scale_3d(4.0), + get_metadata(True, {"pixdim": (0.5, 0.5, 0.5), "src_pixdim": (2.0, 2.0, 2.0)}), + get_metadata(True, {"pixdim": (0.5, 0.5, 0.5), "src_pixdim": (2.0, 2.0, 2.0), + "shape_override": torch.IntTensor([1, 128, 128, 96])}) + ), + ( + get_arange_img((32, 32, 24)), + affine_scale_3d((0.5, 2.0, 1/3.0)), + get_metadata(True, {"pixdim": (2.0, 1.0, 4.5), "src_pixdim": (1.0, 2.0, 1.5)}), + get_metadata(True, {"pixdim": (2.0, 1.0, 4.5), "src_pixdim": (1.0, 2.0, 1.5), + "shape_override": torch.IntTensor([1, 16, 64, 8])}) + ), + ( + get_arange_img((32, 32, 24)), + affine_scale_3d((2.0, 0.5, 3.0)), + get_metadata(True, {"pixdim": (1.0, 2.0, 1.5), "src_pixdim": (2.0, 1.0, 4.5)}), + get_metadata(True, {"pixdim": (1.0, 2.0, 1.5), "src_pixdim": (2.0, 1.0, 4.5), + "shape_override": torch.IntTensor([1, 64, 16, 72])}) + ), + ] + + def _test_functional_spacing_impl( + self, img, expected_transform, call_params, expected_metadata + ): + img_, transform_, metadata = spacing(img, **call_params) + self.assertTrue(torch.allclose(transform_, expected_transform), + msg=f"{transform_} != {expected_transform}") + actual_keys = set(metadata.keys()) + expected_keys = set(expected_metadata.keys()) + self.assertSetEqual(actual_keys, expected_keys) + for k in actual_keys: + if isinstance(metadata[k], torch.Tensor): + self.assertTrue(torch.allclose(metadata[k], expected_metadata[k]), + msg=f"{metadata[k]} != {expected_metadata[k]} for key {k}") + else: + self.assertEqual(metadata[k], expected_metadata[k], + msg=f"{metadata[k]} != {expected_metadata[k]} for key {k}") + + def test_functional_spacing(self): + for icase, case in enumerate(self.SPACING_CASES): + with self.subTest(f"{icase}"): + self._test_functional_spacing_impl(*case) From 61dbe1104382dc46b1d1cb6f90299bdf4fa1d51e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 30 Oct 2022 11:16:13 +0000 Subject: [PATCH 03/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/transforms/lazy/array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/lazy/array.py b/monai/transforms/lazy/array.py index e271d07703..ae165cf566 100644 --- a/monai/transforms/lazy/array.py +++ b/monai/transforms/lazy/array.py @@ -45,4 +45,4 @@ def inverse(self, data): # rd[k] = apply(v) # # def inverse(self, data): -# return NotImplementedError() \ No newline at end of file +# return NotImplementedError() From 9c11d4d466bfd50e24158de2f54ce37bef0bbc3f Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Tue, 1 Nov 2022 17:17:10 +0000 Subject: [PATCH 04/19] Adding additional spatial functional unit tests Signed-off-by: Ben Murray --- monai/transforms/lazy/functional.py | 14 ++-- monai/transforms/meta_matrix.py | 7 ++ monai/transforms/spatial/functional.py | 22 +++-- tests/test_resizef.py | 30 +++---- tests/test_rotatef.py | 107 +++++++++++++++++++++++++ tests/test_spacingf.py | 32 ++++---- 6 files changed, 167 insertions(+), 45 deletions(-) create mode 100644 tests/test_rotatef.py diff --git a/monai/transforms/lazy/functional.py b/monai/transforms/lazy/functional.py index 50192514af..3ed005ec1d 100644 --- a/monai/transforms/lazy/functional.py +++ b/monai/transforms/lazy/functional.py @@ -47,19 +47,19 @@ def shape_from_extents( if isinstance(extents, (list, tuple)): if isinstance(extents[0], np.ndarray): aextents = np.asarray(extents) - aextents = torch.from_numpy(aextents) else: aextents = torch.stack(extents) + aextents = aextents.numpy() else: if isinstance(extents, np.ndarray): - aextents = torch.from_numpy(extents) - else: aextents = extents + else: + aextents = extents.numpy() - mins = aextents.min(axis=0)[0] - maxes = aextents.max(axis=0)[0] - values = torch.round(maxes - mins).type(torch.IntTensor)[:-1] - return torch.cat((torch.IntTensor([src_shape[0]]), values)) + mins = aextents.min(axis=0) + maxes = aextents.max(axis=0) + values = np.round(maxes - mins).astype(int)[:-1].tolist() + return (src_shape[0],) + tuple(values) def metadata_is_compatible(value_1, value_2): diff --git a/monai/transforms/meta_matrix.py b/monai/transforms/meta_matrix.py index c8df87c1e6..dc6a8c5902 100644 --- a/monai/transforms/meta_matrix.py +++ b/monai/transforms/meta_matrix.py @@ -112,6 +112,13 @@ def ensure_tensor(data: NdarrayOrTensor): return torch.as_tensor(data) +def apply_align_corners(matrix, spatial_size, factory): + inflated_spatial_size = tuple(s + 1 for s in spatial_size) + scale_factors = tuple(s / i for s, i in zip(spatial_size, inflated_spatial_size)) + scale_mat = factory.scale(scale_factors) + return matmul(scale_mat, matrix) + + class Matrix: def __init__(self, matrix: NdarrayOrTensor): self.data = ensure_tensor(matrix) diff --git a/monai/transforms/spatial/functional.py b/monai/transforms/spatial/functional.py index cf121d4465..9bac7cb521 100644 --- a/monai/transforms/spatial/functional.py +++ b/monai/transforms/spatial/functional.py @@ -22,7 +22,7 @@ from monai.config import DtypeLike from monai.data import get_track_meta from monai.transforms.lazy.functional import extents_from_shape, shape_from_extents -from monai.transforms.meta_matrix import MatrixFactory +from monai.transforms.meta_matrix import MatrixFactory, apply_align_corners from monai.utils import ( convert_to_tensor, ensure_tuple, @@ -289,22 +289,28 @@ def rotate( raise ValueError(f"Unsupported image dimension: {input_ndim}, available options are [2, 3].") angle_ = ensure_tuple_rep(angle, 1 if input_ndim == 2 else 3) - rotate_tx = torch.from_numpy(create_rotate(input_ndim, angle_)) + rotate_tx = torch.from_numpy(create_rotate(input_ndim, angle_).astype(np.float32)) im_extents = extents_from_shape(input_shape) if not keep_size: im_extents = [rotate_tx @ e for e in im_extents] spatial_shape = shape_from_extents(input_shape, im_extents) else: spatial_shape = input_shape - transform = rotate_tx + + if align_corners is True: + transform = apply_align_corners(rotate_tx, spatial_shape[1:], + MatrixFactory.from_tensor(img_)).matrix.data + else: + transform = rotate_tx + metadata = { - "angle": angle_, + "angle": angle, "keep_size": keep_size, "mode": mode_, "padding_mode": padding_mode_, "align_corners": align_corners, "dtype": dtype_, - "im_extents": im_extents, + # "im_extents": im_extents, "shape_override": spatial_shape } return img_, transform, metadata @@ -362,9 +368,11 @@ def zoom( "align_corners": align_corners, "keep_size": keep_size, "dtype": dtype_, - "im_extents": im_extents, - "shape_override": shape_override_ + "im_extents": im_extents } + if keep_size is False or align_corners is True: + metadata["shape_override"] = shape_override_ + return img_, transform, metadata diff --git a/tests/test_resizef.py b/tests/test_resizef.py index eb32e70e3b..775098674a 100644 --- a/tests/test_resizef.py +++ b/tests/test_resizef.py @@ -52,19 +52,19 @@ class TestFunctionalSpacing(unittest.TestCase): get_arange_img((32, 32)), affine_scale_2d(0.5), get_metadata(False, {"spatial_size": (16, 16)}), - get_metadata(False, {"spatial_size": (16, 16), "shape_override": torch.IntTensor([1, 16, 16])}) + get_metadata(False, {"spatial_size": (16, 16), "shape_override": (1, 16, 16)}) ), ( get_arange_img((32, 32)), affine_scale_2d(2.0), get_metadata(False, {"spatial_size": (64, 64)}), - get_metadata(False, {"spatial_size": (64, 64), "shape_override": torch.IntTensor([1, 64, 64])}) + get_metadata(False, {"spatial_size": (64, 64), "shape_override": (1, 64, 64)}) ), ( get_arange_img((32, 16)), affine_scale_2d((0.5, 1.0)), get_metadata(False, {"spatial_size": (16, 16)}), - get_metadata(False, {"spatial_size": (16, 16), "shape_override": torch.IntTensor([1, 16, 16])}) + get_metadata(False, {"spatial_size": (16, 16), "shape_override": (1, 16, 16)}) ), # 2d - "longest" ( @@ -72,28 +72,28 @@ class TestFunctionalSpacing(unittest.TestCase): affine_scale_2d((0.5)), get_metadata(False, {"spatial_size": 16, "size_mode": "longest"}), get_metadata(False, {"spatial_size": 16, "size_mode": "longest", - "shape_override": torch.IntTensor([1, 16, 8])}) + "shape_override": (1, 16, 8)}) ), ( get_arange_img((16, 32)), affine_scale_2d((0.5)), get_metadata(False, {"spatial_size": 16, "size_mode": "longest"}), get_metadata(False, {"spatial_size": 16, "size_mode": "longest", - "shape_override": torch.IntTensor([1, 8, 16])}) + "shape_override": (1, 8, 16)}) ), ( get_arange_img((32, 16)), affine_scale_2d((2.0)), get_metadata(False, {"spatial_size": 64, "size_mode": "longest"}), get_metadata(False, {"spatial_size": 64, "size_mode": "longest", - "shape_override": torch.IntTensor([1, 64, 32])}) + "shape_override": (1, 64, 32)}) ), ( get_arange_img((16, 32)), affine_scale_2d((2.0)), get_metadata(False, {"spatial_size": 64, "size_mode": "longest"}), get_metadata(False, {"spatial_size": 64, "size_mode": "longest", - "shape_override": torch.IntTensor([1, 32, 64])}) + "shape_override": (1, 32, 64)}) ), # 3d - "all" ( @@ -101,14 +101,14 @@ class TestFunctionalSpacing(unittest.TestCase): affine_scale_3d(0.5), get_metadata(False, {"spatial_size": (16, 16, 8)}), get_metadata(False, {"spatial_size": (16, 16, 8), - "shape_override": torch.IntTensor([1, 16, 16, 8])}) + "shape_override": (1, 16, 16, 8)}) ), ( get_arange_img((32, 32, 16)), affine_scale_3d(2.0), get_metadata(False, {"spatial_size": (64, 64, 32)}), get_metadata(False, {"spatial_size": (64, 64, 32), - "shape_override": torch.IntTensor([1, 64, 64, 32])}) + "shape_override": (1, 64, 64, 32)}) ), # 3d - "longest" ( @@ -116,42 +116,42 @@ class TestFunctionalSpacing(unittest.TestCase): affine_scale_3d((0.5)), get_metadata(False, {"spatial_size": 16, "size_mode": "longest"}), get_metadata(False, {"spatial_size": 16, "size_mode": "longest", - "shape_override": torch.IntTensor([1, 16, 8, 4])}) + "shape_override": (1, 16, 8, 4)}) ), ( get_arange_img((16, 32, 8)), affine_scale_3d((0.5)), get_metadata(False, {"spatial_size": 16, "size_mode": "longest"}), get_metadata(False, {"spatial_size": 16, "size_mode": "longest", - "shape_override": torch.IntTensor([1, 8, 16, 4])}) + "shape_override": (1, 8, 16, 4)}) ), ( get_arange_img((8, 16, 32)), affine_scale_3d((0.5)), get_metadata(False, {"spatial_size": 16, "size_mode": "longest"}), get_metadata(False, {"spatial_size": 16, "size_mode": "longest", - "shape_override": torch.IntTensor([1, 4, 8, 16])}) + "shape_override": (1, 4, 8, 16)}) ), ( get_arange_img((32, 16, 8)), affine_scale_3d((2.0)), get_metadata(False, {"spatial_size": 64, "size_mode": "longest"}), get_metadata(False, {"spatial_size": 64, "size_mode": "longest", - "shape_override": torch.IntTensor([1, 64, 32, 16])}) + "shape_override": (1, 64, 32, 16)}) ), ( get_arange_img((16, 32, 8)), affine_scale_3d((2.0)), get_metadata(False, {"spatial_size": 64, "size_mode": "longest"}), get_metadata(False, {"spatial_size": 64, "size_mode": "longest", - "shape_override": torch.IntTensor([1, 32, 64, 16])}) + "shape_override": (1, 32, 64, 16)}) ), ( get_arange_img((8, 16, 32)), affine_scale_3d((2.0)), get_metadata(False, {"spatial_size": 64, "size_mode": "longest"}), get_metadata(False, {"spatial_size": 64, "size_mode": "longest", - "shape_override": torch.IntTensor([1, 16, 32, 64])}) + "shape_override": (1, 16, 32, 64)}) ), ] diff --git a/tests/test_rotatef.py b/tests/test_rotatef.py new file mode 100644 index 0000000000..ebe7c8b520 --- /dev/null +++ b/tests/test_rotatef.py @@ -0,0 +1,107 @@ +import unittest + +import math + +import torch + +from monai.transforms.spatial.functional import rotate +from tests.utils import get_arange_img + + +def affine_rotate_2d(radians, scale=None): + t = torch.eye(3) + t[:, 0] = torch.FloatTensor([math.cos(radians), math.sin(radians), 0.0]) + t[:, 1] = torch.FloatTensor([-math.sin(radians), math.cos(radians), 0.0]) + if scale is not None: + t[0, 0] *= scale + t[0, 1] *= scale + t[1, 0] *= scale + t[1, 1] *= scale + return t + + +def affine_scale_3d(scale): + scale0, scale1, scale2 = scale if isinstance(scale, (tuple, list)) else (scale, scale, scale) + t = torch.eye(4) + t[:, 0] = torch.FloatTensor([scale0, 0.0, 0.0, 0.0]) + t[:, 1] = torch.FloatTensor([0.0, scale1, 0.0, 0.0]) + t[:, 2] = torch.FloatTensor([0.0, 0.0, scale2, 0.0]) + return t + + +def get_metadata(is_3d=True, overrides=None, remove=None): + metadata = { + "angle": 0.0, + "keep_size": True, + "mode": "nearest", + "padding_mode": "zeros", + "align_corners": False, + "dtype": torch.float32, + # "im_extents": None, + # "shape_override": torch.IntTensor([1, 32, 32]) # shape override shouldn't always be in here + } + if overrides is not None: + for k, v in overrides.items(): + metadata[k] = v + if remove is not None: + for k in remove: + if k in metadata: + del metadata[k] + return metadata + + +class TestFunctionalRotate(unittest.TestCase): + + ROTATE_CASES = [ + # keep_size = True + ( + get_arange_img((32, 32)), + affine_rotate_2d(torch.pi / 4), + get_metadata(False, {"angle": torch.pi / 4}), + get_metadata(False, {"angle": torch.pi / 4, "shape_override": (1, 32, 32)}) + ), + ( + get_arange_img((32, 32)), + affine_rotate_2d(torch.pi / 4, 32/33), + get_metadata(False, {"angle": torch.pi / 4, "align_corners": True}), + get_metadata(False, {"angle": torch.pi / 4, "align_corners": True, + "shape_override": (1, 32, 32)}) + ), + # keep_size = False + ( + get_arange_img((32, 32)), + affine_rotate_2d(torch.pi / 4), + get_metadata(False, {"angle": torch.pi / 4, "keep_size": False}), + get_metadata(False, {"angle": torch.pi / 4, "keep_size": False, + "shape_override": (1, 45, 45)}) + ), + ( + get_arange_img((32, 32)), + affine_rotate_2d(torch.pi / 4, 45/46), + get_metadata(False, {"angle": torch.pi / 4, "keep_size": False, "align_corners": True}), + get_metadata(False, {"angle": torch.pi / 4, "keep_size": False, "align_corners": True, + "shape_override": (1, 45, 45)}) + ), + ] + + def _test_functional_rotate_impl( + self, img, expected_transform, call_params, expected_metadata + ): + img_, transform_, metadata = rotate(img, **call_params) + self.assertTrue(torch.allclose(transform_, expected_transform), + msg=f"{transform_} != {expected_transform}") + actual_keys = set(metadata.keys()) + expected_keys = set(expected_metadata.keys()) + self.assertSetEqual(actual_keys, expected_keys) + for k in actual_keys: + if isinstance(metadata[k], torch.Tensor): + self.assertTrue(torch.allclose(metadata[k], expected_metadata[k]), + msg=f"{metadata[k]} != {expected_metadata[k]} for key {k}") + else: + self.assertEqual(metadata[k], expected_metadata[k], + msg=f"{metadata[k]} != {expected_metadata[k]} for key {k}") + + def test_functional_rotate(self): + for icase, case in enumerate(self.ROTATE_CASES): + with self.subTest(f"{icase}"): + self._test_functional_rotate_impl(*case) diff --git a/tests/test_spacingf.py b/tests/test_spacingf.py index 5b41532943..0fcc1b43d6 100644 --- a/tests/test_spacingf.py +++ b/tests/test_spacingf.py @@ -52,109 +52,109 @@ class TestFunctionalSpacing(unittest.TestCase): get_arange_img((32, 32)), affine_scale_2d(0.5), get_metadata(False, {"pixdim": (2.0, 2.0)}), - get_metadata(False, {"pixdim": (2.0, 2.0), "shape_override": torch.IntTensor([1, 16, 16])}) + get_metadata(False, {"pixdim": (2.0, 2.0), "shape_override": (1, 16, 16)}) ), ( get_arange_img((32, 32)), affine_scale_2d(0.5), get_metadata(False, {"src_pixdim": (0.5, 0.5)}), - get_metadata(False, {"src_pixdim": (0.5, 0.5), "shape_override": torch.IntTensor([1, 16, 16])}) + get_metadata(False, {"src_pixdim": (0.5, 0.5), "shape_override": (1, 16, 16)}) ), ( get_arange_img((32, 32)), affine_scale_2d(0.25), get_metadata(False, {"pixdim": (2.0, 2.0), "src_pixdim": (0.5, 0.5)}), get_metadata(False, {"pixdim": (2.0, 2.0), "src_pixdim": (0.5, 0.5), - "shape_override": torch.IntTensor([1, 8, 8])}) + "shape_override": (1, 8, 8)}) ), ( get_arange_img((32, 32)), affine_scale_2d(2.0), get_metadata(False, {"src_pixdim": (2.0, 2.0)}), - get_metadata(False, {"src_pixdim": (2.0, 2.0), "shape_override": torch.IntTensor([1, 64, 64])}) + get_metadata(False, {"src_pixdim": (2.0, 2.0), "shape_override": (1, 64, 64)}) ), ( get_arange_img((32, 32)), affine_scale_2d(2.0), get_metadata(False, {"pixdim": (0.5, 0.5)}), - get_metadata(False, {"pixdim": (0.5, 0.5), "shape_override": torch.IntTensor([1, 64, 64])}) + get_metadata(False, {"pixdim": (0.5, 0.5), "shape_override": (1, 64, 64)}) ), ( get_arange_img((32, 32)), affine_scale_2d(4.0), get_metadata(False, {"pixdim": (0.5, 0.5), "src_pixdim": (2.0, 2.0)}), get_metadata(False, {"pixdim": (0.5, 0.5), "src_pixdim": (2.0, 2.0), - "shape_override": torch.IntTensor([1, 128, 128])}) + "shape_override": (1, 128, 128)}) ), ( get_arange_img((32, 32)), affine_scale_2d((0.5, 2.0)), get_metadata(False, {"pixdim": (2.0, 1.0), "src_pixdim": (1.0, 2.0)}), get_metadata(False, {"pixdim": (2.0, 1.0), "src_pixdim": (1.0, 2.0), - "shape_override": torch.IntTensor([1, 16, 64])}) + "shape_override": (1, 16, 64)}) ), ( get_arange_img((32, 32)), affine_scale_2d((2.0, 0.5)), get_metadata(False, {"pixdim": (1.0, 2.0), "src_pixdim": (2.0, 1.0)}), get_metadata(False, {"pixdim": (1.0, 2.0), "src_pixdim": (2.0, 1.0), - "shape_override": torch.IntTensor([1, 64, 16])}) + "shape_override": (1, 64, 16)}) ), ( get_arange_img((32, 32, 24)), affine_scale_3d(0.5), get_metadata(True, {"pixdim": (2.0, 2.0, 2.0)}), get_metadata(True, {"pixdim": (2.0, 2.0, 2.0), - "shape_override": torch.IntTensor([1, 16, 16, 12])}) + "shape_override": (1, 16, 16, 12)}) ), ( get_arange_img((32, 32, 24)), affine_scale_3d(0.5), get_metadata(True, {"src_pixdim": (0.5, 0.5, 0.5)}), get_metadata(True, {"src_pixdim": (0.5, 0.5, 0.5), - "shape_override": torch.IntTensor([1, 16, 16, 12])}) + "shape_override": (1, 16, 16, 12)}) ), ( get_arange_img((32, 32, 24)), affine_scale_3d(0.25), get_metadata(True, {"pixdim": (2.0, 2.0, 2.0), "src_pixdim": (0.5, 0.5, 0.5)}), get_metadata(True, {"pixdim": (2.0, 2.0, 2.0), "src_pixdim": (0.5, 0.5, 0.5), - "shape_override": torch.IntTensor([1, 8, 8, 6])}) + "shape_override": (1, 8, 8, 6)}) ), ( get_arange_img((32, 32, 24)), affine_scale_3d(2.0), get_metadata(True, {"src_pixdim": (2.0, 2.0, 2.0)}), get_metadata(True, {"src_pixdim": (2.0, 2.0, 2.0), - "shape_override": torch.IntTensor([1, 64, 64, 48])}) + "shape_override": (1, 64, 64, 48)}) ), ( get_arange_img((32, 32, 24)), affine_scale_3d(2.0), get_metadata(True, {"pixdim": (0.5, 0.5, 0.5)}), get_metadata(True, {"pixdim": (0.5, 0.5, 0.5), - "shape_override": torch.IntTensor([1, 64, 64, 48])}) + "shape_override": (1, 64, 64, 48)}) ), ( get_arange_img((32, 32, 24)), affine_scale_3d(4.0), get_metadata(True, {"pixdim": (0.5, 0.5, 0.5), "src_pixdim": (2.0, 2.0, 2.0)}), get_metadata(True, {"pixdim": (0.5, 0.5, 0.5), "src_pixdim": (2.0, 2.0, 2.0), - "shape_override": torch.IntTensor([1, 128, 128, 96])}) + "shape_override": (1, 128, 128, 96)}) ), ( get_arange_img((32, 32, 24)), affine_scale_3d((0.5, 2.0, 1/3.0)), get_metadata(True, {"pixdim": (2.0, 1.0, 4.5), "src_pixdim": (1.0, 2.0, 1.5)}), get_metadata(True, {"pixdim": (2.0, 1.0, 4.5), "src_pixdim": (1.0, 2.0, 1.5), - "shape_override": torch.IntTensor([1, 16, 64, 8])}) + "shape_override": (1, 16, 64, 8)}) ), ( get_arange_img((32, 32, 24)), affine_scale_3d((2.0, 0.5, 3.0)), get_metadata(True, {"pixdim": (1.0, 2.0, 1.5), "src_pixdim": (2.0, 1.0, 4.5)}), get_metadata(True, {"pixdim": (1.0, 2.0, 1.5), "src_pixdim": (2.0, 1.0, 4.5), - "shape_override": torch.IntTensor([1, 64, 16, 72])}) + "shape_override": (1, 64, 16, 72)}) ), ] From 77d7c315410d54112992a882651aba6d5c472148 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Tue, 1 Nov 2022 23:41:07 +0000 Subject: [PATCH 05/19] Work on apply and matmul tests Signed-off-by: Ben Murray --- monai/data/meta_obj.py | 3 ++ monai/transforms/lazy/functional.py | 10 +++--- tests/test_apply.py | 55 +++++++++++++++-------------- tests/test_matmul.py | 30 +++++++++++++++- 4 files changed, 66 insertions(+), 32 deletions(-) diff --git a/monai/data/meta_obj.py b/monai/data/meta_obj.py index 6aab05dc94..f804fc2e3d 100644 --- a/monai/data/meta_obj.py +++ b/monai/data/meta_obj.py @@ -213,6 +213,9 @@ def push_pending_operation(self, t: Any) -> None: def pop_pending_operation(self) -> Any: return self._pending_operations.pop() + def clear_pending_operations(self) -> Any: + self._pending_operations = list() + @property def is_batch(self) -> bool: """Return whether object is part of batch or not.""" diff --git a/monai/transforms/lazy/functional.py b/monai/transforms/lazy/functional.py index 3ed005ec1d..02e44150ec 100644 --- a/monai/transforms/lazy/functional.py +++ b/monai/transforms/lazy/functional.py @@ -132,7 +132,7 @@ def apply(data: Union[torch.Tensor, MetaTensor, dict], pending: Optional[Union[d return rd if isinstance(data, MetaTensor) and pending is None: - pending_ = data.pending_transforms + pending_ = data.pending_operations else: pending_ = [] if pending is None else pending @@ -174,7 +174,7 @@ def apply(data: Union[torch.Tensor, MetaTensor, dict], pending: Optional[Union[d kwargs = prepare_args_dict_for_apply(cur_mode, cur_padding_mode, cur_device, cur_dtype) cumulative_matrix_ = matrix_from_matrix_container(cumulative_matrix) - a = Affine(norm_coords=False, affine=cumulative_matrix_, **kwargs) + a = Affine(affine=cumulative_matrix_, **kwargs) data, _ = a(img=data) cumulative_matrix, cumulative_extents = starting_matrix_and_extents(matrix_factory, data) @@ -190,12 +190,12 @@ def apply(data: Union[torch.Tensor, MetaTensor, dict], pending: Optional[Union[d cumulative_matrix_ = matrix_from_matrix_container(cumulative_matrix) # print(f"applying with cumulative matrix\n {cumulative_matrix_}") - a = Affine(norm_coords=False, affine=cumulative_matrix_, spatial_size=cur_shape[1:], normalized=False, **kwargs) + a = Affine(affine=cumulative_matrix_, spatial_size=cur_shape[1:], normalized=False, **kwargs) data, tx = a(img=data) if isinstance(data, MetaTensor): - data.clear_pending_transforms() + data.clear_pending_operations() for p in pending_: data.affine = p.matrix.data data.push_applied_operation(p) - return data, pending_ + return data, None if pending is None else pending_ diff --git a/tests/test_apply.py b/tests/test_apply.py index 9cd9aee462..2a5d5aa2a7 100644 --- a/tests/test_apply.py +++ b/tests/test_apply.py @@ -12,17 +12,30 @@ import unittest import torch -from monai.utils import convert_to_tensor +from monai.utils import convert_to_tensor, TransformBackends from monai.transforms.lazy.functional import apply -from monai.transforms.meta_matrix import MetaMatrix - +from monai.transforms.meta_matrix import MetaMatrix, MatrixFactory + + +def single_2d_transform_cases(): + f = MatrixFactory(2, TransformBackends.TORCH, "cpu") + + cases = [ + ( + torch.randn((1, 32, 32)), + [MetaMatrix(f.rotate_euler(torch.pi / 4).matrix, {"id": "rotate"})], + (1, 32, 32) + ), + ( + torch.randn((1, 16, 16)), + [MetaMatrix(f.rotate_euler(torch.pi / 4).matrix, + {"id": "rotate", "shape_override": (1, 45, 45)})], + (1, 45, 45) + ) + ] -def rotate_45_2d(): - t = torch.eye(3) - t[:, 0] = torch.FloatTensor([0, -1, 0]) - t[:, 1] = torch.FloatTensor([1, 0, 0]) - return t + return cases class TestApply(unittest.TestCase): @@ -32,29 +45,19 @@ def _test_apply_impl(self, tensor, pending_transforms): result = apply(tensor, pending_transforms) self.assertListEqual(result[1], pending_transforms) - def _test_apply_metatensor_impl(self, tensor, pending_transforms, pending_as_parameter): - tensor_ = convert_to_tensor(tensor) + def _test_apply_metatensor_impl(self, tensor, pending_transforms, expected_shape, pending_as_parameter): + tensor_ = convert_to_tensor(tensor, track_meta=True) if pending_as_parameter: - result = apply(tensor_, pending_transforms) + result, transforms = apply(tensor_, pending_transforms) else: for p in pending_transforms: - # TODO: cannot do the next part until the MetaTensor PR #5107 is in - # tensor_.push_pending(p) - raise NotImplementedError() + tensor_.push_pending_operation(p) + result, transforms = apply(tensor_) - def _test_apply_metatensor_impl(self, tensor, pending_transforms, pending_as_parameter): - tensor_ = convert_to_tensor(tensor) - if pending_as_parameter: - result = apply(tensor_, pending_transforms) - else: - for p in pending_transforms: - # TODO: cannot do the next part until the MetaTensor PR #5107 is in - # tensor_.push_pending(p) - raise NotImplementedError() + self.assertEqual(result.shape, expected_shape) - SINGLE_TRANSFORM_CASES = [ - (torch.randn((1, 16, 16)), [MetaMatrix(rotate_45_2d(), {"id": "rotate"})]) - ] + + SINGLE_TRANSFORM_CASES = single_2d_transform_cases() def test_apply_single_transform(self): for case in self.SINGLE_TRANSFORM_CASES: diff --git a/tests/test_matmul.py b/tests/test_matmul.py index 2ac3b70a2e..813f623423 100644 --- a/tests/test_matmul.py +++ b/tests/test_matmul.py @@ -11,13 +11,15 @@ import unittest +from parameterized import parameterized + import numpy as np import torch -from parameterized import parameterized from monai.transforms.meta_matrix import ( Grid, Matrix, + MatrixFactory, is_grid_shaped, is_matrix_shaped, matmul, @@ -25,6 +27,20 @@ matmul_grid_matrix_slow, matmul_matrix_grid, ) +from monai.utils import TransformBackends + + +def get_matmul_2d_test_cases(): + f = MatrixFactory(2, TransformBackends.TORCH, "cpu") + cases = [ + ( + f.rotate_euler(torch.pi / 4), + f.scale((0.5, 0.5)), + torch.FloatTensor([[0.35355339, -0.35355339, 0], [0.35355339, 0.35355339, 0], [0, 0, 1]]) + ) + ] + + return cases class TestMatmulFunctions(unittest.TestCase): @@ -139,5 +155,17 @@ def test_is_grid_shaped(self, grid, expected): # self._test_is_grid_shaped_impl(*case) + def _test_matmul_outputs_impl(self, left, right, expected): + actual = matmul(left, right) + self.assertTrue(torch.allclose(actual.matrix.data, expected), + msg=f"{actual.matrix.data} is not close to {expected}") + + def test_matmul_outputs(self): + cases = get_matmul_2d_test_cases() + for case in cases: + self._test_matmul_outputs_impl(*case) + + + if __name__ == "__main__": unittest.main() From 828a7d8727b41055e9fbdd43645a82bd8fa9e6d7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 1 Nov 2022 12:52:00 +0000 Subject: [PATCH 06/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/transforms/spatial/functional.py | 2 +- tests/test_flipf.py | 2 +- tests/test_resizef.py | 22 +++++++++++----------- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/monai/transforms/spatial/functional.py b/monai/transforms/spatial/functional.py index 9bac7cb521..abf8312415 100644 --- a/monai/transforms/spatial/functional.py +++ b/monai/transforms/spatial/functional.py @@ -18,7 +18,7 @@ from monai.networks.utils import meshgrid_ij -from monai.transforms import create_grid, create_rotate, create_translate, map_spatial_axes +from monai.transforms import create_grid, create_rotate from monai.config import DtypeLike from monai.data import get_track_meta from monai.transforms.lazy.functional import extents_from_shape, shape_from_extents diff --git a/tests/test_flipf.py b/tests/test_flipf.py index ad1b364e94..58e5b34bb6 100644 --- a/tests/test_flipf.py +++ b/tests/test_flipf.py @@ -134,4 +134,4 @@ def _test_functional_flip_impl( def test_functional_flip(self): for icase, case in enumerate(self.FLIP_CASES): with self.subTest(f"{icase}"): - self._test_functional_flip_impl(*case) \ No newline at end of file + self._test_functional_flip_impl(*case) diff --git a/tests/test_resizef.py b/tests/test_resizef.py index 775098674a..cd69b5ea58 100644 --- a/tests/test_resizef.py +++ b/tests/test_resizef.py @@ -69,28 +69,28 @@ class TestFunctionalSpacing(unittest.TestCase): # 2d - "longest" ( get_arange_img((32, 16)), - affine_scale_2d((0.5)), + affine_scale_2d(0.5), get_metadata(False, {"spatial_size": 16, "size_mode": "longest"}), get_metadata(False, {"spatial_size": 16, "size_mode": "longest", "shape_override": (1, 16, 8)}) ), ( get_arange_img((16, 32)), - affine_scale_2d((0.5)), + affine_scale_2d(0.5), get_metadata(False, {"spatial_size": 16, "size_mode": "longest"}), get_metadata(False, {"spatial_size": 16, "size_mode": "longest", "shape_override": (1, 8, 16)}) ), ( get_arange_img((32, 16)), - affine_scale_2d((2.0)), + affine_scale_2d(2.0), get_metadata(False, {"spatial_size": 64, "size_mode": "longest"}), get_metadata(False, {"spatial_size": 64, "size_mode": "longest", "shape_override": (1, 64, 32)}) ), ( get_arange_img((16, 32)), - affine_scale_2d((2.0)), + affine_scale_2d(2.0), get_metadata(False, {"spatial_size": 64, "size_mode": "longest"}), get_metadata(False, {"spatial_size": 64, "size_mode": "longest", "shape_override": (1, 32, 64)}) @@ -113,42 +113,42 @@ class TestFunctionalSpacing(unittest.TestCase): # 3d - "longest" ( get_arange_img((32, 16, 8)), - affine_scale_3d((0.5)), + affine_scale_3d(0.5), get_metadata(False, {"spatial_size": 16, "size_mode": "longest"}), get_metadata(False, {"spatial_size": 16, "size_mode": "longest", "shape_override": (1, 16, 8, 4)}) ), ( get_arange_img((16, 32, 8)), - affine_scale_3d((0.5)), + affine_scale_3d(0.5), get_metadata(False, {"spatial_size": 16, "size_mode": "longest"}), get_metadata(False, {"spatial_size": 16, "size_mode": "longest", "shape_override": (1, 8, 16, 4)}) ), ( get_arange_img((8, 16, 32)), - affine_scale_3d((0.5)), + affine_scale_3d(0.5), get_metadata(False, {"spatial_size": 16, "size_mode": "longest"}), get_metadata(False, {"spatial_size": 16, "size_mode": "longest", "shape_override": (1, 4, 8, 16)}) ), ( get_arange_img((32, 16, 8)), - affine_scale_3d((2.0)), + affine_scale_3d(2.0), get_metadata(False, {"spatial_size": 64, "size_mode": "longest"}), get_metadata(False, {"spatial_size": 64, "size_mode": "longest", "shape_override": (1, 64, 32, 16)}) ), ( get_arange_img((16, 32, 8)), - affine_scale_3d((2.0)), + affine_scale_3d(2.0), get_metadata(False, {"spatial_size": 64, "size_mode": "longest"}), get_metadata(False, {"spatial_size": 64, "size_mode": "longest", "shape_override": (1, 32, 64, 16)}) ), ( get_arange_img((8, 16, 32)), - affine_scale_3d((2.0)), + affine_scale_3d(2.0), get_metadata(False, {"spatial_size": 64, "size_mode": "longest"}), get_metadata(False, {"spatial_size": 64, "size_mode": "longest", "shape_override": (1, 16, 32, 64)}) @@ -175,4 +175,4 @@ def _test_functional_resize_impl( def test_functional_resize(self): for icase, case in enumerate(self.SPACING_CASES): with self.subTest(f"{icase}"): - self._test_functional_resize_impl(*case) \ No newline at end of file + self._test_functional_resize_impl(*case) From 23a05fa23ba6f682e0727b1877ba999a2f1d202b Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Wed, 2 Nov 2022 11:02:33 +0000 Subject: [PATCH 07/19] Adding tests for matmul and matrix_matrix Signed-off-by: Ben Murray --- tests/test_matmul.py | 73 ++++++++++++++++++++++++++++++++++---------- 1 file changed, 57 insertions(+), 16 deletions(-) diff --git a/tests/test_matmul.py b/tests/test_matmul.py index 813f623423..33250ae05a 100644 --- a/tests/test_matmul.py +++ b/tests/test_matmul.py @@ -26,23 +26,11 @@ matmul_grid_matrix, matmul_grid_matrix_slow, matmul_matrix_grid, + matmul_matrix_matrix, ) from monai.utils import TransformBackends -def get_matmul_2d_test_cases(): - f = MatrixFactory(2, TransformBackends.TORCH, "cpu") - cases = [ - ( - f.rotate_euler(torch.pi / 4), - f.scale((0.5, 0.5)), - torch.FloatTensor([[0.35355339, -0.35355339, 0], [0.35355339, 0.35355339, 0], [0, 0, 1]]) - ) - ] - - return cases - - class TestMatmulFunctions(unittest.TestCase): def test_matrix_grid_and_grid_inv_matrix_are_equivalent(self): grid = torch.randn((3, 32, 32)) @@ -155,17 +143,70 @@ def test_is_grid_shaped(self, grid, expected): # self._test_is_grid_shaped_impl(*case) +def get_matmul_2d_test_cases(): + f = MatrixFactory(2, TransformBackends.TORCH, "cpu") + cases = [ + ( + f.rotate_euler(torch.pi / 4), + f.scale((0.5, 0.5)), + torch.FloatTensor([[0.35355339, -0.35355339, 0], [0.35355339, 0.35355339, 0], [0, 0, 1]]) + ), + ( + f.scale((0.5, 0.5)), + f.rotate_euler(torch.pi / 4), + torch.FloatTensor([[0.35355339, -0.35355339, 0], [0.35355339, 0.35355339, 0], [0, 0, 1]]) + + ), + ( + f.translate((8, 8)), + f.rotate_euler(torch.pi / 2), + torch.FloatTensor([[0, -1, 8], [1, 0, 8], [0, 0, 1]]) + ), + ( + f.rotate_euler(torch.pi / 2), + f.translate((8, 8)), + torch.FloatTensor([[0, -1, -8], [1, 0, 8], [0, 0, 1]]) + ), + ] + + return cases + + +MATMUL_2D_TEST_CASES = get_matmul_2d_test_cases() + + +class TestMatmulOutputs(unittest.TestCase): def _test_matmul_outputs_impl(self, left, right, expected): actual = matmul(left, right) - self.assertTrue(torch.allclose(actual.matrix.data, expected), + self.assertTrue(torch.allclose(actual.matrix.data, expected, atol=1e-7), msg=f"{actual.matrix.data} is not close to {expected}") - def test_matmul_outputs(self): - cases = get_matmul_2d_test_cases() + @parameterized.expand(MATMUL_2D_TEST_CASES) + def test_matmul_outputs(self, left, right, expected): + self._test_matmul_outputs_impl(left, right, expected) + + def test_all_matmul_outputs(self): + cases = MATMUL_2D_TEST_CASES for case in cases: self._test_matmul_outputs_impl(*case) +class TestMatrixMatrixOutputs(unittest.TestCase): + + def _test_matrix_matrix_outputs_impl(self, left, right, expected): + actual = matmul_matrix_matrix(left.matrix.data, right.matrix.data) + self.assertTrue(torch.allclose(actual, expected, atol=1e-7), + msg=f"{actual} is not close to {expected}") + + @parameterized.expand(MATMUL_2D_TEST_CASES) + def test_matrix_matrix_outputs(self, left, right, expected): + self._test_matrix_matrix_outputs_impl(left, right, expected) + + def test_all_matrix_matrix_outputs(self): + cases = MATMUL_2D_TEST_CASES + for case in cases: + self._test_matrix_matrix_outputs_impl(*case) + if __name__ == "__main__": unittest.main() From a67ef17a67c00ded2002656e716073fcff3232fe Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Wed, 2 Nov 2022 17:43:58 +0000 Subject: [PATCH 08/19] Removing spatial functional transforms from this PR Signed-off-by: Ben Murray --- monai/transforms/spatial/functional.py | 506 ------------------------- 1 file changed, 506 deletions(-) delete mode 100644 monai/transforms/spatial/functional.py diff --git a/monai/transforms/spatial/functional.py b/monai/transforms/spatial/functional.py deleted file mode 100644 index abf8312415..0000000000 --- a/monai/transforms/spatial/functional.py +++ /dev/null @@ -1,506 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional, Sequence, Tuple, Union - -import numpy as np - -import torch -from monai.networks.layers import GaussianFilter - -from monai.networks.utils import meshgrid_ij - -from monai.transforms import create_grid, create_rotate -from monai.config import DtypeLike -from monai.data import get_track_meta -from monai.transforms.lazy.functional import extents_from_shape, shape_from_extents -from monai.transforms.meta_matrix import MatrixFactory, apply_align_corners -from monai.utils import ( - convert_to_tensor, - ensure_tuple, - ensure_tuple_rep, - ensure_tuple_size, - fall_back_tuple, - get_equivalent_dtype, - look_up_option, - GridSampleMode, - GridSamplePadMode, - InterpolateMode, - NumpyPadMode -) - - -def identity( - img: torch.Tensor, - mode: Optional[Union[InterpolateMode, str]] = None, - padding_mode: Optional[Union[NumpyPadMode, GridSamplePadMode, str]] = None, - dtype: Optional[Union[DtypeLike, torch.dtype]] = None -): - img_ = convert_to_tensor(img, track_meta=get_track_meta()) - - mode_ = None if mode is None else look_up_option(mode, GridSampleMode) - padding_mode_ = None if padding_mode is None else look_up_option(padding_mode, GridSamplePadMode) - dtype_ = get_equivalent_dtype(dtype or img_.dtype, torch.Tensor) - - transform = MatrixFactory.from_tensor(img_).identity().matrix.matrix - - metadata = dict() - if mode_ is not None: - metadata["mode"] = mode_ - if padding_mode_ is not None: - metadata["padding_mode"] = padding_mode_ - metadata["dtype"] = dtype_ - return img_, transform, metadata - - -def spacing( - img: torch.Tensor, - pixdim: Union[Sequence[float], float], - src_pixdim: Union[Sequence[float], float], - diagonal: Optional[bool] = False, - mode: Optional[Union[InterpolateMode, str]] = InterpolateMode.AREA, - padding_mode: Optional[Union[NumpyPadMode, GridSamplePadMode, str]] = NumpyPadMode.EDGE, - align_corners: Optional[bool] = False, - dtype: Optional[Union[DtypeLike, torch.dtype]] = None, - shape_override: Optional[Sequence] = None -): - """ - Args: - img: channel first array, must have shape: (num_channels, H[, W, ..., ]). - mode: {``"nearest"``, ``"nearest-exact"``, ``"linear"``, - ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} - The interpolation mode. Defaults to ``self.mode``. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html - align_corners: This only has an effect when mode is - 'linear', 'bilinear', 'bicubic' or 'trilinear'. Defaults to ``self.align_corners``. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html - anti_aliasing: bool, optional - Whether to apply a Gaussian filter to smooth the image prior - to downsampling. It is crucial to filter when downsampling - the image to avoid aliasing artifacts. See also ``skimage.transform.resize`` - anti_aliasing_sigma: {float, tuple of floats}, optional - Standard deviation for Gaussian filtering used when anti-aliasing. - By default, this value is chosen as (s - 1) / 2 where s is the - downsampling factor, where s > 1. For the up-size case, s < 1, no - anti-aliasing is performed prior to rescaling. - - Raises: - ValueError: When ``self.spatial_size`` length is less than ``img`` spatial dimensions. - - """ - - img_ = convert_to_tensor(img, track_meta=get_track_meta()) - input_shape = img_.shape if shape_override is None else shape_override - input_ndim = len(input_shape) - 1 - - pixdim_ = ensure_tuple_rep(pixdim, input_ndim) - src_pixdim_ = ensure_tuple_rep(src_pixdim, input_ndim) - - if diagonal is True: - raise ValueError("'diagonal' value of True is not currently supported") - - mode_ = look_up_option(mode, GridSampleMode) - padding_mode_ = look_up_option(padding_mode, GridSamplePadMode) - dtype_ = get_equivalent_dtype(dtype or img.dtype, torch.Tensor) - zoom_factors = [i / j for i, j in zip(src_pixdim_, pixdim_)] - - # TODO: decide whether we are consistently returning MetaMatrix or concrete transforms - transform = MatrixFactory.from_tensor(img).scale(zoom_factors).matrix.data - im_extents = extents_from_shape(input_shape) - im_extents = [transform @ e for e in im_extents] - shape_override_ = shape_from_extents(input_shape, im_extents) - - metadata = { - "pixdim": pixdim_, - "src_pixdim": src_pixdim_, - "diagonal": diagonal, - "mode": mode_, - "padding_mode": padding_mode_, - "align_corners": align_corners, - "dtype": dtype_, - # "im_extents": im_extents, - "shape_override": shape_override_ - } - return img_, transform, metadata - - -def orientation( - img: torch.Tensor -): - pass - - -def flip( - img: torch.Tensor, - spatial_axis: Union[Sequence[int], int], - shape_override: Optional[Sequence] = None -): - img_ = convert_to_tensor(img, track_meta=get_track_meta()) - input_shape = img_.shape if shape_override is None else shape_override - - spatial_axis_ = spatial_axis - if spatial_axis_ is None: - spatial_axis_ = tuple(i for i in range(len(input_shape[1:]))) - transform = MatrixFactory.from_tensor(img).flip(spatial_axis_).matrix.data - # im_extents = extents_from_shape(input_shape) - # im_extents = [transform @ e for e in im_extents] - # - # shape_override_ = shape_from_extents(input_shape, im_extents) - - metadata = { - "spatial_axis": spatial_axis_, - # "im_extents": im_extents, - "shape_override": shape_override - } - return img_, transform, metadata - - -def resize( - img: torch.Tensor, - spatial_size: Union[Sequence[int], int], - size_mode: str = "all", - mode: Optional[Union[InterpolateMode, str]] = InterpolateMode.AREA, - align_corners: Optional[bool] = False, - anti_aliasing: Optional[bool] = None, - anti_aliasing_sigma: Optional[Union[Sequence[float], float]] = None, - dtype: Optional[Union[DtypeLike, torch.dtype]] = None, - shape_override: Optional[Sequence] = None -): - """ - Args: - img: channel first array, must have shape: (num_channels, H[, W, ..., ]). - mode: {``"nearest"``, ``"nearest-exact"``, ``"linear"``, - ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} - The interpolation mode. Defaults to ``self.mode``. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html - align_corners: This only has an effect when mode is - 'linear', 'bilinear', 'bicubic' or 'trilinear'. Defaults to ``self.align_corners``. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html - anti_aliasing: bool, optional - Whether to apply a Gaussian filter to smooth the image prior - to downsampling. It is crucial to filter when downsampling - the image to avoid aliasing artifacts. See also ``skimage.transform.resize`` - anti_aliasing_sigma: {float, tuple of floats}, optional - Standard deviation for Gaussian filtering used when anti-aliasing. - By default, this value is chosen as (s - 1) / 2 where s is the - downsampling factor, where s > 1. For the up-size case, s < 1, no - anti-aliasing is performed prior to rescaling. - - Raises: - ValueError: When ``self.spatial_size`` length is less than ``img`` spatial dimensions. - - """ - - img_ = convert_to_tensor(img, track_meta=get_track_meta()) - input_shape = img_.shape if shape_override is None else shape_override - input_ndim = len(input_shape) - 1 - - if size_mode == "all": - output_ndim = len(ensure_tuple(spatial_size)) - if output_ndim > input_ndim: - input_shape = ensure_tuple_size(input_shape, output_ndim + 1, 1) - img = img.reshape(input_shape) - elif output_ndim < input_ndim: - raise ValueError( - "len(spatial_size) must be greater or equal to img spatial dimensions, " - f"got spatial_size={output_ndim} img={input_ndim}." - ) - spatial_size_ = fall_back_tuple(spatial_size, input_shape[1:]) - else: # for the "longest" mode - img_size = input_shape[1:] - if not isinstance(spatial_size, int): - raise ValueError("spatial_size must be an int number if size_mode is 'longest'.") - scale = spatial_size / max(img_size) - spatial_size_ = tuple(int(round(s * scale)) for s in img_size) - - mode_ = look_up_option(mode, GridSampleMode) - dtype_ = get_equivalent_dtype(dtype or img.dtype, torch.Tensor) - zoom_factors = [i / j for i, j in zip(spatial_size_, input_shape[1:])] - transform = MatrixFactory.from_tensor(img).scale(zoom_factors).matrix.data - im_extents = extents_from_shape(input_shape) - im_extents = [transform @ e for e in im_extents] - shape_override_ = shape_from_extents(input_shape, im_extents) - - metadata = { - "spatial_size": spatial_size, - "size_mode": size_mode, - "mode": mode_, - "align_corners": align_corners, - "anti_aliasing": anti_aliasing, - "anti_aliasing_sigma": anti_aliasing_sigma, - "dtype": dtype_, - # "im_extents": im_extents, - "shape_override": shape_override_ - } - return img_, transform, metadata - - -def rotate( - img: torch.Tensor, - angle: Union[Sequence[float], float], - keep_size: Optional[bool] = True, - mode: Optional[Union[InterpolateMode, str]] = InterpolateMode.AREA, - padding_mode: Optional[Union[NumpyPadMode, GridSamplePadMode, str]] = NumpyPadMode.EDGE, - align_corners: Optional[bool] = False, - dtype: Optional[Union[DtypeLike, torch.dtype]] = None, - shape_override: Optional[Sequence] = None -): - """ - Args: - img: channel first array, must have shape: [chns, H, W] or [chns, H, W, D]. - angle: Rotation angle(s) in radians. should a float for 2D, three floats for 3D. - keep_size: If it is True, the output shape is kept the same as the input. - If it is False, the output shape is adapted so that the - input array is contained completely in the output. Default is True. - mode: {``"bilinear"``, ``"nearest"``} - Interpolation mode to calculate output values. Defaults to ``self.mode``. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html - padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} - Padding mode for outside grid values. Defaults to ``self.padding_mode``. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html - align_corners: Defaults to ``self.align_corners``. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html - align_corners: Defaults to ``self.align_corners``. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html - dtype: data type for resampling computation. Defaults to ``self.dtype``. - If None, use the data type of input data. To be compatible with other modules, - the output data type is always ``np.float32``. - - Raises: - ValueError: When ``img`` spatially is not one of [2D, 3D]. - - """ - - img_ = convert_to_tensor(img, track_meta=get_track_meta()) - mode_ = look_up_option(mode, GridSampleMode) - padding_mode_ = look_up_option(padding_mode, GridSamplePadMode) - dtype_ = get_equivalent_dtype(dtype or img.dtype, torch.Tensor) - input_shape = img_.shape if shape_override is None else shape_override - input_ndim = len(input_shape) - 1 - if input_ndim not in (2, 3): - raise ValueError(f"Unsupported image dimension: {input_ndim}, available options are [2, 3].") - - angle_ = ensure_tuple_rep(angle, 1 if input_ndim == 2 else 3) - rotate_tx = torch.from_numpy(create_rotate(input_ndim, angle_).astype(np.float32)) - im_extents = extents_from_shape(input_shape) - if not keep_size: - im_extents = [rotate_tx @ e for e in im_extents] - spatial_shape = shape_from_extents(input_shape, im_extents) - else: - spatial_shape = input_shape - - if align_corners is True: - transform = apply_align_corners(rotate_tx, spatial_shape[1:], - MatrixFactory.from_tensor(img_)).matrix.data - else: - transform = rotate_tx - - metadata = { - "angle": angle, - "keep_size": keep_size, - "mode": mode_, - "padding_mode": padding_mode_, - "align_corners": align_corners, - "dtype": dtype_, - # "im_extents": im_extents, - "shape_override": spatial_shape - } - return img_, transform, metadata - - -def zoom( - img: torch.Tensor, - factor: Union[Sequence[float], float], - mode: Optional[Union[InterpolateMode, str]] = InterpolateMode.BILINEAR, - padding_mode: Optional[Union[NumpyPadMode, GridSamplePadMode, str]] = NumpyPadMode.EDGE, - align_corners: Optional[bool] = False, - keep_size: Optional[bool] = True, - dtype: Optional[Union[DtypeLike, torch.dtype]] = None, - shape_override: Optional[Sequence] = None -): - """ - Args: - img: channel first array, must have shape: (num_channels, H[, W, ..., ]). - mode: {``"nearest"``, ``"nearest-exact"``, ``"linear"``, - ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} - The interpolation mode. Defaults to ``self.mode``. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html - align_corners: This only has an effect when mode is - 'linear', 'bilinear', 'bicubic' or 'trilinear'. Defaults to ``self.align_corners``. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html - - Raises: - ValueError: When ``self.spatial_size`` length is less than ``img`` spatial dimensions. - - """ - - img_ = convert_to_tensor(img, track_meta=get_track_meta()) - input_shape = img_.shape if shape_override is None else shape_override - input_ndim = len(input_shape) - 1 - - zoom_factors = ensure_tuple_rep(factor, input_ndim) - zoom_factors = [1 / f for f in zoom_factors] - - mode_ = look_up_option(mode, GridSampleMode) - padding_mode_ = look_up_option(padding_mode, GridSamplePadMode) - dtype_ = get_equivalent_dtype(dtype or img_.dtype, torch.Tensor) - - transform = MatrixFactory.from_tensor(img_).scale(zoom_factors).matrix.matrix - im_extents = extents_from_shape(input_shape) - if keep_size is False: - im_extents = [transform @ e for e in im_extents] - shape_override_ = shape_from_extents(input_shape, im_extents) - else: - shape_override_ = input_shape - - metadata = { - "factor": zoom_factors, - "mode": mode_, - "padding_mode": padding_mode_, - "align_corners": align_corners, - "keep_size": keep_size, - "dtype": dtype_, - "im_extents": im_extents - } - if keep_size is False or align_corners is True: - metadata["shape_override"] = shape_override_ - - return img_, transform, metadata - - -def rotate90( - img: torch.Tensor, - k: Optional[int] = 1, - spatial_axes: Optional[Tuple[int, int]] = (0, 1), - shape_override: Optional[bool] = None -): - if len(spatial_axes) != 2: - raise ValueError("'spatial_axes' must be a tuple of two integers indicating") - - img_ = convert_to_tensor(img, track_meta=get_track_meta()) - # axes = map_spatial_axes(img.ndim, spatial_axes) - # ori_shape = img.shape[1:] - input_shape = img_.shape if shape_override is None else shape_override - input_ndim = len(input_shape) - 1 - - transform = MatrixFactory.from_tensor(img_).rotate_90(k, ) - - metadata = { - "k": k, - "spatial_axes": spatial_axes, - "shape_override": shape_override - } - return img_, transform, metadata - - -def grid_distortion( - img: torch.Tensor, - num_cells: Union[Tuple[int], int], - distort_steps: Sequence[Sequence[float]], - mode: str = GridSampleMode.BILINEAR, - padding_mode: str = GridSamplePadMode.BORDER, - shape_override: Optional[Tuple[int]] = None -): - all_ranges = [] - num_cells = ensure_tuple_rep(num_cells, len(img.shape) - 1) - for dim_idx, dim_size in enumerate(img.shape[1:]): - dim_distort_steps = distort_steps[dim_idx] - ranges = torch.zeros(dim_size, dtype=torch.float32) - cell_size = dim_size // num_cells[dim_idx] - prev = 0 - for idx in range(num_cells[dim_idx] + 1): - start = int(idx * cell_size) - end = start + cell_size - if end > dim_size: - end = dim_size - cur = dim_size - else: - cur = prev + cell_size * dim_distort_steps[idx] - prev = cur - ranges = range - (dim_size - 1.0) / 2.0 - all_ranges.append() - coords = meshgrid_ij(*all_ranges) - grid = torch.stack([*coords, torch.ones_like(coords[0])]) - - metadata = { - "num_cells": num_cells, - "distort_steps": distort_steps, - "mode": mode, - "padding_mode": padding_mode - } - - return img, grid, metadata - - -def elastic_3d( - img: torch.Tensor, - sigma: float, - magnitude: float, - offsets: torch.Tensor, - spatial_size: Optional[Union[Tuple[int, int, int], int]] = None, - mode: str = GridSampleMode.BILINEAR, - padding_mode: str = GridSamplePadMode.REFLECTION, - device: Optional[torch.device] = None, - shape_override: Optional[Tuple[float]] = None -): - img_ = convert_to_tensor(img, track_meta=get_track_meta()) - - sp_size = fall_back_tuple(spatial_size, img.shape[1:]) - device_ = img.device if isinstance(img, torch.Tensor) else device - grid = create_grid(spatial_size=sp_size, device=device_, backend="torch") - gaussian = GaussianFilter(3, sigma, 3.0).to(device=device_) - grid[:3] += gaussian(offsets)[0] * magnitude - - metadata = { - "sigma": sigma, - "magnitude": magnitude, - "offsets": offsets, - } - if spatial_size is not None: - metadata["spatial_size"] = spatial_size - if mode is not None: - metadata["mode"] = mode - if padding_mode is not None: - metadata["padding_mode"] = padding_mode - if shape_override is not None: - metadata["shape_override"] = shape_override - - return img_, grid, metadata - - -def translate( - img: torch.Tensor, - translation: Sequence[float], - mode: Optional[Union[GridSampleMode, str]] = GridSampleMode.BILINEAR, - padding_mode: Optional[Union[GridSamplePadMode, str]] = NumpyPadMode.EDGE, - dtype: Union[DtypeLike, torch.dtype] = np.float32, - shape_override: Optional[Sequence] = None -): - img_ = convert_to_tensor(img, track_meta=get_track_meta()) - input_shape = img_.shape if shape_override is None else shape_override - input_ndim = len(input_shape) - 1 - if len(translation) != input_ndim: - raise ValueError(f"'translate' length {len(translation)} must be equal to 'img' " - f"spatial dimensions of {input_ndim}") - - transform = MatrixFactory.from_tensor(img).translate(translation).matrix.matrix - im_extents = extents_from_shape(input_shape) - im_extents = [transform @ e for e in im_extents] - # shape_override_ = shape_from_extents(input_shape, im_extents) - - metadata = { - "translation": translation, - "padding_mode": padding_mode, - "dtype": img.dtype, - "im_extents": im_extents, - # "shape_override": shape_override_ - } - return img_, transform, metadata From b25a6cb8fbb480e8d51b9d978c9ed362965eff08 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Wed, 2 Nov 2022 17:45:00 +0000 Subject: [PATCH 09/19] Removing croppad functional transforms from this PR Signed-off-by: Ben Murray --- monai/transforms/croppad/functional.py | 50 -------------------------- 1 file changed, 50 deletions(-) delete mode 100644 monai/transforms/croppad/functional.py diff --git a/monai/transforms/croppad/functional.py b/monai/transforms/croppad/functional.py deleted file mode 100644 index 6bdd45e64d..0000000000 --- a/monai/transforms/croppad/functional.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional, Sequence, Union - -import torch - -from monai.data.meta_obj import get_track_meta -from monai.transforms.lazy.functional import extents_from_shape, shape_from_extents -from monai.transforms.meta_matrix import MatrixFactory -from monai.utils import GridSamplePadMode, NumpyPadMode, convert_to_tensor - - -def croppad( - img: torch.Tensor, - slices: Union[Sequence[slice], slice], - padding_mode: Optional[Union[GridSamplePadMode, str]] = NumpyPadMode.EDGE, - shape_override: Optional[Sequence] = None -): - img_ = convert_to_tensor(img, track_meta=get_track_meta()) - input_shape = img_.shape if shape_override is None else shape_override - input_ndim = len(input_shape) - 1 - if len(slices) != input_ndim: - raise ValueError(f"'slices' length {len(slices)} must be equal to 'img' " - f"spatial dimensions of {input_ndim}") - - img_centers = [i / 2 for i in input_shape[1:]] - slice_centers = [(s.stop + s.start) / 2 for s in slices] - deltas = [s - i for i, s in zip(img_centers, slice_centers)] - transform = MatrixFactory.from_tensor(img).translate(deltas).matrix.matrix - im_extents = extents_from_shape([input_shape[0]] + [s.stop - s.start for s in slices]) - im_extents = [transform @ e for e in im_extents] - shape_override_ = shape_from_extents(input_shape, im_extents) - - metadata = { - "slices": slices, - "padding_mode": padding_mode, - "dtype": img.dtype, - "im_extents": im_extents, - "shape_override": shape_override_ - } - return img_, transform, metadata From 42f33a276fc350ea9d3db41c72b1518c19019913 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Wed, 2 Nov 2022 17:47:02 +0000 Subject: [PATCH 10/19] Removing spatial functional tests as they are no longer part of this PR Signed-off-by: Ben Murray --- tests/test_flipf.py | 137 ------------------------------- tests/test_resizef.py | 178 ---------------------------------------- tests/test_rotatef.py | 107 ------------------------ tests/test_spacingf.py | 181 ----------------------------------------- 4 files changed, 603 deletions(-) delete mode 100644 tests/test_flipf.py delete mode 100644 tests/test_resizef.py delete mode 100644 tests/test_rotatef.py delete mode 100644 tests/test_spacingf.py diff --git a/tests/test_flipf.py b/tests/test_flipf.py deleted file mode 100644 index 58e5b34bb6..0000000000 --- a/tests/test_flipf.py +++ /dev/null @@ -1,137 +0,0 @@ -import unittest - -from typing import Sequence, Union - -import torch - -from monai.transforms.spatial.functional import flip -from tests.utils import get_arange_img - - -def affine_flip(dims, axis: Union[int, Sequence[int]]): - if axis is None: - return torch.eye(dims + 1) - - if not isinstance(axis, (list, tuple)): - axis = (axis,) - t = torch.eye(dims + 1) - for i in axis: - t[i, i] = -1 - return t - - -def get_metadata(overrides=None, remove=None): - metadata = { - "spatial_axis": None, - "shape_override": None, - } - if overrides is not None: - for k, v in overrides.items(): - metadata[k] = v - if remove is not None: - for k in remove: - if k in metadata: - del metadata[k] - return metadata - - -class TestFunctionalSpacing(unittest.TestCase): - - FLIP_CASES = [ - # 2d cases - ( - get_arange_img((32, 32)), affine_flip(2, (0, 1)), - get_metadata({"spatial_axis": None}), get_metadata({"spatial_axis": (0, 1)}) - ), - ( - get_arange_img((32, 32)), affine_flip(2, 0), - get_metadata({"spatial_axis": 0}), get_metadata({"spatial_axis": 0}) - ), - ( - get_arange_img((32, 32)), affine_flip(2, 1), - get_metadata({"spatial_axis": 1}), get_metadata({"spatial_axis": 1}) - ), - ( - get_arange_img((32, 32)), affine_flip(2, 0), - get_metadata({"spatial_axis": (0,)}), get_metadata({"spatial_axis": (0,)}) - ), - ( - get_arange_img((32, 32)), affine_flip(2, 1), - get_metadata({"spatial_axis": (1,)}), get_metadata({"spatial_axis": (1,)}) - ), - ( - get_arange_img((32, 32)), affine_flip(2, (0, 1)), - get_metadata({"spatial_axis": (0, 1)}), get_metadata({"spatial_axis": (0, 1)}) - ), - ( - get_arange_img((32, 32)), affine_flip(2, (0, 1)), - get_metadata({"spatial_axis": (1, 0)}), get_metadata({"spatial_axis": (1, 0)}) - ), - # 3d cases - ( - get_arange_img((32, 32, 16)), affine_flip(3, (0, 1, 2)), - get_metadata({"spatial_axis": None}), get_metadata({"spatial_axis": (0, 1, 2)}) - ), - ( - get_arange_img((32, 32, 16)), affine_flip(3, 0), - get_metadata({"spatial_axis": 0}), get_metadata({"spatial_axis": 0}) - ), - ( - get_arange_img((32, 32, 16)), affine_flip(3, 1), - get_metadata({"spatial_axis": 1}), get_metadata({"spatial_axis": 1}) - ), - ( - get_arange_img((32, 32, 16)), affine_flip(3, 2), - get_metadata({"spatial_axis": 2}), get_metadata({"spatial_axis": 2}) - ), - ( - get_arange_img((32, 32, 16)), affine_flip(3, 0), - get_metadata({"spatial_axis": (0,)}), get_metadata({"spatial_axis": (0,)}) - ), - ( - get_arange_img((32, 32, 16)), affine_flip(3, 1), - get_metadata({"spatial_axis": (1,)}), get_metadata({"spatial_axis": (1,)}) - ), - ( - get_arange_img((32, 32, 16)), affine_flip(3, 2), - get_metadata({"spatial_axis": (2,)}), get_metadata({"spatial_axis": (2,)}) - ), - ( - get_arange_img((32, 32, 16)), affine_flip(3, (0, 1)), - get_metadata({"spatial_axis": (0, 1)}), get_metadata({"spatial_axis": (0, 1)}) - ), - ( - get_arange_img((32, 32, 16)), affine_flip(3, (0, 2)), - get_metadata({"spatial_axis": (0, 2)}), get_metadata({"spatial_axis": (0, 2)}) - ), - ( - get_arange_img((32, 32, 16)), affine_flip(3, (1, 2)), - get_metadata({"spatial_axis": (1, 2)}), get_metadata({"spatial_axis": (1, 2)}) - ), - ( - get_arange_img((32, 32, 16)), affine_flip(3, (0, 1, 2)), - get_metadata({"spatial_axis": (0, 1, 2)}), get_metadata({"spatial_axis": (0, 1, 2)}) - ), - ] - - def _test_functional_flip_impl( - self, img, expected_transform, call_params, expected_metadata - ): - img_, transform_, metadata = flip(img, **call_params) - self.assertTrue(torch.allclose(transform_, expected_transform), - msg=f"{transform_} != {expected_transform}") - actual_keys = set(metadata.keys()) - expected_keys = set(expected_metadata.keys()) - self.assertSetEqual(actual_keys, expected_keys) - for k in actual_keys: - if isinstance(metadata[k], torch.Tensor) and metadata[k] is not None: - self.assertTrue(torch.allclose(metadata[k], expected_metadata[k]), - msg=f"{metadata[k]} != {expected_metadata[k]} for key {k}") - else: - self.assertEqual(metadata[k], expected_metadata[k], - msg=f"{metadata[k]} != {expected_metadata[k]} for key {k}") - - def test_functional_flip(self): - for icase, case in enumerate(self.FLIP_CASES): - with self.subTest(f"{icase}"): - self._test_functional_flip_impl(*case) diff --git a/tests/test_resizef.py b/tests/test_resizef.py deleted file mode 100644 index cd69b5ea58..0000000000 --- a/tests/test_resizef.py +++ /dev/null @@ -1,178 +0,0 @@ -import unittest - -import torch - -from monai.transforms.spatial.functional import resize -from tests.utils import get_arange_img - - -def affine_scale_2d(scale): - scale0, scale1 = scale if isinstance(scale, (tuple, list)) else (scale, scale) - t = torch.eye(3) - t[:, 0] = torch.FloatTensor([scale0, 0.0, 0.0]) - t[:, 1] = torch.FloatTensor([0.0, scale1, 0.0]) - return t - - -def affine_scale_3d(scale): - scale0, scale1, scale2 = scale if isinstance(scale, (tuple, list)) else (scale, scale, scale) - t = torch.eye(4) - t[:, 0] = torch.FloatTensor([scale0, 0.0, 0.0, 0.0]) - t[:, 1] = torch.FloatTensor([0.0, scale1, 0.0, 0.0]) - t[:, 2] = torch.FloatTensor([0.0, 0.0, scale2, 0.0]) - return t - - -def get_metadata(is_3d=True, overrides=None, remove=None): - metadata = { - "size_mode": "all", - "mode": "nearest", - "align_corners": False, - "anti_aliasing": False, - "anti_aliasing_sigma": None, - "dtype": torch.float32, - # "im_extents": None, - # "shape_override": torch.IntTensor([1, 32, 32]) # shape override shouldn't always be in here - } - if overrides is not None: - for k, v in overrides.items(): - metadata[k] = v - if remove is not None: - for k in remove: - if k in metadata: - del metadata[k] - return metadata - - -class TestFunctionalSpacing(unittest.TestCase): - - SPACING_CASES = [ - # 2d - "all" - ( - get_arange_img((32, 32)), - affine_scale_2d(0.5), - get_metadata(False, {"spatial_size": (16, 16)}), - get_metadata(False, {"spatial_size": (16, 16), "shape_override": (1, 16, 16)}) - ), - ( - get_arange_img((32, 32)), - affine_scale_2d(2.0), - get_metadata(False, {"spatial_size": (64, 64)}), - get_metadata(False, {"spatial_size": (64, 64), "shape_override": (1, 64, 64)}) - ), - ( - get_arange_img((32, 16)), - affine_scale_2d((0.5, 1.0)), - get_metadata(False, {"spatial_size": (16, 16)}), - get_metadata(False, {"spatial_size": (16, 16), "shape_override": (1, 16, 16)}) - ), - # 2d - "longest" - ( - get_arange_img((32, 16)), - affine_scale_2d(0.5), - get_metadata(False, {"spatial_size": 16, "size_mode": "longest"}), - get_metadata(False, {"spatial_size": 16, "size_mode": "longest", - "shape_override": (1, 16, 8)}) - ), - ( - get_arange_img((16, 32)), - affine_scale_2d(0.5), - get_metadata(False, {"spatial_size": 16, "size_mode": "longest"}), - get_metadata(False, {"spatial_size": 16, "size_mode": "longest", - "shape_override": (1, 8, 16)}) - ), - ( - get_arange_img((32, 16)), - affine_scale_2d(2.0), - get_metadata(False, {"spatial_size": 64, "size_mode": "longest"}), - get_metadata(False, {"spatial_size": 64, "size_mode": "longest", - "shape_override": (1, 64, 32)}) - ), - ( - get_arange_img((16, 32)), - affine_scale_2d(2.0), - get_metadata(False, {"spatial_size": 64, "size_mode": "longest"}), - get_metadata(False, {"spatial_size": 64, "size_mode": "longest", - "shape_override": (1, 32, 64)}) - ), - # 3d - "all" - ( - get_arange_img((32, 32, 16)), - affine_scale_3d(0.5), - get_metadata(False, {"spatial_size": (16, 16, 8)}), - get_metadata(False, {"spatial_size": (16, 16, 8), - "shape_override": (1, 16, 16, 8)}) - ), - ( - get_arange_img((32, 32, 16)), - affine_scale_3d(2.0), - get_metadata(False, {"spatial_size": (64, 64, 32)}), - get_metadata(False, {"spatial_size": (64, 64, 32), - "shape_override": (1, 64, 64, 32)}) - ), - # 3d - "longest" - ( - get_arange_img((32, 16, 8)), - affine_scale_3d(0.5), - get_metadata(False, {"spatial_size": 16, "size_mode": "longest"}), - get_metadata(False, {"spatial_size": 16, "size_mode": "longest", - "shape_override": (1, 16, 8, 4)}) - ), - ( - get_arange_img((16, 32, 8)), - affine_scale_3d(0.5), - get_metadata(False, {"spatial_size": 16, "size_mode": "longest"}), - get_metadata(False, {"spatial_size": 16, "size_mode": "longest", - "shape_override": (1, 8, 16, 4)}) - ), - ( - get_arange_img((8, 16, 32)), - affine_scale_3d(0.5), - get_metadata(False, {"spatial_size": 16, "size_mode": "longest"}), - get_metadata(False, {"spatial_size": 16, "size_mode": "longest", - "shape_override": (1, 4, 8, 16)}) - ), - ( - get_arange_img((32, 16, 8)), - affine_scale_3d(2.0), - get_metadata(False, {"spatial_size": 64, "size_mode": "longest"}), - get_metadata(False, {"spatial_size": 64, "size_mode": "longest", - "shape_override": (1, 64, 32, 16)}) - ), - ( - get_arange_img((16, 32, 8)), - affine_scale_3d(2.0), - get_metadata(False, {"spatial_size": 64, "size_mode": "longest"}), - get_metadata(False, {"spatial_size": 64, "size_mode": "longest", - "shape_override": (1, 32, 64, 16)}) - ), - ( - get_arange_img((8, 16, 32)), - affine_scale_3d(2.0), - get_metadata(False, {"spatial_size": 64, "size_mode": "longest"}), - get_metadata(False, {"spatial_size": 64, "size_mode": "longest", - "shape_override": (1, 16, 32, 64)}) - ), - ] - - def _test_functional_resize_impl( - self, img, expected_transform, call_params, expected_metadata - ): - img_, transform_, metadata = resize(img, **call_params) - self.assertTrue(torch.allclose(transform_, expected_transform), - msg=f"{transform_} != {expected_transform}") - actual_keys = set(metadata.keys()) - expected_keys = set(expected_metadata.keys()) - self.assertSetEqual(actual_keys, expected_keys) - for k in actual_keys: - if isinstance(metadata[k], torch.Tensor): - self.assertTrue(torch.allclose(metadata[k], expected_metadata[k]), - msg=f"{metadata[k]} != {expected_metadata[k]} for key {k}") - else: - self.assertEqual(metadata[k], expected_metadata[k], - msg=f"{metadata[k]} != {expected_metadata[k]} for key {k}") - - def test_functional_resize(self): - for icase, case in enumerate(self.SPACING_CASES): - with self.subTest(f"{icase}"): - self._test_functional_resize_impl(*case) diff --git a/tests/test_rotatef.py b/tests/test_rotatef.py deleted file mode 100644 index ebe7c8b520..0000000000 --- a/tests/test_rotatef.py +++ /dev/null @@ -1,107 +0,0 @@ -import unittest - -import math - -import torch - -from monai.transforms.spatial.functional import rotate -from tests.utils import get_arange_img - - -def affine_rotate_2d(radians, scale=None): - t = torch.eye(3) - t[:, 0] = torch.FloatTensor([math.cos(radians), math.sin(radians), 0.0]) - t[:, 1] = torch.FloatTensor([-math.sin(radians), math.cos(radians), 0.0]) - if scale is not None: - t[0, 0] *= scale - t[0, 1] *= scale - t[1, 0] *= scale - t[1, 1] *= scale - return t - - -def affine_scale_3d(scale): - scale0, scale1, scale2 = scale if isinstance(scale, (tuple, list)) else (scale, scale, scale) - t = torch.eye(4) - t[:, 0] = torch.FloatTensor([scale0, 0.0, 0.0, 0.0]) - t[:, 1] = torch.FloatTensor([0.0, scale1, 0.0, 0.0]) - t[:, 2] = torch.FloatTensor([0.0, 0.0, scale2, 0.0]) - return t - - -def get_metadata(is_3d=True, overrides=None, remove=None): - metadata = { - "angle": 0.0, - "keep_size": True, - "mode": "nearest", - "padding_mode": "zeros", - "align_corners": False, - "dtype": torch.float32, - # "im_extents": None, - # "shape_override": torch.IntTensor([1, 32, 32]) # shape override shouldn't always be in here - } - if overrides is not None: - for k, v in overrides.items(): - metadata[k] = v - if remove is not None: - for k in remove: - if k in metadata: - del metadata[k] - return metadata - - -class TestFunctionalRotate(unittest.TestCase): - - ROTATE_CASES = [ - # keep_size = True - ( - get_arange_img((32, 32)), - affine_rotate_2d(torch.pi / 4), - get_metadata(False, {"angle": torch.pi / 4}), - get_metadata(False, {"angle": torch.pi / 4, "shape_override": (1, 32, 32)}) - ), - ( - get_arange_img((32, 32)), - affine_rotate_2d(torch.pi / 4, 32/33), - get_metadata(False, {"angle": torch.pi / 4, "align_corners": True}), - get_metadata(False, {"angle": torch.pi / 4, "align_corners": True, - "shape_override": (1, 32, 32)}) - ), - # keep_size = False - ( - get_arange_img((32, 32)), - affine_rotate_2d(torch.pi / 4), - get_metadata(False, {"angle": torch.pi / 4, "keep_size": False}), - get_metadata(False, {"angle": torch.pi / 4, "keep_size": False, - "shape_override": (1, 45, 45)}) - ), - ( - get_arange_img((32, 32)), - affine_rotate_2d(torch.pi / 4, 45/46), - get_metadata(False, {"angle": torch.pi / 4, "keep_size": False, "align_corners": True}), - get_metadata(False, {"angle": torch.pi / 4, "keep_size": False, "align_corners": True, - "shape_override": (1, 45, 45)}) - ), - ] - - def _test_functional_rotate_impl( - self, img, expected_transform, call_params, expected_metadata - ): - img_, transform_, metadata = rotate(img, **call_params) - self.assertTrue(torch.allclose(transform_, expected_transform), - msg=f"{transform_} != {expected_transform}") - actual_keys = set(metadata.keys()) - expected_keys = set(expected_metadata.keys()) - self.assertSetEqual(actual_keys, expected_keys) - for k in actual_keys: - if isinstance(metadata[k], torch.Tensor): - self.assertTrue(torch.allclose(metadata[k], expected_metadata[k]), - msg=f"{metadata[k]} != {expected_metadata[k]} for key {k}") - else: - self.assertEqual(metadata[k], expected_metadata[k], - msg=f"{metadata[k]} != {expected_metadata[k]} for key {k}") - - def test_functional_rotate(self): - for icase, case in enumerate(self.ROTATE_CASES): - with self.subTest(f"{icase}"): - self._test_functional_rotate_impl(*case) diff --git a/tests/test_spacingf.py b/tests/test_spacingf.py deleted file mode 100644 index 0fcc1b43d6..0000000000 --- a/tests/test_spacingf.py +++ /dev/null @@ -1,181 +0,0 @@ -import unittest - -import torch - -from monai.transforms.spatial.functional import spacing -from tests.utils import get_arange_img - - -def affine_scale_2d(scale): - scale0, scale1 = scale if isinstance(scale, (tuple, list)) else (scale, scale) - t = torch.eye(3) - t[:, 0] = torch.FloatTensor([scale0, 0.0, 0.0]) - t[:, 1] = torch.FloatTensor([0.0, scale1, 0.0]) - return t - - -def affine_scale_3d(scale): - scale0, scale1, scale2 = scale if isinstance(scale, (tuple, list)) else (scale, scale, scale) - t = torch.eye(4) - t[:, 0] = torch.FloatTensor([scale0, 0.0, 0.0, 0.0]) - t[:, 1] = torch.FloatTensor([0.0, scale1, 0.0, 0.0]) - t[:, 2] = torch.FloatTensor([0.0, 0.0, scale2, 0.0]) - return t - - -def get_metadata(is_3d=True, overrides=None, remove=None): - metadata = { - "pixdim": (1.0, 1.0, 1.0) if is_3d else (1.0, 1.0), - "src_pixdim": (1.0, 1.0, 1.0) if is_3d else (1.0, 1.0), - "diagonal": False, - "mode": "nearest", - "padding_mode": "zeros", - "align_corners": False, - "dtype": torch.float32, - # "im_extents": None, - # "shape_override": torch.IntTensor([1, 32, 32]) # shape override shouldn't always be in here - } - if overrides is not None: - for k, v in overrides.items(): - metadata[k] = v - if remove is not None: - for k in remove: - if k in metadata: - del metadata[k] - return metadata - - -class TestFunctionalSpacing(unittest.TestCase): - - SPACING_CASES = [ - ( - get_arange_img((32, 32)), - affine_scale_2d(0.5), - get_metadata(False, {"pixdim": (2.0, 2.0)}), - get_metadata(False, {"pixdim": (2.0, 2.0), "shape_override": (1, 16, 16)}) - ), - ( - get_arange_img((32, 32)), - affine_scale_2d(0.5), - get_metadata(False, {"src_pixdim": (0.5, 0.5)}), - get_metadata(False, {"src_pixdim": (0.5, 0.5), "shape_override": (1, 16, 16)}) - ), - ( - get_arange_img((32, 32)), - affine_scale_2d(0.25), - get_metadata(False, {"pixdim": (2.0, 2.0), "src_pixdim": (0.5, 0.5)}), - get_metadata(False, {"pixdim": (2.0, 2.0), "src_pixdim": (0.5, 0.5), - "shape_override": (1, 8, 8)}) - ), - ( - get_arange_img((32, 32)), - affine_scale_2d(2.0), - get_metadata(False, {"src_pixdim": (2.0, 2.0)}), - get_metadata(False, {"src_pixdim": (2.0, 2.0), "shape_override": (1, 64, 64)}) - ), - ( - get_arange_img((32, 32)), - affine_scale_2d(2.0), - get_metadata(False, {"pixdim": (0.5, 0.5)}), - get_metadata(False, {"pixdim": (0.5, 0.5), "shape_override": (1, 64, 64)}) - ), - ( - get_arange_img((32, 32)), - affine_scale_2d(4.0), - get_metadata(False, {"pixdim": (0.5, 0.5), "src_pixdim": (2.0, 2.0)}), - get_metadata(False, {"pixdim": (0.5, 0.5), "src_pixdim": (2.0, 2.0), - "shape_override": (1, 128, 128)}) - ), - ( - get_arange_img((32, 32)), - affine_scale_2d((0.5, 2.0)), - get_metadata(False, {"pixdim": (2.0, 1.0), "src_pixdim": (1.0, 2.0)}), - get_metadata(False, {"pixdim": (2.0, 1.0), "src_pixdim": (1.0, 2.0), - "shape_override": (1, 16, 64)}) - ), - ( - get_arange_img((32, 32)), - affine_scale_2d((2.0, 0.5)), - get_metadata(False, {"pixdim": (1.0, 2.0), "src_pixdim": (2.0, 1.0)}), - get_metadata(False, {"pixdim": (1.0, 2.0), "src_pixdim": (2.0, 1.0), - "shape_override": (1, 64, 16)}) - ), - ( - get_arange_img((32, 32, 24)), - affine_scale_3d(0.5), - get_metadata(True, {"pixdim": (2.0, 2.0, 2.0)}), - get_metadata(True, {"pixdim": (2.0, 2.0, 2.0), - "shape_override": (1, 16, 16, 12)}) - ), - ( - get_arange_img((32, 32, 24)), - affine_scale_3d(0.5), - get_metadata(True, {"src_pixdim": (0.5, 0.5, 0.5)}), - get_metadata(True, {"src_pixdim": (0.5, 0.5, 0.5), - "shape_override": (1, 16, 16, 12)}) - ), - ( - get_arange_img((32, 32, 24)), - affine_scale_3d(0.25), - get_metadata(True, {"pixdim": (2.0, 2.0, 2.0), "src_pixdim": (0.5, 0.5, 0.5)}), - get_metadata(True, {"pixdim": (2.0, 2.0, 2.0), "src_pixdim": (0.5, 0.5, 0.5), - "shape_override": (1, 8, 8, 6)}) - ), - ( - get_arange_img((32, 32, 24)), - affine_scale_3d(2.0), - get_metadata(True, {"src_pixdim": (2.0, 2.0, 2.0)}), - get_metadata(True, {"src_pixdim": (2.0, 2.0, 2.0), - "shape_override": (1, 64, 64, 48)}) - ), - ( - get_arange_img((32, 32, 24)), - affine_scale_3d(2.0), - get_metadata(True, {"pixdim": (0.5, 0.5, 0.5)}), - get_metadata(True, {"pixdim": (0.5, 0.5, 0.5), - "shape_override": (1, 64, 64, 48)}) - ), - ( - get_arange_img((32, 32, 24)), - affine_scale_3d(4.0), - get_metadata(True, {"pixdim": (0.5, 0.5, 0.5), "src_pixdim": (2.0, 2.0, 2.0)}), - get_metadata(True, {"pixdim": (0.5, 0.5, 0.5), "src_pixdim": (2.0, 2.0, 2.0), - "shape_override": (1, 128, 128, 96)}) - ), - ( - get_arange_img((32, 32, 24)), - affine_scale_3d((0.5, 2.0, 1/3.0)), - get_metadata(True, {"pixdim": (2.0, 1.0, 4.5), "src_pixdim": (1.0, 2.0, 1.5)}), - get_metadata(True, {"pixdim": (2.0, 1.0, 4.5), "src_pixdim": (1.0, 2.0, 1.5), - "shape_override": (1, 16, 64, 8)}) - ), - ( - get_arange_img((32, 32, 24)), - affine_scale_3d((2.0, 0.5, 3.0)), - get_metadata(True, {"pixdim": (1.0, 2.0, 1.5), "src_pixdim": (2.0, 1.0, 4.5)}), - get_metadata(True, {"pixdim": (1.0, 2.0, 1.5), "src_pixdim": (2.0, 1.0, 4.5), - "shape_override": (1, 64, 16, 72)}) - ), - ] - - def _test_functional_spacing_impl( - self, img, expected_transform, call_params, expected_metadata - ): - img_, transform_, metadata = spacing(img, **call_params) - self.assertTrue(torch.allclose(transform_, expected_transform), - msg=f"{transform_} != {expected_transform}") - actual_keys = set(metadata.keys()) - expected_keys = set(expected_metadata.keys()) - self.assertSetEqual(actual_keys, expected_keys) - for k in actual_keys: - if isinstance(metadata[k], torch.Tensor): - self.assertTrue(torch.allclose(metadata[k], expected_metadata[k]), - msg=f"{metadata[k]} != {expected_metadata[k]} for key {k}") - else: - self.assertEqual(metadata[k], expected_metadata[k], - msg=f"{metadata[k]} != {expected_metadata[k]} for key {k}") - - def test_functional_spacing(self): - for icase, case in enumerate(self.SPACING_CASES): - with self.subTest(f"{icase}"): - self._test_functional_spacing_impl(*case) From 11772631907e6d66a4e59c9e49188fc087f1463d Mon Sep 17 00:00:00 2001 From: monai-bot Date: Sat, 12 Nov 2022 10:35:57 +0000 Subject: [PATCH 11/19] [MONAI] code formatting Signed-off-by: monai-bot --- monai/transforms/__init__.py | 8 +--- monai/transforms/lazy/array.py | 2 +- monai/transforms/lazy/functional.py | 2 +- monai/transforms/meta_matrix.py | 43 ++++++++---------- monai/transforms/utility/functional.py | 8 +--- monai/transforms/utils.py | 63 +++++++++----------------- tests/test_apply.py | 25 ++++------ tests/test_matmul.py | 30 ++++-------- tests/test_resample.py | 9 +--- 9 files changed, 66 insertions(+), 124 deletions(-) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 6416414c05..e52009ed98 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -229,13 +229,7 @@ from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict from .lazy.array import Apply from .lazy.functional import apply -from .meta_matrix import ( - Grid, - matmul, - Matrix, - MatrixFactory, - MetaMatrix, -) +from .meta_matrix import Grid, Matrix, MatrixFactory, MetaMatrix, matmul from .meta_utility.dictionary import ( FromMetaTensord, FromMetaTensorD, diff --git a/monai/transforms/lazy/array.py b/monai/transforms/lazy/array.py index ae165cf566..fe3dd7a211 100644 --- a/monai/transforms/lazy/array.py +++ b/monai/transforms/lazy/array.py @@ -9,8 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from monai.transforms.lazy.functional import apply from monai.transforms.inverse import InvertibleTransform +from monai.transforms.lazy.functional import apply __all__ = ["Apply"] diff --git a/monai/transforms/lazy/functional.py b/monai/transforms/lazy/functional.py index 02e44150ec..ce9fc0fe84 100644 --- a/monai/transforms/lazy/functional.py +++ b/monai/transforms/lazy/functional.py @@ -127,7 +127,7 @@ def apply(data: Union[torch.Tensor, MetaTensor, dict], pending: Optional[Union[d if isinstance(data, dict): rd = dict() for k, v in data.items(): - result = apply(v, pending) + result = v(*pending) rd[k] = result return rd diff --git a/monai/transforms/meta_matrix.py b/monai/transforms/meta_matrix.py index dc6a8c5902..9733361c55 100644 --- a/monai/transforms/meta_matrix.py +++ b/monai/transforms/meta_matrix.py @@ -13,16 +13,23 @@ import numpy as np import torch -from monai.transforms.utils import _create_rotate, _create_rotate_90, _create_flip, _create_shear, _create_scale, \ - _create_translate - -from monai.transforms.utils import get_backend_from_tensor_like, get_device_from_tensor_like -from monai.utils import TransformBackends from monai.config import NdarrayOrTensor +from monai.transforms.utils import ( + _create_flip, + _create_rotate, + _create_rotate_90, + _create_scale, + _create_shear, + _create_translate, + get_backend_from_tensor_like, + get_device_from_tensor_like, +) +from monai.utils import TransformBackends __all__ = ["Grid", "matmul", "Matrix", "MatrixFactory", "MetaMatrix"] + def is_matrix_shaped(data): return ( @@ -36,11 +43,7 @@ def is_grid_shaped(data): class MatrixFactory: - - def __init__(self, - dims: int, - backend: TransformBackends, - device: Optional[torch.device] = None): + def __init__(self, dims: int, backend: TransformBackends, device: Optional[torch.device] = None): if backend == TransformBackends.NUMPY: if device is not None: @@ -54,27 +57,17 @@ def __init__(self, if device is None: raise ValueError("'device' must be set with TransformBackends.TORCH") self._device = device - self._sin = lambda th: torch.sin(torch.as_tensor(th, - dtype=torch.float32, - device=self._device)) - self._cos = lambda th: torch.cos(torch.as_tensor(th, - dtype=torch.float32, - device=self._device)) - self._eye = lambda rank: torch.eye(rank, - device=self._device, - dtype=torch.float32); - self._diag = lambda size: torch.diag(torch.as_tensor(size, - device=self._device, - dtype=torch.float32)) + self._sin = lambda th: torch.sin(torch.as_tensor(th, dtype=torch.float32, device=self._device)) + self._cos = lambda th: torch.cos(torch.as_tensor(th, dtype=torch.float32, device=self._device)) + self._eye = lambda rank: torch.eye(rank, device=self._device, dtype=torch.float32) + self._diag = lambda size: torch.diag(torch.as_tensor(size, device=self._device, dtype=torch.float32)) self._backend = backend self._dims = dims @staticmethod def from_tensor(data): - return MatrixFactory(len(data.shape)-1, - get_backend_from_tensor_like(data), - get_device_from_tensor_like(data)) + return MatrixFactory(len(data.shape) - 1, get_backend_from_tensor_like(data), get_device_from_tensor_like(data)) def identity(self): matrix = self._eye(self._dims + 1) diff --git a/monai/transforms/utility/functional.py b/monai/transforms/utility/functional.py index 7aec22c723..56570799f7 100644 --- a/monai/transforms/utility/functional.py +++ b/monai/transforms/utility/functional.py @@ -11,17 +11,13 @@ from typing import Optional, Union import torch -from monai.transforms import Affine from monai.config import NdarrayOrTensor +from monai.transforms import Affine from monai.transforms.meta_matrix import Grid, Matrix -def resample( - data: torch.Tensor, - matrix: Union[NdarrayOrTensor, Matrix, Grid], - kwargs: Optional[dict] = None -): +def resample(data: torch.Tensor, matrix: Union[NdarrayOrTensor, Matrix, Grid], kwargs: Optional[dict] = None): """ This is a minimal implementation of resample that always uses Affine. """ diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index bdf392779a..1f632c7343 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -890,24 +890,17 @@ def _create_translate( def _create_rotate_90( - spatial_dims: int, - axis: Tuple[int, int], - steps: Optional[int] = 1, - eye_func: Callable = np.eye + spatial_dims: int, axis: Tuple[int, int], steps: Optional[int] = 1, eye_func: Callable = np.eye ) -> NdarrayOrTensor: - values = [(1, 0, 0, 1), - (0, -1, 1, 0), - (-1, 0, 0, -1), - (0, 1, -1, 0)] + values = [(1, 0, 0, 1), (0, -1, 1, 0), (-1, 0, 0, -1), (0, 1, -1, 0)] if spatial_dims == 2: if axis != (0, 1): raise ValueError(f"if 'spatial_dims' is 2, 'axis' must be (0, 1) but is {axis}") elif spatial_dims == 3: if axis not in ((0, 1), (0, 2), (1, 2)): - raise ValueError("if 'spatial_dims' is 3, 'axis' must be (0,1), (0, 2), or (1, 2) " - f"but is {axis}") + raise ValueError("if 'spatial_dims' is 3, 'axis' must be (0,1), (0, 2), or (1, 2) " f"but is {axis}") else: raise ValueError(f"'spatial_dims' must be 2 or 3 but is {spatial_dims}") @@ -946,38 +939,31 @@ def create_rotate_90( """ _backend = look_up_option(backend, TransformBackends) if _backend == TransformBackends.NUMPY: - return _create_rotate_90( - spatial_dims=spatial_dims, - axis=axis, - steps=steps, - eye_func=np.eye) + return _create_rotate_90(spatial_dims=spatial_dims, axis=axis, steps=steps, eye_func=np.eye) if _backend == TransformBackends.TORCH: return _create_rotate_90( - spatial_dims=spatial_dims, - axis=axis, - steps=steps, - eye_func=lambda rank: torch.eye(rank, device=device), + spatial_dims=spatial_dims, axis=axis, steps=steps, eye_func=lambda rank: torch.eye(rank, device=device) ) raise ValueError(f"backend {backend} is not supported") -def _create_flip( - spatial_dims: int, - spatial_axis: Union[Sequence[int], int], - eye_func: Callable = np.eye -): +def _create_flip(spatial_dims: int, spatial_axis: Union[Sequence[int], int], eye_func: Callable = np.eye): affine = eye_func(spatial_dims + 1) if isinstance(spatial_axis, int): if spatial_axis < -spatial_dims or spatial_axis >= spatial_dims: - raise ValueError("'spatial_axis' values must be between " - f"{-spatial_dims} and {spatial_dims-1} inclusive " - f"('spatial_axis' is {spatial_axis})") + raise ValueError( + "'spatial_axis' values must be between " + f"{-spatial_dims} and {spatial_dims-1} inclusive " + f"('spatial_axis' is {spatial_axis})" + ) affine[spatial_axis, spatial_axis] = -1 else: if any((s < -spatial_dims or s >= spatial_dims) for s in spatial_axis): - raise ValueError("'spatial_axis' values must be between " - f"{-spatial_dims} and {spatial_dims-1} inclusive " - f"('spatial_axis' is {spatial_axis})") + raise ValueError( + "'spatial_axis' values must be between " + f"{-spatial_dims} and {spatial_dims-1} inclusive " + f"('spatial_axis' is {spatial_axis})" + ) for i in range(spatial_dims): if i in spatial_axis: @@ -987,22 +973,17 @@ def _create_flip( def create_flip( - spatial_dims: int, - spatial_axis: Union[Sequence[int], int], - device: Optional[torch.device] = None, - backend: str = TransformBackends.NUMPY, + spatial_dims: int, + spatial_axis: Union[Sequence[int], int], + device: Optional[torch.device] = None, + backend: str = TransformBackends.NUMPY, ) -> NdarrayOrTensor: _backend = look_up_option(backend, TransformBackends) if _backend == TransformBackends.NUMPY: - return _create_flip( - spatial_dims=spatial_dims, - spatial_axis=spatial_axis, - eye_func=np.eye) + return _create_flip(spatial_dims=spatial_dims, spatial_axis=spatial_axis, eye_func=np.eye) if _backend == TransformBackends.TORCH: return _create_flip( - spatial_dims=spatial_dims, - spatial_axis=spatial_axis, - eye_func=lambda rank: torch.eye(rank, device=device), + spatial_dims=spatial_dims, spatial_axis=spatial_axis, eye_func=lambda rank: torch.eye(rank, device=device) ) raise ValueError(f"backend {backend} is not supported") diff --git a/tests/test_apply.py b/tests/test_apply.py index 2a5d5aa2a7..27426bdced 100644 --- a/tests/test_apply.py +++ b/tests/test_apply.py @@ -12,43 +12,37 @@ import unittest import torch -from monai.utils import convert_to_tensor, TransformBackends from monai.transforms.lazy.functional import apply -from monai.transforms.meta_matrix import MetaMatrix, MatrixFactory +from monai.transforms.meta_matrix import MatrixFactory, MetaMatrix +from monai.utils import TransformBackends, convert_to_tensor def single_2d_transform_cases(): f = MatrixFactory(2, TransformBackends.TORCH, "cpu") cases = [ - ( - torch.randn((1, 32, 32)), - [MetaMatrix(f.rotate_euler(torch.pi / 4).matrix, {"id": "rotate"})], - (1, 32, 32) - ), + (torch.randn((1, 32, 32)), [MetaMatrix(f.rotate_euler(torch.pi / 4).matrix, {"id": "rotate"})], (1, 32, 32)), ( torch.randn((1, 16, 16)), - [MetaMatrix(f.rotate_euler(torch.pi / 4).matrix, - {"id": "rotate", "shape_override": (1, 45, 45)})], - (1, 45, 45) - ) + [MetaMatrix(f.rotate_euler(torch.pi / 4).matrix, {"id": "rotate", "shape_override": (1, 45, 45)})], + (1, 45, 45), + ), ] return cases class TestApply(unittest.TestCase): - def _test_apply_impl(self, tensor, pending_transforms): print(tensor.shape) - result = apply(tensor, pending_transforms) + result = tensor(*pending_transforms) self.assertListEqual(result[1], pending_transforms) def _test_apply_metatensor_impl(self, tensor, pending_transforms, expected_shape, pending_as_parameter): tensor_ = convert_to_tensor(tensor, track_meta=True) if pending_as_parameter: - result, transforms = apply(tensor_, pending_transforms) + result, transforms = tensor_(*pending_transforms) else: for p in pending_transforms: tensor_.push_pending_operation(p) @@ -56,7 +50,6 @@ def _test_apply_metatensor_impl(self, tensor, pending_transforms, expected_shape self.assertEqual(result.shape, expected_shape) - SINGLE_TRANSFORM_CASES = single_2d_transform_cases() def test_apply_single_transform(self): @@ -72,5 +65,5 @@ def test_apply_single_transform_metatensor_override(self): self._test_apply_metatensor_impl(*case, True) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_matmul.py b/tests/test_matmul.py index 33250ae05a..a9a7852046 100644 --- a/tests/test_matmul.py +++ b/tests/test_matmul.py @@ -11,10 +11,9 @@ import unittest -from parameterized import parameterized - import numpy as np import torch +from parameterized import parameterized from monai.transforms.meta_matrix import ( Grid, @@ -149,24 +148,15 @@ def get_matmul_2d_test_cases(): ( f.rotate_euler(torch.pi / 4), f.scale((0.5, 0.5)), - torch.FloatTensor([[0.35355339, -0.35355339, 0], [0.35355339, 0.35355339, 0], [0, 0, 1]]) + torch.FloatTensor([[0.35355339, -0.35355339, 0], [0.35355339, 0.35355339, 0], [0, 0, 1]]), ), ( f.scale((0.5, 0.5)), f.rotate_euler(torch.pi / 4), - torch.FloatTensor([[0.35355339, -0.35355339, 0], [0.35355339, 0.35355339, 0], [0, 0, 1]]) - - ), - ( - f.translate((8, 8)), - f.rotate_euler(torch.pi / 2), - torch.FloatTensor([[0, -1, 8], [1, 0, 8], [0, 0, 1]]) - ), - ( - f.rotate_euler(torch.pi / 2), - f.translate((8, 8)), - torch.FloatTensor([[0, -1, -8], [1, 0, 8], [0, 0, 1]]) + torch.FloatTensor([[0.35355339, -0.35355339, 0], [0.35355339, 0.35355339, 0], [0, 0, 1]]), ), + (f.translate((8, 8)), f.rotate_euler(torch.pi / 2), torch.FloatTensor([[0, -1, 8], [1, 0, 8], [0, 0, 1]])), + (f.rotate_euler(torch.pi / 2), f.translate((8, 8)), torch.FloatTensor([[0, -1, -8], [1, 0, 8], [0, 0, 1]])), ] return cases @@ -178,8 +168,10 @@ def get_matmul_2d_test_cases(): class TestMatmulOutputs(unittest.TestCase): def _test_matmul_outputs_impl(self, left, right, expected): actual = matmul(left, right) - self.assertTrue(torch.allclose(actual.matrix.data, expected, atol=1e-7), - msg=f"{actual.matrix.data} is not close to {expected}") + self.assertTrue( + torch.allclose(actual.matrix.data, expected, atol=1e-7), + msg=f"{actual.matrix.data} is not close to {expected}", + ) @parameterized.expand(MATMUL_2D_TEST_CASES) def test_matmul_outputs(self, left, right, expected): @@ -192,11 +184,9 @@ def test_all_matmul_outputs(self): class TestMatrixMatrixOutputs(unittest.TestCase): - def _test_matrix_matrix_outputs_impl(self, left, right, expected): actual = matmul_matrix_matrix(left.matrix.data, right.matrix.data) - self.assertTrue(torch.allclose(actual, expected, atol=1e-7), - msg=f"{actual} is not close to {expected}") + self.assertTrue(torch.allclose(actual, expected, atol=1e-7), msg=f"{actual} is not close to {expected}") @parameterized.expand(MATMUL_2D_TEST_CASES) def test_matrix_matrix_outputs(self, left, right, expected): diff --git a/tests/test_resample.py b/tests/test_resample.py index fc8dab06bf..5e85c741ef 100644 --- a/tests/test_resample.py +++ b/tests/test_resample.py @@ -14,9 +14,7 @@ import torch from monai.transforms.utility.functional import resample - from monai.utils import convert_to_tensor - from tests.utils import get_arange_img @@ -28,19 +26,16 @@ def rotate_45_2d(): class TestResampleFunction(unittest.TestCase): - def _test_resample_function_impl(self, img, matrix): result = resample(convert_to_tensor(img), matrix) print(result) - RESAMPLE_FUNCTION_CASES = [ - (get_arange_img((1, 16, 16)), rotate_45_2d()) - ] + RESAMPLE_FUNCTION_CASES = [(get_arange_img((1, 16, 16)), rotate_45_2d())] def test_resample_function(self): for case in self.RESAMPLE_FUNCTION_CASES: self._test_resample_function_impl(*case) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() From 1492f4099de647d37f6c4982d559e83d660c2ddb Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sat, 12 Nov 2022 10:39:24 +0000 Subject: [PATCH 12/19] minor updates - `_pending_operations` defaults to `MetaObj.get_default_applied_operations()` - removes the unused `Apply` class - fixes unit tests, style checks - torch.pi doesn't exist before torch 1.10 Signed-off-by: Wenqi Li --- monai/data/meta_obj.py | 2 +- monai/transforms/__init__.py | 1 - monai/transforms/lazy/array.py | 48 ---------------------------------- monai/transforms/utils.py | 10 ++----- tests/test_apply.py | 13 ++++----- tests/test_matmul.py | 8 +++--- tests/test_resample.py | 7 +++-- tests/utils.py | 16 +++--------- 8 files changed, 21 insertions(+), 84 deletions(-) delete mode 100644 monai/transforms/lazy/array.py diff --git a/monai/data/meta_obj.py b/monai/data/meta_obj.py index f804fc2e3d..74daf59ba8 100644 --- a/monai/data/meta_obj.py +++ b/monai/data/meta_obj.py @@ -214,7 +214,7 @@ def pop_pending_operation(self) -> Any: return self._pending_operations.pop() def clear_pending_operations(self) -> Any: - self._pending_operations = list() + self._pending_operations = MetaObj.get_default_applied_operations() @property def is_batch(self) -> bool: diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index e52009ed98..2bee4eacce 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -227,7 +227,6 @@ from .inverse_batch_transform import BatchInverseTransform, Decollated, DecollateD, DecollateDict from .io.array import SUPPORTED_READERS, LoadImage, SaveImage from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict -from .lazy.array import Apply from .lazy.functional import apply from .meta_matrix import Grid, Matrix, MatrixFactory, MetaMatrix, matmul from .meta_utility.dictionary import ( diff --git a/monai/transforms/lazy/array.py b/monai/transforms/lazy/array.py deleted file mode 100644 index fe3dd7a211..0000000000 --- a/monai/transforms/lazy/array.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from monai.transforms.inverse import InvertibleTransform -from monai.transforms.lazy.functional import apply - -__all__ = ["Apply"] - - -class Apply(InvertibleTransform): - """ - Apply wraps the apply method and can function as a Transform in either array or dictionary - mode. - """ - - def __init__(self): - super().__init__() - - def __call__(self, *args, **kwargs): - return apply(*args, **kwargs) - - def inverse(self, data): - return NotImplementedError() - - -# class Applyd(MapTransform, InvertibleTransform): -# -# def __init__(self): -# super().__init__() -# -# def __call__( -# self, -# d: dict -# ): -# rd = dict() -# for k, v in d.items(): -# rd[k] = apply(v) -# -# def inverse(self, data): -# return NotImplementedError() diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 1f632c7343..ee7c9d1645 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -904,16 +904,10 @@ def _create_rotate_90( else: raise ValueError(f"'spatial_dims' must be 2 or 3 but is {spatial_dims}") - steps_ = steps % 4 - affine = eye_func(spatial_dims + 1) - if spatial_dims == 2: - a, b = 0, 1 - else: - a, b = axis - - affine[a, a], affine[a, b], affine[b, a], affine[b, b] = values[steps] + a, b = (0, 1) if spatial_dims == 2 else axis + affine[a, a], affine[a, b], affine[b, a], affine[b, b] = values[steps % 4] return affine diff --git a/tests/test_apply.py b/tests/test_apply.py index 27426bdced..49130944cb 100644 --- a/tests/test_apply.py +++ b/tests/test_apply.py @@ -11,6 +11,7 @@ import unittest +import numpy as np import torch from monai.transforms.lazy.functional import apply @@ -22,10 +23,10 @@ def single_2d_transform_cases(): f = MatrixFactory(2, TransformBackends.TORCH, "cpu") cases = [ - (torch.randn((1, 32, 32)), [MetaMatrix(f.rotate_euler(torch.pi / 4).matrix, {"id": "rotate"})], (1, 32, 32)), + (torch.randn((1, 32, 32)), [MetaMatrix(f.rotate_euler(np.pi / 4).matrix, {"id": "rotate"})], (1, 32, 32)), ( torch.randn((1, 16, 16)), - [MetaMatrix(f.rotate_euler(torch.pi / 4).matrix, {"id": "rotate", "shape_override": (1, 45, 45)})], + [MetaMatrix(f.rotate_euler(np.pi / 4).matrix, {"id": "rotate", "shape_override": (1, 45, 45)})], (1, 45, 45), ), ] @@ -34,15 +35,15 @@ def single_2d_transform_cases(): class TestApply(unittest.TestCase): - def _test_apply_impl(self, tensor, pending_transforms): - print(tensor.shape) - result = tensor(*pending_transforms) + def _test_apply_impl(self, tensor, pending_transforms, expected_shape): + result = apply(tensor, pending_transforms) self.assertListEqual(result[1], pending_transforms) + self.assertEqual(result[0].shape, expected_shape) def _test_apply_metatensor_impl(self, tensor, pending_transforms, expected_shape, pending_as_parameter): tensor_ = convert_to_tensor(tensor, track_meta=True) if pending_as_parameter: - result, transforms = tensor_(*pending_transforms) + result, transforms = apply(tensor_, pending_transforms) else: for p in pending_transforms: tensor_.push_pending_operation(p) diff --git a/tests/test_matmul.py b/tests/test_matmul.py index a9a7852046..661291d4ad 100644 --- a/tests/test_matmul.py +++ b/tests/test_matmul.py @@ -146,17 +146,17 @@ def get_matmul_2d_test_cases(): f = MatrixFactory(2, TransformBackends.TORCH, "cpu") cases = [ ( - f.rotate_euler(torch.pi / 4), + f.rotate_euler(np.pi / 4), f.scale((0.5, 0.5)), torch.FloatTensor([[0.35355339, -0.35355339, 0], [0.35355339, 0.35355339, 0], [0, 0, 1]]), ), ( f.scale((0.5, 0.5)), - f.rotate_euler(torch.pi / 4), + f.rotate_euler(np.pi / 4), torch.FloatTensor([[0.35355339, -0.35355339, 0], [0.35355339, 0.35355339, 0], [0, 0, 1]]), ), - (f.translate((8, 8)), f.rotate_euler(torch.pi / 2), torch.FloatTensor([[0, -1, 8], [1, 0, 8], [0, 0, 1]])), - (f.rotate_euler(torch.pi / 2), f.translate((8, 8)), torch.FloatTensor([[0, -1, -8], [1, 0, 8], [0, 0, 1]])), + (f.translate((8, 8)), f.rotate_euler(np.pi / 2), torch.FloatTensor([[0, -1, 8], [1, 0, 8], [0, 0, 1]])), + (f.rotate_euler(np.pi / 2), f.translate((8, 8)), torch.FloatTensor([[0, -1, -8], [1, 0, 8], [0, 0, 1]])), ] return cases diff --git a/tests/test_resample.py b/tests/test_resample.py index 5e85c741ef..d427be12d7 100644 --- a/tests/test_resample.py +++ b/tests/test_resample.py @@ -18,7 +18,7 @@ from tests.utils import get_arange_img -def rotate_45_2d(): +def rotate_90_2d(): t = torch.eye(3) t[:, 0] = torch.FloatTensor([0, -1, 0]) t[:, 1] = torch.FloatTensor([1, 0, 0]) @@ -27,10 +27,9 @@ def rotate_45_2d(): class TestResampleFunction(unittest.TestCase): def _test_resample_function_impl(self, img, matrix): - result = resample(convert_to_tensor(img), matrix) - print(result) + resample(convert_to_tensor(img), matrix) - RESAMPLE_FUNCTION_CASES = [(get_arange_img((1, 16, 16)), rotate_45_2d())] + RESAMPLE_FUNCTION_CASES = [(get_arange_img((1, 16, 16)), rotate_90_2d())] def test_resample_function(self): for case in self.RESAMPLE_FUNCTION_CASES: diff --git a/tests/utils.py b/tests/utils.py index 89766657ba..e963e03668 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -348,21 +348,13 @@ def make_rand_affine(ndim: int = 3, random_state: Optional[np.random.RandomState return af -def get_arange_img(size, dtype=torch.float32, offset=0): +def get_arange_img(size, dtype=np.float32, offset=0): """ - Returns an 2d or 3d image as a numpy tensor (complete with channel as dim 0) + Returns an image as a numpy array (complete with channel as dim 0) with contents that iterate like an arange. """ - img = torch.zeros(size, dtype=dtype) - if len(size) == 2: - for j in range(size[0]): - for i in range(size[1]): - img[j, i] = i + j * size[0] + offset - else: - for k in range(size[0]): - for j in range(size[1]): - for i in range(size[2]): - img[k, j, i] = i + j * size[0] + k * size[0] * size[1] + offset + n_elem = np.prod(size) + img = np.arange(offset, offset + n_elem, dtype=dtype).reshape(size) return np.expand_dims(img, 0) From 2d4122bf9c22ae438c91af895b4cc4b6b95c6b8a Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sat, 12 Nov 2022 23:02:24 +0000 Subject: [PATCH 13/19] apply/matmul/resample MVP Signed-off-by: Wenqi Li --- monai/transforms/__init__.py | 4 +- monai/transforms/lazy/functional.py | 222 +++++--------------- monai/transforms/meta_matrix.py | 275 +------------------------ monai/transforms/utility/functional.py | 30 ++- monai/transforms/utils.py | 88 -------- monai/utils/enums.py | 1 + tests/test_apply.py | 15 +- tests/test_matmul.py | 202 ------------------ 8 files changed, 92 insertions(+), 745 deletions(-) delete mode 100644 tests/test_matmul.py diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 2bee4eacce..eb0fd897aa 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -228,7 +228,7 @@ from .io.array import SUPPORTED_READERS, LoadImage, SaveImage from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict from .lazy.functional import apply -from .meta_matrix import Grid, Matrix, MatrixFactory, MetaMatrix, matmul +from .meta_matrix import matmul from .meta_utility.dictionary import ( FromMetaTensord, FromMetaTensorD, @@ -634,8 +634,6 @@ generate_label_classes_crop_centers, generate_pos_neg_label_crop_centers, generate_spatial_bounding_box, - get_backend_from_tensor_like, - get_device_from_tensor_like, get_extreme_points, get_largest_connected_component_mask, get_number_image_type_conversions, diff --git a/monai/transforms/lazy/functional.py b/monai/transforms/lazy/functional.py index ce9fc0fe84..733d7961fc 100644 --- a/monai/transforms/lazy/functional.py +++ b/monai/transforms/lazy/functional.py @@ -9,193 +9,77 @@ # See the License for the specific language governing permissions and # limitations under the License. -import itertools as it -from typing import Optional, Sequence, Union +from typing import Optional, Union import numpy as np import torch -from monai.config import DtypeLike from monai.data.meta_tensor import MetaTensor -from monai.transforms.meta_matrix import Matrix, MatrixFactory, MetaMatrix, matmul -from monai.transforms.spatial.array import Affine -from monai.transforms.utils import dtypes_to_str_or_identity, get_backend_from_tensor_like, get_device_from_tensor_like -from monai.utils import GridSampleMode, GridSamplePadMode +from monai.data.utils import to_affine_nd +from monai.transforms.meta_matrix import matmul +from monai.transforms.utility.functional import resample +from monai.utils import LazyAttr __all__ = ["apply"] -# TODO: This should move to a common place to be shared with dictionary -# from monai.utils.type_conversion import dtypes_to_str_or_identity - -GridSampleModeSequence = Union[Sequence[Union[GridSampleMode, str]], GridSampleMode, str] -GridSamplePadModeSequence = Union[Sequence[Union[GridSamplePadMode, str]], GridSamplePadMode, str] -DtypeSequence = Union[Sequence[DtypeLike], DtypeLike] - - -# TODO: move to mapping_stack.py -def extents_from_shape(shape, dtype=np.float32): - extents = [[0, shape[i]] for i in range(1, len(shape))] - - extents = it.product(*extents) - return [np.asarray(e + (1,), dtype=dtype) for e in extents] - - -# TODO: move to mapping_stack.py -def shape_from_extents( - src_shape: Sequence, extents: Union[Sequence[np.ndarray], Sequence[torch.Tensor], np.ndarray, torch.Tensor] -): - if isinstance(extents, (list, tuple)): - if isinstance(extents[0], np.ndarray): - aextents = np.asarray(extents) - else: - aextents = torch.stack(extents) - aextents = aextents.numpy() - else: - if isinstance(extents, np.ndarray): - aextents = extents - else: - aextents = extents.numpy() - - mins = aextents.min(axis=0) - maxes = aextents.max(axis=0) - values = np.round(maxes - mins).astype(int)[:-1].tolist() - return (src_shape[0],) + tuple(values) - - -def metadata_is_compatible(value_1, value_2): - if value_1 is None: - return True - else: - if value_2 is None: - return True - return value_1 == value_2 - - -def metadata_dtype_is_compatible(value_1, value_2): - if value_1 is None: - return True - else: - if value_2 is None: - return True - - # if we are here, value_1 and value_2 are both set - # TODO: this is not a good enough solution - value_1_ = dtypes_to_str_or_identity(value_1) - value_2_ = dtypes_to_str_or_identity(value_2) - return value_1_ == value_2_ - - -def starting_matrix_and_extents(matrix_factory, data): - # set up the identity matrix and metadata - cumulative_matrix = matrix_factory.identity() - cumulative_extents = extents_from_shape(data.shape) - return cumulative_matrix, cumulative_extents - - -def prepare_args_dict_for_apply(cur_mode, cur_padding_mode, cur_device, cur_dtype): - kwargs = {} - if cur_mode is not None: - kwargs["mode"] = cur_mode - if cur_padding_mode is not None: - kwargs["padding_mode"] = cur_padding_mode - if cur_device is not None: - kwargs["device"] = cur_device - if cur_dtype is not None: - kwargs["dtype"] = cur_dtype - - return kwargs - - -def matrix_from_matrix_container(matrix): - if isinstance(matrix, MetaMatrix): - return matrix.matrix.data - elif isinstance(matrix, Matrix): - return matrix.data - else: - return matrix - - -def apply(data: Union[torch.Tensor, MetaTensor, dict], pending: Optional[Union[dict, list]] = None): - """ - This method applies pending transforms to tensors. - Args: - data: A torch Tensor, monai MetaTensor, or a dictionary containing Tensors / MetaTensors - pending: Optional arg containing pending transforms. This must be set if data is a Tensor - or dictionary of Tensors, but is optional if data is a MetaTensor / dictionary of - MetaTensors. - """ - if isinstance(data, dict): - rd = dict() - for k, v in data.items(): - result = v(*pending) - rd[k] = result - return rd - if isinstance(data, MetaTensor) and pending is None: - pending_ = data.pending_operations - else: - pending_ = [] if pending is None else pending +def mat_from_pending(pending_item): + if isinstance(pending_item, (torch.Tensor, np.ndarray)): + return pending_item + if isinstance(pending_item, dict): + return pending_item[LazyAttr.AFFINE] + return pending_item - if len(pending_) == 0: - return data - dim_count = len(data.shape) - 1 - matrix_factory = MatrixFactory(dim_count, get_backend_from_tensor_like(data), get_device_from_tensor_like(data)) - - # set up the identity matrix and metadata - cumulative_matrix, cumulative_extents = starting_matrix_and_extents(matrix_factory, data) - - # set the various resampling parameters to an initial state - cur_mode = None - cur_padding_mode = None - cur_device = None - cur_dtype = None - cur_shape = data.shape - - for meta_matrix in pending_: - next_matrix = meta_matrix.matrix - # print("intermediate matrix\n", matrix_from_matrix_container(cumulative_matrix)) - cumulative_matrix = matmul(cumulative_matrix, next_matrix) - cumulative_extents = [matmul(e, cumulative_matrix) for e in cumulative_extents] - - new_mode = meta_matrix.metadata.get("mode", None) - new_padding_mode = meta_matrix.metadata.get("padding_mode", None) - new_device = meta_matrix.metadata.get("device", None) - new_dtype = meta_matrix.metadata.get("dtype", None) - new_shape = meta_matrix.metadata.get("shape_override", None) - - mode_compat = metadata_is_compatible(cur_mode, new_mode) - padding_mode_compat = metadata_is_compatible(cur_padding_mode, new_padding_mode) - device_compat = metadata_is_compatible(cur_device, new_device) - dtype_compat = metadata_dtype_is_compatible(cur_dtype, new_dtype) - - if mode_compat is False or padding_mode_compat is False or device_compat is False or dtype_compat is False: - # carry out an intermediate resample here due to incompatibility between arguments - kwargs = prepare_args_dict_for_apply(cur_mode, cur_padding_mode, cur_device, cur_dtype) +def kwargs_from_pending(pending_item): + if not isinstance(pending_item, dict): + return {} + ret = { + LazyAttr.INTERP_MODE: pending_item.get(LazyAttr.INTERP_MODE, None), # interpolation mode + LazyAttr.PADDING_MODE: pending_item.get(LazyAttr.PADDING_MODE, None), # padding mode + } + if LazyAttr.SHAPE in pending_item: + ret[LazyAttr.SHAPE] = pending_item[LazyAttr.SHAPE] + if LazyAttr.DTYPE in pending_item: + ret[LazyAttr.DTYPE] = pending_item[LazyAttr.DTYPE] + return ret + - cumulative_matrix_ = matrix_from_matrix_container(cumulative_matrix) - a = Affine(affine=cumulative_matrix_, **kwargs) - data, _ = a(img=data) +def is_compatible_kwargs(kwargs_1, kwargs_2): + return True - cumulative_matrix, cumulative_extents = starting_matrix_and_extents(matrix_factory, data) - cur_mode = cur_mode if new_mode is None else new_mode - cur_padding_mode = cur_padding_mode if new_padding_mode is None else new_padding_mode - cur_device = cur_device if new_device is None else new_device - cur_dtype = cur_dtype if new_dtype is None else new_dtype - cur_shape = cur_shape if new_shape is None else new_shape +def apply(data: Union[torch.Tensor, MetaTensor], pending: Optional[list] = None): + """ + This method applies pending transforms to tensors. + + Args: + data: A torch Tensor, monai MetaTensor + pending: pending transforms. This must be set if data is a Tensor, but is optional if data is a MetaTensor. + """ + if isinstance(data, MetaTensor) and pending is None: + pending = data.pending_operations + pending = [] if pending is None else pending - kwargs = prepare_args_dict_for_apply(cur_mode, cur_padding_mode, cur_device, cur_dtype) + if not pending: + return data - cumulative_matrix_ = matrix_from_matrix_container(cumulative_matrix) + cumulative_xform = mat_from_pending(pending[0]) + cur_kwargs = kwargs_from_pending(pending[0]) - # print(f"applying with cumulative matrix\n {cumulative_matrix_}") - a = Affine(affine=cumulative_matrix_, spatial_size=cur_shape[1:], normalized=False, **kwargs) - data, tx = a(img=data) + for p in pending[1:]: + new_kwargs = kwargs_from_pending(p) + if not is_compatible_kwargs(cur_kwargs, new_kwargs): + # carry out an intermediate resample here due to incompatibility between arguments + data = resample(data, cumulative_xform, cur_kwargs) + next_matrix = mat_from_pending(p) + cumulative_xform = matmul(cumulative_xform, next_matrix) + cur_kwargs.update(new_kwargs) + data = resample(data, cumulative_xform, cur_kwargs) if isinstance(data, MetaTensor): data.clear_pending_operations() - for p in pending_: - data.affine = p.matrix.data + data.affine = data.affine @ to_affine_nd(3, cumulative_xform) + for p in pending: data.push_applied_operation(p) - return data, None if pending is None else pending_ + return data, pending diff --git a/monai/transforms/meta_matrix.py b/monai/transforms/meta_matrix.py index 9733361c55..0bad720294 100644 --- a/monai/transforms/meta_matrix.py +++ b/monai/transforms/meta_matrix.py @@ -9,277 +9,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Sequence, Union -import numpy as np import torch from monai.config import NdarrayOrTensor -from monai.transforms.utils import ( - _create_flip, - _create_rotate, - _create_rotate_90, - _create_scale, - _create_shear, - _create_translate, - get_backend_from_tensor_like, - get_device_from_tensor_like, -) -from monai.utils import TransformBackends -__all__ = ["Grid", "matmul", "Matrix", "MatrixFactory", "MetaMatrix"] +__all__ = ["is_affine_shaped", "matmul"] -def is_matrix_shaped(data): +def is_affine_shaped(data): + """Check if the data is a square matrix for the last two dimensions.""" + if not hasattr(data, "shape") or len(data.shape) < 2: + return False + return data.shape[-1] in (3, 4) and data.shape[-2] in (3, 4) and data.shape[-1] == data.shape[-2] - return ( - len(data.shape) == 2 and data.shape[0] in (3, 4) and data.shape[1] in (3, 4) and data.shape[0] == data.shape[1] - ) - -def is_grid_shaped(data): - - return len(data.shape) == 3 and data.shape[0] == 3 or len(data.shape) == 4 and data.shape[0] == 4 - - -class MatrixFactory: - def __init__(self, dims: int, backend: TransformBackends, device: Optional[torch.device] = None): - - if backend == TransformBackends.NUMPY: - if device is not None: - raise ValueError("'device' must be None with TransformBackends.NUMPY") - self._device = None - self._sin = lambda th: np.sin(th, dtype=np.float32) - self._cos = lambda th: np.cos(th, dtype=np.float32) - self._eye = lambda th: np.eye(th, dtype=np.float32) - self._diag = lambda th: np.diag(th).astype(np.float32) - else: - if device is None: - raise ValueError("'device' must be set with TransformBackends.TORCH") - self._device = device - self._sin = lambda th: torch.sin(torch.as_tensor(th, dtype=torch.float32, device=self._device)) - self._cos = lambda th: torch.cos(torch.as_tensor(th, dtype=torch.float32, device=self._device)) - self._eye = lambda rank: torch.eye(rank, device=self._device, dtype=torch.float32) - self._diag = lambda size: torch.diag(torch.as_tensor(size, device=self._device, dtype=torch.float32)) - - self._backend = backend - self._dims = dims - - @staticmethod - def from_tensor(data): - return MatrixFactory(len(data.shape) - 1, get_backend_from_tensor_like(data), get_device_from_tensor_like(data)) - - def identity(self): - matrix = self._eye(self._dims + 1) - return MetaMatrix(matrix, {}) - - def rotate_euler(self, radians: Union[Sequence[float], float], **extra_args): - matrix = _create_rotate(self._dims, radians, self._sin, self._cos, self._eye) - return MetaMatrix(matrix, extra_args) - - def rotate_90(self, rotations, axis, **extra_args): - matrix = _create_rotate_90(self._dims, rotations, axis) - return MetaMatrix(matrix, extra_args) - - def flip(self, axis, **extra_args): - matrix = _create_flip(self._dims, axis, self._eye) - return MetaMatrix(matrix, extra_args) - - def shear(self, coefs: Union[Sequence[float], float], **extra_args): - matrix = _create_shear(self._dims, coefs, self._eye) - return MetaMatrix(matrix, extra_args) - - def scale(self, factors: Union[Sequence[float], float], **extra_args): - matrix = _create_scale(self._dims, factors, self._diag) - return MetaMatrix(matrix, extra_args) - - def translate(self, offsets: Union[Sequence[float], float], **extra_args): - matrix = _create_translate(self._dims, offsets, self._eye) - return MetaMatrix(matrix, extra_args) - - -def ensure_tensor(data: NdarrayOrTensor): - if isinstance(data, torch.Tensor): - return data - - return torch.as_tensor(data) - - -def apply_align_corners(matrix, spatial_size, factory): - inflated_spatial_size = tuple(s + 1 for s in spatial_size) - scale_factors = tuple(s / i for s, i in zip(spatial_size, inflated_spatial_size)) - scale_mat = factory.scale(scale_factors) - return matmul(scale_mat, matrix) - - -class Matrix: - def __init__(self, matrix: NdarrayOrTensor): - self.data = ensure_tensor(matrix) - - # def __matmul__(self, other): - # if isinstance(other, Matrix): - # other_matrix = other.data - # else: - # other_matrix = other - # return self.data @ other_matrix - # - # def __rmatmul__(self, other): - # return other.__matmul__(self.data) - - -class Grid: - def __init__(self, grid): - self.data = ensure_tensor(grid) - - # def __matmul__(self, other): - # raise NotImplementedError() - - -class MetaMatrix: - def __init__(self, matrix: Union[NdarrayOrTensor, Matrix, Grid], metadata: Optional[dict] = None): - - if not isinstance(matrix, (Matrix, Grid)): - if matrix.shape == 2: - if matrix.shape[0] != matrix.shape[1] or matrix.shape[0] not in (3, 4): - raise ValueError( - "If 'matrix' is passed a numpy ndarray/torch Tensor, it must" - f" be 3x3 or 4x4 ('matrix' has has shape {matrix.shape})" - ) - matrix_ = Matrix(matrix) - else: - matrix_ = matrix - self.matrix = matrix_ - - self.metadata = metadata or {} - - def __matmul__(self, other): - if isinstance(other, MetaMatrix): - other_ = other.matrix - else: - other_ = other - return MetaMatrix(self.matrix @ other_) - - def __rmatmul__(self, other): - if isinstance(other, MetaMatrix): - other_ = other.matrix - else: - other_ = other - return MetaMatrix(other_ @ self.matrix) - - -def matmul( - left: Union[MetaMatrix, Grid, Matrix, NdarrayOrTensor], right: Union[MetaMatrix, Grid, Matrix, NdarrayOrTensor] -): - matrix_types = (MetaMatrix, Grid, Matrix, torch.Tensor, np.ndarray) - - if not isinstance(left, matrix_types): - raise TypeError(f"'left' must be one of {matrix_types} but is {type(left)}") - if not isinstance(right, matrix_types): - raise TypeError(f"'second' must be one of {matrix_types} but is {type(right)}") - - left_ = left.matrix if isinstance(left, MetaMatrix) else left - right_ = right.matrix if isinstance(right, MetaMatrix) else right - - # TODO: it might be better to not return a metamatrix, unless we pass in the resulting - # metadata also - put_in_metamatrix = isinstance(left, MetaMatrix) or isinstance(right, MetaMatrix) - - put_in_grid = isinstance(left, Grid) or isinstance(right, Grid) - - put_in_matrix = isinstance(left, Matrix) or isinstance(right, Matrix) - put_in_matrix = False if put_in_grid is True else put_in_matrix - - promote_to_tensor = not (isinstance(left_, np.ndarray) and isinstance(right_, np.ndarray)) - - left_raw = left_.data if isinstance(left_, (Matrix, Grid)) else left_ - right_raw = right_.data if isinstance(right_, (Matrix, Grid)) else right_ - - if promote_to_tensor: - left_raw = torch.as_tensor(left_raw) - right_raw = torch.as_tensor(right_raw) - - if isinstance(left_, Grid): - if isinstance(right_, Grid): - raise RuntimeError("Unable to matrix multiply two Grids") - else: - result = matmul_grid_matrix(left_raw, right_raw) - else: - if isinstance(right_, Grid): - result = matmul_matrix_grid(left_raw, right_raw) - else: - result = matmul_matrix_matrix(left_raw, right_raw) - - if put_in_grid: - result = Grid(result) - elif put_in_matrix: - result = Matrix(result) - - if put_in_metamatrix: - result = MetaMatrix(result) - - return result - - -def matmul_matrix_grid(left: NdarrayOrTensor, right: NdarrayOrTensor): - if not is_matrix_shaped(left): - raise ValueError(f"'left' should be a 2D or 3D homogenous matrix but has shape {left.shape}") - - if not is_grid_shaped(right): - raise ValueError( - "'right' should be a 3D array with shape[0] == 2 or a " - f"4D array with shape[0] == 3 but has shape {right.shape}" - ) - - # flatten the grid to take advantage of torch batch matrix multiply - right_flat = right.reshape(right.shape[0], -1) - result_flat = left @ right_flat - # restore the grid shape - result = result_flat.reshape((-1,) + result_flat.shape[1:]) - return result - - -def matmul_grid_matrix(left: NdarrayOrTensor, right: NdarrayOrTensor): - if not is_grid_shaped(left): - raise ValueError( - "'left' should be a 3D array with shape[0] == 2 or a " - f"4D array with shape[0] == 3 but has shape {left.shape}" - ) - - if not is_matrix_shaped(right): - raise ValueError(f"'right' should be a 2D or 3D homogenous matrix but has shape {right.shape}") - - try: - inv_matrix = torch.inverse(right) - except RuntimeError: - # the matrix is not invertible, so we will have to perform a slow grid to matrix operation - return matmul_grid_matrix_slow(left, right) - - # invert the matrix and swap the arguments, taking advantage of - # matrix @ vector == vector_transposed @ matrix_inverse - return matmul_matrix_grid(inv_matrix, left) - - -def matmul_grid_matrix_slow(left: NdarrayOrTensor, right: NdarrayOrTensor): - if not is_grid_shaped(left): - raise ValueError( - "'left' should be a 3D array with shape[0] == 2 or a " - f"4D array with shape[0] == 3 but has shape {left.shape}" - ) - - if not is_matrix_shaped(right): - raise ValueError(f"'right' should be a 2D or 3D homogenous matrix but has shape {right.shape}") - - flat_left = left.reshape(left.shape[0], -1) - result_flat = torch.zeros_like(flat_left) - for i in range(flat_left.shape[1]): - vector = flat_left[:, i][None, :] - result_vector = vector @ right - result_flat[:, i] = result_vector[0, :] - - # restore the grid shape - result = result_flat.reshape((-1,) + result_flat.shape[1:]) - return result - - -def matmul_matrix_matrix(left: NdarrayOrTensor, right: NdarrayOrTensor): - return left @ right +def matmul(left: NdarrayOrTensor, right: NdarrayOrTensor): + if is_affine_shaped(left) and is_affine_shaped(right): + return torch.matmul(left, right) + raise NotImplementedError diff --git a/monai/transforms/utility/functional.py b/monai/transforms/utility/functional.py index 56570799f7..00045d25e9 100644 --- a/monai/transforms/utility/functional.py +++ b/monai/transforms/utility/functional.py @@ -8,21 +8,33 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union +from typing import Optional import torch +import monai from monai.config import NdarrayOrTensor -from monai.transforms import Affine -from monai.transforms.meta_matrix import Grid, Matrix +from monai.transforms.meta_matrix import is_affine_shaped +from monai.utils import LazyAttr +__all__ = ["resample"] -def resample(data: torch.Tensor, matrix: Union[NdarrayOrTensor, Matrix, Grid], kwargs: Optional[dict] = None): + +def resample(data: torch.Tensor, matrix: NdarrayOrTensor, kwargs: Optional[dict] = None): """ This is a minimal implementation of resample that always uses Affine. """ - if kwargs is not None: - a = Affine(affine=matrix, **kwargs) - else: - a = Affine(affine=matrix) - return a(img=data) + if not is_affine_shaped(matrix): + raise NotImplementedError("calling dense grid resample API not implemented") + kwargs = {} if kwargs is None else kwargs + init_kwargs = { + "spatial_size": kwargs.pop(LazyAttr.SHAPE, data.shape)[1:], + "dtype": kwargs.pop(LazyAttr.DTYPE, data.dtype), + } + call_kwargs = { + "mode": kwargs.pop(LazyAttr.INTERP_MODE, None), + "padding_mode": kwargs.pop(LazyAttr.PADDING_MODE, None), + } + resampler = monai.transforms.Affine(affine=matrix, image_only=True, **init_kwargs) + with resampler.trace_transform(False): # don't track this transform in `data` + return resampler(img=data, **call_kwargs) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index ee7c9d1645..e8a9791999 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -1883,93 +1883,5 @@ def squarepulse(sig, duty: float = 0.5): return y -def get_device_from_tensor_like(data: NdarrayOrTensor): - """ - This function returns the device of `data`, which must either be a numpy ndarray or a - pytorch Tensor. - Args: - data: the ndarray/tensor to return the device of - Returns: - None if `data` is a numpy array, or the device of the pytorch Tensor - """ - if isinstance(data, np.ndarray): - return None - elif isinstance(data, torch.Tensor): - return data.device - else: - msg = "'data' must be one of numpy ndarray or torch Tensor but is {}" - raise ValueError(msg.format(type(data))) - - -def get_backend_from_tensor_like(data: NdarrayOrTensor): - """ - This function returns the backend of `data`, which must either be a numpy ndarray or a - pytorch Tensor. - Args: - data: the ndarray/tensor to return the device of - Returns: - None if `data` is a numpy array, or the device of the pytorch Tensor - """ - if isinstance(data, np.ndarray): - return TransformBackends.NUMPY - elif isinstance(data, torch.Tensor): - return TransformBackends.TORCH - else: - msg = "'data' must be one of numpy ndarray or torch Tensor but is {}" - raise ValueError(msg.format(type(data))) - - -def dtype_torch_to_numpy(dtype: torch.dtype) -> np.dtype: - """Convert a torch dtype to its numpy equivalent.""" - return torch.empty([], dtype=dtype).numpy().dtype # type: ignore - - -def dtype_numpy_to_torch(dtype: np.dtype) -> torch.dtype: - """Convert a numpy dtype to its torch equivalent.""" - return torch.from_numpy(np.empty([], dtype=dtype)).dtype - - -__dtype_dict = { - np.int8: "int8", - torch.int8: "int8", - np.int16: "int16", - torch.int16: "int16", - int: "int32", - np.int32: "int32", - torch.int32: "int32", - np.int64: "int64", - torch.int64: "int64", - np.uint8: "uint8", - torch.uint8: "uint8", - np.uint16: "uint16", - np.uint32: "uint32", - np.uint64: "uint64", - float: "float32", - np.float16: "float16", - np.float: "float32", - np.float32: "float32", - np.float64: "float64", - torch.float16: "float16", - torch.float: "float32", - torch.float32: "float32", - torch.double: "float64", - torch.float64: "float64", -} - - -def dtypes_to_str_or_identity(dtype: Any) -> Any: - return __dtype_dict.get(dtype, dtype) - - -def get_numpy_dtype_from_string(dtype: str) -> np.dtype: - """Get a numpy dtype (e.g., `np.float32`) from its string (e.g., `"float32"`).""" - return np.empty([], dtype=dtype).dtype - - -def get_torch_dtype_from_string(dtype: str) -> torch.dtype: - """Get a torch dtype (e.g., `torch.float32`) from its string (e.g., `"float32"`).""" - return dtype_numpy_to_torch(get_numpy_dtype_from_string(dtype)) - - if __name__ == "__main__": print_transform_backends() diff --git a/monai/utils/enums.py b/monai/utils/enums.py index 4fd9bea557..b606cd8667 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -630,3 +630,4 @@ class LazyAttr(StrEnum): AFFINE = "lazy_affine" PADDING_MODE = "lazy_padding_mode" INTERP_MODE = "lazy_interpolation_mode" + DTYPE = "lazy_dtype" diff --git a/tests/test_apply.py b/tests/test_apply.py index 49130944cb..45abb00326 100644 --- a/tests/test_apply.py +++ b/tests/test_apply.py @@ -15,24 +15,21 @@ import torch from monai.transforms.lazy.functional import apply -from monai.transforms.meta_matrix import MatrixFactory, MetaMatrix -from monai.utils import TransformBackends, convert_to_tensor +from monai.transforms.utils import create_rotate +from monai.utils import LazyAttr, convert_to_tensor def single_2d_transform_cases(): - f = MatrixFactory(2, TransformBackends.TORCH, "cpu") - - cases = [ - (torch.randn((1, 32, 32)), [MetaMatrix(f.rotate_euler(np.pi / 4).matrix, {"id": "rotate"})], (1, 32, 32)), + return [ + (torch.randn((1, 32, 32)), [{LazyAttr.AFFINE: create_rotate(2, np.pi / 4)}], (1, 32, 32)), + (torch.randn((1, 32, 32)), [create_rotate(2, np.pi / 4)], (1, 32, 32)), ( torch.randn((1, 16, 16)), - [MetaMatrix(f.rotate_euler(np.pi / 4).matrix, {"id": "rotate", "shape_override": (1, 45, 45)})], + [{LazyAttr.AFFINE: create_rotate(2, np.pi / 4), LazyAttr.SHAPE: (1, 45, 45)}], (1, 45, 45), ), ] - return cases - class TestApply(unittest.TestCase): def _test_apply_impl(self, tensor, pending_transforms, expected_shape): diff --git a/tests/test_matmul.py b/tests/test_matmul.py deleted file mode 100644 index 661291d4ad..0000000000 --- a/tests/test_matmul.py +++ /dev/null @@ -1,202 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import numpy as np -import torch -from parameterized import parameterized - -from monai.transforms.meta_matrix import ( - Grid, - Matrix, - MatrixFactory, - is_grid_shaped, - is_matrix_shaped, - matmul, - matmul_grid_matrix, - matmul_grid_matrix_slow, - matmul_matrix_grid, - matmul_matrix_matrix, -) -from monai.utils import TransformBackends - - -class TestMatmulFunctions(unittest.TestCase): - def test_matrix_grid_and_grid_inv_matrix_are_equivalent(self): - grid = torch.randn((3, 32, 32)) - - matrix = torch.eye(3, 3) - matrix[:, 0] = torch.FloatTensor([0, -1, 0]) - matrix[:, 1] = torch.FloatTensor([1, 0, 0]) - - inv_matrix = torch.inverse(matrix) - - result1 = matmul_matrix_grid(matrix, grid) - result2 = matmul_grid_matrix(grid, inv_matrix) - self.assertTrue(torch.allclose(result1, result2)) - - def test_matmul_grid_matrix_slow(self): - grid = torch.randn((3, 32, 32)) - - matrix = torch.eye(3, 3) - matrix[:, 0] = torch.FloatTensor([0, -1, 0]) - matrix[:, 1] = torch.FloatTensor([1, 0, 0]) - - result1 = matmul_grid_matrix_slow(grid, matrix) - result2 = matmul_grid_matrix(grid, matrix) - self.assertTrue(torch.allclose(result1, result2)) - - MATMUL_TEST_CASES = [ - [np.eye(3, dtype=np.float32), np.eye(3, dtype=np.float32), np.ndarray], - [np.eye(3, dtype=np.float32), torch.eye(3), torch.Tensor], - [np.eye(3, dtype=np.float32), Matrix(torch.eye(3)), Matrix], - [np.eye(3, dtype=np.float32), Grid(torch.randn((3, 8, 8))), Grid], - [torch.eye(3), np.eye(3, dtype=np.float32), torch.Tensor], - [torch.eye(3), torch.eye(3), torch.Tensor], - [torch.eye(3), Matrix(torch.eye(3)), Matrix], - [torch.eye(3), Grid(torch.randn((3, 8, 8))), Grid], - [Matrix(torch.eye(3)), np.eye(3, dtype=np.float32), Matrix], - [Matrix(torch.eye(3)), torch.eye(3), Matrix], - [Matrix(torch.eye(3)), Matrix(torch.eye(3)), Matrix], - [Matrix(torch.eye(3)), Grid(torch.randn((3, 8, 8))), Grid], - [Grid(torch.randn((3, 8, 8))), np.eye(3, dtype=np.float32), Grid], - [Grid(torch.randn((3, 8, 8))), torch.eye(3), Grid], - [Grid(torch.randn((3, 8, 8))), Matrix(torch.eye(3)), Grid], - [Grid(torch.randn((3, 8, 8))), Grid(torch.randn((3, 8, 8))), None], - ] - - def _test_matmul_correct_return_type_impl(self, left, right, expected): - if expected is None: - with self.assertRaises(RuntimeError): - result = matmul(left, right) - else: - result = matmul(left, right) - self.assertIsInstance(result, expected) - - @parameterized.expand(MATMUL_TEST_CASES) - def test_matmul_correct_return_type(self, left, right, expected): - self._test_matmul_correct_return_type_impl(left, right, expected) - - # def test_all_matmul_correct_return_type(self): - # for case in self.MATMUL_TEST_CASES: - # with self.subTest(f"{case}"): - # self._test_matmul_correct_return_type_impl(*case) - - MATRIX_SHAPE_TESTCASES = [ - (torch.randn(2, 2), False), - (torch.randn(3, 3), True), - (torch.randn(4, 4), True), - (torch.randn(5, 5), False), - (torch.randn(3, 4), False), - (torch.randn(4, 3), False), - (torch.randn(3), False), - (torch.randn(4), False), - (torch.randn(5), False), - (torch.randn(3, 3, 3), False), - (torch.randn(4, 4, 4), False), - (torch.randn(5, 5, 5), False), - ] - - def _test_is_matrix_shaped_impl(self, matrix, expected): - self.assertEqual(is_matrix_shaped(matrix), expected) - - @parameterized.expand(MATRIX_SHAPE_TESTCASES) - def test_is_matrix_shaped(self, matrix, expected): - self._test_is_matrix_shaped_impl(matrix, expected) - - # def test_all_is_matrix_shaped(self): - # for case in self.MATRIX_SHAPE_TESTCASES: - # with self.subTest(f"{case[0].shape}"): - # self._test_is_matrix_shaped_impl(*case) - - GRID_SHAPE_TESTCASES = [ - (torch.randn(1, 16, 32), False), - (torch.randn(2, 16, 32), False), - (torch.randn(3, 16, 32), True), - (torch.randn(4, 16, 32), False), - (torch.randn(5, 16, 32), False), - (torch.randn(3, 16, 32, 64), False), - (torch.randn(4, 16, 32, 64), True), - (torch.randn(5, 16, 32, 64), False), - ] - - def _test_is_grid_shaped_impl(self, grid, expected): - self.assertEqual(is_grid_shaped(grid), expected) - - @parameterized.expand(GRID_SHAPE_TESTCASES) - def test_is_grid_shaped(self, grid, expected): - self._test_is_grid_shaped_impl(grid, expected) - - # def test_all_is_grid_shaped(self): - # for case in self.GRID_SHAPE_TESTCASES: - # with self.subTest(f"{case[0].shape}"): - # self._test_is_grid_shaped_impl(*case) - - -def get_matmul_2d_test_cases(): - f = MatrixFactory(2, TransformBackends.TORCH, "cpu") - cases = [ - ( - f.rotate_euler(np.pi / 4), - f.scale((0.5, 0.5)), - torch.FloatTensor([[0.35355339, -0.35355339, 0], [0.35355339, 0.35355339, 0], [0, 0, 1]]), - ), - ( - f.scale((0.5, 0.5)), - f.rotate_euler(np.pi / 4), - torch.FloatTensor([[0.35355339, -0.35355339, 0], [0.35355339, 0.35355339, 0], [0, 0, 1]]), - ), - (f.translate((8, 8)), f.rotate_euler(np.pi / 2), torch.FloatTensor([[0, -1, 8], [1, 0, 8], [0, 0, 1]])), - (f.rotate_euler(np.pi / 2), f.translate((8, 8)), torch.FloatTensor([[0, -1, -8], [1, 0, 8], [0, 0, 1]])), - ] - - return cases - - -MATMUL_2D_TEST_CASES = get_matmul_2d_test_cases() - - -class TestMatmulOutputs(unittest.TestCase): - def _test_matmul_outputs_impl(self, left, right, expected): - actual = matmul(left, right) - self.assertTrue( - torch.allclose(actual.matrix.data, expected, atol=1e-7), - msg=f"{actual.matrix.data} is not close to {expected}", - ) - - @parameterized.expand(MATMUL_2D_TEST_CASES) - def test_matmul_outputs(self, left, right, expected): - self._test_matmul_outputs_impl(left, right, expected) - - def test_all_matmul_outputs(self): - cases = MATMUL_2D_TEST_CASES - for case in cases: - self._test_matmul_outputs_impl(*case) - - -class TestMatrixMatrixOutputs(unittest.TestCase): - def _test_matrix_matrix_outputs_impl(self, left, right, expected): - actual = matmul_matrix_matrix(left.matrix.data, right.matrix.data) - self.assertTrue(torch.allclose(actual, expected, atol=1e-7), msg=f"{actual} is not close to {expected}") - - @parameterized.expand(MATMUL_2D_TEST_CASES) - def test_matrix_matrix_outputs(self, left, right, expected): - self._test_matrix_matrix_outputs_impl(left, right, expected) - - def test_all_matrix_matrix_outputs(self): - cases = MATMUL_2D_TEST_CASES - for case in cases: - self._test_matrix_matrix_outputs_impl(*case) - - -if __name__ == "__main__": - unittest.main() From 6c4db356f9cb5071e745214c5da084e99bf4dcd8 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 18 Nov 2022 13:40:12 +0000 Subject: [PATCH 14/19] rearrange modules, adds back matrix/grid types Signed-off-by: Wenqi Li --- monai/transforms/__init__.py | 2 +- monai/transforms/lazy/functional.py | 31 +------ monai/transforms/lazy/utils.py | 119 +++++++++++++++++++++++++ monai/transforms/meta_matrix.py | 30 ------- monai/transforms/utility/functional.py | 40 --------- monai/transforms/utils.py | 39 ++++---- tests/test_resample.py | 2 +- 7 files changed, 141 insertions(+), 122 deletions(-) create mode 100644 monai/transforms/lazy/utils.py delete mode 100644 monai/transforms/meta_matrix.py delete mode 100644 monai/transforms/utility/functional.py diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index eb0fd897aa..196b137702 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -228,7 +228,7 @@ from .io.array import SUPPORTED_READERS, LoadImage, SaveImage from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict from .lazy.functional import apply -from .meta_matrix import matmul +from .lazy.utils import matmul, resample from .meta_utility.dictionary import ( FromMetaTensord, FromMetaTensorD, diff --git a/monai/transforms/lazy/functional.py b/monai/transforms/lazy/functional.py index 733d7961fc..ccf2787749 100644 --- a/monai/transforms/lazy/functional.py +++ b/monai/transforms/lazy/functional.py @@ -11,44 +11,15 @@ from typing import Optional, Union -import numpy as np import torch from monai.data.meta_tensor import MetaTensor from monai.data.utils import to_affine_nd -from monai.transforms.meta_matrix import matmul -from monai.transforms.utility.functional import resample -from monai.utils import LazyAttr +from monai.transforms.lazy.utils import is_compatible_kwargs, kwargs_from_pending, mat_from_pending, matmul, resample __all__ = ["apply"] -def mat_from_pending(pending_item): - if isinstance(pending_item, (torch.Tensor, np.ndarray)): - return pending_item - if isinstance(pending_item, dict): - return pending_item[LazyAttr.AFFINE] - return pending_item - - -def kwargs_from_pending(pending_item): - if not isinstance(pending_item, dict): - return {} - ret = { - LazyAttr.INTERP_MODE: pending_item.get(LazyAttr.INTERP_MODE, None), # interpolation mode - LazyAttr.PADDING_MODE: pending_item.get(LazyAttr.PADDING_MODE, None), # padding mode - } - if LazyAttr.SHAPE in pending_item: - ret[LazyAttr.SHAPE] = pending_item[LazyAttr.SHAPE] - if LazyAttr.DTYPE in pending_item: - ret[LazyAttr.DTYPE] = pending_item[LazyAttr.DTYPE] - return ret - - -def is_compatible_kwargs(kwargs_1, kwargs_2): - return True - - def apply(data: Union[torch.Tensor, MetaTensor], pending: Optional[list] = None): """ This method applies pending transforms to tensors. diff --git a/monai/transforms/lazy/utils.py b/monai/transforms/lazy/utils.py new file mode 100644 index 0000000000..0578dda94a --- /dev/null +++ b/monai/transforms/lazy/utils.py @@ -0,0 +1,119 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import numpy as np +import torch + +import monai +from monai.config import NdarrayOrTensor +from monai.utils import LazyAttr + +__all__ = ["resample", "matmul"] + + +class Affine: + """A class to represent an affine transform matrix.""" + + __slots__ = ("data",) + + def __init__(self, data): + self.data = data + + @staticmethod + def is_affine_shaped(data): + """Check if the data is an affine matrix.""" + if isinstance(data, Affine): + return True + if isinstance(data, DDF): + return False + if not hasattr(data, "shape") or len(data.shape) < 2: + return False + return data.shape[-1] in (3, 4) and data.shape[-2] in (3, 4) and data.shape[-1] == data.shape[-2] + + +class DDF: + """A class to represent a dense displacement field.""" + + __slots__ = ("data",) + + def __init__(self, data): + self.data = data + + @staticmethod + def is_ddf_shaped(data): + """Check if the data is a DDF.""" + if isinstance(data, DDF): + return True + if isinstance(data, Affine): + return False + if not hasattr(data, "shape") or len(data.shape) < 3: + return False + return not Affine.is_affine_shaped(data) + + +def matmul(left: torch.Tensor, right: torch.Tensor): + if Affine.is_affine_shaped(left) and Affine.is_affine_shaped(right): # linear transforms + if isinstance(left, Affine): + left = left.data + if isinstance(right, Affine): + right = right.data + return torch.matmul(left, right) + if DDF.is_ddf_shaped(left) and DDF.is_ddf_shaped(right): # adds DDFs + return left + right + raise NotImplementedError + + +def mat_from_pending(pending_item): + if isinstance(pending_item, (torch.Tensor, np.ndarray)): + return pending_item + if isinstance(pending_item, dict): + return pending_item[LazyAttr.AFFINE] + return pending_item + + +def kwargs_from_pending(pending_item): + if not isinstance(pending_item, dict): + return {} + ret = { + LazyAttr.INTERP_MODE: pending_item.get(LazyAttr.INTERP_MODE, None), # interpolation mode + LazyAttr.PADDING_MODE: pending_item.get(LazyAttr.PADDING_MODE, None), # padding mode + } + if LazyAttr.SHAPE in pending_item: + ret[LazyAttr.SHAPE] = pending_item[LazyAttr.SHAPE] + if LazyAttr.DTYPE in pending_item: + ret[LazyAttr.DTYPE] = pending_item[LazyAttr.DTYPE] + return ret + + +def is_compatible_kwargs(kwargs_1, kwargs_2): + return True + + +def resample(data: torch.Tensor, matrix: NdarrayOrTensor, kwargs: Optional[dict] = None): + """ + This is a minimal implementation of resample that always uses Affine. + """ + if not Affine.is_affine_shaped(matrix): + raise NotImplementedError("calling dense grid resample API not implemented") + kwargs = {} if kwargs is None else kwargs + init_kwargs = { + "spatial_size": kwargs.pop(LazyAttr.SHAPE, data.shape)[1:], + "dtype": kwargs.pop(LazyAttr.DTYPE, data.dtype), + } + call_kwargs = { + "mode": kwargs.pop(LazyAttr.INTERP_MODE, None), + "padding_mode": kwargs.pop(LazyAttr.PADDING_MODE, None), + } + resampler = monai.transforms.Affine(affine=matrix, image_only=True, **init_kwargs) + with resampler.trace_transform(False): # don't track this transform in `data` + return resampler(img=data, **call_kwargs) diff --git a/monai/transforms/meta_matrix.py b/monai/transforms/meta_matrix.py deleted file mode 100644 index 0bad720294..0000000000 --- a/monai/transforms/meta_matrix.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import torch - -from monai.config import NdarrayOrTensor - -__all__ = ["is_affine_shaped", "matmul"] - - -def is_affine_shaped(data): - """Check if the data is a square matrix for the last two dimensions.""" - if not hasattr(data, "shape") or len(data.shape) < 2: - return False - return data.shape[-1] in (3, 4) and data.shape[-2] in (3, 4) and data.shape[-1] == data.shape[-2] - - -def matmul(left: NdarrayOrTensor, right: NdarrayOrTensor): - if is_affine_shaped(left) and is_affine_shaped(right): - return torch.matmul(left, right) - raise NotImplementedError diff --git a/monai/transforms/utility/functional.py b/monai/transforms/utility/functional.py deleted file mode 100644 index 00045d25e9..0000000000 --- a/monai/transforms/utility/functional.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Optional - -import torch - -import monai -from monai.config import NdarrayOrTensor -from monai.transforms.meta_matrix import is_affine_shaped -from monai.utils import LazyAttr - -__all__ = ["resample"] - - -def resample(data: torch.Tensor, matrix: NdarrayOrTensor, kwargs: Optional[dict] = None): - """ - This is a minimal implementation of resample that always uses Affine. - """ - if not is_affine_shaped(matrix): - raise NotImplementedError("calling dense grid resample API not implemented") - kwargs = {} if kwargs is None else kwargs - init_kwargs = { - "spatial_size": kwargs.pop(LazyAttr.SHAPE, data.shape)[1:], - "dtype": kwargs.pop(LazyAttr.DTYPE, data.dtype), - } - call_kwargs = { - "mode": kwargs.pop(LazyAttr.INTERP_MODE, None), - "padding_mode": kwargs.pop(LazyAttr.PADDING_MODE, None), - } - resampler = monai.transforms.Affine(affine=matrix, image_only=True, **init_kwargs) - with resampler.trace_transform(False): # don't track this transform in `data` - return resampler(img=data, **call_kwargs) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index e8a9791999..b2e693428e 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -890,53 +890,52 @@ def _create_translate( def _create_rotate_90( - spatial_dims: int, axis: Tuple[int, int], steps: Optional[int] = 1, eye_func: Callable = np.eye + spatial_dims: int, axes: Tuple[int, int], steps: int = 1, eye_func: Callable = np.eye ) -> NdarrayOrTensor: values = [(1, 0, 0, 1), (0, -1, 1, 0), (-1, 0, 0, -1), (0, 1, -1, 0)] if spatial_dims == 2: - if axis != (0, 1): - raise ValueError(f"if 'spatial_dims' is 2, 'axis' must be (0, 1) but is {axis}") + if axes != (0, 1): + raise ValueError(f"if 'spatial_dims' is 2, 'axis' must be (0, 1) but is {axes}") elif spatial_dims == 3: - if axis not in ((0, 1), (0, 2), (1, 2)): - raise ValueError("if 'spatial_dims' is 3, 'axis' must be (0,1), (0, 2), or (1, 2) " f"but is {axis}") + if axes not in ((0, 1), (0, 2), (1, 2)): + raise ValueError("if 'spatial_dims' is 3, 'axis' must be (0,1), (0, 2), or (1, 2) " f"but is {axes}") else: raise ValueError(f"'spatial_dims' must be 2 or 3 but is {spatial_dims}") affine = eye_func(spatial_dims + 1) - a, b = (0, 1) if spatial_dims == 2 else axis + a, b = (0, 1) if spatial_dims == 2 else axes affine[a, a], affine[a, b], affine[b, a], affine[b, b] = values[steps % 4] - return affine + return affine # type: ignore def create_rotate_90( spatial_dims: int, - axis: int, - steps: Optional[int] = 1, + axes: Tuple[int, int] = (0, 1), + steps: int = 1, device: Optional[torch.device] = None, backend: str = TransformBackends.NUMPY, ) -> NdarrayOrTensor: """ - create a 2D or 3D rotation matrix + create a 2D or 3D rotation90 matrix. + Args: spatial_dims: {``2``, ``3``} spatial rank - radians: rotation radians - when spatial_dims == 3, the `radians` sequence corresponds to - rotation in the 1st, 2nd, and 3rd dim respectively. + axes: 2 int numbers, defines the plane to rotate with 2 spatial axes. + Default: (0, 1), this is the first two axis in spatial dimensions. + If axis is negative it counts from the last to the first axis. + steps: number of times to rotate by 90 degrees device: device to compute and store the output (when the backend is "torch"). backend: APIs to use, ``numpy`` or ``torch``. - Raises: - ValueError: When ``radians`` is empty. - ValueError: When ``spatial_dims`` is not one of [2, 3]. """ _backend = look_up_option(backend, TransformBackends) if _backend == TransformBackends.NUMPY: - return _create_rotate_90(spatial_dims=spatial_dims, axis=axis, steps=steps, eye_func=np.eye) + return _create_rotate_90(spatial_dims=spatial_dims, axes=axes, steps=steps, eye_func=np.eye) if _backend == TransformBackends.TORCH: return _create_rotate_90( - spatial_dims=spatial_dims, axis=axis, steps=steps, eye_func=lambda rank: torch.eye(rank, device=device) + spatial_dims=spatial_dims, axes=axes, steps=steps, eye_func=lambda rank: torch.eye(rank, device=device) ) raise ValueError(f"backend {backend} is not supported") @@ -974,9 +973,9 @@ def create_flip( ) -> NdarrayOrTensor: _backend = look_up_option(backend, TransformBackends) if _backend == TransformBackends.NUMPY: - return _create_flip(spatial_dims=spatial_dims, spatial_axis=spatial_axis, eye_func=np.eye) + return _create_flip(spatial_dims=spatial_dims, spatial_axis=spatial_axis, eye_func=np.eye) # type: ignore if _backend == TransformBackends.TORCH: - return _create_flip( + return _create_flip( # type: ignore spatial_dims=spatial_dims, spatial_axis=spatial_axis, eye_func=lambda rank: torch.eye(rank, device=device) ) raise ValueError(f"backend {backend} is not supported") diff --git a/tests/test_resample.py b/tests/test_resample.py index d427be12d7..aa48cee13b 100644 --- a/tests/test_resample.py +++ b/tests/test_resample.py @@ -13,7 +13,7 @@ import torch -from monai.transforms.utility.functional import resample +from monai.transforms.lazy.functional import resample from monai.utils import convert_to_tensor from tests.utils import get_arange_img From 9eb52791d1500e1ecb3b66177329db21ef3becfe Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 22 Nov 2022 15:28:47 +0000 Subject: [PATCH 15/19] matmul -> combine_transforms Signed-off-by: Wenqi Li --- monai/transforms/__init__.py | 2 +- monai/transforms/lazy/functional.py | 10 ++++++++-- monai/transforms/lazy/utils.py | 9 +++++++-- 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 196b137702..29884da00a 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -228,7 +228,7 @@ from .io.array import SUPPORTED_READERS, LoadImage, SaveImage from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict from .lazy.functional import apply -from .lazy.utils import matmul, resample +from .lazy.utils import combine_transforms, resample from .meta_utility.dictionary import ( FromMetaTensord, FromMetaTensorD, diff --git a/monai/transforms/lazy/functional.py b/monai/transforms/lazy/functional.py index ccf2787749..97a18cc08d 100644 --- a/monai/transforms/lazy/functional.py +++ b/monai/transforms/lazy/functional.py @@ -15,7 +15,13 @@ from monai.data.meta_tensor import MetaTensor from monai.data.utils import to_affine_nd -from monai.transforms.lazy.utils import is_compatible_kwargs, kwargs_from_pending, mat_from_pending, matmul, resample +from monai.transforms.lazy.utils import ( + combine_transforms, + is_compatible_kwargs, + kwargs_from_pending, + mat_from_pending, + resample, +) __all__ = ["apply"] @@ -44,7 +50,7 @@ def apply(data: Union[torch.Tensor, MetaTensor], pending: Optional[list] = None) # carry out an intermediate resample here due to incompatibility between arguments data = resample(data, cumulative_xform, cur_kwargs) next_matrix = mat_from_pending(p) - cumulative_xform = matmul(cumulative_xform, next_matrix) + cumulative_xform = combine_transforms(cumulative_xform, next_matrix) cur_kwargs.update(new_kwargs) data = resample(data, cumulative_xform, cur_kwargs) if isinstance(data, MetaTensor): diff --git a/monai/transforms/lazy/utils.py b/monai/transforms/lazy/utils.py index 0578dda94a..5f30ba840c 100644 --- a/monai/transforms/lazy/utils.py +++ b/monai/transforms/lazy/utils.py @@ -18,7 +18,7 @@ from monai.config import NdarrayOrTensor from monai.utils import LazyAttr -__all__ = ["resample", "matmul"] +__all__ = ["resample", "combine_transforms"] class Affine: @@ -61,7 +61,8 @@ def is_ddf_shaped(data): return not Affine.is_affine_shaped(data) -def matmul(left: torch.Tensor, right: torch.Tensor): +def combine_transforms(left: torch.Tensor, right: torch.Tensor): + """Given transforms A and B to be applied to x, return the combined transform (AB), so that A(B(x)) becomes AB(x)""" if Affine.is_affine_shaped(left) and Affine.is_affine_shaped(right): # linear transforms if isinstance(left, Affine): left = left.data @@ -69,6 +70,10 @@ def matmul(left: torch.Tensor, right: torch.Tensor): right = right.data return torch.matmul(left, right) if DDF.is_ddf_shaped(left) and DDF.is_ddf_shaped(right): # adds DDFs + if isinstance(left, DDF): + left = left.data + if isinstance(right, DDF): + right = right.data return left + right raise NotImplementedError From 17574cb79d00a9c55b3d4dbb5cbbad579ec1f810 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 23 Nov 2022 11:10:36 +0000 Subject: [PATCH 16/19] remove unused Signed-off-by: Wenqi Li --- monai/transforms/utils.py | 92 --------------------------------------- 1 file changed, 92 deletions(-) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index b2e693428e..e96d906f20 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -889,98 +889,6 @@ def _create_translate( return array_func(affine) # type: ignore -def _create_rotate_90( - spatial_dims: int, axes: Tuple[int, int], steps: int = 1, eye_func: Callable = np.eye -) -> NdarrayOrTensor: - - values = [(1, 0, 0, 1), (0, -1, 1, 0), (-1, 0, 0, -1), (0, 1, -1, 0)] - - if spatial_dims == 2: - if axes != (0, 1): - raise ValueError(f"if 'spatial_dims' is 2, 'axis' must be (0, 1) but is {axes}") - elif spatial_dims == 3: - if axes not in ((0, 1), (0, 2), (1, 2)): - raise ValueError("if 'spatial_dims' is 3, 'axis' must be (0,1), (0, 2), or (1, 2) " f"but is {axes}") - else: - raise ValueError(f"'spatial_dims' must be 2 or 3 but is {spatial_dims}") - - affine = eye_func(spatial_dims + 1) - - a, b = (0, 1) if spatial_dims == 2 else axes - affine[a, a], affine[a, b], affine[b, a], affine[b, b] = values[steps % 4] - return affine # type: ignore - - -def create_rotate_90( - spatial_dims: int, - axes: Tuple[int, int] = (0, 1), - steps: int = 1, - device: Optional[torch.device] = None, - backend: str = TransformBackends.NUMPY, -) -> NdarrayOrTensor: - """ - create a 2D or 3D rotation90 matrix. - - Args: - spatial_dims: {``2``, ``3``} spatial rank - axes: 2 int numbers, defines the plane to rotate with 2 spatial axes. - Default: (0, 1), this is the first two axis in spatial dimensions. - If axis is negative it counts from the last to the first axis. - steps: number of times to rotate by 90 degrees - device: device to compute and store the output (when the backend is "torch"). - backend: APIs to use, ``numpy`` or ``torch``. - """ - _backend = look_up_option(backend, TransformBackends) - if _backend == TransformBackends.NUMPY: - return _create_rotate_90(spatial_dims=spatial_dims, axes=axes, steps=steps, eye_func=np.eye) - if _backend == TransformBackends.TORCH: - return _create_rotate_90( - spatial_dims=spatial_dims, axes=axes, steps=steps, eye_func=lambda rank: torch.eye(rank, device=device) - ) - raise ValueError(f"backend {backend} is not supported") - - -def _create_flip(spatial_dims: int, spatial_axis: Union[Sequence[int], int], eye_func: Callable = np.eye): - affine = eye_func(spatial_dims + 1) - if isinstance(spatial_axis, int): - if spatial_axis < -spatial_dims or spatial_axis >= spatial_dims: - raise ValueError( - "'spatial_axis' values must be between " - f"{-spatial_dims} and {spatial_dims-1} inclusive " - f"('spatial_axis' is {spatial_axis})" - ) - affine[spatial_axis, spatial_axis] = -1 - else: - if any((s < -spatial_dims or s >= spatial_dims) for s in spatial_axis): - raise ValueError( - "'spatial_axis' values must be between " - f"{-spatial_dims} and {spatial_dims-1} inclusive " - f"('spatial_axis' is {spatial_axis})" - ) - - for i in range(spatial_dims): - if i in spatial_axis: - affine[i, i] = -1 - - return affine - - -def create_flip( - spatial_dims: int, - spatial_axis: Union[Sequence[int], int], - device: Optional[torch.device] = None, - backend: str = TransformBackends.NUMPY, -) -> NdarrayOrTensor: - _backend = look_up_option(backend, TransformBackends) - if _backend == TransformBackends.NUMPY: - return _create_flip(spatial_dims=spatial_dims, spatial_axis=spatial_axis, eye_func=np.eye) # type: ignore - if _backend == TransformBackends.TORCH: - return _create_flip( # type: ignore - spatial_dims=spatial_dims, spatial_axis=spatial_axis, eye_func=lambda rank: torch.eye(rank, device=device) - ) - raise ValueError(f"backend {backend} is not supported") - - def generate_spatial_bounding_box( img: NdarrayOrTensor, select_fn: Callable = is_positive, From e067b6e96027d486c586492efa49a6f88d360e34 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 23 Nov 2022 11:26:59 +0000 Subject: [PATCH 17/19] adds docstrings Signed-off-by: Wenqi Li --- monai/transforms/lazy/functional.py | 14 +++++++------- monai/transforms/lazy/utils.py | 7 +++++-- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/monai/transforms/lazy/functional.py b/monai/transforms/lazy/functional.py index 97a18cc08d..abb4413836 100644 --- a/monai/transforms/lazy/functional.py +++ b/monai/transforms/lazy/functional.py @@ -17,9 +17,9 @@ from monai.data.utils import to_affine_nd from monai.transforms.lazy.utils import ( combine_transforms, - is_compatible_kwargs, + is_compatible_apply_kwargs, kwargs_from_pending, - mat_from_pending, + affine_from_pending, resample, ) @@ -28,10 +28,10 @@ def apply(data: Union[torch.Tensor, MetaTensor], pending: Optional[list] = None): """ - This method applies pending transforms to tensors. + This method applies pending transforms to `data` tensors. Args: - data: A torch Tensor, monai MetaTensor + data: A torch Tensor or a monai MetaTensor. pending: pending transforms. This must be set if data is a Tensor, but is optional if data is a MetaTensor. """ if isinstance(data, MetaTensor) and pending is None: @@ -41,15 +41,15 @@ def apply(data: Union[torch.Tensor, MetaTensor], pending: Optional[list] = None) if not pending: return data - cumulative_xform = mat_from_pending(pending[0]) + cumulative_xform = affine_from_pending(pending[0]) cur_kwargs = kwargs_from_pending(pending[0]) for p in pending[1:]: new_kwargs = kwargs_from_pending(p) - if not is_compatible_kwargs(cur_kwargs, new_kwargs): + if not is_compatible_apply_kwargs(cur_kwargs, new_kwargs): # carry out an intermediate resample here due to incompatibility between arguments data = resample(data, cumulative_xform, cur_kwargs) - next_matrix = mat_from_pending(p) + next_matrix = affine_from_pending(p) cumulative_xform = combine_transforms(cumulative_xform, next_matrix) cur_kwargs.update(new_kwargs) data = resample(data, cumulative_xform, cur_kwargs) diff --git a/monai/transforms/lazy/utils.py b/monai/transforms/lazy/utils.py index 5f30ba840c..7bb01c8a6d 100644 --- a/monai/transforms/lazy/utils.py +++ b/monai/transforms/lazy/utils.py @@ -78,7 +78,8 @@ def combine_transforms(left: torch.Tensor, right: torch.Tensor): raise NotImplementedError -def mat_from_pending(pending_item): +def affine_from_pending(pending_item): + """Extract the affine matrix from a pending transform item.""" if isinstance(pending_item, (torch.Tensor, np.ndarray)): return pending_item if isinstance(pending_item, dict): @@ -87,6 +88,7 @@ def mat_from_pending(pending_item): def kwargs_from_pending(pending_item): + """Extract kwargs from a pending transform item.""" if not isinstance(pending_item, dict): return {} ret = { @@ -100,7 +102,8 @@ def kwargs_from_pending(pending_item): return ret -def is_compatible_kwargs(kwargs_1, kwargs_2): +def is_compatible_apply_kwargs(kwargs_1, kwargs_2): + """Check if two sets of kwargs are compatible (to be combined in `apply`).""" return True From c5aa26cb7b9f32eb9ad130991cdddf451a2cb08c Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 23 Nov 2022 13:08:37 +0000 Subject: [PATCH 18/19] adds testing Signed-off-by: Wenqi Li --- monai/transforms/lazy/functional.py | 2 +- monai/transforms/lazy/utils.py | 18 +++++++----------- tests/test_apply.py | 14 +++++++++----- tests/test_resample.py | 16 ++++++++-------- 4 files changed, 25 insertions(+), 25 deletions(-) diff --git a/monai/transforms/lazy/functional.py b/monai/transforms/lazy/functional.py index abb4413836..0536bbe85b 100644 --- a/monai/transforms/lazy/functional.py +++ b/monai/transforms/lazy/functional.py @@ -16,10 +16,10 @@ from monai.data.meta_tensor import MetaTensor from monai.data.utils import to_affine_nd from monai.transforms.lazy.utils import ( + affine_from_pending, combine_transforms, is_compatible_apply_kwargs, kwargs_from_pending, - affine_from_pending, resample, ) diff --git a/monai/transforms/lazy/utils.py b/monai/transforms/lazy/utils.py index 7bb01c8a6d..155c39de1c 100644 --- a/monai/transforms/lazy/utils.py +++ b/monai/transforms/lazy/utils.py @@ -16,7 +16,7 @@ import monai from monai.config import NdarrayOrTensor -from monai.utils import LazyAttr +from monai.utils import LazyAttr, convert_to_tensor __all__ = ["resample", "combine_transforms"] @@ -61,19 +61,15 @@ def is_ddf_shaped(data): return not Affine.is_affine_shaped(data) -def combine_transforms(left: torch.Tensor, right: torch.Tensor): +def combine_transforms(left: torch.Tensor, right: torch.Tensor) -> torch.Tensor: """Given transforms A and B to be applied to x, return the combined transform (AB), so that A(B(x)) becomes AB(x)""" if Affine.is_affine_shaped(left) and Affine.is_affine_shaped(right): # linear transforms - if isinstance(left, Affine): - left = left.data - if isinstance(right, Affine): - right = right.data + left = convert_to_tensor(left.data if isinstance(left, Affine) else left, wrap_sequence=True) + right = convert_to_tensor(right.data if isinstance(right, Affine) else right, wrap_sequence=True) return torch.matmul(left, right) - if DDF.is_ddf_shaped(left) and DDF.is_ddf_shaped(right): # adds DDFs - if isinstance(left, DDF): - left = left.data - if isinstance(right, DDF): - right = right.data + if DDF.is_ddf_shaped(left) and DDF.is_ddf_shaped(right): # adds DDFs, do we need metadata if metatensor input? + left = convert_to_tensor(left.data if isinstance(left, DDF) else left, wrap_sequence=True) + right = convert_to_tensor(right.data if isinstance(right, DDF) else right, wrap_sequence=True) return left + right raise NotImplementedError diff --git a/tests/test_apply.py b/tests/test_apply.py index 45abb00326..afb29ad576 100644 --- a/tests/test_apply.py +++ b/tests/test_apply.py @@ -17,15 +17,20 @@ from monai.transforms.lazy.functional import apply from monai.transforms.utils import create_rotate from monai.utils import LazyAttr, convert_to_tensor +from tests.utils import get_arange_img def single_2d_transform_cases(): return [ - (torch.randn((1, 32, 32)), [{LazyAttr.AFFINE: create_rotate(2, np.pi / 4)}], (1, 32, 32)), - (torch.randn((1, 32, 32)), [create_rotate(2, np.pi / 4)], (1, 32, 32)), ( - torch.randn((1, 16, 16)), - [{LazyAttr.AFFINE: create_rotate(2, np.pi / 4), LazyAttr.SHAPE: (1, 45, 45)}], + torch.as_tensor(get_arange_img((32, 32))), + [{LazyAttr.AFFINE: create_rotate(2, np.pi / 4)}, {LazyAttr.AFFINE: create_rotate(2, -np.pi / 4)}], + (1, 32, 32), + ), + (torch.as_tensor(get_arange_img((32, 32))), [create_rotate(2, np.pi / 2)], (1, 32, 32)), + ( + torch.as_tensor(get_arange_img((16, 16))), + [{LazyAttr.AFFINE: create_rotate(2, np.pi / 2), LazyAttr.SHAPE: (1, 45, 45)}], (1, 45, 45), ), ] @@ -45,7 +50,6 @@ def _test_apply_metatensor_impl(self, tensor, pending_transforms, expected_shape for p in pending_transforms: tensor_.push_pending_operation(p) result, transforms = apply(tensor_) - self.assertEqual(result.shape, expected_shape) SINGLE_TRANSFORM_CASES = single_2d_transform_cases() diff --git a/tests/test_resample.py b/tests/test_resample.py index aa48cee13b..0136552334 100644 --- a/tests/test_resample.py +++ b/tests/test_resample.py @@ -12,10 +12,11 @@ import unittest import torch +from parameterized import parameterized from monai.transforms.lazy.functional import resample from monai.utils import convert_to_tensor -from tests.utils import get_arange_img +from tests.utils import assert_allclose, get_arange_img def rotate_90_2d(): @@ -25,15 +26,14 @@ def rotate_90_2d(): return t -class TestResampleFunction(unittest.TestCase): - def _test_resample_function_impl(self, img, matrix): - resample(convert_to_tensor(img), matrix) +RESAMPLE_FUNCTION_CASES = [(get_arange_img((3, 3)), rotate_90_2d(), [[2, 5, 8], [1, 4, 7], [0, 3, 6]])] - RESAMPLE_FUNCTION_CASES = [(get_arange_img((1, 16, 16)), rotate_90_2d())] - def test_resample_function(self): - for case in self.RESAMPLE_FUNCTION_CASES: - self._test_resample_function_impl(*case) +class TestResampleFunction(unittest.TestCase): + @parameterized.expand(RESAMPLE_FUNCTION_CASES) + def test_resample_function_impl(self, img, matrix, expected): + out = resample(convert_to_tensor(img), matrix) + assert_allclose(out[0], expected, type_test=False) if __name__ == "__main__": From 3c2d1ec20aae7cdea8dabf1b71e91d577924772a Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 23 Nov 2022 16:05:35 +0000 Subject: [PATCH 19/19] DDF -> DisplacementField Signed-off-by: Wenqi Li --- monai/networks/blocks/warp.py | 6 +++--- monai/transforms/lazy/utils.py | 14 ++++++++------ 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/monai/networks/blocks/warp.py b/monai/networks/blocks/warp.py index 7a28e86301..66e39b18f2 100644 --- a/monai/networks/blocks/warp.py +++ b/monai/networks/blocks/warp.py @@ -28,7 +28,7 @@ class Warp(nn.Module): """ - Warp an image with given dense displacement field (DDF). + Warp an image with given dense displacement field (DisplacementField). """ def __init__(self, mode=GridSampleMode.BILINEAR.value, padding_mode=GridSamplePadMode.BORDER.value): @@ -113,7 +113,7 @@ def forward(self, image: torch.Tensor, ddf: torch.Tensor): ddf_shape = (image.shape[0], spatial_dims) + tuple(image.shape[2:]) if ddf.shape != ddf_shape: raise ValueError( - f"Given input {spatial_dims}-d image shape {image.shape}, the input DDF shape must be {ddf_shape}, " + f"Given input {spatial_dims}-d image shape {image.shape}, the input DisplacementField shape must be {ddf_shape}, " f"Got {ddf.shape} instead." ) grid = self.get_reference_grid(ddf) + ddf @@ -134,7 +134,7 @@ def forward(self, image: torch.Tensor, ddf: torch.Tensor): class DVF2DDF(nn.Module): """ - Layer calculates a dense displacement field (DDF) from a dense velocity field (DVF) + Layer calculates a dense displacement field (DisplacementField) from a dense velocity field (DVF) with scaling and squaring. Adapted from: diff --git a/monai/transforms/lazy/utils.py b/monai/transforms/lazy/utils.py index 155c39de1c..4e37e78833 100644 --- a/monai/transforms/lazy/utils.py +++ b/monai/transforms/lazy/utils.py @@ -34,14 +34,14 @@ def is_affine_shaped(data): """Check if the data is an affine matrix.""" if isinstance(data, Affine): return True - if isinstance(data, DDF): + if isinstance(data, DisplacementField): return False if not hasattr(data, "shape") or len(data.shape) < 2: return False return data.shape[-1] in (3, 4) and data.shape[-2] in (3, 4) and data.shape[-1] == data.shape[-2] -class DDF: +class DisplacementField: """A class to represent a dense displacement field.""" __slots__ = ("data",) @@ -52,7 +52,7 @@ def __init__(self, data): @staticmethod def is_ddf_shaped(data): """Check if the data is a DDF.""" - if isinstance(data, DDF): + if isinstance(data, DisplacementField): return True if isinstance(data, Affine): return False @@ -67,9 +67,11 @@ def combine_transforms(left: torch.Tensor, right: torch.Tensor) -> torch.Tensor: left = convert_to_tensor(left.data if isinstance(left, Affine) else left, wrap_sequence=True) right = convert_to_tensor(right.data if isinstance(right, Affine) else right, wrap_sequence=True) return torch.matmul(left, right) - if DDF.is_ddf_shaped(left) and DDF.is_ddf_shaped(right): # adds DDFs, do we need metadata if metatensor input? - left = convert_to_tensor(left.data if isinstance(left, DDF) else left, wrap_sequence=True) - right = convert_to_tensor(right.data if isinstance(right, DDF) else right, wrap_sequence=True) + if DisplacementField.is_ddf_shaped(left) and DisplacementField.is_ddf_shaped( + right + ): # adds DDFs, do we need metadata if metatensor input? + left = convert_to_tensor(left.data if isinstance(left, DisplacementField) else left, wrap_sequence=True) + right = convert_to_tensor(right.data if isinstance(right, DisplacementField) else right, wrap_sequence=True) return left + right raise NotImplementedError