diff --git a/bindings/python/Cargo.toml b/bindings/python/Cargo.toml index 0260f788b..469fc7698 100644 --- a/bindings/python/Cargo.toml +++ b/bindings/python/Cargo.toml @@ -32,5 +32,5 @@ crate-type = ["cdylib"] [dependencies] iceberg = { path = "../../crates/iceberg" } -pyo3 = { version = "0.21.1", features = ["extension-module"] } -arrow = { version = "52.2.0", features = ["pyarrow"] } +pyo3 = { version = "0.21", features = ["extension-module"] } +arrow = { version = "52", features = ["pyarrow"] } diff --git a/bindings/python/src/error.rs b/bindings/python/src/error.rs new file mode 100644 index 000000000..a2d1424cc --- /dev/null +++ b/bindings/python/src/error.rs @@ -0,0 +1,24 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use pyo3::exceptions::PyValueError; +use pyo3::PyErr; + +/// Convert an iceberg error to a python error +pub fn to_py_err(err: iceberg::Error) -> PyErr { + PyValueError::new_err(err.to_string()) +} diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index 5c3f77ff7..a16bdac4d 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -15,23 +15,13 @@ // specific language governing permissions and limitations // under the License. -use iceberg::io::FileIOBuilder; use pyo3::prelude::*; -use pyo3::wrap_pyfunction; +mod error; mod transform; -#[pyfunction] -fn hello_world() -> PyResult { - let _ = FileIOBuilder::new_fs_io().build().unwrap(); - Ok("Hello, world!".to_string()) -} - - #[pymodule] -fn pyiceberg_core_rust(m: &Bound<'_, PyModule>) -> PyResult<()> { - m.add_function(wrap_pyfunction!(hello_world, m)?)?; - - m.add_class::()?; +fn pyiceberg_core_rust(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { + transform::register_module(py, m)?; Ok(()) } diff --git a/bindings/python/src/transform.rs b/bindings/python/src/transform.rs index 8f4585b2a..5b0d82f22 100644 --- a/bindings/python/src/transform.rs +++ b/bindings/python/src/transform.rs @@ -15,24 +15,55 @@ // specific language governing permissions and limitations // under the License. +use arrow::array::{make_array, Array, ArrayData}; +use arrow::pyarrow::{FromPyArrow, ToPyArrow}; use iceberg::spec::Transform; use iceberg::transform::create_transform_function; +use pyo3::prelude::*; -use arrow::{ - array::{make_array, Array, ArrayData}, -}; -use arrow::pyarrow::{FromPyArrow, ToPyArrow}; -use pyo3::{exceptions::PyValueError, prelude::*}; +use crate::error::to_py_err; + +#[pyfunction] +pub fn identity(py: Python, array: PyObject) -> PyResult { + apply(py, array, Transform::Identity) +} + +#[pyfunction] +pub fn void(py: Python, array: PyObject) -> PyResult { + apply(py, array, Transform::Void) +} + +#[pyfunction] +pub fn year(py: Python, array: PyObject) -> PyResult { + apply(py, array, Transform::Year) +} + +#[pyfunction] +pub fn month(py: Python, array: PyObject) -> PyResult { + apply(py, array, Transform::Month) +} -fn to_py_err(err: iceberg::Error) -> PyErr { - PyValueError::new_err(err.to_string()) +#[pyfunction] +pub fn day(py: Python, array: PyObject) -> PyResult { + apply(py, array, Transform::Day) } -#[pyclass] -pub struct ArrowArrayTransform { +#[pyfunction] +pub fn hour(py: Python, array: PyObject) -> PyResult { + apply(py, array, Transform::Hour) } -fn apply(array: PyObject, transform: Transform, py: Python) -> PyResult { +#[pyfunction] +pub fn bucket(py: Python, array: PyObject, num_buckets: u32) -> PyResult { + apply(py, array, Transform::Bucket(num_buckets)) +} + +#[pyfunction] +pub fn truncate(py: Python, array: PyObject, width: u32) -> PyResult { + apply(py, array, Transform::Truncate(width)) +} + +fn apply(py: Python, array: PyObject, transform: Transform) -> PyResult { // import let array = ArrayData::from_pyarrow_bound(array.bind(py))?; let array = make_array(array); @@ -43,45 +74,20 @@ fn apply(array: PyObject, transform: Transform, py: Python) -> PyResult PyResult { - apply(array, Transform::Identity, py) - } - - #[staticmethod] - pub fn void(array: PyObject, py: Python) -> PyResult { - apply(array, Transform::Void, py) - } - - #[staticmethod] - pub fn year(array: PyObject, py: Python) -> PyResult { - apply(array, Transform::Year, py) - } - - #[staticmethod] - pub fn month(array: PyObject, py: Python) -> PyResult { - apply(array, Transform::Month, py) - } - - #[staticmethod] - pub fn day(array: PyObject, py: Python) -> PyResult { - apply(array, Transform::Day, py) - } - - #[staticmethod] - pub fn hour(array: PyObject, py: Python) -> PyResult { - apply(array, Transform::Hour, py) - } +pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { + let this = PyModule::new_bound(py, "transform")?; - #[staticmethod] - pub fn bucket(array: PyObject, num_buckets: u32, py: Python) -> PyResult { - apply(array, Transform::Bucket(num_buckets), py) - } + this.add_function(wrap_pyfunction!(identity, &this)?)?; + this.add_function(wrap_pyfunction!(void, &this)?)?; + this.add_function(wrap_pyfunction!(year, &this)?)?; + this.add_function(wrap_pyfunction!(month, &this)?)?; + this.add_function(wrap_pyfunction!(day, &this)?)?; + this.add_function(wrap_pyfunction!(hour, &this)?)?; + this.add_function(wrap_pyfunction!(bucket, &this)?)?; + this.add_function(wrap_pyfunction!(truncate, &this)?)?; - #[staticmethod] - pub fn truncate(array: PyObject, width: u32, py: Python) -> PyResult { - apply(array, Transform::Truncate(width), py) - } + m.add_submodule(&this)?; + py.import_bound("sys")? + .getattr("modules")? + .set_item("pyiceberg_core.transform", this) } diff --git a/bindings/python/tests/test_basic.py b/bindings/python/tests/test_basic.py deleted file mode 100644 index 817793ba8..000000000 --- a/bindings/python/tests/test_basic.py +++ /dev/null @@ -1,22 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from pyiceberg_core import hello_world - - -def test_hello_world(): - hello_world() diff --git a/bindings/python/tests/test_transform.py b/bindings/python/tests/test_transform.py index 1fa2d577a..4180b6902 100644 --- a/bindings/python/tests/test_transform.py +++ b/bindings/python/tests/test_transform.py @@ -19,18 +19,18 @@ import pyarrow as pa import pytest -from pyiceberg_core import ArrowArrayTransform +from pyiceberg_core import transform def test_identity_transform(): arr = pa.array([1, 2]) - result = ArrowArrayTransform.identity(arr) + result = transform.identity(arr) assert result == arr def test_bucket_transform(): arr = pa.array([1, 2]) - result = ArrowArrayTransform.bucket(arr, 10) + result = transform.bucket(arr, 10) expected = pa.array([6, 2], type=pa.int32()) assert result == expected @@ -41,14 +41,14 @@ def test_bucket_transform_fails_for_list_type_input(): ValueError, match=r"FeatureUnsupported => Unsupported data type for bucket transform", ): - ArrowArrayTransform.bucket(arr, 10) + transform.bucket(arr, 10) def test_bucket_chunked_array(): chunked = pa.chunked_array([pa.array([1, 2]), pa.array([3, 4])]) result_chunks = [] for arr in chunked.iterchunks(): - result_chunks.append(ArrowArrayTransform.bucket(arr, 10)) + result_chunks.append(transform.bucket(arr, 10)) expected = pa.chunked_array( [pa.array([6, 2], type=pa.int32()), pa.array([5, 0], type=pa.int32())] @@ -58,34 +58,42 @@ def test_bucket_chunked_array(): def test_year_transform(): arr = pa.array([date(1970, 1, 1), date(2000, 1, 1)]) - result = ArrowArrayTransform.year(arr) + result = transform.year(arr) expected = pa.array([0, 30], type=pa.int32()) assert result == expected def test_month_transform(): arr = pa.array([date(1970, 1, 1), date(2000, 4, 1)]) - result = ArrowArrayTransform.month(arr) + result = transform.month(arr) expected = pa.array([0, 30 * 12 + 3], type=pa.int32()) assert result == expected def test_day_transform(): arr = pa.array([date(1970, 1, 1), date(2000, 4, 1)]) - result = ArrowArrayTransform.day(arr) + result = transform.day(arr) expected = pa.array([0, 11048], type=pa.int32()) assert result == expected def test_hour_transform(): arr = pa.array([datetime(1970, 1, 1, 19, 1, 23), datetime(2000, 3, 1, 12, 1, 23)]) - result = ArrowArrayTransform.hour(arr) + result = transform.hour(arr) expected = pa.array([19, 264420], type=pa.int32()) assert result == expected def test_truncate_transform(): arr = pa.array(["this is a long string", "hi my name is sung"]) - result = ArrowArrayTransform.truncate(arr, 5) + result = transform.truncate(arr, 5) expected = pa.array(["this ", "hi my"]) assert result == expected + + +def test_identity_transform_with_direct_import(): + from pyiceberg_core.transform import identity + + arr = pa.array([1, 2]) + result = identity(arr) + assert result == arr