Skip to content

Commit 30d522d

Browse files
authored
Add multidimensional binning demo (#203)
* Add nD binning notebook * fix.
1 parent f2945f0 commit 30d522d

File tree

2 files changed

+374
-0
lines changed

2 files changed

+374
-0
lines changed

docs/source/user-stories.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@
88
user-stories/climatology.ipynb
99
user-stories/climatology-hourly.ipynb
1010
user-stories/custom-aggregations.ipynb
11+
user-stories/nD-bins.ipynb
1112
```
Lines changed: 373 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,373 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "e970d800-c612-482a-bb3a-b1eb7ad53d88",
6+
"metadata": {
7+
"tags": [],
8+
"user_expressions": []
9+
},
10+
"source": [
11+
"# Binning with multi-dimensional bins\n",
12+
"\n",
13+
"```{warning}\n",
14+
"This post is a proof-of-concept for discussion. Expect APIs to change to enable this use case.\n",
15+
"```\n",
16+
"\n",
17+
"Here we explore a binning problem where the bins are multidimensional\n",
18+
"([xhistogram issue](https://github.com/xgcm/xhistogram/issues/28))\n",
19+
"\n",
20+
"> One of such multi-dim bin applications is the ranked probability score rps we\n",
21+
"> use in `xskillscore.rps`, where we want to know how many forecasts fell into\n",
22+
"> which bins. Bins are often defined as terciles of the forecast distribution\n",
23+
"> and the bins for these terciles\n",
24+
"> (`forecast_with_lon_lat_time_dims.quantile(q=[.33,.66],dim='time')`) depend on\n",
25+
"> `lon` and `lat`.\n"
26+
]
27+
},
28+
{
29+
"cell_type": "code",
30+
"execution_count": null,
31+
"id": "01f1a2ef-de62-45d0-a04e-343cd78debc5",
32+
"metadata": {
33+
"tags": []
34+
},
35+
"outputs": [],
36+
"source": [
37+
"import math\n",
38+
"\n",
39+
"import numpy as np\n",
40+
"import pandas as pd\n",
41+
"import xarray as xr\n",
42+
"\n",
43+
"import flox\n",
44+
"import flox.xarray"
45+
]
46+
},
47+
{
48+
"cell_type": "markdown",
49+
"id": "0be3e214-0cf0-426f-8ebb-669cc5322310",
50+
"metadata": {
51+
"user_expressions": []
52+
},
53+
"source": [
54+
"## Create test data\n"
55+
]
56+
},
57+
{
58+
"cell_type": "markdown",
59+
"id": "ce239000-e053-4fc3-ad14-e9e0160da869",
60+
"metadata": {
61+
"user_expressions": []
62+
},
63+
"source": [
64+
"Data to be reduced\n"
65+
]
66+
},
67+
{
68+
"cell_type": "code",
69+
"execution_count": null,
70+
"id": "7659c24e-f5a1-4e59-84c0-5ec965ef92d2",
71+
"metadata": {
72+
"tags": []
73+
},
74+
"outputs": [],
75+
"source": [
76+
"array = xr.DataArray(\n",
77+
" np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]),\n",
78+
" dims=(\"space\", \"time\"),\n",
79+
" name=\"array\",\n",
80+
")\n",
81+
"array"
82+
]
83+
},
84+
{
85+
"cell_type": "markdown",
86+
"id": "da0c0ac9-ad75-42cd-a1ea-99069f5bef00",
87+
"metadata": {
88+
"user_expressions": []
89+
},
90+
"source": [
91+
"Array to group by\n"
92+
]
93+
},
94+
{
95+
"cell_type": "code",
96+
"execution_count": null,
97+
"id": "4601e744-5d22-447e-97ce-9644198d485e",
98+
"metadata": {
99+
"tags": []
100+
},
101+
"outputs": [],
102+
"source": [
103+
"by = xr.DataArray(\n",
104+
" np.array([[1, 2, 3], [3, 4, 5], [5, 6, 7], [6, 7, 9]]),\n",
105+
" dims=(\"space\", \"time\"),\n",
106+
" name=\"by\",\n",
107+
")\n",
108+
"by"
109+
]
110+
},
111+
{
112+
"cell_type": "markdown",
113+
"id": "61c21c94-7b6e-46a6-b9c2-59d7b2d40c81",
114+
"metadata": {
115+
"tags": [],
116+
"user_expressions": []
117+
},
118+
"source": [
119+
"Multidimensional bins:\n"
120+
]
121+
},
122+
{
123+
"cell_type": "code",
124+
"execution_count": null,
125+
"id": "863a1991-ab8d-47c0-aa48-22b422fcea8c",
126+
"metadata": {
127+
"tags": []
128+
},
129+
"outputs": [],
130+
"source": [
131+
"bins = by + 0.5\n",
132+
"bins = xr.DataArray(\n",
133+
" np.concatenate([bins, bins[:, [-1]] + 1], axis=-1)[:, :-1].T,\n",
134+
" dims=(\"time\", \"nbins\"),\n",
135+
" name=\"bins\",\n",
136+
")\n",
137+
"bins"
138+
]
139+
},
140+
{
141+
"cell_type": "markdown",
142+
"id": "e65ecaba-d1cc-4485-ae58-c390cb2ebfab",
143+
"metadata": {
144+
"user_expressions": []
145+
},
146+
"source": [
147+
"## Concept\n",
148+
"\n",
149+
"The key idea is that GroupBy is two steps:\n",
150+
"\n",
151+
"1. Factorize (a.k.a \"digitize\") : convert the `by` data to a set of integer\n",
152+
" codes representing the bins.\n",
153+
"2. Apply the reduction.\n",
154+
"\n",
155+
"We treat multi-dimensional binning as a slightly complicated factorization\n",
156+
"problem. Assume that bins are a function of `time`. So we\n",
157+
"\n",
158+
"1. generate a set of appropriate integer codes by:\n",
159+
" 1. Loop over \"time\" and factorize the data appropriately.\n",
160+
" 2. Add an offset to these codes so that \"bin 0\" for `time=0` is different\n",
161+
" from \"bin 0\" for `time=1`\n",
162+
"2. apply the groupby reduction to the \"offset codes\"\n",
163+
"3. reshape the output to the right shape\n",
164+
"\n",
165+
"We will work at the xarray level, so its easy to keep track of the different\n",
166+
"dimensions.\n",
167+
"\n",
168+
"### Factorizing\n",
169+
"\n",
170+
"The core `factorize_` function (which wraps `pd.cut`) only handles 1D bins, so\n",
171+
"we use `xr.apply_ufunc` to vectorize it for us.\n"
172+
]
173+
},
174+
{
175+
"cell_type": "code",
176+
"execution_count": null,
177+
"id": "aa33ab2c-0ecf-4198-a033-2a77f5d83c99",
178+
"metadata": {
179+
"tags": []
180+
},
181+
"outputs": [],
182+
"source": [
183+
"factorize_loop_dim = \"time\""
184+
]
185+
},
186+
{
187+
"cell_type": "code",
188+
"execution_count": null,
189+
"id": "afcddcc1-dd57-461e-a649-1f8bcd30342f",
190+
"metadata": {
191+
"tags": []
192+
},
193+
"outputs": [],
194+
"source": [
195+
"def factorize_nd_bins_core(by, bins):\n",
196+
" group_idx, *_, props = flox.core.factorize_(\n",
197+
" (by,),\n",
198+
" axes=(-1,),\n",
199+
" expected_groups=(pd.IntervalIndex.from_breaks(bins),),\n",
200+
" )\n",
201+
" # Use -1 as the NaN sentinel value\n",
202+
" group_idx[props.nanmask] = -1\n",
203+
" return group_idx\n",
204+
"\n",
205+
"\n",
206+
"codes = xr.apply_ufunc(\n",
207+
" factorize_nd_bins_core,\n",
208+
" by,\n",
209+
" bins,\n",
210+
" # TODO: avoid hardcoded dim names\n",
211+
" input_core_dims=[[\"space\"], [\"nbins\"]],\n",
212+
" output_core_dims=[[\"space\"]],\n",
213+
" vectorize=True,\n",
214+
")\n",
215+
"codes"
216+
]
217+
},
218+
{
219+
"cell_type": "markdown",
220+
"id": "1661312a-dc61-4a26-bfd8-12c2dc01eb15",
221+
"metadata": {
222+
"user_expressions": []
223+
},
224+
"source": [
225+
"### Offset the codes\n",
226+
"\n",
227+
"These are integer codes appropriate for a single timestep.\n",
228+
"\n",
229+
"We now add an offset that changes in time, to make sure \"bin 0\" for `time=0` is\n",
230+
"different from \"bin 0\" for `time=1` (taken from\n",
231+
"[this StackOverflow thread](https://stackoverflow.com/questions/46256279/bin-elements-per-row-vectorized-2d-bincount-for-numpy)).\n"
232+
]
233+
},
234+
{
235+
"cell_type": "code",
236+
"execution_count": null,
237+
"id": "0e5801cb-a79c-4670-ad10-36bb19f1a6ff",
238+
"metadata": {
239+
"tags": []
240+
},
241+
"outputs": [],
242+
"source": [
243+
"N = math.prod([codes.sizes[d] for d in codes.dims if d != factorize_loop_dim])\n",
244+
"offset = xr.DataArray(np.arange(codes.sizes[factorize_loop_dim]), dims=factorize_loop_dim)\n",
245+
"# TODO: think about N-1 here\n",
246+
"offset_codes = (codes + offset * (N - 1)).rename(by.name)\n",
247+
"offset_codes.data[codes == -1] = -1\n",
248+
"offset_codes"
249+
]
250+
},
251+
{
252+
"cell_type": "markdown",
253+
"id": "6c06c48b-316b-4a33-9bc3-921acd10bcba",
254+
"metadata": {
255+
"user_expressions": []
256+
},
257+
"source": [
258+
"### Reduce\n",
259+
"\n",
260+
"Now that we have appropriate codes, let's apply the reduction\n"
261+
]
262+
},
263+
{
264+
"cell_type": "code",
265+
"execution_count": null,
266+
"id": "2cf1295e-4585-48b9-ac2b-9e00d03b2b9a",
267+
"metadata": {
268+
"tags": []
269+
},
270+
"outputs": [],
271+
"source": [
272+
"interim = flox.xarray.xarray_reduce(\n",
273+
" array,\n",
274+
" offset_codes,\n",
275+
" func=\"sum\",\n",
276+
" # We use RangeIndex to indicate that `-1` code can be safely ignored\n",
277+
" # (it indicates values outside the bins)\n",
278+
" # TODO: Avoid hardcoding 9 = sizes[\"time\"] x (sizes[\"nbins\"] - 1)\n",
279+
" expected_groups=pd.RangeIndex(9),\n",
280+
")\n",
281+
"interim"
282+
]
283+
},
284+
{
285+
"cell_type": "markdown",
286+
"id": "3539509b-d9b4-4342-a679-6ada6f285dfb",
287+
"metadata": {
288+
"user_expressions": []
289+
},
290+
"source": [
291+
"## Make final result\n",
292+
"\n",
293+
"Now reshape that 1D result appropriately.\n"
294+
]
295+
},
296+
{
297+
"cell_type": "code",
298+
"execution_count": null,
299+
"id": "b1389d37-d76d-4a50-9dfb-8710258de3fd",
300+
"metadata": {
301+
"tags": []
302+
},
303+
"outputs": [],
304+
"source": [
305+
"final = (\n",
306+
" interim.coarsen(by=3)\n",
307+
" # bin_number dimension is last, this makes sense since it is the core dimension\n",
308+
" # and we vectorize over the loop dims.\n",
309+
" # So the first (Nbins-1) elements are for the first index of the loop dim\n",
310+
" .construct({\"by\": (factorize_loop_dim, \"bin_number\")})\n",
311+
" .transpose(..., factorize_loop_dim)\n",
312+
" .drop_vars(\"by\")\n",
313+
")\n",
314+
"final"
315+
]
316+
},
317+
{
318+
"cell_type": "markdown",
319+
"id": "a98b5e60-94af-45ae-be1b-4cb47e2d77ba",
320+
"metadata": {
321+
"user_expressions": []
322+
},
323+
"source": [
324+
"I think this is the expected answer.\n"
325+
]
326+
},
327+
{
328+
"cell_type": "code",
329+
"execution_count": null,
330+
"id": "053a8643-f6d9-4fd1-b014-230fa716449c",
331+
"metadata": {
332+
"tags": []
333+
},
334+
"outputs": [],
335+
"source": [
336+
"array.isel(space=slice(1, None)).rename({\"space\": \"bin_number\"}).identical(final)"
337+
]
338+
},
339+
{
340+
"cell_type": "markdown",
341+
"id": "619ba4c4-7c87-459a-ab86-c187d3a86c67",
342+
"metadata": {
343+
"tags": [],
344+
"user_expressions": []
345+
},
346+
"source": [
347+
"## TODO\n",
348+
"\n",
349+
"This could be extended to:\n",
350+
"\n",
351+
"1. handle multiple `factorize_loop_dim`\n",
352+
"2. avoid hard coded dimension names in the `apply_ufunc` call for factorizing\n",
353+
"3. avoid hard coded number of output elements in the `xarray_reduce` call.\n",
354+
"4. Somehow propagate the bin edges to the final output.\n"
355+
]
356+
}
357+
],
358+
"metadata": {
359+
"language_info": {
360+
"codemirror_mode": {
361+
"name": "ipython",
362+
"version": 3
363+
},
364+
"file_extension": ".py",
365+
"mimetype": "text/x-python",
366+
"name": "python",
367+
"nbconvert_exporter": "python",
368+
"pygments_lexer": "ipython3"
369+
}
370+
},
371+
"nbformat": 4,
372+
"nbformat_minor": 5
373+
}

0 commit comments

Comments
 (0)