diff --git a/xarray/backends/plugins.py b/xarray/backends/plugins.py index b8cd2bf6378..88c29306d18 100644 --- a/xarray/backends/plugins.py +++ b/xarray/backends/plugins.py @@ -8,19 +8,19 @@ from .common import BACKEND_ENTRYPOINTS +STANDARD_BACKENDS_ORDER = ["netcdf4", "h5netcdf", "scipy"] -def remove_duplicates(backend_entrypoints): + +def remove_duplicates(pkg_entrypoints): # sort and group entrypoints by name - backend_entrypoints = sorted(backend_entrypoints, key=lambda ep: ep.name) - backend_entrypoints_grouped = itertools.groupby( - backend_entrypoints, key=lambda ep: ep.name - ) + pkg_entrypoints = sorted(pkg_entrypoints, key=lambda ep: ep.name) + pkg_entrypoints_grouped = itertools.groupby(pkg_entrypoints, key=lambda ep: ep.name) # check if there are multiple entrypoints for the same name - unique_backend_entrypoints = [] - for name, matches in backend_entrypoints_grouped: + unique_pkg_entrypoints = [] + for name, matches in pkg_entrypoints_grouped: matches = list(matches) - unique_backend_entrypoints.append(matches[0]) + unique_pkg_entrypoints.append(matches[0]) matches_len = len(matches) if matches_len > 1: selected_module_name = matches[0].module_name @@ -30,7 +30,7 @@ def remove_duplicates(backend_entrypoints): f"\n {all_module_names}.\n It will be used: {selected_module_name}.", RuntimeWarning, ) - return unique_backend_entrypoints + return unique_pkg_entrypoints def detect_parameters(open_dataset): @@ -51,13 +51,13 @@ def detect_parameters(open_dataset): return tuple(parameters_list) -def create_engines_dict(backend_entrypoints): - engines = {} - for backend_ep in backend_entrypoints: - name = backend_ep.name - backend = backend_ep.load() - engines[name] = backend - return engines +def backends_dict_from_pkg(pkg_entrypoints): + backend_entrypoints = {} + for pkg_ep in pkg_entrypoints: + name = pkg_ep.name + backend = pkg_ep.load() + backend_entrypoints[name] = backend + return backend_entrypoints def set_missing_parameters(backend_entrypoints): @@ -67,11 +67,23 @@ def set_missing_parameters(backend_entrypoints): backend.open_dataset_parameters = detect_parameters(open_dataset) -def build_engines(entrypoints): +def sort_backends(backend_entrypoints): + ordered_backends_entrypoints = {} + for be_name in STANDARD_BACKENDS_ORDER: + if backend_entrypoints.get(be_name, None) is not None: + ordered_backends_entrypoints[be_name] = backend_entrypoints.pop(be_name) + ordered_backends_entrypoints.update( + {name: backend_entrypoints[name] for name in sorted(backend_entrypoints)} + ) + return ordered_backends_entrypoints + + +def build_engines(pkg_entrypoints): backend_entrypoints = BACKEND_ENTRYPOINTS.copy() - pkg_entrypoints = remove_duplicates(entrypoints) - external_backend_entrypoints = create_engines_dict(pkg_entrypoints) + pkg_entrypoints = remove_duplicates(pkg_entrypoints) + external_backend_entrypoints = backends_dict_from_pkg(pkg_entrypoints) backend_entrypoints.update(external_backend_entrypoints) + backend_entrypoints = sort_backends(backend_entrypoints) set_missing_parameters(backend_entrypoints) engines = {} for name, backend in backend_entrypoints.items(): @@ -81,8 +93,8 @@ def build_engines(entrypoints): @functools.lru_cache(maxsize=1) def list_engines(): - entrypoints = pkg_resources.iter_entry_points("xarray.backends") - return build_engines(entrypoints) + pkg_entrypoints = pkg_resources.iter_entry_points("xarray.backends") + return build_engines(pkg_entrypoints) def guess_engine(store_spec): diff --git a/xarray/tests/test_plugins.py b/xarray/tests/test_plugins.py index 64a1c563dba..0cda2901cee 100644 --- a/xarray/tests/test_plugins.py +++ b/xarray/tests/test_plugins.py @@ -58,13 +58,13 @@ def test_remove_duplicates_warnings(dummy_duplicated_entrypoints): @mock.patch("pkg_resources.EntryPoint.load", mock.MagicMock(return_value=None)) -def test_create_engines_dict(): +def test_backends_dict_from_pkg(): specs = [ "engine1 = xarray.tests.test_plugins:backend_1", "engine2 = xarray.tests.test_plugins:backend_2", ] entrypoints = [pkg_resources.EntryPoint.parse(spec) for spec in specs] - engines = plugins.create_engines_dict(entrypoints) + engines = plugins.backends_dict_from_pkg(entrypoints) assert len(engines) == 2 assert engines.keys() == set(("engine1", "engine2")) @@ -111,8 +111,38 @@ def test_build_engines(): "cfgrib = xarray.tests.test_plugins:backend_1" ) backend_entrypoints = plugins.build_engines([dummy_pkg_entrypoint]) + assert isinstance(backend_entrypoints["cfgrib"], DummyBackendEntrypoint1) assert backend_entrypoints["cfgrib"].open_dataset_parameters == ( "filename_or_obj", "decoder", ) + + +@mock.patch( + "pkg_resources.EntryPoint.load", + mock.MagicMock(return_value=DummyBackendEntrypoint1), +) +def test_build_engines_sorted(): + dummy_pkg_entrypoints = [ + pkg_resources.EntryPoint.parse( + "dummy2 = xarray.tests.test_plugins:backend_1", + ), + pkg_resources.EntryPoint.parse( + "dummy1 = xarray.tests.test_plugins:backend_1", + ), + ] + backend_entrypoints = plugins.build_engines(dummy_pkg_entrypoints) + backend_entrypoints = list(backend_entrypoints) + + indices = [] + for be in plugins.STANDARD_BACKENDS_ORDER: + try: + index = backend_entrypoints.index(be) + backend_entrypoints.pop(index) + indices.append(index) + except ValueError: + pass + + assert set(indices) < {0, -1} + assert list(backend_entrypoints) == sorted(backend_entrypoints)