Skip to content

Leverage Iceberg-Rust for all the transforms #1833

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 28 additions & 23 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

128 changes: 40 additions & 88 deletions pyiceberg/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,36 @@ def parse_transform(v: Any) -> Any:
return v


def _pyiceberg_transform_wrapper(
transform_func: Callable[["ArrayLike", Any], "ArrayLike"],
*args: Any,
expected_type: Optional["pa.DataType"] = None,
) -> Callable[["ArrayLike"], "ArrayLike"]:
try:
import pyarrow as pa
except ModuleNotFoundError as e:
raise ModuleNotFoundError("For bucket/truncate transforms, PyArrow needs to be installed") from e
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: change this error message since we use this for all pyarrow transforms


def _transform(array: "ArrayLike") -> "ArrayLike":
def _cast_if_needed(arr: "ArrayLike") -> "ArrayLike":
if expected_type is not None:
return arr.cast(expected_type)
else:
return arr

if isinstance(array, pa.Array):
return _cast_if_needed(transform_func(array, *args))
elif isinstance(array, pa.ChunkedArray):
result_chunks = []
for arr in array.iterchunks():
result_chunks.append(_cast_if_needed(transform_func(arr, *args)))
return pa.chunked_array(result_chunks)
else:
raise ValueError(f"PyArrow array can only be of type pa.Array or pa.ChunkedArray, but found {type(array)}")

return _transform


class Transform(IcebergRootModel[str], ABC, Generic[S, T]):
"""Transform base class for concrete transforms.

Expand Down Expand Up @@ -198,27 +228,6 @@ def supports_pyarrow_transform(self) -> bool:
@abstractmethod
def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": ...

def _pyiceberg_transform_wrapper(
self, transform_func: Callable[["ArrayLike", Any], "ArrayLike"], *args: Any
) -> Callable[["ArrayLike"], "ArrayLike"]:
try:
import pyarrow as pa
except ModuleNotFoundError as e:
raise ModuleNotFoundError("For bucket/truncate transforms, PyArrow needs to be installed") from e

def _transform(array: "ArrayLike") -> "ArrayLike":
if isinstance(array, pa.Array):
return transform_func(array, *args)
elif isinstance(array, pa.ChunkedArray):
result_chunks = []
for arr in array.iterchunks():
result_chunks.append(transform_func(arr, *args))
return pa.chunked_array(result_chunks)
else:
raise ValueError(f"PyArrow array can only be of type pa.Array or pa.ChunkedArray, but found {type(array)}")

return _transform


class BucketTransform(Transform[S, int]):
"""Base Transform class to transform a value into a bucket partition value.
Expand Down Expand Up @@ -375,7 +384,7 @@ def __repr__(self) -> str:
def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
from pyiceberg_core import transform as pyiceberg_core_transform

return self._pyiceberg_transform_wrapper(pyiceberg_core_transform.bucket, self._num_buckets)
return _pyiceberg_transform_wrapper(pyiceberg_core_transform.bucket, self._num_buckets)

@property
def supports_pyarrow_transform(self) -> bool:
Expand Down Expand Up @@ -501,22 +510,9 @@ def __repr__(self) -> str:

def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
import pyarrow as pa
import pyarrow.compute as pc

if isinstance(source, DateType):
epoch = pa.scalar(datetime.EPOCH_DATE)
elif isinstance(source, TimestampType):
epoch = pa.scalar(datetime.EPOCH_TIMESTAMP)
elif isinstance(source, TimestamptzType):
epoch = pa.scalar(datetime.EPOCH_TIMESTAMPTZ)
elif isinstance(source, TimestampNanoType):
epoch = pa.scalar(datetime.EPOCH_TIMESTAMP).cast(pa.timestamp("ns"))
elif isinstance(source, TimestamptzNanoType):
epoch = pa.scalar(datetime.EPOCH_TIMESTAMPTZ).cast(pa.timestamp("ns"))
else:
raise ValueError(f"Cannot apply year transform for type: {source}")
from pyiceberg_core import transform as pyiceberg_core_transform

return lambda v: pc.years_between(epoch, v) if v is not None else None
return _pyiceberg_transform_wrapper(pyiceberg_core_transform.year, expected_type=pa.int32())


class MonthTransform(TimeTransform[S]):
Expand Down Expand Up @@ -575,28 +571,9 @@ def __repr__(self) -> str:

def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
import pyarrow as pa
import pyarrow.compute as pc

if isinstance(source, DateType):
epoch = pa.scalar(datetime.EPOCH_DATE)
elif isinstance(source, TimestampType):
epoch = pa.scalar(datetime.EPOCH_TIMESTAMP)
elif isinstance(source, TimestamptzType):
epoch = pa.scalar(datetime.EPOCH_TIMESTAMPTZ)
elif isinstance(source, TimestampNanoType):
epoch = pa.scalar(datetime.EPOCH_TIMESTAMP).cast(pa.timestamp("ns"))
elif isinstance(source, TimestamptzNanoType):
epoch = pa.scalar(datetime.EPOCH_TIMESTAMPTZ).cast(pa.timestamp("ns"))
else:
raise ValueError(f"Cannot apply month transform for type: {source}")

def month_func(v: pa.Array) -> pa.Array:
return pc.add(
pc.multiply(pc.years_between(epoch, v), pa.scalar(12)),
pc.add(pc.month(v), pa.scalar(-1)),
)
from pyiceberg_core import transform as pyiceberg_core_transform

return lambda v: month_func(v) if v is not None else None
return _pyiceberg_transform_wrapper(pyiceberg_core_transform.month, expected_type=pa.int32())


class DayTransform(TimeTransform[S]):
Expand Down Expand Up @@ -663,22 +640,9 @@ def __repr__(self) -> str:

def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
import pyarrow as pa
import pyarrow.compute as pc

if isinstance(source, DateType):
epoch = pa.scalar(datetime.EPOCH_DATE)
elif isinstance(source, TimestampType):
epoch = pa.scalar(datetime.EPOCH_TIMESTAMP)
elif isinstance(source, TimestamptzType):
epoch = pa.scalar(datetime.EPOCH_TIMESTAMPTZ)
elif isinstance(source, TimestampNanoType):
epoch = pa.scalar(datetime.EPOCH_TIMESTAMP).cast(pa.timestamp("ns"))
elif isinstance(source, TimestamptzNanoType):
epoch = pa.scalar(datetime.EPOCH_TIMESTAMPTZ).cast(pa.timestamp("ns"))
else:
raise ValueError(f"Cannot apply day transform for type: {source}")
from pyiceberg_core import transform as pyiceberg_core_transform

return lambda v: pc.days_between(epoch, v) if v is not None else None
return _pyiceberg_transform_wrapper(pyiceberg_core_transform.day, expected_type=pa.int32())


class HourTransform(TimeTransform[S]):
Expand Down Expand Up @@ -728,21 +692,9 @@ def __repr__(self) -> str:
return "HourTransform()"

def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
import pyarrow as pa
import pyarrow.compute as pc

if isinstance(source, TimestampType):
epoch = pa.scalar(datetime.EPOCH_TIMESTAMP)
elif isinstance(source, TimestamptzType):
epoch = pa.scalar(datetime.EPOCH_TIMESTAMPTZ)
elif isinstance(source, TimestampNanoType):
epoch = pa.scalar(datetime.EPOCH_TIMESTAMP).cast(pa.timestamp("ns"))
elif isinstance(source, TimestamptzNanoType):
epoch = pa.scalar(datetime.EPOCH_TIMESTAMPTZ).cast(pa.timestamp("ns"))
else:
raise ValueError(f"Cannot apply hour transform for type: {source}")
from pyiceberg_core import transform as pyiceberg_core_transform

return lambda v: pc.hours_between(epoch, v) if v is not None else None
return _pyiceberg_transform_wrapper(pyiceberg_core_transform.hour)


def _base64encode(buffer: bytes) -> str:
Expand Down Expand Up @@ -965,7 +917,7 @@ def __repr__(self) -> str:
def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
from pyiceberg_core import transform as pyiceberg_core_transform

return self._pyiceberg_transform_wrapper(pyiceberg_core_transform.truncate, self._width)
return _pyiceberg_transform_wrapper(pyiceberg_core_transform.truncate, self._width)

@property
def supports_pyarrow_transform(self) -> bool:
Expand Down
10 changes: 9 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ psycopg2-binary = { version = ">=2.9.6", optional = true }
sqlalchemy = { version = "^2.0.18", optional = true }
getdaft = { version = ">=0.2.12", optional = true }
cachetools = "^5.5.0"
pyiceberg-core = { version = "^0.4.0", optional = true }
pyiceberg-core = { version = "0.4.0.dev20250326000154", source="testpypi", optional = true }
polars = { version = "^1.21.0", optional = true }
thrift-sasl = { version = ">=0.4.3", optional = true }

Expand Down Expand Up @@ -115,6 +115,14 @@ mkdocs-material = "9.6.9"
mkdocs-material-extensions = "1.3.1"
mkdocs-section-index = "0.3.9"

[[tool.poetry.source]]
name = "pypi"
priority = "primary"

[[tool.poetry.source]]
name = "testpypi"
url = "https://test.pypi.org/simple/"

[[tool.mypy.overrides]]
module = "pytest_mock.*"
ignore_missing_imports = true
Expand Down
4 changes: 2 additions & 2 deletions tests/table/test_partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,8 @@ def test_partition_type(table_schema_simple: Schema) -> None:
(DecimalType(5, 9), Decimal(19.25)),
(DateType(), datetime.date(1925, 5, 22)),
(TimeType(), datetime.time(19, 25, 00)),
(TimestampType(), datetime.datetime(19, 5, 1, 22, 1, 1)),
(TimestamptzType(), datetime.datetime(19, 5, 1, 22, 1, 1, tzinfo=datetime.timezone.utc)),
(TimestampType(), datetime.datetime(2022, 5, 1, 22, 1, 1)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is wrong with 19?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First, I thought it was some kind of clock correction somewhere in time that was accounted by one, and not the other. Digging into it deeper, it looks like there is an issue with negative numbers in general:

>                   assert t.transform(source_type)(value) == t.pyarrow_transform(source_type)(pa.array([value])).to_pylist()[0]
E                   assert -2 == -1
E                    +  where -2 = <function HourTransform.transform.<locals>.<lambda> at 0x143fa6e60>(datetime.datetime(1969, 12, 31, 22, 1, 1))
E                    +    where <function HourTransform.transform.<locals>.<lambda> at 0x143fa6e60> = <bound method HourTransform.transform of HourTransform()>(TimestampType())
E                    +      where <bound method HourTransform.transform of HourTransform()> = HourTransform().transform

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got a fix here: apache/iceberg-rust#1146

(TimestamptzType(), datetime.datetime(2022, 5, 1, 22, 1, 1, tzinfo=datetime.timezone.utc)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you want to merge #1592 into this PR?

(StringType(), "abc"),
(UUIDType(), UUID("12345678-1234-5678-1234-567812345678").bytes),
(FixedType(5), 'b"\x8e\xd1\x87\x01"'),
Expand Down