|
7 | 7 | import numpy as np
|
8 | 8 |
|
9 | 9 | from xarray.namedarray._typing import (
|
| 10 | + Default, |
10 | 11 | _arrayapi,
|
| 12 | + _Axis, |
| 13 | + _default, |
| 14 | + _Dim, |
11 | 15 | _DType,
|
12 | 16 | _ScalarType,
|
13 | 17 | _ShapeType,
|
@@ -144,3 +148,51 @@ def real(
|
144 | 148 | xp = _get_data_namespace(x)
|
145 | 149 | out = x._new(data=xp.real(x._data))
|
146 | 150 | return out
|
| 151 | + |
| 152 | + |
| 153 | +# %% Manipulation functions |
| 154 | +def expand_dims( |
| 155 | + x: NamedArray[Any, _DType], |
| 156 | + /, |
| 157 | + *, |
| 158 | + dim: _Dim | Default = _default, |
| 159 | + axis: _Axis = 0, |
| 160 | +) -> NamedArray[Any, _DType]: |
| 161 | + """ |
| 162 | + Expands the shape of an array by inserting a new dimension of size one at the |
| 163 | + position specified by dims. |
| 164 | +
|
| 165 | + Parameters |
| 166 | + ---------- |
| 167 | + x : |
| 168 | + Array to expand. |
| 169 | + dim : |
| 170 | + Dimension name. New dimension will be stored in the axis position. |
| 171 | + axis : |
| 172 | + (Not recommended) Axis position (zero-based). Default is 0. |
| 173 | +
|
| 174 | + Returns |
| 175 | + ------- |
| 176 | + out : |
| 177 | + An expanded output array having the same data type as x. |
| 178 | +
|
| 179 | + Examples |
| 180 | + -------- |
| 181 | + >>> x = NamedArray(("x", "y"), nxp.asarray([[1.0, 2.0], [3.0, 4.0]])) |
| 182 | + >>> expand_dims(x) |
| 183 | + <xarray.NamedArray (dim_2: 1, x: 2, y: 2)> |
| 184 | + Array([[[1., 2.], |
| 185 | + [3., 4.]]], dtype=float64) |
| 186 | + >>> expand_dims(x, dim="z") |
| 187 | + <xarray.NamedArray (z: 1, x: 2, y: 2)> |
| 188 | + Array([[[1., 2.], |
| 189 | + [3., 4.]]], dtype=float64) |
| 190 | + """ |
| 191 | + xp = _get_data_namespace(x) |
| 192 | + dims = x.dims |
| 193 | + if dim is _default: |
| 194 | + dim = f"dim_{len(dims)}" |
| 195 | + d = list(dims) |
| 196 | + d.insert(axis, dim) |
| 197 | + out = x._new(dims=tuple(d), data=xp.expand_dims(x._data, axis=axis)) |
| 198 | + return out |
0 commit comments