Skip to content

Commit e4a2044

Browse files
Merge pull request #316 from AayushSabharwal/as/size
fix: add support for adjoint of AbstractVectorOfArray
2 parents 2b69321 + cc14f2c commit e4a2044

File tree

2 files changed

+15
-0
lines changed

2 files changed

+15
-0
lines changed

src/vector_of_array.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,10 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractVectorOfArray, _arg,
347347
end
348348
end
349349

350+
Base.@propagate_inbounds function Base.getindex(A::Adjoint{T,<:AbstractVectorOfArray}, idxs...) where {T}
351+
return getindex(A.parent, reverse(to_indices(A, idxs))...)
352+
end
353+
350354
function _observed(A::AbstractDiffEqArray{T, N}, sym, i::Int) where {T, N}
351355
observed(A, sym)(A.u[i], A.p, A.t[i])
352356
end
@@ -395,6 +399,9 @@ end
395399

396400
# Interface for the two-dimensional indexing, a more standard AbstractArray interface
397401
@inline Base.size(VA::AbstractVectorOfArray) = (size(VA.u[1])..., length(VA.u))
402+
@inline Base.size(VA::AbstractVectorOfArray, i) = size(VA)[i]
403+
@inline Base.size(A::Adjoint{T,<:AbstractVectorOfArray}) where {T} = reverse(size(A.parent))
404+
@inline Base.size(A::Adjoint{T,<:AbstractVectorOfArray}, i) where {T} = size(A)[i]
398405
Base.axes(VA::AbstractVectorOfArray) = Base.OneTo.(size(VA))
399406
Base.axes(VA::AbstractVectorOfArray, d::Int) = Base.OneTo(size(VA)[d])
400407

@@ -592,6 +599,7 @@ end
592599
@inline Statistics.var(VA::AbstractVectorOfArray; kwargs...) = var(Array(VA); kwargs...)
593600
@inline Statistics.cov(VA::AbstractVectorOfArray; kwargs...) = cov(Array(VA); kwargs...)
594601
@inline Statistics.cor(VA::AbstractVectorOfArray; kwargs...) = cor(Array(VA); kwargs...)
602+
@inline Base.adjoint(VA::AbstractVectorOfArray) = Adjoint(VA)
595603

596604
# make it show just like its data
597605
function Base.show(io::IO, m::MIME"text/plain", x::AbstractVectorOfArray)

test/linalg.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,10 @@ for T in (Array{Float64}, Array{ComplexF64})
4848
@test d.x[i] == b.x[i] * c.x[i]
4949
end
5050
end
51+
52+
va = VectorOfArray([i * ones(3) for i in 1:4])
53+
mat = Array(va)
54+
55+
@test size(va') == (size(va', 1), size(va', 2)) == (size(va, 2), size(va, 1))
56+
@test all(va'[i] == mat'[i] for i in eachindex(mat'))
57+
@test Array(va') == mat'

0 commit comments

Comments
 (0)