Skip to content

[AutoDiff] Enhance performance of custom derivatives lookup #76951

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions include/swift/AST/DeclContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,10 @@ class IterableDeclContext {
/// We must restore this when delayed parsing the body.
unsigned InFreestandingMacroArgument : 1;

/// Whether delayed parsing detect a possible custom derivative definition
/// while skipping the body of this context.
unsigned HasDerivativeDeclarations : 1;

template<class A, class B, class C>
friend struct ::llvm::CastInfo;

Expand All @@ -817,6 +821,7 @@ class IterableDeclContext {
: LastDeclAndKind(nullptr, kind) {
AddedParsedMembers = 0;
HasOperatorDeclarations = 0;
HasDerivativeDeclarations = 0;
HasNestedClassDeclarations = 0;
InFreestandingMacroArgument = 0;
}
Expand Down Expand Up @@ -855,6 +860,15 @@ class IterableDeclContext {
InFreestandingMacroArgument = 1;
}

bool maybeHasDerivativeDeclarations() const {
return HasDerivativeDeclarations;
}

void setMaybeHasDerivativeDeclarations() {
assert(hasUnparsedMembers());
HasDerivativeDeclarations = 1;
}

/// Retrieve the current set of members in this context.
///
/// NOTE: This operation is an alias of \c getCurrentMembers() that is considered
Expand Down
1 change: 1 addition & 0 deletions include/swift/AST/Module.h
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ class ModuleDecl
: public DeclContext, public TypeDecl, public ASTAllocated<ModuleDecl> {
friend class DirectOperatorLookupRequest;
friend class DirectPrecedenceGroupLookupRequest;
friend class CustomDerivativesRequest;

/// The ABI name of the module, if it differs from the module name.
mutable Identifier ModuleABIName;
Expand Down
16 changes: 16 additions & 0 deletions include/swift/AST/TypeCheckRequests.h
Original file line number Diff line number Diff line change
Expand Up @@ -5165,6 +5165,22 @@ class GenericTypeParamDeclGetValueTypeRequest
bool isCached() const { return true; }
};

class CustomDerivativesRequest
: public SimpleRequest<CustomDerivativesRequest,
evaluator::SideEffect(SourceFile *),
RequestFlags::Cached> {
public:
using SimpleRequest::SimpleRequest;

private:
friend SimpleRequest;

evaluator::SideEffect evaluate(Evaluator &evaluator, SourceFile *sf) const;

public:
bool isCached() const { return true; }
};

#define SWIFT_TYPEID_ZONE TypeChecker
#define SWIFT_TYPEID_HEADER "swift/AST/TypeCheckerTypeIDZone.def"
#include "swift/Basic/DefineTypeIDZone.h"
Expand Down
3 changes: 3 additions & 0 deletions include/swift/AST/TypeCheckerTypeIDZone.def
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,9 @@ SWIFT_REQUEST(TypeChecker, ParamCaptureInfoRequest,
SWIFT_REQUEST(TypeChecker, IsUnsafeRequest,
bool(Decl *),
SeparatelyCached, NoLocationInfo)
SWIFT_REQUEST(TypeChecker, CustomDerivativesRequest,
CustomDerivativesResult(SourceFile *),
Cached, NoLocationInfo)

SWIFT_REQUEST(TypeChecker, GenericTypeParamDeclGetValueTypeRequest,
Type(GenericTypeParamDecl *), Cached, NoLocationInfo)
3 changes: 2 additions & 1 deletion include/swift/Parse/Parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -974,7 +974,8 @@ class Parser {
IterableDeclContext *IDC);

bool canDelayMemberDeclParsing(bool &HasOperatorDeclarations,
bool &HasNestedClassDeclarations);
bool &HasNestedClassDeclarations,
bool &HasDerivativeDeclarations);

bool canDelayFunctionBodyParsing(bool &HasNestedTypeDeclarations);

Expand Down
93 changes: 73 additions & 20 deletions lib/AST/Module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,15 +166,17 @@ class swift::SourceLookupCache {
ValueDeclMap TopLevelValues;
ValueDeclMap ClassMembers;
bool MemberCachePopulated = false;
llvm::SmallVector<AbstractFunctionDecl *, 0> CustomDerivatives;
DeclName UniqueMacroNamePlaceholder;

template<typename T>
using OperatorMap = llvm::DenseMap<Identifier, TinyPtrVector<T *>>;
OperatorMap<OperatorDecl> Operators;
OperatorMap<PrecedenceGroupDecl> PrecedenceGroups;

template<typename Range>
void addToUnqualifiedLookupCache(Range decls, bool onlyOperators);
template <typename Range>
void addToUnqualifiedLookupCache(Range decls, bool onlyOperators,
bool onlyDerivatives);
template<typename Range>
void addToMemberCache(Range decls);

Expand Down Expand Up @@ -205,6 +207,10 @@ class swift::SourceLookupCache {
/// guaranteed to be meaningful.
void getPrecedenceGroups(SmallVectorImpl<PrecedenceGroupDecl *> &results);

/// Retrieves all the function decls marked as @derivative. The order of the
/// results is not guaranteed to be meaningful.
llvm::SmallVector<AbstractFunctionDecl *, 0> getCustomDerivativeDecls();

/// Look up an operator declaration.
///
/// \param name The operator name ("+", ">>", etc.)
Expand Down Expand Up @@ -249,9 +255,10 @@ static Decl *getAsDecl(Decl *decl) { return decl; }
static Expr *getAsExpr(ASTNode node) { return node.dyn_cast<Expr *>(); }
static Decl *getAsDecl(ASTNode node) { return node.dyn_cast<Decl *>(); }

template<typename Range>
template <typename Range>
void SourceLookupCache::addToUnqualifiedLookupCache(Range items,
bool onlyOperators) {
bool onlyOperators,
bool onlyDerivatives) {
for (auto item : items) {
// In script mode, we'll see macro expansion expressions for freestanding
// macros.
Expand All @@ -268,19 +275,36 @@ void SourceLookupCache::addToUnqualifiedLookupCache(Range items,
continue;

if (auto *VD = dyn_cast<ValueDecl>(D)) {
if (onlyOperators ? VD->isOperator() : VD->hasName()) {
// Cache the value under both its compound name and its full name.
auto getDerivative = [onlyDerivatives, VD]() -> AbstractFunctionDecl * {
if (auto *AFD = dyn_cast<AbstractFunctionDecl>(VD))
if (AFD->getAttrs().hasAttribute<DerivativeAttr>())
return AFD;
return nullptr;
};
if (onlyOperators && VD->isOperator())
TopLevelValues.add(VD);

if (!onlyOperators && VD->getAttrs().hasAttribute<CustomAttr>()) {
if (onlyDerivatives)
if (AbstractFunctionDecl *AFD = getDerivative())
CustomDerivatives.push_back(AFD);
if (!onlyOperators && !onlyDerivatives && VD->hasName()) {
TopLevelValues.add(VD);
if (VD->getAttrs().hasAttribute<CustomAttr>())
MayHaveAuxiliaryDecls.push_back(VD);
}
if (AbstractFunctionDecl *AFD = getDerivative())
CustomDerivatives.push_back(AFD);
Comment on lines +286 to +294
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a specific suggestion on how to improve it, but this logic was a little tangled before with onlyOperators and now it's... really hard to sort through.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for bringing attention to this! I agree that the logic is pretty complicated and should be simplified somehow. I was also unable to quickly find a nice way to fix this, but I'll think of it and submit a subsequent PR once the refactoring of this function is ready.

}
}

if (auto *NTD = dyn_cast<NominalTypeDecl>(D))
if (!NTD->hasUnparsedMembers() || NTD->maybeHasOperatorDeclarations())
addToUnqualifiedLookupCache(NTD->getMembers(), true);
if (auto *NTD = dyn_cast<NominalTypeDecl>(D)) {
bool onlyOperatorsArg =
(!NTD->hasUnparsedMembers() || NTD->maybeHasOperatorDeclarations());
bool onlyDerivativesArg =
(!NTD->hasUnparsedMembers() || NTD->maybeHasDerivativeDeclarations());
if (onlyOperatorsArg || onlyDerivativesArg) {
addToUnqualifiedLookupCache(NTD->getMembers(), onlyOperatorsArg,
onlyDerivativesArg);
}
}

if (auto *ED = dyn_cast<ExtensionDecl>(D)) {
// Avoid populating the cache with the members of invalid extension
Expand All @@ -292,8 +316,14 @@ void SourceLookupCache::addToUnqualifiedLookupCache(Range items,
MayHaveAuxiliaryDecls.push_back(ED);
}

if (!ED->hasUnparsedMembers() || ED->maybeHasOperatorDeclarations())
addToUnqualifiedLookupCache(ED->getMembers(), true);
bool onlyOperatorsArg =
(!ED->hasUnparsedMembers() || ED->maybeHasOperatorDeclarations());
bool onlyDerivativesArg =
(!ED->hasUnparsedMembers() || ED->maybeHasDerivativeDeclarations());
if (onlyOperatorsArg || onlyDerivativesArg) {
addToUnqualifiedLookupCache(ED->getMembers(), onlyOperatorsArg,
onlyDerivativesArg);
}
}

if (auto *OD = dyn_cast<OperatorDecl>(D))
Expand All @@ -307,7 +337,8 @@ void SourceLookupCache::addToUnqualifiedLookupCache(Range items,
MayHaveAuxiliaryDecls.push_back(MED);
} else if (auto TLCD = dyn_cast<TopLevelCodeDecl>(D)) {
if (auto body = TLCD->getBody()){
addToUnqualifiedLookupCache(body->getElements(), onlyOperators);
addToUnqualifiedLookupCache(body->getElements(), onlyOperators,
onlyDerivatives);
}
}
}
Expand Down Expand Up @@ -488,8 +519,8 @@ SourceLookupCache::SourceLookupCache(const SourceFile &SF)
{
FrontendStatsTracer tracer(SF.getASTContext().Stats,
"source-file-populate-cache");
addToUnqualifiedLookupCache(SF.getTopLevelItems(), false);
addToUnqualifiedLookupCache(SF.getHoistedDecls(), false);
addToUnqualifiedLookupCache(SF.getTopLevelItems(), false, false);
addToUnqualifiedLookupCache(SF.getHoistedDecls(), false, false);
}

SourceLookupCache::SourceLookupCache(const ModuleDecl &M)
Expand All @@ -499,11 +530,11 @@ SourceLookupCache::SourceLookupCache(const ModuleDecl &M)
"module-populate-cache");
for (const FileUnit *file : M.getFiles()) {
auto *SF = cast<SourceFile>(file);
addToUnqualifiedLookupCache(SF->getTopLevelItems(), false);
addToUnqualifiedLookupCache(SF->getHoistedDecls(), false);
addToUnqualifiedLookupCache(SF->getTopLevelItems(), false, false);
addToUnqualifiedLookupCache(SF->getHoistedDecls(), false, false);

if (auto *SFU = file->getSynthesizedFile()) {
addToUnqualifiedLookupCache(SFU->getTopLevelDecls(), false);
addToUnqualifiedLookupCache(SFU->getTopLevelDecls(), false, false);
}
}
}
Expand Down Expand Up @@ -572,6 +603,11 @@ void SourceLookupCache::getOperatorDecls(
results.append(ops.second.begin(), ops.second.end());
}

llvm::SmallVector<AbstractFunctionDecl *, 0>
SourceLookupCache::getCustomDerivativeDecls() {
return CustomDerivatives;
}

void SourceLookupCache::lookupOperator(Identifier name, OperatorFixity fixity,
TinyPtrVector<OperatorDecl *> &results) {
auto ops = Operators.find(name);
Expand Down Expand Up @@ -4027,6 +4063,23 @@ bool IsNonUserModuleRequest::evaluate(Evaluator &evaluator, ModuleDecl *mod) con
return false;
}

evaluator::SideEffect CustomDerivativesRequest::evaluate(Evaluator &evaluator,
SourceFile *sf) const {
ModuleDecl *module = sf->getParentModule();
assert(isParsedModule(module));
llvm::SmallVector<AbstractFunctionDecl *, 0> decls =
module->getSourceLookupCache().getCustomDerivativeDecls();
for (const AbstractFunctionDecl *afd : decls) {
for (const auto *derAttr :
afd->getAttrs().getAttributes<DerivativeAttr>()) {
// Resolve derivative function configurations from `@derivative`
// attributes by type-checking them.
(void)derAttr->getOriginalFunction(sf->getASTContext());
}
}
return {};
}

version::Version ModuleDecl::getLanguageVersionBuiltWith() const {
for (auto *F : getFiles()) {
auto *LD = dyn_cast<LoadedFile>(F);
Expand Down
44 changes: 31 additions & 13 deletions lib/Parse/ParseDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5644,19 +5644,20 @@ static void diagnoseOperatorFixityAttributes(Parser &P,
}
}

static unsigned skipUntilMatchingRBrace(Parser &P,
bool &HasPoundDirective,
static unsigned skipUntilMatchingRBrace(Parser &P, bool &HasPoundDirective,
bool &HasPoundSourceLocation,
bool &HasOperatorDeclarations,
bool &HasNestedClassDeclarations,
bool &HasNestedTypeDeclarations,
bool &HasPotentialRegexLiteral) {
bool &HasPotentialRegexLiteral,
bool &HasDerivativeDeclarations) {
HasPoundDirective = false;
HasPoundSourceLocation = false;
HasOperatorDeclarations = false;
HasNestedClassDeclarations = false;
HasNestedTypeDeclarations = false;
HasPotentialRegexLiteral = false;
HasDerivativeDeclarations = false;

unsigned OpenBraces = 1;
unsigned OpenPoundIf = 0;
Expand Down Expand Up @@ -5685,6 +5686,18 @@ static unsigned skipUntilMatchingRBrace(Parser &P,
tok::kw_protocol)
|| P.Tok.isContextualKeyword("actor");

if (P.consumeIf(tok::at_sign)) {
if (P.Tok.is(tok::identifier)) {
std::optional<DeclAttrKind> DK =
DeclAttribute::getAttrKindFromString(P.Tok.getText());
if (DK && *DK == DeclAttrKind::Derivative) {
HasDerivativeDeclarations = true;
P.consumeToken();
}
}
continue;
}

// HACK: Bail if we encounter what could potentially be a regex literal.
// This is necessary as:
// - We might encounter an invalid Swift token that might be valid in a
Expand Down Expand Up @@ -7033,13 +7046,17 @@ bool Parser::parseMemberDeclList(SourceLoc &LBLoc, SourceLoc &RBLoc,

bool HasOperatorDeclarations = false;
bool HasNestedClassDeclarations = false;
bool HasDerivativeDeclarations = false;

if (canDelayMemberDeclParsing(HasOperatorDeclarations,
HasNestedClassDeclarations)) {
HasNestedClassDeclarations,
HasDerivativeDeclarations)) {
if (HasOperatorDeclarations)
IDC->setMaybeHasOperatorDeclarations();
if (HasNestedClassDeclarations)
IDC->setMaybeHasNestedClassDeclarations();
if (HasDerivativeDeclarations)
IDC->setMaybeHasDerivativeDeclarations();
if (InFreestandingMacroArgument)
IDC->setInFreestandingMacroArgument();

Expand All @@ -7052,6 +7069,7 @@ bool Parser::parseMemberDeclList(SourceLoc &LBLoc, SourceLoc &RBLoc,
auto membersAndHash = parseDeclList(LBLoc, RBLoc, RBraceDiag, hadError);
IDC->setMaybeHasOperatorDeclarations();
IDC->setMaybeHasNestedClassDeclarations();
IDC->setMaybeHasDerivativeDeclarations();
Context.evaluator.cacheOutput(
ParseMembersRequest{IDC},
FingerprintAndMembers{
Expand Down Expand Up @@ -7112,7 +7130,8 @@ Parser::parseDeclList(SourceLoc LBLoc, SourceLoc &RBLoc, Diag<> ErrorDiag,
}

bool Parser::canDelayMemberDeclParsing(bool &HasOperatorDeclarations,
bool &HasNestedClassDeclarations) {
bool &HasNestedClassDeclarations,
bool &HasDerivativeDeclarations) {
// If explicitly disabled, respect the flag.
if (!isDelayedParsingEnabled())
return false;
Expand All @@ -7124,13 +7143,10 @@ bool Parser::canDelayMemberDeclParsing(bool &HasOperatorDeclarations,
bool HasPoundSourceLocation;
bool HasNestedTypeDeclarations;
bool HasPotentialRegexLiteral;
skipUntilMatchingRBrace(*this,
HasPoundDirective,
HasPoundSourceLocation,
HasOperatorDeclarations,
HasNestedClassDeclarations,
HasNestedTypeDeclarations,
HasPotentialRegexLiteral);
skipUntilMatchingRBrace(*this, HasPoundDirective, HasPoundSourceLocation,
HasOperatorDeclarations, HasNestedClassDeclarations,
HasNestedTypeDeclarations, HasPotentialRegexLiteral,
HasDerivativeDeclarations);
if (!HasPoundDirective && !HasPotentialRegexLiteral) {
// If we didn't see any pound directive, we must not have seen
// #sourceLocation either.
Expand Down Expand Up @@ -7742,9 +7758,11 @@ bool Parser::canDelayFunctionBodyParsing(bool &HasNestedTypeDeclarations) {
bool HasOperatorDeclarations;
bool HasNestedClassDeclarations;
bool HasPotentialRegexLiteral;
bool HasDerivativeDeclarations;
skipUntilMatchingRBrace(*this, HasPoundDirectives, HasPoundSourceLocation,
HasOperatorDeclarations, HasNestedClassDeclarations,
HasNestedTypeDeclarations, HasPotentialRegexLiteral);
HasNestedTypeDeclarations, HasPotentialRegexLiteral,
HasDerivativeDeclarations);
if (HasPoundSourceLocation || HasPotentialRegexLiteral)
return false;

Expand Down
Loading