Skip to content

Commit dafbef9

Browse files
committed
Add multidimensional binning demo
1 parent 892af67 commit dafbef9

File tree

1 file changed

+339
-0
lines changed

1 file changed

+339
-0
lines changed
Lines changed: 339 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,339 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "e970d800-c612-482a-bb3a-b1eb7ad53d88",
6+
"metadata": {},
7+
"source": [
8+
"# Multi-dimensional Binning\n",
9+
"\n",
10+
"```{warning}\n",
11+
"This post is a proof-of-concept for discussion. Expect APIs to change to enable this use case.\n",
12+
"```\n",
13+
"\n",
14+
"Here we explore a binning problem where the bins are multidimensional\n",
15+
"([xhistogram issue](https://github.com/xgcm/xhistogram/issues/28))\n",
16+
"\n",
17+
"> One of such multi-dim bin applications is the ranked probability score rps we\n",
18+
"> use in `xskillscore.rps`, where we want to know how many forecasts fell into\n",
19+
"> which bins. Bins are often defined as terciles of the forecast distribution\n",
20+
"> and the bins for these terciles\n",
21+
"> (`forecast_with_lon_lat_time_dims.quantile(q=[.33,.66],dim='time')`) depend on\n",
22+
"> `lon` and `lat`.\n"
23+
]
24+
},
25+
{
26+
"cell_type": "code",
27+
"execution_count": null,
28+
"id": "01f1a2ef-de62-45d0-a04e-343cd78debc5",
29+
"metadata": {},
30+
"outputs": [],
31+
"source": [
32+
"import math\n",
33+
"\n",
34+
"import flox\n",
35+
"import flox.xarray\n",
36+
"import numpy as np\n",
37+
"import pandas as pd\n",
38+
"import xarray as xr"
39+
]
40+
},
41+
{
42+
"cell_type": "markdown",
43+
"id": "0be3e214-0cf0-426f-8ebb-669cc5322310",
44+
"metadata": {},
45+
"source": [
46+
"## Create test data\n"
47+
]
48+
},
49+
{
50+
"cell_type": "markdown",
51+
"id": "ce239000-e053-4fc3-ad14-e9e0160da869",
52+
"metadata": {},
53+
"source": [
54+
"Data to be reduced\n"
55+
]
56+
},
57+
{
58+
"cell_type": "code",
59+
"execution_count": null,
60+
"id": "7659c24e-f5a1-4e59-84c0-5ec965ef92d2",
61+
"metadata": {},
62+
"outputs": [],
63+
"source": [
64+
"array = xr.DataArray(\n",
65+
" np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]),\n",
66+
" dims=(\"space\", \"time\"),\n",
67+
" name=\"array\",\n",
68+
")\n",
69+
"array"
70+
]
71+
},
72+
{
73+
"cell_type": "markdown",
74+
"id": "da0c0ac9-ad75-42cd-a1ea-99069f5bef00",
75+
"metadata": {},
76+
"source": [
77+
"Array to group by\n"
78+
]
79+
},
80+
{
81+
"cell_type": "code",
82+
"execution_count": null,
83+
"id": "4601e744-5d22-447e-97ce-9644198d485e",
84+
"metadata": {},
85+
"outputs": [],
86+
"source": [
87+
"by = xr.DataArray(\n",
88+
" np.array([[1, 2, 3], [3, 4, 5], [5, 6, 7], [6, 7, 9]]),\n",
89+
" dims=(\"space\", \"time\"),\n",
90+
" name=\"by\",\n",
91+
")\n",
92+
"by"
93+
]
94+
},
95+
{
96+
"cell_type": "markdown",
97+
"id": "61c21c94-7b6e-46a6-b9c2-59d7b2d40c81",
98+
"metadata": {},
99+
"source": [
100+
"multidimensional bins:\n"
101+
]
102+
},
103+
{
104+
"cell_type": "code",
105+
"execution_count": null,
106+
"id": "863a1991-ab8d-47c0-aa48-22b422fcea8c",
107+
"metadata": {},
108+
"outputs": [],
109+
"source": [
110+
"bins = by + 0.5\n",
111+
"bins = xr.DataArray(\n",
112+
" np.concatenate([bins, bins[:, [-1]] + 1], axis=-1)[:, :-1].T,\n",
113+
" dims=(\"time\", \"nbins\"),\n",
114+
")\n",
115+
"# array += xr.DataArray([0, 1, 2, 4], dims=\"space\")\n",
116+
"\n",
117+
"display(by)\n",
118+
"display(bins)"
119+
]
120+
},
121+
{
122+
"cell_type": "markdown",
123+
"id": "e65ecaba-d1cc-4485-ae58-c390cb2ebfab",
124+
"metadata": {},
125+
"source": [
126+
"## Concept\n",
127+
"\n",
128+
"The key idea is that GroupBy is two steps:\n",
129+
"\n",
130+
"1. Factorize (a.k.a \"digitize\") : convert the `by` data to a set of integer\n",
131+
" codes representing the bins.\n",
132+
"2. Apply the reduction.\n",
133+
"\n",
134+
"We treat multi-dimensional binning as a slightly complicated factorization\n",
135+
"problem. Assume that bins are a function of `time`. So we\n",
136+
"\n",
137+
"1. generate a set of appropriate integer codes by:\n",
138+
" 1. Loop over \"time\" and factorize the data appropriately.\n",
139+
" 2. Add an offset to these codes so that \"bin 0\" for `time=0` is different\n",
140+
" from \"bin 0\" for `time=1`\n",
141+
"2. apply the groupby reduction to the \"offset codes\"\n",
142+
"3. reshape the output to the right shape\n",
143+
"\n",
144+
"We will work at the xarray level, so its easy to keep track of the different\n",
145+
"dimensions.\n",
146+
"\n",
147+
"### Factorizing\n",
148+
"\n",
149+
"The core `factorize_` function (which wraps `pd.cut`) only handles 1D bins, so\n",
150+
"we use `xr.apply_ufunc` to vectorize it for us.\n"
151+
]
152+
},
153+
{
154+
"cell_type": "code",
155+
"execution_count": null,
156+
"id": "aa33ab2c-0ecf-4198-a033-2a77f5d83c99",
157+
"metadata": {},
158+
"outputs": [],
159+
"source": [
160+
"factorize_loop_dim = \"time\""
161+
]
162+
},
163+
{
164+
"cell_type": "code",
165+
"execution_count": null,
166+
"id": "afcddcc1-dd57-461e-a649-1f8bcd30342f",
167+
"metadata": {},
168+
"outputs": [],
169+
"source": [
170+
"def factorize_nd_bins_core(by, bins):\n",
171+
" group_idx, *_, props = flox.core.factorize_(\n",
172+
" (by,),\n",
173+
" axis=-1,\n",
174+
" expected_groups=(pd.IntervalIndex.from_breaks(bins),),\n",
175+
" )\n",
176+
" # Use -1 as the NaN sentinel value\n",
177+
" group_idx[props.nanmask] = -1\n",
178+
" return group_idx\n",
179+
"\n",
180+
"\n",
181+
"codes = xr.apply_ufunc(\n",
182+
" factorize_nd_bins_core,\n",
183+
" by,\n",
184+
" bins,\n",
185+
" # TODO: avoid hardcoded dim names\n",
186+
" input_core_dims=[[\"space\"], [\"nbins\"]],\n",
187+
" output_core_dims=[[\"space\"]],\n",
188+
" vectorize=True,\n",
189+
")\n",
190+
"codes"
191+
]
192+
},
193+
{
194+
"cell_type": "markdown",
195+
"id": "1661312a-dc61-4a26-bfd8-12c2dc01eb15",
196+
"metadata": {},
197+
"source": [
198+
"### Offset the codes\n",
199+
"\n",
200+
"These are integer codes appropriate for a single timestep.\n",
201+
"\n",
202+
"We now add an offset that changes in time, to make sure \"bin 0\" for `time=0` is\n",
203+
"different from \"bin 0\" for `time=1` (taken from\n",
204+
"[this StackOverflow thread](https://stackoverflow.com/questions/46256279/bin-elements-per-row-vectorized-2d-bincount-for-numpy)).\n"
205+
]
206+
},
207+
{
208+
"cell_type": "code",
209+
"execution_count": null,
210+
"id": "0e5801cb-a79c-4670-ad10-36bb19f1a6ff",
211+
"metadata": {},
212+
"outputs": [],
213+
"source": [
214+
"N = math.prod([codes.sizes[d] for d in codes.dims if d != factorize_loop_dim])\n",
215+
"offset = xr.DataArray(\n",
216+
" np.arange(codes.sizes[factorize_loop_dim]), dims=factorize_loop_dim\n",
217+
")\n",
218+
"# TODO: think about N-1 here\n",
219+
"offset_codes = (codes + offset * (N - 1)).rename(by.name)\n",
220+
"offset_codes.data[codes == -1] = -1\n",
221+
"offset_codes"
222+
]
223+
},
224+
{
225+
"cell_type": "markdown",
226+
"id": "6c06c48b-316b-4a33-9bc3-921acd10bcba",
227+
"metadata": {},
228+
"source": [
229+
"### Reduce\n",
230+
"\n",
231+
"Now that we have appropriate codes, let's apply the reduction\n"
232+
]
233+
},
234+
{
235+
"cell_type": "code",
236+
"execution_count": null,
237+
"id": "2cf1295e-4585-48b9-ac2b-9e00d03b2b9a",
238+
"metadata": {},
239+
"outputs": [],
240+
"source": [
241+
"interim = flox.xarray.xarray_reduce(\n",
242+
" array,\n",
243+
" offset_codes,\n",
244+
" func=\"sum\",\n",
245+
" # We use RangeIndex to indicate that `-1` code can be safely ignored\n",
246+
" # (it indicates values outside the bins)\n",
247+
" # TODO: Avoid hardcoding 9 = sizes[\"time\"] x (sizes[\"nbins\"] - 1)\n",
248+
" expected_groups=pd.RangeIndex(9),\n",
249+
")\n",
250+
"interim"
251+
]
252+
},
253+
{
254+
"cell_type": "markdown",
255+
"id": "3539509b-d9b4-4342-a679-6ada6f285dfb",
256+
"metadata": {},
257+
"source": [
258+
"## Make final result\n",
259+
"\n",
260+
"Now reshape that 1D result appropriately.\n"
261+
]
262+
},
263+
{
264+
"cell_type": "code",
265+
"execution_count": null,
266+
"id": "b1389d37-d76d-4a50-9dfb-8710258de3fd",
267+
"metadata": {},
268+
"outputs": [],
269+
"source": [
270+
"final = (\n",
271+
" interim.coarsen(by=3)\n",
272+
" # bin_number dimension is last, this makes sense since it is the core dimension\n",
273+
" # and we vectorize over the loop dims.\n",
274+
" # So the first (Nbins-1) elements are for the first index of the loop dim\n",
275+
" .construct({\"by\": (factorize_loop_dim, \"bin_number\")})\n",
276+
" .transpose(..., factorize_loop_dim)\n",
277+
" .drop_vars(\"by\")\n",
278+
")\n",
279+
"final"
280+
]
281+
},
282+
{
283+
"cell_type": "markdown",
284+
"id": "a98b5e60-94af-45ae-be1b-4cb47e2d77ba",
285+
"metadata": {},
286+
"source": [
287+
"I think this is the expected answer.\n"
288+
]
289+
},
290+
{
291+
"cell_type": "code",
292+
"execution_count": null,
293+
"id": "053a8643-f6d9-4fd1-b014-230fa716449c",
294+
"metadata": {},
295+
"outputs": [],
296+
"source": [
297+
"array.isel(space=slice(1, None)).rename({\"space\": \"bin_number\"}).identical(\n",
298+
" final\n",
299+
")"
300+
]
301+
},
302+
{
303+
"cell_type": "markdown",
304+
"id": "619ba4c4-7c87-459a-ab86-c187d3a86c67",
305+
"metadata": {},
306+
"source": [
307+
"## TODO\n",
308+
"\n",
309+
"This could be extended to:\n",
310+
"\n",
311+
"1. handle multiple `factorize_loop_dim`\n",
312+
"2. avoid hard coded dimension names in the `apply_ufunc` call for factorizing\n",
313+
"3. avoid hard coded number of output elements in the `xarray_reduce` call.\n",
314+
"4. Somehow propagate the bin edges to the final output.\n"
315+
]
316+
}
317+
],
318+
"metadata": {
319+
"kernelspec": {
320+
"display_name": "flox-tests",
321+
"language": "python",
322+
"name": "conda-env-flox-tests-py"
323+
},
324+
"language_info": {
325+
"codemirror_mode": {
326+
"name": "ipython",
327+
"version": 3
328+
},
329+
"file_extension": ".py",
330+
"mimetype": "text/x-python",
331+
"name": "python",
332+
"nbconvert_exporter": "python",
333+
"pygments_lexer": "ipython3",
334+
"version": "3.10.6"
335+
}
336+
},
337+
"nbformat": 4,
338+
"nbformat_minor": 5
339+
}

0 commit comments

Comments
 (0)