Skip to content

Commit 73984ad

Browse files
authored
[AutoDiff] [stdlib] Add differentiable map and reduce methods on 'Array'. (#26023)
Add variants of map and reduce that take a `@differentiable` closure and are themselves differentiable. ```swift extension Array { @differentiable(wrt: self) func differentiableMap<Result: Differentiable>( _ body: @differentiable (Element) -> Result ) -> [Result] @differentiable(wrt: (self, initialResult)) func differentiableReduce<Result: Differentiable>( _ initialResult: Result, _ nextPartialResult: @differentiable (Result, Element) -> Result ) -> Result } ``` Also make `Array.DifferentiableView` conform to `ExpressibleByArrayLiteral` so that tests and user code are easier to write.
1 parent 2ce6f9c commit 73984ad

File tree

3 files changed

+115
-0
lines changed

3 files changed

+115
-0
lines changed

stdlib/public/core/Array.swift

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1994,6 +1994,12 @@ extension Array.DifferentiableView : Equatable where Element : Equatable {
19941994
}
19951995
}
19961996

1997+
extension Array.DifferentiableView : ExpressibleByArrayLiteral {
1998+
public init(arrayLiteral elements: Element...) {
1999+
self.init(elements)
2000+
}
2001+
}
2002+
19972003
extension Array.DifferentiableView : CustomStringConvertible {
19982004
public var description: String {
19992005
return base.description

stdlib/public/core/AutoDiff.swift

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -885,3 +885,76 @@ public struct AnyDerivative : Differentiable & AdditiveArithmetic {
885885
_box._move(along: direction._box)
886886
}
887887
}
888+
889+
//===----------------------------------------------------------------------===//
890+
// Differentiable higher order functions for collections
891+
//===----------------------------------------------------------------------===//
892+
893+
public extension Array where Element: Differentiable {
894+
@differentiable(wrt: (self, initialResult), vjp: _vjpDifferentiableReduce)
895+
func differentiableReduce<Result: Differentiable>(
896+
_ initialResult: Result,
897+
_ nextPartialResult: @differentiable (Result, Element) -> Result
898+
) -> Result {
899+
reduce(initialResult, nextPartialResult)
900+
}
901+
902+
@usableFromInline
903+
internal func _vjpDifferentiableReduce<Result: Differentiable>(
904+
_ initialResult: Result,
905+
_ nextPartialResult: @differentiable (Result, Element) -> Result
906+
) -> (value: Result,
907+
pullback: (Result.TangentVector)
908+
-> (Array.TangentVector, Result.TangentVector)) {
909+
var pullbacks:
910+
[(Result.TangentVector) -> (Result.TangentVector, Element.TangentVector)]
911+
= []
912+
let count = self.count
913+
pullbacks.reserveCapacity(count)
914+
var result = initialResult
915+
for element in self {
916+
let (y, pb) =
917+
Swift.valueWithPullback(at: result, element, in: nextPartialResult)
918+
result = y
919+
pullbacks.append(pb)
920+
}
921+
return (value: result, pullback: { tangent in
922+
var resultTangent = tangent
923+
var elementTangents = TangentVector([])
924+
elementTangents.base.reserveCapacity(count)
925+
for pullback in pullbacks.reversed() {
926+
let (newResultTangent, elementTangent) = pullback(resultTangent)
927+
resultTangent = newResultTangent
928+
elementTangents.base.append(elementTangent)
929+
}
930+
return (TangentVector(elementTangents.base.reversed()), resultTangent)
931+
})
932+
}
933+
}
934+
935+
public extension Array where Element: Differentiable {
936+
@differentiable(wrt: self, vjp: _vjpDifferentiableMap)
937+
func differentiableMap<Result: Differentiable>(
938+
_ body: @differentiable (Element) -> Result
939+
) -> [Result] {
940+
map(body)
941+
}
942+
943+
@usableFromInline
944+
internal func _vjpDifferentiableMap<Result: Differentiable>(
945+
_ body: @differentiable (Element) -> Result
946+
) -> (value: [Result],
947+
pullback: (Array<Result>.TangentVector) -> Array.TangentVector) {
948+
var values: [Result] = []
949+
var pullbacks: [(Result.TangentVector) -> Element.TangentVector] = []
950+
for x in self {
951+
let (y, pb) = Swift.valueWithPullback(at: x, in: body)
952+
values.append(y)
953+
pullbacks.append(pb)
954+
}
955+
func pullback(_ tans: Array<Result>.TangentVector) -> Array.TangentVector {
956+
.init(zip(tans.base, pullbacks).map { tan, pb in pb(tan) })
957+
}
958+
return (value: values, pullback: pullback)
959+
}
960+
}

test/AutoDiff/collection_hofs.swift

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// RUN: %target-run-simple-swift
2+
// REQUIRES: executable_test
3+
4+
import StdlibUnittest
5+
#if os(macOS)
6+
import Darwin.C
7+
#else
8+
import Glibc
9+
#endif
10+
11+
// Test suite for differentiable higher order functions for collections
12+
// such as `differentiableMap(_:)` and `differentiableReduce(_:)`.
13+
var CollectionHOFTests = TestSuite("CollectionHOF")
14+
15+
let xx: [Float] = [1, 2, 3, 4, 5]
16+
17+
CollectionHOFTests.test("differentiableMap(_:)") {
18+
func double(_ xx: [Float]) -> [Float] {
19+
xx.differentiableMap { $0 * $0 }
20+
}
21+
expectEqual([], pullback(at: xx, in: double)([]))
22+
expectEqual([0], pullback(at: xx, in: double)([0]))
23+
expectEqual([2], pullback(at: xx, in: double)([1]))
24+
expectEqual([2, 4, 6, 8, 10], pullback(at: xx, in: double)([1, 1, 1, 1, 1]))
25+
}
26+
27+
CollectionHOFTests.test("differentiableReduce(_:)") {
28+
func product(_ xx: [Float]) -> Float {
29+
xx.differentiableReduce(1) { $0 * $1 }
30+
}
31+
expectEqual([1], gradient(at: [0], in: product))
32+
expectEqual([1], gradient(at: [1], in: product))
33+
expectEqual([120, 60, 40, 30, 24], gradient(at: xx, in: product))
34+
}
35+
36+
runAllTests()

0 commit comments

Comments
 (0)