From 7fde02ea53f423ca25740e9339d8ff1db8f73052 Mon Sep 17 00:00:00 2001
From: Strophox <strophox@gmail.com>
Date: Wed, 31 Jul 2024 09:41:14 +0200
Subject: [PATCH] enable Miri to pass const pointers through FFI

Co-authored-by: Ralf Jung <post@ralfj.de>
---
 .../rustc_const_eval/src/interpret/machine.rs | 12 +--
 .../src/mir/interpret/allocation.rs           |  3 +-
 src/tools/miri/src/alloc_addresses/mod.rs     | 68 ++++++++++++++-
 src/tools/miri/src/concurrency/thread.rs      |  2 +-
 src/tools/miri/src/lib.rs                     |  3 +
 src/tools/miri/src/machine.rs                 | 25 ++++++
 src/tools/miri/src/shims/native_lib.rs        | 13 +++
 .../{libtest.map => native-lib.map}           |  8 ++
 .../tests/native-lib/pass/ptr_read_access.rs  | 82 +++++++++++++++++++
 .../native-lib/pass/ptr_read_access.stdout    |  1 +
 ...all_extern_c_fn.rs => scalar_arguments.rs} |  0
 ...rn_c_fn.stdout => scalar_arguments.stdout} |  0
 .../miri/tests/native-lib/ptr_read_access.c   | 47 +++++++++++
 .../native-lib/{test.c => scalar_arguments.c} |  0
 src/tools/miri/tests/ui.rs                    | 13 ++-
 15 files changed, 260 insertions(+), 17 deletions(-)
 rename src/tools/miri/tests/native-lib/{libtest.map => native-lib.map} (62%)
 create mode 100644 src/tools/miri/tests/native-lib/pass/ptr_read_access.rs
 create mode 100644 src/tools/miri/tests/native-lib/pass/ptr_read_access.stdout
 rename src/tools/miri/tests/native-lib/pass/{call_extern_c_fn.rs => scalar_arguments.rs} (100%)
 rename src/tools/miri/tests/native-lib/pass/{call_extern_c_fn.stdout => scalar_arguments.stdout} (100%)
 create mode 100644 src/tools/miri/tests/native-lib/ptr_read_access.c
 rename src/tools/miri/tests/native-lib/{test.c => scalar_arguments.c} (100%)

diff --git a/compiler/rustc_const_eval/src/interpret/machine.rs b/compiler/rustc_const_eval/src/interpret/machine.rs
index 761ab81e22842..cff0612d58edb 100644
--- a/compiler/rustc_const_eval/src/interpret/machine.rs
+++ b/compiler/rustc_const_eval/src/interpret/machine.rs
@@ -357,17 +357,7 @@ pub trait Machine<'tcx>: Sized {
         ecx: &InterpCx<'tcx, Self>,
         id: AllocId,
         alloc: &'b Allocation,
-    ) -> InterpResult<'tcx, Cow<'b, Allocation<Self::Provenance, Self::AllocExtra, Self::Bytes>>>
-    {
-        // The default implementation does a copy; CTFE machines have a more efficient implementation
-        // based on their particular choice for `Provenance`, `AllocExtra`, and `Bytes`.
-        let kind = Self::GLOBAL_KIND
-            .expect("if GLOBAL_KIND is None, adjust_global_allocation must be overwritten");
-        let alloc = alloc.adjust_from_tcx(&ecx.tcx, |ptr| ecx.global_root_pointer(ptr))?;
-        let extra =
-            Self::init_alloc_extra(ecx, id, MemoryKind::Machine(kind), alloc.size(), alloc.align)?;
-        Ok(Cow::Owned(alloc.with_extra(extra)))
-    }
+    ) -> InterpResult<'tcx, Cow<'b, Allocation<Self::Provenance, Self::AllocExtra, Self::Bytes>>>;
 
     /// Initialize the extra state of an allocation.
     ///
diff --git a/compiler/rustc_middle/src/mir/interpret/allocation.rs b/compiler/rustc_middle/src/mir/interpret/allocation.rs
index 3e101c0c6354d..5fb8af576ae93 100644
--- a/compiler/rustc_middle/src/mir/interpret/allocation.rs
+++ b/compiler/rustc_middle/src/mir/interpret/allocation.rs
@@ -358,10 +358,11 @@ impl Allocation {
     pub fn adjust_from_tcx<Prov: Provenance, Bytes: AllocBytes, Err>(
         &self,
         cx: &impl HasDataLayout,
+        mut alloc_bytes: impl FnMut(&[u8], Align) -> Result<Bytes, Err>,
         mut adjust_ptr: impl FnMut(Pointer<CtfeProvenance>) -> Result<Pointer<Prov>, Err>,
     ) -> Result<Allocation<Prov, (), Bytes>, Err> {
         // Copy the data.
-        let mut bytes = Bytes::from_bytes(Cow::Borrowed(&*self.bytes), self.align);
+        let mut bytes = alloc_bytes(&*self.bytes, self.align)?;
         // Adjust provenance of pointers stored in this allocation.
         let mut new_provenance = Vec::with_capacity(self.provenance.ptrs().len());
         let ptr_size = cx.data_layout().pointer_size.bytes_usize();
diff --git a/src/tools/miri/src/alloc_addresses/mod.rs b/src/tools/miri/src/alloc_addresses/mod.rs
index ed955e78c3e9a..76c68add8cdc9 100644
--- a/src/tools/miri/src/alloc_addresses/mod.rs
+++ b/src/tools/miri/src/alloc_addresses/mod.rs
@@ -42,6 +42,11 @@ pub struct GlobalStateInner {
     /// they do not have an `AllocExtra`.
     /// This is the inverse of `int_to_ptr_map`.
     base_addr: FxHashMap<AllocId, u64>,
+    /// Temporarily store prepared memory space for global allocations the first time their memory
+    /// address is required. This is used to ensure that the memory is allocated before Miri assigns
+    /// it an internal address, which is important for matching the internal address to the machine
+    /// address so FFI can read from pointers.
+    prepared_alloc_bytes: FxHashMap<AllocId, MiriAllocBytes>,
     /// A pool of addresses we can reuse for future allocations.
     reuse: ReusePool,
     /// Whether an allocation has been exposed or not. This cannot be put
@@ -59,6 +64,7 @@ impl VisitProvenance for GlobalStateInner {
         let GlobalStateInner {
             int_to_ptr_map: _,
             base_addr: _,
+            prepared_alloc_bytes: _,
             reuse: _,
             exposed: _,
             next_base_addr: _,
@@ -78,6 +84,7 @@ impl GlobalStateInner {
         GlobalStateInner {
             int_to_ptr_map: Vec::default(),
             base_addr: FxHashMap::default(),
+            prepared_alloc_bytes: FxHashMap::default(),
             reuse: ReusePool::new(config),
             exposed: FxHashSet::default(),
             next_base_addr: stack_addr,
@@ -166,7 +173,39 @@ trait EvalContextExtPriv<'tcx>: crate::MiriInterpCxExt<'tcx> {
                 assert!(!matches!(kind, AllocKind::Dead));
 
                 // This allocation does not have a base address yet, pick or reuse one.
-                let base_addr = if let Some((reuse_addr, clock)) = global_state.reuse.take_addr(
+                let base_addr = if ecx.machine.native_lib.is_some() {
+                    // In native lib mode, we use the "real" address of the bytes for this allocation.
+                    // This ensures the interpreted program and native code have the same view of memory.
+                    match kind {
+                        AllocKind::LiveData => {
+                            let ptr = if ecx.tcx.try_get_global_alloc(alloc_id).is_some() {
+                                // For new global allocations, we always pre-allocate the memory to be able use the machine address directly.
+                                let prepared_bytes = MiriAllocBytes::zeroed(size, align)
+                                    .unwrap_or_else(|| {
+                                        panic!("Miri ran out of memory: cannot create allocation of {size:?} bytes")
+                                    });
+                                let ptr = prepared_bytes.as_ptr();
+                                    // Store prepared allocation space to be picked up for use later.
+                                    global_state.prepared_alloc_bytes.try_insert(alloc_id, prepared_bytes).unwrap();
+                                ptr
+                            } else {
+                                ecx.get_alloc_bytes_unchecked_raw(alloc_id)?
+                            };
+                            // Ensure this pointer's provenance is exposed, so that it can be used by FFI code.
+                            ptr.expose_provenance().try_into().unwrap()
+                        }
+                        AllocKind::Function | AllocKind::VTable => {
+                            // Allocate some dummy memory to get a unique address for this function/vtable.
+                            let alloc_bytes = MiriAllocBytes::from_bytes(&[0u8; 1], Align::from_bytes(1).unwrap());
+                            // We don't need to expose these bytes as nobody is allowed to access them.
+                            let addr = alloc_bytes.as_ptr().addr().try_into().unwrap();
+                            // Leak the underlying memory to ensure it remains unique.
+                            std::mem::forget(alloc_bytes);
+                            addr
+                        }
+                        AllocKind::Dead => unreachable!()
+                    }
+                } else if let Some((reuse_addr, clock)) = global_state.reuse.take_addr(                    
                     &mut *rng,
                     size,
                     align,
@@ -318,6 +357,33 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
         Ok(base_ptr.wrapping_offset(offset, ecx))
     }
 
+    // This returns some prepared `MiriAllocBytes`, either because `addr_from_alloc_id` reserved
+    // memory space in the past, or by doing the pre-allocation right upon being called.
+    fn get_global_alloc_bytes(&self, id: AllocId, kind: MemoryKind, bytes: &[u8], align: Align) -> InterpResult<'tcx, MiriAllocBytes> {
+        let ecx = self.eval_context_ref();
+        Ok(if ecx.machine.native_lib.is_some() {
+            // In native lib mode, MiriAllocBytes for global allocations are handled via `prepared_alloc_bytes`.
+            // This additional call ensures that some `MiriAllocBytes` are always prepared.
+            ecx.addr_from_alloc_id(id, kind)?;
+            let mut global_state = ecx.machine.alloc_addresses.borrow_mut();
+            // The memory we need here will have already been allocated during an earlier call to
+            // `addr_from_alloc_id` for this allocation. So don't create a new `MiriAllocBytes` here, instead
+            // fetch the previously prepared bytes from `prepared_alloc_bytes`.
+            let mut prepared_alloc_bytes = global_state
+                .prepared_alloc_bytes
+                .remove(&id)
+                .unwrap_or_else(|| panic!("alloc bytes for {id:?} have not been prepared"));
+            // Sanity-check that the prepared allocation has the right size and alignment.
+            assert!(prepared_alloc_bytes.as_ptr().is_aligned_to(align.bytes_usize()));
+            assert_eq!(prepared_alloc_bytes.len(), bytes.len());
+            // Copy allocation contents into prepared memory.
+            prepared_alloc_bytes.copy_from_slice(bytes);
+            prepared_alloc_bytes
+        } else {
+            MiriAllocBytes::from_bytes(std::borrow::Cow::Borrowed(&*bytes), align)
+        })
+    }
+
     /// When a pointer is used for a memory access, this computes where in which allocation the
     /// access is going.
     fn ptr_get_alloc(
diff --git a/src/tools/miri/src/concurrency/thread.rs b/src/tools/miri/src/concurrency/thread.rs
index a4d3e3f7af3ce..37d4a2663e012 100644
--- a/src/tools/miri/src/concurrency/thread.rs
+++ b/src/tools/miri/src/concurrency/thread.rs
@@ -887,7 +887,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
             let alloc = this.ctfe_query(|tcx| tcx.eval_static_initializer(def_id))?;
             // We make a full copy of this allocation.
             let mut alloc =
-                alloc.inner().adjust_from_tcx(&this.tcx, |ptr| this.global_root_pointer(ptr))?;
+                alloc.inner().adjust_from_tcx(&this.tcx, |bytes, align| Ok(MiriAllocBytes::from_bytes(std::borrow::Cow::Borrowed(bytes), align)), |ptr| this.global_root_pointer(ptr))?;
             // This allocation will be deallocated when the thread dies, so it is not in read-only memory.
             alloc.mutability = Mutability::Mut;
             // Create a fresh allocation with this content.
diff --git a/src/tools/miri/src/lib.rs b/src/tools/miri/src/lib.rs
index f796e845a2399..f70ca098fb95b 100644
--- a/src/tools/miri/src/lib.rs
+++ b/src/tools/miri/src/lib.rs
@@ -12,6 +12,9 @@
 #![feature(let_chains)]
 #![feature(trait_upcasting)]
 #![feature(strict_overflow_ops)]
+#![feature(strict_provenance)]
+#![feature(exposed_provenance)]
+#![feature(pointer_is_aligned_to)]
 // Configure clippy and other lints
 #![allow(
     clippy::collapsible_else_if,
diff --git a/src/tools/miri/src/machine.rs b/src/tools/miri/src/machine.rs
index a7493d48d6a2d..35bef5e6ed513 100644
--- a/src/tools/miri/src/machine.rs
+++ b/src/tools/miri/src/machine.rs
@@ -1,6 +1,7 @@
 //! Global machine state as well as implementation of the interpreter engine
 //! `Machine` trait.
 
+use std::borrow::Cow;
 use std::cell::RefCell;
 use std::collections::hash_map::Entry;
 use std::fmt;
@@ -1225,6 +1226,30 @@ impl<'tcx> Machine<'tcx> for MiriMachine<'tcx> {
         })
     }
 
+    /// Called to adjust global allocations to the Provenance and AllocExtra of this machine.
+    ///
+    /// If `alloc` contains pointers, then they are all pointing to globals.
+    ///
+    /// This should avoid copying if no work has to be done! If this returns an owned
+    /// allocation (because a copy had to be done to adjust things), machine memory will
+    /// cache the result. (This relies on `AllocMap::get_or` being able to add the
+    /// owned allocation to the map even when the map is shared.)
+    fn adjust_global_allocation<'b>(
+        ecx: &InterpCx<'tcx, Self>,
+        id: AllocId,
+        alloc: &'b Allocation,
+    ) -> InterpResult<'tcx, Cow<'b, Allocation<Self::Provenance, Self::AllocExtra, Self::Bytes>>>
+    {
+        let kind = Self::GLOBAL_KIND.unwrap().into();
+        let alloc = alloc.adjust_from_tcx(&ecx.tcx,
+            |bytes, align| ecx.get_global_alloc_bytes(id, kind, bytes, align),
+            |ptr| ecx.global_root_pointer(ptr),
+        )?;
+        let extra =
+            Self::init_alloc_extra(ecx, id, kind, alloc.size(), alloc.align)?;
+        Ok(Cow::Owned(alloc.with_extra(extra)))
+    }
+
     #[inline(always)]
     fn before_memory_read(
         _tcx: TyCtxtAt<'tcx>,
diff --git a/src/tools/miri/src/shims/native_lib.rs b/src/tools/miri/src/shims/native_lib.rs
index 40697e17ba193..25d96fbd7331f 100644
--- a/src/tools/miri/src/shims/native_lib.rs
+++ b/src/tools/miri/src/shims/native_lib.rs
@@ -194,6 +194,8 @@ enum CArg {
     UInt64(u64),
     /// usize.
     USize(usize),
+    /// Raw pointer, stored as C's `void*`.
+    RawPtr(*mut std::ffi::c_void),
 }
 
 impl<'a> CArg {
@@ -210,6 +212,7 @@ impl<'a> CArg {
             CArg::UInt32(i) => ffi::arg(i),
             CArg::UInt64(i) => ffi::arg(i),
             CArg::USize(i) => ffi::arg(i),
+            CArg::RawPtr(i) => ffi::arg(i),
         }
     }
 }
@@ -234,6 +237,16 @@ fn imm_to_carg<'tcx>(v: ImmTy<'tcx>, cx: &impl HasDataLayout) -> InterpResult<'t
         ty::Uint(UintTy::U64) => CArg::UInt64(v.to_scalar().to_u64()?),
         ty::Uint(UintTy::Usize) =>
             CArg::USize(v.to_scalar().to_target_usize(cx)?.try_into().unwrap()),
+        ty::RawPtr(_, mutability) => {
+            // Arbitrary mutable pointer accesses are not currently supported in Miri.
+            if mutability.is_mut() {
+                throw_unsup_format!("unsupported mutable pointer type for native call: {}", v.layout.ty);
+            } else {
+                let s = v.to_scalar().to_pointer(cx)?.addr();
+                // This relies on the `expose_provenance` in `addr_from_alloc_id`.
+                CArg::RawPtr(std::ptr::with_exposed_provenance_mut(s.bytes_usize()))
+            }
+        },
         _ => throw_unsup_format!("unsupported argument type for native call: {}", v.layout.ty),
     })
 }
diff --git a/src/tools/miri/tests/native-lib/libtest.map b/src/tools/miri/tests/native-lib/native-lib.map
similarity index 62%
rename from src/tools/miri/tests/native-lib/libtest.map
rename to src/tools/miri/tests/native-lib/native-lib.map
index a57a4dc149feb..7e3bd19622af3 100644
--- a/src/tools/miri/tests/native-lib/libtest.map
+++ b/src/tools/miri/tests/native-lib/native-lib.map
@@ -1,12 +1,20 @@
 CODEABI_1.0 {
     # Define which symbols to export.
     global:
+        # scalar_arguments.c
         add_one_int;
         printer;
         test_stack_spill;
         get_unsigned_int;
         add_int16;
         add_short_to_long;
+
+        # ptr_read_access.c
+        print_pointer;
+        access_simple;
+        access_nested;
+        access_static;
+
     # The rest remains private.
     local: *;
 };
diff --git a/src/tools/miri/tests/native-lib/pass/ptr_read_access.rs b/src/tools/miri/tests/native-lib/pass/ptr_read_access.rs
new file mode 100644
index 0000000000000..d8e6209839e2a
--- /dev/null
+++ b/src/tools/miri/tests/native-lib/pass/ptr_read_access.rs
@@ -0,0 +1,82 @@
+//@only-target-linux
+//@only-on-host
+
+fn main() {
+    test_pointer();
+
+    test_simple();
+
+    test_nested();
+
+    test_static();
+}
+
+// Test void function that dereferences a pointer and prints its contents from C.
+fn test_pointer() {
+    extern "C" {
+        fn print_pointer(ptr: *const i32);
+    }
+
+    let x = 42;
+
+    unsafe { print_pointer(&x) };
+}
+
+// Test function that dereferences a simple struct pointer and accesses a field.
+fn test_simple() {
+    #[repr(C)]
+    struct Simple {
+        field: i32
+    }
+
+    extern "C" {
+        fn access_simple(s_ptr: *const Simple) -> i32;
+    }
+
+    let simple = Simple { field: -42 };
+
+    assert_eq!(unsafe { access_simple(&simple) }, -42);
+}
+
+// Test function that dereferences nested struct pointers and accesses fields.
+fn test_nested() {
+    use std::ptr::NonNull;
+    
+    #[derive(Debug, PartialEq, Eq)]
+    #[repr(C)]
+    struct Nested {
+        value: i32,
+        next: Option<NonNull<Nested>>,
+    }
+
+    extern "C" {
+        fn access_nested(n_ptr: *const Nested) -> i32;
+    }
+
+    let mut nested_0 = Nested { value: 97, next: None };
+    let mut nested_1 = Nested { value: 98, next: NonNull::new(&mut nested_0) };
+    let nested_2 = Nested { value: 99, next: NonNull::new(&mut nested_1) };
+
+    assert_eq!(unsafe { access_nested(&nested_2) }, 97);
+}
+
+// Test function that dereferences static struct pointers and accesses fields.
+fn test_static() {
+
+    #[repr(C)]
+    struct Static {
+        value: i32,
+        recurse: &'static Static,
+    }
+
+    extern "C" {
+        fn access_static(n_ptr: *const Static) -> i32;
+    }
+    
+    static STATIC: Static = Static {
+        value: 9001,
+        recurse: &STATIC,
+    };
+
+    assert_eq!(unsafe { access_static(&STATIC) }, 9001);
+}
diff --git a/src/tools/miri/tests/native-lib/pass/ptr_read_access.stdout b/src/tools/miri/tests/native-lib/pass/ptr_read_access.stdout
new file mode 100644
index 0000000000000..1a8799abfc93e
--- /dev/null
+++ b/src/tools/miri/tests/native-lib/pass/ptr_read_access.stdout
@@ -0,0 +1 @@
+printing pointer dereference from C: 42
diff --git a/src/tools/miri/tests/native-lib/pass/call_extern_c_fn.rs b/src/tools/miri/tests/native-lib/pass/scalar_arguments.rs
similarity index 100%
rename from src/tools/miri/tests/native-lib/pass/call_extern_c_fn.rs
rename to src/tools/miri/tests/native-lib/pass/scalar_arguments.rs
diff --git a/src/tools/miri/tests/native-lib/pass/call_extern_c_fn.stdout b/src/tools/miri/tests/native-lib/pass/scalar_arguments.stdout
similarity index 100%
rename from src/tools/miri/tests/native-lib/pass/call_extern_c_fn.stdout
rename to src/tools/miri/tests/native-lib/pass/scalar_arguments.stdout
diff --git a/src/tools/miri/tests/native-lib/ptr_read_access.c b/src/tools/miri/tests/native-lib/ptr_read_access.c
new file mode 100644
index 0000000000000..03b9189e2e86d
--- /dev/null
+++ b/src/tools/miri/tests/native-lib/ptr_read_access.c
@@ -0,0 +1,47 @@
+#include <stdio.h>
+
+/* Test: test_pointer */
+
+void print_pointer(const int *ptr) {
+  printf("printing pointer dereference from C: %d\n", *ptr);
+}
+
+/* Test: test_simple */
+
+typedef struct Simple {
+  int field;
+} Simple;
+
+int access_simple(const Simple *s_ptr) {
+  return s_ptr->field;
+}
+
+/* Test: test_nested */
+
+typedef struct Nested {
+  int value;
+  struct Nested *next;
+} Nested;
+
+// Returns the innermost/last value of a Nested pointer chain.
+int access_nested(const Nested *n_ptr) {
+  // Edge case: `n_ptr == NULL` (i.e. first Nested is None).
+  if (!n_ptr) { return 0; }
+
+  while (n_ptr->next) {
+    n_ptr = n_ptr->next;
+  }
+
+  return n_ptr->value;
+}
+
+/* Test: test_static */
+
+typedef struct Static {
+    int value;
+    struct Static *recurse;
+} Static;
+
+int access_static(const Static *s_ptr) {
+  return s_ptr->recurse->recurse->value;
+}
diff --git a/src/tools/miri/tests/native-lib/test.c b/src/tools/miri/tests/native-lib/scalar_arguments.c
similarity index 100%
rename from src/tools/miri/tests/native-lib/test.c
rename to src/tools/miri/tests/native-lib/scalar_arguments.c
diff --git a/src/tools/miri/tests/ui.rs b/src/tools/miri/tests/ui.rs
index 9cbcf6e42a795..c510ef95c30e8 100644
--- a/src/tools/miri/tests/ui.rs
+++ b/src/tools/miri/tests/ui.rs
@@ -36,18 +36,25 @@ fn build_native_lib() -> PathBuf {
     // Create the directory if it does not already exist.
     std::fs::create_dir_all(&so_target_dir)
         .expect("Failed to create directory for shared object file");
-    let so_file_path = so_target_dir.join("libtestlib.so");
+    let so_file_path = so_target_dir.join("native-lib.so");
     let cc_output = Command::new(cc)
         .args([
             "-shared",
             "-o",
             so_file_path.to_str().unwrap(),
-            "tests/native-lib/test.c",
+            // FIXME: Automate gathering of all relevant C source files in the directory.
+            "tests/native-lib/scalar_arguments.c",
+            "tests/native-lib/ptr_read_access.c",
             // Only add the functions specified in libcode.version to the shared object file.
             // This is to avoid automatically adding `malloc`, etc.
             // Source: https://anadoxin.org/blog/control-over-symbol-exports-in-gcc.html/
             "-fPIC",
-            "-Wl,--version-script=tests/native-lib/libtest.map",
+            "-Wl,--version-script=tests/native-lib/native-lib.map",
+            // Ensure we notice serious problems in the C code.
+            "-Wall",
+            "-Wextra",
+            "-Wpedantic",
+            "-Werror",
         ])
         .output()
         .expect("failed to generate shared object file for testing native function calls");