Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datafusion/substrait/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ object_store = { workspace = true }
pbjson-types = "0.7"
# TODO use workspace version
prost = "0.13"
substrait = { version = "0.42", features = ["serde"] }
substrait = { version = "0.45", features = ["serde"] }
url = { workspace = true }

[dev-dependencies]
Expand Down
120 changes: 109 additions & 11 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,65 @@ fn split_eq_and_noneq_join_predicate_with_nulls_equality(
(accum_join_keys, nulls_equal_nulls, join_filter)
}

async fn union_rels(
rels: &[Rel],
ctx: &SessionContext,
extensions: &Extensions,
is_all: bool,
) -> Result<LogicalPlan> {
let mut union_builder = Ok(LogicalPlanBuilder::from(
from_substrait_rel(ctx, &rels[0], extensions).await?,
));
for input in &rels[1..] {
let rel_plan = from_substrait_rel(ctx, input, extensions).await?;

union_builder = if is_all {
union_builder?.union(rel_plan)
} else {
union_builder?.union_distinct(rel_plan)
};
}
union_builder?.build()
}

async fn intersect_rels(
rels: &[Rel],
ctx: &SessionContext,
extensions: &Extensions,
is_all: bool,
) -> Result<LogicalPlan> {
let mut rel = from_substrait_rel(ctx, &rels[0], extensions).await?;

for input in &rels[1..] {
rel = LogicalPlanBuilder::intersect(
rel,
from_substrait_rel(ctx, input, extensions).await?,
is_all,
)?
}

Ok(rel)
}

async fn except_rels(
rels: &[Rel],
ctx: &SessionContext,
extensions: &Extensions,
is_all: bool,
) -> Result<LogicalPlan> {
let mut rel = from_substrait_rel(ctx, &rels[0], extensions).await?;

for input in &rels[1..] {
rel = LogicalPlanBuilder::except(
rel,
from_substrait_rel(ctx, input, extensions).await?,
is_all,
)?
}

Ok(rel)
}

/// Convert Substrait Plan to DataFusion LogicalPlan
pub async fn from_substrait_plan(
ctx: &SessionContext,
Expand Down Expand Up @@ -513,6 +572,7 @@ fn make_renamed_schema(
}

/// Convert Substrait Rel to DataFusion DataFrame
#[allow(deprecated)]
#[async_recursion]
pub async fn from_substrait_rel(
ctx: &SessionContext,
Expand Down Expand Up @@ -875,27 +935,65 @@ pub async fn from_substrait_rel(
Ok(set_op) => match set_op {
set_rel::SetOp::UnionAll => {
if !set.inputs.is_empty() {
let mut union_builder = Ok(LogicalPlanBuilder::from(
from_substrait_rel(ctx, &set.inputs[0], extensions).await?,
));
for input in &set.inputs[1..] {
union_builder = union_builder?
.union(from_substrait_rel(ctx, input, extensions).await?);
}
union_builder?.build()
union_rels(&set.inputs, ctx, extensions, true).await
} else {
not_impl_err!("Union relation requires at least one input")
}
}
set_rel::SetOp::UnionDistinct => {
if !set.inputs.is_empty() {
union_rels(&set.inputs, ctx, extensions, false).await
} else {
not_impl_err!("Union relation requires at least one input")
}
}
set_rel::SetOp::IntersectionPrimary => {
if set.inputs.len() == 2 {
if set.inputs.len() >= 2 {
LogicalPlanBuilder::intersect(
from_substrait_rel(ctx, &set.inputs[0], extensions).await?,
from_substrait_rel(ctx, &set.inputs[1], extensions).await?,
union_rels(&set.inputs[1..], ctx, extensions, true).await?,
false,
)
} else {
not_impl_err!("Primary Intersect relation with more than two inputs isn't supported")
not_impl_err!(
"Primary Intersect relation requires at least two inputs"
)
}
}
set_rel::SetOp::IntersectionMultiset => {
if set.inputs.len() >= 2 {
intersect_rels(&set.inputs, ctx, extensions, false).await
} else {
not_impl_err!(
"Multiset Intersect relation requires at least two inputs"
)
}
}
set_rel::SetOp::IntersectionMultisetAll => {
if set.inputs.len() >= 2 {
intersect_rels(&set.inputs, ctx, extensions, true).await
} else {
not_impl_err!(
"MultisetAll Intersect relation requires at least two inputs"
)
}
}
set_rel::SetOp::MinusPrimary => {
if set.inputs.len() >= 2 {
except_rels(&set.inputs, ctx, extensions, false).await
} else {
not_impl_err!(
"Primary Minus relation requires at least two inputs"
)
}
}
set_rel::SetOp::MinusPrimaryAll => {
if set.inputs.len() >= 2 {
except_rels(&set.inputs, ctx, extensions, true).await
} else {
not_impl_err!(
"PrimaryAll Minus relation requires at least two inputs"
)
}
}
_ => not_impl_err!("Unsupported set operator: {set_op:?}"),
Expand Down
12 changes: 11 additions & 1 deletion datafusion/substrait/src/logical_plan/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ pub fn to_substrait_extended_expr(
}

/// Convert DataFusion LogicalPlan to Substrait Rel
#[allow(deprecated)]
pub fn to_substrait_rel(
plan: &LogicalPlan,
ctx: &SessionContext,
Expand Down Expand Up @@ -227,6 +228,7 @@ pub fn to_substrait_rel(
advanced_extension: None,
read_type: Some(ReadType::VirtualTable(VirtualTable {
values: vec![],
expressions: vec![],
})),
}))),
}))
Expand Down Expand Up @@ -263,7 +265,10 @@ pub fn to_substrait_rel(
best_effort_filter: None,
projection: None,
advanced_extension: None,
read_type: Some(ReadType::VirtualTable(VirtualTable { values })),
read_type: Some(ReadType::VirtualTable(VirtualTable {
values,
expressions: vec![],
})),
}))),
}))
}
Expand Down Expand Up @@ -359,6 +364,7 @@ pub fn to_substrait_rel(
rel_type: Some(RelType::Aggregate(Box::new(AggregateRel {
common: None,
input: Some(input),
grouping_expressions: vec![],
groupings,
measures,
advanced_extension: None,
Expand All @@ -377,8 +383,10 @@ pub fn to_substrait_rel(
rel_type: Some(RelType::Aggregate(Box::new(AggregateRel {
common: None,
input: Some(input),
grouping_expressions: vec![],
groupings: vec![Grouping {
grouping_expressions: grouping,
expression_references: vec![],
}],
measures: vec![],
advanced_extension: None,
Expand Down Expand Up @@ -764,6 +772,7 @@ pub fn operator_to_name(op: Operator) -> &'static str {
}
}

#[allow(deprecated)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this deprecated use required for changes in the substrait dependency?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so.. grouping_expressions field was deprecated in the proto definitions, but should still be used for some time not to break backwards-compatibility. Maybe there's some other way.. I just saw similar deprecated markers and followed suit.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it's not a rush, and might be better to stick with the old approach for now for producers until consumers have had a chance to catch up.

The new approach puts the grouping_expressions in the AggregateRel (so in to_substrait_agg_measure) and not in the Grouping. Instead, the Grouping has a list of indices into the AggregateRel's grouping_expressions.

E.g. instead of...

AggregateRel = {
  "groupings": [
    { "grouping_expressions": [expr_1, epxr_2] }
  ]
}

You would have...

AggregateRel = {
  "grouping_expressions": [expr_1, expr_2],
  "groupings": [
    { "expression_references": [0, 1] }
   ]
}

This makes it easier to recognize something like a rollup:

AggregateRel = {
  "grouping_expressions": [expr_1, expr_2],
  "groupings": [
    { "expression_references": [0, 1] },
    { "expression_references": [0] },
    { "expression_references": [] }
   ]
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so.. grouping_expressions field was deprecated in the proto definitions, but should still be used for some time not to break backwards-compatibility.

+1 for retaining both for backwards compatability

I've created #12957 to track that we deferred implementing this.

pub fn parse_flat_grouping_exprs(
ctx: &SessionContext,
exprs: &[Expr],
Expand All @@ -776,6 +785,7 @@ pub fn parse_flat_grouping_exprs(
.collect::<Result<Vec<_>>>()?;
Ok(Grouping {
grouping_expressions,
expression_references: vec![],
})
}

Expand Down
66 changes: 66 additions & 0 deletions datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,72 @@ async fn simple_intersect_consume() -> Result<()> {
.await
}

#[tokio::test]
async fn primary_intersect_consume() -> Result<()> {
let proto_plan =
read_json("tests/testdata/test_plans/intersect_primary.substrait.json");

assert_substrait_sql(
proto_plan,
"SELECT a FROM data INTERSECT (SELECT a FROM data2 UNION ALL SELECT a FROM data2)",
)
.await
}

#[tokio::test]
async fn multiset_intersect_consume() -> Result<()> {
let proto_plan =
read_json("tests/testdata/test_plans/intersect_multiset.substrait.json");

assert_substrait_sql(
proto_plan,
"SELECT a FROM data INTERSECT SELECT a FROM data2 INTERSECT SELECT a FROM data2",
)
.await
}

#[tokio::test]
async fn multiset_intersect_all_consume() -> Result<()> {
let proto_plan =
read_json("tests/testdata/test_plans/intersect_multiset_all.substrait.json");

assert_substrait_sql(
proto_plan,
"SELECT a FROM data INTERSECT ALL SELECT a FROM data2 INTERSECT ALL SELECT a FROM data2",
)
.await
}

#[tokio::test]
async fn primary_except_consume() -> Result<()> {
let proto_plan = read_json("tests/testdata/test_plans/minus_primary.substrait.json");

assert_substrait_sql(
proto_plan,
"SELECT a FROM data EXCEPT SELECT a FROM data2 EXCEPT SELECT a FROM data2",
)
.await
}

#[tokio::test]
async fn primary_except_all_consume() -> Result<()> {
let proto_plan =
read_json("tests/testdata/test_plans/minus_primary_all.substrait.json");

assert_substrait_sql(
proto_plan,
"SELECT a FROM data EXCEPT ALL SELECT a FROM data2 EXCEPT ALL SELECT a FROM data2",
)
.await
}

#[tokio::test]
async fn union_distinct_consume() -> Result<()> {
let proto_plan = read_json("tests/testdata/test_plans/union_distinct.substrait.json");

assert_substrait_sql(proto_plan, "SELECT a FROM data UNION SELECT a FROM data2").await
}

#[tokio::test]
async fn simple_intersect_table_reuse() -> Result<()> {
// Substrait does currently NOT maintain the alias of the tables.
Expand Down
Loading