diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 3de610b3046..ab934ee9061 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -21,6 +21,10 @@ v0.12.1 (unreleased) Enhancements ~~~~~~~~~~~~ +- Allow ``expand_dims`` method to support inserting/broadcasting dimensions + with size > 1. (:issue:`2710`) + By `Martin Pletcher `_. + Bug fixes ~~~~~~~~~ diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index e7e12ae3da4..75f3298104f 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -1,4 +1,5 @@ import functools +import sys import warnings from collections import OrderedDict @@ -1138,7 +1139,7 @@ def swap_dims(self, dims_dict): ds = self._to_temp_dataset().swap_dims(dims_dict) return self._from_temp_dataset(ds) - def expand_dims(self, dim, axis=None): + def expand_dims(self, dim=None, axis=None, **dim_kwargs): """Return a new object with an additional axis (or axes) inserted at the corresponding position in the array shape. @@ -1147,21 +1148,53 @@ def expand_dims(self, dim, axis=None): Parameters ---------- - dim : str or sequence of str. + dim : str, sequence of str, dict, or None Dimensions to include on the new variable. - dimensions are inserted with length 1. + If provided as str or sequence of str, then dimensions are inserted + with length 1. If provided as a dict, then the keys are the new + dimensions and the values are either integers (giving the length of + the new dimensions) or sequence/ndarray (giving the coordinates of + the new dimensions). **WARNING** for python 3.5, if ``dim`` is + dict-like, then it must be an ``OrderedDict``. This is to ensure + that the order in which the dims are given is maintained. axis : integer, list (or tuple) of integers, or None Axis position(s) where new axis is to be inserted (position(s) on the result array). If a list (or tuple) of integers is passed, multiple axes are inserted. In this case, dim arguments should be same length list. If axis=None is passed, all the axes will be inserted to the start of the result array. + **dim_kwargs : int or sequence/ndarray + The keywords are arbitrary dimensions being inserted and the values + are either the lengths of the new dims (if int is given), or their + coordinates. Note, this is an alternative to passing a dict to the + dim kwarg and will only be used if dim is None. **WARNING** for + python 3.5 ``dim_kwargs`` is not available. Returns ------- expanded : same type as caller This object, but with an additional dimension(s). """ + if isinstance(dim, int): + raise TypeError('dim should be str or sequence of strs or dict') + elif isinstance(dim, str): + dim = OrderedDict(((dim, 1),)) + elif isinstance(dim, (list, tuple)): + if len(dim) != len(set(dim)): + raise ValueError('dims should not contain duplicate values.') + dim = OrderedDict(((d, 1) for d in dim)) + + # TODO: get rid of the below code block when python 3.5 is no longer + # supported. + python36_plus = sys.version_info[0] == 3 and sys.version_info[1] > 5 + not_ordereddict = dim is not None and not isinstance(dim, OrderedDict) + if not python36_plus and not_ordereddict: + raise TypeError("dim must be an OrderedDict for python <3.6") + elif not python36_plus and dim_kwargs: + raise ValueError("dim_kwargs isn't available for python <3.6") + dim_kwargs = OrderedDict(dim_kwargs) + + dim = either_dict_or_kwargs(dim, dim_kwargs, 'expand_dims') ds = self._to_temp_dataset().expand_dims(dim, axis) return self._from_temp_dataset(ds) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 12c5d139fdc..9dbcd8a8f70 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -2329,7 +2329,7 @@ def swap_dims(self, dims_dict, inplace=None): return self._replace_with_new_dims(variables, coord_names, indexes=indexes, inplace=inplace) - def expand_dims(self, dim, axis=None): + def expand_dims(self, dim=None, axis=None, **dim_kwargs): """Return a new object with an additional axis (or axes) inserted at the corresponding position in the array shape. @@ -2338,15 +2338,27 @@ def expand_dims(self, dim, axis=None): Parameters ---------- - dim : str or sequence of str. + dim : str, sequence of str, dict, or None Dimensions to include on the new variable. - dimensions are inserted with length 1. + If provided as str or sequence of str, then dimensions are inserted + with length 1. If provided as a dict, then the keys are the new + dimensions and the values are either integers (giving the length of + the new dimensions) or sequence/ndarray (giving the coordinates of + the new dimensions). **WARNING** for python 3.5, if ``dim`` is + dict-like, then it must be an ``OrderedDict``. This is to ensure + that the order in which the dims are given is maintained. axis : integer, list (or tuple) of integers, or None Axis position(s) where new axis is to be inserted (position(s) on the result array). If a list (or tuple) of integers is passed, multiple axes are inserted. In this case, dim arguments should be - the same length list. If axis=None is passed, all the axes will - be inserted to the start of the result array. + same length list. If axis=None is passed, all the axes will be + inserted to the start of the result array. + **dim_kwargs : int or sequence/ndarray + The keywords are arbitrary dimensions being inserted and the values + are either the lengths of the new dims (if int is given), or their + coordinates. Note, this is an alternative to passing a dict to the + dim kwarg and will only be used if dim is None. **WARNING** for + python 3.5 ``dim_kwargs`` is not available. Returns ------- @@ -2354,10 +2366,25 @@ def expand_dims(self, dim, axis=None): This object, but with an additional dimension(s). """ if isinstance(dim, int): - raise ValueError('dim should be str or sequence of strs or dict') + raise TypeError('dim should be str or sequence of strs or dict') + elif isinstance(dim, str): + dim = OrderedDict(((dim, 1),)) + elif isinstance(dim, (list, tuple)): + if len(dim) != len(set(dim)): + raise ValueError('dims should not contain duplicate values.') + dim = OrderedDict(((d, 1) for d in dim)) + + # TODO: get rid of the below code block when python 3.5 is no longer + # supported. + python36_plus = sys.version_info[0] == 3 and sys.version_info[1] > 5 + not_ordereddict = dim is not None and not isinstance(dim, OrderedDict) + if not python36_plus and not_ordereddict: + raise TypeError("dim must be an OrderedDict for python <3.6") + elif not python36_plus and dim_kwargs: + raise ValueError("dim_kwargs isn't available for python <3.6") + + dim = either_dict_or_kwargs(dim, dim_kwargs, 'expand_dims') - if isinstance(dim, str): - dim = [dim] if axis is not None and not isinstance(axis, (list, tuple)): axis = [axis] @@ -2376,10 +2403,24 @@ def expand_dims(self, dim, axis=None): '{dim} already exists as coordinate or' ' variable name.'.format(dim=d)) - if len(dim) != len(set(dim)): - raise ValueError('dims should not contain duplicate values.') - variables = OrderedDict() + # If dim is a dict, then ensure that the values are either integers + # or iterables. + for k, v in dim.items(): + if hasattr(v, "__iter__"): + # If the value for the new dimension is an iterable, then + # save the coordinates to the variables dict, and set the + # value within the dim dict to the length of the iterable + # for later use. + variables[k] = xr.IndexVariable((k,), v) + self._coord_names.add(k) + dim[k] = variables[k].size + elif isinstance(v, int): + pass # Do nothing if the dimensions value is just an int + else: + raise TypeError('The value of new dimension {k} must be ' + 'an iterable or an int'.format(k=k)) + for k, v in self._variables.items(): if k not in dim: if k in self._coord_names: # Do not change coordinates @@ -2400,11 +2441,13 @@ def expand_dims(self, dim, axis=None): ' values.') # We need to sort them to make sure `axis` equals to the # axis positions of the result array. - zip_axis_dim = sorted(zip(axis_pos, dim)) + zip_axis_dim = sorted(zip(axis_pos, dim.items())) + + all_dims = list(zip(v.dims, v.shape)) + for d, c in zip_axis_dim: + all_dims.insert(d, c) + all_dims = OrderedDict(all_dims) - all_dims = list(v.dims) - for a, d in zip_axis_dim: - all_dims.insert(a, d) variables[k] = v.set_dims(all_dims) else: # If dims includes a label of a non-dimension coordinate, diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 4975071dad8..b1ecf160533 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -3,6 +3,7 @@ from collections import OrderedDict from copy import deepcopy from textwrap import dedent +import sys import numpy as np import pandas as pd @@ -1303,7 +1304,7 @@ def test_expand_dims_error(self): coords={'x': np.linspace(0.0, 1.0, 3)}, attrs={'key': 'entry'}) - with raises_regex(ValueError, 'dim should be str or'): + with raises_regex(TypeError, 'dim should be str or'): array.expand_dims(0) with raises_regex(ValueError, 'lengths of dim and axis'): # dims and axis argument should be the same length @@ -1328,6 +1329,16 @@ def test_expand_dims_error(self): array.expand_dims(dim=['y', 'z'], axis=[2, -4]) array.expand_dims(dim=['y', 'z'], axis=[2, 3]) + array = DataArray(np.random.randn(3, 4), dims=['x', 'dim_0'], + coords={'x': np.linspace(0.0, 1.0, 3)}, + attrs={'key': 'entry'}) + with pytest.raises(TypeError): + array.expand_dims(OrderedDict((("new_dim", 3.2),))) + + # Attempt to use both dim and kwargs + with pytest.raises(ValueError): + array.expand_dims(OrderedDict((("d", 4),)), e=4) + def test_expand_dims(self): array = DataArray(np.random.randn(3, 4), dims=['x', 'dim_0'], coords={'x': np.linspace(0.0, 1.0, 3)}, @@ -1392,6 +1403,46 @@ def test_expand_dims_with_scalar_coordinate(self): roundtripped = actual.squeeze(['z'], drop=False) assert_identical(array, roundtripped) + def test_expand_dims_with_greater_dim_size(self): + array = DataArray(np.random.randn(3, 4), dims=['x', 'dim_0'], + coords={'x': np.linspace(0.0, 1.0, 3), 'z': 1.0}, + attrs={'key': 'entry'}) + # For python 3.5 and earlier this has to be an ordered dict, to + # maintain insertion order. + actual = array.expand_dims( + OrderedDict((('y', 2), ('z', 1), ('dim_1', ['a', 'b', 'c'])))) + + expected_coords = OrderedDict(( + ('y', [0, 1]), ('z', [1.0]), ('dim_1', ['a', 'b', 'c']), + ('x', np.linspace(0, 1, 3)), ('dim_0', range(4)))) + expected = DataArray(array.values * np.ones([2, 1, 3, 3, 4]), + coords=expected_coords, + dims=list(expected_coords.keys()), + attrs={'key': 'entry'} + ).drop(['y', 'dim_0']) + assert_identical(expected, actual) + + # Test with kwargs instead of passing dict to dim arg. + + # TODO: only the code under the if-statement is needed when python 3.5 + # is no longer supported. + python36_plus = sys.version_info[0] == 3 and sys.version_info[1] > 5 + if python36_plus: + other_way = array.expand_dims(dim_1=['a', 'b', 'c']) + + other_way_expected = DataArray( + array.values * np.ones([3, 3, 4]), + coords={'dim_1': ['a', 'b', 'c'], + 'x': np.linspace(0, 1, 3), + 'dim_0': range(4), 'z': 1.0}, + dims=['dim_1', 'x', 'dim_0'], + attrs={'key': 'entry'}).drop('dim_0') + assert_identical(other_way_expected, other_way) + else: + # In python 3.5, using dim_kwargs should raise a ValueError. + with raises_regex(ValueError, "dim_kwargs isn't"): + array.expand_dims(e=["l", "m", "n"]) + def test_set_index(self): indexes = [self.mindex.get_level_values(n) for n in self.mindex.names] coords = {idx.name: ('x', idx) for idx in indexes} diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 8e8c6c4b419..75b736239e6 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -2033,6 +2033,27 @@ def test_expand_dims_error(self): with raises_regex(ValueError, 'already exists'): original.expand_dims(dim=['z']) + original = Dataset({'x': ('a', np.random.randn(3)), + 'y': (['b', 'a'], np.random.randn(4, 3)), + 'z': ('a', np.random.randn(3))}, + coords={'a': np.linspace(0, 1, 3), + 'b': np.linspace(0, 1, 4), + 'c': np.linspace(0, 1, 5)}, + attrs={'key': 'entry'}) + with raises_regex(TypeError, 'value of new dimension'): + original.expand_dims(OrderedDict((("d", 3.2),))) + + # TODO: only the code under the if-statement is needed when python 3.5 + # is no longer supported. + python36_plus = sys.version_info[0] == 3 and sys.version_info[1] > 5 + if python36_plus: + with raises_regex(ValueError, 'both keyword and positional'): + original.expand_dims(OrderedDict((("d", 4),)), e=4) + else: + # In python 3.5, using dim_kwargs should raise a ValueError. + with raises_regex(ValueError, "dim_kwargs isn't"): + original.expand_dims(OrderedDict((("d", 4),)), e=4) + def test_expand_dims(self): original = Dataset({'x': ('a', np.random.randn(3)), 'y': (['b', 'a'], np.random.randn(4, 3))}, @@ -2066,6 +2087,53 @@ def test_expand_dims(self): roundtripped = actual.squeeze('z') assert_identical(original, roundtripped) + # Test expanding one dimension to have size > 1 that doesn't have + # coordinates, and also expanding another dimension to have size > 1 + # that DOES have coordinates. + actual = original.expand_dims( + OrderedDict((("d", 4), ("e", ["l", "m", "n"])))) + + expected = Dataset( + {'x': xr.DataArray(original['x'].values * np.ones([4, 3, 3]), + coords=dict(d=range(4), + e=['l', 'm', 'n'], + a=np.linspace(0, 1, 3)), + dims=['d', 'e', 'a']).drop('d'), + 'y': xr.DataArray(original['y'].values * np.ones([4, 3, 4, 3]), + coords=dict(d=range(4), + e=['l', 'm', 'n'], + b=np.linspace(0, 1, 4), + a=np.linspace(0, 1, 3)), + dims=['d', 'e', 'b', 'a']).drop('d')}, + coords={'c': np.linspace(0, 1, 5)}, + attrs={'key': 'entry'}) + assert_identical(actual, expected) + + # Test with kwargs instead of passing dict to dim arg. + + # TODO: only the code under the if-statement is needed when python 3.5 + # is no longer supported. + python36_plus = sys.version_info[0] == 3 and sys.version_info[1] > 5 + if python36_plus: + other_way = original.expand_dims(e=["l", "m", "n"]) + other_way_expected = Dataset( + {'x': xr.DataArray(original['x'].values * np.ones([3, 3]), + coords=dict(e=['l', 'm', 'n'], + a=np.linspace(0, 1, 3)), + dims=['e', 'a']), + 'y': xr.DataArray(original['y'].values * np.ones([3, 4, 3]), + coords=dict(e=['l', 'm', 'n'], + b=np.linspace(0, 1, 4), + a=np.linspace(0, 1, 3)), + dims=['e', 'b', 'a'])}, + coords={'c': np.linspace(0, 1, 5)}, + attrs={'key': 'entry'}) + assert_identical(other_way_expected, other_way) + else: + # In python 3.5, using dim_kwargs should raise a ValueError. + with raises_regex(ValueError, "dim_kwargs isn't"): + original.expand_dims(e=["l", "m", "n"]) + def test_set_index(self): expected = create_test_multiindex() mindex = expected['x'].to_index()