Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 78 additions & 33 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2136,6 +2136,7 @@ _hvncat(dimsshape::Union{Tuple, Int}, row_first::Bool, xs::Number...) = _typed_h
_hvncat(dimsshape::Union{Tuple, Int}, row_first::Bool, xs::AbstractArray...) = _typed_hvncat(promote_eltype(xs...), dimsshape, row_first, xs...)
_hvncat(dimsshape::Union{Tuple, Int}, row_first::Bool, xs::AbstractArray{T}...) where T = _typed_hvncat(T, dimsshape, row_first, xs...)


typed_hvncat(T::Type, dimsshape::Tuple, row_first::Bool, xs...) = _typed_hvncat(T, dimsshape, row_first, xs...)
typed_hvncat(T::Type, dim::Int, xs...) = _typed_hvncat(T, Val(dim), xs...)

Expand All @@ -2152,9 +2153,9 @@ _typed_hvncat(::Type, ::Val{0}, ::AbstractArray...) = _typed_hvncat_0d_only_one(
_typed_hvncat_0d_only_one() =
throw(ArgumentError("a 0-dimensional array may only contain exactly one element"))

_typed_hvncat(::Type{T}, ::Val{N}) where {T, N} = Array{T, N}(undef, ntuple(x -> 0, Val(N)))

function _typed_hvncat(::Type{T}, dims::Tuple{Vararg{Int, N}}, row_first::Bool, xs::Number...) where {T, N}
function _typed_hvncat(::Type{T}, dims::NTuple{N, Int}, row_first::Bool, xs::Number...) where {T, N}
all(>(0), dims) ||
throw(ArgumentError("`dims` argument must contain positive integers"))
A = Array{T, N}(undef, dims...)
lengtha = length(A) # Necessary to store result because throw blocks are being deoptimized right now, which leads to excessive allocations
lengthx = length(xs) # Cuts from 3 allocations to 1.
Expand Down Expand Up @@ -2191,9 +2192,28 @@ function hvncat_fill!(A::Array, row_first::Bool, xs::Tuple)
end

_typed_hvncat(T::Type, dim::Int, ::Bool, xs...) = _typed_hvncat(T, Val(dim), xs...) # catches from _hvncat type promoters

function _typed_hvncat(::Type{T}, ::Val{N}) where {T, N}
N < 0 &&
throw(ArgumentError("concatenation dimension must be nonnegative"))
return Array{T, N}(undef, ntuple(x -> 0, Val(N)))
end

function _typed_hvncat(T::Type, ::Val{N}, xs::Number...) where N
N < 0 &&
throw(ArgumentError("concatenation dimension must be nonnegative"))
A = cat_similar(xs[1], T, (ntuple(x -> 1, Val(N - 1))..., length(xs)))
hvncat_fill!(A, false, xs)
return A
end

function _typed_hvncat(::Type{T}, ::Val{N}, as::AbstractArray...) where {T, N}
# optimization for arrays that can be concatenated by copying them linearly into the destination
# conditions: the elements must all have 1- or 0-length dimensions above N
# conditions: the elements must all have 1-length dimensions above N
length(as) > 0 ||
throw(ArgumentError("must have at least one element"))
N < 0 &&
throw(ArgumentError("concatenation dimension must be nonnegative"))
for a ∈ as
ndims(a) <= N || all(x -> size(a, x) == 1, (N + 1):ndims(a)) ||
return _typed_hvncat(T, (ntuple(x -> 1, N - 1)..., length(as), 1), false, as...)
Expand All @@ -2203,10 +2223,13 @@ function _typed_hvncat(::Type{T}, ::Val{N}, as::AbstractArray...) where {T, N}
nd = max(N, ndims(as[1]))

Ndim = 0
for i ∈ 1:lastindex(as)
Ndim += cat_size(as[i], N)
for d ∈ 1:N - 1
cat_size(as[1], d) == cat_size(as[i], d) || throw(ArgumentError("mismatched size along axis $d in element $i"))
for i ∈ eachindex(as)
a = as[i]
Ndim += size(a, N)
nd = max(nd, ndims(a))
for d ∈ 1:N-1
size(a, d) == size(as[1], d) ||
throw(ArgumentError("all dimensions of element $i other than $N must be of length 1"))
end
end

Expand All @@ -2222,17 +2245,20 @@ function _typed_hvncat(::Type{T}, ::Val{N}, as::AbstractArray...) where {T, N}
end

function _typed_hvncat(::Type{T}, ::Val{N}, as...) where {T, N}
# optimization for scalars and 1-length arrays that can be concatenated by copying them linearly
# into the destination
length(as) > 0 ||
throw(ArgumentError("must have at least one element"))
N < 0 &&
throw(ArgumentError("concatenation dimension must be nonnegative"))
nd = N
Ndim = 0
for a ∈ as
if a isa AbstractArray
cat_size(a, N) == length(a) ||
throw(ArgumentError("all dimensions of elements other than $N must be of length 1"))
nd = max(nd, cat_ndims(a))
end
for i ∈ eachindex(as)
a = as[i]
Ndim += cat_size(a, N)
nd = max(nd, cat_ndims(a))
for d ∈ 1:N-1
cat_size(a, d) == 1 ||
throw(ArgumentError("all dimensions of element $i other than $N must be of length 1"))
end
end

A = Array{T, nd}(undef, ntuple(x -> 1, N - 1)..., Ndim, ntuple(x -> 1, nd - N)...)
Expand Down Expand Up @@ -2276,7 +2302,12 @@ function _typed_hvncat_1d(::Type{T}, ds::Int, ::Val{row_first}, as...) where {T,
end
end

function _typed_hvncat(::Type{T}, dims::Tuple{Vararg{Int, N}}, row_first::Bool, as...) where {T, N}
function _typed_hvncat(::Type{T}, dims::NTuple{N, Int}, row_first::Bool, as...) where {T, N}
length(as) > 0 ||
throw(ArgumentError("must have at least one element"))
all(>(0), dims) ||
throw(ArgumentError("`dims` argument must contain positive integers"))

d1 = row_first ? 2 : 1
d2 = row_first ? 1 : 2

Expand All @@ -2291,7 +2322,9 @@ function _typed_hvncat(::Type{T}, dims::Tuple{Vararg{Int, N}}, row_first::Bool,

currentdims = zeros(Int, nd)
blockcount = 0
elementcount = 0
for i ∈ eachindex(as)
elementcount += cat_length(as[i])
currentdims[d1] += cat_size(as[i], d1)
if currentdims[d1] == outdims[d1]
currentdims[d1] = 0
Expand Down Expand Up @@ -2321,14 +2354,9 @@ function _typed_hvncat(::Type{T}, dims::Tuple{Vararg{Int, N}}, row_first::Bool,
end
end

# calling sum() leads to 3 extra allocations
len = 0
for a ∈ as
len += cat_length(a)
end
outlen = prod(outdims)
outlen == 0 && throw(ArgumentError("too few elements in arguments, unable to infer dimensions"))
len == outlen || throw(ArgumentError("too many elements in arguments; expected $(outlen), got $(len)"))
elementcount == outlen ||
throw(ArgumentError("mismatched number of elements; expected $(outlen), got $(elementcount)"))

# copy into final array
A = cat_similar(as[1], T, outdims)
Expand All @@ -2347,22 +2375,32 @@ function _typed_hvncat(T::Type, shape::Tuple{Tuple}, row_first::Bool, xs...)
return _typed_hvncat_1d(T, shape[1][1], Val(row_first), xs...)
end

function _typed_hvncat(T::Type, shape::NTuple{N, Tuple}, row_first::Bool, as...) where {N}
function _typed_hvncat(::Type{T}, shape::NTuple{N, Tuple}, row_first::Bool, as...) where {T, N}
length(as) > 0 ||
throw(ArgumentError("must have at least one element"))
all(>(0), tuple((shape...)...)) ||
throw(ArgumentError("`shape` argument must consist of positive integers"))

d1 = row_first ? 2 : 1
d2 = row_first ? 1 : 2
shape = collect(shape) # saves allocations later
shapelength = shape[end][1]
shapev = collect(shape) # saves allocations later
all(!isempty, shapev) ||
throw(ArgumentError("each level of `shape` argument must have at least one value"))
length(shapev[end]) == 1 ||
throw(ArgumentError("last level of shape must contain only one integer"))
shapelength = shapev[end][1]
lengthas = length(as)
shapelength == lengthas || throw(ArgumentError("number of elements does not match shape; expected $(shapelength), got $lengthas)"))

# discover dimensions
nd = max(N, cat_ndims(as[1]))
outdims = zeros(Int, nd)
currentdims = zeros(Int, nd)
blockcounts = zeros(Int, nd)
shapepos = ones(Int, nd)

elementcount = 0
for i ∈ eachindex(as)
elementcount += cat_length(as[i])
wasstartblock = false
for d ∈ 1:N
ad = (d < 3 && row_first) ? (d == 1 ? 2 : 1) : d
Expand All @@ -2372,27 +2410,34 @@ function _typed_hvncat(T::Type, shape::NTuple{N, Tuple}, row_first::Bool, as...)
if d == 1 || i == 1 || wasstartblock
currentdims[d] += dsize
elseif dsize != cat_size(as[i - 1], ad)
throw(ArgumentError("""argument $i has a mismatched number of elements along axis $ad; \
expected $(cat_size(as[i - 1], ad)), got $dsize"""))
throw(ArgumentError("argument $i has a mismatched number of elements along axis $ad; \
expected $(cat_size(as[i - 1], ad)), got $dsize"))
end

wasstartblock = blockcounts[d] == 1 # remember for next dimension

isendblock = blockcounts[d] == shape[d][shapepos[d]]
isendblock = blockcounts[d] == shapev[d][shapepos[d]]
if isendblock
if outdims[d] == 0
outdims[d] = currentdims[d]
elseif outdims[d] != currentdims[d]
throw(ArgumentError("""argument $i has a mismatched number of elements along axis $ad; \
expected $(abs(outdims[d] - (currentdims[d] - dsize))), got $dsize"""))
throw(ArgumentError("argument $i has a mismatched number of elements along axis $ad; \
expected $(abs(outdims[d] - (currentdims[d] - dsize))), got $dsize"))
end
currentdims[d] = 0
blockcounts[d] = 0
shapepos[d] += 1
d > 1 && (blockcounts[d - 1] == 0 ||
throw(ArgumentError("shape in level $d is inconsistent; level counts must nest \
evenly into each other")))
end
end
end

outlen = prod(outdims)
elementcount == outlen ||
throw(ArgumentError("mismatched number of elements; expected $(outlen), got $(elementcount)"))

if row_first
outdims[1], outdims[2] = outdims[2], outdims[1]
end
Expand Down
63 changes: 63 additions & 0 deletions test/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1390,6 +1390,69 @@ using Base: typed_hvncat
@test [v v;;; fill(v, 1, 2)] == fill(v, 1, 2, 2)
end

# dims form
for v ∈ ((), (1,), ([1],), (1, [1]), ([1], 1), ([1], [1]))
# reject dimension < 0
@test_throws ArgumentError hvncat(-1, v...)

# reject shape tuple with no elements
@test_throws ArgumentError hvncat(((),), true, v...)
end

# reject dims or shape with negative or zero values
for v1 ∈ (-1, 0, 1)
for v2 ∈ (-1, 0, 1)
v1 == v2 == 1 && continue
for v3 ∈ ((), (1,), ([1],), (1, [1]), ([1], 1), ([1], [1]))
@test_throws ArgumentError hvncat((v1, v2), true, v3...)
@test_throws ArgumentError hvncat(((v1,), (v2,)), true, v3...)
end
end
end

for v ∈ ((1, [1]), ([1], 1), ([1], [1]))
# reject shape with more than one end value
@test_throws ArgumentError hvncat(((1, 1),), true, v...)
end

for v ∈ ((1, 2, 3), (1, 2, [3]), ([1], [2], [3]))
# reject shape with more values in later level
@test_throws ArgumentError hvncat(((2, 1), (1, 1, 1)), true, v...)
end

# reject shapes that don't nest evenly between levels (e.g. 1 + 2 does not fit into 2)
@test_throws ArgumentError hvncat(((1, 2, 1), (2, 2), (4,)), true, [1 2], [3], [4], [1 2; 3 4])

# zero-length arrays are handled appropriately
@test [zeros(Int, 1, 2, 0) ;;; 1 3] == [1 3;;;]
@test [[] ;;; [] ;;; []] == Array{Any}(undef, 0, 1, 3)
@test [[] ; 1 ;;; 2 ; []] == [1 ;;; 2]
@test [[] ; [] ;;; [] ; []] == Array{Any}(undef, 0, 1, 2)
@test [[] ; 1 ;;; 2] == [1 ;;; 2]
@test [[] ; [] ;;; [] ;;; []] == Array{Any}(undef, 0, 1, 3)
z = zeros(Int, 0, 0, 0)
[z z ; z ;;; z ;;; z] == Array{Int}(undef, 0, 0, 0)

for v1 ∈ (zeros(Int, 0, 0), zeros(Int, 0, 0, 0, 0), zeros(Int, 0, 0, 0, 0, 0, 0, 0))
for v2 ∈ (1, [1])
for v3 ∈ (2, [2])
@test_throws ArgumentError [v1 ;;; v2]
@test_throws ArgumentError [v1 ;;; v2 v3]
@test_throws ArgumentError [v1 v1 ;;; v2 v3]
end
end
end
v1 = zeros(Int, 0, 0, 0)
for v2 ∈ (1, [1])
for v3 ∈ (2, [2])
# current behavior, not potentially dangerous.
# should throw error like above loop
@test [v1 ;;; v2 v3] == [v2 v3;;;]
@test_throws ArgumentError [v1 ;;; v2]
@test_throws ArgumentError [v1 v1 ;;; v2 v3]
end
end

# 0-dimension behaviors
# exactly one argument, placed in an array
# if already an array, copy, with type conversion as necessary
Expand Down