Skip to content

Commit 4cb7d02

Browse files
committed
Merge non concatenated variables.
1 parent 4c994e2 commit 4cb7d02

File tree

1 file changed

+68
-50
lines changed

1 file changed

+68
-50
lines changed

xarray/core/concat.py

Lines changed: 68 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,12 @@
55

66
from . import dtypes, utils
77
from .alignment import align
8+
from .merge import (
9+
determine_coords,
10+
merge_variables,
11+
expand_variable_dicts,
12+
_VALID_COMPAT,
13+
)
814
from .variable import IndexVariable, Variable, as_variable
915
from .variable import concat as concat_vars
1016

@@ -65,7 +71,7 @@ def concat(
6571
in addition to the 'minimal' coordinates.
6672
compat : {'equals', 'identical', 'override'}, optional
6773
String indicating how to compare non-concatenated variables and
68-
dataset global attributes for potential conflicts.
74+
dataset global attributes for potential conflicts. This is passed down to merge.
6975
* 'equals' means that all variable values and dimensions must be the same;
7076
* 'identical' means that variable attributes and global attributes
7177
must also be equal.
@@ -145,7 +151,7 @@ def concat(
145151
"`data_vars` and `coords` arguments"
146152
)
147153

148-
if compat not in ["equals", "identical", "override", "no_conflicts"]:
154+
if compat not in _VALID_COMPAT:
149155
raise ValueError(
150156
"compat=%r invalid: must be 'equals', 'identical or 'override'" % compat
151157
)
@@ -186,22 +192,28 @@ def _calc_concat_dim_coord(dim):
186192
return dim, coord
187193

188194

189-
def _calc_concat_over(datasets, dim, data_vars, coords):
195+
def _calc_concat_over(datasets, dim, dim_names, data_vars, coords, compat):
190196
"""
191197
Determine which dataset variables need to be concatenated in the result,
192-
and which can simply be taken from the first dataset.
193198
"""
194199
# Return values
195200
concat_over = set()
196201
equals = {}
197202

198-
if dim in datasets[0]:
203+
if dim in dim_names:
199204
concat_over_existing_dim = True
200205
concat_over.add(dim)
201206
else:
202207
concat_over_existing_dim = False
203208

204209
for ds in datasets:
210+
if concat_over_existing_dim:
211+
if dim not in ds.dims:
212+
# TODO: why did I do this
213+
if dim in ds:
214+
ds = ds.set_coords(dim)
215+
else:
216+
raise ValueError("%r is not a dimension in all datasets" % dim)
205217
concat_over.update(k for k, v in ds.variables.items() if dim in v.dims)
206218

207219
def process_subset_opt(opt, subset):
@@ -225,7 +237,7 @@ def process_subset_opt(opt, subset):
225237
for ds_rhs in datasets[1:]:
226238
v_rhs = ds_rhs.variables[k].compute()
227239
computed.append(v_rhs)
228-
if not v_lhs.equals(v_rhs):
240+
if not getattr(v_lhs, compat)(v_rhs):
229241
concat_over.add(k)
230242
equals[k] = False
231243
# computed variables are not to be re-computed
@@ -291,68 +303,74 @@ def _dataset_concat(
291303
*datasets, join=join, copy=False, exclude=[dim], fill_value=fill_value
292304
)
293305

294-
concat_over, equals = _calc_concat_over(datasets, dim, data_vars, coords)
306+
# determine dimensional coordinate names and a dict mapping name to DataArray
307+
def determine_dims(datasets, result_coord_names):
308+
dims = set()
309+
coords = dict()
310+
for ds in datasets:
311+
for dim in set(ds.dims) - dims:
312+
if dim not in coords:
313+
coords[dim] = ds.coords[dim].variable
314+
dims = dims | set(ds.dims)
315+
return dims, coords
316+
317+
result_coord_names, noncoord_names = determine_coords(datasets)
318+
both_data_and_coords = result_coord_names & noncoord_names
319+
if both_data_and_coords:
320+
raise ValueError(
321+
"%r is a coordinate in some datasets and not in others."
322+
% list(both_data_and_coords)[0] # preserve previous behaviour
323+
)
324+
dim_names, result_coords = determine_dims(datasets, result_coord_names)
325+
# we don't want the concat dimension in the result dataset yet
326+
result_coords.pop(dim, None)
295327

296-
def insert_result_variable(k, v):
297-
assert isinstance(v, Variable)
298-
if k in datasets[0].coords:
299-
result_coord_names.add(k)
300-
result_vars[k] = v
328+
# determine which variables to concatentate
329+
concat_over, equals = _calc_concat_over(
330+
datasets, dim, dim_names, data_vars, coords, compat
331+
)
332+
333+
# determine which variables to merge
334+
variables_to_merge = (result_coord_names | noncoord_names) - concat_over - dim_names
335+
if variables_to_merge:
336+
to_merge = []
337+
for ds in datasets:
338+
to_merge.append(ds.reset_coords()[list(variables_to_merge)])
339+
# TODO: Provide equals as an argument and thread that down to merge.unique_variable
340+
result_vars = merge_variables(
341+
expand_variable_dicts(to_merge), priority_vars=None, compat=compat
342+
)
343+
else:
344+
result_vars = OrderedDict()
345+
result_vars.update(result_coords)
301346

302-
# create the new dataset and add constant variables
303-
result_vars = OrderedDict()
304-
result_coord_names = set(datasets[0].coords)
347+
# assign attrs and encoding from first dataset
305348
result_attrs = datasets[0].attrs
306349
result_encoding = datasets[0].encoding
307350

308-
for k, v in datasets[0].variables.items():
309-
if k not in concat_over:
310-
insert_result_variable(k, v)
351+
def insert_result_variable(k, v):
352+
assert isinstance(v, Variable)
353+
result_vars[k] = v
311354

312-
# check that global attributes and non-concatenated variables are fixed
313-
# across all datasets
355+
# check that global attributes are fixed across all datasets if necessary
314356
for ds in datasets[1:]:
315357
if compat == "identical" and not utils.dict_equiv(ds.attrs, result_attrs):
316358
raise ValueError("Dataset global attributes are not equal.")
317-
for k, v in ds.variables.items():
318-
if k not in result_vars and k not in concat_over:
319-
raise ValueError("Encountered unexpected variable %r" % k)
320-
elif (k in result_coord_names) != (k in ds.coords):
321-
raise ValueError(
322-
"%r is a coordinate in some datasets but not others." % k
323-
)
324-
elif compat != "override" and k in result_vars and k != dim:
325-
# Don't use Variable.identical as it internally invokes
326-
# Variable.equals, and we may already know the answer
327-
if compat == "identical" and not utils.dict_equiv(
328-
v.attrs, result_vars[k].attrs
329-
):
330-
raise ValueError(
331-
"Variable '%s' is not identical across datasets. "
332-
"You can skip this check by specifying compat='override'." % k
333-
)
334-
335-
# Proceed with equals()
336-
try:
337-
# May be populated when using the "different" method
338-
is_equal = equals[k]
339-
except KeyError:
340-
result_vars[k].load()
341-
is_equal = v.equals(result_vars[k])
342-
if not is_equal:
343-
raise ValueError(
344-
"Variable '%s' is not equal across datasets. "
345-
"You can skip this check by specifying compat='override'." % k
346-
)
347359

360+
##############
361+
# TODO: do this stuff earlier so we loop over datasets only once
362+
#############
348363
# we've already verified everything is consistent; now, calculate
349364
# shared dimension sizes so we can expand the necessary variables
350365
dim_lengths = [ds.dims.get(dim, 1) for ds in datasets]
366+
# non_concat_dims = dim_names - concat_over
351367
non_concat_dims = {}
352368
for ds in datasets:
353369
non_concat_dims.update(ds.dims)
354370
non_concat_dims.pop(dim, None)
355371

372+
# seems like there should be a helper function for this. We would need to add
373+
# an exclude kwarg to exclude comparing along concat_dim
356374
def ensure_common_dims(vars):
357375
# ensure each variable with the given name shares the same
358376
# dimensions and the same shape for all of them except along the

0 commit comments

Comments
 (0)