diff --git a/base/iddict.jl b/base/iddict.jl index 23ba65799a395..93e0409cbd5c4 100644 --- a/base/iddict.jl +++ b/base/iddict.jl @@ -85,8 +85,9 @@ function get(d::IdDict{K,V}, @nospecialize(key), @nospecialize(default)) where { val = ccall(:jl_eqtable_get, Any, (Any, Any, Any), d.ht, key, default) val === default ? default : val::V end + function getindex(d::IdDict{K,V}, @nospecialize(key)) where {K, V} - val = get(d, key, secret_table_token) + val = ccall(:jl_eqtable_get, Any, (Any, Any, Any), d.ht, key, secret_table_token) val === secret_table_token && throw(KeyError(key)) return val::V end @@ -134,23 +135,38 @@ length(d::IdDict) = d.count copy(d::IdDict) = typeof(d)(d) -get!(d::IdDict{K,V}, @nospecialize(key), @nospecialize(default)) where {K, V} = (d[key] = get(d, key, default))::V +function get!(d::IdDict{K,V}, @nospecialize(key), @nospecialize(default)) where {K, V} + val = ccall(:jl_eqtable_get, Any, (Any, Any, Any), d.ht, key, secret_table_token) + if val === secret_table_token + val = isa(default, V) ? default : convert(V, default) + setindex!(d, val, key) + return val + else + return val::V + end +end function get(default::Callable, d::IdDict{K,V}, @nospecialize(key)) where {K, V} - val = get(d, key, secret_table_token) + val = ccall(:jl_eqtable_get, Any, (Any, Any, Any), d.ht, key, secret_table_token) if val === secret_table_token - val = default() + return default() + else + return val::V end - return val end function get!(default::Callable, d::IdDict{K,V}, @nospecialize(key)) where {K, V} - val = get(d, key, secret_table_token) + val = ccall(:jl_eqtable_get, Any, (Any, Any, Any), d.ht, key, secret_table_token) if val === secret_table_token val = default() + if !isa(val, V) + val = convert(V, val) + end setindex!(d, val, key) + return val + else + return val::V end - return val end in(@nospecialize(k), v::KeySet{<:Any,<:IdDict}) = get(v.dict, k, secret_table_token) !== secret_table_token diff --git a/test/dict.jl b/test/dict.jl index 176314671f937..de455576b2bc4 100644 --- a/test/dict.jl +++ b/test/dict.jl @@ -554,7 +554,8 @@ end @test delete!(d, "a") === d @test !haskey(d, "a") @test_throws ArgumentError get!(IdDict{Symbol,Any}(), 2, "b") - + @test get!(IdDict{Int,Int}(), 1, 2.0) === 2 + @test get!(()->2.0, IdDict{Int,Int}(), 1) === 2 # sizehint! & rehash! d = IdDict()