Skip to content

Commit 2f7e03c

Browse files
committed
Add logp parametrization for Categorical
1 parent 28b5539 commit 2f7e03c

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

src/parameterized/categorical.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,22 @@ ncategories(d::Categorical) = length(d.p)
1919

2020
logdensity(d::Categorical{(:p)}, y) = log(d.p[y])
2121

22-
# Very inefficient because of the heavy implementation of Dists.DiscreteNonParametric
22+
# The implementation of Dists.DiscreteNonParametric has heavy argument checks
23+
# But I think since the values of Categorical are 1:n the sortperm has no effect
24+
# So it might be OK
2325
distproxy(d::Categorical{(:p)}) = Dists.Categorical(d.p)
2426

2527
Base.rand(rng::AbstractRNG, T::Type, d::Categorical{(:p)}) = rand(rng, distproxy(d))
2628

2729
asparams(::Type{<:Categorical}, ::Val{:p}) = as𝕀
30+
31+
###############################################################################
32+
@kwstruct Categorical(logp)
33+
34+
logdensity(d::Categorical{(:logp)}, y) = d.logp[y]
35+
36+
distproxy(d::Categorical{(:logp)}) = Dists.Categorical(exp.(d.logp)) # inefficient
37+
38+
Base.rand(rng::AbstractRNG, T::Type, d::Categorical{(:logp)}) = rand(rng, distproxy(d))
39+
40+
asparams(::Type{<:Categorical}, ::Val{:logp}) = asℝ

0 commit comments

Comments
 (0)