Skip to content

Commit 3a57acc

Browse files
fix broadcast against non APs and test upstream
1 parent fc06f6a commit 3a57acc

File tree

3 files changed

+23
-9
lines changed

3 files changed

+23
-9
lines changed

src/array_partition.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,8 @@ ArrayPartitionStyle(::Val{N}) where N = ArrayPartitionStyle{Broadcast.DefaultArr
233233
function Broadcast.BroadcastStyle(::ArrayPartitionStyle{AStyle}, ::ArrayPartitionStyle{BStyle}) where {AStyle, BStyle}
234234
ArrayPartitionStyle(Broadcast.BroadcastStyle(AStyle(), BStyle()))
235235
end
236+
Broadcast.BroadcastStyle(::ArrayPartitionStyle, ::Broadcast.DefaultArrayStyle{0}) = Broadcast.DefaultArrayStyle{1}()
237+
Broadcast.BroadcastStyle(::ArrayPartitionStyle, ::Broadcast.DefaultArrayStyle{N}) where N = Broadcast.DefaultArrayStyle{N}()
236238

237239
combine_styles(args::Tuple{}) = Broadcast.DefaultArrayStyle{0}()
238240
combine_styles(args::Tuple{Any}) = Broadcast.result_style(Broadcast.BroadcastStyle(args[1]))
@@ -252,7 +254,7 @@ end
252254
ArrayPartition(f, N)
253255
end
254256

255-
@inline function Base.copyto!(dest::ArrayPartition, bc::Broadcast.Broadcasted)
257+
@inline function Base.copyto!(dest::ArrayPartition, bc::Broadcast.Broadcasted{ArrayPartitionStyle{Style}}) where Style
256258
N = npartitions(dest, bc)
257259
for i in 1:N
258260
copyto!(dest.x[i], unpack(bc, i))

test/partitions_test.jl

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,21 @@ a = 5
2626
@. p = p*p2
2727
K = p.*p2
2828

29-
@test_broken p.*rand(10)
29+
x = rand(10)
30+
y = p.*x
31+
@test y[1:5] == p.x[1] .* x[1:5]
32+
@test y[6:10] == p.x[2] .* x[6:10]
33+
y = p.*x'
34+
for i in 1:10
35+
@test y[1:5,i] == p.x[1] .* x[i]
36+
@test y[6:10,i] == p.x[2] .* x[i]
37+
end
38+
y = p .* p'
39+
@test y[1:5,1:5] == p.x[1] .* p.x[1]'
40+
@test y[6:10,6:10] == p.x[2] .* p.x[2]'
41+
@test y[1:5,6:10] == p.x[1] .* p.x[2]'
42+
@test y[6:10,1:5] == p.x[2] .* p.x[1]'
43+
3044
b = rand(10)
3145
c = rand(10)
3246
copyto!(b,p)
@@ -94,13 +108,11 @@ S = [
94108
]
95109

96110
for sizes in S
97-
x = ArrayPartition( randn.(sizes[1]) )
111+
x = ArrayPartition( randn.(sizes[1]) )
98112
y = ArrayPartition( zeros.(sizes[2]) )
99113
y_array = zeros(length(x))
100114
copyto!(y,x) #testing Base.copyto!(dest::ArrayPartition,A::ArrayPartition)
101115
copyto!(y_array,x) #testing Base.copyto!(dest::Array,A::ArrayPartition)
102116
@test all([x[i] == y[i] for i in eachindex(x)])
103117
@test all([x[i] == y_array[i] for i in eachindex(x)])
104118
end
105-
106-

test/upstream.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@ u0 = ArrayPartition([1.0,0.0],[0.0])
88
@test ArrayInterface.zeromatrix(u0) isa Matrix
99
tspan = (0.0,100.0)
1010
prob = ODEProblem(lorenz,u0,tspan)
11-
sol = solve(prob,Tsit5());
12-
sol = solve(prob,AutoTsit5(Rosenbrock23(autodiff=false)));
13-
@test_broken sol = solve(prob,AutoTsit5(Rosenbrock23()));
11+
sol = solve(prob,Tsit5())
12+
sol = solve(prob,AutoTsit5(Rosenbrock23(autodiff=false)))
13+
sol = solve(prob,AutoTsit5(Rosenbrock23()))
1414

1515
function f!(F, vars)
1616
x = vars.x[1]
@@ -24,7 +24,7 @@ end
2424
# To show that the function works
2525
F = ArrayPartition([0.0 0.0],[0.0, 0.0])
2626
u0= ArrayPartition([0.1; 1.2], [0.1; 1.2])
27-
result = mymodel(F, u0)
27+
result = f!(F, u0)
2828

2929
# To show the NLsolve error that results with ArrayPartitions:
3030
nlsolve(f!, ArrayPartition([0.1; 1.2], [0.1; 1.2]))

0 commit comments

Comments
 (0)