Skip to content

Commit f632a34

Browse files
Type overloading for return_inferencedata in pm.sample() (#6709)
1 parent 9b72c2e commit f632a34

File tree

1 file changed

+71
-1
lines changed

1 file changed

+71
-1
lines changed

pymc/sampling/mcmc.py

+71-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,18 @@
2121
import warnings
2222

2323
from collections import defaultdict
24-
from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union
24+
from typing import (
25+
Any,
26+
Dict,
27+
Iterator,
28+
List,
29+
Literal,
30+
Optional,
31+
Sequence,
32+
Tuple,
33+
Union,
34+
overload,
35+
)
2536

2637
import numpy as np
2738
import pytensor.gradient as tg
@@ -318,6 +329,65 @@ def _sample_external_nuts(
318329
)
319330

320331

332+
@overload
333+
def sample(
334+
draws: int = 1000,
335+
*,
336+
tune: int = 1000,
337+
chains: Optional[int] = None,
338+
cores: Optional[int] = None,
339+
random_seed: RandomState = None,
340+
progressbar: bool = True,
341+
step=None,
342+
nuts_sampler: str = "pymc",
343+
initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None,
344+
init: str = "auto",
345+
jitter_max_retries: int = 10,
346+
n_init: int = 200_000,
347+
trace: Optional[TraceOrBackend] = None,
348+
discard_tuned_samples: bool = True,
349+
compute_convergence_checks: bool = True,
350+
keep_warning_stat: bool = False,
351+
return_inferencedata: Literal[True],
352+
idata_kwargs: Optional[Dict[str, Any]] = None,
353+
nuts_sampler_kwargs: Optional[Dict[str, Any]] = None,
354+
callback=None,
355+
mp_ctx=None,
356+
**kwargs,
357+
) -> InferenceData:
358+
...
359+
360+
361+
@overload
362+
def sample(
363+
draws: int = 1000,
364+
*,
365+
tune: int = 1000,
366+
chains: Optional[int] = None,
367+
cores: Optional[int] = None,
368+
random_seed: RandomState = None,
369+
progressbar: bool = True,
370+
step=None,
371+
nuts_sampler: str = "pymc",
372+
initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None,
373+
init: str = "auto",
374+
jitter_max_retries: int = 10,
375+
n_init: int = 200_000,
376+
trace: Optional[TraceOrBackend] = None,
377+
discard_tuned_samples: bool = True,
378+
compute_convergence_checks: bool = True,
379+
keep_warning_stat: bool = False,
380+
return_inferencedata: Literal[False],
381+
idata_kwargs: Optional[Dict[str, Any]] = None,
382+
nuts_sampler_kwargs: Optional[Dict[str, Any]] = None,
383+
callback=None,
384+
mp_ctx=None,
385+
model: Optional[Model] = None,
386+
**kwargs,
387+
) -> MultiTrace:
388+
...
389+
390+
321391
def sample(
322392
draws: int = 1000,
323393
*,

0 commit comments

Comments
 (0)