1
- # #############
2
- # DiffResult #
3
- # #############
1
+ # ########
2
+ # Types #
3
+ # ########
4
4
5
- mutable struct DiffResult{O,V,D<: Tuple }
5
+ abstract type DiffResult{O,V,D<: Tuple } end
6
+
7
+ struct ImmutableDiffResult{O,V,D<: Tuple } <: DiffResult{O,V,D}
6
8
value:: V
7
9
derivs:: D # ith element = ith-order derivative
8
- function DiffResult {O,V,D} (value:: V , derivs:: NTuple{O,Any} ) where {O,V,D }
9
- return new {O,V,D } (value, derivs)
10
+ function ImmutableDiffResult (value:: V , derivs:: NTuple{O,Any} ) where {O,V}
11
+ return new {O,V,typeof(derivs) } (value, derivs)
10
12
end
11
13
end
12
14
13
- """
14
- DiffResult(value, derivs::Tuple)
15
+ mutable struct MutableDiffResult{O,V,D<: Tuple } <: DiffResult{O,V,D}
16
+ value:: V
17
+ derivs:: D # ith element = ith-order derivative
18
+ function MutableDiffResult (value:: V , derivs:: NTuple{O,Any} ) where {O,V}
19
+ return new {O,V,typeof(derivs)} (value, derivs)
20
+ end
21
+ end
15
22
16
- Return a `DiffResult` instance where values will be stored in the provided `value` storage
17
- and derivatives will be stored in the provided `derivs` storage.
23
+ # ###############
24
+ # Constructors #
25
+ # ###############
18
26
19
- Note that the arguments can be `Number`s or `AbstractArray`s, depending on the dimensionality
20
- of your target function.
21
27
"""
22
- DiffResult {V,O} (value:: V , derivs:: NTuple{O,Any} ) = DiffResult {O,V,typeof(derivs)} (value, derivs)
28
+ DiffResult(value::Union{Number,AbstractArray}, derivs::Tuple{Vararg{Number}})
29
+ DiffResult(value::Union{Number,AbstractArray}, derivs::Tuple{Vararg{AbstractArray}})
23
30
24
- """
25
- DiffResult(value, derivs...)
31
+ Return `r::DiffResult`, with output value storage provided by `value` and output derivative
32
+ storage provided by `derivs`.
33
+
34
+ In reality, `DiffResult` is an abstract supertype of two concrete types, `MutableDiffResult`
35
+ and `ImmutableDiffResult`. If all `value`/`derivs` are all `Number`s or `SArray`s, then `r`
36
+ will be immutable (i.e. `r::ImmutableDiffResult`). Otherwise, `r` will be mutable
37
+ (i.e. `r::MutableDiffResult`).
26
38
27
- Equivalent to `DiffResult(value, derivs::Tuple)`, where `derivs...` is the splatted form of ` derivs::Tuple `.
39
+ Note that ` derivs` can be provide in splatted form, i.e. `DiffResult(value, derivs...) `.
28
40
"""
29
- DiffResult (value, derivs... ) = DiffResult (value, derivs)
41
+ DiffResult
42
+
43
+ DiffResult (value:: Number , derivs:: Tuple{Vararg{Number}} ) = ImmutableDiffResult (value, derivs)
44
+ DiffResult (value:: Number , derivs:: Tuple{Vararg{SArray}} ) = ImmutableDiffResult (value, derivs)
45
+ DiffResult (value:: SArray , derivs:: Tuple{Vararg{SArray}} ) = ImmutableDiffResult (value, derivs)
46
+ DiffResult (value:: Number , derivs:: Tuple{Vararg{AbstractArray}} ) = MutableDiffResult (value, derivs)
47
+ DiffResult (value:: AbstractArray , derivs:: Tuple{Vararg{AbstractArray}} ) = MutableDiffResult (value, derivs)
48
+ DiffResult (value:: Union{Number,AbstractArray} , derivs:: Union{Number,AbstractArray} ...) = DiffResult (value, derivs)
30
49
31
50
"""
32
51
GradientResult(x::AbstractArray)
33
52
34
- Construct a `DiffResult` that can be used for gradient calculations where `x` is the
35
- input to the target function.
53
+ Construct a `DiffResult` that can be used for gradient calculations where `x` is the input
54
+ to the target function.
36
55
37
56
Note that `GradientResult` allocates its own storage; `x` is only used for type and
38
57
shape information. If you want to allocate storage yourself, use the `DiffResult`
39
58
constructor instead.
40
59
"""
41
60
GradientResult (x:: AbstractArray ) = DiffResult (first (x), similar (x))
61
+ GradientResult (x:: SArray ) = DiffResult (first (x), x)
42
62
43
63
"""
44
64
JacobianResult(x::AbstractArray)
45
65
46
- Construct a `DiffResult` that can be used for Jacobian calculations where `x` is the
47
- input to the target function. This method assumes that the target function's output
48
- dimension equals its input dimension.
66
+ Construct a `DiffResult` that can be used for Jacobian calculations where `x` is the input
67
+ to the target function. This method assumes that the target function's output dimension
68
+ equals its input dimension.
49
69
50
70
Note that `JacobianResult` allocates its own storage; `x` is only used for type and
51
71
shape information. If you want to allocate storage yourself, use the `DiffResult`
52
72
constructor instead.
53
73
"""
54
74
JacobianResult (x:: AbstractArray ) = DiffResult (similar (x), similar (x, length (x), length (x)))
75
+ JacobianResult (x:: SArray{<:Any,T,<:Any,L} ) where {T,L} = DiffResult (x, zeros (SMatrix{L,L,T}))
55
76
56
77
"""
57
78
JacobianResult(y::AbstractArray, x::AbstractArray)
@@ -64,6 +85,7 @@ Like the single argument version, `y` and `x` are only used for type and
64
85
shape information and are not stored in the returned `DiffResult`.
65
86
"""
66
87
JacobianResult (y:: AbstractArray , x:: AbstractArray ) = DiffResult (similar (y), similar (y, length (y), length (x)))
88
+ JacobianResult (y:: SArray{<:Any,<:Any,<:Any,Y} , x:: SArray{<:Any,T,<:Any,X} ) where {T,Y,X} = DiffResult (y, zeros (SMatrix{Y,X,T}))
67
89
68
90
"""
69
91
HessianResult(x::AbstractArray)
@@ -76,9 +98,30 @@ shape information. If you want to allocate storage yourself, use the `DiffResult
76
98
constructor instead.
77
99
"""
78
100
HessianResult (x:: AbstractArray ) = DiffResult (first (x), similar (x), similar (x, length (x), length (x)))
101
+ HessianResult (x:: SArray{<:Any,T,<:Any,L} ) where {T,L} = DiffResult (first (x), x, zeros (SMatrix{L,L,T}))
102
+
103
+ # ############
104
+ # Interface #
105
+ # ############
106
+
107
+ @generated function tuple_eltype (x:: Tuple , :: Type{Val{i}} ) where {i}
108
+ return quote
109
+ $ (Expr (:meta , :inline ))
110
+ return $ (x. parameters[i])
111
+ end
112
+ end
113
+
114
+ @generated function tuple_setindex (x:: NTuple{N,Any} , y, :: Type{Val{i}} ) where {N,i}
115
+ new_tuple = Expr (:tuple , [ifelse (i == n, :y , :(x[$ n])) for n in 1 : N]. .. )
116
+ return quote
117
+ $ (Expr (:meta , :inline ))
118
+ return $ new_tuple
119
+ end
120
+ end
79
121
80
122
Base. eltype (r:: DiffResult ) = eltype (typeof (r))
81
- Base. eltype {O,V,D} (:: Type{DiffResult{O,V,D}} ) = eltype (V)
123
+
124
+ Base. eltype (:: Type{D} ) where {O,V,D<: DiffResult{O,V} } = eltype (V)
82
125
83
126
Base.:(== )(a:: DiffResult , b:: DiffResult ) = a. value == b. value && a. derivs == b. derivs
84
127
@@ -92,26 +135,35 @@ Base.copy(r::DiffResult) = DiffResult(copy(r.value), map(copy, r.derivs))
92
135
93
136
Return the primal value stored in `r`.
94
137
95
- Note that this method returns a reference, not a copy. Thus, if `value(r)` is mutable,
96
- mutating `value(r)` will mutate `r`.
138
+ Note that this method returns a reference, not a copy.
97
139
"""
98
140
value (r:: DiffResult ) = r. value
99
141
100
142
"""
101
143
value!(r::DiffResult, x)
102
144
103
- Copy `x` into `r`'s value storage, such that `value(r) == x`.
145
+ Return `s::DiffResult` with the same data as `r`, except for `value(s) == x`.
146
+
147
+ This function may or may not mutate `r`. If `r::ImmutableDiffResult`, a totally new
148
+ instance will be created and returned, whereas if `r::MutableDiffResult`, then `r` will be
149
+ mutated in-place and returned. Thus, this function should be called as `r = value!(r, x)`.
104
150
"""
105
- value! (r:: DiffResult , x:: Number ) = (r. value = x; return r)
106
- value! (r:: DiffResult , x:: AbstractArray ) = (copy! (value (r), x); return r)
151
+ value! (r:: MutableDiffResult , x:: Number ) = (r. value = x; return r)
152
+ value! (r:: MutableDiffResult , x:: AbstractArray ) = (copy! (value (r), x); return r)
153
+ value! (r:: ImmutableDiffResult , x:: Union{Number,SArray} ) = ImmutableDiffResult (x, r. derivs)
154
+ value! (r:: ImmutableDiffResult , x:: AbstractArray ) = ImmutableDiffResult (typeof (value (r))(x), r. derivs)
107
155
108
156
"""
109
157
value!(f, r::DiffResult, x)
110
158
111
- Like `value!(r::DiffResult, x)`, but with `f` applied to each element, such that `value(r) == map(f, x)`.
159
+ Equivalent to `value!(r::DiffResult, map(f, x))`, but without the implied temporary
160
+ allocation (when possible).
112
161
"""
113
- value! (f, r:: DiffResult , x:: Number ) = (r. value = f (x); return r)
114
- value! (f, r:: DiffResult , x:: AbstractArray ) = (map! (f, value (r), x); return r)
162
+ value! (f, r:: MutableDiffResult , x:: Number ) = (r. value = f (x); return r)
163
+ value! (f, r:: MutableDiffResult , x:: AbstractArray ) = (map! (f, value (r), x); return r)
164
+ value! (f, r:: ImmutableDiffResult , x:: Number ) = value! (r, f (x))
165
+ value! (f, r:: ImmutableDiffResult , x:: SArray ) = value! (r, map (f, x))
166
+ value! (f, r:: ImmutableDiffResult , x:: AbstractArray ) = value! (r, map (f, typeof (value (r))(x)))
115
167
116
168
# derivative/derivative! #
117
169
# ------------------------#
@@ -121,122 +173,159 @@ value!(f, r::DiffResult, x::AbstractArray) = (map!(f, value(r), x); return r)
121
173
122
174
Return the `ith` derivative stored in `r`, defaulting to the first derivative.
123
175
124
- Note that this method returns a reference, not a copy. Thus, if `derivative(r)` is mutable,
125
- mutating `derivative(r)` will mutate `r`.
176
+ Note that this method returns a reference, not a copy.
126
177
"""
127
- derivative {i} (r:: DiffResult , :: Type{Val{i}} = Val{1 }) = r. derivs[i]
178
+ derivative (r:: DiffResult , :: Type{Val{i}} = Val{1 }) where {i} = r. derivs[i]
128
179
129
180
"""
130
181
derivative!(r::DiffResult, x, ::Type{Val{i}} = Val{1})
131
182
132
- Copy `x` into `r`'s `ith` derivative storage, such that `derivative(r, Val{i}) == x`.
183
+ Return `s::DiffResult` with the same data as `r`, except `derivative(s, Val{i}) == x`.
184
+
185
+ This function may or may not mutate `r`. If `r::ImmutableDiffResult`, a totally new
186
+ instance will be created and returned, whereas if `r::MutableDiffResult`, then `r` will be
187
+ mutated in-place and returned. Thus, this function should be called as
188
+ `r = derivative!(r, x, Val{i})`.
133
189
"""
134
- @generated function derivative! {O,i} (r:: DiffResult{O} , x:: Number , :: Type{Val{i}} = Val{1 })
135
- newderivs = Expr (:tuple , [i == n ? :(x) : :(derivative (r, Val{$ n})) for n in 1 : O]. .. )
136
- return quote
137
- r. derivs = $ newderivs
138
- return r
139
- end
190
+ function derivative! (r:: MutableDiffResult , x:: Number , :: Type{Val{i}} = Val{1 }) where {i}
191
+ r. derivs = tuple_setindex (r. derivs, x, Val{i})
192
+ return r
140
193
end
141
194
142
- function derivative! {i} (r:: DiffResult , x:: AbstractArray , :: Type{Val{i}} = Val{1 })
195
+ function derivative! (r:: MutableDiffResult , x:: AbstractArray , :: Type{Val{i}} = Val{1 }) where {i}
143
196
copy! (derivative (r, Val{i}), x)
144
197
return r
145
198
end
146
199
200
+ function derivative! (r:: ImmutableDiffResult , x:: Union{Number,SArray} , :: Type{Val{i}} = Val{1 }) where {i}
201
+ return ImmutableDiffResult (value (r), tuple_setindex (r. derivs, x, Val{i}))
202
+ end
203
+
204
+ function derivative! (r:: ImmutableDiffResult , x:: AbstractArray , :: Type{Val{i}} = Val{1 }) where {i}
205
+ T = tuple_eltype (r. derivs, Val{i})
206
+ return ImmutableDiffResult (value (r), tuple_setindex (r. derivs, T (x), Val{i}))
207
+ end
208
+
147
209
"""
148
210
derivative!(f, r::DiffResult, x, ::Type{Val{i}} = Val{1})
149
211
150
- Like `derivative!(r::DiffResult, x , Val{i})`, but with `f` applied to each element,
151
- such that `derivative(r, Val{i}) == map(f, x)` .
212
+ Equivalent to `derivative!(r::DiffResult, map(f, x) , Val{i})`, but without the implied
213
+ temporary allocation (when possible) .
152
214
"""
153
- @generated function derivative! {O,i} (f, r:: DiffResult{O} , x:: Number , :: Type{Val{i}} = Val{1 })
154
- newderivs = Expr (:tuple , [i == n ? :(f (x)) : :(derivative (r, Val{$ n})) for n in 1 : O]. .. )
155
- return quote
156
- r. derivs = $ newderivs
157
- return r
158
- end
215
+ function derivative! (f, r:: MutableDiffResult , x:: Number , :: Type{Val{i}} = Val{1 }) where {i}
216
+ r. derivs = tuple_setindex (r. derivs, f (x), Val{i})
217
+ return r
159
218
end
160
219
161
- function derivative! {i} (f, r:: DiffResult , x:: AbstractArray , :: Type{Val{i}} = Val{1 })
220
+ function derivative! (f, r:: MutableDiffResult , x:: AbstractArray , :: Type{Val{i}} = Val{1 }) where {i}
162
221
map! (f, derivative (r, Val{i}), x)
163
222
return r
164
223
end
165
224
225
+ function derivative! (f, r:: ImmutableDiffResult , x:: Number , :: Type{Val{i}} = Val{1 }) where {i}
226
+ return derivative! (r, f (x), Val{i})
227
+ end
228
+
229
+ function derivative! (f, r:: ImmutableDiffResult , x:: SArray , :: Type{Val{i}} = Val{1 }) where {i}
230
+ return derivative! (r, map (f, x), Val{i})
231
+ end
232
+
233
+ function derivative! (f, r:: ImmutableDiffResult , x:: AbstractArray , :: Type{Val{i}} = Val{1 }) where {i}
234
+ T = tuple_eltype (r. derivs, Val{i})
235
+ return derivative! (r, map (f, T (x)), Val{i})
236
+ end
237
+
166
238
# special-cased methods #
167
239
# -----------------------#
168
240
169
241
"""
170
242
gradient(r::DiffResult)
171
243
172
- Return the gradient stored in `r` (equivalent to `derivative(r)`) .
244
+ Return the gradient stored in `r`.
173
245
174
- Note that this method returns a reference, not a copy. Thus, if `gradient(r)` is mutable,
175
- mutating `gradient(r)` will mutate `r`.
246
+ Equivalent to `derivative(r, Val{1})`; see `derivative` docs for aliasing behavior.
176
247
"""
177
248
gradient (r:: DiffResult ) = derivative (r)
178
249
179
250
"""
180
251
gradient!(r::DiffResult, x)
181
252
182
- Copy `x` into `r`'s gradient storage, such that `gradient(r) == x`.
253
+ Return `s::DiffResult` with the same data as `r`, except `gradient(s) == x`.
254
+
255
+ Equivalent to `derivative!(r, x, Val{1})`; see `derivative!` docs for aliasing behavior.
183
256
"""
184
257
gradient! (r:: DiffResult , x) = derivative! (r, x)
185
258
186
259
"""
187
260
gradient!(f, r::DiffResult, x)
188
261
189
- Like `gradient!(r::DiffResult, x)`, but with `f` applied to each element,
190
- such that `gradient(r) == map(f, x)`.
262
+ Equivalent to `gradient!(r::DiffResult, map(f, x))`, but without the implied temporary
263
+ allocation (when possible).
264
+
265
+ Equivalent to `derivative!(f, r, x, Val{1})`; see `derivative!` docs for aliasing behavior.
191
266
"""
192
267
gradient! (f, r:: DiffResult , x) = derivative! (f, r, x)
193
268
194
269
"""
195
270
jacobian(r::DiffResult)
196
271
197
- Return the Jacobian stored in `r` (equivalent to `derivative(r)`) .
272
+ Return the Jacobian stored in `r`.
198
273
199
- Note that this method returns a reference, not a copy. Thus, if `jacobian(r)` is mutable,
200
- mutating `jacobian(r)` will mutate `r`.
274
+ Equivalent to `derivative(r, Val{1})`; see `derivative` docs for aliasing behavior.
201
275
"""
202
276
jacobian (r:: DiffResult ) = derivative (r)
203
277
204
278
"""
205
279
jacobian!(r::DiffResult, x)
206
280
207
- Copy `x` into `r`'s Jacobian storage, such that `jacobian(r) == x`.
281
+ Return `s::DiffResult` with the same data as `r`, except `jacobian(s) == x`.
282
+
283
+ Equivalent to `derivative!(r, x, Val{1})`; see `derivative!` docs for aliasing behavior.
208
284
"""
209
285
jacobian! (r:: DiffResult , x) = derivative! (r, x)
210
286
211
287
"""
212
288
jacobian!(f, r::DiffResult, x)
213
289
214
- Like `jacobian!(r::DiffResult, x)`, but with `f` applied to each element,
215
- such that `jacobian(r) == map(f, x)`.
290
+ Equivalent to `jacobian!(r::DiffResult, map(f, x))`, but without the implied temporary
291
+ allocation (when possible).
292
+
293
+ Equivalent to `derivative!(f, r, x, Val{1})`; see `derivative!` docs for aliasing behavior.
216
294
"""
217
295
jacobian! (f, r:: DiffResult , x) = derivative! (f, r, x)
218
296
219
297
"""
220
298
hessian(r::DiffResult)
221
299
222
- Return the Hessian stored in `r` (equivalent to `derivative(r, Val{2})`) .
300
+ Return the Hessian stored in `r`.
223
301
224
- Note that this method returns a reference, not a copy. Thus, if `hessian(r)` is mutable,
225
- mutating `hessian(r)` will mutate `r`.
302
+ Equivalent to `derivative(r, Val{2})`; see `derivative` docs for aliasing behavior.
226
303
"""
227
304
hessian (r:: DiffResult ) = derivative (r, Val{2 })
228
305
229
306
"""
230
307
hessian!(r::DiffResult, x)
231
308
232
- Copy `x` into `r`'s Hessian storage, such that `hessian(r) == x`.
309
+ Return `s::DiffResult` with the same data as `r`, except `hessian(s) == x`.
310
+
311
+ Equivalent to `derivative!(r, x, Val{2})`; see `derivative!` docs for aliasing behavior.
233
312
"""
234
313
hessian! (r:: DiffResult , x) = derivative! (r, x, Val{2 })
235
314
236
315
"""
237
316
hessian!(f, r::DiffResult, x)
238
317
239
- Like `hessian!(r::DiffResult, x)`, but with `f` applied to each element,
240
- such that `hessian(r) == map(f, x)`.
318
+ Equivalent to `hessian!(r::DiffResult, map(f, x))`, but without the implied temporary
319
+ allocation (when possible).
320
+
321
+ Equivalent to `derivative!(f, r, x, Val{2})`; see `derivative!` docs for aliasing behavior.
241
322
"""
242
323
hessian! (f, r:: DiffResult , x) = derivative! (f, r, x, Val{2 })
324
+
325
+ # ##################
326
+ # Pretty Printing #
327
+ # ##################
328
+
329
+ Base. show (io:: IO , r:: ImmutableDiffResult ) = print (io, " ImmutableDiffResult($(r. value) , $(r. derivs) )" )
330
+
331
+ Base. show (io:: IO , r:: MutableDiffResult ) = print (io, " MutableDiffResult($(r. value) , $(r. derivs) )" )
0 commit comments