From 388580a8b1b519cc00934304ff634e631613d7c7 Mon Sep 17 00:00:00 2001 From: Fokko Date: Thu, 6 Mar 2025 12:21:59 +0100 Subject: [PATCH 01/10] Update-schema: Add support for `initial-default` --- pyiceberg/table/update/schema.py | 44 ++++++++++++++++++++++----- tests/integration/test_rest_schema.py | 36 ++++++++++++++++++---- 2 files changed, 66 insertions(+), 14 deletions(-) diff --git a/pyiceberg/table/update/schema.py b/pyiceberg/table/update/schema.py index 8ee3b43c24..6041a2ef6c 100644 --- a/pyiceberg/table/update/schema.py +++ b/pyiceberg/table/update/schema.py @@ -20,9 +20,10 @@ from copy import copy from dataclasses import dataclass from enum import Enum -from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union from pyiceberg.exceptions import ResolveError, ValidationError +from pyiceberg.expressions import literal # type: ignore from pyiceberg.schema import ( PartnerAccessor, Schema, @@ -153,7 +154,12 @@ def union_by_name(self, new_schema: Union[Schema, "pa.Schema"]) -> UpdateSchema: return self def add_column( - self, path: Union[str, Tuple[str, ...]], field_type: IcebergType, doc: Optional[str] = None, required: bool = False + self, + path: Union[str, Tuple[str, ...]], + field_type: IcebergType, + doc: Optional[str] = None, + required: bool = False, + default_value: Optional[Any] = None, ) -> UpdateSchema: """Add a new column to a nested struct or Add a new top-level column. @@ -168,6 +174,7 @@ def add_column( field_type: Type for the new column. doc: Documentation string for the new column. required: Whether the new column is required. + default_value: Default value for the new column. Returns: This for method chaining. @@ -177,10 +184,6 @@ def add_column( raise ValueError(f"Cannot add column with ambiguous name: {path}, provide a tuple instead") path = (path,) - if required and not self._allow_incompatible_changes: - # Table format version 1 and 2 cannot add required column because there is no initial value - raise ValueError(f"Incompatible change: cannot add required column: {'.'.join(path)}") - name = path[-1] parent = path[:-1] @@ -212,13 +215,34 @@ def add_column( # assign new IDs in order new_id = self.assign_new_column_id() + new_type = assign_fresh_schema_ids(field_type, self.assign_new_column_id) + + if default_value is not None: + try: + # To make sure that the value is valid for the type + initial_default = literal(default_value).to(new_type).value + except ValueError as e: + raise ValueError(f"Invalid default value: {e}") from e + else: + initial_default = default_value + + if (required and initial_default is None) and not self._allow_incompatible_changes: + # Table format version 1 and 2 cannot add required column because there is no initial value + raise ValueError(f"Incompatible change: cannot add required column: {'.'.join(path)}") # update tracking for moves self._added_name_to_id[full_name] = new_id self._id_to_parent[new_id] = parent_full_path - new_type = assign_fresh_schema_ids(field_type, self.assign_new_column_id) - field = NestedField(field_id=new_id, name=name, field_type=new_type, required=required, doc=doc) + field = NestedField( + field_id=new_id, + name=name, + field_type=new_type, + required=required, + doc=doc, + initial_default=initial_default, + write_default=initial_default, + ) if parent_id in self._adds: self._adds[parent_id].append(field) @@ -330,6 +354,7 @@ def _set_column_requirement(self, path: Union[str, Tuple[str, ...]], required: b field_type=updated.field_type, doc=updated.doc, required=required, + initial_default=updated.initial_default, ) else: self._updates[field.field_id] = NestedField( @@ -338,6 +363,7 @@ def _set_column_requirement(self, path: Union[str, Tuple[str, ...]], required: b field_type=field.field_type, doc=field.doc, required=required, + initial_default=field.initial_default, ) def update_column( @@ -387,6 +413,7 @@ def update_column( field_type=field_type or updated.field_type, doc=doc if doc is not None else updated.doc, required=updated.required, + initial_default=updated.initial_default, ) else: self._updates[field.field_id] = NestedField( @@ -395,6 +422,7 @@ def update_column( field_type=field_type or field.field_type, doc=doc if doc is not None else field.doc, required=field.required, + initial_default=field.initial_default, ) if required is not None: diff --git a/tests/integration/test_rest_schema.py b/tests/integration/test_rest_schema.py index 6a704839e2..452900e1b4 100644 --- a/tests/integration/test_rest_schema.py +++ b/tests/integration/test_rest_schema.py @@ -27,6 +27,7 @@ from pyiceberg.table.sorting import SortField, SortOrder from pyiceberg.table.update.schema import UpdateSchema from pyiceberg.transforms import IdentityTransform +from pyiceberg.typedef import EMPTY_DICT, Properties from pyiceberg.types import ( BinaryType, BooleanType, @@ -69,7 +70,7 @@ def simple_table(catalog: Catalog, table_schema_simple: Schema) -> Table: return _create_table_with_schema(catalog, table_schema_simple) -def _create_table_with_schema(catalog: Catalog, schema: Schema) -> Table: +def _create_table_with_schema(catalog: Catalog, schema: Schema, properties: Properties = EMPTY_DICT) -> Table: tbl_name = "default.test_schema_evolution" try: catalog.drop_table(tbl_name) @@ -78,7 +79,7 @@ def _create_table_with_schema(catalog: Catalog, schema: Schema) -> Table: return catalog.create_table( identifier=tbl_name, schema=schema, - properties={TableProperties.DEFAULT_NAME_MAPPING: create_mapping_from_schema(schema).model_dump_json()}, + properties={TableProperties.DEFAULT_NAME_MAPPING: create_mapping_from_schema(schema).model_dump_json(), **properties}, ) @@ -1076,9 +1077,8 @@ def test_add_required_column(catalog: Catalog) -> None: schema_ = Schema(NestedField(field_id=1, name="a", field_type=BooleanType(), required=False)) table = _create_table_with_schema(catalog, schema_) update = table.update_schema() - with pytest.raises(ValueError) as exc_info: + with pytest.raises(ValueError, match="Incompatible change: cannot add required column: data"): update.add_column(path="data", field_type=IntegerType(), required=True) - assert "Incompatible change: cannot add required column: data" in str(exc_info.value) new_schema = ( UpdateSchema(transaction=table.transaction(), allow_incompatible_changes=True) @@ -1091,16 +1091,40 @@ def test_add_required_column(catalog: Catalog) -> None: ) +@pytest.mark.integration +def test_add_required_column_initial_default(catalog: Catalog) -> None: + schema_ = Schema(NestedField(field_id=1, name="a", field_type=BooleanType(), required=False)) + table = _create_table_with_schema(catalog, schema_) + new_schema = ( + UpdateSchema(transaction=table.transaction()) + .add_column(path="data", field_type=IntegerType(), required=True, default_value=22) + ._apply() + ) + assert new_schema == Schema( + NestedField(field_id=1, name="a", field_type=BooleanType(), required=False), + NestedField(field_id=2, name="data", field_type=IntegerType(), required=True, initial_default=22, write_default=22), + schema_id=1, + ) + + +@pytest.mark.integration +def test_add_required_column_initial_default_invalid_value(catalog: Catalog) -> None: + schema_ = Schema(NestedField(field_id=1, name="a", field_type=BooleanType(), required=False)) + table = _create_table_with_schema(catalog, schema_) + update = table.update_schema() + with pytest.raises(ValueError, match="Invalid default value: Could not convert abc into a int"): + update.add_column(path="data", field_type=IntegerType(), required=True, default_value="abc") + + @pytest.mark.integration def test_add_required_column_case_insensitive(catalog: Catalog) -> None: schema_ = Schema(NestedField(field_id=1, name="id", field_type=BooleanType(), required=False)) table = _create_table_with_schema(catalog, schema_) - with pytest.raises(ValueError) as exc_info: + with pytest.raises(ValueError, match="already exists: ID"): with table.transaction() as txn: with txn.update_schema(allow_incompatible_changes=True) as update: update.case_sensitive(False).add_column(path="ID", field_type=IntegerType(), required=True) - assert "already exists: ID" in str(exc_info.value) new_schema = ( UpdateSchema(transaction=table.transaction(), allow_incompatible_changes=True) From 06e69deca5159bfe75e15e0273dfc35e3493d937 Mon Sep 17 00:00:00 2001 From: Fokko Date: Tue, 11 Mar 2025 19:51:44 +0100 Subject: [PATCH 02/10] Add missing args --- pyiceberg/table/update/schema.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pyiceberg/table/update/schema.py b/pyiceberg/table/update/schema.py index 6041a2ef6c..8b2f312e87 100644 --- a/pyiceberg/table/update/schema.py +++ b/pyiceberg/table/update/schema.py @@ -355,6 +355,7 @@ def _set_column_requirement(self, path: Union[str, Tuple[str, ...]], required: b doc=updated.doc, required=required, initial_default=updated.initial_default, + write_default=updated.write_default, ) else: self._updates[field.field_id] = NestedField( @@ -364,6 +365,7 @@ def _set_column_requirement(self, path: Union[str, Tuple[str, ...]], required: b doc=field.doc, required=required, initial_default=field.initial_default, + write_default=field.write_default, ) def update_column( @@ -414,6 +416,7 @@ def update_column( doc=doc if doc is not None else updated.doc, required=updated.required, initial_default=updated.initial_default, + write_default=updated.write_default, ) else: self._updates[field.field_id] = NestedField( @@ -423,6 +426,7 @@ def update_column( doc=doc if doc is not None else field.doc, required=field.required, initial_default=field.initial_default, + write_default=field.write_default, ) if required is not None: From 3da569deaa2913c71f56ccc33892dc9f99a93bd3 Mon Sep 17 00:00:00 2001 From: Fokko Date: Mon, 17 Mar 2025 14:30:42 +0100 Subject: [PATCH 03/10] WIP --- pyiceberg/conversions.py | 67 ++++++++++++++++++++++++++- pyiceberg/expressions/literals.py | 2 + pyiceberg/table/update/schema.py | 63 +++++++++++++++++++++++++ pyiceberg/types.py | 24 ++++++++-- tests/integration/test_rest_schema.py | 61 +++++++++++++++++++++++- 5 files changed, 212 insertions(+), 5 deletions(-) diff --git a/pyiceberg/conversions.py b/pyiceberg/conversions.py index de67cdfff0..e6bcc558b5 100644 --- a/pyiceberg/conversions.py +++ b/pyiceberg/conversions.py @@ -20,6 +20,7 @@ - Converting partition strings to built-in python objects. - Converting a value to a byte buffer. - Converting a byte buffer to a value. + - Converting a json-single field serialized field Note: Conversion logic varies based on the PrimitiveType implementation. Therefore conversion functions @@ -59,7 +60,7 @@ UUIDType, strtobool, ) -from pyiceberg.utils.datetime import date_to_days, datetime_to_micros, time_to_micros +from pyiceberg.utils.datetime import date_to_days, datetime_to_micros, time_to_micros, days_to_date, to_human_day from pyiceberg.utils.decimal import decimal_to_bytes, unscaled_to_decimal _BOOL_STRUCT = Struct(" bytes: def _(primitive_type: DecimalType, buf: bytes) -> Decimal: unscaled = int.from_bytes(buf, "big", signed=True) return unscaled_to_decimal(unscaled, primitive_type.scale) + + +@singledispatch # type: ignore +def to_json(primitive_type: PrimitiveType, b: Any) -> L: # type: ignore + """Convert bytes to a built-in python value. + + https://iceberg.apache.org/spec/#json-single-value-serialization + + Args: + primitive_type (PrimitiveType): An implementation of the PrimitiveType base class. + b (bytes): The bytes to convert. + """ + raise TypeError(f"Cannot deserialize bytes, type {primitive_type} not supported: {str(b)}") + + +@from_bytes.register(BooleanType) +def _(_: PrimitiveType, val: str) -> bool: + return bool(val) + + +@from_bytes.register(IntegerType) +@from_bytes.register(LongType) +def _(_: PrimitiveType, val: str) -> int: + return int(val) + + +@from_bytes.register(DateType) +def _(_: PrimitiveType, val: Union[int, date]) -> str: + if isinstance(val, date): + val = date_to_days(val) + + return to_human_day(val) + +@from_bytes.register(TimeType) +def _(_: PrimitiveType, time) -> str: + return to_huma + + +@from_bytes.register(FloatType) +def _(_: FloatType, b: bytes) -> float: + return _FLOAT_STRUCT.unpack(b)[0] + + +@from_bytes.register(DoubleType) +def _(_: DoubleType, b: bytes) -> float: + return _DOUBLE_STRUCT.unpack(b)[0] + + +@from_bytes.register(StringType) +def _(_: StringType, b: bytes) -> str: + return bytes(b).decode(UTF8) + + +@from_bytes.register(BinaryType) +@from_bytes.register(FixedType) +@from_bytes.register(UUIDType) +def _(_: PrimitiveType, b: bytes) -> bytes: + return b + + +@from_bytes.register(DecimalType) +def _(primitive_type: DecimalType, buf: bytes) -> Decimal: + unscaled = int.from_bytes(buf, "big", signed=True) + return unscaled_to_decimal(unscaled, primitive_type.scale) diff --git a/pyiceberg/expressions/literals.py b/pyiceberg/expressions/literals.py index b29d0d9e48..985ec92011 100644 --- a/pyiceberg/expressions/literals.py +++ b/pyiceberg/expressions/literals.py @@ -152,6 +152,8 @@ def literal(value: L) -> Literal[L]: return TimestampLiteral(datetime_to_micros(value)) # type: ignore elif isinstance(value, date): return DateLiteral(date_to_days(value)) # type: ignore + elif isinstance(value, time): + return DateLiteral(date_to_days(value)) # type: ignore else: raise TypeError(f"Invalid literal value: {repr(value)}") diff --git a/pyiceberg/table/update/schema.py b/pyiceberg/table/update/schema.py index 8b2f312e87..14fcf3bb48 100644 --- a/pyiceberg/table/update/schema.py +++ b/pyiceberg/table/update/schema.py @@ -274,6 +274,19 @@ def delete_column(self, path: Union[str, Tuple[str, ...]]) -> UpdateSchema: return self + def set_default_value(self, path: Union[str, Tuple[str, ...]], default_value: Any) -> UpdateSchema: + """Set the default value of a column. + + Args: + path: The path to the column. + + Returns: + The UpdateSchema with the delete operation staged. + """ + self._set_column_default_value(path, default_value) + + return self + def rename_column(self, path_from: Union[str, Tuple[str, ...]], new_name: str) -> UpdateSchema: """Update the name of a column. @@ -297,6 +310,8 @@ def rename_column(self, path_from: Union[str, Tuple[str, ...]], new_name: str) - field_type=updated.field_type, doc=updated.doc, required=updated.required, + initial_default=updated.initial_default, + write_default=updated.write_default, ) else: self._updates[field_from.field_id] = NestedField( @@ -305,6 +320,8 @@ def rename_column(self, path_from: Union[str, Tuple[str, ...]], new_name: str) - field_type=field_from.field_type, doc=field_from.doc, required=field_from.required, + initial_default=field_from.initial_default, + write_default=field_from.write_default, ) # Lookup the field because of casing @@ -368,6 +385,52 @@ def _set_column_requirement(self, path: Union[str, Tuple[str, ...]], required: b write_default=field.write_default, ) + + def _set_column_default_value(self, path: Union[str, Tuple[str, ...]], default_value: Any) -> None: + path = (path,) if isinstance(path, str) else path + name = ".".join(path) + + field = self._schema.find_field(name, self._case_sensitive) + + if default_value is not None: + try: + # To make sure that the value is valid for the type + default_value = literal(default_value).to(field.field_type).value + except ValueError as e: + raise ValueError(f"Invalid default value: {e}") from e + + if field.required and default_value != field.write_default: + # if the change is a noop, allow it even if allowIncompatibleChanges is false + return + + if not self._allow_incompatible_changes and field.required and default_value is not None: + raise ValueError(f"Cannot change change default-value of a required column to None") + + if field.field_id in self._deletes: + raise ValueError(f"Cannot update a column that will be deleted: {name}") + + if updated := self._updates.get(field.field_id): + self._updates[field.field_id] = NestedField( + field_id=updated.field_id, + name=updated.name, + field_type=updated.field_type, + doc=updated.doc, + required=updated.required, + initial_default=updated.initial_default, + write_default=default_value, + ) + else: + self._updates[field.field_id] = NestedField( + field_id=field.field_id, + name=field.name, + field_type=field.field_type, + doc=field.doc, + required=field.required, + initial_default=field.initial_default, + write_default=default_value + ) + + def update_column( self, path: Union[str, Tuple[str, ...]], diff --git a/pyiceberg/types.py b/pyiceberg/types.py index bd0eb7a5e9..36b6fdf7f8 100644 --- a/pyiceberg/types.py +++ b/pyiceberg/types.py @@ -43,14 +43,16 @@ Tuple, ) +from typing_extensions import Self + from pydantic import ( Field, PrivateAttr, SerializeAsAny, model_serializer, - model_validator, + model_validator, field_validator, ) -from pydantic_core.core_schema import ValidatorFunctionWrapHandler +from pydantic_core.core_schema import ValidatorFunctionWrapHandler, ValidationInfo from pyiceberg.exceptions import ValidationError from pyiceberg.typedef import IcebergBaseModel, IcebergRootModel, L @@ -308,7 +310,23 @@ class NestedField(IcebergType): required: bool = Field(default=False) doc: Optional[str] = Field(default=None, repr=False) initial_default: Optional[Any] = Field(alias="initial-default", default=None, repr=False) - write_default: Optional[L] = Field(alias="write-default", default=None, repr=False) # type: ignore + write_default: Optional[L] = Field(alias="write-default", default=None, repr=False) + + # @field_validator('initial_default', mode='after') + # @classmethod + # def check_passwords_match(cls, value: str, info: ValidationInfo) -> str: + # from pyiceberg.expressions import literal + # if value is not None: + # return literal(value).to(info.data['field_type']).value + # return value + # + # @field_validator('write_default', mode='after') + # @classmethod + # def check_passwords_match2(cls, value: str, info: ValidationInfo) -> str: + # from pyiceberg.expressions import literal + # if value is not None: + # return literal(value).to(info.data['field_type']).value + # return value def __init__( self, diff --git a/tests/integration/test_rest_schema.py b/tests/integration/test_rest_schema.py index 452900e1b4..5e03f40411 100644 --- a/tests/integration/test_rest_schema.py +++ b/tests/integration/test_rest_schema.py @@ -15,11 +15,16 @@ # specific language governing permissions and limitations # under the License. # pylint:disable=redefined-outer-name +from datetime import date, time, datetime, timezone +from decimal import Decimal +from typing import Any +from uuid import UUID import pytest from pyiceberg.catalog import Catalog, load_catalog from pyiceberg.exceptions import CommitFailedException, NoSuchTableError, ValidationError +from pyiceberg.expressions import literal from pyiceberg.partitioning import PartitionField, PartitionSpec from pyiceberg.schema import Schema, prune_columns from pyiceberg.table import Table, TableProperties @@ -1092,6 +1097,47 @@ def test_add_required_column(catalog: Catalog) -> None: @pytest.mark.integration +@pytest.mark.parametrize( + "iceberg_type, default_value, write_default", + [ + (BooleanType(), True, False), + (IntegerType(), 123, 456), + (LongType(), 123, 456), + (FloatType(), 19.25, 22.27), + (DoubleType(), 19.25, 22.27), + (DecimalType(10, 2), Decimal("19.25"), Decimal("22.27")), + (DecimalType(100, 2), Decimal("19.25"), Decimal("22.27")), + (StringType(), "abc", "def"), + (DateType(), date(1990, 3, 1), date(1991, 3, 1)), + (TimeType(), time(19, 25, 22), time(22, 25, 22)), + (TimestampType(), datetime(1990, 5, 1, 22, 1, 1), datetime(2000, 5, 1, 22, 1, 1)), + (TimestamptzType(), datetime(1990, 5, 1, 22, 1, 1, tzinfo=timezone.utc), datetime(2000, 5, 1, 22, 1, 1, tzinfo=timezone.utc)), + (BinaryType(), b"123", b"456"), + (FixedType(4), b"1234", b"5678"), + (UUIDType(), UUID(int=0x12345678123456781234567812345678), UUID(int=0x32145678123456781234567812345678)), + ] +) +def test_initial_default_all_columns(catalog: Catalog, iceberg_type: PrimitiveType, default_value: Any, write_default: Any) -> None: + # Round trips all the types through the rest catalog to check the serialization + table = _create_table_with_schema(catalog, Schema(), properties={TableProperties.FORMAT_VERSION: 3}) + + with table.update_schema() as tx: + tx.add_column(path="data", field_type=iceberg_type, required=True, default_value=default_value) + + field = table.schema().find_field(1) + physical_type = literal(default_value).to(iceberg_type).value + assert physical_type == field.initial_default + assert physical_type == field.write_default + + with table.update_schema() as tx: + tx.set_default_value("data", write_default) + + field = table.schema().find_field(1) + write_physical_type = literal(default_value).to(iceberg_type).value + assert physical_type == field.initial_default + assert write_physical_type == field.write_default + + def test_add_required_column_initial_default(catalog: Catalog) -> None: schema_ = Schema(NestedField(field_id=1, name="a", field_type=BooleanType(), required=False)) table = _create_table_with_schema(catalog, schema_) @@ -1101,11 +1147,24 @@ def test_add_required_column_initial_default(catalog: Catalog) -> None: ._apply() ) assert new_schema == Schema( - NestedField(field_id=1, name="a", field_type=BooleanType(), required=False), + NestedField(field_id=1, name="a", field_type=BooleanType(), required=True, initial_default=True), NestedField(field_id=2, name="data", field_type=IntegerType(), required=True, initial_default=22, write_default=22), schema_id=1, ) + # Update + new_schema = ( + UpdateSchema(transaction=table.transaction()) + .update_column(path="data", field_type=LongType()) + .rename_column("a", "bool") + ._apply() + ) + assert new_schema == Schema( + NestedField(field_id=1, name="bool", field_type=BooleanType(), required=False), + NestedField(field_id=2, name="data", field_type=LongType(), required=True, initial_default=22, write_default=22), + schema_id=1, + ) + @pytest.mark.integration def test_add_required_column_initial_default_invalid_value(catalog: Catalog) -> None: From d88e80ee43371999e1b7ab846f4d1a4267b6b14e Mon Sep 17 00:00:00 2001 From: Fokko Date: Tue, 18 Mar 2025 11:49:57 +0100 Subject: [PATCH 04/10] WIP --- pyiceberg/conversions.py | 222 +++++++++++++++++++++----- pyiceberg/expressions/literals.py | 3 +- pyiceberg/table/update/schema.py | 6 +- pyiceberg/types.py | 24 +-- tests/integration/test_rest_schema.py | 16 +- tests/test_conversions.py | 49 ++++++ 6 files changed, 252 insertions(+), 68 deletions(-) diff --git a/pyiceberg/conversions.py b/pyiceberg/conversions.py index e6bcc558b5..065c16e54f 100644 --- a/pyiceberg/conversions.py +++ b/pyiceberg/conversions.py @@ -29,6 +29,7 @@ implementations that share the same conversion logic, registrations can be stacked. """ +import codecs import uuid from datetime import date, datetime, time from decimal import Decimal @@ -60,7 +61,23 @@ UUIDType, strtobool, ) -from pyiceberg.utils.datetime import date_to_days, datetime_to_micros, time_to_micros, days_to_date, to_human_day +from pyiceberg.utils.datetime import ( + date_str_to_days, + date_to_days, + datetime_to_micros, + days_to_date, + micros_to_time, + micros_to_timestamp, + micros_to_timestamptz, + time_str_to_micros, + time_to_micros, + timestamp_to_micros, + timestamptz_to_micros, + to_human_day, + to_human_time, + to_human_timestamp, + to_human_timestamptz, +) from pyiceberg.utils.decimal import decimal_to_bytes, unscaled_to_decimal _BOOL_STRUCT = Struct(" L: # type: ignore primitive_type (PrimitiveType): An implementation of the PrimitiveType base class. b (bytes): The bytes to convert. """ - raise TypeError(f"Cannot deserialize bytes, type {primitive_type} not supported: {str(b)}") + raise TypeError(f"Cannot deserialize bytes, type {primitive_type} not supported: {b!r}") @from_bytes.register(BooleanType) @@ -328,64 +345,197 @@ def _(primitive_type: DecimalType, buf: bytes) -> Decimal: @singledispatch # type: ignore -def to_json(primitive_type: PrimitiveType, b: Any) -> L: # type: ignore - """Convert bytes to a built-in python value. +def to_json(primitive_type: PrimitiveType, val: Any) -> L: # type: ignore + """Convert built-in python values into JSON value types. https://iceberg.apache.org/spec/#json-single-value-serialization Args: primitive_type (PrimitiveType): An implementation of the PrimitiveType base class. - b (bytes): The bytes to convert. + val (Any): The arbitrary built-in value to convert into the right form """ - raise TypeError(f"Cannot deserialize bytes, type {primitive_type} not supported: {str(b)}") + raise TypeError(f"Cannot deserialize bytes, type {primitive_type} not supported: {val}") -@from_bytes.register(BooleanType) -def _(_: PrimitiveType, val: str) -> bool: - return bool(val) +@to_json.register(BooleanType) +def _(_: BooleanType, val: bool) -> bool: + """Python bool automatically converts into a JSON bool.""" + return val -@from_bytes.register(IntegerType) -@from_bytes.register(LongType) -def _(_: PrimitiveType, val: str) -> int: - return int(val) +@to_json.register(IntegerType) +@to_json.register(LongType) +def _(_: Union[IntegerType, LongType], val: int) -> int: + """Python int automatically converts to a JSON int.""" + return val -@from_bytes.register(DateType) -def _(_: PrimitiveType, val: Union[int, date]) -> str: +@to_json.register(DateType) +def _(_: DateType, val: Union[date, int]) -> str: + """JSON date is string encoded""" if isinstance(val, date): val = date_to_days(val) - return to_human_day(val) -@from_bytes.register(TimeType) -def _(_: PrimitiveType, time) -> str: - return to_huma +@to_json.register(TimeType) +def _(_: TimeType, val: Union[int, time]) -> str: + """Python time or microseconds since epoch serializes into an ISO8601 time""" + if isinstance(val, time): + val = time_to_micros(val) + return to_human_time(val) -@from_bytes.register(FloatType) -def _(_: FloatType, b: bytes) -> float: - return _FLOAT_STRUCT.unpack(b)[0] +@to_json.register(TimestampType) +def _(_: PrimitiveType, val: Union[int, datetime]) -> str: + """Python datetime (without timezone) or microseconds since epoch serializes into an ISO8601 timestamp.""" + if isinstance(val, datetime): + val = datetime_to_micros(val) -@from_bytes.register(DoubleType) -def _(_: DoubleType, b: bytes) -> float: - return _DOUBLE_STRUCT.unpack(b)[0] + return to_human_timestamp(val) -@from_bytes.register(StringType) -def _(_: StringType, b: bytes) -> str: - return bytes(b).decode(UTF8) +@to_json.register(TimestamptzType) +def _(_: TimestamptzType, val: Union[int, datetime]) -> str: + """Python datetime (with timezone) or microseconds since epoch serializes into an ISO8601 timestamp.""" + if isinstance(val, datetime): + val = datetime_to_micros(val) + return to_human_timestamptz(val) -@from_bytes.register(BinaryType) -@from_bytes.register(FixedType) -@from_bytes.register(UUIDType) -def _(_: PrimitiveType, b: bytes) -> bytes: +@to_json.register(FloatType) +@to_json.register(DoubleType) +def _(_: Union[FloatType, DoubleType], val: float) -> float: + """Float serializes into JSON float.""" + return val + + +@to_json.register(StringType) +def _(_: StringType, val: str) -> str: + """Python string serializes into JSON string.""" + return val + + +@to_json.register(FixedType) +def _(t: FixedType, b: bytes) -> str: + """Python bytes serializes into hexadecimal encoded string.""" + if len(t) != len(b): + raise ValueError(f"FixedType has length {len(t)}, which is different from the value: {len(b)}") + + return codecs.encode(b, "hex").decode(UTF8) + + +@to_json.register(BinaryType) +def _(_: BinaryType, b: bytes) -> str: + """Python bytes serializes into hexadecimal encoded string.""" + return codecs.encode(b, "hex").decode(UTF8) + + +@to_json.register(DecimalType) +def _(_: DecimalType, val: Decimal) -> str: + """Python decimal serializes into string + + Stores the string representation of the decimal value, specifically, for + values with a positive scale, the number of digits to the right of the + decimal point is used to indicate scale, for values with a negative scale, + the scientific notation is used and the exponent must equal the negated scale + .""" + return str(val) + + +@to_json.register(UUIDType) +def _(_: UUIDType, val: uuid.UUID) -> str: + """Serializes into a JSON string.""" + return str(val) + + +@singledispatch # type: ignore +def from_json(primitive_type: PrimitiveType, val: Any) -> L: # type: ignore + """Convert JSON value types into built-in python values. + + https://iceberg.apache.org/spec/#json-single-value-serialization + + Args: + primitive_type (PrimitiveType): An implementation of the PrimitiveType base class. + val (Any): The arbitrary JSON value to convert into the right form + """ + raise TypeError(f"Cannot deserialize bytes, type {primitive_type} not supported: {str(val)}") + + +@from_json.register(BooleanType) +def _(_: BooleanType, val: bool) -> bool: + """JSON bool automatically converts into a Python bool.""" + return val + + +@from_json.register(IntegerType) +@from_json.register(LongType) +def _(_: Union[IntegerType, LongType], val: int) -> int: + """JSON int automatically converts to a Python int.""" + return val + + +@from_json.register(DateType) +def _(_: DateType, val: str) -> date: + """JSON date is string encoded.""" + return days_to_date(date_str_to_days(val)) + + +@from_json.register(TimeType) +def _(_: TimeType, val: str) -> time: + """JSON ISO8601 string into Python time.""" + return micros_to_time(time_str_to_micros(val)) + + +@from_json.register(TimestampType) +def _(_: PrimitiveType, val: str) -> datetime: + """JSON ISO8601 string into Python datetime.""" + return micros_to_timestamp(timestamp_to_micros(val)) + + +@from_json.register(TimestamptzType) +def _(_: TimestamptzType, val: str) -> datetime: + """JSON ISO8601 string into Python datetime.""" + return micros_to_timestamptz(timestamptz_to_micros(val)) + + +@from_json.register(FloatType) +@from_json.register(DoubleType) +def _(_: Union[FloatType, DoubleType], val: float) -> float: + """JSON float deserializes into a Python float.""" + return val + + +@from_json.register(StringType) +def _(_: StringType, val: str) -> str: + """JSON string serializes into a Python string.""" + return val + + +@from_json.register(FixedType) +def _(t: FixedType, val: str) -> bytes: + """JSON hexadecimal encoded string into bytes.""" + b = codecs.decode(val.encode(UTF8), "hex") + + if len(t) != len(b): + raise ValueError(f"FixedType has length {len(t)}, which is different from the value: {len(b)}") + return b -@from_bytes.register(DecimalType) -def _(primitive_type: DecimalType, buf: bytes) -> Decimal: - unscaled = int.from_bytes(buf, "big", signed=True) - return unscaled_to_decimal(unscaled, primitive_type.scale) +@from_json.register(BinaryType) +def _(_: BinaryType, val: str) -> bytes: + """JSON hexadecimal encoded string into bytes.""" + return codecs.decode(val.encode(UTF8), "hex") + + +@from_json.register(DecimalType) +def _(_: DecimalType, val: str) -> Decimal: + """string into a Python decimal.""" + return Decimal(val) + + +@from_json.register(UUIDType) +def _(_: UUIDType, val: str) -> uuid.UUID: + """JSON string into Python UUID.""" + return uuid.UUID(val) diff --git a/pyiceberg/expressions/literals.py b/pyiceberg/expressions/literals.py index 985ec92011..8c1387a5be 100644 --- a/pyiceberg/expressions/literals.py +++ b/pyiceberg/expressions/literals.py @@ -152,8 +152,7 @@ def literal(value: L) -> Literal[L]: return TimestampLiteral(datetime_to_micros(value)) # type: ignore elif isinstance(value, date): return DateLiteral(date_to_days(value)) # type: ignore - elif isinstance(value, time): - return DateLiteral(date_to_days(value)) # type: ignore + else: raise TypeError(f"Invalid literal value: {repr(value)}") diff --git a/pyiceberg/table/update/schema.py b/pyiceberg/table/update/schema.py index 14fcf3bb48..78d22795ac 100644 --- a/pyiceberg/table/update/schema.py +++ b/pyiceberg/table/update/schema.py @@ -385,7 +385,6 @@ def _set_column_requirement(self, path: Union[str, Tuple[str, ...]], required: b write_default=field.write_default, ) - def _set_column_default_value(self, path: Union[str, Tuple[str, ...]], default_value: Any) -> None: path = (path,) if isinstance(path, str) else path name = ".".join(path) @@ -404,7 +403,7 @@ def _set_column_default_value(self, path: Union[str, Tuple[str, ...]], default_v return if not self._allow_incompatible_changes and field.required and default_value is not None: - raise ValueError(f"Cannot change change default-value of a required column to None") + raise ValueError("Cannot change change default-value of a required column to None") if field.field_id in self._deletes: raise ValueError(f"Cannot update a column that will be deleted: {name}") @@ -427,10 +426,9 @@ def _set_column_default_value(self, path: Union[str, Tuple[str, ...]], default_v doc=field.doc, required=field.required, initial_default=field.initial_default, - write_default=default_value + write_default=default_value, ) - def update_column( self, path: Union[str, Tuple[str, ...]], diff --git a/pyiceberg/types.py b/pyiceberg/types.py index 36b6fdf7f8..bd0eb7a5e9 100644 --- a/pyiceberg/types.py +++ b/pyiceberg/types.py @@ -43,16 +43,14 @@ Tuple, ) -from typing_extensions import Self - from pydantic import ( Field, PrivateAttr, SerializeAsAny, model_serializer, - model_validator, field_validator, + model_validator, ) -from pydantic_core.core_schema import ValidatorFunctionWrapHandler, ValidationInfo +from pydantic_core.core_schema import ValidatorFunctionWrapHandler from pyiceberg.exceptions import ValidationError from pyiceberg.typedef import IcebergBaseModel, IcebergRootModel, L @@ -310,23 +308,7 @@ class NestedField(IcebergType): required: bool = Field(default=False) doc: Optional[str] = Field(default=None, repr=False) initial_default: Optional[Any] = Field(alias="initial-default", default=None, repr=False) - write_default: Optional[L] = Field(alias="write-default", default=None, repr=False) - - # @field_validator('initial_default', mode='after') - # @classmethod - # def check_passwords_match(cls, value: str, info: ValidationInfo) -> str: - # from pyiceberg.expressions import literal - # if value is not None: - # return literal(value).to(info.data['field_type']).value - # return value - # - # @field_validator('write_default', mode='after') - # @classmethod - # def check_passwords_match2(cls, value: str, info: ValidationInfo) -> str: - # from pyiceberg.expressions import literal - # if value is not None: - # return literal(value).to(info.data['field_type']).value - # return value + write_default: Optional[L] = Field(alias="write-default", default=None, repr=False) # type: ignore def __init__( self, diff --git a/tests/integration/test_rest_schema.py b/tests/integration/test_rest_schema.py index 5e03f40411..f97ea25194 100644 --- a/tests/integration/test_rest_schema.py +++ b/tests/integration/test_rest_schema.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint:disable=redefined-outer-name -from datetime import date, time, datetime, timezone +from datetime import date, datetime, time, timezone from decimal import Decimal from typing import Any from uuid import UUID @@ -1109,15 +1109,21 @@ def test_add_required_column(catalog: Catalog) -> None: (DecimalType(100, 2), Decimal("19.25"), Decimal("22.27")), (StringType(), "abc", "def"), (DateType(), date(1990, 3, 1), date(1991, 3, 1)), - (TimeType(), time(19, 25, 22), time(22, 25, 22)), + (TimeType(), time(19, 25, 22), time(22, 25, 22)), (TimestampType(), datetime(1990, 5, 1, 22, 1, 1), datetime(2000, 5, 1, 22, 1, 1)), - (TimestamptzType(), datetime(1990, 5, 1, 22, 1, 1, tzinfo=timezone.utc), datetime(2000, 5, 1, 22, 1, 1, tzinfo=timezone.utc)), + ( + TimestamptzType(), + datetime(1990, 5, 1, 22, 1, 1, tzinfo=timezone.utc), + datetime(2000, 5, 1, 22, 1, 1, tzinfo=timezone.utc), + ), (BinaryType(), b"123", b"456"), (FixedType(4), b"1234", b"5678"), (UUIDType(), UUID(int=0x12345678123456781234567812345678), UUID(int=0x32145678123456781234567812345678)), - ] + ], ) -def test_initial_default_all_columns(catalog: Catalog, iceberg_type: PrimitiveType, default_value: Any, write_default: Any) -> None: +def test_initial_default_all_columns( + catalog: Catalog, iceberg_type: PrimitiveType, default_value: Any, write_default: Any +) -> None: # Round trips all the types through the rest catalog to check the serialization table = _create_table_with_schema(catalog, Schema(), properties={TableProperties.FORMAT_VERSION: 3}) diff --git a/tests/test_conversions.py b/tests/test_conversions.py index f57998aa4e..0eafb96602 100644 --- a/tests/test_conversions.py +++ b/tests/test_conversions.py @@ -545,3 +545,52 @@ def test_datetime_obj_to_bytes(primitive_type: PrimitiveType, value: Union[datet bytes_from_value = conversions.to_bytes(primitive_type, value) assert bytes_from_value == expected_bytes + + +@pytest.mark.parametrize( + "primitive_type, value, expected", + [ + (BooleanType(), True, True), + (IntegerType(), 34, 34), + (LongType(), 34, 34), + (FloatType(), 1.0, 1.0), + (DoubleType(), 1.0, 1.0), + (DecimalType(9, 4), Decimal("123.4500"), "123.4500"), + (DecimalType(9, 0), Decimal("2"), "2"), + (DecimalType(9, -20), Decimal("2E+20"), "2E+20"), + (DateType(), date(2017, 11, 16), "2017-11-16"), + (TimeType(), time(22, 31, 8, 123456), "22:31:08.123456"), + (TimestampType(), datetime(2017, 11, 16, 22, 31, 8, 123456), "2017-11-16T22:31:08.123456"), + (TimestamptzType(), datetime(2017, 11, 16, 22, 31, 8, 123456, tzinfo=timezone.utc), "2017-11-16T22:31:08.123456+00:00"), + (StringType(), "iceberg", "iceberg"), + (BinaryType(), b"\x01\x02\x03\xff", "010203ff"), + (FixedType(4), b"\x01\x02\x03\xff", "010203ff"), + ], +) +def test_json_single_serialization(primitive_type: PrimitiveType, value: Any, expected: Any) -> None: + json_val = conversions.to_json(primitive_type, value) + assert json_val == expected + + +@pytest.mark.parametrize( + "primitive_type, value", + [ + (BooleanType(), True), + (IntegerType(), 34), + (LongType(), 34), + (FloatType(), 1.0), + (DoubleType(), 1.0), + (DecimalType(9, 4), Decimal("123.4500")), + (DecimalType(9, 0), Decimal("2")), + (DecimalType(9, -20), Decimal("2E+20")), + (DateType(), date(2017, 11, 16)), + (TimeType(), time(22, 31, 8, 123456)), + (TimestampType(), datetime(2017, 11, 16, 22, 31, 8, 123456)), + (TimestamptzType(), datetime(2017, 11, 16, 22, 31, 8, 123456, tzinfo=timezone.utc)), + (StringType(), "iceberg"), + (BinaryType(), b"\x01\x02\x03\xff"), + (FixedType(4), b"\x01\x02\x03\xff"), + ], +) +def test_json_serialize_roundtrip(primitive_type: PrimitiveType, value: Any) -> None: + assert value == conversions.from_json(primitive_type, conversions.to_json(primitive_type, value)) From 92ff4fed9f7f1eb09eca9f24649b8a59b1b4d93e Mon Sep 17 00:00:00 2001 From: Fokko Date: Mon, 24 Mar 2025 22:15:11 +0100 Subject: [PATCH 05/10] Cleanup --- pyiceberg/conversions.py | 2 +- tests/integration/test_rest_schema.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/pyiceberg/conversions.py b/pyiceberg/conversions.py index 1805a5e0b7..fe208b4aca 100644 --- a/pyiceberg/conversions.py +++ b/pyiceberg/conversions.py @@ -472,7 +472,7 @@ def _(_: DecimalType, val: Decimal) -> str: @to_json.register(UUIDType) def _(_: UUIDType, val: uuid.UUID) -> str: - """Serializes into a JSON string.""" + """Serialize into a JSON string.""" return str(val) diff --git a/tests/integration/test_rest_schema.py b/tests/integration/test_rest_schema.py index f97ea25194..62a54700fd 100644 --- a/tests/integration/test_rest_schema.py +++ b/tests/integration/test_rest_schema.py @@ -24,7 +24,7 @@ from pyiceberg.catalog import Catalog, load_catalog from pyiceberg.exceptions import CommitFailedException, NoSuchTableError, ValidationError -from pyiceberg.expressions import literal +from pyiceberg.expressions import literal # type: ignore from pyiceberg.partitioning import PartitionField, PartitionSpec from pyiceberg.schema import Schema, prune_columns from pyiceberg.table import Table, TableProperties @@ -1144,6 +1144,7 @@ def test_initial_default_all_columns( assert write_physical_type == field.write_default +@pytest.mark.integration def test_add_required_column_initial_default(catalog: Catalog) -> None: schema_ = Schema(NestedField(field_id=1, name="a", field_type=BooleanType(), required=False)) table = _create_table_with_schema(catalog, schema_) From 85ffac3687674cda4b0477ebebd7fc9f5a9f862c Mon Sep 17 00:00:00 2001 From: Fokko Date: Tue, 25 Mar 2025 14:37:11 +0100 Subject: [PATCH 06/10] WIP --- pyiceberg/catalog/rest.py | 2 +- pyiceberg/expressions/literals.py | 1 - pyiceberg/types.py | 28 ++++++++++++--- tests/integration/test_rest_schema.py | 49 ++++++++++++++------------- 4 files changed, 49 insertions(+), 31 deletions(-) diff --git a/pyiceberg/catalog/rest.py b/pyiceberg/catalog/rest.py index ae00454000..e92a4532c9 100644 --- a/pyiceberg/catalog/rest.py +++ b/pyiceberg/catalog/rest.py @@ -781,7 +781,7 @@ def commit_table( 504: CommitStateUnknownException, }, ) - return CommitTableResponse(**response.json()) + return CommitTableResponse.model_validate_json(response.text) @retry(**_RETRY_ARGS) def create_namespace(self, namespace: Union[str, Identifier], properties: Properties = EMPTY_DICT) -> None: diff --git a/pyiceberg/expressions/literals.py b/pyiceberg/expressions/literals.py index 8c1387a5be..b29d0d9e48 100644 --- a/pyiceberg/expressions/literals.py +++ b/pyiceberg/expressions/literals.py @@ -152,7 +152,6 @@ def literal(value: L) -> Literal[L]: return TimestampLiteral(datetime_to_micros(value)) # type: ignore elif isinstance(value, date): return DateLiteral(date_to_days(value)) # type: ignore - else: raise TypeError(f"Invalid literal value: {repr(value)}") diff --git a/pyiceberg/types.py b/pyiceberg/types.py index 8e83b011bf..d898ad3ffd 100644 --- a/pyiceberg/types.py +++ b/pyiceberg/types.py @@ -40,7 +40,7 @@ Dict, Literal, Optional, - Tuple, + Tuple, Annotated, ) from pydantic import ( @@ -48,9 +48,9 @@ PrivateAttr, SerializeAsAny, model_serializer, - model_validator, + model_validator, PlainSerializer, BeforeValidator, ) -from pydantic_core.core_schema import ValidatorFunctionWrapHandler +from pydantic_core.core_schema import ValidatorFunctionWrapHandler, SerializationInfo, ValidationInfo from pyiceberg.exceptions import ValidationError from pyiceberg.typedef import IcebergBaseModel, IcebergRootModel, L, TableVersion @@ -289,6 +289,24 @@ def __eq__(self, other: Any) -> bool: return self.root == other.root if isinstance(other, DecimalType) else False +# def _serialize_default_value(v: Any, context: SerializationInfo) -> Any: +# from pyiceberg.conversions import to_json, from_json +# return v +def _deserialize_default_value(v: Any, context: ValidationInfo) -> Any: + if context.mode != 'python': + if v is not None: + from pyiceberg.conversions import from_json + return from_json(context.data.get("field_type"), v) + else: + return None + else: + return v + +# PlainSerializer(_serialize_default_value, return_type=Any), +DefaultValue = Annotated[ + Any, BeforeValidator(_deserialize_default_value) +] + class NestedField(IcebergType): """Represents a field of a struct, a map key, a map value, or a list element. @@ -317,8 +335,8 @@ class NestedField(IcebergType): field_type: SerializeAsAny[IcebergType] = Field(alias="type") required: bool = Field(default=False) doc: Optional[str] = Field(default=None, repr=False) - initial_default: Optional[Any] = Field(alias="initial-default", default=None, repr=False) - write_default: Optional[L] = Field(alias="write-default", default=None, repr=False) # type: ignore + initial_default: DefaultValue = Field(alias="initial-default", default=None, repr=False) + write_default: DefaultValue = Field(alias="write-default", default=None, repr=False) # type: ignore def __init__( self, diff --git a/tests/integration/test_rest_schema.py b/tests/integration/test_rest_schema.py index 62a54700fd..ef6604668b 100644 --- a/tests/integration/test_rest_schema.py +++ b/tests/integration/test_rest_schema.py @@ -1100,25 +1100,25 @@ def test_add_required_column(catalog: Catalog) -> None: @pytest.mark.parametrize( "iceberg_type, default_value, write_default", [ - (BooleanType(), True, False), - (IntegerType(), 123, 456), - (LongType(), 123, 456), - (FloatType(), 19.25, 22.27), - (DoubleType(), 19.25, 22.27), + # (BooleanType(), True, False), + # (IntegerType(), 123, 456), + # (LongType(), 123, 456), + # (FloatType(), 19.25, 22.27), + # (DoubleType(), 19.25, 22.27), (DecimalType(10, 2), Decimal("19.25"), Decimal("22.27")), - (DecimalType(100, 2), Decimal("19.25"), Decimal("22.27")), - (StringType(), "abc", "def"), - (DateType(), date(1990, 3, 1), date(1991, 3, 1)), - (TimeType(), time(19, 25, 22), time(22, 25, 22)), - (TimestampType(), datetime(1990, 5, 1, 22, 1, 1), datetime(2000, 5, 1, 22, 1, 1)), - ( - TimestamptzType(), - datetime(1990, 5, 1, 22, 1, 1, tzinfo=timezone.utc), - datetime(2000, 5, 1, 22, 1, 1, tzinfo=timezone.utc), - ), - (BinaryType(), b"123", b"456"), - (FixedType(4), b"1234", b"5678"), - (UUIDType(), UUID(int=0x12345678123456781234567812345678), UUID(int=0x32145678123456781234567812345678)), + # (DecimalType(10, 2), Decimal("19.25"), Decimal("22.27")), + # (StringType(), "abc", "def"), + # (DateType(), date(1990, 3, 1), date(1991, 3, 1)), + # (TimeType(), time(19, 25, 22), time(22, 25, 22)), + # (TimestampType(), datetime(1990, 5, 1, 22, 1, 1), datetime(2000, 5, 1, 22, 1, 1)), + # ( + # TimestamptzType(), + # datetime(1990, 5, 1, 22, 1, 1, tzinfo=timezone.utc), + # datetime(2000, 5, 1, 22, 1, 1, tzinfo=timezone.utc), + # ), + # (BinaryType(), b"123", b"456"), + # (FixedType(4), b"1234", b"5678"), + # (UUIDType(), UUID(int=0x12345678123456781234567812345678), UUID(int=0x32145678123456781234567812345678)), ], ) def test_initial_default_all_columns( @@ -1127,21 +1127,22 @@ def test_initial_default_all_columns( # Round trips all the types through the rest catalog to check the serialization table = _create_table_with_schema(catalog, Schema(), properties={TableProperties.FORMAT_VERSION: 3}) - with table.update_schema() as tx: - tx.add_column(path="data", field_type=iceberg_type, required=True, default_value=default_value) + tx = table.update_schema() + tx.add_column(path="data", field_type=iceberg_type, required=True, default_value=default_value) + tx.commit() field = table.schema().find_field(1) physical_type = literal(default_value).to(iceberg_type).value - assert physical_type == field.initial_default - assert physical_type == field.write_default + assert field.initial_default == physical_type + assert field.write_default == physical_type with table.update_schema() as tx: tx.set_default_value("data", write_default) field = table.schema().find_field(1) write_physical_type = literal(default_value).to(iceberg_type).value - assert physical_type == field.initial_default - assert write_physical_type == field.write_default + assert field.initial_default == physical_type + assert field.write_default == write_physical_type @pytest.mark.integration From c911b9b2035bba0d8f0d2ab7012ec687b61ac864 Mon Sep 17 00:00:00 2001 From: Fokko Date: Wed, 26 Mar 2025 15:15:57 +0100 Subject: [PATCH 07/10] MOARRR CODE --- pyiceberg/avro/resolver.py | 2 +- pyiceberg/conversions.py | 65 ++++++++++++++++++++------- pyiceberg/expressions/literals.py | 5 ++- pyiceberg/table/update/schema.py | 26 ++++++++--- pyiceberg/types.py | 53 +++++++++++++++------- pyiceberg/utils/schema_conversion.py | 2 +- tests/integration/test_rest_schema.py | 47 +++++++++---------- tests/test_types.py | 4 +- tests/utils/test_manifest.py | 2 +- 9 files changed, 135 insertions(+), 71 deletions(-) diff --git a/pyiceberg/avro/resolver.py b/pyiceberg/avro/resolver.py index 9ed111ff40..c4ec393513 100644 --- a/pyiceberg/avro/resolver.py +++ b/pyiceberg/avro/resolver.py @@ -290,7 +290,7 @@ def struct(self, file_schema: StructType, record_struct: Optional[IcebergType], # There is a default value if file_field.write_default is not None: # The field is not in the record, but there is a write default value - results.append((None, DefaultWriter(writer=writer, value=file_field.write_default))) # type: ignore + results.append((None, DefaultWriter(writer=writer, value=file_field.write_default))) elif file_field.required: raise ValueError(f"Field is required, and there is no write default: {file_field}") else: diff --git a/pyiceberg/conversions.py b/pyiceberg/conversions.py index fe208b4aca..b9b05a00ee 100644 --- a/pyiceberg/conversions.py +++ b/pyiceberg/conversions.py @@ -503,27 +503,47 @@ def _(_: Union[IntegerType, LongType], val: int) -> int: @from_json.register(DateType) -def _(_: DateType, val: str) -> date: +def _(_: DateType, val: Union[str, int, date]) -> date: """JSON date is string encoded.""" - return days_to_date(date_str_to_days(val)) + if isinstance(val, str): + val = date_str_to_days(val) + if isinstance(val, int): + return days_to_date(val) + else: + return val @from_json.register(TimeType) -def _(_: TimeType, val: str) -> time: +def _(_: TimeType, val: Union[str, int, time]) -> time: """JSON ISO8601 string into Python time.""" - return micros_to_time(time_str_to_micros(val)) + if isinstance(val, str): + val = time_str_to_micros(val) + if isinstance(val, int): + return micros_to_time(val) + else: + return val @from_json.register(TimestampType) -def _(_: PrimitiveType, val: str) -> datetime: +def _(_: PrimitiveType, val: Union[str, int, datetime]) -> datetime: """JSON ISO8601 string into Python datetime.""" - return micros_to_timestamp(timestamp_to_micros(val)) + if isinstance(val, str): + val = timestamp_to_micros(val) + if isinstance(val, int): + return micros_to_timestamp(val) + else: + return val @from_json.register(TimestamptzType) -def _(_: TimestamptzType, val: str) -> datetime: +def _(_: TimestamptzType, val: Union[str, int, datetime]) -> datetime: """JSON ISO8601 string into Python datetime.""" - return micros_to_timestamptz(timestamptz_to_micros(val)) + if isinstance(val, str): + val = timestamptz_to_micros(val) + if isinstance(val, int): + return micros_to_timestamptz(val) + else: + return val @from_json.register(FloatType) @@ -540,20 +560,26 @@ def _(_: StringType, val: str) -> str: @from_json.register(FixedType) -def _(t: FixedType, val: str) -> bytes: +def _(t: FixedType, val: Union[str, bytes]) -> bytes: """JSON hexadecimal encoded string into bytes.""" - b = codecs.decode(val.encode(UTF8), "hex") + if isinstance(val, str): + b = codecs.decode(val.encode(UTF8), "hex") - if len(t) != len(b): - raise ValueError(f"FixedType has length {len(t)}, which is different from the value: {len(b)}") + if len(t) != len(b): + raise ValueError(f"FixedType has length {len(t)}, which is different from the value: {len(b)}") - return b + return b + else: + return val @from_json.register(BinaryType) -def _(_: BinaryType, val: str) -> bytes: +def _(_: BinaryType, val: Union[bytes, str]) -> bytes: """JSON hexadecimal encoded string into bytes.""" - return codecs.decode(val.encode(UTF8), "hex") + if isinstance(val, str): + return codecs.decode(val.encode(UTF8), "hex") + else: + return val @from_json.register(DecimalType) @@ -563,6 +589,11 @@ def _(_: DecimalType, val: str) -> Decimal: @from_json.register(UUIDType) -def _(_: UUIDType, val: str) -> uuid.UUID: +def _(_: UUIDType, val: Union[str, bytes, uuid.UUID]) -> uuid.UUID: """Convert JSON string into Python UUID.""" - return uuid.UUID(val) + if isinstance(val, str): + return uuid.UUID(val) + elif isinstance(val, bytes): + return uuid.UUID(bytes=val) + else: + return val diff --git a/pyiceberg/expressions/literals.py b/pyiceberg/expressions/literals.py index b29d0d9e48..a3b5242b97 100644 --- a/pyiceberg/expressions/literals.py +++ b/pyiceberg/expressions/literals.py @@ -23,7 +23,7 @@ import struct from abc import ABC, abstractmethod -from datetime import date, datetime +from datetime import date, datetime, time from decimal import ROUND_HALF_UP, Decimal from functools import singledispatchmethod from math import isnan @@ -54,6 +54,7 @@ datetime_to_micros, micros_to_days, time_str_to_micros, + time_to_micros, timestamp_to_micros, timestamptz_to_micros, ) @@ -152,6 +153,8 @@ def literal(value: L) -> Literal[L]: return TimestampLiteral(datetime_to_micros(value)) # type: ignore elif isinstance(value, date): return DateLiteral(date_to_days(value)) # type: ignore + elif isinstance(value, time): + return TimeLiteral(time_to_micros(value)) else: raise TypeError(f"Invalid literal value: {repr(value)}") diff --git a/pyiceberg/table/update/schema.py b/pyiceberg/table/update/schema.py index 78d22795ac..075298c504 100644 --- a/pyiceberg/table/update/schema.py +++ b/pyiceberg/table/update/schema.py @@ -398,11 +398,11 @@ def _set_column_default_value(self, path: Union[str, Tuple[str, ...]], default_v except ValueError as e: raise ValueError(f"Invalid default value: {e}") from e - if field.required and default_value != field.write_default: + if field.required and default_value == field.write_default: # if the change is a noop, allow it even if allowIncompatibleChanges is false return - if not self._allow_incompatible_changes and field.required and default_value is not None: + if not self._allow_incompatible_changes and field.required and default_value is None: raise ValueError("Cannot change change default-value of a required column to None") if field.field_id in self._deletes: @@ -729,19 +729,35 @@ def struct(self, struct: StructType, field_results: List[Optional[IcebergType]]) name = field.name doc = field.doc required = field.required + write_default = field.write_default # There is an update if update := self._updates.get(field.field_id): name = update.name doc = update.doc required = update.required - - if field.name == name and field.field_type == result_type and field.required == required and field.doc == doc: + write_default = update.write_default + + if ( + field.name == name + and field.field_type == result_type + and field.required == required + and field.doc == doc + and field.write_default == write_default + ): new_fields.append(field) else: has_changes = True new_fields.append( - NestedField(field_id=field.field_id, name=name, field_type=result_type, required=required, doc=doc) + NestedField( + field_id=field.field_id, + name=name, + field_type=result_type, + required=required, + doc=doc, + initial_default=field.initial_default, + write_default=write_default, + ) ) if has_changes: diff --git a/pyiceberg/types.py b/pyiceberg/types.py index d898ad3ffd..b4bf01bd91 100644 --- a/pyiceberg/types.py +++ b/pyiceberg/types.py @@ -35,22 +35,24 @@ import re from functools import cached_property from typing import ( + Annotated, Any, ClassVar, Dict, Literal, Optional, - Tuple, Annotated, + Tuple, ) from pydantic import ( + BeforeValidator, Field, PrivateAttr, SerializeAsAny, model_serializer, - model_validator, PlainSerializer, BeforeValidator, + model_validator, ) -from pydantic_core.core_schema import ValidatorFunctionWrapHandler, SerializationInfo, ValidationInfo +from pydantic_core.core_schema import ValidationInfo, ValidatorFunctionWrapHandler from pyiceberg.exceptions import ValidationError from pyiceberg.typedef import IcebergBaseModel, IcebergRootModel, L, TableVersion @@ -289,24 +291,21 @@ def __eq__(self, other: Any) -> bool: return self.root == other.root if isinstance(other, DecimalType) else False -# def _serialize_default_value(v: Any, context: SerializationInfo) -> Any: -# from pyiceberg.conversions import to_json, from_json -# return v def _deserialize_default_value(v: Any, context: ValidationInfo) -> Any: - if context.mode != 'python': - if v is not None: - from pyiceberg.conversions import from_json - return from_json(context.data.get("field_type"), v) - else: - return None + if v is not None: + from pyiceberg.conversions import from_json + + return from_json(context.data.get("field_type"), v) else: - return v + return None + -# PlainSerializer(_serialize_default_value, return_type=Any), DefaultValue = Annotated[ - Any, BeforeValidator(_deserialize_default_value) + L, + BeforeValidator(_deserialize_default_value), ] + class NestedField(IcebergType): """Represents a field of a struct, a map key, a map value, or a list element. @@ -335,8 +334,8 @@ class NestedField(IcebergType): field_type: SerializeAsAny[IcebergType] = Field(alias="type") required: bool = Field(default=False) doc: Optional[str] = Field(default=None, repr=False) - initial_default: DefaultValue = Field(alias="initial-default", default=None, repr=False) - write_default: DefaultValue = Field(alias="write-default", default=None, repr=False) # type: ignore + initial_default: Optional[DefaultValue] = Field(alias="initial-default", default=None, repr=False) # type: ignore + write_default: Optional[DefaultValue] = Field(alias="write-default", default=None, repr=False) # type: ignore def __init__( self, @@ -360,6 +359,26 @@ def __init__( data["write-default"] = data["write-default"] if "write-default" in data else write_default super().__init__(**data) + @model_serializer() + def serialize_model(self) -> Dict[str, Any]: + from pyiceberg.conversions import to_json + + fields = { + "id": self.field_id, + "name": self.name, + "type": self.field_type, + "required": self.required, + } + + if self.doc is not None: + fields["doc"] = self.doc + if self.initial_default is not None: + fields["initial-default"] = to_json(self.field_type, self.initial_default) + if self.write_default is not None: + fields["write-default"] = to_json(self.field_type, self.write_default) + + return fields + def __str__(self) -> str: """Return the string representation of the NestedField class.""" doc = "" if not self.doc else f" ({self.doc})" diff --git a/pyiceberg/utils/schema_conversion.py b/pyiceberg/utils/schema_conversion.py index 6959380d63..ec2fccd509 100644 --- a/pyiceberg/utils/schema_conversion.py +++ b/pyiceberg/utils/schema_conversion.py @@ -530,7 +530,7 @@ def field(self, field: NestedField, field_result: AvroType) -> AvroType: } if field.write_default is not None: - result["default"] = field.write_default # type: ignore + result["default"] = field.write_default elif field.optional: result["default"] = None diff --git a/tests/integration/test_rest_schema.py b/tests/integration/test_rest_schema.py index ef6604668b..1c30331fac 100644 --- a/tests/integration/test_rest_schema.py +++ b/tests/integration/test_rest_schema.py @@ -24,7 +24,6 @@ from pyiceberg.catalog import Catalog, load_catalog from pyiceberg.exceptions import CommitFailedException, NoSuchTableError, ValidationError -from pyiceberg.expressions import literal # type: ignore from pyiceberg.partitioning import PartitionField, PartitionSpec from pyiceberg.schema import Schema, prune_columns from pyiceberg.table import Table, TableProperties @@ -1100,25 +1099,25 @@ def test_add_required_column(catalog: Catalog) -> None: @pytest.mark.parametrize( "iceberg_type, default_value, write_default", [ - # (BooleanType(), True, False), - # (IntegerType(), 123, 456), - # (LongType(), 123, 456), - # (FloatType(), 19.25, 22.27), - # (DoubleType(), 19.25, 22.27), + (BooleanType(), True, False), + (IntegerType(), 123, 456), + (LongType(), 123, 456), + (FloatType(), 19.25, 22.27), + (DoubleType(), 19.25, 22.27), (DecimalType(10, 2), Decimal("19.25"), Decimal("22.27")), - # (DecimalType(10, 2), Decimal("19.25"), Decimal("22.27")), - # (StringType(), "abc", "def"), - # (DateType(), date(1990, 3, 1), date(1991, 3, 1)), - # (TimeType(), time(19, 25, 22), time(22, 25, 22)), - # (TimestampType(), datetime(1990, 5, 1, 22, 1, 1), datetime(2000, 5, 1, 22, 1, 1)), - # ( - # TimestamptzType(), - # datetime(1990, 5, 1, 22, 1, 1, tzinfo=timezone.utc), - # datetime(2000, 5, 1, 22, 1, 1, tzinfo=timezone.utc), - # ), - # (BinaryType(), b"123", b"456"), - # (FixedType(4), b"1234", b"5678"), - # (UUIDType(), UUID(int=0x12345678123456781234567812345678), UUID(int=0x32145678123456781234567812345678)), + (DecimalType(10, 2), Decimal("19.25"), Decimal("22.27")), + (StringType(), "abc", "def"), + (DateType(), date(1990, 3, 1), date(1991, 3, 1)), + (TimeType(), time(19, 25, 22), time(22, 25, 22)), + (TimestampType(), datetime(1990, 5, 1, 22, 1, 1), datetime(2000, 5, 1, 22, 1, 1)), + ( + TimestamptzType(), + datetime(1990, 5, 1, 22, 1, 1, tzinfo=timezone.utc), + datetime(2000, 5, 1, 22, 1, 1, tzinfo=timezone.utc), + ), + (BinaryType(), b"123", b"456"), + (FixedType(4), b"1234", b"5678"), + (UUIDType(), UUID(int=0x12345678123456781234567812345678), UUID(int=0x32145678123456781234567812345678)), ], ) def test_initial_default_all_columns( @@ -1132,17 +1131,15 @@ def test_initial_default_all_columns( tx.commit() field = table.schema().find_field(1) - physical_type = literal(default_value).to(iceberg_type).value - assert field.initial_default == physical_type - assert field.write_default == physical_type + assert field.initial_default == default_value + assert field.write_default == default_value with table.update_schema() as tx: tx.set_default_value("data", write_default) field = table.schema().find_field(1) - write_physical_type = literal(default_value).to(iceberg_type).value - assert field.initial_default == physical_type - assert field.write_default == write_physical_type + assert field.initial_default == default_value + assert field.write_default == write_default @pytest.mark.integration diff --git a/tests/test_types.py b/tests/test_types.py index b19df17e08..75f2fe418f 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -225,11 +225,9 @@ def test_nested_field() -> None: assert str(field_var) == str(eval(repr(field_var))) assert field_var == pickle.loads(pickle.dumps(field_var)) - with pytest.raises(pydantic_core.ValidationError) as exc_info: + with pytest.raises(pydantic_core.ValidationError, match=".*validation errors for NestedField.*"): _ = (NestedField(1, "field", StringType(), required=True, write_default=(1, "a", True)),) # type: ignore - assert "validation errors for NestedField" in str(exc_info.value) - @pytest.mark.parametrize("input_index,input_type", non_parameterized_types) @pytest.mark.parametrize("check_index,check_type", non_parameterized_types) diff --git a/tests/utils/test_manifest.py b/tests/utils/test_manifest.py index 3b1fc6f013..42ebaebb49 100644 --- a/tests/utils/test_manifest.py +++ b/tests/utils/test_manifest.py @@ -416,7 +416,7 @@ def test_write_manifest( data_file = manifest_entry.data_file - assert data_file.content is DataFileContent.DATA + assert data_file.content == DataFileContent.DATA assert ( data_file.file_path == "/home/iceberg/warehouse/nyc/taxis_partitioned/data/VendorID=null/00000-633-d8a4223e-dc97-45a1-86e1-adaba6e8abd7-00001.parquet" From 8dab08e5161bfe0e2535986484a845d899a7f35a Mon Sep 17 00:00:00 2001 From: Fokko Date: Wed, 26 Mar 2025 15:17:50 +0100 Subject: [PATCH 08/10] MY HANDS ARE FIXING TESTS --- pyiceberg/types.py | 5 +---- tests/utils/test_manifest.py | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/pyiceberg/types.py b/pyiceberg/types.py index b4bf01bd91..ca8cac80ed 100644 --- a/pyiceberg/types.py +++ b/pyiceberg/types.py @@ -300,10 +300,7 @@ def _deserialize_default_value(v: Any, context: ValidationInfo) -> Any: return None -DefaultValue = Annotated[ - L, - BeforeValidator(_deserialize_default_value), -] +DefaultValue = Annotated[L, BeforeValidator(_deserialize_default_value)] class NestedField(IcebergType): diff --git a/tests/utils/test_manifest.py b/tests/utils/test_manifest.py index 42ebaebb49..70823508ed 100644 --- a/tests/utils/test_manifest.py +++ b/tests/utils/test_manifest.py @@ -79,7 +79,7 @@ def test_read_manifest_entry(generated_manifest_entry_file: str) -> None: data_file = manifest_entry.data_file - assert data_file.content is DataFileContent.DATA + assert data_file.content == DataFileContent.DATA assert ( data_file.file_path == "/home/iceberg/warehouse/nyc/taxis_partitioned/data/VendorID=null/00000-633-d8a4223e-dc97-45a1-86e1-adaba6e8abd7-00001.parquet" From 8b0f8a49fd89378a2897a484aec1cc864b4f72dd Mon Sep 17 00:00:00 2001 From: Fokko Date: Wed, 26 Mar 2025 19:10:09 +0100 Subject: [PATCH 09/10] A GIFT THAT KEEPS ON GIVING --- pyiceberg/expressions/literals.py | 2 +- pyiceberg/table/update/schema.py | 7 ++++--- pyiceberg/typedef.py | 4 ++-- tests/integration/test_rest_schema.py | 24 +++++++++--------------- 4 files changed, 16 insertions(+), 21 deletions(-) diff --git a/pyiceberg/expressions/literals.py b/pyiceberg/expressions/literals.py index a3b5242b97..490c0ba2da 100644 --- a/pyiceberg/expressions/literals.py +++ b/pyiceberg/expressions/literals.py @@ -154,7 +154,7 @@ def literal(value: L) -> Literal[L]: elif isinstance(value, date): return DateLiteral(date_to_days(value)) # type: ignore elif isinstance(value, time): - return TimeLiteral(time_to_micros(value)) + return TimeLiteral(time_to_micros(value)) # type: ignore else: raise TypeError(f"Invalid literal value: {repr(value)}") diff --git a/pyiceberg/table/update/schema.py b/pyiceberg/table/update/schema.py index 075298c504..6ad01e97f2 100644 --- a/pyiceberg/table/update/schema.py +++ b/pyiceberg/table/update/schema.py @@ -48,6 +48,7 @@ UpdatesAndRequirements, UpdateTableMetadata, ) +from pyiceberg.typedef import L from pyiceberg.types import IcebergType, ListType, MapType, NestedField, PrimitiveType, StructType if TYPE_CHECKING: @@ -159,7 +160,7 @@ def add_column( field_type: IcebergType, doc: Optional[str] = None, required: bool = False, - default_value: Optional[Any] = None, + default_value: Optional[L] = None, ) -> UpdateSchema: """Add a new column to a nested struct or Add a new top-level column. @@ -224,7 +225,7 @@ def add_column( except ValueError as e: raise ValueError(f"Invalid default value: {e}") from e else: - initial_default = default_value + initial_default = default_value # type: ignore if (required and initial_default is None) and not self._allow_incompatible_changes: # Table format version 1 and 2 cannot add required column because there is no initial value @@ -274,7 +275,7 @@ def delete_column(self, path: Union[str, Tuple[str, ...]]) -> UpdateSchema: return self - def set_default_value(self, path: Union[str, Tuple[str, ...]], default_value: Any) -> UpdateSchema: + def set_default_value(self, path: Union[str, Tuple[str, ...]], default_value: Optional[L]) -> UpdateSchema: """Set the default value of a column. Args: diff --git a/pyiceberg/typedef.py b/pyiceberg/typedef.py index 07374887a3..82d0c901aa 100644 --- a/pyiceberg/typedef.py +++ b/pyiceberg/typedef.py @@ -17,7 +17,7 @@ from __future__ import annotations from abc import abstractmethod -from datetime import date, datetime +from datetime import date, datetime, time from decimal import Decimal from functools import lru_cache from typing import ( @@ -94,7 +94,7 @@ def __missing__(self, key: K) -> V: """A recursive dictionary type for nested structures in PyIceberg.""" # Represents the literal value -L = TypeVar("L", str, bool, int, float, bytes, UUID, Decimal, datetime, date, covariant=True) +L = TypeVar("L", str, bool, int, float, bytes, UUID, Decimal, datetime, date, time, covariant=True) @runtime_checkable diff --git a/tests/integration/test_rest_schema.py b/tests/integration/test_rest_schema.py index 1c30331fac..8f45e3d029 100644 --- a/tests/integration/test_rest_schema.py +++ b/tests/integration/test_rest_schema.py @@ -1145,26 +1145,20 @@ def test_initial_default_all_columns( @pytest.mark.integration def test_add_required_column_initial_default(catalog: Catalog) -> None: schema_ = Schema(NestedField(field_id=1, name="a", field_type=BooleanType(), required=False)) - table = _create_table_with_schema(catalog, schema_) - new_schema = ( - UpdateSchema(transaction=table.transaction()) - .add_column(path="data", field_type=IntegerType(), required=True, default_value=22) - ._apply() - ) - assert new_schema == Schema( - NestedField(field_id=1, name="a", field_type=BooleanType(), required=True, initial_default=True), + table = _create_table_with_schema(catalog, schema_, properties={TableProperties.FORMAT_VERSION: 3}) + + table.update_schema().add_column(path="data", field_type=IntegerType(), required=True, default_value=22).commit() + + assert table.schema() == Schema( + NestedField(field_id=1, name="a", field_type=BooleanType(), required=False), NestedField(field_id=2, name="data", field_type=IntegerType(), required=True, initial_default=22, write_default=22), schema_id=1, ) # Update - new_schema = ( - UpdateSchema(transaction=table.transaction()) - .update_column(path="data", field_type=LongType()) - .rename_column("a", "bool") - ._apply() - ) - assert new_schema == Schema( + table.update_schema().update_column(path="data", field_type=LongType()).rename_column("a", "bool").commit() + + assert table.schema() == Schema( NestedField(field_id=1, name="bool", field_type=BooleanType(), required=False), NestedField(field_id=2, name="data", field_type=LongType(), required=True, initial_default=22, write_default=22), schema_id=1, From 2b6bad20fec032bc6900df9b82180e8be3a4f63c Mon Sep 17 00:00:00 2001 From: Fokko Date: Tue, 22 Apr 2025 21:09:47 +0200 Subject: [PATCH 10/10] Comments --- pyiceberg/conversions.py | 10 ++++------ tests/integration/test_rest_schema.py | 26 +++++++++++++++++--------- 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/pyiceberg/conversions.py b/pyiceberg/conversions.py index b9b05a00ee..7bf7b462e2 100644 --- a/pyiceberg/conversions.py +++ b/pyiceberg/conversions.py @@ -563,14 +563,12 @@ def _(_: StringType, val: str) -> str: def _(t: FixedType, val: Union[str, bytes]) -> bytes: """JSON hexadecimal encoded string into bytes.""" if isinstance(val, str): - b = codecs.decode(val.encode(UTF8), "hex") + val = codecs.decode(val.encode(UTF8), "hex") - if len(t) != len(b): - raise ValueError(f"FixedType has length {len(t)}, which is different from the value: {len(b)}") + if len(t) != len(val): + raise ValueError(f"FixedType has length {len(t)}, which is different from the value: {len(val)}") - return b - else: - return val + return val @from_json.register(BinaryType) diff --git a/tests/integration/test_rest_schema.py b/tests/integration/test_rest_schema.py index f8ae993fa1..4462da1c8c 100644 --- a/tests/integration/test_rest_schema.py +++ b/tests/integration/test_rest_schema.py @@ -1095,7 +1095,7 @@ def test_add_required_column(catalog: Catalog) -> None: @pytest.mark.integration @pytest.mark.parametrize( - "iceberg_type, default_value, write_default", + "iceberg_type, initial_default, write_default", [ (BooleanType(), True, False), (IntegerType(), 123, 456), @@ -1119,25 +1119,33 @@ def test_add_required_column(catalog: Catalog) -> None: ], ) def test_initial_default_all_columns( - catalog: Catalog, iceberg_type: PrimitiveType, default_value: Any, write_default: Any + catalog: Catalog, iceberg_type: PrimitiveType, initial_default: Any, write_default: Any ) -> None: # Round trips all the types through the rest catalog to check the serialization table = _create_table_with_schema(catalog, Schema(), properties={TableProperties.FORMAT_VERSION: 3}) tx = table.update_schema() - tx.add_column(path="data", field_type=iceberg_type, required=True, default_value=default_value) + tx.add_column(path="data", field_type=iceberg_type, required=True, default_value=initial_default) + tx.add_column(path="nested", field_type=StructType(), required=False) tx.commit() - field = table.schema().find_field(1) - assert field.initial_default == default_value - assert field.write_default == default_value + tx = table.update_schema() + tx.add_column(path=("nested", "data"), field_type=iceberg_type, required=True, default_value=initial_default) + tx.commit() + + for field_id in [1, 3]: + field = table.schema().find_field(field_id) + assert field.initial_default == initial_default + assert field.write_default == initial_default with table.update_schema() as tx: tx.set_default_value("data", write_default) + tx.set_default_value(("nested", "data"), write_default) - field = table.schema().find_field(1) - assert field.initial_default == default_value - assert field.write_default == write_default + for field_id in [1, 3]: + field = table.schema().find_field(field_id) + assert field.initial_default == initial_default + assert field.write_default == write_default @pytest.mark.integration