5
5
6
6
from . import dtypes , utils
7
7
from .alignment import align
8
+ from .merge import (
9
+ determine_coords ,
10
+ merge_variables ,
11
+ expand_variable_dicts ,
12
+ _VALID_COMPAT ,
13
+ )
8
14
from .variable import IndexVariable , Variable , as_variable
9
15
from .variable import concat as concat_vars
10
16
@@ -65,7 +71,7 @@ def concat(
65
71
in addition to the 'minimal' coordinates.
66
72
compat : {'equals', 'identical', 'override'}, optional
67
73
String indicating how to compare non-concatenated variables and
68
- dataset global attributes for potential conflicts.
74
+ dataset global attributes for potential conflicts. This is passed down to merge.
69
75
* 'equals' means that all variable values and dimensions must be the same;
70
76
* 'identical' means that variable attributes and global attributes
71
77
must also be equal.
@@ -145,7 +151,7 @@ def concat(
145
151
"`data_vars` and `coords` arguments"
146
152
)
147
153
148
- if compat not in [ "equals" , "identical" , "override" , "no_conflicts" ] :
154
+ if compat not in _VALID_COMPAT :
149
155
raise ValueError (
150
156
"compat=%r invalid: must be 'equals', 'identical or 'override'" % compat
151
157
)
@@ -186,22 +192,28 @@ def _calc_concat_dim_coord(dim):
186
192
return dim , coord
187
193
188
194
189
- def _calc_concat_over (datasets , dim , data_vars , coords ):
195
+ def _calc_concat_over (datasets , dim , dim_names , data_vars , coords , compat ):
190
196
"""
191
197
Determine which dataset variables need to be concatenated in the result,
192
- and which can simply be taken from the first dataset.
193
198
"""
194
199
# Return values
195
200
concat_over = set ()
196
201
equals = {}
197
202
198
- if dim in datasets [ 0 ] :
203
+ if dim in dim_names :
199
204
concat_over_existing_dim = True
200
205
concat_over .add (dim )
201
206
else :
202
207
concat_over_existing_dim = False
203
208
204
209
for ds in datasets :
210
+ if concat_over_existing_dim :
211
+ if dim not in ds .dims :
212
+ # TODO: why did I do this
213
+ if dim in ds :
214
+ ds = ds .set_coords (dim )
215
+ else :
216
+ raise ValueError ("%r is not a dimension in all datasets" % dim )
205
217
concat_over .update (k for k , v in ds .variables .items () if dim in v .dims )
206
218
207
219
def process_subset_opt (opt , subset ):
@@ -225,7 +237,7 @@ def process_subset_opt(opt, subset):
225
237
for ds_rhs in datasets [1 :]:
226
238
v_rhs = ds_rhs .variables [k ].compute ()
227
239
computed .append (v_rhs )
228
- if not v_lhs . equals (v_rhs ):
240
+ if not getattr ( v_lhs , compat ) (v_rhs ):
229
241
concat_over .add (k )
230
242
equals [k ] = False
231
243
# computed variables are not to be re-computed
@@ -291,68 +303,74 @@ def _dataset_concat(
291
303
* datasets , join = join , copy = False , exclude = [dim ], fill_value = fill_value
292
304
)
293
305
294
- concat_over , equals = _calc_concat_over (datasets , dim , data_vars , coords )
306
+ # determine dimensional coordinate names and a dict mapping name to DataArray
307
+ def determine_dims (datasets , result_coord_names ):
308
+ dims = set ()
309
+ coords = dict ()
310
+ for ds in datasets :
311
+ for dim in set (ds .dims ) - dims :
312
+ if dim not in coords :
313
+ coords [dim ] = ds .coords [dim ].variable
314
+ dims = dims | set (ds .dims )
315
+ return dims , coords
316
+
317
+ result_coord_names , noncoord_names = determine_coords (datasets )
318
+ both_data_and_coords = result_coord_names & noncoord_names
319
+ if both_data_and_coords :
320
+ raise ValueError (
321
+ "%r is a coordinate in some datasets and not in others."
322
+ % list (both_data_and_coords )[0 ] # preserve previous behaviour
323
+ )
324
+ dim_names , result_coords = determine_dims (datasets , result_coord_names )
325
+ # we don't want the concat dimension in the result dataset yet
326
+ result_coords .pop (dim , None )
295
327
296
- def insert_result_variable (k , v ):
297
- assert isinstance (v , Variable )
298
- if k in datasets [0 ].coords :
299
- result_coord_names .add (k )
300
- result_vars [k ] = v
328
+ # determine which variables to concatentate
329
+ concat_over , equals = _calc_concat_over (
330
+ datasets , dim , dim_names , data_vars , coords , compat
331
+ )
332
+
333
+ # determine which variables to merge
334
+ variables_to_merge = (result_coord_names | noncoord_names ) - concat_over - dim_names
335
+ if variables_to_merge :
336
+ to_merge = []
337
+ for ds in datasets :
338
+ to_merge .append (ds .reset_coords ()[list (variables_to_merge )])
339
+ # TODO: Provide equals as an argument and thread that down to merge.unique_variable
340
+ result_vars = merge_variables (
341
+ expand_variable_dicts (to_merge ), priority_vars = None , compat = compat
342
+ )
343
+ else :
344
+ result_vars = OrderedDict ()
345
+ result_vars .update (result_coords )
301
346
302
- # create the new dataset and add constant variables
303
- result_vars = OrderedDict ()
304
- result_coord_names = set (datasets [0 ].coords )
347
+ # assign attrs and encoding from first dataset
305
348
result_attrs = datasets [0 ].attrs
306
349
result_encoding = datasets [0 ].encoding
307
350
308
- for k , v in datasets [ 0 ]. variables . items ( ):
309
- if k not in concat_over :
310
- insert_result_variable ( k , v )
351
+ def insert_result_variable ( k , v ):
352
+ assert isinstance ( v , Variable )
353
+ result_vars [ k ] = v
311
354
312
- # check that global attributes and non-concatenated variables are fixed
313
- # across all datasets
355
+ # check that global attributes are fixed across all datasets if necessary
314
356
for ds in datasets [1 :]:
315
357
if compat == "identical" and not utils .dict_equiv (ds .attrs , result_attrs ):
316
358
raise ValueError ("Dataset global attributes are not equal." )
317
- for k , v in ds .variables .items ():
318
- if k not in result_vars and k not in concat_over :
319
- raise ValueError ("Encountered unexpected variable %r" % k )
320
- elif (k in result_coord_names ) != (k in ds .coords ):
321
- raise ValueError (
322
- "%r is a coordinate in some datasets but not others." % k
323
- )
324
- elif compat != "override" and k in result_vars and k != dim :
325
- # Don't use Variable.identical as it internally invokes
326
- # Variable.equals, and we may already know the answer
327
- if compat == "identical" and not utils .dict_equiv (
328
- v .attrs , result_vars [k ].attrs
329
- ):
330
- raise ValueError (
331
- "Variable '%s' is not identical across datasets. "
332
- "You can skip this check by specifying compat='override'." % k
333
- )
334
-
335
- # Proceed with equals()
336
- try :
337
- # May be populated when using the "different" method
338
- is_equal = equals [k ]
339
- except KeyError :
340
- result_vars [k ].load ()
341
- is_equal = v .equals (result_vars [k ])
342
- if not is_equal :
343
- raise ValueError (
344
- "Variable '%s' is not equal across datasets. "
345
- "You can skip this check by specifying compat='override'." % k
346
- )
347
359
360
+ ##############
361
+ # TODO: do this stuff earlier so we loop over datasets only once
362
+ #############
348
363
# we've already verified everything is consistent; now, calculate
349
364
# shared dimension sizes so we can expand the necessary variables
350
365
dim_lengths = [ds .dims .get (dim , 1 ) for ds in datasets ]
366
+ # non_concat_dims = dim_names - concat_over
351
367
non_concat_dims = {}
352
368
for ds in datasets :
353
369
non_concat_dims .update (ds .dims )
354
370
non_concat_dims .pop (dim , None )
355
371
372
+ # seems like there should be a helper function for this. We would need to add
373
+ # an exclude kwarg to exclude comparing along concat_dim
356
374
def ensure_common_dims (vars ):
357
375
# ensure each variable with the given name shares the same
358
376
# dimensions and the same shape for all of them except along the
0 commit comments