From 58d77e1efcb6710cfcea7ca59a6b98dc61873f10 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 11 Aug 2022 16:17:36 +0200 Subject: [PATCH 1/2] Implement table traits for DiffEq arrays --- Project.toml | 6 ++- src/RecursiveArrayTools.jl | 3 ++ src/tabletraits.jl | 77 ++++++++++++++++++++++++++++++ test/downstream/symbol_indexing.jl | 32 ++++++++++--- test/runtests.jl | 1 + test/tabletraits.jl | 15 ++++++ test/testutils.jl | 60 +++++++++++++++++++++++ 7 files changed, 187 insertions(+), 7 deletions(-) create mode 100644 src/tabletraits.jl create mode 100644 test/tabletraits.jl create mode 100644 test/testutils.jl diff --git a/Project.toml b/Project.toml index 6cc771fc..13ec3a98 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "RecursiveArrayTools" uuid = "731186ca-8d62-57ce-b412-fbd966d074cd" authors = ["Chris Rackauckas "] -version = "2.31.3" +version = "2.32.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -11,10 +11,12 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" +IteratorInterfaceExtensions = "82899510-4779-5014-852e-03e436cf321d" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] @@ -25,8 +27,10 @@ ChainRulesCore = "0.10.7, 1" DocStringExtensions = "0.8, 0.9" FillArrays = "0.11, 0.12, 0.13" GPUArraysCore = "0.1" +IteratorInterfaceExtensions = "1" RecipesBase = "0.7, 0.8, 1.0" StaticArraysCore = "1" +Tables = "1" ZygoteRules = "0.2" julia = "1.6" diff --git a/src/RecursiveArrayTools.jl b/src/RecursiveArrayTools.jl index 25529e1e..407346ad 100644 --- a/src/RecursiveArrayTools.jl +++ b/src/RecursiveArrayTools.jl @@ -19,11 +19,14 @@ import ArrayInterfaceStaticArraysCore using FillArrays +import Tables, IteratorInterfaceExtensions + abstract type AbstractVectorOfArray{T, N, A} <: AbstractArray{T, N} end abstract type AbstractDiffEqArray{T, N, A} <: AbstractVectorOfArray{T, N, A} end include("utils.jl") include("vector_of_array.jl") +include("tabletraits.jl") include("array_partition.jl") include("zygote.jl") diff --git a/src/tabletraits.jl b/src/tabletraits.jl new file mode 100644 index 00000000..e7a70492 --- /dev/null +++ b/src/tabletraits.jl @@ -0,0 +1,77 @@ +# Tables traits for AbstractDiffEqArray +Tables.istable(::Type{<:AbstractDiffEqArray}) = true +Tables.rowaccess(::Type{<:AbstractDiffEqArray}) = true +function Tables.rows(A::AbstractDiffEqArray) + VT = eltype(A.u) + if VT <: AbstractArray + N = length(A.u[1]) + names = [ + :timestamp, + (A.syms !== nothing ? (A.syms[i] for i in 1:N) : + (Symbol("value", i) for i in 1:N))..., + ] + types = Type[eltype(A.t), (eltype(A.u[1]) for _ in 1:N)...] + else + names = [:timestamp, A.syms !== nothing ? A.syms[1] : :value] + types = Type[eltype(A.t), VT] + end + return AbstractDiffEqArrayRows(names, types, A.t, A.u) +end + +# Override fallback definitions for AbstractMatrix +Tables.istable(::AbstractDiffEqArray) = true # Ref: https://github.com/JuliaData/Tables.jl/pull/198 +Tables.columns(x::AbstractDiffEqArray) = Tables.columntable(Tables.rows(x)) + +# Iterator of Tables.AbstractRow rows +struct AbstractDiffEqArrayRows{T, U} + names::Vector{Symbol} + types::Vector{Type} + lookup::Dict{Symbol, Int} + t::T + u::U +end +function AbstractDiffEqArrayRows(names, types, t, u) + AbstractDiffEqArrayRows(names, types, + Dict(nm => i for (i, nm) in enumerate(names)), t, u) +end + +Base.length(x::AbstractDiffEqArrayRows) = length(x.u) +function Base.eltype(::Type{AbstractDiffEqArrayRows{T, U}}) where {T, U} + AbstractDiffEqArrayRow{eltype(T), eltype(U)} +end +function Base.iterate(x::AbstractDiffEqArrayRows, (t_state, u_state)=(iterate(x.t), iterate(x.u))) + t_state === nothing && return nothing + u_state === nothing && return nothing + t, _t_state = t_state + u, _u_state = u_state + st = (iterate(x.t, _t_state), iterate(x.u, _u_state)) + return (AbstractDiffEqArrayRow(x.names, x.lookup, t, u), st) +end + +Tables.istable(::Type{<:AbstractDiffEqArrayRows}) = true +Tables.rowaccess(::Type{<:AbstractDiffEqArrayRows}) = true +Tables.rows(x::AbstractDiffEqArrayRows) = x +Tables.schema(x::AbstractDiffEqArrayRows) = Tables.Schema(x.names, x.types) + +# AbstractRow subtype +struct AbstractDiffEqArrayRow{T, U} <: Tables.AbstractRow + names::Vector{Symbol} + lookup::Dict{Symbol, Int} + t::T + u::U +end + +Tables.columnnames(x::AbstractDiffEqArrayRow) = getfield(x, :names) +function Tables.getcolumn(x::AbstractDiffEqArrayRow, i::Int) + i == 1 ? getfield(x, :t) : getfield(x, :u)[i - 1] +end +function Tables.getcolumn(x::AbstractDiffEqArrayRow, nm::Symbol) + nm === :timestamp ? getfield(x, :t) : getfield(x, :u)[getfield(x, :lookup)[nm] - 1] +end + +# Iterator interface for QueryVerse +# (see also https://tables.juliadata.org/stable/#Tables.datavaluerows) +IteratorInterfaceExtensions.isiterable(::AbstractDiffEqArray) = true +function IteratorInterfaceExtensions.getiterator(A::AbstractDiffEqArray) + Tables.datavaluerows(Tables.rows(A)) +end diff --git a/test/downstream/symbol_indexing.jl b/test/downstream/symbol_indexing.jl index 7f32198e..2b7723ef 100644 --- a/test/downstream/symbol_indexing.jl +++ b/test/downstream/symbol_indexing.jl @@ -1,25 +1,45 @@ using RecursiveArrayTools, ModelingToolkit, OrdinaryDiffEq, Test +include("../testutils.jl") + @variables t x(t) @parameters τ D = Differential(t) @variables RHS(t) @named fol_separate = ODESystem([ RHS ~ (1 - x)/τ, D(x) ~ RHS ]) -fol_simplified = structural_simplify(fol_separate) +fol_simplified = structural_simplify(fol_separate) prob = ODEProblem(fol_simplified, [x => 0.0], (0.0,10.0), [τ => 3.0]) sol = solve(prob, Tsit5()) sol_new = DiffEqArray( sol.u[1:10], - sol.t[1:10], - sol.prob.f.syms, - sol.prob.f.indepsym, - sol.prob.f.observed, + sol.t[1:10], + sol.prob.f.syms, + sol.prob.f.indepsym, + sol.prob.f.observed, sol.prob.p ) @test sol_new[RHS] ≈ (1 .- sol_new[x])./3.0 @test sol_new[t] ≈ sol_new.t -@test sol_new[t, 1:5] ≈ sol_new.t[1:5] \ No newline at end of file +@test sol_new[t, 1:5] ≈ sol_new.t[1:5] + +# Tables interface +test_tables_interface(sol_new, [:timestamp, Symbol("x(t)")], hcat(sol_new[t], sol_new[x])) + +# Two components +@variables y(t) +@parameters α β γ δ +@named lv = ODESystem([ D(x) ~ α * x - β * x * y, + D(y) ~ δ * x * y - γ * x * y]) + +prob = ODEProblem(lv, [x => 1.0, y => 1.0], (0.0, 10.0), + [α => 1.5, β => 1.0, γ => 3.0, δ => 1.0]) +sol = solve(prob, Tsit5()) + +ts = 0:0.5:10 +sol_ts = sol(ts) +@assert sol_ts isa DiffEqArray +test_tables_interface(sol_ts, [:timestamp, Symbol("x(t)"), Symbol("y(t)")], hcat(ts, Array(sol_ts)')) \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index efa80e65..69eec144 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -24,6 +24,7 @@ if GROUP == "Core" || GROUP == "All" @time @testset "Partitions Tests" begin include("partitions_test.jl") end @time @testset "VecOfArr Indexing Tests" begin include("basic_indexing.jl") end @time @testset "VecOfArr Interface Tests" begin include("interface_tests.jl") end + @time @testset "Table traits" begin include("tabletraits.jl") end @time @testset "StaticArrays Tests" begin include("copy_static_array_test.jl") end @time @testset "Linear Algebra Tests" begin include("linalg.jl") end @time @testset "Upstream Tests" begin include("upstream.jl") end diff --git a/test/tabletraits.jl b/test/tabletraits.jl new file mode 100644 index 00000000..a9524483 --- /dev/null +++ b/test/tabletraits.jl @@ -0,0 +1,15 @@ +using RecursiveArrayTools, Random, Test + +include("testutils.jl") + +Random.seed!(1234) + +n = 20 +t = sort(randn(n)) +u = randn(n) +A = DiffEqArray(u, t) +test_tables_interface(A, [:timestamp, :value], hcat(t, u)) + +u = [randn(3) for _ in 1:n] +A = DiffEqArray(u, t) +test_tables_interface(A, [:timestamp, :value1, :value2, :value3], hcat(t, reduce(vcat, u'))) diff --git a/test/testutils.jl b/test/testutils.jl new file mode 100644 index 00000000..0830849b --- /dev/null +++ b/test/testutils.jl @@ -0,0 +1,60 @@ +using RecursiveArrayTools +using RecursiveArrayTools: Tables, IteratorInterfaceExtensions + +# Test Tables interface with row access + IteratorInterfaceExtensions for QueryVerse +# (see https://tables.juliadata.org/stable/#Testing-Tables.jl-Implementations) +function test_tables_interface(x::AbstractDiffEqArray, names::Vector{Symbol}, values::Matrix) + @assert length(names) == size(values, 2) + + # AbstractDiffEqArray is a table with row access + @test Tables.istable(x) + @test Tables.istable(typeof(x)) + @test Tables.rowaccess(x) + @test Tables.rowaccess(typeof(x)) + @test !Tables.columnaccess(x) + @test !Tables.columnaccess(typeof(x)) + + # Check implementation of AbstractRow iterator + tbl = Tables.rows(x) + @test length(tbl) == size(values, 1) + @test Tables.istable(tbl) + @test Tables.istable(typeof(tbl)) + @test Tables.rowaccess(tbl) + @test Tables.rowaccess(typeof(tbl)) + @test Tables.rows(tbl) === tbl + + # Check implementation of AbstractRow subtype + for (i, row) in enumerate(tbl) + @test eltype(tbl) === typeof(row) + @test propertynames(row) == Tables.columnnames(row) == names + for (j, name) in enumerate(names) + @test getproperty(row, name) == Tables.getcolumn(row, name) == Tables.getcolumn(row, j) == values[i, j] + end + end + + # Check column access + coltbl = Tables.columns(x) + @test length(coltbl) == size(values, 2) + @test Tables.istable(coltbl) + @test Tables.istable(typeof(coltbl)) + @test Tables.columnaccess(coltbl) + @test Tables.columnaccess(typeof(coltbl)) + @test Tables.columns(coltbl) === coltbl + @test propertynames(coltbl) == Tables.columnnames(coltbl) == Tuple(names) + for (i, name) in enumerate(names) + @test getproperty(coltbl, name) == Tables.getcolumn(coltbl, name) == Tables.getcolumn(coltbl, i) == values[:, i] + end + + # IteratorInterfaceExtensions + @test IteratorInterfaceExtensions.isiterable(x) + iterator = IteratorInterfaceExtensions.getiterator(x) + for (i, row) in enumerate(iterator) + @test row isa NamedTuple + @test propertynames(row) == Tuple(names) + for (j, name) in enumerate(names) + @test getproperty(row, name) == row[j] == values[i, j] + end + end + + nothing +end From b26e35e5ce28e435415e9df1ef6ba480c5da0f56 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 11 Aug 2022 16:46:23 +0200 Subject: [PATCH 2/2] Revert unintentional changes --- test/downstream/symbol_indexing.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/downstream/symbol_indexing.jl b/test/downstream/symbol_indexing.jl index 2b7723ef..b8fce572 100644 --- a/test/downstream/symbol_indexing.jl +++ b/test/downstream/symbol_indexing.jl @@ -15,10 +15,10 @@ sol = solve(prob, Tsit5()) sol_new = DiffEqArray( sol.u[1:10], - sol.t[1:10], - sol.prob.f.syms, - sol.prob.f.indepsym, - sol.prob.f.observed, + sol.t[1:10], + sol.prob.f.syms, + sol.prob.f.indepsym, + sol.prob.f.observed, sol.prob.p ) @@ -42,4 +42,4 @@ sol = solve(prob, Tsit5()) ts = 0:0.5:10 sol_ts = sol(ts) @assert sol_ts isa DiffEqArray -test_tables_interface(sol_ts, [:timestamp, Symbol("x(t)"), Symbol("y(t)")], hcat(ts, Array(sol_ts)')) \ No newline at end of file +test_tables_interface(sol_ts, [:timestamp, Symbol("x(t)"), Symbol("y(t)")], hcat(ts, Array(sol_ts)'))