-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[AutoDiff] Lookup for custom derivatives in non-primary source files #58965
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1186,11 +1186,31 @@ bool CompilerInstance::forEachFileToTypeCheck( | |
return false; | ||
} | ||
|
||
bool CompilerInstance::forEachSourceFile( | ||
llvm::function_ref<bool(SourceFile &)> fn) { | ||
for (auto fileName : getMainModule()->getFiles()) { | ||
auto *SF = dyn_cast<SourceFile>(fileName); | ||
if (!SF) { | ||
continue; | ||
} | ||
if (fn(*SF)) | ||
return true; | ||
; | ||
} | ||
|
||
return false; | ||
} | ||
|
||
void CompilerInstance::finishTypeChecking() { | ||
forEachFileToTypeCheck([](SourceFile &SF) { | ||
performWholeModuleTypeChecking(SF); | ||
return false; | ||
}); | ||
|
||
forEachSourceFile([](SourceFile &SF) { | ||
loadDerivativeConfigurations(SF); | ||
return false; | ||
}); | ||
Comment on lines
+1210
to
+1213
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you have any sense of the performance implications (empirical or theoretical) of this derivative function resolution approach, compared with the old approach? A high-level performance summary seems to be:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right. And any custom derivative declaration is typechecked. Regardless whether the derivative is used or not. This is quite a big hammer, but for now I do not see other way to "expose" them. Old approach had its own pros and cons:
|
||
} | ||
|
||
SourceFile::ParsingOptions | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -375,6 +375,44 @@ void swift::performWholeModuleTypeChecking(SourceFile &SF) { | |
} | ||
} | ||
|
||
void swift::loadDerivativeConfigurations(SourceFile &SF) { | ||
if (!isDifferentiableProgrammingEnabled(SF)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When is this true? The recursive walk over the AST in secondary files is going to trigger delayed parsing in all nominal types and extensions, which will impact performance in non-WMO builds. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah ok. Can you add a comment to that effect? Something that mentions that this defeats delayed parsing, but only triggers with 'import _Differentiation'. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @slavapestov This is true when there is an What is the viable alternative? Essentially the task is: for function There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The only real alternative would be to change the design of this attribute so that this kind of search does not have to be performed. By definition, it requires parsing all source files ahead of time which is incompatible with our strategy of skipping bodies of nominal types and extensions in secondary files. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah. Though the alternative design would not easily allow "pluggable" derivatives.... (see #58644 (comment) for the design explanation, etc. by @dan-zheng ) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One alternative is to handle this like we do operators and AnyObject lookup. When skipping over the body of a nominal type or extension, make a note if the lexer sees '@_derivative', then eagerly parse only those delayed bodies when performing this global lookup. That's still not great but better than disabling delayed parsing entirely when 'import _Differentiation' appears in the program. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @slavapestov Yes, I though about this as well, thanks! Though I was a bit hesitant to introduce such derivative-specific logic into lexer and parser :) But if you're thinking it's acceptable, I can try to do something around this in the subsequent PRs. |
||
return; | ||
|
||
auto &Ctx = SF.getASTContext(); | ||
FrontendStatsTracer tracer(Ctx.Stats, | ||
"load-derivative-configurations"); | ||
|
||
class DerivativeFinder : public ASTWalker { | ||
public: | ||
DerivativeFinder() {} | ||
|
||
bool walkToDeclPre(Decl *D) override { | ||
if (auto *afd = dyn_cast<AbstractFunctionDecl>(D)) { | ||
for (auto *derAttr : afd->getAttrs().getAttributes<DerivativeAttr>()) { | ||
// Resolve derivative function configurations from `@derivative` | ||
// attributes by type-checking them. | ||
(void)derAttr->getOriginalFunction(D->getASTContext()); | ||
} | ||
} | ||
|
||
return true; | ||
} | ||
}; | ||
|
||
switch (SF.Kind) { | ||
case SourceFileKind::Library: | ||
case SourceFileKind::Main: { | ||
DerivativeFinder finder; | ||
SF.walkContext(finder); | ||
return; | ||
} | ||
case SourceFileKind::SIL: | ||
case SourceFileKind::Interface: | ||
return; | ||
} | ||
} | ||
|
||
bool swift::isAdditiveArithmeticConformanceDerivationEnabled(SourceFile &SF) { | ||
auto &ctx = SF.getASTContext(); | ||
// Return true if `AdditiveArithmetic` derived conformances are explicitly | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import _Differentiation | ||
|
||
@inlinable | ||
@derivative(of: min) | ||
func minVJP<T: Comparable & Differentiable>( | ||
_ x: T, | ||
_ y: T | ||
) -> (value: T, pullback: (T.TangentVector) -> (T.TangentVector, T.TangentVector)) { | ||
func pullback(_ v: T.TangentVector) -> (T.TangentVector, T.TangentVector) { | ||
if x <= y { | ||
return (v, .zero) | ||
} | ||
else { | ||
return (.zero, v) | ||
} | ||
} | ||
return (value: min(x, y), pullback: pullback) | ||
} | ||
|
||
@inlinable | ||
@derivative(of: max) | ||
func maxVJP<T: Comparable & Differentiable>( | ||
_ x: T, | ||
_ y: T | ||
) -> (value: T, pullback: (T.TangentVector) -> (T.TangentVector, T.TangentVector)) { | ||
func pullback(_ v: T.TangentVector) -> (T.TangentVector, T.TangentVector) { | ||
if x < y { | ||
return (.zero, v) | ||
} | ||
else { | ||
return (v, .zero) | ||
} | ||
} | ||
return (value: max(x, y), pullback: pullback) | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
// RUN: %target-swift-frontend -emit-sil -verify -primary-file %s %S/Inputs/derivatives.swift -module-name main -o /dev/null | ||
|
||
import _Differentiation | ||
|
||
@differentiable(reverse) | ||
func clamp(_ value: Double, _ lowerBound: Double, _ upperBound: Double) -> Double { | ||
// No error expected | ||
return max(min(value, upperBound), lowerBound) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you please clarify the functionality added by this PR on top of #58644? My understanding:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a way to avoid the walk over all source files in the case where none of them have differentiable programming enabled?