diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index f0e0e023..6e12beea 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -25,7 +25,7 @@ jobs: fail-fast: false matrix: version: - - '1.6' + - '1.10' - '1' - 'pre' os: diff --git a/Project.toml b/Project.toml index f005cb3d..881a80e0 100644 --- a/Project.toml +++ b/Project.toml @@ -4,7 +4,6 @@ authors = ["Chad Scherrer ", "Oliver Schulz max_x && isapprox(x, max_x, atol = 4 * eps(T)) + max_x + else + x + end +end + + +@inline function _eval_dist_trafo_func(f::typeof(_trafo_cdf), d::Distribution{Univariate,Continuous}, src_v::Real) + R_V = float(promote_type(typeof(src_v), _dist_params_numtype(d))) + if insupport(d, src_v) + trg_v = f(d, src_v) + convert(R_V, trg_v) + else + convert(R_V, NaN) + end +end + +@inline function _eval_dist_trafo_func(f::typeof(_trafo_quantile), d::Distribution{Univariate,Continuous}, src_v::Real) + R_V = float(promote_type(typeof(src_v), _dist_params_numtype(d))) + if 0 <= src_v <= 1 + trg_v = f(d, src_v) + convert(R_V, trg_v) + else + convert(R_V, NaN) + end +end + + +std_dist_from(src_d::Distribution{Univariate,Continuous}) = StandardUvUniform() + +function apply_dist_trafo(::StandardUvUniform, src_d::Distribution{Univariate,Continuous}, src_v::Real) + _eval_dist_trafo_func(_trafo_cdf, src_d, src_v) +end + +std_dist_to(trg_d::Distribution{Univariate,Continuous}) = StandardUvUniform() + +function apply_dist_trafo(trg_d::Distribution{Univariate,Continuous}, ::StandardUvUniform, src_v::Real) + TV = float(typeof(src_v)) + # Avoid src_v ≈ 0 and src_v ≈ 1 to avoid infinite variate values for target distributions with infinite support: + mod_src_v = ifelse(src_v ≈ 0, zero(TV) + eps(TV), ifelse(src_v ≈ 1, one(TV) - eps(TV), convert(TV, src_v))) + _eval_dist_trafo_func(_trafo_quantile, trg_d, mod_src_v) +end + + + +function _dist_trafo_rescale_impl(trg_d, src_d, src_v::Real) + R = float(typeof(src_v)) + trg_offs, trg_scale = location(trg_d), scale(trg_d) + src_offs, src_scale = location(src_d), scale(src_d) + rescale_factor = trg_scale / src_scale + (src_v - src_offs) * rescale_factor + trg_offs +end + +@inline apply_dist_trafo(trg_d::Uniform, src_d::Uniform, src_v::Real) = _dist_trafo_rescale_impl(trg_d, src_d, src_v) +@inline apply_dist_trafo(trg_d::StandardUvUniform, src_d::Uniform, src_v::Real) = _dist_trafo_rescale_impl(trg_d, src_d, src_v) +@inline apply_dist_trafo(trg_d::Uniform, src_d::StandardUvUniform, src_v::Real) = _dist_trafo_rescale_impl(trg_d, src_d, src_v) + +# ToDo: Use StandardUvNormal as standard intermediate dist for Normal? Would +# be useful if StandardUvNormal would be a better standard intermediate than +# StandardUvUniform for some other uniform distributions as well. +# +# std_dist_from(src_d::Normal) = StandardUvNormal() +# std_dist_to(trg_d::Normal) = StandardUvNormal() + +@inline apply_dist_trafo(trg_d::Normal, src_d::Normal, src_v::Real) = _dist_trafo_rescale_impl(trg_d, src_d, src_v) +@inline apply_dist_trafo(trg_d::StandardUvNormal, src_d::Normal, src_v::Real) = _dist_trafo_rescale_impl(trg_d, src_d, src_v) +@inline apply_dist_trafo(trg_d::Normal, src_d::StandardUvNormal, src_v::Real) = _dist_trafo_rescale_impl(trg_d, src_d, src_v) + + +# ToDo: Optimized implementation for Distributions.Truncated <-> StandardUvUniform + + +@inline apply_dist_trafo(trg_d::StandardUvUniform, src_d::StandardUvUniform, src_v::Real) = src_v + +@inline apply_dist_trafo(trg_d::StandardUvNormal, src_d::StandardUvNormal, src_v::Real) = src_v + +@inline function apply_dist_trafo(trg_d::StandardUvUniform, src_d::StandardUvNormal, src_v::Real) + apply_dist_trafo(StandardUvUniform(), Normal(), src_v) +end + +@inline function apply_dist_trafo(trg_d::StandardUvNormal, src_d::StandardUvUniform, src_v::Real) + apply_dist_trafo(Normal(), StandardUvUniform(), src_v) +end + + +@inline function apply_dist_trafo(trg_d::StandardMvUniform, src_d::StandardMvNormal, src_v::AbstractVector{<:Real}) + @_adignore @argcheck eff_totalndof(trg_d) == eff_totalndof(src_d) + _product_dist_trafo_impl(StandardUvUniform(), StandardUvNormal(), src_v) +end + +@inline function apply_dist_trafo(trg_d::StandardMvNormal, src_d::StandardMvUniform, src_v::AbstractVector{<:Real}) + @_adignore @argcheck eff_totalndof(trg_d) == eff_totalndof(src_d) + _product_dist_trafo_impl(StandardUvNormal(), StandardUvUniform(), src_v) +end + + +std_dist_from(src_d::MvNormal) = StandardMvNormal(length(src_d)) + +_cholesky_L(A) = cholesky(A).L +_cholesky_L(A::Diagonal{<:Real}) = Diagonal(sqrt.(diag(A))) +_cholesky_L(A::PDiagMat{<:Real}) = Diagonal(sqrt.(A.diag)) +_cholesky_L(A::ScalMat{<:Real}) = Diagonal(Fill(sqrt(A.value), A.dim)) + +function apply_dist_trafo(trg_d::StandardMvNormal, src_d::MvNormal, src_v::AbstractVector{<:Real}) + @argcheck length(trg_d) == length(src_d) + _cholesky_L(src_d.Σ) \ (src_v - src_d.μ) +end + +std_dist_to(trg_d::MvNormal) = StandardMvNormal(length(trg_d)) + +function apply_dist_trafo(trg_d::MvNormal, src_d::StandardMvNormal, src_v::AbstractVector{<:Real}) + @argcheck length(trg_d) == length(src_d) + _cholesky_L(trg_d.Σ) * src_v + trg_d.μ +end + + +eff_totalndof(d::Dirichlet) = length(d) - 1 +eff_totalndof(d::DistributionsAD.TuringDirichlet) = length(d) - 1 + +std_dist_to(trg_d::Dirichlet) = StandardMvUniform(eff_totalndof(trg_d)) +std_dist_to(trg_d::DistributionsAD.TuringDirichlet) = StandardMvUniform(eff_totalndof(trg_d)) + +std_dist_from(trg_d::Dirichlet) = StandardMvUniform(eff_totalndof(trg_d)) +std_dist_from(trg_d::DistributionsAD.TuringDirichlet) = StandardMvUniform(eff_totalndof(trg_d)) + + +function apply_dist_trafo(trg_d::Dirichlet, src_d::StandardMvUniform, src_v::AbstractVector{<:Real}) + apply_dist_trafo(DistributionsAD.TuringDirichlet(trg_d.alpha), src_d, src_v) +end + +function apply_dist_trafo(trg_d::StandardMvUniform, src_d::Dirichlet, src_v::AbstractVector{<:Real}) + apply_dist_trafo(trg_d, DistributionsAD.TuringDirichlet(src_d.alpha), src_v) +end + +function _dirichlet_beta_trafo(α::Real, β::Real, src_v::Real) + R = float(promote_type(typeof(α), typeof(β), typeof(src_v))) + convert(R, apply_dist_trafo(Beta(α, β), StandardUvUniform(), src_v))::R +end + +_a_times_one_minus_b(a::Real, b::Real) = a * (1 - b) + +function apply_dist_trafo(trg_d::DistributionsAD.TuringDirichlet, src_d::StandardMvUniform, src_v::AbstractVector{<:Real}) + # See M. J. Betancourt, "Cruising The Simplex: Hamiltonian Monte Carlo and the Dirichlet Distribution", + # https://arxiv.org/abs/1010.3436 + + @_adignore @argcheck length(trg_d) == length(src_d) + 1 + αs = _dropfront(_rev_cumsum(trg_d.alpha)) + βs = _dropback(trg_d.alpha) + beta_v = fwddiff(_dirichlet_beta_trafo).(αs, βs, src_v) + beta_v_cp = _exp_cumsum_log(_pushfront(beta_v, 1)) + beta_v_ext = _pushback(beta_v, 0) + fwddiff(_a_times_one_minus_b).(beta_v_cp, beta_v_ext) +end + +function _inv_dirichlet_beta_trafo(α::Real, β::Real, beta_v::Real) + R = float(promote_type(typeof(α), typeof(β), typeof(beta_v))) + convert(R, apply_dist_trafo(StandardUvUniform(), Beta(α, β), beta_v))::R +end + +# ToDo: Find efficient pullback for this: +function _dirichlet_variate_to_beta_v(src_v::AbstractVector{<:Real}) + idxs = eachindex(src_v) + beta_v = similar(src_v, length(idxs) - 1) + @assert firstindex(beta_v) == firstindex(src_v) + @assert lastindex(beta_v) == lastindex(src_v) - 1 + T = eltype(src_v) + sum_log_beta_v::T = 0 + @inbounds for i in eachindex(beta_v) + beta_v[i] = 1 - src_v[i] / exp(sum_log_beta_v) + sum_log_beta_v += log(beta_v[i]) + end + return beta_v +end + +# ToDo: Make Zygote-compatible: +function apply_dist_trafo(trg_d::StandardMvUniform, src_d::DistributionsAD.TuringDirichlet, src_v::AbstractVector{<:Real}) + @_adignore @argcheck length(trg_d) == length(src_d) - 1 + αs = _dropfront(_rev_cumsum(src_d.alpha)) + βs = _dropback(src_d.alpha) + beta_v = _dirichlet_variate_to_beta_v(src_v) + fwddiff(_inv_dirichlet_beta_trafo).(αs, βs, beta_v) +end + + +function _product_dist_trafo_impl(trg_ds, src_ds, src_v::AbstractVector{<:Real}) + fwddiff(apply_dist_trafo).(trg_ds, src_ds, src_v) +end + +function apply_dist_trafo(trg_d::Distributions.Product, src_d::Distributions.Product, src_v::AbstractVector{<:Real}) + @_adignore @argcheck eff_totalndof(trg_d) == eff_totalndof(src_d) + _product_dist_trafo_impl(trg_d.v, src_d.v, src_v) +end + +function apply_dist_trafo(trg_d::StandardMvUniform, src_d::Distributions.Product, src_v::AbstractVector{<:Real}) + @_adignore @argcheck eff_totalndof(trg_d) == eff_totalndof(src_d) + _product_dist_trafo_impl(StandardUvUniform(), src_d.v, src_v) +end + +function apply_dist_trafo(trg_d::StandardMvNormal, src_d::Distributions.Product, src_v::AbstractVector{<:Real}) + @_adignore @argcheck eff_totalndof(trg_d) == eff_totalndof(src_d) + _product_dist_trafo_impl(StandardUvNormal(), src_d.v, src_v) +end + +function apply_dist_trafo(trg_d::Distributions.Product, src_d::StandardMvUniform, src_v::AbstractVector{<:Real}) + @_adignore @argcheck eff_totalndof(trg_d) == eff_totalndof(src_d) + _product_dist_trafo_impl(trg_d.v, StandardUvUniform(), src_v) +end + +function apply_dist_trafo(trg_d::Distributions.Product, src_d::StandardMvNormal, src_v::AbstractVector{<:Real}) + @_adignore @argcheck eff_totalndof(trg_d) == eff_totalndof(src_d) + _product_dist_trafo_impl(trg_d.v, StandardUvNormal(), src_v) +end + + +_flat_ntd_orig_elshape(d::Distribution) = ArrayShape{Real}(totalndof(varshape(d))) + +function _flat_ntd_orig_accessors(d::NamedTupleDist{names,DT,AT,VT}) where {names,DT,AT,VT} + shapes = map(_flat_ntd_orig_elshape, values(d)) + vs = NamedTupleShape(VT, NamedTuple{names}(shapes)) + values(vs) +end + +_flat_ntd_eff_elshape(d::Distribution) = ArrayShape{Real}(eff_totalndof(d)) + +function _flat_ntd_eff_accessors(d::NamedTupleDist{names,DT,AT,VT}) where {names,DT,AT,VT} + shapes = map(_flat_ntd_eff_elshape, values(d)) + vs = NamedTupleShape(VT, NamedTuple{names}(shapes)) + values(vs) +end + +function _flat_ntdistelem_to_stdmv(trg_d::StdMvDist, sd::Distribution, src_v_unshaped::AbstractVector{<:Real}, src_acc::ValueAccessor) + td = view(trg_d, Base.OneTo(eff_totalndof(sd))) + sv = src_acc(src_v_unshaped) + apply_dist_trafo(td, unshaped(sd), sv) +end + +function _flat_ntdistelem_to_stdmv(trg_d::StdMvDist, sd::ConstValueDist, src_v_unshaped::AbstractVector{<:Real}, src_acc::ValueAccessor) + Bool[] +end + +function apply_dist_trafo(trg_d::StdMvDist, src_d::ValueShapes.UnshapedNTD, src_v::AbstractVector{<:Real}) + @argcheck length(src_d) == length(eachindex(src_v)) + src_accessors = _flat_ntd_orig_accessors(src_d.shaped) + rs = map((src_acc, sd) -> _flat_ntdistelem_to_stdmv(trg_d, sd, src_v, src_acc), src_accessors, values(src_d.shaped)) + vcat(rs...) +end + +apply_dist_trafo(trg_d::StdMvDist, src_d::ValueShapes.UnshapedNTD, src_v) = throw(ArgumentError("Invalid variate type $(nameof(typeof(src_v)))) for NamedTupleDist")) + +function apply_dist_trafo(trg_d::StdMvDist, src_d::NamedTupleDist, src_v::Union{NamedTuple,ShapedAsNT}) + src_v_unshaped = unshaped(src_v, varshape(src_d)) + apply_dist_trafo(trg_d, unshaped(src_d), src_v_unshaped) +end + +apply_dist_trafo(trg_d::StdMvDist, src_d::NamedTupleDist, src_v) = throw(ArgumentError("Invalid variate type $(nameof(typeof(src_v))) for NamedTupleDist")) + + +function _stdmv_to_flat_ntdistelem(td::Distribution, src_d::StdMvDist, src_v::AbstractVector{<:Real}, src_acc::ValueAccessor) + sd = view(src_d, ValueShapes.view_idxs(Base.OneTo(length(src_d)), src_acc)) + sv = src_acc(src_v) + apply_dist_trafo(unshaped(td), sd, sv) +end + +function _stdmv_to_flat_ntdistelem(td::ConstValueDist, src_d::StdMvDist, src_v::AbstractVector{<:Real}, src_acc::ValueAccessor) + Bool[] +end + +function apply_dist_trafo(trg_d::ValueShapes.UnshapedNTD, src_d::StdMvDist, src_v::AbstractVector{<:Real}) + @argcheck length(src_d) == length(eachindex(src_v)) + src_accessors = _flat_ntd_eff_accessors(trg_d.shaped) + rs = map((acc, td) -> _stdmv_to_flat_ntdistelem(td, src_d, src_v, acc), src_accessors, values(trg_d.shaped)) + vcat(rs...) +end + +function apply_dist_trafo(trg_d::NamedTupleDist, src_d::StdMvDist, src_v::AbstractVector{<:Real}) + unshaped_result = apply_dist_trafo(unshaped(trg_d), src_d, src_v) + varshape(trg_d)(unshaped_result) +end + +@static if isdefined(Distributions, :ReshapedDistribution) + const AnyReshapedDist = Union{Distributions.ReshapedDistribution,ValueShapes.ReshapedDist} +else + const AnyReshapedDist = Union{Distributions.MatrixReshaped,ValueShapes.ReshapedDist} +end + +eff_totalndof(d::AnyReshapedDist) = eff_totalndof(unshaped(d)) +std_dist_from(src_d::AnyReshapedDist) = std_dist_from(unshaped(src_d)) +std_dist_to(trg_d::AnyReshapedDist) = std_dist_to(unshaped(trg_d)) + +function apply_dist_trafo(trg_d::Distribution{Multivariate}, src_d::AnyReshapedDist, src_v::Any) + src_vs = varshape(src_d) + @argcheck eff_totalndof(trg_d) == eff_totalndof(src_d) + apply_dist_trafo(trg_d, unshaped(src_d), unshaped(src_v, src_vs)) +end + +function apply_dist_trafo(trg_d::AnyReshapedDist, src_d::Distribution{Multivariate}, src_v::AbstractVector{<:Real}) + trg_vs = varshape(trg_d) + @argcheck eff_totalndof(trg_d) == eff_totalndof(src_d) + r = apply_dist_trafo(unshaped(trg_d), src_d, src_v) + trg_vs(r) +end + +function apply_dist_trafo(trg_d::AnyReshapedDist, src_d::AnyReshapedDist, src_v::AbstractVector{<:Real}) + trg_vs = varshape(trg_d) + src_vs = varshape(src_d) + @argcheck totalndof(trg_vs) == totalndof(src_vs) + r = apply_dist_trafo(unshaped(trg_d), unshaped(src_d), unshaped(src_v, src_vs)) + v = trg_vs(r) +end + + +function apply_dist_trafo(trg_d::StdMvDist, src_d::UnshapedHDist, src_v::AbstractVector{<:Real}) + src_v_primary, src_v_secondary = _hd_split(src_d, src_v) + trg_d_primary = typeof(trg_d)(length(eachindex(src_v_primary))) + trg_d_secondary = typeof(trg_d)(length(eachindex(src_v_secondary))) + trg_v_primary = apply_dist_trafo(trg_d_primary, _hd_pridist(src_d), src_v_primary) + trg_v_secondary = apply_dist_trafo(trg_d_secondary, _hd_secdist(src_d, src_v_primary), src_v_secondary) + vcat(trg_v_primary, trg_v_secondary) +end + +function apply_dist_trafo(trg_d::StdMvDist, src_d::HierarchicalDistribution, src_v::Any) + src_v_unshaped = unshaped(src_v, varshape(src_d)) + apply_dist_trafo(trg_d, unshaped(src_d), src_v_unshaped) +end + +function apply_dist_trafo(trg_d::UnshapedHDist, src_d::StdMvDist, src_v::AbstractVector{<:Real}) + src_v_primary, src_v_secondary = _hd_split_efftotalndof(trg_d, src_v) + src_d_primary = typeof(src_d)(length(eachindex(src_v_primary))) + src_d_secondary = typeof(src_d)(length(eachindex(src_v_secondary))) + trg_v_primary = apply_dist_trafo(_hd_pridist(trg_d), src_d_primary, src_v_primary) + trg_v_secondary = apply_dist_trafo(_hd_secdist(trg_d, trg_v_primary), src_d_secondary, src_v_secondary) + vcat(trg_v_primary, trg_v_secondary) +end + +function apply_dist_trafo(trg_d::HierarchicalDistribution, src_d::StdMvDist, src_v::AbstractVector{<:Real}) + unshaped_result = apply_dist_trafo(unshaped(trg_d), src_d, src_v) + varshape(trg_d)(unshaped_result) +end diff --git a/ext/distributions/autodiff_utils.jl b/ext/distributions/autodiff_utils.jl new file mode 100644 index 00000000..6f2e1516 --- /dev/null +++ b/ext/distributions/autodiff_utils.jl @@ -0,0 +1,75 @@ +# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT). + +@inline _adignore_call(f) = f() +@inline _adignore_call_pullback(@nospecialize ΔΩ) = (NoTangent(), NoTangent()) +ChainRulesCore.rrule(::typeof(_adignore_call), f) = _adignore_call(f), _adignore_call_pullback + +macro _adignore(expr) + :(_adignore_call(() -> $(esc(expr)))) +end + + +function _pushfront(v::AbstractVector, x) + T = promote_type(eltype(v), typeof(x)) + r = similar(v, T, length(eachindex(v)) + 1) + r[firstindex(r)] = x + r[firstindex(r)+1:lastindex(r)] = v + r +end + +function ChainRulesCore.rrule(::typeof(_pushfront), v::AbstractVector, x) + result = _pushfront(v, x) + function _pushfront_pullback(thunked_ΔΩ) + ΔΩ = unthunk(thunked_ΔΩ) + (NoTangent(), ΔΩ[firstindex(ΔΩ)+1:lastindex(ΔΩ)], ΔΩ[firstindex(ΔΩ)]) + end + return result, _pushfront_pullback +end + + +function _pushback(v::AbstractVector, x) + T = promote_type(eltype(v), typeof(x)) + r = similar(v, T, length(eachindex(v)) + 1) + r[lastindex(r)] = x + r[firstindex(r):lastindex(r)-1] = v + r +end + +function ChainRulesCore.rrule(::typeof(_pushback), v::AbstractVector, x) + result = _pushback(v, x) + function _pushback_pullback(thunked_ΔΩ) + ΔΩ = unthunk(thunked_ΔΩ) + (NoTangent(), ΔΩ[firstindex(ΔΩ):lastindex(ΔΩ)-1], ΔΩ[lastindex(ΔΩ)]) + end + return result, _pushback_pullback +end + + +_dropfront(v::AbstractVector) = v[firstindex(v)+1:lastindex(v)] + +_dropback(v::AbstractVector) = v[firstindex(v):lastindex(v)-1] + + +_rev_cumsum(xs::AbstractVector) = reverse(cumsum(reverse(xs))) + +function ChainRulesCore.rrule(::typeof(_rev_cumsum), xs::AbstractVector) + result = _rev_cumsum(xs) + function _rev_cumsum_pullback(ΔΩ) + ∂xs = @thunk cumsum(unthunk(ΔΩ)) + (NoTangent(), ∂xs) + end + return result, _rev_cumsum_pullback +end + + +# Equivalent to `cumprod(xs)``: +_exp_cumsum_log(xs::AbstractVector) = exp.(cumsum(log.(xs))) + +function ChainRulesCore.rrule(::typeof(_exp_cumsum_log), xs::AbstractVector) + result = _exp_cumsum_log(xs) + function _exp_cumsum_log_pullback(ΔΩ) + ∂xs = inv.(xs) .* _rev_cumsum(exp.(cumsum(log.(xs))) .* unthunk(ΔΩ)) + (NoTangent(), ∂xs) + end + return result, _exp_cumsum_log_pullback +end diff --git a/ext/distributions/dirac.jl b/ext/distributions/dirac.jl new file mode 100644 index 00000000..9967f46f --- /dev/null +++ b/ext/distributions/dirac.jl @@ -0,0 +1,14 @@ +# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT). + +MeasureBase.AbstractMeasure(obj::Distributions.Dirac) = MeasureBase.Dirac(obj.value) + +function AsMeasure{D}(::D) where {D<:Distributions.Dirac} + throw(ArgumentError("Don't wrap Distributions.Dirac into MeasureBase.AsMeasure, use asmeasure to convert instead.")) +end + + +Distributions.Distribution(m::MeasureBase.Dirac{<:Real}) = Distribtions.Dirac(m.x) + +function Distributions.Distribution(@nospecialize(m::MeasureBase.Dirac{T})) where T + throw(ArgumentError("Can only convert MeasureBase.Dirac{<:Real} to Distributions.Dirac, but not MeasureBase.Dirac{<:$(nameof(T))}")) +end diff --git a/ext/distributions/dirichlet.jl b/ext/distributions/dirichlet.jl new file mode 100644 index 00000000..226532c0 --- /dev/null +++ b/ext/distributions/dirichlet.jl @@ -0,0 +1,33 @@ +# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT). + +const DirichletMeasure = AsMeasure{<:Dirichlet} + +MeasureBase.getdof(m::DirichletMeasure) = length(m.obj) - 1 + +MeasureBase.transport_origin(m::DirichletMeasure) = StdUniform()^getdof(m) + + + +function _dirichlet_beta_trafo(α::Real, β::Real, x::Real) + R = float(promote_type(typeof(α), typeof(β), typeof(x))) + convert(R, transport_def(Beta(α, β), StdUniform(), x))::R +end + +_a_times_one_minus_b(a::Real, b::Real) = a * (1 - b) + +function MeasureBase.from_origin(ν::Dirichlet, x) + # See M. J. Betancourt, "Cruising The Simplex: Hamiltonian Monte Carlo and the Dirichlet Distribution", + # https://arxiv.org/abs/1010.3436 + + # Sanity check (TODO - remove?): + @_adignore @argcheck length(ν) == length(x) + 1 + + αs = _dropfront(_rev_cumsum(ν.alpha)) + βs = _dropback(ν.alpha) + beta_v = fwddiff(_dirichlet_beta_trafo).(αs, βs, x) + beta_v_cp = _exp_cumsum_log(_pushfront(beta_v, 1)) + beta_v_ext = _pushback(beta_v, 0) + fwddiff(_a_times_one_minus_b).(beta_v_cp, beta_v_ext) +end + +# ToDo: MeasureBase.to_origin(ν::Dirichlet, y) diff --git a/ext/distributions/dist_vartransform.jl b/ext/distributions/dist_vartransform.jl new file mode 100644 index 00000000..569fefa8 --- /dev/null +++ b/ext/distributions/dist_vartransform.jl @@ -0,0 +1,16 @@ +# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT). + +const _AnyStdUniform = Union{StandardUniform, Uniform} +const _AnyStdNormal = Union{StandardNormal, Normal} + +const _AnyStdDistribution = Union{_AnyStdUniform, _AnyStdNormal} + +_std_measure(::Type{<:_AnyStdUniform}) = StandardUniform +_std_measure(::Type{<:_AnyStdNormal}) = StandardNormal + +_std_measure(::Type{M}, ::StaticInt{1}) where {M<:_AnyStdDistribution} = M() +_std_measure(::Type{M}, dof::Integer) where {M<:_AnyStdDistribution} = M(dof) +_std_measure_for(::Type{M}, μ::Any) where {M<:_AnyStdDistribution} = _std_measure(_std_measure(M), getdof(μ)) + +MeasureBase.transport_to(::Type{NU}, μ) where {NU<:_AnyStdDistribution} = transport_to(_std_measure_for(NU, μ), μ) +MeasureBase.transport_to(ν, ::Type{MU}) where {MU<:_AnyStdDistribution} = transport_to(ν, _std_measure_for(MU, ν)) diff --git a/ext/distributions/distribution_measure.jl b/ext/distributions/distribution_measure.jl new file mode 100644 index 00000000..e15f66c8 --- /dev/null +++ b/ext/distributions/distribution_measure.jl @@ -0,0 +1,71 @@ +# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT). + + +const DistributionMeasure{F<:VariateForm,S<:ValueSupport,D<:Distribution{F,S}} = AsMeasure{D} + +@inline MeasureBase.AbstractMeasure(obj::Distribution) = AsMeasure{typeof(obj)}(obj) +@inline Base.convert(::Type{AbstractMeasure}, obj::Distribution) = AbstractMeasure(obj) + +@inline Distributions.Distribution(m::DistributionMeasure) = m.obj +@inline Distributions.Distribution{F}(m::DistributionMeasure{F}) where {F<:VariateForm} = Distribution(m) +@inline Distributions.Distribution{F,S}(m::DistributionMeasure{F,S}) where {F<:VariateForm,S<:ValueSupport} = Distribution(m) + +@inline Base.convert(::Type{Distribution}, m::DistributionMeasure) = Distribution(m) +@inline Base.convert(::Type{Distribution{F}}, m::DistributionMeasure{F}) where {F<:VariateForm} = Distribution(m) +@inline Base.convert(::Type{Distribution{F,S}}, m::DistributionMeasure{F,S}) where {F<:VariateForm,S<:ValueSupport} = Distribution(m) + + +Base.rand(rng::AbstractRNG, ::Type{T}, m::DistributionMeasure) where {T<:Real} = convert_realtype(T, rand(m.obj)) + +function _flat_powrand(rng::AbstractRNG, ::Type{T}, d::Distribution{<:ArrayLikeVariate{0}}, sz::Dims) where {T<:Real} + convert_realtype(T, reshape(rand(d, prod(sz)), sz...)) +end + +function _flat_powrand(rng::AbstractRNG, ::Type{T}, d::Distribution{<:ArrayLikeVariate{1}}, sz::Dims) where {T<:Real} + convert_realtype(T, reshape(rand(rng, d, prod(sz)), size(d)..., sz...)) +end + +function _flat_powrand(rng::AbstractRNG, ::Type{T}, d::ReshapedDistribution{N,<:Any,<:Distribution{<:ArrayLikeVariate{1}}}, sz::Dims) where {T<:Real,N} + convert_realtype(T, reshape(rand(rng, d.dist, prod(sz)), d.dims..., sz...)) +end + +function _flat_powrand(rng::AbstractRNG, ::Type{T}, d::Distribution, sz::Dims) where {T<:Real} + flatview(ArrayOfSimilarArrays(convert_realtype(T, rand(rng, d, sz)))) +end + +function Base.rand(rng::AbstractRNG, ::Type{T}, m::PowerMeasure{<:DistributionMeasure{<:ArrayLikeVariate{0}}, NTuple{N,Base.OneTo{Int}}}) where {T<:Real,N} + _flat_powrand(rng, T, m.parent.obj, map(length, m.axes)) +end + +function Base.rand(rng::AbstractRNG, ::Type{T}, m::PowerMeasure{<:DistributionMeasure{<:ArrayLikeVariate{M}}, NTuple{N,Base.OneTo{Int}}}) where {T<:Real,M,N} + flat_data = _flat_powrand(rng, T, m.parent.obj, map(length, m.axes)) + ArrayOfSimilarArrays{T,M,N}(flat_data) +end + + +@inline DensityInterface.densityof(m::DistributionMeasure) = densityof(m.obj) +@inline DensityInterface.logdensityof(m::DistributionMeasure) = logdensityof(m.obj) + +@inline MeasureBase.logdensity_def(m::DistributionMeasure, x) = DensityInterface.logdensityof(m.obj, x) +@inline MeasureBase.unsafe_logdensityof(m::DistributionMeasure, x) = DensityInterface.logdensityof(m.obj, x) +@inline MeasureBase.insupport(m::DistributionMeasure, x) = Distributions.insupport(m.obj, x) + +@inline MeasureBase.rootmeasure(m::DistributionMeasure{<:ArrayLikeVariate{0},<:Continuous}) = Lebesgue() +@inline MeasureBase.rootmeasure(m::DistributionMeasure{<:ArrayLikeVariate,<:Continuous}) = Lebesgue()^size(m.obj) +@inline MeasureBase.rootmeasure(m::DistributionMeasure{<:ArrayLikeVariate{0},<:Discrete}) = Counting() +@inline MeasureBase.rootmeasure(m::DistributionMeasure{<:ArrayLikeVariate,<:Discrete}) = Counting()^size(m.obj) + +@inline MeasureBase.basemeasure(m::DistributionMeasure) = rootmeasure(m) + +@inline MeasureBase.mspace_elsize(m::DistributionMeasure{<:ArrayLikeVariate}) = size(m.obj) + +@inline MeasureBase.getdof(m::DistributionMeasure{<:ArrayLikeVariate{0}}) = 1 + +@inline MeasureBase.paramnames(m::DistributionMeasure) = propertynames(m.obj) +@inline MeasureBase.params(m::DistributionMeasure) = NamedTuple{propertynames(m.obj)}(Distributions.params(m.obj)) + +# @inline MeasureBase.testvalue(m::DistributionMeasure) = testvalue(basemeasure(d)) + + +@inline MeasureBase.basemeasure(d::Distributions.Poisson) = Counting(MeasureBase.BoundedInts(static(0), static(Inf))) +@inline MeasureBase.basemeasure(d::Distributions.Product{<:Any,<:Distributions.Poisson}) = Counting(MeasureBase.BoundedInts(static(0), static(Inf)))^size(d) diff --git a/ext/distributions/distributions.jl b/ext/distributions/distributions.jl new file mode 100644 index 00000000..39a9620a --- /dev/null +++ b/ext/distributions/distributions.jl @@ -0,0 +1,66 @@ +# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT). + +using LinearAlgebra: Diagonal, dot, cholesky + +import Random +using Random: AbstractRNG, rand! + +import DensityInterface +using DensityInterface: logdensityof + +import MeasureBase +using MeasureBase: AbstractMeasure, AsMeasure +using MeasureBase: Lebesgue, Counting, ℝ +using MeasureBase: StdMeasure, StdUniform, StdExponential, StdLogistic +using MeasureBase: PowerMeasure, WeightedMeasure +using MeasureBase: basemeasure, testvalue +using MeasureBase: getdof, checked_arg +using MeasureBase: transport_to, transport_def, transport_origin, from_origin, to_origin +using MeasureBase: NoTransformOrigin, NoTransport + +import Distributions +using Distributions: Distribution, VariateForm, ValueSupport, ContinuousDistribution +using Distributions: Univariate, Multivariate, ArrayLikeVariate, Continuous, Discrete +using Distributions: Uniform, Exponential, Logistic, Normal +using Distributions: MvNormal, Beta, Dirichlet +using Distributions: ReshapedDistribution + +import Statistics +import StatsBase +import StatsFuns +import PDMats + +using IrrationalConstants: log2π, invsqrt2π + +using Static: True, False, StaticInt, static +using FillArrays: Fill, Ones, Zeros + +import ChainRulesCore +using ChainRulesCore: ZeroTangent, NoTangent, unthunk, @thunk + +import ForwardDiff +using ForwardDiffPullbacks: fwddiff + +import Functors +using Functors: fmap + +using ArgCheck: @argcheck + +using ArraysOfArrays: ArrayOfSimilarArrays, flatview + +include("utils.jl") +include("autodiff_utils.jl") +include("standard_dist.jl") +include("standard_uniform.jl") +include("standard_normal.jl") +include("distribution_measure.jl") +include("dist_vartransform.jl") +include("univariate.jl") +include("standardmv.jl") +include("product.jl") +include("reshaped.jl") +include("dirichlet.jl") + +export StdNormal +export DistributionMeasure +export StandardDist diff --git a/ext/distributions/mixture.jl b/ext/distributions/mixture.jl new file mode 100644 index 00000000..587d3e73 --- /dev/null +++ b/ext/distributions/mixture.jl @@ -0,0 +1,4 @@ +# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT). + +# ToDo: +# AbstractMixtureModel: MixtureModel, UnivariateGMM diff --git a/ext/distributions/product.jl b/ext/distributions/product.jl new file mode 100644 index 00000000..07b38299 --- /dev/null +++ b/ext/distributions/product.jl @@ -0,0 +1,17 @@ +# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT). + +@static if isdefined(Distributions, :Product) + MeasureBase.AbstractMeasure(obj::Distributions.Product) = productmeasure(map(asmeasure, obj.v)) + + function AsMeasure{D}(::D) where {D<:Distributions.Product} + throw(ArgumentError("Don't wrap Distributions.Product into MeasureBase.AsMeasure, use asmeasure to convert instead.")) + end +end + +@static if isdefined(Distributions, :ProductDistribution) + MeasureBase.AbstractMeasure(obj::Distributions.ProductDistribution) = productmeasure(map(asmeasure, obj.dists)) + + function AsMeasure{D}(::D) where {D<:Distributions.ProductDistribution} + throw(ArgumentError("Don't wrap Distributions.ProductDistribution into MeasureBase.AsMeasure, use asmeasure to convert instead.")) + end +end diff --git a/ext/distributions/reshaped.jl b/ext/distributions/reshaped.jl new file mode 100644 index 00000000..6cbd5ede --- /dev/null +++ b/ext/distributions/reshaped.jl @@ -0,0 +1,13 @@ +# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT). + +function MeasureBase.AbstractMeasure(d::Distributions.ReshapedDistribution) + orig_dist = d.dist + pushfwd(Reshape(size(d), size(orig_dist)), AbstractMeasure(orig_dist)) +end + +function AsMeasure{D}(::D) where {D<:Distributions.ReshapedDistribution} + throw(ArgumentError("Don't wrap Distributions.ReshapedDistribution into MeasureBase.AsMeasure, use asmeasure to convert instead.")) +end + + +# ToDo: Conversion back to Distribution diff --git a/ext/distributions/standardmv.jl b/ext/distributions/standardmv.jl new file mode 100644 index 00000000..2e1e36ee --- /dev/null +++ b/ext/distributions/standardmv.jl @@ -0,0 +1,33 @@ +# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT). + + +MeasureBase.getdof(m::AsMeasure{<:AbstractMvNormal}) = length(m.obj) + +MeasureBase.transport_origin(ν::MvNormal) = StandardDist{Normal}(length(ν)) + +function MeasureBase.from_origin(ν::MvNormal, x) + A = cholesky(ν.Σ).L + b = ν.μ + muladd(A, x, b) +end + +function MeasureBase.to_origin(ν::MvNormal, y) + A = cholesky(ν.Σ).L + b = ν.μ + A \ (y - b) +end + + +AbstractMvNormal +AbstractMvLogNormal + +#DirichletMultinomial +#Distributions.AbstractMvLogNormal +#Distributions.AbstractMvTDist +#Distributions.ProductDistribution{1} +#Distributions.ReshapedDistribution{1, S, D} where {S<:ValueSupport, D<:(Distribution{<:ArrayLikeVariate, S})} +#JointOrderStatistics +#Multinomial +#MultivariateMixture (alias for AbstractMixtureModel{ArrayLikeVariate{1}}) +#MvLogitNormal +#VonMisesFisher diff --git a/ext/distributions/univariate.jl b/ext/distributions/univariate.jl new file mode 100644 index 00000000..899f615a --- /dev/null +++ b/ext/distributions/univariate.jl @@ -0,0 +1,176 @@ +# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT). + + +@inline MeasureBase.getdof(::Distribution{Univariate}) = static(1) + +@inline MeasureBase.check_dof(a::Distribution{Univariate}, b::Distribution{Univariate}) = nothing + + +# Use ForwardDiff for univariate transformations: +@inline function ChainRulesCore.rrule(::typeof(transport_def), ν::Distribution{Univariate}, μ::Distribution{Univariate}, x::Any) + ChainRulesCore.rrule(fwddiff(transport_def), ν, μ, x) +end +@inline function ChainRulesCore.rrule(::typeof(transport_def), ν::MeasureBase.StdMeasure, μ::Distribution{Univariate}, x::Any) + ChainRulesCore.rrule(fwddiff(transport_def), ν, μ, x) +end +@inline function ChainRulesCore.rrule(::typeof(transport_def), ν::Distribution{Univariate}, μ::MeasureBase.StdMeasure, x::Any) + ChainRulesCore.rrule(fwddiff(transport_def), ν, μ, x) +end + + +# Generic transformations to/from StdUniform via cdf/quantile: + + +_dist_params_numtype(d::Distribution) = promote_type(map(typeof, Distributions.params(d))...) + + +@inline _trafo_cdf(d::Distribution{Univariate,Continuous}, x::Real) = _trafo_cdf_impl(_dist_params_numtype(d), d, x) + +@inline _trafo_cdf_impl(::Type{<:Real}, d::Distribution{Univariate,Continuous}, x::Real) = Distributions.cdf(d, x) + +@inline function _trafo_cdf_impl(::Type{<:Union{Integer,AbstractFloat}}, d::Distribution{Univariate,Continuous}, x::ForwardDiff.Dual{TAG}) where TAG + x_v = ForwardDiff.value(x) + u = Distributions.cdf(d, x_v) + dudx = Distributions.pdf(d, x_v) + ForwardDiff.Dual{TAG}(u, dudx * ForwardDiff.partials(x)) +end + + +@inline _trafo_quantile(d::Distribution{Univariate,Continuous}, u::Real) = _trafo_quantile_impl(_dist_params_numtype(d), d, u) + +@inline _trafo_quantile_impl(::Type{<:Real}, d::Distribution{Univariate,Continuous}, u::Real) = _trafo_quantile_impl_generic(d, u) + +@inline function _trafo_quantile_impl(::Type{<:Union{Integer,AbstractFloat}}, d::Distribution{Univariate,Continuous}, u::ForwardDiff.Dual{TAG}) where {TAG} + x = _trafo_quantile_impl_generic(d, ForwardDiff.value(u)) + dxdu = inv(Distributions.pdf(d, x)) + ForwardDiff.Dual{TAG}(x, dxdu * ForwardDiff.partials(u)) +end + + +@inline _trafo_quantile_impl_generic(d::Distribution{Univariate,Continuous}, u::Real) = Distributions.quantile(d, u) + +# Workaround for Beta dist, ForwardDiff doesn't work for parameters: +@inline _trafo_quantile_impl_generic(d::Beta{T}, u::Real) where {T<:ForwardDiff.Dual} = convert(float(typeof(u)), NaN) +# Workaround for Beta dist, current quantile implementation only supports Float64: +@inline function _trafo_quantile_impl_generic(d::Beta{T}, u::Union{Integer,AbstractFloat}) where {T<:Union{Integer,AbstractFloat}} + Distributions.quantile(d, convert(promote_type(Float64, typeof(u)), u)) +end + +#= +# ToDo: + +# Workaround for rounding errors that can result in quantile values outside of support of Truncated: +@inline function _trafo_quantile_impl_generic(d::Truncated{<:Distribution{Univariate,Continuous}}, u::Real) + x = Distributions.quantile(d, u) + T = typeof(x) + min_x = T(minimum(d)) + max_x = T(maximum(d)) + if x < min_x && isapprox(x, min_x, atol = 4 * eps(T)) + min_x + elseif x > max_x && isapprox(x, max_x, atol = 4 * eps(T)) + max_x + else + x + end +end + +# Workaround for rounding errors that can result in quantile values outside of support of Truncated: +@inline function _trafo_quantile_impl_generic(d::Truncated{<:Distribution{Univariate,Continuous}}, u::Real) + x = Distributions.quantile(d, u) + T = typeof(x) + min_x = T(minimum(d)) + max_x = T(maximum(d)) + if x < min_x && isapprox(x, min_x, atol = 4 * eps(T)) + min_x + elseif x > max_x && isapprox(x, max_x, atol = 4 * eps(T)) + max_x + else + x + end +end +=# + + +@inline function _result_numtype(d::Distribution{Univariate}, x::T) where {T<:Real} + float(promote_type(T, eltype(Distributions.params(d)))) + # firsttype(first(typeof(x), promote_type(map(eltype, Distributions.params(d))...))) +end + + +@inline function MeasureBase.transport_def(::StdUniform, μ::Distribution{Univariate,Continuous}, x) + R = _result_numtype(μ, x) + if Distributions.insupport(μ, x) + y = _trafo_cdf(μ, x) + convert(R, y) + else + convert(R, NaN) + end +end + + +@inline function MeasureBase.transport_def(ν::Distribution{Univariate,Continuous}, ::StdUniform, x::T) where T + R = _result_numtype(ν, x) + TF = float(T) + if 0 <= x <= 1 + # Avoid x ≈ 0 and x ≈ 1 to avoid infinite variate values for target distributions with infinite support: + mod_x = ifelse(x == 0, zero(TF) + eps(TF), ifelse(x == 1, one(TF) - eps(TF), convert(TF, x))) + y = _trafo_quantile(ν, mod_x) + convert(R, y) + else + convert(R, NaN) + end +end + + +# Use standard measures as transformation origin for scaled/translated equivalents: + +function _origin_to_affine(ν::Distribution{Univariate}, y::T) where {T<:Real} + trg_offs, trg_scale = Distributions.location(ν), Distributions.scale(ν) + x = muladd(y, trg_scale, trg_offs) + convert(_result_numtype(ν, y), x) +end + +function _affine_to_origin(μ::Distribution{Univariate}, x::T) where {T<:Real} + src_offs, src_scale = Distributions.location(μ), Distributions.scale(μ) + y = (x - src_offs) / src_scale + convert(_result_numtype(μ, x), y) +end + +for (A, B) in [ + (Uniform, StdUniform), + (Logistic, StdLogistic), + (Normal, StdNormal) +] + @eval begin + @inline MeasureBase.transport_origin(::$A) = $B() + @inline MeasureBase.to_origin(ν::$A, y) = _affine_to_origin(ν, y) + @inline MeasureBase.from_origin(ν::$A, x) = _origin_to_affine(ν, x) + end +end + +@inline MeasureBase.transport_origin(::Exponential) = StdExponential() +@inline MeasureBase.to_origin(ν::Exponential, y) = Distributions.scale(ν) \ y +@inline MeasureBase.from_origin(ν::Exponential, x) = Distributions.scale(ν) * x + + + +# Transform between univariate and single-element power measure + +function MeasureBase.transport_def(ν::Distribution{Univariate}, μ::PowerMeasure{<:StdMeasure}, x) + return transport_def(ν, μ.parent, only(x)) +end + +function MeasureBase.transport_def(ν::PowerMeasure{<:StdMeasure}, μ::Distribution{Univariate}, x) + return Fill(transport_def(ν.parent, μ, only(x)), map(length, ν.axes)...) +end + + +# Transform between univariate and single-element standard multivariate + +function MeasureBase.transport_def(ν::Distribution{Univariate}, μ::StandardDist{D,1}, x) where D + return transport_def(ν, StandardDist{D}(), only(x)) +end + +function MeasureBase.transport_def(ν::StandardDist{D,1}, μ::Distribution{Univariate}, x) where D + return Fill(transport_def(StandardDist{D}(), μ, only(x)), size(ν)...) +end diff --git a/ext/distributions/utils.jl b/ext/distributions/utils.jl new file mode 100644 index 00000000..be146786 --- /dev/null +++ b/ext/distributions/utils.jl @@ -0,0 +1,32 @@ +# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT). + + +""" + convert_realtype(::Type{T}, x) where {T<:Real} + +Convert x to use `T` as it's underlying type for real numbers. +""" +function convert_realtype end + +_convert_realtype_pullback(ΔΩ) = NoTangent(), NoTangent, ΔΩ +ChainRulesCore.rrule(::typeof(convert_realtype), ::Type{T}, x) where T = convert_realtype(T, x), _convert_realtype_pullback + +@inline convert_realtype(::Type{T}, x::T) where {T<:Real} = x +@inline convert_realtype(::Type{T}, x::AbstractArray{T}) where {T<:Real} = x +@inline convert_realtype(::Type{T}, x::U) where {T<:Real,U<:Real} = T(x) +convert_realtype(::Type{T}, x::AbstractArray{U}) where {T<:Real,U<:Real} = T.(x) +convert_realtype(::Type{T}, x) where {T<:Real} = fmap(elem -> convert_realtype(T, elem), x) + + +""" + firsttype(::Type{T}, ::Type{U}) where {T<:Real,U<:Real} + +Return the first type, but as a dual number type if the second one is dual. + +If `U <: ForwardDiff.Dual{tag,<:Real,N}`, returns `ForwardDiff.Dual{tag,T,N}`, +otherwise returns `T` +""" +function firsttype end + +firsttype(::Type{T}, ::Type{U}) where {T<:Real,U<:Real} = T +firsttype(::Type{T}, ::Type{<:ForwardDiff.Dual{tag,<:Real,N}}) where {T<:Real,tag,N} = ForwardDiff.Dual{tag,T,N} diff --git a/src/MeasureBase.jl b/src/MeasureBase.jl index e29c4ae9..9f23fc8a 100644 --- a/src/MeasureBase.jl +++ b/src/MeasureBase.jl @@ -31,15 +31,12 @@ using IntervalSets using PrettyPrinting const Pretty = PrettyPrinting -using ChainRulesCore import FillArrays using Static using Static: StaticInteger using FunctionChains -export ≪ export gentype -export rebase export AbstractMeasure @@ -59,6 +56,43 @@ abstract type AbstractMeasure end AbstractMeasure(m::AbstractMeasure) = m + +""" + asmeasure(m) + +Turns a measure-like object `m` into an `AbstractMeasure`. + +Calls `convert(AbstractMeasure, m)` by default +""" +function asmeasure end + +@inline asmeasure(m::AbstractMeasure) = m +asmeasure(m) = convert(AbstractMeasure, m) +export asmeasure + + +""" + struct AsMeasure{T} + +Wrapes a measure-like object into an `AbstractMeasure`. + +Constructor: + +``` +AsMeasure{T}(obj::T) +``` + +User code should not create instances of `AsMeasure` directly, but should +call `asmeasure(obj)` instead. +""" +struct AsMeasure{T} + obj::T + + AsMeasure{T}(obj::T) = new(obj) +end + + + function Pretty.quoteof(d::M) where {M<:AbstractMeasure} the_names = fieldnames(typeof(d)) :($M($([getfield(d, n) for n in the_names]...))) @@ -108,6 +142,7 @@ using Compat using IrrationalConstants include("static.jl") +include("collection_utils.jl") include("smf.jl") include("getdof.jl") include("transport.jl") @@ -129,16 +164,15 @@ include("primitives/trivial.jl") include("combinators/bind.jl") include("combinators/transformedmeasure.jl") +include("combinators/reshape.jl") include("combinators/weighted.jl") include("combinators/superpose.jl") include("combinators/product.jl") include("combinators/power.jl") include("combinators/spikemixture.jl") include("combinators/likelihood.jl") -include("combinators/pointwise.jl") include("combinators/restricted.jl") include("combinators/smart-constructors.jl") -include("combinators/powerweighted.jl") include("combinators/conditional.jl") include("standard/stdmeasure.jl") @@ -156,6 +190,8 @@ include("density-core.jl") include("interface.jl") +include("measure_operators.jl") + using .Interface end # module MeasureBase diff --git a/src/absolutecontinuity.jl b/src/absolutecontinuity.jl index 8062198c..c65aeaf3 100644 --- a/src/absolutecontinuity.jl +++ b/src/absolutecontinuity.jl @@ -54,3 +54,6 @@ # representative(μ) ≪ representative(ν) && return true # return false # end + +# ≪(::M, ::WeightedMeasure{R,M}) where {R,M} = true +# ≪(::WeightedMeasure{R,M}, ::M) where {R,M} = true diff --git a/src/collection_utils.jl b/src/collection_utils.jl new file mode 100644 index 00000000..1de51f7e --- /dev/null +++ b/src/collection_utils.jl @@ -0,0 +1,24 @@ +function _pushfront(v::AbstractVector, x) + T = promote_type(eltype(v), typeof(x)) + r = similar(v, T, length(eachindex(v)) + 1) + r[firstindex(r)] = x + r[firstindex(r)+1:lastindex(r)] = v + r +end + +function _pushback(v::AbstractVector, x) + T = promote_type(eltype(v), typeof(x)) + r = similar(v, T, length(eachindex(v)) + 1) + r[lastindex(r)] = x + r[firstindex(r):lastindex(r)-1] = v + r +end + +_dropfront(v::AbstractVector) = v[firstindex(v)+1:lastindex(v)] + +_dropback(v::AbstractVector) = v[firstindex(v):lastindex(v)-1] + +_rev_cumsum(xs::AbstractVector) = reverse(cumsum(reverse(xs))) + +# Equivalent to `cumprod(xs)``: +_exp_cumsum_log(xs::AbstractVector) = exp.(cumsum(log.(xs))) diff --git a/src/combinators/bind.jl b/src/combinators/bind.jl index cc2022f2..465f2bf7 100644 --- a/src/combinators/bind.jl +++ b/src/combinators/bind.jl @@ -1,36 +1,45 @@ +""" + struct MeasureBase.Bind{M,K} <: AbstractMeasure + +Represents a monatic bind. User code should create instances of `Bind` +directly, but should call `mbind(k, μ)` instead. +""" struct Bind{M,K} <: AbstractMeasure - μ::M k::K + μ::M end -export ↣ +getdof(d::Bind) = NoDOF{typeof(d)}() + +function Base.rand(rng::AbstractRNG, ::Type{T}, d::Bind) where {T} + x = rand(rng, T, d.μ) + y = rand(rng, T, d.k(x)) + return y +end """ -If -- μ is an `AbstractMeasure` or satisfies the Measure interface, and -- k is a function taking values from the support of μ and returning a measure + mbind(k, μ)::AbstractMeasure + +Given -Then `μ ↣ k` is a measure, called a *monadic bind*. In a -probabilistic programming language like Soss.jl, this could be expressed as +- a measure μ +- a kernel function k that takes values from the support of μ and returns a + measure -Note that bind is usually written `>>=`, but this symbol is unavailable in Julia. +The *monadic bind* operation `mbind(k, μ)` returns is a new measure. +If `ν == mbind(k, μ)` and all measures involved are sampleable, then +samples from `rand(ν)` follow the same distribution as those from `rand(k(rand(μ)))`. + + +A monadic bind ofen written as `>>=` (e.g. in Haskell), but this symbol is +unavailable in Julia. ``` -bind = @model μ,k begin - x ~ μ - y ~ k(x) - return y +μ = StdExponential() +ν = mbind(μ) do scale + pushfwd(Base.Fix1(*, scale), StdNormal()) end ``` - -See also `bind` and `Bind` """ -↣(μ, k) = bind(μ, k) - -bind(μ, k) = Bind(μ, k) - -function Base.rand(rng::AbstractRNG, ::Type{T}, d::Bind) where {T} - x = rand(rng, T, d.μ) - y = rand(rng, T, d.k(x)) - return y -end +mbind(k, μ) = Bind(k, μ) +export mbind diff --git a/src/combinators/likelihood.jl b/src/combinators/likelihood.jl index 6dfd164f..b244fd0f 100644 --- a/src/combinators/likelihood.jl +++ b/src/combinators/likelihood.jl @@ -11,9 +11,9 @@ abstract type AbstractLikelihood end # insupport(ℓ::AbstractLikelihood, p) = insupport(ℓ.k(p), ℓ.x) @doc raw""" - Likelihood(k::AbstractTransitionKernel, x) + Likelihood(k, x) -"Observe" a value `x`, yielding a function from the parameters to ℝ. +Default result of [`likelihoodof(k, x)`](@ref). Likelihoods are most commonly used in conjunction with an existing _prior_ measure to yield a new measure, the _posterior_. In Bayes's Law, we have @@ -64,39 +64,12 @@ With several parameters, things work as expected: --------- - Likelihood(M<:ParameterizedMeasure, constraint::NamedTuple, x) - -In some cases the measure might have several parameters, and we may want the -(log-)likelihood with respect to some subset of them. In this case, we can use -the three-argument form, where the second argument is a constraint. For example, - - julia> ℓ = Likelihood(Normal{(:μ,:σ)}, (σ=3.0,), 2.0) - Likelihood(Normal{(:μ, :σ), T} where T, (σ = 3.0,), 2.0) - -Similarly to the above, we have - - julia> density_def(ℓ, (μ=2.0,)) - 0.3333333333333333 - - julia> logdensity_def(ℓ, (μ=2.0,)) - -1.0986122886681098 - - julia> density_def(ℓ, 2.0) - 0.3333333333333333 - - julia> logdensity_def(ℓ, 2.0) - -1.0986122886681098 - ------------------------ - Finally, let's return to the expression for Bayes's Law, -``P(θ|x) ∝ P(θ) P(x|θ)`` +``P(θ|x) ∝ P(x|θ) P(θ)`` -The product on the right side is computed pointwise. To work with this in -MeasureBase, we have a "pointwise product" `⊙`, which takes a measure and a -likelihood, and returns a new measure, that is, the unnormalized posterior that -has density ``P(θ) P(x|θ)`` with respect to the base measure of the prior. +In measure theory, the product on the right side is the Lebesgue integral +of the likelihood with respect to the prior. For example, say we have @@ -104,23 +77,27 @@ For example, say we have x ~ Normal(μ,σ) σ = 1 -and we observe `x=3`. We can compute the posterior measure on `μ` as +and we observe `x=3`. We can compute the (non-normalized) posterior measure on +`μ` as - julia> post = Normal() ⊙ Likelihood(Normal{(:μ, :σ)}, (σ=1,), 3) - Normal() ⊙ Likelihood(Normal{(:μ, :σ), T} where T, (σ = 1,), 3) - - julia> logdensity_def(post, 2) - -2.5 + julia> prior = Normal() + julia> likelihood = Likelihood(μ -> Normal(μ, 1), 3) + julia> post = mintegrate(likelihood, prior) + julia> post isa MeasureBase.DensityMeasure + true + julia> logdensity_rel(post, Lebesgue(), 2) + -4.337877066409345 """ struct Likelihood{K,X} <: AbstractLikelihood k::K x::X - Likelihood(k::K, x::X) where {K<:AbstractTransitionKernel,X} = new{K,X}(k, x) - Likelihood(k::K, x::X) where {K<:Function,X} = new{K,X}(k, x) - Likelihood(μ, x) = Likelihood(kernel(μ), x) + Likelihood{K,X}(k, x) where {K,X} = new{K,X}(k, x) end +# For type stability, in case k is a type (resp. a constructor): +Likelihood(k, x::X) where {X} = Likelihood{Core.Typeof(k),X}(k, x) + (lik::AbstractLikelihood)(p) = exp(ULogarithmic, logdensityof(lik.k(p), lik.x)) DensityInterface.DensityKind(::AbstractLikelihood) = IsDensity() @@ -150,58 +127,86 @@ end export likelihoodof -""" - likelihoodof(k::AbstractTransitionKernel, x; constraints...) - likelihoodof(k::AbstractTransitionKernel, x, constraints::NamedTuple) +@doc raw""" + likelihoodof(k, x) -A likelihood is *not* a measure. Rather, a likelihood acts on a measure, through -the "pointwise product" `⊙`, yielding another measure. -""" -function likelihoodof end +Returns the likelihood of observing `x` under a family of probability +measures that is generated by a transition kernel `k(θ)`. + +`k(θ)` maps points in the parameter space to measures (resp. objects that can +be converted to measures) on a implicit set `Χ` that contains values like `x`. -likelihoodof(k, x, ::NamedTuple{()}) = Likelihood(k, x) +`likelihoodof(k, x)` returns a likelihood object. A likelihhood is **not** a +measure, it is a function from the parameter space to `ℝ₊`. Likelihood +objects can also be interpreted as "generic densities" (but **not** as +probability densities). -likelihoodof(k, x; kwargs...) = likelihoodof(k, x, NamedTuple(kwargs)) +`likelihoodof(k, x)` implicitly chooses `ξ = rootmeasure(k(θ))` as the +reference measure on the observation set `Χ`. Note that this implicit +`ξ` **must** be independent of `θ`. -likelihoodof(k, x, pars::NamedTuple) = likelihoodof(kernel(k, pars), x) +`ℒₓ = likelihoodof(k, x)` has the mathematical interpretation -likelihoodof(k::AbstractTransitionKernel, x) = Likelihood(k, x) +```math +\mathcal{L}_x(\theta) = \frac{\rm{d}\, k(\theta)}{\rm{d}\, \chi}(x) +``` -export log_likelihood_ratio +`likelihoodof` must return an object that implements the +[`DensityInterface`](https://github.com/JuliaMath/DensityInterface.jl)` API +and `ℒₓ = likelihoodof(k, x)` must satisfy +```julia +log(ℒₓ(θ)) == logdensityof(ℒₓ, θ) ≈ logdensityof(k(θ), x) + +DensityKind(ℒₓ) isa IsDensity +``` + +By default, an instance of [`MeasureBase.Likelihood`](@ref) is returned. """ - log_likelihood_ratio(ℓ::Likelihood, p, q) +function likelihoodof end -Compute the log of the likelihood ratio, in order to compare two choices for -parameters. This is computed as +likelihoodof(k, x) = Likelihood(k, x) - logdensity_rel(ℓ.k(p), ℓ.k(q), ℓ.x) +############################################################################### +# At the least, we need to think through in some more detail whether +# (log-)likelihood ratios expressed in this way are correct and useful. For now +# this code is commented out; we may remove it entirely in the future. -Since `logdensity_rel` can leave common base measure unevaluated, this can be -more efficient than +# export log_likelihood_ratio - logdensityof(ℓ.k(p), ℓ.x) - logdensityof(ℓ.k(q), ℓ.x) -""" -log_likelihood_ratio(ℓ::Likelihood, p, q) = logdensity_rel(ℓ.k(p), ℓ.k(q), ℓ.x) +# """ +# log_likelihood_ratio(ℓ::Likelihood, p, q) -# likelihoodof(k, x; kwargs...) = likelihoodof(k, x, NamedTuple(kwargs)) +# Compute the log of the likelihood ratio, in order to compare two choices for +# parameters. This is computed as -export likelihood_ratio +# logdensity_rel(ℓ.k(p), ℓ.k(q), ℓ.x) -""" - likelihood_ratio(ℓ::Likelihood, p, q) +# Since `logdensity_rel` can leave common base measure unevaluated, this can be +# more efficient than -Compute the log of the likelihood ratio, in order to compare two choices for -parameters. This is equal to +# logdensityof(ℓ.k(p), ℓ.x) - logdensityof(ℓ.k(q), ℓ.x) +# """ +# log_likelihood_ratio(ℓ::Likelihood, p, q) = logdensity_rel(ℓ.k(p), ℓ.k(q), ℓ.x) - density_rel(ℓ.k(p), ℓ.k(q), ℓ.x) +# # likelihoodof(k, x; kwargs...) = likelihoodof(k, x, NamedTuple(kwargs)) -but is computed using LogarithmicNumbers.jl to avoid underflow and overflow. -Since `density_rel` can leave common base measure unevaluated, this can be -more efficient than +# export likelihood_ratio - logdensityof(ℓ.k(p), ℓ.x) - logdensityof(ℓ.k(q), ℓ.x) -""" -function likelihood_ratio(ℓ::Likelihood, p, q) - exp(ULogarithmic, logdensity_rel(ℓ.k(p), ℓ.k(q), ℓ.x)) -end +# """ +# likelihood_ratio(ℓ::Likelihood, p, q) + +# Compute the log of the likelihood ratio, in order to compare two choices for +# parameters. This is equal to + +# density_rel(ℓ.k(p), ℓ.k(q), ℓ.x) + +# but is computed using LogarithmicNumbers.jl to avoid underflow and overflow. +# Since `density_rel` can leave common base measure unevaluated, this can be +# more efficient than + +# logdensityof(ℓ.k(p), ℓ.x) - logdensityof(ℓ.k(q), ℓ.x) +# """ +# function likelihood_ratio(ℓ::Likelihood, p, q) +# exp(ULogarithmic, logdensity_rel(ℓ.k(p), ℓ.k(q), ℓ.x)) +# end diff --git a/src/combinators/pointwise.jl b/src/combinators/pointwise.jl deleted file mode 100644 index 778e7f4e..00000000 --- a/src/combinators/pointwise.jl +++ /dev/null @@ -1,30 +0,0 @@ -export ⊙ - -struct PointwiseProductMeasure{P,L} <: AbstractMeasure - prior::P - likelihood::L -end - -iterate(p::PointwiseProductMeasure, i = 1) = iterate((p.prior, p.likelihood), i) - -function Pretty.tile(d::PointwiseProductMeasure) - Pretty.pair_layout(Pretty.tile(d.prior), Pretty.tile(d.likelihood), sep = " ⊙ ") -end - -⊙(prior, ℓ) = pointwiseproduct(prior, ℓ) - -@inbounds function insupport(d::PointwiseProductMeasure, p) - prior, ℓ = d - istrue(insupport(prior, p)) && istrue(insupport(ℓ, p)) -end - -@inline function logdensity_def(d::PointwiseProductMeasure, p) - prior, ℓ = d - unsafe_logdensityof(ℓ, p) -end - -basemeasure(d::PointwiseProductMeasure) = d.prior - -function gentype(d::PointwiseProductMeasure) - gentype(d.prior) -end diff --git a/src/combinators/powerweighted.jl b/src/combinators/powerweighted.jl deleted file mode 100644 index 47f50da4..00000000 --- a/src/combinators/powerweighted.jl +++ /dev/null @@ -1,37 +0,0 @@ -export ↑ - -struct PowerWeightedMeasure{M,A} <: AbstractMeasure - parent::M - exponent::A -end - -logdensity_def(d::PowerWeightedMeasure, x) = d.exponent * logdensity_def(d.parent, x) - -basemeasure(d::PowerWeightedMeasure, x) = basemeasure(d.parent, x)↑d.exponent - -basemeasure(d::PowerWeightedMeasure) = basemeasure(d.parent)↑d.exponent - -function powerweightedmeasure(d, α) - isone(α) && return d - PowerWeightedMeasure(d, α) -end - -(d::AbstractMeasure)↑α = powerweightedmeasure(d, α) - -insupport(d::PowerWeightedMeasure, x) = insupport(d.parent, x) - -function Base.show(io::IO, d::PowerWeightedMeasure) - print(io, d.parent, " ↑ ", d.exponent) -end - -function powerweightedmeasure(d::PowerWeightedMeasure, α) - powerweightedmeasure(d.parent, α * d.exponent) -end - -function powerweightedmeasure(d::WeightedMeasure, α) - weightedmeasure(α * d.logweight, powerweightedmeasure(d.base, α)) -end - -function Pretty.tile(d::PowerWeightedMeasure) - Pretty.pair_layout(Pretty.tile(d.parent), Pretty.tile(d.exponent), sep = " ↑ ") -end diff --git a/src/combinators/product.jl b/src/combinators/product.jl index cb7a0aaf..516678f5 100644 --- a/src/combinators/product.jl +++ b/src/combinators/product.jl @@ -167,19 +167,6 @@ function testvalue(::Type{T}, d::AbstractProductMeasure) where {T} _map(m -> testvalue(T, m), marginals(d)) end -export ⊗ - -""" - ⊗(μs::AbstractMeasure...) - -`⊗` is a binary operator for building product measures. This satisfies the law - -``` - basemeasure(μ ⊗ ν) == basemeasure(μ) ⊗ basemeasure(ν) -``` -""" -⊗(μs::AbstractMeasure...) = productmeasure(μs) - ############################################################################### # I <: Base.Generator diff --git a/src/combinators/reshape.jl b/src/combinators/reshape.jl new file mode 100644 index 00000000..1f24ca85 --- /dev/null +++ b/src/combinators/reshape.jl @@ -0,0 +1,49 @@ +# ToDo: Support static resizes for static arrays + +""" + struct MeasureBase.Reshape <: Function + +Represents a function that reshapes an array. + +Supports `InverseFunctions.inverse` and +`ChangesOfVariables.with_logabsdet_jacobian`. + +Constructor: + +```julia +Reshape(output_size::Dims, input_size::Dims) +``` +""" +struct Reshape{M,N} <: Function + output_size::NTuple{M,Int} + input_size::NTuple{N,Int} +end + +_throw_reshape_mismatch(sz, sz_x) = throw(DimensionMismatch("Reshape input size is $sz but got input of size $sz_x")) + +function (f::Reshape)(x::AbstractArray) + sz_x = size(x) + f.input_size == sz_x || _throw_reshape_mismatch(f.input_size, sz_x) + return reshape(x, f.output_size) +end + +InverseFunctions.inverse(f::Reshape) = Reshape(f.input_size, f.output_size) + +ChangesOfVariables.with_logabsdet_jacobian(::Reshape, x::AbstractArray) = zero(real_numtype(typeof(x))) + + +""" + mreshape(m::AbstractMeasure, sz::Vararg{N,Integer}) where N + mreshape(m::AbstractMeasure, sz::NTuple{N,Integer}) where N + +Reshape a measure `m` over an array-valued space, returning a measure over +a space of arrays with shape `sz`. +""" +function mreshape end + +_elsize_for_reshape(m::AbstractMeasure) = _elsize_for_reshape(some_mspace_elsize(m), m) +_elsize_for_reshape(sz::NTuple{<:Any,Integer}, ::AbstractMeasure) = sz +_elsize_for_reshape(::NoMSpaceElementSize, m::AbstractMeasure) = size(testvalue(m)) + +mreshape(m::AbstractMeasure, sz::Vararg{<:Any,Integer}) = mreshape(m, sz) +mreshape(m::AbstractMeasure, sz::NTuple{<:Any,Integer}) = pushfwd(Reshape(sz, _elsize_for_reshape(m)), m) diff --git a/src/combinators/transformedmeasure.jl b/src/combinators/transformedmeasure.jl index 803b404b..dab76d5f 100644 --- a/src/combinators/transformedmeasure.jl +++ b/src/combinators/transformedmeasure.jl @@ -140,7 +140,7 @@ end # pullback """ - pullback(f, μ, volcorr = WithVolCorr()) + pullbck(f, μ, volcorr = WithVolCorr()) A _pullback_ is a dual concept to a _pushforward_. While a pushforward needs a map _from_ the support of a measure, a pullback requires a map _into_ the @@ -152,8 +152,11 @@ in terms of the inverse function; the "forward" function is not used at all. In some cases, we may be focusing on log-density (and not, for example, sampling). To manually specify an inverse, call -`pullback(InverseFunctions.setinverse(f, finv), μ, volcorr)`. +`pullbck(InverseFunctions.setinverse(f, finv), μ, volcorr)`. """ -function pullback(f, μ, volcorr::TransformVolCorr = WithVolCorr()) - pushfwd(setinverse(inverse(f), f), μ, volcorr) +function pullbck(f, μ, volcorr::TransformVolCorr = WithVolCorr()) + PushforwardMeasure(inverse(f), f, μ, volcorr) end +export pullbck + +@deprecate pullback(f, μ, volcorr::TransformVolCorr = WithVolCorr()) pullbck(f, μ, volcorr) diff --git a/src/combinators/weighted.jl b/src/combinators/weighted.jl index db239b50..124662b6 100644 --- a/src/combinators/weighted.jl +++ b/src/combinators/weighted.jl @@ -46,9 +46,6 @@ end Base.:*(m::AbstractMeasure, k::Real) = k * m -≪(::M, ::WeightedMeasure{R,M}) where {R,M} = true -≪(::WeightedMeasure{R,M}, ::M) where {R,M} = true - gentype(μ::WeightedMeasure) = gentype(μ.base) insupport(μ::WeightedMeasure, x) = insupport(μ.base, x) diff --git a/src/density-core.jl b/src/density-core.jl index c8c861ee..f46c431b 100644 --- a/src/density-core.jl +++ b/src/density-core.jl @@ -35,14 +35,6 @@ end _checksupport(cond, result) = ifelse(cond == true, result, oftype(result, -Inf)) -import ChainRulesCore -@inline function ChainRulesCore.rrule(::typeof(_checksupport), cond, result) - y = _checksupport(cond, result) - function _checksupport_pullback(ȳ) - return NoTangent(), ZeroTangent(), one(ȳ) - end - y, _checksupport_pullback -end export unsafe_logdensityof diff --git a/src/density.jl b/src/density.jl index 4862dcb1..57367ec5 100644 --- a/src/density.jl +++ b/src/density.jl @@ -20,8 +20,7 @@ For measures `μ` and `ν`, `Density(μ,ν)` represents the _density function_ `dμ/dν`, also called the _Radom-Nikodym derivative_: https://en.wikipedia.org/wiki/Radon%E2%80%93Nikodym_theorem#Radon%E2%80%93Nikodym_derivative -Instead of calling this directly, users should call `density_rel(μ, ν)` or -its abbreviated form, `𝒹(μ,ν)`. +Instead of calling this directly, users should call `density_rel(μ, ν)`. """ struct Density{M,B} <: AbstractDensity μ::M @@ -32,16 +31,6 @@ Base.:∘(::typeof(log), d::Density) = logdensity_rel(d.μ, d.base) Base.log(d::Density) = log ∘ d -export 𝒹 - -""" - 𝒹(μ, base) - -Compute the density (Radom-Nikodym derivative) of μ with respect to `base`. This -is a shorthand form for `density_rel(μ, base)`. -""" -𝒹(μ, base) = density_rel(μ, base) - density_rel(μ, base) = Density(μ, base) (f::Density)(x) = density_rel(f.μ, f.base, x) @@ -73,16 +62,6 @@ Base.:∘(::typeof(exp), d::LogDensity) = density_rel(d.μ, d.base) Base.exp(d::LogDensity) = exp ∘ d -export log𝒹 - -""" - log𝒹(μ, base) - -Compute the log-density (Radom-Nikodym derivative) of μ with respect to `base`. -This is a shorthand form for `logdensity_rel(μ, base)` -""" -log𝒹(μ, base) = logdensity_rel(μ, base) - logdensity_rel(μ, base) = LogDensity(μ, base) (f::LogDensity)(x) = logdensity_rel(f.μ, f.base, x) @@ -98,12 +77,13 @@ DensityInterface.funcdensity(d::LogDensity) = throw(MethodError(funcdensity, (d, base :: B end -A `DensityMeasure` is a measure defined by a density or log-density with respect -to some other "base" measure. +A `DensityMeasure` is a measure defined by a density or log-density with +respect to some other "base" measure. -Users should not call `DensityMeasure` directly, but should instead call `∫(f, -base)` (if `f` is a density function or `DensityInterface.IsDensity` object) or -`∫exp(f, base)` (if `f` is a log-density function). +Users should not instantiate `DensityMeasure` directly, but should instead +call `mintegral_exp(f, base)` (if `f` is a density function or +`DensityInterface.IsDensity` object) or `mintegral_exp(f, base)` (if `f` +is a log-density function). """ struct DensityMeasure{F,B} <: AbstractMeasure f::F @@ -120,56 +100,84 @@ end end function Pretty.tile(μ::DensityMeasure{F,B}) where {F,B} - result = Pretty.literal("DensityMeasure ∫(") + result = Pretty.literal("mintegrate(") result *= Pretty.pair_layout(Pretty.tile(μ.f), Pretty.tile(μ.base); sep = ", ") result *= Pretty.literal(")") end -export ∫ +basemeasure(μ::DensityMeasure) = μ.base -""" - ∫(f, base::AbstractMeasure) +logdensity_def(μ::DensityMeasure, x) = logdensityof(μ.f, x) -Define a new measure in terms of a density `f` over some measure `base`. -""" -∫(f, base) = _densitymeasure(f, base, DensityKind(f)) +density_def(μ::DensityMeasure, x) = densityof(μ.f, x) -_densitymeasure(f, base, ::IsDensity) = DensityMeasure(f, base) -function _densitymeasure(f, base, ::HasDensity) - @error "`∫(f, base)` requires `DensityKind(f)` to be `IsDensity()` or `NoDensity()`." -end -_densitymeasure(f, base, ::NoDensity) = DensityMeasure(funcdensity(f), base) +@doc raw""" + mintegrate(f, μ::AbstractMeasure)::AbstractMeasure -export ∫exp +Returns a new measure that represents the indefinite +[integral](https://en.wikipedia.org/wiki/Radon%E2%80%93Nikodym_theorem) +of `f` with respect to `μ`. -""" - ∫exp(f, base::AbstractMeasure) +`ν = mintegrate(f, μ)` generates a measure `ν` that has the mathematical +interpretation -Define a new measure in terms of a log-density `f` over some measure `base`. +math``` +\nu(A) = \int_A f(a) \, \rm{d}\mu(a) +``` """ -∫exp(f, base) = _logdensitymeasure(f, base, DensityKind(f)) +function mintegrate end +export mintegrate -function _logdensitymeasure(f, base, ::IsDensity) - @error "`∫exp(f, base)` is not valid when `DensityKind(f) == IsDensity()`. Use `∫(f, base)` instead." -end -function _logdensitymeasure(f, base, ::HasDensity) - @error "`∫exp(f, base)` is not valid when `DensityKind(f) == HasDensity()`." +mintegrate(f, μ::AbstractMeasure) = _mintegrate_impl(f, μ, DensityKind(f)) + +_mintegrate_impl(f, μ, ::IsDensity) = DensityMeasure(f, μ) +function _mintegrate_impl(f, μ, ::HasDensity) + throw( + ArgumentError( + "`mintegrate(f, mu)` requires `DensityKind(f)` to be `IsDensity()` or `NoDensity()`.", + ), + ) end -_logdensitymeasure(f, base, ::NoDensity) = DensityMeasure(logfuncdensity(f), base) +_mintegrate_impl(f, μ, ::NoDensity) = DensityMeasure(funcdensity(f), μ) -basemeasure(μ::DensityMeasure) = μ.base +@doc raw""" + mintegrate_exp(log_f, μ::AbstractMeasure) -logdensity_def(μ::DensityMeasure, x) = logdensityof(μ.f, x) +Given a function `log_f` that semantically represents the log of a function +`f`, `mintegrate` returns a new measure that represents the indefinite +[integral](https://en.wikipedia.org/wiki/Radon%E2%80%93Nikodym_theorem) +of `f` with respect to `μ`. -density_def(μ::DensityMeasure, x) = densityof(μ.f, x) +`ν = mintegrate_exp(log_f, μ)` generates a measure `ν` that has the +mathematical interpretation -""" - rebase(μ, ν) - -Express `μ` in terms of a density over `ν`. Satisfies +math``` +\nu(A) = \int_A e^{log(f(a))} \, \rm{d}\mu(a) = \int_A f(a) \, \rm{d}\mu(a) ``` -basemeasure(rebase(μ, ν)) == ν -density(rebase(μ, ν)) == 𝒹(μ,ν) -``` + +Note that `exp(log_f(...))` is usually not run explicitly, calculations that +involve the resulting measure are typically performed in log-space, +internally. """ -rebase(μ, ν) = ∫(𝒹(μ, ν), ν) +function mintegrate_exp end +export mintegrate_exp + +function mintegrate_exp(log_f, μ::AbstractMeasure) + _mintegrate_exp_impl(log_f, μ, DensityKind(log_f)) +end + +function _mintegrate_exp_impl(log_f, μ, ::IsDensity) + throw( + ArgumentError( + "`mintegrate_exp(log_f, μ)` is not valid when `DensityKind(log_f) == IsDensity()`. Use `mintegrate(log_f, μ)` instead.", + ), + ) +end +function _mintegrate_exp_impl(log_f, μ, ::HasDensity) + throw( + ArgumentError( + "`mintegrate_exp(log_f, μ)` is not valid when `DensityKind(log_f) == HasDensity()`.", + ), + ) +end +_mintegrate_exp_impl(log_f, μ, ::NoDensity) = DensityMeasure(logfuncdensity(log_f), μ) diff --git a/src/getdof.jl b/src/getdof.jl index 4496b7f2..930c0aac 100644 --- a/src/getdof.jl +++ b/src/getdof.jl @@ -51,6 +51,7 @@ end _check_dof_pullback(ΔΩ) = NoTangent(), NoTangent(), NoTangent() ChainRulesCore.rrule(::typeof(check_dof), ν, μ) = check_dof(ν, μ), _check_dof_pullback + """ MeasureBase.NoArgCheck{MU,T} @@ -78,6 +79,3 @@ end @propagate_inbounds function checked_arg(mu::MU, x) where {MU} _default_checked_arg(MU, basemeasure(mu), x) end - -_checked_arg_pullback(ΔΩ) = NoTangent(), NoTangent(), ΔΩ -ChainRulesCore.rrule(::typeof(checked_arg), ν, x) = checked_arg(ν, x), _checked_arg_pullback diff --git a/src/insupport.jl b/src/insupport.jl index 5184917d..a9a96363 100644 --- a/src/insupport.jl +++ b/src/insupport.jl @@ -18,11 +18,6 @@ Checks if `x` is in the support of distribution/measure `μ`, throws an """ function require_insupport end -_require_insupport_pullback(ΔΩ) = NoTangent(), ZeroTangent() -function ChainRulesCore.rrule(::typeof(require_insupport), μ, x) - return require_insupport(μ, x), _require_insupport_pullback -end - function require_insupport(μ, x) if !insupport(μ, x) throw(ArgumentError("x is not within the support of μ")) diff --git a/src/measure_operators.jl b/src/measure_operators.jl new file mode 100644 index 00000000..5822d4de --- /dev/null +++ b/src/measure_operators.jl @@ -0,0 +1,131 @@ +""" + module MeasureOperators + +Defines the following operators for measures: + +* `f ⋄ μ == pushfwd(f, μ)` + +* `μ ⊙ f == inverse(f) ⋄ μ` +""" +module MeasureOperators + +using MeasureBase: AbstractMeasure +using MeasureBase: pushfwd, pullbck, mbind, productmeasure +using MeasureBase: mintegrate, mintegrate_exp, density_rel, logdensity_rel +using InverseFunctions: inverse +using Reexport: @reexport + +@doc raw""" + ⋄(f, μ::AbstractMeasure) = pushfwd(f, μ) + +The `\\diamond` operator denotes a pushforward operation: `ν = f ⋄ μ` +generates a +[pushforward measure](https://en.wikipedia.org/wiki/Pushforward_measure). + +A common mathematical notation for a pushforward is ``f_*μ``, but as +there is no "subscript-star" operator in Julia, we use `⋄`. + +See [`pushfwd(f, μ)`](@ref) for details. + +Also see [`ν ⊙ f`](@ref), the pullback operator. +""" +⋄(f, μ::AbstractMeasure) = pushfwd(f, μ) +export ⋄ + +@doc raw""" + ⊙(ν::AbstractMeasure, f) = pullbck(f, ν) + +The `\\odot` operator denotes a pullback operation. + +See also [`pullbck(ν, f)`](@ref) for details. Note that `pullbck` takes it's +arguments in different order, in keeping with the Julia convention of +passing functions as the first argument. A pullback is mathematically the +precomposition of a measure `μ`` with the function `f` applied to sets. so +`⊙` takes the measure as the first and the function as the second argument, +as common in mathematical notation for precomposition. + +A common mathematical notation for pullback in measure theory is +``f \circ μ``, but as `∘` is used for function composition in Julia and as +`f` semantically acts point-wise on sets, we use `⊙`. + +Also see [f ⋄ μ](@ref), the pushforward operator. +""" +⊙(ν::AbstractMeasure, f) = pullbck(f, ν) +export ⊙ + +""" + μ ▷ k = mbind(k, μ) + +The `\\triangleright` operator denotes a measure monadic bind operation. + +A common operator choice for a monadics bind operator is `>>=` (e.g. in +the Haskell programming language), but this has a different meaning in +Julia and there is no close equivalent, so we use `▷`. + +See [`mbind(k, μ)`](@ref) for details. Note that `mbind` takes its +arguments in different order, in keeping with the Julia convention of +passing functions as the first argument. `▷`, on the other hand, takes +its arguments in the order common for monadic binds in functional +programming (like the Haskell `>>=` operator) and mathematics. +""" +▷(μ::AbstractMeasure, k) = mbind(k, μ) +export ▷ + +# ToDo: Use `⨂` instead of `⊗` for better readability? +""" + ⊗(μs::AbstractMeasure...) = productmeasure(μs) + +`⊗` is an operator for building product measures. + +See [`productmeasure(μs)`](@ref) for details. +""" +⊗(μs::AbstractMeasure...) = productmeasure(μs) +export ⊗ + +""" + ∫(f, μ::AbstractMeasure) = mintegrate(f, μ) + +Denotes an indefinite integral of the function `f` with respect to the +measure `μ`. + +See [`mintegrate(f, μ)`](@ref) for details. +""" +∫(f, μ::AbstractMeasure) = mintegrate(f, μ) +export ∫ + +""" + ∫exp(f, μ::AbstractMeasure) = mintegrate_exp(f, μ) + +Generates a new measure that is the indefinite integral of `exp` of `f` +with respect to the measure `μ`. + +See [`mintegrate_exp(f, μ)`](@ref) for details. +""" +∫exp(f, μ::AbstractMeasure) = mintegrate_exp(f, μ) +export ∫exp + +""" + 𝒹(ν, μ) = density_rel(ν, μ) + +Compute the density, i.e. the +[Radom-Nikodym derivative](https://en.wikipedia.org/wiki/Radon%E2%80%93Nikodym_theorem) +of `ν`` with respect to `μ`. + +For details, see [`density_rel(ν, μ)`}(@ref). +""" +𝒹(ν, μ::AbstractMeasure) = density_rel(ν, μ) +export 𝒹 + +""" + log𝒹(ν, μ) = logdensity_rel(ν, μ) + +Compute the log-density, i.e. the logarithm of the +[Radom-Nikodym derivative](https://en.wikipedia.org/wiki/Radon%E2%80%93Nikodym_theorem) +of `ν`` with respect to `μ`. + +For details, see [`logdensity_rel(ν, μ)`}(@ref). +""" +log𝒹(ν, μ::AbstractMeasure) = logdensity_rel(ν, μ) +export log𝒹 + +end # module MeasureOperators diff --git a/src/parameterized.jl b/src/parameterized.jl index 78e43995..8b1c8c88 100644 --- a/src/parameterized.jl +++ b/src/parameterized.jl @@ -127,14 +127,3 @@ params(::Type{PM}) where {N,PM<:ParameterizedMeasure{N}} = N function paramnames(μ, constraints::NamedTuple{N}) where {N} tuple((k for k in paramnames(μ) if k ∉ N)...) end - -############################################################################### -# kernelfactor - -function kernelfactor(::Type{P}) where {N,P<:ParameterizedMeasure{N}} - (constructorof(P), N) -end - -function kernelfactor(::P) where {N,P<:ParameterizedMeasure{N}} - (constructorof(P), N) -end diff --git a/src/static.jl b/src/static.jl index b723d043..da471b62 100644 --- a/src/static.jl +++ b/src/static.jl @@ -49,7 +49,9 @@ Returns the length of `x` as a dynamic or static integer. """ maybestatic_length(x) = length(x) maybestatic_length(x::AbstractUnitRange) = length(x) -function maybestatic_length(::Static.OptionallyStaticUnitRange{<:StaticInteger{A},<:StaticInteger{B}}) where {A,B} +function maybestatic_length( + ::Static.OptionallyStaticUnitRange{<:StaticInteger{A},<:StaticInteger{B}}, +) where {A,B} StaticInt{B - A + 1}() end diff --git a/src/transport.jl b/src/transport.jl index ce8ce1fd..d6754e1b 100644 --- a/src/transport.jl +++ b/src/transport.jl @@ -135,9 +135,6 @@ end return static(10) end -_origin_depth_pullback(ΔΩ) = NoTangent(), NoTangent() -ChainRulesCore.rrule(::typeof(_origin_depth), ν) = _origin_depth(ν), _origin_depth_pullback - # If both both measures have no origin: function _transport_between_origins(ν, ::StaticInteger{0}, ::StaticInteger{0}, μ, x) _transport_with_intermediate(ν, _transport_intermediate(ν, μ), μ, x) diff --git a/test/combinators/reshape.jl b/test/combinators/reshape.jl new file mode 100644 index 00000000..c6624582 --- /dev/null +++ b/test/combinators/reshape.jl @@ -0,0 +1,7 @@ +using Test + +using MeasureBase + +@testset "reshape" begin + +end diff --git a/test/distributions/getjacobian.jl b/test/distributions/getjacobian.jl new file mode 100644 index 00000000..87de7b86 --- /dev/null +++ b/test/distributions/getjacobian.jl @@ -0,0 +1,34 @@ +# This file is a part of ChangesOfVariables.jl, licensed under the MIT License (MIT). + +import ForwardDiff + +torv_and_back(V::AbstractVector{<:Real}) = V, identity +torv_and_back(x::Real) = [x], V -> V[1] +torv_and_back(x::Complex) = [real(x), imag(x)], V -> Complex(V[1], V[2]) +torv_and_back(x::NTuple{N}) where N = [x...], V -> ntuple(i -> V[i], Val(N)) + +function torv_and_back(x::Ref) + xval = x[] + V, to_xval = torv_and_back(xval) + back_to_ref(V) = Ref(to_xval(V)) + return (V, back_to_ref) +end + +torv_and_back(A::AbstractArray{<:Real}) = vec(A), V -> reshape(V, size(A)) + +function torv_and_back(A::AbstractArray{Complex{T}, N}) where {T<:Real, N} + RA = cat(real.(A), imag.(A), dims = N+1) + V, to_array = torv_and_back(RA) + function back_to_complex(V) + RA = to_array(V) + Complex.(view(RA, map(_ -> :, size(A))..., 1), view(RA, map(_ -> :, size(A))..., 2)) + end + return (V, back_to_complex) +end + + +function getjacobian(f, x) + V, to_x = torv_and_back(x) + vf(V) = torv_and_back(f(to_x(V)))[1] + ForwardDiff.jacobian(vf, V) +end diff --git a/test/distributions/test_autodiff_utils.jl b/test/distributions/test_autodiff_utils.jl new file mode 100644 index 00000000..6399bb80 --- /dev/null +++ b/test/distributions/test_autodiff_utils.jl @@ -0,0 +1,19 @@ +# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT). + +using DistributionMeasures +using Test + +using LinearAlgebra +using Distributions, ArraysOfArrays +import ForwardDiff, Zygote + + +@testset "trafo_utils" begin + xs = rand(5) + @test Zygote.jacobian(DistributionMeasures._pushfront, xs, 42)[1] ≈ ForwardDiff.jacobian(xs -> DistributionMeasures._pushfront(xs, 1), xs) + @test Zygote.jacobian(DistributionMeasures._pushfront, xs, 42)[2] ≈ vec(ForwardDiff.jacobian(x -> DistributionMeasures._pushfront(xs, x[1]), [42])) + @test Zygote.jacobian(DistributionMeasures._pushback, xs, 42)[1] ≈ ForwardDiff.jacobian(xs -> DistributionMeasures._pushback(xs, 1), xs) + @test Zygote.jacobian(DistributionMeasures._pushback, xs, 42)[2] ≈ vec(ForwardDiff.jacobian(x -> DistributionMeasures._pushback(xs, x[1]), [42])) + @test Zygote.jacobian(DistributionMeasures._rev_cumsum, xs)[1] ≈ ForwardDiff.jacobian(DistributionMeasures._rev_cumsum, xs) + @test Zygote.jacobian(DistributionMeasures._exp_cumsum_log, xs)[1] ≈ ForwardDiff.jacobian(DistributionMeasures._exp_cumsum_log, xs) ≈ ForwardDiff.jacobian(cumprod, xs) +end diff --git a/test/distributions/test_distribution_measure.jl b/test/distributions/test_distribution_measure.jl new file mode 100644 index 00000000..33b1ecbe --- /dev/null +++ b/test/distributions/test_distribution_measure.jl @@ -0,0 +1,54 @@ +# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT). + +using DistributionMeasures +using Test + +import Distributions +using Distributions: Distribution +import MeasureBase +using MeasureBase: AbstractMeasure + +@testset "Measure interface" begin + d = Distributions.Weibull() + @test @inferred(AbstractMeasure(d)) isa AbstractMeasure + @test @inferred(AbstractMeasure(d)) isa DistributionMeasure + @test @inferred(convert(AbstractMeasure, d)) isa AbstractMeasure + @test @inferred(convert(AbstractMeasure, d)) isa DistributionMeasure + @test @inferred(Distribution(AbstractMeasure(d))) === d + @test @inferred(convert(Distribution, convert(AbstractMeasure, d))) === d + + + c0 = AbstractMeasure(Distributions.Weibull(0.7, 1.3)) + c1 = AbstractMeasure(Distributions.MvNormal([0.7, 0.9], [1.4 0.5; 0.5 1.1])) + + d0 = AbstractMeasure(Distributions.Poisson(0.7)) + d1 = AbstractMeasure(Distributions.product_distribution(Distributions.Poisson.([0.7, 1.4]))) + + for μ in [c0, c1, d0, d1] + d = Distribution(μ) + x = rand(μ) + @test @inferred(MeasureBase.logdensity_def(μ, x)) == Distributions.logpdf(d, x) + @test @inferred(MeasureBase.unsafe_logdensityof(μ, x)) == Distributions.logpdf(d, x) + + MeasureBase.Interface.test_interface(d) + end + + @test @inferred(MeasureBase.basemeasure(c0)) == MeasureBase.Lebesgue(MeasureBase.ℝ) + @test @inferred(MeasureBase.basemeasure(c1)) == MeasureBase.Lebesgue(MeasureBase.ℝ) ^ 2 + + @test @inferred(MeasureBase.insupport(c0, 3)) == true + @test @inferred(MeasureBase.insupport(c0, -3)) == false + @test @inferred(MeasureBase.insupport(c1, [0.1, 0.2])) == true + @test @inferred(MeasureBase.insupport(d0, 3)) == true + @test @inferred(MeasureBase.insupport(d0, 3.2)) == false + @test @inferred(MeasureBase.insupport(d1, [1, 2])) == true + @test @inferred(MeasureBase.insupport(d1, [1.1, 2.2])) == false + + @test MeasureBase.paramnames(c0) == (:α, :θ) + if VERSION >= v"1.8" + @test @inferred(MeasureBase.params(c0)) == (α = 0.7, θ = 1.3) + else + # v1.6 can't type-infer this: + @test (MeasureBase.params(c0)) == (α = 0.7, θ = 1.3) + end +end diff --git a/test/distributions/test_distributions.jl b/test/distributions/test_distributions.jl new file mode 100644 index 00000000..6ad52a73 --- /dev/null +++ b/test/distributions/test_distributions.jl @@ -0,0 +1,12 @@ +using DistributionMeasures +using Test + +@testset "Distributions extension" begin + include("test_autodiff_utils.jl") + include("test_measure_interface.jl") + include("test_distribution_measure.jl") + include("test_standard_dist.jl") + include("test_standard_uniform.jl") + include("test_standard_normal.jl") + include("test_transport.jl") +end diff --git a/test/distributions/test_measure_interface.jl b/test/distributions/test_measure_interface.jl new file mode 100644 index 00000000..50873533 --- /dev/null +++ b/test/distributions/test_measure_interface.jl @@ -0,0 +1,44 @@ +# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT). + +using DistributionMeasures +using Test + +import Distributions +import MeasureBase + +@testset "Measure interface" begin + c0 = Distributions.Weibull(0.7, 1.3) + c1 = Distributions.MvNormal([0.7, 0.9], [1.4 0.5; 0.5 1.1]) + + d0 = Distributions.Poisson(0.7) + d1 = Distributions.product_distribution(Distributions.Poisson.([0.7, 1.4])) + + for d in [c0, c1, d0, d1] + x = rand(d) + @test @inferred(MeasureBase.logdensity_def(d, x)) == Distributions.logpdf(d, x) + @test @inferred(MeasureBase.unsafe_logdensityof(d, x)) == Distributions.logpdf(d, x) + + MeasureBase.Interface.test_interface(d) + end + + @test @inferred(MeasureBase.basemeasure(c0)) == MeasureBase.Lebesgue(MeasureBase.ℝ) + @test @inferred(MeasureBase.basemeasure(c1)) == MeasureBase.Lebesgue(MeasureBase.ℝ) ^ 2 + + @test @inferred(MeasureBase.insupport(c0, 3)) == true + @test @inferred(MeasureBase.insupport(c0, -3)) == false + @test @inferred(MeasureBase.insupport(c1, [0.1, 0.2])) == true + @test @inferred(MeasureBase.insupport(d0, 3)) == true + @test @inferred(MeasureBase.insupport(d0, 3.2)) == false + @test @inferred(MeasureBase.insupport(d1, [1, 2])) == true + @test @inferred(MeasureBase.insupport(d1, [1.1, 2.2])) == false + + @test MeasureBase.paramnames(c0) == (:α, :θ) + if VERSION >= v"1.8" + @test @inferred(MeasureBase.params(c0)) == (α = 0.7, θ = 1.3) + else + # v1.6 can't type-infer this: + @test (MeasureBase.params(c0)) == (α = 0.7, θ = 1.3) + end + + @test MeasureBase.∫(x -> Distributions.Normal(x, 0), Distributions.Normal()) isa MeasureBase.DensityMeasure +end diff --git a/test/distributions/test_standard_dist.jl b/test/distributions/test_standard_dist.jl new file mode 100644 index 00000000..4d211d87 --- /dev/null +++ b/test/distributions/test_standard_dist.jl @@ -0,0 +1,129 @@ +# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT). + +using DistributionMeasures +using Test + +using Random, Statistics, LinearAlgebra +using Distributions, PDMats +using StableRNGs +import ForwardDiff, ChainRulesTestUtils + + +@testset "standard_dist" begin + stblrng() = StableRNG(789990641) + + for (D, sz, dref) in [ + (Uniform, (), Uniform()), + (Uniform, (5,), product_distribution(fill(Uniform(0.0, 1.0), 5))), + (Uniform, (2, 3), reshape(product_distribution(fill(Uniform(0.0, 1.0), 6)), 2, 3)), + (Normal, (), Normal()), + (Normal, (), Normal(0., 1.0)), + (Normal, (5,), MvNormal(Diagonal(fill(1.0, 5)))), + (Normal, (2, 3), reshape(MvNormal(Diagonal(fill(1.0, 6))), 2, 3)), + (Exponential, (), Exponential()), + (Exponential, (5,), product_distribution(fill(Exponential(1.0), 5))), + (Exponential, (2, 3), reshape(product_distribution(fill(Exponential(1.0), 6)), 2, 3)), + ] + @testset "StandardDist{$D}($(join(sz,",")))" begin + N = length(sz) + + @test @inferred(StandardDist{D}(sz...)) isa StandardDist{D} + @test @inferred(StandardDist{D}(sz...)) isa StandardDist{D} + @test @inferred(size(StandardDist{D}(sz...))) == size(dref) + @test @inferred(size(StandardDist{D}(sz...))) == size(dref) + + d = StandardDist{D}(sz...) + + if size(d) == () + @test @inferred(DistributionMeasures.nonstddist(d)) == dref + end + + @test @inferred(length(d)) == length(dref) + @test @inferred(size(d)) == size(dref) + + @test @inferred(eltype(typeof(d))) == eltype(typeof(dref)) + @test @inferred(eltype(d)) == eltype(dref) + + @test @inferred(Distributions.params(d)) == () + @test @inferred(partype(d)) == partype(dref) + + for f in [minimum, maximum, mean, median, mode, modes, var, std, skewness, kurtosis, location, scale, entropy] + supported_by_dref = try f(dref); true catch MethodError; false; end + if supported_by_dref + @test @inferred(f(d)) ≈ f(dref) + end + end + + for x in [rand(dref) for i in 1:10] + ref_gradlogpdf = try + gradlogpdf(dref, x) + catch MethodError + ForwardDiff.gradient(x -> logpdf(dref, x), x) + end + @test @inferred(gradlogpdf(d, x)) ≈ ref_gradlogpdf + @test @inferred(logpdf(d, x)) ≈ logpdf(dref, x) + @test @inferred(pdf(d, x)) ≈ pdf(dref, x) + end + + if size(d) == () + for x in [minimum(dref), quantile(dref, 1//3), quantile(dref, 1//2), quantile(dref, 2//3), maximum(dref)] + for f in [logpdf, pdf, gradlogpdf, logcdf, cdf, logccdf, ccdf] + @test @inferred(f(d, x)) ≈ f(dref, x) + end + end + + for x in [0, 1//3, 1//2, 2//3, 1] + for f in [quantile, cquantile] + @test @inferred(f(d, x)) ≈ f(dref, x) + end + end + + for x in log.([0, 1//3, 1//2, 2//3, 1]) + for f in [invlogcdf, invlogccdf] + @test @inferred(f(d, x)) ≈ f(dref, x) + end + end + + for p in [0.0, 0.25, 0.75, 1.0] + @test @inferred(quantile(d, p)) == quantile(dref, p) + @test @inferred(cquantile(d, p)) == cquantile(dref, p) + end + + for t in [-3, 0, 3] + @test isapprox(@inferred(mgf(d, t)), mgf(dref, t), rtol = 1e-5) + @test isapprox(@inferred(cf(d, t)), cf(dref, t), rtol = 1e-5) + end + + @test @inferred(truncated(d, quantile(dref, 1//3), quantile(dref, 2//3))) == truncated(dref, quantile(dref, 1//3), quantile(dref, 2//3)) + + @test @inferred(product_distribution(fill(d, 3))) == StandardDist{typeof(d)}(3) + @test @inferred(product_distribution(fill(d, 3, 4))) == StandardDist{typeof(d)}(3, 4) + end + + if length(size(d)) == 1 + @test @inferred(convert(Distributions.Product, d)) isa Distributions.Product + d_as_prod = convert(Distributions.Product, d) + @test d_as_prod.v == fill(StandardDist{D}(), size(d)...) + end + + @test @inferred(rand(stblrng(), d)) == rand(stblrng(), d) + @test @inferred(rand(stblrng(), d, 5)) == rand(stblrng(), d, 5) + + @test @inferred(rand(stblrng(), d)) == rand(stblrng(), dref) + @test @inferred(rand(stblrng(), d, 5)) == rand(stblrng(), dref, 5) + @test @inferred(rand!(stblrng(), d, zeros(size(d)...))) == rand!(stblrng(), dref, zeros(size(dref)...)) + if length(size(d)) == 1 + @test @inferred(rand!(stblrng(), d, zeros(size(d)..., 5))) == rand!(stblrng(), dref, zeros(size(dref)..., 5)) + end + end + end + + @testset "StandardDist{Normal}()" begin + # TODO: Add @inferred + d = StandardDist{Normal}(4) + d_uv = StandardDist{Normal}() + dref = MvNormal(Diagonal(fill(1.0, 4))) + @test (MvNormal(d)) == dref + @test (Base.convert(MvNormal, d)) == dref + end +end diff --git a/test/distributions/test_standard_normal.jl b/test/distributions/test_standard_normal.jl new file mode 100644 index 00000000..4d5cbad4 --- /dev/null +++ b/test/distributions/test_standard_normal.jl @@ -0,0 +1,130 @@ +# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT). + +using DistributionMeasures +using Test + +using Random, Statistics, LinearAlgebra +using Distributions, PDMats +using StableRNGs + + +@testset "StandardDist{Normal}" begin + stblrng() = StableRNG(789990641) + + @testset "StandardDist{Normal,0}" begin + @test @inferred(Normal(StandardDist{Normal}())) isa Normal{Float64} + @test @inferred(Normal(StandardDist{Normal}())) == Normal() + @test @inferred(convert(Normal, StandardDist{Normal}())) == Normal() + + d = StandardDist{Normal}() + dref = Normal() + + @test @inferred(minimum(d)) == minimum(dref) + @test @inferred(maximum(d)) == maximum(dref) + + @test @inferred(Distributions.params(d)) == () + @test @inferred(partype(d)) == partype(dref) + + @test @inferred(location(d)) == location(dref) + @test @inferred(scale(d)) == scale(dref) + + @test @inferred(eltype(typeof(d))) == eltype(typeof(dref)) + @test @inferred(eltype(d)) == eltype(dref) + + @test @inferred(length(d)) == length(dref) + @test @inferred(size(d)) == size(dref) + + @test @inferred(mean(d)) == mean(dref) + @test @inferred(median(d)) == median(dref) + @test @inferred(mode(d)) == mode(dref) + @test @inferred(modes(d)) ≈ modes(dref) + + @test @inferred(var(d)) == var(dref) + @test @inferred(std(d)) == std(dref) + @test @inferred(skewness(d)) == skewness(dref) + @test @inferred(kurtosis(d)) == kurtosis(dref) + + @test @inferred(entropy(d)) == entropy(dref) + + for x in [-Inf, -1.3, 0.0, 1.3, +Inf] + @test @inferred(gradlogpdf(d, x)) == gradlogpdf(dref, x) + + @test @inferred(logpdf(d, x)) == logpdf(dref, x) + @test @inferred(pdf(d, x)) == pdf(dref, x) + @test @inferred(logcdf(d, x)) == logcdf(dref, x) + @test @inferred(cdf(d, x)) == cdf(dref, x) + @test @inferred(logccdf(d, x)) == logccdf(dref, x) + @test @inferred(ccdf(d, x)) == ccdf(dref, x) + end + + for p in [0.0, 0.25, 0.75, 1.0] + @test @inferred(quantile(d, p)) == quantile(dref, p) + @test @inferred(cquantile(d, p)) == cquantile(dref, p) + end + + for t in [-3, 0, 3] + @test @inferred(mgf(d, t)) == mgf(dref, t) + @test @inferred(cf(d, t)) == cf(dref, t) + end + + @test @inferred(rand(stblrng(), d)) == rand(stblrng(), dref) + @test @inferred(rand!(stblrng(), d, fill(0.0))) == rand!(stblrng(), dref, fill(0.0)) + @test @inferred(rand(stblrng(), d, 5)) == rand(stblrng(), dref, 5) + + @test @inferred(truncated(StandardDist{Normal}(), -2.2f0, 3.1f0)) isa Truncated{Normal{Float64}} + @test truncated(StandardDist{Normal}(), -2.2f0, 3.1f0) == truncated(Normal(0.0, 1.0), -2.2f0, 3.1f0) + + @test @inferred(product_distribution(fill(StandardDist{Normal}(), 3))) isa StandardDist{Normal,1} + @test product_distribution(fill(StandardDist{Normal}(), 3)) == StandardDist{Normal}(3) + end + + + @testset "StandardDist{Normal,1}" begin + @test @inferred(StandardDist{Normal}(3)) isa StandardDist{Normal,1} + @test @inferred(StandardDist{Normal}(3)) isa StandardDist{Normal,1} + @test @inferred(StandardDist{Normal}(3)) isa StandardDist{Normal,1} + + @test @inferred(MvNormal(StandardDist{Normal}(3))) isa MvNormal{Int} + @test @inferred(MvNormal(StandardDist{Normal}(3))) == MvNormal(ScalMat(3, 1.0)) + @test @inferred(convert(MvNormal, StandardDist{Normal}(3))) == MvNormal(ScalMat(3, 1.0)) + + d = StandardDist{Normal}(3) + dref = MvNormal(ScalMat(3, 1.0)) + + @test @inferred(eltype(typeof(d))) == eltype(typeof(dref)) + @test @inferred(eltype(d)) == eltype(dref) + + @test @inferred(length(d)) == length(dref) + @test @inferred(size(d)) == size(dref) + + @test @inferred(Distributions.params(d)) == () + @test @inferred(partype(d)) == partype(dref) + + @test @inferred(mean(d)) == mean(dref) + @test @inferred(var(d)) == var(dref) + @test @inferred(cov(d)) == cov(dref) + + @test @inferred(mode(d)) == mode(dref) + @test @inferred(modes(d)) == modes(dref) + + @test @inferred(invcov(d)) == invcov(dref) + @test @inferred(logdetcov(d)) == logdetcov(dref) + + @test @inferred(entropy(d)) == entropy(dref) + + for x in fill.([-Inf, -1.3, 0.0, 1.3, +Inf], 3) + # Distributions.insupport is inconsistent at +- Inf between Normal and MvNormal + if !any(isinf, x) + @test @inferred(Distributions.insupport(d, x)) == Distributions.insupport(dref, x) + end + @test @inferred(logpdf(d, x)) == logpdf(dref, x) + @test @inferred(pdf(d, x)) == pdf(dref, x) + @test @inferred(sqmahal(d, x)) == sqmahal(dref, x) + @test @inferred(gradlogpdf(d, x)) == gradlogpdf(dref, x) + end + + @test @inferred(rand(stblrng(), d)) == rand(stblrng(), d) + @test @inferred(rand!(stblrng(), d, zeros(3))) == rand!(stblrng(), d, zeros(3)) + @test @inferred(rand!(stblrng(), d, zeros(3, 10))) == rand!(stblrng(), d, zeros(3, 10)) + end +end diff --git a/test/distributions/test_standard_uniform.jl b/test/distributions/test_standard_uniform.jl new file mode 100644 index 00000000..66c8a99c --- /dev/null +++ b/test/distributions/test_standard_uniform.jl @@ -0,0 +1,119 @@ +# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT). + +using DistributionMeasures +using Test + +using Random, Statistics, LinearAlgebra +using Distributions, PDMats +using StableRNGs +using FillArrays +using ForwardDiff + + +@testset "StandardDist{Uniform}" begin + stblrng() = StableRNG(789990641) + + @testset "StandardDist{Uniform,0}" begin + @test @inferred(Uniform(StandardDist{Uniform}())) isa Uniform{Float64} + @test @inferred(Uniform(StandardDist{Uniform}())) == Uniform() + @test @inferred(convert(Uniform, StandardDist{Uniform}())) == Uniform() + + d = StandardDist{Uniform}() + dref = Uniform() + + @test @inferred(minimum(d)) == minimum(dref) + @test @inferred(maximum(d)) == maximum(dref) + + @test @inferred(Distributions.params(d)) == () + @test @inferred(partype(d)) == partype(dref) + + @test @inferred(location(d)) == location(dref) + @test @inferred(scale(d)) == scale(dref) + + @test @inferred(eltype(typeof(d))) == eltype(typeof(dref)) + @test @inferred(eltype(d)) == eltype(dref) + + @test @inferred(length(d)) == length(dref) + @test @inferred(size(d)) == size(dref) + + @test @inferred(mean(d)) == mean(dref) + @test @inferred(median(d)) == median(dref) + @test @inferred(mode(d)) == mode(dref) + @test @inferred(modes(d)) ≈ modes(dref) + + @test @inferred(var(d)) ≈ var(dref) + @test @inferred(std(d)) ≈ std(dref) + @test @inferred(skewness(d)) == skewness(dref) + @test @inferred(kurtosis(d)) ≈ kurtosis(dref) + + @test @inferred(entropy(d)) == entropy(dref) + + for x in [-0.5, 0.0, 0.25, 0.75, 1.0, 1.5] + @test @inferred(logpdf(d, x)) == logpdf(dref, x) + @test @inferred(pdf(d, x)) == pdf(dref, x) + @test @inferred(logcdf(d, x)) == logcdf(dref, x) + @test @inferred(cdf(d, x)) == cdf(dref, x) + @test @inferred(logccdf(d, x)) == logccdf(dref, x) + @test @inferred(ccdf(d, x)) == ccdf(dref, x) + end + + for p in [0.0, 0.25, 0.75, 1.0] + @test @inferred(quantile(d, p)) == quantile(dref, p) + @test @inferred(cquantile(d, p)) == cquantile(dref, p) + end + + for t in [-3, 0, 3] + @test @inferred(mgf(d, t)) == mgf(dref, t) + @test @inferred(cf(d, t)) == cf(dref, t) + end + + @test @inferred(rand(stblrng(), d)) == rand(stblrng(), dref) + @test @inferred(rand!(stblrng(), d, fill(0.0))) == rand!(stblrng(), dref, fill(0.0)) + @test @inferred(rand(stblrng(), d, 5)) == rand(stblrng(), dref, 5) + + @test @inferred(truncated(StandardDist{Uniform}(), -0.5f0, 0.7f0)) isa Uniform{Float64} + @test truncated(StandardDist{Uniform}(), -0.5f0, 0.7f0) == Uniform(0.0f0, 0.7f0) + @test truncated(StandardDist{Uniform}(), 0.2f0, 0.7f0) == Uniform(0.2f0, 0.7f0) + + @test @inferred(product_distribution(fill(StandardDist{Uniform}(), 3))) isa DistributionMeasures.StandardDist{Uniform,1} + @test product_distribution(fill(StandardDist{Uniform}(), 3)) == DistributionMeasures.StandardDist{Uniform}(3) + end + + + @testset "StandardDist{Uniform,1}" begin + d = DistributionMeasures.StandardDist{Uniform}(3) + dref = product_distribution(fill(Uniform(), 3)) + + @test @inferred(eltype(typeof(d))) == eltype(typeof(dref)) + @test @inferred(eltype(d)) == eltype(dref) + + @test @inferred(length(d)) == length(dref) + @test @inferred(size(d)) == size(dref) + + @test @inferred(Distributions.params(d)) == () + @test @inferred(partype(d)) == partype(dref) + + @test @inferred(mean(d)) == mean(dref) + @test @inferred(var(d)) ≈ var(dref) + @test @inferred(cov(d)) ≈ cov(dref) + + @test @inferred(mode(d)) == [0.5, 0.5, 0.5] + @test @inferred(modes(d)) == fill([0, 0,0 ]) + + @test @inferred(invcov(d)) == inv(cov(dref)) + @test @inferred(logdetcov(d)) == logdet(cov(dref)) + + @test @inferred(entropy(d)) == entropy(dref) + + for x in fill.([-Inf, -1.3, 0.0, 1.3, +Inf], 3) + @test @inferred(Distributions.insupport(d, x)) == Distributions.insupport(dref, x) + @test @inferred(logpdf(d, x)) == logpdf(dref, x) + @test @inferred(pdf(d, x)) == pdf(dref, x) + @test @inferred(gradlogpdf(d, x)) == ForwardDiff.gradient(x -> logpdf(d, x), x) + end + + @test @inferred(rand(stblrng(), d)) == rand(stblrng(), d) + @test @inferred(rand!(stblrng(), d, zeros(3))) == rand!(stblrng(), d, zeros(3)) + @test @inferred(rand!(stblrng(), d, zeros(3, 10))) == rand!(stblrng(), d, zeros(3, 10)) + end +end diff --git a/test/distributions/test_transport.jl b/test/distributions/test_transport.jl new file mode 100644 index 00000000..1542dcc2 --- /dev/null +++ b/test/distributions/test_transport.jl @@ -0,0 +1,149 @@ +# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT). + +using DistributionMeasures +using Test + +using LinearAlgebra +using InverseFunctions, ChangesOfVariables +using Distributions, ArraysOfArrays +import ForwardDiff, Zygote + +using MeasureBase: transport_to, transport_def, transport_origin +using MeasureBase: StdExponential +using DistributionMeasures: _trafo_cdf, _trafo_quantile + +include("getjacobian.jl") + + +@testset "test_distribution_transform" begin + function test_back_and_forth(trg, src) + @testset "transform $(typeof(trg).name) <-> $(typeof(src).name)" begin + x = rand(src) + y = transport_def(trg, src, x) + src_v_reco = transport_def(src, trg, y) + + @test x ≈ src_v_reco + + f = x -> transport_def(trg, src, x) + ref_ladj = logpdf(src, x) - logpdf(trg, y) + @test ref_ladj ≈ logabsdet(getjacobian(f, x))[1] + end + end + + reshaped_rand(d::Distribution{Univariate}, n) = rand(d, n) + reshaped_rand(d::Distribution{Multivariate}, n) = nestedview(rand(d, n)) + + function test_dist_trafo_moments(trg, src) + unshaped(x) = first(torv_and_back(x)) + @testset "check moments of trafo $(typeof(trg).name) <- $(typeof(src).name)" begin + X = reshaped_rand(src, 10^5) + Y = transport_to(trg, src).(X) + Y_ref = reshaped_rand(trg, 10^6) + @test isapprox(mean(unshaped.(Y)), mean(unshaped.(Y_ref)), rtol = 0.5) + @test isapprox(cov(unshaped.(Y)), cov(unshaped.(Y_ref)), rtol = 0.5) + end + end + + @testset "transforms-tests" begin + stduvuni = StandardDist{Uniform}() + stduvnorm = StandardDist{Uniform}() + + uniform1 = Uniform(-5.0, -0.01) + uniform2 = Uniform(0.01, 5.0) + + normal1 = Normal(-10, 1) + normal2 = Normal(10, 5) + + stdmvnorm1 = StandardDist{Normal}(1) + stdmvnorm2 = StandardDist{Normal}(2) + + stdmvuni2 = StandardDist{Uniform}(2) + + standnorm2_reshaped = reshape(stdmvnorm2, 1, 2) + + mvnorm = MvNormal([0.3, -2.9], [1.7 0.5; 0.5 2.3]) + beta = Beta(3,1) + gamma = Gamma(0.1,0.7) + dirich = Dirichlet([0.1,4]) + + test_back_and_forth(stduvuni, stduvuni) + test_back_and_forth(stduvnorm, stduvnorm) + test_back_and_forth(stduvuni, stduvnorm) + test_back_and_forth(stduvnorm, stduvuni) + + test_back_and_forth(stdmvuni2, stdmvuni2) + test_back_and_forth(stdmvnorm2, stdmvnorm2) + test_back_and_forth(stdmvuni2, stdmvnorm2) + test_back_and_forth(stdmvnorm2, stdmvuni2) + + test_back_and_forth(beta, stduvnorm) + test_back_and_forth(gamma, stduvnorm) + test_back_and_forth(gamma, beta) + + test_back_and_forth(mvnorm, stdmvuni2) + test_back_and_forth(stdmvuni2, mvnorm) + + test_back_and_forth(mvnorm, standnorm2_reshaped) + test_back_and_forth(standnorm2_reshaped, mvnorm) + test_back_and_forth(stdmvnorm2, standnorm2_reshaped) + test_back_and_forth(standnorm2_reshaped, standnorm2_reshaped) + + test_dist_trafo_moments(normal2, normal1) + test_dist_trafo_moments(uniform2, uniform1) + + test_dist_trafo_moments(beta, stduvnorm) + test_dist_trafo_moments(gamma, stduvnorm) + + test_dist_trafo_moments(mvnorm, stdmvnorm2) + test_dist_trafo_moments(dirich, stdmvnorm1) + + let + mvuni = product_distribution([Uniform(), Uniform()]) + + x = rand() + @test_throws ArgumentError transport_to(stduvnorm, mvnorm)(x) + @test_throws ArgumentError transport_to(stduvnorm, stdmvnorm1)(x) + @test_throws ArgumentError transport_to(stduvnorm, stdmvnorm2)(x) + + x = rand(2) + @test_throws ArgumentError transport_to(stduvnorm, mvnorm)(x) + @test_throws ArgumentError transport_to(stduvnorm, stdmvnorm1)(x) + @test_throws ArgumentError transport_to(stduvnorm, stdmvnorm2)(x) + end + end + + @testset "Custom cdf and quantile for dual numbers" begin + Dual = ForwardDiff.Dual + + @test isapprox(_trafo_cdf(Normal(Dual(0, 1, 0, 0), Dual(1, 0, 1, 0)), Dual(0.5, 0, 0, 1)), cdf(Normal(Dual(0, 1, 0, 0), Dual(1, 0, 1, 0)), Dual(0.5, 0, 0, 1)), rtol = 10^-6) + @test isapprox(_trafo_cdf(Normal(0, 1), Dual(0.5, 1)), cdf(Normal(0, 1), Dual(0.5, 1)), rtol = 10^-6) + + @test isapprox(_trafo_quantile(Normal(0, 1), Dual(0.5, 1)), quantile(Normal(0, 1), Dual(0.5, 1)), rtol = 10^-6) + @test isapprox(_trafo_quantile(Normal(Dual(0, 1, 0, 0), Dual(1, 0, 1, 0)), Dual(0.5, 0, 0, 1)), quantile(Normal(Dual(0, 1, 0, 0), Dual(1, 0, 1, 0)), Dual(0.5, 0, 0, 1)), rtol = 10^-6) + end + + @testset "trafo autodiff pullbacks" begin + x = [0.6, 0.7, 0.8, 0.9] + f = transport_to(Dirichlet([3.0, 4.0, 5.0, 6.0, 7.0]), Uniform) + @test isapprox(ForwardDiff.jacobian(f, x), Zygote.jacobian(f, x)[1], rtol = 10^-4) + f = inverse(transport_to(Normal, Dirichlet([3.0, 4.0, 5.0, 6.0, 7.0]))) + @test isapprox(ForwardDiff.jacobian(f, x), Zygote.jacobian(f, x)[1], rtol = 10^-4) + end + + + @testset "transport_to autosel" begin + for (M,R) in [ + (StandardNormal, StandardNormal) + (Normal, StandardNormal) + (StandardUniform, StandardUniform) + (Uniform, StandardUniform) + ] + @test @inferred(transport_to(M, Weibull())) == transport_to(R(), Weibull()) + @test @inferred(transport_to(Weibull(), M)) == transport_to(Weibull(), R()) + @test @inferred(transport_to(M, MvNormal(float(I(5))))) == transport_to(R(5), MvNormal(float(I(5)))) + @test @inferred(transport_to(MvNormal(float(I(5))), M)) == transport_to(MvNormal(float(I(5))), R(5)) + @test @inferred(transport_to(M, StdExponential()^(2,3))) == transport_to(R(6), StdExponential()^(2,3)) + @test @inferred(transport_to(StdExponential()^(2,3), M)) == transport_to(StdExponential()^(2,3), R(6)) + end + end +end diff --git a/test/measure_operators.jl b/test/measure_operators.jl new file mode 100644 index 00000000..a3adaa8f --- /dev/null +++ b/test/measure_operators.jl @@ -0,0 +1,24 @@ +using Test + +using MeasureBase: AbstractMeasure +using MeasureBase: StdExponential, StdLogistic, StdUniform +using MeasureBase: pushfwd, pullbck, mbind, productmeasure +using MeasureBase: mintegrate, mintegrate_exp, density_rel, logdensity_rel +using MeasureBase.MeasureOperators: ⋄, ⊙, ▷, ⊗, ∫, ∫exp, 𝒹, log𝒹 + +@testset "MeasureOperators" begin + μ = StdExponential() + ν = StdUniform() + k(σ) = pushfwd(x -> σ * x, StdNormal()) + μs = (StdExponential(), StdLogistic(), StdUniform()) + f = sqrt + + @test @inferred(f ⋄ μ) == pushfwd(f, μ) + @test @inferred(ν ⊙ f) == pullbck(f, ν) + @test @inferred(μ ▷ k) == mbind(k, μ) + @test @inferred(⊗(μs...)) == productmeasure(μs) + @test @inferred(∫(f, μ)) == mintegrate(f, μ) + @test @inferred(∫exp(f, μ)) == mintegrate_exp(f, μ) + @test @inferred(𝒹(ν, μ)) == density_rel(ν, μ) + @test @inferred(log𝒹(ν, μ)) == logdensity_rel(ν, μ) +end diff --git a/test/runtests.jl b/test/runtests.jl index f9263b6d..97314dd4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -19,5 +19,10 @@ include("smf.jl") include("combinators/weighted.jl") include("combinators/transformedmeasure.jl") +include("combinators/reshape.jl") + +include("test_distributions.jl") + +include("measure_operators.jl") include("test_docs.jl") diff --git a/test/static.jl b/test/static.jl index a6c50db2..f618124b 100644 --- a/test/static.jl +++ b/test/static.jl @@ -11,7 +11,7 @@ import FillArrays @test static(2) isa MeasureBase.IntegerLike @test true isa MeasureBase.IntegerLike @test static(true) isa MeasureBase.IntegerLike - + @test @inferred(MeasureBase.one_to(7)) isa Base.OneTo @test @inferred(MeasureBase.one_to(7)) == 1:7 @test @inferred(MeasureBase.one_to(static(7))) isa Static.SOneTo @@ -19,10 +19,13 @@ import FillArrays @test @inferred(MeasureBase.fill_with(4.2, (7,))) == FillArrays.Fill(4.2, 7) @test @inferred(MeasureBase.fill_with(4.2, (static(7),))) == FillArrays.Fill(4.2, 7) - @test @inferred(MeasureBase.fill_with(4.2, (3, static(7)))) == FillArrays.Fill(4.2, 3, 7) + @test @inferred(MeasureBase.fill_with(4.2, (3, static(7)))) == + FillArrays.Fill(4.2, 3, 7) @test @inferred(MeasureBase.fill_with(4.2, (3:7,))) == FillArrays.Fill(4.2, (3:7,)) - @test @inferred(MeasureBase.fill_with(4.2, (static(3):static(7),))) == FillArrays.Fill(4.2, (3:7,)) - @test @inferred(MeasureBase.fill_with(4.2, (3:7, static(2):static(5)))) == FillArrays.Fill(4.2, (3:7, 2:5)) + @test @inferred(MeasureBase.fill_with(4.2, (static(3):static(7),))) == + FillArrays.Fill(4.2, (3:7,)) + @test @inferred(MeasureBase.fill_with(4.2, (3:7, static(2):static(5)))) == + FillArrays.Fill(4.2, (3:7, 2:5)) @test MeasureBase.maybestatic_length(MeasureBase.one_to(7)) isa Int @test MeasureBase.maybestatic_length(MeasureBase.one_to(7)) == 7