Skip to content

Implement TPCH substrait integration teset, support tpch_2 #11234

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 94 additions & 48 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,38 @@
// specific language governing permissions and limitations
// under the License.

use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
use async_recursion::async_recursion;
use datafusion::arrow::array::GenericListArray;
use datafusion::arrow::datatypes::{
DataType, Field, FieldRef, Fields, IntervalUnit, Schema, TimeUnit,
};
use datafusion::common::plan_err;
use datafusion::common::{
not_impl_datafusion_err, not_impl_err, plan_datafusion_err, plan_err,
substrait_datafusion_err, substrait_err, DFSchema, DFSchemaRef,
not_impl_datafusion_err, not_impl_err, plan_datafusion_err, substrait_datafusion_err,
substrait_err, DFSchema, DFSchemaRef,
};
use substrait::proto::expression::literal::IntervalDayToSecond;
use substrait::proto::read_rel::local_files::file_or_files::PathType::UriFile;
use url::Url;

use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
use datafusion::execution::FunctionRegistry;
use datafusion::logical_expr::expr::{InSubquery, Sort};

use datafusion::logical_expr::{
aggregate_function, expr::find_df_window_func, Aggregate, BinaryExpr, Case,
EmptyRelation, Expr, ExprSchemable, LogicalPlan, Operator, Projection, Values,
};
use url::Url;

use crate::variation_const::{
DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF,
DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF,
DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF,
INTERVAL_DAY_TIME_TYPE_REF, INTERVAL_MONTH_DAY_NANO_TYPE_REF,
INTERVAL_YEAR_MONTH_TYPE_REF, LARGE_CONTAINER_TYPE_VARIATION_REF,
TIMESTAMP_MICRO_TYPE_VARIATION_REF, TIMESTAMP_MILLI_TYPE_VARIATION_REF,
TIMESTAMP_NANO_TYPE_VARIATION_REF, TIMESTAMP_SECOND_TYPE_VARIATION_REF,
UNSIGNED_INTEGER_TYPE_VARIATION_REF,
};
use datafusion::common::scalar::ScalarStructBuilder;
use datafusion::logical_expr::expr::InList;
use datafusion::logical_expr::{
col, expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning,
Repartition, Subquery, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
Expand All @@ -46,10 +59,15 @@ use datafusion::{
prelude::{Column, SessionContext},
scalar::ScalarValue,
};
use std::collections::HashMap;
use std::str::FromStr;
use std::sync::Arc;
use substrait::proto::exchange_rel::ExchangeKind;
use substrait::proto::expression::literal::user_defined::Val;
use substrait::proto::expression::literal::IntervalDayToSecond;
use substrait::proto::expression::subquery::SubqueryType;
use substrait::proto::expression::{self, FieldReference, Literal, ScalarFunction};
use substrait::proto::read_rel::local_files::file_or_files::PathType::UriFile;
use substrait::proto::{
aggregate_function::AggregationInvocation,
expression::{
Expand All @@ -70,24 +88,6 @@ use substrait::proto::{
};
use substrait::proto::{FunctionArgument, SortField};

use datafusion::arrow::array::GenericListArray;
use datafusion::common::scalar::ScalarStructBuilder;
use datafusion::logical_expr::expr::{InList, InSubquery, Sort};
use std::collections::HashMap;
use std::str::FromStr;
use std::sync::Arc;

use crate::variation_const::{
DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF,
DECIMAL_128_TYPE_VARIATION_REF, DECIMAL_256_TYPE_VARIATION_REF,
DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF,
INTERVAL_DAY_TIME_TYPE_REF, INTERVAL_MONTH_DAY_NANO_TYPE_REF,
INTERVAL_YEAR_MONTH_TYPE_REF, LARGE_CONTAINER_TYPE_VARIATION_REF,
TIMESTAMP_MICRO_TYPE_VARIATION_REF, TIMESTAMP_MILLI_TYPE_VARIATION_REF,
TIMESTAMP_NANO_TYPE_VARIATION_REF, TIMESTAMP_SECOND_TYPE_VARIATION_REF,
UNSIGNED_INTEGER_TYPE_VARIATION_REF,
};

pub fn name_to_op(name: &str) -> Result<Operator> {
match name {
"equal" => Ok(Operator::Eq),
Expand Down Expand Up @@ -1125,17 +1125,32 @@ pub async fn from_substrait_rex(
expr::ScalarFunction::new_udf(func.to_owned(), args),
)))
} else if let Ok(op) = name_to_op(fn_name) {
if args.len() != 2 {
if f.arguments.len() < 2 {
return not_impl_err!(
"Expect two arguments for binary operator {op:?}"
"Expect at least two arguments for binary operator {op:?}, the provided number of operators is {:?}",
f.arguments.len()
);
}
// Some expressions are binary in DataFusion but take in a variadic number of args in Substrait.
// In those cases we iterate through all the arguments, applying the binary expression against them all
let combined_expr = args
.into_iter()
.fold(None, |combined_expr: Option<Arc<Expr>>, arg: Expr| {
Some(match combined_expr {
Some(expr) => Arc::new(Expr::BinaryExpr(BinaryExpr {
left: Box::new(
Arc::try_unwrap(expr)
.unwrap_or_else(|arc: Arc<Expr>| (*arc).clone()),
), // Avoid cloning if possible
op: op.clone(),
right: Box::new(arg),
})),
None => Arc::new(arg),
})
})
.unwrap();

Ok(Arc::new(Expr::BinaryExpr(BinaryExpr {
left: Box::new(args[0].to_owned()),
op,
right: Box::new(args[1].to_owned()),
})))
Ok(combined_expr)
} else if let Some(builder) = BuiltinExprBuilder::try_from_name(fn_name) {
builder.build(ctx, f, input_schema, extensions).await
} else {
Expand Down Expand Up @@ -1269,7 +1284,22 @@ pub async fn from_substrait_rex(
}
}
}
_ => substrait_err!("Subquery type not implemented"),
SubqueryType::Scalar(query) => {
let plan = from_substrait_rel(
ctx,
&(query.input.clone()).unwrap_or_default(),
extensions,
)
.await?;
let outer_ref_columns = plan.all_out_ref_exprs();
Ok(Arc::new(Expr::ScalarSubquery(Subquery {
subquery: Arc::new(plan),
outer_ref_columns,
})))
}
other_type => {
substrait_err!("Subquery type {:?} not implemented", other_type)
}
},
None => {
substrait_err!("Subquery experssion without SubqueryType is not allowed")
Expand Down Expand Up @@ -1699,6 +1729,7 @@ fn from_substrait_literal(
})) => {
ScalarValue::new_interval_dt(*days, (seconds * 1000) + (microseconds / 1000))
}
Some(LiteralType::FixedChar(c)) => ScalarValue::Utf8(Some(c.clone())),
Some(LiteralType::UserDefined(user_defined)) => {
match user_defined.type_reference {
INTERVAL_YEAR_MONTH_TYPE_REF => {
Expand Down Expand Up @@ -1988,8 +2019,8 @@ impl BuiltinExprBuilder {
extensions: &HashMap<u32, &String>,
) -> Result<Arc<Expr>> {
let fn_name = if case_insensitive { "ILIKE" } else { "LIKE" };
if f.arguments.len() != 3 {
return substrait_err!("Expect three arguments for `{fn_name}` expr");
if f.arguments.len() != 2 && f.arguments.len() != 3 {
return substrait_err!("Expect two or three arguments for `{fn_name}` expr");
}

let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else {
Expand All @@ -2007,25 +2038,40 @@ impl BuiltinExprBuilder {
.await?
.as_ref()
.clone();
let Some(ArgType::Value(escape_char_substrait)) = &f.arguments[2].arg_type else {
return substrait_err!("Invalid arguments type for `{fn_name}` expr");
};
let escape_char_expr =
from_substrait_rex(ctx, escape_char_substrait, input_schema, extensions)
.await?
.as_ref()
.clone();
let Expr::Literal(ScalarValue::Utf8(escape_char)) = escape_char_expr else {
return substrait_err!(
"Expect Utf8 literal for escape char, but found {escape_char_expr:?}"
);

// Default case: escape character is Literal(Utf8(None))
let escape_char = if f.arguments.len() == 3 {
let Some(ArgType::Value(escape_char_substrait)) = &f.arguments[2].arg_type
else {
return substrait_err!("Invalid arguments type for `{fn_name}` expr");
};

let escape_char_expr =
from_substrait_rex(ctx, escape_char_substrait, input_schema, extensions)
.await?
.as_ref()
.clone();

match escape_char_expr {
Expr::Literal(ScalarValue::Utf8(escape_char_string)) => {
// Convert Option<String> to Option<char>
escape_char_string.and_then(|s| s.chars().next())
}
_ => {
return substrait_err!(
"Expect Utf8 literal for escape char, but found {escape_char_expr:?}"
)
}
}
} else {
None
};

Ok(Arc::new(Expr::Like(Like {
negated: false,
expr: Box::new(expr),
pattern: Box::new(pattern),
escape_char: escape_char.map(|c| c.chars().next().unwrap()),
escape_char,
case_insensitive,
})))
}
Expand Down
93 changes: 83 additions & 10 deletions datafusion/substrait/tests/cases/consumer_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,51 @@ mod tests {
use std::io::BufReader;
use substrait::proto::Plan;

async fn register_csv(
ctx: &SessionContext,
table_name: &str,
file_path: &str,
) -> Result<()> {
ctx.register_csv(table_name, file_path, CsvReadOptions::default())
.await
}

async fn create_context_tpch2() -> Result<SessionContext> {
let ctx = SessionContext::new();

let registrations = vec![
("FILENAME_PLACEHOLDER_0", "tests/testdata/tpch/part.csv"),
("FILENAME_PLACEHOLDER_1", "tests/testdata/tpch/supplier.csv"),
("FILENAME_PLACEHOLDER_2", "tests/testdata/tpch/partsupp.csv"),
("FILENAME_PLACEHOLDER_3", "tests/testdata/tpch/nation.csv"),
("FILENAME_PLACEHOLDER_4", "tests/testdata/tpch/region.csv"),
("FILENAME_PLACEHOLDER_5", "tests/testdata/tpch/partsupp.csv"),
("FILENAME_PLACEHOLDER_6", "tests/testdata/tpch/supplier.csv"),
("FILENAME_PLACEHOLDER_7", "tests/testdata/tpch/nation.csv"),
("FILENAME_PLACEHOLDER_8", "tests/testdata/tpch/region.csv"),
];

for (table_name, file_path) in registrations {
register_csv(&ctx, table_name, file_path).await?;
}

Ok(ctx)
}

async fn create_context_tpch1() -> Result<SessionContext> {
let ctx = SessionContext::new();
register_csv(
&ctx,
"FILENAME_PLACEHOLDER_0",
"tests/testdata/tpch/lineitem.csv",
)
.await?;
Ok(ctx)
}

#[tokio::test]
async fn tpch_test_1() -> Result<()> {
let ctx = create_context().await?;
let ctx = create_context_tpch1().await?;
let path = "tests/testdata/tpch_substrait_plans/query_1.json";
let proto = serde_json::from_reader::<_, Plan>(BufReader::new(
File::open(path).expect("file not found"),
Expand All @@ -56,14 +98,45 @@ mod tests {
Ok(())
}

async fn create_context() -> datafusion::common::Result<SessionContext> {
let ctx = SessionContext::new();
ctx.register_csv(
"FILENAME_PLACEHOLDER_0",
"tests/testdata/tpch/lineitem.csv",
CsvReadOptions::default(),
)
.await?;
Ok(ctx)
#[tokio::test]
async fn tpch_test_2() -> Result<()> {
let ctx = create_context_tpch2().await?;
let path = "tests/testdata/tpch_substrait_plans/query_2.json";
let proto = serde_json::from_reader::<_, Plan>(BufReader::new(
File::open(path).expect("file not found"),
))
.expect("failed to parse json");

let plan = from_substrait_plan(&ctx, &proto).await?;
let plan_str = format!("{:?}", plan);
assert_eq!(
plan_str,
"Projection: FILENAME_PLACEHOLDER_1.s_acctbal AS S_ACCTBAL, FILENAME_PLACEHOLDER_1.s_name AS S_NAME, FILENAME_PLACEHOLDER_3.n_name AS N_NAME, FILENAME_PLACEHOLDER_0.p_partkey AS P_PARTKEY, FILENAME_PLACEHOLDER_0.p_mfgr AS P_MFGR, FILENAME_PLACEHOLDER_1.s_address AS S_ADDRESS, FILENAME_PLACEHOLDER_1.s_phone AS S_PHONE, FILENAME_PLACEHOLDER_1.s_comment AS S_COMMENT\
\n Limit: skip=0, fetch=100\
\n Sort: FILENAME_PLACEHOLDER_1.s_acctbal DESC NULLS FIRST, FILENAME_PLACEHOLDER_3.n_name ASC NULLS LAST, FILENAME_PLACEHOLDER_1.s_name ASC NULLS LAST, FILENAME_PLACEHOLDER_0.p_partkey ASC NULLS LAST\
\n Projection: FILENAME_PLACEHOLDER_1.s_acctbal, FILENAME_PLACEHOLDER_1.s_name, FILENAME_PLACEHOLDER_3.n_name, FILENAME_PLACEHOLDER_0.p_partkey, FILENAME_PLACEHOLDER_0.p_mfgr, FILENAME_PLACEHOLDER_1.s_address, FILENAME_PLACEHOLDER_1.s_phone, FILENAME_PLACEHOLDER_1.s_comment\
\n Filter: FILENAME_PLACEHOLDER_0.p_partkey = FILENAME_PLACEHOLDER_2.ps_partkey AND FILENAME_PLACEHOLDER_1.s_suppkey = FILENAME_PLACEHOLDER_2.ps_suppkey AND FILENAME_PLACEHOLDER_0.p_size = Int32(15) AND FILENAME_PLACEHOLDER_0.p_type LIKE CAST(Utf8(\"%BRASS\") AS Utf8) AND FILENAME_PLACEHOLDER_1.s_nationkey = FILENAME_PLACEHOLDER_3.n_nationkey AND FILENAME_PLACEHOLDER_3.n_regionkey = FILENAME_PLACEHOLDER_4.r_regionkey AND FILENAME_PLACEHOLDER_4.r_name = CAST(Utf8(\"EUROPE\") AS Utf8) AND FILENAME_PLACEHOLDER_2.ps_supplycost = (<subquery>)\
\n Subquery:\
\n Aggregate: groupBy=[[]], aggr=[[MIN(FILENAME_PLACEHOLDER_5.ps_supplycost)]]\
\n Projection: FILENAME_PLACEHOLDER_5.ps_supplycost\
\n Filter: FILENAME_PLACEHOLDER_5.ps_partkey = FILENAME_PLACEHOLDER_5.ps_partkey AND FILENAME_PLACEHOLDER_6.s_suppkey = FILENAME_PLACEHOLDER_5.ps_suppkey AND FILENAME_PLACEHOLDER_6.s_nationkey = FILENAME_PLACEHOLDER_7.n_nationkey AND FILENAME_PLACEHOLDER_7.n_regionkey = FILENAME_PLACEHOLDER_8.r_regionkey AND FILENAME_PLACEHOLDER_8.r_name = CAST(Utf8(\"EUROPE\") AS Utf8)\
\n Inner Join: Filter: Boolean(true)\
\n Inner Join: Filter: Boolean(true)\
\n Inner Join: Filter: Boolean(true)\
\n TableScan: FILENAME_PLACEHOLDER_5 projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost, ps_comment]\
\n TableScan: FILENAME_PLACEHOLDER_6 projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]\
\n TableScan: FILENAME_PLACEHOLDER_7 projection=[n_nationkey, n_name, n_regionkey, n_comment]\
\n TableScan: FILENAME_PLACEHOLDER_8 projection=[r_regionkey, r_name, r_comment]\
\n Inner Join: Filter: Boolean(true)\
\n Inner Join: Filter: Boolean(true)\
\n Inner Join: Filter: Boolean(true)\
\n Inner Join: Filter: Boolean(true)\
\n TableScan: FILENAME_PLACEHOLDER_0 projection=[p_partkey, p_name, p_mfgr, p_brand, p_type, p_size, p_container, p_retailprice, p_comment]\
\n TableScan: FILENAME_PLACEHOLDER_1 projection=[s_suppkey, s_name, s_address, s_nationkey, s_phone, s_acctbal, s_comment]\
\n TableScan: FILENAME_PLACEHOLDER_2 projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost, ps_comment]\
\n TableScan: FILENAME_PLACEHOLDER_3 projection=[n_nationkey, n_name, n_regionkey, n_comment]\
\n TableScan: FILENAME_PLACEHOLDER_4 projection=[r_regionkey, r_name, r_comment]"
);
Ok(())
}
}
2 changes: 2 additions & 0 deletions datafusion/substrait/tests/testdata/tpch/nation.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
n_nationkey,n_name,n_regionkey,n_comment
0,ALGERIA,0, haggle. carefully final deposits detect slyly agai
2 changes: 2 additions & 0 deletions datafusion/substrait/tests/testdata/tpch/part.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
p_partkey,p_name,p_mfgr,p_brand,p_type,p_size,p_container,p_retailprice,p_comment
1,pink powder puff,Manufacturer#1,Brand#13,SMALL PLATED COPPER,7,JUMBO PKG,901.00,ly final dependencies: slyly bold
2 changes: 2 additions & 0 deletions datafusion/substrait/tests/testdata/tpch/partsupp.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
ps_partkey,ps_suppkey,ps_availqty,ps_supplycost,ps_comment
1,1,1000,50.00,slyly final packages boost against the slyly regular
2 changes: 2 additions & 0 deletions datafusion/substrait/tests/testdata/tpch/region.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
r_regionkey,r_name,r_comment
0,AFRICA,lar deposits. blithely final packages cajole. regular waters are final requests. regular accounts are according to
2 changes: 2 additions & 0 deletions datafusion/substrait/tests/testdata/tpch/supplier.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
s_suppkey,s_name,s_address,s_nationkey,s_phone,s_acctbal,s_comment
1,Supplier#1,123 Main St,0,555-1234,1000.00,No comments
Loading