diff --git a/compiler/rustc_smir/src/rustc_smir/mod.rs b/compiler/rustc_smir/src/rustc_smir/mod.rs index e772ae942fa5a..1eb8b0ca40617 100644 --- a/compiler/rustc_smir/src/rustc_smir/mod.rs +++ b/compiler/rustc_smir/src/rustc_smir/mod.rs @@ -287,9 +287,8 @@ impl<'tcx> Stable<'tcx> for mir::Body<'tcx> { type T = stable_mir::mir::Body; fn stable(&self, tables: &mut Tables<'tcx>) -> Self::T { - stable_mir::mir::Body { - blocks: self - .basic_blocks + stable_mir::mir::Body::new( + self.basic_blocks .iter() .map(|block| stable_mir::mir::BasicBlock { terminator: block.terminator().stable(tables), @@ -300,15 +299,15 @@ impl<'tcx> Stable<'tcx> for mir::Body<'tcx> { .collect(), }) .collect(), - locals: self - .local_decls + self.local_decls .iter() .map(|decl| stable_mir::mir::LocalDecl { ty: decl.ty.stable(tables), span: decl.source_info.span.stable(tables), }) .collect(), - } + self.arg_count, + ) } } diff --git a/compiler/stable_mir/src/mir/body.rs b/compiler/stable_mir/src/mir/body.rs index fc617513aeeed..d86d56b3dc724 100644 --- a/compiler/stable_mir/src/mir/body.rs +++ b/compiler/stable_mir/src/mir/body.rs @@ -2,10 +2,60 @@ use crate::ty::{AdtDef, ClosureDef, Const, CoroutineDef, GenericArgs, Movability use crate::Opaque; use crate::{ty::Ty, Span}; +/// The SMIR representation of a single function. #[derive(Clone, Debug)] pub struct Body { pub blocks: Vec, - pub locals: LocalDecls, + + // Declarations of locals within the function. + // + // The first local is the return value pointer, followed by `arg_count` + // locals for the function arguments, followed by any user-declared + // variables and temporaries. + locals: LocalDecls, + + // The number of arguments this function takes. + arg_count: usize, +} + +impl Body { + /// Constructs a `Body`. + /// + /// A constructor is required to build a `Body` from outside the crate + /// because the `arg_count` and `locals` fields are private. + pub fn new(blocks: Vec, locals: LocalDecls, arg_count: usize) -> Self { + // If locals doesn't contain enough entries, it can lead to panics in + // `ret_local`, `arg_locals`, and `inner_locals`. + assert!( + locals.len() > arg_count, + "A Body must contain at least a local for the return value and each of the function's arguments" + ); + Self { blocks, locals, arg_count } + } + + /// Return local that holds this function's return value. + pub fn ret_local(&self) -> &LocalDecl { + &self.locals[0] + } + + /// Locals in `self` that correspond to this function's arguments. + pub fn arg_locals(&self) -> &[LocalDecl] { + &self.locals[1..][..self.arg_count] + } + + /// Inner locals for this function. These are the locals that are + /// neither the return local nor the argument locals. + pub fn inner_locals(&self) -> &[LocalDecl] { + &self.locals[self.arg_count + 1..] + } + + /// Convenience function to get all the locals in this function. + /// + /// Locals are typically accessed via the more specific methods `ret_local`, + /// `arg_locals`, and `inner_locals`. + pub fn locals(&self) -> &[LocalDecl] { + &self.locals + } } type LocalDecls = Vec; @@ -467,7 +517,7 @@ pub enum NullOp { } impl Operand { - pub fn ty(&self, locals: &LocalDecls) -> Ty { + pub fn ty(&self, locals: &[LocalDecl]) -> Ty { match self { Operand::Copy(place) | Operand::Move(place) => place.ty(locals), Operand::Constant(c) => c.ty(), @@ -482,7 +532,7 @@ impl Constant { } impl Place { - pub fn ty(&self, locals: &LocalDecls) -> Ty { + pub fn ty(&self, locals: &[LocalDecl]) -> Ty { let _start_ty = locals[self.local].ty; todo!("Implement projection") } diff --git a/tests/ui-fulldeps/stable-mir/check_instance.rs b/tests/ui-fulldeps/stable-mir/check_instance.rs index ee82bc77aedae..a340877752d8f 100644 --- a/tests/ui-fulldeps/stable-mir/check_instance.rs +++ b/tests/ui-fulldeps/stable-mir/check_instance.rs @@ -59,7 +59,7 @@ fn test_body(body: mir::Body) { for term in body.blocks.iter().map(|bb| &bb.terminator) { match &term.kind { Call { func, .. } => { - let TyKind::RigidTy(ty) = func.ty(&body.locals).kind() else { unreachable!() }; + let TyKind::RigidTy(ty) = func.ty(body.locals()).kind() else { unreachable!() }; let RigidTy::FnDef(def, args) = ty else { unreachable!() }; let result = Instance::resolve(def, &args); assert!(result.is_ok()); diff --git a/tests/ui-fulldeps/stable-mir/crate-info.rs b/tests/ui-fulldeps/stable-mir/crate-info.rs index 60c6053d295b4..ed6b786f5e1de 100644 --- a/tests/ui-fulldeps/stable-mir/crate-info.rs +++ b/tests/ui-fulldeps/stable-mir/crate-info.rs @@ -47,7 +47,7 @@ fn test_stable_mir(_tcx: TyCtxt<'_>) -> ControlFlow<()> { let bar = get_item(&items, (DefKind::Fn, "bar")).unwrap(); let body = bar.body(); - assert_eq!(body.locals.len(), 2); + assert_eq!(body.locals().len(), 2); assert_eq!(body.blocks.len(), 1); let block = &body.blocks[0]; assert_eq!(block.statements.len(), 1); @@ -62,7 +62,7 @@ fn test_stable_mir(_tcx: TyCtxt<'_>) -> ControlFlow<()> { let foo_bar = get_item(&items, (DefKind::Fn, "foo_bar")).unwrap(); let body = foo_bar.body(); - assert_eq!(body.locals.len(), 5); + assert_eq!(body.locals().len(), 5); assert_eq!(body.blocks.len(), 4); let block = &body.blocks[0]; match &block.terminator.kind { @@ -72,29 +72,29 @@ fn test_stable_mir(_tcx: TyCtxt<'_>) -> ControlFlow<()> { let types = get_item(&items, (DefKind::Fn, "types")).unwrap(); let body = types.body(); - assert_eq!(body.locals.len(), 6); + assert_eq!(body.locals().len(), 6); assert_matches!( - body.locals[0].ty.kind(), + body.locals()[0].ty.kind(), stable_mir::ty::TyKind::RigidTy(stable_mir::ty::RigidTy::Bool) ); assert_matches!( - body.locals[1].ty.kind(), + body.locals()[1].ty.kind(), stable_mir::ty::TyKind::RigidTy(stable_mir::ty::RigidTy::Bool) ); assert_matches!( - body.locals[2].ty.kind(), + body.locals()[2].ty.kind(), stable_mir::ty::TyKind::RigidTy(stable_mir::ty::RigidTy::Char) ); assert_matches!( - body.locals[3].ty.kind(), + body.locals()[3].ty.kind(), stable_mir::ty::TyKind::RigidTy(stable_mir::ty::RigidTy::Int(stable_mir::ty::IntTy::I32)) ); assert_matches!( - body.locals[4].ty.kind(), + body.locals()[4].ty.kind(), stable_mir::ty::TyKind::RigidTy(stable_mir::ty::RigidTy::Uint(stable_mir::ty::UintTy::U64)) ); assert_matches!( - body.locals[5].ty.kind(), + body.locals()[5].ty.kind(), stable_mir::ty::TyKind::RigidTy(stable_mir::ty::RigidTy::Float( stable_mir::ty::FloatTy::F64 )) @@ -123,10 +123,10 @@ fn test_stable_mir(_tcx: TyCtxt<'_>) -> ControlFlow<()> { for block in instance.body().blocks { match &block.terminator.kind { stable_mir::mir::TerminatorKind::Call { func, .. } => { - let TyKind::RigidTy(ty) = func.ty(&body.locals).kind() else { unreachable!() }; + let TyKind::RigidTy(ty) = func.ty(&body.locals()).kind() else { unreachable!() }; let RigidTy::FnDef(def, args) = ty else { unreachable!() }; let next_func = Instance::resolve(def, &args).unwrap(); - match next_func.body().locals[1].ty.kind() { + match next_func.body().locals()[1].ty.kind() { TyKind::RigidTy(RigidTy::Uint(_)) | TyKind::RigidTy(RigidTy::Tuple(_)) => {} other => panic!("{other:?}"), } @@ -140,6 +140,29 @@ fn test_stable_mir(_tcx: TyCtxt<'_>) -> ControlFlow<()> { // Ensure we don't panic trying to get the body of a constant. foo_const.body(); + let locals_fn = get_item(&items, (DefKind::Fn, "locals")).unwrap(); + let body = locals_fn.body(); + assert_eq!(body.locals().len(), 4); + assert_matches!( + body.ret_local().ty.kind(), + stable_mir::ty::TyKind::RigidTy(stable_mir::ty::RigidTy::Char) + ); + assert_eq!(body.arg_locals().len(), 2); + assert_matches!( + body.arg_locals()[0].ty.kind(), + stable_mir::ty::TyKind::RigidTy(stable_mir::ty::RigidTy::Int(stable_mir::ty::IntTy::I32)) + ); + assert_matches!( + body.arg_locals()[1].ty.kind(), + stable_mir::ty::TyKind::RigidTy(stable_mir::ty::RigidTy::Uint(stable_mir::ty::UintTy::U64)) + ); + assert_eq!(body.inner_locals().len(), 1); + // If conditions have an extra inner local to hold their results + assert_matches!( + body.inner_locals()[0].ty.kind(), + stable_mir::ty::TyKind::RigidTy(stable_mir::ty::RigidTy::Bool) + ); + ControlFlow::Continue(()) } @@ -211,6 +234,14 @@ fn generate_input(path: &str) -> std::io::Result<()> { pub fn assert(x: i32) -> i32 {{ x + 1 + }} + + pub fn locals(a: i32, _: u64) -> char {{ + if a > 5 {{ + 'a' + }} else {{ + 'b' + }} }}"# )?; Ok(()) diff --git a/tests/ui-fulldeps/stable-mir/smir_internal.rs b/tests/ui-fulldeps/stable-mir/smir_internal.rs index 5ad05559cb4bb..b0596b1882383 100644 --- a/tests/ui-fulldeps/stable-mir/smir_internal.rs +++ b/tests/ui-fulldeps/stable-mir/smir_internal.rs @@ -29,7 +29,7 @@ const CRATE_NAME: &str = "input"; fn test_translation(_tcx: TyCtxt<'_>) -> ControlFlow<()> { let main_fn = stable_mir::entry_fn().unwrap(); let body = main_fn.body(); - let orig_ty = body.locals[0].ty; + let orig_ty = body.locals()[0].ty; let rustc_ty = rustc_internal::internal(&orig_ty); assert!(rustc_ty.is_unit()); ControlFlow::Continue(())