diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index b69a5015bb..4396efc547 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -120,7 +120,7 @@ class SupervisedTrainer(Trainer): #ignite.engine.engine.Engine.register_events. decollate: whether to decollate the batch-first data to a list of data after model computation, recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`. - default to `True`. + default to `False` as training slows due to tensor movement to CPU for decollation when enabled. optim_set_to_none: when calling `optimizer.zero_grad()`, instead of setting to zero, set the grads to None. more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html. to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for @@ -154,7 +154,7 @@ def __init__( amp: bool = False, event_names: list[str | EventEnum | type[EventEnum]] | None = None, event_to_attr: dict | None = None, - decollate: bool = True, + decollate: bool = False, optim_set_to_none: bool = False, to_kwargs: dict | None = None, amp_kwargs: dict | None = None, diff --git a/tests/apps/deepgrow/transforms/test_deepgrow_interaction.py b/tests/apps/deepgrow/transforms/test_deepgrow_interaction.py index 35759699f8..9a33a9a068 100644 --- a/tests/apps/deepgrow/transforms/test_deepgrow_interaction.py +++ b/tests/apps/deepgrow/transforms/test_deepgrow_interaction.py @@ -78,6 +78,7 @@ def run_interaction(self, train, compose): optimizer=opt, loss_function=loss, iteration_update=i, + decollate=True, ) engine.add_event_handler(IterationEvents.INNER_ITERATION_STARTED, add_one) engine.add_event_handler(IterationEvents.INNER_ITERATION_COMPLETED, add_one) diff --git a/tests/integration/test_deepedit_interaction.py b/tests/integration/test_deepedit_interaction.py index 8baf4dc827..8c5f45d463 100644 --- a/tests/integration/test_deepedit_interaction.py +++ b/tests/integration/test_deepedit_interaction.py @@ -103,6 +103,7 @@ def run_interaction(self, train): loss_function=loss, postprocessing=post_transforms, iteration_update=i, + decollate=True, ) engine.add_event_handler(IterationEvents.INNER_ITERATION_STARTED, add_one) engine.add_event_handler(IterationEvents.INNER_ITERATION_COMPLETED, add_one) diff --git a/tests/testing_data/config_fl_train.json b/tests/testing_data/config_fl_train.json index 5b7fb6608e..8dbe9cc77f 100644 --- a/tests/testing_data/config_fl_train.json +++ b/tests/testing_data/config_fl_train.json @@ -119,7 +119,8 @@ "loss_function": "@loss", "optimizer": "@optimizer", "inferer": "@train#inferer", - "train_handlers": "@train#handlers" + "train_handlers": "@train#handlers", + "decollate": true } }, "validate": {