|
9 | 9 | # See the License for the specific language governing permissions and
|
10 | 10 | # limitations under the License.
|
11 | 11 |
|
12 |
| -import itertools as it |
13 |
| -from typing import Optional, Sequence, Union |
| 12 | +from typing import Optional, Union |
14 | 13 |
|
15 | 14 | import numpy as np
|
16 | 15 | import torch
|
17 | 16 |
|
18 |
| -from monai.config import DtypeLike |
19 | 17 | from monai.data.meta_tensor import MetaTensor
|
20 |
| -from monai.transforms.meta_matrix import Matrix, MatrixFactory, MetaMatrix, matmul |
21 |
| -from monai.transforms.spatial.array import Affine |
22 |
| -from monai.transforms.utils import dtypes_to_str_or_identity, get_backend_from_tensor_like, get_device_from_tensor_like |
23 |
| -from monai.utils import GridSampleMode, GridSamplePadMode |
| 18 | +from monai.data.utils import to_affine_nd |
| 19 | +from monai.transforms.meta_matrix import matmul |
| 20 | +from monai.transforms.utility.functional import resample |
| 21 | +from monai.utils import LazyAttr |
24 | 22 |
|
25 | 23 | __all__ = ["apply"]
|
26 | 24 |
|
27 |
| -# TODO: This should move to a common place to be shared with dictionary |
28 |
| -# from monai.utils.type_conversion import dtypes_to_str_or_identity |
29 |
| - |
30 |
| -GridSampleModeSequence = Union[Sequence[Union[GridSampleMode, str]], GridSampleMode, str] |
31 |
| -GridSamplePadModeSequence = Union[Sequence[Union[GridSamplePadMode, str]], GridSamplePadMode, str] |
32 |
| -DtypeSequence = Union[Sequence[DtypeLike], DtypeLike] |
33 |
| - |
34 |
| - |
35 |
| -# TODO: move to mapping_stack.py |
36 |
| -def extents_from_shape(shape, dtype=np.float32): |
37 |
| - extents = [[0, shape[i]] for i in range(1, len(shape))] |
38 |
| - |
39 |
| - extents = it.product(*extents) |
40 |
| - return [np.asarray(e + (1,), dtype=dtype) for e in extents] |
41 |
| - |
42 |
| - |
43 |
| -# TODO: move to mapping_stack.py |
44 |
| -def shape_from_extents( |
45 |
| - src_shape: Sequence, extents: Union[Sequence[np.ndarray], Sequence[torch.Tensor], np.ndarray, torch.Tensor] |
46 |
| -): |
47 |
| - if isinstance(extents, (list, tuple)): |
48 |
| - if isinstance(extents[0], np.ndarray): |
49 |
| - aextents = np.asarray(extents) |
50 |
| - else: |
51 |
| - aextents = torch.stack(extents) |
52 |
| - aextents = aextents.numpy() |
53 |
| - else: |
54 |
| - if isinstance(extents, np.ndarray): |
55 |
| - aextents = extents |
56 |
| - else: |
57 |
| - aextents = extents.numpy() |
58 |
| - |
59 |
| - mins = aextents.min(axis=0) |
60 |
| - maxes = aextents.max(axis=0) |
61 |
| - values = np.round(maxes - mins).astype(int)[:-1].tolist() |
62 |
| - return (src_shape[0],) + tuple(values) |
63 |
| - |
64 |
| - |
65 |
| -def metadata_is_compatible(value_1, value_2): |
66 |
| - if value_1 is None: |
67 |
| - return True |
68 |
| - else: |
69 |
| - if value_2 is None: |
70 |
| - return True |
71 |
| - return value_1 == value_2 |
72 |
| - |
73 |
| - |
74 |
| -def metadata_dtype_is_compatible(value_1, value_2): |
75 |
| - if value_1 is None: |
76 |
| - return True |
77 |
| - else: |
78 |
| - if value_2 is None: |
79 |
| - return True |
80 |
| - |
81 |
| - # if we are here, value_1 and value_2 are both set |
82 |
| - # TODO: this is not a good enough solution |
83 |
| - value_1_ = dtypes_to_str_or_identity(value_1) |
84 |
| - value_2_ = dtypes_to_str_or_identity(value_2) |
85 |
| - return value_1_ == value_2_ |
86 |
| - |
87 |
| - |
88 |
| -def starting_matrix_and_extents(matrix_factory, data): |
89 |
| - # set up the identity matrix and metadata |
90 |
| - cumulative_matrix = matrix_factory.identity() |
91 |
| - cumulative_extents = extents_from_shape(data.shape) |
92 |
| - return cumulative_matrix, cumulative_extents |
93 |
| - |
94 |
| - |
95 |
| -def prepare_args_dict_for_apply(cur_mode, cur_padding_mode, cur_device, cur_dtype): |
96 |
| - kwargs = {} |
97 |
| - if cur_mode is not None: |
98 |
| - kwargs["mode"] = cur_mode |
99 |
| - if cur_padding_mode is not None: |
100 |
| - kwargs["padding_mode"] = cur_padding_mode |
101 |
| - if cur_device is not None: |
102 |
| - kwargs["device"] = cur_device |
103 |
| - if cur_dtype is not None: |
104 |
| - kwargs["dtype"] = cur_dtype |
105 |
| - |
106 |
| - return kwargs |
107 |
| - |
108 |
| - |
109 |
| -def matrix_from_matrix_container(matrix): |
110 |
| - if isinstance(matrix, MetaMatrix): |
111 |
| - return matrix.matrix.data |
112 |
| - elif isinstance(matrix, Matrix): |
113 |
| - return matrix.data |
114 |
| - else: |
115 |
| - return matrix |
116 |
| - |
117 |
| - |
118 |
| -def apply(data: Union[torch.Tensor, MetaTensor, dict], pending: Optional[Union[dict, list]] = None): |
119 |
| - """ |
120 |
| - This method applies pending transforms to tensors. |
121 |
| - Args: |
122 |
| - data: A torch Tensor, monai MetaTensor, or a dictionary containing Tensors / MetaTensors |
123 |
| - pending: Optional arg containing pending transforms. This must be set if data is a Tensor |
124 |
| - or dictionary of Tensors, but is optional if data is a MetaTensor / dictionary of |
125 |
| - MetaTensors. |
126 |
| - """ |
127 |
| - if isinstance(data, dict): |
128 |
| - rd = dict() |
129 |
| - for k, v in data.items(): |
130 |
| - result = v(*pending) |
131 |
| - rd[k] = result |
132 |
| - return rd |
133 | 25 |
|
134 |
| - if isinstance(data, MetaTensor) and pending is None: |
135 |
| - pending_ = data.pending_operations |
136 |
| - else: |
137 |
| - pending_ = [] if pending is None else pending |
| 26 | +def mat_from_pending(pending_item): |
| 27 | + if isinstance(pending_item, (torch.Tensor, np.ndarray)): |
| 28 | + return pending_item |
| 29 | + if isinstance(pending_item, dict): |
| 30 | + return pending_item[LazyAttr.AFFINE] |
| 31 | + return pending_item |
138 | 32 |
|
139 |
| - if len(pending_) == 0: |
140 |
| - return data |
141 | 33 |
|
142 |
| - dim_count = len(data.shape) - 1 |
143 |
| - matrix_factory = MatrixFactory(dim_count, get_backend_from_tensor_like(data), get_device_from_tensor_like(data)) |
144 |
| - |
145 |
| - # set up the identity matrix and metadata |
146 |
| - cumulative_matrix, cumulative_extents = starting_matrix_and_extents(matrix_factory, data) |
147 |
| - |
148 |
| - # set the various resampling parameters to an initial state |
149 |
| - cur_mode = None |
150 |
| - cur_padding_mode = None |
151 |
| - cur_device = None |
152 |
| - cur_dtype = None |
153 |
| - cur_shape = data.shape |
154 |
| - |
155 |
| - for meta_matrix in pending_: |
156 |
| - next_matrix = meta_matrix.matrix |
157 |
| - # print("intermediate matrix\n", matrix_from_matrix_container(cumulative_matrix)) |
158 |
| - cumulative_matrix = matmul(cumulative_matrix, next_matrix) |
159 |
| - cumulative_extents = [matmul(e, cumulative_matrix) for e in cumulative_extents] |
160 |
| - |
161 |
| - new_mode = meta_matrix.metadata.get("mode", None) |
162 |
| - new_padding_mode = meta_matrix.metadata.get("padding_mode", None) |
163 |
| - new_device = meta_matrix.metadata.get("device", None) |
164 |
| - new_dtype = meta_matrix.metadata.get("dtype", None) |
165 |
| - new_shape = meta_matrix.metadata.get("shape_override", None) |
166 |
| - |
167 |
| - mode_compat = metadata_is_compatible(cur_mode, new_mode) |
168 |
| - padding_mode_compat = metadata_is_compatible(cur_padding_mode, new_padding_mode) |
169 |
| - device_compat = metadata_is_compatible(cur_device, new_device) |
170 |
| - dtype_compat = metadata_dtype_is_compatible(cur_dtype, new_dtype) |
171 |
| - |
172 |
| - if mode_compat is False or padding_mode_compat is False or device_compat is False or dtype_compat is False: |
173 |
| - # carry out an intermediate resample here due to incompatibility between arguments |
174 |
| - kwargs = prepare_args_dict_for_apply(cur_mode, cur_padding_mode, cur_device, cur_dtype) |
| 34 | +def kwargs_from_pending(pending_item): |
| 35 | + if not isinstance(pending_item, dict): |
| 36 | + return {} |
| 37 | + ret = { |
| 38 | + LazyAttr.INTERP_MODE: pending_item.get(LazyAttr.INTERP_MODE, None), # interpolation mode |
| 39 | + LazyAttr.PADDING_MODE: pending_item.get(LazyAttr.PADDING_MODE, None), # padding mode |
| 40 | + } |
| 41 | + if LazyAttr.SHAPE in pending_item: |
| 42 | + ret[LazyAttr.SHAPE] = pending_item[LazyAttr.SHAPE] |
| 43 | + if LazyAttr.DTYPE in pending_item: |
| 44 | + ret[LazyAttr.DTYPE] = pending_item[LazyAttr.DTYPE] |
| 45 | + return ret |
| 46 | + |
175 | 47 |
|
176 |
| - cumulative_matrix_ = matrix_from_matrix_container(cumulative_matrix) |
177 |
| - a = Affine(affine=cumulative_matrix_, **kwargs) |
178 |
| - data, _ = a(img=data) |
| 48 | +def is_compatible_kwargs(kwargs_1, kwargs_2): |
| 49 | + return True |
179 | 50 |
|
180 |
| - cumulative_matrix, cumulative_extents = starting_matrix_and_extents(matrix_factory, data) |
181 | 51 |
|
182 |
| - cur_mode = cur_mode if new_mode is None else new_mode |
183 |
| - cur_padding_mode = cur_padding_mode if new_padding_mode is None else new_padding_mode |
184 |
| - cur_device = cur_device if new_device is None else new_device |
185 |
| - cur_dtype = cur_dtype if new_dtype is None else new_dtype |
186 |
| - cur_shape = cur_shape if new_shape is None else new_shape |
| 52 | +def apply(data: Union[torch.Tensor, MetaTensor], pending: Optional[list] = None): |
| 53 | + """ |
| 54 | + This method applies pending transforms to tensors. |
| 55 | +
|
| 56 | + Args: |
| 57 | + data: A torch Tensor, monai MetaTensor |
| 58 | + pending: pending transforms. This must be set if data is a Tensor, but is optional if data is a MetaTensor. |
| 59 | + """ |
| 60 | + if isinstance(data, MetaTensor) and pending is None: |
| 61 | + pending = data.pending_operations |
| 62 | + pending = [] if pending is None else pending |
187 | 63 |
|
188 |
| - kwargs = prepare_args_dict_for_apply(cur_mode, cur_padding_mode, cur_device, cur_dtype) |
| 64 | + if not pending: |
| 65 | + return data |
189 | 66 |
|
190 |
| - cumulative_matrix_ = matrix_from_matrix_container(cumulative_matrix) |
| 67 | + cumulative_xform = mat_from_pending(pending[0]) |
| 68 | + cur_kwargs = kwargs_from_pending(pending[0]) |
191 | 69 |
|
192 |
| - # print(f"applying with cumulative matrix\n {cumulative_matrix_}") |
193 |
| - a = Affine(affine=cumulative_matrix_, spatial_size=cur_shape[1:], normalized=False, **kwargs) |
194 |
| - data, tx = a(img=data) |
| 70 | + for p in pending[1:]: |
| 71 | + new_kwargs = kwargs_from_pending(p) |
| 72 | + if not is_compatible_kwargs(cur_kwargs, new_kwargs): |
| 73 | + # carry out an intermediate resample here due to incompatibility between arguments |
| 74 | + data = resample(data, cumulative_xform, cur_kwargs) |
| 75 | + next_matrix = mat_from_pending(p) |
| 76 | + cumulative_xform = matmul(cumulative_xform, next_matrix) |
| 77 | + cur_kwargs.update(new_kwargs) |
| 78 | + data = resample(data, cumulative_xform, cur_kwargs) |
195 | 79 | if isinstance(data, MetaTensor):
|
196 | 80 | data.clear_pending_operations()
|
197 |
| - for p in pending_: |
198 |
| - data.affine = p.matrix.data |
| 81 | + data.affine = data.affine @ to_affine_nd(3, cumulative_xform) |
| 82 | + for p in pending: |
199 | 83 | data.push_applied_operation(p)
|
200 | 84 |
|
201 |
| - return data, None if pending is None else pending_ |
| 85 | + return data, pending |
0 commit comments