Skip to content

Commit 849c636

Browse files
authored
[AutoDiff] Fix differentiation for non-wrt inout parameters. (#33304)
Fix SIL differential function type calculation to handle non-wrt `inout` parameters. Patch `SILFunctionType::getDifferentiabilityResultIndices` to prevent returning empty result indices for `@differentiable` function types with no formal results where all `inout` parameters are `@noDerivative`. TF-1305 tracks a robust fix. Resolves SR-13305. Exposes TF-1305: parameter/result differentiability hole for `inout` parameters.
1 parent e4f5cc2 commit 849c636

File tree

3 files changed

+146
-4
lines changed

3 files changed

+146
-4
lines changed

lib/SIL/IR/SILFunctionType.cpp

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,25 @@ IndexSubset *SILFunctionType::getDifferentiabilityResultIndices() {
235235
resultIndices.push_back(resultAndIndex.index());
236236
// Check `inout` parameters.
237237
for (auto inoutParamAndIndex : enumerate(getIndirectMutatingParameters()))
238-
if (inoutParamAndIndex.value().getDifferentiability() !=
239-
SILParameterDifferentiability::NotDifferentiable)
238+
// FIXME(TF-1305): The `getResults().empty()` condition is a hack.
239+
//
240+
// Currently, an `inout` parameter can either be:
241+
// 1. Both a differentiability parameter and a differentiability result.
242+
// 2. `@noDerivative`: neither a differentiability parameter nor a
243+
// differentiability result.
244+
// However, there is no way to represent an `inout` parameter that:
245+
// 3. Is a differentiability result but not a differentiability parameter.
246+
// 4. Is a differentiability parameter but not a differentiability result.
247+
// This case is not currently expressible and does not yet have clear use
248+
// cases, so supporting it is a non-goal.
249+
//
250+
// See TF-1305 for solution ideas. For now, `@noDerivative` `inout`
251+
// parameters are not treated as differentiability results, unless the
252+
// original function has no formal results, in which case all `inout`
253+
// parameters are treated as differentiability results.
254+
if (getResults().empty() ||
255+
inoutParamAndIndex.value().getDifferentiability() !=
256+
SILParameterDifferentiability::NotDifferentiable)
240257
resultIndices.push_back(getNumResults() + inoutParamAndIndex.index());
241258
auto numSemanticResults =
242259
getNumResults() + getNumIndirectMutatingParameters();
@@ -432,8 +449,9 @@ static CanSILFunctionType getAutoDiffDifferentialType(
432449
}
433450
}
434451
SmallVector<SILResultInfo, 1> differentialResults;
435-
if (inoutParamIndices->isEmpty()) {
436-
for (auto resultIndex : resultIndices->getIndices()) {
452+
for (auto resultIndex : resultIndices->getIndices()) {
453+
// Handle formal original result.
454+
if (resultIndex < originalFnTy->getNumResults()) {
437455
auto &result = originalResults[resultIndex];
438456
auto resultTan =
439457
result.getInterfaceType()->getAutoDiffTangentSpace(lookupConformance);
@@ -452,8 +470,27 @@ static CanSILFunctionType getAutoDiffDifferentialType(
452470
substReplacements.push_back(resultTanType);
453471
differentialResults.push_back({gpType, resultConv});
454472
}
473+
continue;
455474
}
475+
// Handle original `inout` parameter.
476+
auto inoutParamIndex = resultIndex - originalFnTy->getNumResults();
477+
auto inoutParamIt = std::next(
478+
originalFnTy->getIndirectMutatingParameters().begin(), inoutParamIndex);
479+
auto paramIndex =
480+
std::distance(originalFnTy->getParameters().begin(), &*inoutParamIt);
481+
// If the original `inout` parameter is a differentiability parameter, then
482+
// it already has a corresponding differential parameter. Skip adding a
483+
// corresponding differential result.
484+
if (parameterIndices->contains(paramIndex))
485+
continue;
486+
auto inoutParam = originalFnTy->getParameters()[paramIndex];
487+
auto paramTan = inoutParam.getInterfaceType()->getAutoDiffTangentSpace(
488+
lookupConformance);
489+
assert(paramTan && "Parameter type does not have a tangent space?");
490+
differentialResults.push_back(
491+
{paramTan->getCanonicalType(), ResultConvention::Indirect});
456492
}
493+
457494
SubstitutionMap substitutions;
458495
if (!substGenericParams.empty()) {
459496
auto genericSig =
@@ -714,7 +751,9 @@ CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType(
714751
CanGenericSignature derivativeFnInvocationGenSig,
715752
bool isReabstractionThunk) {
716753
assert(parameterIndices);
754+
assert(!parameterIndices->isEmpty() && "Parameter indices must not be empty");
717755
assert(resultIndices);
756+
assert(!resultIndices->isEmpty() && "Result indices must not be empty");
718757
auto &ctx = getASTContext();
719758

720759
// Look up result in cache.
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// RUN: %target-swift-frontend -emit-sil -verify %s
2+
// REQUIRES: asserts
3+
4+
import _Differentiation
5+
6+
// SR-13305: Test protocol witness thunk for `@differentiable` protocol
7+
// requirement, where the required method has a non-wrt `inout` parameter
8+
// that should be treated as a differentiability result.
9+
10+
protocol SR_13305_Protocol {
11+
@differentiable(wrt: x)
12+
func method(x: Float, y: inout Float)
13+
}
14+
15+
struct SR_13305_Struct: SR_13305_Protocol {
16+
@differentiable(wrt: x)
17+
func method(x: Float, y: inout Float) {
18+
y = y * x
19+
}
20+
}
21+
22+
// Original crash:
23+
// Assertion failed: (!array.empty() && "claiming next from empty array!"), function claimNext, file /Users/danielzheng/swift-build/swift/lib/SILGen/SILGenPoly.cpp, line 112.
24+
// Stack dump:
25+
// ...
26+
// 1. Swift version 5.3-dev (LLVM f8bd914aadc2e7b, Swift ba9c433c81d51ea)
27+
// 2. While evaluating request ASTLoweringRequest(Lowering AST to SIL for module main)
28+
// 3. While generating SIL witness table protocol conformance to 'SR_13305_Protocol' (at sr-13305.swift:7:1) for type 'SR_13305_Struct' (declared at [sr-13305.swift:12:1 - line:17:1] RangeText="struct SR_13305_Struct: SR_13305_Protocol {
29+
// @differentiable(wrt: x)
30+
// func method(x: Float, y: inout Float) {
31+
// y = y * x
32+
// }
33+
// ")
34+
// 4. While generating protocol witness thunk SIL function "@AD__$s4main15SR_13305_StructVAA0B15_13305_ProtocolA2aDP6method1x1yySf_SfztFTW_jvp_SUU".
35+
// for 'method(x:y:)' (at sr-13305.swift:14:3)
36+
// 5. While emitting reabstraction thunk in SIL function "@$sSfIegy_S2fIegyd_TR".
37+
// ...
38+
// 7 swift-frontend 0x0000000100fe80ad swift::SILResultInfo const& claimNext<swift::SILResultInfo>(llvm::ArrayRef<swift::SILResultInfo>&) + 93
39+
// 8 swift-frontend 0x0000000100fe6cc0 (anonymous namespace)::ResultPlanner::claimNextInnerResult((anonymous namespace)::ResultPlanner::PlanData&) + 32
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
// RUN: %target-run-simple-swift
2+
// REQUIRES: executable_test
3+
4+
import DifferentiationUnittest
5+
import StdlibUnittest
6+
7+
var InoutParameterAutoDiffTests = TestSuite("InoutParameterDifferentiation")
8+
9+
// SR-13305: Test function with non-wrt `inout` parameter, which should be
10+
// treated as a differentiability result.
11+
12+
protocol SR_13305_Protocol {
13+
@differentiable(wrt: x)
14+
func method(_ x: Float, _ y: inout Float)
15+
16+
@differentiable(wrt: x)
17+
func genericMethod<T: Differentiable>(_ x: T, _ y: inout T)
18+
}
19+
20+
InoutParameterAutoDiffTests.test("non-wrt inout parameter") {
21+
struct SR_13305_Struct: SR_13305_Protocol {
22+
@differentiable(wrt: x)
23+
func method(_ x: Float, _ y: inout Float) {
24+
y = y * x
25+
}
26+
27+
@differentiable(wrt: x)
28+
func genericMethod<T: Differentiable>(_ x: T, _ y: inout T) {
29+
y = x
30+
}
31+
}
32+
33+
@differentiable(wrt: x)
34+
func foo(_ s: SR_13305_Struct, _ x: Float, _ y: Float) -> Float {
35+
var y = y
36+
s.method(x, &y)
37+
return y
38+
}
39+
40+
@differentiable(wrt: x)
41+
func fooGeneric<T: SR_13305_Protocol>(_ s: T, _ x: Float, _ y: Float) -> Float {
42+
var y = y
43+
s.method(x, &y)
44+
return x
45+
}
46+
47+
let s = SR_13305_Struct()
48+
49+
do {
50+
let (value, (dx, dy)) = valueWithGradient(at: 2, 3, in: { foo(s, $0, $1) })
51+
expectEqual(6, value)
52+
expectEqual((3, 2), (dx, dy))
53+
}
54+
expectEqual((value: 6, gradient: 3), valueWithGradient(at: 2, in: { foo(s, $0, 3) }))
55+
56+
do {
57+
let (value, (dx, dy)) = valueWithGradient(at: 2, 3, in: { fooGeneric(s, $0, $1) })
58+
expectEqual(2, value)
59+
expectEqual((1, 0), (dx, dy))
60+
}
61+
expectEqual((value: 2, gradient: 1), valueWithGradient(at: 2, in: { fooGeneric(s, $0, 3) }))
62+
}
63+
64+
runAllTests()

0 commit comments

Comments
 (0)