Skip to content

Allow expand_dims() method to support inserting/broadcasting dimensions with size>1 #2757

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Mar 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
b9663b2
Quarter offset implemented (base is now latest pydata-master). (#2721)
jwenfai Mar 2, 2019
ab268de
Add `Dataset.drop_dims` (#2767)
kmsquire Mar 3, 2019
b393754
Improve name concat (#2792)
TomNicholas Mar 4, 2019
872b49c
Don't use deprecated np.asscalar() (#2800)
TimoRoth Mar 5, 2019
849eb18
Add support for cftime.datetime coordinates with coarsen (#2778)
spencerkclark Mar 6, 2019
54883ba
some docs updates (#2746)
dcherian Mar 12, 2019
526a395
Drop failing tests writing multi-dimensional arrays as attributes (#2…
shoyer Mar 14, 2019
9f00c6f
Push back finalizing deprecations for 0.12 (#2809)
shoyer Mar 15, 2019
7d209f6
enable loading remote hdf5 files (#2782)
scottyhq Mar 16, 2019
ccb198f
Release 0.12.0
shoyer Mar 16, 2019
6ec9910
Add whats-new for 0.12.1
shoyer Mar 16, 2019
4ce03c2
Rework whats-new for 0.12
shoyer Mar 16, 2019
97fdb83
DOC: Update donation links
shoyer Mar 20, 2019
a74ecd6
DOC: remove outdated warning (#2818)
shoyer Mar 20, 2019
21fa6e0
Allow expand_dims() method to support inserting/broadcasting dimensio…
pletchm Feb 7, 2019
78608d4
Merge branch 'master' into feature/expand-dims-broadcast
pletchm Mar 21, 2019
4e47dd1
Allow expand_dims() method to support inserting/broadcasting dimensio…
pletchm Feb 7, 2019
1ffc64e
Merge branch 'feature/expand-dims-broadcast' of https://github.com/pl…
pletchm Mar 26, 2019
c319c47
Allow expand_dims() method to support inserting/broadcasting dimensio…
pletchm Feb 7, 2019
cac434e
Merge branch 'feature/expand-dims-broadcast' of https://github.com/pl…
pletchm Mar 26, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ v0.12.1 (unreleased)
Enhancements
~~~~~~~~~~~~

- Allow ``expand_dims`` method to support inserting/broadcasting dimensions
with size > 1. (:issue:`2710`)
By `Martin Pletcher <https://github.com/pletchm>`_.


Bug fixes
~~~~~~~~~
Expand Down
39 changes: 36 additions & 3 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import sys
import warnings
from collections import OrderedDict

Expand Down Expand Up @@ -1138,7 +1139,7 @@ def swap_dims(self, dims_dict):
ds = self._to_temp_dataset().swap_dims(dims_dict)
return self._from_temp_dataset(ds)

def expand_dims(self, dim, axis=None):
def expand_dims(self, dim=None, axis=None, **dim_kwargs):
"""Return a new object with an additional axis (or axes) inserted at
the corresponding position in the array shape.

Expand All @@ -1147,21 +1148,53 @@ def expand_dims(self, dim, axis=None):

Parameters
----------
dim : str or sequence of str.
dim : str, sequence of str, dict, or None
Dimensions to include on the new variable.
dimensions are inserted with length 1.
If provided as str or sequence of str, then dimensions are inserted
with length 1. If provided as a dict, then the keys are the new
dimensions and the values are either integers (giving the length of
the new dimensions) or sequence/ndarray (giving the coordinates of
the new dimensions). **WARNING** for python 3.5, if ``dim`` is
dict-like, then it must be an ``OrderedDict``. This is to ensure
that the order in which the dims are given is maintained.
axis : integer, list (or tuple) of integers, or None
Axis position(s) where new axis is to be inserted (position(s) on
the result array). If a list (or tuple) of integers is passed,
multiple axes are inserted. In this case, dim arguments should be
same length list. If axis=None is passed, all the axes will be
inserted to the start of the result array.
**dim_kwargs : int or sequence/ndarray
The keywords are arbitrary dimensions being inserted and the values
are either the lengths of the new dims (if int is given), or their
coordinates. Note, this is an alternative to passing a dict to the
dim kwarg and will only be used if dim is None. **WARNING** for
python 3.5 ``dim_kwargs`` is not available.

Returns
-------
expanded : same type as caller
This object, but with an additional dimension(s).
"""
if isinstance(dim, int):
raise TypeError('dim should be str or sequence of strs or dict')
elif isinstance(dim, str):
dim = OrderedDict(((dim, 1),))
elif isinstance(dim, (list, tuple)):
if len(dim) != len(set(dim)):
raise ValueError('dims should not contain duplicate values.')
dim = OrderedDict(((d, 1) for d in dim))

# TODO: get rid of the below code block when python 3.5 is no longer
# supported.
python36_plus = sys.version_info[0] == 3 and sys.version_info[1] > 5
not_ordereddict = dim is not None and not isinstance(dim, OrderedDict)
if not python36_plus and not_ordereddict:
raise TypeError("dim must be an OrderedDict for python <3.6")
elif not python36_plus and dim_kwargs:
raise ValueError("dim_kwargs isn't available for python <3.6")
dim_kwargs = OrderedDict(dim_kwargs)

dim = either_dict_or_kwargs(dim, dim_kwargs, 'expand_dims')
ds = self._to_temp_dataset().expand_dims(dim, axis)
return self._from_temp_dataset(ds)

Expand Down
73 changes: 58 additions & 15 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2329,7 +2329,7 @@ def swap_dims(self, dims_dict, inplace=None):
return self._replace_with_new_dims(variables, coord_names,
indexes=indexes, inplace=inplace)

def expand_dims(self, dim, axis=None):
def expand_dims(self, dim=None, axis=None, **dim_kwargs):
"""Return a new object with an additional axis (or axes) inserted at
the corresponding position in the array shape.

Expand All @@ -2338,26 +2338,53 @@ def expand_dims(self, dim, axis=None):

Parameters
----------
dim : str or sequence of str.
dim : str, sequence of str, dict, or None
Dimensions to include on the new variable.
dimensions are inserted with length 1.
If provided as str or sequence of str, then dimensions are inserted
with length 1. If provided as a dict, then the keys are the new
dimensions and the values are either integers (giving the length of
the new dimensions) or sequence/ndarray (giving the coordinates of
the new dimensions). **WARNING** for python 3.5, if ``dim`` is
dict-like, then it must be an ``OrderedDict``. This is to ensure
that the order in which the dims are given is maintained.
axis : integer, list (or tuple) of integers, or None
Axis position(s) where new axis is to be inserted (position(s) on
the result array). If a list (or tuple) of integers is passed,
multiple axes are inserted. In this case, dim arguments should be
the same length list. If axis=None is passed, all the axes will
be inserted to the start of the result array.
same length list. If axis=None is passed, all the axes will be
inserted to the start of the result array.
**dim_kwargs : int or sequence/ndarray
The keywords are arbitrary dimensions being inserted and the values
are either the lengths of the new dims (if int is given), or their
coordinates. Note, this is an alternative to passing a dict to the
dim kwarg and will only be used if dim is None. **WARNING** for
python 3.5 ``dim_kwargs`` is not available.

Returns
-------
expanded : same type as caller
This object, but with an additional dimension(s).
"""
if isinstance(dim, int):
raise ValueError('dim should be str or sequence of strs or dict')
raise TypeError('dim should be str or sequence of strs or dict')
elif isinstance(dim, str):
dim = OrderedDict(((dim, 1),))
elif isinstance(dim, (list, tuple)):
if len(dim) != len(set(dim)):
raise ValueError('dims should not contain duplicate values.')
dim = OrderedDict(((d, 1) for d in dim))

# TODO: get rid of the below code block when python 3.5 is no longer
# supported.
python36_plus = sys.version_info[0] == 3 and sys.version_info[1] > 5
not_ordereddict = dim is not None and not isinstance(dim, OrderedDict)
if not python36_plus and not_ordereddict:
raise TypeError("dim must be an OrderedDict for python <3.6")
elif not python36_plus and dim_kwargs:
raise ValueError("dim_kwargs isn't available for python <3.6")

dim = either_dict_or_kwargs(dim, dim_kwargs, 'expand_dims')

if isinstance(dim, str):
dim = [dim]
if axis is not None and not isinstance(axis, (list, tuple)):
axis = [axis]

Expand All @@ -2376,10 +2403,24 @@ def expand_dims(self, dim, axis=None):
'{dim} already exists as coordinate or'
' variable name.'.format(dim=d))

if len(dim) != len(set(dim)):
raise ValueError('dims should not contain duplicate values.')

variables = OrderedDict()
# If dim is a dict, then ensure that the values are either integers
# or iterables.
for k, v in dim.items():
if hasattr(v, "__iter__"):
# If the value for the new dimension is an iterable, then
# save the coordinates to the variables dict, and set the
# value within the dim dict to the length of the iterable
# for later use.
variables[k] = xr.IndexVariable((k,), v)
self._coord_names.add(k)
dim[k] = variables[k].size
elif isinstance(v, int):
pass # Do nothing if the dimensions value is just an int
else:
raise TypeError('The value of new dimension {k} must be '
'an iterable or an int'.format(k=k))

for k, v in self._variables.items():
if k not in dim:
if k in self._coord_names: # Do not change coordinates
Expand All @@ -2400,11 +2441,13 @@ def expand_dims(self, dim, axis=None):
' values.')
# We need to sort them to make sure `axis` equals to the
# axis positions of the result array.
zip_axis_dim = sorted(zip(axis_pos, dim))
zip_axis_dim = sorted(zip(axis_pos, dim.items()))

all_dims = list(zip(v.dims, v.shape))
for d, c in zip_axis_dim:
all_dims.insert(d, c)
all_dims = OrderedDict(all_dims)

all_dims = list(v.dims)
for a, d in zip_axis_dim:
all_dims.insert(a, d)
variables[k] = v.set_dims(all_dims)
else:
# If dims includes a label of a non-dimension coordinate,
Expand Down
53 changes: 52 additions & 1 deletion xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections import OrderedDict
from copy import deepcopy
from textwrap import dedent
import sys

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -1303,7 +1304,7 @@ def test_expand_dims_error(self):
coords={'x': np.linspace(0.0, 1.0, 3)},
attrs={'key': 'entry'})

with raises_regex(ValueError, 'dim should be str or'):
with raises_regex(TypeError, 'dim should be str or'):
array.expand_dims(0)
with raises_regex(ValueError, 'lengths of dim and axis'):
# dims and axis argument should be the same length
Expand All @@ -1328,6 +1329,16 @@ def test_expand_dims_error(self):
array.expand_dims(dim=['y', 'z'], axis=[2, -4])
array.expand_dims(dim=['y', 'z'], axis=[2, 3])

array = DataArray(np.random.randn(3, 4), dims=['x', 'dim_0'],
coords={'x': np.linspace(0.0, 1.0, 3)},
attrs={'key': 'entry'})
with pytest.raises(TypeError):
array.expand_dims(OrderedDict((("new_dim", 3.2),)))

# Attempt to use both dim and kwargs
with pytest.raises(ValueError):
array.expand_dims(OrderedDict((("d", 4),)), e=4)

def test_expand_dims(self):
array = DataArray(np.random.randn(3, 4), dims=['x', 'dim_0'],
coords={'x': np.linspace(0.0, 1.0, 3)},
Expand Down Expand Up @@ -1392,6 +1403,46 @@ def test_expand_dims_with_scalar_coordinate(self):
roundtripped = actual.squeeze(['z'], drop=False)
assert_identical(array, roundtripped)

def test_expand_dims_with_greater_dim_size(self):
array = DataArray(np.random.randn(3, 4), dims=['x', 'dim_0'],
coords={'x': np.linspace(0.0, 1.0, 3), 'z': 1.0},
attrs={'key': 'entry'})
# For python 3.5 and earlier this has to be an ordered dict, to
# maintain insertion order.
actual = array.expand_dims(
OrderedDict((('y', 2), ('z', 1), ('dim_1', ['a', 'b', 'c']))))

expected_coords = OrderedDict((
('y', [0, 1]), ('z', [1.0]), ('dim_1', ['a', 'b', 'c']),
('x', np.linspace(0, 1, 3)), ('dim_0', range(4))))
expected = DataArray(array.values * np.ones([2, 1, 3, 3, 4]),
coords=expected_coords,
dims=list(expected_coords.keys()),
attrs={'key': 'entry'}
).drop(['y', 'dim_0'])
assert_identical(expected, actual)

# Test with kwargs instead of passing dict to dim arg.

# TODO: only the code under the if-statement is needed when python 3.5
# is no longer supported.
python36_plus = sys.version_info[0] == 3 and sys.version_info[1] > 5
if python36_plus:
other_way = array.expand_dims(dim_1=['a', 'b', 'c'])

other_way_expected = DataArray(
array.values * np.ones([3, 3, 4]),
coords={'dim_1': ['a', 'b', 'c'],
'x': np.linspace(0, 1, 3),
'dim_0': range(4), 'z': 1.0},
dims=['dim_1', 'x', 'dim_0'],
attrs={'key': 'entry'}).drop('dim_0')
assert_identical(other_way_expected, other_way)
else:
# In python 3.5, using dim_kwargs should raise a ValueError.
with raises_regex(ValueError, "dim_kwargs isn't"):
array.expand_dims(e=["l", "m", "n"])

def test_set_index(self):
indexes = [self.mindex.get_level_values(n) for n in self.mindex.names]
coords = {idx.name: ('x', idx) for idx in indexes}
Expand Down
68 changes: 68 additions & 0 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2033,6 +2033,27 @@ def test_expand_dims_error(self):
with raises_regex(ValueError, 'already exists'):
original.expand_dims(dim=['z'])

original = Dataset({'x': ('a', np.random.randn(3)),
'y': (['b', 'a'], np.random.randn(4, 3)),
'z': ('a', np.random.randn(3))},
coords={'a': np.linspace(0, 1, 3),
'b': np.linspace(0, 1, 4),
'c': np.linspace(0, 1, 5)},
attrs={'key': 'entry'})
with raises_regex(TypeError, 'value of new dimension'):
original.expand_dims(OrderedDict((("d", 3.2),)))

# TODO: only the code under the if-statement is needed when python 3.5
# is no longer supported.
python36_plus = sys.version_info[0] == 3 and sys.version_info[1] > 5
if python36_plus:
with raises_regex(ValueError, 'both keyword and positional'):
original.expand_dims(OrderedDict((("d", 4),)), e=4)
else:
# In python 3.5, using dim_kwargs should raise a ValueError.
with raises_regex(ValueError, "dim_kwargs isn't"):
original.expand_dims(OrderedDict((("d", 4),)), e=4)

def test_expand_dims(self):
original = Dataset({'x': ('a', np.random.randn(3)),
'y': (['b', 'a'], np.random.randn(4, 3))},
Expand Down Expand Up @@ -2066,6 +2087,53 @@ def test_expand_dims(self):
roundtripped = actual.squeeze('z')
assert_identical(original, roundtripped)

# Test expanding one dimension to have size > 1 that doesn't have
# coordinates, and also expanding another dimension to have size > 1
# that DOES have coordinates.
actual = original.expand_dims(
OrderedDict((("d", 4), ("e", ["l", "m", "n"]))))

expected = Dataset(
{'x': xr.DataArray(original['x'].values * np.ones([4, 3, 3]),
coords=dict(d=range(4),
e=['l', 'm', 'n'],
a=np.linspace(0, 1, 3)),
dims=['d', 'e', 'a']).drop('d'),
'y': xr.DataArray(original['y'].values * np.ones([4, 3, 4, 3]),
coords=dict(d=range(4),
e=['l', 'm', 'n'],
b=np.linspace(0, 1, 4),
a=np.linspace(0, 1, 3)),
dims=['d', 'e', 'b', 'a']).drop('d')},
coords={'c': np.linspace(0, 1, 5)},
attrs={'key': 'entry'})
assert_identical(actual, expected)

# Test with kwargs instead of passing dict to dim arg.

# TODO: only the code under the if-statement is needed when python 3.5
# is no longer supported.
python36_plus = sys.version_info[0] == 3 and sys.version_info[1] > 5
if python36_plus:
other_way = original.expand_dims(e=["l", "m", "n"])
other_way_expected = Dataset(
{'x': xr.DataArray(original['x'].values * np.ones([3, 3]),
coords=dict(e=['l', 'm', 'n'],
a=np.linspace(0, 1, 3)),
dims=['e', 'a']),
'y': xr.DataArray(original['y'].values * np.ones([3, 4, 3]),
coords=dict(e=['l', 'm', 'n'],
b=np.linspace(0, 1, 4),
a=np.linspace(0, 1, 3)),
dims=['e', 'b', 'a'])},
coords={'c': np.linspace(0, 1, 5)},
attrs={'key': 'entry'})
assert_identical(other_way_expected, other_way)
else:
# In python 3.5, using dim_kwargs should raise a ValueError.
with raises_regex(ValueError, "dim_kwargs isn't"):
original.expand_dims(e=["l", "m", "n"])

def test_set_index(self):
expected = create_test_multiindex()
mindex = expected['x'].to_index()
Expand Down