diff --git a/src/array_partition.jl b/src/array_partition.jl index ed669f3c..85c9d792 100644 --- a/src/array_partition.jl +++ b/src/array_partition.jl @@ -103,22 +103,29 @@ end Base.mapreduce(f,op,A::ArrayPartition) = mapreduce(f,op,(mapreduce(f,op,x) for x in A.x)) Base.any(f,A::ArrayPartition) = any(f,(any(f,x) for x in A.x)) Base.any(f::Function,A::ArrayPartition) = any(f,(any(f,x) for x in A.x)) -function Base.copyto!(dest::Array,A::ArrayPartition) +function Base.copyto!(dest::AbstractArray,A::ArrayPartition) @assert length(dest) == length(A) cur = 1 @inbounds for i in 1:length(A.x) - dest[cur:(cur+length(A.x[i])-1)] .= A.x[i] - cur += length(A.x[i]) + dest[cur:(cur+length(A.x[i])-1)] .= vec(A.x[i]) + cur += length(A.x[i]) end dest end function Base.copyto!(A::ArrayPartition,src::ArrayPartition) @assert length(src) == length(A) - cur = 1 - @inbounds for i in 1:length(A.x) - A.x[i] .= @view(src[cur:(cur+length(A.x[i])-1)]) - cur += length(A.x[i]) + if size.(A.x) == size.(src.x) + A .= src + else + cnt = 0 + for i in eachindex(A.x) + x = A.x[i] + for k in eachindex(x) + cnt += 1 + x[k] = src[cnt] + end + end end A end diff --git a/test/partitions_test.jl b/test/partitions_test.jl index 9021fe6d..0f9f3ff3 100644 --- a/test/partitions_test.jl +++ b/test/partitions_test.jl @@ -85,3 +85,22 @@ _scalar_op(y) = y + 1 _broadcast_wrapper(y) = _scalar_op.(y) # Issue #8 # @inferred _broadcast_wrapper(x) + +#### testing copyto! +S = [ + ((1,),(2,)) => ((1,),(2,)), + ((3,2),(2,)) => ((3,2),(2,)), + ((3,2),(2,)) => ((3,),(3,),(2,)) + ] + +for sizes in S + x = ArrayPartition( randn.(sizes[1]) ) + y = ArrayPartition( zeros.(sizes[2]) ) + y_array = zeros(length(x)) + copyto!(y,x) #testing Base.copyto!(dest::ArrayPartition,A::ArrayPartition) + copyto!(y_array,x) #testing Base.copyto!(dest::Array,A::ArrayPartition) + @test all([x[i] == y[i] for i in eachindex(x)]) + @test all([x[i] == y_array[i] for i in eachindex(x)]) +end + +