Skip to content

Commit 0b06fb4

Browse files
committed
Lookup for custom derivatives in non-primary source files after typecheck is finished for the primary source.
This registers all custom derivatives before autodiff transformations and makes them available to them. Fully resolves swiftlang#55170
1 parent 42655e1 commit 0b06fb4

File tree

9 files changed

+111
-37
lines changed

9 files changed

+111
-37
lines changed

include/swift/AST/Decl.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6265,11 +6265,9 @@ class AbstractFunctionDecl : public GenericContext, public ValueDecl {
62656265
DerivativeFunctionConfigurationList *DerivativeFunctionConfigs = nullptr;
62666266

62676267
public:
6268-
/// Get all derivative function configurations. If `lookInNonPrimarySources`
6269-
/// is true then lookup is done in non-primary sources as well. Note that
6270-
/// such lookup might end in cycles if done during sema stages.
6268+
/// Get all derivative function configurations.
62716269
ArrayRef<AutoDiffConfig>
6272-
getDerivativeFunctionConfigurations(bool lookInNonPrimarySources = true);
6270+
getDerivativeFunctionConfigurations();
62736271

62746272
/// Add the given derivative function configuration.
62756273
void addDerivativeFunctionConfiguration(const AutoDiffConfig &config);

include/swift/Frontend/Frontend.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -669,6 +669,7 @@ class CompilerInstance {
669669

670670
/// If \p fn returns true, exits early and returns true.
671671
bool forEachFileToTypeCheck(llvm::function_ref<bool(SourceFile &)> fn);
672+
bool forEachSourceFile(llvm::function_ref<bool(SourceFile &)> fn);
672673

673674
/// Whether the cancellation of the current operation has been requested.
674675
bool isCancellationRequested() const;

include/swift/Subsystems.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,10 @@ namespace swift {
157157
/// emitted.
158158
void performWholeModuleTypeChecking(SourceFile &SF);
159159

160+
/// Load derivative configurations from @derivative attributes (including
161+
/// those defined in non-primary sources).
162+
void loadDerivativeConfigurations(SourceFile &SF);
163+
160164
/// Resolve the given \c TypeRepr to an interface type.
161165
///
162166
/// This is used when dealing with partial source files (e.g. SIL parsing,

lib/AST/Decl.cpp

Lines changed: 1 addition & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -8312,7 +8312,7 @@ void AbstractFunctionDecl::prepareDerivativeFunctionConfigurations() {
83128312
}
83138313

83148314
ArrayRef<AutoDiffConfig>
8315-
AbstractFunctionDecl::getDerivativeFunctionConfigurations(bool lookInNonPrimarySources) {
8315+
AbstractFunctionDecl::getDerivativeFunctionConfigurations() {
83168316
prepareDerivativeFunctionConfigurations();
83178317

83188318
// Resolve derivative function configurations from `@differentiable`
@@ -8336,36 +8336,6 @@ AbstractFunctionDecl::getDerivativeFunctionConfigurations(bool lookInNonPrimaryS
83368336
*DerivativeFunctionConfigs);
83378337
}
83388338

8339-
class DerivativeFinder : public ASTWalker {
8340-
const AbstractFunctionDecl *AFD;
8341-
public:
8342-
DerivativeFinder(const AbstractFunctionDecl *afd) : AFD(afd) {}
8343-
8344-
bool walkToDeclPre(Decl *D) override {
8345-
if (auto *afd = dyn_cast<AbstractFunctionDecl>(D)) {
8346-
for (auto *derAttr : afd->getAttrs().getAttributes<DerivativeAttr>()) {
8347-
// Resolve derivative function configurations from `@derivative`
8348-
// attributes by type-checking them.
8349-
if (AFD->getName().matchesRef(
8350-
derAttr->getOriginalFunctionName().Name.getFullName())) {
8351-
(void)derAttr->getOriginalFunction(afd->getASTContext());
8352-
return false;
8353-
}
8354-
}
8355-
}
8356-
8357-
return true;
8358-
}
8359-
};
8360-
8361-
// Load derivative configurations from @derivative attributes defined in
8362-
// non-primary sources. Note that it might trigger lookup cycles if called
8363-
// from inside Sema stages.
8364-
if (lookInNonPrimarySources) {
8365-
DerivativeFinder finder(this);
8366-
getParent()->walkContext(finder);
8367-
}
8368-
83698339
return DerivativeFunctionConfigs->getArrayRef();
83708340
}
83718341

lib/Frontend/Frontend.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1186,11 +1186,31 @@ bool CompilerInstance::forEachFileToTypeCheck(
11861186
return false;
11871187
}
11881188

1189+
bool CompilerInstance::forEachSourceFile(
1190+
llvm::function_ref<bool(SourceFile &)> fn) {
1191+
for (auto fileName : getMainModule()->getFiles()) {
1192+
auto *SF = dyn_cast<SourceFile>(fileName);
1193+
if (!SF) {
1194+
continue;
1195+
}
1196+
if (fn(*SF))
1197+
return true;
1198+
;
1199+
}
1200+
1201+
return false;
1202+
}
1203+
11891204
void CompilerInstance::finishTypeChecking() {
11901205
forEachFileToTypeCheck([](SourceFile &SF) {
11911206
performWholeModuleTypeChecking(SF);
11921207
return false;
11931208
});
1209+
1210+
forEachSourceFile([](SourceFile &SF) {
1211+
loadDerivativeConfigurations(SF);
1212+
return false;
1213+
});
11941214
}
11951215

11961216
SourceFile::ParsingOptions

lib/Sema/TypeCheckProtocol.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -379,8 +379,7 @@ matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req,
379379
bool foundExactConfig = false;
380380
Optional<AutoDiffConfig> supersetConfig = None;
381381
for (auto witnessConfig :
382-
witnessAFD->getDerivativeFunctionConfigurations(
383-
/*lookInNonPrimarySources*/ false)) {
382+
witnessAFD->getDerivativeFunctionConfigurations()) {
384383
// All the witness's derivative generic requirements must be satisfied
385384
// by the requirement's derivative generic requirements OR by the
386385
// conditional conformance requirements.

lib/Sema/TypeChecker.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,44 @@ void swift::performWholeModuleTypeChecking(SourceFile &SF) {
375375
}
376376
}
377377

378+
void swift::loadDerivativeConfigurations(SourceFile &SF) {
379+
if (!isDifferentiableProgrammingEnabled(SF))
380+
return;
381+
382+
auto &Ctx = SF.getASTContext();
383+
FrontendStatsTracer tracer(Ctx.Stats,
384+
"load-derivative-configurations");
385+
386+
class DerivativeFinder : public ASTWalker {
387+
public:
388+
DerivativeFinder() {}
389+
390+
bool walkToDeclPre(Decl *D) override {
391+
if (auto *afd = dyn_cast<AbstractFunctionDecl>(D)) {
392+
for (auto *derAttr : afd->getAttrs().getAttributes<DerivativeAttr>()) {
393+
// Resolve derivative function configurations from `@derivative`
394+
// attributes by type-checking them.
395+
(void)derAttr->getOriginalFunction(D->getASTContext());
396+
}
397+
}
398+
399+
return true;
400+
}
401+
};
402+
403+
switch (SF.Kind) {
404+
case SourceFileKind::Library:
405+
case SourceFileKind::Main: {
406+
DerivativeFinder finder;
407+
SF.walkContext(finder);
408+
return;
409+
}
410+
case SourceFileKind::SIL:
411+
case SourceFileKind::Interface:
412+
return;
413+
}
414+
}
415+
378416
bool swift::isAdditiveArithmeticConformanceDerivationEnabled(SourceFile &SF) {
379417
auto &ctx = SF.getASTContext();
380418
// Return true if `AdditiveArithmetic` derived conformances are explicitly
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import _Differentiation
2+
3+
@inlinable
4+
@derivative(of: min)
5+
func minVJP<T: Comparable & Differentiable>(
6+
_ x: T,
7+
_ y: T
8+
) -> (value: T, pullback: (T.TangentVector) -> (T.TangentVector, T.TangentVector)) {
9+
func pullback(_ v: T.TangentVector) -> (T.TangentVector, T.TangentVector) {
10+
if x <= y {
11+
return (v, .zero)
12+
}
13+
else {
14+
return (.zero, v)
15+
}
16+
}
17+
return (value: min(x, y), pullback: pullback)
18+
}
19+
20+
@inlinable
21+
@derivative(of: max)
22+
func maxVJP<T: Comparable & Differentiable>(
23+
_ x: T,
24+
_ y: T
25+
) -> (value: T, pullback: (T.TangentVector) -> (T.TangentVector, T.TangentVector)) {
26+
func pullback(_ v: T.TangentVector) -> (T.TangentVector, T.TangentVector) {
27+
if x < y {
28+
return (.zero, v)
29+
}
30+
else {
31+
return (v, .zero)
32+
}
33+
}
34+
return (value: max(x, y), pullback: pullback)
35+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
// RUN: %target-swift-frontend -emit-sil -verify -primary-file %s %S/Inputs/derivatives.swift -module-name main -o /dev/null
2+
3+
import _Differentiation
4+
5+
@differentiable(reverse)
6+
func clamp(_ value: Double, _ lowerBound: Double, _ upperBound: Double) -> Double {
7+
// No error expected
8+
return max(min(value, upperBound), lowerBound)
9+
}

0 commit comments

Comments
 (0)