Skip to content

Replace MacroTools with ExprTools #121

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@ version = "0.7.3"
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
AbstractMCMC = "1.0"
Bijectors = "0.5.2, 0.6, 0.7"
Distributions = "0.22, 0.23"
MacroTools = "0.5.1"
ExprTools = "0.1.1"
ZygoteRules = "0.2"
julia = "1"

Expand Down
2 changes: 1 addition & 1 deletion src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module DynamicPPL
using AbstractMCMC: AbstractSampler, AbstractChains, AbstractModel
using Distributions
using Bijectors
using MacroTools
using ExprTools

import AbstractMCMC
import ZygoteRules
Expand Down
78 changes: 46 additions & 32 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,25 +81,33 @@ Builds the `model_info` dictionary from the model's expression.
"""
function build_model_info(input_expr)
# Extract model name (:name), arguments (:args), (:kwargs) and definition (:body)
modeldef = MacroTools.splitdef(input_expr)
modeldef = ExprTools.splitdef(input_expr)
modeldef[:whereparams] = get(modeldef,:whereparams, [])
modeldef[:args] = get(modeldef,:args, [])
# Function body of the model is empty
warn_empty(modeldef[:body])
# Construct model_info dictionary

# Extracting the argument symbols from the model definition
arg_syms = map(modeldef[:args]) do arg
# @model demo(x)
if (arg isa Symbol)
if arg isa Symbol
arg
# @model demo(::Type{T}) where {T}
elseif MacroTools.@capture(arg, ::Type{T_} = Tval_)
T
# @model demo(x::T = 1)
elseif MacroTools.@capture(arg, x_::T_ = val_)
x
# @model demo(x = 1)
elseif MacroTools.@capture(arg, x_ = val_)
x
elseif arg == Expr(:kw, IsEqual(x->true), IsEqual(x->true))
expr, default = arg.args
# @model demo(::Type{T}=Float64) where {T}
if !isnothing(get_type(expr))
# TODO support t::Type{T}
!Meta.isexpr(expr, :(::), 2) || throw(ArgumentError("The syntax `t::Type{T}` is currently unsupported please use `::Type{T}` instead."))
get_type(expr)
# @model demo(x::Int = 1) or demo(x::Vector{Int} = [1,2,3])
elseif !isnothing(get_symbol(expr))
get_symbol(expr)
# @model demo(x = 1)
elseif expr isa Symbol
expr
else
throw(ArgumentError("Unsupported default argument $arg to the `@model` macro."))
end
else
throw(ArgumentError("Unsupported argument $arg to the `@model` macro."))
end
Expand All @@ -113,18 +121,21 @@ function build_model_info(input_expr)
)
args_nt = Expr(:call, :($namedtuple), nt_type, Expr(:tuple, arg_syms...))
end
args = map(modeldef[:args]) do arg
args = map(get(modeldef, :args, [])) do arg
if (arg isa Symbol)
arg
elseif MacroTools.@capture(arg, ::Type{T_} = Tval_)
if in(T, modeldef[:whereparams])
S = :Any
else
ind = findfirst(modeldef[:whereparams]) do x
MacroTools.@capture(x, T1_ <: S_) && T1 == T
elseif !isnothing(get_type(arg.args[1]))
Texpr, Tval = arg.args
T = get_type(Texpr)
S = nothing
for x in modeldef[:whereparams]
if x == T
S = :Any
elseif x == Expr(:<:, T, IsEqual(x->true))
S = x.args[2]
end
ind !== nothing || throw(ArgumentError("Please make sure type parameters are properly used. Every `Type{T}` argument need to have `T` in the a `where` clause"))
end
!isnothing(S) || throw(ArgumentError("Please make sure type parameters are properly used. Every `Type{T}` argument need to have `T` in the a `where` clause"))
Expr(:kw, :($T::Type{<:$S}), Tval)
else
arg
Expand All @@ -134,18 +145,19 @@ function build_model_info(input_expr)

default_syms = []
default_vals = []
foreach(modeldef[:args]) do arg
# @model demo(::Type{T}) where {T}
if MacroTools.@capture(arg, ::Type{T_} = Tval_)
push!(default_syms, T)
push!(default_vals, Tval)
# @model demo(x::T = 1)
elseif MacroTools.@capture(arg, x_::T_ = val_)
push!(default_syms, x)
push!(default_vals, val)
# @model demo(x = 1)
elseif MacroTools.@capture(arg, x_ = val_)
push!(default_syms, x)
foreach(get(modeldef, :args, [])) do arg
if arg == Expr(:kw, IsEqual(x->true), IsEqual(x->true))
var, val = arg.args
sym = if var isa Symbol
var
elseif var == Expr(:(::), IsEqual(issymbol), IsEqual(x->true))
var.args[1]
elseif !isnothing(get_type(var))
get_type(var)
else
error("Unsupported keyword argument given $arg")
end
push!(default_syms, sym)
push!(default_vals, val)
end
end
Expand All @@ -164,6 +176,8 @@ function build_model_info(input_expr)
return model_info
end



"""
generate_mainbody(expr, args, warn)

Expand Down
74 changes: 58 additions & 16 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,23 +1,64 @@
import Base: ==

"""
IsEqual(fn)

Takes a funciton of the form `fn(x)::Bool`
"""
struct IsEqual
fn::Function
end

(==)(x::IsEqual, y) = x.fn(y)
(==)(y , x::IsEqual) = x == y
Comment on lines +3 to +13
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the motivation for IsEqual? To me it seems it would be much easier (also to read) if we just check fn(y) if needed. In functions such as all or any you could even just provide fn as first argument directly. In calls such as expr == Expr(:(::), IsEqual(issymbol), IsEqual(x->true)) you rely on the fact that the internal implementation of (==)(::Expr, ::Expr) will check == recursively, which is not guaranteed in general. IMO it would be safer and less prone to surprises if one explicitly checks the parts of an expression one is interested in.

Suggested change
"""
IsEqual(fn)
Takes a funciton of the form `fn(x)::Bool`
"""
struct IsEqual
fn::Function
end
(==)(x::IsEqual, y) = x.fn(y)
(==)(y , x::IsEqual) = x == y

Copy link
Author

@onetonfoot onetonfoot May 17, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found it to be more readable since the shape of the Expr your checking is clearly declared.

For example

function get_type(expr)
    if expr == Expr(:(::), Expr(:curly, :Type, IsEqual(issymbol)))
        expr.args[1].args[2]
    elseif expr == Expr(:(::), IsEqual(issymbol) , Expr(:curly, :Type, IsEqual(issymbol)))
        expr.args[2].args[2]
    else
        nothing
    end
end

Vs

function get_type(expr)
    type_expr = if Meta.isexpr(expr, :(::), 2)
        expr.args[1]
    elseif Meta.isexpr(expr, :(::), 1)
        expr.args[2]
    else
        nothing
    end
    isnothing(type_expr) ? nothing : type_expr[2]
end

But yeah it could be refactored into a functions instead if that's the preferred style.

I played with this idea a bit more and the recursion seems to work ok.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a problem with the style per se (although I'm not sure if it will make it easier for people to understand DynamicPPL's source code). The main problem I see is just that the expression matching happens implicitly, by relying on how ==(::Expr, ::Expr) is implemented which we can't control.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it assumes your familiar with the S-Expr form which is probably the case if you're writing the functions but may not be so if your reading the source or just calling said functions.

I'd think ==(::Expr, ::Expr) would remain stable since it's quite a fundamental to the language but you're right in that it's not under our control which could be a problem long term if it changes across Julia versions.

I'll update the pull to remove the use of IsEqual.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it's a good idea if we write a couple of generic helper function s in the style of splitdef. Something like f, (arg1, arg2) = splitcall(expr), t, T = splittyping(expr), etc., in a form that allow tuple unpacking and mark absence with nothing.


issymbol(x) = x isa Symbol
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you remove IsEqual, I'd also get rid of this. If it stays, let's at least use it consistenly across the file instead of isa Symbol.


"""
get_symbol(expr)

Return `x` for expressions of form `x::Type` otherwise return nothing
"""
function get_symbol(expr)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_symbol is a too ambiguous name, IMO. What about get_argument_name?

if expr == Expr(:(::), IsEqual(issymbol), IsEqual(x->true))
expr.args[1]
else
nothing
end
end

"""
get_type(x)

Return `T` if an expresion is of the form `:(x::Type{T})` or `:(::Type{T})` when `T` is a symbol otherwise returns `nothing`
"""
function get_type(expr)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly here: I propose get_argument_type for consistency.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about getarg_type and getarg_symbol? Since is closer to getargs_dottile.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With what I proposed above, you could actually subsume it under splittyping, with a helper splitcurly:

t, T = splittypeing(:(t::Type{T}))
@assert T !== nothing
_Type, Ts = splitcurly(T)
@assert Ts == (:T,)

Although this might get out of hand quickly -- now that I write it, I start to reimplement pattern matching in my head...

if expr == Expr(:(::), Expr(:curly, :Type, IsEqual(issymbol)))
expr.args[1].args[2]
elseif expr == Expr(:(::), IsEqual(issymbol) , Expr(:curly, :Type, IsEqual(issymbol)))
expr.args[2].args[2]
else
nothing
end
end

"""
getargs_dottilde(x)

Return the arguments `L` and `R`, if `x` is an expression of the form `L .~ R` or
`(~).(L, R)`, or `nothing` otherwise.
"""
getargs_dottilde(x) = nothing
function getargs_dottilde(expr::Expr)
function getargs_dottilde(expr)
any_arg = IsEqual(x->true)
# Check if the expression is of the form `L .~ R`.
if Meta.isexpr(expr, :call, 3) && expr.args[1] === :.~
return expr.args[2], expr.args[3]
end

if expr == Expr(:call, :.~, any_arg, any_arg)
expr.args[2], expr.args[3]
# Check if the expression is of the form `(~).(L, R)`.
if Meta.isexpr(expr, :., 2) && expr.args[1] === :~ &&
Meta.isexpr(expr.args[2], :tuple, 2)
return expr.args[2].args[1], expr.args[2].args[2]
elseif expr == Expr(:., :~, Expr(:tuple, any_arg, any_arg))
expr.args[2].args[1], expr.args[2].args[2]
else
nothing
end

return
end

"""
Expand All @@ -26,12 +67,13 @@ end
Return the arguments `L` and `R`, if `x` is an expression of the form `L ~ R`, or `nothing`
otherwise.
"""
getargs_tilde(x) = nothing
function getargs_tilde(expr::Expr)
if Meta.isexpr(expr, :call, 3) && expr.args[1] === :~
return expr.args[2], expr.args[3]
function getargs_tilde(expr)
any_arg = IsEqual(x->true)
if expr == Expr(:call, :~, any_arg, any_arg)
expr.args[2], expr.args[3]
else
nothing
end
return
end

############################################
Expand Down
2 changes: 1 addition & 1 deletion test/Turing/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ module Turing
using Requires, Reexport, ForwardDiff
using DistributionsAD, Bijectors, StatsFuns, SpecialFunctions
using Statistics, LinearAlgebra
using Markdown, Libtask, MacroTools
using Markdown, Libtask
@reexport using Distributions, MCMCChains, Libtask
using Tracker: Tracker

Expand Down
2 changes: 1 addition & 1 deletion test/Turing/core/Core.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module Core

using DistributionsAD, Bijectors
using MacroTools, Libtask, ForwardDiff, Random
using Libtask, ForwardDiff, Random
using Distributions, LinearAlgebra
using ..Utilities, Reexport
using Tracker: Tracker
Expand Down
91 changes: 89 additions & 2 deletions test/compiler.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using .Turing, Random, MacroTools, Distributions, Test
using DynamicPPL: DynamicPPL, vsym, vinds, @varname, VarInfo, VarName
using .Turing, Random, ExprTools, Distributions, Test
using DynamicPPL: DynamicPPL, vsym, vinds, @varname, VarInfo, VarName, build_model_info, namedtuple, build_model_info

dir = splitdir(splitdir(pathof(DynamicPPL))[1])[1]
include(dir*"/test/test_utils/AllUtils.jl")
Expand All @@ -16,6 +16,93 @@ macro custom(expr)
end
end

@testset "build_model_info" begin
@testset "single arg" begin
expr = :(model(obs) = begin
obs
end)

result = build_model_info(expr)
@test result[:args] == [:obs]
@test result[:arg_syms] == [:obs]
@test result[:name] == :model
result
end

@testset "default arg missing" begin
expr = :(model(x = missing) = begin
x
end)

result = build_model_info(expr)

@test result[:args] == [:($(Expr(:kw, :x, :missing)))]
@test result[:arg_syms] == [:x]
args_nt = result[:args_nt]
@test args_nt.args[2] == :(NamedTuple{(:x,), Tuple{Core.Typeof(x)}})
@test args_nt.args[3] == :((x,))
defaults_nt = result[:defaults_nt]
@test defaults_nt.args[2] == :(NamedTuple{(:x,), Tuple{Core.Typeof(missing)}})
@test defaults_nt.args[3] == :((missing,))
end

@testset "default arg with type annotation" begin
expr1 = :(model(x::Int = 2) = begin
x
end)

result1 = build_model_info(expr1)

@test result1[:args] == [:($(Expr(:kw, :(x::Int), 2)))]
@test result1[:arg_syms] == [:x]
args_nt = result1[:args_nt]
@test args_nt.args[2] == :(NamedTuple{(:x,), Tuple{Core.Typeof(x)}})
@test args_nt.args[3] == :((x,))
defaults_nt = result1[:defaults_nt]
@test defaults_nt.args[2] == :(NamedTuple{(:x,), Tuple{Core.Typeof(2)}})
@test defaults_nt.args[3] == :((2,))

expr2 = :(model(x::Vector{Float64} = [1,2,3]) = begin
x
end)

result2 = build_model_info(expr2)
@test result2[:arg_syms] == [:x]
end

@testset "default arg type" begin
expr1 = :(model(::Type{T}=Float64) where {T <: Float64} = begin
T
end)

result1 = build_model_info(expr1)
@test result1[:args] == [:($(Expr(:kw, :(T::Type{<:Float64}), :Float64)))]
@test result1[:arg_syms] == [:T]

expr2 = :(model(::Type{T}=Float64) where {T} = begin
T
end)

result2 = build_model_info(expr2)
@test result2[:args] == [:($(Expr(:kw, :(T::Type{<:Any}), :Float64)))]
@test result2[:arg_syms] == [:T]

#TODO support t::Type{T}
expr3 = :(model(t::Type{T}=Float64) where {T <: Float64} = begin
T
end)

@test_throws ArgumentError build_model_info(expr3)

#Invalid expression
expr4 = :(model(::Type{T}=Float64) where {X} = begin
T
end)
@test_throws ArgumentError build_model_info(expr4)
end
end


@testset "compiler.jl" begin
@testset "assume" begin
@model test_assume() = begin
Expand Down
17 changes: 14 additions & 3 deletions test/utils.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
using DynamicPPL
using DynamicPPL: getargs_dottilde, getargs_tilde

using Test
using DynamicPPL: getargs_dottilde, getargs_tilde, get_type, get_symbol

@testset "getargs_dottilde" begin
# Some things that are not expressions.
Expand Down Expand Up @@ -32,3 +30,16 @@ end
@test getargs_tilde(:(@. x ~ Normal(μ, $(Expr(:$, :(sqrt(v))))))) === nothing
@test getargs_tilde(:(@~ Normal.(μ, σ))) === nothing
end

@testset "get_type" begin
@test get_type(:(::Type{T})) == :T
@test get_type(:(a::Type{A})) == :A
@test get_type(:(::Type{T < Float64})) === nothing
end

@testset "get_symbol" begin
@test get_symbol(:(x::Int)) == :x
@test get_symbol(:(a::Type{A})) == :a
@test get_symbol(:(::Type{A})) === nothing
@test get_symbol(:(y::Vector{Int})) == :y
end