diff --git a/include/swift/SIL/SILModule.h b/include/swift/SIL/SILModule.h index 5266f38d3ad93..3503dc32e643a 100644 --- a/include/swift/SIL/SILModule.h +++ b/include/swift/SIL/SILModule.h @@ -600,6 +600,12 @@ class SILModule { /// Erase a global SIL variable from the module. void eraseGlobalVariable(SILGlobalVariable *G); + /// Erase a differentiability witness from the module. + void eraseDifferentiabilityWittness(SILDifferentiabilityWitness *dw); + + /// Erase all differentiability witnesses for function f. + void eraseAllDifferentiabilityWittnesses(SILFunction *f); + /// Create and return an empty SIL module suitable for generating or parsing /// SIL into. /// diff --git a/lib/SIL/IR/SILModule.cpp b/lib/SIL/IR/SILModule.cpp index 6b15487e17355..111b186ed5595 100644 --- a/lib/SIL/IR/SILModule.cpp +++ b/lib/SIL/IR/SILModule.cpp @@ -501,6 +501,33 @@ void SILModule::eraseGlobalVariable(SILGlobalVariable *gv) { getSILGlobalList().erase(gv); } +void SILModule::eraseDifferentiabilityWittness(SILDifferentiabilityWitness *dw) { + getSILLoader()->invalidateDifferentiabilityWitness(dw); + + Mangle::ASTMangler mangler(getASTContext()); + auto *originalFunction = dw->getOriginalFunction(); + auto mangledKey = mangler.mangleSILDifferentiabilityWitness( + originalFunction->getName(), dw->getKind(), dw->getConfig()); + DifferentiabilityWitnessMap.erase(mangledKey); + llvm::erase(DifferentiabilityWitnessesByFunction[originalFunction->getName()], dw); + + getDifferentiabilityWitnessList().erase(dw); +} + +void SILModule::eraseAllDifferentiabilityWittnesses(SILFunction *f) { + Mangle::ASTMangler mangler(getASTContext()); + + for (auto *dw : DifferentiabilityWitnessesByFunction.at(f->getName())) { + getSILLoader()->invalidateDifferentiabilityWitness(dw); + auto mangledKey = mangler.mangleSILDifferentiabilityWitness( + f->getName(), dw->getKind(), dw->getConfig()); + DifferentiabilityWitnessMap.erase(mangledKey); + getDifferentiabilityWitnessList().erase(dw); + } + + DifferentiabilityWitnessesByFunction.erase(f->getName()); +} + SILVTable *SILModule::lookUpVTable(const ClassDecl *C, bool deserializeLazily) { if (!C) diff --git a/lib/SILGen/SILGen.cpp b/lib/SILGen/SILGen.cpp index 87c17bf3fb179..2fad7ee3ee6c8 100644 --- a/lib/SILGen/SILGen.cpp +++ b/lib/SILGen/SILGen.cpp @@ -1331,7 +1331,7 @@ void SILGenModule::postEmitFunction(SILDeclRef constant, void SILGenModule::emitDifferentiabilityWitnessesForFunction( SILDeclRef constant, SILFunction *F) { - // Visit `@derivative` attributes and generate SIL differentiability + // Visit `@differentiable` attributes and generate SIL differentiability // witnesses. // Skip if the SILDeclRef is a: // - Default argument generator function. @@ -1361,33 +1361,6 @@ void SILGenModule::emitDifferentiabilityWitnessesForFunction( config, /*jvp*/ nullptr, /*vjp*/ nullptr, diffAttr); } - for (auto *derivAttr : Attrs.getAttributes()) { - SILFunction *jvp = nullptr; - SILFunction *vjp = nullptr; - switch (derivAttr->getDerivativeKind()) { - case AutoDiffDerivativeFunctionKind::JVP: - jvp = F; - break; - case AutoDiffDerivativeFunctionKind::VJP: - vjp = F; - break; - } - auto *origAFD = derivAttr->getOriginalFunction(getASTContext()); - auto origDeclRef = - SILDeclRef(origAFD).asForeign(requiresForeignEntryPoint(origAFD)); - auto *origFn = getFunction(origDeclRef, NotForDefinition); - auto witnessGenSig = - autodiff::getDifferentiabilityWitnessGenericSignature( - origAFD->getGenericSignature(), AFD->getGenericSignature()); - auto *resultIndices = - autodiff::getFunctionSemanticResultIndices(origAFD, - derivAttr->getParameterIndices()); - AutoDiffConfig config(derivAttr->getParameterIndices(), resultIndices, - witnessGenSig); - emitDifferentiabilityWitness(origAFD, origFn, - DifferentiabilityKind::Reverse, config, jvp, - vjp, derivAttr); - } }; if (auto *accessor = dyn_cast(AFD)) if (accessor->isGetter()) @@ -1492,6 +1465,36 @@ void SILGenModule::emitAbstractFuncDecl(AbstractFunctionDecl *AFD) { SILDeclRef::BackDeploymentKind::Thunk); emitBackDeploymentThunk(thunk); } + + // Emit differentiability witness for the function referenced in + // @derivative(of:) attribute registering current function as VJP / JVP. + for (auto *derivAttr : AFD->getAttrs().getAttributes()) { + auto *F = getFunction(SILDeclRef(AFD), NotForDefinition); + SILFunction *jvp = nullptr, *vjp = nullptr; + switch (derivAttr->getDerivativeKind()) { + case AutoDiffDerivativeFunctionKind::JVP: + jvp = F; + break; + case AutoDiffDerivativeFunctionKind::VJP: + vjp = F; + break; + } + auto *origAFD = derivAttr->getOriginalFunction(getASTContext()); + auto origDeclRef = + SILDeclRef(origAFD).asForeign(requiresForeignEntryPoint(origAFD)); + auto *origFn = getFunction(origDeclRef, NotForDefinition); + auto witnessGenSig = + autodiff::getDifferentiabilityWitnessGenericSignature( + origAFD->getGenericSignature(), AFD->getGenericSignature()); + auto *resultIndices = + autodiff::getFunctionSemanticResultIndices(origAFD, + derivAttr->getParameterIndices()); + AutoDiffConfig config(derivAttr->getParameterIndices(), resultIndices, + witnessGenSig); + emitDifferentiabilityWitness(origAFD, origFn, + DifferentiabilityKind::Reverse, config, jvp, + vjp, derivAttr); + } } void SILGenModule::emitFunction(FuncDecl *fd) { diff --git a/lib/SILOptimizer/Mandatory/CapturePromotion.cpp b/lib/SILOptimizer/Mandatory/CapturePromotion.cpp index 70e6f813a123c..d153fe27b2851 100644 --- a/lib/SILOptimizer/Mandatory/CapturePromotion.cpp +++ b/lib/SILOptimizer/Mandatory/CapturePromotion.cpp @@ -1473,9 +1473,23 @@ processPartialApplyInst(SILOptFunctionBuilder &funcBuilder, funcBuilder, pai, fri, promotableIndices, f->getResilienceExpansion()); worklist.push_back(clonedFn); + SILFunction *origFn = fri->getReferencedFunction(); + for (auto *w : mod.lookUpDifferentiabilityWitnessesForFunction( + origFn->getName())) { + assert(!w->getJVP() && !w->getVJP() && "does not expect custom derivatives here"); + auto linkage = stripExternalFromLinkage(clonedFn->getLinkage()); + SILDifferentiabilityWitness::createDefinition( + mod, linkage, clonedFn, + w->getKind(), w->getParameterIndices(), w->getResultIndices(), + w->getDerivativeGenericSignature(), + /*jvp*/ nullptr, /*vjp*/ nullptr, + /*isSerialized*/ hasPublicVisibility(clonedFn->getLinkage()), + w->getAttribute()); + } + // Mark the original partial apply function as deletable if it doesn't have // uses later. - fri->getReferencedFunction()->addSemanticsAttr(semantics::DELETE_IF_UNUSED); + origFn->addSemanticsAttr(semantics::DELETE_IF_UNUSED); // Initialize a SILBuilder and create a function_ref referencing the cloned // closure. diff --git a/lib/SILOptimizer/Mandatory/DiagnosticDeadFunctionElimination.cpp b/lib/SILOptimizer/Mandatory/DiagnosticDeadFunctionElimination.cpp index 70a56d6a6175e..5404bdd8c7594 100644 --- a/lib/SILOptimizer/Mandatory/DiagnosticDeadFunctionElimination.cpp +++ b/lib/SILOptimizer/Mandatory/DiagnosticDeadFunctionElimination.cpp @@ -36,6 +36,7 @@ namespace { struct DiagnosticDeadFunctionEliminator : SILFunctionTransform { void run() override { auto *fn = getFunction(); + auto &mod = fn->getModule(); // If an earlier pass asked us to eliminate the function body if it's // unused, and the function is in fact unused, do that now. @@ -67,6 +68,10 @@ struct DiagnosticDeadFunctionEliminator : SILFunctionTransform { b.createUnreachable(loc); } + // Drop differentiability witnesses, if any + if (!mod.lookUpDifferentiabilityWitnessesForFunction(fn->getName()).empty()) + mod.eraseAllDifferentiabilityWittnesses(fn); + // If the function has shared linkage, reduce this version to private // linkage, because we don't want the deleted-body form to win in any // ODR shootouts. diff --git a/test/AutoDiff/compiler_crashers_fixed/issue-59135-nested-function-diff-wittness-capture-promotion.swift b/test/AutoDiff/compiler_crashers_fixed/issue-59135-nested-function-diff-wittness-capture-promotion.swift new file mode 100644 index 0000000000000..c81240143c878 --- /dev/null +++ b/test/AutoDiff/compiler_crashers_fixed/issue-59135-nested-function-diff-wittness-capture-promotion.swift @@ -0,0 +1,33 @@ +// RUN: %target-swift-emit-sil -Xllvm -debug-only=differentiation -o /dev/null 2>&1 %s | %FileCheck %s + +// The differentiability witness for y in s(h:) will be generated by silgen. However, later the capture +// promotion pass would specialize it since it only captures an integer and therefore does not need to +// box the capture. Ensure we create differentiability witness for specialized function. In addition to +// this, since the original function is not used anymore, the body of it is removed (with only unreachable +// terminator inside). Remove original differentiability witness as it would lead to non-differentiable +// diagnostics further on. + +// CHECK-LABEL: differentiability witness for specialized y #1 (_:) in s(h:) +// CHECK: sil_differentiability_witness private [reverse] [parameters 0] [results 0] @$s4null1s1hAA1BVAE_tF1yL_yAA1WVAHFTf2ni_n : $@convention(thin) (@guaranteed W, Int) -> @owned W { +// CHECK-NOT: sil_differentiability_witness private [reverse] [parameters 0] [results 0] @$s4null1s1hAA1BVAE_tF1yL_yAA1WVAHF : $@convention(thin) (@guaranteed W, @guaranteed { var Int }) -> @owned W { + +import _Differentiation +struct B: Differentiable{} +struct X { var j = [Float]()} +struct W: Differentiable { + @noDerivative var z: X + var h: B +} +func o(_ x: T, _ f: @differentiable(reverse) (T) -> R) -> R {f(x)} +func m(_ f: @escaping @differentiable(reverse) (T) -> R) -> @differentiable(reverse) (T) -> R {{ x in o(x, f) }} +@differentiable(reverse) + func s(h: B) -> B { + var (_, e) = (0,0) + @differentiable(reverse) + func y(_ i: W) -> W { + let _ = e; + return i + } + let w = m(y) + return B() +} diff --git a/test/AutoDiff/compiler_crashers_fixed/issue-59135-usableFromInline-VJP.swift b/test/AutoDiff/compiler_crashers_fixed/issue-59135-usableFromInline-VJP.swift new file mode 100644 index 0000000000000..5a763b3888d01 --- /dev/null +++ b/test/AutoDiff/compiler_crashers_fixed/issue-59135-usableFromInline-VJP.swift @@ -0,0 +1,38 @@ +// RUN: %empty-directory(%t) +// RUN: %target-swift-emit-sil -Xllvm -debug-only=differentiation -emit-module -module-name M -emit-module-path %t/M.swiftmodule 2>&1 %s | %FileCheck %s + +// The original function Tensor.subscriptIndexPath() is not marked as @differentiable. As a result, no explicit differentiable witness is generated for it. +// However, the witness is generated as a side effect of providing a derivative via @derivative(of: subscriptIndexPath) on _vjpSubscriptIndexPath. +// Since _vjpSubscriptIndexPath is not emitted when -emit-module is used, we need to ensure we still generate a wittness. + +import _Differentiation + +// CHECK-LABEL: differentiability witness for Tensor.subscriptIndexPath() +// CHECK: sil_differentiability_witness [serialized] [reverse] [parameters 0] [results 0] @$s1M6TensorV18subscriptIndexPathACyF : $@convention(method) (Tensor) -> Tensor { +// CHECK: vjp: @$s1M6TensorV18subscriptIndexPathACyFTJrSpSr : $@convention(method) (Tensor) -> (Tensor, @owned @callee_guaranteed (Tensor) -> Tensor) + +// CHECK-LABEL: reverse-mode derivative of Tensor.subscriptIndexPath() +// CHECK: sil [thunk] [always_inline] [ossa] @$s1M6TensorV18subscriptIndexPathACyFTJrSpSr : $@convention(method) (Tensor) -> (Tensor, @owned @callee_guaranteed (Tensor) -> Tensor) { +// CHECK: function_ref Tensor._vjpSubscriptIndexPath() +// CHECK: function_ref @$s1M6TensorV22_vjpSubscriptIndexPathAC5value_A2Cc8pullbacktyF : $@convention(method) (Tensor) -> (Tensor, @owned @callee_guaranteed (Tensor) -> Tensor) + +public struct Tensor: Differentiable & AdditiveArithmetic { + @inlinable + func subscriptIndexPath() -> Tensor { + fatalError() + } + + @inlinable + @differentiable(reverse, wrt: self) + func subscriptRanges() -> Tensor { + subscriptIndexPath() + } + + @usableFromInline + @derivative(of: subscriptIndexPath) + func _vjpSubscriptIndexPath() -> ( + value: Tensor, pullback: (Tensor) -> Tensor + ) { + fatalError() + } +}