Skip to content

Commit a94d692

Browse files
committed
[AutoDiff] Enhance performance of custom derivatives lookup
In swiftlang#58965, lookup for custom derivatives in non-primary source files was introduced. It required traversing all delayed parsed function bodies of a file if the file was compiled with differential programming enabled (even for functions with no `@derivative` attribute). This patch introduces `CustomDerivativesRequest` to address the issue. Resolves swiftlang#60102
1 parent 3230a68 commit a94d692

File tree

10 files changed

+159
-57
lines changed

10 files changed

+159
-57
lines changed

include/swift/AST/DeclContext.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -798,6 +798,10 @@ class IterableDeclContext {
798798
/// while skipping the body of this context.
799799
unsigned HasNestedClassDeclarations : 1;
800800

801+
/// Whether delayed parsing detect a possible custom derivative definition
802+
/// while skipping the body of this context.
803+
unsigned HasDerivativeDeclarations : 1;
804+
801805
template<class A, class B, class C>
802806
friend struct ::llvm::CastInfo;
803807

@@ -813,6 +817,7 @@ class IterableDeclContext {
813817
: LastDeclAndKind(nullptr, kind) {
814818
AddedParsedMembers = 0;
815819
HasOperatorDeclarations = 0;
820+
HasDerivativeDeclarations = 0;
816821
HasNestedClassDeclarations = 0;
817822
}
818823

@@ -841,6 +846,15 @@ class IterableDeclContext {
841846
HasNestedClassDeclarations = 1;
842847
}
843848

849+
bool maybeHasDerivativeDeclarations() const {
850+
return HasDerivativeDeclarations;
851+
}
852+
853+
void setMaybeHasDerivativeDeclarations() {
854+
assert(hasUnparsedMembers());
855+
HasDerivativeDeclarations = 1;
856+
}
857+
844858
/// Retrieve the current set of members in this context.
845859
///
846860
/// NOTE: This operation is an alias of \c getCurrentMembers() that is considered

include/swift/AST/Module.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,7 @@ class ModuleDecl
240240
: public DeclContext, public TypeDecl, public ASTAllocated<ModuleDecl> {
241241
friend class DirectOperatorLookupRequest;
242242
friend class DirectPrecedenceGroupLookupRequest;
243+
friend class CustomDerivativesRequest;
243244

244245
/// The ABI name of the module, if it differs from the module name.
245246
mutable Identifier ModuleABIName;

include/swift/AST/TypeCheckRequests.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5165,6 +5165,22 @@ class GenericTypeParamDeclGetValueTypeRequest
51655165
bool isCached() const { return true; }
51665166
};
51675167

5168+
class CustomDerivativesRequest
5169+
: public SimpleRequest<CustomDerivativesRequest,
5170+
evaluator::SideEffect(SourceFile *),
5171+
RequestFlags::Cached> {
5172+
public:
5173+
using SimpleRequest::SimpleRequest;
5174+
5175+
private:
5176+
friend SimpleRequest;
5177+
5178+
evaluator::SideEffect evaluate(Evaluator &evaluator, SourceFile *sf) const;
5179+
5180+
public:
5181+
bool isCached() const { return true; }
5182+
};
5183+
51685184
#define SWIFT_TYPEID_ZONE TypeChecker
51695185
#define SWIFT_TYPEID_HEADER "swift/AST/TypeCheckerTypeIDZone.def"
51705186
#include "swift/Basic/DefineTypeIDZone.h"

include/swift/AST/TypeCheckerTypeIDZone.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -605,6 +605,9 @@ SWIFT_REQUEST(TypeChecker, ParamCaptureInfoRequest,
605605
SWIFT_REQUEST(TypeChecker, IsUnsafeRequest,
606606
bool(Decl *),
607607
SeparatelyCached, NoLocationInfo)
608+
SWIFT_REQUEST(TypeChecker, CustomDerivativesRequest,
609+
CustomDerivativesResult(SourceFile *),
610+
Cached, NoLocationInfo)
608611

609612
SWIFT_REQUEST(TypeChecker, GenericTypeParamDeclGetValueTypeRequest,
610613
Type(GenericTypeParamDecl *), Cached, NoLocationInfo)

include/swift/Parse/Parser.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -974,7 +974,8 @@ class Parser {
974974
IterableDeclContext *IDC);
975975

976976
bool canDelayMemberDeclParsing(bool &HasOperatorDeclarations,
977-
bool &HasNestedClassDeclarations);
977+
bool &HasNestedClassDeclarations,
978+
bool &HasDerivativeDeclarations);
978979

979980
bool canDelayFunctionBodyParsing(bool &HasNestedTypeDeclarations);
980981

lib/AST/Module.cpp

Lines changed: 72 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -166,15 +166,17 @@ class swift::SourceLookupCache {
166166
ValueDeclMap TopLevelValues;
167167
ValueDeclMap ClassMembers;
168168
bool MemberCachePopulated = false;
169+
llvm::SmallVector<AbstractFunctionDecl *, 0> CustomDerivatives;
169170
DeclName UniqueMacroNamePlaceholder;
170171

171172
template<typename T>
172173
using OperatorMap = llvm::DenseMap<Identifier, TinyPtrVector<T *>>;
173174
OperatorMap<OperatorDecl> Operators;
174175
OperatorMap<PrecedenceGroupDecl> PrecedenceGroups;
175176

176-
template<typename Range>
177-
void addToUnqualifiedLookupCache(Range decls, bool onlyOperators);
177+
template <typename Range>
178+
void addToUnqualifiedLookupCache(Range decls, bool onlyOperators,
179+
bool onlyDerivatives);
178180
template<typename Range>
179181
void addToMemberCache(Range decls);
180182

@@ -205,6 +207,9 @@ class swift::SourceLookupCache {
205207
/// guaranteed to be meaningful.
206208
void getPrecedenceGroups(SmallVectorImpl<PrecedenceGroupDecl *> &results);
207209

210+
// TODO: is it valid to return const reference from here?
211+
llvm::SmallVector<AbstractFunctionDecl *, 0> getCustomDerivativeDecls();
212+
208213
/// Look up an operator declaration.
209214
///
210215
/// \param name The operator name ("+", ">>", etc.)
@@ -249,9 +254,10 @@ static Decl *getAsDecl(Decl *decl) { return decl; }
249254
static Expr *getAsExpr(ASTNode node) { return node.dyn_cast<Expr *>(); }
250255
static Decl *getAsDecl(ASTNode node) { return node.dyn_cast<Decl *>(); }
251256

252-
template<typename Range>
257+
template <typename Range>
253258
void SourceLookupCache::addToUnqualifiedLookupCache(Range items,
254-
bool onlyOperators) {
259+
bool onlyOperators,
260+
bool onlyDerivatives) {
255261
for (auto item : items) {
256262
// In script mode, we'll see macro expansion expressions for freestanding
257263
// macros.
@@ -268,19 +274,36 @@ void SourceLookupCache::addToUnqualifiedLookupCache(Range items,
268274
continue;
269275

270276
if (auto *VD = dyn_cast<ValueDecl>(D)) {
271-
if (onlyOperators ? VD->isOperator() : VD->hasName()) {
272-
// Cache the value under both its compound name and its full name.
277+
auto getDerivative = [onlyDerivatives, VD]() -> AbstractFunctionDecl * {
278+
if (auto *AFD = dyn_cast<AbstractFunctionDecl>(VD))
279+
if (AFD->getAttrs().hasAttribute<DerivativeAttr>())
280+
return AFD;
281+
return nullptr;
282+
};
283+
if (onlyOperators && VD->isOperator())
273284
TopLevelValues.add(VD);
274-
275-
if (!onlyOperators && VD->getAttrs().hasAttribute<CustomAttr>()) {
285+
if (onlyDerivatives)
286+
if (AbstractFunctionDecl *AFD = getDerivative())
287+
CustomDerivatives.push_back(AFD);
288+
if (!onlyOperators && !onlyDerivatives && VD->hasName()) {
289+
TopLevelValues.add(VD);
290+
if (VD->getAttrs().hasAttribute<CustomAttr>())
276291
MayHaveAuxiliaryDecls.push_back(VD);
277-
}
292+
if (AbstractFunctionDecl *AFD = getDerivative())
293+
CustomDerivatives.push_back(AFD);
278294
}
279295
}
280296

281-
if (auto *NTD = dyn_cast<NominalTypeDecl>(D))
282-
if (!NTD->hasUnparsedMembers() || NTD->maybeHasOperatorDeclarations())
283-
addToUnqualifiedLookupCache(NTD->getMembers(), true);
297+
if (auto *NTD = dyn_cast<NominalTypeDecl>(D)) {
298+
bool onlyOperatorsArg =
299+
(!NTD->hasUnparsedMembers() || NTD->maybeHasOperatorDeclarations());
300+
bool onlyDerivativesArg =
301+
(!NTD->hasUnparsedMembers() || NTD->maybeHasDerivativeDeclarations());
302+
if (onlyOperatorsArg || onlyDerivativesArg) {
303+
addToUnqualifiedLookupCache(NTD->getMembers(), onlyOperatorsArg,
304+
onlyDerivativesArg);
305+
}
306+
}
284307

285308
if (auto *ED = dyn_cast<ExtensionDecl>(D)) {
286309
// Avoid populating the cache with the members of invalid extension
@@ -292,8 +315,14 @@ void SourceLookupCache::addToUnqualifiedLookupCache(Range items,
292315
MayHaveAuxiliaryDecls.push_back(ED);
293316
}
294317

295-
if (!ED->hasUnparsedMembers() || ED->maybeHasOperatorDeclarations())
296-
addToUnqualifiedLookupCache(ED->getMembers(), true);
318+
bool onlyOperatorsArg =
319+
(!ED->hasUnparsedMembers() || ED->maybeHasOperatorDeclarations());
320+
bool onlyDerivativesArg =
321+
(!ED->hasUnparsedMembers() || ED->maybeHasDerivativeDeclarations());
322+
if (onlyOperatorsArg || onlyDerivativesArg) {
323+
addToUnqualifiedLookupCache(ED->getMembers(), onlyOperatorsArg,
324+
onlyDerivativesArg);
325+
}
297326
}
298327

299328
if (auto *OD = dyn_cast<OperatorDecl>(D))
@@ -307,7 +336,8 @@ void SourceLookupCache::addToUnqualifiedLookupCache(Range items,
307336
MayHaveAuxiliaryDecls.push_back(MED);
308337
} else if (auto TLCD = dyn_cast<TopLevelCodeDecl>(D)) {
309338
if (auto body = TLCD->getBody()){
310-
addToUnqualifiedLookupCache(body->getElements(), onlyOperators);
339+
addToUnqualifiedLookupCache(body->getElements(), onlyOperators,
340+
onlyDerivatives);
311341
}
312342
}
313343
}
@@ -488,8 +518,8 @@ SourceLookupCache::SourceLookupCache(const SourceFile &SF)
488518
{
489519
FrontendStatsTracer tracer(SF.getASTContext().Stats,
490520
"source-file-populate-cache");
491-
addToUnqualifiedLookupCache(SF.getTopLevelItems(), false);
492-
addToUnqualifiedLookupCache(SF.getHoistedDecls(), false);
521+
addToUnqualifiedLookupCache(SF.getTopLevelItems(), false, false);
522+
addToUnqualifiedLookupCache(SF.getHoistedDecls(), false, false);
493523
}
494524

495525
SourceLookupCache::SourceLookupCache(const ModuleDecl &M)
@@ -499,11 +529,11 @@ SourceLookupCache::SourceLookupCache(const ModuleDecl &M)
499529
"module-populate-cache");
500530
for (const FileUnit *file : M.getFiles()) {
501531
auto *SF = cast<SourceFile>(file);
502-
addToUnqualifiedLookupCache(SF->getTopLevelItems(), false);
503-
addToUnqualifiedLookupCache(SF->getHoistedDecls(), false);
532+
addToUnqualifiedLookupCache(SF->getTopLevelItems(), false, false);
533+
addToUnqualifiedLookupCache(SF->getHoistedDecls(), false, false);
504534

505535
if (auto *SFU = file->getSynthesizedFile()) {
506-
addToUnqualifiedLookupCache(SFU->getTopLevelDecls(), false);
536+
addToUnqualifiedLookupCache(SFU->getTopLevelDecls(), false, false);
507537
}
508538
}
509539
}
@@ -572,6 +602,11 @@ void SourceLookupCache::getOperatorDecls(
572602
results.append(ops.second.begin(), ops.second.end());
573603
}
574604

605+
llvm::SmallVector<AbstractFunctionDecl *, 0>
606+
SourceLookupCache::getCustomDerivativeDecls() {
607+
return CustomDerivatives;
608+
}
609+
575610
void SourceLookupCache::lookupOperator(Identifier name, OperatorFixity fixity,
576611
TinyPtrVector<OperatorDecl *> &results) {
577612
auto ops = Operators.find(name);
@@ -4008,6 +4043,23 @@ bool IsNonUserModuleRequest::evaluate(Evaluator &evaluator, ModuleDecl *mod) con
40084043
(!sdkOrPlatform.empty() && pathStartsWith(sdkOrPlatform, modulePath));
40094044
}
40104045

4046+
evaluator::SideEffect CustomDerivativesRequest::evaluate(Evaluator &evaluator,
4047+
SourceFile *sf) const {
4048+
ModuleDecl *module = sf->getParentModule();
4049+
assert(isParsedModule(module));
4050+
llvm::SmallVector<AbstractFunctionDecl *, 0> decls =
4051+
module->getSourceLookupCache().getCustomDerivativeDecls();
4052+
for (const AbstractFunctionDecl *afd : decls) {
4053+
for (const auto *derAttr :
4054+
afd->getAttrs().getAttributes<DerivativeAttr>()) {
4055+
// Resolve derivative function configurations from `@derivative`
4056+
// attributes by type-checking them.
4057+
(void)derAttr->getOriginalFunction(sf->getASTContext());
4058+
}
4059+
}
4060+
return {};
4061+
}
4062+
40114063
version::Version ModuleDecl::getLanguageVersionBuiltWith() const {
40124064
for (auto *F : getFiles()) {
40134065
auto *LD = dyn_cast<LoadedFile>(F);

lib/Parse/ParseDecl.cpp

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5644,19 +5644,20 @@ static void diagnoseOperatorFixityAttributes(Parser &P,
56445644
}
56455645
}
56465646

5647-
static unsigned skipUntilMatchingRBrace(Parser &P,
5648-
bool &HasPoundDirective,
5647+
static unsigned skipUntilMatchingRBrace(Parser &P, bool &HasPoundDirective,
56495648
bool &HasPoundSourceLocation,
56505649
bool &HasOperatorDeclarations,
56515650
bool &HasNestedClassDeclarations,
56525651
bool &HasNestedTypeDeclarations,
5653-
bool &HasPotentialRegexLiteral) {
5652+
bool &HasPotentialRegexLiteral,
5653+
bool &HasDerivativeDeclarations) {
56545654
HasPoundDirective = false;
56555655
HasPoundSourceLocation = false;
56565656
HasOperatorDeclarations = false;
56575657
HasNestedClassDeclarations = false;
56585658
HasNestedTypeDeclarations = false;
56595659
HasPotentialRegexLiteral = false;
5660+
HasDerivativeDeclarations = false;
56605661

56615662
unsigned OpenBraces = 1;
56625663
unsigned OpenPoundIf = 0;
@@ -5685,6 +5686,18 @@ static unsigned skipUntilMatchingRBrace(Parser &P,
56855686
tok::kw_protocol)
56865687
|| P.Tok.isContextualKeyword("actor");
56875688

5689+
if (P.consumeIf(tok::at_sign)) {
5690+
if (P.Tok.is(tok::identifier)) {
5691+
std::optional<DeclAttrKind> DK =
5692+
DeclAttribute::getAttrKindFromString(P.Tok.getText());
5693+
if (DK && *DK == DeclAttrKind::Derivative) {
5694+
HasDerivativeDeclarations = true;
5695+
P.consumeToken();
5696+
}
5697+
}
5698+
continue;
5699+
}
5700+
56885701
// HACK: Bail if we encounter what could potentially be a regex literal.
56895702
// This is necessary as:
56905703
// - We might encounter an invalid Swift token that might be valid in a
@@ -7033,13 +7046,17 @@ bool Parser::parseMemberDeclList(SourceLoc &LBLoc, SourceLoc &RBLoc,
70337046

70347047
bool HasOperatorDeclarations;
70357048
bool HasNestedClassDeclarations;
7049+
bool HasDerivativeDeclarations;
70367050

70377051
if (canDelayMemberDeclParsing(HasOperatorDeclarations,
7038-
HasNestedClassDeclarations)) {
7052+
HasNestedClassDeclarations,
7053+
HasDerivativeDeclarations)) {
70397054
if (HasOperatorDeclarations)
70407055
IDC->setMaybeHasOperatorDeclarations();
70417056
if (HasNestedClassDeclarations)
70427057
IDC->setMaybeHasNestedClassDeclarations();
7058+
if (HasDerivativeDeclarations)
7059+
IDC->setMaybeHasDerivativeDeclarations();
70437060

70447061
if (delayParsingDeclList(LBLoc, RBLoc, IDC))
70457062
return true;
@@ -7050,6 +7067,7 @@ bool Parser::parseMemberDeclList(SourceLoc &LBLoc, SourceLoc &RBLoc,
70507067
auto membersAndHash = parseDeclList(LBLoc, RBLoc, RBraceDiag, hadError);
70517068
IDC->setMaybeHasOperatorDeclarations();
70527069
IDC->setMaybeHasNestedClassDeclarations();
7070+
IDC->setMaybeHasDerivativeDeclarations();
70537071
Context.evaluator.cacheOutput(
70547072
ParseMembersRequest{IDC},
70557073
FingerprintAndMembers{
@@ -7110,7 +7128,8 @@ Parser::parseDeclList(SourceLoc LBLoc, SourceLoc &RBLoc, Diag<> ErrorDiag,
71107128
}
71117129

71127130
bool Parser::canDelayMemberDeclParsing(bool &HasOperatorDeclarations,
7113-
bool &HasNestedClassDeclarations) {
7131+
bool &HasNestedClassDeclarations,
7132+
bool &HasDerivativeDeclarations) {
71147133
// If explicitly disabled, respect the flag.
71157134
if (!isDelayedParsingEnabled())
71167135
return false;
@@ -7122,13 +7141,10 @@ bool Parser::canDelayMemberDeclParsing(bool &HasOperatorDeclarations,
71227141
bool HasPoundSourceLocation;
71237142
bool HasNestedTypeDeclarations;
71247143
bool HasPotentialRegexLiteral;
7125-
skipUntilMatchingRBrace(*this,
7126-
HasPoundDirective,
7127-
HasPoundSourceLocation,
7128-
HasOperatorDeclarations,
7129-
HasNestedClassDeclarations,
7130-
HasNestedTypeDeclarations,
7131-
HasPotentialRegexLiteral);
7144+
skipUntilMatchingRBrace(*this, HasPoundDirective, HasPoundSourceLocation,
7145+
HasOperatorDeclarations, HasNestedClassDeclarations,
7146+
HasNestedTypeDeclarations, HasPotentialRegexLiteral,
7147+
HasDerivativeDeclarations);
71327148
if (!HasPoundDirective && !HasPotentialRegexLiteral) {
71337149
// If we didn't see any pound directive, we must not have seen
71347150
// #sourceLocation either.
@@ -7740,9 +7756,11 @@ bool Parser::canDelayFunctionBodyParsing(bool &HasNestedTypeDeclarations) {
77407756
bool HasOperatorDeclarations;
77417757
bool HasNestedClassDeclarations;
77427758
bool HasPotentialRegexLiteral;
7759+
bool HasDerivativeDeclarations;
77437760
skipUntilMatchingRBrace(*this, HasPoundDirectives, HasPoundSourceLocation,
77447761
HasOperatorDeclarations, HasNestedClassDeclarations,
7745-
HasNestedTypeDeclarations, HasPotentialRegexLiteral);
7762+
HasNestedTypeDeclarations, HasPotentialRegexLiteral,
7763+
HasDerivativeDeclarations);
77467764
if (HasPoundSourceLocation || HasPotentialRegexLiteral)
77477765
return false;
77487766

0 commit comments

Comments
 (0)