Skip to content

Poor inference of [1, 2] + [3.0, missing]` #28382

@andyferris

Description

@andyferris

I have been trying to get into using missing and fast Union types in my worflow and seem to be coming up with some issues. It's most easily highlighted by:

julia> @code_warntype [1, 2] + [3.0, missing]
Body::Array
44 1 ── %1  = (Base.getfield)(Bs, 1, true)::Array{Union{Missing, Float64},1}                                                                               │╻╷╷         iterate
   └───       goto #7 if not true                                                                                                                          │           
   2 ┄─ %3  = φ (#1 => %1, #6 => %12)::Array{Union{Missing, Float64},1}                                                                                    │           
   │    %4  = φ (#1 => 2, #6 => %13)::Int64                                                                                                                │           
45 │          invoke Base.promote_shape(_2::Array{Int64,1}, %3::Array{Union{Missing, Float64},1})                                                          │           
   │    %6  = (Base.slt_int)(1, %4)::Bool                                                                                                                  ││╻           <
   └───       goto #4 if not %6                                                                                                                            ││          
   3 ──       goto #5                                                                                                                                      ││          
   4 ── %9  = (Base.getfield)(Bs, %4, true)::Array{Union{Missing, Float64},1}                                                                              ││╻           getindex
   │    %10 = (Base.add_int)(%4, 1)::Int64                                                                                                                 ││╻           +
   └───       goto #5                                                                                                                                      ││          
   5 ┄─ %12 = φ (#4 => %9)::Array{Union{Missing, Float64},1}                                                                                               │           
   │    %13 = φ (#4 => %10)::Int64                                                                                                                         │           
   │    %14 = φ (#3 => true, #4 => false)::Bool                                                                                                            │           
   │    %15 = (Base.not_int)(%14)::Bool                                                                                                                    │           
   └───       goto #7 if not %15                                                                                                                           │           
   6 ──       goto #2                                                                                                                                      │           
47 7 ── %18 = (getfield)(Bs, 1)::Array{Union{Missing, Float64},1}                                                                                          │           
   │    %19 = (Core.tuple)(A, %18)::Tuple{Array{Int64,1},Array{Union{Missing, Float64},1}}                                                                 ││╻           broadcasted
   │    %20 = (Base.arraysize)(A, 1)::Int64                                                                                                                │││╻╷╷╷╷       instantiate
   │    %21 = (Base.slt_int)(%20, 0)::Bool                                                                                                                 ││││╻╷╷╷        combine_axes
   │    %22 = (Base.ifelse)(%21, 0, %20)::Int64                                                                                                            │││││┃│││││      broadcast_axes
   │    %23 = %new(Base.OneTo{Int64}, %22)::Base.OneTo{Int64}                                                                                              ││││││┃│││        axes
   │    %24 = (Base.arraysize)(%18, 1)::Int64                                                                                                              ││││││╻╷╷         broadcast_axes
   │    %25 = (Base.slt_int)(%24, 0)::Bool                                                                                                                 │││││││╻╷╷╷        axes
   │    %26 = (Base.ifelse)(%25, 0, %24)::Int64                                                                                                            ││││││││┃│││        map
   │    %27 = %new(Base.OneTo{Int64}, %26)::Base.OneTo{Int64}                                                                                              │││││││││┃│          Type
   │    %28 = (%26 === %22)::Bool                                                                                                                          ││││││╻╷╷╷╷       _bcs
   │    %29 = (Base.and_int)(true, %28)::Bool                                                                                                              │││││││╻           _bcs1
   └───       goto #9 if not %29                                                                                                                           ││││││││┃           _bcsm
   8 ──       goto #10                                                                                                                                     │││││││││   
   9 ── %32 = (%22 === 1)::Bool                                                                                                                            │││││││││╻           ==
   └───       goto #10                                                                                                                                     │││││││││   
   10 ┄ %34 = φ (#8 => %29, #9 => %32)::Bool                                                                                                               ││││││││    
   └───       goto #12 if not %34                                                                                                                          ││││││││    
   11 ─       goto #18                                                                                                                                     ││││││││    
   12 ─ %37 = (%22 === %26)::Bool                                                                                                                          │││││││││╻╷          ==
   │    %38 = (Base.and_int)(true, %37)::Bool                                                                                                              ││││││││││╻           &
   └───       goto #14 if not %38                                                                                                                          │││││││││   
   13 ─       goto #15                                                                                                                                     │││││││││   
   14 ─ %41 = (%26 === 1)::Bool                                                                                                                            │││││││││╻           ==
   └───       goto #15                                                                                                                                     │││││││││   
   15 ┄ %43 = φ (#13 => %38, #14 => %41)::Bool                                                                                                             ││││││││    
   └───       goto #17 if not %43                                                                                                                          ││││││││    
   16 ─       goto #18                                                                                                                                     ││││││││    
   17 ─ %46 = %new(Base.DimensionMismatch, "arrays could not be broadcast to a common size")::DimensionMismatch                                            ││││││││╻           Type
   │          (Base.Broadcast.throw)(%46)                                                                                                                  ││││││││    
   └───       $(Expr(:unreachable))                                                                                                                        ││││││││    
   18 ┄ %49 = φ (#11 => %27, #16 => %23)::Base.OneTo{Int64}                                                                                                │││││││     
   │    %50 = (Core.tuple)(%49)::Tuple{Base.OneTo{Int64}}                                                                                                  │││││││     
   └───       goto #19                                                                                                                                     │││││││     
   19 ─       goto #20                                                                                                                                     ││││││      
   20 ─       goto #21                                                                                                                                     │││││       
   21 ─ %54 = %new(Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Tuple{Base.OneTo{Int64}},typeof(+),Tuple{Array{Int64,1},Array{Union{Missing, Float64},1}}}, +, %19, %50)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Tuple{Base.OneTo{Int64}},typeof(+),Tuple{Array{Int64,1},Array{Union{Missing, Float64},1}}}
   └───       goto #22                                                                                                                                     ││││        
   22 ─ %56 = invoke Base.Broadcast.copy(%54::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Tuple{Base.OneTo{Int64}},typeof(+),Tuple{Array{Int64,1},Array{Union{Missing, Float64},1}}})::Array
   └───       goto #23                                                                                                                                     │││         
   23 ─       goto #24                                                                                                                                     ││          
   24 ─       return %56                  

As you can see, the resultant output is (correctly but imprecisely) inferred to be Array (not even Vector?).

I find this example interesting for two reasons

  1. Performance (resulting from imprecise inference)
  2. Predictable interface provided by output container

Regarding performance, I find it quite usual to combine combinations of "functional" operations like +, map, etc, together with hand-written loops in the same scope. In fact, predictably good performance over loops, recursion, higher-order programming, etc is kind of one of our main selling points. In the above, performance will be pessimised if I loop over the output since neither the container type (even the array dimensionality is lost, somehow!) nor the element type is well known. I'm guessing but I'd predict that Vector{Union{Missing, Float64}} would be fastest for subsequent iteration, but it's possible that if inference found something like Vector{<:Union{Missing, Float64}} then this would also be quite fast (I'm not sure, really).

Regarding the second, more semantic issue of what the output actually becomes at run time, I'm a bit confused about what I'm meant to assume I can do with the output container. For these input types, we can get

julia> Int[] + Union{Float64, Missing}[]
0-element Array{Union{Missing, Float64},1}

julia> Int[1] + Union{Float64, Missing}[3.0]
1-element Array{Float64,1}:
 4.0

julia> Int[2] + Union{Float64, Missing}[missing]
1-element Array{Missing,1}:
 missing

julia> Int[1, 2] + Union{Float64, Missing}[3.0, missing]
2-element Array{Union{Missing, Float64},1}:
 4.0     
  missing

As a programmer, this scares me somewhat because I might get subtly different program behavior (or program correctness) depending on my input data. My concerm is that I might perform tests with mixtures of Missing and Float64 input values but be caught out in rare cases in production when all (or none) of the inputs are missing.

In this hypothetical scenario, my (tested, production) code might look something like

vec3 = vec1 + vec2
for i in keys(vec3)
    if isless(vec3[i], 0) # vec3 can't have any negative elements
        vec3[i] = missing
    end
end
return vec3

But unfortunately this would sometimes fail, depending on my input data, which may vary in size and quality (and maybe 1 in 10,000 times are all non-missing).

As a coder, I need to be able to predict the interface provided by the outputs of operation so I can write reliable code. Generally, I've noted that the output of +, map, broadcast, filter (for Array inputs) tend to be mutable Arrays, and generally am unafraid to mutate them. Naively, I would track in my head the eltypes of vec1 and vec2 and assume I can populate vec3 with anything consistent.

I'm not sure what the best solution is, but from my perspective I find any potential (performance?) gains in returning the narrowest container that fits the run-time data questionable in value compared to understanding the semantic guarantees as a code author.

Metadata

Metadata

Assignees

No one assigned

    Labels

    broadcastApplying a function over a collectioncompiler:inferenceType inferencemissing dataBase.missing and related functionality

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions