-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Description
Describe the bug
If a list of transforms includes more than one MultiSampleTrait
-like operations (e.g. each returning a list of generated items), putting this as an argument to execute_compose()
will lead to a list of lists being returned by this line of apply_transform. The next transform in the list would then break, because it would try to interpret one entry of the outer list as a single item, e.g. a dict.
To Reproduce
Steps to reproduce the behavior (a sketch):
import torch
from monai.transforms.traits import MultiSampleTrait
from monai.transforms.compose import execute_compose
from torchvision.transforms import CenterCrop
class Foo(MultiSampleTrait):
def __init__(self):
super().__init__()
def __call__(self, data):
return [data]
center_crop = CenterCrop([128,128]) # just some "normal" transform
img = torch.zeros([1, 512, 512])
res = execute_compose(img, [center_crop]) # -> tensor 1x128x128 (as expected)
res = execute_compose(img, [Foo(), center_crop]) # -> list, containing one element, tensor 1x128x128 (as expected)
res = execute_compose(img, [Foo(), Foo(), center_crop]) # -> TypeError: Unexpected type <class 'list'>
Expected behavior
Since input list or tuple unwrapping is implemented in apply_transform
, and MultiSampleTrait
indicates the possibility of such transforms (I know it does not do anything by itself, nor is it used to inform any actions inside MONAI as of now), I expect it to handle return values that are lists or tuples correctly.
Environment
(providing a partial output since the rest seems irrelevant)
================================
Printing MONAI config...
================================
MONAI version: 1.4.0
Numpy version: 1.26.4
Pytorch version: 2.7.1+cu126
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 46a5272196a6c2590ca2589029eed8e4d56ff008
MONAI __file__: /opt/venv/lib/python3.12/site-packages/monai/__init__.py
(I checked that the main-branch MONAI has the same issue in its code).
Suggested fix
Replace these lines with
res = []
for item in data:
res_item = _apply_transform(transform, item, unpack_items, lazy, overrides, log_stats)
if isinstance(res_item, list | tuple):
res.extend(res_item)
else:
res.append(res_item)
return res