Skip to content

Commit 87c4356

Browse files
committed
apply/matmul/resample MVP
Signed-off-by: Wenqi Li <[email protected]>
1 parent b2ebc6d commit 87c4356

File tree

8 files changed

+92
-745
lines changed

8 files changed

+92
-745
lines changed

monai/transforms/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@
228228
from .io.array import SUPPORTED_READERS, LoadImage, SaveImage
229229
from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict
230230
from .lazy.functional import apply
231-
from .meta_matrix import Grid, Matrix, MatrixFactory, MetaMatrix, matmul
231+
from .meta_matrix import matmul
232232
from .meta_utility.dictionary import (
233233
FromMetaTensord,
234234
FromMetaTensorD,
@@ -634,8 +634,6 @@
634634
generate_label_classes_crop_centers,
635635
generate_pos_neg_label_crop_centers,
636636
generate_spatial_bounding_box,
637-
get_backend_from_tensor_like,
638-
get_device_from_tensor_like,
639637
get_extreme_points,
640638
get_largest_connected_component_mask,
641639
get_number_image_type_conversions,

monai/transforms/lazy/functional.py

Lines changed: 53 additions & 169 deletions
Original file line numberDiff line numberDiff line change
@@ -9,193 +9,77 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
import itertools as it
13-
from typing import Optional, Sequence, Union
12+
from typing import Optional, Union
1413

1514
import numpy as np
1615
import torch
1716

18-
from monai.config import DtypeLike
1917
from monai.data.meta_tensor import MetaTensor
20-
from monai.transforms.meta_matrix import Matrix, MatrixFactory, MetaMatrix, matmul
21-
from monai.transforms.spatial.array import Affine
22-
from monai.transforms.utils import dtypes_to_str_or_identity, get_backend_from_tensor_like, get_device_from_tensor_like
23-
from monai.utils import GridSampleMode, GridSamplePadMode
18+
from monai.data.utils import to_affine_nd
19+
from monai.transforms.meta_matrix import matmul
20+
from monai.transforms.utility.functional import resample
21+
from monai.utils import LazyAttr
2422

2523
__all__ = ["apply"]
2624

27-
# TODO: This should move to a common place to be shared with dictionary
28-
# from monai.utils.type_conversion import dtypes_to_str_or_identity
29-
30-
GridSampleModeSequence = Union[Sequence[Union[GridSampleMode, str]], GridSampleMode, str]
31-
GridSamplePadModeSequence = Union[Sequence[Union[GridSamplePadMode, str]], GridSamplePadMode, str]
32-
DtypeSequence = Union[Sequence[DtypeLike], DtypeLike]
33-
34-
35-
# TODO: move to mapping_stack.py
36-
def extents_from_shape(shape, dtype=np.float32):
37-
extents = [[0, shape[i]] for i in range(1, len(shape))]
38-
39-
extents = it.product(*extents)
40-
return [np.asarray(e + (1,), dtype=dtype) for e in extents]
41-
42-
43-
# TODO: move to mapping_stack.py
44-
def shape_from_extents(
45-
src_shape: Sequence, extents: Union[Sequence[np.ndarray], Sequence[torch.Tensor], np.ndarray, torch.Tensor]
46-
):
47-
if isinstance(extents, (list, tuple)):
48-
if isinstance(extents[0], np.ndarray):
49-
aextents = np.asarray(extents)
50-
else:
51-
aextents = torch.stack(extents)
52-
aextents = aextents.numpy()
53-
else:
54-
if isinstance(extents, np.ndarray):
55-
aextents = extents
56-
else:
57-
aextents = extents.numpy()
58-
59-
mins = aextents.min(axis=0)
60-
maxes = aextents.max(axis=0)
61-
values = np.round(maxes - mins).astype(int)[:-1].tolist()
62-
return (src_shape[0],) + tuple(values)
63-
64-
65-
def metadata_is_compatible(value_1, value_2):
66-
if value_1 is None:
67-
return True
68-
else:
69-
if value_2 is None:
70-
return True
71-
return value_1 == value_2
72-
73-
74-
def metadata_dtype_is_compatible(value_1, value_2):
75-
if value_1 is None:
76-
return True
77-
else:
78-
if value_2 is None:
79-
return True
80-
81-
# if we are here, value_1 and value_2 are both set
82-
# TODO: this is not a good enough solution
83-
value_1_ = dtypes_to_str_or_identity(value_1)
84-
value_2_ = dtypes_to_str_or_identity(value_2)
85-
return value_1_ == value_2_
86-
87-
88-
def starting_matrix_and_extents(matrix_factory, data):
89-
# set up the identity matrix and metadata
90-
cumulative_matrix = matrix_factory.identity()
91-
cumulative_extents = extents_from_shape(data.shape)
92-
return cumulative_matrix, cumulative_extents
93-
94-
95-
def prepare_args_dict_for_apply(cur_mode, cur_padding_mode, cur_device, cur_dtype):
96-
kwargs = {}
97-
if cur_mode is not None:
98-
kwargs["mode"] = cur_mode
99-
if cur_padding_mode is not None:
100-
kwargs["padding_mode"] = cur_padding_mode
101-
if cur_device is not None:
102-
kwargs["device"] = cur_device
103-
if cur_dtype is not None:
104-
kwargs["dtype"] = cur_dtype
105-
106-
return kwargs
107-
108-
109-
def matrix_from_matrix_container(matrix):
110-
if isinstance(matrix, MetaMatrix):
111-
return matrix.matrix.data
112-
elif isinstance(matrix, Matrix):
113-
return matrix.data
114-
else:
115-
return matrix
116-
117-
118-
def apply(data: Union[torch.Tensor, MetaTensor, dict], pending: Optional[Union[dict, list]] = None):
119-
"""
120-
This method applies pending transforms to tensors.
121-
Args:
122-
data: A torch Tensor, monai MetaTensor, or a dictionary containing Tensors / MetaTensors
123-
pending: Optional arg containing pending transforms. This must be set if data is a Tensor
124-
or dictionary of Tensors, but is optional if data is a MetaTensor / dictionary of
125-
MetaTensors.
126-
"""
127-
if isinstance(data, dict):
128-
rd = dict()
129-
for k, v in data.items():
130-
result = v(*pending)
131-
rd[k] = result
132-
return rd
13325

134-
if isinstance(data, MetaTensor) and pending is None:
135-
pending_ = data.pending_operations
136-
else:
137-
pending_ = [] if pending is None else pending
26+
def mat_from_pending(pending_item):
27+
if isinstance(pending_item, (torch.Tensor, np.ndarray)):
28+
return pending_item
29+
if isinstance(pending_item, dict):
30+
return pending_item[LazyAttr.AFFINE]
31+
return pending_item
13832

139-
if len(pending_) == 0:
140-
return data
14133

142-
dim_count = len(data.shape) - 1
143-
matrix_factory = MatrixFactory(dim_count, get_backend_from_tensor_like(data), get_device_from_tensor_like(data))
144-
145-
# set up the identity matrix and metadata
146-
cumulative_matrix, cumulative_extents = starting_matrix_and_extents(matrix_factory, data)
147-
148-
# set the various resampling parameters to an initial state
149-
cur_mode = None
150-
cur_padding_mode = None
151-
cur_device = None
152-
cur_dtype = None
153-
cur_shape = data.shape
154-
155-
for meta_matrix in pending_:
156-
next_matrix = meta_matrix.matrix
157-
# print("intermediate matrix\n", matrix_from_matrix_container(cumulative_matrix))
158-
cumulative_matrix = matmul(cumulative_matrix, next_matrix)
159-
cumulative_extents = [matmul(e, cumulative_matrix) for e in cumulative_extents]
160-
161-
new_mode = meta_matrix.metadata.get("mode", None)
162-
new_padding_mode = meta_matrix.metadata.get("padding_mode", None)
163-
new_device = meta_matrix.metadata.get("device", None)
164-
new_dtype = meta_matrix.metadata.get("dtype", None)
165-
new_shape = meta_matrix.metadata.get("shape_override", None)
166-
167-
mode_compat = metadata_is_compatible(cur_mode, new_mode)
168-
padding_mode_compat = metadata_is_compatible(cur_padding_mode, new_padding_mode)
169-
device_compat = metadata_is_compatible(cur_device, new_device)
170-
dtype_compat = metadata_dtype_is_compatible(cur_dtype, new_dtype)
171-
172-
if mode_compat is False or padding_mode_compat is False or device_compat is False or dtype_compat is False:
173-
# carry out an intermediate resample here due to incompatibility between arguments
174-
kwargs = prepare_args_dict_for_apply(cur_mode, cur_padding_mode, cur_device, cur_dtype)
34+
def kwargs_from_pending(pending_item):
35+
if not isinstance(pending_item, dict):
36+
return {}
37+
ret = {
38+
LazyAttr.INTERP_MODE: pending_item.get(LazyAttr.INTERP_MODE, None), # interpolation mode
39+
LazyAttr.PADDING_MODE: pending_item.get(LazyAttr.PADDING_MODE, None), # padding mode
40+
}
41+
if LazyAttr.SHAPE in pending_item:
42+
ret[LazyAttr.SHAPE] = pending_item[LazyAttr.SHAPE]
43+
if LazyAttr.DTYPE in pending_item:
44+
ret[LazyAttr.DTYPE] = pending_item[LazyAttr.DTYPE]
45+
return ret
46+
17547

176-
cumulative_matrix_ = matrix_from_matrix_container(cumulative_matrix)
177-
a = Affine(affine=cumulative_matrix_, **kwargs)
178-
data, _ = a(img=data)
48+
def is_compatible_kwargs(kwargs_1, kwargs_2):
49+
return True
17950

180-
cumulative_matrix, cumulative_extents = starting_matrix_and_extents(matrix_factory, data)
18151

182-
cur_mode = cur_mode if new_mode is None else new_mode
183-
cur_padding_mode = cur_padding_mode if new_padding_mode is None else new_padding_mode
184-
cur_device = cur_device if new_device is None else new_device
185-
cur_dtype = cur_dtype if new_dtype is None else new_dtype
186-
cur_shape = cur_shape if new_shape is None else new_shape
52+
def apply(data: Union[torch.Tensor, MetaTensor], pending: Optional[list] = None):
53+
"""
54+
This method applies pending transforms to tensors.
55+
56+
Args:
57+
data: A torch Tensor, monai MetaTensor
58+
pending: pending transforms. This must be set if data is a Tensor, but is optional if data is a MetaTensor.
59+
"""
60+
if isinstance(data, MetaTensor) and pending is None:
61+
pending = data.pending_operations
62+
pending = [] if pending is None else pending
18763

188-
kwargs = prepare_args_dict_for_apply(cur_mode, cur_padding_mode, cur_device, cur_dtype)
64+
if not pending:
65+
return data
18966

190-
cumulative_matrix_ = matrix_from_matrix_container(cumulative_matrix)
67+
cumulative_xform = mat_from_pending(pending[0])
68+
cur_kwargs = kwargs_from_pending(pending[0])
19169

192-
# print(f"applying with cumulative matrix\n {cumulative_matrix_}")
193-
a = Affine(affine=cumulative_matrix_, spatial_size=cur_shape[1:], normalized=False, **kwargs)
194-
data, tx = a(img=data)
70+
for p in pending[1:]:
71+
new_kwargs = kwargs_from_pending(p)
72+
if not is_compatible_kwargs(cur_kwargs, new_kwargs):
73+
# carry out an intermediate resample here due to incompatibility between arguments
74+
data = resample(data, cumulative_xform, cur_kwargs)
75+
next_matrix = mat_from_pending(p)
76+
cumulative_xform = matmul(cumulative_xform, next_matrix)
77+
cur_kwargs.update(new_kwargs)
78+
data = resample(data, cumulative_xform, cur_kwargs)
19579
if isinstance(data, MetaTensor):
19680
data.clear_pending_operations()
197-
for p in pending_:
198-
data.affine = p.matrix.data
81+
data.affine = data.affine @ to_affine_nd(3, cumulative_xform)
82+
for p in pending:
19983
data.push_applied_operation(p)
20084

201-
return data, None if pending is None else pending_
85+
return data, pending

0 commit comments

Comments
 (0)