Skip to content

Commit 3bc3e30

Browse files
committed
Added tuple result tests, extracting tuple elements as semantic result types.
1 parent 72866b6 commit 3bc3e30

File tree

4 files changed

+76
-2
lines changed

4 files changed

+76
-2
lines changed

lib/AST/AutoDiff.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,16 @@ void autodiff::getFunctionSemanticResultTypes(
196196
functionType->getResult()->getAs<AnyFunctionType>()) {
197197
formalResultType = resultFunctionType->getResult();
198198
}
199-
if (!formalResultType->isEqual(ctx.TheEmptyTupleType))
200-
result.push_back({remap(formalResultType), /*isInout*/ false});
199+
if (!formalResultType->isEqual(ctx.TheEmptyTupleType)) {
200+
// Separate tuple elements into individual results.
201+
if (formalResultType->is<TupleType>()) {
202+
for (auto elt : formalResultType->castTo<TupleType>()->getElements()) {
203+
result.push_back({remap(elt.getType()), /*isInout*/ false});
204+
}
205+
} else {
206+
result.push_back({remap(formalResultType), /*isInout*/ false});
207+
}
208+
}
201209

202210
// Collect `inout` parameters as semantic results.
203211
for (auto param : functionType->getParams())

test/AutoDiff/Sema/derivative_attr_type_checking.swift

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,22 @@ extension InoutParameters {
871871
) { fatalError() }
872872
}
873873

874+
// Test tuple results.
875+
876+
extension InoutParameters {
877+
func tupleResults(_ x: Float) -> (Float, Float) { (x, x) }
878+
@derivative(of: tupleResults, wrt: x)
879+
func vjpTupleResults(_ x: Float) -> (
880+
value: (Float, Float), pullback: (Float, Float) -> Float
881+
) { fatalError() }
882+
883+
func tupleResultsInt(_ x: Float) -> (Int, Float) { (1, x) }
884+
@derivative(of: tupleResultsInt, wrt: x)
885+
func vjpTupleResults(_ x: Float) -> (
886+
value: (Int, Float), pullback: (Float) -> Float
887+
) { fatalError() }
888+
}
889+
874890
// Test original/derivative function `inout` parameter mismatches.
875891

876892
extension InoutParameters {

test/AutoDiff/Sema/differentiable_attr_type_checking.swift

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -685,6 +685,19 @@ extension InoutParameters {
685685
mutating func mutatingMethod(_ other: Self) -> Self {}
686686
}
687687

688+
// Test tuple results.
689+
690+
extension InoutParameters {
691+
@differentiable(reverse)
692+
static func tupleResults(_ x: Self) -> (Self, Self) {}
693+
694+
@differentiable(reverse)
695+
static func tupleResultsInt(_ x: Self) -> (Int, Self) {}
696+
697+
@differentiable(reverse)
698+
static func tupleResultsInt2(_ x: Self) -> (Self, Int) {}
699+
}
700+
688701
// Test accessors: `set`, `_read`, `_modify`.
689702

690703
struct Accessors: Differentiable {

test/AutoDiff/validation-test/simple_math.swift

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,43 @@ SimpleMathTests.test("MultipleResultsWithCustomPullback") {
147147
expectEqual((10, 5), gradient(at: 5, 10, of: multiply_swapCustom))
148148
}
149149

150+
// Test functions returning tuples.
151+
@differentiable(reverse)
152+
func swapTuple(_ x: Float, _ y: Float) -> (Float, Float) {
153+
return (y, x)
154+
}
155+
156+
@differentiable(reverse)
157+
func swapTupleCustom(_ x: Float, _ y: Float) -> (Float, Float) {
158+
return (y, x)
159+
}
160+
@derivative(of: swapTupleCustom)
161+
func vjpSwapTupleCustom(_ x: Float, _ y: Float) -> (
162+
value: (Float, Float), pullback: (Float, Float) -> (Float, Float)
163+
) {
164+
return (swapTupleCustom(x, y), {v1, v2 in
165+
return (v2, v1)
166+
})
167+
}
168+
169+
SimpleMathTests.test("ReturningTuples") {
170+
func multiply_swapTuple(_ x: Float, _ y: Float) -> Float {
171+
let result = swapTuple(x, y)
172+
return result.0 * result.1
173+
}
174+
175+
expectEqual((4, 3), gradient(at: 3, 4, of: multiply_swapTuple))
176+
expectEqual((10, 5), gradient(at: 5, 10, of: multiply_swapTuple))
177+
178+
func multiply_swapTupleCustom(_ x: Float, _ y: Float) -> Float {
179+
let result = swapTupleCustom(x, y)
180+
return result.0 * result.1
181+
}
182+
183+
expectEqual((4, 3), gradient(at: 3, 4, of: multiply_swapTupleCustom))
184+
expectEqual((10, 5), gradient(at: 5, 10, of: multiply_swapTupleCustom))
185+
}
186+
150187
SimpleMathTests.test("CaptureLocal") {
151188
let z: Float = 10
152189
func foo(_ x: Float) -> Float {

0 commit comments

Comments
 (0)