Skip to content

Fix similar for VectorOfArray #402

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ext/RecursiveArrayToolsReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ end
return Array(VA), Array_adjoint
end

@adjoint function Base.view(A::AbstractVectorOfArray{<:ReverseDiff.TrackedReal, N}, I::Colon...) where {N}
@adjoint function Base.view(
A::AbstractVectorOfArray{<:ReverseDiff.TrackedReal, N}, I::Colon...) where {N}
view_adjoint = let A = A, I = I
function (y)
A = recursivecopy(A)
Expand Down
5 changes: 3 additions & 2 deletions ext/RecursiveArrayToolsSparseArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ module RecursiveArrayToolsSparseArraysExt
import SparseArrays
import RecursiveArrayTools

function Base.copyto!(dest::SparseArrays.AbstractCompressedVector, A::RecursiveArrayTools.ArrayPartition)
function Base.copyto!(
dest::SparseArrays.AbstractCompressedVector, A::RecursiveArrayTools.ArrayPartition)
@assert length(dest) == length(A)
cur = 1
@inbounds for i in 1:length(A.x)
Expand All @@ -17,4 +18,4 @@ function Base.copyto!(dest::SparseArrays.AbstractCompressedVector, A::RecursiveA
dest
end

end
end
1 change: 0 additions & 1 deletion ext/RecursiveArrayToolsZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ end
view(A, I...), view_adjoint
end


@adjoint function Broadcast.broadcasted(::typeof(+), x::AbstractVectorOfArray,
y::Union{Zygote.Numeric, AbstractVectorOfArray})
broadcast(+, x, y), ȳ -> (nothing, map(x -> Zygote.unbroadcast(x, ȳ), (x, y))...)
Expand Down
60 changes: 35 additions & 25 deletions src/vector_of_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ returns a vector of the series for each component, that is, `A[i,:]` for each `i
A plot recipe is provided, which plots the `A[i,:]` series.

There is also support for `VectorOfArray` constructed from multi-dimensional arrays

```julia
VectorOfArray(u::AbstractArray{AT}) where {T, N, AT <: AbstractArray{T, N}}
```
Expand Down Expand Up @@ -60,8 +61,9 @@ A[1, :] # all time periods for f(t)
A.t
```
"""
mutable struct DiffEqArray{T, N, A, B, F, S, D <: Union{Nothing, ParameterTimeseriesCollection}} <:
AbstractDiffEqArray{T, N, A}
mutable struct DiffEqArray{
T, N, A, B, F, S, D <: Union{Nothing, ParameterTimeseriesCollection}} <:
AbstractDiffEqArray{T, N, A}
u::A # A <: AbstractVector{<: AbstractArray{T, N - 1}}
t::B
p::F
Expand Down Expand Up @@ -177,7 +179,9 @@ function DiffEqArray(vec::AbstractVector{T},
::NTuple{N, Int},
p = nothing,
sys = nothing; discretes = nothing) where {T, N}
DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), typeof(p), typeof(sys), typeof(discretes)}(vec,
DiffEqArray{
eltype(T), N, typeof(vec), typeof(ts), typeof(p), typeof(sys), typeof(discretes)}(
vec,
ts,
p,
sys,
Expand All @@ -197,7 +201,8 @@ end
function DiffEqArray(vec::AbstractVector{VT},
ts::AbstractVector,
::NTuple{N, Int}, p; discretes = nothing) where {T, N, VT <: AbstractArray{T, N}}
DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), typeof(p), Nothing, typeof(discretes)}(vec,
DiffEqArray{
eltype(T), N, typeof(vec), typeof(ts), typeof(p), Nothing, typeof(discretes)}(vec,
ts,
p,
nothing,
Expand Down Expand Up @@ -253,7 +258,7 @@ function DiffEqArray(vec::AbstractVector{VT},
typeof(ts),
typeof(p),
typeof(sys),
typeof(discretes),
typeof(discretes)
}(vec,
ts,
p,
Expand Down Expand Up @@ -375,19 +380,23 @@ Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::NotSymboli
end

struct ParameterIndexingError <: Exception
sym
sym::Any
end

function Base.showerror(io::IO, pie::ParameterIndexingError)
print(io, "Indexing with parameters is deprecated. Use `getp(A, $(pie.sym))` for parameter indexing.")
print(io,
"Indexing with parameters is deprecated. Use `getp(A, $(pie.sym))` for parameter indexing.")
end

# Symbolic Indexing Methods
for (symtype, elsymtype, valtype, errcheck) in [
(ScalarSymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Any, :(is_parameter(A, sym) && !is_timeseries_parameter(A, sym))),
(ArraySymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Any, :(is_parameter(A, sym) && !is_timeseries_parameter(A, sym))),
(NotSymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Union{<:Tuple, <:AbstractArray},
:(all(x -> is_parameter(A, x) && !is_timeseries_parameter(A, x), sym))),
(ScalarSymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Any,
:(is_parameter(A, sym) && !is_timeseries_parameter(A, sym))),
(ArraySymbolic, SymbolicIndexingInterface.SymbolicTypeTrait, Any,
:(is_parameter(A, sym) && !is_timeseries_parameter(A, sym))),
(NotSymbolic, SymbolicIndexingInterface.SymbolicTypeTrait,
Union{<:Tuple, <:AbstractArray},
:(all(x -> is_parameter(A, x) && !is_timeseries_parameter(A, x), sym)))
]
@eval Base.@propagate_inbounds function _getindex(A::AbstractDiffEqArray, ::$symtype,
::$elsymtype, sym::$valtype, arg...)
Expand All @@ -413,8 +422,9 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractVectorOfArray, _arg,
elsymtype = symbolic_type(eltype(_arg))

if symtype == NotSymbolic() && elsymtype == NotSymbolic()
if _arg isa Union{Tuple, AbstractArray} && any(x -> symbolic_type(x) != NotSymbolic(), _arg)
_getindex(A, symtype, elsymtype, _arg, args...)
if _arg isa Union{Tuple, AbstractArray} &&
any(x -> symbolic_type(x) != NotSymbolic(), _arg)
_getindex(A, symtype, elsymtype, _arg, args...)
else
_getindex(A, symtype, _arg, args...)
end
Expand Down Expand Up @@ -536,9 +546,7 @@ end

function Base.zero(VA::AbstractVectorOfArray)
val = copy(VA)
for i in eachindex(VA.u)
val.u[i] = zero(VA.u[i])
end
val.u = zero.(VA.u)
return val
end

Expand Down Expand Up @@ -707,30 +715,32 @@ end

# Tools for creating similar objects
Base.eltype(::Type{<:AbstractVectorOfArray{T}}) where {T} = T
# TODO: Is there a better way to do this?

@inline function Base.similar(VA::AbstractVectorOfArray, args...)
if args[end] isa Type
return Base.similar(eltype(VA)[], args..., size(VA))
else
return Base.similar(eltype(VA)[], args...)
end
end
@inline function Base.similar(VA::VectorOfArray, ::Type{T} = eltype(VA)) where {T}
VectorOfArray([similar(VA[:, i], T) for i in eachindex(VA.u)])
end

# for VectorOfArray with multi-dimensional parent arrays of arrays where all elements are the same type
function Base.similar(vec::VectorOfArray{
T, N, AT}) where {T, N, AT <: AbstractArray{<:AbstractArray{T}}}
return VectorOfArray(similar.(Base.parent(vec)))
end

# special-case when the multi-dimensional parent array is just an AbstractVector (call the old method)
function Base.similar(vec::VectorOfArray{
T, N, AT}) where {T, N, AT <: AbstractVector{<:AbstractArray{T}}}
return Base.similar(vec, eltype(vec))
@inline function Base.similar(VA::VectorOfArray, ::Type{T} = eltype(VA)) where {T}
VectorOfArray(similar.(VA.u, T))
end

@inline function Base.similar(VA::VectorOfArray, dims::N) where {N <: Number}
l = length(VA)
if dims <= l
VectorOfArray(similar.(VA.u[1:dims]))
else
VectorOfArray([similar.(VA.u); [similar(VA.u[end]) for _ in (l + 1):dims]])
end
end

# fill!
# For DiffEqArray it ignores ts and fills only u
Expand Down
4 changes: 2 additions & 2 deletions test/basic_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -248,13 +248,13 @@ function foo!(u)
end
foo!(u_matrix)
foo!(u_vector)
@test all(u_matrix .== [3, 10])
@test all(u_matrix .== [3, 10])
@test all(vec(u_matrix) .≈ vec(u_vector))

# test that, for VectorOfArray with multi-dimensional parent arrays,
# broadcast and `similar` preserve the structure of the parent array
@test typeof(parent(similar(u_matrix))) == typeof(parent(u_matrix))
@test typeof(parent((x->x).(u_matrix))) == typeof(parent(u_matrix))
@test typeof(parent((x -> x).(u_matrix))) == typeof(parent(u_matrix))

# test efficiency
num_allocs = @allocations foo!(u_matrix)
Expand Down
4 changes: 4 additions & 0 deletions test/interface_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,10 @@ testva2 = similar(testva)
@test typeof(testva2) == typeof(testva)
@test size(testva2) == size(testva)

testva3 = similar(testva, 10)
@test typeof(testva3) == typeof(testva)
@test length(testva3) == 10

# Fill AbstractVectorOfArray and check all
testval = 3.0
fill!(testva2, testval)
Expand Down
7 changes: 3 additions & 4 deletions test/partitions_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,6 @@ y = ArrayPartition(ArrayPartition([1], [2.0]), ArrayPartition([3], [4.0]))
@test all(isnan, ArrayPartition([NaN], [NaN]))
@test all(isnan, ArrayPartition([NaN], ArrayPartition([NaN])))


# broadcasting
_scalar_op(y) = y + 1
# Can't do `@inferred(_scalar_op.(x))` so we wrap that in a function:
Expand Down Expand Up @@ -303,7 +302,7 @@ end
end

@testset "Scalar copyto!" begin
u = [2.0,1.0]
copyto!(u, ArrayPartition(1.0,-1.2))
@test u == [1.0,-1.2]
u = [2.0, 1.0]
copyto!(u, ArrayPartition(1.0, -1.2))
@test u == [1.0, -1.2]
end
Loading