diff --git a/bioimageio/core/__main__.py b/bioimageio/core/__main__.py index a26f43d1..1ae89a2b 100644 --- a/bioimageio/core/__main__.py +++ b/bioimageio/core/__main__.py @@ -4,16 +4,14 @@ import sys import warnings from glob import glob - from pathlib import Path -from pprint import pformat, pprint +from pprint import pformat from typing import List, Optional import typer -from bioimageio.core import __version__, prediction, commands, resource_tests, load_raw_resource_description +from bioimageio.core import __version__, commands, prediction, resource_tests from bioimageio.core.common import TestSummary -from bioimageio.core.prediction_pipeline import get_weight_formats from bioimageio.spec.__main__ import app, help_version as help_version_spec from bioimageio.spec.model.raw_nodes import WeightsFormat @@ -192,7 +190,6 @@ def predict_image( weight_format: Optional[WeightFormatEnum] = typer.Option(None, help="The weight format to use."), devices: Optional[List[str]] = typer.Option(None, help="Devices for running the model."), ): - if isinstance(padding, str): padding = json.loads(padding.replace("'", '"')) assert isinstance(padding, dict) @@ -244,7 +241,7 @@ def predict_images( tiling = json.loads(tiling.replace("'", '"')) assert isinstance(tiling, dict) - # this is a weird typer bug: default devices are empty tuple although they should be None + # this is a weird typer bug: default devices are empty tuple, although they should be None if len(devices) == 0: devices = None prediction.predict_images( diff --git a/bioimageio/core/image_helper.py b/bioimageio/core/image_helper.py index 0468b61f..5f2389f5 100644 --- a/bioimageio/core/image_helper.py +++ b/bioimageio/core/image_helper.py @@ -13,13 +13,13 @@ # -def transform_input_image(image: np.ndarray, tensor_axes: str, image_axes: Optional[str] = None): +def transform_input_image(image: np.ndarray, tensor_axes: Sequence[str], image_axes: Optional[Sequence[str]] = None): """Transform input image into output tensor with desired axes. Args: image: the input image tensor_axes: the desired tensor axes - input_axes: the axes of the input image (optional) + image_axes: the axes of the input image (optional) """ # if the image axes are not given deduce them from the required axes and image shape if image_axes is None: @@ -35,7 +35,16 @@ def transform_input_image(image: np.ndarray, tensor_axes: str, image_axes: Optio image_axes = "bczyx" else: raise ValueError(f"Invalid number of image dimensions: {ndim}") - tensor = DataArray(image, dims=tuple(image_axes)) + + # instead of 'b' we might want 'batch', etc... + axis_letter_map = { + letter: name + for letter, name in {"b": "batch", "c": "channel", "i": "index", "t": "time"}.items() + if name in tensor_axes # only do this mapping if the full name is in the desired tensor_axes + } + image_axes = tuple(axis_letter_map.get(a, a) for a in image_axes) + + tensor = DataArray(image, dims=image_axes) # expand the missing image axes missing_axes = tuple(set(tensor_axes) - set(image_axes)) tensor = tensor.expand_dims(dim=missing_axes) @@ -75,9 +84,10 @@ def transform_output_tensor(tensor: np.ndarray, tensor_axes: str, output_axes: s def to_channel_last(image): - chan_id = image.dims.index("c") + c = "c" if "c" in image.dims else "channel" + chan_id = image.dims.index(c) if chan_id != image.ndim - 1: - target_axes = tuple(ax for ax in image.dims if ax != "c") + ("c",) + target_axes = tuple(ax for ax in image.dims if ax != c) + (c,) image = image.transpose(*target_axes) return image @@ -95,17 +105,17 @@ def load_image(in_path, axes: Sequence[str]) -> DataArray: is_volume = "z" in axes im = imageio.volread(in_path) if is_volume else imageio.imread(in_path) im = transform_input_image(im, axes) - return DataArray(im, dims=axes) + return DataArray(im, dims=tuple(axes)) def load_tensors(sources, tensor_specs: List[Union[InputTensor, OutputTensor]]) -> List[DataArray]: return [load_image(s, sspec.axes) for s, sspec in zip(sources, tensor_specs)] -def save_image(out_path, image): - ext = os.path.splitext(out_path)[1] +def save_image(out_path: os.PathLike, image): + ext = os.path.splitext(str(out_path))[1] if ext == ".npy": - np.save(out_path, image) + np.save(str(out_path), image) else: is_volume = "z" in image.dims @@ -113,9 +123,9 @@ def save_image(out_path, image): squeeze = {ax: 0 if (ax in "bc" and sh == 1) else slice(None) for ax, sh in zip(image.dims, image.shape)} image = image[squeeze] - if "b" in image.dims: + if "b" in image.dims or "batch" in image.dims: raise RuntimeError(f"Cannot save prediction with batchsize > 1 as {ext}-file") - if "c" in image.dims: # image formats need channel last + if "c" in image.dims or "channel" in image.dims: # image formats need channel last image = to_channel_last(image) save_function = imageio.volsave if is_volume else imageio.imsave @@ -157,7 +167,6 @@ def pad(image, axes: Sequence[str], padding, pad_right=True) -> Tuple[np.ndarray pad_width = [] crop = {} for ax, dlen, pr in zip(axes, image.shape, pad_right): - if ax in "zyx": pad_to = padding_[ax] diff --git a/bioimageio/core/prediction_pipeline/_model_adapters/_pytorch_model_adapter.py b/bioimageio/core/prediction_pipeline/_model_adapters/_pytorch_model_adapter.py index f47aa1d7..b3709f30 100644 --- a/bioimageio/core/prediction_pipeline/_model_adapters/_pytorch_model_adapter.py +++ b/bioimageio/core/prediction_pipeline/_model_adapters/_pytorch_model_adapter.py @@ -55,7 +55,7 @@ def _unload(self) -> None: def get_nn_instance(model_node: nodes.Model, **kwargs): weight_spec = model_node.weights.get("pytorch_state_dict") assert weight_spec is not None - assert isinstance(weight_spec.architecture, nodes.ImportedSource) + assert isinstance(weight_spec.architecture, nodes.ImportedCallable) model_kwargs = weight_spec.kwargs joined_kwargs = {} if model_kwargs is missing else dict(model_kwargs) joined_kwargs.update(kwargs) diff --git a/bioimageio/core/resource_io/nodes.py b/bioimageio/core/resource_io/nodes.py index 47e2035f..8c26c66b 100644 --- a/bioimageio/core/resource_io/nodes.py +++ b/bioimageio/core/resource_io/nodes.py @@ -6,10 +6,12 @@ from marshmallow import missing from marshmallow.utils import _Missing -from bioimageio.spec.model import raw_nodes as model_raw_nodes -from bioimageio.spec.rdf import raw_nodes as rdf_raw_nodes from bioimageio.spec.collection import raw_nodes as collection_raw_nodes +from bioimageio.spec.dataset import raw_nodes as dataset_raw_nodes +from bioimageio.spec.model.v0_4 import raw_nodes as model_raw_nodes +from bioimageio.spec.rdf import raw_nodes as rdf_raw_nodes from bioimageio.spec.shared import raw_nodes +from bioimageio.spec.workflow import raw_nodes as workflow_raw_nodes @dataclass @@ -48,12 +50,12 @@ class CiteEntry(Node, rdf_raw_nodes.CiteEntry): @dataclass -class Author(Node, model_raw_nodes.Author): +class Author(Node, rdf_raw_nodes.Author): pass @dataclass -class Maintainer(Node, model_raw_nodes.Maintainer): +class Maintainer(Node, rdf_raw_nodes.Maintainer): pass @@ -62,10 +64,19 @@ class Badge(Node, rdf_raw_nodes.Badge): pass +@dataclass +class Attachments(Node, rdf_raw_nodes.Attachments): + files: Union[_Missing, List[Path]] = missing + unknown: Union[_Missing, Dict[str, Any]] = missing + + @dataclass class RDF(rdf_raw_nodes.RDF, ResourceDescription): + authors: Union[_Missing, List[Author]] = missing + attachments: Union[_Missing, Attachments] = missing badges: Union[_Missing, List[Badge]] = missing - covers: Union[_Missing, List[Path]] = missing + cite: Union[_Missing, List[CiteEntry]] = missing + maintainers: Union[_Missing, List[Maintainer]] = missing @dataclass @@ -74,17 +85,22 @@ class CollectionEntry(Node, collection_raw_nodes.CollectionEntry): @dataclass -class LinkedDataset(Node, model_raw_nodes.LinkedDataset): +class Collection(collection_raw_nodes.Collection, RDF): + collection: List[CollectionEntry] = missing + + +@dataclass +class Dataset(Node, dataset_raw_nodes.Dataset): pass @dataclass -class ModelParent(Node, model_raw_nodes.ModelParent): +class LinkedDataset(Node, model_raw_nodes.LinkedDataset): pass @dataclass -class Collection(collection_raw_nodes.Collection, RDF): +class ModelParent(Node, model_raw_nodes.ModelParent): pass @@ -106,6 +122,7 @@ class Postprocessing(Node, model_raw_nodes.Postprocessing): @dataclass class InputTensor(Node, model_raw_nodes.InputTensor): axes: Tuple[str, ...] = missing + preprocessing: Union[_Missing, List[Preprocessing]] = missing def __post_init__(self): super().__post_init__() @@ -116,6 +133,7 @@ def __post_init__(self): @dataclass class OutputTensor(Node, model_raw_nodes.OutputTensor): axes: Tuple[str, ...] = missing + postprocessing: Union[_Missing, List[Postprocessing]] = missing def __post_init__(self): super().__post_init__() @@ -124,48 +142,47 @@ def __post_init__(self): @dataclass -class ImportedSource(Node): - factory: Callable +class ImportedCallable(Node): + call: Callable def __call__(self, *args, **kwargs): - return self.factory(*args, **kwargs) + return self.call(*args, **kwargs) @dataclass -class KerasHdf5WeightsEntry(Node, model_raw_nodes.KerasHdf5WeightsEntry): - source: Path = missing +class WeightsEntryBase(model_raw_nodes._WeightsEntryBase): + dependencies: Union[_Missing, Dependencies] = missing @dataclass -class OnnxWeightsEntry(Node, model_raw_nodes.OnnxWeightsEntry): +class KerasHdf5WeightsEntry(WeightsEntryBase, model_raw_nodes.KerasHdf5WeightsEntry): source: Path = missing @dataclass -class PytorchStateDictWeightsEntry(Node, model_raw_nodes.PytorchStateDictWeightsEntry): +class OnnxWeightsEntry(WeightsEntryBase, model_raw_nodes.OnnxWeightsEntry): source: Path = missing - architecture: Union[_Missing, ImportedSource] = missing @dataclass -class TorchscriptWeightsEntry(Node, model_raw_nodes.TorchscriptWeightsEntry): +class PytorchStateDictWeightsEntry(WeightsEntryBase, model_raw_nodes.PytorchStateDictWeightsEntry): source: Path = missing + architecture: Union[_Missing, ImportedCallable] = missing @dataclass -class TensorflowJsWeightsEntry(Node, model_raw_nodes.TensorflowJsWeightsEntry): +class TorchscriptWeightsEntry(WeightsEntryBase, model_raw_nodes.TorchscriptWeightsEntry): source: Path = missing @dataclass -class TensorflowSavedModelBundleWeightsEntry(Node, model_raw_nodes.TensorflowSavedModelBundleWeightsEntry): +class TensorflowJsWeightsEntry(WeightsEntryBase, model_raw_nodes.TensorflowJsWeightsEntry): source: Path = missing @dataclass -class Attachments(Node, rdf_raw_nodes.Attachments): - files: Union[_Missing, List[Path]] = missing - unknown: Union[_Missing, Dict[str, Any]] = missing +class TensorflowSavedModelBundleWeightsEntry(WeightsEntryBase, model_raw_nodes.TensorflowSavedModelBundleWeightsEntry): + source: Path = missing WeightsEntry = Union[ @@ -180,8 +197,43 @@ class Attachments(Node, rdf_raw_nodes.Attachments): @dataclass class Model(model_raw_nodes.Model, RDF): - authors: List[Author] = missing - maintainers: Union[_Missing, List[Maintainer]] = missing + inputs: List[InputTensor] = missing + outputs: List[OutputTensor] = missing + parent: Union[_Missing, ModelParent] = missing + run_mode: Union[_Missing, RunMode] = missing test_inputs: List[Path] = missing test_outputs: List[Path] = missing + training_data: Union[_Missing, Dataset, LinkedDataset] = missing weights: Dict[model_raw_nodes.WeightsFormat, WeightsEntry] = missing + + +@dataclass +class Axis(Node, workflow_raw_nodes.Axis): + pass + + +@dataclass +class BatchAxis(Node, workflow_raw_nodes.Axis): + pass + + +@dataclass +class Input(Node, workflow_raw_nodes.Input): + pass + + +@dataclass +class Option(Node, workflow_raw_nodes.Option): + pass + + +@dataclass +class Output(Node, workflow_raw_nodes.Output): + pass + + +@dataclass +class Workflow(workflow_raw_nodes.Workflow, RDF): + inputs: List[Input] = missing + options: List[Option] = missing + outputs: List[Output] = missing diff --git a/bioimageio/core/resource_io/utils.py b/bioimageio/core/resource_io/utils.py index 575f049b..c8582141 100644 --- a/bioimageio/core/resource_io/utils.py +++ b/bioimageio/core/resource_io/utils.py @@ -6,6 +6,8 @@ import typing from types import ModuleType +from marshmallow import missing + from bioimageio.spec.shared import raw_nodes, resolve_source, source_available from bioimageio.spec.shared.node_transformer import ( GenericRawNode, @@ -54,9 +56,9 @@ def generic_visit(self, node): super().generic_visit(node) -class SourceNodeTransformer(NodeTransformer): +class CallableNodeTransformer(NodeTransformer): """ - Imports all source callables + Import all callables note: Requires previous transformation by UriNodeTransformer """ @@ -70,21 +72,23 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): sys.path.remove(self.path) - def transform_LocalImportableModule(self, node: raw_nodes.LocalImportableModule) -> nodes.ImportedSource: + def transform_LocalCallableFromModule(self, node: raw_nodes.LocalCallableFromModule) -> nodes.ImportedCallable: with self.TemporaryInsertionIntoPythonPath(str(node.root_path)): module = importlib.import_module(node.module_name) - return nodes.ImportedSource(factory=getattr(module, node.callable_name)) + return nodes.ImportedCallable(call=getattr(module, node.callable_name)) @staticmethod - def transform_ResolvedImportableSourceFile(node: raw_nodes.ResolvedImportableSourceFile) -> nodes.ImportedSource: + def transform_ResolvedCallableFromSourceFile( + node: raw_nodes.ResolvedCallableFromSourceFile, + ) -> nodes.ImportedCallable: module_path = resolve_source(node.source_file) module_name = f"module_from_source.{module_path.stem}" importlib_spec = importlib.util.spec_from_file_location(module_name, module_path) assert importlib_spec is not None dep = importlib.util.module_from_spec(importlib_spec) importlib_spec.loader.exec_module(dep) # type: ignore # todo: possible to use "loader.load_module"? - return nodes.ImportedSource(factory=getattr(dep, node.callable_name)) + return nodes.ImportedCallable(call=getattr(dep, node.callable_name)) class RawNodeTypeTransformer(NodeTransformer): @@ -95,7 +99,9 @@ def __init__(self, nodes_module: ModuleType): def generic_transformer(self, node: GenericRawNode) -> GenericResolvedNode: if isinstance(node, raw_nodes.RawNode): resolved_data = { - field.name: self.transform(getattr(node, field.name)) for field in dataclasses.fields(node) + field.name: self.transform(getattr(node, field.name)) + for field in dataclasses.fields(node) + if getattr(node, field.name) is not missing # exclude missing fields to respect for node defaults } resolved_node_type: typing.Type[GenericResolvedNode] = getattr(self.nodes, node.__class__.__name__) return resolved_node_type(**resolved_data) # type: ignore @@ -115,10 +121,15 @@ def all_sources_available( def resolve_raw_node( - raw_rd: GenericRawNode, nodes_module: typing.Any, uri_only_if_in_package: bool = True + raw_rd: GenericRawNode, + nodes_module: typing.Any, + uri_only_if_in_package: bool = True, + root_path: typing.Optional[pathlib.Path] = None, ) -> GenericResolvedNode: """resolve all uris and paths (that are included when packaging)""" - rd = UriNodeTransformer(root_path=raw_rd.root_path, uri_only_if_in_package=uri_only_if_in_package).transform(raw_rd) - rd = SourceNodeTransformer().transform(rd) + rd = UriNodeTransformer( + root_path=root_path or raw_rd.root_path, uri_only_if_in_package=uri_only_if_in_package + ).transform(raw_rd) + rd = CallableNodeTransformer().transform(rd) rd = RawNodeTypeTransformer(nodes_module).transform(rd) return rd diff --git a/tests/build_spec/test_add_weights.py b/tests/build_spec/test_add_weights.py index 2f8300b0..482d167f 100644 --- a/tests/build_spec/test_add_weights.py +++ b/tests/build_spec/test_add_weights.py @@ -1,4 +1,7 @@ import os + +import pytest + from bioimageio.core import export_resource_package, load_raw_resource_description, load_resource_description from bioimageio.core.resource_tests import test_model as _test_model @@ -45,5 +48,6 @@ def test_add_torchscript(unet2d_nuclei_broad_model, tmp_path): _test_add_weights(unet2d_nuclei_broad_model, tmp_path, "pytorch_state_dict", "torchscript") +@pytest.mark.skipif(pytest.skip_onnx, reason="onnx") def test_add_onnx(unet2d_nuclei_broad_model, tmp_path): _test_add_weights(unet2d_nuclei_broad_model, tmp_path, "pytorch_state_dict", "onnx", opset_version=12) diff --git a/tests/resource_io/test_utils.py b/tests/resource_io/test_utils.py index 30889a1d..7174a8d5 100644 --- a/tests/resource_io/test_utils.py +++ b/tests/resource_io/test_utils.py @@ -12,9 +12,9 @@ def test_resolve_import_path(tmpdir): (tmpdir / str(source_file)).write_text("class Foo: pass", encoding="utf8") node = raw_nodes.ImportableSourceFile(source_file=source_file, callable_name="Foo") uri_transformed = utils.UriNodeTransformer(root_path=tmpdir).transform(node) - source_transformed = utils.SourceNodeTransformer().transform(uri_transformed) - assert isinstance(source_transformed, nodes.ImportedSource) - Foo = source_transformed.factory + source_transformed = utils.CallableNodeTransformer().transform(uri_transformed) + assert isinstance(source_transformed, nodes.ImportedCallable) + Foo = source_transformed.call assert Foo.__name__ == "Foo" assert isinstance(Foo, type) diff --git a/tests/test_cli.py b/tests/test_cli.py index c0de99d4..55f205f5 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -98,7 +98,7 @@ def _test_cli_predict_images(model, tmp_path, extra_kwargs=None): expected_outputs.append(out_folder / f"im-{i}.npy") input_pattern = str(in_folder / "*.npy") - cmd = ["bioimageio", "predict-images", model, input_pattern, str(out_folder)] + cmd = ["bioimageio", "predict-images", str(model), input_pattern, str(out_folder)] if extra_kwargs is not None: cmd.extend(extra_kwargs) ret = run_subprocess(cmd) @@ -126,7 +126,7 @@ def test_torch_to_torchscript(unet2d_nuclei_broad_model, tmp_path): assert out_path.exists() -@pytest.mark.skipif(pytest.skip_onnx, reason="requires torch and onnx") +@pytest.mark.skipif(pytest.skip_onnx, reason="requires onnx") def test_torch_to_onnx(unet2d_nuclei_broad_model, tmp_path): out_path = tmp_path.with_suffix(".onnx") ret = run_subprocess(["bioimageio", "convert-torch-weights-to-onnx", str(unet2d_nuclei_broad_model), str(out_path)])