From 5ff4b5e36f410f791f869af9064b53239f012dd7 Mon Sep 17 00:00:00 2001 From: Adam Gutglick Date: Tue, 16 Sep 2025 16:59:18 +0100 Subject: [PATCH 1/2] SQL stuff --- datafusion/sql/src/expr/value.rs | 49 +++++++++++++++++++++---- datafusion/sql/tests/sql_integration.rs | 22 ++++++++--- 2 files changed, 59 insertions(+), 12 deletions(-) diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs index 7075a1afd9dd..a48227693b0f 100644 --- a/datafusion/sql/src/expr/value.rs +++ b/datafusion/sql/src/expr/value.rs @@ -21,6 +21,7 @@ use arrow::compute::kernels::cast_utils::{ }; use arrow::datatypes::{ i256, DataType, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, + DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION, }; use bigdecimal::num_bigint::BigInt; use bigdecimal::{BigDecimal, Signed, ToPrimitive}; @@ -372,7 +373,32 @@ fn parse_decimal(unsigned_number: &str, negative: bool) -> Result { } else { digits }; - if precision <= DECIMAL128_MAX_PRECISION as u64 { + + if precision <= DECIMAL32_MAX_PRECISION as u64 { + let val = int_val.to_i32().ok_or_else(|| { + // Failures are unexpected here as we have already checked the precision + internal_datafusion_err!( + "Unexpected overflow when converting {} to i32", + int_val + ) + })?; + Ok(Expr::Literal( + ScalarValue::Decimal32(Some(val), precision as u8, scale as i8), + None, + )) + } else if precision <= DECIMAL64_MAX_PRECISION as u64 { + let val = int_val.to_i64().ok_or_else(|| { + // Failures are unexpected here as we have already checked the precision + internal_datafusion_err!( + "Unexpected overflow when converting {} to i64", + int_val + ) + })?; + Ok(Expr::Literal( + ScalarValue::Decimal64(Some(val), precision as u8, scale as i8), + None, + )) + } else if precision <= DECIMAL128_MAX_PRECISION as u64 { let val = int_val.to_i128().ok_or_else(|| { // Failures are unexpected here as we have already checked the precision internal_datafusion_err!( @@ -460,15 +486,24 @@ mod tests { fn test_parse_decimal() { // Supported cases let cases = [ - ("0", ScalarValue::Decimal128(Some(0), 1, 0)), - ("1", ScalarValue::Decimal128(Some(1), 1, 0)), - ("123.45", ScalarValue::Decimal128(Some(12345), 5, 2)), + ("0", ScalarValue::Decimal32(Some(0), 1, 0)), + ("1", ScalarValue::Decimal32(Some(1), 1, 0)), + ("123.45", ScalarValue::Decimal32(Some(12345), 5, 2)), // Digit count is less than scale - ("0.001", ScalarValue::Decimal128(Some(1), 3, 3)), + ("0.001", ScalarValue::Decimal32(Some(1), 3, 3)), // Scientific notation - ("123.456e-2", ScalarValue::Decimal128(Some(123456), 6, 5)), + ("123.456e-2", ScalarValue::Decimal32(Some(123456), 6, 5)), // Negative scale - ("123456e128", ScalarValue::Decimal128(Some(123456), 6, -128)), + ("123456e128", ScalarValue::Decimal32(Some(123456), 6, -128)), + // Decimal128 + ( + &("9".repeat(19) + "." + "99999"), + ScalarValue::Decimal128( + Some(i128::from_str(&"9".repeat(24)).unwrap()), + 24, + 5, + ), + ), // Decimal256 ( &("9".repeat(39) + "." + "99999"), diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index f66af28f436e..0e5409398e60 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -87,7 +87,7 @@ fn parse_decimals_3() { assert_snapshot!( plan, @r" - Projection: Decimal128(Some(1),1,1) + Projection: Decimal32(Some(1),1,1) EmptyRelation: rows=1 " ); @@ -101,7 +101,7 @@ fn parse_decimals_4() { assert_snapshot!( plan, @r" - Projection: Decimal128(Some(1),2,2) + Projection: Decimal32(Some(1),2,2) EmptyRelation: rows=1 " ); @@ -115,7 +115,7 @@ fn parse_decimals_5() { assert_snapshot!( plan, @r" - Projection: Decimal128(Some(10),2,1) + Projection: Decimal32(Some(10),2,1) EmptyRelation: rows=1 " ); @@ -129,7 +129,7 @@ fn parse_decimals_6() { assert_snapshot!( plan, @r" - Projection: Decimal128(Some(1001),4,2) + Projection: Decimal32(Some(1001),4,2) EmptyRelation: rows=1 " ); @@ -318,7 +318,19 @@ fn test_int_decimal_no_scale() { assert_snapshot!( plan, @r" - Projection: CAST(Int64(10) AS Decimal128(5, 0)) + Projection: CAST(Int64(10) AS Decimal32(5, 0)) + EmptyRelation: rows=1 + " + ); +} + +#[test] +fn test_int_decimal128_no_scale() { + let plan = logical_plan("SELECT CAST(10 AS DECIMAL(20))").unwrap(); + assert_snapshot!( + plan, + @r" + Projection: CAST(Int64(10) AS Decimal128(20, 0)) EmptyRelation: rows=1 " ); From b7273d1024d84dc055640ef2e69a856e52c5fcf4 Mon Sep 17 00:00:00 2001 From: Adam Gutglick Date: Wed, 24 Sep 2025 22:52:54 +0100 Subject: [PATCH 2/2] Rest of the stuff --- datafusion/common/src/scalar/mod.rs | 6 + .../expr-common/src/type_coercion/binary.rs | 97 ++++++++++- .../type_coercion/binary/tests/arithmetic.rs | 90 +++++++++++ .../type_coercion/binary/tests/comparison.rs | 152 ++++++++++++++++++ datafusion/functions/src/math/abs.rs | 12 +- .../src/joins/sort_merge_join/stream.rs | 6 + .../spark/src/function/math/width_bucket.rs | 9 +- datafusion/sql/src/expr/value.rs | 18 +++ datafusion/sql/src/utils.rs | 15 +- .../sqllogictest/src/engines/conversion.rs | 22 ++- .../engines/datafusion_engine/normalize.rs | 10 ++ .../sqllogictest/test_files/aggregate.slt | 63 ++++++-- .../sqllogictest/test_files/decimal.slt | 58 +++---- .../sqllogictest/test_files/explain_tree.slt | 2 +- .../sqllogictest/test_files/functions.slt | 2 +- datafusion/sqllogictest/test_files/joins.slt | 10 +- datafusion/sqllogictest/test_files/math.slt | 6 +- .../sqllogictest/test_files/options.slt | 12 +- .../sqllogictest/test_files/predicates.slt | 6 +- .../sqllogictest/test_files/qualify.slt | 10 +- .../test_files/spark/math/mod.slt | 2 +- .../sqllogictest/test_files/subquery.slt | 4 +- 22 files changed, 524 insertions(+), 88 deletions(-) diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 8c079056e21d..bba994dd11b5 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -1362,6 +1362,12 @@ impl ScalarValue { DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(0.0))), DataType::Float32 => ScalarValue::Float32(Some(0.0)), DataType::Float64 => ScalarValue::Float64(Some(0.0)), + DataType::Decimal32(precision, scale) => { + ScalarValue::Decimal32(Some(0), *precision, *scale) + } + DataType::Decimal64(precision, scale) => { + ScalarValue::Decimal64(Some(0), *precision, *scale) + } DataType::Decimal128(precision, scale) => { ScalarValue::Decimal128(Some(0), *precision, *scale) } diff --git a/datafusion/expr-common/src/type_coercion/binary.rs b/datafusion/expr-common/src/type_coercion/binary.rs index 1c99f49d26cf..4ef18ece560d 100644 --- a/datafusion/expr-common/src/type_coercion/binary.rs +++ b/datafusion/expr-common/src/type_coercion/binary.rs @@ -351,12 +351,21 @@ fn math_decimal_coercion( Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _), Null, ) => Some((lhs_type.clone(), lhs_type.clone())), + // Same variant decimals - no coercion needed (Decimal32(_, _), Decimal32(_, _)) | (Decimal64(_, _), Decimal64(_, _)) | (Decimal128(_, _), Decimal128(_, _)) | (Decimal256(_, _), Decimal256(_, _)) => { Some((lhs_type.clone(), rhs_type.clone())) } + // Cross-variant decimal coercion - choose larger variant with appropriate precision/scale + (Decimal32(_, _), Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _)) + | (Decimal64(_, _), Decimal32(_, _) | Decimal128(_, _) | Decimal256(_, _)) + | (Decimal128(_, _), Decimal32(_, _) | Decimal64(_, _) | Decimal256(_, _)) + | (Decimal256(_, _), Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _)) => { + let coerced_type = get_wider_decimal_type_cross_variant(lhs_type, rhs_type)?; + Some((coerced_type.clone(), coerced_type)) + } // Unlike with comparison we don't coerce to a decimal in the case of floating point // numbers, instead falling back to floating point arithmetic instead ( @@ -955,20 +964,90 @@ pub fn decimal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option get_wider_decimal_type(lhs_type, rhs_type), + (Decimal32(_, _), Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _)) => { + get_wider_decimal_type_cross_variant(lhs_type, rhs_type) + } + (Decimal32(_, _), _) => get_common_decimal_type(lhs_type, rhs_type), + (Decimal64(_, _), Decimal64(_, _)) => get_wider_decimal_type(lhs_type, rhs_type), + (Decimal64(_, _), Decimal32(_, _) | Decimal128(_, _) | Decimal256(_, _)) => { + get_wider_decimal_type_cross_variant(lhs_type, rhs_type) + } + (Decimal64(_, _), _) => get_common_decimal_type(lhs_type, rhs_type), (Decimal128(_, _), Decimal128(_, _)) => { get_wider_decimal_type(lhs_type, rhs_type) } + (Decimal128(_, _), Decimal32(_, _) | Decimal64(_, _) | Decimal256(_, _)) => { + get_wider_decimal_type_cross_variant(lhs_type, rhs_type) + } (Decimal128(_, _), _) => get_common_decimal_type(lhs_type, rhs_type), - (_, Decimal128(_, _)) => get_common_decimal_type(rhs_type, lhs_type), (Decimal256(_, _), Decimal256(_, _)) => { get_wider_decimal_type(lhs_type, rhs_type) } + (Decimal256(_, _), Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _)) => { + get_wider_decimal_type_cross_variant(lhs_type, rhs_type) + } (Decimal256(_, _), _) => get_common_decimal_type(lhs_type, rhs_type), + (_, Decimal32(_, _)) => get_common_decimal_type(rhs_type, lhs_type), + (_, Decimal64(_, _)) => get_common_decimal_type(rhs_type, lhs_type), + (_, Decimal128(_, _)) => get_common_decimal_type(rhs_type, lhs_type), (_, Decimal256(_, _)) => get_common_decimal_type(rhs_type, lhs_type), (_, _) => None, } } +/// Handle cross-variant decimal widening by choosing the larger variant +fn get_wider_decimal_type_cross_variant( + lhs_type: &DataType, + rhs_type: &DataType, +) -> Option { + use arrow::datatypes::DataType::*; + + let (p1, s1) = match lhs_type { + Decimal32(p, s) => (*p, *s), + Decimal64(p, s) => (*p, *s), + Decimal128(p, s) => (*p, *s), + Decimal256(p, s) => (*p, *s), + _ => return None, + }; + + let (p2, s2) = match rhs_type { + Decimal32(p, s) => (*p, *s), + Decimal64(p, s) => (*p, *s), + Decimal128(p, s) => (*p, *s), + Decimal256(p, s) => (*p, *s), + _ => return None, + }; + + // max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2) + let s = s1.max(s2); + let range = (p1 as i8 - s1).max(p2 as i8 - s2); + let required_precision = (range + s) as u8; + + // Choose the larger variant between the two input types + match (lhs_type, rhs_type) { + (Decimal32(_, _), Decimal64(_, _)) | (Decimal64(_, _), Decimal32(_, _)) => { + Some(Decimal64(required_precision, s)) + } + (Decimal32(_, _), Decimal128(_, _)) | (Decimal128(_, _), Decimal32(_, _)) => { + Some(Decimal128(required_precision, s)) + } + (Decimal32(_, _), Decimal256(_, _)) | (Decimal256(_, _), Decimal32(_, _)) => { + Some(Decimal256(required_precision, s)) + } + (Decimal64(_, _), Decimal128(_, _)) | (Decimal128(_, _), Decimal64(_, _)) => { + Some(Decimal128(required_precision, s)) + } + (Decimal64(_, _), Decimal256(_, _)) | (Decimal256(_, _), Decimal64(_, _)) => { + Some(Decimal256(required_precision, s)) + } + (Decimal128(_, _), Decimal256(_, _)) | (Decimal256(_, _), Decimal128(_, _)) => { + Some(Decimal256(required_precision, s)) + } + _ => None, + } +} + /// Coerce `lhs_type` and `rhs_type` to a common type. fn get_common_decimal_type( decimal_type: &DataType, @@ -976,7 +1055,15 @@ fn get_common_decimal_type( ) -> Option { use arrow::datatypes::DataType::*; match decimal_type { - Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) => { + Decimal32(_, _) => { + let other_decimal_type = coerce_numeric_type_to_decimal32(other_type)?; + get_wider_decimal_type(decimal_type, &other_decimal_type) + } + Decimal64(_, _) => { + let other_decimal_type = coerce_numeric_type_to_decimal64(other_type)?; + get_wider_decimal_type(decimal_type, &other_decimal_type) + } + Decimal128(_, _) => { let other_decimal_type = coerce_numeric_type_to_decimal128(other_type)?; get_wider_decimal_type(decimal_type, &other_decimal_type) } @@ -988,7 +1075,7 @@ fn get_common_decimal_type( } } -/// Returns a `DataType::Decimal128` that can store any value from either +/// Returns a decimal [`DataType`] that can store any value from either /// `lhs_decimal_type` and `rhs_decimal_type` /// /// The result decimal type is `(max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2))`. @@ -1209,14 +1296,14 @@ fn numerical_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option DataType { - DataType::Decimal128( + DataType::Decimal32( DECIMAL32_MAX_PRECISION.min(precision), DECIMAL32_MAX_SCALE.min(scale), ) } fn create_decimal64_type(precision: u8, scale: i8) -> DataType { - DataType::Decimal128( + DataType::Decimal64( DECIMAL64_MAX_PRECISION.min(precision), DECIMAL64_MAX_SCALE.min(scale), ) diff --git a/datafusion/expr-common/src/type_coercion/binary/tests/arithmetic.rs b/datafusion/expr-common/src/type_coercion/binary/tests/arithmetic.rs index e6238ba0078d..9bf3ba5f78f6 100644 --- a/datafusion/expr-common/src/type_coercion/binary/tests/arithmetic.rs +++ b/datafusion/expr-common/src/type_coercion/binary/tests/arithmetic.rs @@ -291,3 +291,93 @@ fn test_coercion_arithmetic_decimal() -> Result<()> { Ok(()) } + +#[test] +fn test_coercion_arithmetic_decimal_cross_variant() -> Result<()> { + let test_cases = [ + ( + DataType::Decimal32(5, 2), + DataType::Decimal64(10, 3), + DataType::Decimal64(10, 3), + DataType::Decimal64(10, 3), + ), + ( + DataType::Decimal32(7, 1), + DataType::Decimal128(15, 4), + DataType::Decimal128(15, 4), + DataType::Decimal128(15, 4), + ), + ( + DataType::Decimal32(9, 0), + DataType::Decimal256(20, 5), + DataType::Decimal256(20, 5), + DataType::Decimal256(20, 5), + ), + ( + DataType::Decimal64(12, 3), + DataType::Decimal128(18, 2), + DataType::Decimal128(19, 3), + DataType::Decimal128(19, 3), + ), + ( + DataType::Decimal64(15, 4), + DataType::Decimal256(25, 6), + DataType::Decimal256(25, 6), + DataType::Decimal256(25, 6), + ), + ( + DataType::Decimal128(20, 5), + DataType::Decimal256(30, 8), + DataType::Decimal256(30, 8), + DataType::Decimal256(30, 8), + ), + // Reverse order cases + ( + DataType::Decimal64(10, 3), + DataType::Decimal32(5, 2), + DataType::Decimal64(10, 3), + DataType::Decimal64(10, 3), + ), + ( + DataType::Decimal128(15, 4), + DataType::Decimal32(7, 1), + DataType::Decimal128(15, 4), + DataType::Decimal128(15, 4), + ), + ( + DataType::Decimal256(20, 5), + DataType::Decimal32(9, 0), + DataType::Decimal256(20, 5), + DataType::Decimal256(20, 5), + ), + ( + DataType::Decimal128(18, 2), + DataType::Decimal64(12, 3), + DataType::Decimal128(19, 3), + DataType::Decimal128(19, 3), + ), + ( + DataType::Decimal256(25, 6), + DataType::Decimal64(15, 4), + DataType::Decimal256(25, 6), + DataType::Decimal256(25, 6), + ), + ( + DataType::Decimal256(30, 8), + DataType::Decimal128(20, 5), + DataType::Decimal256(30, 8), + DataType::Decimal256(30, 8), + ), + ]; + + for (lhs_type, rhs_type, expected_lhs_type, expected_rhs_type) in test_cases { + test_math_decimal_coercion_rule( + lhs_type, + rhs_type, + expected_lhs_type, + expected_rhs_type, + ); + } + + Ok(()) +} diff --git a/datafusion/expr-common/src/type_coercion/binary/tests/comparison.rs b/datafusion/expr-common/src/type_coercion/binary/tests/comparison.rs index 208edae4ffc2..20667b1257ba 100644 --- a/datafusion/expr-common/src/type_coercion/binary/tests/comparison.rs +++ b/datafusion/expr-common/src/type_coercion/binary/tests/comparison.rs @@ -67,6 +67,158 @@ fn test_decimal_binary_comparison_coercion() -> Result<()> { Ok(()) } +#[test] +fn test_decimal_cross_variant_comparison_coercion() -> Result<()> { + // Test cross-variant decimal coercion (different decimal types) + let test_cases = [ + // (lhs, rhs, expected_result) + ( + DataType::Decimal32(5, 2), + DataType::Decimal64(10, 3), + DataType::Decimal64(10, 3), + ), + ( + DataType::Decimal32(7, 1), + DataType::Decimal128(15, 4), + DataType::Decimal128(15, 4), + ), + ( + DataType::Decimal32(9, 0), + DataType::Decimal256(20, 5), + DataType::Decimal256(20, 5), + ), + ( + DataType::Decimal64(12, 3), + DataType::Decimal128(18, 2), + DataType::Decimal128(19, 3), + ), + ( + DataType::Decimal64(15, 4), + DataType::Decimal256(25, 6), + DataType::Decimal256(25, 6), + ), + ( + DataType::Decimal128(20, 5), + DataType::Decimal256(30, 8), + DataType::Decimal256(30, 8), + ), + // Reverse order cases + ( + DataType::Decimal64(10, 3), + DataType::Decimal32(5, 2), + DataType::Decimal64(10, 3), + ), + ( + DataType::Decimal128(15, 4), + DataType::Decimal32(7, 1), + DataType::Decimal128(15, 4), + ), + ( + DataType::Decimal256(20, 5), + DataType::Decimal32(9, 0), + DataType::Decimal256(20, 5), + ), + ( + DataType::Decimal128(18, 2), + DataType::Decimal64(12, 3), + DataType::Decimal128(19, 3), + ), + ( + DataType::Decimal256(25, 6), + DataType::Decimal64(15, 4), + DataType::Decimal256(25, 6), + ), + ( + DataType::Decimal256(30, 8), + DataType::Decimal128(20, 5), + DataType::Decimal256(30, 8), + ), + ]; + + let comparison_op_types = [ + Operator::NotEq, + Operator::Eq, + Operator::Gt, + Operator::GtEq, + Operator::Lt, + Operator::LtEq, + ]; + + for (lhs_type, rhs_type, expected_type) in test_cases { + for op in comparison_op_types { + let (lhs, rhs) = + BinaryTypeCoercer::new(&lhs_type, &op, &rhs_type).get_input_types()?; + assert_eq!(expected_type, lhs, "Coercion of type {lhs_type:?} with {rhs_type:?} resulted in unexpected type: {lhs:?}"); + assert_eq!(expected_type, rhs, "Coercion of type {rhs_type:?} with {lhs_type:?} resulted in unexpected type: {rhs:?}"); + } + } + + Ok(()) +} + +#[test] +fn test_decimal_variants_with_numeric_comparison_coercion() -> Result<()> { + let input_decimal32 = DataType::Decimal32(7, 2); + let input_types = [DataType::Int8, DataType::Int16, DataType::Float16]; + let result_types = [ + DataType::Decimal32(7, 2), + DataType::Decimal32(7, 2), + DataType::Decimal32(8, 3), + ]; + + let input_decimal64 = DataType::Decimal64(12, 3); + let input_types_64 = [ + DataType::Int8, + DataType::Int16, + DataType::Int32, + DataType::Float16, + DataType::Float32, + ]; + let result_types_64 = [ + DataType::Decimal64(12, 3), + DataType::Decimal64(12, 3), + DataType::Decimal64(13, 3), + DataType::Decimal64(12, 3), + DataType::Decimal64(16, 7), + ]; + + let comparison_op_types = [ + Operator::NotEq, + Operator::Eq, + Operator::Gt, + Operator::GtEq, + Operator::Lt, + Operator::LtEq, + ]; + + // Test Decimal32 cases + for (i, input_type) in input_types.iter().enumerate() { + let expect_type = &result_types[i]; + for op in comparison_op_types { + let (lhs, rhs) = BinaryTypeCoercer::new(&input_decimal32, &op, input_type) + .get_input_types()?; + assert_eq!( + expect_type, &lhs, + "Coercion of type {input_decimal32:?} with {input_type:?} resulted in unexpected type: {lhs:?}" + ); + assert_eq!(expect_type, &rhs, "Coercion of type {input_decimal32:?} with {input_type:?} resulted in unexpected type: {rhs:?}"); + } + } + + // Test Decimal64 cases + for (i, input_type) in input_types_64.iter().enumerate() { + let expect_type = &result_types_64[i]; + for op in comparison_op_types { + let (lhs, rhs) = BinaryTypeCoercer::new(&input_decimal64, &op, input_type) + .get_input_types()?; + assert_eq!(expect_type, &lhs, "Coercion of type {input_decimal64:?} with {input_type:?} resulted in unexpected type: {lhs:?}"); + assert_eq!(expect_type, &rhs, "Coercion of type {input_decimal64:?} with {input_type:?} resulted in unexpected type: {rhs:?}"); + } + } + + Ok(()) +} + #[test] fn test_like_coercion() { // string coerce to strings diff --git a/datafusion/functions/src/math/abs.rs b/datafusion/functions/src/math/abs.rs index 8af8e4c2c849..040f13c01449 100644 --- a/datafusion/functions/src/math/abs.rs +++ b/datafusion/functions/src/math/abs.rs @@ -21,8 +21,8 @@ use std::any::Any; use std::sync::Arc; use arrow::array::{ - ArrayRef, Decimal128Array, Decimal256Array, Float32Array, Float64Array, Int16Array, - Int32Array, Int64Array, Int8Array, + ArrayRef, Decimal128Array, Decimal256Array, Decimal32Array, Decimal64Array, + Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, }; use arrow::datatypes::DataType; use arrow::error::ArrowError; @@ -98,6 +98,8 @@ fn create_abs_function(input_data_type: &DataType) -> Result | DataType::UInt64 => Ok(|input: &ArrayRef| Ok(Arc::clone(input))), // Decimal types + DataType::Decimal32(_, _) => Ok(make_decimal_abs_function!(Decimal32Array)), + DataType::Decimal64(_, _) => Ok(make_decimal_abs_function!(Decimal64Array)), DataType::Decimal128(_, _) => Ok(make_decimal_abs_function!(Decimal128Array)), DataType::Decimal256(_, _) => Ok(make_decimal_abs_function!(Decimal256Array)), @@ -162,6 +164,12 @@ impl ScalarUDFImpl for AbsFunc { DataType::UInt16 => Ok(DataType::UInt16), DataType::UInt32 => Ok(DataType::UInt32), DataType::UInt64 => Ok(DataType::UInt64), + DataType::Decimal32(precision, scale) => { + Ok(DataType::Decimal32(precision, scale)) + } + DataType::Decimal64(precision, scale) => { + Ok(DataType::Decimal64(precision, scale)) + } DataType::Decimal128(precision, scale) => { Ok(DataType::Decimal128(precision, scale)) } diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs b/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs index d28a9bad17ec..0ce920f447be 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs @@ -1922,7 +1922,10 @@ fn compare_join_arrays( DataType::BinaryView => compare_value!(BinaryViewArray), DataType::FixedSizeBinary(_) => compare_value!(FixedSizeBinaryArray), DataType::LargeBinary => compare_value!(LargeBinaryArray), + DataType::Decimal32(..) => compare_value!(Decimal32Array), + DataType::Decimal64(..) => compare_value!(Decimal64Array), DataType::Decimal128(..) => compare_value!(Decimal128Array), + DataType::Decimal256(..) => compare_value!(Decimal256Array), DataType::Timestamp(time_unit, None) => match time_unit { TimeUnit::Second => compare_value!(TimestampSecondArray), TimeUnit::Millisecond => compare_value!(TimestampMillisecondArray), @@ -1994,7 +1997,10 @@ fn is_join_arrays_equal( DataType::BinaryView => compare_value!(BinaryViewArray), DataType::FixedSizeBinary(_) => compare_value!(FixedSizeBinaryArray), DataType::LargeBinary => compare_value!(LargeBinaryArray), + DataType::Decimal32(..) => compare_value!(Decimal32Array), + DataType::Decimal64(..) => compare_value!(Decimal64Array), DataType::Decimal128(..) => compare_value!(Decimal128Array), + DataType::Decimal256(..) => compare_value!(Decimal256Array), DataType::Timestamp(time_unit, None) => match time_unit { TimeUnit::Second => compare_value!(TimestampSecondArray), TimeUnit::Millisecond => compare_value!(TimestampMillisecondArray), diff --git a/datafusion/spark/src/function/math/width_bucket.rs b/datafusion/spark/src/function/math/width_bucket.rs index 24f8fe6b2456..3781e9c1159d 100644 --- a/datafusion/spark/src/function/math/width_bucket.rs +++ b/datafusion/spark/src/function/math/width_bucket.rs @@ -96,7 +96,14 @@ impl ScalarUDFImpl for SparkWidthBucket { let is_num = |t: &DataType| { matches!( t, - Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Decimal128(_, _) + Int8 | Int16 + | Int32 + | Int64 + | Float32 + | Float64 + | Decimal32(_, _) + | Decimal64(_, _) + | Decimal128(_, _) ) }; diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs index a48227693b0f..d03abd2cb7e3 100644 --- a/datafusion/sql/src/expr/value.rs +++ b/datafusion/sql/src/expr/value.rs @@ -495,6 +495,24 @@ mod tests { ("123.456e-2", ScalarValue::Decimal32(Some(123456), 6, 5)), // Negative scale ("123456e128", ScalarValue::Decimal32(Some(123456), 6, -128)), + // Decimal32 + ( + &("9".repeat(2) + "." + "99999"), + ScalarValue::Decimal32( + Some(i32::from_str(&"9".repeat(7)).unwrap()), + 7, + 5, + ), + ), + // Decimal64 + ( + &("9".repeat(10) + "." + "99999"), + ScalarValue::Decimal64( + Some(i64::from_str(&"9".repeat(15)).unwrap()), + 15, + 5, + ), + ), // Decimal128 ( &("9".repeat(19) + "." + "99999"), diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index daee722526d9..5b198a4ad964 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -20,7 +20,8 @@ use std::vec; use arrow::datatypes::{ - DataType, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE, + DataType, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, + DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION, DECIMAL_DEFAULT_SCALE, }; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, @@ -302,12 +303,14 @@ pub(crate) fn make_decimal_type( plan_err!( "Decimal(precision = {precision}, scale = {scale}) should satisfy `0 < precision <= 76`, and `scale <= precision`." ) - } else if precision > DECIMAL128_MAX_PRECISION - && precision <= DECIMAL256_MAX_PRECISION - { - Ok(DataType::Decimal256(precision, scale)) - } else { + } else if precision <= DECIMAL32_MAX_PRECISION { + Ok(DataType::Decimal32(precision, scale)) + } else if precision <= DECIMAL64_MAX_PRECISION { + Ok(DataType::Decimal64(precision, scale)) + } else if precision <= DECIMAL128_MAX_PRECISION { Ok(DataType::Decimal128(precision, scale)) + } else { + Ok(DataType::Decimal256(precision, scale)) } } diff --git a/datafusion/sqllogictest/src/engines/conversion.rs b/datafusion/sqllogictest/src/engines/conversion.rs index de3acbee93b1..dc4cad803535 100644 --- a/datafusion/sqllogictest/src/engines/conversion.rs +++ b/datafusion/sqllogictest/src/engines/conversion.rs @@ -15,7 +15,9 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::{i256, Decimal128Type, Decimal256Type, DecimalType}; +use arrow::datatypes::{ + i256, Decimal128Type, Decimal256Type, Decimal32Type, Decimal64Type, DecimalType, +}; use bigdecimal::BigDecimal; use half::f16; use rust_decimal::prelude::*; @@ -96,6 +98,24 @@ pub(crate) fn spark_f64_to_str(value: f64) -> String { } } +pub(crate) fn decimal_32_to_str(value: i32, scale: i8) -> String { + let precision = u8::MAX; // does not matter + big_decimal_to_str( + BigDecimal::from_str(&Decimal32Type::format_decimal(value, precision, scale)) + .unwrap(), + None, + ) +} + +pub(crate) fn decimal_64_to_str(value: i64, scale: i8) -> String { + let precision = u8::MAX; // does not matter + big_decimal_to_str( + BigDecimal::from_str(&Decimal64Type::format_decimal(value, precision, scale)) + .unwrap(), + None, + ) +} + pub(crate) fn decimal_128_to_str(value: i128, scale: i8) -> String { let precision = u8::MAX; // does not matter big_decimal_to_str( diff --git a/datafusion/sqllogictest/src/engines/datafusion_engine/normalize.rs b/datafusion/sqllogictest/src/engines/datafusion_engine/normalize.rs index 05e1f284c560..dc21f764a251 100644 --- a/datafusion/sqllogictest/src/engines/datafusion_engine/normalize.rs +++ b/datafusion/sqllogictest/src/engines/datafusion_engine/normalize.rs @@ -212,6 +212,14 @@ pub fn cell_to_string(col: &ArrayRef, row: usize, is_spark_path: bool) -> Result Ok(f64_to_str(result)) } } + DataType::Decimal32(_, scale) => { + let value = get_row_value!(array::Decimal32Array, col, row); + Ok(decimal_32_to_str(value, *scale)) + } + DataType::Decimal64(_, scale) => { + let value = get_row_value!(array::Decimal64Array, col, row); + Ok(decimal_64_to_str(value, *scale)) + } DataType::Decimal128(_, scale) => { let value = get_row_value!(array::Decimal128Array, col, row); Ok(decimal_128_to_str(value, *scale)) @@ -274,6 +282,8 @@ pub fn convert_schema_to_types(columns: &Fields) -> Vec { DataType::Float16 | DataType::Float32 | DataType::Float64 + | DataType::Decimal32(_, _) + | DataType::Decimal64(_, _) | DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => DFColumnType::Float, DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => { diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 4601d595e310..4a86d68e6838 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -887,14 +887,31 @@ SELECT approx_median(col_f64_nan) FROM median_table ---- NaN -# median decimal +# median decimal64 statement ok create table t(c decimal(10, 4)) as values (0.0001), (0.0002), (0.0003), (0.0004), (0.0005), (0.0006); query RT select median(c), arrow_typeof(median(c)) from t; ---- -0.0003 Decimal128(10, 4) +0.0003 Decimal64(10, 4) + +query RT +select approx_median(c), arrow_typeof(approx_median(c)) from t; +---- +0.00035 Float64 + +statement ok +drop table t; + +# median decimal128 +statement ok +create table t(c decimal(20, 4)) as values (0.0001), (0.0002), (0.0003), (0.0004), (0.0005), (0.0006); + +query RT +select median(c), arrow_typeof(median(c)) from t; +---- +0.0003 Decimal128(20, 4) query RT select approx_median(c), arrow_typeof(approx_median(c)) from t; @@ -911,7 +928,7 @@ create table t(c decimal(10, 4)) as values (0.0001), (null), (0.0003), (0.0004), query RT select median(c), arrow_typeof(median(c)) from t; ---- -0.0003 Decimal128(10, 4) +0.0003 Decimal64(10, 4) statement ok drop table t; @@ -923,7 +940,7 @@ create table t(c decimal(10, 4)) as values (null), (null), (null); query RT select median(c), arrow_typeof(median(c)) from t; ---- -NULL Decimal128(10, 4) +NULL Decimal64(10, 4) statement ok drop table t; @@ -2529,14 +2546,26 @@ select avg(c1) from test ---- 1.75 -# avg_decimal +# avg_decimal64 statement ok create table t (c1 decimal(10, 0)) as values (1), (2), (3), (4), (5), (6); query RT select avg(c1), arrow_typeof(avg(c1)) from t; ---- -3.5 Decimal128(14, 4) +3.5 Decimal64(14, 4) + +statement ok +drop table t; + +# avg_decimal128 +statement ok +create table t (c1 decimal(20, 0)) as values (1), (2), (3), (4), (5), (6); + +query RT +select avg(c1), arrow_typeof(avg(c1)) from t; +---- +3.5 Decimal128(24, 4) statement ok drop table t; @@ -2548,7 +2577,7 @@ create table t (c1 decimal(10, 0)) as values (1), (NULL), (3), (4), (5); query RT select avg(c1), arrow_typeof(avg(c1)) from t; ---- -3.25 Decimal128(14, 4) +3.25 Decimal64(14, 4) statement ok drop table t; @@ -2560,7 +2589,7 @@ create table t (c1 decimal(10, 0)) as values (NULL), (NULL), (NULL), (NULL), (NU query RT select avg(c1), arrow_typeof(avg(c1)) from t; ---- -NULL Decimal128(14, 4) +NULL Decimal64(14, 4) statement ok drop table t; @@ -5028,7 +5057,7 @@ NULL query RT select sum(c1), arrow_typeof(sum(c1)) from d_table; ---- -100 Decimal128(20, 3) +100 Decimal64(18, 3) # aggregate sum with decimal statement ok @@ -5064,7 +5093,7 @@ select sum(c2), arrow_typeof(sum(c2)) from t; ---- -NULL Decimal128(20, 0) NULL Int64 +NULL Decimal64(18, 0) NULL Int64 statement ok drop table t; @@ -5108,21 +5137,21 @@ drop table t; query TRT select c2, sum(c1), arrow_typeof(sum(c1)) from d_table GROUP BY c2 ORDER BY c2; ---- -A 1100.045 Decimal128(20, 3) -B -1000.045 Decimal128(20, 3) +A 1100.045 Decimal64(18, 3) +B -1000.045 Decimal64(18, 3) # aggregate_decimal_avg query RT select avg(c1), arrow_typeof(avg(c1)) from d_table ---- -5 Decimal128(14, 7) +5 Decimal64(14, 7) query TRT select c2, avg(c1), arrow_typeof(avg(c1)) from d_table GROUP BY c2 ORDER BY c2 ---- -A 110.0045 Decimal128(14, 7) -B -100.0045 Decimal128(14, 7) +A 110.0045 Decimal64(14, 7) +B -100.0045 Decimal64(14, 7) # aggregate_decimal_count_distinct query I @@ -7341,7 +7370,7 @@ create table t_decimal (c decimal(10, 4)) as values (100.00), (125.00), (175.00) query RT select avg(distinct c), arrow_typeof(avg(distinct c)) from t_decimal; ---- -180 Decimal128(14, 8) +180 Decimal64(14, 8) statement ok drop table t_decimal; @@ -7520,7 +7549,7 @@ select avg(d) from distinct_avg; ---- -3 Float64 37.4255 Float64 698.56005 Decimal128(14, 8) 15041.868333 Decimal256(54, 6) 4 56.52525 957.11074444 1272562.81625 +3 Float64 37.4255 Float64 698.56005 Decimal64(14, 8) 15041.868333 Decimal256(54, 6) 4 56.52525 957.11074444 1272562.81625 query RRRR rowsort select diff --git a/datafusion/sqllogictest/test_files/decimal.slt b/datafusion/sqllogictest/test_files/decimal.slt index 502821fcc304..c0b1d25db075 100644 --- a/datafusion/sqllogictest/test_files/decimal.slt +++ b/datafusion/sqllogictest/test_files/decimal.slt @@ -19,14 +19,14 @@ query TR select arrow_typeof(cast(1.23 as decimal(10,4))), cast(1.23 as decimal(10,4)); ---- -Decimal128(10, 4) 1.23 +Decimal64(10, 4) 1.23 query TR select arrow_typeof(cast(cast(1.23 as decimal(10,3)) as decimal(10,4))), cast(cast(1.23 as decimal(10,3)) as decimal(10,4)); ---- -Decimal128(10, 4) 1.23 +Decimal64(10, 4) 1.23 query TR @@ -37,7 +37,7 @@ Decimal128(24, 2) 1.23 statement ok CREATE EXTERNAL TABLE decimal_simple ( -c1 DECIMAL(10,6) NOT NULL, +c1 DECIMAL(20,6) NOT NULL, c2 DOUBLE NOT NULL, c3 BIGINT NOT NULL, c4 BOOLEAN NOT NULL, @@ -51,7 +51,7 @@ OPTIONS ('format.has_header' 'true'); query TT select arrow_typeof(c1), arrow_typeof(c5) from decimal_simple where c1 > c5 limit 1; ---- -Decimal128(10, 6) Decimal128(12, 7) +Decimal128(20, 6) Decimal64(12, 7) query R rowsort @@ -99,13 +99,13 @@ select * from decimal_simple where c1 > c5; query TR select arrow_typeof(min(c1)), min(c1) from decimal_simple where c4=false; ---- -Decimal128(10, 6) 0.00002 +Decimal128(20, 6) 0.00002 query TR select arrow_typeof(max(c1)), max(c1) from decimal_simple where c4=false; ---- -Decimal128(10, 6) 0.00005 +Decimal128(20, 6) 0.00005 # inferred precision is 10+10 @@ -113,7 +113,7 @@ Decimal128(10, 6) 0.00005 query TR select arrow_typeof(sum(c1)), sum(c1) from decimal_simple; ---- -Decimal128(20, 6) 0.00055 +Decimal128(30, 6) 0.00055 # inferred precision is original precision + 4 @@ -121,13 +121,13 @@ Decimal128(20, 6) 0.00055 query TR select arrow_typeof(avg(c1)), avg(c1) from decimal_simple; ---- -Decimal128(14, 10) 0.0000366666 +Decimal128(24, 10) 0.0000366666 query TR select arrow_typeof(median(c1)), median(c1) from decimal_simple; ---- -Decimal128(10, 6) 0.00004 +Decimal128(20, 6) 0.00004 query RRIBR rowsort @@ -235,7 +235,7 @@ select c1+1 from decimal_simple; query T select arrow_typeof(c1+c5) from decimal_simple limit 1; ---- -Decimal128(13, 7) +Decimal128(22, 7) query R rowsort @@ -287,7 +287,7 @@ select c1-1 from decimal_simple; query T select arrow_typeof(c1-c5) from decimal_simple limit 1; ---- -Decimal128(13, 7) +Decimal128(22, 7) query R rowsort @@ -313,7 +313,7 @@ select c1-c5 from decimal_simple; query T select arrow_typeof(c1*20) from decimal_simple limit 1; ---- -Decimal128(31, 6) +Decimal128(38, 6) query R rowsort @@ -339,7 +339,7 @@ select c1*20 from decimal_simple; query T select arrow_typeof(c1*c5) from decimal_simple limit 1; ---- -Decimal128(23, 13) +Decimal128(38, 14) query R rowsort @@ -365,7 +365,7 @@ select c1*c5 from decimal_simple; query T select arrow_typeof(c1/cast(0.00001 as decimal(5,5))) from decimal_simple limit 1; ---- -Decimal128(19, 10) +Decimal128(30, 10) query R rowsort @@ -391,33 +391,33 @@ select c1/cast(0.00001 as decimal(5,5)) from decimal_simple; query T select arrow_typeof(c1/c5) from decimal_simple limit 1; ---- -Decimal128(21, 10) +Decimal128(32, 11) query R rowsort select c1/c5 from decimal_simple; ---- 0.5 -0.641025641 -0.7142857142 -0.7352941176 +0.64102564102 +0.71428571428 +0.73529411764 0.8 -0.8571428571 -0.909090909 -0.909090909 +0.85714285714 +0.90909090909 +0.90909090909 0.9375 -0.9615384615 +0.96153846153 1 1 -1.0526315789 -1.5151515151 -2.7272727272 +1.05263157894 +1.51515151515 +2.72727272727 query T select arrow_typeof(c5%cast(0.00001 as decimal(5,5))) from decimal_simple limit 1; ---- -Decimal128(7, 7) +Decimal64(12, 7) query R rowsort @@ -443,7 +443,7 @@ select c5%cast(0.00001 as decimal(5,5)) from decimal_simple; query T select arrow_typeof(c1%c5) from decimal_simple limit 1; ---- -Decimal128(11, 7) +Decimal128(21, 7) query R rowsort @@ -469,7 +469,7 @@ select c1%c5 from decimal_simple; query T select arrow_typeof(abs(c1)) from decimal_simple limit 1; ---- -Decimal128(10, 6) +Decimal128(20, 6) query R rowsort @@ -578,7 +578,7 @@ select count(*),c1,c4 from decimal_simple group by c1,c4 order by c1,c4; query TR select arrow_typeof(cast(400420638.54 as decimal(12,2))), cast(400420638.54 as decimal(12,2)); ---- -Decimal128(12, 2) 400420638.54 +Decimal64(12, 2) 400420638.54 query TR diff --git a/datafusion/sqllogictest/test_files/explain_tree.slt b/datafusion/sqllogictest/test_files/explain_tree.slt index 7d70a892af0c..d3ef1268decd 100644 --- a/datafusion/sqllogictest/test_files/explain_tree.slt +++ b/datafusion/sqllogictest/test_files/explain_tree.slt @@ -1314,7 +1314,7 @@ physical_plan 11)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ 12)│ DataSourceExec ││ DataSourceExec │ 13)│ -------------------- ││ -------------------- │ -14)│ bytes: 5932 ││ bytes: 5932 │ +14)│ bytes: 5884 ││ bytes: 5900 │ 15)│ format: memory ││ format: memory │ 16)│ rows: 1 ││ rows: 1 │ 17)└───────────────────────────┘└───────────────────────────┘ diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index 20f79622a62c..9ba1e3d8c23d 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -575,7 +575,7 @@ statement ok CREATE TABLE products ( product_id INT PRIMARY KEY, product_name VARCHAR(100), -price DECIMAL(10, 2)) +price DECIMAL(20, 2)) statement ok INSERT INTO products (product_id, product_name, price) VALUES diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index c24b0777cc13..4aae397c06ef 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -2767,18 +2767,18 @@ query TT explain select * from hashjoin_datatype_table_t1 t1 right join hashjoin_datatype_table_t2 t2 on t1.c3 = t2.c3 ---- logical_plan -01)Right Join: CAST(t1.c3 AS Decimal128(10, 2)) = t2.c3 +01)Right Join: CAST(t1.c3 AS Decimal64(10, 2)) = t2.c3 02)--SubqueryAlias: t1 03)----TableScan: hashjoin_datatype_table_t1 projection=[c1, c2, c3, c4] 04)--SubqueryAlias: t2 05)----TableScan: hashjoin_datatype_table_t2 projection=[c1, c2, c3, c4] physical_plan 01)ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3, c4@3 as c4, c1@5 as c1, c2@6 as c2, c3@7 as c3, c4@8 as c4] -02)--SortMergeJoin: join_type=Right, on=[(CAST(t1.c3 AS Decimal128(10, 2))@4, c3@2)] -03)----SortExec: expr=[CAST(t1.c3 AS Decimal128(10, 2))@4 ASC], preserve_partitioning=[true] +02)--SortMergeJoin: join_type=Right, on=[(CAST(t1.c3 AS Decimal64(10, 2))@4, c3@2)] +03)----SortExec: expr=[CAST(t1.c3 AS Decimal64(10, 2))@4 ASC], preserve_partitioning=[true] 04)------CoalesceBatchesExec: target_batch_size=2 -05)--------RepartitionExec: partitioning=Hash([CAST(t1.c3 AS Decimal128(10, 2))@4], 2), input_partitions=2 -06)----------ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3, c4@3 as c4, CAST(c3@2 AS Decimal128(10, 2)) as CAST(t1.c3 AS Decimal128(10, 2))] +05)--------RepartitionExec: partitioning=Hash([CAST(t1.c3 AS Decimal64(10, 2))@4], 2), input_partitions=2 +06)----------ProjectionExec: expr=[c1@0 as c1, c2@1 as c2, c3@2 as c3, c4@3 as c4, CAST(c3@2 AS Decimal64(10, 2)) as CAST(t1.c3 AS Decimal64(10, 2))] 07)------------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 08)--------------DataSourceExec: partitions=1, partition_sizes=[1] 09)----SortExec: expr=[c3@2 ASC], preserve_partitioning=[true] diff --git a/datafusion/sqllogictest/test_files/math.slt b/datafusion/sqllogictest/test_files/math.slt index e206aa16b8a9..bf04b2a1133f 100644 --- a/datafusion/sqllogictest/test_files/math.slt +++ b/datafusion/sqllogictest/test_files/math.slt @@ -477,7 +477,7 @@ drop table test_non_nullable_float statement ok CREATE TABLE test_nullable_decimal( - c1 DECIMAL(10, 2), /* Decimal128 */ + c1 DECIMAL(20, 2), /* Decimal128 */ c2 DECIMAL(38, 10), /* Decimal128 with max precision */ c3 DECIMAL(40, 2), /* Decimal256 */ c4 DECIMAL(76, 10) /* Decimal256 with max precision */ @@ -540,7 +540,7 @@ SELECT arrow_typeof(abs(c4)) FROM test_nullable_decimal limit 1 ---- -Decimal128(10, 2) Decimal128(38, 10) Decimal256(40, 2) Decimal256(76, 10) +Decimal128(20, 2) Decimal128(38, 10) Decimal256(40, 2) Decimal256(76, 10) # abs: decimals query RRRR rowsort @@ -556,7 +556,7 @@ drop table test_nullable_decimal statement ok -CREATE TABLE test_non_nullable_decimal(c1 DECIMAL(9,2) NOT NULL); +CREATE TABLE test_non_nullable_decimal(c1 DECIMAL(19,2) NOT NULL); query I INSERT INTO test_non_nullable_decimal VALUES(1) diff --git a/datafusion/sqllogictest/test_files/options.slt b/datafusion/sqllogictest/test_files/options.slt index 71ff12e8cc50..7cabba2c47c6 100644 --- a/datafusion/sqllogictest/test_files/options.slt +++ b/datafusion/sqllogictest/test_files/options.slt @@ -147,7 +147,7 @@ select .0 as c1, 0. as c2, 0000. as c3, 00000.00 as c4 query TTTT select arrow_typeof(.0) as c1, arrow_typeof(0.) as c2, arrow_typeof(0000.) as c3, arrow_typeof(00000.00) as c4 ---- -Decimal128(1, 1) Decimal128(1, 0) Decimal128(1, 0) Decimal128(2, 2) +Decimal32(1, 1) Decimal32(1, 0) Decimal32(1, 0) Decimal32(2, 2) query RR select 999999999999999999999999999999999999, -999999999999999999999999999999999999 @@ -198,14 +198,14 @@ select 1.23e3, arrow_typeof(1.23e3), +1.23e1, arrow_typeof(+1.23e1), -1234.56e-3, arrow_typeof(-1234.56e-3) ---- -1230 Decimal128(3, -1) 12.3 Decimal128(3, 1) -1.23456 Decimal128(6, 5) +1230 Decimal32(3, -1) 12.3 Decimal32(3, 1) -1.23456 Decimal32(6, 5) query RTRTRT select 1.23e-2, arrow_typeof(1.23e-2), 1.23456e0, arrow_typeof(1.23456e0), -.0123e2, arrow_typeof(-.0123e2) ---- -0.0123 Decimal128(4, 4) 1.23456 Decimal128(6, 5) -1.23 Decimal128(3, 2) +0.0123 Decimal32(4, 4) 1.23456 Decimal32(6, 5) -1.23 Decimal32(3, 2) # Decimal256 cases query RT @@ -265,12 +265,12 @@ select 1.00000000000000000000000000000000000000000000000000000000000000000000000 query TR select arrow_typeof(1e77), 1e77 ---- -Decimal128(1, -77) 100000000000000000000000000000000000000000000000000000000000000000000000000000 +Decimal32(1, -77) 100000000000000000000000000000000000000000000000000000000000000000000000000000 query T select arrow_typeof(1e128) ---- -Decimal128(1, -128) +Decimal32(1, -128) query error Decimal scale \-129 exceeds the minimum supported scale: \-128 select 1e129 @@ -280,7 +280,7 @@ query RTRT select 1e40 + 1e40, arrow_typeof(1e40 + 1e40), 1e-40 + -1e-40, arrow_typeof(1e-40 + -1e-40) ---- -20000000000000000000000000000000000000000 Decimal128(2, -40) 0 Decimal256(41, 40) +20000000000000000000000000000000000000000 Decimal32(2, -40) 0 Decimal256(41, 40) # Restore option to default value statement ok diff --git a/datafusion/sqllogictest/test_files/predicates.slt b/datafusion/sqllogictest/test_files/predicates.slt index 77ee3e4f05a0..e59c0667be74 100644 --- a/datafusion/sqllogictest/test_files/predicates.slt +++ b/datafusion/sqllogictest/test_files/predicates.slt @@ -662,9 +662,9 @@ OR ---- logical_plan 01)Projection: lineitem.l_partkey -02)--Inner Join: lineitem.l_partkey = part.p_partkey Filter: part.p_brand = Utf8View("Brand#12") AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND part.p_size <= Int32(15) -03)----Filter: lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) -04)------TableScan: lineitem projection=[l_partkey, l_quantity], partial_filters=[lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)] +02)--Inner Join: lineitem.l_partkey = part.p_partkey Filter: part.p_brand = Utf8View("Brand#12") AND lineitem.l_quantity >= Decimal64(Some(100),15,2) AND lineitem.l_quantity <= Decimal64(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND lineitem.l_quantity >= Decimal64(Some(1000),15,2) AND lineitem.l_quantity <= Decimal64(Some(2000),15,2) AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND lineitem.l_quantity >= Decimal64(Some(2000),15,2) AND lineitem.l_quantity <= Decimal64(Some(3000),15,2) AND part.p_size <= Int32(15) +03)----Filter: lineitem.l_quantity >= Decimal64(Some(100),15,2) AND lineitem.l_quantity <= Decimal64(Some(1100),15,2) OR lineitem.l_quantity >= Decimal64(Some(1000),15,2) AND lineitem.l_quantity <= Decimal64(Some(2000),15,2) OR lineitem.l_quantity >= Decimal64(Some(2000),15,2) AND lineitem.l_quantity <= Decimal64(Some(3000),15,2) +04)------TableScan: lineitem projection=[l_partkey, l_quantity], partial_filters=[lineitem.l_quantity >= Decimal64(Some(100),15,2) AND lineitem.l_quantity <= Decimal64(Some(1100),15,2) OR lineitem.l_quantity >= Decimal64(Some(1000),15,2) AND lineitem.l_quantity <= Decimal64(Some(2000),15,2) OR lineitem.l_quantity >= Decimal64(Some(2000),15,2) AND lineitem.l_quantity <= Decimal64(Some(3000),15,2)] 05)----Filter: (part.p_brand = Utf8View("Brand#12") AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND part.p_size <= Int32(15)) AND part.p_size >= Int32(1) 06)------TableScan: part projection=[p_partkey, p_brand, p_size], partial_filters=[part.p_size >= Int32(1), part.p_brand = Utf8View("Brand#12") AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND part.p_size <= Int32(15)] physical_plan diff --git a/datafusion/sqllogictest/test_files/qualify.slt b/datafusion/sqllogictest/test_files/qualify.slt index d53b56ce58de..80abcbd76d0b 100644 --- a/datafusion/sqllogictest/test_files/qualify.slt +++ b/datafusion/sqllogictest/test_files/qualify.slt @@ -308,25 +308,25 @@ QUALIFY r > 60000 ---- logical_plan 01)Projection: users.dept, avg(users.salary) PARTITION BY [users.dept] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING AS r -02)--Filter: avg(users.salary) PARTITION BY [users.dept] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING > Decimal128(Some(60000000000),14,6) +02)--Filter: avg(users.salary) PARTITION BY [users.dept] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING > Decimal64(Some(60000000000),14,6) 03)----Projection: users.dept, avg(users.salary) PARTITION BY [users.dept] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING 04)------WindowAggr: windowExpr=[[avg(users.salary) PARTITION BY [users.dept] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] 05)--------Projection: users.dept, users.salary -06)----------Filter: sum(users.salary) > Decimal128(Some(2000000),20,2) +06)----------Filter: sum(users.salary) > Decimal64(Some(2000000),18,2) 07)------------Aggregate: groupBy=[[users.dept, users.salary]], aggr=[[sum(users.salary)]] -08)--------------Filter: users.salary > Decimal128(Some(500000),10,2) +08)--------------Filter: users.salary > Decimal64(Some(500000),10,2) 09)----------------TableScan: users projection=[salary, dept] physical_plan 01)ProjectionExec: expr=[dept@0 as dept, avg(users.salary) PARTITION BY [users.dept] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@1 as r] 02)--CoalesceBatchesExec: target_batch_size=8192 03)----FilterExec: avg(users.salary) PARTITION BY [users.dept] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@1 > Some(60000000000),14,6 04)------ProjectionExec: expr=[dept@0 as dept, avg(users.salary) PARTITION BY [users.dept] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@2 as avg(users.salary) PARTITION BY [users.dept] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING] -05)--------WindowAggExec: wdw=[avg(users.salary) PARTITION BY [users.dept] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "avg(users.salary) PARTITION BY [users.dept] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: Decimal128(14, 6), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] +05)--------WindowAggExec: wdw=[avg(users.salary) PARTITION BY [users.dept] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "avg(users.salary) PARTITION BY [users.dept] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: Decimal64(14, 6), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }] 06)----------SortExec: expr=[dept@0 ASC NULLS LAST], preserve_partitioning=[true] 07)------------CoalesceBatchesExec: target_batch_size=8192 08)--------------RepartitionExec: partitioning=Hash([dept@0], 4), input_partitions=4 09)----------------CoalesceBatchesExec: target_batch_size=8192 -10)------------------FilterExec: sum(users.salary)@2 > Some(2000000),20,2, projection=[dept@0, salary@1] +10)------------------FilterExec: sum(users.salary)@2 > Some(2000000),18,2, projection=[dept@0, salary@1] 11)--------------------AggregateExec: mode=FinalPartitioned, gby=[dept@0 as dept, salary@1 as salary], aggr=[sum(users.salary)] 12)----------------------CoalesceBatchesExec: target_batch_size=8192 13)------------------------RepartitionExec: partitioning=Hash([dept@0, salary@1], 4), input_partitions=4 diff --git a/datafusion/sqllogictest/test_files/spark/math/mod.slt b/datafusion/sqllogictest/test_files/spark/math/mod.slt index 2780b3e1053d..3043cc617ab9 100644 --- a/datafusion/sqllogictest/test_files/spark/math/mod.slt +++ b/datafusion/sqllogictest/test_files/spark/math/mod.slt @@ -24,7 +24,7 @@ ## Original Query: SELECT MOD(2, 1.8); ## PySpark 3.5.5 Result: {'mod(2, 1.8)': Decimal('0.2'), 'typeof(mod(2, 1.8))': 'decimal(2,1)', 'typeof(2)': 'int', 'typeof(1.8)': 'decimal(2,1)'} query R -SELECT MOD(2::int, 1.8::decimal(2,1)); +SELECT MOD(2::int, 1.8::decimal(10,1)); ---- 0.2 diff --git a/datafusion/sqllogictest/test_files/subquery.slt b/datafusion/sqllogictest/test_files/subquery.slt index 43f85d1e2014..f6c9ffe7652f 100644 --- a/datafusion/sqllogictest/test_files/subquery.slt +++ b/datafusion/sqllogictest/test_files/subquery.slt @@ -348,13 +348,13 @@ where c_acctbal < ( logical_plan 01)Sort: customer.c_custkey ASC NULLS LAST 02)--Projection: customer.c_custkey -03)----Inner Join: customer.c_custkey = __scalar_sq_1.o_custkey Filter: CAST(customer.c_acctbal AS Decimal128(25, 2)) < __scalar_sq_1.sum(orders.o_totalprice) +03)----Inner Join: customer.c_custkey = __scalar_sq_1.o_custkey Filter: CAST(customer.c_acctbal AS Decimal64(18, 2)) < __scalar_sq_1.sum(orders.o_totalprice) 04)------TableScan: customer projection=[c_custkey, c_acctbal] 05)------SubqueryAlias: __scalar_sq_1 06)--------Projection: sum(orders.o_totalprice), orders.o_custkey 07)----------Aggregate: groupBy=[[orders.o_custkey]], aggr=[[sum(orders.o_totalprice)]] 08)------------Projection: orders.o_custkey, orders.o_totalprice -09)--------------Inner Join: orders.o_orderkey = __scalar_sq_2.l_orderkey Filter: CAST(orders.o_totalprice AS Decimal128(25, 2)) < __scalar_sq_2.price +09)--------------Inner Join: orders.o_orderkey = __scalar_sq_2.l_orderkey Filter: CAST(orders.o_totalprice AS Decimal64(18, 2)) < __scalar_sq_2.price 10)----------------TableScan: orders projection=[o_orderkey, o_custkey, o_totalprice] 11)----------------SubqueryAlias: __scalar_sq_2 12)------------------Projection: sum(lineitem.l_extendedprice) AS price, lineitem.l_orderkey