diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml index 484026626f..a710e8e46e 100644 --- a/.github/workflows/pythonapp.yml +++ b/.github/workflows/pythonapp.yml @@ -88,6 +88,10 @@ jobs: name: Install torch cpu from pytorch.org (Windows only) run: | python -m pip install torch==1.12.1+cpu torchvision==0.13.1+cpu -f https://download.pytorch.org/whl/torch_stable.html + - if: runner.os == 'Linux' + name: Install itk pre-release (Linux only) + run: | + python -m pip install --pre -U itk - name: Install the dependencies run: | python -m pip install torch==1.12.1 torchvision==0.13.1 diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 874f01a945..7b728fde48 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -22,11 +22,31 @@ Generic Interfaces :members: :special-members: __call__ +`RandomizableTrait` +^^^^^^^^^^^^^^^^^^^ +.. autoclass:: RandomizableTrait + :members: + +`LazyTrait` +^^^^^^^^^^^ +.. autoclass:: LazyTrait + :members: + +`MultiSampleTrait` +^^^^^^^^^^^^^^^^^^ +.. autoclass:: MultiSampleTrait + :members: + `Randomizable` ^^^^^^^^^^^^^^ .. autoclass:: Randomizable :members: +`LazyTransform` +^^^^^^^^^^^^^^^ +.. autoclass:: LazyTransform + :members: + `RandomizableTransform` ^^^^^^^^^^^^^^^^^^^^^^^ .. autoclass:: RandomizableTransform diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 087d4d1950..34e1368fe2 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -318,7 +318,12 @@ def _get_meta_dict(self, img) -> Dict: """ img_meta_dict = img.GetMetaDataDictionary() - meta_dict = {key: img_meta_dict[key] for key in img_meta_dict.GetKeys() if not key.startswith("ITK_")} + meta_dict = {} + for key in img_meta_dict.GetKeys(): + if key.startswith("ITK_"): + continue + val = img_meta_dict[key] + meta_dict[key] = np.asarray(val) if type(val).__name__.startswith("itk") else val meta_dict["spacing"] = np.asarray(img.GetSpacing()) return meta_dict diff --git a/monai/data/meta_obj.py b/monai/data/meta_obj.py index 5061efc1ce..6aab05dc94 100644 --- a/monai/data/meta_obj.py +++ b/monai/data/meta_obj.py @@ -82,6 +82,7 @@ class MetaObj: def __init__(self): self._meta: dict = MetaObj.get_default_meta() self._applied_operations: list = MetaObj.get_default_applied_operations() + self._pending_operations: list = MetaObj.get_default_applied_operations() # the same default as applied_ops self._is_batch: bool = False @staticmethod @@ -199,6 +200,19 @@ def push_applied_operation(self, t: Any) -> None: def pop_applied_operation(self) -> Any: return self._applied_operations.pop() + @property + def pending_operations(self) -> list[dict]: + """Get the pending operations. Defaults to ``[]``.""" + if hasattr(self, "_pending_operations"): + return self._pending_operations + return MetaObj.get_default_applied_operations() # the same default as applied_ops + + def push_pending_operation(self, t: Any) -> None: + self._pending_operations.append(t) + + def pop_pending_operation(self) -> Any: + return self._pending_operations.pop() + @property def is_batch(self) -> bool: """Return whether object is part of batch or not.""" diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 5a7d81ad8e..493aef848b 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -23,8 +23,8 @@ from monai.data.meta_obj import MetaObj, get_track_meta from monai.data.utils import affine_to_spacing, decollate_batch, list_data_collate, remove_extra_metadata from monai.utils import look_up_option -from monai.utils.enums import MetaKeys, PostFix, SpaceKeys -from monai.utils.type_conversion import convert_data_type, convert_to_tensor +from monai.utils.enums import LazyAttr, MetaKeys, PostFix, SpaceKeys +from monai.utils.type_conversion import convert_data_type, convert_to_numpy, convert_to_tensor __all__ = ["MetaTensor"] @@ -445,6 +445,20 @@ def pixdim(self): return [affine_to_spacing(a) for a in self.affine] return affine_to_spacing(self.affine) + def peek_pending_shape(self): + """Get the currently expected spatial shape as if all the pending operations are executed.""" + res = None + if self.pending_operations: + res = self.pending_operations[-1].get(LazyAttr.SHAPE, None) + # default to spatial shape (assuming channel-first input) + return tuple(convert_to_numpy(self.shape, wrap_sequence=True).tolist()[1:]) if res is None else res + + def peek_pending_affine(self): + res = None + if self.pending_operations: + res = self.pending_operations[-1].get(LazyAttr.AFFINE, None) + return self.affine if res is None else res + def new_empty(self, size, dtype=None, device=None, requires_grad=False): """ must be defined for deepcopy to work diff --git a/monai/networks/nets/attentionunet.py b/monai/networks/nets/attentionunet.py index 177a54e105..a57b57425e 100644 --- a/monai/networks/nets/attentionunet.py +++ b/monai/networks/nets/attentionunet.py @@ -143,12 +143,27 @@ def forward(self, g: torch.Tensor, x: torch.Tensor) -> torch.Tensor: class AttentionLayer(nn.Module): - def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, submodule: nn.Module, dropout=0.0): + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + submodule: nn.Module, + up_kernel_size=3, + strides=2, + dropout=0.0, + ): super().__init__() self.attention = AttentionBlock( spatial_dims=spatial_dims, f_g=in_channels, f_l=in_channels, f_int=in_channels // 2 ) - self.upconv = UpConv(spatial_dims=spatial_dims, in_channels=out_channels, out_channels=in_channels, strides=2) + self.upconv = UpConv( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=in_channels, + strides=strides, + kernel_size=up_kernel_size, + ) self.merge = Convolution( spatial_dims=spatial_dims, in_channels=2 * in_channels, out_channels=in_channels, dropout=dropout ) @@ -174,7 +189,7 @@ class AttentionUnet(nn.Module): channels (Sequence[int]): sequence of channels. Top block first. The length of `channels` should be no less than 2. strides (Sequence[int]): stride to use for convolutions. kernel_size: convolution kernel size. - upsample_kernel_size: convolution kernel size for transposed convolution layers. + up_kernel_size: convolution kernel size for transposed convolution layers. dropout: dropout ratio. Defaults to no dropout. """ @@ -210,9 +225,9 @@ def __init__( ) self.up_kernel_size = up_kernel_size - def _create_block(channels: Sequence[int], strides: Sequence[int], level: int = 0) -> nn.Module: + def _create_block(channels: Sequence[int], strides: Sequence[int]) -> nn.Module: if len(channels) > 2: - subblock = _create_block(channels[1:], strides[1:], level=level + 1) + subblock = _create_block(channels[1:], strides[1:]) return AttentionLayer( spatial_dims=spatial_dims, in_channels=channels[0], @@ -227,17 +242,19 @@ def _create_block(channels: Sequence[int], strides: Sequence[int], level: int = ), subblock, ), + up_kernel_size=self.up_kernel_size, + strides=strides[0], dropout=dropout, ) else: # the next layer is the bottom so stop recursion, - # create the bottom layer as the sublock for this layer - return self._get_bottom_layer(channels[0], channels[1], strides[0], level=level + 1) + # create the bottom layer as the subblock for this layer + return self._get_bottom_layer(channels[0], channels[1], strides[0]) encdec = _create_block(self.channels, self.strides) self.model = nn.Sequential(head, encdec, reduce_channels) - def _get_bottom_layer(self, in_channels: int, out_channels: int, strides: int, level: int) -> nn.Module: + def _get_bottom_layer(self, in_channels: int, out_channels: int, strides: int) -> nn.Module: return AttentionLayer( spatial_dims=self.dimensions, in_channels=in_channels, @@ -249,6 +266,8 @@ def _get_bottom_layer(self, in_channels: int, out_channels: int, strides: int, l strides=strides, dropout=self.dropout, ), + up_kernel_size=self.up_kernel_size, + strides=strides, dropout=self.dropout, ) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 389571d16f..307b5cda28 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -227,6 +227,8 @@ from .inverse_batch_transform import BatchInverseTransform, Decollated, DecollateD, DecollateDict from .io.array import SUPPORTED_READERS, LoadImage, SaveImage from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict +from .lazy.array import Apply +from .lazy.functional import apply from .meta_utility.dictionary import ( FromMetaTensord, FromMetaTensorD, @@ -235,6 +237,13 @@ ToMetaTensorD, ToMetaTensorDict, ) +from .meta_matrix import ( + Grid, + matmul, + Matrix, + MatrixFactory, + MetaMatrix, +) from .nvtx import ( Mark, Markd, @@ -449,7 +458,18 @@ ZoomD, ZoomDict, ) -from .transform import MapTransform, Randomizable, RandomizableTransform, ThreadUnsafe, Transform, apply_transform +from .transform import ( + LazyTrait, + LazyTransform, + MapTransform, + MultiSampleTrait, + Randomizable, + RandomizableTrait, + RandomizableTransform, + ThreadUnsafe, + Transform, + apply_transform, +) from .utility.array import ( AddChannel, AddCoordinateChannels, @@ -621,6 +641,8 @@ generate_label_classes_crop_centers, generate_pos_neg_label_crop_centers, generate_spatial_bounding_box, + get_backend_from_tensor_like, + get_device_from_tensor_like, get_extreme_points, get_largest_connected_component_mask, get_number_image_type_conversions, diff --git a/monai/transforms/lazy/__init__.py b/monai/transforms/lazy/__init__.py new file mode 100644 index 0000000000..1e97f89407 --- /dev/null +++ b/monai/transforms/lazy/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/monai/transforms/lazy/array.py b/monai/transforms/lazy/array.py new file mode 100644 index 0000000000..ae165cf566 --- /dev/null +++ b/monai/transforms/lazy/array.py @@ -0,0 +1,48 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from monai.transforms.lazy.functional import apply +from monai.transforms.inverse import InvertibleTransform + +__all__ = ["Apply"] + + +class Apply(InvertibleTransform): + """ + Apply wraps the apply method and can function as a Transform in either array or dictionary + mode. + """ + + def __init__(self): + super().__init__() + + def __call__(self, *args, **kwargs): + return apply(*args, **kwargs) + + def inverse(self, data): + return NotImplementedError() + + +# class Applyd(MapTransform, InvertibleTransform): +# +# def __init__(self): +# super().__init__() +# +# def __call__( +# self, +# d: dict +# ): +# rd = dict() +# for k, v in d.items(): +# rd[k] = apply(v) +# +# def inverse(self, data): +# return NotImplementedError() diff --git a/monai/transforms/lazy/functional.py b/monai/transforms/lazy/functional.py new file mode 100644 index 0000000000..adb44fc9a3 --- /dev/null +++ b/monai/transforms/lazy/functional.py @@ -0,0 +1,202 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools as it +from typing import Optional, Sequence, Union + +import numpy as np +import torch + +from monai.config import DtypeLike +from monai.data.meta_tensor import MetaTensor +from monai.transforms.meta_matrix import Matrix, MatrixFactory, MetaMatrix, matmul +from monai.transforms.spatial.array import Affine +from monai.transforms.utils import dtypes_to_str_or_identity, get_backend_from_tensor_like, get_device_from_tensor_like +from monai.utils import GridSampleMode, GridSamplePadMode + +__all__ = ["apply"] + +# TODO: This should move to a common place to be shared with dictionary +# from monai.utils.type_conversion import dtypes_to_str_or_identity + +GridSampleModeSequence = Union[Sequence[Union[GridSampleMode, str]], GridSampleMode, str] +GridSamplePadModeSequence = Union[Sequence[Union[GridSamplePadMode, str]], GridSamplePadMode, str] +DtypeSequence = Union[Sequence[DtypeLike], DtypeLike] + + +# TODO: move to mapping_stack.py +def extents_from_shape(shape, dtype=np.float32): + extents = [[0, shape[i]] for i in range(1, len(shape))] + + extents = it.product(*extents) + return [np.asarray(e + (1,), dtype=dtype) for e in extents] + + +# TODO: move to mapping_stack.py +def shape_from_extents( + src_shape: Sequence, extents: Union[Sequence[np.ndarray], Sequence[torch.Tensor], np.ndarray, torch.Tensor] +): + if isinstance(extents, (list, tuple)): + if isinstance(extents[0], np.ndarray): + aextents = np.asarray(extents) + aextents = torch.from_numpy(aextents) + else: + aextents = torch.stack(extents) + else: + if isinstance(extents, np.ndarray): + aextents = torch.from_numpy(extents) + else: + aextents = extents + + mins = aextents.min(axis=0)[0] + maxes = aextents.max(axis=0)[0] + values = torch.round(maxes - mins).type(torch.IntTensor)[:-1] + return torch.cat((torch.as_tensor([src_shape[0]]), values)) + + +def metadata_is_compatible(value_1, value_2): + if value_1 is None: + return True + else: + if value_2 is None: + return True + return value_1 == value_2 + + +def metadata_dtype_is_compatible(value_1, value_2): + if value_1 is None: + return True + else: + if value_2 is None: + return True + + # if we are here, value_1 and value_2 are both set + # TODO: this is not a good enough solution + value_1_ = dtypes_to_str_or_identity(value_1) + value_2_ = dtypes_to_str_or_identity(value_2) + return value_1_ == value_2_ + + +def starting_matrix_and_extents(matrix_factory, data): + # set up the identity matrix and metadata + cumulative_matrix = matrix_factory.identity() + cumulative_extents = extents_from_shape(data.shape) + return cumulative_matrix, cumulative_extents + + +def prepare_args_dict_for_apply(cur_mode, cur_padding_mode, cur_device, cur_dtype): + kwargs = {} + if cur_mode is not None: + kwargs["mode"] = cur_mode + if cur_padding_mode is not None: + kwargs["padding_mode"] = cur_padding_mode + if cur_device is not None: + kwargs["device"] = cur_device + if cur_dtype is not None: + kwargs["dtype"] = cur_dtype + + return kwargs + + +def matrix_from_matrix_container(matrix): + if isinstance(matrix, MetaMatrix): + return matrix.matrix.data + elif isinstance(matrix, Matrix): + return matrix.data + else: + return matrix + + +def apply(data: Union[torch.Tensor, MetaTensor, dict], pending: Optional[Union[dict, list]] = None): + """ + This method applies pending transforms to tensors. + + Args: + data: A torch Tensor, monai MetaTensor, or a dictionary containing Tensors / MetaTensors + pending: Optional arg containing pending transforms. This must be set if data is a Tensor + or dictionary of Tensors, but is optional if data is a MetaTensor / dictionary of + MetaTensors. + """ + if isinstance(data, dict): + rd = dict() + for k, v in data.items(): + result = apply(v, pending) + rd[k] = result + return rd + + if isinstance(data, MetaTensor) and pending is None: + pending_ = data.pending_transforms + else: + pending_ = [] if pending is None else pending + + if len(pending_) == 0: + return data + + dim_count = len(data.shape) - 1 + matrix_factory = MatrixFactory(dim_count, get_backend_from_tensor_like(data), get_device_from_tensor_like(data)) + + # set up the identity matrix and metadata + cumulative_matrix, cumulative_extents = starting_matrix_and_extents(matrix_factory, data) + + # set the various resampling parameters to an initial state + cur_mode = None + cur_padding_mode = None + cur_device = None + cur_dtype = None + cur_shape = data.shape + + for meta_matrix in pending_: + next_matrix = meta_matrix.matrix + # print("intermediate matrix\n", matrix_from_matrix_container(cumulative_matrix)) + cumulative_matrix = matmul(cumulative_matrix, next_matrix) + cumulative_extents = [matmul(e, cumulative_matrix) for e in cumulative_extents] + + new_mode = meta_matrix.metadata.get("mode", None) + new_padding_mode = meta_matrix.metadata.get("padding_mode", None) + new_device = meta_matrix.metadata.get("device", None) + new_dtype = meta_matrix.metadata.get("dtype", None) + new_shape = meta_matrix.metadata.get("shape_override", None) + + mode_compat = metadata_is_compatible(cur_mode, new_mode) + padding_mode_compat = metadata_is_compatible(cur_padding_mode, new_padding_mode) + device_compat = metadata_is_compatible(cur_device, new_device) + dtype_compat = metadata_dtype_is_compatible(cur_dtype, new_dtype) + + if mode_compat is False or padding_mode_compat is False or device_compat is False or dtype_compat is False: + # carry out an intermediate resample here due to incompatibility between arguments + kwargs = prepare_args_dict_for_apply(cur_mode, cur_padding_mode, cur_device, cur_dtype) + + cumulative_matrix_ = matrix_from_matrix_container(cumulative_matrix) + a = Affine(norm_coords=False, affine=cumulative_matrix_, **kwargs) + data, _ = a(img=data) + + cumulative_matrix, cumulative_extents = starting_matrix_and_extents(matrix_factory, data) + + cur_mode = cur_mode if new_mode is None else new_mode + cur_padding_mode = cur_padding_mode if new_padding_mode is None else new_padding_mode + cur_device = cur_device if new_device is None else new_device + cur_dtype = cur_dtype if new_dtype is None else new_dtype + cur_shape = cur_shape if new_shape is None else new_shape + + kwargs = prepare_args_dict_for_apply(cur_mode, cur_padding_mode, cur_device, cur_dtype) + + cumulative_matrix_ = matrix_from_matrix_container(cumulative_matrix) + + # print(f"applying with cumulative matrix\n {cumulative_matrix_}") + a = Affine(norm_coords=False, affine=cumulative_matrix_, spatial_size=cur_shape[1:], normalized=False, **kwargs) + data, tx = a(img=data) + if isinstance(data, MetaTensor): + data.clear_pending_transforms() + for p in pending_: + data.affine = p.matrix.data + data.push_applied_operation(p) + + return data, pending_ diff --git a/monai/transforms/meta_matrix.py b/monai/transforms/meta_matrix.py new file mode 100644 index 0000000000..845fe3f2c5 --- /dev/null +++ b/monai/transforms/meta_matrix.py @@ -0,0 +1,285 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Sequence, Union + +import numpy as np +import torch +from monai.transforms.utils import _create_rotate, _create_rotate_90, _create_flip, _create_shear, _create_scale, \ + _create_translate + +from monai.transforms.utils import get_backend_from_tensor_like, get_device_from_tensor_like +from monai.utils import TransformBackends + +from monai.config import NdarrayOrTensor + +__all__ = ["Grid", "matmul", "Matrix", "MatrixFactory", "MetaMatrix"] + +def is_matrix_shaped(data): + + return ( + len(data.shape) == 2 and data.shape[0] in (3, 4) and data.shape[1] in (3, 4) and data.shape[0] == data.shape[1] + ) + + +def is_grid_shaped(data): + + return len(data.shape) == 3 and data.shape[0] == 3 or len(data.shape) == 4 and data.shape[0] == 4 + + +class MatrixFactory: + + def __init__(self, + dims: int, + backend: TransformBackends, + device: Optional[torch.device] = None): + + if backend == TransformBackends.NUMPY: + if device is not None: + raise ValueError("'device' must be None with TransformBackends.NUMPY") + self._device = None + self._sin = lambda th: np.sin(th, dtype=np.float32) + self._cos = lambda th: np.cos(th, dtype=np.float32) + self._eye = lambda th: np.eye(th, dtype=np.float32) + self._diag = lambda th: np.diag(th, dtype=np.float32) + else: + if device is None: + raise ValueError("'device' must be set with TransformBackends.TORCH") + self._device = device + self._sin = lambda th: torch.sin(torch.as_tensor(th, + dtype=torch.float32, + device=self._device)) + self._cos = lambda th: torch.cos(torch.as_tensor(th, + dtype=torch.float32, + device=self._device)) + self._eye = lambda rank: torch.eye(rank, + device=self._device, + dtype=torch.float32); + self._diag = lambda size: torch.diag(torch.as_tensor(size, + device=self._device, + dtype=torch.float32)) + + self._backend = backend + self._dims = dims + + @staticmethod + def from_tensor(data): + return MatrixFactory(len(data.shape)-1, + get_backend_from_tensor_like(data), + get_device_from_tensor_like(data)) + + def identity(self): + matrix = self._eye(self._dims + 1) + return MetaMatrix(matrix, {}) + + def rotate_euler(self, radians: Union[Sequence[float], float], **extra_args): + matrix = _create_rotate(self._dims, radians, self._sin, self._cos, self._eye) + return MetaMatrix(matrix, extra_args) + + def rotate_90(self, rotations, axis, **extra_args): + matrix = _create_rotate_90(self._dims, rotations, axis) + return MetaMatrix(matrix, extra_args) + + def flip(self, axis, **extra_args): + matrix = _create_flip(self._dims, axis, self._eye) + return MetaMatrix(matrix, extra_args) + + def shear(self, coefs: Union[Sequence[float], float], **extra_args): + matrix = _create_shear(self._dims, coefs, self._eye) + return MetaMatrix(matrix, extra_args) + + def scale(self, factors: Union[Sequence[float], float], **extra_args): + matrix = _create_scale(self._dims, factors, self._diag) + return MetaMatrix(matrix, extra_args) + + def translate(self, offsets: Union[Sequence[float], float], **extra_args): + matrix = _create_translate(self._dims, offsets, self._eye) + return MetaMatrix(matrix, extra_args) + + +def ensure_tensor(data: NdarrayOrTensor): + if isinstance(data, torch.Tensor): + return data + + return torch.as_tensor(data) + + +class Matrix: + def __init__(self, matrix: NdarrayOrTensor): + self.data = ensure_tensor(matrix) + + # def __matmul__(self, other): + # if isinstance(other, Matrix): + # other_matrix = other.data + # else: + # other_matrix = other + # return self.data @ other_matrix + # + # def __rmatmul__(self, other): + # return other.__matmul__(self.data) + + +class Grid: + def __init__(self, grid): + self.data = ensure_tensor(grid) + + # def __matmul__(self, other): + # raise NotImplementedError() + + +class MetaMatrix: + def __init__(self, matrix: Union[NdarrayOrTensor, Matrix, Grid], metadata: Optional[dict] = None): + + if not isinstance(matrix, (Matrix, Grid)): + if matrix.shape == 2: + if matrix.shape[0] != matrix.shape[1] or matrix.shape[0] not in (3, 4): + raise ValueError( + "If 'matrix' is passed a numpy ndarray/torch Tensor, it must" + f" be 3x3 or 4x4 ('matrix' has has shape {matrix.shape})" + ) + matrix_ = Matrix(matrix) + else: + matrix_ = matrix + self.matrix = matrix_ + + self.metadata = metadata or {} + + def __matmul__(self, other): + if isinstance(other, MetaMatrix): + other_ = other.matrix + else: + other_ = other + return MetaMatrix(self.matrix @ other_) + + def __rmatmul__(self, other): + if isinstance(other, MetaMatrix): + other_ = other.matrix + else: + other_ = other + return MetaMatrix(other_ @ self.matrix) + + +def matmul( + left: Union[MetaMatrix, Grid, Matrix, NdarrayOrTensor], right: Union[MetaMatrix, Grid, Matrix, NdarrayOrTensor] +): + matrix_types = (MetaMatrix, Grid, Matrix, torch.Tensor, np.ndarray) + + if not isinstance(left, matrix_types): + raise TypeError(f"'left' must be one of {matrix_types} but is {type(left)}") + if not isinstance(right, matrix_types): + raise TypeError(f"'second' must be one of {matrix_types} but is {type(right)}") + + left_ = left.matrix if isinstance(left, MetaMatrix) else left + right_ = right.matrix if isinstance(right, MetaMatrix) else right + + # TODO: it might be better to not return a metamatrix, unless we pass in the resulting + # metadata also + put_in_metamatrix = isinstance(left, MetaMatrix) or isinstance(right, MetaMatrix) + + put_in_grid = isinstance(left, Grid) or isinstance(right, Grid) + + put_in_matrix = isinstance(left, Matrix) or isinstance(right, Matrix) + put_in_matrix = False if put_in_grid is True else put_in_matrix + + promote_to_tensor = not (isinstance(left_, np.ndarray) and isinstance(right_, np.ndarray)) + + left_raw = left_.data if isinstance(left_, (Matrix, Grid)) else left_ + right_raw = right_.data if isinstance(right_, (Matrix, Grid)) else right_ + + if promote_to_tensor: + left_raw = torch.as_tensor(left_raw) + right_raw = torch.as_tensor(right_raw) + + if isinstance(left_, Grid): + if isinstance(right_, Grid): + raise RuntimeError("Unable to matrix multiply two Grids") + else: + result = matmul_grid_matrix(left_raw, right_raw) + else: + if isinstance(right_, Grid): + result = matmul_matrix_grid(left_raw, right_raw) + else: + result = matmul_matrix_matrix(left_raw, right_raw) + + if put_in_grid: + result = Grid(result) + elif put_in_matrix: + result = Matrix(result) + + if put_in_metamatrix: + result = MetaMatrix(result) + + return result + + +def matmul_matrix_grid(left: NdarrayOrTensor, right: NdarrayOrTensor): + if not is_matrix_shaped(left): + raise ValueError(f"'left' should be a 2D or 3D homogenous matrix but has shape {left.shape}") + + if not is_grid_shaped(right): + raise ValueError( + "'right' should be a 3D array with shape[0] == 2 or a " + f"4D array with shape[0] == 3 but has shape {right.shape}" + ) + + # flatten the grid to take advantage of torch batch matrix multiply + right_flat = right.reshape(right.shape[0], -1) + result_flat = left @ right_flat + # restore the grid shape + result = result_flat.reshape((-1,) + result_flat.shape[1:]) + return result + + +def matmul_grid_matrix(left: NdarrayOrTensor, right: NdarrayOrTensor): + if not is_grid_shaped(left): + raise ValueError( + "'left' should be a 3D array with shape[0] == 2 or a " + f"4D array with shape[0] == 3 but has shape {left.shape}" + ) + + if not is_matrix_shaped(right): + raise ValueError(f"'right' should be a 2D or 3D homogenous matrix but has shape {right.shape}") + + try: + inv_matrix = torch.inverse(right) + except RuntimeError: + # the matrix is not invertible, so we will have to perform a slow grid to matrix operation + return matmul_grid_matrix_slow(left, right) + + # invert the matrix and swap the arguments, taking advantage of + # matrix @ vector == vector_transposed @ matrix_inverse + return matmul_matrix_grid(inv_matrix, left) + + +def matmul_grid_matrix_slow(left: NdarrayOrTensor, right: NdarrayOrTensor): + if not is_grid_shaped(left): + raise ValueError( + "'left' should be a 3D array with shape[0] == 2 or a " + f"4D array with shape[0] == 3 but has shape {left.shape}" + ) + + if not is_matrix_shaped(right): + raise ValueError(f"'right' should be a 2D or 3D homogenous matrix but has shape {right.shape}") + + flat_left = left.reshape(left.shape[0], -1) + result_flat = torch.zeros_like(flat_left) + for i in range(flat_left.shape[1]): + vector = flat_left[:, i][None, :] + result_vector = vector @ right + result_flat[:, i] = result_vector[0, :] + + # restore the grid shape + result = result_flat.reshape((-1,) + result_flat.shape[1:]) + return result + + +def matmul_matrix_matrix(left: NdarrayOrTensor, right: NdarrayOrTensor): + return left @ right diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 21d057f5d3..b1a7d9b4db 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -26,7 +26,18 @@ from monai.utils.enums import TransformBackends from monai.utils.misc import MONAIEnvVars -__all__ = ["ThreadUnsafe", "apply_transform", "Randomizable", "RandomizableTransform", "Transform", "MapTransform"] +__all__ = [ + "ThreadUnsafe", + "apply_transform", + "LazyTrait", + "RandomizableTrait", + "MultiSampleTrait", + "Randomizable", + "LazyTransform", + "RandomizableTransform", + "Transform", + "MapTransform", +] ReturnType = TypeVar("ReturnType") @@ -118,6 +129,56 @@ def _log_stats(data, prefix: Optional[str] = "Data"): raise RuntimeError(f"applying transform {transform}") from e +class LazyTrait: + """ + An interface to indicate that the transform has the capability to execute using + MONAI's lazy resampling feature. In order to do this, the implementing class needs + to be able to describe its operation as an affine matrix or grid with accompanying metadata. + This interface can be extended from by people adapting transforms to the MONAI framework as + well as by implementors of MONAI transforms. + """ + + @property + def lazy_evaluation(self): + """ + Get whether lazy_evaluation is enabled for this transform instance. + Returns: + True if the transform is operating in a lazy fashion, False if not. + """ + raise NotImplementedError() + + @lazy_evaluation.setter + def lazy_evaluation(self, enabled: bool): + """ + Set whether lazy_evaluation is enabled for this transform instance. + Args: + enabled: True if the transform should operate in a lazy fashion, False if not. + """ + raise NotImplementedError() + + +class RandomizableTrait: + """ + An interface to indicate that the transform has the capability to perform + randomized transforms to the data that it is called upon. This interface + can be extended from by people adapting transforms to the MONAI framework as well as by + implementors of MONAI transforms. + """ + + pass + + +class MultiSampleTrait: + """ + An interface to indicate that the transform has the capability to return multiple samples + given an input, such as when performing random crops of a sample. This interface can be + extended from by people adapting transforms to the MONAI framework as well as by implementors + of MONAI transforms. + """ + + pass + + class ThreadUnsafe: """ A class to denote that the transform will mutate its member variables, @@ -251,7 +312,27 @@ def __call__(self, data: Any): raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") -class RandomizableTransform(Randomizable, Transform): +class LazyTransform(Transform, LazyTrait): + """ + An implementation of functionality for lazy transforms that can be subclassed by array and + dictionary transforms to simplify implementation of new lazy transforms. + """ + + def __init__(self, lazy_evaluation: Optional[bool] = True): + self.lazy_evaluation = lazy_evaluation + + @property + def lazy_evaluation(self): + return self.lazy_evaluation + + @lazy_evaluation.setter + def lazy_evaluation(self, lazy_evaluation: bool): + if not isinstance(lazy_evaluation, bool): + raise TypeError("'lazy_evaluation must be a bool but is of " f"type {type(lazy_evaluation)}'") + self.lazy_evaluation = lazy_evaluation + + +class RandomizableTransform(Randomizable, Transform, RandomizableTrait): """ An interface for handling random state locally, currently based on a class variable `R`, which is an instance of `np.random.RandomState`. diff --git a/monai/transforms/utility/functional.py b/monai/transforms/utility/functional.py new file mode 100644 index 0000000000..7aec22c723 --- /dev/null +++ b/monai/transforms/utility/functional.py @@ -0,0 +1,32 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Union + +import torch +from monai.transforms import Affine + +from monai.config import NdarrayOrTensor +from monai.transforms.meta_matrix import Grid, Matrix + + +def resample( + data: torch.Tensor, + matrix: Union[NdarrayOrTensor, Matrix, Grid], + kwargs: Optional[dict] = None +): + """ + This is a minimal implementation of resample that always uses Affine. + """ + if kwargs is not None: + a = Affine(affine=matrix, **kwargs) + else: + a = Affine(affine=matrix) + return a(img=data) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index e96d906f20..a982471cfe 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -85,6 +85,8 @@ "generate_label_classes_crop_centers", "generate_pos_neg_label_crop_centers", "generate_spatial_bounding_box", + "get_backend_from_tensor_like", + "get_device_from_tensor_like", "get_extreme_points", "get_largest_connected_component_mask", "remove_small_objects", @@ -889,6 +891,127 @@ def _create_translate( return array_func(affine) # type: ignore +def _create_rotate_90( + spatial_dims: int, + axis: Tuple[int, int], + steps: Optional[int] = 1, + eye_func: Callable = np.eye +) -> NdarrayOrTensor: + + values = [(1, 0, 0, 1), + (0, -1, 1, 0), + (-1, 0, 0, -1), + (0, 1, -1, 0)] + + if spatial_dims == 2: + if axis != (0, 1): + raise ValueError(f"if 'spatial_dims' is 2, 'axis' must be (0, 1) but is {axis}") + elif spatial_dims == 3: + if axis not in ((0, 1), (0, 2), (1, 2)): + raise ValueError("if 'spatial_dims' is 3, 'axis' must be (0,1), (0, 2), or (1, 2) " + f"but is {axis}") + else: + raise ValueError(f"'spatial_dims' must be 2 or 3 but is {spatial_dims}") + + steps_ = steps % 4 + + affine = eye_func(spatial_dims + 1) + + if spatial_dims == 2: + a, b = 0, 1 + else: + a, b = axis + + affine[a, a], affine[a, b], affine[b, a], affine[b, b] = values[steps] + return affine + + +def create_rotate_90( + spatial_dims: int, + axis: int, + steps: Optional[int] = 1, + device: Optional[torch.device] = None, + backend: str = TransformBackends.NUMPY, +) -> NdarrayOrTensor: + """ + create a 2D or 3D rotation matrix + + Args: + spatial_dims: {``2``, ``3``} spatial rank + radians: rotation radians + when spatial_dims == 3, the `radians` sequence corresponds to + rotation in the 1st, 2nd, and 3rd dim respectively. + device: device to compute and store the output (when the backend is "torch"). + backend: APIs to use, ``numpy`` or ``torch``. + + Raises: + ValueError: When ``radians`` is empty. + ValueError: When ``spatial_dims`` is not one of [2, 3]. + + """ + _backend = look_up_option(backend, TransformBackends) + if _backend == TransformBackends.NUMPY: + return _create_rotate_90( + spatial_dims=spatial_dims, + axis=axis, + steps=steps, + eye_func=np.eye) + if _backend == TransformBackends.TORCH: + return _create_rotate_90( + spatial_dims=spatial_dims, + axis=axis, + steps=steps, + eye_func=lambda rank: torch.eye(rank, device=device), + ) + raise ValueError(f"backend {backend} is not supported") + + +def _create_flip( + spatial_dims: int, + spatial_axis: Union[Sequence[int], int], + eye_func: Callable = np.eye +): + affine = eye_func(spatial_dims + 1) + if isinstance(spatial_axis, int): + if spatial_axis < -spatial_dims or spatial_axis >= spatial_dims: + raise ValueError("'spatial_axis' values must be between " + f"{-spatial_dims} and {spatial_dims-1} inclusive " + f"('spatial_axis' is {spatial_axis})") + affine[spatial_axis, spatial_axis] = -1 + else: + if any((s < -spatial_dims or s >= spatial_dims) for s in spatial_axis): + raise ValueError("'spatial_axis' values must be between " + f"{-spatial_dims} and {spatial_dims-1} inclusive " + f"('spatial_axis' is {spatial_axis})") + + for i in range(spatial_dims): + if i in spatial_axis: + affine[i, i] = -1 + + return affine + + +def create_flip( + spatial_dims: int, + spatial_axis: Union[Sequence[int], int], + device: Optional[torch.device] = None, + backend: str = TransformBackends.NUMPY, +) -> NdarrayOrTensor: + _backend = look_up_option(backend, TransformBackends) + if _backend == TransformBackends.NUMPY: + return _create_flip( + spatial_dims=spatial_dims, + spatial_axis=spatial_axis, + eye_func=np.eye) + if _backend == TransformBackends.TORCH: + return _create_flip( + spatial_dims=spatial_dims, + spatial_axis=spatial_axis, + eye_func=lambda rank: torch.eye(rank, device=device), + ) + raise ValueError(f"backend {backend} is not supported") + + def generate_spatial_bounding_box( img: NdarrayOrTensor, select_fn: Callable = is_positive, @@ -1790,5 +1913,97 @@ def squarepulse(sig, duty: float = 0.5): return y +def get_device_from_tensor_like(data: NdarrayOrTensor): + """ + This function returns the device of `data`, which must either be a numpy ndarray or a + pytorch Tensor. + + Args: + data: the ndarray/tensor to return the device of + + Returns: + None if `data` is a numpy array, or the device of the pytorch Tensor + """ + if isinstance(data, np.ndarray): + return None + elif isinstance(data, torch.Tensor): + return data.device + else: + msg = "'data' must be one of numpy ndarray or torch Tensor but is {}" + raise ValueError(msg.format(type(data))) + + +def get_backend_from_tensor_like(data: NdarrayOrTensor): + """ + This function returns the backend of `data`, which must either be a numpy ndarray or a + pytorch Tensor. + + Args: + data: the ndarray/tensor to return the device of + + Returns: + None if `data` is a numpy array, or the device of the pytorch Tensor + """ + if isinstance(data, np.ndarray): + return TransformBackends.NUMPY + elif isinstance(data, torch.Tensor): + return TransformBackends.TORCH + else: + msg = "'data' must be one of numpy ndarray or torch Tensor but is {}" + raise ValueError(msg.format(type(data))) + + +def dtype_torch_to_numpy(dtype: torch.dtype) -> np.dtype: + """Convert a torch dtype to its numpy equivalent.""" + return torch.empty([], dtype=dtype).numpy().dtype # type: ignore + + +def dtype_numpy_to_torch(dtype: np.dtype) -> torch.dtype: + """Convert a numpy dtype to its torch equivalent.""" + return torch.from_numpy(np.empty([], dtype=dtype)).dtype + + +__dtype_dict = { + np.int8: "int8", + torch.int8: "int8", + np.int16: "int16", + torch.int16: "int16", + int: "int32", + np.int32: "int32", + torch.int32: "int32", + np.int64: "int64", + torch.int64: "int64", + np.uint8: "uint8", + torch.uint8: "uint8", + np.uint16: "uint16", + np.uint32: "uint32", + np.uint64: "uint64", + float: "float32", + np.float16: "float16", + np.float: "float32", + np.float32: "float32", + np.float64: "float64", + torch.float16: "float16", + torch.float: "float32", + torch.float32: "float32", + torch.double: "float64", + torch.float64: "float64", +} + + +def dtypes_to_str_or_identity(dtype: Any) -> Any: + return __dtype_dict.get(dtype, dtype) + + +def get_numpy_dtype_from_string(dtype: str) -> np.dtype: + """Get a numpy dtype (e.g., `np.float32`) from its string (e.g., `"float32"`).""" + return np.empty([], dtype=dtype).dtype + + +def get_torch_dtype_from_string(dtype: str) -> torch.dtype: + """Get a torch dtype (e.g., `torch.float32`) from its string (e.g., `"float32"`).""" + return dtype_numpy_to_torch(get_numpy_dtype_from_string(dtype)) + + if __name__ == "__main__": print_transform_backends() diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index c5419cb9af..21d3621090 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -34,6 +34,7 @@ InterpolateMode, InverseKeys, JITMetadataKeys, + LazyAttr, LossReduction, MetaKeys, Method, diff --git a/monai/utils/enums.py b/monai/utils/enums.py index 79edbd7451..4fd9bea557 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -54,6 +54,7 @@ "AlgoEnsembleKeys", "HoVerNetMode", "HoVerNetBranch", + "LazyAttr", ] @@ -616,3 +617,16 @@ class HoVerNetBranch(StrEnum): HV = "horizontal_vertical" NP = "nucleus_prediction" NC = "type_prediction" + + +class LazyAttr(StrEnum): + """ + MetaTensor with pending operations requires some key attributes tracked especially when the primary array + is not up-to-date due to lazy evaluation. + This class specifies the set of key attributes to be tracked for each MetaTensor. + """ + + SHAPE = "lazy_shape" # spatial shape + AFFINE = "lazy_affine" + PADDING_MODE = "lazy_padding_mode" + INTERP_MODE = "lazy_interpolation_mode" diff --git a/monai/visualize/gradient_based.py b/monai/visualize/gradient_based.py index 7f4ddce1d0..7ab6ef260d 100644 --- a/monai/visualize/gradient_based.py +++ b/monai/visualize/gradient_based.py @@ -90,7 +90,7 @@ def get_grad(self, x: torch.Tensor, index: torch.Tensor | int | None, retain_gra x.requires_grad = True self._model(x, class_idx=index, retain_graph=retain_graph, **kwargs) - grad: torch.Tensor = x.grad.detach() + grad: torch.Tensor = x.grad.detach() # type: ignore return grad def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None, **kwargs) -> torch.Tensor: diff --git a/tests/test_apply.py b/tests/test_apply.py new file mode 100644 index 0000000000..019bae42f1 --- /dev/null +++ b/tests/test_apply.py @@ -0,0 +1,69 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from monai.utils import convert_to_tensor + +from monai.transforms.lazy.functional import apply +from monai.transforms.meta_matrix import MetaMatrix + + +def rotate_45_2d(): + t = torch.eye(3) + t[:, 0] = torch.FloatTensor([0, -1, 0]) + t[:, 1] = torch.FloatTensor([1, 0, 0]) + return t + + +class TestApply(unittest.TestCase): + + def _test_apply_impl(self, tensor, pending_transforms): + print(tensor.shape) + result = apply(tensor, pending_transforms) + self.assertListEqual(result[1], pending_transforms) + + def _test_apply_metatensor_impl(self, tensor, pending_transforms, pending_as_parameter): + tensor_ = convert_to_tensor(tensor) + if pending_as_parameter: + result = apply(tensor_, pending_transforms) + else: + for p in pending_transforms: + # TODO: cannot do the next part until the MetaTensor PR #5107 is in + # tensor_.push_pending(p) + raise NotImplementedError() + + def _test_apply_metatensor_impl(self, tensor, pending_transforms, pending_as_parameter): + tensor_ = convert_to_tensor(tensor) + if pending_as_parameter: + result = apply(tensor_, pending_transforms) + else: + for p in pending_transforms: + # TODO: cannot do the next part until the MetaTensor PR #5107 is in + # tensor_.push_pending(p) + raise NotImplementedError() + + SINGLE_TRANSFORM_CASES = [ + (torch.randn((1, 16, 16)), [MetaMatrix(rotate_45_2d(), {"id": "rotate"})]) + ] + + def test_apply_single_transform(self): + for case in self.SINGLE_TRANSFORM_CASES: + self._test_apply_impl(*case) + + def test_apply_single_transform_metatensor(self): + for case in self.SINGLE_TRANSFORM_CASES: + self._test_apply_metatensor_impl(*case, False) + + def test_apply_single_transform_metatensor_override(self): + for case in self.SINGLE_TRANSFORM_CASES: + self._test_apply_metatensor_impl(*case, True) diff --git a/tests/test_attentionunet.py b/tests/test_attentionunet.py index b2f53f9c16..e1df7b8acd 100644 --- a/tests/test_attentionunet.py +++ b/tests/test_attentionunet.py @@ -39,7 +39,7 @@ def test_attentionunet(self): shape = (3, 1) + (92,) * dims input = torch.rand(*shape) model = att.AttentionUnet( - spatial_dims=dims, in_channels=1, out_channels=2, channels=(3, 4, 5), strides=(2, 2) + spatial_dims=dims, in_channels=1, out_channels=2, channels=(3, 4, 5), up_kernel_size=5, strides=(1, 2) ) output = model(input) self.assertEqual(output.shape[2:], input.shape[2:]) diff --git a/tests/test_load_image.py b/tests/test_load_image.py index 834c9f9d59..1db39a310b 100644 --- a/tests/test_load_image.py +++ b/tests/test_load_image.py @@ -315,7 +315,7 @@ def test_channel_dim(self, input_param, filename, expected_shape): with tempfile.TemporaryDirectory() as tempdir: filename = os.path.join(tempdir, filename) nib.save(nib.Nifti1Image(test_image, np.eye(4)), filename) - result = LoadImage(image_only=True, **input_param)(filename) + result = LoadImage(image_only=True, **input_param)(filename) # with itk, meta has 'qto_xyz': itkMatrixF44 self.assertTupleEqual( result.shape, (3, 128, 128, 128) if input_param.get("ensure_channel_first", False) else expected_shape diff --git a/tests/test_matmul.py b/tests/test_matmul.py new file mode 100644 index 0000000000..2ac3b70a2e --- /dev/null +++ b/tests/test_matmul.py @@ -0,0 +1,143 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.transforms.meta_matrix import ( + Grid, + Matrix, + is_grid_shaped, + is_matrix_shaped, + matmul, + matmul_grid_matrix, + matmul_grid_matrix_slow, + matmul_matrix_grid, +) + + +class TestMatmulFunctions(unittest.TestCase): + def test_matrix_grid_and_grid_inv_matrix_are_equivalent(self): + grid = torch.randn((3, 32, 32)) + + matrix = torch.eye(3, 3) + matrix[:, 0] = torch.FloatTensor([0, -1, 0]) + matrix[:, 1] = torch.FloatTensor([1, 0, 0]) + + inv_matrix = torch.inverse(matrix) + + result1 = matmul_matrix_grid(matrix, grid) + result2 = matmul_grid_matrix(grid, inv_matrix) + self.assertTrue(torch.allclose(result1, result2)) + + def test_matmul_grid_matrix_slow(self): + grid = torch.randn((3, 32, 32)) + + matrix = torch.eye(3, 3) + matrix[:, 0] = torch.FloatTensor([0, -1, 0]) + matrix[:, 1] = torch.FloatTensor([1, 0, 0]) + + result1 = matmul_grid_matrix_slow(grid, matrix) + result2 = matmul_grid_matrix(grid, matrix) + self.assertTrue(torch.allclose(result1, result2)) + + MATMUL_TEST_CASES = [ + [np.eye(3, dtype=np.float32), np.eye(3, dtype=np.float32), np.ndarray], + [np.eye(3, dtype=np.float32), torch.eye(3), torch.Tensor], + [np.eye(3, dtype=np.float32), Matrix(torch.eye(3)), Matrix], + [np.eye(3, dtype=np.float32), Grid(torch.randn((3, 8, 8))), Grid], + [torch.eye(3), np.eye(3, dtype=np.float32), torch.Tensor], + [torch.eye(3), torch.eye(3), torch.Tensor], + [torch.eye(3), Matrix(torch.eye(3)), Matrix], + [torch.eye(3), Grid(torch.randn((3, 8, 8))), Grid], + [Matrix(torch.eye(3)), np.eye(3, dtype=np.float32), Matrix], + [Matrix(torch.eye(3)), torch.eye(3), Matrix], + [Matrix(torch.eye(3)), Matrix(torch.eye(3)), Matrix], + [Matrix(torch.eye(3)), Grid(torch.randn((3, 8, 8))), Grid], + [Grid(torch.randn((3, 8, 8))), np.eye(3, dtype=np.float32), Grid], + [Grid(torch.randn((3, 8, 8))), torch.eye(3), Grid], + [Grid(torch.randn((3, 8, 8))), Matrix(torch.eye(3)), Grid], + [Grid(torch.randn((3, 8, 8))), Grid(torch.randn((3, 8, 8))), None], + ] + + def _test_matmul_correct_return_type_impl(self, left, right, expected): + if expected is None: + with self.assertRaises(RuntimeError): + result = matmul(left, right) + else: + result = matmul(left, right) + self.assertIsInstance(result, expected) + + @parameterized.expand(MATMUL_TEST_CASES) + def test_matmul_correct_return_type(self, left, right, expected): + self._test_matmul_correct_return_type_impl(left, right, expected) + + # def test_all_matmul_correct_return_type(self): + # for case in self.MATMUL_TEST_CASES: + # with self.subTest(f"{case}"): + # self._test_matmul_correct_return_type_impl(*case) + + MATRIX_SHAPE_TESTCASES = [ + (torch.randn(2, 2), False), + (torch.randn(3, 3), True), + (torch.randn(4, 4), True), + (torch.randn(5, 5), False), + (torch.randn(3, 4), False), + (torch.randn(4, 3), False), + (torch.randn(3), False), + (torch.randn(4), False), + (torch.randn(5), False), + (torch.randn(3, 3, 3), False), + (torch.randn(4, 4, 4), False), + (torch.randn(5, 5, 5), False), + ] + + def _test_is_matrix_shaped_impl(self, matrix, expected): + self.assertEqual(is_matrix_shaped(matrix), expected) + + @parameterized.expand(MATRIX_SHAPE_TESTCASES) + def test_is_matrix_shaped(self, matrix, expected): + self._test_is_matrix_shaped_impl(matrix, expected) + + # def test_all_is_matrix_shaped(self): + # for case in self.MATRIX_SHAPE_TESTCASES: + # with self.subTest(f"{case[0].shape}"): + # self._test_is_matrix_shaped_impl(*case) + + GRID_SHAPE_TESTCASES = [ + (torch.randn(1, 16, 32), False), + (torch.randn(2, 16, 32), False), + (torch.randn(3, 16, 32), True), + (torch.randn(4, 16, 32), False), + (torch.randn(5, 16, 32), False), + (torch.randn(3, 16, 32, 64), False), + (torch.randn(4, 16, 32, 64), True), + (torch.randn(5, 16, 32, 64), False), + ] + + def _test_is_grid_shaped_impl(self, grid, expected): + self.assertEqual(is_grid_shaped(grid), expected) + + @parameterized.expand(GRID_SHAPE_TESTCASES) + def test_is_grid_shaped(self, grid, expected): + self._test_is_grid_shaped_impl(grid, expected) + + # def test_all_is_grid_shaped(self): + # for case in self.GRID_SHAPE_TESTCASES: + # with self.subTest(f"{case[0].shape}"): + # self._test_is_grid_shaped_impl(*case) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index b46905f3c1..20d25ef61c 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -495,6 +495,15 @@ def test_construct_with_pre_applied_transforms(self): m = MetaTensor(im, applied_operations=data["im"].applied_operations) self.assertEqual(len(m.applied_operations), len(tr.transforms)) + def test_pending_ops(self): + m, _ = self.get_im() + self.assertEqual(m.pending_operations, []) + self.assertEqual(m.peek_pending_shape(), (10, 8)) + self.assertIsInstance(m.peek_pending_affine(), torch.Tensor) + m.push_pending_operation({}) + self.assertEqual(m.peek_pending_shape(), (10, 8)) + self.assertIsInstance(m.peek_pending_affine(), torch.Tensor) + @parameterized.expand(TESTS) def test_multiprocessing(self, device=None, dtype=None): """multiprocessing sharing with 'device' and 'dtype'""" diff --git a/tests/test_nifti_rw.py b/tests/test_nifti_rw.py index 31e1de3fe9..6edf53d339 100644 --- a/tests/test_nifti_rw.py +++ b/tests/test_nifti_rw.py @@ -157,8 +157,8 @@ def test_write_2d(self): writer_obj.set_metadata({"affine": np.diag([1, 1, 1]), "original_affine": np.diag([1.4, 1, 1])}) writer_obj.write(image_name, verbose=True) out = nib.load(image_name) - np.testing.assert_allclose(out.get_fdata(), [[0, 1, 2], [3.0, 4, 5]]) - np.testing.assert_allclose(out.affine, np.diag([1.4, 1, 1, 1])) + np.testing.assert_allclose(out.get_fdata(), [[0, 1, 2], [3.0, 4, 5]], atol=1e-4, rtol=1e-4) + np.testing.assert_allclose(out.affine, np.diag([1.4, 1, 1, 1]), atol=1e-4, rtol=1e-4) image_name = os.path.join(out_dir, "test1.nii.gz") img = np.arange(5).reshape((1, 5)) @@ -168,8 +168,8 @@ def test_write_2d(self): ) writer_obj.write(image_name, verbose=True) out = nib.load(image_name) - np.testing.assert_allclose(out.get_fdata(), [[0, 2, 4]]) - np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 1, 1])) + np.testing.assert_allclose(out.get_fdata(), [[0, 2, 4]], atol=1e-4, rtol=1e-4) + np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 1, 1]), atol=1e-4, rtol=1e-4) def test_write_3d(self): with tempfile.TemporaryDirectory() as out_dir: @@ -192,8 +192,8 @@ def test_write_3d(self): ) writer_obj.write(image_name, verbose=True) out = nib.load(image_name) - np.testing.assert_allclose(out.get_fdata(), [[[0, 2, 4]]]) - np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 2, 1])) + np.testing.assert_allclose(out.get_fdata(), [[[0, 2, 4]]], atol=1e-4, rtol=1e-4) + np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 2, 1]), atol=1e-4, rtol=1e-4) def test_write_4d(self): with tempfile.TemporaryDirectory() as out_dir: @@ -216,8 +216,8 @@ def test_write_4d(self): ) writer_obj.write(image_name, verbose=True) out = nib.load(image_name) - np.testing.assert_allclose(out.get_fdata(), [[[[0], [2], [4]]]]) - np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 2, 1])) + np.testing.assert_allclose(out.get_fdata(), [[[[0], [2], [4]]]], atol=1e-4, rtol=1e-4) + np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 2, 1]), atol=1e-4, rtol=1e-4) def test_write_5d(self): with tempfile.TemporaryDirectory() as out_dir: @@ -241,8 +241,10 @@ def test_write_5d(self): writer_obj.set_metadata({"affine": np.diag([1, 1, 1, 3]), "original_affine": np.diag([1.4, 2.0, 2, 3])}) writer_obj.write(image_name, verbose=True) out = nib.load(image_name) - np.testing.assert_allclose(out.get_fdata(), np.array([[[[[0.0, 2.0]], [[4.0, 5.0]], [[7.0, 9.0]]]]])) - np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 2, 1])) + np.testing.assert_allclose( + out.get_fdata(), np.array([[[[[0.0, 2.0]], [[4.0, 5.0]], [[7.0, 9.0]]]]]), atol=1e-4, rtol=1e-4 + ) + np.testing.assert_allclose(out.affine, np.diag([1.4, 2, 2, 1]), atol=1e-4, rtol=1e-4) if __name__ == "__main__": diff --git a/tests/test_randomizable_transform_type.py b/tests/test_randomizable_transform_type.py new file mode 100644 index 0000000000..9f77d2cd5a --- /dev/null +++ b/tests/test_randomizable_transform_type.py @@ -0,0 +1,33 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from monai.transforms.transform import RandomizableTrait, RandomizableTransform + + +class InheritsInterface(RandomizableTrait): + pass + + +class InheritsImplementation(RandomizableTransform): + def __call__(self, data): + return data + + +class TestRandomizableTransformType(unittest.TestCase): + def test_is_randomizable_transform_type(self): + inst = InheritsInterface() + self.assertIsInstance(inst, RandomizableTrait) + + def test_set_random_state_randomizable_transform(self): + inst = InheritsImplementation() + inst.set_random_state(0) diff --git a/tests/test_resample.py b/tests/test_resample.py new file mode 100644 index 0000000000..4e33517bd1 --- /dev/null +++ b/tests/test_resample.py @@ -0,0 +1,42 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from monai.transforms.utility.functional import resample + +from monai.utils import convert_to_tensor + +from tests.utils import get_arange_img + + +def rotate_45_2d(): + t = torch.eye(3) + t[:, 0] = torch.FloatTensor([0, -1, 0]) + t[:, 1] = torch.FloatTensor([1, 0, 0]) + return t + + +class TestResampleFunction(unittest.TestCase): + + def _test_resample_function_impl(self, img, matrix): + result = resample(convert_to_tensor(img), matrix) + print(result) + + RESAMPLE_FUNCTION_CASES = [ + (get_arange_img((1, 16, 16)), rotate_45_2d()) + ] + + def test_resample_function(self): + for case in self.RESAMPLE_FUNCTION_CASES: + self._test_resample_function_impl(*case) diff --git a/tests/utils.py b/tests/utils.py index b16b4b13fb..bbceb9cfc4 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -347,6 +347,24 @@ def make_rand_affine(ndim: int = 3, random_state: Optional[np.random.RandomState return af +def get_arange_img(size, dtype=torch.float32, offset=0): + """ + Returns an 2d or 3d image as a numpy tensor (complete with channel as dim 0) + with contents that iterate like an arange. + """ + img = torch.zeros(size, dtype=dtype) + if len(size) == 2: + for j in range(size[0]): + for i in range(size[1]): + img[j, i] = i + j * size[0] + offset + else: + for k in range(size[0]): + for j in range(size[1]): + for i in range(size[2]): + img[k, j, i] = i + j * size[0] + k * size[0] * size[1] + offset + return np.expand_dims(img, 0) + + class DistTestCase(unittest.TestCase): """ testcase without _outcome, so that it's picklable.