diff --git a/lib/SILOptimizer/Differentiation/Thunk.cpp b/lib/SILOptimizer/Differentiation/Thunk.cpp index 52b505681ed43..487bce2929183 100644 --- a/lib/SILOptimizer/Differentiation/Thunk.cpp +++ b/lib/SILOptimizer/Differentiation/Thunk.cpp @@ -786,10 +786,8 @@ getOrCreateSubsetParametersThunkForDerivativeFunction( // Extract all direct results. SmallVector directResults; extractAllElements(apply, builder, directResults); - auto originalDirectResults = ArrayRef(directResults).drop_back(1); - auto originalDirectResult = - joinElements(originalDirectResults, builder, apply->getLoc()); auto linearMap = directResults.back(); + directResults.pop_back(); auto linearMapType = linearMap->getType().castTo(); auto linearMapTargetType = targetType->getResults() @@ -830,8 +828,8 @@ getOrCreateSubsetParametersThunkForDerivativeFunction( 0); if (origFnType->getNumResults() > 0 && origFnType->getResults().front().isFormalDirect()) { - auto result = - joinElements({originalDirectResult, thunkedLinearMap}, builder, loc); + directResults.push_back(thunkedLinearMap); + auto result = joinElements(directResults, builder, loc); builder.createReturn(loc, result); } else { builder.createReturn(loc, thunkedLinearMap); diff --git a/test/AutoDiff/SILOptimizer/param_thunk_tuple.swift b/test/AutoDiff/SILOptimizer/param_thunk_tuple.swift new file mode 100644 index 0000000000000..8a0402a099989 --- /dev/null +++ b/test/AutoDiff/SILOptimizer/param_thunk_tuple.swift @@ -0,0 +1,33 @@ +// RUN: %target-swift-frontend -emit-sil %s | %FileCheck %s + +// Verify the result type of a subset parameters thunk matches the declaration: +// +// CHECK: // autodiff subset parameters thunk for forward-mode derivative from f(x:) +// CHECK-NEXT: sil shared [transparent] [thunk] @$s17param_thunk_tuple{{.*}} : $@convention(thin) (X) +// CHECK-SAME: -> (Float, Double, @owned @callee_guaranteed (X.TangentVector) -> Float) +// CHECK: return +// CHECK-SAME: %{{.*}} : $(Float, Double, @callee_guaranteed (X.TangentVector) -> Float) +// +// CHECK: // autodiff subset parameters thunk for reverse-mode derivative from f(x:) +// CHECK-NEXT: sil shared [transparent] [thunk] @$s17param_thunk_tuple{{.*}} : $@convention(thin) (X) +// CHECK-SAME: -> (Float, Double, @owned @callee_guaranteed (Float) -> X.TangentVector) +// CHECK: return +// CHECK-SAME: %{{.*}} : $(Float, Double, @callee_guaranteed (Float) -> X.TangentVector) + +import _Differentiation + +struct X: Differentiable { + var a: Float + var b: Double +} + +@differentiable(reverse) +func f(x: X) -> (Float, Double) { + (x.a, x.b) +} + +@differentiable(reverse) +func g1(x: X) -> Float { + f(x: x).0 +} +