27
27
28
28
import concurrent .futures
29
29
import fnmatch
30
+ import functools
30
31
import itertools
31
32
import logging
33
+ import operator
32
34
import os
33
35
import re
34
36
import uuid
@@ -2542,36 +2544,6 @@ class _TablePartition:
2542
2544
arrow_table_partition : pa .Table
2543
2545
2544
2546
2545
- def _get_table_partitions (
2546
- arrow_table : pa .Table ,
2547
- partition_spec : PartitionSpec ,
2548
- schema : Schema ,
2549
- slice_instructions : list [dict [str , Any ]],
2550
- ) -> list [_TablePartition ]:
2551
- sorted_slice_instructions = sorted (slice_instructions , key = lambda x : x ["offset" ])
2552
-
2553
- partition_fields = partition_spec .fields
2554
-
2555
- offsets = [inst ["offset" ] for inst in sorted_slice_instructions ]
2556
- projected_and_filtered = {
2557
- partition_field .source_id : arrow_table [schema .find_field (name_or_id = partition_field .source_id ).name ]
2558
- .take (offsets )
2559
- .to_pylist ()
2560
- for partition_field in partition_fields
2561
- }
2562
-
2563
- table_partitions = []
2564
- for idx , inst in enumerate (sorted_slice_instructions ):
2565
- partition_slice = arrow_table .slice (** inst )
2566
- fieldvalues = [
2567
- PartitionFieldValue (partition_field , projected_and_filtered [partition_field .source_id ][idx ])
2568
- for partition_field in partition_fields
2569
- ]
2570
- partition_key = PartitionKey (raw_partition_field_values = fieldvalues , partition_spec = partition_spec , schema = schema )
2571
- table_partitions .append (_TablePartition (partition_key = partition_key , arrow_table_partition = partition_slice ))
2572
- return table_partitions
2573
-
2574
-
2575
2547
def _determine_partitions (spec : PartitionSpec , schema : Schema , arrow_table : pa .Table ) -> List [_TablePartition ]:
2576
2548
"""Based on the iceberg table partition spec, slice the arrow table into partitions with their keys.
2577
2549
@@ -2594,42 +2566,46 @@ def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.T
2594
2566
We then retrieve the partition keys by offsets.
2595
2567
And slice the arrow table by offsets and lengths of each partition.
2596
2568
"""
2597
- partition_columns : List [Tuple [PartitionField , NestedField ]] = [
2598
- (partition_field , schema .find_field (partition_field .source_id )) for partition_field in spec .fields
2599
- ]
2600
- partition_values_table = pa .table (
2601
- {
2602
- str (partition .field_id ): partition .transform .pyarrow_transform (field .field_type )(arrow_table [field .name ])
2603
- for partition , field in partition_columns
2604
- }
2605
- )
2569
+ # Assign unique names to columns where the partition transform has been applied
2570
+ # to avoid conflicts
2571
+ partition_fields = [f"_partition_{ field .name } " for field in spec .fields ]
2572
+
2573
+ for partition , name in zip (spec .fields , partition_fields ):
2574
+ source_field = schema .find_field (partition .source_id )
2575
+ arrow_table = arrow_table .append_column (
2576
+ name , partition .transform .pyarrow_transform (source_field .field_type )(arrow_table [source_field .name ])
2577
+ )
2578
+
2579
+ unique_partition_fields = arrow_table .select (partition_fields ).group_by (partition_fields ).aggregate ([])
2580
+
2581
+ table_partitions = []
2582
+ # TODO: As a next step, we could also play around with yielding instead of materializing the full list
2583
+ for unique_partition in unique_partition_fields .to_pylist ():
2584
+ partition_key = PartitionKey (
2585
+ raw_partition_field_values = [
2586
+ PartitionFieldValue (field = field , value = unique_partition [name ])
2587
+ for field , name in zip (spec .fields , partition_fields )
2588
+ ],
2589
+ partition_spec = spec ,
2590
+ schema = schema ,
2591
+ )
2592
+ filtered_table = arrow_table .filter (
2593
+ functools .reduce (
2594
+ operator .and_ ,
2595
+ [
2596
+ pc .field (partition_field_name ) == unique_partition [partition_field_name ]
2597
+ if unique_partition [partition_field_name ] is not None
2598
+ else pc .field (partition_field_name ).is_null ()
2599
+ for field , partition_field_name in zip (spec .fields , partition_fields )
2600
+ ],
2601
+ )
2602
+ )
2603
+ filtered_table = filtered_table .drop_columns (partition_fields )
2606
2604
2607
- # Sort by partitions
2608
- sort_indices = pa .compute .sort_indices (
2609
- partition_values_table ,
2610
- sort_keys = [(col , "ascending" ) for col in partition_values_table .column_names ],
2611
- null_placement = "at_end" ,
2612
- ).to_pylist ()
2613
- arrow_table = arrow_table .take (sort_indices )
2614
-
2615
- # Get slice_instructions to group by partitions
2616
- partition_values_table = partition_values_table .take (sort_indices )
2617
- reversed_indices = pa .compute .sort_indices (
2618
- partition_values_table ,
2619
- sort_keys = [(col , "descending" ) for col in partition_values_table .column_names ],
2620
- null_placement = "at_start" ,
2621
- ).to_pylist ()
2622
- slice_instructions : List [Dict [str , Any ]] = []
2623
- last = len (reversed_indices )
2624
- reversed_indices_size = len (reversed_indices )
2625
- ptr = 0
2626
- while ptr < reversed_indices_size :
2627
- group_size = last - reversed_indices [ptr ]
2628
- offset = reversed_indices [ptr ]
2629
- slice_instructions .append ({"offset" : offset , "length" : group_size })
2630
- last = reversed_indices [ptr ]
2631
- ptr = ptr + group_size
2632
-
2633
- table_partitions : List [_TablePartition ] = _get_table_partitions (arrow_table , spec , schema , slice_instructions )
2605
+ # The combine_chunks seems to be counter-intuitive to do, but it actually returns
2606
+ # fresh buffers that don't interfere with each other when it is written out to file
2607
+ table_partitions .append (
2608
+ _TablePartition (partition_key = partition_key , arrow_table_partition = filtered_table .combine_chunks ())
2609
+ )
2634
2610
2635
2611
return table_partitions
0 commit comments