diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 3de5a80a9782..97dfc09c4f2a 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -1399,6 +1399,16 @@ impl TreeNodeRewriter for Simplifier<'_, S> { // Rules for Case // + // CASE WHEN true THEN A ... END --> A + Expr::Case(Case { + expr: None, + mut when_then_expr, + else_expr: _, + }) if !when_then_expr.is_empty() && is_true(when_then_expr[0].0.as_ref()) => { + let (_, then_) = when_then_expr.swap_remove(0); + Transformed::yes(*then_) + } + // CASE // WHEN X THEN A // WHEN Y THEN B @@ -3552,6 +3562,76 @@ mod tests { ); } + #[test] + fn simplify_expr_case_when_true() { + // CASE WHEN true THEN 1 ELSE x END --> 1 + assert_eq!( + simplify(Expr::Case(Case::new( + None, + vec![(Box::new(lit(true)), Box::new(lit(1)),)], + Some(Box::new(col("x"))), + ))), + lit(1) + ); + + // CASE WHEN true THEN col("a") ELSE col("b") END --> col("a") + assert_eq!( + simplify(Expr::Case(Case::new( + None, + vec![(Box::new(lit(true)), Box::new(col("a")),)], + Some(Box::new(col("b"))), + ))), + col("a") + ); + + // CASE WHEN true THEN col("a") WHEN col("x") > 5 THEN col("b") ELSE col("c") END --> col("a") + assert_eq!( + simplify(Expr::Case(Case::new( + None, + vec![ + (Box::new(lit(true)), Box::new(col("a"))), + (Box::new(col("x").gt(lit(5))), Box::new(col("b"))), + ], + Some(Box::new(col("c"))), + ))), + col("a") + ); + + // CASE WHEN true THEN col("a") END --> col("a") (no else clause) + assert_eq!( + simplify(Expr::Case(Case::new( + None, + vec![(Box::new(lit(true)), Box::new(col("a")),)], + None, + ))), + col("a") + ); + + // Negative test: CASE WHEN a THEN 1 ELSE 2 END should not be simplified + let expr = Expr::Case(Case::new( + None, + vec![(Box::new(col("a")), Box::new(lit(1)))], + Some(Box::new(lit(2))), + )); + assert_eq!(simplify(expr.clone()), expr); + + // Negative test: CASE WHEN false THEN 1 ELSE 2 END should not use this rule + let expr = Expr::Case(Case::new( + None, + vec![(Box::new(lit(false)), Box::new(lit(1)))], + Some(Box::new(lit(2))), + )); + assert_ne!(simplify(expr), lit(1)); + + // Negative test: CASE WHEN col("x") > 5 THEN 1 ELSE 2 END should not be simplified + let expr = Expr::Case(Case::new( + None, + vec![(Box::new(col("x").gt(lit(5))), Box::new(lit(1)))], + Some(Box::new(lit(2))), + )); + assert_eq!(simplify(expr.clone()), expr); + } + fn distinct_from(left: impl Into, right: impl Into) -> Expr { Expr::BinaryExpr(BinaryExpr { left: Box::new(left.into()), diff --git a/datafusion/sqllogictest/test_files/projection.slt b/datafusion/sqllogictest/test_files/projection.slt index 97ebe2340dc2..9f840e7bdc2f 100644 --- a/datafusion/sqllogictest/test_files/projection.slt +++ b/datafusion/sqllogictest/test_files/projection.slt @@ -253,7 +253,7 @@ physical_plan statement ok drop table t; -# Regression test for +# Regression test for # https://github.com/apache/datafusion/issues/17513 query I diff --git a/datafusion/sqllogictest/test_files/select.slt b/datafusion/sqllogictest/test_files/select.slt index 989f7df7b49e..cd1f90c42efd 100644 --- a/datafusion/sqllogictest/test_files/select.slt +++ b/datafusion/sqllogictest/test_files/select.slt @@ -1656,10 +1656,10 @@ query TT explain select coalesce(1, y/x), coalesce(2, y/x) from t; ---- logical_plan -01)Projection: CASE WHEN Boolean(true) THEN Int64(1) ELSE CAST(t.y / t.x AS Int64) END AS coalesce(Int64(1),t.y / t.x), CASE WHEN Boolean(true) THEN Int64(2) ELSE CAST(t.y / t.x AS Int64) END AS coalesce(Int64(2),t.y / t.x) -02)--TableScan: t projection=[x, y] +01)Projection: Int64(1) AS coalesce(Int64(1),t.y / t.x), Int64(2) AS coalesce(Int64(2),t.y / t.x) +02)--TableScan: t projection=[] physical_plan -01)ProjectionExec: expr=[CASE WHEN true THEN 1 ELSE CAST(y@1 / x@0 AS Int64) END as coalesce(Int64(1),t.y / t.x), CASE WHEN true THEN 2 ELSE CAST(y@1 / x@0 AS Int64) END as coalesce(Int64(2),t.y / t.x)] +01)ProjectionExec: expr=[1 as coalesce(Int64(1),t.y / t.x), 2 as coalesce(Int64(2),t.y / t.x)] 02)--DataSourceExec: partitions=1, partition_sizes=[1] query TT