@@ -105,7 +105,8 @@ def convert_to_tensor(
105
105
track_meta : bool = False ,
106
106
):
107
107
"""
108
- Utility to convert the input data to a PyTorch Tensor, if tracking meta, convert to `MetaTensor`.
108
+ Utility to convert the input data to a PyTorch Tensor, if `track_meta` is True, the output will be a `MetaTensor`,
109
+ otherwise, the output will be a regular torch Tensor.
109
110
If passing a dictionary, list or tuple, recursively check every item and convert it to PyTorch Tensor.
110
111
111
112
Args:
@@ -121,13 +122,14 @@ def convert_to_tensor(
121
122
122
123
"""
123
124
124
- def _convert_tensor (tensor ):
125
+ def _convert_tensor (tensor , ** kwargs ):
125
126
if not isinstance (tensor , torch .Tensor ):
126
- tensor = torch .as_tensor (tensor )
127
+ # if input data is not Tensor, convert it to Tensor first
128
+ tensor = torch .as_tensor (tensor , ** kwargs )
127
129
if track_meta and not isinstance (tensor , monai .data .MetaTensor ):
128
130
return monai .data .MetaTensor (tensor )
129
131
if not track_meta and isinstance (tensor , monai .data .MetaTensor ):
130
- return tensor .as_tensor (tensor )
132
+ return tensor .as_tensor ()
131
133
return tensor
132
134
133
135
if isinstance (data , torch .Tensor ):
@@ -142,15 +144,15 @@ def _convert_tensor(tensor):
142
144
data = np .ascontiguousarray (data )
143
145
return _convert_tensor (data , dtype = dtype , device = device )
144
146
elif (has_cp and isinstance (data , cp_ndarray )) or isinstance (data , (float , int , bool )):
145
- return _convert_tensor (data , dtype = dtype , device = device ) # type: ignore
147
+ return _convert_tensor (data , dtype = dtype , device = device )
146
148
elif isinstance (data , list ):
147
- list_ret = [convert_to_tensor (i , dtype = dtype , device = device ) for i in data ]
149
+ list_ret = [convert_to_tensor (i , dtype = dtype , device = device , track_meta = track_meta ) for i in data ]
148
150
return _convert_tensor (list_ret , dtype = dtype , device = device ) if wrap_sequence else list_ret
149
151
elif isinstance (data , tuple ):
150
- tuple_ret = tuple (convert_to_tensor (i , dtype = dtype , device = device ) for i in data )
151
- return _convert_tensor (tuple_ret , dtype = dtype , device = device ) if wrap_sequence else tuple_ret # type: ignore
152
+ tuple_ret = tuple (convert_to_tensor (i , dtype = dtype , device = device , track_meta = track_meta ) for i in data )
153
+ return _convert_tensor (tuple_ret , dtype = dtype , device = device ) if wrap_sequence else tuple_ret
152
154
elif isinstance (data , dict ):
153
- return {k : convert_to_tensor (v , dtype = dtype , device = device ) for k , v in data .items ()}
155
+ return {k : convert_to_tensor (v , dtype = dtype , device = device , track_meta = track_meta ) for k , v in data .items ()}
154
156
155
157
return data
156
158
@@ -230,7 +232,6 @@ def convert_data_type(
230
232
device : Optional [torch .device ] = None ,
231
233
dtype : Union [DtypeLike , torch .dtype ] = None ,
232
234
wrap_sequence : bool = False ,
233
- drop_meta : bool = True ,
234
235
) -> Tuple [NdarrayTensor , type , Optional [torch .device ]]:
235
236
"""
236
237
Convert to `torch.Tensor`/`np.ndarray` from `torch.Tensor`/`np.ndarray`/`float`/`int` etc.
@@ -244,9 +245,6 @@ def convert_data_type(
244
245
If left blank, it remains unchanged.
245
246
wrap_sequence: if `False`, then lists will recursively call this function.
246
247
E.g., `[1, 2]` -> `[array(1), array(2)]`. If `True`, then `[1, 2]` -> `array([1, 2])`.
247
- drop_meta: whether to drop the meta information of the input data, default to `True`.
248
- If `True`, then the meta information will be dropped quietly, unless the output type is MetaTensor.
249
- If `False`, converting a MetaTensor into a non-tensor instance will raise an error.
250
248
251
249
Returns:
252
250
modified data, orig_type, orig_device
@@ -278,18 +276,10 @@ def convert_data_type(
278
276
279
277
dtype_ = get_equivalent_dtype (dtype , output_type )
280
278
281
- if not drop_meta and not issubclass (output_type , monai .data .MetaObj ) and isinstance (data , monai .data .MetaObj ):
282
- # input has a MetaObj, user chose keep the metadata, but the output type cannot take a MetaObj.
283
- if issubclass (output_type , torch .Tensor ):
284
- # user-specified MetaTensor to torch tensor keep the MetaTensor type, for backward compatibility
285
- output_type = type (data ) # type: ignore
286
- else :
287
- raise RuntimeError (f"the specified output_type { output_type } cannot have the metaobj, but drop_meta=False." )
288
-
289
279
data_ : NdarrayTensor
290
280
291
281
if issubclass (output_type , torch .Tensor ):
292
- track_meta = True if issubclass (output_type , monai .data .MetaTensor ) else False
282
+ track_meta = issubclass (output_type , monai .data .MetaTensor )
293
283
data_ = convert_to_tensor (data , dtype = dtype_ , device = device , wrap_sequence = wrap_sequence , track_meta = track_meta )
294
284
return data_ , orig_type , orig_device
295
285
if issubclass (output_type , np .ndarray ):
@@ -302,11 +292,7 @@ def convert_data_type(
302
292
303
293
304
294
def convert_to_dst_type (
305
- src : Any ,
306
- dst : NdarrayTensor ,
307
- dtype : Union [DtypeLike , torch .dtype , None ] = None ,
308
- wrap_sequence : bool = False ,
309
- drop_meta : bool = True ,
295
+ src : Any , dst : NdarrayTensor , dtype : Union [DtypeLike , torch .dtype , None ] = None , wrap_sequence : bool = False
310
296
) -> Tuple [NdarrayTensor , type , Optional [torch .device ]]:
311
297
"""
312
298
Convert source data to the same data type and device as the destination data.
@@ -320,9 +306,6 @@ def convert_to_dst_type(
320
306
dtype: an optional argument if the target `dtype` is different from the original `dst`'s data type.
321
307
wrap_sequence: if `False`, then lists will recursively call this function. E.g., `[1, 2]` -> `[array(1), array(2)]`.
322
308
If `True`, then `[1, 2]` -> `array([1, 2])`.
323
- drop_meta: whether to drop the meta information of the input data, default to `True`.
324
- If `True`, then the meta information will be dropped quietly, unless the output type is MetaTensor.
325
- If `False`, converting a MetaTensor into a non-tensor instance will raise an error.
326
309
327
310
See Also:
328
311
:func:`convert_data_type`
@@ -346,7 +329,7 @@ def convert_to_dst_type(
346
329
output_type = type (dst )
347
330
output : NdarrayTensor
348
331
output , _type , _device = convert_data_type (
349
- data = src , output_type = output_type , device = device , dtype = dtype , wrap_sequence = wrap_sequence , drop_meta = drop_meta
332
+ data = src , output_type = output_type , device = device , dtype = dtype , wrap_sequence = wrap_sequence
350
333
)
351
334
if copy_meta and isinstance (output , monai .data .MetaTensor ): # type: ignore
352
335
output .meta , output .applied_operations = deepcopy (dst .meta ), deepcopy (dst .applied_operations ) # type: ignore
0 commit comments