Skip to content

Preserve custom AbstractArray types during broadcasting of ArrayPartition #136

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 16, 2021

Conversation

jdeldre
Copy link
Contributor

@jdeldre jdeldre commented Apr 16, 2021

It would be useful to allow a custom AbstractArray type inside an ArrayPartition to be preserved during broadcasting, but currently it is not. Understandably, one needs to extend some of the broadcasting functions to the new type, as I show below. However, there is a small fix needed to the package to enable this. This is because unpack, which strips out each partition one by one and creates a separate Broadcasted wrapper for it, gives the same promoted BroadcastStyle to each one, even if they are different types, so some get dispatched to the wrong similar function in copy. In the PR, I don't pass the Style during the unpack to allow it to decide naturally which style it is.

If you'd like further detail, for example, consider

struct MyType{T} <: AbstractVector{T}
    data :: Vector{T}
end
Base.similar(A::MyType{T}) where {T} = MyType{T}(similar(A.data))
Base.similar(A::MyType{T},::Type{S}) where {T,S} = MyType(similar(A.data,S))

Base.size(A::MyType) = size(A.data)
Base.getindex(A::MyType, i::Int) = getindex(A.data,i)
Base.setindex!(A::MyType, v, i::Int) = setindex!(A.data,v,i)
Base.IndexStyle(::MyType) = IndexLinear()

and then I define

julia> ap = ArrayPartition(MyType(ones(10)),collect(1:2));

julia> typeof(ap)
ArrayPartition{Float64, Tuple{MyType{Float64}, Vector{Int64}}}

and calculate

up = ap .+ 1

The MyType wrapper is converted to just a Vector{Float64}

julia> typeof(up)
ArrayPartition{Float64, Tuple{Vector{Float64}, Vector{Int64}}}

No problem. We haven't extended similar for the Broadcasted wrapper to do anything for our new type, so it just falls back to the default. So we endow it with some BroadcastStyle functionality:

Base.BroadcastStyle(::Type{<:MyType}) = Broadcast.ArrayStyle{MyType}()

function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyType}},::Type{T}) where {T}
    similar(find_mt(bc),T)
end

function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{MyType}})
    similar(find_mt(bc))
end

find_mt(bc::Base.Broadcast.Broadcasted) = find_mt(bc.args)
find_mt(args::Tuple) = find_mt(find_mt(args[1]), Base.tail(args))
find_mt(x) = x
find_mt(::Tuple{}) = nothing
find_mt(a::MyType, rest) = a
find_mt(::Any, rest) = find_mt(rest)

Except that now,

julia> up = ap .+ 1

produces an error, because our new similar function gets called on the non MyType part of the ArrayPartition, too. The PR fixes this.

@ChrisRackauckas ChrisRackauckas merged commit c2b5da2 into SciML:master Apr 16, 2021
@ChrisRackauckas
Copy link
Member

This is great! Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants