From 064e0c6264e2019f68d5658cfbefe2c6872e5379 Mon Sep 17 00:00:00 2001
From: Doug Gregor <dgregor@apple.com>
Date: Mon, 3 Mar 2025 17:41:04 -0800
Subject: [PATCH 1/2] Ensure that isolated conformances originate in the same
 isolation domain

This is the missing check for "rule #1" in the isolated conformances proposal,
which states that an isolated conformance can only be referenced within
the same isolation domain as the conformance. For example, a
main-actor-isolated conformance can only be used within main-actor code.
---
 include/swift/AST/DiagnosticsSema.def       |   3 +
 lib/Sema/TypeCheckConcurrency.cpp           | 112 ++++++++++++++++++++
 lib/Sema/TypeCheckConcurrency.h             |  31 ++++++
 lib/Sema/TypeCheckProtocol.cpp              | 104 ++++++++++++++++++
 lib/Sema/TypeCheckProtocol.h                |  32 ++++++
 test/Concurrency/isolated_conformance.swift |   6 ++
 6 files changed, 288 insertions(+)

diff --git a/include/swift/AST/DiagnosticsSema.def b/include/swift/AST/DiagnosticsSema.def
index 500a5b51ff544..835d38359f09a 100644
--- a/include/swift/AST/DiagnosticsSema.def
+++ b/include/swift/AST/DiagnosticsSema.def
@@ -8320,6 +8320,9 @@ ERROR(isolated_conformance_with_sendable_simple,none,
       "isolated conformance of %0 to %1 cannot be used to satisfy conformance "
       "requirement for a `Sendable` type parameter ",
       (Type, DeclName))
+ERROR(isolated_conformance_wrong_domain,none,
+      "%0 isolated conformance of %1 to %2 cannot be used in %3 context",
+      (ActorIsolation, Type, DeclName, ActorIsolation))
 
 //===----------------------------------------------------------------------===//
 // MARK: @execution Attribute
diff --git a/lib/Sema/TypeCheckConcurrency.cpp b/lib/Sema/TypeCheckConcurrency.cpp
index fcee92580b536..2a76c83a0ac3b 100644
--- a/lib/Sema/TypeCheckConcurrency.cpp
+++ b/lib/Sema/TypeCheckConcurrency.cpp
@@ -18,6 +18,7 @@
 #include "MiscDiagnostics.h"
 #include "TypeCheckDistributed.h"
 #include "TypeCheckInvertible.h"
+#include "TypeCheckProtocol.h"
 #include "TypeCheckType.h"
 #include "TypeChecker.h"
 #include "swift/AST/ASTWalker.h"
@@ -3175,6 +3176,18 @@ namespace {
         checkDefaultArgument(defaultArg);
       }
 
+      if (auto erasureExpr = dyn_cast<ErasureExpr>(expr)) {
+        checkIsolatedConformancesInContext(
+            erasureExpr->getConformances(), erasureExpr->getLoc(),
+            getDeclContext());
+      }
+
+      if (auto *underlyingToOpaque = dyn_cast<UnderlyingToOpaqueExpr>(expr)) {
+        checkIsolatedConformancesInContext(
+            underlyingToOpaque->substitutions, underlyingToOpaque->getLoc(),
+            getDeclContext());
+      }
+
       return Action::Continue(expr);
     }
 
@@ -4282,6 +4295,9 @@ namespace {
       if (!declRef)
         return false;
 
+      // Make sure isolated conformances are formed in the right context.
+      checkIsolatedConformancesInContext(declRef, loc, getDeclContext());
+
       auto decl = declRef.getDecl();
 
       // If this declaration is a callee from the enclosing application,
@@ -7684,3 +7700,99 @@ ActorIsolation swift::getConformanceIsolation(ProtocolConformance *conformance)
 
   return getActorIsolation(nominal);
 }
+
+namespace {
+  /// Identifies isolated conformances whose isolation differs from the
+  /// context's isolation.
+  class MismatchedIsolatedConformances {
+    llvm::TinyPtrVector<ProtocolConformance *> badIsolatedConformances;
+    DeclContext *fromDC;
+    mutable std::optional<ActorIsolation> fromIsolation;
+
+  public:
+    MismatchedIsolatedConformances(const DeclContext *fromDC)
+      : fromDC(const_cast<DeclContext *>(fromDC)) { }
+
+    ActorIsolation getContextIsolation() const {
+      if (!fromIsolation)
+        fromIsolation = getActorIsolationOfContext(fromDC);
+
+      return *fromIsolation;
+    }
+
+    ArrayRef<ProtocolConformance *> getBadIsolatedConformances() const {
+      return badIsolatedConformances;
+    }
+
+    explicit operator bool() const { return !badIsolatedConformances.empty(); }
+
+    bool operator()(ProtocolConformanceRef conformance) {
+      if (conformance.isAbstract() || conformance.isPack())
+        return false;
+
+      auto concrete = conformance.getConcrete();
+      auto normal = dyn_cast<NormalProtocolConformance>(
+          concrete->getRootConformance());
+      if (!normal)
+        return false;
+
+      if (!normal->isIsolated())
+        return false;
+
+      auto conformanceIsolation = getConformanceIsolation(concrete);
+      if (conformanceIsolation == getContextIsolation())
+        return true;
+
+      badIsolatedConformances.push_back(concrete);
+      return false;
+    }
+
+    /// If there were any bad isolated conformances, diagnose them and return
+    /// true. Otherwise, returns false.
+    bool diagnose(SourceLoc loc) const {
+      if (badIsolatedConformances.empty())
+        return false;
+
+      ASTContext &ctx = fromDC->getASTContext();
+      auto firstConformance = badIsolatedConformances.front();
+      ctx.Diags.diagnose(
+          loc, diag::isolated_conformance_wrong_domain,
+          getConformanceIsolation(firstConformance),
+          firstConformance->getType(),
+          firstConformance->getProtocol()->getName(),
+          getContextIsolation());
+      return true;
+    }
+  };
+
+}
+
+bool swift::checkIsolatedConformancesInContext(
+    ConcreteDeclRef declRef, SourceLoc loc, const DeclContext *dc) {
+  MismatchedIsolatedConformances mismatched(dc);
+  forEachConformance(declRef, mismatched);
+  return mismatched.diagnose(loc);
+}
+
+bool swift::checkIsolatedConformancesInContext(
+    ArrayRef<ProtocolConformanceRef> conformances, SourceLoc loc,
+    const DeclContext *dc) {
+  MismatchedIsolatedConformances mismatched(dc);
+  for (auto conformance: conformances)
+    forEachConformance(conformance, mismatched);
+  return mismatched.diagnose(loc);
+}
+
+bool swift::checkIsolatedConformancesInContext(
+    SubstitutionMap subs, SourceLoc loc, const DeclContext *dc) {
+  MismatchedIsolatedConformances mismatched(dc);
+  forEachConformance(subs, mismatched);
+  return mismatched.diagnose(loc);
+}
+
+bool swift::checkIsolatedConformancesInContext(
+    Type type, SourceLoc loc, const DeclContext *dc) {
+  MismatchedIsolatedConformances mismatched(dc);
+  forEachConformance(type, mismatched);
+  return mismatched.diagnose(loc);
+}
diff --git a/lib/Sema/TypeCheckConcurrency.h b/lib/Sema/TypeCheckConcurrency.h
index d00eb95870b24..0eea9b1ebdae4 100644
--- a/lib/Sema/TypeCheckConcurrency.h
+++ b/lib/Sema/TypeCheckConcurrency.h
@@ -703,6 +703,37 @@ void introduceUnsafeInheritExecutorReplacements(
 /// the immediate conformance, not any conformances on which it depends.
 ActorIsolation getConformanceIsolation(ProtocolConformance *conformance);
 
+/// Check for correct use of isolated conformances in the given reference.
+///
+/// This checks that any isolated conformances that occur in the given
+/// declaration reference match the isolated of the context.
+bool checkIsolatedConformancesInContext(
+    ConcreteDeclRef declRef, SourceLoc loc, const DeclContext *dc);
+
+/// Check for correct use of isolated conformances in the set given set of
+/// protocol conformances.
+///
+/// This checks that any isolated conformances that occur in the given
+/// declaration reference match the isolated of the context.
+bool checkIsolatedConformancesInContext(
+    ArrayRef<ProtocolConformanceRef> conformances, SourceLoc loc,
+    const DeclContext *dc);
+
+/// Check for correct use of isolated conformances in the given substitution
+/// map.
+///
+/// This checks that any isolated conformances that occur in the given
+/// substitution map match the isolated of the context.
+bool checkIsolatedConformancesInContext(
+    SubstitutionMap subs, SourceLoc loc, const DeclContext *dc);
+
+/// Check for correct use of isolated conformances in the given type.
+///
+/// This checks that any isolated conformances that occur in the given
+/// type match the isolated of the context.
+bool checkIsolatedConformancesInContext(
+    Type type, SourceLoc loc, const DeclContext *dc);
+
 } // end namespace swift
 
 namespace llvm {
diff --git a/lib/Sema/TypeCheckProtocol.cpp b/lib/Sema/TypeCheckProtocol.cpp
index 0e17cc2d634a0..cfd06a93d4275 100644
--- a/lib/Sema/TypeCheckProtocol.cpp
+++ b/lib/Sema/TypeCheckProtocol.cpp
@@ -43,6 +43,7 @@
 #include "swift/AST/GenericSignature.h"
 #include "swift/AST/NameLookup.h"
 #include "swift/AST/NameLookupRequests.h"
+#include "swift/AST/PackConformance.h"
 #include "swift/AST/ParameterList.h"
 #include "swift/AST/PotentialMacroExpansions.h"
 #include "swift/AST/PrettyStackTrace.h"
@@ -7159,3 +7160,106 @@ void TypeChecker::inferDefaultWitnesses(ProtocolDecl *proto) {
         req.getFirstType()->getCanonicalType(), requirementProto, conformance);
   }
 }
+
+bool swift::forEachConformance(
+    SubstitutionMap subs,
+    llvm::function_ref<bool(ProtocolConformanceRef)> body) {
+  if (!subs)
+    return false;
+
+  for (auto type: subs.getReplacementTypes()) {
+    if (forEachConformance(type, body))
+      return true;
+  }
+
+  for (auto conformance: subs.getConformances()) {
+    if (forEachConformance(conformance, body))
+      return true;
+  }
+
+  return false;
+}
+
+bool swift::forEachConformance(
+    ProtocolConformanceRef conformance,
+    llvm::function_ref<bool(ProtocolConformanceRef)> body) {
+  // Visit this conformance.
+  if (body(conformance))
+    return true;
+
+  if (conformance.isInvalid() || conformance.isAbstract())
+    return false;
+
+  if (conformance.isPack()) {
+    auto pack = conformance.getPack()->getPatternConformances();
+    for (auto conformance : pack) {
+      if (forEachConformance(conformance, body))
+        return true;
+    }
+
+    return false;
+  }
+
+  // Check the substitution make within this conformance.
+  auto concrete = conformance.getConcrete();
+  if (forEachConformance(concrete->getSubstitutionMap(), body))
+    return true;
+
+
+  return false;
+}
+
+bool swift::forEachConformance(
+    Type type, llvm::function_ref<bool(ProtocolConformanceRef)> body) {
+  return type.findIf([&](Type type) {
+    if (auto typeAlias = dyn_cast<TypeAliasType>(type.getPointer())) {
+      if (forEachConformance(typeAlias->getSubstitutionMap(), body))
+        return true;
+
+      return false;
+    }
+
+    if (auto opaqueArchetype =
+            dyn_cast<OpaqueTypeArchetypeType>(type.getPointer())) {
+      if (forEachConformance(opaqueArchetype->getSubstitutions(), body))
+        return true;
+
+      return false;
+    }
+
+    // Look through type sugar.
+    if (auto sugarType = dyn_cast<SyntaxSugarType>(type.getPointer())) {
+      type = sugarType->getImplementationType();
+    }
+
+    if (auto boundGeneric = dyn_cast<BoundGenericType>(type.getPointer())) {
+      auto subs = boundGeneric->getContextSubstitutionMap();
+      if (forEachConformance(subs, body))
+        return true;
+
+      return false;
+    }
+
+    return false;
+  });
+}
+
+bool swift::forEachConformance(
+    ConcreteDeclRef declRef,
+    llvm::function_ref<bool(ProtocolConformanceRef)> body) {
+  if (!declRef)
+    return false;
+
+  Type type = declRef.getDecl()->getInterfaceType();
+  if (auto subs = declRef.getSubstitutions()) {
+    if (forEachConformance(subs, body))
+      return true;
+
+    type = type.subst(subs);
+  }
+
+  if (forEachConformance(type, body))
+    return true;
+
+  return false;
+}
diff --git a/lib/Sema/TypeCheckProtocol.h b/lib/Sema/TypeCheckProtocol.h
index 0e9360b2d3cdf..2291fa25fa076 100644
--- a/lib/Sema/TypeCheckProtocol.h
+++ b/lib/Sema/TypeCheckProtocol.h
@@ -240,6 +240,38 @@ bool witnessHasImplementsAttrForRequiredName(ValueDecl *witness,
 bool witnessHasImplementsAttrForExactRequirement(ValueDecl *witness,
                                                  ValueDecl *requirement);
 
+/// Visit each conformance within the given type.
+///
+/// If `body` returns true for any conformance, this function stops the
+/// traversal and returns true.
+bool forEachConformance(
+    Type type, llvm::function_ref<bool(ProtocolConformanceRef)> body);
+
+/// Visit each conformance within the given conformance (including the given
+/// one).
+///
+/// If `body` returns true for any conformance, this function stops the
+/// traversal and returns true.
+bool forEachConformance(
+    ProtocolConformanceRef conformance,
+    llvm::function_ref<bool(ProtocolConformanceRef)> body);
+
+/// Visit each conformance within the given substitution map.
+///
+/// If `body` returns true for any conformance, this function stops the
+/// traversal and returns true.
+bool forEachConformance(
+    SubstitutionMap subs,
+    llvm::function_ref<bool(ProtocolConformanceRef)> body);
+
+/// Visit each conformance within the given declaration reference.
+///
+/// If `body` returns true for any conformance, this function stops the
+/// traversal and returns true.
+bool forEachConformance(
+    ConcreteDeclRef declRef,
+    llvm::function_ref<bool(ProtocolConformanceRef)> body);
+
 }
 
 #endif // SWIFT_SEMA_PROTOCOL_H
diff --git a/test/Concurrency/isolated_conformance.swift b/test/Concurrency/isolated_conformance.swift
index f8aeea1b4ede6..872fb4c02a5e7 100644
--- a/test/Concurrency/isolated_conformance.swift
+++ b/test/Concurrency/isolated_conformance.swift
@@ -119,3 +119,9 @@ func testIsolationConformancesInCall(c: C) {
   acceptSendableP(c) // expected-error{{isolated conformance of 'C' to 'P' cannot be used to satisfy conformance requirement for a `Sendable` type parameter}}
   acceptSendableMetaP(c) // expected-error{{isolated conformance of 'C' to 'P' cannot be used to satisfy conformance requirement for a `Sendable` type parameter}}
 }
+
+func testIsolationConformancesFromOutside(c: C) {
+  acceptP(c) // expected-error{{main actor-isolated isolated conformance of 'C' to 'P' cannot be used in nonisolated context}}
+  let _: any P = c // expected-error{{main actor-isolated isolated conformance of 'C' to 'P' cannot be used in nonisolated context}}
+  let _ = PWrapper<C>() // expected-error{{main actor-isolated isolated conformance of 'C' to 'P' cannot be used in nonisolated context}}
+}

From 5c67cffbc0e171ce6fc911950fcce84b45a92e6d Mon Sep 17 00:00:00 2001
From: Doug Gregor <dgregor@apple.com>
Date: Mon, 3 Mar 2025 22:14:54 -0800
Subject: [PATCH 2/2] Prevent infinite recursion with conformance enumeration.

---
 lib/Sema/TypeCheckProtocol.cpp | 69 +++++++++++++++++++++++++---------
 lib/Sema/TypeCheckProtocol.h   | 14 +++++--
 2 files changed, 62 insertions(+), 21 deletions(-)

diff --git a/lib/Sema/TypeCheckProtocol.cpp b/lib/Sema/TypeCheckProtocol.cpp
index cfd06a93d4275..4137133059fa6 100644
--- a/lib/Sema/TypeCheckProtocol.cpp
+++ b/lib/Sema/TypeCheckProtocol.cpp
@@ -7163,17 +7163,22 @@ void TypeChecker::inferDefaultWitnesses(ProtocolDecl *proto) {
 
 bool swift::forEachConformance(
     SubstitutionMap subs,
-    llvm::function_ref<bool(ProtocolConformanceRef)> body) {
+    llvm::function_ref<bool(ProtocolConformanceRef)> body,
+    VisitedConformances *visitedConformances) {
   if (!subs)
     return false;
 
+  VisitedConformances visited;
+  if (!visitedConformances)
+    visitedConformances = &visited;
+
   for (auto type: subs.getReplacementTypes()) {
-    if (forEachConformance(type, body))
+    if (forEachConformance(type, body, visitedConformances))
       return true;
   }
 
   for (auto conformance: subs.getConformances()) {
-    if (forEachConformance(conformance, body))
+    if (forEachConformance(conformance, body, visitedConformances))
       return true;
   }
 
@@ -7182,10 +7187,12 @@ bool swift::forEachConformance(
 
 bool swift::forEachConformance(
     ProtocolConformanceRef conformance,
-    llvm::function_ref<bool(ProtocolConformanceRef)> body) {
-  // Visit this conformance.
-  if (body(conformance))
-    return true;
+    llvm::function_ref<bool(ProtocolConformanceRef)> body,
+    VisitedConformances *visitedConformances) {
+  // Make sure we can store visited conformances.
+  VisitedConformances visited;
+  if (!visitedConformances)
+    visitedConformances = &visited;
 
   if (conformance.isInvalid() || conformance.isAbstract())
     return false;
@@ -7193,27 +7200,48 @@ bool swift::forEachConformance(
   if (conformance.isPack()) {
     auto pack = conformance.getPack()->getPatternConformances();
     for (auto conformance : pack) {
-      if (forEachConformance(conformance, body))
+      if (forEachConformance(conformance, body, visitedConformances))
         return true;
     }
 
     return false;
   }
 
-  // Check the substitution make within this conformance.
+  // Extract the concrete conformance.
   auto concrete = conformance.getConcrete();
-  if (forEachConformance(concrete->getSubstitutionMap(), body))
+
+  // Prevent recursion.
+  if (!visitedConformances->insert(concrete).second)
+    return false;
+
+  // Visit this conformance.
+  if (body(conformance))
     return true;
 
+  // Check the substitution map within this conformance.
+  if (forEachConformance(concrete->getSubstitutionMap(), body,
+                         visitedConformances))
+    return true;
 
   return false;
 }
 
 bool swift::forEachConformance(
-    Type type, llvm::function_ref<bool(ProtocolConformanceRef)> body) {
+    Type type, llvm::function_ref<bool(ProtocolConformanceRef)> body,
+    VisitedConformances *visitedConformances) {
+  // Make sure we can store visited conformances.
+  VisitedConformances visited;
+  if (!visitedConformances)
+    visitedConformances = &visited;
+
+  // Prevent recursion.
+  if (!visitedConformances->insert(type.getPointer()).second)
+    return false;
+
   return type.findIf([&](Type type) {
     if (auto typeAlias = dyn_cast<TypeAliasType>(type.getPointer())) {
-      if (forEachConformance(typeAlias->getSubstitutionMap(), body))
+      if (forEachConformance(typeAlias->getSubstitutionMap(), body,
+                             visitedConformances))
         return true;
 
       return false;
@@ -7221,7 +7249,8 @@ bool swift::forEachConformance(
 
     if (auto opaqueArchetype =
             dyn_cast<OpaqueTypeArchetypeType>(type.getPointer())) {
-      if (forEachConformance(opaqueArchetype->getSubstitutions(), body))
+      if (forEachConformance(opaqueArchetype->getSubstitutions(), body,
+                             visitedConformances))
         return true;
 
       return false;
@@ -7234,7 +7263,7 @@ bool swift::forEachConformance(
 
     if (auto boundGeneric = dyn_cast<BoundGenericType>(type.getPointer())) {
       auto subs = boundGeneric->getContextSubstitutionMap();
-      if (forEachConformance(subs, body))
+      if (forEachConformance(subs, body, visitedConformances))
         return true;
 
       return false;
@@ -7246,19 +7275,25 @@ bool swift::forEachConformance(
 
 bool swift::forEachConformance(
     ConcreteDeclRef declRef,
-    llvm::function_ref<bool(ProtocolConformanceRef)> body) {
+    llvm::function_ref<bool(ProtocolConformanceRef)> body,
+    VisitedConformances *visitedConformances) {
   if (!declRef)
     return false;
 
+  // Make sure we can store visited conformances.
+  VisitedConformances visited;
+  if (!visitedConformances)
+    visitedConformances = &visited;
+
   Type type = declRef.getDecl()->getInterfaceType();
   if (auto subs = declRef.getSubstitutions()) {
-    if (forEachConformance(subs, body))
+    if (forEachConformance(subs, body, visitedConformances))
       return true;
 
     type = type.subst(subs);
   }
 
-  if (forEachConformance(type, body))
+  if (forEachConformance(type, body, visitedConformances))
     return true;
 
   return false;
diff --git a/lib/Sema/TypeCheckProtocol.h b/lib/Sema/TypeCheckProtocol.h
index 2291fa25fa076..f71caa8ded30e 100644
--- a/lib/Sema/TypeCheckProtocol.h
+++ b/lib/Sema/TypeCheckProtocol.h
@@ -240,12 +240,15 @@ bool witnessHasImplementsAttrForRequiredName(ValueDecl *witness,
 bool witnessHasImplementsAttrForExactRequirement(ValueDecl *witness,
                                                  ValueDecl *requirement);
 
+using VisitedConformances = llvm::SmallPtrSet<void *, 16>;
+
 /// Visit each conformance within the given type.
 ///
 /// If `body` returns true for any conformance, this function stops the
 /// traversal and returns true.
 bool forEachConformance(
-    Type type, llvm::function_ref<bool(ProtocolConformanceRef)> body);
+    Type type, llvm::function_ref<bool(ProtocolConformanceRef)> body,
+    VisitedConformances *visitedConformances = nullptr);
 
 /// Visit each conformance within the given conformance (including the given
 /// one).
@@ -254,7 +257,8 @@ bool forEachConformance(
 /// traversal and returns true.
 bool forEachConformance(
     ProtocolConformanceRef conformance,
-    llvm::function_ref<bool(ProtocolConformanceRef)> body);
+    llvm::function_ref<bool(ProtocolConformanceRef)> body,
+    VisitedConformances *visitedConformances = nullptr);
 
 /// Visit each conformance within the given substitution map.
 ///
@@ -262,7 +266,8 @@ bool forEachConformance(
 /// traversal and returns true.
 bool forEachConformance(
     SubstitutionMap subs,
-    llvm::function_ref<bool(ProtocolConformanceRef)> body);
+    llvm::function_ref<bool(ProtocolConformanceRef)> body,
+    VisitedConformances *visitedConformances = nullptr);
 
 /// Visit each conformance within the given declaration reference.
 ///
@@ -270,7 +275,8 @@ bool forEachConformance(
 /// traversal and returns true.
 bool forEachConformance(
     ConcreteDeclRef declRef,
-    llvm::function_ref<bool(ProtocolConformanceRef)> body);
+    llvm::function_ref<bool(ProtocolConformanceRef)> body,
+    VisitedConformances *visitedConformances = nullptr);
 
 }