Skip to content

Commit ce5e0e7

Browse files
author
marcrasi
authored
[AutoDiff] Fix SR-12641: Handle address-only types in derivative fn types (#31496)
* The update in `SILFunctionType.cpp` fixes SR-12641 by making address-only parameters/results in differentials have indirect convention. * I updated the crasher test to use a resilient struct defined in the test, instead of `Tracked<Float>`, so that the test does not need to depend on `DifferentiationUnittest`. * The update in `VJPEmitter.cpp` fixes a similar issue with pullbacks that I discovered while investigating. * I added code that exposes this new issue to the SR-12641 crasher test.
1 parent 6a5d8bb commit ce5e0e7

File tree

4 files changed

+144
-50
lines changed

4 files changed

+144
-50
lines changed

lib/SIL/IR/SILFunctionType.cpp

+62-6
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,59 @@ getSemanticResults(SILFunctionType *functionType, IndexSubset *parameterIndices,
303303
static CanSILFunctionType
304304
getAutoDiffDifferentialType(SILFunctionType *originalFnTy,
305305
IndexSubset *parameterIndices, unsigned resultIndex,
306-
LookupConformanceFn lookupConformance) {
306+
LookupConformanceFn lookupConformance,
307+
TypeConverter &TC) {
308+
// Given the tangent type and the corresponding original parameter's
309+
// convention, returns the tangent parameter's convention.
310+
auto getTangentParameterConvention =
311+
[&](CanType tanType,
312+
ParameterConvention origParamConv) -> ParameterConvention {
313+
tanType =
314+
tanType->getCanonicalType(originalFnTy->getSubstGenericSignature());
315+
AbstractionPattern pattern(originalFnTy->getSubstGenericSignature(),
316+
tanType);
317+
auto &tl =
318+
TC.getTypeLowering(pattern, tanType, TypeExpansionContext::minimal());
319+
// When the tangent type is address only, we must ensure that the tangent
320+
// parameter's convention is indirect.
321+
if (tl.isAddressOnly() && !isIndirectFormalParameter(origParamConv)) {
322+
switch (origParamConv) {
323+
case ParameterConvention::Direct_Guaranteed:
324+
return ParameterConvention::Indirect_In_Guaranteed;
325+
case ParameterConvention::Direct_Owned:
326+
case ParameterConvention::Direct_Unowned:
327+
return ParameterConvention::Indirect_In;
328+
default:
329+
llvm_unreachable("unhandled parameter convention");
330+
}
331+
}
332+
return origParamConv;
333+
};
334+
335+
// Given the tangent type and the corresponding original result's convention,
336+
// returns the tangent result's convention.
337+
auto getTangentResultConvention =
338+
[&](CanType tanType,
339+
ResultConvention origResConv) -> ResultConvention {
340+
tanType =
341+
tanType->getCanonicalType(originalFnTy->getSubstGenericSignature());
342+
AbstractionPattern pattern(originalFnTy->getSubstGenericSignature(),
343+
tanType);
344+
auto &tl =
345+
TC.getTypeLowering(pattern, tanType, TypeExpansionContext::minimal());
346+
// When the tangent type is address only, we must ensure that the tangent
347+
// result's convention is indirect.
348+
if (tl.isAddressOnly() && !isIndirectFormalResult(origResConv)) {
349+
switch (origResConv) {
350+
case ResultConvention::Owned:
351+
return ResultConvention::Indirect;
352+
default:
353+
llvm_unreachable("unhandled result convention");
354+
}
355+
}
356+
return origResConv;
357+
};
358+
307359
auto &ctx = originalFnTy->getASTContext();
308360
SmallVector<GenericTypeParamType *, 4> substGenericParams;
309361
SmallVector<Requirement, 4> substRequirements;
@@ -324,15 +376,17 @@ getAutoDiffDifferentialType(SILFunctionType *originalFnTy,
324376
param.getInterfaceType()->getAutoDiffTangentSpace(lookupConformance);
325377
assert(paramTan && "Parameter type does not have a tangent space?");
326378
auto paramTanType = paramTan->getCanonicalType();
379+
auto paramConv = getTangentParameterConvention(paramTanType,
380+
param.getConvention());
327381
if (!paramTanType->hasArchetype() && !paramTanType->hasTypeParameter()) {
328382
differentialParams.push_back(
329-
{paramTan->getCanonicalType(), param.getConvention()});
383+
{paramTan->getCanonicalType(), paramConv});
330384
} else {
331385
auto gpIndex = substGenericParams.size();
332386
auto gpType = CanGenericTypeParamType::get(0, gpIndex, ctx);
333387
substGenericParams.push_back(gpType);
334388
substReplacements.push_back(paramTanType);
335-
differentialParams.push_back({gpType, param.getConvention()});
389+
differentialParams.push_back({gpType, paramConv});
336390
}
337391
}
338392
SmallVector<SILResultInfo, 1> differentialResults;
@@ -342,15 +396,17 @@ getAutoDiffDifferentialType(SILFunctionType *originalFnTy,
342396
result.getInterfaceType()->getAutoDiffTangentSpace(lookupConformance);
343397
assert(resultTan && "Result type does not have a tangent space?");
344398
auto resultTanType = resultTan->getCanonicalType();
399+
auto resultConv = getTangentResultConvention(resultTanType,
400+
result.getConvention());
345401
if (!resultTanType->hasArchetype() && !resultTanType->hasTypeParameter()) {
346402
differentialResults.push_back(
347-
{resultTan->getCanonicalType(), result.getConvention()});
403+
{resultTan->getCanonicalType(), resultConv});
348404
} else {
349405
auto gpIndex = substGenericParams.size();
350406
auto gpType = CanGenericTypeParamType::get(0, gpIndex, ctx);
351407
substGenericParams.push_back(gpType);
352408
substReplacements.push_back(resultTanType);
353-
differentialResults.push_back({gpType, result.getConvention()});
409+
differentialResults.push_back({gpType, resultConv});
354410
}
355411
}
356412
SubstitutionMap substitutions;
@@ -620,7 +676,7 @@ CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType(
620676
case AutoDiffDerivativeFunctionKind::JVP:
621677
closureType =
622678
getAutoDiffDifferentialType(constrainedOriginalFnTy, parameterIndices,
623-
resultIndex, lookupConformance);
679+
resultIndex, lookupConformance, TC);
624680
break;
625681
case AutoDiffDerivativeFunctionKind::VJP:
626682
closureType =

lib/SILOptimizer/Differentiation/VJPEmitter.cpp

+12-4
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,12 @@ SILFunction *VJPEmitter::createEmptyPullback() {
9797
switch (origResConv) {
9898
case ResultConvention::Owned:
9999
case ResultConvention::Autoreleased:
100-
conv = tl.isTrivial() ? ParameterConvention::Direct_Unowned
101-
: ParameterConvention::Direct_Guaranteed;
100+
if (tl.isAddressOnly()) {
101+
conv = ParameterConvention::Indirect_In_Guaranteed;
102+
} else {
103+
conv = tl.isTrivial() ? ParameterConvention::Direct_Unowned
104+
: ParameterConvention::Direct_Guaranteed;
105+
}
102106
break;
103107
case ResultConvention::Unowned:
104108
case ResultConvention::UnownedInnerPointer:
@@ -123,8 +127,12 @@ SILFunction *VJPEmitter::createEmptyPullback() {
123127
case ParameterConvention::Direct_Owned:
124128
case ParameterConvention::Direct_Guaranteed:
125129
case ParameterConvention::Direct_Unowned:
126-
conv =
127-
tl.isTrivial() ? ResultConvention::Unowned : ResultConvention::Owned;
130+
if (tl.isAddressOnly()) {
131+
conv = ResultConvention::Indirect;
132+
} else {
133+
conv = tl.isTrivial() ? ResultConvention::Unowned
134+
: ResultConvention::Owned;
135+
}
128136
break;
129137
case ParameterConvention::Indirect_In:
130138
case ParameterConvention::Indirect_Inout:

test/AutoDiff/compiler_crashers/sr12641-silgen-immutable-address-use-verification-failure.swift

-40
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
// RUN: %target-swift-frontend -enable-resilience -emit-sil -verify %s
2+
// REQUIRES: asserts
3+
4+
// SR-12641: SILGen verification error regarding `ImmutableAddressUseVerifier` and AutoDiff-generated code.
5+
6+
import _Differentiation
7+
8+
public struct Resilient: Differentiable {
9+
var x: Float
10+
}
11+
12+
public class Class: Differentiable {
13+
var x: Resilient
14+
init(_ x: Resilient) {
15+
self.x = x
16+
}
17+
}
18+
19+
public func f(_ c: Class) -> Resilient {
20+
return Resilient(x: 0)
21+
}
22+
23+
_ = pullback(at: Class(Resilient(x: 10)), in: f)
24+
25+
// swift/lib/SIL/Verifier/SILVerifier.cpp:456: bool (anonymous namespace)::ImmutableAddressUseVerifier::isConsumingOrMutatingArgumentConvention(swift::SILArgumentConvention): Assertion `conv.isIndirectConvention() && "Expect an indirect convention"' failed.
26+
// Stack dump:
27+
// ...
28+
// 1. Swift version 5.3-dev (LLVM be43a34c3c, Swift 6d5b2f5220)
29+
// 2. While evaluating request SILGenWholeModuleRequest(SIL Generation for module main)
30+
// 3. While verifying SIL function "@$s4main5ClassC13TangentVectorVAA9ResilientVADVIeggr_AeHIegnr_TR".
31+
// ...
32+
// #8 0x00000000011e7a3e (anonymous namespace)::ImmutableAddressUseVerifier::isConsumingOrMutatingApplyUse(swift::Operand*)
33+
// #9 0x00000000011e6add (anonymous namespace)::ImmutableAddressUseVerifier::isMutatingOrConsuming(swift::SILValue)
34+
// #10 0x00000000011ce0b4 (anonymous namespace)::SILVerifier::visitSILBasicBlock(swift::SILBasicBlock*)
35+
36+
// Related crasher discovered while fixing SR-12641.
37+
38+
class LoadableOriginal<T: Differentiable>: Differentiable {
39+
var x: T
40+
init(_ x: T) { self.x = x }
41+
}
42+
43+
@differentiable
44+
func loadableOriginal<T: AdditiveArithmetic>(_ loadable: LoadableOriginal<T>) -> T {
45+
return T.zero
46+
}
47+
48+
// swift/include/swift/SIL/TypeLowering.h:845: swift::SILType swift::Lowering::TypeConverter::getLoweredLoadableType(swift::Type, swift::TypeExpansionContext, swift::SILModule &): Assertion `(ti.isLoadable() || !SILModuleConventions(M).useLoweredAddresses()) && "unexpected address-only type"' failed.
49+
// Stack dump:
50+
// ...
51+
// 2. While evaluating request ExecuteSILPipelineRequest(Run pipelines { Guaranteed Passes } on SIL for main.main)
52+
// 3. While running pass #153 SILModuleTransform "Differentiation".
53+
// 4. While processing // differentiability witness for loadableOriginal<A>(_:)
54+
// sil_differentiability_witness hidden [parameters 0] [results 0] <T where T : AdditiveArithmetic, T : Differentiable> @$s4main16loadableOriginalyxAA08LoadableC0CyxGs18AdditiveArithmeticRz16_Differentiation14DifferentiableRzlF : $@convention(thin) <T where T : Additive
55+
// Arithmetic, T : Differentiable> (@guaranteed LoadableOriginal<T>) -> @out T {
56+
// }
57+
//
58+
// on SIL function "@$s4main16loadableOriginalyxAA08LoadableC0CyxGs18AdditiveArithmeticRz16_Differentiation14DifferentiableRzlF".
59+
// for 'loadableOriginal(_:)'
60+
// 5. While generating VJP for SIL function "@$s4main16loadableOriginalyxAA08LoadableC0CyxGs18AdditiveArithmeticRz16_Differentiation14DifferentiableRzlF".
61+
// for 'loadableOriginal(_:)'
62+
// 6. While generating pullback for SIL function "@$s4main16loadableOriginalyxAA08LoadableC0CyxGs18AdditiveArithmeticRz16_Differentiation14DifferentiableRzlF".
63+
// for 'loadableOriginal(_:)'
64+
// ...
65+
// #9 0x0000000000f83fbb swift::autodiff::PullbackEmitter::emitZeroDirect(swift::CanType, swift::SILLocation)
66+
// #10 0x0000000000f8248b swift::autodiff::PullbackEmitter::emitZeroDerivativesForNonvariedResult(swift::SILValue)
67+
// #11 0x0000000000f7fcae swift::autodiff::PullbackEmitter::run()
68+
// #12 0x0000000000f3fba4 swift::autodiff::VJPEmitter::run()
69+
// #13 0x0000000000eb1669 (anonymous namespace)::DifferentiationTransformer::canonicalizeDifferentiabilityWitness(swift::SILFunction*, swift::SILDifferentiabilityWitness*, swift::autodiff::DifferentiationInvoker, swift::IsSerialized_t)
70+
// #14 0x0000000000eaea5e (anonymous namespace)::Differentiation::run()

0 commit comments

Comments
 (0)