diff --git a/include/swift/AST/DeclContext.h b/include/swift/AST/DeclContext.h index 2ae5ba9044a50..d40c6a74b3f2b 100644 --- a/include/swift/AST/DeclContext.h +++ b/include/swift/AST/DeclContext.h @@ -802,6 +802,10 @@ class IterableDeclContext { /// We must restore this when delayed parsing the body. unsigned InFreestandingMacroArgument : 1; + /// Whether delayed parsing detect a possible custom derivative definition + /// while skipping the body of this context. + unsigned HasDerivativeDeclarations : 1; + template friend struct ::llvm::CastInfo; @@ -817,6 +821,7 @@ class IterableDeclContext { : LastDeclAndKind(nullptr, kind) { AddedParsedMembers = 0; HasOperatorDeclarations = 0; + HasDerivativeDeclarations = 0; HasNestedClassDeclarations = 0; InFreestandingMacroArgument = 0; } @@ -855,6 +860,15 @@ class IterableDeclContext { InFreestandingMacroArgument = 1; } + bool maybeHasDerivativeDeclarations() const { + return HasDerivativeDeclarations; + } + + void setMaybeHasDerivativeDeclarations() { + assert(hasUnparsedMembers()); + HasDerivativeDeclarations = 1; + } + /// Retrieve the current set of members in this context. /// /// NOTE: This operation is an alias of \c getCurrentMembers() that is considered diff --git a/include/swift/AST/Module.h b/include/swift/AST/Module.h index 9d5bae570ffd6..c5724a6893e96 100644 --- a/include/swift/AST/Module.h +++ b/include/swift/AST/Module.h @@ -240,6 +240,7 @@ class ModuleDecl : public DeclContext, public TypeDecl, public ASTAllocated { friend class DirectOperatorLookupRequest; friend class DirectPrecedenceGroupLookupRequest; + friend class CustomDerivativesRequest; /// The ABI name of the module, if it differs from the module name. mutable Identifier ModuleABIName; diff --git a/include/swift/AST/TypeCheckRequests.h b/include/swift/AST/TypeCheckRequests.h index b876c842368f1..f20af7b53343d 100644 --- a/include/swift/AST/TypeCheckRequests.h +++ b/include/swift/AST/TypeCheckRequests.h @@ -5165,6 +5165,22 @@ class GenericTypeParamDeclGetValueTypeRequest bool isCached() const { return true; } }; +class CustomDerivativesRequest + : public SimpleRequest { +public: + using SimpleRequest::SimpleRequest; + +private: + friend SimpleRequest; + + evaluator::SideEffect evaluate(Evaluator &evaluator, SourceFile *sf) const; + +public: + bool isCached() const { return true; } +}; + #define SWIFT_TYPEID_ZONE TypeChecker #define SWIFT_TYPEID_HEADER "swift/AST/TypeCheckerTypeIDZone.def" #include "swift/Basic/DefineTypeIDZone.h" diff --git a/include/swift/AST/TypeCheckerTypeIDZone.def b/include/swift/AST/TypeCheckerTypeIDZone.def index e2b20ee4cb086..a3cc4c195dacb 100644 --- a/include/swift/AST/TypeCheckerTypeIDZone.def +++ b/include/swift/AST/TypeCheckerTypeIDZone.def @@ -605,6 +605,9 @@ SWIFT_REQUEST(TypeChecker, ParamCaptureInfoRequest, SWIFT_REQUEST(TypeChecker, IsUnsafeRequest, bool(Decl *), SeparatelyCached, NoLocationInfo) +SWIFT_REQUEST(TypeChecker, CustomDerivativesRequest, + CustomDerivativesResult(SourceFile *), + Cached, NoLocationInfo) SWIFT_REQUEST(TypeChecker, GenericTypeParamDeclGetValueTypeRequest, Type(GenericTypeParamDecl *), Cached, NoLocationInfo) diff --git a/include/swift/Parse/Parser.h b/include/swift/Parse/Parser.h index ca475d44826cb..c739c468bcb26 100644 --- a/include/swift/Parse/Parser.h +++ b/include/swift/Parse/Parser.h @@ -974,7 +974,8 @@ class Parser { IterableDeclContext *IDC); bool canDelayMemberDeclParsing(bool &HasOperatorDeclarations, - bool &HasNestedClassDeclarations); + bool &HasNestedClassDeclarations, + bool &HasDerivativeDeclarations); bool canDelayFunctionBodyParsing(bool &HasNestedTypeDeclarations); diff --git a/lib/AST/Module.cpp b/lib/AST/Module.cpp index 9f5d9572aaf78..c0a1b086bfbf1 100644 --- a/lib/AST/Module.cpp +++ b/lib/AST/Module.cpp @@ -166,6 +166,7 @@ class swift::SourceLookupCache { ValueDeclMap TopLevelValues; ValueDeclMap ClassMembers; bool MemberCachePopulated = false; + llvm::SmallVector CustomDerivatives; DeclName UniqueMacroNamePlaceholder; template @@ -173,8 +174,9 @@ class swift::SourceLookupCache { OperatorMap Operators; OperatorMap PrecedenceGroups; - template - void addToUnqualifiedLookupCache(Range decls, bool onlyOperators); + template + void addToUnqualifiedLookupCache(Range decls, bool onlyOperators, + bool onlyDerivatives); template void addToMemberCache(Range decls); @@ -205,6 +207,10 @@ class swift::SourceLookupCache { /// guaranteed to be meaningful. void getPrecedenceGroups(SmallVectorImpl &results); + /// Retrieves all the function decls marked as @derivative. The order of the + /// results is not guaranteed to be meaningful. + llvm::SmallVector getCustomDerivativeDecls(); + /// Look up an operator declaration. /// /// \param name The operator name ("+", ">>", etc.) @@ -249,9 +255,10 @@ static Decl *getAsDecl(Decl *decl) { return decl; } static Expr *getAsExpr(ASTNode node) { return node.dyn_cast(); } static Decl *getAsDecl(ASTNode node) { return node.dyn_cast(); } -template +template void SourceLookupCache::addToUnqualifiedLookupCache(Range items, - bool onlyOperators) { + bool onlyOperators, + bool onlyDerivatives) { for (auto item : items) { // In script mode, we'll see macro expansion expressions for freestanding // macros. @@ -268,19 +275,36 @@ void SourceLookupCache::addToUnqualifiedLookupCache(Range items, continue; if (auto *VD = dyn_cast(D)) { - if (onlyOperators ? VD->isOperator() : VD->hasName()) { - // Cache the value under both its compound name and its full name. + auto getDerivative = [onlyDerivatives, VD]() -> AbstractFunctionDecl * { + if (auto *AFD = dyn_cast(VD)) + if (AFD->getAttrs().hasAttribute()) + return AFD; + return nullptr; + }; + if (onlyOperators && VD->isOperator()) TopLevelValues.add(VD); - - if (!onlyOperators && VD->getAttrs().hasAttribute()) { + if (onlyDerivatives) + if (AbstractFunctionDecl *AFD = getDerivative()) + CustomDerivatives.push_back(AFD); + if (!onlyOperators && !onlyDerivatives && VD->hasName()) { + TopLevelValues.add(VD); + if (VD->getAttrs().hasAttribute()) MayHaveAuxiliaryDecls.push_back(VD); - } + if (AbstractFunctionDecl *AFD = getDerivative()) + CustomDerivatives.push_back(AFD); } } - if (auto *NTD = dyn_cast(D)) - if (!NTD->hasUnparsedMembers() || NTD->maybeHasOperatorDeclarations()) - addToUnqualifiedLookupCache(NTD->getMembers(), true); + if (auto *NTD = dyn_cast(D)) { + bool onlyOperatorsArg = + (!NTD->hasUnparsedMembers() || NTD->maybeHasOperatorDeclarations()); + bool onlyDerivativesArg = + (!NTD->hasUnparsedMembers() || NTD->maybeHasDerivativeDeclarations()); + if (onlyOperatorsArg || onlyDerivativesArg) { + addToUnqualifiedLookupCache(NTD->getMembers(), onlyOperatorsArg, + onlyDerivativesArg); + } + } if (auto *ED = dyn_cast(D)) { // Avoid populating the cache with the members of invalid extension @@ -292,8 +316,14 @@ void SourceLookupCache::addToUnqualifiedLookupCache(Range items, MayHaveAuxiliaryDecls.push_back(ED); } - if (!ED->hasUnparsedMembers() || ED->maybeHasOperatorDeclarations()) - addToUnqualifiedLookupCache(ED->getMembers(), true); + bool onlyOperatorsArg = + (!ED->hasUnparsedMembers() || ED->maybeHasOperatorDeclarations()); + bool onlyDerivativesArg = + (!ED->hasUnparsedMembers() || ED->maybeHasDerivativeDeclarations()); + if (onlyOperatorsArg || onlyDerivativesArg) { + addToUnqualifiedLookupCache(ED->getMembers(), onlyOperatorsArg, + onlyDerivativesArg); + } } if (auto *OD = dyn_cast(D)) @@ -307,7 +337,8 @@ void SourceLookupCache::addToUnqualifiedLookupCache(Range items, MayHaveAuxiliaryDecls.push_back(MED); } else if (auto TLCD = dyn_cast(D)) { if (auto body = TLCD->getBody()){ - addToUnqualifiedLookupCache(body->getElements(), onlyOperators); + addToUnqualifiedLookupCache(body->getElements(), onlyOperators, + onlyDerivatives); } } } @@ -488,8 +519,8 @@ SourceLookupCache::SourceLookupCache(const SourceFile &SF) { FrontendStatsTracer tracer(SF.getASTContext().Stats, "source-file-populate-cache"); - addToUnqualifiedLookupCache(SF.getTopLevelItems(), false); - addToUnqualifiedLookupCache(SF.getHoistedDecls(), false); + addToUnqualifiedLookupCache(SF.getTopLevelItems(), false, false); + addToUnqualifiedLookupCache(SF.getHoistedDecls(), false, false); } SourceLookupCache::SourceLookupCache(const ModuleDecl &M) @@ -499,11 +530,11 @@ SourceLookupCache::SourceLookupCache(const ModuleDecl &M) "module-populate-cache"); for (const FileUnit *file : M.getFiles()) { auto *SF = cast(file); - addToUnqualifiedLookupCache(SF->getTopLevelItems(), false); - addToUnqualifiedLookupCache(SF->getHoistedDecls(), false); + addToUnqualifiedLookupCache(SF->getTopLevelItems(), false, false); + addToUnqualifiedLookupCache(SF->getHoistedDecls(), false, false); if (auto *SFU = file->getSynthesizedFile()) { - addToUnqualifiedLookupCache(SFU->getTopLevelDecls(), false); + addToUnqualifiedLookupCache(SFU->getTopLevelDecls(), false, false); } } } @@ -572,6 +603,11 @@ void SourceLookupCache::getOperatorDecls( results.append(ops.second.begin(), ops.second.end()); } +llvm::SmallVector +SourceLookupCache::getCustomDerivativeDecls() { + return CustomDerivatives; +} + void SourceLookupCache::lookupOperator(Identifier name, OperatorFixity fixity, TinyPtrVector &results) { auto ops = Operators.find(name); @@ -4027,6 +4063,23 @@ bool IsNonUserModuleRequest::evaluate(Evaluator &evaluator, ModuleDecl *mod) con return false; } +evaluator::SideEffect CustomDerivativesRequest::evaluate(Evaluator &evaluator, + SourceFile *sf) const { + ModuleDecl *module = sf->getParentModule(); + assert(isParsedModule(module)); + llvm::SmallVector decls = + module->getSourceLookupCache().getCustomDerivativeDecls(); + for (const AbstractFunctionDecl *afd : decls) { + for (const auto *derAttr : + afd->getAttrs().getAttributes()) { + // Resolve derivative function configurations from `@derivative` + // attributes by type-checking them. + (void)derAttr->getOriginalFunction(sf->getASTContext()); + } + } + return {}; +} + version::Version ModuleDecl::getLanguageVersionBuiltWith() const { for (auto *F : getFiles()) { auto *LD = dyn_cast(F); diff --git a/lib/Parse/ParseDecl.cpp b/lib/Parse/ParseDecl.cpp index 06db030c112fd..310112ab082e1 100644 --- a/lib/Parse/ParseDecl.cpp +++ b/lib/Parse/ParseDecl.cpp @@ -5644,19 +5644,20 @@ static void diagnoseOperatorFixityAttributes(Parser &P, } } -static unsigned skipUntilMatchingRBrace(Parser &P, - bool &HasPoundDirective, +static unsigned skipUntilMatchingRBrace(Parser &P, bool &HasPoundDirective, bool &HasPoundSourceLocation, bool &HasOperatorDeclarations, bool &HasNestedClassDeclarations, bool &HasNestedTypeDeclarations, - bool &HasPotentialRegexLiteral) { + bool &HasPotentialRegexLiteral, + bool &HasDerivativeDeclarations) { HasPoundDirective = false; HasPoundSourceLocation = false; HasOperatorDeclarations = false; HasNestedClassDeclarations = false; HasNestedTypeDeclarations = false; HasPotentialRegexLiteral = false; + HasDerivativeDeclarations = false; unsigned OpenBraces = 1; unsigned OpenPoundIf = 0; @@ -5685,6 +5686,18 @@ static unsigned skipUntilMatchingRBrace(Parser &P, tok::kw_protocol) || P.Tok.isContextualKeyword("actor"); + if (P.consumeIf(tok::at_sign)) { + if (P.Tok.is(tok::identifier)) { + std::optional DK = + DeclAttribute::getAttrKindFromString(P.Tok.getText()); + if (DK && *DK == DeclAttrKind::Derivative) { + HasDerivativeDeclarations = true; + P.consumeToken(); + } + } + continue; + } + // HACK: Bail if we encounter what could potentially be a regex literal. // This is necessary as: // - We might encounter an invalid Swift token that might be valid in a @@ -7033,13 +7046,17 @@ bool Parser::parseMemberDeclList(SourceLoc &LBLoc, SourceLoc &RBLoc, bool HasOperatorDeclarations = false; bool HasNestedClassDeclarations = false; + bool HasDerivativeDeclarations = false; if (canDelayMemberDeclParsing(HasOperatorDeclarations, - HasNestedClassDeclarations)) { + HasNestedClassDeclarations, + HasDerivativeDeclarations)) { if (HasOperatorDeclarations) IDC->setMaybeHasOperatorDeclarations(); if (HasNestedClassDeclarations) IDC->setMaybeHasNestedClassDeclarations(); + if (HasDerivativeDeclarations) + IDC->setMaybeHasDerivativeDeclarations(); if (InFreestandingMacroArgument) IDC->setInFreestandingMacroArgument(); @@ -7052,6 +7069,7 @@ bool Parser::parseMemberDeclList(SourceLoc &LBLoc, SourceLoc &RBLoc, auto membersAndHash = parseDeclList(LBLoc, RBLoc, RBraceDiag, hadError); IDC->setMaybeHasOperatorDeclarations(); IDC->setMaybeHasNestedClassDeclarations(); + IDC->setMaybeHasDerivativeDeclarations(); Context.evaluator.cacheOutput( ParseMembersRequest{IDC}, FingerprintAndMembers{ @@ -7112,7 +7130,8 @@ Parser::parseDeclList(SourceLoc LBLoc, SourceLoc &RBLoc, Diag<> ErrorDiag, } bool Parser::canDelayMemberDeclParsing(bool &HasOperatorDeclarations, - bool &HasNestedClassDeclarations) { + bool &HasNestedClassDeclarations, + bool &HasDerivativeDeclarations) { // If explicitly disabled, respect the flag. if (!isDelayedParsingEnabled()) return false; @@ -7124,13 +7143,10 @@ bool Parser::canDelayMemberDeclParsing(bool &HasOperatorDeclarations, bool HasPoundSourceLocation; bool HasNestedTypeDeclarations; bool HasPotentialRegexLiteral; - skipUntilMatchingRBrace(*this, - HasPoundDirective, - HasPoundSourceLocation, - HasOperatorDeclarations, - HasNestedClassDeclarations, - HasNestedTypeDeclarations, - HasPotentialRegexLiteral); + skipUntilMatchingRBrace(*this, HasPoundDirective, HasPoundSourceLocation, + HasOperatorDeclarations, HasNestedClassDeclarations, + HasNestedTypeDeclarations, HasPotentialRegexLiteral, + HasDerivativeDeclarations); if (!HasPoundDirective && !HasPotentialRegexLiteral) { // If we didn't see any pound directive, we must not have seen // #sourceLocation either. @@ -7742,9 +7758,11 @@ bool Parser::canDelayFunctionBodyParsing(bool &HasNestedTypeDeclarations) { bool HasOperatorDeclarations; bool HasNestedClassDeclarations; bool HasPotentialRegexLiteral; + bool HasDerivativeDeclarations; skipUntilMatchingRBrace(*this, HasPoundDirectives, HasPoundSourceLocation, HasOperatorDeclarations, HasNestedClassDeclarations, - HasNestedTypeDeclarations, HasPotentialRegexLiteral); + HasNestedTypeDeclarations, HasPotentialRegexLiteral, + HasDerivativeDeclarations); if (HasPoundSourceLocation || HasPotentialRegexLiteral) return false; diff --git a/lib/Sema/TypeChecker.cpp b/lib/Sema/TypeChecker.cpp index 2131dd4e17fe4..ee3a982d8dec3 100644 --- a/lib/Sema/TypeChecker.cpp +++ b/lib/Sema/TypeChecker.cpp @@ -404,34 +404,13 @@ void swift::loadDerivativeConfigurations(SourceFile &SF) { FrontendStatsTracer tracer(Ctx.Stats, "load-derivative-configurations"); - class DerivativeFinder : public ASTWalker { - public: - DerivativeFinder() {} - - MacroWalking getMacroWalkingBehavior() const override { - return MacroWalking::Expansion; - } - - PreWalkAction walkToDeclPre(Decl *D) override { - if (auto *afd = dyn_cast(D)) { - for (auto *derAttr : afd->getAttrs().getAttributes()) { - // Resolve derivative function configurations from `@derivative` - // attributes by type-checking them. - (void)derAttr->getOriginalFunction(D->getASTContext()); - } - } - - return Action::Continue(); - } - }; - switch (SF.Kind) { case SourceFileKind::DefaultArgument: case SourceFileKind::Library: case SourceFileKind::MacroExpansion: case SourceFileKind::Main: { - DerivativeFinder finder; - SF.walkContext(finder); + CustomDerivativesRequest request(&SF); + evaluateOrDefault(SF.getASTContext().evaluator, request, {}); return; } case SourceFileKind::SIL: diff --git a/test/AutoDiff/SILOptimizer/Inputs/differentiation_diagnostics_other_file.swift b/test/AutoDiff/SILOptimizer/Inputs/differentiation_diagnostics_other_file.swift index 568218e1f155f..37499434c67aa 100644 --- a/test/AutoDiff/SILOptimizer/Inputs/differentiation_diagnostics_other_file.swift +++ b/test/AutoDiff/SILOptimizer/Inputs/differentiation_diagnostics_other_file.swift @@ -26,6 +26,18 @@ extension Protocol { } } +struct Struct: Differentiable { + func identityDerivativeAttr() -> Self { self } + + // Test cross-file `@derivative` attribute. + @derivative(of: identityDerivativeAttr) + func vjpIdentityDerivativeAttr() -> ( + value: Self, pullback: (TangentVector) -> TangentVector + ) { + fatalError() + } +} + class Class: Differentiable { // Test `@differentiable` propagation from storage declaration to accessors. @differentiable(reverse) diff --git a/test/AutoDiff/SILOptimizer/differentiation_diagnostics_cross_file.swift b/test/AutoDiff/SILOptimizer/differentiation_diagnostics_cross_file.swift index d9b5f93f0a04b..4482fe8093270 100644 --- a/test/AutoDiff/SILOptimizer/differentiation_diagnostics_cross_file.swift +++ b/test/AutoDiff/SILOptimizer/differentiation_diagnostics_cross_file.swift @@ -21,6 +21,12 @@ func crossFileDerivativeAttr( return input.identityDerivativeAttr() } +@differentiable(reverse) +func crossFileDerivativeAttr(_ input: Struct) -> Struct { + // No error expected + return input.identityDerivativeAttr() +} + // TF-1234: Test `@differentiable` propagation from protocol requirement storage // declarations to their accessors in other file. @differentiable(reverse)