diff --git a/pyiceberg/catalog/__init__.py b/pyiceberg/catalog/__init__.py index 18d803fe1c..5bb9ec277a 100644 --- a/pyiceberg/catalog/__init__.py +++ b/pyiceberg/catalog/__init__.py @@ -779,7 +779,7 @@ def _get_updated_props_and_update_summary( def _resolve_table_location(self, location: Optional[str], database_name: str, table_name: str) -> str: if not location: return self._get_default_warehouse_location(database_name, table_name) - return location + return location.rstrip("/") def _get_default_warehouse_location(self, database_name: str, table_name: str) -> str: database_properties = self.load_namespace_properties(database_name) diff --git a/pyiceberg/catalog/rest.py b/pyiceberg/catalog/rest.py index 53e3f6a123..565d809194 100644 --- a/pyiceberg/catalog/rest.py +++ b/pyiceberg/catalog/rest.py @@ -519,6 +519,8 @@ def _create_table( fresh_sort_order = assign_fresh_sort_order_ids(sort_order, iceberg_schema, fresh_schema) namespace_and_table = self._split_identifier_for_path(identifier) + if location: + location = location.rstrip("/") request = CreateTableRequest( name=namespace_and_table["table"], location=location, diff --git a/tests/catalog/test_base.py b/tests/catalog/test_base.py index 7d5e0a973c..06e9a8a3aa 100644 --- a/tests/catalog/test_base.py +++ b/tests/catalog/test_base.py @@ -105,6 +105,7 @@ def create_table( if not location: location = f'{self._warehouse_location}/{"/".join(identifier)}' + location = location.rstrip("/") metadata_location = self._get_metadata_location(location=location) metadata = new_table_metadata( @@ -353,6 +354,19 @@ def test_create_table_location_override(catalog: InMemoryCatalog) -> None: assert table.location() == new_location +def test_create_table_removes_trailing_slash_from_location(catalog: InMemoryCatalog) -> None: + new_location = f"{catalog._warehouse_location}/new_location" + table = catalog.create_table( + identifier=TEST_TABLE_IDENTIFIER, + schema=TEST_TABLE_SCHEMA, + location=f"{new_location}/", + partition_spec=TEST_TABLE_PARTITION_SPEC, + properties=TEST_TABLE_PROPERTIES, + ) + assert catalog.load_table(TEST_TABLE_IDENTIFIER) == table + assert table.location() == new_location + + @pytest.mark.parametrize( "schema,expected", [ diff --git a/tests/catalog/test_dynamodb.py b/tests/catalog/test_dynamodb.py index 1c647cf828..f4b16d343b 100644 --- a/tests/catalog/test_dynamodb.py +++ b/tests/catalog/test_dynamodb.py @@ -117,6 +117,21 @@ def test_create_table_with_given_location( assert TABLE_METADATA_LOCATION_REGEX.match(table.metadata_location) +@mock_aws +def test_create_table_removes_trailing_slash_in_location( + _bucket_initialize: None, moto_endpoint_url: str, table_schema_nested: Schema, database_name: str, table_name: str +) -> None: + catalog_name = "test_ddb_catalog" + identifier = (database_name, table_name) + test_catalog = DynamoDbCatalog(catalog_name, **{"s3.endpoint": moto_endpoint_url}) + test_catalog.create_namespace(namespace=database_name) + location = f"s3://{BUCKET_NAME}/{database_name}.db/{table_name}" + table = test_catalog.create_table(identifier=identifier, schema=table_schema_nested, location=f"{location}/") + assert table.identifier == (catalog_name,) + identifier + assert table.location() == location + assert TABLE_METADATA_LOCATION_REGEX.match(table.metadata_location) + + @mock_aws def test_create_table_with_no_location( _bucket_initialize: None, table_schema_nested: Schema, database_name: str, table_name: str diff --git a/tests/catalog/test_glue.py b/tests/catalog/test_glue.py index 5999b192a2..5b67b92c68 100644 --- a/tests/catalog/test_glue.py +++ b/tests/catalog/test_glue.py @@ -137,6 +137,22 @@ def test_create_table_with_given_location( assert test_catalog._parse_metadata_version(table.metadata_location) == 0 +@mock_aws +def test_create_table_removes_trailing_slash_in_location( + _bucket_initialize: None, moto_endpoint_url: str, table_schema_nested: Schema, database_name: str, table_name: str +) -> None: + catalog_name = "glue" + identifier = (database_name, table_name) + test_catalog = GlueCatalog(catalog_name, **{"s3.endpoint": moto_endpoint_url}) + test_catalog.create_namespace(namespace=database_name) + location = f"s3://{BUCKET_NAME}/{database_name}.db/{table_name}" + table = test_catalog.create_table(identifier=identifier, schema=table_schema_nested, location=f"{location}/") + assert table.identifier == (catalog_name,) + identifier + assert table.location() == location + assert TABLE_METADATA_LOCATION_REGEX.match(table.metadata_location) + assert test_catalog._parse_metadata_version(table.metadata_location) == 0 + + @mock_aws def test_create_table_with_pyarrow_schema( _bucket_initialize: None, diff --git a/tests/catalog/test_hive.py b/tests/catalog/test_hive.py index 70927ea1bc..af3a380100 100644 --- a/tests/catalog/test_hive.py +++ b/tests/catalog/test_hive.py @@ -365,6 +365,181 @@ def test_create_table( assert metadata.model_dump() == expected.model_dump() +@pytest.mark.parametrize("hive2_compatible", [True, False]) +@patch("time.time", MagicMock(return_value=12345)) +def test_create_table_with_given_location_removes_trailing_slash( + table_schema_with_all_types: Schema, hive_database: HiveDatabase, hive_table: HiveTable, hive2_compatible: bool +) -> None: + catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL) + if hive2_compatible: + catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL, **{"hive.hive2-compatible": "true"}) + + location = f"{hive_database.locationUri}/table-given-location" + + catalog._client = MagicMock() + catalog._client.__enter__().create_table.return_value = None + catalog._client.__enter__().get_table.return_value = hive_table + catalog._client.__enter__().get_database.return_value = hive_database + catalog.create_table( + ("default", "table"), schema=table_schema_with_all_types, properties={"owner": "javaberg"}, location=f"{location}/" + ) + + called_hive_table: HiveTable = catalog._client.__enter__().create_table.call_args[0][0] + # This one is generated within the function itself, so we need to extract + # it to construct the assert_called_with + metadata_location: str = called_hive_table.parameters["metadata_location"] + assert metadata_location.endswith(".metadata.json") + assert "/database/table-given-location/metadata/" in metadata_location + catalog._client.__enter__().create_table.assert_called_with( + HiveTable( + tableName="table", + dbName="default", + owner="javaberg", + createTime=12345, + lastAccessTime=12345, + retention=None, + sd=StorageDescriptor( + cols=[ + FieldSchema(name='boolean', type='boolean', comment=None), + FieldSchema(name='integer', type='int', comment=None), + FieldSchema(name='long', type='bigint', comment=None), + FieldSchema(name='float', type='float', comment=None), + FieldSchema(name='double', type='double', comment=None), + FieldSchema(name='decimal', type='decimal(32,3)', comment=None), + FieldSchema(name='date', type='date', comment=None), + FieldSchema(name='time', type='string', comment=None), + FieldSchema(name='timestamp', type='timestamp', comment=None), + FieldSchema( + name='timestamptz', + type='timestamp' if hive2_compatible else 'timestamp with local time zone', + comment=None, + ), + FieldSchema(name='string', type='string', comment=None), + FieldSchema(name='uuid', type='string', comment=None), + FieldSchema(name='fixed', type='binary', comment=None), + FieldSchema(name='binary', type='binary', comment=None), + FieldSchema(name='list', type='array', comment=None), + FieldSchema(name='map', type='map', comment=None), + FieldSchema(name='struct', type='struct', comment=None), + ], + location=f"{hive_database.locationUri}/table-given-location", + inputFormat="org.apache.hadoop.mapred.FileInputFormat", + outputFormat="org.apache.hadoop.mapred.FileOutputFormat", + compressed=None, + numBuckets=None, + serdeInfo=SerDeInfo( + name=None, + serializationLib="org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe", + parameters=None, + description=None, + serializerClass=None, + deserializerClass=None, + serdeType=None, + ), + bucketCols=None, + sortCols=None, + parameters=None, + skewedInfo=None, + storedAsSubDirectories=None, + ), + partitionKeys=None, + parameters={"EXTERNAL": "TRUE", "table_type": "ICEBERG", "metadata_location": metadata_location}, + viewOriginalText=None, + viewExpandedText=None, + tableType="EXTERNAL_TABLE", + privileges=None, + temporary=False, + rewriteEnabled=None, + creationMetadata=None, + catName=None, + ownerType=1, + writeId=-1, + isStatsCompliant=None, + colStats=None, + accessType=None, + requiredReadCapabilities=None, + requiredWriteCapabilities=None, + id=None, + fileMetadata=None, + dictionary=None, + txnId=None, + ) + ) + + with open(metadata_location, encoding=UTF8) as f: + payload = f.read() + + metadata = TableMetadataUtil.parse_raw(payload) + + assert "database/table-given-location" in metadata.location + + expected = TableMetadataV2( + location=metadata.location, + table_uuid=metadata.table_uuid, + last_updated_ms=metadata.last_updated_ms, + last_column_id=22, + schemas=[ + Schema( + NestedField(field_id=1, name='boolean', field_type=BooleanType(), required=True), + NestedField(field_id=2, name='integer', field_type=IntegerType(), required=True), + NestedField(field_id=3, name='long', field_type=LongType(), required=True), + NestedField(field_id=4, name='float', field_type=FloatType(), required=True), + NestedField(field_id=5, name='double', field_type=DoubleType(), required=True), + NestedField(field_id=6, name='decimal', field_type=DecimalType(precision=32, scale=3), required=True), + NestedField(field_id=7, name='date', field_type=DateType(), required=True), + NestedField(field_id=8, name='time', field_type=TimeType(), required=True), + NestedField(field_id=9, name='timestamp', field_type=TimestampType(), required=True), + NestedField(field_id=10, name='timestamptz', field_type=TimestamptzType(), required=True), + NestedField(field_id=11, name='string', field_type=StringType(), required=True), + NestedField(field_id=12, name='uuid', field_type=UUIDType(), required=True), + NestedField(field_id=13, name='fixed', field_type=FixedType(length=12), required=True), + NestedField(field_id=14, name='binary', field_type=BinaryType(), required=True), + NestedField( + field_id=15, + name='list', + field_type=ListType(type='list', element_id=18, element_type=StringType(), element_required=True), + required=True, + ), + NestedField( + field_id=16, + name='map', + field_type=MapType( + type='map', key_id=19, key_type=StringType(), value_id=20, value_type=IntegerType(), value_required=True + ), + required=True, + ), + NestedField( + field_id=17, + name='struct', + field_type=StructType( + NestedField(field_id=21, name='inner_string', field_type=StringType(), required=False), + NestedField(field_id=22, name='inner_int', field_type=IntegerType(), required=True), + ), + required=False, + ), + schema_id=0, + identifier_field_ids=[2], + ) + ], + current_schema_id=0, + last_partition_id=999, + properties={"owner": "javaberg", 'write.parquet.compression-codec': 'zstd'}, + partition_specs=[PartitionSpec()], + default_spec_id=0, + current_snapshot_id=None, + snapshots=[], + snapshot_log=[], + metadata_log=[], + sort_orders=[SortOrder(order_id=0)], + default_sort_order_id=0, + refs={}, + format_version=2, + last_sequence_number=0, + ) + + assert metadata.model_dump() == expected.model_dump() + + @patch("time.time", MagicMock(return_value=12345)) def test_create_v1_table(table_schema_simple: Schema, hive_database: HiveDatabase, hive_table: HiveTable) -> None: catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL) diff --git a/tests/catalog/test_rest.py b/tests/catalog/test_rest.py index 15ddb01b25..b8410d6841 100644 --- a/tests/catalog/test_rest.py +++ b/tests/catalog/test_rest.py @@ -732,6 +732,31 @@ def test_create_table_200( assert actual == expected +def test_create_table_with_given_location_removes_trailing_slash_200( + rest_mock: Mocker, table_schema_simple: Schema, example_table_metadata_no_snapshot_v1_rest_json: Dict[str, Any] +) -> None: + rest_mock.post( + f"{TEST_URI}v1/namespaces/fokko/tables", + json=example_table_metadata_no_snapshot_v1_rest_json, + status_code=200, + request_headers=TEST_HEADERS, + ) + catalog = RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN) + location = "s3://warehouse/database/table-custom-location" + catalog.create_table( + identifier=("fokko", "fokko2"), + schema=table_schema_simple, + location=f"{location}/", + partition_spec=PartitionSpec( + PartitionField(source_id=1, field_id=1000, transform=TruncateTransform(width=3), name="id"), spec_id=1 + ), + sort_order=SortOrder(SortField(source_id=2, transform=IdentityTransform())), + properties={"owner": "fokko"}, + ) + assert rest_mock.last_request + assert rest_mock.last_request.json()["location"] == location + + def test_create_table_409(rest_mock: Mocker, table_schema_simple: Schema) -> None: rest_mock.post( f"{TEST_URI}v1/namespaces/fokko/tables", diff --git a/tests/catalog/test_sql.py b/tests/catalog/test_sql.py index 40a1566e2f..9796526887 100644 --- a/tests/catalog/test_sql.py +++ b/tests/catalog/test_sql.py @@ -264,6 +264,28 @@ def test_create_table_with_default_warehouse_location( catalog.drop_table(random_identifier) +@pytest.mark.parametrize( + 'catalog', + [ + lazy_fixture('catalog_memory'), + lazy_fixture('catalog_sqlite'), + ], +) +def test_create_table_with_given_location_removes_trailing_slash( + warehouse: Path, catalog: SqlCatalog, table_schema_nested: Schema, random_identifier: Identifier +) -> None: + database_name, table_name = random_identifier + location = f"file://{warehouse}/{database_name}.db/{table_name}-given" + catalog.create_namespace(database_name) + catalog.create_table(random_identifier, table_schema_nested, location=f"{location}/") + table = catalog.load_table(random_identifier) + assert table.identifier == (catalog.name,) + random_identifier + assert table.metadata_location.startswith(f"file://{warehouse}") + assert os.path.exists(table.metadata_location[len("file://") :]) + assert table.location() == location + catalog.drop_table(random_identifier) + + @pytest.mark.parametrize( 'catalog', [