Skip to content

Commit cf68d28

Browse files
authored
Merge pull request #76951 from kovdan01/issue60102
[AutoDiff] Enhance performance of custom derivatives lookup
2 parents 0678829 + 0d7e37e commit cf68d28

File tree

10 files changed

+160
-57
lines changed

10 files changed

+160
-57
lines changed

include/swift/AST/DeclContext.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -802,6 +802,10 @@ class IterableDeclContext {
802802
/// We must restore this when delayed parsing the body.
803803
unsigned InFreestandingMacroArgument : 1;
804804

805+
/// Whether delayed parsing detect a possible custom derivative definition
806+
/// while skipping the body of this context.
807+
unsigned HasDerivativeDeclarations : 1;
808+
805809
template<class A, class B, class C>
806810
friend struct ::llvm::CastInfo;
807811

@@ -817,6 +821,7 @@ class IterableDeclContext {
817821
: LastDeclAndKind(nullptr, kind) {
818822
AddedParsedMembers = 0;
819823
HasOperatorDeclarations = 0;
824+
HasDerivativeDeclarations = 0;
820825
HasNestedClassDeclarations = 0;
821826
InFreestandingMacroArgument = 0;
822827
}
@@ -855,6 +860,15 @@ class IterableDeclContext {
855860
InFreestandingMacroArgument = 1;
856861
}
857862

863+
bool maybeHasDerivativeDeclarations() const {
864+
return HasDerivativeDeclarations;
865+
}
866+
867+
void setMaybeHasDerivativeDeclarations() {
868+
assert(hasUnparsedMembers());
869+
HasDerivativeDeclarations = 1;
870+
}
871+
858872
/// Retrieve the current set of members in this context.
859873
///
860874
/// NOTE: This operation is an alias of \c getCurrentMembers() that is considered

include/swift/AST/Module.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,7 @@ class ModuleDecl
240240
: public DeclContext, public TypeDecl, public ASTAllocated<ModuleDecl> {
241241
friend class DirectOperatorLookupRequest;
242242
friend class DirectPrecedenceGroupLookupRequest;
243+
friend class CustomDerivativesRequest;
243244

244245
/// The ABI name of the module, if it differs from the module name.
245246
mutable Identifier ModuleABIName;

include/swift/AST/TypeCheckRequests.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5208,6 +5208,22 @@ class GenericTypeParamDeclGetValueTypeRequest
52085208
bool isCached() const { return true; }
52095209
};
52105210

5211+
class CustomDerivativesRequest
5212+
: public SimpleRequest<CustomDerivativesRequest,
5213+
evaluator::SideEffect(SourceFile *),
5214+
RequestFlags::Cached> {
5215+
public:
5216+
using SimpleRequest::SimpleRequest;
5217+
5218+
private:
5219+
friend SimpleRequest;
5220+
5221+
evaluator::SideEffect evaluate(Evaluator &evaluator, SourceFile *sf) const;
5222+
5223+
public:
5224+
bool isCached() const { return true; }
5225+
};
5226+
52115227
#define SWIFT_TYPEID_ZONE TypeChecker
52125228
#define SWIFT_TYPEID_HEADER "swift/AST/TypeCheckerTypeIDZone.def"
52135229
#include "swift/Basic/DefineTypeIDZone.h"

include/swift/AST/TypeCheckerTypeIDZone.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,9 @@ SWIFT_REQUEST(TypeChecker, ParamCaptureInfoRequest,
611611
SWIFT_REQUEST(TypeChecker, IsUnsafeRequest,
612612
bool(Decl *),
613613
SeparatelyCached, NoLocationInfo)
614+
SWIFT_REQUEST(TypeChecker, CustomDerivativesRequest,
615+
CustomDerivativesResult(SourceFile *),
616+
Cached, NoLocationInfo)
614617

615618
SWIFT_REQUEST(TypeChecker, GenericTypeParamDeclGetValueTypeRequest,
616619
Type(GenericTypeParamDecl *), Cached, NoLocationInfo)

include/swift/Parse/Parser.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -968,7 +968,8 @@ class Parser {
968968
IterableDeclContext *IDC);
969969

970970
bool canDelayMemberDeclParsing(bool &HasOperatorDeclarations,
971-
bool &HasNestedClassDeclarations);
971+
bool &HasNestedClassDeclarations,
972+
bool &HasDerivativeDeclarations);
972973

973974
bool canDelayFunctionBodyParsing(bool &HasNestedTypeDeclarations);
974975

lib/AST/Module.cpp

Lines changed: 73 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -166,15 +166,17 @@ class swift::SourceLookupCache {
166166
ValueDeclMap TopLevelValues;
167167
ValueDeclMap ClassMembers;
168168
bool MemberCachePopulated = false;
169+
llvm::SmallVector<AbstractFunctionDecl *, 0> CustomDerivatives;
169170
DeclName UniqueMacroNamePlaceholder;
170171

171172
template<typename T>
172173
using OperatorMap = llvm::DenseMap<Identifier, TinyPtrVector<T *>>;
173174
OperatorMap<OperatorDecl> Operators;
174175
OperatorMap<PrecedenceGroupDecl> PrecedenceGroups;
175176

176-
template<typename Range>
177-
void addToUnqualifiedLookupCache(Range decls, bool onlyOperators);
177+
template <typename Range>
178+
void addToUnqualifiedLookupCache(Range decls, bool onlyOperators,
179+
bool onlyDerivatives);
178180
template<typename Range>
179181
void addToMemberCache(Range decls);
180182

@@ -205,6 +207,10 @@ class swift::SourceLookupCache {
205207
/// guaranteed to be meaningful.
206208
void getPrecedenceGroups(SmallVectorImpl<PrecedenceGroupDecl *> &results);
207209

210+
/// Retrieves all the function decls marked as @derivative. The order of the
211+
/// results is not guaranteed to be meaningful.
212+
llvm::SmallVector<AbstractFunctionDecl *, 0> getCustomDerivativeDecls();
213+
208214
/// Look up an operator declaration.
209215
///
210216
/// \param name The operator name ("+", ">>", etc.)
@@ -249,9 +255,10 @@ static Decl *getAsDecl(Decl *decl) { return decl; }
249255
static Expr *getAsExpr(ASTNode node) { return node.dyn_cast<Expr *>(); }
250256
static Decl *getAsDecl(ASTNode node) { return node.dyn_cast<Decl *>(); }
251257

252-
template<typename Range>
258+
template <typename Range>
253259
void SourceLookupCache::addToUnqualifiedLookupCache(Range items,
254-
bool onlyOperators) {
260+
bool onlyOperators,
261+
bool onlyDerivatives) {
255262
for (auto item : items) {
256263
// In script mode, we'll see macro expansion expressions for freestanding
257264
// macros.
@@ -268,19 +275,36 @@ void SourceLookupCache::addToUnqualifiedLookupCache(Range items,
268275
continue;
269276

270277
if (auto *VD = dyn_cast<ValueDecl>(D)) {
271-
if (onlyOperators ? VD->isOperator() : VD->hasName()) {
272-
// Cache the value under both its compound name and its full name.
278+
auto getDerivative = [onlyDerivatives, VD]() -> AbstractFunctionDecl * {
279+
if (auto *AFD = dyn_cast<AbstractFunctionDecl>(VD))
280+
if (AFD->getAttrs().hasAttribute<DerivativeAttr>())
281+
return AFD;
282+
return nullptr;
283+
};
284+
if (onlyOperators && VD->isOperator())
273285
TopLevelValues.add(VD);
274-
275-
if (!onlyOperators && VD->getAttrs().hasAttribute<CustomAttr>()) {
286+
if (onlyDerivatives)
287+
if (AbstractFunctionDecl *AFD = getDerivative())
288+
CustomDerivatives.push_back(AFD);
289+
if (!onlyOperators && !onlyDerivatives && VD->hasName()) {
290+
TopLevelValues.add(VD);
291+
if (VD->getAttrs().hasAttribute<CustomAttr>())
276292
MayHaveAuxiliaryDecls.push_back(VD);
277-
}
293+
if (AbstractFunctionDecl *AFD = getDerivative())
294+
CustomDerivatives.push_back(AFD);
278295
}
279296
}
280297

281-
if (auto *NTD = dyn_cast<NominalTypeDecl>(D))
282-
if (!NTD->hasUnparsedMembers() || NTD->maybeHasOperatorDeclarations())
283-
addToUnqualifiedLookupCache(NTD->getMembers(), true);
298+
if (auto *NTD = dyn_cast<NominalTypeDecl>(D)) {
299+
bool onlyOperatorsArg =
300+
(!NTD->hasUnparsedMembers() || NTD->maybeHasOperatorDeclarations());
301+
bool onlyDerivativesArg =
302+
(!NTD->hasUnparsedMembers() || NTD->maybeHasDerivativeDeclarations());
303+
if (onlyOperatorsArg || onlyDerivativesArg) {
304+
addToUnqualifiedLookupCache(NTD->getMembers(), onlyOperatorsArg,
305+
onlyDerivativesArg);
306+
}
307+
}
284308

285309
if (auto *ED = dyn_cast<ExtensionDecl>(D)) {
286310
// Avoid populating the cache with the members of invalid extension
@@ -292,8 +316,14 @@ void SourceLookupCache::addToUnqualifiedLookupCache(Range items,
292316
MayHaveAuxiliaryDecls.push_back(ED);
293317
}
294318

295-
if (!ED->hasUnparsedMembers() || ED->maybeHasOperatorDeclarations())
296-
addToUnqualifiedLookupCache(ED->getMembers(), true);
319+
bool onlyOperatorsArg =
320+
(!ED->hasUnparsedMembers() || ED->maybeHasOperatorDeclarations());
321+
bool onlyDerivativesArg =
322+
(!ED->hasUnparsedMembers() || ED->maybeHasDerivativeDeclarations());
323+
if (onlyOperatorsArg || onlyDerivativesArg) {
324+
addToUnqualifiedLookupCache(ED->getMembers(), onlyOperatorsArg,
325+
onlyDerivativesArg);
326+
}
297327
}
298328

299329
if (auto *OD = dyn_cast<OperatorDecl>(D))
@@ -307,7 +337,8 @@ void SourceLookupCache::addToUnqualifiedLookupCache(Range items,
307337
MayHaveAuxiliaryDecls.push_back(MED);
308338
} else if (auto TLCD = dyn_cast<TopLevelCodeDecl>(D)) {
309339
if (auto body = TLCD->getBody()){
310-
addToUnqualifiedLookupCache(body->getElements(), onlyOperators);
340+
addToUnqualifiedLookupCache(body->getElements(), onlyOperators,
341+
onlyDerivatives);
311342
}
312343
}
313344
}
@@ -488,8 +519,8 @@ SourceLookupCache::SourceLookupCache(const SourceFile &SF)
488519
{
489520
FrontendStatsTracer tracer(SF.getASTContext().Stats,
490521
"source-file-populate-cache");
491-
addToUnqualifiedLookupCache(SF.getTopLevelItems(), false);
492-
addToUnqualifiedLookupCache(SF.getHoistedDecls(), false);
522+
addToUnqualifiedLookupCache(SF.getTopLevelItems(), false, false);
523+
addToUnqualifiedLookupCache(SF.getHoistedDecls(), false, false);
493524
}
494525

495526
SourceLookupCache::SourceLookupCache(const ModuleDecl &M)
@@ -499,11 +530,11 @@ SourceLookupCache::SourceLookupCache(const ModuleDecl &M)
499530
"module-populate-cache");
500531
for (const FileUnit *file : M.getFiles()) {
501532
auto *SF = cast<SourceFile>(file);
502-
addToUnqualifiedLookupCache(SF->getTopLevelItems(), false);
503-
addToUnqualifiedLookupCache(SF->getHoistedDecls(), false);
533+
addToUnqualifiedLookupCache(SF->getTopLevelItems(), false, false);
534+
addToUnqualifiedLookupCache(SF->getHoistedDecls(), false, false);
504535

505536
if (auto *SFU = file->getSynthesizedFile()) {
506-
addToUnqualifiedLookupCache(SFU->getTopLevelDecls(), false);
537+
addToUnqualifiedLookupCache(SFU->getTopLevelDecls(), false, false);
507538
}
508539
}
509540
}
@@ -572,6 +603,11 @@ void SourceLookupCache::getOperatorDecls(
572603
results.append(ops.second.begin(), ops.second.end());
573604
}
574605

606+
llvm::SmallVector<AbstractFunctionDecl *, 0>
607+
SourceLookupCache::getCustomDerivativeDecls() {
608+
return CustomDerivatives;
609+
}
610+
575611
void SourceLookupCache::lookupOperator(Identifier name, OperatorFixity fixity,
576612
TinyPtrVector<OperatorDecl *> &results) {
577613
auto ops = Operators.find(name);
@@ -4026,6 +4062,23 @@ bool IsNonUserModuleRequest::evaluate(Evaluator &evaluator, ModuleDecl *mod) con
40264062
return false;
40274063
}
40284064

4065+
evaluator::SideEffect CustomDerivativesRequest::evaluate(Evaluator &evaluator,
4066+
SourceFile *sf) const {
4067+
ModuleDecl *module = sf->getParentModule();
4068+
assert(isParsedModule(module));
4069+
llvm::SmallVector<AbstractFunctionDecl *, 0> decls =
4070+
module->getSourceLookupCache().getCustomDerivativeDecls();
4071+
for (const AbstractFunctionDecl *afd : decls) {
4072+
for (const auto *derAttr :
4073+
afd->getAttrs().getAttributes<DerivativeAttr>()) {
4074+
// Resolve derivative function configurations from `@derivative`
4075+
// attributes by type-checking them.
4076+
(void)derAttr->getOriginalFunction(sf->getASTContext());
4077+
}
4078+
}
4079+
return {};
4080+
}
4081+
40294082
version::Version ModuleDecl::getLanguageVersionBuiltWith() const {
40304083
for (auto *F : getFiles()) {
40314084
auto *LD = dyn_cast<LoadedFile>(F);

lib/Parse/ParseDecl.cpp

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5524,19 +5524,20 @@ static void diagnoseOperatorFixityAttributes(Parser &P,
55245524
}
55255525
}
55265526

5527-
static unsigned skipUntilMatchingRBrace(Parser &P,
5528-
bool &HasPoundDirective,
5527+
static unsigned skipUntilMatchingRBrace(Parser &P, bool &HasPoundDirective,
55295528
bool &HasPoundSourceLocation,
55305529
bool &HasOperatorDeclarations,
55315530
bool &HasNestedClassDeclarations,
55325531
bool &HasNestedTypeDeclarations,
5533-
bool &HasPotentialRegexLiteral) {
5532+
bool &HasPotentialRegexLiteral,
5533+
bool &HasDerivativeDeclarations) {
55345534
HasPoundDirective = false;
55355535
HasPoundSourceLocation = false;
55365536
HasOperatorDeclarations = false;
55375537
HasNestedClassDeclarations = false;
55385538
HasNestedTypeDeclarations = false;
55395539
HasPotentialRegexLiteral = false;
5540+
HasDerivativeDeclarations = false;
55405541

55415542
unsigned OpenBraces = 1;
55425543
unsigned OpenPoundIf = 0;
@@ -5565,6 +5566,18 @@ static unsigned skipUntilMatchingRBrace(Parser &P,
55655566
tok::kw_protocol)
55665567
|| P.Tok.isContextualKeyword("actor");
55675568

5569+
if (P.consumeIf(tok::at_sign)) {
5570+
if (P.Tok.is(tok::identifier)) {
5571+
std::optional<DeclAttrKind> DK =
5572+
DeclAttribute::getAttrKindFromString(P.Tok.getText());
5573+
if (DK && *DK == DeclAttrKind::Derivative) {
5574+
HasDerivativeDeclarations = true;
5575+
P.consumeToken();
5576+
}
5577+
}
5578+
continue;
5579+
}
5580+
55685581
// HACK: Bail if we encounter what could potentially be a regex literal.
55695582
// This is necessary as:
55705583
// - We might encounter an invalid Swift token that might be valid in a
@@ -6913,13 +6926,17 @@ bool Parser::parseMemberDeclList(SourceLoc &LBLoc, SourceLoc &RBLoc,
69136926

69146927
bool HasOperatorDeclarations = false;
69156928
bool HasNestedClassDeclarations = false;
6929+
bool HasDerivativeDeclarations = false;
69166930

69176931
if (canDelayMemberDeclParsing(HasOperatorDeclarations,
6918-
HasNestedClassDeclarations)) {
6932+
HasNestedClassDeclarations,
6933+
HasDerivativeDeclarations)) {
69196934
if (HasOperatorDeclarations)
69206935
IDC->setMaybeHasOperatorDeclarations();
69216936
if (HasNestedClassDeclarations)
69226937
IDC->setMaybeHasNestedClassDeclarations();
6938+
if (HasDerivativeDeclarations)
6939+
IDC->setMaybeHasDerivativeDeclarations();
69236940
if (InFreestandingMacroArgument)
69246941
IDC->setInFreestandingMacroArgument();
69256942

@@ -6932,6 +6949,7 @@ bool Parser::parseMemberDeclList(SourceLoc &LBLoc, SourceLoc &RBLoc,
69326949
auto membersAndHash = parseDeclList(LBLoc, RBLoc, RBraceDiag, hadError);
69336950
IDC->setMaybeHasOperatorDeclarations();
69346951
IDC->setMaybeHasNestedClassDeclarations();
6952+
IDC->setMaybeHasDerivativeDeclarations();
69356953
Context.evaluator.cacheOutput(
69366954
ParseMembersRequest{IDC},
69376955
FingerprintAndMembers{
@@ -6992,7 +7010,8 @@ Parser::parseDeclList(SourceLoc LBLoc, SourceLoc &RBLoc, Diag<> ErrorDiag,
69927010
}
69937011

69947012
bool Parser::canDelayMemberDeclParsing(bool &HasOperatorDeclarations,
6995-
bool &HasNestedClassDeclarations) {
7013+
bool &HasNestedClassDeclarations,
7014+
bool &HasDerivativeDeclarations) {
69967015
// If explicitly disabled, respect the flag.
69977016
if (!isDelayedParsingEnabled())
69987017
return false;
@@ -7004,13 +7023,10 @@ bool Parser::canDelayMemberDeclParsing(bool &HasOperatorDeclarations,
70047023
bool HasPoundSourceLocation;
70057024
bool HasNestedTypeDeclarations;
70067025
bool HasPotentialRegexLiteral;
7007-
skipUntilMatchingRBrace(*this,
7008-
HasPoundDirective,
7009-
HasPoundSourceLocation,
7010-
HasOperatorDeclarations,
7011-
HasNestedClassDeclarations,
7012-
HasNestedTypeDeclarations,
7013-
HasPotentialRegexLiteral);
7026+
skipUntilMatchingRBrace(*this, HasPoundDirective, HasPoundSourceLocation,
7027+
HasOperatorDeclarations, HasNestedClassDeclarations,
7028+
HasNestedTypeDeclarations, HasPotentialRegexLiteral,
7029+
HasDerivativeDeclarations);
70147030
if (!HasPoundDirective && !HasPotentialRegexLiteral) {
70157031
// If we didn't see any pound directive, we must not have seen
70167032
// #sourceLocation either.
@@ -7622,9 +7638,11 @@ bool Parser::canDelayFunctionBodyParsing(bool &HasNestedTypeDeclarations) {
76227638
bool HasOperatorDeclarations;
76237639
bool HasNestedClassDeclarations;
76247640
bool HasPotentialRegexLiteral;
7641+
bool HasDerivativeDeclarations;
76257642
skipUntilMatchingRBrace(*this, HasPoundDirectives, HasPoundSourceLocation,
76267643
HasOperatorDeclarations, HasNestedClassDeclarations,
7627-
HasNestedTypeDeclarations, HasPotentialRegexLiteral);
7644+
HasNestedTypeDeclarations, HasPotentialRegexLiteral,
7645+
HasDerivativeDeclarations);
76287646
if (HasPoundSourceLocation || HasPotentialRegexLiteral)
76297647
return false;
76307648

0 commit comments

Comments
 (0)