Skip to content
8 changes: 6 additions & 2 deletions pyiceberg/avro/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
List,
Optional,
Tuple,
Union,
)
from uuid import UUID

Expand Down Expand Up @@ -121,8 +122,11 @@ def write(self, encoder: BinaryEncoder, val: Any) -> None:

@dataclass(frozen=True)
class UUIDWriter(Writer):
def write(self, encoder: BinaryEncoder, val: UUID) -> None:
encoder.write(val.bytes)
def write(self, encoder: BinaryEncoder, val: Union[UUID, bytes]) -> None:
if isinstance(val, UUID):
encoder.write(val.bytes)
else:
encoder.write(val)


@dataclass(frozen=True)
Expand Down
4 changes: 3 additions & 1 deletion pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,7 @@ def visit_string(self, _: StringType) -> pa.DataType:
return pa.large_string()

def visit_uuid(self, _: UUIDType) -> pa.DataType:
return pa.binary(16)
return pa.uuid()

def visit_unknown(self, _: UnknownType) -> pa.DataType:
return pa.null()
Expand Down Expand Up @@ -1252,6 +1252,8 @@ def primitive(self, primitive: pa.DataType) -> PrimitiveType:
return FixedType(primitive.byte_width)
elif pa.types.is_null(primitive):
return UnknownType()
elif isinstance(primitive, pa.UuidType):
return UUIDType()

raise TypeError(f"Unsupported type: {primitive}")

Expand Down
13 changes: 11 additions & 2 deletions pyiceberg/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,8 +467,17 @@ def _(type: IcebergType, value: Optional[time]) -> Optional[int]:


@_to_partition_representation.register(UUIDType)
def _(type: IcebergType, value: Optional[uuid.UUID]) -> Optional[str]:
return str(value) if value is not None else None
def _(type: IcebergType, value: Optional[Union[uuid.UUID, int, bytes]]) -> Optional[Union[bytes, int]]:
if value is None:
return None
elif isinstance(value, bytes):
return value # IdentityTransform
elif isinstance(value, uuid.UUID):
return value.bytes # IdentityTransform
elif isinstance(value, int):
return value # BucketTransform
else:
raise ValueError(f"Type not recognized: {value}")


@_to_partition_representation.register(PrimitiveType)
Expand Down
7 changes: 5 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2788,7 +2788,7 @@ def pyarrow_schema_with_promoted_types() -> "pa.Schema":
pa.field("list", pa.list_(pa.int32()), nullable=False), # can support upcasting integer to long
pa.field("map", pa.map_(pa.string(), pa.int32()), nullable=False), # can support upcasting integer to long
pa.field("double", pa.float32(), nullable=True), # can support upcasting float to double
pa.field("uuid", pa.binary(length=16), nullable=True), # can support upcasting float to double
pa.field("uuid", pa.binary(length=16), nullable=True), # can support upcasting fixed to uuid
)
)

Expand All @@ -2804,7 +2804,10 @@ def pyarrow_table_with_promoted_types(pyarrow_schema_with_promoted_types: "pa.Sc
"list": [[1, 1], [2, 2]],
"map": [{"a": 1}, {"b": 2}],
"double": [1.1, 9.2],
"uuid": [b"qZx\xefNS@\x89\x9b\xf9:\xd0\xee\x9b\xf5E", b"\x97]\x87T^JDJ\x96\x97\xf4v\xe4\x03\x0c\xde"],
"uuid": [
uuid.UUID("00000000-0000-0000-0000-000000000000").bytes,
uuid.UUID("11111111-1111-1111-1111-111111111111").bytes,
],
},
schema=pyarrow_schema_with_promoted_types,
)
4 changes: 2 additions & 2 deletions tests/integration/test_add_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,7 @@ def test_add_files_with_valid_upcast(
with pq.ParquetWriter(fos, schema=pyarrow_schema_with_promoted_types) as writer:
writer.write_table(pyarrow_table_with_promoted_types)

tbl.add_files(file_paths=[file_path])
tbl.add_files(file_paths=[file_path], check_duplicate_files=False)
# table's long field should cast to long on read
written_arrow_table = tbl.scan().to_arrow()
assert written_arrow_table == pyarrow_table_with_promoted_types.cast(
Expand All @@ -747,7 +747,7 @@ def test_add_files_with_valid_upcast(
pa.field("list", pa.list_(pa.int64()), nullable=False),
pa.field("map", pa.map_(pa.string(), pa.int64()), nullable=False),
pa.field("double", pa.float64(), nullable=True),
pa.field("uuid", pa.binary(length=16), nullable=True), # can UUID is read as fixed length binary of length 16
pa.field("uuid", pa.uuid(), nullable=True),
)
)
)
Expand Down
20 changes: 0 additions & 20 deletions tests/integration/test_partitioning_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
# pylint:disable=redefined-outer-name
import uuid
from datetime import date, datetime, timedelta, timezone
from decimal import Decimal
from typing import Any, List
Expand Down Expand Up @@ -308,25 +307,6 @@
(CAST('2023-01-01' AS DATE), 'Associated string value for date 2023-01-01')
""",
),
(
[PartitionField(source_id=14, field_id=1001, transform=IdentityTransform(), name="uuid_field")],
[uuid.UUID("f47ac10b-58cc-4372-a567-0e02b2c3d479")],
Record("f47ac10b-58cc-4372-a567-0e02b2c3d479"),
"uuid_field=f47ac10b-58cc-4372-a567-0e02b2c3d479",
f"""CREATE TABLE {identifier} (
uuid_field string,
string_field string
)
USING iceberg
PARTITIONED BY (
identity(uuid_field)
)
""",
f"""INSERT INTO {identifier}
VALUES
('f47ac10b-58cc-4372-a567-0e02b2c3d479', 'Associated string value for UUID f47ac10b-58cc-4372-a567-0e02b2c3d479')
""",
),
(
[PartitionField(source_id=11, field_id=1001, transform=IdentityTransform(), name="binary_field")],
[b"example"],
Expand Down
8 changes: 4 additions & 4 deletions tests/integration/test_reads.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,15 +588,15 @@ def test_partitioned_tables(catalog: Catalog) -> None:
def test_unpartitioned_uuid_table(catalog: Catalog) -> None:
unpartitioned_uuid = catalog.load_table("default.test_uuid_and_fixed_unpartitioned")
arrow_table_eq = unpartitioned_uuid.scan(row_filter="uuid_col == '102cb62f-e6f8-4eb0-9973-d9b012ff0967'").to_arrow()
assert arrow_table_eq["uuid_col"].to_pylist() == [uuid.UUID("102cb62f-e6f8-4eb0-9973-d9b012ff0967").bytes]
assert arrow_table_eq["uuid_col"].to_pylist() == [uuid.UUID("102cb62f-e6f8-4eb0-9973-d9b012ff0967")]

arrow_table_neq = unpartitioned_uuid.scan(
row_filter="uuid_col != '102cb62f-e6f8-4eb0-9973-d9b012ff0967' and uuid_col != '639cccce-c9d2-494a-a78c-278ab234f024'"
).to_arrow()
assert arrow_table_neq["uuid_col"].to_pylist() == [
uuid.UUID("ec33e4b2-a834-4cc3-8c4a-a1d3bfc2f226").bytes,
uuid.UUID("c1b0d8e0-0b0e-4b1e-9b0a-0e0b0d0c0a0b").bytes,
uuid.UUID("923dae77-83d6-47cd-b4b0-d383e64ee57e").bytes,
uuid.UUID("ec33e4b2-a834-4cc3-8c4a-a1d3bfc2f226"),
uuid.UUID("c1b0d8e0-0b0e-4b1e-9b0a-0e0b0d0c0a0b"),
uuid.UUID("923dae77-83d6-47cd-b4b0-d383e64ee57e"),
]


Expand Down
59 changes: 57 additions & 2 deletions tests/integration/test_writes/test_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import os
import random
import time
import uuid
from datetime import date, datetime, timedelta
from decimal import Decimal
from pathlib import Path
Expand Down Expand Up @@ -49,7 +50,7 @@
from pyiceberg.schema import Schema
from pyiceberg.table import TableProperties
from pyiceberg.table.sorting import SortDirection, SortField, SortOrder
from pyiceberg.transforms import DayTransform, HourTransform, IdentityTransform
from pyiceberg.transforms import DayTransform, HourTransform, IdentityTransform, Transform
from pyiceberg.types import (
DateType,
DecimalType,
Expand All @@ -59,6 +60,7 @@
LongType,
NestedField,
StringType,
UUIDType,
)
from utils import _create_table

Expand Down Expand Up @@ -1272,7 +1274,7 @@ def test_table_write_schema_with_valid_upcast(
pa.field("list", pa.list_(pa.int64()), nullable=False),
pa.field("map", pa.map_(pa.string(), pa.int64()), nullable=False),
pa.field("double", pa.float64(), nullable=True), # can support upcasting float to double
pa.field("uuid", pa.binary(length=16), nullable=True), # can UUID is read as fixed length binary of length 16
pa.field("uuid", pa.uuid(), nullable=True),
)
)
)
Expand Down Expand Up @@ -1844,6 +1846,59 @@ def test_read_write_decimals(session_catalog: Catalog) -> None:
assert tbl.scan().to_arrow() == arrow_table


@pytest.mark.integration
@pytest.mark.parametrize(
"transform",
[
IdentityTransform(),
# Bucket is disabled because of an issue in Iceberg Java:
# https://github.com/apache/iceberg/pull/13324
# BucketTransform(32)
],
)
def test_uuid_partitioning(session_catalog: Catalog, spark: SparkSession, transform: Transform) -> None: # type: ignore
identifier = f"default.test_uuid_partitioning_{str(transform).replace('[32]', '')}"

schema = Schema(NestedField(field_id=1, name="uuid", field_type=UUIDType(), required=True))

try:
session_catalog.drop_table(identifier=identifier)
except NoSuchTableError:
pass

partition_spec = PartitionSpec(PartitionField(source_id=1, field_id=1000, transform=transform, name="uuid_identity"))

import pyarrow as pa

arr_table = pa.Table.from_pydict(
{
"uuid": [
uuid.UUID("00000000-0000-0000-0000-000000000000").bytes,
uuid.UUID("11111111-1111-1111-1111-111111111111").bytes,
],
},
schema=pa.schema(
[
# Uuid not yet supported, so we have to stick with `binary(16)`
# https://github.com/apache/arrow/issues/46468
pa.field("uuid", pa.binary(16), nullable=False),
]
),
)

tbl = session_catalog.create_table(
identifier=identifier,
schema=schema,
partition_spec=partition_spec,
)

tbl.append(arr_table)

lhs = [r[0] for r in spark.table(identifier).collect()]
rhs = [str(u.as_py()) for u in tbl.scan().to_arrow()["uuid"].combine_chunks()]
assert lhs == rhs


@pytest.mark.integration
def test_avro_compression_codecs(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None:
identifier = "default.test_avro_compression_codecs"
Expand Down