1
+ from collections import defaultdict
2
+ import xarray as xr
3
+ import numpy as np
4
+
5
+ def get_1d_dims (d ):
6
+ """
7
+ Find all dimensions in a dataset that are purely 1-dimensional,
8
+ i.e., those dimensions that are not part of a 2D or higher-D
9
+ variable.
10
+
11
+ arguments
12
+ d: xarray Dataset
13
+ returns
14
+ dims1d: a list of dimension names
15
+ """
16
+ # Assume all dims coorespond to 1D vars
17
+ dims1d = list (d .dims .keys ())
18
+ for varname , var in d .variables .items ():
19
+ if len (var .dims ) > 1 :
20
+ for vardim in var .dims :
21
+ if vardim in dims1d :
22
+ dims1d .remove (str (vardim ))
23
+ return dims1d
24
+
25
+ def gen_1d_datasets (d ):
26
+ """
27
+ Generate a sequence of datasets having only those variables
28
+ along each dimension that is only used for 1-dimensional variables.
29
+
30
+ arguments
31
+ d: xarray Dataset
32
+ returns
33
+ generator function yielding a sequence of single-dimension datasets
34
+ """
35
+ dims1d = get_1d_dims (d )
36
+ # print(dims1d)
37
+ for dim in dims1d :
38
+ all_dims = list (d .dims .keys ())
39
+ all_dims .remove (dim )
40
+ yield d .drop_dims (all_dims )
41
+
42
+ def get_1d_datasets (d ):
43
+ """
44
+ Generate a list of datasets having only those variables
45
+ along each dimension that is only used for 1-dimensional variables.
46
+
47
+ arguments
48
+ d: xarray Dataset
49
+ returns
50
+ a list of single-dimension datasets
51
+ """
52
+ return [d1 for d1 in gen_1d_datasets (d , * args , ** kwargs )]
53
+
54
+ def get_scalar_vars (d ):
55
+ scalars = []
56
+ for varname , var in d .variables .items ():
57
+ if len (var .dims ) == 0 :
58
+ scalars .append (varname )
59
+ return scalars
60
+
61
+ def concat_1d_dims (datasets , stack_scalars = None ):
62
+ """
63
+ For each xarray Dataset in datasets, concatenate (preserving the order of datasets)
64
+ all variables along dimensions that are only used for one-dimensional variables.
65
+
66
+ arguments
67
+ d: iterable of xarray Datasets
68
+ stack_scalars: create a new dimension named with this value
69
+ that aggregates all scalar variables and coordinates
70
+ returns
71
+ a new xarray Dataset with only the single-dimension variables
72
+ """
73
+ # dictionary mapping dimension names to a list of all
74
+ # datasets having only that dimension
75
+ all_1d_datasets = defaultdict (list )
76
+
77
+ for d in datasets :
78
+ scalars = get_scalar_vars (d )
79
+ for d_1d_initial in gen_1d_datasets (d ):
80
+ # Get rid of scalars
81
+ d_1d = d_1d_initial .drop (scalars )
82
+ dims = tuple (d_1d .dims .keys ())
83
+ all_1d_datasets [dims [0 ]].append (d_1d )
84
+ if stack_scalars :
85
+ # restore scalars along new dimension stack_scalars
86
+ scalar_dataset = xr .Dataset ()
87
+ for scalar_var in scalars :
88
+ # promote from scalar to an array with a dimension, and remove
89
+ # the coordinate info so that it's just a regular variable.
90
+ as_1d = d [scalar_var ].expand_dims (stack_scalars ).reset_coords (drop = True )
91
+ scalar_dataset [scalar_var ] = as_1d # xr.DataArray(as_1d, dims=[stack_scalars])
92
+ all_1d_datasets [stack_scalars ].append (scalar_dataset )
93
+
94
+ unified = xr .Dataset ()
95
+ for dim in all_1d_datasets :
96
+ combined = xr .concat (all_1d_datasets [dim ], dim , coords = 'minimal' , data_vars = 'minimal' )
97
+ unified .update (combined )
98
+ return unified
99
+
100
+ # datasets=[]
101
+ # for i, size in enumerate((4, 6)):
102
+ # a = xr.DataArray(10*i + np.arange(size), dims='x')
103
+ # b = xr.DataArray(10*i + np.arange(size/2), dims='y')
104
+ # c = xr.DataArray(20*i + np.arange(size*3), dims='t')
105
+ # d = xr.DataArray(11*i + np.arange(size*3), dims='t')
106
+ # T = xr.DataArray(10*i + np.arange(size)**2, dims='x')
107
+ # D = xr.DataArray(10*i + np.arange(size/2)**2, dims='y')
108
+ # z = xr.DataArray(10*i + np.arange(size*4)**2, dims='z')
109
+ # u = xr.DataArray(10*i + np.arange(size*5)**2, dims='u')
110
+ # v = xr.DataArray(12*i + np.arange(size*5)**2, dims='u')
111
+ # P = xr.DataArray(10*i + np.ones((size,int(size/2))), dims=['x', 'y'])
112
+ # Q = xr.DataArray(20*i + np.ones((size,int(size/2))), dims=['x', 'y'])
113
+ # d = xr.Dataset({'x':a,'y':b, 't':c, 'd':d, 'u':u, 'v':v, 'z':z, 'T':T, 'D':D, 'P':P, 'Q':Q})
114
+ # datasets.append(d)
115
+ # # datasets.append(d[{'x':slice(None, None), 'y':slice(0,0)}])
116
+ # for d in datasets: print(d,'\n')
117
+ # # xr.combine_by_coords(datasets, coords='all')
118
+ # # xr.combine_nested(datasets, coords='all', data_vars='all')
119
+
120
+ # # print(get_1d_dims(d))
121
+ # assert(get_1d_dims(d)==['t', 'u', 'z'])
122
+ # # for d1 in get_1d_datasets(d):
123
+ # # print(d1,'\n')
124
+
125
+ # combined = concat_1d_dims(datasets)
126
+ # print(combined)
0 commit comments