Skip to content

Commit 056acaf

Browse files
committed
Added tests for custom type broadcasting
1 parent ec1e144 commit 056acaf

File tree

1 file changed

+37
-0
lines changed

1 file changed

+37
-0
lines changed

test/partitions_test.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,43 @@ end
136136
foo(xcde0, xce0)
137137
#@test 0 == @allocated foo(xcde0, xce0)
138138

139+
# Custom AbstractArray types broadcasting
140+
struct MyType{T} <: AbstractVector{T}
141+
data :: Vector{T}
142+
end
143+
Base.similar(A::MyType{T}) where {T} = MyType{T}(similar(A.data))
144+
Base.similar(A::MyType{T},::Type{S}) where {T,S} = MyType(similar(A.data,S))
145+
146+
Base.size(A::MyType) = size(A.data)
147+
Base.getindex(A::MyType, i::Int) = getindex(A.data,i)
148+
Base.setindex!(A::MyType, v, i::Int) = setindex!(A.data,v,i)
149+
Base.IndexStyle(::MyType) = IndexLinear()
150+
151+
Base.BroadcastStyle(::Type{<:MyType}) = Broadcast.ArrayStyle{MyType}()
152+
153+
function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyType}},::Type{T}) where {T}
154+
similar(find_mt(bc),T)
155+
end
156+
157+
function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyType}})
158+
similar(find_mt(bc))
159+
end
160+
161+
find_mt(bc::Base.Broadcast.Broadcasted) = find_mt(bc.args)
162+
find_mt(args::Tuple) = find_mt(find_mt(args[1]), Base.tail(args))
163+
find_mt(x) = x
164+
find_mt(::Tuple{}) = nothing
165+
find_mt(a::MyType, rest) = a
166+
find_mt(::Any, rest) = find_mt(rest)
167+
168+
ap = ArrayPartition(MyType(ones(10)),collect(1:2))
169+
up = ap .+ 1
170+
@test typeof(ap) == typeof(up)
171+
172+
up = 2 .* ap .+ 1
173+
@test typeof(ap) == typeof(up)
174+
175+
139176
@testset "ArrayInterface.ismutable(ArrayPartition($a, $b)) == $r" for (a, b, r) in ((1,2, false), ([1], 2, false), ([1], [2], true))
140177
@test ArrayInterface.ismutable(ArrayPartition(a, b)) == r
141178
end

0 commit comments

Comments
 (0)