diff --git a/base/boot.jl b/base/boot.jl index 98b8cf2e9cf40..27b2daac0598d 100644 --- a/base/boot.jl +++ b/base/boot.jl @@ -222,7 +222,7 @@ primitive type Char <: AbstractChar 32 end primitive type Int8 <: Signed 8 end #primitive type UInt8 <: Unsigned 8 end primitive type Int16 <: Signed 16 end -primitive type UInt16 <: Unsigned 16 end +#primitive type UInt16 <: Unsigned 16 end #primitive type Int32 <: Signed 32 end #primitive type UInt32 <: Unsigned 32 end #primitive type Int64 <: Signed 64 end diff --git a/base/task.jl b/base/task.jl index e84c344c28d17..ce5506f224be1 100644 --- a/base/task.jl +++ b/base/task.jl @@ -166,11 +166,21 @@ end elseif field === :exception # TODO: this field name should be deprecated in 2.0 return t._isexception ? t.result : nothing + elseif field === :sticky + return getfield(t, :sticky_count) != 0 else return getfield(t, field) end end +function setproperty!(t::Task, field::Symbol, x) + if field === :sticky + t.sticky_count = convert(Bool, x) + else + setfield!(t, field, convert(fieldtype(Task, field), x)) + end +end + """ istaskdone(t::Task) -> Bool @@ -611,6 +621,37 @@ function __preinit_threads__() nothing end +# Factored out so that the behavior after saturation can be tested: +is_sticky_count_saturated(t::Task) = t.sticky_count === typemax(t.sticky_count) + +# This is a struct rather than a closure so that `serialize` can be dispatched +# to ignore `parent` field. +struct StickyCountDecrementer + code::Any + parent::Union{Nothing,Task} +end + +unset_parent(f::StickyCountDecrementer) = StickyCountDecrementer(f.code, nothing) + +function (f::StickyCountDecrementer)() + try + f.code() + finally + parent_task = f.parent + if parent_task !== nothing && !is_sticky_count_saturated(parent_task) + # Once `parent_task.sticky_count` hits the typemax (which + # practically never happens), we stop un-sticking the parent task. + # This only affects the performance in rare cases (which already + # torturing the scheulder anyway) and does not sacrifice the + # correctness. Checking saturation should be done for all tasks + # includding those started with `parent_task.sticky_count < typemax + # -1` since there may be sticky tasks started realying on that the + # counter is saturated. + parent_task.sticky_count -= 1 + end + end +end + function enq_work(t::Task) (t._state === task_state_runnable && t.queue === nothing) || error("schedule: Task not runnable") tid = Threads.threadid(t) @@ -625,8 +666,12 @@ function enq_work(t::Task) # t.sticky && tid == 0 is a task that needs to be co-scheduled with # the parent task. If the parent (current_task) is not sticky we must # set it to be sticky. - # XXX: Ideally we would be able to unset this - current_task().sticky = true + parent_task = current_task() + if t.sticky && !is_sticky_count_saturated(parent_task) + parent_task.sticky_count += 1 + t.code = StickyCountDecrementer(t.code, parent_task) + end + tid = Threads.threadid() ccall(:jl_set_task_tid, Cvoid, (Any, Cint), t, tid-1) end diff --git a/src/builtins.c b/src/builtins.c index f40d694d23529..f6cbe654594f1 100644 --- a/src/builtins.c +++ b/src/builtins.c @@ -1868,6 +1868,7 @@ void jl_init_primitives(void) JL_GC_DISABLED add_builtin("UInt8", (jl_value_t*)jl_uint8_type); add_builtin("Int32", (jl_value_t*)jl_int32_type); add_builtin("Int64", (jl_value_t*)jl_int64_type); + add_builtin("UInt16", (jl_value_t*)jl_uint16_type); add_builtin("UInt32", (jl_value_t*)jl_uint32_type); add_builtin("UInt64", (jl_value_t*)jl_uint64_type); #ifdef _P64 diff --git a/src/init.c b/src/init.c index 602583a9221fd..2ad6bdf4a0101 100644 --- a/src/init.c +++ b/src/init.c @@ -805,7 +805,6 @@ static void post_boot_hooks(void) jl_char_type = (jl_datatype_t*)core("Char"); jl_int8_type = (jl_datatype_t*)core("Int8"); jl_int16_type = (jl_datatype_t*)core("Int16"); - jl_uint16_type = (jl_datatype_t*)core("UInt16"); jl_float16_type = (jl_datatype_t*)core("Float16"); jl_float32_type = (jl_datatype_t*)core("Float32"); jl_float64_type = (jl_datatype_t*)core("Float64"); @@ -819,6 +818,7 @@ static void post_boot_hooks(void) jl_uint8_type->super = jl_unsigned_type; jl_int32_type->super = jl_signed_type; jl_int64_type->super = jl_signed_type; + jl_uint16_type->super = jl_unsigned_type; jl_uint32_type->super = jl_unsigned_type; jl_uint64_type->super = jl_unsigned_type; diff --git a/src/jltypes.c b/src/jltypes.c index f9f60f1227b9a..f637c52d61408 100644 --- a/src/jltypes.c +++ b/src/jltypes.c @@ -2092,6 +2092,8 @@ void jl_init_types(void) JL_GC_DISABLED jl_any_type, jl_emptysvec, 32); jl_int64_type = jl_new_primitivetype((jl_value_t*)jl_symbol("Int64"), core, jl_any_type, jl_emptysvec, 64); + jl_uint16_type = jl_new_primitivetype((jl_value_t*)jl_symbol("UInt16"), core, + jl_any_type, jl_emptysvec, 16); jl_uint32_type = jl_new_primitivetype((jl_value_t*)jl_symbol("UInt32"), core, jl_any_type, jl_emptysvec, 32); jl_uint64_type = jl_new_primitivetype((jl_value_t*)jl_symbol("UInt64"), core, @@ -2543,8 +2545,8 @@ void jl_init_types(void) JL_GC_DISABLED "rngState1", "rngState2", "rngState3", + "sticky_count", "_state", - "sticky", "_isexception"), jl_svec(14, jl_any_type, @@ -2558,8 +2560,8 @@ void jl_init_types(void) JL_GC_DISABLED jl_uint64_type, jl_uint64_type, jl_uint64_type, + jl_uint16_type, jl_uint8_type, - jl_bool_type, jl_bool_type), jl_emptysvec, 0, 1, 6); diff --git a/src/julia.h b/src/julia.h index 3455817cf1a92..a34e3d94b7b97 100644 --- a/src/julia.h +++ b/src/julia.h @@ -1833,8 +1833,8 @@ typedef struct _jl_task_t { uint64_t rngState1; uint64_t rngState2; uint64_t rngState3; + uint16_t sticky; // 0 means this Task can be migrated to a new thread uint8_t _state; - uint8_t sticky; // record whether this Task can be migrated to a new thread uint8_t _isexception; // set if `result` is an exception to throw or that we exited with // hidden state: diff --git a/stdlib/Serialization/src/Serialization.jl b/stdlib/Serialization/src/Serialization.jl index 592db96565c7a..6bd37b97481e0 100644 --- a/stdlib/Serialization/src/Serialization.jl +++ b/stdlib/Serialization/src/Serialization.jl @@ -456,6 +456,9 @@ function serialize(s::AbstractSerializer, linfo::Core.MethodInstance) nothing end +serialize(s::AbstractSerializer, f::Base.StickyCountDecrementer) = + invoke(serialize, Tuple{typeof(s),Any}, s, Base.unset_parent(f)) + function serialize(s::AbstractSerializer, t::Task) serialize_cycle(s, t) && return if istaskstarted(t) && !istaskdone(t) diff --git a/test/threads_exec.jl b/test/threads_exec.jl index f3d2dc9577c64..f21202cd796fc 100644 --- a/test/threads_exec.jl +++ b/test/threads_exec.jl @@ -850,7 +850,7 @@ fib34666(x) = wait(child) end wait(parent) - @test parent.sticky == true + @test parent.sticky == false end function jitter_channel(f, k, delay, ntasks, schedule) @@ -912,3 +912,37 @@ end @test reproducible_rand(r, 10) == val end end + +# [ADD TESTS ABOVE THIS COMMENT] +# +# The following tests must be done at the end, since they need to monkey-patch runtime. +const MAX_STICKY_COUNT = 3 +@assert MAX_STICKY_COUNT <= typemax(fieldtype(Task, :sticky_count)) +Base.is_sticky_count_saturated(t::Task) = t.sticky_count == MAX_STICKY_COUNT + +@testset "Saturated sticky_count" begin + @testset for nchild in MAX_STICKY_COUNT-1:MAX_STICKY_COUNT+1 + local is_sticky_pre, is_sticky_post, sticky_count_pre, sticky_count_post + @sync Threads.@spawn begin + is_sticky_pre = current_task().sticky + sticky_count_pre = current_task().sticky_count + @sync for _ in 1:nchild + @async nothing + end + is_sticky_post = current_task().sticky + sticky_count_post = current_task().sticky_count + end + @test !is_sticky_pre + @test sticky_count_pre == 0 + if nchild < MAX_STICKY_COUNT + @test !is_sticky_post + @test sticky_count_post == 0 + else + @test is_sticky_post + @test sticky_count_post == MAX_STICKY_COUNT + end + end +end + +# Please do not add tests at the end of this file. Pleaes add tests above the above +# comment [ADD TESTS ABOVE THIS COMMENT].