diff --git a/ci/requirements/py37.yml b/ci/requirements/py37.yml index dd3267abe4c..f697b62dc9b 100644 --- a/ci/requirements/py37.yml +++ b/ci/requirements/py37.yml @@ -17,6 +17,7 @@ dependencies: - h5netcdf - h5py - hdf5 + - hvplot - hypothesis - iris - isort diff --git a/setup.cfg b/setup.cfg index e336f46e68c..5fb749259c9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -90,6 +90,10 @@ xarray = static/css/* static/html/* +[options.entry_points] +xarray_plotting_backends = + matplotlib = xarray:plot + [tool:pytest] python_files = test_*.py testpaths = xarray/tests properties @@ -197,4 +201,4 @@ ignore_errors = True test = pytest [pytest-watch] -nobeep = True \ No newline at end of file +nobeep = True diff --git a/xarray/core/options.py b/xarray/core/options.py index 72f9ad8e1fa..81f0bf6f513 100644 --- a/xarray/core/options.py +++ b/xarray/core/options.py @@ -9,6 +9,7 @@ CMAP_DIVERGENT = "cmap_divergent" KEEP_ATTRS = "keep_attrs" DISPLAY_STYLE = "display_style" +PLOTTING_BACKEND = "plotting_backend" OPTIONS = { @@ -21,6 +22,7 @@ CMAP_DIVERGENT: "RdBu_r", KEEP_ATTRS: "default", DISPLAY_STYLE: "text", + PLOTTING_BACKEND: "matplotlib", } _JOIN_OPTIONS = frozenset(["inner", "outer", "left", "right", "exact"]) @@ -39,6 +41,7 @@ def _positive_integer(value): WARN_FOR_UNCLOSED_FILES: lambda value: isinstance(value, bool), KEEP_ATTRS: lambda choice: choice in [True, False, "default"], DISPLAY_STYLE: _DISPLAY_OPTIONS.__contains__, + PLOTTING_BACKEND: lambda value: isinstance(value, str), } @@ -56,9 +59,16 @@ def _warn_on_setting_enable_cftimeindex(enable_cftimeindex): ) +def _set_plotting_backend(backend): + from ..plot.utils import _get_plot_backend + + return _get_plot_backend(backend) + + _SETTERS = { FILE_CACHE_MAXSIZE: _set_file_cache_maxsize, ENABLE_CFTIMEINDEX: _warn_on_setting_enable_cftimeindex, + PLOTTING_BACKEND: _set_plotting_backend, } @@ -104,6 +114,10 @@ class set_options: Default: ``'default'``. - ``display_style``: display style to use in jupyter for xarray objects. Default: ``'text'``. Other options are ``'html'``. + - ``plotting_backend``: The name of plotting backend to use. Backends can be implemented + as third-party libraries implementing the xarray plotting API. They can use other + plotting libraries like Bokeh, Holoviews, Hvplot, Altair, etc. + Default: ``'matplotlib'`` You can use ``set_options`` either as a context manager: diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index 6eec7c6b433..10c4e050d06 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -575,7 +575,6 @@ def _add_colorbar(primitive, ax, cbar_ax, cbar_kwargs, cmap_params): cbar_kwargs.setdefault("ax", ax) else: cbar_kwargs.setdefault("cax", cbar_ax) - cbar = plt.colorbar(primitive, **cbar_kwargs) return cbar @@ -772,3 +771,87 @@ def _process_cmap_cbar_kwargs( cmap_params = _determine_cmap_params(**cmap_kwargs) return cmap_params, cbar_kwargs + + +_backends = {} + + +def _find_backend(backend): + """ + Find an xarray plotting backend + + Parameters + ---------- + backend : str + The identifier for the backend. Either an entrypoint item registered + with pkg_resources, or a module name. + + Notes + ----- + Modifies _backends with imported backends as a side effect. + + Returns + ------- + types.ModuleType + The imported backend. + """ + + import pkg_resources # Delay import for performance. + import importlib + + for entry_point in pkg_resources.iter_entry_points("xarray_plotting_backends"): + if entry_point.name == "matplotlib": + # matplotlib is an optional dependency. When + # missing, this would raise. + continue + _backends[entry_point.name] = entry_point.load() + + try: + return _backends[backend] + except KeyError: + # Fall back to unregisted, module name approach. + try: + module = importlib.import_module(backend) + except ImportError: + # We re-raise later on. + pass + + else: + if hasattr(module, "plot"): + # Validate that the interface is implemented when the option is set, + # rather than at plot time + _backends[backend] = module + return module + msg = ( + "Could not find plotting backend '{name}'. Ensure that you've installed the " + "package providing the '{name}' entrypoint, or that the package has a" + "top-level `.plot` method." + ) + + raise ValueError(msg.format(name=backend)) + + +def _get_plot_backend(backend=None): + """ + Return the plotting backend to use + """ + + backend = backend or OPTIONS["plotting_backend"] + + if backend == "matplotlib": + try: + import xarray.plot as module + except ImportError: + raise ImportError( + "matplotlib is required for plotting when the " + 'default backend "matplotlib" is selected.' + ) from None + + _backends["matplotlib"] = module + + if backend in _backends: + return _backends[backend] + + module = _find_backend(backend) + _backends[backend] = module + return module diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 6592360cdf2..ff376e818bc 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -85,6 +85,8 @@ def LooseVersion(vstring): has_seaborn = False requires_seaborn = pytest.mark.skipif(not has_seaborn, reason="requires seaborn") +has_hvplot, requires_hvplot = _importorskip("hvplot") + # change some global options for tests set_options(warn_for_unclosed_files=True) diff --git a/xarray/tests/test_options.py b/xarray/tests/test_options.py index f155acbf494..ac859617568 100644 --- a/xarray/tests/test_options.py +++ b/xarray/tests/test_options.py @@ -77,6 +77,18 @@ def test_display_style(): assert OPTIONS["display_style"] == original +def test_plotting_backend(): + original = "matplotlib" + assert OPTIONS["plotting_backend"] == original + with pytest.raises(ValueError): + xarray.set_options(plotting_backend=5) + + with xarray.set_options(plotting_backend="holoviews"): + assert OPTIONS["plotting_backend"] == "holoviews" + + assert OPTIONS["plotting_backend"] == original + + def create_test_dataset_attrs(seed=0): ds = create_test_data(seed) ds.attrs = {"attr1": 5, "attr2": "history", "attr3": {"nested": "more_info"}} diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 71cb119f0d6..07dc376b415 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -24,6 +24,7 @@ has_nc_time_axis, raises_regex, requires_cftime, + requires_hvplot, requires_matplotlib, requires_nc_time_axis, requires_seaborn, @@ -2228,3 +2229,13 @@ def test_plot_transposes_properly(plotfunc): # pcolormesh returns 1D array but imshow returns a 2D array so it is necessary # to ravel() on the LHS assert np.all(hdl.get_array().ravel() == da.to_masked_array().ravel()) + + +@requires_matplotlib +@requires_hvplot +@pytest.mark.parametrize("plotting_backend", ["matplotlib", "hvplot.plotting"]) +def test_plotting_backend(plotting_backend): + air = xr.tutorial.open_dataset("air_temperature").load().air + + with xr.set_options(plotting_backend=plotting_backend): + air.isel(time=500).plot(add_colorbar=False)