diff --git a/src/ArrayInterface.jl b/src/ArrayInterface.jl index 6f205427d..397558047 100644 --- a/src/ArrayInterface.jl +++ b/src/ArrayInterface.jl @@ -792,6 +792,7 @@ include("dimensions.jl") include("axes.jl") include("size.jl") include("stridelayout.jl") +include("broadcast.jl") abstract type AbstractArray2{T,N} <: AbstractArray{T,N} end diff --git a/src/broadcast.jl b/src/broadcast.jl new file mode 100644 index 000000000..91c20e342 --- /dev/null +++ b/src/broadcast.jl @@ -0,0 +1,61 @@ + +""" + BroadcastAxis + +An abstract trait that is used to determine how axes are combined when calling `broadcast_axis`. +""" +abstract type BroadcastAxis end + +struct BroadcastAxisDefault <: BroadcastAxis end + +BroadcastAxis(x) = BroadcastAxis(typeof(x)) +BroadcastAxis(::Type{T}) where {T} = BroadcastAxisDefault() + +""" + broadcast_axis(x, y) + +Broadcast axis `x` and `y` into a common space. The resulting axis should be equal in length +to both `x` and `y` unless one has a length of `1`, in which case the longest axis will be +equal to the output. + +```julia +julia> ArrayInterface.broadcast_axis(1:10, 1:10) + +julia> ArrayInterface.broadcast_axis(1:10, 1) +1:10 + +``` +""" +broadcast_axis(x, y) = broadcast_axis(BroadcastAxis(x), x, y) +# stagger default broadcasting in case y has something other than default +broadcast_axis(::BroadcastAxisDefault, x, y) = _broadcast_axis(BroadcastAxis(y), x, y) +function _broadcast_axis(::BroadcastAxisDefault, x, y) + return One():_combine_length(static_length(x), static_length(y)) +end +_broadcast_axis(s::BroadcastAxis, x, y) = broadcast_axis(s, x, y) + +# we can use a similar trick as we do with `indices` where unequal sizes error and we just +# keep the static value. However, axes can be unequal if one of them is `1` so we have to +# fall back to dynamic values in those cases +_combine_length(x::StaticInt{X}, y::StaticInt{Y}) where {X,Y} = static(_combine_length(X, Y)) +_combine_length(x::StaticInt{X}, ::Int) where {X} = x +_combine_length(x::StaticInt{1}, y::Int) = y +_combine_length(x::StaticInt{1}, y::StaticInt{1}) = y +_combine_length(x::Int, y::StaticInt{Y}) where {Y} = y +_combine_length(x::Int, y::StaticInt{1}) = x +@inline function _combine_length(x::Int, y::Int) + if x === y + return x + elseif y === 1 + return x + elseif x === 1 + return y + else + _dimerr(x, y) + end +end + +function _dimerr(@nospecialize(x), @nospecialize(y)) + throw(DimensionMismatch("axes could not be broadcast to a common size; " * + "got axes with lengths $(x) and $(y)")) +end diff --git a/test/broadcast.jl b/test/broadcast.jl new file mode 100644 index 000000000..a82373c06 --- /dev/null +++ b/test/broadcast.jl @@ -0,0 +1,30 @@ + +s5 = static(1):static(5) +s4 = static(1):static(4) +s1 = static(1):static(1) +d5 = static(1):5 +d4 = static(1):static(4) +d1 = static(1):static(1) + +struct DummyBroadcast <: ArrayInterface.BroadcastAxis end + +struct DummyAxis end + +ArrayInterface.BroadcastAxis(::Type{DummyAxis}) = DummyBroadcast() + +ArrayInterface.broadcast_axis(::DummyBroadcast, x, y) = y + +@inferred(ArrayInterface.broadcast_axis(s1, s1)) === s1 +@inferred(ArrayInterface.broadcast_axis(s5, s5)) === s5 +@inferred(ArrayInterface.broadcast_axis(s5, s1)) === s5 +@inferred(ArrayInterface.broadcast_axis(s1, s5)) === s5 +@inferred(ArrayInterface.broadcast_axis(s5, d5)) === s5 +@inferred(ArrayInterface.broadcast_axis(d5, s5)) === s5 +@inferred(ArrayInterface.broadcast_axis(d5, d1)) === d5 +@inferred(ArrayInterface.broadcast_axis(d1, d5)) === d5 +@inferred(ArrayInterface.broadcast_axis(s1, d5)) === d5 +@inferred(ArrayInterface.broadcast_axis(d5, s1)) === d5 +@inferred(ArrayInterface.broadcast_axis(s5, DummyAxis())) === s5 + +@test_throws DimensionMismatch ArrayInterface.broadcast_axis(s5, s4) + diff --git a/test/runtests.jl b/test/runtests.jl index d13544b93..885425a59 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -716,3 +716,8 @@ end include("indexing.jl") include("dimensions.jl") +@testset "broadcast" begin + include("broadcast.jl") +end + +