Skip to content

Commit 016ace3

Browse files
committed
FEAT/FIX: added support for proper stack(Session) (closes #1057)
I consider this as a bug fix because stack(Session) *seemed* to work but did not keep labels (because the Session was considered as a Sequence), which is surprising and caused issue in production code
1 parent 95fd27a commit 016ace3

File tree

3 files changed

+32
-6
lines changed

3 files changed

+32
-6
lines changed

doc/source/changes/version_0_34_1.rst.inc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ New features
2020

2121
* added support for Python 3.11.
2222

23+
* added support for stacking all arrays of a Session by simply doing: `stack(my_session)` instead of
24+
`stack(my_session.items())` (closes :issue:`1057`).
2325

2426
.. _misc:
2527

larray/core/array.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9586,13 +9586,12 @@ def stack(elements=None, axes=None, title=None, meta=None, dtype=None, res_axes=
95869586
95879587
Parameters
95889588
----------
9589-
elements : tuple, list or dict.
9589+
elements : tuple, list, dict or Session.
95909590
Elements to stack. Elements can be scalars, arrays, sessions, (label, value) pairs or a {label: value} mapping.
9591-
In the later case, axis must be defined and cannot be a name only, because we need to have labels order,
9592-
which the mapping does not provide.
95939591
9594-
Stacking sessions will return a new session containing the arrays of all sessions stacked together. An array
9595-
missing in a session will be replaced by NaN.
9592+
Stacking a single session will stack all its arrays in a single array.
9593+
Stacking several sessions will take the corresponding arrays in all the sessions and stack them, returning a
9594+
new session. An array missing in a session will be replaced by NaN.
95969595
axes : str, Axis, Group or sequence of Axis, optional
95979596
Axes to create. If None, defaults to a range() axis.
95989597
title : str, optional
@@ -9733,14 +9732,18 @@ def stack(elements=None, axes=None, title=None, meta=None, dtype=None, res_axes=
97339732
if kwargs:
97349733
elements = kwargs.items()
97359734

9736-
if isinstance(elements, dict):
9735+
if isinstance(elements, (dict, Session)):
97379736
elements = elements.items()
97389737

97399738
if isinstance(elements, Array):
97409739
if axes is None:
97419740
axes = -1
97429741
axes = elements.axes[axes]
97439742
items = elements.items(axes)
9743+
elif isinstance(elements, Session):
9744+
if axes is None:
9745+
axes = 'array'
9746+
items = elements.items()
97449747
elif isinstance(elements, Iterable):
97459748
if not isinstance(elements, Sequence):
97469749
elements = list(elements)

larray/tests/test_session.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,6 +657,27 @@ def test_arrays():
657657

658658

659659
def test_stack():
660+
# stacking all arrays of a single session
661+
# =======================================
662+
# a) using explicit axis
663+
s = Session(arr1=ndtest(3), arr2=ndtest(3) + 10)
664+
axis = Axis("array=arr2,arr1")
665+
res = stack(s, axis)
666+
expected = stack(arr1=s.arr1, arr2=s.arr2, axes=axis)
667+
assert_larray_equal(res, expected)
668+
669+
# b) using explicit axis name
670+
s = Session(arr1=ndtest(3), arr2=ndtest(3) + 10)
671+
res = stack(s, "array")
672+
expected = stack(arr1=s.arr1, arr2=s.arr2, axes="array")
673+
assert_larray_equal(res, expected)
674+
675+
# c) using not axis information
676+
s = Session(arr1=ndtest(3), arr2=ndtest(3) + 10)
677+
res = stack(s)
678+
expected = stack(arr1=s.arr1, arr2=s.arr2)
679+
assert_larray_equal(res, expected)
680+
660681
# stacking two sessions (it will stack all arrays with corresponding names
661682
s1 = Session(arr1=ndtest(3), arr2=ndtest(3) + 10)
662683
s2 = Session(arr1=ndtest(3), arr2=ndtest(3) + 30)

0 commit comments

Comments
 (0)