@@ -30,7 +30,7 @@ extension Array.DifferentiableView: Differentiable
30
30
where Element: Differentiable {
31
31
/// The viewed array.
32
32
public var base : [ Element ] {
33
- get { return _base }
33
+ get { _base }
34
34
_modify { yield & _base }
35
35
}
36
36
@@ -58,9 +58,10 @@ where Element: Differentiable {
58
58
59
59
public mutating func move( along direction: TangentVector ) {
60
60
precondition (
61
- base. count == direction. base. count,
62
- " cannot move Array.DifferentiableView with count \( base. count) along "
63
- + " direction with different count \( direction. base. count) " )
61
+ base. count == direction. base. count, """
62
+ Count mismatch: \( base. count) ('self') and \( direction. base. count) \
63
+ ('direction')
64
+ """ )
64
65
for i in base. indices {
65
66
base [ i] . move ( along: direction. base [ i] )
66
67
}
@@ -106,35 +107,31 @@ where Element: AdditiveArithmetic & Differentiable {
106
107
lhs: Array . DifferentiableView ,
107
108
rhs: Array . DifferentiableView
108
109
) -> Array . DifferentiableView {
109
- precondition (
110
- lhs. base. count == 0 || rhs. base. count == 0
111
- || lhs. base. count == rhs. base. count,
112
- " cannot add Array.DifferentiableViews with different counts: "
113
- + " \( lhs. base. count) and \( rhs. base. count) " )
114
110
if lhs. base. count == 0 {
115
111
return rhs
116
112
}
117
113
if rhs. base. count == 0 {
118
114
return lhs
119
115
}
116
+ precondition (
117
+ lhs. base. count == rhs. base. count,
118
+ " Count mismatch: \( lhs. base. count) and \( rhs. base. count) " )
120
119
return Array . DifferentiableView ( zip ( lhs. base, rhs. base) . map ( + ) )
121
120
}
122
121
123
122
public static func - (
124
123
lhs: Array . DifferentiableView ,
125
124
rhs: Array . DifferentiableView
126
125
) -> Array . DifferentiableView {
127
- precondition (
128
- lhs. base. count == 0 || rhs. base. count == 0
129
- || lhs. base. count == rhs. base. count,
130
- " cannot subtract Array.DifferentiableViews with different counts: "
131
- + " \( lhs. base. count) and \( rhs. base. count) " )
132
126
if lhs. base. count == 0 {
133
127
return rhs
134
128
}
135
129
if rhs. base. count == 0 {
136
130
return lhs
137
131
}
132
+ precondition (
133
+ lhs. base. count == rhs. base. count,
134
+ " Count mismatch: \( lhs. base. count) and \( rhs. base. count) " )
138
135
return Array . DifferentiableView ( zip ( lhs. base, rhs. base) . map ( - ) )
139
136
}
140
137
@@ -202,10 +199,10 @@ extension Array where Element: Differentiable {
202
199
) {
203
200
func pullback( _ v: TangentVector ) -> ( TangentVector , TangentVector ) {
204
201
precondition (
205
- v. base. count == lhs. count + rhs. count,
206
- " + should receive gradient with count equal to sum of operand "
207
- + " counts, but counts are: gradient \( v . base . count) , "
208
- + " lhs \( lhs . count ) , rhs \( rhs . count ) " )
202
+ v. base. count == lhs. count + rhs. count, """
203
+ Tangent vector with invalid count; expected to equal the sum of \
204
+ operand counts \( lhs . count ) and \( rhs . count)
205
+ "" ")
209
206
return (
210
207
TangentVector ( [ Element . TangentVector] ( v. base [ 0 ..< lhs. count] ) ) ,
211
208
TangentVector ( [ Element . TangentVector] ( v. base [ lhs. count... ] ) )
0 commit comments