diff --git a/dask_expr/_concat.py b/dask_expr/_concat.py index f99503121..e532659fe 100644 --- a/dask_expr/_concat.py +++ b/dask_expr/_concat.py @@ -15,9 +15,11 @@ Blockwise, Expr, Projection, + ToFrame, are_co_aligned, determine_column_projection, ) +from dask_expr._util import _convert_to_list class Concat(Expr): @@ -235,6 +237,7 @@ def get_columns_or_name(e: Expr): return e.columns if e.ndim == 2 else [e.name] columns = determine_column_projection(self, parent, dependents) + columns = _convert_to_list(columns) columns_frame = [ [col for col in get_columns_or_name(frame) if col in columns] for frame in self._frames @@ -252,18 +255,22 @@ def get_columns_or_name(e: Expr): for frame, cols in zip(self._frames, columns_frame) if len(cols) > 0 ] - return type(parent)( - type(self)( - self.join, - self.ignore_order, - self._kwargs, - self.axis, - self.ignore_unknown_divisions, - self.interleave_partitions, - *frames, - ), - *parent.operands[1:], + result = type(self)( + self.join, + self.ignore_order, + self._kwargs, + self.axis, + self.ignore_unknown_divisions, + self.interleave_partitions, + *frames, ) + if result.columns == _convert_to_list(parent.operand("columns")): + if result.ndim == parent.ndim: + return result + elif result.ndim < parent.ndim: + return ToFrame(result) + + return type(parent)(result, *parent.operands[1:]) class StackPartition(Concat): diff --git a/dask_expr/tests/test_concat.py b/dask_expr/tests/test_concat.py index e75f63e37..394c4c0cc 100644 --- a/dask_expr/tests/test_concat.py +++ b/dask_expr/tests/test_concat.py @@ -339,3 +339,13 @@ def test_concat_series(pdf): expected = concat([df2.y, df2.x], axis=1)[["x", "y"]] assert q.optimize(fuse=False)._name == expected.optimize(fuse=False)._name assert_eq(q, pd.concat([pdf.y, pdf.x, pdf.z], axis=1)[["x", "y"]]) + + +def test_concat_series_and_projection(df, pdf): + result = concat([df.x, df.y], axis=1)["x"] + expected = pd.concat([pdf.x, pdf.y], axis=1)["x"] + assert_eq(result, expected) + + result = concat([df.x, df.y], axis=1)[["x"]] + expected = pd.concat([pdf.x, pdf.y], axis=1)[["x"]] + assert_eq(result, expected)