Skip to content

Commit c86a34e

Browse files
committed
Add a test for getstepsize()
1 parent 406c2ad commit c86a34e

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

test/mcmc/hmc.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,25 @@ using Turing
329329
@test pvalue(ApproximateTwoSampleKSTest(vec(results), vec(results_prior))) > 0.001
330330
end
331331

332+
@testset "getstepsize: Turing.jl#2400" begin
333+
algs = [
334+
HMC(0.1, 10),
335+
HMCDA(0.8, 0.75),
336+
NUTS(0.5),
337+
NUTS(0, 0.5),
338+
]
339+
@testset "$(alg)" for alg in algs
340+
spl = Sampler(alg, gdemo_default)
341+
hmc_state = DynamicPPL.initialstep(
342+
Random.default_rng(),
343+
gdemo_default,
344+
spl,
345+
DynamicPPL.VarInfo(gdemo_default)
346+
)[2]
347+
@test Turing.Inference.getstepsize(spl, hmc_state) isa Float64
348+
end
349+
end
350+
332351
@testset "Check ADType" begin
333352
alg = HMC(0.1, 10; adtype=adbackend)
334353
m = DynamicPPL.contextualize(

0 commit comments

Comments
 (0)