Skip to content

Commit 71ab5fa

Browse files
committed
ensure sparams are cached correctly for widened methods
Follow-up issue found while working on #47476
1 parent 16d3b92 commit 71ab5fa

File tree

9 files changed

+79
-50
lines changed

9 files changed

+79
-50
lines changed

base/compiler/typeinfer.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ function maybe_compress_codeinfo(interp::AbstractInterpreter, linfo::MethodInsta
347347
return ci
348348
end
349349
if may_discard_trees(interp)
350-
cache_the_tree = ci.inferred && (is_inlineable(ci) || isa_compileable_sig(linfo.specTypes, def))
350+
cache_the_tree = ci.inferred && (is_inlineable(ci) || isa_compileable_sig(linfo.specTypes, linfo.sparam_vals, def))
351351
else
352352
cache_the_tree = true
353353
end

base/compiler/utilities.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,8 @@ function get_compileable_sig(method::Method, @nospecialize(atype), sparams::Simp
152152
mt, atype, sparams, method)
153153
end
154154

155-
isa_compileable_sig(@nospecialize(atype), method::Method) =
156-
!iszero(ccall(:jl_isa_compileable_sig, Int32, (Any, Any), atype, method))
155+
isa_compileable_sig(@nospecialize(atype), sparams::SimpleVector, method::Method) =
156+
!iszero(ccall(:jl_isa_compileable_sig, Int32, (Any, Any, Any), atype, sparams, method))
157157

158158
# eliminate UnionAll vars that might be degenerate due to having identical bounds,
159159
# or a concrete upper bound and appearing covariantly.
@@ -200,7 +200,12 @@ function specialize_method(method::Method, @nospecialize(atype), sparams::Simple
200200
if compilesig
201201
new_atype = get_compileable_sig(method, atype, sparams)
202202
new_atype === nothing && return nothing
203-
atype = new_atype
203+
if atype !== new_atype
204+
sp_ = ccall(:jl_type_intersection_with_env, Any, (Any, Any), new_atype, method.sig)::SimpleVector
205+
if sparams === sp_[2]::SimpleVector
206+
atype = new_atype
207+
end
208+
end
204209
end
205210
if preexisting
206211
# check cached specializations

src/gf.c

Lines changed: 59 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -637,13 +637,14 @@ static void jl_compilation_sig(
637637
for (i = 0; i < np; i++) {
638638
jl_value_t *elt = jl_tparam(tt, i);
639639
jl_value_t *decl_i = jl_nth_slot_type(decl, i);
640+
jl_value_t *type_i = jl_rewrap_unionall(decl_i, decl);
640641
size_t i_arg = (i < nargs - 1 ? i : nargs - 1);
641642

642-
if (jl_is_kind(decl_i)) {
643+
if (jl_is_kind(type_i)) {
643644
// if we can prove the match was against the kind (not a Type)
644645
// we want to put that in the cache instead
645646
if (!*newparams) *newparams = jl_svec_copy(tt->parameters);
646-
elt = decl_i;
647+
elt = type_i;
647648
jl_svecset(*newparams, i, elt);
648649
}
649650
else if (jl_is_type_type(elt)) {
@@ -652,7 +653,7 @@ static void jl_compilation_sig(
652653
// and the result of matching the type signature
653654
// needs to be restricted to the concrete type 'kind'
654655
jl_value_t *kind = jl_typeof(jl_tparam0(elt));
655-
if (jl_subtype(kind, decl_i) && !jl_subtype((jl_value_t*)jl_type_type, decl_i)) {
656+
if (jl_subtype(kind, type_i) && !jl_subtype((jl_value_t*)jl_type_type, type_i)) {
656657
// if we can prove the match was against the kind (not a Type)
657658
// it's simpler (and thus better) to put that cache instead
658659
if (!*newparams) *newparams = jl_svec_copy(tt->parameters);
@@ -664,7 +665,7 @@ static void jl_compilation_sig(
664665
// not triggered for isdispatchtuple(tt), this attempts to handle
665666
// some cases of adapting a random signature into a compilation signature
666667
// if we get a kind, where we don't expect to accept one, widen it to something more expected (Type{T})
667-
if (!(jl_subtype(elt, decl_i) && !jl_subtype((jl_value_t*)jl_type_type, decl_i))) {
668+
if (!(jl_subtype(elt, type_i) && !jl_subtype((jl_value_t*)jl_type_type, type_i))) {
668669
if (!*newparams) *newparams = jl_svec_copy(tt->parameters);
669670
elt = (jl_value_t*)jl_type_type;
670671
jl_svecset(*newparams, i, elt);
@@ -703,7 +704,7 @@ static void jl_compilation_sig(
703704
jl_svecset(*newparams, i, jl_type_type);
704705
}
705706
else if (jl_is_type_type(elt)) { // elt isa Type{T}
706-
if (very_general_type(decl_i)) {
707+
if (!jl_has_free_typevars(decl_i) && very_general_type(type_i)) {
707708
/*
708709
Here's a fairly simple heuristic: if this argument slot's
709710
declared type is general (Type or Any),
@@ -742,15 +743,13 @@ static void jl_compilation_sig(
742743
*/
743744
if (!*newparams) *newparams = jl_svec_copy(tt->parameters);
744745
if (i < nargs || !definition->isva) {
745-
jl_value_t *di = jl_type_intersection(decl_i, (jl_value_t*)jl_type_type);
746+
jl_value_t *di = jl_type_intersection(type_i, (jl_value_t*)jl_type_type);
746747
assert(di != (jl_value_t*)jl_bottom_type);
747748
// issue #11355: DataType has a UID and so would take precedence in the cache
748749
if (jl_is_kind(di))
749750
jl_svecset(*newparams, i, (jl_value_t*)jl_type_type);
750751
else
751752
jl_svecset(*newparams, i, di);
752-
// TODO: recompute static parameter values, so in extreme cases we
753-
// can give `T=Type` instead of `T=Type{Type{Type{...`. /* make editors happy:}}} */
754753
}
755754
else {
756755
jl_svecset(*newparams, i, (jl_value_t*)jl_type_type);
@@ -759,14 +758,15 @@ static void jl_compilation_sig(
759758
}
760759

761760
int notcalled_func = (i_arg > 0 && i_arg <= 8 && !(definition->called & (1 << (i_arg - 1))) &&
761+
!jl_has_free_typevars(decl_i) &&
762762
jl_subtype(elt, (jl_value_t*)jl_function_type));
763-
if (notcalled_func && (decl_i == (jl_value_t*)jl_any_type ||
764-
decl_i == (jl_value_t*)jl_function_type ||
765-
(jl_is_uniontype(decl_i) && // Base.Callable
766-
((((jl_uniontype_t*)decl_i)->a == (jl_value_t*)jl_function_type &&
767-
((jl_uniontype_t*)decl_i)->b == (jl_value_t*)jl_type_type) ||
768-
(((jl_uniontype_t*)decl_i)->b == (jl_value_t*)jl_function_type &&
769-
((jl_uniontype_t*)decl_i)->a == (jl_value_t*)jl_type_type))))) {
763+
if (notcalled_func && (type_i == (jl_value_t*)jl_any_type ||
764+
type_i == (jl_value_t*)jl_function_type ||
765+
(jl_is_uniontype(type_i) && // Base.Callable
766+
((((jl_uniontype_t*)type_i)->a == (jl_value_t*)jl_function_type &&
767+
((jl_uniontype_t*)type_i)->b == (jl_value_t*)jl_type_type) ||
768+
(((jl_uniontype_t*)type_i)->b == (jl_value_t*)jl_function_type &&
769+
((jl_uniontype_t*)type_i)->a == (jl_value_t*)jl_type_type))))) {
770770
// and attempt to despecialize types marked Function, Callable, or Any
771771
// when called with a subtype of Function but is not called
772772
if (!*newparams) *newparams = jl_svec_copy(tt->parameters);
@@ -833,6 +833,7 @@ static void jl_compilation_sig(
833833
// compute whether this type signature is a possible return value from jl_compilation_sig given a concrete-type for `tt`
834834
JL_DLLEXPORT int jl_isa_compileable_sig(
835835
jl_tupletype_t *type,
836+
jl_svec_t *sparams,
836837
jl_method_t *definition)
837838
{
838839
jl_value_t *decl = definition->sig;
@@ -886,6 +887,7 @@ JL_DLLEXPORT int jl_isa_compileable_sig(
886887
for (i = 0; i < np; i++) {
887888
jl_value_t *elt = jl_tparam(type, i);
888889
jl_value_t *decl_i = jl_nth_slot_type((jl_value_t*)decl, i);
890+
jl_value_t *type_i = jl_rewrap_unionall(decl_i, decl);
889891
size_t i_arg = (i < nargs - 1 ? i : nargs - 1);
890892

891893
if (jl_is_vararg(elt)) {
@@ -919,25 +921,26 @@ JL_DLLEXPORT int jl_isa_compileable_sig(
919921

920922
if (jl_is_kind(elt)) {
921923
// kind slots always get guard entries (checking for subtypes of Type)
922-
if (jl_subtype(elt, decl_i) && !jl_subtype((jl_value_t*)jl_type_type, decl_i))
924+
if (jl_subtype(elt, type_i) && !jl_subtype((jl_value_t*)jl_type_type, type_i))
923925
continue;
924926
// TODO: other code paths that could reach here
925927
return 0;
926928
}
927-
else if (jl_is_kind(decl_i)) {
929+
else if (jl_is_kind(type_i)) {
928930
return 0;
929931
}
930932

931933
if (jl_is_type_type(jl_unwrap_unionall(elt))) {
932-
int iscalled = i_arg > 0 && i_arg <= 8 && (definition->called & (1 << (i_arg - 1)));
934+
int iscalled = (i_arg > 0 && i_arg <= 8 && (definition->called & (1 << (i_arg - 1)))) ||
935+
jl_has_free_typevars(decl_i);
933936
if (jl_types_equal(elt, (jl_value_t*)jl_type_type)) {
934-
if (!iscalled && very_general_type(decl_i))
937+
if (!iscalled && very_general_type(type_i))
935938
continue;
936939
if (i >= nargs && definition->isva)
937940
continue;
938941
return 0;
939942
}
940-
if (!iscalled && very_general_type(decl_i))
943+
if (!iscalled && very_general_type(type_i))
941944
return 0;
942945
if (!jl_is_datatype(elt))
943946
return 0;
@@ -949,7 +952,7 @@ JL_DLLEXPORT int jl_isa_compileable_sig(
949952
jl_value_t *kind = jl_typeof(jl_tparam0(elt));
950953
if (kind == jl_bottom_type)
951954
return 0; // Type{Union{}} gets normalized to typeof(Union{})
952-
if (jl_subtype(kind, decl_i) && !jl_subtype((jl_value_t*)jl_type_type, decl_i))
955+
if (jl_subtype(kind, type_i) && !jl_subtype((jl_value_t*)jl_type_type, type_i))
953956
return 0; // gets turned into a kind
954957

955958
else if (jl_is_type_type(jl_tparam0(elt)) &&
@@ -963,7 +966,7 @@ JL_DLLEXPORT int jl_isa_compileable_sig(
963966
this can be determined using a type intersection.
964967
*/
965968
if (i < nargs || !definition->isva) {
966-
jl_value_t *di = jl_type_intersection(decl_i, (jl_value_t*)jl_type_type);
969+
jl_value_t *di = jl_type_intersection(type_i, (jl_value_t*)jl_type_type);
967970
JL_GC_PUSH1(&di);
968971
assert(di != (jl_value_t*)jl_bottom_type);
969972
if (jl_is_kind(di)) {
@@ -984,14 +987,15 @@ JL_DLLEXPORT int jl_isa_compileable_sig(
984987
}
985988

986989
int notcalled_func = (i_arg > 0 && i_arg <= 8 && !(definition->called & (1 << (i_arg - 1))) &&
990+
!jl_has_free_typevars(decl_i) &&
987991
jl_subtype(elt, (jl_value_t*)jl_function_type));
988-
if (notcalled_func && (decl_i == (jl_value_t*)jl_any_type ||
989-
decl_i == (jl_value_t*)jl_function_type ||
990-
(jl_is_uniontype(decl_i) && // Base.Callable
991-
((((jl_uniontype_t*)decl_i)->a == (jl_value_t*)jl_function_type &&
992-
((jl_uniontype_t*)decl_i)->b == (jl_value_t*)jl_type_type) ||
993-
(((jl_uniontype_t*)decl_i)->b == (jl_value_t*)jl_function_type &&
994-
((jl_uniontype_t*)decl_i)->a == (jl_value_t*)jl_type_type))))) {
992+
if (notcalled_func && (type_i == (jl_value_t*)jl_any_type ||
993+
type_i == (jl_value_t*)jl_function_type ||
994+
(jl_is_uniontype(type_i) && // Base.Callable
995+
((((jl_uniontype_t*)type_i)->a == (jl_value_t*)jl_function_type &&
996+
((jl_uniontype_t*)type_i)->b == (jl_value_t*)jl_type_type) ||
997+
(((jl_uniontype_t*)type_i)->b == (jl_value_t*)jl_function_type &&
998+
((jl_uniontype_t*)type_i)->a == (jl_value_t*)jl_type_type))))) {
995999
// and attempt to despecialize types marked Function, Callable, or Any
9961000
// when called with a subtype of Function but is not called
9971001
if (elt == (jl_value_t*)jl_function_type)
@@ -1087,7 +1091,7 @@ static jl_method_instance_t *cache_method(
10871091
// cache miss. Alternatively, we may use the original signature in the
10881092
// cache, but use this return for compilation.
10891093
//
1090-
// In most cases `!jl_isa_compileable_sig(tt, definition)`,
1094+
// In most cases `!jl_isa_compileable_sig(tt, sparams, definition)`,
10911095
// although for some cases, (notably Varargs)
10921096
// we might choose a replacement type that's preferable but not strictly better
10931097
int issubty;
@@ -1099,7 +1103,7 @@ static jl_method_instance_t *cache_method(
10991103
}
11001104
newparams = NULL;
11011105
}
1102-
// TODO: maybe assert(jl_isa_compileable_sig(compilationsig, definition));
1106+
// TODO: maybe assert(jl_isa_compileable_sig(compilationsig, sparams, definition));
11031107
newmeth = jl_specializations_get_linfo(definition, (jl_value_t*)compilationsig, sparams);
11041108

11051109
jl_tupletype_t *cachett = tt;
@@ -2281,9 +2285,21 @@ JL_DLLEXPORT jl_value_t *jl_normalize_to_compilable_sig(jl_methtable_t *mt, jl_t
22812285
jl_methtable_t *kwmt = mt == jl_kwcall_mt ? jl_kwmethod_table_for(m->sig) : mt;
22822286
intptr_t nspec = (kwmt == NULL || kwmt == jl_type_type_mt || kwmt == jl_nonfunction_mt || kwmt == jl_kwcall_mt ? m->nargs + 1 : jl_atomic_load_relaxed(&kwmt->max_args) + 2 + 2 * (mt == jl_kwcall_mt));
22832287
jl_compilation_sig(ti, env, m, nspec, &newparams);
2284-
tt = (newparams ? jl_apply_tuple_type(newparams) : ti);
2285-
int is_compileable = ((jl_datatype_t*)ti)->isdispatchtuple ||
2286-
jl_isa_compileable_sig(tt, m);
2288+
int is_compileable = ((jl_datatype_t*)ti)->isdispatchtuple;
2289+
if (newparams) {
2290+
tt = jl_apply_tuple_type(newparams);
2291+
if (!is_compileable) {
2292+
// compute new env, if used below
2293+
jl_value_t *ti = jl_type_intersection_env((jl_value_t*)tt, (jl_value_t*)m->sig, &newparams);
2294+
assert(ti != jl_bottom_type); (void)ti;
2295+
env = newparams;
2296+
}
2297+
}
2298+
else {
2299+
tt = ti;
2300+
}
2301+
if (!is_compileable)
2302+
is_compileable = jl_isa_compileable_sig(tt, env, m);
22872303
JL_GC_POP();
22882304
return is_compileable ? (jl_value_t*)tt : jl_nothing;
22892305
}
@@ -2301,7 +2317,7 @@ jl_method_instance_t *jl_normalize_to_compilable_mi(jl_method_instance_t *mi JL_
23012317
return mi;
23022318
jl_svec_t *env = NULL;
23032319
JL_GC_PUSH2(&compilationsig, &env);
2304-
jl_value_t *ti = jl_type_intersection_env((jl_value_t*)mi->specTypes, (jl_value_t*)def->sig, &env);
2320+
jl_value_t *ti = jl_type_intersection_env((jl_value_t*)compilationsig, (jl_value_t*)def->sig, &env);
23052321
assert(ti != jl_bottom_type); (void)ti;
23062322
mi = jl_specializations_get_linfo(def, (jl_value_t*)compilationsig, env);
23072323
JL_GC_POP();
@@ -2318,7 +2334,7 @@ jl_method_instance_t *jl_method_match_to_mi(jl_method_match_t *match, size_t wor
23182334
if (jl_is_datatype(ti)) {
23192335
jl_methtable_t *mt = jl_method_get_table(m);
23202336
if ((jl_value_t*)mt != jl_nothing) {
2321-
// get the specialization without caching it
2337+
// get the specialization, possibly also caching it
23222338
if (mt_cache && ((jl_datatype_t*)ti)->isdispatchtuple) {
23232339
// Since we also use this presence in the cache
23242340
// to trigger compilation when producing `.ji` files,
@@ -2330,11 +2346,15 @@ jl_method_instance_t *jl_method_match_to_mi(jl_method_match_t *match, size_t wor
23302346
}
23312347
else {
23322348
jl_value_t *tt = jl_normalize_to_compilable_sig(mt, ti, env, m);
2333-
JL_GC_PUSH1(&tt);
23342349
if (tt != jl_nothing) {
2350+
JL_GC_PUSH2(&tt, &env);
2351+
if (!jl_egal(tt, (jl_value_t*)ti)) {
2352+
jl_value_t *ti = jl_type_intersection_env((jl_value_t*)tt, (jl_value_t*)m->sig, &env);
2353+
assert(ti != jl_bottom_type); (void)ti;
2354+
}
23352355
mi = jl_specializations_get_linfo(m, (jl_value_t*)tt, env);
2356+
JL_GC_POP();
23362357
}
2337-
JL_GC_POP();
23382358
}
23392359
}
23402360
}
@@ -2397,7 +2417,7 @@ jl_method_instance_t *jl_get_compile_hint_specialization(jl_tupletype_t *types J
23972417
size_t count = 0;
23982418
for (i = 0; i < n; i++) {
23992419
jl_method_match_t *match1 = (jl_method_match_t*)jl_array_ptr_ref(matches, i);
2400-
if (jl_isa_compileable_sig(types, match1->method))
2420+
if (jl_isa_compileable_sig(types, match1->sparams, match1->method))
24012421
jl_array_ptr_set(matches, count++, (jl_value_t*)match1);
24022422
}
24032423
jl_array_del_end((jl_array_t*)matches, n - count);

src/julia.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1433,7 +1433,7 @@ STATIC_INLINE int jl_is_concrete_type(jl_value_t *v) JL_NOTSAFEPOINT
14331433
return jl_is_datatype(v) && ((jl_datatype_t*)v)->isconcretetype;
14341434
}
14351435

1436-
JL_DLLEXPORT int jl_isa_compileable_sig(jl_tupletype_t *type, jl_method_t *definition);
1436+
JL_DLLEXPORT int jl_isa_compileable_sig(jl_tupletype_t *type, jl_svec_t *sparams, jl_method_t *definition);
14371437

14381438
// type constructors
14391439
JL_DLLEXPORT jl_typename_t *jl_new_typename_in(jl_sym_t *name, jl_module_t *inmodule, int abstract, int mutabl);

src/precompile.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ static void jl_compile_all_defs(jl_array_t *mis)
269269
size_t i, l = jl_array_len(allmeths);
270270
for (i = 0; i < l; i++) {
271271
jl_method_t *m = (jl_method_t*)jl_array_ptr_ref(allmeths, i);
272-
if (jl_isa_compileable_sig((jl_tupletype_t*)m->sig, m)) {
272+
if (jl_is_datatype(m->sig) && jl_isa_compileable_sig((jl_tupletype_t*)m->sig, jl_emptysvec, m)) {
273273
// method has a single compilable specialization, e.g. its definition
274274
// signature is concrete. in this case we can just hint it.
275275
jl_compile_hint((jl_tupletype_t*)m->sig);
@@ -354,7 +354,7 @@ static void *jl_precompile_(jl_array_t *m)
354354
mi = (jl_method_instance_t*)item;
355355
size_t min_world = 0;
356356
size_t max_world = ~(size_t)0;
357-
if (mi != jl_atomic_load_relaxed(&mi->def.method->unspecialized) && !jl_isa_compileable_sig((jl_tupletype_t*)mi->specTypes, mi->def.method))
357+
if (mi != jl_atomic_load_relaxed(&mi->def.method->unspecialized) && !jl_isa_compileable_sig((jl_tupletype_t*)mi->specTypes, mi->sparam_vals, mi->def.method))
358358
mi = jl_get_specialization1((jl_tupletype_t*)mi->specTypes, jl_atomic_load_acquire(&jl_world_counter), &min_world, &max_world, 0);
359359
if (mi)
360360
jl_array_ptr_1d_push(m2, (jl_value_t*)mi);

stdlib/Random/src/Random.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ rand(rng::AbstractRNG, ::UniformT{T}) where {T} = rand(rng, T)
256256
rand(rng::AbstractRNG, X) = rand(rng, Sampler(rng, X, Val(1)))
257257
# this is needed to disambiguate
258258
rand(rng::AbstractRNG, X::Dims) = rand(rng, Sampler(rng, X, Val(1)))
259-
rand(rng::AbstractRNG=default_rng(), ::Type{X}=Float64) where {X} = rand(rng, Sampler(rng, X, Val(1)))::X
259+
rand(rng::AbstractRNG=default_rng(), ::Type{X}=Float64) where {X} = rand(rng, Sampler(rng, X, Val(1)))::X
260260

261261
rand(X) = rand(default_rng(), X)
262262
rand(::Type{X}) where {X} = rand(default_rng(), X)

test/compiler/inference.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ f11366(x::Type{Ref{T}}) where {T} = Ref{x}
406406

407407

408408
let f(T) = Type{T}
409-
@test Base.return_types(f, Tuple{Type{Int}}) == [Type{Type{Int}}]
409+
@test Base.return_types(f, Tuple{Type{Int}}) == Any[Type{Type{Int}}]
410410
end
411411

412412
# issue #9222

test/core.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7914,3 +7914,7 @@ code_typed(f47476, (Int, Int, Int, Vararg{Union{Int, NTuple{2,Int}}},))
79147914
code_typed(f47476, (Int, Int, Int, Int, Vararg{Union{Int, NTuple{2,Int}}},))
79157915
@test f47476(1, 2, 3, 4, 5, 6, (7, 8)) === 2
79167916
@test_throws UndefVarError(:N) f47476(1, 2, 3, 4, 5, 6, 7)
7917+
7918+
vect47476(::Type{T}) where {T} = T
7919+
@test vect47476(Type{Type{Type{Int32}}}) === Type{Type{Type{Int32}}}
7920+
@test vect47476(Type{Type{Type{Int64}}}) === Type{Type{Type{Int64}}}

test/precompile.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1493,8 +1493,8 @@ end
14931493
f(x, y) = x + y
14941494
f(x::Int, y) = 2x + y
14951495
end
1496-
precompile(M.f, (Int, Any))
1497-
precompile(M.f, (AbstractFloat, Any))
1496+
@test precompile(M.f, (Int, Any))
1497+
@test precompile(M.f, (AbstractFloat, Any))
14981498
mis = map(methods(M.f)) do m
14991499
m.specializations[1]
15001500
end

0 commit comments

Comments
 (0)