Skip to content

Commit 771c20c

Browse files
viiryaalamb
andauthored
Support round() function with two parameters (apache#5807)
* Physical round expression supports two parameters * fix * Update datafusion/physical-expr/src/math_expressions.rs Co-authored-by: Andrew Lamb <[email protected]> * fix format * Add sqllogictests test for math --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent 5bc0051 commit 771c20c

File tree

6 files changed

+266
-69
lines changed

6 files changed

+266
-69
lines changed
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
##########
19+
## Math expression Tests
20+
##########
21+
22+
statement ok
23+
CREATE external table aggregate_simple(c1 real, c2 double, c3 boolean) STORED as CSV WITH HEADER ROW LOCATION 'tests/data/aggregate_simple.csv';
24+
25+
# Round
26+
query R
27+
SELECT ROUND(c1) FROM aggregate_simple
28+
----
29+
0
30+
0
31+
0
32+
0
33+
0
34+
0
35+
0
36+
0
37+
0
38+
0
39+
0
40+
0
41+
0
42+
0
43+
0
44+
45+
# Round
46+
query R
47+
SELECT round(c1/3, 2) FROM aggregate_simple order by c1
48+
----
49+
0
50+
0
51+
0
52+
0
53+
0
54+
0
55+
0
56+
0
57+
0
58+
0
59+
0
60+
0
61+
0
62+
0
63+
0
64+
65+
# Round
66+
query R
67+
SELECT round(c1, 4) FROM aggregate_simple order by c1
68+
----
69+
0
70+
0
71+
0
72+
0
73+
0
74+
0
75+
0
76+
0
77+
0
78+
0
79+
0.0001
80+
0.0001
81+
0.0001
82+
0.0001
83+
0.0001
84+
85+
# Round
86+
query RRRRRRRR
87+
SELECT round(125.2345, -3), round(125.2345, -2), round(125.2345, -1), round(125.2345), round(125.2345, 0), round(125.2345, 1), round(125.2345, 2), round(125.2345, 3)
88+
----
89+
0 100 130 125 125 125.2 125.23 125.235

datafusion/expr/src/expr_fn.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,7 @@ scalar_expr!(
464464
num,
465465
"nearest integer greater than or equal to argument"
466466
);
467-
scalar_expr!(Round, round, num, "round to nearest integer");
467+
nary_scalar_expr!(Round, round, "round to nearest integer");
468468
scalar_expr!(Trunc, trunc, num, "truncate toward zero");
469469
scalar_expr!(Abs, abs, num, "absolute value");
470470
scalar_expr!(Signum, signum, num, "sign of the argument (-1, 0, +1) ");
@@ -766,7 +766,8 @@ mod test {
766766
test_unary_scalar_expr!(Atan, atan);
767767
test_unary_scalar_expr!(Floor, floor);
768768
test_unary_scalar_expr!(Ceil, ceil);
769-
test_unary_scalar_expr!(Round, round);
769+
test_nary_scalar_expr!(Round, round, input);
770+
test_nary_scalar_expr!(Round, round, input, decimal_places);
770771
test_unary_scalar_expr!(Trunc, trunc);
771772
test_unary_scalar_expr!(Abs, abs);
772773
test_unary_scalar_expr!(Signum, signum);

datafusion/physical-expr/src/functions.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,9 @@ pub fn create_physical_fun(
347347
BuiltinScalarFunction::Log10 => Arc::new(math_expressions::log10),
348348
BuiltinScalarFunction::Log2 => Arc::new(math_expressions::log2),
349349
BuiltinScalarFunction::Random => Arc::new(math_expressions::random),
350-
BuiltinScalarFunction::Round => Arc::new(math_expressions::round),
350+
BuiltinScalarFunction::Round => {
351+
Arc::new(|args| make_scalar_function(math_expressions::round)(args))
352+
}
351353
BuiltinScalarFunction::Signum => Arc::new(math_expressions::signum),
352354
BuiltinScalarFunction::Sin => Arc::new(math_expressions::sin),
353355
BuiltinScalarFunction::Sqrt => Arc::new(math_expressions::sqrt),

datafusion/physical-expr/src/math_expressions.rs

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,18 @@ macro_rules! make_function_inputs2 {
113113
})
114114
.collect::<$ARRAY_TYPE>()
115115
}};
116+
($ARG1: expr, $ARG2: expr, $NAME1:expr, $NAME2: expr, $ARRAY_TYPE1:ident, $ARRAY_TYPE2:ident, $FUNC: block) => {{
117+
let arg1 = downcast_arg!($ARG1, $NAME1, $ARRAY_TYPE1);
118+
let arg2 = downcast_arg!($ARG2, $NAME2, $ARRAY_TYPE2);
119+
120+
arg1.iter()
121+
.zip(arg2.iter())
122+
.map(|(a1, a2)| match (a1, a2) {
123+
(Some(a1), Some(a2)) => Some($FUNC(a1, a2.try_into().ok()?)),
124+
_ => None,
125+
})
126+
.collect::<$ARRAY_TYPE1>()
127+
}};
116128
}
117129

118130
math_unary_function!("sqrt", sqrt);
@@ -124,7 +136,6 @@ math_unary_function!("acos", acos);
124136
math_unary_function!("atan", atan);
125137
math_unary_function!("floor", floor);
126138
math_unary_function!("ceil", ceil);
127-
math_unary_function!("round", round);
128139
math_unary_function!("trunc", trunc);
129140
math_unary_function!("abs", abs);
130141
math_unary_function!("signum", signum);
@@ -149,6 +160,59 @@ pub fn random(args: &[ColumnarValue]) -> Result<ColumnarValue> {
149160
Ok(ColumnarValue::Array(Arc::new(array)))
150161
}
151162

163+
/// Round SQL function
164+
pub fn round(args: &[ArrayRef]) -> Result<ArrayRef> {
165+
if args.len() != 1 && args.len() != 2 {
166+
return Err(DataFusionError::Internal(format!(
167+
"round function requires one or two arguments, got {}",
168+
args.len()
169+
)));
170+
}
171+
172+
let mut decimal_places =
173+
&(Arc::new(Int64Array::from_value(0, args[0].len())) as ArrayRef);
174+
175+
if args.len() == 2 {
176+
decimal_places = &args[1];
177+
}
178+
179+
match args[0].data_type() {
180+
DataType::Float64 => Ok(Arc::new(make_function_inputs2!(
181+
&args[0],
182+
decimal_places,
183+
"value",
184+
"decimal_places",
185+
Float64Array,
186+
Int64Array,
187+
{
188+
|value: f64, decimal_places: i64| {
189+
(value * 10.0_f64.powi(decimal_places.try_into().unwrap())).round()
190+
/ 10.0_f64.powi(decimal_places.try_into().unwrap())
191+
}
192+
}
193+
)) as ArrayRef),
194+
195+
DataType::Float32 => Ok(Arc::new(make_function_inputs2!(
196+
&args[0],
197+
decimal_places,
198+
"value",
199+
"decimal_places",
200+
Float32Array,
201+
Int64Array,
202+
{
203+
|value: f32, decimal_places: i64| {
204+
(value * 10.0_f32.powi(decimal_places.try_into().unwrap())).round()
205+
/ 10.0_f32.powi(decimal_places.try_into().unwrap())
206+
}
207+
}
208+
)) as ArrayRef),
209+
210+
other => Err(DataFusionError::Internal(format!(
211+
"Unsupported data type {other:?} for function round"
212+
))),
213+
}
214+
}
215+
152216
/// Power SQL function
153217
pub fn power(args: &[ArrayRef]) -> Result<ArrayRef> {
154218
match args[0].data_type() {
@@ -365,4 +429,40 @@ mod tests {
365429
assert_eq!(floats.value(2), 4.0);
366430
assert_eq!(floats.value(3), 4.0);
367431
}
432+
433+
#[test]
434+
fn test_round_f32() {
435+
let args: Vec<ArrayRef> = vec![
436+
Arc::new(Float32Array::from(vec![125.2345; 10])), // input
437+
Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4, 5, -1, -2, -3, -4])), // decimal_places
438+
];
439+
440+
let result = round(&args).expect("failed to initialize function round");
441+
let floats =
442+
as_float32_array(&result).expect("failed to initialize function round");
443+
444+
let expected = Float32Array::from(vec![
445+
125.0, 125.2, 125.23, 125.235, 125.2345, 125.2345, 130.0, 100.0, 0.0, 0.0,
446+
]);
447+
448+
assert_eq!(floats, &expected);
449+
}
450+
451+
#[test]
452+
fn test_round_f64() {
453+
let args: Vec<ArrayRef> = vec![
454+
Arc::new(Float64Array::from(vec![125.2345; 10])), // input
455+
Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4, 5, -1, -2, -3, -4])), // decimal_places
456+
];
457+
458+
let result = round(&args).expect("failed to initialize function round");
459+
let floats =
460+
as_float64_array(&result).expect("failed to initialize function round");
461+
462+
let expected = Float64Array::from(vec![
463+
125.0, 125.2, 125.23, 125.235, 125.2345, 125.2345, 130.0, 100.0, 0.0, 0.0,
464+
]);
465+
466+
assert_eq!(floats, &expected);
467+
}
368468
}

datafusion/proto/src/logical_plan/from_proto.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1139,7 +1139,12 @@ pub fn parse_expr(
11391139
ScalarFunction::Log10 => Ok(log10(parse_expr(&args[0], registry)?)),
11401140
ScalarFunction::Floor => Ok(floor(parse_expr(&args[0], registry)?)),
11411141
ScalarFunction::Ceil => Ok(ceil(parse_expr(&args[0], registry)?)),
1142-
ScalarFunction::Round => Ok(round(parse_expr(&args[0], registry)?)),
1142+
ScalarFunction::Round => Ok(round(
1143+
args.to_owned()
1144+
.iter()
1145+
.map(|expr| parse_expr(expr, registry))
1146+
.collect::<Result<Vec<_>, _>>()?,
1147+
)),
11431148
ScalarFunction::Trunc => Ok(trunc(parse_expr(&args[0], registry)?)),
11441149
ScalarFunction::Abs => Ok(abs(parse_expr(&args[0], registry)?)),
11451150
ScalarFunction::Signum => Ok(signum(parse_expr(&args[0], registry)?)),

0 commit comments

Comments
 (0)