diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml index 6f8f81401f3b..7d6c66769358 100644 --- a/datafusion/substrait/Cargo.toml +++ b/datafusion/substrait/Cargo.toml @@ -41,7 +41,7 @@ object_store = { workspace = true } pbjson-types = "0.7" # TODO use workspace version prost = "0.13" -substrait = { version = "0.42", features = ["serde"] } +substrait = { version = "0.45", features = ["serde"] } url = { workspace = true } [dev-dependencies] diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 888480774996..9536bc696ba5 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -197,6 +197,65 @@ fn split_eq_and_noneq_join_predicate_with_nulls_equality( (accum_join_keys, nulls_equal_nulls, join_filter) } +async fn union_rels( + rels: &[Rel], + ctx: &SessionContext, + extensions: &Extensions, + is_all: bool, +) -> Result { + let mut union_builder = Ok(LogicalPlanBuilder::from( + from_substrait_rel(ctx, &rels[0], extensions).await?, + )); + for input in &rels[1..] { + let rel_plan = from_substrait_rel(ctx, input, extensions).await?; + + union_builder = if is_all { + union_builder?.union(rel_plan) + } else { + union_builder?.union_distinct(rel_plan) + }; + } + union_builder?.build() +} + +async fn intersect_rels( + rels: &[Rel], + ctx: &SessionContext, + extensions: &Extensions, + is_all: bool, +) -> Result { + let mut rel = from_substrait_rel(ctx, &rels[0], extensions).await?; + + for input in &rels[1..] { + rel = LogicalPlanBuilder::intersect( + rel, + from_substrait_rel(ctx, input, extensions).await?, + is_all, + )? + } + + Ok(rel) +} + +async fn except_rels( + rels: &[Rel], + ctx: &SessionContext, + extensions: &Extensions, + is_all: bool, +) -> Result { + let mut rel = from_substrait_rel(ctx, &rels[0], extensions).await?; + + for input in &rels[1..] { + rel = LogicalPlanBuilder::except( + rel, + from_substrait_rel(ctx, input, extensions).await?, + is_all, + )? + } + + Ok(rel) +} + /// Convert Substrait Plan to DataFusion LogicalPlan pub async fn from_substrait_plan( ctx: &SessionContext, @@ -513,6 +572,7 @@ fn make_renamed_schema( } /// Convert Substrait Rel to DataFusion DataFrame +#[allow(deprecated)] #[async_recursion] pub async fn from_substrait_rel( ctx: &SessionContext, @@ -875,27 +935,65 @@ pub async fn from_substrait_rel( Ok(set_op) => match set_op { set_rel::SetOp::UnionAll => { if !set.inputs.is_empty() { - let mut union_builder = Ok(LogicalPlanBuilder::from( - from_substrait_rel(ctx, &set.inputs[0], extensions).await?, - )); - for input in &set.inputs[1..] { - union_builder = union_builder? - .union(from_substrait_rel(ctx, input, extensions).await?); - } - union_builder?.build() + union_rels(&set.inputs, ctx, extensions, true).await + } else { + not_impl_err!("Union relation requires at least one input") + } + } + set_rel::SetOp::UnionDistinct => { + if !set.inputs.is_empty() { + union_rels(&set.inputs, ctx, extensions, false).await } else { not_impl_err!("Union relation requires at least one input") } } set_rel::SetOp::IntersectionPrimary => { - if set.inputs.len() == 2 { + if set.inputs.len() >= 2 { LogicalPlanBuilder::intersect( from_substrait_rel(ctx, &set.inputs[0], extensions).await?, - from_substrait_rel(ctx, &set.inputs[1], extensions).await?, + union_rels(&set.inputs[1..], ctx, extensions, true).await?, false, ) } else { - not_impl_err!("Primary Intersect relation with more than two inputs isn't supported") + not_impl_err!( + "Primary Intersect relation requires at least two inputs" + ) + } + } + set_rel::SetOp::IntersectionMultiset => { + if set.inputs.len() >= 2 { + intersect_rels(&set.inputs, ctx, extensions, false).await + } else { + not_impl_err!( + "Multiset Intersect relation requires at least two inputs" + ) + } + } + set_rel::SetOp::IntersectionMultisetAll => { + if set.inputs.len() >= 2 { + intersect_rels(&set.inputs, ctx, extensions, true).await + } else { + not_impl_err!( + "MultisetAll Intersect relation requires at least two inputs" + ) + } + } + set_rel::SetOp::MinusPrimary => { + if set.inputs.len() >= 2 { + except_rels(&set.inputs, ctx, extensions, false).await + } else { + not_impl_err!( + "Primary Minus relation requires at least two inputs" + ) + } + } + set_rel::SetOp::MinusPrimaryAll => { + if set.inputs.len() >= 2 { + except_rels(&set.inputs, ctx, extensions, true).await + } else { + not_impl_err!( + "PrimaryAll Minus relation requires at least two inputs" + ) } } _ => not_impl_err!("Unsupported set operator: {set_op:?}"), diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 1165ce13d236..0e1375a8e0ea 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -172,6 +172,7 @@ pub fn to_substrait_extended_expr( } /// Convert DataFusion LogicalPlan to Substrait Rel +#[allow(deprecated)] pub fn to_substrait_rel( plan: &LogicalPlan, ctx: &SessionContext, @@ -227,6 +228,7 @@ pub fn to_substrait_rel( advanced_extension: None, read_type: Some(ReadType::VirtualTable(VirtualTable { values: vec![], + expressions: vec![], })), }))), })) @@ -263,7 +265,10 @@ pub fn to_substrait_rel( best_effort_filter: None, projection: None, advanced_extension: None, - read_type: Some(ReadType::VirtualTable(VirtualTable { values })), + read_type: Some(ReadType::VirtualTable(VirtualTable { + values, + expressions: vec![], + })), }))), })) } @@ -359,6 +364,7 @@ pub fn to_substrait_rel( rel_type: Some(RelType::Aggregate(Box::new(AggregateRel { common: None, input: Some(input), + grouping_expressions: vec![], groupings, measures, advanced_extension: None, @@ -377,8 +383,10 @@ pub fn to_substrait_rel( rel_type: Some(RelType::Aggregate(Box::new(AggregateRel { common: None, input: Some(input), + grouping_expressions: vec![], groupings: vec![Grouping { grouping_expressions: grouping, + expression_references: vec![], }], measures: vec![], advanced_extension: None, @@ -764,6 +772,7 @@ pub fn operator_to_name(op: Operator) -> &'static str { } } +#[allow(deprecated)] pub fn parse_flat_grouping_exprs( ctx: &SessionContext, exprs: &[Expr], @@ -776,6 +785,7 @@ pub fn parse_flat_grouping_exprs( .collect::>>()?; Ok(Grouping { grouping_expressions, + expression_references: vec![], }) } diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 80caaafad6f6..9297459b81c4 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -675,6 +675,72 @@ async fn simple_intersect_consume() -> Result<()> { .await } +#[tokio::test] +async fn primary_intersect_consume() -> Result<()> { + let proto_plan = + read_json("tests/testdata/test_plans/intersect_primary.substrait.json"); + + assert_substrait_sql( + proto_plan, + "SELECT a FROM data INTERSECT (SELECT a FROM data2 UNION ALL SELECT a FROM data2)", + ) + .await +} + +#[tokio::test] +async fn multiset_intersect_consume() -> Result<()> { + let proto_plan = + read_json("tests/testdata/test_plans/intersect_multiset.substrait.json"); + + assert_substrait_sql( + proto_plan, + "SELECT a FROM data INTERSECT SELECT a FROM data2 INTERSECT SELECT a FROM data2", + ) + .await +} + +#[tokio::test] +async fn multiset_intersect_all_consume() -> Result<()> { + let proto_plan = + read_json("tests/testdata/test_plans/intersect_multiset_all.substrait.json"); + + assert_substrait_sql( + proto_plan, + "SELECT a FROM data INTERSECT ALL SELECT a FROM data2 INTERSECT ALL SELECT a FROM data2", + ) + .await +} + +#[tokio::test] +async fn primary_except_consume() -> Result<()> { + let proto_plan = read_json("tests/testdata/test_plans/minus_primary.substrait.json"); + + assert_substrait_sql( + proto_plan, + "SELECT a FROM data EXCEPT SELECT a FROM data2 EXCEPT SELECT a FROM data2", + ) + .await +} + +#[tokio::test] +async fn primary_except_all_consume() -> Result<()> { + let proto_plan = + read_json("tests/testdata/test_plans/minus_primary_all.substrait.json"); + + assert_substrait_sql( + proto_plan, + "SELECT a FROM data EXCEPT ALL SELECT a FROM data2 EXCEPT ALL SELECT a FROM data2", + ) + .await +} + +#[tokio::test] +async fn union_distinct_consume() -> Result<()> { + let proto_plan = read_json("tests/testdata/test_plans/union_distinct.substrait.json"); + + assert_substrait_sql(proto_plan, "SELECT a FROM data UNION SELECT a FROM data2").await +} + #[tokio::test] async fn simple_intersect_table_reuse() -> Result<()> { // Substrait does currently NOT maintain the alias of the tables. diff --git a/datafusion/substrait/tests/testdata/test_plans/intersect_multiset.substrait.json b/datafusion/substrait/tests/testdata/test_plans/intersect_multiset.substrait.json new file mode 100644 index 000000000000..8ff69bd82c3a --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/intersect_multiset.substrait.json @@ -0,0 +1,166 @@ +{ + "relations": [ + { + "root": { + "input": { + "set": { + "inputs": [ + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + }, + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data2" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + }, + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data2" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + } + ], + "op": "SET_OP_INTERSECTION_MULTISET" + } + }, + "names": [ + "a" + ] + } + } + ], + "version": { + "minorNumber": 54, + "producer": "subframe" + } + } \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/test_plans/intersect_multiset_all.substrait.json b/datafusion/substrait/tests/testdata/test_plans/intersect_multiset_all.substrait.json new file mode 100644 index 000000000000..56daf6ed46f4 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/intersect_multiset_all.substrait.json @@ -0,0 +1,166 @@ +{ + "relations": [ + { + "root": { + "input": { + "set": { + "inputs": [ + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + }, + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data2" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + }, + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data2" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + } + ], + "op": "SET_OP_INTERSECTION_MULTISET_ALL" + } + }, + "names": [ + "a" + ] + } + } + ], + "version": { + "minorNumber": 54, + "producer": "subframe" + } + } \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/test_plans/intersect_primary.substrait.json b/datafusion/substrait/tests/testdata/test_plans/intersect_primary.substrait.json new file mode 100644 index 000000000000..229dd7251705 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/intersect_primary.substrait.json @@ -0,0 +1,166 @@ +{ + "relations": [ + { + "root": { + "input": { + "set": { + "inputs": [ + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + }, + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data2" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + }, + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data2" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + } + ], + "op": "SET_OP_INTERSECTION_PRIMARY" + } + }, + "names": [ + "a" + ] + } + } + ], + "version": { + "minorNumber": 54, + "producer": "subframe" + } + } \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/test_plans/minus_primary.substrait.json b/datafusion/substrait/tests/testdata/test_plans/minus_primary.substrait.json new file mode 100644 index 000000000000..33b0e2ab8c80 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/minus_primary.substrait.json @@ -0,0 +1,166 @@ +{ + "relations": [ + { + "root": { + "input": { + "set": { + "inputs": [ + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + }, + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data2" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + }, + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data2" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + } + ], + "op": "SET_OP_MINUS_PRIMARY" + } + }, + "names": [ + "a" + ] + } + } + ], + "version": { + "minorNumber": 54, + "producer": "subframe" + } + } \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/test_plans/minus_primary_all.substrait.json b/datafusion/substrait/tests/testdata/test_plans/minus_primary_all.substrait.json new file mode 100644 index 000000000000..229f78ab5bf6 --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/minus_primary_all.substrait.json @@ -0,0 +1,166 @@ +{ + "relations": [ + { + "root": { + "input": { + "set": { + "inputs": [ + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + }, + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data2" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + }, + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data2" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + } + ], + "op": "SET_OP_MINUS_PRIMARY_ALL" + } + }, + "names": [ + "a" + ] + } + } + ], + "version": { + "minorNumber": 54, + "producer": "subframe" + } + } \ No newline at end of file diff --git a/datafusion/substrait/tests/testdata/test_plans/union_distinct.substrait.json b/datafusion/substrait/tests/testdata/test_plans/union_distinct.substrait.json new file mode 100644 index 000000000000..e8b02749660d --- /dev/null +++ b/datafusion/substrait/tests/testdata/test_plans/union_distinct.substrait.json @@ -0,0 +1,118 @@ +{ + "relations": [ + { + "root": { + "input": { + "set": { + "inputs": [ + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + }, + { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_NULLABLE" + } + }, + "namedTable": { + "names": [ + "data2" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + } + ], + "op": "SET_OP_UNION_DISTINCT" + } + }, + "names": [ + "a" + ] + } + } + ], + "version": { + "minorNumber": 54, + "producer": "subframe" + } +} \ No newline at end of file diff --git a/datafusion/substrait/tests/utils.rs b/datafusion/substrait/tests/utils.rs index 9f63b74ef0fc..00cbfb0c412c 100644 --- a/datafusion/substrait/tests/utils.rs +++ b/datafusion/substrait/tests/utils.rs @@ -147,6 +147,7 @@ pub mod test { Ok(()) } + #[allow(deprecated)] fn collect_schemas_from_rel(&mut self, rel: &Rel) -> Result<()> { let rel_type = rel .rel_type