From 951026c7dcb0e362682328a545961dc1b140865d Mon Sep 17 00:00:00 2001 From: Fabien Maussion Date: Sat, 15 Oct 2016 11:21:54 +0200 Subject: [PATCH 1/2] fixes https://github.com/pydata/xarray/pull/1027 --- xarray/core/combine.py | 15 ++++---- xarray/core/utils.py | 70 ++++++++++++++++++++++++++++++++++++- xarray/test/test_dataset.py | 15 ++++++++ xarray/test/test_utils.py | 14 ++++++++ 4 files changed, 107 insertions(+), 7 deletions(-) diff --git a/xarray/core/combine.py b/xarray/core/combine.py index 674c19c9191..f47af6666c7 100644 --- a/xarray/core/combine.py +++ b/xarray/core/combine.py @@ -145,6 +145,8 @@ def _calc_concat_over(datasets, dim, data_vars, coords): Determine which dataset variables need to be concatenated in the result, and which can simply be taken from the first dataset. """ + from .utils import OrderedSet + def process_subset_opt(opt, subset): if subset == 'coords': subset_long_name = 'coordinates' @@ -161,11 +163,11 @@ def differs(vname): return any(not ds.variables[vname].equals(v) for ds in datasets[1:]) # all nonindexes that are not the same in each dataset - concat_new = set(k for k in getattr(datasets[0], subset) - if k not in concat_over and differs(k)) + concat_new = OrderedSet(k for k in getattr(datasets[0], subset) + if k not in concat_over and differs(k)) elif opt == 'all': - concat_new = (set(getattr(datasets[0], subset)) - - set(datasets[0].dims)) + concat_new = (OrderedSet(getattr(datasets[0], subset)) - + OrderedSet(datasets[0].dims)) elif opt == 'minimal': concat_new = set() else: @@ -182,11 +184,12 @@ def differs(vname): return concat_new concat_over = set() + ivars = list(process_subset_opt(coords, 'coords')) + \ + list(process_subset_opt(data_vars, 'data_vars')) + concat_over = OrderedSet(ivars) for ds in datasets: concat_over.update(k for k, v in ds.variables.items() if dim in v.dims) - concat_over.update(process_subset_opt(data_vars, 'data_vars')) - concat_over.update(process_subset_opt(coords, 'coords')) if dim in datasets[0]: concat_over.add(dim) diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 3e40fe32ed0..943ac3f7394 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -6,7 +6,7 @@ import itertools import re import warnings -from collections import Mapping, MutableMapping, Iterable +from collections import Mapping, MutableMapping, Iterable, MutableSet import numpy as np import pandas as pd @@ -260,6 +260,74 @@ def ordered_dict_intersection(first_dict, second_dict, compat=equivalent): return new_dict +class OrderedSet(MutableSet): + # Recipe: https://code.activestate.com/recipes/576694/ + + def __init__(self, iterable=None): + self.end = end = [] + end += [None, end, end] # sentinel node for doubly linked list + self.map = {} # key --> [key, prev, next] + if iterable is not None: + self |= iterable + + def __len__(self): + return len(self.map) + + def __contains__(self, key): + return key in self.map + + def __add__(self, other): + self.update(other) + return self + + def add(self, key): + if key not in self.map: + end = self.end + curr = end[1] + curr[2] = end[1] = self.map[key] = [key, curr, end] + + def update(self, ite): + for k in ite: + self.add(k) + + def discard(self, key): + if key in self.map: + key, prev, next = self.map.pop(key) + prev[2] = next + next[1] = prev + + def __iter__(self): + end = self.end + curr = end[2] + while curr is not end: + yield curr[0] + curr = curr[2] + + def __reversed__(self): + end = self.end + curr = end[1] + while curr is not end: + yield curr[0] + curr = curr[1] + + def pop(self, last=True): + if not self: + raise KeyError('set is empty') + key = self.end[1][0] if last else self.end[2][0] + self.discard(key) + return key + + def __repr__(self): + if not self: + return '%s()' % (self.__class__.__name__,) + return '%s(%r)' % (self.__class__.__name__, list(self)) + + def __eq__(self, other): + if isinstance(other, OrderedSet): + return len(self) == len(other) and list(self) == list(other) + return set(self) == set(other) + + class SingleSlotPickleMixin(object): """Mixin class to add the ability to pickle objects whose state is defined by a single __slots__ attribute. Only necessary under Python 2. diff --git a/xarray/test/test_dataset.py b/xarray/test/test_dataset.py index ab0306915c5..723b0a72917 100644 --- a/xarray/test/test_dataset.py +++ b/xarray/test/test_dataset.py @@ -1858,6 +1858,21 @@ def test_groupby_nan(self): expected = Dataset({'foo': ('bar', [1.5, 3]), 'bar': [1, 2]}) self.assertDatasetIdentical(actual, expected) + def test_groupby_order(self): + # groupby should preserve variables order + ds = Dataset() + for vn in ['a', 'b', 'c']: + ds[vn] = DataArray(np.arange(10), dims=['t']) + all_vars_ref = list(ds.variables.keys()) + data_vars_ref = list(ds.data_vars.keys()) + ds = ds.groupby('t').mean() + all_vars = list(ds.variables.keys()) + data_vars = list(ds.data_vars.keys()) + self.assertEqual(data_vars, data_vars_ref) + # coords are now at the end of the list, so the test below fails + # self.assertEqual(all_vars, all_vars_ref) + + def test_resample_and_first(self): times = pd.date_range('2000-01-01', freq='6H', periods=10) ds = Dataset({'foo': (['time', 'x', 'y'], np.random.randn(10, 5, 3)), diff --git a/xarray/test/test_utils.py b/xarray/test/test_utils.py index 373940c97d4..a35c7572019 100644 --- a/xarray/test/test_utils.py +++ b/xarray/test/test_utils.py @@ -168,3 +168,17 @@ def test_hashable(self): self.assertTrue(utils.hashable(v)) for v in [[5, 6], ['seven', '8'], {9: 'ten'}]: self.assertFalse(utils.hashable(v)) + + +class Test_OrderedSet(TestCase): + + def test_set_op(self): + s1 = utils.OrderedSet(['a', 'b', 'c']) + s2 = utils.OrderedSet(['b']) + s3 = utils.OrderedSet(['d', 'e']) + self.assertEqual(s1 - s3, s1) + self.assertEqual(s1 - s2, set(['a', 'c'])) + self.assertEqual(s1 + s2, s1) + self.assertEqual(s1 + s3, set(['a', 'b', 'c', 'd', 'e'])) + s1.update(s3) + self.assertEqual(s1, set(['a', 'b', 'c', 'd', 'e'])) From a46c205fe21310d3ff0e81550100e79a0ac5040b Mon Sep 17 00:00:00 2001 From: Fabien Maussion Date: Sat, 15 Oct 2016 12:13:36 +0200 Subject: [PATCH 2/2] simplify class --- xarray/core/utils.py | 19 +++---------------- xarray/test/test_utils.py | 6 ++++++ 2 files changed, 9 insertions(+), 16 deletions(-) diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 943ac3f7394..d775ecb4656 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -277,8 +277,9 @@ def __contains__(self, key): return key in self.map def __add__(self, other): - self.update(other) - return self + out = OrderedSet(self) + out.update(other) + return out def add(self, key): if key not in self.map: @@ -303,20 +304,6 @@ def __iter__(self): yield curr[0] curr = curr[2] - def __reversed__(self): - end = self.end - curr = end[1] - while curr is not end: - yield curr[0] - curr = curr[1] - - def pop(self, last=True): - if not self: - raise KeyError('set is empty') - key = self.end[1][0] if last else self.end[2][0] - self.discard(key) - return key - def __repr__(self): if not self: return '%s()' % (self.__class__.__name__,) diff --git a/xarray/test/test_utils.py b/xarray/test/test_utils.py index a35c7572019..fdeb40af991 100644 --- a/xarray/test/test_utils.py +++ b/xarray/test/test_utils.py @@ -182,3 +182,9 @@ def test_set_op(self): self.assertEqual(s1 + s3, set(['a', 'b', 'c', 'd', 'e'])) s1.update(s3) self.assertEqual(s1, set(['a', 'b', 'c', 'd', 'e'])) + s1.discard('e') + self.assertEqual(s1, set(['a', 'b', 'c', 'd'])) + + + +