Skip to content

Commit b92b2ce

Browse files
FinnBehrendtKumoLiuericspod
authored
Add support for optional conditioning in PatchInferer, SliceInferer, and SlidingWindowInferer (#8400)
Fixes [#8220](#8220) ### Description This PR adds support for optional conditioning in MONAI’s inferers, allowing models to receive auxiliary inputs for conditioning that are processed (patched, sliced) the same way as the inputs. This is particularly relevant for generative models like conditional GANs or DMs. Example Usage: ```python # Given a conditioned model, inputs of shape (1, C, H, W, D) and condition of shape (1, C, H, W, D) output = SliceInferer(...)(inputs, model, condition=cond_tensor) ``` ### Types of changes - Extended `PatchInferer`, `SliceInferer`, and `SlidingWindowInferer` to optionally accept a `condition` tensor (passed as a kwarg). - The `condition` can now be: - `None` (default) - A tensor of the same shape as `inputs` - The inferers now slice/patch the conditions alongside the corresponding inputs and feed them to the network. - Updated unit tests for each inferer: - Verified with and without conditioning <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Additional extensions such as support for dense vector conditioning (e.g., (1, C, Z), with Z being the conditional dimension) could be explored in a follow-up PR if there’s interest. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added support for an optional "condition" tensor in patch-based, sliding window, and slice inference, allowing conditional inference workflows. * The "condition" tensor is validated for shape and type consistency with inputs and is processed in sync during inference. * **Tests** * Introduced extensive new tests for conditional inference across patch, sliding window, and slice inferers to ensure correct behavior and output validation when using the "condition" argument. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: FinnBehrendt <[email protected]> Co-authored-by: YunLiu <[email protected]> Co-authored-by: Eric Kerfoot <[email protected]>
1 parent 2a12c4b commit b92b2ce

File tree

5 files changed

+673
-15
lines changed

5 files changed

+673
-15
lines changed

monai/inferers/inferer.py

Lines changed: 88 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -322,8 +322,36 @@ def __call__(
322322
supports callables such as ``lambda x: my_torch_model(x, additional_config)``
323323
args: optional args to be passed to ``network``.
324324
kwargs: optional keyword args to be passed to ``network``.
325+
condition (torch.Tensor, optional): If provided via `**kwargs`,
326+
this tensor must match the shape of `inputs` and will be sliced, patched, or windowed alongside the inputs.
327+
The resulting segments will be passed to the model together with the corresponding input segments.
325328
326329
"""
330+
# check if there is a conditioning signal
331+
condition = kwargs.pop("condition", None)
332+
# shape check for condition
333+
if condition is not None:
334+
if isinstance(inputs, torch.Tensor) and isinstance(condition, torch.Tensor):
335+
if condition.shape != inputs.shape:
336+
raise ValueError(
337+
f"`condition` must match shape of `inputs` ({inputs.shape}), but got {condition.shape}"
338+
)
339+
elif isinstance(inputs, list) and isinstance(condition, list):
340+
if len(inputs) != len(condition):
341+
raise ValueError(
342+
f"Length of `condition` must match `inputs`. Got {len(inputs)} and {len(condition)}."
343+
)
344+
for (in_patch, _), (cond_patch, _) in zip(inputs, condition):
345+
if cond_patch.shape != in_patch.shape:
346+
raise ValueError(
347+
"Each `condition` patch must match the shape of the corresponding input patch. "
348+
f"Got {cond_patch.shape} and {in_patch.shape}."
349+
)
350+
else:
351+
raise ValueError(
352+
"`condition` and `inputs` must be of the same type (both Tensor or both list of patches)."
353+
)
354+
327355
patches_locations: Iterable[tuple[torch.Tensor, Sequence[int]]] | MetaTensor
328356
if self.splitter is None:
329357
# handle situations where the splitter is not provided
@@ -344,20 +372,39 @@ def __call__(
344372
f"The provided inputs type is {type(inputs)}."
345373
)
346374
patches_locations = inputs
375+
if condition is not None:
376+
condition_locations = condition
347377
else:
348378
# apply splitter
349379
patches_locations = self.splitter(inputs)
380+
if condition is not None:
381+
# apply splitter to condition
382+
condition_locations = self.splitter(condition)
350383

351384
ratios: list[float] = []
352385
mergers: list[Merger] = []
353-
for patches, locations, batch_size in self._batch_sampler(patches_locations):
354-
# run inference
355-
outputs = self._run_inference(network, patches, *args, **kwargs)
356-
# initialize the mergers
357-
if not mergers:
358-
mergers, ratios = self._initialize_mergers(inputs, outputs, patches, batch_size)
359-
# aggregate outputs
360-
self._aggregate(outputs, locations, batch_size, mergers, ratios)
386+
if condition is not None:
387+
for (patches, locations, batch_size), (condition_patches, _, _) in zip(
388+
self._batch_sampler(patches_locations), self._batch_sampler(condition_locations)
389+
):
390+
# add patched condition to kwargs
391+
kwargs["condition"] = condition_patches
392+
# run inference
393+
outputs = self._run_inference(network, patches, *args, **kwargs)
394+
# initialize the mergers
395+
if not mergers:
396+
mergers, ratios = self._initialize_mergers(inputs, outputs, patches, batch_size)
397+
# aggregate outputs
398+
self._aggregate(outputs, locations, batch_size, mergers, ratios)
399+
else:
400+
for patches, locations, batch_size in self._batch_sampler(patches_locations):
401+
# run inference
402+
outputs = self._run_inference(network, patches, *args, **kwargs)
403+
# initialize the mergers
404+
if not mergers:
405+
mergers, ratios = self._initialize_mergers(inputs, outputs, patches, batch_size)
406+
# aggregate outputs
407+
self._aggregate(outputs, locations, batch_size, mergers, ratios)
361408

362409
# finalize the mergers and get the results
363410
merged_outputs = [merger.finalize() for merger in mergers]
@@ -519,8 +566,14 @@ def __call__(
519566
supports callables such as ``lambda x: my_torch_model(x, additional_config)``
520567
args: optional args to be passed to ``network``.
521568
kwargs: optional keyword args to be passed to ``network``.
522-
569+
condition (torch.Tensor, optional): If provided via `**kwargs`,
570+
this tensor must match the shape of `inputs` and will be sliced, patched, or windowed alongside the inputs.
571+
The resulting segments will be passed to the model together with the corresponding input segments.
523572
"""
573+
# shape check for condition
574+
condition = kwargs.get("condition", None)
575+
if condition is not None and condition.shape != inputs.shape:
576+
raise ValueError(f"`condition` must match shape of `inputs` ({inputs.shape}), but got {condition.shape}")
524577

525578
device = kwargs.pop("device", self.device)
526579
buffer_steps = kwargs.pop("buffer_steps", self.buffer_steps)
@@ -728,7 +781,9 @@ def __call__(
728781
network: 2D model to execute inference on slices in the 3D input
729782
args: optional args to be passed to ``network``.
730783
kwargs: optional keyword args to be passed to ``network``.
731-
"""
784+
condition (torch.Tensor, optional): If provided via `**kwargs`,
785+
this tensor must match the shape of `inputs` and will be sliced, patched, or windowed alongside the inputs.
786+
The resulting segments will be passed to the model together with the corresponding input segments."""
732787
if self.spatial_dim > 2:
733788
raise ValueError("`spatial_dim` can only be `0, 1, 2` with `[H, W, D]` respectively.")
734789

@@ -742,12 +797,28 @@ def __call__(
742797
f"Currently, only 2D `roi_size` ({self.orig_roi_size}) with 3D `inputs` tensor (shape={inputs.shape}) is supported."
743798
)
744799

745-
return super().__call__(inputs=inputs, network=lambda x: self.network_wrapper(network, x, *args, **kwargs))
800+
# shape check for condition
801+
condition = kwargs.get("condition", None)
802+
if condition is not None and condition.shape != inputs.shape:
803+
raise ValueError(f"`condition` must match shape of `inputs` ({inputs.shape}), but got {condition.shape}")
804+
805+
# check if there is a conditioning signal
806+
if condition is not None:
807+
return super().__call__(
808+
inputs=inputs,
809+
network=lambda x, *args, **kwargs: self.network_wrapper(network, x, *args, **kwargs),
810+
condition=condition,
811+
)
812+
else:
813+
return super().__call__(
814+
inputs=inputs, network=lambda x, *args, **kwargs: self.network_wrapper(network, x, *args, **kwargs)
815+
)
746816

747817
def network_wrapper(
748818
self,
749819
network: Callable[..., torch.Tensor | Sequence[torch.Tensor] | dict[Any, torch.Tensor]],
750820
x: torch.Tensor,
821+
condition: torch.Tensor | None = None,
751822
*args: Any,
752823
**kwargs: Any,
753824
) -> torch.Tensor | tuple[torch.Tensor, ...] | dict[Any, torch.Tensor]:
@@ -756,7 +827,12 @@ def network_wrapper(
756827
"""
757828
# Pass 4D input [N, C, H, W]/[N, C, D, W]/[N, C, D, H] to the model as it is 2D.
758829
x = x.squeeze(dim=self.spatial_dim + 2)
759-
out = network(x, *args, **kwargs)
830+
831+
if condition is not None:
832+
condition = condition.squeeze(dim=self.spatial_dim + 2)
833+
out = network(x, condition, *args, **kwargs)
834+
else:
835+
out = network(x, *args, **kwargs)
760836

761837
# Unsqueeze the network output so it is [N, C, D, H, W] as expected by
762838
# the default SlidingWindowInferer class

monai/inferers/utils.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,8 @@ def sliding_window_inference(
153153
device = device or inputs.device
154154
sw_device = sw_device or inputs.device
155155

156+
condition = kwargs.pop("condition", None)
157+
156158
temp_meta = None
157159
if isinstance(inputs, MetaTensor):
158160
temp_meta = MetaTensor([]).copy_meta_from(inputs, copy_attr=False)
@@ -168,6 +170,8 @@ def sliding_window_inference(
168170
pad_size.extend([half, diff - half])
169171
if any(pad_size):
170172
inputs = F.pad(inputs, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode), value=cval)
173+
if condition is not None:
174+
condition = F.pad(condition, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode), value=cval)
171175

172176
# Store all slices
173177
scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap)
@@ -220,13 +224,19 @@ def sliding_window_inference(
220224
]
221225
if sw_batch_size > 1:
222226
win_data = torch.cat([inputs[win_slice] for win_slice in unravel_slice]).to(sw_device)
227+
if condition is not None:
228+
win_condition = torch.cat([condition[win_slice] for win_slice in unravel_slice]).to(sw_device)
229+
kwargs["condition"] = win_condition
223230
else:
224231
win_data = inputs[unravel_slice[0]].to(sw_device)
232+
if condition is not None:
233+
win_condition = condition[unravel_slice[0]].to(sw_device)
234+
kwargs["condition"] = win_condition
235+
225236
if with_coord:
226-
seg_prob_out = predictor(win_data, unravel_slice, *args, **kwargs) # batched patch
237+
seg_prob_out = predictor(win_data, unravel_slice, *args, **kwargs)
227238
else:
228-
seg_prob_out = predictor(win_data, *args, **kwargs) # batched patch
229-
239+
seg_prob_out = predictor(win_data, *args, **kwargs)
230240
# convert seg_prob_out to tuple seg_tuple, this does not allocate new memory.
231241
dict_keys, seg_tuple = _flatten_struct(seg_prob_out)
232242
if process_fn:

0 commit comments

Comments
 (0)