Skip to content

Fix two issues related with emission of differentiability witnesses #80983

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions include/swift/SIL/SILModule.h
Original file line number Diff line number Diff line change
@@ -600,6 +600,12 @@ class SILModule {
/// Erase a global SIL variable from the module.
void eraseGlobalVariable(SILGlobalVariable *G);

/// Erase a differentiability witness from the module.
void eraseDifferentiabilityWittness(SILDifferentiabilityWitness *dw);

/// Erase all differentiability witnesses for function f.
void eraseAllDifferentiabilityWittnesses(SILFunction *f);

/// Create and return an empty SIL module suitable for generating or parsing
/// SIL into.
///
27 changes: 27 additions & 0 deletions lib/SIL/IR/SILModule.cpp
Original file line number Diff line number Diff line change
@@ -501,6 +501,33 @@ void SILModule::eraseGlobalVariable(SILGlobalVariable *gv) {
getSILGlobalList().erase(gv);
}

void SILModule::eraseDifferentiabilityWittness(SILDifferentiabilityWitness *dw) {
getSILLoader()->invalidateDifferentiabilityWitness(dw);

Mangle::ASTMangler mangler(getASTContext());
auto *originalFunction = dw->getOriginalFunction();
auto mangledKey = mangler.mangleSILDifferentiabilityWitness(
originalFunction->getName(), dw->getKind(), dw->getConfig());
DifferentiabilityWitnessMap.erase(mangledKey);
llvm::erase(DifferentiabilityWitnessesByFunction[originalFunction->getName()], dw);

getDifferentiabilityWitnessList().erase(dw);
}

void SILModule::eraseAllDifferentiabilityWittnesses(SILFunction *f) {
Mangle::ASTMangler mangler(getASTContext());

for (auto *dw : DifferentiabilityWitnessesByFunction.at(f->getName())) {
getSILLoader()->invalidateDifferentiabilityWitness(dw);
auto mangledKey = mangler.mangleSILDifferentiabilityWitness(
f->getName(), dw->getKind(), dw->getConfig());
DifferentiabilityWitnessMap.erase(mangledKey);
getDifferentiabilityWitnessList().erase(dw);
}

DifferentiabilityWitnessesByFunction.erase(f->getName());
}

SILVTable *SILModule::lookUpVTable(const ClassDecl *C,
bool deserializeLazily) {
if (!C)
59 changes: 31 additions & 28 deletions lib/SILGen/SILGen.cpp
Original file line number Diff line number Diff line change
@@ -1331,7 +1331,7 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,

void SILGenModule::emitDifferentiabilityWitnessesForFunction(
SILDeclRef constant, SILFunction *F) {
// Visit `@derivative` attributes and generate SIL differentiability
// Visit `@differentiable` attributes and generate SIL differentiability
// witnesses.
// Skip if the SILDeclRef is a:
// - Default argument generator function.
@@ -1361,33 +1361,6 @@ void SILGenModule::emitDifferentiabilityWitnessesForFunction(
config, /*jvp*/ nullptr,
/*vjp*/ nullptr, diffAttr);
}
for (auto *derivAttr : Attrs.getAttributes<DerivativeAttr>()) {
SILFunction *jvp = nullptr;
SILFunction *vjp = nullptr;
switch (derivAttr->getDerivativeKind()) {
case AutoDiffDerivativeFunctionKind::JVP:
jvp = F;
break;
case AutoDiffDerivativeFunctionKind::VJP:
vjp = F;
break;
}
auto *origAFD = derivAttr->getOriginalFunction(getASTContext());
auto origDeclRef =
SILDeclRef(origAFD).asForeign(requiresForeignEntryPoint(origAFD));
auto *origFn = getFunction(origDeclRef, NotForDefinition);
auto witnessGenSig =
autodiff::getDifferentiabilityWitnessGenericSignature(
origAFD->getGenericSignature(), AFD->getGenericSignature());
auto *resultIndices =
autodiff::getFunctionSemanticResultIndices(origAFD,
derivAttr->getParameterIndices());
AutoDiffConfig config(derivAttr->getParameterIndices(), resultIndices,
witnessGenSig);
emitDifferentiabilityWitness(origAFD, origFn,
DifferentiabilityKind::Reverse, config, jvp,
vjp, derivAttr);
}
};
if (auto *accessor = dyn_cast<AccessorDecl>(AFD))
if (accessor->isGetter())
@@ -1492,6 +1465,36 @@ void SILGenModule::emitAbstractFuncDecl(AbstractFunctionDecl *AFD) {
SILDeclRef::BackDeploymentKind::Thunk);
emitBackDeploymentThunk(thunk);
}

// Emit differentiability witness for the function referenced in
// @derivative(of:) attribute registering current function as VJP / JVP.
for (auto *derivAttr : AFD->getAttrs().getAttributes<DerivativeAttr>()) {
auto *F = getFunction(SILDeclRef(AFD), NotForDefinition);
SILFunction *jvp = nullptr, *vjp = nullptr;
switch (derivAttr->getDerivativeKind()) {
case AutoDiffDerivativeFunctionKind::JVP:
jvp = F;
break;
case AutoDiffDerivativeFunctionKind::VJP:
vjp = F;
break;
}
auto *origAFD = derivAttr->getOriginalFunction(getASTContext());
auto origDeclRef =
SILDeclRef(origAFD).asForeign(requiresForeignEntryPoint(origAFD));
auto *origFn = getFunction(origDeclRef, NotForDefinition);
auto witnessGenSig =
autodiff::getDifferentiabilityWitnessGenericSignature(
origAFD->getGenericSignature(), AFD->getGenericSignature());
auto *resultIndices =
autodiff::getFunctionSemanticResultIndices(origAFD,
derivAttr->getParameterIndices());
AutoDiffConfig config(derivAttr->getParameterIndices(), resultIndices,
witnessGenSig);
emitDifferentiabilityWitness(origAFD, origFn,
DifferentiabilityKind::Reverse, config, jvp,
vjp, derivAttr);
}
}

void SILGenModule::emitFunction(FuncDecl *fd) {
16 changes: 15 additions & 1 deletion lib/SILOptimizer/Mandatory/CapturePromotion.cpp
Original file line number Diff line number Diff line change
@@ -1473,9 +1473,23 @@ processPartialApplyInst(SILOptFunctionBuilder &funcBuilder,
funcBuilder, pai, fri, promotableIndices, f->getResilienceExpansion());
worklist.push_back(clonedFn);

SILFunction *origFn = fri->getReferencedFunction();
for (auto *w : mod.lookUpDifferentiabilityWitnessesForFunction(
origFn->getName())) {
assert(!w->getJVP() && !w->getVJP() && "does not expect custom derivatives here");
auto linkage = stripExternalFromLinkage(clonedFn->getLinkage());
SILDifferentiabilityWitness::createDefinition(
mod, linkage, clonedFn,
w->getKind(), w->getParameterIndices(), w->getResultIndices(),
w->getDerivativeGenericSignature(),
/*jvp*/ nullptr, /*vjp*/ nullptr,
/*isSerialized*/ hasPublicVisibility(clonedFn->getLinkage()),
w->getAttribute());
}

// Mark the original partial apply function as deletable if it doesn't have
// uses later.
fri->getReferencedFunction()->addSemanticsAttr(semantics::DELETE_IF_UNUSED);
origFn->addSemanticsAttr(semantics::DELETE_IF_UNUSED);

// Initialize a SILBuilder and create a function_ref referencing the cloned
// closure.
Original file line number Diff line number Diff line change
@@ -36,6 +36,7 @@ namespace {
struct DiagnosticDeadFunctionEliminator : SILFunctionTransform {
void run() override {
auto *fn = getFunction();
auto &mod = fn->getModule();

// If an earlier pass asked us to eliminate the function body if it's
// unused, and the function is in fact unused, do that now.
@@ -67,6 +68,10 @@ struct DiagnosticDeadFunctionEliminator : SILFunctionTransform {
b.createUnreachable(loc);
}

// Drop differentiability witnesses, if any
if (!mod.lookUpDifferentiabilityWitnessesForFunction(fn->getName()).empty())
mod.eraseAllDifferentiabilityWittnesses(fn);

// If the function has shared linkage, reduce this version to private
// linkage, because we don't want the deleted-body form to win in any
// ODR shootouts.
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// RUN: %target-swift-emit-sil -Xllvm -debug-only=differentiation -o /dev/null 2>&1 %s | %FileCheck %s

// The differentiability witness for y in s(h:) will be generated by silgen. However, later the capture
// promotion pass would specialize it since it only captures an integer and therefore does not need to
// box the capture. Ensure we create differentiability witness for specialized function. In addition to
// this, since the original function is not used anymore, the body of it is removed (with only unreachable
// terminator inside). Remove original differentiability witness as it would lead to non-differentiable
// diagnostics further on.

// CHECK-LABEL: differentiability witness for specialized y #1 (_:) in s(h:)
// CHECK: sil_differentiability_witness private [reverse] [parameters 0] [results 0] @$s4null1s1hAA1BVAE_tF1yL_yAA1WVAHFTf2ni_n : $@convention(thin) (@guaranteed W, Int) -> @owned W {
// CHECK-NOT: sil_differentiability_witness private [reverse] [parameters 0] [results 0] @$s4null1s1hAA1BVAE_tF1yL_yAA1WVAHF : $@convention(thin) (@guaranteed W, @guaranteed { var Int }) -> @owned W {

import _Differentiation
struct B: Differentiable{}
struct X { var j = [Float]()}
struct W: Differentiable {
@noDerivative var z: X
var h: B
}
func o<T, R>(_ x: T, _ f: @differentiable(reverse) (T) -> R) -> R {f(x)}
func m<T, R>(_ f: @escaping @differentiable(reverse) (T) -> R) -> @differentiable(reverse) (T) -> R {{ x in o(x, f) }}
@differentiable(reverse)
func s(h: B) -> B {
var (_, e) = (0,0)
@differentiable(reverse)
func y(_ i: W) -> W {
let _ = e;
return i
}
let w = m(y)
return B()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// RUN: %empty-directory(%t)
// RUN: %target-swift-emit-sil -Xllvm -debug-only=differentiation -emit-module -module-name M -emit-module-path %t/M.swiftmodule 2>&1 %s | %FileCheck %s

// The original function Tensor.subscriptIndexPath() is not marked as @differentiable. As a result, no explicit differentiable witness is generated for it.
// However, the witness is generated as a side effect of providing a derivative via @derivative(of: subscriptIndexPath) on _vjpSubscriptIndexPath.
// Since _vjpSubscriptIndexPath is not emitted when -emit-module is used, we need to ensure we still generate a wittness.

import _Differentiation

// CHECK-LABEL: differentiability witness for Tensor.subscriptIndexPath()
// CHECK: sil_differentiability_witness [serialized] [reverse] [parameters 0] [results 0] @$s1M6TensorV18subscriptIndexPathACyF : $@convention(method) (Tensor) -> Tensor {
// CHECK: vjp: @$s1M6TensorV18subscriptIndexPathACyFTJrSpSr : $@convention(method) (Tensor) -> (Tensor, @owned @callee_guaranteed (Tensor) -> Tensor)

// CHECK-LABEL: reverse-mode derivative of Tensor.subscriptIndexPath()
// CHECK: sil [thunk] [always_inline] [ossa] @$s1M6TensorV18subscriptIndexPathACyFTJrSpSr : $@convention(method) (Tensor) -> (Tensor, @owned @callee_guaranteed (Tensor) -> Tensor) {
// CHECK: function_ref Tensor._vjpSubscriptIndexPath()
// CHECK: function_ref @$s1M6TensorV22_vjpSubscriptIndexPathAC5value_A2Cc8pullbacktyF : $@convention(method) (Tensor) -> (Tensor, @owned @callee_guaranteed (Tensor) -> Tensor)

public struct Tensor: Differentiable & AdditiveArithmetic {
@inlinable
func subscriptIndexPath() -> Tensor {
fatalError()
}

@inlinable
@differentiable(reverse, wrt: self)
func subscriptRanges() -> Tensor {
subscriptIndexPath()
}

@usableFromInline
@derivative(of: subscriptIndexPath)
func _vjpSubscriptIndexPath() -> (
value: Tensor, pullback: (Tensor) -> Tensor
) {
fatalError()
}
}