From d71f7be21812203f343285e40456975ff5cb8ae9 Mon Sep 17 00:00:00 2001
From: Michael Goulet <michael@errs.io>
Date: Tue, 26 Dec 2023 16:42:40 +0000
Subject: [PATCH] Couple of random coroutine pass simplifications

---
 compiler/rustc_mir_transform/src/coroutine.rs | 35 +++++++------------
 1 file changed, 13 insertions(+), 22 deletions(-)

diff --git a/compiler/rustc_mir_transform/src/coroutine.rs b/compiler/rustc_mir_transform/src/coroutine.rs
index 5e434d30b003b..ce1a36cf67021 100644
--- a/compiler/rustc_mir_transform/src/coroutine.rs
+++ b/compiler/rustc_mir_transform/src/coroutine.rs
@@ -1417,20 +1417,18 @@ fn create_coroutine_resume_function<'tcx>(
     cases.insert(0, (UNRESUMED, START_BLOCK));
 
     // Panic when resumed on the returned or poisoned state
-    let coroutine_kind = body.coroutine_kind().unwrap();
-
     if can_unwind {
         cases.insert(
             1,
-            (POISONED, insert_panic_block(tcx, body, ResumedAfterPanic(coroutine_kind))),
+            (POISONED, insert_panic_block(tcx, body, ResumedAfterPanic(transform.coroutine_kind))),
         );
     }
 
     if can_return {
-        let block = match coroutine_kind {
+        let block = match transform.coroutine_kind {
             CoroutineKind::Desugared(CoroutineDesugaring::Async, _)
             | CoroutineKind::Coroutine(_) => {
-                insert_panic_block(tcx, body, ResumedAfterReturn(coroutine_kind))
+                insert_panic_block(tcx, body, ResumedAfterReturn(transform.coroutine_kind))
             }
             CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _)
             | CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
@@ -1444,7 +1442,7 @@ fn create_coroutine_resume_function<'tcx>(
 
     make_coroutine_state_argument_indirect(tcx, body);
 
-    match coroutine_kind {
+    match transform.coroutine_kind {
         // Iterator::next doesn't accept a pinned argument,
         // unlike for all other coroutine kinds.
         CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {}
@@ -1614,12 +1612,6 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
             }
         };
 
-        let is_async_kind =
-            matches!(coroutine_kind, CoroutineKind::Desugared(CoroutineDesugaring::Async, _));
-        let is_async_gen_kind =
-            matches!(coroutine_kind, CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _));
-        let is_gen_kind =
-            matches!(coroutine_kind, CoroutineKind::Desugared(CoroutineDesugaring::Gen, _));
         let new_ret_ty = match coroutine_kind {
             CoroutineKind::Desugared(CoroutineDesugaring::Async, _) => {
                 // Compute Poll<return_ty>
@@ -1653,7 +1645,10 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
         let old_ret_local = replace_local(RETURN_PLACE, new_ret_ty, body, tcx);
 
         // Replace all occurrences of `ResumeTy` with `&mut Context<'_>` within async bodies.
-        if is_async_kind || is_async_gen_kind {
+        if matches!(
+            coroutine_kind,
+            CoroutineKind::Desugared(CoroutineDesugaring::Async | CoroutineDesugaring::AsyncGen, _)
+        ) {
             transform_async_context(tcx, body);
         }
 
@@ -1662,11 +1657,7 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
         // case there is no `Assign` to it that the transform can turn into a store to the coroutine
         // state. After the yield the slot in the coroutine state would then be uninitialized.
         let resume_local = Local::new(2);
-        let resume_ty = if is_async_kind {
-            Ty::new_task_context(tcx)
-        } else {
-            body.local_decls[resume_local].ty
-        };
+        let resume_ty = body.local_decls[resume_local].ty;
         let old_resume_local = replace_local(resume_local, resume_ty, body, tcx);
 
         // When first entering the coroutine, move the resume argument into its old local
@@ -1709,11 +1700,11 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
         // Run the transformation which converts Places from Local to coroutine struct
         // accesses for locals in `remap`.
         // It also rewrites `return x` and `yield y` as writing a new coroutine state and returning
-        // either CoroutineState::Complete(x) and CoroutineState::Yielded(y),
-        // or Poll::Ready(x) and Poll::Pending respectively depending on `is_async_kind`.
+        // either `CoroutineState::Complete(x)` and `CoroutineState::Yielded(y)`,
+        // or `Poll::Ready(x)` and `Poll::Pending` respectively depending on the coroutine kind.
         let mut transform = TransformVisitor {
             tcx,
-            coroutine_kind: body.coroutine_kind().unwrap(),
+            coroutine_kind,
             remap,
             storage_liveness,
             always_live_locals,
@@ -1730,7 +1721,7 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
         body.spread_arg = None;
 
         // Remove the context argument within generator bodies.
-        if is_gen_kind {
+        if matches!(coroutine_kind, CoroutineKind::Desugared(CoroutineDesugaring::Gen, _)) {
             transform_gen_context(tcx, body);
         }