diff --git a/asv_bench/benchmarks/import_xarray.py b/asv_bench/benchmarks/import_xarray.py new file mode 100644 index 00000000000..94652e3b82a --- /dev/null +++ b/asv_bench/benchmarks/import_xarray.py @@ -0,0 +1,9 @@ +class ImportXarray: + def setup(self, *args, **kwargs): + def import_xr(): + import xarray # noqa: F401 + + self._import_xr = import_xr + + def time_import_xarray(self): + self._import_xr() diff --git a/xarray/plot/dataset_plot.py b/xarray/plot/dataset_plot.py index c1aedd570bc..7288a368e47 100644 --- a/xarray/plot/dataset_plot.py +++ b/xarray/plot/dataset_plot.py @@ -12,6 +12,7 @@ _process_cmap_cbar_kwargs, get_axis, label_from_attrs, + plt, ) # copied from seaborn @@ -134,8 +135,7 @@ def _infer_scatter_data(ds, x, y, hue, markersize, size_norm, size_mapping=None) # copied from seaborn def _parse_size(data, norm): - - import matplotlib as mpl + mpl = plt.matplotlib if data is None: return None @@ -544,8 +544,6 @@ def quiver(ds, x, y, ax, u, v, **kwargs): Wraps :py:func:`matplotlib:matplotlib.pyplot.quiver`. """ - import matplotlib as mpl - if x is None or y is None or u is None or v is None: raise ValueError("Must specify x, y, u, v for quiver plots.") @@ -560,7 +558,7 @@ def quiver(ds, x, y, ax, u, v, **kwargs): # TODO: Fix this by always returning a norm with vmin, vmax in cmap_params if not cmap_params["norm"]: - cmap_params["norm"] = mpl.colors.Normalize( + cmap_params["norm"] = plt.Normalize( cmap_params.pop("vmin"), cmap_params.pop("vmax") ) @@ -576,8 +574,6 @@ def streamplot(ds, x, y, ax, u, v, **kwargs): Wraps :py:func:`matplotlib:matplotlib.pyplot.streamplot`. """ - import matplotlib as mpl - if x is None or y is None or u is None or v is None: raise ValueError("Must specify x, y, u, v for streamplot plots.") @@ -613,7 +609,7 @@ def streamplot(ds, x, y, ax, u, v, **kwargs): # TODO: Fix this by always returning a norm with vmin, vmax in cmap_params if not cmap_params["norm"]: - cmap_params["norm"] = mpl.colors.Normalize( + cmap_params["norm"] = plt.Normalize( cmap_params.pop("vmin"), cmap_params.pop("vmax") ) diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index 28dd82e76f5..b384dea0571 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -9,8 +9,8 @@ _get_nice_quiver_magnitude, _infer_xy_labels, _process_cmap_cbar_kwargs, - import_matplotlib_pyplot, label_from_attrs, + plt, ) # Overrides axes.labelsize, xtick.major.size, ytick.major.size @@ -116,8 +116,6 @@ def __init__( """ - plt = import_matplotlib_pyplot() - # Handle corner case of nonunique coordinates rep_col = col is not None and not data[col].to_index().is_unique rep_row = row is not None and not data[row].to_index().is_unique @@ -519,10 +517,8 @@ def set_titles(self, template="{coord} = {value}", maxchar=30, size=None, **kwar self: FacetGrid object """ - import matplotlib as mpl - if size is None: - size = mpl.rcParams["axes.labelsize"] + size = plt.rcParams["axes.labelsize"] nicetitle = functools.partial(_nicetitle, maxchar=maxchar, template=template) @@ -619,8 +615,6 @@ def map(self, func, *args, **kwargs): self : FacetGrid object """ - plt = import_matplotlib_pyplot() - for ax, namedict in zip(self.axes.flat, self.name_dicts.flat): if namedict is not None: data = self.data.loc[namedict] diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index dffdde25db4..1e1e59e2f71 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -29,9 +29,9 @@ _resolve_intervals_2dplot, _update_axes, get_axis, - import_matplotlib_pyplot, label_from_attrs, legend_elements, + plt, ) # copied from seaborn @@ -83,8 +83,6 @@ def _parse_size(data, norm, width): If the data is categorical, normalize it to numbers. """ - plt = import_matplotlib_pyplot() - if data is None: return None @@ -682,8 +680,6 @@ def scatter( **kwargs : optional Additional keyword arguments to matplotlib """ - plt = import_matplotlib_pyplot() - # Handle facetgrids first if row or col: allargs = locals().copy() @@ -1111,8 +1107,6 @@ def newplotfunc( allargs["plotfunc"] = globals()[plotfunc.__name__] return _easy_facetgrid(darray, kind="dataarray", **allargs) - plt = import_matplotlib_pyplot() - if ( plotfunc.__name__ == "surface" and not kwargs.get("_is_facetgrid", False) diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index 594f2e5360e..6fbbe9d4bca 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -47,6 +47,12 @@ def import_matplotlib_pyplot(): return plt +try: + plt = import_matplotlib_pyplot() +except ImportError: + plt = None + + def _determine_extend(calc_data, vmin, vmax): extend_min = calc_data.min() < vmin extend_max = calc_data.max() > vmax @@ -64,7 +70,7 @@ def _build_discrete_cmap(cmap, levels, extend, filled): """ Build a discrete colormap and normalization of the data. """ - import matplotlib as mpl + mpl = plt.matplotlib if len(levels) == 1: levels = [levels[0], levels[0]] @@ -115,8 +121,7 @@ def _build_discrete_cmap(cmap, levels, extend, filled): def _color_palette(cmap, n_colors): - import matplotlib.pyplot as plt - from matplotlib.colors import ListedColormap + ListedColormap = plt.matplotlib.colors.ListedColormap colors_i = np.linspace(0, 1.0, n_colors) if isinstance(cmap, (list, tuple)): @@ -177,7 +182,7 @@ def _determine_cmap_params( cmap_params : dict Use depends on the type of the plotting function """ - import matplotlib as mpl + mpl = plt.matplotlib if isinstance(levels, Iterable): levels = sorted(levels) @@ -285,13 +290,13 @@ def _determine_cmap_params( levels = np.asarray([(vmin + vmax) / 2]) else: # N in MaxNLocator refers to bins, not ticks - ticker = mpl.ticker.MaxNLocator(levels - 1) + ticker = plt.MaxNLocator(levels - 1) levels = ticker.tick_values(vmin, vmax) vmin, vmax = levels[0], levels[-1] # GH3734 if vmin == vmax: - vmin, vmax = mpl.ticker.LinearLocator(2).tick_values(vmin, vmax) + vmin, vmax = plt.LinearLocator(2).tick_values(vmin, vmax) if extend is None: extend = _determine_extend(calc_data, vmin, vmax) @@ -421,10 +426,7 @@ def _assert_valid_xy(darray, xy, name): def get_axis(figsize=None, size=None, aspect=None, ax=None, **kwargs): - try: - import matplotlib as mpl - import matplotlib.pyplot as plt - except ImportError: + if plt is None: raise ImportError("matplotlib is required for plot.utils.get_axis") if figsize is not None: @@ -437,7 +439,7 @@ def get_axis(figsize=None, size=None, aspect=None, ax=None, **kwargs): if ax is not None: raise ValueError("cannot provide both `size` and `ax` arguments") if aspect is None: - width, height = mpl.rcParams["figure.figsize"] + width, height = plt.rcParams["figure.figsize"] aspect = width / height figsize = (size * aspect, size) _, ax = plt.subplots(figsize=figsize) @@ -454,9 +456,6 @@ def get_axis(figsize=None, size=None, aspect=None, ax=None, **kwargs): def _maybe_gca(**kwargs): - - import matplotlib.pyplot as plt - # can call gcf unconditionally: either it exists or would be created by plt.axes f = plt.gcf() @@ -912,9 +911,7 @@ def _process_cmap_cbar_kwargs( def _get_nice_quiver_magnitude(u, v): - import matplotlib as mpl - - ticker = mpl.ticker.MaxNLocator(3) + ticker = plt.MaxNLocator(3) mean = np.mean(np.hypot(u.to_numpy(), v.to_numpy())) magnitude = ticker.tick_values(0, mean)[-2] return magnitude @@ -989,7 +986,7 @@ def legend_elements( """ import warnings - import matplotlib as mpl + mpl = plt.matplotlib mlines = mpl.lines @@ -1126,7 +1123,6 @@ def _legend_add_subtitle(handles, labels, text, func): def _adjust_legend_subtitles(legend): """Make invisible-handle "subtitles" entries look more like titles.""" - plt = import_matplotlib_pyplot() # Legend title not in rcParams until 3.0 font_size = plt.rcParams.get("legend.title_fontsize", None)