Skip to content

Commit 1c27243

Browse files
Merge #9988
9988: fix: Refactor & improve handling of overloaded binary operators r=flodiebold a=flodiebold Fixes #9971. Also records them as method resolutions, which we could use later. Co-authored-by: Florian Diebold <[email protected]>
2 parents c8fd4fd + 424dda8 commit 1c27243

File tree

11 files changed

+317
-261
lines changed

11 files changed

+317
-261
lines changed

crates/hir_def/src/data.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,13 @@ impl TraitData {
199199
_ => None,
200200
})
201201
}
202+
203+
pub fn method_by_name(&self, name: &Name) -> Option<FunctionId> {
204+
self.items.iter().find_map(|(item_name, item)| match item {
205+
AssocItemId::FunctionId(t) if item_name == name => Some(*t),
206+
_ => None,
207+
})
208+
}
202209
}
203210

204211
#[derive(Debug, Clone, PartialEq, Eq)]

crates/hir_expand/src/name.rs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,33 @@ pub mod known {
285285
wrapping_add,
286286
wrapping_mul,
287287
wrapping_sub,
288+
// known methods of lang items
289+
add,
290+
mul,
291+
sub,
292+
div,
293+
rem,
294+
shl,
295+
shr,
296+
bitxor,
297+
bitor,
298+
bitand,
299+
add_assign,
300+
mul_assign,
301+
sub_assign,
302+
div_assign,
303+
rem_assign,
304+
shl_assign,
305+
shr_assign,
306+
bitxor_assign,
307+
bitor_assign,
308+
bitand_assign,
309+
eq,
310+
ne,
311+
ge,
312+
gt,
313+
le,
314+
lt,
288315
);
289316

290317
// self/Self cannot be used as an identifier

crates/hir_ty/src/infer.rs

Lines changed: 13 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@
1616
use std::ops::Index;
1717
use std::sync::Arc;
1818

19-
use chalk_ir::{cast::Cast, DebruijnIndex, Mutability, Safety};
19+
use chalk_ir::{cast::Cast, DebruijnIndex, Mutability, Safety, Scalar};
2020
use hir_def::{
2121
body::Body,
2222
data::{ConstData, FunctionData, StaticData},
23-
expr::{ArithOp, BinaryOp, BindingAnnotation, ExprId, PatId},
23+
expr::{BindingAnnotation, ExprId, PatId},
2424
lang_item::LangItemTarget,
2525
path::{path, Path},
2626
resolver::{HasResolver, ResolveValueResult, Resolver, TypeNs, ValueNs},
@@ -134,11 +134,17 @@ pub struct TypeMismatch {
134134
#[derive(Clone, PartialEq, Eq, Debug)]
135135
struct InternedStandardTypes {
136136
unknown: Ty,
137+
bool_: Ty,
138+
unit: Ty,
137139
}
138140

139141
impl Default for InternedStandardTypes {
140142
fn default() -> Self {
141-
InternedStandardTypes { unknown: TyKind::Error.intern(&Interner) }
143+
InternedStandardTypes {
144+
unknown: TyKind::Error.intern(&Interner),
145+
bool_: TyKind::Scalar(Scalar::Bool).intern(&Interner),
146+
unit: TyKind::Tuple(0, Substitution::empty(&Interner)).intern(&Interner),
147+
}
142148
}
143149
}
144150
/// Represents coercing a value to a different type of value.
@@ -751,28 +757,6 @@ impl<'a> InferenceContext<'a> {
751757
self.db.trait_data(trait_).associated_type_by_name(&name![Output])
752758
}
753759

754-
fn resolve_binary_op_output(&self, bop: &BinaryOp) -> Option<TypeAliasId> {
755-
let lang_item = match bop {
756-
BinaryOp::ArithOp(aop) => match aop {
757-
ArithOp::Add => "add",
758-
ArithOp::Sub => "sub",
759-
ArithOp::Mul => "mul",
760-
ArithOp::Div => "div",
761-
ArithOp::Shl => "shl",
762-
ArithOp::Shr => "shr",
763-
ArithOp::Rem => "rem",
764-
ArithOp::BitXor => "bitxor",
765-
ArithOp::BitOr => "bitor",
766-
ArithOp::BitAnd => "bitand",
767-
},
768-
_ => return None,
769-
};
770-
771-
let trait_ = self.resolve_lang_item(lang_item)?.as_trait();
772-
773-
self.db.trait_data(trait_?).associated_type_by_name(&name![Output])
774-
}
775-
776760
fn resolve_boxed_box(&self) -> Option<AdtId> {
777761
let struct_ = self.resolve_lang_item("owned_box")?.as_struct()?;
778762
Some(struct_.into())
@@ -846,6 +830,10 @@ impl Expectation {
846830
}
847831
}
848832

833+
fn from_option(ty: Option<Ty>) -> Self {
834+
ty.map_or(Expectation::None, Expectation::HasType)
835+
}
836+
849837
/// The following explanation is copied straight from rustc:
850838
/// Provides an expectation for an rvalue expression given an *optional*
851839
/// hint, which is not required for type safety (the resulting type might

crates/hir_ty/src/infer/expr.rs

Lines changed: 213 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,13 @@ use std::{
88

99
use chalk_ir::{cast::Cast, fold::Shift, Mutability, TyVariableKind};
1010
use hir_def::{
11-
expr::{Array, BinaryOp, Expr, ExprId, Literal, MatchGuard, Statement, UnaryOp},
11+
expr::{
12+
ArithOp, Array, BinaryOp, CmpOp, Expr, ExprId, Literal, MatchGuard, Ordering, Statement,
13+
UnaryOp,
14+
},
1215
path::{GenericArg, GenericArgs},
1316
resolver::resolver_for_expr,
14-
AssocContainerId, FieldId, Lookup,
17+
AssocContainerId, FieldId, FunctionId, Lookup,
1518
};
1619
use hir_expand::name::{name, Name};
1720
use stdx::always;
@@ -23,7 +26,7 @@ use crate::{
2326
infer::coerce::CoerceMany,
2427
lower::lower_to_chalk_mutability,
2528
mapping::from_chalk,
26-
method_resolution, op,
29+
method_resolution,
2730
primitive::{self, UintTy},
2831
static_lifetime, to_chalk_trait_id,
2932
traits::FnTrait,
@@ -669,34 +672,21 @@ impl<'a> InferenceContext<'a> {
669672
}
670673
}
671674
Expr::BinaryOp { lhs, rhs, op } => match op {
672-
Some(op) => {
673-
let lhs_expectation = match op {
674-
BinaryOp::LogicOp(..) => {
675-
Expectation::has_type(TyKind::Scalar(Scalar::Bool).intern(&Interner))
676-
}
677-
_ => Expectation::none(),
678-
};
679-
let lhs_ty = self.infer_expr(*lhs, &lhs_expectation);
680-
let lhs_ty = self.resolve_ty_shallow(&lhs_ty);
681-
let rhs_expectation = op::binary_op_rhs_expectation(*op, lhs_ty.clone());
682-
let rhs_ty =
683-
self.infer_expr_coerce(*rhs, &Expectation::has_type(rhs_expectation));
684-
let rhs_ty = self.resolve_ty_shallow(&rhs_ty);
685-
686-
let ret = op::binary_op_return_ty(*op, lhs_ty.clone(), rhs_ty.clone());
687-
688-
if ret.is_unknown() {
689-
cov_mark::hit!(infer_expr_inner_binary_operator_overload);
690-
691-
self.resolve_associated_type_with_params(
692-
lhs_ty,
693-
self.resolve_binary_op_output(op),
694-
&[rhs_ty],
695-
)
696-
} else {
697-
ret
698-
}
675+
Some(BinaryOp::Assignment { op: None }) => {
676+
let lhs_ty = self.infer_expr(*lhs, &Expectation::none());
677+
self.infer_expr_coerce(*rhs, &Expectation::has_type(lhs_ty));
678+
self.result.standard_types.unit.clone()
679+
}
680+
Some(BinaryOp::LogicOp(_)) => {
681+
let bool_ty = self.result.standard_types.bool_.clone();
682+
self.infer_expr_coerce(*lhs, &Expectation::HasType(bool_ty.clone()));
683+
let lhs_diverges = self.diverges;
684+
self.infer_expr_coerce(*rhs, &Expectation::HasType(bool_ty.clone()));
685+
// Depending on the LHS' value, the RHS can never execute.
686+
self.diverges = lhs_diverges;
687+
bool_ty
699688
}
689+
Some(op) => self.infer_overloadable_binop(*lhs, *op, *rhs, tgt_expr),
700690
_ => self.err_ty(),
701691
},
702692
Expr::Range { lhs, rhs, range_type } => {
@@ -862,6 +852,62 @@ impl<'a> InferenceContext<'a> {
862852
ty
863853
}
864854

855+
fn infer_overloadable_binop(
856+
&mut self,
857+
lhs: ExprId,
858+
op: BinaryOp,
859+
rhs: ExprId,
860+
tgt_expr: ExprId,
861+
) -> Ty {
862+
let lhs_expectation = Expectation::none();
863+
let lhs_ty = self.infer_expr(lhs, &lhs_expectation);
864+
let rhs_ty = self.table.new_type_var();
865+
866+
let func = self.resolve_binop_method(op);
867+
let func = match func {
868+
Some(func) => func,
869+
None => {
870+
let rhs_ty = self.builtin_binary_op_rhs_expectation(op, lhs_ty.clone());
871+
let rhs_ty = self.infer_expr_coerce(rhs, &Expectation::from_option(rhs_ty));
872+
return self
873+
.builtin_binary_op_return_ty(op, lhs_ty, rhs_ty)
874+
.unwrap_or_else(|| self.err_ty());
875+
}
876+
};
877+
878+
let subst = TyBuilder::subst_for_def(self.db, func)
879+
.push(lhs_ty.clone())
880+
.push(rhs_ty.clone())
881+
.build();
882+
self.write_method_resolution(tgt_expr, func, subst.clone());
883+
884+
let method_ty = self.db.value_ty(func.into()).substitute(&Interner, &subst);
885+
self.register_obligations_for_call(&method_ty);
886+
887+
self.infer_expr_coerce(rhs, &Expectation::has_type(rhs_ty.clone()));
888+
889+
let ret_ty = match method_ty.callable_sig(self.db) {
890+
Some(sig) => sig.ret().clone(),
891+
None => self.err_ty(),
892+
};
893+
894+
let ret_ty = self.normalize_associated_types_in(ret_ty);
895+
896+
// FIXME: record autoref adjustments
897+
898+
// use knowledge of built-in binary ops, which can sometimes help inference
899+
if let Some(builtin_rhs) = self.builtin_binary_op_rhs_expectation(op, lhs_ty.clone()) {
900+
self.unify(&builtin_rhs, &rhs_ty);
901+
}
902+
if let Some(builtin_ret) =
903+
self.builtin_binary_op_return_ty(op, lhs_ty.clone(), rhs_ty.clone())
904+
{
905+
self.unify(&builtin_ret, &ret_ty);
906+
}
907+
908+
ret_ty
909+
}
910+
865911
fn infer_block(
866912
&mut self,
867913
expr: ExprId,
@@ -1136,4 +1182,141 @@ impl<'a> InferenceContext<'a> {
11361182
}
11371183
}
11381184
}
1185+
1186+
fn builtin_binary_op_return_ty(&mut self, op: BinaryOp, lhs_ty: Ty, rhs_ty: Ty) -> Option<Ty> {
1187+
let lhs_ty = self.resolve_ty_shallow(&lhs_ty);
1188+
let rhs_ty = self.resolve_ty_shallow(&rhs_ty);
1189+
match op {
1190+
BinaryOp::LogicOp(_) | BinaryOp::CmpOp(_) => {
1191+
Some(TyKind::Scalar(Scalar::Bool).intern(&Interner))
1192+
}
1193+
BinaryOp::Assignment { .. } => Some(TyBuilder::unit()),
1194+
BinaryOp::ArithOp(ArithOp::Shl | ArithOp::Shr) => {
1195+
// all integer combinations are valid here
1196+
if matches!(
1197+
lhs_ty.kind(&Interner),
1198+
TyKind::Scalar(Scalar::Int(_) | Scalar::Uint(_))
1199+
| TyKind::InferenceVar(_, TyVariableKind::Integer)
1200+
) && matches!(
1201+
rhs_ty.kind(&Interner),
1202+
TyKind::Scalar(Scalar::Int(_) | Scalar::Uint(_))
1203+
| TyKind::InferenceVar(_, TyVariableKind::Integer)
1204+
) {
1205+
Some(lhs_ty)
1206+
} else {
1207+
None
1208+
}
1209+
}
1210+
BinaryOp::ArithOp(_) => match (lhs_ty.kind(&Interner), rhs_ty.kind(&Interner)) {
1211+
// (int, int) | (uint, uint) | (float, float)
1212+
(TyKind::Scalar(Scalar::Int(_)), TyKind::Scalar(Scalar::Int(_)))
1213+
| (TyKind::Scalar(Scalar::Uint(_)), TyKind::Scalar(Scalar::Uint(_)))
1214+
| (TyKind::Scalar(Scalar::Float(_)), TyKind::Scalar(Scalar::Float(_))) => {
1215+
Some(rhs_ty)
1216+
}
1217+
// ({int}, int) | ({int}, uint)
1218+
(
1219+
TyKind::InferenceVar(_, TyVariableKind::Integer),
1220+
TyKind::Scalar(Scalar::Int(_) | Scalar::Uint(_)),
1221+
) => Some(rhs_ty),
1222+
// (int, {int}) | (uint, {int})
1223+
(
1224+
TyKind::Scalar(Scalar::Int(_) | Scalar::Uint(_)),
1225+
TyKind::InferenceVar(_, TyVariableKind::Integer),
1226+
) => Some(lhs_ty),
1227+
// ({float} | float)
1228+
(
1229+
TyKind::InferenceVar(_, TyVariableKind::Float),
1230+
TyKind::Scalar(Scalar::Float(_)),
1231+
) => Some(rhs_ty),
1232+
// (float, {float})
1233+
(
1234+
TyKind::Scalar(Scalar::Float(_)),
1235+
TyKind::InferenceVar(_, TyVariableKind::Float),
1236+
) => Some(lhs_ty),
1237+
// ({int}, {int}) | ({float}, {float})
1238+
(
1239+
TyKind::InferenceVar(_, TyVariableKind::Integer),
1240+
TyKind::InferenceVar(_, TyVariableKind::Integer),
1241+
)
1242+
| (
1243+
TyKind::InferenceVar(_, TyVariableKind::Float),
1244+
TyKind::InferenceVar(_, TyVariableKind::Float),
1245+
) => Some(rhs_ty),
1246+
_ => None,
1247+
},
1248+
}
1249+
}
1250+
1251+
fn builtin_binary_op_rhs_expectation(&mut self, op: BinaryOp, lhs_ty: Ty) -> Option<Ty> {
1252+
Some(match op {
1253+
BinaryOp::LogicOp(..) => TyKind::Scalar(Scalar::Bool).intern(&Interner),
1254+
BinaryOp::Assignment { op: None } => lhs_ty,
1255+
BinaryOp::CmpOp(CmpOp::Eq { .. }) => match self
1256+
.resolve_ty_shallow(&lhs_ty)
1257+
.kind(&Interner)
1258+
{
1259+
TyKind::Scalar(_) | TyKind::Str => lhs_ty,
1260+
TyKind::InferenceVar(_, TyVariableKind::Integer | TyVariableKind::Float) => lhs_ty,
1261+
_ => return None,
1262+
},
1263+
BinaryOp::ArithOp(ArithOp::Shl | ArithOp::Shr) => return None,
1264+
BinaryOp::CmpOp(CmpOp::Ord { .. })
1265+
| BinaryOp::Assignment { op: Some(_) }
1266+
| BinaryOp::ArithOp(_) => match self.resolve_ty_shallow(&lhs_ty).kind(&Interner) {
1267+
TyKind::Scalar(Scalar::Int(_) | Scalar::Uint(_) | Scalar::Float(_)) => lhs_ty,
1268+
TyKind::InferenceVar(_, TyVariableKind::Integer | TyVariableKind::Float) => lhs_ty,
1269+
_ => return None,
1270+
},
1271+
})
1272+
}
1273+
1274+
fn resolve_binop_method(&self, op: BinaryOp) -> Option<FunctionId> {
1275+
let (name, lang_item) = match op {
1276+
BinaryOp::LogicOp(_) => return None,
1277+
BinaryOp::ArithOp(aop) => match aop {
1278+
ArithOp::Add => (name!(add), "add"),
1279+
ArithOp::Mul => (name!(mul), "mul"),
1280+
ArithOp::Sub => (name!(sub), "sub"),
1281+
ArithOp::Div => (name!(div), "div"),
1282+
ArithOp::Rem => (name!(rem), "rem"),
1283+
ArithOp::Shl => (name!(shl), "shl"),
1284+
ArithOp::Shr => (name!(shr), "shr"),
1285+
ArithOp::BitXor => (name!(bitxor), "bitxor"),
1286+
ArithOp::BitOr => (name!(bitor), "bitor"),
1287+
ArithOp::BitAnd => (name!(bitand), "bitand"),
1288+
},
1289+
BinaryOp::Assignment { op: Some(aop) } => match aop {
1290+
ArithOp::Add => (name!(add_assign), "add_assign"),
1291+
ArithOp::Mul => (name!(mul_assign), "mul_assign"),
1292+
ArithOp::Sub => (name!(sub_assign), "sub_assign"),
1293+
ArithOp::Div => (name!(div_assign), "div_assign"),
1294+
ArithOp::Rem => (name!(rem_assign), "rem_assign"),
1295+
ArithOp::Shl => (name!(shl_assign), "shl_assign"),
1296+
ArithOp::Shr => (name!(shr_assign), "shr_assign"),
1297+
ArithOp::BitXor => (name!(bitxor_assign), "bitxor_assign"),
1298+
ArithOp::BitOr => (name!(bitor_assign), "bitor_assign"),
1299+
ArithOp::BitAnd => (name!(bitand_assign), "bitand_assign"),
1300+
},
1301+
BinaryOp::CmpOp(cop) => match cop {
1302+
CmpOp::Eq { negated: false } => (name!(eq), "eq"),
1303+
CmpOp::Eq { negated: true } => (name!(ne), "eq"),
1304+
CmpOp::Ord { ordering: Ordering::Less, strict: false } => {
1305+
(name!(le), "partial_ord")
1306+
}
1307+
CmpOp::Ord { ordering: Ordering::Less, strict: true } => (name!(lt), "partial_ord"),
1308+
CmpOp::Ord { ordering: Ordering::Greater, strict: false } => {
1309+
(name!(ge), "partial_ord")
1310+
}
1311+
CmpOp::Ord { ordering: Ordering::Greater, strict: true } => {
1312+
(name!(gt), "partial_ord")
1313+
}
1314+
},
1315+
BinaryOp::Assignment { op: None } => return None,
1316+
};
1317+
1318+
let trait_ = self.resolve_lang_item(lang_item)?.as_trait()?;
1319+
1320+
self.db.trait_data(trait_).method_by_name(&name)
1321+
}
11391322
}

0 commit comments

Comments
 (0)