Skip to content

Commit 6aacd72

Browse files
committed
Do not consider non-wrt semantic result parameters as semantic results
Fixes #67174
1 parent 3d130a5 commit 6aacd72

15 files changed

+84
-132
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -253,10 +253,9 @@ struct AutoDiffSemanticFunctionResultType {
253253
Type type;
254254
unsigned index : 30;
255255
bool isSemanticResultParameter : 1;
256-
bool isWrtParam : 1;
257256

258-
AutoDiffSemanticFunctionResultType(Type t, unsigned idx, bool param, bool wrt)
259-
: type(t), index(idx), isSemanticResultParameter(param), isWrtParam(wrt) { }
257+
AutoDiffSemanticFunctionResultType(Type t, unsigned idx, bool param)
258+
: type(t), index(idx), isSemanticResultParameter(param) { }
260259
};
261260

262261
/// Key for caching SIL derivative function types.

include/swift/AST/Types.h

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3500,19 +3500,16 @@ class AnyFunctionType : public TypeBase {
35003500
/// Preconditions:
35013501
/// - Parameters corresponding to parameter indices must conform to
35023502
/// `Differentiable`.
3503-
/// - There is one semantic function result type: either the formal original
3504-
/// result or an `inout` parameter. It must conform to `Differentiable`.
3503+
/// - There are semantic function result type: either the formal original
3504+
/// result or a "wrt" semantic result parameter.
35053505
///
35063506
/// Differential typing rules: takes "wrt" parameter derivatives and returns a
35073507
/// "wrt" result derivative.
35083508
///
35093509
/// - Case 1: original function has no `inout` parameters.
35103510
/// - Original: `(T0, T1, ...) -> R`
35113511
/// - Differential: `(T0.Tan, T1.Tan, ...) -> R.Tan`
3512-
/// - Case 2: original function has a non-wrt `inout` parameter.
3513-
/// - Original: `(T0, inout T1, ...) -> Void`
3514-
/// - Differential: `(T0.Tan, ...) -> T1.Tan`
3515-
/// - Case 3: original function has a wrt `inout` parameter.
3512+
/// - Case 2: original function has a wrt `inout` parameter.
35163513
/// - Original: `(T0, inout T1, ...) -> Void`
35173514
/// - Differential: `(T0.Tan, inout T1.Tan, ...) -> Void`
35183515
///
@@ -3522,10 +3519,7 @@ class AnyFunctionType : public TypeBase {
35223519
/// - Case 1: original function has no `inout` parameters.
35233520
/// - Original: `(T0, T1, ...) -> R`
35243521
/// - Pullback: `R.Tan -> (T0.Tan, T1.Tan, ...)`
3525-
/// - Case 2: original function has a non-wrt `inout` parameter.
3526-
/// - Original: `(T0, inout T1, ...) -> Void`
3527-
/// - Pullback: `(T1.Tan) -> (T0.Tan, ...)`
3528-
/// - Case 3: original function has a wrt `inout` parameter.
3522+
/// - Case 2: original function has a wrt `inout` parameter.
35293523
/// - Original: `(T0, inout T1, ...) -> Void`
35303524
/// - Pullback: `(inout T1.Tan) -> (T0.Tan, ...)`
35313525
///

lib/AST/AutoDiff.cpp

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -199,19 +199,16 @@ void autodiff::getFunctionSemanticResults(
199199
if (formalResultType->is<TupleType>()) {
200200
for (auto elt : formalResultType->castTo<TupleType>()->getElements()) {
201201
resultTypes.emplace_back(elt.getType(), resultIdx++,
202-
/*isInout*/ false, /*isWrt*/ false);
202+
/*isParameter*/ false);
203203
}
204204
} else {
205205
resultTypes.emplace_back(formalResultType, resultIdx++,
206-
/*isInout*/ false, /*isWrt*/ false);
206+
/*isParameter*/ false);
207207
}
208208
}
209209

210-
bool addNonWrts = resultTypes.empty();
211-
212210
// Collect wrt semantic result (`inout`) parameters as
213-
// semantic results As an extention, collect all (including non-wrt) inouts as
214-
// results for functions returning void.
211+
// semantic results
215212
auto collectSemanticResults = [&](const AnyFunctionType *functionType,
216213
unsigned curryOffset = 0) {
217214
for (auto paramAndIndex : enumerate(functionType->getParams())) {
@@ -221,10 +218,9 @@ void autodiff::getFunctionSemanticResults(
221218
unsigned idx = paramAndIndex.index() + curryOffset;
222219
assert(idx < parameterIndices->getCapacity() &&
223220
"invalid parameter index");
224-
bool isWrt = parameterIndices->contains(idx);
225-
if (addNonWrts || isWrt)
221+
if (parameterIndices->contains(idx))
226222
resultTypes.emplace_back(paramAndIndex.value().getPlainType(),
227-
resultIdx, /*isInout*/ true, isWrt);
223+
resultIdx, /*isParameter*/ true);
228224
resultIdx += 1;
229225
}
230226
};

lib/AST/Type.cpp

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5557,7 +5557,7 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
55575557
return llvm::make_error<DerivativeFunctionTypeError>(
55585558
this, DerivativeFunctionTypeError::Kind::NoSemanticResults);
55595559

5560-
// Accumulate non-inout result tangent spaces.
5560+
// Accumulate non-semantic result tangent spaces.
55615561
SmallVector<Type, 1> resultTanTypes, inoutTanTypes;
55625562
for (auto i : range(originalResults.size())) {
55635563
auto originalResult = originalResults[i];
@@ -5578,14 +5578,8 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
55785578

55795579
if (!originalResult.isSemanticResultParameter)
55805580
resultTanTypes.push_back(resultTan->getType());
5581-
else if (originalResult.isSemanticResultParameter && !originalResult.isWrtParam)
5582-
inoutTanTypes.push_back(resultTan->getType());
55835581
}
55845582

5585-
// Treat non-wrt inouts as semantic results for functions returning Void
5586-
if (resultTanTypes.empty())
5587-
resultTanTypes = inoutTanTypes;
5588-
55895583
// Compute the result linear map function type.
55905584
FunctionType *linearMapType;
55915585
switch (kind) {
@@ -5596,11 +5590,7 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
55965590
// - Original: `(T0, T1, ...) -> R`
55975591
// - Differential: `(T0.Tan, T1.Tan, ...) -> R.Tan`
55985592
//
5599-
// Case 2: original function has a non-wrt `inout` parameter.
5600-
// - Original: `(T0, inout T1, ...) -> Void`
5601-
// - Differential: `(T0.Tan, ...) -> T1.Tan`
5602-
//
5603-
// Case 3: original function has a wrt `inout` parameter.
5593+
// Case 2: original function has a wrt `inout` parameter.
56045594
// - Original: `(T0, inout T1, ...) -> Void`
56055595
// - Differential: `(T0.Tan, inout T1.Tan, ...) -> Void`
56065596
SmallVector<AnyFunctionType::Param, 4> differentialParams;
@@ -5647,11 +5637,7 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
56475637
// - Original: `(T0, T1, ...) -> R`
56485638
// - Pullback: `R.Tan -> (T0.Tan, T1.Tan, ...)`
56495639
//
5650-
// Case 2: original function has a non-wrt `inout` parameter.
5651-
// - Original: `(T0, inout T1, ...) -> Void`
5652-
// - Pullback: `(T1.Tan) -> (T0.Tan, ...)`
5653-
//
5654-
// Case 3: original function has wrt `inout` parameters.
5640+
// Case 2: original function has wrt `inout` parameters.
56555641
// - Original: `(T0, inout T1, ...) -> R`
56565642
// - Pullback: `(R.Tan, inout T1.Tan) -> (T0.Tan, ...)`
56575643
SmallVector<TupleTypeElt, 4> pullbackResults;

lib/SIL/IR/SILFunctionType.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -253,11 +253,8 @@ IndexSubset *SILFunctionType::getDifferentiabilityResultIndices() {
253253
// cases, so supporting it is a non-goal.
254254
//
255255
// See TF-1305 for solution ideas. For now, `@noDerivative` `inout`
256-
// parameters are not treated as differentiability results, unless the
257-
// original function has no formal results, in which case all `inout`
258-
// parameters are treated as differentiability results.
259-
if (resultIndices.empty() ||
260-
resultParamAndIndex.value().getDifferentiability() !=
256+
// parameters are not treated as differentiability results.
257+
if (resultParamAndIndex.value().getDifferentiability() !=
261258
SILParameterDifferentiability::NotDifferentiable)
262259
resultIndices.push_back(getNumResults() + resultParamAndIndex.index());
263260

test/AutoDiff/SILGen/inout_differentiability_witness.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ func test3(x: Int, y: inout DiffableStruct, z: Float) -> (Float, Float) { return
1717
@differentiable(reverse, wrt: y)
1818
func test4(x: Int, y: inout DiffableStruct, z: Float) -> Void { }
1919

20-
@differentiable(reverse, wrt: z)
20+
@differentiable(reverse, wrt: (y, z))
2121
func test5(x: Int, y: inout DiffableStruct, z: Float) -> Void { }
2222

2323
@differentiable(reverse, wrt: (y, z))
@@ -48,9 +48,9 @@ func test6(x: Int, y: inout DiffableStruct, z: Float) -> (Float, Float) { return
4848
// CHECK: }
4949

5050
// CHECK-LABEL: differentiability witness for test5(x:y:z:)
51-
// CHECK: sil_differentiability_witness hidden [reverse] [parameters 2] [results 0] @$s31inout_differentiability_witness5test51x1y1zySi_AA14DiffableStructVzSftF : $@convention(thin) (Int, @inout DiffableStruct, Float) -> () {
52-
// CHECK: jvp: @$s31inout_differentiability_witness5test51x1y1zySi_AA14DiffableStructVzSftFTJfUUSpSr : $@convention(thin) (Int, @inout DiffableStruct, Float) -> @owned @callee_guaranteed (Float) -> @out DiffableStruct.TangentVector
53-
// CHECK: vjp: @$s31inout_differentiability_witness5test51x1y1zySi_AA14DiffableStructVzSftFTJrUUSpSr : $@convention(thin) (Int, @inout DiffableStruct, Float) -> @owned @callee_guaranteed (@in_guaranteed DiffableStruct.TangentVector) -> Float
51+
// CHECK: sil_differentiability_witness hidden [reverse] [parameters 1 2] [results 0] @$s31inout_differentiability_witness5test51x1y1zySi_AA14DiffableStructVzSftF : $@convention(thin) (Int, @inout DiffableStruct, Float) -> () {
52+
// CHECK: jvp: @$s31inout_differentiability_witness5test51x1y1zySi_AA14DiffableStructVzSftFTJfUSSpSr : $@convention(thin) (Int, @inout DiffableStruct, Float) -> @owned @callee_guaranteed (@inout DiffableStruct.TangentVector, Float) -> ()
53+
// CHECK: vjp: @$s31inout_differentiability_witness5test51x1y1zySi_AA14DiffableStructVzSftFTJrUSSpSr : $@convention(thin) (Int, @inout DiffableStruct, Float) -> @owned @callee_guaranteed (@inout DiffableStruct.TangentVector) -> Float
5454
// CHECK: }
5555

5656
// CHECK-LABEL: differentiability witness for test6(x:y:z:)

test/AutoDiff/SILGen/witness_table.swift

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ protocol Protocol: Differentiable {
1212
@differentiable(reverse)
1313
var property: Float { get set }
1414

15-
@differentiable(reverse, wrt: x)
15+
@differentiable(reverse, wrt: (self, x))
1616
subscript(_ x: Float, _ y: Float) -> Float { get set }
1717
}
1818

@@ -82,22 +82,22 @@ struct Struct: Protocol {
8282
// CHECK: apply [[VJP_FN]]
8383
// CHECK: }
8484

85-
@differentiable(reverse, wrt: x)
85+
@differentiable(reverse, wrt: (self, x))
8686
subscript(_ x: Float, _ y: Float) -> Float {
8787
get { x }
8888
set {}
8989
}
9090

91-
// CHECK-LABEL: sil private [transparent] [thunk] [ossa] @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcigTW_jvp_SUU : $@convention(witness_method: Protocol) (Float, Float, @in_guaranteed Struct) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
91+
// CHECK-LABEL: sil private [transparent] [thunk] [ossa] @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcigTW_jvp_SUS : $@convention(witness_method: Protocol) (Float, Float, @in_guaranteed Struct) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (Float, @in_guaranteed τ_0_0) -> Float for <DummyTangentVector>)
9292
// CHECK: [[ORIG_FN:%.*]] = function_ref @$s13witness_table6StructVyS2f_Sftcig : $@convention(method) (Float, Float, Struct) -> Float
93-
// CHECK: [[DIFF_FN:%.*]] = differentiable_function [parameters 0] [results 0] [[ORIG_FN]]
93+
// CHECK: [[DIFF_FN:%.*]] = differentiable_function [parameters 0 2] [results 0] [[ORIG_FN]]
9494
// CHECK: [[JVP_FN:%.*]] = differentiable_function_extract [jvp] [[DIFF_FN]]
9595
// CHECK: apply [[JVP_FN]]
9696
// CHECK: }
9797

98-
// CHECK-LABEL: sil private [transparent] [thunk] [ossa] @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcigTW_vjp_SUU : $@convention(witness_method: Protocol) (Float, Float, @in_guaranteed Struct) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
98+
// CHECK-LABEL: sil private [transparent] [thunk] [ossa] @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcigTW_vjp_SUS : $@convention(witness_method: Protocol) (Float, Float, @in_guaranteed Struct) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (Float) -> (Float, @out τ_0_0) for <DummyTangentVector>)
9999
// CHECK: [[ORIG_FN:%.*]] = function_ref @$s13witness_table6StructVyS2f_Sftcig : $@convention(method) (Float, Float, Struct) -> Float
100-
// CHECK: [[DIFF_FN:%.*]] = differentiable_function [parameters 0] [results 0] [[ORIG_FN]]
100+
// CHECK: [[DIFF_FN:%.*]] = differentiable_function [parameters 0 2] [results 0] [[ORIG_FN]]
101101
// CHECK: [[VJP_FN:%.*]] = differentiable_function_extract [vjp] [[DIFF_FN]]
102102
// CHECK: apply [[VJP_FN]]
103103
// CHECK: }
@@ -118,10 +118,10 @@ struct Struct: Protocol {
118118
// CHECK-NEXT: method #Protocol.property!setter.vjp.SS.<Self where Self : Protocol>: <Self where Self : Protocol> (inout Self) -> (Float) -> () : @AD__$s13witness_table6StructVAA8ProtocolA2aDP8propertySfvsTW_vjp_SS
119119
// CHECK-NEXT: method #Protocol.property!modify: <Self where Self : Protocol> (inout Self) -> () -> () : @$s13witness_table6StructVAA8ProtocolA2aDP8propertySfvMTW
120120
// CHECK-NEXT: method #Protocol.subscript!getter: <Self where Self : Protocol> (Self) -> (Float, Float) -> Float : @$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcigTW
121-
// CHECK-NEXT: method #Protocol.subscript!getter.jvp.SUU.<Self where Self : Protocol>: <Self where Self : Protocol> (Self) -> (Float, Float) -> Float : @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcigTW_jvp_SU
122-
// CHECK-NEXT: method #Protocol.subscript!getter.vjp.SUU.<Self where Self : Protocol>: <Self where Self : Protocol> (Self) -> (Float, Float) -> Float : @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcigTW_vjp_SUU
121+
// CHECK-NEXT: method #Protocol.subscript!getter.jvp.SUS.<Self where Self : Protocol>: <Self where Self : Protocol> (Self) -> (Float, Float) -> Float : @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcigTW_jvp_SUS
122+
// CHECK-NEXT: method #Protocol.subscript!getter.vjp.SUS.<Self where Self : Protocol>: <Self where Self : Protocol> (Self) -> (Float, Float) -> Float : @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcigTW_vjp_SUS
123123
// CHECK-NEXT: method #Protocol.subscript!setter: <Self where Self : Protocol> (inout Self) -> (Float, Float, Float) -> () : @$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcisTW
124-
// CHECK-NEXT: method #Protocol.subscript!setter.jvp.USUU.<Self where Self : Protocol>: <Self where Self : Protocol> (inout Self) -> (Float, Float, Float) -> () : @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcisTW_jvp_USUU
125-
// CHECK-NEXT: method #Protocol.subscript!setter.vjp.USUU.<Self where Self : Protocol>: <Self where Self : Protocol> (inout Self) -> (Float, Float, Float) -> () : @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcisTW_vjp_USUU
124+
// CHECK-NEXT: method #Protocol.subscript!setter.jvp.USUS.<Self where Self : Protocol>: <Self where Self : Protocol> (inout Self) -> (Float, Float, Float) -> () : @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcisTW_jvp_USUS
125+
// CHECK-NEXT: method #Protocol.subscript!setter.vjp.USUS.<Self where Self : Protocol>: <Self where Self : Protocol> (inout Self) -> (Float, Float, Float) -> () : @AD__$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftcisTW_vjp_USUS
126126
// CHECK-NEXT: method #Protocol.subscript!modify: <Self where Self : Protocol> (inout Self) -> (Float, Float) -> () : @$s13witness_table6StructVAA8ProtocolA2aDPyS2f_SftciMTW
127127
// CHECK: }

0 commit comments

Comments
 (0)