-
Notifications
You must be signed in to change notification settings - Fork 1.7k
perf: Optimize CASE for any WHEN false #17835
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
95e6858
b0e3e9b
e370969
befdc48
01e3b6e
a0c4b2e
b25c38f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1436,33 +1436,49 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> { | |
|
||
// CASE WHEN true THEN A ... END --> A | ||
// CASE WHEN X THEN A WHEN TRUE THEN B ... END --> CASE WHEN X THEN A ELSE B END | ||
// CASE WHEN false THEN A END --> NULL | ||
// CASE WHEN false THEN A ELSE B END --> B | ||
// CASE WHEN X THEN A WHEN false THEN B END --> CASE WHEN X THEN A ELSE B END | ||
Expr::Case(Case { | ||
expr: None, | ||
mut when_then_expr, | ||
else_expr: _, | ||
// if let guard is not stabilized so we can't use it yet: https://github.com/rust-lang/rust/issues/51114 | ||
// Once it's supported we can avoid searching through when_then_expr twice in the below .any() and .position() calls | ||
// }) if let Some(i) = when_then_expr.iter().position(|(when, _)| is_true(when.as_ref())) => { | ||
when_then_expr, | ||
mut else_expr, | ||
}) if when_then_expr | ||
.iter() | ||
.any(|(when, _)| is_true(when.as_ref())) => | ||
.any(|(when, _)| is_true(when.as_ref()) || is_false(when.as_ref())) => | ||
{ | ||
let i = when_then_expr | ||
.iter() | ||
.position(|(when, _)| is_true(when.as_ref())) | ||
.unwrap(); | ||
let (_, then_) = when_then_expr.swap_remove(i); | ||
// CASE WHEN true THEN A ... END --> A | ||
if i == 0 { | ||
return Ok(Transformed::yes(*then_)); | ||
let out_type = info.get_data_type(&when_then_expr[0].1)?; | ||
let mut new_when_then_expr = Vec::with_capacity(when_then_expr.len()); | ||
|
||
for (when, then) in when_then_expr.into_iter() { | ||
if is_true(when.as_ref()) { | ||
// Skip adding the rest of the when-then expressions after WHEN true | ||
// CASE WHEN X THEN A WHEN TRUE THEN B ... END --> CASE WHEN X THEN A ELSE B END | ||
else_expr = Some(then); | ||
break; | ||
} else if !is_false(when.as_ref()) { | ||
new_when_then_expr.push((when, then)); | ||
} | ||
// else: skip WHEN false cases | ||
} | ||
|
||
// Exclude CASE statement altogether if there are no when-then expressions left | ||
if new_when_then_expr.is_empty() { | ||
// CASE WHEN false THEN A ELSE B END --> B | ||
if let Some(else_expr) = else_expr { | ||
return Ok(Transformed::yes(*else_expr)); | ||
// CASE WHEN false THEN A END --> NULL | ||
} else { | ||
let null = | ||
Expr::Literal(ScalarValue::try_new_null(&out_type)?, None); | ||
return Ok(Transformed::yes(null)); | ||
} | ||
} | ||
|
||
// CASE WHEN X THEN A WHEN TRUE THEN B ... END --> CASE WHEN X THEN A ELSE B END | ||
when_then_expr.truncate(i); | ||
Transformed::yes(Expr::Case(Case { | ||
expr: None, | ||
when_then_expr, | ||
else_expr: Some(then_), | ||
when_then_expr: new_when_then_expr, | ||
else_expr, | ||
})) | ||
} | ||
|
||
|
@@ -3810,53 +3826,53 @@ mod tests { | |
|
||
#[test] | ||
fn simplify_expr_case_when_first_true() { | ||
// CASE WHEN true THEN 1 ELSE x END --> 1 | ||
// CASE WHEN true THEN 1 ELSE c1 END --> 1 | ||
assert_eq!( | ||
simplify(Expr::Case(Case::new( | ||
None, | ||
vec![(Box::new(lit(true)), Box::new(lit(1)),)], | ||
Some(Box::new(col("x"))), | ||
Some(Box::new(col("c1"))), | ||
))), | ||
lit(1) | ||
); | ||
|
||
// CASE WHEN true THEN col("a") ELSE col("b") END --> col("a") | ||
// CASE WHEN true THEN col('a') ELSE col('b') END --> col('a') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was worried about this change, as it seems to potentially change the intent of the test -- to use literals rather than column references. However, I see the issue is that the DataTypes need to match and after spending some time rewriting these tests to use column references rather than literals I think the literals are fine |
||
assert_eq!( | ||
simplify(Expr::Case(Case::new( | ||
None, | ||
vec![(Box::new(lit(true)), Box::new(col("a")),)], | ||
Some(Box::new(col("b"))), | ||
vec![(Box::new(lit(true)), Box::new(lit("a")),)], | ||
Some(Box::new(lit("b"))), | ||
))), | ||
col("a") | ||
lit("a") | ||
); | ||
|
||
// CASE WHEN true THEN col("a") WHEN col("x") > 5 THEN col("b") ELSE col("c") END --> col("a") | ||
// CASE WHEN true THEN col('a') WHEN col('x') > 5 THEN col('b') ELSE col('c') END --> col('a') | ||
assert_eq!( | ||
simplify(Expr::Case(Case::new( | ||
None, | ||
vec![ | ||
(Box::new(lit(true)), Box::new(col("a"))), | ||
(Box::new(col("x").gt(lit(5))), Box::new(col("b"))), | ||
(Box::new(lit(true)), Box::new(lit("a"))), | ||
(Box::new(lit("x").gt(lit(5))), Box::new(lit("b"))), | ||
], | ||
Some(Box::new(col("c"))), | ||
Some(Box::new(lit("c"))), | ||
))), | ||
col("a") | ||
lit("a") | ||
); | ||
|
||
// CASE WHEN true THEN col("a") END --> col("a") (no else clause) | ||
// CASE WHEN true THEN col('a') END --> col('a') (no else clause) | ||
assert_eq!( | ||
simplify(Expr::Case(Case::new( | ||
None, | ||
vec![(Box::new(lit(true)), Box::new(col("a")),)], | ||
vec![(Box::new(lit(true)), Box::new(lit("a")),)], | ||
None, | ||
))), | ||
col("a") | ||
lit("a") | ||
); | ||
|
||
// Negative test: CASE WHEN a THEN 1 ELSE 2 END should not be simplified | ||
// Negative test: CASE WHEN c2 THEN 1 ELSE 2 END should not be simplified | ||
let expr = Expr::Case(Case::new( | ||
None, | ||
vec![(Box::new(col("a")), Box::new(lit(1)))], | ||
vec![(Box::new(col("c2")), Box::new(lit(1)))], | ||
Some(Box::new(lit(2))), | ||
)); | ||
assert_eq!(simplify(expr.clone()), expr); | ||
|
@@ -3869,87 +3885,135 @@ mod tests { | |
)); | ||
assert_ne!(simplify(expr), lit(1)); | ||
|
||
// Negative test: CASE WHEN col("x") > 5 THEN 1 ELSE 2 END should not be simplified | ||
// Negative test: CASE WHEN col('c1') > 5 THEN 1 ELSE 2 END should not be simplified | ||
let expr = Expr::Case(Case::new( | ||
None, | ||
vec![(Box::new(col("x").gt(lit(5))), Box::new(lit(1)))], | ||
vec![(Box::new(col("c1").gt(lit(5))), Box::new(lit(1)))], | ||
Some(Box::new(lit(2))), | ||
)); | ||
assert_eq!(simplify(expr.clone()), expr); | ||
} | ||
|
||
#[test] | ||
fn simplify_expr_case_when_any_true() { | ||
// CASE WHEN x > 0 THEN a WHEN true THEN b ELSE c END --> CASE WHEN x > 0 THEN a ELSE b END | ||
// CASE WHEN c3 > 0 THEN 'a' WHEN true THEN 'b' ELSE 'c' END --> CASE WHEN c3 > 0 THEN 'a' ELSE 'b' END | ||
assert_eq!( | ||
simplify(Expr::Case(Case::new( | ||
None, | ||
vec![ | ||
(Box::new(col("x").gt(lit(0))), Box::new(col("a"))), | ||
(Box::new(lit(true)), Box::new(col("b"))), | ||
(Box::new(col("c3").gt(lit(0))), Box::new(lit("a"))), | ||
(Box::new(lit(true)), Box::new(lit("b"))), | ||
], | ||
Some(Box::new(col("c"))), | ||
Some(Box::new(lit("c"))), | ||
))), | ||
Expr::Case(Case::new( | ||
None, | ||
vec![(Box::new(col("x").gt(lit(0))), Box::new(col("a")))], | ||
Some(Box::new(col("b"))), | ||
vec![(Box::new(col("c3").gt(lit(0))), Box::new(lit("a")))], | ||
Some(Box::new(lit("b"))), | ||
)) | ||
); | ||
|
||
// CASE WHEN x > 0 THEN a WHEN y < 0 THEN b WHEN true THEN c WHEN z = 0 THEN d ELSE e END | ||
// --> CASE WHEN x > 0 THEN a WHEN y < 0 THEN b ELSE c END | ||
// CASE WHEN c3 > 0 THEN 'a' WHEN c4 < 0 THEN 'b' WHEN true THEN 'c' WHEN c3 = 0 THEN 'd' ELSE 'e' END | ||
// --> CASE WHEN c3 > 0 THEN 'a' WHEN c4 < 0 THEN 'b' ELSE 'c' END | ||
assert_eq!( | ||
simplify(Expr::Case(Case::new( | ||
None, | ||
vec![ | ||
(Box::new(col("x").gt(lit(0))), Box::new(col("a"))), | ||
(Box::new(col("y").lt(lit(0))), Box::new(col("b"))), | ||
(Box::new(lit(true)), Box::new(col("c"))), | ||
(Box::new(col("z").eq(lit(0))), Box::new(col("d"))), | ||
(Box::new(col("c3").gt(lit(0))), Box::new(lit("a"))), | ||
(Box::new(col("c4").lt(lit(0))), Box::new(lit("b"))), | ||
(Box::new(lit(true)), Box::new(lit("c"))), | ||
(Box::new(col("c3").eq(lit(0))), Box::new(lit("d"))), | ||
], | ||
Some(Box::new(col("e"))), | ||
Some(Box::new(lit("e"))), | ||
))), | ||
Expr::Case(Case::new( | ||
None, | ||
vec![ | ||
(Box::new(col("x").gt(lit(0))), Box::new(col("a"))), | ||
(Box::new(col("y").lt(lit(0))), Box::new(col("b"))), | ||
(Box::new(col("c3").gt(lit(0))), Box::new(lit("a"))), | ||
(Box::new(col("c4").lt(lit(0))), Box::new(lit("b"))), | ||
], | ||
Some(Box::new(col("c"))), | ||
Some(Box::new(lit("c"))), | ||
)) | ||
); | ||
|
||
// CASE WHEN x > 0 THEN a WHEN y < 0 THEN b WHEN true THEN c END (no else) | ||
// --> CASE WHEN x > 0 THEN a WHEN y < 0 THEN b ELSE c END | ||
// CASE WHEN c3 > 0 THEN 1 WHEN c4 < 0 THEN 2 WHEN true THEN 3 END (no else) | ||
// --> CASE WHEN c3 > 0 THEN 1 WHEN c4 < 0 THEN 2 ELSE 3 END | ||
assert_eq!( | ||
simplify(Expr::Case(Case::new( | ||
None, | ||
vec![ | ||
(Box::new(col("x").gt(lit(0))), Box::new(col("a"))), | ||
(Box::new(col("y").lt(lit(0))), Box::new(col("b"))), | ||
(Box::new(lit(true)), Box::new(col("c"))), | ||
(Box::new(col("c3").gt(lit(0))), Box::new(lit(1))), | ||
(Box::new(col("c4").lt(lit(0))), Box::new(lit(2))), | ||
(Box::new(lit(true)), Box::new(lit(3))), | ||
], | ||
None, | ||
))), | ||
Expr::Case(Case::new( | ||
None, | ||
vec![ | ||
(Box::new(col("x").gt(lit(0))), Box::new(col("a"))), | ||
(Box::new(col("y").lt(lit(0))), Box::new(col("b"))), | ||
(Box::new(col("c3").gt(lit(0))), Box::new(lit(1))), | ||
(Box::new(col("c4").lt(lit(0))), Box::new(lit(2))), | ||
], | ||
Some(Box::new(col("c"))), | ||
Some(Box::new(lit(3))), | ||
)) | ||
); | ||
|
||
// Negative test: CASE WHEN x > 0 THEN a WHEN y < 0 THEN b ELSE c END should not be simplified | ||
// Negative test: CASE WHEN c3 > 0 THEN c3 WHEN c4 < 0 THEN 2 ELSE 3 END should not be simplified | ||
let expr = Expr::Case(Case::new( | ||
None, | ||
vec![ | ||
(Box::new(col("x").gt(lit(0))), Box::new(col("a"))), | ||
(Box::new(col("y").lt(lit(0))), Box::new(col("b"))), | ||
(Box::new(col("c3").gt(lit(0))), Box::new(col("c3"))), | ||
(Box::new(col("c4").lt(lit(0))), Box::new(lit(2))), | ||
], | ||
Some(Box::new(col("c"))), | ||
Some(Box::new(lit(3))), | ||
)); | ||
assert_eq!(simplify(expr.clone()), expr); | ||
} | ||
|
||
#[test] | ||
fn simplify_expr_case_when_any_false() { | ||
// CASE WHEN false THEN 'a' END --> NULL | ||
assert_eq!( | ||
simplify(Expr::Case(Case::new( | ||
None, | ||
vec![(Box::new(lit(false)), Box::new(lit("a")))], | ||
None, | ||
))), | ||
Expr::Literal(ScalarValue::Utf8(None), None) | ||
); | ||
|
||
// CASE WHEN false THEN 2 ELSE 1 END --> 1 | ||
assert_eq!( | ||
simplify(Expr::Case(Case::new( | ||
None, | ||
vec![(Box::new(lit(false)), Box::new(lit(2)))], | ||
Some(Box::new(lit(1))), | ||
))), | ||
lit(1), | ||
); | ||
|
||
// CASE WHEN c3 < 10 THEN 'b' WHEN false then c3 ELSE c4 END --> CASE WHEN c3 < 10 THEN b ELSE c4 END | ||
assert_eq!( | ||
simplify(Expr::Case(Case::new( | ||
None, | ||
vec![ | ||
(Box::new(col("c3").lt(lit(10))), Box::new(lit("b"))), | ||
(Box::new(lit(false)), Box::new(col("c3"))), | ||
], | ||
Some(Box::new(col("c4"))), | ||
))), | ||
Expr::Case(Case::new( | ||
None, | ||
vec![(Box::new(col("c3").lt(lit(10))), Box::new(lit("b")))], | ||
Some(Box::new(col("c4"))), | ||
)) | ||
); | ||
|
||
// Negative test: CASE WHEN c3 = 4 THEN 1 ELSE 2 END should not be simplified | ||
let expr = Expr::Case(Case::new( | ||
None, | ||
vec![(Box::new(col("c3").eq(lit(4))), Box::new(lit(1)))], | ||
Some(Box::new(lit(2))), | ||
)); | ||
assert_eq!(simplify(expr.clone()), expr); | ||
} | ||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Introducing this
get_data_type
call made some of the existing tests fail because it was trying to get the data type of a column that didn't exist in the schema. I updated the existing tests to use the actual column names e.g (col("c1")
,col("c3")
) or string literals (e.g lit("a")) instead of the invalid column names (e.gcol("a")
) hence why so many random changes in the old tests. When I ran queries in the CLI, it seemed like Datafusion was catching the invalid column names before it got to this code, so I think this should be safe.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It took me moment to convince myself that we did not need to gate on
!when_then_expr.empty()
to ensurewhen_then_expr[0]
doesn't panic -- and that is because.any()
needs at least one expr to evaluate to true.TLDR this is fine, I am just recording my thought process for anyone else who is interested