Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 16 additions & 22 deletions src/conversions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -292,37 +292,31 @@ append!(a::PyVector{T}, items) where {T} = PyVector{T}(append!(a.o, items))
#########################################################################
# Lists and 1d arrays.

if VERSION < v"1.1.0-DEV.392" # #29440
cirange(I,J) = CartesianIndices(map((i,j) -> i:j, Tuple(I), Tuple(J)))
else
cirange(I,J) = I:J
end

# recursive conversion of A to a list of list of lists... starting
# with dimension dim and index i in A.
function array2py(A::AbstractArray{T, N}, dim::Integer, i::Integer) where {T, N}
if dim > N
# with dimension dim and Cartesian index i in A.
function array2py(A::AbstractArray{<:Any, N}, dim::Integer, i::CartesianIndex{N}) where {N}
if dim > N # base case
return PyObject(A[i])
elseif dim == N # special case last dim to coarsen recursion leaves
len = size(A, dim)
s = N == 1 ? 1 : stride(A, dim)
o = PyObject(@pycheckn ccall((@pysym :PyList_New), PyPtr, (Int,), len))
for j = 0:len-1
oi = PyObject(A[i+j*s])
@pycheckz ccall((@pysym :PyList_SetItem), Cint, (PyPtr,Int,PyPtr),
o, j, oi)
pyincref(oi) # PyList_SetItem steals the reference
end
return o
else # dim < N: store multidimensional array as list of lists
len = size(A, dim)
s = stride(A, dim)
o = PyObject(@pycheckn ccall((@pysym :PyList_New), PyPtr, (Int,), len))
for j = 0:len-1
oi = array2py(A, dim+1, i+j*s)
else # recursively store multidimensional array as list of lists
ilast = CartesianIndex(ntuple(j -> j == dim ? lastindex(A, dim) : i[j], Val{N}()))
o = PyObject(@pycheckn ccall((@pysym :PyList_New), PyPtr, (Int,), size(A, dim)))
for icur in cirange(i,ilast)
oi = array2py(A, dim+1, icur)
@pycheckz ccall((@pysym :PyList_SetItem), Cint, (PyPtr,Int,PyPtr),
o, j, oi)
o, icur[dim]-i[dim], oi)
pyincref(oi) # PyList_SetItem steals the reference
end
return o
end
end

array2py(A::AbstractArray) = array2py(A, 1, 1)
array2py(A::AbstractArray) = array2py(A, 1, first(CartesianIndices(A)))

PyObject(A::AbstractArray) =
ndims(A) <= 1 || hasmethod(stride, Tuple{typeof(A),Int}) ? array2py(A) :
Expand Down
20 changes: 18 additions & 2 deletions src/numpy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,13 +190,23 @@ function PyObject(a::StridedArray{T}) where T<:PYARR_TYPES
try
return NpyArray(a, false)
catch
array2py(a) # fallback to non-NumPy version
return array2py(a) # fallback to non-NumPy version
end
end

PyReverseDims(a::StridedArray{T}) where {T<:PYARR_TYPES} = NpyArray(a, true)
function PyReverseDims(a::StridedArray{T,N}) where {T<:PYARR_TYPES,N}
try
return NpyArray(a, true)
catch
return array2py(permutedims(a, N:-1:1)) # fallback to non-NumPy version
end
end
PyReverseDims(a::BitArray) = PyReverseDims(Array(a))

# fallback to physically transposing the array
PyReverseDims(a::AbstractArray{<:Any,N}) where {N} = PyObject(permutedims(a, N:-1:1))
PyReverseDims(a::AbstractMatrix) = PyObject(permutedims(a))

"""
PyReverseDims(array)

Expand All @@ -209,3 +219,9 @@ libraries that expect row-major data.
PyReverseDims(a::AbstractArray)

#########################################################################

# transposed arrays can be passed to NumPy without copying
PyObject(a::Union{LinearAlgebra.Adjoint{<:Real},LinearAlgebra.Transpose}) =
PyReverseDims(a.parent)

PyObject(a::LinearAlgebra.Adjoint) = PyObject(Matrix(a)) # non-real arrays require a copy
2 changes: 1 addition & 1 deletion src/pybuffer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ function array_format(pybuf::PyBuffer)
use_native_sizes = false
elseif fmt_str[1] == '='
use_native_sizes = false
elseif fmt_str[1] == "Z"
elseif fmt_str[1] == 'Z'
type_start_idx = 1
else
error("Unsupported format string: \"$fmt_str\"")
Expand Down
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ const PyInt = pyversion < v"3" ? Int : Clonglong
@test roundtripeq(C_NULL) && roundtripeq(convert(Ptr{Cvoid}, 12345))
@test roundtripeq([1,3,4,5]) && roundtripeq([1,3.2,"hello",true])
@test roundtripeq([1 2 3;4 5 6]) && roundtripeq([1. 2 3;4 5 6])
@test roundtripeq([1. 2 3;4 5 6]')
@test roundtripeq([1.0+2im 2+3im 3;4 5 6]')
@test roundtripeq((1,(3.2,"hello"),true)) && roundtripeq(())
@test roundtripeq(Int32)
@test roundtripeq(Dict(1 => "hello", 2 => "goodbye")) && roundtripeq(Dict())
Expand Down Expand Up @@ -119,6 +121,7 @@ const PyInt = pyversion < v"3" ? Int : Clonglong
array2py2arrayeq(x) = PyCall.py2array(Float64,PyCall.array2py(x)) == x
@test array2py2arrayeq(rand(3))
@test array2py2arrayeq(rand(3,4))
@test array2py2arrayeq(rand(3,4)')
@test array2py2arrayeq(rand(3,4,5))

@test roundtripeq(2:10) && roundtripeq(10:-1:2)
Expand Down