Skip to content

Implement general purpose async functions #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ serde = { version = "1.0.217", features = ["derive"] }
futures = { version = "0.3.31" }
log = "0.4.22"
env_logger = "0.11.6"
reqwest = { version = "0.12.12", features = ["json"] }
reqwest = { version = "0.12.12", features = ["json"] }
87 changes: 87 additions & 0 deletions src/llm/async_func.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
use crate::llm::functions::AsyncScalarUDF;
use datafusion::arrow::array::{ArrayRef, RecordBatch};
use datafusion::arrow::datatypes::{Field, Schema};
use datafusion::common::{internal_err, not_impl_err, Result};
use datafusion::logical_expr::ScalarUDF;
use datafusion::physical_expr::{PhysicalExpr, ScalarFunctionExpr};
use std::fmt::Display;
use std::sync::Arc;

/// Wrapper for a Async function that can be used in a DataFusion query
#[derive(Debug, Clone)]
pub struct AsyncFuncExpr {
/// The name of the output column this function will generate
pub name: String,
/// The actual function (always `ScalarFunctionExpr`)
pub func: Arc<dyn PhysicalExpr>,
}

impl Display for AsyncFuncExpr {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "async_expr(name={}, expr={})", self.name, self.func)
}
}

impl AsyncFuncExpr {
/// create a new AsyncFuncExpr
pub fn new(name: impl Into<String>, func: Arc<dyn PhysicalExpr>) -> Self {
Self {
name: name.into(),
func,
}
}

/// return if this is an async function
pub fn is_async_func(func: &ScalarUDF) -> bool {
func.inner()
.as_any()
.downcast_ref::<AsyncScalarUDF>()
.is_some()
}

/// return the name of the output column
pub fn name(&self) -> &str {
&self.name
}

/// Return the output field generated by evaluating this function
pub fn field(&self, input_schema: &Schema) -> Field {
Field::new(
&self.name,
self.func.data_type(input_schema).unwrap(),
self.func.nullable(input_schema).unwrap(),
)
}

/// This (async) function is called for each record batch to evaluate the LLM expressions
///
/// The output is the output of evaluating the llm expression and the input record batch
pub async fn invoke_async(&self, batch: &RecordBatch) -> Result<ArrayRef> {
let Some(llm_function) = self.func.as_any().downcast_ref::<ScalarFunctionExpr>() else {
return internal_err!(
"unexpected function type, expected ScalarFunctionExpr, got: {:?}",
self.func
);
};

let Some(async_udf) = llm_function
.fun()
.inner()
.as_any()
.downcast_ref::<AsyncScalarUDF>()
else {
return not_impl_err!(
"Don't know how to evaluate async function: {:?}",
llm_function
);
};

async_udf.invoke_async(batch).await
}
}

impl PartialEq<Arc<dyn PhysicalExpr>> for AsyncFuncExpr {
fn eq(&self, other: &Arc<dyn PhysicalExpr>) -> bool {
&self.func == other
}
}
166 changes: 166 additions & 0 deletions src/llm/exec.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
use crate::llm::async_func::AsyncFuncExpr;
use datafusion::arrow::datatypes::{Fields, Schema, SchemaRef};
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::common::Result;
use datafusion::execution::{SendableRecordBatchStream, TaskContext};
use datafusion::physical_expr::EquivalenceProperties;
use datafusion::physical_plan::metrics::ExecutionPlanMetricsSet;
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
use datafusion::physical_plan::{
DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, Partitioning,
PlanProperties,
};
use futures::stream::StreamExt;
use log::trace;
use std::any::Any;
use std::sync::Arc;

/// This structure evaluates a set of async expressions on a record
/// batch producing a new record batch
///
/// This is similar to a ProjectionExec except that the functions can be async
///
/// The schema of the output of the AsyncFuncExec is:
/// Input columns followed by one column for each async expression
#[derive(Debug)]
pub struct AsyncFuncExec {
/// The async expressions to evaluate
async_exprs: Vec<AsyncFuncExpr>,
input: Arc<dyn ExecutionPlan>,
/// Cache holding plan properties like equivalences, output partitioning etc.
cache: PlanProperties,
metrics: ExecutionPlanMetricsSet,
}

impl AsyncFuncExec {
pub fn new(async_exprs: Vec<AsyncFuncExpr>, input: Arc<dyn ExecutionPlan>) -> Self {
// compute the output schema: input schema then async expressions
let fields: Fields = input
.schema()
.fields()
.iter()
.cloned()
.chain(
async_exprs
.iter()
.map(|async_expr| Arc::new(async_expr.field(input.schema().as_ref()))),
)
.collect();
let schema = Arc::new(Schema::new(fields));
let cache = AsyncFuncExec::compute_properties(&input, schema).unwrap();

Self {
input,
async_exprs,
cache,
metrics: ExecutionPlanMetricsSet::new(),
}
}

/// This function creates the cache object that stores the plan properties
/// such as schema, equivalence properties, ordering, partitioning, etc.
fn compute_properties(
input: &Arc<dyn ExecutionPlan>,
schema: SchemaRef,
) -> Result<PlanProperties> {
let eq_properties = EquivalenceProperties::new(schema);

// TODO: This is a dummy partitioning. We need to figure out the actual partitioning.
let output_partitioning = Partitioning::RoundRobinBatch(1);

Ok(PlanProperties::new(
eq_properties,
output_partitioning,
input.pipeline_behavior(),
input.boundedness(),
))
}
}

impl DisplayAs for AsyncFuncExec {
fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match t {
DisplayFormatType::Default | DisplayFormatType::Verbose => {
let expr: Vec<String> = self
.async_exprs
.iter()
.map(|async_expr| async_expr.to_string())
.collect();

write!(f, "AsyncFuncExec: async_expr=[{}]", expr.join(", "))
}
}
}
}

impl ExecutionPlan for AsyncFuncExec {
fn name(&self) -> &str {
"async_func"
}

fn as_any(&self) -> &dyn Any {
self
}

fn properties(&self) -> &PlanProperties {
&self.cache
}

fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![&self.input]
}

fn with_new_children(
self: Arc<Self>,
_children: Vec<Arc<dyn ExecutionPlan>>,
) -> Result<Arc<dyn ExecutionPlan>> {
Ok(Arc::new(AsyncFuncExec::new(
self.async_exprs.clone(),
Arc::clone(&self.input),
)))
}

fn execute(
&self,
partition: usize,
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
trace!(
"Start AsyncFuncExpr::execute for partition {} of context session_id {} and task_id {:?}",
partition,
context.session_id(),
context.task_id()
);
// TODO figure out how to record metrics

// first execute the input stream
let input_stream = self.input.execute(partition, context.clone())?;

// now, for each record batch, evaluate the async expressions and add the columns to the result
let async_exprs_captured = Arc::new(self.async_exprs.clone());
let schema_captured = self.schema();

let stream_with_async_functions = input_stream.then(move |batch| {
// need to clone *again* to capture the async_exprs and schema in the
// stream and satisfy lifetime requirements.
let async_exprs_captured = Arc::clone(&async_exprs_captured);
let schema_captured = schema_captured.clone();

async move {
let batch = batch?;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor, would moving this invocation of the ? operator save a task in case of an error?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, you are right -- that would be an improvement 👍

// append the result of evaluating the async expressions to the output
let mut output_arrays = batch.columns().to_vec();
for async_expr in async_exprs_captured.iter() {
let output_array = async_expr.invoke_async(&batch).await?;
output_arrays.push(output_array);
}
let batch = RecordBatch::try_new(schema_captured, output_arrays)?;
Ok(batch)
}
});

// Adapt the stream with the output schema
let adapter = RecordBatchStreamAdapter::new(self.schema(), stream_with_async_functions);
Ok(Box::pin(adapter))
}
}
81 changes: 57 additions & 24 deletions src/llm/functions.rs
Original file line number Diff line number Diff line change
@@ -1,49 +1,82 @@
use async_trait::async_trait;
use datafusion::arrow::array::{ArrayRef, RecordBatch};
use datafusion::arrow::datatypes::DataType;
use datafusion::common::internal_err;
use datafusion::common::Result;
use datafusion::logical_expr::{
ColumnarValue, ScalarUDFImpl, Signature, TypeSignature, Volatility,
};
use datafusion::logical_expr::{ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature};
use std::any::Any;
use std::fmt::Debug;
use std::sync::Arc;

/// A scalar UDF that will be bypassed when planning logical plan.
/// This is used to register the remote function to the context. The function should not be
/// invoked by DataFusion. It's only used to generate the logical plan and unparsed them to SQL.
/// A scalar UDF that can invoke using async methods
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is the the new API. At a high level it is meant to mimic ScalarUDFImpl except that it has a async invoke function

///
/// Note this is less efficient than the ScalarUDFImpl, but it can be used
/// to register remote functions in the context.
///
/// The name is chosen to mirror ScalarUDFImpl
#[async_trait]
pub trait AsyncScalarUDFImpl: Debug + Send + Sync {
/// the function cast as any
fn as_any(&self) -> &dyn Any;

/// The name of the function
fn name(&self) -> &str;

/// The signature of the function
fn signature(&self) -> &Signature;

/// The return type of the function
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType>;

/// Invoke the function asynchronously with the async arguments
async fn invoke_async(&self, args: &RecordBatch) -> Result<ArrayRef>;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder whether this should return a Stream of ArrayRef, so that internally you can batch the calls to an external system with the right batch size ? In case of LLM there might be also a problem with the context, I suppose...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is also an excellent question -- the current situation is that Datafusion handles the batching (aka target_size) -- so normally will pass 8k rows or whatever to the`

I think we could potentially make the API something like:

    fn invoke_async_stream(&self, input: SendableRecordBatchStream) -> Result<SendableRecordBatchStream>;

but I think that might be tricker to code / get right

In terms of LLM context, this particualr PR only adds async scalar functions. I think we could likely do something similar to with window and aggregate functions, which might more naturally map to context 🤔

}

/// A scalar UDF that must be invoked using async methods
///
/// Note this is not meant to be used directly, but is meant to be an implementation detail
/// for AsyncUDFImpl.
///
/// This is used to register remote functions in the context. The function
/// should not be invoked by DataFusion. It's only used to generate the logical
/// plan and unparsed them to SQL.
#[derive(Debug)]
pub struct ByPassScalarUDF {
name: String,
return_type: DataType,
signature: Signature,
pub struct AsyncScalarUDF {
inner: Arc<dyn AsyncScalarUDFImpl>,
}

impl ByPassScalarUDF {
pub fn new(name: &str, return_type: DataType) -> Self {
Self {
name: name.to_string(),
return_type,
signature: Signature::one_of(
vec![TypeSignature::VariadicAny, TypeSignature::Nullary],
Volatility::Volatile,
),
}
impl AsyncScalarUDF {
pub fn new(inner: Arc<dyn AsyncScalarUDFImpl>) -> Self {
Self { inner }
}

/// Turn this AsyncUDF into a ScalarUDF, suitable for
/// registering in the context
pub fn into_scalar_udf(self) -> Arc<ScalarUDF> {
Arc::new(ScalarUDF::new_from_impl(self))
}

/// Invoke the function asynchronously with the async arguments
pub async fn invoke_async(&self, args: &RecordBatch) -> Result<ArrayRef> {
self.inner.invoke_async(args).await
}
}

impl ScalarUDFImpl for ByPassScalarUDF {
impl ScalarUDFImpl for AsyncScalarUDF {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
&self.name
self.inner.name()
}

fn signature(&self) -> &Signature {
&self.signature
self.inner.signature()
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(self.return_type.clone())
self.inner.return_type(_arg_types)
}

fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
Expand Down
Loading