From 7803fcd997ffaff4c9694e2d631c3074d9bfb20a Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Wed, 10 Jul 2024 21:45:48 +0000 Subject: [PATCH 1/7] support Etc/UTC --- pyiceberg/io/pyarrow.py | 2 +- tests/integration/test_writes/test_writes.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index ae7799cfde..852b4a2a12 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -937,7 +937,7 @@ def primitive(self, primitive: pa.DataType) -> PrimitiveType: else: raise TypeError(f"Unsupported precision for timestamp type: {primitive.unit}") - if primitive.tz == "UTC" or primitive.tz == "+00:00": + if primitive.tz in ("UTC", "+00:00", "Etc/UTC"): return TimestamptzType() elif primitive.tz is None: return TimestampType() diff --git a/tests/integration/test_writes/test_writes.py b/tests/integration/test_writes/test_writes.py index 2542fbdb38..cb87e69327 100644 --- a/tests/integration/test_writes/test_writes.py +++ b/tests/integration/test_writes/test_writes.py @@ -979,6 +979,7 @@ def test_write_all_timestamp_precision(mocker: MockerFixture, session_catalog: C ("timestamptz_us", pa.timestamp(unit="us", tz="UTC")), ("timestamp_ns", pa.timestamp(unit="ns")), ("timestamptz_ns", pa.timestamp(unit="ns", tz="UTC")), + ("timestamptz_us_etc_utc", pa.timestamp(unit="us", tz="Etc/UTC")), ]) TEST_DATA_WITH_NULL = { "timestamp_s": [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)], @@ -1005,6 +1006,11 @@ def test_write_all_timestamp_precision(mocker: MockerFixture, session_catalog: C None, datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc), ], + "timestamptz_us_etc_utc": [ + datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc), + None, + datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc), + ], } input_arrow_table = pa.Table.from_pydict(TEST_DATA_WITH_NULL, schema=arrow_table_schema_with_all_timestamp_precisions) mocker.patch.dict(os.environ, values={"PYICEBERG_DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE": "True"}) @@ -1028,6 +1034,7 @@ def test_write_all_timestamp_precision(mocker: MockerFixture, session_catalog: C ("timestamptz_us", pa.timestamp(unit="us", tz="UTC")), ("timestamp_ns", pa.timestamp(unit="us")), ("timestamptz_ns", pa.timestamp(unit="us", tz="UTC")), + ("timestamptz_us_etc_utc", pa.timestamp(unit="us", tz="UTC")), ]) assert written_arrow_table.schema == expected_schema_in_all_us assert written_arrow_table == input_arrow_table.cast(expected_schema_in_all_us) From fa3ba66b1fc23f95ee7d066316582cd3efc4270d Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Thu, 11 Jul 2024 08:33:49 -0400 Subject: [PATCH 2/7] Update pyiceberg/io/pyarrow.py Super Co-authored-by: Fokko Driesprong --- pyiceberg/io/pyarrow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 852b4a2a12..19e7bee1f0 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -937,7 +937,7 @@ def primitive(self, primitive: pa.DataType) -> PrimitiveType: else: raise TypeError(f"Unsupported precision for timestamp type: {primitive.unit}") - if primitive.tz in ("UTC", "+00:00", "Etc/UTC"): + if primitive.tz in {"UTC", "+00:00", "Etc/UTC"}: return TimestamptzType() elif primitive.tz is None: return TimestampType() From fd5313848c736aa97e04ed05e22ea5e301be3264 Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Thu, 11 Jul 2024 15:14:23 +0000 Subject: [PATCH 3/7] tests --- pyiceberg/io/pyarrow.py | 2 +- pyiceberg/table/__init__.py | 8 ----- tests/integration/test_writes/test_writes.py | 37 ++++++++++++++++---- 3 files changed, 32 insertions(+), 15 deletions(-) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 19e7bee1f0..e939adca74 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -937,7 +937,7 @@ def primitive(self, primitive: pa.DataType) -> PrimitiveType: else: raise TypeError(f"Unsupported precision for timestamp type: {primitive.unit}") - if primitive.tz in {"UTC", "+00:00", "Etc/UTC"}: + if primitive.tz in {"UTC", "+00:00", "Etc/UTC", "Z"}: return TimestamptzType() elif primitive.tz is None: return TimestampType() diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 39bcfc2ef6..de2fb77028 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -529,10 +529,6 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) ) _check_schema_compatible(self._table.schema(), other_schema=df.schema) - # cast if the two schemas are compatible but not equal - table_arrow_schema = self._table.schema().as_arrow() - if table_arrow_schema != df.schema: - df = df.cast(table_arrow_schema) manifest_merge_enabled = PropertyUtil.property_as_bool( self.table_metadata.properties, @@ -588,10 +584,6 @@ def overwrite( ) _check_schema_compatible(self._table.schema(), other_schema=df.schema) - # cast if the two schemas are compatible but not equal - table_arrow_schema = self._table.schema().as_arrow() - if table_arrow_schema != df.schema: - df = df.cast(table_arrow_schema) self.delete(delete_filter=overwrite_filter, snapshot_properties=snapshot_properties) diff --git a/tests/integration/test_writes/test_writes.py b/tests/integration/test_writes/test_writes.py index cb87e69327..c03efa9dcd 100644 --- a/tests/integration/test_writes/test_writes.py +++ b/tests/integration/test_writes/test_writes.py @@ -23,6 +23,7 @@ from typing import Any, Dict from urllib.parse import urlparse +import pandas as pd import pyarrow as pa import pyarrow.parquet as pq import pytest @@ -968,7 +969,9 @@ def table_write_subset_of_schema(session_catalog: Catalog, arrow_table_with_null @pytest.mark.integration @pytest.mark.parametrize("format_version", [1, 2]) -def test_write_all_timestamp_precision(mocker: MockerFixture, session_catalog: Catalog, format_version: int) -> None: +def test_write_all_timestamp_precision( + mocker: MockerFixture, spark: SparkSession, session_catalog: Catalog, format_version: int +) -> None: identifier = "default.table_all_timestamp_precision" arrow_table_schema_with_all_timestamp_precisions = pa.schema([ ("timestamp_s", pa.timestamp(unit="s")), @@ -980,8 +983,9 @@ def test_write_all_timestamp_precision(mocker: MockerFixture, session_catalog: C ("timestamp_ns", pa.timestamp(unit="ns")), ("timestamptz_ns", pa.timestamp(unit="ns", tz="UTC")), ("timestamptz_us_etc_utc", pa.timestamp(unit="us", tz="Etc/UTC")), + ("timestamptz_us_z", pa.timestamp(unit="us", tz="Z")), ]) - TEST_DATA_WITH_NULL = { + TEST_DATA_WITH_NULL = pd.DataFrame({ "timestamp_s": [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)], "timestamptz_s": [ datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc), @@ -1000,7 +1004,11 @@ def test_write_all_timestamp_precision(mocker: MockerFixture, session_catalog: C None, datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc), ], - "timestamp_ns": [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)], + "timestamp_ns": [ + pd.Timestamp(year=2024, month=7, day=11, hour=3, minute=30, second=0, microsecond=12, nanosecond=6), + None, + pd.Timestamp(year=2024, month=7, day=11, hour=3, minute=30, second=0, microsecond=12, nanosecond=7), + ], "timestamptz_ns": [ datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc), None, @@ -1011,8 +1019,13 @@ def test_write_all_timestamp_precision(mocker: MockerFixture, session_catalog: C None, datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc), ], - } - input_arrow_table = pa.Table.from_pydict(TEST_DATA_WITH_NULL, schema=arrow_table_schema_with_all_timestamp_precisions) + "timestamptz_us_z": [ + datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc), + None, + datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc), + ], + }) + input_arrow_table = pa.Table.from_pandas(TEST_DATA_WITH_NULL, schema=arrow_table_schema_with_all_timestamp_precisions) mocker.patch.dict(os.environ, values={"PYICEBERG_DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE": "True"}) tbl = _create_table( @@ -1035,9 +1048,21 @@ def test_write_all_timestamp_precision(mocker: MockerFixture, session_catalog: C ("timestamp_ns", pa.timestamp(unit="us")), ("timestamptz_ns", pa.timestamp(unit="us", tz="UTC")), ("timestamptz_us_etc_utc", pa.timestamp(unit="us", tz="UTC")), + ("timestamptz_us_z", pa.timestamp(unit="us", tz="UTC")), ]) assert written_arrow_table.schema == expected_schema_in_all_us - assert written_arrow_table == input_arrow_table.cast(expected_schema_in_all_us) + assert written_arrow_table == input_arrow_table.cast(expected_schema_in_all_us, safe=False) + lhs = spark.table(f"{identifier}").toPandas() + rhs = written_arrow_table.to_pandas() + + for column in written_arrow_table.column_names: + for left, right in zip(lhs[column].to_list(), rhs[column].to_list()): + if pd.isnull(left): + assert pd.isnull(right) + else: + # Check only upto microsecond precision since Spark loaded dtype is timezone unaware + # and supports upto microsecond precision + assert left.timestamp() == right.timestamp(), f"Difference in column {column}: {left} != {right}" @pytest.mark.integration From bf70252f163d14bfeca68514d5a00da87591d95b Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Thu, 11 Jul 2024 16:13:01 +0000 Subject: [PATCH 4/7] fix --- pyiceberg/io/pyarrow.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 3c696778f8..b7ecc9b36d 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -174,6 +174,7 @@ MAP_KEY_NAME = "key" MAP_VALUE_NAME = "value" DOC = "doc" +UTC_ALIASES = {"UTC", "+00:00", "Etc/UTC", "Z"} T = TypeVar("T") @@ -937,7 +938,7 @@ def primitive(self, primitive: pa.DataType) -> PrimitiveType: else: raise TypeError(f"Unsupported precision for timestamp type: {primitive.unit}") - if primitive.tz in {"UTC", "+00:00", "Etc/UTC", "Z"}: + if primitive.tz in UTC_ALIASES: return TimestamptzType() elif primitive.tz is None: return TimestampType() @@ -1320,7 +1321,16 @@ def _cast_if_needed(self, field: NestedField, values: pa.Array) -> pa.Array: and pa.types.is_timestamp(values.type) and values.type.unit == "ns" ): - return values.cast(target_type, safe=False) + if target_type.tz == "UTC" and values.type.tz in UTC_ALIASES or not target_type.tz and not values.type.tz: + return values.cast(target_type, safe=False) + if ( + pa.types.is_timestamp(target_type) + and target_type.unit == "us" + and pa.types.is_timestamp(values.type) + and values.type.unit in {"s", "ms", "us"} + ): + if target_type.tz == "UTC" and values.type.tz in UTC_ALIASES or not target_type.tz and not values.type.tz: + return values.cast(target_type) return values def _construct_field(self, field: NestedField, arrow_type: pa.DataType) -> pa.Field: From 9375e87b589afa5041e12a0622caceca51e984c2 Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Thu, 11 Jul 2024 21:02:19 +0000 Subject: [PATCH 5/7] parenthesis - thanks Fokko! --- pyiceberg/io/pyarrow.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index b7ecc9b36d..188d378a01 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -1321,7 +1321,7 @@ def _cast_if_needed(self, field: NestedField, values: pa.Array) -> pa.Array: and pa.types.is_timestamp(values.type) and values.type.unit == "ns" ): - if target_type.tz == "UTC" and values.type.tz in UTC_ALIASES or not target_type.tz and not values.type.tz: + if (target_type.tz == "UTC" and values.type.tz in UTC_ALIASES) or (not target_type.tz and not values.type.tz): return values.cast(target_type, safe=False) if ( pa.types.is_timestamp(target_type) @@ -1329,7 +1329,7 @@ def _cast_if_needed(self, field: NestedField, values: pa.Array) -> pa.Array: and pa.types.is_timestamp(values.type) and values.type.unit in {"s", "ms", "us"} ): - if target_type.tz == "UTC" and values.type.tz in UTC_ALIASES or not target_type.tz and not values.type.tz: + if (target_type.tz == "UTC" and values.type.tz in UTC_ALIASES) or (not target_type.tz and not values.type.tz): return values.cast(target_type) return values From ce643f670f04e7ee4db0363cebdc5457e3323e56 Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Thu, 11 Jul 2024 22:50:24 +0000 Subject: [PATCH 6/7] adopt review feedback: more tests, refactoring, stricter checks --- pyiceberg/io/pyarrow.py | 51 ++++---- tests/conftest.py | 116 +++++++++++++++++- tests/integration/test_add_files.py | 1 + .../test_writes/test_partitioned_writes.py | 14 +-- tests/integration/test_writes/test_writes.py | 85 +++---------- tests/io/test_pyarrow.py | 33 +++++ 6 files changed, 199 insertions(+), 101 deletions(-) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 188d378a01..976f727d63 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -1297,16 +1297,17 @@ def to_requested_schema( class ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, Optional[pa.Array]]): - file_schema: Schema + _file_schema: Schema _include_field_ids: bool + _downcast_ns_timestamp_to_us: bool def __init__(self, file_schema: Schema, downcast_ns_timestamp_to_us: bool = False, include_field_ids: bool = False) -> None: - self.file_schema = file_schema + self._file_schema = file_schema self._include_field_ids = include_field_ids - self.downcast_ns_timestamp_to_us = downcast_ns_timestamp_to_us + self._downcast_ns_timestamp_to_us = downcast_ns_timestamp_to_us def _cast_if_needed(self, field: NestedField, values: pa.Array) -> pa.Array: - file_field = self.file_schema.find_field(field.field_id) + file_field = self._file_schema.find_field(field.field_id) if field.field_type.is_primitive: if field.field_type != file_field.field_type: @@ -1314,23 +1315,31 @@ def _cast_if_needed(self, field: NestedField, values: pa.Array) -> pa.Array: schema_to_pyarrow(promote(file_field.field_type, field.field_type), include_field_ids=self._include_field_ids) ) elif (target_type := schema_to_pyarrow(field.field_type, include_field_ids=self._include_field_ids)) != values.type: - # Downcasting of nanoseconds to microseconds - if ( - pa.types.is_timestamp(target_type) - and target_type.unit == "us" - and pa.types.is_timestamp(values.type) - and values.type.unit == "ns" - ): - if (target_type.tz == "UTC" and values.type.tz in UTC_ALIASES) or (not target_type.tz and not values.type.tz): - return values.cast(target_type, safe=False) - if ( - pa.types.is_timestamp(target_type) - and target_type.unit == "us" - and pa.types.is_timestamp(values.type) - and values.type.unit in {"s", "ms", "us"} - ): - if (target_type.tz == "UTC" and values.type.tz in UTC_ALIASES) or (not target_type.tz and not values.type.tz): - return values.cast(target_type) + if field.field_type == TimestampType(): + # Downcasting of nanoseconds to microseconds + if ( + pa.types.is_timestamp(target_type) + and not target_type.tz + and pa.types.is_timestamp(values.type) + and not values.type.tz + ): + if target_type.unit == "us" and values.type.unit == "ns" and self._downcast_ns_timestamp_to_us: + return values.cast(target_type, safe=False) + elif target_type.unit == "us" and values.type.unit in {"s", "ms"}: + return values.cast(target_type) + raise ValueError(f"Unsupported schema projection from {values.type} to {target_type}") + elif field.field_type == TimestamptzType(): + if ( + pa.types.is_timestamp(target_type) + and target_type.tz == "UTC" + and pa.types.is_timestamp(values.type) + and values.type.tz in UTC_ALIASES + ): + if target_type.unit == "us" and values.type.unit == "ns" and self._downcast_ns_timestamp_to_us: + return values.cast(target_type, safe=False) + elif target_type.unit == "us" and values.type.unit in {"s", "ms", "us"}: + return values.cast(target_type) + raise ValueError(f"Unsupported schema projection from {values.type} to {target_type}") return values def _construct_field(self, field: NestedField, arrow_type: pa.DataType) -> pa.Field: diff --git a/tests/conftest.py b/tests/conftest.py index 95e1128af6..6b1a2b43e2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2382,10 +2382,122 @@ def arrow_table_date_timestamps() -> "pa.Table": @pytest.fixture(scope="session") -def arrow_table_date_timestamps_schema() -> Schema: - """Pyarrow table Schema with only date, timestamp and timestamptz values.""" +def table_date_timestamps_schema() -> Schema: + """Iceberg table Schema with only date, timestamp and timestamptz values.""" return Schema( NestedField(field_id=1, name="date", field_type=DateType(), required=False), NestedField(field_id=2, name="timestamp", field_type=TimestampType(), required=False), NestedField(field_id=3, name="timestamptz", field_type=TimestamptzType(), required=False), ) + + +@pytest.fixture(scope="session") +def arrow_table_schema_with_all_timestamp_precisions() -> "pa.Schema": + """Pyarrow Schema with all supported timestamp types.""" + import pyarrow as pa + + return pa.schema([ + ("timestamp_s", pa.timestamp(unit="s")), + ("timestamptz_s", pa.timestamp(unit="s", tz="UTC")), + ("timestamp_ms", pa.timestamp(unit="ms")), + ("timestamptz_ms", pa.timestamp(unit="ms", tz="UTC")), + ("timestamp_us", pa.timestamp(unit="us")), + ("timestamptz_us", pa.timestamp(unit="us", tz="UTC")), + ("timestamp_ns", pa.timestamp(unit="ns")), + ("timestamptz_ns", pa.timestamp(unit="ns", tz="UTC")), + ("timestamptz_us_etc_utc", pa.timestamp(unit="us", tz="Etc/UTC")), + ("timestamptz_ns_z", pa.timestamp(unit="ns", tz="Z")), + ("timestamptz_s_0000", pa.timestamp(unit="s", tz="+00:00")), + ]) + + +@pytest.fixture(scope="session") +def arrow_table_with_all_timestamp_precisions(arrow_table_schema_with_all_timestamp_precisions: "pa.Schema") -> "pa.Table": + """Pyarrow table with all supported timestamp types.""" + import pandas as pd + import pyarrow as pa + + test_data = pd.DataFrame({ + "timestamp_s": [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)], + "timestamptz_s": [ + datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc), + None, + datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc), + ], + "timestamp_ms": [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)], + "timestamptz_ms": [ + datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc), + None, + datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc), + ], + "timestamp_us": [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)], + "timestamptz_us": [ + datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc), + None, + datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc), + ], + "timestamp_ns": [ + pd.Timestamp(year=2024, month=7, day=11, hour=3, minute=30, second=0, microsecond=12, nanosecond=6), + None, + pd.Timestamp(year=2024, month=7, day=11, hour=3, minute=30, second=0, microsecond=12, nanosecond=7), + ], + "timestamptz_ns": [ + datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc), + None, + datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc), + ], + "timestamptz_us_etc_utc": [ + datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc), + None, + datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc), + ], + "timestamptz_ns_z": [ + pd.Timestamp(year=2024, month=7, day=11, hour=3, minute=30, second=0, microsecond=12, nanosecond=6, tz="UTC"), + None, + pd.Timestamp(year=2024, month=7, day=11, hour=3, minute=30, second=0, microsecond=12, nanosecond=7, tz="UTC"), + ], + "timestamptz_s_0000": [ + datetime(2023, 1, 1, 19, 25, 1, tzinfo=timezone.utc), + None, + datetime(2023, 3, 1, 19, 25, 1, tzinfo=timezone.utc), + ], + }) + return pa.Table.from_pandas(test_data, schema=arrow_table_schema_with_all_timestamp_precisions) + + +@pytest.fixture(scope="session") +def arrow_table_schema_with_all_microseconds_timestamp_precisions() -> "pa.Schema": + """Pyarrow Schema with all microseconds timestamp.""" + import pyarrow as pa + + return pa.schema([ + ("timestamp_s", pa.timestamp(unit="us")), + ("timestamptz_s", pa.timestamp(unit="us", tz="UTC")), + ("timestamp_ms", pa.timestamp(unit="us")), + ("timestamptz_ms", pa.timestamp(unit="us", tz="UTC")), + ("timestamp_us", pa.timestamp(unit="us")), + ("timestamptz_us", pa.timestamp(unit="us", tz="UTC")), + ("timestamp_ns", pa.timestamp(unit="us")), + ("timestamptz_ns", pa.timestamp(unit="us", tz="UTC")), + ("timestamptz_us_etc_utc", pa.timestamp(unit="us", tz="UTC")), + ("timestamptz_ns_z", pa.timestamp(unit="us", tz="UTC")), + ("timestamptz_s_0000", pa.timestamp(unit="us", tz="UTC")), + ]) + + +@pytest.fixture(scope="session") +def table_schema_with_all_microseconds_timestamp_precision() -> Schema: + """Iceberg table Schema with only date, timestamp and timestamptz values.""" + return Schema( + NestedField(field_id=1, name="timestamp_s", field_type=TimestampType(), required=False), + NestedField(field_id=2, name="timestamptz_s", field_type=TimestamptzType(), required=False), + NestedField(field_id=3, name="timestamp_ms", field_type=TimestampType(), required=False), + NestedField(field_id=4, name="timestamptz_ms", field_type=TimestamptzType(), required=False), + NestedField(field_id=5, name="timestamp_us", field_type=TimestampType(), required=False), + NestedField(field_id=6, name="timestamptz_us", field_type=TimestamptzType(), required=False), + NestedField(field_id=7, name="timestamp_ns", field_type=TimestampType(), required=False), + NestedField(field_id=8, name="timestamptz_ns", field_type=TimestamptzType(), required=False), + NestedField(field_id=9, name="timestamptz_us_etc_utc", field_type=TimestamptzType(), required=False), + NestedField(field_id=10, name="timestamptz_ns_z", field_type=TimestamptzType(), required=False), + NestedField(field_id=11, name="timestamptz_s_0000", field_type=TimestamptzType(), required=False), + ) diff --git a/tests/integration/test_add_files.py b/tests/integration/test_add_files.py index 825d17e924..1ef004577e 100644 --- a/tests/integration/test_add_files.py +++ b/tests/integration/test_add_files.py @@ -518,6 +518,7 @@ def test_add_files_with_large_and_regular_schema(spark: SparkSession, session_ca assert table_schema == arrow_schema_large +@pytest.mark.integration def test_timestamp_tz_ns_downcast_on_read(session_catalog: Catalog, format_version: int, mocker: MockerFixture) -> None: nanoseconds_schema_iceberg = Schema(NestedField(1, "quux", TimestamptzType())) diff --git a/tests/integration/test_writes/test_partitioned_writes.py b/tests/integration/test_writes/test_partitioned_writes.py index 12da9c928b..b199f00210 100644 --- a/tests/integration/test_writes/test_partitioned_writes.py +++ b/tests/integration/test_writes/test_partitioned_writes.py @@ -461,7 +461,7 @@ def test_append_transform_partition_verify_partitions_count( session_catalog: Catalog, spark: SparkSession, arrow_table_date_timestamps: pa.Table, - arrow_table_date_timestamps_schema: Schema, + table_date_timestamps_schema: Schema, transform: Transform[Any, Any], expected_partitions: Set[Any], format_version: int, @@ -469,7 +469,7 @@ def test_append_transform_partition_verify_partitions_count( # Given part_col = "timestamptz" identifier = f"default.arrow_table_v{format_version}_with_{str(transform)}_transform_partitioned_on_col_{part_col}" - nested_field = arrow_table_date_timestamps_schema.find_field(part_col) + nested_field = table_date_timestamps_schema.find_field(part_col) partition_spec = PartitionSpec( PartitionField(source_id=nested_field.field_id, field_id=1001, transform=transform, name=part_col), ) @@ -481,7 +481,7 @@ def test_append_transform_partition_verify_partitions_count( properties={"format-version": str(format_version)}, data=[arrow_table_date_timestamps], partition_spec=partition_spec, - schema=arrow_table_date_timestamps_schema, + schema=table_date_timestamps_schema, ) # Then @@ -510,20 +510,20 @@ def test_append_multiple_partitions( session_catalog: Catalog, spark: SparkSession, arrow_table_date_timestamps: pa.Table, - arrow_table_date_timestamps_schema: Schema, + table_date_timestamps_schema: Schema, format_version: int, ) -> None: # Given identifier = f"default.arrow_table_v{format_version}_with_multiple_partitions" partition_spec = PartitionSpec( PartitionField( - source_id=arrow_table_date_timestamps_schema.find_field("date").field_id, + source_id=table_date_timestamps_schema.find_field("date").field_id, field_id=1001, transform=YearTransform(), name="date_year", ), PartitionField( - source_id=arrow_table_date_timestamps_schema.find_field("timestamptz").field_id, + source_id=table_date_timestamps_schema.find_field("timestamptz").field_id, field_id=1000, transform=HourTransform(), name="timestamptz_hour", @@ -537,7 +537,7 @@ def test_append_multiple_partitions( properties={"format-version": str(format_version)}, data=[arrow_table_date_timestamps], partition_spec=partition_spec, - schema=arrow_table_date_timestamps_schema, + schema=table_date_timestamps_schema, ) # Then diff --git a/tests/integration/test_writes/test_writes.py b/tests/integration/test_writes/test_writes.py index e8eb9f6ae6..41bc6fb5bf 100644 --- a/tests/integration/test_writes/test_writes.py +++ b/tests/integration/test_writes/test_writes.py @@ -18,7 +18,7 @@ import math import os import time -from datetime import date, datetime, timezone +from datetime import date, datetime from pathlib import Path from typing import Any, Dict from urllib.parse import urlparse @@ -979,88 +979,31 @@ def table_write_subset_of_schema(session_catalog: Catalog, arrow_table_with_null @pytest.mark.integration @pytest.mark.parametrize("format_version", [1, 2]) def test_write_all_timestamp_precision( - mocker: MockerFixture, spark: SparkSession, session_catalog: Catalog, format_version: int + mocker: MockerFixture, + spark: SparkSession, + session_catalog: Catalog, + format_version: int, + arrow_table_schema_with_all_timestamp_precisions: pa.Schema, + arrow_table_with_all_timestamp_precisions: pa.Table, + arrow_table_schema_with_all_microseconds_timestamp_precisions: pa.Schema, ) -> None: identifier = "default.table_all_timestamp_precision" - arrow_table_schema_with_all_timestamp_precisions = pa.schema([ - ("timestamp_s", pa.timestamp(unit="s")), - ("timestamptz_s", pa.timestamp(unit="s", tz="UTC")), - ("timestamp_ms", pa.timestamp(unit="ms")), - ("timestamptz_ms", pa.timestamp(unit="ms", tz="UTC")), - ("timestamp_us", pa.timestamp(unit="us")), - ("timestamptz_us", pa.timestamp(unit="us", tz="UTC")), - ("timestamp_ns", pa.timestamp(unit="ns")), - ("timestamptz_ns", pa.timestamp(unit="ns", tz="UTC")), - ("timestamptz_us_etc_utc", pa.timestamp(unit="us", tz="Etc/UTC")), - ("timestamptz_us_z", pa.timestamp(unit="us", tz="Z")), - ]) - TEST_DATA_WITH_NULL = pd.DataFrame({ - "timestamp_s": [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)], - "timestamptz_s": [ - datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc), - None, - datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc), - ], - "timestamp_ms": [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)], - "timestamptz_ms": [ - datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc), - None, - datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc), - ], - "timestamp_us": [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)], - "timestamptz_us": [ - datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc), - None, - datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc), - ], - "timestamp_ns": [ - pd.Timestamp(year=2024, month=7, day=11, hour=3, minute=30, second=0, microsecond=12, nanosecond=6), - None, - pd.Timestamp(year=2024, month=7, day=11, hour=3, minute=30, second=0, microsecond=12, nanosecond=7), - ], - "timestamptz_ns": [ - datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc), - None, - datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc), - ], - "timestamptz_us_etc_utc": [ - datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc), - None, - datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc), - ], - "timestamptz_us_z": [ - datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc), - None, - datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc), - ], - }) - input_arrow_table = pa.Table.from_pandas(TEST_DATA_WITH_NULL, schema=arrow_table_schema_with_all_timestamp_precisions) mocker.patch.dict(os.environ, values={"PYICEBERG_DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE": "True"}) tbl = _create_table( session_catalog, identifier, {"format-version": format_version}, - data=[input_arrow_table], + data=[arrow_table_with_all_timestamp_precisions], schema=arrow_table_schema_with_all_timestamp_precisions, ) - tbl.overwrite(input_arrow_table) + tbl.overwrite(arrow_table_with_all_timestamp_precisions) written_arrow_table = tbl.scan().to_arrow() - expected_schema_in_all_us = pa.schema([ - ("timestamp_s", pa.timestamp(unit="us")), - ("timestamptz_s", pa.timestamp(unit="us", tz="UTC")), - ("timestamp_ms", pa.timestamp(unit="us")), - ("timestamptz_ms", pa.timestamp(unit="us", tz="UTC")), - ("timestamp_us", pa.timestamp(unit="us")), - ("timestamptz_us", pa.timestamp(unit="us", tz="UTC")), - ("timestamp_ns", pa.timestamp(unit="us")), - ("timestamptz_ns", pa.timestamp(unit="us", tz="UTC")), - ("timestamptz_us_etc_utc", pa.timestamp(unit="us", tz="UTC")), - ("timestamptz_us_z", pa.timestamp(unit="us", tz="UTC")), - ]) - assert written_arrow_table.schema == expected_schema_in_all_us - assert written_arrow_table == input_arrow_table.cast(expected_schema_in_all_us, safe=False) + assert written_arrow_table.schema == arrow_table_schema_with_all_microseconds_timestamp_precisions + assert written_arrow_table == arrow_table_with_all_timestamp_precisions.cast( + arrow_table_schema_with_all_microseconds_timestamp_precisions, safe=False + ) lhs = spark.table(f"{identifier}").toPandas() rhs = written_arrow_table.to_pandas() diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py index 1b9468993c..fde50a035e 100644 --- a/tests/io/test_pyarrow.py +++ b/tests/io/test_pyarrow.py @@ -68,6 +68,7 @@ expression_to_pyarrow, project_table, schema_to_pyarrow, + to_requested_schema, ) from pyiceberg.manifest import DataFile, DataFileContent, FileFormat from pyiceberg.partitioning import PartitionField, PartitionSpec @@ -1798,3 +1799,35 @@ def test_identity_partition_on_multi_columns() -> None: ("n_legs", "ascending"), ("animal", "ascending"), ]) == arrow_table.sort_by([("born_year", "ascending"), ("n_legs", "ascending"), ("animal", "ascending")]) + + +def test_to_requested_schema_timestamps( + arrow_table_schema_with_all_timestamp_precisions: pa.Schema, + arrow_table_with_all_timestamp_precisions: pa.Table, + arrow_table_schema_with_all_microseconds_timestamp_precisions: pa.Schema, + table_schema_with_all_microseconds_timestamp_precision: Schema, +) -> None: + requested_schema = table_schema_with_all_microseconds_timestamp_precision + file_schema = requested_schema + batch = arrow_table_with_all_timestamp_precisions.to_batches()[0] + result = to_requested_schema(requested_schema, file_schema, batch, downcast_ns_timestamp_to_us=True, include_field_ids=False) + + expected = arrow_table_with_all_timestamp_precisions.cast( + arrow_table_schema_with_all_microseconds_timestamp_precisions, safe=False + ).to_batches()[0] + assert result == expected + + +def test_to_requested_schema_timestamps_without_downcast_raises_exception( + arrow_table_schema_with_all_timestamp_precisions: pa.Schema, + arrow_table_with_all_timestamp_precisions: pa.Table, + arrow_table_schema_with_all_microseconds_timestamp_precisions: pa.Schema, + table_schema_with_all_microseconds_timestamp_precision: Schema, +) -> None: + requested_schema = table_schema_with_all_microseconds_timestamp_precision + file_schema = requested_schema + batch = arrow_table_with_all_timestamp_precisions.to_batches()[0] + with pytest.raises(ValueError) as exc_info: + to_requested_schema(requested_schema, file_schema, batch, downcast_ns_timestamp_to_us=False, include_field_ids=False) + + assert "Unsupported schema projection from timestamp[ns] to timestamp[us]" in str(exc_info.value) From 413b4f731442b6f7224834db046f9a579e6c6632 Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Fri, 12 Jul 2024 00:53:28 +0000 Subject: [PATCH 7/7] make _to_requested_schema private --- pyiceberg/io/pyarrow.py | 6 +++--- tests/io/test_pyarrow.py | 10 +++++----- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index db83d09b91..7016316d93 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -1074,7 +1074,7 @@ def _task_to_record_batches( arrow_table = pa.Table.from_batches([batch]) arrow_table = arrow_table.filter(pyarrow_filter) batch = arrow_table.to_batches()[0] - yield to_requested_schema(projected_schema, file_project_schema, batch, downcast_ns_timestamp_to_us=True) + yield _to_requested_schema(projected_schema, file_project_schema, batch, downcast_ns_timestamp_to_us=True) current_index += len(batch) @@ -1279,7 +1279,7 @@ def project_batches( total_row_count += len(batch) -def to_requested_schema( +def _to_requested_schema( requested_schema: Schema, file_schema: Schema, batch: pa.RecordBatch, @@ -1990,7 +1990,7 @@ def write_parquet(task: WriteTask) -> DataFile: downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False batches = [ - to_requested_schema( + _to_requested_schema( requested_schema=file_schema, file_schema=table_schema, batch=batch, diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py index 279b07d222..37198b7edb 100644 --- a/tests/io/test_pyarrow.py +++ b/tests/io/test_pyarrow.py @@ -65,11 +65,11 @@ _determine_partitions, _primitive_to_physical, _read_deletes, + _to_requested_schema, bin_pack_arrow_table, expression_to_pyarrow, project_table, schema_to_pyarrow, - to_requested_schema, ) from pyiceberg.manifest import DataFile, DataFileContent, FileFormat from pyiceberg.partitioning import PartitionField, PartitionSpec @@ -1892,7 +1892,7 @@ def test_identity_partition_on_multi_columns() -> None: ]) == arrow_table.sort_by([("born_year", "ascending"), ("n_legs", "ascending"), ("animal", "ascending")]) -def test_to_requested_schema_timestamps( +def test__to_requested_schema_timestamps( arrow_table_schema_with_all_timestamp_precisions: pa.Schema, arrow_table_with_all_timestamp_precisions: pa.Table, arrow_table_schema_with_all_microseconds_timestamp_precisions: pa.Schema, @@ -1901,7 +1901,7 @@ def test_to_requested_schema_timestamps( requested_schema = table_schema_with_all_microseconds_timestamp_precision file_schema = requested_schema batch = arrow_table_with_all_timestamp_precisions.to_batches()[0] - result = to_requested_schema(requested_schema, file_schema, batch, downcast_ns_timestamp_to_us=True, include_field_ids=False) + result = _to_requested_schema(requested_schema, file_schema, batch, downcast_ns_timestamp_to_us=True, include_field_ids=False) expected = arrow_table_with_all_timestamp_precisions.cast( arrow_table_schema_with_all_microseconds_timestamp_precisions, safe=False @@ -1909,7 +1909,7 @@ def test_to_requested_schema_timestamps( assert result == expected -def test_to_requested_schema_timestamps_without_downcast_raises_exception( +def test__to_requested_schema_timestamps_without_downcast_raises_exception( arrow_table_schema_with_all_timestamp_precisions: pa.Schema, arrow_table_with_all_timestamp_precisions: pa.Table, arrow_table_schema_with_all_microseconds_timestamp_precisions: pa.Schema, @@ -1919,6 +1919,6 @@ def test_to_requested_schema_timestamps_without_downcast_raises_exception( file_schema = requested_schema batch = arrow_table_with_all_timestamp_precisions.to_batches()[0] with pytest.raises(ValueError) as exc_info: - to_requested_schema(requested_schema, file_schema, batch, downcast_ns_timestamp_to_us=False, include_field_ids=False) + _to_requested_schema(requested_schema, file_schema, batch, downcast_ns_timestamp_to_us=False, include_field_ids=False) assert "Unsupported schema projection from timestamp[ns] to timestamp[us]" in str(exc_info.value)