27
27
28
28
import concurrent .futures
29
29
import fnmatch
30
+ import functools
30
31
import itertools
31
32
import logging
33
+ import operator
32
34
import os
33
35
import re
34
36
import uuid
@@ -2174,7 +2176,10 @@ def _partition_value(self, partition_field: PartitionField, schema: Schema) -> A
2174
2176
raise ValueError (
2175
2177
f"Cannot infer partition value from parquet metadata as there are more than one partition values for Partition Field: { partition_field .name } . { lower_value = } , { upper_value = } "
2176
2178
)
2177
- return lower_value
2179
+
2180
+ source_field = schema .find_field (partition_field .source_id )
2181
+ transform = partition_field .transform .transform (source_field .field_type )
2182
+ return transform (lower_value )
2178
2183
2179
2184
def partition (self , partition_spec : PartitionSpec , schema : Schema ) -> Record :
2180
2185
return Record (** {field .name : self ._partition_value (field , schema ) for field in partition_spec .fields })
@@ -2558,38 +2563,8 @@ class _TablePartition:
2558
2563
arrow_table_partition : pa .Table
2559
2564
2560
2565
2561
- def _get_table_partitions (
2562
- arrow_table : pa .Table ,
2563
- partition_spec : PartitionSpec ,
2564
- schema : Schema ,
2565
- slice_instructions : list [dict [str , Any ]],
2566
- ) -> list [_TablePartition ]:
2567
- sorted_slice_instructions = sorted (slice_instructions , key = lambda x : x ["offset" ])
2568
-
2569
- partition_fields = partition_spec .fields
2570
-
2571
- offsets = [inst ["offset" ] for inst in sorted_slice_instructions ]
2572
- projected_and_filtered = {
2573
- partition_field .source_id : arrow_table [schema .find_field (name_or_id = partition_field .source_id ).name ]
2574
- .take (offsets )
2575
- .to_pylist ()
2576
- for partition_field in partition_fields
2577
- }
2578
-
2579
- table_partitions = []
2580
- for idx , inst in enumerate (sorted_slice_instructions ):
2581
- partition_slice = arrow_table .slice (** inst )
2582
- fieldvalues = [
2583
- PartitionFieldValue (partition_field , projected_and_filtered [partition_field .source_id ][idx ])
2584
- for partition_field in partition_fields
2585
- ]
2586
- partition_key = PartitionKey (raw_partition_field_values = fieldvalues , partition_spec = partition_spec , schema = schema )
2587
- table_partitions .append (_TablePartition (partition_key = partition_key , arrow_table_partition = partition_slice ))
2588
- return table_partitions
2589
-
2590
-
2591
2566
def _determine_partitions (spec : PartitionSpec , schema : Schema , arrow_table : pa .Table ) -> List [_TablePartition ]:
2592
- """Based on the iceberg table partition spec, slice the arrow table into partitions with their keys.
2567
+ """Based on the iceberg table partition spec, filter the arrow table into partitions with their keys.
2593
2568
2594
2569
Example:
2595
2570
Input:
@@ -2598,54 +2573,50 @@ def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.T
2598
2573
'n_legs': [2, 2, 2, 4, 4, 4, 4, 5, 100],
2599
2574
'animal': ["Flamingo", "Parrot", "Parrot", "Dog", "Horse", "Horse", "Horse","Brittle stars", "Centipede"]}.
2600
2575
The algorithm:
2601
- Firstly we group the rows into partitions by sorting with sort order [('n_legs', 'descending'), ('year', 'descending')]
2602
- and null_placement of "at_end".
2603
- This gives the same table as raw input.
2604
- Then we sort_indices using reverse order of [('n_legs', 'descending'), ('year', 'descending')]
2605
- and null_placement : "at_start".
2606
- This gives:
2607
- [8, 7, 4, 5, 6, 3, 1, 2, 0]
2608
- Based on this we get partition groups of indices:
2609
- [{'offset': 8, 'length': 1}, {'offset': 7, 'length': 1}, {'offset': 4, 'length': 3}, {'offset': 3, 'length': 1}, {'offset': 1, 'length': 2}, {'offset': 0, 'length': 1}]
2610
- We then retrieve the partition keys by offsets.
2611
- And slice the arrow table by offsets and lengths of each partition.
2576
+ - We determine the set of unique partition keys
2577
+ - Then we produce a set of partitions by filtering on each of the combinations
2578
+ - We combine the chunks to create a copy to avoid GIL congestion on the original table
2612
2579
"""
2613
- partition_columns : List [Tuple [PartitionField , NestedField ]] = [
2614
- (partition_field , schema .find_field (partition_field .source_id )) for partition_field in spec .fields
2615
- ]
2616
- partition_values_table = pa .table (
2617
- {
2618
- str (partition .field_id ): partition .transform .pyarrow_transform (field .field_type )(arrow_table [field .name ])
2619
- for partition , field in partition_columns
2620
- }
2621
- )
2580
+ # Assign unique names to columns where the partition transform has been applied
2581
+ # to avoid conflicts
2582
+ partition_fields = [f"_partition_{ field .name } " for field in spec .fields ]
2583
+
2584
+ for partition , name in zip (spec .fields , partition_fields ):
2585
+ source_field = schema .find_field (partition .source_id )
2586
+ arrow_table = arrow_table .append_column (
2587
+ name , partition .transform .pyarrow_transform (source_field .field_type )(arrow_table [source_field .name ])
2588
+ )
2589
+
2590
+ unique_partition_fields = arrow_table .select (partition_fields ).group_by (partition_fields ).aggregate ([])
2591
+
2592
+ table_partitions = []
2593
+ # TODO: As a next step, we could also play around with yielding instead of materializing the full list
2594
+ for unique_partition in unique_partition_fields .to_pylist ():
2595
+ partition_key = PartitionKey (
2596
+ field_values = [
2597
+ PartitionFieldValue (field = field , value = unique_partition [name ])
2598
+ for field , name in zip (spec .fields , partition_fields )
2599
+ ],
2600
+ partition_spec = spec ,
2601
+ schema = schema ,
2602
+ )
2603
+ filtered_table = arrow_table .filter (
2604
+ functools .reduce (
2605
+ operator .and_ ,
2606
+ [
2607
+ pc .field (partition_field_name ) == unique_partition [partition_field_name ]
2608
+ if unique_partition [partition_field_name ] is not None
2609
+ else pc .field (partition_field_name ).is_null ()
2610
+ for field , partition_field_name in zip (spec .fields , partition_fields )
2611
+ ],
2612
+ )
2613
+ )
2614
+ filtered_table = filtered_table .drop_columns (partition_fields )
2622
2615
2623
- # Sort by partitions
2624
- sort_indices = pa .compute .sort_indices (
2625
- partition_values_table ,
2626
- sort_keys = [(col , "ascending" ) for col in partition_values_table .column_names ],
2627
- null_placement = "at_end" ,
2628
- ).to_pylist ()
2629
- arrow_table = arrow_table .take (sort_indices )
2630
-
2631
- # Get slice_instructions to group by partitions
2632
- partition_values_table = partition_values_table .take (sort_indices )
2633
- reversed_indices = pa .compute .sort_indices (
2634
- partition_values_table ,
2635
- sort_keys = [(col , "descending" ) for col in partition_values_table .column_names ],
2636
- null_placement = "at_start" ,
2637
- ).to_pylist ()
2638
- slice_instructions : List [Dict [str , Any ]] = []
2639
- last = len (reversed_indices )
2640
- reversed_indices_size = len (reversed_indices )
2641
- ptr = 0
2642
- while ptr < reversed_indices_size :
2643
- group_size = last - reversed_indices [ptr ]
2644
- offset = reversed_indices [ptr ]
2645
- slice_instructions .append ({"offset" : offset , "length" : group_size })
2646
- last = reversed_indices [ptr ]
2647
- ptr = ptr + group_size
2648
-
2649
- table_partitions : List [_TablePartition ] = _get_table_partitions (arrow_table , spec , schema , slice_instructions )
2616
+ # The combine_chunks seems to be counter-intuitive to do, but it actually returns
2617
+ # fresh buffers that don't interfere with each other when it is written out to file
2618
+ table_partitions .append (
2619
+ _TablePartition (partition_key = partition_key , arrow_table_partition = filtered_table .combine_chunks ())
2620
+ )
2650
2621
2651
2622
return table_partitions
0 commit comments