From 11ee82de7b0e2e4990e271eb8e4b0bf96c5b99fb Mon Sep 17 00:00:00 2001 From: Pavel Yaskevich Date: Thu, 20 Jul 2023 13:30:24 -0700 Subject: [PATCH 01/37] [CSBindings] Prefer conjunctions over closure variables without bindings If a closure doesn't have a contextual type inferred yet it should be delayed in favor of already resolved closure conjunction because "resolving" such a closure early could miss result builder attribute attached to a parameter the closure is passed to. Partially resolves https://github.com/apple/swift/issues/67363 --- lib/Sema/CSBindings.cpp | 7 ++ test/Constraints/issue67363.swift | 85 +++++++++++++++++++++++++ test/expr/closure/multi_statement.swift | 24 +++++++ 3 files changed, 116 insertions(+) create mode 100644 test/Constraints/issue67363.swift diff --git a/lib/Sema/CSBindings.cpp b/lib/Sema/CSBindings.cpp index 9547f4c3a06cd..c4c77a0ffa10a 100644 --- a/lib/Sema/CSBindings.cpp +++ b/lib/Sema/CSBindings.cpp @@ -1094,6 +1094,13 @@ bool BindingSet::favoredOverConjunction(Constraint *conjunction) const { if (locator->directlyAt()) { auto *closure = castToExpr(locator->getAnchor()); + // If there are no bindings for the closure yet we cannot prioritize + // it because that runs into risk of missing a result builder transform. + if (TypeVar->getImpl().isClosureType()) { + if (Bindings.empty()) + return false; + } + if (auto transform = CS.getAppliedResultBuilderTransform(closure)) { // Conjunctions that represent closures with result builder transformed // bodies could be attempted right after their resolution if they meet diff --git a/test/Constraints/issue67363.swift b/test/Constraints/issue67363.swift new file mode 100644 index 0000000000000..b767bf646d729 --- /dev/null +++ b/test/Constraints/issue67363.swift @@ -0,0 +1,85 @@ +// RUN: %target-typecheck-verify-swift -disable-availability-checking + +// https://github.com/apple/swift/issues/67363 + +protocol UIView { + init() +} + +class UILabel : UIView { + required init() {} +} + +class UIStackView : UIView { + required init() {} +} + +protocol ViewRepresentable { + associatedtype View: UIView + func configure(view: View) +} + +struct StyledString: ViewRepresentable { + let content: String + func configure(view: UILabel) {} +} + +class StackViewOne: UIStackView { + var first = First() +} + +struct Stack { + struct One: ViewRepresentable { + let first: First + func configure(view: StackViewOne) { + first.configure(view: view.first) + } + } + + @resultBuilder + enum Builder { + static func buildBlock(_ first: First) -> Stack.One { + Stack.One(first: first) + } + } + + static func vertical(@Builder build builder: () -> StackType) -> StackType { + builder() + } +} + +struct ListItem { + let body: any ViewRepresentable +} + +@resultBuilder +enum ListBuilder { + static func buildExpression(_ expression: View?) -> [ListItem?] { + [expression.map { .init(body: $0) }] + } + + static func buildBlock(_ components: [ListItem?]...) -> [ListItem] { + components.flatMap { $0.compactMap { $0 } } + } +} + +struct WithFooter: ViewRepresentable { + let body: T + let footer: () -> [ListItem] + func configure(view: T.View) {} +} + +extension ViewRepresentable { + func withFooter(@ListBuilder build: @escaping () -> [ListItem]) -> WithFooter { + .init(body: self, footer: build) + } +} + +func testThatResultBuilderIsAppliedToWithFooterArgument() -> some ViewRepresentable { + Stack.vertical() { + StyledString(content: "vertical") + } + .withFooter { + StyledString(content: "footer") + } +} diff --git a/test/expr/closure/multi_statement.swift b/test/expr/closure/multi_statement.swift index af4b1dd33c4f4..71b3987832876 100644 --- a/test/expr/closure/multi_statement.swift +++ b/test/expr/closure/multi_statement.swift @@ -694,3 +694,27 @@ func test_recursive_var_reference_in_multistatement_closure() { } } } + +// https://github.com/apple/swift/issues/67363 +func test_result_builder_in_member_chaining() { + @resultBuilder + struct Builder { + static func buildBlock(_: T) -> Int { 42 } + } + + struct Test { + static func test(fn: () -> T) -> T { + fn() + } + + func builder(@Builder _: () -> Int) {} + } + + Test.test { + let test = Test() + return test + }.builder { // Ok + let result = "" + result + } +} From c210a08afb14ed9bb2c02356a50c31ceedbbf67f Mon Sep 17 00:00:00 2001 From: Sam Kortekaas Date: Thu, 29 Jun 2023 12:22:12 +0200 Subject: [PATCH 02/37] [SourceKit] Don't report types for implicit expressions Fixes incorrectly reporting an optional type for an expression when the contextual type is optional. fixes #66882 rdar://111462279 --- lib/IDE/IDETypeChecking.cpp | 9 +++++++++ test/SourceKit/ExpressionType/basic.swift | 11 +++++++++++ 2 files changed, 20 insertions(+) diff --git a/lib/IDE/IDETypeChecking.cpp b/lib/IDE/IDETypeChecking.cpp index cc22ce3737728..8823514c52f7d 100644 --- a/lib/IDE/IDETypeChecking.cpp +++ b/lib/IDE/IDETypeChecking.cpp @@ -667,6 +667,15 @@ class ExpressionTypeCollector: public SourceEntityWalker { if (E->getType().isNull()) return false; + // We should not report a type for implicit expressions, except for + // - `OptionalEvaluationExpr` to show the correct type when there is optional chaining + // - `DotSyntaxCallExpr` to report the method type without the metatype + if (E->isImplicit() && + !isa(E) && + !isa(E)) { + return false; + } + // If we have already reported types for this source range, we shouldn't // report again. This makes sure we always report the outtermost type of // several overlapping expressions. diff --git a/test/SourceKit/ExpressionType/basic.swift b/test/SourceKit/ExpressionType/basic.swift index 97887beb2e1dd..d0c81fb793ea7 100644 --- a/test/SourceKit/ExpressionType/basic.swift +++ b/test/SourceKit/ExpressionType/basic.swift @@ -20,6 +20,13 @@ func DictS(_ a: [Int: S]) { _ = a[2]?.val.advanced(by: 1).byteSwapped } +func optExpr(str: String?) -> String? { + let a: String? = str + let b: String? = "Hey" + let c: String = "Bye" + return c +} + // RUN: %sourcekitd-test -req=collect-type %s -- %s | %FileCheck %s // CHECK: (183, 202): Int // CHECK: (183, 196): String @@ -31,3 +38,7 @@ func DictS(_ a: [Int: S]) { // CHECK: (291, 292): Int? // CHECK: (295, 332): Int? // CHECK: (295, 320): Int +// CHECK: (395, 398): String? +// CHECK: (418, 423): String +// CHECK: (442, 447): String +// CHECK: (457, 458): String From 55892ef30d704f952478ed03cdf48ca664d46208 Mon Sep 17 00:00:00 2001 From: Michael Gottesman Date: Wed, 12 Jul 2023 13:47:22 -0700 Subject: [PATCH 03/37] [silgen] Add a special visitor for accessing the base of noncopyable types. We want these to be borrowed in most cases and to create an appropriate onion wrapping. Since we are doing this in more cases now, we fix a bunch of cases where we used to be forced to insert a copy since a coroutine or access would end too early. --- lib/SILGen/LValue.h | 2 +- lib/SILGen/SILGenLValue.cpp | 133 +++++++++++-- .../Mandatory/MoveOnlyAddressCheckerUtils.cpp | 29 ++- test/SILGen/moveonly.swift | 183 +++++------------- test/SILGen/moveonly_library_evolution.swift | 10 +- ...ly_addressonly_subscript_diagnostics.swift | 4 - ...eonly_loadable_subscript_diagnostics.swift | 4 - .../moveonly_partial_consumption.swift | 3 +- 8 files changed, 185 insertions(+), 183 deletions(-) diff --git a/lib/SILGen/LValue.h b/lib/SILGen/LValue.h index a2f1c286cdbf5..e16811623e729 100644 --- a/lib/SILGen/LValue.h +++ b/lib/SILGen/LValue.h @@ -530,7 +530,7 @@ class LValue { SGFAccessKind selfAccess, SGFAccessKind otherAccess); - void dump() const; + SWIFT_DEBUG_DUMP; void dump(raw_ostream &os, unsigned indent = 0) const; }; diff --git a/lib/SILGen/SILGenLValue.cpp b/lib/SILGen/SILGenLValue.cpp index 4ce894784af65..c7ec40a069cb6 100644 --- a/lib/SILGen/SILGenLValue.cpp +++ b/lib/SILGen/SILGenLValue.cpp @@ -2877,6 +2877,116 @@ static ManagedValue visitRecNonInOutBase(SILGenLValue &SGL, Expr *e, value); } +static CanType getBaseFormalType(Expr *baseExpr) { + return baseExpr->getType()->getWithoutSpecifierType()->getCanonicalType(); +} + +class LLVM_LIBRARY_VISIBILITY SILGenBorrowedBaseVisitor + : public Lowering::ExprVisitor { +public: + SILGenLValue &SGL; + SILGenFunction &SGF; + + SILGenBorrowedBaseVisitor(SILGenLValue &SGL, SILGenFunction &SGF) + : SGL(SGL), SGF(SGF) {} + + /// Returns the subexpr + static bool isNonCopyableBaseBorrow(SILGenFunction &SGF, Expr *e) { + if (auto *le = dyn_cast(e)) + return le->getType()->isPureMoveOnly(); + if (auto *m = dyn_cast(e)) { + // If our m is a pure noncopyable type or our base is, we need to perform + // a noncopyable base borrow. + // + // DISCUSSION: We can have a noncopyable member_ref_expr with a copyable + // base if the noncopyable member_ref_expr is from a computed method. In + // such a case, we want to ensure that we wrap things the right way. + return m->getType()->isPureMoveOnly() || + m->getBase()->getType()->isPureMoveOnly(); + } + return false; + } + + LValue visitExpr(Expr *e, SGFAccessKind accessKind, LValueOptions options) { + e->dump(llvm::errs()); + llvm::report_fatal_error("Unimplemented node!"); + } + + LValue visitMemberRefExpr(MemberRefExpr *e, SGFAccessKind accessKind, + LValueOptions options) { + // If we have a member_ref_expr, we create a component that will when we + // evaluate the lvalue, + VarDecl *var = cast(e->getMember().getDecl()); + + assert(!e->getType()->is()); + + auto accessSemantics = e->getAccessSemantics(); + AccessStrategy strategy = var->getAccessStrategy( + accessSemantics, getFormalAccessKind(accessKind), + SGF.SGM.M.getSwiftModule(), SGF.F.getResilienceExpansion()); + + auto baseFormalType = getBaseFormalType(e->getBase()); + LValue lv = visit( + e->getBase(), + getBaseAccessKind(SGF.SGM, var, accessKind, strategy, baseFormalType), + getBaseOptions(options, strategy)); + llvm::Optional actorIso; + if (e->isImplicitlyAsync()) + actorIso = getActorIsolation(var); + lv.addMemberVarComponent(SGF, e, var, e->getMember().getSubstitutions(), + options, e->isSuper(), accessKind, strategy, + getSubstFormalRValueType(e), + false /*is on self parameter*/, actorIso); + return lv; + } + + ManagedValue emitImmediateBaseValue(Expr *e) { + // We are going to immediately use this base value, so we want to borrow it. + ManagedValue mv = + SGF.emitRValueAsSingleValue(e, SGFContext::AllowImmediatePlusZero); + if (mv.isPlusZeroRValueOrTrivial()) + return mv; + + // Any temporaries needed to materialize the lvalue must be destroyed when + // at the end of the lvalue's formal evaluation scope. + // e.g. for foo(self.bar) + // %self = load [copy] %ptr_self + // %rvalue = barGetter(%self) + // destroy_value %self // self must be released before calling foo. + // foo(%rvalue) + SILValue value = mv.forward(SGF); + return SGF.emitFormalAccessManagedRValueWithCleanup(CleanupLocation(e), + value); + } + + LValue visitDeclRefExpr(DeclRefExpr *e, SGFAccessKind accessKind, + LValueOptions options) { + if (accessKind == SGFAccessKind::BorrowedObjectRead) { + auto rv = emitImmediateBaseValue(e); + CanType formalType = getSubstFormalRValueType(e); + auto typeData = getValueTypeData(accessKind, formalType, rv.getValue()); + LValue lv; + lv.add(rv, llvm::None, typeData, /*isRValue=*/true); + return lv; + } + + return SGL.visitDeclRefExpr(e, accessKind, options); + } + + LValue visitLoadExpr(LoadExpr *e, SGFAccessKind accessKind, + LValueOptions options) { + // TODO: orig abstraction pattern. + LValue lv = SGL.visitRec(e->getSubExpr(), + SGFAccessKind::BorrowedAddressRead, options); + CanType formalType = getSubstFormalRValueType(e); + LValueTypeData typeData{accessKind, AbstractionPattern(formalType), + formalType, lv.getTypeOfRValue().getASTType()}; + lv.add(typeData); + return lv; + } +}; + LValue SILGenLValue::visitRec(Expr *e, SGFAccessKind accessKind, LValueOptions options, AbstractionPattern orig) { // First see if we have an lvalue type. If we do, then quickly handle that and @@ -2889,19 +2999,14 @@ LValue SILGenLValue::visitRec(Expr *e, SGFAccessKind accessKind, // a `borrow x` operator, the operator is used on the base here), we want to // apply the lvalue within a formal access to the original value instead of // an actual loaded copy. - - if (e->getType()->isPureMoveOnly()) { - if (auto load = dyn_cast(e)) { - LValue lv = visitRec(load->getSubExpr(), SGFAccessKind::BorrowedAddressRead, - options, orig); - CanType formalType = getSubstFormalRValueType(e); - LValueTypeData typeData{accessKind, AbstractionPattern(formalType), - formalType, lv.getTypeOfRValue().getASTType()}; - lv.add(typeData); - return lv; - } + if (SILGenBorrowedBaseVisitor::isNonCopyableBaseBorrow(SGF, e)) { + SILGenBorrowedBaseVisitor visitor(*this, SGF); + auto accessKind = SGFAccessKind::BorrowedObjectRead; + if (e->getType()->is()) + accessKind = SGFAccessKind::BorrowedAddressRead; + return visitor.visit(e, accessKind, options); } - + // Otherwise we have a non-lvalue type (references, values, metatypes, // etc). These act as the root of a logical lvalue. Compute the root value, // wrap it in a ValueComponent, and return it for our caller. @@ -3554,10 +3659,6 @@ static SGFAccessKind getBaseAccessKind(SILGenModule &SGM, llvm_unreachable("bad access strategy"); } -static CanType getBaseFormalType(Expr *baseExpr) { - return baseExpr->getType()->getWithoutSpecifierType()->getCanonicalType(); -} - bool isCallToReplacedInDynamicReplacement(SILGenFunction &SGF, AbstractFunctionDecl *afd, bool &isObjCReplacementSelfCall); diff --git a/lib/SILOptimizer/Mandatory/MoveOnlyAddressCheckerUtils.cpp b/lib/SILOptimizer/Mandatory/MoveOnlyAddressCheckerUtils.cpp index ef8878db0be25..a3296d40f12dc 100644 --- a/lib/SILOptimizer/Mandatory/MoveOnlyAddressCheckerUtils.cpp +++ b/lib/SILOptimizer/Mandatory/MoveOnlyAddressCheckerUtils.cpp @@ -1504,19 +1504,30 @@ struct CopiedLoadBorrowEliminationVisitor final // We can only hit this if our load_borrow was copied. llvm_unreachable("We should never hit this"); - case OperandOwnership::GuaranteedForwarding: - // If we have a switch_enum, we always need to convert it to a load - // [copy] since we need to destructure through it. - shouldConvertToLoadCopy |= isa(nextUse->getUser()); - + case OperandOwnership::GuaranteedForwarding: { + SmallVector forwardedValues; + auto *fn = nextUse->getUser()->getFunction(); ForwardingOperand(nextUse).visitForwardedValues([&](SILValue value) { - for (auto *use : value->getUses()) { - useWorklist.push_back(use); - } + if (value->getType().isTrivial(fn)) + return true; + forwardedValues.push_back(value); return true; }); - continue; + // If we do not have any forwarded values, just continue. + if (forwardedValues.empty()) + continue; + + while (!forwardedValues.empty()) { + for (auto *use : forwardedValues.pop_back_val()->getUses()) + useWorklist.push_back(use); + } + + // If we have a switch_enum, we always need to convert it to a load + // [copy] since we need to destructure through it. + shouldConvertToLoadCopy |= isa(nextUse->getUser()); + continue; + } case OperandOwnership::Borrow: LLVM_DEBUG(llvm::dbgs() << " Found recursive borrow!\n"); // Look through borrows. diff --git a/test/SILGen/moveonly.swift b/test/SILGen/moveonly.swift index 10defcf0b840c..87837a9696c29 100644 --- a/test/SILGen/moveonly.swift +++ b/test/SILGen/moveonly.swift @@ -1170,17 +1170,17 @@ public struct LoadableSubscriptGetOnlyTesterNonCopyableStructParent : ~Copyable // CHECK: [[MARK:%.*]] = mark_must_check [no_consume_or_assign] [[ACCESS]] // CHECK: [[LOAD_BORROW:%.*]] = load_borrow [[MARK]] // CHECK: [[VALUE:%.*]] = apply {{%.*}}([[LOAD_BORROW]]) -// CHECK: end_borrow [[LOAD_BORROW]] -// CHECK: end_access [[ACCESS]] // // CHECK: [[BORROWED_VALUE:%.*]] = begin_borrow [[VALUE]] // CHECK: [[TEMP:%.*]] = alloc_stack $AddressOnlyProtocol // CHECK: [[TEMP_MARK:%.*]] = mark_must_check [consumable_and_assignable] [[TEMP]] // CHECK: apply {{%.*}}([[TEMP_MARK]], {{%.*}}, [[BORROWED_VALUE]]) // CHECK: end_borrow [[BORROWED_VALUE]] -// CHECK: destroy_value [[VALUE]] +// CHECK: end_borrow [[LOAD_BORROW]] +// CHECK: end_access [[ACCESS]] // CHECK: apply {{%.*}}([[TEMP_MARK]]) // CHECK: destroy_addr [[TEMP_MARK]] +// CHECK: destroy_value [[VALUE]] // } // end sil function '$s8moveonly077testSubscriptGetOnlyThroughNonCopyableParentStruct_BaseLoadable_ResultAddressE4_VaryyF' public func testSubscriptGetOnlyThroughNonCopyableParentStruct_BaseLoadable_ResultAddressOnly_Var() { var m = LoadableSubscriptGetOnlyTesterNonCopyableStructParent() @@ -1197,13 +1197,9 @@ public func testSubscriptGetOnlyThroughNonCopyableParentStruct_BaseLoadable_Resu // CHECK: [[MARK:%.*]] = mark_must_check [no_consume_or_assign] [[PROJECT]] // CHECK: [[LOAD:%.*]] = load_borrow [[MARK]] // CHECK: [[EXT:%.*]] = struct_extract [[LOAD]] -// CHECK: [[COPY:%.*]] = copy_value [[EXT]] -// CHECK: [[BORROW:%.*]] = begin_borrow [[COPY]] // CHECK: [[TEMP:%.*]] = alloc_stack $AddressOnlyProtocol // CHECK: [[TEMP_MARK:%.*]] = mark_must_check [consumable_and_assignable] [[TEMP]] -// CHECK: apply {{%.*}}([[TEMP_MARK]], {{%.*}}, [[BORROW]]) -// CHECK: end_borrow [[BORROW]] -// CHECK: destroy_value [[COPY]] +// CHECK: apply {{%.*}}([[TEMP_MARK]], {{%.*}}, [[EXT]]) // CHECK: apply {{%.*}}([[TEMP_MARK]]) // CHECK: destroy_addr [[TEMP_MARK]] // CHECK: end_borrow [[LOAD]] @@ -1257,11 +1253,6 @@ public class LoadableSubscriptGetOnlyTesterClassParent { var testerParent = LoadableSubscriptGetOnlyTesterNonCopyableStructParent() } -// TODO(MG): I am preparing a small pass that cleans up the copy_value -// below. The code in SILGen is in some very generic code that changing could -// have other unintentional side-effects, so it makes sense to instead just add -// a small cleanup transform before we do move checking to cleanup this pattern. -// // CHECK-LABEL: sil [ossa] @$s8moveonly065testSubscriptGetOnlyThroughParentClass_BaseLoadable_ResultAddressE4_VaryyF : $@convention(thin) () -> () { // CHECK: [[BOX:%.*]] = alloc_box $ // CHECK: [[BORROW:%.*]] = begin_borrow [lexical] [[BOX]] @@ -1306,34 +1297,17 @@ public class LoadableSubscriptGetOnlyTesterClassParent { // CHECK: apply {{%.*}}([[TEMP2_MARK]]) // CHECK: destroy_addr [[TEMP2_MARK]] // -// Third read. This is a case that we can't handle today due to the way the AST -// looks: -// -// (subscript_expr type='AddressOnlyProtocol' -// (member_ref_expr type='LoadableSubscriptGetOnlyTester' -// (load_expr implicit type='LoadableSubscriptGetOnlyTesterClassParent' -// (declref_expr type='@lvalue LoadableSubscriptGetOnlyTesterClassParent' -// (argument_list -// (argument -// (integer_literal_expr type='Int' -// -// due to the load_expr in the subscript base, SILGen emits a base rvalue for -// the load_expr and copies it, ending the coroutine. What we need is the -// ability to have an lvalue pseudo-component that treats the declref_expr (and -// any member_ref_expr) as a base and allows for a load_expr to be followed by N -// member_ref_expr. +// Third read. // // CHECK: [[ACCESS:%.*]] = begin_access [read] [unknown] [[PROJECT]] -// CHECK: [[COPYABLE_CLASS:%.*]] = load [copy] [[ACCESS]] -// CHECK: end_access [[ACCESS]] -// CHECK: [[BORROW_COPYABLE_CLASS:%.*]] = begin_borrow [[COPYABLE_CLASS]] -// CHECK: ([[CORO_RESULT:%.*]], [[CORO_TOKEN:%.*]]) = begin_apply {{%.*}}([[BORROW_COPYABLE_CLASS]]) -// CHECK: [[CORO_RESULT_COPY:%.*]] = copy_value [[CORO_RESULT]] -// CHECK: end_apply [[CORO_TOKEN]] -// CHECK: [[BORROW:%.*]] = begin_borrow [[CORO_RESULT_COPY]] +// CHECK: [[LOAD:%.*]] = load_borrow [[ACCESS]] +// CHECK: ([[CORO_RESULT:%.*]], [[CORO_TOKEN:%.*]]) = begin_apply {{%.*}}([[LOAD]]) // CHECK: [[TEMP:%.*]] = alloc_stack $ // CHECK: [[TEMP_MARK:%.*]] = mark_must_check [consumable_and_assignable] [[TEMP]] -// CHECK: apply {{%.*}}([[TEMP_MARK]], {{%.*}}, [[BORROW]]) +// CHECK: apply {{%.*}}([[TEMP_MARK]], {{%.*}}, [[CORO_RESULT]]) +// CHECK: end_apply [[CORO_TOKEN]] +// CHECK: end_borrow [[LOAD]] +// CHECK: apply {{%.*}}([[TEMP_MARK]]) // CHECK: destroy_addr [[TEMP_MARK]] // CHECK: } // end sil function '$s8moveonly065testSubscriptGetOnlyThroughParentClass_BaseLoadable_ResultAddressE4_VaryyF' @@ -1554,16 +1528,16 @@ public struct LoadableSubscriptGetSetTesterNonCopyableStructParent : ~Copyable { // CHECK: [[MARK:%.*]] = mark_must_check [no_consume_or_assign] [[ACCESS]] // CHECK: [[LOAD_BORROW:%.*]] = load_borrow [[MARK]] // CHECK: [[VALUE:%.*]] = apply {{%.*}}([[LOAD_BORROW]]) -// CHECK: end_borrow [[LOAD_BORROW]] -// CHECK: end_access [[ACCESS]] // CHECK: [[BORROWED_VALUE:%.*]] = begin_borrow [[VALUE]] // CHECK: [[TEMP:%.*]] = alloc_stack $AddressOnlyProtocol // CHECK: [[MARK_TEMP:%.*]] = mark_must_check [consumable_and_assignable] [[TEMP]] // CHECK: apply {{%.*}}([[MARK_TEMP]], {{%.*}}, [[BORROWED_VALUE]]) // CHECK: end_borrow [[BORROWED_VALUE]] -// CHECK: destroy_value [[VALUE]] +// CHECK: end_borrow [[LOAD_BORROW]] +// CHECK: end_access [[ACCESS]] // CHECK: apply {{%.*}}([[MARK_TEMP]]) // CHECK: destroy_addr [[MARK_TEMP]] +// CHECK: destroy_value [[VALUE]] // } // end sil function '$s8moveonly077testSubscriptGetSetThroughNonCopyableParentStruct_BaseLoadable_ResultAddressE4_VaryyF' public func testSubscriptGetSetThroughNonCopyableParentStruct_BaseLoadable_ResultAddressOnly_Var() { var m = LoadableSubscriptGetSetTesterNonCopyableStructParent() @@ -1581,13 +1555,9 @@ public func testSubscriptGetSetThroughNonCopyableParentStruct_BaseLoadable_Resul // CHECK: [[MARK:%.*]] = mark_must_check [no_consume_or_assign] [[PROJECT]] // CHECK: [[LOAD:%.*]] = load_borrow [[MARK]] // CHECK: [[EXT:%.*]] = struct_extract [[LOAD]] -// CHECK: [[COPY:%.*]] = copy_value [[EXT]] -// CHECK: [[BORROW:%.*]] = begin_borrow [[COPY]] // CHECK: [[TEMP:%.*]] = alloc_stack $AddressOnlyProtocol // CHECK: [[MARK_TEMP:%.*]] = mark_must_check [consumable_and_assignable] [[TEMP]] -// CHECK: apply {{%.*}}([[MARK_TEMP]], {{%.*}}, [[BORROW]]) -// CHECK: end_borrow [[BORROW]] -// CHECK: destroy_value [[COPY]] +// CHECK: apply {{%.*}}([[MARK_TEMP]], {{%.*}}, [[EXT]]) // CHECK: apply {{%.*}}([[MARK_TEMP]]) // CHECK: destroy_addr [[MARK_TEMP]] // CHECK: end_borrow [[LOAD]] @@ -1742,36 +1712,17 @@ public class LoadableSubscriptGetSetTesterClassParent { // CHECK: apply {{%.*}}([[MARK_TEMP]], {{%.*}}, [[GEP]]) // CHECK: end_apply [[CORO_TOKEN]] // -// Third read. This is a case that we can't handle today due to the way the AST -// looks: -// -// (subscript_expr type='AddressOnlyProtocol' -// (member_ref_expr type='LoadableSubscriptGetSetTester' -// (load_expr implicit type='LoadableSubscriptGetSetTesterClassParent' -// (declref_expr type='@lvalue LoadableSubscriptGetSetTesterClassParent' -// (argument_list -// (argument -// (integer_literal_expr type='Int' -// -// due to the load_expr in the subscript base, SILGen emits a base rvalue for -// the load_expr and copies it, ending the coroutine. What we need is the -// ability to have an lvalue pseudo-component that treats the declref_expr (and -// any member_ref_expr) as a base and allows for a load_expr to be followed by N -// member_ref_expr. +// Third read. // // CHECK: [[ACCESS:%.*]] = begin_access [read] [unknown] [[PROJECT]] -// CHECK: [[COPYABLE_CLASS:%.*]] = load [copy] [[ACCESS]] -// CHECK: end_access [[ACCESS]] -// CHECK: [[BORROW_COPYABLE_CLASS:%.*]] = begin_borrow [[COPYABLE_CLASS]] -// CHECK: ([[CORO_RESULT:%.*]], [[CORO_TOKEN:%.*]]) = begin_apply {{%.*}}([[BORROW_COPYABLE_CLASS]]) -// CHECK: [[CORO_RESULT_COPY:%.*]] = copy_value [[CORO_RESULT]] -// CHECK: end_apply [[CORO_TOKEN]] -// CHECK: [[BORROW:%.*]] = begin_borrow [[CORO_RESULT_COPY]] +// CHECK: [[CLASS:%.*]] = load_borrow [[ACCESS]] +// CHECK: ([[CORO_RESULT:%.*]], [[CORO_TOKEN:%.*]]) = begin_apply {{%.*}}([[CLASS]]) // CHECK: [[TEMP:%.*]] = alloc_stack $ // CHECK: [[MARK_TEMP:%.*]] = mark_must_check [consumable_and_assignable] [[TEMP]] -// CHECK: apply {{%.*}}([[MARK_TEMP]], {{%.*}}, [[BORROW]]) -// CHECK: end_borrow [[BORROW]] -// CHECK: destroy_value [[CORO_RESULT_COPY]] +// CHECK: apply {{%.*}}([[MARK_TEMP]], {{%.*}}, [[CORO_RESULT]]) +// CHECK: end_apply [[CORO_TOKEN]] +// CHECK: end_borrow [[CLASS]] +// CHECK: end_access [[ACCESS]] // CHECK: apply {{%.*}}([[MARK_TEMP]]) // CHECK: destroy_addr [[MARK_TEMP]] // @@ -2009,14 +1960,14 @@ public struct LoadableSubscriptReadModifyTesterNonCopyableStructParent : ~Copyab // CHECK: [[MARK:%.*]] = mark_must_check [no_consume_or_assign] [[ACCESS]] // CHECK: [[LOAD_BORROW:%.*]] = load_borrow [[MARK]] // CHECK: [[VALUE:%.*]] = apply {{%.*}}([[LOAD_BORROW]]) -// CHECK: end_borrow [[LOAD_BORROW]] -// CHECK: end_access [[ACCESS]] // CHECK: [[BORROWED_VALUE:%.*]] = begin_borrow [[VALUE]] // CHECK: ([[CORO_RESULT:%.*]], [[CORO_TOKEN:%.*]]) = begin_apply {{%.*}}({{%.*}}, [[BORROWED_VALUE]]) // CHECK: apply {{%.*}}([[CORO_RESULT]]) // CHECK: end_apply [[CORO_TOKEN]] // CHECK: end_borrow [[BORROWED_VALUE]] -// } // end sil function '$s8moveonly077testSubscriptReadModifyThroughNonCopyableParentStruct_BaseLoadable_ResultAddressE4_VaryyF' +// CHECK: end_borrow [[LOAD_BORROW]] +// CHECK: end_access [[ACCESS]] +// } // end sil function '$s8moveonly88testSubscriptReadModifyThroughNonCopyableParentStruct_BaseLoadable_ResultAddressOnly_VaryyF' public func testSubscriptReadModifyThroughNonCopyableParentStruct_BaseLoadable_ResultAddressOnly_Var() { var m = LoadableSubscriptReadModifyTesterNonCopyableStructParent() m = LoadableSubscriptReadModifyTesterNonCopyableStructParent() @@ -2034,13 +1985,9 @@ public func testSubscriptReadModifyThroughNonCopyableParentStruct_BaseLoadable_R // CHECK: [[MARK:%.*]] = mark_must_check [no_consume_or_assign] [[PROJECT]] // CHECK: [[LOAD:%.*]] = load_borrow [[MARK]] // CHECK: [[EXT:%.*]] = struct_extract [[LOAD]] -// CHECK: [[COPY:%.*]] = copy_value [[EXT]] -// CHECK: [[BORROW:%.*]] = begin_borrow [[COPY]] -// CHECK: ([[CORO_RESULT:%.*]], [[CORO_TOKEN:%.*]]) = begin_apply {{%.*}}({{%.*}}, [[BORROW]]) +// CHECK: ([[CORO_RESULT:%.*]], [[CORO_TOKEN:%.*]]) = begin_apply {{%.*}}({{%.*}}, [[EXT]]) // CHECK: apply {{%.*}}([[CORO_RESULT]]) // CHECK: end_apply [[CORO_TOKEN]] -// CHECK: end_borrow [[BORROW]] -// CHECK: destroy_value [[COPY]] // CHECK: end_borrow [[LOAD]] // CHECK: } // end sil function '$s8moveonly88testSubscriptReadModifyThroughNonCopyableParentStruct_BaseLoadable_ResultAddressOnly_LetyyF' public func testSubscriptReadModifyThroughNonCopyableParentStruct_BaseLoadable_ResultAddressOnly_Let() { @@ -2171,36 +2118,17 @@ public class LoadableSubscriptReadModifyTesterClassParent { // CHECK: end_borrow [[BORROW_COPYABLE_CLASS]] // CHECK: destroy_value [[COPYABLE_CLASS]] // -// Third read. This is a case that we can't handle today due to the way the AST -// looks: -// -// (subscript_expr type='AddressOnlyProtocol' -// (member_ref_expr type='LoadableSubscriptReadModifyTester' -// (load_expr implicit type='LoadableSubscriptReadModifyTesterClassParent' -// (declref_expr type='@lvalue LoadableSubscriptReadModifyTesterClassParent' -// (argument_list -// (argument -// (integer_literal_expr type='Int' -// -// due to the load_expr in the subscript base, SILGen emits a base rvalue for -// the load_expr and copies it, ending the coroutine. What we need is the -// ability to have an lvalue pseudo-component that treats the declref_expr (and -// any member_ref_expr) as a base and allows for a load_expr to be followed by N -// member_ref_expr. +// Third read. // // CHECK: [[ACCESS:%.*]] = begin_access [read] [unknown] [[PROJECT]] -// CHECK: [[COPYABLE_CLASS:%.*]] = load [copy] [[ACCESS]] -// CHECK: end_access [[ACCESS]] -// CHECK: [[BORROW_COPYABLE_CLASS:%.*]] = begin_borrow [[COPYABLE_CLASS]] -// CHECK: ([[CORO_RESULT:%.*]], [[CORO_TOKEN:%.*]]) = begin_apply {{%.*}}([[BORROW_COPYABLE_CLASS]]) -// CHECK: [[CORO_RESULT_COPY:%.*]] = copy_value [[CORO_RESULT]] -// CHECK: end_apply [[CORO_TOKEN]] -// CHECK: [[BORROW:%.*]] = begin_borrow [[CORO_RESULT_COPY]] -// CHECK: ([[CORO_RESULT:%.*]], [[CORO_TOKEN:%.*]]) = begin_apply {{%.*}}({{%.*}}, [[BORROW]]) -// CHECK: apply {{%.*}}([[CORO_RESULT]]) +// CHECK: [[CLASS:%.*]] = load_borrow [[ACCESS]] +// CHECK: ([[CORO_RESULT:%.*]], [[CORO_TOKEN:%.*]]) = begin_apply {{%.*}}([[CLASS]]) +// CHECK: ([[CORO_RESULT2:%.*]], [[CORO_TOKEN2:%.*]]) = begin_apply {{%.*}}({{%.*}}, [[CORO_RESULT]]) +// CHECK: apply {{%.*}}([[CORO_RESULT2]]) +// CHECK: end_apply [[CORO_TOKEN2]] // CHECK: end_apply [[CORO_TOKEN]] // CHECK: end_borrow [[BORROW]] -// CHECK: destroy_value [[CORO_RESULT_COPY]] +// CHECK: end_access [[ACCESS]] // // First read // CHECK: [[ACCESS:%.*]] = begin_access [read] [unknown] [[PROJECT]] @@ -2434,16 +2362,16 @@ public struct LoadableSubscriptGetModifyTesterNonCopyableStructParent : ~Copyabl // CHECK: [[MARK:%.*]] = mark_must_check [no_consume_or_assign] [[ACCESS]] // CHECK: [[LOAD_BORROW:%.*]] = load_borrow [[MARK]] // CHECK: [[VALUE:%.*]] = apply {{%.*}}([[LOAD_BORROW]]) -// CHECK: end_borrow [[LOAD_BORROW]] -// CHECK: end_access [[ACCESS]] // CHECK: [[BORROWED_VALUE:%.*]] = begin_borrow [[VALUE]] // CHECK: [[TEMP:%.*]] = alloc_stack $AddressOnlyProtocol // CHECK: [[MARK_TEMP:%.*]] = mark_must_check [consumable_and_assignable] [[TEMP]] // CHECK: apply {{%.*}}([[MARK_TEMP]], {{%.*}}, [[BORROWED_VALUE]]) // CHECK: end_borrow [[BORROWED_VALUE]] -// CHECK: destroy_value [[VALUE]] +// CHECK: end_borrow [[LOAD_BORROW]] +// CHECK: end_access [[ACCESS]] // CHECK: apply {{%.*}}([[MARK_TEMP]]) // CHECK: destroy_addr [[MARK_TEMP]] +// CHECK: destroy_value [[VALUE]] // } // end sil function '$s8moveonly077testSubscriptGetModifyThroughNonCopyableParentStruct_BaseLoadable_ResultAddressE4_VaryyF' public func testSubscriptGetModifyThroughNonCopyableParentStruct_BaseLoadable_ResultAddressOnly_Var() { var m = LoadableSubscriptGetModifyTesterNonCopyableStructParent() @@ -2461,13 +2389,9 @@ public func testSubscriptGetModifyThroughNonCopyableParentStruct_BaseLoadable_Re // CHECK: [[MARK:%.*]] = mark_must_check [no_consume_or_assign] [[PROJECT]] // CHECK: [[LOAD:%.*]] = load_borrow [[MARK]] // CHECK: [[EXT:%.*]] = struct_extract [[LOAD]] -// CHECK: [[COPY:%.*]] = copy_value [[EXT]] -// CHECK: [[BORROW:%.*]] = begin_borrow [[COPY]] // CHECK: [[TEMP:%.*]] = alloc_stack $AddressOnlyProtocol // CHECK: [[MARK_TEMP:%.*]] = mark_must_check [consumable_and_assignable] [[TEMP]] -// CHECK: apply {{%.*}}([[MARK_TEMP]], {{%.*}}, [[BORROW]]) -// CHECK: end_borrow [[BORROW]] -// CHECK: destroy_value [[COPY]] +// CHECK: apply {{%.*}}([[MARK_TEMP]], {{%.*}}, [[EXT]]) // CHECK: apply {{%.*}}([[MARK_TEMP]]) // CHECK: destroy_addr [[MARK_TEMP]] // CHECK: end_borrow [[LOAD]] @@ -2576,36 +2500,17 @@ public class LoadableSubscriptGetModifyTesterClassParent { // CHECK: end_apply [[CORO_TOKEN_2]] // CHECK: end_apply [[CORO_TOKEN]] // -// Third read. This is a case that we can't handle today due to the way the AST -// looks: -// -// (subscript_expr type='AddressOnlyProtocol' -// (member_ref_expr type='LoadableSubscriptGetModifyTester' -// (load_expr implicit type='LoadableSubscriptGetModifyTesterClassParent' -// (declref_expr type='@lvalue LoadableSubscriptGetModifyTesterClassParent' -// (argument_list -// (argument -// (integer_literal_expr type='Int' -// -// due to the load_expr in the subscript base, SILGen emits a base rvalue for -// the load_expr and copies it, ending the coroutine. What we need is the -// ability to have an lvalue pseudo-component that treats the declref_expr (and -// any member_ref_expr) as a base and allows for a load_expr to be followed by N -// member_ref_expr. +// Third read. // // CHECK: [[ACCESS:%.*]] = begin_access [read] [unknown] [[PROJECT]] -// CHECK: [[COPYABLE_CLASS:%.*]] = load [copy] [[ACCESS]] -// CHECK: end_access [[ACCESS]] -// CHECK: [[BORROW_COPYABLE_CLASS:%.*]] = begin_borrow [[COPYABLE_CLASS]] -// CHECK: ([[CORO_RESULT:%.*]], [[CORO_TOKEN:%.*]]) = begin_apply {{%.*}}([[BORROW_COPYABLE_CLASS]]) -// CHECK: [[CORO_RESULT_COPY:%.*]] = copy_value [[CORO_RESULT]] -// CHECK: end_apply [[CORO_TOKEN]] -// CHECK: [[BORROW:%.*]] = begin_borrow [[CORO_RESULT_COPY]] +// CHECK: [[CLASS:%.*]] = load_borrow [[ACCESS]] +// CHECK: ([[CORO_RESULT:%.*]], [[CORO_TOKEN:%.*]]) = begin_apply {{%.*}}([[CLASS]]) // CHECK: [[TEMP:%.*]] = alloc_stack $ // CHECK: [[MARK_TEMP:%.*]] = mark_must_check [consumable_and_assignable] [[TEMP]] -// CHECK: apply {{%.*}}([[MARK_TEMP]], {{%.*}}, [[BORROW]]) -// CHECK: end_borrow [[BORROW]] -// CHECK: destroy_value [[CORO_RESULT_COPY]] +// CHECK: apply {{%.*}}([[MARK_TEMP]], {{%.*}}, [[CORO_RESULT]]) +// CHECK: end_apply [[CORO_TOKEN]] +// CHECK: end_borrow [[CLASS]] +// CHECK: end_access [[ACCESS]] // CHECK: apply {{%.*}}([[MARK_TEMP]]) // CHECK: destroy_addr [[MARK_TEMP]] // diff --git a/test/SILGen/moveonly_library_evolution.swift b/test/SILGen/moveonly_library_evolution.swift index 3460127494e13..53256c4d20605 100644 --- a/test/SILGen/moveonly_library_evolution.swift +++ b/test/SILGen/moveonly_library_evolution.swift @@ -38,14 +38,8 @@ public struct DeinitTest : ~Copyable { // CHECK: bb0([[ARG:%.*]] : @guaranteed $CopyableKlass): // CHECK: [[ADDR:%.*]] = ref_element_addr [[ARG]] // CHECK: [[MARKED_ADDR:%.*]] = mark_must_check [no_consume_or_assign] [[ADDR]] -// CHECK: [[LOADED_VALUE:%.*]] = load [copy] [[MARKED_ADDR]] -// CHECK: [[BORROWED_LOADED_VALUE:%.*]] = begin_borrow [[LOADED_VALUE]] -// CHECK: [[EXT:%.*]] = struct_extract [[BORROWED_LOADED_VALUE]] -// CHECK: [[SPILL:%.*]] = alloc_stack $EmptyStruct -// CHECK: [[STORE_BORROW:%.*]] = store_borrow [[EXT]] to [[SPILL]] -// CHECK: apply {{%.*}}([[STORE_BORROW]]) : $@convention(thin) (@in_guaranteed EmptyStruct) -> () -// CHECK: end_borrow [[STORE_BORROW]] -// CHECK: end_borrow [[BORROWED_LOADED_VALUE]] +// CHECK: [[GEP:%.*]] = struct_element_addr [[MARKED_ADDR]] +// CHECK: apply {{%.*}}([[GEP]]) : $@convention(thin) (@in_guaranteed EmptyStruct) -> () // CHECK: } // end sil function '$s26moveonly_library_evolution29callerArgumentSpillingTestArgyyAA13CopyableKlassCF' public func callerArgumentSpillingTestArg(_ x: CopyableKlass) { borrowVal(x.letStruct.e) diff --git a/test/SILOptimizer/moveonly_addressonly_subscript_diagnostics.swift b/test/SILOptimizer/moveonly_addressonly_subscript_diagnostics.swift index 835ac51629f85..3c4cd85343037 100644 --- a/test/SILOptimizer/moveonly_addressonly_subscript_diagnostics.swift +++ b/test/SILOptimizer/moveonly_addressonly_subscript_diagnostics.swift @@ -87,7 +87,6 @@ public func testSubscriptGetOnlyThroughParentClass_BaseLoadable_ResultAddressOnl m.testerParent.tester[0].nonMutatingFunc() // expected-error @-1 {{copy of noncopyable typed value}} m.computedTester[0].nonMutatingFunc() - // expected-error @-1 {{copy of noncopyable typed value}} } // MARK: Getter + Setter. @@ -178,7 +177,6 @@ public func testSubscriptGetSetThroughParentClass_BaseLoadable_ResultAddressOnly m.testerParent.tester[0].nonMutatingFunc() m.testerParent.tester[0].mutatingFunc() m.computedTester[0].nonMutatingFunc() - // expected-error @-1 {{copy of noncopyable typed value}} m.computedTester2[0].nonMutatingFunc() m.computedTester2[0].mutatingFunc() } @@ -272,7 +270,6 @@ public func testSubscriptReadModifyThroughParentClass_BaseLoadable_ResultAddress m.testerParent.tester[0].nonMutatingFunc() m.testerParent.tester[0].mutatingFunc() m.computedTester[0].nonMutatingFunc() - // expected-error @-1 {{copy of noncopyable typed value}} m.computedTester2[0].nonMutatingFunc() m.computedTester2[0].mutatingFunc() } @@ -359,7 +356,6 @@ public func testSubscriptGetModifyThroughParentClass_BaseLoadable_ResultAddressO m.testerParent.tester[0].nonMutatingFunc() m.testerParent.tester[0].mutatingFunc() m.computedTester[0].nonMutatingFunc() - // expected-error @-1 {{copy of noncopyable typed value}} m.computedTester2[0].nonMutatingFunc() m.computedTester2[0].mutatingFunc() } diff --git a/test/SILOptimizer/moveonly_loadable_subscript_diagnostics.swift b/test/SILOptimizer/moveonly_loadable_subscript_diagnostics.swift index a29244eaab356..bf58a32b330b9 100644 --- a/test/SILOptimizer/moveonly_loadable_subscript_diagnostics.swift +++ b/test/SILOptimizer/moveonly_loadable_subscript_diagnostics.swift @@ -84,7 +84,6 @@ public func testSubscriptGetOnlyThroughParentClass_BaseLoadable_ResultLoadable_V m.testerParent.tester[0].nonMutatingFunc() // expected-error @-1 {{copy of noncopyable typed value}} m.computedTester[0].nonMutatingFunc() - // expected-error @-1 {{copy of noncopyable typed value}} } // MARK: Getter + Setter. @@ -175,7 +174,6 @@ public func testSubscriptGetSetThroughParentClass_BaseLoadable_ResultLoadable_Va m.testerParent.tester[0].nonMutatingFunc() m.testerParent.tester[0].mutatingFunc() m.computedTester[0].nonMutatingFunc() - // expected-error @-1 {{copy of noncopyable typed value}} m.computedTester2[0].nonMutatingFunc() m.computedTester2[0].mutatingFunc() } @@ -269,7 +267,6 @@ public func testSubscriptReadModifyThroughParentClass_BaseLoadable_ResultLoadabl m.testerParent.tester[0].nonMutatingFunc() m.testerParent.tester[0].mutatingFunc() m.computedTester[0].nonMutatingFunc() - // expected-error @-1 {{copy of noncopyable typed value}} m.computedTester2[0].nonMutatingFunc() m.computedTester2[0].mutatingFunc() } @@ -356,7 +353,6 @@ public func testSubscriptGetModifyThroughParentClass_BaseLoadable_ResultLoadable m.testerParent.tester[0].nonMutatingFunc() m.testerParent.tester[0].mutatingFunc() m.computedTester[0].nonMutatingFunc() - // expected-error @-1 {{copy of noncopyable typed value}} m.computedTester2[0].nonMutatingFunc() m.computedTester2[0].mutatingFunc() } diff --git a/test/SILOptimizer/moveonly_partial_consumption.swift b/test/SILOptimizer/moveonly_partial_consumption.swift index 9292b02f8f8cd..be19336eea69b 100644 --- a/test/SILOptimizer/moveonly_partial_consumption.swift +++ b/test/SILOptimizer/moveonly_partial_consumption.swift @@ -224,11 +224,10 @@ func addressOnlyTestArg(_ x: borrowing AddressOnlyType) { // expected-error @-1 {{'x' is borrowed and cannot be consumed}} // expected-error @-2 {{'x' is borrowed and cannot be consumed}} // expected-error @-3 {{'x' is borrowed and cannot be consumed}} - // expected-error @-4 {{'x' is borrowed and cannot be consumed}} let _ = x.e // expected-note {{consumed here}} let _ = x.k let _ = x.l.e // expected-note {{consumed here}} - let _ = x.l.k // expected-note {{consumed here}} + let _ = x.l.k switch x.lEnum { // expected-note {{consumed here}} case .first: break From 26081ffb8219ce6b8c412ee3e86682152ec0e6f1 Mon Sep 17 00:00:00 2001 From: Michael Gottesman Date: Mon, 24 Jul 2023 17:37:58 -0700 Subject: [PATCH 04/37] [silgen] Teach accessor projection to use store_borrow if it has a non-tuple. This prevents another type of copy of noncopyable value error. I also as a small change, changed the tuple version to use a formal access temporary since we are projecting a component out implying that the lifetime of the temporary must end within the formal access. Otherwise, we cause the lifetime of the temporary to outlive the access. This can be seen in the change to read_accessor.swift where we used to extend the lifetime of the destroy_addr outside of the coroutine access we are performing. --- lib/SIL/Verifier/SILVerifier.cpp | 7 ++++++- lib/SILGen/SILGenLValue.cpp | 17 +++++++++++++++-- test/SILGen/moveonly.swift | 14 ++++++-------- test/SILGen/read_accessor.swift | 1 + ...only_addressonly_subscript_diagnostics.swift | 2 -- ...oveonly_loadable_subscript_diagnostics.swift | 2 -- 6 files changed, 28 insertions(+), 15 deletions(-) diff --git a/lib/SIL/Verifier/SILVerifier.cpp b/lib/SIL/Verifier/SILVerifier.cpp index 820b47576cf59..88c1fee071fe9 100644 --- a/lib/SIL/Verifier/SILVerifier.cpp +++ b/lib/SIL/Verifier/SILVerifier.cpp @@ -2595,7 +2595,12 @@ class SILVerifier : public SILVerifierBase { require(SI->getDest()->getType().isAddress(), "Must store to an address dest"); // Note: This is the current implementation and the design is not final. - require(isa(SI->getDest()), + auto isLegal = [](SILValue value) { + if (auto *mmci = dyn_cast(value)) + value = mmci->getOperand(); + return isa(value); + }; + require(isLegal(SI->getDest()), "store_borrow destination can only be an alloc_stack"); requireSameType(SI->getDest()->getType().getObjectType(), SI->getSrc()->getType(), diff --git a/lib/SILGen/SILGenLValue.cpp b/lib/SILGen/SILGenLValue.cpp index c7ec40a069cb6..66bd8a3aff297 100644 --- a/lib/SILGen/SILGenLValue.cpp +++ b/lib/SILGen/SILGenLValue.cpp @@ -2156,12 +2156,25 @@ namespace { if (value.getType().isAddress() || !isReadAccessResultAddress(getAccessKind())) return value; + + // If we have a guaranteed object and our read access result requires an + // address, store it using a store_borrow. + if (value.getType().isObject() && + value.getOwnershipKind() == OwnershipKind::Guaranteed) { + SILValue alloc = SGF.emitTemporaryAllocation(loc, getTypeOfRValue()); + if (alloc->getType().isMoveOnly()) + alloc = SGF.B.createMarkMustCheckInst( + loc, alloc, MarkMustCheckInst::CheckKind::NoConsumeOrAssign); + return SGF.B.createFormalAccessStoreBorrow(loc, value, alloc); + } } // Otherwise, we need to make a temporary. + // TODO: This needs to be changed to use actual store_borrows. Noncopyable + // types do not support tuples today, so we can avoid this for now. // TODO: build a scalar tuple if possible. - auto temporary = - SGF.emitTemporary(loc, SGF.getTypeLowering(getTypeOfRValue())); + auto temporary = SGF.emitFormalAccessTemporary( + loc, SGF.getTypeLowering(getTypeOfRValue())); auto yieldsAsArray = llvm::makeArrayRef(yields); copyBorrowedYieldsIntoTemporary(SGF, loc, yieldsAsArray, getOrigFormalType(), getSubstFormalType(), diff --git a/test/SILGen/moveonly.swift b/test/SILGen/moveonly.swift index 87837a9696c29..fa1f0f39e7b48 100644 --- a/test/SILGen/moveonly.swift +++ b/test/SILGen/moveonly.swift @@ -1265,10 +1265,9 @@ public class LoadableSubscriptGetOnlyTesterClassParent { // CHECK: [[BORROW_COPYABLE_CLASS:%.*]] = begin_borrow [[COPYABLE_CLASS]] // CHECK: ([[CORO_RESULT:%.*]], [[CORO_TOKEN:%.*]]) = begin_apply {{%.*}}([[BORROW_COPYABLE_CLASS]]) // CHECK: [[TEMP:%.*]] = alloc_stack $LoadableSubscriptGetOnlyTester -// CHECK: [[TEMP_MARK:%.*]] = mark_must_check [consumable_and_assignable] [[TEMP]] -// CHECK: [[CORO_RESULT_COPY:%.*]] = copy_value [[CORO_RESULT]] -// CHECK: store [[CORO_RESULT_COPY]] to [init] [[TEMP_MARK]] -// CHECK: [[LOAD:%.*]] = load_borrow [[TEMP_MARK]] +// CHECK: [[TEMP_MARK:%.*]] = mark_must_check [no_consume_or_assign] [[TEMP]] +// CHECK: [[TEMP_MARK_BORROW:%.*]] = store_borrow [[CORO_RESULT]] to [[TEMP_MARK]] +// CHECK: [[LOAD:%.*]] = load_borrow [[TEMP_MARK_BORROW]] // CHECK: [[TEMP2:%.*]] = alloc_stack $ // CHECK: [[TEMP2_MARK:%.*]] = mark_must_check [consumable_and_assignable] [[TEMP2]] // CHECK: apply {{%.*}}([[TEMP2_MARK]], {{%.*}}, [[LOAD]]) @@ -1284,10 +1283,9 @@ public class LoadableSubscriptGetOnlyTesterClassParent { // CHECK: [[BORROW_COPYABLE_CLASS:%.*]] = begin_borrow [[COPYABLE_CLASS]] // CHECK: ([[CORO_RESULT:%.*]], [[CORO_TOKEN:%.*]]) = begin_apply {{%.*}}([[BORROW_COPYABLE_CLASS]]) // CHECK: [[TEMP:%.*]] = alloc_stack $LoadableSubscriptGetOnlyTester -// CHECK: [[TEMP_MARK:%.*]] = mark_must_check [consumable_and_assignable] [[TEMP]] -// CHECK: [[CORO_RESULT_COPY:%.*]] = copy_value [[CORO_RESULT]] -// CHECK: store [[CORO_RESULT_COPY]] to [init] [[TEMP_MARK]] -// CHECK: [[GEP:%.*]] = struct_element_addr [[TEMP_MARK]] +// CHECK: [[TEMP_MARK:%.*]] = mark_must_check [no_consume_or_assign] [[TEMP]] +// CHECK: [[TEMP_MARK_BORROW:%.*]] = store_borrow [[CORO_RESULT]] to [[TEMP_MARK]] +// CHECK: [[GEP:%.*]] = struct_element_addr [[TEMP_MARK_BORROW]] // CHECK: [[LOAD:%.*]] = load_borrow [[GEP]] // CHECK: [[TEMP2:%.*]] = alloc_stack $ // CHECK: [[TEMP2_MARK:%.*]] = mark_must_check [consumable_and_assignable] [[TEMP2]] diff --git a/test/SILGen/read_accessor.swift b/test/SILGen/read_accessor.swift index 85fa3dfbac49d..39d9f4e4b495b 100644 --- a/test/SILGen/read_accessor.swift +++ b/test/SILGen/read_accessor.swift @@ -112,6 +112,7 @@ struct TupleReader { // CHECK-NEXT: [[TUPLE:%.*]] = load [copy] [[TEMP]] // CHECK-NEXT: destructure_tuple // CHECK-NEXT: destructure_tuple +// CHECK-NEXT: destroy_addr [[TEMP]] // CHECK-NEXT: end_apply // CHECK-LABEL: } // end sil function '$s13read_accessor11TupleReaderV11useReadableyyF' func useReadable() { diff --git a/test/SILOptimizer/moveonly_addressonly_subscript_diagnostics.swift b/test/SILOptimizer/moveonly_addressonly_subscript_diagnostics.swift index 3c4cd85343037..212b5982900c2 100644 --- a/test/SILOptimizer/moveonly_addressonly_subscript_diagnostics.swift +++ b/test/SILOptimizer/moveonly_addressonly_subscript_diagnostics.swift @@ -83,9 +83,7 @@ public func testSubscriptGetOnlyThroughParentClass_BaseLoadable_ResultAddressOnl var m = LoadableSubscriptGetOnlyTesterClassParent() m = LoadableSubscriptGetOnlyTesterClassParent() m.tester[0].nonMutatingFunc() - // expected-error @-1 {{copy of noncopyable typed value}} m.testerParent.tester[0].nonMutatingFunc() - // expected-error @-1 {{copy of noncopyable typed value}} m.computedTester[0].nonMutatingFunc() } diff --git a/test/SILOptimizer/moveonly_loadable_subscript_diagnostics.swift b/test/SILOptimizer/moveonly_loadable_subscript_diagnostics.swift index bf58a32b330b9..17074c733e26e 100644 --- a/test/SILOptimizer/moveonly_loadable_subscript_diagnostics.swift +++ b/test/SILOptimizer/moveonly_loadable_subscript_diagnostics.swift @@ -80,9 +80,7 @@ public func testSubscriptGetOnlyThroughParentClass_BaseLoadable_ResultLoadable_V var m = LoadableSubscriptGetOnlyTesterClassParent() m = LoadableSubscriptGetOnlyTesterClassParent() m.tester[0].nonMutatingFunc() - // expected-error @-1 {{copy of noncopyable typed value}} m.testerParent.tester[0].nonMutatingFunc() - // expected-error @-1 {{copy of noncopyable typed value}} m.computedTester[0].nonMutatingFunc() } From df17f19979f3eeb347f119e014238731d48ff55b Mon Sep 17 00:00:00 2001 From: Amritpan Kaur Date: Mon, 31 Jul 2023 16:00:27 -0700 Subject: [PATCH 05/37] [ConstraintSystem] Set up key path root lookups. --- include/swift/Sema/ConstraintSystem.h | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/include/swift/Sema/ConstraintSystem.h b/include/swift/Sema/ConstraintSystem.h index 460e664238360..5227daff14c8b 100644 --- a/include/swift/Sema/ConstraintSystem.h +++ b/include/swift/Sema/ConstraintSystem.h @@ -3031,6 +3031,20 @@ class ConstraintSystem { return std::get<1>(result->second); return nullptr; } + + TypeVariableType *getKeyPathRootType(const KeyPathExpr *keyPath) const { + auto result = getKeyPathRootTypeIfAvailable(keyPath); + assert(result); + return result; + } + + TypeVariableType * + getKeyPathRootTypeIfAvailable(const KeyPathExpr *keyPath) const { + auto result = KeyPaths.find(keyPath); + if (result != KeyPaths.end()) + return std::get<0>(result->second); + return nullptr; + } TypeBase* getFavoredType(Expr *E) { assert(E != nullptr); From 1384ff0038d7c30bb2340eae3ae01987b90e589e Mon Sep 17 00:00:00 2001 From: Amritpan Kaur Date: Mon, 31 Jul 2023 15:57:21 -0700 Subject: [PATCH 06/37] [CSBinding] Allow inference to bind AnyKeyPath as a KeyPath that can be converted to AnyKeyPath later. --- include/swift/Sema/ConstraintLocator.h | 16 -------- .../swift/Sema/ConstraintLocatorPathElts.def | 2 +- lib/Sema/CSBindings.cpp | 39 +++++++++++-------- 3 files changed, 24 insertions(+), 33 deletions(-) diff --git a/include/swift/Sema/ConstraintLocator.h b/include/swift/Sema/ConstraintLocator.h index 909a09265f4b1..7089cf62f0ba5 100644 --- a/include/swift/Sema/ConstraintLocator.h +++ b/include/swift/Sema/ConstraintLocator.h @@ -1058,22 +1058,6 @@ class LocatorPathElt::ContextualType final : public StoredIntegerElement<1> { } }; -class LocatorPathElt::KeyPathType final - : public StoredPointerElement { -public: - KeyPathType(Type valueType) - : StoredPointerElement(PathElementKind::KeyPathType, - valueType.getPointer()) { - assert(valueType); - } - - Type getValueType() const { return getStoredPointer(); } - - static bool classof(const LocatorPathElt *elt) { - return elt->getKind() == PathElementKind::KeyPathType; - } -}; - class LocatorPathElt::ConstructorMemberType final : public StoredIntegerElement<1> { public: diff --git a/include/swift/Sema/ConstraintLocatorPathElts.def b/include/swift/Sema/ConstraintLocatorPathElts.def index c60640e114cd9..c412b998a827a 100644 --- a/include/swift/Sema/ConstraintLocatorPathElts.def +++ b/include/swift/Sema/ConstraintLocatorPathElts.def @@ -123,7 +123,7 @@ CUSTOM_LOCATOR_PATH_ELT(KeyPathDynamicMember) SIMPLE_LOCATOR_PATH_ELT(KeyPathRoot) /// The type of the key path expression. -CUSTOM_LOCATOR_PATH_ELT(KeyPathType) +SIMPLE_LOCATOR_PATH_ELT(KeyPathType) /// The value of a key path. SIMPLE_LOCATOR_PATH_ELT(KeyPathValue) diff --git a/lib/Sema/CSBindings.cpp b/lib/Sema/CSBindings.cpp index 9547f4c3a06cd..c2fe1e14f165f 100644 --- a/lib/Sema/CSBindings.cpp +++ b/lib/Sema/CSBindings.cpp @@ -1250,27 +1250,34 @@ PotentialBindings::inferFromRelational(Constraint *constraint) { return llvm::None; if (TypeVar->getImpl().isKeyPathType()) { - auto *BGT = type->lookThroughAllOptionalTypes()->getAs(); - if (!BGT || !isKnownKeyPathType(BGT)) + auto objectTy = type->lookThroughAllOptionalTypes(); + if (!isKnownKeyPathType(objectTy)) return llvm::None; - - // `PartialKeyPath` represents a type-erased version of `KeyPath`. + + auto &ctx = CS.getASTContext(); + auto *keyPathTypeLoc = TypeVar->getImpl().getLocator(); + auto *keyPath = castToExpr(keyPathTypeLoc->getAnchor()); + // `AnyKeyPath` and `PartialKeyPath` represent type-erased versions of + // `KeyPath`. // - // In situations where partial key path cannot be used directly i.e. - // passing an argument to a parameter represented by a partial key path, - // let's attempt a `KeyPath` binding which would then be converted to a - // partial key path since there is a subtype relationship between them. - if (BGT->isPartialKeyPath() && kind == AllowedBindingKind::Subtypes) { - auto &ctx = CS.getASTContext(); - auto *keyPathLoc = TypeVar->getImpl().getLocator(); - - auto rootTy = BGT->getGenericArgs()[0]; + // In situations where `AnyKeyPath` or `PartialKeyPath` cannot be used + // directly i.e. passing an argument to a parameter represented by a + // `AnyKeyPath` or `PartialKeyPath`, let's attempt a `KeyPath` binding which + // would then be converted to a `AnyKeyPath` or `PartialKeyPath` since there + // is a subtype relationship between them. + if (objectTy->isAnyKeyPath()) { + auto root = CS.getKeyPathRootType(keyPath); + auto value = CS.getKeyPathValueType(keyPath); + + type = BoundGenericType::get(ctx.getKeyPathDecl(), Type(), + {root, value}); + } else if (objectTy->isPartialKeyPath() && + kind == AllowedBindingKind::Subtypes) { + auto rootTy = objectTy->castTo()->getGenericArgs()[0]; // Since partial key path is an erased version of `KeyPath`, the value // type would never be used, which means that binding can use // type variable generated for a result of key path expression. - auto valueTy = - keyPathLoc->castLastElementTo() - .getValueType(); + auto valueTy = CS.getKeyPathValueType(keyPath); type = BoundGenericType::get(ctx.getKeyPathDecl(), Type(), {rootTy, valueTy}); From e8425bf4c83eb8f9242f928865f6493d5b15e20e Mon Sep 17 00:00:00 2001 From: Amritpan Kaur Date: Mon, 31 Jul 2023 16:07:32 -0700 Subject: [PATCH 07/37] [CSGen] Update LocatorPathElt::KeyPathType usage. --- lib/Sema/CSGen.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Sema/CSGen.cpp b/lib/Sema/CSGen.cpp index bc4249f00e9bb..1a06ff7d53db3 100644 --- a/lib/Sema/CSGen.cpp +++ b/lib/Sema/CSGen.cpp @@ -3822,7 +3822,7 @@ namespace { // The type of key path depends on the overloads chosen for the key // path components. auto typeLoc = - CS.getConstraintLocator(locator, LocatorPathElt::KeyPathType(value)); + CS.getConstraintLocator(locator, LocatorPathElt::KeyPathType()); Type kpTy = CS.createTypeVariable(typeLoc, TVO_CanBindToNoEscape | TVO_CanBindToHole); From e2bac24023d7028168f7f10233ae4f8bd568a343 Mon Sep 17 00:00:00 2001 From: Pavel Yaskevich Date: Wed, 2 Aug 2023 10:30:41 -0700 Subject: [PATCH 08/37] [TypeChecker] Make sure that distributed actors always get "default" init Default initialization of stored properties doesn't play a role in default init synthesis for distributed actors. --- lib/Sema/CodeSynthesis.cpp | 6 ++++++ test/decl/protocol/special/DistributedActor.swift | 6 ------ 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/lib/Sema/CodeSynthesis.cpp b/lib/Sema/CodeSynthesis.cpp index 8dd7b5fd1a236..439ca84245564 100644 --- a/lib/Sema/CodeSynthesis.cpp +++ b/lib/Sema/CodeSynthesis.cpp @@ -1490,6 +1490,12 @@ HasDefaultInitRequest::evaluate(Evaluator &evaluator, if (hasUserDefinedDesignatedInit(evaluator, decl)) return false; + // Regardless of whether all of the properties are initialized or + // not distributed actors always get a special "default" init based + // on `id` and `actorSystem` synthesized properties. + if (decl->isDistributedActor()) + return true; + // We can only synthesize a default init if all the stored properties have an // initial value. return areAllStoredPropertiesDefaultInitializable(evaluator, decl); diff --git a/test/decl/protocol/special/DistributedActor.swift b/test/decl/protocol/special/DistributedActor.swift index cc141b5a7c908..4061801d9a136 100644 --- a/test/decl/protocol/special/DistributedActor.swift +++ b/test/decl/protocol/special/DistributedActor.swift @@ -35,10 +35,8 @@ extension DAP where ActorSystem.ActorID == String { } distributed actor D2 { - // expected-error@-1{{actor 'D2' has no initializers}} let actorSystem: String // expected-error@-1{{property 'actorSystem' cannot be defined explicitly, as it conflicts with distributed actor synthesized stored property}} - // expected-note@-2{{stored property 'actorSystem' without initial value prevents synthesized initializers}} } distributed actor D3 { @@ -49,14 +47,10 @@ distributed actor D3 { struct OtherActorIdentity: Sendable, Hashable, Codable {} distributed actor D4 { - // expected-error@-1{{actor 'D4' has no initializers}} - let actorSystem: String // expected-error@-1{{property 'actorSystem' cannot be defined explicitly, as it conflicts with distributed actor synthesized stored property}} - // expected-note@-2{{stored property 'actorSystem' without initial value prevents synthesized initializers}} let id: OtherActorIdentity // expected-error@-1{{property 'id' cannot be defined explicitly, as it conflicts with distributed actor synthesized stored property}} - // expected-note@-2{{stored property 'id' without initial value prevents synthesized initializers}} } protocol P1: DistributedActor { From c3d22762411fcb97655410d9e30ac55f9d995f5d Mon Sep 17 00:00:00 2001 From: Michael Gottesman Date: Mon, 24 Jul 2023 18:16:06 -0700 Subject: [PATCH 09/37] [silgen] Eliminate two more cases around subscripts where we were not borrowing. Also, the store_borrow work in the previous patch caused some additional issues to crop up. I fixed them in this PR and added some tests in the process. --- include/swift/SIL/SILInstruction.h | 10 ++ .../Utils/FieldSensitivePrunedLiveness.cpp | 5 + lib/SILGen/SILGenApply.cpp | 36 +++--- lib/SILGen/SILGenLValue.cpp | 2 + .../Mandatory/MoveOnlyAddressCheckerUtils.cpp | 43 ++++++- lib/SILOptimizer/Mandatory/MoveOnlyUtils.cpp | 3 + test/SILGen/moveonly.swift | 22 ++-- test/SILOptimizer/moveonly_addresschecker.sil | 15 +++ .../moveonly_addresschecker_diagnostics.sil | 115 +++++++++++++++++- ...ly_addressonly_subscript_diagnostics.swift | 1 - ...eonly_loadable_subscript_diagnostics.swift | 1 - 11 files changed, 215 insertions(+), 38 deletions(-) diff --git a/include/swift/SIL/SILInstruction.h b/include/swift/SIL/SILInstruction.h index 341a6ab9bd14e..0c1cc0e84600a 100644 --- a/include/swift/SIL/SILInstruction.h +++ b/include/swift/SIL/SILInstruction.h @@ -4392,8 +4392,18 @@ class StoreBorrowInst ArrayRef getAllOperands() const { return Operands.asArray(); } MutableArrayRef getAllOperands() { return Operands.asArray(); } + + using EndBorrowRange = + decltype(std::declval().getUsersOfType()); + + /// Return a range over all EndBorrow instructions for this BeginBorrow. + EndBorrowRange getEndBorrows() const; }; +inline auto StoreBorrowInst::getEndBorrows() const -> EndBorrowRange { + return getUsersOfType(); +} + /// Represents the end of a borrow scope of a value %val from a /// value or address %src. /// diff --git a/lib/SIL/Utils/FieldSensitivePrunedLiveness.cpp b/lib/SIL/Utils/FieldSensitivePrunedLiveness.cpp index baa3e424e883e..164d4bfc162c8 100644 --- a/lib/SIL/Utils/FieldSensitivePrunedLiveness.cpp +++ b/lib/SIL/Utils/FieldSensitivePrunedLiveness.cpp @@ -129,6 +129,11 @@ SubElementOffset::computeForAddress(SILValue projectionDerivedFromRoot, continue; } + if (auto *sbi = dyn_cast(projectionDerivedFromRoot)) { + projectionDerivedFromRoot = sbi->getDest(); + continue; + } + if (auto *m = dyn_cast( projectionDerivedFromRoot)) { projectionDerivedFromRoot = m->getOperand(); diff --git a/lib/SILGen/SILGenApply.cpp b/lib/SILGen/SILGenApply.cpp index d6d4a5ca7af5a..c2e3f81e7faae 100644 --- a/lib/SILGen/SILGenApply.cpp +++ b/lib/SILGen/SILGenApply.cpp @@ -3124,6 +3124,21 @@ Expr *ArgumentSource::findStorageReferenceExprForMoveOnly( if (kind == StorageReferenceOperationKind::Consume && !sawLoad) return nullptr; + // If we did not see a load or a subscript expr and our argExpr is a + // declref_expr, return nullptr. We have an object not something that will be + // in memory. This can happen with classes or with values captured by a + // closure. + // + // NOTE: If we see a member_ref_expr from a decl_ref_expr, we still process it + // since the declref_expr could be from a class. + if (!sawLoad && !subscriptExpr) { + if (auto *declRef = dyn_cast(argExpr)) { + assert(!declRef->getType()->is() && + "Shouldn't ever have an lvalue type here!"); + return nullptr; + } + } + auto result = ::findStorageReferenceExprForBorrow(argExpr); if (!result) @@ -3143,31 +3158,18 @@ Expr *ArgumentSource::findStorageReferenceExprForMoveOnly( } if (!storage) - return nullptr; + return nullptr; assert(type); SILType ty = SGF.getLoweredType(type->getWithoutSpecifierType()->getCanonicalType()); bool isMoveOnly = ty.isPureMoveOnly(); if (auto *pd = dyn_cast(storage)) { - isMoveOnly |= pd->getSpecifier() == ParamSpecifier::Borrowing; - isMoveOnly |= pd->getSpecifier() == ParamSpecifier::Consuming; + isMoveOnly |= pd->getSpecifier() == ParamSpecifier::Borrowing; + isMoveOnly |= pd->getSpecifier() == ParamSpecifier::Consuming; } if (!isMoveOnly) - return nullptr; - - // It makes sense to borrow any kind of storage we refer to at this stage, - // but SILGenLValue does not currently handle some kinds of references well. - // - // When rejecting to do the LValue-style borrow here, it'll end up going thru - // the RValue-style emission, after which the extra copy will get eliminated. - // - // If we did not see a LoadExpr around the argument expression, then only - // do the borrow if the storage is non-local. - // FIXME: I don't have a principled reason for why this matters and hope that - // we can fix the AST we're working with. - if (!sawLoad && storage->getDeclContext()->isLocalContext()) - return nullptr; + return nullptr; // Claim the value of this argument since we found a storage reference that // has a move only base. diff --git a/lib/SILGen/SILGenLValue.cpp b/lib/SILGen/SILGenLValue.cpp index 66bd8a3aff297..9b6c265a41567 100644 --- a/lib/SILGen/SILGenLValue.cpp +++ b/lib/SILGen/SILGenLValue.cpp @@ -2918,6 +2918,8 @@ class LLVM_LIBRARY_VISIBILITY SILGenBorrowedBaseVisitor return m->getType()->isPureMoveOnly() || m->getBase()->getType()->isPureMoveOnly(); } + if (auto *d = dyn_cast(e)) + return e->getType()->isPureMoveOnly(); return false; } diff --git a/lib/SILOptimizer/Mandatory/MoveOnlyAddressCheckerUtils.cpp b/lib/SILOptimizer/Mandatory/MoveOnlyAddressCheckerUtils.cpp index a3296d40f12dc..dae2f0ef191b3 100644 --- a/lib/SILOptimizer/Mandatory/MoveOnlyAddressCheckerUtils.cpp +++ b/lib/SILOptimizer/Mandatory/MoveOnlyAddressCheckerUtils.cpp @@ -280,6 +280,8 @@ using namespace swift; using namespace swift::siloptimizer; +#pragma clang optimize off + llvm::cl::opt DisableMoveOnlyAddressCheckerLifetimeExtension( "move-only-address-checker-disable-lifetime-extension", llvm::cl::init(false), @@ -342,7 +344,7 @@ static void convertMemoryReinitToInitForm(SILInstruction *memInst, break; } } - + // Insert a new debug_value instruction after the reinitialization, so that // the debugger knows that the variable is in a usable form again. insertDebugValueBefore(memInst->getNextInstruction(), debugVar, @@ -931,6 +933,7 @@ void UseState::initializeLiveness( // We begin by initializing all of our init uses. for (auto initInstAndValue : initInsts) { LLVM_DEBUG(llvm::dbgs() << "Found def: " << *initInstAndValue.first); + liveness.initializeDef(initInstAndValue.first, initInstAndValue.second); } @@ -944,7 +947,7 @@ void UseState::initializeLiveness( reinitInstAndValue.second); } } - + // Then check if our markedValue is from an argument that is in, // in_guaranteed, inout, or inout_aliasable, consider the marked address to be // the initialization point. @@ -1050,6 +1053,26 @@ void UseState::initializeLiveness( LLVM_DEBUG(llvm::dbgs() << "Liveness with just inits:\n"; liveness.print(llvm::dbgs())); + for (auto initInstAndValue : initInsts) { + // If our init inst is a store_borrow, treat the end_borrow as liveness + // uses. + // + // NOTE: We do not need to check for access scopes here since store_borrow + // can only apply to alloc_stack today. + if (auto *sbi = dyn_cast(initInstAndValue.first)) { + // We can only store_borrow if our mark_must_check is a + // no_consume_or_assign. + assert(address->getCheckKind() == + MarkMustCheckInst::CheckKind::NoConsumeOrAssign && + "store_borrow implies no_consume_or_assign since we cannot " + "consume a borrowed inited value"); + for (auto *ebi : sbi->getEndBorrows()) { + liveness.updateForUse(ebi, initInstAndValue.second, + false /*lifetime ending*/); + } + } + } + // Now at this point, we have defined all of our defs so we can start adding // uses to the liveness. for (auto reinitInstAndValue : reinitInsts) { @@ -1979,6 +2002,22 @@ bool GatherUsesVisitor::visitUse(Operand *op) { if (isa(user)) return true; + // This visitor looks through store_borrow instructions but does visit the + // end_borrow of the store_borrow. If we see such an end_borrow, register the + // store_borrow instead. Since we use sets, if we visit multiple end_borrows, + // we will only record the store_borrow once. + if (auto *ebi = dyn_cast(user)) { + if (auto *sbi = dyn_cast(ebi->getOperand())) { + LLVM_DEBUG(llvm::dbgs() << "Found store_borrow: " << *sbi); + auto leafRange = TypeTreeLeafTypeRange::get(op->get(), getRootAddress()); + if (!leafRange) + return false; + + useState.recordInitUse(user, op->get(), *leafRange); + return true; + } + } + if (auto *di = dyn_cast(user)) { // Save the debug_value if it is attached directly to this mark_must_check. // If the underlying storage we're checking is immutable, then the access diff --git a/lib/SILOptimizer/Mandatory/MoveOnlyUtils.cpp b/lib/SILOptimizer/Mandatory/MoveOnlyUtils.cpp index 7e12810e42503..172f07d0339b0 100644 --- a/lib/SILOptimizer/Mandatory/MoveOnlyUtils.cpp +++ b/lib/SILOptimizer/Mandatory/MoveOnlyUtils.cpp @@ -220,6 +220,9 @@ bool noncopyable::memInstMustInitialize(Operand *memOper) { case SILInstructionKind::Store##Name##Inst: \ return cast(memInst)->isInitializationOfDest(); #include "swift/AST/ReferenceStorage.def" + + case SILInstructionKind::StoreBorrowInst: + return true; } } diff --git a/test/SILGen/moveonly.swift b/test/SILGen/moveonly.swift index fa1f0f39e7b48..04129fc51560c 100644 --- a/test/SILGen/moveonly.swift +++ b/test/SILGen/moveonly.swift @@ -818,13 +818,10 @@ func enumSwitchTest1(_ e: borrowing EnumSwitchTests.E) { // // CHECK: [[GLOBAL:%.*]] = global_addr @$s8moveonly9letGlobalAA16NonTrivialStructVvp : // CHECK: [[MARKED_GLOBAL:%.*]] = mark_must_check [no_consume_or_assign] [[GLOBAL]] -// FIXME: this copy probably shouldn't be here when accessing through the letGlobal, but maybe it's cleaned up? -// CHECK: [[LOADED_VAL:%.*]] = load [copy] [[MARKED_GLOBAL]] : $*NonTrivialStruct -// CHECK: [[LOADED_BORROWED_VAL:%.*]] = begin_borrow [[LOADED_VAL]] -// CHECK: [[LOADED_GEP:%.*]] = struct_extract [[LOADED_BORROWED_VAL]] : $NonTrivialStruct, #NonTrivialStruct.nonTrivialStruct2 +// CHECK: [[LOADED_VAL:%.*]] = load_borrow [[MARKED_GLOBAL]] : $*NonTrivialStruct +// CHECK: [[LOADED_GEP:%.*]] = struct_extract [[LOADED_VAL]] : $NonTrivialStruct, #NonTrivialStruct.nonTrivialStruct2 // CHECK: apply {{%.*}}([[LOADED_GEP]]) -// CHECK: end_borrow [[LOADED_BORROWED_VAL]] -// CHECK: destroy_value [[LOADED_VAL]] +// CHECK: end_borrow [[LOADED_VAL]] // CHECK: } // end sil function '$s8moveonly16testGlobalBorrowyyF' func testGlobalBorrow() { borrowVal(varGlobal) @@ -856,13 +853,11 @@ func testGlobalBorrow() { // // CHECK: [[GLOBAL:%.*]] = global_addr @$s8moveonly9letGlobalAA16NonTrivialStructVvp : // CHECK: [[MARKED_GLOBAL:%.*]] = mark_must_check [no_consume_or_assign] [[GLOBAL]] -// CHECK: [[LOADED_VAL:%.*]] = load [copy] [[MARKED_GLOBAL]] -// CHECK: [[LOADED_BORROWED_VAL:%.*]] = begin_borrow [[LOADED_VAL]] -// CHECK: [[LOADED_GEP:%.*]] = struct_extract [[LOADED_BORROWED_VAL]] +// CHECK: [[LOADED_VAL:%.*]] = load_borrow [[MARKED_GLOBAL]] +// CHECK: [[LOADED_GEP:%.*]] = struct_extract [[LOADED_VAL]] // CHECK: [[LOADED_GEP_COPY:%.*]] = copy_value [[LOADED_GEP]] -// CHECK: end_borrow [[LOADED_BORROWED_VAL]] -// CHECK: destroy_value [[LOADED_VAL]] // CHECK: apply {{%.*}}([[LOADED_GEP_COPY]]) +// CHECK: end_borrow [[LOADED_VAL]] // // CHECK: } // end sil function '$s8moveonly17testGlobalConsumeyyF' func testGlobalConsume() { @@ -1833,11 +1828,8 @@ public func testSubscriptReadModify_BaseLoadable_ResultAddressOnly_Var() { // CHECK: [[MARK:%.*]] = mark_must_check [no_consume_or_assign] [[PROJECT]] // CHECK: [[LOAD_BORROW:%.*]] = load_borrow [[MARK]] // CHECK: ([[CORO_RESULT:%.*]], [[CORO_TOKEN:%.*]]) = begin_apply {{%.*}}({{%.*}}, [[LOAD_BORROW]]) -// CHECK: [[TEMP:%.*]] = alloc_stack $AddressOnlyProtocol -// CHECK: copy_addr [[CORO_RESULT]] to [init] [[TEMP]] +// CHECK: apply {{%.*}}([[CORO_RESULT]]) // CHECK: end_apply [[CORO_TOKEN]] -// CHECK: apply {{%.*}}([[TEMP]]) -// CHECK: destroy_addr [[TEMP]] // CHECK: end_borrow [[LOAD_BORROW]] // CHECK: } // end sil function '$s8moveonly58testSubscriptReadModify_BaseLoadable_ResultAddressOnly_LetyyF' public func testSubscriptReadModify_BaseLoadable_ResultAddressOnly_Let() { diff --git a/test/SILOptimizer/moveonly_addresschecker.sil b/test/SILOptimizer/moveonly_addresschecker.sil index ccc42c621f880..4feb4dbd1c42f 100644 --- a/test/SILOptimizer/moveonly_addresschecker.sil +++ b/test/SILOptimizer/moveonly_addresschecker.sil @@ -844,3 +844,18 @@ bb0(%0 : $Int, %1a : $*NonCopyableNativeObjectPair): %16 = tuple () return %16 : $() } + +sil [ossa] @testSupportStoreBorrow : $@convention(thin) (@guaranteed NonTrivialStruct) -> () { +bb0(%0 : @guaranteed $NonTrivialStruct): + %1 = alloc_stack $NonTrivialStruct + %1a = mark_must_check [no_consume_or_assign] %1 : $*NonTrivialStruct + %borrow = store_borrow %0 to %1a : $*NonTrivialStruct + %f = function_ref @useNonTrivialStruct : $@convention(thin) (@guaranteed NonTrivialStruct) -> () + %l = load_borrow %borrow : $*NonTrivialStruct + apply %f(%l) : $@convention(thin) (@guaranteed NonTrivialStruct) -> () + end_borrow %l : $NonTrivialStruct + end_borrow %borrow : $*NonTrivialStruct + dealloc_stack %1 : $*NonTrivialStruct + %9999 = tuple () + return %9999 : $() +} diff --git a/test/SILOptimizer/moveonly_addresschecker_diagnostics.sil b/test/SILOptimizer/moveonly_addresschecker_diagnostics.sil index 2e0f2a9348847..5490a9caefc27 100644 --- a/test/SILOptimizer/moveonly_addresschecker_diagnostics.sil +++ b/test/SILOptimizer/moveonly_addresschecker_diagnostics.sil @@ -1,8 +1,6 @@ // RUN: %target-sil-opt -sil-move-only-address-checker -enable-experimental-feature MoveOnlyPartialConsumption -enable-experimental-feature MoveOnlyClasses -enable-sil-verify-all %s -verify // RUN: %target-sil-opt -sil-move-only-address-checker -enable-experimental-feature MoveOnlyPartialConsumption -enable-experimental-feature MoveOnlyClasses -enable-sil-verify-all -move-only-diagnostics-silently-emit-diagnostics %s | %FileCheck %s -// TODO: Add FileCheck - // This file contains specific SIL test cases that we expect to emit // diagnostics. These are cases where we want to make it easy to validate // independent of potential changes in the frontend's emission that this @@ -77,6 +75,7 @@ sil @get_aggstruct : $@convention(thin) () -> @owned AggStruct sil @nonConsumingUseKlass : $@convention(thin) (@guaranteed Klass) -> () sil @nonConsumingUseNonTrivialStruct : $@convention(thin) (@guaranteed NonTrivialStruct) -> () sil @consumingUseNonTrivialStruct : $@convention(thin) (@owned NonTrivialStruct) -> () +sil @inUseNonTrivialStruct : $@convention(thin) (@in NonTrivialStruct) -> () sil @classConsume : $@convention(thin) (@owned Klass) -> () // user: %18 sil @copyableClassConsume : $@convention(thin) (@owned CopyableKlass) -> () // user: %24 sil @copyableClassUseMoveOnlyWithoutEscaping : $@convention(thin) (@guaranteed CopyableKlass) -> () // user: %16 @@ -529,3 +528,115 @@ bb0(%0 : @owned $NonTrivialStruct): %9999 = tuple() return %9999 : $() } + +// CHECK-LABEL: sil [ossa] @testSupportStoreBorrow1 : $@convention(thin) (@guaranteed NonTrivialStruct) -> () { +// CHECK: bb0([[ARG:%.*]] : @guaranteed +// CHECK: [[STACK:%.*]] = alloc_stack +// CHECK: [[BORROW:%.*]] = store_borrow [[ARG]] to [[STACK]] +// CHECK: [[LOAD:%.*]] = load_borrow [[BORROW]] +// CHECK: apply {{%.*}}([[LOAD]]) +// CHECK: end_borrow [[LOAD]] +// CHECK: end_borrow [[BORROW]] +// CHECK: dealloc_stack [[STACK]] +// CHECK: } // end sil function 'testSupportStoreBorrow1' +sil [ossa] @testSupportStoreBorrow1 : $@convention(thin) (@guaranteed NonTrivialStruct) -> () { +bb0(%0 : @guaranteed $NonTrivialStruct): + %1 = alloc_stack $NonTrivialStruct + %1a = mark_must_check [no_consume_or_assign] %1 : $*NonTrivialStruct + %borrow = store_borrow %0 to %1a : $*NonTrivialStruct + %f = function_ref @nonConsumingUseNonTrivialStruct : $@convention(thin) (@guaranteed NonTrivialStruct) -> () + %l = load_borrow %borrow : $*NonTrivialStruct + apply %f(%l) : $@convention(thin) (@guaranteed NonTrivialStruct) -> () + end_borrow %l : $NonTrivialStruct + end_borrow %borrow : $*NonTrivialStruct + dealloc_stack %1 : $*NonTrivialStruct + %9999 = tuple () + return %9999 : $() +} + +// CHECK-LABEL: sil [ossa] @testSupportStoreBorrow2 : $@convention(thin) (@guaranteed NonTrivialStruct) -> () { +// CHECK: bb0([[ARG:%.*]] : @guaranteed +// CHECK: [[STACK:%.*]] = alloc_stack +// CHECK: [[BORROW:%.*]] = store_borrow [[ARG]] to [[STACK]] +// CHECK: [[LOAD:%.*]] = load_borrow [[BORROW]] +// CHECK: apply {{%.*}}([[LOAD]]) +// CHECK: end_borrow [[LOAD]] +// CHECK: end_borrow [[BORROW]] +// CHECK: dealloc_stack [[STACK]] +// CHECK: } // end sil function 'testSupportStoreBorrow2' +sil [ossa] @testSupportStoreBorrow2 : $@convention(thin) (@guaranteed NonTrivialStruct) -> () { +bb0(%0 : @guaranteed $NonTrivialStruct): + %1 = alloc_stack $NonTrivialStruct + %1a = mark_must_check [no_consume_or_assign] %1 : $*NonTrivialStruct + %borrow = store_borrow %0 to %1a : $*NonTrivialStruct + %f = function_ref @nonConsumingUseNonTrivialStruct : $@convention(thin) (@guaranteed NonTrivialStruct) -> () + %l = load [copy] %borrow : $*NonTrivialStruct + apply %f(%l) : $@convention(thin) (@guaranteed NonTrivialStruct) -> () + destroy_value %l : $NonTrivialStruct + end_borrow %borrow : $*NonTrivialStruct + dealloc_stack %1 : $*NonTrivialStruct + %9999 = tuple () + return %9999 : $() +} + +// CHECK-LABEL: sil [ossa] @testSupportStoreBorrow3 : $@convention(thin) (@guaranteed NonTrivialStruct) -> () { +// CHECK: bb0([[ARG:%.*]] : @guaranteed +// CHECK: [[STACK:%.*]] = alloc_stack +// CHECK: [[BORROW:%.*]] = store_borrow [[ARG]] to [[STACK]] +// CHECK: [[LOAD:%.*]] = load_borrow [[BORROW]] +// CHECK: apply {{%.*}}([[LOAD]]) +// CHECK: end_borrow [[LOAD]] +// CHECK: end_borrow [[BORROW]] +// CHECK: dealloc_stack [[STACK]] +// CHECK: } // end sil function 'testSupportStoreBorrow3' +sil [ossa] @testSupportStoreBorrow3 : $@convention(thin) (@guaranteed NonTrivialStruct) -> () { +bb0(%0 : @guaranteed $NonTrivialStruct): + %1 = alloc_stack $NonTrivialStruct + %1a = mark_must_check [no_consume_or_assign] %1 : $*NonTrivialStruct + %borrow = store_borrow %0 to %1a : $*NonTrivialStruct + %f = function_ref @nonConsumingUseNonTrivialStruct : $@convention(thin) (@guaranteed NonTrivialStruct) -> () + %l = load_borrow %borrow : $*NonTrivialStruct + %l2 = copy_value %l : $NonTrivialStruct + apply %f(%l2) : $@convention(thin) (@guaranteed NonTrivialStruct) -> () + destroy_value %l2 : $NonTrivialStruct + end_borrow %l : $NonTrivialStruct + end_borrow %borrow : $*NonTrivialStruct + dealloc_stack %1 : $*NonTrivialStruct + %9999 = tuple () + return %9999 : $() +} + +sil [ossa] @testSupportStoreBorrow4 : $@convention(thin) (@guaranteed NonTrivialStruct) -> () { +bb0(%0 : @guaranteed $NonTrivialStruct): + %1 = alloc_stack $NonTrivialStruct + %1a = mark_must_check [no_consume_or_assign] %1 : $*NonTrivialStruct + // expected-error @-1 {{noncopyable 'unknown' cannot be consumed when captured by an escaping closure}} + %borrow = store_borrow %0 to %1a : $*NonTrivialStruct + %f = function_ref @consumingUseNonTrivialStruct : $@convention(thin) (@owned NonTrivialStruct) -> () + %l = load_borrow %borrow : $*NonTrivialStruct + %l2 = copy_value %l : $NonTrivialStruct + apply %f(%l2) : $@convention(thin) (@owned NonTrivialStruct) -> () + end_borrow %l : $NonTrivialStruct + end_borrow %borrow : $*NonTrivialStruct + dealloc_stack %1 : $*NonTrivialStruct + %9999 = tuple () + return %9999 : $() +} + +sil [ossa] @testSupportStoreBorrow5 : $@convention(thin) (@guaranteed NonTrivialStruct) -> () { +bb0(%0 : @guaranteed $NonTrivialStruct): + %1 = alloc_stack $NonTrivialStruct + %1a = mark_must_check [no_consume_or_assign] %1 : $*NonTrivialStruct + // expected-error @-1 {{'unknown' is borrowed and cannot be consumed}} + %borrow = store_borrow %0 to %1a : $*NonTrivialStruct + %f = function_ref @consumingUseNonTrivialStruct : $@convention(thin) (@owned NonTrivialStruct) -> () + %a = alloc_stack $NonTrivialStruct + copy_addr %borrow to [init] %a : $*NonTrivialStruct + // expected-note @-1 {{consumed here}} + destroy_addr %a : $*NonTrivialStruct + dealloc_stack %a : $*NonTrivialStruct + end_borrow %borrow : $*NonTrivialStruct + dealloc_stack %1 : $*NonTrivialStruct + %9999 = tuple () + return %9999 : $() +} diff --git a/test/SILOptimizer/moveonly_addressonly_subscript_diagnostics.swift b/test/SILOptimizer/moveonly_addressonly_subscript_diagnostics.swift index 212b5982900c2..923f76abeeb07 100644 --- a/test/SILOptimizer/moveonly_addressonly_subscript_diagnostics.swift +++ b/test/SILOptimizer/moveonly_addressonly_subscript_diagnostics.swift @@ -203,7 +203,6 @@ public func testSubscriptReadModify_BaseLoadable_ResultAddressOnly_Var() { public func testSubscriptReadModify_BaseLoadable_ResultAddressOnly_Let() { let m = LoadableSubscriptReadModifyTester() m[0].nonMutatingFunc() - // expected-error @-1 {{copy of noncopyable typed value}} } public func testSubscriptReadModify_BaseLoadable_ResultAddressOnly_InOut(m: inout LoadableSubscriptReadModifyTester) { diff --git a/test/SILOptimizer/moveonly_loadable_subscript_diagnostics.swift b/test/SILOptimizer/moveonly_loadable_subscript_diagnostics.swift index 17074c733e26e..ede31976475de 100644 --- a/test/SILOptimizer/moveonly_loadable_subscript_diagnostics.swift +++ b/test/SILOptimizer/moveonly_loadable_subscript_diagnostics.swift @@ -200,7 +200,6 @@ public func testSubscriptReadModify_BaseLoadable_ResultLoadable_Var() { public func testSubscriptReadModify_BaseLoadable_ResultLoadable_Let() { let m = LoadableSubscriptReadModifyTester() m[0].nonMutatingFunc() - // expected-error @-1 {{copy of noncopyable typed value}} } public func testSubscriptReadModify_BaseLoadable_ResultLoadable_InOut(m: inout LoadableSubscriptReadModifyTester) { From 79935f9720cf7b997c60e5e0d3e01f0b8a05e35c Mon Sep 17 00:00:00 2001 From: Tony Allevato Date: Wed, 2 Aug 2023 14:22:32 -0400 Subject: [PATCH 10/37] Add explicit ctors for aggregation for types that default or delete ctors. In C++20, types that declare or delete their default/copy/move constructors are no longer aggregates, so the aggregate uses of these types will not compile under C++20. Adding them fixes this, without affecting older language modes. --- include/swift/Driver/Compilation.h | 5 +++++ lib/AST/CASTBridging.cpp | 4 ++++ lib/Frontend/DependencyVerifier.cpp | 3 +++ 3 files changed, 12 insertions(+) diff --git a/include/swift/Driver/Compilation.h b/include/swift/Driver/Compilation.h index c0ea0ed65046a..7c94454d401d6 100644 --- a/include/swift/Driver/Compilation.h +++ b/include/swift/Driver/Compilation.h @@ -89,6 +89,11 @@ class Compilation { /// This data is used for cross-module module dependencies. fine_grained_dependencies::ModuleDepGraph depGraph; + Result(bool hadAbnormalExit, int exitCode, + fine_grained_dependencies::ModuleDepGraph depGraph) + : hadAbnormalExit(hadAbnormalExit), exitCode(exitCode), + depGraph(depGraph) {} + Result(const Result &) = delete; Result &operator=(const Result &) = delete; diff --git a/lib/AST/CASTBridging.cpp b/lib/AST/CASTBridging.cpp index 5b30cb1df8660..50215828c43bf 100644 --- a/lib/AST/CASTBridging.cpp +++ b/lib/AST/CASTBridging.cpp @@ -20,6 +20,10 @@ struct BridgedDiagnosticImpl { InFlightDiagnostic inFlight; std::vector textBlobs; + BridgedDiagnosticImpl(InFlightDiagnostic inFlight, + std::vector textBlobs) + : inFlight(std::move(inFlight)), textBlobs(std::move(textBlobs)) {} + BridgedDiagnosticImpl(const BridgedDiagnosticImpl &) = delete; BridgedDiagnosticImpl(BridgedDiagnosticImpl &&) = delete; BridgedDiagnosticImpl &operator=(const BridgedDiagnosticImpl &) = delete; diff --git a/lib/Frontend/DependencyVerifier.cpp b/lib/Frontend/DependencyVerifier.cpp index db40b0a00b4c0..81f4396412378 100644 --- a/lib/Frontend/DependencyVerifier.cpp +++ b/lib/Frontend/DependencyVerifier.cpp @@ -148,6 +148,9 @@ struct Obligation { public: Key() = delete; + private: + Key(StringRef Name, Expectation::Kind Kind) : Name(Name), Kind(Kind) {} + public: static Key forNegative(StringRef name) { return Key{name, Expectation::Kind::Negative}; From c71c1e193bbfb2f814d4eb2b895c0bb517ec0a0f Mon Sep 17 00:00:00 2001 From: Tony Allevato Date: Wed, 2 Aug 2023 15:03:48 -0400 Subject: [PATCH 11/37] Ensure types used as `std::vector` elements are complete. These were never allowed, but with C++20 making more `vector` functions `constexpr`, they would start causing build failures in that language mode. --- .../SILOptimizer/PassManager/PassPipeline.h | 45 +++++++++---------- lib/IRGen/Outlining.h | 2 +- 2 files changed, 23 insertions(+), 24 deletions(-) diff --git a/include/swift/SILOptimizer/PassManager/PassPipeline.h b/include/swift/SILOptimizer/PassManager/PassPipeline.h index e9ecc77e7a078..d6c963c2c7612 100644 --- a/include/swift/SILOptimizer/PassManager/PassPipeline.h +++ b/include/swift/SILOptimizer/PassManager/PassPipeline.h @@ -33,7 +33,28 @@ namespace swift { class SILPassPipelinePlan; -struct SILPassPipeline; + +struct SILPassPipeline final { + unsigned ID; + StringRef Name; + unsigned KindOffset; + bool isFunctionPassPipeline; + + friend bool operator==(const SILPassPipeline &lhs, + const SILPassPipeline &rhs) { + return lhs.ID == rhs.ID && lhs.Name.equals(rhs.Name) && + lhs.KindOffset == rhs.KindOffset; + } + + friend bool operator!=(const SILPassPipeline &lhs, + const SILPassPipeline &rhs) { + return !(lhs == rhs); + } + + friend llvm::hash_code hash_value(const SILPassPipeline &pipeline) { + return llvm::hash_combine(pipeline.ID, pipeline.Name, pipeline.KindOffset); + } +}; enum class PassPipelineKind { #define PASSPIPELINE(NAME, DESCRIPTION) NAME, @@ -123,28 +144,6 @@ class SILPassPipelinePlan final { } }; -struct SILPassPipeline final { - unsigned ID; - StringRef Name; - unsigned KindOffset; - bool isFunctionPassPipeline; - - friend bool operator==(const SILPassPipeline &lhs, - const SILPassPipeline &rhs) { - return lhs.ID == rhs.ID && lhs.Name.equals(rhs.Name) && - lhs.KindOffset == rhs.KindOffset; - } - - friend bool operator!=(const SILPassPipeline &lhs, - const SILPassPipeline &rhs) { - return !(lhs == rhs); - } - - friend llvm::hash_code hash_value(const SILPassPipeline &pipeline) { - return llvm::hash_combine(pipeline.ID, pipeline.Name, pipeline.KindOffset); - } -}; - inline void SILPassPipelinePlan:: startPipeline(StringRef Name, bool isFunctionPassPipeline) { PipelineStages.push_back(SILPassPipeline{ diff --git a/lib/IRGen/Outlining.h b/lib/IRGen/Outlining.h index 8cda830d97401..97433fff2390b 100644 --- a/lib/IRGen/Outlining.h +++ b/lib/IRGen/Outlining.h @@ -17,6 +17,7 @@ #ifndef SWIFT_IRGEN_OUTLINING_H #define SWIFT_IRGEN_OUTLINING_H +#include "LocalTypeDataKind.h" #include "swift/Basic/LLVM.h" #include "llvm/ADT/MapVector.h" @@ -37,7 +38,6 @@ class Address; class Explosion; class IRGenFunction; class IRGenModule; -class LocalTypeDataKey; class TypeInfo; /// A helper class for emitting outlined value operations. From 52e216e6ebad13065485c822f1d3dbe7f27fb897 Mon Sep 17 00:00:00 2001 From: Tony Allevato Date: Wed, 2 Aug 2023 15:43:15 -0400 Subject: [PATCH 12/37] Force `LinkEntity` to be zero-initialized to avoid use-of-uninitialized-value problems. --- include/swift/IRGen/Linking.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/swift/IRGen/Linking.h b/include/swift/IRGen/Linking.h index 87b7d69d22c75..cfec436b1862f 100644 --- a/include/swift/IRGen/Linking.h +++ b/include/swift/IRGen/Linking.h @@ -680,7 +680,7 @@ class LinkEntity { Data = LINKENTITY_SET_FIELD(Kind, unsigned(kind)); } - LinkEntity() = default; + LinkEntity() : Pointer(nullptr), SecondaryPointer(nullptr), Data(0) {} static bool isValidResilientMethodRef(SILDeclRef declRef) { if (declRef.isForeign) From b5d3e0b6f705f9e73a9ca504c06eb5adf0f026a7 Mon Sep 17 00:00:00 2001 From: Doug Gregor Date: Mon, 31 Jul 2023 14:50:15 -0700 Subject: [PATCH 13/37] [Type refinement context] Don't query property wrappers just for range info Querying property wrappers involves semantic analysis that can cause cyclic references while building the type refinement context, and it's unnecessary: we need only know that these are custom attributes to incorporate their source ranges. Switch to the simpler/cheaper query. A small part of fixing the cyclic references in rdar://112079160. --- lib/Sema/TypeCheckAvailability.cpp | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/lib/Sema/TypeCheckAvailability.cpp b/lib/Sema/TypeCheckAvailability.cpp index 5c35966d5caa2..7b9bc64e60630 100644 --- a/lib/Sema/TypeCheckAvailability.cpp +++ b/lib/Sema/TypeCheckAvailability.cpp @@ -679,13 +679,16 @@ class TypeRefinementContextBuilder : private ASTWalker { // For a variable declaration (without accessors) we use the range of the // containing pattern binding declaration to make sure that we include // any type annotation in the type refinement context range. We also - // need to include any attached property wrappers. + // need to include any custom attributes that were written on the + // declaration. if (auto *varDecl = dyn_cast(storageDecl)) { if (auto *PBD = varDecl->getParentPatternBinding()) Range = PBD->getSourceRange(); - for (auto *propertyWrapper : varDecl->getAttachedPropertyWrappers()) { - Range.widen(propertyWrapper->getRange()); + for (auto attr : varDecl->getOriginalAttrs()) { + if (auto customAttr = dyn_cast(attr)) { + Range.widen(customAttr->getRange()); + } } } @@ -696,7 +699,7 @@ class TypeRefinementContextBuilder : private ASTWalker { // locations and have callers of that method provide appropriate source // locations. SourceRange BracesRange = storageDecl->getBracesRange(); - if (storageDecl->hasParsedAccessors() && BracesRange.isValid()) { + if (BracesRange.isValid()) { Range.widen(BracesRange); } From 3079f3d0748eefa034f725ffb9c57038f298c175 Mon Sep 17 00:00:00 2001 From: Doug Gregor Date: Mon, 31 Jul 2023 22:43:01 -0700 Subject: [PATCH 14/37] [Type refinement context] Lazily expand TRCs for pattern bindings Eager expansion of type refinement contexts (TRCs) for variables within pattern binding declarations is causing cyclic references in some places involving macros. Make this expansion lazy, triggered by walking into these pattern binding declarations as part of (e.g.) availability queries. Another step toward fixing the cyclic references in rdar://112079160. --- include/swift/AST/TypeCheckRequests.h | 20 +++ include/swift/AST/TypeCheckerTypeIDZone.def | 3 + include/swift/AST/TypeRefinementContext.h | 3 + lib/AST/TypeRefinementContext.cpp | 18 +++ lib/Sema/TypeCheckAvailability.cpp | 164 +++++++++++++------- 5 files changed, 152 insertions(+), 56 deletions(-) diff --git a/include/swift/AST/TypeCheckRequests.h b/include/swift/AST/TypeCheckRequests.h index cf65ce8111e18..d035263fa78b0 100644 --- a/include/swift/AST/TypeCheckRequests.h +++ b/include/swift/AST/TypeCheckRequests.h @@ -63,6 +63,7 @@ class TypeAliasDecl; class TypeLoc; class Witness; class TypeResolution; +class TypeRefinementContext; struct TypeWitnessAndDecl; class ValueDecl; enum class OpaqueReadOwnership: uint8_t; @@ -4433,6 +4434,25 @@ class InitAccessorReferencedVariablesRequest bool isCached() const { return true; } }; +/// Expand the children of the type refinement context for the given +/// declaration. +class ExpandChildTypeRefinementContextsRequest + : public SimpleRequest { +public: + using SimpleRequest::SimpleRequest; + +private: + friend SimpleRequest; + + bool evaluate(Evaluator &evaluator, Decl *decl, + TypeRefinementContext *parentTRC) const; + +public: + bool isCached() const { return true; } +}; + #define SWIFT_TYPEID_ZONE TypeChecker #define SWIFT_TYPEID_HEADER "swift/AST/TypeCheckerTypeIDZone.def" #include "swift/Basic/DefineTypeIDZone.h" diff --git a/include/swift/AST/TypeCheckerTypeIDZone.def b/include/swift/AST/TypeCheckerTypeIDZone.def index e26034a3df926..0d322e7501c01 100644 --- a/include/swift/AST/TypeCheckerTypeIDZone.def +++ b/include/swift/AST/TypeCheckerTypeIDZone.def @@ -506,3 +506,6 @@ SWIFT_REQUEST(TypeChecker, InitAccessorReferencedVariablesRequest, ArrayRef(DeclAttribute *, AccessorDecl *, ArrayRef), Cached, NoLocationInfo) +SWIFT_REQUEST(TypeChecker, ExpandChildTypeRefinementContextsRequest, + bool(Decl *, TypeRefinementContext *), + Cached, NoLocationInfo) diff --git a/include/swift/AST/TypeRefinementContext.h b/include/swift/AST/TypeRefinementContext.h index dbe492476b8b7..22e78af500955 100644 --- a/include/swift/AST/TypeRefinementContext.h +++ b/include/swift/AST/TypeRefinementContext.h @@ -297,6 +297,9 @@ class TypeRefinementContext : public ASTAllocated { static StringRef getReasonName(Reason R); }; +void simple_display(llvm::raw_ostream &out, + const TypeRefinementContext *trc); + } // end namespace swift #endif diff --git a/lib/AST/TypeRefinementContext.cpp b/lib/AST/TypeRefinementContext.cpp index 99e84e26b6c55..d349b1114971f 100644 --- a/lib/AST/TypeRefinementContext.cpp +++ b/lib/AST/TypeRefinementContext.cpp @@ -20,6 +20,7 @@ #include "swift/AST/Stmt.h" #include "swift/AST/Expr.h" #include "swift/AST/SourceFile.h" +#include "swift/AST/TypeCheckRequests.h" #include "swift/AST/TypeRefinementContext.h" #include "swift/Basic/SourceManager.h" @@ -192,6 +193,18 @@ TypeRefinementContext::findMostRefinedSubContext(SourceLoc Loc, !rangeContainsTokenLocWithGeneratedSource(SM, SrcRange, Loc)) return nullptr; + // If this context is for a declaration, ensure that we've expanded the + // children of the declaration. + if (Node.getReason() == Reason::Decl || + Node.getReason() == Reason::DeclImplicit) { + if (auto decl = Node.getAsDecl()) { + ASTContext &ctx = decl->getASTContext(); + (void)evaluateOrDefault( + ctx.evaluator, ExpandChildTypeRefinementContextsRequest{decl, this}, + false); + } + } + // For the moment, we perform a linear search here, but we can and should // do something more efficient. for (TypeRefinementContext *Child : Children) { @@ -411,3 +424,8 @@ StringRef TypeRefinementContext::getReasonName(Reason R) { llvm_unreachable("Unhandled Reason in switch."); } + +void swift::simple_display( + llvm::raw_ostream &out, const TypeRefinementContext *trc) { + out << "TRC @" << trc; +} diff --git a/lib/Sema/TypeCheckAvailability.cpp b/lib/Sema/TypeCheckAvailability.cpp index 7b9bc64e60630..a6e166f42fb62 100644 --- a/lib/Sema/TypeCheckAvailability.cpp +++ b/lib/Sema/TypeCheckAvailability.cpp @@ -440,6 +440,8 @@ class TypeRefinementContextBuilder : private ASTWalker { return "building type refinement context for"; } + friend class swift::ExpandChildTypeRefinementContextsRequest; + public: TypeRefinementContextBuilder(TypeRefinementContext *TRC, ASTContext &Context) : Context(Context) { @@ -576,9 +578,10 @@ class TypeRefinementContextBuilder : private ASTWalker { // pattern. The context necessary for the pattern as a whole was already // introduced if necessary by the first var decl. if (auto *VD = dyn_cast(D)) { - if (auto *PBD = VD->getParentPatternBinding()) + if (auto *PBD = VD->getParentPatternBinding()) { if (VD != PBD->getAnchoringVarDecl(0)) return nullptr; + } } // Declarations with an explicit availability attribute always get a TRC. @@ -597,11 +600,16 @@ class TypeRefinementContextBuilder : private ASTWalker { // internal property in a public struct can be effectively less available // than the containing struct decl because the internal property will only // be accessed by code running at the deployment target or later. + // + // For declarations that have their child construction delayed, always + // create this implicit declaration context. It will be used to trigger + // lazy creation of the child TRCs. AvailabilityContext CurrentAvailability = getCurrentTRC()->getAvailabilityInfo(); AvailabilityContext EffectiveAvailability = getEffectiveAvailabilityForDeclSignature(D, CurrentAvailability); - if (CurrentAvailability.isSupersetOf(EffectiveAvailability)) + if (isa(D) || + CurrentAvailability.isSupersetOf(EffectiveAvailability)) return TypeRefinementContext::createForDeclImplicit( Context, D, getCurrentTRC(), EffectiveAvailability, refinementSourceRangeForDecl(D)); @@ -709,28 +717,79 @@ class TypeRefinementContextBuilder : private ASTWalker { return D->getSourceRange(); } + // Creates an implicit decl TRC specifying the deployment + // target for `range` in decl `D`. + TypeRefinementContext * + createImplicitDeclContextForDeploymentTarget(Decl *D, SourceRange range){ + AvailabilityContext Availability = + AvailabilityContext::forDeploymentTarget(Context); + Availability.intersectWith(getCurrentTRC()->getAvailabilityInfo()); + + return TypeRefinementContext::createForDeclImplicit( + Context, D, getCurrentTRC(), Availability, range); + } + + /// Build contexts for a VarDecl with the given initializer. + void buildContextsForPatternBindingDecl(PatternBindingDecl *pattern) { + // Build contexts for each of the pattern entries. + for (unsigned index : range(pattern->getNumPatternEntries())) { + auto var = pattern->getAnchoringVarDecl(index); + if (!var) + continue; + + // Var decls may have associated pattern binding decls or property wrappers + // with init expressions. Those expressions need to be constrained to the + // deployment target unless they are exposed to clients. + if (!var->hasInitialValue() || var->isInitExposedToClients()) + continue; + + auto *initExpr = pattern->getInit(index); + if (initExpr && !initExpr->isImplicit()) { + assert(initExpr->getSourceRange().isValid()); + + // Create a TRC for the init written in the source. The ASTWalker + // won't visit these expressions so instead of pushing these onto the + // stack we build them directly. + auto *initTRC = createImplicitDeclContextForDeploymentTarget( + var, initExpr->getSourceRange()); + TypeRefinementContextBuilder(initTRC, Context).build(initExpr); + } + } + + // Ideally any init expression would be returned by `getInit()` above. + // However, for property wrappers it doesn't get populated until + // typechecking completes (which is too late). Instead, we find the + // the property wrapper attribute and use its source range to create a + // TRC for the initializer expression. + // + // FIXME: Since we don't have an expression here, we can't build out its + // TRC. If the Expr that will eventually be created contains a closure + // expression, then it might have AST nodes that need to be refined. For + // example, property wrapper initializers that takes block arguments + // are not handled correctly because of this (rdar://77841331). + if (auto firstVar = pattern->getAnchoringVarDecl(0)) { + if (firstVar->hasInitialValue() && !firstVar->isInitExposedToClients()) { + for (auto *wrapper : firstVar->getAttachedPropertyWrappers()) { + createImplicitDeclContextForDeploymentTarget( + firstVar, wrapper->getRange()); + } + } + } + } + void buildContextsForBodyOfDecl(Decl *D) { // Are we already constrained by the deployment target? If not, adding // new contexts won't change availability. if (isCurrentTRCContainedByDeploymentTarget()) return; - // A lambda that creates an implicit decl TRC specifying the deployment - // target for `range` in decl `D`. - auto createContext = [this](Decl *D, SourceRange range) { - AvailabilityContext Availability = - AvailabilityContext::forDeploymentTarget(Context); - Availability.intersectWith(getCurrentTRC()->getAvailabilityInfo()); - - return TypeRefinementContext::createForDeclImplicit( - Context, D, getCurrentTRC(), Availability, range); - }; - // Top level code always uses the deployment target. if (auto tlcd = dyn_cast(D)) { if (auto bodyStmt = tlcd->getBody()) { - pushDeclBodyContext(createContext(tlcd, tlcd->getSourceRange()), tlcd, - bodyStmt); + pushDeclBodyContext( + createImplicitDeclContextForDeploymentTarget( + tlcd, tlcd->getSourceRange()), + tlcd, bodyStmt); } return; } @@ -741,51 +800,14 @@ class TypeRefinementContextBuilder : private ASTWalker { if (!afd->isImplicit() && afd->getResilienceExpansion() != ResilienceExpansion::Minimal) { if (auto body = afd->getBody(/*canSynthesize*/ false)) { - pushDeclBodyContext(createContext(afd, afd->getBodySourceRange()), - afd, body); + pushDeclBodyContext( + createImplicitDeclContextForDeploymentTarget( + afd, afd->getBodySourceRange()), + afd, body); } } return; } - - // Var decls may have associated pattern binding decls or property wrappers - // with init expressions. Those expressions need to be constrained to the - // deployment target unless they are exposed to clients. - if (auto vd = dyn_cast(D)) { - if (!vd->hasInitialValue() || vd->isInitExposedToClients()) - return; - - if (auto *pbd = vd->getParentPatternBinding()) { - int idx = pbd->getPatternEntryIndexForVarDecl(vd); - auto *initExpr = pbd->getInit(idx); - if (initExpr && !initExpr->isImplicit()) { - assert(initExpr->getSourceRange().isValid()); - - // Create a TRC for the init written in the source. The ASTWalker - // won't visit these expressions so instead of pushing these onto the - // stack we build them directly. - auto *initTRC = createContext(vd, initExpr->getSourceRange()); - TypeRefinementContextBuilder(initTRC, Context).build(initExpr); - } - - // Ideally any init expression would be returned by `getInit()` above. - // However, for property wrappers it doesn't get populated until - // typechecking completes (which is too late). Instead, we find the - // the property wrapper attribute and use its source range to create a - // TRC for the initializer expression. - // - // FIXME: Since we don't have an expression here, we can't build out its - // TRC. If the Expr that will eventually be created contains a closure - // expression, then it might have AST nodes that need to be refined. For - // example, property wrapper initializers that takes block arguments - // are not handled correctly because of this (rdar://77841331). - for (auto *wrapper : vd->getAttachedPropertyWrappers()) { - createContext(vd, wrapper->getRange()); - } - } - - return; - } } PreWalkResult walkToStmtPre(Stmt *S) override { @@ -1277,6 +1299,36 @@ TypeChecker::getOrBuildTypeRefinementContext(SourceFile *SF) { return TRC; } +bool ExpandChildTypeRefinementContextsRequest::evaluate( + Evaluator &evaluator, Decl *decl, TypeRefinementContext *parentTRC +) const { + // If the parent TRC is already contained by the deployment target, there's + // nothing more to do. + ASTContext &ctx = decl->getASTContext(); + if (computeContainedByDeploymentTarget(parentTRC, ctx)) + return false; + + // Variables can have children corresponding to property wrappers and + // the initial values provided in each pattern binding entry. + if (auto var = dyn_cast(decl)) { + if (auto *pattern = var->getParentPatternBinding()) { + // Only do this for the first variable in the pattern binding declaration. + auto anchorVar = pattern->getAnchoringVarDecl(0); + if (anchorVar != var) { + return evaluateOrDefault( + evaluator, + ExpandChildTypeRefinementContextsRequest{anchorVar, parentTRC}, + false); + } + + TypeRefinementContextBuilder builder(parentTRC, ctx); + builder.buildContextsForPatternBindingDecl(pattern); + } + } + + return false; +} + AvailabilityContext TypeChecker::overApproximateAvailabilityAtLocation(SourceLoc loc, const DeclContext *DC, From 50ca096cd8de72f49d9e137be5b9f9fe263799e4 Mon Sep 17 00:00:00 2001 From: Doug Gregor Date: Tue, 1 Aug 2023 09:49:47 -0700 Subject: [PATCH 15/37] Make `VarDecl::isLayoutExposedToClients` check property wrappers more lazily The check for "has property wrappers" as part of determining whether the layout of a variable is exposed to clients can trigger reference cycles. Push this check later, which eliminates these cycles for types that aren't frozen/fixed-layout. This is a hack, not a real fix, but it eliminates the cyclic references observed in rdar://112079160. --- lib/AST/Decl.cpp | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/lib/AST/Decl.cpp b/lib/AST/Decl.cpp index 44f15ea3d4246..72168b6a28080 100644 --- a/lib/AST/Decl.cpp +++ b/lib/AST/Decl.cpp @@ -2079,19 +2079,23 @@ bool VarDecl::isLayoutExposedToClients() const { if (!parent) return false; if (isStatic()) return false; - if (!hasStorage() && - !getAttrs().hasAttribute() && - !hasAttachedPropertyWrapper()) { - return false; - } auto nominalAccess = parent->getFormalAccessScope(/*useDC=*/nullptr, /*treatUsableFromInlineAsPublic=*/true); if (!nominalAccess.isPublic()) return false; - return (parent->getAttrs().hasAttribute() || - parent->getAttrs().hasAttribute()); + if (!parent->getAttrs().hasAttribute() && + !parent->getAttrs().hasAttribute()) + return false; + + if (!hasStorage() && + !getAttrs().hasAttribute() && + !hasAttachedPropertyWrapper()) { + return false; + } + + return true; } /// Check whether the given type representation will be From 0d779dfd10934a45d2b87b791028a0920e6b67f1 Mon Sep 17 00:00:00 2001 From: Doug Gregor Date: Tue, 1 Aug 2023 14:11:32 -0700 Subject: [PATCH 16/37] [Type refinement context] Avoid creating implicit contexts with bad ranges --- lib/Sema/TypeCheckAvailability.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Sema/TypeCheckAvailability.cpp b/lib/Sema/TypeCheckAvailability.cpp index a6e166f42fb62..dd9efb3e08cfe 100644 --- a/lib/Sema/TypeCheckAvailability.cpp +++ b/lib/Sema/TypeCheckAvailability.cpp @@ -608,7 +608,7 @@ class TypeRefinementContextBuilder : private ASTWalker { getCurrentTRC()->getAvailabilityInfo(); AvailabilityContext EffectiveAvailability = getEffectiveAvailabilityForDeclSignature(D, CurrentAvailability); - if (isa(D) || + if ((isa(D) && refinementSourceRangeForDecl(D).isValid()) || CurrentAvailability.isSupersetOf(EffectiveAvailability)) return TypeRefinementContext::createForDeclImplicit( Context, D, getCurrentTRC(), EffectiveAvailability, From c09817505909d8d023c5ec3f143d6aa93059f77a Mon Sep 17 00:00:00 2001 From: Doug Gregor Date: Tue, 1 Aug 2023 16:10:04 -0700 Subject: [PATCH 17/37] Add test case involving circular references with `@Observable` Add a test case for Observable types that are extended from other source files. Prior to the recent changes to make `TypeRefinementContext` more lazy, this would trigger circular references through the `TypeRefinementContextBuilder`. Finishes rdar://112079160. --- .../Observation/Inputs/ObservableClass.swift | 7 +++++++ .../ObservableAvailabilityCycle.swift | 19 +++++++++++++++++++ 2 files changed, 26 insertions(+) create mode 100644 test/stdlib/Observation/Inputs/ObservableClass.swift create mode 100644 test/stdlib/Observation/ObservableAvailabilityCycle.swift diff --git a/test/stdlib/Observation/Inputs/ObservableClass.swift b/test/stdlib/Observation/Inputs/ObservableClass.swift new file mode 100644 index 0000000000000..16819bdada192 --- /dev/null +++ b/test/stdlib/Observation/Inputs/ObservableClass.swift @@ -0,0 +1,7 @@ +import Foundation +import Observation + +@available(SwiftStdlib 5.9, *) +@Observable final public class ObservableClass { + public var state: State = .unused +} diff --git a/test/stdlib/Observation/ObservableAvailabilityCycle.swift b/test/stdlib/Observation/ObservableAvailabilityCycle.swift new file mode 100644 index 0000000000000..7343740f2b259 --- /dev/null +++ b/test/stdlib/Observation/ObservableAvailabilityCycle.swift @@ -0,0 +1,19 @@ +// REQUIRES: swift_swift_parser + +// RUN: %target-swift-frontend -typecheck -parse-as-library -enable-experimental-feature InitAccessors -external-plugin-path %swift-host-lib-dir/plugins#%swift-plugin-server -primary-file %s %S/Inputs/ObservableClass.swift + +// RUN: %target-swift-frontend -typecheck -parse-as-library -enable-experimental-feature InitAccessors -external-plugin-path %swift-host-lib-dir/plugins#%swift-plugin-server %s -primary-file %S/Inputs/ObservableClass.swift + +// REQUIRES: observation +// REQUIRES: concurrency +// REQUIRES: objc_interop +// UNSUPPORTED: use_os_stdlib +// UNSUPPORTED: back_deployment_runtime + +@available(SwiftStdlib 5.9, *) +extension ObservableClass { + @frozen public enum State: Sendable { + case unused + case used + } +} From 7f031dfdd496808041471c49a77edb880036d97a Mon Sep 17 00:00:00 2001 From: Doug Gregor Date: Wed, 2 Aug 2023 07:38:19 -0700 Subject: [PATCH 18/37] Drop unnecessary "parent context" state from TypeRefinementContextBuilder This state is a holdover from when accessors we stored "alongside" their variable declarations, rather than contained within them. That's no longer the case, so we don't need to track this information any more. --- lib/Sema/TypeCheckAvailability.cpp | 67 ------------------------------ 1 file changed, 67 deletions(-) diff --git a/lib/Sema/TypeCheckAvailability.cpp b/lib/Sema/TypeCheckAvailability.cpp index dd9efb3e08cfe..ba4172384ddc9 100644 --- a/lib/Sema/TypeCheckAvailability.cpp +++ b/lib/Sema/TypeCheckAvailability.cpp @@ -388,20 +388,6 @@ class TypeRefinementContextBuilder : private ASTWalker { }; std::vector DeclBodyContextStack; - /// A mapping from abstract storage declarations with accessors to - /// to the type refinement contexts for those declarations. We refer to - /// this map to determine the appropriate parent TRC to use when - /// walking the accessor function. - llvm::DenseMap - StorageContexts; - - /// A mapping from pattern binding storage declarations to the type refinement - /// contexts for those declarations. We refer to this map to determine the - /// appropriate parent TRC to use when walking a var decl that belongs to a - /// pattern containing multiple vars. - llvm::DenseMap - PatternBindingContexts; - TypeRefinementContext *getCurrentTRC() { return ContextStack.back().TRC; } @@ -482,19 +468,9 @@ class TypeRefinementContextBuilder : private ASTWalker { PreWalkAction walkToDeclPre(Decl *D) override { PrettyStackTraceDecl trace(stackTraceAction(), D); - // Adds in a parent TRC for decls which are syntactically nested but are not - // represented that way in the AST. (Particularly, AbstractStorageDecl - // parents for AccessorDecl children.) - if (auto ParentTRC = getEffectiveParentContextForDecl(D)) { - pushContext(ParentTRC, D); - } - // Adds in a TRC that covers the entire declaration. if (auto DeclTRC = getNewContextForSignatureOfDecl(D)) { pushContext(DeclTRC, D); - - // Possibly use this as an effective parent context later. - recordEffectiveParentContext(D, DeclTRC); } // Create TRCs that cover only the body of the declaration. @@ -515,49 +491,6 @@ class TypeRefinementContextBuilder : private ASTWalker { return Action::Continue(); } - TypeRefinementContext *getEffectiveParentContextForDecl(Decl *D) { - // FIXME: Can we assert that we won't walk parent decls later that should - // have been returned here? - if (auto *accessor = dyn_cast(D)) { - // Use TRC of the storage rather the current TRC when walking this - // function. - auto it = StorageContexts.find(accessor->getStorage()); - if (it != StorageContexts.end()) { - return it->second; - } - } else if (auto *VD = dyn_cast(D)) { - // Use the TRC of the pattern binding decl as the parent for var decls. - if (auto *PBD = VD->getParentPatternBinding()) { - auto it = PatternBindingContexts.find(PBD); - if (it != PatternBindingContexts.end()) { - return it->second; - } - } - } - - return nullptr; - } - - /// If necessary, records a TRC so it can be returned by subsequent calls to - /// `getEffectiveParentContextForDecl()`. - void recordEffectiveParentContext(Decl *D, TypeRefinementContext *NewTRC) { - if (auto *StorageDecl = dyn_cast(D)) { - // Stash the TRC for the storage declaration to use as the parent of - // accessor decls later. - if (StorageDecl->hasParsedAccessors()) - StorageContexts[StorageDecl] = NewTRC; - } - - if (auto *VD = dyn_cast(D)) { - // Stash the TRC for the var decl if its parent pattern binding decl has - // more than one entry so that the sibling var decls can reuse it. - if (auto *PBD = VD->getParentPatternBinding()) { - if (PBD->getNumPatternEntries() > 1) - PatternBindingContexts[PBD] = NewTRC; - } - } - } - /// Returns a new context to be introduced for the declaration, or nullptr /// if no new context should be introduced. TypeRefinementContext *getNewContextForSignatureOfDecl(Decl *D) { From 37959de29e7dd142f2a2d6a1c5e29fcf2b8898cd Mon Sep 17 00:00:00 2001 From: Doug Gregor Date: Wed, 2 Aug 2023 10:30:59 -0700 Subject: [PATCH 19/37] Establish type refinement contexts for pattern binding decls directly The type refinement context builder had a bunch of logic to try to model type refinement contexts for the first variable declaration that shows up within a pattern binding declaration. Instead, model this more syntactically by creating a type refinement context for the pattern binding declaration itself. This both addresses a regression in the handling of `if #available` within a closure that's part of an initializer, and fixes a bug in the same area where similar code has explicit availability annotations. --- include/swift/AST/Availability.h | 7 ++ lib/AST/Availability.cpp | 30 ++++++++ lib/AST/TypeRefinementContext.cpp | 4 + lib/Sema/TypeCheckAvailability.cpp | 106 +++++++------------------- test/Sema/availability_versions.swift | 22 ++++++ 5 files changed, 91 insertions(+), 78 deletions(-) diff --git a/include/swift/AST/Availability.h b/include/swift/AST/Availability.h index b36f7c674c2a5..6cdb09b250bff 100644 --- a/include/swift/AST/Availability.h +++ b/include/swift/AST/Availability.h @@ -377,6 +377,13 @@ class AvailabilityInference { annotatedAvailableRangeForAttr(const SpecializeAttr *attr, ASTContext &ctx); }; +/// Given a declaration upon which an availability attribute would appear in +/// concrete syntax, return a declaration to which the parser +/// actually attaches the attribute in the abstract syntax tree. We use this +/// function to determine whether the concrete syntax already has an +/// availability attribute. +const Decl *abstractSyntaxDeclForAvailableAttribute(const Decl *D); + } // end namespace swift #endif diff --git a/lib/AST/Availability.cpp b/lib/AST/Availability.cpp index eb233b0882e05..38a0732da6133 100644 --- a/lib/AST/Availability.cpp +++ b/lib/AST/Availability.cpp @@ -225,6 +225,8 @@ AvailabilityInference::attrForAnnotatedAvailableRange(const Decl *D, ASTContext &Ctx) { const AvailableAttr *bestAvailAttr = nullptr; + D = abstractSyntaxDeclForAvailableAttribute(D); + for (auto Attr : D->getAttrs()) { auto *AvailAttr = dyn_cast(Attr); if (AvailAttr == nullptr || !AvailAttr->Introduced.has_value() || @@ -749,3 +751,31 @@ ASTContext::getSwift5PlusAvailability(llvm::VersionTuple swiftVersion) { bool ASTContext::supportsVersionedAvailability() const { return minimumAvailableOSVersionForTriple(LangOpts.Target).has_value(); } + +const Decl * +swift::abstractSyntaxDeclForAvailableAttribute(const Decl *ConcreteSyntaxDecl) { + // This function needs to be kept in sync with its counterpart, + // concreteSyntaxDeclForAvailableAttribute(). + + if (auto *PBD = dyn_cast(ConcreteSyntaxDecl)) { + // Existing @available attributes in the AST are attached to VarDecls + // rather than PatternBindingDecls, so we return the first VarDecl for + // the pattern binding declaration. + // This is safe, even though there may be multiple VarDecls, because + // all parsed attribute that appear in the concrete syntax upon on the + // PatternBindingDecl are added to all of the VarDecls for the pattern + // binding. + for (auto index : range(PBD->getNumPatternEntries())) { + if (auto VD = PBD->getAnchoringVarDecl(index)) + return VD; + } + } else if (auto *ECD = dyn_cast(ConcreteSyntaxDecl)) { + // Similar to the PatternBindingDecl case above, we return the + // first EnumElementDecl. + if (auto *Elem = ECD->getFirstElement()) { + return Elem; + } + } + + return ConcreteSyntaxDecl; +} diff --git a/lib/AST/TypeRefinementContext.cpp b/lib/AST/TypeRefinementContext.cpp index d349b1114971f..cde3ebff744b2 100644 --- a/lib/AST/TypeRefinementContext.cpp +++ b/lib/AST/TypeRefinementContext.cpp @@ -367,6 +367,10 @@ void TypeRefinementContext::print(raw_ostream &OS, SourceManager &SrcMgr, OS << "extension." << ED->getExtendedType().getString(); } else if (isa(D)) { OS << ""; + } else if (auto PBD = dyn_cast(D)) { + if (auto VD = PBD->getAnchoringVarDecl(0)) { + OS << VD->getName(); + } } } diff --git a/lib/Sema/TypeCheckAvailability.cpp b/lib/Sema/TypeCheckAvailability.cpp index ba4172384ddc9..c5a52f71edb5d 100644 --- a/lib/Sema/TypeCheckAvailability.cpp +++ b/lib/Sema/TypeCheckAvailability.cpp @@ -308,6 +308,8 @@ ExportContext::getExportabilityReason() const { /// on the target platform. static const AvailableAttr *getActiveAvailableAttribute(const Decl *D, ASTContext &AC) { + D = abstractSyntaxDeclForAvailableAttribute(D); + for (auto Attr : D->getAttrs()) if (auto AvAttr = dyn_cast(Attr)) { if (!AvAttr->isInvalid() && AvAttr->isActivePlatform(AC)) { @@ -494,7 +496,10 @@ class TypeRefinementContextBuilder : private ASTWalker { /// Returns a new context to be introduced for the declaration, or nullptr /// if no new context should be introduced. TypeRefinementContext *getNewContextForSignatureOfDecl(Decl *D) { - if (!isa(D) && !isa(D) && !isa(D)) + if (!isa(D) && + !isa(D) && + !isa(D) && + !isa(D)) return nullptr; // Only introduce for an AbstractStorageDecl if it is not local. We @@ -503,20 +508,17 @@ class TypeRefinementContextBuilder : private ASTWalker { if (isa(D) && D->getDeclContext()->isLocalContext()) return nullptr; + // Don't introduce for variable declarations that have a parent pattern + // binding; all of the relevant information is on the pattern binding. + if (auto var = dyn_cast(D)) { + if (var->getParentPatternBinding()) + return nullptr; + } + // Ignore implicit declarations (mainly skips over `DeferStmt` functions). if (D->isImplicit()) return nullptr; - // Skip introducing additional contexts for var decls past the first in a - // pattern. The context necessary for the pattern as a whole was already - // introduced if necessary by the first var decl. - if (auto *VD = dyn_cast(D)) { - if (auto *PBD = VD->getParentPatternBinding()) { - if (VD != PBD->getAnchoringVarDecl(0)) - return nullptr; - } - } - // Declarations with an explicit availability attribute always get a TRC. if (hasActiveAvailableAttribute(D, Context)) { AvailabilityContext DeclaredAvailability = @@ -541,7 +543,8 @@ class TypeRefinementContextBuilder : private ASTWalker { getCurrentTRC()->getAvailabilityInfo(); AvailabilityContext EffectiveAvailability = getEffectiveAvailabilityForDeclSignature(D, CurrentAvailability); - if ((isa(D) && refinementSourceRangeForDecl(D).isValid()) || + if ((isa(D) && + refinementSourceRangeForDecl(D).isValid()) || CurrentAvailability.isSupersetOf(EffectiveAvailability)) return TypeRefinementContext::createForDeclImplicit( Context, D, getCurrentTRC(), EffectiveAvailability, @@ -617,22 +620,6 @@ class TypeRefinementContextBuilder : private ASTWalker { // the bodies of its accessors. SourceRange Range = storageDecl->getSourceRange(); - // For a variable declaration (without accessors) we use the range of the - // containing pattern binding declaration to make sure that we include - // any type annotation in the type refinement context range. We also - // need to include any custom attributes that were written on the - // declaration. - if (auto *varDecl = dyn_cast(storageDecl)) { - if (auto *PBD = varDecl->getParentPatternBinding()) - Range = PBD->getSourceRange(); - - for (auto attr : varDecl->getOriginalAttrs()) { - if (auto customAttr = dyn_cast(attr)) { - Range.widen(customAttr->getRange()); - } - } - } - // HACK: For synthesized trivial accessors we may have not a valid // location for the end of the braces, so in that case we will fall back // to using the range for the storage declaration. The right fix here is @@ -646,7 +633,13 @@ class TypeRefinementContextBuilder : private ASTWalker { return Range; } - + + // For pattern binding declarations, include the attributes in the source + // range so that we're sure to cover any property wrappers. + if (auto patternBinding = dyn_cast(D)) { + return D->getSourceRangeIncludingAttrs(); + } + return D->getSourceRange(); } @@ -662,7 +655,7 @@ class TypeRefinementContextBuilder : private ASTWalker { Context, D, getCurrentTRC(), Availability, range); } - /// Build contexts for a VarDecl with the given initializer. + /// Build contexts for a pattern binding declaration. void buildContextsForPatternBindingDecl(PatternBindingDecl *pattern) { // Build contexts for each of the pattern entries. for (unsigned index : range(pattern->getNumPatternEntries())) { @@ -1241,22 +1234,11 @@ bool ExpandChildTypeRefinementContextsRequest::evaluate( if (computeContainedByDeploymentTarget(parentTRC, ctx)) return false; - // Variables can have children corresponding to property wrappers and - // the initial values provided in each pattern binding entry. - if (auto var = dyn_cast(decl)) { - if (auto *pattern = var->getParentPatternBinding()) { - // Only do this for the first variable in the pattern binding declaration. - auto anchorVar = pattern->getAnchoringVarDecl(0); - if (anchorVar != var) { - return evaluateOrDefault( - evaluator, - ExpandChildTypeRefinementContextsRequest{anchorVar, parentTRC}, - false); - } - - TypeRefinementContextBuilder builder(parentTRC, ctx); - builder.buildContextsForPatternBindingDecl(pattern); - } + // Pattern binding declarations can have children corresponding to property + // wrappers and the initial values provided in each pattern binding entry. + if (auto *binding = dyn_cast(decl)) { + TypeRefinementContextBuilder builder(parentTRC, ctx); + builder.buildContextsForPatternBindingDecl(binding); } return false; @@ -1596,38 +1578,6 @@ concreteSyntaxDeclForAvailableAttribute(const Decl *AbstractSyntaxDecl) { return AbstractSyntaxDecl; } -/// Given a declaration upon which an availability attribute would appear in -/// concrete syntax, return a declaration to which the parser -/// actually attaches the attribute in the abstract syntax tree. We use this -/// function to determine whether the concrete syntax already has an -/// availability attribute. -static const Decl * -abstractSyntaxDeclForAvailableAttribute(const Decl *ConcreteSyntaxDecl) { - // This function needs to be kept in sync with its counterpart, - // concreteSyntaxDeclForAvailableAttribute(). - - if (auto *PBD = dyn_cast(ConcreteSyntaxDecl)) { - // Existing @available attributes in the AST are attached to VarDecls - // rather than PatternBindingDecls, so we return the first VarDecl for - // the pattern binding declaration. - // This is safe, even though there may be multiple VarDecls, because - // all parsed attribute that appear in the concrete syntax upon on the - // PatternBindingDecl are added to all of the VarDecls for the pattern - // binding. - if (PBD->getNumPatternEntries() != 0) { - return PBD->getAnchoringVarDecl(0); - } - } else if (auto *ECD = dyn_cast(ConcreteSyntaxDecl)) { - // Similar to the PatternBindingDecl case above, we return the - // first EnumElementDecl. - if (auto *Elem = ECD->getFirstElement()) { - return Elem; - } - } - - return ConcreteSyntaxDecl; -} - /// Given a declaration, return a better related declaration for which /// to suggest an @available fixit, or the original declaration /// if no such related declaration exists. diff --git a/test/Sema/availability_versions.swift b/test/Sema/availability_versions.swift index 7618d2943bf86..da27f7b9b695e 100644 --- a/test/Sema/availability_versions.swift +++ b/test/Sema/availability_versions.swift @@ -1741,3 +1741,25 @@ func useHasUnavailableExtension(_ s: HasUnavailableExtension) { s.inheritsUnavailable() // expected-error {{'inheritsUnavailable()' is unavailable in macOS}} s.moreAvailableButStillUnavailable() // expected-error {{'moreAvailableButStillUnavailable()' is unavailable in macOS}} } + +@available(macOS 10.15, *) +func f() -> Int { 17 } + +class StoredPropertiesWithAvailabilityInClosures { + private static let value: Int = { + if #available(macOS 10.15, *) { + return f() + } + + return 0 + }() + + @available(macOS 10.14, *) + private static let otherValue: Int = { + if #available(macOS 10.15, *) { + return f() + } + + return 0 + }() +} From 9f9929a1d0f86e08c28a1d8beb9bcb7cd3585d9a Mon Sep 17 00:00:00 2001 From: Mike Ash Date: Wed, 2 Aug 2023 16:15:22 -0400 Subject: [PATCH 20/37] [Concurrency] Fix crash when actor is dynamically subclassed. Dynamic subclasses have a NULL type descriptor. Make sure isDefaultActorClass doesn't try to dereference that NULL descriptor. rdar://112223265 --- stdlib/public/BackDeployConcurrency/Actor.cpp | 4 ++- stdlib/public/Concurrency/Actor.cpp | 3 +- .../Runtime/actor_dynamic_subclass.swift | 33 +++++++++++++++++++ 3 files changed, 38 insertions(+), 2 deletions(-) create mode 100644 test/Concurrency/Runtime/actor_dynamic_subclass.swift diff --git a/stdlib/public/BackDeployConcurrency/Actor.cpp b/stdlib/public/BackDeployConcurrency/Actor.cpp index 3db760534eb4b..a0c278cfdb25e 100644 --- a/stdlib/public/BackDeployConcurrency/Actor.cpp +++ b/stdlib/public/BackDeployConcurrency/Actor.cpp @@ -1712,8 +1712,10 @@ static bool isDefaultActorClass(const ClassMetadata *metadata) { assert(metadata->isTypeMetadata()); while (true) { // Trust the class descriptor if it says it's a default actor. - if (metadata->getDescription()->isDefaultActor()) + if (!metadata->isArtificialSubclass() && + metadata->getDescription()->isDefaultActor()) { return true; + } // Go to the superclass. metadata = metadata->Superclass; diff --git a/stdlib/public/Concurrency/Actor.cpp b/stdlib/public/Concurrency/Actor.cpp index 0df081f08b97e..e210fbb534bea 100644 --- a/stdlib/public/Concurrency/Actor.cpp +++ b/stdlib/public/Concurrency/Actor.cpp @@ -1775,7 +1775,8 @@ static bool isDefaultActorClass(const ClassMetadata *metadata) { assert(metadata->isTypeMetadata()); while (true) { // Trust the class descriptor if it says it's a default actor. - if (metadata->getDescription()->isDefaultActor()) { + if (!metadata->isArtificialSubclass() && + metadata->getDescription()->isDefaultActor()) { return true; } diff --git a/test/Concurrency/Runtime/actor_dynamic_subclass.swift b/test/Concurrency/Runtime/actor_dynamic_subclass.swift new file mode 100644 index 0000000000000..e1cbf96d68672 --- /dev/null +++ b/test/Concurrency/Runtime/actor_dynamic_subclass.swift @@ -0,0 +1,33 @@ +// RUN: %target-run-simple-swift(-Xfrontend -disable-availability-checking -parse-as-library) + +// REQUIRES: executable_test +// REQUIRES: concurrency +// REQUIRES: objc_interop + +// UNSUPPORTED: back_deployment_runtime +// UNSUPPORTED: use_os_stdlib + +// Make sure the concurrency runtime tolerates dynamically-subclassed actors. + +import ObjectiveC + +actor Foo: NSObject { + var x = 0 + + func doit() async { + x += 1 + try! await Task.sleep(nanoseconds: 1000) + x += 1 + } +} + +@main +enum Main { + static func main() async { + let FooSub = objc_allocateClassPair(Foo.self, "FooSub", 0) as! Foo.Type + objc_registerClassPair(FooSub) + let foosub = FooSub.init() + await foosub.doit() + } +} + From e96b2a65b2164c476d60b7f85f5e37cf383b901d Mon Sep 17 00:00:00 2001 From: Michael Gottesman Date: Wed, 2 Aug 2023 16:57:38 -0700 Subject: [PATCH 21/37] Update a test --- test/SIL/store_borrow_verify_errors.sil | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/SIL/store_borrow_verify_errors.sil b/test/SIL/store_borrow_verify_errors.sil index b9b9f3cc97692..be949d5730bc4 100644 --- a/test/SIL/store_borrow_verify_errors.sil +++ b/test/SIL/store_borrow_verify_errors.sil @@ -199,7 +199,7 @@ bb1: } // CHECK: Begin Error in function test_store_borrow_dest -// CHECK: SIL verification failed: store_borrow destination can only be an alloc_stack: isa(SI->getDest()) +// CHECK: SIL verification failed: store_borrow destination can only be an alloc_stack: isLegal(SI->getDest()) // CHECK: Verifying instruction: // CHECK: %0 = argument of bb0 : $Klass // user: %3 // CHECK: %2 = struct_element_addr %1 : $*NonTrivialStruct, #NonTrivialStruct.val // user: %3 From 8fb42945071a97bca80e593abbfcb02797f57291 Mon Sep 17 00:00:00 2001 From: Artem Chikin Date: Wed, 2 Aug 2023 16:14:40 -0700 Subject: [PATCH 22/37] [Compile Time Constant Extraction] Map types with archetypes out of context, before mangling them for printing. Matching logic in the ASTPrinter. Otherwise we attempt to mangle types with archetypes in them, which cannot be done, and causes the compiler to crash. Resolves rdar://113039215 --- lib/ConstExtract/ConstExtract.cpp | 5 ++- test/ConstExtraction/ExtractArchetype.swift | 36 +++++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) create mode 100644 test/ConstExtraction/ExtractArchetype.swift diff --git a/lib/ConstExtract/ConstExtract.cpp b/lib/ConstExtract/ConstExtract.cpp index 063ca59a66f2b..0291aa6b1e07f 100644 --- a/lib/ConstExtract/ConstExtract.cpp +++ b/lib/ConstExtract/ConstExtract.cpp @@ -85,7 +85,10 @@ std::string toFullyQualifiedProtocolNameString(const swift::ProtocolDecl &Protoc } std::string toMangledTypeNameString(const swift::Type &Type) { - return Mangle::ASTMangler().mangleTypeWithoutPrefix(Type->getCanonicalType()); + auto PrintingType = Type; + if (Type->hasArchetype()) + PrintingType = Type->mapTypeOutOfContext(); + return Mangle::ASTMangler().mangleTypeWithoutPrefix(PrintingType->getCanonicalType()); } } // namespace diff --git a/test/ConstExtraction/ExtractArchetype.swift b/test/ConstExtraction/ExtractArchetype.swift new file mode 100644 index 0000000000000..91127d4aa337c --- /dev/null +++ b/test/ConstExtraction/ExtractArchetype.swift @@ -0,0 +1,36 @@ +// RUN: %empty-directory(%t) +// RUN: echo "[MyProto]" > %t/protocols.json + +// RUN: %target-swift-frontend -typecheck -emit-const-values-path %t/ExtractEnums.swiftconstvalues -const-gather-protocols-file %t/protocols.json -primary-file %s +// RUN: cat %t/ExtractEnums.swiftconstvalues 2>&1 | %FileCheck %s + +protocol MyProto {} + +public struct Foo { + init(bar: Any) { + } +} + +public struct ArchetypalConformance: MyProto { + let baz: Foo = Foo(bar: T.self) + public init() {} +} + +// CHECK: [ +// CHECK-NEXT: { +// CHECK-NEXT: "typeName": "ExtractArchetype.ArchetypalConformance" +// CHECK: "valueKind": "InitCall", +// CHECK-NEXT: "value": { +// CHECK-NEXT: "type": "ExtractArchetype.Foo", +// CHECK-NEXT: "arguments": [ +// CHECK-NEXT: { +// CHECK-NEXT: "label": "bar", +// CHECK-NEXT: "type": "Any", +// CHECK-NEXT: "valueKind": "Type", +// CHECK-NEXT: "value": { +// CHECK-NEXT: "type": "T", +// CHECK-NEXT: "mangledName": "x" +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: ] +// CHECK-NEXT: } From a5d9b13ef0745d797edb6f15c24f68a7df5ed2e5 Mon Sep 17 00:00:00 2001 From: Hamish Knight Date: Thu, 3 Aug 2023 14:17:54 +0100 Subject: [PATCH 23/37] [CodeComplete] Avoid `let`/`var` completions in a few cases Don't suggest `let` or `var` in e.g the sequence expression of a `for` loop, or after a `return`. We ought to do a better job of checking whether we're in expression position before suggesting these (as opposed to a pattern), but I'm leaving that as future work for now. --- lib/IDE/CodeCompletion.cpp | 10 ++++++++-- test/IDE/complete_keywords.swift | 24 ++++++++++++++++-------- 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/lib/IDE/CodeCompletion.cpp b/lib/IDE/CodeCompletion.cpp index 5dfb622622e83..e4cc1de623c46 100644 --- a/lib/IDE/CodeCompletion.cpp +++ b/lib/IDE/CodeCompletion.cpp @@ -1016,13 +1016,19 @@ void CodeCompletionCallbacksImpl::addKeywords(CodeCompletionResultSink &Sink, addStmtKeywords(Sink, CurDeclContext, MaybeFuncBody); addClosureSignatureKeywordsIfApplicable(Sink, CurDeclContext); + LLVM_FALLTHROUGH; + case CompletionKind::PostfixExprBeginning: + // We need to add 'let' and 'var' keywords in expression position here as + // we initially parse patterns as expressions. + // FIXME: We ought to be able to determine if we're in a pattern context and + // only enable 'let' and 'var' in that case. + addLetVarKeywords(Sink); + LLVM_FALLTHROUGH; case CompletionKind::ReturnStmtExpr: case CompletionKind::YieldStmtExpr: - case CompletionKind::PostfixExprBeginning: case CompletionKind::ForEachSequence: addSuperKeyword(Sink, CurDeclContext); - addLetVarKeywords(Sink); addExprKeywords(Sink, CurDeclContext); addAnyTypeKeyword(Sink, CurDeclContext->getASTContext().TheAnyType); break; diff --git a/test/IDE/complete_keywords.swift b/test/IDE/complete_keywords.swift index e2a63a6e5e822..8ea50c0149dbc 100644 --- a/test/IDE/complete_keywords.swift +++ b/test/IDE/complete_keywords.swift @@ -256,8 +256,11 @@ // // let and var // -// KW_EXPR-DAG: Keyword[let]/None: let{{; name=.+$}} -// KW_EXPR-DAG: Keyword[var]/None: var{{; name=.+$}} +// KW_LETVAR-DAG: Keyword[let]/None: let{{; name=.+$}} +// KW_LETVAR-DAG: Keyword[var]/None: var{{; name=.+$}} +// +// KW_LETVAR_NEG-NOT: Keyword[let]/None: let{{; name=.+$}} +// KW_LETVAR_NEG-NOT: Keyword[var]/None: var{{; name=.+$}} // // Literals // @@ -423,24 +426,29 @@ extension SubClass { } func inExpr1() { - (#^EXPR_1?check=KW_EXPR;check=KW_EXPR_NEG^#) + (#^EXPR_1?check=KW_EXPR;check=KW_LETVAR;check=KW_EXPR_NEG^#) } func inExpr2() { - let x = #^EXPR_2?check=KW_EXPR;check=KW_EXPR_NEG^# + let x = #^EXPR_2?check=KW_EXPR;check=KW_LETVAR;check=KW_EXPR_NEG^# } func inExpr3() { - if #^EXPR_3?check=KW_EXPR;check=KW_EXPR_NEG^# {} + if #^EXPR_3?check=KW_EXPR;check=KW_LETVAR;check=KW_EXPR_NEG^# {} } func inExpr4() { let x = 1 - x + #^EXPR_4?check=KW_EXPR;check=KW_EXPR_NEG^# + x + #^EXPR_4?check=KW_EXPR;check=KW_LETVAR;check=KW_EXPR_NEG^# } func inExpr5() { var x: Int - x = #^EXPR_5?check=KW_EXPR;check=KW_EXPR_NEG^# + x = #^EXPR_5?check=KW_EXPR;check=KW_LETVAR;check=KW_EXPR_NEG^# } func inExpr6() -> Int { - return #^EXPR_6?check=KW_EXPR;check=KW_EXPR_NEG^# + // Make sure we don't recommend 'let' and 'var' here. + return #^EXPR_6?check=KW_EXPR;check=KW_EXPR_NEG;check=KW_LETVAR_NEG^# +} +func inExpr7() { + // Make sure we don't recommend 'let' and 'var' here. + for x in #^EXPR_7?check=KW_EXPR;check=KW_EXPR_NEG;check=KW_LETVAR_NEG^# } func inSwitch(val: Int) { From 4c6fc1ae8b5473aac546cc05501e3c01fe4b0c0a Mon Sep 17 00:00:00 2001 From: Nate Chandler Date: Thu, 20 Jul 2023 15:12:21 -0700 Subject: [PATCH 24/37] [OpaqueValues] Skip keypath lvalue temporary. When emitting key-path functions on behalf of a base that's an l-value, copy the l-value into a temporary only when using lowered addresses. --- lib/SILGen/SILGenLValue.cpp | 13 +++++++++---- test/SILGen/opaque_values_silgen.swift | 21 +++++++++++++++++++++ 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/lib/SILGen/SILGenLValue.cpp b/lib/SILGen/SILGenLValue.cpp index b20190b7164cd..9106c04adb6a5 100644 --- a/lib/SILGen/SILGenLValue.cpp +++ b/lib/SILGen/SILGenLValue.cpp @@ -2180,10 +2180,15 @@ static ManagedValue makeBaseConsumableMaterializedRValue(SILGenFunction &SGF, SILLocation loc, ManagedValue base) { if (base.isLValue()) { - auto tmp = SGF.emitTemporaryAllocation(loc, base.getType()); - SGF.B.createCopyAddr(loc, base.getValue(), tmp, - IsNotTake, IsInitialization); - return SGF.emitManagedBufferWithCleanup(tmp); + if (SGF.useLoweredAddresses()) { + auto tmp = SGF.emitTemporaryAllocation(loc, base.getType()); + SGF.B.createCopyAddr(loc, base.getValue(), tmp, IsNotTake, + IsInitialization); + return SGF.emitManagedBufferWithCleanup(tmp); + } + return SGF.emitLoad(loc, base.getValue(), + SGF.getTypeLowering(base.getType()), SGFContext(), + IsNotTake); } bool isBorrowed = base.isPlusZeroRValueOrTrivial() diff --git a/test/SILGen/opaque_values_silgen.swift b/test/SILGen/opaque_values_silgen.swift index 436cc65577778..c024174a726ba 100644 --- a/test/SILGen/opaque_values_silgen.swift +++ b/test/SILGen/opaque_values_silgen.swift @@ -6,6 +6,10 @@ class C {} +struct MyInt { + var int: Int +} + func genericInout(_: inout T) {} func hasVarArg(_ args: Any...) {} @@ -678,3 +682,20 @@ func FormClassKeyPath() { } _ = \Q.q } + +// CHECK-LABEL: sil {{.*}}[ossa] @UseGetterOnInout : {{.*}} { +// CHECK: bb0([[CONTAINER_ADDR:%[^,]+]] : +// CHECK: [[KEYPATH:%[^,]+]] = keypath $WritableKeyPath, (root $MyInt; stored_property #MyInt.int : $Int) +// CHECK: [[CONTAINER_ACCESS:%[^,]+]] = begin_access [read] [unknown] [[CONTAINER_ADDR]] +// CHECK: [[KEYPATH_UP:%[^,]+]] = upcast [[KEYPATH]] +// CHECK: [[CONTAINER:%[^,]+]] = load [trivial] [[CONTAINER_ACCESS]] +// CHECK: [[GETTER:%[^,]+]] = function_ref @swift_getAtKeyPath +// CHECK: [[VALUE:%[^,]+]] = apply [[GETTER]]([[CONTAINER]], [[KEYPATH_UP]]) +// CHECK: end_access [[CONTAINER_ACCESS]] +// CHECK: destroy_value [[KEYPATH_UP]] +// CHECK: return [[VALUE]] : $Int +// CHECK-LABEL: } // end sil function 'UseGetterOnInout' +@_silgen_name("UseGetterOnInout") +func getInout(_ i: inout MyInt) -> Int { + return i[keyPath: \MyInt.int] +} From 399e5b469df491cd0c1e40e4fa587c3eb946438e Mon Sep 17 00:00:00 2001 From: Egor Zhdan Date: Thu, 3 Aug 2023 14:48:11 +0100 Subject: [PATCH 25/37] [cxx-interop] Add test for mutable `std.optional.pointee` See https://github.com/apple/swift/pull/67648 --- test/Interop/Cxx/stdlib/use-std-optional.swift | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/Interop/Cxx/stdlib/use-std-optional.swift b/test/Interop/Cxx/stdlib/use-std-optional.swift index 4d1d1878c4c0a..7a55d480529f2 100644 --- a/test/Interop/Cxx/stdlib/use-std-optional.swift +++ b/test/Interop/Cxx/stdlib/use-std-optional.swift @@ -12,6 +12,10 @@ StdOptionalTestSuite.test("pointee") { let nonNilOpt = getNonNilOptional() let pointee = nonNilOpt.pointee expectEqual(123, pointee) + + var modifiedOpt = getNilOptional() + modifiedOpt.pointee = 777 + expectEqual(777, modifiedOpt.pointee) } StdOptionalTestSuite.test("std::optional => Swift.Optional") { From 10e7312b9fa5340a225e22012bd68bb652ecc9d3 Mon Sep 17 00:00:00 2001 From: Alex Hoppen Date: Thu, 3 Aug 2023 08:18:32 -0700 Subject: [PATCH 26/37] Remove a leftover `print` statement --- lib/Macros/Sources/ObservationMacros/Extensions.swift | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/Macros/Sources/ObservationMacros/Extensions.swift b/lib/Macros/Sources/ObservationMacros/Extensions.swift index eb6589475daf3..65fe80f2a9bf4 100644 --- a/lib/Macros/Sources/ObservationMacros/Extensions.swift +++ b/lib/Macros/Sources/ObservationMacros/Extensions.swift @@ -76,7 +76,6 @@ extension VariableDeclSyntax { if accessorsMatching({ $0 == .keyword(.get) }).count > 0 { return true } else { - print("else branch") return bindings.contains { binding in if case .getter = binding.accessorBlock?.accessors { return true From 03334a8f92d290da62fb2f2ccb7b3426b9f84a33 Mon Sep 17 00:00:00 2001 From: Anton Korobeynikov Date: Thu, 3 Aug 2023 09:33:11 -0700 Subject: [PATCH 27/37] [AutoDiff] Generalize handling of semantic result parameters (#67230) Introduce the notion of "semantic result parameter". Handle differentiation of inouts via semantic result parameter abstraction. Do not consider non-wrt semantic result parameters as semantic results Fixes #67174 --- include/swift/AST/AutoDiff.h | 12 +- include/swift/AST/Types.h | 54 +++++-- include/swift/SIL/ApplySite.h | 12 ++ include/swift/SIL/SILInstruction.h | 29 ++++ lib/AST/AutoDiff.cpp | 18 +-- lib/AST/Type.cpp | 43 ++---- lib/SIL/IR/SILFunctionType.cpp | 136 +++++++++++------- .../DifferentiableActivityAnalysis.cpp | 4 +- lib/SILOptimizer/Differentiation/Common.cpp | 25 ++-- .../Differentiation/LinearMapInfo.cpp | 25 ++-- .../Differentiation/PullbackCloner.cpp | 16 ++- lib/SILOptimizer/Differentiation/Thunk.cpp | 32 +++-- .../Differentiation/VJPCloner.cpp | 74 +++++----- .../Mandatory/Differentiation.cpp | 10 +- .../inout_differentiability_witness.swift | 8 +- test/AutoDiff/SILGen/witness_table.swift | 20 +-- .../SILOptimizer/activity_analysis.swift | 48 +++---- .../differentiation_diagnostics.swift | 8 +- .../forward_mode_diagnostics.swift | 4 +- .../Sema/derivative_attr_type_checking.swift | 22 +-- .../differentiable_attr_type_checking.swift | 2 +- ...e-55745-noderivative-inout-parameter.swift | 5 +- .../validation-test/forward_mode_simple.swift | 24 +--- .../validation-test/inout_parameters.swift | 15 +- 24 files changed, 365 insertions(+), 281 deletions(-) diff --git a/include/swift/AST/AutoDiff.h b/include/swift/AST/AutoDiff.h index d70d91ebbdfd4..890bf4cd274c9 100644 --- a/include/swift/AST/AutoDiff.h +++ b/include/swift/AST/AutoDiff.h @@ -246,16 +246,16 @@ inline llvm::raw_ostream &operator<<(llvm::raw_ostream &s, return s; } -/// A semantic function result type: either a formal function result type or -/// an `inout` parameter type. Used in derivative function type calculation. +/// A semantic function result type: either a formal function result type or a +/// semantic result (an `inout`) parameter type. Used in derivative function type +/// calculation. struct AutoDiffSemanticFunctionResultType { Type type; unsigned index : 30; - bool isInout : 1; - bool isWrtParam : 1; + bool isSemanticResultParameter : 1; - AutoDiffSemanticFunctionResultType(Type t, unsigned idx, bool inout, bool wrt) - : type(t), index(idx), isInout(inout), isWrtParam(wrt) { } + AutoDiffSemanticFunctionResultType(Type t, unsigned idx, bool param) + : type(t), index(idx), isSemanticResultParameter(param) { } }; /// Key for caching SIL derivative function types. diff --git a/include/swift/AST/Types.h b/include/swift/AST/Types.h index c19f5a7a9706f..f8ff0bbafa643 100644 --- a/include/swift/AST/Types.h +++ b/include/swift/AST/Types.h @@ -3192,6 +3192,12 @@ class AnyFunctionType : public TypeBase { /// Whether the parameter is marked '@noDerivative'. bool isNoDerivative() const { return Flags.isNoDerivative(); } + /// Whether the parameter might be a semantic result for autodiff purposes. + /// This includes inout parameters. + bool isAutoDiffSemanticResult() const { + return isInOut(); + } + ValueOwnership getValueOwnership() const { return Flags.getValueOwnership(); } @@ -3509,8 +3515,8 @@ class AnyFunctionType : public TypeBase { /// Preconditions: /// - Parameters corresponding to parameter indices must conform to /// `Differentiable`. - /// - There is one semantic function result type: either the formal original - /// result or an `inout` parameter. It must conform to `Differentiable`. + /// - There are semantic function result type: either the formal original + /// result or a "wrt" semantic result parameter. /// /// Differential typing rules: takes "wrt" parameter derivatives and returns a /// "wrt" result derivative. @@ -3518,10 +3524,7 @@ class AnyFunctionType : public TypeBase { /// - Case 1: original function has no `inout` parameters. /// - Original: `(T0, T1, ...) -> R` /// - Differential: `(T0.Tan, T1.Tan, ...) -> R.Tan` - /// - Case 2: original function has a non-wrt `inout` parameter. - /// - Original: `(T0, inout T1, ...) -> Void` - /// - Differential: `(T0.Tan, ...) -> T1.Tan` - /// - Case 3: original function has a wrt `inout` parameter. + /// - Case 2: original function has a wrt `inout` parameter. /// - Original: `(T0, inout T1, ...) -> Void` /// - Differential: `(T0.Tan, inout T1.Tan, ...) -> Void` /// @@ -3531,10 +3534,7 @@ class AnyFunctionType : public TypeBase { /// - Case 1: original function has no `inout` parameters. /// - Original: `(T0, T1, ...) -> R` /// - Pullback: `R.Tan -> (T0.Tan, T1.Tan, ...)` - /// - Case 2: original function has a non-wrt `inout` parameter. - /// - Original: `(T0, inout T1, ...) -> Void` - /// - Pullback: `(T1.Tan) -> (T0.Tan, ...)` - /// - Case 3: original function has a wrt `inout` parameter. + /// - Case 2: original function has a wrt `inout` parameter. /// - Original: `(T0, inout T1, ...) -> Void` /// - Pullback: `(inout T1.Tan) -> (T0.Tan, ...)` /// @@ -4101,6 +4101,9 @@ class SILParameterInfo { return getConvention() == ParameterConvention::Indirect_Inout || getConvention() == ParameterConvention::Indirect_InoutAliasable; } + bool isAutoDiffSemanticResult() const { + return isIndirectMutating(); + } bool isPack() const { return isPackParameter(getConvention()); @@ -4836,6 +4839,37 @@ class SILFunctionType final return llvm::count_if(getParameters(), IndirectMutatingParameterFilter()); } + struct AutoDiffSemanticResultsParameterFilter { + bool operator()(SILParameterInfo param) const { + return param.isAutoDiffSemanticResult(); + } + }; + + using AutoDiffSemanticResultsParameterIter = + llvm::filter_iterator; + using AutoDiffSemanticResultsParameterRange = + iterator_range; + + /// A range of SILParameterInfo for all semantic results parameters. + AutoDiffSemanticResultsParameterRange + getAutoDiffSemanticResultsParameters() const { + return llvm::make_filter_range(getParameters(), + AutoDiffSemanticResultsParameterFilter()); + } + + /// Returns the number of semantic results parameters. + unsigned getNumAutoDiffSemanticResultsParameters() const { + return llvm::count_if(getParameters(), AutoDiffSemanticResultsParameterFilter()); + } + + /// Returns the number of function potential semantic results: + /// * Usual results + /// * Inout parameters + unsigned getNumAutoDiffSemanticResults() const { + return getNumResults() + getNumAutoDiffSemanticResultsParameters(); + } + /// Get the generic signature that the component types are specified /// in terms of, if any. CanGenericSignature getSubstGenericSignature() const { diff --git a/include/swift/SIL/ApplySite.h b/include/swift/SIL/ApplySite.h index a38e25bb3313a..06888dc4d681c 100644 --- a/include/swift/SIL/ApplySite.h +++ b/include/swift/SIL/ApplySite.h @@ -681,6 +681,18 @@ class FullApplySite : public ApplySite { llvm_unreachable("invalid apply kind"); } + AutoDiffSemanticResultArgumentRange getAutoDiffSemanticResultArguments() const { + switch (getKind()) { + case FullApplySiteKind::ApplyInst: + return cast(getInstruction())->getAutoDiffSemanticResultArguments(); + case FullApplySiteKind::TryApplyInst: + return cast(getInstruction())->getAutoDiffSemanticResultArguments(); + case FullApplySiteKind::BeginApplyInst: + return cast(getInstruction())->getAutoDiffSemanticResultArguments(); + } + llvm_unreachable("invalid apply kind"); + } + /// Returns true if \p op is the callee operand of this apply site /// and not an argument operand. bool isCalleeOperand(const Operand &op) const { diff --git a/include/swift/SIL/SILInstruction.h b/include/swift/SIL/SILInstruction.h index aa8762ba22767..0dd02756268be 100644 --- a/include/swift/SIL/SILInstruction.h +++ b/include/swift/SIL/SILInstruction.h @@ -2785,6 +2785,25 @@ struct OperandToInoutArgument { using InoutArgumentRange = OptionalTransformRange, OperandToInoutArgument>; +/// Predicate used to filter AutoDiffSemanticResultArgumentRange. +struct OperandToAutoDiffSemanticResultArgument { + ArrayRef paramInfos; + OperandValueArrayRef arguments; + OperandToAutoDiffSemanticResultArgument(ArrayRef paramInfos, + OperandValueArrayRef arguments) + : paramInfos(paramInfos), arguments(arguments) { + assert(paramInfos.size() == arguments.size()); + } + llvm::Optional operator()(size_t i) const { + if (paramInfos[i].isAutoDiffSemanticResult()) + return arguments[i]; + return llvm::None; + } +}; + +using AutoDiffSemanticResultArgumentRange = + OptionalTransformRange, OperandToAutoDiffSemanticResultArgument>; + /// The partial specialization of ApplyInstBase for full applications. /// Adds some methods relating to 'self' and to result types that don't /// make sense for partial applications. @@ -2894,6 +2913,16 @@ class ApplyInstBase impl.getArgumentsWithoutIndirectResults())); } + /// Returns all autodiff semantic result (`@inout`, `@inout_aliasable`) + /// arguments passed to the instruction. + AutoDiffSemanticResultArgumentRange getAutoDiffSemanticResultArguments() const { + auto &impl = asImpl(); + return AutoDiffSemanticResultArgumentRange( + indices(getArgumentsWithoutIndirectResults()), + OperandToAutoDiffSemanticResultArgument(impl.getSubstCalleeConv().getParameters(), + impl.getArgumentsWithoutIndirectResults())); + } + bool hasSemantics(StringRef semanticsString) const { return doesApplyCalleeHaveSemantics(getCallee(), semanticsString); } diff --git a/lib/AST/AutoDiff.cpp b/lib/AST/AutoDiff.cpp index c7a04a16939e7..3778543fc9acf 100644 --- a/lib/AST/AutoDiff.cpp +++ b/lib/AST/AutoDiff.cpp @@ -199,32 +199,28 @@ void autodiff::getFunctionSemanticResults( if (formalResultType->is()) { for (auto elt : formalResultType->castTo()->getElements()) { resultTypes.emplace_back(elt.getType(), resultIdx++, - /*isInout*/ false, /*isWrt*/ false); + /*isParameter*/ false); } } else { resultTypes.emplace_back(formalResultType, resultIdx++, - /*isInout*/ false, /*isWrt*/ false); + /*isParameter*/ false); } } - bool addNonWrts = resultTypes.empty(); - - // Collect wrt `inout` parameters as semantic results - // As an extention, collect all (including non-wrt) inouts as results for - // functions returning void. + // Collect wrt semantic result (`inout`) parameters as + // semantic results auto collectSemanticResults = [&](const AnyFunctionType *functionType, unsigned curryOffset = 0) { for (auto paramAndIndex : enumerate(functionType->getParams())) { - if (!paramAndIndex.value().isInOut()) + if (!paramAndIndex.value().isAutoDiffSemanticResult()) continue; unsigned idx = paramAndIndex.index() + curryOffset; assert(idx < parameterIndices->getCapacity() && "invalid parameter index"); - bool isWrt = parameterIndices->contains(idx); - if (addNonWrts || isWrt) + if (parameterIndices->contains(idx)) resultTypes.emplace_back(paramAndIndex.value().getPlainType(), - resultIdx, /*isInout*/ true, isWrt); + resultIdx, /*isParameter*/ true); resultIdx += 1; } }; diff --git a/lib/AST/Type.cpp b/lib/AST/Type.cpp index 36f0a9f95f134..f0385dfdbbed6 100644 --- a/lib/AST/Type.cpp +++ b/lib/AST/Type.cpp @@ -5558,7 +5558,7 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType( return llvm::make_error( this, DerivativeFunctionTypeError::Kind::NoSemanticResults); - // Accumulate non-inout result tangent spaces. + // Accumulate non-semantic result tangent spaces. SmallVector resultTanTypes, inoutTanTypes; for (auto i : range(originalResults.size())) { auto originalResult = originalResults[i]; @@ -5577,16 +5577,10 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType( this, DerivativeFunctionTypeError::Kind::NonDifferentiableResult, std::make_pair(originalResultType, unsigned(originalResult.index))); - if (!originalResult.isInout) + if (!originalResult.isSemanticResultParameter) resultTanTypes.push_back(resultTan->getType()); - else if (originalResult.isInout && !originalResult.isWrtParam) - inoutTanTypes.push_back(resultTan->getType()); } - // Treat non-wrt inouts as semantic results for functions returning Void - if (resultTanTypes.empty()) - resultTanTypes = inoutTanTypes; - // Compute the result linear map function type. FunctionType *linearMapType; switch (kind) { @@ -5597,11 +5591,7 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType( // - Original: `(T0, T1, ...) -> R` // - Differential: `(T0.Tan, T1.Tan, ...) -> R.Tan` // - // Case 2: original function has a non-wrt `inout` parameter. - // - Original: `(T0, inout T1, ...) -> Void` - // - Differential: `(T0.Tan, ...) -> T1.Tan` - // - // Case 3: original function has a wrt `inout` parameter. + // Case 2: original function has a wrt `inout` parameter. // - Original: `(T0, inout T1, ...) -> Void` // - Differential: `(T0.Tan, inout T1.Tan, ...) -> Void` SmallVector differentialParams; @@ -5648,15 +5638,11 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType( // - Original: `(T0, T1, ...) -> R` // - Pullback: `R.Tan -> (T0.Tan, T1.Tan, ...)` // - // Case 2: original function has a non-wrt `inout` parameter. - // - Original: `(T0, inout T1, ...) -> Void` - // - Pullback: `(T1.Tan) -> (T0.Tan, ...)` - // - // Case 3: original function has wrt `inout` parameters. + // Case 2: original function has wrt `inout` parameters. // - Original: `(T0, inout T1, ...) -> R` // - Pullback: `(R.Tan, inout T1.Tan) -> (T0.Tan, ...)` SmallVector pullbackResults; - SmallVector inoutParams; + SmallVector semanticResultParams; for (auto i : range(diffParams.size())) { auto diffParam = diffParams[i]; auto paramType = diffParam.getPlainType(); @@ -5669,10 +5655,10 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType( NonDifferentiableDifferentiabilityParameter, std::make_pair(paramType, i)); - if (diffParam.isInOut()) { + if (diffParam.isAutoDiffSemanticResult()) { if (paramType->isVoid()) continue; - inoutParams.push_back(diffParam); + semanticResultParams.push_back(diffParam); continue; } pullbackResults.emplace_back(paramTan->getType()); @@ -5693,15 +5679,15 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType( pullbackParams.push_back(AnyFunctionType::Param( resultTanType, Identifier(), flags)); } - // Then append inout parameters. - for (auto i : range(inoutParams.size())) { - auto inoutParam = inoutParams[i]; - auto inoutParamType = inoutParam.getPlainType(); - auto inoutParamTan = - inoutParamType->getAutoDiffTangentSpace(lookupConformance); + // Then append semantic result parameters. + for (auto i : range(semanticResultParams.size())) { + auto semanticResultParam = semanticResultParams[i]; + auto semanticResultParamType = semanticResultParam.getPlainType(); + auto semanticResultParamTan = + semanticResultParamType->getAutoDiffTangentSpace(lookupConformance); auto flags = ParameterTypeFlags().withInOut(true); pullbackParams.push_back(AnyFunctionType::Param( - inoutParamTan->getType(), Identifier(), flags)); + semanticResultParamTan->getType(), Identifier(), flags)); } // FIXME: Verify ExtInfo state is correct, not working by accident. FunctionType::ExtInfo info; @@ -5709,6 +5695,7 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType( break; } } + assert(linearMapType && "Expected linear map type"); return linearMapType; } diff --git a/lib/SIL/IR/SILFunctionType.cpp b/lib/SIL/IR/SILFunctionType.cpp index 6853badd31b2c..ad9f32baedf01 100644 --- a/lib/SIL/IR/SILFunctionType.cpp +++ b/lib/SIL/IR/SILFunctionType.cpp @@ -237,9 +237,11 @@ IndexSubset *SILFunctionType::getDifferentiabilityResultIndices() { if (resultAndIndex.value().getDifferentiability() != SILResultDifferentiability::NotDifferentiable) resultIndices.push_back(resultAndIndex.index()); + + auto numSemanticResults = getNumResults(); - // Check `inout` parameters. - for (auto inoutParamAndIndex : enumerate(getIndirectMutatingParameters())) + // Check semantic results (`inout`) parameters. + for (auto resultParamAndIndex : enumerate(getAutoDiffSemanticResultsParameters())) // Currently, an `inout` parameter can either be: // 1. Both a differentiability parameter and a differentiability result. // 2. `@noDerivative`: neither a differentiability parameter nor a @@ -251,16 +253,13 @@ IndexSubset *SILFunctionType::getDifferentiabilityResultIndices() { // cases, so supporting it is a non-goal. // // See TF-1305 for solution ideas. For now, `@noDerivative` `inout` - // parameters are not treated as differentiability results, unless the - // original function has no formal results, in which case all `inout` - // parameters are treated as differentiability results. - if (resultIndices.empty() || - inoutParamAndIndex.value().getDifferentiability() != + // parameters are not treated as differentiability results. + if (resultParamAndIndex.value().getDifferentiability() != SILParameterDifferentiability::NotDifferentiable) - resultIndices.push_back(getNumResults() + inoutParamAndIndex.index()); + resultIndices.push_back(getNumResults() + resultParamAndIndex.index()); + + numSemanticResults += getNumAutoDiffSemanticResultsParameters(); - auto numSemanticResults = - getNumResults() + getNumIndirectMutatingParameters(); return IndexSubset::get(getASTContext(), numSemanticResults, resultIndices); } @@ -369,18 +368,19 @@ getDifferentiabilityParameters(SILFunctionType *originalFnTy, /// Collects the semantic results of the given function type in /// `originalResults`. The semantic results are formal results followed by -/// `inout` parameters, in type order. +/// semantic result parameters, in type order. static void -getSemanticResults(SILFunctionType *functionType, IndexSubset *parameterIndices, +getSemanticResults(SILFunctionType *functionType, + IndexSubset *parameterIndices, SmallVectorImpl &originalResults) { // Collect original formal results. originalResults.append(functionType->getResults().begin(), functionType->getResults().end()); - // Collect original `inout` parameters. + // Collect original semantic result parameters. for (auto i : range(functionType->getNumParameters())) { auto param = functionType->getParameters()[i]; - if (!param.isIndirectMutating()) + if (!param.isAutoDiffSemanticResult()) continue; if (param.getDifferentiability() != SILParameterDifferentiability::NotDifferentiable) originalResults.emplace_back(param.getInterfaceType(), ResultConvention::Indirect); @@ -597,23 +597,25 @@ static CanSILFunctionType getAutoDiffDifferentialType( differentialResults.push_back({resultTanType, resultConv}); continue; } - // Handle original `inout` parameters. - auto inoutParamIndex = resultIndex - originalFnTy->getNumResults(); - auto inoutParamIt = std::next( - originalFnTy->getIndirectMutatingParameters().begin(), inoutParamIndex); + // Handle original semantic result parameters. + auto resultParamIndex = resultIndex - originalFnTy->getNumResults(); + auto resultParamIt = std::next( + originalFnTy->getAutoDiffSemanticResultsParameters().begin(), + resultParamIndex); auto paramIndex = - std::distance(originalFnTy->getParameters().begin(), &*inoutParamIt); - // If the original `inout` parameter is a differentiability parameter, then - // it already has a corresponding differential parameter. Skip adding a - // corresponding differential result. + std::distance(originalFnTy->getParameters().begin(), &*resultParamIt); + // If the original semantic result parameter is a differentiability + // parameter, then it already has a corresponding differential + // parameter. Skip adding a corresponding differential result. if (parameterIndices->contains(paramIndex)) continue; - auto inoutParam = originalFnTy->getParameters()[paramIndex]; - auto inoutParamTanType = getAutoDiffTangentTypeForLinearMap( - inoutParam.getInterfaceType(), lookupConformance, + + auto resultParam = originalFnTy->getParameters()[paramIndex]; + auto resultParamTanType = getAutoDiffTangentTypeForLinearMap( + resultParam.getInterfaceType(), lookupConformance, substGenericParams, substReplacements, ctx); - differentialResults.push_back( - {inoutParamTanType, ResultConvention::Indirect}); + differentialResults.emplace_back(resultParamTanType, + ResultConvention::Indirect); } SubstitutionMap substitutions; @@ -734,28 +736,29 @@ static CanSILFunctionType getAutoDiffPullbackType( ->getAutoDiffTangentSpace(lookupConformance) ->getCanonicalType(), origRes.getConvention()); - pullbackParams.push_back({resultTanType, paramConv}); + pullbackParams.emplace_back(resultTanType, paramConv); continue; } - // Handle `inout` parameters. - auto inoutParamIndex = resultIndex - originalFnTy->getNumResults(); - auto inoutParamIt = std::next( - originalFnTy->getIndirectMutatingParameters().begin(), inoutParamIndex); + // Handle original semantic result parameters. + auto resultParamIndex = resultIndex - originalFnTy->getNumResults(); + auto resultParamIt = std::next( + originalFnTy->getAutoDiffSemanticResultsParameters().begin(), + resultParamIndex); auto paramIndex = - std::distance(originalFnTy->getParameters().begin(), &*inoutParamIt); - auto inoutParam = originalFnTy->getParameters()[paramIndex]; + std::distance(originalFnTy->getParameters().begin(), &*resultParamIt); + auto resultParam = originalFnTy->getParameters()[paramIndex]; // The pullback parameter convention depends on whether the original `inout` // parameter is a differentiability parameter. // - If yes, the pullback parameter convention is `@inout`. // - If no, the pullback parameter convention is `@in_guaranteed`. - auto inoutParamTanType = getAutoDiffTangentTypeForLinearMap( - inoutParam.getInterfaceType(), lookupConformance, - substGenericParams, substReplacements, ctx); - bool isWrtInoutParameter = parameterIndices->contains(paramIndex); - auto paramTanConvention = isWrtInoutParameter - ? inoutParam.getConvention() - : ParameterConvention::Indirect_In_Guaranteed; - pullbackParams.push_back({inoutParamTanType, paramTanConvention}); + auto resultParamTanType = getAutoDiffTangentTypeForLinearMap( + resultParam.getInterfaceType(), lookupConformance, + substGenericParams, substReplacements, ctx); + ParameterConvention paramTanConvention = resultParam.getConvention(); + if (!parameterIndices->contains(paramIndex)) + paramTanConvention = ParameterConvention::Indirect_In_Guaranteed; + + pullbackParams.emplace_back(resultParamTanType, paramTanConvention); } // Collect pullback results. @@ -763,9 +766,9 @@ static CanSILFunctionType getAutoDiffPullbackType( getDifferentiabilityParameters(originalFnTy, parameterIndices, diffParams); SmallVector pullbackResults; for (auto ¶m : diffParams) { - // Skip `inout` parameters, which semantically behave as original results - // and always appear as pullback parameters. - if (param.isIndirectMutating()) + // Skip semantic result parameters, which semantically behave as original + // results and always appear as pullback parameters. + if (param.isAutoDiffSemanticResult()) continue; auto paramTanType = getAutoDiffTangentTypeForLinearMap( param.getInterfaceType(), lookupConformance, @@ -898,6 +901,7 @@ CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType( origTypeOfAbstraction, TC); break; } + // Compute the derivative function parameters. SmallVector newParameters; newParameters.reserve(constrainedOriginalFnTy->getNumParameters()); @@ -4091,6 +4095,40 @@ static llvm::cl::opt DisableConstantInfoCache("sil-disable-typelowering-constantinfo-cache", llvm::cl::init(false)); +static IndexSubset * +getLoweredResultIndices(const SILFunctionType *functionType, + const IndexSubset *parameterIndices) { + SmallVector resultIndices; + + // Check formal results. + for (auto resultAndIndex : enumerate(functionType->getResults())) + if (resultAndIndex.value().getDifferentiability() != + SILResultDifferentiability::NotDifferentiable) + resultIndices.push_back(resultAndIndex.index()); + + auto numResults = functionType->getNumResults(); + + // Collect semantic result parameters. + unsigned semResultParamIdx = 0; + for (auto resultParamAndIndex + : enumerate(functionType->getParameters())) { + if (!resultParamAndIndex.value().isAutoDiffSemanticResult()) + continue; + + if (resultParamAndIndex.value().getDifferentiability() != + SILParameterDifferentiability::NotDifferentiable && + parameterIndices->contains(resultParamAndIndex.index())) + resultIndices.push_back(numResults + semResultParamIdx); + semResultParamIdx += 1; + } + + numResults += semResultParamIdx; + + return IndexSubset::get(functionType->getASTContext(), + numResults, resultIndices); +} + + const SILConstantInfo & TypeConverter::getConstantInfo(TypeExpansionContext expansion, SILDeclRef constant) { @@ -4149,11 +4187,9 @@ TypeConverter::getConstantInfo(TypeExpansionContext expansion, // Use it to compute lowered derivative function type. auto *loweredParamIndices = autodiff::getLoweredParameterIndices( derivativeId->getParameterIndices(), formalInterfaceType); - auto numResults = - origFnConstantInfo.SILFnType->getNumResults() + - origFnConstantInfo.SILFnType->getNumIndirectMutatingParameters(); - auto *loweredResultIndices = IndexSubset::getDefault( - M.getASTContext(), numResults, /*includeAll*/ true); + auto *loweredResultIndices + = getLoweredResultIndices(origFnConstantInfo.SILFnType, loweredParamIndices); + silFnType = origFnConstantInfo.SILFnType->getAutoDiffDerivativeFunctionType( loweredParamIndices, loweredResultIndices, derivativeId->getKind(), *this, LookUpConformanceInModule(&M)); diff --git a/lib/SILOptimizer/Analysis/DifferentiableActivityAnalysis.cpp b/lib/SILOptimizer/Analysis/DifferentiableActivityAnalysis.cpp index 0cd474d74b521..2c1fbc5876bbf 100644 --- a/lib/SILOptimizer/Analysis/DifferentiableActivityAnalysis.cpp +++ b/lib/SILOptimizer/Analysis/DifferentiableActivityAnalysis.cpp @@ -138,8 +138,8 @@ void DifferentiableActivityInfo::propagateVaried( if (isVaried(operand->get(), i)) { for (auto indRes : applySite.getIndirectSILResults()) propagateVariedInwardsThroughProjections(indRes, i); - for (auto inoutArg : applySite.getInoutArguments()) - propagateVariedInwardsThroughProjections(inoutArg, i); + for (auto semresArg : applySite.getAutoDiffSemanticResultArguments()) + propagateVariedInwardsThroughProjections(semresArg, i); // Propagate variedness to apply site direct results. forEachApplyDirectResult(applySite, [&](SILValue directResult) { setVariedAndPropagateToUsers(directResult, i); diff --git a/lib/SILOptimizer/Differentiation/Common.cpp b/lib/SILOptimizer/Differentiation/Common.cpp index a8f749b3187c3..aec5f1e09b34a 100644 --- a/lib/SILOptimizer/Differentiation/Common.cpp +++ b/lib/SILOptimizer/Differentiation/Common.cpp @@ -147,11 +147,11 @@ void collectAllFormalResultsInTypeOrder(SILFunction &function, for (auto &resInfo : convs.getResults()) results.push_back(resInfo.isFormalDirect() ? dirResults[dirResIdx++] : indResults[indResIdx++]); - // Treat `inout` parameters as semantic results. - // Append `inout` parameters after formal results. + // Treat semantic result parameters as semantic results. + // Append them` parameters after formal results. for (auto i : range(convs.getNumParameters())) { auto paramInfo = convs.getParameters()[i]; - if (!paramInfo.isIndirectMutating()) + if (!paramInfo.isAutoDiffSemanticResult()) continue; auto *argument = function.getArgumentsWithoutIndirectResults()[i]; results.push_back(argument); @@ -190,6 +190,7 @@ void collectMinimalIndicesForFunctionCall( SmallVectorImpl &resultIndices) { auto calleeFnTy = ai->getSubstCalleeType(); auto calleeConvs = ai->getSubstCalleeConv(); + // Parameter indices are indices (in the callee type signature) of parameter // arguments that are varied or are arguments. // Record all parameter indices in type order. @@ -199,6 +200,7 @@ void collectMinimalIndicesForFunctionCall( paramIndices.push_back(currentParamIdx); ++currentParamIdx; } + // Result indices are indices (in the callee type signature) of results that // are useful. SmallVector directResults; @@ -226,22 +228,21 @@ void collectMinimalIndicesForFunctionCall( ++indResIdx; } } - // Record all `inout` parameters as results. - auto inoutParamResultIndex = calleeFnTy->getNumResults(); + + // Record all semantic result parameters as results. + auto semanticResultParamResultIndex = calleeFnTy->getNumResults(); for (auto ¶mAndIdx : enumerate(calleeConvs.getParameters())) { auto ¶m = paramAndIdx.value(); - if (!param.isIndirectMutating()) + if (!param.isAutoDiffSemanticResult()) continue; unsigned idx = paramAndIdx.index() + calleeFnTy->getNumIndirectFormalResults(); - auto inoutArg = ai->getArgument(idx); - results.push_back(inoutArg); - resultIndices.push_back(inoutParamResultIndex++); + results.push_back(ai->getArgument(idx)); + resultIndices.push_back(semanticResultParamResultIndex++); } + // Make sure the function call has active results. #ifndef NDEBUG - auto numResults = calleeFnTy->getNumResults() + - calleeFnTy->getNumIndirectMutatingParameters(); - assert(results.size() == numResults); + assert(results.size() == calleeFnTy->getNumAutoDiffSemanticResults()); assert(llvm::any_of(results, [&](SILValue result) { return activityInfo.isActive(result, parentConfig); })); diff --git a/lib/SILOptimizer/Differentiation/LinearMapInfo.cpp b/lib/SILOptimizer/Differentiation/LinearMapInfo.cpp index f469607de2758..9797c3e982381 100644 --- a/lib/SILOptimizer/Differentiation/LinearMapInfo.cpp +++ b/lib/SILOptimizer/Differentiation/LinearMapInfo.cpp @@ -177,7 +177,7 @@ Type LinearMapInfo::getLinearMapType(ADContext &context, ApplyInst *ai) { auto hasActiveResults = llvm::any_of(allResults, [&](SILValue res) { return activityInfo.isActive(res, config); }); - bool hasActiveInoutArgument = false; + bool hasActiveSemanticResultArgument = false; bool hasActiveArguments = false; auto numIndirectResults = ai->getNumIndirectResults(); for (auto argIdx : range(ai->getSubstCalleeConv().getNumParameters())) { @@ -186,13 +186,13 @@ Type LinearMapInfo::getLinearMapType(ADContext &context, ApplyInst *ai) { hasActiveArguments = true; auto paramInfo = ai->getSubstCalleeConv().getParamInfoForSILArg( numIndirectResults + argIdx); - if (paramInfo.isIndirectMutating()) - hasActiveInoutArgument = true; + if (paramInfo.isAutoDiffSemanticResult()) + hasActiveSemanticResultArgument = true; } } if (!hasActiveArguments) return {}; - if (!hasActiveResults && !hasActiveInoutArgument) + if (!hasActiveResults && !hasActiveSemanticResultArgument) return {}; // Compute differentiability parameters. @@ -213,9 +213,8 @@ Type LinearMapInfo::getLinearMapType(ADContext &context, ApplyInst *ai) { ai->getArgumentsWithoutIndirectResults().size(), activeParamIndices); } // Compute differentiability results. - auto numResults = remappedOrigFnSubstTy->getNumResults() + - remappedOrigFnSubstTy->getNumIndirectMutatingParameters(); - auto *results = IndexSubset::get(original->getASTContext(), numResults, + auto *results = IndexSubset::get(original->getASTContext(), + remappedOrigFnSubstTy->getNumAutoDiffSemanticResults(), activeResultIndices); // Create autodiff indices for the `apply` instruction. AutoDiffConfig applyConfig(parameters, results); @@ -234,10 +233,11 @@ Type LinearMapInfo::getLinearMapType(ADContext &context, ApplyInst *ai) { for (auto resultIndex : applyConfig.resultIndices->getIndices()) { SILType remappedResultType; if (resultIndex >= origFnTy->getNumResults()) { - auto inoutArgIdx = resultIndex - origFnTy->getNumResults(); - auto inoutArg = - *std::next(ai->getInoutArguments().begin(), inoutArgIdx); - remappedResultType = inoutArg->getType(); + auto semanticResultArgIdx = resultIndex - origFnTy->getNumResults(); + auto semanticResultArg = + *std::next(ai->getAutoDiffSemanticResultArguments().begin(), + semanticResultArgIdx); + remappedResultType = semanticResultArg->getType(); } else { remappedResultType = origFnTy->getResults()[resultIndex].getSILStorageInterfaceType(); @@ -277,8 +277,9 @@ Type LinearMapInfo::getLinearMapType(ADContext &context, ApplyInst *ai) { SmallVector params; for (auto ¶m : silFnTy->getParameters()) { ParameterTypeFlags flags; - if (param.isIndirectMutating()) + if (param.isAutoDiffSemanticResult()) flags = flags.withInOut(true); + params.push_back( AnyFunctionType::Param(param.getInterfaceType(), Identifier(), flags)); } diff --git a/lib/SILOptimizer/Differentiation/PullbackCloner.cpp b/lib/SILOptimizer/Differentiation/PullbackCloner.cpp index c7eb9aa769424..f6ef9a33fdfb7 100644 --- a/lib/SILOptimizer/Differentiation/PullbackCloner.cpp +++ b/lib/SILOptimizer/Differentiation/PullbackCloner.cpp @@ -864,6 +864,7 @@ class PullbackCloner::Implementation final /// Adjoint: (adj[x0], adj[x1], ...) += apply @fn_pullback (adj[y0], ...) void visitApplyInst(ApplyInst *ai) { assert(getPullbackInfo().shouldDifferentiateApplySite(ai)); + // Skip `array.uninitialized_intrinsic` applications, which have special // `store` and `copy_addr` support. if (ArraySemanticsCall(ai, semantics::ARRAY_UNINITIALIZED_INTRINSIC)) @@ -901,11 +902,11 @@ class PullbackCloner::Implementation final }); SmallVector origAllResults; collectAllActualResultsInTypeOrder(ai, origDirectResults, origAllResults); - // Append `inout` arguments after original results. + // Append semantic result arguments after original results. for (auto paramIdx : applyInfo.config.parameterIndices->getIndices()) { auto paramInfo = ai->getSubstCalleeConv().getParamInfoForSILArg( ai->getNumIndirectResults() + paramIdx); - if (!paramInfo.isIndirectMutating()) + if (!paramInfo.isAutoDiffSemanticResult()) continue; origAllResults.push_back( ai->getArgumentsWithoutIndirectResults()[paramIdx]); @@ -981,10 +982,10 @@ class PullbackCloner::Implementation final auto allResultsIt = allResults.begin(); for (unsigned i : applyInfo.config.parameterIndices->getIndices()) { auto origArg = ai->getArgument(ai->getNumIndirectResults() + i); - // Skip adjoint accumulation for `inout` arguments. + // Skip adjoint accumulation for semantic results arguments. auto paramInfo = ai->getSubstCalleeConv().getParamInfoForSILArg( ai->getNumIndirectResults() + i); - if (paramInfo.isIndirectMutating()) + if (paramInfo.isAutoDiffSemanticResult()) continue; auto tan = *allResultsIt++; if (tan->getType().isAddress()) { @@ -2036,6 +2037,7 @@ bool PullbackCloner::Implementation::run() { // the adjoint buffer of the original result. auto seedParamInfo = pullback.getLoweredFunctionType()->getParameters()[seedIndex]; + if (seedParamInfo.isIndirectInOut()) { setAdjointBuffer(originalExitBlock, origResult, seed); } @@ -2123,7 +2125,7 @@ bool PullbackCloner::Implementation::run() { // Collect differentiation parameter adjoints. // Do a first pass to collect non-inout values. for (auto i : getConfig().parameterIndices->getIndices()) { - if (!conv.getParameters()[i].isIndirectMutating()) { + if (!conv.getParameters()[i].isAutoDiffSemanticResult()) { addRetElt(i); } } @@ -2136,14 +2138,14 @@ bool PullbackCloner::Implementation::run() { const auto &pullbackConv = pullback.getConventions(); SmallVector pullbackInOutArgs; for (auto pullbackArg : enumerate(pullback.getArgumentsWithoutIndirectResults())) { - if (pullbackConv.getParameters()[pullbackArg.index()].isIndirectMutating()) + if (pullbackConv.getParameters()[pullbackArg.index()].isAutoDiffSemanticResult()) pullbackInOutArgs.push_back(pullbackArg.value()); } unsigned pullbackInoutArgumentIdx = 0; for (auto i : getConfig().parameterIndices->getIndices()) { // Skip non-inout parameters. - if (!conv.getParameters()[i].isIndirectMutating()) + if (!conv.getParameters()[i].isAutoDiffSemanticResult()) continue; // For functions with multiple basic blocks, accumulation is needed diff --git a/lib/SILOptimizer/Differentiation/Thunk.cpp b/lib/SILOptimizer/Differentiation/Thunk.cpp index 487bce2929183..116ea3a2d228a 100644 --- a/lib/SILOptimizer/Differentiation/Thunk.cpp +++ b/lib/SILOptimizer/Differentiation/Thunk.cpp @@ -365,7 +365,9 @@ getOrCreateSubsetParametersThunkForLinearMap( const AutoDiffConfig &desiredConfig, const AutoDiffConfig &actualConfig, ADContext &adContext) { LLVM_DEBUG(getADDebugStream() - << "Getting a subset parameters thunk for " << linearMapType + << "Getting a subset parameters thunk for " + << (kind == AutoDiffDerivativeFunctionKind::JVP ? "jvp" : "vjp") + << " linear map " << linearMapType << " from " << actualConfig << " to " << desiredConfig << '\n'); assert(!linearMapType->getCombinedSubstitutions()); @@ -539,10 +541,10 @@ getOrCreateSubsetParametersThunkForLinearMap( unsigned pullbackResultIndex = 0; for (unsigned i : actualConfig.parameterIndices->getIndices()) { auto origParamInfo = origFnType->getParameters()[i]; - // Skip original `inout` parameters. All non-indirect-result pullback - // arguments (including `inout` arguments) are appended to `arguments` + // Skip original semantic result parameters. All non-indirect-result pullback + // arguments (including semantic result` arguments) are appended to `arguments` // later. - if (origParamInfo.isIndirectMutating()) + if (origParamInfo.isAutoDiffSemanticResult()) continue; auto resultInfo = linearMapType->getResults()[pullbackResultIndex]; assert(pullbackResultIndex < linearMapType->getNumResults()); @@ -619,16 +621,18 @@ getOrCreateSubsetParametersThunkForLinearMap( extractAllElements(ai, builder, pullbackDirectResults); SmallVector allResults; collectAllActualResultsInTypeOrder(ai, pullbackDirectResults, allResults); - // Collect pullback `inout` arguments in type order. - unsigned inoutArgIdx = 0; + // Collect pullback semantic result arguments in type order. + unsigned semanticResultArgIdx = 0; SILFunctionConventions origConv(origFnType, thunk->getModule()); for (auto paramIdx : actualConfig.parameterIndices->getIndices()) { auto paramInfo = origConv.getParameters()[paramIdx]; - if (!paramInfo.isIndirectMutating()) + if (!paramInfo.isAutoDiffSemanticResult()) continue; - auto inoutArg = *std::next(ai->getInoutArguments().begin(), inoutArgIdx++); + auto semanticResultArg = + *std::next(ai->getAutoDiffSemanticResultArguments().begin(), + semanticResultArgIdx++); unsigned mappedParamIdx = mapOriginalParameterIndex(paramIdx); - allResults.insert(allResults.begin() + mappedParamIdx, inoutArg); + allResults.insert(allResults.begin() + mappedParamIdx, semanticResultArg); } assert(allResults.size() == actualConfig.parameterIndices->getNumIndices() && "Number of pullback results should match number of differentiability " @@ -668,8 +672,10 @@ getOrCreateSubsetParametersThunkForDerivativeFunction( AutoDiffDerivativeFunctionKind kind, const AutoDiffConfig &desiredConfig, const AutoDiffConfig &actualConfig, ADContext &adContext) { LLVM_DEBUG(getADDebugStream() - << "Getting a subset parameters thunk for derivative function " - << derivativeFn << " of the original function " << origFnOperand + << "Getting a subset parameters thunk for derivative " + << (kind == AutoDiffDerivativeFunctionKind::JVP ? "jvp" : "vjp") + << " function " << derivativeFn + << " of the original function " << origFnOperand << " from " << actualConfig << " to " << desiredConfig << '\n'); auto origFnType = origFnOperand->getType().castTo(); @@ -823,9 +829,7 @@ getOrCreateSubsetParametersThunkForDerivativeFunction( SILType::getPrimitiveObjectType(linearMapTargetType), /*withoutActuallyEscaping*/ false); } - assert(origFnType->getNumResults() + - origFnType->getNumIndirectMutatingParameters() > - 0); + assert(origFnType->getNumAutoDiffSemanticResults() > 0); if (origFnType->getNumResults() > 0 && origFnType->getResults().front().isFormalDirect()) { directResults.push_back(thunkedLinearMap); diff --git a/lib/SILOptimizer/Differentiation/VJPCloner.cpp b/lib/SILOptimizer/Differentiation/VJPCloner.cpp index f91175c655e91..7cb715b278553 100644 --- a/lib/SILOptimizer/Differentiation/VJPCloner.cpp +++ b/lib/SILOptimizer/Differentiation/VJPCloner.cpp @@ -449,24 +449,22 @@ class VJPCloner::Implementation final activeResultIndices); assert(!activeParamIndices.empty() && "Parameter indices cannot be empty"); assert(!activeResultIndices.empty() && "Result indices cannot be empty"); - LLVM_DEBUG(auto &s = getADDebugStream() << "Active indices: params={"; + LLVM_DEBUG(auto &s = getADDebugStream() << "Active indices: params=("; llvm::interleave( activeParamIndices.begin(), activeParamIndices.end(), [&s](unsigned i) { s << i; }, [&s] { s << ", "; }); - s << "}, results={"; llvm::interleave( + s << "), results=("; llvm::interleave( activeResultIndices.begin(), activeResultIndices.end(), [&s](unsigned i) { s << i; }, [&s] { s << ", "; }); - s << "}\n";); + s << ")\n";); // Form expected indices. - auto numSemanticResults = - ai->getSubstCalleeType()->getNumResults() + - ai->getSubstCalleeType()->getNumIndirectMutatingParameters(); AutoDiffConfig config( IndexSubset::get(getASTContext(), ai->getArgumentsWithoutIndirectResults().size(), activeParamIndices), - IndexSubset::get(getASTContext(), numSemanticResults, + IndexSubset::get(getASTContext(), + ai->getSubstCalleeType()->getNumAutoDiffSemanticResults(), activeResultIndices)); // Emit the VJP. @@ -537,10 +535,11 @@ class VJPCloner::Implementation final for (auto resultIndex : config.resultIndices->getIndices()) { SILType remappedResultType; if (resultIndex >= originalFnTy->getNumResults()) { - auto inoutArgIdx = resultIndex - originalFnTy->getNumResults(); - auto inoutArg = - *std::next(ai->getInoutArguments().begin(), inoutArgIdx); - remappedResultType = inoutArg->getType(); + auto semanticResultArgIdx = resultIndex - originalFnTy->getNumResults(); + auto semanticResultArg = + *std::next(ai->getAutoDiffSemanticResultArguments().begin(), + semanticResultArgIdx); + remappedResultType = semanticResultArg->getType(); } else { remappedResultType = originalFnTy->getResults()[resultIndex] .getSILStorageInterfaceType(); @@ -891,55 +890,57 @@ SILFunction *VJPCloner::Implementation::createEmptyPullback() { auto config = witness->getConfig(); // Add pullback parameters based on original result indices. - SmallVector inoutParamIndices; + SmallVector semanticResultParamIndices; for (auto i : range(origTy->getNumParameters())) { auto origParam = origParams[i]; - if (!origParam.isIndirectInOut()) + if (!origParam.isAutoDiffSemanticResult()) continue; - inoutParamIndices.push_back(i); + semanticResultParamIndices.push_back(i); } + for (auto resultIndex : config.resultIndices->getIndices()) { // Handle formal result. if (resultIndex < origTy->getNumResults()) { auto origResult = origTy->getResults()[resultIndex]; origResult = origResult.getWithInterfaceType( origResult.getInterfaceType()->getReducedType(witnessCanGenSig)); - pbParams.push_back(getTangentParameterInfoForOriginalResult( + auto paramInfo = getTangentParameterInfoForOriginalResult( origResult.getInterfaceType() ->getAutoDiffTangentSpace(lookupConformance) ->getType() ->getReducedType(witnessCanGenSig), - origResult.getConvention())); + origResult.getConvention()); + pbParams.push_back(paramInfo); continue; } - // Handle `inout` parameter. + + // Handle semantic result parameter. unsigned paramIndex = 0; - unsigned inoutParamIndex = 0; + unsigned resultParamIndex = 0; for (auto i : range(origTy->getNumParameters())) { auto origParam = origTy->getParameters()[i]; - if (!origParam.isIndirectMutating()) { + if (!origParam.isAutoDiffSemanticResult()) { ++paramIndex; continue; } - if (inoutParamIndex == resultIndex - origTy->getNumResults()) + if (resultParamIndex == resultIndex - origTy->getNumResults()) break; ++paramIndex; - ++inoutParamIndex; + ++resultParamIndex; } - auto inoutParam = origParams[paramIndex]; - auto origResult = inoutParam.getWithInterfaceType( - inoutParam.getInterfaceType()->getReducedType(witnessCanGenSig)); - auto inoutParamTanConvention = - config.isWrtParameter(paramIndex) - ? inoutParam.getConvention() - : ParameterConvention::Indirect_In_Guaranteed; - SILParameterInfo inoutParamTanParam( - origResult.getInterfaceType() - ->getAutoDiffTangentSpace(lookupConformance) - ->getType() - ->getReducedType(witnessCanGenSig), - inoutParamTanConvention); - pbParams.push_back(inoutParamTanParam); + auto resultParam = origParams[paramIndex]; + auto origResult = resultParam.getWithInterfaceType( + resultParam.getInterfaceType()->getReducedType(witnessCanGenSig)); + + auto resultParamTanConvention = resultParam.getConvention(); + if (!config.isWrtParameter(paramIndex)) + resultParamTanConvention = ParameterConvention::Indirect_In_Guaranteed; + + pbParams.emplace_back(origResult.getInterfaceType() + ->getAutoDiffTangentSpace(lookupConformance) + ->getType() + ->getReducedType(witnessCanGenSig), + resultParamTanConvention); } if (pullbackInfo.hasHeapAllocatedContext()) { @@ -961,7 +962,7 @@ SILFunction *VJPCloner::Implementation::createEmptyPullback() { // Add pullback results for the requested wrt parameters. for (auto i : config.parameterIndices->getIndices()) { auto origParam = origParams[i]; - if (origParam.isIndirectMutating()) + if (origParam.isAutoDiffSemanticResult()) continue; origParam = origParam.getWithInterfaceType( origParam.getInterfaceType()->getReducedType(witnessCanGenSig)); @@ -997,6 +998,7 @@ SILFunction *VJPCloner::Implementation::createEmptyPullback() { original->isRuntimeAccessible()); pullback->setDebugScope(new (module) SILDebugScope(original->getLocation(), pullback)); + return pullback; } diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index 8a1f59a260a76..60f57f80ebe85 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -544,11 +544,11 @@ emitDerivativeFunctionReference( for (auto resultIndex : desiredResultIndices->getIndices()) { SILType resultType; if (resultIndex >= originalFnTy->getNumResults()) { - auto inoutParamIdx = resultIndex - originalFnTy->getNumResults(); - auto inoutParam = - *std::next(originalFnTy->getIndirectMutatingParameters().begin(), - inoutParamIdx); - resultType = inoutParam.getSILStorageInterfaceType(); + auto semanticResultParamIdx = resultIndex - originalFnTy->getNumResults(); + auto semanticResultParam = + *std::next(originalFnTy->getAutoDiffSemanticResultsParameters().begin(), + semanticResultParamIdx); + resultType = semanticResultParam.getSILStorageInterfaceType(); } else { resultType = originalFnTy->getResults()[resultIndex] .getSILStorageInterfaceType(); diff --git a/test/AutoDiff/SILGen/inout_differentiability_witness.swift b/test/AutoDiff/SILGen/inout_differentiability_witness.swift index e49b4e92a947d..f146f966a2b14 100644 --- a/test/AutoDiff/SILGen/inout_differentiability_witness.swift +++ b/test/AutoDiff/SILGen/inout_differentiability_witness.swift @@ -17,7 +17,7 @@ func test3(x: Int, y: inout DiffableStruct, z: Float) -> (Float, Float) { return @differentiable(reverse, wrt: y) func test4(x: Int, y: inout DiffableStruct, z: Float) -> Void { } -@differentiable(reverse, wrt: z) +@differentiable(reverse, wrt: (y, z)) func test5(x: Int, y: inout DiffableStruct, z: Float) -> Void { } @differentiable(reverse, wrt: (y, z)) @@ -48,9 +48,9 @@ func test6(x: Int, y: inout DiffableStruct, z: Float) -> (Float, Float) { return // CHECK: } // CHECK-LABEL: differentiability witness for test5(x:y:z:) -// CHECK: sil_differentiability_witness hidden [reverse] [parameters 2] [results 0] @$s31inout_differentiability_witness5test51x1y1zySi_AA14DiffableStructVzSftF : $@convention(thin) (Int, @inout DiffableStruct, Float) -> () { -// CHECK: jvp: @$s31inout_differentiability_witness5test51x1y1zySi_AA14DiffableStructVzSftFTJfUUSpSr : $@convention(thin) (Int, @inout DiffableStruct, Float) -> @owned @callee_guaranteed (Float) -> @out DiffableStruct.TangentVector -// CHECK: vjp: @$s31inout_differentiability_witness5test51x1y1zySi_AA14DiffableStructVzSftFTJrUUSpSr : $@convention(thin) (Int, @inout DiffableStruct, Float) -> @owned @callee_guaranteed (@in_guaranteed DiffableStruct.TangentVector) -> Float +// CHECK: sil_differentiability_witness hidden [reverse] [parameters 1 2] [results 0] @$s31inout_differentiability_witness5test51x1y1zySi_AA14DiffableStructVzSftF : $@convention(thin) (Int, @inout DiffableStruct, Float) -> () { +// CHECK: jvp: @$s31inout_differentiability_witness5test51x1y1zySi_AA14DiffableStructVzSftFTJfUSSpSr : $@convention(thin) (Int, @inout DiffableStruct, Float) -> @owned @callee_guaranteed (@inout DiffableStruct.TangentVector, Float) -> () +// CHECK: vjp: @$s31inout_differentiability_witness5test51x1y1zySi_AA14DiffableStructVzSftFTJrUSSpSr : $@convention(thin) (Int, @inout DiffableStruct, Float) -> @owned @callee_guaranteed (@inout DiffableStruct.TangentVector) -> Float // CHECK: } // CHECK-LABEL: differentiability witness for test6(x:y:z:) diff --git a/test/AutoDiff/SILGen/witness_table.swift b/test/AutoDiff/SILGen/witness_table.swift index 5928bd844d09f..4f631f4067dc2 100644 --- a/test/AutoDiff/SILGen/witness_table.swift +++ b/test/AutoDiff/SILGen/witness_table.swift @@ -12,7 +12,7 @@ protocol Protocol: Differentiable { @differentiable(reverse) var property: Float { get set } - @differentiable(reverse, wrt: x) + @differentiable(reverse, wrt: (self, x)) subscript(_ x: Float, _ y: Float) -> Float { get set } } @@ -82,22 +82,22 @@ struct Struct: Protocol { // CHECK: apply [[VJP_FN]] // CHECK: } - @differentiable(reverse, wrt: x) + @differentiable(reverse, wrt: (self, x)) subscript(_ x: Float, _ y: Float) -> Float { get { x } set {} } - // CHECK-LABEL: sil private [transparent] [thunk] [ossa] @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcigTW_jvp_SUU : $@convention(witness_method: Protocol) (Float, Float, @in_guaranteed Struct) -> (Float, @owned @callee_guaranteed (Float) -> Float) { + // CHECK-LABEL: sil private [transparent] [thunk] [ossa] @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcigTW_jvp_SUS : $@convention(witness_method: Protocol) (Float, Float, @in_guaranteed Struct) -> (Float, @owned @callee_guaranteed @substituted <Ï„_0_0> (Float, @in_guaranteed Ï„_0_0) -> Float for ) // CHECK: [[ORIG_FN:%.*]] = function_ref @$s13witness_table6StructVyS2f_Sftcig : $@convention(method) (Float, Float, Struct) -> Float - // CHECK: [[DIFF_FN:%.*]] = differentiable_function [parameters 0] [results 0] [[ORIG_FN]] + // CHECK: [[DIFF_FN:%.*]] = differentiable_function [parameters 0 2] [results 0] [[ORIG_FN]] // CHECK: [[JVP_FN:%.*]] = differentiable_function_extract [jvp] [[DIFF_FN]] // CHECK: apply [[JVP_FN]] // CHECK: } - // CHECK-LABEL: sil private [transparent] [thunk] [ossa] @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcigTW_vjp_SUU : $@convention(witness_method: Protocol) (Float, Float, @in_guaranteed Struct) -> (Float, @owned @callee_guaranteed (Float) -> Float) { + // CHECK-LABEL: sil private [transparent] [thunk] [ossa] @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcigTW_vjp_SUS : $@convention(witness_method: Protocol) (Float, Float, @in_guaranteed Struct) -> (Float, @owned @callee_guaranteed @substituted <Ï„_0_0> (Float) -> (Float, @out Ï„_0_0) for ) // CHECK: [[ORIG_FN:%.*]] = function_ref @$s13witness_table6StructVyS2f_Sftcig : $@convention(method) (Float, Float, Struct) -> Float - // CHECK: [[DIFF_FN:%.*]] = differentiable_function [parameters 0] [results 0] [[ORIG_FN]] + // CHECK: [[DIFF_FN:%.*]] = differentiable_function [parameters 0 2] [results 0] [[ORIG_FN]] // CHECK: [[VJP_FN:%.*]] = differentiable_function_extract [vjp] [[DIFF_FN]] // CHECK: apply [[VJP_FN]] // CHECK: } @@ -118,10 +118,10 @@ struct Struct: Protocol { // CHECK-NEXT: method #Protocol.property!setter.vjp.SS.: (inout Self) -> (Float) -> () : @AD__$s13witness_table6StructVAA8ProtocolA2aDP8propertySfvsTW_vjp_SS // CHECK-NEXT: method #Protocol.property!modify: (inout Self) -> () -> () : @$s13witness_table6StructVAA8ProtocolA2aDP8propertySfvMTW // CHECK-NEXT: method #Protocol.subscript!getter: (Self) -> (Float, Float) -> Float : @$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcigTW -// CHECK-NEXT: method #Protocol.subscript!getter.jvp.SUU.: (Self) -> (Float, Float) -> Float : @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcigTW_jvp_SU -// CHECK-NEXT: method #Protocol.subscript!getter.vjp.SUU.: (Self) -> (Float, Float) -> Float : @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcigTW_vjp_SUU +// CHECK-NEXT: method #Protocol.subscript!getter.jvp.SUS.: (Self) -> (Float, Float) -> Float : @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcigTW_jvp_SUS +// CHECK-NEXT: method #Protocol.subscript!getter.vjp.SUS.: (Self) -> (Float, Float) -> Float : @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcigTW_vjp_SUS // CHECK-NEXT: method #Protocol.subscript!setter: (inout Self) -> (Float, Float, Float) -> () : @$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcisTW -// CHECK-NEXT: method #Protocol.subscript!setter.jvp.USUU.: (inout Self) -> (Float, Float, Float) -> () : @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcisTW_jvp_USUU -// CHECK-NEXT: method #Protocol.subscript!setter.vjp.USUU.: (inout Self) -> (Float, Float, Float) -> () : @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcisTW_vjp_USUU +// CHECK-NEXT: method #Protocol.subscript!setter.jvp.USUS.: (inout Self) -> (Float, Float, Float) -> () : @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcisTW_jvp_USUS +// CHECK-NEXT: method #Protocol.subscript!setter.vjp.USUS.: (inout Self) -> (Float, Float, Float) -> () : @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcisTW_vjp_USUS // CHECK-NEXT: method #Protocol.subscript!modify: (inout Self) -> (Float, Float) -> () : @$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftciMTW // CHECK: } diff --git a/test/AutoDiff/SILOptimizer/activity_analysis.swift b/test/AutoDiff/SILOptimizer/activity_analysis.swift index 02f75ebba20f4..f2cf4d0d9bf0c 100644 --- a/test/AutoDiff/SILOptimizer/activity_analysis.swift +++ b/test/AutoDiff/SILOptimizer/activity_analysis.swift @@ -407,23 +407,21 @@ func testArrayUninitializedIntrinsicApplyIndirectResult(_ x: T, _ y: T) -> [W struct Mut: Differentiable {} extension Mut { - @differentiable(reverse, wrt: x) + @differentiable(reverse, wrt: (self, x)) mutating func mutatingMethod(_ x: Mut) {} } -// CHECK-LABEL: [AD] Activity info for ${{.*}}3MutV14mutatingMethodyyACF at parameter indices (0) and result indices (0) +// CHECK-LABEL: [AD] Activity info for ${{.*}}3MutV14mutatingMethodyyACF at parameter indices (0, 1) and result indices (0) // CHECK: [VARIED] %0 = argument of bb0 : $Mut -// CHECK: [USEFUL] %1 = argument of bb0 : $*Mut +// CHECK: [ACTIVE] %1 = argument of bb0 : $*Mut -// TODO(TF-985): Find workaround to avoid marking non-wrt `inout` argument as -// active. -@differentiable(reverse, wrt: x) +@differentiable(reverse, wrt: (nonactive, x)) func nonActiveInoutArg(_ nonactive: inout Mut, _ x: Mut) { nonactive.mutatingMethod(x) nonactive = x } -// CHECK-LABEL: [AD] Activity info for ${{.*}}17nonActiveInoutArgyyAA3MutVz_ADtF at parameter indices (1) and result indices (0) +// CHECK-LABEL: [AD] Activity info for ${{.*}}17nonActiveInoutArgyyAA3MutVz_ADtF at parameter indices (0, 1) and result indices (0) // CHECK: [ACTIVE] %0 = argument of bb0 : $*Mut // CHECK: [ACTIVE] %1 = argument of bb0 : $Mut // CHECK: [ACTIVE] %4 = begin_access [modify] [static] %0 : $*Mut @@ -449,14 +447,14 @@ func activeInoutArgMutatingMethod(_ x: Mut) -> Mut { // CHECK: [ACTIVE] %11 = begin_access [read] [static] %2 : $*Mut // CHECK: [ACTIVE] %12 = load [trivial] %11 : $*Mut -@differentiable(reverse, wrt: x) +@differentiable(reverse, wrt: (nonactive, x)) func activeInoutArgMutatingMethodVar(_ nonactive: inout Mut, _ x: Mut) { var result = nonactive result.mutatingMethod(x) nonactive = result } -// CHECK-LABEL: [AD] Activity info for ${{.*}}31activeInoutArgMutatingMethodVaryyAA3MutVz_ADtF at parameter indices (1) and result indices (0) +// CHECK-LABEL: [AD] Activity info for ${{.*}}31activeInoutArgMutatingMethodVaryyAA3MutVz_ADtF at parameter indices (0, 1) and result indices (0) // CHECK: [ACTIVE] %0 = argument of bb0 : $*Mut // CHECK: [ACTIVE] %1 = argument of bb0 : $Mut // CHECK: [ACTIVE] %4 = alloc_stack $Mut, var, name "result" @@ -470,14 +468,14 @@ func activeInoutArgMutatingMethodVar(_ nonactive: inout Mut, _ x: Mut) { // CHECK: [ACTIVE] %15 = begin_access [modify] [static] %0 : $*Mut // CHECK: [NONE] %19 = tuple () -@differentiable(reverse, wrt: x) +@differentiable(reverse, wrt: (nonactive, x)) func activeInoutArgMutatingMethodTuple(_ nonactive: inout Mut, _ x: Mut) { var result = (nonactive, x) result.0.mutatingMethod(result.0) nonactive = result.0 } -// CHECK-LABEL: [AD] Activity info for ${{.*}}33activeInoutArgMutatingMethodTupleyyAA3MutVz_ADtF at parameter indices (1) and result indices (0) +// CHECK-LABEL: [AD] Activity info for ${{.*}}33activeInoutArgMutatingMethodTupleyyAA3MutVz_ADtF at parameter indices (0, 1) and result indices (0) // CHECK: [ACTIVE] %0 = argument of bb0 : $*Mut // CHECK: [ACTIVE] %1 = argument of bb0 : $Mut // CHECK: [ACTIVE] %4 = alloc_stack $(Mut, Mut), var, name "result" @@ -499,39 +497,39 @@ func activeInoutArgMutatingMethodTuple(_ nonactive: inout Mut, _ x: Mut) { // Check `inout` arguments. @differentiable(reverse) -func activeInoutArg(_ x: Float) -> Float { +func activeInoutArg(_ x: inout Float) -> Float { var result = x result += x return result } -// CHECK-LABEL: [AD] Activity info for ${{.*}}activeInoutArg{{.*}} at parameter indices (0) and result indices (0) -// CHECK: [ACTIVE] %0 = argument of bb0 : $Float +// CHECK-LABEL: [AD] Activity info for ${{.*}}activeInoutArg{{.*}} at parameter indices (0) and result indices (0, 1) +// CHECK: [ACTIVE] %0 = argument of bb0 : $*Float // CHECK: [ACTIVE] %2 = alloc_stack $Float, var, name "result" -// CHECK: [ACTIVE] %5 = begin_access [modify] [static] %2 : $*Float +// CHECK: [ACTIVE] %10 = begin_access [modify] [static] %2 : $*Float // CHECK: [NONE] // function_ref static Float.+= infix(_:_:) -// CHECK: [NONE] %7 = apply %6(%5, %0, %4) : $@convention(method) (@inout Float, Float, @thin Float.Type) -> () -// CHECK: [ACTIVE] %9 = begin_access [read] [static] %2 : $*Float -// CHECK: [ACTIVE] %10 = load [trivial] %9 : $*Float +// CHECK: [NONE] %12 = apply %11(%10, %8, %6) : $@convention(method) (@inout Float, Float, @thin Float.Type) -> () +// CHECK: [ACTIVE] %14 = begin_access [read] [static] %2 : $*Float +// CHECK: [ACTIVE] %15 = load [trivial] %14 : $*Float @differentiable(reverse) -func activeInoutArgNonactiveInitialResult(_ x: Float) -> Float { +func activeInoutArgNonactiveInitialResult(_ x: inout Float) -> Float { var result: Float = 1 result += x return result } -// CHECK-LABEL: [AD] Activity info for ${{.*}}activeInoutArgNonactiveInitialResult{{.*}} at parameter indices (0) and result indices (0) -// CHECK: [ACTIVE] %0 = argument of bb0 : $Float +// CHECK-LABEL: [AD] Activity info for ${{.*}}activeInoutArgNonactiveInitialResult{{.*}} at parameter indices (0) and result indices (0, 1) +// CHECK: [ACTIVE] %0 = argument of bb0 : $*Float // CHECK: [ACTIVE] %2 = alloc_stack $Float, var, name "result" // CHECK: [NONE] // function_ref Float.init(_builtinIntegerLiteral:) // CHECK: [USEFUL] %6 = apply %5(%3, %4) : $@convention(method) (Builtin.IntLiteral, @thin Float.Type) -> Float // CHECK: [USEFUL] %8 = metatype $@thin Float.Type -// CHECK: [ACTIVE] %9 = begin_access [modify] [static] %2 : $*Float +// CHECK: [ACTIVE] %12 = begin_access [modify] [static] %2 : $*Float // CHECK: [NONE] // function_ref static Float.+= infix(_:_:) -// CHECK: [NONE] %11 = apply %10(%9, %0, %8) : $@convention(method) (@inout Float, Float, @thin Float.Type) -> () -// CHECK: [ACTIVE] %13 = begin_access [read] [static] %2 : $*Float -// CHECK: [ACTIVE] %14 = load [trivial] %13 : $*Float +// CHECK: [NONE] %14 = apply %13(%12, %10, %8) : $@convention(method) (@inout Float, Float, @thin Float.Type) -> () +// CHECK: [ACTIVE] %16 = begin_access [read] [static] %2 : $*Float +// CHECK: [ACTIVE] %17 = load [trivial] %16 : $*Float //===----------------------------------------------------------------------===// // Throwing function differentiation (`try_apply`) diff --git a/test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift b/test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift index 0f106c55bff56..e149bcb3116db 100644 --- a/test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift +++ b/test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift @@ -400,11 +400,11 @@ func activeInoutParamControlFlowComplex(_ array: [Float], _ bool: Bool) -> Float struct Mut: Differentiable {} extension Mut { - @differentiable(reverse, wrt: x) + @differentiable(reverse, wrt: (self, x)) mutating func mutatingMethod(_ x: Mut) {} } -@differentiable(reverse, wrt: x) +@differentiable(reverse, wrt: (nonactive, x)) func nonActiveInoutParam(_ nonactive: inout Mut, _ x: Mut) { nonactive.mutatingMethod(x) } @@ -416,14 +416,14 @@ func activeInoutParamMutatingMethod(_ x: Mut) -> Mut { return result } -@differentiable(reverse, wrt: x) +@differentiable(reverse, wrt: (nonactive, x)) func activeInoutParamMutatingMethodVar(_ nonactive: inout Mut, _ x: Mut) { var result = nonactive result.mutatingMethod(x) nonactive = result } -@differentiable(reverse, wrt: x) +@differentiable(reverse, wrt: (nonactive, x)) func activeInoutParamMutatingMethodTuple(_ nonactive: inout Mut, _ x: Mut) { var result = (nonactive, x) result.0.mutatingMethod(result.0) diff --git a/test/AutoDiff/SILOptimizer/forward_mode_diagnostics.swift b/test/AutoDiff/SILOptimizer/forward_mode_diagnostics.swift index 5fb4d13407bba..cd29f9019c45c 100644 --- a/test/AutoDiff/SILOptimizer/forward_mode_diagnostics.swift +++ b/test/AutoDiff/SILOptimizer/forward_mode_diagnostics.swift @@ -89,7 +89,7 @@ func activeInoutParamControlFlow(_ array: [Float]) -> Float { struct X: Differentiable { var x: Float - @differentiable(reverse, wrt: y) + @differentiable(reverse, wrt: (self, y)) mutating func mutate(_ y: X) { self.x = y.x } } @@ -104,7 +104,7 @@ func activeMutatingMethod(_ x: Float) -> Float { struct Mut: Differentiable {} extension Mut { - @differentiable(reverse, wrt: x) + @differentiable(reverse, wrt: (self, x)) mutating func mutatingMethod(_ x: Mut) {} } diff --git a/test/AutoDiff/Sema/derivative_attr_type_checking.swift b/test/AutoDiff/Sema/derivative_attr_type_checking.swift index f3e26e8a32b13..5bd615d1188a1 100644 --- a/test/AutoDiff/Sema/derivative_attr_type_checking.swift +++ b/test/AutoDiff/Sema/derivative_attr_type_checking.swift @@ -769,7 +769,8 @@ struct InoutParameters: Differentiable { } extension InoutParameters { - // expected-note @+1 4 {{'staticMethod(_:rhs:)' defined here}} + // expected-note @+2 {{'staticMethod(_:rhs:)' defined here}} + // expected-note @+1 {{'staticMethod(_:rhs:)' defined here}} static func staticMethod(_ lhs: inout Self, rhs: Self) {} // Test wrt `inout` parameter. @@ -800,33 +801,34 @@ extension InoutParameters { // Test non-wrt `inout` parameter. + // expected-error @+1 {{cannot differentiate void function 'staticMethod(_:rhs:)'}} @derivative(of: staticMethod, wrt: rhs) static func vjpNotWrtInout(_ lhs: inout Self, _ rhs: Self) -> ( value: Void, pullback: (TangentVector) -> TangentVector ) { fatalError() } - // expected-error @+1 {{function result's 'pullback' type does not match 'staticMethod(_:rhs:)'}} + // expected-error @+1 {{cannot differentiate void function 'staticMethod(_:rhs:)'}} @derivative(of: staticMethod, wrt: rhs) static func vjpNotWrtInoutMismatch(_ lhs: inout Self, _ rhs: Self) -> ( - // expected-note @+1 {{'pullback' does not have expected type '(InoutParameters.TangentVector) -> InoutParameters.TangentVector' (aka '(DummyTangentVector) -> DummyTangentVector')}} value: Void, pullback: (inout TangentVector) -> TangentVector ) { fatalError() } + // expected-error @+1 {{cannot differentiate void function 'staticMethod(_:rhs:)'}} @derivative(of: staticMethod, wrt: rhs) static func jvpNotWrtInout(_ lhs: inout Self, _ rhs: Self) -> ( value: Void, differential: (TangentVector) -> TangentVector ) { fatalError() } - // expected-error @+1 {{function result's 'differential' type does not match 'staticMethod(_:rhs:)'}} + // expected-error @+1 {{cannot differentiate void function 'staticMethod(_:rhs:)'}} @derivative(of: staticMethod, wrt: rhs) static func jvpNotWrtInout(_ lhs: inout Self, _ rhs: Self) -> ( - // expected-note @+1 {{'differential' does not have expected type '(InoutParameters.TangentVector) -> InoutParameters.TangentVector' (aka '(DummyTangentVector) -> DummyTangentVector')}} value: Void, differential: (inout TangentVector) -> TangentVector ) { fatalError() } } extension InoutParameters { - // expected-note @+1 4 {{'mutatingMethod' defined here}} + // expected-note @+2 {{'mutatingMethod' defined here}} + // expected-note @+1 {{'mutatingMethod' defined here}} mutating func mutatingMethod(_ other: Self) {} // Test wrt `inout` `self` parameter. @@ -857,27 +859,27 @@ extension InoutParameters { // Test non-wrt `inout` `self` parameter. + // expected-error @+1 {{cannot differentiate void function 'mutatingMethod'}} @derivative(of: mutatingMethod, wrt: other) mutating func vjpNotWrtInout(_ other: Self) -> ( value: Void, pullback: (TangentVector) -> TangentVector ) { fatalError() } - // expected-error @+1 {{function result's 'pullback' type does not match 'mutatingMethod'}} + // expected-error @+1 {{cannot differentiate void function 'mutatingMethod'}} @derivative(of: mutatingMethod, wrt: other) mutating func vjpNotWrtInoutMismatch(_ other: Self) -> ( - // expected-note @+1 {{'pullback' does not have expected type '(InoutParameters.TangentVector) -> InoutParameters.TangentVector' (aka '(DummyTangentVector) -> DummyTangentVector')}} value: Void, pullback: (inout TangentVector) -> TangentVector ) { fatalError() } + // expected-error @+1 {{cannot differentiate void function 'mutatingMethod'}} @derivative(of: mutatingMethod, wrt: other) mutating func jvpNotWrtInout(_ other: Self) -> ( value: Void, differential: (TangentVector) -> TangentVector ) { fatalError() } - // expected-error @+1 {{function result's 'differential' type does not match 'mutatingMethod'}} + // expected-error @+1 {{cannot differentiate void function 'mutatingMethod'}} @derivative(of: mutatingMethod, wrt: other) mutating func jvpNotWrtInoutMismatch(_ other: Self) -> ( - // expected-note @+1 {{'differential' does not have expected type '(InoutParameters.TangentVector) -> InoutParameters.TangentVector' (aka '(DummyTangentVector) -> DummyTangentVector')}} value: Void, differential: (TangentVector, TangentVector) -> Void ) { fatalError() } } diff --git a/test/AutoDiff/Sema/differentiable_attr_type_checking.swift b/test/AutoDiff/Sema/differentiable_attr_type_checking.swift index 53299fe2f639b..eefa81f12a70b 100644 --- a/test/AutoDiff/Sema/differentiable_attr_type_checking.swift +++ b/test/AutoDiff/Sema/differentiable_attr_type_checking.swift @@ -681,7 +681,7 @@ struct InoutParameters: Differentiable { } extension NonDiffableStruct { - // expected-error @+1 {{can only differentiate functions with results that conform to 'Differentiable', but 'NonDiffableStruct' does not conform to 'Differentiable'}} + // expected-error @+1 {{cannot differentiate void function 'nondiffResult(x:y:z:)'}} @differentiable(reverse) static func nondiffResult(x: Int, y: inout NonDiffableStruct, z: Float) {} diff --git a/test/AutoDiff/compiler_crashers_fixed/issue-55745-noderivative-inout-parameter.swift b/test/AutoDiff/compiler_crashers_fixed/issue-55745-noderivative-inout-parameter.swift index 62a7d2ea018cf..5ab20d9360a82 100644 --- a/test/AutoDiff/compiler_crashers_fixed/issue-55745-noderivative-inout-parameter.swift +++ b/test/AutoDiff/compiler_crashers_fixed/issue-55745-noderivative-inout-parameter.swift @@ -4,15 +4,16 @@ import _Differentiation // https://github.com/apple/swift/issues/55745 // Test protocol witness thunk for `@differentiable` protocol requirement, where -// the required method has a non-wrt `inout` parameter that should be treated as -// a differentiability result. +// the required method has a non-wrt `inout` parameter. protocol Proto { + // expected-error @+1 {{cannot differentiate void function 'method(x:y:)'}} @differentiable(reverse, wrt: x) func method(x: Float, y: inout Float) } struct Struct: Proto { + // expected-error @+1 {{cannot differentiate void function 'method(x:y:)'}} @differentiable(reverse, wrt: x) func method(x: Float, y: inout Float) { y = y * x diff --git a/test/AutoDiff/validation-test/forward_mode_simple.swift b/test/AutoDiff/validation-test/forward_mode_simple.swift index 0b4cb384e368c..90e52c4e71650 100644 --- a/test/AutoDiff/validation-test/forward_mode_simple.swift +++ b/test/AutoDiff/validation-test/forward_mode_simple.swift @@ -1320,29 +1320,14 @@ ForwardModeTests.test("ForceUnwrapping") { } ForwardModeTests.test("NonVariedResult") { - @differentiable(reverse, wrt: x) - func nonWrtInoutParam(_ x: T, _ y: inout T) { - y = x - } - @differentiable(reverse) func wrtInoutParam(_ x: T, _ y: inout T) { y = x } - @differentiable(reverse, wrt: x) - func nonWrtInoutParamNonVaried(_ x: T, _ y: inout T) {} - - @differentiable(reverse, wrt: x) + @differentiable(reverse, wrt: (x, y)) func wrtInoutParamNonVaried(_ x: T, _ y: inout T) {} - @differentiable(reverse) - func variedResultTracked(_ x: Tracked) -> Tracked { - var result: Tracked = 0 - nonWrtInoutParam(x, &result) - return result - } - @differentiable(reverse) func variedResultTracked2(_ x: Tracked) -> Tracked { var result: Tracked = 0 @@ -1352,13 +1337,6 @@ ForwardModeTests.test("NonVariedResult") { @differentiable(reverse) func nonVariedResultTracked(_ x: Tracked) -> Tracked { - var result: Tracked = 0 - nonWrtInoutParamNonVaried(x, &result) - return result - } - - @differentiable(reverse) - func nonVariedResultTracked2(_ x: Tracked) -> Tracked { // expected-warning @+1 {{variable 'result' was never mutated}} var result: Tracked = 0 return result diff --git a/test/AutoDiff/validation-test/inout_parameters.swift b/test/AutoDiff/validation-test/inout_parameters.swift index 800b373ffcec0..95ede7ee938c6 100644 --- a/test/AutoDiff/validation-test/inout_parameters.swift +++ b/test/AutoDiff/validation-test/inout_parameters.swift @@ -191,26 +191,27 @@ InoutParameterAutoDiffTests.test("InoutClassParameter") { } } -// https://github.com/apple/swift/issues/55745 -// Test function with non-wrt `inout` parameter, which should be -// treated as a differentiability result. +// Test function with wrt `inout` parameter, which should be treated as a differentiability result. +// Original issue https://github.com/apple/swift/issues/55745 deals with non-wrt `inout` which +// we explicitly disallow now + protocol P_55745 { - @differentiable(reverse, wrt: x) + @differentiable(reverse, wrt: (x, y)) func method(_ x: Float, _ y: inout Float) - @differentiable(reverse, wrt: x) + @differentiable(reverse, wrt: (x, y)) func genericMethod(_ x: T, _ y: inout T) } InoutParameterAutoDiffTests.test("non-wrt inout parameter") { struct Struct: P_55745 { - @differentiable(reverse, wrt: x) + @differentiable(reverse, wrt: (x, y)) func method(_ x: Float, _ y: inout Float) { y = y * x } - @differentiable(reverse, wrt: x) + @differentiable(reverse, wrt: (x, y)) func genericMethod(_ x: T, _ y: inout T) { y = x } From 9e9cbc1e95a8ff16c385156f5a004818b313b508 Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Tue, 25 Jul 2023 22:50:18 -0400 Subject: [PATCH 28/37] AST: Remove unprintable character from ModuleNameLookup.cpp --- lib/AST/ModuleNameLookup.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/AST/ModuleNameLookup.cpp b/lib/AST/ModuleNameLookup.cpp index 894e91f063ce4..a836dd000ecdf 100644 --- a/lib/AST/ModuleNameLookup.cpp +++ b/lib/AST/ModuleNameLookup.cpp @@ -78,7 +78,7 @@ class LookupByName : public ModuleNameLookup { lookupKind(lookupKind) {} private: - /// Returns whether it's okay to stop recursively searching imports, given  + /// Returns whether it's okay to stop recursively searching imports, given /// that we found something non-overloadable. static bool canReturnEarly() { return true; From 92ad17231ee4995f4b4c5cae1458d1d34ca67411 Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Thu, 3 Aug 2023 11:06:55 -0400 Subject: [PATCH 29/37] IRGen: Factor out emitFunctionTypeMetadataRef() from EmitTypeMetadataRef --- lib/IRGen/MetadataRequest.cpp | 498 +++++++++++++++++----------------- 1 file changed, 254 insertions(+), 244 deletions(-) diff --git a/lib/IRGen/MetadataRequest.cpp b/lib/IRGen/MetadataRequest.cpp index e226401fc7251..382fff5479c25 100644 --- a/lib/IRGen/MetadataRequest.cpp +++ b/lib/IRGen/MetadataRequest.cpp @@ -1364,6 +1364,258 @@ static void destroyGenericArgumentsArray(IRGenFunction &IGF, IGF.IGM.getPointerSize() * args.size()); } +static llvm::Value *getFunctionParameterRef(IRGenFunction &IGF, + AnyFunctionType::CanParam param) { + auto type = param.getPlainType()->getCanonicalType(); + return IGF.emitAbstractTypeMetadataRef(type); +} + +static MetadataResponse emitFunctionTypeMetadataRef(IRGenFunction &IGF, + CanFunctionType type, + DynamicMetadataRequest request) { + auto result = + IGF.emitAbstractTypeMetadataRef(type->getResult()->getCanonicalType()); + + auto params = type.getParams(); + auto numParams = params.size(); + + // Retrieve the ABI parameter flags from the type-level parameter + // flags. + auto getABIParameterFlags = [](ParameterTypeFlags flags) { + return ParameterFlags() + .withValueOwnership(flags.getValueOwnership()) + .withVariadic(flags.isVariadic()) + .withAutoClosure(flags.isAutoClosure()) + .withNoDerivative(flags.isNoDerivative()) + .withIsolated(flags.isIsolated()); + }; + + bool hasParameterFlags = false; + for (auto param : params) { + if (!getABIParameterFlags(param.getParameterFlags()).isNone()) { + hasParameterFlags = true; + break; + } + } + + // Map the convention to a runtime metadata value. + FunctionMetadataConvention metadataConvention; + bool isEscaping = false; + switch (type->getRepresentation()) { + case FunctionTypeRepresentation::Swift: + metadataConvention = FunctionMetadataConvention::Swift; + isEscaping = !type->isNoEscape(); + break; + case FunctionTypeRepresentation::Thin: + metadataConvention = FunctionMetadataConvention::Thin; + break; + case FunctionTypeRepresentation::Block: + metadataConvention = FunctionMetadataConvention::Block; + break; + case FunctionTypeRepresentation::CFunctionPointer: + metadataConvention = FunctionMetadataConvention::CFunctionPointer; + break; + } + + FunctionMetadataDifferentiabilityKind metadataDifferentiabilityKind; + switch (type->getDifferentiabilityKind()) { + case DifferentiabilityKind::NonDifferentiable: + metadataDifferentiabilityKind = + FunctionMetadataDifferentiabilityKind::NonDifferentiable; + break; + case DifferentiabilityKind::Normal: + metadataDifferentiabilityKind = + FunctionMetadataDifferentiabilityKind::Normal; + break; + case DifferentiabilityKind::Linear: + metadataDifferentiabilityKind = + FunctionMetadataDifferentiabilityKind::Linear; + break; + case DifferentiabilityKind::Forward: + metadataDifferentiabilityKind = + FunctionMetadataDifferentiabilityKind::Forward; + break; + case DifferentiabilityKind::Reverse: + metadataDifferentiabilityKind = + FunctionMetadataDifferentiabilityKind::Reverse; + break; + } + + auto flags = FunctionTypeFlags() + .withNumParameters(numParams) + .withConvention(metadataConvention) + .withAsync(type->isAsync()) + .withConcurrent(type->isSendable()) + .withThrows(type->isThrowing()) + .withParameterFlags(hasParameterFlags) + .withEscaping(isEscaping) + .withDifferentiable(type->isDifferentiable()) + .withGlobalActor(!type->getGlobalActor().isNull()); + + auto flagsVal = llvm::ConstantInt::get(IGF.IGM.SizeTy, + flags.getIntValue()); + llvm::Value *diffKindVal = nullptr; + if (type->isDifferentiable()) { + assert(metadataDifferentiabilityKind.isDifferentiable()); + diffKindVal = llvm::ConstantInt::get( + IGF.IGM.SizeTy, metadataDifferentiabilityKind.getIntValue()); + } else if (type->getGlobalActor()) { + diffKindVal = llvm::ConstantInt::get( + IGF.IGM.SizeTy, + FunctionMetadataDifferentiabilityKind::NonDifferentiable); + } + + auto collectParameters = + [&](llvm::function_ref + processor) { + for (auto index : indices(params)) { + auto param = params[index]; + auto flags = param.getParameterFlags(); + + auto parameterFlags = getABIParameterFlags(flags); + processor(index, getFunctionParameterRef(IGF, param), + parameterFlags); + } + }; + + auto constructSimpleCall = + [&](llvm::SmallVectorImpl &arguments) + -> FunctionPointer { + arguments.push_back(flagsVal); + + collectParameters([&](unsigned i, llvm::Value *typeRef, + ParameterFlags flags) { + arguments.push_back(typeRef); + if (hasParameterFlags) + arguments.push_back( + llvm::ConstantInt::get(IGF.IGM.Int32Ty, flags.getIntValue())); + }); + + arguments.push_back(result); + + switch (params.size()) { + case 0: + return IGF.IGM.getGetFunctionMetadata0FunctionPointer(); + + case 1: + return IGF.IGM.getGetFunctionMetadata1FunctionPointer(); + + case 2: + return IGF.IGM.getGetFunctionMetadata2FunctionPointer(); + + case 3: + return IGF.IGM.getGetFunctionMetadata3FunctionPointer(); + + default: + llvm_unreachable("supports only 1/2/3 parameter functions"); + } + }; + + switch (numParams) { + case 0: + case 1: + case 2: + case 3: { + if (!hasParameterFlags && !type->isDifferentiable() && + !type->getGlobalActor()) { + llvm::SmallVector arguments; + auto metadataFn = constructSimpleCall(arguments); + auto *call = IGF.Builder.CreateCall(metadataFn, arguments); + call->setDoesNotThrow(); + return MetadataResponse::forComplete(call); + } + + // If function type has parameter flags or is differentiable or has a + // global actor, emit the most general function to retrieve them. + LLVM_FALLTHROUGH; + } + + default: + assert((!params.empty() || type->isDifferentiable() || + type->getGlobalActor()) && + "0 parameter case should be specialized unless it is a " + "differentiable function or has a global actor"); + + auto *const Int32Ptr = IGF.IGM.Int32Ty->getPointerTo(); + llvm::SmallVector arguments; + + arguments.push_back(flagsVal); + + if (diffKindVal) { + arguments.push_back(diffKindVal); + } + + ConstantInitBuilder paramFlags(IGF.IGM); + auto flagsArr = paramFlags.beginArray(); + + Address parameters; + if (!params.empty()) { + auto arrayTy = + llvm::ArrayType::get(IGF.IGM.TypeMetadataPtrTy, numParams); + parameters = IGF.createAlloca( + arrayTy, IGF.IGM.getTypeMetadataAlignment(), "function-parameters"); + + IGF.Builder.CreateLifetimeStart(parameters, + IGF.IGM.getPointerSize() * numParams); + + collectParameters([&](unsigned i, llvm::Value *typeRef, + ParameterFlags flags) { + auto argPtr = IGF.Builder.CreateStructGEP(parameters, i, + IGF.IGM.getPointerSize()); + IGF.Builder.CreateStore(typeRef, argPtr); + if (i == 0) + arguments.push_back(argPtr.getAddress()); + + if (hasParameterFlags) + flagsArr.addInt32(flags.getIntValue()); + }); + } else { + auto parametersPtr = + llvm::ConstantPointerNull::get( + IGF.IGM.TypeMetadataPtrTy->getPointerTo()); + arguments.push_back(parametersPtr); + } + + if (hasParameterFlags) { + auto *flagsVar = flagsArr.finishAndCreateGlobal( + "parameter-flags", IGF.IGM.getPointerAlignment(), + /* constant */ true); + arguments.push_back(IGF.Builder.CreateBitCast(flagsVar, Int32Ptr)); + } else { + flagsArr.abandon(); + arguments.push_back(llvm::ConstantPointerNull::get(Int32Ptr)); + } + + arguments.push_back(result); + + if (Type globalActor = type->getGlobalActor()) { + arguments.push_back( + IGF.emitAbstractTypeMetadataRef(globalActor->getCanonicalType())); + } + + auto getMetadataFn = + type->getGlobalActor() + ? (IGF.IGM.isConcurrencyAvailable() + ? IGF.IGM + .getGetFunctionMetadataGlobalActorFunctionPointer() + : IGF.IGM + .getGetFunctionMetadataGlobalActorBackDeployFunctionPointer()) + : type->isDifferentiable() + ? IGF.IGM.getGetFunctionMetadataDifferentiableFunctionPointer() + : IGF.IGM.getGetFunctionMetadataFunctionPointer(); + + auto call = IGF.Builder.CreateCall(getMetadataFn, arguments); + call->setDoesNotThrow(); + + if (parameters.isValid()) + IGF.Builder.CreateLifetimeEnd(parameters, + IGF.IGM.getPointerSize() * numParams); + + return MetadataResponse::forComplete(call); + } +} + namespace { /// A visitor class for emitting a reference to a metatype object. /// This implements a "raw" access, useful for implementing cache @@ -1516,256 +1768,14 @@ namespace { return MetadataResponse::getUndef(IGF); } - llvm::Value *getFunctionParameterRef(AnyFunctionType::CanParam ¶m) { - auto type = param.getPlainType()->getCanonicalType(); - return IGF.emitAbstractTypeMetadataRef(type); - } - MetadataResponse visitFunctionType(CanFunctionType type, DynamicMetadataRequest request) { if (auto metatype = tryGetLocal(type, request)) return metatype; - auto result = - IGF.emitAbstractTypeMetadataRef(type->getResult()->getCanonicalType()); - - auto params = type.getParams(); - auto numParams = params.size(); - - // Retrieve the ABI parameter flags from the type-level parameter - // flags. - auto getABIParameterFlags = [](ParameterTypeFlags flags) { - return ParameterFlags() - .withValueOwnership(flags.getValueOwnership()) - .withVariadic(flags.isVariadic()) - .withAutoClosure(flags.isAutoClosure()) - .withNoDerivative(flags.isNoDerivative()) - .withIsolated(flags.isIsolated()); - }; - - bool hasParameterFlags = false; - for (auto param : params) { - if (!getABIParameterFlags(param.getParameterFlags()).isNone()) { - hasParameterFlags = true; - break; - } - } - - // Map the convention to a runtime metadata value. - FunctionMetadataConvention metadataConvention; - bool isEscaping = false; - switch (type->getRepresentation()) { - case FunctionTypeRepresentation::Swift: - metadataConvention = FunctionMetadataConvention::Swift; - isEscaping = !type->isNoEscape(); - break; - case FunctionTypeRepresentation::Thin: - metadataConvention = FunctionMetadataConvention::Thin; - break; - case FunctionTypeRepresentation::Block: - metadataConvention = FunctionMetadataConvention::Block; - break; - case FunctionTypeRepresentation::CFunctionPointer: - metadataConvention = FunctionMetadataConvention::CFunctionPointer; - break; - } - - FunctionMetadataDifferentiabilityKind metadataDifferentiabilityKind; - switch (type->getDifferentiabilityKind()) { - case DifferentiabilityKind::NonDifferentiable: - metadataDifferentiabilityKind = - FunctionMetadataDifferentiabilityKind::NonDifferentiable; - break; - case DifferentiabilityKind::Normal: - metadataDifferentiabilityKind = - FunctionMetadataDifferentiabilityKind::Normal; - break; - case DifferentiabilityKind::Linear: - metadataDifferentiabilityKind = - FunctionMetadataDifferentiabilityKind::Linear; - break; - case DifferentiabilityKind::Forward: - metadataDifferentiabilityKind = - FunctionMetadataDifferentiabilityKind::Forward; - break; - case DifferentiabilityKind::Reverse: - metadataDifferentiabilityKind = - FunctionMetadataDifferentiabilityKind::Reverse; - break; - } - - auto flags = FunctionTypeFlags() - .withNumParameters(numParams) - .withConvention(metadataConvention) - .withAsync(type->isAsync()) - .withConcurrent(type->isSendable()) - .withThrows(type->isThrowing()) - .withParameterFlags(hasParameterFlags) - .withEscaping(isEscaping) - .withDifferentiable(type->isDifferentiable()) - .withGlobalActor(!type->getGlobalActor().isNull()); - - auto flagsVal = llvm::ConstantInt::get(IGF.IGM.SizeTy, - flags.getIntValue()); - llvm::Value *diffKindVal = nullptr; - if (type->isDifferentiable()) { - assert(metadataDifferentiabilityKind.isDifferentiable()); - diffKindVal = llvm::ConstantInt::get( - IGF.IGM.SizeTy, metadataDifferentiabilityKind.getIntValue()); - } else if (type->getGlobalActor()) { - diffKindVal = llvm::ConstantInt::get( - IGF.IGM.SizeTy, - FunctionMetadataDifferentiabilityKind::NonDifferentiable); - } - - auto collectParameters = - [&](llvm::function_ref - processor) { - for (auto index : indices(params)) { - auto param = params[index]; - auto flags = param.getParameterFlags(); - - auto parameterFlags = getABIParameterFlags(flags); - processor(index, getFunctionParameterRef(param), parameterFlags); - } - }; - - auto constructSimpleCall = - [&](llvm::SmallVectorImpl &arguments) - -> FunctionPointer { - arguments.push_back(flagsVal); - - collectParameters([&](unsigned i, llvm::Value *typeRef, - ParameterFlags flags) { - arguments.push_back(typeRef); - if (hasParameterFlags) - arguments.push_back( - llvm::ConstantInt::get(IGF.IGM.Int32Ty, flags.getIntValue())); - }); - - arguments.push_back(result); - - switch (params.size()) { - case 0: - return IGF.IGM.getGetFunctionMetadata0FunctionPointer(); + auto response = emitFunctionTypeMetadataRef(IGF, type, request); - case 1: - return IGF.IGM.getGetFunctionMetadata1FunctionPointer(); - - case 2: - return IGF.IGM.getGetFunctionMetadata2FunctionPointer(); - - case 3: - return IGF.IGM.getGetFunctionMetadata3FunctionPointer(); - - default: - llvm_unreachable("supports only 1/2/3 parameter functions"); - } - }; - - switch (numParams) { - case 0: - case 1: - case 2: - case 3: { - if (!hasParameterFlags && !type->isDifferentiable() && - !type->getGlobalActor()) { - llvm::SmallVector arguments; - auto metadataFn = constructSimpleCall(arguments); - auto *call = IGF.Builder.CreateCall(metadataFn, arguments); - call->setDoesNotThrow(); - return setLocal(CanType(type), MetadataResponse::forComplete(call)); - } - - // If function type has parameter flags or is differentiable or has a - // global actor, emit the most general function to retrieve them. - LLVM_FALLTHROUGH; - } - - default: - assert((!params.empty() || type->isDifferentiable() || - type->getGlobalActor()) && - "0 parameter case should be specialized unless it is a " - "differentiable function or has a global actor"); - - auto *const Int32Ptr = IGF.IGM.Int32Ty->getPointerTo(); - llvm::SmallVector arguments; - - arguments.push_back(flagsVal); - - if (diffKindVal) { - arguments.push_back(diffKindVal); - } - - ConstantInitBuilder paramFlags(IGF.IGM); - auto flagsArr = paramFlags.beginArray(); - - Address parameters; - if (!params.empty()) { - auto arrayTy = - llvm::ArrayType::get(IGF.IGM.TypeMetadataPtrTy, numParams); - parameters = IGF.createAlloca( - arrayTy, IGF.IGM.getTypeMetadataAlignment(), "function-parameters"); - - IGF.Builder.CreateLifetimeStart(parameters, - IGF.IGM.getPointerSize() * numParams); - - collectParameters([&](unsigned i, llvm::Value *typeRef, - ParameterFlags flags) { - auto argPtr = IGF.Builder.CreateStructGEP(parameters, i, - IGF.IGM.getPointerSize()); - IGF.Builder.CreateStore(typeRef, argPtr); - if (i == 0) - arguments.push_back(argPtr.getAddress()); - - if (hasParameterFlags) - flagsArr.addInt32(flags.getIntValue()); - }); - } else { - auto parametersPtr = - llvm::ConstantPointerNull::get( - IGF.IGM.TypeMetadataPtrTy->getPointerTo()); - arguments.push_back(parametersPtr); - } - - if (hasParameterFlags) { - auto *flagsVar = flagsArr.finishAndCreateGlobal( - "parameter-flags", IGF.IGM.getPointerAlignment(), - /* constant */ true); - arguments.push_back(IGF.Builder.CreateBitCast(flagsVar, Int32Ptr)); - } else { - flagsArr.abandon(); - arguments.push_back(llvm::ConstantPointerNull::get(Int32Ptr)); - } - - arguments.push_back(result); - - if (Type globalActor = type->getGlobalActor()) { - arguments.push_back( - IGF.emitAbstractTypeMetadataRef(globalActor->getCanonicalType())); - } - - auto getMetadataFn = - type->getGlobalActor() - ? (IGF.IGM.isConcurrencyAvailable() - ? IGF.IGM - .getGetFunctionMetadataGlobalActorFunctionPointer() - : IGF.IGM - .getGetFunctionMetadataGlobalActorBackDeployFunctionPointer()) - : type->isDifferentiable() - ? IGF.IGM.getGetFunctionMetadataDifferentiableFunctionPointer() - : IGF.IGM.getGetFunctionMetadataFunctionPointer(); - - auto call = IGF.Builder.CreateCall(getMetadataFn, arguments); - call->setDoesNotThrow(); - - if (parameters.isValid()) - IGF.Builder.CreateLifetimeEnd(parameters, - IGF.IGM.getPointerSize() * numParams); - - return setLocal(type, MetadataResponse::forComplete(call)); - } + return setLocal(type, response); } MetadataResponse visitMetatypeType(CanMetatypeType type, From 0f3c43089aac58542838d143652278c61a09edc0 Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Thu, 3 Aug 2023 11:42:42 -0400 Subject: [PATCH 30/37] IRGen: Refactor function type metadata emission --- lib/IRGen/MetadataRequest.cpp | 314 ++++++++++++++++++++-------------- 1 file changed, 189 insertions(+), 125 deletions(-) diff --git a/lib/IRGen/MetadataRequest.cpp b/lib/IRGen/MetadataRequest.cpp index 382fff5479c25..d1cf383451c39 100644 --- a/lib/IRGen/MetadataRequest.cpp +++ b/lib/IRGen/MetadataRequest.cpp @@ -1370,28 +1370,19 @@ static llvm::Value *getFunctionParameterRef(IRGenFunction &IGF, return IGF.emitAbstractTypeMetadataRef(type); } -static MetadataResponse emitFunctionTypeMetadataRef(IRGenFunction &IGF, - CanFunctionType type, - DynamicMetadataRequest request) { - auto result = - IGF.emitAbstractTypeMetadataRef(type->getResult()->getCanonicalType()); - - auto params = type.getParams(); - auto numParams = params.size(); - - // Retrieve the ABI parameter flags from the type-level parameter - // flags. - auto getABIParameterFlags = [](ParameterTypeFlags flags) { - return ParameterFlags() +/// Mapping type-level parameter flags to ABI parameter flags. +static ParameterFlags getABIParameterFlags(ParameterTypeFlags flags) { + return ParameterFlags() .withValueOwnership(flags.getValueOwnership()) .withVariadic(flags.isVariadic()) .withAutoClosure(flags.isAutoClosure()) .withNoDerivative(flags.isNoDerivative()) .withIsolated(flags.isIsolated()); - }; +} +static FunctionTypeFlags getFunctionTypeFlags(CanFunctionType type) { bool hasParameterFlags = false; - for (auto param : params) { + for (auto param : type.getParams()) { if (!getABIParameterFlags(param.getParameterFlags()).isNone()) { hasParameterFlags = true; break; @@ -1417,80 +1408,149 @@ static MetadataResponse emitFunctionTypeMetadataRef(IRGenFunction &IGF, break; } - FunctionMetadataDifferentiabilityKind metadataDifferentiabilityKind; - switch (type->getDifferentiabilityKind()) { - case DifferentiabilityKind::NonDifferentiable: - metadataDifferentiabilityKind = - FunctionMetadataDifferentiabilityKind::NonDifferentiable; - break; - case DifferentiabilityKind::Normal: - metadataDifferentiabilityKind = - FunctionMetadataDifferentiabilityKind::Normal; - break; - case DifferentiabilityKind::Linear: - metadataDifferentiabilityKind = - FunctionMetadataDifferentiabilityKind::Linear; - break; - case DifferentiabilityKind::Forward: - metadataDifferentiabilityKind = - FunctionMetadataDifferentiabilityKind::Forward; - break; - case DifferentiabilityKind::Reverse: - metadataDifferentiabilityKind = - FunctionMetadataDifferentiabilityKind::Reverse; - break; + return FunctionTypeFlags() + .withConvention(metadataConvention) + .withAsync(type->isAsync()) + .withConcurrent(type->isSendable()) + .withThrows(type->isThrowing()) + .withParameterFlags(hasParameterFlags) + .withEscaping(isEscaping) + .withDifferentiable(type->isDifferentiable()) + .withGlobalActor(!type->getGlobalActor().isNull()); +} + +namespace { +struct FunctionTypeMetadataParamInfo { + StackAddress parameters; + StackAddress paramFlags; + unsigned numParams; +}; +} + +static FunctionTypeMetadataParamInfo +emitFunctionTypeMetadataParams(IRGenFunction &IGF, + AnyFunctionType::CanParamArrayRef params, + FunctionTypeFlags flags, + DynamicMetadataRequest request, + SmallVectorImpl &arguments) { + FunctionTypeMetadataParamInfo info; + info.numParams = params.size(); + + ConstantInitBuilder paramFlags(IGF.IGM); + auto flagsArr = paramFlags.beginArray(); + + if (!params.empty()) { + auto arrayTy = + llvm::ArrayType::get(IGF.IGM.TypeMetadataPtrTy, info.numParams); + info.parameters = StackAddress(IGF.createAlloca( + arrayTy, IGF.IGM.getTypeMetadataAlignment(), "function-parameters")); + + IGF.Builder.CreateLifetimeStart(info.parameters.getAddress(), + IGF.IGM.getPointerSize() * info.numParams); + + for (unsigned i : indices(params)) { + auto param = params[i]; + auto paramFlags = getABIParameterFlags(param.getParameterFlags()); + + auto argPtr = IGF.Builder.CreateStructGEP(info.parameters.getAddress(), i, + IGF.IGM.getPointerSize()); + auto *typeRef = getFunctionParameterRef(IGF, param); + IGF.Builder.CreateStore(typeRef, argPtr); + if (i == 0) + arguments.push_back(argPtr.getAddress()); + + flagsArr.addInt32(paramFlags.getIntValue()); + } + } else { + auto parametersPtr = + llvm::ConstantPointerNull::get( + IGF.IGM.TypeMetadataPtrTy->getPointerTo()); + arguments.push_back(parametersPtr); + } + + auto *Int32Ptr = IGF.IGM.Int32Ty->getPointerTo(); + if (flags.hasParameterFlags()) { + auto *flagsVar = flagsArr.finishAndCreateGlobal( + "parameter-flags", IGF.IGM.getPointerAlignment(), + /* constant */ true); + arguments.push_back(IGF.Builder.CreateBitCast(flagsVar, Int32Ptr)); + } else { + flagsArr.abandon(); + arguments.push_back(llvm::ConstantPointerNull::get(Int32Ptr)); } - auto flags = FunctionTypeFlags() - .withNumParameters(numParams) - .withConvention(metadataConvention) - .withAsync(type->isAsync()) - .withConcurrent(type->isSendable()) - .withThrows(type->isThrowing()) - .withParameterFlags(hasParameterFlags) - .withEscaping(isEscaping) - .withDifferentiable(type->isDifferentiable()) - .withGlobalActor(!type->getGlobalActor().isNull()); - - auto flagsVal = llvm::ConstantInt::get(IGF.IGM.SizeTy, - flags.getIntValue()); - llvm::Value *diffKindVal = nullptr; - if (type->isDifferentiable()) { - assert(metadataDifferentiabilityKind.isDifferentiable()); - diffKindVal = llvm::ConstantInt::get( - IGF.IGM.SizeTy, metadataDifferentiabilityKind.getIntValue()); - } else if (type->getGlobalActor()) { - diffKindVal = llvm::ConstantInt::get( - IGF.IGM.SizeTy, - FunctionMetadataDifferentiabilityKind::NonDifferentiable); - } - - auto collectParameters = - [&](llvm::function_ref - processor) { - for (auto index : indices(params)) { - auto param = params[index]; - auto flags = param.getParameterFlags(); - - auto parameterFlags = getABIParameterFlags(flags); - processor(index, getFunctionParameterRef(IGF, param), - parameterFlags); - } - }; + return info; +} + +static FunctionTypeMetadataParamInfo +emitDynamicFunctionTypeMetadataParams(IRGenFunction &IGF, + AnyFunctionType::CanParamArrayRef params, + FunctionTypeFlags flags, + CanPackType packType, + DynamicMetadataRequest request, + SmallVectorImpl &arguments) { + assert(false); +} + +static void cleanupFunctionTypeMetadataParams(IRGenFunction &IGF, + FunctionTypeMetadataParamInfo info) { + if (info.parameters.isValid()) { + if (info.parameters.getExtraInfo()) { + IGF.emitDeallocateDynamicAlloca(info.parameters); + } else { + IGF.Builder.CreateLifetimeEnd(info.parameters.getAddress(), + IGF.IGM.getPointerSize() * info.numParams); + } + } +} + +static CanPackType getInducedPackType(AnyFunctionType::CanParamArrayRef params, + ASTContext &ctx) { + SmallVector elts; + for (auto param : params) + elts.push_back(param.getPlainType()); + + return CanPackType::get(ctx, elts); +} + +static MetadataResponse emitFunctionTypeMetadataRef(IRGenFunction &IGF, + CanFunctionType type, + DynamicMetadataRequest request) { + auto result = + IGF.emitAbstractTypeMetadataRef(type->getResult()->getCanonicalType()); + + auto params = type.getParams(); + bool hasPackExpansion = type->containsPackExpansionParam(); + + auto flags = getFunctionTypeFlags(type); + llvm::Value *flagsVal = nullptr; + llvm::Value *shapeExpression = nullptr; + CanPackType packType; + + if (!hasPackExpansion) { + flags = flags.withNumParameters(params.size()); + flagsVal = llvm::ConstantInt::get(IGF.IGM.SizeTy, + flags.getIntValue()); + } else { + packType = getInducedPackType(type.getParams(), type->getASTContext()); + auto *shapeExpression = IGF.emitPackShapeExpression(packType); + + flagsVal = llvm::ConstantInt::get(IGF.IGM.SizeTy, + flags.getIntValue()); + flagsVal = IGF.Builder.CreateOr(flagsVal, shapeExpression); + } auto constructSimpleCall = [&](llvm::SmallVectorImpl &arguments) -> FunctionPointer { + assert(!flags.hasParameterFlags()); + assert(!shapeExpression); + arguments.push_back(flagsVal); - collectParameters([&](unsigned i, llvm::Value *typeRef, - ParameterFlags flags) { - arguments.push_back(typeRef); - if (hasParameterFlags) - arguments.push_back( - llvm::ConstantInt::get(IGF.IGM.Int32Ty, flags.getIntValue())); - }); + for (auto param : params) { + arguments.push_back(getFunctionParameterRef(IGF, param)); + } arguments.push_back(result); @@ -1512,13 +1572,13 @@ static MetadataResponse emitFunctionTypeMetadataRef(IRGenFunction &IGF, } }; - switch (numParams) { + switch (params.size()) { case 0: case 1: case 2: case 3: { - if (!hasParameterFlags && !type->isDifferentiable() && - !type->getGlobalActor()) { + if (!flags.hasParameterFlags() && !type->isDifferentiable() && + !type->getGlobalActor() && !hasPackExpansion) { llvm::SmallVector arguments; auto metadataFn = constructSimpleCall(arguments); auto *call = IGF.Builder.CreateCall(metadataFn, arguments); @@ -1537,54 +1597,60 @@ static MetadataResponse emitFunctionTypeMetadataRef(IRGenFunction &IGF, "0 parameter case should be specialized unless it is a " "differentiable function or has a global actor"); - auto *const Int32Ptr = IGF.IGM.Int32Ty->getPointerTo(); llvm::SmallVector arguments; arguments.push_back(flagsVal); - if (diffKindVal) { - arguments.push_back(diffKindVal); - } - - ConstantInitBuilder paramFlags(IGF.IGM); - auto flagsArr = paramFlags.beginArray(); + llvm::Value *diffKindVal = nullptr; - Address parameters; - if (!params.empty()) { - auto arrayTy = - llvm::ArrayType::get(IGF.IGM.TypeMetadataPtrTy, numParams); - parameters = IGF.createAlloca( - arrayTy, IGF.IGM.getTypeMetadataAlignment(), "function-parameters"); - - IGF.Builder.CreateLifetimeStart(parameters, - IGF.IGM.getPointerSize() * numParams); + { + FunctionMetadataDifferentiabilityKind metadataDifferentiabilityKind; + switch (type->getDifferentiabilityKind()) { + case DifferentiabilityKind::NonDifferentiable: + metadataDifferentiabilityKind = + FunctionMetadataDifferentiabilityKind::NonDifferentiable; + break; + case DifferentiabilityKind::Normal: + metadataDifferentiabilityKind = + FunctionMetadataDifferentiabilityKind::Normal; + break; + case DifferentiabilityKind::Linear: + metadataDifferentiabilityKind = + FunctionMetadataDifferentiabilityKind::Linear; + break; + case DifferentiabilityKind::Forward: + metadataDifferentiabilityKind = + FunctionMetadataDifferentiabilityKind::Forward; + break; + case DifferentiabilityKind::Reverse: + metadataDifferentiabilityKind = + FunctionMetadataDifferentiabilityKind::Reverse; + break; + } - collectParameters([&](unsigned i, llvm::Value *typeRef, - ParameterFlags flags) { - auto argPtr = IGF.Builder.CreateStructGEP(parameters, i, - IGF.IGM.getPointerSize()); - IGF.Builder.CreateStore(typeRef, argPtr); - if (i == 0) - arguments.push_back(argPtr.getAddress()); + if (type->isDifferentiable()) { + assert(metadataDifferentiabilityKind.isDifferentiable()); + diffKindVal = llvm::ConstantInt::get( + IGF.IGM.SizeTy, metadataDifferentiabilityKind.getIntValue()); + } else if (type->getGlobalActor()) { + diffKindVal = llvm::ConstantInt::get( + IGF.IGM.SizeTy, + FunctionMetadataDifferentiabilityKind::NonDifferentiable); + } + } - if (hasParameterFlags) - flagsArr.addInt32(flags.getIntValue()); - }); - } else { - auto parametersPtr = - llvm::ConstantPointerNull::get( - IGF.IGM.TypeMetadataPtrTy->getPointerTo()); - arguments.push_back(parametersPtr); + if (diffKindVal) { + arguments.push_back(diffKindVal); } - if (hasParameterFlags) { - auto *flagsVar = flagsArr.finishAndCreateGlobal( - "parameter-flags", IGF.IGM.getPointerAlignment(), - /* constant */ true); - arguments.push_back(IGF.Builder.CreateBitCast(flagsVar, Int32Ptr)); + FunctionTypeMetadataParamInfo info; + if (!hasPackExpansion) { + assert(!shapeExpression); + info = emitFunctionTypeMetadataParams(IGF, params, flags, request, + arguments); } else { - flagsArr.abandon(); - arguments.push_back(llvm::ConstantPointerNull::get(Int32Ptr)); + info = emitDynamicFunctionTypeMetadataParams(IGF, params, flags, packType, + request, arguments); } arguments.push_back(result); @@ -1608,9 +1674,7 @@ static MetadataResponse emitFunctionTypeMetadataRef(IRGenFunction &IGF, auto call = IGF.Builder.CreateCall(getMetadataFn, arguments); call->setDoesNotThrow(); - if (parameters.isValid()) - IGF.Builder.CreateLifetimeEnd(parameters, - IGF.IGM.getPointerSize() * numParams); + cleanupFunctionTypeMetadataParams(IGF, info); return MetadataResponse::forComplete(call); } From 6ede5e050f84d342b23f6245c79422b6a437d32f Mon Sep 17 00:00:00 2001 From: Hamish Knight Date: Thu, 3 Aug 2023 21:16:10 +0100 Subject: [PATCH 31/37] [CodeComplete] Avoid dropping "is for code completion" bit This should no longer be needed now that we check for a code completion token when increasing the score. It should also allow us to skip more conjunction elements, as that requires the bit being set. --- lib/Sema/CSStep.cpp | 18 ------------------ lib/Sema/CSStep.h | 5 ----- 2 files changed, 23 deletions(-) diff --git a/lib/Sema/CSStep.cpp b/lib/Sema/CSStep.cpp index 78d777227713f..45bd75f268c77 100644 --- a/lib/Sema/CSStep.cpp +++ b/lib/Sema/CSStep.cpp @@ -880,21 +880,6 @@ bool ConjunctionStep::attempt(const ConjunctionElement &element) { CS.Timer.emplace(element.getLocator(), CS); } - assert(!ModifiedOptions.has_value() && - "Previously modified options should have been restored in resume"); - if (CS.isForCodeCompletion() && - !element.mightContainCodeCompletionToken(CS) && - !getLocator()->isForSingleValueStmtConjunctionOrBrace()) { - ModifiedOptions.emplace(CS.Options); - // If we know that this conjunction element doesn't contain the code - // completion token, type check it in normal mode without any special - // behavior that is intended for the code completion token. - // Avoid doing this for SingleValueStmtExprs, because we can more eagerly - // prune branches in that case, which requires us to detect the code - // completion option while solving the conjunction. - CS.Options -= ConstraintSystemFlags::ForCodeCompletion; - } - auto success = element.attempt(CS); // If element attempt has failed, mark whole conjunction @@ -906,9 +891,6 @@ bool ConjunctionStep::attempt(const ConjunctionElement &element) { } StepResult ConjunctionStep::resume(bool prevFailed) { - // Restore the old ConstraintSystemOptions if 'attempt' modified them. - ModifiedOptions.reset(); - // Return from the follow-up splitter step that // attempted to apply information gained from the // isolated constraint to the outer context. diff --git a/lib/Sema/CSStep.h b/lib/Sema/CSStep.h index 469922441e762..e35805f37b632 100644 --- a/lib/Sema/CSStep.h +++ b/lib/Sema/CSStep.h @@ -942,11 +942,6 @@ class ConjunctionStep : public BindingStep { /// in isolated mode. SmallVector IsolatedSolutions; - /// If \c ConjunctionStep::attempt modified the constraint system options, - /// it will store the original options in this \c llvm::SaveAndRestore. - /// Upon \c resume, these values will be restored. - llvm::Optional> ModifiedOptions; - public: ConjunctionStep(ConstraintSystem &cs, Constraint *conjunction, SmallVectorImpl &solutions) From 59fd29551e04ab0afcbff887b885b4b6e3718414 Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Thu, 3 Aug 2023 14:06:07 -0400 Subject: [PATCH 32/37] IRGen: Emitting metadata for function types containing pack expansion parameters --- lib/IRGen/GenPack.cpp | 60 +++++++++++++++++++ lib/IRGen/GenPack.h | 6 ++ lib/IRGen/MetadataRequest.cpp | 24 +++++++- lib/IRGen/MetadataRequest.h | 2 + .../variadic_generic_func_types.swift | 42 +++++++++++++ 5 files changed, 132 insertions(+), 2 deletions(-) create mode 100644 test/Interpreter/variadic_generic_func_types.swift diff --git a/lib/IRGen/GenPack.cpp b/lib/IRGen/GenPack.cpp index c48e5902803af..31cf1208597aa 100644 --- a/lib/IRGen/GenPack.cpp +++ b/lib/IRGen/GenPack.cpp @@ -1349,3 +1349,63 @@ irgen::emitDynamicTupleTypeLabels(IRGenFunction &IGF, return labelString; } + +StackAddress +irgen::emitDynamicFunctionParameterFlags(IRGenFunction &IGF, + AnyFunctionType::CanParamArrayRef params, + CanPackType packType, + llvm::Value *shapeExpression) { + auto array = + IGF.emitDynamicAlloca(IGF.IGM.Int32Ty, shapeExpression, + Alignment(4), /*allowTaskAlloc=*/true); + + unsigned numExpansions = 0; + + auto visitFn = [&](CanType eltTy, + unsigned scalarIndex, + llvm::Value *dynamicIndex, + llvm::Value *dynamicLength) { + if (scalarIndex != 0 || dynamicIndex == nullptr) { + auto *constant = llvm::ConstantInt::get(IGF.IGM.SizeTy, scalarIndex); + accumulateSum(IGF, dynamicIndex, constant); + } + + auto elt = params[scalarIndex + numExpansions]; + auto flags = getABIParameterFlags(elt.getParameterFlags()); + auto flagsVal = llvm::ConstantInt::get( + IGF.IGM.Int32Ty, flags.getIntValue()); + + assert(eltTy == elt.getPlainType()); + + // If we're looking at a pack expansion, insert the appropriate + // number of flags fields. + if (auto expansionTy = dyn_cast(eltTy)) { + emitPackExpansionPack(IGF, array.getAddress(), expansionTy, + dynamicIndex, dynamicLength, + [&](llvm::Value *) -> llvm::Value * { + return flagsVal; + }); + + // We consumed an expansion. + numExpansions += 1; + + return; + } + + // The destination address, where we put the current element's flags field. + Address eltAddr( + IGF.Builder.CreateInBoundsGEP(array.getAddress().getElementType(), + array.getAddressPointer(), + dynamicIndex), + array.getAddress().getElementType(), + array.getAlignment()); + + // Otherwise, we have a single scalar element, which deposits a single + // flags field. + IGF.Builder.CreateStore(flagsVal, eltAddr); + }; + + (void) visitPackExplosion(IGF, packType, visitFn); + + return array; +} \ No newline at end of file diff --git a/lib/IRGen/GenPack.h b/lib/IRGen/GenPack.h index 912d0960bfb54..504e3c77d4eda 100644 --- a/lib/IRGen/GenPack.h +++ b/lib/IRGen/GenPack.h @@ -120,6 +120,12 @@ emitDynamicTupleTypeLabels(IRGenFunction &IGF, CanPackType packType, llvm::Value *shapeExpression); +StackAddress +emitDynamicFunctionParameterFlags(IRGenFunction &IGF, + AnyFunctionType::CanParamArrayRef params, + CanPackType packType, + llvm::Value *shapeExpression); + } // end namespace irgen } // end namespace swift diff --git a/lib/IRGen/MetadataRequest.cpp b/lib/IRGen/MetadataRequest.cpp index d1cf383451c39..805a1399b1d99 100644 --- a/lib/IRGen/MetadataRequest.cpp +++ b/lib/IRGen/MetadataRequest.cpp @@ -1371,7 +1371,7 @@ static llvm::Value *getFunctionParameterRef(IRGenFunction &IGF, } /// Mapping type-level parameter flags to ABI parameter flags. -static ParameterFlags getABIParameterFlags(ParameterTypeFlags flags) { +ParameterFlags irgen::getABIParameterFlags(ParameterTypeFlags flags) { return ParameterFlags() .withValueOwnership(flags.getValueOwnership()) .withVariadic(flags.isVariadic()) @@ -1489,7 +1489,27 @@ emitDynamicFunctionTypeMetadataParams(IRGenFunction &IGF, CanPackType packType, DynamicMetadataRequest request, SmallVectorImpl &arguments) { - assert(false); + assert(!params.empty()); + + FunctionTypeMetadataParamInfo info; + + llvm::Value *shape; + std::tie(info.parameters, shape) = emitTypeMetadataPack( + IGF, packType, MetadataState::Abstract); + + arguments.push_back(info.parameters.getAddress().getAddress()); + + if (flags.hasParameterFlags()) { + info.paramFlags = emitDynamicFunctionParameterFlags( + IGF, params, packType, shape); + + arguments.push_back(info.paramFlags.getAddress().getAddress()); + } else { + arguments.push_back(llvm::ConstantPointerNull::get( + IGF.IGM.Int32Ty->getPointerTo())); + } + + return info; } static void cleanupFunctionTypeMetadataParams(IRGenFunction &IGF, diff --git a/lib/IRGen/MetadataRequest.h b/lib/IRGen/MetadataRequest.h index d8e08b78f50e2..f752e399c74ed 100644 --- a/lib/IRGen/MetadataRequest.h +++ b/lib/IRGen/MetadataRequest.h @@ -709,6 +709,8 @@ MetadataResponse emitCheckTypeMetadataState(IRGenFunction &IGF, OperationCost getCheckTypeMetadataStateCost(DynamicMetadataRequest request, MetadataResponse response); +ParameterFlags getABIParameterFlags(ParameterTypeFlags flags); + } // end namespace irgen } // end namespace swift diff --git a/test/Interpreter/variadic_generic_func_types.swift b/test/Interpreter/variadic_generic_func_types.swift new file mode 100644 index 0000000000000..3c7614ba87cd6 --- /dev/null +++ b/test/Interpreter/variadic_generic_func_types.swift @@ -0,0 +1,42 @@ +// RUN: %target-run-simple-swift + +// REQUIRES: executable_test + +import StdlibUnittest + +var funcs = TestSuite("VariadicGenericFuncTypes") + +func makeFunctionType1(_: repeat (each T).Type) -> Any.Type { + return ((repeat each T) -> ()).self +} + +func makeFunctionType2(_: repeat (each T).Type) -> Any.Type { + return ((Character, repeat each T, Bool) -> ()).self +} + +func makeFunctionType3(_: repeat (each T).Type) -> Any.Type { + return ((inout Character, repeat each T, inout Bool) -> ()).self +} + +funcs.test("makeFunctionType1") { + expectEqual("() -> ()", _typeName(makeFunctionType1())) + expectEqual("(Swift.Int) -> ()", _typeName(makeFunctionType1(Int.self))) + expectEqual("(Swift.Int, Swift.String) -> ()", _typeName(makeFunctionType1(Int.self, String.self))) + expectEqual("(Swift.Int, Swift.Float, Swift.String) -> ()", _typeName(makeFunctionType1(Int.self, Float.self, String.self))) +} + +funcs.test("makeFunctionType2") { + expectEqual("(Swift.Character, Swift.Bool) -> ()", _typeName(makeFunctionType2())) + expectEqual("(Swift.Character, Swift.Int, Swift.Bool) -> ()", _typeName(makeFunctionType2(Int.self))) + expectEqual("(Swift.Character, Swift.Int, Swift.String, Swift.Bool) -> ()", _typeName(makeFunctionType2(Int.self, String.self))) + expectEqual("(Swift.Character, Swift.Int, Swift.Float, Swift.String, Swift.Bool) -> ()", _typeName(makeFunctionType2(Int.self, Float.self, String.self))) +} + +funcs.test("makeFunctionType3") { + expectEqual("(inout Swift.Character, inout Swift.Bool) -> ()", _typeName(makeFunctionType3())) + expectEqual("(inout Swift.Character, Swift.Int, inout Swift.Bool) -> ()", _typeName(makeFunctionType3(Int.self))) + expectEqual("(inout Swift.Character, Swift.Int, Swift.String, inout Swift.Bool) -> ()", _typeName(makeFunctionType3(Int.self, String.self))) + expectEqual("(inout Swift.Character, Swift.Int, Swift.Float, Swift.String, inout Swift.Bool) -> ()", _typeName(makeFunctionType3(Int.self, Float.self, String.self))) +} + +runAllTests() From 0e80b4e710fd4292f18932f299a24171db020183 Mon Sep 17 00:00:00 2001 From: Doug Gregor Date: Thu, 3 Aug 2023 23:44:50 -0700 Subject: [PATCH 33/37] Expand type refinement contexts to encompass attributes --- lib/Sema/TypeCheckAvailability.cpp | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/lib/Sema/TypeCheckAvailability.cpp b/lib/Sema/TypeCheckAvailability.cpp index 37b8e0b321e15..724050fb7d950 100644 --- a/lib/Sema/TypeCheckAvailability.cpp +++ b/lib/Sema/TypeCheckAvailability.cpp @@ -634,13 +634,7 @@ class TypeRefinementContextBuilder : private ASTWalker { return Range; } - // For pattern binding declarations, include the attributes in the source - // range so that we're sure to cover any property wrappers. - if (auto patternBinding = dyn_cast(D)) { - return D->getSourceRangeIncludingAttrs(); - } - - return D->getSourceRange(); + return D->getSourceRangeIncludingAttrs(); } // Creates an implicit decl TRC specifying the deployment From c9225db1560152efcc6cb0959872692f3845984e Mon Sep 17 00:00:00 2001 From: Slava Pestov Date: Fri, 4 Aug 2023 10:19:21 -0400 Subject: [PATCH 34/37] Disable test/stdlib/Observation/ObservableAvailabilityCycle.swift --- test/stdlib/Observation/ObservableAvailabilityCycle.swift | 1 + 1 file changed, 1 insertion(+) diff --git a/test/stdlib/Observation/ObservableAvailabilityCycle.swift b/test/stdlib/Observation/ObservableAvailabilityCycle.swift index 7343740f2b259..f7b404a0f8a2e 100644 --- a/test/stdlib/Observation/ObservableAvailabilityCycle.swift +++ b/test/stdlib/Observation/ObservableAvailabilityCycle.swift @@ -1,4 +1,5 @@ // REQUIRES: swift_swift_parser +// REQUIRES: rdar113395709 // RUN: %target-swift-frontend -typecheck -parse-as-library -enable-experimental-feature InitAccessors -external-plugin-path %swift-host-lib-dir/plugins#%swift-plugin-server -primary-file %s %S/Inputs/ObservableClass.swift From 310105779b8f358d1798a08888158148fdb1bfad Mon Sep 17 00:00:00 2001 From: Mishal Shah Date: Fri, 4 Aug 2023 08:29:23 -0700 Subject: [PATCH 35/37] Add stable/20230725 to the rebranch alias --- utils/update_checkout/update-checkout-config.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/update_checkout/update-checkout-config.json b/utils/update_checkout/update-checkout-config.json index 48254afbf5031..bb90f068f110c 100644 --- a/utils/update_checkout/update-checkout-config.json +++ b/utils/update_checkout/update-checkout-config.json @@ -135,7 +135,7 @@ } }, "rebranch": { - "aliases": ["rebranch"], + "aliases": ["rebranch", "stable/20230725"], "repos": { "llvm-project": "stable/20230725", "swift-llvm-bindings": "stable/20230725", From 69493d1f4a9784e9c8e24b112811fae20438b339 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexis=20Laferri=C3=A8re?= Date: Thu, 3 Aug 2023 16:17:43 -0700 Subject: [PATCH 36/37] [Sema] @_private imports brings is all SPI of the imported module Within one module, SPI decls are always visible. Conceptually we want the same behavior for `@_private` imports where the client pretends it's part of the same module. rdar://81240984 --- lib/AST/Module.cpp | 3 ++ test/SPI/private-import-access-spi.swift | 68 ++++++++++++++++++++++++ 2 files changed, 71 insertions(+) create mode 100644 test/SPI/private-import-access-spi.swift diff --git a/lib/AST/Module.cpp b/lib/AST/Module.cpp index d9942ece6ccd8..b27796ba315ca 100644 --- a/lib/AST/Module.cpp +++ b/lib/AST/Module.cpp @@ -3419,6 +3419,9 @@ bool SourceFile::isImportedAsSPI(const ValueDecl *targetDecl) const { if (shouldImplicitImportAsSPI(targetDecl->getSPIGroups())) return true; + if (hasTestableOrPrivateImport(AccessLevel::Public, targetDecl, PrivateOnly)) + return true; + lookupImportedSPIGroups(targetModule, importedSPIGroups); if (importedSPIGroups.empty()) return false; diff --git a/test/SPI/private-import-access-spi.swift b/test/SPI/private-import-access-spi.swift new file mode 100644 index 0000000000000..763822c3b2224 --- /dev/null +++ b/test/SPI/private-import-access-spi.swift @@ -0,0 +1,68 @@ +/// An `@_private` import opens access to all SPIs of the imported module. +/// Exports of SPI in API are still reported. + +// RUN: %empty-directory(%t) +// RUN: split-file %s %t + +/// Build the library. +// RUN: %target-swift-frontend -emit-module %t/Lib_FileA.swift %t/Lib_FileB.swift \ +// RUN: -module-name Lib -emit-module-path %t/Lib.swiftmodule \ +// RUN: -enable-library-evolution -swift-version 5 \ +// RUN: -enable-private-imports + +/// Typecheck a @_private client. +// RUN: %target-swift-frontend -typecheck -verify -I %t %t/PrivateClient.swift + +/// Typecheck a regular client building against the same Lib with private imports enabled. +// RUN: %target-swift-frontend -typecheck -verify -I %t %t/RegularClient.swift + +//--- Lib_FileA.swift +@_spi(S) public func spiFuncA() {} +@_spi(S) public struct SPITypeA {} + +//--- Lib_FileB.swift +@_spi(S) public func spiFuncB() {} +@_spi(S) public struct SPITypeB {} + +//--- PrivateClient.swift +@_private(sourceFile: "Lib_FileA.swift") import Lib + +func useOnly(a: SPITypeA, b: SPITypeB) { + spiFuncA() + spiFuncB() +} + +public func export(a: SPITypeA, b: SPITypeB) { // expected-error {{cannot use struct 'SPITypeA' here; it is an SPI imported from 'Lib'}} + // expected-error @-1 {{cannot use struct 'SPITypeB' here; it is an SPI imported from 'Lib'}} + spiFuncA() + spiFuncB() +} + +@inlinable +public func inlinableExport(a: SPITypeA, b: SPITypeB) { // expected-error {{struct 'SPITypeA' cannot be used in an '@inlinable' function because it is an SPI imported from 'Lib'}} + // expected-error @-1 {{struct 'SPITypeB' cannot be used in an '@inlinable' function because it is an SPI imported from 'Lib'}} + spiFuncA() // expected-error {{global function 'spiFuncA()' cannot be used in an '@inlinable' function because it is an SPI imported from 'Lib'}} + spiFuncB() // expected-error {{global function 'spiFuncB()' cannot be used in an '@inlinable' function because it is an SPI imported from 'Lib'}} +} + +//--- RegularClient.swift +import Lib + +func useOnly(a: SPITypeA, b: SPITypeB) { // expected-error {{cannot find type 'SPITypeA' in scope}} + // expected-error @-1 {{cannot find type 'SPITypeB' in scope}} + spiFuncA() // expected-error {{cannot find 'spiFuncA' in scope}} + spiFuncB() // expected-error {{cannot find 'spiFuncB' in scope}} +} + +public func export(a: SPITypeA, b: SPITypeB) { // expected-error {{cannot find type 'SPITypeA' in scope}} + // expected-error @-1 {{cannot find type 'SPITypeB' in scope}} + spiFuncA() // expected-error {{cannot find 'spiFuncA' in scope}} + spiFuncB() // expected-error {{cannot find 'spiFuncB' in scope}} +} + +@inlinable +public func inlinableExport(a: SPITypeA, b: SPITypeB) { // expected-error {{cannot find type 'SPITypeA' in scope}} + // expected-error @-1 {{cannot find type 'SPITypeB' in scope}} + spiFuncA() // expected-error {{cannot find 'spiFuncA' in scope}} + spiFuncB() // expected-error {{cannot find 'spiFuncB' in scope}} +} From c32f022a49fb1eedbd3c97ea32a99ec8a07a8af6 Mon Sep 17 00:00:00 2001 From: Becca Royal-Gordon Date: Fri, 4 Aug 2023 16:18:15 -0700 Subject: [PATCH 37/37] Include cmath so we have ceil() --- lib/Sema/ConstraintSystem.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/lib/Sema/ConstraintSystem.cpp b/lib/Sema/ConstraintSystem.cpp index 3d5893bc12c0a..1ce1c01ed94c9 100644 --- a/lib/Sema/ConstraintSystem.cpp +++ b/lib/Sema/ConstraintSystem.cpp @@ -38,6 +38,7 @@ #include "llvm/ADT/SmallString.h" #include "llvm/Support/Compiler.h" #include "llvm/Support/Format.h" +#include using namespace swift; using namespace constraints;