-
-
Notifications
You must be signed in to change notification settings - Fork 59
Merge statespace module from http://github.com/jessegrabowski/pymc_statespace #174
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Couple of small comments
Also need to add relevant entries in the docs |
…equire data in `__init__`
Big refactor done, maybe have a look at the Making a custom statespace notebook to see how things work now. It's way better i think. First, it's the right way to do it in pytensor -- previously I was just copying a numpy implementation. Second, it gets rid of all the numpy internals in the I think the tests will pass on this next run (the irony of a commit called "all tests pass" failing CI is not lost on me, pride goeth before the fall), and then I would ask for this to be approved for merging, then I'll open PRs for the notebooks to go into pymc-examples, and for SolveDiscreteARE to go to pytensor. I'll also open an issue with a checklist of "to do" things still outstanding that I can work on (and maybe even drum up some contributors on). |
Tweak tolerances for test in `test_structural.py`
Tests should pass once https://github.com/pymc-devs/pymc/releases/tag/v5.7.2 gets picked up by the CI |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The integration with PyMC models looks sweet.
As is usual for big PRs I left a comment about an unimportant part of the code!
I suggest you can codecov locally (we don't seem to have it in the CI here :( ) just to see if there's some important lines that aren't being tested. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me know if there's some part that you want me to inspect more carefully
Also maybe have a look at how I handle missing values/data registration in |
Add tests for impulse_response_function and forecast Add tests for SARIMAX in "interpretable" mode
That's an artifact from the fact that our classes are just functions in disguise. Not a pattern that we should try to stick with unless necessary. pymc-devs/pymc#5308 |
As I mentioned above, I think the clean solution is to store a boolean mask of the nan/non-nan mask, and then the unraveled non-nan entries. Not sure how helpful but you can have a look at the more recent |
I could only have a superficial look, not being familiar enough with this kind of models and math :( The Asserts/SpecifyShapes should be pretty cheap / removable during compilation. The transpositions/concatenation are a tax we pay for every timeseries at the moment. Hopefully they aren't crazy expensive. Transpositions are as cheap as they get, what matters is how data layout plays into subsequent loops, but that's a bit too low-level optimization to look at before we see unreasonable bottlenecks. Concatenations may or not require copying arrays, I don't remember now. I don't see anything obvious from my myopic perspective :D |
Add reference to Harvey (1989) in SARIMA docstring
Add reference to Harvey (1989) in SARIMA docstring
Re-run notebooks
This is the promised statespace PR. It's a WIP, for several reasons, but it already represents a MVP that can be used. I'll try to explain what I was going for with everything here. I expect that a lot needs to be changed, though, and I look forward to working hard to improve it!
Overall design goal
Statespace models are a pair of linear equations of the form:
Where the first equation is called the "transition equation" and the second equation is called the "observation equation". The state vector$x_t$ contain both observed and unobserved states, while the observation vector $y_t$ contains only the states observed in data.
Since everything is linear and gaussian, the posterior distribution over the observed states is just a multivariate normal. Actually it can be shown that this is a special case of a guassian process. But, since we might have a lot of states or a lot of times, we can compute the posterior mean and covariance more efficiently via recursion. This is the Kalman Filter.
Sidebar 1: Why not just scan?
Given the advancements in scan lately, it's fair to ask why bother with filtering at all, and just let PyMC automatically infer the logp of these two equations. This might be possible, and it's something I'd like to explore. It will also let us push out beyond gaussian errors, which would be great. There was already a user on the forum looking for poisson distributed observations, for example.
For now it's not possible, because$Q$ might not be fully rank -- that is, not all states in $x_t$ need to be stochastic. We could solve this with clever slicing, though. A second problem is that the quantity $R_t\epsilon_t$ is not measurable. For gaussian errors this isn't a problem, because we can just fold it into the covariance matrix, and write $x_{t+1} = T_tx_t + \epsilon_t, \quad \epsilon_t \sim N(0, R_tQ_tR_t^T)$ , but this identity doesn't hold generally.
Sidebar over
Basically, this PR has the following goals:
Modules in this PR
I will briefly introduce what I've done in this PR, and try to justify my choices. I am confident they can be improved.
Core
The core module is responsible for representing arbitrary statespace models. There are two files,
representation.py
andstatespace.py
.representation.py
This is the lowest-level object in the module. It is responsible for initializing and storing the matrices c, d, T, Z, R, H, Q, along with initial state and covariance x0 and P0. The module overloads
__getitem__
and__setitem__
to allow the user to slice into matrices by name, for examplemod['transition', 0, 0]
gets the[0, 0]
position of theT
matrix.In addition, it also has a bunch of checks and logic to handle time-varying matrices. If a user wants the state-space matrices to vary over time, it is necessary to duplicate them and store the whole stack to scan over later. If the model is not time varying, it automatically slices around the time dimension so the user is never confronted with it.
Users should never have to touch this, it's all just low-level machinery.
statespace.py
This is the base class for all statespace models. It's responsible for combining the statespace matrices with a kalman filter to make a logp graph. This is accomplished by the
gather_required_random_variables
,update
, andbuild_statespace_graph
methods. Onlyupdate
will vary between models, and needs to be implemented.update
is responsible for taking a flat vector of parameters and shuttling them to the correct places in the statespace matrices.The property
param_names
also needs to be set for each model. This defines the names of the parameters thatgather_required_random_variables
will look for in the pymc model. It also defines the order of the flat parameter vector that will be passed intoupdate
.Filters
Filters holds all the implementations for the Kalman Filter.
distributions.py
Work in progress. Eventually, this should implement a PyMC distribution wrapper around the kalman filter, so that we can directly sample from it. There's just a lot of wrinkles to iron out.
kalman_filter.py
The actual kalman filters. Currently there are 5 implemented: standard, univariate, single time series, cholesky, and steady state. These need to be bench marked against each other. Single time series should be used when there is only a single observed state (e.g. ARIMA), otherwise use standard. Cholesky is supposd to be faster, but in my limited testing it's not. Not sure why. In principle the univariate filter is the most robust, but it has a scan in a scan, so it's quite slow. I haven't benchmarked it in JAX, though.
kalman_smoother.py
Pytensor implementation of the Kalman smoother. Good for hidden state inference, but split out because not all users will need it. It's a post-estimation thing.
numpy_filter.py
This is a re-implementation of the cholesky kalman filter in pure numpy. Potentially for use in
distribution.py
. I'm not sure whatpm.draw
should return when called on a statespace distribution, so it's there as an option.utilities.py
Shared functions between modules. Currently just holds a helper function to sort scan inputs into sequence and non-sequence. This is needed because if matrices are time-varying, they are sequences, otherwise they are non-sequences. I guess I could always make them sequences and just copy the matrix a bunch of times if its not time-varying, but this is more memory efficient (cope because I already sunk a lot of time into doing it this way?)
Models
This module will hold actual implementations of state space models that users can call. Right now they are fully contained models, following the setup of statsmodels, but I could imagine a better, more modular API. Right now I have VARMA, ARMIA, and local level.
There's also a utility file for shared functions, right now it's a little function that's used in the
update
function to keep track of slicing up the flat parameter vector.Utils
A hodge-podge of stuff.
numba_linalg.py
This holds numba implements linear algebra routines with no overload. Currently it's just
scipy.linalg.block_diag
.pytensor_scipy.py
This holds a pytensor
Op
s that should be split off and pulled intopytensor.tensor.slinalg
for solving the Discrete Algebraic Riccati equation. This is currently used in the SteadyStateFilter, and can save a lot of time by pre-computing and re-using a single matrix inverse for all time-steps in the kalman filter. It doesn't have a jaxified version, though, so it's not actually that useful right now. Solving AREs is useful in general, though.simulation.py
This holds numbafied routines for posterior predictive simulation. Since the current implementation returns a
pm.Potential
for the logp, it's not possible to use the usual posterior predictive sampling machinery, so I resorted to this. I hope it can be removed in the future. It contains separate functions for conditional and unconditional simulation.Unconditional simulation just applies the observation and state transition equations to an initial state. This is useful for computing theoretical moments of the system, and also for forecasting.
Conditional simulation draws statespace matrices and runs the data through the kalman filter (and, optionally, smoother). This is useful for hidden state inference and missing data interpolation.
Summary and To Do
Basically that's it. There's still a lot to do, but this is at least a start. I hope it can be useful to the community, and we can get it to be a super fast, reliable alternative to the statsmodels statespace module.
Here's a quick, non-exhaustive list of to-dos:
pm.potential
term via a distribution wrapper around the kalman filtermod = seasonal_part + trend_part + arima_part
.