Skip to content

Commit f8cfd1c

Browse files
fix: update compat entires, reduce method ambiguities
1 parent 094c2bb commit f8cfd1c

File tree

8 files changed

+74
-112
lines changed

8 files changed

+74
-112
lines changed

Project.toml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ IteratorInterfaceExtensions = "82899510-4779-5014-852e-03e436cf321d"
1212
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1313
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
1414
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
15+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1516
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
1617
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1718
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
@@ -39,15 +40,17 @@ GPUArraysCore = "0.1"
3940
IteratorInterfaceExtensions = "1"
4041
LabelledArrays = "1"
4142
LinearAlgebra = "1"
42-
Measurements = "2"
43-
MonteCarloMeasurements = "1"
43+
Measurements = "2.3"
44+
MonteCarloMeasurements = "1.1"
4445
NLsolve = "4"
46+
OrdinaryDiffEq = "6"
4547
Pkg = "1"
4648
Random = "1"
4749
RecipesBase = "0.7, 0.8, 1.0"
4850
Requires = "1.0"
4951
SafeTestsets = "0.1"
50-
StaticArrays = "0.12"
52+
SparseArrays = "1"
53+
StaticArrays = "1.6"
5154
StaticArraysCore = "1.1"
5255
Statistics = "1"
5356
StructArrays = "0.6"

src/RecursiveArrayTools.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using DocStringExtensions
88
using RecipesBase, StaticArraysCore, Statistics,
99
ArrayInterface, LinearAlgebra
1010
using SymbolicIndexingInterface
11+
using SparseArrays
1112

1213
import Adapt
1314

src/array_partition.jl

Lines changed: 55 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -176,14 +176,16 @@ Base.all(f, A::ArrayPartition) = all(f, (all(f, x) for x in A.x))
176176
Base.all(f::Function, A::ArrayPartition) = all((all(f, x) for x in A.x))
177177
Base.all(A::ArrayPartition) = all(identity, A)
178178

179-
function Base.copyto!(dest::AbstractArray, A::ArrayPartition)
180-
@assert length(dest) == length(A)
181-
cur = 1
182-
@inbounds for i in 1:length(A.x)
183-
dest[cur:(cur + length(A.x[i]) - 1)] .= vec(A.x[i])
184-
cur += length(A.x[i])
179+
for type in [AbstractArray, SparseArrays.AbstractCompressedVector, PermutedDimsArray]
180+
@eval function Base.copyto!(dest::$(type), A::ArrayPartition)
181+
@assert length(dest) == length(A)
182+
cur = 1
183+
@inbounds for i in 1:length(A.x)
184+
dest[cur:(cur + length(A.x[i]) - 1)] .= vec(A.x[i])
185+
cur += length(A.x[i])
186+
end
187+
dest
185188
end
186-
dest
187189
end
188190

189191
function Base.copyto!(A::ArrayPartition, src::ArrayPartition)
@@ -419,30 +421,38 @@ end
419421

420422
ArrayInterface.zeromatrix(A::ArrayPartition) = ArrayInterface.zeromatrix(Vector(A))
421423

422-
function LinearAlgebra.ldiv!(A::Factorization, b::ArrayPartition)
423-
(x = ldiv!(A, Array(b)); copyto!(b, x))
424+
function __get_subtypes_in_module(mod, supertype; include_supertype = true, all=false, except=[])
425+
return filter([getproperty(mod, name) for name in names(mod; all) if !in(name, except)]) do value
426+
return value isa Type && (value <: supertype) && (include_supertype || value != supertype) && !in(value, except)
427+
end
424428
end
425429

426-
@static if VERSION >= v"1.9"
427-
function LinearAlgebra.ldiv!(A::LinearAlgebra.SVD{T, Tr, M},
428-
b::ArrayPartition) where {Tr, T, M <: AbstractArray{T}}
430+
for factorization in vcat(__get_subtypes_in_module(LinearAlgebra, Factorization; include_supertype = false, all=true, except=[:LU, :LAPACKFactorizations]), LDLt{T,<:SymTridiagonal{T,V} where {V<:AbstractVector{T}}} where {T})
431+
@eval function LinearAlgebra.ldiv!(A::T, b::ArrayPartition) where {T<:$factorization}
429432
(x = ldiv!(A, Array(b)); copyto!(b, x))
430433
end
434+
end
431435

432-
function LinearAlgebra.ldiv!(A::LinearAlgebra.QRCompactWY{T, M, C},
433-
b::ArrayPartition) where {
434-
T <: Union{Float32, Float64, ComplexF64, ComplexF32},
435-
M <: AbstractMatrix{T},
436-
C <: AbstractMatrix{T},
437-
}
438-
(x = ldiv!(A, Array(b)); copyto!(b, x))
439-
end
436+
function LinearAlgebra.ldiv!(A::LinearAlgebra.SVD{T, Tr, M},
437+
b::ArrayPartition) where {Tr, T, M <: AbstractArray{T}}
438+
(x = ldiv!(A, Array(b)); copyto!(b, x))
440439
end
441440

442-
function LinearAlgebra.ldiv!(A::LU, b::ArrayPartition)
443-
LinearAlgebra._ipiv_rows!(A, 1:length(A.ipiv), b)
444-
ldiv!(UpperTriangular(A.factors), ldiv!(UnitLowerTriangular(A.factors), b))
445-
return b
441+
function LinearAlgebra.ldiv!(A::LinearAlgebra.QRCompactWY{T, M, C},
442+
b::ArrayPartition) where {
443+
T <: Union{Float32, Float64, ComplexF64, ComplexF32},
444+
M <: AbstractMatrix{T},
445+
C <: AbstractMatrix{T},
446+
}
447+
(x = ldiv!(A, Array(b)); copyto!(b, x))
448+
end
449+
450+
for type in [LU, LU{T,Tridiagonal{T,V}} where {T,V}]
451+
@eval function LinearAlgebra.ldiv!(A::$type, b::ArrayPartition)
452+
LinearAlgebra._ipiv_rows!(A, 1:length(A.ipiv), b)
453+
ldiv!(UpperTriangular(A.factors), ldiv!(UnitLowerTriangular(A.factors), b))
454+
return b
455+
end
446456
end
447457

448458
# block matrix indexing
@@ -458,78 +468,31 @@ end
458468
# [U11 U12 U13] [ b1 ]
459469
# [ 0 U22 U23] \ [ b2 ]
460470
# [ 0 0 U33] [ b3 ]
461-
function LinearAlgebra.ldiv!(A::UnitUpperTriangular, bb::ArrayPartition)
462-
A = A.data
463-
n = npartitions(bb)
464-
b = bb.x
465-
lens = map(length, b)
466-
@inbounds for j in n:-1:1
467-
Ajj = UnitUpperTriangular(getblock(A, lens, j, j))
468-
xj = ldiv!(Ajj, vec(b[j]))
469-
for i in (j - 1):-1:1
470-
Aij = getblock(A, lens, i, j)
471-
# bi = -Aij * xj + bi
472-
mul!(vec(b[i]), Aij, xj, -1, true)
473-
end
474-
end
475-
return bb
476-
end
477-
478-
function LinearAlgebra.ldiv!(A::UpperTriangular, bb::ArrayPartition)
479-
A = A.data
480-
n = npartitions(bb)
481-
b = bb.x
482-
lens = map(length, b)
483-
@inbounds for j in n:-1:1
484-
Ajj = UpperTriangular(getblock(A, lens, j, j))
485-
xj = ldiv!(Ajj, vec(b[j]))
486-
for i in (j - 1):-1:1
487-
Aij = getblock(A, lens, i, j)
488-
# bi = -Aij * xj + bi
489-
mul!(vec(b[i]), Aij, xj, -1, true)
490-
end
491-
end
492-
return bb
493-
end
494-
495-
function LinearAlgebra.ldiv!(A::UnitLowerTriangular, bb::ArrayPartition)
496-
A = A.data
497-
n = npartitions(bb)
498-
b = bb.x
499-
lens = map(length, b)
500-
@inbounds for j in 1:n
501-
Ajj = UnitLowerTriangular(getblock(A, lens, j, j))
502-
xj = ldiv!(Ajj, vec(b[j]))
503-
for i in (j + 1):n
504-
Aij = getblock(A, lens, i, j)
505-
# bi = -Aij * xj + b[i]
506-
mul!(vec(b[i]), Aij, xj, -1, true)
471+
for basetype in [UnitUpperTriangular, UpperTriangular, UnitLowerTriangular, LowerTriangular]
472+
for type in [basetype, basetype{T, <:Adjoint{T}} where {T}, basetype{T, <:Transpose{T}} where {T}]
473+
j_iter, i_iter = if basetype <: UnitUpperTriangular || basetype <: UpperTriangular
474+
(:(n:-1:1), :(j-1:-1:1))
475+
else
476+
(:(1:n), :((j+1):n))
507477
end
508-
end
509-
return bb
510-
end
511-
function _ldiv!(A::LowerTriangular, bb::ArrayPartition)
512-
A = A.data
513-
n = npartitions(bb)
514-
b = bb.x
515-
lens = map(length, b)
516-
@inbounds for j in 1:n
517-
Ajj = LowerTriangular(getblock(A, lens, j, j))
518-
xj = ldiv!(Ajj, vec(b[j]))
519-
for i in (j + 1):n
520-
Aij = getblock(A, lens, i, j)
521-
# bi = -Aij * xj + b[i]
522-
mul!(vec(b[i]), Aij, xj, -1, true)
478+
@eval function LinearAlgebra.ldiv!(A::$type, bb::ArrayPartition)
479+
A = A.data
480+
n = npartitions(bb)
481+
b = bb.x
482+
lens = map(length, b)
483+
@inbounds for j in $j_iter
484+
Ajj = $basetype(getblock(A, lens, j, j))
485+
xj = ldiv!(Ajj, vec(b[j]))
486+
for i in $i_iter
487+
Aij = getblock(A, lens, i, j)
488+
# bi = -Aij * xj + bi
489+
mul!(vec(b[i]), Aij, xj, -1, true)
490+
end
491+
end
492+
return bb
523493
end
524494
end
525-
return bb
526-
end
527-
528-
function LinearAlgebra.ldiv!(A::LowerTriangular{T, <:LinearAlgebra.Adjoint{T}},
529-
bb::ArrayPartition) where {T}
530-
_ldiv!(A, bb)
531495
end
532-
LinearAlgebra.ldiv!(A::LowerTriangular, bb::ArrayPartition) = _ldiv!(A, bb)
533496

534497
# TODO: optimize
535498
function LinearAlgebra._ipiv_rows!(A::LU, order::OrdinalRange, B::ArrayPartition)

src/vector_of_array.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ function VectorOfArray(vec::AbstractVector{VT}) where {T, N, VT <: AbstractArray
139139
end
140140

141141
function DiffEqArray(vec::AbstractVector{T},
142-
ts,
142+
ts::AbstractVector,
143143
::NTuple{N, Int},
144144
p = nothing,
145145
sys = nothing) where {T, N}
@@ -532,7 +532,7 @@ function Base.CartesianIndices(VA::AbstractVectorOfArray)
532532
end
533533

534534
# Tools for creating similar objects
535-
Base.eltype(::VectorOfArray{T}) where {T} = T
535+
Base.eltype(::Type{<:AbstractVectorOfArray{T}}) where {T} = T
536536
# TODO: Is there a better way to do this?
537537
@inline function Base.similar(VA::AbstractVectorOfArray, args...)
538538
if args[end] isa Type

test/interface_tests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using RecursiveArrayTools, Test
1+
using RecursiveArrayTools, StaticArrays, Test
22

33
t = 1:3
44
testva = VectorOfArray([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

test/qa.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
using RecursiveArrayTools, Aqua
22
@testset "Aqua" begin
33
Aqua.find_persistent_tasks_deps(RecursiveArrayTools)
4-
Aqua.test_ambiguities(RecursiveArrayTools, recursive = false, broken = true)
4+
ambs = Aqua.detect_ambiguities(RecursiveArrayTools; recursive = true)
5+
@warn "Number of method ambiguities: $(length(ambs))"
6+
@test length(ambs) <= 2
57
Aqua.test_deps_compat(RecursiveArrayTools)
68
Aqua.test_piracies(RecursiveArrayTools)
79
Aqua.test_project_extras(RecursiveArrayTools)

test/runtests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ end
5151
@time @safetestset "Upstream Tests" begin
5252
include("upstream.jl")
5353
end
54-
# @time @safetestset "Adjoint Tests" begin include("adjoints.jl") end
54+
@time @safetestset "Adjoint Tests" begin include("adjoints.jl") end
5555
@time @safetestset "Measurement Tests" begin
5656
include("measurements.jl")
5757
end
@@ -65,7 +65,7 @@ end
6565
@time @safetestset "Event Tests with ArrayPartition" begin
6666
include("downstream/downstream_events.jl")
6767
end
68-
VERSION >= v"1.9" && @time @safetestset "Measurements and Units" begin
68+
@time @safetestset "Measurements and Units" begin
6969
include("downstream/measurements_and_units.jl")
7070
end
7171
@time @safetestset "TrackerExt" begin

test/upstream.jl

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,7 @@ end
4141
ArrayPartition(zeros(1), [0.75])),
4242
(0.0, 1.0)), AutoTsit5(Rodas5())).retcode == ReturnCode.Success
4343

44-
if VERSION < v"1.7"
45-
@test solve(ODEProblem(dyn,
46-
ArrayPartition(ArrayPartition(zeros(1), [-1.0]),
47-
ArrayPartition(zeros(1), [0.75])),
48-
(0.0, 1.0)), Rodas5()).retcode == ReturnCode.Success
49-
else
50-
@test_broken solve(ODEProblem(dyn,
51-
ArrayPartition(ArrayPartition(zeros(1), [-1.0]),
52-
ArrayPartition(zeros(1), [0.75])),
53-
(0.0, 1.0)), Rodas5()).retcode == ReturnCode.Success
54-
end
44+
@test_broken solve(ODEProblem(dyn,
45+
ArrayPartition(ArrayPartition(zeros(1), [-1.0]),
46+
ArrayPartition(zeros(1), [0.75])),
47+
(0.0, 1.0)), Rodas5()).retcode == ReturnCode.Success

0 commit comments

Comments
 (0)