Skip to content

Commit f76169b

Browse files
dan-zhengrxwei
andauthored
[AutoDiff] Improve Array.TangentVector precondition messages. (swiftlang#32154)
Improve count-related precondition messages. Co-authored-by: Richard Wei <[email protected]>
1 parent d049281 commit f76169b

File tree

1 file changed

+15
-18
lines changed

1 file changed

+15
-18
lines changed

stdlib/public/Differentiation/ArrayDifferentiation.swift

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ extension Array.DifferentiableView: Differentiable
3030
where Element: Differentiable {
3131
/// The viewed array.
3232
public var base: [Element] {
33-
get { return _base }
33+
get { _base }
3434
_modify { yield &_base }
3535
}
3636

@@ -58,9 +58,10 @@ where Element: Differentiable {
5858

5959
public mutating func move(along direction: TangentVector) {
6060
precondition(
61-
base.count == direction.base.count,
62-
"cannot move Array.DifferentiableView with count \(base.count) along "
63-
+ "direction with different count \(direction.base.count)")
61+
base.count == direction.base.count, """
62+
Count mismatch: \(base.count) ('self') and \(direction.base.count) \
63+
('direction')
64+
""")
6465
for i in base.indices {
6566
base[i].move(along: direction.base[i])
6667
}
@@ -106,35 +107,31 @@ where Element: AdditiveArithmetic & Differentiable {
106107
lhs: Array.DifferentiableView,
107108
rhs: Array.DifferentiableView
108109
) -> Array.DifferentiableView {
109-
precondition(
110-
lhs.base.count == 0 || rhs.base.count == 0
111-
|| lhs.base.count == rhs.base.count,
112-
"cannot add Array.DifferentiableViews with different counts: "
113-
+ "\(lhs.base.count) and \(rhs.base.count)")
114110
if lhs.base.count == 0 {
115111
return rhs
116112
}
117113
if rhs.base.count == 0 {
118114
return lhs
119115
}
116+
precondition(
117+
lhs.base.count == rhs.base.count,
118+
"Count mismatch: \(lhs.base.count) and \(rhs.base.count)")
120119
return Array.DifferentiableView(zip(lhs.base, rhs.base).map(+))
121120
}
122121

123122
public static func - (
124123
lhs: Array.DifferentiableView,
125124
rhs: Array.DifferentiableView
126125
) -> Array.DifferentiableView {
127-
precondition(
128-
lhs.base.count == 0 || rhs.base.count == 0
129-
|| lhs.base.count == rhs.base.count,
130-
"cannot subtract Array.DifferentiableViews with different counts: "
131-
+ "\(lhs.base.count) and \(rhs.base.count)")
132126
if lhs.base.count == 0 {
133127
return rhs
134128
}
135129
if rhs.base.count == 0 {
136130
return lhs
137131
}
132+
precondition(
133+
lhs.base.count == rhs.base.count,
134+
"Count mismatch: \(lhs.base.count) and \(rhs.base.count)")
138135
return Array.DifferentiableView(zip(lhs.base, rhs.base).map(-))
139136
}
140137

@@ -202,10 +199,10 @@ extension Array where Element: Differentiable {
202199
) {
203200
func pullback(_ v: TangentVector) -> (TangentVector, TangentVector) {
204201
precondition(
205-
v.base.count == lhs.count + rhs.count,
206-
"+ should receive gradient with count equal to sum of operand "
207-
+ "counts, but counts are: gradient \(v.base.count), "
208-
+ "lhs \(lhs.count), rhs \(rhs.count)")
202+
v.base.count == lhs.count + rhs.count, """
203+
Tangent vector with invalid count; expected to equal the sum of \
204+
operand counts \(lhs.count) and \(rhs.count)
205+
""")
209206
return (
210207
TangentVector([Element.TangentVector](v.base[0..<lhs.count])),
211208
TangentVector([Element.TangentVector](v.base[lhs.count...]))

0 commit comments

Comments
 (0)