Skip to content

Commit 0508667

Browse files
authored
Add support for categorical type (#693)
1 parent 990ce80 commit 0508667

File tree

3 files changed

+48
-0
lines changed

3 files changed

+48
-0
lines changed

pyiceberg/io/pyarrow.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -731,6 +731,16 @@ def _(obj: pa.MapType, visitor: PyArrowSchemaVisitor[T]) -> T:
731731
return visitor.map(obj, key_result, value_result)
732732

733733

734+
@visit_pyarrow.register(pa.DictionaryType)
735+
def _(obj: pa.DictionaryType, visitor: PyArrowSchemaVisitor[T]) -> T:
736+
# Parquet has no dictionary type. dictionary-encoding is handled
737+
# as an encoding detail, not as a separate type.
738+
# We will follow this approach in determining the Iceberg Type,
739+
# as we only support parquet in PyIceberg for now.
740+
logger.warning(f"Iceberg does not have a dictionary type. {type(obj)} will be inferred as {obj.value_type} on read.")
741+
return visit_pyarrow(obj.value_type, visitor)
742+
743+
734744
@visit_pyarrow.register(pa.DataType)
735745
def _(obj: pa.DataType, visitor: PyArrowSchemaVisitor[T]) -> T:
736746
if pa.types.is_nested(obj):

tests/integration/test_writes/test_writes.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,30 @@ def test_python_writes_special_character_column_with_spark_reads(
315315
assert spark_df.equals(pyiceberg_df)
316316

317317

318+
@pytest.mark.integration
319+
@pytest.mark.parametrize("format_version", [1, 2])
320+
def test_python_writes_dictionary_encoded_column_with_spark_reads(
321+
spark: SparkSession, session_catalog: Catalog, format_version: int
322+
) -> None:
323+
identifier = "default.python_writes_dictionary_encoded_column_with_spark_reads"
324+
TEST_DATA = {
325+
'id': [1, 2, 3, 1, 1],
326+
'name': ['AB', 'CD', 'EF', 'CD', 'EF'],
327+
}
328+
pa_schema = pa.schema([
329+
pa.field('id', pa.dictionary(pa.int32(), pa.int32(), False)),
330+
pa.field('name', pa.dictionary(pa.int32(), pa.string(), False)),
331+
])
332+
arrow_table = pa.Table.from_pydict(TEST_DATA, schema=pa_schema)
333+
334+
tbl = _create_table(session_catalog, identifier, {"format-version": format_version}, schema=pa_schema)
335+
336+
tbl.overwrite(arrow_table)
337+
spark_df = spark.sql(f"SELECT * FROM {identifier}").toPandas()
338+
pyiceberg_df = tbl.scan().to_pandas()
339+
assert spark_df.equals(pyiceberg_df)
340+
341+
318342
@pytest.mark.integration
319343
def test_write_bin_pack_data_files(spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None:
320344
identifier = "default.write_bin_pack_data_files"

tests/io/test_pyarrow_visitor.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
DoubleType,
4040
FixedType,
4141
FloatType,
42+
IcebergType,
4243
IntegerType,
4344
ListType,
4445
LongType,
@@ -280,6 +281,19 @@ def test_pyarrow_map_to_iceberg() -> None:
280281
assert visit_pyarrow(pyarrow_map, _ConvertToIceberg()) == expected
281282

282283

284+
@pytest.mark.parametrize(
285+
"value_type, expected_result",
286+
[
287+
(pa.string(), StringType()),
288+
(pa.int32(), IntegerType()),
289+
(pa.float64(), DoubleType()),
290+
],
291+
)
292+
def test_pyarrow_dictionary_encoded_type_to_iceberg(value_type: pa.DataType, expected_result: IcebergType) -> None:
293+
pyarrow_dict = pa.dictionary(pa.int32(), value_type)
294+
assert visit_pyarrow(pyarrow_dict, _ConvertToIceberg()) == expected_result
295+
296+
283297
def test_round_schema_conversion_simple(table_schema_simple: Schema) -> None:
284298
actual = str(pyarrow_to_schema(schema_to_pyarrow(table_schema_simple)))
285299
expected = """table {

0 commit comments

Comments
 (0)