diff --git a/examples/network_samples/OR_GLM-L2-LCFA_G16_s20233580057200_e20233580057400_c20233580057420.nc b/examples/network_samples/OR_GLM-L2-LCFA_G16_s20233580057200_e20233580057400_c20233580057420.nc new file mode 100644 index 0000000..1abde3e Binary files /dev/null and b/examples/network_samples/OR_GLM-L2-LCFA_G16_s20233580057200_e20233580057400_c20233580057420.nc differ diff --git a/examples/network_samples/OR_GLM-L2-LCFA_G16_s20233580057400_e20233580058000_c20233580058020.nc b/examples/network_samples/OR_GLM-L2-LCFA_G16_s20233580057400_e20233580058000_c20233580058020.nc new file mode 100644 index 0000000..beee790 Binary files /dev/null and b/examples/network_samples/OR_GLM-L2-LCFA_G16_s20233580057400_e20233580058000_c20233580058020.nc differ diff --git a/pyxlma/xarray_util.py b/pyxlma/xarray_util.py index 79b67cc..ccf9cb8 100644 --- a/pyxlma/xarray_util.py +++ b/pyxlma/xarray_util.py @@ -49,7 +49,7 @@ def get_1d_datasets(d): returns a list of single-dimension datasets """ - return [d1 for d1 in gen_1d_datasets(d, *args, **kwargs)] + return [d1 for d1 in gen_1d_datasets(d)] def get_scalar_vars(d): scalars = [] diff --git a/tests/test_xarray_util.py b/tests/test_xarray_util.py new file mode 100644 index 0000000..904d7ff --- /dev/null +++ b/tests/test_xarray_util.py @@ -0,0 +1,44 @@ +from glob import glob +import xarray as xr +from pyxlma.xarray_util import * + +def test_get_1d_dims(): + lma = xr.open_dataset('tests/truth/lma_netcdf/lma_stats.nc') + dims_1d = get_1d_dims(lma) + print(lma) + assert dims_1d == ['number_of_flashes'] + + +def test_gen_1d_datasets(): + lma = xr.open_dataset('tests/truth/lma_netcdf/lma_stats.nc') + datasets_1d = list(gen_1d_datasets(lma)) + assert len(datasets_1d) == 1 + assert list(datasets_1d[0].dims.keys()) == ['number_of_flashes'] + + +def test_get_1d_datasets(): + lma = xr.open_dataset('tests/truth/lma_netcdf/lma_stats.nc') + datasets_1d = get_1d_datasets(lma) + assert len(datasets_1d) == 1 + assert list(datasets_1d[0].dims.keys()) == ['number_of_flashes'] + + +def test_get_scalar_vars(): + lma = xr.open_dataset('tests/truth/lma_netcdf/lma_stats.nc') + scalar_vars = get_scalar_vars(lma) + assert scalar_vars == ['network_center_latitude', 'network_center_longitude', 'network_center_altitude', + 'flash_distance_separation_threshold', 'flash_time_separation_threshold'] + + +def test_concat_1d_dims_no_scalars(): + glm_datasets = [xr.open_dataset(path) for path in glob('examples/network_samples/OR_GLM*.nc')] + concatenated = concat_1d_dims(glm_datasets) + assert dict(concatenated.dims) == {'number_of_events': 29664, 'number_of_groups': 10420, 'number_of_flashes': 606, + 'number_of_time_bounds': 6, 'number_of_wavelength_bounds': 6, 'number_of_field_of_view_bounds': 6} + + +def test_concat_1d_dims_scalars(): + glm_datasets = [xr.open_dataset(path) for path in glob('examples/network_samples/OR_GLM*.nc')] + concatenated = concat_1d_dims(glm_datasets, stack_scalars='scalars') + assert dict(concatenated.dims) == {'number_of_events': 29664, 'number_of_groups': 10420, 'number_of_flashes': 606, 'number_of_time_bounds': 6, + 'number_of_wavelength_bounds': 6, 'number_of_field_of_view_bounds': 6, 'scalars': 3}