diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 4ec7b254b8..50abac7b3a 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -21,7 +21,18 @@ import warnings from collections import defaultdict -from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union +from typing import ( + Any, + Dict, + Iterator, + List, + Literal, + Optional, + Sequence, + Tuple, + Union, + overload, +) import numpy as np import pytensor.gradient as tg @@ -318,6 +329,65 @@ def _sample_external_nuts( ) +@overload +def sample( + draws: int = 1000, + *, + tune: int = 1000, + chains: Optional[int] = None, + cores: Optional[int] = None, + random_seed: RandomState = None, + progressbar: bool = True, + step=None, + nuts_sampler: str = "pymc", + initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None, + init: str = "auto", + jitter_max_retries: int = 10, + n_init: int = 200_000, + trace: Optional[TraceOrBackend] = None, + discard_tuned_samples: bool = True, + compute_convergence_checks: bool = True, + keep_warning_stat: bool = False, + return_inferencedata: Literal[True], + idata_kwargs: Optional[Dict[str, Any]] = None, + nuts_sampler_kwargs: Optional[Dict[str, Any]] = None, + callback=None, + mp_ctx=None, + **kwargs, +) -> InferenceData: + ... + + +@overload +def sample( + draws: int = 1000, + *, + tune: int = 1000, + chains: Optional[int] = None, + cores: Optional[int] = None, + random_seed: RandomState = None, + progressbar: bool = True, + step=None, + nuts_sampler: str = "pymc", + initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None, + init: str = "auto", + jitter_max_retries: int = 10, + n_init: int = 200_000, + trace: Optional[TraceOrBackend] = None, + discard_tuned_samples: bool = True, + compute_convergence_checks: bool = True, + keep_warning_stat: bool = False, + return_inferencedata: Literal[False], + idata_kwargs: Optional[Dict[str, Any]] = None, + nuts_sampler_kwargs: Optional[Dict[str, Any]] = None, + callback=None, + mp_ctx=None, + model: Optional[Model] = None, + **kwargs, +) -> MultiTrace: + ... + + def sample( draws: int = 1000, *,