diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 8897371903..39fac39ca2 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -11,6 +11,8 @@ from __future__ import annotations +import copy + import warnings from copy import deepcopy from typing import Any, Sequence @@ -151,6 +153,30 @@ def __init__( if MetaKeys.SPACE not in self.meta: self.meta[MetaKeys.SPACE] = SpaceKeys.RAS # defaulting to the right-anterior-superior space + self._pending_transforms = list() + + def push_pending_transform(self, meta_matrix): + self._pending_transforms.append(meta_matrix) + + @property + def has_pending_transforms(self): + return len(self._pending_transforms) > 0 + + def peek_pending_transform(self): + return copy.deepcopy(self._pending_transforms[-1]) + + def pop_pending_transform(self): + transform = self._pending_transforms[0] + self._pending_transforms.pop(0) + return transform + + @property + def pending_transforms(self): + return copy.deepcopy(self._pending_transforms) + + def clear_pending_transforms(self): + self._pending_transforms = list() + @staticmethod def update_meta(rets: Sequence, func, args, kwargs) -> Sequence: """ diff --git a/monai/transforms/atmostonce/__init__.py b/monai/transforms/atmostonce/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/monai/transforms/atmostonce/apply.py b/monai/transforms/atmostonce/apply.py new file mode 100644 index 0000000000..56c5802dda --- /dev/null +++ b/monai/transforms/atmostonce/apply.py @@ -0,0 +1,222 @@ +from typing import Optional, Sequence, Union + +import itertools as it + +import numpy as np + +import torch + +from monai.config import DtypeLike +from monai.data import MetaTensor +from monai.transforms import Affine +from monai.transforms.inverse import InvertibleTransform +from monai.transforms.transform import MapTransform +from monai.transforms.atmostonce.utils import matmul +from monai.utils import GridSampleMode, GridSamplePadMode +from monai.utils.misc import get_backend_from_data, get_device_from_data +from monai.utils.mapping_stack import MatrixFactory, MetaMatrix, Matrix + +# TODO: This should move to a common place to be shared with dictionary +from monai.utils.type_conversion import dtypes_to_str_or_identity + +GridSampleModeSequence = Union[Sequence[Union[GridSampleMode, str]], GridSampleMode, str] +GridSamplePadModeSequence = Union[Sequence[Union[GridSamplePadMode, str]], GridSamplePadMode, str] +DtypeSequence = Union[Sequence[DtypeLike], DtypeLike] + + +# TODO: move to mapping_stack.py +def extents_from_shape(shape, dtype=np.float64): + extents = [[0, shape[i]] for i in range(1, len(shape))] + + extents = it.product(*extents) + return list(np.asarray(e + (1,), dtype=dtype) for e in extents) + + +# TODO: move to mapping_stack.py +def shape_from_extents( + src_shape: Sequence, + extents: Union[Sequence[np.ndarray], Sequence[torch.Tensor], np.ndarray, torch.Tensor] +): + if isinstance(extents, (list, tuple)): + if isinstance(extents[0], np.ndarray): + aextents = np.asarray(extents) + aextents = torch.from_numpy(aextents) + else: + aextents = torch.stack(extents) + else: + if isinstance(extents, np.ndarray): + aextents = torch.from_numpy(extents) + else: + aextents = extents + + mins = aextents.min(axis=0)[0] + maxes = aextents.max(axis=0)[0] + values = torch.round(maxes - mins).type(torch.IntTensor)[:-1] + return torch.cat((torch.as_tensor([src_shape[0]]), values)) + + +def metadata_is_compatible(value_1, value_2): + if value_1 is None: + return True + else: + if value_2 is None: + return True + return value_1 == value_2 + + +def metadata_dtype_is_compatible(value_1, value_2): + if value_1 is None: + return True + else: + if value_2 is None: + return True + + # if we are here, value_1 and value_2 are both set + # TODO: this is not a good enough solution + value_1_ = dtypes_to_str_or_identity(value_1) + value_2_ = dtypes_to_str_or_identity(value_2) + return value_1_ == value_2_ + + +def starting_matrix_and_extents(matrix_factory, data): + # set up the identity matrix and metadata + cumulative_matrix = matrix_factory.identity() + cumulative_extents = extents_from_shape(data.shape) + return cumulative_matrix, cumulative_extents + + +def prepare_args_dict_for_apply(cur_mode, cur_padding_mode, cur_device, cur_dtype): + kwargs = {} + if cur_mode is not None: + kwargs['mode'] = cur_mode + if cur_padding_mode is not None: + kwargs['padding_mode'] = cur_padding_mode + if cur_device is not None: + kwargs['device'] = cur_device + if cur_dtype is not None: + kwargs['dtype'] = cur_dtype + + return kwargs + + +def matrix_from_matrix_container(matrix): + if isinstance(matrix, MetaMatrix): + return matrix.matrix.matrix + elif isinstance(matrix, Matrix): + return matrix.matrix + else: + return matrix + + +def apply(data: Union[torch.Tensor, MetaTensor], + pending: Optional[dict] = None): + + if isinstance(data, dict): + rd = dict() + for k, v in data.items(): + result = apply(v) + rd[k] = result + return rd + + pending_ = pending + pending_ = data.pending_transforms + + if len(pending_) == 0: + return data + + dim_count = len(data.shape) - 1 + matrix_factory = MatrixFactory(dim_count, + get_backend_from_data(data), + get_device_from_data(data)) + + # set up the identity matrix and metadata + cumulative_matrix, cumulative_extents = starting_matrix_and_extents(matrix_factory, data) + + # set the various resampling parameters to an initial state + cur_mode = None + cur_padding_mode = None + cur_device = None + cur_dtype = None + cur_shape = data.shape + + for meta_matrix in pending_: + next_matrix = meta_matrix.matrix + # print("intermediate matrix\n", matrix_from_matrix_container(cumulative_matrix)) + cumulative_matrix = matmul(cumulative_matrix, next_matrix) + cumulative_extents = [matmul(e, cumulative_matrix) for e in cumulative_extents] + + new_mode = meta_matrix.metadata.get('mode', None) + new_padding_mode = meta_matrix.metadata.get('padding_mode', None) + new_device = meta_matrix.metadata.get('device', None) + new_dtype = meta_matrix.metadata.get('dtype', None) + new_shape = meta_matrix.metadata.get('shape_override', None) + + mode_compat = metadata_is_compatible(cur_mode, new_mode) + padding_mode_compat = metadata_is_compatible(cur_padding_mode, new_padding_mode) + device_compat = metadata_is_compatible(cur_device, new_device) + dtype_compat = metadata_dtype_is_compatible(cur_dtype, new_dtype) + + if (mode_compat is False or padding_mode_compat is False or + device_compat is False or dtype_compat is False): + # carry out an intermediate resample here due to incompatibility between arguments + kwargs = prepare_args_dict_for_apply(cur_mode, cur_padding_mode, cur_device, cur_dtype) + + cumulative_matrix_ = matrix_from_matrix_container(cumulative_matrix) + a = Affine(norm_coords=False, + affine=cumulative_matrix_, + **kwargs) + data, _ = a(img=data) + + cumulative_matrix, cumulative_extents =\ + starting_matrix_and_extents(matrix_factory, data) + + cur_mode = cur_mode if new_mode is None else new_mode + cur_padding_mode = cur_padding_mode if new_padding_mode is None else new_padding_mode + cur_device = cur_device if new_device is None else new_device + cur_dtype = cur_dtype if new_dtype is None else new_dtype + cur_shape = cur_shape if new_shape is None else new_shape + + kwargs = prepare_args_dict_for_apply(cur_mode, cur_padding_mode, cur_device, cur_dtype) + + cumulative_matrix_ = matrix_from_matrix_container(cumulative_matrix) + + # print(f"applying with cumulative matrix\n {cumulative_matrix_}") + a = Affine(norm_coords=False, + affine=cumulative_matrix_, + spatial_size=cur_shape[1:], + normalized=False, + **kwargs) + data, tx = a(img=data) + data.clear_pending_transforms() + + return data + + +# make Apply universal for arrays and dictionaries; it just calls through to functional apply +class Apply(InvertibleTransform): + + def __init__(self): + super().__init__() + + def __call__(self, *args, **kwargs): + return apply(*args, **kwargs) + + def inverse(self, data): + return NotImplementedError() + + +class Applyd(MapTransform, InvertibleTransform): + + def __init__(self): + super().__init__() + + def __call__( + self, + d: dict + ): + rd = dict() + for k, v in d.items(): + rd[k] = apply(v) + + def inverse(self, data): + return NotImplementedError() diff --git a/monai/transforms/atmostonce/array.py b/monai/transforms/atmostonce/array.py new file mode 100644 index 0000000000..aa6c4ed3ff --- /dev/null +++ b/monai/transforms/atmostonce/array.py @@ -0,0 +1,969 @@ +from typing import Any, Optional, Sequence, Tuple, Union + +import numpy as np + +import torch +from monai.networks.utils import meshgrid_ij + +from monai.transforms.spatial.array import RandRange + +from monai.config import DtypeLike, NdarrayOrTensor +from monai.data import MetaTensor + +from monai.transforms import InvertibleTransform, RandomizableTransform + +from monai.transforms.atmostonce.apply import apply +from monai.transforms.atmostonce.functional import resize, rotate, zoom, spacing, croppad, translate, rotate90, flip, \ + identity, grid_distortion, elastic_3d +from monai.transforms.atmostonce.lazy_transform import LazyTransform +from monai.transforms.atmostonce.randomizers import RotateRandomizer, Elastic3DRandomizer +from monai.transforms.atmostonce.utility import IMultiSampleTransform, ILazyTransform, IRandomizableTransform +from monai.transforms.atmostonce.utils import value_to_tuple_range + +from monai.utils import (GridSampleMode, GridSamplePadMode, + InterpolateMode, NumpyPadMode, PytorchPadMode, look_up_option) +from monai.utils.mapping_stack import MetaMatrix +from monai.utils.misc import ensure_tuple, ensure_tuple_rep + + +# TODO: these transforms are intended to replace array transforms once development is done + + +class Identity(LazyTransform, InvertibleTransform): + + def __init__( + self, + mode: Optional[Union[GridSampleMode, str]] = None, + padding_mode: Optional[Union[GridSamplePadMode, str]] = None, + dtype: Optional[Union[DtypeLike, torch.dtype]] = None, + lazy_evaluation: Optional[bool] = False + ): + LazyTransform.__init__(self, lazy_evaluation) + self.mode = mode + self.padding_mode = padding_mode + self.dtype = dtype + + def __call__( + self, + img: torch.Tensor, + mode: Optional[Union[GridSampleMode, str]] = None, + padding_mode: Optional[Union[GridSamplePadMode, str]] = None, + dtype: Optional[Union[DtypeLike, torch.dtype]] = None + ): + mode_ = mode or self.mode + padding_mode_ = padding_mode or self.mode + dtype_ = dtype or self.dtype + + img_t, transform, metadata = identity(img, mode_, padding_mode_, dtype_) + + # TODO: candidate for refactoring into a LazyTransform method + img_t.push_pending_transform(MetaMatrix(transform, metadata)) + if not self.lazy_evaluation: + img_t = apply(img_t) + + return img_t + + def inverse(self, data): + return NotImplementedError() + + +# spatial +# ======= + +# TODO: why doesn't Spacing have antialiasing options? +class Spacing(LazyTransform, InvertibleTransform): + + def __init__( + self, + pixdim: Union[Sequence[float], float, np.ndarray], + src_pixdim: Optional[Union[Sequence[float], float, np.ndarray]], + diagonal: Optional[bool] = False, + mode: Optional[Union[GridSampleMode, str]] = GridSampleMode.BILINEAR, + padding_mode: Optional[Union[GridSamplePadMode, str]] = GridSamplePadMode.BORDER, + align_corners: Optional[bool] = False, + dtype: Optional[DtypeLike] = np.float64, + lazy_evaluation: Optional[bool] = False, + shape_override: Optional[Sequence] = None + ): + LazyTransform.__init__(self, lazy_evaluation) + self.pixdim = pixdim + self.src_pixdim = src_pixdim + self.diagonal = diagonal + self.mode = mode + self.padding_mode = padding_mode + self.align_corners = align_corners + self.dtype = dtype + + def __call__( + self, + img: NdarrayOrTensor, + mode: Optional[Union[GridSampleMode, str]] = None, + padding_mode: Optional[Union[GridSamplePadMode, str]] = None, + align_corners: Optional[bool] = None, + dtype: Optional[Union[DtypeLike, torch.dtype]] = None, + shape_override: Optional[Sequence] = None + ): + + mode_ = mode or self.mode + padding_mode_ = padding_mode or self.padding_mode + align_corners_ = align_corners or self.align_corners + dtype_ = dtype or self.dtype + + shape_override_ = shape_override + if shape_override_ is None and isinstance(img, MetaTensor) and img.has_pending_transforms: + shape_override_ = img.peek_pending_transform().metadata.get("shape_override", None) + + img_t, transform, metadata = spacing(img, self.pixdim, self.src_pixdim, self.diagonal, + mode_, padding_mode_, align_corners_, dtype_, + shape_override_) + + # TODO: candidate for refactoring into a LazyTransform method + img_t.push_pending_transform(MetaMatrix(transform, metadata)) + if not self.lazy_evaluation: + img_t = apply(img_t) + + return img_t + + def inverse(self, data): + raise NotImplementedError() + + +class Flip(LazyTransform, InvertibleTransform): + + def __init__( + self, + spatial_axis: Optional[Union[Sequence[int], int]] = None, + lazy_evaluation: Optional[bool] = True + ) -> None: + LazyTransform.__init__(self, lazy_evaluation) + self.spatial_axis = spatial_axis + + def __call__( + self, + img: NdarrayOrTensor, + spatial_axis: Optional[Union[Sequence[int], int]] = None, + shape_override: Optional[Sequence] = None + ): + spatial_axis_ = self.spatial_axis = spatial_axis + shape_override_ = shape_override + if (shape_override_ is None and + isinstance(img, MetaTensor) and img.has_pending_transforms): + shape_override_ = img.peek_pending_transform().metadata.get("shape_override", None) + + img_t, transform, metadata = flip(img, spatial_axis_, shape_override_) + + # TODO: candidate for refactoring into a LazyTransform method + img_t.push_pending_transform(MetaMatrix(transform, metadata)) + if not self.lazy_evaluation: + img_t = apply(img_t) + + return img_t + + +class Resize(LazyTransform, InvertibleTransform): + + def __init__( + self, + spatial_size: Union[Sequence[int], int], + size_mode: Optional[str] = "all", + mode: Union[InterpolateMode, str] = InterpolateMode.AREA, + align_corners: Optional[bool] = False, + anti_aliasing: Optional[bool] = False, + anti_aliasing_sigma: Optional[Union[Sequence[float], float, None]] = None, + dtype: Optional[Union[DtypeLike, torch.dtype]] = np.float32, + lazy_evaluation: Optional[bool] = False + ): + LazyTransform.__init__(self, lazy_evaluation) + self.spatial_size = spatial_size + self.size_mode = size_mode + self.mode = mode, + self.align_corners = align_corners + self.anti_aliasing = anti_aliasing + self.anti_aliasing_sigma = anti_aliasing_sigma + self.dtype = dtype + + def __call__( + self, + img: NdarrayOrTensor, + mode: Optional[Union[InterpolateMode, str]] = None, + align_corners: Optional[bool] = None, + anti_aliasing: Optional[bool] = None, + anti_aliasing_sigma: Union[Sequence[float], float, None] = None, + shape_override: Optional[Sequence] = None + ) -> NdarrayOrTensor: + mode_ = mode or self.mode + align_corners_ = align_corners or self.align_corners + anti_aliasing_ = anti_aliasing or self.anti_aliasing + anti_aliasing_sigma_ = anti_aliasing_sigma or self.anti_aliasing_sigma + + shape_override_ = shape_override + if shape_override_ is None and isinstance(img, MetaTensor) and img.has_pending_transforms: + shape_override_ = img.peek_pending_transform().metadata.get("shape_override", None) + + img_t, transform, metadata = resize(img, self.spatial_size, self.size_mode, mode_, + align_corners_, anti_aliasing_, anti_aliasing_sigma_, + self.dtype, shape_override_) + + # TODO: candidate for refactoring into a LazyTransform method + img_t.push_pending_transform(MetaMatrix(transform, metadata)) + if not self.lazy_evaluation: + img_t = apply(img_t) + + return img_t + + +class Rotate(LazyTransform, InvertibleTransform): + + def __init__( + self, + angle: Union[Sequence[float], float], + keep_size: bool = True, + mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, + padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, + align_corners: bool = False, + dtype: Union[DtypeLike, torch.dtype] = np.float32, + lazy_evaluation: Optional[bool] = False + ): + LazyTransform.__init__(self, lazy_evaluation) + self.angle = angle + self.keep_size = keep_size + self.mode = mode + self.padding_mode = padding_mode + self.align_corners = align_corners + self.dtype = dtype + + def __call__( + self, + img: NdarrayOrTensor, + angle: Optional[Union[Sequence[float], float]] = None, + mode: Optional[Union[InterpolateMode, str]] = None, + padding_mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None, + align_corners: Optional[bool] = None, + shape_override: Optional[Sequence] = None + ) -> NdarrayOrTensor: + angle_ = self.angle or angle + mode_ = mode or self.mode or mode + padding_mode_ = padding_mode or self.padding_mode + align_corners_ = align_corners or self.align_corners + keep_size = self.keep_size + dtype_ = self.dtype + + shape_override_ = shape_override + if shape_override_ is None and isinstance(img, MetaTensor) and img.has_pending_transforms: + shape_override_ = img.peek_pending_transform().metadata.get("shape_override", None) + + # TODO: We should be tracking random rotate rather than just rotate + img_t, transform, metadata = rotate(img, angle_, keep_size, mode_, padding_mode_, + align_corners_, dtype_, shape_override_) + + # TODO: candidate for refactoring into a LazyTransform method + img_t.push_pending_transform(MetaMatrix(transform, metadata)) + if not self.lazy_evaluation: + img_t = apply(img_t) + + return img_t + + def inverse(self, data): + raise NotImplementedError() + + +class Zoom(LazyTransform, InvertibleTransform): + """ + Zoom into / out of the image applying the `zoom` factor as a scalar, or if `zoom` is a tuple of + values, apply each zoom factor to the appropriate dimension. + """ + + def __init__( + self, + factor: Union[Sequence[float], float], + mode: Union[InterpolateMode, str] = InterpolateMode.BILINEAR, + padding_mode: Union[NumpyPadMode, PytorchPadMode, str] = NumpyPadMode.EDGE, + align_corners: Optional[bool] = None, + keep_size: Optional[bool] = True, + dtype: Union[DtypeLike, torch.dtype] = np.float32, + lazy_evaluation: Optional[bool] = True, + **kwargs + ): + LazyTransform.__init__(self, lazy_evaluation) + self.factor = factor + self.mode: InterpolateMode = InterpolateMode(mode) + self.padding_mode = padding_mode + self.align_corners = align_corners + self.keep_size = keep_size + self.dtype = dtype + self.kwargs = kwargs + print("mode =", self.mode) + + def __call__( + self, + img: NdarrayOrTensor, + mode: Optional[Union[InterpolateMode, str]] = None, + padding_mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None, + align_corners: Optional[bool] = None, + factor: Optional[Union[Sequence[float], float]] = None, + shape_override: Optional[Sequence] = None + ) -> NdarrayOrTensor: + + factor = self.factor if factor is None else factor + mode = self.mode if mode is None else mode + padding_mode = self.padding_mode if padding_mode is None else padding_mode + align_corners = self.align_corners if align_corners is None else align_corners + keep_size = self.keep_size + dtype = self.dtype + + shape_override_ = shape_override + if shape_override_ is None and isinstance(img, MetaTensor) and img.has_pending_transforms: + shape_override_ = img.peek_pending_transform().metadata.get("shape_override", None) + print("mode =", mode) + img_t, transform, metadata = zoom(img, factor, mode, padding_mode, align_corners, + keep_size, dtype, shape_override_) + + # TODO: candidate for refactoring into a LazyTransform method + img_t.push_pending_transform(MetaMatrix(transform, metadata)) + if not self.lazy_evaluation: + img_t = apply(img_t) + + return img_t + + def inverse(self, data): + raise NotImplementedError() + + +class Rotate90(InvertibleTransform, LazyTransform): + + def __init__( + self, + k: Optional[int] = 1, + spatial_axes: Optional[Tuple[int, int]] = (0, 1), + lazy_evaluation: Optional[bool] = True, + ) -> None: + LazyTransform.__init__(self, lazy_evaluation) + self.k = k + self.spatial_axes = spatial_axes + + def __call__( + self, + img: torch.Tensor, + k: Optional[int] = None, + spatial_axes: Optional[Tuple[int, int]] = None, + shape_override: Optional[Sequence[int]] = None + ) -> torch.Tensor: + k_ = k or self.k + spatial_axes_ = spatial_axes or self.spatial_axes + + shape_override_ = shape_override + if (shape_override_ is None and + isinstance(img, MetaTensor) and img.has_pending_transforms): + shape_override_ = img.peek_pending_transform().metadata.get("shape_override", None) + + img_t, transform, metadata = rotate90(img, k_, spatial_axes_, shape_override_) + + # TODO: candidate for refactoring into a LazyTransform method + img_t.push_pending_transform(MetaMatrix(transform, metadata)) + if not self.lazy_evaluation: + img_t = apply(img_t) + + return img_t + + +class RandRotate90(RandomizableTransform, InvertibleTransform, LazyTransform): + + def __init__( + self, + prob: float = 0.1, + max_k: int = 3, + spatial_axes: Tuple[int, int] = (0, 1), + lazy_evaluation: Optional[bool] = True + ) -> None: + RandomizableTransform.__init__(self, prob) + self.max_k = max_k + self.spatial_axes = spatial_axes + + self.k = 0 + + self.op = Rotate90(0, spatial_axes, lazy_evaluation) + + def randomize(self, data: Optional[Any] = None) -> None: + super().randomize(None) + if self._do_transform: + self.k = self.R.randint(self.max_k) + 1 + + def __call__( + self, + img: torch.Tensor, + randomize: bool = True, + shape_override: Optional[Sequence] = None + ) -> torch.Tensor: + + if randomize: + self.randomize() + + k = self.k if self._do_transform else 0 + + return self.op(img, k, shape_override=shape_override) + + def inverse( + self, + data: NdarrayOrTensor, + ): + raise NotImplementedError() + + +# class RandRotate(RandomizableTransform, InvertibleTransform, LazyTransform): +# +# def __init__( +# self, +# range_x: Optional[Union[Tuple[float, float], float]] = 0.0, +# range_y: Optional[Union[Tuple[float, float], float]] = 0.0, +# range_z: Optional[Union[Tuple[float, float], float]] = 0.0, +# prob: Optional[float] = 0.1, +# keep_size: Optional[bool] = True, +# mode: Optional[Union[GridSampleMode, str]] = GridSampleMode.BILINEAR, +# padding_mode: Optional[Union[GridSamplePadMode, str]] = GridSamplePadMode.BORDER, +# align_corners: Optional[bool] = False, +# dtype: Optional[Union[DtypeLike, torch.dtype]] = np.float32, +# lazy_evaluation: Optional[bool] = True +# ): +# RandomizableTransform.__init__(self, prob) +# self.range_x = value_to_tuple_range(range_x) +# self.range_y = value_to_tuple_range(range_y) +# self.range_z = value_to_tuple_range(range_z) +# +# self.x, self.y, self.z = 0.0, 0.0, 0.0 +# +# self.op = Rotate(0, keep_size, mode, padding_mode, align_corners, dtype, lazy_evaluation) +# +# def randomize(self, data: Optional[Any] = None) -> None: +# super().randomize(None) +# if self._do_transform is True: +# self.x, self.y, self.z = 0.0, 0.0, 0.0 +# +# self.x = self.R.uniform(low=self.range_x[0], high=self.range_x[1]) +# self.y = self.R.uniform(low=self.range_y[0], high=self.range_y[1]) +# self.z = self.R.uniform(low=self.range_z[0], high=self.range_z[1]) +# +# def __call__( +# self, +# img: NdarrayOrTensor, +# mode: Optional[Union[InterpolateMode, str]] = None, +# padding_mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None, +# align_corners: Optional[bool] = None, +# dtype: Optional[Union[DtypeLike, torch.dtype]] = None, +# randomize: Optional[bool] = True, +# shape_override: Optional[Sequence] = None +# ) -> NdarrayOrTensor: +# +# if randomize: +# self.randomize() +# +# img_dims = len(img.shape) - 1 +# if self._do_transform: +# angle = self.x if img_dims == 2 else (self.x, self.y, self.z) +# else: +# angle = 0 if img_dims == 2 else (0, 0, 0) +# +# return self.op(img, angle, mode, padding_mode, align_corners, shape_override) +# +# def inverse( +# self, +# data: NdarrayOrTensor, +# ): +# raise NotImplementedError() + + +class RandRotate(InvertibleTransform, ILazyTransform, IRandomizableTransform): + + def __init__( + self, + range_x: Optional[Union[Tuple[float, float], float]] = 0.0, + range_y: Optional[Union[Tuple[float, float], float]] = 0.0, + range_z: Optional[Union[Tuple[float, float], float]] = 0.0, + prob: Optional[float] = 0.1, + keep_size: Optional[bool] = True, + mode: Optional[Union[GridSampleMode, str]] = GridSampleMode.BILINEAR, + padding_mode: Optional[Union[GridSamplePadMode, str]] = GridSamplePadMode.BORDER, + align_corners: Optional[bool] = False, + dtype: Optional[Union[DtypeLike, torch.dtype]] = np.float32, + lazy_evaluation: Optional[bool] = True + ): + self.randomizer = RotateRandomizer(value_to_tuple_range(range_x), + value_to_tuple_range(range_y), + value_to_tuple_range(range_z), + prob) + + self.op = Rotate(0, keep_size, mode, padding_mode, align_corners, dtype, lazy_evaluation) + + def __call__( + self, + img: NdarrayOrTensor, + mode: Optional[Union[InterpolateMode, str]] = None, + padding_mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None, + align_corners: Optional[bool] = None, + dtype: Optional[Union[DtypeLike, torch.dtype]] = None, + randomize: Optional[bool] = True, + shape_override: Optional[Sequence] = None + ) -> NdarrayOrTensor: + + angles = self.randomizer.sample(img) + + return self.op(img, angles, mode, padding_mode, align_corners, shape_override) + + @property + def lazy_evaluation(self): + return self.op.lazy_evaluation + + @lazy_evaluation.setter + def lazy_evaluation(self, value): + self.op.lazy_evaluation = value + + def inverse( + self, + data: NdarrayOrTensor, + ): + raise NotImplementedError() + + +class RandFlip(RandomizableTransform, InvertibleTransform, LazyTransform): + + def __init__( + self, + prob: float = 0.1, + spatial_axis: Optional[Union[Sequence[int], int]] = None + ) -> None: + RandomizableTransform.__init__(self, prob) + self.prob = prob + self.spatial_axis = spatial_axis + self.do_flip = False + self.op = Flip(0, spatial_axis) + self.nop = Identity() + + def randomize(self, data: Optional[Any] = None) -> None: + super().randomize(None) + if not self._do_transform: + self.do_flip = self._do_transform + + def __call__( + self, + img: NdarrayOrTensor, + randomize: Optional[bool] = True + ): + if randomize: + self.randomize() + if self.do_flip is True: + return self.op(img, self.spatial_axis) + else: + return self.nop(img) + + return self.op(img, self.spatial_axis) + + def inverse( + self, + data: NdarrayOrTensor, + ): + raise NotImplementedError() + + +class RandAxisFlip(RandomizableTransform, InvertibleTransform, LazyTransform): + + def __init__( + self, + prob: float = 0.1 + ) -> None: + RandomizableTransform.__init__(self, prob) + self.prob = prob + self.spatial_axis = None + self.op = Flip(self.spatial_axis) + + def randomize( + self, + data: Optional[Any] = None + ) -> None: + super().randomize(None) + if self._do_transform: + self.spatial_axis = self.R.randint(0, data.ndim - 1) + + def __call__( + self, + img: NdarrayOrTensor, + randomize: Optional[bool] = True + ) -> NdarrayOrTensor: + if randomize: + self.randomize() + + if self._do_transform: + spatial_axis = self.spatial_axis + else: + spatial_axis = None + + return self.op(img, spatial_axis) + + def inverse( + self, + data: NdarrayOrTensor, + ): + raise NotImplementedError() + + +class RandZoom(RandomizableTransform, InvertibleTransform, ILazyTransform): + + def __init__( + self, + prob: float = 0.1, + min_zoom: Optional[Union[Sequence[float], float]] = 0.9, + max_zoom: Optional[Union[Sequence[float], float]] = 1.1, + mode: Optional[Union[GridSampleMode, str]] = InterpolateMode.AREA, + padding_mode: Optional[Union[GridSamplePadMode, NumpyPadMode, str]] = NumpyPadMode.EDGE, + align_corners: Optional[bool] = None, + keep_size: bool = True, + lazy_evaluation: Optional[bool] = True, + **kwargs + ) -> None: + RandomizableTransform.__init__(self, prob) + self.prob = prob + self.min_zoom = ensure_tuple(min_zoom) + self.max_zoom = ensure_tuple(max_zoom) + if len(self.min_zoom) != len(self.max_zoom): + raise AssertionError("min_zoom and max_zoom must have the same length ", + f"but are {min_zoom} and {max_zoom} respectively") + self.mode = look_up_option(mode, InterpolateMode) + self.padding_mode = padding_mode + self.align_corners = align_corners + self.keep_size = keep_size + self.factors = None + + self.op = Zoom(1.0, self.mode, self.padding_mode, self.align_corners, self.keep_size, + lazy_evaluation=lazy_evaluation) + + def randomize( + self, + data: Optional[Any] = None + ) -> None: + super().randomize(None) + if not self._do_transform: + self.factors = [self.R.uniform(l, h) for l, h in zip(self.min_zoom, self.max_zoom)] + if len(self.factors) == 1: + # to keep the spatial shape ratio, use same random zoom factor for all dims + self.factors = ensure_tuple_rep(self.factors[0], data.ndim - 1) + elif len(self.factors) == 2 and data.ndim > 3: + # if 2 zoom factors provided for 3D data, use the first factor for H and W dims, second factor for D dim + self.factors =\ + ensure_tuple_rep(self.factors[0], data.ndim - 2) + ensure_tuple(self.factors[-1]) + + def __call__( + self, + img: NdarrayOrTensor, + randomize: Optional[bool] = True + ) -> NdarrayOrTensor: + if randomize: + self.randomize(img) + + if self._do_transform: + factors_ = self.factors + else: + factors_ = 1.0 + + return self.op(img, factor=factors_) + + def inverse( + self, + data: NdarrayOrTensor, + ): + raise NotImplementedError() + + +class GridDistortion(LazyTransform): + + def __init__( + self, + num_cells: Union[Tuple[int], int], + distort_steps: Sequence[Sequence[float]], + mode: Optional[Union[GridSampleMode, str]] = GridSampleMode.BILINEAR, + padding_mode: Optional[Union[GridSamplePadMode, NumpyPadMode, str]] = NumpyPadMode.EDGE, + ): + self.num_cells = num_cells + self.distort_steps = distort_steps + self.mode = mode + self.padding_mode = padding_mode + + def __call__( + self, + img: torch.Tensor, + distort_steps: Optional[Sequence[Sequence[float]]], + mode: Optional[Union[GridSampleMode, str]] = None, + padding_mode: Optional[Union[GridSamplePadMode, NumpyPadMode, str]] = None, + ): + distort_steps_ = self.distort_steps if distort_steps is None else distort_steps + num_cells_ = ensure_tuple_rep(self.num_cells, len(img.shape)-1) + mode_ = mode or self.mode + padding_mode_ = mode_ or self.padding_mode + + shape_override_ = None + if isinstance(img, MetaTensor) and img.has_pending_transforms: + shape_override_ = img.peek_pending_transform().metadata.get("shape_override", None) + + img_t, transform, metadata = grid_distortion(img, num_cells_, distort_steps_, + mode_, padding_mode_, shape_override_) + + # TODO: candidate for refactoring into a LazyTransform method + img_t.push_pending_transform(MetaMatrix(transform, metadata)) + if not self.lazy_evaluation: + img_t = apply(img_t) + + return img_t + + def inverse(self, data): + raise NotImplementedError() + + +class Rand3DElastic(LazyTransform, IRandomizableTransform): + + def __init__( + self, + sigma_range: Tuple[float, float], + magnitude_range: Tuple[float, float], + prob: float = 0.1, + rotate_range: RandRange = None, + shear_range: RandRange = None, + translate_range: RandRange = None, + scale_range: RandRange = None, + spatial_size: Optional[Union[Tuple[int, int, int], int]] = None, + mode: Optional[Union[GridSampleMode, str]] = GridSampleMode.BILINEAR, + padding_mode: Optional[Union[GridSamplePadMode, NumpyPadMode, str]] = GridSamplePadMode.REFLECTION, + as_tensor_output: bool = False, + device: Optional[torch.device] = None, + lazy_evaluation: Optional[bool] = True + ): + LazyTransform.__init__(self, lazy_evaluation=lazy_evaluation) + self.spatial_size = spatial_size + self.mode = mode + self.padding_mode = padding_mode + self.device = device + + self.randomizer = Elastic3DRandomizer(sigma_range, magnitude_range, prob) + + self.nop = Identity() + + def __call__( + self, + img: torch.Tensor, + spatial_size: Optional[Union[Tuple[int, int, int], int]] = None, + mode: Optional[Union[GridSampleMode, str]] = None, + padding_mode: Optional[Union[GridSamplePadMode, NumpyPadMode, str]] = None, + randomize: Optional[bool] = True, + shape_override: Tuple[int] = None + ): + mode_ = mode or self.mode + padding_mode_ = padding_mode or self.padding_mode + spatial_size_ = spatial_size or self.spatial_size or img.shape[1:] + + shape_override_ = shape_override + if shape_override is None and (isinstance(img, MetaTensor) and img.has_pending_transforms): + shape_override_ = img.peek_pending_transform().metadata.get("shape_override") + + rand_offsets, magnitude, sigma = self.randomizer.sample(spatial_size_, self.device) + if rand_offsets is None: + return self.nop(img) + else: + + img_t, transform, metadata = elastic_3d(img, + sigma, magnitude, rand_offsets, + spatial_size_, mode_, padding_mode_, + self.device, shape_override=shape_override_) + + # TODO: candidate for refactoring into a LazyTransform method + img_t.push_pending_transform(MetaMatrix(transform, metadata)) + if not self.lazy_evaluation: + img_t = apply(img_t) + + return img_t + + def inverse(self, data): + raise NotImplementedError() + + +class Translate(LazyTransform, InvertibleTransform): + def __init__( + self, + translation: Union[Sequence[float], float], + mode: Union[InterpolateMode, str] = InterpolateMode.AREA, + padding_mode: Union[NumpyPadMode, PytorchPadMode, str] = NumpyPadMode.EDGE, + dtype: Union[DtypeLike, torch.dtype] = np.float32, + lazy_evaluation: Optional[bool] = True, + **kwargs + ): + LazyTransform.__init__(self, lazy_evaluation) + self.translation = translation + self.mode: InterpolateMode = InterpolateMode(mode) + self.padding_mode = padding_mode + self.dtype = dtype + self.kwargs = kwargs + + def __call__( + self, + img: NdarrayOrTensor, + mode: Optional[Union[InterpolateMode, str]] = None, + padding_mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None, + shape_override: Optional[Sequence] = None + ) -> NdarrayOrTensor: + mode = self.mode or mode + padding_mode = self.padding_mode or padding_mode + dtype = self.dtype + + shape_override_ = shape_override + if shape_override_ is None and isinstance(img, MetaTensor) and img.has_pending_transforms: + shape_override_ = img.peek_pending_transform().metadata.get("shape_override", None) + + img_t, transform, metadata = translate(img, self.translation, + mode, padding_mode, dtype, shape_override_) + + # TODO: candidate for refactoring into a LazyTransform method + img_t.push_pending_transform(MetaMatrix(transform, metadata)) + if not self.lazy_evaluation: + img_t = apply(img_t) + + return img_t + + def inverse(self, data): + raise NotImplementedError() + + +# croppad +# ======= + +class CropPad(LazyTransform, InvertibleTransform): + + def __init__( + self, + slices: Optional[Sequence[slice]] = None, + padding_mode: Optional[Union[GridSamplePadMode, str]] = GridSamplePadMode.BORDER, + lazy_evaluation: Optional[bool] = True + ): + LazyTransform.__init__(self, lazy_evaluation) + self.slices = slices + self.padding_mode = padding_mode + + def __call__( + self, + img: NdarrayOrTensor, + slices: Optional[Sequence[slice]] = None, + shape_override: Optional[Sequence] = None + ): + slices_ = slices if self.slices is None else self.slices + + shape_override_ = shape_override + if shape_override_ is None and isinstance(img, MetaTensor) and img.has_pending_transforms: + shape_override_ = img.peek_pending_transform().metadata.get("shape_override", None) + + img_t, transform, metadata = croppad(img, slices_, self.padding_mode, shape_override_) + + # TODO: candidate for refactoring into a LazyTransform method + img_t.push_pending_transform(MetaMatrix(transform, metadata)) + if not self.lazy_evaluation: + img_t = apply(img_t) + + return img_t + + def inverse( + self, + data: NdarrayOrTensor + ): + raise NotImplementedError() + + +class RandomCropPad(InvertibleTransform, RandomizableTransform, ILazyTransform): + + def __init__( + self, + sizes: Union[Sequence[int], int], + prob: Optional[float] = 0.1, + padding_mode: Optional[Union[GridSamplePadMode, str]] = GridSamplePadMode.BORDER, + lazy_evaluation: Optional[bool] = True + ): + RandomizableTransform.__init__(self, prob) + self.sizes = sizes + self.padding_mode = padding_mode + self.offsets = None + + self.op = CropPad(padding_mode=padding_mode, lazy_evaluation=lazy_evaluation) + + def randomize( + self, + img: torch.Tensor + ): + super().randomize(None) + if self._do_transform: + img_shape = img.shape[1:] + if isinstance(self.sizes, int): + crop_shape = tuple(self.sizes for _ in range(len(img_shape))) + else: + crop_shape = self.sizes + + valid_ranges = tuple(i - c for i, c in zip(img_shape, crop_shape)) + self.offsets = tuple(self.R.randint(0, r+1) if r > 0 else r for r in valid_ranges) + + def __call__( + self, + img: torch.Tensor, + randomize: Optional[bool] = True + ): + + img_shape = img.shape[:1] + + if randomize: + self.randomize(img) + + if self._do_transform: + offsets_ = self.offsets + else: + # center crop if this sample isn't random + offsets_ = tuple((i - s) // 2 for i, s in zip(img_shape, self.sizes)) + slices = tuple(slice(o, o + s) for o, s in zip(offsets_, self.sizes)) + return self.op(img, slices=slices) + + + def inverse( + self, + data: NdarrayOrTensor + ): + raise NotImplementedError() + + @property + def lazy_evaluation(self): + return self.op.lazy_evaluation + + +class RandomCropPadMultiSample( + InvertibleTransform, ILazyTransform, IRandomizableTransform, IMultiSampleTransform +): + + def __init__( + self, + sizes: Union[Sequence[int], int], + sample_count: int, + padding_mode: Optional[Union[GridSamplePadMode, str]] = GridSamplePadMode.BORDER, + lazy_evaluation: Optional[bool] = True + ): + self.sample_count = sample_count + self.op = RandomCropPad(sizes, 1.0, padding_mode, lazy_evaluation) + + def __call__( + self, + img: torch.Tensor, + randomize: Optional[bool] = True + ): + for i in range(self.sample_count): + yield self.op(img, randomize) + + def inverse( + self, + data: NdarrayOrTensor + ): + raise NotImplementedError() + + def set_random_state(self, seed=None, state=None): + self.op.set_random_state(seed, state) + + @property + def lazy_evaluation(self): + return self.op.lazy_evaluation + diff --git a/monai/transforms/atmostonce/compose.py b/monai/transforms/atmostonce/compose.py new file mode 100644 index 0000000000..cd6597769c --- /dev/null +++ b/monai/transforms/atmostonce/compose.py @@ -0,0 +1,261 @@ +import warnings +from typing import Any, Callable, Optional, Sequence, Union + +import numpy as np + +from monai.transforms.atmostonce.apply import Apply +from monai.transforms.atmostonce.lazy_transform import LazyTransform, compile_lazy_transforms, flatten_sequences +from monai.transforms.atmostonce.utility import CachedTransformCompose, MultiSampleTransformCompose, \ + IMultiSampleTransform, IRandomizableTransform, ILazyTransform +from monai.utils import GridSampleMode, GridSamplePadMode, ensure_tuple, get_seed, MAX_SEED + +from monai.transforms import Randomizable, InvertibleTransform, OneOf, apply_transform + + +# TODO: this is intended to replace Compose once development is done + + +class ComposeCompiler: + + def compile(self, transforms, cache_mechanism): + + transforms1 = self.compile_caching(transforms, cache_mechanism) + + transforms2 = self.compile_multisampling(transforms1) + + transforms3 = self.compile_lazy_resampling(transforms2) + + return transforms3 + + def compile_caching(self, transforms, cache_mechanism): + # TODO: handle being passed a transform list with containers + # given a list of transforms, determine where to add a cached transform object + # and what transforms to put in it + cacheable = list() + for t in transforms: + if self.transform_is_random(t) is False: + cacheable.append(t) + else: + break + + if len(cacheable) == 0: + return list(transforms) + else: + return [CachedTransformCompose(cacheable, cache_mechanism)] + transforms[len(cacheable):] + + def compile_multisampling(self, transforms): + for i in reversed(range(len(transforms))): + if self.transform_is_multisampling(transforms[i]) is True: + transforms_ = transforms[:i] + [MultiSampleTransformCompose(transforms[i], + transforms[i+1:])] + return self.compile_multisampling(transforms_) + + return list(transforms) + + def compile_lazy_resampling(self, transforms): + result = list() + lazy = list() + for i in range(len(transforms)): + if self.transform_is_lazy(transforms[i]): + lazy.append(transforms[i]) + else: + if len(lazy) > 0: + result.extend(lazy) + result.append(Apply()) + lazy = list() + result.append(transforms[i]) + if len(lazy) > 0: + result.extend(lazy) + result.append(Apply()) + return result + + def transform_is_random(self, t): + return isinstance(t, IRandomizableTransform) + + def transform_is_container(self, t): + return isinstance(t, (CachedTransformCompose, MultiSampleTransformCompose)) + + def transform_is_multisampling(self, t): + return isinstance(t, IMultiSampleTransform) + + def transform_is_lazy(self, t): + return isinstance(t, ILazyTransform) + + +class Compose(Randomizable, InvertibleTransform): + """ + ``Compose`` provides the ability to chain a series of callables together in + a sequential manner. Each transform in the sequence must take a single + argument and return a single value. + + ``Compose`` can be used in two ways: + + #. With a series of transforms that accept and return a single + ndarray / tensor / tensor-like parameter. + #. With a series of transforms that accept and return a dictionary that + contains one or more parameters. Such transforms must have pass-through + semantics that unused values in the dictionary must be copied to the return + dictionary. It is required that the dictionary is copied between input + and output of each transform. + + If some transform takes a data item dictionary as input, and returns a + sequence of data items in the transform chain, all following transforms + will be applied to each item of this list if `map_items` is `True` (the + default). If `map_items` is `False`, the returned sequence is passed whole + to the next callable in the chain. + + For example: + + A `Compose([transformA, transformB, transformC], + map_items=True)(data_dict)` could achieve the following patch-based + transformation on the `data_dict` input: + + #. transformA normalizes the intensity of 'img' field in the `data_dict`. + #. transformB crops out image patches from the 'img' and 'seg' of + `data_dict`, and return a list of three patch samples:: + + {'img': 3x100x100 data, 'seg': 1x100x100 data, 'shape': (100, 100)} + applying transformB + ----------> + [{'img': 3x20x20 data, 'seg': 1x20x20 data, 'shape': (20, 20)}, + {'img': 3x20x20 data, 'seg': 1x20x20 data, 'shape': (20, 20)}, + {'img': 3x20x20 data, 'seg': 1x20x20 data, 'shape': (20, 20)},] + + #. transformC then randomly rotates or flips 'img' and 'seg' of + each dictionary item in the list returned by transformB. + + The composed transforms will be set the same global random seed if user called + `set_determinism()`. + + When using the pass-through dictionary operation, you can make use of + :class:`monai.transforms.adaptors.adaptor` to wrap transforms that don't conform + to the requirements. This approach allows you to use transforms from + otherwise incompatible libraries with minimal additional work. + + Note: + + In many cases, Compose is not the best way to create pre-processing + pipelines. Pre-processing is often not a strictly sequential series of + operations, and much of the complexity arises when a not-sequential + set of functions must be called as if it were a sequence. + + Example: images and labels + Images typically require some kind of normalization that labels do not. + Both are then typically augmented through the use of random rotations, + flips, and deformations. + Compose can be used with a series of transforms that take a dictionary + that contains 'image' and 'label' entries. This might require wrapping + `torchvision` transforms before passing them to compose. + Alternatively, one can create a class with a `__call__` function that + calls your pre-processing functions taking into account that not all of + them are called on the labels. + + Args: + transforms: sequence of callables. + map_items: whether to apply transform to each item in the input `data` if `data` is a list or tuple. + defaults to `True`. + unpack_items: whether to unpack input `data` with `*` as parameters for the callable function of transform. + defaults to `False`. + log_stats: whether to log the detailed information of data and applied transform when error happened, + for NumPy array and PyTorch Tensor, log the data shape and value range, + for other metadata, log the values directly. default to `False`. + lazy_resample: whether to compute consecutive spatial transforms resampling lazily. Default to False. + mode: {``"bilinear"``, ``"nearest"``} + Interpolation mode when ``lazy_resample=True``. Defaults to ``"bilinear"``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + When `USE_COMPILED` is `True`, this argument uses + ``"nearest"``, ``"bilinear"``, ``"bicubic"`` to indicate 0, 1, 3 order interpolations. + See also: https://docs.monai.io/en/stable/networks.html#grid-pull + padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} + Padding mode for outside grid values when ``lazy_resample=True``. Defaults to ``"border"``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + + """ + + def __init__( + self, + transforms: Optional[Union[Sequence[Callable], Callable]] = None, + map_items: bool = True, + unpack_items: bool = False, + log_stats: bool = False, + mode=GridSampleMode.BILINEAR, + padding_mode=GridSamplePadMode.BORDER, + lazy_evaluation: bool = False + ) -> None: + if transforms is None: + transforms = [] + self.transforms = ensure_tuple(transforms) + + if lazy_evaluation is True: + self.dst_transforms = compile_lazy_transforms(self.transforms) + else: + self.dst_transforms = flatten_sequences(self.transforms) + + self.map_items = map_items + self.unpack_items = unpack_items + self.log_stats = log_stats + self.mode = mode + self.padding_mode = padding_mode + self.lazy_evaluation = lazy_evaluation + self.set_random_state(seed=get_seed()) + + if self.lazy_evaluation: + for t in self.dst_transforms: + if isinstance(t, LazyTransform): + t.lazy_evaluation = True + + def set_random_state(self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None) -> "Compose": + super().set_random_state(seed=seed, state=state) + for _transform in self.transforms: + if not isinstance(_transform, Randomizable): + continue + _transform.set_random_state(seed=self.R.randint(MAX_SEED, dtype="uint32")) + return self + + def randomize(self, data: Optional[Any] = None) -> None: + for _transform in self.transforms: + if not isinstance(_transform, Randomizable): + continue + try: + _transform.randomize(data) + except TypeError as type_error: + tfm_name: str = type(_transform).__name__ + warnings.warn( + f'Transform "{tfm_name}" in Compose not randomized\n{tfm_name}.{type_error}.', RuntimeWarning + ) + + # TODO: this is a more general function that could be implemented elsewhere + def flatten(self): + """Return a Composition with a simple list of transforms, as opposed to any nested Compositions. + + e.g., `t1 = Compose([x, x, x, x, Compose([Compose([x, x]), x, x])]).flatten()` + will result in the equivalent of `t1 = Compose([x, x, x, x, x, x, x, x])`. + + """ + new_transforms = [] + for t in self.transforms: + if isinstance(t, Compose) and not isinstance(t, OneOf): + new_transforms += t.flatten().transforms + else: + new_transforms.append(t) + + return Compose(new_transforms) + + def __len__(self): + """Return number of transformations.""" + return len(self.flatten().transforms) + + def __call__(self, input_): + for _transform in self.dst_transforms: + input_ = apply_transform(_transform, input_, self.map_items, self.unpack_items, self.log_stats) + return input_ + + def inverse(self, data): + invertible_transforms = [t for t in self.flatten().transforms if isinstance(t, InvertibleTransform)] + if not invertible_transforms: + warnings.warn("inverse has been called but no invertible transforms have been supplied") + + # loop backwards over transforms + for t in reversed(invertible_transforms): + data = apply_transform(t.inverse, data, self.map_items, self.unpack_items, self.log_stats) + return data diff --git a/monai/transforms/atmostonce/dictionary.py b/monai/transforms/atmostonce/dictionary.py new file mode 100644 index 0000000000..59f6a5c102 --- /dev/null +++ b/monai/transforms/atmostonce/dictionary.py @@ -0,0 +1,457 @@ +from typing import Any, Hashable, Mapping, Optional, Sequence, Tuple, Union + +import numpy as np + +import torch + +from monai.transforms.atmostonce.array import Rotate, Resize, Spacing, Zoom, CropPad, RotateRandomizer +from monai.transforms.atmostonce.utility import ILazyTransform, IRandomizableTransform, IMultiSampleTransform +from monai.transforms.atmostonce.utils import value_to_tuple_range +from monai.utils import ensure_tuple_rep + +from monai.config import KeysCollection, DtypeLike, SequenceStr +from monai.transforms.atmostonce.lazy_transform import LazyTransform +from monai.transforms.inverse import InvertibleTransform +from monai.transforms.transform import MapTransform, RandomizableTransform +from monai.utils.enums import TransformBackends, GridSampleMode, GridSamplePadMode, InterpolateMode, NumpyPadMode, \ + PytorchPadMode +from monai.utils.mapping_stack import MatrixFactory +from monai.utils.type_conversion import expand_scalar_to_tuple + + +def get_device_from_data(data): + if isinstance(data, np.ndarray): + return None + elif isinstance(data, torch.Tensor): + return data.device + else: + msg = "'data' must be one of numpy ndarray or torch Tensor but is {}" + raise ValueError(msg.format(type(data))) + + +def get_backend_from_data(data): + if isinstance(data, np.ndarray): + return TransformBackends.NUMPY + elif isinstance(data, torch.Tensor): + return TransformBackends.TORCH + else: + msg = "'data' must be one of numpy ndarray or torch Tensor but is {}" + raise ValueError(msg.format(type(data))) + + +# TODO: reconcile multiple definitions to one in utils +def expand_potential_tuple(keys, value): + if not isinstance(value, (tuple, list)): + return tuple(value for _ in keys) + return value + + +def keys_to_process( + keys: KeysCollection, + dictionary: Mapping[Hashable, torch.Tensor], + allow_missing_keys: bool, +): + if allow_missing_keys is True: + return {k for k in keys if k in dictionary} + return keys + + +# class MappingStackTransformd(MapTransform, InvertibleTransform): +# +# def __init__(self, +# keys: KeysCollection): +# super().__init__(self) +# self.keys = keys +# +# def __call__(self, +# d: Mapping, +# *args, +# **kwargs): +# mappings = d.get("mappings", dict()) +# rd = dict() +# for k in self.keys: +# data = d[k] +# dims = len(data.shape)-1 +# device = get_device_from_data(data) +# backend = get_backend_from_data(data) +# v = None # mappings.get(k, MappingStack(MatrixFactory(dims, backend, device))) +# v.push(self.get_matrix(dims, backend, device, *args, **kwargs)) +# mappings[k] = v +# rd[k] = data +# +# rd["mappings"] = mappings +# +# return rd +# +# def get_matrix(self, dims, backend, device, *args, **kwargs): +# msg = "get_matrix must be implemented in a subclass of MappingStackTransform" +# raise NotImplementedError(msg) + + +class Spacingd(LazyTransform, MapTransform, InvertibleTransform): + + def __init__(self, + keys: KeysCollection, + pixdim: Union[Sequence[float], float, np.ndarray], + src_pixdim: Optional[Union[Sequence[float], float, np.ndarray]], + diagonal: Optional[bool] = False, + mode: Optional[Union[GridSampleMode, str]] = GridSampleMode.BILINEAR, + padding_mode: Optional[Union[GridSamplePadMode, str]] = GridSamplePadMode.BORDER, + align_corners: Optional[bool] = False, + dtype: Optional[DtypeLike] = np.float64, + allow_missing_keys: Optional[bool] = False, + lazy_evaluation: Optional[bool] = False + ): + LazyTransform.__init__(self, lazy_evaluation) + MapTransform.__init__(self) + InvertibleTransform.__init__(self) + self.keys = keys + self.pixdim = pixdim + self.src_pixdim = src_pixdim + self.diagonal = diagonal + self.modes = ensure_tuple_rep(mode) + self.padding_modes = ensure_tuple_rep(padding_mode) + self.align_corners = align_corners + self.dtypes = ensure_tuple_rep(dtype) + self.allow_missing_keys = allow_missing_keys + + def __call__(self, d: Mapping): + rd = dict(d) + if self.allow_missing_keys is True: + keys_present = {k for k in self.keys if k in d} + else: + keys_present = self.keys + + for ik, k in enumerate(keys_present): + tx = Spacing(self.pixdim, self.src_pixdim, self.diagonal, + self.modes[ik], self.padding_modes[ik], + self.align_corners, self.dtypes[ik]) + + rd[k] = tx(d[k]) + + return rd + + +class Rotated(LazyTransform, MapTransform, InvertibleTransform): + + def __init__(self, + keys: KeysCollection, + angle: Union[Sequence[float], float], + keep_size: bool = True, + mode: Optional[SequenceStr] = GridSampleMode.BILINEAR, + padding_mode: Optional[SequenceStr] = GridSamplePadMode.BORDER, + align_corners: Optional[Union[Sequence[bool], bool]] = False, + dtype: Optional[Union[Sequence[Union[DtypeLike, torch.dtype]], DtypeLike, torch.dtype]] = np.float32, + allow_missing_keys: Optional[bool] = False, + lazy_evaluation: Optional[bool] = False + ): + super().__init__(self) + self.keys = keys + self.angle = angle + self.keep_size = keep_size + self.modes = ensure_tuple_rep(mode, len(keys)) + self.padding_modes = ensure_tuple_rep(padding_mode, len(keys)) + self.align_corners = align_corners + self.dtypes = ensure_tuple_rep(dtype, len(keys)) + self.allow_missing_keys = allow_missing_keys + + def __call__(self, d: Mapping): + keys = keys_to_process(self.keys, d, self.allow_missing_keys) + rd = dict(d) + + for ik, k in enumerate(keys): + tx = Rotate(self.angle, self.keep_size, + self.modes[ik], self.padding_modes[ik], + self.align_corners, self.dtypes[ik]) + rd[k] = tx(d[k]) + + return rd + + def inverse(self, data: Any): + raise NotImplementedError() + + +class RandRotated(MapTransform, InvertibleTransform, LazyTransform, IRandomizableTransform): + + def __init__( + self, + keys: KeysCollection, + range_x: Union[Tuple[float, float], float] = 0.0, + range_y: Union[Tuple[float, float], float] = 0.0, + range_z: Union[Tuple[float, float], float] = 0.0, + prob: float = 0.1, + keep_size: bool = True, + mode: SequenceStr = GridSampleMode.BILINEAR, + padding_mode: SequenceStr = GridSamplePadMode.BORDER, + align_corners: Union[Sequence[bool], bool] = False, + dtype: Union[Sequence[Union[DtypeLike, torch.dtype]], DtypeLike, torch.dtype] = np.float32, + lazy_evaluation: Optional[bool] = True, + allow_missing_keys: Optional[bool] = False, + ): + self.keys = keys + self.allow_missing_keys = allow_missing_keys + self.randomizer = RotateRandomizer(value_to_tuple_range(range_x), + value_to_tuple_range(range_y), + value_to_tuple_range(range_z), + prob) + self.op = Rotate(0, keep_size, mode, padding_mode, align_corners, dtype, lazy_evaluation) + + def __call__( + self, + data: Mapping[Hashable, torch.Tensor] + ): + keys = keys_to_process(self.keys, data, self.allow_missing_keys) + rd = dict(data) + + angles = self.randomizer.sample(data[keys[0]]) + + for ik, k in enumerate(keys): + rd[k] = self.op(data[k], angles) + + return rd + + def inverse(self, data): + raise NotImplementedError() + + +class Resized(LazyTransform, MapTransform, InvertibleTransform): + + def __init__(self, + keys: KeysCollection, + spatial_size: Union[Sequence[int], int], + size_mode: Optional[str] = "all", + mode: Union[InterpolateMode, str] = InterpolateMode.AREA, + align_corners: Optional[bool] = False, + anti_aliasing: Optional[bool] = False, + anti_aliasing_sigma: Optional[Union[Sequence[float], float, None]] = None, + dtype: Optional[Union[DtypeLike, torch.dtype]] = np.float32, + lazy_evaluation: Optional[bool] = False + ): + LazyTransform.__init__(self, lazy_evaluation) + self.keys = keys + self.spatial_size = spatial_size + self.size_mode = size_mode + self.modes = ensure_tuple_rep(mode), + self.align_corners = align_corners + self.anti_aliasing = anti_aliasing + self.anti_aliasing_sigma = anti_aliasing_sigma + self.dtype = dtype + + def __call__(self, d: Mapping): + rd = dict(d) + if self.allow_missing_keys is True: + keys_present = {k for k in self.keys if k in d} + else: + keys_present = self.keys + + for ik, k in enumerate(keys_present): + tx = Resize(self.spatial_size, self.size_mode, self.modes[ik], self.align_corners, + self.anti_aliasing, self.anti_aliasing_sigma, self.dtype) + rd[k] = tx(d[k]) + + return rd + + def inverse(self, data: Any): + raise NotImplementedError() + + +class Zoomd(LazyTransform, MapTransform, InvertibleTransform): + + def __init__(self, + keys: KeysCollection, + zoom: Union[Sequence[float], float], + mode: Optional[Union[InterpolateMode, str]] = InterpolateMode.AREA, + padding_mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = NumpyPadMode.EDGE, + align_corners: Optional[bool] = None, + keep_size: Optional[bool] = True, + dtype: Optional[Union[DtypeLike, torch.dtype]] = np.float32, + lazy_evaluation: Optional[bool] = False, + **kwargs + ): + LazyTransform.__init__(self, lazy_evaluation) + self.keys = keys + self.zoom = zoom + self.modes = ensure_tuple_rep(mode) + self.padding_modes = ensure_tuple_rep(padding_mode) + self.align_corners = align_corners + self.keep_size = keep_size + self.dtype = dtype + + def __call__(self, d: Mapping): + rd = dict(d) + if self.allow_missing_keys is True: + keys_present = {k for k in self.keys if k in d} + else: + keys_present = self.keys + + for ik, k in enumerate(keys_present): + tx = Zoom(self.zoom, self.modes[ik], self.padding_modes[k], self.align_corners, + self.keep_size, self.dtype) + rd[k] = tx(d[k]) + + return rd + + def inverse(self, data: Any): + raise NotImplementedError() + + +class Translated(MapTransform, InvertibleTransform): + + def __init__(self, + keys: KeysCollection, + translate: Union[Sequence[float], float]): + super().__init__(self) + self.keys = keys + self.translate = expand_scalar_to_tuple(translate, len(keys)) + + def __call__(self, d: Mapping): + mappings = d.get("mappings", dict()) + rd = dict() + for k in self.keys: + data = d[k] + dims = len(data.shape)-1 + device = get_device_from_data(data) + backend = get_backend_from_data(data) + matrix_factory = MatrixFactory(dims, backend, device) + v = None # mappings.get(k, MappingStack(matrix_factory)) + v.push(matrix_factory.translate(self.translate)) + mappings[k] = v + rd[k] = data + + return rd + + +class CropPadd(MapTransform, InvertibleTransform, ILazyTransform): + + def __init__( + self, + keys: KeysCollection, + slices: Optional[Sequence[slice]] = None, + padding_mode: Optional[Union[GridSamplePadMode, str]] = GridSamplePadMode.BORDER, + lazy_evaluation: Optional[bool] = True + ): + self.keys = keys + self.slices = slices + self.padding_modes = padding_mode + self.lazy_evaluation = lazy_evaluation + + + def __call__( + self, + d: dict + ): + keys = keys_to_process(self.keys, d, self.allow_missing_keys) + + rd = dict(d) + for ik, k in enumerate(keys): + tx = CropPad(slices=self.slices, + padding_mode=self.padding_modes, + lazy_evaluation=self.lazy_evaluation) + + rd[k] = tx(d[k]) + + return rd + + +class RandomCropPadd(MapTransform, InvertibleTransform, RandomizableTransform, ILazyTransform): + + def __init__( + self, + keys: KeysCollection, + sizes: Union[Sequence[int], int], + prob: Optional[float] = 0.1, + padding_mode: Optional[Union[GridSamplePadMode, str]] = GridSamplePadMode.BORDER, + allow_missing_keys: bool=False, + lazy_evaluation: Optional[bool] = True + ): + RandomizableTransform.__init__(self, prob) + self.keys = keys + self.sizes = sizes + self.padding_mode = padding_mode + self.offsets = None + self.allow_missing_keys = allow_missing_keys + + self.op = CropPad(None, padding_mode) + + def randomize( + self, + img: torch.Tensor, + ): + super().randomize(None) + if self._do_transform: + img_shape = img.shape[1:] + if isinstance(self.sizes, int): + crop_shape = tuple(self.sizes for _ in range(len(img_shape))) + else: + crop_shape = self.sizes + + valid_ranges = tuple(i - c for i, c in zip(img_shape, crop_shape)) + self.offsets = tuple(self.R.randint(0, r+1) if r > 0 else r for r in valid_ranges) + + def __call__( + self, + d: dict, + randomize: Optional[bool] = True + ): + keys = keys_to_process(self.keys, d, self.allow_missing_keys) + + img = d[keys[0]] + img_shape = img.shape[:1] + + if randomize: + self.randomize(img) + + if self._do_transform: + offsets_ = self.offsets + else: + # center crop if this sample isn't random + offsets_ = tuple((i - s) // 2 for i, s in zip(img_shape, self.sizes)) + + slices = tuple(slice(o, o + s) for o, s in zip(offsets_, self.sizes)) + + rd = dict(d) + for k in keys: + rd[k] = self.op(img, slices=slices) + + return rd + + @property + def lazy_evaluation(self): + return self.op.lazy_evaluation + + +class RandomCropPadMultiSampled( + InvertibleTransform, IRandomizableTransform, ILazyTransform, IMultiSampleTransform +): + + def __init__( + self, + keys: Sequence[str], + sizes: Union[Sequence[int], int], + sample_count: int, + padding_mode: Optional[Union[GridSamplePadMode, str]] = GridSamplePadMode.BORDER, + lazy_evaluation: Optional[bool] = True + ): + self.sample_count = sample_count + self.op = RandomCropPadd(keys, sizes, 1.0, padding_mode, lazy_evaluation) + + def __call__( + self, + d: dict, + randomize: Optional[bool] = True + ): + for _ in range(self.sample_count): + yield self.op(d, randomize) + + def inverse( + self, + data: dict + ): + raise NotImplementedError() + + def set_random_state(self, seed=None, state=None): + self.op.set_random_state(seed, state) + + @property + def lazy_evaluation(self): + return self.op.lazy_evaluation diff --git a/monai/transforms/atmostonce/functional.py b/monai/transforms/atmostonce/functional.py new file mode 100644 index 0000000000..22f5c81eb8 --- /dev/null +++ b/monai/transforms/atmostonce/functional.py @@ -0,0 +1,509 @@ +from typing import Optional, Sequence, Tuple, Union + +import numpy as np + +import torch +from monai.networks.layers import GaussianFilter + +from monai.networks.utils import meshgrid_ij + +from monai.transforms import create_rotate, create_translate, map_spatial_axes, create_grid + +from monai.data import get_track_meta +from monai.transforms.atmostonce.apply import extents_from_shape, shape_from_extents +from monai.utils import convert_to_tensor, get_equivalent_dtype, ensure_tuple_rep, look_up_option, \ + GridSampleMode, GridSamplePadMode, fall_back_tuple, ensure_tuple_size, ensure_tuple, InterpolateMode, NumpyPadMode + +from monai.config import DtypeLike +from monai.utils.mapping_stack import MatrixFactory + + +def identity( + img: torch.Tensor, + mode: Optional[Union[InterpolateMode, str]] = None, + padding_mode: Optional[Union[NumpyPadMode, GridSamplePadMode, str]] = None, + dtype: Optional[Union[DtypeLike, torch.dtype]] = None +): + img_ = convert_to_tensor(img, track_meta=get_track_meta()) + + mode_ = None if mode is None else look_up_option(mode, GridSampleMode) + padding_mode_ = None if padding_mode is None else look_up_option(padding_mode, GridSamplePadMode) + dtype_ = get_equivalent_dtype(dtype or img_.dtype, torch.Tensor) + + transform = MatrixFactory.from_tensor(img_).identity().matrix.matrix + + metadata = dict() + if mode_ is not None: + metadata["mode"] = mode_ + if padding_mode_ is not None: + metadata["padding_mode"] = padding_mode_ + metadata["dtype"] = dtype_ + return img_, transform, metadata + + +def spacing( + img: torch.Tensor, + pixdim: Union[Sequence[float], float], + src_pixdim: Union[Sequence[float], float], + diagonal: Optional[bool] = False, + mode: Optional[Union[InterpolateMode, str]] = InterpolateMode.AREA, + padding_mode: Optional[Union[NumpyPadMode, GridSamplePadMode, str]] = NumpyPadMode.EDGE, + align_corners: Optional[bool] = False, + dtype: Optional[Union[DtypeLike, torch.dtype]] = None, + shape_override: Optional[Sequence] = None +): + """ + Args: + img: channel first array, must have shape: (num_channels, H[, W, ..., ]). + mode: {``"nearest"``, ``"nearest-exact"``, ``"linear"``, + ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} + The interpolation mode. Defaults to ``self.mode``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html + align_corners: This only has an effect when mode is + 'linear', 'bilinear', 'bicubic' or 'trilinear'. Defaults to ``self.align_corners``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html + anti_aliasing: bool, optional + Whether to apply a Gaussian filter to smooth the image prior + to downsampling. It is crucial to filter when downsampling + the image to avoid aliasing artifacts. See also ``skimage.transform.resize`` + anti_aliasing_sigma: {float, tuple of floats}, optional + Standard deviation for Gaussian filtering used when anti-aliasing. + By default, this value is chosen as (s - 1) / 2 where s is the + downsampling factor, where s > 1. For the up-size case, s < 1, no + anti-aliasing is performed prior to rescaling. + + Raises: + ValueError: When ``self.spatial_size`` length is less than ``img`` spatial dimensions. + + """ + + img_ = convert_to_tensor(img, track_meta=get_track_meta()) + input_shape = img_.shape if shape_override is None else shape_override + input_ndim = len(input_shape) - 1 + + pixdim_ = ensure_tuple_rep(pixdim, input_ndim) + src_pixdim_ = ensure_tuple_rep(src_pixdim, input_ndim) + + if diagonal is True: + raise ValueError("'diagonal' value of True is not currently supported") + + mode_ = look_up_option(mode, GridSampleMode) + padding_mode_ = look_up_option(padding_mode, GridSamplePadMode) + dtype_ = get_equivalent_dtype(dtype or img.dtype, torch.Tensor) + zoom_factors = [i / j for i, j in zip(src_pixdim_, pixdim_)] + + # TODO: decide whether we are consistently returning MetaMatrix or concrete transforms + transform = MatrixFactory.from_tensor(img).scale(zoom_factors).matrix.matrix + im_extents = extents_from_shape(input_shape) + im_extents = [transform @ e for e in im_extents] + shape_override_ = shape_from_extents(input_shape, im_extents) + + metadata = { + "pixdim": pixdim_, + "src_pixdim": src_pixdim_, + "diagonal": diagonal, + "mode": mode_, + "padding_mode": padding_mode_, + "align_corners": align_corners, + "dtype": dtype_, + "im_extents": im_extents, + "shape_override": shape_override_ + } + return img_, transform, metadata + + +def orientation( + img: torch.Tensor +): + pass + + +def flip( + img: torch.Tensor, + spatial_axis: Union[Sequence[int], int], + shape_override: Optional[Sequence] = None +): + img_ = convert_to_tensor(img, track_meta=get_track_meta()) + input_shape = img_.shape if shape_override is None else shape_override + + spatial_axis_ = spatial_axis + if spatial_axis_ is None: + spatial_axis_ = tuple(i for i in range(len(input_shape[1:]))) + transform = MatrixFactory.from_tensor(img).flip(spatial_axis_).matrix.matrix + im_extents = extents_from_shape(input_shape) + im_extents = [transform @ e for e in im_extents] + + shape_override_ = shape_from_extents(input_shape, im_extents) + + metadata = { + "spatial_axes": spatial_axis, + "im_extents": im_extents, + "shape_override": shape_override_ + } + return img_, transform, metadata + + +def resize( + img: torch.Tensor, + spatial_size: Union[Sequence[int], int], + size_mode: str = "all", + mode: Optional[Union[InterpolateMode, str]] = InterpolateMode.AREA, + align_corners: Optional[bool] = False, + anti_aliasing: Optional[bool] = None, + anti_aliasing_sigma: Optional[Union[Sequence[float], float]] = None, + dtype: Optional[Union[DtypeLike, torch.dtype]] = None, + shape_override: Optional[Sequence] = None +): + """ + Args: + img: channel first array, must have shape: (num_channels, H[, W, ..., ]). + mode: {``"nearest"``, ``"nearest-exact"``, ``"linear"``, + ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} + The interpolation mode. Defaults to ``self.mode``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html + align_corners: This only has an effect when mode is + 'linear', 'bilinear', 'bicubic' or 'trilinear'. Defaults to ``self.align_corners``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html + anti_aliasing: bool, optional + Whether to apply a Gaussian filter to smooth the image prior + to downsampling. It is crucial to filter when downsampling + the image to avoid aliasing artifacts. See also ``skimage.transform.resize`` + anti_aliasing_sigma: {float, tuple of floats}, optional + Standard deviation for Gaussian filtering used when anti-aliasing. + By default, this value is chosen as (s - 1) / 2 where s is the + downsampling factor, where s > 1. For the up-size case, s < 1, no + anti-aliasing is performed prior to rescaling. + + Raises: + ValueError: When ``self.spatial_size`` length is less than ``img`` spatial dimensions. + + """ + + img_ = convert_to_tensor(img, track_meta=get_track_meta()) + input_shape = img_.shape if shape_override is None else shape_override + input_ndim = len(input_shape) - 1 + + if size_mode == "all": + output_ndim = len(ensure_tuple(spatial_size)) + if output_ndim > input_ndim: + input_shape = ensure_tuple_size(input_shape, output_ndim + 1, 1) + img = img.reshape(input_shape) + elif output_ndim < input_ndim: + raise ValueError( + "len(spatial_size) must be greater or equal to img spatial dimensions, " + f"got spatial_size={output_ndim} img={input_ndim}." + ) + spatial_size_ = fall_back_tuple(spatial_size, input_shape[1:]) + else: # for the "longest" mode + img_size = input_shape[1:] + if not isinstance(spatial_size, int): + raise ValueError("spatial_size must be an int number if size_mode is 'longest'.") + scale = spatial_size / max(img_size) + spatial_size_ = tuple(int(round(s * scale)) for s in img_size) + + mode_ = look_up_option(mode, GridSampleMode) + dtype_ = get_equivalent_dtype(dtype or img.dtype, torch.Tensor) + zoom_factors = [i / j for i, j in zip(spatial_size, input_shape[1:])] + transform = MatrixFactory.from_tensor(img).scale(zoom_factors).matrix.matrix + im_extents = extents_from_shape(input_shape) + im_extents = [transform @ e for e in im_extents] + shape_override_ = shape_from_extents(input_shape, im_extents) + + metadata = { + "spatial_size": spatial_size_, + "size_mode": size_mode, + "mode": mode_, + "align_corners": align_corners, + "anti_aliasing": anti_aliasing, + "anti_aliasing_sigma": anti_aliasing_sigma, + "dtype": dtype_, + "im_extents": im_extents, + "shape_override": shape_override_ + } + return img_, transform, metadata + + +def rotate( + img: torch.Tensor, + angle: Union[Sequence[float], float], + keep_size: Optional[bool] = True, + mode: Optional[Union[InterpolateMode, str]] = InterpolateMode.AREA, + padding_mode: Optional[Union[NumpyPadMode, GridSamplePadMode, str]] = NumpyPadMode.EDGE, + align_corners: Optional[bool] = False, + dtype: Optional[Union[DtypeLike, torch.dtype]] = None, + shape_override: Optional[Sequence] = None +): + """ + Args: + img: channel first array, must have shape: [chns, H, W] or [chns, H, W, D]. + angle: Rotation angle(s) in radians. should a float for 2D, three floats for 3D. + keep_size: If it is True, the output shape is kept the same as the input. + If it is False, the output shape is adapted so that the + input array is contained completely in the output. Default is True. + mode: {``"bilinear"``, ``"nearest"``} + Interpolation mode to calculate output values. Defaults to ``self.mode``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``} + Padding mode for outside grid values. Defaults to ``self.padding_mode``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + align_corners: Defaults to ``self.align_corners``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + align_corners: Defaults to ``self.align_corners``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html + dtype: data type for resampling computation. Defaults to ``self.dtype``. + If None, use the data type of input data. To be compatible with other modules, + the output data type is always ``np.float32``. + + Raises: + ValueError: When ``img`` spatially is not one of [2D, 3D]. + + """ + + img_ = convert_to_tensor(img, track_meta=get_track_meta()) + mode_ = look_up_option(mode, GridSampleMode) + padding_mode_ = look_up_option(padding_mode, GridSamplePadMode) + dtype_ = get_equivalent_dtype(dtype or img.dtype, torch.Tensor) + input_shape = img_.shape if shape_override is None else shape_override + input_ndim = len(input_shape) - 1 + if input_ndim not in (2, 3): + raise ValueError(f"Unsupported image dimension: {input_ndim}, available options are [2, 3].") + + angle_ = ensure_tuple_rep(angle, 1 if input_ndim == 2 else 3) + rotate_tx = torch.from_numpy(create_rotate(input_ndim, angle_)) + im_extents = extents_from_shape(input_shape) + if not keep_size: + im_extents = [rotate_tx @ e for e in im_extents] + spatial_shape = shape_from_extents(input_shape, im_extents) + else: + spatial_shape = input_shape + transform = rotate_tx + metadata = { + "angle": angle_, + "keep_size": keep_size, + "mode": mode_, + "padding_mode": padding_mode_, + "align_corners": align_corners, + "dtype": dtype_, + "im_extents": im_extents, + "shape_override": spatial_shape + } + return img_, transform, metadata + + +def zoom( + img: torch.Tensor, + factor: Union[Sequence[float], float], + mode: Optional[Union[InterpolateMode, str]] = InterpolateMode.BILINEAR, + padding_mode: Optional[Union[NumpyPadMode, GridSamplePadMode, str]] = NumpyPadMode.EDGE, + align_corners: Optional[bool] = False, + keep_size: Optional[bool] = True, + dtype: Optional[Union[DtypeLike, torch.dtype]] = None, + shape_override: Optional[Sequence] = None +): + """ + Args: + img: channel first array, must have shape: (num_channels, H[, W, ..., ]). + mode: {``"nearest"``, ``"nearest-exact"``, ``"linear"``, + ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} + The interpolation mode. Defaults to ``self.mode``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html + align_corners: This only has an effect when mode is + 'linear', 'bilinear', 'bicubic' or 'trilinear'. Defaults to ``self.align_corners``. + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html + + Raises: + ValueError: When ``self.spatial_size`` length is less than ``img`` spatial dimensions. + + """ + + img_ = convert_to_tensor(img, track_meta=get_track_meta()) + input_shape = img_.shape if shape_override is None else shape_override + input_ndim = len(input_shape) - 1 + + zoom_factors = ensure_tuple_rep(factor, input_ndim) + zoom_factors = [1 / f for f in zoom_factors] + + mode_ = look_up_option(mode, GridSampleMode) + padding_mode_ = look_up_option(padding_mode, GridSamplePadMode) + dtype_ = get_equivalent_dtype(dtype or img_.dtype, torch.Tensor) + + transform = MatrixFactory.from_tensor(img_).scale(zoom_factors).matrix.matrix + im_extents = extents_from_shape(input_shape) + if keep_size is False: + im_extents = [transform @ e for e in im_extents] + shape_override_ = shape_from_extents(input_shape, im_extents) + else: + shape_override_ = input_shape + + metadata = { + "factor": zoom_factors, + "mode": mode_, + "padding_mode": padding_mode_, + "align_corners": align_corners, + "keep_size": keep_size, + "dtype": dtype_, + "im_extents": im_extents, + "shape_override": shape_override_ + } + return img_, transform, metadata + + +def rotate90( + img: torch.Tensor, + k: Optional[int] = 1, + spatial_axes: Optional[Tuple[int, int]] = (0, 1), + shape_override: Optional[bool] = None +): + if len(spatial_axes) != 2: + raise ValueError("'spatial_axes' must be a tuple of two integers indicating") + + img_ = convert_to_tensor(img, track_meta=get_track_meta()) + # axes = map_spatial_axes(img.ndim, spatial_axes) + # ori_shape = img.shape[1:] + input_shape = img_.shape if shape_override is None else shape_override + input_ndim = len(input_shape) - 1 + + transform = MatrixFactory.from_tensor(img_).rotate_90(k, ) + + metadata = { + "k": k, + "spatial_axes": spatial_axes, + "shape_override": shape_override + } + return img_, transform, metadata + + +def grid_distortion( + img: torch.Tensor, + num_cells: Union[Tuple[int], int], + distort_steps: Sequence[Sequence[float]], + mode: str = GridSampleMode.BILINEAR, + padding_mode: str = GridSamplePadMode.BORDER, + shape_override: Optional[Tuple[int]] = None +): + all_ranges = [] + num_cells = ensure_tuple_rep(num_cells, len(img.shape) - 1) + for dim_idx, dim_size in enumerate(img.shape[1:]): + dim_distort_steps = distort_steps[dim_idx] + ranges = torch.zeros(dim_size, dtype=torch.float32) + cell_size = dim_size // num_cells[dim_idx] + prev = 0 + for idx in range(num_cells[dim_idx] + 1): + start = int(idx * cell_size) + end = start + cell_size + if end > dim_size: + end = dim_size + cur = dim_size + else: + cur = prev + cell_size * dim_distort_steps[idx] + prev = cur + ranges = range - (dim_size - 1.0) / 2.0 + all_ranges.append() + coords = meshgrid_ij(*all_ranges) + grid = torch.stack([*coords, torch.ones_like(coords[0])]) + + metadata = { + "num_cells": num_cells, + "distort_steps": distort_steps, + "mode": mode, + "padding_mode": padding_mode + } + + return img, grid, metadata + + +def elastic_3d( + img: torch.Tensor, + sigma: float, + magnitude: float, + offsets: torch.Tensor, + spatial_size: Optional[Union[Tuple[int, int, int], int]] = None, + mode: str = GridSampleMode.BILINEAR, + padding_mode: str = GridSamplePadMode.REFLECTION, + device: Optional[torch.device] = None, + shape_override: Optional[Tuple[float]] = None +): + img_ = convert_to_tensor(img, track_meta=get_track_meta()) + + sp_size = fall_back_tuple(spatial_size, img.shape[1:]) + device_ = img.device if isinstance(img, torch.Tensor) else device + grid = create_grid(spatial_size=sp_size, device=device_, backend="torch") + gaussian = GaussianFilter(3, sigma, 3.0).to(device=device_) + grid[:3] += gaussian(offsets)[0] * magnitude + + metadata = { + "sigma": sigma, + "magnitude": magnitude, + "offsets": offsets, + } + if spatial_size is not None: + metadata["spatial_size"] = spatial_size + if mode is not None: + metadata["mode"] = mode + if padding_mode is not None: + metadata["padding_mode"] = padding_mode + if shape_override is not None: + metadata["shape_override"] = shape_override + + return img_, grid, metadata + + +def translate( + img: torch.Tensor, + translation: Sequence[float], + mode: Optional[Union[GridSampleMode, str]] = GridSampleMode.BILINEAR, + padding_mode: Optional[Union[GridSamplePadMode, str]] = NumpyPadMode.EDGE, + dtype: Union[DtypeLike, torch.dtype] = np.float32, + shape_override: Optional[Sequence] = None +): + img_ = convert_to_tensor(img, track_meta=get_track_meta()) + input_shape = img_.shape if shape_override is None else shape_override + input_ndim = len(input_shape) - 1 + if len(translation) != input_ndim: + raise ValueError(f"'translate' length {len(translation)} must be equal to 'img' " + f"spatial dimensions of {input_ndim}") + + transform = MatrixFactory.from_tensor(img).translate(translation).matrix.matrix + im_extents = extents_from_shape(input_shape) + im_extents = [transform @ e for e in im_extents] + # shape_override_ = shape_from_extents(input_shape, im_extents) + + metadata = { + "translation": translation, + "padding_mode": padding_mode, + "dtype": img.dtype, + "im_extents": im_extents, + # "shape_override": shape_override_ + } + return img_, transform, metadata + + +def croppad( + img: torch.Tensor, + slices: Union[Sequence[slice], slice], + padding_mode: Optional[Union[GridSamplePadMode, str]] = NumpyPadMode.EDGE, + shape_override: Optional[Sequence] = None +): + img_ = convert_to_tensor(img, track_meta=get_track_meta()) + input_shape = img_.shape if shape_override is None else shape_override + input_ndim = len(input_shape) - 1 + if len(slices) != input_ndim: + raise ValueError(f"'slices' length {len(slices)} must be equal to 'img' " + f"spatial dimensions of {input_ndim}") + + img_centers = [i / 2 for i in input_shape[1:]] + slice_centers = [(s.stop + s.start) / 2 for s in slices] + deltas = [s - i for i, s in zip(img_centers, slice_centers)] + transform = MatrixFactory.from_tensor(img).translate(deltas).matrix.matrix + im_extents = extents_from_shape([input_shape[0]] + [s.stop - s.start for s in slices]) + im_extents = [transform @ e for e in im_extents] + shape_override_ = shape_from_extents(input_shape, im_extents) + + metadata = { + "slices": slices, + "padding_mode": padding_mode, + "dtype": img.dtype, + "im_extents": im_extents, + "shape_override": shape_override_ + } + return img_, transform, metadata diff --git a/monai/transforms/atmostonce/lazy_transform.py b/monai/transforms/atmostonce/lazy_transform.py new file mode 100644 index 0000000000..4f2179dd2c --- /dev/null +++ b/monai/transforms/atmostonce/lazy_transform.py @@ -0,0 +1,80 @@ +from monai.config import NdarrayOrTensor +from monai.data import MetaTensor +from monai.transforms import Randomizable +from monai.transforms.atmostonce.apply import Applyd +from monai.transforms.atmostonce.utility import ILazyTransform +from monai.utils.mapping_stack import MetaMatrix + + +# TODO: move to mapping_stack.py +def push_transform( + data: MetaTensor, + meta_matrix: MetaMatrix +): + data.push_pending_transform(meta_matrix) + + +# TODO: move to mapping_stack.py +def update_metadata( + data: MetaTensor, + transform: NdarrayOrTensor, + extra_info +): + pass + + +# TODO: move to utils +def flatten_sequences(seq): + + def flatten_impl(s, accum): + if isinstance(s, (list, tuple)): + for inner_t in s: + accum = flatten_impl(inner_t, accum) + else: + accum.append(s) + return accum + + dest = [] + for s in seq: + dest = flatten_impl(s, dest) + + return dest + + +def transforms_compatible(current, next): + raise NotImplementedError() + + +def compile_lazy_transforms(transforms): + flat = flatten_sequences(transforms) + for i in range(len(flat)-1): + cur_t, next_t = flat[i], flat[i + 1] + if not transforms_compatible(cur_t, next_t): + flat.insert(i + 1, Applyd()) + if not isinstance(flat[-1], Applyd): + flat.append(Applyd) + return flat + + +def compile_cached_dataloading_transforms(transforms): + flat = flatten_sequences(transforms) + for i in range(len(flat)): + cur_t = flat[i] + if isinstance(cur_t, Randomizable): + flat.insert + + +class LazyTransform(ILazyTransform): + + def __init__(self, lazy_evaluation): + self.lazy_evaluation = lazy_evaluation + + # TODO: determine whether to have an 'eval' defined here that implements laziness + # def __call__(self, *args, **kwargs): + # """Call this method after calculating your meta data""" + # if self.lazily_evaluate: + # # forward the transform to metatensor + # pass + # else: + # # apply the transform and reset the stack on metatensor + # pass diff --git a/monai/transforms/atmostonce/randomizers.py b/monai/transforms/atmostonce/randomizers.py new file mode 100644 index 0000000000..fbcb6381aa --- /dev/null +++ b/monai/transforms/atmostonce/randomizers.py @@ -0,0 +1,105 @@ +import numpy as np + +import torch + + +class Randomizer: + + def __init__( + self, + prob: float = 1.0, + seed=None, + state=None + ): + self.R = None + self.set_random_state(seed, state) + + if not 0.0 <= prob <= 1.0: + raise ValueError(f"'prob' must be between 0.0 and 1.0 inclusive but is {prob}") + self.prob = prob + + def set_random_state(self, seed=None, state=None): + if seed is not None: + self.R = np.random.RandomState(seed) + elif state is not None: + self.R = state + else: + self.R = np.random.RandomState() + + def do_random(self): + return self.R.uniform() < self.prob + + def sample(self): + return self.R.uniform() + + +class RotateRandomizer(Randomizer): + + def __init__( + self, + range_x, + range_y, + range_z, + prob: float = 1.0, + seed=None, + state=None, + ): + super().__init__(prob, state, seed) + self.range_x = range_x + self.range_y = range_y + self.range_z = range_z + + def sample( + self, + data: torch.Tensor = None + ): + if not isinstance(data, (np.ndarray, torch.Tensor)): + raise ValueError("data must be a numpy ndarray or torch tensor but is of " + f"type {type(data)}") + + spatial_shape = len(data.shape[1:]) + if spatial_shape == 2: + if self.do_random(): + return self.R.uniform(self.range_x[0], self.range_x[1]) + return 0.0 + elif spatial_shape == 3: + if self.do_random(): + x = self.R.uniform(self.range_x[0], self.range_x[1]) + y = self.R.uniform(self.range_y[0], self.range_y[1]) + z = self.R.uniform(self.range_z[0], self.range_z[1]) + return x, y, z + return 0.0, 0.0, 0.0 + else: + raise ValueError("data should be a tensor with 2 or 3 spatial dimensions but it " + f"has {spatial_shape} spatial dimensions") + + +class Elastic3DRandomizer(Randomizer): + + def __init__( + self, + sigma_range, + magnitude_range, + prob=1.0, + grid_size=None, + seed=None, + state=None, + ): + super().__init__(prob, seed, state) + self.grid_size = grid_size + self.sigma_range = sigma_range + self.magnitude_range = magnitude_range + + def sample( + self, + grid_size, + device + ): + if self.do_random(): + rand_offsets = self.R.uniform(-1.0, 1.0, [3] + list(grid_size)).astype(np.float32, copy=False) + rand_offsets = torch.as_tensor(rand_offsets, device=device).unsqueeze(0) + sigma = self.R.uniform(self.sigma_range[0], self.sigma_range[1]) + magnitude = self.R.uniform(self.magnitude_range[0], self.magnitude_range[1]) + return rand_offsets, magnitude, sigma + + return None, None, None diff --git a/monai/transforms/atmostonce/utility.py b/monai/transforms/atmostonce/utility.py new file mode 100644 index 0000000000..43e84ec4fa --- /dev/null +++ b/monai/transforms/atmostonce/utility.py @@ -0,0 +1,119 @@ +from typing import Callable, Sequence + +import abc +from abc import ABC + +import torch + + +class ILazyTransform(abc.ABC): + + @property + def lazy_evaluation(self): + raise NotImplementedError() + + @lazy_evaluation.setter + def lazy_evaluation(self, lazy_evaluation): + raise NotImplementedError() + + +class IMultiSampleTransform(abc.ABC): + pass + + +class IRandomizableTransform(abc.ABC): + pass + + +class CacheMechanism(ABC): + """ + The interface for caching mechanisms to be used with CachedTransform. This interface provides + the ability to check whether cached objects are present, test and fetch simultaneously, and + store items. It makes no other assumptions about the caching mechanism, capacity, cache eviction + strategies or any other aspect of cache implementation + """ + + def try_fetch( + self, + key + ): + raise NotImplementedError() + + def store( + self, + key, + value + ): + raise NotImplementedError() + + +class CachedTransformCompose: + """ + CachedTransformCompose provides the functionality to cache the output of one or more transforms + such that they only need to be run once. Each time that CachedTransform is run, it checks whether + a cached entity is present, and if that entity is present, it loads it and returns the + resulting tensor / tensors as output. If that entity is not present in the cache, it executes + the transforms in its internal pipeline and caches the result before returning it. + """ + + def __init__( + self, + transforms: Callable, + cache: CacheMechanism + ): + """ + Args: + transforms: A sequence of callable objects + cache: A caching mechanism that implements the `CacheMechanism` interface + """ + self.transforms = transforms + self.cache = cache + + def __call__( + self, + key, + *args, + **kwargs + ): + is_present, value = self.cache.try_fetch(key) + + if is_present: + return value + + result = self.transforms(*args, **kwargs) + self.cache.store(key, result) + + return result + + +class MultiSampleTransformCompose: + """ + MultiSampleTransformCompose takes the output of a transform that generates multiple samples + and executes each sample separately in a depth first fashion, gathering the results into an + array that is finally returned after all samples are processed + """ + def __init__( + self, + multi_sample: Callable, + transforms: Callable, + ): + self.multi_sample = multi_sample + self.transforms = transforms + + def __call__( + self, + t, + *args, + **kwargs + ): + output = list() + for mt in self.multi_sample(t): + mt_out = self.transforms(mt) + if isinstance(mt_out, (torch.Tensor, dict)): + output.append(mt_out) + elif isinstance(mt_out, list): + output += mt_out + else: + raise ValueError(f"self.transform must return a Tensor or list of Tensors, but returned {mt_out}") + + return output diff --git a/monai/transforms/atmostonce/utils.py b/monai/transforms/atmostonce/utils.py new file mode 100644 index 0000000000..ac8a1f1f68 --- /dev/null +++ b/monai/transforms/atmostonce/utils.py @@ -0,0 +1,55 @@ +from typing import Union + +import numpy as np + +import torch + +from monai.config import NdarrayOrTensor +from monai.utils.mapping_stack import Matrix, MetaMatrix + + +def matmul( + first: Union[MetaMatrix, Matrix, NdarrayOrTensor], + second: Union[MetaMatrix, Matrix, NdarrayOrTensor] +): + matrix_types = (MetaMatrix, Matrix, torch.Tensor, np.ndarray) + + if not isinstance(first, matrix_types): + raise TypeError(f"'first' must be one of {matrix_types} but is {type(first)}") + if not isinstance(second, matrix_types): + raise TypeError(f"'second' must be one of {matrix_types} but is {type(second)}") + + first_ = first + if isinstance(first_, MetaMatrix): + first_ = first_.matrix.matrix + elif isinstance(first_, Matrix): + first_ = first_.matrix + + second_ = second + if isinstance(second_, MetaMatrix): + second_ = second_.matrix.matrix + elif isinstance(second_, Matrix): + second_ = second_.matrix + + if isinstance(first_, np.ndarray): + if isinstance(second_, np.ndarray): + return first_ @ second_ + else: + return torch.from_numpy(first_) @ second_ + else: + if isinstance(second_, np.ndarray): + return first_ @ torch.from_numpy(second_) + else: + return first_ @ second_ + + +def value_to_tuple_range(value): + if isinstance(value, (tuple, list)): + if len(value) == 2: + return (value[0], value[1]) if value[0] <= value[1] else (value[1], value[0]) + elif len(value) == 1: + return -value[0], value[0] + else: + raise ValueError(f"parameter 'value' must be of length 1 or 2 but is {len(value)}") + else: + return -value, value diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 730cb634c0..4c284619bd 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -22,6 +22,7 @@ from monai import config, transforms from monai.config import KeysCollection from monai.data.meta_tensor import MetaTensor +from monai.transforms.atmostonce.utility import IRandomizableTransform from monai.utils import MAX_SEED, ensure_tuple, first from monai.utils.enums import TransformBackends @@ -243,7 +244,7 @@ def __call__(self, data: Any): raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") -class RandomizableTransform(Randomizable, Transform): +class RandomizableTransform(Randomizable, Transform, IRandomizableTransform): """ An interface for handling random state locally, currently based on a class variable `R`, which is an instance of `np.random.RandomState`. diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index ae550e7ce6..93939a98c4 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -758,6 +758,127 @@ def _create_rotate( raise ValueError(f"Unsupported spatial_dims: {spatial_dims}, available options are [2, 3].") +def create_rotate_90( + spatial_dims: int, + axis: int, + steps: Optional[int] = 1, + device: Optional[torch.device] = None, + backend: str = TransformBackends.NUMPY, +) -> NdarrayOrTensor: + """ + create a 2D or 3D rotation matrix + + Args: + spatial_dims: {``2``, ``3``} spatial rank + radians: rotation radians + when spatial_dims == 3, the `radians` sequence corresponds to + rotation in the 1st, 2nd, and 3rd dim respectively. + device: device to compute and store the output (when the backend is "torch"). + backend: APIs to use, ``numpy`` or ``torch``. + + Raises: + ValueError: When ``radians`` is empty. + ValueError: When ``spatial_dims`` is not one of [2, 3]. + + """ + _backend = look_up_option(backend, TransformBackends) + if _backend == TransformBackends.NUMPY: + return _create_rotate_90( + spatial_dims=spatial_dims, + axis=axis, + steps=steps, + eye_func=np.eye) + if _backend == TransformBackends.TORCH: + return _create_rotate_90( + spatial_dims=spatial_dims, + axis=axis, + steps=steps, + eye_func=lambda rank: torch.eye(rank, device=device), + ) + raise ValueError(f"backend {backend} is not supported") + + +def _create_rotate_90( + spatial_dims: int, + axis: Tuple[int, int], + steps: Optional[int] = 1, + eye_func: Callable = np.eye +) -> NdarrayOrTensor: + + values = [(1, 0, 0, 1), + (0, -1, 1, 0), + (-1, 0, 0, -1), + (0, 1, -1, 0)] + + if spatial_dims == 2: + if axis != (0, 1): + raise ValueError(f"if 'spatial_dims' is 2, 'axis' must be (0, 1) but is {axis}") + elif spatial_dims == 3: + if axis not in ((0, 1), (0, 2), (1, 2)): + raise ValueError("if 'spatial_dims' is 3, 'axis' must be (0,1), (0, 2), or (1, 2) " + f"but is {axis}") + else: + raise ValueError(f"'spatial_dims' must be 2 or 3 but is {spatial_dims}") + + steps_ = steps % 4 + + affine = eye_func(spatial_dims + 1) + + if spatial_dims == 2: + a, b = 0, 1 + else: + a, b = axis + + affine[a, a], affine[a, b], affine[b, a], affine[b, b] = values[steps] + return affine + + +def create_flip( + spatial_dims: int, + spatial_axis: Union[Sequence[int], int], + device: Optional[torch.device] = None, + backend: str = TransformBackends.NUMPY, +) -> NdarrayOrTensor: + _backend = look_up_option(backend, TransformBackends) + if _backend == TransformBackends.NUMPY: + return _create_flip( + spatial_dims=spatial_dims, + spatial_axis=spatial_axis, + eye_func=np.eye) + if _backend == TransformBackends.TORCH: + return _create_flip( + spatial_dims=spatial_dims, + spatial_axis=spatial_axis, + eye_func=lambda rank: torch.eye(rank, device=device), + ) + raise ValueError(f"backend {backend} is not supported") + + +def _create_flip( + spatial_dims: int, + spatial_axis: Union[Sequence[int], int], + eye_func: Callable = np.eye +): + affine = eye_func(spatial_dims + 1) + if isinstance(spatial_axis, int): + if spatial_axis < -spatial_dims or spatial_axis >= spatial_dims: + raise ValueError("'spatial_axis' values must be between " + f"{-spatial_dims} and {spatial_dims-1} inclusive " + f"('spatial_axis' is {spatial_axis})") + affine[spatial_axis, spatial_axis] = -1 + else: + if any((s < -spatial_dims or s >= spatial_dims) for s in spatial_axis): + raise ValueError("'spatial_axis' values must be between " + f"{-spatial_dims} and {spatial_dims-1} inclusive " + f"('spatial_axis' is {spatial_axis})") + + for i in range(spatial_dims): + if i in spatial_axis: + affine[i, i] = -1 + + return affine + + def create_shear( spatial_dims: int, coefs: Union[Sequence[float], float], diff --git a/monai/utils/mapping_stack.py b/monai/utils/mapping_stack.py new file mode 100644 index 0000000000..78f6bb3e8a --- /dev/null +++ b/monai/utils/mapping_stack.py @@ -0,0 +1,201 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Sequence, Union + +import numpy as np + +import torch +from monai.config import NdarrayOrTensor + +from monai.utils.enums import TransformBackends +from monai.transforms.utils import (_create_rotate, _create_scale, _create_shear, + _create_translate, _create_rotate_90, _create_flip) +from monai.utils.misc import get_backend_from_data, get_device_from_data + + +def ensure_tensor(data: NdarrayOrTensor): + if isinstance(data, torch.Tensor): + return data + + return torch.as_tensor(data) + + +class MatrixFactory: + + def __init__(self, + dims: int, + backend: TransformBackends, + device: Optional[torch.device] = None): + + if backend == TransformBackends.NUMPY: + if device is not None: + raise ValueError("'device' must be None with TransformBackends.NUMPY") + self._device = None + self._sin = np.sin + self._cos = np.cos + self._eye = np.eye + self._diag = np.diag + else: + if device is None: + raise ValueError("'device' must be set with TransformBackends.TORCH") + self._device = device + self._sin = lambda th: torch.sin(torch.as_tensor(th, + dtype=torch.float64, + device=self._device)) + self._cos = lambda th: torch.cos(torch.as_tensor(th, + dtype=torch.float64, + device=self._device)) + self._eye = lambda rank: torch.eye(rank, + device=self._device, + dtype=torch.float64); + self._diag = lambda size: torch.diag(torch.as_tensor(size, + device=self._device, + dtype=torch.float64)) + + self._backend = backend + self._dims = dims + + @staticmethod + def from_tensor(data): + return MatrixFactory(len(data.shape)-1, + get_backend_from_data(data), + get_device_from_data(data)) + + def identity(self): + matrix = self._eye(self._dims + 1) + return MetaMatrix(matrix, {}) + + def rotate_euler(self, radians: Union[Sequence[float], float], **extra_args): + matrix = _create_rotate(self._dims, radians, self._sin, self._cos, self._eye) + return MetaMatrix(matrix, extra_args) + + def rotate_90(self, rotations, axis, **extra_args): + matrix = _create_rotate_90(self._dims, rotations, axis) + return MetaMatrix(matrix, extra_args) + + def flip(self, axis, **extra_args): + matrix = _create_flip(self._dims, axis, self._eye) + return MetaMatrix(matrix, extra_args) + + def shear(self, coefs: Union[Sequence[float], float], **extra_args): + matrix = _create_shear(self._dims, coefs, self._eye) + return MetaMatrix(matrix, extra_args) + + def scale(self, factors: Union[Sequence[float], float], **extra_args): + matrix = _create_scale(self._dims, factors, self._diag) + return MetaMatrix(matrix, extra_args) + + def translate(self, offsets: Union[Sequence[float], float], **extra_args): + matrix = _create_translate(self._dims, offsets, self._eye) + return MetaMatrix(matrix, extra_args) + + +class Mapping: + + def __init__(self, matrix): + self._matrix = matrix + + def apply(self, other): + return Mapping(other @ self._matrix) + + +class Dimensions: + + def __init__(self, flips, permutes): + raise NotImplementedError() + + def __matmul__(self, other): + raise NotImplementedError() + + def __rmatmul__(self, other): + raise NotImplementedError() + + +class Matrix: + + def __init__(self, matrix: NdarrayOrTensor): + self.matrix = ensure_tensor(matrix) + + def __matmul__(self, other): + if isinstance(other, Matrix): + other_matrix = other.matrix + else: + other_matrix = other + return self.matrix @ other_matrix + + def __rmatmul__(self, other): + return other.__matmul__(self.matrix) + + +# TODO: remove if the existing Grid is fine for our purposes +class Grid: + def __init__(self, grid): + raise NotImplementedError() + + def __matmul__(self, other): + raise NotImplementedError() + + +class MetaMatrix: + + def __init__(self, matrix, metadata=None): + if not isinstance(matrix, (Matrix, Grid)): + matrix_ = Matrix(matrix) + else: + matrix_ = matrix + self.matrix = matrix_ + + self.metadata = metadata or {} + + def __matmul__(self, other): + if isinstance(other, MetaMatrix): + other_ = other.matrix + else: + other_ = other + return MetaMatrix(self.matrix @ other_) + + def __rmatmul__(self, other): + if isinstance(other, MetaMatrix): + other_ = other.matrix + else: + other_ = other + return MetaMatrix(other_ @ self.matrix) + + +class MappingStack: + """ + This class keeps track of a series of mappings and apply them / calculate their inverse (if + mappings are invertible). Mapping stacks are used to generate a mapping that gets applied during a `Resample` / + `Resampled` transform. + + A mapping is one of: + - a description of a change to a numpy array that only requires index manipulation instead of an actual resample. + - a homogeneous matrix representing a geometric transform to be applied during a resample + - a field representing a deformation to be applied during a resample + """ + + def __init__(self, factory: MatrixFactory): + self.factory = factory + self.stack = [] + self.applied_stack = [] + + def push(self, mapping): + self.stack.append(mapping) + + def pop(self): + raise NotImplementedError() + + def transform(self): + m = Mapping(self.factory.identity()) + for t in self.stack: + m = m.apply(t) + return m diff --git a/monai/utils/misc.py b/monai/utils/misc.py index fc38dc5056..b2d5dd6e32 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -25,6 +25,7 @@ import numpy as np import torch +from monai.utils import TransformBackends from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor, PathLike from monai.utils.module import version_leq @@ -52,6 +53,8 @@ "sample_slices", "check_parent_dir", "save_obj", + "get_backend_from_data", + "get_device_from_data", ] _seed = None @@ -471,3 +474,23 @@ def save_obj( shutil.move(str(temp_path), path) except PermissionError: # project-monai/monai issue #3613 pass + + +def get_device_from_data(data): + if isinstance(data, np.ndarray): + return None + elif isinstance(data, torch.Tensor): + return data.device + else: + msg = "'data' must be one of numpy ndarray or torch Tensor but is {}" + raise ValueError(msg.format(type(data))) + + +def get_backend_from_data(data): + if isinstance(data, np.ndarray): + return TransformBackends.NUMPY + elif isinstance(data, torch.Tensor): + return TransformBackends.TORCH + else: + msg = "'data' must be one of numpy ndarray or torch Tensor but is {}" + raise ValueError(msg.format(type(data))) diff --git a/monai/utils/type_conversion.py b/monai/utils/type_conversion.py index 9c9fb1a4b2..7e12639b76 100644 --- a/monai/utils/type_conversion.py +++ b/monai/utils/type_conversion.py @@ -10,7 +10,7 @@ # limitations under the License. import re -from typing import Any, Optional, Sequence, Tuple, Type, Union +from typing import Any, Optional, Sequence, Tuple, Type, TypeVar, Union import numpy as np import torch @@ -34,8 +34,40 @@ "convert_to_numpy", "convert_to_tensor", "convert_to_dst_type", + "expand_scalar_to_tuple" ] +__dtype_dict = { + np.int8: 'int8', + torch.int8: 'int8', + np.int16: 'int16', + torch.int16: 'int16', + int: 'int32', + np.int32: 'int32', + torch.int32: 'int32', + np.int64: 'int64', + torch.int64: 'int64', + np.uint8: 'uint8', + torch.uint8: 'uint8', + np.uint16: 'uint16', + np.uint32: 'uint32', + np.uint64: 'uint64', + float: 'float32', + np.float16: 'float16', + np.float: 'float32', + np.float32: 'float32', + np.float64: 'float64', + torch.float16: 'float16', + torch.float: 'float32', + torch.float32: 'float32', + torch.double: 'float64', + torch.float64: 'float64' +} + +def dtypes_to_str_or_identity(dtype: Any) -> Any: + return __dtype_dict.get(dtype, dtype) + + def get_numpy_dtype_from_string(dtype: str) -> np.dtype: """Get a numpy dtype (e.g., `np.float32`) from its string (e.g., `"float32"`).""" @@ -347,3 +379,28 @@ def convert_to_list(data: Union[Sequence, torch.Tensor, np.ndarray]) -> list: """ return data.tolist() if isinstance(data, (torch.Tensor, np.ndarray)) else list(data) + + +TValue = TypeVar('TValue') +def expand_scalar_to_tuple(value: Union[Tuple[TValue], float], + length: int): + """ + If `value` is not a tuple, it will be converted to a tuple of the given `length`. + Otherwise, it is returned as is. Note that if `value` is a tuple, its length must be + the same as the `length` parameter or the conversion will fail. + Args: + value: the value to be converted to a tuple if necessary + length: the length of the resulting tuple + Returns: + If `value` is already a tuple, then `value` is returned. Otherwise, return a tuple + of length `length`, each element of which is `value`. + """ + if not isinstance(length, int): + raise ValueError("'length' must be an integer value") + + if not isinstance(value, tuple): + return tuple(value for _ in range(length)) + else: + if length != len(value): + raise ValueError("if 'value' is a tuple it must be the same length as 'length'") + return value diff --git a/tests/test_atmostonce.py b/tests/test_atmostonce.py new file mode 100644 index 0000000000..beabfe3da6 --- /dev/null +++ b/tests/test_atmostonce.py @@ -0,0 +1,1041 @@ +import unittest + +import math + +import astropy.samp.tests.test_errors +import numpy as np + + +import torch + +from monai.transforms.atmostonce import array as amoa +from monai.transforms.atmostonce import dictionary as amod +from monai.transforms.atmostonce.array import Rotate, CropPad +from monai.transforms.atmostonce.compose import ComposeCompiler +from monai.transforms.atmostonce.lazy_transform import compile_lazy_transforms +from monai.transforms.atmostonce.utils import value_to_tuple_range +from monai.utils import TransformBackends + +from monai.transforms.spatial import array as spatialarray +from monai.transforms import Affined, Affine, Flip, RandSpatialCropSamplesd, RandRotated +from monai.transforms.atmostonce.functional import croppad, resize, rotate, zoom, spacing, flip +from monai.transforms.atmostonce.apply import Applyd, extents_from_shape, shape_from_extents, apply, Apply +from monai.transforms.atmostonce.dictionary import Rotated +import monai.transforms.croppad.array as ocpa +from monai.transforms.compose import Compose +from monai.utils.enums import GridSampleMode, GridSamplePadMode +from monai.utils.mapping_stack import MatrixFactory, MetaMatrix + +from monai.transforms.atmostonce.utility import CachedTransformCompose, CacheMechanism, MultiSampleTransformCompose, \ + IMultiSampleTransform, IRandomizableTransform, ILazyTransform + + +class FakeRand(np.random.RandomState): + + def __init__(self, + rands=tuple(), + randints=tuple(), + uniforms=tuple() + ): + self.rands = rands + self.randind = 0 + self.randints = randints + self.randintind = 0 + self.uniforms = uniforms + self.uniformind = 0 + + def rand(self, *_, **__): + value = self.rands[self.randind] + self.randind += 1 + return value + + def randint(self, *_, **__): + value = self.randints[self.randintind] + self.randintind += 1 + return value + + def uniform(self, *_, **__): + value = self.uniforms[self.uniformind] + self.uniformind += 1 + return value + + +def get_img(size, dtype=torch.float32, offset=0): + img = torch.zeros(size, dtype=dtype) + if len(size) == 2: + for j in range(size[0]): + for i in range(size[1]): + img[j, i] = i + j * size[0] + offset + else: + for k in range(size[0]): + for j in range(size[1]): + for i in range(size[2]): + img[k, j, i] = i + j * size[0] + k * size[0] * size[1] + return np.expand_dims(img, 0) + + +def enumerate_results_of_op(results): + if isinstance(results, dict): + for k, v in results.items(): + if isinstance(v, (np.ndarray, torch.Tensor)): + print(k, v.shape, v[tuple(slice(0, 8) for _ in v.shape)]) + else: + print(k, v) + else: + for ir, v in enumerate(results): + if isinstance(v, (np.ndarray, torch.Tensor)): + print(ir, v.shape, v[tuple(slice(0, 8) for _ in v.shape)]) + else: + print(ir, v) + + +def matrices_nearly_equal(actual, expected): + if actual.shape != expected.shape: + raise ValueError("actual matrix does not match expected matrix size; " + f"{actual} vs {expected} respectively") + + +def test_array_op_multi_sample(tester, op, img, expected): + + s = 0 + for actual in op(img): + e = expected[s] + s += 1 + if op.lazy_evaluation is True: + actual = apply(actual) + + if isinstance(e, dict): + for k, v in e.items(): + if not torch.allclose(actual[k], v): + print("torch.allclose test returned False") + print(actual) + print(e) + tester.assertTrue(False) + else: + if not torch.allclose(actual, e): + print("torch.allclose test returned False") + print(actual) + print(e) + tester.assertTrue(False) + + +def test_array_op(tester, op, img, expected): + actual = op(img) + + if op.lazy_evaluation is True: + actual = apply(actual) + + if not torch.allclose(actual, expected): + print("torch.allclose test returned False") + print(actual) + print(expected) + tester.assertTrue(False) + + +class TestLowLevel(unittest.TestCase): + + def test_extents_2(self): + actual = extents_from_shape([1, 24, 32]) + expected = [np.asarray(v) for v in ((0, 0, 1), (0, 32, 1), (24, 0, 1), (24, 32, 1))] + self.assertTrue(np.all([np.array_equal(a, e) for a, e in zip(actual, expected)])) + + def test_extents_3(self): + actual = extents_from_shape([1, 12, 16, 8]) + expected = [np.asarray(v) for v in ((0, 0, 0, 1), (0, 0, 8, 1), (0, 16, 0, 1), (0, 16, 8, 1), + (12, 0, 0, 1), (12, 0, 8, 1), (12, 16, 0, 1), (12, 16, 8, 1))] + self.assertTrue(np.all([np.array_equal(a, e) for a, e in zip(actual, expected)])) + + def test_shape_from_extents(self): + actual = shape_from_extents([np.asarray([-16, -20, 1]), + np.asarray([-16, 20, 1]), + np.asarray([16, -20, 1]), + np.asarray([16, 20, 1])]) + print(actual) + + + def test_compile_transforms(self): + values = ["a", "b", ["c", ["d"], "e"], "f", ["g", "h"], "i"] + result = compile_lazy_transforms(values) + print(result) + + +class TestMappingStack(unittest.TestCase): + + def test_rotation_pi_by_2(self): + + fac = MatrixFactory(2, TransformBackends.NUMPY) + mat = fac.rotate_euler(torch.pi / 2) + expected = np.asarray([[0, -1, 0], + [1, 0, 0], + [0, 0, 1]]) + self.assertTrue(np.allclose(mat.matrix.matrix, expected)) + + def test_rotation_pi_by_4(self): + + fac = MatrixFactory(2, TransformBackends.NUMPY) + mat = fac.rotate_euler(torch.pi / 4) + piby4 = math.cos(torch.pi / 4) + expected = np.asarray([[piby4, -piby4, 0], + [piby4, piby4, 0], + [0, 0, 1]]) + self.assertTrue(np.allclose(mat.matrix.matrix, expected)) + + def test_rotation_pi_by_8(self): + fac = MatrixFactory(2, TransformBackends.NUMPY) + mat = fac.rotate_euler(torch.pi / 8) + cospi = math.cos(torch.pi / 8) + sinpi = math.sin(torch.pi / 8) + expected = np.asarray([[cospi, -sinpi, 0], + [sinpi, cospi, 0], + [0, 0, 1]]) + self.assertTrue(np.allclose(mat.matrix.matrix, expected)) + + def scale_by_2(self): + fac = MatrixFactory(2, TransformBackends.NUMPY) + mat = fac.scale(2) + expected = np.asarray([[2, 0, 0], + [0, 2, 0], + [0, 0, 1]]) + self.assertTrue(np.allclose(mat.matrix.matrix, expected)) + + # TODO: turn into proper test + def test_mult_matrices(self): + + fac = MatrixFactory(2, TransformBackends.NUMPY) + matrix1 = fac.translate((-16, -16)) + matrix2 = fac.rotate_euler(torch.pi / 4) + + matrix12 = matrix1 @ matrix2 + matrix21 = matrix2 @ matrix1 + + print("matrix12\n", matrix12.matrix.matrix) + print("matrix21\n", matrix21.matrix.matrix) + + extents = extents_from_shape([1, 32, 32]) + + print("matrix1") + for e in extents: + print(" ", e, matrix1.matrix.matrix @ e) + print("matrix2") + for e in extents: + print(" ", e, matrix2.matrix.matrix @ e) + print("matrix12") + for e in extents: + print(" ", e, matrix12.matrix.matrix @ e) + print("matrix21") + for e in extents: + print(" ", e, matrix21.matrix.matrix @ e) + + +class TestFunctional(unittest.TestCase): + + def _test_functional_impl(self, + op, + image, + params, + expected_matrix): + r_image, r_transform, r_metadata = op(image, **params) + enumerate_results_of_op((r_image, r_transform, r_metadata)) + self.assertTrue(torch.allclose(r_transform, expected_matrix)) + + # TODO: turn into proper test + def test_spacing(self): + kwargs = { + "pixdim": (0.5, 0.6), "src_pixdim": (1.0, 1.0), "diagonal": False, + "mode": "bilinear", "padding_mode": "border", "align_corners": None + } + expected_tx = torch.DoubleTensor([[2.0, 0.0, 0.0], + [0.0, 1.66666667, 0.0], + [0.0, 0.0, 1.0]]) + self._test_functional_impl(spacing, get_img((24, 32)), kwargs, expected_tx) + + + # TODO: turn into proper test + def test_resize(self): + kwargs = { + "spatial_size": (40, 40), "size_mode": "all", + "mode": "bilinear", "align_corners": None + } + expected_tx = torch.DoubleTensor([[1.66666667, 0.0, 0.0], + [0.0, 1.25, 0.0], + [0.0, 0.0, 1.0]]) + self._test_functional_impl(resize, get_img((24, 32)), kwargs, expected_tx) + + + # TODO: turn into proper test + def test_rotate(self): + kwargs = { + "angle": torch.pi / 4, "keep_size": True, + "mode": "bilinear", "padding_mode": "border" + } + expected_tx = torch.DoubleTensor([[0.70710678, -0.70710678, 0.0], + [0.70710678, 0.70710678, 0.0], + [0.0, 0.0, 1.0]]) + self._test_functional_impl(rotate, get_img((24, 32)), kwargs, expected_tx) + + + def test_zoom(self): + # results = zoom(np.zeros((1, 64, 64), dtype=np.float32), + # 2, + # "bilinear", + # "zeros") + # enumerate_results_of_op(results) + kwargs = { + "factor": 2, "mode": "nearest", "padding_mode": "border", "keep_size": True + } + expected_tx = torch.DoubleTensor([[0.5, 0.0, 0.0], + [0.0, 0.5, 0.0], + [0.0, 0.0, 1.0]]) + self._test_functional_impl(zoom, get_img((24, 32)), kwargs, expected_tx) + + + def _check_matrix(self, actual, expected): + np.allclose(actual, expected) + + def _test_rotate_90_impl(self, values, keep_dims, expected): + results = rotate(np.zeros((1, 64, 64, 32), dtype=np.float32), + values, + keep_dims, + "bilinear", + "border") + # enumerate_results_of_op(results) + self._check_matrix(results[1], expected) + + def test_rotate_d0_r1(self): + expected = np.asarray([[1, 0, 0, 0], + [0, 0, -1, 0], + [0, 1, 0, 0], + [0, 0, 0, 1]]) + self._test_rotate_90_impl((torch.pi / 2, 0, 0), True, expected) + + def test_rotate_d0_r2(self): + expected = np.asarray([[1, 0, 0, 0], + [0, -1, 0, 0], + [0, 0, -1, 0], + [0, 0, 0, 1]]) + self._test_rotate_90_impl((torch.pi, 0, 0), True, expected) + + def test_rotate_d0_r3(self): + expected = np.asarray([[1, 0, 0, 0], + [0, 0, 1, 0], + [0, -1, 0, 0], + [0, 0, 0, 1]]) + self._test_rotate_90_impl((3 * torch.pi / 2, 0, 0), True, expected) + + def test_rotate_d2_r1(self): + expected = np.asarray([[0, -1, 0, 0], + [1, 0, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1]]) + self._test_rotate_90_impl((0, 0, torch.pi / 2), True, expected) + + def test_rotate_d2_r2(self): + expected = np.asarray([[-1, 0, 0, 0], + [0, -1, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1]]) + self._test_rotate_90_impl((0, 0, torch.pi), True, expected) + + def test_rotate_d2_r3(self): + expected = np.asarray([[0, 1, 0, 0], + [-1, 0, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1]]) + self._test_rotate_90_impl((0, 0, 3 * torch.pi / 2), True, expected) + + def test_croppad_identity(self): + img = get_img((16, 16)).astype(int) + results = croppad(img, + (slice(0, 16), slice(0, 16))) + enumerate_results_of_op(results) + m = results[1].matrix.matrix + print(m) + result_size = results[2]['spatial_shape'] + a = Affine(affine=m, + padding_mode=GridSamplePadMode.ZEROS, + spatial_size=result_size) + img_, _ = a(img) + print(img_) + + def _croppad_impl(self, img_ext, slices, expected): + img = get_img(img_ext).astype(int) + results = croppad(img, slices) + enumerate_results_of_op(results) + m = results[1].matrix.matrix + print(m) + result_size = results[2]['spatial_shape'] + a = Affine(affine=m, + padding_mode=GridSamplePadMode.ZEROS, + spatial_size=result_size) + img_, _ = a(img) + if expected is None: + print(img_.numpy()) + else: + self.assertTrue(torch.allclose(img_, expected)) + + def test_croppad_img_odd_crop_odd(self): + expected = torch.as_tensor([[63., 64., 65., 66., 67., 68., 69.], + [78., 79., 80., 81., 82., 83., 84.], + [93., 94., 95., 96., 97., 98., 99.], + [108., 109., 110., 111., 112., 113., 114.], + [123., 124., 125., 126., 127., 128., 129.]]) + self._croppad_impl((15, 15), (slice(4, 9), slice(3, 10)), expected) + + def test_croppad_img_odd_crop_even(self): + expected = torch.as_tensor([[63., 64., 65., 66., 67., 68.], + [78., 79., 80., 81., 82., 83.], + [93., 94., 95., 96., 97., 98.], + [108., 109., 110., 111., 112., 113.]]) + self._croppad_impl((15, 15), (slice(4, 8), slice(3, 9)), expected) + + def test_croppad_img_even_crop_odd(self): + expected = torch.as_tensor([[67., 68., 69., 70., 71., 72., 73.], + [83., 84., 85., 86., 87., 88., 89.], + [99., 100., 101., 102., 103., 104., 105.], + [115., 116., 117., 118., 119., 120., 121.], + [131., 132., 133., 134., 135., 136., 137.]]) + self._croppad_impl((16, 16), (slice(4, 9), slice(3, 10)), expected) + + def test_croppad_img_even_crop_even(self): + expected = torch.as_tensor([[67., 68., 69., 70., 71., 72.], + [83., 84., 85., 86., 87., 88.], + [99., 100., 101., 102., 103., 104.], + [115., 116., 117., 118., 119., 120.]]) + self._croppad_impl((16, 16), (slice(4, 8), slice(3, 9)), expected) + + def _test_flip_impl(self, dims, spatial_axis, expected, verbose=False): + if dims == 2: + img = get_img((32, 32)) + else: + img = get_img((32, 32, 8)) + + actual = flip(img, spatial_axis=spatial_axis) + if verbose: + print("expected\n", expected) + print("actual\n", actual[1]) + self.assertTrue(np.allclose(expected, actual[1])) + + def test_flip(self): + + tests = [ + (2, None, {(0, 0): -1, (1, 1): -1}), + (2, 0, {(0, 0): -1}), + (2, 1, {(1, 1): -1}), + (2, (0,), {(0, 0): -1}), + (2, (1,), {(1, 1): -1}), + (2, (0, 1), {(0, 0): -1, (1, 1): -1}), + (3, None, {(0, 0): -1, (1, 1): -1, (2, 2): -1}), + (3, 0, {(0, 0): -1}), + (3, 1, {(1, 1): -1}), + (3, 2, {(2, 2): -1}), + (3, (0,), {(0, 0): -1}), + (3, (1,), {(1, 1): -1}), + (3, (2,), {(2, 2): -1}), + (3, (0, 1), {(0, 0): -1, (1, 1): -1}), + (3, (0, 2), {(0, 0): -1, (2, 2): -1}), + (3, (1, 2), {(1, 1): -1, (2, 2): -1}), + (3, (0, 1, 2), {(0, 0): -1, (1, 1): -1, (2, 2): -1}), + ] + + for t in tests: + with self.subTest(f"{t}"): + expected = np.eye(t[0] + 1) + for ke, kv in t[2].items(): + expected[ke] = kv + self._test_flip_impl(t[0], t[1], expected) + + +class TestArrayTransforms(unittest.TestCase): + + def test_apply_function(self): + img = get_img((16, 16)) + r = Rotate(torch.pi / 4, + keep_size=False, + mode="bilinear", + padding_mode="zeros", + lazy_evaluation=True) + c = CropPad((slice(4, 12), slice(6, 14)), + lazy_evaluation=True) + + img_r = r(img) + cur_op = img_r.peek_pending_transform() + img_rc = c(img_r, + shape_override=cur_op.metadata.get("shape_override", None)) + + img_rca = apply(img_rc) + + def test_rand_rotate(self): + r = amoa.RandRotate((-torch.pi / 4, torch.pi / 4), + prob=0.0, + keep_size=True, + mode="bilinear", + padding_mode="border", + align_corners=False) + img = np.zeros((1, 32, 32), dtype=np.float32) + results = r(img) + enumerate_results_of_op(results) + enumerate_results_of_op(results.pending_transforms[-1].metadata) + + def test_rand_zoom(self): + r = amoa.RandZoom(prob=1.0, + min_zoom=0.9, + max_zoom=1.1, + mode="nearest", + padding_mode="zeros", + keep_size=True) + + r.set_random_state(state=FakeRand((0.5,), (1.05,))) + img = np.zeros((1, 32, 32)) + results = r(img) + enumerate_results_of_op(results) + enumerate_results_of_op(results.pending_transforms[-1].metadata) + + + # TODO: amo: add tests for matrix and result size + def test_croppad(self): + img = get_img((15, 15)).astype(int) + results = croppad(img, (slice(4, 8), slice(3, 9))) + enumerate_results_of_op(results) + m = results[1].matrix.matrix + # print(m) + result_size = results[2]['spatial_shape'] + a = Affine(affine=m, + padding_mode=GridSamplePadMode.ZEROS, + spatial_size=result_size) + img_, _ = a(img) + # print(img_.numpy()) + + + def test_rotate_apply_not_lazy(self): + r = amoa.Rotate(-torch.pi / 4, + mode="bilinear", + padding_mode="border", + keep_size=False) + data = get_img((32, 32)) + data = r(data) + # data = apply(data) + print(data.shape) + print(data) + + def test_rotate_apply_lazy(self): + r = amoa.Rotate(-torch.pi / 4, + mode="bilinear", + padding_mode="border", + keep_size=False) + r.lazy_evaluation = True + data = get_img((32, 32)) + data = r(data) + data = apply(data) + expected = torch.DoubleTensor([[0.70710677, 0.70710677, 0.0, -15.61269784], + [-0.70710677, 0.70710677, 0.0, 15.5], + [0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 1.0]]) + self.assertTrue(torch.allclose(expected, data.affine)) + + def test_zoom_apply_lazy(self): + r = amoa.Zoom(2, + mode="bilinear", + padding_mode="border", + keep_size=False) + r.lazy_evaluation = True + data = get_img((32, 32)) + data = r(data) + data = apply(data) + expected = torch.DoubleTensor([[0.5, 0.0, 0.0, 11.75], + [0.0, 0.5, 0.0, 11.75], + [0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 1.0]]) + self.assertTrue(torch.allclose(expected, data.affine)) + + def test_crop_then_rotate_apply_lazy(self): + data = get_img((32, 32)) + print(data.shape) + + lc1 = amoa.CropPad(lazy_evaluation=True, + padding_mode="zeros") + lr1 = amoa.Rotate(torch.pi / 4, + keep_size=False, + padding_mode="zeros", + lazy_evaluation=False) + datas = [] + datas.append(data) + data1 = lc1(data, slices=(slice(0, 16), slice(0, 16))) + datas.append(data1) + data2 = lr1(data1) + datas.append(data2) + + +class TestOldTransforms(unittest.TestCase): + + def test_rand_zoom(self): + + r = spatialarray.RandZoom(1.0, 0.9, 1.1) + t = torch.rand((1, 32, 32)) + t_out = r(t) + print(t_out.shape) + + r = spatialarray.RandZoom(1.0, (0.9, 0.9, 0.9), (1.1, 1.1, 1.1)) + t_out = r(t) + print(t_out.shape) + + def test_deform_grid(self): + r = spatialarray.Rand2DElastic((1, 1), + (0.1, 0.2), + 1.0) + img = get_img((16, 16)) + result = r(img) + print(result) + + def test_center_spatial_crop(self): + r = ocpa.CenterSpatialCrop(4) + img = get_img((8, 8)) + result = r(img) + print(result) + + img = get_img((9, 9)) + result = r(img) + print(result) + + +class TestDictionaryTransforms(unittest.TestCase): + + def test_rotate_numpy(self): + r = Rotated(('image', 'label'), [0.0, 1.0, 0.0]) + + d = { + 'image': np.zeros((1, 64, 64, 32), dtype=np.float32), + 'label': np.ones((1, 64, 64, 32), dtype=np.int8) + } + d = r(d) + + for k, v in d.items(): + if isinstance(v, np.ndarray): + print(k, v.shape) + else: + print(k, v) + + def test_rotate_tensor(self): + r = Rotated(('image', 'label'), [0.0, 1.0, 0.0]) + + d = { + 'image': torch.zeros((1, 64, 64, 32), device="cpu", dtype=torch.float32), + 'label': torch.ones((1, 64, 64, 32), device="cpu", dtype=torch.int8) + } + d = r(d) + + for k, v in d.items(): + if isinstance(v, (np.ndarray, torch.Tensor)): + print(k, v.shape) + else: + print(k, v) + + def test_rotate_apply(self): + c = Compose([ + Rotated(('image', 'label'), (0.0, 3.14159265 / 2, 0.0)), + Applyd(('image', 'label'), + modes=(GridSampleMode.BILINEAR, GridSampleMode.NEAREST), + padding_modes=(GridSamplePadMode.BORDER, GridSamplePadMode.BORDER)) + ]) + + image = torch.zeros((1, 16, 16, 4), device="cpu", dtype=torch.float32) + for y in range(image.shape[-2]): + for z in range(image.shape[-1]): + image[0, :, y, z] = y + z * 16 + label = torch.ones((1, 16, 16, 4), device="cpu", dtype=torch.int8) + d = { + 'image': image, + 'label': label + } + # plt.imshow(d['image'][0, ..., d['image'].shape[-1]//2]) + d = c(d) + # plt.imshow(d['image'][0, ..., d['image'].shape[-1]//2]) + print(d['image'].shape) + + def test_old_affine(self): + c = Compose([ + Affined(('image', 'label'), + rotate_params=(0.0, 0.0, 3.14159265 / 2)) + ]) + + d = { + 'image': torch.zeros((1, 64, 64, 32), device="cpu", dtype=torch.float32), + 'label': torch.ones((1, 64, 64, 32), device="cpu", dtype=torch.int8) + } + d = c(d) + print(d['image'].shape) + + +class TestUtils(unittest.TestCase): + + def test_value_to_tuple_range(self): + self.assertTupleEqual(value_to_tuple_range(5), (-5, 5)) + self.assertTupleEqual(value_to_tuple_range([5]), (-5, 5)) + self.assertTupleEqual(value_to_tuple_range((5,)), (-5, 5)) + self.assertTupleEqual(value_to_tuple_range([-2.1, 4.3]), (-2.1, 4.3)) + self.assertTupleEqual(value_to_tuple_range((-2.1, 4.3)), (-2.1, 4.3)) + self.assertTupleEqual(value_to_tuple_range([4.3, -2.1]), (-2.1, 4.3)) + self.assertTupleEqual(value_to_tuple_range((4.3, -2.1)), (-2.1, 4.3)) + + +class TestRotate(unittest.TestCase): + + def _test_rotate_array_nonlazy(self, r, t, expected): + t_out = r(t) + self.assertTrue(torch.allclose(t_out.affine, expected)) + self.assertFalse(t_out.has_pending_transforms) + + def _test_rotate_array_lazy(self, r, t, expected): + t_out = r(t) + self.assertTrue(torch.allclose(t_out.affine, torch.eye(4, 4, dtype=torch.double))) + self.assertTrue(t_out.has_pending_transforms) + self.assertTrue(torch.allclose(t_out.peek_pending_transform().matrix.matrix, expected)) + + def test_rotate(self): + r = amoa.Rotate(torch.pi, + keep_size=True, + mode="nearest", + padding_mode="zeros", + lazy_evaluation=False) + t = get_img((16, 16)) + + expected = torch.eye(4, 4, dtype=torch.double) + expected[0, :] = torch.DoubleTensor([-1, 0, 0, 15]) + expected[1, :] = torch.DoubleTensor([0, -1, 0, 15]) + self._test_rotate_array_nonlazy(r, t, expected) + + def test_rand_rotate(self): + r = amoa.RandRotate((0, torch.pi * 2), + (0, torch.pi * 2), + (0, torch.pi * 2), + prob=0.5, + keep_size=True, + mode="nearest", + padding_mode="zeros", + lazy_evaluation=False) + t = get_img((16, 16)) + + expected = torch.eye(4, 4, dtype=torch.double) + expected[0, :] = torch.DoubleTensor([-1, 0, 0, 15]) + expected[1, :] = torch.DoubleTensor([0, -1, 0, 15]) + r.randomizer.set_random_state(state=FakeRand(uniforms=(0.25, torch.pi))) + self._test_rotate_array_nonlazy(r, t, expected) + + expected = torch.eye(4, 4, dtype=torch.double) + r.randomizer.set_random_state(state=FakeRand(uniforms=(0.75, torch.pi))) + self._test_rotate_array_nonlazy(r, t, expected) + + r.lazy_evaluation = True + + expected = torch.eye(3, 3, dtype=torch.double) + expected[0, :] = torch.DoubleTensor([-1, 0, 0]) + expected[1, :] = torch.DoubleTensor([0, -1, 0]) + r.randomizer.set_random_state(state=FakeRand(uniforms=(0.25, torch.pi))) + self._test_rotate_array_lazy(r, t, expected) + + expected = torch.eye(3, 3, dtype=torch.double) + r.randomizer.set_random_state(state=FakeRand(uniforms=(0.75, torch.pi))) + self._test_rotate_array_lazy(r, t, expected) + + +class TestRand3DElastic(unittest.TestCase): + + def test_array(self): + img = get_img((16, 16, 8)) + r = amoa.Rand3DElastic((0.5, 1.5), (0, 1), 1.0, mode="nearest", padding_mode="zeros") + result = r(img) + print(result.shape) + + +class TestCropPad(unittest.TestCase): + + def _test_functional(self, targs, img, expected): + result, tx, md = amoa.croppad(img, **targs) + result.push_pending_transform(MetaMatrix(tx, md)) + actual = apply(result) + if not torch.allclose(actual, expected): + print("torch.allclose test returned False") + print(actual) + print(expected) + self.assertTrue(False) + + def _test_rand(self, targs, rng_fac, img, expected): + targs['lazy_evaluation'] = False + r = amoa.RandomCropPad(**targs) + r.set_random_state(state=rng_fac()) + actual = r(img) + + if not torch.allclose(actual, expected): + print("torch.allclose test returned False") + print(actual) + print(expected) + self.assertTrue(False) + + targs['lazy_evaluation'] = True + r = amoa.RandomCropPad(**targs) + # a = amoa.apply() + r.set_random_state(state=rng_fac()) + actual = amoa.apply(r(img)) + + if not torch.allclose(actual, expected): + print("torch.allclose test returned False") + print(actual) + print(expected) + self.assertTrue(False) + + def test_croppad_all_valid(self): + targs = {'slices': None, 'padding_mode': 'zeros'} + img = get_img((16, 16)) + for j in range(8): + for i in range(8): + expected = torch.FloatTensor( + [[(i + j * 16) + ii + jj * 16 for ii in range(8)] for jj in range(8)]) + targs['slices'] = (slice(i, i+8), slice(j, j+8)) + self._test_functional(targs, img, expected) + + def test_randcroppad(self): + targs = {'sizes': (8, 8), 'prob': 1.0, 'padding_mode': 'zeros'} + rng_fac = lambda: FakeRand(rands=(0.5,), randints=(2, 6)) + img = get_img((16, 16)) + expected = torch.FloatTensor([[98 + i + j * 16 for i in range(8)] for j in range(8)]) + + self._test_rand(targs, rng_fac, img, expected) + + def test_randcroppad_ysmall(self): + targs = {'sizes': (8, 8), 'prob': 1.0, 'padding_mode': 'zeros'} + rng_fac = lambda: FakeRand(rands=(0.5,), randints=(6,)) + img = get_img((16, 6)) + expected = torch.FloatTensor([[102 + i + j * 16 for i in range(8)] for j in range(8)]) + + self._test_rand(targs, rng_fac, img, expected) + + def test_rand_croppad(self): + r = amoa.RandomCropPad((8, 8), 1.0, padding_mode="zeros", lazy_evaluation=False) + rng = FakeRand(rands=(0.5,), randints=(2, 6)) + r.set_random_state(state=rng) + + img = get_img((16, 16)) + + actual = r(img) + expected = torch.FloatTensor([[102 + i + j * 16 for i in range(8)] for j in range(8)]) + print(actual) + print(expected) + self.assertTrue(torch.allclose(actual, expected)) + + def test_randcroppadmulti(self): + op = amoa.RandomCropPadMultiSample((8, 8), 4, padding_mode="zeros", lazy_evaluation=False) + rng = FakeRand(rands=(0.5, 0.5, 0.5, 0.5), randints=(2, 6, 3, 5, 4, 4, 5, 3)) + op.set_random_state(state=rng) + img = get_img((16, 16)) + expected = [ + torch.FloatTensor([[38 + i + j * 16 for i in range(8)] for j in range(8)]), + torch.FloatTensor([[53 + i + j * 16 for i in range(8)] for j in range(8)]), + torch.FloatTensor([[68 + i + j * 16 for i in range(8)] for j in range(8)]), + torch.FloatTensor([[83 + i + j * 16 for i in range(8)] for j in range(8)]) + ] + test_array_op_multi_sample(self, op, img, expected) + + def test_randcropppadmultid(self): + op = amod.RandomCropPadMultiSampled(('img', 'lbl'), + (8, 8), + 4, + padding_mode="zeros", + lazy_evaluation=False) + rng = FakeRand(rands=(0.5, 0.5, 0.5, 0.5), randints=(2, 6, 3, 5, 4, 4, 5, 3)) + op.set_random_state(state=rng) + img = get_img((16, 16)) + lbl = get_img((16, 16)) + d = {'img': img, 'lbl': lbl} + expected_ts = [ + torch.FloatTensor([[38 + i + j * 16 for i in range(8)] for j in range(8)]), + torch.FloatTensor([[53 + i + j * 16 for i in range(8)] for j in range(8)]), + torch.FloatTensor([[68 + i + j * 16 for i in range(8)] for j in range(8)]), + torch.FloatTensor([[83 + i + j * 16 for i in range(8)] for j in range(8)]) + ] + expected = [{'img': e, 'lbl': e} for e in expected_ts] + test_array_op_multi_sample(self, op, d, expected) + + +# Utility transforms for compose compiler +# ================================================================================================= + +class TestMemoryCacheMechanism(CacheMechanism): + + def __init__( + self, + max_count: int + ): + self.max_count = max_count + self.contents = dict() + self.order = list() + + def try_fetch( + self, + key + ): + if key in self.contents: + return True, self.contents[key] + + return False, None + + def store( + self, + key, + value + ): + if key in self.contents: + self.contents[key] = value + else: + if len(self.contents) >= self.max_count: + last = self.order.pop() + del self.contents[last] + + self.contents[key] = value + self.order.append(key) + + +class TestUtilityTransforms(unittest.TestCase): + + def test_cached_transform(self): + + def generate_noise(shape): + def _inner(*args, **kwargs): + return np.random.normal(size=shape) + return _inner + + ct = CachedTransformCompose(transforms=generate_noise((1, 16, 16)), + cache=TestMemoryCacheMechanism(4)) + + first = ct("foo") + second = ct("foo") + third = ct("bar") + + self.assertIs(first, second) + self.assertIsNot(first, third) + + def test_multi_transform(self): + + def fake_multi_sample(keys, num_samples, roi_size): + def _inner(t): + for i in range(num_samples): + yield {'image': t['image'][0:1, i:i+roi_size[0], i:i+roi_size[1]]} + return _inner + +# t1 = RandSpatialCropSamplesd(keys=('image',), num_samples=4, roi_size=(32, 32)) + t1 = fake_multi_sample(keys=('image',), num_samples=4, roi_size=(32, 32)) + t2 = RandRotated(keys=('image',), range_z=(-torch.pi/2, torch.pi/2)) + mst = MultiSampleTransformCompose(t1, Compose([t2])) + c = Compose([mst]) + + d = torch.rand((1, 64, 64)) + + _d = d.data + _dd = d.data.clone() + d.data = _dd + r = c({'image': d}) + print(r) + + def test_compile_caching(self): + class NotRandomizable: + def __init__(self, name): + self.name = name + + def __repr__(self): + return f"NR<{self.name}>" + + class Randomizable(IRandomizableTransform): + def __init__(self, name): + self.name = name + + def __repr__(self): + return f"R<{self.name}>" + + a = NotRandomizable("a") + b = NotRandomizable("b") + c = Randomizable("c") + d = Randomizable("d") + e = NotRandomizable("e") + + source_transforms = [a, b, c, d, e] + + cc = ComposeCompiler() + + actual = cc.compile_caching(source_transforms, CacheMechanism()) + + self.assertIsInstance(actual[0], CachedTransformCompose) + self.assertEqual(len(actual[0].transforms), 2) + self.assertTrue(actual[0].transforms[0], a) + self.assertTrue(actual[0].transforms[1], b) + self.assertTrue(actual[1], c) + self.assertTrue(actual[2], d) + self.assertTrue(actual[3], e) + + + def test_compile_multisampling(self): + class NotMultiSampling: + def __init__(self, name): + self.name = name + + def __repr__(self): + return f"NMS<{self.name}>" + + class MultiSampling(IMultiSampleTransform): + def __init__(self, name): + self.name = name + + def __repr__(self): + return f"MS<{self.name}>" + + a = NotMultiSampling("a") + b = NotMultiSampling("b") + c = MultiSampling("c") + d = NotMultiSampling("d") + e = MultiSampling("e") + f = NotMultiSampling("f") + + source_transforms = [a, b, c, d, e, f] + + cc = ComposeCompiler() + + actual = cc.compile_multisampling(source_transforms) + print(actual) + + self.assertEqual(actual[0], a) + self.assertEqual(actual[1], b) + self.assertIsInstance(actual[2], MultiSampleTransformCompose) + self.assertEqual(actual[2].multi_sample, c) + self.assertEqual(len(actual[2].transforms), 2) + self.assertEqual(actual[2].transforms[0], d) + self.assertIsInstance(actual[2].transforms[1], MultiSampleTransformCompose) + self.assertEqual(actual[2].transforms[1].multi_sample, e) + self.assertEqual(len(actual[2].transforms[1].transforms), 1) + self.assertEqual(actual[2].transforms[1].transforms[0], f) + + def test_compile_lazy_resampling(self): + class NotLazy: + def __init__(self, name): + self.name = name + + def __repr__(self): + return f"NL<{self.name}>" + + class Lazy(ILazyTransform): + def __init__(self, name): + self.name = name + + def __repr__(self): + return f"L<{self.name}>" + + a = NotLazy("a") + b = Lazy("b") + c = Lazy("c") + d = NotLazy("d") + e = Lazy("e") + f = Lazy("f") + + source_transforms = [a, b, c, d, e, f] + + cc = ComposeCompiler() + + actual = cc.compile_lazy_resampling(source_transforms) + + print(actual) \ No newline at end of file diff --git a/tests/test_create_grid_and_affine.py b/tests/test_create_grid_and_affine.py index d70db45468..87901494cc 100644 --- a/tests/test_create_grid_and_affine.py +++ b/tests/test_create_grid_and_affine.py @@ -22,6 +22,7 @@ create_shear, create_translate, ) +from monai.transforms.utils import create_rotate_90 from tests.utils import assert_allclose, is_tf32_env @@ -220,6 +221,22 @@ def test_create_rotate(self): np.array([[0.0, -1.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]), ) + def test_create_rotate_90(self): + expected = np.eye(3) + test_assert(create_rotate_90, (2, (0, 1), 0), expected) + + expected = np.eye(3) + expected[0:2, 0:2] = [[0, -1], [1, 0]] + test_assert(create_rotate_90, (2, (0, 1), 1), expected) + + expected = np.eye(3) + expected[0:2, 0:2] = [[-1, 0], [0, -1]] + test_assert(create_rotate_90, (2, (0, 1), 2), expected) + + expected = np.eye(3) + expected[0:2, 0:2] = [[0, 1], [-1, 0]] + test_assert(create_rotate_90, (2, (0, 1), 3), expected) + def test_create_shear(self): test_assert(create_shear, (2, 1.0), np.array([[1.0, 1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])) test_assert(create_shear, (2, (2.0, 3.0)), np.array([[1.0, 2.0, 0.0], [3.0, 1.0, 0.0], [0.0, 0.0, 1.0]])) diff --git a/tests/test_mapping_stack.py b/tests/test_mapping_stack.py new file mode 100644 index 0000000000..6c6fb1ec11 --- /dev/null +++ b/tests/test_mapping_stack.py @@ -0,0 +1,30 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from monai.transforms.utils import TransformBackends + +from monai.utils.mapping_stack import MappingStack, MatrixFactory + + +class MappingStackTest(unittest.TestCase): + + def test_scale_then_translate(self): + + f = MatrixFactory(3, TransformBackends.NUMPY) + m_scale = f.scale((2, 2, 2)) + m_trans = f.translate((20, 20, 0)) + ms = MappingStack(f) + ms.push(m_scale) + ms.push(m_trans) + + print(ms.transform()._matrix)