diff --git a/stdlib/public/Differentiation/ArrayDifferentiation.swift b/stdlib/public/Differentiation/ArrayDifferentiation.swift index fbaef9c34fe80..a4b1149a20f56 100644 --- a/stdlib/public/Differentiation/ArrayDifferentiation.swift +++ b/stdlib/public/Differentiation/ArrayDifferentiation.swift @@ -30,7 +30,7 @@ extension Array.DifferentiableView: Differentiable where Element: Differentiable { /// The viewed array. public var base: [Element] { - get { return _base } + get { _base } _modify { yield &_base } } @@ -58,9 +58,10 @@ where Element: Differentiable { public mutating func move(along direction: TangentVector) { precondition( - base.count == direction.base.count, - "cannot move Array.DifferentiableView with count \(base.count) along " - + "direction with different count \(direction.base.count)") + base.count == direction.base.count, """ + Count mismatch: \(base.count) ('self') and \(direction.base.count) \ + ('direction') + """) for i in base.indices { base[i].move(along: direction.base[i]) } @@ -106,17 +107,15 @@ where Element: AdditiveArithmetic & Differentiable { lhs: Array.DifferentiableView, rhs: Array.DifferentiableView ) -> Array.DifferentiableView { - precondition( - lhs.base.count == 0 || rhs.base.count == 0 - || lhs.base.count == rhs.base.count, - "cannot add Array.DifferentiableViews with different counts: " - + "\(lhs.base.count) and \(rhs.base.count)") if lhs.base.count == 0 { return rhs } if rhs.base.count == 0 { return lhs } + precondition( + lhs.base.count == rhs.base.count, + "Count mismatch: \(lhs.base.count) and \(rhs.base.count)") return Array.DifferentiableView(zip(lhs.base, rhs.base).map(+)) } @@ -124,17 +123,15 @@ where Element: AdditiveArithmetic & Differentiable { lhs: Array.DifferentiableView, rhs: Array.DifferentiableView ) -> Array.DifferentiableView { - precondition( - lhs.base.count == 0 || rhs.base.count == 0 - || lhs.base.count == rhs.base.count, - "cannot subtract Array.DifferentiableViews with different counts: " - + "\(lhs.base.count) and \(rhs.base.count)") if lhs.base.count == 0 { return rhs } if rhs.base.count == 0 { return lhs } + precondition( + lhs.base.count == rhs.base.count, + "Count mismatch: \(lhs.base.count) and \(rhs.base.count)") return Array.DifferentiableView(zip(lhs.base, rhs.base).map(-)) } @@ -202,10 +199,10 @@ extension Array where Element: Differentiable { ) { func pullback(_ v: TangentVector) -> (TangentVector, TangentVector) { precondition( - v.base.count == lhs.count + rhs.count, - "+ should receive gradient with count equal to sum of operand " - + "counts, but counts are: gradient \(v.base.count), " - + "lhs \(lhs.count), rhs \(rhs.count)") + v.base.count == lhs.count + rhs.count, """ + Tangent vector with invalid count; expected to equal the sum of \ + operand counts \(lhs.count) and \(rhs.count) + """) return ( TangentVector([Element.TangentVector](v.base[0..