@@ -29,7 +29,7 @@ use datafusion_common::Result;
29
29
use datafusion_execution:: memory_pool:: proxy:: VecAllocExt ;
30
30
use datafusion_expr:: EmitTo ;
31
31
use half:: f16;
32
- use hashbrown:: raw :: RawTable ;
32
+ use hashbrown:: hash_table :: HashTable ;
33
33
use std:: mem:: size_of;
34
34
use std:: sync:: Arc ;
35
35
@@ -86,7 +86,7 @@ pub struct GroupValuesPrimitive<T: ArrowPrimitiveType> {
86
86
///
87
87
/// We don't store the hashes as hashing fixed width primitives
88
88
/// is fast enough for this not to benefit performance
89
- map : RawTable < usize > ,
89
+ map : HashTable < usize > ,
90
90
/// The group index of the null value if any
91
91
null_group : Option < usize > ,
92
92
/// The values for each group index
@@ -100,7 +100,7 @@ impl<T: ArrowPrimitiveType> GroupValuesPrimitive<T> {
100
100
assert ! ( PrimitiveArray :: <T >:: is_compatible( & data_type) ) ;
101
101
Self {
102
102
data_type,
103
- map : RawTable :: with_capacity ( 128 ) ,
103
+ map : HashTable :: with_capacity ( 128 ) ,
104
104
values : Vec :: with_capacity ( 128 ) ,
105
105
null_group : None ,
106
106
random_state : Default :: default ( ) ,
@@ -126,22 +126,19 @@ where
126
126
Some ( key) => {
127
127
let state = & self . random_state ;
128
128
let hash = key. hash ( state) ;
129
- let insert = self . map . find_or_find_insert_slot (
129
+ let insert = self . map . entry (
130
130
hash,
131
131
|g| unsafe { self . values . get_unchecked ( * g) . is_eq ( key) } ,
132
132
|g| unsafe { self . values . get_unchecked ( * g) . hash ( state) } ,
133
133
) ;
134
134
135
- // SAFETY: No mutation occurred since find_or_find_insert_slot
136
- unsafe {
137
- match insert {
138
- Ok ( v) => * v. as_ref ( ) ,
139
- Err ( slot) => {
140
- let g = self . values . len ( ) ;
141
- self . map . insert_in_slot ( hash, slot, g) ;
142
- self . values . push ( key) ;
143
- g
144
- }
135
+ match insert {
136
+ hashbrown:: hash_table:: Entry :: Occupied ( o) => * o. get ( ) ,
137
+ hashbrown:: hash_table:: Entry :: Vacant ( v) => {
138
+ let g = self . values . len ( ) ;
139
+ v. insert ( g) ;
140
+ self . values . push ( key) ;
141
+ g
145
142
}
146
143
}
147
144
}
@@ -183,18 +180,18 @@ where
183
180
build_primitive ( std:: mem:: take ( & mut self . values ) , self . null_group . take ( ) )
184
181
}
185
182
EmitTo :: First ( n) => {
186
- // SAFETY: self.map outlives iterator and is not modified concurrently
187
- unsafe {
188
- for bucket in self . map . iter ( ) {
189
- // Decrement group index by n
190
- match bucket. as_ref ( ) . checked_sub ( n) {
191
- // Group index was >= n, shift value down
192
- Some ( sub) => * bucket. as_mut ( ) = sub,
193
- // Group index was < n, so remove from table
194
- None => self . map . erase ( bucket) ,
183
+ self . map . retain ( |group_idx| {
184
+ // Decrement group index by n
185
+ match group_idx. checked_sub ( n) {
186
+ // Group index was >= n, shift value down
187
+ Some ( sub) => {
188
+ * group_idx = sub;
189
+ true
195
190
}
191
+ // Group index was < n, so remove from table
192
+ None => false ,
196
193
}
197
- }
194
+ } ) ;
198
195
let null_group = match & mut self . null_group {
199
196
Some ( v) if * v >= n => {
200
197
* v -= n;
0 commit comments