@@ -30,10 +30,10 @@ function inverse_at!(x::AbstractVector, index::Int, t::ScalarTransform, y::Real)
30
30
index + 1
31
31
end
32
32
33
- function inverse_eltype (t:: ScalarTransform , y :: Real )
33
+ function inverse_eltype (t:: ScalarTransform , :: Type{T} ) where T <: Real
34
34
# NOTE this is a shortcut to get sensible types for all subtypes of ScalarTransform, which
35
35
# we test for. If it breaks it should be extended accordingly.
36
- return Base. promote_typejoin_union (Base. promote_op (inverse, typeof (t), typeof (y) ))
36
+ return Base. promote_typejoin_union (Base. promote_op (inverse, typeof (t), T ))
37
37
end
38
38
39
39
_domain_label (:: ScalarTransform , index:: Int ) = ()
@@ -66,43 +66,50 @@ $(TYPEDEF)
66
66
67
67
Exponential transformation `x ↦ eˣ`. Maps from all reals to the positive reals.
68
68
"""
69
- struct TVExp <: ScalarTransform
70
- end
69
+ struct TVExp <: ScalarTransform end
70
+
71
71
transform (:: TVExp , x:: Real ) = exp (x)
72
+
72
73
transform_and_logjac (t:: TVExp , x:: Real ) = transform (t, x), x
73
74
74
75
function inverse (:: TVExp , x:: Number )
75
76
log (x)
76
77
end
78
+
77
79
inverse_and_logjac (t:: TVExp , x:: Number ) = inverse (t, x), - log (x)
78
80
79
81
"""
80
82
$(TYPEDEF)
81
83
82
84
Logistic transformation `x ↦ logit(x)`. Maps from all reals to (0, 1).
83
85
"""
84
- struct TVLogistic <: ScalarTransform
85
- end
86
+ struct TVLogistic <: ScalarTransform end
87
+
86
88
transform (:: TVLogistic , x:: Real ) = logistic (x)
89
+
87
90
transform_and_logjac (t:: TVLogistic , x:: Real ) = transform (t, x), logistic_logjac (x)
88
91
89
92
function inverse (:: TVLogistic , x:: Number )
90
93
logit (x)
91
94
end
95
+
92
96
inverse_and_logjac (t:: TVLogistic , x:: Number ) = inverse (t, x), logit_logjac (x)
93
97
94
98
"""
95
99
$(TYPEDEF)
96
100
97
- Shift transformation `x ↦ x + shift`.
101
+ Shift transformation `x ↦ x + shift`.
98
102
"""
99
103
struct TVShift{T <: Real } <: ScalarTransform
100
104
shift:: T
101
105
end
106
+
102
107
transform (t:: TVShift , x:: Real ) = x + t. shift
108
+
103
109
transform_and_logjac (t:: TVShift , x:: Real ) = transform (t, x), logjac_zero (LogJac (), typeof (x))
104
110
105
111
inverse (t:: TVShift , x:: Number ) = x - t. shift
112
+
106
113
inverse_and_logjac (t:: TVShift , x:: Number ) = inverse (t, x), logjac_zero (LogJac (), typeof (x))
107
114
108
115
"""
@@ -117,12 +124,15 @@ struct TVScale{T} <: ScalarTransform
117
124
new (scale)
118
125
end
119
126
end
127
+
120
128
TVScale (scale:: T ) where {T} = TVScale {T} (scale)
121
129
122
130
transform (t:: TVScale , x:: Real ) = t. scale * x
123
- transform_and_logjac (t:: TVScale{<:Real} , x:: Real ) = transform (t, x), log (t. scale)
131
+
132
+ transform_and_logjac (t:: TVScale{<:Real} , x:: Real ) = transform (t, x), log (t. scale)
124
133
125
134
inverse (t:: TVScale , x:: Number ) = x / t. scale
135
+
126
136
inverse_and_logjac (t:: TVScale{<:Real} , x:: Number ) = inverse (t, x), - log (t. scale)
127
137
128
138
"""
@@ -155,15 +165,15 @@ struct CompositeScalarTransform{Ts <: Tuple} <: ScalarTransform
155
165
end
156
166
157
167
transform (t:: CompositeScalarTransform , x) = foldr (transform, t. transforms, init= x)
158
- function transform_and_logjac (ts:: CompositeScalarTransform , x)
168
+ function transform_and_logjac (ts:: CompositeScalarTransform , x)
159
169
foldr (ts. transforms, init= (x, logjac_zero (LogJac (), typeof (x)))) do t, (x, logjac)
160
170
nx, nlogjac = transform_and_logjac (t, x)
161
171
(nx, logjac + nlogjac)
162
172
end
163
173
end
164
174
165
175
inverse (ts:: CompositeScalarTransform , x) = foldl ((y, t) -> inverse (t, y), ts. transforms, init= x)
166
- function inverse_and_logjac (ts:: CompositeScalarTransform , x)
176
+ function inverse_and_logjac (ts:: CompositeScalarTransform , x)
167
177
foldl (ts. transforms, init= (x, logjac_zero (LogJac (), typeof (x)))) do (x, logjac), t
168
178
nx, nlogjac = inverse_and_logjac (t, x)
169
179
(nx, logjac + nlogjac)
@@ -283,7 +293,7 @@ function Base.show(io::IO, ct::CompositeScalarTransform{Tuple{TVShift{T}, TVNeg,
283
293
print (io, " as(Real, -∞, " , ct. transforms[1 ]. shift, " )" )
284
294
end
285
295
function Base. show (io:: IO , ct:: CompositeScalarTransform{Tuple{TVShift{T1}, TVScale{T2}, TVLogistic}} ) where {T1, T2}
286
- print (io, " as(Real, " , ct. transforms[1 ]. shift, " , " , ct. transforms[1 ]. shift +
296
+ print (io, " as(Real, " , ct. transforms[1 ]. shift, " , " , ct. transforms[1 ]. shift +
287
297
ct. transforms[2 ]. scale, " )" )
288
298
end
289
299
0 commit comments