Skip to content
12 changes: 1 addition & 11 deletions medsegpy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,7 @@ def summary(self, additional_vars=None):
super().summary(summary_vars)


class ContextUNetConfig(Config):
class ContextUNetConfig(ContextEncoderConfig):
"""
Configuration for the ContextUNet model.

Expand All @@ -769,16 +769,6 @@ class ContextUNetConfig(Config):
"""

MODEL_NAME = "ContextUNet"
NUM_FILTERS = [[32, 32], [64, 64], [128, 128], [256, 256]]

def __init__(self, state="training", create_dirs=True):
super().__init__(self.MODEL_NAME, state, create_dirs=create_dirs)

def summary(self, additional_vars=None):
summary_vars = ["NUM_FILTERS"]
if additional_vars:
summary_vars.extend(additional_vars)
super().summary(summary_vars)


class ContextInpaintingConfig(ContextUNetConfig):
Expand Down
9 changes: 3 additions & 6 deletions medsegpy/cross_validation/cv_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,9 @@ def init_cv_experiments(self, num_valid_bins=1, num_test_bins=1):

for i in range(len(temp)):
for j in range(i + 1, len(temp)):
assert (
len(set(temp[i]) & set(temp[j])) == 0
), "Test bins %d and %d not mutually exclusive - %d overlap" % (
i,
j,
len(set(temp[i]) & set(temp[j])),
assert len(set(temp[i]) & set(temp[j])) == 0, (
"Test bins %d and %d not mutually exclusive - %d overlap"
% (i, j, len(set(temp[i]) & set(temp[j])))
)

self.num_valid_bins = num_valid_bins
Expand Down
4 changes: 3 additions & 1 deletion medsegpy/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Sequence, Union

import numpy as np
from numba import njit


def collect_mask(mask: np.ndarray, index: Sequence[Union[int, Sequence[int], int]]):
Expand Down Expand Up @@ -195,7 +196,7 @@ def generate_poisson_disc_mask(
x /= x.max()
y = np.maximum(abs(y - img_shape[-2] / 2), 0)
y /= y.max()
r = np.sqrt(x**2 + y**2)
r = np.sqrt(x ** 2 + y ** 2)

# Quick checks
assert int(num_samples) == num_samples, (
Expand Down Expand Up @@ -233,6 +234,7 @@ def generate_poisson_disc_mask(
return mask, patch_mask


@njit
def _poisson(nx, ny, K, R, num_samples=None, patch_size=0.0, seed=None):
mask = np.zeros((ny, nx))
patch_mask = np.zeros((ny, nx))
Expand Down
2 changes: 2 additions & 0 deletions medsegpy/data/datasets/abct.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,4 +153,6 @@ def register_all_abct():
txt_file_or_scan_root = os.path.join(
Cluster.working_cluster().data_dir, txt_file_or_scan_root
)
if not os.path.exists(txt_file_or_scan_root):
continue
register_abct(dataset_name, txt_file_or_scan_root)
2 changes: 2 additions & 0 deletions medsegpy/data/datasets/oai.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,4 +176,6 @@ def register_all_oai():
for dataset_name, scan_root in _DATA_CATALOG.items():
if not os.path.isabs(scan_root):
scan_root = os.path.join(Cluster.working_cluster().data_dir, scan_root)
if not os.path.exists(scan_root):
continue
register_oai(dataset_name, scan_root)
6 changes: 6 additions & 0 deletions medsegpy/data/datasets/qdess_mri.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import re

from medsegpy.data.catalog import DatasetCatalog, MetadataCatalog
from medsegpy.utils.cluster import Cluster

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -161,6 +162,7 @@ def load_2d_from_filepaths(filepaths: list, source_path: str, dataset_name: str
corresponding ground truth segmentations.
total_num_slices: The total number of slices for this dataset.
dataset_name: The name of the dataset.

Returns:
dataset_dicts: A list of dictionaries, described above in the
docstring.
Expand Down Expand Up @@ -336,4 +338,8 @@ def register_all_qdess_datasets():
Registers all qDESS MRI datasets listed in _DATA_CATALOG.
"""
for dataset_name, scan_root in _DATA_CATALOG.items():
if not os.path.isabs(scan_root):
scan_root = os.path.join(Cluster.working_cluster().data_dir, scan_root)
if not os.path.exists(scan_root):
continue
register_qdess_dataset(scan_root=scan_root, dataset_name=dataset_name)
8 changes: 3 additions & 5 deletions medsegpy/data/im_gens.py
Original file line number Diff line number Diff line change
Expand Up @@ -1234,11 +1234,9 @@ def __validate_img_size__(self, total_volume_shape):
# this means shape of total volume must be perfectly divisible into
# cubes of size IMG_SIZE
for dim in range(3):
assert (
total_volume_shape[dim] % self.config.IMG_SIZE[dim] == 0
), "Cannot divide volume of size %s to blocks of size %s" % (
total_volume_shape,
self.config.IMG_SIZE,
assert total_volume_shape[dim] % self.config.IMG_SIZE[dim] == 0, (
"Cannot divide volume of size %s to blocks of size %s"
% (total_volume_shape, self.config.IMG_SIZE)
)

def img_generator_test(self, model=None):
Expand Down
5 changes: 5 additions & 0 deletions medsegpy/data/transforms/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def apply_image(self, img: np.ndarray):
img (ndarray): of shape NxHxWxC, or HxWxC or HxW. The array can be
of type uint8 in range [0, 255], or floating point in range
[0, 1] or [0, 255].

Returns:
ndarray: image after apply the transformation.
"""
Expand Down Expand Up @@ -147,6 +148,7 @@ def _apply(self, x: _T, meth: str) -> _T:
Args:
x: input to apply the transform operations.
meth (str): meth.

Returns:
x: after apply the transformation.
"""
Expand All @@ -167,6 +169,7 @@ def __add__(self, other: "TransformList") -> "TransformList":
"""
Args:
other (TransformList): transformation to add.

Returns:
TransformList: list of transforms.
"""
Expand All @@ -177,6 +180,7 @@ def __iadd__(self, other: "TransformList") -> "TransformList":
"""
Args:
other (TransformList): transformation to add.

Returns:
TransformList: list of transforms.
"""
Expand All @@ -188,6 +192,7 @@ def __radd__(self, other: "TransformList") -> "TransformList":
"""
Args:
other (TransformList): transformation to add.

Returns:
TransformList: list of transforms.
"""
Expand Down
Loading