Skip to content

Commit ef8f193

Browse files
committed
Add automatic chunking to open_rasterio
This uses the automatic chunking in dask 0.18+ to chunk rasterio datasets in a nicely aligned way. Currently this doesn't implement tests due to a difficulty in creating chunked tiff images.
1 parent 9491318 commit ef8f193

File tree

2 files changed

+35
-1
lines changed

2 files changed

+35
-1
lines changed

xarray/backends/rasterio_.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,8 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None,
176176
Chunk sizes along each dimension, e.g., ``5``, ``(5, 5)`` or
177177
``{'x': 5, 'y': 5}``. If chunks is provided, it used to load the new
178178
DataArray into a dask array.
179+
Chunks can also be set to ``True`` or ``"auto"`` to choose sensible
180+
chunk sizes according to ``dask.config.get("array.chunk-size")``
179181
cache : bool, optional
180182
If True, cache data loaded from the underlying datastore in memory as
181183
NumPy arrays when accessed to avoid reading from the underlying data-
@@ -283,6 +285,30 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None,
283285

284286
# this lets you write arrays loaded with rasterio
285287
data = indexing.CopyOnWriteArray(data)
288+
if chunks in (True, 'auto'):
289+
if not attrs.get('is_tiled', False):
290+
msg = "Data store is not tiled. Automatic chunking is not sensible"
291+
raise ValueError(msg)
292+
293+
import dask.array
294+
if dask.__version__ < '0.18.0':
295+
msg = ("Automatic chunking requires dask.__version__ >= 0.18.0 . "
296+
"You currently have version %s" % dask.__version__)
297+
raise NotImplementedError(msg)
298+
299+
img = riods._ds
300+
block_shapes = set(img.block_shapes)
301+
block_shape = (1,) + list(block_shapes)[0]
302+
previous_chunks = tuple((c,) for c in block_shape)
303+
shape = (img.count, img.height, img.width)
304+
dtype = img.dtypes[0]
305+
chunks = dask.array.core.normalize_chunks(
306+
'auto',
307+
shape=shape,
308+
previous_chunks=previous_chunks,
309+
dtype=dtype
310+
)
311+
286312
if cache and (chunks is None):
287313
data = indexing.MemoryCachedArray(data)
288314

xarray/tests/test_backends.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1340,7 +1340,6 @@ def test_write_uneven_dask_chunks(self):
13401340
print(k)
13411341
assert v.chunks == actual[k].chunks
13421342

1343-
13441343
def test_chunk_encoding(self):
13451344
# These datasets have no dask chunks. All chunking specified in
13461345
# encoding
@@ -3009,6 +3008,15 @@ def test_chunks(self):
30093008
ex = expected.sel(band=1).mean(dim='x')
30103009
assert_allclose(ac, ex)
30113010

3011+
@requires_dask
3012+
def test_chunks_auto(self):
3013+
import dask
3014+
with dask.config.set({'array.chunk-size': '1kiB'}):
3015+
with create_tmp_geotiff(1024, 1024, 3) as (tmp_file, expected):
3016+
with xr.open_rasterio(tmp_file, chunks=True) as actual:
3017+
assert actual.chunks
3018+
# TODO: enhance create_tmp_geotiff to support tiled images
3019+
30123020
def test_pickle_rasterio(self):
30133021
# regression test for https://github.com/pydata/xarray/issues/2121
30143022
with create_tmp_geotiff() as (tmp_file, expected):

0 commit comments

Comments
 (0)