Skip to content

Conversation

thchr
Copy link
Contributor

@thchr thchr commented Feb 16, 2022

While cat(A, B; dims=2) does not infer (cf. #5339), cat(A, B, dims=Val(2)) does. The latter version is preferable when the concatenation dimension is known statically.

The generalization of this for concatenation along multiple dimensions, e.g. cat(A, B, dims=(1,2)) to cat(A, B, dims=Val((1,2)), doesn't work currently due to a missing method of Base.dims2cat for Val{(i,j,...)}. This PR adds that and makes cat(A, B, dims=Val((1,2,...))) work.

On this PR:

julia> @code_warntype cat(rand(2,2), rand(2,2), dims=Val((1,2)))
MethodInstance for (::Base.var"#cat##kw")(::NamedTuple{(:dims,), Tuple{Val{(1, 2)}}}, ::typeof(cat), ::Matrix{Float64}, ::Matrix{Float64})
  from (::Base.var"#cat##kw")(::Any, ::typeof(cat), A...) in Base at abstractarray.jl:1861
Arguments
  _::Core.Const(Base.var"#cat##kw"())
  @_2::Core.Const((dims = Val{(1, 2)}(),))
  @_3::Core.Const(cat)
  A::Tuple{Matrix{Float64}, Matrix{Float64}}
Locals
  dims::Val{(1, 2)}
  @_6::Val{(1, 2)}
Body::Matrix{Float64}
[... omitted]

Cc. @mcabbott who suggested the dims=Val((1,2)) form on Slack.

@JeffBezanson
Copy link
Member

Looks like a bug fix, so that's all good, but I've long wanted cat(A, B, dims=2) to be inferrable --- we should look into why constant prop isn't working, or how the code can be rewritten, or whatever it would take.

@JeffBezanson JeffBezanson added the backport 1.8 Change should be backported to release-1.8 label Feb 17, 2022
@JeffBezanson JeffBezanson merged commit 1ad2396 into JuliaLang:master Feb 17, 2022
KristofferC pushed a commit that referenced this pull request Feb 18, 2022
@martinholters
Copy link
Member

we should look into why constant prop isn't working, or how the code can be rewritten, or whatever it would take.

As a starting point:

--- a/stdlib/LinearAlgebra/src/special.jl
+++ b/stdlib/LinearAlgebra/src/special.jl
@@ -372,7 +372,7 @@ hcat(A::Vector...) = Base.typed_hcat(promote_eltype(A...), A...)
 hcat(A::_DenseConcatGroup...) = Base.typed_hcat(promote_eltype(A...), A...)
 hvcat(rows::Tuple{Vararg{Int}}, xs::_DenseConcatGroup...) = Base.typed_hvcat(promote_eltype(xs...), rows, xs...)
 # For performance, specially handle the case where the matrices/vectors have homogeneous eltype
-Base._cat(dims, xs::_TypedDenseConcatGroup{T}...) where {T} = Base.cat_t(T, xs...; dims=dims)
+Base.@constprop :aggressive Base._cat(dims, xs::_TypedDenseConcatGroup{T}...) where {T} = Base._cat_t(dims, T, xs...)
 vcat(A::_TypedDenseConcatGroup{T}...) where {T} = Base.typed_vcat(T, A...)
 hcat(A::_TypedDenseConcatGroup{T}...) where {T} = Base.typed_hcat(T, A...)
 hvcat(rows::Tuple{Vararg{Int}}, xs::_TypedDenseConcatGroup{T}...) where {T} = Base.typed_hvcat(T, rows, xs...)

That gives me:

julia> (@code_typed (() -> cat(rand(2,2), rand(2,2), dims=1))())[2]
Matrix{Float64} (alias for Array{Float64, 2})

julia> (@code_typed (() -> cat(rand(2,2), rand(2,2), dims=2))())[2]
Matrix{Float64} (alias for Array{Float64, 2})

julia> (@code_typed (() -> cat(rand(2,2), rand(2,2), dims=3))())[2]
Array{Float64, 3}

julia> (@code_typed (() -> cat(rand(2,2), rand(2,2), dims=(1,2,4)))())[2]
Array{Float64, 4}

All of these are Any on master.

@KristofferC KristofferC removed the backport 1.8 Change should be backported to release-1.8 label Feb 24, 2022
staticfloat pushed a commit to JuliaCI/julia-buildkite-testing that referenced this pull request Mar 2, 2022
LilithHafner pushed a commit to LilithHafner/julia that referenced this pull request Mar 8, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants