|
15 | 15 | // specific language governing permissions and limitations
|
16 | 16 | // under the License.
|
17 | 17 |
|
18 |
| -use arrow::array::{new_null_array, BooleanArray}; |
19 |
| -use arrow::compute::kernels::zip::zip; |
20 |
| -use arrow::compute::{and, is_not_null, is_null}; |
21 | 18 | use arrow::datatypes::{DataType, Field, FieldRef};
|
22 |
| -use datafusion_common::{exec_err, internal_err, Result}; |
| 19 | +use datafusion_common::{exec_err, internal_err, plan_err, Result}; |
23 | 20 | use datafusion_expr::binary::try_type_union_resolution;
|
| 21 | +use datafusion_expr::conditional_expressions::CaseBuilder; |
| 22 | +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; |
24 | 23 | use datafusion_expr::{
|
25 |
| - ColumnarValue, Documentation, ReturnFieldArgs, ScalarFunctionArgs, |
| 24 | + ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs, |
26 | 25 | };
|
27 | 26 | use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
|
28 | 27 | use datafusion_macros::user_doc;
|
@@ -95,61 +94,36 @@ impl ScalarUDFImpl for CoalesceFunc {
|
95 | 94 | Ok(Field::new(self.name(), return_type, nullable).into())
|
96 | 95 | }
|
97 | 96 |
|
98 |
| - /// coalesce evaluates to the first value which is not NULL |
99 |
| - fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> { |
100 |
| - let args = args.args; |
101 |
| - // do not accept 0 arguments. |
| 97 | + fn simplify( |
| 98 | + &self, |
| 99 | + args: Vec<Expr>, |
| 100 | + _info: &dyn SimplifyInfo, |
| 101 | + ) -> Result<ExprSimplifyResult> { |
102 | 102 | if args.is_empty() {
|
103 |
| - return exec_err!( |
104 |
| - "coalesce was called with {} arguments. It requires at least 1.", |
105 |
| - args.len() |
106 |
| - ); |
| 103 | + return plan_err!("coalesce must have at least one argument"); |
107 | 104 | }
|
108 |
| - |
109 |
| - let return_type = args[0].data_type(); |
110 |
| - let mut return_array = args.iter().filter_map(|x| match x { |
111 |
| - ColumnarValue::Array(array) => Some(array.len()), |
112 |
| - _ => None, |
113 |
| - }); |
114 |
| - |
115 |
| - if let Some(size) = return_array.next() { |
116 |
| - // start with nulls as default output |
117 |
| - let mut current_value = new_null_array(&return_type, size); |
118 |
| - let mut remainder = BooleanArray::from(vec![true; size]); |
119 |
| - |
120 |
| - for arg in args { |
121 |
| - match arg { |
122 |
| - ColumnarValue::Array(ref array) => { |
123 |
| - let to_apply = and(&remainder, &is_not_null(array.as_ref())?)?; |
124 |
| - current_value = zip(&to_apply, array, ¤t_value)?; |
125 |
| - remainder = and(&remainder, &is_null(array)?)?; |
126 |
| - } |
127 |
| - ColumnarValue::Scalar(value) => { |
128 |
| - if value.is_null() { |
129 |
| - continue; |
130 |
| - } else { |
131 |
| - let last_value = value.to_scalar()?; |
132 |
| - current_value = zip(&remainder, &last_value, ¤t_value)?; |
133 |
| - break; |
134 |
| - } |
135 |
| - } |
136 |
| - } |
137 |
| - if remainder.iter().all(|x| x == Some(false)) { |
138 |
| - break; |
139 |
| - } |
140 |
| - } |
141 |
| - Ok(ColumnarValue::Array(current_value)) |
142 |
| - } else { |
143 |
| - let result = args |
144 |
| - .iter() |
145 |
| - .filter_map(|x| match x { |
146 |
| - ColumnarValue::Scalar(s) if !s.is_null() => Some(x.clone()), |
147 |
| - _ => None, |
148 |
| - }) |
149 |
| - .next() |
150 |
| - .unwrap_or_else(|| args[0].clone()); |
151 |
| - Ok(result) |
| 105 | + if args.len() == 1 { |
| 106 | + return Ok(ExprSimplifyResult::Simplified( |
| 107 | + args.into_iter().next().unwrap(), |
| 108 | + )); |
152 | 109 | }
|
| 110 | + |
| 111 | + let n = args.len(); |
| 112 | + let (init, last_elem) = args.split_at(n - 1); |
| 113 | + let whens = init |
| 114 | + .iter() |
| 115 | + .map(|x| x.clone().is_not_null()) |
| 116 | + .collect::<Vec<_>>(); |
| 117 | + let cases = init.to_vec(); |
| 118 | + Ok(ExprSimplifyResult::Simplified( |
| 119 | + CaseBuilder::new(None, whens, cases, Some(Box::new(last_elem[0].clone()))) |
| 120 | + .end()?, |
| 121 | + )) |
| 122 | + } |
| 123 | + |
| 124 | + /// coalesce evaluates to the first value which is not NULL |
| 125 | + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> { |
| 126 | + internal_err!("coalesce should have been simplified to case") |
153 | 127 | }
|
154 | 128 |
|
155 | 129 | fn short_circuits(&self) -> bool {
|
|
0 commit comments