Skip to content

Commit 2366c91

Browse files
committed
Internal refactor of XArray, with a new CoordXArray subtype
This allows us to simplify our internal model for XArray (it always cached internally as a base ndarray) and supports some previously tricky aspects involving pandas.Index objects. Noteably: 1. The dtype of arrays stored as pandas.Index objects can now be faithfully saved and restored. Doing math with XArray objects always yields objects with the right dtype, so `ds['latitude'] + 1` has dtype=float, not dtype=object. 2. It's no longer necessary to load index data into memory upon creating a new Dataset. Instead, the index data can be loaded on demand. 3. `var.data` is always an ndarray. `var.index` is always a pandas.Index. Related issues: #17, #39, #40.
1 parent fdbfb7c commit 2366c91

File tree

11 files changed

+411
-231
lines changed

11 files changed

+411
-231
lines changed

src/xray/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .xarray import XArray, broadcast_xarrays
1+
from .xarray import as_xarray, XArray, CoordXArray, broadcast_xarrays
22
from .dataset import Dataset, open_dataset
33
from .dataset_array import DatasetArray, align
44
from .utils import (orthogonal_indexer, decode_cf_datetime, encode_cf_datetime,

src/xray/common.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,6 @@ def func(self, dimension=cls._reduce_dimension_default,
1515

1616

1717
class AbstractArray(ImplementsReduce):
18-
@property
19-
def dtype(self):
20-
return self._data.dtype
21-
22-
@property
23-
def shape(self):
24-
return self._data.shape
25-
26-
@property
27-
def size(self):
28-
return self._data.size
29-
30-
@property
31-
def ndim(self):
32-
return self._data.ndim
33-
34-
def __len__(self):
35-
return len(self._data)
36-
3718
def __nonzero__(self):
3819
return bool(self.data)
3920

src/xray/conventions.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -260,20 +260,6 @@ def encode_cf_variable(array):
260260
data, encoding.pop('units', None), encoding.pop('calendar', None))
261261
attributes['units'] = units
262262
attributes['calendar'] = calendar
263-
elif data.dtype == np.dtype('O'):
264-
# Unfortunately, pandas.Index arrays often have dtype=object even if
265-
# they were created from an array with a sensible datatype (e.g.,
266-
# pandas.Float64Index always has dtype=object for some reason). Because
267-
# we allow for doing math with coordinates, these object arrays can
268-
# propagate onward to other variables, which is why we don't only apply
269-
# this check to XArrays with data that is a pandas.Index.
270-
# Accordingly, we convert object arrays to the type of their first
271-
# variable.
272-
dtype = np.array(data.reshape(-1)[0]).dtype
273-
# N.B. the "astype" call below will fail if data cannot be cast to the
274-
# type of its first element (which is probably the only sensible thing
275-
# to do).
276-
data = np.asarray(data).astype(dtype)
277263

278264
def get_to(source, dest, k):
279265
v = source.get(k)

src/xray/dataset.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,10 @@ def _get_virtual_variable(self, key):
7676
if ref_var in self._datetimeindices():
7777
if suffix == 'season':
7878
# seasons = np.array(['DJF', 'MAM', 'JJA', 'SON'])
79-
month = self[ref_var].data.month
79+
month = self[ref_var].index.month
8080
data = (month // 3) % 4 + 1
8181
else:
82-
data = getattr(self[ref_var].data, suffix)
82+
data = getattr(self[ref_var].index, suffix)
8383
return xarray.XArray(self[ref_var].dimensions, data)
8484
raise KeyError('virtual variable %r not found' % key)
8585

@@ -130,14 +130,15 @@ def __init__(self, variables=None, attributes=None, decode_cf=False):
130130

131131
def _as_variable(self, name, var, decode_cf=False):
132132
if isinstance(var, DatasetArray):
133-
var = var.array
134-
if not isinstance(var, xarray.XArray):
133+
var = xarray.as_xarray(var)
134+
elif not isinstance(var, xarray.XArray):
135135
try:
136136
var = xarray.XArray(*var)
137137
except TypeError:
138138
raise TypeError('Dataset variables must be of type '
139139
'DatasetArray or XArray, or a sequence of the '
140-
'form (dimensions, data[, attributes])')
140+
'form (dimensions, data[, attributes, '
141+
'encoding])')
141142
# this will unmask and rescale the data as well as convert
142143
# time variables to datetime indices.
143144
if decode_cf:
@@ -147,9 +148,7 @@ def _as_variable(self, name, var, decode_cf=False):
147148
if var.ndim != 1:
148149
raise ValueError('a coordinate variable must be defined with '
149150
'1-dimensional data')
150-
# create a new XArray object on which to modify the data
151-
var = xarray.XArray(var.dimensions, pd.Index(var.data),
152-
var.attributes, encoding=var.encoding)
151+
var = var.to_coord()
153152
return var
154153

155154
def set_variables(self, variables, decode_cf=False):
@@ -487,7 +486,7 @@ def labeled_by(self, **indexers):
487486
Dataset.indexed_by
488487
Array.indexed_by
489488
"""
490-
return self.indexed_by(**remap_loc_indexers(self.variables, indexers))
489+
return self.indexed_by(**remap_loc_indexers(self, indexers))
491490

492491
def renamed(self, name_dict):
493492
"""Returns a new object with renamed variables and dimensions.

src/xray/dataset_array.py

Lines changed: 46 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -66,28 +66,47 @@ def __init__(self, dataset, focus):
6666
self.focus = focus
6767

6868
@property
69-
def array(self):
69+
def variable(self):
7070
return self.dataset.variables[self.focus]
71-
@array.setter
72-
def array(self, value):
71+
@variable.setter
72+
def variable(self, value):
7373
self.dataset[self.focus] = value
7474

75-
# _data is necessary for AbstractArray
7675
@property
77-
def _data(self):
78-
return self.array._data
76+
def dtype(self):
77+
return self.variable.dtype
78+
79+
@property
80+
def shape(self):
81+
return self.variable.shape
82+
83+
@property
84+
def size(self):
85+
return self.variable.size
86+
87+
@property
88+
def ndim(self):
89+
return self.variable.ndim
90+
91+
def __len__(self):
92+
return len(self.variable)
7993

8094
@property
8195
def data(self):
82-
"""The array's data as a numpy.ndarray"""
83-
return self.array.data
96+
"""The variables's data as a numpy.ndarray"""
97+
return self.variable.data
8498
@data.setter
8599
def data(self, value):
86-
self.array.data = value
100+
self.variable.data = value
101+
102+
@property
103+
def index(self):
104+
"""The variable's data as a pandas.Index"""
105+
return self.variable.index
87106

88107
@property
89108
def dimensions(self):
90-
return self.array.dimensions
109+
return self.variable.dimensions
91110

92111
def _key_to_indexers(self, key):
93112
return OrderedDict(
@@ -107,7 +126,7 @@ def __setitem__(self, key, value):
107126
self.dataset[key] = value
108127
else:
109128
# orthogonal array indexing
110-
self.array[key] = value
129+
self.variable[key] = value
111130

112131
def __delitem__(self, key):
113132
del self.dataset[key]
@@ -127,11 +146,11 @@ def __iter__(self):
127146

128147
@property
129148
def attributes(self):
130-
return self.array.attributes
149+
return self.variable.attributes
131150

132151
@property
133152
def encoding(self):
134-
return self.array.encoding
153+
return self.variable.encoding
135154

136155
@property
137156
def variables(self):
@@ -178,7 +197,7 @@ def indexed_by(self, **indexers):
178197
if self.focus not in ds:
179198
# always keep focus variable in the dataset, even if it was
180199
# unselected because indexing made it a scaler
181-
ds[self.focus] = self.array.indexed_by(**indexers)
200+
ds[self.focus] = self.variable.indexed_by(**indexers)
182201
return type(self)(ds, self.focus)
183202

184203
def labeled_by(self, **indexers):
@@ -236,7 +255,7 @@ def refocus(self, new_var, name=None):
236255
If `new_var` is a dataset array, its contents will be merged in.
237256
"""
238257
if not hasattr(new_var, 'dimensions'):
239-
new_var = type(self.array)(self.array.dimensions, new_var)
258+
new_var = type(self.variable)(self.variable.dimensions, new_var)
240259
if self.focus not in self.dimensions:
241260
# only unselect the focus from the dataset if it isn't a coordinate
242261
# variable
@@ -301,7 +320,7 @@ def transpose(self, *dimensions):
301320
numpy.transpose
302321
Array.transpose
303322
"""
304-
return self.refocus(self.array.transpose(*dimensions), self.focus)
323+
return self.refocus(self.variable.transpose(*dimensions), self.focus)
305324

306325
def squeeze(self, dimension=None):
307326
"""Return a new DatasetArray object with squeezed data.
@@ -361,7 +380,7 @@ def reduce(self, func, dimension=None, axis=None, **kwargs):
361380
DatasetArray with this object's array replaced with an array with
362381
summarized data and the indicated dimension(s) removed.
363382
"""
364-
var = self.array.reduce(func, dimension, axis, **kwargs)
383+
var = self.variable.reduce(func, dimension, axis, **kwargs)
365384
drop = set(self.dimensions) - set(var.dimensions)
366385
# For now, take an aggressive strategy of removing all variables
367386
# associated with any dropped dimensions
@@ -495,13 +514,13 @@ def to_series(self):
495514
return pd.Series(self.data.reshape(-1), index=index, name=self.focus)
496515

497516
def __array_wrap__(self, obj, context=None):
498-
return self.refocus(self.array.__array_wrap__(obj, context))
517+
return self.refocus(self.variable.__array_wrap__(obj, context))
499518

500519
@staticmethod
501520
def _unary_op(f):
502521
@functools.wraps(f)
503522
def func(self, *args, **kwargs):
504-
return self.refocus(f(self.array, *args, **kwargs),
523+
return self.refocus(f(self.variable, *args, **kwargs),
505524
self.focus + '_' + f.__name__)
506525
return func
507526

@@ -523,12 +542,12 @@ def func(self, other):
523542
ds = self.unselected()
524543
if hasattr(other, 'unselected'):
525544
ds.merge(other.unselected(), inplace=True)
526-
other_array = getattr(other, 'array', other)
545+
other_array = getattr(other, 'variable', other)
527546
other_focus = getattr(other, 'focus', 'other')
528547
focus = self.focus + '_' + f.__name__ + '_' + other_focus
529-
ds[focus] = (f(self.array, other_array)
548+
ds[focus] = (f(self.variable, other_array)
530549
if not reflexive
531-
else f(other_array, self.array))
550+
else f(other_array, self.variable))
532551
return type(self)(ds, focus)
533552
return func
534553

@@ -537,8 +556,8 @@ def _inplace_binary_op(f):
537556
@functools.wraps(f)
538557
def func(self, other):
539558
self._check_coordinates_compat(other)
540-
other_array = getattr(other, 'array', other)
541-
self.array = f(self.array, other_array)
559+
other_array = getattr(other, 'variable', other)
560+
self.variable = f(self.variable, other_array)
542561
if hasattr(other, 'unselected'):
543562
self.dataset.merge(other.unselected(), inplace=True)
544563
return self
@@ -555,8 +574,9 @@ def align(array1, array2):
555574
# TODO: automatically align when doing math with arrays, or better yet
556575
# calculate the union of the indices and fill in the mis-aligned data with
557576
# NaN.
558-
overlapping_coords = {k: (array1.coordinates[k].data
559-
& array2.coordinates[k].data)
577+
# TODO: generalize this function to any number of arguments
578+
overlapping_coords = {k: (array1.coordinates[k].index
579+
& array2.coordinates[k].index)
560580
for k in array1.coordinates
561581
if k in array2.coordinates}
562582
return tuple(ar.labeled_by(**overlapping_coords)

src/xray/utils.py

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import numpy as np
77
import pandas as pd
88

9+
import xarray
10+
911

1012
def expanded_indexer(key, ndim):
1113
"""Given a key for indexing an ndarray, return an equivalent key which is a
@@ -86,12 +88,12 @@ def all_full_slices(key_index):
8688

8789

8890
def remap_loc_indexers(indices, indexers):
89-
"""Given mappings of indices and label based indexers, return equivalent
90-
location based indexers.
91+
"""Given mappings of XArray indices and label based indexers, return
92+
equivalent location based indexers.
9193
"""
9294
new_indexers = OrderedDict()
9395
for dim, loc in indexers.iteritems():
94-
index = indices[dim].data
96+
index = indices[dim].index
9597
if isinstance(loc, slice):
9698
indexer = index.slice_indexer(loc.start, loc.stop, loc.step)
9799
else:
@@ -201,11 +203,12 @@ def encode_cf_datetime(dates, units=None, calendar=None):
201203
and np.issubdtype(dates.dtype, np.datetime64)):
202204
# for now, don't bother doing any trickery like decode_cf_datetime to
203205
# convert dates to numbers faster
204-
dates = dates.astype(datetime)
206+
# TODO: don't use pandas.DatetimeIndex to do the conversion
207+
dates = pd.Index(dates.reshape(-1)).to_pydatetime().reshape(dates.shape)
205208

206209
if hasattr(dates, 'ndim') and dates.ndim == 0:
207-
# unpack dates because date2num doesn't like 0-dimensional arguments
208-
dates = dates[()]
210+
# date2num doesn't like 0-dimensional arguments
211+
dates = dates.item()
209212

210213
num = nc4.date2num(dates, units, calendar)
211214
return (num, units, calendar)
@@ -235,33 +238,40 @@ def xarray_equal(v1, v2, rtol=1e-05, atol=1e-08):
235238
This function is necessary because `v1 == v2` for XArrays and DatasetArrays
236239
does element-wise comparisions (like numpy.ndarrays).
237240
"""
241+
v1, v2 = map(xarray.as_xarray, [v1, v2])
238242
if (v1.dimensions == v2.dimensions
239-
and dict_equal(v1.attributes, v2.attributes)):
240-
try:
243+
and dict_equal(v1.attributes, v2.attributes)):
244+
if v1._data is v2._data:
241245
# if _data is identical, skip checking arrays by value
242-
if v1._data is v2._data:
243-
return True
244-
except AttributeError:
245-
# _data is not part of the public interface, so it's okay if its
246-
# missing
247-
pass
248-
249-
def is_floating(arr):
250-
return np.issubdtype(arr.dtype, float)
251-
252-
data1 = v1.data
253-
data2 = v2.data
254-
if hasattr(data1, 'equals'):
255-
# handle pandas.Index objects
256-
return data1.equals(data2)
257-
elif is_floating(data1) or is_floating(data2):
258-
return allclose_or_equiv(data1, data2, rtol=rtol, atol=atol)
246+
return True
259247
else:
260-
return np.array_equal(data1, data2)
248+
def is_floating(arr):
249+
return np.issubdtype(arr.dtype, float)
250+
251+
data1 = v1.data
252+
data2 = v2.data
253+
if is_floating(data1) or is_floating(data2):
254+
return allclose_or_equiv(data1, data2, rtol=rtol, atol=atol)
255+
else:
256+
return np.array_equal(data1, data2)
261257
else:
262258
return False
263259

264260

261+
def safe_cast_to_index(array):
262+
"""Given an array, safely cast it to a pandas.Index
263+
264+
Unlike pandas.Index, if the array has dtype=object or dtype=timedelta64,
265+
this function will not attempt to do automatic type conversion but will
266+
always return an index with dtype=object.
267+
"""
268+
kwargs = {}
269+
if isinstance(array, np.ndarray):
270+
if array.dtype == object or array.dtype == np.timedelta64:
271+
kwargs['dtype'] = object
272+
return pd.Index(array, **kwargs)
273+
274+
265275
def update_safety_check(first_dict, second_dict, compat=operator.eq):
266276
"""Check the safety of updating one dictionary with another.
267277

0 commit comments

Comments
 (0)