Skip to content

Commit 8aabdc9

Browse files
Improve performance for NormalizeIntensity (#6887)
### Description In order to implement the "nonzero" functionality of the NormalizeIntensity transform a mask is used. In case nonzero is False, the mask is still used, but is initialized to all True/1. This unecessary masking causes a considerable performance hit. The changed implementation forgoes using the mask in case nonzero is False. I ran a quick benchmark on my system comparing the old implementation, the new implementation and the normalization using the wrapper around the torchvision normalize transform. The results were the following, showing a more than 10x performance improvement (notice the times for the old normalize are in milliseconds, the other times are in microseconds): > [-------------- torchvision ---------------] > | cpu | cuda > 1 threads: --------------------------------- > (250, 250, 250) | 18847.2 | 1440.5 > (100, 100, 100) | 484.6 | 395.5 > > Times are in microseconds (us). > > [--------------- monai ----------------] > | cpu | cuda > 1 threads: ----------------------------- > (250, 250, 250) | 603.7 | 11.5 > (100, 100, 100) | 39.9 | 1.5 > > Times are in milliseconds (ms). > > [------------- monai_improved ------------] > | cpu | cuda > 1 threads: -------------------------------- > (250, 250, 250) | 17763.2 | 720.0 > (100, 100, 100) | 938.0 | 185.2 > > Times are in microseconds (us). The benchmarks were created with the following code (the ImprovedNormalizeIntensity class does not exist in the PR, this was my quick fix to have both the old and the new implementation available) ```python import torch.utils.benchmark as benchmark import torch from monai.transforms import TorchVision from monai.transforms.intensity.array import ImprovedNormalizeIntensity, NormalizeIntensity shapes = [ (250, 250, 250), (100,100,100) ] normalizers = { 'torchvision': TorchVision(name="Normalize", mean=1000, std=333), 'monai': NormalizeIntensity(subtrahend=1000, divisor=333), 'monai_improved': ImprovedNormalizeIntensity(subtrahend=1000, divisor=333), } results = [] for shape in shapes: for device in ['cpu', 'cuda']: torch_tensor = torch.rand((1,1)+shape).to(device) for name, normalizer in normalizers.items(): t = benchmark.Timer( stmt='normalizer(x)', globals={'normalizer': normalizer , 'x': torch_tensor}, label=name, sub_label=str(shape), description=device, num_threads=1, ) results.append(t.blocked_autorange(min_run_time=10)) compare = benchmark.Compare(results) compare.print() ``` ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: John Zielke <[email protected]>
1 parent 6e47140 commit 8aabdc9

File tree

1 file changed

+15
-11
lines changed

1 file changed

+15
-11
lines changed

monai/transforms/intensity/array.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -839,29 +839,33 @@ def _normalize(self, img: NdarrayOrTensor, sub=None, div=None) -> NdarrayOrTenso
839839

840840
if self.nonzero:
841841
slices = img != 0
842+
masked_img = img[slices]
843+
if not slices.any():
844+
return img
842845
else:
843-
if isinstance(img, np.ndarray):
844-
slices = np.ones_like(img, dtype=bool)
845-
else:
846-
slices = torch.ones_like(img, dtype=torch.bool)
847-
if not slices.any():
848-
return img
846+
slices = None
847+
masked_img = img
849848

850-
_sub = sub if sub is not None else self._mean(img[slices])
849+
_sub = sub if sub is not None else self._mean(masked_img)
851850
if isinstance(_sub, (torch.Tensor, np.ndarray)):
852851
_sub, *_ = convert_to_dst_type(_sub, img)
853-
_sub = _sub[slices]
852+
if slices is not None:
853+
_sub = _sub[slices]
854854

855-
_div = div if div is not None else self._std(img[slices])
855+
_div = div if div is not None else self._std(masked_img)
856856
if np.isscalar(_div):
857857
if _div == 0.0:
858858
_div = 1.0
859859
elif isinstance(_div, (torch.Tensor, np.ndarray)):
860860
_div, *_ = convert_to_dst_type(_div, img)
861-
_div = _div[slices]
861+
if slices is not None:
862+
_div = _div[slices]
862863
_div[_div == 0.0] = 1.0
863864

864-
img[slices] = (img[slices] - _sub) / _div
865+
if slices is not None:
866+
img[slices] = (masked_img - _sub) / _div
867+
else:
868+
img = (img - _sub) / _div
865869
return img
866870

867871
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:

0 commit comments

Comments
 (0)