-
Notifications
You must be signed in to change notification settings - Fork 2
Implement general purpose async functions #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
use crate::llm::functions::AsyncScalarUDF; | ||
use datafusion::arrow::array::{ArrayRef, RecordBatch}; | ||
use datafusion::arrow::datatypes::{Field, Schema}; | ||
use datafusion::common::{internal_err, not_impl_err, Result}; | ||
use datafusion::logical_expr::ScalarUDF; | ||
use datafusion::physical_expr::{PhysicalExpr, ScalarFunctionExpr}; | ||
use std::fmt::Display; | ||
use std::sync::Arc; | ||
|
||
/// Wrapper for a Async function that can be used in a DataFusion query | ||
#[derive(Debug, Clone)] | ||
pub struct AsyncFuncExpr { | ||
/// The name of the output column this function will generate | ||
pub name: String, | ||
/// The actual function (always `ScalarFunctionExpr`) | ||
pub func: Arc<dyn PhysicalExpr>, | ||
} | ||
|
||
impl Display for AsyncFuncExpr { | ||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { | ||
write!(f, "async_expr(name={}, expr={})", self.name, self.func) | ||
} | ||
} | ||
|
||
impl AsyncFuncExpr { | ||
/// create a new AsyncFuncExpr | ||
pub fn new(name: impl Into<String>, func: Arc<dyn PhysicalExpr>) -> Self { | ||
Self { | ||
name: name.into(), | ||
func, | ||
} | ||
} | ||
|
||
/// return if this is an async function | ||
pub fn is_async_func(func: &ScalarUDF) -> bool { | ||
func.inner() | ||
.as_any() | ||
.downcast_ref::<AsyncScalarUDF>() | ||
.is_some() | ||
} | ||
|
||
/// return the name of the output column | ||
pub fn name(&self) -> &str { | ||
&self.name | ||
} | ||
|
||
/// Return the output field generated by evaluating this function | ||
pub fn field(&self, input_schema: &Schema) -> Field { | ||
Field::new( | ||
&self.name, | ||
self.func.data_type(input_schema).unwrap(), | ||
self.func.nullable(input_schema).unwrap(), | ||
) | ||
} | ||
|
||
/// This (async) function is called for each record batch to evaluate the LLM expressions | ||
/// | ||
/// The output is the output of evaluating the llm expression and the input record batch | ||
pub async fn invoke_async(&self, batch: &RecordBatch) -> Result<ArrayRef> { | ||
let Some(llm_function) = self.func.as_any().downcast_ref::<ScalarFunctionExpr>() else { | ||
return internal_err!( | ||
"unexpected function type, expected ScalarFunctionExpr, got: {:?}", | ||
self.func | ||
); | ||
}; | ||
|
||
let Some(async_udf) = llm_function | ||
.fun() | ||
.inner() | ||
.as_any() | ||
.downcast_ref::<AsyncScalarUDF>() | ||
else { | ||
return not_impl_err!( | ||
"Don't know how to evaluate async function: {:?}", | ||
llm_function | ||
); | ||
}; | ||
|
||
async_udf.invoke_async(batch).await | ||
} | ||
} | ||
|
||
impl PartialEq<Arc<dyn PhysicalExpr>> for AsyncFuncExpr { | ||
fn eq(&self, other: &Arc<dyn PhysicalExpr>) -> bool { | ||
&self.func == other | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
use crate::llm::async_func::AsyncFuncExpr; | ||
use datafusion::arrow::datatypes::{Fields, Schema, SchemaRef}; | ||
use datafusion::arrow::record_batch::RecordBatch; | ||
use datafusion::common::Result; | ||
use datafusion::execution::{SendableRecordBatchStream, TaskContext}; | ||
use datafusion::physical_expr::EquivalenceProperties; | ||
use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet; | ||
use datafusion::physical_plan::stream::RecordBatchStreamAdapter; | ||
use datafusion::physical_plan::{ | ||
DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, Partitioning, | ||
PlanProperties, | ||
}; | ||
use futures::stream::StreamExt; | ||
use log::trace; | ||
use std::any::Any; | ||
use std::sync::Arc; | ||
|
||
/// This structure evaluates a set of async expressions on a record | ||
/// batch producing a new record batch | ||
/// | ||
/// This is similar to a ProjectionExec except that the functions can be async | ||
/// | ||
/// The schema of the output of the AsyncFuncExec is: | ||
/// Input columns followed by one column for each async expression | ||
#[derive(Debug)] | ||
pub struct AsyncFuncExec { | ||
/// The async expressions to evaluate | ||
async_exprs: Vec<AsyncFuncExpr>, | ||
input: Arc<dyn ExecutionPlan>, | ||
/// Cache holding plan properties like equivalences, output partitioning etc. | ||
cache: PlanProperties, | ||
metrics: ExecutionPlanMetricsSet, | ||
} | ||
|
||
impl AsyncFuncExec { | ||
pub fn new(async_exprs: Vec<AsyncFuncExpr>, input: Arc<dyn ExecutionPlan>) -> Self { | ||
// compute the output schema: input schema then async expressions | ||
let fields: Fields = input | ||
.schema() | ||
.fields() | ||
.iter() | ||
.cloned() | ||
.chain( | ||
async_exprs | ||
.iter() | ||
.map(|async_expr| Arc::new(async_expr.field(input.schema().as_ref()))), | ||
) | ||
.collect(); | ||
let schema = Arc::new(Schema::new(fields)); | ||
let cache = AsyncFuncExec::compute_properties(&input, schema).unwrap(); | ||
|
||
Self { | ||
input, | ||
async_exprs, | ||
cache, | ||
metrics: ExecutionPlanMetricsSet::new(), | ||
} | ||
} | ||
|
||
/// This function creates the cache object that stores the plan properties | ||
/// such as schema, equivalence properties, ordering, partitioning, etc. | ||
fn compute_properties( | ||
input: &Arc<dyn ExecutionPlan>, | ||
schema: SchemaRef, | ||
) -> Result<PlanProperties> { | ||
let eq_properties = EquivalenceProperties::new(schema); | ||
|
||
// TODO: This is a dummy partitioning. We need to figure out the actual partitioning. | ||
let output_partitioning = Partitioning::RoundRobinBatch(1); | ||
|
||
Ok(PlanProperties::new( | ||
eq_properties, | ||
output_partitioning, | ||
input.pipeline_behavior(), | ||
input.boundedness(), | ||
)) | ||
} | ||
} | ||
|
||
impl DisplayAs for AsyncFuncExec { | ||
fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { | ||
match t { | ||
DisplayFormatType::Default | DisplayFormatType::Verbose => { | ||
let expr: Vec<String> = self | ||
.async_exprs | ||
.iter() | ||
.map(|async_expr| async_expr.to_string()) | ||
.collect(); | ||
|
||
write!(f, "AsyncFuncExec: async_expr=[{}]", expr.join(", ")) | ||
} | ||
} | ||
} | ||
} | ||
|
||
impl ExecutionPlan for AsyncFuncExec { | ||
fn name(&self) -> &str { | ||
"async_func" | ||
} | ||
|
||
fn as_any(&self) -> &dyn Any { | ||
self | ||
} | ||
|
||
fn properties(&self) -> &PlanProperties { | ||
&self.cache | ||
} | ||
|
||
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> { | ||
vec![&self.input] | ||
} | ||
|
||
fn with_new_children( | ||
self: Arc<Self>, | ||
_children: Vec<Arc<dyn ExecutionPlan>>, | ||
) -> Result<Arc<dyn ExecutionPlan>> { | ||
Ok(Arc::new(AsyncFuncExec::new( | ||
self.async_exprs.clone(), | ||
Arc::clone(&self.input), | ||
))) | ||
} | ||
|
||
fn execute( | ||
&self, | ||
partition: usize, | ||
context: Arc<TaskContext>, | ||
) -> Result<SendableRecordBatchStream> { | ||
trace!( | ||
"Start AsyncFuncExpr::execute for partition {} of context session_id {} and task_id {:?}", | ||
partition, | ||
context.session_id(), | ||
context.task_id() | ||
); | ||
// TODO figure out how to record metrics | ||
|
||
// first execute the input stream | ||
let input_stream = self.input.execute(partition, context.clone())?; | ||
|
||
// now, for each record batch, evaluate the async expressions and add the columns to the result | ||
let async_exprs_captured = Arc::new(self.async_exprs.clone()); | ||
let schema_captured = self.schema(); | ||
|
||
let stream_with_async_functions = input_stream.then(move |batch| { | ||
// need to clone *again* to capture the async_exprs and schema in the | ||
// stream and satisfy lifetime requirements. | ||
let async_exprs_captured = Arc::clone(&async_exprs_captured); | ||
let schema_captured = schema_captured.clone(); | ||
|
||
async move { | ||
let batch = batch?; | ||
// append the result of evaluating the async expressions to the output | ||
let mut output_arrays = batch.columns().to_vec(); | ||
for async_expr in async_exprs_captured.iter() { | ||
let output_array = async_expr.invoke_async(&batch).await?; | ||
output_arrays.push(output_array); | ||
} | ||
let batch = RecordBatch::try_new(schema_captured, output_arrays)?; | ||
Ok(batch) | ||
} | ||
}); | ||
|
||
// Adapt the stream with the output schema | ||
let adapter = RecordBatchStreamAdapter::new(self.schema(), stream_with_async_functions); | ||
Ok(Box::pin(adapter)) | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,49 +1,82 @@ | ||
use async_trait::async_trait; | ||
use datafusion::arrow::array::{ArrayRef, RecordBatch}; | ||
use datafusion::arrow::datatypes::DataType; | ||
use datafusion::common::internal_err; | ||
use datafusion::common::Result; | ||
use datafusion::logical_expr::{ | ||
ColumnarValue, ScalarUDFImpl, Signature, TypeSignature, Volatility, | ||
}; | ||
use datafusion::logical_expr::{ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature}; | ||
use std::any::Any; | ||
use std::fmt::Debug; | ||
use std::sync::Arc; | ||
|
||
/// A scalar UDF that will be bypassed when planning logical plan. | ||
/// This is used to register the remote function to the context. The function should not be | ||
/// invoked by DataFusion. It's only used to generate the logical plan and unparsed them to SQL. | ||
/// A scalar UDF that can invoke using async methods | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here is the the new API. At a high level it is meant to mimic |
||
/// | ||
/// Note this is less efficient than the ScalarUDFImpl, but it can be used | ||
/// to register remote functions in the context. | ||
/// | ||
/// The name is chosen to mirror ScalarUDFImpl | ||
#[async_trait] | ||
pub trait AsyncScalarUDFImpl: Debug + Send + Sync { | ||
/// the function cast as any | ||
fn as_any(&self) -> &dyn Any; | ||
|
||
/// The name of the function | ||
fn name(&self) -> &str; | ||
|
||
/// The signature of the function | ||
fn signature(&self) -> &Signature; | ||
|
||
/// The return type of the function | ||
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType>; | ||
|
||
/// Invoke the function asynchronously with the async arguments | ||
async fn invoke_async(&self, args: &RecordBatch) -> Result<ArrayRef>; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder whether this should return a Stream of ArrayRef, so that internally you can batch the calls to an external system with the right batch size ? In case of LLM there might be also a problem with the context, I suppose... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That is also an excellent question -- the current situation is that Datafusion handles the batching (aka I think we could potentially make the API something like: fn invoke_async_stream(&self, input: SendableRecordBatchStream) -> Result<SendableRecordBatchStream>; but I think that might be tricker to code / get right In terms of LLM context, this particualr PR only adds async scalar functions. I think we could likely do something similar to with window and aggregate functions, which might more naturally map to context 🤔 |
||
} | ||
|
||
/// A scalar UDF that must be invoked using async methods | ||
/// | ||
/// Note this is not meant to be used directly, but is meant to be an implementation detail | ||
/// for AsyncUDFImpl. | ||
/// | ||
/// This is used to register remote functions in the context. The function | ||
/// should not be invoked by DataFusion. It's only used to generate the logical | ||
/// plan and unparsed them to SQL. | ||
#[derive(Debug)] | ||
pub struct ByPassScalarUDF { | ||
name: String, | ||
return_type: DataType, | ||
signature: Signature, | ||
pub struct AsyncScalarUDF { | ||
inner: Arc<dyn AsyncScalarUDFImpl>, | ||
} | ||
|
||
impl ByPassScalarUDF { | ||
pub fn new(name: &str, return_type: DataType) -> Self { | ||
Self { | ||
name: name.to_string(), | ||
return_type, | ||
signature: Signature::one_of( | ||
vec![TypeSignature::VariadicAny, TypeSignature::Nullary], | ||
Volatility::Volatile, | ||
), | ||
} | ||
impl AsyncScalarUDF { | ||
pub fn new(inner: Arc<dyn AsyncScalarUDFImpl>) -> Self { | ||
Self { inner } | ||
} | ||
|
||
/// Turn this AsyncUDF into a ScalarUDF, suitable for | ||
/// registering in the context | ||
pub fn into_scalar_udf(self) -> Arc<ScalarUDF> { | ||
Arc::new(ScalarUDF::new_from_impl(self)) | ||
} | ||
|
||
/// Invoke the function asynchronously with the async arguments | ||
pub async fn invoke_async(&self, args: &RecordBatch) -> Result<ArrayRef> { | ||
self.inner.invoke_async(args).await | ||
} | ||
} | ||
|
||
impl ScalarUDFImpl for ByPassScalarUDF { | ||
impl ScalarUDFImpl for AsyncScalarUDF { | ||
fn as_any(&self) -> &dyn Any { | ||
self | ||
} | ||
|
||
fn name(&self) -> &str { | ||
&self.name | ||
self.inner.name() | ||
} | ||
|
||
fn signature(&self) -> &Signature { | ||
&self.signature | ||
self.inner.signature() | ||
} | ||
|
||
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> { | ||
Ok(self.return_type.clone()) | ||
self.inner.return_type(_arg_types) | ||
} | ||
|
||
fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor, would moving this invocation of the
?
operator save a task in case of an error?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, you are right -- that would be an improvement 👍