diff --git a/changes/2819.chore.rst b/changes/2819.chore.rst new file mode 100644 index 0000000000..f9a3358309 --- /dev/null +++ b/changes/2819.chore.rst @@ -0,0 +1,4 @@ +Ensure that invocations of ``create_array`` use consistent keyword arguments, with consistent defaults. +Specifically, ``zarr.api.synchronous.create_array`` now takes a ``write_data`` keyword argument; The +``create_array`` method on ``zarr.Group`` takes ``data`` and ``write_data`` keyword arguments. The ``fill_value`` +keyword argument of the various invocations of ``create_array`` has been consistently set to ``None``, where previously it was either ``None`` or ``0``. \ No newline at end of file diff --git a/src/zarr/api/asynchronous.py b/src/zarr/api/asynchronous.py index 3b53095636..6b573fd033 100644 --- a/src/zarr/api/asynchronous.py +++ b/src/zarr/api/asynchronous.py @@ -11,6 +11,7 @@ from zarr.abc.store import Store from zarr.core.array import ( + DEFAULT_FILL_VALUE, Array, AsyncArray, CompressorLike, @@ -860,10 +861,10 @@ async def open_group( async def create( shape: ChunkCoords | int, *, # Note: this is a change from v2 - chunks: ChunkCoords | int | None = None, # TODO: v2 allowed chunks=True + chunks: ChunkCoords | int | bool | None = None, dtype: ZDTypeLike | None = None, compressor: CompressorLike = "auto", - fill_value: Any | None = 0, # TODO: need type + fill_value: Any | None = DEFAULT_FILL_VALUE, order: MemoryOrder | None = None, store: str | StoreLike | None = None, synchronizer: Any | None = None, diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index bad710ed43..0f57495e61 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -20,6 +20,7 @@ from zarr.abc.store import Store, set_or_delete from zarr.core._info import GroupInfo from zarr.core.array import ( + DEFAULT_FILL_VALUE, Array, AsyncArray, CompressorLike, @@ -71,6 +72,7 @@ from zarr.core.buffer import Buffer, BufferPrototype from zarr.core.chunk_key_encodings import ChunkKeyEncodingLike from zarr.core.common import MemoryOrder + from zarr.core.dtype import ZDTypeLike logger = logging.getLogger("zarr.group") @@ -999,22 +1001,24 @@ async def create_array( self, name: str, *, - shape: ShapeLike, - dtype: npt.DTypeLike, + shape: ShapeLike | None = None, + dtype: ZDTypeLike | None = None, + data: np.ndarray[Any, np.dtype[Any]] | None = None, chunks: ChunkCoords | Literal["auto"] = "auto", shards: ShardsLike | None = None, filters: FiltersLike = "auto", compressors: CompressorsLike = "auto", compressor: CompressorLike = "auto", serializer: SerializerLike = "auto", - fill_value: Any | None = 0, + fill_value: Any | None = DEFAULT_FILL_VALUE, order: MemoryOrder | None = None, attributes: dict[str, JSON] | None = None, chunk_key_encoding: ChunkKeyEncodingLike | None = None, dimension_names: DimensionNames = None, storage_options: dict[str, Any] | None = None, overwrite: bool = False, - config: ArrayConfig | ArrayConfigLike | None = None, + config: ArrayConfigLike | None = None, + write_data: bool = True, ) -> AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]: """Create an array within this group. @@ -1102,6 +1106,11 @@ async def create_array( Whether to overwrite an array with the same name in the store, if one exists. config : ArrayConfig or ArrayConfigLike, optional Runtime configuration for the array. + write_data : bool + If a pre-existing array-like object was provided to this function via the ``data`` parameter + then ``write_data`` determines whether the values in that array-like object should be + written to the Zarr array created by this function. If ``write_data`` is ``False``, then the + array will be left empty. Returns ------- @@ -1116,6 +1125,7 @@ async def create_array( name=name, shape=shape, dtype=dtype, + data=data, chunks=chunks, shards=shards, filters=filters, @@ -1130,6 +1140,7 @@ async def create_array( storage_options=storage_options, overwrite=overwrite, config=config, + write_data=write_data, ) @deprecated("Use AsyncGroup.create_array instead.") @@ -2411,22 +2422,24 @@ def create_array( self, name: str, *, - shape: ShapeLike, - dtype: npt.DTypeLike, + shape: ShapeLike | None = None, + dtype: ZDTypeLike | None = None, + data: np.ndarray[Any, np.dtype[Any]] | None = None, chunks: ChunkCoords | Literal["auto"] = "auto", shards: ShardsLike | None = None, filters: FiltersLike = "auto", compressors: CompressorsLike = "auto", compressor: CompressorLike = "auto", serializer: SerializerLike = "auto", - fill_value: Any | None = 0, - order: MemoryOrder | None = "C", + fill_value: Any | None = DEFAULT_FILL_VALUE, + order: MemoryOrder | None = None, attributes: dict[str, JSON] | None = None, chunk_key_encoding: ChunkKeyEncodingLike | None = None, dimension_names: DimensionNames = None, storage_options: dict[str, Any] | None = None, overwrite: bool = False, - config: ArrayConfig | ArrayConfigLike | None = None, + config: ArrayConfigLike | None = None, + write_data: bool = True, ) -> Array: """Create an array within this group. @@ -2437,10 +2450,13 @@ def create_array( name : str The name of the array relative to the group. If ``path`` is ``None``, the array will be located at the root of the store. - shape : ChunkCoords - Shape of the array. - dtype : npt.DTypeLike - Data type of the array. + shape : ChunkCoords, optional + Shape of the array. Can be ``None`` if ``data`` is provided. + dtype : npt.DTypeLike | None + Data type of the array. Can be ``None`` if ``data`` is provided. + data : Array-like data to use for initializing the array. If this parameter is provided, the + ``shape`` and ``dtype`` parameters must be identical to ``data.shape`` and ``data.dtype``, + or ``None``. chunks : ChunkCoords, optional Chunk shape of the array. If not specified, default are guessed based on the shape and dtype. @@ -2514,6 +2530,11 @@ def create_array( Whether to overwrite an array with the same name in the store, if one exists. config : ArrayConfig or ArrayConfigLike, optional Runtime configuration for the array. + write_data : bool + If a pre-existing array-like object was provided to this function via the ``data`` parameter + then ``write_data`` determines whether the values in that array-like object should be + written to the Zarr array created by this function. If ``write_data`` is ``False``, then the + array will be left empty. Returns ------- @@ -2528,6 +2549,7 @@ def create_array( name=name, shape=shape, dtype=dtype, + data=data, chunks=chunks, shards=shards, fill_value=fill_value, @@ -2541,6 +2563,7 @@ def create_array( overwrite=overwrite, storage_options=storage_options, config=config, + write_data=write_data, ) ) ) @@ -2813,7 +2836,7 @@ def array( compressors: CompressorsLike = "auto", compressor: CompressorLike = None, serializer: SerializerLike = "auto", - fill_value: Any | None = 0, + fill_value: Any | None = DEFAULT_FILL_VALUE, order: MemoryOrder | None = "C", attributes: dict[str, JSON] | None = None, chunk_key_encoding: ChunkKeyEncodingLike | None = None, diff --git a/tests/test_api.py b/tests/test_api.py index e6cb612a82..da61f97847 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,5 +1,7 @@ from __future__ import annotations +import inspect +import pathlib import re from typing import TYPE_CHECKING @@ -8,6 +10,7 @@ if TYPE_CHECKING: import pathlib + from collections.abc import Callable from zarr.abc.store import Store from zarr.core.common import JSON, MemoryOrder, ZarrFormat @@ -1216,6 +1219,43 @@ def test_open_array_with_mode_r_plus(store: Store, zarr_format: ZarrFormat) -> N z2[:] = 3 +@pytest.mark.parametrize( + ("a_func", "b_func"), + [ + (zarr.api.asynchronous.create_array, zarr.api.synchronous.create_array), + (zarr.api.asynchronous.save, zarr.api.synchronous.save), + (zarr.api.asynchronous.save_array, zarr.api.synchronous.save_array), + (zarr.api.asynchronous.save_group, zarr.api.synchronous.save_group), + (zarr.api.asynchronous.open_group, zarr.api.synchronous.open_group), + (zarr.api.asynchronous.create, zarr.api.synchronous.create), + ], +) +def test_consistent_signatures( + a_func: Callable[[object], object], b_func: Callable[[object], object] +) -> None: + """ + Ensure that pairs of functions have the same signature + """ + base_sig = inspect.signature(a_func) + test_sig = inspect.signature(b_func) + wrong: dict[str, list[object]] = { + "missing_from_test": [], + "missing_from_base": [], + "wrong_type": [], + } + for key, value in base_sig.parameters.items(): + if key not in test_sig.parameters: + wrong["missing_from_test"].append((key, value)) + for key, value in test_sig.parameters.items(): + if key not in base_sig.parameters: + wrong["missing_from_base"].append((key, value)) + if base_sig.parameters[key] != value: + wrong["wrong_type"].append({key: {"test": value, "base": base_sig.parameters[key]}}) + assert wrong["missing_from_base"] == [] + assert wrong["missing_from_test"] == [] + assert wrong["wrong_type"] == [] + + def test_api_exports() -> None: """ Test that the sync API and the async API export the same objects diff --git a/tests/test_array.py b/tests/test_array.py index fe23bc1284..0bca860e84 100644 --- a/tests/test_array.py +++ b/tests/test_array.py @@ -970,6 +970,43 @@ def test_auto_partition_auto_shards( assert auto_shards == expected_shards +def test_chunks_and_shards() -> None: + store = StorePath(MemoryStore()) + shape = (100, 100) + chunks = (5, 5) + shards = (10, 10) + + arr_v3 = zarr.create_array(store=store / "v3", shape=shape, chunks=chunks, dtype="i4") + assert arr_v3.chunks == chunks + assert arr_v3.shards is None + + arr_v3_sharding = zarr.create_array( + store=store / "v3_sharding", + shape=shape, + chunks=chunks, + shards=shards, + dtype="i4", + ) + assert arr_v3_sharding.chunks == chunks + assert arr_v3_sharding.shards == shards + + arr_v2 = zarr.create_array( + store=store / "v2", shape=shape, chunks=chunks, zarr_format=2, dtype="i4" + ) + assert arr_v2.chunks == chunks + assert arr_v2.shards is None + + +@pytest.mark.parametrize("store", ["memory"], indirect=True) +@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning") +@pytest.mark.parametrize( + ("dtype", "fill_value_expected"), [(" None: + a = zarr.create_array(store, shape=(5,), chunks=(5,), dtype=dtype) + assert a.fill_value == fill_value_expected + + @pytest.mark.parametrize("store", ["memory"], indirect=True) class TestCreateArray: @staticmethod @@ -1769,6 +1806,25 @@ def test_multiprocessing(store: Store, method: Literal["fork", "spawn", "forkser assert all(np.array_equal(r, data) for r in results) +def test_create_array_method_signature() -> None: + """ + Test that the signature of the ``AsyncGroup.create_array`` function has nearly the same signature + as the ``create_array`` function. ``AsyncGroup.create_array`` should take all of the same keyword + arguments as ``create_array`` except ``store``. + """ + + base_sig = inspect.signature(create_array) + meth_sig = inspect.signature(AsyncGroup.create_array) + # ignore keyword arguments that are either missing or have different semantics when + # create_array is invoked as a group method + ignore_kwargs = {"zarr_format", "store", "name"} + # TODO: make this test stronger. right now, it only checks that all the parameters in the + # function signature are used in the method signature. we can be more strict and check that + # the method signature uses no extra parameters. + base_params = dict(filter(lambda kv: kv[0] not in ignore_kwargs, base_sig.parameters.items())) + assert (set(base_params.items()) - set(meth_sig.parameters.items())) == set() + + async def test_sharding_coordinate_selection() -> None: store = MemoryStore() g = zarr.open_group(store, mode="w") diff --git a/tests/test_group.py b/tests/test_group.py index 60a1fcb9bf..ee2317ade4 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -1531,6 +1531,7 @@ def test_create_nodes_concurrency_limit(store: MemoryStore) -> None: @pytest.mark.parametrize( ("a_func", "b_func"), [ + (zarr.core.group.AsyncGroup.create_array, zarr.core.group.Group.create_array), (zarr.core.group.AsyncGroup.create_hierarchy, zarr.core.group.Group.create_hierarchy), (zarr.core.group.create_hierarchy, zarr.core.sync_group.create_hierarchy), (zarr.core.group.create_nodes, zarr.core.sync_group.create_nodes), @@ -1546,7 +1547,22 @@ def test_consistent_signatures( """ base_sig = inspect.signature(a_func) test_sig = inspect.signature(b_func) - assert test_sig.parameters == base_sig.parameters + wrong: dict[str, list[object]] = { + "missing_from_test": [], + "missing_from_base": [], + "wrong_type": [], + } + for key, value in base_sig.parameters.items(): + if key not in test_sig.parameters: + wrong["missing_from_test"].append((key, value)) + for key, value in test_sig.parameters.items(): + if key not in base_sig.parameters: + wrong["missing_from_base"].append((key, value)) + if base_sig.parameters[key] != value: + wrong["wrong_type"].append({key: {"test": value, "base": base_sig.parameters[key]}}) + assert wrong["missing_from_base"] == [] + assert wrong["missing_from_test"] == [] + assert wrong["wrong_type"] == [] @pytest.mark.parametrize("store", ["memory"], indirect=True)