diff --git a/Project.toml b/Project.toml index c0179083b..d3e7988be 100644 --- a/Project.toml +++ b/Project.toml @@ -7,7 +7,7 @@ version = "0.7.3" AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" -MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" @@ -15,7 +15,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" AbstractMCMC = "1.0" Bijectors = "0.5.2, 0.6, 0.7" Distributions = "0.22, 0.23" -MacroTools = "0.5.1" +ExprTools = "0.1.1" ZygoteRules = "0.2" julia = "1" diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index c7db4873f..bac5fdc48 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -3,7 +3,7 @@ module DynamicPPL using AbstractMCMC: AbstractSampler, AbstractChains, AbstractModel using Distributions using Bijectors -using MacroTools +using ExprTools import AbstractMCMC import ZygoteRules diff --git a/src/compiler.jl b/src/compiler.jl index 2e91afd52..f88f41b50 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -81,25 +81,33 @@ Builds the `model_info` dictionary from the model's expression. """ function build_model_info(input_expr) # Extract model name (:name), arguments (:args), (:kwargs) and definition (:body) - modeldef = MacroTools.splitdef(input_expr) + modeldef = ExprTools.splitdef(input_expr) + modeldef[:whereparams] = get(modeldef,:whereparams, []) + modeldef[:args] = get(modeldef,:args, []) # Function body of the model is empty warn_empty(modeldef[:body]) - # Construct model_info dictionary # Extracting the argument symbols from the model definition arg_syms = map(modeldef[:args]) do arg # @model demo(x) - if (arg isa Symbol) + if arg isa Symbol arg - # @model demo(::Type{T}) where {T} - elseif MacroTools.@capture(arg, ::Type{T_} = Tval_) - T - # @model demo(x::T = 1) - elseif MacroTools.@capture(arg, x_::T_ = val_) - x - # @model demo(x = 1) - elseif MacroTools.@capture(arg, x_ = val_) - x + elseif arg == Expr(:kw, IsEqual(x->true), IsEqual(x->true)) + expr, default = arg.args + # @model demo(::Type{T}=Float64) where {T} + if !isnothing(get_type(expr)) + # TODO support t::Type{T} + !Meta.isexpr(expr, :(::), 2) || throw(ArgumentError("The syntax `t::Type{T}` is currently unsupported please use `::Type{T}` instead.")) + get_type(expr) + # @model demo(x::Int = 1) or demo(x::Vector{Int} = [1,2,3]) + elseif !isnothing(get_symbol(expr)) + get_symbol(expr) + # @model demo(x = 1) + elseif expr isa Symbol + expr + else + throw(ArgumentError("Unsupported default argument $arg to the `@model` macro.")) + end else throw(ArgumentError("Unsupported argument $arg to the `@model` macro.")) end @@ -113,18 +121,21 @@ function build_model_info(input_expr) ) args_nt = Expr(:call, :($namedtuple), nt_type, Expr(:tuple, arg_syms...)) end - args = map(modeldef[:args]) do arg + args = map(get(modeldef, :args, [])) do arg if (arg isa Symbol) arg - elseif MacroTools.@capture(arg, ::Type{T_} = Tval_) - if in(T, modeldef[:whereparams]) - S = :Any - else - ind = findfirst(modeldef[:whereparams]) do x - MacroTools.@capture(x, T1_ <: S_) && T1 == T + elseif !isnothing(get_type(arg.args[1])) + Texpr, Tval = arg.args + T = get_type(Texpr) + S = nothing + for x in modeldef[:whereparams] + if x == T + S = :Any + elseif x == Expr(:<:, T, IsEqual(x->true)) + S = x.args[2] end - ind !== nothing || throw(ArgumentError("Please make sure type parameters are properly used. Every `Type{T}` argument need to have `T` in the a `where` clause")) end + !isnothing(S) || throw(ArgumentError("Please make sure type parameters are properly used. Every `Type{T}` argument need to have `T` in the a `where` clause")) Expr(:kw, :($T::Type{<:$S}), Tval) else arg @@ -134,18 +145,19 @@ function build_model_info(input_expr) default_syms = [] default_vals = [] - foreach(modeldef[:args]) do arg - # @model demo(::Type{T}) where {T} - if MacroTools.@capture(arg, ::Type{T_} = Tval_) - push!(default_syms, T) - push!(default_vals, Tval) - # @model demo(x::T = 1) - elseif MacroTools.@capture(arg, x_::T_ = val_) - push!(default_syms, x) - push!(default_vals, val) - # @model demo(x = 1) - elseif MacroTools.@capture(arg, x_ = val_) - push!(default_syms, x) + foreach(get(modeldef, :args, [])) do arg + if arg == Expr(:kw, IsEqual(x->true), IsEqual(x->true)) + var, val = arg.args + sym = if var isa Symbol + var + elseif var == Expr(:(::), IsEqual(issymbol), IsEqual(x->true)) + var.args[1] + elseif !isnothing(get_type(var)) + get_type(var) + else + error("Unsupported keyword argument given $arg") + end + push!(default_syms, sym) push!(default_vals, val) end end @@ -164,6 +176,8 @@ function build_model_info(input_expr) return model_info end + + """ generate_mainbody(expr, args, warn) diff --git a/src/utils.jl b/src/utils.jl index 9ad2a844d..dae04aedc 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,23 +1,64 @@ +import Base: == + +""" + IsEqual(fn) + +Takes a funciton of the form `fn(x)::Bool` +""" +struct IsEqual + fn::Function +end + +(==)(x::IsEqual, y) = x.fn(y) +(==)(y , x::IsEqual) = x == y + +issymbol(x) = x isa Symbol + +""" + get_symbol(expr) + +Return `x` for expressions of form `x::Type` otherwise return nothing +""" +function get_symbol(expr) + if expr == Expr(:(::), IsEqual(issymbol), IsEqual(x->true)) + expr.args[1] + else + nothing + end +end + +""" + get_type(x) + +Return `T` if an expresion is of the form `:(x::Type{T})` or `:(::Type{T})` when `T` is a symbol otherwise returns `nothing` +""" +function get_type(expr) + if expr == Expr(:(::), Expr(:curly, :Type, IsEqual(issymbol))) + expr.args[1].args[2] + elseif expr == Expr(:(::), IsEqual(issymbol) , Expr(:curly, :Type, IsEqual(issymbol))) + expr.args[2].args[2] + else + nothing + end +end + """ getargs_dottilde(x) Return the arguments `L` and `R`, if `x` is an expression of the form `L .~ R` or `(~).(L, R)`, or `nothing` otherwise. """ -getargs_dottilde(x) = nothing -function getargs_dottilde(expr::Expr) +function getargs_dottilde(expr) + any_arg = IsEqual(x->true) # Check if the expression is of the form `L .~ R`. - if Meta.isexpr(expr, :call, 3) && expr.args[1] === :.~ - return expr.args[2], expr.args[3] - end - + if expr == Expr(:call, :.~, any_arg, any_arg) + expr.args[2], expr.args[3] # Check if the expression is of the form `(~).(L, R)`. - if Meta.isexpr(expr, :., 2) && expr.args[1] === :~ && - Meta.isexpr(expr.args[2], :tuple, 2) - return expr.args[2].args[1], expr.args[2].args[2] + elseif expr == Expr(:., :~, Expr(:tuple, any_arg, any_arg)) + expr.args[2].args[1], expr.args[2].args[2] + else + nothing end - - return end """ @@ -26,12 +67,13 @@ end Return the arguments `L` and `R`, if `x` is an expression of the form `L ~ R`, or `nothing` otherwise. """ -getargs_tilde(x) = nothing -function getargs_tilde(expr::Expr) - if Meta.isexpr(expr, :call, 3) && expr.args[1] === :~ - return expr.args[2], expr.args[3] +function getargs_tilde(expr) + any_arg = IsEqual(x->true) + if expr == Expr(:call, :~, any_arg, any_arg) + expr.args[2], expr.args[3] + else + nothing end - return end ############################################ diff --git a/test/Turing/Turing.jl b/test/Turing/Turing.jl index c83f663c6..3c554491e 100644 --- a/test/Turing/Turing.jl +++ b/test/Turing/Turing.jl @@ -11,7 +11,7 @@ module Turing using Requires, Reexport, ForwardDiff using DistributionsAD, Bijectors, StatsFuns, SpecialFunctions using Statistics, LinearAlgebra -using Markdown, Libtask, MacroTools +using Markdown, Libtask @reexport using Distributions, MCMCChains, Libtask using Tracker: Tracker diff --git a/test/Turing/core/Core.jl b/test/Turing/core/Core.jl index 0227a6498..551b67452 100644 --- a/test/Turing/core/Core.jl +++ b/test/Turing/core/Core.jl @@ -1,7 +1,7 @@ module Core using DistributionsAD, Bijectors -using MacroTools, Libtask, ForwardDiff, Random +using Libtask, ForwardDiff, Random using Distributions, LinearAlgebra using ..Utilities, Reexport using Tracker: Tracker diff --git a/test/compiler.jl b/test/compiler.jl index da180baea..2158cce7f 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -1,5 +1,5 @@ -using .Turing, Random, MacroTools, Distributions, Test -using DynamicPPL: DynamicPPL, vsym, vinds, @varname, VarInfo, VarName +using .Turing, Random, ExprTools, Distributions, Test +using DynamicPPL: DynamicPPL, vsym, vinds, @varname, VarInfo, VarName, build_model_info, namedtuple, build_model_info dir = splitdir(splitdir(pathof(DynamicPPL))[1])[1] include(dir*"/test/test_utils/AllUtils.jl") @@ -16,6 +16,93 @@ macro custom(expr) end end +@testset "build_model_info" begin + @testset "single arg" begin + expr = :(model(obs) = begin + obs + end) + + result = build_model_info(expr) + @test result[:args] == [:obs] + @test result[:arg_syms] == [:obs] + @test result[:name] == :model + result + end + + @testset "default arg missing" begin + expr = :(model(x = missing) = begin + x + end) + + result = build_model_info(expr) + + @test result[:args] == [:($(Expr(:kw, :x, :missing)))] + @test result[:arg_syms] == [:x] + args_nt = result[:args_nt] + @test args_nt.args[2] == :(NamedTuple{(:x,), Tuple{Core.Typeof(x)}}) + @test args_nt.args[3] == :((x,)) + defaults_nt = result[:defaults_nt] + @test defaults_nt.args[2] == :(NamedTuple{(:x,), Tuple{Core.Typeof(missing)}}) + @test defaults_nt.args[3] == :((missing,)) + end + + @testset "default arg with type annotation" begin + expr1 = :(model(x::Int = 2) = begin + x + end) + + result1 = build_model_info(expr1) + + @test result1[:args] == [:($(Expr(:kw, :(x::Int), 2)))] + @test result1[:arg_syms] == [:x] + args_nt = result1[:args_nt] + @test args_nt.args[2] == :(NamedTuple{(:x,), Tuple{Core.Typeof(x)}}) + @test args_nt.args[3] == :((x,)) + defaults_nt = result1[:defaults_nt] + @test defaults_nt.args[2] == :(NamedTuple{(:x,), Tuple{Core.Typeof(2)}}) + @test defaults_nt.args[3] == :((2,)) + + expr2 = :(model(x::Vector{Float64} = [1,2,3]) = begin + x + end) + + result2 = build_model_info(expr2) + @test result2[:arg_syms] == [:x] + end + + @testset "default arg type" begin + expr1 = :(model(::Type{T}=Float64) where {T <: Float64} = begin + T + end) + + result1 = build_model_info(expr1) + @test result1[:args] == [:($(Expr(:kw, :(T::Type{<:Float64}), :Float64)))] + @test result1[:arg_syms] == [:T] + + expr2 = :(model(::Type{T}=Float64) where {T} = begin + T + end) + + result2 = build_model_info(expr2) + @test result2[:args] == [:($(Expr(:kw, :(T::Type{<:Any}), :Float64)))] + @test result2[:arg_syms] == [:T] + + #TODO support t::Type{T} + expr3 = :(model(t::Type{T}=Float64) where {T <: Float64} = begin + T + end) + + @test_throws ArgumentError build_model_info(expr3) + + #Invalid expression + expr4 = :(model(::Type{T}=Float64) where {X} = begin + T + end) + @test_throws ArgumentError build_model_info(expr4) + end +end + + @testset "compiler.jl" begin @testset "assume" begin @model test_assume() = begin diff --git a/test/utils.jl b/test/utils.jl index 542eff9d3..6552f782c 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -1,7 +1,5 @@ using DynamicPPL -using DynamicPPL: getargs_dottilde, getargs_tilde - -using Test +using DynamicPPL: getargs_dottilde, getargs_tilde, get_type, get_symbol @testset "getargs_dottilde" begin # Some things that are not expressions. @@ -32,3 +30,16 @@ end @test getargs_tilde(:(@. x ~ Normal(μ, $(Expr(:$, :(sqrt(v))))))) === nothing @test getargs_tilde(:(@~ Normal.(μ, σ))) === nothing end + +@testset "get_type" begin + @test get_type(:(::Type{T})) == :T + @test get_type(:(a::Type{A})) == :A + @test get_type(:(::Type{T < Float64})) === nothing +end + +@testset "get_symbol" begin + @test get_symbol(:(x::Int)) == :x + @test get_symbol(:(a::Type{A})) == :a + @test get_symbol(:(::Type{A})) === nothing + @test get_symbol(:(y::Vector{Int})) == :y +end \ No newline at end of file