Skip to content

Commit 7311068

Browse files
committed
[AutoDiff] Fix return type of subset parameters thunk function
The patch resolves #55703. When the original function has a tuple result type, we should append thunkedLinearMap as the last element of the tuple to match the function declaration. Before this patch, the compiler used to wrap the original result tuple and thunkedLinearMap into another tuple, and caused the verifier error. Before the patch: return %{{.*}} : $((Float, Double), @callee_guaranteed (X.TangentVector) -> Float) After the patch: return %{{.*}} : $(Float, Double, @callee_guaranteed (Float) -> X.TangentVector)
1 parent 9e73dad commit 7311068

File tree

2 files changed

+36
-5
lines changed

2 files changed

+36
-5
lines changed

lib/SILOptimizer/Differentiation/Thunk.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -786,10 +786,8 @@ getOrCreateSubsetParametersThunkForDerivativeFunction(
786786
// Extract all direct results.
787787
SmallVector<SILValue, 8> directResults;
788788
extractAllElements(apply, builder, directResults);
789-
auto originalDirectResults = ArrayRef<SILValue>(directResults).drop_back(1);
790-
auto originalDirectResult =
791-
joinElements(originalDirectResults, builder, apply->getLoc());
792789
auto linearMap = directResults.back();
790+
directResults.pop_back();
793791

794792
auto linearMapType = linearMap->getType().castTo<SILFunctionType>();
795793
auto linearMapTargetType = targetType->getResults()
@@ -830,8 +828,8 @@ getOrCreateSubsetParametersThunkForDerivativeFunction(
830828
0);
831829
if (origFnType->getNumResults() > 0 &&
832830
origFnType->getResults().front().isFormalDirect()) {
833-
auto result =
834-
joinElements({originalDirectResult, thunkedLinearMap}, builder, loc);
831+
directResults.push_back(thunkedLinearMap);
832+
auto result = joinElements(directResults, builder, loc);
835833
builder.createReturn(loc, result);
836834
} else {
837835
builder.createReturn(loc, thunkedLinearMap);
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// RUN: %target-swift-frontend -emit-sil %s | %FileCheck %s
2+
3+
// Verify the result type of a subset parameters thunk matches the declaration:
4+
//
5+
// CHECK: // autodiff subset parameters thunk for forward-mode derivative from f(x:)
6+
// CHECK-NEXT: sil shared [transparent] [thunk] @$s17param_thunk_tuple{{.*}} : $@convention(thin) (X)
7+
// CHECK-SAME: -> (Float, Double, @owned @callee_guaranteed (X.TangentVector) -> Float)
8+
// CHECK: return
9+
// CHECK-SAME: %{{.*}} : $(Float, Double, @callee_guaranteed (X.TangentVector) -> Float)
10+
//
11+
// CHECK: // autodiff subset parameters thunk for reverse-mode derivative from f(x:)
12+
// CHECK-NEXT: sil shared [transparent] [thunk] @$s17param_thunk_tuple{{.*}} : $@convention(thin) (X)
13+
// CHECK-SAME: -> (Float, Double, @owned @callee_guaranteed (Float) -> X.TangentVector)
14+
// CHECK: return
15+
// CHECK-SAME: %{{.*}} : $(Float, Double, @callee_guaranteed (Float) -> X.TangentVector)
16+
17+
import _Differentiation
18+
19+
struct X: Differentiable {
20+
var a: Float
21+
var b: Double
22+
}
23+
24+
@differentiable(reverse)
25+
func f(x: X) -> (Float, Double) {
26+
(x.a, x.b)
27+
}
28+
29+
@differentiable(reverse)
30+
func g1(x: X) -> Float {
31+
f(x: x).0
32+
}
33+

0 commit comments

Comments
 (0)