@@ -8,10 +8,13 @@ use std::{
8
8
9
9
use chalk_ir:: { cast:: Cast , fold:: Shift , Mutability , TyVariableKind } ;
10
10
use hir_def:: {
11
- expr:: { Array , BinaryOp , Expr , ExprId , Literal , MatchGuard , Statement , UnaryOp } ,
11
+ expr:: {
12
+ ArithOp , Array , BinaryOp , CmpOp , Expr , ExprId , Literal , MatchGuard , Ordering , Statement ,
13
+ UnaryOp ,
14
+ } ,
12
15
path:: { GenericArg , GenericArgs } ,
13
16
resolver:: resolver_for_expr,
14
- AssocContainerId , FieldId , Lookup ,
17
+ AssocContainerId , FieldId , FunctionId , Lookup ,
15
18
} ;
16
19
use hir_expand:: name:: { name, Name } ;
17
20
use stdx:: always;
@@ -23,7 +26,7 @@ use crate::{
23
26
infer:: coerce:: CoerceMany ,
24
27
lower:: lower_to_chalk_mutability,
25
28
mapping:: from_chalk,
26
- method_resolution, op ,
29
+ method_resolution,
27
30
primitive:: { self , UintTy } ,
28
31
static_lifetime, to_chalk_trait_id,
29
32
traits:: FnTrait ,
@@ -669,34 +672,21 @@ impl<'a> InferenceContext<'a> {
669
672
}
670
673
}
671
674
Expr :: BinaryOp { lhs, rhs, op } => match op {
672
- Some ( op) => {
673
- let lhs_expectation = match op {
674
- BinaryOp :: LogicOp ( ..) => {
675
- Expectation :: has_type ( TyKind :: Scalar ( Scalar :: Bool ) . intern ( & Interner ) )
676
- }
677
- _ => Expectation :: none ( ) ,
678
- } ;
679
- let lhs_ty = self . infer_expr ( * lhs, & lhs_expectation) ;
680
- let lhs_ty = self . resolve_ty_shallow ( & lhs_ty) ;
681
- let rhs_expectation = op:: binary_op_rhs_expectation ( * op, lhs_ty. clone ( ) ) ;
682
- let rhs_ty =
683
- self . infer_expr_coerce ( * rhs, & Expectation :: has_type ( rhs_expectation) ) ;
684
- let rhs_ty = self . resolve_ty_shallow ( & rhs_ty) ;
685
-
686
- let ret = op:: binary_op_return_ty ( * op, lhs_ty. clone ( ) , rhs_ty. clone ( ) ) ;
687
-
688
- if ret. is_unknown ( ) {
689
- cov_mark:: hit!( infer_expr_inner_binary_operator_overload) ;
690
-
691
- self . resolve_associated_type_with_params (
692
- lhs_ty,
693
- self . resolve_binary_op_output ( op) ,
694
- & [ rhs_ty] ,
695
- )
696
- } else {
697
- ret
698
- }
675
+ Some ( BinaryOp :: Assignment { op : None } ) => {
676
+ let lhs_ty = self . infer_expr ( * lhs, & Expectation :: none ( ) ) ;
677
+ self . infer_expr_coerce ( * rhs, & Expectation :: has_type ( lhs_ty) ) ;
678
+ self . result . standard_types . unit . clone ( )
679
+ }
680
+ Some ( BinaryOp :: LogicOp ( _) ) => {
681
+ let bool_ty = self . result . standard_types . bool_ . clone ( ) ;
682
+ self . infer_expr_coerce ( * lhs, & Expectation :: HasType ( bool_ty. clone ( ) ) ) ;
683
+ let lhs_diverges = self . diverges ;
684
+ self . infer_expr_coerce ( * rhs, & Expectation :: HasType ( bool_ty. clone ( ) ) ) ;
685
+ // Depending on the LHS' value, the RHS can never execute.
686
+ self . diverges = lhs_diverges;
687
+ bool_ty
699
688
}
689
+ Some ( op) => self . infer_overloadable_binop ( * lhs, * op, * rhs, tgt_expr) ,
700
690
_ => self . err_ty ( ) ,
701
691
} ,
702
692
Expr :: Range { lhs, rhs, range_type } => {
@@ -862,6 +852,62 @@ impl<'a> InferenceContext<'a> {
862
852
ty
863
853
}
864
854
855
+ fn infer_overloadable_binop (
856
+ & mut self ,
857
+ lhs : ExprId ,
858
+ op : BinaryOp ,
859
+ rhs : ExprId ,
860
+ tgt_expr : ExprId ,
861
+ ) -> Ty {
862
+ let lhs_expectation = Expectation :: none ( ) ;
863
+ let lhs_ty = self . infer_expr ( lhs, & lhs_expectation) ;
864
+ let rhs_ty = self . table . new_type_var ( ) ;
865
+
866
+ let func = self . resolve_binop_method ( op) ;
867
+ let func = match func {
868
+ Some ( func) => func,
869
+ None => {
870
+ let rhs_ty = self . builtin_binary_op_rhs_expectation ( op, lhs_ty. clone ( ) ) ;
871
+ let rhs_ty = self . infer_expr_coerce ( rhs, & Expectation :: from_option ( rhs_ty) ) ;
872
+ return self
873
+ . builtin_binary_op_return_ty ( op, lhs_ty, rhs_ty)
874
+ . unwrap_or_else ( || self . err_ty ( ) ) ;
875
+ }
876
+ } ;
877
+
878
+ let subst = TyBuilder :: subst_for_def ( self . db , func)
879
+ . push ( lhs_ty. clone ( ) )
880
+ . push ( rhs_ty. clone ( ) )
881
+ . build ( ) ;
882
+ self . write_method_resolution ( tgt_expr, func, subst. clone ( ) ) ;
883
+
884
+ let method_ty = self . db . value_ty ( func. into ( ) ) . substitute ( & Interner , & subst) ;
885
+ self . register_obligations_for_call ( & method_ty) ;
886
+
887
+ self . infer_expr_coerce ( rhs, & Expectation :: has_type ( rhs_ty. clone ( ) ) ) ;
888
+
889
+ let ret_ty = match method_ty. callable_sig ( self . db ) {
890
+ Some ( sig) => sig. ret ( ) . clone ( ) ,
891
+ None => self . err_ty ( ) ,
892
+ } ;
893
+
894
+ let ret_ty = self . normalize_associated_types_in ( ret_ty) ;
895
+
896
+ // FIXME: record autoref adjustments
897
+
898
+ // use knowledge of built-in binary ops, which can sometimes help inference
899
+ if let Some ( builtin_rhs) = self . builtin_binary_op_rhs_expectation ( op, lhs_ty. clone ( ) ) {
900
+ self . unify ( & builtin_rhs, & rhs_ty) ;
901
+ }
902
+ if let Some ( builtin_ret) =
903
+ self . builtin_binary_op_return_ty ( op, lhs_ty. clone ( ) , rhs_ty. clone ( ) )
904
+ {
905
+ self . unify ( & builtin_ret, & ret_ty) ;
906
+ }
907
+
908
+ ret_ty
909
+ }
910
+
865
911
fn infer_block (
866
912
& mut self ,
867
913
expr : ExprId ,
@@ -1136,4 +1182,141 @@ impl<'a> InferenceContext<'a> {
1136
1182
}
1137
1183
}
1138
1184
}
1185
+
1186
+ fn builtin_binary_op_return_ty ( & mut self , op : BinaryOp , lhs_ty : Ty , rhs_ty : Ty ) -> Option < Ty > {
1187
+ let lhs_ty = self . resolve_ty_shallow ( & lhs_ty) ;
1188
+ let rhs_ty = self . resolve_ty_shallow ( & rhs_ty) ;
1189
+ match op {
1190
+ BinaryOp :: LogicOp ( _) | BinaryOp :: CmpOp ( _) => {
1191
+ Some ( TyKind :: Scalar ( Scalar :: Bool ) . intern ( & Interner ) )
1192
+ }
1193
+ BinaryOp :: Assignment { .. } => Some ( TyBuilder :: unit ( ) ) ,
1194
+ BinaryOp :: ArithOp ( ArithOp :: Shl | ArithOp :: Shr ) => {
1195
+ // all integer combinations are valid here
1196
+ if matches ! (
1197
+ lhs_ty. kind( & Interner ) ,
1198
+ TyKind :: Scalar ( Scalar :: Int ( _) | Scalar :: Uint ( _) )
1199
+ | TyKind :: InferenceVar ( _, TyVariableKind :: Integer )
1200
+ ) && matches ! (
1201
+ rhs_ty. kind( & Interner ) ,
1202
+ TyKind :: Scalar ( Scalar :: Int ( _) | Scalar :: Uint ( _) )
1203
+ | TyKind :: InferenceVar ( _, TyVariableKind :: Integer )
1204
+ ) {
1205
+ Some ( lhs_ty)
1206
+ } else {
1207
+ None
1208
+ }
1209
+ }
1210
+ BinaryOp :: ArithOp ( _) => match ( lhs_ty. kind ( & Interner ) , rhs_ty. kind ( & Interner ) ) {
1211
+ // (int, int) | (uint, uint) | (float, float)
1212
+ ( TyKind :: Scalar ( Scalar :: Int ( _) ) , TyKind :: Scalar ( Scalar :: Int ( _) ) )
1213
+ | ( TyKind :: Scalar ( Scalar :: Uint ( _) ) , TyKind :: Scalar ( Scalar :: Uint ( _) ) )
1214
+ | ( TyKind :: Scalar ( Scalar :: Float ( _) ) , TyKind :: Scalar ( Scalar :: Float ( _) ) ) => {
1215
+ Some ( rhs_ty)
1216
+ }
1217
+ // ({int}, int) | ({int}, uint)
1218
+ (
1219
+ TyKind :: InferenceVar ( _, TyVariableKind :: Integer ) ,
1220
+ TyKind :: Scalar ( Scalar :: Int ( _) | Scalar :: Uint ( _) ) ,
1221
+ ) => Some ( rhs_ty) ,
1222
+ // (int, {int}) | (uint, {int})
1223
+ (
1224
+ TyKind :: Scalar ( Scalar :: Int ( _) | Scalar :: Uint ( _) ) ,
1225
+ TyKind :: InferenceVar ( _, TyVariableKind :: Integer ) ,
1226
+ ) => Some ( lhs_ty) ,
1227
+ // ({float} | float)
1228
+ (
1229
+ TyKind :: InferenceVar ( _, TyVariableKind :: Float ) ,
1230
+ TyKind :: Scalar ( Scalar :: Float ( _) ) ,
1231
+ ) => Some ( rhs_ty) ,
1232
+ // (float, {float})
1233
+ (
1234
+ TyKind :: Scalar ( Scalar :: Float ( _) ) ,
1235
+ TyKind :: InferenceVar ( _, TyVariableKind :: Float ) ,
1236
+ ) => Some ( lhs_ty) ,
1237
+ // ({int}, {int}) | ({float}, {float})
1238
+ (
1239
+ TyKind :: InferenceVar ( _, TyVariableKind :: Integer ) ,
1240
+ TyKind :: InferenceVar ( _, TyVariableKind :: Integer ) ,
1241
+ )
1242
+ | (
1243
+ TyKind :: InferenceVar ( _, TyVariableKind :: Float ) ,
1244
+ TyKind :: InferenceVar ( _, TyVariableKind :: Float ) ,
1245
+ ) => Some ( rhs_ty) ,
1246
+ _ => None ,
1247
+ } ,
1248
+ }
1249
+ }
1250
+
1251
+ fn builtin_binary_op_rhs_expectation ( & mut self , op : BinaryOp , lhs_ty : Ty ) -> Option < Ty > {
1252
+ Some ( match op {
1253
+ BinaryOp :: LogicOp ( ..) => TyKind :: Scalar ( Scalar :: Bool ) . intern ( & Interner ) ,
1254
+ BinaryOp :: Assignment { op : None } => lhs_ty,
1255
+ BinaryOp :: CmpOp ( CmpOp :: Eq { .. } ) => match self
1256
+ . resolve_ty_shallow ( & lhs_ty)
1257
+ . kind ( & Interner )
1258
+ {
1259
+ TyKind :: Scalar ( _) | TyKind :: Str => lhs_ty,
1260
+ TyKind :: InferenceVar ( _, TyVariableKind :: Integer | TyVariableKind :: Float ) => lhs_ty,
1261
+ _ => return None ,
1262
+ } ,
1263
+ BinaryOp :: ArithOp ( ArithOp :: Shl | ArithOp :: Shr ) => return None ,
1264
+ BinaryOp :: CmpOp ( CmpOp :: Ord { .. } )
1265
+ | BinaryOp :: Assignment { op : Some ( _) }
1266
+ | BinaryOp :: ArithOp ( _) => match self . resolve_ty_shallow ( & lhs_ty) . kind ( & Interner ) {
1267
+ TyKind :: Scalar ( Scalar :: Int ( _) | Scalar :: Uint ( _) | Scalar :: Float ( _) ) => lhs_ty,
1268
+ TyKind :: InferenceVar ( _, TyVariableKind :: Integer | TyVariableKind :: Float ) => lhs_ty,
1269
+ _ => return None ,
1270
+ } ,
1271
+ } )
1272
+ }
1273
+
1274
+ fn resolve_binop_method ( & self , op : BinaryOp ) -> Option < FunctionId > {
1275
+ let ( name, lang_item) = match op {
1276
+ BinaryOp :: LogicOp ( _) => return None ,
1277
+ BinaryOp :: ArithOp ( aop) => match aop {
1278
+ ArithOp :: Add => ( name ! ( add) , "add" ) ,
1279
+ ArithOp :: Mul => ( name ! ( mul) , "mul" ) ,
1280
+ ArithOp :: Sub => ( name ! ( sub) , "sub" ) ,
1281
+ ArithOp :: Div => ( name ! ( div) , "div" ) ,
1282
+ ArithOp :: Rem => ( name ! ( rem) , "rem" ) ,
1283
+ ArithOp :: Shl => ( name ! ( shl) , "shl" ) ,
1284
+ ArithOp :: Shr => ( name ! ( shr) , "shr" ) ,
1285
+ ArithOp :: BitXor => ( name ! ( bitxor) , "bitxor" ) ,
1286
+ ArithOp :: BitOr => ( name ! ( bitor) , "bitor" ) ,
1287
+ ArithOp :: BitAnd => ( name ! ( bitand) , "bitand" ) ,
1288
+ } ,
1289
+ BinaryOp :: Assignment { op : Some ( aop) } => match aop {
1290
+ ArithOp :: Add => ( name ! ( add_assign) , "add_assign" ) ,
1291
+ ArithOp :: Mul => ( name ! ( mul_assign) , "mul_assign" ) ,
1292
+ ArithOp :: Sub => ( name ! ( sub_assign) , "sub_assign" ) ,
1293
+ ArithOp :: Div => ( name ! ( div_assign) , "div_assign" ) ,
1294
+ ArithOp :: Rem => ( name ! ( rem_assign) , "rem_assign" ) ,
1295
+ ArithOp :: Shl => ( name ! ( shl_assign) , "shl_assign" ) ,
1296
+ ArithOp :: Shr => ( name ! ( shr_assign) , "shr_assign" ) ,
1297
+ ArithOp :: BitXor => ( name ! ( bitxor_assign) , "bitxor_assign" ) ,
1298
+ ArithOp :: BitOr => ( name ! ( bitor_assign) , "bitor_assign" ) ,
1299
+ ArithOp :: BitAnd => ( name ! ( bitand_assign) , "bitand_assign" ) ,
1300
+ } ,
1301
+ BinaryOp :: CmpOp ( cop) => match cop {
1302
+ CmpOp :: Eq { negated : false } => ( name ! ( eq) , "eq" ) ,
1303
+ CmpOp :: Eq { negated : true } => ( name ! ( ne) , "eq" ) ,
1304
+ CmpOp :: Ord { ordering : Ordering :: Less , strict : false } => {
1305
+ ( name ! ( le) , "partial_ord" )
1306
+ }
1307
+ CmpOp :: Ord { ordering : Ordering :: Less , strict : true } => ( name ! ( lt) , "partial_ord" ) ,
1308
+ CmpOp :: Ord { ordering : Ordering :: Greater , strict : false } => {
1309
+ ( name ! ( ge) , "partial_ord" )
1310
+ }
1311
+ CmpOp :: Ord { ordering : Ordering :: Greater , strict : true } => {
1312
+ ( name ! ( gt) , "partial_ord" )
1313
+ }
1314
+ } ,
1315
+ BinaryOp :: Assignment { op : None } => return None ,
1316
+ } ;
1317
+
1318
+ let trait_ = self . resolve_lang_item ( lang_item) ?. as_trait ( ) ?;
1319
+
1320
+ self . db . trait_data ( trait_) . method_by_name ( & name)
1321
+ }
1139
1322
}
0 commit comments