Skip to content

[AutoDiff] Lookup for custom derivatives in non-primary source files #58965

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 18, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions include/swift/AST/Decl.h
Original file line number Diff line number Diff line change
@@ -6265,11 +6265,9 @@ class AbstractFunctionDecl : public GenericContext, public ValueDecl {
DerivativeFunctionConfigurationList *DerivativeFunctionConfigs = nullptr;

public:
/// Get all derivative function configurations. If `lookInNonPrimarySources`
/// is true then lookup is done in non-primary sources as well. Note that
/// such lookup might end in cycles if done during sema stages.
/// Get all derivative function configurations.
ArrayRef<AutoDiffConfig>
getDerivativeFunctionConfigurations(bool lookInNonPrimarySources = true);
getDerivativeFunctionConfigurations();

/// Add the given derivative function configuration.
void addDerivativeFunctionConfiguration(const AutoDiffConfig &config);
1 change: 1 addition & 0 deletions include/swift/Frontend/Frontend.h
Original file line number Diff line number Diff line change
@@ -669,6 +669,7 @@ class CompilerInstance {

/// If \p fn returns true, exits early and returns true.
bool forEachFileToTypeCheck(llvm::function_ref<bool(SourceFile &)> fn);
bool forEachSourceFile(llvm::function_ref<bool(SourceFile &)> fn);

/// Whether the cancellation of the current operation has been requested.
bool isCancellationRequested() const;
4 changes: 4 additions & 0 deletions include/swift/Subsystems.h
Original file line number Diff line number Diff line change
@@ -157,6 +157,10 @@ namespace swift {
/// emitted.
void performWholeModuleTypeChecking(SourceFile &SF);

/// Load derivative configurations from @derivative attributes (including
/// those defined in non-primary sources).
void loadDerivativeConfigurations(SourceFile &SF);

/// Resolve the given \c TypeRepr to an interface type.
///
/// This is used when dealing with partial source files (e.g. SIL parsing,
32 changes: 1 addition & 31 deletions lib/AST/Decl.cpp
Original file line number Diff line number Diff line change
@@ -8312,7 +8312,7 @@ void AbstractFunctionDecl::prepareDerivativeFunctionConfigurations() {
}

ArrayRef<AutoDiffConfig>
AbstractFunctionDecl::getDerivativeFunctionConfigurations(bool lookInNonPrimarySources) {
AbstractFunctionDecl::getDerivativeFunctionConfigurations() {
prepareDerivativeFunctionConfigurations();

// Resolve derivative function configurations from `@differentiable`
@@ -8336,36 +8336,6 @@ AbstractFunctionDecl::getDerivativeFunctionConfigurations(bool lookInNonPrimaryS
*DerivativeFunctionConfigs);
}

class DerivativeFinder : public ASTWalker {
const AbstractFunctionDecl *AFD;
public:
DerivativeFinder(const AbstractFunctionDecl *afd) : AFD(afd) {}

bool walkToDeclPre(Decl *D) override {
if (auto *afd = dyn_cast<AbstractFunctionDecl>(D)) {
for (auto *derAttr : afd->getAttrs().getAttributes<DerivativeAttr>()) {
// Resolve derivative function configurations from `@derivative`
// attributes by type-checking them.
if (AFD->getName().matchesRef(
derAttr->getOriginalFunctionName().Name.getFullName())) {
(void)derAttr->getOriginalFunction(afd->getASTContext());
return false;
}
}
}

return true;
}
};

// Load derivative configurations from @derivative attributes defined in
// non-primary sources. Note that it might trigger lookup cycles if called
// from inside Sema stages.
if (lookInNonPrimarySources) {
DerivativeFinder finder(this);
getParent()->walkContext(finder);
}

return DerivativeFunctionConfigs->getArrayRef();
}

20 changes: 20 additions & 0 deletions lib/Frontend/Frontend.cpp
Original file line number Diff line number Diff line change
@@ -1186,11 +1186,31 @@ bool CompilerInstance::forEachFileToTypeCheck(
return false;
}

bool CompilerInstance::forEachSourceFile(
llvm::function_ref<bool(SourceFile &)> fn) {
for (auto fileName : getMainModule()->getFiles()) {
auto *SF = dyn_cast<SourceFile>(fileName);
if (!SF) {
continue;
}
if (fn(*SF))
return true;
;
}

return false;
}

void CompilerInstance::finishTypeChecking() {
forEachFileToTypeCheck([](SourceFile &SF) {
performWholeModuleTypeChecking(SF);
return false;
});

forEachSourceFile([](SourceFile &SF) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to avoid the walk over all source files in the case where none of them have differentiable programming enabled?

loadDerivativeConfigurations(SF);
return false;
});
Comment on lines +1210 to +1213
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have any sense of the performance implications (empirical or theoretical) of this derivative function resolution approach, compared with the old approach?

A high-level performance summary seems to be:

  • No overhead if import _Differentiation is missing.
  • Otherwise, every declaration in every source file is visited.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. And any custom derivative declaration is typechecked. Regardless whether the derivative is used or not. This is quite a big hammer, but for now I do not see other way to "expose" them.

Old approach had its own pros and cons:

  • Pros: other source files were visited only when derivative was needed. And only those required were typechecked
  • Cons: other source files were visited on every derivative lookup. And we had typechecking cycles.

}

SourceFile::ParsingOptions
3 changes: 1 addition & 2 deletions lib/Sema/TypeCheckProtocol.cpp
Original file line number Diff line number Diff line change
@@ -379,8 +379,7 @@ matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req,
bool foundExactConfig = false;
Optional<AutoDiffConfig> supersetConfig = None;
for (auto witnessConfig :
witnessAFD->getDerivativeFunctionConfigurations(
/*lookInNonPrimarySources*/ false)) {
witnessAFD->getDerivativeFunctionConfigurations()) {
// All the witness's derivative generic requirements must be satisfied
// by the requirement's derivative generic requirements OR by the
// conditional conformance requirements.
38 changes: 38 additions & 0 deletions lib/Sema/TypeChecker.cpp
Original file line number Diff line number Diff line change
@@ -375,6 +375,44 @@ void swift::performWholeModuleTypeChecking(SourceFile &SF) {
}
}

void swift::loadDerivativeConfigurations(SourceFile &SF) {
if (!isDifferentiableProgrammingEnabled(SF))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When is this true? The recursive walk over the AST in secondary files is going to trigger delayed parsing in all nominal types and extensions, which will impact performance in non-WMO builds.

Copy link
Contributor

@dan-zheng dan-zheng May 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isDifferentiableProgrammingEnabled (implementation) returns true only for files that have import _Differentiation, so I hope there would be negligible impact for other files.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok. Can you add a comment to that effect? Something that mentions that this defeats delayed parsing, but only triggers with 'import _Differentiation'.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@slavapestov This is true when there is an import _Differentiation in the given SourceFile (so, corresponding ImportDecl).

What is the viable alternative? Essentially the task is: for function f find all functions that has @derivative(of : f) attribute specified. Both in primary and secondary sources.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the viable alternative? Essentially the task is: for function f find all functions that has @Derivative(of : f) attribute specified. Both in primary and secondary sources.

The only real alternative would be to change the design of this attribute so that this kind of search does not have to be performed. By definition, it requires parsing all source files ahead of time which is incompatible with our strategy of skipping bodies of nominal types and extensions in secondary files.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. Though the alternative design would not easily allow "pluggable" derivatives.... (see #58644 (comment) for the design explanation, etc. by @dan-zheng )

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One alternative is to handle this like we do operators and AnyObject lookup. When skipping over the body of a nominal type or extension, make a note if the lexer sees '@_derivative', then eagerly parse only those delayed bodies when performing this global lookup. That's still not great but better than disabling delayed parsing entirely when 'import _Differentiation' appears in the program.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@slavapestov Yes, I though about this as well, thanks! Though I was a bit hesitant to introduce such derivative-specific logic into lexer and parser :) But if you're thinking it's acceptable, I can try to do something around this in the subsequent PRs.

return;

auto &Ctx = SF.getASTContext();
FrontendStatsTracer tracer(Ctx.Stats,
"load-derivative-configurations");

class DerivativeFinder : public ASTWalker {
public:
DerivativeFinder() {}

bool walkToDeclPre(Decl *D) override {
if (auto *afd = dyn_cast<AbstractFunctionDecl>(D)) {
for (auto *derAttr : afd->getAttrs().getAttributes<DerivativeAttr>()) {
// Resolve derivative function configurations from `@derivative`
// attributes by type-checking them.
(void)derAttr->getOriginalFunction(D->getASTContext());
}
}

return true;
}
};

switch (SF.Kind) {
case SourceFileKind::Library:
case SourceFileKind::Main: {
DerivativeFinder finder;
SF.walkContext(finder);
return;
}
case SourceFileKind::SIL:
case SourceFileKind::Interface:
return;
}
}

bool swift::isAdditiveArithmeticConformanceDerivationEnabled(SourceFile &SF) {
auto &ctx = SF.getASTContext();
// Return true if `AdditiveArithmetic` derived conformances are explicitly
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import _Differentiation

@inlinable
@derivative(of: min)
func minVJP<T: Comparable & Differentiable>(
_ x: T,
_ y: T
) -> (value: T, pullback: (T.TangentVector) -> (T.TangentVector, T.TangentVector)) {
func pullback(_ v: T.TangentVector) -> (T.TangentVector, T.TangentVector) {
if x <= y {
return (v, .zero)
}
else {
return (.zero, v)
}
}
return (value: min(x, y), pullback: pullback)
}

@inlinable
@derivative(of: max)
func maxVJP<T: Comparable & Differentiable>(
_ x: T,
_ y: T
) -> (value: T, pullback: (T.TangentVector) -> (T.TangentVector, T.TangentVector)) {
func pullback(_ v: T.TangentVector) -> (T.TangentVector, T.TangentVector) {
if x < y {
return (.zero, v)
}
else {
return (v, .zero)
}
}
return (value: max(x, y), pullback: pullback)
}
9 changes: 9 additions & 0 deletions test/AutoDiff/Sema/DerivativeRegistrationCrossFile/main.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// RUN: %target-swift-frontend -emit-sil -verify -primary-file %s %S/Inputs/derivatives.swift -module-name main -o /dev/null

import _Differentiation

@differentiable(reverse)
func clamp(_ value: Double, _ lowerBound: Double, _ upperBound: Double) -> Double {
// No error expected
return max(min(value, upperBound), lowerBound)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please clarify the functionality added by this PR on top of #58644?

My understanding:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Yes, as I mentioned, this avoids circular requests. As custom derivative resolution is decoupled from obtaining derivatives configurations. Also, after this PR the differentiation will not trigger typechecking
  2. After this PR indeed, we lookup and register all explicit derivates regardless of their original function and regardless whether this function is used. The nice side effect is that we will error out in case of @derivative(of: foo) with foo not defined anywhere.
  3. Effectively this resolves the following cross-file derivate requests:
    • File A.swift is using (wants a derivative of) function from B.swift with derivatives defined in C.swift
    • The important special case is when B.swift is effectively a part of stdlib (like in the testcase). Here C.swift is a non-primary source but we have no way to resolve the derivative in other way as there is no mechanism to "pull" anything from C.swift – there are simply no forward dependency edges in the dependency graph, only backward.

}