Skip to content

Commit 0d7e37e

Browse files
committed
[AutoDiff] Enhance performance of custom derivatives lookup
In #58965, lookup for custom derivatives in non-primary source files was introduced. It required triggering delayed members parsing of nominal types in a file if the file was compiled with differential programming enabled. This patch introduces `CustomDerivativesRequest` to address the issue. We only parse delayed members if tokens `@` and `derivative` appear together inside skipped nominal type body (similar to how member operators are handled). Resolves #60102
1 parent 1941996 commit 0d7e37e

File tree

10 files changed

+160
-57
lines changed

10 files changed

+160
-57
lines changed

include/swift/AST/DeclContext.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -802,6 +802,10 @@ class IterableDeclContext {
802802
/// We must restore this when delayed parsing the body.
803803
unsigned InFreestandingMacroArgument : 1;
804804

805+
/// Whether delayed parsing detect a possible custom derivative definition
806+
/// while skipping the body of this context.
807+
unsigned HasDerivativeDeclarations : 1;
808+
805809
template<class A, class B, class C>
806810
friend struct ::llvm::CastInfo;
807811

@@ -817,6 +821,7 @@ class IterableDeclContext {
817821
: LastDeclAndKind(nullptr, kind) {
818822
AddedParsedMembers = 0;
819823
HasOperatorDeclarations = 0;
824+
HasDerivativeDeclarations = 0;
820825
HasNestedClassDeclarations = 0;
821826
InFreestandingMacroArgument = 0;
822827
}
@@ -855,6 +860,15 @@ class IterableDeclContext {
855860
InFreestandingMacroArgument = 1;
856861
}
857862

863+
bool maybeHasDerivativeDeclarations() const {
864+
return HasDerivativeDeclarations;
865+
}
866+
867+
void setMaybeHasDerivativeDeclarations() {
868+
assert(hasUnparsedMembers());
869+
HasDerivativeDeclarations = 1;
870+
}
871+
858872
/// Retrieve the current set of members in this context.
859873
///
860874
/// NOTE: This operation is an alias of \c getCurrentMembers() that is considered

include/swift/AST/Module.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,7 @@ class ModuleDecl
240240
: public DeclContext, public TypeDecl, public ASTAllocated<ModuleDecl> {
241241
friend class DirectOperatorLookupRequest;
242242
friend class DirectPrecedenceGroupLookupRequest;
243+
friend class CustomDerivativesRequest;
243244

244245
/// The ABI name of the module, if it differs from the module name.
245246
mutable Identifier ModuleABIName;

include/swift/AST/TypeCheckRequests.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5165,6 +5165,22 @@ class GenericTypeParamDeclGetValueTypeRequest
51655165
bool isCached() const { return true; }
51665166
};
51675167

5168+
class CustomDerivativesRequest
5169+
: public SimpleRequest<CustomDerivativesRequest,
5170+
evaluator::SideEffect(SourceFile *),
5171+
RequestFlags::Cached> {
5172+
public:
5173+
using SimpleRequest::SimpleRequest;
5174+
5175+
private:
5176+
friend SimpleRequest;
5177+
5178+
evaluator::SideEffect evaluate(Evaluator &evaluator, SourceFile *sf) const;
5179+
5180+
public:
5181+
bool isCached() const { return true; }
5182+
};
5183+
51685184
#define SWIFT_TYPEID_ZONE TypeChecker
51695185
#define SWIFT_TYPEID_HEADER "swift/AST/TypeCheckerTypeIDZone.def"
51705186
#include "swift/Basic/DefineTypeIDZone.h"

include/swift/AST/TypeCheckerTypeIDZone.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -605,6 +605,9 @@ SWIFT_REQUEST(TypeChecker, ParamCaptureInfoRequest,
605605
SWIFT_REQUEST(TypeChecker, IsUnsafeRequest,
606606
bool(Decl *),
607607
SeparatelyCached, NoLocationInfo)
608+
SWIFT_REQUEST(TypeChecker, CustomDerivativesRequest,
609+
CustomDerivativesResult(SourceFile *),
610+
Cached, NoLocationInfo)
608611

609612
SWIFT_REQUEST(TypeChecker, GenericTypeParamDeclGetValueTypeRequest,
610613
Type(GenericTypeParamDecl *), Cached, NoLocationInfo)

include/swift/Parse/Parser.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -974,7 +974,8 @@ class Parser {
974974
IterableDeclContext *IDC);
975975

976976
bool canDelayMemberDeclParsing(bool &HasOperatorDeclarations,
977-
bool &HasNestedClassDeclarations);
977+
bool &HasNestedClassDeclarations,
978+
bool &HasDerivativeDeclarations);
978979

979980
bool canDelayFunctionBodyParsing(bool &HasNestedTypeDeclarations);
980981

lib/AST/Module.cpp

Lines changed: 73 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -166,15 +166,17 @@ class swift::SourceLookupCache {
166166
ValueDeclMap TopLevelValues;
167167
ValueDeclMap ClassMembers;
168168
bool MemberCachePopulated = false;
169+
llvm::SmallVector<AbstractFunctionDecl *, 0> CustomDerivatives;
169170
DeclName UniqueMacroNamePlaceholder;
170171

171172
template<typename T>
172173
using OperatorMap = llvm::DenseMap<Identifier, TinyPtrVector<T *>>;
173174
OperatorMap<OperatorDecl> Operators;
174175
OperatorMap<PrecedenceGroupDecl> PrecedenceGroups;
175176

176-
template<typename Range>
177-
void addToUnqualifiedLookupCache(Range decls, bool onlyOperators);
177+
template <typename Range>
178+
void addToUnqualifiedLookupCache(Range decls, bool onlyOperators,
179+
bool onlyDerivatives);
178180
template<typename Range>
179181
void addToMemberCache(Range decls);
180182

@@ -205,6 +207,10 @@ class swift::SourceLookupCache {
205207
/// guaranteed to be meaningful.
206208
void getPrecedenceGroups(SmallVectorImpl<PrecedenceGroupDecl *> &results);
207209

210+
/// Retrieves all the function decls marked as @derivative. The order of the
211+
/// results is not guaranteed to be meaningful.
212+
llvm::SmallVector<AbstractFunctionDecl *, 0> getCustomDerivativeDecls();
213+
208214
/// Look up an operator declaration.
209215
///
210216
/// \param name The operator name ("+", ">>", etc.)
@@ -249,9 +255,10 @@ static Decl *getAsDecl(Decl *decl) { return decl; }
249255
static Expr *getAsExpr(ASTNode node) { return node.dyn_cast<Expr *>(); }
250256
static Decl *getAsDecl(ASTNode node) { return node.dyn_cast<Decl *>(); }
251257

252-
template<typename Range>
258+
template <typename Range>
253259
void SourceLookupCache::addToUnqualifiedLookupCache(Range items,
254-
bool onlyOperators) {
260+
bool onlyOperators,
261+
bool onlyDerivatives) {
255262
for (auto item : items) {
256263
// In script mode, we'll see macro expansion expressions for freestanding
257264
// macros.
@@ -268,19 +275,36 @@ void SourceLookupCache::addToUnqualifiedLookupCache(Range items,
268275
continue;
269276

270277
if (auto *VD = dyn_cast<ValueDecl>(D)) {
271-
if (onlyOperators ? VD->isOperator() : VD->hasName()) {
272-
// Cache the value under both its compound name and its full name.
278+
auto getDerivative = [onlyDerivatives, VD]() -> AbstractFunctionDecl * {
279+
if (auto *AFD = dyn_cast<AbstractFunctionDecl>(VD))
280+
if (AFD->getAttrs().hasAttribute<DerivativeAttr>())
281+
return AFD;
282+
return nullptr;
283+
};
284+
if (onlyOperators && VD->isOperator())
273285
TopLevelValues.add(VD);
274-
275-
if (!onlyOperators && VD->getAttrs().hasAttribute<CustomAttr>()) {
286+
if (onlyDerivatives)
287+
if (AbstractFunctionDecl *AFD = getDerivative())
288+
CustomDerivatives.push_back(AFD);
289+
if (!onlyOperators && !onlyDerivatives && VD->hasName()) {
290+
TopLevelValues.add(VD);
291+
if (VD->getAttrs().hasAttribute<CustomAttr>())
276292
MayHaveAuxiliaryDecls.push_back(VD);
277-
}
293+
if (AbstractFunctionDecl *AFD = getDerivative())
294+
CustomDerivatives.push_back(AFD);
278295
}
279296
}
280297

281-
if (auto *NTD = dyn_cast<NominalTypeDecl>(D))
282-
if (!NTD->hasUnparsedMembers() || NTD->maybeHasOperatorDeclarations())
283-
addToUnqualifiedLookupCache(NTD->getMembers(), true);
298+
if (auto *NTD = dyn_cast<NominalTypeDecl>(D)) {
299+
bool onlyOperatorsArg =
300+
(!NTD->hasUnparsedMembers() || NTD->maybeHasOperatorDeclarations());
301+
bool onlyDerivativesArg =
302+
(!NTD->hasUnparsedMembers() || NTD->maybeHasDerivativeDeclarations());
303+
if (onlyOperatorsArg || onlyDerivativesArg) {
304+
addToUnqualifiedLookupCache(NTD->getMembers(), onlyOperatorsArg,
305+
onlyDerivativesArg);
306+
}
307+
}
284308

285309
if (auto *ED = dyn_cast<ExtensionDecl>(D)) {
286310
// Avoid populating the cache with the members of invalid extension
@@ -292,8 +316,14 @@ void SourceLookupCache::addToUnqualifiedLookupCache(Range items,
292316
MayHaveAuxiliaryDecls.push_back(ED);
293317
}
294318

295-
if (!ED->hasUnparsedMembers() || ED->maybeHasOperatorDeclarations())
296-
addToUnqualifiedLookupCache(ED->getMembers(), true);
319+
bool onlyOperatorsArg =
320+
(!ED->hasUnparsedMembers() || ED->maybeHasOperatorDeclarations());
321+
bool onlyDerivativesArg =
322+
(!ED->hasUnparsedMembers() || ED->maybeHasDerivativeDeclarations());
323+
if (onlyOperatorsArg || onlyDerivativesArg) {
324+
addToUnqualifiedLookupCache(ED->getMembers(), onlyOperatorsArg,
325+
onlyDerivativesArg);
326+
}
297327
}
298328

299329
if (auto *OD = dyn_cast<OperatorDecl>(D))
@@ -307,7 +337,8 @@ void SourceLookupCache::addToUnqualifiedLookupCache(Range items,
307337
MayHaveAuxiliaryDecls.push_back(MED);
308338
} else if (auto TLCD = dyn_cast<TopLevelCodeDecl>(D)) {
309339
if (auto body = TLCD->getBody()){
310-
addToUnqualifiedLookupCache(body->getElements(), onlyOperators);
340+
addToUnqualifiedLookupCache(body->getElements(), onlyOperators,
341+
onlyDerivatives);
311342
}
312343
}
313344
}
@@ -488,8 +519,8 @@ SourceLookupCache::SourceLookupCache(const SourceFile &SF)
488519
{
489520
FrontendStatsTracer tracer(SF.getASTContext().Stats,
490521
"source-file-populate-cache");
491-
addToUnqualifiedLookupCache(SF.getTopLevelItems(), false);
492-
addToUnqualifiedLookupCache(SF.getHoistedDecls(), false);
522+
addToUnqualifiedLookupCache(SF.getTopLevelItems(), false, false);
523+
addToUnqualifiedLookupCache(SF.getHoistedDecls(), false, false);
493524
}
494525

495526
SourceLookupCache::SourceLookupCache(const ModuleDecl &M)
@@ -499,11 +530,11 @@ SourceLookupCache::SourceLookupCache(const ModuleDecl &M)
499530
"module-populate-cache");
500531
for (const FileUnit *file : M.getFiles()) {
501532
auto *SF = cast<SourceFile>(file);
502-
addToUnqualifiedLookupCache(SF->getTopLevelItems(), false);
503-
addToUnqualifiedLookupCache(SF->getHoistedDecls(), false);
533+
addToUnqualifiedLookupCache(SF->getTopLevelItems(), false, false);
534+
addToUnqualifiedLookupCache(SF->getHoistedDecls(), false, false);
504535

505536
if (auto *SFU = file->getSynthesizedFile()) {
506-
addToUnqualifiedLookupCache(SFU->getTopLevelDecls(), false);
537+
addToUnqualifiedLookupCache(SFU->getTopLevelDecls(), false, false);
507538
}
508539
}
509540
}
@@ -572,6 +603,11 @@ void SourceLookupCache::getOperatorDecls(
572603
results.append(ops.second.begin(), ops.second.end());
573604
}
574605

606+
llvm::SmallVector<AbstractFunctionDecl *, 0>
607+
SourceLookupCache::getCustomDerivativeDecls() {
608+
return CustomDerivatives;
609+
}
610+
575611
void SourceLookupCache::lookupOperator(Identifier name, OperatorFixity fixity,
576612
TinyPtrVector<OperatorDecl *> &results) {
577613
auto ops = Operators.find(name);
@@ -4027,6 +4063,23 @@ bool IsNonUserModuleRequest::evaluate(Evaluator &evaluator, ModuleDecl *mod) con
40274063
return false;
40284064
}
40294065

4066+
evaluator::SideEffect CustomDerivativesRequest::evaluate(Evaluator &evaluator,
4067+
SourceFile *sf) const {
4068+
ModuleDecl *module = sf->getParentModule();
4069+
assert(isParsedModule(module));
4070+
llvm::SmallVector<AbstractFunctionDecl *, 0> decls =
4071+
module->getSourceLookupCache().getCustomDerivativeDecls();
4072+
for (const AbstractFunctionDecl *afd : decls) {
4073+
for (const auto *derAttr :
4074+
afd->getAttrs().getAttributes<DerivativeAttr>()) {
4075+
// Resolve derivative function configurations from `@derivative`
4076+
// attributes by type-checking them.
4077+
(void)derAttr->getOriginalFunction(sf->getASTContext());
4078+
}
4079+
}
4080+
return {};
4081+
}
4082+
40304083
version::Version ModuleDecl::getLanguageVersionBuiltWith() const {
40314084
for (auto *F : getFiles()) {
40324085
auto *LD = dyn_cast<LoadedFile>(F);

lib/Parse/ParseDecl.cpp

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5644,19 +5644,20 @@ static void diagnoseOperatorFixityAttributes(Parser &P,
56445644
}
56455645
}
56465646

5647-
static unsigned skipUntilMatchingRBrace(Parser &P,
5648-
bool &HasPoundDirective,
5647+
static unsigned skipUntilMatchingRBrace(Parser &P, bool &HasPoundDirective,
56495648
bool &HasPoundSourceLocation,
56505649
bool &HasOperatorDeclarations,
56515650
bool &HasNestedClassDeclarations,
56525651
bool &HasNestedTypeDeclarations,
5653-
bool &HasPotentialRegexLiteral) {
5652+
bool &HasPotentialRegexLiteral,
5653+
bool &HasDerivativeDeclarations) {
56545654
HasPoundDirective = false;
56555655
HasPoundSourceLocation = false;
56565656
HasOperatorDeclarations = false;
56575657
HasNestedClassDeclarations = false;
56585658
HasNestedTypeDeclarations = false;
56595659
HasPotentialRegexLiteral = false;
5660+
HasDerivativeDeclarations = false;
56605661

56615662
unsigned OpenBraces = 1;
56625663
unsigned OpenPoundIf = 0;
@@ -5685,6 +5686,18 @@ static unsigned skipUntilMatchingRBrace(Parser &P,
56855686
tok::kw_protocol)
56865687
|| P.Tok.isContextualKeyword("actor");
56875688

5689+
if (P.consumeIf(tok::at_sign)) {
5690+
if (P.Tok.is(tok::identifier)) {
5691+
std::optional<DeclAttrKind> DK =
5692+
DeclAttribute::getAttrKindFromString(P.Tok.getText());
5693+
if (DK && *DK == DeclAttrKind::Derivative) {
5694+
HasDerivativeDeclarations = true;
5695+
P.consumeToken();
5696+
}
5697+
}
5698+
continue;
5699+
}
5700+
56885701
// HACK: Bail if we encounter what could potentially be a regex literal.
56895702
// This is necessary as:
56905703
// - We might encounter an invalid Swift token that might be valid in a
@@ -7033,13 +7046,17 @@ bool Parser::parseMemberDeclList(SourceLoc &LBLoc, SourceLoc &RBLoc,
70337046

70347047
bool HasOperatorDeclarations = false;
70357048
bool HasNestedClassDeclarations = false;
7049+
bool HasDerivativeDeclarations = false;
70367050

70377051
if (canDelayMemberDeclParsing(HasOperatorDeclarations,
7038-
HasNestedClassDeclarations)) {
7052+
HasNestedClassDeclarations,
7053+
HasDerivativeDeclarations)) {
70397054
if (HasOperatorDeclarations)
70407055
IDC->setMaybeHasOperatorDeclarations();
70417056
if (HasNestedClassDeclarations)
70427057
IDC->setMaybeHasNestedClassDeclarations();
7058+
if (HasDerivativeDeclarations)
7059+
IDC->setMaybeHasDerivativeDeclarations();
70437060
if (InFreestandingMacroArgument)
70447061
IDC->setInFreestandingMacroArgument();
70457062

@@ -7052,6 +7069,7 @@ bool Parser::parseMemberDeclList(SourceLoc &LBLoc, SourceLoc &RBLoc,
70527069
auto membersAndHash = parseDeclList(LBLoc, RBLoc, RBraceDiag, hadError);
70537070
IDC->setMaybeHasOperatorDeclarations();
70547071
IDC->setMaybeHasNestedClassDeclarations();
7072+
IDC->setMaybeHasDerivativeDeclarations();
70557073
Context.evaluator.cacheOutput(
70567074
ParseMembersRequest{IDC},
70577075
FingerprintAndMembers{
@@ -7112,7 +7130,8 @@ Parser::parseDeclList(SourceLoc LBLoc, SourceLoc &RBLoc, Diag<> ErrorDiag,
71127130
}
71137131

71147132
bool Parser::canDelayMemberDeclParsing(bool &HasOperatorDeclarations,
7115-
bool &HasNestedClassDeclarations) {
7133+
bool &HasNestedClassDeclarations,
7134+
bool &HasDerivativeDeclarations) {
71167135
// If explicitly disabled, respect the flag.
71177136
if (!isDelayedParsingEnabled())
71187137
return false;
@@ -7124,13 +7143,10 @@ bool Parser::canDelayMemberDeclParsing(bool &HasOperatorDeclarations,
71247143
bool HasPoundSourceLocation;
71257144
bool HasNestedTypeDeclarations;
71267145
bool HasPotentialRegexLiteral;
7127-
skipUntilMatchingRBrace(*this,
7128-
HasPoundDirective,
7129-
HasPoundSourceLocation,
7130-
HasOperatorDeclarations,
7131-
HasNestedClassDeclarations,
7132-
HasNestedTypeDeclarations,
7133-
HasPotentialRegexLiteral);
7146+
skipUntilMatchingRBrace(*this, HasPoundDirective, HasPoundSourceLocation,
7147+
HasOperatorDeclarations, HasNestedClassDeclarations,
7148+
HasNestedTypeDeclarations, HasPotentialRegexLiteral,
7149+
HasDerivativeDeclarations);
71347150
if (!HasPoundDirective && !HasPotentialRegexLiteral) {
71357151
// If we didn't see any pound directive, we must not have seen
71367152
// #sourceLocation either.
@@ -7742,9 +7758,11 @@ bool Parser::canDelayFunctionBodyParsing(bool &HasNestedTypeDeclarations) {
77427758
bool HasOperatorDeclarations;
77437759
bool HasNestedClassDeclarations;
77447760
bool HasPotentialRegexLiteral;
7761+
bool HasDerivativeDeclarations;
77457762
skipUntilMatchingRBrace(*this, HasPoundDirectives, HasPoundSourceLocation,
77467763
HasOperatorDeclarations, HasNestedClassDeclarations,
7747-
HasNestedTypeDeclarations, HasPotentialRegexLiteral);
7764+
HasNestedTypeDeclarations, HasPotentialRegexLiteral,
7765+
HasDerivativeDeclarations);
77487766
if (HasPoundSourceLocation || HasPotentialRegexLiteral)
77497767
return false;
77507768

0 commit comments

Comments
 (0)