|
21 | 21 | import warnings
|
22 | 22 |
|
23 | 23 | from collections import defaultdict
|
24 |
| -from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union |
| 24 | +from typing import ( |
| 25 | + Any, |
| 26 | + Dict, |
| 27 | + Iterator, |
| 28 | + List, |
| 29 | + Literal, |
| 30 | + Optional, |
| 31 | + Sequence, |
| 32 | + Tuple, |
| 33 | + Union, |
| 34 | + overload, |
| 35 | +) |
25 | 36 |
|
26 | 37 | import numpy as np
|
27 | 38 | import pytensor.gradient as tg
|
@@ -318,6 +329,65 @@ def _sample_external_nuts(
|
318 | 329 | )
|
319 | 330 |
|
320 | 331 |
|
| 332 | +@overload |
| 333 | +def sample( |
| 334 | + draws: int = 1000, |
| 335 | + *, |
| 336 | + tune: int = 1000, |
| 337 | + chains: Optional[int] = None, |
| 338 | + cores: Optional[int] = None, |
| 339 | + random_seed: RandomState = None, |
| 340 | + progressbar: bool = True, |
| 341 | + step=None, |
| 342 | + nuts_sampler: str = "pymc", |
| 343 | + initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None, |
| 344 | + init: str = "auto", |
| 345 | + jitter_max_retries: int = 10, |
| 346 | + n_init: int = 200_000, |
| 347 | + trace: Optional[TraceOrBackend] = None, |
| 348 | + discard_tuned_samples: bool = True, |
| 349 | + compute_convergence_checks: bool = True, |
| 350 | + keep_warning_stat: bool = False, |
| 351 | + return_inferencedata: Literal[True], |
| 352 | + idata_kwargs: Optional[Dict[str, Any]] = None, |
| 353 | + nuts_sampler_kwargs: Optional[Dict[str, Any]] = None, |
| 354 | + callback=None, |
| 355 | + mp_ctx=None, |
| 356 | + **kwargs, |
| 357 | +) -> InferenceData: |
| 358 | + ... |
| 359 | + |
| 360 | + |
| 361 | +@overload |
| 362 | +def sample( |
| 363 | + draws: int = 1000, |
| 364 | + *, |
| 365 | + tune: int = 1000, |
| 366 | + chains: Optional[int] = None, |
| 367 | + cores: Optional[int] = None, |
| 368 | + random_seed: RandomState = None, |
| 369 | + progressbar: bool = True, |
| 370 | + step=None, |
| 371 | + nuts_sampler: str = "pymc", |
| 372 | + initvals: Optional[Union[StartDict, Sequence[Optional[StartDict]]]] = None, |
| 373 | + init: str = "auto", |
| 374 | + jitter_max_retries: int = 10, |
| 375 | + n_init: int = 200_000, |
| 376 | + trace: Optional[TraceOrBackend] = None, |
| 377 | + discard_tuned_samples: bool = True, |
| 378 | + compute_convergence_checks: bool = True, |
| 379 | + keep_warning_stat: bool = False, |
| 380 | + return_inferencedata: Literal[False], |
| 381 | + idata_kwargs: Optional[Dict[str, Any]] = None, |
| 382 | + nuts_sampler_kwargs: Optional[Dict[str, Any]] = None, |
| 383 | + callback=None, |
| 384 | + mp_ctx=None, |
| 385 | + model: Optional[Model] = None, |
| 386 | + **kwargs, |
| 387 | +) -> MultiTrace: |
| 388 | + ... |
| 389 | + |
| 390 | + |
321 | 391 | def sample(
|
322 | 392 | draws: int = 1000,
|
323 | 393 | *,
|
|
0 commit comments