diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 874f01a945..d178485e75 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -10,6 +10,21 @@ Generic Interfaces .. automodule:: monai.transforms .. currentmodule:: monai.transforms +`RandomizableTransformType` +^^^^^^^^^^^^^^^^^^^^^^^^^^^ +.. autoclass:: RandomizableTransformType + :members: + +`LazyTransformType` +^^^^^^^^^^^^^^^^^^^ +.. autoclass:: LazyTransformType + :members: + +`MultiSampleTransformType` +^^^^^^^^^^^^^^^^^^^^^^^^^^ +.. autoclass:: MultiSampleTransformType + :members: + `Transform` ^^^^^^^^^^^ .. autoclass:: Transform diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 389571d16f..b51123d856 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -449,7 +449,17 @@ ZoomD, ZoomDict, ) -from .transform import MapTransform, Randomizable, RandomizableTransform, ThreadUnsafe, Transform, apply_transform +from .transform import ( + LazyTransformType, + MapTransform, + MultiSampleTransformType, + Randomizable, + RandomizableTransform, + RandomizableTransformType, + ThreadUnsafe, + Transform, + apply_transform, +) from .utility.array import ( AddChannel, AddCoordinateChannels, diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 21d057f5d3..560e55bed3 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -26,7 +26,17 @@ from monai.utils.enums import TransformBackends from monai.utils.misc import MONAIEnvVars -__all__ = ["ThreadUnsafe", "apply_transform", "Randomizable", "RandomizableTransform", "Transform", "MapTransform"] +__all__ = [ + "ThreadUnsafe", + "apply_transform", + "LazyTransformType", + "RandomizableTransformType", + "MultiSampleTransformType", + "Randomizable", + "RandomizableTransform", + "Transform", + "MapTransform", +] ReturnType = TypeVar("ReturnType") @@ -118,6 +128,74 @@ def _log_stats(data, prefix: Optional[str] = "Data"): raise RuntimeError(f"applying transform {transform}") from e +class LazyTransformType: + """ + An interface to indicate that the transform has the capability to describe + its operation as an affine matrix or grid with accompanying metadata. This + interface can be extended from by people adapting transforms to the MONAI framework as well as + by implementors of MONAI transforms. + """ + + @property + def lazy_evaluation(self): + """ + Get whether lazy_evaluation is enabled for this transform instance. + + Returns: + True if the transform is operating in a lazy fashion, False if not. + """ + raise NotImplementedError() + + @lazy_evaluation.setter + def lazy_evaluation(self, enabled: bool): + """ + Set whether lazy_evaluation is enabled for this transform instance. + + Args: + enabled: True if the transform should operate in a lazy fashion, False if not. + """ + raise NotImplementedError() + + +class RandomizableTransformType: + """ + An interface to indicate that the transform has the capability to perform + randomized transforms to the data that it is called upon. This interface + can be extended from by people adapting transforms to the MONAI framework as well as by + implementors of MONAI transforms. + """ + + def set_random_state( + self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None + ) -> "RandomizableTransformType": + """ + Set either the seed for an inbuilt random generator (assumed to be np.random.RandomState) + or set a random generator for this transform to use (again, assumed to be + np.random.RandomState). One one of these parameters should be set. If your random transform + that implements this interface doesn't support setting or reseeding of its random + generator, this method does not need to be implemented. + + Args: + seed: set the random state with an integer seed. + state: set the random state with a `np.random.RandomState` object. + + Returns: + self as a convenience for assignment + """ + raise TypeError(f"{self.__class__.__name__} does not support setting of random state via set_random_state.") + + +class MultiSampleTransformType: + """ + An interface to indicate that the transform has the capability to return multiple samples + given an input, such as when performing random crops of a sample. This interface can be + extended from by people adapting transforms to the MONAI framework as well as by implementors + of MONAI transforms. + """ + + pass + + class ThreadUnsafe: """ A class to denote that the transform will mutate its member variables, @@ -251,7 +329,7 @@ def __call__(self, data: Any): raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") -class RandomizableTransform(Randomizable, Transform): +class RandomizableTransform(Randomizable, Transform, RandomizableTransformType): """ An interface for handling random state locally, currently based on a class variable `R`, which is an instance of `np.random.RandomState`. diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 3096d76889..e96d906f20 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -505,11 +505,11 @@ def generate_pos_neg_label_crop_centers( raise ValueError("No sampling location available.") if len(fg_indices) == 0 or len(bg_indices) == 0: + pos_ratio = 0 if len(fg_indices) == 0 else 1 warnings.warn( - f"N foreground {len(fg_indices)}, N background {len(bg_indices)}," - "unable to generate class balanced samples." + f"Num foregrounds {len(fg_indices)}, Num backgrounds {len(bg_indices)}, " + f"unable to generate class balanced samples, setting `pos_ratio` to {pos_ratio}." ) - pos_ratio = 0 if fg_indices.size == 0 else 1 for _ in range(num_samples): indices_to_use = fg_indices if rand_state.rand() < pos_ratio else bg_indices diff --git a/tests/test_generate_pos_neg_label_crop_centers.py b/tests/test_generate_pos_neg_label_crop_centers.py index 91db0e9d96..d1a208770f 100644 --- a/tests/test_generate_pos_neg_label_crop_centers.py +++ b/tests/test_generate_pos_neg_label_crop_centers.py @@ -31,7 +31,20 @@ list, 2, 3, - ] + ], + [ + { + "spatial_size": [2, 2, 2], + "num_samples": 2, + "pos_ratio": 0.0, + "label_spatial_shape": [3, 3, 3], + "fg_indices": [], + "bg_indices": [3, 12, 21], + }, + list, + 2, + 3, + ], ] diff --git a/tests/test_randomizable_transform_type.py b/tests/test_randomizable_transform_type.py new file mode 100644 index 0000000000..92c69a3b4e --- /dev/null +++ b/tests/test_randomizable_transform_type.py @@ -0,0 +1,38 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from monai.transforms.transform import RandomizableTransform, RandomizableTransformType + + +class InheritsInterface(RandomizableTransformType): + pass + + +class InheritsImplementation(RandomizableTransform): + def __call__(self, data): + return data + + +class TestRandomizableTransformType(unittest.TestCase): + def test_is_randomizable_transform_type(self): + inst = InheritsInterface() + self.assertIsInstance(inst, RandomizableTransformType) + + def test_set_random_state_default_impl(self): + inst = InheritsInterface() + with self.assertRaises(TypeError): + inst.set_random_state(seed=0) + + def test_set_random_state_randomizable_transform(self): + inst = InheritsImplementation() + inst.set_random_state(0)