Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions datafusion-examples/examples/advanced_udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use arrow::array::{
};
use arrow::datatypes::{ArrowNativeTypeOp, ArrowPrimitiveType, Float64Type, UInt32Type};
use arrow::record_batch::RecordBatch;
use arrow_schema::FieldRef;
use datafusion::common::{cast::as_float64_array, ScalarValue};
use datafusion::error::Result;
use datafusion::logical_expr::{
Expand Down Expand Up @@ -92,10 +93,10 @@ impl AggregateUDFImpl for GeoMeanUdaf {
}

/// This is the description of the state. accumulator's state() must match the types here.
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
Ok(vec![
Field::new("prod", args.return_type().clone(), true),
Field::new("n", DataType::UInt32, true),
Field::new("prod", args.return_type().clone(), true).into(),
Field::new("n", DataType::UInt32, true).into(),
])
}

Expand Down Expand Up @@ -401,7 +402,7 @@ impl AggregateUDFImpl for SimplifiedGeoMeanUdaf {
unimplemented!("should not be invoked")
}

fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
unimplemented!("should not be invoked")
}

Expand Down
9 changes: 5 additions & 4 deletions datafusion-examples/examples/advanced_udwf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use arrow::{
array::{ArrayRef, AsArray, Float64Array},
datatypes::Float64Type,
};
use arrow_schema::FieldRef;
use datafusion::common::ScalarValue;
use datafusion::error::Result;
use datafusion::functions_aggregate::average::avg_udaf;
Expand Down Expand Up @@ -87,8 +88,8 @@ impl WindowUDFImpl for SmoothItUdf {
Ok(Box::new(MyPartitionEvaluator::new()))
}

fn field(&self, field_args: WindowUDFFieldArgs) -> Result<Field> {
Ok(Field::new(field_args.name(), DataType::Float64, true))
fn field(&self, field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
Ok(Field::new(field_args.name(), DataType::Float64, true).into())
}
}

Expand Down Expand Up @@ -205,8 +206,8 @@ impl WindowUDFImpl for SimplifySmoothItUdf {
Some(Box::new(simplify))
}

fn field(&self, field_args: WindowUDFFieldArgs) -> Result<Field> {
Ok(Field::new(field_args.name(), DataType::Float64, true))
fn field(&self, field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
Ok(Field::new(field_args.name(), DataType::Float64, true).into())
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ fn test_update_matching_exprs() -> Result<()> {
Arc::new(Column::new("b", 1)),
)),
],
Field::new("f", DataType::Int32, true),
Field::new("f", DataType::Int32, true).into(),
)),
Arc::new(CaseExpr::try_new(
Some(Arc::new(Column::new("d", 2))),
Expand Down Expand Up @@ -193,7 +193,7 @@ fn test_update_matching_exprs() -> Result<()> {
Arc::new(Column::new("b", 1)),
)),
],
Field::new("f", DataType::Int32, true),
Field::new("f", DataType::Int32, true).into(),
)),
Arc::new(CaseExpr::try_new(
Some(Arc::new(Column::new("d", 3))),
Expand Down Expand Up @@ -261,7 +261,7 @@ fn test_update_projected_exprs() -> Result<()> {
Arc::new(Column::new("b", 1)),
)),
],
Field::new("f", DataType::Int32, true),
Field::new("f", DataType::Int32, true).into(),
)),
Arc::new(CaseExpr::try_new(
Some(Arc::new(Column::new("d", 2))),
Expand Down Expand Up @@ -326,7 +326,7 @@ fn test_update_projected_exprs() -> Result<()> {
Arc::new(Column::new("b_new", 1)),
)),
],
Field::new("f", DataType::Int32, true),
Field::new("f", DataType::Int32, true).into(),
)),
Arc::new(CaseExpr::try_new(
Some(Arc::new(Column::new("d_new", 3))),
Expand Down
10 changes: 6 additions & 4 deletions datafusion/core/tests/user_defined/user_defined_aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ use arrow::array::{
StringArray, StructArray, UInt64Array,
};
use arrow::datatypes::{Fields, Schema};
use arrow_schema::FieldRef;
use datafusion::common::test_util::batches_to_string;
use datafusion::dataframe::DataFrame;
use datafusion::datasource::MemTable;
Expand Down Expand Up @@ -572,7 +573,7 @@ impl TimeSum {
// Returns the same type as its input
let return_type = timestamp_type.clone();

let state_fields = vec![Field::new("sum", timestamp_type, true)];
let state_fields = vec![Field::new("sum", timestamp_type, true).into()];

let volatility = Volatility::Immutable;

Expand Down Expand Up @@ -672,7 +673,7 @@ impl FirstSelector {
let state_fields = state_type
.into_iter()
.enumerate()
.map(|(i, t)| Field::new(format!("{i}"), t, true))
.map(|(i, t)| Field::new(format!("{i}"), t, true).into())
.collect::<Vec<_>>();

// Possible input signatures
Expand Down Expand Up @@ -932,9 +933,10 @@ impl AggregateUDFImpl for MetadataBasedAggregateUdf {
unimplemented!("this should never be called since return_field is implemented");
}

fn return_field(&self, _arg_fields: &[Field]) -> Result<Field> {
fn return_field(&self, _arg_fields: &[FieldRef]) -> Result<FieldRef> {
Ok(Field::new(self.name(), DataType::UInt64, true)
.with_metadata(self.metadata.clone()))
.with_metadata(self.metadata.clone())
.into())
}

fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use arrow::array::{
use arrow::compute::kernels::numeric::add;
use arrow::datatypes::{DataType, Field, Schema};
use arrow_schema::extension::{Bool8, CanonicalExtensionType, ExtensionType};
use arrow_schema::ArrowError;
use arrow_schema::{ArrowError, FieldRef};
use datafusion::common::test_util::batches_to_string;
use datafusion::execution::context::{FunctionFactory, RegisterFunction, SessionState};
use datafusion::prelude::*;
Expand Down Expand Up @@ -814,7 +814,7 @@ impl ScalarUDFImpl for TakeUDF {
///
/// 1. If the third argument is '0', return the type of the first argument
/// 2. If the third argument is '1', return the type of the second argument
fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<Field> {
fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
if args.arg_fields.len() != 3 {
return plan_err!("Expected 3 arguments, got {}.", args.arg_fields.len());
}
Expand Down Expand Up @@ -845,7 +845,8 @@ impl ScalarUDFImpl for TakeUDF {
self.name(),
args.arg_fields[take_idx].data_type().to_owned(),
true,
))
)
.into())
}

// The actual implementation
Expand Down Expand Up @@ -1412,9 +1413,10 @@ impl ScalarUDFImpl for MetadataBasedUdf {
);
}

fn return_field_from_args(&self, _args: ReturnFieldArgs) -> Result<Field> {
fn return_field_from_args(&self, _args: ReturnFieldArgs) -> Result<FieldRef> {
Ok(Field::new(self.name(), DataType::UInt64, true)
.with_metadata(self.metadata.clone()))
.with_metadata(self.metadata.clone())
.into())
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
Expand Down Expand Up @@ -1562,14 +1564,15 @@ impl ScalarUDFImpl for ExtensionBasedUdf {
Ok(DataType::Utf8)
}

fn return_field_from_args(&self, _args: ReturnFieldArgs) -> Result<Field> {
fn return_field_from_args(&self, _args: ReturnFieldArgs) -> Result<FieldRef> {
Ok(Field::new("canonical_extension_udf", DataType::Utf8, true)
.with_extension_type(MyUserExtentionType {}))
.with_extension_type(MyUserExtentionType {})
.into())
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
assert_eq!(args.arg_fields.len(), 1);
let input_field = args.arg_fields[0];
let input_field = args.arg_fields[0].as_ref();

let output_as_bool = matches!(
CanonicalExtensionType::try_from(input_field),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use arrow::array::{
UInt64Array,
};
use arrow::datatypes::{DataType, Field, Schema};
use arrow_schema::FieldRef;
use datafusion::common::test_util::batches_to_string;
use datafusion::common::{Result, ScalarValue};
use datafusion::prelude::SessionContext;
Expand Down Expand Up @@ -564,8 +565,8 @@ impl OddCounter {
&self.aliases
}

fn field(&self, field_args: WindowUDFFieldArgs) -> Result<Field> {
Ok(Field::new(field_args.name(), DataType::Int64, true))
fn field(&self, field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
Ok(Field::new(field_args.name(), DataType::Int64, true).into())
}
}

Expand Down Expand Up @@ -683,7 +684,7 @@ impl WindowUDFImpl for VariadicWindowUDF {
unimplemented!("unnecessary for testing");
}

fn field(&self, _: WindowUDFFieldArgs) -> Result<Field> {
fn field(&self, _: WindowUDFFieldArgs) -> Result<FieldRef> {
unimplemented!("unnecessary for testing");
}
}
Expand Down Expand Up @@ -809,9 +810,10 @@ impl WindowUDFImpl for MetadataBasedWindowUdf {
Ok(Box::new(MetadataBasedPartitionEvaluator { double_output }))
}

fn field(&self, field_args: WindowUDFFieldArgs) -> Result<Field> {
fn field(&self, field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
Ok(Field::new(field_args.name(), DataType::UInt64, true)
.with_metadata(self.metadata.clone()))
.with_metadata(self.metadata.clone())
.into())
}
}

Expand Down
4 changes: 2 additions & 2 deletions datafusion/expr-common/src/type_coercion/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

use crate::signature::TypeSignature;
use arrow::datatypes::{
DataType, Field, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE,
DataType, FieldRef, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE,
DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE,
};

Expand Down Expand Up @@ -89,7 +89,7 @@ pub static TIMES: &[DataType] = &[
/// number of input types.
pub fn check_arg_count(
func_name: &str,
input_fields: &[Field],
input_fields: &[FieldRef],
signature: &TypeSignature,
) -> Result<()> {
match signature {
Expand Down
6 changes: 3 additions & 3 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use crate::logical_plan::Subquery;
use crate::Volatility;
use crate::{udaf, ExprSchemable, Operator, Signature, WindowFrame, WindowUDF};

use arrow::datatypes::{DataType, Field, FieldRef};
use arrow::datatypes::{DataType, FieldRef};
use datafusion_common::cse::{HashNode, NormalizeEq, Normalizeable};
use datafusion_common::tree_node::{
Transformed, TransformedResult, TreeNode, TreeNodeContainer, TreeNodeRecursion,
Expand Down Expand Up @@ -846,10 +846,10 @@ impl WindowFunctionDefinition {
/// Returns the datatype of the window function
pub fn return_field(
&self,
input_expr_fields: &[Field],
input_expr_fields: &[FieldRef],
_input_expr_nullable: &[bool],
display_name: &str,
) -> Result<Field> {
) -> Result<FieldRef> {
match self {
WindowFunctionDefinition::AggregateUDF(fun) => {
fun.return_field(input_expr_fields)
Expand Down
17 changes: 9 additions & 8 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ use crate::{
use arrow::compute::kernels::cast_utils::{
parse_interval_day_time, parse_interval_month_day_nano, parse_interval_year_month,
};
use arrow::datatypes::{DataType, Field};
use arrow::datatypes::{DataType, Field, FieldRef};
use datafusion_common::{plan_err, Column, Result, ScalarValue, Spans, TableReference};
use datafusion_functions_window_common::field::WindowUDFFieldArgs;
use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
Expand Down Expand Up @@ -492,6 +492,7 @@ pub fn create_udaf(
.into_iter()
.enumerate()
.map(|(i, t)| Field::new(format!("{i}"), t, true))
.map(Arc::new)
.collect::<Vec<_>>();
AggregateUDF::from(SimpleAggregateUDF::new(
name,
Expand All @@ -510,7 +511,7 @@ pub struct SimpleAggregateUDF {
signature: Signature,
return_type: DataType,
accumulator: AccumulatorFactoryFunction,
state_fields: Vec<Field>,
state_fields: Vec<FieldRef>,
}

impl Debug for SimpleAggregateUDF {
Expand All @@ -533,7 +534,7 @@ impl SimpleAggregateUDF {
return_type: DataType,
volatility: Volatility,
accumulator: AccumulatorFactoryFunction,
state_fields: Vec<Field>,
state_fields: Vec<FieldRef>,
) -> Self {
let name = name.into();
let signature = Signature::exact(input_type, volatility);
Expand All @@ -553,7 +554,7 @@ impl SimpleAggregateUDF {
signature: Signature,
return_type: DataType,
accumulator: AccumulatorFactoryFunction,
state_fields: Vec<Field>,
state_fields: Vec<FieldRef>,
) -> Self {
let name = name.into();
Self {
Expand Down Expand Up @@ -590,7 +591,7 @@ impl AggregateUDFImpl for SimpleAggregateUDF {
(self.accumulator)(acc_args)
}

fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is nice that this is now avoiding a deep copy of a bunch of Fields 👍

Ok(self.state_fields.clone())
}
}
Expand Down Expand Up @@ -678,12 +679,12 @@ impl WindowUDFImpl for SimpleWindowUDF {
(self.partition_evaluator_factory)()
}

fn field(&self, field_args: WindowUDFFieldArgs) -> Result<Field> {
Ok(Field::new(
fn field(&self, field_args: WindowUDFFieldArgs) -> Result<FieldRef> {
Ok(Arc::new(Field::new(
field_args.name(),
self.return_type.clone(),
true,
))
)))
}
}

Expand Down
Loading