Skip to content

[Prototype] 4855 compose with lazy resampling #4911

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 58 additions & 2 deletions monai/data/meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from monai.data.utils import affine_to_spacing, decollate_batch, list_data_collate, remove_extra_metadata
from monai.utils import look_up_option
from monai.utils.enums import MetaKeys, PostFix, SpaceKeys
from monai.utils.type_conversion import convert_data_type, convert_to_tensor
from monai.utils.type_conversion import convert_data_type, convert_to_dst_type, convert_to_tensor

__all__ = ["MetaTensor"]

Expand Down Expand Up @@ -125,7 +125,7 @@ def __init__(
super().__init__()
# set meta
if meta is not None:
self.meta = meta
self.meta = dict(meta)
elif isinstance(x, MetaObj):
self.__dict__ = deepcopy(x.__dict__)
# set the affine
Expand All @@ -150,6 +150,62 @@ def __init__(

if MetaKeys.SPACE not in self.meta:
self.meta[MetaKeys.SPACE] = SpaceKeys.RAS # defaulting to the right-anterior-superior space
if MetaKeys.EVALUATED not in self.meta:
self.meta[MetaKeys.EVALUATED] = True
if MetaKeys.ORIGINAL_CHANNEL_DIM not in self.meta:
self.meta[MetaKeys.ORIGINAL_CHANNEL_DIM] = "no_channel" # defaulting to channel first

@property
def evaluated(self) -> bool:
"""a flag indicating whether the array content is up-to-date with the affine/spatial_shape properties."""
if MetaKeys.EVALUATED not in self.meta:
self.meta[MetaKeys.EVALUATED] = True
return bool(self.meta[MetaKeys.EVALUATED])

@evaluated.setter
def evaluated(self, value: bool):
"""when setting an evaluated metatensor to a lazy status, original affine will be stored."""
if not value and (MetaKeys.SPATIAL_SHAPE not in self.meta or MetaKeys.AFFINE not in self.meta):
warnings.warn("Setting MetaTensor to lazy evaluation requires spatial_shape and affine.")
if self.evaluated and not value:
self.meta[MetaKeys.ORIGINAL_AFFINE] = self.affine # switch to lazy evaluation, store current affine
self.meta[MetaKeys.SPATIAL_SHAPE] = self.spatial_shape
self.meta[MetaKeys.EVALUATED] = value

def evaluate(self, mode="bilinear", padding_mode="border"):
if self.evaluated:
self.spatial_shape = self.array.shape[1:]
return
# how to ensure channel first?
resampler = monai.transforms.SpatialResample(mode=mode, padding_mode=padding_mode)
dst_affine, self.affine = self.affine, self.meta[MetaKeys.ORIGINAL_AFFINE]
with resampler.trace_transform(False):
output = resampler(self, dst_affine=dst_affine, spatial_size=self.spatial_shape, align_corners=True)
self.array = output.array
self.spatial_shape = self.array.shape[1:]
self.affine = dst_affine
self.evaluated = True
return

@property
def spatial_shape(self):
"""if spatial shape is undefined, it infers the shape from array shape and original channel dim."""
if MetaKeys.SPATIAL_SHAPE not in self.meta:
_shape = list(self.array.shape)
channel_dim = self.meta.get(MetaKeys.ORIGINAL_CHANNEL_DIM, 0)
if _shape and channel_dim != "no_channel":
_shape.pop(int(channel_dim))
else:
_shape = self.meta.get(MetaKeys.SPATIAL_SHAPE)
if not isinstance(_shape, torch.Tensor):
self.meta[MetaKeys.SPATIAL_SHAPE] = convert_to_tensor(
_shape, device=torch.device("cpu"), wrap_sequence=True, track_meta=False
)
return self.meta[MetaKeys.SPATIAL_SHAPE]

@spatial_shape.setter
def spatial_shape(self, value):
self.meta[MetaKeys.SPATIAL_SHAPE] = convert_to_dst_type(value, self.spatial_shape, wrap_sequence=True)[0]

@staticmethod
def update_meta(rets: Sequence, func, args, kwargs) -> Sequence:
Expand Down
59 changes: 58 additions & 1 deletion monai/transforms/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,44 @@

# For backwards compatibility (so this still works: from monai.transforms.compose import MapTransform)
from monai.transforms.transform import ( # noqa: F401
LazyTransform,
MapTransform,
Randomizable,
RandomizableTransform,
Transform,
apply_transform,
)
from monai.utils import MAX_SEED, ensure_tuple, get_seed
from monai.utils.enums import TraceKeys
from monai.utils.enums import GridSampleMode, GridSamplePadMode, TraceKeys

__all__ = ["Compose", "OneOf"]


def eval_lazy_stack(
data, upcoming, lazy_resample: bool = False, mode=GridSampleMode.BILINEAR, padding_mode=GridSamplePadMode.BORDER
):
"""
Given the upcoming transform ``upcoming``, if lazy_resample is True, go through the Metatensors and
evaluate the lazy applied operations. The returned `data` will then be ready for the ``upcoming`` transform.
"""
if not lazy_resample:
return data # eager evaluation
if isinstance(data, monai.data.MetaTensor):
if lazy_resample and not isinstance(upcoming, LazyTransform):
data.evaluate(mode=mode, padding_mode=padding_mode)
return data
if isinstance(data, Mapping):
if isinstance(upcoming, MapTransform):
return {
k: eval_lazy_stack(v, upcoming, lazy_resample, mode, padding_mode) if k in upcoming.keys else v
for k, v in data.items()
}
return {k: eval_lazy_stack(v, upcoming, lazy_resample, mode, padding_mode) for k, v in data.items()}
if isinstance(data, (list, tuple)):
return [eval_lazy_stack(v, upcoming, lazy_resample, mode, padding_mode) for v in data]
return data


class Compose(Randomizable, InvertibleTransform):
"""
``Compose`` provides the ability to chain a series of callables together in
Expand Down Expand Up @@ -111,6 +137,16 @@ class Compose(Randomizable, InvertibleTransform):
log_stats: whether to log the detailed information of data and applied transform when error happened,
for NumPy array and PyTorch Tensor, log the data shape and value range,
for other metadata, log the values directly. default to `False`.
lazy_resample: whether to compute consecutive spatial transforms resampling lazily. Default to False.
mode: {``"bilinear"``, ``"nearest"``}
Interpolation mode when ``lazy_resample=True``. Defaults to ``"bilinear"``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
When `USE_COMPILED` is `True`, this argument uses
``"nearest"``, ``"bilinear"``, ``"bicubic"`` to indicate 0, 1, 3 order interpolations.
See also: https://docs.monai.io/en/stable/networks.html#grid-pull
padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
Padding mode for outside grid values when ``lazy_resample=True``. Defaults to ``"border"``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html

"""

Expand All @@ -120,15 +156,26 @@ def __init__(
map_items: bool = True,
unpack_items: bool = False,
log_stats: bool = False,
lazy_resample: bool = False,
mode=GridSampleMode.BILINEAR,
padding_mode=GridSamplePadMode.BORDER,
) -> None:
if transforms is None:
transforms = []
self.transforms = ensure_tuple(transforms)
self.map_items = map_items
self.unpack_items = unpack_items
self.log_stats = log_stats
self.lazy_resample = lazy_resample
self.mode = mode
self.padding_mode = padding_mode
self.set_random_state(seed=get_seed())

if self.lazy_resample:
for t in self.flatten().transforms: # TODO: test Compose of Compose/OneOf
if isinstance(t, LazyTransform):
t.set_eager_mode(False)

def set_random_state(self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None) -> "Compose":
super().set_random_state(seed=seed, state=state)
for _transform in self.transforms:
Expand Down Expand Up @@ -171,7 +218,17 @@ def __len__(self):

def __call__(self, input_):
for _transform in self.transforms:
input_ = eval_lazy_stack(
input_,
upcoming=_transform,
lazy_resample=self.lazy_resample,
mode=self.mode,
padding_mode=self.padding_mode,
)
input_ = apply_transform(_transform, input_, self.map_items, self.unpack_items, self.log_stats)
input_ = eval_lazy_stack(
input_, upcoming=None, lazy_resample=self.lazy_resample, mode=self.mode, padding_mode=self.padding_mode
)
return input_

def inverse(self, data):
Expand Down
Loading