Skip to content

RandomizableTransformType, LazyTransformType, and MultiSampleTransformType #5410

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 15 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,21 @@ Generic Interfaces
.. automodule:: monai.transforms
.. currentmodule:: monai.transforms

`RandomizableTransformType`
^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: RandomizableTransformType
:members:

`LazyTransformType`
^^^^^^^^^^^^^^^^^^^
.. autoclass:: LazyTransformType
:members:

`MultiSampleTransformType`
^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: MultiSampleTransformType
:members:

`Transform`
^^^^^^^^^^^
.. autoclass:: Transform
Expand Down
12 changes: 11 additions & 1 deletion monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,17 @@
ZoomD,
ZoomDict,
)
from .transform import MapTransform, Randomizable, RandomizableTransform, ThreadUnsafe, Transform, apply_transform
from .transform import (
LazyTransformType,
MapTransform,
MultiSampleTransformType,
Randomizable,
RandomizableTransform,
RandomizableTransformType,
ThreadUnsafe,
Transform,
apply_transform,
)
from .utility.array import (
AddChannel,
AddCoordinateChannels,
Expand Down
82 changes: 80 additions & 2 deletions monai/transforms/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,17 @@
from monai.utils.enums import TransformBackends
from monai.utils.misc import MONAIEnvVars

__all__ = ["ThreadUnsafe", "apply_transform", "Randomizable", "RandomizableTransform", "Transform", "MapTransform"]
__all__ = [
"ThreadUnsafe",
"apply_transform",
"LazyTransformType",
"RandomizableTransformType",
"MultiSampleTransformType",
"Randomizable",
"RandomizableTransform",
"Transform",
"MapTransform",
]

ReturnType = TypeVar("ReturnType")

Expand Down Expand Up @@ -118,6 +128,74 @@ def _log_stats(data, prefix: Optional[str] = "Data"):
raise RuntimeError(f"applying transform {transform}") from e


class LazyTransformType:
"""
An interface to indicate that the transform has the capability to describe
its operation as an affine matrix or grid with accompanying metadata. This
interface can be extended from by people adapting transforms to the MONAI framework as well as
by implementors of MONAI transforms.
"""

@property
def lazy_evaluation(self):
"""
Get whether lazy_evaluation is enabled for this transform instance.

Returns:
True if the transform is operating in a lazy fashion, False if not.
"""
raise NotImplementedError()

@lazy_evaluation.setter
def lazy_evaluation(self, enabled: bool):
"""
Set whether lazy_evaluation is enabled for this transform instance.

Args:
enabled: True if the transform should operate in a lazy fashion, False if not.
"""
raise NotImplementedError()


class RandomizableTransformType:
"""
An interface to indicate that the transform has the capability to perform
randomized transforms to the data that it is called upon. This interface
can be extended from by people adapting transforms to the MONAI framework as well as by
implementors of MONAI transforms.
"""

def set_random_state(
self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None
) -> "RandomizableTransformType":
"""
Set either the seed for an inbuilt random generator (assumed to be np.random.RandomState)
or set a random generator for this transform to use (again, assumed to be
np.random.RandomState). One one of these parameters should be set. If your random transform
that implements this interface doesn't support setting or reseeding of its random
generator, this method does not need to be implemented.

Args:
seed: set the random state with an integer seed.
state: set the random state with a `np.random.RandomState` object.

Returns:
self as a convenience for assignment
"""
raise TypeError(f"{self.__class__.__name__} does not support setting of random state via set_random_state.")


class MultiSampleTransformType:
"""
An interface to indicate that the transform has the capability to return multiple samples
given an input, such as when performing random crops of a sample. This interface can be
extended from by people adapting transforms to the MONAI framework as well as by implementors
of MONAI transforms.
"""

pass


class ThreadUnsafe:
"""
A class to denote that the transform will mutate its member variables,
Expand Down Expand Up @@ -251,7 +329,7 @@ def __call__(self, data: Any):
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")


class RandomizableTransform(Randomizable, Transform):
class RandomizableTransform(Randomizable, Transform, RandomizableTransformType):
"""
An interface for handling random state locally, currently based on a class variable `R`,
which is an instance of `np.random.RandomState`.
Expand Down
6 changes: 3 additions & 3 deletions monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,11 +505,11 @@ def generate_pos_neg_label_crop_centers(
raise ValueError("No sampling location available.")

if len(fg_indices) == 0 or len(bg_indices) == 0:
pos_ratio = 0 if len(fg_indices) == 0 else 1
warnings.warn(
f"N foreground {len(fg_indices)}, N background {len(bg_indices)},"
"unable to generate class balanced samples."
f"Num foregrounds {len(fg_indices)}, Num backgrounds {len(bg_indices)}, "
f"unable to generate class balanced samples, setting `pos_ratio` to {pos_ratio}."
)
pos_ratio = 0 if fg_indices.size == 0 else 1

for _ in range(num_samples):
indices_to_use = fg_indices if rand_state.rand() < pos_ratio else bg_indices
Expand Down
15 changes: 14 additions & 1 deletion tests/test_generate_pos_neg_label_crop_centers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,20 @@
list,
2,
3,
]
],
[
{
"spatial_size": [2, 2, 2],
"num_samples": 2,
"pos_ratio": 0.0,
"label_spatial_shape": [3, 3, 3],
"fg_indices": [],
"bg_indices": [3, 12, 21],
},
list,
2,
3,
],
]


Expand Down
38 changes: 38 additions & 0 deletions tests/test_randomizable_transform_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

from monai.transforms.transform import RandomizableTransform, RandomizableTransformType


class InheritsInterface(RandomizableTransformType):
pass


class InheritsImplementation(RandomizableTransform):
def __call__(self, data):
return data


class TestRandomizableTransformType(unittest.TestCase):
def test_is_randomizable_transform_type(self):
inst = InheritsInterface()
self.assertIsInstance(inst, RandomizableTransformType)

def test_set_random_state_default_impl(self):
inst = InheritsInterface()
with self.assertRaises(TypeError):
inst.set_random_state(seed=0)

def test_set_random_state_randomizable_transform(self):
inst = InheritsImplementation()
inst.set_random_state(0)