Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 128 additions & 64 deletions datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1436,33 +1436,49 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {

// CASE WHEN true THEN A ... END --> A
// CASE WHEN X THEN A WHEN TRUE THEN B ... END --> CASE WHEN X THEN A ELSE B END
// CASE WHEN false THEN A END --> NULL
// CASE WHEN false THEN A ELSE B END --> B
// CASE WHEN X THEN A WHEN false THEN B END --> CASE WHEN X THEN A ELSE B END
Expr::Case(Case {
expr: None,
mut when_then_expr,
else_expr: _,
// if let guard is not stabilized so we can't use it yet: https://github.com/rust-lang/rust/issues/51114
// Once it's supported we can avoid searching through when_then_expr twice in the below .any() and .position() calls
// }) if let Some(i) = when_then_expr.iter().position(|(when, _)| is_true(when.as_ref())) => {
when_then_expr,
mut else_expr,
}) if when_then_expr
.iter()
.any(|(when, _)| is_true(when.as_ref())) =>
.any(|(when, _)| is_true(when.as_ref()) || is_false(when.as_ref())) =>
{
let i = when_then_expr
.iter()
.position(|(when, _)| is_true(when.as_ref()))
.unwrap();
let (_, then_) = when_then_expr.swap_remove(i);
// CASE WHEN true THEN A ... END --> A
if i == 0 {
return Ok(Transformed::yes(*then_));
let out_type = info.get_data_type(&when_then_expr[0].1)?;
Copy link
Contributor Author

@petern48 petern48 Sep 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Introducing this get_data_type call made some of the existing tests fail because it was trying to get the data type of a column that didn't exist in the schema. I updated the existing tests to use the actual column names e.g (col("c1"), col("c3")) or string literals (e.g lit("a")) instead of the invalid column names (e.g col("a")) hence why so many random changes in the old tests. When I ran queries in the CLI, it seemed like Datafusion was catching the invalid column names before it got to this code, so I think this should be safe.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It took me moment to convince myself that we did not need to gate on !when_then_expr.empty() to ensure when_then_expr[0] doesn't panic -- and that is because .any() needs at least one expr to evaluate to true.

TLDR this is fine, I am just recording my thought process for anyone else who is interested

let mut new_when_then_expr = Vec::with_capacity(when_then_expr.len());

for (when, then) in when_then_expr.into_iter() {
if is_true(when.as_ref()) {
// Skip adding the rest of the when-then expressions after WHEN true
// CASE WHEN X THEN A WHEN TRUE THEN B ... END --> CASE WHEN X THEN A ELSE B END
else_expr = Some(then);
break;
} else if !is_false(when.as_ref()) {
new_when_then_expr.push((when, then));
}
// else: skip WHEN false cases
}

// Exclude CASE statement altogether if there are no when-then expressions left
if new_when_then_expr.is_empty() {
// CASE WHEN false THEN A ELSE B END --> B
if let Some(else_expr) = else_expr {
return Ok(Transformed::yes(*else_expr));
// CASE WHEN false THEN A END --> NULL
} else {
let null =
Expr::Literal(ScalarValue::try_new_null(&out_type)?, None);
return Ok(Transformed::yes(null));
}
}

// CASE WHEN X THEN A WHEN TRUE THEN B ... END --> CASE WHEN X THEN A ELSE B END
when_then_expr.truncate(i);
Transformed::yes(Expr::Case(Case {
expr: None,
when_then_expr,
else_expr: Some(then_),
when_then_expr: new_when_then_expr,
else_expr,
}))
}

Expand Down Expand Up @@ -3810,53 +3826,53 @@ mod tests {

#[test]
fn simplify_expr_case_when_first_true() {
// CASE WHEN true THEN 1 ELSE x END --> 1
// CASE WHEN true THEN 1 ELSE c1 END --> 1
assert_eq!(
simplify(Expr::Case(Case::new(
None,
vec![(Box::new(lit(true)), Box::new(lit(1)),)],
Some(Box::new(col("x"))),
Some(Box::new(col("c1"))),
))),
lit(1)
);

// CASE WHEN true THEN col("a") ELSE col("b") END --> col("a")
// CASE WHEN true THEN col('a') ELSE col('b') END --> col('a')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was worried about this change, as it seems to potentially change the intent of the test -- to use literals rather than column references.

However, I see the issue is that the DataTypes need to match and after spending some time rewriting these tests to use column references rather than literals I think the literals are fine

assert_eq!(
simplify(Expr::Case(Case::new(
None,
vec![(Box::new(lit(true)), Box::new(col("a")),)],
Some(Box::new(col("b"))),
vec![(Box::new(lit(true)), Box::new(lit("a")),)],
Some(Box::new(lit("b"))),
))),
col("a")
lit("a")
);

// CASE WHEN true THEN col("a") WHEN col("x") > 5 THEN col("b") ELSE col("c") END --> col("a")
// CASE WHEN true THEN col('a') WHEN col('x') > 5 THEN col('b') ELSE col('c') END --> col('a')
assert_eq!(
simplify(Expr::Case(Case::new(
None,
vec![
(Box::new(lit(true)), Box::new(col("a"))),
(Box::new(col("x").gt(lit(5))), Box::new(col("b"))),
(Box::new(lit(true)), Box::new(lit("a"))),
(Box::new(lit("x").gt(lit(5))), Box::new(lit("b"))),
],
Some(Box::new(col("c"))),
Some(Box::new(lit("c"))),
))),
col("a")
lit("a")
);

// CASE WHEN true THEN col("a") END --> col("a") (no else clause)
// CASE WHEN true THEN col('a') END --> col('a') (no else clause)
assert_eq!(
simplify(Expr::Case(Case::new(
None,
vec![(Box::new(lit(true)), Box::new(col("a")),)],
vec![(Box::new(lit(true)), Box::new(lit("a")),)],
None,
))),
col("a")
lit("a")
);

// Negative test: CASE WHEN a THEN 1 ELSE 2 END should not be simplified
// Negative test: CASE WHEN c2 THEN 1 ELSE 2 END should not be simplified
let expr = Expr::Case(Case::new(
None,
vec![(Box::new(col("a")), Box::new(lit(1)))],
vec![(Box::new(col("c2")), Box::new(lit(1)))],
Some(Box::new(lit(2))),
));
assert_eq!(simplify(expr.clone()), expr);
Expand All @@ -3869,87 +3885,135 @@ mod tests {
));
assert_ne!(simplify(expr), lit(1));

// Negative test: CASE WHEN col("x") > 5 THEN 1 ELSE 2 END should not be simplified
// Negative test: CASE WHEN col('c1') > 5 THEN 1 ELSE 2 END should not be simplified
let expr = Expr::Case(Case::new(
None,
vec![(Box::new(col("x").gt(lit(5))), Box::new(lit(1)))],
vec![(Box::new(col("c1").gt(lit(5))), Box::new(lit(1)))],
Some(Box::new(lit(2))),
));
assert_eq!(simplify(expr.clone()), expr);
}

#[test]
fn simplify_expr_case_when_any_true() {
// CASE WHEN x > 0 THEN a WHEN true THEN b ELSE c END --> CASE WHEN x > 0 THEN a ELSE b END
// CASE WHEN c3 > 0 THEN 'a' WHEN true THEN 'b' ELSE 'c' END --> CASE WHEN c3 > 0 THEN 'a' ELSE 'b' END
assert_eq!(
simplify(Expr::Case(Case::new(
None,
vec![
(Box::new(col("x").gt(lit(0))), Box::new(col("a"))),
(Box::new(lit(true)), Box::new(col("b"))),
(Box::new(col("c3").gt(lit(0))), Box::new(lit("a"))),
(Box::new(lit(true)), Box::new(lit("b"))),
],
Some(Box::new(col("c"))),
Some(Box::new(lit("c"))),
))),
Expr::Case(Case::new(
None,
vec![(Box::new(col("x").gt(lit(0))), Box::new(col("a")))],
Some(Box::new(col("b"))),
vec![(Box::new(col("c3").gt(lit(0))), Box::new(lit("a")))],
Some(Box::new(lit("b"))),
))
);

// CASE WHEN x > 0 THEN a WHEN y < 0 THEN b WHEN true THEN c WHEN z = 0 THEN d ELSE e END
// --> CASE WHEN x > 0 THEN a WHEN y < 0 THEN b ELSE c END
// CASE WHEN c3 > 0 THEN 'a' WHEN c4 < 0 THEN 'b' WHEN true THEN 'c' WHEN c3 = 0 THEN 'd' ELSE 'e' END
// --> CASE WHEN c3 > 0 THEN 'a' WHEN c4 < 0 THEN 'b' ELSE 'c' END
assert_eq!(
simplify(Expr::Case(Case::new(
None,
vec![
(Box::new(col("x").gt(lit(0))), Box::new(col("a"))),
(Box::new(col("y").lt(lit(0))), Box::new(col("b"))),
(Box::new(lit(true)), Box::new(col("c"))),
(Box::new(col("z").eq(lit(0))), Box::new(col("d"))),
(Box::new(col("c3").gt(lit(0))), Box::new(lit("a"))),
(Box::new(col("c4").lt(lit(0))), Box::new(lit("b"))),
(Box::new(lit(true)), Box::new(lit("c"))),
(Box::new(col("c3").eq(lit(0))), Box::new(lit("d"))),
],
Some(Box::new(col("e"))),
Some(Box::new(lit("e"))),
))),
Expr::Case(Case::new(
None,
vec![
(Box::new(col("x").gt(lit(0))), Box::new(col("a"))),
(Box::new(col("y").lt(lit(0))), Box::new(col("b"))),
(Box::new(col("c3").gt(lit(0))), Box::new(lit("a"))),
(Box::new(col("c4").lt(lit(0))), Box::new(lit("b"))),
],
Some(Box::new(col("c"))),
Some(Box::new(lit("c"))),
))
);

// CASE WHEN x > 0 THEN a WHEN y < 0 THEN b WHEN true THEN c END (no else)
// --> CASE WHEN x > 0 THEN a WHEN y < 0 THEN b ELSE c END
// CASE WHEN c3 > 0 THEN 1 WHEN c4 < 0 THEN 2 WHEN true THEN 3 END (no else)
// --> CASE WHEN c3 > 0 THEN 1 WHEN c4 < 0 THEN 2 ELSE 3 END
assert_eq!(
simplify(Expr::Case(Case::new(
None,
vec![
(Box::new(col("x").gt(lit(0))), Box::new(col("a"))),
(Box::new(col("y").lt(lit(0))), Box::new(col("b"))),
(Box::new(lit(true)), Box::new(col("c"))),
(Box::new(col("c3").gt(lit(0))), Box::new(lit(1))),
(Box::new(col("c4").lt(lit(0))), Box::new(lit(2))),
(Box::new(lit(true)), Box::new(lit(3))),
],
None,
))),
Expr::Case(Case::new(
None,
vec![
(Box::new(col("x").gt(lit(0))), Box::new(col("a"))),
(Box::new(col("y").lt(lit(0))), Box::new(col("b"))),
(Box::new(col("c3").gt(lit(0))), Box::new(lit(1))),
(Box::new(col("c4").lt(lit(0))), Box::new(lit(2))),
],
Some(Box::new(col("c"))),
Some(Box::new(lit(3))),
))
);

// Negative test: CASE WHEN x > 0 THEN a WHEN y < 0 THEN b ELSE c END should not be simplified
// Negative test: CASE WHEN c3 > 0 THEN c3 WHEN c4 < 0 THEN 2 ELSE 3 END should not be simplified
let expr = Expr::Case(Case::new(
None,
vec![
(Box::new(col("x").gt(lit(0))), Box::new(col("a"))),
(Box::new(col("y").lt(lit(0))), Box::new(col("b"))),
(Box::new(col("c3").gt(lit(0))), Box::new(col("c3"))),
(Box::new(col("c4").lt(lit(0))), Box::new(lit(2))),
],
Some(Box::new(col("c"))),
Some(Box::new(lit(3))),
));
assert_eq!(simplify(expr.clone()), expr);
}

#[test]
fn simplify_expr_case_when_any_false() {
// CASE WHEN false THEN 'a' END --> NULL
assert_eq!(
simplify(Expr::Case(Case::new(
None,
vec![(Box::new(lit(false)), Box::new(lit("a")))],
None,
))),
Expr::Literal(ScalarValue::Utf8(None), None)
);

// CASE WHEN false THEN 2 ELSE 1 END --> 1
assert_eq!(
simplify(Expr::Case(Case::new(
None,
vec![(Box::new(lit(false)), Box::new(lit(2)))],
Some(Box::new(lit(1))),
))),
lit(1),
);

// CASE WHEN c3 < 10 THEN 'b' WHEN false then c3 ELSE c4 END --> CASE WHEN c3 < 10 THEN b ELSE c4 END
assert_eq!(
simplify(Expr::Case(Case::new(
None,
vec![
(Box::new(col("c3").lt(lit(10))), Box::new(lit("b"))),
(Box::new(lit(false)), Box::new(col("c3"))),
],
Some(Box::new(col("c4"))),
))),
Expr::Case(Case::new(
None,
vec![(Box::new(col("c3").lt(lit(10))), Box::new(lit("b")))],
Some(Box::new(col("c4"))),
))
);

// Negative test: CASE WHEN c3 = 4 THEN 1 ELSE 2 END should not be simplified
let expr = Expr::Case(Case::new(
None,
vec![(Box::new(col("c3").eq(lit(4))), Box::new(lit(1)))],
Some(Box::new(lit(2))),
));
assert_eq!(simplify(expr.clone()), expr);
}
Expand Down