diff --git a/include/swift/AST/Decl.h b/include/swift/AST/Decl.h index 573a640eb2164..0a764b44f06c1 100644 --- a/include/swift/AST/Decl.h +++ b/include/swift/AST/Decl.h @@ -6265,11 +6265,9 @@ class AbstractFunctionDecl : public GenericContext, public ValueDecl { DerivativeFunctionConfigurationList *DerivativeFunctionConfigs = nullptr; public: - /// Get all derivative function configurations. If `lookInNonPrimarySources` - /// is true then lookup is done in non-primary sources as well. Note that - /// such lookup might end in cycles if done during sema stages. + /// Get all derivative function configurations. ArrayRef - getDerivativeFunctionConfigurations(bool lookInNonPrimarySources = true); + getDerivativeFunctionConfigurations(); /// Add the given derivative function configuration. void addDerivativeFunctionConfiguration(const AutoDiffConfig &config); diff --git a/include/swift/Frontend/Frontend.h b/include/swift/Frontend/Frontend.h index 81fb962ad843c..b0f6e4dcf4fc6 100644 --- a/include/swift/Frontend/Frontend.h +++ b/include/swift/Frontend/Frontend.h @@ -669,6 +669,7 @@ class CompilerInstance { /// If \p fn returns true, exits early and returns true. bool forEachFileToTypeCheck(llvm::function_ref fn); + bool forEachSourceFile(llvm::function_ref fn); /// Whether the cancellation of the current operation has been requested. bool isCancellationRequested() const; diff --git a/include/swift/Subsystems.h b/include/swift/Subsystems.h index 41a152bec3889..16b6d589b034a 100644 --- a/include/swift/Subsystems.h +++ b/include/swift/Subsystems.h @@ -157,6 +157,10 @@ namespace swift { /// emitted. void performWholeModuleTypeChecking(SourceFile &SF); + /// Load derivative configurations from @derivative attributes (including + /// those defined in non-primary sources). + void loadDerivativeConfigurations(SourceFile &SF); + /// Resolve the given \c TypeRepr to an interface type. /// /// This is used when dealing with partial source files (e.g. SIL parsing, diff --git a/lib/AST/Decl.cpp b/lib/AST/Decl.cpp index a391f2a550eef..80538d89e428a 100644 --- a/lib/AST/Decl.cpp +++ b/lib/AST/Decl.cpp @@ -8312,7 +8312,7 @@ void AbstractFunctionDecl::prepareDerivativeFunctionConfigurations() { } ArrayRef -AbstractFunctionDecl::getDerivativeFunctionConfigurations(bool lookInNonPrimarySources) { +AbstractFunctionDecl::getDerivativeFunctionConfigurations() { prepareDerivativeFunctionConfigurations(); // Resolve derivative function configurations from `@differentiable` @@ -8336,36 +8336,6 @@ AbstractFunctionDecl::getDerivativeFunctionConfigurations(bool lookInNonPrimaryS *DerivativeFunctionConfigs); } - class DerivativeFinder : public ASTWalker { - const AbstractFunctionDecl *AFD; - public: - DerivativeFinder(const AbstractFunctionDecl *afd) : AFD(afd) {} - - bool walkToDeclPre(Decl *D) override { - if (auto *afd = dyn_cast(D)) { - for (auto *derAttr : afd->getAttrs().getAttributes()) { - // Resolve derivative function configurations from `@derivative` - // attributes by type-checking them. - if (AFD->getName().matchesRef( - derAttr->getOriginalFunctionName().Name.getFullName())) { - (void)derAttr->getOriginalFunction(afd->getASTContext()); - return false; - } - } - } - - return true; - } - }; - - // Load derivative configurations from @derivative attributes defined in - // non-primary sources. Note that it might trigger lookup cycles if called - // from inside Sema stages. - if (lookInNonPrimarySources) { - DerivativeFinder finder(this); - getParent()->walkContext(finder); - } - return DerivativeFunctionConfigs->getArrayRef(); } diff --git a/lib/Frontend/Frontend.cpp b/lib/Frontend/Frontend.cpp index e11946d535b71..22c530742ce1c 100644 --- a/lib/Frontend/Frontend.cpp +++ b/lib/Frontend/Frontend.cpp @@ -1186,11 +1186,31 @@ bool CompilerInstance::forEachFileToTypeCheck( return false; } +bool CompilerInstance::forEachSourceFile( + llvm::function_ref fn) { + for (auto fileName : getMainModule()->getFiles()) { + auto *SF = dyn_cast(fileName); + if (!SF) { + continue; + } + if (fn(*SF)) + return true; + ; + } + + return false; +} + void CompilerInstance::finishTypeChecking() { forEachFileToTypeCheck([](SourceFile &SF) { performWholeModuleTypeChecking(SF); return false; }); + + forEachSourceFile([](SourceFile &SF) { + loadDerivativeConfigurations(SF); + return false; + }); } SourceFile::ParsingOptions diff --git a/lib/Sema/TypeCheckProtocol.cpp b/lib/Sema/TypeCheckProtocol.cpp index 2d37c9b919048..43c6882b3929d 100644 --- a/lib/Sema/TypeCheckProtocol.cpp +++ b/lib/Sema/TypeCheckProtocol.cpp @@ -379,8 +379,7 @@ matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req, bool foundExactConfig = false; Optional supersetConfig = None; for (auto witnessConfig : - witnessAFD->getDerivativeFunctionConfigurations( - /*lookInNonPrimarySources*/ false)) { + witnessAFD->getDerivativeFunctionConfigurations()) { // All the witness's derivative generic requirements must be satisfied // by the requirement's derivative generic requirements OR by the // conditional conformance requirements. diff --git a/lib/Sema/TypeChecker.cpp b/lib/Sema/TypeChecker.cpp index 1f7e19d8a8257..bc1fece472e96 100644 --- a/lib/Sema/TypeChecker.cpp +++ b/lib/Sema/TypeChecker.cpp @@ -375,6 +375,44 @@ void swift::performWholeModuleTypeChecking(SourceFile &SF) { } } +void swift::loadDerivativeConfigurations(SourceFile &SF) { + if (!isDifferentiableProgrammingEnabled(SF)) + return; + + auto &Ctx = SF.getASTContext(); + FrontendStatsTracer tracer(Ctx.Stats, + "load-derivative-configurations"); + + class DerivativeFinder : public ASTWalker { + public: + DerivativeFinder() {} + + bool walkToDeclPre(Decl *D) override { + if (auto *afd = dyn_cast(D)) { + for (auto *derAttr : afd->getAttrs().getAttributes()) { + // Resolve derivative function configurations from `@derivative` + // attributes by type-checking them. + (void)derAttr->getOriginalFunction(D->getASTContext()); + } + } + + return true; + } + }; + + switch (SF.Kind) { + case SourceFileKind::Library: + case SourceFileKind::Main: { + DerivativeFinder finder; + SF.walkContext(finder); + return; + } + case SourceFileKind::SIL: + case SourceFileKind::Interface: + return; + } +} + bool swift::isAdditiveArithmeticConformanceDerivationEnabled(SourceFile &SF) { auto &ctx = SF.getASTContext(); // Return true if `AdditiveArithmetic` derived conformances are explicitly diff --git a/test/AutoDiff/Sema/DerivativeRegistrationCrossFile/Inputs/derivatives.swift b/test/AutoDiff/Sema/DerivativeRegistrationCrossFile/Inputs/derivatives.swift new file mode 100644 index 0000000000000..71dffec74e834 --- /dev/null +++ b/test/AutoDiff/Sema/DerivativeRegistrationCrossFile/Inputs/derivatives.swift @@ -0,0 +1,35 @@ +import _Differentiation + +@inlinable +@derivative(of: min) +func minVJP( + _ x: T, + _ y: T +) -> (value: T, pullback: (T.TangentVector) -> (T.TangentVector, T.TangentVector)) { + func pullback(_ v: T.TangentVector) -> (T.TangentVector, T.TangentVector) { + if x <= y { + return (v, .zero) + } + else { + return (.zero, v) + } + } + return (value: min(x, y), pullback: pullback) +} + +@inlinable +@derivative(of: max) +func maxVJP( + _ x: T, + _ y: T +) -> (value: T, pullback: (T.TangentVector) -> (T.TangentVector, T.TangentVector)) { + func pullback(_ v: T.TangentVector) -> (T.TangentVector, T.TangentVector) { + if x < y { + return (.zero, v) + } + else { + return (v, .zero) + } + } + return (value: max(x, y), pullback: pullback) +} diff --git a/test/AutoDiff/Sema/DerivativeRegistrationCrossFile/main.swift b/test/AutoDiff/Sema/DerivativeRegistrationCrossFile/main.swift new file mode 100644 index 0000000000000..071510d508373 --- /dev/null +++ b/test/AutoDiff/Sema/DerivativeRegistrationCrossFile/main.swift @@ -0,0 +1,9 @@ +// RUN: %target-swift-frontend -emit-sil -verify -primary-file %s %S/Inputs/derivatives.swift -module-name main -o /dev/null + +import _Differentiation + +@differentiable(reverse) +func clamp(_ value: Double, _ lowerBound: Double, _ upperBound: Double) -> Double { + // No error expected + return max(min(value, upperBound), lowerBound) +}