Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 24 additions & 16 deletions src/aggregation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,10 @@ function _domain_label(transformation::ViewTransformation, index::Int)
_array_domain_label(asℝ, dims, index)
end

inverse_eltype(transformation::ViewTransformation, y) = eltype(y)
function inverse_eltype(transformation::ViewTransformation,
::Type{T}) where T <: AbstractArray
_ensure_float(eltype(T))
end

function inverse_at!(x::AbstractVector, index, transformation::ViewTransformation,
y::AbstractArray)
Expand Down Expand Up @@ -210,8 +213,9 @@ function transform_with(flag::LogJacFlag, transformation::StaticArrayTransformat
end

function inverse_eltype(transformation::Union{ArrayTransformation,StaticArrayTransformation},
x::AbstractArray)
inverse_eltype(transformation.inner_transformation, first(x)) # FIXME shortcut
::Type{T}) where T <: AbstractArray
inverse_eltype(transformation.inner_transformation,
_ensure_float(eltype(T)))
end

function inverse_at!(x::AbstractVector, index,
Expand Down Expand Up @@ -341,12 +345,17 @@ internally.

*Performs no argument validation, caller should do this.*
"""
_inverse_eltype_tuple(ts::NTransforms, ys::Tuple) =
reduce(promote_type, map(inverse_eltype, ts, ys))
# NOTE: See https://github.com/tpapp/TransformVariables.jl/pull/80
# `map` and `reduce` both have specializations on `Tuple`s that make them type stable
# even when the `Tuple` is heterogenous, but that is not currently the case with
# `mapreduce`, therefore separate `reduce` and `map` are preferred as a workaround.
function _inverse_eltype_tuple(ts::NTransforms{N}, ::Type{T}) where {N,T<:Tuple}
@argcheck T <: NTuple{N,Any} "Incompatible input length."
__inverse_eltype_tuple(ts, T)
end
function __inverse_eltype_tuple(ts::NTransforms, ::Type{Tuple{}})
Union{}
end
function __inverse_eltype_tuple(ts::NTransforms, ::Type{T}) where {T<:Tuple}
promote_type(inverse_eltype(Base.first(ts), fieldtype(T, 1)),
__inverse_eltype_tuple(Base.tail(ts), Tuple{Base.tail(fieldtypes(T))...}))
end

"""
$(SIGNATURES)
Expand All @@ -366,10 +375,9 @@ function transform_with(flag::LogJacFlag, tt::TransformTuple{<:Tuple}, x, index)
transform_tuple(flag, tt.transformations, x, index)
end

function inverse_eltype(tt::TransformTuple{<:Tuple}, y::Tuple)
function inverse_eltype(tt::TransformTuple{<:Tuple}, ::Type{T}) where T <: Tuple
(; transformations) = tt
@argcheck length(transformations) == length(y)
_inverse_eltype_tuple(transformations, y)
_inverse_eltype_tuple(transformations, T)
end

function inverse_at!(x::AbstractVector, index, tt::TransformTuple{<:Tuple}, y::Tuple)
Expand All @@ -378,19 +386,19 @@ function inverse_at!(x::AbstractVector, index, tt::TransformTuple{<:Tuple}, y::T
_inverse!_tuple(x, index, tt.transformations, y)
end

as(transformations::NamedTuple{N,<:NTransforms}) where N =
function as(transformations::NamedTuple{N,<:NTransforms}) where N
TransformTuple(transformations)
end

function transform_with(flag::LogJacFlag, tt::TransformTuple{<:NamedTuple}, x, index)
(; transformations) = tt
y, ℓ, index′ = transform_tuple(flag, values(transformations), x, index)
NamedTuple{keys(transformations)}(y), ℓ, index′
end

function inverse_eltype(tt::TransformTuple{<:NamedTuple}, y::NamedTuple)
function inverse_eltype(tt::TransformTuple{<:NamedTuple}, ::Type{NamedTuple{N,T}}) where {N,T}
(; transformations) = tt
@argcheck _same_set_of_names(transformations, y)
_inverse_eltype_tuple(values(transformations), values(NamedTuple{keys(transformations)}(y)))
_inverse_eltype_tuple(values(transformations), T)
end

function inverse_at!(x::AbstractVector, index, tt::TransformTuple{<:NamedTuple}, y::NamedTuple)
Expand Down
2 changes: 1 addition & 1 deletion src/constant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ function transform_with(logjac_flag::LogJacFlag, t::Constant, x::AbstractVector,
t.value, logjac_zero(logjac_flag, eltype(x)), index
end

inverse_eltype(t::Constant, _) = Union{}
inverse_eltype(t::Constant, ::Type) = Union{}

function inverse_at!(x::AbstractVector, index, t::Constant, y)
@argcheck t.value == y
Expand Down
40 changes: 24 additions & 16 deletions src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,16 +181,31 @@
inverse(f::CallableInverse) = Base.Fix1(transform, f.x)

"""
$(FUNCTIONNAME)(t::AbstractTransform, y)
```
$(FUNCTIONNAME)(t::AbstractTransform, y)
$(FUNCTIONNAME)(t::AbstractTransform, ::Type{T})
```

The element type for vector `x` so that `inverse!(x, t, y)` works.
The element type for vector `x` so that `inverse!(x, t, y::T)` works.

!!! note
It is not guaranteed that the result is the narrowest possible type, and may change
without warning between versions. Some effort is made to come up with a reasonable
concrete type even in corner cases.
# Notes

1. It is not guaranteed that the result is the narrowest possible type, and may change
without warning between versions. Some effort is made to come up with a reasonable
concrete type even in corner cases.

2. Transformations should provide a method for *types*, not values.

3. No dimension or input compatibility checks are guaranteed to be performed, even for
values.
"""
function inverse_eltype end
function inverse_eltype(t::AbstractTransform, y::T) where T
inverse_eltype(t, T)
end

function inverse_eltype(t::AbstractTransform, T::Type)
throw(MethodError(inverse_eltype, (t, T)))

Check warning on line 207 in src/generic.jl

View check run for this annotation

Codecov / codecov/patch

src/generic.jl#L206-L207

Added lines #L206 - L207 were not covered by tests
end

"""
$(SIGNATURES)
Expand Down Expand Up @@ -283,15 +298,8 @@

# We want to avoid vectors with non-numerical element types
# Ref https://github.com/tpapp/TransformVariables.jl/issues/132
function inverse(t::VectorTransform, y)
inverse!(Vector{_float_or_Float64(inverse_eltype(t, y))}(undef, dimension(t)), t, y)
end
function _float_or_Float64(::Type{T}) where T
if T !== Union{} && T <: Number # heuristic: it is assumed that every `Number` type defines `float`
return float(T)
else
return Float64
end
function inverse(t::VectorTransform, y::T) where T
inverse!(Vector{_ensure_float(inverse_eltype(t, T))}(undef, dimension(t)), t, y)
end

"""
Expand Down
32 changes: 21 additions & 11 deletions src/scalar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ function inverse_at!(x::AbstractVector, index::Int, t::ScalarTransform, y::Real)
index + 1
end

function inverse_eltype(t::ScalarTransform, y::Real)
function inverse_eltype(t::ScalarTransform, ::Type{T}) where T <: Real
# NOTE this is a shortcut to get sensible types for all subtypes of ScalarTransform, which
# we test for. If it breaks it should be extended accordingly.
return Base.promote_typejoin_union(Base.promote_op(inverse, typeof(t), typeof(y)))
return Base.promote_typejoin_union(Base.promote_op(inverse, typeof(t), T))
end

_domain_label(::ScalarTransform, index::Int) = ()
Expand Down Expand Up @@ -66,43 +66,50 @@ $(TYPEDEF)

Exponential transformation `x ↦ eˣ`. Maps from all reals to the positive reals.
"""
struct TVExp <: ScalarTransform
end
struct TVExp <: ScalarTransform end

transform(::TVExp, x::Real) = exp(x)

transform_and_logjac(t::TVExp, x::Real) = transform(t, x), x

function inverse(::TVExp, x::Number)
log(x)
end

inverse_and_logjac(t::TVExp, x::Number) = inverse(t, x), -log(x)

"""
$(TYPEDEF)

Logistic transformation `x ↦ logit(x)`. Maps from all reals to (0, 1).
"""
struct TVLogistic <: ScalarTransform
end
struct TVLogistic <: ScalarTransform end

transform(::TVLogistic, x::Real) = logistic(x)

transform_and_logjac(t::TVLogistic, x::Real) = transform(t, x), logistic_logjac(x)

function inverse(::TVLogistic, x::Number)
logit(x)
end

inverse_and_logjac(t::TVLogistic, x::Number) = inverse(t, x), logit_logjac(x)

"""
$(TYPEDEF)

Shift transformation `x ↦ x + shift`.
Shift transformation `x ↦ x + shift`.
"""
struct TVShift{T <: Real} <: ScalarTransform
shift::T
end

transform(t::TVShift, x::Real) = x + t.shift

transform_and_logjac(t::TVShift, x::Real) = transform(t, x), logjac_zero(LogJac(), typeof(x))

inverse(t::TVShift, x::Number) = x - t.shift

inverse_and_logjac(t::TVShift, x::Number) = inverse(t, x), logjac_zero(LogJac(), typeof(x))

"""
Expand All @@ -117,12 +124,15 @@ struct TVScale{T} <: ScalarTransform
new(scale)
end
end

TVScale(scale::T) where {T} = TVScale{T}(scale)

transform(t::TVScale, x::Real) = t.scale * x
transform_and_logjac(t::TVScale{<:Real}, x::Real) = transform(t, x), log(t.scale)

transform_and_logjac(t::TVScale{<:Real}, x::Real) = transform(t, x), log(t.scale)

inverse(t::TVScale, x::Number) = x / t.scale

inverse_and_logjac(t::TVScale{<:Real}, x::Number) = inverse(t, x), -log(t.scale)

"""
Expand Down Expand Up @@ -155,15 +165,15 @@ struct CompositeScalarTransform{Ts <: Tuple} <: ScalarTransform
end

transform(t::CompositeScalarTransform, x) = foldr(transform, t.transforms, init=x)
function transform_and_logjac(ts::CompositeScalarTransform, x)
function transform_and_logjac(ts::CompositeScalarTransform, x)
foldr(ts.transforms, init=(x, logjac_zero(LogJac(), typeof(x)))) do t, (x, logjac)
nx, nlogjac = transform_and_logjac(t, x)
(nx, logjac + nlogjac)
end
end

inverse(ts::CompositeScalarTransform, x) = foldl((y, t) -> inverse(t, y), ts.transforms, init=x)
function inverse_and_logjac(ts::CompositeScalarTransform, x)
function inverse_and_logjac(ts::CompositeScalarTransform, x)
foldl(ts.transforms, init=(x, logjac_zero(LogJac(), typeof(x)))) do (x, logjac), t
nx, nlogjac = inverse_and_logjac(t, x)
(nx, logjac + nlogjac)
Expand Down Expand Up @@ -283,7 +293,7 @@ function Base.show(io::IO, ct::CompositeScalarTransform{Tuple{TVShift{T}, TVNeg,
print(io, "as(Real, -∞, ", ct.transforms[1].shift, ")")
end
function Base.show(io::IO, ct::CompositeScalarTransform{Tuple{TVShift{T1}, TVScale{T2}, TVLogistic}}) where {T1, T2}
print(io, "as(Real, ", ct.transforms[1].shift, ", ", ct.transforms[1].shift +
print(io, "as(Real, ", ct.transforms[1].shift, ", ", ct.transforms[1].shift +
ct.transforms[2].scale, ")")
end

Expand Down
15 changes: 12 additions & 3 deletions src/special_arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,10 @@ function transform_with(flag::LogJacFlag, t::UnitVector, x::AbstractVector, inde
y, ℓ, index
end

inverse_eltype(t::UnitVector, y::AbstractVector) = robust_eltype(y)
function inverse_eltype(t::UnitVector,
::Type{T}) where T <: AbstractVector
_ensure_float(eltype(T))
end

function inverse_at!(x::AbstractVector, index, t::UnitVector, y::AbstractVector)
(; n) = t
Expand Down Expand Up @@ -157,7 +160,10 @@ function transform_with(flag::LogJacFlag, t::UnitSimplex, x::AbstractVector, ind
y, ℓ, index
end

inverse_eltype(t::UnitSimplex, y::AbstractVector) = robust_eltype(y)
function inverse_eltype(t::UnitSimplex,
::Type{T}) where T <: AbstractVector
_ensure_float(eltype(T))
end

function inverse_at!(x::AbstractVector, index, t::UnitSimplex, y::AbstractVector)
(; n) = t
Expand Down Expand Up @@ -297,7 +303,10 @@ function transform_with(flag::LogJacFlag, transformation::StaticCorrCholeskyFact
UpperTriangular(SMatrix{S,S}(U)), ℓ, index′
end

inverse_eltype(t::Union{CorrCholeskyFactor,StaticCorrCholeskyFactor}, U::UpperTriangular) = robust_eltype(U)
function inverse_eltype(t::Union{CorrCholeskyFactor,StaticCorrCholeskyFactor},
::Type{T}) where {T<:UpperTriangular}
_ensure_float(eltype(T))
end

function inverse_at!(x::AbstractVector, index,
t::Union{CorrCholeskyFactor,StaticCorrCholeskyFactor}, U::UpperTriangular)
Expand Down
28 changes: 28 additions & 0 deletions src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,31 @@
end

robust_eltype(x::T) where T = robust_eltype(T)

"""
$(SIGNATURES)

Regularize input type, preferring a floating point, falling back to `Float64`.

Internal, not exported.

# Motivation

Type calculations occasionally give types that are too narrow (eg `Union{}` for empty
vectors) or broad. Since this package is primarily intended for *numerical*
calculations, we fall back to something sensible. This function implements the
heuristics for this, and is currently used in inverse element type calculations.
"""
function _ensure_float(::Type{T}) where T
if T <: Number # heuristic: it is assumed that every `Number` type defines `float`
return float(T)
else
return Float64

Check warning on line 59 in src/utilities.jl

View check run for this annotation

Codecov / codecov/patch

src/utilities.jl#L59

Added line #L59 was not covered by tests
end
end

# pass through containers
_ensure_float(::Type{T}) where {T<:AbstractArray} = T

# special case Union{}
_ensure_float(::Type{Union{}}) = Float64
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
Expand Down
Loading
Loading