From 7a2cdf20e42c5b2a4d0f259eab8968cda7ae55c2 Mon Sep 17 00:00:00 2001
From: Michael Goulet <michael@errs.io>
Date: Sat, 27 May 2023 19:05:09 +0000
Subject: [PATCH 1/2] Move alias-relate to its own module

---
 .../src/solve/alias_relate.rs                 | 146 ++++++++++++++++++
 .../rustc_trait_selection/src/solve/mod.rs    | 137 +---------------
 2 files changed, 147 insertions(+), 136 deletions(-)
 create mode 100644 compiler/rustc_trait_selection/src/solve/alias_relate.rs

diff --git a/compiler/rustc_trait_selection/src/solve/alias_relate.rs b/compiler/rustc_trait_selection/src/solve/alias_relate.rs
new file mode 100644
index 0000000000000..59ab3bc3a727b
--- /dev/null
+++ b/compiler/rustc_trait_selection/src/solve/alias_relate.rs
@@ -0,0 +1,146 @@
+use super::{EvalCtxt, SolverMode};
+use rustc_infer::traits::query::NoSolution;
+use rustc_middle::traits::solve::{Certainty, Goal, QueryResult};
+use rustc_middle::ty;
+
+/// We may need to invert the alias relation direction if dealing an alias on the RHS.
+#[derive(Debug)]
+enum Invert {
+    No,
+    Yes,
+}
+
+impl<'tcx> EvalCtxt<'_, 'tcx> {
+    #[instrument(level = "debug", skip(self), ret)]
+    pub(super) fn compute_alias_relate_goal(
+        &mut self,
+        goal: Goal<'tcx, (ty::Term<'tcx>, ty::Term<'tcx>, ty::AliasRelationDirection)>,
+    ) -> QueryResult<'tcx> {
+        let tcx = self.tcx();
+        let Goal { param_env, predicate: (lhs, rhs, direction) } = goal;
+        if lhs.is_infer() || rhs.is_infer() {
+            bug!(
+                "`AliasRelate` goal with an infer var on lhs or rhs which should have been instantiated"
+            );
+        }
+
+        match (lhs.to_alias_ty(tcx), rhs.to_alias_ty(tcx)) {
+            (None, None) => bug!("`AliasRelate` goal without an alias on either lhs or rhs"),
+
+            // RHS is not a projection, only way this is true is if LHS normalizes-to RHS
+            (Some(alias_lhs), None) => self.assemble_normalizes_to_candidate(
+                param_env,
+                alias_lhs,
+                rhs,
+                direction,
+                Invert::No,
+            ),
+
+            // LHS is not a projection, only way this is true is if RHS normalizes-to LHS
+            (None, Some(alias_rhs)) => self.assemble_normalizes_to_candidate(
+                param_env,
+                alias_rhs,
+                lhs,
+                direction,
+                Invert::Yes,
+            ),
+
+            (Some(alias_lhs), Some(alias_rhs)) => {
+                debug!("both sides are aliases");
+
+                let mut candidates = Vec::new();
+                // LHS normalizes-to RHS
+                candidates.extend(self.assemble_normalizes_to_candidate(
+                    param_env,
+                    alias_lhs,
+                    rhs,
+                    direction,
+                    Invert::No,
+                ));
+                // RHS normalizes-to RHS
+                candidates.extend(self.assemble_normalizes_to_candidate(
+                    param_env,
+                    alias_rhs,
+                    lhs,
+                    direction,
+                    Invert::Yes,
+                ));
+                // Relate via substs
+                let subst_relate_response = self
+                    .assemble_subst_relate_candidate(param_env, alias_lhs, alias_rhs, direction);
+                candidates.extend(subst_relate_response);
+                debug!(?candidates);
+
+                if let Some(merged) = self.try_merge_responses(&candidates) {
+                    Ok(merged)
+                } else {
+                    // When relating two aliases and we have ambiguity, we prefer
+                    // relating the generic arguments of the aliases over normalizing
+                    // them. This is necessary for inference during typeck.
+                    //
+                    // As this is incomplete, we must not do so during coherence.
+                    match (self.solver_mode(), subst_relate_response) {
+                        (SolverMode::Normal, Ok(response)) => Ok(response),
+                        (SolverMode::Normal, Err(NoSolution)) | (SolverMode::Coherence, _) => {
+                            self.flounder(&candidates)
+                        }
+                    }
+                }
+            }
+        }
+    }
+
+    #[instrument(level = "debug", skip(self), ret)]
+    fn assemble_normalizes_to_candidate(
+        &mut self,
+        param_env: ty::ParamEnv<'tcx>,
+        alias: ty::AliasTy<'tcx>,
+        other: ty::Term<'tcx>,
+        direction: ty::AliasRelationDirection,
+        invert: Invert,
+    ) -> QueryResult<'tcx> {
+        self.probe(|ecx| {
+            let other = match direction {
+                // This is purely an optimization.
+                ty::AliasRelationDirection::Equate => other,
+
+                ty::AliasRelationDirection::Subtype => {
+                    let fresh = ecx.next_term_infer_of_kind(other);
+                    let (sub, sup) = match invert {
+                        Invert::No => (fresh, other),
+                        Invert::Yes => (other, fresh),
+                    };
+                    ecx.sub(param_env, sub, sup)?;
+                    fresh
+                }
+            };
+            ecx.add_goal(Goal::new(
+                ecx.tcx(),
+                param_env,
+                ty::Binder::dummy(ty::ProjectionPredicate { projection_ty: alias, term: other }),
+            ));
+            ecx.evaluate_added_goals_and_make_canonical_response(Certainty::Yes)
+        })
+    }
+
+    fn assemble_subst_relate_candidate(
+        &mut self,
+        param_env: ty::ParamEnv<'tcx>,
+        alias_lhs: ty::AliasTy<'tcx>,
+        alias_rhs: ty::AliasTy<'tcx>,
+        direction: ty::AliasRelationDirection,
+    ) -> QueryResult<'tcx> {
+        self.probe(|ecx| {
+            match direction {
+                ty::AliasRelationDirection::Equate => {
+                    ecx.eq(param_env, alias_lhs, alias_rhs)?;
+                }
+                ty::AliasRelationDirection::Subtype => {
+                    ecx.sub(param_env, alias_lhs, alias_rhs)?;
+                }
+            }
+
+            ecx.evaluate_added_goals_and_make_canonical_response(Certainty::Yes)
+        })
+    }
+}
diff --git a/compiler/rustc_trait_selection/src/solve/mod.rs b/compiler/rustc_trait_selection/src/solve/mod.rs
index 56a254d9c07e1..f4c29c837b884 100644
--- a/compiler/rustc_trait_selection/src/solve/mod.rs
+++ b/compiler/rustc_trait_selection/src/solve/mod.rs
@@ -20,6 +20,7 @@ use rustc_middle::ty::{
     CoercePredicate, RegionOutlivesPredicate, SubtypePredicate, TypeOutlivesPredicate,
 };
 
+mod alias_relate;
 mod assembly;
 mod canonicalize;
 mod eval_ctxt;
@@ -154,142 +155,6 @@ impl<'a, 'tcx> EvalCtxt<'a, 'tcx> {
         }
     }
 
-    #[instrument(level = "debug", skip(self), ret)]
-    fn compute_alias_relate_goal(
-        &mut self,
-        goal: Goal<'tcx, (ty::Term<'tcx>, ty::Term<'tcx>, ty::AliasRelationDirection)>,
-    ) -> QueryResult<'tcx> {
-        let tcx = self.tcx();
-        // We may need to invert the alias relation direction if dealing an alias on the RHS.
-        #[derive(Debug)]
-        enum Invert {
-            No,
-            Yes,
-        }
-        let evaluate_normalizes_to =
-            |ecx: &mut EvalCtxt<'_, 'tcx>, alias, other, direction, invert| {
-                let span = tracing::span!(
-                    tracing::Level::DEBUG,
-                    "compute_alias_relate_goal(evaluate_normalizes_to)",
-                    ?alias,
-                    ?other,
-                    ?direction,
-                    ?invert
-                );
-                let _enter = span.enter();
-                let result = ecx.probe(|ecx| {
-                    let other = match direction {
-                        // This is purely an optimization.
-                        ty::AliasRelationDirection::Equate => other,
-
-                        ty::AliasRelationDirection::Subtype => {
-                            let fresh = ecx.next_term_infer_of_kind(other);
-                            let (sub, sup) = match invert {
-                                Invert::No => (fresh, other),
-                                Invert::Yes => (other, fresh),
-                            };
-                            ecx.sub(goal.param_env, sub, sup)?;
-                            fresh
-                        }
-                    };
-                    ecx.add_goal(goal.with(
-                        tcx,
-                        ty::Binder::dummy(ty::ProjectionPredicate {
-                            projection_ty: alias,
-                            term: other,
-                        }),
-                    ));
-                    ecx.evaluate_added_goals_and_make_canonical_response(Certainty::Yes)
-                });
-                debug!(?result);
-                result
-            };
-
-        let (lhs, rhs, direction) = goal.predicate;
-
-        if lhs.is_infer() || rhs.is_infer() {
-            bug!(
-                "`AliasRelate` goal with an infer var on lhs or rhs which should have been instantiated"
-            );
-        }
-
-        match (lhs.to_alias_ty(tcx), rhs.to_alias_ty(tcx)) {
-            (None, None) => bug!("`AliasRelate` goal without an alias on either lhs or rhs"),
-
-            // RHS is not a projection, only way this is true is if LHS normalizes-to RHS
-            (Some(alias_lhs), None) => {
-                evaluate_normalizes_to(self, alias_lhs, rhs, direction, Invert::No)
-            }
-
-            // LHS is not a projection, only way this is true is if RHS normalizes-to LHS
-            (None, Some(alias_rhs)) => {
-                evaluate_normalizes_to(self, alias_rhs, lhs, direction, Invert::Yes)
-            }
-
-            (Some(alias_lhs), Some(alias_rhs)) => {
-                debug!("both sides are aliases");
-
-                let mut candidates = Vec::new();
-                // LHS normalizes-to RHS
-                candidates.extend(evaluate_normalizes_to(
-                    self,
-                    alias_lhs,
-                    rhs,
-                    direction,
-                    Invert::No,
-                ));
-                // RHS normalizes-to RHS
-                candidates.extend(evaluate_normalizes_to(
-                    self,
-                    alias_rhs,
-                    lhs,
-                    direction,
-                    Invert::Yes,
-                ));
-                // Relate via substs
-                let subst_relate_response = self.probe(|ecx| {
-                    let span = tracing::span!(
-                        tracing::Level::DEBUG,
-                        "compute_alias_relate_goal(relate_via_substs)",
-                        ?alias_lhs,
-                        ?alias_rhs,
-                        ?direction
-                    );
-                    let _enter = span.enter();
-
-                    match direction {
-                        ty::AliasRelationDirection::Equate => {
-                            ecx.eq(goal.param_env, alias_lhs, alias_rhs)?;
-                        }
-                        ty::AliasRelationDirection::Subtype => {
-                            ecx.sub(goal.param_env, alias_lhs, alias_rhs)?;
-                        }
-                    }
-
-                    ecx.evaluate_added_goals_and_make_canonical_response(Certainty::Yes)
-                });
-                candidates.extend(subst_relate_response);
-                debug!(?candidates);
-
-                if let Some(merged) = self.try_merge_responses(&candidates) {
-                    Ok(merged)
-                } else {
-                    // When relating two aliases and we have ambiguity, we prefer
-                    // relating the generic arguments of the aliases over normalizing
-                    // them. This is necessary for inference during typeck.
-                    //
-                    // As this is incomplete, we must not do so during coherence.
-                    match (self.solver_mode(), subst_relate_response) {
-                        (SolverMode::Normal, Ok(response)) => Ok(response),
-                        (SolverMode::Normal, Err(NoSolution)) | (SolverMode::Coherence, _) => {
-                            self.flounder(&candidates)
-                        }
-                    }
-                }
-            }
-        }
-    }
-
     #[instrument(level = "debug", skip(self), ret)]
     fn compute_const_arg_has_type_goal(
         &mut self,

From 3ea7c512bd1587006a1c196df841e9b7ec60fb0b Mon Sep 17 00:00:00 2001
From: Michael Goulet <michael@errs.io>
Date: Sat, 27 May 2023 19:27:59 +0000
Subject: [PATCH 2/2] Fall back to bidirectional normalizes-to if no subst-eq
 in alias-eq goal

---
 .../src/solve/alias_relate.rs                 | 95 ++++++++++++++-----
 tests/ui/traits/new-solver/tait-eq-proj-2.rs  | 21 ++++
 tests/ui/traits/new-solver/tait-eq-proj.rs    | 35 +++++++
 tests/ui/traits/new-solver/tait-eq-tait.rs    | 17 ++++
 4 files changed, 145 insertions(+), 23 deletions(-)
 create mode 100644 tests/ui/traits/new-solver/tait-eq-proj-2.rs
 create mode 100644 tests/ui/traits/new-solver/tait-eq-proj.rs
 create mode 100644 tests/ui/traits/new-solver/tait-eq-tait.rs

diff --git a/compiler/rustc_trait_selection/src/solve/alias_relate.rs b/compiler/rustc_trait_selection/src/solve/alias_relate.rs
index 59ab3bc3a727b..66a4d36a1e5a7 100644
--- a/compiler/rustc_trait_selection/src/solve/alias_relate.rs
+++ b/compiler/rustc_trait_selection/src/solve/alias_relate.rs
@@ -79,11 +79,21 @@ impl<'tcx> EvalCtxt<'_, 'tcx> {
                     // them. This is necessary for inference during typeck.
                     //
                     // As this is incomplete, we must not do so during coherence.
-                    match (self.solver_mode(), subst_relate_response) {
-                        (SolverMode::Normal, Ok(response)) => Ok(response),
-                        (SolverMode::Normal, Err(NoSolution)) | (SolverMode::Coherence, _) => {
-                            self.flounder(&candidates)
+                    match self.solver_mode() {
+                        SolverMode::Normal => {
+                            if let Ok(subst_relate_response) = subst_relate_response {
+                                Ok(subst_relate_response)
+                            } else if let Ok(bidirectional_normalizes_to_response) = self
+                                .assemble_bidirectional_normalizes_to_candidate(
+                                    param_env, lhs, rhs, direction,
+                                )
+                            {
+                                Ok(bidirectional_normalizes_to_response)
+                            } else {
+                                self.flounder(&candidates)
+                            }
                         }
+                        SolverMode::Coherence => self.flounder(&candidates),
                     }
                 }
             }
@@ -100,29 +110,42 @@ impl<'tcx> EvalCtxt<'_, 'tcx> {
         invert: Invert,
     ) -> QueryResult<'tcx> {
         self.probe(|ecx| {
-            let other = match direction {
-                // This is purely an optimization.
-                ty::AliasRelationDirection::Equate => other,
-
-                ty::AliasRelationDirection::Subtype => {
-                    let fresh = ecx.next_term_infer_of_kind(other);
-                    let (sub, sup) = match invert {
-                        Invert::No => (fresh, other),
-                        Invert::Yes => (other, fresh),
-                    };
-                    ecx.sub(param_env, sub, sup)?;
-                    fresh
-                }
-            };
-            ecx.add_goal(Goal::new(
-                ecx.tcx(),
-                param_env,
-                ty::Binder::dummy(ty::ProjectionPredicate { projection_ty: alias, term: other }),
-            ));
+            ecx.normalizes_to_inner(param_env, alias, other, direction, invert)?;
             ecx.evaluate_added_goals_and_make_canonical_response(Certainty::Yes)
         })
     }
 
+    fn normalizes_to_inner(
+        &mut self,
+        param_env: ty::ParamEnv<'tcx>,
+        alias: ty::AliasTy<'tcx>,
+        other: ty::Term<'tcx>,
+        direction: ty::AliasRelationDirection,
+        invert: Invert,
+    ) -> Result<(), NoSolution> {
+        let other = match direction {
+            // This is purely an optimization.
+            ty::AliasRelationDirection::Equate => other,
+
+            ty::AliasRelationDirection::Subtype => {
+                let fresh = self.next_term_infer_of_kind(other);
+                let (sub, sup) = match invert {
+                    Invert::No => (fresh, other),
+                    Invert::Yes => (other, fresh),
+                };
+                self.sub(param_env, sub, sup)?;
+                fresh
+            }
+        };
+        self.add_goal(Goal::new(
+            self.tcx(),
+            param_env,
+            ty::Binder::dummy(ty::ProjectionPredicate { projection_ty: alias, term: other }),
+        ));
+
+        Ok(())
+    }
+
     fn assemble_subst_relate_candidate(
         &mut self,
         param_env: ty::ParamEnv<'tcx>,
@@ -143,4 +166,30 @@ impl<'tcx> EvalCtxt<'_, 'tcx> {
             ecx.evaluate_added_goals_and_make_canonical_response(Certainty::Yes)
         })
     }
+
+    fn assemble_bidirectional_normalizes_to_candidate(
+        &mut self,
+        param_env: ty::ParamEnv<'tcx>,
+        lhs: ty::Term<'tcx>,
+        rhs: ty::Term<'tcx>,
+        direction: ty::AliasRelationDirection,
+    ) -> QueryResult<'tcx> {
+        self.probe(|ecx| {
+            ecx.normalizes_to_inner(
+                param_env,
+                lhs.to_alias_ty(ecx.tcx()).unwrap(),
+                rhs,
+                direction,
+                Invert::No,
+            )?;
+            ecx.normalizes_to_inner(
+                param_env,
+                rhs.to_alias_ty(ecx.tcx()).unwrap(),
+                lhs,
+                direction,
+                Invert::Yes,
+            )?;
+            ecx.evaluate_added_goals_and_make_canonical_response(Certainty::Yes)
+        })
+    }
 }
diff --git a/tests/ui/traits/new-solver/tait-eq-proj-2.rs b/tests/ui/traits/new-solver/tait-eq-proj-2.rs
new file mode 100644
index 0000000000000..99a3d02bd1aea
--- /dev/null
+++ b/tests/ui/traits/new-solver/tait-eq-proj-2.rs
@@ -0,0 +1,21 @@
+// compile-flags: -Ztrait-solver=next
+// check-pass
+
+#![feature(type_alias_impl_trait)]
+
+// Similar to tests/ui/traits/new-solver/tait-eq-proj.rs
+// but check the alias-sub relation in the other direction.
+
+type Tait = impl Iterator<Item = impl Sized>;
+
+fn mk<T>() -> T { todo!() }
+
+fn a() {
+    let x: Tait = mk();
+    let mut array = mk();
+    let mut z = IntoIterator::into_iter(array);
+    z = x;
+    array = [0i32; 32];
+}
+
+fn main() {}
diff --git a/tests/ui/traits/new-solver/tait-eq-proj.rs b/tests/ui/traits/new-solver/tait-eq-proj.rs
new file mode 100644
index 0000000000000..01141b2819a8d
--- /dev/null
+++ b/tests/ui/traits/new-solver/tait-eq-proj.rs
@@ -0,0 +1,35 @@
+// compile-flags: -Ztrait-solver=next
+// check-pass
+
+#![feature(type_alias_impl_trait)]
+
+type Tait = impl Iterator<Item = impl Sized>;
+
+/*
+
+Consider the goal - AliasRelate(Tait, <[i32; 32] as IntoIterator>::IntoIter)
+which is registered on the line above.
+
+A. SubstRelate - fails (of course).
+
+B. NormalizesToRhs - Tait normalizes-to <[i32; 32] as IntoIterator>::IntoIter
+    * infer definition - Tait := <[i32; 32] as IntoIterator>::IntoIter
+
+C. NormalizesToLhs - <[i32; 32] as IntoIterator>::IntoIter normalizes-to Tait
+    * Find impl candidate, after substitute - std::array::IntoIter<i32, 32>
+    * Equate std::array::IntoIter<i32, 32> and Tait
+        * infer definition - Tait := std::array::IntoIter<i32, 32>
+
+B and C are not equal, but they are equivalent modulo normalization.
+
+We get around this by evaluating both the NormalizesToRhs and NormalizesToLhs
+goals together. Essentially:
+    A alias-relate B if A normalizes-to B and B normalizes-to A.
+
+*/
+
+fn a() {
+    let _: Tait = IntoIterator::into_iter([0i32; 32]);
+}
+
+fn main() {}
diff --git a/tests/ui/traits/new-solver/tait-eq-tait.rs b/tests/ui/traits/new-solver/tait-eq-tait.rs
new file mode 100644
index 0000000000000..532c4c39bd499
--- /dev/null
+++ b/tests/ui/traits/new-solver/tait-eq-tait.rs
@@ -0,0 +1,17 @@
+// compile-flags: -Ztrait-solver=next
+// check-pass
+
+// Not exactly sure if this is the inference behavior we *want*,
+// but it is a side-effect of the lazy normalization of TAITs.
+
+#![feature(type_alias_impl_trait)]
+
+type Tait = impl Sized;
+type Tait2 = impl Sized;
+
+fn mk<T>() -> T { todo!() }
+
+fn main() {
+    let x: Tait = 1u32;
+    let y: Tait2 = x;
+}