Skip to content

Commit ae22050

Browse files
committed
Determine expected parameters from expected return in calls
Fixes #9560
1 parent eb2cc10 commit ae22050

File tree

3 files changed

+98
-11
lines changed

3 files changed

+98
-11
lines changed

crates/hir_ty/src/infer.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -845,8 +845,9 @@ impl Expectation {
845845
/// which still is useful, because it informs integer literals and the like.
846846
/// See the test case `test/ui/coerce-expect-unsized.rs` and #20169
847847
/// for examples of where this comes up,.
848-
fn rvalue_hint(ty: Ty) -> Self {
849-
match ty.strip_references().kind(&Interner) {
848+
fn rvalue_hint(table: &mut unify::InferenceTable, ty: Ty) -> Self {
849+
// FIXME: do struct_tail_without_normalization
850+
match table.resolve_ty_shallow(&ty).kind(&Interner) {
850851
TyKind::Slice(_) | TyKind::Str | TyKind::Dyn(_) => Expectation::RValueLikeUnsized(ty),
851852
_ => Expectation::has_type(ty),
852853
}

crates/hir_ty/src/infer/expr.rs

Lines changed: 83 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -340,11 +340,25 @@ impl<'a> InferenceContext<'a> {
340340
None => (Vec::new(), self.err_ty()),
341341
};
342342
self.register_obligations_for_call(&callee_ty);
343-
self.check_call_arguments(args, &param_tys);
343+
344+
let expected_inputs = self.expected_inputs_for_expected_output(
345+
expected,
346+
ret_ty.clone(),
347+
param_tys.clone(),
348+
);
349+
350+
self.check_call_arguments(args, &expected_inputs, &param_tys);
344351
self.normalize_associated_types_in(ret_ty)
345352
}
346353
Expr::MethodCall { receiver, args, method_name, generic_args } => self
347-
.infer_method_call(tgt_expr, *receiver, args, method_name, generic_args.as_deref()),
354+
.infer_method_call(
355+
tgt_expr,
356+
*receiver,
357+
args,
358+
method_name,
359+
generic_args.as_deref(),
360+
expected,
361+
),
348362
Expr::Match { expr, arms } => {
349363
let input_ty = self.infer_expr(*expr, &Expectation::none());
350364

@@ -575,7 +589,7 @@ impl<'a> InferenceContext<'a> {
575589
// FIXME: record type error - expected reference but found ptr,
576590
// which cannot be coerced
577591
}
578-
Expectation::rvalue_hint(Ty::clone(exp_inner))
592+
Expectation::rvalue_hint(&mut self.table, Ty::clone(exp_inner))
579593
} else {
580594
Expectation::none()
581595
};
@@ -902,6 +916,7 @@ impl<'a> InferenceContext<'a> {
902916
args: &[ExprId],
903917
method_name: &Name,
904918
generic_args: Option<&GenericArgs>,
919+
expected: &Expectation,
905920
) -> Ty {
906921
let receiver_ty = self.infer_expr(receiver, &Expectation::none());
907922
let canonicalized_receiver = self.canonicalize(receiver_ty.clone());
@@ -935,7 +950,7 @@ impl<'a> InferenceContext<'a> {
935950
};
936951
let method_ty = method_ty.substitute(&Interner, &substs);
937952
self.register_obligations_for_call(&method_ty);
938-
let (expected_receiver_ty, param_tys, ret_ty) = match method_ty.callable_sig(self.db) {
953+
let (formal_receiver_ty, param_tys, ret_ty) = match method_ty.callable_sig(self.db) {
939954
Some(sig) => {
940955
if !sig.params().is_empty() {
941956
(sig.params()[0].clone(), sig.params()[1..].to_vec(), sig.ret().clone())
@@ -945,28 +960,87 @@ impl<'a> InferenceContext<'a> {
945960
}
946961
None => (self.err_ty(), Vec::new(), self.err_ty()),
947962
};
948-
self.unify(&expected_receiver_ty, &receiver_ty);
963+
self.unify(&formal_receiver_ty, &receiver_ty);
964+
965+
let expected_inputs =
966+
self.expected_inputs_for_expected_output(expected, ret_ty.clone(), param_tys.clone());
949967

950-
self.check_call_arguments(args, &param_tys);
968+
self.check_call_arguments(args, &expected_inputs, &param_tys);
951969
self.normalize_associated_types_in(ret_ty)
952970
}
953971

954-
fn check_call_arguments(&mut self, args: &[ExprId], param_tys: &[Ty]) {
972+
fn expected_inputs_for_expected_output(
973+
&mut self,
974+
expected_output: &Expectation,
975+
output: Ty,
976+
inputs: Vec<Ty>,
977+
) -> Vec<Ty> {
978+
// rustc does a snapshot here and rolls back the unification, but since
979+
// we actually want to keep unbound variables in the result it then
980+
// needs to do 'fudging' to recreate them. So I'm not sure rustc's
981+
// approach is cleaner than ours, which is to create independent copies
982+
// of the variables before unifying. It might be more performant though,
983+
// so we might want to benchmark when we can actually do
984+
// snapshot/rollback.
985+
if let Some(expected_ty) = expected_output.to_option(&mut self.table) {
986+
let (expected_ret_ty, expected_params) = self.table.reinstantiate((output, inputs));
987+
if self.table.try_unify(&expected_ty, &expected_ret_ty).is_ok() {
988+
expected_params
989+
} else {
990+
Vec::new()
991+
}
992+
} else {
993+
Vec::new()
994+
}
995+
}
996+
997+
fn check_call_arguments(&mut self, args: &[ExprId], expected_inputs: &[Ty], param_tys: &[Ty]) {
955998
// Quoting https://github.com/rust-lang/rust/blob/6ef275e6c3cb1384ec78128eceeb4963ff788dca/src/librustc_typeck/check/mod.rs#L3325 --
956999
// We do this in a pretty awful way: first we type-check any arguments
9571000
// that are not closures, then we type-check the closures. This is so
9581001
// that we have more information about the types of arguments when we
9591002
// type-check the functions. This isn't really the right way to do this.
9601003
for &check_closures in &[false, true] {
9611004
let param_iter = param_tys.iter().cloned().chain(repeat(self.err_ty()));
962-
for (&arg, param_ty) in args.iter().zip(param_iter) {
1005+
let expected_iter = expected_inputs
1006+
.iter()
1007+
.cloned()
1008+
.chain(param_iter.clone().skip(expected_inputs.len()));
1009+
for ((&arg, param_ty), expected_ty) in args.iter().zip(param_iter).zip(expected_iter) {
9631010
let is_closure = matches!(&self.body[arg], Expr::Lambda { .. });
9641011
if is_closure != check_closures {
9651012
continue;
9661013
}
9671014

1015+
// the difference between param_ty and expected here is that
1016+
// expected is the parameter when the expected *return* type is
1017+
// taken into account. So in `let _: &[i32] = identity(&[1, 2])`
1018+
// the expected type is already `&[i32]`, whereas param_ty is
1019+
// still an unbound type variable. We don't always want to force
1020+
// the parameter to coerce to the expected type (for example in
1021+
// `coerce_unsize_expected_type_4`).
9681022
let param_ty = self.normalize_associated_types_in(param_ty);
969-
self.infer_expr_coerce(arg, &Expectation::has_type(param_ty.clone()));
1023+
let expected = Expectation::rvalue_hint(&mut self.table, expected_ty);
1024+
// infer with the expected type we have...
1025+
let ty = self.infer_expr_inner(arg, &expected);
1026+
1027+
// then coerce to either the expected type or just the formal parameter type
1028+
let coercion_target = if let Some(ty) = expected.only_has_type(&mut self.table) {
1029+
// if we are coercing to the expectation, unify with the
1030+
// formal parameter type to connect everything
1031+
self.unify(&ty, &param_ty);
1032+
ty
1033+
} else {
1034+
param_ty
1035+
};
1036+
if !coercion_target.is_unknown() {
1037+
if self.coerce(Some(arg), &ty, &coercion_target).is_err() {
1038+
self.result.type_mismatches.insert(
1039+
arg.into(),
1040+
TypeMismatch { expected: coercion_target, actual: ty.clone() },
1041+
);
1042+
}
1043+
}
9701044
}
9711045
}
9721046
}

crates/hir_ty/src/infer/unify.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,18 @@ impl<'a> InferenceTable<'a> {
302302
self.resolve_with_fallback(t, |_, _, d, _| d)
303303
}
304304

305+
/// This makes a copy of the given `t` where all unbound inference variables
306+
/// have been replaced by fresh ones. This is useful for 'speculatively'
307+
/// unifying the result with something, without affecting the original types.
308+
pub(crate) fn reinstantiate<T>(&mut self, t: T) -> T::Result
309+
where
310+
T: HasInterner<Interner = Interner> + Fold<Interner>,
311+
T::Result: HasInterner<Interner = Interner> + Fold<Interner, Result = T::Result>,
312+
{
313+
let canonicalized = self.canonicalize(t);
314+
self.var_unification_table.instantiate_canonical(&Interner, canonicalized.value)
315+
}
316+
305317
/// Unify two types and register new trait goals that arise from that.
306318
pub(crate) fn unify(&mut self, ty1: &Ty, ty2: &Ty) -> bool {
307319
let result = if let Ok(r) = self.try_unify(ty1, ty2) {

0 commit comments

Comments
 (0)