-
-
Notifications
You must be signed in to change notification settings - Fork 59
Pr 451 - modified and added tests to statespace #466
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Hey @jessegrabowski, did I push this incorrectly? I assume this was not suppose to create a new pull request. How can I fix this? Jesse, in terms of the exogenous forecasts, I am still experiencing an issue where an assertion error:
The odd thing is that if I specify
Any idea what is going on? |
Looks like you did it right! You can see your commit on the git history now. Oh, I see that we're in a new PR. That's fine, I'll close the other one and we can continue work here :) For future reference, you need to make sure you checkout my actual branch, from my fork of extras. Then you can directly push into it, and it will show up where you expect. If you use pycharm, you can check out PR branches directly from the git sidebar, which is handy. Can you give a fuller traceback on the JAX error? I've run into this one before, but I need to know the context to remember the solution. |
@jessegrabowski does this help?
|
Ah okay, I checkedout the PR from inside of a fork I made of pymc-extras. I will make sure to do it the right way moving forward! |
Actually you did it the right way, it's very unusal to push into someone else's PR. But it's all good! |
Okay, @jessegrabowski! I figured out the JAX issue. It was unhappy because I was building the graph with JAX and then trying to sample the model using the pymc sampler. When I use the numpyro sampler it runs without any issues. Okay, so now is the weird thing. With JAX and numpyro everything works including the exogenous forecasts. However, if you build the graph with the default mode and sample with native pymc, then the forecasts will fail and return the above assertion error about the first dimension of the time varying matrix that I posted above. I have been digging through the code trying to find where exactly this issue is arising but I am struggling to pinpoint the location. Do you have any hypotheses to where this might be originating? |
The assertion gets added here:
I'll copy what I was worried about with set data from the other thread, because it's relevant here. Need to make sure we're computing the initial states of the forecast with the old data, then changing it: I made this change then undid it because I thought it wasn't doing the right thing. I need to double-check, but the forecasting logic basically goes like this: So I think an important test is to make sure that t=0 of the forecast always matches the "data" in the provided hidden state. That will let us know if doing the I bring this up because the |
for name in self.data_names: | ||
if name in scenario.keys(): | ||
pm.set_data( | ||
{"data": np.zeros((len(forecast_index), self.k_endog))}, | ||
coords={"data_time": np.arange(len(forecast_index))}, | ||
) | ||
break |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added this logic to update the static shape of the target variable when forecasting. I realize that this logic is naive in a few aspects:
1). This is making an assumption on how the scenario data is constructed (I think I resolved that)
2). The timing of when this is being called may be inappropriate
3). Probably other things that I am not thinking of right now
With your suggestions I can make this more robust, I just wanted to confirm that my suspicion that the issue is that the static shape of the target needs to be updated to reflect the shape of the forecast index?
EDIT:
Sorry about the multiple pings, I didn't highlight all of the code.
I should also mention that these do pass the unit tests in test_statespace.py
I reduced the complexity of the tests that involve testing exogenous forecasting.