diff --git a/src/vector_of_array.jl b/src/vector_of_array.jl index 84f45b26..0af4c507 100644 --- a/src/vector_of_array.jl +++ b/src/vector_of_array.jl @@ -156,13 +156,16 @@ function VectorOfArray(vec::AbstractVector{VT}) where {T, N, VT <: AbstractArray VectorOfArray{T, N + 1, typeof(vec)}(vec) end -# allow multi-dimensional arrays as long as they're linearly indexed +# allow multi-dimensional arrays as long as they're linearly indexed. +# currently restricted to arrays whose elements are all the same type function VectorOfArray(array::AbstractArray{AT}) where {T, N, AT <: AbstractArray{T, N}} @assert IndexStyle(typeof(array)) isa IndexLinear return VectorOfArray{T, N + 1, typeof(array)}(array) end +Base.parent(vec::VectorOfArray) = vec.u + function DiffEqArray(vec::AbstractVector{T}, ts::AbstractVector, ::NTuple{N, Int}, @@ -721,6 +724,18 @@ end VectorOfArray([similar(VA[:, i], T) for i in eachindex(VA.u)]) end +# for VectorOfArray with multi-dimensional parent arrays of arrays where all elements are the same type +function Base.similar(vec::VectorOfArray{ + T, N, AT}) where {T, N, AT <: AbstractArray{<:AbstractArray{T}}} + return VectorOfArray(similar(Base.parent(vec))) +end + +# special-case when the multi-dimensional parent array is just an AbstractVector (call the old method) +function Base.similar(vec::VectorOfArray{ + T, N, AT}) where {T, N, AT <: AbstractVector{<:AbstractArray{T}}} + return Base.similar(vec, eltype(vec)) +end + # fill! # For DiffEqArray it ignores ts and fills only u function Base.fill!(VA::AbstractVectorOfArray, x) diff --git a/test/basic_indexing.jl b/test/basic_indexing.jl index 86fba66f..749f7419 100644 --- a/test/basic_indexing.jl +++ b/test/basic_indexing.jl @@ -250,6 +250,10 @@ foo!(u_matrix) foo!(u_vector) @test u_matrix ≈ u_vector +# test that, for VectorOfArray with multi-dimensional parent arrays, +# `similar` preserves the structure of the parent array +@test typeof(parent(similar(u_matrix))) == typeof(parent(u_matrix)) + # test efficiency num_allocs = @allocations foo!(u_matrix) @test num_allocs == 0 diff --git a/test/downstream/Project.toml b/test/downstream/Project.toml index 79c99a1c..651c60cb 100644 --- a/test/downstream/Project.toml +++ b/test/downstream/Project.toml @@ -2,6 +2,7 @@ ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" @@ -10,4 +11,5 @@ ModelingToolkit = "8.33" MonteCarloMeasurements = "1.1" OrdinaryDiffEq = "6.31" Unitful = "1.17" -Tracker = "0.2" \ No newline at end of file +Tracker = "0.2" +StaticArrays = "1" diff --git a/test/upstream.jl b/test/downstream/odesolve.jl similarity index 84% rename from test/upstream.jl rename to test/downstream/odesolve.jl index 31f27a05..4aab1598 100644 --- a/test/upstream.jl +++ b/test/downstream/odesolve.jl @@ -1,4 +1,4 @@ -using OrdinaryDiffEq, NLsolve, RecursiveArrayTools, Test, ArrayInterface +using OrdinaryDiffEq, NLsolve, RecursiveArrayTools, Test, ArrayInterface, StaticArrays function lorenz(du, u, p, t) du[1] = 10.0 * (u[2] - u[1]) du[2] = u[1] * (28.0 - u[3]) - u[2] @@ -49,3 +49,14 @@ end ArrayPartition(zeros(1), [0.75])), (0.0, 1.0)), Rodas5()).retcode == ReturnCode.Success + +function rhs!(duu::VectorOfArray, uu::VectorOfArray, p, t) + du = parent(duu) + u = parent(uu) + du .= u +end + +u = fill(SVector{2}(ones(2)), 2, 3) +ode = ODEProblem(rhs!, VectorOfArray(u), (0.0, 1.0)) +sol = solve(ode, Tsit5()) +@test SciMLBase.successful_retcode(sol) diff --git a/test/runtests.jl b/test/runtests.jl index 69fda8b9..7ed2afcd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -19,67 +19,31 @@ end @time begin if GROUP == "Core" || GROUP == "All" - @time @safetestset "Quality Assurance" begin - include("qa.jl") - end - @time @safetestset "Utils Tests" begin - include("utils_test.jl") - end - @time @safetestset "NamedArrayPartition Tests" begin - include("named_array_partition_tests.jl") - end - @time @safetestset "Partitions Tests" begin - include("partitions_test.jl") - end - @time @safetestset "VecOfArr Indexing Tests" begin - include("basic_indexing.jl") - end - @time @safetestset "SymbolicIndexingInterface API test" begin - include("symbolic_indexing_interface_test.jl") - end - @time @safetestset "VecOfArr Interface Tests" begin - include("interface_tests.jl") - end - @time @safetestset "Table traits" begin - include("tabletraits.jl") - end - @time @safetestset "StaticArrays Tests" begin - include("copy_static_array_test.jl") - end - @time @safetestset "Linear Algebra Tests" begin - include("linalg.jl") - end - @time @safetestset "Upstream Tests" begin - include("upstream.jl") - end - @time @safetestset "Adjoint Tests" begin - include("adjoints.jl") - end - @time @safetestset "Measurement Tests" begin - include("measurements.jl") - end + @time @safetestset "Quality Assurance" include("qa.jl") + @time @safetestset "Utils Tests" include("utils_test.jl") + @time @safetestset "NamedArrayPartition Tests" include("named_array_partition_tests.jl") + @time @safetestset "Partitions Tests" include("partitions_test.jl") + @time @safetestset "VecOfArr Indexing Tests" include("basic_indexing.jl") + @time @safetestset "SymbolicIndexingInterface API test" include("symbolic_indexing_interface_test.jl") + @time @safetestset "VecOfArr Interface Tests" include("interface_tests.jl") + @time @safetestset "Table traits" include("tabletraits.jl") + @time @safetestset "StaticArrays Tests" include("copy_static_array_test.jl") + @time @safetestset "Linear Algebra Tests" include("linalg.jl") + @time @safetestset "Adjoint Tests" include("adjoints.jl") + @time @safetestset "Measurement Tests" include("measurements.jl") end if GROUP == "Downstream" activate_downstream_env() - @time @safetestset "DiffEqArray Indexing Tests" begin - include("downstream/symbol_indexing.jl") - end - @time @safetestset "Event Tests with ArrayPartition" begin - include("downstream/downstream_events.jl") - end - @time @safetestset "Measurements and Units" begin - include("downstream/measurements_and_units.jl") - end - @time @safetestset "TrackerExt" begin - include("downstream/TrackerExt.jl") - end + @time @safetestset "DiffEqArray Indexing Tests" include("downstream/symbol_indexing.jl") + @time @safetestset "ODE Solve Tests" include("downstream/odesolve.jl") + @time @safetestset "Event Tests with ArrayPartition" include("downstream/downstream_events.jl") + @time @safetestset "Measurements and Units" include("downstream/measurements_and_units.jl") + @time @safetestset "TrackerExt" include("downstream/TrackerExt.jl") end if GROUP == "GPU" activate_gpu_env() - @time @safetestset "VectorOfArray GPU" begin - include("gpu/vectorofarray_gpu.jl") - end + @time @safetestset "VectorOfArray GPU" include("gpu/vectorofarray_gpu.jl") end end