Skip to content

Commit 1f15e9c

Browse files
authored
allow for immutable DiffResults to support totally stack-allocated computations (#9)
1 parent 6ed8220 commit 1f15e9c

File tree

4 files changed

+337
-111
lines changed

4 files changed

+337
-111
lines changed

REQUIRE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
julia 0.6-
2+
StaticArrays 0.5.0

src/DiffBase.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ __precompile__()
22

33
module DiffBase
44

5+
using StaticArrays
6+
57
include("results.jl")
68
include("testfuncs.jl")
79
include("rules.jl")

src/results.jl

Lines changed: 158 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,78 @@
1-
##############
2-
# DiffResult #
3-
##############
1+
#########
2+
# Types #
3+
#########
44

5-
mutable struct DiffResult{O,V,D<:Tuple}
5+
abstract type DiffResult{O,V,D<:Tuple} end
6+
7+
struct ImmutableDiffResult{O,V,D<:Tuple} <: DiffResult{O,V,D}
68
value::V
79
derivs::D # ith element = ith-order derivative
8-
function DiffResult{O,V,D}(value::V, derivs::NTuple{O,Any}) where {O,V,D}
9-
return new{O,V,D}(value, derivs)
10+
function ImmutableDiffResult(value::V, derivs::NTuple{O,Any}) where {O,V}
11+
return new{O,V,typeof(derivs)}(value, derivs)
1012
end
1113
end
1214

13-
"""
14-
DiffResult(value, derivs::Tuple)
15+
mutable struct MutableDiffResult{O,V,D<:Tuple} <: DiffResult{O,V,D}
16+
value::V
17+
derivs::D # ith element = ith-order derivative
18+
function MutableDiffResult(value::V, derivs::NTuple{O,Any}) where {O,V}
19+
return new{O,V,typeof(derivs)}(value, derivs)
20+
end
21+
end
1522

16-
Return a `DiffResult` instance where values will be stored in the provided `value` storage
17-
and derivatives will be stored in the provided `derivs` storage.
23+
################
24+
# Constructors #
25+
################
1826

19-
Note that the arguments can be `Number`s or `AbstractArray`s, depending on the dimensionality
20-
of your target function.
2127
"""
22-
DiffResult{V,O}(value::V, derivs::NTuple{O,Any}) = DiffResult{O,V,typeof(derivs)}(value, derivs)
28+
DiffResult(value::Union{Number,AbstractArray}, derivs::Tuple{Vararg{Number}})
29+
DiffResult(value::Union{Number,AbstractArray}, derivs::Tuple{Vararg{AbstractArray}})
2330
24-
"""
25-
DiffResult(value, derivs...)
31+
Return `r::DiffResult`, with output value storage provided by `value` and output derivative
32+
storage provided by `derivs`.
33+
34+
In reality, `DiffResult` is an abstract supertype of two concrete types, `MutableDiffResult`
35+
and `ImmutableDiffResult`. If all `value`/`derivs` are all `Number`s or `SArray`s, then `r`
36+
will be immutable (i.e. `r::ImmutableDiffResult`). Otherwise, `r` will be mutable
37+
(i.e. `r::MutableDiffResult`).
2638
27-
Equivalent to `DiffResult(value, derivs::Tuple)`, where `derivs...` is the splatted form of `derivs::Tuple`.
39+
Note that `derivs` can be provide in splatted form, i.e. `DiffResult(value, derivs...)`.
2840
"""
29-
DiffResult(value, derivs...) = DiffResult(value, derivs)
41+
DiffResult
42+
43+
DiffResult(value::Number, derivs::Tuple{Vararg{Number}}) = ImmutableDiffResult(value, derivs)
44+
DiffResult(value::Number, derivs::Tuple{Vararg{SArray}}) = ImmutableDiffResult(value, derivs)
45+
DiffResult(value::SArray, derivs::Tuple{Vararg{SArray}}) = ImmutableDiffResult(value, derivs)
46+
DiffResult(value::Number, derivs::Tuple{Vararg{AbstractArray}}) = MutableDiffResult(value, derivs)
47+
DiffResult(value::AbstractArray, derivs::Tuple{Vararg{AbstractArray}}) = MutableDiffResult(value, derivs)
48+
DiffResult(value::Union{Number,AbstractArray}, derivs::Union{Number,AbstractArray}...) = DiffResult(value, derivs)
3049

3150
"""
3251
GradientResult(x::AbstractArray)
3352
34-
Construct a `DiffResult` that can be used for gradient calculations where `x` is the
35-
input to the target function.
53+
Construct a `DiffResult` that can be used for gradient calculations where `x` is the input
54+
to the target function.
3655
3756
Note that `GradientResult` allocates its own storage; `x` is only used for type and
3857
shape information. If you want to allocate storage yourself, use the `DiffResult`
3958
constructor instead.
4059
"""
4160
GradientResult(x::AbstractArray) = DiffResult(first(x), similar(x))
61+
GradientResult(x::SArray) = DiffResult(first(x), x)
4262

4363
"""
4464
JacobianResult(x::AbstractArray)
4565
46-
Construct a `DiffResult` that can be used for Jacobian calculations where `x` is the
47-
input to the target function. This method assumes that the target function's output
48-
dimension equals its input dimension.
66+
Construct a `DiffResult` that can be used for Jacobian calculations where `x` is the input
67+
to the target function. This method assumes that the target function's output dimension
68+
equals its input dimension.
4969
5070
Note that `JacobianResult` allocates its own storage; `x` is only used for type and
5171
shape information. If you want to allocate storage yourself, use the `DiffResult`
5272
constructor instead.
5373
"""
5474
JacobianResult(x::AbstractArray) = DiffResult(similar(x), similar(x, length(x), length(x)))
75+
JacobianResult(x::SArray{<:Any,T,<:Any,L}) where {T,L} = DiffResult(x, zeros(SMatrix{L,L,T}))
5576

5677
"""
5778
JacobianResult(y::AbstractArray, x::AbstractArray)
@@ -64,6 +85,7 @@ Like the single argument version, `y` and `x` are only used for type and
6485
shape information and are not stored in the returned `DiffResult`.
6586
"""
6687
JacobianResult(y::AbstractArray, x::AbstractArray) = DiffResult(similar(y), similar(y, length(y), length(x)))
88+
JacobianResult(y::SArray{<:Any,<:Any,<:Any,Y}, x::SArray{<:Any,T,<:Any,X}) where {T,Y,X} = DiffResult(y, zeros(SMatrix{Y,X,T}))
6789

6890
"""
6991
HessianResult(x::AbstractArray)
@@ -76,9 +98,30 @@ shape information. If you want to allocate storage yourself, use the `DiffResult
7698
constructor instead.
7799
"""
78100
HessianResult(x::AbstractArray) = DiffResult(first(x), similar(x), similar(x, length(x), length(x)))
101+
HessianResult(x::SArray{<:Any,T,<:Any,L}) where {T,L} = DiffResult(first(x), x, zeros(SMatrix{L,L,T}))
102+
103+
#############
104+
# Interface #
105+
#############
106+
107+
@generated function tuple_eltype(x::Tuple, ::Type{Val{i}}) where {i}
108+
return quote
109+
$(Expr(:meta, :inline))
110+
return $(x.parameters[i])
111+
end
112+
end
113+
114+
@generated function tuple_setindex(x::NTuple{N,Any}, y, ::Type{Val{i}}) where {N,i}
115+
new_tuple = Expr(:tuple, [ifelse(i == n, :y, :(x[$n])) for n in 1:N]...)
116+
return quote
117+
$(Expr(:meta, :inline))
118+
return $new_tuple
119+
end
120+
end
79121

80122
Base.eltype(r::DiffResult) = eltype(typeof(r))
81-
Base.eltype{O,V,D}(::Type{DiffResult{O,V,D}}) = eltype(V)
123+
124+
Base.eltype(::Type{D}) where {O,V,D<:DiffResult{O,V}} = eltype(V)
82125

83126
Base.:(==)(a::DiffResult, b::DiffResult) = a.value == b.value && a.derivs == b.derivs
84127

@@ -92,26 +135,35 @@ Base.copy(r::DiffResult) = DiffResult(copy(r.value), map(copy, r.derivs))
92135
93136
Return the primal value stored in `r`.
94137
95-
Note that this method returns a reference, not a copy. Thus, if `value(r)` is mutable,
96-
mutating `value(r)` will mutate `r`.
138+
Note that this method returns a reference, not a copy.
97139
"""
98140
value(r::DiffResult) = r.value
99141

100142
"""
101143
value!(r::DiffResult, x)
102144
103-
Copy `x` into `r`'s value storage, such that `value(r) == x`.
145+
Return `s::DiffResult` with the same data as `r`, except for `value(s) == x`.
146+
147+
This function may or may not mutate `r`. If `r::ImmutableDiffResult`, a totally new
148+
instance will be created and returned, whereas if `r::MutableDiffResult`, then `r` will be
149+
mutated in-place and returned. Thus, this function should be called as `r = value!(r, x)`.
104150
"""
105-
value!(r::DiffResult, x::Number) = (r.value = x; return r)
106-
value!(r::DiffResult, x::AbstractArray) = (copy!(value(r), x); return r)
151+
value!(r::MutableDiffResult, x::Number) = (r.value = x; return r)
152+
value!(r::MutableDiffResult, x::AbstractArray) = (copy!(value(r), x); return r)
153+
value!(r::ImmutableDiffResult, x::Union{Number,SArray}) = ImmutableDiffResult(x, r.derivs)
154+
value!(r::ImmutableDiffResult, x::AbstractArray) = ImmutableDiffResult(typeof(value(r))(x), r.derivs)
107155

108156
"""
109157
value!(f, r::DiffResult, x)
110158
111-
Like `value!(r::DiffResult, x)`, but with `f` applied to each element, such that `value(r) == map(f, x)`.
159+
Equivalent to `value!(r::DiffResult, map(f, x))`, but without the implied temporary
160+
allocation (when possible).
112161
"""
113-
value!(f, r::DiffResult, x::Number) = (r.value = f(x); return r)
114-
value!(f, r::DiffResult, x::AbstractArray) = (map!(f, value(r), x); return r)
162+
value!(f, r::MutableDiffResult, x::Number) = (r.value = f(x); return r)
163+
value!(f, r::MutableDiffResult, x::AbstractArray) = (map!(f, value(r), x); return r)
164+
value!(f, r::ImmutableDiffResult, x::Number) = value!(r, f(x))
165+
value!(f, r::ImmutableDiffResult, x::SArray) = value!(r, map(f, x))
166+
value!(f, r::ImmutableDiffResult, x::AbstractArray) = value!(r, map(f, typeof(value(r))(x)))
115167

116168
# derivative/derivative! #
117169
#------------------------#
@@ -121,122 +173,159 @@ value!(f, r::DiffResult, x::AbstractArray) = (map!(f, value(r), x); return r)
121173
122174
Return the `ith` derivative stored in `r`, defaulting to the first derivative.
123175
124-
Note that this method returns a reference, not a copy. Thus, if `derivative(r)` is mutable,
125-
mutating `derivative(r)` will mutate `r`.
176+
Note that this method returns a reference, not a copy.
126177
"""
127-
derivative{i}(r::DiffResult, ::Type{Val{i}} = Val{1}) = r.derivs[i]
178+
derivative(r::DiffResult, ::Type{Val{i}} = Val{1}) where {i} = r.derivs[i]
128179

129180
"""
130181
derivative!(r::DiffResult, x, ::Type{Val{i}} = Val{1})
131182
132-
Copy `x` into `r`'s `ith` derivative storage, such that `derivative(r, Val{i}) == x`.
183+
Return `s::DiffResult` with the same data as `r`, except `derivative(s, Val{i}) == x`.
184+
185+
This function may or may not mutate `r`. If `r::ImmutableDiffResult`, a totally new
186+
instance will be created and returned, whereas if `r::MutableDiffResult`, then `r` will be
187+
mutated in-place and returned. Thus, this function should be called as
188+
`r = derivative!(r, x, Val{i})`.
133189
"""
134-
@generated function derivative!{O,i}(r::DiffResult{O}, x::Number, ::Type{Val{i}} = Val{1})
135-
newderivs = Expr(:tuple, [i == n ? :(x) : :(derivative(r, Val{$n})) for n in 1:O]...)
136-
return quote
137-
r.derivs = $newderivs
138-
return r
139-
end
190+
function derivative!(r::MutableDiffResult, x::Number, ::Type{Val{i}} = Val{1}) where {i}
191+
r.derivs = tuple_setindex(r.derivs, x, Val{i})
192+
return r
140193
end
141194

142-
function derivative!{i}(r::DiffResult, x::AbstractArray, ::Type{Val{i}} = Val{1})
195+
function derivative!(r::MutableDiffResult, x::AbstractArray, ::Type{Val{i}} = Val{1}) where {i}
143196
copy!(derivative(r, Val{i}), x)
144197
return r
145198
end
146199

200+
function derivative!(r::ImmutableDiffResult, x::Union{Number,SArray}, ::Type{Val{i}} = Val{1}) where {i}
201+
return ImmutableDiffResult(value(r), tuple_setindex(r.derivs, x, Val{i}))
202+
end
203+
204+
function derivative!(r::ImmutableDiffResult, x::AbstractArray, ::Type{Val{i}} = Val{1}) where {i}
205+
T = tuple_eltype(r.derivs, Val{i})
206+
return ImmutableDiffResult(value(r), tuple_setindex(r.derivs, T(x), Val{i}))
207+
end
208+
147209
"""
148210
derivative!(f, r::DiffResult, x, ::Type{Val{i}} = Val{1})
149211
150-
Like `derivative!(r::DiffResult, x, Val{i})`, but with `f` applied to each element,
151-
such that `derivative(r, Val{i}) == map(f, x)`.
212+
Equivalent to `derivative!(r::DiffResult, map(f, x), Val{i})`, but without the implied
213+
temporary allocation (when possible).
152214
"""
153-
@generated function derivative!{O,i}(f, r::DiffResult{O}, x::Number, ::Type{Val{i}} = Val{1})
154-
newderivs = Expr(:tuple, [i == n ? :(f(x)) : :(derivative(r, Val{$n})) for n in 1:O]...)
155-
return quote
156-
r.derivs = $newderivs
157-
return r
158-
end
215+
function derivative!(f, r::MutableDiffResult, x::Number, ::Type{Val{i}} = Val{1}) where {i}
216+
r.derivs = tuple_setindex(r.derivs, f(x), Val{i})
217+
return r
159218
end
160219

161-
function derivative!{i}(f, r::DiffResult, x::AbstractArray, ::Type{Val{i}} = Val{1})
220+
function derivative!(f, r::MutableDiffResult, x::AbstractArray, ::Type{Val{i}} = Val{1}) where {i}
162221
map!(f, derivative(r, Val{i}), x)
163222
return r
164223
end
165224

225+
function derivative!(f, r::ImmutableDiffResult, x::Number, ::Type{Val{i}} = Val{1}) where {i}
226+
return derivative!(r, f(x), Val{i})
227+
end
228+
229+
function derivative!(f, r::ImmutableDiffResult, x::SArray, ::Type{Val{i}} = Val{1}) where {i}
230+
return derivative!(r, map(f, x), Val{i})
231+
end
232+
233+
function derivative!(f, r::ImmutableDiffResult, x::AbstractArray, ::Type{Val{i}} = Val{1}) where {i}
234+
T = tuple_eltype(r.derivs, Val{i})
235+
return derivative!(r, map(f, T(x)), Val{i})
236+
end
237+
166238
# special-cased methods #
167239
#-----------------------#
168240

169241
"""
170242
gradient(r::DiffResult)
171243
172-
Return the gradient stored in `r` (equivalent to `derivative(r)`).
244+
Return the gradient stored in `r`.
173245
174-
Note that this method returns a reference, not a copy. Thus, if `gradient(r)` is mutable,
175-
mutating `gradient(r)` will mutate `r`.
246+
Equivalent to `derivative(r, Val{1})`; see `derivative` docs for aliasing behavior.
176247
"""
177248
gradient(r::DiffResult) = derivative(r)
178249

179250
"""
180251
gradient!(r::DiffResult, x)
181252
182-
Copy `x` into `r`'s gradient storage, such that `gradient(r) == x`.
253+
Return `s::DiffResult` with the same data as `r`, except `gradient(s) == x`.
254+
255+
Equivalent to `derivative!(r, x, Val{1})`; see `derivative!` docs for aliasing behavior.
183256
"""
184257
gradient!(r::DiffResult, x) = derivative!(r, x)
185258

186259
"""
187260
gradient!(f, r::DiffResult, x)
188261
189-
Like `gradient!(r::DiffResult, x)`, but with `f` applied to each element,
190-
such that `gradient(r) == map(f, x)`.
262+
Equivalent to `gradient!(r::DiffResult, map(f, x))`, but without the implied temporary
263+
allocation (when possible).
264+
265+
Equivalent to `derivative!(f, r, x, Val{1})`; see `derivative!` docs for aliasing behavior.
191266
"""
192267
gradient!(f, r::DiffResult, x) = derivative!(f, r, x)
193268

194269
"""
195270
jacobian(r::DiffResult)
196271
197-
Return the Jacobian stored in `r` (equivalent to `derivative(r)`).
272+
Return the Jacobian stored in `r`.
198273
199-
Note that this method returns a reference, not a copy. Thus, if `jacobian(r)` is mutable,
200-
mutating `jacobian(r)` will mutate `r`.
274+
Equivalent to `derivative(r, Val{1})`; see `derivative` docs for aliasing behavior.
201275
"""
202276
jacobian(r::DiffResult) = derivative(r)
203277

204278
"""
205279
jacobian!(r::DiffResult, x)
206280
207-
Copy `x` into `r`'s Jacobian storage, such that `jacobian(r) == x`.
281+
Return `s::DiffResult` with the same data as `r`, except `jacobian(s) == x`.
282+
283+
Equivalent to `derivative!(r, x, Val{1})`; see `derivative!` docs for aliasing behavior.
208284
"""
209285
jacobian!(r::DiffResult, x) = derivative!(r, x)
210286

211287
"""
212288
jacobian!(f, r::DiffResult, x)
213289
214-
Like `jacobian!(r::DiffResult, x)`, but with `f` applied to each element,
215-
such that `jacobian(r) == map(f, x)`.
290+
Equivalent to `jacobian!(r::DiffResult, map(f, x))`, but without the implied temporary
291+
allocation (when possible).
292+
293+
Equivalent to `derivative!(f, r, x, Val{1})`; see `derivative!` docs for aliasing behavior.
216294
"""
217295
jacobian!(f, r::DiffResult, x) = derivative!(f, r, x)
218296

219297
"""
220298
hessian(r::DiffResult)
221299
222-
Return the Hessian stored in `r` (equivalent to `derivative(r, Val{2})`).
300+
Return the Hessian stored in `r`.
223301
224-
Note that this method returns a reference, not a copy. Thus, if `hessian(r)` is mutable,
225-
mutating `hessian(r)` will mutate `r`.
302+
Equivalent to `derivative(r, Val{2})`; see `derivative` docs for aliasing behavior.
226303
"""
227304
hessian(r::DiffResult) = derivative(r, Val{2})
228305

229306
"""
230307
hessian!(r::DiffResult, x)
231308
232-
Copy `x` into `r`'s Hessian storage, such that `hessian(r) == x`.
309+
Return `s::DiffResult` with the same data as `r`, except `hessian(s) == x`.
310+
311+
Equivalent to `derivative!(r, x, Val{2})`; see `derivative!` docs for aliasing behavior.
233312
"""
234313
hessian!(r::DiffResult, x) = derivative!(r, x, Val{2})
235314

236315
"""
237316
hessian!(f, r::DiffResult, x)
238317
239-
Like `hessian!(r::DiffResult, x)`, but with `f` applied to each element,
240-
such that `hessian(r) == map(f, x)`.
318+
Equivalent to `hessian!(r::DiffResult, map(f, x))`, but without the implied temporary
319+
allocation (when possible).
320+
321+
Equivalent to `derivative!(f, r, x, Val{2})`; see `derivative!` docs for aliasing behavior.
241322
"""
242323
hessian!(f, r::DiffResult, x) = derivative!(f, r, x, Val{2})
324+
325+
###################
326+
# Pretty Printing #
327+
###################
328+
329+
Base.show(io::IO, r::ImmutableDiffResult) = print(io, "ImmutableDiffResult($(r.value), $(r.derivs))")
330+
331+
Base.show(io::IO, r::MutableDiffResult) = print(io, "MutableDiffResult($(r.value), $(r.derivs))")

0 commit comments

Comments
 (0)