From 9029cf7eac2816ab19bbf2f182f602d9743bb5e6 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 13 Aug 2024 16:07:44 -0400 Subject: [PATCH 1/2] Add when function --- python/datafusion/functions.py | 11 +++++++++++ src/functions.rs | 9 +++++++++ 2 files changed, 20 insertions(+) diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index 59a1974fd..ec0c1104d 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -245,6 +245,7 @@ "var", "var_pop", "var_samp", + "when", "window", ] @@ -364,6 +365,16 @@ def case(expr: Expr) -> CaseBuilder: return CaseBuilder(f.case(expr.expr)) +def when(when: Expr, then: Expr) -> CaseBuilder: + """Create a case expression that has no base expression. + + Create a :py:class:`~datafusion.expr.CaseBuilder` to match cases for the + expression ``expr``. See :py:class:`~datafusion.expr.CaseBuilder` for + detailed usage. + """ + return CaseBuilder(f.when(when.expr, then.expr)) + + def window( name: str, args: list[Expr], diff --git a/src/functions.rs b/src/functions.rs index c53d4ad92..252563621 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -533,6 +533,14 @@ fn case(expr: PyExpr) -> PyResult { }) } +/// Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression. +#[pyfunction] +fn when(when: PyExpr, then: PyExpr) -> PyResult { + Ok(PyCaseBuilder { + case_builder: datafusion_expr::when(when.expr, then.expr), + }) +} + /// Helper function to find the appropriate window function. /// /// Search procedure: @@ -910,6 +918,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(char_length))?; m.add_wrapped(wrap_pyfunction!(coalesce))?; m.add_wrapped(wrap_pyfunction!(case))?; + m.add_wrapped(wrap_pyfunction!(when))?; m.add_wrapped(wrap_pyfunction!(col))?; m.add_wrapped(wrap_pyfunction!(concat_ws))?; m.add_wrapped(wrap_pyfunction!(concat))?; From 5236b5d0a33bd26fe23d1c41c483b7627715cd37 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sun, 25 Aug 2024 09:06:42 -0400 Subject: [PATCH 2/2] Add unit test for when statements that have no base case statement --- python/datafusion/tests/test_functions.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/python/datafusion/tests/test_functions.py b/python/datafusion/tests/test_functions.py index 732136eaa..e5429bd60 100644 --- a/python/datafusion/tests/test_functions.py +++ b/python/datafusion/tests/test_functions.py @@ -836,6 +836,25 @@ def test_case(df): assert result.column(2) == pa.array(["Hola", "Mundo", None]) +def test_when_with_no_base(df): + df.show() + df = df.select( + column("b"), + f.when(column("b") > literal(5), literal("too big")) + .when(column("b") < literal(5), literal("too small")) + .otherwise(literal("just right")) + .alias("goldilocks"), + f.when(column("a") == literal("Hello"), column("a")).end().alias("greeting"), + ) + df.show() + + result = df.collect() + result = result[0] + assert result.column(0) == pa.array([4, 5, 6]) + assert result.column(1) == pa.array(["too small", "just right", "too big"]) + assert result.column(2) == pa.array(["Hello", None, None]) + + def test_regr_funcs_sql(df): # test case base on # https://github.com/apache/arrow-datafusion/blob/d1361d56b9a9e0c165d3d71a8df6795d2a5f51dd/datafusion/core/tests/sqllogictests/test_files/aggregate.slt#L2330