diff --git a/Cargo.toml b/Cargo.toml index 51ae9ef..41d63ca 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,4 +11,4 @@ serde = { version = "1.0.217", features = ["derive"] } futures = { version = "0.3.31" } log = "0.4.22" env_logger = "0.11.6" -reqwest = { version = "0.12.12", features = ["json"] } +reqwest = { version = "0.12.12", features = ["json"] } \ No newline at end of file diff --git a/src/llm/async_func.rs b/src/llm/async_func.rs new file mode 100644 index 0000000..54b58da --- /dev/null +++ b/src/llm/async_func.rs @@ -0,0 +1,87 @@ +use crate::llm::functions::AsyncScalarUDF; +use datafusion::arrow::array::{ArrayRef, RecordBatch}; +use datafusion::arrow::datatypes::{Field, Schema}; +use datafusion::common::{internal_err, not_impl_err, Result}; +use datafusion::logical_expr::ScalarUDF; +use datafusion::physical_expr::{PhysicalExpr, ScalarFunctionExpr}; +use std::fmt::Display; +use std::sync::Arc; + +/// Wrapper for a Async function that can be used in a DataFusion query +#[derive(Debug, Clone)] +pub struct AsyncFuncExpr { + /// The name of the output column this function will generate + pub name: String, + /// The actual function (always `ScalarFunctionExpr`) + pub func: Arc, +} + +impl Display for AsyncFuncExpr { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "async_expr(name={}, expr={})", self.name, self.func) + } +} + +impl AsyncFuncExpr { + /// create a new AsyncFuncExpr + pub fn new(name: impl Into, func: Arc) -> Self { + Self { + name: name.into(), + func, + } + } + + /// return if this is an async function + pub fn is_async_func(func: &ScalarUDF) -> bool { + func.inner() + .as_any() + .downcast_ref::() + .is_some() + } + + /// return the name of the output column + pub fn name(&self) -> &str { + &self.name + } + + /// Return the output field generated by evaluating this function + pub fn field(&self, input_schema: &Schema) -> Field { + Field::new( + &self.name, + self.func.data_type(input_schema).unwrap(), + self.func.nullable(input_schema).unwrap(), + ) + } + + /// This (async) function is called for each record batch to evaluate the LLM expressions + /// + /// The output is the output of evaluating the llm expression and the input record batch + pub async fn invoke_async(&self, batch: &RecordBatch) -> Result { + let Some(llm_function) = self.func.as_any().downcast_ref::() else { + return internal_err!( + "unexpected function type, expected ScalarFunctionExpr, got: {:?}", + self.func + ); + }; + + let Some(async_udf) = llm_function + .fun() + .inner() + .as_any() + .downcast_ref::() + else { + return not_impl_err!( + "Don't know how to evaluate async function: {:?}", + llm_function + ); + }; + + async_udf.invoke_async(batch).await + } +} + +impl PartialEq> for AsyncFuncExpr { + fn eq(&self, other: &Arc) -> bool { + &self.func == other + } +} diff --git a/src/llm/exec.rs b/src/llm/exec.rs new file mode 100644 index 0000000..d864ff3 --- /dev/null +++ b/src/llm/exec.rs @@ -0,0 +1,166 @@ +use crate::llm::async_func::AsyncFuncExpr; +use datafusion::arrow::datatypes::{Fields, Schema, SchemaRef}; +use datafusion::arrow::record_batch::RecordBatch; +use datafusion::common::Result; +use datafusion::execution::{SendableRecordBatchStream, TaskContext}; +use datafusion::physical_expr::EquivalenceProperties; +use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet; +use datafusion::physical_plan::stream::RecordBatchStreamAdapter; +use datafusion::physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, Partitioning, + PlanProperties, +}; +use futures::stream::StreamExt; +use log::trace; +use std::any::Any; +use std::sync::Arc; + +/// This structure evaluates a set of async expressions on a record +/// batch producing a new record batch +/// +/// This is similar to a ProjectionExec except that the functions can be async +/// +/// The schema of the output of the AsyncFuncExec is: +/// Input columns followed by one column for each async expression +#[derive(Debug)] +pub struct AsyncFuncExec { + /// The async expressions to evaluate + async_exprs: Vec, + input: Arc, + /// Cache holding plan properties like equivalences, output partitioning etc. + cache: PlanProperties, + metrics: ExecutionPlanMetricsSet, +} + +impl AsyncFuncExec { + pub fn new(async_exprs: Vec, input: Arc) -> Self { + // compute the output schema: input schema then async expressions + let fields: Fields = input + .schema() + .fields() + .iter() + .cloned() + .chain( + async_exprs + .iter() + .map(|async_expr| Arc::new(async_expr.field(input.schema().as_ref()))), + ) + .collect(); + let schema = Arc::new(Schema::new(fields)); + let cache = AsyncFuncExec::compute_properties(&input, schema).unwrap(); + + Self { + input, + async_exprs, + cache, + metrics: ExecutionPlanMetricsSet::new(), + } + } + + /// This function creates the cache object that stores the plan properties + /// such as schema, equivalence properties, ordering, partitioning, etc. + fn compute_properties( + input: &Arc, + schema: SchemaRef, + ) -> Result { + let eq_properties = EquivalenceProperties::new(schema); + + // TODO: This is a dummy partitioning. We need to figure out the actual partitioning. + let output_partitioning = Partitioning::RoundRobinBatch(1); + + Ok(PlanProperties::new( + eq_properties, + output_partitioning, + input.pipeline_behavior(), + input.boundedness(), + )) + } +} + +impl DisplayAs for AsyncFuncExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + let expr: Vec = self + .async_exprs + .iter() + .map(|async_expr| async_expr.to_string()) + .collect(); + + write!(f, "AsyncFuncExec: async_expr=[{}]", expr.join(", ")) + } + } + } +} + +impl ExecutionPlan for AsyncFuncExec { + fn name(&self) -> &str { + "async_func" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.cache + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> Result> { + Ok(Arc::new(AsyncFuncExec::new( + self.async_exprs.clone(), + Arc::clone(&self.input), + ))) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> Result { + trace!( + "Start AsyncFuncExpr::execute for partition {} of context session_id {} and task_id {:?}", + partition, + context.session_id(), + context.task_id() + ); + // TODO figure out how to record metrics + + // first execute the input stream + let input_stream = self.input.execute(partition, context.clone())?; + + // now, for each record batch, evaluate the async expressions and add the columns to the result + let async_exprs_captured = Arc::new(self.async_exprs.clone()); + let schema_captured = self.schema(); + + let stream_with_async_functions = input_stream.then(move |batch| { + // need to clone *again* to capture the async_exprs and schema in the + // stream and satisfy lifetime requirements. + let async_exprs_captured = Arc::clone(&async_exprs_captured); + let schema_captured = schema_captured.clone(); + + async move { + let batch = batch?; + // append the result of evaluating the async expressions to the output + let mut output_arrays = batch.columns().to_vec(); + for async_expr in async_exprs_captured.iter() { + let output_array = async_expr.invoke_async(&batch).await?; + output_arrays.push(output_array); + } + let batch = RecordBatch::try_new(schema_captured, output_arrays)?; + Ok(batch) + } + }); + + // Adapt the stream with the output schema + let adapter = RecordBatchStreamAdapter::new(self.schema(), stream_with_async_functions); + Ok(Box::pin(adapter)) + } +} diff --git a/src/llm/functions.rs b/src/llm/functions.rs index f234b3c..851a237 100644 --- a/src/llm/functions.rs +++ b/src/llm/functions.rs @@ -1,49 +1,82 @@ +use async_trait::async_trait; +use datafusion::arrow::array::{ArrayRef, RecordBatch}; use datafusion::arrow::datatypes::DataType; use datafusion::common::internal_err; use datafusion::common::Result; -use datafusion::logical_expr::{ - ColumnarValue, ScalarUDFImpl, Signature, TypeSignature, Volatility, -}; +use datafusion::logical_expr::{ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature}; use std::any::Any; +use std::fmt::Debug; +use std::sync::Arc; -/// A scalar UDF that will be bypassed when planning logical plan. -/// This is used to register the remote function to the context. The function should not be -/// invoked by DataFusion. It's only used to generate the logical plan and unparsed them to SQL. +/// A scalar UDF that can invoke using async methods +/// +/// Note this is less efficient than the ScalarUDFImpl, but it can be used +/// to register remote functions in the context. +/// +/// The name is chosen to mirror ScalarUDFImpl +#[async_trait] +pub trait AsyncScalarUDFImpl: Debug + Send + Sync { + /// the function cast as any + fn as_any(&self) -> &dyn Any; + + /// The name of the function + fn name(&self) -> &str; + + /// The signature of the function + fn signature(&self) -> &Signature; + + /// The return type of the function + fn return_type(&self, _arg_types: &[DataType]) -> Result; + + /// Invoke the function asynchronously with the async arguments + async fn invoke_async(&self, args: &RecordBatch) -> Result; +} + +/// A scalar UDF that must be invoked using async methods +/// +/// Note this is not meant to be used directly, but is meant to be an implementation detail +/// for AsyncUDFImpl. +/// +/// This is used to register remote functions in the context. The function +/// should not be invoked by DataFusion. It's only used to generate the logical +/// plan and unparsed them to SQL. #[derive(Debug)] -pub struct ByPassScalarUDF { - name: String, - return_type: DataType, - signature: Signature, +pub struct AsyncScalarUDF { + inner: Arc, } -impl ByPassScalarUDF { - pub fn new(name: &str, return_type: DataType) -> Self { - Self { - name: name.to_string(), - return_type, - signature: Signature::one_of( - vec![TypeSignature::VariadicAny, TypeSignature::Nullary], - Volatility::Volatile, - ), - } +impl AsyncScalarUDF { + pub fn new(inner: Arc) -> Self { + Self { inner } + } + + /// Turn this AsyncUDF into a ScalarUDF, suitable for + /// registering in the context + pub fn into_scalar_udf(self) -> Arc { + Arc::new(ScalarUDF::new_from_impl(self)) + } + + /// Invoke the function asynchronously with the async arguments + pub async fn invoke_async(&self, args: &RecordBatch) -> Result { + self.inner.invoke_async(args).await } } -impl ScalarUDFImpl for ByPassScalarUDF { +impl ScalarUDFImpl for AsyncScalarUDF { fn as_any(&self) -> &dyn Any { self } fn name(&self) -> &str { - &self.name + self.inner.name() } fn signature(&self) -> &Signature { - &self.signature + self.inner.signature() } fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(self.return_type.clone()) + self.inner.return_type(_arg_types) } fn invoke(&self, _args: &[ColumnarValue]) -> Result { diff --git a/src/llm/logical.rs b/src/llm/logical.rs deleted file mode 100644 index bdd177b..0000000 --- a/src/llm/logical.rs +++ /dev/null @@ -1,214 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use datafusion::arrow::datatypes::{DataType, Field}; -use datafusion::common::alias::AliasGenerator; -use datafusion::common::tree_node::{Transformed, TransformedResult}; -use datafusion::common::{not_impl_err, plan_err, Column, DFSchema, DFSchemaRef, Result}; -use datafusion::config::ConfigOptions; -use datafusion::logical_expr::{ - col, Aggregate, Expr, Extension, Filter, LogicalPlan, LogicalPlanBuilder, Partitioning, - Projection, UserDefinedLogicalNodeCore, -}; -use datafusion::optimizer::AnalyzerRule; -use std::cmp::Ordering; -use std::collections::HashMap; -use std::fmt::{format, Debug, Formatter}; -use std::sync::Arc; - -#[derive(PartialEq, Eq, Hash, Debug, Clone)] -pub struct LLMPlan { - pub schema: DFSchemaRef, - pub input: Arc, - pub required_columns: Vec, -} - -impl PartialOrd for LLMPlan { - fn partial_cmp(&self, _other: &Self) -> Option { - None - } -} - -impl UserDefinedLogicalNodeCore for LLMPlan { - fn name(&self) -> &str { - "LLM" - } - - fn inputs(&self) -> Vec<&LogicalPlan> { - vec![self.input.as_ref()] - } - - fn schema(&self) -> &DFSchemaRef { - &self.schema - } - - fn expressions(&self) -> Vec { - self.schema - .fields() - .iter() - .map(|field| col(field.name())) - .collect() - } - - fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result { - let columns: Vec = self - .required_columns - .iter() - .map(|c| c.name().to_string()) - .collect(); - write!(f, "LLM({})", columns.join(", ")) - } - - fn with_exprs_and_inputs(&self, _exprs: Vec, inputs: Vec) -> Result { - Ok(LLMPlan { - schema: self.schema.clone(), - input: Arc::new(inputs[0].clone()), - required_columns: self.required_columns.clone(), - }) - } -} - -pub struct LLMFunctionAnalyzeRule {} - -impl Debug for LLMFunctionAnalyzeRule { - fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { - f.debug_struct("LLMFunctionAnalyzeRule").finish() - } -} - -impl AnalyzerRule for LLMFunctionAnalyzeRule { - fn analyze(&self, plan: LogicalPlan, config: &ConfigOptions) -> Result { - plan.transform_down_with_subqueries(|plan| self.analyze_llm_function(plan)) - .data() - } - - fn name(&self) -> &str { - "LLMFunctionAnalyzeRule" - } -} - -impl LLMFunctionAnalyzeRule { - fn analyze_llm_function(&self, plan: LogicalPlan) -> Result> { - let alias_generator = AliasGenerator::new(); - let mut collected: HashMap> = HashMap::new(); - - // Find the scalar functions that are semantic functions and collect the required columns - // The first argument is the promotion, the rest are the columns - // e.g. `llm_bool("If all of them are Aisa countries: {}, {}, {}", column1, column2, column3)` - let plan = plan - .map_expressions(|expr| { - if let Expr::ScalarFunction(scalar_function) = &expr { - if scalar_function.name().starts_with("llm_") { - if scalar_function.args.len() == 0 { - return plan_err!("LLM function must have at least one argument"); - } - // let promotion = scalar_function.args[0].clone(); - let columns: Vec<_> = scalar_function.args[1..].iter().cloned().collect(); - if !columns.is_empty() { - collected.insert(alias_generator.next("__llm"), columns); - } - } - } - Ok(Transformed::no(expr)) - })? - .data; - - if collected.is_empty() { - return Ok(Transformed::no(plan)); - } - - if collected.len() > 1 { - return not_impl_err!("Only one LLM function is allowed in a query"); - } - // TODO: currently support one LLM function - let first = collected - .get("__llm_1") - .map(|v| v.clone()) - .unwrap_or_default(); - let schema: DFSchemaRef = Arc::new(DFSchema::from_unqualified_fields( - vec![Field::new("c1", DataType::Boolean, false)].into(), - HashMap::new(), - )?); - - let input = input(&plan)?; - let projection = LogicalPlanBuilder::from(input) - .project(first.clone())? - .build()?; - - let first = first - .iter() - .map(|e| { - if let Expr::Column(c) = e { - Ok(c.clone()) - } else { - return not_impl_err!("Only column references are allowed in LLM functions"); - } - }) - .collect::>()?; - - let node = LLMPlan { - schema, - input: Arc::new(projection), - required_columns: first, - }; - let llm_plan = LogicalPlan::Extension(Extension { - node: Arc::new(node), - }); - Ok(Transformed::yes(llm_plan)) - } -} - -fn input(plan: &LogicalPlan) -> Result> { - Ok(match plan { - LogicalPlan::Projection(Projection { input, .. }) => Arc::clone(input), - LogicalPlan::Aggregate(Aggregate { input, .. }) => Arc::clone(input), - LogicalPlan::Filter(Filter { input, .. }) => Arc::clone(input), - _ => return not_impl_err!("Unsupported plan: {:?}", plan), - }) -} - -#[cfg(test)] -mod test { - use crate::llm::functions::ByPassScalarUDF; - use crate::llm::logical::LLMFunctionAnalyzeRule; - use datafusion::common::Result; - use datafusion::execution::{FunctionRegistry, SessionStateBuilder}; - use datafusion::logical_expr::ScalarUDF; - use datafusion::prelude::SessionContext; - use std::sync::Arc; - - #[tokio::test] - async fn test_simple() -> Result<()> { - let mut state = SessionStateBuilder::default() - .with_analyzer_rules(vec![Arc::new(LLMFunctionAnalyzeRule {})]) - .with_optimizer_rules(vec![]) - .build(); - let udf = ByPassScalarUDF::new("llm_bool", datafusion::arrow::datatypes::DataType::Boolean); - state.register_udf(Arc::new(ScalarUDF::new_from_impl(udf)))?; - let ctx = SessionContext::new_with_state(state); - ctx.sql("create table t1 (c1 int, c2 int, c3 int)") - .await? - .show() - .await?; - let plan = ctx.sql("select llm_bool('If all of them are Aisa countries: {}, {}, {}', c1, c2, c3) from t1") - .await?.into_optimized_plan()?; - println!("{}", plan); - Ok(()) - } -} diff --git a/src/llm/mod.rs b/src/llm/mod.rs index 666e26c..1b24121 100644 --- a/src/llm/mod.rs +++ b/src/llm/mod.rs @@ -1,4 +1,4 @@ +mod async_func; +pub mod exec; pub mod functions; -pub mod logical; -pub mod physical_planner; -pub mod plan; +pub mod physical_optimizer; diff --git a/src/llm/physical_optimizer.rs b/src/llm/physical_optimizer.rs new file mode 100644 index 0000000..5b25fa4 --- /dev/null +++ b/src/llm/physical_optimizer.rs @@ -0,0 +1,148 @@ +use crate::llm::async_func::AsyncFuncExpr; +use crate::llm::exec::AsyncFuncExec; +use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; +use datafusion::config::ConfigOptions; +use datafusion::physical_expr::expressions::Column; +use datafusion::physical_expr::{PhysicalExpr, ScalarFunctionExpr}; +use datafusion::physical_optimizer::PhysicalOptimizerRule; +use datafusion::physical_plan::projection::ProjectionExec; +use datafusion::physical_plan::ExecutionPlan; +use std::sync::Arc; + +#[derive(Debug)] +pub struct AsyncFuncRule {} + +impl PhysicalOptimizerRule for AsyncFuncRule { + /// Insert a AsyncFunctionNode node in front of this projection if there are any async functions in it + /// + /// For example, if the projection is: + /// ```text + /// ProjectionExec(["A", "B", llm_func('foo', "C") + 1]) + /// ``` + /// + /// Rewrite to + /// ProjectionExec(["A", "B", "__async_fn_1" + 1]) <-- note here that the async function is not evaluated and instead is a new column + /// AsyncFunctionNode(["A", "B", llm_func('foo', "C")]) + /// + fn optimize( + &self, + plan: Arc, + _config: &ConfigOptions, + ) -> datafusion::common::Result> { + // replace ProjectionExec with async exec there are any async functions + // TODO: handle other types of ExecutionPlans (like Filter) + let Some(proj_exec) = plan.as_any().downcast_ref::() else { + return Ok(plan); + }; + + // find any instances of async functions in the expressions + let num_input_columns = proj_exec.input().schema().fields().len(); + let mut async_map = AsyncMapper::new(num_input_columns); + proj_exec.expr().iter().for_each(|(expr, _column_name)| { + async_map.find_references(expr); + }); + + if async_map.is_empty() { + return Ok(plan); + } + + // rewrite the projection's expressions in terms of the columns with the result of async evaluation + let new_exprs = proj_exec + .expr() + .iter() + .map(|(expr, column_name)| { + let new_expr = Arc::clone(expr) + .transform_up(|e| Ok(async_map.map_expr(e))) + .expect("no failures as closure is infallible"); + (new_expr.data, column_name.to_string()) + }) + .collect::>(); + + let async_exec = AsyncFuncExec::new(async_map.async_exprs, Arc::clone(proj_exec.input())); + + let new_proj_exec = ProjectionExec::try_new(new_exprs, Arc::new(async_exec))?; + + Ok(Arc::new(new_proj_exec) as _) + } + + fn name(&self) -> &str { + "async_func_rule" + } + + /// verify the schema has not changed + fn schema_check(&self) -> bool { + true + } +} + +/// Maps async_expressions to new columns +/// +/// The output of the async functions are appended, in order, to the end of the input schema +#[derive(Debug)] +struct AsyncMapper { + /// the number of columns in the input plan + /// used to generate the output column names. + /// the first async expr is `__async_fn_0`, the second is `__async_fn_1`, etc + num_input_columns: usize, + /// the expressions to map + async_exprs: Vec, +} + +impl AsyncMapper { + pub fn new(num_input_columns: usize) -> Self { + Self { + num_input_columns, + async_exprs: Vec::new(), + } + } + pub fn is_empty(&self) -> bool { + self.async_exprs.is_empty() + } + + /// Finds any references to async functions in the expression and adds them to the map + pub fn find_references(&mut self, proj_expr: &Arc) { + // recursively look for references to async functions + proj_expr + .apply(|expr| { + if let Some(func) = expr.as_any().downcast_ref::() { + if AsyncFuncExpr::is_async_func(func.fun()) { + let next_name = format!("__async_fn_{}", self.async_exprs.len()); + self.async_exprs + .push(AsyncFuncExpr::new(next_name, Arc::clone(expr))); + } + } + Ok(TreeNodeRecursion::Continue) + }) + .expect("no failures as closure is infallible"); + } + + /// If the expression matches any of the async functions, return the new column + pub fn map_expr(&self, expr: Arc) -> Transformed> { + // find the first matching async function if any + let Some(idx) = self + .async_exprs + .iter() + .enumerate() + .find_map( + |(idx, async_expr)| { + if async_expr == &expr { + Some(idx) + } else { + None + } + }, + ) + else { + return Transformed::no(expr); + }; + // rewrite in terms of the output column + Transformed::yes(self.output_column(idx)) + } + + /// return the output column for the async function at index idx + pub fn output_column(&self, idx: usize) -> Arc { + let async_expr = &self.async_exprs[idx]; + let output_idx = self.num_input_columns + idx; + Arc::new(Column::new(async_expr.name(), output_idx)) + } +} diff --git a/src/llm/physical_planner.rs b/src/llm/physical_planner.rs deleted file mode 100644 index 0098a54..0000000 --- a/src/llm/physical_planner.rs +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -use crate::llm::logical::LLMPlan; -use crate::llm::plan::LLMExec; -use async_trait::async_trait; -use datafusion::common::Result; -use datafusion::execution::context::QueryPlanner; -use datafusion::execution::SessionState; -use datafusion::logical_expr::{LogicalPlan, UserDefinedLogicalNode}; -use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet; -use datafusion::physical_plan::ExecutionPlan; -use datafusion::physical_planner::{DefaultPhysicalPlanner, ExtensionPlanner, PhysicalPlanner}; -use std::fmt::Debug; -use std::sync::Arc; - -#[derive(Debug)] -pub struct LLMQueryPlanner {} - -#[async_trait] -impl QueryPlanner for LLMQueryPlanner { - async fn create_physical_plan( - &self, - logical_plan: &LogicalPlan, - session_state: &SessionState, - ) -> Result> { - // Teach the default physical planner how to plan TopK nodes. - let physical_planner = - DefaultPhysicalPlanner::with_extension_planners(vec![Arc::new(LLMPlanner {})]); - // Delegate most work of physical planning to the default physical planner - physical_planner - .create_physical_plan(logical_plan, session_state) - .await - } -} - -struct LLMPlanner {} - -#[async_trait] -impl ExtensionPlanner for LLMPlanner { - async fn plan_extension( - &self, - _planner: &dyn PhysicalPlanner, - node: &dyn UserDefinedLogicalNode, - _logical_inputs: &[&LogicalPlan], - physical_inputs: &[Arc], - _session_state: &SessionState, - ) -> Result>> { - if let Some(LLMPlan { - schema, - input, - required_columns, - }) = node.as_any().downcast_ref::() - { - let exec = LLMExec::new(physical_inputs[0].clone(), ExecutionPlanMetricsSet::new()); - Ok(Some(Arc::new(exec))) - } else { - Ok(None) - } - } -} diff --git a/src/llm/plan.rs b/src/llm/plan.rs deleted file mode 100644 index 8679024..0000000 --- a/src/llm/plan.rs +++ /dev/null @@ -1,212 +0,0 @@ -use datafusion::arrow::array::{record_batch, BooleanArray, RecordBatchIterator}; -use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; -use datafusion::arrow::record_batch::RecordBatch; -use datafusion::common::{internal_err, Column, Result}; -use datafusion::execution::{SendableRecordBatchStream, TaskContext}; -use datafusion::logical_expr::ColumnarValue; -use datafusion::physical_expr::EquivalenceProperties; -use datafusion::physical_optimizer::pruning::RequiredColumns; -use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet}; -use datafusion::physical_plan::{ - DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, Partitioning, - PlanProperties, RecordBatchStream, -}; -use futures::stream::{Stream, StreamExt}; -use log::trace; -use serde::Serialize; -use std::any::Any; -use std::env; -use std::fmt::{Display, Formatter}; -use std::pin::Pin; -use std::sync::{Arc, Mutex}; -use std::task::{Context, Poll}; - -#[derive(Debug)] -pub struct LLMExec { - input: Arc, - /// Cache holding plan properties like equivalences, output partitioning etc. - cache: PlanProperties, - metrics: ExecutionPlanMetricsSet, -} - -impl LLMExec { - pub fn new(input: Arc, metrics: ExecutionPlanMetricsSet) -> Self { - let cache = LLMExec::compute_properties(&input, input.schema().clone()).unwrap(); - Self { - input, - cache, - metrics, - } - } - - /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. - fn compute_properties( - input: &Arc, - schema: SchemaRef, - ) -> Result { - let eq_properties = EquivalenceProperties::new(schema); - - // TODO: This is a dummy partitioning. We need to figure out the actual partitioning. - let output_partitioning = Partitioning::RoundRobinBatch(1); - - Ok(PlanProperties::new( - eq_properties, - output_partitioning, - input.pipeline_behavior(), - input.boundedness(), - )) - } -} - -impl DisplayAs for LLMExec { - fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { - match t { - DisplayFormatType::Default | DisplayFormatType::Verbose => { - write!(f, "LLM()") - } - } - } -} - -impl ExecutionPlan for LLMExec { - fn name(&self) -> &str { - "llm" - } - - fn as_any(&self) -> &dyn Any { - self - } - - fn properties(&self) -> &PlanProperties { - &self.cache - } - - fn children(&self) -> Vec<&Arc> { - vec![&self.input] - } - - fn with_new_children( - self: Arc, - children: Vec>, - ) -> Result> { - Ok(Arc::new(LLMExec::new( - children[0].clone(), - self.metrics.clone(), - ))) - } - - fn execute( - &self, - partition: usize, - context: Arc, - ) -> Result { - trace!("Start ProjectionExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id()); - Ok(Box::pin(LLMStream { - required_columns: vec![], - promotion: "".to_string(), - input: self.input.execute(partition, context)?, - baseline_metrics: BaselineMetrics::new(&self.metrics, partition), - })) - } -} - -pub struct LLMStream { - required_columns: Vec, - promotion: String, - input: SendableRecordBatchStream, - baseline_metrics: BaselineMetrics, -} - -impl Stream for LLMStream { - type Item = Result; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let poll = self.input.poll_next_unpin(cx).map(|x| match x { - Some(Ok(batch)) => { - let result = tokio::task::block_in_place(|| { - tokio::runtime::Handle::current().block_on(self.mock_llm_bool(&batch)) - }); - Some(result) - } - other => other, - }); - self.baseline_metrics.record_poll(poll) - } -} - -impl RecordBatchStream for LLMStream { - fn schema(&self) -> SchemaRef { - self.input.schema() - } -} - -impl LLMStream { - async fn mock_llm_bool(&self, batch: &RecordBatch) -> Result { - let shared_result = Arc::new(Mutex::new(None)); - let result_clone = Arc::clone(&shared_result); - let num = batch.num_rows(); - - // try to simulate the LLM async call - let handle = tokio::spawn(async move { - println!("Called LLM mock function with {} rows", num); - let record = vec![true; num]; - let bool_array = BooleanArray::from(record); - let schema = Schema::new(vec![Field::new("llm_bool", DataType::Boolean, false)]); - let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(bool_array)]).unwrap(); - let mut result = result_clone.lock().unwrap(); - *result = Some(batch); - }); - - handle.await.unwrap(); - let result = Arc::clone(&shared_result).lock().unwrap().take().unwrap(); - Ok(result) - } -} - -// TODO: -// pretty batch and concat with promotion -// as the input of the LLM function -// async fn ask_openai() -> Result> { -// let messages = vec![Message { -// role: "user".to_string(), -// content: "Say this is a test!".to_string(), -// }]; -// let body = MySendBody { -// model: "gpt-4o-mini".to_string(), -// messages, -// temperature: 0.7, -// }; -// -// let key = match env::var("OPENAI_API_KEY") { -// Ok(key) => key, -// Err(e) => panic!("OPENAI_API_KEY is not set: {}", e), -// }; -// let key = format!("Bearer {}", key); -// let client = reqwest::Client::new(); -// let response = client.post("https://api.openai.com/v1/chat/completions") -// .header("Authorization", key) -// .json(&body) -// .send().await -// .map_err(|e| { -// internal_err!("Failed to send request to OpenAI: {}", e) -// })?; -// if response.status().is_success() { -// response.text() -// } -// else { -// internal_err!("Failed to send request to OpenAI: {}", response.status()) -// } -// } - -#[derive(Serialize)] -struct MySendBody { - model: String, - messages: Vec, - temperature: f32, -} - -#[derive(Serialize)] -struct Message { - role: String, - content: String, -} diff --git a/src/main.rs b/src/main.rs index 4d2f710..76386f1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,11 +1,14 @@ -use crate::llm::functions::ByPassScalarUDF; -use crate::llm::logical::LLMFunctionAnalyzeRule; -use crate::llm::physical_planner::LLMQueryPlanner; +use crate::llm::functions::{AsyncScalarUDF, AsyncScalarUDFImpl}; +use crate::llm::physical_optimizer::AsyncFuncRule; +use async_trait::async_trait; +use datafusion::arrow::array::{ArrayRef, AsArray, BooleanArray, RecordBatch}; +use datafusion::arrow::datatypes::{DataType, Int32Type}; use datafusion::common::Result; use datafusion::execution::{FunctionRegistry, SessionStateBuilder}; use datafusion::functions_aggregate::min_max::max_udaf; -use datafusion::logical_expr::ScalarUDF; +use datafusion::logical_expr::{Signature, TypeSignature, Volatility}; use datafusion::prelude::SessionContext; +use std::any::Any; use std::sync::Arc; mod llm; @@ -13,34 +16,141 @@ mod llm; #[tokio::main] async fn main() -> Result<()> { env_logger::init(); + let mut state = SessionStateBuilder::default() - .with_analyzer_rules(vec![Arc::new(LLMFunctionAnalyzeRule {})]) - .with_optimizer_rules(vec![]) - .with_query_planner(Arc::new(LLMQueryPlanner {})) - .with_physical_optimizer_rules(vec![]) + .with_physical_optimizer_rule(Arc::new(AsyncFuncRule {})) .build(); - let udf = ByPassScalarUDF::new("llm_bool", datafusion::arrow::datatypes::DataType::Boolean); - state.register_udf(Arc::new(ScalarUDF::new_from_impl(udf)))?; + + let llm_bool = LLMBool::new(); + let udf = AsyncScalarUDF::new(Arc::new(llm_bool)); + state.register_udf(udf.into_scalar_udf())?; state.register_udaf(max_udaf())?; let ctx = SessionContext::new_with_state(state); ctx.sql("create table t1 (c1 int, c2 int, c3 int)") .await? .show() .await?; - ctx.sql("insert into t1 values (1, 2, 3), (1, 2, 3), (1, 2, 3)") + ctx.sql("insert into t1 values (1, 2, 3), (11, 2, 3), (1, 2, 3)") .await? .show() .await?; - ctx.sql("insert into t1 values (1, 2, 3), (1, 2, 3), (1, 2, 3)") + ctx.sql("insert into t1 values (1, 2, 3), (1, 2, 3), (21, 2, 3)") .await? .show() .await?; - ctx.sql("insert into t1 values (1, 2, 3), (1, 2, 3), (1, 2, 3)") + ctx.sql("insert into t1 values (31, 2, 3), (1, 2, 3), (1, 2, 3)") .await? .show() .await?; ctx.sql("explain select llm_bool('If all of them are Aisa countries: {}, {}, {}', t1.c1, t1.c2, t1.c3) from t1") .await?.show().await?; + + ctx.sql("select llm_bool('If all of them are Aisa countries: {}, {}, {}', t1.c1, t1.c2, t1.c3) from t1") + .await?.show().await?; + Ok(()) } + +/// This is a simple example of a UDF that takes a string, invokes a (remote) LLM function +/// and returns a boolean +#[derive(Debug)] +struct LLMBool { + signature: Signature, +} + +impl LLMBool { + fn new() -> Self { + Self { + signature: Signature::one_of( + vec![TypeSignature::VariadicAny, TypeSignature::Nullary], + Volatility::Volatile, + ), + } + } +} + +#[async_trait] +impl AsyncScalarUDFImpl for LLMBool { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "llm_bool" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Boolean) + } + + async fn invoke_async(&self, args: &RecordBatch) -> Result { + // TODO make the actual async call to open AI + // + // Note this is an async function + + // example calls the function with three integer args. We return true if + // the first argument is greater than 10 + let first_arg = args.columns()[0].as_primitive::(); + + let output: BooleanArray = first_arg + .iter() + .map(|arg| arg.map(|arg| arg > 10)) + .collect(); + + Ok(Arc::new(output)) + } +} + +// TODO: +// pretty batch and concat with promotion +// as the input of the LLM function +// async fn ask_openai() -> Result> { +// let messages = vec![Message { +// role: "user".to_string(), +// content: "Say this is a test!".to_string(), +// }]; +// let body = MySendBody { +// model: "gpt-4o-mini".to_string(), +// messages, +// temperature: 0.7, +// }; +// +// let key = match env::var("OPENAI_API_KEY") { +// Ok(key) => key, +// Err(e) => panic!("OPENAI_API_KEY is not set: {}", e), +// }; +// let key = format!("Bearer {}", key); +// let client = reqwest::Client::new(); +// let response = client.post("https://api.openai.com/v1/chat/completions") +// .header("Authorization", key) +// .json(&body) +// .send().await +// .map_err(|e| { +// internal_err!("Failed to send request to OpenAI: {}", e) +// })?; +// if response.status().is_success() { +// response.text() +// } +// else { +// internal_err!("Failed to send request to OpenAI: {}", response.status()) +// } +// } +/* +#[derive(Serialize)] +struct MySendBody { + model: String, + messages: Vec, + temperature: f32, +} + +#[derive(Serialize)] +struct Message { + role: String, + content: String, +} +*/