Skip to content

feat: Compute closure captures #14470

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 10, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 78 additions & 23 deletions crates/hir-def/src/body/lower.rs
Original file line number Diff line number Diff line change
@@ -28,9 +28,9 @@ use crate::{
data::adt::StructKind,
db::DefDatabase,
hir::{
dummy_expr_id, Array, Binding, BindingAnnotation, BindingId, ClosureKind, Expr, ExprId,
Label, LabelId, Literal, MatchArm, Movability, Pat, PatId, RecordFieldPat, RecordLitField,
Statement,
dummy_expr_id, Array, Binding, BindingAnnotation, BindingId, CaptureBy, ClosureKind, Expr,
ExprId, Label, LabelId, Literal, MatchArm, Movability, Pat, PatId, RecordFieldPat,
RecordLitField, Statement,
},
item_scope::BuiltinShadowMode,
lang_item::LangItem,
@@ -67,6 +67,7 @@ pub(super) fn lower(
is_lowering_assignee_expr: false,
is_lowering_generator: false,
label_ribs: Vec::new(),
current_binding_owner: None,
}
.collect(params, body, is_async_fn)
}
@@ -92,6 +93,7 @@ struct ExprCollector<'a> {

// resolution
label_ribs: Vec<LabelRib>,
current_binding_owner: Option<ExprId>,
}

#[derive(Clone, Debug)]
@@ -261,11 +263,16 @@ impl ExprCollector<'_> {
}
Some(ast::BlockModifier::Const(_)) => {
self.with_label_rib(RibKind::Constant, |this| {
this.collect_block_(e, |id, statements, tail| Expr::Const {
id,
statements,
tail,
})
this.collect_as_a_binding_owner_bad(
|this| {
this.collect_block_(e, |id, statements, tail| Expr::Const {
id,
statements,
tail,
})
},
syntax_ptr,
)
})
}
None => self.collect_block(e),
@@ -461,6 +468,8 @@ impl ExprCollector<'_> {
}
}
ast::Expr::ClosureExpr(e) => self.with_label_rib(RibKind::Closure, |this| {
let (result_expr_id, prev_binding_owner) =
this.initialize_binding_owner(syntax_ptr);
let mut args = Vec::new();
let mut arg_types = Vec::new();
if let Some(pl) = e.param_list() {
@@ -494,17 +503,19 @@ impl ExprCollector<'_> {
ClosureKind::Closure
};
this.is_lowering_generator = prev_is_lowering_generator;

this.alloc_expr(
Expr::Closure {
args: args.into(),
arg_types: arg_types.into(),
ret_type,
body,
closure_kind,
},
syntax_ptr,
)
let capture_by =
if e.move_token().is_some() { CaptureBy::Value } else { CaptureBy::Ref };
this.is_lowering_generator = prev_is_lowering_generator;
this.current_binding_owner = prev_binding_owner;
this.body.exprs[result_expr_id] = Expr::Closure {
args: args.into(),
arg_types: arg_types.into(),
ret_type,
body,
closure_kind,
capture_by,
};
result_expr_id
}),
ast::Expr::BinExpr(e) => {
let op = e.op_kind();
@@ -545,7 +556,15 @@ impl ExprCollector<'_> {
ArrayExprKind::Repeat { initializer, repeat } => {
let initializer = self.collect_expr_opt(initializer);
let repeat = self.with_label_rib(RibKind::Constant, |this| {
this.collect_expr_opt(repeat)
if let Some(repeat) = repeat {
let syntax_ptr = AstPtr::new(&repeat);
this.collect_as_a_binding_owner_bad(
|this| this.collect_expr(repeat),
syntax_ptr,
)
} else {
this.missing_expr()
}
});
self.alloc_expr(
Expr::Array(Array::Repeat { initializer, repeat }),
@@ -592,6 +611,32 @@ impl ExprCollector<'_> {
})
}

fn initialize_binding_owner(
&mut self,
syntax_ptr: AstPtr<ast::Expr>,
) -> (ExprId, Option<ExprId>) {
let result_expr_id = self.alloc_expr(Expr::Missing, syntax_ptr);
let prev_binding_owner = self.current_binding_owner.take();
self.current_binding_owner = Some(result_expr_id);
(result_expr_id, prev_binding_owner)
}

/// FIXME: This function is bad. It will produce a dangling `Missing` expr which wastes memory. Currently
/// it is used only for const blocks and repeat expressions, which are also hacky and ideally should have
/// their own body. Don't add more usage for this function so that we can remove this function after
/// separating those bodies.
fn collect_as_a_binding_owner_bad(
&mut self,
job: impl FnOnce(&mut ExprCollector<'_>) -> ExprId,
syntax_ptr: AstPtr<ast::Expr>,
) -> ExprId {
let (id, prev_owner) = self.initialize_binding_owner(syntax_ptr);
let tmp = job(self);
self.body.exprs[id] = mem::replace(&mut self.body.exprs[tmp], Expr::Missing);
self.current_binding_owner = prev_owner;
id
}

/// Desugar `try { <stmts>; <expr> }` into `'<new_label>: { <stmts>; ::std::ops::Try::from_output(<expr>) }`,
/// `try { <stmts>; }` into `'<new_label>: { <stmts>; ::std::ops::Try::from_output(()) }`
/// and save the `<new_label>` to use it as a break target for desugaring of the `?` operator.
@@ -1112,8 +1157,13 @@ impl ExprCollector<'_> {
}
ast::Pat::ConstBlockPat(const_block_pat) => {
if let Some(block) = const_block_pat.block_expr() {
let expr_id =
self.with_label_rib(RibKind::Constant, |this| this.collect_block(block));
let expr_id = self.with_label_rib(RibKind::Constant, |this| {
let syntax_ptr = AstPtr::new(&block.clone().into());
this.collect_as_a_binding_owner_bad(
|this| this.collect_block(block),
syntax_ptr,
)
});
Pat::ConstBlock(expr_id)
} else {
Pat::Missing
@@ -1272,7 +1322,12 @@ impl ExprCollector<'_> {
}

fn alloc_binding(&mut self, name: Name, mode: BindingAnnotation) -> BindingId {
self.body.bindings.alloc(Binding { name, mode, definitions: SmallVec::new() })
self.body.bindings.alloc(Binding {
name,
mode,
definitions: SmallVec::new(),
owner: self.current_binding_owner,
})
}

fn alloc_pat(&mut self, pat: Pat, ptr: PatPtr) -> PatId {
12 changes: 10 additions & 2 deletions crates/hir-def/src/body/pretty.rs
Original file line number Diff line number Diff line change
@@ -5,7 +5,9 @@ use std::fmt::{self, Write};
use syntax::ast::HasName;

use crate::{
hir::{Array, BindingAnnotation, BindingId, ClosureKind, Literal, Movability, Statement},
hir::{
Array, BindingAnnotation, BindingId, CaptureBy, ClosureKind, Literal, Movability, Statement,
},
pretty::{print_generic_args, print_path, print_type_ref},
type_ref::TypeRef,
};
@@ -360,7 +362,7 @@ impl<'a> Printer<'a> {
self.print_expr(*index);
w!(self, "]");
}
Expr::Closure { args, arg_types, ret_type, body, closure_kind } => {
Expr::Closure { args, arg_types, ret_type, body, closure_kind, capture_by } => {
match closure_kind {
ClosureKind::Generator(Movability::Static) => {
w!(self, "static ");
@@ -370,6 +372,12 @@ impl<'a> Printer<'a> {
}
_ => (),
}
match capture_by {
CaptureBy::Value => {
w!(self, "move ");
}
CaptureBy::Ref => (),
}
w!(self, "|");
for (i, (pat, ty)) in args.iter().zip(arg_types.iter()).enumerate() {
if i != 0 {
25 changes: 25 additions & 0 deletions crates/hir-def/src/hir.rs
Original file line number Diff line number Diff line change
@@ -275,6 +275,7 @@ pub enum Expr {
ret_type: Option<Interned<TypeRef>>,
body: ExprId,
closure_kind: ClosureKind,
capture_by: CaptureBy,
},
Tuple {
exprs: Box<[ExprId]>,
@@ -292,6 +293,14 @@ pub enum ClosureKind {
Async,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CaptureBy {
/// `move |x| y + x`.
Value,
/// `move` keyword was not specified.
Ref,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Movability {
Static,
@@ -484,6 +493,22 @@ pub struct Binding {
pub name: Name,
pub mode: BindingAnnotation,
pub definitions: SmallVec<[PatId; 1]>,
/// Id of the closure/generator that owns this binding. If it is owned by the
/// top level expression, this field would be `None`.
pub owner: Option<ExprId>,
}

impl Binding {
pub fn is_upvar(&self, relative_to: ExprId) -> bool {
match self.owner {
Some(x) => {
// We assign expression ids in a way that outer closures will recieve
// a lower id
x.into_raw() < relative_to.into_raw()
}
None => true,
}
}
}

#[derive(Debug, Clone, Eq, PartialEq)]
13 changes: 11 additions & 2 deletions crates/hir-ty/src/chalk_ext.rs
Original file line number Diff line number Diff line change
@@ -12,8 +12,9 @@ use hir_def::{
use crate::{
db::HirDatabase, from_assoc_type_id, from_chalk_trait_id, from_foreign_def_id,
from_placeholder_idx, to_chalk_trait_id, utils::generics, AdtId, AliasEq, AliasTy, Binders,
CallableDefId, CallableSig, DynTy, FnPointer, ImplTraitId, Interner, Lifetime, ProjectionTy,
QuantifiedWhereClause, Substitution, TraitRef, Ty, TyBuilder, TyKind, TypeFlags, WhereClause,
CallableDefId, CallableSig, ClosureId, DynTy, FnPointer, ImplTraitId, Interner, Lifetime,
ProjectionTy, QuantifiedWhereClause, Substitution, TraitRef, Ty, TyBuilder, TyKind, TypeFlags,
WhereClause,
};

pub trait TyExt {
@@ -28,6 +29,7 @@ pub trait TyExt {
fn as_adt(&self) -> Option<(hir_def::AdtId, &Substitution)>;
fn as_builtin(&self) -> Option<BuiltinType>;
fn as_tuple(&self) -> Option<&Substitution>;
fn as_closure(&self) -> Option<ClosureId>;
fn as_fn_def(&self, db: &dyn HirDatabase) -> Option<FunctionId>;
fn as_reference(&self) -> Option<(&Ty, Lifetime, Mutability)>;
fn as_reference_or_ptr(&self) -> Option<(&Ty, Rawness, Mutability)>;
@@ -128,6 +130,13 @@ impl TyExt for Ty {
}
}

fn as_closure(&self) -> Option<ClosureId> {
match self.kind(Interner) {
TyKind::Closure(id, _) => Some(*id),
_ => None,
}
}

fn as_fn_def(&self, db: &dyn HirDatabase) -> Option<FunctionId> {
match self.callable_def(db) {
Some(CallableDefId::FunctionId(func)) => Some(func),
75 changes: 75 additions & 0 deletions crates/hir-ty/src/consteval/tests.rs
Original file line number Diff line number Diff line change
@@ -1105,6 +1105,81 @@ fn try_block() {
);
}

#[test]
fn closures() {
check_number(
r#"
//- minicore: fn, copy
const GOAL: i32 = {
let y = 5;
let c = |x| x + y;
c(2)
};
"#,
7,
);
check_number(
r#"
//- minicore: fn, copy
const GOAL: i32 = {
let y = 5;
let c = |(a, b): &(i32, i32)| *a + *b + y;
c(&(2, 3))
};
"#,
10,
);
check_number(
r#"
//- minicore: fn, copy
const GOAL: i32 = {
let mut y = 5;
let c = |x| {
y = y + x;
};
c(2);
c(3);
y
};
"#,
10,
);
check_number(
r#"
//- minicore: fn, copy
struct X(i32);
impl X {
fn mult(&mut self, n: i32) {
self.0 = self.0 * n
}
}
const GOAL: i32 = {
let x = X(1);
let c = || {
x.mult(2);
|| {
x.mult(3);
|| {
|| {
x.mult(4);
|| {
x.mult(x.0);
|| {
x.0
}
}
}
}
}
};
let r = c()()()()()();
r + x.0
};
"#,
24 * 24 * 2,
);
}

#[test]
fn or_pattern() {
check_number(
11 changes: 7 additions & 4 deletions crates/hir-ty/src/db.rs
Original file line number Diff line number Diff line change
@@ -19,9 +19,9 @@ use crate::{
consteval::ConstEvalError,
method_resolution::{InherentImpls, TraitImpls, TyFingerprint},
mir::{BorrowckResult, MirBody, MirLowerError},
Binders, CallableDefId, Const, FnDefId, GenericArg, ImplTraitId, InferenceResult, Interner,
PolyFnSig, QuantifiedWhereClause, ReturnTypeImplTraits, Substitution, TraitRef, Ty, TyDefId,
ValueTyDefId,
Binders, CallableDefId, ClosureId, Const, FnDefId, GenericArg, ImplTraitId, InferenceResult,
Interner, PolyFnSig, QuantifiedWhereClause, ReturnTypeImplTraits, Substitution, TraitRef, Ty,
TyDefId, ValueTyDefId,
};
use hir_expand::name::Name;

@@ -38,8 +38,11 @@ pub trait HirDatabase: DefDatabase + Upcast<dyn DefDatabase> {
#[salsa::cycle(crate::mir::mir_body_recover)]
fn mir_body(&self, def: DefWithBodyId) -> Result<Arc<MirBody>, MirLowerError>;

#[salsa::invoke(crate::mir::mir_body_for_closure_query)]
fn mir_body_for_closure(&self, def: ClosureId) -> Result<Arc<MirBody>, MirLowerError>;

#[salsa::invoke(crate::mir::borrowck_query)]
fn borrowck(&self, def: DefWithBodyId) -> Result<Arc<BorrowckResult>, MirLowerError>;
fn borrowck(&self, def: DefWithBodyId) -> Result<Arc<[BorrowckResult]>, MirLowerError>;

#[salsa::invoke(crate::lower::ty_query)]
#[salsa::cycle(crate::lower::ty_recover)]
Loading