Skip to content

[Early WIP] - Lazy resampling #4922

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 37 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
6c9e081
Pulling across more from standalone prototype
atbenmurray Jul 28, 2022
0b0670d
Applyd function
atbenmurray Aug 1, 2022
b1e6f3c
more atmostonce functionality; baseline atmostonce (non-dictionary) t…
atbenmurray Aug 11, 2022
02a95a4
Pulling across more from standalone prototype
atbenmurray Jul 28, 2022
4fa26c5
Applyd function
atbenmurray Aug 1, 2022
65fafa9
more atmostonce functionality; baseline atmostonce (non-dictionary) t…
atbenmurray Aug 11, 2022
3530432
Working on ground-up refactor of array transforms
atbenmurray Aug 15, 2022
d74fb56
Re-re-re-factored function / array / dict based rotate and others; re…
atbenmurray Aug 17, 2022
2dab7e5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 17, 2022
6c41489
Resolving merge conflicts
atbenmurray Aug 18, 2022
999698e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 18, 2022
7af9558
Simplified array and dictionary transforms; working on generic translate
atbenmurray Aug 19, 2022
e47b221
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 19, 2022
d53d95d
Base croppad debugged and tests added
atbenmurray Aug 19, 2022
b6a0b05
Partial implementation of mon-metatensor based lazy resampling
atbenmurray Aug 19, 2022
77756c8
Local changes on 034
atbenmurray Aug 29, 2022
73b185b
Resolving conflicts
atbenmurray Aug 29, 2022
886057e
Minor fix to Apply; minor fix to enumerate_results_of_op
atbenmurray Aug 29, 2022
f4ff23d
Merge branch 'lazy_resampling' of github.com:project-monai/monai into…
atbenmurray Aug 29, 2022
1421773
Working on transform based compose compilers
atbenmurray Aug 30, 2022
57ae027
Work on rotate_90 functional, and associated tests
atbenmurray Aug 31, 2022
2f1bf91
Work on apply
atbenmurray Sep 1, 2022
7a8f1de
Resolving conflicts
atbenmurray Sep 1, 2022
d0b490b
bug fixes
atbenmurray Sep 6, 2022
97216af
Removing dead code from apply
atbenmurray Sep 7, 2022
bb3a60e
Addition work on array
atbenmurray Sep 7, 2022
623889c
Merge branch 'lazy_resampling' of github.com:project-monai/monai into…
atbenmurray Sep 7, 2022
f53a56e
More lazy transforms
atbenmurray Sep 9, 2022
eb7692d
Further work on transforms
atbenmurray Sep 14, 2022
b353e3e
Fixes for zoom and rotate; rename of spaced to spacingd; introduction…
atbenmurray Oct 7, 2022
7653631
Merge branch 'lazy_resampling' of github.com:project-monai/monai into…
atbenmurray Oct 7, 2022
bcbfb68
Removing unnecessary comments in apply
atbenmurray Oct 7, 2022
af043e9
Adding utility transforms for compose compiler
atbenmurray Oct 12, 2022
6e8b745
Resolving merge conflicts
atbenmurray Oct 12, 2022
6295e02
Adding RandCropPad and RandCropPadMultiSample array implementations
atbenmurray Oct 14, 2022
b76e965
Compose compile; initial multisample generic croppad (array and dict)…
atbenmurray Oct 19, 2022
62f8172
More work towards lazy resampling
atbenmurray Oct 26, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions monai/data/meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

from __future__ import annotations

import copy

import warnings
from copy import deepcopy
from typing import Any, Sequence
Expand Down Expand Up @@ -151,6 +153,30 @@ def __init__(
if MetaKeys.SPACE not in self.meta:
self.meta[MetaKeys.SPACE] = SpaceKeys.RAS # defaulting to the right-anterior-superior space

self._pending_transforms = list()

def push_pending_transform(self, meta_matrix):
self._pending_transforms.append(meta_matrix)

@property
def has_pending_transforms(self):
return len(self._pending_transforms) > 0

def peek_pending_transform(self):
return copy.deepcopy(self._pending_transforms[-1])

def pop_pending_transform(self):
transform = self._pending_transforms[0]
self._pending_transforms.pop(0)
return transform

@property
def pending_transforms(self):
return copy.deepcopy(self._pending_transforms)

def clear_pending_transforms(self):
self._pending_transforms = list()

@staticmethod
def update_meta(rets: Sequence, func, args, kwargs) -> Sequence:
"""
Expand Down
Empty file.
222 changes: 222 additions & 0 deletions monai/transforms/atmostonce/apply.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
from typing import Optional, Sequence, Union

import itertools as it

import numpy as np

import torch

from monai.config import DtypeLike
from monai.data import MetaTensor
from monai.transforms import Affine
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.transform import MapTransform
from monai.transforms.atmostonce.utils import matmul
from monai.utils import GridSampleMode, GridSamplePadMode
from monai.utils.misc import get_backend_from_data, get_device_from_data
from monai.utils.mapping_stack import MatrixFactory, MetaMatrix, Matrix

# TODO: This should move to a common place to be shared with dictionary
from monai.utils.type_conversion import dtypes_to_str_or_identity

GridSampleModeSequence = Union[Sequence[Union[GridSampleMode, str]], GridSampleMode, str]
GridSamplePadModeSequence = Union[Sequence[Union[GridSamplePadMode, str]], GridSamplePadMode, str]
DtypeSequence = Union[Sequence[DtypeLike], DtypeLike]


# TODO: move to mapping_stack.py
def extents_from_shape(shape, dtype=np.float64):
extents = [[0, shape[i]] for i in range(1, len(shape))]

extents = it.product(*extents)
return list(np.asarray(e + (1,), dtype=dtype) for e in extents)


# TODO: move to mapping_stack.py
def shape_from_extents(
src_shape: Sequence,
extents: Union[Sequence[np.ndarray], Sequence[torch.Tensor], np.ndarray, torch.Tensor]
):
if isinstance(extents, (list, tuple)):
if isinstance(extents[0], np.ndarray):
aextents = np.asarray(extents)
aextents = torch.from_numpy(aextents)
else:
aextents = torch.stack(extents)
else:
if isinstance(extents, np.ndarray):
aextents = torch.from_numpy(extents)
else:
aextents = extents

mins = aextents.min(axis=0)[0]
maxes = aextents.max(axis=0)[0]
values = torch.round(maxes - mins).type(torch.IntTensor)[:-1]
return torch.cat((torch.as_tensor([src_shape[0]]), values))


def metadata_is_compatible(value_1, value_2):
if value_1 is None:
return True
else:
if value_2 is None:
return True
return value_1 == value_2


def metadata_dtype_is_compatible(value_1, value_2):
if value_1 is None:
return True
else:
if value_2 is None:
return True

# if we are here, value_1 and value_2 are both set
# TODO: this is not a good enough solution
value_1_ = dtypes_to_str_or_identity(value_1)
value_2_ = dtypes_to_str_or_identity(value_2)
return value_1_ == value_2_


def starting_matrix_and_extents(matrix_factory, data):
# set up the identity matrix and metadata
cumulative_matrix = matrix_factory.identity()
cumulative_extents = extents_from_shape(data.shape)
return cumulative_matrix, cumulative_extents


def prepare_args_dict_for_apply(cur_mode, cur_padding_mode, cur_device, cur_dtype):
kwargs = {}
if cur_mode is not None:
kwargs['mode'] = cur_mode
if cur_padding_mode is not None:
kwargs['padding_mode'] = cur_padding_mode
if cur_device is not None:
kwargs['device'] = cur_device
if cur_dtype is not None:
kwargs['dtype'] = cur_dtype

return kwargs


def matrix_from_matrix_container(matrix):
if isinstance(matrix, MetaMatrix):
return matrix.matrix.matrix
elif isinstance(matrix, Matrix):
return matrix.matrix
else:
return matrix


def apply(data: Union[torch.Tensor, MetaTensor],
pending: Optional[dict] = None):

if isinstance(data, dict):
rd = dict()
for k, v in data.items():
result = apply(v)
rd[k] = result
return rd

pending_ = pending
pending_ = data.pending_transforms

if len(pending_) == 0:
return data

dim_count = len(data.shape) - 1
matrix_factory = MatrixFactory(dim_count,
get_backend_from_data(data),
get_device_from_data(data))

# set up the identity matrix and metadata
cumulative_matrix, cumulative_extents = starting_matrix_and_extents(matrix_factory, data)

# set the various resampling parameters to an initial state
cur_mode = None
cur_padding_mode = None
cur_device = None
cur_dtype = None
cur_shape = data.shape

for meta_matrix in pending_:
next_matrix = meta_matrix.matrix
# print("intermediate matrix\n", matrix_from_matrix_container(cumulative_matrix))
cumulative_matrix = matmul(cumulative_matrix, next_matrix)
cumulative_extents = [matmul(e, cumulative_matrix) for e in cumulative_extents]

new_mode = meta_matrix.metadata.get('mode', None)
new_padding_mode = meta_matrix.metadata.get('padding_mode', None)
new_device = meta_matrix.metadata.get('device', None)
new_dtype = meta_matrix.metadata.get('dtype', None)
new_shape = meta_matrix.metadata.get('shape_override', None)

mode_compat = metadata_is_compatible(cur_mode, new_mode)
padding_mode_compat = metadata_is_compatible(cur_padding_mode, new_padding_mode)
device_compat = metadata_is_compatible(cur_device, new_device)
dtype_compat = metadata_dtype_is_compatible(cur_dtype, new_dtype)

if (mode_compat is False or padding_mode_compat is False or
device_compat is False or dtype_compat is False):
# carry out an intermediate resample here due to incompatibility between arguments
kwargs = prepare_args_dict_for_apply(cur_mode, cur_padding_mode, cur_device, cur_dtype)

cumulative_matrix_ = matrix_from_matrix_container(cumulative_matrix)
a = Affine(norm_coords=False,
affine=cumulative_matrix_,
**kwargs)
data, _ = a(img=data)

cumulative_matrix, cumulative_extents =\
starting_matrix_and_extents(matrix_factory, data)

cur_mode = cur_mode if new_mode is None else new_mode
cur_padding_mode = cur_padding_mode if new_padding_mode is None else new_padding_mode
cur_device = cur_device if new_device is None else new_device
cur_dtype = cur_dtype if new_dtype is None else new_dtype
cur_shape = cur_shape if new_shape is None else new_shape

kwargs = prepare_args_dict_for_apply(cur_mode, cur_padding_mode, cur_device, cur_dtype)

cumulative_matrix_ = matrix_from_matrix_container(cumulative_matrix)

# print(f"applying with cumulative matrix\n {cumulative_matrix_}")
a = Affine(norm_coords=False,
affine=cumulative_matrix_,
spatial_size=cur_shape[1:],
normalized=False,
**kwargs)
data, tx = a(img=data)
data.clear_pending_transforms()

return data


# make Apply universal for arrays and dictionaries; it just calls through to functional apply
class Apply(InvertibleTransform):

def __init__(self):
super().__init__()

def __call__(self, *args, **kwargs):
return apply(*args, **kwargs)

def inverse(self, data):
return NotImplementedError()


class Applyd(MapTransform, InvertibleTransform):

def __init__(self):
super().__init__()

def __call__(
self,
d: dict
):
rd = dict()
for k, v in d.items():
rd[k] = apply(v)

def inverse(self, data):
return NotImplementedError()
Loading