Skip to content

Commit c3c02cf

Browse files
authored
Implementing math power function for SQL (#2324)
* Implementing POWER function * Delete pv.yaml * Delete build-ballista-docker.sh * Delete ballista.dockerfile * aligining with latest upstream changes * Readding docker files * Formatting * Leaving only 64bit types * Adding tests, remove type conversion * fix for cast * Update functions.rs
1 parent e596236 commit c3c02cf

File tree

12 files changed

+218
-8
lines changed

12 files changed

+218
-8
lines changed

datafusion/core/src/logical_plan/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ pub use expr::{
4343
count, count_distinct, create_udaf, create_udf, date_part, date_trunc, digest,
4444
exists, exp, exprlist_to_fields, floor, in_list, in_subquery, initcap, left, length,
4545
lit, lit_timestamp_nano, ln, log10, log2, lower, lpad, ltrim, max, md5, min,
46-
not_exists, not_in_subquery, now, now_expr, nullif, octet_length, or, random,
46+
not_exists, not_in_subquery, now, now_expr, nullif, octet_length, or, power, random,
4747
regexp_match, regexp_replace, repeat, replace, reverse, right, round, rpad, rtrim,
4848
scalar_subquery, sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt,
4949
starts_with, strpos, substr, sum, tan, to_hex, to_timestamp_micros,

datafusion/core/src/physical_plan/functions.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,10 @@ pub fn create_physical_fun(
293293
BuiltinScalarFunction::Sqrt => Arc::new(math_expressions::sqrt),
294294
BuiltinScalarFunction::Tan => Arc::new(math_expressions::tan),
295295
BuiltinScalarFunction::Trunc => Arc::new(math_expressions::trunc),
296+
BuiltinScalarFunction::Power => {
297+
Arc::new(|args| make_scalar_function(math_expressions::power)(args))
298+
}
299+
296300
// string functions
297301
BuiltinScalarFunction::Array => Arc::new(array_expressions::array),
298302
BuiltinScalarFunction::Ascii => Arc::new(|args| match args[0].data_type() {

datafusion/core/tests/sql/functions.rs

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,3 +445,129 @@ async fn case_builtin_math_expression() {
445445
assert_batches_sorted_eq!(expected, &results);
446446
}
447447
}
448+
449+
#[tokio::test]
450+
async fn test_power() -> Result<()> {
451+
let schema = Arc::new(Schema::new(vec![
452+
Field::new("i32", DataType::Int16, true),
453+
Field::new("i64", DataType::Int64, true),
454+
Field::new("f32", DataType::Float32, true),
455+
Field::new("f64", DataType::Float64, true),
456+
]));
457+
458+
let data = RecordBatch::try_new(
459+
schema.clone(),
460+
vec![
461+
Arc::new(Int16Array::from(vec![
462+
Some(2),
463+
Some(5),
464+
Some(0),
465+
Some(-14),
466+
None,
467+
])),
468+
Arc::new(Int64Array::from(vec![
469+
Some(2),
470+
Some(5),
471+
Some(0),
472+
Some(-14),
473+
None,
474+
])),
475+
Arc::new(Float32Array::from(vec![
476+
Some(1.0),
477+
Some(2.5),
478+
Some(0.0),
479+
Some(-14.5),
480+
None,
481+
])),
482+
Arc::new(Float64Array::from(vec![
483+
Some(1.0),
484+
Some(2.5),
485+
Some(0.0),
486+
Some(-14.5),
487+
None,
488+
])),
489+
],
490+
)?;
491+
492+
let table = MemTable::try_new(schema, vec![vec![data]])?;
493+
494+
let ctx = SessionContext::new();
495+
ctx.register_table("test", Arc::new(table))?;
496+
let sql = r"SELECT power(i32, exp_i) as power_i32,
497+
power(i64, exp_f) as power_i64,
498+
power(f32, exp_i) as power_f32,
499+
power(f64, exp_f) as power_f64,
500+
power(2, 3) as power_int_scalar,
501+
power(2.5, 3.0) as power_float_scalar
502+
FROM (select test.*, 3 as exp_i, 3.0 as exp_f from test) a";
503+
let actual = execute_to_batches(&ctx, sql).await;
504+
let expected = vec![
505+
"+-----------+-----------+-----------+-----------+------------------+--------------------+",
506+
"| power_i32 | power_i64 | power_f32 | power_f64 | power_int_scalar | power_float_scalar |",
507+
"+-----------+-----------+-----------+-----------+------------------+--------------------+",
508+
"| 8 | 8 | 1 | 1 | 8 | 15.625 |",
509+
"| 125 | 125 | 15.625 | 15.625 | 8 | 15.625 |",
510+
"| 0 | 0 | 0 | 0 | 8 | 15.625 |",
511+
"| -2744 | -2744 | -3048.625 | -3048.625 | 8 | 15.625 |",
512+
"| | | | | 8 | 15.625 |",
513+
"+-----------+-----------+-----------+-----------+------------------+--------------------+",
514+
];
515+
assert_batches_eq!(expected, &actual);
516+
//dbg!(actual[0].schema().fields());
517+
assert_eq!(
518+
actual[0]
519+
.schema()
520+
.field_with_name("power_i32")
521+
.unwrap()
522+
.data_type()
523+
.to_owned(),
524+
DataType::Int64
525+
);
526+
assert_eq!(
527+
actual[0]
528+
.schema()
529+
.field_with_name("power_i64")
530+
.unwrap()
531+
.data_type()
532+
.to_owned(),
533+
DataType::Float64
534+
);
535+
assert_eq!(
536+
actual[0]
537+
.schema()
538+
.field_with_name("power_f32")
539+
.unwrap()
540+
.data_type()
541+
.to_owned(),
542+
DataType::Float64
543+
);
544+
assert_eq!(
545+
actual[0]
546+
.schema()
547+
.field_with_name("power_f64")
548+
.unwrap()
549+
.data_type()
550+
.to_owned(),
551+
DataType::Float64
552+
);
553+
assert_eq!(
554+
actual[0]
555+
.schema()
556+
.field_with_name("power_int_scalar")
557+
.unwrap()
558+
.data_type()
559+
.to_owned(),
560+
DataType::Int64
561+
);
562+
assert_eq!(
563+
actual[0]
564+
.schema()
565+
.field_with_name("power_float_scalar")
566+
.unwrap()
567+
.data_type()
568+
.to_owned(),
569+
DataType::Float64
570+
);
571+
572+
Ok(())
573+
}

datafusion/expr/src/built_in_function.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ pub enum BuiltinScalarFunction {
5454
Log10,
5555
/// log2
5656
Log2,
57+
/// power
58+
Power,
5759
/// round
5860
Round,
5961
/// signum
@@ -184,6 +186,7 @@ impl BuiltinScalarFunction {
184186
BuiltinScalarFunction::Log => Volatility::Immutable,
185187
BuiltinScalarFunction::Log10 => Volatility::Immutable,
186188
BuiltinScalarFunction::Log2 => Volatility::Immutable,
189+
BuiltinScalarFunction::Power => Volatility::Immutable,
187190
BuiltinScalarFunction::Round => Volatility::Immutable,
188191
BuiltinScalarFunction::Signum => Volatility::Immutable,
189192
BuiltinScalarFunction::Sin => Volatility::Immutable,
@@ -267,6 +270,7 @@ impl FromStr for BuiltinScalarFunction {
267270
"log" => BuiltinScalarFunction::Log,
268271
"log10" => BuiltinScalarFunction::Log10,
269272
"log2" => BuiltinScalarFunction::Log2,
273+
"power" => BuiltinScalarFunction::Power,
270274
"round" => BuiltinScalarFunction::Round,
271275
"signum" => BuiltinScalarFunction::Signum,
272276
"sin" => BuiltinScalarFunction::Sin,

datafusion/expr/src/expr_fn.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@ unary_scalar_expr!(Log2, log2);
282282
unary_scalar_expr!(Log10, log10);
283283
unary_scalar_expr!(Ln, ln);
284284
unary_scalar_expr!(NullIf, nullif);
285+
scalar_expr!(Power, power, base, exponent);
285286

286287
// string functions
287288
scalar_expr!(Ascii, ascii, string);

datafusion/expr/src/function.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,11 @@ pub fn return_type(
217217
}
218218
}),
219219

220+
BuiltinScalarFunction::Power => match &input_expr_types[0] {
221+
DataType::Int64 => Ok(DataType::Int64),
222+
_ => Ok(DataType::Float64),
223+
},
224+
220225
BuiltinScalarFunction::Abs
221226
| BuiltinScalarFunction::Acos
222227
| BuiltinScalarFunction::Asin
@@ -505,6 +510,13 @@ pub fn signature(fun: &BuiltinScalarFunction) -> Signature {
505510
fun.volatility(),
506511
),
507512
BuiltinScalarFunction::Random => Signature::exact(vec![], fun.volatility()),
513+
BuiltinScalarFunction::Power => Signature::one_of(
514+
vec![
515+
TypeSignature::Exact(vec![DataType::Int64, DataType::Int64]),
516+
TypeSignature::Exact(vec![DataType::Float64, DataType::Float64]),
517+
],
518+
fun.volatility(),
519+
),
508520
// math expressions expect 1 argument of type f64 or f32
509521
// priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we
510522
// return the best approximation for it (in f64).

datafusion/physical-expr/src/math_expressions.rs

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@
1717

1818
//! Math expressions
1919
20-
use arrow::array::{Float32Array, Float64Array};
20+
use arrow::array::ArrayRef;
21+
use arrow::array::{Float32Array, Float64Array, Int64Array};
2122
use arrow::datatypes::DataType;
2223
use datafusion_common::ScalarValue;
2324
use datafusion_common::{DataFusionError, Result};
2425
use datafusion_expr::ColumnarValue;
2526
use rand::{thread_rng, Rng};
27+
use std::any::type_name;
2628
use std::iter;
2729
use std::sync::Arc;
2830

@@ -86,6 +88,33 @@ macro_rules! math_unary_function {
8688
};
8789
}
8890

91+
macro_rules! downcast_arg {
92+
($ARG:expr, $NAME:expr, $ARRAY_TYPE:ident) => {{
93+
$ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| {
94+
DataFusionError::Internal(format!(
95+
"could not cast {} to {}",
96+
$NAME,
97+
type_name::<$ARRAY_TYPE>()
98+
))
99+
})?
100+
}};
101+
}
102+
103+
macro_rules! make_function_inputs2 {
104+
($ARG1: expr, $ARG2: expr, $NAME1:expr, $NAME2: expr, $ARRAY_TYPE:ident, $FUNC: block) => {{
105+
let arg1 = downcast_arg!($ARG1, $NAME1, $ARRAY_TYPE);
106+
let arg2 = downcast_arg!($ARG2, $NAME2, $ARRAY_TYPE);
107+
108+
arg1.iter()
109+
.zip(arg2.iter())
110+
.map(|(a1, a2)| match (a1, a2) {
111+
(Some(a1), Some(a2)) => Some($FUNC(a1, a2.try_into().ok()?)),
112+
_ => None,
113+
})
114+
.collect::<$ARRAY_TYPE>()
115+
}};
116+
}
117+
89118
math_unary_function!("sqrt", sqrt);
90119
math_unary_function!("sin", sin);
91120
math_unary_function!("cos", cos);
@@ -120,6 +149,33 @@ pub fn random(args: &[ColumnarValue]) -> Result<ColumnarValue> {
120149
Ok(ColumnarValue::Array(Arc::new(array)))
121150
}
122151

152+
pub fn power(args: &[ArrayRef]) -> Result<ArrayRef> {
153+
match args[0].data_type() {
154+
DataType::Float64 => Ok(Arc::new(make_function_inputs2!(
155+
&args[0],
156+
&args[1],
157+
"base",
158+
"exponent",
159+
Float64Array,
160+
{ f64::powf }
161+
)) as ArrayRef),
162+
163+
DataType::Int64 => Ok(Arc::new(make_function_inputs2!(
164+
&args[0],
165+
&args[1],
166+
"base",
167+
"exponent",
168+
Int64Array,
169+
{ i64::pow }
170+
)) as ArrayRef),
171+
172+
other => Err(DataFusionError::Internal(format!(
173+
"Unsupported data type {:?} for function power",
174+
other
175+
))),
176+
}
177+
}
178+
123179
#[cfg(test)]
124180
mod tests {
125181

datafusion/proto/proto/datafusion.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ enum ScalarFunction {
184184
Trim=61;
185185
Upper=62;
186186
Coalesce=63;
187+
Power=64;
187188
}
188189

189190
message ScalarFunctionNode {

datafusion/proto/src/from_proto.rs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ use datafusion::{
3131
logical_plan::{
3232
abs, acos, ascii, asin, atan, ceil, character_length, chr, concat_expr,
3333
concat_ws_expr, cos, digest, exp, floor, left, ln, log10, log2, now_expr, nullif,
34-
random, regexp_replace, repeat, replace, reverse, right, round, signum, sin,
35-
split_part, sqrt, starts_with, strpos, substr, tan, to_hex, to_timestamp_micros,
36-
to_timestamp_millis, to_timestamp_seconds, translate, trunc,
34+
power, random, regexp_replace, repeat, replace, reverse, right, round, signum,
35+
sin, split_part, sqrt, starts_with, strpos, substr, tan, to_hex,
36+
to_timestamp_micros, to_timestamp_millis, to_timestamp_seconds, translate, trunc,
3737
window_frames::{WindowFrame, WindowFrameBound, WindowFrameUnits},
3838
Column, DFField, DFSchema, DFSchemaRef, Expr, Operator,
3939
},
@@ -466,6 +466,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction {
466466
ScalarFunction::Translate => Self::Translate,
467467
ScalarFunction::RegexpMatch => Self::RegexpMatch,
468468
ScalarFunction::Coalesce => Self::Coalesce,
469+
ScalarFunction::Power => Self::Power,
469470
}
470471
}
471472
}
@@ -1243,6 +1244,10 @@ pub fn parse_expr(
12431244
.map(|expr| parse_expr(expr, registry))
12441245
.collect::<Result<Vec<_>, _>>()?,
12451246
)),
1247+
ScalarFunction::Power => Ok(power(
1248+
parse_expr(&args[0], registry)?,
1249+
parse_expr(&args[1], registry)?,
1250+
)),
12461251
_ => Err(proto_error(
12471252
"Protobuf deserialization error: Unsupported scalar function",
12481253
)),

datafusion/proto/src/to_proto.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1069,6 +1069,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction {
10691069
BuiltinScalarFunction::Translate => Self::Translate,
10701070
BuiltinScalarFunction::RegexpMatch => Self::RegexpMatch,
10711071
BuiltinScalarFunction::Coalesce => Self::Coalesce,
1072+
BuiltinScalarFunction::Power => Self::Power,
10721073
};
10731074

10741075
Ok(scalar_function)

dev/build-ballista-docker.sh

100755100644
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,4 @@ set -e
2121

2222
. ./dev/build-set-env.sh
2323
docker build -t ballista-base:$BALLISTA_VERSION -f dev/docker/ballista-base.dockerfile .
24-
docker build -t ballista:$BALLISTA_VERSION -f dev/docker/ballista.dockerfile .
24+
docker build -t ballista:$BALLISTA_VERSION -f dev/docker/ballista.dockerfile .

dev/docker/ballista.dockerfile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ ARG RELEASE_FLAG=--release
2525
FROM ballista-base:0.6.0 AS base
2626
WORKDIR /tmp/ballista
2727
RUN apt-get -y install cmake
28-
RUN cargo install cargo-chef --version 0.1.23
28+
RUN cargo install cargo-chef --version 0.1.34
2929

3030
FROM base as planner
3131
ADD Cargo.toml .
@@ -105,4 +105,4 @@ COPY benchmarks/queries/ /queries/
105105
ENV RUST_LOG=info
106106
ENV RUST_BACKTRACE=full
107107

108-
CMD ["/executor"]
108+
CMD ["/executor"]

0 commit comments

Comments
 (0)