From d53548e7708852ff76b7486fd496b06045979686 Mon Sep 17 00:00:00 2001 From: Sam Gardner Date: Mon, 24 Feb 2025 12:52:33 -0600 Subject: [PATCH 1/3] add traversal tests from glmtools --- tests/test_traversal.py | 172 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 172 insertions(+) create mode 100644 tests/test_traversal.py diff --git a/tests/test_traversal.py b/tests/test_traversal.py new file mode 100644 index 0000000..216386e --- /dev/null +++ b/tests/test_traversal.py @@ -0,0 +1,172 @@ +import numpy as np +from numpy.testing import assert_equal +import xarray as xr + +from pyxlma.lmalib.traversal import OneToManyTraversal + +def get_four_level_data(): + storm_id = [0,1,2] + flash_parent_storm_id = [0,0,0,2,2,2,2,2] + flash_id = [1,2,3,4,5,6,7,8] + stroke_parent_flash_id = [1,1,2,3, 4, 4, 4, 6, 8, 8, 8] + stroke_id = [4,6,7,9,13,14,15,19,20,23,46] + trig_parent_stroke_id = [4,4,4,4,6, 7, 7, 9,13,13,14,14,15,19,20,20,23,46] + trig_id = [1,3,5,8,9,10,12,16,18,19,20,22,23,25,26,30,31,32] + trig_parent_storm_id = [0,0,0,0,0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2] + trig_parent_flash_id = [1,1,1,1,1, 2, 2, 3, 4, 4, 4, 4, 4, 6, 8, 8, 8, 8] + + storm_child_flash_count = [3,0,5] + flash_child_stroke_count = [2,1,1,3,0,1,0,3] + stroke_child_trig_count = [4,1,2,1,2,2,1,1,2,1,1] + + storm_child_stroke_count = [4,0,7] + storm_child_trig_count = [8,0,10] + + storm_dim = 'nstorms' + flash_dim = 'nflashes' + stroke_dim = 'nstrokes' + trig_dim = 'ntrigs' + + d = xr.Dataset({ + 'storm_id': (storm_dim, storm_id), + 'flash_id': (flash_dim, flash_id), + 'flash_parent_storm_id': (flash_dim, flash_parent_storm_id), + 'stroke_id': (stroke_dim, stroke_id), + 'stroke_parent_flash_id': (stroke_dim, stroke_parent_flash_id), + 'trig_id': (trig_dim, trig_id), + 'trig_parent_stroke_id': (trig_dim, trig_parent_stroke_id), + 'trig_parent_flash_id': (trig_dim, trig_parent_flash_id), + 'trig_parent_storm_id': (trig_dim, trig_parent_storm_id), + 'storm_child_flash_count': (storm_dim, storm_child_flash_count), + 'storm_child_stroke_count': (storm_dim, storm_child_stroke_count), + 'storm_child_trig_count': (storm_dim, storm_child_trig_count), + 'flash_child_stroke_count': (flash_dim, flash_child_stroke_count), + 'stroke_child_trig_count': (stroke_dim, stroke_child_trig_count), + }) + d = d.set_coords(['stroke_id', 'flash_id', 'storm_id', 'trig_id', + 'stroke_parent_flash_id', 'trig_parent_stroke_id', + 'flash_parent_storm_id']) + assert len(flash_id) == len(flash_parent_storm_id) + assert len(stroke_id) == len(stroke_parent_flash_id) + assert len(trig_id) == len(trig_parent_stroke_id) + assert sum(storm_child_flash_count) == len(flash_id) + assert sum(storm_child_stroke_count) == len(stroke_id) + assert sum(storm_child_trig_count) == len(trig_id) + assert sum(flash_child_stroke_count) == len(stroke_id) + assert sum(stroke_child_trig_count) == len(trig_id) + return d + +def get_four_level_data_traversal(): + d = get_four_level_data() + + entity_vars = ('storm_id', 'flash_id', 'stroke_id', 'trig_id') + parent_vars = ('flash_parent_storm_id', + 'stroke_parent_flash_id', + 'trig_parent_stroke_id') + traversal = OneToManyTraversal(d, entity_vars, parent_vars) + return d, traversal + +def get_four_level_empty_data_traversal(): + d = get_four_level_data() + empty_all = {dim:slice(0,0) for dim in d.dims} + d = d[empty_all] + + entity_vars = ('storm_id', 'flash_id', 'stroke_id', 'trig_id') + parent_vars = ('flash_parent_storm_id', + 'stroke_parent_flash_id', + 'trig_parent_stroke_id') + traversal = OneToManyTraversal(d, entity_vars, parent_vars) + return d, traversal + +def test_empty_traversal_init(): + d, traversal = get_four_level_empty_data_traversal() + +def test_count_children(): + d, traversal = get_four_level_data_traversal() + + # validation data + storm_child_count = d['storm_child_flash_count'].data + flash_child_count = d['flash_child_stroke_count'].data + stroke_child_count = d['stroke_child_trig_count'].data + storm_child_stroke_count = d['storm_child_stroke_count'].data + storm_child_trig_count = d['storm_child_trig_count'].data + + n_storms = traversal.count_children('storm_id')[0] + assert_equal(storm_child_count, n_storms) + n_flashes = traversal.count_children('flash_id')[0] + assert_equal(flash_child_count, n_flashes) + n_strokes = traversal.count_children('stroke_id')[0] + assert_equal(stroke_child_count, n_strokes) + + all_counts = traversal.count_children('storm_id', 'trig_id') + assert_equal(storm_child_count, all_counts[0]) + assert_equal(flash_child_count, all_counts[1]) + assert_equal(stroke_child_count, all_counts[2]) + + grouper = d.groupby('trig_parent_storm_id').groups + count = [len(grouper[eid]) if (eid in grouper) else 0 + for eid in d['storm_id'].data] + assert_equal(storm_child_trig_count, count) + +def test_replicate_parent_ids(): + d, traversal = get_four_level_data_traversal() + trig_parent_storm_ids = traversal.replicate_parent_ids('storm_id', + 'trig_parent_stroke_id') + trig_parent_flash_ids = traversal.replicate_parent_ids('flash_id', + 'trig_parent_stroke_id') + trig_parent_stroke_ids = traversal.replicate_parent_ids('stroke_id', + 'trig_parent_stroke_id') + + assert_equal(d['trig_parent_storm_id'].data, trig_parent_storm_ids) + assert_equal(d['trig_parent_flash_id'].data, trig_parent_flash_ids) + assert_equal(d['trig_parent_stroke_id'].data, trig_parent_stroke_ids) + + +def test_prune_from_middle(): + d, traversal = get_four_level_data_traversal() + + reduced_stroke_id = [13,15,23] + d = traversal.reduce_to_entities('stroke_id', reduced_stroke_id) + reduced_storm_id = [2,] + reduced_flash_id = [4,8] + reduced_trig_id = [18,19,23,31] + assert_equal(d['storm_id'].data, reduced_storm_id) + assert_equal(d['flash_id'].data, reduced_flash_id) + assert_equal(d['stroke_id'].data, reduced_stroke_id) + assert_equal(d['trig_id'].data, reduced_trig_id) + +def test_prune_from_bottom(): + d, traversal = get_four_level_data_traversal() + + trig_idx = slice(7,10) + reduced_storm_id = np.unique(d['trig_parent_storm_id'][trig_idx].data) + reduced_flash_id = np.unique(d['trig_parent_flash_id'][trig_idx].data) + reduced_stroke_id = np.unique(d['trig_parent_stroke_id'][trig_idx].data) + reduced_trig_id = d['trig_id'][trig_idx].data + d = traversal.reduce_to_entities('trig_id', reduced_trig_id) + assert_equal(d['trig_id'].data, reduced_trig_id) + assert_equal(d['stroke_id'].data, reduced_stroke_id) + assert_equal(d['flash_id'].data, reduced_flash_id) + assert_equal(d['storm_id'].data, reduced_storm_id) + +def test_prune_from_top(): + d, traversal = get_four_level_data_traversal() + reduced_storm_id = [1,] + d = traversal.reduce_to_entities('storm_id', reduced_storm_id) + reduced_stroke_id = np.asarray([]) + reduced_flash_id = np.asarray([]) + reduced_trig_id = np.asarray([]) + assert_equal(d['storm_id'], reduced_storm_id) + assert_equal(d['flash_id'], reduced_flash_id) + assert_equal(d['stroke_id'], reduced_stroke_id) + assert_equal(d['trig_id'], reduced_trig_id) + + reduced_storm_id = [2,] + d = traversal.reduce_to_entities('storm_id', reduced_storm_id) + reduced_flash_id = [4,5,6,7,8] + reduced_stroke_id = [13,14,15,19,20,23,46] + reduced_trig_id = [18,19,20,22,23,25,26,30,31,32] + assert_equal(d['storm_id'].data, reduced_storm_id) + assert_equal(d['flash_id'].data, reduced_flash_id) + assert_equal(d['stroke_id'].data, reduced_stroke_id) + assert_equal(d['trig_id'].data, reduced_trig_id) \ No newline at end of file From e5b9c95a92cdd61a7d70c5fc305b24454ac3f53c Mon Sep 17 00:00:00 2001 From: Sam Gardner Date: Mon, 24 Feb 2025 15:03:45 -0600 Subject: [PATCH 2/3] fix traversal for new xarray giving slices to represent group indices instead of integers --- pyxlma/lmalib/traversal.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pyxlma/lmalib/traversal.py b/pyxlma/lmalib/traversal.py index 54cdc01..ec11fa3 100644 --- a/pyxlma/lmalib/traversal.py +++ b/pyxlma/lmalib/traversal.py @@ -222,6 +222,8 @@ def reduce_to_entities(self, entity_id_var, entity_ids): e_iter = (np.atleast_1d(p_group[eid]) for eid in last_entity_ids if eid in p_group) e_idx = list(itertools.chain.from_iterable(e_iter)) + if type(e_idx[0]) == slice: + e_idx = np.array([e.start for e in e_idx]) if len(e_idx) == 0: # xarray doesn't accept an empty array as a valid index e_idx = slice(0, 0) @@ -242,6 +244,8 @@ def reduce_to_entities(self, entity_id_var, entity_ids): e_iter = (np.atleast_1d(e_group[eid]) for eid in entity_ids if eid in e_group) e_idx = list(itertools.chain.from_iterable(e_iter)) + if type(e_idx[0]) == slice: + e_idx = np.array([e.start for e in e_idx]) last_entity_ids = entity_ids # == dataset[e_var].data if len(e_idx) == 0: # xarray doesn't accept an empty array as a valid index From 31a3f751fdda8745976b34c4f8b4402f9fb5fe13 Mon Sep 17 00:00:00 2001 From: Sam Gardner Date: Mon, 24 Feb 2025 15:09:14 -0600 Subject: [PATCH 3/3] check if e_idx contains any values before checking first type --- pyxlma/lmalib/traversal.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyxlma/lmalib/traversal.py b/pyxlma/lmalib/traversal.py index ec11fa3..d5d19d7 100644 --- a/pyxlma/lmalib/traversal.py +++ b/pyxlma/lmalib/traversal.py @@ -222,7 +222,7 @@ def reduce_to_entities(self, entity_id_var, entity_ids): e_iter = (np.atleast_1d(p_group[eid]) for eid in last_entity_ids if eid in p_group) e_idx = list(itertools.chain.from_iterable(e_iter)) - if type(e_idx[0]) == slice: + if len(e_idx) > 0 and type(e_idx[0]) == slice: e_idx = np.array([e.start for e in e_idx]) if len(e_idx) == 0: # xarray doesn't accept an empty array as a valid index @@ -244,7 +244,7 @@ def reduce_to_entities(self, entity_id_var, entity_ids): e_iter = (np.atleast_1d(e_group[eid]) for eid in entity_ids if eid in e_group) e_idx = list(itertools.chain.from_iterable(e_iter)) - if type(e_idx[0]) == slice: + if len(e_idx) > 0 and type(e_idx[0]) == slice: e_idx = np.array([e.start for e in e_idx]) last_entity_ids = entity_ids # == dataset[e_var].data if len(e_idx) == 0: