Skip to content

Remove trailing slash from table location when creating a table #702

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyiceberg/catalog/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,7 @@ def _get_updated_props_and_update_summary(
def _resolve_table_location(self, location: Optional[str], database_name: str, table_name: str) -> str:
if not location:
return self._get_default_warehouse_location(database_name, table_name)
return location
return location.rstrip("/")

def _get_default_warehouse_location(self, database_name: str, table_name: str) -> str:
database_properties = self.load_namespace_properties(database_name)
Expand Down
2 changes: 2 additions & 0 deletions pyiceberg/catalog/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,8 @@ def _create_table(
fresh_sort_order = assign_fresh_sort_order_ids(sort_order, iceberg_schema, fresh_schema)

namespace_and_table = self._split_identifier_for_path(identifier)
if location:
location = location.rstrip("/")
request = CreateTableRequest(
name=namespace_and_table["table"],
location=location,
Expand Down
14 changes: 14 additions & 0 deletions tests/catalog/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def create_table(

if not location:
location = f'{self._warehouse_location}/{"/".join(identifier)}'
location = location.rstrip("/")

metadata_location = self._get_metadata_location(location=location)
metadata = new_table_metadata(
Expand Down Expand Up @@ -353,6 +354,19 @@ def test_create_table_location_override(catalog: InMemoryCatalog) -> None:
assert table.location() == new_location


def test_create_table_removes_trailing_slash_from_location(catalog: InMemoryCatalog) -> None:
new_location = f"{catalog._warehouse_location}/new_location"
table = catalog.create_table(
identifier=TEST_TABLE_IDENTIFIER,
schema=TEST_TABLE_SCHEMA,
location=f"{new_location}/",
partition_spec=TEST_TABLE_PARTITION_SPEC,
properties=TEST_TABLE_PROPERTIES,
)
assert catalog.load_table(TEST_TABLE_IDENTIFIER) == table
assert table.location() == new_location


@pytest.mark.parametrize(
"schema,expected",
[
Expand Down
15 changes: 15 additions & 0 deletions tests/catalog/test_dynamodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,21 @@ def test_create_table_with_given_location(
assert TABLE_METADATA_LOCATION_REGEX.match(table.metadata_location)


@mock_aws
def test_create_table_removes_trailing_slash_in_location(
_bucket_initialize: None, moto_endpoint_url: str, table_schema_nested: Schema, database_name: str, table_name: str
) -> None:
catalog_name = "test_ddb_catalog"
identifier = (database_name, table_name)
test_catalog = DynamoDbCatalog(catalog_name, **{"s3.endpoint": moto_endpoint_url})
test_catalog.create_namespace(namespace=database_name)
location = f"s3://{BUCKET_NAME}/{database_name}.db/{table_name}"
table = test_catalog.create_table(identifier=identifier, schema=table_schema_nested, location=f"{location}/")
assert table.identifier == (catalog_name,) + identifier
assert table.location() == location
assert TABLE_METADATA_LOCATION_REGEX.match(table.metadata_location)


@mock_aws
def test_create_table_with_no_location(
_bucket_initialize: None, table_schema_nested: Schema, database_name: str, table_name: str
Expand Down
16 changes: 16 additions & 0 deletions tests/catalog/test_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,22 @@ def test_create_table_with_given_location(
assert test_catalog._parse_metadata_version(table.metadata_location) == 0


@mock_aws
def test_create_table_removes_trailing_slash_in_location(
_bucket_initialize: None, moto_endpoint_url: str, table_schema_nested: Schema, database_name: str, table_name: str
) -> None:
catalog_name = "glue"
identifier = (database_name, table_name)
test_catalog = GlueCatalog(catalog_name, **{"s3.endpoint": moto_endpoint_url})
test_catalog.create_namespace(namespace=database_name)
location = f"s3://{BUCKET_NAME}/{database_name}.db/{table_name}"
table = test_catalog.create_table(identifier=identifier, schema=table_schema_nested, location=f"{location}/")
assert table.identifier == (catalog_name,) + identifier
assert table.location() == location
assert TABLE_METADATA_LOCATION_REGEX.match(table.metadata_location)
assert test_catalog._parse_metadata_version(table.metadata_location) == 0


@mock_aws
def test_create_table_with_pyarrow_schema(
_bucket_initialize: None,
Expand Down
175 changes: 175 additions & 0 deletions tests/catalog/test_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,181 @@ def test_create_table(
assert metadata.model_dump() == expected.model_dump()


@pytest.mark.parametrize("hive2_compatible", [True, False])
@patch("time.time", MagicMock(return_value=12345))
def test_create_table_with_given_location_removes_trailing_slash(
table_schema_with_all_types: Schema, hive_database: HiveDatabase, hive_table: HiveTable, hive2_compatible: bool
) -> None:
catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL)
if hive2_compatible:
catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL, **{"hive.hive2-compatible": "true"})

location = f"{hive_database.locationUri}/table-given-location"

catalog._client = MagicMock()
catalog._client.__enter__().create_table.return_value = None
catalog._client.__enter__().get_table.return_value = hive_table
catalog._client.__enter__().get_database.return_value = hive_database
catalog.create_table(
("default", "table"), schema=table_schema_with_all_types, properties={"owner": "javaberg"}, location=f"{location}/"
)

called_hive_table: HiveTable = catalog._client.__enter__().create_table.call_args[0][0]
# This one is generated within the function itself, so we need to extract
# it to construct the assert_called_with
metadata_location: str = called_hive_table.parameters["metadata_location"]
assert metadata_location.endswith(".metadata.json")
assert "/database/table-given-location/metadata/" in metadata_location
catalog._client.__enter__().create_table.assert_called_with(
HiveTable(
tableName="table",
dbName="default",
owner="javaberg",
createTime=12345,
lastAccessTime=12345,
retention=None,
sd=StorageDescriptor(
cols=[
FieldSchema(name='boolean', type='boolean', comment=None),
FieldSchema(name='integer', type='int', comment=None),
FieldSchema(name='long', type='bigint', comment=None),
FieldSchema(name='float', type='float', comment=None),
FieldSchema(name='double', type='double', comment=None),
FieldSchema(name='decimal', type='decimal(32,3)', comment=None),
FieldSchema(name='date', type='date', comment=None),
FieldSchema(name='time', type='string', comment=None),
FieldSchema(name='timestamp', type='timestamp', comment=None),
FieldSchema(
name='timestamptz',
type='timestamp' if hive2_compatible else 'timestamp with local time zone',
comment=None,
),
FieldSchema(name='string', type='string', comment=None),
FieldSchema(name='uuid', type='string', comment=None),
FieldSchema(name='fixed', type='binary', comment=None),
FieldSchema(name='binary', type='binary', comment=None),
FieldSchema(name='list', type='array<string>', comment=None),
FieldSchema(name='map', type='map<string,int>', comment=None),
FieldSchema(name='struct', type='struct<inner_string:string,inner_int:int>', comment=None),
],
location=f"{hive_database.locationUri}/table-given-location",
inputFormat="org.apache.hadoop.mapred.FileInputFormat",
outputFormat="org.apache.hadoop.mapred.FileOutputFormat",
compressed=None,
numBuckets=None,
serdeInfo=SerDeInfo(
name=None,
serializationLib="org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe",
parameters=None,
description=None,
serializerClass=None,
deserializerClass=None,
serdeType=None,
),
bucketCols=None,
sortCols=None,
parameters=None,
skewedInfo=None,
storedAsSubDirectories=None,
),
partitionKeys=None,
parameters={"EXTERNAL": "TRUE", "table_type": "ICEBERG", "metadata_location": metadata_location},
viewOriginalText=None,
viewExpandedText=None,
tableType="EXTERNAL_TABLE",
privileges=None,
temporary=False,
rewriteEnabled=None,
creationMetadata=None,
catName=None,
ownerType=1,
writeId=-1,
isStatsCompliant=None,
colStats=None,
accessType=None,
requiredReadCapabilities=None,
requiredWriteCapabilities=None,
id=None,
fileMetadata=None,
dictionary=None,
txnId=None,
)
)

with open(metadata_location, encoding=UTF8) as f:
payload = f.read()

metadata = TableMetadataUtil.parse_raw(payload)

assert "database/table-given-location" in metadata.location

expected = TableMetadataV2(
location=metadata.location,
table_uuid=metadata.table_uuid,
last_updated_ms=metadata.last_updated_ms,
last_column_id=22,
schemas=[
Schema(
NestedField(field_id=1, name='boolean', field_type=BooleanType(), required=True),
NestedField(field_id=2, name='integer', field_type=IntegerType(), required=True),
NestedField(field_id=3, name='long', field_type=LongType(), required=True),
NestedField(field_id=4, name='float', field_type=FloatType(), required=True),
NestedField(field_id=5, name='double', field_type=DoubleType(), required=True),
NestedField(field_id=6, name='decimal', field_type=DecimalType(precision=32, scale=3), required=True),
NestedField(field_id=7, name='date', field_type=DateType(), required=True),
NestedField(field_id=8, name='time', field_type=TimeType(), required=True),
NestedField(field_id=9, name='timestamp', field_type=TimestampType(), required=True),
NestedField(field_id=10, name='timestamptz', field_type=TimestamptzType(), required=True),
NestedField(field_id=11, name='string', field_type=StringType(), required=True),
NestedField(field_id=12, name='uuid', field_type=UUIDType(), required=True),
NestedField(field_id=13, name='fixed', field_type=FixedType(length=12), required=True),
NestedField(field_id=14, name='binary', field_type=BinaryType(), required=True),
NestedField(
field_id=15,
name='list',
field_type=ListType(type='list', element_id=18, element_type=StringType(), element_required=True),
required=True,
),
NestedField(
field_id=16,
name='map',
field_type=MapType(
type='map', key_id=19, key_type=StringType(), value_id=20, value_type=IntegerType(), value_required=True
),
required=True,
),
NestedField(
field_id=17,
name='struct',
field_type=StructType(
NestedField(field_id=21, name='inner_string', field_type=StringType(), required=False),
NestedField(field_id=22, name='inner_int', field_type=IntegerType(), required=True),
),
required=False,
),
schema_id=0,
identifier_field_ids=[2],
)
],
current_schema_id=0,
last_partition_id=999,
properties={"owner": "javaberg", 'write.parquet.compression-codec': 'zstd'},
partition_specs=[PartitionSpec()],
default_spec_id=0,
current_snapshot_id=None,
snapshots=[],
snapshot_log=[],
metadata_log=[],
sort_orders=[SortOrder(order_id=0)],
default_sort_order_id=0,
refs={},
format_version=2,
last_sequence_number=0,
)

assert metadata.model_dump() == expected.model_dump()


@patch("time.time", MagicMock(return_value=12345))
def test_create_v1_table(table_schema_simple: Schema, hive_database: HiveDatabase, hive_table: HiveTable) -> None:
catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL)
Expand Down
25 changes: 25 additions & 0 deletions tests/catalog/test_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,31 @@ def test_create_table_200(
assert actual == expected


def test_create_table_with_given_location_removes_trailing_slash_200(
rest_mock: Mocker, table_schema_simple: Schema, example_table_metadata_no_snapshot_v1_rest_json: Dict[str, Any]
) -> None:
rest_mock.post(
f"{TEST_URI}v1/namespaces/fokko/tables",
json=example_table_metadata_no_snapshot_v1_rest_json,
status_code=200,
request_headers=TEST_HEADERS,
)
catalog = RestCatalog("rest", uri=TEST_URI, token=TEST_TOKEN)
location = "s3://warehouse/database/table-custom-location"
catalog.create_table(
identifier=("fokko", "fokko2"),
schema=table_schema_simple,
location=f"{location}/",
partition_spec=PartitionSpec(
PartitionField(source_id=1, field_id=1000, transform=TruncateTransform(width=3), name="id"), spec_id=1
),
sort_order=SortOrder(SortField(source_id=2, transform=IdentityTransform())),
properties={"owner": "fokko"},
)
assert rest_mock.last_request
assert rest_mock.last_request.json()["location"] == location


def test_create_table_409(rest_mock: Mocker, table_schema_simple: Schema) -> None:
rest_mock.post(
f"{TEST_URI}v1/namespaces/fokko/tables",
Expand Down
22 changes: 22 additions & 0 deletions tests/catalog/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,28 @@ def test_create_table_with_default_warehouse_location(
catalog.drop_table(random_identifier)


@pytest.mark.parametrize(
'catalog',
[
lazy_fixture('catalog_memory'),
lazy_fixture('catalog_sqlite'),
],
)
def test_create_table_with_given_location_removes_trailing_slash(
warehouse: Path, catalog: SqlCatalog, table_schema_nested: Schema, random_identifier: Identifier
) -> None:
database_name, table_name = random_identifier
location = f"file://{warehouse}/{database_name}.db/{table_name}-given"
catalog.create_namespace(database_name)
catalog.create_table(random_identifier, table_schema_nested, location=f"{location}/")
table = catalog.load_table(random_identifier)
assert table.identifier == (catalog.name,) + random_identifier
assert table.metadata_location.startswith(f"file://{warehouse}")
assert os.path.exists(table.metadata_location[len("file://") :])
assert table.location() == location
catalog.drop_table(random_identifier)


@pytest.mark.parametrize(
'catalog',
[
Expand Down