Skip to content

Pr 451 - modified and added tests to statespace #466

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

Dekermanjian
Copy link

I reduced the complexity of the tests that involve testing exogenous forecasting.

@Dekermanjian
Copy link
Author

Hey @jessegrabowski, did I push this incorrectly? I assume this was not suppose to create a new pull request. How can I fix this?

Jesse, in terms of the exogenous forecasts, I am still experiencing an issue where an assertion error:

AssertionError: The first dimension of a time varying matrix (the time dimension) must be equal to the first dimension of the data (the time dimension)

The odd thing is that if I specify mode=JAX in build_statespace_graph() I don't get that warning and the test passes. However, it is passing because we are sampling the prior. If you try the sample the model then you get an error:

AttributeError: 'jaxlib.xla_extension.ArrayImpl' object has no attribute 'data'

Any idea what is going on?

@jessegrabowski
Copy link
Member

jessegrabowski commented May 1, 2025

Looks like you did it right! You can see your commit on the git history now.

Oh, I see that we're in a new PR. That's fine, I'll close the other one and we can continue work here :)

For future reference, you need to make sure you checkout my actual branch, from my fork of extras. Then you can directly push into it, and it will show up where you expect. If you use pycharm, you can check out PR branches directly from the git sidebar, which is handy.

Can you give a fuller traceback on the JAX error? I've run into this one before, but I need to know the context to remember the solution.

@Dekermanjian
Copy link
Author

@jessegrabowski does this help?

self = <pytensor.compile.function.types.Function object at 0x320ce0b90>, output_subset = None, args = (array([-0.34611857, -1.16511424, -0.77565797, -0.09904973]),), kwargs = {}, trust_input = False
vm = <pytensor.link.c.cvm.CVM object at 0x323c00770>, profile = None, arg_container = <array([[1.],
       [1.],
       [1.],
       [1.],
       [1.],
       [0.],
       [1.],
       [0.],
       [1.],
       [1.]])>
arg = array([-0.34611857, -1.16511424, -0.77565797, -0.09904973])

    def __call__(self, *args, output_subset=None, **kwargs):
        """
        Evaluates value of a function on given arguments.
    
        Parameters
        ----------
        args : list
            List of inputs to the function. All inputs are required, even when
            some of them are not necessary to calculate requested subset of
            outputs.
    
        kwargs : dict
            The function inputs can be passed as keyword argument. For this, use
            the name of the input or the input instance as the key.
    
            Keyword argument ``output_subset`` is a list of either indices of the
            function's outputs or the keys belonging to the `output_keys` dict
            and represent outputs that are requested to be calculated. Regardless
            of the presence of ``output_subset``, the updates are always calculated
            and processed. To disable the updates, you should use the ``copy``
            method with ``delete_updates=True``.
    
        Returns
        -------
        list
            List of outputs on indices/keys from ``output_subset`` or all of them,
            if ``output_subset`` is not passed.
        """
        trust_input = self.trust_input
        input_storage = self.input_storage
        vm = self.vm
        profile = self.profile
    
        if profile:
            t0 = time.perf_counter()
    
        if output_subset is not None:
            warnings.warn("output_subset is deprecated.", FutureWarning)
            if self.output_keys is not None:
                output_subset = [self.output_keys.index(key) for key in output_subset]
    
        # Reinitialize each container's 'provided' counter
        if trust_input:
            for arg_container, arg in zip(input_storage, args, strict=False):
                arg_container.storage[0] = arg
        else:
            for arg_container in input_storage:
                arg_container.provided = 0
    
            if len(args) + len(kwargs) > len(input_storage):
                raise TypeError("Too many parameter passed to pytensor function")
    
            # Set positional arguments
            for arg_container, arg in zip(input_storage, args, strict=False):
                # See discussion about None as input
                # https://groups.google.com/group/theano-dev/browse_thread/thread/920a5e904e8a8525/4f1b311a28fc27e5
                if arg is None:
                    arg_container.storage[0] = arg
                else:
                    try:
                        arg_container.storage[0] = arg_container.type.filter(
                            arg,
                            strict=arg_container.strict,
                            allow_downcast=arg_container.allow_downcast,
                        )
    
                    except Exception as e:
                        i = input_storage.index(arg_container)
                        function_name = "pytensor function"
                        argument_name = "argument"
                        if self.name:
                            function_name += ' with name "' + self.name + '"'
                        if hasattr(arg, "name") and arg.name:
                            argument_name += ' with name "' + arg.name + '"'
                        where = get_variable_trace_string(self.maker.inputs[i].variable)
                        if len(e.args) == 1:
                            e.args = (
                                "Bad input "
                                + argument_name
                                + " to "
                                + function_name
                                + f" at index {int(i)} (0-based). {where}"
                                + e.args[0],
                            )
                        else:
                            e.args = (
                                "Bad input "
                                + argument_name
                                + " to "
                                + function_name
                                + f" at index {int(i)} (0-based). {where}"
                            ) + e.args
                        self._restore_defaults()
                        raise
                arg_container.provided += 1
    
        # Set keyword arguments
        if kwargs:  # for speed, skip the items for empty kwargs
            for k, arg in kwargs.items():
                self[k] = arg
    
        if not trust_input:
            # Collect aliased inputs among the storage space
            for potential_group in self._potential_aliased_input_groups:
                args_share_memory: list[list[int]] = []
                for i in potential_group:
                    i_type = self.maker.inputs[i].variable.type
                    i_val = input_storage[i].storage[0]
    
                    # Check if value is aliased with any of the values in one of the groups
                    for j_group in args_share_memory:
                        if any(
                            i_type.may_share_memory(input_storage[j].storage[0], i_val)
                            for j in j_group
                        ):
                            j_group.append(i)
                            break
                    else:  # no break
                        # Create a new group
                        args_share_memory.append([i])
    
                # Check for groups of more than one argument that share memory
                for group in args_share_memory:
                    if len(group) > 1:
                        # copy all but the first
                        for i in group[1:]:
                            input_storage[i].storage[0] = copy.copy(
                                input_storage[i].storage[0]
                            )
    
            # Check if inputs are missing, or if inputs were set more than once, or
            # if we tried to provide inputs that are supposed to be implicit.
            for arg_container in input_storage:
                if arg_container.required and not arg_container.provided:
                    self._restore_defaults()
                    raise TypeError(
                        f"Missing required input: {getattr(self.inv_finder[arg_container], 'variable', self.inv_finder[arg_container])}"
                    )
                if arg_container.provided > 1:
                    self._restore_defaults()
                    raise TypeError(
                        f"Multiple values for input: {getattr(self.inv_finder[arg_container], 'variable', self.inv_finder[arg_container])}"
                    )
                if arg_container.implicit and arg_container.provided > 0:
                    self._restore_defaults()
                    raise TypeError(
                        f"Tried to provide value for implicit input: {getattr(self.inv_finder[arg_container], 'variable', self.inv_finder[arg_container])}"
                    )
    
        # Do the actual work
        if profile:
            t0_fn = time.perf_counter()
        try:
>           outputs = vm() if output_subset is None else vm(output_subset=output_subset)

/opt/miniconda3/envs/pymc-extras/lib/python3.12/site-packages/pytensor/compile/function/types.py:1037: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
/opt/miniconda3/envs/pymc-extras/lib/python3.12/site-packages/pytensor/scan/op.py:1649: in rval
    r = p(n, [x[0] for x in i], o)
/opt/miniconda3/envs/pymc-extras/lib/python3.12/site-packages/pytensor/scan/op.py:1577: in p
    t_fn, n_steps = scan_perform_ext.perform(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

>   ???
E   AttributeError: 'jaxlib.xla_extension.ArrayImpl' object has no attribute 'data'

pytensor/scan/scan_perform.pyx:397: AttributeError

During handling of the above exception, another exception occurred:

exog_pymc_mod = <pymc.model.core.Model object at 0x16e86adb0>, rng = Generator(PCG64) at 0x16E2C9B60

    @pytest.fixture(scope="session")
    def idata_exog(exog_pymc_mod, rng):
        with exog_pymc_mod:
>           idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng)

tests/statespace/test_statespace.py:227: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
/opt/miniconda3/envs/pymc-extras/lib/python3.12/site-packages/pymc/sampling/mcmc.py:832: in sample
    initial_points, step = init_nuts(
/opt/miniconda3/envs/pymc-extras/lib/python3.12/site-packages/pymc/sampling/mcmc.py:1605: in init_nuts
    initial_points = _init_jitter(
/opt/miniconda3/envs/pymc-extras/lib/python3.12/site-packages/pymc/sampling/mcmc.py:1482: in _init_jitter
    point_logp = model_logp_fn(point)
/opt/miniconda3/envs/pymc-extras/lib/python3.12/site-packages/pymc/sampling/mcmc.py:1603: in model_logp_fn
    return logp_dlogp_func([q], extra_vars={})[0]
/opt/miniconda3/envs/pymc-extras/lib/python3.12/site-packages/pymc/model/core.py:289: in __call__
    return self._pytensor_function(*grad_vars)
/opt/miniconda3/envs/pymc-extras/lib/python3.12/site-packages/pytensor/compile/function/types.py:1047: in __call__
    raise_with_op(
/opt/miniconda3/envs/pymc-extras/lib/python3.12/site-packages/pytensor/link/utils.py:526: in raise_with_op
    raise exc_value.with_traceback(exc_trace)
/opt/miniconda3/envs/pymc-extras/lib/python3.12/site-packages/pytensor/compile/function/types.py:1037: in __call__
    outputs = vm() if output_subset is None else vm(output_subset=output_subset)
/opt/miniconda3/envs/pymc-extras/lib/python3.12/site-packages/pytensor/scan/op.py:1649: in rval
    r = p(n, [x[0] for x in i], o)
/opt/miniconda3/envs/pymc-extras/lib/python3.12/site-packages/pytensor/scan/op.py:1577: in p
    t_fn, n_steps = scan_perform_ext.perform(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

>   ???
E   AttributeError: 'jaxlib.xla_extension.ArrayImpl' object has no attribute 'data'
E   Apply node that caused the error: Scan{grad_of_forward_kalman_pass, while_loop=False, inplace=all}(Composite{...}.10, Composite{...}.0, Subtensor{start:stop:step}.0, Composite{...}.1, Transpose{axes=[0, 2, 1]}.0, Transpose{axes=[0, 2, 1]}.0, ExpandDims{axis=1}.0, Alloc.0, ExpandDims{axis=1}.0, Subtensor{start:stop:step}.0, Subtensor{start:stop:step}.0, Subtensor{start:stop:step}.0, Alloc.0, Subtensor{start:stop:step}.0, Subtensor{start:stop:step}.0, Alloc.0, Alloc.0)
E   Toposort index: 82
E   Inputs types: [TensorType(int64, shape=()), TensorType(bool, shape=(None, 1)), TensorType(float64, shape=(None, 1, 1)), TensorType(float64, shape=(None, 1)), TensorType(float64, shape=(None, 2, 2)), TensorType(float64, shape=(None, 2, 2)), TensorType(float64, shape=(None, 1, 2)), TensorType(float64, shape=(None, 1)), TensorType(float64, shape=(None, 1, 2)), TensorType(float64, shape=(None, 1)), TensorType(float64, shape=(None, 1, 2)), TensorType(float64, shape=(None, 2, 2)), TensorType(float64, shape=(None,)), TensorType(float64, shape=(None, 1)), TensorType(float64, shape=(None, 2, 2)), TensorType(float64, shape=(None, 2)), TensorType(float64, shape=(None, 2, 2))]
E   Inputs shapes: [(), (10, 1), (10, 1, 1), (10, 1), (10, 2, 2), (10, 2, 2), (10, 1, 2), (10, 1), (10, 1, 2), (10, 1), (10, 1, 2), (10, 2, 2), (10,), (10, 1), (10, 2, 2), (11, 2), (11, 2, 2)]
E   Inputs strides: [(), (1, 1), (-8, 8, 8), (8, 8), (-32, 8, 16), (-32, 8, 16), (-16, 8, 8), (8, 8), (-16, 8, 8), (-8, 8), (-16, 16, 8), (-32, 16, 8), (8,), (-8, 8), (-32, 16, 8), (16, 8), (32, 16, 8)]
E   Inputs values: [array(10), 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown']
E   Outputs clients: [[Subtensor{start:stop:step}(Scan{grad_of_forward_kalman_pass, while_loop=False, inplace=all}.0, ScalarFromTensor.0, ScalarFromTensor.0, -1)], [Subtensor{start:stop:step}(Scan{grad_of_forward_kalman_pass, while_loop=False, inplace=all}.1, ScalarFromTensor.0, ScalarFromTensor.0, -1)]]
E   
E   HINT: Re-running with most PyTensor optimizations disabled could provide a back-trace showing when this node was created. This can be done by setting the PyTensor flag 'optimizer=fast_compile'. If that does not work, PyTensor optimizations can be disabled with 'optimizer=None'.
E   HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.

pytensor/scan/scan_perform.pyx:397: AttributeError

@Dekermanjian
Copy link
Author

For future reference, you need to make sure you checkout my actual branch, from my fork of extras. Then you can directly push into it, and it will show up where you expect. If you use pycharm, you can check out PR branches directly from the git sidebar, which is handy.

Ah okay, I checkedout the PR from inside of a fork I made of pymc-extras. I will make sure to do it the right way moving forward!

@jessegrabowski
Copy link
Member

Actually you did it the right way, it's very unusal to push into someone else's PR. But it's all good!

@Dekermanjian
Copy link
Author

Okay, @jessegrabowski! I figured out the JAX issue. It was unhappy because I was building the graph with JAX and then trying to sample the model using the pymc sampler. When I use the numpyro sampler it runs without any issues.

Okay, so now is the weird thing. With JAX and numpyro everything works including the exogenous forecasts. However, if you build the graph with the default mode and sample with native pymc, then the forecasts will fail and return the above assertion error about the first dimension of the time varying matrix that I posted above.

I have been digging through the code trying to find where exactly this issue is arising but I am struggling to pinpoint the location. Do you have any hypotheses to where this might be originating?

@jessegrabowski
Copy link
Member

jessegrabowski commented May 2, 2025

The assertion gets added here:

def add_check_on_time_varying_shapes(

I'll copy what I was worried about with set data from the other thread, because it's relevant here. Need to make sure we're computing the initial states of the forecast with the old data, then changing it:

I made this change then undid it because I thought it wasn't doing the right thing. I need to double-check, but the forecasting logic basically goes like this:
1. The original user graph -- using the training data -- is reconstructed in its entirety.
2. Using the original graph, we compute the value of the hidden states at the requested x0 for the forecasts. We do not want new scenario data here, otherwise the value of the requested x0 will be wrong.
3. Starting from this x0, we construct a new graph which iterates the statespace equations forward for the requested number of time steps. If there are exogenous regressors, this is where we do want them.

I think I had the impression that this change put the scenario data in step (2), but I looked at it a few weeks ago now and I forget.

So I think an important test is to make sure that t=0 of the forecast always matches the "data" in the provided hidden state. That will let us know if doing the set_data is somehow doing something wrong.

I bring this up because the Asset is a runtime check. If the lengths of things don't agree, the computation is probably going wrong silently in JAX mode.

Comment on lines +2072 to +2078
for name in self.data_names:
if name in scenario.keys():
pm.set_data(
{"data": np.zeros((len(forecast_index), self.k_endog))},
coords={"data_time": np.arange(len(forecast_index))},
)
break
Copy link
Author

@Dekermanjian Dekermanjian May 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this logic to update the static shape of the target variable when forecasting. I realize that this logic is naive in a few aspects:
1). This is making an assumption on how the scenario data is constructed (I think I resolved that)
2). The timing of when this is being called may be inappropriate
3). Probably other things that I am not thinking of right now

With your suggestions I can make this more robust, I just wanted to confirm that my suspicion that the issue is that the static shape of the target needs to be updated to reflect the shape of the forecast index?

EDIT:
Sorry about the multiple pings, I didn't highlight all of the code.

I should also mention that these do pass the unit tests in test_statespace.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants