From a9a5e93ca794980c61dfed3d25900f410f348f6b Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Thu, 5 Dec 2019 12:16:10 -0600 Subject: [PATCH 1/4] Fix map_blocks HLG layering This fixes an issue with the HighLevelGraph noted in https://github.com/pydata/xarray/pull/3584, and exposed by a recent change in Dask to do more HLG fusion. --- xarray/core/parallel.py | 13 ++++++++++--- xarray/tests/test_dask.py | 7 +++++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index fbb5ef94ca2..183eb27f85d 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -7,12 +7,14 @@ except ImportError: pass +import collections import itertools import operator from typing import ( Any, Callable, Dict, + DefaultDict, Hashable, Mapping, Sequence, @@ -222,6 +224,7 @@ def _wrapper(func, obj, to_array, args, kwargs): indexes.update({k: template.indexes[k] for k in new_indexes}) graph: Dict[Any, Any] = {} + new_layers: DefaultDict[str, Dict[Any, Any]] = collections.defaultdict(dict) gname = "{}-{}".format( dask.utils.funcname(func), dask.base.tokenize(dataset, args, kwargs) ) @@ -310,9 +313,13 @@ def _wrapper(func, obj, to_array, args, kwargs): # unchunked dimensions in the input have one chunk in the result key += (0,) - graph[key] = (operator.getitem, from_wrapper, name) + new_layers[gname_l][key] = (operator.getitem, from_wrapper, name) - graph = HighLevelGraph.from_collections(gname, graph, dependencies=[dataset]) + hlg = HighLevelGraph.from_collections(gname, graph, dependencies=[dataset]) + + for gname_l, layer in new_layers.items(): + hlg.dependencies[gname_l] = {gname} + hlg.layers[gname_l] = layer result = Dataset(coords=indexes, attrs=template.attrs) for name, gname_l in var_key_map.items(): @@ -325,7 +332,7 @@ def _wrapper(func, obj, to_array, args, kwargs): var_chunks.append((len(indexes[dim]),)) data = dask.array.Array( - graph, name=gname_l, chunks=var_chunks, dtype=template[name].dtype + hlg, name=gname_l, chunks=var_chunks, dtype=template[name].dtype ) result[name] = (dims, data, template[name].attrs) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index f3b10e3370c..c229326df42 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1189,6 +1189,13 @@ def func(obj): assert_identical(expected.compute(), actual.compute()) +def test_map_blocks_hlg_layers(): + ds = xr.Dataset({"x": (("y",), dask.array.ones(10, chunks=(5,)))}) + mapped = ds.map_blocks(lambda x: x) + + xr.testing.assert_equal(mapped, ds) # does not work + + def test_make_meta(map_ds): from ..core.parallel import make_meta From 1a304a723cbc123bfc0cde5c301fa5480b04a2dd Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Thu, 5 Dec 2019 14:38:49 -0600 Subject: [PATCH 2/4] update --- doc/whats-new.rst | 2 ++ xarray/core/parallel.py | 7 +++++++ xarray/tests/test_dask.py | 8 ++++++-- 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index d4d8ab8f3e5..8b246608fe0 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -36,6 +36,8 @@ Bug fixes ~~~~~~~~~ - Fix plotting with transposed 2D non-dimensional coordinates. (:issue:`3138`, :pull:`3441`) By `Deepak Cherian `_. +- Fix issue with Dask-backed datasets raising a ``KeyError`` on some computations involving ``map_blocks`` (:pull:`3598`) + By `Tom Augspurger `_. Documentation ~~~~~~~~~~~~~ diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 183eb27f85d..8e593d18fe6 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -223,6 +223,10 @@ def _wrapper(func, obj, to_array, args, kwargs): indexes = {dim: dataset.indexes[dim] for dim in preserved_indexes} indexes.update({k: template.indexes[k] for k in new_indexes}) + # We're building a new HighLevelGraph hlg. We'll have one new layer + # for each variable in the dataset, which is the result of the + # func applied to the values. + graph: Dict[Any, Any] = {} new_layers: DefaultDict[str, Dict[Any, Any]] = collections.defaultdict(dict) gname = "{}-{}".format( @@ -318,6 +322,9 @@ def _wrapper(func, obj, to_array, args, kwargs): hlg = HighLevelGraph.from_collections(gname, graph, dependencies=[dataset]) for gname_l, layer in new_layers.items(): + # Ensure we have a valid HighLevelGraph. + # This adds in the getitems for each variable in the dataset. + # This just depends on the layer we created earlier ("graph") hlg.dependencies[gname_l] = {gname} hlg.layers[gname_l] = layer diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index c229326df42..9c6d4d65a95 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1190,10 +1190,14 @@ def func(obj): def test_map_blocks_hlg_layers(): - ds = xr.Dataset({"x": (("y",), dask.array.ones(10, chunks=(5,)))}) + # regression test for #3599 + ds = xr.Dataset({ + "x": (("a",), dask.array.ones(10, chunks=(5,))), + "z": (("b",), dask.array.ones(10, chunks=(5,))), + }) mapped = ds.map_blocks(lambda x: x) - xr.testing.assert_equal(mapped, ds) # does not work + xr.testing.assert_equal(mapped, ds) def test_make_meta(map_ds): From 0777591157c1578fb0fb5bd096859c8bbac4089b Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Thu, 5 Dec 2019 14:39:03 -0600 Subject: [PATCH 3/4] black --- xarray/tests/test_dask.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 9c6d4d65a95..6122e987154 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -1191,10 +1191,12 @@ def func(obj): def test_map_blocks_hlg_layers(): # regression test for #3599 - ds = xr.Dataset({ - "x": (("a",), dask.array.ones(10, chunks=(5,))), - "z": (("b",), dask.array.ones(10, chunks=(5,))), - }) + ds = xr.Dataset( + { + "x": (("a",), dask.array.ones(10, chunks=(5,))), + "z": (("b",), dask.array.ones(10, chunks=(5,))), + } + ) mapped = ds.map_blocks(lambda x: x) xr.testing.assert_equal(mapped, ds) From 0ea4ff84efecee4e788a7cb188dc461c3aba9f91 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Thu, 5 Dec 2019 14:41:25 -0600 Subject: [PATCH 4/4] update --- xarray/core/parallel.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 8e593d18fe6..dd6c67338d8 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -317,14 +317,18 @@ def _wrapper(func, obj, to_array, args, kwargs): # unchunked dimensions in the input have one chunk in the result key += (0,) + # We're adding multiple new layers to the graph: + # The first new layer is the result of the computation on + # the array. + # Then we add one layer per variable, which extracts the + # result for that variable, and depends on just the first new + # layer. new_layers[gname_l][key] = (operator.getitem, from_wrapper, name) hlg = HighLevelGraph.from_collections(gname, graph, dependencies=[dataset]) for gname_l, layer in new_layers.items(): - # Ensure we have a valid HighLevelGraph. # This adds in the getitems for each variable in the dataset. - # This just depends on the layer we created earlier ("graph") hlg.dependencies[gname_l] = {gname} hlg.layers[gname_l] = layer