Skip to content

Commit 000f035

Browse files
authored
Merge branch 'dev' into metatensor-croppad
2 parents 0c0c32d + d34fa14 commit 000f035

File tree

9 files changed

+48
-47
lines changed

9 files changed

+48
-47
lines changed

monai/data/png_writer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,9 @@ def write_png(
8181
if scale is not None:
8282
data = np.clip(data, 0.0, 1.0) # png writer only can scale data in range [0, 1]
8383
if scale == np.iinfo(np.uint8).max:
84-
data = convert_data_type((scale * data), np.ndarray, dtype=np.uint8, drop_meta=True)[0]
84+
data = convert_data_type((scale * data), np.ndarray, dtype=np.uint8)[0]
8585
elif scale == np.iinfo(np.uint16).max:
86-
data = convert_data_type((scale * data), np.ndarray, dtype=np.uint16, drop_meta=True)[0]
86+
data = convert_data_type((scale * data), np.ndarray, dtype=np.uint16)[0]
8787
else:
8888
raise ValueError(f"Unsupported scale: {scale}, available options are [255, 65535]")
8989

monai/data/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -856,7 +856,7 @@ def to_affine_nd(r: Union[np.ndarray, int], affine: NdarrayTensor, dtype=np.floa
856856
an (r+1) x (r+1) matrix (tensor or ndarray depends on the input ``affine`` data type)
857857
858858
"""
859-
affine_np = convert_data_type(affine, output_type=np.ndarray, dtype=dtype, wrap_sequence=True, drop_meta=True)[0]
859+
affine_np = convert_data_type(affine, output_type=np.ndarray, dtype=dtype, wrap_sequence=True)[0]
860860
affine_np = affine_np.copy()
861861
if affine_np.ndim != 2:
862862
raise ValueError(f"affine must have 2 dimensions, got {affine_np.ndim}.")

monai/inferers/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def sliding_window_inference(
175175
for idx in slice_range
176176
]
177177
window_data = torch.cat(
178-
[convert_data_type(inputs[win_slice], torch.Tensor, drop_meta=True)[0] for win_slice in unravel_slice]
178+
[convert_data_type(inputs[win_slice], torch.Tensor)[0] for win_slice in unravel_slice]
179179
).to(sw_device)
180180
seg_prob_out = predictor(window_data, *args, **kwargs) # batched patch segmentation
181181

monai/metrics/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,8 @@ def get_mask_edges(seg_pred, seg_gt, label_idx: int = 1, crop: bool = True) -> T
155155
seg_pred, seg_gt = np.expand_dims(seg_pred, axis=channel_dim), np.expand_dims(seg_gt, axis=channel_dim)
156156
box_start, box_end = generate_spatial_bounding_box(np.asarray(seg_pred | seg_gt))
157157
cropper = SpatialCrop(roi_start=box_start, roi_end=box_end)
158-
seg_pred = convert_data_type(np.squeeze(cropper(seg_pred), axis=channel_dim), np.ndarray, drop_meta=True)[0]
159-
seg_gt = convert_data_type(np.squeeze(cropper(seg_gt), axis=channel_dim), np.ndarray, drop_meta=True)[0]
158+
seg_pred = convert_data_type(np.squeeze(cropper(seg_pred), axis=channel_dim), np.ndarray)[0]
159+
seg_gt = convert_data_type(np.squeeze(cropper(seg_gt), axis=channel_dim), np.ndarray)[0]
160160

161161
# Do binary erosion and use XOR to get edges
162162
edges_pred = binary_erosion(seg_pred) ^ seg_pred

monai/transforms/utils_pytorch_numpy_unification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ def mode(x: NdarrayTensor, dim: int = -1, to_long: bool = True) -> NdarrayTensor
376376
to_long: convert input to long before performing mode.
377377
"""
378378
dtype = torch.int64 if to_long else None
379-
x_t, *_ = convert_data_type(x, torch.Tensor, dtype=dtype, drop_meta=True)
379+
x_t, *_ = convert_data_type(x, torch.Tensor, dtype=dtype)
380380
o_t = torch.mode(x_t, dim).values
381381
o, *_ = convert_to_dst_type(o_t, x)
382382
return o

monai/utils/type_conversion.py

Lines changed: 14 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,8 @@ def convert_to_tensor(
105105
track_meta: bool = False,
106106
):
107107
"""
108-
Utility to convert the input data to a PyTorch Tensor, if tracking meta, convert to `MetaTensor`.
108+
Utility to convert the input data to a PyTorch Tensor, if `track_meta` is True, the output will be a `MetaTensor`,
109+
otherwise, the output will be a regular torch Tensor.
109110
If passing a dictionary, list or tuple, recursively check every item and convert it to PyTorch Tensor.
110111
111112
Args:
@@ -121,13 +122,14 @@ def convert_to_tensor(
121122
122123
"""
123124

124-
def _convert_tensor(tensor):
125+
def _convert_tensor(tensor, **kwargs):
125126
if not isinstance(tensor, torch.Tensor):
126-
tensor = torch.as_tensor(tensor)
127+
# if input data is not Tensor, convert it to Tensor first
128+
tensor = torch.as_tensor(tensor, **kwargs)
127129
if track_meta and not isinstance(tensor, monai.data.MetaTensor):
128130
return monai.data.MetaTensor(tensor)
129131
if not track_meta and isinstance(tensor, monai.data.MetaTensor):
130-
return tensor.as_tensor(tensor)
132+
return tensor.as_tensor()
131133
return tensor
132134

133135
if isinstance(data, torch.Tensor):
@@ -142,15 +144,15 @@ def _convert_tensor(tensor):
142144
data = np.ascontiguousarray(data)
143145
return _convert_tensor(data, dtype=dtype, device=device)
144146
elif (has_cp and isinstance(data, cp_ndarray)) or isinstance(data, (float, int, bool)):
145-
return _convert_tensor(data, dtype=dtype, device=device) # type: ignore
147+
return _convert_tensor(data, dtype=dtype, device=device)
146148
elif isinstance(data, list):
147-
list_ret = [convert_to_tensor(i, dtype=dtype, device=device) for i in data]
149+
list_ret = [convert_to_tensor(i, dtype=dtype, device=device, track_meta=track_meta) for i in data]
148150
return _convert_tensor(list_ret, dtype=dtype, device=device) if wrap_sequence else list_ret
149151
elif isinstance(data, tuple):
150-
tuple_ret = tuple(convert_to_tensor(i, dtype=dtype, device=device) for i in data)
151-
return _convert_tensor(tuple_ret, dtype=dtype, device=device) if wrap_sequence else tuple_ret # type: ignore
152+
tuple_ret = tuple(convert_to_tensor(i, dtype=dtype, device=device, track_meta=track_meta) for i in data)
153+
return _convert_tensor(tuple_ret, dtype=dtype, device=device) if wrap_sequence else tuple_ret
152154
elif isinstance(data, dict):
153-
return {k: convert_to_tensor(v, dtype=dtype, device=device) for k, v in data.items()}
155+
return {k: convert_to_tensor(v, dtype=dtype, device=device, track_meta=track_meta) for k, v in data.items()}
154156

155157
return data
156158

@@ -230,7 +232,6 @@ def convert_data_type(
230232
device: Optional[torch.device] = None,
231233
dtype: Union[DtypeLike, torch.dtype] = None,
232234
wrap_sequence: bool = False,
233-
drop_meta: bool = True,
234235
) -> Tuple[NdarrayTensor, type, Optional[torch.device]]:
235236
"""
236237
Convert to `torch.Tensor`/`np.ndarray` from `torch.Tensor`/`np.ndarray`/`float`/`int` etc.
@@ -244,9 +245,6 @@ def convert_data_type(
244245
If left blank, it remains unchanged.
245246
wrap_sequence: if `False`, then lists will recursively call this function.
246247
E.g., `[1, 2]` -> `[array(1), array(2)]`. If `True`, then `[1, 2]` -> `array([1, 2])`.
247-
drop_meta: whether to drop the meta information of the input data, default to `True`.
248-
If `True`, then the meta information will be dropped quietly, unless the output type is MetaTensor.
249-
If `False`, converting a MetaTensor into a non-tensor instance will raise an error.
250248
251249
Returns:
252250
modified data, orig_type, orig_device
@@ -278,18 +276,10 @@ def convert_data_type(
278276

279277
dtype_ = get_equivalent_dtype(dtype, output_type)
280278

281-
if not drop_meta and not issubclass(output_type, monai.data.MetaObj) and isinstance(data, monai.data.MetaObj):
282-
# input has a MetaObj, user chose keep the metadata, but the output type cannot take a MetaObj.
283-
if issubclass(output_type, torch.Tensor):
284-
# user-specified MetaTensor to torch tensor keep the MetaTensor type, for backward compatibility
285-
output_type = type(data) # type: ignore
286-
else:
287-
raise RuntimeError(f"the specified output_type {output_type} cannot have the metaobj, but drop_meta=False.")
288-
289279
data_: NdarrayTensor
290280

291281
if issubclass(output_type, torch.Tensor):
292-
track_meta = True if issubclass(output_type, monai.data.MetaTensor) else False
282+
track_meta = issubclass(output_type, monai.data.MetaTensor)
293283
data_ = convert_to_tensor(data, dtype=dtype_, device=device, wrap_sequence=wrap_sequence, track_meta=track_meta)
294284
return data_, orig_type, orig_device
295285
if issubclass(output_type, np.ndarray):
@@ -302,11 +292,7 @@ def convert_data_type(
302292

303293

304294
def convert_to_dst_type(
305-
src: Any,
306-
dst: NdarrayTensor,
307-
dtype: Union[DtypeLike, torch.dtype, None] = None,
308-
wrap_sequence: bool = False,
309-
drop_meta: bool = True,
295+
src: Any, dst: NdarrayTensor, dtype: Union[DtypeLike, torch.dtype, None] = None, wrap_sequence: bool = False
310296
) -> Tuple[NdarrayTensor, type, Optional[torch.device]]:
311297
"""
312298
Convert source data to the same data type and device as the destination data.
@@ -320,9 +306,6 @@ def convert_to_dst_type(
320306
dtype: an optional argument if the target `dtype` is different from the original `dst`'s data type.
321307
wrap_sequence: if `False`, then lists will recursively call this function. E.g., `[1, 2]` -> `[array(1), array(2)]`.
322308
If `True`, then `[1, 2]` -> `array([1, 2])`.
323-
drop_meta: whether to drop the meta information of the input data, default to `True`.
324-
If `True`, then the meta information will be dropped quietly, unless the output type is MetaTensor.
325-
If `False`, converting a MetaTensor into a non-tensor instance will raise an error.
326309
327310
See Also:
328311
:func:`convert_data_type`
@@ -346,7 +329,7 @@ def convert_to_dst_type(
346329
output_type = type(dst)
347330
output: NdarrayTensor
348331
output, _type, _device = convert_data_type(
349-
data=src, output_type=output_type, device=device, dtype=dtype, wrap_sequence=wrap_sequence, drop_meta=drop_meta
332+
data=src, output_type=output_type, device=device, dtype=dtype, wrap_sequence=wrap_sequence
350333
)
351334
if copy_meta and isinstance(output, monai.data.MetaTensor): # type: ignore
352335
output.meta, output.applied_operations = deepcopy(dst.meta), deepcopy(dst.applied_operations) # type: ignore

tests/test_convert_data_type.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,24 +16,25 @@
1616
import torch
1717
from parameterized import parameterized
1818

19+
from monai.data import MetaTensor
1920
from monai.utils.type_conversion import convert_data_type, convert_to_dst_type
20-
from tests.utils import TEST_NDARRAYS
21+
from tests.utils import TEST_NDARRAYS_ALL
2122

2223
TESTS: List[Tuple] = []
23-
for in_type in TEST_NDARRAYS + (int, float):
24-
for out_type in TEST_NDARRAYS:
24+
for in_type in TEST_NDARRAYS_ALL + (int, float):
25+
for out_type in TEST_NDARRAYS_ALL:
2526
TESTS.append((in_type(np.array(1.0)), out_type(np.array(1.0)))) # type: ignore
2627

2728
TESTS_LIST: List[Tuple] = []
28-
for in_type in TEST_NDARRAYS + (int, float):
29-
for out_type in TEST_NDARRAYS:
29+
for in_type in TEST_NDARRAYS_ALL + (int, float):
30+
for out_type in TEST_NDARRAYS_ALL:
3031
TESTS_LIST.append(
3132
([in_type(np.array(1.0)), in_type(np.array(1.0))], out_type(np.array([1.0, 1.0])), True) # type: ignore
3233
)
3334
TESTS_LIST.append(
3435
(
3536
[in_type(np.array(1.0)), in_type(np.array(1.0))], # type: ignore
36-
[out_type(np.array(1.0)), out_type(np.array(1.0))],
37+
[out_type(np.array(1.0)), out_type(np.array(1.0))], # type: ignore
3738
False,
3839
)
3940
)
@@ -83,14 +84,17 @@ def test_convert_data_type(self, in_image, im_out):
8384
self.assertEqual(type(in_image), orig_type)
8485
if isinstance(in_image, torch.Tensor):
8586
self.assertEqual(in_image.device, orig_device)
87+
8688
# check output is desired type
87-
if isinstance(im_out, torch.Tensor):
89+
if isinstance(im_out, MetaTensor):
90+
output_type = MetaTensor
91+
elif isinstance(im_out, torch.Tensor):
8892
output_type = torch.Tensor
8993
else:
9094
output_type = np.ndarray
9195
self.assertEqual(type(converted_im), output_type)
9296
# check dtype is unchanged
93-
if isinstance(in_type, (np.ndarray, torch.Tensor)):
97+
if isinstance(in_type, (np.ndarray, torch.Tensor, MetaTensor)):
9498
self.assertEqual(converted_im.dtype, im_out.dtype)
9599

96100

tests/test_utils_pytorch_numpy_unification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def test_percentile(self):
3636
for p in TEST_NDARRAYS:
3737
arr = p(np.arange(100 * 101).reshape(1, 100, 101).astype(np.float32))
3838
results.append(percentile(arr, q))
39-
assert_allclose(results[0], results[-1], type_test=False, atol=1e-4)
39+
assert_allclose(results[0], results[-1], type_test=False, atol=1e-4, rtol=1e-4)
4040

4141
def test_fails(self):
4242
for p in TEST_NDARRAYS:

tests/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from monai.config.deviceconfig import USE_COMPILED
3939
from monai.config.type_definitions import NdarrayOrTensor
4040
from monai.data import create_test_image_2d, create_test_image_3d
41+
from monai.data.meta_tensor import MetaTensor
4142
from monai.networks import convert_to_torchscript
4243
from monai.utils import optional_import
4344
from monai.utils.module import pytorch_after, version_leq
@@ -712,6 +713,19 @@ def query_memory(n=2):
712713
gpu_tensor: Callable = partial(torch.as_tensor, device="cuda")
713714
TEST_NDARRAYS = TEST_NDARRAYS + (gpu_tensor,) # type: ignore
714715

716+
TEST_TORCH_TENSORS: Tuple[Callable] = (torch.as_tensor,) # type: ignore
717+
if torch.cuda.is_available():
718+
gpu_tensor: Callable = partial(torch.as_tensor, device="cuda") # type: ignore
719+
TEST_NDARRAYS = TEST_TORCH_TENSORS + (gpu_tensor,) # type: ignore
720+
721+
DEFAULT_TEST_AFFINE = torch.tensor(
722+
[[2.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 0.0, 2.0, 0.0], [0.0, 0.0, 0.0, 1.0]]
723+
)
724+
_metatensor_creator = partial(MetaTensor, meta={"a": "b", "affine": DEFAULT_TEST_AFFINE})
725+
TEST_NDARRAYS_NO_META_TENSOR: Tuple[Callable] = (np.array,) + TEST_TORCH_TENSORS # type: ignore
726+
TEST_TORCH_AND_META_TENSORS: Tuple[Callable] = TEST_TORCH_TENSORS + (_metatensor_creator,) # type: ignore
727+
TEST_NDARRAYS_ALL: Tuple[Callable] = TEST_NDARRAYS_NO_META_TENSOR + (_metatensor_creator,) # type: ignore
728+
715729

716730
TEST_DEVICES = [[torch.device("cpu")]]
717731
if torch.cuda.is_available():

0 commit comments

Comments
 (0)