Skip to content

Update-schema: Add support for initial-default #1770

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Apr 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyiceberg/avro/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def struct(self, file_schema: StructType, record_struct: Optional[IcebergType],
# There is a default value
if file_field.write_default is not None:
# The field is not in the record, but there is a write default value
results.append((None, DefaultWriter(writer=writer, value=file_field.write_default))) # type: ignore
results.append((None, DefaultWriter(writer=writer, value=file_field.write_default)))
elif file_field.required:
Comment on lines +293 to 294
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there tests for the round trip writing/reading of default values? Or are we doing that separately, and we're just focusing on the schema update changes in this PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's this PR #1644 but I thought splitting out the schema-update changes in a separate PR might make it easier to review 👍

raise ValueError(f"Field is required, and there is no write default: {file_field}")
else:
Expand Down
63 changes: 46 additions & 17 deletions pyiceberg/conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,27 +503,47 @@ def _(_: Union[IntegerType, LongType], val: int) -> int:


@from_json.register(DateType)
def _(_: DateType, val: str) -> date:
def _(_: DateType, val: Union[str, int, date]) -> date:
"""JSON date is string encoded."""
return days_to_date(date_str_to_days(val))
if isinstance(val, str):
val = date_str_to_days(val)
if isinstance(val, int):
return days_to_date(val)
else:
return val


@from_json.register(TimeType)
def _(_: TimeType, val: str) -> time:
def _(_: TimeType, val: Union[str, int, time]) -> time:
"""JSON ISO8601 string into Python time."""
return micros_to_time(time_str_to_micros(val))
if isinstance(val, str):
val = time_str_to_micros(val)
if isinstance(val, int):
return micros_to_time(val)
else:
return val


@from_json.register(TimestampType)
def _(_: PrimitiveType, val: str) -> datetime:
def _(_: PrimitiveType, val: Union[str, int, datetime]) -> datetime:
"""JSON ISO8601 string into Python datetime."""
return micros_to_timestamp(timestamp_to_micros(val))
if isinstance(val, str):
val = timestamp_to_micros(val)
if isinstance(val, int):
return micros_to_timestamp(val)
else:
return val


@from_json.register(TimestamptzType)
def _(_: TimestamptzType, val: str) -> datetime:
def _(_: TimestamptzType, val: Union[str, int, datetime]) -> datetime:
"""JSON ISO8601 string into Python datetime."""
return micros_to_timestamptz(timestamptz_to_micros(val))
if isinstance(val, str):
val = timestamptz_to_micros(val)
if isinstance(val, int):
return micros_to_timestamptz(val)
else:
return val


@from_json.register(FloatType)
Expand All @@ -540,20 +560,24 @@ def _(_: StringType, val: str) -> str:


@from_json.register(FixedType)
def _(t: FixedType, val: str) -> bytes:
def _(t: FixedType, val: Union[str, bytes]) -> bytes:
"""JSON hexadecimal encoded string into bytes."""
b = codecs.decode(val.encode(UTF8), "hex")
if isinstance(val, str):
val = codecs.decode(val.encode(UTF8), "hex")

if len(t) != len(b):
raise ValueError(f"FixedType has length {len(t)}, which is different from the value: {len(b)}")
if len(t) != len(val):
raise ValueError(f"FixedType has length {len(t)}, which is different from the value: {len(val)}")

return b
return val


@from_json.register(BinaryType)
def _(_: BinaryType, val: str) -> bytes:
def _(_: BinaryType, val: Union[bytes, str]) -> bytes:
"""JSON hexadecimal encoded string into bytes."""
return codecs.decode(val.encode(UTF8), "hex")
if isinstance(val, str):
return codecs.decode(val.encode(UTF8), "hex")
else:
return val


@from_json.register(DecimalType)
Expand All @@ -563,6 +587,11 @@ def _(_: DecimalType, val: str) -> Decimal:


@from_json.register(UUIDType)
def _(_: UUIDType, val: str) -> uuid.UUID:
def _(_: UUIDType, val: Union[str, bytes, uuid.UUID]) -> uuid.UUID:
"""Convert JSON string into Python UUID."""
return uuid.UUID(val)
if isinstance(val, str):
return uuid.UUID(val)
elif isinstance(val, bytes):
return uuid.UUID(bytes=val)
else:
return val
5 changes: 4 additions & 1 deletion pyiceberg/expressions/literals.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import struct
from abc import ABC, abstractmethod
from datetime import date, datetime
from datetime import date, datetime, time
from decimal import ROUND_HALF_UP, Decimal
from functools import singledispatchmethod
from math import isnan
Expand Down Expand Up @@ -54,6 +54,7 @@
datetime_to_micros,
micros_to_days,
time_str_to_micros,
time_to_micros,
timestamp_to_micros,
timestamptz_to_micros,
)
Expand Down Expand Up @@ -152,6 +153,8 @@ def literal(value: L) -> Literal[L]:
return TimestampLiteral(datetime_to_micros(value)) # type: ignore
elif isinstance(value, date):
return DateLiteral(date_to_days(value)) # type: ignore
elif isinstance(value, time):
return TimeLiteral(time_to_micros(value)) # type: ignore
else:
raise TypeError(f"Invalid literal value: {repr(value)}")

Expand Down
132 changes: 121 additions & 11 deletions pyiceberg/table/update/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
from copy import copy
from dataclasses import dataclass
from enum import Enum
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union

from pyiceberg.exceptions import ResolveError, ValidationError
from pyiceberg.expressions import literal # type: ignore
from pyiceberg.schema import (
PartnerAccessor,
Schema,
Expand All @@ -47,6 +48,7 @@
UpdatesAndRequirements,
UpdateTableMetadata,
)
from pyiceberg.typedef import L
from pyiceberg.types import IcebergType, ListType, MapType, NestedField, PrimitiveType, StructType

if TYPE_CHECKING:
Expand Down Expand Up @@ -153,7 +155,12 @@ def union_by_name(self, new_schema: Union[Schema, "pa.Schema"]) -> UpdateSchema:
return self

def add_column(
self, path: Union[str, Tuple[str, ...]], field_type: IcebergType, doc: Optional[str] = None, required: bool = False
self,
path: Union[str, Tuple[str, ...]],
field_type: IcebergType,
doc: Optional[str] = None,
required: bool = False,
default_value: Optional[L] = None,
) -> UpdateSchema:
"""Add a new column to a nested struct or Add a new top-level column.

Expand All @@ -168,6 +175,7 @@ def add_column(
field_type: Type for the new column.
doc: Documentation string for the new column.
required: Whether the new column is required.
default_value: Default value for the new column.

Returns:
This for method chaining.
Expand All @@ -177,10 +185,6 @@ def add_column(
raise ValueError(f"Cannot add column with ambiguous name: {path}, provide a tuple instead")
path = (path,)

if required and not self._allow_incompatible_changes:
# Table format version 1 and 2 cannot add required column because there is no initial value
raise ValueError(f"Incompatible change: cannot add required column: {'.'.join(path)}")

name = path[-1]
parent = path[:-1]

Expand Down Expand Up @@ -212,13 +216,34 @@ def add_column(

# assign new IDs in order
new_id = self.assign_new_column_id()
new_type = assign_fresh_schema_ids(field_type, self.assign_new_column_id)

if default_value is not None:
try:
# To make sure that the value is valid for the type
initial_default = literal(default_value).to(new_type).value
except ValueError as e:
raise ValueError(f"Invalid default value: {e}") from e
else:
initial_default = default_value # type: ignore

if (required and initial_default is None) and not self._allow_incompatible_changes:
# Table format version 1 and 2 cannot add required column because there is no initial value
raise ValueError(f"Incompatible change: cannot add required column: {'.'.join(path)}")

# update tracking for moves
self._added_name_to_id[full_name] = new_id
self._id_to_parent[new_id] = parent_full_path

new_type = assign_fresh_schema_ids(field_type, self.assign_new_column_id)
field = NestedField(field_id=new_id, name=name, field_type=new_type, required=required, doc=doc)
field = NestedField(
field_id=new_id,
name=name,
field_type=new_type,
required=required,
doc=doc,
initial_default=initial_default,
write_default=initial_default,
)

if parent_id in self._adds:
self._adds[parent_id].append(field)
Expand Down Expand Up @@ -250,6 +275,19 @@ def delete_column(self, path: Union[str, Tuple[str, ...]]) -> UpdateSchema:

return self

def set_default_value(self, path: Union[str, Tuple[str, ...]], default_value: Optional[L]) -> UpdateSchema:
"""Set the default value of a column.

Args:
path: The path to the column.

Returns:
The UpdateSchema with the delete operation staged.
"""
self._set_column_default_value(path, default_value)

return self

def rename_column(self, path_from: Union[str, Tuple[str, ...]], new_name: str) -> UpdateSchema:
"""Update the name of a column.

Expand All @@ -273,6 +311,8 @@ def rename_column(self, path_from: Union[str, Tuple[str, ...]], new_name: str) -
field_type=updated.field_type,
doc=updated.doc,
required=updated.required,
initial_default=updated.initial_default,
write_default=updated.write_default,
)
else:
self._updates[field_from.field_id] = NestedField(
Expand All @@ -281,6 +321,8 @@ def rename_column(self, path_from: Union[str, Tuple[str, ...]], new_name: str) -
field_type=field_from.field_type,
doc=field_from.doc,
required=field_from.required,
initial_default=field_from.initial_default,
write_default=field_from.write_default,
)

# Lookup the field because of casing
Expand Down Expand Up @@ -330,6 +372,8 @@ def _set_column_requirement(self, path: Union[str, Tuple[str, ...]], required: b
field_type=updated.field_type,
doc=updated.doc,
required=required,
initial_default=updated.initial_default,
write_default=updated.write_default,
)
else:
self._updates[field.field_id] = NestedField(
Expand All @@ -338,6 +382,52 @@ def _set_column_requirement(self, path: Union[str, Tuple[str, ...]], required: b
field_type=field.field_type,
doc=field.doc,
required=required,
initial_default=field.initial_default,
write_default=field.write_default,
)

def _set_column_default_value(self, path: Union[str, Tuple[str, ...]], default_value: Any) -> None:
path = (path,) if isinstance(path, str) else path
name = ".".join(path)

field = self._schema.find_field(name, self._case_sensitive)

if default_value is not None:
try:
# To make sure that the value is valid for the type
default_value = literal(default_value).to(field.field_type).value
except ValueError as e:
raise ValueError(f"Invalid default value: {e}") from e

if field.required and default_value == field.write_default:
# if the change is a noop, allow it even if allowIncompatibleChanges is false
return

if not self._allow_incompatible_changes and field.required and default_value is None:
raise ValueError("Cannot change change default-value of a required column to None")

if field.field_id in self._deletes:
raise ValueError(f"Cannot update a column that will be deleted: {name}")

if updated := self._updates.get(field.field_id):
self._updates[field.field_id] = NestedField(
field_id=updated.field_id,
name=updated.name,
field_type=updated.field_type,
doc=updated.doc,
required=updated.required,
initial_default=updated.initial_default,
write_default=default_value,
)
else:
self._updates[field.field_id] = NestedField(
field_id=field.field_id,
name=field.name,
field_type=field.field_type,
doc=field.doc,
required=field.required,
initial_default=field.initial_default,
write_default=default_value,
)

def update_column(
Expand Down Expand Up @@ -387,6 +477,8 @@ def update_column(
field_type=field_type or updated.field_type,
doc=doc if doc is not None else updated.doc,
required=updated.required,
initial_default=updated.initial_default,
write_default=updated.write_default,
)
else:
self._updates[field.field_id] = NestedField(
Expand All @@ -395,6 +487,8 @@ def update_column(
field_type=field_type or field.field_type,
doc=doc if doc is not None else field.doc,
required=field.required,
initial_default=field.initial_default,
write_default=field.write_default,
)

if required is not None:
Expand Down Expand Up @@ -636,19 +730,35 @@ def struct(self, struct: StructType, field_results: List[Optional[IcebergType]])
name = field.name
doc = field.doc
required = field.required
write_default = field.write_default

# There is an update
if update := self._updates.get(field.field_id):
name = update.name
doc = update.doc
required = update.required

if field.name == name and field.field_type == result_type and field.required == required and field.doc == doc:
write_default = update.write_default

if (
field.name == name
and field.field_type == result_type
and field.required == required
and field.doc == doc
and field.write_default == write_default
):
new_fields.append(field)
else:
has_changes = True
new_fields.append(
NestedField(field_id=field.field_id, name=name, field_type=result_type, required=required, doc=doc)
NestedField(
field_id=field.field_id,
name=name,
field_type=result_type,
required=required,
doc=doc,
initial_default=field.initial_default,
write_default=write_default,
)
)

if has_changes:
Expand Down
Loading