Skip to content

Commit c0c14b6

Browse files
committed
AbstractInterpreter: enable selective pure/concrete eval for external AbstractInterpreter with overlayed method table
Built on top of #44511 and #44561, and solves <JuliaGPU/GPUCompiler.jl#309>. This commit allows external `AbstractInterpreter` to selectively use pure/concrete evals even if it uses an overlayed method table. More specifically, such `AbstractInterpreter` can use pure/concrete evals as far as any callees used in a call in question doesn't come from the overlayed method table: ```julia @test Base.return_types((), MTOverlayInterp()) do isbitstype(Int) ? nothing : missing end == Any[Nothing] Base.@assume_effects :terminates_globally function issue41694(x) res = 1 1 < x < 20 || throw("bad") while x > 1 res *= x x -= 1 end return res end @test Base.return_types((), MTOverlayInterp()) do issue41694(3) == 6 ? nothing : missing end == Any[Nothing] ``` In order to check if a call is tainted by any overlayed call, our effect system now additionally tracks `overlayed::Bool` property. This effect property is required to prevents concrete-eval in the following kind of situation: ```julia strangesin(x) = sin(x) @overlay OverlayedMT strangesin(x::Float64) = iszero(x) ? nothing : cos(x) Base.@assume_effects :total totalcall(f, args...) = f(args...) @test Base.return_types(; interp=MTOverlayInterp()) do # we need to disable partial pure/concrete evaluation when tainted by any overlayed call if totalcall(strangesin, 1.0) == cos(1.0) return nothing else return missing end end |> only === Nothing ```
1 parent b2890d5 commit c0c14b6

File tree

8 files changed

+185
-98
lines changed

8 files changed

+185
-98
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 78 additions & 42 deletions
Large diffs are not rendered by default.

base/compiler/inferencestate.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ mutable struct InferenceState
134134
#=parent=#nothing,
135135
#=cached=#cache === :global,
136136
#=inferred=#false, #=dont_work_on_me=#false,
137-
#=ipo_effects=#Effects(consistent, ALWAYS_TRUE, ALWAYS_TRUE, ALWAYS_TRUE, inbounds_taints_consistency),
137+
#=ipo_effects=#Effects(consistent, ALWAYS_TRUE, ALWAYS_TRUE, ALWAYS_TRUE, false, inbounds_taints_consistency),
138138
interp)
139139
result.result = frame
140140
cache !== :no && push!(get_inference_cache(interp), result)

base/compiler/methodtable.jl

Lines changed: 35 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,18 @@ end
4040
getindex(result::MethodLookupResult, idx::Int) = getindex(result.matches, idx)::MethodMatch
4141

4242
"""
43-
findall(sig::Type, view::MethodTableView; limit::Int=typemax(Int)) -> MethodLookupResult or missing
43+
findall(sig::Type, view::MethodTableView; limit::Int=typemax(Int)) ->
44+
(matches::MethodLookupResult, overlayed::Bool) or missing
4445
45-
Find all methods in the given method table `view` that are applicable to the
46-
given signature `sig`. If no applicable methods are found, an empty result is
47-
returned. If the number of applicable methods exceeded the specified limit,
48-
`missing` is returned.
46+
Find all methods in the given method table `view` that are applicable to the given signature `sig`.
47+
If no applicable methods are found, an empty result is returned.
48+
If the number of applicable methods exceeded the specified limit, `missing` is returned.
49+
`overlayed` indicates if any matching method is defined in an overlayed method table.
4950
"""
5051
function findall(@nospecialize(sig::Type), table::InternalMethodTable; limit::Int=Int(typemax(Int32)))
51-
return _findall(sig, nothing, table.world, limit)
52+
result = _findall(sig, nothing, table.world, limit)
53+
result === missing && return missing
54+
return result, false
5255
end
5356

5457
function findall(@nospecialize(sig::Type), table::OverlayMethodTable; limit::Int=Int(typemax(Int32)))
@@ -57,7 +60,7 @@ function findall(@nospecialize(sig::Type), table::OverlayMethodTable; limit::Int
5760
nr = length(result)
5861
if nr 1 && result[nr].fully_covers
5962
# no need to fall back to the internal method table
60-
return result
63+
return result, true
6164
end
6265
# fall back to the internal method table
6366
fallback_result = _findall(sig, nothing, table.world, limit)
@@ -68,7 +71,7 @@ function findall(@nospecialize(sig::Type), table::OverlayMethodTable; limit::Int
6871
WorldRange(
6972
max(result.valid_worlds.min_world, fallback_result.valid_worlds.min_world),
7073
min(result.valid_worlds.max_world, fallback_result.valid_worlds.max_world)),
71-
result.ambig | fallback_result.ambig)
74+
result.ambig | fallback_result.ambig), !isempty(result)
7275
end
7376

7477
function _findall(@nospecialize(sig::Type), mt::Union{Nothing,Core.MethodTable}, world::UInt, limit::Int)
@@ -83,31 +86,38 @@ function _findall(@nospecialize(sig::Type), mt::Union{Nothing,Core.MethodTable},
8386
end
8487

8588
"""
86-
findsup(sig::Type, view::MethodTableView) -> Tuple{MethodMatch, WorldRange} or nothing
87-
88-
Find the (unique) method `m` such that `sig <: m.sig`, while being more
89-
specific than any other method with the same property. In other words, find
90-
the method which is the least upper bound (supremum) under the specificity/subtype
91-
relation of the queried `signature`. If `sig` is concrete, this is equivalent to
92-
asking for the method that will be called given arguments whose types match the
93-
given signature. This query is also used to implement `invoke`.
94-
95-
Such a method `m` need not exist. It is possible that no method is an
96-
upper bound of `sig`, or it is possible that among the upper bounds, there
97-
is no least element. In both cases `nothing` is returned.
89+
findsup(sig::Type, view::MethodTableView) ->
90+
(match::MethodMatch, valid_worlds::WorldRange, overlayed::Bool) or nothing
91+
92+
Find the (unique) method such that `sig <: match.method.sig`, while being more
93+
specific than any other method with the same property. In other words, find the method
94+
which is the least upper bound (supremum) under the specificity/subtype relation of
95+
the queried `sig`nature. If `sig` is concrete, this is equivalent to asking for the method
96+
that will be called given arguments whose types match the given signature.
97+
Note that this query is also used to implement `invoke`.
98+
99+
Such a matching method `match` doesn't necessarily exist.
100+
It is possible that no method is an upper bound of `sig`, or
101+
it is possible that among the upper bounds, there is no least element.
102+
In both cases `nothing` is returned.
103+
104+
`overlayed` indicates if the matching method is defined in an overlayed method table.
98105
"""
99106
function findsup(@nospecialize(sig::Type), table::InternalMethodTable)
100-
return _findsup(sig, nothing, table.world)
107+
return (_findsup(sig, nothing, table.world)..., false)
101108
end
102109

103110
function findsup(@nospecialize(sig::Type), table::OverlayMethodTable)
104111
match, valid_worlds = _findsup(sig, table.mt, table.world)
105-
match !== nothing && return match, valid_worlds
112+
match !== nothing && return match, valid_worlds, true
106113
# fall back to the internal method table
107114
fallback_match, fallback_valid_worlds = _findsup(sig, nothing, table.world)
108-
return fallback_match, WorldRange(
109-
max(valid_worlds.min_world, fallback_valid_worlds.min_world),
110-
min(valid_worlds.max_world, fallback_valid_worlds.max_world))
115+
return (
116+
fallback_match,
117+
WorldRange(
118+
max(valid_worlds.min_world, fallback_valid_worlds.min_world),
119+
min(valid_worlds.max_world, fallback_valid_worlds.max_world)),
120+
false)
111121
end
112122

113123
function _findsup(@nospecialize(sig::Type), mt::Union{Nothing,Core.MethodTable}, world::UInt)

base/compiler/ssair/show.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -803,6 +803,7 @@ function Base.show(io::IO, e::Core.Compiler.Effects)
803803
print(io, ',')
804804
printstyled(io, string(tristate_letter(e.terminates), 't'); color=tristate_color(e.terminates))
805805
print(io, ')')
806+
e.overlayed && printstyled(io, ''; color=:red)
806807
end
807808

808809
@specialize

base/compiler/tfuncs.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1789,11 +1789,11 @@ function builtin_effects(f::Builtin, argtypes::Vector{Any}, rt)
17891789
if (f === Core.getfield || f === Core.isdefined) && length(argtypes) >= 3
17901790
# consistent if the argtype is immutable
17911791
if isvarargtype(argtypes[2])
1792-
return Effects(Effects(), effect_free=ALWAYS_TRUE, terminates=ALWAYS_TRUE)
1792+
return Effects(; effect_free=ALWAYS_TRUE, terminates=ALWAYS_TRUE, overlayed=false)
17931793
end
17941794
s = widenconst(argtypes[2])
17951795
if isType(s) || !isa(s, DataType) || isabstracttype(s)
1796-
return Effects(Effects(), effect_free=ALWAYS_TRUE, terminates=ALWAYS_TRUE)
1796+
return Effects(; effect_free=ALWAYS_TRUE, terminates=ALWAYS_TRUE, overlayed=false)
17971797
end
17981798
s = s::DataType
17991799
ipo_consistent = !ismutabletype(s)
@@ -1826,7 +1826,9 @@ function builtin_effects(f::Builtin, argtypes::Vector{Any}, rt)
18261826
ipo_consistent ? ALWAYS_TRUE : ALWAYS_FALSE,
18271827
effect_free ? ALWAYS_TRUE : ALWAYS_FALSE,
18281828
nothrow ? ALWAYS_TRUE : TRISTATE_UNKNOWN,
1829-
ALWAYS_TRUE)
1829+
#=terminates=#ALWAYS_TRUE,
1830+
#=overlayed=#false,
1831+
)
18301832
end
18311833

18321834
function builtin_nothrow(@nospecialize(f), argtypes::Array{Any, 1}, @nospecialize(rt))
@@ -2007,7 +2009,9 @@ function intrinsic_effects(f::IntrinsicFunction, argtypes::Vector{Any})
20072009
ipo_consistent ? ALWAYS_TRUE : ALWAYS_FALSE,
20082010
effect_free ? ALWAYS_TRUE : ALWAYS_FALSE,
20092011
nothrow ? ALWAYS_TRUE : TRISTATE_UNKNOWN,
2010-
ALWAYS_TRUE)
2012+
#=terminates=#ALWAYS_TRUE,
2013+
#=overlayed=#false,
2014+
)
20112015
end
20122016

20132017
# TODO: this function is a very buggy and poor model of the return_type function

base/compiler/typeinfer.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ function rt_adjust_effects(@nospecialize(rt), ipo_effects::Effects)
431431
# but we don't currently model idempontency using dataflow, so we don't notice.
432432
# Fix that up here to improve precision.
433433
if !ipo_effects.inbounds_taints_consistency && rt === Union{}
434-
return Effects(ipo_effects, consistent=ALWAYS_TRUE)
434+
return Effects(ipo_effects; consistent=ALWAYS_TRUE)
435435
end
436436
return ipo_effects
437437
end
@@ -755,11 +755,11 @@ function merge_call_chain!(parent::InferenceState, ancestor::InferenceState, chi
755755
# and ensure that walking the parent list will get the same result (DAG) from everywhere
756756
# Also taint the termination effect, because we can no longer guarantee the absence
757757
# of recursion.
758-
tristate_merge!(parent, Effects(EFFECTS_TOTAL, terminates=TRISTATE_UNKNOWN))
758+
tristate_merge!(parent, Effects(EFFECTS_TOTAL; terminates=TRISTATE_UNKNOWN))
759759
while true
760760
add_cycle_backedge!(child, parent, parent.currpc)
761761
union_caller_cycle!(ancestor, child)
762-
tristate_merge!(child, Effects(EFFECTS_TOTAL, terminates=TRISTATE_UNKNOWN))
762+
tristate_merge!(child, Effects(EFFECTS_TOTAL; terminates=TRISTATE_UNKNOWN))
763763
child = parent
764764
child === ancestor && break
765765
parent = child.parent::InferenceState

base/compiler/types.jl

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ struct Effects
3838
effect_free::TriState
3939
nothrow::TriState
4040
terminates::TriState
41+
overlayed::Bool
4142
# This effect is currently only tracked in inference and modified
4243
# :consistent before caching. We may want to track it in the future.
4344
inbounds_taints_consistency::Bool
@@ -46,27 +47,33 @@ function Effects(
4647
consistent::TriState,
4748
effect_free::TriState,
4849
nothrow::TriState,
49-
terminates::TriState)
50+
terminates::TriState,
51+
overlayed::Bool)
5052
return Effects(
5153
consistent,
5254
effect_free,
5355
nothrow,
5456
terminates,
57+
overlayed,
5558
false)
5659
end
57-
Effects() = Effects(TRISTATE_UNKNOWN, TRISTATE_UNKNOWN, TRISTATE_UNKNOWN, TRISTATE_UNKNOWN)
5860

59-
function Effects(e::Effects;
61+
const EFFECTS_TOTAL = Effects(ALWAYS_TRUE, ALWAYS_TRUE, ALWAYS_TRUE, ALWAYS_TRUE, false)
62+
const EFFECTS_UNKNOWN = Effects(TRISTATE_UNKNOWN, TRISTATE_UNKNOWN, TRISTATE_UNKNOWN, TRISTATE_UNKNOWN, true)
63+
64+
function Effects(e::Effects = EFFECTS_UNKNOWN;
6065
consistent::TriState = e.consistent,
6166
effect_free::TriState = e.effect_free,
6267
nothrow::TriState = e.nothrow,
6368
terminates::TriState = e.terminates,
69+
overlayed::Bool = e.overlayed,
6470
inbounds_taints_consistency::Bool = e.inbounds_taints_consistency)
6571
return Effects(
6672
consistent,
6773
effect_free,
6874
nothrow,
6975
terminates,
76+
overlayed,
7077
inbounds_taints_consistency)
7178
end
7279

@@ -82,20 +89,20 @@ is_removable_if_unused(effects::Effects) =
8289
effects.terminates === ALWAYS_TRUE &&
8390
effects.nothrow === ALWAYS_TRUE
8491

85-
const EFFECTS_TOTAL = Effects(ALWAYS_TRUE, ALWAYS_TRUE, ALWAYS_TRUE, ALWAYS_TRUE)
86-
8792
function encode_effects(e::Effects)
88-
return e.consistent.state |
89-
(e.effect_free.state << 2) |
90-
(e.nothrow.state << 4) |
91-
(e.terminates.state << 6)
93+
return (e.consistent.state << 1) |
94+
(e.effect_free.state << 3) |
95+
(e.nothrow.state << 5) |
96+
(e.terminates.state << 7) |
97+
(e.overlayed)
9298
end
9399
function decode_effects(e::UInt8)
94100
return Effects(
95-
TriState(e & 0x3),
96-
TriState((e >> 2) & 0x3),
97-
TriState((e >> 4) & 0x3),
98-
TriState((e >> 6) & 0x3),
101+
TriState((e >> 1) & 0x03),
102+
TriState((e >> 3) & 0x03),
103+
TriState((e >> 5) & 0x03),
104+
TriState((e >> 7) & 0x03),
105+
e & 0x01 0x00,
99106
false)
100107
end
101108

@@ -109,6 +116,7 @@ function tristate_merge(old::Effects, new::Effects)
109116
old.nothrow, new.nothrow),
110117
tristate_merge(
111118
old.terminates, new.terminates),
119+
old.overlayed | new.overlayed,
112120
old.inbounds_taints_consistency | new.inbounds_taints_consistency)
113121
end
114122

@@ -158,7 +166,7 @@ mutable struct InferenceResult
158166
arginfo#=::Union{Nothing,Tuple{ArgInfo,InferenceState}}=# = nothing)
159167
argtypes, overridden_by_const = matching_cache_argtypes(linfo, arginfo)
160168
return new(linfo, argtypes, overridden_by_const, Any, nothing,
161-
WorldRange(), Effects(), Effects(), nothing)
169+
WorldRange(), Effects(; overlayed=false), Effects(; overlayed=false), nothing)
162170
end
163171
end
164172

test/compiler/AbstractInterpreter.jl

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,25 +41,53 @@ import Base.Experimental: @MethodTable, @overlay
4141
@MethodTable(OverlayedMT)
4242
CC.method_table(interp::MTOverlayInterp) = CC.OverlayMethodTable(CC.get_world_counter(interp), OverlayedMT)
4343

44-
@overlay OverlayedMT sin(x::Float64) = 1
45-
@test Base.return_types((Int,), MTOverlayInterp()) do x
46-
sin(x)
47-
end == Any[Int]
44+
strangesin(x) = sin(x)
45+
@overlay OverlayedMT strangesin(x::Float64) = iszero(x) ? nothing : cos(x)
46+
@test Base.return_types((Float64,), MTOverlayInterp()) do x
47+
strangesin(x)
48+
end |> only === Union{Float64,Nothing}
4849
@test Base.return_types((Any,), MTOverlayInterp()) do x
49-
Base.@invoke sin(x::Float64)
50-
end == Any[Int]
50+
Base.@invoke strangesin(x::Float64)
51+
end |> only === Union{Float64,Nothing}
5152

5253
# fallback to the internal method table
5354
@test Base.return_types((Int,), MTOverlayInterp()) do x
5455
cos(x)
55-
end == Any[Float64]
56+
end |> only === Float64
5657
@test Base.return_types((Any,), MTOverlayInterp()) do x
5758
Base.@invoke cos(x::Float64)
58-
end == Any[Float64]
59+
end |> only === Float64
5960

6061
# not fully covered overlay method match
6162
overlay_match(::Any) = nothing
6263
@overlay OverlayedMT overlay_match(::Int) = missing
6364
@test Base.return_types((Any,), MTOverlayInterp()) do x
6465
overlay_match(x)
65-
end == Any[Union{Nothing,Missing}]
66+
end |> only === Union{Nothing,Missing}
67+
68+
# partial pure/concrete evaluation
69+
@test Base.return_types((), MTOverlayInterp()) do
70+
isbitstype(Int) ? nothing : missing
71+
end |> only === Nothing
72+
Base.@assume_effects :terminates_globally function issue41694(x)
73+
res = 1
74+
1 < x < 20 || throw("bad")
75+
while x > 1
76+
res *= x
77+
x -= 1
78+
end
79+
return res
80+
end
81+
@test Base.return_types((), MTOverlayInterp()) do
82+
issue41694(3) == 6 ? nothing : missing
83+
end |> only === Nothing
84+
85+
# disable partial pure/concrete evaluation when tainted by any overlayed call
86+
Base.@assume_effects :total totalcall(f, args...) = f(args...)
87+
@test Base.return_types((), MTOverlayInterp()) do
88+
if totalcall(strangesin, 1.0) == cos(1.0)
89+
return nothing
90+
else
91+
return missing
92+
end
93+
end |> only === Nothing

0 commit comments

Comments
 (0)