Skip to content

Commit dd46d07

Browse files
committed
Support overloaded deref MIR lowering
1 parent 9564773 commit dd46d07

File tree

17 files changed

+389
-60
lines changed

17 files changed

+389
-60
lines changed

crates/hir-def/src/body/lower.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1030,8 +1030,9 @@ impl ExprCollector<'_> {
10301030
.collect(),
10311031
}
10321032
}
1033-
ast::Pat::LiteralPat(lit) => 'b: {
1034-
if let Some(ast_lit) = lit.literal() {
1033+
// FIXME: rustfmt removes this label if it is a block and not a loop
1034+
ast::Pat::LiteralPat(lit) => 'b: loop {
1035+
break if let Some(ast_lit) = lit.literal() {
10351036
let mut hir_lit: Literal = ast_lit.kind().into();
10361037
if lit.minus_token().is_some() {
10371038
let Some(h) = hir_lit.negate() else {
@@ -1045,8 +1046,8 @@ impl ExprCollector<'_> {
10451046
Pat::Lit(expr_id)
10461047
} else {
10471048
Pat::Missing
1048-
}
1049-
}
1049+
};
1050+
},
10501051
ast::Pat::RestPat(_) => {
10511052
// `RestPat` requires special handling and should not be mapped
10521053
// to a Pat. Here we are using `Pat::Missing` as a fallback for

crates/hir-ty/src/consteval/tests.rs

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,7 @@ fn reference_autoderef() {
166166

167167
#[test]
168168
fn overloaded_deref() {
169-
// FIXME: We should support this.
170-
check_fail(
169+
check_number(
171170
r#"
172171
//- minicore: deref_mut
173172
struct Foo;
@@ -185,9 +184,7 @@ fn overloaded_deref() {
185184
*y + *x
186185
};
187186
"#,
188-
ConstEvalError::MirLowerError(MirLowerError::NotSupported(
189-
"explicit overloaded deref".into(),
190-
)),
187+
10,
191188
);
192189
}
193190

@@ -698,7 +695,7 @@ fn pattern_matching_literal() {
698695
}
699696
const GOAL: i32 = f(-1) + f(1) + f(0) + f(-5);
700697
"#,
701-
211
698+
211,
702699
);
703700
check_number(
704701
r#"
@@ -711,7 +708,7 @@ fn pattern_matching_literal() {
711708
}
712709
const GOAL: u8 = f("foo") + f("bar");
713710
"#,
714-
11
711+
11,
715712
);
716713
}
717714

@@ -1116,6 +1113,22 @@ fn function_traits() {
11161113
"#,
11171114
15,
11181115
);
1116+
check_number(
1117+
r#"
1118+
//- minicore: coerce_unsized, fn
1119+
fn add2(x: u8) -> u8 {
1120+
x + 2
1121+
}
1122+
fn call(f: &dyn Fn(u8) -> u8, x: u8) -> u8 {
1123+
f(x)
1124+
}
1125+
fn call_mut(f: &mut dyn FnMut(u8) -> u8, x: u8) -> u8 {
1126+
f(x)
1127+
}
1128+
const GOAL: u8 = call(&add2, 3) + call_mut(&mut add2, 3);
1129+
"#,
1130+
10,
1131+
);
11191132
check_number(
11201133
r#"
11211134
//- minicore: fn

crates/hir-ty/src/infer.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ mod expr;
5757
mod pat;
5858
mod coerce;
5959
mod closure;
60+
mod mutability;
6061

6162
/// The entry point of type inference.
6263
pub(crate) fn infer_query(db: &dyn HirDatabase, def: DefWithBodyId) -> Arc<InferenceResult> {
@@ -99,6 +100,8 @@ pub(crate) fn infer_query(db: &dyn HirDatabase, def: DefWithBodyId) -> Arc<Infer
99100

100101
ctx.infer_body();
101102

103+
ctx.infer_mut_body();
104+
102105
Arc::new(ctx.resolve_all())
103106
}
104107

crates/hir-ty/src/infer/expr.rs

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -390,14 +390,28 @@ impl<'a> InferenceContext<'a> {
390390
if let Some(fn_x) = func {
391391
match fn_x {
392392
FnTrait::FnOnce => (),
393-
FnTrait::FnMut => adjustments.push(Adjustment::borrow(
394-
Mutability::Mut,
395-
derefed_callee.clone(),
396-
)),
397-
FnTrait::Fn => adjustments.push(Adjustment::borrow(
398-
Mutability::Not,
399-
derefed_callee.clone(),
400-
)),
393+
FnTrait::FnMut => {
394+
if !matches!(
395+
derefed_callee.kind(Interner),
396+
TyKind::Ref(Mutability::Mut, _, _)
397+
) {
398+
adjustments.push(Adjustment::borrow(
399+
Mutability::Mut,
400+
derefed_callee.clone(),
401+
));
402+
}
403+
}
404+
FnTrait::Fn => {
405+
if !matches!(
406+
derefed_callee.kind(Interner),
407+
TyKind::Ref(Mutability::Not, _, _)
408+
) {
409+
adjustments.push(Adjustment::borrow(
410+
Mutability::Not,
411+
derefed_callee.clone(),
412+
));
413+
}
414+
}
401415
}
402416
let trait_ = fn_x
403417
.get_id(self.db, self.trait_env.krate)
@@ -673,6 +687,23 @@ impl<'a> InferenceContext<'a> {
673687
// FIXME: Note down method resolution her
674688
match op {
675689
UnaryOp::Deref => {
690+
if let Some(deref_trait) = self
691+
.db
692+
.lang_item(self.table.trait_env.krate, LangItem::Deref)
693+
.and_then(|l| l.as_trait())
694+
{
695+
if let Some(deref_fn) =
696+
self.db.trait_data(deref_trait).method_by_name(&name![deref])
697+
{
698+
// FIXME: this is wrong in multiple ways, subst is empty, and we emit it even for builtin deref (note that
699+
// the mutability is not wrong, and will be fixed in `self.infer_mut`).
700+
self.write_method_resolution(
701+
tgt_expr,
702+
deref_fn,
703+
Substitution::empty(Interner),
704+
);
705+
}
706+
}
676707
autoderef::deref(&mut self.table, inner_ty).unwrap_or_else(|| self.err_ty())
677708
}
678709
UnaryOp::Neg => {

crates/hir-ty/src/infer/mutability.rs

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
//! Finds if an expression is an immutable context or a mutable context, which is used in selecting
2+
//! between `Deref` and `DerefMut` or `Index` and `IndexMut` or similar.
3+
4+
use chalk_ir::Mutability;
5+
use hir_def::{
6+
expr::{Array, BindingAnnotation, Expr, ExprId, PatId, Statement, UnaryOp},
7+
lang_item::LangItem,
8+
};
9+
use hir_expand::name;
10+
11+
use crate::{lower::lower_to_chalk_mutability, Adjust, AutoBorrow, OverloadedDeref};
12+
13+
use super::InferenceContext;
14+
15+
impl<'a> InferenceContext<'a> {
16+
pub(crate) fn infer_mut_body(&mut self) {
17+
self.infer_mut_expr(self.body.body_expr, Mutability::Not);
18+
}
19+
20+
fn infer_mut_expr(&mut self, tgt_expr: ExprId, mut mutability: Mutability) {
21+
let mut v = vec![];
22+
let adjustments = self.result.expr_adjustments.get_mut(&tgt_expr).unwrap_or(&mut v);
23+
for adj in adjustments.iter_mut().rev() {
24+
match &mut adj.kind {
25+
Adjust::NeverToAny | Adjust::Deref(None) | Adjust::Pointer(_) => (),
26+
Adjust::Deref(Some(d)) => *d = OverloadedDeref(Some(mutability)),
27+
Adjust::Borrow(b) => match b {
28+
AutoBorrow::Ref(m) | AutoBorrow::RawPtr(m) => mutability = *m,
29+
},
30+
}
31+
}
32+
self.infer_mut_expr_without_adjust(tgt_expr, mutability);
33+
}
34+
35+
fn infer_mut_expr_without_adjust(&mut self, tgt_expr: ExprId, mutability: Mutability) {
36+
match &self.body[tgt_expr] {
37+
Expr::Missing => (),
38+
&Expr::If { condition, then_branch, else_branch } => {
39+
self.infer_mut_expr(condition, Mutability::Not);
40+
self.infer_mut_expr(then_branch, Mutability::Not);
41+
if let Some(else_branch) = else_branch {
42+
self.infer_mut_expr(else_branch, Mutability::Not);
43+
}
44+
}
45+
Expr::Let { pat, expr } => self.infer_mut_expr(*expr, self.pat_bound_mutability(*pat)),
46+
Expr::Block { id: _, statements, tail, label: _ }
47+
| Expr::TryBlock { id: _, statements, tail }
48+
| Expr::Async { id: _, statements, tail }
49+
| Expr::Const { id: _, statements, tail }
50+
| Expr::Unsafe { id: _, statements, tail } => {
51+
for st in statements.iter() {
52+
match st {
53+
Statement::Let { pat, type_ref: _, initializer, else_branch } => {
54+
if let Some(i) = initializer {
55+
self.infer_mut_expr(*i, self.pat_bound_mutability(*pat));
56+
}
57+
if let Some(e) = else_branch {
58+
self.infer_mut_expr(*e, Mutability::Not);
59+
}
60+
}
61+
Statement::Expr { expr, has_semi: _ } => {
62+
self.infer_mut_expr(*expr, Mutability::Not);
63+
}
64+
}
65+
}
66+
if let Some(tail) = tail {
67+
self.infer_mut_expr(*tail, Mutability::Not);
68+
}
69+
}
70+
&Expr::For { iterable: c, pat: _, body, label: _ }
71+
| &Expr::While { condition: c, body, label: _ } => {
72+
self.infer_mut_expr(c, Mutability::Not);
73+
self.infer_mut_expr(body, Mutability::Not);
74+
}
75+
Expr::MethodCall { receiver: x, method_name: _, args, generic_args: _ }
76+
| Expr::Call { callee: x, args, is_assignee_expr: _ } => {
77+
self.infer_mut_not_expr_iter(args.iter().copied().chain(Some(*x)));
78+
}
79+
Expr::Match { expr, arms } => {
80+
let m = self.pat_iter_bound_mutability(arms.iter().map(|x| x.pat));
81+
self.infer_mut_expr(*expr, m);
82+
for arm in arms.iter() {
83+
self.infer_mut_expr(arm.expr, Mutability::Not);
84+
}
85+
}
86+
Expr::Yield { expr }
87+
| Expr::Yeet { expr }
88+
| Expr::Return { expr }
89+
| Expr::Break { expr, label: _ } => {
90+
if let &Some(expr) = expr {
91+
self.infer_mut_expr(expr, Mutability::Not);
92+
}
93+
}
94+
Expr::RecordLit { path: _, fields, spread, ellipsis: _, is_assignee_expr: _ } => {
95+
self.infer_mut_not_expr_iter(fields.iter().map(|x| x.expr).chain(*spread))
96+
}
97+
&Expr::Index { base, index } => {
98+
self.infer_mut_expr(base, mutability);
99+
self.infer_mut_expr(index, Mutability::Not);
100+
}
101+
Expr::UnaryOp { expr, op: UnaryOp::Deref } => {
102+
if let Some((f, _)) = self.result.method_resolutions.get_mut(&tgt_expr) {
103+
if mutability == Mutability::Mut {
104+
if let Some(deref_trait) = self
105+
.db
106+
.lang_item(self.table.trait_env.krate, LangItem::DerefMut)
107+
.and_then(|l| l.as_trait())
108+
{
109+
if let Some(deref_fn) =
110+
self.db.trait_data(deref_trait).method_by_name(&name![deref_mut])
111+
{
112+
*f = deref_fn;
113+
}
114+
}
115+
}
116+
}
117+
self.infer_mut_expr(*expr, mutability);
118+
}
119+
Expr::Field { expr, name: _ } => {
120+
self.infer_mut_expr(*expr, mutability);
121+
}
122+
Expr::UnaryOp { expr, op: _ }
123+
| Expr::Range { lhs: Some(expr), rhs: None, range_type: _ }
124+
| Expr::Range { rhs: Some(expr), lhs: None, range_type: _ }
125+
| Expr::Await { expr }
126+
| Expr::Box { expr }
127+
| Expr::Loop { body: expr, label: _ }
128+
| Expr::Cast { expr, type_ref: _ } => {
129+
self.infer_mut_expr(*expr, Mutability::Not);
130+
}
131+
Expr::Ref { expr, rawness: _, mutability } => {
132+
let mutability = lower_to_chalk_mutability(*mutability);
133+
self.infer_mut_expr(*expr, mutability);
134+
}
135+
Expr::Array(Array::Repeat { initializer: lhs, repeat: rhs })
136+
| Expr::BinaryOp { lhs, rhs, op: _ }
137+
| Expr::Range { lhs: Some(lhs), rhs: Some(rhs), range_type: _ } => {
138+
self.infer_mut_expr(*lhs, Mutability::Not);
139+
self.infer_mut_expr(*rhs, Mutability::Not);
140+
}
141+
// not implemented
142+
Expr::Closure { .. } => (),
143+
Expr::Tuple { exprs, is_assignee_expr: _ }
144+
| Expr::Array(Array::ElementList { elements: exprs, is_assignee_expr: _ }) => {
145+
self.infer_mut_not_expr_iter(exprs.iter().copied());
146+
}
147+
// These don't need any action, as they don't have sub expressions
148+
Expr::Range { lhs: None, rhs: None, range_type: _ }
149+
| Expr::Literal(_)
150+
| Expr::Path(_)
151+
| Expr::Continue { .. }
152+
| Expr::Underscore => (),
153+
}
154+
}
155+
156+
fn infer_mut_not_expr_iter(&mut self, exprs: impl Iterator<Item = ExprId>) {
157+
for expr in exprs {
158+
self.infer_mut_expr(expr, Mutability::Not);
159+
}
160+
}
161+
162+
fn pat_iter_bound_mutability(&self, mut pat: impl Iterator<Item = PatId>) -> Mutability {
163+
if pat.any(|p| self.pat_bound_mutability(p) == Mutability::Mut) {
164+
Mutability::Mut
165+
} else {
166+
Mutability::Not
167+
}
168+
}
169+
170+
/// Checks if the pat contains a `ref mut` binding. Such paths makes the context of bounded expressions
171+
/// mutable. For example in `let (ref mut x0, ref x1) = *x;` we need to use `DerefMut` for `*x` but in
172+
/// `let (ref x0, ref x1) = *x;` we should use `Deref`.
173+
fn pat_bound_mutability(&self, pat: PatId) -> Mutability {
174+
let mut r = Mutability::Not;
175+
self.body.walk_bindings_in_pat(pat, |b| {
176+
if self.body.bindings[b].mode == BindingAnnotation::RefMut {
177+
r = Mutability::Mut;
178+
}
179+
});
180+
r
181+
}
182+
}

0 commit comments

Comments
 (0)