diff --git a/Project.toml b/Project.toml index c3723c8..42f84f2 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" keywords = ["probablistic programming"] license = "MIT" desc = "Common interfaces for probabilistic programming" -version = "0.10.1" +version = "0.11.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/docs/src/api.md b/docs/src/api.md index e4ca9e2..e7705da 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -14,6 +14,13 @@ vsym @vsym ``` +## VarName prefixing and unprefixing + +```@docs +prefix +unprefix +``` + ## VarName serialisation ```@docs diff --git a/src/AbstractPPL.jl b/src/AbstractPPL.jl index 86015a6..d0df706 100644 --- a/src/AbstractPPL.jl +++ b/src/AbstractPPL.jl @@ -14,7 +14,9 @@ export VarName, index_to_dict, dict_to_index, varname_to_string, - string_to_varname + string_to_varname, + prefix, + unprefix # Abstract model functions export AbstractProbabilisticProgram, diff --git a/src/varname.jl b/src/varname.jl index 0d0d62d..38d558c 100644 --- a/src/varname.jl +++ b/src/varname.jl @@ -766,6 +766,8 @@ function vsym(expr::Expr) end end +### Serialisation to JSON / string + # String constants for each index type that we support serialisation / # deserialisation of const _BASE_INTEGER_TYPE = "Base.Integer" @@ -936,3 +938,193 @@ Convert a string representation of a `VarName` back to a `VarName`. The string should have been generated by `varname_to_string`. """ string_to_varname(str::AbstractString) = dict_to_varname(JSON.parse(str)) + +### Prefixing and unprefixing + +""" + _strip_identity(optic) + +Remove identity lenses from composed optics. +""" +_strip_identity(::Base.ComposedFunction{typeof(identity),typeof(identity)}) = identity +function _strip_identity(o::Base.ComposedFunction{Outer,typeof(identity)}) where {Outer} + return _strip_identity(o.outer) +end +function _strip_identity(o::Base.ComposedFunction{typeof(identity),Inner}) where {Inner} + return _strip_identity(o.inner) +end +_strip_identity(o::Base.ComposedFunction) = o +_strip_identity(o::Accessors.PropertyLens) = o +_strip_identity(o::Accessors.IndexLens) = o +_strip_identity(o::typeof(identity)) = o + +""" + _inner(optic) + +Get the innermost (non-identity) layer of an optic. + +```jldoctest; setup=:(using Accessors) +julia> AbstractPPL._inner(Accessors.@o _.a.b.c) +(@o _.a) + +julia> AbstractPPL._inner(Accessors.@o _[1][2][3]) +(@o _[1]) + +julia> AbstractPPL._inner(Accessors.@o _) +identity (generic function with 1 method) +``` +""" +_inner(o::Base.ComposedFunction{Outer,Inner}) where {Outer,Inner} = o.inner +_inner(o::Accessors.PropertyLens) = o +_inner(o::Accessors.IndexLens) = o +_inner(o::typeof(identity)) = o + +""" + _outer(optic) + +Get the outer layer of an optic. + +```jldoctest; setup=:(using Accessors) +julia> AbstractPPL._outer(Accessors.@o _.a.b.c) +(@o _.b.c) + +julia> AbstractPPL._outer(Accessors.@o _[1][2][3]) +(@o _[2][3]) + +julia> AbstractPPL._outer(Accessors.@o _.a) +identity (generic function with 1 method) + +julia> AbstractPPL._outer(Accessors.@o _[1]) +identity (generic function with 1 method) + +julia> AbstractPPL._outer(Accessors.@o _) +identity (generic function with 1 method) +``` +""" +_outer(o::Base.ComposedFunction{Outer,Inner}) where {Outer,Inner} = o.outer +_outer(::Accessors.PropertyLens) = identity +_outer(::Accessors.IndexLens) = identity +_outer(::typeof(identity)) = identity + +""" + optic_to_vn(optic) + +Convert an Accessors optic to a VarName. This is best explained through +examples. + +```jldoctest; setup=:(using Accessors) +julia> AbstractPPL.optic_to_vn(Accessors.@o _.a) +a + +julia> AbstractPPL.optic_to_vn(Accessors.@o _.a.b) +a.b + +julia> AbstractPPL.optic_to_vn(Accessors.@o _.a[1]) +a[1] +``` + +The outermost layer of the optic (technically, what Accessors.jl calls the +'innermost') must be a `PropertyLens`, or else it will fail. This is because a +VarName needs to have a symbol. + +```jldoctest; setup=:(using Accessors) +julia> AbstractPPL.optic_to_vn(Accessors.@o _[1]) +ERROR: ArgumentError: optic_to_vn: could not convert optic `(@o _[1])` to a VarName +[...] +``` +""" +function optic_to_vn(::Accessors.PropertyLens{sym}) where {sym} + return VarName{sym}() +end +function optic_to_vn(o::Base.ComposedFunction{Outer,typeof(identity)}) where {Outer} + return optic_to_vn(o.outer) +end +function optic_to_vn( + o::Base.ComposedFunction{Outer,Accessors.PropertyLens{sym}} +) where {Outer,sym} + return VarName{sym}(o.outer) +end +function optic_to_vn(@nospecialize(o)) + msg = "optic_to_vn: could not convert optic `$o` to a VarName" + throw(ArgumentError(msg)) +end + +unprefix_optic(o, ::typeof(identity)) = o # Base case +function unprefix_optic(optic, optic_prefix) + # strip one layer of the optic and check for equality + inner = _inner(_strip_identity(optic)) + inner_prefix = _inner(_strip_identity(optic_prefix)) + if inner != inner_prefix + msg = "could not remove prefix $(optic_prefix) from optic $(optic)" + throw(ArgumentError(msg)) + end + # recurse + return unprefix_optic( + _outer(_strip_identity(optic)), _outer(_strip_identity(optic_prefix)) + ) +end + +""" + unprefix(vn::VarName, prefix::VarName) + +Remove a prefix from a VarName. + +```jldoctest +julia> AbstractPPL.unprefix(@varname(y.x), @varname(y)) +x + +julia> AbstractPPL.unprefix(@varname(y.x.a), @varname(y)) +x.a + +julia> AbstractPPL.unprefix(@varname(y[1].x), @varname(y[1])) +x + +julia> AbstractPPL.unprefix(@varname(y), @varname(n)) +ERROR: ArgumentError: could not remove prefix n from VarName y +[...] +``` +""" +function unprefix( + vn::VarName{sym_vn}, prefix::VarName{sym_prefix} +) where {sym_vn,sym_prefix} + if sym_vn != sym_prefix + msg = "could not remove prefix $(prefix) from VarName $(vn)" + throw(ArgumentError(msg)) + end + optic_vn = getoptic(vn) + optic_prefix = getoptic(prefix) + return optic_to_vn(unprefix_optic(optic_vn, optic_prefix)) +end + +""" + prefix(vn::VarName, prefix::VarName) + +Add a prefix to a VarName. + +```jldoctest +julia> AbstractPPL.prefix(@varname(x), @varname(y)) +y.x + +julia> AbstractPPL.prefix(@varname(x.a), @varname(y)) +y.x.a + +julia> AbstractPPL.prefix(@varname(x.a), @varname(y[1])) +y[1].x.a +``` +""" +function prefix(vn::VarName{sym_vn}, prefix::VarName{sym_prefix}) where {sym_vn,sym_prefix} + optic_vn = getoptic(vn) + optic_prefix = getoptic(prefix) + # Special case `identity` to avoid having ComposedFunctions with identity + if optic_vn == identity + new_inner_optic_vn = PropertyLens{sym_vn}() + else + new_inner_optic_vn = optic_vn ∘ PropertyLens{sym_vn}() + end + if optic_prefix == identity + new_optic_vn = new_inner_optic_vn + else + new_optic_vn = new_inner_optic_vn ∘ optic_prefix + end + return VarName{sym_prefix}(new_optic_vn) +end diff --git a/test/varname.jl b/test/varname.jl index 32ac1c1..3fb2733 100644 --- a/test/varname.jl +++ b/test/varname.jl @@ -233,4 +233,19 @@ end # Serialisation should now work @test string_to_varname(varname_to_string(vn)) == vn end + + @testset "prefix and unprefix" begin + @test prefix(@varname(y), @varname(x)) == @varname(x.y) + @test prefix(@varname(y), @varname(x[1])) == @varname(x[1].y) + @test prefix(@varname(y), @varname(x.a)) == @varname(x.a.y) + @test prefix(@varname(y[1]), @varname(x)) == @varname(x.y[1]) + @test prefix(@varname(y.a), @varname(x)) == @varname(x.y.a) + + @test unprefix(@varname(x.y[1]), @varname(x)) == @varname(y[1]) + @test unprefix(@varname(x[1].y), @varname(x[1])) == @varname(y) + @test unprefix(@varname(x.a.y), @varname(x.a)) == @varname(y) + @test unprefix(@varname(x.y.a), @varname(x)) == @varname(y.a) + @test_throws ArgumentError unprefix(@varname(x.y.a), @varname(n)) + @test_throws ArgumentError unprefix(@varname(x.y.a), @varname(x[1])) + end end