From 3099a9aac722fbbfb160865b967bef3389f8026c Mon Sep 17 00:00:00 2001 From: Adrian Qin <147659252+jqin61@users.noreply.github.com> Date: Wed, 20 Mar 2024 17:11:45 +0000 Subject: [PATCH 1/6] validation for static overwrite with filter --- pyiceberg/table/__init__.py | 221 +++++++++++++++++++++- tests/table/test_init.py | 355 +++++++++++++++++++++++++++++++++++- 2 files changed, 564 insertions(+), 12 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 2fab4b7cf5..ad03bc0177 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -134,6 +134,42 @@ from pyiceberg.utils.concurrent import ExecutorFactory from pyiceberg.utils.datetime import datetime_to_millis +from pyiceberg.expressions import ( + AlwaysFalse, + AlwaysTrue, + And, + BooleanExpression, + BoundEqualTo, + BoundGreaterThan, + BoundGreaterThanOrEqual, + BoundIn, + BoundIsNaN, + BoundIsNull, + BoundLessThan, + BoundLessThanOrEqual, + BoundNotEqualTo, + BoundNotIn, + BoundNotNaN, + BoundNotNull, + BoundReference, + EqualTo, + GreaterThan, + GreaterThanOrEqual, + In, + IsNaN, + IsNull, + LessThan, + LessThanOrEqual, + Not, + NotEqualTo, + NotIn, + NotNaN, + NotNull, + Or, + Reference, + UnboundPredicate, +) + if TYPE_CHECKING: import daft import pandas as pd @@ -149,8 +185,75 @@ _JAVA_LONG_MAX = 9223372036854775807 - -def _check_schema(table_schema: Schema, other_schema: "pa.Schema") -> None: +# to do: make the types not expression but unbound predicates, this should be more precise, we already know it could only be unboundisnull and unboundequalto +# actually could make it union[isnull, equalto] and make the return as union[boundisnull,boundequalto] +def _validate_static_overwrite_filter_field(unbound_expr: BooleanExpression, table_schema: Schema, spec:PartitionSpec)-> BoundPredicate: + # step 1: check the unbound_expr is within the schema and the value matches the schema + print(f"unbound_expr is {unbound_expr=}") + bound_expr = unbound_expr.bind(table_schema) + print(f"{bound_expr=}") + print(f"{bound_expr.term=}, {bound_expr.term.ref()=}, {bound_expr.term.ref().unbound_expr=}, {bound_expr.term.ref().unbound_expr.field_id=}") + nested_field = bound_expr.term.ref().unbound_expr + + # step 2: check the unbound_expr is within the partition spec + # this is the part unbound_expr + part_fields = spec.fields_by_source_id(nested_field.field_id) + print(f"{part_fields=}") + if len(part_fields) != 1: + raise ValueError(f"get {len(part_fields)=}, not 1, if this number is 0, indicating the static filter is not within the partition fields, which is invalid") + part_field = part_fields[0] + + + # step 3: check the unbound_expr is with identity transform + print(f"{part_field.transform=}") + if not isinstance(part_field.transform, IdentityTransform): + raise ValueError(f"static overwrite partition filter can only apply to partition fields which are without hidden transform, but get {part_field.transform=} for {bound_field.term.ref().unbound_expr=}") + + return bound_expr # nested_field + + +def _validate_static_overwrite_filter(table_schema: Schema, overwrite_filter: BooleanExpression, spec:PartitionSpec) -> None: + is_null_predicates, eq_to_predicates = _validate_static_overwrite_filter_expr_type(expr=overwrite_filter) + + bound_is_null_predicates = [] + bound_eq_to_predicates = [] + + print("---check is null fields with rules of filter,part,transform") + # these are expressions (not as name - fields) + for unbound_expr in is_null_predicates: + bp:BoundPredicate = _validate_static_overwrite_filter_field(unbound_expr = unbound_expr, table_schema = table_schema, spec = spec) + #bound_field.term.ref().field + bound_is_null_predicates.append(bp) + + print("----check eq to fields with rules of filter,part,transform") + for unbound_expr in eq_to_predicates: + bp:BoundPredicate = _validate_static_overwrite_filter_field(unbound_expr = unbound_expr, table_schema = table_schema, spec = spec) + bound_eq_to_predicates.append(bp) + + return bound_is_null_predicates, bound_eq_to_predicates + +def _fill_in_df(df: pa.Table, bound_is_null_predicates: List[BoundIsNull], bound_eq_to_predicates: List[BoundEqualTo]): + is_null_nested_fields = [bound_field.term.ref().field for bound_predicate in bound_is_null_predicates] + eq_to_nested_fields = [bound_field.term.ref().field for bound_predicate in bound_eq_to_predicates] + from itertools import chain + schema = Schema(*chain(is_null_nested_fields, eq_to_nested_fields)) + print(f"{schema=}") + from pyiceberg.io.pyarrow import schema_to_pyarrow + + pa_schema = schema_to_pyarrow(schema) + print(f"{pa_schema=}") + + is_null_nested_field_name_values = zip([nested_field.name for nested_field in is_null_nested_fields], [None]*len(bound_is_null_predicates)) + eq_to_nested_field_name_values = zip([nested_field.name for nested_field in eq_to_nested_fields], [predicate.literal for predicate in bound_eq_to_predicates]) + + num_rows = df.num_rows + for field_name, value in is_null_nested_fields + eq_to_nested_fields: + pa_field = pa_schema.field(field_name) + literal_array = pa.array([None] * num_rows, type=pa_field.type) + df = df.add_column(df.num_columns, field.name, null_array) + return df + +def _check_schema(table_schema: Schema, other_schema: "pa.Schema", to_truncate: List[NestedField]=[]): from pyiceberg.io.pyarrow import _pyarrow_to_schema_without_ids, pyarrow_to_schema name_mapping = table_schema.name_mapping @@ -162,8 +265,18 @@ def _check_schema(table_schema: Schema, other_schema: "pa.Schema") -> None: raise ValueError( f"PyArrow table contains more columns: {', '.join(sorted(additional_names))}. Update the schema first (hint, use union_by_name)." ) from e + print(f"{task_schema=}, {task_schema.as_struct()=}") + + # check fields dont step into itself, and do not step into each other, maybe we could move this to other 1+3(here) fields check - if table_schema.as_struct() != task_schema.as_struct(): + # check mutual + union + remaining_schema = table_schema + if len(to_truncate) != 0: + remaining_schema = _truncate_fields(table_schema, to_truncate) + + print(f"{remaining_schema.as_struct()=}") + print(f"{task_schema.as_struct()=}") + if remaining_schema.as_struct() != task_schema.as_struct(): from rich.console import Console from rich.table import Table as RichTable @@ -174,15 +287,91 @@ def _check_schema(table_schema: Schema, other_schema: "pa.Schema") -> None: rich_table.add_column("Table field") rich_table.add_column("Dataframe field") + partition_filter_field_names = set([nested_field.name for nested_field in to_truncate]) + print(f"{partition_filter_field_names=}") + for lhs in table_schema.fields: - try: - rhs = task_schema.find_field(lhs.field_id) - rich_table.add_row("✅" if lhs == rhs else "❌", str(lhs), str(rhs)) - except ValueError: - rich_table.add_row("❌", str(lhs), "Missing") + if lhs.name in partition_filter_field_names: + try: + rhs = task_schema.find_field(lhs.field_id) + rich_table.add_row("❌", str(lhs), "Appears in both filter and arrow table.") + except ValueError: + rich_table.add_row("✅", str(lhs), "Appears in filter but not arrow table.") + else: + try: + rhs = task_schema.find_field(lhs.field_id) + rich_table.add_row("✅" if lhs == rhs else "❌", str(lhs), str(rhs)) + except ValueError: + rich_table.add_row("❌", str(lhs), "Missing") console.print(rich_table) - raise ValueError(f"Mismatch in fields:\n{console.export_text()}") + raise ValueError(f"With partition fields in static overwrite filter as {[nested_field.name for nested_field in to_truncate]}, mismatch in fields:\n{console.export_text()}") + +def _truncate_fields(table_schema: Schema, to_truncate: List[NestedField]) -> Schema: + to_truncate_fields_source_ids = set(nested_field.field_id for nested_field in to_truncate) + truncated = [field for field in table_schema.fields if field.field_id not in to_truncate_fields_source_ids] + print(f"{truncated=}") + return Schema(*truncated) + + +def _validate_static_overwrite_filter_expr_type(expr: BooleanExpression): + '''Validate whether expression only has 1)And 2)IsNull and 3)EqualTo and break down the raw expression into IsNull and EqualTo. + ''' + from pyiceberg.io.pyarrow import _pyarrow_to_schema_without_ids, pyarrow_to_schema + + def _recursively_fetch_fields(expr, is_null_predicates, eq_to_predicates) -> None: + print(f"get expression as {expr=}") + if isinstance(expr, EqualTo): + eq_to_predicates.append(expr) + elif isinstance(expr, IsNull): + is_null_predicates.append(expr) + elif isinstance(expr, And): + _recursively_fetch_fields(expr.left, is_null_predicates, eq_to_predicates) + _recursively_fetch_fields(expr.right, is_null_predicates, eq_to_predicates) + else: + raise ValueError(f"static overwrite partitioning filter can only be isequalto, is null, and, alwaysTrue, but get {expr=}") + + is_null_predicates = [] + eq_to_predicates = [] + _recursively_fetch_fields(expr, is_null_predicates, eq_to_predicates) + print(f"{is_null_predicates=}") + print(f"{eq_to_predicates=}") + return is_null_predicates, eq_to_predicates + + +# def _check_schema(table_schema: Schema, other_schema: "pa.Schema") -> None: +# from pyiceberg.io.pyarrow import _pyarrow_to_schema_without_ids, pyarrow_to_schema + +# name_mapping = table_schema.name_mapping +# try: +# task_schema = pyarrow_to_schema(other_schema, name_mapping=name_mapping) +# except ValueError as e: +# other_schema = _pyarrow_to_schema_without_ids(other_schema) +# additional_names = set(other_schema.column_names) - set(table_schema.column_names) +# raise ValueError( +# f"PyArrow table contains more columns: {', '.join(sorted(additional_names))}. Update the schema first (hint, use union_by_name)." +# ) from e + +# if table_schema.as_struct() != task_schema.as_struct(): +# from rich.console import Console +# from rich.table import Table as RichTable + +# console = Console(record=True) + +# rich_table = RichTable(show_header=True, header_style="bold") +# rich_table.add_column("") +# rich_table.add_column("Table field") +# rich_table.add_column("Dataframe field") + +# for lhs in table_schema.fields: +# try: +# rhs = task_schema.find_field(lhs.field_id) +# rich_table.add_row("✅" if lhs == rhs else "❌", str(lhs), str(rhs)) +# except ValueError: +# rich_table.add_row("❌", str(lhs), "Missing") + +# console.print(rich_table) +# raise ValueError(f"Mismatch in fields:\n{console.export_text()}") class TableProperties: @@ -1159,8 +1348,20 @@ def overwrite(self, df: pa.Table, overwrite_filter: Union[str, BooleanExpression if not isinstance(df, pa.Table): raise ValueError(f"Expected PyArrow table, got: {df}") - _check_schema(self.schema(), other_schema=df.schema) + if not overwrite_filter == ALWAYS_TRUE: + bound_is_null_predicates, bound_eq_to_predicates = _validate_static_overwrite_filter(table_schema=self.schema(), overwrite_filter=overwrite_filter, spec=self.metadata.spec()) -> None: + print("----check fields in partition filter and in the dataframe are mutual exclusive but unioning to full schema") + + is_null_nested_fields = [bound_field.term.ref().field for bound_predicate in bound_is_null_predicates] + eq_to_nested_fields = [bound_field.term.ref().field for bound_predicate in bound_eq_to_predicates] + print(f"{is_null_nested_fields=}") + print(f"{eq_to_nested_fields=}") + _check_schema(table_schema, other_schema, to_truncate = is_null_nested_fields + eq_to_nested_fields) + _fill_in_df(df, bound_is_null_predicates, bound_eq_to_predicates) + else: + _check_schema(table_schema, other_schema) + with self.transaction() as txn: with txn.update_snapshot().overwrite(overwrite_filter) as update_snapshot: # skip writing data files if the dataframe is empty diff --git a/tests/table/test_init.py b/tests/table/test_init.py index be3a28199a..564fc94c93 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -17,12 +17,27 @@ # pylint:disable=redefined-outer-name import uuid from copy import copy -from typing import Any, Dict +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Generic, + Iterable, + List, + Literal, + Optional, + Set, + Tuple, + TypeVar, + Union, +) import pyarrow as pa import pytest from pydantic import ValidationError from sortedcontainers import SortedList +import re from pyiceberg.catalog.noop import NoopCatalog from pyiceberg.exceptions import CommitFailedException @@ -67,6 +82,10 @@ _TableMetadataUpdateContext, update_table_metadata, verify_table_already_sorted, + _fetch_fields_and_validate_expression_type, + _validate_static_overwrite_filter_field, + _check_schema, + # _validate_static_overwrite_filter ) from pyiceberg.table.metadata import INITIAL_SEQUENCE_NUMBER, TableMetadataUtil, TableMetadataV2, _generate_snapshot_id from pyiceberg.table.snapshots import ( @@ -81,7 +100,7 @@ SortField, SortOrder, ) -from pyiceberg.transforms import BucketTransform, IdentityTransform +from pyiceberg.transforms import BucketTransform, IdentityTransform, TruncateTransform from pyiceberg.types import ( BinaryType, BooleanType, @@ -1013,6 +1032,338 @@ def test_correct_schema() -> None: assert "Snapshot not found: -1" in str(exc_info.value) +from pyiceberg.expressions import ( + AlwaysFalse, + AlwaysTrue, + And, + BooleanExpression, + BoundEqualTo, + BoundGreaterThan, + BoundGreaterThanOrEqual, + BoundIn, + BoundIsNaN, + BoundIsNull, + BoundLessThan, + BoundLessThanOrEqual, + BoundNotEqualTo, + BoundNotIn, + BoundNotNaN, + BoundNotNull, + BoundReference, + EqualTo, + GreaterThan, + GreaterThanOrEqual, + In, + IsNaN, + IsNull, + LessThan, + LessThanOrEqual, + Not, + NotEqualTo, + NotIn, + NotNaN, + NotNull, + Or, + Reference, + UnboundPredicate, +) + + +#_validate_static_overwrite_filter_field +@pytest.mark.french +def test__validate_static_overwrite_filter_field_fail_on_non_schema_fields_in_filter()-> None: #pred: BooleanExpression, table_schema: Schema, spec:PartitionSpec)-> None: + # todo: is it possible to make boolean expression more specific, like bound expression? no it is not bound yet + test_schema = Schema( + NestedField(field_id=1, name="foo", field_type=StringType(), required=False), + NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), + schema_id=1, + identifier_field_ids=[2], + ) + pred = EqualTo(Reference("not a field"), "hello") + partition_spec=PartitionSpec( + PartitionField(source_id=1, field_id=1001, transform=IdentityTransform(), name="test_part_col") + ) + with pytest.raises(ValueError, match=f"Could not find field with name {pred.term.name}, case_sensitive=True"): + _validate_static_overwrite_filter_field(field=pred, table_schema=test_schema, spec=partition_spec) + + +@pytest.mark.french +def test__validate_static_overwrite_filter_field_fail_on_non_part_fields_in_filter()-> None: #pred: BooleanExpression, table_schema: Schema, spec:PartitionSpec)-> None: + # todo: is it possible to make boolean expression more specific, like bound expression? no it is not bound yet + test_schema = Schema( + NestedField(field_id=1, name="foo", field_type=StringType(), required=False), + NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), + schema_id=1, + identifier_field_ids=[2], + ) + pred = EqualTo(Reference("foo"), "hello") + partition_spec=PartitionSpec( + PartitionField(source_id=2, field_id=1001, transform=IdentityTransform(), name="bar") + ) + import re + with pytest.raises(ValueError, match=re.escape("get len(part_fields)=0, not 1, if this number is 0, indicating the static filter is not within the partition fields, which is invalid")): + _validate_static_overwrite_filter_field(field=pred, table_schema=test_schema, spec=partition_spec) + +#to do add one test that the partition fields passed +@pytest.mark.french +def test__validate_static_overwrite_filter_field_fail_on_non_identity_transorm_filter()-> None: #pred: BooleanExpression, table_schema: Schema, spec:PartitionSpec)-> None: + # todo: is it possible to make boolean expression more specific, like bound expression? no it is not bound yet + test_schema = Schema( + NestedField(field_id=1, name="foo", field_type=StringType(), required=False), + NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), + schema_id=1, + identifier_field_ids=[2], + ) + pred = EqualTo(Reference("foo"), "hello") + partition_spec=PartitionSpec( + PartitionField(source_id=2, field_id=1001, transform=IdentityTransform(), name="bar"), + PartitionField(source_id=1, field_id=1002, transform=TruncateTransform(2), name="foo_trunc") + ) + # import re + with pytest.raises(ValueError, match="static overwrite partition filter can only apply to partition fields which are without hidden transform, but get.*"): + _validate_static_overwrite_filter_field(field=pred, table_schema=test_schema, spec=partition_spec) + +# combine this with above +@pytest.mark.french +def test__validate_static_overwrite_filter_field_succeed_on_an_identity_field_although_table_has_hidden_partition()-> None: #pred: BooleanExpression, table_schema: Schema, spec:PartitionSpec)-> None: + # todo: is it possible to make boolean expression more specific, like bound expression? no it is not bound yet + test_schema = Schema( + NestedField(field_id=1, name="foo", field_type=StringType(), required=False), + NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), + schema_id=1, + identifier_field_ids=[2], + ) + pred = EqualTo(Reference("bar"), 3) + partition_spec=PartitionSpec( + PartitionField(source_id=2, field_id=1001, transform=IdentityTransform(), name="bar"), + PartitionField(source_id=1, field_id=1002, transform=TruncateTransform(2), name="foo_trunc") + ) + # import re + #with pytest.raises(ValueError, match="static overwrite partition filter can only apply to partition fields which are without hidden transform, but get.*"): + _validate_static_overwrite_filter_field(field=pred, table_schema=test_schema, spec=partition_spec) + +@pytest.mark.french +def test__validate_static_overwrite_filter_field_fail_to_bind()-> None: #pred: BooleanExpression, table_schema: Schema, spec:PartitionSpec)-> None: + # todo: is it possible to make boolean expression more specific, like bound expression? no it is not bound yet + test_schema = Schema( + NestedField(field_id=1, name="foo", field_type=StringType(), required=False), + NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), + schema_id=1, + identifier_field_ids=[2], + ) + pred = EqualTo(Reference("bar"), "an incompatible type") + partition_spec=PartitionSpec( + PartitionField(source_id=2, field_id=1001, transform=IdentityTransform(), name="bar"), + PartitionField(source_id=1, field_id=1002, transform=TruncateTransform(2), name="foo_trunc") + ) + with pytest.raises(ValueError, match="Could not convert an incompatible type into a int"): + _validate_static_overwrite_filter_field(field=pred, table_schema=test_schema, spec=partition_spec) + +# # to do, raises should use full match to check whether it is overlapping error or it is unioning error +# # to do, in the raw code and test, should add that filter is not conflicting with itself,, e.g. the same field can only be referenced once +# # advanced: as long as the values are different +# # as to do, a filter can and a same expr multiple times +# @pytest.mark.integration +# @pytest.mark.parametrize( +# "task_field_names, eq_to_fields_names, null_field_names, schema_field_names, raises", +# [ +# # good case +# (["f1"], ["f2"], ["f3"], ["f1", "f2", "f3"], False), +# # filter has overlapping fields with arrow fields +# (["f1", "f2"], ["f2"], ["f3"], ["f1", "f2", "f3"], True), +# (["f1", "f2"], ["f3"], ["f1"], ["f1", "f2", "f3"], True), +# # filter has one field referenced multiple times +# #(["f1", "f2"], ["f3"], ["f3"], ["f1", "f2", "f3"], True), +# #(["f1", "f2"], ["f3, f3"], ["f4"], ["f1", "f2", "f3", "f4"], True), +# #(["f1", "f2"], ["f4"], ["f3, f3"], ["f1", "f2", "f3", "f4"], True), +# # shortage +# (["f1"], ["f2"], ["f3"], ["f1", "f2", "f3", "f4"], True), +# ] +# ) +# def test__validate_mutual_exclusive_and_union(task_field_names, eq_to_fields_names, null_field_names, schema_field_names, raises): +# if raises: +# with pytest.raises(ValueError): # match=expected): to do +# _check_schema(task_field_names, eq_to_fields_names, null_field_names, schema_field_names) +# else: +# _check_schema(task_field_names, eq_to_fields_names, null_field_names, schema_field_names) + +# to do the truncated is sorted or not? arrow table schema does not have to match exactly in order right while now it has to be in order. +@pytest.mark.french +def test__validate_mutual_exclusive_and_union_succeed(): + table_schema: Schema = Schema( + NestedField(field_id=1, name="foo", field_type=StringType(), required=False), + NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), + schema_id=1, + identifier_field_ids=[2], + ) + other_schema: pa.Schema = pa.schema([ + pa.field('foo', pa.string()), + pa.field('baz', pa.bool_()), + # pa.field('name', pa.string()), + # pa.field('age', pa.int32()), + # pa.field('email', pa.string()), + # pa.field('is_subscribed', pa.bool_()) + ]) + print(f"{other_schema=}") + # pa.schema([ + # pa.field("foo", pa.string(), nullable=False), + # pa.field("bar", pa.int32(), nullable=True), + # pa.field("baz", pa.bool_(), nullable=True), + # ]) + + null_nested_fields: List[NestedField] = [] + eq_to_nested_fields: List[NestedField] = [NestedField(field_id=2, name="bar", field_type=IntegerType(), required=False)] + _check_schema(table_schema, other_schema, null_nested_fields, eq_to_nested_fields) + +# to do rename: with filter +@pytest.mark.french +def test__validate_mutual_exclusive_and_union_with_filter_fail_on_missing_field(): + table_schema: Schema = Schema( + NestedField(field_id=1, name="foo", field_type=StringType(), required=False), + NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), + schema_id=1, + identifier_field_ids=[2], + ) + other_schema: pa.Schema = pa.schema([ + pa.field('baz', pa.bool_()), + ]) + + null_nested_fields: List[NestedField] = [] + eq_to_nested_fields: List[NestedField] = [NestedField(field_id=2, name="bar", field_type=IntegerType(), required=False)] + + expected = re.escape('''With partition fields in static overwrite filter as ['bar'], mismatch in fields: +┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ +┃ ┃ Table field ┃ Dataframe field ┃ +┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ +│ ❌ │ 1: foo: optional string │ Missing │ +│ ✅ │ 2: bar: required int │ Appears in filter but not arrow table. │ +│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │ +└────┴──────────────────────────┴────────────────────────────────────────┘ +''') + with pytest.raises(ValueError, match=expected): + _check_schema(table_schema, other_schema, null_nested_fields, eq_to_nested_fields) + +@pytest.mark.french +def test__validate_mutual_exclusive_and_union_with_filter_fail_on_nullability_mismatch(): + table_schema: Schema = Schema( + NestedField(field_id=1, name="foo", field_type=StringType(), required=False), + NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), + schema_id=1, + identifier_field_ids=[2], + ) + other_schema: pa.Schema = pa.schema([ + pa.field('foo', pa.string()), + pa.field('bar', pa.int32()), + ]) + + + null_nested_fields: List[NestedField] = [] + eq_to_nested_fields: List[NestedField] = [NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False)] + expected = re.escape('''With partition fields in static overwrite filter as ['baz'], mismatch in fields: +┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ +┃ ┃ Table field ┃ Dataframe field ┃ +┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ +│ ✅ │ 1: foo: optional string │ 1: foo: optional string │ +│ ❌ │ 2: bar: required int │ 2: bar: optional int │ +│ ✅ │ 3: baz: optional boolean │ Appears in filter but not arrow table. │ +└────┴──────────────────────────┴────────────────────────────────────────┘ +''') + with pytest.raises(ValueError, match=expected): + _check_schema(table_schema, other_schema, null_nested_fields, eq_to_nested_fields) + +@pytest.mark.french +def test__validate_mutual_exclusive_and_union_with_filter_fail_on_type_mismatch(): + table_schema: Schema = Schema( + NestedField(field_id=1, name="foo", field_type=StringType(), required=False), + NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), + schema_id=1, + identifier_field_ids=[2], + ) + other_schema: pa.Schema = pa.schema([ + pa.field('foo', pa.string()), + pa.field('bar', pa.string(), nullable=False), + ]) + + null_nested_fields: List[NestedField] = [] + eq_to_nested_fields: List[NestedField] = [NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False)] + expected = re.escape('''With partition fields in static overwrite filter as ['baz'], mismatch in fields: +┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ +┃ ┃ Table field ┃ Dataframe field ┃ +┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ +│ ✅ │ 1: foo: optional string │ 1: foo: optional string │ +│ ❌ │ 2: bar: required int │ 2: bar: required string │ +│ ✅ │ 3: baz: optional boolean │ Appears in filter but not arrow table. │ +└────┴──────────────────────────┴────────────────────────────────────────┘ +''') + with pytest.raises(ValueError, match=expected): + _check_schema(table_schema, other_schema, null_nested_fields, eq_to_nested_fields) + + +@pytest.mark.french +def test__validate_mutual_exclusive_and_union_with_field_appear_in_both_filter_and_dataframe(): + table_schema: Schema = Schema( + NestedField(field_id=1, name="foo", field_type=StringType(), required=False), + NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), + schema_id=1, + identifier_field_ids=[2], + ) + other_schema: pa.Schema = pa.schema([ + pa.field('foo', pa.string()), + pa.field('bar', pa.int32(), nullable=False), + ]) + + null_nested_fields: List[NestedField] = [NestedField(field_id=1, name="foo", field_type=StringType(), required=False)] + eq_to_nested_fields: List[NestedField] = [NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False)] + expected = re.escape('''With partition fields in static overwrite filter as ['foo', 'baz'], mismatch in fields: +┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ +┃ ┃ Table field ┃ Dataframe field ┃ +┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ +│ ❌ │ 1: foo: optional string │ Appears in both filter and arrow table. │ +│ ✅ │ 2: bar: required int │ 2: bar: required int │ +│ ✅ │ 3: baz: optional boolean │ Appears in filter but not arrow table. │ +└────┴──────────────────────────┴─────────────────────────────────────────┘ +''') + with pytest.raises(ValueError, match=expected): + _check_schema(table_schema, other_schema, null_nested_fields, eq_to_nested_fields) + + +# to do, add tests for bind value failure?????? +@pytest.mark.integration +@pytest.mark.parametrize( + "pred, raises, is_null_list, eq_to_list", + [ + (EqualTo(Reference("foo"), "hello"), False, [], [EqualTo(Reference("foo"), "hello")]), + (IsNull(Reference("foo")), False, [IsNull(Reference("foo"))], []), + (And(IsNull(Reference("foo")),EqualTo(Reference("foo"), "hello")), False, [IsNull(Reference("foo"))], [EqualTo(Reference("foo"), "hello")]), + (NotNull, True, [], []), + (NotEqualTo, True, [], []), + (LessThan(Reference("foo"), 5), True, [], []), + (Or(IsNull(Reference("foo")),EqualTo(Reference("foo"), "hello")), True, [], []), + (And(EqualTo(Reference("foo"), "hello"), And(IsNull(Reference("foo")), EqualTo(Reference("boo"), "hello"))), False, [IsNull(Reference("foo"))], [EqualTo(Reference("foo"), "hello"), EqualTo(Reference("boo"), "hello")]), + ], +) +def test__fetch_fields_and_validate_expression_type(pred, raises, is_null_list, eq_to_list)-> None: + if raises: + with pytest.raises(ValueError): # match=expected): to do + res = _fetch_fields_and_validate_expression_type(pred) + else: + res = _fetch_fields_and_validate_expression_type(pred) + print(f":::::::::::::adrian, {res=}") + print(f"lets check the strings, {set([str(e) for e in res[1]])=}") + assert set([str(e) for e in res[0]]) == set([str(e) for e in is_null_list]) + assert set([str(e) for e in res[1]]) == set([str(e) for e in eq_to_list]) + + def test_schema_mismatch_type(table_schema_simple: Schema) -> None: other_schema = pa.schema(( From 64fbc2a2b9e8efc786b7b8bd4652388061c8b567 Mon Sep 17 00:00:00 2001 From: Adrian Qin <147659252+jqin61@users.noreply.github.com> Date: Thu, 21 Mar 2024 18:12:47 +0000 Subject: [PATCH 2/6] fix tests --- pyiceberg/expressions/__init__.py | 7 + pyiceberg/io/pyarrow.py | 5 +- pyiceberg/table/__init__.py | 192 +++++++----- tests/integration/test_partitioned_writes.py | 40 ++- tests/table/test_init.py | 302 +++++++++++-------- 5 files changed, 329 insertions(+), 217 deletions(-) diff --git a/pyiceberg/expressions/__init__.py b/pyiceberg/expressions/__init__.py index 5adf3a8a48..2c46ffb562 100644 --- a/pyiceberg/expressions/__init__.py +++ b/pyiceberg/expressions/__init__.py @@ -383,6 +383,10 @@ def __repr__(self) -> str: @abstractmethod def as_bound(self) -> Type[BoundUnaryPredicate[Any]]: ... + def __hash__(self) -> int: + return hash(str(self)) + + class BoundUnaryPredicate(BoundPredicate[L], ABC): def __repr__(self) -> str: @@ -698,6 +702,9 @@ def __repr__(self) -> str: @abstractmethod def as_bound(self) -> Type[BoundLiteralPredicate[L]]: ... + def __hash__(self) -> int: + return hash(str(self)) + class BoundLiteralPredicate(BoundPredicate[L], ABC): literal: Literal[L] diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 5c26f1f96c..43a7132175 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -1725,6 +1725,7 @@ def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteT file_path = f'{table_metadata.location}/data/{task.generate_data_file_path("parquet")}' # generate_data_file_filename schema = table_metadata.schema() + arrow_file_schema = schema_to_pyarrow(schema) fo = io.new_output(file_path) @@ -1735,7 +1736,9 @@ def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteT ) with fo.create(overwrite=True) as fos: with pq.ParquetWriter(fos, schema=arrow_file_schema, **parquet_writer_kwargs) as writer: - writer.write_table(task.df, row_group_size=row_group_size) + # align the columns accordingly in case input arrow table has columns in order different from iceberg table + df_to_write = task.df.select(arrow_file_schema.names) + writer.write_table(df_to_write, row_group_size=row_group_size) data_file = DataFile( content=DataFileContent.DATA, diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index ad03bc0177..ff2dca4360 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -191,12 +191,15 @@ def _validate_static_overwrite_filter_field(unbound_expr: BooleanExpression, tab # step 1: check the unbound_expr is within the schema and the value matches the schema print(f"unbound_expr is {unbound_expr=}") bound_expr = unbound_expr.bind(table_schema) - print(f"{bound_expr=}") - print(f"{bound_expr.term=}, {bound_expr.term.ref()=}, {bound_expr.term.ref().unbound_expr=}, {bound_expr.term.ref().unbound_expr.field_id=}") - nested_field = bound_expr.term.ref().unbound_expr - - # step 2: check the unbound_expr is within the partition spec + + # step 2: check non nullable column is not partitioned overwriten with isNull. + # It has to break cuz we cannot add null value for that column if it is non-nullable. + if isinstance(bound_expr, AlwaysFalse): + raise ValueError(f"Static overwriting with part of the explicit partition filter not meaningful (e.g. specifing a non-nullable partition field to be null).") + + # step 3: check the unbound_expr is within the partition spec # this is the part unbound_expr + nested_field = bound_expr.term.ref().field part_fields = spec.fields_by_source_id(nested_field.field_id) print(f"{part_fields=}") if len(part_fields) != 1: @@ -204,10 +207,10 @@ def _validate_static_overwrite_filter_field(unbound_expr: BooleanExpression, tab part_field = part_fields[0] - # step 3: check the unbound_expr is with identity transform + # step 4: check the unbound_expr is with identity transform print(f"{part_field.transform=}") if not isinstance(part_field.transform, IdentityTransform): - raise ValueError(f"static overwrite partition filter can only apply to partition fields which are without hidden transform, but get {part_field.transform=} for {bound_field.term.ref().unbound_expr=}") + raise ValueError(f"static overwrite partition filter can only apply to partition fields which are without hidden transform, but get {part_field.transform=} for {nested_field=}") return bound_expr # nested_field @@ -232,9 +235,12 @@ def _validate_static_overwrite_filter(table_schema: Schema, overwrite_filter: Bo return bound_is_null_predicates, bound_eq_to_predicates +import pyarrow as pa def _fill_in_df(df: pa.Table, bound_is_null_predicates: List[BoundIsNull], bound_eq_to_predicates: List[BoundEqualTo]): - is_null_nested_fields = [bound_field.term.ref().field for bound_predicate in bound_is_null_predicates] - eq_to_nested_fields = [bound_field.term.ref().field for bound_predicate in bound_eq_to_predicates] + """Use bound filter predicates to extend the pyarrow with correct schema matching the iceberg schema and fill in the values. + """ + is_null_nested_fields = [bound_predicate.term.ref().field for bound_predicate in bound_is_null_predicates] + eq_to_nested_fields = [bound_predicate.term.ref().field for bound_predicate in bound_eq_to_predicates] from itertools import chain schema = Schema(*chain(is_null_nested_fields, eq_to_nested_fields)) print(f"{schema=}") @@ -244,16 +250,17 @@ def _fill_in_df(df: pa.Table, bound_is_null_predicates: List[BoundIsNull], bound print(f"{pa_schema=}") is_null_nested_field_name_values = zip([nested_field.name for nested_field in is_null_nested_fields], [None]*len(bound_is_null_predicates)) - eq_to_nested_field_name_values = zip([nested_field.name for nested_field in eq_to_nested_fields], [predicate.literal for predicate in bound_eq_to_predicates]) + eq_to_nested_field_name_values = zip([nested_field.name for nested_field in eq_to_nested_fields], [predicate.literal.value for predicate in bound_eq_to_predicates]) num_rows = df.num_rows - for field_name, value in is_null_nested_fields + eq_to_nested_fields: + for field_name, value in chain(is_null_nested_field_name_values,eq_to_nested_field_name_values): + print(f"{field_name=}, {value=}") pa_field = pa_schema.field(field_name) - literal_array = pa.array([None] * num_rows, type=pa_field.type) - df = df.add_column(df.num_columns, field.name, null_array) + literal_array = pa.array([value] * num_rows, type=pa_field.type) + df = df.add_column(df.num_columns, field_name, literal_array) return df -def _check_schema(table_schema: Schema, other_schema: "pa.Schema", to_truncate: List[NestedField]=[]): +def _check_schema(table_schema: Schema, other_schema: "pa.Schema", filter_predicates: List[BoundPredicate]=[]): from pyiceberg.io.pyarrow import _pyarrow_to_schema_without_ids, pyarrow_to_schema name_mapping = table_schema.name_mapping @@ -265,52 +272,75 @@ def _check_schema(table_schema: Schema, other_schema: "pa.Schema", to_truncate: raise ValueError( f"PyArrow table contains more columns: {', '.join(sorted(additional_names))}. Update the schema first (hint, use union_by_name)." ) from e - print(f"{task_schema=}, {task_schema.as_struct()=}") - - # check fields dont step into itself, and do not step into each other, maybe we could move this to other 1+3(here) fields check - - # check mutual + union - remaining_schema = table_schema - if len(to_truncate) != 0: - remaining_schema = _truncate_fields(table_schema, to_truncate) - print(f"{remaining_schema.as_struct()=}") - print(f"{task_schema.as_struct()=}") - if remaining_schema.as_struct() != task_schema.as_struct(): - from rich.console import Console - from rich.table import Table as RichTable + def compare_and_rich_print(table_schema: Schema, task_schema: Schema): + sorted_table_schema = Schema(*sorted(table_schema.fields, key=lambda field:field.field_id)) + sorted_task_schema = Schema(*sorted(task_schema.fields, key=lambda field:field.field_id)) + if sorted_table_schema.as_struct() != sorted_task_schema.as_struct(): + from rich.console import Console + from rich.table import Table as RichTable - console = Console(record=True) + console = Console(record=True) - rich_table = RichTable(show_header=True, header_style="bold") - rich_table.add_column("") - rich_table.add_column("Table field") - rich_table.add_column("Dataframe field") + rich_table = RichTable(show_header=True, header_style="bold") + rich_table.add_column("") + rich_table.add_column("Table field") + rich_table.add_column("Dataframe field") - partition_filter_field_names = set([nested_field.name for nested_field in to_truncate]) - print(f"{partition_filter_field_names=}") - - for lhs in table_schema.fields: - if lhs.name in partition_filter_field_names: - try: - rhs = task_schema.find_field(lhs.field_id) - rich_table.add_row("❌", str(lhs), "Appears in both filter and arrow table.") - except ValueError: - rich_table.add_row("✅", str(lhs), "Appears in filter but not arrow table.") - else: + for lhs in table_schema.fields: try: rhs = task_schema.find_field(lhs.field_id) rich_table.add_row("✅" if lhs == rhs else "❌", str(lhs), str(rhs)) except ValueError: rich_table.add_row("❌", str(lhs), "Missing") - console.print(rich_table) - raise ValueError(f"With partition fields in static overwrite filter as {[nested_field.name for nested_field in to_truncate]}, mismatch in fields:\n{console.export_text()}") + console.print(rich_table) + raise ValueError(f"Mismatch in fields:\n{console.export_text()}") + + def compare_and_rich_print_with_filter(table_schema: Schema, task_schema: Schema, filter_predicates: List[BoundPredicate]): + filter_fields = [bound_predicate.term.ref().field for bound_predicate in filter_predicates] + remaining_schema = _truncate_fields(table_schema, to_truncate = filter_fields) + sorted_remaining_schema = Schema(*sorted(remaining_schema.fields, key=lambda field:field.field_id)) + sorted_task_schema = Schema(*sorted(task_schema.fields, key=lambda field:field.field_id)) + if sorted_remaining_schema.as_struct() != sorted_task_schema.as_struct(): + from rich.console import Console + from rich.table import Table as RichTable + + console = Console(record=True) + + rich_table = RichTable(show_header=True, header_style="bold") + rich_table.add_column("") + rich_table.add_column("Table field") + rich_table.add_column("Dataframe field") + rich_table.add_column("Overwrite filter field") + + filter_field_names = [field.name for field in filter_fields] + for lhs in table_schema.fields: + if lhs.name in filter_field_names: + try: + rhs = task_schema.find_field(lhs.field_id) + rich_table.add_row("❌", str(lhs), str(rhs), lhs.name) + except ValueError: + rich_table.add_row("✅", str(lhs), "N/A", lhs.name) + else: + try: + rhs = task_schema.find_field(lhs.field_id) + rich_table.add_row("✅" if lhs == rhs else "❌", str(lhs), str(rhs), "N/A") + except ValueError: + rich_table.add_row("❌", str(lhs), "Missing", "N/A") + + console.print(rich_table) + raise ValueError(f"Mismatch in fields:\n{console.export_text()}") + + if len(filter_predicates) != 0: + compare_and_rich_print_with_filter(table_schema, task_schema, filter_predicates) + else: + compare_and_rich_print(table_schema, task_schema) + def _truncate_fields(table_schema: Schema, to_truncate: List[NestedField]) -> Schema: to_truncate_fields_source_ids = set(nested_field.field_id for nested_field in to_truncate) truncated = [field for field in table_schema.fields if field.field_id not in to_truncate_fields_source_ids] - print(f"{truncated=}") return Schema(*truncated) @@ -318,24 +348,33 @@ def _validate_static_overwrite_filter_expr_type(expr: BooleanExpression): '''Validate whether expression only has 1)And 2)IsNull and 3)EqualTo and break down the raw expression into IsNull and EqualTo. ''' from pyiceberg.io.pyarrow import _pyarrow_to_schema_without_ids, pyarrow_to_schema + from collections import defaultdict def _recursively_fetch_fields(expr, is_null_predicates, eq_to_predicates) -> None: print(f"get expression as {expr=}") if isinstance(expr, EqualTo): - eq_to_predicates.append(expr) + duplication_check[expr.term.name].add(expr) + eq_to_predicates.add(expr) elif isinstance(expr, IsNull): - is_null_predicates.append(expr) + duplication_check[expr.term.name].add(expr) + is_null_predicates.add(expr) elif isinstance(expr, And): _recursively_fetch_fields(expr.left, is_null_predicates, eq_to_predicates) _recursively_fetch_fields(expr.right, is_null_predicates, eq_to_predicates) else: raise ValueError(f"static overwrite partitioning filter can only be isequalto, is null, and, alwaysTrue, but get {expr=}") - is_null_predicates = [] - eq_to_predicates = [] + duplication_check = defaultdict(set) + is_null_predicates = set() + eq_to_predicates = set() _recursively_fetch_fields(expr, is_null_predicates, eq_to_predicates) + for field, expr_set in duplication_check.items(): + if len(expr_set) != 1: + raise ValueError(f"static overwrite partitioning filter has more than 1 different predicates with same field {expr_set}") print(f"{is_null_predicates=}") print(f"{eq_to_predicates=}") + + # check fields don't step into itself, and do not step into each other, maybe we could move this to other 1+3(here) fields check return is_null_predicates, eq_to_predicates @@ -1347,19 +1386,21 @@ def overwrite(self, df: pa.Table, overwrite_filter: Union[str, BooleanExpression if not isinstance(df, pa.Table): raise ValueError(f"Expected PyArrow table, got: {df}") - + overwrite_filter = _parse_row_filter(overwrite_filter) if not overwrite_filter == ALWAYS_TRUE: - bound_is_null_predicates, bound_eq_to_predicates = _validate_static_overwrite_filter(table_schema=self.schema(), overwrite_filter=overwrite_filter, spec=self.metadata.spec()) -> None: - print("----check fields in partition filter and in the dataframe are mutual exclusive but unioning to full schema") + bound_is_null_predicates, bound_eq_to_predicates = _validate_static_overwrite_filter(table_schema=self.schema(), overwrite_filter=overwrite_filter, spec=self.metadata.spec()) + print("----check fields in partition filter and in the dataframe are mutual exclusive but unioning to full schema") - is_null_nested_fields = [bound_field.term.ref().field for bound_predicate in bound_is_null_predicates] - eq_to_nested_fields = [bound_field.term.ref().field for bound_predicate in bound_eq_to_predicates] + is_null_nested_fields = [bound_predicate.term.ref().field for bound_predicate in bound_is_null_predicates] + eq_to_nested_fields = [bound_predicate.term.ref().field for bound_predicate in bound_eq_to_predicates] print(f"{is_null_nested_fields=}") print(f"{eq_to_nested_fields=}") - _check_schema(table_schema, other_schema, to_truncate = is_null_nested_fields + eq_to_nested_fields) - _fill_in_df(df, bound_is_null_predicates, bound_eq_to_predicates) + #_check_schema(table_schema, other_schema, to_truncate = is_null_nested_fields + eq_to_nested_fields) + # bound_is_null_predicates + _check_schema(table_schema=self.schema(), other_schema=df.schema, filter_predicates = bound_is_null_predicates + bound_eq_to_predicates) + df = _fill_in_df(df, bound_is_null_predicates, bound_eq_to_predicates) else: - _check_schema(table_schema, other_schema) + _check_schema(table_schema=self.schema(), other_schema=df.schema) with self.transaction() as txn: @@ -2887,7 +2928,9 @@ def _get_partition_sort_order(partition_columns: list[str], reverse: bool = Fals def get_partition_columns(iceberg_table_metadata: TableMetadata, arrow_table: pa.Table) -> list[str]: + print(f"{arrow_table=}") arrow_table_cols = set(arrow_table.column_names) + print(f"{arrow_table_cols=}") partition_cols = [] for transform_field in iceberg_table_metadata.spec().fields: column_name = iceberg_table_metadata.schema().find_column_name(transform_field.source_id) @@ -2958,28 +3001,29 @@ def partition(iceberg_table_metadata: TableMetadata, arrow_table: pa.Table) -> I partition_columns = get_partition_columns(iceberg_table_metadata, arrow_table) sort_order_options = _get_partition_sort_order(partition_columns, reverse=False) + print(f"{sort_order_options=}") sorted_arrow_table_indices = pc.sort_indices(arrow_table, **sort_order_options) - # Efficiently avoid applying the grouping algorithm when the table is already sorted - slice_instructions: list[dict[str, Any]] = [] + # group table by partition scheme if verify_table_already_sorted(sorted_arrow_table_indices): - slice_instructions = [{"offset": 0, "length": 1}] + arrow_table_grouped_by_partition = arrow_table else: - # group table by partition scheme arrow_table_grouped_by_partition = arrow_table.take(sorted_arrow_table_indices) - reversing_sort_order_options = _get_partition_sort_order(partition_columns, reverse=True) - reverse_sort_indices = pc.sort_indices(arrow_table_grouped_by_partition, **reversing_sort_order_options).to_pylist() - - last = len(reverse_sort_indices) - reverse_sort_indices_size = len(reverse_sort_indices) - ptr = 0 - while ptr < reverse_sort_indices_size: - group_size = last - reverse_sort_indices[ptr] - offset = reverse_sort_indices[ptr] - slice_instructions.append({"offset": offset, "length": group_size}) - last = reverse_sort_indices[ptr] - ptr = ptr + group_size + slice_instructions: list[dict[str, Any]] = [] + + reversing_sort_order_options = _get_partition_sort_order(partition_columns, reverse=True) + reverse_sort_indices = pc.sort_indices(arrow_table_grouped_by_partition, **reversing_sort_order_options).to_pylist() + + last = len(reverse_sort_indices) + reverse_sort_indices_size = len(reverse_sort_indices) + ptr = 0 + while ptr < reverse_sort_indices_size: + group_size = last - reverse_sort_indices[ptr] + offset = reverse_sort_indices[ptr] + slice_instructions.append({"offset": offset, "length": group_size}) + last = reverse_sort_indices[ptr] + ptr = ptr + group_size table_partitions: list[TablePartition] = _get_table_partitions( arrow_table_grouped_by_partition, iceberg_table_metadata.spec(), iceberg_table_metadata.schema(), slice_instructions diff --git a/tests/integration/test_partitioned_writes.py b/tests/integration/test_partitioned_writes.py index 599366a24e..fee04ebe62 100644 --- a/tests/integration/test_partitioned_writes.py +++ b/tests/integration/test_partitioned_writes.py @@ -445,7 +445,8 @@ def test_summaries_with_null(spark: SparkSession, session_catalog: Catalog, arro tbl.append(arrow_table_with_null) tbl.overwrite(arrow_table_with_null) tbl.append(arrow_table_with_null) - tbl.overwrite(arrow_table_with_null, overwrite_filter="int=1") + valid_arrow_table_with_null_to_overwrite = arrow_table_with_null.drop(['int']) + tbl.overwrite(valid_arrow_table_with_null_to_overwrite, overwrite_filter="int=1") rows = spark.sql( f""" @@ -513,16 +514,16 @@ def test_summaries_with_null(spark: SparkSession, session_catalog: Catalog, arro # static overwrite which deletes 2 record (one from step3, one from step4) and 2 datafile, adding 3 new data files and 3 records, so total data files and records are 6 - 2 + 3 = 7 assert summaries[4] == { 'removed-files-size': '10790', - 'added-data-files': '3', + 'added-data-files': '1', 'total-equality-deletes': '0', 'added-records': '3', 'deleted-data-files': '2', 'total-position-deletes': '0', - 'added-files-size': '15029', + 'added-files-size': '5455', 'total-delete-files': '0', 'deleted-records': '2', - 'total-files-size': '34297', - 'total-data-files': '7', + 'total-files-size': '24723', + 'total-data-files': '5', 'total-records': '7', } @@ -547,7 +548,9 @@ def test_data_files_with_table_partitioned_with_null( tbl.append(arrow_table_with_null) tbl.overwrite(arrow_table_with_null) tbl.append(arrow_table_with_null) - tbl.overwrite(arrow_table_with_null, overwrite_filter="int=1") + + valid_arrow_table_with_null_to_overwrite = arrow_table_with_null.drop(['int']) + tbl.overwrite(valid_arrow_table_with_null_to_overwrite, overwrite_filter="int=1") # first append links to 1 manifest file (M1) # second append's manifest list links to 2 manifest files (M1, M2) @@ -562,7 +565,7 @@ def test_data_files_with_table_partitioned_with_null( # M3 0 0 6 S3 # M4 3 0 0 S3 # M5 3 0 0 S4 - # M6 3 0 0 S5 + # M6 1 0 0 S5 # M7 0 4 2 S5 spark.sql( @@ -576,7 +579,7 @@ def test_data_files_with_table_partitioned_with_null( FROM {identifier}.all_manifests """ ).collect() - assert [row.added_data_files_count for row in rows] == [3, 3, 3, 3, 0, 3, 3, 3, 0] + assert [row.added_data_files_count for row in rows] == [3, 3, 3, 3, 0, 3, 3, 1, 0] assert [row.existing_data_files_count for row in rows] == [0, 0, 0, 0, 0, 0, 0, 0, 4] assert [row.deleted_data_files_count for row in rows] == [0, 0, 0, 0, 6, 0, 0, 0, 2] @@ -669,11 +672,24 @@ def test_query_filter_after_append_overwrite_table_with_expr( properties={'format-version': '1'}, ) - for _ in range(2): + + for i in range(3): tbl.append(arrow_table_with_null) - tbl.overwrite(arrow_table_with_null, expr) + print("this is ", i) + spark.sql(f"refresh table {identifier}") + spark.sql(f"select file_path from {identifier}.files").show(20, False) + spark.sql(f"select * from {identifier}").show(20, False) + + + valid_arrow_table_with_null_to_overwrite = arrow_table_with_null.drop([part_col]) + tbl.overwrite(valid_arrow_table_with_null_to_overwrite, expr) iceberg_table = session_catalog.load_table(identifier=identifier) + spark.sql(f"refresh table {identifier}") + print("this is 3") + spark.sql(f"select file_path from {identifier}.files").show(20, False) spark.sql(f"select * from {identifier}").show(20, False) - assert iceberg_table.scan(row_filter=expr).to_arrow().num_rows == 1 - assert iceberg_table.scan().to_arrow().num_rows == 7 + + assert iceberg_table.scan().to_arrow().num_rows == 9 + assert iceberg_table.scan(row_filter=expr).to_arrow().num_rows == 3 + diff --git a/tests/table/test_init.py b/tests/table/test_init.py index 564fc94c93..116ab53ca6 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -82,10 +82,10 @@ _TableMetadataUpdateContext, update_table_metadata, verify_table_already_sorted, - _fetch_fields_and_validate_expression_type, + _validate_static_overwrite_filter_expr_type, _validate_static_overwrite_filter_field, _check_schema, - # _validate_static_overwrite_filter + _fill_in_df ) from pyiceberg.table.metadata import INITIAL_SEQUENCE_NUMBER, TableMetadataUtil, TableMetadataV2, _generate_snapshot_id from pyiceberg.table.snapshots import ( @@ -1070,7 +1070,6 @@ def test_correct_schema() -> None: #_validate_static_overwrite_filter_field -@pytest.mark.french def test__validate_static_overwrite_filter_field_fail_on_non_schema_fields_in_filter()-> None: #pred: BooleanExpression, table_schema: Schema, spec:PartitionSpec)-> None: # todo: is it possible to make boolean expression more specific, like bound expression? no it is not bound yet test_schema = Schema( @@ -1085,10 +1084,40 @@ def test__validate_static_overwrite_filter_field_fail_on_non_schema_fields_in_fi PartitionField(source_id=1, field_id=1001, transform=IdentityTransform(), name="test_part_col") ) with pytest.raises(ValueError, match=f"Could not find field with name {pred.term.name}, case_sensitive=True"): - _validate_static_overwrite_filter_field(field=pred, table_schema=test_schema, spec=partition_spec) + _validate_static_overwrite_filter_field(unbound_expr=pred, table_schema=test_schema, spec=partition_spec) +@pytest.mark.zy +def test_mine(table_schema_simple) -> None: + pred = IsNull(Reference("bar")) + pred.bind(table_schema_simple) + print("xxx", pred.term.name) + + from pyiceberg.io.pyarrow import _pyarrow_to_schema_without_ids, pyarrow_to_schema + pa_table = pa.table( + {"bar": [1, 2, 3], "foo": ["a", "b", "c"], "baz": [True, False, None]}, + ) + name_mapping = table_schema_simple.name_mapping + print("xxxx!", pyarrow_to_schema(pa_table.schema, name_mapping=name_mapping)) + + +def test__fill_in_df(table_schema_simple) -> None: + df = pa.table( + {"baz": [True, False, None]} + ) + unbound_is_null_predicates = [IsNull(Reference("foo"))] + unbound_eq_to_predicates = [EqualTo(Reference("bar"), 3)] + bound_is_null_predicates = [unbound_predicate.bind(table_schema_simple) for unbound_predicate in unbound_is_null_predicates] + bound_eq_to_predicates = [unbound_predicate.bind(table_schema_simple) for unbound_predicate in unbound_eq_to_predicates] + filled_df = _fill_in_df(df = df, bound_is_null_predicates = bound_is_null_predicates, bound_eq_to_predicates = bound_eq_to_predicates) + print(f"{filled_df=}") + expected = pa.table( + { + "baz": [True, False, None], + "foo": [None, None, None], + "bar": [3,3,3] + }, schema=pa.schema([pa.field('baz', pa.bool_()), pa.field('foo', pa.string()), pa.field('bar', pa.int32())])) + assert filled_df == expected -@pytest.mark.french def test__validate_static_overwrite_filter_field_fail_on_non_part_fields_in_filter()-> None: #pred: BooleanExpression, table_schema: Schema, spec:PartitionSpec)-> None: # todo: is it possible to make boolean expression more specific, like bound expression? no it is not bound yet test_schema = Schema( @@ -1104,10 +1133,9 @@ def test__validate_static_overwrite_filter_field_fail_on_non_part_fields_in_filt ) import re with pytest.raises(ValueError, match=re.escape("get len(part_fields)=0, not 1, if this number is 0, indicating the static filter is not within the partition fields, which is invalid")): - _validate_static_overwrite_filter_field(field=pred, table_schema=test_schema, spec=partition_spec) + _validate_static_overwrite_filter_field(unbound_expr=pred, table_schema=test_schema, spec=partition_spec) #to do add one test that the partition fields passed -@pytest.mark.french def test__validate_static_overwrite_filter_field_fail_on_non_identity_transorm_filter()-> None: #pred: BooleanExpression, table_schema: Schema, spec:PartitionSpec)-> None: # todo: is it possible to make boolean expression more specific, like bound expression? no it is not bound yet test_schema = Schema( @@ -1124,10 +1152,9 @@ def test__validate_static_overwrite_filter_field_fail_on_non_identity_transorm_f ) # import re with pytest.raises(ValueError, match="static overwrite partition filter can only apply to partition fields which are without hidden transform, but get.*"): - _validate_static_overwrite_filter_field(field=pred, table_schema=test_schema, spec=partition_spec) + _validate_static_overwrite_filter_field(unbound_expr=pred, table_schema=test_schema, spec=partition_spec) # combine this with above -@pytest.mark.french def test__validate_static_overwrite_filter_field_succeed_on_an_identity_field_although_table_has_hidden_partition()-> None: #pred: BooleanExpression, table_schema: Schema, spec:PartitionSpec)-> None: # todo: is it possible to make boolean expression more specific, like bound expression? no it is not bound yet test_schema = Schema( @@ -1144,9 +1171,9 @@ def test__validate_static_overwrite_filter_field_succeed_on_an_identity_field_al ) # import re #with pytest.raises(ValueError, match="static overwrite partition filter can only apply to partition fields which are without hidden transform, but get.*"): - _validate_static_overwrite_filter_field(field=pred, table_schema=test_schema, spec=partition_spec) + _validate_static_overwrite_filter_field(unbound_expr=pred, table_schema=test_schema, spec=partition_spec) + -@pytest.mark.french def test__validate_static_overwrite_filter_field_fail_to_bind()-> None: #pred: BooleanExpression, table_schema: Schema, spec:PartitionSpec)-> None: # todo: is it possible to make boolean expression more specific, like bound expression? no it is not bound yet test_schema = Schema( @@ -1162,39 +1189,30 @@ def test__validate_static_overwrite_filter_field_fail_to_bind()-> None: #pred: B PartitionField(source_id=1, field_id=1002, transform=TruncateTransform(2), name="foo_trunc") ) with pytest.raises(ValueError, match="Could not convert an incompatible type into a int"): - _validate_static_overwrite_filter_field(field=pred, table_schema=test_schema, spec=partition_spec) - -# # to do, raises should use full match to check whether it is overlapping error or it is unioning error -# # to do, in the raw code and test, should add that filter is not conflicting with itself,, e.g. the same field can only be referenced once -# # advanced: as long as the values are different -# # as to do, a filter can and a same expr multiple times -# @pytest.mark.integration -# @pytest.mark.parametrize( -# "task_field_names, eq_to_fields_names, null_field_names, schema_field_names, raises", -# [ -# # good case -# (["f1"], ["f2"], ["f3"], ["f1", "f2", "f3"], False), -# # filter has overlapping fields with arrow fields -# (["f1", "f2"], ["f2"], ["f3"], ["f1", "f2", "f3"], True), -# (["f1", "f2"], ["f3"], ["f1"], ["f1", "f2", "f3"], True), -# # filter has one field referenced multiple times -# #(["f1", "f2"], ["f3"], ["f3"], ["f1", "f2", "f3"], True), -# #(["f1", "f2"], ["f3, f3"], ["f4"], ["f1", "f2", "f3", "f4"], True), -# #(["f1", "f2"], ["f4"], ["f3, f3"], ["f1", "f2", "f3", "f4"], True), -# # shortage -# (["f1"], ["f2"], ["f3"], ["f1", "f2", "f3", "f4"], True), -# ] -# ) -# def test__validate_mutual_exclusive_and_union(task_field_names, eq_to_fields_names, null_field_names, schema_field_names, raises): -# if raises: -# with pytest.raises(ValueError): # match=expected): to do -# _check_schema(task_field_names, eq_to_fields_names, null_field_names, schema_field_names) -# else: -# _check_schema(task_field_names, eq_to_fields_names, null_field_names, schema_field_names) - -# to do the truncated is sorted or not? arrow table schema does not have to match exactly in order right while now it has to be in order. -@pytest.mark.french -def test__validate_mutual_exclusive_and_union_succeed(): + _validate_static_overwrite_filter_field(unbound_expr=pred, table_schema=test_schema, spec=partition_spec) + + +def test__validate_static_overwrite_filter_field_fail_to_bind_due_to_non_nullable()-> None: #pred: BooleanExpression, table_schema: Schema, spec:PartitionSpec)-> None: + # todo: is it possible to make boolean expression more specific, like bound expression? no it is not bound yet + test_schema = Schema( + NestedField(field_id=1, name="foo", field_type=StringType(), required=False), + NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), + schema_id=1, + identifier_field_ids=[2], + ) + pred = IsNull(Reference("bar")) + partition_spec=PartitionSpec( + PartitionField(source_id=3, field_id=1001, transform=IdentityTransform(), name="baz"), + PartitionField(source_id=1, field_id=1002, transform=TruncateTransform(2), name="foo_trunc") + ) + with pytest.raises(ValueError, match=re.escape("Static overwriting with part of the explicit partition filter not meaningful (e.g. specifing a non-nullable partition field to be null)")): + _validate_static_overwrite_filter_field(unbound_expr=pred, table_schema=test_schema, spec=partition_spec) + + + + +def test__check_schema_with_filter_succeed(): table_schema: Schema = Schema( NestedField(field_id=1, name="foo", field_type=StringType(), required=False), NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), @@ -1205,25 +1223,13 @@ def test__validate_mutual_exclusive_and_union_succeed(): other_schema: pa.Schema = pa.schema([ pa.field('foo', pa.string()), pa.field('baz', pa.bool_()), - # pa.field('name', pa.string()), - # pa.field('age', pa.int32()), - # pa.field('email', pa.string()), - # pa.field('is_subscribed', pa.bool_()) ]) - print(f"{other_schema=}") - # pa.schema([ - # pa.field("foo", pa.string(), nullable=False), - # pa.field("bar", pa.int32(), nullable=True), - # pa.field("baz", pa.bool_(), nullable=True), - # ]) - - null_nested_fields: List[NestedField] = [] - eq_to_nested_fields: List[NestedField] = [NestedField(field_id=2, name="bar", field_type=IntegerType(), required=False)] - _check_schema(table_schema, other_schema, null_nested_fields, eq_to_nested_fields) - -# to do rename: with filter -@pytest.mark.french -def test__validate_mutual_exclusive_and_union_with_filter_fail_on_missing_field(): + + unbound_preds = [EqualTo(Reference("bar"), 15)] + filter_predicates = [pred.bind(table_schema) for pred in unbound_preds] + _check_schema(table_schema, other_schema, filter_predicates = filter_predicates) + +def test__check_schema_with_filter_succeed_on_pyarrow_table_with_random_column_order() -> None: table_schema: Schema = Schema( NestedField(field_id=1, name="foo", field_type=StringType(), required=False), NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), @@ -1233,25 +1239,44 @@ def test__validate_mutual_exclusive_and_union_with_filter_fail_on_missing_field( ) other_schema: pa.Schema = pa.schema([ pa.field('baz', pa.bool_()), + pa.field('foo', pa.string()), ]) - null_nested_fields: List[NestedField] = [] - eq_to_nested_fields: List[NestedField] = [NestedField(field_id=2, name="bar", field_type=IntegerType(), required=False)] - - expected = re.escape('''With partition fields in static overwrite filter as ['bar'], mismatch in fields: -┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ -┃ ┃ Table field ┃ Dataframe field ┃ -┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ -│ ❌ │ 1: foo: optional string │ Missing │ -│ ✅ │ 2: bar: required int │ Appears in filter but not arrow table. │ -│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │ -└────┴──────────────────────────┴────────────────────────────────────────┘ + unbound_preds = [EqualTo(Reference("bar"), 15)] + filter_predicates = [pred.bind(table_schema) for pred in unbound_preds] + _check_schema(table_schema, other_schema, filter_predicates = filter_predicates) + + +def test__check_schema_with_filter_fail_on_missing_field(): + table_schema: Schema = Schema( + NestedField(field_id=1, name="foo", field_type=StringType(), required=False), + NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), + schema_id=1, + identifier_field_ids=[2], + ) + other_schema: pa.Schema = pa.schema([ + pa.field('baz', pa.bool_()), + ]) + + unbound_preds = [EqualTo(Reference("bar"), 15)] + filter_predicates = [pred.bind(table_schema) for pred in unbound_preds] + + expected = re.escape('''Mismatch in fields: +┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓ +┃ ┃ Table field ┃ Dataframe field ┃ Overwrite filter field ┃ +┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩ +│ ❌ │ 1: foo: optional string │ Missing │ N/A │ +│ ✅ │ 2: bar: required int │ N/A │ bar │ +│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │ N/A │ +└────┴──────────────────────────┴──────────────────────────┴────────────────────────┘ ''') with pytest.raises(ValueError, match=expected): - _check_schema(table_schema, other_schema, null_nested_fields, eq_to_nested_fields) + _check_schema(table_schema, other_schema, filter_predicates = filter_predicates) -@pytest.mark.french -def test__validate_mutual_exclusive_and_union_with_filter_fail_on_nullability_mismatch(): + + +def test__check_schema_with_filter_fail_on_nullability_mismatch(): table_schema: Schema = Schema( NestedField(field_id=1, name="foo", field_type=StringType(), required=False), NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), @@ -1264,23 +1289,22 @@ def test__validate_mutual_exclusive_and_union_with_filter_fail_on_nullability_mi pa.field('bar', pa.int32()), ]) - - null_nested_fields: List[NestedField] = [] - eq_to_nested_fields: List[NestedField] = [NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False)] - expected = re.escape('''With partition fields in static overwrite filter as ['baz'], mismatch in fields: -┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ -┃ ┃ Table field ┃ Dataframe field ┃ -┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ -│ ✅ │ 1: foo: optional string │ 1: foo: optional string │ -│ ❌ │ 2: bar: required int │ 2: bar: optional int │ -│ ✅ │ 3: baz: optional boolean │ Appears in filter but not arrow table. │ -└────┴──────────────────────────┴────────────────────────────────────────┘ + unbound_preds = [EqualTo(Reference("baz"), True)] + filter_predicates = [pred.bind(table_schema) for pred in unbound_preds] + expected = re.escape('''Mismatch in fields: +┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓ +┃ ┃ Table field ┃ Dataframe field ┃ Overwrite filter field ┃ +┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩ +│ ✅ │ 1: foo: optional string │ 1: foo: optional string │ N/A │ +│ ❌ │ 2: bar: required int │ 2: bar: optional int │ N/A │ +│ ✅ │ 3: baz: optional boolean │ N/A │ baz │ +└────┴──────────────────────────┴─────────────────────────┴────────────────────────┘ ''') with pytest.raises(ValueError, match=expected): - _check_schema(table_schema, other_schema, null_nested_fields, eq_to_nested_fields) + _check_schema(table_schema, other_schema, filter_predicates = filter_predicates) -@pytest.mark.french -def test__validate_mutual_exclusive_and_union_with_filter_fail_on_type_mismatch(): + +def test__check_schema_with_filter_fail_on_type_mismatch(): table_schema: Schema = Schema( NestedField(field_id=1, name="foo", field_type=StringType(), required=False), NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), @@ -1293,23 +1317,22 @@ def test__validate_mutual_exclusive_and_union_with_filter_fail_on_type_mismatch( pa.field('bar', pa.string(), nullable=False), ]) - null_nested_fields: List[NestedField] = [] - eq_to_nested_fields: List[NestedField] = [NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False)] - expected = re.escape('''With partition fields in static overwrite filter as ['baz'], mismatch in fields: -┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ -┃ ┃ Table field ┃ Dataframe field ┃ -┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ -│ ✅ │ 1: foo: optional string │ 1: foo: optional string │ -│ ❌ │ 2: bar: required int │ 2: bar: required string │ -│ ✅ │ 3: baz: optional boolean │ Appears in filter but not arrow table. │ -└────┴──────────────────────────┴────────────────────────────────────────┘ + unbound_preds = [EqualTo(Reference("baz"), True)] + filter_predicates = [pred.bind(table_schema) for pred in unbound_preds] + expected = re.escape('''Mismatch in fields: +┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓ +┃ ┃ Table field ┃ Dataframe field ┃ Overwrite filter field ┃ +┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩ +│ ✅ │ 1: foo: optional string │ 1: foo: optional string │ N/A │ +│ ❌ │ 2: bar: required int │ 2: bar: required string │ N/A │ +│ ✅ │ 3: baz: optional boolean │ N/A │ baz │ +└────┴──────────────────────────┴─────────────────────────┴────────────────────────┘ ''') with pytest.raises(ValueError, match=expected): - _check_schema(table_schema, other_schema, null_nested_fields, eq_to_nested_fields) + _check_schema(table_schema, other_schema, filter_predicates = filter_predicates) -@pytest.mark.french -def test__validate_mutual_exclusive_and_union_with_field_appear_in_both_filter_and_dataframe(): +def test__check_schema_with_field_fail_due_to_filter_and_dataframe_hold_same_field(): table_schema: Schema = Schema( NestedField(field_id=1, name="foo", field_type=StringType(), required=False), NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), @@ -1322,44 +1345,46 @@ def test__validate_mutual_exclusive_and_union_with_field_appear_in_both_filter_a pa.field('bar', pa.int32(), nullable=False), ]) - null_nested_fields: List[NestedField] = [NestedField(field_id=1, name="foo", field_type=StringType(), required=False)] - eq_to_nested_fields: List[NestedField] = [NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False)] - expected = re.escape('''With partition fields in static overwrite filter as ['foo', 'baz'], mismatch in fields: -┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ -┃ ┃ Table field ┃ Dataframe field ┃ -┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ -│ ❌ │ 1: foo: optional string │ Appears in both filter and arrow table. │ -│ ✅ │ 2: bar: required int │ 2: bar: required int │ -│ ✅ │ 3: baz: optional boolean │ Appears in filter but not arrow table. │ -└────┴──────────────────────────┴─────────────────────────────────────────┘ + unbound_preds = [IsNull(Reference("foo")), EqualTo(Reference("baz"), True)] + filter_predicates = [pred.bind(table_schema) for pred in unbound_preds] + expected = re.escape('''Mismatch in fields: +┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓ +┃ ┃ Table field ┃ Dataframe field ┃ Overwrite filter field ┃ +┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩ +│ ❌ │ 1: foo: optional string │ 1: foo: optional string │ foo │ +│ ✅ │ 2: bar: required int │ 2: bar: required int │ N/A │ +│ ✅ │ 3: baz: optional boolean │ N/A │ baz │ +└────┴──────────────────────────┴─────────────────────────┴────────────────────────┘ ''') with pytest.raises(ValueError, match=expected): - _check_schema(table_schema, other_schema, null_nested_fields, eq_to_nested_fields) + _check_schema(table_schema, other_schema, filter_predicates = filter_predicates) -# to do, add tests for bind value failure?????? -@pytest.mark.integration @pytest.mark.parametrize( "pred, raises, is_null_list, eq_to_list", [ - (EqualTo(Reference("foo"), "hello"), False, [], [EqualTo(Reference("foo"), "hello")]), - (IsNull(Reference("foo")), False, [IsNull(Reference("foo"))], []), - (And(IsNull(Reference("foo")),EqualTo(Reference("foo"), "hello")), False, [IsNull(Reference("foo"))], [EqualTo(Reference("foo"), "hello")]), - (NotNull, True, [], []), - (NotEqualTo, True, [], []), - (LessThan(Reference("foo"), 5), True, [], []), - (Or(IsNull(Reference("foo")),EqualTo(Reference("foo"), "hello")), True, [], []), - (And(EqualTo(Reference("foo"), "hello"), And(IsNull(Reference("foo")), EqualTo(Reference("boo"), "hello"))), False, [IsNull(Reference("foo"))], [EqualTo(Reference("foo"), "hello"), EqualTo(Reference("boo"), "hello")]), + (EqualTo(Reference("foo"), "hello"), False, {}, {EqualTo(Reference("foo"), "hello")}), + (IsNull(Reference("foo")), False, {IsNull(Reference("foo"))}, {}), + (And(IsNull(Reference("foo")),EqualTo(Reference("boo"), "hello")), False, {IsNull(Reference("foo"))}, {EqualTo(Reference("boo"), "hello")}), + (NotNull, True, {}, {}), + (NotEqualTo, True, {}, {}), + (LessThan(Reference("foo"), 5), True, {}, {}), + (Or(IsNull(Reference("foo")),EqualTo(Reference("foo"), "hello")), True, {}, {}), + (And(EqualTo(Reference("foo"), "hello"), And(IsNull(Reference("baz")), EqualTo(Reference("boo"), "hello"))), False, {IsNull(Reference("baz"))}, {EqualTo(Reference("foo"), "hello"), EqualTo(Reference("boo"), "hello")}), + # Below are crowd-crush tests: a same field can only be with same literal/null, not different literals or both literal and null + # A false crush: when there are duplicated isnull/equalto, the collector should deduplicate them. + (And(EqualTo(Reference("foo"), "hello"), EqualTo(Reference("foo"), "hello")), False, {}, {EqualTo(Reference("foo"), "hello")}), + # When crush happens + (And(EqualTo(Reference("foo"), "hello"), EqualTo(Reference("foo"), "bye")), True, {}, {EqualTo(Reference("foo"), "hello"), EqualTo(Reference("foo"), "bye")}), + (And(EqualTo(Reference("foo"), "hello"), IsNull(Reference("foo"))), True, {IsNull(Reference("foo"))}, {}) ], ) -def test__fetch_fields_and_validate_expression_type(pred, raises, is_null_list, eq_to_list)-> None: +def test__validate_static_overwrite_filter_expr_type(pred, raises, is_null_list, eq_to_list)-> None: if raises: with pytest.raises(ValueError): # match=expected): to do - res = _fetch_fields_and_validate_expression_type(pred) + res = _validate_static_overwrite_filter_expr_type(pred) else: - res = _fetch_fields_and_validate_expression_type(pred) - print(f":::::::::::::adrian, {res=}") - print(f"lets check the strings, {set([str(e) for e in res[1]])=}") + res = _validate_static_overwrite_filter_expr_type(pred) assert set([str(e) for e in res[0]]) == set([str(e) for e in is_null_list]) assert set([str(e) for e in res[1]]) == set([str(e) for e in eq_to_list]) @@ -1426,6 +1451,24 @@ def test_schema_mismatch_missing_field(table_schema_simple: Schema) -> None: with pytest.raises(ValueError, match=expected): _check_schema(table_schema_simple, other_schema) +def test_schema_succeed(table_schema_simple: Schema) -> None: + other_schema = pa.schema(( + pa.field("foo", pa.string(), nullable=True), + pa.field("bar", pa.int32(), nullable=False), + pa.field("baz", pa.bool_(), nullable=True), + )) + + _check_schema(table_schema_simple, other_schema) + +def test_schema_succeed_on_pyarrow_table_reversed_order(table_schema_simple: Schema) -> None: + other_schema = pa.schema(( + pa.field("baz", pa.bool_(), nullable=True), + pa.field("bar", pa.int32(), nullable=False), + pa.field("foo", pa.string(), nullable=True), + )) + + _check_schema(table_schema_simple, other_schema) + def test_schema_mismatch_additional_field(table_schema_simple: Schema) -> None: other_schema = pa.schema(( @@ -1467,7 +1510,6 @@ def test_table_properties_raise_for_none_value(example_table_metadata_v2: Dict[s assert "None type is not a supported value in properties: property_name" in str(exc_info.value) -@pytest.mark.integration @pytest.mark.parametrize( "input_sorted_indices, expected_sorted_or_not", [ From b901444a6f20cd2b6aa7bcfe4b600c5bd154a9e3 Mon Sep 17 00:00:00 2001 From: Adrian Qin <147659252+jqin61@users.noreply.github.com> Date: Fri, 22 Mar 2024 06:52:04 +0000 Subject: [PATCH 3/6] finally fixed linting --- pyiceberg/expressions/__init__.py | 16 +- pyiceberg/table/__init__.py | 285 ++++++-------- tests/integration/test_partitioned_writes.py | 13 +- tests/table/test_init.py | 390 +++++++++---------- 4 files changed, 320 insertions(+), 384 deletions(-) diff --git a/pyiceberg/expressions/__init__.py b/pyiceberg/expressions/__init__.py index 2c46ffb562..29503eff26 100644 --- a/pyiceberg/expressions/__init__.py +++ b/pyiceberg/expressions/__init__.py @@ -373,7 +373,10 @@ def as_bound(self) -> Type[BoundPredicate[L]]: ... class UnaryPredicate(UnboundPredicate[Any], ABC): def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundUnaryPredicate[Any]: bound_term = self.term.bind(schema, case_sensitive) - return self.as_bound(bound_term) + print(f"{bound_term=}") + res = self.as_bound(bound_term) + print(f"{res=}") + return res def __repr__(self) -> str: """Return the string representation of the UnaryPredicate class.""" @@ -384,10 +387,10 @@ def __repr__(self) -> str: def as_bound(self) -> Type[BoundUnaryPredicate[Any]]: ... def __hash__(self) -> int: + """Return hash value of the UnaryPredicate class.""" return hash(str(self)) - class BoundUnaryPredicate(BoundPredicate[L], ABC): def __repr__(self) -> str: """Return the string representation of the BoundUnaryPredicate class.""" @@ -416,6 +419,10 @@ def __invert__(self) -> BoundNotNull[L]: def as_unbound(self) -> Type[IsNull]: return IsNull + def __hash__(self) -> int: + """Return hash value of the BoundIsNull class.""" + return hash(str(self)) + class BoundNotNull(BoundUnaryPredicate[L]): def __new__(cls, term: BoundTerm[L]): # type: ignore # pylint: disable=W0221 @@ -703,6 +710,7 @@ def __repr__(self) -> str: def as_bound(self) -> Type[BoundLiteralPredicate[L]]: ... def __hash__(self) -> int: + """Return hash value of the UnaryPredicate class.""" return hash(str(self)) @@ -738,6 +746,10 @@ def __invert__(self) -> BoundNotEqualTo[L]: def as_unbound(self) -> Type[EqualTo[L]]: return EqualTo + def __hash__(self) -> int: + """Return hash value of the BoundEqualTo class.""" + return hash(str(self)) + class BoundNotEqualTo(BoundLiteralPredicate[L]): def __invert__(self) -> BoundEqualTo[L]: diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index ff2dca4360..0cb4f60468 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -51,10 +51,15 @@ from pyiceberg.exceptions import CommitFailedException, ResolveError, ValidationError from pyiceberg.expressions import ( + AlwaysFalse, AlwaysTrue, And, BooleanExpression, + BoundEqualTo, + BoundIsNull, + BoundPredicate, EqualTo, + IsNull, Reference, parser, visitors, @@ -115,14 +120,7 @@ ) from pyiceberg.table.sorting import SortOrder from pyiceberg.transforms import TimeTransform, Transform, VoidTransform -from pyiceberg.typedef import ( - EMPTY_DICT, - IcebergBaseModel, - IcebergRootModel, - Identifier, - KeyDefaultDict, - Properties, -) +from pyiceberg.typedef import EMPTY_DICT, IcebergBaseModel, IcebergRootModel, Identifier, KeyDefaultDict, L, Properties from pyiceberg.types import ( IcebergType, ListType, @@ -134,42 +132,6 @@ from pyiceberg.utils.concurrent import ExecutorFactory from pyiceberg.utils.datetime import datetime_to_millis -from pyiceberg.expressions import ( - AlwaysFalse, - AlwaysTrue, - And, - BooleanExpression, - BoundEqualTo, - BoundGreaterThan, - BoundGreaterThanOrEqual, - BoundIn, - BoundIsNaN, - BoundIsNull, - BoundLessThan, - BoundLessThanOrEqual, - BoundNotEqualTo, - BoundNotIn, - BoundNotNaN, - BoundNotNull, - BoundReference, - EqualTo, - GreaterThan, - GreaterThanOrEqual, - In, - IsNaN, - IsNull, - LessThan, - LessThanOrEqual, - Not, - NotEqualTo, - NotIn, - NotNaN, - NotNull, - Or, - Reference, - UnboundPredicate, -) - if TYPE_CHECKING: import daft import pandas as pd @@ -185,82 +147,106 @@ _JAVA_LONG_MAX = 9223372036854775807 + # to do: make the types not expression but unbound predicates, this should be more precise, we already know it could only be unboundisnull and unboundequalto # actually could make it union[isnull, equalto] and make the return as union[boundisnull,boundequalto] -def _validate_static_overwrite_filter_field(unbound_expr: BooleanExpression, table_schema: Schema, spec:PartitionSpec)-> BoundPredicate: +def _bind_and_validate_static_overwrite_filter_predicate( + unbound_expr: Union[IsNull, EqualTo[L]], table_schema: Schema, spec: PartitionSpec +) -> Union[BoundIsNull[L], BoundEqualTo[L]]: # step 1: check the unbound_expr is within the schema and the value matches the schema - print(f"unbound_expr is {unbound_expr=}") - bound_expr = unbound_expr.bind(table_schema) + bound_expr: Union[BoundIsNull[L], BoundEqualTo[L], AlwaysFalse] = unbound_expr.bind(table_schema) # type: ignore # The bind returns upcast types. - # step 2: check non nullable column is not partitioned overwriten with isNull. - # It has to break cuz we cannot add null value for that column if it is non-nullable. + # step 2: check non nullable column is not partitioned overwriten with isNull. + # It has to break because we cannot fill null values into input arrow table (and parquets to write) for an iceberg field which is non-nullable. if isinstance(bound_expr, AlwaysFalse): - raise ValueError(f"Static overwriting with part of the explicit partition filter not meaningful (e.g. specifing a non-nullable partition field to be null).") + raise ValueError( + "Static overwriting with part of the explicit partition filter not meaningful (specifing a non-nullable partition field to be null)." + ) # step 3: check the unbound_expr is within the partition spec - # this is the part unbound_expr - nested_field = bound_expr.term.ref().field - part_fields = spec.fields_by_source_id(nested_field.field_id) - print(f"{part_fields=}") + if not isinstance(bound_expr, (BoundIsNull, BoundEqualTo)): + raise ValueError( + f"{unbound_expr=} binds to {bound_expr=} whose type is not expected. Expecting BoundIsNull or BoundEqualTo" + ) + nested_field: NestedField = bound_expr.term.ref().field + part_fields: List[PartitionField] = spec.fields_by_source_id(nested_field.field_id) if len(part_fields) != 1: - raise ValueError(f"get {len(part_fields)=}, not 1, if this number is 0, indicating the static filter is not within the partition fields, which is invalid") + raise ValueError(f"Get {len(part_fields)} partition fields from filter predicate {str(unbound_expr)}, expecting 1.") part_field = part_fields[0] - # step 4: check the unbound_expr is with identity transform - print(f"{part_field.transform=}") if not isinstance(part_field.transform, IdentityTransform): - raise ValueError(f"static overwrite partition filter can only apply to partition fields which are without hidden transform, but get {part_field.transform=} for {nested_field=}") + raise ValueError( + f"static overwrite partition filter can only apply to partition fields which are without hidden transform, but get {part_field.transform=} for {nested_field=}" + ) - return bound_expr # nested_field - + return bound_expr -def _validate_static_overwrite_filter(table_schema: Schema, overwrite_filter: BooleanExpression, spec:PartitionSpec) -> None: + +def _validate_static_overwrite_filter( + table_schema: Schema, overwrite_filter: BooleanExpression, spec: PartitionSpec +) -> Tuple[Set[BoundIsNull[L]], Set[BoundEqualTo[L]]]: is_null_predicates, eq_to_predicates = _validate_static_overwrite_filter_expr_type(expr=overwrite_filter) - bound_is_null_predicates = [] - bound_eq_to_predicates = [] - - print("---check is null fields with rules of filter,part,transform") - # these are expressions (not as name - fields) - for unbound_expr in is_null_predicates: - bp:BoundPredicate = _validate_static_overwrite_filter_field(unbound_expr = unbound_expr, table_schema = table_schema, spec = spec) - #bound_field.term.ref().field - bound_is_null_predicates.append(bp) - - print("----check eq to fields with rules of filter,part,transform") - for unbound_expr in eq_to_predicates: - bp:BoundPredicate = _validate_static_overwrite_filter_field(unbound_expr = unbound_expr, table_schema = table_schema, spec = spec) - bound_eq_to_predicates.append(bp) - - return bound_is_null_predicates, bound_eq_to_predicates - -import pyarrow as pa -def _fill_in_df(df: pa.Table, bound_is_null_predicates: List[BoundIsNull], bound_eq_to_predicates: List[BoundEqualTo]): - """Use bound filter predicates to extend the pyarrow with correct schema matching the iceberg schema and fill in the values. - """ + bound_is_null_preds = set() + bound_eq_to_preds = set() + for unbound_is_null in is_null_predicates: + bound_pred = _bind_and_validate_static_overwrite_filter_predicate( + unbound_expr=unbound_is_null, table_schema=table_schema, spec=spec + ) + if not isinstance(bound_pred, BoundIsNull): + raise ValueError(f"Expecting IsNull after binding {unbound_is_null} to schema but get {bound_pred}.") + bound_is_null_preds.add(bound_pred) + + for unbound_eq_to in eq_to_predicates: + bound_pred = _bind_and_validate_static_overwrite_filter_predicate( + unbound_expr=unbound_eq_to, table_schema=table_schema, spec=spec + ) + if not isinstance(bound_pred, BoundEqualTo): + raise ValueError(f"Expecting IsNull after binding {unbound_eq_to} to schema but get {bound_pred}.") + bound_eq_to_preds.add(bound_pred) + return (bound_is_null_preds, bound_eq_to_preds) # type: ignore + + +def _fill_in_df( + df: pa.Table, bound_is_null_predicates: Set[BoundIsNull[L]], bound_eq_to_predicates: Set[BoundEqualTo[L]] +) -> pa.Table: + """Use bound filter predicates to extend the pyarrow with correct schema matching the iceberg schema and fill in the values.""" + try: + import pyarrow as pa + except ModuleNotFoundError as e: + raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e + is_null_nested_fields = [bound_predicate.term.ref().field for bound_predicate in bound_is_null_predicates] eq_to_nested_fields = [bound_predicate.term.ref().field for bound_predicate in bound_eq_to_predicates] - from itertools import chain + schema = Schema(*chain(is_null_nested_fields, eq_to_nested_fields)) - print(f"{schema=}") from pyiceberg.io.pyarrow import schema_to_pyarrow pa_schema = schema_to_pyarrow(schema) - print(f"{pa_schema=}") - - is_null_nested_field_name_values = zip([nested_field.name for nested_field in is_null_nested_fields], [None]*len(bound_is_null_predicates)) - eq_to_nested_field_name_values = zip([nested_field.name for nested_field in eq_to_nested_fields], [predicate.literal.value for predicate in bound_eq_to_predicates]) - - num_rows = df.num_rows - for field_name, value in chain(is_null_nested_field_name_values,eq_to_nested_field_name_values): - print(f"{field_name=}, {value=}") + + is_null_nested_field_name_values = zip( + [nested_field.name for nested_field in is_null_nested_fields], [None] * len(bound_is_null_predicates) + ) + eq_to_nested_field_name_values = zip( + [nested_field.name for nested_field in eq_to_nested_fields], + [predicate.literal.value for predicate in bound_eq_to_predicates], + ) + + for field_name, value in chain(is_null_nested_field_name_values, eq_to_nested_field_name_values): pa_field = pa_schema.field(field_name) - literal_array = pa.array([value] * num_rows, type=pa_field.type) + literal_array = pa.array([value] * df.num_rows, type=pa_field.type) df = df.add_column(df.num_columns, field_name, literal_array) return df -def _check_schema(table_schema: Schema, other_schema: "pa.Schema", filter_predicates: List[BoundPredicate]=[]): + +# linttodo: break this down into 2 functions +def _check_schema( + table_schema: Schema, other_schema: "pa.Schema", filter_predicates: Set[BoundPredicate[L]] | None = None +) -> None: + if filter_predicates is None: + filter_predicates = set() + from pyiceberg.io.pyarrow import _pyarrow_to_schema_without_ids, pyarrow_to_schema name_mapping = table_schema.name_mapping @@ -273,9 +259,9 @@ def _check_schema(table_schema: Schema, other_schema: "pa.Schema", filter_predic f"PyArrow table contains more columns: {', '.join(sorted(additional_names))}. Update the schema first (hint, use union_by_name)." ) from e - def compare_and_rich_print(table_schema: Schema, task_schema: Schema): - sorted_table_schema = Schema(*sorted(table_schema.fields, key=lambda field:field.field_id)) - sorted_task_schema = Schema(*sorted(task_schema.fields, key=lambda field:field.field_id)) + def compare_and_rich_print(table_schema: Schema, task_schema: Schema) -> None: + sorted_table_schema = Schema(*sorted(table_schema.fields, key=lambda field: field.field_id)) + sorted_task_schema = Schema(*sorted(task_schema.fields, key=lambda field: field.field_id)) if sorted_table_schema.as_struct() != sorted_task_schema.as_struct(): from rich.console import Console from rich.table import Table as RichTable @@ -297,11 +283,13 @@ def compare_and_rich_print(table_schema: Schema, task_schema: Schema): console.print(rich_table) raise ValueError(f"Mismatch in fields:\n{console.export_text()}") - def compare_and_rich_print_with_filter(table_schema: Schema, task_schema: Schema, filter_predicates: List[BoundPredicate]): + def compare_and_rich_print_with_filter( + table_schema: Schema, task_schema: Schema, filter_predicates: Set[BoundPredicate[L]] + ) -> None: filter_fields = [bound_predicate.term.ref().field for bound_predicate in filter_predicates] - remaining_schema = _truncate_fields(table_schema, to_truncate = filter_fields) - sorted_remaining_schema = Schema(*sorted(remaining_schema.fields, key=lambda field:field.field_id)) - sorted_task_schema = Schema(*sorted(task_schema.fields, key=lambda field:field.field_id)) + remaining_schema = _truncate_fields(table_schema, to_truncate=filter_fields) + sorted_remaining_schema = Schema(*sorted(remaining_schema.fields, key=lambda field: field.field_id)) + sorted_task_schema = Schema(*sorted(task_schema.fields, key=lambda field: field.field_id)) if sorted_remaining_schema.as_struct() != sorted_task_schema.as_struct(): from rich.console import Console from rich.table import Table as RichTable @@ -336,81 +324,51 @@ def compare_and_rich_print_with_filter(table_schema: Schema, task_schema: Schema compare_and_rich_print_with_filter(table_schema, task_schema, filter_predicates) else: compare_and_rich_print(table_schema, task_schema) - + def _truncate_fields(table_schema: Schema, to_truncate: List[NestedField]) -> Schema: - to_truncate_fields_source_ids = set(nested_field.field_id for nested_field in to_truncate) + to_truncate_fields_source_ids = {nested_field.field_id for nested_field in to_truncate} truncated = [field for field in table_schema.fields if field.field_id not in to_truncate_fields_source_ids] return Schema(*truncated) -def _validate_static_overwrite_filter_expr_type(expr: BooleanExpression): - '''Validate whether expression only has 1)And 2)IsNull and 3)EqualTo and break down the raw expression into IsNull and EqualTo. - ''' - from pyiceberg.io.pyarrow import _pyarrow_to_schema_without_ids, pyarrow_to_schema +def _validate_static_overwrite_filter_expr_type(expr: BooleanExpression) -> Tuple[Set[IsNull], Set[EqualTo[L]]]: + """Validate whether expression only has 1)And 2)IsNull and 3)EqualTo and break down the raw expression into IsNull and EqualTo.""" from collections import defaultdict - def _recursively_fetch_fields(expr, is_null_predicates, eq_to_predicates) -> None: - print(f"get expression as {expr=}") + def _recursively_fetch_fields( + expr: BooleanExpression, is_null_predicates: Set[IsNull], eq_to_predicates: Set[EqualTo[L]] + ) -> None: if isinstance(expr, EqualTo): + if not isinstance(expr.term, Reference): + raise ValueError(f"Unsupported unbound term {expr.term} in {expr}, expecting a refernce.") duplication_check[expr.term.name].add(expr) eq_to_predicates.add(expr) elif isinstance(expr, IsNull): + if not isinstance(expr.term, Reference): + raise ValueError(f"Unsupported unbound term {expr.term} in {expr}, expecting a refernce.") duplication_check[expr.term.name].add(expr) is_null_predicates.add(expr) elif isinstance(expr, And): _recursively_fetch_fields(expr.left, is_null_predicates, eq_to_predicates) _recursively_fetch_fields(expr.right, is_null_predicates, eq_to_predicates) else: - raise ValueError(f"static overwrite partitioning filter can only be isequalto, is null, and, alwaysTrue, but get {expr=}") + raise ValueError( + f"static overwrite partitioning filter can only be isequalto, is null, and, alwaysTrue, but get {expr=}" + ) - duplication_check = defaultdict(set) - is_null_predicates = set() - eq_to_predicates = set() + duplication_check: Dict[str, Set[Union[IsNull, EqualTo[L]]]] = defaultdict(set) + is_null_predicates: Set[IsNull] = set() + eq_to_predicates: Set[EqualTo[L]] = set() _recursively_fetch_fields(expr, is_null_predicates, eq_to_predicates) - for field, expr_set in duplication_check.items(): + for _, expr_set in duplication_check.items(): if len(expr_set) != 1: - raise ValueError(f"static overwrite partitioning filter has more than 1 different predicates with same field {expr_set}") - print(f"{is_null_predicates=}") - print(f"{eq_to_predicates=}") + raise ValueError( + f"static overwrite partitioning filter has more than 1 different predicates with same field {expr_set}" + ) # check fields don't step into itself, and do not step into each other, maybe we could move this to other 1+3(here) fields check return is_null_predicates, eq_to_predicates - - -# def _check_schema(table_schema: Schema, other_schema: "pa.Schema") -> None: -# from pyiceberg.io.pyarrow import _pyarrow_to_schema_without_ids, pyarrow_to_schema - -# name_mapping = table_schema.name_mapping -# try: -# task_schema = pyarrow_to_schema(other_schema, name_mapping=name_mapping) -# except ValueError as e: -# other_schema = _pyarrow_to_schema_without_ids(other_schema) -# additional_names = set(other_schema.column_names) - set(table_schema.column_names) -# raise ValueError( -# f"PyArrow table contains more columns: {', '.join(sorted(additional_names))}. Update the schema first (hint, use union_by_name)." -# ) from e - -# if table_schema.as_struct() != task_schema.as_struct(): -# from rich.console import Console -# from rich.table import Table as RichTable - -# console = Console(record=True) - -# rich_table = RichTable(show_header=True, header_style="bold") -# rich_table.add_column("") -# rich_table.add_column("Table field") -# rich_table.add_column("Dataframe field") - -# for lhs in table_schema.fields: -# try: -# rhs = task_schema.find_field(lhs.field_id) -# rich_table.add_row("✅" if lhs == rhs else "❌", str(lhs), str(rhs)) -# except ValueError: -# rich_table.add_row("❌", str(lhs), "Missing") - -# console.print(rich_table) -# raise ValueError(f"Mismatch in fields:\n{console.export_text()}") class TableProperties: @@ -1388,21 +1346,19 @@ def overwrite(self, df: pa.Table, overwrite_filter: Union[str, BooleanExpression raise ValueError(f"Expected PyArrow table, got: {df}") overwrite_filter = _parse_row_filter(overwrite_filter) if not overwrite_filter == ALWAYS_TRUE: - bound_is_null_predicates, bound_eq_to_predicates = _validate_static_overwrite_filter(table_schema=self.schema(), overwrite_filter=overwrite_filter, spec=self.metadata.spec()) - print("----check fields in partition filter and in the dataframe are mutual exclusive but unioning to full schema") - - is_null_nested_fields = [bound_predicate.term.ref().field for bound_predicate in bound_is_null_predicates] - eq_to_nested_fields = [bound_predicate.term.ref().field for bound_predicate in bound_eq_to_predicates] - print(f"{is_null_nested_fields=}") - print(f"{eq_to_nested_fields=}") - #_check_schema(table_schema, other_schema, to_truncate = is_null_nested_fields + eq_to_nested_fields) - # bound_is_null_predicates - _check_schema(table_schema=self.schema(), other_schema=df.schema, filter_predicates = bound_is_null_predicates + bound_eq_to_predicates) + bound_is_null_predicates, bound_eq_to_predicates = _validate_static_overwrite_filter( + table_schema=self.schema(), overwrite_filter=overwrite_filter, spec=self.metadata.spec() + ) + + _check_schema( + table_schema=self.schema(), + other_schema=df.schema, + filter_predicates=bound_is_null_predicates.union(bound_eq_to_predicates), + ) df = _fill_in_df(df, bound_is_null_predicates, bound_eq_to_predicates) else: _check_schema(table_schema=self.schema(), other_schema=df.schema) - with self.transaction() as txn: with txn.update_snapshot().overwrite(overwrite_filter) as update_snapshot: # skip writing data files if the dataframe is empty @@ -2928,9 +2884,7 @@ def _get_partition_sort_order(partition_columns: list[str], reverse: bool = Fals def get_partition_columns(iceberg_table_metadata: TableMetadata, arrow_table: pa.Table) -> list[str]: - print(f"{arrow_table=}") arrow_table_cols = set(arrow_table.column_names) - print(f"{arrow_table_cols=}") partition_cols = [] for transform_field in iceberg_table_metadata.spec().fields: column_name = iceberg_table_metadata.schema().find_column_name(transform_field.source_id) @@ -3001,7 +2955,6 @@ def partition(iceberg_table_metadata: TableMetadata, arrow_table: pa.Table) -> I partition_columns = get_partition_columns(iceberg_table_metadata, arrow_table) sort_order_options = _get_partition_sort_order(partition_columns, reverse=False) - print(f"{sort_order_options=}") sorted_arrow_table_indices = pc.sort_indices(arrow_table, **sort_order_options) # group table by partition scheme diff --git a/tests/integration/test_partitioned_writes.py b/tests/integration/test_partitioned_writes.py index fee04ebe62..02086f5f63 100644 --- a/tests/integration/test_partitioned_writes.py +++ b/tests/integration/test_partitioned_writes.py @@ -446,7 +446,7 @@ def test_summaries_with_null(spark: SparkSession, session_catalog: Catalog, arro tbl.overwrite(arrow_table_with_null) tbl.append(arrow_table_with_null) valid_arrow_table_with_null_to_overwrite = arrow_table_with_null.drop(['int']) - tbl.overwrite(valid_arrow_table_with_null_to_overwrite, overwrite_filter="int=1") + tbl.overwrite(valid_arrow_table_with_null_to_overwrite, overwrite_filter="int=1") rows = spark.sql( f""" @@ -550,7 +550,7 @@ def test_data_files_with_table_partitioned_with_null( tbl.append(arrow_table_with_null) valid_arrow_table_with_null_to_overwrite = arrow_table_with_null.drop(['int']) - tbl.overwrite(valid_arrow_table_with_null_to_overwrite, overwrite_filter="int=1") + tbl.overwrite(valid_arrow_table_with_null_to_overwrite, overwrite_filter="int=1") # first append links to 1 manifest file (M1) # second append's manifest list links to 2 manifest files (M1, M2) @@ -672,24 +672,19 @@ def test_query_filter_after_append_overwrite_table_with_expr( properties={'format-version': '1'}, ) - - for i in range(3): + for _ in range(3): tbl.append(arrow_table_with_null) - print("this is ", i) spark.sql(f"refresh table {identifier}") spark.sql(f"select file_path from {identifier}.files").show(20, False) spark.sql(f"select * from {identifier}").show(20, False) - valid_arrow_table_with_null_to_overwrite = arrow_table_with_null.drop([part_col]) tbl.overwrite(valid_arrow_table_with_null_to_overwrite, expr) iceberg_table = session_catalog.load_table(identifier=identifier) spark.sql(f"refresh table {identifier}") - print("this is 3") spark.sql(f"select file_path from {identifier}.files").show(20, False) spark.sql(f"select * from {identifier}").show(20, False) - + assert iceberg_table.scan().to_arrow().num_rows == 9 assert iceberg_table.scan(row_filter=expr).to_arrow().num_rows == 3 - diff --git a/tests/table/test_init.py b/tests/table/test_init.py index 116ab53ca6..d72a8e1534 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -15,37 +15,32 @@ # specific language governing permissions and limitations # under the License. # pylint:disable=redefined-outer-name +import re import uuid from copy import copy -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - Generic, - Iterable, - List, - Literal, - Optional, - Set, - Tuple, - TypeVar, - Union, -) +from typing import Any, Dict, Set, Union import pyarrow as pa import pytest from pydantic import ValidationError from sortedcontainers import SortedList -import re from pyiceberg.catalog.noop import NoopCatalog from pyiceberg.exceptions import CommitFailedException from pyiceberg.expressions import ( AlwaysTrue, And, + BoundEqualTo, + BoundIsNull, + BoundPredicate, EqualTo, In, + IsNull, + LessThan, + NotEqualTo, + NotNull, + Or, + Reference, ) from pyiceberg.io import PY_IO_IMPL, load_file_io from pyiceberg.manifest import ( @@ -77,15 +72,14 @@ Table, UpdateSchema, _apply_table_update, + _bind_and_validate_static_overwrite_filter_predicate, _check_schema, + _fill_in_df, _match_deletes_to_data_file, _TableMetadataUpdateContext, + _validate_static_overwrite_filter_expr_type, update_table_metadata, verify_table_already_sorted, - _validate_static_overwrite_filter_expr_type, - _validate_static_overwrite_filter_field, - _check_schema, - _fill_in_df ) from pyiceberg.table.metadata import INITIAL_SEQUENCE_NUMBER, TableMetadataUtil, TableMetadataV2, _generate_snapshot_id from pyiceberg.table.snapshots import ( @@ -101,6 +95,7 @@ SortOrder, ) from pyiceberg.transforms import BucketTransform, IdentityTransform, TruncateTransform +from pyiceberg.typedef import L from pyiceberg.types import ( BinaryType, BooleanType, @@ -1032,46 +1027,9 @@ def test_correct_schema() -> None: assert "Snapshot not found: -1" in str(exc_info.value) -from pyiceberg.expressions import ( - AlwaysFalse, - AlwaysTrue, - And, - BooleanExpression, - BoundEqualTo, - BoundGreaterThan, - BoundGreaterThanOrEqual, - BoundIn, - BoundIsNaN, - BoundIsNull, - BoundLessThan, - BoundLessThanOrEqual, - BoundNotEqualTo, - BoundNotIn, - BoundNotNaN, - BoundNotNull, - BoundReference, - EqualTo, - GreaterThan, - GreaterThanOrEqual, - In, - IsNaN, - IsNull, - LessThan, - LessThanOrEqual, - Not, - NotEqualTo, - NotIn, - NotNaN, - NotNull, - Or, - Reference, - UnboundPredicate, -) - -#_validate_static_overwrite_filter_field -def test__validate_static_overwrite_filter_field_fail_on_non_schema_fields_in_filter()-> None: #pred: BooleanExpression, table_schema: Schema, spec:PartitionSpec)-> None: - # todo: is it possible to make boolean expression more specific, like bound expression? no it is not bound yet +# _bind_and_validate_static_overwrite_filter_predicate +def test__bind_and_validate_static_overwrite_filter_predicate_fails_on_non_schema_fields_in_filter() -> None: test_schema = Schema( NestedField(field_id=1, name="foo", field_type=StringType(), required=False), NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), @@ -1080,46 +1038,53 @@ def test__validate_static_overwrite_filter_field_fail_on_non_schema_fields_in_fi identifier_field_ids=[2], ) pred = EqualTo(Reference("not a field"), "hello") - partition_spec=PartitionSpec( + partition_spec = PartitionSpec( PartitionField(source_id=1, field_id=1001, transform=IdentityTransform(), name="test_part_col") ) - with pytest.raises(ValueError, match=f"Could not find field with name {pred.term.name}, case_sensitive=True"): - _validate_static_overwrite_filter_field(unbound_expr=pred, table_schema=test_schema, spec=partition_spec) + with pytest.raises(ValueError, match="Could not find field with name not a field, case_sensitive=True"): + _bind_and_validate_static_overwrite_filter_predicate(unbound_expr=pred, table_schema=test_schema, spec=partition_spec) -@pytest.mark.zy -def test_mine(table_schema_simple) -> None: - pred = IsNull(Reference("bar")) - pred.bind(table_schema_simple) - print("xxx", pred.term.name) - from pyiceberg.io.pyarrow import _pyarrow_to_schema_without_ids, pyarrow_to_schema - pa_table = pa.table( - {"bar": [1, 2, 3], "foo": ["a", "b", "c"], "baz": [True, False, None]}, - ) - name_mapping = table_schema_simple.name_mapping - print("xxxx!", pyarrow_to_schema(pa_table.schema, name_mapping=name_mapping)) +# @pytest.mark.zy +# def test_mine(table_schema_simple) -> None: +# pred = IsNull("bar") +# print("xxxx!", pred.term) +# pred.bind(table_schema_simple) -def test__fill_in_df(table_schema_simple) -> None: - df = pa.table( - {"baz": [True, False, None]} - ) +# from pyiceberg.io.pyarrow import _pyarrow_to_schema_without_ids, pyarrow_to_schema +# pa_table = pa.table( +# {"bar": [1, 2, 3], "foo": ["a", "b", "c"], "baz": [True, False, None]}, +# ) +# name_mapping = table_schema_simple.name_mapping +# print("xxxx!", pyarrow_to_schema(pa_table.schema, name_mapping=name_mapping)) + + +def test__fill_in_df(table_schema_simple: Schema) -> None: + df = pa.table({"baz": [True, False, None]}) unbound_is_null_predicates = [IsNull(Reference("foo"))] unbound_eq_to_predicates = [EqualTo(Reference("bar"), 3)] - bound_is_null_predicates = [unbound_predicate.bind(table_schema_simple) for unbound_predicate in unbound_is_null_predicates] - bound_eq_to_predicates = [unbound_predicate.bind(table_schema_simple) for unbound_predicate in unbound_eq_to_predicates] - filled_df = _fill_in_df(df = df, bound_is_null_predicates = bound_is_null_predicates, bound_eq_to_predicates = bound_eq_to_predicates) - print(f"{filled_df=}") + bound_is_null_predicates: Set[BoundIsNull[Any]] = { + unbound_predicate.bind(table_schema_simple) # type: ignore # because bind returns super type and python could not downcast implicitly using type annotation + for unbound_predicate in unbound_is_null_predicates + } + bound_eq_to_predicates: Set[BoundEqualTo[Any]] = { + unbound_predicate.bind(table_schema_simple) # type: ignore # because bind returns super type and python could not downcast implicitly using type annotation + for unbound_predicate in unbound_eq_to_predicates + } + filled_df = _fill_in_df( + df=df, + bound_is_null_predicates=bound_is_null_predicates, + bound_eq_to_predicates=bound_eq_to_predicates, + ) expected = pa.table( - { - "baz": [True, False, None], - "foo": [None, None, None], - "bar": [3,3,3] - }, schema=pa.schema([pa.field('baz', pa.bool_()), pa.field('foo', pa.string()), pa.field('bar', pa.int32())])) + {"baz": [True, False, None], "foo": [None, None, None], "bar": [3, 3, 3]}, + schema=pa.schema([pa.field('baz', pa.bool_()), pa.field('foo', pa.string()), pa.field('bar', pa.int32())]), + ) assert filled_df == expected -def test__validate_static_overwrite_filter_field_fail_on_non_part_fields_in_filter()-> None: #pred: BooleanExpression, table_schema: Schema, spec:PartitionSpec)-> None: - # todo: is it possible to make boolean expression more specific, like bound expression? no it is not bound yet + +def test__bind_and_validate_static_overwrite_filter_predicate_fails_on_non_part_fields_in_filter() -> None: test_schema = Schema( NestedField(field_id=1, name="foo", field_type=StringType(), required=False), NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), @@ -1128,16 +1093,19 @@ def test__validate_static_overwrite_filter_field_fail_on_non_part_fields_in_filt identifier_field_ids=[2], ) pred = EqualTo(Reference("foo"), "hello") - partition_spec=PartitionSpec( - PartitionField(source_id=2, field_id=1001, transform=IdentityTransform(), name="bar") - ) + partition_spec = PartitionSpec(PartitionField(source_id=2, field_id=1001, transform=IdentityTransform(), name="bar")) import re - with pytest.raises(ValueError, match=re.escape("get len(part_fields)=0, not 1, if this number is 0, indicating the static filter is not within the partition fields, which is invalid")): - _validate_static_overwrite_filter_field(unbound_expr=pred, table_schema=test_schema, spec=partition_spec) -#to do add one test that the partition fields passed -def test__validate_static_overwrite_filter_field_fail_on_non_identity_transorm_filter()-> None: #pred: BooleanExpression, table_schema: Schema, spec:PartitionSpec)-> None: - # todo: is it possible to make boolean expression more specific, like bound expression? no it is not bound yet + with pytest.raises( + ValueError, + match=re.escape( + "Get 0 partition fields from filter predicate EqualTo(term=Reference(name='foo'), literal=literal('hello')), expecting 1." + ), + ): + _bind_and_validate_static_overwrite_filter_predicate(unbound_expr=pred, table_schema=test_schema, spec=partition_spec) + + +def test__bind_and_validate_static_overwrite_filter_predicate_fails_on_non_identity_transorm_filter() -> None: test_schema = Schema( NestedField(field_id=1, name="foo", field_type=StringType(), required=False), NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), @@ -1146,17 +1114,21 @@ def test__validate_static_overwrite_filter_field_fail_on_non_identity_transorm_f identifier_field_ids=[2], ) pred = EqualTo(Reference("foo"), "hello") - partition_spec=PartitionSpec( + partition_spec = PartitionSpec( PartitionField(source_id=2, field_id=1001, transform=IdentityTransform(), name="bar"), - PartitionField(source_id=1, field_id=1002, transform=TruncateTransform(2), name="foo_trunc") + PartitionField(source_id=1, field_id=1002, transform=TruncateTransform(2), name="foo_trunc"), ) # import re - with pytest.raises(ValueError, match="static overwrite partition filter can only apply to partition fields which are without hidden transform, but get.*"): - _validate_static_overwrite_filter_field(unbound_expr=pred, table_schema=test_schema, spec=partition_spec) + with pytest.raises( + ValueError, + match="static overwrite partition filter can only apply to partition fields which are without hidden transform, but get.*", + ): + _bind_and_validate_static_overwrite_filter_predicate(unbound_expr=pred, table_schema=test_schema, spec=partition_spec) + -# combine this with above -def test__validate_static_overwrite_filter_field_succeed_on_an_identity_field_although_table_has_hidden_partition()-> None: #pred: BooleanExpression, table_schema: Schema, spec:PartitionSpec)-> None: - # todo: is it possible to make boolean expression more specific, like bound expression? no it is not bound yet +def test__bind_and_validate_static_overwrite_filter_predicate_succeeds_on_an_identity_transform_field_although_table_has_other_hidden_partition_fields() -> ( + None +): test_schema = Schema( NestedField(field_id=1, name="foo", field_type=StringType(), required=False), NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), @@ -1165,17 +1137,15 @@ def test__validate_static_overwrite_filter_field_succeed_on_an_identity_field_al identifier_field_ids=[2], ) pred = EqualTo(Reference("bar"), 3) - partition_spec=PartitionSpec( + partition_spec = PartitionSpec( PartitionField(source_id=2, field_id=1001, transform=IdentityTransform(), name="bar"), - PartitionField(source_id=1, field_id=1002, transform=TruncateTransform(2), name="foo_trunc") + PartitionField(source_id=1, field_id=1002, transform=TruncateTransform(2), name="foo_trunc"), ) - # import re - #with pytest.raises(ValueError, match="static overwrite partition filter can only apply to partition fields which are without hidden transform, but get.*"): - _validate_static_overwrite_filter_field(unbound_expr=pred, table_schema=test_schema, spec=partition_spec) + + _bind_and_validate_static_overwrite_filter_predicate(unbound_expr=pred, table_schema=test_schema, spec=partition_spec) -def test__validate_static_overwrite_filter_field_fail_to_bind()-> None: #pred: BooleanExpression, table_schema: Schema, spec:PartitionSpec)-> None: - # todo: is it possible to make boolean expression more specific, like bound expression? no it is not bound yet +def test__bind_and_validate_static_overwrite_filter_predicate_fails_to_bind() -> None: test_schema = Schema( NestedField(field_id=1, name="foo", field_type=StringType(), required=False), NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), @@ -1184,16 +1154,15 @@ def test__validate_static_overwrite_filter_field_fail_to_bind()-> None: #pred: B identifier_field_ids=[2], ) pred = EqualTo(Reference("bar"), "an incompatible type") - partition_spec=PartitionSpec( + partition_spec = PartitionSpec( PartitionField(source_id=2, field_id=1001, transform=IdentityTransform(), name="bar"), - PartitionField(source_id=1, field_id=1002, transform=TruncateTransform(2), name="foo_trunc") + PartitionField(source_id=1, field_id=1002, transform=TruncateTransform(2), name="foo_trunc"), ) with pytest.raises(ValueError, match="Could not convert an incompatible type into a int"): - _validate_static_overwrite_filter_field(unbound_expr=pred, table_schema=test_schema, spec=partition_spec) + _bind_and_validate_static_overwrite_filter_predicate(unbound_expr=pred, table_schema=test_schema, spec=partition_spec) -def test__validate_static_overwrite_filter_field_fail_to_bind_due_to_non_nullable()-> None: #pred: BooleanExpression, table_schema: Schema, spec:PartitionSpec)-> None: - # todo: is it possible to make boolean expression more specific, like bound expression? no it is not bound yet +def test__bind_and_validate_static_overwrite_filter_predicate_fails_to_bind_due_to_non_nullable() -> None: test_schema = Schema( NestedField(field_id=1, name="foo", field_type=StringType(), required=False), NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), @@ -1202,67 +1171,60 @@ def test__validate_static_overwrite_filter_field_fail_to_bind_due_to_non_nullabl identifier_field_ids=[2], ) pred = IsNull(Reference("bar")) - partition_spec=PartitionSpec( + partition_spec = PartitionSpec( PartitionField(source_id=3, field_id=1001, transform=IdentityTransform(), name="baz"), - PartitionField(source_id=1, field_id=1002, transform=TruncateTransform(2), name="foo_trunc") + PartitionField(source_id=1, field_id=1002, transform=TruncateTransform(2), name="foo_trunc"), ) - with pytest.raises(ValueError, match=re.escape("Static overwriting with part of the explicit partition filter not meaningful (e.g. specifing a non-nullable partition field to be null)")): - _validate_static_overwrite_filter_field(unbound_expr=pred, table_schema=test_schema, spec=partition_spec) - - + with pytest.raises( + ValueError, + match=re.escape( + "Static overwriting with part of the explicit partition filter not meaningful (specifing a non-nullable partition field to be null)" + ), + ): + _bind_and_validate_static_overwrite_filter_predicate(unbound_expr=pred, table_schema=test_schema, spec=partition_spec) -def test__check_schema_with_filter_succeed(): - table_schema: Schema = Schema( - NestedField(field_id=1, name="foo", field_type=StringType(), required=False), - NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), - NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), - schema_id=1, - identifier_field_ids=[2], - ) +def test__check_schema_with_filter_succeed(iceberg_schema_simple: Schema) -> None: other_schema: pa.Schema = pa.schema([ pa.field('foo', pa.string()), pa.field('baz', pa.bool_()), ]) unbound_preds = [EqualTo(Reference("bar"), 15)] - filter_predicates = [pred.bind(table_schema) for pred in unbound_preds] - _check_schema(table_schema, other_schema, filter_predicates = filter_predicates) + filter_predicates: Set[BoundPredicate[int]] = {pred.bind(iceberg_schema_simple) for pred in unbound_preds} + # upcast set[BoundLiteralPredicate[int]] to set[BoundPredicate[int]] + # because _check_schema expects set[BoundPredicate[int]] and set is not covariant + # (meaning although BoundLiteralPredicate is subclass of BoundPredicate, list[BoundLiteralPredicate] is not that of list[BoundPredicate]) + _check_schema(iceberg_schema_simple, other_schema, filter_predicates=filter_predicates) -def test__check_schema_with_filter_succeed_on_pyarrow_table_with_random_column_order() -> None: - table_schema: Schema = Schema( - NestedField(field_id=1, name="foo", field_type=StringType(), required=False), - NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), - NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), - schema_id=1, - identifier_field_ids=[2], - ) + +def test__check_schema_with_filter_succeed_on_pyarrow_table_with_random_column_order(iceberg_schema_simple: Schema) -> None: other_schema: pa.Schema = pa.schema([ pa.field('baz', pa.bool_()), pa.field('foo', pa.string()), ]) unbound_preds = [EqualTo(Reference("bar"), 15)] - filter_predicates = [pred.bind(table_schema) for pred in unbound_preds] - _check_schema(table_schema, other_schema, filter_predicates = filter_predicates) + filter_predicates: Set[BoundPredicate[int]] = {pred.bind(iceberg_schema_simple) for pred in unbound_preds} + # upcast set[BoundLiteralPredicate[int]] to set[BoundPredicate[int]] + # because _check_schema expects set[BoundPredicate[int]] and set is not covariant + # (meaning although BoundLiteralPredicate is subclass of BoundPredicate, list[BoundLiteralPredicate] is not that of list[BoundPredicate]) + _check_schema(iceberg_schema_simple, other_schema, filter_predicates=filter_predicates) -def test__check_schema_with_filter_fail_on_missing_field(): - table_schema: Schema = Schema( - NestedField(field_id=1, name="foo", field_type=StringType(), required=False), - NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), - NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), - schema_id=1, - identifier_field_ids=[2], - ) +def test__check_schema_with_filter_fails_on_missing_field(iceberg_schema_simple: Schema) -> None: other_schema: pa.Schema = pa.schema([ pa.field('baz', pa.bool_()), ]) unbound_preds = [EqualTo(Reference("bar"), 15)] - filter_predicates = [pred.bind(table_schema) for pred in unbound_preds] + filter_predicates: Set[BoundPredicate[int]] = {pred.bind(iceberg_schema_simple) for pred in unbound_preds} + # upcast set[BoundLiteralPredicate[int]] to set[BoundPredicate[int]] + # because _check_schema expects set[BoundPredicate[int]] and set is not covariant + # (meaning although BoundLiteralPredicate is subclass of BoundPredicate, list[BoundLiteralPredicate] is not that of list[BoundPredicate]) - expected = re.escape('''Mismatch in fields: + expected = re.escape( + """Mismatch in fields: ┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ ┃ Table field ┃ Dataframe field ┃ Overwrite filter field ┃ ┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩ @@ -1270,28 +1232,25 @@ def test__check_schema_with_filter_fail_on_missing_field(): │ ✅ │ 2: bar: required int │ N/A │ bar │ │ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │ N/A │ └────┴──────────────────────────┴──────────────────────────┴────────────────────────┘ -''') +""" + ) with pytest.raises(ValueError, match=expected): - _check_schema(table_schema, other_schema, filter_predicates = filter_predicates) - + _check_schema(iceberg_schema_simple, other_schema, filter_predicates=filter_predicates) -def test__check_schema_with_filter_fail_on_nullability_mismatch(): - table_schema: Schema = Schema( - NestedField(field_id=1, name="foo", field_type=StringType(), required=False), - NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), - NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), - schema_id=1, - identifier_field_ids=[2], - ) +def test__check_schema_with_filter_fails_on_nullability_mismatch(iceberg_schema_simple: Schema) -> None: other_schema: pa.Schema = pa.schema([ pa.field('foo', pa.string()), pa.field('bar', pa.int32()), ]) unbound_preds = [EqualTo(Reference("baz"), True)] - filter_predicates = [pred.bind(table_schema) for pred in unbound_preds] - expected = re.escape('''Mismatch in fields: + filter_predicates: Set[BoundPredicate[bool]] = {pred.bind(iceberg_schema_simple) for pred in unbound_preds} + # upcast set[BoundLiteralPredicate[bool]] to set[BoundPredicate[bool]] + # because _check_schema expects set[BoundPredicate[bool]] and set is not covariant + # (meaning although BoundLiteralPredicate is subclass of BoundPredicate, list[BoundLiteralPredicate] is not that of list[BoundPredicate]) + expected = re.escape( + """Mismatch in fields: ┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ ┃ Table field ┃ Dataframe field ┃ Overwrite filter field ┃ ┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩ @@ -1299,27 +1258,25 @@ def test__check_schema_with_filter_fail_on_nullability_mismatch(): │ ❌ │ 2: bar: required int │ 2: bar: optional int │ N/A │ │ ✅ │ 3: baz: optional boolean │ N/A │ baz │ └────┴──────────────────────────┴─────────────────────────┴────────────────────────┘ -''') +""" + ) with pytest.raises(ValueError, match=expected): - _check_schema(table_schema, other_schema, filter_predicates = filter_predicates) + _check_schema(iceberg_schema_simple, other_schema, filter_predicates=filter_predicates) -def test__check_schema_with_filter_fail_on_type_mismatch(): - table_schema: Schema = Schema( - NestedField(field_id=1, name="foo", field_type=StringType(), required=False), - NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), - NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), - schema_id=1, - identifier_field_ids=[2], - ) +def test__check_schema_with_filter_fails_on_type_mismatch(iceberg_schema_simple: Schema) -> None: other_schema: pa.Schema = pa.schema([ pa.field('foo', pa.string()), pa.field('bar', pa.string(), nullable=False), ]) unbound_preds = [EqualTo(Reference("baz"), True)] - filter_predicates = [pred.bind(table_schema) for pred in unbound_preds] - expected = re.escape('''Mismatch in fields: + filter_predicates: Set[BoundPredicate[bool]] = {pred.bind(iceberg_schema_simple) for pred in unbound_preds} + # upcast set[BoundLiteralPredicate[bool]] to set[BoundPredicate[bool]] + # because _check_schema expects set[BoundPredicate[bool]] and set is not covariant + # (meaning although BoundLiteralPredicate is subclass of BoundPredicate, list[BoundLiteralPredicate] is not that of list[BoundPredicate]) + expected = re.escape( + """Mismatch in fields: ┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ ┃ Table field ┃ Dataframe field ┃ Overwrite filter field ┃ ┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩ @@ -1327,27 +1284,22 @@ def test__check_schema_with_filter_fail_on_type_mismatch(): │ ❌ │ 2: bar: required int │ 2: bar: required string │ N/A │ │ ✅ │ 3: baz: optional boolean │ N/A │ baz │ └────┴──────────────────────────┴─────────────────────────┴────────────────────────┘ -''') +""" + ) with pytest.raises(ValueError, match=expected): - _check_schema(table_schema, other_schema, filter_predicates = filter_predicates) + _check_schema(iceberg_schema_simple, other_schema, filter_predicates=filter_predicates) -def test__check_schema_with_field_fail_due_to_filter_and_dataframe_hold_same_field(): - table_schema: Schema = Schema( - NestedField(field_id=1, name="foo", field_type=StringType(), required=False), - NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), - NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), - schema_id=1, - identifier_field_ids=[2], - ) +def test__check_schema_with_filter_fails_due_to_filter_and_dataframe_holding_same_field(iceberg_schema_simple: Schema) -> None: other_schema: pa.Schema = pa.schema([ pa.field('foo', pa.string()), pa.field('bar', pa.int32(), nullable=False), ]) unbound_preds = [IsNull(Reference("foo")), EqualTo(Reference("baz"), True)] - filter_predicates = [pred.bind(table_schema) for pred in unbound_preds] - expected = re.escape('''Mismatch in fields: + filter_predicates: Set[BoundPredicate[Any]] = {pred.bind(iceberg_schema_simple) for pred in unbound_preds} # type: ignore # bind returns BoundLiteralPredicate and BoundUnaryPredicate and thus set has type inferred as Set[BooleanExpression] which could not be downcast to Set[BoundPredicate[Any]] implicitly using :. + expected = re.escape( + """Mismatch in fields: ┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ ┃ Table field ┃ Dataframe field ┃ Overwrite filter field ┃ ┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩ @@ -1355,42 +1307,64 @@ def test__check_schema_with_field_fail_due_to_filter_and_dataframe_hold_same_fie │ ✅ │ 2: bar: required int │ 2: bar: required int │ N/A │ │ ✅ │ 3: baz: optional boolean │ N/A │ baz │ └────┴──────────────────────────┴─────────────────────────┴────────────────────────┘ -''') +""" + ) with pytest.raises(ValueError, match=expected): - _check_schema(table_schema, other_schema, filter_predicates = filter_predicates) + _check_schema(iceberg_schema_simple, other_schema, filter_predicates=filter_predicates) @pytest.mark.parametrize( - "pred, raises, is_null_list, eq_to_list", + "pred, raises, is_null_preds, eq_to_preds", [ (EqualTo(Reference("foo"), "hello"), False, {}, {EqualTo(Reference("foo"), "hello")}), (IsNull(Reference("foo")), False, {IsNull(Reference("foo"))}, {}), - (And(IsNull(Reference("foo")),EqualTo(Reference("boo"), "hello")), False, {IsNull(Reference("foo"))}, {EqualTo(Reference("boo"), "hello")}), + ( + And(IsNull(Reference("foo")), EqualTo(Reference("boo"), "hello")), + False, + {IsNull(Reference("foo"))}, + {EqualTo(Reference("boo"), "hello")}, + ), (NotNull, True, {}, {}), (NotEqualTo, True, {}, {}), (LessThan(Reference("foo"), 5), True, {}, {}), - (Or(IsNull(Reference("foo")),EqualTo(Reference("foo"), "hello")), True, {}, {}), - (And(EqualTo(Reference("foo"), "hello"), And(IsNull(Reference("baz")), EqualTo(Reference("boo"), "hello"))), False, {IsNull(Reference("baz"))}, {EqualTo(Reference("foo"), "hello"), EqualTo(Reference("boo"), "hello")}), + (Or(IsNull(Reference("foo")), EqualTo(Reference("foo"), "hello")), True, {}, {}), + ( + And(EqualTo(Reference("foo"), "hello"), And(IsNull(Reference("baz")), EqualTo(Reference("boo"), "hello"))), + False, + {IsNull(Reference("baz"))}, + {EqualTo(Reference("foo"), "hello"), EqualTo(Reference("boo"), "hello")}, + ), # Below are crowd-crush tests: a same field can only be with same literal/null, not different literals or both literal and null # A false crush: when there are duplicated isnull/equalto, the collector should deduplicate them. - (And(EqualTo(Reference("foo"), "hello"), EqualTo(Reference("foo"), "hello")), False, {}, {EqualTo(Reference("foo"), "hello")}), + ( + And(EqualTo(Reference("foo"), "hello"), EqualTo(Reference("foo"), "hello")), + False, + {}, + {EqualTo(Reference("foo"), "hello")}, + ), # When crush happens - (And(EqualTo(Reference("foo"), "hello"), EqualTo(Reference("foo"), "bye")), True, {}, {EqualTo(Reference("foo"), "hello"), EqualTo(Reference("foo"), "bye")}), - (And(EqualTo(Reference("foo"), "hello"), IsNull(Reference("foo"))), True, {IsNull(Reference("foo"))}, {}) + ( + And(EqualTo(Reference("foo"), "hello"), EqualTo(Reference("foo"), "bye")), + True, + {}, + {EqualTo(Reference("foo"), "hello"), EqualTo(Reference("foo"), "bye")}, + ), + (And(EqualTo(Reference("foo"), "hello"), IsNull(Reference("foo"))), True, {IsNull(Reference("foo"))}, {}), ], ) -def test__validate_static_overwrite_filter_expr_type(pred, raises, is_null_list, eq_to_list)-> None: +def test__validate_static_overwrite_filter_expr_type( + pred: Union[IsNull, EqualTo[Any]], raises: bool, is_null_preds: Set[IsNull], eq_to_preds: Set[EqualTo[L]] +) -> None: if raises: - with pytest.raises(ValueError): # match=expected): to do + with pytest.raises(ValueError): res = _validate_static_overwrite_filter_expr_type(pred) else: res = _validate_static_overwrite_filter_expr_type(pred) - assert set([str(e) for e in res[0]]) == set([str(e) for e in is_null_list]) - assert set([str(e) for e in res[1]]) == set([str(e) for e in eq_to_list]) + assert {str(e) for e in res[0]} == {str(e) for e in is_null_preds} + assert {str(e) for e in res[1]} == {str(e) for e in eq_to_preds} - -def test_schema_mismatch_type(table_schema_simple: Schema) -> None: +def test_check_schema_mismatch_type(table_schema_simple: Schema) -> None: other_schema = pa.schema(( pa.field("foo", pa.string(), nullable=True), pa.field("bar", pa.decimal128(18, 6), nullable=False), @@ -1411,7 +1385,7 @@ def test_schema_mismatch_type(table_schema_simple: Schema) -> None: _check_schema(table_schema_simple, other_schema) -def test_schema_mismatch_nullability(table_schema_simple: Schema) -> None: +def test_check_schema_mismatch_nullability(table_schema_simple: Schema) -> None: other_schema = pa.schema(( pa.field("foo", pa.string(), nullable=True), pa.field("bar", pa.int32(), nullable=True), @@ -1432,7 +1406,7 @@ def test_schema_mismatch_nullability(table_schema_simple: Schema) -> None: _check_schema(table_schema_simple, other_schema) -def test_schema_mismatch_missing_field(table_schema_simple: Schema) -> None: +def test_check_schema_mismatch_missing_field(table_schema_simple: Schema) -> None: other_schema = pa.schema(( pa.field("foo", pa.string(), nullable=True), pa.field("baz", pa.bool_(), nullable=True), @@ -1451,7 +1425,8 @@ def test_schema_mismatch_missing_field(table_schema_simple: Schema) -> None: with pytest.raises(ValueError, match=expected): _check_schema(table_schema_simple, other_schema) -def test_schema_succeed(table_schema_simple: Schema) -> None: + +def test_check_schema_succeed(table_schema_simple: Schema) -> None: other_schema = pa.schema(( pa.field("foo", pa.string(), nullable=True), pa.field("bar", pa.int32(), nullable=False), @@ -1460,7 +1435,8 @@ def test_schema_succeed(table_schema_simple: Schema) -> None: _check_schema(table_schema_simple, other_schema) -def test_schema_succeed_on_pyarrow_table_reversed_order(table_schema_simple: Schema) -> None: + +def test_schema_succeed_on_pyarrow_table_reversed_column_order(table_schema_simple: Schema) -> None: other_schema = pa.schema(( pa.field("baz", pa.bool_(), nullable=True), pa.field("bar", pa.int32(), nullable=False), From ea7b138534559454a4c0885130ca3294ba15d745 Mon Sep 17 00:00:00 2001 From: Adrian Qin <147659252+jqin61@users.noreply.github.com> Date: Fri, 22 Mar 2024 06:54:43 +0000 Subject: [PATCH 4/6] clean up --- pyiceberg/expressions/__init__.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pyiceberg/expressions/__init__.py b/pyiceberg/expressions/__init__.py index 29503eff26..8829b4a3fb 100644 --- a/pyiceberg/expressions/__init__.py +++ b/pyiceberg/expressions/__init__.py @@ -373,10 +373,7 @@ def as_bound(self) -> Type[BoundPredicate[L]]: ... class UnaryPredicate(UnboundPredicate[Any], ABC): def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundUnaryPredicate[Any]: bound_term = self.term.bind(schema, case_sensitive) - print(f"{bound_term=}") - res = self.as_bound(bound_term) - print(f"{res=}") - return res + return self.as_bound(bound_term) def __repr__(self) -> str: """Return the string representation of the UnaryPredicate class.""" From f5c5e68b65125e4ae0316e2514bd379c7cbacd56 Mon Sep 17 00:00:00 2001 From: Adrian Qin <147659252+jqin61@users.noreply.github.com> Date: Fri, 22 Mar 2024 15:56:20 +0000 Subject: [PATCH 5/6] more clean up --- pyiceberg/table/__init__.py | 232 ++++++++++++++++++++++++------------ tests/table/test_init.py | 29 ++--- 2 files changed, 161 insertions(+), 100 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 0cb4f60468..a730b2bab4 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -148,8 +148,6 @@ _JAVA_LONG_MAX = 9223372036854775807 -# to do: make the types not expression but unbound predicates, this should be more precise, we already know it could only be unboundisnull and unboundequalto -# actually could make it union[isnull, equalto] and make the return as union[boundisnull,boundequalto] def _bind_and_validate_static_overwrite_filter_predicate( unbound_expr: Union[IsNull, EqualTo[L]], table_schema: Schema, spec: PartitionSpec ) -> Union[BoundIsNull[L], BoundEqualTo[L]]: @@ -240,13 +238,7 @@ def _fill_in_df( return df -# linttodo: break this down into 2 functions -def _check_schema( - table_schema: Schema, other_schema: "pa.Schema", filter_predicates: Set[BoundPredicate[L]] | None = None -) -> None: - if filter_predicates is None: - filter_predicates = set() - +def _arrow_schema_to_iceberg_schema_with_field_ids(table_schema: Schema, other_schema: "pa.Schema") -> Schema: from pyiceberg.io.pyarrow import _pyarrow_to_schema_without_ids, pyarrow_to_schema name_mapping = table_schema.name_mapping @@ -258,72 +250,157 @@ def _check_schema( raise ValueError( f"PyArrow table contains more columns: {', '.join(sorted(additional_names))}. Update the schema first (hint, use union_by_name)." ) from e + return task_schema - def compare_and_rich_print(table_schema: Schema, task_schema: Schema) -> None: - sorted_table_schema = Schema(*sorted(table_schema.fields, key=lambda field: field.field_id)) - sorted_task_schema = Schema(*sorted(task_schema.fields, key=lambda field: field.field_id)) - if sorted_table_schema.as_struct() != sorted_task_schema.as_struct(): - from rich.console import Console - from rich.table import Table as RichTable - - console = Console(record=True) - rich_table = RichTable(show_header=True, header_style="bold") - rich_table.add_column("") - rich_table.add_column("Table field") - rich_table.add_column("Dataframe field") - - for lhs in table_schema.fields: +def _check_schema_with_filter_predicates( + table_schema: Schema, other_schema: "pa.Schema", filter_predicates: Set[BoundPredicate[L]] +) -> None: + task_schema = _arrow_schema_to_iceberg_schema_with_field_ids(table_schema, other_schema) + + filter_fields = [bound_predicate.term.ref().field for bound_predicate in filter_predicates] + remaining_schema = _truncate_fields(table_schema, to_truncate=filter_fields) + sorted_remaining_schema = Schema(*sorted(remaining_schema.fields, key=lambda field: field.field_id)) + sorted_task_schema = Schema(*sorted(task_schema.fields, key=lambda field: field.field_id)) + if sorted_remaining_schema.as_struct() != sorted_task_schema.as_struct(): + from rich.console import Console + from rich.table import Table as RichTable + + console = Console(record=True) + + rich_table = RichTable(show_header=True, header_style="bold") + rich_table.add_column("") + rich_table.add_column("Table field") + rich_table.add_column("Dataframe field") + rich_table.add_column("Overwrite filter field") + + filter_field_names = [field.name for field in filter_fields] + for lhs in table_schema.fields: + if lhs.name in filter_field_names: try: rhs = task_schema.find_field(lhs.field_id) - rich_table.add_row("✅" if lhs == rhs else "❌", str(lhs), str(rhs)) + rich_table.add_row("❌", str(lhs), str(rhs), lhs.name) except ValueError: - rich_table.add_row("❌", str(lhs), "Missing") + rich_table.add_row("✅", str(lhs), "N/A", lhs.name) + else: + try: + rhs = task_schema.find_field(lhs.field_id) + rich_table.add_row("✅" if lhs == rhs else "❌", str(lhs), str(rhs), "N/A") + except ValueError: + rich_table.add_row("❌", str(lhs), "Missing", "N/A") - console.print(rich_table) - raise ValueError(f"Mismatch in fields:\n{console.export_text()}") + console.print(rich_table) + raise ValueError(f"Mismatch in fields:\n{console.export_text()}") - def compare_and_rich_print_with_filter( - table_schema: Schema, task_schema: Schema, filter_predicates: Set[BoundPredicate[L]] - ) -> None: - filter_fields = [bound_predicate.term.ref().field for bound_predicate in filter_predicates] - remaining_schema = _truncate_fields(table_schema, to_truncate=filter_fields) - sorted_remaining_schema = Schema(*sorted(remaining_schema.fields, key=lambda field: field.field_id)) - sorted_task_schema = Schema(*sorted(task_schema.fields, key=lambda field: field.field_id)) - if sorted_remaining_schema.as_struct() != sorted_task_schema.as_struct(): - from rich.console import Console - from rich.table import Table as RichTable - - console = Console(record=True) - - rich_table = RichTable(show_header=True, header_style="bold") - rich_table.add_column("") - rich_table.add_column("Table field") - rich_table.add_column("Dataframe field") - rich_table.add_column("Overwrite filter field") - - filter_field_names = [field.name for field in filter_fields] - for lhs in table_schema.fields: - if lhs.name in filter_field_names: - try: - rhs = task_schema.find_field(lhs.field_id) - rich_table.add_row("❌", str(lhs), str(rhs), lhs.name) - except ValueError: - rich_table.add_row("✅", str(lhs), "N/A", lhs.name) - else: - try: - rhs = task_schema.find_field(lhs.field_id) - rich_table.add_row("✅" if lhs == rhs else "❌", str(lhs), str(rhs), "N/A") - except ValueError: - rich_table.add_row("❌", str(lhs), "Missing", "N/A") - console.print(rich_table) - raise ValueError(f"Mismatch in fields:\n{console.export_text()}") +def _check_schema(table_schema: Schema, other_schema: "pa.Schema") -> None: + task_schema = _arrow_schema_to_iceberg_schema_with_field_ids(table_schema, other_schema) - if len(filter_predicates) != 0: - compare_and_rich_print_with_filter(table_schema, task_schema, filter_predicates) - else: - compare_and_rich_print(table_schema, task_schema) + if table_schema.as_struct() != task_schema.as_struct(): + from rich.console import Console + from rich.table import Table as RichTable + + console = Console(record=True) + + rich_table = RichTable(show_header=True, header_style="bold") + rich_table.add_column("") + rich_table.add_column("Table field") + rich_table.add_column("Dataframe field") + + for lhs in table_schema.fields: + try: + rhs = task_schema.find_field(lhs.field_id) + rich_table.add_row("✅" if lhs == rhs else "❌", str(lhs), str(rhs)) + except ValueError: + rich_table.add_row("❌", str(lhs), "Missing") + + console.print(rich_table) + raise ValueError(f"Mismatch in fields:\n{console.export_text()}") + + +# def _check_schema( +# table_schema: Schema, other_schema: "pa.Schema", filter_predicates: Set[BoundPredicate[L]] | None = None +# ) -> None: +# if filter_predicates is None: +# filter_predicates = set() + +# from pyiceberg.io.pyarrow import _pyarrow_to_schema_without_ids, pyarrow_to_schema + +# name_mapping = table_schema.name_mapping +# try: +# task_schema = pyarrow_to_schema(other_schema, name_mapping=name_mapping) +# except ValueError as e: +# other_schema = _pyarrow_to_schema_without_ids(other_schema) +# additional_names = set(other_schema.column_names) - set(table_schema.column_names) +# raise ValueError( +# f"PyArrow table contains more columns: {', '.join(sorted(additional_names))}. Update the schema first (hint, use union_by_name)." +# ) from e + +# def compare_and_rich_print(table_schema: Schema, task_schema: Schema) -> None: +# sorted_table_schema = Schema(*sorted(table_schema.fields, key=lambda field: field.field_id)) +# sorted_task_schema = Schema(*sorted(task_schema.fields, key=lambda field: field.field_id)) +# if sorted_table_schema.as_struct() != sorted_task_schema.as_struct(): +# from rich.console import Console +# from rich.table import Table as RichTable + +# console = Console(record=True) + +# rich_table = RichTable(show_header=True, header_style="bold") +# rich_table.add_column("") +# rich_table.add_column("Table field") +# rich_table.add_column("Dataframe field") + +# for lhs in table_schema.fields: +# try: +# rhs = task_schema.find_field(lhs.field_id) +# rich_table.add_row("✅" if lhs == rhs else "❌", str(lhs), str(rhs)) +# except ValueError: +# rich_table.add_row("❌", str(lhs), "Missing") + +# console.print(rich_table) +# raise ValueError(f"Mismatch in fields:\n{console.export_text()}") + +# def compare_and_rich_print_with_filter( +# table_schema: Schema, task_schema: Schema, filter_predicates: Set[BoundPredicate[L]] +# ) -> None: +# filter_fields = [bound_predicate.term.ref().field for bound_predicate in filter_predicates] +# remaining_schema = _truncate_fields(table_schema, to_truncate=filter_fields) +# sorted_remaining_schema = Schema(*sorted(remaining_schema.fields, key=lambda field: field.field_id)) +# sorted_task_schema = Schema(*sorted(task_schema.fields, key=lambda field: field.field_id)) +# if sorted_remaining_schema.as_struct() != sorted_task_schema.as_struct(): +# from rich.console import Console +# from rich.table import Table as RichTable + +# console = Console(record=True) + +# rich_table = RichTable(show_header=True, header_style="bold") +# rich_table.add_column("") +# rich_table.add_column("Table field") +# rich_table.add_column("Dataframe field") +# rich_table.add_column("Overwrite filter field") + +# filter_field_names = [field.name for field in filter_fields] +# for lhs in table_schema.fields: +# if lhs.name in filter_field_names: +# try: +# rhs = task_schema.find_field(lhs.field_id) +# rich_table.add_row("❌", str(lhs), str(rhs), lhs.name) +# except ValueError: +# rich_table.add_row("✅", str(lhs), "N/A", lhs.name) +# else: +# try: +# rhs = task_schema.find_field(lhs.field_id) +# rich_table.add_row("✅" if lhs == rhs else "❌", str(lhs), str(rhs), "N/A") +# except ValueError: +# rich_table.add_row("❌", str(lhs), "Missing", "N/A") + +# console.print(rich_table) +# raise ValueError(f"Mismatch in fields:\n{console.export_text()}") + +# if len(filter_predicates) != 0: +# compare_and_rich_print_with_filter(table_schema, task_schema, filter_predicates) +# else: +# compare_and_rich_print(table_schema, task_schema) def _truncate_fields(table_schema: Schema, to_truncate: List[NestedField]) -> Schema: @@ -424,11 +501,11 @@ class PartitionProjector: def __init__( self, table_metadata: TableMetadata, - row_filter: Union[str, BooleanExpression] = ALWAYS_TRUE, + row_filter: BooleanExpression = ALWAYS_TRUE, case_sensitive: bool = True, ): self.table_metadata = table_metadata - self.row_filter = _parse_row_filter(row_filter) + self.row_filter = _parse_row_filter(row_filter) # todo make it BooleanExpression self.case_sensitive = case_sensitive def _build_partition_projection(self, spec_id: int) -> BooleanExpression: @@ -1344,13 +1421,15 @@ def overwrite(self, df: pa.Table, overwrite_filter: Union[str, BooleanExpression if not isinstance(df, pa.Table): raise ValueError(f"Expected PyArrow table, got: {df}") + overwrite_filter = _parse_row_filter(overwrite_filter) + if not overwrite_filter == ALWAYS_TRUE: bound_is_null_predicates, bound_eq_to_predicates = _validate_static_overwrite_filter( table_schema=self.schema(), overwrite_filter=overwrite_filter, spec=self.metadata.spec() ) - _check_schema( + _check_schema_with_filter_predicates( table_schema=self.schema(), other_schema=df.schema, filter_predicates=bound_is_null_predicates.union(bound_eq_to_predicates), @@ -2678,17 +2757,16 @@ class _MergingSnapshotProducer(UpdateTableMetadata["_MergingSnapshotProducer"]): _io: FileIO _deleted_data_files: Optional[DeletedDataFiles] - # _manifests_compositions: Any #list[Callable[[_MergingSnapshotProducer], List[ManifestFile]]] def __init__( self, - operation: Operation, # done, inited + operation: Operation, transaction: Transaction, - io: FileIO, # done, inited - overwrite_filter: Union[str, BooleanExpression] = ALWAYS_TRUE, + io: FileIO, + overwrite_filter: BooleanExpression = ALWAYS_TRUE, commit_uuid: Optional[uuid.UUID] = None, ) -> None: super().__init__(transaction) - self.commit_uuid = commit_uuid or uuid.uuid4() # done + self.commit_uuid = commit_uuid or uuid.uuid4() self._io = io self._operation = operation self._snapshot_id = self._transaction.table_metadata.new_snapshot_id() @@ -3089,7 +3167,7 @@ def _get_deleted_entries(manifest: ManifestFile) -> List[ManifestEntry]: class PartialOverwriteFiles(_MergingSnapshotProducer): - def __init__(self, overwrite_filter: Union[str, BooleanExpression], **kwargs: Any) -> None: + def __init__(self, overwrite_filter: BooleanExpression, **kwargs: Any) -> None: super().__init__(**kwargs) self._deleted_data_files = ExplicitlyDeletedDataFiles() self.overwrite_filter = overwrite_filter @@ -3173,9 +3251,7 @@ def __init__(self, transaction: Transaction, io: FileIO) -> None: def fast_append(self) -> FastAppendFiles: return FastAppendFiles(operation=Operation.APPEND, transaction=self._transaction, io=self._io) - def overwrite( - self, overwrite_filter: Union[str, BooleanExpression] = ALWAYS_TRUE - ) -> Union[OverwriteFiles, PartialOverwriteFiles]: + def overwrite(self, overwrite_filter: BooleanExpression = ALWAYS_TRUE) -> Union[OverwriteFiles, PartialOverwriteFiles]: if overwrite_filter == ALWAYS_TRUE: return OverwriteFiles( operation=Operation.OVERWRITE diff --git a/tests/table/test_init.py b/tests/table/test_init.py index d72a8e1534..5acdc483aa 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -74,6 +74,7 @@ _apply_table_update, _bind_and_validate_static_overwrite_filter_predicate, _check_schema, + _check_schema_with_filter_predicates, _fill_in_df, _match_deletes_to_data_file, _TableMetadataUpdateContext, @@ -1028,7 +1029,6 @@ def test_correct_schema() -> None: assert "Snapshot not found: -1" in str(exc_info.value) -# _bind_and_validate_static_overwrite_filter_predicate def test__bind_and_validate_static_overwrite_filter_predicate_fails_on_non_schema_fields_in_filter() -> None: test_schema = Schema( NestedField(field_id=1, name="foo", field_type=StringType(), required=False), @@ -1045,21 +1045,6 @@ def test__bind_and_validate_static_overwrite_filter_predicate_fails_on_non_schem _bind_and_validate_static_overwrite_filter_predicate(unbound_expr=pred, table_schema=test_schema, spec=partition_spec) -# @pytest.mark.zy -# def test_mine(table_schema_simple) -> None: -# pred = IsNull("bar") -# print("xxxx!", pred.term) -# pred.bind(table_schema_simple) - - -# from pyiceberg.io.pyarrow import _pyarrow_to_schema_without_ids, pyarrow_to_schema -# pa_table = pa.table( -# {"bar": [1, 2, 3], "foo": ["a", "b", "c"], "baz": [True, False, None]}, -# ) -# name_mapping = table_schema_simple.name_mapping -# print("xxxx!", pyarrow_to_schema(pa_table.schema, name_mapping=name_mapping)) - - def test__fill_in_df(table_schema_simple: Schema) -> None: df = pa.table({"baz": [True, False, None]}) unbound_is_null_predicates = [IsNull(Reference("foo"))] @@ -1195,7 +1180,7 @@ def test__check_schema_with_filter_succeed(iceberg_schema_simple: Schema) -> Non # upcast set[BoundLiteralPredicate[int]] to set[BoundPredicate[int]] # because _check_schema expects set[BoundPredicate[int]] and set is not covariant # (meaning although BoundLiteralPredicate is subclass of BoundPredicate, list[BoundLiteralPredicate] is not that of list[BoundPredicate]) - _check_schema(iceberg_schema_simple, other_schema, filter_predicates=filter_predicates) + _check_schema_with_filter_predicates(iceberg_schema_simple, other_schema, filter_predicates=filter_predicates) def test__check_schema_with_filter_succeed_on_pyarrow_table_with_random_column_order(iceberg_schema_simple: Schema) -> None: @@ -1209,7 +1194,7 @@ def test__check_schema_with_filter_succeed_on_pyarrow_table_with_random_column_o # upcast set[BoundLiteralPredicate[int]] to set[BoundPredicate[int]] # because _check_schema expects set[BoundPredicate[int]] and set is not covariant # (meaning although BoundLiteralPredicate is subclass of BoundPredicate, list[BoundLiteralPredicate] is not that of list[BoundPredicate]) - _check_schema(iceberg_schema_simple, other_schema, filter_predicates=filter_predicates) + _check_schema_with_filter_predicates(iceberg_schema_simple, other_schema, filter_predicates=filter_predicates) def test__check_schema_with_filter_fails_on_missing_field(iceberg_schema_simple: Schema) -> None: @@ -1235,7 +1220,7 @@ def test__check_schema_with_filter_fails_on_missing_field(iceberg_schema_simple: """ ) with pytest.raises(ValueError, match=expected): - _check_schema(iceberg_schema_simple, other_schema, filter_predicates=filter_predicates) + _check_schema_with_filter_predicates(iceberg_schema_simple, other_schema, filter_predicates=filter_predicates) def test__check_schema_with_filter_fails_on_nullability_mismatch(iceberg_schema_simple: Schema) -> None: @@ -1261,7 +1246,7 @@ def test__check_schema_with_filter_fails_on_nullability_mismatch(iceberg_schema_ """ ) with pytest.raises(ValueError, match=expected): - _check_schema(iceberg_schema_simple, other_schema, filter_predicates=filter_predicates) + _check_schema_with_filter_predicates(iceberg_schema_simple, other_schema, filter_predicates=filter_predicates) def test__check_schema_with_filter_fails_on_type_mismatch(iceberg_schema_simple: Schema) -> None: @@ -1287,7 +1272,7 @@ def test__check_schema_with_filter_fails_on_type_mismatch(iceberg_schema_simple: """ ) with pytest.raises(ValueError, match=expected): - _check_schema(iceberg_schema_simple, other_schema, filter_predicates=filter_predicates) + _check_schema_with_filter_predicates(iceberg_schema_simple, other_schema, filter_predicates=filter_predicates) def test__check_schema_with_filter_fails_due_to_filter_and_dataframe_holding_same_field(iceberg_schema_simple: Schema) -> None: @@ -1310,7 +1295,7 @@ def test__check_schema_with_filter_fails_due_to_filter_and_dataframe_holding_sam """ ) with pytest.raises(ValueError, match=expected): - _check_schema(iceberg_schema_simple, other_schema, filter_predicates=filter_predicates) + _check_schema_with_filter_predicates(iceberg_schema_simple, other_schema, filter_predicates=filter_predicates) @pytest.mark.parametrize( From b788d38cadaa0789ca49ad8afb7b581cb000e599 Mon Sep 17 00:00:00 2001 From: Adrian Qin <147659252+jqin61@users.noreply.github.com> Date: Sat, 23 Mar 2024 21:43:53 +0000 Subject: [PATCH 6/6] clean up --- pyiceberg/table/__init__.py | 5 +- tests/table/test_init.py | 96 +++++++++++++++---------------------- 2 files changed, 42 insertions(+), 59 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index a730b2bab4..c60850c1af 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -296,7 +296,10 @@ def _check_schema_with_filter_predicates( def _check_schema(table_schema: Schema, other_schema: "pa.Schema") -> None: task_schema = _arrow_schema_to_iceberg_schema_with_field_ids(table_schema, other_schema) - if table_schema.as_struct() != task_schema.as_struct(): + sorted_table_schema = Schema(*sorted(table_schema.fields, key=lambda field: field.field_id)) + sorted_task_schema = Schema(*sorted(task_schema.fields, key=lambda field: field.field_id)) + + if sorted_table_schema.as_struct() != sorted_task_schema.as_struct(): from rich.console import Console from rich.table import Table as RichTable diff --git a/tests/table/test_init.py b/tests/table/test_init.py index 5acdc483aa..f5519a0f7e 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -1029,20 +1029,17 @@ def test_correct_schema() -> None: assert "Snapshot not found: -1" in str(exc_info.value) -def test__bind_and_validate_static_overwrite_filter_predicate_fails_on_non_schema_fields_in_filter() -> None: - test_schema = Schema( - NestedField(field_id=1, name="foo", field_type=StringType(), required=False), - NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), - NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), - schema_id=1, - identifier_field_ids=[2], - ) +def test__bind_and_validate_static_overwrite_filter_predicate_fails_on_non_schema_fields_in_filter( + iceberg_schema_simple: Schema, +) -> None: pred = EqualTo(Reference("not a field"), "hello") partition_spec = PartitionSpec( PartitionField(source_id=1, field_id=1001, transform=IdentityTransform(), name="test_part_col") ) with pytest.raises(ValueError, match="Could not find field with name not a field, case_sensitive=True"): - _bind_and_validate_static_overwrite_filter_predicate(unbound_expr=pred, table_schema=test_schema, spec=partition_spec) + _bind_and_validate_static_overwrite_filter_predicate( + unbound_expr=pred, table_schema=iceberg_schema_simple, spec=partition_spec + ) def test__fill_in_df(table_schema_simple: Schema) -> None: @@ -1069,14 +1066,9 @@ def test__fill_in_df(table_schema_simple: Schema) -> None: assert filled_df == expected -def test__bind_and_validate_static_overwrite_filter_predicate_fails_on_non_part_fields_in_filter() -> None: - test_schema = Schema( - NestedField(field_id=1, name="foo", field_type=StringType(), required=False), - NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), - NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), - schema_id=1, - identifier_field_ids=[2], - ) +def test__bind_and_validate_static_overwrite_filter_predicate_fails_on_non_part_fields_in_filter( + iceberg_schema_simple: Schema, +) -> None: pred = EqualTo(Reference("foo"), "hello") partition_spec = PartitionSpec(PartitionField(source_id=2, field_id=1001, transform=IdentityTransform(), name="bar")) import re @@ -1087,17 +1079,14 @@ def test__bind_and_validate_static_overwrite_filter_predicate_fails_on_non_part_ "Get 0 partition fields from filter predicate EqualTo(term=Reference(name='foo'), literal=literal('hello')), expecting 1." ), ): - _bind_and_validate_static_overwrite_filter_predicate(unbound_expr=pred, table_schema=test_schema, spec=partition_spec) + _bind_and_validate_static_overwrite_filter_predicate( + unbound_expr=pred, table_schema=iceberg_schema_simple, spec=partition_spec + ) -def test__bind_and_validate_static_overwrite_filter_predicate_fails_on_non_identity_transorm_filter() -> None: - test_schema = Schema( - NestedField(field_id=1, name="foo", field_type=StringType(), required=False), - NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), - NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), - schema_id=1, - identifier_field_ids=[2], - ) +def test__bind_and_validate_static_overwrite_filter_predicate_fails_on_non_identity_transorm_filter( + iceberg_schema_simple: Schema, +) -> None: pred = EqualTo(Reference("foo"), "hello") partition_spec = PartitionSpec( PartitionField(source_id=2, field_id=1001, transform=IdentityTransform(), name="bar"), @@ -1108,53 +1097,42 @@ def test__bind_and_validate_static_overwrite_filter_predicate_fails_on_non_ident ValueError, match="static overwrite partition filter can only apply to partition fields which are without hidden transform, but get.*", ): - _bind_and_validate_static_overwrite_filter_predicate(unbound_expr=pred, table_schema=test_schema, spec=partition_spec) + _bind_and_validate_static_overwrite_filter_predicate( + unbound_expr=pred, table_schema=iceberg_schema_simple, spec=partition_spec + ) -def test__bind_and_validate_static_overwrite_filter_predicate_succeeds_on_an_identity_transform_field_although_table_has_other_hidden_partition_fields() -> ( - None -): - test_schema = Schema( - NestedField(field_id=1, name="foo", field_type=StringType(), required=False), - NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), - NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), - schema_id=1, - identifier_field_ids=[2], - ) +def test__bind_and_validate_static_overwrite_filter_predicate_succeeds_on_an_identity_transform_field_although_table_has_other_hidden_partition_fields( + iceberg_schema_simple: Schema, +) -> None: pred = EqualTo(Reference("bar"), 3) partition_spec = PartitionSpec( PartitionField(source_id=2, field_id=1001, transform=IdentityTransform(), name="bar"), PartitionField(source_id=1, field_id=1002, transform=TruncateTransform(2), name="foo_trunc"), ) - _bind_and_validate_static_overwrite_filter_predicate(unbound_expr=pred, table_schema=test_schema, spec=partition_spec) + _bind_and_validate_static_overwrite_filter_predicate( + unbound_expr=pred, table_schema=iceberg_schema_simple, spec=partition_spec + ) -def test__bind_and_validate_static_overwrite_filter_predicate_fails_to_bind() -> None: - test_schema = Schema( - NestedField(field_id=1, name="foo", field_type=StringType(), required=False), - NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), - NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), - schema_id=1, - identifier_field_ids=[2], - ) +def test__bind_and_validate_static_overwrite_filter_predicate_fails_to_bind_due_to_incompatible_predicate_value( + iceberg_schema_simple: Schema, +) -> None: pred = EqualTo(Reference("bar"), "an incompatible type") partition_spec = PartitionSpec( PartitionField(source_id=2, field_id=1001, transform=IdentityTransform(), name="bar"), PartitionField(source_id=1, field_id=1002, transform=TruncateTransform(2), name="foo_trunc"), ) with pytest.raises(ValueError, match="Could not convert an incompatible type into a int"): - _bind_and_validate_static_overwrite_filter_predicate(unbound_expr=pred, table_schema=test_schema, spec=partition_spec) + _bind_and_validate_static_overwrite_filter_predicate( + unbound_expr=pred, table_schema=iceberg_schema_simple, spec=partition_spec + ) -def test__bind_and_validate_static_overwrite_filter_predicate_fails_to_bind_due_to_non_nullable() -> None: - test_schema = Schema( - NestedField(field_id=1, name="foo", field_type=StringType(), required=False), - NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), - NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), - schema_id=1, - identifier_field_ids=[2], - ) +def test__bind_and_validate_static_overwrite_filter_predicate_fails_to_bind_due_to_non_nullable( + iceberg_schema_simple: Schema, +) -> None: pred = IsNull(Reference("bar")) partition_spec = PartitionSpec( PartitionField(source_id=3, field_id=1001, transform=IdentityTransform(), name="baz"), @@ -1166,7 +1144,9 @@ def test__bind_and_validate_static_overwrite_filter_predicate_fails_to_bind_due_ "Static overwriting with part of the explicit partition filter not meaningful (specifing a non-nullable partition field to be null)" ), ): - _bind_and_validate_static_overwrite_filter_predicate(unbound_expr=pred, table_schema=test_schema, spec=partition_spec) + _bind_and_validate_static_overwrite_filter_predicate( + unbound_expr=pred, table_schema=iceberg_schema_simple, spec=partition_spec + ) def test__check_schema_with_filter_succeed(iceberg_schema_simple: Schema) -> None: @@ -1421,14 +1401,14 @@ def test_check_schema_succeed(table_schema_simple: Schema) -> None: _check_schema(table_schema_simple, other_schema) -def test_schema_succeed_on_pyarrow_table_reversed_column_order(table_schema_simple: Schema) -> None: +def test_schema_succeed_on_pyarrow_table_reversed_column_order(iceberg_schema_simple: Schema) -> None: other_schema = pa.schema(( pa.field("baz", pa.bool_(), nullable=True), pa.field("bar", pa.int32(), nullable=False), pa.field("foo", pa.string(), nullable=True), )) - _check_schema(table_schema_simple, other_schema) + _check_schema(iceberg_schema_simple, other_schema) def test_schema_mismatch_additional_field(table_schema_simple: Schema) -> None: