Skip to content

Commit 0339e7f

Browse files
authored
Support CreateTableTransaction for SqlCatalog (#684)
1 parent 91973f2 commit 0339e7f

File tree

2 files changed

+124
-40
lines changed

2 files changed

+124
-40
lines changed

pyiceberg/catalog/sql.py

Lines changed: 64 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionSpec
6161
from pyiceberg.schema import Schema
6262
from pyiceberg.serializers import FromInputFile
63-
from pyiceberg.table import CommitTableRequest, CommitTableResponse, Table, update_table_metadata
63+
from pyiceberg.table import CommitTableRequest, CommitTableResponse, Table
6464
from pyiceberg.table.metadata import new_table_metadata
6565
from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder
6666
from pyiceberg.typedef import EMPTY_DICT, Identifier, Properties
@@ -402,59 +402,83 @@ def _commit_table(self, table_request: CommitTableRequest) -> CommitTableRespons
402402
identifier_tuple = self.identifier_to_tuple_without_catalog(
403403
tuple(table_request.identifier.namespace.root + [table_request.identifier.name])
404404
)
405-
current_table = self.load_table(identifier_tuple)
406405
namespace_tuple = Catalog.namespace_from(identifier_tuple)
407406
namespace = Catalog.namespace_to_string(namespace_tuple)
408407
table_name = Catalog.table_name_from(identifier_tuple)
409-
base_metadata = current_table.metadata
410-
for requirement in table_request.requirements:
411-
requirement.validate(base_metadata)
412408

413-
updated_metadata = update_table_metadata(base_metadata, table_request.updates)
414-
if updated_metadata == base_metadata:
415-
# no changes, do nothing
416-
return CommitTableResponse(metadata=base_metadata, metadata_location=current_table.metadata_location)
409+
current_table: Optional[Table]
410+
try:
411+
current_table = self.load_table(identifier_tuple)
412+
except NoSuchTableError:
413+
current_table = None
417414

418-
# write new metadata
419-
new_metadata_version = self._parse_metadata_version(current_table.metadata_location) + 1
420-
new_metadata_location = self._get_metadata_location(current_table.metadata.location, new_metadata_version)
421-
self._write_metadata(updated_metadata, current_table.io, new_metadata_location)
415+
updated_staged_table = self._update_and_stage_table(current_table, table_request)
416+
if current_table and updated_staged_table.metadata == current_table.metadata:
417+
# no changes, do nothing
418+
return CommitTableResponse(metadata=current_table.metadata, metadata_location=current_table.metadata_location)
419+
self._write_metadata(
420+
metadata=updated_staged_table.metadata,
421+
io=updated_staged_table.io,
422+
metadata_path=updated_staged_table.metadata_location,
423+
)
422424

423425
with Session(self.engine) as session:
424-
if self.engine.dialect.supports_sane_rowcount:
425-
stmt = (
426-
update(IcebergTables)
427-
.where(
428-
IcebergTables.catalog_name == self.name,
429-
IcebergTables.table_namespace == namespace,
430-
IcebergTables.table_name == table_name,
431-
IcebergTables.metadata_location == current_table.metadata_location,
432-
)
433-
.values(metadata_location=new_metadata_location, previous_metadata_location=current_table.metadata_location)
434-
)
435-
result = session.execute(stmt)
436-
if result.rowcount < 1:
437-
raise CommitFailedException(f"Table has been updated by another process: {namespace}.{table_name}")
438-
else:
439-
try:
440-
tbl = (
441-
session.query(IcebergTables)
442-
.with_for_update(of=IcebergTables)
443-
.filter(
426+
if current_table:
427+
# table exists, update it
428+
if self.engine.dialect.supports_sane_rowcount:
429+
stmt = (
430+
update(IcebergTables)
431+
.where(
444432
IcebergTables.catalog_name == self.name,
445433
IcebergTables.table_namespace == namespace,
446434
IcebergTables.table_name == table_name,
447435
IcebergTables.metadata_location == current_table.metadata_location,
448436
)
449-
.one()
437+
.values(
438+
metadata_location=updated_staged_table.metadata_location,
439+
previous_metadata_location=current_table.metadata_location,
440+
)
450441
)
451-
tbl.metadata_location = new_metadata_location
452-
tbl.previous_metadata_location = current_table.metadata_location
453-
except NoResultFound as e:
454-
raise CommitFailedException(f"Table has been updated by another process: {namespace}.{table_name}") from e
455-
session.commit()
442+
result = session.execute(stmt)
443+
if result.rowcount < 1:
444+
raise CommitFailedException(f"Table has been updated by another process: {namespace}.{table_name}")
445+
else:
446+
try:
447+
tbl = (
448+
session.query(IcebergTables)
449+
.with_for_update(of=IcebergTables)
450+
.filter(
451+
IcebergTables.catalog_name == self.name,
452+
IcebergTables.table_namespace == namespace,
453+
IcebergTables.table_name == table_name,
454+
IcebergTables.metadata_location == current_table.metadata_location,
455+
)
456+
.one()
457+
)
458+
tbl.metadata_location = updated_staged_table.metadata_location
459+
tbl.previous_metadata_location = current_table.metadata_location
460+
except NoResultFound as e:
461+
raise CommitFailedException(f"Table has been updated by another process: {namespace}.{table_name}") from e
462+
session.commit()
463+
else:
464+
# table does not exist, create it
465+
try:
466+
session.add(
467+
IcebergTables(
468+
catalog_name=self.name,
469+
table_namespace=namespace,
470+
table_name=table_name,
471+
metadata_location=updated_staged_table.metadata_location,
472+
previous_metadata_location=None,
473+
)
474+
)
475+
session.commit()
476+
except IntegrityError as e:
477+
raise TableAlreadyExistsError(f"Table {namespace}.{table_name} already exists") from e
456478

457-
return CommitTableResponse(metadata=updated_metadata, metadata_location=new_metadata_location)
479+
return CommitTableResponse(
480+
metadata=updated_staged_table.metadata, metadata_location=updated_staged_table.metadata_location
481+
)
458482

459483
def _namespace_exists(self, identifier: Union[str, Identifier]) -> bool:
460484
namespace_tuple = Catalog.identifier_to_tuple(identifier)

tests/catalog/test_sql.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1350,6 +1350,66 @@ def test_write_and_evolve(catalog: SqlCatalog, format_version: int) -> None:
13501350
snapshot_update.append_data_file(data_file)
13511351

13521352

1353+
@pytest.mark.parametrize(
1354+
"catalog",
1355+
[
1356+
lazy_fixture("catalog_memory"),
1357+
lazy_fixture("catalog_sqlite"),
1358+
lazy_fixture("catalog_sqlite_without_rowcount"),
1359+
],
1360+
)
1361+
@pytest.mark.parametrize("format_version", [1, 2])
1362+
def test_create_table_transaction(catalog: SqlCatalog, format_version: int) -> None:
1363+
identifier = f"default.arrow_create_table_transaction_{catalog.name}_{format_version}"
1364+
try:
1365+
catalog.create_namespace("default")
1366+
except NamespaceAlreadyExistsError:
1367+
pass
1368+
1369+
try:
1370+
catalog.drop_table(identifier=identifier)
1371+
except NoSuchTableError:
1372+
pass
1373+
1374+
pa_table = pa.Table.from_pydict(
1375+
{
1376+
"foo": ["a", None, "z"],
1377+
},
1378+
schema=pa.schema([pa.field("foo", pa.string(), nullable=True)]),
1379+
)
1380+
1381+
pa_table_with_column = pa.Table.from_pydict(
1382+
{
1383+
"foo": ["a", None, "z"],
1384+
"bar": [19, None, 25],
1385+
},
1386+
schema=pa.schema([
1387+
pa.field("foo", pa.string(), nullable=True),
1388+
pa.field("bar", pa.int32(), nullable=True),
1389+
]),
1390+
)
1391+
1392+
with catalog.create_table_transaction(
1393+
identifier=identifier, schema=pa_table.schema, properties={"format-version": str(format_version)}
1394+
) as txn:
1395+
with txn.update_snapshot().fast_append() as snapshot_update:
1396+
for data_file in _dataframe_to_data_files(table_metadata=txn.table_metadata, df=pa_table, io=txn._table.io):
1397+
snapshot_update.append_data_file(data_file)
1398+
1399+
with txn.update_schema() as schema_txn:
1400+
schema_txn.union_by_name(pa_table_with_column.schema)
1401+
1402+
with txn.update_snapshot().fast_append() as snapshot_update:
1403+
for data_file in _dataframe_to_data_files(
1404+
table_metadata=txn.table_metadata, df=pa_table_with_column, io=txn._table.io
1405+
):
1406+
snapshot_update.append_data_file(data_file)
1407+
1408+
tbl = catalog.load_table(identifier=identifier)
1409+
assert tbl.format_version == format_version
1410+
assert len(tbl.scan().to_arrow()) == 6
1411+
1412+
13531413
@pytest.mark.parametrize(
13541414
"catalog",
13551415
[

0 commit comments

Comments
 (0)