Skip to content

Commit ed5900b

Browse files
tsanonatiagodcherian
authored
Extend padding functionalities (#9353)
* add functionality to pad dataset data variables with unique constant values * clean up implementation of variable specific padding for dataset. Add tests * more expressive docsting and symplefy type signature with alias in dataset pad func. enforce number values to be converted to tuples for all in `_pad_options_dim_to_index`. make variable pad funtion consistent with dataset. extend tests * fix typing * add terms to conf.py, make docstrings more accurate, expand tests for dataset pad function * filter constant value types without mutating input map * add todo to change default padding for missing variables in constant_values * add changes to whats new --------- Co-authored-by: tiago <[email protected]> Co-authored-by: Deepak Cherian <[email protected]>
1 parent a56a407 commit ed5900b

File tree

6 files changed

+114
-30
lines changed

6 files changed

+114
-30
lines changed

doc/conf.py

+3
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,9 @@
153153
"matplotlib colormap name": ":doc:`matplotlib colormap name <matplotlib:gallery/color/colormap_reference>`",
154154
"matplotlib axes object": ":py:class:`matplotlib axes object <matplotlib.axes.Axes>`",
155155
"colormap": ":py:class:`colormap <matplotlib.colors.Colormap>`",
156+
# xarray terms
157+
"dim name": ":term:`dimension name <name>`",
158+
"var name": ":term:`variable name <name>`",
156159
# objects without namespace: xarray
157160
"DataArray": "~xarray.DataArray",
158161
"Dataset": "~xarray.Dataset",

doc/whats-new.rst

+2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ New Features
2424
~~~~~~~~~~~~
2525
- Make chunk manager an option in ``set_options`` (:pull:`9362`).
2626
By `Tom White <https://github.com/tomwhite>`_.
27+
- Allow data variable specific ``constant_values`` in the dataset ``pad`` function (:pull:`9353``).
28+
By `Tiago Sanona <https://github.com/tsanona>`_.
2729

2830
Breaking changes
2931
~~~~~~~~~~~~~~~~

xarray/core/dataset.py

+22-10
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@
163163
ReindexMethodOptions,
164164
SideOptions,
165165
T_ChunkDimFreq,
166+
T_DatasetPadConstantValues,
166167
T_Xarray,
167168
)
168169
from xarray.core.weighted import DatasetWeighted
@@ -9153,9 +9154,7 @@ def pad(
91539154
stat_length: (
91549155
int | tuple[int, int] | Mapping[Any, tuple[int, int]] | None
91559156
) = None,
9156-
constant_values: (
9157-
float | tuple[float, float] | Mapping[Any, tuple[float, float]] | None
9158-
) = None,
9157+
constant_values: T_DatasetPadConstantValues | None = None,
91599158
end_values: int | tuple[int, int] | Mapping[Any, tuple[int, int]] | None = None,
91609159
reflect_type: PadReflectOptions = None,
91619160
keep_attrs: bool | None = None,
@@ -9211,17 +9210,19 @@ def pad(
92119210
(stat_length,) or int is a shortcut for before = after = statistic
92129211
length for all axes.
92139212
Default is ``None``, to use the entire axis.
9214-
constant_values : scalar, tuple or mapping of hashable to tuple, default: 0
9215-
Used in 'constant'. The values to set the padded values for each
9216-
axis.
9213+
constant_values : scalar, tuple, mapping of dim name to scalar or tuple, or \
9214+
mapping of var name to scalar, tuple or to mapping of dim name to scalar or tuple, default: None
9215+
Used in 'constant'. The values to set the padded values for each data variable / axis.
9216+
``{var_1: {dim_1: (before_1, after_1), ... dim_N: (before_N, after_N)}, ...
9217+
var_M: (before, after)}`` unique pad constants per data variable.
92179218
``{dim_1: (before_1, after_1), ... dim_N: (before_N, after_N)}`` unique
92189219
pad constants along each dimension.
92199220
``((before, after),)`` yields same before and after constants for each
92209221
dimension.
92219222
``(constant,)`` or ``constant`` is a shortcut for ``before = after = constant`` for
92229223
all dimensions.
9223-
Default is 0.
9224-
end_values : scalar, tuple or mapping of hashable to tuple, default: 0
9224+
Default is ``None``, pads with ``np.nan``.
9225+
end_values : scalar, tuple or mapping of hashable to tuple, default: None
92259226
Used in 'linear_ramp'. The values used for the ending value of the
92269227
linear_ramp and that will form the edge of the padded array.
92279228
``{dim_1: (before_1, after_1), ... dim_N: (before_N, after_N)}`` unique
@@ -9230,7 +9231,7 @@ def pad(
92309231
axis.
92319232
``(constant,)`` or ``constant`` is a shortcut for ``before = after = constant`` for
92329233
all axes.
9233-
Default is 0.
9234+
Default is None.
92349235
reflect_type : {"even", "odd", None}, optional
92359236
Used in "reflect", and "symmetric". The "even" style is the
92369237
default with an unaltered reflection around the edge value. For
@@ -9304,11 +9305,22 @@ def pad(
93049305
if not var_pad_width:
93059306
variables[name] = var
93069307
elif name in self.data_vars:
9308+
if utils.is_dict_like(constant_values):
9309+
if name in constant_values.keys():
9310+
filtered_constant_values = constant_values[name]
9311+
elif not set(var.dims).isdisjoint(constant_values.keys()):
9312+
filtered_constant_values = {
9313+
k: v for k, v in constant_values.items() if k in var.dims
9314+
}
9315+
else:
9316+
filtered_constant_values = 0 # TODO: https://github.com/pydata/xarray/pull/9353#discussion_r1724018352
9317+
else:
9318+
filtered_constant_values = constant_values
93079319
variables[name] = var.pad(
93089320
pad_width=var_pad_width,
93099321
mode=mode,
93109322
stat_length=stat_length,
9311-
constant_values=constant_values,
9323+
constant_values=filtered_constant_values,
93129324
end_values=end_values,
93139325
reflect_type=reflect_type,
93149326
keep_attrs=keep_attrs,

xarray/core/types.py

+5
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,11 @@ def copy(
243243
"symmetric",
244244
"wrap",
245245
]
246+
T_PadConstantValues = float | tuple[float, float]
247+
T_VarPadConstantValues = T_PadConstantValues | Mapping[Any, T_PadConstantValues]
248+
T_DatasetPadConstantValues = (
249+
T_VarPadConstantValues | Mapping[Any, T_VarPadConstantValues]
250+
)
246251
PadReflectOptions = Literal["even", "odd", None]
247252

248253
CFCalendar = Literal[

xarray/core/variable.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
Self,
6666
T_Chunks,
6767
T_DuckArray,
68+
T_VarPadConstantValues,
6869
)
6970
from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint
7071

@@ -1121,9 +1122,14 @@ def shift(self, shifts=None, fill_value=dtypes.NA, **shifts_kwargs):
11211122

11221123
def _pad_options_dim_to_index(
11231124
self,
1124-
pad_option: Mapping[Any, int | tuple[int, int]],
1125+
pad_option: Mapping[Any, int | float | tuple[int, int] | tuple[float, float]],
11251126
fill_with_shape=False,
11261127
):
1128+
# change number values to a tuple of two of those values
1129+
for k, v in pad_option.items():
1130+
if isinstance(v, numbers.Number):
1131+
pad_option[k] = (v, v)
1132+
11271133
if fill_with_shape:
11281134
return [
11291135
(n, n) if d not in pad_option else pad_option[d]
@@ -1138,9 +1144,7 @@ def pad(
11381144
stat_length: (
11391145
int | tuple[int, int] | Mapping[Any, tuple[int, int]] | None
11401146
) = None,
1141-
constant_values: (
1142-
float | tuple[float, float] | Mapping[Any, tuple[float, float]] | None
1143-
) = None,
1147+
constant_values: T_VarPadConstantValues | None = None,
11441148
end_values: int | tuple[int, int] | Mapping[Any, tuple[int, int]] | None = None,
11451149
reflect_type: PadReflectOptions = None,
11461150
keep_attrs: bool | None = None,
@@ -1160,7 +1164,7 @@ def pad(
11601164
stat_length : int, tuple or mapping of hashable to tuple
11611165
Used in 'maximum', 'mean', 'median', and 'minimum'. Number of
11621166
values at edge of each axis used to calculate the statistic value.
1163-
constant_values : scalar, tuple or mapping of hashable to tuple
1167+
constant_values : scalar, tuple or mapping of hashable to scalar or tuple
11641168
Used in 'constant'. The values to set the padded values for each
11651169
axis.
11661170
end_values : scalar, tuple or mapping of hashable to tuple
@@ -1207,10 +1211,6 @@ def pad(
12071211
if stat_length is None and mode in ["maximum", "mean", "median", "minimum"]:
12081212
stat_length = [(n, n) for n in self.data.shape] # type: ignore[assignment]
12091213

1210-
# change integer values to a tuple of two of those values and change pad_width to index
1211-
for k, v in pad_width.items():
1212-
if isinstance(v, numbers.Number):
1213-
pad_width[k] = (v, v)
12141214
pad_width_by_index = self._pad_options_dim_to_index(pad_width)
12151215

12161216
# create pad_options_kwargs, numpy/dask requires only relevant kwargs to be nonempty

xarray/tests/test_dataset.py

+73-11
Original file line numberDiff line numberDiff line change
@@ -6704,18 +6704,80 @@ def test_polyfit_warnings(self) -> None:
67046704
ds.var1.polyfit("dim2", 10, full=True)
67056705
assert len(ws) == 1
67066706

6707-
def test_pad(self) -> None:
6708-
ds = create_test_data(seed=1)
6709-
padded = ds.pad(dim2=(1, 1), constant_values=42)
6710-
6711-
assert padded["dim2"].shape == (11,)
6712-
assert padded["var1"].shape == (8, 11)
6713-
assert padded["var2"].shape == (8, 11)
6714-
assert padded["var3"].shape == (10, 8)
6715-
assert dict(padded.sizes) == {"dim1": 8, "dim2": 11, "dim3": 10, "time": 20}
6707+
@staticmethod
6708+
def _test_data_var_interior(
6709+
original_data_var, padded_data_var, padded_dim_name, expected_pad_values
6710+
):
6711+
np.testing.assert_equal(
6712+
np.unique(padded_data_var.isel({padded_dim_name: [0, -1]})),
6713+
expected_pad_values,
6714+
)
6715+
np.testing.assert_array_equal(
6716+
padded_data_var.isel({padded_dim_name: slice(1, -1)}), original_data_var
6717+
)
67166718

6717-
np.testing.assert_equal(padded["var1"].isel(dim2=[0, -1]).data, 42)
6718-
np.testing.assert_equal(padded["dim2"][[0, -1]].data, np.nan)
6719+
@pytest.mark.parametrize("padded_dim_name", ["dim1", "dim2", "dim3", "time"])
6720+
@pytest.mark.parametrize(
6721+
["constant_values"],
6722+
[
6723+
pytest.param(None, id="default"),
6724+
pytest.param(42, id="scalar"),
6725+
pytest.param((42, 43), id="tuple"),
6726+
pytest.param({"dim1": 42, "dim2": 43}, id="per dim scalar"),
6727+
pytest.param({"dim1": (42, 43), "dim2": (43, 44)}, id="per dim tuple"),
6728+
pytest.param({"var1": 42, "var2": (42, 43)}, id="per var"),
6729+
pytest.param({"var1": 42, "dim1": (42, 43)}, id="mixed"),
6730+
],
6731+
)
6732+
def test_pad(self, padded_dim_name, constant_values) -> None:
6733+
ds = create_test_data(seed=1)
6734+
padded = ds.pad({padded_dim_name: (1, 1)}, constant_values=constant_values)
6735+
6736+
# test padded dim values and size
6737+
for ds_dim_name, ds_dim in ds.sizes.items():
6738+
if ds_dim_name == padded_dim_name:
6739+
np.testing.assert_equal(padded.sizes[ds_dim_name], ds_dim + 2)
6740+
if ds_dim_name in padded.coords:
6741+
assert padded[ds_dim_name][[0, -1]].isnull().all()
6742+
else:
6743+
np.testing.assert_equal(padded.sizes[ds_dim_name], ds_dim)
6744+
6745+
# check if coord "numbers" with dimention dim3 is paded correctly
6746+
if padded_dim_name == "dim3":
6747+
assert padded["numbers"][[0, -1]].isnull().all()
6748+
# twarning: passes but dtype changes from int to float
6749+
np.testing.assert_array_equal(padded["numbers"][1:-1], ds["numbers"])
6750+
6751+
# test if data_vars are paded with correct values
6752+
for data_var_name, data_var in padded.data_vars.items():
6753+
if padded_dim_name in data_var.dims:
6754+
if utils.is_dict_like(constant_values):
6755+
if (
6756+
expected := constant_values.get(data_var_name, None)
6757+
) is not None:
6758+
self._test_data_var_interior(
6759+
ds[data_var_name], data_var, padded_dim_name, expected
6760+
)
6761+
elif (
6762+
expected := constant_values.get(padded_dim_name, None)
6763+
) is not None:
6764+
self._test_data_var_interior(
6765+
ds[data_var_name], data_var, padded_dim_name, expected
6766+
)
6767+
else:
6768+
self._test_data_var_interior(
6769+
ds[data_var_name], data_var, padded_dim_name, 0
6770+
)
6771+
elif constant_values:
6772+
self._test_data_var_interior(
6773+
ds[data_var_name], data_var, padded_dim_name, constant_values
6774+
)
6775+
else:
6776+
self._test_data_var_interior(
6777+
ds[data_var_name], data_var, padded_dim_name, np.nan
6778+
)
6779+
else:
6780+
assert_array_equal(data_var, ds[data_var_name])
67196781

67206782
@pytest.mark.parametrize(
67216783
["keep_attrs", "attrs", "expected"],

0 commit comments

Comments
 (0)