From ebacfa0e4feb78dadf4865ec68e67e680a8ae8d3 Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Mon, 8 Mar 2021 12:50:53 -0500 Subject: [PATCH 1/4] Implement broadcast_axis --- src/broadcast.jl | 36 ++++++++++++++++++++++++++++++++++++ test/broadcast.jl | 13 +++++++++++++ 2 files changed, 49 insertions(+) create mode 100644 src/broadcast.jl create mode 100644 test/broadcast.jl diff --git a/src/broadcast.jl b/src/broadcast.jl new file mode 100644 index 000000000..5d6e4e501 --- /dev/null +++ b/src/broadcast.jl @@ -0,0 +1,36 @@ + +""" BroadcastAxis """ +abstract type BroadcastAxis end + +struct BroadcastAxisDefault <: BroadcastAxis end + +BroadcastAxis(x) = BroadcastAxis(typeof(x)) +BroadcastAxis(::Type{T}) where {T} = BroadcastAxisDefault() + +""" broadcast_axis(x, y) """ +broadcast_axis(x, y) = broadcast_axis(BroadcastAxis(x), x, y) +# stagger default broadcasting in case y has something other than default +broadcast_axis(::BroadcastAxisDefault, x, y) = _broadcast_axis(BroadcastAxis(y), x, y) +function _broadcast_axis(::BroadcastAxisDefault, x, y) + return One():_combine_length(static_length(x), static_length(y)) +end +_broadcast_axis(s::BroadcastAxis, x, y) = broadcasted_axis(s, x, y) +_combine_length(::StaticInt{X}, ::StaticInt{Y}) where {X,Y} = static(_combine_length(X, Y)) +_combine_length(::StaticInt{X}, y::Int) where {X} = _combine_length(X, y) +_combine_length(x::Int, ::StaticInt{Y}) where {Y} = _combine_length(x, Y) +@inline function _combine_length(x::Int, y::Int) + if x === y + return x + elseif y === 1 + return x + elseif x === 1 + return y + else + _dimerr(x, y) + end +end + +function _dimerr(@nospecialize(x), @nospecialize(y)) + throw(DimensionMismatch("axes could not be broadcast to a common size; " * + "got axes with lengths $(x) and $(y)")) +end diff --git a/test/broadcast.jl b/test/broadcast.jl new file mode 100644 index 000000000..eb97d0628 --- /dev/null +++ b/test/broadcast.jl @@ -0,0 +1,13 @@ + +s5 = static(1):static(5) +s4 = static(1):static(4) +s1 = static(1):static(1) +d5 = static(1):5 + +@inferred(ArrayInterface.broadcast_axis(s5, s5)) === s5 +@inferred(ArrayInterface.broadcast_axis(s5, s1)) === s5 +@inferred(ArrayInterface.broadcast_axis(s1, s5)) === s5 +@inferred(ArrayInterface.broadcast_axis(s5, d5)) === d5 + +@test_throws DimensionMismatch ArrayInterface.broadcast_axis(s5, s4) + From f7a8dc5d018702c5c1e02fd538cfe22ef1269888 Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Mon, 8 Mar 2021 13:16:31 -0500 Subject: [PATCH 2/4] Tests --- src/ArrayInterface.jl | 1 + test/runtests.jl | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/src/ArrayInterface.jl b/src/ArrayInterface.jl index 6f205427d..397558047 100644 --- a/src/ArrayInterface.jl +++ b/src/ArrayInterface.jl @@ -792,6 +792,7 @@ include("dimensions.jl") include("axes.jl") include("size.jl") include("stridelayout.jl") +include("broadcast.jl") abstract type AbstractArray2{T,N} <: AbstractArray{T,N} end diff --git a/test/runtests.jl b/test/runtests.jl index d13544b93..885425a59 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -716,3 +716,8 @@ end include("indexing.jl") include("dimensions.jl") +@testset "broadcast" begin + include("broadcast.jl") +end + + From 62b6f7a97ed660ecb2f784ce3fdec99048210121 Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Mon, 8 Mar 2021 19:52:27 -0500 Subject: [PATCH 3/4] docstrings + try to preserve static lengths --- src/broadcast.jl | 35 ++++++++++++++++++++++++++++++----- test/broadcast.jl | 18 +++++++++++++++++- 2 files changed, 47 insertions(+), 6 deletions(-) diff --git a/src/broadcast.jl b/src/broadcast.jl index 5d6e4e501..2258d5cf3 100644 --- a/src/broadcast.jl +++ b/src/broadcast.jl @@ -1,5 +1,9 @@ -""" BroadcastAxis """ +""" + BroadcastAxis + +An abstract trait that is used to determine how axes are combined when calling `broadcast_axis`. +""" abstract type BroadcastAxis end struct BroadcastAxisDefault <: BroadcastAxis end @@ -7,7 +11,21 @@ struct BroadcastAxisDefault <: BroadcastAxis end BroadcastAxis(x) = BroadcastAxis(typeof(x)) BroadcastAxis(::Type{T}) where {T} = BroadcastAxisDefault() -""" broadcast_axis(x, y) """ +""" + broadcast_axis(x, y) + +Broadcast axis `x` and `y` into a common space. The resulting axis should be equal in length +to both `x` and `y` unless one has a length of `1`, in which case the longest axis will be +equal to the output. + +```julia +julia> ArrayInterface.broadcast_axis(1:10, 1:10) + +julia> ArrayInterface.broadcast_axis(1:10, 1) +1:10 + +``` +""" broadcast_axis(x, y) = broadcast_axis(BroadcastAxis(x), x, y) # stagger default broadcasting in case y has something other than default broadcast_axis(::BroadcastAxisDefault, x, y) = _broadcast_axis(BroadcastAxis(y), x, y) @@ -15,9 +33,16 @@ function _broadcast_axis(::BroadcastAxisDefault, x, y) return One():_combine_length(static_length(x), static_length(y)) end _broadcast_axis(s::BroadcastAxis, x, y) = broadcasted_axis(s, x, y) -_combine_length(::StaticInt{X}, ::StaticInt{Y}) where {X,Y} = static(_combine_length(X, Y)) -_combine_length(::StaticInt{X}, y::Int) where {X} = _combine_length(X, y) -_combine_length(x::Int, ::StaticInt{Y}) where {Y} = _combine_length(x, Y) + +# we can use a similar trick as we do with `indices` where unequal sizes error and we just +# keep the static value. However, axes can be unequal if one of them is `1` so we have to +# fall back to dynamic values in those cases +_combine_length(x::StaticInt{X}, y::StaticInt{Y}) where {X,Y} = static(_combine_length(X, Y)) +_combine_length(x::StaticInt{X}, ::Int) where {X} = x +_combine_length(x::StaticInt{1}, y::Int) = y +_combine_length(x::StaticInt{1}, y::StaticInt{1}) = y +_combine_length(x::Int, y::StaticInt{Y}) where {Y} = y +_combine_length(x::Int, y::StaticInt{1}) = x @inline function _combine_length(x::Int, y::Int) if x === y return x diff --git a/test/broadcast.jl b/test/broadcast.jl index eb97d0628..9b27ce2c8 100644 --- a/test/broadcast.jl +++ b/test/broadcast.jl @@ -3,11 +3,27 @@ s5 = static(1):static(5) s4 = static(1):static(4) s1 = static(1):static(1) d5 = static(1):5 +d4 = static(1):static(4) +d1 = static(1):static(1) + +struct DummyBroadcast <: ArrayInterface.BroadcastAxis end + +struct DummyAxis end + +ArrayInterface.BroadcastAxis(::Type{DummyAxis}) = DummyBroadcast() + +ArrayInterface.broadcast_axis(::DummyBroadcast, x, y) = y @inferred(ArrayInterface.broadcast_axis(s5, s5)) === s5 @inferred(ArrayInterface.broadcast_axis(s5, s1)) === s5 @inferred(ArrayInterface.broadcast_axis(s1, s5)) === s5 -@inferred(ArrayInterface.broadcast_axis(s5, d5)) === d5 +@inferred(ArrayInterface.broadcast_axis(s5, d5)) === s5 +@inferred(ArrayInterface.broadcast_axis(d5, s5)) === s5 +@inferred(ArrayInterface.broadcast_axis(d5, d1)) === d5 +@inferred(ArrayInterface.broadcast_axis(d1, d5)) === d5 +@inferred(ArrayInterface.broadcast_axis(s1, d5)) === d5 +@inferred(ArrayInterface.broadcast_axis(d5, s1)) === d5 +@inferred(ArrayInterface.broadcast_axis(DummyAxis(), s5)) === s5 @test_throws DimensionMismatch ArrayInterface.broadcast_axis(s5, s4) From 94f7e3edea6cdef0dd9b40247f9e0d1a1d83ac3b Mon Sep 17 00:00:00 2001 From: "Zachary P. Christensen" Date: Mon, 8 Mar 2021 21:02:55 -0500 Subject: [PATCH 4/4] Test staggered check for non default style --- src/broadcast.jl | 2 +- test/broadcast.jl | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/broadcast.jl b/src/broadcast.jl index 2258d5cf3..91c20e342 100644 --- a/src/broadcast.jl +++ b/src/broadcast.jl @@ -32,7 +32,7 @@ broadcast_axis(::BroadcastAxisDefault, x, y) = _broadcast_axis(BroadcastAxis(y), function _broadcast_axis(::BroadcastAxisDefault, x, y) return One():_combine_length(static_length(x), static_length(y)) end -_broadcast_axis(s::BroadcastAxis, x, y) = broadcasted_axis(s, x, y) +_broadcast_axis(s::BroadcastAxis, x, y) = broadcast_axis(s, x, y) # we can use a similar trick as we do with `indices` where unequal sizes error and we just # keep the static value. However, axes can be unequal if one of them is `1` so we have to diff --git a/test/broadcast.jl b/test/broadcast.jl index 9b27ce2c8..a82373c06 100644 --- a/test/broadcast.jl +++ b/test/broadcast.jl @@ -14,6 +14,7 @@ ArrayInterface.BroadcastAxis(::Type{DummyAxis}) = DummyBroadcast() ArrayInterface.broadcast_axis(::DummyBroadcast, x, y) = y +@inferred(ArrayInterface.broadcast_axis(s1, s1)) === s1 @inferred(ArrayInterface.broadcast_axis(s5, s5)) === s5 @inferred(ArrayInterface.broadcast_axis(s5, s1)) === s5 @inferred(ArrayInterface.broadcast_axis(s1, s5)) === s5 @@ -23,7 +24,7 @@ ArrayInterface.broadcast_axis(::DummyBroadcast, x, y) = y @inferred(ArrayInterface.broadcast_axis(d1, d5)) === d5 @inferred(ArrayInterface.broadcast_axis(s1, d5)) === d5 @inferred(ArrayInterface.broadcast_axis(d5, s1)) === d5 -@inferred(ArrayInterface.broadcast_axis(DummyAxis(), s5)) === s5 +@inferred(ArrayInterface.broadcast_axis(s5, DummyAxis())) === s5 @test_throws DimensionMismatch ArrayInterface.broadcast_axis(s5, s4)