Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions crates/hir-ty/src/infer/closure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -924,7 +924,7 @@ impl InferenceContext<'_> {
}
}

fn closure_kind(&self) -> FnTrait {
fn closure_kind_from_capture(&self) -> FnTrait {
let mut r = FnTrait::Fn;
for it in &self.current_captures {
r = cmp::min(
Expand All @@ -941,7 +941,7 @@ impl InferenceContext<'_> {
r
}

fn analyze_closure(&mut self, closure: ClosureId) -> FnTrait {
fn analyze_closure(&mut self, closure: ClosureId, predicate: Option<FnTrait>) -> FnTrait {
let InternedClosure(_, root) = self.db.lookup_intern_closure(closure.into());
self.current_closure = Some(closure);
let Expr::Closure { body, capture_by, .. } = &self.body[root] else {
Expand All @@ -959,7 +959,14 @@ impl InferenceContext<'_> {
}
self.restrict_precision_for_unsafe();
// closure_kind should be done before adjust_for_move_closure
let closure_kind = self.closure_kind();
let closure_kind = {
let from_capture = self.closure_kind_from_capture();
// if predicate.unwrap_or(FnTrait::Fn) < from_capture {
// // Diagnostics here, like compiler does in
// // https://github.com/rust-lang/rust/blob/11f32b73e0dc9287e305b5b9980d24aecdc8c17f/compiler/rustc_hir_typeck/src/upvar.rs#L264
// }
predicate.unwrap_or(from_capture)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should use unwrap_or_else(|| self.closure_kind_from_capture()) for lazyiness

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was planning to implement a diagnostic here as the comments above says. I'll try it and will convert into unwrap_or_else if it is not possible

};
match capture_by {
CaptureBy::Value => self.adjust_for_move_closure(),
CaptureBy::Ref => (),
Expand All @@ -975,7 +982,9 @@ impl InferenceContext<'_> {
let deferred_closures = self.sort_closures();
for (closure, exprs) in deferred_closures.into_iter().rev() {
self.current_captures = vec![];
let kind = self.analyze_closure(closure);

let predicate = self.table.get_closure_fn_trait_predicate(closure);
let kind = self.analyze_closure(closure, predicate);

for (derefed_callee, callee_ty, params, expr) in exprs {
if let &Expr::Call { callee, .. } = &self.body[expr] {
Expand Down
69 changes: 62 additions & 7 deletions crates/hir-ty/src/infer/unify.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
//! Unification and canonicalization logic.

use std::{fmt, iter, mem};
use std::{cmp, fmt, iter, mem};

use chalk_ir::{
cast::Cast, fold::TypeFoldable, interner::HasInterner, zip::Zip, CanonicalVarKind, FloatTy,
IntTy, TyVariableKind, UniverseIndex,
IntTy, TyVariableKind, UniverseIndex, WhereClause,
};
use chalk_solve::infer::ParameterEnaVariableExt;
use either::Either;
Expand All @@ -14,11 +14,12 @@ use triomphe::Arc;

use super::{InferOk, InferResult, InferenceContext, TypeError};
use crate::{
consteval::unknown_const, db::HirDatabase, fold_tys_and_consts, static_lifetime,
to_chalk_trait_id, traits::FnTrait, AliasEq, AliasTy, BoundVar, Canonical, Const, ConstValue,
DebruijnIndex, GenericArg, GenericArgData, Goal, Guidance, InEnvironment, InferenceVar,
Interner, Lifetime, ParamKind, ProjectionTy, ProjectionTyExt, Scalar, Solution, Substitution,
TraitEnvironment, Ty, TyBuilder, TyExt, TyKind, VariableKind,
chalk_db::TraitId, consteval::unknown_const, db::HirDatabase, fold_tys_and_consts,
static_lifetime, to_chalk_trait_id, traits::FnTrait, AliasEq, AliasTy, BoundVar, Canonical,
ClosureId, Const, ConstValue, DebruijnIndex, DomainGoal, GenericArg, GenericArgData, Goal,
GoalData, Guidance, InEnvironment, InferenceVar, Interner, Lifetime, ParamKind, ProjectionTy,
ProjectionTyExt, Scalar, Solution, Substitution, TraitEnvironment, Ty, TyBuilder, TyExt,
TyKind, VariableKind,
};

impl InferenceContext<'_> {
Expand Down Expand Up @@ -181,6 +182,8 @@ pub(crate) struct InferenceTable<'a> {
/// Double buffer used in [`Self::resolve_obligations_as_possible`] to cut down on
/// temporary allocations.
resolve_obligations_buffer: Vec<Canonicalized<InEnvironment<Goal>>>,
fn_trait_predicates: Vec<(Ty, FnTrait)>,
cached_fn_trait_ids: Option<CachedFnTraitIds>,
Comment on lines +185 to +186
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not too happy with this, can't we just iterate through the pending obligations when necessary? (as is what rustc does)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The chalk resolves obligations like FnOnce immediately, so they are not remained in the pending obligations. The most problematic thing is that it infers the closures with FnOnce predicate as implementing FnMut and Fn

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I should look into it more, but I'm not sure that it is intended chalk functionality or not

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anyway, I'm gonna take a deep look into the chalk closure inferences. I might fix this from upstream or maybe RA is asking chalk without some details

Copy link
Member Author

@ShoyuVanilla ShoyuVanilla Feb 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Veykril I had a closer into RA and chalk, and I think that we couldn't use pending obligations like rustc does.

When the chalk has to solve a goal that whether some closure implements some FnTrait;

1. It asks RustIrDatabase for the kind of our closure and checkes if the kind of closure is compatible with the given FnTrait
2. If so, it pushes the clause for the "FACT" that our closure implements that trait.

So, solving whether the given closure implements FnTrait is heavily depends on RustIrDatabase and I think that's reasonable for it to do so because what chalk wants to be(though it is not going to be actively maintained due to new trait solver in rustc) is an abstract trait solver, not a very definite rustc inference maker.

But since RA's ChalkContext, which is our RustIrDatabase implementation, always returns Fn as the kind of a closure. (I think that this is sort of a heuristic to avoid having other inferences and coercions delayed or unresolved.)
Thus, the obligations like closure{n}[...] implements FnX are always solved into Unique solutions and not remain as "pending" because they are immediately solved.

Giving the "correct" kind of a closure doesn't work for this because;

  1. We need to use "correct" pending obligations for closure kind inferencing
  2. But the "correct" pending obligations are what we are going to do with "correct" kind of a closure, a cyclic dependency!

I've tried just giving FnOnce as a closure_kind but aside from that it makes lots of inferences in our test cases unresolved - if the later were to worked correctly I would have tried fixing those inferences -, it still does not "pend" obligations.
The obligations for closure implementation FnMut and Fn are solved as None, and they are not pended as well.

So, I have concluded that we would have to modify lots of RA inferencing code to do this as rustc does, and this was why I wrote somewhat sad codes for similar thing 😢 (When I wrote it, I hadn't debugged the chalk much, but tried using pending obligations and it was always empty; so I had similar conclusion without much evidence then 😅 )

}

pub(crate) struct InferenceTableSnapshot {
Expand All @@ -189,15 +192,34 @@ pub(crate) struct InferenceTableSnapshot {
type_variable_table_snapshot: Vec<TypeVariableFlags>,
}

#[derive(Clone)]
struct CachedFnTraitIds {
fn_trait: TraitId,
fn_mut_trait: TraitId,
fn_once_trait: TraitId,
}

impl CachedFnTraitIds {
fn new(db: &dyn HirDatabase, trait_env: &Arc<TraitEnvironment>) -> Option<Self> {
let fn_trait = FnTrait::Fn.get_id(db, trait_env.krate).map(to_chalk_trait_id)?;
let fn_mut_trait = FnTrait::FnMut.get_id(db, trait_env.krate).map(to_chalk_trait_id)?;
let fn_once_trait = FnTrait::FnOnce.get_id(db, trait_env.krate).map(to_chalk_trait_id)?;
Some(Self { fn_trait, fn_mut_trait, fn_once_trait })
}
}

impl<'a> InferenceTable<'a> {
pub(crate) fn new(db: &'a dyn HirDatabase, trait_env: Arc<TraitEnvironment>) -> Self {
let cached_fn_trait_ids = CachedFnTraitIds::new(db, &trait_env);
InferenceTable {
db,
trait_env,
var_unification_table: ChalkInferenceTable::new(),
type_variable_table: Vec::new(),
pending_obligations: Vec::new(),
resolve_obligations_buffer: Vec::new(),
fn_trait_predicates: Vec::new(),
cached_fn_trait_ids,
}
}

Expand Down Expand Up @@ -547,6 +569,22 @@ impl<'a> InferenceTable<'a> {
}

fn register_obligation_in_env(&mut self, goal: InEnvironment<Goal>) {
if let Some(fn_trait_ids) = &self.cached_fn_trait_ids {
if let GoalData::DomainGoal(DomainGoal::Holds(WhereClause::Implemented(trait_ref))) =
goal.goal.data(Interner)
{
if let Some(ty) = trait_ref.substitution.type_parameters(Interner).next() {
if trait_ref.trait_id == fn_trait_ids.fn_trait {
self.fn_trait_predicates.push((ty, FnTrait::Fn));
} else if trait_ref.trait_id == fn_trait_ids.fn_mut_trait {
self.fn_trait_predicates.push((ty, FnTrait::FnMut));
} else if trait_ref.trait_id == fn_trait_ids.fn_once_trait {
self.fn_trait_predicates.push((ty, FnTrait::FnOnce));
}
}
}
}

let canonicalized = self.canonicalize(goal);
let solution = self.try_resolve_obligation(&canonicalized);
if matches!(solution, Some(Solution::Ambig(_))) {
Expand Down Expand Up @@ -838,6 +876,23 @@ impl<'a> InferenceTable<'a> {
_ => c,
}
}

pub(super) fn get_closure_fn_trait_predicate(
&mut self,
closure_id: ClosureId,
) -> Option<FnTrait> {
let predicates = mem::take(&mut self.fn_trait_predicates);
let res = predicates.iter().filter_map(|(ty, fn_trait)| {
if matches!(self.resolve_completely(ty.clone()).kind(Interner), TyKind::Closure(c, ..) if *c == closure_id) {
Some(*fn_trait)
} else {
None
}
}).fold(None, |acc, x| Some(cmp::max(acc.unwrap_or(FnTrait::FnOnce), x)));
self.fn_trait_predicates = predicates;

res
}
}

impl fmt::Debug for InferenceTable<'_> {
Expand Down
8 changes: 4 additions & 4 deletions crates/hir-ty/src/tests/patterns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -702,25 +702,25 @@ fn test() {
51..58 'loop {}': !
56..58 '{}': ()
72..171 '{ ... x); }': ()
78..81 'foo': fn foo<&(i32, &str), i32, impl Fn(&(i32, &str)) -> i32>(&(i32, &str), impl Fn(&(i32, &str)) -> i32) -> i32
78..81 'foo': fn foo<&(i32, &str), i32, impl FnOnce(&(i32, &str)) -> i32>(&(i32, &str), impl FnOnce(&(i32, &str)) -> i32) -> i32
78..105 'foo(&(...y)| x)': i32
82..91 '&(1, "a")': &(i32, &str)
83..91 '(1, "a")': (i32, &str)
84..85 '1': i32
87..90 '"a"': &str
93..104 '|&(x, y)| x': impl Fn(&(i32, &str)) -> i32
93..104 '|&(x, y)| x': impl FnOnce(&(i32, &str)) -> i32
94..101 '&(x, y)': &(i32, &str)
95..101 '(x, y)': (i32, &str)
96..97 'x': i32
99..100 'y': &str
103..104 'x': i32
142..145 'foo': fn foo<&(i32, &str), &i32, impl Fn(&(i32, &str)) -> &i32>(&(i32, &str), impl Fn(&(i32, &str)) -> &i32) -> &i32
142..145 'foo': fn foo<&(i32, &str), &i32, impl FnOnce(&(i32, &str)) -> &i32>(&(i32, &str), impl FnOnce(&(i32, &str)) -> &i32) -> &i32
142..168 'foo(&(...y)| x)': &i32
146..155 '&(1, "a")': &(i32, &str)
147..155 '(1, "a")': (i32, &str)
148..149 '1': i32
151..154 '"a"': &str
157..167 '|(x, y)| x': impl Fn(&(i32, &str)) -> &i32
157..167 '|(x, y)| x': impl FnOnce(&(i32, &str)) -> &i32
158..164 '(x, y)': (i32, &str)
159..160 'x': &i32
162..163 'y': &&str
Expand Down
2 changes: 1 addition & 1 deletion crates/hir-ty/src/tests/regression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -862,7 +862,7 @@ fn main() {
123..126 'S()': S<i32>
132..133 's': S<i32>
132..144 's.g(|_x| {})': ()
136..143 '|_x| {}': impl Fn(&i32)
136..143 '|_x| {}': impl FnOnce(&i32)
137..139 '_x': &i32
141..143 '{}': ()
150..151 's': S<i32>
Expand Down
27 changes: 25 additions & 2 deletions crates/hir-ty/src/tests/simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2190,9 +2190,9 @@ fn main() {
149..151 'Ok': extern "rust-call" Ok<(), ()>(()) -> Result<(), ()>
149..155 'Ok(())': Result<(), ()>
152..154 '()': ()
167..171 'test': fn test<(), (), impl Fn() -> impl Future<Output = Result<(), ()>>, impl Future<Output = Result<(), ()>>>(impl Fn() -> impl Future<Output = Result<(), ()>>)
167..171 'test': fn test<(), (), impl FnMut() -> impl Future<Output = Result<(), ()>>, impl Future<Output = Result<(), ()>>>(impl FnMut() -> impl Future<Output = Result<(), ()>>)
167..228 'test(|... })': ()
172..227 '|| asy... }': impl Fn() -> impl Future<Output = Result<(), ()>>
172..227 '|| asy... }': impl FnMut() -> impl Future<Output = Result<(), ()>>
175..227 'async ... }': impl Future<Output = Result<(), ()>>
191..205 'return Err(())': !
198..201 'Err': extern "rust-call" Err<(), ()>(()) -> Result<(), ()>
Expand Down Expand Up @@ -2743,6 +2743,29 @@ impl B for Astruct {}
)
}

#[test]
fn closures_kinds_with_predicates() {
check_types(
r#"
//- minicore: fn
struct A<F: FnOnce()>(F);
struct B<'a, F: FnMut()>(&'a F);

fn f() {
let c1 = || {};
//^^ impl Fn()
let a1 = A(|| {});
let c2 = a1.0;
//^^ impl FnOnce()
let c3 = || {};
//^^ impl FnMut()
let a2 = A(c3);
let b1 = B(&a2.0);
}
"#,
)
}

#[test]
fn capture_kinds_simple() {
check_types(
Expand Down
34 changes: 17 additions & 17 deletions crates/hir-ty/src/tests/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1333,9 +1333,9 @@ fn foo<const C: u8, T>() -> (impl FnOnce(&str, T), impl Trait<u8>) {
}
"#,
expect![[r#"
134..165 '{ ...(C)) }': (impl Fn(&str, T), Bar<u8>)
140..163 '(|inpu...ar(C))': (impl Fn(&str, T), Bar<u8>)
141..154 '|input, t| {}': impl Fn(&str, T)
134..165 '{ ...(C)) }': (impl FnOnce(&str, T), Bar<u8>)
140..163 '(|inpu...ar(C))': (impl FnOnce(&str, T), Bar<u8>)
141..154 '|input, t| {}': impl FnOnce(&str, T)
142..147 'input': &str
149..150 't': T
152..154 '{}': ()
Expand Down Expand Up @@ -1963,20 +1963,20 @@ fn test() {
163..167 '1u32': u32
174..175 'x': Option<u32>
174..190 'x.map(...v + 1)': Option<u32>
180..189 '|v| v + 1': impl Fn(u32) -> u32
180..189 '|v| v + 1': impl FnOnce(u32) -> u32
181..182 'v': u32
184..185 'v': u32
184..189 'v + 1': u32
188..189 '1': u32
196..197 'x': Option<u32>
196..212 'x.map(... 1u64)': Option<u64>
202..211 '|_v| 1u64': impl Fn(u32) -> u64
202..211 '|_v| 1u64': impl FnOnce(u32) -> u64
203..205 '_v': u32
207..211 '1u64': u64
222..223 'y': Option<i64>
239..240 'x': Option<u32>
239..252 'x.map(|_v| 1)': Option<i64>
245..251 '|_v| 1': impl Fn(u32) -> i64
245..251 '|_v| 1': impl FnOnce(u32) -> i64
246..248 '_v': u32
250..251 '1': i64
"#]],
Expand Down Expand Up @@ -2062,17 +2062,17 @@ fn test() {
312..314 '{}': ()
330..489 '{ ... S); }': ()
340..342 'x1': u64
345..349 'foo1': fn foo1<S, u64, impl Fn(S) -> u64>(S, impl Fn(S) -> u64) -> u64
345..349 'foo1': fn foo1<S, u64, impl FnOnce(S) -> u64>(S, impl FnOnce(S) -> u64) -> u64
345..368 'foo1(S...hod())': u64
350..351 'S': S
353..367 '|s| s.method()': impl Fn(S) -> u64
353..367 '|s| s.method()': impl FnOnce(S) -> u64
354..355 's': S
357..358 's': S
357..367 's.method()': u64
378..380 'x2': u64
383..387 'foo2': fn foo2<S, u64, impl Fn(S) -> u64>(impl Fn(S) -> u64, S) -> u64
383..387 'foo2': fn foo2<S, u64, impl FnOnce(S) -> u64>(impl FnOnce(S) -> u64, S) -> u64
383..406 'foo2(|...(), S)': u64
388..402 '|s| s.method()': impl Fn(S) -> u64
388..402 '|s| s.method()': impl FnOnce(S) -> u64
389..390 's': S
392..393 's': S
392..402 's.method()': u64
Expand All @@ -2081,14 +2081,14 @@ fn test() {
421..422 'S': S
421..446 'S.foo1...hod())': u64
428..429 'S': S
431..445 '|s| s.method()': impl Fn(S) -> u64
431..445 '|s| s.method()': impl FnOnce(S) -> u64
432..433 's': S
435..436 's': S
435..445 's.method()': u64
456..458 'x4': u64
461..462 'S': S
461..486 'S.foo2...(), S)': u64
468..482 '|s| s.method()': impl Fn(S) -> u64
468..482 '|s| s.method()': impl FnOnce(S) -> u64
469..470 's': S
472..473 's': S
472..482 's.method()': u64
Expand Down Expand Up @@ -2562,9 +2562,9 @@ fn main() {
72..74 '_v': F
117..120 '{ }': ()
132..163 '{ ... }); }': ()
138..148 'f::<(), _>': fn f<(), impl Fn(&())>(impl Fn(&()))
138..148 'f::<(), _>': fn f<(), impl FnOnce(&())>(impl FnOnce(&()))
138..160 'f::<()... z; })': ()
149..159 '|z| { z; }': impl Fn(&())
149..159 '|z| { z; }': impl FnOnce(&())
150..151 'z': &()
153..159 '{ z; }': ()
155..156 'z': &()
Expand Down Expand Up @@ -2749,9 +2749,9 @@ fn main() {
983..998 'Vec::<i32>::new': fn new<i32>() -> Vec<i32>
983..1000 'Vec::<...:new()': Vec<i32>
983..1012 'Vec::<...iter()': IntoIter<i32>
983..1075 'Vec::<...one })': FilterMap<IntoIter<i32>, impl Fn(i32) -> Option<u32>>
983..1075 'Vec::<...one })': FilterMap<IntoIter<i32>, impl FnMut(i32) -> Option<u32>>
983..1101 'Vec::<... y; })': ()
1029..1074 '|x| if...None }': impl Fn(i32) -> Option<u32>
1029..1074 '|x| if...None }': impl FnMut(i32) -> Option<u32>
1030..1031 'x': i32
1033..1074 'if x >...None }': Option<u32>
1036..1037 'x': i32
Expand All @@ -2764,7 +2764,7 @@ fn main() {
1049..1057 'x as u32': u32
1066..1074 '{ None }': Option<u32>
1068..1072 'None': Option<u32>
1090..1100 '|y| { y; }': impl Fn(u32)
1090..1100 '|y| { y; }': impl FnMut(u32)
1091..1092 'y': u32
1094..1100 '{ y; }': ()
1096..1097 'y': u32
Expand Down
14 changes: 7 additions & 7 deletions crates/ide/src/hover/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -353,9 +353,9 @@ fn main() {
expect![[r#"
```rust
{closure#0} // size = 8, align = 8, niches = 1
impl FnOnce() -> S2
impl Fn() -> S2
```
Coerced to: &impl FnOnce() -> S2
Coerced to: &impl Fn() -> S2

## Captures
* `x` by move"#]],
Expand Down Expand Up @@ -401,17 +401,17 @@ fn main() {
},
},
HoverGotoTypeData {
mod_path: "core::ops::function::FnOnce",
mod_path: "core::ops::function::Fn",
nav: NavigationTarget {
file_id: FileId(
1,
),
full_range: 632..867,
focus_range: 693..699,
name: "FnOnce",
full_range: 254..425,
focus_range: 310..312,
name: "Fn",
kind: Trait,
container_name: "function",
description: "pub trait FnOnce<Args>\nwhere\n Args: Tuple,",
description: "pub trait Fn<Args>\nwhere\n Self: FnMut<Args>,\n Args: Tuple,",
},
},
],
Expand Down