From 0d2548a173359da5075585e1f5ccb678d87ad24b Mon Sep 17 00:00:00 2001
From: Stein Somers <git@steinsomers.be>
Date: Mon, 23 Nov 2020 14:41:53 +0100
Subject: [PATCH] BTreeMap: declare exclusive access to arrays when copying
 from them

---
 library/alloc/src/collections/btree/node.rs | 81 +++++----------------
 1 file changed, 17 insertions(+), 64 deletions(-)

diff --git a/library/alloc/src/collections/btree/node.rs b/library/alloc/src/collections/btree/node.rs
index 769383515b7f1..16b4b1091eff7 100644
--- a/library/alloc/src/collections/btree/node.rs
+++ b/library/alloc/src/collections/btree/node.rs
@@ -295,15 +295,6 @@ impl<BorrowType, K, V> NodeRef<BorrowType, K, V, marker::Internal> {
     }
 }
 
-impl<'a, K, V> NodeRef<marker::Immut<'a>, K, V, marker::Internal> {
-    /// Exposes the data of an internal node in an immutable tree.
-    fn as_internal(this: &Self) -> &'a InternalNode<K, V> {
-        let ptr = Self::as_internal_ptr(this);
-        // SAFETY: there can be no mutable references into this tree borrowed as `Immut`.
-        unsafe { &*ptr }
-    }
-}
-
 impl<'a, K, V> NodeRef<marker::Mut<'a>, K, V, marker::Internal> {
     /// Borrows exclusive access to the data of an internal node.
     fn as_internal_mut(&mut self) -> &mut InternalNode<K, V> {
@@ -368,17 +359,6 @@ impl<'a, K: 'a, V: 'a, Type> NodeRef<marker::Immut<'a>, K, V, Type> {
     }
 }
 
-impl<'a, K, V> NodeRef<marker::Immut<'a>, K, V, marker::Internal> {
-    /// Exposes the contents of one of the edges in the node.
-    ///
-    /// # Safety
-    /// The node has more than `idx` initialized elements.
-    unsafe fn edge_at(self, idx: usize) -> &'a BoxedNode<K, V> {
-        debug_assert!(idx <= self.len());
-        unsafe { Self::as_internal(&self).edges.get_unchecked(idx).assume_init_ref() }
-    }
-}
-
 impl<BorrowType, K, V, Type> NodeRef<BorrowType, K, V, Type> {
     /// Finds the parent of the current node. Returns `Ok(handle)` if the current
     /// node actually has a parent, where `handle` points to the edge of the parent
@@ -550,31 +530,6 @@ impl<'a, K: 'a, V: 'a> NodeRef<marker::Mut<'a>, K, V, marker::Internal> {
     }
 }
 
-impl<'a, K: 'a, V: 'a, Type> NodeRef<marker::Immut<'a>, K, V, Type> {
-    /// Exposes the entire key storage area in the node,
-    /// regardless of the node's current length,
-    /// having exclusive access to the entire node.
-    unsafe fn key_area(self) -> &'a [MaybeUninit<K>] {
-        self.into_leaf().keys.as_slice()
-    }
-
-    /// Exposes the entire value storage area in the node,
-    /// regardless of the node's current length,
-    /// having exclusive access to the entire node.
-    unsafe fn val_area(self) -> &'a [MaybeUninit<V>] {
-        self.into_leaf().vals.as_slice()
-    }
-}
-
-impl<'a, K: 'a, V: 'a> NodeRef<marker::Immut<'a>, K, V, marker::Internal> {
-    /// Exposes the entire storage area for edge contents in the node,
-    /// regardless of the node's current length,
-    /// having exclusive access to the entire node.
-    unsafe fn edge_area(self) -> &'a [MaybeUninit<BoxedNode<K, V>>] {
-        Self::as_internal(&self).edges.as_slice()
-    }
-}
-
 impl<'a, K, V, Type> NodeRef<marker::ValMut<'a>, K, V, Type> {
     /// # Safety
     /// - The node has more than `idx` initialized elements.
@@ -707,12 +662,12 @@ impl<'a, K: 'a, V: 'a> NodeRef<marker::Mut<'a>, K, V, marker::LeafOrInternal> {
         let idx = self.len() - 1;
 
         unsafe {
-            let key = ptr::read(self.reborrow().key_at(idx));
-            let val = ptr::read(self.reborrow().val_at(idx));
+            let key = self.key_area_mut_at(idx).assume_init_read();
+            let val = self.val_area_mut_at(idx).assume_init_read();
             let edge = match self.reborrow_mut().force() {
                 ForceResult::Leaf(_) => None,
-                ForceResult::Internal(internal) => {
-                    let node = ptr::read(internal.reborrow().edge_at(idx + 1));
+                ForceResult::Internal(mut internal) => {
+                    let node = internal.edge_area_mut_at(idx + 1).assume_init_read();
                     let mut edge = Root { node, height: internal.height - 1, _marker: PhantomData };
                     // Currently, clearing the parent link is superfluous, because we will
                     // insert the node elsewhere and set its parent link again.
@@ -1172,16 +1127,16 @@ impl<'a, K: 'a, V: 'a, NodeType> Handle<NodeRef<marker::Mut<'a>, K, V, NodeType>
         let new_len = self.node.len() - self.idx - 1;
         new_node.len = new_len as u16;
         unsafe {
-            let k = ptr::read(self.node.reborrow().key_at(self.idx));
-            let v = ptr::read(self.node.reborrow().val_at(self.idx));
+            let k = self.node.key_area_mut_at(self.idx).assume_init_read();
+            let v = self.node.val_area_mut_at(self.idx).assume_init_read();
 
             ptr::copy_nonoverlapping(
-                self.node.reborrow().key_area().as_ptr().add(self.idx + 1),
+                self.node.key_area_mut_at(self.idx + 1..).as_ptr(),
                 new_node.keys.as_mut_ptr(),
                 new_len,
             );
             ptr::copy_nonoverlapping(
-                self.node.reborrow().val_area().as_ptr().add(self.idx + 1),
+                self.node.val_area_mut_at(self.idx + 1..).as_ptr(),
                 new_node.vals.as_mut_ptr(),
                 new_len,
             );
@@ -1240,7 +1195,7 @@ impl<'a, K: 'a, V: 'a> Handle<NodeRef<marker::Mut<'a>, K, V, marker::Internal>,
             let kv = self.split_leaf_data(&mut new_node.data);
             let new_len = usize::from(new_node.data.len);
             ptr::copy_nonoverlapping(
-                self.node.reborrow().edge_area().as_ptr().add(self.idx + 1),
+                self.node.edge_area_mut_at(self.idx + 1..).as_ptr(),
                 new_node.edges.as_mut_ptr(),
                 new_len + 1,
             );
@@ -1352,7 +1307,7 @@ impl<'a, K: 'a, V: 'a> BalancingContext<'a, K, V> {
         let old_parent_len = parent_node.len();
         let mut left_node = self.left_child;
         let old_left_len = left_node.len();
-        let right_node = self.right_child;
+        let mut right_node = self.right_child;
         let right_len = right_node.len();
         let new_left_len = old_left_len + 1 + right_len;
 
@@ -1370,7 +1325,7 @@ impl<'a, K: 'a, V: 'a> BalancingContext<'a, K, V> {
                 slice_remove(parent_node.key_area_mut_at(..old_parent_len), parent_idx);
             left_node.key_area_mut_at(old_left_len).write(parent_key);
             ptr::copy_nonoverlapping(
-                right_node.reborrow().key_area().as_ptr(),
+                right_node.key_area_mut_at(..).as_ptr(),
                 left_node.key_area_mut_at(old_left_len + 1..).as_mut_ptr(),
                 right_len,
             );
@@ -1379,7 +1334,7 @@ impl<'a, K: 'a, V: 'a> BalancingContext<'a, K, V> {
                 slice_remove(parent_node.val_area_mut_at(..old_parent_len), parent_idx);
             left_node.val_area_mut_at(old_left_len).write(parent_val);
             ptr::copy_nonoverlapping(
-                right_node.reborrow().val_area().as_ptr(),
+                right_node.val_area_mut_at(..).as_ptr(),
                 left_node.val_area_mut_at(old_left_len + 1..).as_mut_ptr(),
                 right_len,
             );
@@ -1392,9 +1347,9 @@ impl<'a, K: 'a, V: 'a> BalancingContext<'a, K, V> {
                 // SAFETY: the height of the nodes being merged is one below the height
                 // of the node of this edge, thus above zero, so they are internal.
                 let mut left_node = left_node.reborrow_mut().cast_to_internal_unchecked();
-                let right_node = right_node.cast_to_internal_unchecked();
+                let mut right_node = right_node.cast_to_internal_unchecked();
                 ptr::copy_nonoverlapping(
-                    right_node.reborrow().edge_area().as_ptr(),
+                    right_node.edge_area_mut_at(..).as_ptr(),
                     left_node.edge_area_mut_at(old_left_len + 1..).as_mut_ptr(),
                     right_len + 1,
                 );
@@ -1503,7 +1458,6 @@ impl<'a, K: 'a, V: 'a> BalancingContext<'a, K, V> {
             match (left_node.reborrow_mut().force(), right_node.reborrow_mut().force()) {
                 (ForceResult::Internal(left), ForceResult::Internal(mut right)) => {
                     // Make room for stolen edges.
-                    let left = left.reborrow();
                     let right_edges = right.edge_area_mut_at(..).as_mut_ptr();
                     ptr::copy(right_edges, right_edges.add(count), old_right_len + 1);
                     right.correct_childrens_parent_links(count..new_right_len + 1);
@@ -1561,7 +1515,7 @@ impl<'a, K: 'a, V: 'a> BalancingContext<'a, K, V> {
             match (left_node.reborrow_mut().force(), right_node.reborrow_mut().force()) {
                 (ForceResult::Internal(left), ForceResult::Internal(mut right)) => {
                     // Steal edges.
-                    move_edges(right.reborrow(), 0, left, old_left_len + 1, count);
+                    move_edges(right.reborrow_mut(), 0, left, old_left_len + 1, count);
 
                     // Fill gap where stolen edges used to be.
                     let right_edges = right.edge_area_mut_at(..).as_mut_ptr();
@@ -1590,14 +1544,14 @@ unsafe fn move_kv<K, V>(
 
 // Source and destination must have the same height.
 unsafe fn move_edges<'a, K: 'a, V: 'a>(
-    source: NodeRef<marker::Immut<'a>, K, V, marker::Internal>,
+    mut source: NodeRef<marker::Mut<'a>, K, V, marker::Internal>,
     source_offset: usize,
     mut dest: NodeRef<marker::Mut<'a>, K, V, marker::Internal>,
     dest_offset: usize,
     count: usize,
 ) {
     unsafe {
-        let source_ptr = source.edge_area().as_ptr();
+        let source_ptr = source.edge_area_mut_at(..).as_ptr();
         let dest_ptr = dest.edge_area_mut_at(dest_offset..).as_mut_ptr();
         ptr::copy_nonoverlapping(source_ptr.add(source_offset), dest_ptr, count);
         dest.correct_childrens_parent_links(dest_offset..dest_offset + count);
@@ -1699,7 +1653,6 @@ impl<'a, K, V> Handle<NodeRef<marker::Mut<'a>, K, V, marker::LeafOrInternal>, ma
 
                 match (left_node.force(), right_node.force()) {
                     (ForceResult::Internal(left), ForceResult::Internal(right)) => {
-                        let left = left.reborrow();
                         move_edges(left, new_left_len + 1, right, 1, new_right_len);
                     }
                     (ForceResult::Leaf(_), ForceResult::Leaf(_)) => {}