-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Redesign PatchWSIDataset #4152
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Redesign PatchWSIDataset #4152
Changes from all commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
6ac6545
Implement PatchWSIDataset
bhashemian d227b55
Add unittests
bhashemian a2d9a11
Add docs
bhashemian f1324d1
Reorder imports
bhashemian 39a679a
formatting:
bhashemian 688ff04
Address comments
bhashemian 4092159
Update to be compatible with Dataset
bhashemian d9d7b82
Update reader to accept str, class, object
bhashemian 31d722c
Add test cases for various reader and level arguments
bhashemian 15a9b28
Update comment about OpenSlide cache
bhashemian 48dc0a3
Rename reader_name to backend
bhashemian 4162493
Merge dev branch
bhashemian ab3472c
Add new test cases
bhashemian 10ee84f
Add unittests for openslide
bhashemian bab87be
Add new test cases
bhashemian c725103
sorts
bhashemian 4838ec1
Add docstring for kwargs
bhashemian File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
# Copyright (c) MONAI Consortium | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import inspect | ||
from typing import Callable, Dict, List, Optional, Tuple, Union | ||
|
||
import numpy as np | ||
|
||
from monai.data import Dataset | ||
from monai.data.wsi_reader import BaseWSIReader, WSIReader | ||
from monai.transforms import apply_transform | ||
from monai.utils import ensure_tuple_rep | ||
|
||
__all__ = ["PatchWSIDataset"] | ||
|
||
|
||
class PatchWSIDataset(Dataset): | ||
""" | ||
This dataset extracts patches from whole slide images (without loading the whole image) | ||
It also reads labels for each patch and provides each patch with its associated class labels. | ||
|
||
Args: | ||
data: the list of input samples including image, location, and label (see the note below for more details). | ||
size: the size of patch to be extracted from the whole slide image. | ||
level: the level at which the patches to be extracted (default to 0). | ||
transform: transforms to be executed on input data. | ||
reader: the module to be used for loading whole slide imaging, | ||
- if `reader` is a string, it defines the backend of `monai.data.WSIReader`. Defaults to cuCIM. | ||
- if `reader` is a class (inherited from `BaseWSIReader`), it is initialized and set as wsi_reader. | ||
- if `reader` is an instance of a a class inherited from `BaseWSIReader`, it is set as the wsi_reader. | ||
kwargs: additional arguments to pass to `WSIReader` or provided whole slide reader class | ||
|
||
Note: | ||
The input data has the following form as an example: | ||
|
||
.. code-block:: python | ||
|
||
[ | ||
{"image": "path/to/image1.tiff", "location": [200, 500], "label": 0}, | ||
{"image": "path/to/image2.tiff", "location": [100, 700], "label": 1} | ||
] | ||
|
||
""" | ||
|
||
def __init__( | ||
self, | ||
data: List, | ||
size: Optional[Union[int, Tuple[int, int]]] = None, | ||
level: Optional[int] = None, | ||
transform: Optional[Callable] = None, | ||
reader="cuCIM", | ||
**kwargs, | ||
): | ||
super().__init__(data, transform) | ||
|
||
# Ensure patch size is a two dimensional tuple | ||
if size is None: | ||
self.size = None | ||
else: | ||
self.size = ensure_tuple_rep(size, 2) | ||
|
||
# Create a default level that override all levels if it is not None | ||
self.level = level | ||
# Set the default WSIReader's level to 0 if level is not provided | ||
if level is None: | ||
level = 0 | ||
|
||
# Setup the WSI reader | ||
self.wsi_reader: Union[WSIReader, BaseWSIReader] | ||
self.backend = "" | ||
if isinstance(reader, str): | ||
self.backend = reader.lower() | ||
self.wsi_reader = WSIReader(backend=self.backend, level=level, **kwargs) | ||
elif inspect.isclass(reader) and issubclass(reader, BaseWSIReader): | ||
self.wsi_reader = reader(level=level, **kwargs) | ||
elif isinstance(reader, BaseWSIReader): | ||
self.wsi_reader = reader | ||
else: | ||
raise ValueError(f"Unsupported reader type: {reader}.") | ||
|
||
# Initialized an empty whole slide image object dict | ||
self.wsi_object_dict: Dict = {} | ||
|
||
def _get_wsi_object(self, sample: Dict): | ||
image_path = sample["image"] | ||
Nic-Ma marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if image_path not in self.wsi_object_dict: | ||
self.wsi_object_dict[image_path] = self.wsi_reader.read(image_path) | ||
return self.wsi_object_dict[image_path] | ||
|
||
def _get_label(self, sample: Dict): | ||
return np.array(sample["label"], dtype=np.float32) | ||
|
||
def _get_location(self, sample: Dict): | ||
size = self._get_size(sample) | ||
return [sample["location"][i] - size[i] // 2 for i in range(len(size))] | ||
|
||
def _get_level(self, sample: Dict): | ||
if self.level is None: | ||
return sample.get("level", 0) | ||
return self.level | ||
|
||
def _get_size(self, sample: Dict): | ||
if self.size is None: | ||
return ensure_tuple_rep(sample.get("size"), 2) | ||
return self.size | ||
|
||
def _get_data(self, sample: Dict): | ||
# Don't store OpenSlide objects to avoid issues with OpenSlide internal cache | ||
if self.backend == "openslide": | ||
self.wsi_object_dict = {} | ||
wsi_obj = self._get_wsi_object(sample) | ||
location = self._get_location(sample) | ||
level = self._get_level(sample) | ||
size = self._get_size(sample) | ||
return self.wsi_reader.get_data(wsi=wsi_obj, location=location, size=size, level=level) | ||
|
||
def _transform(self, index: int): | ||
# Get a single entry of data | ||
sample: Dict = self.data[index] | ||
# Extract patch image and associated metadata | ||
image, metadata = self._get_data(sample) | ||
# Get the label | ||
label = self._get_label(sample) | ||
|
||
# Create put all patch information together and apply transforms | ||
patch = {"image": image, "label": label, "metadata": metadata} | ||
return apply_transform(self.transform, patch) if self.transform else patch |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
# Copyright (c) MONAI Consortium | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
import unittest | ||
from unittest import skipUnless | ||
|
||
import numpy as np | ||
from numpy.testing import assert_array_equal | ||
from parameterized import parameterized | ||
|
||
from monai.data import PatchWSIDataset | ||
from monai.data.wsi_reader import CuCIMWSIReader, OpenSlideWSIReader | ||
from monai.utils import optional_import | ||
from tests.utils import download_url_or_skip_test, testing_data_config | ||
|
||
cucim, has_cucim = optional_import("cucim") | ||
has_cucim = has_cucim and hasattr(cucim, "CuImage") | ||
openslide, has_osl = optional_import("openslide") | ||
imwrite, has_tiff = optional_import("tifffile", name="imwrite") | ||
_, has_codec = optional_import("imagecodecs") | ||
has_tiff = has_tiff and has_codec | ||
|
||
FILE_KEY = "wsi_img" | ||
FILE_URL = testing_data_config("images", FILE_KEY, "url") | ||
base_name, extension = os.path.basename(f"{FILE_URL}"), ".tiff" | ||
FILE_PATH = os.path.join(os.path.dirname(__file__), "testing_data", "temp_" + base_name + extension) | ||
|
||
TEST_CASE_0 = [ | ||
{"data": [{"image": FILE_PATH, "location": [0, 0], "label": [1], "level": 0}], "size": (1, 1)}, | ||
{"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([1])}, | ||
] | ||
|
||
TEST_CASE_0_L1 = [ | ||
{"data": [{"image": FILE_PATH, "location": [0, 0], "label": [1]}], "size": (1, 1), "level": 1}, | ||
{"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([1])}, | ||
] | ||
|
||
TEST_CASE_0_L2 = [ | ||
{"data": [{"image": FILE_PATH, "location": [0, 0], "label": [1]}], "size": (1, 1), "level": 1}, | ||
{"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([1])}, | ||
] | ||
TEST_CASE_1 = [ | ||
{"data": [{"image": FILE_PATH, "location": [0, 0], "size": 1, "label": [1]}]}, | ||
{"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([1])}, | ||
] | ||
|
||
TEST_CASE_2 = [ | ||
{"data": [{"image": FILE_PATH, "location": [0, 0], "label": [1]}], "size": 1, "level": 0}, | ||
{"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([1])}, | ||
] | ||
|
||
TEST_CASE_3 = [ | ||
{"data": [{"image": FILE_PATH, "location": [0, 0], "label": [[[0, 1], [1, 0]]]}], "size": 1}, | ||
{"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[[0, 1], [1, 0]]])}, | ||
] | ||
|
||
TEST_CASE_4 = [ | ||
{ | ||
"data": [ | ||
{"image": FILE_PATH, "location": [0, 0], "label": [[[0, 1], [1, 0]]]}, | ||
{"image": FILE_PATH, "location": [0, 0], "label": [[[1, 0], [0, 0]]]}, | ||
], | ||
"size": 1, | ||
}, | ||
[ | ||
{"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[[0, 1], [1, 0]]])}, | ||
{"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[[1, 0], [0, 0]]])}, | ||
], | ||
] | ||
|
||
TEST_CASE_5 = [ | ||
{ | ||
"data": [ | ||
{"image": FILE_PATH, "location": [0, 0], "label": [[[0, 1], [1, 0]]], "size": 1, "level": 1}, | ||
{"image": FILE_PATH, "location": [100, 100], "label": [[[1, 0], [0, 0]]], "size": 1, "level": 1}, | ||
] | ||
}, | ||
[ | ||
{"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[[0, 1], [1, 0]]])}, | ||
{"image": np.array([[[243]], [[243]], [[243]]], dtype=np.uint8), "label": np.array([[[1, 0], [0, 0]]])}, | ||
], | ||
] | ||
|
||
|
||
@skipUnless(has_cucim or has_osl or has_tiff, "Requires cucim, openslide, or tifffile!") | ||
def setUpModule(): # noqa: N802 | ||
hash_type = testing_data_config("images", FILE_KEY, "hash_type") | ||
hash_val = testing_data_config("images", FILE_KEY, "hash_val") | ||
download_url_or_skip_test(FILE_URL, FILE_PATH, hash_type=hash_type, hash_val=hash_val) | ||
|
||
|
||
class PatchWSIDatasetTests: | ||
class Tests(unittest.TestCase): | ||
backend = None | ||
|
||
@parameterized.expand([TEST_CASE_0, TEST_CASE_0_L1, TEST_CASE_0_L2, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) | ||
def test_read_patches_str(self, input_parameters, expected): | ||
dataset = PatchWSIDataset(reader=self.backend, **input_parameters) | ||
sample = dataset[0] | ||
self.assertTupleEqual(sample["label"].shape, expected["label"].shape) | ||
self.assertTupleEqual(sample["image"].shape, expected["image"].shape) | ||
self.assertIsNone(assert_array_equal(sample["label"], expected["label"])) | ||
self.assertIsNone(assert_array_equal(sample["image"], expected["image"])) | ||
|
||
@parameterized.expand([TEST_CASE_0, TEST_CASE_0_L1, TEST_CASE_0_L2, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) | ||
def test_read_patches_class(self, input_parameters, expected): | ||
if self.backend == "openslide": | ||
reader = OpenSlideWSIReader | ||
elif self.backend == "cucim": | ||
reader = CuCIMWSIReader | ||
else: | ||
raise ValueError("Unsupported backend: {self.backend}") | ||
dataset = PatchWSIDataset(reader=reader, **input_parameters) | ||
sample = dataset[0] | ||
self.assertTupleEqual(sample["label"].shape, expected["label"].shape) | ||
self.assertTupleEqual(sample["image"].shape, expected["image"].shape) | ||
self.assertIsNone(assert_array_equal(sample["label"], expected["label"])) | ||
self.assertIsNone(assert_array_equal(sample["image"], expected["image"])) | ||
|
||
@parameterized.expand([TEST_CASE_0, TEST_CASE_0_L1, TEST_CASE_0_L2, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) | ||
def test_read_patches_object(self, input_parameters, expected): | ||
if self.backend == "openslide": | ||
reader = OpenSlideWSIReader(level=input_parameters.get("level", 0)) | ||
elif self.backend == "cucim": | ||
reader = CuCIMWSIReader(level=input_parameters.get("level", 0)) | ||
else: | ||
raise ValueError("Unsupported backend: {self.backend}") | ||
dataset = PatchWSIDataset(reader=reader, **input_parameters) | ||
sample = dataset[0] | ||
self.assertTupleEqual(sample["label"].shape, expected["label"].shape) | ||
self.assertTupleEqual(sample["image"].shape, expected["image"].shape) | ||
self.assertIsNone(assert_array_equal(sample["label"], expected["label"])) | ||
self.assertIsNone(assert_array_equal(sample["image"], expected["image"])) | ||
|
||
@parameterized.expand([TEST_CASE_4, TEST_CASE_5]) | ||
def test_read_patches_str_multi(self, input_parameters, expected): | ||
dataset = PatchWSIDataset(reader=self.backend, **input_parameters) | ||
for i in range(len(dataset)): | ||
self.assertTupleEqual(dataset[i]["label"].shape, expected[i]["label"].shape) | ||
self.assertTupleEqual(dataset[i]["image"].shape, expected[i]["image"].shape) | ||
self.assertIsNone(assert_array_equal(dataset[i]["label"], expected[i]["label"])) | ||
self.assertIsNone(assert_array_equal(dataset[i]["image"], expected[i]["image"])) | ||
|
||
|
||
@skipUnless(has_cucim, "Requires cucim") | ||
class TestPatchWSIDatasetCuCIM(PatchWSIDatasetTests.Tests): | ||
@classmethod | ||
def setUpClass(cls): | ||
cls.backend = "cucim" | ||
|
||
|
||
@skipUnless(has_osl, "Requires cucim") | ||
class TestPatchWSIDatasetOpenSlide(PatchWSIDatasetTests.Tests): | ||
@classmethod | ||
def setUpClass(cls): | ||
cls.backend = "openslide" | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.