Skip to content

Commit c33c000

Browse files
Merge pull request #88 from SciML/linear_algebra
support linear algebra on ArrayPartition
2 parents 05a5a49 + 3a57acc commit c33c000

File tree

7 files changed

+68
-13
lines changed

7 files changed

+68
-13
lines changed

.travis.yml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@ language: julia
33
os:
44
- linux
55
julia:
6-
- 1.0
7-
- nightly
8-
matrix:
9-
allow_failures:
10-
- julia: nightly
6+
- 1
7+
# - nightly
8+
#matrix:
9+
# allow_failures:
10+
# - julia: nightly
1111
notifications:
1212
email: false
1313
# uncomment the following lines to override the default test script

Project.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,22 @@ version = "2.2.0"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
8+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
9+
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
10+
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
811
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
912
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1013
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1114
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1215
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
1316

1417
[compat]
15-
ArrayInterface = "1.2, 2.0"
18+
ArrayInterface = "2.7"
1619
RecipesBase = "0.7, 0.8, 1.0"
1720
Requires = "0.5, 1.0"
1821
StaticArrays = "0.10, 0.11, 0.12"
1922
ZygoteRules = "0.2"
20-
julia = "1"
23+
julia = "1.3"
2124

2225
[extras]
2326
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

src/RecursiveArrayTools.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ __precompile__()
33
module RecursiveArrayTools
44

55
using Requires, RecipesBase, StaticArrays, Statistics,
6-
ArrayInterface, ZygoteRules
6+
ArrayInterface, ZygoteRules, LinearAlgebra
77

88
abstract type AbstractVectorOfArray{T, N, A} <: AbstractArray{T, N} end
99
abstract type AbstractDiffEqArray{T, N, A} <: AbstractVectorOfArray{T, N, A} end

src/array_partition.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,8 @@ ArrayPartitionStyle(::Val{N}) where N = ArrayPartitionStyle{Broadcast.DefaultArr
233233
function Broadcast.BroadcastStyle(::ArrayPartitionStyle{AStyle}, ::ArrayPartitionStyle{BStyle}) where {AStyle, BStyle}
234234
ArrayPartitionStyle(Broadcast.BroadcastStyle(AStyle(), BStyle()))
235235
end
236+
Broadcast.BroadcastStyle(::ArrayPartitionStyle, ::Broadcast.DefaultArrayStyle{0}) = Broadcast.DefaultArrayStyle{1}()
237+
Broadcast.BroadcastStyle(::ArrayPartitionStyle, ::Broadcast.DefaultArrayStyle{N}) where N = Broadcast.DefaultArrayStyle{N}()
236238

237239
combine_styles(args::Tuple{}) = Broadcast.DefaultArrayStyle{0}()
238240
combine_styles(args::Tuple{Any}) = Broadcast.result_style(Broadcast.BroadcastStyle(args[1]))
@@ -252,7 +254,7 @@ end
252254
ArrayPartition(f, N)
253255
end
254256

255-
@inline function Base.copyto!(dest::ArrayPartition, bc::Broadcast.Broadcasted)
257+
@inline function Base.copyto!(dest::ArrayPartition, bc::Broadcast.Broadcasted{ArrayPartitionStyle{Style}}) where Style
256258
N = npartitions(dest, bc)
257259
for i in 1:N
258260
copyto!(dest.x[i], unpack(bc, i))
@@ -293,3 +295,10 @@ common_number(a, b) =
293295
(b == 0 ? a :
294296
(a == b ? a :
295297
throw(DimensionMismatch("number of partitions must be equal"))))
298+
299+
## Linear Algebra
300+
301+
ArrayInterface.zeromatrix(A::ArrayPartition) = ArrayInterface.zeromatrix(reduce(vcat,vec.(A.x)))
302+
LinearAlgebra.ldiv!(A::LinearAlgebra.LU,b::ArrayPartition) = ldiv!(A,Array(b))
303+
LinearAlgebra.ldiv!(A::LinearAlgebra.QR,b::ArrayPartition) = ldiv!(A,Array(b))
304+
LinearAlgebra.ldiv!(A::LinearAlgebra.SVD,b::ArrayPartition) = ldiv!(A,Array(b))

test/partitions_test.jl

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,21 @@ a = 5
2626
@. p = p*p2
2727
K = p.*p2
2828

29-
@test_broken p.*rand(10)
29+
x = rand(10)
30+
y = p.*x
31+
@test y[1:5] == p.x[1] .* x[1:5]
32+
@test y[6:10] == p.x[2] .* x[6:10]
33+
y = p.*x'
34+
for i in 1:10
35+
@test y[1:5,i] == p.x[1] .* x[i]
36+
@test y[6:10,i] == p.x[2] .* x[i]
37+
end
38+
y = p .* p'
39+
@test y[1:5,1:5] == p.x[1] .* p.x[1]'
40+
@test y[6:10,6:10] == p.x[2] .* p.x[2]'
41+
@test y[1:5,6:10] == p.x[1] .* p.x[2]'
42+
@test y[6:10,1:5] == p.x[2] .* p.x[1]'
43+
3044
b = rand(10)
3145
c = rand(10)
3246
copyto!(b,p)
@@ -94,13 +108,11 @@ S = [
94108
]
95109

96110
for sizes in S
97-
x = ArrayPartition( randn.(sizes[1]) )
111+
x = ArrayPartition( randn.(sizes[1]) )
98112
y = ArrayPartition( zeros.(sizes[2]) )
99113
y_array = zeros(length(x))
100114
copyto!(y,x) #testing Base.copyto!(dest::ArrayPartition,A::ArrayPartition)
101115
copyto!(y_array,x) #testing Base.copyto!(dest::Array,A::ArrayPartition)
102116
@test all([x[i] == y[i] for i in eachindex(x)])
103117
@test all([x[i] == y_array[i] for i in eachindex(x)])
104118
end
105-
106-

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@ using Test
77
@time @testset "VecOfArr Indexing Tests" begin include("basic_indexing.jl") end
88
@time @testset "VecOfArr Interface Tests" begin include("interface_tests.jl") end
99
@time @testset "StaticArrays Tests" begin include("copy_static_array_test.jl") end
10+
@time @testset "Upstream Tests" begin include("upstream.jl") end
1011
end

test/upstream.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
using OrdinaryDiffEq, NLsolve, RecursiveArrayTools, Test, ArrayInterface
2+
function lorenz(du,u,p,t)
3+
du[1] = 10.0*(u[2]-u[1])
4+
du[2] = u[1]*(28.0-u[3]) - u[2]
5+
du[3] = u[1]*u[2] - (8/3)*u[3]
6+
end
7+
u0 = ArrayPartition([1.0,0.0],[0.0])
8+
@test ArrayInterface.zeromatrix(u0) isa Matrix
9+
tspan = (0.0,100.0)
10+
prob = ODEProblem(lorenz,u0,tspan)
11+
sol = solve(prob,Tsit5())
12+
sol = solve(prob,AutoTsit5(Rosenbrock23(autodiff=false)))
13+
sol = solve(prob,AutoTsit5(Rosenbrock23()))
14+
15+
function f!(F, vars)
16+
x = vars.x[1]
17+
F.x[1][1] = (x[1]+3)*(x[2]^3-7)+18
18+
F.x[1][2] = sin(x[2]*exp(x[1])-1)
19+
y=vars.x[2]
20+
F.x[2][1] = (y[1]+3)*(y[2]^3-7)+18
21+
F.x[2][2] = sin(y[2]*exp(y[1])-1)
22+
end
23+
24+
# To show that the function works
25+
F = ArrayPartition([0.0 0.0],[0.0, 0.0])
26+
u0= ArrayPartition([0.1; 1.2], [0.1; 1.2])
27+
result = f!(F, u0)
28+
29+
# To show the NLsolve error that results with ArrayPartitions:
30+
nlsolve(f!, ArrayPartition([0.1; 1.2], [0.1; 1.2]))

0 commit comments

Comments
 (0)