Skip to content

Commit 8c5cccc

Browse files
Merge pull request #402 from ErikQQY/qqy/voa_similar
Fix similar for VectorOfArray
2 parents 8e727d1 + da37b6f commit 8c5cccc

7 files changed

+49
-35
lines changed

ext/RecursiveArrayToolsReverseDiffExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ end
2424
return Array(VA), Array_adjoint
2525
end
2626

27-
@adjoint function Base.view(A::AbstractVectorOfArray{<:ReverseDiff.TrackedReal, N}, I::Colon...) where {N}
27+
@adjoint function Base.view(
28+
A::AbstractVectorOfArray{<:ReverseDiff.TrackedReal, N}, I::Colon...) where {N}
2829
view_adjoint = let A = A, I = I
2930
function (y)
3031
A = recursivecopy(A)

ext/RecursiveArrayToolsSparseArraysExt.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@ module RecursiveArrayToolsSparseArraysExt
33
import SparseArrays
44
import RecursiveArrayTools
55

6-
function Base.copyto!(dest::SparseArrays.AbstractCompressedVector, A::RecursiveArrayTools.ArrayPartition)
6+
function Base.copyto!(
7+
dest::SparseArrays.AbstractCompressedVector, A::RecursiveArrayTools.ArrayPartition)
78
@assert length(dest) == length(A)
89
cur = 1
910
@inbounds for i in 1:length(A.x)
@@ -17,4 +18,4 @@ function Base.copyto!(dest::SparseArrays.AbstractCompressedVector, A::RecursiveA
1718
dest
1819
end
1920

20-
end
21+
end

ext/RecursiveArrayToolsZygoteExt.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,6 @@ end
146146
view(A, I...), view_adjoint
147147
end
148148

149-
150149
@adjoint function Broadcast.broadcasted(::typeof(+), x::AbstractVectorOfArray,
151150
y::Union{Zygote.Numeric, AbstractVectorOfArray})
152151
broadcast(+, x, y), ȳ -> (nothing, map(x -> Zygote.unbroadcast(x, ȳ), (x, y))...)

src/vector_of_array.jl

Lines changed: 35 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ returns a vector of the series for each component, that is, `A[i,:]` for each `i
2929
A plot recipe is provided, which plots the `A[i,:]` series.
3030
3131
There is also support for `VectorOfArray` constructed from multi-dimensional arrays
32+
3233
```julia
3334
VectorOfArray(u::AbstractArray{AT}) where {T, N, AT <: AbstractArray{T, N}}
3435
```
@@ -60,8 +61,9 @@ A[1, :] # all time periods for f(t)
6061
A.t
6162
```
6263
"""
63-
mutable struct DiffEqArray{T, N, A, B, F, S, D <: Union{Nothing, ParameterTimeseriesCollection}} <:
64-
AbstractDiffEqArray{T, N, A}
64+
mutable struct DiffEqArray{
65+
T, N, A, B, F, S, D <: Union{Nothing, ParameterTimeseriesCollection}} <:
66+
AbstractDiffEqArray{T, N, A}
6567
u::A # A <: AbstractVector{<: AbstractArray{T, N - 1}}
6668
t::B
6769
p::F
@@ -177,7 +179,9 @@ function DiffEqArray(vec::AbstractVector{T},
177179
::NTuple{N, Int},
178180
p = nothing,
179181
sys = nothing; discretes = nothing) where {T, N}
180-
DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), typeof(p), typeof(sys), typeof(discretes)}(vec,
182+
DiffEqArray{
183+
eltype(T), N, typeof(vec), typeof(ts), typeof(p), typeof(sys), typeof(discretes)}(
184+
vec,
181185
ts,
182186
p,
183187
sys,
@@ -197,7 +201,8 @@ end
197201
function DiffEqArray(vec::AbstractVector{VT},
198202
ts::AbstractVector,
199203
::NTuple{N, Int}, p; discretes = nothing) where {T, N, VT <: AbstractArray{T, N}}
200-
DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), typeof(p), Nothing, typeof(discretes)}(vec,
204+
DiffEqArray{
205+
eltype(T), N, typeof(vec), typeof(ts), typeof(p), Nothing, typeof(discretes)}(vec,
201206
ts,
202207
p,
203208
nothing,
@@ -253,7 +258,7 @@ function DiffEqArray(vec::AbstractVector{VT},
253258
typeof(ts),
254259
typeof(p),
255260
typeof(sys),
256-
typeof(discretes),
261+
typeof(discretes)
257262
}(vec,
258263
ts,
259264
p,
@@ -375,19 +380,23 @@ Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::NotSymboli
375380
end
376381

377382
struct ParameterIndexingError <: Exception
378-
sym
383+
sym::Any
379384
end
380385

381386
function Base.showerror(io::IO, pie::ParameterIndexingError)
382-
print(io, "Indexing with parameters is deprecated. Use `getp(A, $(pie.sym))` for parameter indexing.")
387+
print(io,
388+
"Indexing with parameters is deprecated. Use `getp(A, $(pie.sym))` for parameter indexing.")
383389
end
384390

385391
# Symbolic Indexing Methods
386392
for (symtype, elsymtype, valtype, errcheck) in [
387-
(ScalarSymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Any, :(is_parameter(A, sym) && !is_timeseries_parameter(A, sym))),
388-
(ArraySymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Any, :(is_parameter(A, sym) && !is_timeseries_parameter(A, sym))),
389-
(NotSymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Union{<:Tuple, <:AbstractArray},
390-
:(all(x -> is_parameter(A, x) && !is_timeseries_parameter(A, x), sym))),
393+
(ScalarSymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Any,
394+
:(is_parameter(A, sym) && !is_timeseries_parameter(A, sym))),
395+
(ArraySymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Any,
396+
:(is_parameter(A, sym) && !is_timeseries_parameter(A, sym))),
397+
(NotSymbolic, SymbolicIndexingInterface.SymbolicTypeTrait,
398+
Union{<:Tuple, <:AbstractArray},
399+
:(all(x -> is_parameter(A, x) && !is_timeseries_parameter(A, x), sym)))
391400
]
392401
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype,
393402
::$elsymtype, sym::$valtype, arg...)
@@ -413,8 +422,9 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractVectorOfArray, _arg,
413422
elsymtype = symbolic_type(eltype(_arg))
414423

415424
if symtype == NotSymbolic() && elsymtype == NotSymbolic()
416-
if _arg isa Union{Tuple, AbstractArray} && any(x -> symbolic_type(x) != NotSymbolic(), _arg)
417-
_getindex(A, symtype, elsymtype, _arg, args...)
425+
if _arg isa Union{Tuple, AbstractArray} &&
426+
any(x -> symbolic_type(x) != NotSymbolic(), _arg)
427+
_getindex(A, symtype, elsymtype, _arg, args...)
418428
else
419429
_getindex(A, symtype, _arg, args...)
420430
end
@@ -536,9 +546,7 @@ end
536546

537547
function Base.zero(VA::AbstractVectorOfArray)
538548
val = copy(VA)
539-
for i in eachindex(VA.u)
540-
val.u[i] = zero(VA.u[i])
541-
end
549+
val.u = zero.(VA.u)
542550
return val
543551
end
544552

@@ -707,30 +715,32 @@ end
707715

708716
# Tools for creating similar objects
709717
Base.eltype(::Type{<:AbstractVectorOfArray{T}}) where {T} = T
710-
# TODO: Is there a better way to do this?
718+
711719
@inline function Base.similar(VA::AbstractVectorOfArray, args...)
712720
if args[end] isa Type
713721
return Base.similar(eltype(VA)[], args..., size(VA))
714722
else
715723
return Base.similar(eltype(VA)[], args...)
716724
end
717725
end
718-
@inline function Base.similar(VA::VectorOfArray, ::Type{T} = eltype(VA)) where {T}
719-
VectorOfArray([similar(VA[:, i], T) for i in eachindex(VA.u)])
720-
end
721726

722-
# for VectorOfArray with multi-dimensional parent arrays of arrays where all elements are the same type
723727
function Base.similar(vec::VectorOfArray{
724728
T, N, AT}) where {T, N, AT <: AbstractArray{<:AbstractArray{T}}}
725729
return VectorOfArray(similar.(Base.parent(vec)))
726730
end
727731

728-
# special-case when the multi-dimensional parent array is just an AbstractVector (call the old method)
729-
function Base.similar(vec::VectorOfArray{
730-
T, N, AT}) where {T, N, AT <: AbstractVector{<:AbstractArray{T}}}
731-
return Base.similar(vec, eltype(vec))
732+
@inline function Base.similar(VA::VectorOfArray, ::Type{T} = eltype(VA)) where {T}
733+
VectorOfArray(similar.(VA.u, T))
732734
end
733735

736+
@inline function Base.similar(VA::VectorOfArray, dims::N) where {N <: Number}
737+
l = length(VA)
738+
if dims <= l
739+
VectorOfArray(similar.(VA.u[1:dims]))
740+
else
741+
VectorOfArray([similar.(VA.u); [similar(VA.u[end]) for _ in (l + 1):dims]])
742+
end
743+
end
734744

735745
# fill!
736746
# For DiffEqArray it ignores ts and fills only u

test/basic_indexing.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -248,13 +248,13 @@ function foo!(u)
248248
end
249249
foo!(u_matrix)
250250
foo!(u_vector)
251-
@test all(u_matrix .== [3, 10])
251+
@test all(u_matrix .== [3, 10])
252252
@test all(vec(u_matrix) .≈ vec(u_vector))
253253

254254
# test that, for VectorOfArray with multi-dimensional parent arrays,
255255
# broadcast and `similar` preserve the structure of the parent array
256256
@test typeof(parent(similar(u_matrix))) == typeof(parent(u_matrix))
257-
@test typeof(parent((x->x).(u_matrix))) == typeof(parent(u_matrix))
257+
@test typeof(parent((x -> x).(u_matrix))) == typeof(parent(u_matrix))
258258

259259
# test efficiency
260260
num_allocs = @allocations foo!(u_matrix)

test/interface_tests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,10 @@ testva2 = similar(testva)
145145
@test typeof(testva2) == typeof(testva)
146146
@test size(testva2) == size(testva)
147147

148+
testva3 = similar(testva, 10)
149+
@test typeof(testva3) == typeof(testva)
150+
@test length(testva3) == 10
151+
148152
# Fill AbstractVectorOfArray and check all
149153
testval = 3.0
150154
fill!(testva2, testval)

test/partitions_test.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,6 @@ y = ArrayPartition(ArrayPartition([1], [2.0]), ArrayPartition([3], [4.0]))
155155
@test all(isnan, ArrayPartition([NaN], [NaN]))
156156
@test all(isnan, ArrayPartition([NaN], ArrayPartition([NaN])))
157157

158-
159158
# broadcasting
160159
_scalar_op(y) = y + 1
161160
# Can't do `@inferred(_scalar_op.(x))` so we wrap that in a function:
@@ -303,7 +302,7 @@ end
303302
end
304303

305304
@testset "Scalar copyto!" begin
306-
u = [2.0,1.0]
307-
copyto!(u, ArrayPartition(1.0,-1.2))
308-
@test u == [1.0,-1.2]
305+
u = [2.0, 1.0]
306+
copyto!(u, ArrayPartition(1.0, -1.2))
307+
@test u == [1.0, -1.2]
309308
end

0 commit comments

Comments
 (0)