Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions datafusion/common/src/scalar/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1362,6 +1362,12 @@ impl ScalarValue {
DataType::Float16 => ScalarValue::Float16(Some(f16::from_f32(0.0))),
DataType::Float32 => ScalarValue::Float32(Some(0.0)),
DataType::Float64 => ScalarValue::Float64(Some(0.0)),
DataType::Decimal32(precision, scale) => {
ScalarValue::Decimal32(Some(0), *precision, *scale)
}
DataType::Decimal64(precision, scale) => {
ScalarValue::Decimal64(Some(0), *precision, *scale)
}
DataType::Decimal128(precision, scale) => {
ScalarValue::Decimal128(Some(0), *precision, *scale)
}
Expand Down
97 changes: 92 additions & 5 deletions datafusion/expr-common/src/type_coercion/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -351,12 +351,21 @@ fn math_decimal_coercion(
Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _),
Null,
) => Some((lhs_type.clone(), lhs_type.clone())),
// Same variant decimals - no coercion needed
(Decimal32(_, _), Decimal32(_, _))
| (Decimal64(_, _), Decimal64(_, _))
| (Decimal128(_, _), Decimal128(_, _))
| (Decimal256(_, _), Decimal256(_, _)) => {
Some((lhs_type.clone(), rhs_type.clone()))
}
// Cross-variant decimal coercion - choose larger variant with appropriate precision/scale
(Decimal32(_, _), Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _))
| (Decimal64(_, _), Decimal32(_, _) | Decimal128(_, _) | Decimal256(_, _))
| (Decimal128(_, _), Decimal32(_, _) | Decimal64(_, _) | Decimal256(_, _))
| (Decimal256(_, _), Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _)) => {
let coerced_type = get_wider_decimal_type_cross_variant(lhs_type, rhs_type)?;
Some((coerced_type.clone(), coerced_type))
}
// Unlike with comparison we don't coerce to a decimal in the case of floating point
// numbers, instead falling back to floating point arithmetic instead
(
Expand Down Expand Up @@ -955,28 +964,106 @@ pub fn decimal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<Data

match (lhs_type, rhs_type) {
// Prefer decimal data type over floating point for comparison operation
(Decimal32(_, _), Decimal32(_, _)) => get_wider_decimal_type(lhs_type, rhs_type),
(Decimal32(_, _), Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _)) => {
get_wider_decimal_type_cross_variant(lhs_type, rhs_type)
}
(Decimal32(_, _), _) => get_common_decimal_type(lhs_type, rhs_type),
(Decimal64(_, _), Decimal64(_, _)) => get_wider_decimal_type(lhs_type, rhs_type),
(Decimal64(_, _), Decimal32(_, _) | Decimal128(_, _) | Decimal256(_, _)) => {
get_wider_decimal_type_cross_variant(lhs_type, rhs_type)
}
(Decimal64(_, _), _) => get_common_decimal_type(lhs_type, rhs_type),
(Decimal128(_, _), Decimal128(_, _)) => {
get_wider_decimal_type(lhs_type, rhs_type)
}
(Decimal128(_, _), Decimal32(_, _) | Decimal64(_, _) | Decimal256(_, _)) => {
get_wider_decimal_type_cross_variant(lhs_type, rhs_type)
}
(Decimal128(_, _), _) => get_common_decimal_type(lhs_type, rhs_type),
(_, Decimal128(_, _)) => get_common_decimal_type(rhs_type, lhs_type),
(Decimal256(_, _), Decimal256(_, _)) => {
get_wider_decimal_type(lhs_type, rhs_type)
}
(Decimal256(_, _), Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _)) => {
get_wider_decimal_type_cross_variant(lhs_type, rhs_type)
}
(Decimal256(_, _), _) => get_common_decimal_type(lhs_type, rhs_type),
(_, Decimal32(_, _)) => get_common_decimal_type(rhs_type, lhs_type),
(_, Decimal64(_, _)) => get_common_decimal_type(rhs_type, lhs_type),
(_, Decimal128(_, _)) => get_common_decimal_type(rhs_type, lhs_type),
(_, Decimal256(_, _)) => get_common_decimal_type(rhs_type, lhs_type),
(_, _) => None,
}
}

/// Handle cross-variant decimal widening by choosing the larger variant
fn get_wider_decimal_type_cross_variant(
lhs_type: &DataType,
rhs_type: &DataType,
) -> Option<DataType> {
use arrow::datatypes::DataType::*;

let (p1, s1) = match lhs_type {
Decimal32(p, s) => (*p, *s),
Decimal64(p, s) => (*p, *s),
Decimal128(p, s) => (*p, *s),
Decimal256(p, s) => (*p, *s),
_ => return None,
};

let (p2, s2) = match rhs_type {
Decimal32(p, s) => (*p, *s),
Decimal64(p, s) => (*p, *s),
Decimal128(p, s) => (*p, *s),
Decimal256(p, s) => (*p, *s),
_ => return None,
};

// max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2)
let s = s1.max(s2);
let range = (p1 as i8 - s1).max(p2 as i8 - s2);
let required_precision = (range + s) as u8;

// Choose the larger variant between the two input types
match (lhs_type, rhs_type) {
(Decimal32(_, _), Decimal64(_, _)) | (Decimal64(_, _), Decimal32(_, _)) => {
Some(Decimal64(required_precision, s))
}
(Decimal32(_, _), Decimal128(_, _)) | (Decimal128(_, _), Decimal32(_, _)) => {
Some(Decimal128(required_precision, s))
}
(Decimal32(_, _), Decimal256(_, _)) | (Decimal256(_, _), Decimal32(_, _)) => {
Some(Decimal256(required_precision, s))
}
(Decimal64(_, _), Decimal128(_, _)) | (Decimal128(_, _), Decimal64(_, _)) => {
Some(Decimal128(required_precision, s))
}
(Decimal64(_, _), Decimal256(_, _)) | (Decimal256(_, _), Decimal64(_, _)) => {
Some(Decimal256(required_precision, s))
}
(Decimal128(_, _), Decimal256(_, _)) | (Decimal256(_, _), Decimal128(_, _)) => {
Some(Decimal256(required_precision, s))
}
_ => None,
}
}

/// Coerce `lhs_type` and `rhs_type` to a common type.
fn get_common_decimal_type(
decimal_type: &DataType,
other_type: &DataType,
) -> Option<DataType> {
use arrow::datatypes::DataType::*;
match decimal_type {
Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) => {
Decimal32(_, _) => {
let other_decimal_type = coerce_numeric_type_to_decimal32(other_type)?;
get_wider_decimal_type(decimal_type, &other_decimal_type)
}
Decimal64(_, _) => {
let other_decimal_type = coerce_numeric_type_to_decimal64(other_type)?;
get_wider_decimal_type(decimal_type, &other_decimal_type)
}
Decimal128(_, _) => {
let other_decimal_type = coerce_numeric_type_to_decimal128(other_type)?;
get_wider_decimal_type(decimal_type, &other_decimal_type)
}
Expand All @@ -988,7 +1075,7 @@ fn get_common_decimal_type(
}
}

/// Returns a `DataType::Decimal128` that can store any value from either
/// Returns a decimal [`DataType`] that can store any value from either
/// `lhs_decimal_type` and `rhs_decimal_type`
///
/// The result decimal type is `(max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2))`.
Expand Down Expand Up @@ -1209,14 +1296,14 @@ fn numerical_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataTy
}

fn create_decimal32_type(precision: u8, scale: i8) -> DataType {
DataType::Decimal128(
DataType::Decimal32(
DECIMAL32_MAX_PRECISION.min(precision),
DECIMAL32_MAX_SCALE.min(scale),
)
}

fn create_decimal64_type(precision: u8, scale: i8) -> DataType {
DataType::Decimal128(
DataType::Decimal64(
DECIMAL64_MAX_PRECISION.min(precision),
DECIMAL64_MAX_SCALE.min(scale),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,3 +291,93 @@ fn test_coercion_arithmetic_decimal() -> Result<()> {

Ok(())
}

#[test]
fn test_coercion_arithmetic_decimal_cross_variant() -> Result<()> {
let test_cases = [
(
DataType::Decimal32(5, 2),
DataType::Decimal64(10, 3),
DataType::Decimal64(10, 3),
DataType::Decimal64(10, 3),
),
(
DataType::Decimal32(7, 1),
DataType::Decimal128(15, 4),
DataType::Decimal128(15, 4),
DataType::Decimal128(15, 4),
),
(
DataType::Decimal32(9, 0),
DataType::Decimal256(20, 5),
DataType::Decimal256(20, 5),
DataType::Decimal256(20, 5),
),
(
DataType::Decimal64(12, 3),
DataType::Decimal128(18, 2),
DataType::Decimal128(19, 3),
DataType::Decimal128(19, 3),
),
(
DataType::Decimal64(15, 4),
DataType::Decimal256(25, 6),
DataType::Decimal256(25, 6),
DataType::Decimal256(25, 6),
),
(
DataType::Decimal128(20, 5),
DataType::Decimal256(30, 8),
DataType::Decimal256(30, 8),
DataType::Decimal256(30, 8),
),
// Reverse order cases
(
DataType::Decimal64(10, 3),
DataType::Decimal32(5, 2),
DataType::Decimal64(10, 3),
DataType::Decimal64(10, 3),
),
(
DataType::Decimal128(15, 4),
DataType::Decimal32(7, 1),
DataType::Decimal128(15, 4),
DataType::Decimal128(15, 4),
),
(
DataType::Decimal256(20, 5),
DataType::Decimal32(9, 0),
DataType::Decimal256(20, 5),
DataType::Decimal256(20, 5),
),
(
DataType::Decimal128(18, 2),
DataType::Decimal64(12, 3),
DataType::Decimal128(19, 3),
DataType::Decimal128(19, 3),
),
(
DataType::Decimal256(25, 6),
DataType::Decimal64(15, 4),
DataType::Decimal256(25, 6),
DataType::Decimal256(25, 6),
),
(
DataType::Decimal256(30, 8),
DataType::Decimal128(20, 5),
DataType::Decimal256(30, 8),
DataType::Decimal256(30, 8),
),
];

for (lhs_type, rhs_type, expected_lhs_type, expected_rhs_type) in test_cases {
test_math_decimal_coercion_rule(
lhs_type,
rhs_type,
expected_lhs_type,
expected_rhs_type,
);
}

Ok(())
}
Loading
Loading