From a8890b63eb1d38f9e53f35a63aa3c9eb1667e8c1 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 4 Mar 2025 09:47:02 +0000 Subject: [PATCH 1/6] Implement prefix / unprefix --- src/AbstractPPL.jl | 4 +- src/varname.jl | 151 +++++++++++++++++++++++++++++++++++++++++++++ test/varname.jl | 14 +++++ 3 files changed, 168 insertions(+), 1 deletion(-) diff --git a/src/AbstractPPL.jl b/src/AbstractPPL.jl index 86015a6..d0df706 100644 --- a/src/AbstractPPL.jl +++ b/src/AbstractPPL.jl @@ -14,7 +14,9 @@ export VarName, index_to_dict, dict_to_index, varname_to_string, - string_to_varname + string_to_varname, + prefix, + unprefix # Abstract model functions export AbstractProbabilisticProgram, diff --git a/src/varname.jl b/src/varname.jl index 0d0d62d..62d71f9 100644 --- a/src/varname.jl +++ b/src/varname.jl @@ -766,6 +766,8 @@ function vsym(expr::Expr) end end +### Serialisation to JSON / string + # String constants for each index type that we support serialisation / # deserialisation of const _BASE_INTEGER_TYPE = "Base.Integer" @@ -936,3 +938,152 @@ Convert a string representation of a `VarName` back to a `VarName`. The string should have been generated by `varname_to_string`. """ string_to_varname(str::AbstractString) = dict_to_varname(JSON.parse(str)) + +### Prefixing and unprefixing + +""" + _strip_identity(optic) + +Remove an inner layer of the identity lens from a composed optic. +""" +_strip_identity(o::Base.ComposedFunction{Outer,typeof(identity)}) where {Outer} = o.outer +_strip_identity(o::Base.ComposedFunction) = o +_strip_identity(o::Accessors.PropertyLens) = o +_strip_identity(o::Accessors.IndexLens) = o +_strip_identity(o::typeof(identity)) = o + +""" + _inner(optic) + +Get the innermost (non-identity) layer of an optic. + +```jldoctest; setup=:(using Accessors) +julia> AbstractPPL._inner(Accessors.@o _.a.b.c) +(@o _.a) + +julia> AbstractPPL._inner(Accessors.@o _[1][2][3]) +(@o _[1]) + +julia> AbstractPPL._inner(Accessors.@o _) +identity (generic function with 1 method) +``` +""" +_inner(o::Base.ComposedFunction{Outer,Inner}) where {Outer,Inner} = o.inner +function _inner(o::Base.ComposedFunction{Outer,typeof(identity)}) where {Outer} + return _strip_identity(o.outer) +end +_inner(o::Accessors.PropertyLens) = o +_inner(o::Accessors.IndexLens) = o +_inner(o::typeof(identity)) = o + +""" + _outer(optic) + +Get the outer layer of an optic. + +```jldoctest; setup=:(using Accessors) +julia> AbstractPPL._outer(Accessors.@o _.a.b.c) +(@o _.b.c) + +julia> AbstractPPL._outer(Accessors.@o _[1][2][3]) +(@o _[2][3]) + +julia> AbstractPPL._outer(Accessors.@o _.a) +identity (generic function with 1 method) + +julia> AbstractPPL._outer(Accessors.@o _[1]) +identity (generic function with 1 method) + +julia> AbstractPPL._outer(Accessors.@o _) +identity (generic function with 1 method) +``` +""" +_outer(o::Base.ComposedFunction{Outer,Inner}) where {Outer,Inner} = _strip_identity(o.outer) +_outer(::Accessors.PropertyLens) = identity +_outer(::Accessors.IndexLens) = identity +_outer(::typeof(identity)) = identity + +""" + optic_to_vn(optic) + +Convert an Accessors optic to a VarName. This is best explained through +examples. + +```jldoctest; setup=:(using Accessors) +julia> AbstractPPL.optic_to_vn(Accessors.@o _.a) +a + +julia> AbstractPPL.optic_to_vn(Accessors.@o _.a.b) +a.b + +julia> AbstractPPL.optic_to_vn(Accessors.@o _.a[1]) +a[1] +``` + +The outermost layer of the optic (technically, what Accessors.jl calls the +'innermost') must be a `PropertyLens`, or else it will fail. This is because a +VarName needs to have a symbol. + +```jldoctest; setup=:(using Accessors) +julia> AbstractPPL.optic_to_vn(Accessors.@o _[1]) +ERROR: ArgumentError: optic_to_vn: could not convert optic `(@o _[1])` to a VarName +[...] +``` +""" +function optic_to_vn(::Accessors.PropertyLens{sym}) where {sym} + return VarName{sym}() +end +function optic_to_vn(o::Base.ComposedFunction{Outer,typeof(identity)}) where {Outer} + return optic_to_vn(o.outer) +end +function optic_to_vn( + o::Base.ComposedFunction{Outer,Accessors.PropertyLens{sym}} +) where {Outer,sym} + return VarName{sym}(o.outer) +end +function optic_to_vn(@nospecialize(o)) + msg = "optic_to_vn: could not convert optic `$o` to a VarName" + throw(ArgumentError(msg)) +end + +unprefix_optic(o, ::typeof(identity)) = o # Base case +function unprefix_optic(optic, optic_prefix) + # strip one layer of the optic and check for equality + inner = _inner(optic) + inner_prefix = _inner(optic_prefix) + if inner != inner_prefix + msg = "could not remove prefix $(optic_prefix) from optic $(optic)" + throw(ArgumentError(msg)) + end + # recurse + return unprefix_optic(_outer(optic), _outer(optic_prefix)) +end + +function unprefix( + vn::VarName{sym_vn}, prefix::VarName{sym_prefix} +) where {sym_vn,sym_prefix} + if sym_vn != sym_prefix + msg = "could not remove prefix $(prefix) from VarName $(vn)" + throw(ArgumentError(msg)) + end + optic_vn = getoptic(vn) + optic_prefix = getoptic(prefix) + return optic_to_vn(unprefix_optic(optic_vn, optic_prefix)) +end + +function prefix(vn::VarName{sym_vn}, prefix::VarName{sym_prefix}) where {sym_vn,sym_prefix} + optic_vn = getoptic(vn) + optic_prefix = getoptic(prefix) + # Special case `identity` to avoid having ComposedFunctions with identity + if optic_vn == identity + new_inner_optic_vn = PropertyLens{sym_vn}() + else + new_inner_optic_vn = optic_vn ∘ PropertyLens{sym_vn}() + end + if optic_prefix == identity + new_optic_vn = new_inner_optic_vn + else + new_optic_vn = new_inner_optic_vn ∘ optic_prefix + end + return VarName{sym_prefix}(new_optic_vn) +end diff --git a/test/varname.jl b/test/varname.jl index 32ac1c1..1eeb454 100644 --- a/test/varname.jl +++ b/test/varname.jl @@ -233,4 +233,18 @@ end # Serialisation should now work @test string_to_varname(varname_to_string(vn)) == vn end + + @testset "prefix and unprefix" begin + @test prefix(@varname(y), @varname(x)) == @varname(x.y) + @test prefix(@varname(y), @varname(x[1])) == @varname(x[1].y) + @test prefix(@varname(y), @varname(x.a)) == @varname(x.a.y) + @test prefix(@varname(y[1]), @varname(x)) == @varname(x.y[1]) + @test prefix(@varname(y.a), @varname(x)) == @varname(x.y.a) + + @test unprefix(@varname(x.y[1]), @varname(x)) == @varname(y[1]) + @test unprefix(@varname(x[1].y), @varname(x[1])) == @varname(y) + @test unprefix(@varname(x.a.y), @varname(x.a)) == @varname(y) + @test unprefix(@varname(x.y.a), @varname(x)) == @varname(y.a) + @test_throws ArgumentError unprefix(@varname(x.y.a), @varname(n)) + end end From 9d0ff7189087e6b4f7d385f361a10b59a76482aa Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 4 Mar 2025 10:11:41 +0000 Subject: [PATCH 2/6] Document --- docs/src/api.md | 7 +++++++ src/varname.jl | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/docs/src/api.md b/docs/src/api.md index e4ca9e2..e7705da 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -14,6 +14,13 @@ vsym @vsym ``` +## VarName prefixing and unprefixing + +```@docs +prefix +unprefix +``` + ## VarName serialisation ```@docs diff --git a/src/varname.jl b/src/varname.jl index 62d71f9..a2ca93e 100644 --- a/src/varname.jl +++ b/src/varname.jl @@ -1059,6 +1059,26 @@ function unprefix_optic(optic, optic_prefix) return unprefix_optic(_outer(optic), _outer(optic_prefix)) end +""" + unprefix(vn::VarName, prefix::VarName) + +Remove a prefix from a VarName. + +```jldoctest +julia> AbstractPPL.unprefix(@varname(y.x), @varname(y)) +x + +julia> AbstractPPL.unprefix(@varname(y.x.a), @varname(y)) +x.a + +julia> AbstractPPL.unprefix(@varname(y[1].x), @varname(y[1])) +x + +julia> AbstractPPL.unprefix(@varname(y), @varname(n)) +ERROR: ArgumentError: could not remove prefix n from VarName y +[...] +``` +""" function unprefix( vn::VarName{sym_vn}, prefix::VarName{sym_prefix} ) where {sym_vn,sym_prefix} @@ -1071,6 +1091,22 @@ function unprefix( return optic_to_vn(unprefix_optic(optic_vn, optic_prefix)) end +""" + prefix(vn::VarName, prefix::VarName) + +Add a prefix to a VarName. + +```jldoctest +julia> AbstractPPL.prefix(@varname(x), @varname(y)) +y.x + +julia> AbstractPPL.prefix(@varname(x.a), @varname(y)) +y.x.a + +julia> AbstractPPL.prefix(@varname(x.a), @varname(y[1])) +y[1].x.a +``` +""" function prefix(vn::VarName{sym_vn}, prefix::VarName{sym_prefix}) where {sym_vn,sym_prefix} optic_vn = getoptic(vn) optic_prefix = getoptic(prefix) From 89b2b425bf88f3e561ee855c840d4f57f5b7e5f9 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 4 Mar 2025 10:23:10 +0000 Subject: [PATCH 3/6] Clean up code a bit --- src/varname.jl | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/src/varname.jl b/src/varname.jl index a2ca93e..09599b1 100644 --- a/src/varname.jl +++ b/src/varname.jl @@ -944,9 +944,15 @@ string_to_varname(str::AbstractString) = dict_to_varname(JSON.parse(str)) """ _strip_identity(optic) -Remove an inner layer of the identity lens from a composed optic. +Remove identity lenses from composed optics. """ -_strip_identity(o::Base.ComposedFunction{Outer,typeof(identity)}) where {Outer} = o.outer +_strip_identity(::Base.ComposedFunction{typeof(identity),typeof(identity)}) = identity +function _strip_identity(o::Base.ComposedFunction{Outer,typeof(identity)}) where {Outer} + return _strip_identity(o.outer) +end +function _strip_identity(o::Base.ComposedFunction{typeof(identity),Inner}) where {Inner} + return _strip_identity(o.inner) +end _strip_identity(o::Base.ComposedFunction) = o _strip_identity(o::Accessors.PropertyLens) = o _strip_identity(o::Accessors.IndexLens) = o @@ -969,9 +975,7 @@ identity (generic function with 1 method) ``` """ _inner(o::Base.ComposedFunction{Outer,Inner}) where {Outer,Inner} = o.inner -function _inner(o::Base.ComposedFunction{Outer,typeof(identity)}) where {Outer} - return _strip_identity(o.outer) -end +_inner(o::Base.ComposedFunction{Outer,typeof(identity)}) where {Outer} = o.outer _inner(o::Accessors.PropertyLens) = o _inner(o::Accessors.IndexLens) = o _inner(o::typeof(identity)) = o @@ -998,7 +1002,7 @@ julia> AbstractPPL._outer(Accessors.@o _) identity (generic function with 1 method) ``` """ -_outer(o::Base.ComposedFunction{Outer,Inner}) where {Outer,Inner} = _strip_identity(o.outer) +_outer(o::Base.ComposedFunction{Outer,Inner}) where {Outer,Inner} = o.outer _outer(::Accessors.PropertyLens) = identity _outer(::Accessors.IndexLens) = identity _outer(::typeof(identity)) = identity @@ -1049,14 +1053,16 @@ end unprefix_optic(o, ::typeof(identity)) = o # Base case function unprefix_optic(optic, optic_prefix) # strip one layer of the optic and check for equality - inner = _inner(optic) - inner_prefix = _inner(optic_prefix) + inner = _inner(_strip_identity(optic)) + inner_prefix = _inner(_strip_identity(optic_prefix)) if inner != inner_prefix msg = "could not remove prefix $(optic_prefix) from optic $(optic)" throw(ArgumentError(msg)) end # recurse - return unprefix_optic(_outer(optic), _outer(optic_prefix)) + return unprefix_optic( + _outer(_strip_identity(optic)), _outer(_strip_identity(optic_prefix)) + ) end """ From b4a58fafd3e699a7134558b570b31b8ba8164f62 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 4 Mar 2025 10:26:26 +0000 Subject: [PATCH 4/6] Add another test --- test/varname.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/varname.jl b/test/varname.jl index 1eeb454..3fb2733 100644 --- a/test/varname.jl +++ b/test/varname.jl @@ -246,5 +246,6 @@ end @test unprefix(@varname(x.a.y), @varname(x.a)) == @varname(y) @test unprefix(@varname(x.y.a), @varname(x)) == @varname(y.a) @test_throws ArgumentError unprefix(@varname(x.y.a), @varname(n)) + @test_throws ArgumentError unprefix(@varname(x.y.a), @varname(x[1])) end end From 64612e6e962d09a8d3f939b5b74fda2b483b9473 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 4 Mar 2025 11:31:00 +0000 Subject: [PATCH 5/6] Remove unused method --- src/varname.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/varname.jl b/src/varname.jl index 09599b1..38d558c 100644 --- a/src/varname.jl +++ b/src/varname.jl @@ -975,7 +975,6 @@ identity (generic function with 1 method) ``` """ _inner(o::Base.ComposedFunction{Outer,Inner}) where {Outer,Inner} = o.inner -_inner(o::Base.ComposedFunction{Outer,typeof(identity)}) where {Outer} = o.outer _inner(o::Accessors.PropertyLens) = o _inner(o::Accessors.IndexLens) = o _inner(o::typeof(identity)) = o From b10bd1b0807e417b15571a66b8357ba39ddfa473 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 4 Mar 2025 11:31:14 +0000 Subject: [PATCH 6/6] Bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index c3723c8..42f84f2 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" keywords = ["probablistic programming"] license = "MIT" desc = "Common interfaces for probabilistic programming" -version = "0.10.1" +version = "0.11.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"