Skip to content

Commit c9958b8

Browse files
Merge pull request #89 from SciML/allocs
fix broadcast allocations
2 parents 53bcecb + 160d10c commit c9958b8

File tree

3 files changed

+24
-8
lines changed

3 files changed

+24
-8
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@ version = "2.3.0"
66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
88
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
9-
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
10-
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
119
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
1210
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1311
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
@@ -23,8 +21,10 @@ ZygoteRules = "0.2"
2321
julia = "1.3"
2422

2523
[extras]
24+
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
25+
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
2626
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2727
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
2828

2929
[targets]
30-
test = ["Test", "Unitful"]
30+
test = ["NLsolve", "OrdinaryDiffEq", "Test", "Unitful"]

src/array_partition.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -230,15 +230,15 @@ ArrayPartitionStyle(::S, ::Val{N}) where {S,N} = ArrayPartitionStyle(S(Val(N)))
230230
ArrayPartitionStyle(::Val{N}) where N = ArrayPartitionStyle{Broadcast.DefaultArrayStyle{N}}()
231231

232232
# promotion rules
233-
function Broadcast.BroadcastStyle(::ArrayPartitionStyle{AStyle}, ::ArrayPartitionStyle{BStyle}) where {AStyle, BStyle}
233+
@inline function Broadcast.BroadcastStyle(::ArrayPartitionStyle{AStyle}, ::ArrayPartitionStyle{BStyle}) where {AStyle, BStyle}
234234
ArrayPartitionStyle(Broadcast.BroadcastStyle(AStyle(), BStyle()))
235235
end
236-
Broadcast.BroadcastStyle(::ArrayPartitionStyle, ::Broadcast.DefaultArrayStyle{0}) = Broadcast.DefaultArrayStyle{1}()
236+
Broadcast.BroadcastStyle(::ArrayPartitionStyle{Style}, ::Broadcast.DefaultArrayStyle{0}) where Style<:Broadcast.BroadcastStyle = ArrayPartitionStyle{Style}()
237237
Broadcast.BroadcastStyle(::ArrayPartitionStyle, ::Broadcast.DefaultArrayStyle{N}) where N = Broadcast.DefaultArrayStyle{N}()
238238

239239
combine_styles(args::Tuple{}) = Broadcast.DefaultArrayStyle{0}()
240-
combine_styles(args::Tuple{Any}) = Broadcast.result_style(Broadcast.BroadcastStyle(args[1]))
241-
combine_styles(args::Tuple{Any, Any}) = Broadcast.result_style(Broadcast.BroadcastStyle(args[1]), Broadcast.BroadcastStyle(args[2]))
240+
@inline combine_styles(args::Tuple{Any}) = Broadcast.result_style(Broadcast.BroadcastStyle(args[1]))
241+
@inline combine_styles(args::Tuple{Any, Any}) = Broadcast.result_style(Broadcast.BroadcastStyle(args[1]), Broadcast.BroadcastStyle(args[2]))
242242
@inline combine_styles(args::Tuple) = Broadcast.result_style(Broadcast.BroadcastStyle(args[1]), combine_styles(Base.tail(args)))
243243

244244
function Broadcast.BroadcastStyle(::Type{ArrayPartition{T,S}}) where {T, S}
@@ -256,7 +256,7 @@ end
256256

257257
@inline function Base.copyto!(dest::ArrayPartition, bc::Broadcast.Broadcasted{ArrayPartitionStyle{Style}}) where Style
258258
N = npartitions(dest, bc)
259-
for i in 1:N
259+
@inbounds for i in 1:N
260260
copyto!(dest.x[i], unpack(bc, i))
261261
end
262262
dest

test/partitions_test.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,3 +119,19 @@ for sizes in S
119119
@test all([x[i] == y[i] for i in eachindex(x)])
120120
@test all([x[i] == y_array[i] for i in eachindex(x)])
121121
end
122+
123+
# Non-allocating broadcast
124+
xce0 = ArrayPartition(zeros(2),[0.])
125+
xcde0 = copy(xce0)
126+
function foo(y, x)
127+
y .= y .+ x
128+
nothing
129+
end
130+
foo(xcde0, xce0)
131+
#@test 0 == @allocated foo(xcde0, xce0)
132+
function foo(y, x)
133+
y .= y .+ 2 .* x
134+
nothing
135+
end
136+
foo(xcde0, xce0)
137+
#@test 0 == @allocated foo(xcde0, xce0)

0 commit comments

Comments
 (0)