diff --git a/NEWS.md b/NEWS.md index 16afb8c168443..a0a3d84744880 100644 --- a/NEWS.md +++ b/NEWS.md @@ -3,6 +3,7 @@ Julia v1.11 Release Notes New language features --------------------- +* `ScopedValue` implement dynamic scope with inheritance across tasks ([#50958]). Language changes ---------------- diff --git a/base/Base.jl b/base/Base.jl index 0673a1081ae69..62b6d21589ed7 100644 --- a/base/Base.jl +++ b/base/Base.jl @@ -331,10 +331,6 @@ using .Libc: getpid, gethostname, time, memcpy, memset, memmove, memcmp const libblas_name = "libblastrampoline" * (Sys.iswindows() ? "-5" : "") const liblapack_name = libblas_name -# Logging -include("logging.jl") -using .CoreLogging - # Concurrency (part 2) # Note that `atomics.jl` here should be deprecated Core.eval(Threads, :(include("atomics.jl"))) @@ -344,6 +340,14 @@ include("task.jl") include("threads_overloads.jl") include("weakkeydict.jl") +# ScopedValues +include("scopedvalues.jl") +using .ScopedValues + +# Logging +include("logging.jl") +using .CoreLogging + include("env.jl") # functions defined in Random diff --git a/base/boot.jl b/base/boot.jl index e24a6f4ffc0e0..a7630fe1cbb60 100644 --- a/base/boot.jl +++ b/base/boot.jl @@ -163,7 +163,7 @@ # result::Any # exception::Any # backtrace::Any -# logstate::Any +# scope::Any # code::Any #end diff --git a/base/exports.jl b/base/exports.jl index 0959fa1c391e2..abf140d0ad81f 100644 --- a/base/exports.jl +++ b/base/exports.jl @@ -648,6 +648,11 @@ export sprint, summary, +# ScopedValue + scoped, + @scoped, + ScopedValue, + # logging @debug, @info, diff --git a/base/logging.jl b/base/logging.jl index c42af08d8f4ae..8792074fe8b01 100644 --- a/base/logging.jl +++ b/base/logging.jl @@ -492,8 +492,10 @@ end LogState(logger) = LogState(LogLevel(_invoked_min_enabled_level(logger)), logger) +const CURRENT_LOGSTATE = ScopedValue{Union{Nothing, LogState}}(nothing) + function current_logstate() - logstate = current_task().logstate + logstate = CURRENT_LOGSTATE[] return (logstate !== nothing ? logstate : _global_logstate)::LogState end @@ -506,17 +508,7 @@ end return nothing end -function with_logstate(f::Function, logstate) - @nospecialize - t = current_task() - old = t.logstate - try - t.logstate = logstate - f() - finally - t.logstate = old - end -end +with_logstate(f::Function, logstate) = @scoped(CURRENT_LOGSTATE => logstate, f()) #------------------------------------------------------------------------------- # Control of the current logger and early log filtering diff --git a/base/scopedvalues.jl b/base/scopedvalues.jl new file mode 100644 index 0000000000000..755f3a03446ed --- /dev/null +++ b/base/scopedvalues.jl @@ -0,0 +1,175 @@ +# This file is a part of Julia. License is MIT: https://julialang.org/license + +module ScopedValues + +export ScopedValue, scoped, @scoped + +""" + ScopedValue(x) + +Create a container that propagates values across dynamic scopes. +Use [`scoped`](@ref) to create and enter a new dynamic scope. + +Values can only be set when entering a new dynamic scope, +and the value referred to will be constant during the +execution of a dynamic scope. + +Dynamic scopes are propagated across tasks. + +# Examples +```jldoctest +julia> const svar = ScopedValue(1); + +julia> svar[] +1 + +julia> scoped(svar => 2) do + svar[] + end +2 + +julia> svar[] +1 +``` + +!!! compat "Julia 1.11" + Scoped values were introduced in Julia 1.11. In Julia 1.8+ a compatible + implementation is available from the package ScopedValues.jl. +""" +mutable struct ScopedValue{T} + const initial_value::T +end + +Base.eltype(::Type{ScopedValue{T}}) where {T} = T + +## +# Notes on the implementation. +# We want lookup to be unreasonably fast. +# - IDDict/Dict are ~10ns +# - ImmutableDict is faster up to about ~15 entries +# - ScopedValue are meant to be constant, Immutabilty +# is thus a boon +# - If we were to use IDDict/Dict we would need to split +# the cache portion and the value portion of the hash-table, +# the value portion is read-only/write-once, but the cache version +# would need a lock which makes ImmutableDict incredibly attractive. +# We could also use task-local-storage, but that added about 12ns. +# - Values are GC'd when scopes become unreachable, one could use +# a WeakKeyDict to also ensure that values get GC'd when ScopedValues +# become unreachable. +# - Scopes are an inline implementation of an ImmutableDict, if we wanted +# be really fancy we could use a CTrie or HAMT. + +mutable struct Scope + const parent::Union{Nothing, Scope} + const key::ScopedValue + const value::Any + Scope(parent, key::ScopedValue{T}, value::T) where T = new(parent, key, value) +end +Scope(parent, key::ScopedValue{T}, value) where T = + Scope(parent, key, convert(T, value)) + +function Scope(scope, pairs::Pair{<:ScopedValue}...) + for pair in pairs + scope = Scope(scope, pair...) + end + return scope +end + +""" + current_scope()::Union{Nothing, Scope} + +Return the current dynamic scope. +""" +current_scope() = current_task().scope::Union{Nothing, Scope} + +function Base.show(io::IO, scope::Scope) + print(io, Scope, "(") + seen = Set{ScopedValue}() + while scope !== nothing + if scope.key ∉ seen + if !isempty(seen) + print(io, ", ") + end + print(io, typeof(scope.key), "@") + show(io, Base.objectid(scope.key)) + print(io, " => ") + show(IOContext(io, :typeinfo => eltype(scope.key)), scope.value) + push!(seen, scope.key) + end + scope = scope.parent + end + print(io, ")") +end + +function Base.getindex(var::ScopedValue{T})::T where T + scope = current_scope() + while scope !== nothing + if scope.key === var + return scope.value::T + end + scope = scope.parent + end + return var.initial_value +end + +function Base.show(io::IO, var::ScopedValue) + print(io, ScopedValue) + print(io, '{', eltype(var), '}') + print(io, '(') + show(IOContext(io, :typeinfo => eltype(var)), var[]) + print(io, ')') +end + +""" + scoped(f, (var::ScopedValue{T} => val::T)...) + +Execute `f` in a new scope with `var` set to `val`. +""" +function scoped(f, pair::Pair{<:ScopedValue}, rest::Pair{<:ScopedValue}...) + @nospecialize + ct = Base.current_task() + current_scope = ct.scope::Union{Nothing, Scope} + scope = Scope(current_scope, pair, rest...) + ct.scope = scope + try + return f() + finally + ct.scope = current_scope + end +end + +scoped(@nospecialize(f)) = f() + +""" + @scoped vars... expr + +Macro version of `scoped(f, vars...)` but with `expr` instead of `f` function. +This is similar to using [`scoped`](@ref) with a `do` block, but avoids creating +a closure. +""" +macro scoped(exprs...) + ex = last(exprs) + if length(exprs) > 1 + exprs = exprs[1:end-1] + else + exprs = () + end + for expr in exprs + if expr.head !== :call || first(expr.args) !== :(=>) + error("@scoped expects arguments of the form `A => 2` got $expr") + end + end + exprs = map(esc, exprs) + ct = gensym(:ct) + current_scope = gensym(:current_scope) + body = Expr(:tryfinally, esc(ex), :($(ct).scope = $(current_scope))) + quote + $(ct) = $(Base.current_task)() + $(current_scope) = $(ct).scope::$(Union{Nothing, Scope}) + $(ct).scope = $(Scope)($(current_scope), $(exprs...)) + $body + end +end + +end # module ScopedValues diff --git a/doc/make.jl b/doc/make.jl index 087b033fcf79c..0ae74a55aceee 100644 --- a/doc/make.jl +++ b/doc/make.jl @@ -112,6 +112,7 @@ BaseDocs = [ "base/arrays.md", "base/parallel.md", "base/multi-threading.md", + "base/scopedvalues.md", "base/constants.md", "base/file.md", "base/io-network.md", diff --git a/doc/src/base/scopedvalues.md b/doc/src/base/scopedvalues.md new file mode 100644 index 0000000000000..78ad91d63a538 --- /dev/null +++ b/doc/src/base/scopedvalues.md @@ -0,0 +1,174 @@ +# Scoped Values + +Scoped values provide an implementation of dynamic scoping in Julia. +Dynamic scope means that the state of the value is dependent on the execution path +of the program. This means that for a scoped value you may observe +multiple different values at the same time. + +!!! compat "Julia 1.11" + Scoped values were introduced in Julia 1.11. In Julia 1.8+ a compatible + implementation is available from the package ScopedValues.jl. + +In its simplest form you can create a [`ScopedValue`](@ref) with a +default value and then use [`scoped`](@ref) or [`@scoped`](@ref) to +enter a new dynamic scope. +The new scope will inherit all values from the parent scope +(and recursively from all outer scopes) with the provided scoped +value taking priority over previous definitions. + +```julia +const scoped_val = ScopedValue(1) +const scoped_val2 = ScopedValue(0) + +# Enter a new dynamic scope and set value +@show scoped_val[] # 1 +@show scoped_val2[] # 0 +scoped(scoped_val => 2) do + @show scoped_val[] # 2 + @show scoped_val2[] # 0 + scoped(scoped_val => 3, scoped_val2 => 5) do + @show scoped_val[] # 3 + @show scoped_val2[] # 5 + end + @show scoped_val[] # 2 + @show scoped_val2[] # 0 +end +@show scoped_val[] # 1 +@show scoped_val2[] # 0 +``` + +Since `scoped` requires a closure or a function and creates another call-frame, +it can sometimes be beneficial to use the macro form. + +```julia +const STATE = ScopedValue{Union{Nothing, State}}() +with_state(f, state::State) = @scoped(STATE => state, f()) +``` + +!!! note + Dynamic scopes are propagated through [`Task`](@ref)s. + +In the example below we open a new dynamic scope before launching a task. +The parent task and the two child tasks observe independent values of the +same scoped value at the same time. + +```julia +import Base.Threads: @spawn +const scoped_val = ScopedValue(1) +@sync begin + scoped(scoped_val => 2) + @spawn @show scoped_val[] # 2 + end + scoped(scoped_val => 3) + @spawn @show scoped_val[] # 3 + end + @show scoped_val[] # 1 +end +``` + +Scoped values are constant throughout a scope, but you can store mutable +state in a scoped value. Just keep in mind that the usual caveats +for global variables apply in the context of concurrent programming. + +Care is also required when storing references to mutable state in scoped +values. You might want to explicitly [unshare mutable state](@ref unshare_mutable_state) +when entering a new dynamic scope. + +```julia +import Base.Threads: @spawn +const sval_dict = ScopedValue(Dict()) + +# Example of using a mutable value wrongly +@sync begin + # `Dict` is not thread-safe the usage below is invalid + @spawn (sval_dict[][:a] = 3) + @spawn (sval_dict[][:b] = 3) +end + +@sync begin + # If we instead pass a unique dictionary to each + # task we can access the dictonaries race free. + scoped(sval_dict => Dict()) + @spawn (sval_dict[][:a] = 3) + end + scoped(sval_dict => Dict()) + @spawn (sval_dict[][:b] = 3) + end +end +``` + +## Example + +In the example below we use a scoped value to implement a permission check in +a web-application. After determining the permissions of the request, +a new dynamic scope is entered and the scoped value `LEVEL` is set. +Other parts of the application can query the scoped value and will receive +the appropriate value. Other alternatives like task-local storage and global variables +are not well suited for this kind of propagation; our only alternative would have +been to thread a value through the entire call-chain. + +```julia +const LEVEL = ScopedValue(:GUEST) + +function serve(request, response) + level = isAdmin(request) ? :ADMIN : :GUEST + scoped(LEVEL => level) do + Threads.@spawn handle(request, respone) + end +end + +function open(connection::Database) + level = LEVEL[] + if level !== :ADMIN + error("Access disallowed") + end + # ... open connection +end + +function handle(request, response) + # ... + open(Database(#=...=#)) + # ... +end +``` + +## Idioms +### [Unshare mutable state]((@id unshare_mutable_state)) + +```julia +import Base.Threads: @spawn +const sval_dict = ScopedValue(Dict()) + +# If you want to add new values to the dict, instead of replacing +# it, unshare the values explicitly. In this example we use `merge` +# to unshare the state of the dictonary in parent scope. +@sync begin + scoped(sval_dict => merge(sval_dict, Dict(:a => 10))) do + @spawn @show sval_dict[][:a] + end + @spawn sval_dict[][:a] = 3 # Not a race since they are unshared. +end +``` + +### Local caching + +Since lookup of a scoped variable is linear in scope depth, it can be beneficial +for a library at an API boundary to cache the state of the scoped value. + +```julia +const DEVICE = ScopedValue(:CPU) + +function solve_problem(args...) + # Cache current device + @scoped DEVICE => DEVICE[] begin + # call functions that use `DEVICE[]` repeatedly. + end +``` + +## API docs + +```@docs +Base.ScopedValues.ScopedValue +Base.ScopedValues.scoped +Base.ScopedValues.@scoped +``` diff --git a/src/jltypes.c b/src/jltypes.c index f3273ae936db3..de28631ef95be 100644 --- a/src/jltypes.c +++ b/src/jltypes.c @@ -3232,7 +3232,7 @@ void jl_init_types(void) JL_GC_DISABLED "storage", "donenotify", "result", - "logstate", + "scope", "code", "rngState0", "rngState1", diff --git a/src/julia.h b/src/julia.h index a96b4a1f5e562..3d94c44c4f2ad 100644 --- a/src/julia.h +++ b/src/julia.h @@ -2013,7 +2013,7 @@ typedef struct _jl_task_t { jl_value_t *tls; jl_value_t *donenotify; jl_value_t *result; - jl_value_t *logstate; + jl_value_t *scope; jl_function_t *start; // 4 byte padding on 32-bit systems // uint32_t padding0; diff --git a/src/task.c b/src/task.c index 1dab8688cb079..2e95a6f1770c4 100644 --- a/src/task.c +++ b/src/task.c @@ -1068,8 +1068,8 @@ JL_DLLEXPORT jl_task_t *jl_new_task(jl_function_t *start, jl_value_t *completion t->result = jl_nothing; t->donenotify = completion_future; jl_atomic_store_relaxed(&t->_isexception, 0); - // Inherit logger state from parent task - t->logstate = ct->logstate; + // Inherit scope from parent task + t->scope = ct->scope; // Fork task-local random state from parent jl_rng_split(t->rngState, ct->rngState); // there is no active exception handler available on this stack yet @@ -1670,7 +1670,7 @@ jl_task_t *jl_init_root_task(jl_ptls_t ptls, void *stack_lo, void *stack_hi) ct->result = jl_nothing; ct->donenotify = jl_nothing; jl_atomic_store_relaxed(&ct->_isexception, 0); - ct->logstate = jl_nothing; + ct->scope = jl_nothing; ct->eh = NULL; ct->gcstack = NULL; ct->excstack = NULL; diff --git a/stdlib/Test/src/Test.jl b/stdlib/Test/src/Test.jl index 622c696b383a0..29b9f4576d4c8 100644 --- a/stdlib/Test/src/Test.jl +++ b/stdlib/Test/src/Test.jl @@ -982,7 +982,11 @@ end A simple fallback test set that throws immediately on a failure. """ struct FallbackTestSet <: AbstractTestSet end -fallback_testset = FallbackTestSet() + +const CURRENT_TESTSET = ScopedValue{AbstractTestSet}(FallbackTestSet()) +const CURRENT_DEPTH = ScopedValue(0) +get_testset() = CURRENT_TESTSET[] +get_testset_depth() = CURRENT_DEPTH[] struct FallbackTestSetException <: Exception msg::String @@ -1520,14 +1524,13 @@ function testset_context(args, ex, source) test_ex = ex.args[2] ex.args[2] = quote + testset = $(CURRENT_TESTSET)[] $(map(contexts) do context - :($push_testset($(ContextTestSet)($(QuoteNode(context)), $context; $options...))) + :(testset = $(ContextTestSet)(testset, $(QuoteNode(context)), $context; $options...)) end...) - try - $(test_ex) - finally - $(map(_->:($pop_testset()), contexts)...) - end + @scoped($(CURRENT_TESTSET) => testset, + $(CURRENT_DEPTH) => $CURRENT_DEPTH[]+1, + $(test_ex)) end return esc(ex) @@ -1563,34 +1566,34 @@ function testset_beginend_call(args, tests, source) else $(testsettype)($desc; $options...) end - push_testset(ts) - # we reproduce the logic of guardseed, but this function - # cannot be used as it changes slightly the semantic of @testset, - # by wrapping the body in a function - local RNG = default_rng() - local oldrng = copy(RNG) - local oldseed = Random.GLOBAL_SEED - try - # RNG is re-seeded with its own seed to ease reproduce a failed test - Random.seed!(Random.GLOBAL_SEED) - let - $(esc(tests)) - end - catch err - err isa InterruptException && rethrow() - # something in the test block threw an error. Count that as an - # error in this test set - trigger_test_failure_break(err) - if err isa FailFastError - get_testset_depth() > 1 ? rethrow() : failfast_print() - else - record(ts, Error(:nontest_error, Expr(:tuple), err, Base.current_exceptions(), $(QuoteNode(source)))) + @scoped $(CURRENT_TESTSET) => ts $(CURRENT_DEPTH) => $(CURRENT_DEPTH)[] + 1 begin + # we reproduce the logic of guardseed, but this function + # cannot be used as it changes slightly the semantic of @testset, + # by wrapping the body in a function + local RNG = default_rng() + local oldrng = copy(RNG) + local oldseed = Random.GLOBAL_SEED + try + # RNG is re-seeded with its own seed to ease reproduce a failed test + Random.seed!(Random.GLOBAL_SEED) + let + $(esc(tests)) + end + catch err + err isa InterruptException && rethrow() + # something in the test block threw an error. Count that as an + # error in this test set + trigger_test_failure_break(err) + if err isa FailFastError + get_testset_depth() > 1 ? rethrow() : failfast_print() + else + record(ts, Error(:nontest_error, Expr(:tuple), err, Base.current_exceptions(), $(QuoteNode(source)))) + end + finally + copy!(RNG, oldrng) + Random.set_global_seed!(oldseed) + ret = finish(ts) end - finally - copy!(RNG, oldrng) - Random.set_global_seed!(oldseed) - pop_testset() - ret = finish(ts) end ret end @@ -1649,7 +1652,6 @@ function testset_forloop(args, testloop, source) # Trick to handle `break` and `continue` in the test code before # they can be handled properly by `finally` lowering. if !first_iteration - pop_testset() finish_errored = true push!(arr, finish(ts)) finish_errored = false @@ -1663,17 +1665,18 @@ function testset_forloop(args, testloop, source) else $(testsettype)($desc; $options...) end - push_testset(ts) - first_iteration = false - try - $(esc(tests)) - catch err - err isa InterruptException && rethrow() - # Something in the test block threw an error. Count that as an - # error in this test set - trigger_test_failure_break(err) - if !isa(err, FailFastError) - record(ts, Error(:nontest_error, Expr(:tuple), err, Base.current_exceptions(), $(QuoteNode(source)))) + @scoped $(CURRENT_TESTSET) => ts $(CURRENT_DEPTH) => $(CURRENT_DEPTH)[] + 1 begin + first_iteration = false + try + $(esc(tests)) + catch err + err isa InterruptException && rethrow() + # Something in the test block threw an error. Count that as an + # error in this test set + trigger_test_failure_break(err) + if !isa(err, FailFastError) + record(ts, Error(:nontest_error, Expr(:tuple), err, Base.current_exceptions(), $(QuoteNode(source)))) + end end end end @@ -1694,7 +1697,6 @@ function testset_forloop(args, testloop, source) finally # Handle `return` in test body if !first_iteration && !finish_errored - pop_testset() push!(arr, finish(ts)) end copy!(RNG, oldrng) @@ -1736,51 +1738,6 @@ end #----------------------------------------------------------------------- # Various helper methods for test sets -""" - get_testset() - -Retrieve the active test set from the task's local storage. If no -test set is active, use the fallback default test set. -""" -function get_testset() - testsets = get(task_local_storage(), :__BASETESTNEXT__, AbstractTestSet[]) - return isempty(testsets) ? fallback_testset : testsets[end] -end - -""" - push_testset(ts::AbstractTestSet) - -Adds the test set to the `task_local_storage`. -""" -function push_testset(ts::AbstractTestSet) - testsets = get(task_local_storage(), :__BASETESTNEXT__, AbstractTestSet[]) - push!(testsets, ts) - setindex!(task_local_storage(), testsets, :__BASETESTNEXT__) -end - -""" - pop_testset() - -Pops the last test set added to the `task_local_storage`. If there are no -active test sets, returns the fallback default test set. -""" -function pop_testset() - testsets = get(task_local_storage(), :__BASETESTNEXT__, AbstractTestSet[]) - ret = isempty(testsets) ? fallback_testset : pop!(testsets) - setindex!(task_local_storage(), testsets, :__BASETESTNEXT__) - return ret -end - -""" - get_testset_depth() - -Return the number of active test sets, not including the default test set -""" -function get_testset_depth() - testsets = get(task_local_storage(), :__BASETESTNEXT__, AbstractTestSet[]) - return length(testsets) -end - _args_and_call(args...; kwargs...) = (args[1:end-1], kwargs, args[end](args[1:end-1]...; kwargs...)) _materialize_broadcasted(f, args...) = Broadcast.materialize(Broadcast.broadcasted(f, args...)) diff --git a/test/choosetests.jl b/test/choosetests.jl index c38817bb4eeb9..4d493a16d0484 100644 --- a/test/choosetests.jl +++ b/test/choosetests.jl @@ -29,6 +29,7 @@ const TESTNAMES = [ "channels", "iostream", "secretbuffer", "specificity", "reinterpretarray", "syntax", "corelogging", "missing", "asyncmap", "smallarrayshrink", "opaque_closure", "filesystem", "download", + "scopedvalues", ] const INTERNET_REQUIRED_LIST = [ diff --git a/test/runtests.jl b/test/runtests.jl index 1264acae985b0..e5c9c64913e55 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -373,15 +373,14 @@ cd(@__DIR__) do Test.TESTSET_PRINT_ENABLE[] = false o_ts = Test.DefaultTestSet("Overall") o_ts.time_end = o_ts.time_start + o_ts_duration # manually populate the timing - Test.push_testset(o_ts) completed_tests = Set{String}() - for (testname, (resp,), duration) in results + @scoped Test.CURRENT_TESTSET => o_ts for (testname, (resp,), duration) in results push!(completed_tests, testname) if isa(resp, Test.DefaultTestSet) resp.time_end = resp.time_start + duration - Test.push_testset(resp) - Test.record(o_ts, resp) - Test.pop_testset() + @scoped Test.CURRENT_TESTSET => resp begin + Test.record(o_ts, resp) + end elseif isa(resp, Test.TestSetException) fake = Test.DefaultTestSet(testname) fake.time_end = fake.time_start + duration @@ -394,9 +393,9 @@ cd(@__DIR__) do for t in resp.errors_and_fails Test.record(fake, t) end - Test.push_testset(fake) - Test.record(o_ts, fake) - Test.pop_testset() + @scoped Test.CURRENT_TESTSET => fake begin + Test.record(o_ts, fake) + end else if !isa(resp, Exception) resp = ErrorException(string("Unknown result type : ", typeof(resp))) @@ -408,18 +407,18 @@ cd(@__DIR__) do fake = Test.DefaultTestSet(testname) fake.time_end = fake.time_start + duration Test.record(fake, Test.Error(:nontest_error, testname, nothing, Any[(resp, [])], LineNumberNode(1))) - Test.push_testset(fake) - Test.record(o_ts, fake) - Test.pop_testset() + @scoped Test.CURRENT_TESTSET => fake begin + Test.record(o_ts, fake) + end end end for test in all_tests (test in completed_tests) && continue fake = Test.DefaultTestSet(test) Test.record(fake, Test.Error(:test_interrupted, test, nothing, [("skipped", [])], LineNumberNode(1))) - Test.push_testset(fake) - Test.record(o_ts, fake) - Test.pop_testset() + @scoped Test.CURRENT_TESTSET => fake begin + Test.record(o_ts, fake) + end end Test.TESTSET_PRINT_ENABLE[] = true println() diff --git a/test/scopedvalues.jl b/test/scopedvalues.jl new file mode 100644 index 0000000000000..367249cf74dbd --- /dev/null +++ b/test/scopedvalues.jl @@ -0,0 +1,121 @@ +# This file is a part of Julia. License is MIT: https://julialang.org/license +import Base: ScopedValues + +@testset "errors" begin + @test ScopedValue{Float64}(1)[] == 1.0 + @test_throws InexactError ScopedValue{Int}(1.5) + var = ScopedValue(1) + @test_throws MethodError var[] = 2 + scoped() do + @test_throws MethodError var[] = 2 + end + @test_throws MethodError ScopedValue{Int}() + @test_throws MethodError ScopedValue() +end + +const svar = ScopedValue(1) +@testset "inheritance" begin + @test svar[] == 1 + scoped() do + @test svar[] == 1 + scoped() do + @test svar[] == 1 + end + scoped(svar => 2) do + @test svar[] == 2 + end + @test svar[] == 1 + end + @test svar[] == 1 +end + +const svar_float = ScopedValue(1.0) + +@testset "multiple scoped values" begin + scoped(svar => 2, svar_float => 2.0) do + @test svar[] == 2 + @test svar_float[] == 2.0 + end + scoped(svar => 2, svar => 3) do + @test svar[] == 3 + end +end + +emptyf() = nothing + +@testset "conversion" begin + scoped(emptyf, svar_float=>2) + @test_throws MethodError scoped(emptyf, svar_float=>"hello") +end + +import Base.Threads: @spawn +@testset "tasks" begin + @test fetch(@spawn begin + svar[] + end) == 1 + scoped(svar => 2) do + @test fetch(@spawn begin + svar[] + end) == 2 + end +end + +@testset "show" begin + @test sprint(show, svar) == "ScopedValue{$Int}(1)" + @test sprint(show, ScopedValues.current_scope()) == "nothing" + scoped(svar => 2.0) do + @test sprint(show, svar) == "ScopedValue{$Int}(2)" + objid = sprint(show, Base.objectid(svar)) + @test sprint(show, ScopedValues.current_scope()) == "Base.ScopedValues.Scope(ScopedValue{$Int}@$objid => 2)" + end +end + +const depth = ScopedValue(0) +function nth_scoped(f, n) + if n <= 0 + f() + else + scoped(depth => n) do + nth_scoped(f, n-1) + end + end +end + + +@testset "nested scoped" begin + @testset for depth in 1:16 + nth_scoped(depth) do + @test svar_float[] == 1.0 + end + scoped(svar_float=>2.0) do + nth_scoped(depth) do + @test svar_float[] == 2.0 + end + end + nth_scoped(depth) do + scoped(svar_float=>2.0) do + @test svar_float[] == 2.0 + end + end + end + scoped(svar_float=>2.0) do + nth_scoped(15) do + @test svar_float[] == 2.0 + scoped(svar_float => 3.0) do + @test svar_float[] == 3.0 + end + end + end +end + +@testset "macro" begin + @scoped svar=>2 svar_float=>2.0 begin + @test svar[] == 2 + @test svar_float[] == 2.0 + end + # Doesn't do much... + @scoped begin + @test svar[] == 1 + @test svar_float[] == 1.0 + end +end