From 9cd6c68033929d60441ad4286d1270eeafbb8620 Mon Sep 17 00:00:00 2001
From: lcnr <rust@lcnr.de>
Date: Mon, 5 Feb 2024 10:51:18 +0100
Subject: [PATCH] cleanup effect var handling

---
 .../src/infer/canonical/canonicalizer.rs      |  2 +-
 .../rustc_infer/src/infer/canonical/mod.rs    |  8 +++-
 compiler/rustc_infer/src/infer/freshen.rs     |  9 +---
 compiler/rustc_infer/src/infer/mod.rs         | 15 ++++---
 .../rustc_infer/src/infer/relate/combine.rs   | 38 +++--------------
 compiler/rustc_infer/src/infer/resolve.rs     |  9 ++--
 compiler/rustc_middle/src/infer/unify_key.rs  | 42 ++++++++++---------
 7 files changed, 52 insertions(+), 71 deletions(-)

diff --git a/compiler/rustc_infer/src/infer/canonical/canonicalizer.rs b/compiler/rustc_infer/src/infer/canonical/canonicalizer.rs
index e4b37f05b778f..30ca70326fa87 100644
--- a/compiler/rustc_infer/src/infer/canonical/canonicalizer.rs
+++ b/compiler/rustc_infer/src/infer/canonical/canonicalizer.rs
@@ -480,7 +480,7 @@ impl<'cx, 'tcx> TypeFolder<TyCtxt<'tcx>> for Canonicalizer<'cx, 'tcx> {
             }
             ty::ConstKind::Infer(InferConst::EffectVar(vid)) => {
                 match self.infcx.unwrap().probe_effect_var(vid) {
-                    Some(value) => return self.fold_const(value.as_const(self.tcx)),
+                    Some(value) => return self.fold_const(value),
                     None => {
                         return self.canonicalize_const_var(
                             CanonicalVarInfo { kind: CanonicalVarKind::Effect },
diff --git a/compiler/rustc_infer/src/infer/canonical/mod.rs b/compiler/rustc_infer/src/infer/canonical/mod.rs
index 386fdb09ba5d5..1f68a5a9c6179 100644
--- a/compiler/rustc_infer/src/infer/canonical/mod.rs
+++ b/compiler/rustc_infer/src/infer/canonical/mod.rs
@@ -24,6 +24,7 @@
 use crate::infer::{ConstVariableOrigin, ConstVariableOriginKind};
 use crate::infer::{InferCtxt, RegionVariableOrigin, TypeVariableOrigin, TypeVariableOriginKind};
 use rustc_index::IndexVec;
+use rustc_middle::infer::unify_key::EffectVarValue;
 use rustc_middle::ty::fold::TypeFoldable;
 use rustc_middle::ty::GenericArg;
 use rustc_middle::ty::{self, List, Ty, TyCtxt};
@@ -152,7 +153,12 @@ impl<'tcx> InferCtxt<'tcx> {
                 )
                 .into(),
             CanonicalVarKind::Effect => {
-                let vid = self.inner.borrow_mut().effect_unification_table().new_key(None).vid;
+                let vid = self
+                    .inner
+                    .borrow_mut()
+                    .effect_unification_table()
+                    .new_key(EffectVarValue::Unknown)
+                    .vid;
                 ty::Const::new_infer(self.tcx, ty::InferConst::EffectVar(vid), self.tcx.types.bool)
                     .into()
             }
diff --git a/compiler/rustc_infer/src/infer/freshen.rs b/compiler/rustc_infer/src/infer/freshen.rs
index d256994d8d1fd..2d5fa1b5c7001 100644
--- a/compiler/rustc_infer/src/infer/freshen.rs
+++ b/compiler/rustc_infer/src/infer/freshen.rs
@@ -151,13 +151,8 @@ impl<'a, 'tcx> TypeFolder<TyCtxt<'tcx>> for TypeFreshener<'a, 'tcx> {
                 self.freshen_const(opt_ct, ty::InferConst::Var(v), ty::InferConst::Fresh, ct.ty())
             }
             ty::ConstKind::Infer(ty::InferConst::EffectVar(v)) => {
-                let opt_ct = self
-                    .infcx
-                    .inner
-                    .borrow_mut()
-                    .effect_unification_table()
-                    .probe_value(v)
-                    .map(|effect| effect.as_const(self.infcx.tcx));
+                let opt_ct =
+                    self.infcx.inner.borrow_mut().effect_unification_table().probe_value(v).known();
                 self.freshen_const(
                     opt_ct,
                     ty::InferConst::EffectVar(v),
diff --git a/compiler/rustc_infer/src/infer/mod.rs b/compiler/rustc_infer/src/infer/mod.rs
index 0a39fe007fd22..13baecf6002c0 100644
--- a/compiler/rustc_infer/src/infer/mod.rs
+++ b/compiler/rustc_infer/src/infer/mod.rs
@@ -8,6 +8,7 @@ pub use self::ValuePairs::*;
 pub use relate::combine::ObligationEmittingRelation;
 use rustc_data_structures::captures::Captures;
 use rustc_data_structures::undo_log::UndoLogs;
+use rustc_middle::infer::unify_key::EffectVarValue;
 use rustc_middle::infer::unify_key::{ConstVidKey, EffectVidKey};
 
 use self::opaque_types::OpaqueTypeStorage;
@@ -25,8 +26,8 @@ use rustc_data_structures::unify as ut;
 use rustc_errors::{DiagCtxt, DiagnosticBuilder, ErrorGuaranteed};
 use rustc_hir::def_id::{DefId, LocalDefId};
 use rustc_middle::infer::canonical::{Canonical, CanonicalVarValues};
+use rustc_middle::infer::unify_key::ConstVariableValue;
 use rustc_middle::infer::unify_key::{ConstVariableOrigin, ConstVariableOriginKind, ToType};
-use rustc_middle::infer::unify_key::{ConstVariableValue, EffectVarValue};
 use rustc_middle::mir::interpret::{ErrorHandled, EvalToValTreeResult};
 use rustc_middle::mir::ConstraintCategory;
 use rustc_middle::traits::{select, DefiningAnchor};
@@ -818,7 +819,7 @@ impl<'tcx> InferCtxt<'tcx> {
 
         (0..table.len())
             .map(|i| ty::EffectVid::from_usize(i))
-            .filter(|&vid| table.probe_value(vid).is_none())
+            .filter(|&vid| table.probe_value(vid).is_unknown())
             .map(|v| {
                 ty::Const::new_infer(self.tcx, ty::InferConst::EffectVar(v), self.tcx.types.bool)
             })
@@ -1236,7 +1237,8 @@ impl<'tcx> InferCtxt<'tcx> {
     }
 
     pub fn var_for_effect(&self, param: &ty::GenericParamDef) -> GenericArg<'tcx> {
-        let effect_vid = self.inner.borrow_mut().effect_unification_table().new_key(None).vid;
+        let effect_vid =
+            self.inner.borrow_mut().effect_unification_table().new_key(EffectVarValue::Unknown).vid;
         let ty = self
             .tcx
             .type_of(param.def_id)
@@ -1416,8 +1418,8 @@ impl<'tcx> InferCtxt<'tcx> {
         }
     }
 
-    pub fn probe_effect_var(&self, vid: EffectVid) -> Option<EffectVarValue<'tcx>> {
-        self.inner.borrow_mut().effect_unification_table().probe_value(vid)
+    pub fn probe_effect_var(&self, vid: EffectVid) -> Option<ty::Const<'tcx>> {
+        self.inner.borrow_mut().effect_unification_table().probe_value(vid).known()
     }
 
     /// Attempts to resolve all type/region/const variables in
@@ -1893,7 +1895,8 @@ impl<'a, 'tcx> TypeFolder<TyCtxt<'tcx>> for ShallowResolver<'a, 'tcx> {
                 .borrow_mut()
                 .effect_unification_table()
                 .probe_value(vid)
-                .map_or(ct, |val| val.as_const(self.infcx.tcx)),
+                .known()
+                .unwrap_or(ct),
             _ => ct,
         }
     }
diff --git a/compiler/rustc_infer/src/infer/relate/combine.rs b/compiler/rustc_infer/src/infer/relate/combine.rs
index 1c120646f1f90..13c17ee1cd2b8 100644
--- a/compiler/rustc_infer/src/infer/relate/combine.rs
+++ b/compiler/rustc_infer/src/infer/relate/combine.rs
@@ -202,11 +202,7 @@ impl<'tcx> InferCtxt<'tcx> {
                 ty::ConstKind::Infer(InferConst::EffectVar(a_vid)),
                 ty::ConstKind::Infer(InferConst::EffectVar(b_vid)),
             ) => {
-                self.inner
-                    .borrow_mut()
-                    .effect_unification_table()
-                    .unify_var_var(a_vid, b_vid)
-                    .map_err(|a| effect_unification_error(self.tcx, relation.a_is_expected(), a))?;
+                self.inner.borrow_mut().effect_unification_table().union(a_vid, b_vid);
                 return Ok(a);
             }
 
@@ -233,19 +229,11 @@ impl<'tcx> InferCtxt<'tcx> {
             }
 
             (ty::ConstKind::Infer(InferConst::EffectVar(vid)), _) => {
-                return self.unify_effect_variable(
-                    relation.a_is_expected(),
-                    vid,
-                    EffectVarValue::Const(b),
-                );
+                return Ok(self.unify_effect_variable(vid, b));
             }
 
             (_, ty::ConstKind::Infer(InferConst::EffectVar(vid))) => {
-                return self.unify_effect_variable(
-                    !relation.a_is_expected(),
-                    vid,
-                    EffectVarValue::Const(a),
-                );
+                return Ok(self.unify_effect_variable(vid, a));
             }
 
             (ty::ConstKind::Unevaluated(..), _) | (_, ty::ConstKind::Unevaluated(..))
@@ -366,18 +354,12 @@ impl<'tcx> InferCtxt<'tcx> {
         Ok(Ty::new_float(self.tcx, val))
     }
 
-    fn unify_effect_variable(
-        &self,
-        vid_is_expected: bool,
-        vid: ty::EffectVid,
-        val: EffectVarValue<'tcx>,
-    ) -> RelateResult<'tcx, ty::Const<'tcx>> {
+    fn unify_effect_variable(&self, vid: ty::EffectVid, val: ty::Const<'tcx>) -> ty::Const<'tcx> {
         self.inner
             .borrow_mut()
             .effect_unification_table()
-            .unify_var_value(vid, Some(val))
-            .map_err(|e| effect_unification_error(self.tcx, vid_is_expected, e))?;
-        Ok(val.as_const(self.tcx))
+            .union_value(vid, EffectVarValue::Known(val));
+        val
     }
 }
 
@@ -579,11 +561,3 @@ fn float_unification_error<'tcx>(
     let (ty::FloatVarValue(a), ty::FloatVarValue(b)) = v;
     TypeError::FloatMismatch(ExpectedFound::new(a_is_expected, a, b))
 }
-
-fn effect_unification_error<'tcx>(
-    _tcx: TyCtxt<'tcx>,
-    _a_is_expected: bool,
-    (_a, _b): (EffectVarValue<'tcx>, EffectVarValue<'tcx>),
-) -> TypeError<'tcx> {
-    bug!("unexpected effect unification error")
-}
diff --git a/compiler/rustc_infer/src/infer/resolve.rs b/compiler/rustc_infer/src/infer/resolve.rs
index 959b09031277c..d5999331dfab6 100644
--- a/compiler/rustc_infer/src/infer/resolve.rs
+++ b/compiler/rustc_infer/src/infer/resolve.rs
@@ -237,14 +237,13 @@ impl<'tcx> TypeFolder<TyCtxt<'tcx>> for EagerResolver<'_, 'tcx> {
             }
             ty::ConstKind::Infer(ty::InferConst::EffectVar(vid)) => {
                 debug_assert_eq!(c.ty(), self.infcx.tcx.types.bool);
-                match self.infcx.probe_effect_var(vid) {
-                    Some(c) => c.as_const(self.infcx.tcx),
-                    None => ty::Const::new_infer(
+                self.infcx.probe_effect_var(vid).unwrap_or_else(|| {
+                    ty::Const::new_infer(
                         self.infcx.tcx,
                         ty::InferConst::EffectVar(self.infcx.root_effect_var(vid)),
                         self.infcx.tcx.types.bool,
-                    ),
-                }
+                    )
+                })
             }
             _ => {
                 if c.has_infer() {
diff --git a/compiler/rustc_middle/src/infer/unify_key.rs b/compiler/rustc_middle/src/infer/unify_key.rs
index c35799ef47f2b..63c0ebd5f6b79 100644
--- a/compiler/rustc_middle/src/infer/unify_key.rs
+++ b/compiler/rustc_middle/src/infer/unify_key.rs
@@ -194,33 +194,37 @@ impl<'tcx> UnifyValue for ConstVariableValue<'tcx> {
 /// values for the effect inference variable
 #[derive(Clone, Copy, Debug)]
 pub enum EffectVarValue<'tcx> {
-    /// The host effect is on, enabling access to syscalls, filesystem access, etc.
-    Host,
-    /// The host effect is off. Execution is restricted to const operations only.
-    NoHost,
-    Const(ty::Const<'tcx>),
+    Unknown,
+    Known(ty::Const<'tcx>),
 }
 
 impl<'tcx> EffectVarValue<'tcx> {
-    pub fn as_const(self, tcx: TyCtxt<'tcx>) -> ty::Const<'tcx> {
+    pub fn known(self) -> Option<ty::Const<'tcx>> {
         match self {
-            EffectVarValue::Host => tcx.consts.true_,
-            EffectVarValue::NoHost => tcx.consts.false_,
-            EffectVarValue::Const(c) => c,
+            EffectVarValue::Unknown => None,
+            EffectVarValue::Known(value) => Some(value),
+        }
+    }
+
+    pub fn is_unknown(self) -> bool {
+        match self {
+            EffectVarValue::Unknown => true,
+            EffectVarValue::Known(_) => false,
         }
     }
 }
 
 impl<'tcx> UnifyValue for EffectVarValue<'tcx> {
-    type Error = (EffectVarValue<'tcx>, EffectVarValue<'tcx>);
+    type Error = NoError;
     fn unify_values(value1: &Self, value2: &Self) -> Result<Self, Self::Error> {
-        match (value1, value2) {
-            (EffectVarValue::Host, EffectVarValue::Host) => Ok(EffectVarValue::Host),
-            (EffectVarValue::NoHost, EffectVarValue::NoHost) => Ok(EffectVarValue::NoHost),
-            (EffectVarValue::NoHost | EffectVarValue::Host, _)
-            | (_, EffectVarValue::NoHost | EffectVarValue::Host) => Err((*value1, *value2)),
-            (EffectVarValue::Const(_), EffectVarValue::Const(_)) => {
-                bug!("equating two const variables, both of which have known values")
+        match (*value1, *value2) {
+            (EffectVarValue::Unknown, EffectVarValue::Unknown) => Ok(EffectVarValue::Unknown),
+            (EffectVarValue::Unknown, EffectVarValue::Known(val))
+            | (EffectVarValue::Known(val), EffectVarValue::Unknown) => {
+                Ok(EffectVarValue::Known(val))
+            }
+            (EffectVarValue::Known(_), EffectVarValue::Known(_)) => {
+                bug!("equating known inference variables: {value1:?} {value2:?}")
             }
         }
     }
@@ -229,7 +233,7 @@ impl<'tcx> UnifyValue for EffectVarValue<'tcx> {
 #[derive(PartialEq, Copy, Clone, Debug)]
 pub struct EffectVidKey<'tcx> {
     pub vid: ty::EffectVid,
-    pub phantom: PhantomData<EffectVarValue<'tcx>>,
+    pub phantom: PhantomData<ty::Const<'tcx>>,
 }
 
 impl<'tcx> From<ty::EffectVid> for EffectVidKey<'tcx> {
@@ -239,7 +243,7 @@ impl<'tcx> From<ty::EffectVid> for EffectVidKey<'tcx> {
 }
 
 impl<'tcx> UnifyKey for EffectVidKey<'tcx> {
-    type Value = Option<EffectVarValue<'tcx>>;
+    type Value = EffectVarValue<'tcx>;
     #[inline]
     fn index(&self) -> u32 {
         self.vid.as_u32()